from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel from huggingface_hub import InferenceClient import tempfile, os, uuid, logging from rank_bm25 import BM25Okapi import re import time from dotenv import load_dotenv from loader import Loader from chunker import Chunker from embedder import Embedder from vector import VectorStorage from retriever import Retriever app = FastAPI() load_dotenv() token = os.environ["HF_TOKEN"] logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) MODELS = [ ("Qwen/Qwen2.5-72B-Instruct"), ("meta-llama/Llama-3.2-3B-Instruct"), ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"), ("mistralai/Mistral-7B-Instruct-v0.3"), ("HuggingFaceH4/zephyr-7b-beta"), ] SESSION_TIMEOUT = 3600 sessions: dict = {} embedder = Embedder() print("Embedding Model ready") def cleanup_expired_sessions(): """Remove sessions older than SESSION_TIMEOUT""" current_time = time.time() expired = [ sid for sid, data in sessions.items() if current_time - data.get('created_at', current_time) > SESSION_TIMEOUT ] for sid in expired: del sessions[sid] logger.info(f"Cleaned up expired session: {sid}") @app.post("/upload") async def upload_pdf(file: UploadFile = File(...)): cleanup_expired_sessions() if not file.filename.endswith(".pdf"): raise HTTPException(status_code=400, detail="Only PDF files are supported.") with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: tmp.write(await file.read()) tmp_path = tmp.name print("Temproary path defined") try: text = Loader(tmp_path).load() print("Text extracted") chunks = Chunker().chunker(text) print("Chunked the text") tokenized_chunks = [re.findall(r"\w+", chunk.lower())for chunk in chunks] print("Tokenized for bm25") print("Embedding started....") vectors = embedder.embed(chunks) print("Finally embedded the text") store = VectorStorage(dimension=len(vectors[0])) store.add(vectors, chunks) print("Embedding stored") bm25 = BM25Okapi(tokenized_chunks) print("bm25 intialized") finally: os.unlink(tmp_path) session_id = str(uuid.uuid4()) sessions[session_id] = {"store": store, "embedder": embedder,"bm25" : bm25, "created_at": time.time()} return {"session_id": session_id, "message": "PDF indexed. Ready to chat!"} class ChatRequest(BaseModel): session_id: str message: str history: list @app.post("/chat") def chat(req: ChatRequest): cleanup_expired_sessions() session = sessions.get(req.session_id) if not session: raise HTTPException(status_code=404, detail="Session not found.") print("Unpacking session id....") store = session["store"] embedder = session["embedder"] bm25 = session["bm25"] print("Session id unpacked") retriever = Retriever(store, embedder,bm25) print("Retriever intialized") context_chunks = retriever.retrieve(req.message) print("Chunks fetched") if not context_chunks: return {"response": "I couldn't find relevant information in the document."} context_text = "\n\n".join(context_chunks) system_prompt = ( "You are a helpful study assistant. Answer the user's question based ONLY on the provided context.\n\n" "FORMATTING RULES (STRICT):\n" "You MUST format your entire response using valid Markdown.\n" "1. Use `##` for main section headings.\n" "2. Use `**bold text**` for subheadings.\n" "3. Use `- ` (a hyphen followed by a space) for bullet points.\n" "4. CRITICAL: You MUST leave a completely blank line (two newline characters) before every heading and every bullet point.\n" "5. Do not write long paragraphs. Keep points concise.\n\n" "EXAMPLE OF EXACT OUTPUT FORMAT:\n" "## 1. The Treaty of Vienna\n\n" "- Signed in 1815\n" "- Aimed to restore order\n\n" "## 2. Key Provisions\n\n" "- Created the German Confederation\n" ) messages = [{"role": "system", "content": system_prompt}] messages.extend(req.history) messages.append({"role": "user", "content": f"Context:\n{context_text}\n\nQuestion: {req.message}"}) print("Started streaming...") def token_stream(): for model in MODELS: success = False if "HF_TOKEN" not in os.environ: yield "data: Error: HF_TOKEN not configured\n\n" yield "data: [DONE]\n\n" return try: client = InferenceClient(model, token=token) logger.info(f"Streaming with: {model}") for Token in client.chat_completion(messages, max_tokens=512, stream=True): text = Token.choices[0].delta.content if text: success = True yield f"data: {text}\n\n" yield "data: [DONE]\n\n" return except Exception as e: if success: yield "data: [DONE]\n\n" return logger.warning(f"Streaming failed for {model}: {e}") continue yield "data: Sorry, all models are currently unavailable. Try again later.\n\n" yield "data: [DONE]\n\n" return StreamingResponse(token_stream(), media_type="text/event-stream") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)