Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI Application for Multimodal RAG System | |
| US Army Medical Research Papers Q&A | |
| """ | |
| import os | |
| import logging | |
| from typing import List, Dict, Optional, Union | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| # Import from query_index (standalone) | |
| from query_index import MultimodalRAGSystem | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global variables | |
| rag_system: Optional[MultimodalRAGSystem] = None | |
| # Store last question-answer pair for simple follow-up | |
| last_qa_context: Optional[str] = None | |
| # Lifecycle management | |
| async def lifespan(app: FastAPI): | |
| """Initialize and cleanup RAG system""" | |
| global rag_system | |
| logger.info("Starting RAG system initialization...") | |
| try: | |
| rag_system = MultimodalRAGSystem() | |
| logger.info("RAG system initialized successfully!") | |
| except Exception as e: | |
| logger.error(f"Error during initialization: {str(e)}") | |
| rag_system = None | |
| yield | |
| logger.info("Shutting down RAG system...") | |
| rag_system = None | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Multimodal RAG API", | |
| description="Q&A system for US Army medical research papers (Text + Images)", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Mount extracted images | |
| # This allows the frontend to load images via /extracted_images/filename.jpg | |
| if os.path.exists("extracted_images"): | |
| app.mount("/extracted_images", StaticFiles(directory="extracted_images"), name="images") | |
| # Mount PDF documents | |
| if os.path.exists("WHEC_Documents"): | |
| app.mount("/documents", StaticFiles(directory="WHEC_Documents"), name="documents") | |
| # Pydantic models | |
| class QueryRequest(BaseModel): | |
| question: str = Field(..., min_length=1, max_length=1000, description="Question to ask") | |
| class ImageSource(BaseModel): | |
| path: Optional[str] | |
| filename: Optional[str] | |
| score: Optional[float] | |
| page: Optional[Union[str, int]] # could be int or str depending on metadata | |
| file: Optional[str] | |
| link: Optional[str] = None | |
| class TextSource(BaseModel): | |
| text: str | |
| score: float | |
| page: Optional[Union[str, int]] | |
| file: Optional[str] | |
| link: Optional[str] = None | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| images: List[ImageSource] | |
| texts: List[TextSource] | |
| question: str | |
| class HealthResponse(BaseModel): | |
| status: str | |
| rag_initialized: bool | |
| # API Endpoints | |
| async def root(): | |
| """Serve the frontend application""" | |
| return FileResponse('static/index.html') | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy", | |
| rag_initialized=rag_system is not None | |
| ) | |
| async def query_rag(request: QueryRequest): | |
| """ | |
| Query the RAG system | |
| """ | |
| global last_qa_context | |
| if not rag_system: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="RAG system not initialized. Check logs for errors." | |
| ) | |
| try: | |
| # Build prompt using previous Q/A if available | |
| if last_qa_context: | |
| prompt = ( | |
| f"Previous question and answer:\n" | |
| f"{last_qa_context}\n\n" | |
| f"Follow up question:\n" | |
| f"{request.question}" | |
| ) | |
| else: | |
| prompt = request.question | |
| # Query RAG system | |
| result = rag_system.ask(prompt) | |
| # Save current Q/A as context for next turn | |
| last_qa_context = ( | |
| f"Question: {request.question}\n" | |
| f"Answer: {result['answer']}" | |
| ) | |
| return QueryResponse( | |
| answer=result['answer'], | |
| images=result['images'], | |
| texts=result['texts'], | |
| question=request.question | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing query: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |