Spaces:
Sleeping
Sleeping
| from contextlib import asynccontextmanager | |
| from fastapi import Depends, FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| import shutil | |
| import uuid | |
| import logging | |
| from pipeline.ingest_document import ingest_document | |
| from pipeline.query_pipeline import run_query_pipeline | |
| from pipeline.conversation_memory import ConversationMemory, ConversationTurn | |
| from api import dependencies | |
| from api.auth import ( | |
| AuthError, | |
| AuthenticatedUser, | |
| create_access_token, | |
| get_current_user, | |
| user_store, | |
| ) | |
| from configs.settings import settings | |
| import gc | |
| import os | |
| # Force single thread for torch to save CPU/memory overhead on small instances | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(name)s | %(levelname)s | %(message)s", | |
| ) | |
| gc.collect() | |
| logger = logging.getLogger(__name__) | |
| # In-memory session store: user_id -> session_id -> ConversationMemory | |
| _sessions: dict[str, dict[str, ConversationMemory]] = {} | |
| def _get_user_sessions(user_id: str) -> dict[str, ConversationMemory]: | |
| return _sessions.setdefault(user_id, {}) | |
| def _create_session(user_id: str, session_name: str) -> tuple[str, ConversationMemory]: | |
| session_id = uuid.uuid4().hex[:16] | |
| session = ConversationMemory(session_name=session_name or "New Session") | |
| _get_user_sessions(user_id)[session_id] = session | |
| return session_id, session | |
| def _resolve_session( | |
| user_id: str, | |
| session_id: Optional[str], | |
| *, | |
| create_if_missing: bool, | |
| ) -> tuple[str, ConversationMemory]: | |
| user_sessions = _get_user_sessions(user_id) | |
| if session_id and session_id in user_sessions: | |
| return session_id, user_sessions[session_id] | |
| if session_id and not create_if_missing: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| if session_id and create_if_missing: | |
| session = ConversationMemory(session_name=f"Session {session_id[:6]}") | |
| user_sessions[session_id] = session | |
| return session_id, session | |
| if user_sessions: | |
| latest_session_id = next(reversed(user_sessions)) | |
| return latest_session_id, user_sessions[latest_session_id] | |
| return _create_session(user_id, "Research Session") | |
| async def lifespan(app: FastAPI): | |
| # Models will load lazily on first request to avoid deployment timeouts | |
| logger.info(f"ResearchMind RAG system starting on {settings.PROJECT_NAME}") | |
| yield | |
| logger.info("Application shutting down.") | |
| app = FastAPI( | |
| title=settings.PROJECT_NAME, | |
| version="2.0.0", | |
| docs_url="/docs", | |
| openapi_url="/openapi.json", | |
| lifespan=lifespan, | |
| ) | |
| # -- CORS Configuration ------------------------------------------------------- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| settings.UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
| def root(): | |
| return { | |
| "message": "Welcome to ResearchMind RAG API", | |
| "version": "2.0.0", | |
| "status": "online", | |
| "docs": "/docs" | |
| } | |
| # -- Request / Response Models ------------------------------------------------ | |
| class QuestionRequest(BaseModel): | |
| question: str | |
| session_id: Optional[str] = None | |
| doc_id: Optional[str] = None | |
| enable_nli: bool = False | |
| use_memory_context: bool = True | |
| class PinRequest(BaseModel): | |
| session_id: Optional[str] = None | |
| text: str | |
| source_question: str | |
| from_doc: Optional[str] = None | |
| class SessionRequest(BaseModel): | |
| session_name: Optional[str] = "New Session" | |
| class AuthRequest(BaseModel): | |
| username: str | |
| password: str | |
| # -- Upload ------------------------------------------------------------------- | |
| def register(req: AuthRequest): | |
| try: | |
| user = user_store.create_user(req.username, req.password) | |
| except AuthError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| token = create_access_token(user) | |
| return {"token": token, "user": user.to_dict()} | |
| def login(req: AuthRequest): | |
| try: | |
| user = user_store.authenticate(req.username, req.password) | |
| except AuthError as exc: | |
| raise HTTPException(status_code=401, detail=str(exc)) from exc | |
| token = create_access_token(user) | |
| return {"token": token, "user": user.to_dict()} | |
| def auth_me(current_user: AuthenticatedUser = Depends(get_current_user)): | |
| return {"user": current_user.to_dict()} | |
| def upload_pdf( | |
| file: UploadFile = File(...), | |
| session_id: Optional[str] = None, | |
| current_user: AuthenticatedUser = Depends(get_current_user), | |
| ): | |
| """Upload and index a PDF document.""" | |
| if not file.filename.lower().endswith(".pdf"): | |
| raise HTTPException(status_code=400, detail="Only PDF files are allowed") | |
| file_id = f"{uuid.uuid4().hex}.pdf" | |
| upload_dir = settings.UPLOAD_DIR / current_user.user_id | |
| upload_dir.mkdir(parents=True, exist_ok=True) | |
| file_path = upload_dir / file_id | |
| try: | |
| with open(file_path, "wb") as f: | |
| shutil.copyfileobj(file.file, f) | |
| except Exception as e: | |
| logger.error("File save error: %s", e) | |
| raise HTTPException(status_code=500, detail="Failed to save file") | |
| try: | |
| encoder = dependencies.get_encoder() | |
| index = dependencies.get_index(current_user.user_id) | |
| logger.info("Ingesting new document: %s", file.filename) | |
| doc_id = ingest_document(file_path, encoder, index) | |
| except Exception as e: | |
| logger.error("Ingestion error: %s", e, exc_info=True) | |
| raise HTTPException(status_code=500, detail="Failed to ingest document") | |
| # Register doc into session memory | |
| _, session = _resolve_session(current_user.user_id, session_id, create_if_missing=True) | |
| session.register_document(doc_id, file.filename, index.chunk_count) | |
| # Explicitly clear memory after heavy operation | |
| gc.collect() | |
| return { | |
| "status": "indexed", | |
| "doc_id": doc_id, | |
| "filename": file.filename, | |
| "chunk_count": index.chunk_count, | |
| "document_count": index.document_count, | |
| } | |
| # -- Ask ---------------------------------------------------------------------- | |
| def ask_question( | |
| req: QuestionRequest, | |
| current_user: AuthenticatedUser = Depends(get_current_user), | |
| ): | |
| """Ask a question against the indexed documents with optional memory context.""" | |
| if not req.question.strip(): | |
| raise HTTPException(status_code=400, detail="Question cannot be empty") | |
| _, session = _resolve_session(current_user.user_id, req.session_id, create_if_missing=True) | |
| # Context-rewrite query using conversation history | |
| question = session.rewrite_query_with_context(req.question) | |
| # Build context strings for prompt injection | |
| conversation_context = session.build_context_prompt() if req.use_memory_context else "" | |
| pinned_context = session.get_pins_context() if req.use_memory_context else "" | |
| try: | |
| encoder = dependencies.get_encoder() | |
| index = dependencies.get_index(current_user.user_id) | |
| reranker = dependencies.get_reranker() | |
| llm = dependencies.get_llm_client() | |
| # Determine document scope: request-level doc_id overrides session-level docs | |
| if req.doc_id: | |
| doc_ids = [req.doc_id] | |
| else: | |
| doc_ids = [d["doc_id"] for d in session.active_docs] if session.active_docs else None | |
| result = run_query_pipeline( | |
| question=question, | |
| encoder=encoder, | |
| index=index, | |
| reranker=reranker, | |
| llm=llm, | |
| enable_nli=req.enable_nli, | |
| conversation_context=conversation_context, | |
| pinned_context=pinned_context, | |
| doc_ids=doc_ids, | |
| ) | |
| except Exception as e: | |
| logger.error("Query error: %s", e, exc_info=True) | |
| return { | |
| "answer": "I cannot answer due to an internal error.", | |
| "verdict": "refused", | |
| "citations": [], | |
| "evidence": [], | |
| } | |
| # Store the turn in memory (strip raw evidence to save memory) | |
| if result.get("verdict") != "refused": | |
| turn = ConversationTurn( | |
| question=req.question, | |
| answer=result.get("answer", ""), | |
| verdict=result.get("verdict", "refused"), | |
| confidence=result.get("confidence", 0.0), | |
| citations=result.get("citations", []), | |
| evidence=[], # don't store full evidence in memory | |
| support_ratio=result.get("support_ratio", 0.0), | |
| citation_coverage=result.get("citation_coverage", 0.0), | |
| ) | |
| session.add_turn(turn) | |
| # Include original question in response for UI tracking | |
| result["original_question"] = req.question | |
| result["rewritten_question"] = question if question != req.question else None | |
| return result | |
| # -- Session & Memory Endpoints ----------------------------------------------- | |
| def create_session( | |
| req: SessionRequest, | |
| current_user: AuthenticatedUser = Depends(get_current_user), | |
| ): | |
| """Create a new conversation session.""" | |
| session_name = req.session_name or "New Session" | |
| session_id, _ = _create_session(current_user.user_id, session_name) | |
| return {"session_id": session_id, "session_name": session_name} | |
| def get_session( | |
| session_id: str, | |
| current_user: AuthenticatedUser = Depends(get_current_user), | |
| ): | |
| """Get session metadata, history, and stats.""" | |
| _, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False) | |
| return { | |
| "session_id": session_id, | |
| "session_name": session.session_name, | |
| "stats": session.get_stats(), | |
| "turns": [t.to_dict() for t in session.turns], | |
| "pins": [p.to_dict() for p in session.pins], | |
| "active_docs": session.active_docs, | |
| } | |
| def get_history( | |
| session_id: str, | |
| current_user: AuthenticatedUser = Depends(get_current_user), | |
| ): | |
| """Return conversation history for a session.""" | |
| _, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False) | |
| return {"turns": [t.to_dict() for t in session.turns]} | |
| def pin_insight( | |
| session_id: str, | |
| req: PinRequest, | |
| current_user: AuthenticatedUser = Depends(get_current_user), | |
| ): | |
| """Pin a key insight to session memory.""" | |
| _, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False) | |
| pin = session.add_pin( | |
| text=req.text, | |
| source_question=req.source_question, | |
| from_doc=req.from_doc, | |
| ) | |
| return {"status": "pinned", "pin": pin.to_dict()} | |
| def remove_pin( | |
| session_id: str, | |
| pin_id: str, | |
| current_user: AuthenticatedUser = Depends(get_current_user), | |
| ): | |
| """Remove a pinned insight.""" | |
| _, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False) | |
| removed = session.remove_pin(pin_id) | |
| if not removed: | |
| raise HTTPException(status_code=404, detail="Pin not found") | |
| return {"status": "removed", "pin_id": pin_id} | |
| def clear_session( | |
| session_id: str, | |
| current_user: AuthenticatedUser = Depends(get_current_user), | |
| ): | |
| """Clear session history and pins (keeps documents).""" | |
| _, session = _resolve_session(current_user.user_id, session_id, create_if_missing=False) | |
| session.clear() | |
| return {"status": "cleared"} | |
| def list_sessions(current_user: AuthenticatedUser = Depends(get_current_user)): | |
| """List all active sessions with basic info.""" | |
| user_sessions = _get_user_sessions(current_user.user_id) | |
| return [ | |
| { | |
| "session_id": sid, | |
| "session_name": mem.session_name, | |
| "turn_count": len(mem.turns), | |
| "pin_count": len(mem.pins), | |
| "doc_count": len(mem.active_docs), | |
| "created_at": mem.created_at.strftime("%Y-%m-%d %H:%M"), | |
| } | |
| for sid, mem in user_sessions.items() | |
| ] | |
| # -- Index Management --------------------------------------------------------- | |
| def clear_index(current_user: AuthenticatedUser = Depends(get_current_user)): | |
| try: | |
| dependencies.clear_index(current_user.user_id) | |
| for session in _get_user_sessions(current_user.user_id).values(): | |
| session.active_docs.clear() | |
| return {"status": "cleared", "message": "Index cleared successfully."} | |
| except Exception as e: | |
| logger.error("Clear index error: %s", e) | |
| raise HTTPException(status_code=500, detail="Failed to clear index") | |
| def index_status(current_user: AuthenticatedUser = Depends(get_current_user)): | |
| index = dependencies.get_index(current_user.user_id) | |
| return { | |
| "chunk_count": index.chunk_count, | |
| "document_count": index.document_count, | |
| "doc_ids": list(index.get_doc_ids()), | |
| } | |
| def health_check(): | |
| return {"status": "ok", "version": "2.0.0"} | |