Spaces:
Running
Running
| """ | |
| FastAPI application — routes for chat (SSE streaming), auth, settings, and static files. | |
| """ | |
| import os | |
| import json | |
| import time | |
| import io | |
| from typing import Optional | |
| from fastapi import FastAPI, Request, UploadFile, File | |
| from fastapi.responses import FileResponse, StreamingResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from langchain_core.messages import HumanMessage, AIMessage | |
| import db | |
| import guard | |
| import retriever | |
| import auth as auth_module | |
| import feed as feed_module | |
| from config import GOOGLE_CLIENT_ID | |
| from graph import chatbot | |
| from tools import run_web_search, fetch_yt_transcript | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| STATIC_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") | |
| ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | |
| # --- Request / Response models --- | |
| class GoogleAuthRequest(BaseModel): | |
| token: str | |
| class ChatRequest(BaseModel): | |
| message: str | |
| thread_id: str | |
| persona: str = "nerd" | |
| language: str = "auto" | |
| username: str = "" | |
| user_id: str = "" | |
| image: str = "" # base64 encoded image data | |
| doc_text: str = "" # extracted text from attached document | |
| doc_name: str = "" # exact name of the document | |
| doc_bytes: str = "" # base64 encoded bytes of the document | |
| class RenameRequest(BaseModel): | |
| thread_id: str | |
| title: str | |
| class ApiKeyRequest(BaseModel): | |
| user_id: str | |
| key: str | |
| class TavilyKeyRequest(BaseModel): | |
| user_id: str | |
| key: str | |
| class ProfileRequest(BaseModel): | |
| user_id: str | |
| profile: str | |
| # --- Feedback Endpoint --- | |
| class FeedbackRequest(BaseModel): | |
| user_id: str | |
| user_email: Optional[str] = "" | |
| user_name: Optional[str] = "" | |
| category: str | |
| overall: int = 5 | |
| ease: int = 4 | |
| quality: int = 4 | |
| message: str | |
| attachments: Optional[list[dict]] = None # [{filename, content (b64), content_type}] | |
| async def submit_feedback(req: FeedbackRequest): | |
| ok = feed_module.send_feedback( | |
| user_id=req.user_id, | |
| user_email=req.user_email or "", | |
| user_name=req.user_name or "", | |
| overall=req.overall, | |
| ease=req.ease, | |
| quality=req.quality, | |
| category=req.category, | |
| message=req.message, | |
| attachments=req.attachments, | |
| ) | |
| return JSONResponse(content={"status": "success" if ok else "logged"}) | |
| # --- SSE helpers --- | |
| def sse_token(token: str) -> str: | |
| return f"data: {json.dumps({'token': token})}\n\n" | |
| def sse_error(message: str) -> str: | |
| return f"data: {json.dumps({'error': message})}\n\n" | |
| def sse_done() -> str: | |
| return "data: [DONE]\n\n" | |
| def sse_tool_event(event: str, tool_name: str) -> str: | |
| """Emit tool_start / tool_end events so the frontend can show a progress bar.""" | |
| return f"data: {json.dumps({'tool_event': event, 'tool': tool_name})}\n\n" | |
| # --- Auth routes --- | |
| def google_login(req: GoogleAuthRequest): | |
| """Verify Google OAuth ID token and create/update user.""" | |
| idinfo = auth_module.verify_google_token(req.token) | |
| if not idinfo: | |
| return JSONResponse({"error": "Invalid token"}, status_code=401) | |
| user = db.upsert_user( | |
| google_id=idinfo["sub"], | |
| email=idinfo.get("email", ""), | |
| name=idinfo.get("name", ""), | |
| picture=idinfo.get("picture", ""), | |
| ) | |
| has_key = bool(db.get_user_api_key(idinfo["sub"])) | |
| has_tavily = bool(db.get_tavily_key(idinfo["sub"])) | |
| return {"user": user, "has_api_key": has_key, "has_tavily_key": has_tavily} | |
| def get_me(user_id: str = ""): | |
| """Get current user data.""" | |
| if not user_id: | |
| return JSONResponse({"error": "user_id required"}, status_code=400) | |
| user = db.get_user(user_id) | |
| if not user: | |
| return JSONResponse({"error": "User not found"}, status_code=404) | |
| return { | |
| "user": user, | |
| "has_api_key": bool(user.get("openrouter_key")), | |
| "has_tavily_key": bool(user.get("tavily_key")), | |
| } | |
| def get_client_id(): | |
| """Return the Google Client ID for frontend OAuth init.""" | |
| return {"client_id": GOOGLE_CLIENT_ID} | |
| # --- User settings routes --- | |
| def save_api_key(req: ApiKeyRequest): | |
| db.save_user_api_key(req.user_id, req.key) | |
| return {"ok": True} | |
| def save_tavily_key(req: TavilyKeyRequest): | |
| db.save_tavily_key(req.user_id, req.key) | |
| return {"ok": True, "has_tavily_key": bool(req.key.strip())} | |
| def save_profile(req: ProfileRequest): | |
| db.save_student_profile(req.user_id, req.profile) | |
| return {"ok": True} | |
| async def upload_doc(file: UploadFile = File(...)): | |
| """ | |
| Accept a PDF, DOCX, or plain-text file and return its extracted text. | |
| The frontend sends this text back as `doc_text` in the chat request. | |
| """ | |
| filename = file.filename or "" | |
| ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" | |
| try: | |
| raw_bytes = await file.read() | |
| except Exception as exc: | |
| print(f"[UPLOAD ERROR] Failed to read uploaded file: {exc}") | |
| return JSONResponse({"error": f"Failed to read file: {exc}"}, status_code=400) | |
| text = "" | |
| try: | |
| if ext == "pdf": | |
| import pdfplumber #type:ignore | |
| with pdfplumber.open(io.BytesIO(raw_bytes)) as pdf: | |
| text = "\n".join(page.extract_text() or "" for page in pdf.pages) | |
| elif ext in ("docx", "doc"): | |
| from docx import Document as DocxDocument | |
| doc = DocxDocument(io.BytesIO(raw_bytes)) | |
| text = "\n".join(p.text for p in doc.paragraphs) | |
| else: | |
| # Plain text / markdown / code files | |
| text = raw_bytes.decode("utf-8", errors="replace") | |
| except Exception as exc: | |
| print(f"[PARSE ERROR] Failed to parse document {filename}: {exc}") | |
| return JSONResponse({"error": f"Could not parse file content: {exc}"}, status_code=422) | |
| text = text.strip() | |
| if not text: | |
| return JSONResponse({"error": "No readable text found in file."}, status_code=422) | |
| # Return first 20,000 chars of extracted text | |
| return {"filename": filename, "text": text[:20_000]} | |
| # --- Core routes --- | |
| APP_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| def serve_sw(): | |
| """Service worker must be served from root for PWA scope.""" | |
| return FileResponse(os.path.join(APP_DIR, "sw.js"), media_type="application/javascript") | |
| def serve_manifest(): | |
| return FileResponse(os.path.join(STATIC_DIR, "manifest.json"), media_type="application/manifest+json") | |
| def serve_index(): | |
| return FileResponse(os.path.join(STATIC_DIR, "index.html")) | |
| def get_threads(user_id: str = ""): | |
| return {"threads": db.get_threads(user_id)} | |
| def get_history(thread_id: str): | |
| config = {"configurable": {"thread_id": thread_id}} | |
| state = chatbot.get_state(config) | |
| messages = state.values.get("messages", []) | |
| result = [] | |
| for msg in messages: | |
| role = "user" if isinstance(msg, HumanMessage) else "assistant" | |
| result.append({"role": role, "content": msg.content}) | |
| return {"messages": result} | |
| def chat(request: ChatRequest, req: Request): | |
| now = time.time() | |
| db.upsert_thread(request.thread_id, request.message, now, request.user_id) | |
| # Resolve user-specific data | |
| user_api_key = "" | |
| tavily_key = "" | |
| student_profile = "" | |
| if request.user_id: | |
| user_api_key = db.get_user_api_key(request.user_id) | |
| tavily_key = db.get_tavily_key(request.user_id) | |
| user_data = db.get_user(request.user_id) | |
| if user_data: | |
| student_profile = user_data.get("student_profile", "") | |
| def stream(): | |
| # Step 1: Guard | |
| is_ok, rejection = guard.check_input(request.message) | |
| if not is_ok: | |
| yield sse_token(rejection) | |
| yield sse_done() | |
| return | |
| # Step 2: Extract search mode sentinel injected by frontend, clean query | |
| _WEB_SENTINEL = "[System Instruction: Web Search is enabled." | |
| _YT_SENTINEL = "[System Instruction: YouTube Video Search is enabled." | |
| raw_message = request.message | |
| clean_message = raw_message | |
| web_search_enabled = False | |
| yt_search_enabled = False | |
| if raw_message.startswith(_WEB_SENTINEL): | |
| web_search_enabled = True | |
| parts = raw_message.split("User Query: ", 1) | |
| clean_message = parts[1].strip() if len(parts) > 1 else raw_message | |
| elif raw_message.startswith(_YT_SENTINEL): | |
| yt_search_enabled = True | |
| parts = raw_message.split("User Query: ", 1) | |
| clean_message = parts[1].strip() if len(parts) > 1 else raw_message | |
| # Step 3: Retrieve context from Pinecone (resilient to errors) | |
| try: | |
| results = retriever.search(clean_message) | |
| context = retriever.format_context(results) | |
| except Exception as exc: | |
| print(f"[PINECONE ERROR] Failed to retrieve context: {exc}") | |
| context = "No relevant context found due to a temporary search error." | |
| # Step 3.5: Execute search/transcript tools programmatically | |
| if web_search_enabled: | |
| yield sse_tool_event("tool_start", "web_search") | |
| try: | |
| web_results = run_web_search(clean_message, api_key=tavily_key) | |
| print(f"[WEB SEARCH] Fetched {len(web_results)} chars for: {clean_message[:60]}", flush=True) | |
| context += f"\n\n[Web Search Results]\n{web_results}\n[End Web Search Results]\n" | |
| except Exception as e: | |
| print(f"[SEARCH ERROR] {e}", flush=True) | |
| context += "\n\nWeb search failed to complete.\n" | |
| yield sse_tool_event("tool_end", "web_search") | |
| elif yt_search_enabled: | |
| yield sse_tool_event("tool_start", "yt_transcript") | |
| try: | |
| transcript = fetch_yt_transcript(clean_message) | |
| if transcript.startswith("TRANSCRIPT_UNAVAILABLE"): | |
| print(f"[YT TRANSCRIPT] Unavailable: {transcript[:120]}", flush=True) | |
| context += f"\n\n[YouTube Transcript Status]\n{transcript}\n" | |
| else: | |
| print(f"[YT TRANSCRIPT] Fetched {len(transcript)} chars", flush=True) | |
| context += f"\n\n[YouTube Video Transcript]\n{transcript[:15_000]}\n[End Transcript]\n" | |
| except Exception as e: | |
| print(f"[YT TRANSCRIPT ERROR] {e}", flush=True) | |
| context += "\n\n[YouTube Transcript Status]\nTRANSCRIPT_UNAVAILABLE: extraction failed unexpectedly.\n" | |
| yield sse_tool_event("tool_end", "yt_transcript") | |
| # Step 4: Handle document context | |
| doc_context = "" | |
| if request.doc_text: | |
| doc_label = f" (Name: {request.doc_name})" if request.doc_name else "" | |
| doc_context = ( | |
| f"\n\n[ATTACHED DOCUMENT{doc_label}]\n" | |
| + request.doc_text[:20_000] | |
| + "\n[END DOCUMENT]\n" | |
| ) | |
| # Step 5: Build message (text or multimodal with image) | |
| if request.image: | |
| msg_content = [ | |
| {"type": "text", "text": clean_message + doc_context}, | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{request.image}"}, | |
| }, | |
| ] | |
| else: | |
| msg_content = clean_message + doc_context | |
| # Step 6: Stream LLM response via LangGraph | |
| has_image = bool(request.image) | |
| # Resolve target model dynamically to print/trace and pass to LangGraph config | |
| try: | |
| from graph import _classify, _pick | |
| category = _classify(clean_message) | |
| target_model, _ = _pick(category, has_image=has_image) | |
| except Exception as e: | |
| print(f"[MODEL PICK ERROR] {e}", flush=True) | |
| target_model = "" | |
| print(f"[ROUTER CHAT] Routing query to model: {target_model}", flush=True) | |
| config = { | |
| "configurable": { | |
| "thread_id": request.thread_id, | |
| "persona": request.persona, | |
| "context": context, | |
| "language": request.language, | |
| "username": request.username, | |
| "student_profile": student_profile, | |
| "user_api_key": user_api_key, | |
| "has_image": has_image, | |
| "search_enabled": (web_search_enabled or yt_search_enabled), | |
| "model": target_model, | |
| "doc_name": request.doc_name, | |
| "doc_bytes": request.doc_bytes, | |
| } | |
| } | |
| try: | |
| for chunk, _metadata in chatbot.stream( | |
| {"messages": [HumanMessage(content=msg_content)]}, | |
| config=config, | |
| stream_mode="messages", | |
| ): | |
| if isinstance(chunk, AIMessage) and chunk.content: | |
| yield sse_token(chunk.content) | |
| except Exception as e: | |
| error_str = str(e).lower() | |
| if "429" in str(e) or "rate" in error_str: | |
| yield sse_error("Rate limited. Please wait a moment and try again.") | |
| elif "402" in str(e) or "payment" in error_str or "credits" in error_str: | |
| yield sse_error("Free credits exhausted on OpenRouter.") | |
| elif "404" in str(e) or "not found" in error_str or "no endpoints" in error_str: | |
| yield sse_error("Model unavailable. Please try again.") | |
| elif "401" in str(e) or "unauthorized" in error_str or "invalid" in error_str: | |
| yield sse_error("Invalid API key.") | |
| else: | |
| yield sse_error("Something went wrong. Please try again.") | |
| yield sse_done() | |
| return | |
| yield sse_done() | |
| return StreamingResponse( | |
| stream(), | |
| media_type="text/event-stream", | |
| headers={"X-Accel-Buffering": "no"}, | |
| ) | |
| def rename_thread(request: RenameRequest): | |
| db.rename_thread(request.thread_id, request.title) | |
| return {"ok": True} | |
| def delete_thread(thread_id: str): | |
| db.delete_thread(thread_id) | |
| return {"ok": True} | |
| # --- Health check --- | |
| def health(): | |
| checks = {"db": False, "pinecone": False} | |
| try: | |
| db.conn.execute("SELECT 1") | |
| checks["db"] = True | |
| except Exception: | |
| pass | |
| try: | |
| retriever._index.describe_index_stats() | |
| checks["pinecone"] = True | |
| except Exception: | |
| pass | |
| ok = all(checks.values()) | |
| return JSONResponse( | |
| {"status": "ok" if ok else "degraded", "checks": checks}, | |
| status_code=200 if ok else 503, | |
| ) | |
| # --- Static file serving --- | |
| if os.path.isdir(ASSETS_DIR): | |
| app.mount("/assets", StaticFiles(directory=ASSETS_DIR), name="assets") | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| def serve_static(filepath: str): | |
| full_path = os.path.join(STATIC_DIR, filepath) | |
| if os.path.isfile(full_path): | |
| return FileResponse(full_path) | |
| return JSONResponse({"error": "not found"}, status_code=404) | |