geneseek / main.py
prabhal's picture
added updates geneseek
e8ade4e
import shutil
import os
from pydantic import BaseModel
from typing import List, Dict
from src.ingest import ingest_file
from src.chain import get_rag_chain
from src import config
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, HTTPException
rag_chain = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag_chain
print("Initializing RAG chain at startup...")
try:
rag_chain = get_rag_chain()
print("RAG chain ready.")
except Exception as e:
print(f"RAG chain not initialized (no collection yet): {e}")
rag_chain = None
yield
app = FastAPI(title="GeneSeek V2 API", lifespan=lifespan)
@app.get("/")
async def root():
return {"message": "GeneSeek"}
@app.get("/health")
async def health_check():
return {"status": "ok", "service": "GeneSeek V2 API"}
class ChatRequest(BaseModel):
question: str
class ChatResponse(BaseModel):
answer: str
contexts: List[Dict]
@app.post("/upload")
async def upload_document(file: UploadFile = File(...)):
global rag_chain
allowed = {".txt", ".pdf"}
ext = os.path.splitext(file.filename)[1].lower()
if ext not in allowed:
raise HTTPException(400, f"Invalid file. Allowed: {allowed}")
file_path = config.RAW_DATA_DIR / file.filename
try:
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
except Exception as e:
raise HTTPException(500, f"Save failed: {e}")
try:
result = ingest_file(str(file_path))
except Exception as e:
raise HTTPException(500, f"Ingestion failed: {e}")
if result is False:
return {"message": "File already ingested. Skipping.", "status": "skipped"}
rag_chain = get_rag_chain()
return {"message": f"Successfully indexed {file.filename}", "status": "success"}
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
global rag_chain
if rag_chain is None:
raise HTTPException(503, "No documents ingested yet. Please upload a file first.")
try:
result = rag_chain(request.question)
return ChatResponse(answer=result["answer"], contexts=result["contexts"])
except Exception as e:
raise HTTPException(500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=False)