File size: 7,848 Bytes
b9cacf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9664f83
b9cacf3
 
 
 
 
 
 
 
 
6801044
 
 
 
 
 
 
 
 
 
 
b9cacf3
 
6801044
b9cacf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import hashlib
import base64
import asyncio
import wave
import io
import httpx
from piper.voice import PiperVoice
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware

# ── Secrets (set these in HuggingFace Space β†’ Settings β†’ Variables and secrets) ──
EXPECTED_HASH  = os.environ.get("HASH_VALUE")
SERVER_MAIN    = os.environ.get("SERVER_MAIN_URL")
SERV_CODE      = os.environ.get("SERV_CODE")
CF_SECRET_KEY  = os.environ.get("CF_SECRET_KEY")   # reserved for Cloudflare token verification if needed

ALLOWED_DOMAIN = os.environ.get("ALLOWED_DOMAIN", "buildwithsupratim.github.io")

# ── Piper TTS config ──
PIPER_MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "en_US-lessac-medium.onnx")

# Load voice once at startup (heavy; reuse across requests)
_piper_voice: PiperVoice = PiperVoice.load(PIPER_MODEL_PATH)

app = FastAPI(title="Maria Middleware", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://buildwithsupratim.github.io"],
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

# ═══════════════════════════════════════════════════════════════
#  AUTH HELPERS
# ═══════════════════════════════════════════════════════════════

def _hash_auth_code(auth_code: str) -> str:
    return hashlib.sha256(auth_code.encode()).hexdigest()


def _check_first(request: Request) -> bool:
    """Primary check: hash of auth_code header must match EXPECTED_HASH."""
    auth_code = request.headers.get("auth_code") or request.headers.get("Auth-Code")
    if not auth_code:
        return False
    return _hash_auth_code(auth_code) == EXPECTED_HASH


def _check_second(request: Request) -> bool:
    """Fallback check: request must originate from the allowed domain.
    
    Cloudflare adds CF-Referer / the standard Referer / Origin headers.
    We validate at the domain level so any path under the GitHub Pages site passes.
    """
    referer = request.headers.get("referer", "")
    origin  = request.headers.get("origin",  "")

    for value in (referer, origin):
        if value and ALLOWED_DOMAIN in value:
            return True
    return False


async def _authorize(request: Request) -> None:
    """Raise 403 if neither check passes."""
    if _check_first(request):
        return
    if _check_second(request):
        return
    raise HTTPException(status_code=403, detail="Forbidden: invalid auth_code and domain not allowed.")


# ═══════════════════════════════════════════════════════════════
#  TTS HELPER
# ═══════════════════════════════════════════════════════════════

def _generate_tts_base64(text: str) -> str:
    """Synthesize *text* with Piper TTS and return base64-encoded WAV bytes."""
    wav_buffer = io.BytesIO()
    with wave.open(wav_buffer, "wb") as wav_file:
        _piper_voice.synthesize(text, wav_file)
    return base64.b64encode(wav_buffer.getvalue()).decode("utf-8")


# ═══════════════════════════════════════════════════════════════
#  ROUTES
# ═══════════════════════════════════════════════════════════════

@app.get("/")
async def root():
    """Root endpoint for HuggingFace health checks."""
    return {"status": "alive"}


@app.get("/ping")
async def ping():
    """Health-check endpoint. Wakes the Space if it was sleeping."""
    return {"status": "alive"}


@app.post("/chat_start")
async def chat_start(request: Request):
    """
    1. Authenticate the caller.
    2. Forward the payload to SERVER_MAIN with serv_code in headers.
    3. Override / add audio_output with Piper TTS base64 WAV audio.
    4. Return the final response.
    """
    await _authorize(request)

    # ── Parse incoming JSON ──────────────────────────────────────
    try:
        payload = await request.json()
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid JSON body.")

    # ── Forward to SERVER_MAIN ───────────────────────────────────
    forward_headers = {
        "Content-Type": "application/json",
        "serv_code":    SERV_CODE,
    }

    async with httpx.AsyncClient(timeout=600.0) as client:
        try:
            server_response = await client.post(
                SERVER_MAIN,
                json=payload,
                headers=forward_headers,
            )
            server_response.raise_for_status()
            response_data = server_response.json()
        except httpx.HTTPStatusError as exc:
            # Extract as much info as possible from the proxy/server
            error_details = {
                "status_code": exc.response.status_code,
                "reason_phrase": exc.response.reason_phrase,
                "url": str(exc.request.url),
                "response_body": exc.response.text[:500],  # Get first 500 chars of the error page
                "headers": dict(exc.response.headers)      # Useful to see if 'server' is 'uvicorn' or 'nginx'
            }
            
            print(f"DEBUG 502: {error_details}") # This will show in Hugging Face Logs
            
            raise HTTPException(
                status_code=exc.response.status_code,
                detail=f"Downstream Error: {error_details['response_body']}"
            )
        except httpx.RequestError as exc:
            raise HTTPException(status_code=502, detail=f"SERVER_MAIN unreachable: {str(exc)}")

    # ── Extract text for TTS ─────────────────────────────────────
    try:
        response_message = response_data["query"]["response_message"]
        tts_text = response_message.get("text", "")
    except (KeyError, TypeError):
        raise HTTPException(status_code=502, detail="Unexpected response schema from SERVER_MAIN.")

    # ── Generate TTS and override audio_output ───────────────────
    if tts_text:
        # Run blocking Piper synthesis in a thread pool to avoid blocking the event loop
        loop = asyncio.get_event_loop()
        audio_b64 = await loop.run_in_executor(None, _generate_tts_base64, tts_text)
        response_data["query"]["response_message"]["audio_output"] = audio_b64
    else:
        # No text β†’ clear audio_output to avoid stale base64
        response_data["query"]["response_message"]["audio_output"] = ""

    return JSONResponse(content=response_data)


# ═══════════════════════════════════════════════════════════════
#  ENTRY POINT
# ═══════════════════════════════════════════════════════════════

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)