focustiki's picture
Update app.py
a5ba31a verified
"""
Data Engineering Knowledge Assistant β€” FastAPI Server
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Routes:
POST /api/chat β†’ streaming SSE (with keep-alive to defeat HF proxy buffering)
POST /api/chat-simple β†’ plain JSON fallback (no streaming)
POST /api/transcribe β†’ voice β†’ text via Groq Whisper
GET /api/health β†’ readiness probe
POST /api/search β†’ raw vector search (debug)
* / β†’ PWA frontend (static/)
"""
from __future__ import annotations
import os
import json
import asyncio
import tempfile
from contextlib import asynccontextmanager
from typing import List, Optional
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
# ──────────────────────────────────────────────────────────────────────────────
# Global state
# ──────────────────────────────────────────────────────────────────────────────
rag_pipeline = None
agent = None
# ──────────────────────────────────────────────────────────────────────────────
# Lifespan β€” init on startup
# ──────────────────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag_pipeline, agent
from rag import DataEngineeringRAG
from agent import DataEngineeringAgent
pdf_path = os.environ.get("PDF_PATH", "knowledge/data_engineering_patterns.pdf")
groq_key = os.environ.get("GROQ_API_KEY", "")
if not groq_key:
print("⚠️ GROQ_API_KEY not set β€” get a free key at https://console.groq.com")
else:
print(f"βœ… GROQ_API_KEY detected ({len(groq_key)} chars)")
print("πŸš€ Starting Data Engineering Knowledge Assistant …")
rag_pipeline = DataEngineeringRAG(pdf_path=pdf_path, groq_api_key=groq_key)
rag_pipeline.initialize()
agent = DataEngineeringAgent(rag=rag_pipeline, groq_api_key=groq_key)
print("βœ… Agent ready β€” listening for requests")
yield
print("πŸ‘‹ Shutting down")
# ──────────────────────────────────────────────────────────────────────────────
# App
# ──────────────────────────────────────────────────────────────────────────────
app = FastAPI(title="DE Knowledge Assistant", version="1.1.0", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# ──────────────────────────────────────────────────────────────────────────────
# Schemas
# ──────────────────────────────────────────────────────────────────────────────
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
message: str
history: Optional[List[ChatMessage]] = []
stream: bool = True
class SearchRequest(BaseModel):
query: str
k: int = 5
# ──────────────────────────────────────────────────────────────────────────────
# Routes
# ──────────────────────────────────────────────────────────────────────────────
@app.get("/api/health")
async def health():
return {
"status": "healthy",
"model": "llama-3.1-8b-instant (Groq)",
"vectorstore_docs": rag_pipeline.get_doc_count() if rag_pipeline else 0,
"groq_key_set": bool(os.environ.get("GROQ_API_KEY")),
"version": "1.1.0",
}
@app.post("/api/chat")
async def chat(req: ChatRequest):
"""
Streaming chat endpoint.
Critical for HF Spaces: we must flush bytes immediately β€” Cloudflare/nginx
will otherwise buffer the whole response until the generator finishes,
defeating streaming entirely. We emit an SSE comment every second as a
heartbeat so the proxy flushes the response chunk-by-chunk.
"""
if not agent:
raise HTTPException(503, "Agent not initialised")
history = [m.model_dump() for m in req.history]
async def event_stream():
# ── Force the proxy to start sending immediately ──────────────────
yield ": keep-alive\n\n"
# ── Heartbeat task β€” keeps the connection flowing in slow moments ──
heartbeat_stop = asyncio.Event()
async def heartbeat(queue: asyncio.Queue):
while not heartbeat_stop.is_set():
try:
await asyncio.wait_for(heartbeat_stop.wait(), timeout=1.0)
except asyncio.TimeoutError:
await queue.put(": ping\n\n")
queue: asyncio.Queue = asyncio.Queue()
hb_task = asyncio.create_task(heartbeat(queue))
async def producer():
try:
async for chunk in agent.astream(message=req.message, history=history):
await queue.put(f"data: {json.dumps({'chunk': chunk})}\n\n")
except Exception as exc:
err = json.dumps({"chunk": f"\n\n⚠️ **Server error:** {type(exc).__name__}: {exc}"})
await queue.put(f"data: {err}\n\n")
finally:
await queue.put("data: [DONE]\n\n")
heartbeat_stop.set()
prod_task = asyncio.create_task(producer())
try:
while True:
item = await queue.get()
yield item
if item.startswith("data: [DONE]"):
break
finally:
hb_task.cancel()
prod_task.cancel()
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)
@app.post("/api/chat-simple")
async def chat_simple(req: ChatRequest):
"""
Non-streaming fallback. Used by the PWA when SSE is blocked
(corporate proxies, some mobile networks, etc.)
"""
if not agent:
raise HTTPException(503, "Agent not initialised")
history = [m.model_dump() for m in req.history]
try:
# Drain the async generator into a single string
chunks = []
async for c in agent.astream(message=req.message, history=history):
chunks.append(c)
return {"response": "".join(chunks)}
except Exception as exc:
raise HTTPException(500, f"{type(exc).__name__}: {exc}")
@app.post("/api/transcribe")
async def transcribe(audio: UploadFile = File(...)):
"""
Voice-to-text via Groq Whisper (free tier).
Accepts webm/ogg/mp3/wav. Returns {"text": "..."}.
Why this beats the browser Web Speech API:
β€’ Works in Brave / Safari / iOS PWA (Web Speech needs Google's proxy)
β€’ Works inside HF Spaces iframe (no cross-origin STT issues)
β€’ Whisper accuracy >> webkitSpeechRecognition
"""
groq_key = os.environ.get("GROQ_API_KEY", "")
if not groq_key:
raise HTTPException(500, "GROQ_API_KEY not configured")
try:
from groq import Groq
client = Groq(api_key=groq_key)
# Save upload to a tmp file (Groq SDK wants a file-like object)
suffix = "." + (audio.filename or "audio.webm").split(".")[-1]
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(await audio.read())
tmp_path = tmp.name
with open(tmp_path, "rb") as f:
result = client.audio.transcriptions.create(
file=(audio.filename or "audio.webm", f.read()),
model="whisper-large-v3-turbo", # fastest, free-tier friendly
language="en",
temperature=0.0,
)
os.unlink(tmp_path)
return {"text": result.text}
except Exception as exc:
raise HTTPException(500, f"Transcription failed: {type(exc).__name__}: {exc}")
@app.post("/api/search")
async def search(req: SearchRequest):
if not rag_pipeline:
raise HTTPException(503, "RAG not initialised")
return {"query": req.query, "results": rag_pipeline.search(req.query, k=req.k)}
# ──────────────────────────────────────────────────────────────────────────────
# Static frontend β€” mount LAST so API routes take priority
# ──────────────────────────────────────────────────────────────────────────────
app.mount("/", StaticFiles(directory="static", html=True), name="static")
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")