""" FastAPI application — routes for chat (SSE streaming), auth, settings, and static files. """ import os import json import time import io from typing import Optional from fastapi import FastAPI, Request, UploadFile, File from fastapi.responses import FileResponse, StreamingResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from langchain_core.messages import HumanMessage, AIMessage import db import guard import retriever import auth as auth_module import feed as feed_module from config import GOOGLE_CLIENT_ID from graph import chatbot from tools import run_web_search, fetch_yt_transcript app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) STATIC_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") # --- Request / Response models --- class GoogleAuthRequest(BaseModel): token: str class ChatRequest(BaseModel): message: str thread_id: str persona: str = "nerd" language: str = "auto" username: str = "" user_id: str = "" image: str = "" # base64 encoded image data doc_text: str = "" # extracted text from attached document doc_name: str = "" # exact name of the document doc_bytes: str = "" # base64 encoded bytes of the document class RenameRequest(BaseModel): thread_id: str title: str class ApiKeyRequest(BaseModel): user_id: str key: str class TavilyKeyRequest(BaseModel): user_id: str key: str class ProfileRequest(BaseModel): user_id: str profile: str # --- Feedback Endpoint --- class FeedbackRequest(BaseModel): user_id: str user_email: Optional[str] = "" user_name: Optional[str] = "" category: str overall: int = 5 ease: int = 4 quality: int = 4 message: str attachments: Optional[list[dict]] = None # [{filename, content (b64), content_type}] @app.post("/feedback") async def submit_feedback(req: FeedbackRequest): ok = feed_module.send_feedback( user_id=req.user_id, user_email=req.user_email or "", user_name=req.user_name or "", overall=req.overall, ease=req.ease, quality=req.quality, category=req.category, message=req.message, attachments=req.attachments, ) return JSONResponse(content={"status": "success" if ok else "logged"}) # --- SSE helpers --- def sse_token(token: str) -> str: return f"data: {json.dumps({'token': token})}\n\n" def sse_error(message: str) -> str: return f"data: {json.dumps({'error': message})}\n\n" def sse_done() -> str: return "data: [DONE]\n\n" def sse_tool_event(event: str, tool_name: str) -> str: """Emit tool_start / tool_end events so the frontend can show a progress bar.""" return f"data: {json.dumps({'tool_event': event, 'tool': tool_name})}\n\n" # --- Auth routes --- @app.post("/auth/google") def google_login(req: GoogleAuthRequest): """Verify Google OAuth ID token and create/update user.""" idinfo = auth_module.verify_google_token(req.token) if not idinfo: return JSONResponse({"error": "Invalid token"}, status_code=401) user = db.upsert_user( google_id=idinfo["sub"], email=idinfo.get("email", ""), name=idinfo.get("name", ""), picture=idinfo.get("picture", ""), ) has_key = bool(db.get_user_api_key(idinfo["sub"])) has_tavily = bool(db.get_tavily_key(idinfo["sub"])) return {"user": user, "has_api_key": has_key, "has_tavily_key": has_tavily} @app.get("/auth/me") def get_me(user_id: str = ""): """Get current user data.""" if not user_id: return JSONResponse({"error": "user_id required"}, status_code=400) user = db.get_user(user_id) if not user: return JSONResponse({"error": "User not found"}, status_code=404) return { "user": user, "has_api_key": bool(user.get("openrouter_key")), "has_tavily_key": bool(user.get("tavily_key")), } @app.get("/auth/client_id") def get_client_id(): """Return the Google Client ID for frontend OAuth init.""" return {"client_id": GOOGLE_CLIENT_ID} # --- User settings routes --- @app.post("/user/apikey") def save_api_key(req: ApiKeyRequest): db.save_user_api_key(req.user_id, req.key) return {"ok": True} @app.post("/user/tavilykey") def save_tavily_key(req: TavilyKeyRequest): db.save_tavily_key(req.user_id, req.key) return {"ok": True, "has_tavily_key": bool(req.key.strip())} @app.post("/user/profile") def save_profile(req: ProfileRequest): db.save_student_profile(req.user_id, req.profile) return {"ok": True} @app.post("/upload-doc") async def upload_doc(file: UploadFile = File(...)): """ Accept a PDF, DOCX, or plain-text file and return its extracted text. The frontend sends this text back as `doc_text` in the chat request. """ filename = file.filename or "" ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else "" try: raw_bytes = await file.read() except Exception as exc: print(f"[UPLOAD ERROR] Failed to read uploaded file: {exc}") return JSONResponse({"error": f"Failed to read file: {exc}"}, status_code=400) text = "" try: if ext == "pdf": import pdfplumber #type:ignore with pdfplumber.open(io.BytesIO(raw_bytes)) as pdf: text = "\n".join(page.extract_text() or "" for page in pdf.pages) elif ext in ("docx", "doc"): from docx import Document as DocxDocument doc = DocxDocument(io.BytesIO(raw_bytes)) text = "\n".join(p.text for p in doc.paragraphs) else: # Plain text / markdown / code files text = raw_bytes.decode("utf-8", errors="replace") except Exception as exc: print(f"[PARSE ERROR] Failed to parse document {filename}: {exc}") return JSONResponse({"error": f"Could not parse file content: {exc}"}, status_code=422) text = text.strip() if not text: return JSONResponse({"error": "No readable text found in file."}, status_code=422) # Return first 20,000 chars of extracted text return {"filename": filename, "text": text[:20_000]} # --- Core routes --- APP_DIR = os.path.dirname(os.path.abspath(__file__)) @app.get("/sw.js") def serve_sw(): """Service worker must be served from root for PWA scope.""" return FileResponse(os.path.join(APP_DIR, "sw.js"), media_type="application/javascript") @app.get("/manifest.json") def serve_manifest(): return FileResponse(os.path.join(STATIC_DIR, "manifest.json"), media_type="application/manifest+json") @app.get("/") def serve_index(): return FileResponse(os.path.join(STATIC_DIR, "index.html")) @app.get("/threads") def get_threads(user_id: str = ""): return {"threads": db.get_threads(user_id)} @app.get("/history/{thread_id}") def get_history(thread_id: str): config = {"configurable": {"thread_id": thread_id}} state = chatbot.get_state(config) messages = state.values.get("messages", []) result = [] for msg in messages: role = "user" if isinstance(msg, HumanMessage) else "assistant" result.append({"role": role, "content": msg.content}) return {"messages": result} @app.post("/chat") def chat(request: ChatRequest, req: Request): now = time.time() db.upsert_thread(request.thread_id, request.message, now, request.user_id) # Resolve user-specific data user_api_key = "" tavily_key = "" student_profile = "" if request.user_id: user_api_key = db.get_user_api_key(request.user_id) tavily_key = db.get_tavily_key(request.user_id) user_data = db.get_user(request.user_id) if user_data: student_profile = user_data.get("student_profile", "") def stream(): # Step 1: Guard is_ok, rejection = guard.check_input(request.message) if not is_ok: yield sse_token(rejection) yield sse_done() return # Step 2: Extract search mode sentinel injected by frontend, clean query _WEB_SENTINEL = "[System Instruction: Web Search is enabled." _YT_SENTINEL = "[System Instruction: YouTube Video Search is enabled." raw_message = request.message clean_message = raw_message web_search_enabled = False yt_search_enabled = False if raw_message.startswith(_WEB_SENTINEL): web_search_enabled = True parts = raw_message.split("User Query: ", 1) clean_message = parts[1].strip() if len(parts) > 1 else raw_message elif raw_message.startswith(_YT_SENTINEL): yt_search_enabled = True parts = raw_message.split("User Query: ", 1) clean_message = parts[1].strip() if len(parts) > 1 else raw_message # Step 3: Retrieve context from Pinecone (resilient to errors) try: results = retriever.search(clean_message) context = retriever.format_context(results) except Exception as exc: print(f"[PINECONE ERROR] Failed to retrieve context: {exc}") context = "No relevant context found due to a temporary search error." # Step 3.5: Execute search/transcript tools programmatically if web_search_enabled: yield sse_tool_event("tool_start", "web_search") try: web_results = run_web_search(clean_message, api_key=tavily_key) print(f"[WEB SEARCH] Fetched {len(web_results)} chars for: {clean_message[:60]}", flush=True) context += f"\n\n[Web Search Results]\n{web_results}\n[End Web Search Results]\n" except Exception as e: print(f"[SEARCH ERROR] {e}", flush=True) context += "\n\nWeb search failed to complete.\n" yield sse_tool_event("tool_end", "web_search") elif yt_search_enabled: yield sse_tool_event("tool_start", "yt_transcript") try: transcript = fetch_yt_transcript(clean_message) if transcript.startswith("TRANSCRIPT_UNAVAILABLE"): print(f"[YT TRANSCRIPT] Unavailable: {transcript[:120]}", flush=True) context += f"\n\n[YouTube Transcript Status]\n{transcript}\n" else: print(f"[YT TRANSCRIPT] Fetched {len(transcript)} chars", flush=True) context += f"\n\n[YouTube Video Transcript]\n{transcript[:15_000]}\n[End Transcript]\n" except Exception as e: print(f"[YT TRANSCRIPT ERROR] {e}", flush=True) context += "\n\n[YouTube Transcript Status]\nTRANSCRIPT_UNAVAILABLE: extraction failed unexpectedly.\n" yield sse_tool_event("tool_end", "yt_transcript") # Step 4: Handle document context doc_context = "" if request.doc_text: doc_label = f" (Name: {request.doc_name})" if request.doc_name else "" doc_context = ( f"\n\n[ATTACHED DOCUMENT{doc_label}]\n" + request.doc_text[:20_000] + "\n[END DOCUMENT]\n" ) # Step 5: Build message (text or multimodal with image) if request.image: msg_content = [ {"type": "text", "text": clean_message + doc_context}, { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{request.image}"}, }, ] else: msg_content = clean_message + doc_context # Step 6: Stream LLM response via LangGraph has_image = bool(request.image) # Resolve target model dynamically to print/trace and pass to LangGraph config try: from graph import _classify, _pick category = _classify(clean_message) target_model, _ = _pick(category, has_image=has_image) except Exception as e: print(f"[MODEL PICK ERROR] {e}", flush=True) target_model = "" print(f"[ROUTER CHAT] Routing query to model: {target_model}", flush=True) config = { "configurable": { "thread_id": request.thread_id, "persona": request.persona, "context": context, "language": request.language, "username": request.username, "student_profile": student_profile, "user_api_key": user_api_key, "has_image": has_image, "search_enabled": (web_search_enabled or yt_search_enabled), "model": target_model, "doc_name": request.doc_name, "doc_bytes": request.doc_bytes, } } try: for chunk, _metadata in chatbot.stream( {"messages": [HumanMessage(content=msg_content)]}, config=config, stream_mode="messages", ): if isinstance(chunk, AIMessage) and chunk.content: yield sse_token(chunk.content) except Exception as e: error_str = str(e).lower() if "429" in str(e) or "rate" in error_str: yield sse_error("Rate limited. Please wait a moment and try again.") elif "402" in str(e) or "payment" in error_str or "credits" in error_str: yield sse_error("Free credits exhausted on OpenRouter.") elif "404" in str(e) or "not found" in error_str or "no endpoints" in error_str: yield sse_error("Model unavailable. Please try again.") elif "401" in str(e) or "unauthorized" in error_str or "invalid" in error_str: yield sse_error("Invalid API key.") else: yield sse_error("Something went wrong. Please try again.") yield sse_done() return yield sse_done() return StreamingResponse( stream(), media_type="text/event-stream", headers={"X-Accel-Buffering": "no"}, ) @app.post("/rename") def rename_thread(request: RenameRequest): db.rename_thread(request.thread_id, request.title) return {"ok": True} @app.delete("/thread/{thread_id}") def delete_thread(thread_id: str): db.delete_thread(thread_id) return {"ok": True} # --- Health check --- @app.get("/health") def health(): checks = {"db": False, "pinecone": False} try: db.conn.execute("SELECT 1") checks["db"] = True except Exception: pass try: retriever._index.describe_index_stats() checks["pinecone"] = True except Exception: pass ok = all(checks.values()) return JSONResponse( {"status": "ok" if ok else "degraded", "checks": checks}, status_code=200 if ok else 503, ) # --- Static file serving --- if os.path.isdir(ASSETS_DIR): app.mount("/assets", StaticFiles(directory=ASSETS_DIR), name="assets") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") @app.get("/{filepath:path}") def serve_static(filepath: str): full_path = os.path.join(STATIC_DIR, filepath) if os.path.isfile(full_path): return FileResponse(full_path) return JSONResponse({"error": "not found"}, status_code=404)