Spaces:
Sleeping
Sleeping
| import os | |
| import hashlib | |
| import base64 | |
| import asyncio | |
| import wave | |
| import io | |
| import httpx | |
| from piper.voice import PiperVoice | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # ββ Secrets (set these in HuggingFace Space β Settings β Variables and secrets) ββ | |
| EXPECTED_HASH = os.environ.get("HASH_VALUE") | |
| SERVER_MAIN = os.environ.get("SERVER_MAIN_URL") | |
| SERV_CODE = os.environ.get("SERV_CODE") | |
| CF_SECRET_KEY = os.environ.get("CF_SECRET_KEY") # reserved for Cloudflare token verification if needed | |
| ALLOWED_DOMAIN = os.environ.get("ALLOWED_DOMAIN", "buildwithsupratim.github.io") | |
| # ββ Piper TTS config ββ | |
| PIPER_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "en_US-lessac-medium.onnx") | |
| # Load voice once at startup (heavy; reuse across requests) | |
| _piper_voice: PiperVoice = PiperVoice.load(PIPER_MODEL_PATH) | |
| app = FastAPI(title="Maria Middleware", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://buildwithsupratim.github.io"], | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"], | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # AUTH HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _hash_auth_code(auth_code: str) -> str: | |
| return hashlib.sha256(auth_code.encode()).hexdigest() | |
| def _check_first(request: Request) -> bool: | |
| """Primary check: hash of auth_code header must match EXPECTED_HASH.""" | |
| auth_code = request.headers.get("auth_code") or request.headers.get("Auth-Code") | |
| if not auth_code: | |
| return False | |
| return _hash_auth_code(auth_code) == EXPECTED_HASH | |
| def _check_second(request: Request) -> bool: | |
| """Fallback check: request must originate from the allowed domain. | |
| Cloudflare adds CF-Referer / the standard Referer / Origin headers. | |
| We validate at the domain level so any path under the GitHub Pages site passes. | |
| """ | |
| referer = request.headers.get("referer", "") | |
| origin = request.headers.get("origin", "") | |
| for value in (referer, origin): | |
| if value and ALLOWED_DOMAIN in value: | |
| return True | |
| return False | |
| async def _authorize(request: Request) -> None: | |
| """Raise 403 if neither check passes.""" | |
| if _check_first(request): | |
| return | |
| if _check_second(request): | |
| return | |
| raise HTTPException(status_code=403, detail="Forbidden: invalid auth_code and domain not allowed.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TTS HELPER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _generate_tts_base64(text: str) -> str: | |
| """Synthesize *text* with Piper TTS and return base64-encoded WAV bytes.""" | |
| wav_buffer = io.BytesIO() | |
| with wave.open(wav_buffer, "wb") as wav_file: | |
| _piper_voice.synthesize(text, wav_file) | |
| return base64.b64encode(wav_buffer.getvalue()).decode("utf-8") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ROUTES | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def root(): | |
| """Root endpoint for HuggingFace health checks.""" | |
| return {"status": "alive"} | |
| async def ping(): | |
| """Health-check endpoint. Wakes the Space if it was sleeping.""" | |
| return {"status": "alive"} | |
| async def chat_start(request: Request): | |
| """ | |
| 1. Authenticate the caller. | |
| 2. Forward the payload to SERVER_MAIN with serv_code in headers. | |
| 3. Override / add audio_output with Piper TTS base64 WAV audio. | |
| 4. Return the final response. | |
| """ | |
| await _authorize(request) | |
| # ββ Parse incoming JSON ββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| payload = await request.json() | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid JSON body.") | |
| # ββ Forward to SERVER_MAIN βββββββββββββββββββββββββββββββββββ | |
| forward_headers = { | |
| "Content-Type": "application/json", | |
| "serv_code": SERV_CODE, | |
| } | |
| async with httpx.AsyncClient(timeout=600.0) as client: | |
| try: | |
| server_response = await client.post( | |
| SERVER_MAIN, | |
| json=payload, | |
| headers=forward_headers, | |
| ) | |
| server_response.raise_for_status() | |
| response_data = server_response.json() | |
| except httpx.HTTPStatusError as exc: | |
| # Extract as much info as possible from the proxy/server | |
| error_details = { | |
| "status_code": exc.response.status_code, | |
| "reason_phrase": exc.response.reason_phrase, | |
| "url": str(exc.request.url), | |
| "response_body": exc.response.text[:500], # Get first 500 chars of the error page | |
| "headers": dict(exc.response.headers) # Useful to see if 'server' is 'uvicorn' or 'nginx' | |
| } | |
| print(f"DEBUG 502: {error_details}") # This will show in Hugging Face Logs | |
| raise HTTPException( | |
| status_code=exc.response.status_code, | |
| detail=f"Downstream Error: {error_details['response_body']}" | |
| ) | |
| except httpx.RequestError as exc: | |
| raise HTTPException(status_code=502, detail=f"SERVER_MAIN unreachable: {str(exc)}") | |
| # ββ Extract text for TTS βββββββββββββββββββββββββββββββββββββ | |
| try: | |
| response_message = response_data["query"]["response_message"] | |
| tts_text = response_message.get("text", "") | |
| except (KeyError, TypeError): | |
| raise HTTPException(status_code=502, detail="Unexpected response schema from SERVER_MAIN.") | |
| # ββ Generate TTS and override audio_output βββββββββββββββββββ | |
| if tts_text: | |
| # Run blocking Piper synthesis in a thread pool to avoid blocking the event loop | |
| loop = asyncio.get_event_loop() | |
| audio_b64 = await loop.run_in_executor(None, _generate_tts_base64, tts_text) | |
| response_data["query"]["response_message"]["audio_output"] = audio_b64 | |
| else: | |
| # No text β clear audio_output to avoid stale base64 | |
| response_data["query"]["response_message"]["audio_output"] = "" | |
| return JSONResponse(content=response_data) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ENTRY POINT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) | |