Spaces:
Sleeping
Sleeping
| """ | |
| Data Engineering Knowledge Assistant β FastAPI Server | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| Routes: | |
| POST /api/chat β streaming SSE (with keep-alive to defeat HF proxy buffering) | |
| POST /api/chat-simple β plain JSON fallback (no streaming) | |
| POST /api/transcribe β voice β text via Groq Whisper | |
| GET /api/health β readiness probe | |
| POST /api/search β raw vector search (debug) | |
| * / β PWA frontend (static/) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import json | |
| import asyncio | |
| import tempfile | |
| from contextlib import asynccontextmanager | |
| from typing import List, Optional | |
| from fastapi import FastAPI, HTTPException, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Global state | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| rag_pipeline = None | |
| agent = None | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Lifespan β init on startup | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| global rag_pipeline, agent | |
| from rag import DataEngineeringRAG | |
| from agent import DataEngineeringAgent | |
| pdf_path = os.environ.get("PDF_PATH", "knowledge/data_engineering_patterns.pdf") | |
| groq_key = os.environ.get("GROQ_API_KEY", "") | |
| if not groq_key: | |
| print("β οΈ GROQ_API_KEY not set β get a free key at https://console.groq.com") | |
| else: | |
| print(f"β GROQ_API_KEY detected ({len(groq_key)} chars)") | |
| print("π Starting Data Engineering Knowledge Assistant β¦") | |
| rag_pipeline = DataEngineeringRAG(pdf_path=pdf_path, groq_api_key=groq_key) | |
| rag_pipeline.initialize() | |
| agent = DataEngineeringAgent(rag=rag_pipeline, groq_api_key=groq_key) | |
| print("β Agent ready β listening for requests") | |
| yield | |
| print("π Shutting down") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # App | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="DE Knowledge Assistant", version="1.1.0", lifespan=lifespan) | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Schemas | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| message: str | |
| history: Optional[List[ChatMessage]] = [] | |
| stream: bool = True | |
| class SearchRequest(BaseModel): | |
| query: str | |
| k: int = 5 | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Routes | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health(): | |
| return { | |
| "status": "healthy", | |
| "model": "llama-3.1-8b-instant (Groq)", | |
| "vectorstore_docs": rag_pipeline.get_doc_count() if rag_pipeline else 0, | |
| "groq_key_set": bool(os.environ.get("GROQ_API_KEY")), | |
| "version": "1.1.0", | |
| } | |
| async def chat(req: ChatRequest): | |
| """ | |
| Streaming chat endpoint. | |
| Critical for HF Spaces: we must flush bytes immediately β Cloudflare/nginx | |
| will otherwise buffer the whole response until the generator finishes, | |
| defeating streaming entirely. We emit an SSE comment every second as a | |
| heartbeat so the proxy flushes the response chunk-by-chunk. | |
| """ | |
| if not agent: | |
| raise HTTPException(503, "Agent not initialised") | |
| history = [m.model_dump() for m in req.history] | |
| async def event_stream(): | |
| # ββ Force the proxy to start sending immediately ββββββββββββββββββ | |
| yield ": keep-alive\n\n" | |
| # ββ Heartbeat task β keeps the connection flowing in slow moments ββ | |
| heartbeat_stop = asyncio.Event() | |
| async def heartbeat(queue: asyncio.Queue): | |
| while not heartbeat_stop.is_set(): | |
| try: | |
| await asyncio.wait_for(heartbeat_stop.wait(), timeout=1.0) | |
| except asyncio.TimeoutError: | |
| await queue.put(": ping\n\n") | |
| queue: asyncio.Queue = asyncio.Queue() | |
| hb_task = asyncio.create_task(heartbeat(queue)) | |
| async def producer(): | |
| try: | |
| async for chunk in agent.astream(message=req.message, history=history): | |
| await queue.put(f"data: {json.dumps({'chunk': chunk})}\n\n") | |
| except Exception as exc: | |
| err = json.dumps({"chunk": f"\n\nβ οΈ **Server error:** {type(exc).__name__}: {exc}"}) | |
| await queue.put(f"data: {err}\n\n") | |
| finally: | |
| await queue.put("data: [DONE]\n\n") | |
| heartbeat_stop.set() | |
| prod_task = asyncio.create_task(producer()) | |
| try: | |
| while True: | |
| item = await queue.get() | |
| yield item | |
| if item.startswith("data: [DONE]"): | |
| break | |
| finally: | |
| hb_task.cancel() | |
| prod_task.cancel() | |
| return StreamingResponse( | |
| event_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache, no-transform", | |
| "X-Accel-Buffering": "no", | |
| "Connection": "keep-alive", | |
| }, | |
| ) | |
| async def chat_simple(req: ChatRequest): | |
| """ | |
| Non-streaming fallback. Used by the PWA when SSE is blocked | |
| (corporate proxies, some mobile networks, etc.) | |
| """ | |
| if not agent: | |
| raise HTTPException(503, "Agent not initialised") | |
| history = [m.model_dump() for m in req.history] | |
| try: | |
| # Drain the async generator into a single string | |
| chunks = [] | |
| async for c in agent.astream(message=req.message, history=history): | |
| chunks.append(c) | |
| return {"response": "".join(chunks)} | |
| except Exception as exc: | |
| raise HTTPException(500, f"{type(exc).__name__}: {exc}") | |
| async def transcribe(audio: UploadFile = File(...)): | |
| """ | |
| Voice-to-text via Groq Whisper (free tier). | |
| Accepts webm/ogg/mp3/wav. Returns {"text": "..."}. | |
| Why this beats the browser Web Speech API: | |
| β’ Works in Brave / Safari / iOS PWA (Web Speech needs Google's proxy) | |
| β’ Works inside HF Spaces iframe (no cross-origin STT issues) | |
| β’ Whisper accuracy >> webkitSpeechRecognition | |
| """ | |
| groq_key = os.environ.get("GROQ_API_KEY", "") | |
| if not groq_key: | |
| raise HTTPException(500, "GROQ_API_KEY not configured") | |
| try: | |
| from groq import Groq | |
| client = Groq(api_key=groq_key) | |
| # Save upload to a tmp file (Groq SDK wants a file-like object) | |
| suffix = "." + (audio.filename or "audio.webm").split(".")[-1] | |
| with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: | |
| tmp.write(await audio.read()) | |
| tmp_path = tmp.name | |
| with open(tmp_path, "rb") as f: | |
| result = client.audio.transcriptions.create( | |
| file=(audio.filename or "audio.webm", f.read()), | |
| model="whisper-large-v3-turbo", # fastest, free-tier friendly | |
| language="en", | |
| temperature=0.0, | |
| ) | |
| os.unlink(tmp_path) | |
| return {"text": result.text} | |
| except Exception as exc: | |
| raise HTTPException(500, f"Transcription failed: {type(exc).__name__}: {exc}") | |
| async def search(req: SearchRequest): | |
| if not rag_pipeline: | |
| raise HTTPException(503, "RAG not initialised") | |
| return {"query": req.query, "results": rag_pipeline.search(req.query, k=req.k)} | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Static frontend β mount LAST so API routes take priority | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info") |