chatbot / app.py
digifreely's picture
Update app.py
6801044 verified
import os
import hashlib
import base64
import asyncio
import wave
import io
import httpx
from piper.voice import PiperVoice
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# ── Secrets (set these in HuggingFace Space β†’ Settings β†’ Variables and secrets) ──
EXPECTED_HASH = os.environ.get("HASH_VALUE")
SERVER_MAIN = os.environ.get("SERVER_MAIN_URL")
SERV_CODE = os.environ.get("SERV_CODE")
CF_SECRET_KEY = os.environ.get("CF_SECRET_KEY") # reserved for Cloudflare token verification if needed
ALLOWED_DOMAIN = os.environ.get("ALLOWED_DOMAIN", "buildwithsupratim.github.io")
# ── Piper TTS config ──
PIPER_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "en_US-lessac-medium.onnx")
# Load voice once at startup (heavy; reuse across requests)
_piper_voice: PiperVoice = PiperVoice.load(PIPER_MODEL_PATH)
app = FastAPI(title="Maria Middleware", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["https://buildwithsupratim.github.io"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# ═══════════════════════════════════════════════════════════════
# AUTH HELPERS
# ═══════════════════════════════════════════════════════════════
def _hash_auth_code(auth_code: str) -> str:
return hashlib.sha256(auth_code.encode()).hexdigest()
def _check_first(request: Request) -> bool:
"""Primary check: hash of auth_code header must match EXPECTED_HASH."""
auth_code = request.headers.get("auth_code") or request.headers.get("Auth-Code")
if not auth_code:
return False
return _hash_auth_code(auth_code) == EXPECTED_HASH
def _check_second(request: Request) -> bool:
"""Fallback check: request must originate from the allowed domain.
Cloudflare adds CF-Referer / the standard Referer / Origin headers.
We validate at the domain level so any path under the GitHub Pages site passes.
"""
referer = request.headers.get("referer", "")
origin = request.headers.get("origin", "")
for value in (referer, origin):
if value and ALLOWED_DOMAIN in value:
return True
return False
async def _authorize(request: Request) -> None:
"""Raise 403 if neither check passes."""
if _check_first(request):
return
if _check_second(request):
return
raise HTTPException(status_code=403, detail="Forbidden: invalid auth_code and domain not allowed.")
# ═══════════════════════════════════════════════════════════════
# TTS HELPER
# ═══════════════════════════════════════════════════════════════
def _generate_tts_base64(text: str) -> str:
"""Synthesize *text* with Piper TTS and return base64-encoded WAV bytes."""
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wav_file:
_piper_voice.synthesize(text, wav_file)
return base64.b64encode(wav_buffer.getvalue()).decode("utf-8")
# ═══════════════════════════════════════════════════════════════
# ROUTES
# ═══════════════════════════════════════════════════════════════
@app.get("/")
async def root():
"""Root endpoint for HuggingFace health checks."""
return {"status": "alive"}
@app.get("/ping")
async def ping():
"""Health-check endpoint. Wakes the Space if it was sleeping."""
return {"status": "alive"}
@app.post("/chat_start")
async def chat_start(request: Request):
"""
1. Authenticate the caller.
2. Forward the payload to SERVER_MAIN with serv_code in headers.
3. Override / add audio_output with Piper TTS base64 WAV audio.
4. Return the final response.
"""
await _authorize(request)
# ── Parse incoming JSON ──────────────────────────────────────
try:
payload = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body.")
# ── Forward to SERVER_MAIN ───────────────────────────────────
forward_headers = {
"Content-Type": "application/json",
"serv_code": SERV_CODE,
}
async with httpx.AsyncClient(timeout=600.0) as client:
try:
server_response = await client.post(
SERVER_MAIN,
json=payload,
headers=forward_headers,
)
server_response.raise_for_status()
response_data = server_response.json()
except httpx.HTTPStatusError as exc:
# Extract as much info as possible from the proxy/server
error_details = {
"status_code": exc.response.status_code,
"reason_phrase": exc.response.reason_phrase,
"url": str(exc.request.url),
"response_body": exc.response.text[:500], # Get first 500 chars of the error page
"headers": dict(exc.response.headers) # Useful to see if 'server' is 'uvicorn' or 'nginx'
}
print(f"DEBUG 502: {error_details}") # This will show in Hugging Face Logs
raise HTTPException(
status_code=exc.response.status_code,
detail=f"Downstream Error: {error_details['response_body']}"
)
except httpx.RequestError as exc:
raise HTTPException(status_code=502, detail=f"SERVER_MAIN unreachable: {str(exc)}")
# ── Extract text for TTS ─────────────────────────────────────
try:
response_message = response_data["query"]["response_message"]
tts_text = response_message.get("text", "")
except (KeyError, TypeError):
raise HTTPException(status_code=502, detail="Unexpected response schema from SERVER_MAIN.")
# ── Generate TTS and override audio_output ───────────────────
if tts_text:
# Run blocking Piper synthesis in a thread pool to avoid blocking the event loop
loop = asyncio.get_event_loop()
audio_b64 = await loop.run_in_executor(None, _generate_tts_base64, tts_text)
response_data["query"]["response_message"]["audio_output"] = audio_b64
else:
# No text β†’ clear audio_output to avoid stale base64
response_data["query"]["response_message"]["audio_output"] = ""
return JSONResponse(content=response_data)
# ═══════════════════════════════════════════════════════════════
# ENTRY POINT
# ═══════════════════════════════════════════════════════════════
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)