RAG / api /app.py
rohitdeshmukh318's picture
Deploy clean HF snapshot without binary PDF history
f499d4b
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List
import shutil
import uuid
import logging
from pipeline.ingest_document import ingest_document
from pipeline.query_pipeline import run_query_pipeline
from pipeline.conversation_memory import ConversationMemory, ConversationTurn
from api import dependencies
from api.auth import (
AuthError,
AuthenticatedUser,
create_access_token,
get_current_user,
user_store,
)
from configs.settings import settings
import gc
import os
# Force single thread for torch to save CPU/memory overhead on small instances
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(name)s | %(levelname)s | %(message)s",
)
gc.collect()
logger = logging.getLogger(__name__)
# In-memory session store: user_id -> session_id -> ConversationMemory
_sessions: dict[str, dict[str, ConversationMemory]] = {}
def _get_user_sessions(user_id: str) -> dict[str, ConversationMemory]:
return _sessions.setdefault(user_id, {})
def _create_session(user_id: str, session_name: str) -> tuple[str, ConversationMemory]:
session_id = uuid.uuid4().hex[:16]
session = ConversationMemory(session_name=session_name or "New Session")
_get_user_sessions(user_id)[session_id] = session
return session_id, session
def _resolve_session(
user_id: str,
session_id: Optional[str],
*,
create_if_missing: bool,
) -> tuple[str, ConversationMemory]:
user_sessions = _get_user_sessions(user_id)
if session_id and session_id in user_sessions:
return session_id, user_sessions[session_id]
if session_id and not create_if_missing:
raise HTTPException(status_code=404, detail="Session not found")
if session_id and create_if_missing:
session = ConversationMemory(session_name=f"Session {session_id[:6]}")
user_sessions[session_id] = session
return session_id, session
if user_sessions:
latest_session_id = next(reversed(user_sessions))
return latest_session_id, user_sessions[latest_session_id]
return _create_session(user_id, "Research Session")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Models will load lazily on first request to avoid deployment timeouts
logger.info(f"ResearchMind RAG system starting on {settings.PROJECT_NAME}")
yield
logger.info("Application shutting down.")
app = FastAPI(
title=settings.PROJECT_NAME,
version="2.0.0",
docs_url="/docs",
openapi_url="/openapi.json",
lifespan=lifespan,
)
# -- CORS Configuration -------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
settings.UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
@app.get("/")
def root():
return {
"message": "Welcome to ResearchMind RAG API",
"version": "2.0.0",
"status": "online",
"docs": "/docs"
}
# -- Request / Response Models ------------------------------------------------
class QuestionRequest(BaseModel):
question: str
session_id: Optional[str] = None
doc_id: Optional[str] = None
enable_nli: bool = False
use_memory_context: bool = True
class PinRequest(BaseModel):
session_id: Optional[str] = None
text: str
source_question: str
from_doc: Optional[str] = None
class SessionRequest(BaseModel):
session_name: Optional[str] = "New Session"
class AuthRequest(BaseModel):
username: str
password: str
# -- Upload -------------------------------------------------------------------
@app.post("/auth/register")
def register(req: AuthRequest):
try:
user = user_store.create_user(req.username, req.password)
except AuthError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
token = create_access_token(user)
return {"token": token, "user": user.to_dict()}
@app.post("/auth/login")
def login(req: AuthRequest):
try:
user = user_store.authenticate(req.username, req.password)
except AuthError as exc:
raise HTTPException(status_code=401, detail=str(exc)) from exc
token = create_access_token(user)
return {"token": token, "user": user.to_dict()}
@app.get("/auth/me")
def auth_me(current_user: AuthenticatedUser = Depends(get_current_user)):
return {"user": current_user.to_dict()}
@app.post("/upload")
def upload_pdf(
file: UploadFile = File(...),
session_id: Optional[str] = None,
current_user: AuthenticatedUser = Depends(get_current_user),
):
"""Upload and index a PDF document."""
if not file.filename.lower().endswith(".pdf"):
raise HTTPException(status_code=400, detail="Only PDF files are allowed")
file_id = f"{uuid.uuid4().hex}.pdf"
upload_dir = settings.UPLOAD_DIR / current_user.user_id
upload_dir.mkdir(parents=True, exist_ok=True)
file_path = upload_dir / file_id
try:
with open(file_path, "wb") as f:
shutil.copyfileobj(file.file, f)
except Exception as e:
logger.error("File save error: %s", e)
raise HTTPException(status_code=500, detail="Failed to save file")
try:
encoder = dependencies.get_encoder()
index = dependencies.get_index(current_user.user_id)
logger.info("Ingesting new document: %s", file.filename)
doc_id = ingest_document(file_path, encoder, index)
except Exception as e:
logger.error("Ingestion error: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail="Failed to ingest document")
# Register doc into session memory
_, session = _resolve_session(current_user.user_id, session_id, create_if_missing=True)
session.register_document(doc_id, file.filename, index.chunk_count)
# Explicitly clear memory after heavy operation
gc.collect()
return {
"status": "indexed",
"doc_id": doc_id,
"filename": file.filename,
"chunk_count": index.chunk_count,
"document_count": index.document_count,
}
# -- Ask ----------------------------------------------------------------------
@app.post("/ask")
def ask_question(
req: QuestionRequest,
current_user: AuthenticatedUser = Depends(get_current_user),
):
"""Ask a question against the indexed documents with optional memory context."""
if not req.question.strip():
raise HTTPException(status_code=400, detail="Question cannot be empty")
_, session = _resolve_session(current_user.user_id, req.session_id, create_if_missing=True)
# Context-rewrite query using conversation history
question = session.rewrite_query_with_context(req.question)
# Build context strings for prompt injection
conversation_context = session.build_context_prompt() if req.use_memory_context else ""
pinned_context = session.get_pins_context() if req.use_memory_context else ""
try:
encoder = dependencies.get_encoder()
index = dependencies.get_index(current_user.user_id)
reranker = dependencies.get_reranker()
llm = dependencies.get_llm_client()
# Determine document scope: request-level doc_id overrides session-level docs
if req.doc_id:
doc_ids = [req.doc_id]
else:
doc_ids = [d["doc_id"] for d in session.active_docs] if session.active_docs else None
result = run_query_pipeline(
question=question,
encoder=encoder,
index=index,
reranker=reranker,
llm=llm,
enable_nli=req.enable_nli,
conversation_context=conversation_context,
pinned_context=pinned_context,
doc_ids=doc_ids,
)
except Exception as e:
logger.error("Query error: %s", e, exc_info=True)
return {
"answer": "I cannot answer due to an internal error.",
"verdict": "refused",
"citations": [],
"evidence": [],
}
# Store the turn in memory (strip raw evidence to save memory)
if result.get("verdict") != "refused":
turn = ConversationTurn(
question=req.question,
answer=result.get("answer", ""),
verdict=result.get("verdict", "refused"),
confidence=result.get("confidence", 0.0),
citations=result.get("citations", []),
evidence=[], # don't store full evidence in memory
support_ratio=result.get("support_ratio", 0.0),
citation_coverage=result.get("citation_coverage", 0.0),
)
session.add_turn(turn)
# Include original question in response for UI tracking
result["original_question"] = req.question
result["rewritten_question"] = question if question != req.question else None
return result
# -- Session & Memory Endpoints -----------------------------------------------
@app.post("/session/create")
def create_session(
req: SessionRequest,
current_user: AuthenticatedUser = Depends(get_current_user),
):
"""Create a new conversation session."""
session_name = req.session_name or "New Session"
session_id, _ = _create_session(current_user.user_id, session_name)
return {"session_id": session_id, "session_name": session_name}
@app.get("/session/{session_id}")
def get_session(
session_id: str,
current_user: AuthenticatedUser = Depends(get_current_user),
):
"""Get session metadata, history, and stats."""
_, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False)
return {
"session_id": session_id,
"session_name": session.session_name,
"stats": session.get_stats(),
"turns": [t.to_dict() for t in session.turns],
"pins": [p.to_dict() for p in session.pins],
"active_docs": session.active_docs,
}
@app.get("/session/{session_id}/history")
def get_history(
session_id: str,
current_user: AuthenticatedUser = Depends(get_current_user),
):
"""Return conversation history for a session."""
_, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False)
return {"turns": [t.to_dict() for t in session.turns]}
@app.post("/session/{session_id}/pin")
def pin_insight(
session_id: str,
req: PinRequest,
current_user: AuthenticatedUser = Depends(get_current_user),
):
"""Pin a key insight to session memory."""
_, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False)
pin = session.add_pin(
text=req.text,
source_question=req.source_question,
from_doc=req.from_doc,
)
return {"status": "pinned", "pin": pin.to_dict()}
@app.delete("/session/{session_id}/pin/{pin_id}")
def remove_pin(
session_id: str,
pin_id: str,
current_user: AuthenticatedUser = Depends(get_current_user),
):
"""Remove a pinned insight."""
_, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False)
removed = session.remove_pin(pin_id)
if not removed:
raise HTTPException(status_code=404, detail="Pin not found")
return {"status": "removed", "pin_id": pin_id}
@app.delete("/session/{session_id}/clear")
def clear_session(
session_id: str,
current_user: AuthenticatedUser = Depends(get_current_user),
):
"""Clear session history and pins (keeps documents)."""
_, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False)
session.clear()
return {"status": "cleared"}
@app.get("/sessions")
def list_sessions(current_user: AuthenticatedUser = Depends(get_current_user)):
"""List all active sessions with basic info."""
user_sessions = _get_user_sessions(current_user.user_id)
return [
{
"session_id": sid,
"session_name": mem.session_name,
"turn_count": len(mem.turns),
"pin_count": len(mem.pins),
"doc_count": len(mem.active_docs),
"created_at": mem.created_at.strftime("%Y-%m-%d %H:%M"),
}
for sid, mem in user_sessions.items()
]
# -- Index Management ---------------------------------------------------------
@app.delete("/index")
def clear_index(current_user: AuthenticatedUser = Depends(get_current_user)):
try:
dependencies.clear_index(current_user.user_id)
for session in _get_user_sessions(current_user.user_id).values():
session.active_docs.clear()
return {"status": "cleared", "message": "Index cleared successfully."}
except Exception as e:
logger.error("Clear index error: %s", e)
raise HTTPException(status_code=500, detail="Failed to clear index")
@app.get("/index/status")
def index_status(current_user: AuthenticatedUser = Depends(get_current_user)):
index = dependencies.get_index(current_user.user_id)
return {
"chunk_count": index.chunk_count,
"document_count": index.document_count,
"doc_ids": list(index.get_doc_ids()),
}
@app.get("/health")
def health_check():
return {"status": "ok", "version": "2.0.0"}