import shutil import os from pydantic import BaseModel from typing import List, Dict from src.ingest import ingest_file from src.chain import get_rag_chain from src import config from contextlib import asynccontextmanager from fastapi import FastAPI, UploadFile, File, HTTPException rag_chain = None @asynccontextmanager async def lifespan(app: FastAPI): global rag_chain print("Initializing RAG chain at startup...") try: rag_chain = get_rag_chain() print("RAG chain ready.") except Exception as e: print(f"RAG chain not initialized (no collection yet): {e}") rag_chain = None yield app = FastAPI(title="GeneSeek V2 API", lifespan=lifespan) @app.get("/") async def root(): return {"message": "GeneSeek"} @app.get("/health") async def health_check(): return {"status": "ok", "service": "GeneSeek V2 API"} class ChatRequest(BaseModel): question: str class ChatResponse(BaseModel): answer: str contexts: List[Dict] @app.post("/upload") async def upload_document(file: UploadFile = File(...)): global rag_chain allowed = {".txt", ".pdf"} ext = os.path.splitext(file.filename)[1].lower() if ext not in allowed: raise HTTPException(400, f"Invalid file. Allowed: {allowed}") file_path = config.RAW_DATA_DIR / file.filename try: with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) except Exception as e: raise HTTPException(500, f"Save failed: {e}") try: result = ingest_file(str(file_path)) except Exception as e: raise HTTPException(500, f"Ingestion failed: {e}") if result is False: return {"message": "File already ingested. Skipping.", "status": "skipped"} rag_chain = get_rag_chain() return {"message": f"Successfully indexed {file.filename}", "status": "success"} @app.post("/chat", response_model=ChatResponse) async def chat_endpoint(request: ChatRequest): global rag_chain if rag_chain is None: raise HTTPException(503, "No documents ingested yet. Please upload a file first.") try: result = rag_chain(request.question) return ChatResponse(answer=result["answer"], contexts=result["contexts"]) except Exception as e: raise HTTPException(500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=False)