documind / app.py
Aaravkumar's picture
Update app.py
31b44ff verified
Raw
History Blame Contribute Delete
6.15 kB
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import tempfile, os, uuid, logging
from rank_bm25 import BM25Okapi
import re
import time
from dotenv import load_dotenv
from loader import Loader
from chunker import Chunker
from embedder import Embedder
from vector import VectorStorage
from retriever import Retriever
app = FastAPI()
load_dotenv()
token = os.environ["HF_TOKEN"]
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
MODELS = [
("Qwen/Qwen2.5-72B-Instruct"),
("meta-llama/Llama-3.2-3B-Instruct"),
("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"),
("mistralai/Mistral-7B-Instruct-v0.3"),
("HuggingFaceH4/zephyr-7b-beta"),
]
SESSION_TIMEOUT = 3600
sessions: dict = {}
embedder = Embedder()
print("Embedding Model ready")
def cleanup_expired_sessions():
"""Remove sessions older than SESSION_TIMEOUT"""
current_time = time.time()
expired = [
sid for sid, data in sessions.items()
if current_time - data.get('created_at', current_time) > SESSION_TIMEOUT
]
for sid in expired:
del sessions[sid]
logger.info(f"Cleaned up expired session: {sid}")
@app.post("/upload")
async def upload_pdf(file: UploadFile = File(...)):
cleanup_expired_sessions()
if not file.filename.endswith(".pdf"):
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(await file.read())
tmp_path = tmp.name
print("Temproary path defined")
try:
text = Loader(tmp_path).load()
print("Text extracted")
chunks = Chunker().chunker(text)
print("Chunked the text")
tokenized_chunks = [re.findall(r"\w+", chunk.lower())for chunk in chunks]
print("Tokenized for bm25")
print("Embedding started....")
vectors = embedder.embed(chunks)
print("Finally embedded the text")
store = VectorStorage(dimension=len(vectors[0]))
store.add(vectors, chunks)
print("Embedding stored")
bm25 = BM25Okapi(tokenized_chunks)
print("bm25 intialized")
finally:
os.unlink(tmp_path)
session_id = str(uuid.uuid4())
sessions[session_id] = {"store": store, "embedder": embedder,"bm25" : bm25, "created_at": time.time()}
return {"session_id": session_id, "message": "PDF indexed. Ready to chat!"}
class ChatRequest(BaseModel):
session_id: str
message: str
history: list
@app.post("/chat")
def chat(req: ChatRequest):
cleanup_expired_sessions()
session = sessions.get(req.session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found.")
print("Unpacking session id....")
store = session["store"]
embedder = session["embedder"]
bm25 = session["bm25"]
print("Session id unpacked")
retriever = Retriever(store, embedder,bm25)
print("Retriever intialized")
context_chunks = retriever.retrieve(req.message)
print("Chunks fetched")
if not context_chunks:
return {"response": "I couldn't find relevant information in the document."}
context_text = "\n\n".join(context_chunks)
system_prompt = (
"You are a helpful study assistant. Answer the user's question based ONLY on the provided context.\n\n"
"FORMATTING RULES (STRICT):\n"
"You MUST format your entire response using valid Markdown.\n"
"1. Use `##` for main section headings.\n"
"2. Use `**bold text**` for subheadings.\n"
"3. Use `- ` (a hyphen followed by a space) for bullet points.\n"
"4. CRITICAL: You MUST leave a completely blank line (two newline characters) before every heading and every bullet point.\n"
"5. Do not write long paragraphs. Keep points concise.\n\n"
"EXAMPLE OF EXACT OUTPUT FORMAT:\n"
"## 1. The Treaty of Vienna\n\n"
"- Signed in 1815\n"
"- Aimed to restore order\n\n"
"## 2. Key Provisions\n\n"
"- Created the German Confederation\n"
)
messages = [{"role": "system", "content": system_prompt}]
messages.extend(req.history)
messages.append({"role": "user", "content": f"Context:\n{context_text}\n\nQuestion: {req.message}"})
print("Started streaming...")
def token_stream():
for model in MODELS:
success = False
if "HF_TOKEN" not in os.environ:
yield "data: Error: HF_TOKEN not configured\n\n"
yield "data: [DONE]\n\n"
return
try:
client = InferenceClient(model, token=token)
logger.info(f"Streaming with: {model}")
for Token in client.chat_completion(messages, max_tokens=512, stream=True):
text = Token.choices[0].delta.content
if text:
success = True
yield f"data: {text}\n\n"
yield "data: [DONE]\n\n"
return
except Exception as e:
if success:
yield "data: [DONE]\n\n"
return
logger.warning(f"Streaming failed for {model}: {e}")
continue
yield "data: Sorry, all models are currently unavailable. Try again later.\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(token_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)