Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from huggingface_hub import InferenceClient | |
| import tempfile, os, uuid, logging | |
| from rank_bm25 import BM25Okapi | |
| import re | |
| import time | |
| from dotenv import load_dotenv | |
| from loader import Loader | |
| from chunker import Chunker | |
| from embedder import Embedder | |
| from vector import VectorStorage | |
| from retriever import Retriever | |
| app = FastAPI() | |
| load_dotenv() | |
| token = os.environ["HF_TOKEN"] | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODELS = [ | |
| ("Qwen/Qwen2.5-72B-Instruct"), | |
| ("meta-llama/Llama-3.2-3B-Instruct"), | |
| ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"), | |
| ("mistralai/Mistral-7B-Instruct-v0.3"), | |
| ("HuggingFaceH4/zephyr-7b-beta"), | |
| ] | |
| SESSION_TIMEOUT = 3600 | |
| sessions: dict = {} | |
| embedder = Embedder() | |
| print("Embedding Model ready") | |
| def cleanup_expired_sessions(): | |
| """Remove sessions older than SESSION_TIMEOUT""" | |
| current_time = time.time() | |
| expired = [ | |
| sid for sid, data in sessions.items() | |
| if current_time - data.get('created_at', current_time) > SESSION_TIMEOUT | |
| ] | |
| for sid in expired: | |
| del sessions[sid] | |
| logger.info(f"Cleaned up expired session: {sid}") | |
| async def upload_pdf(file: UploadFile = File(...)): | |
| cleanup_expired_sessions() | |
| if not file.filename.endswith(".pdf"): | |
| raise HTTPException(status_code=400, detail="Only PDF files are supported.") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: | |
| tmp.write(await file.read()) | |
| tmp_path = tmp.name | |
| print("Temproary path defined") | |
| try: | |
| text = Loader(tmp_path).load() | |
| print("Text extracted") | |
| chunks = Chunker().chunker(text) | |
| print("Chunked the text") | |
| tokenized_chunks = [re.findall(r"\w+", chunk.lower())for chunk in chunks] | |
| print("Tokenized for bm25") | |
| print("Embedding started....") | |
| vectors = embedder.embed(chunks) | |
| print("Finally embedded the text") | |
| store = VectorStorage(dimension=len(vectors[0])) | |
| store.add(vectors, chunks) | |
| print("Embedding stored") | |
| bm25 = BM25Okapi(tokenized_chunks) | |
| print("bm25 intialized") | |
| finally: | |
| os.unlink(tmp_path) | |
| session_id = str(uuid.uuid4()) | |
| sessions[session_id] = {"store": store, "embedder": embedder,"bm25" : bm25, "created_at": time.time()} | |
| return {"session_id": session_id, "message": "PDF indexed. Ready to chat!"} | |
| class ChatRequest(BaseModel): | |
| session_id: str | |
| message: str | |
| history: list | |
| def chat(req: ChatRequest): | |
| cleanup_expired_sessions() | |
| session = sessions.get(req.session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found.") | |
| print("Unpacking session id....") | |
| store = session["store"] | |
| embedder = session["embedder"] | |
| bm25 = session["bm25"] | |
| print("Session id unpacked") | |
| retriever = Retriever(store, embedder,bm25) | |
| print("Retriever intialized") | |
| context_chunks = retriever.retrieve(req.message) | |
| print("Chunks fetched") | |
| if not context_chunks: | |
| return {"response": "I couldn't find relevant information in the document."} | |
| context_text = "\n\n".join(context_chunks) | |
| system_prompt = ( | |
| "You are a helpful study assistant. Answer the user's question based ONLY on the provided context.\n\n" | |
| "FORMATTING RULES (STRICT):\n" | |
| "You MUST format your entire response using valid Markdown.\n" | |
| "1. Use `##` for main section headings.\n" | |
| "2. Use `**bold text**` for subheadings.\n" | |
| "3. Use `- ` (a hyphen followed by a space) for bullet points.\n" | |
| "4. CRITICAL: You MUST leave a completely blank line (two newline characters) before every heading and every bullet point.\n" | |
| "5. Do not write long paragraphs. Keep points concise.\n\n" | |
| "EXAMPLE OF EXACT OUTPUT FORMAT:\n" | |
| "## 1. The Treaty of Vienna\n\n" | |
| "- Signed in 1815\n" | |
| "- Aimed to restore order\n\n" | |
| "## 2. Key Provisions\n\n" | |
| "- Created the German Confederation\n" | |
| ) | |
| messages = [{"role": "system", "content": system_prompt}] | |
| messages.extend(req.history) | |
| messages.append({"role": "user", "content": f"Context:\n{context_text}\n\nQuestion: {req.message}"}) | |
| print("Started streaming...") | |
| def token_stream(): | |
| for model in MODELS: | |
| success = False | |
| if "HF_TOKEN" not in os.environ: | |
| yield "data: Error: HF_TOKEN not configured\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| try: | |
| client = InferenceClient(model, token=token) | |
| logger.info(f"Streaming with: {model}") | |
| for Token in client.chat_completion(messages, max_tokens=512, stream=True): | |
| text = Token.choices[0].delta.content | |
| if text: | |
| success = True | |
| yield f"data: {text}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return | |
| except Exception as e: | |
| if success: | |
| yield "data: [DONE]\n\n" | |
| return | |
| logger.warning(f"Streaming failed for {model}: {e}") | |
| continue | |
| yield "data: Sorry, all models are currently unavailable. Try again later.\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(token_stream(), media_type="text/event-stream") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |