import os import hashlib import base64 import asyncio import wave import io import httpx from piper.voice import PiperVoice from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware # ── Secrets (set these in HuggingFace Space → Settings → Variables and secrets) ── EXPECTED_HASH = os.environ.get("HASH_VALUE") SERVER_MAIN = os.environ.get("SERVER_MAIN_URL") SERV_CODE = os.environ.get("SERV_CODE") CF_SECRET_KEY = os.environ.get("CF_SECRET_KEY") # reserved for Cloudflare token verification if needed ALLOWED_DOMAIN = os.environ.get("ALLOWED_DOMAIN", "buildwithsupratim.github.io") # ── Piper TTS config ── PIPER_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "en_US-lessac-medium.onnx") # Load voice once at startup (heavy; reuse across requests) _piper_voice: PiperVoice = PiperVoice.load(PIPER_MODEL_PATH) app = FastAPI(title="Maria Middleware", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["https://buildwithsupratim.github.io"], allow_methods=["GET", "POST"], allow_headers=["*"], ) # ═══════════════════════════════════════════════════════════════ # AUTH HELPERS # ═══════════════════════════════════════════════════════════════ def _hash_auth_code(auth_code: str) -> str: return hashlib.sha256(auth_code.encode()).hexdigest() def _check_first(request: Request) -> bool: """Primary check: hash of auth_code header must match EXPECTED_HASH.""" auth_code = request.headers.get("auth_code") or request.headers.get("Auth-Code") if not auth_code: return False return _hash_auth_code(auth_code) == EXPECTED_HASH def _check_second(request: Request) -> bool: """Fallback check: request must originate from the allowed domain. Cloudflare adds CF-Referer / the standard Referer / Origin headers. We validate at the domain level so any path under the GitHub Pages site passes. """ referer = request.headers.get("referer", "") origin = request.headers.get("origin", "") for value in (referer, origin): if value and ALLOWED_DOMAIN in value: return True return False async def _authorize(request: Request) -> None: """Raise 403 if neither check passes.""" if _check_first(request): return if _check_second(request): return raise HTTPException(status_code=403, detail="Forbidden: invalid auth_code and domain not allowed.") # ═══════════════════════════════════════════════════════════════ # TTS HELPER # ═══════════════════════════════════════════════════════════════ def _generate_tts_base64(text: str) -> str: """Synthesize *text* with Piper TTS and return base64-encoded WAV bytes.""" wav_buffer = io.BytesIO() with wave.open(wav_buffer, "wb") as wav_file: _piper_voice.synthesize(text, wav_file) return base64.b64encode(wav_buffer.getvalue()).decode("utf-8") # ═══════════════════════════════════════════════════════════════ # ROUTES # ═══════════════════════════════════════════════════════════════ @app.get("/") async def root(): """Root endpoint for HuggingFace health checks.""" return {"status": "alive"} @app.get("/ping") async def ping(): """Health-check endpoint. Wakes the Space if it was sleeping.""" return {"status": "alive"} @app.post("/chat_start") async def chat_start(request: Request): """ 1. Authenticate the caller. 2. Forward the payload to SERVER_MAIN with serv_code in headers. 3. Override / add audio_output with Piper TTS base64 WAV audio. 4. Return the final response. """ await _authorize(request) # ── Parse incoming JSON ────────────────────────────────────── try: payload = await request.json() except Exception: raise HTTPException(status_code=400, detail="Invalid JSON body.") # ── Forward to SERVER_MAIN ─────────────────────────────────── forward_headers = { "Content-Type": "application/json", "serv_code": SERV_CODE, } async with httpx.AsyncClient(timeout=600.0) as client: try: server_response = await client.post( SERVER_MAIN, json=payload, headers=forward_headers, ) server_response.raise_for_status() response_data = server_response.json() except httpx.HTTPStatusError as exc: # Extract as much info as possible from the proxy/server error_details = { "status_code": exc.response.status_code, "reason_phrase": exc.response.reason_phrase, "url": str(exc.request.url), "response_body": exc.response.text[:500], # Get first 500 chars of the error page "headers": dict(exc.response.headers) # Useful to see if 'server' is 'uvicorn' or 'nginx' } print(f"DEBUG 502: {error_details}") # This will show in Hugging Face Logs raise HTTPException( status_code=exc.response.status_code, detail=f"Downstream Error: {error_details['response_body']}" ) except httpx.RequestError as exc: raise HTTPException(status_code=502, detail=f"SERVER_MAIN unreachable: {str(exc)}") # ── Extract text for TTS ───────────────────────────────────── try: response_message = response_data["query"]["response_message"] tts_text = response_message.get("text", "") except (KeyError, TypeError): raise HTTPException(status_code=502, detail="Unexpected response schema from SERVER_MAIN.") # ── Generate TTS and override audio_output ─────────────────── if tts_text: # Run blocking Piper synthesis in a thread pool to avoid blocking the event loop loop = asyncio.get_event_loop() audio_b64 = await loop.run_in_executor(None, _generate_tts_base64, tts_text) response_data["query"]["response_message"]["audio_output"] = audio_b64 else: # No text → clear audio_output to avoid stale base64 response_data["query"]["response_message"]["audio_output"] = "" return JSONResponse(content=response_data) # ═══════════════════════════════════════════════════════════════ # ENTRY POINT # ═══════════════════════════════════════════════════════════════ if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)