Spaces:
Sleeping
Sleeping
File size: 7,848 Bytes
b9cacf3 9664f83 b9cacf3 6801044 b9cacf3 6801044 b9cacf3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import os
import hashlib
import base64
import asyncio
import wave
import io
import httpx
from piper.voice import PiperVoice
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# ββ Secrets (set these in HuggingFace Space β Settings β Variables and secrets) ββ
EXPECTED_HASH = os.environ.get("HASH_VALUE")
SERVER_MAIN = os.environ.get("SERVER_MAIN_URL")
SERV_CODE = os.environ.get("SERV_CODE")
CF_SECRET_KEY = os.environ.get("CF_SECRET_KEY") # reserved for Cloudflare token verification if needed
ALLOWED_DOMAIN = os.environ.get("ALLOWED_DOMAIN", "buildwithsupratim.github.io")
# ββ Piper TTS config ββ
PIPER_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "en_US-lessac-medium.onnx")
# Load voice once at startup (heavy; reuse across requests)
_piper_voice: PiperVoice = PiperVoice.load(PIPER_MODEL_PATH)
app = FastAPI(title="Maria Middleware", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["https://buildwithsupratim.github.io"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# AUTH HELPERS
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _hash_auth_code(auth_code: str) -> str:
return hashlib.sha256(auth_code.encode()).hexdigest()
def _check_first(request: Request) -> bool:
"""Primary check: hash of auth_code header must match EXPECTED_HASH."""
auth_code = request.headers.get("auth_code") or request.headers.get("Auth-Code")
if not auth_code:
return False
return _hash_auth_code(auth_code) == EXPECTED_HASH
def _check_second(request: Request) -> bool:
"""Fallback check: request must originate from the allowed domain.
Cloudflare adds CF-Referer / the standard Referer / Origin headers.
We validate at the domain level so any path under the GitHub Pages site passes.
"""
referer = request.headers.get("referer", "")
origin = request.headers.get("origin", "")
for value in (referer, origin):
if value and ALLOWED_DOMAIN in value:
return True
return False
async def _authorize(request: Request) -> None:
"""Raise 403 if neither check passes."""
if _check_first(request):
return
if _check_second(request):
return
raise HTTPException(status_code=403, detail="Forbidden: invalid auth_code and domain not allowed.")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# TTS HELPER
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _generate_tts_base64(text: str) -> str:
"""Synthesize *text* with Piper TTS and return base64-encoded WAV bytes."""
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wav_file:
_piper_voice.synthesize(text, wav_file)
return base64.b64encode(wav_buffer.getvalue()).decode("utf-8")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ROUTES
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.get("/")
async def root():
"""Root endpoint for HuggingFace health checks."""
return {"status": "alive"}
@app.get("/ping")
async def ping():
"""Health-check endpoint. Wakes the Space if it was sleeping."""
return {"status": "alive"}
@app.post("/chat_start")
async def chat_start(request: Request):
"""
1. Authenticate the caller.
2. Forward the payload to SERVER_MAIN with serv_code in headers.
3. Override / add audio_output with Piper TTS base64 WAV audio.
4. Return the final response.
"""
await _authorize(request)
# ββ Parse incoming JSON ββββββββββββββββββββββββββββββββββββββ
try:
payload = await request.json()
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body.")
# ββ Forward to SERVER_MAIN βββββββββββββββββββββββββββββββββββ
forward_headers = {
"Content-Type": "application/json",
"serv_code": SERV_CODE,
}
async with httpx.AsyncClient(timeout=600.0) as client:
try:
server_response = await client.post(
SERVER_MAIN,
json=payload,
headers=forward_headers,
)
server_response.raise_for_status()
response_data = server_response.json()
except httpx.HTTPStatusError as exc:
# Extract as much info as possible from the proxy/server
error_details = {
"status_code": exc.response.status_code,
"reason_phrase": exc.response.reason_phrase,
"url": str(exc.request.url),
"response_body": exc.response.text[:500], # Get first 500 chars of the error page
"headers": dict(exc.response.headers) # Useful to see if 'server' is 'uvicorn' or 'nginx'
}
print(f"DEBUG 502: {error_details}") # This will show in Hugging Face Logs
raise HTTPException(
status_code=exc.response.status_code,
detail=f"Downstream Error: {error_details['response_body']}"
)
except httpx.RequestError as exc:
raise HTTPException(status_code=502, detail=f"SERVER_MAIN unreachable: {str(exc)}")
# ββ Extract text for TTS βββββββββββββββββββββββββββββββββββββ
try:
response_message = response_data["query"]["response_message"]
tts_text = response_message.get("text", "")
except (KeyError, TypeError):
raise HTTPException(status_code=502, detail="Unexpected response schema from SERVER_MAIN.")
# ββ Generate TTS and override audio_output βββββββββββββββββββ
if tts_text:
# Run blocking Piper synthesis in a thread pool to avoid blocking the event loop
loop = asyncio.get_event_loop()
audio_b64 = await loop.run_in_executor(None, _generate_tts_base64, tts_text)
response_data["query"]["response_message"]["audio_output"] = audio_b64
else:
# No text β clear audio_output to avoid stale base64
response_data["query"]["response_message"]["audio_output"] = ""
return JSONResponse(content=response_data)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ENTRY POINT
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
|