""" FastAPI Application for Multimodal RAG System US Army Medical Research Papers Q&A """ """ FastAPI Application for Multimodal RAG System US Army Medical Research Papers Q&A """ import os import io import logging from typing import List, Dict, Optional, Union from contextlib import asynccontextmanager from datetime import datetime from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field # Import from query_index (standalone) from query_index import MultimodalRAGSystem # PDF generation libraries from reportlab.lib.pagesizes import letter from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle from reportlab.lib.units import inch from reportlab.lib.enums import TA_CENTER, TA_JUSTIFY from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, PageBreak, Image as RLImage, Table, TableStyle from reportlab.lib import colors import requests # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Global variables rag_system: Optional[MultimodalRAGSystem] = None # Keep short conversation history chat_history: List[Dict[str, str]] = [] MAX_HISTORY_TURNS = 3 # Lifecycle management @asynccontextmanager async def lifespan(app: FastAPI): """Initialize and cleanup RAG system""" global rag_system logger.info("Starting RAG system initialization...") try: rag_system = MultimodalRAGSystem() logger.info("RAG system initialized successfully!") except Exception as e: logger.error(f"Error during initialization: {str(e)}") rag_system = None yield logger.info("Shutting down RAG system...") rag_system = None # Create FastAPI app app = FastAPI( title="Multimodal RAG API", description="Q&A system for US Army medical research papers (Text + Images)", version="2.0.0", lifespan=lifespan ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Mount static files app.mount("/static", StaticFiles(directory="static"), name="static") # Mount extracted images # This allows the frontend to load images via /extracted_images/filename.jpg if os.path.exists("extracted_images"): app.mount("/extracted_images", StaticFiles(directory="extracted_images"), name="images") # Mount PDF documents if os.path.exists("WHEC_Documents"): app.mount("/documents", StaticFiles(directory="WHEC_Documents"), name="documents") # Pydantic models class QueryRequest(BaseModel): question: str = Field(..., min_length=1, max_length=1000, description="Question to ask") class ImageSource(BaseModel): path: Optional[str] filename: Optional[str] score: Optional[float] page: Optional[Union[str, int]] file: Optional[str] link: Optional[str] = None class TextSource(BaseModel): text: str score: float page: Optional[Union[str, int]] file: Optional[str] link: Optional[str] = None class QueryResponse(BaseModel): answer: str images: List[ImageSource] texts: List[TextSource] question: str class HealthResponse(BaseModel): status: str rag_initialized: bool class ConversationItem(BaseModel): question: str answer: str images: List[ImageSource] texts: List[TextSource] timestamp: str class ReportRequest(BaseModel): conversations: List[ConversationItem] # API Endpoints @app.get("/", tags=["Root"]) async def root(): """Serve the frontend application""" return FileResponse('static/index.html') @app.get("/health", response_model=HealthResponse, tags=["Health"]) async def health_check(): """Health check endpoint""" return HealthResponse( status="healthy", rag_initialized=rag_system is not None ) @app.post("/query", response_model=QueryResponse, tags=["Query"]) async def query_rag(request: QueryRequest): global chat_history if not rag_system: raise HTTPException(status_code=503, detail="RAG system not initialized.") try: result = rag_system.ask( query_str=request.question, chat_history=chat_history ) # Update history chat_history.append({"role": "user", "content": request.question}) chat_history.append({"role": "assistant", "content": result["answer"]}) # Trim history safely if len(chat_history) > MAX_HISTORY_TURNS * 2: chat_history = chat_history[-MAX_HISTORY_TURNS * 2:] return QueryResponse( answer=result["answer"], images=result["images"], texts=result["texts"], question=request.question ) except Exception as e: logger.error(f"Error processing query: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/generate-report", tags=["Report"]) async def generate_report(request: ReportRequest): """Generate a PDF report from conversation data""" try: buffer = io.BytesIO() # Create PDF document doc = SimpleDocTemplate( buffer, pagesize=letter, rightMargin=0.75*inch, leftMargin=0.75*inch, topMargin=1*inch, bottomMargin=0.75*inch ) # Container for the 'Flowable' objects story = [] # Define styles styles = getSampleStyleSheet() # Custom styles title_style = ParagraphStyle( 'CustomTitle', parent=styles['Heading1'], fontSize=24, textColor=colors.HexColor('#3b82f6'), spaceAfter=12, alignment=TA_CENTER, fontName='Helvetica-Bold' ) subtitle_style = ParagraphStyle( 'CustomSubtitle', parent=styles['Normal'], fontSize=12, textColor=colors.HexColor('#64748b'), spaceAfter=20, alignment=TA_CENTER ) question_style = ParagraphStyle( 'QuestionStyle', parent=styles['Heading2'], fontSize=14, textColor=colors.HexColor('#3b82f6'), spaceAfter=10, fontName='Helvetica-Bold' ) answer_style = ParagraphStyle( 'AnswerStyle', parent=styles['Normal'], fontSize=11, alignment=TA_JUSTIFY, spaceAfter=12 ) source_title_style = ParagraphStyle( 'SourceTitle', parent=styles['Heading3'], fontSize=11, textColor=colors.HexColor('#8b5cf6'), spaceAfter=6, fontName='Helvetica-Bold' ) # Add title story.append(Paragraph("WHEC Research Assistant", title_style)) story.append(Paragraph("Conversation Report", title_style)) story.append(Spacer(1, 0.2*inch)) # Add metadata current_time = datetime.now().strftime("%B %d, %Y at %I:%M %p") story.append(Paragraph(f"Generated: {current_time}", subtitle_style)) story.append(Paragraph("Source: WHEC (Warrior Heat- and Exertion-Related Events Collaborative)", subtitle_style)) story.append(Paragraph(f"Total Questions: {len(request.conversations)}", subtitle_style)) story.append(Spacer(1, 0.3*inch)) # Add horizontal line story.append(Spacer(1, 0.1*inch)) # Process each conversation for idx, conv in enumerate(request.conversations, 1): # Question story.append(Paragraph(f"Question {idx}", question_style)) story.append(Paragraph(conv.question, answer_style)) story.append(Spacer(1, 0.15*inch)) # Answer story.append(Paragraph("Answer", source_title_style)) story.append(Paragraph(conv.answer, answer_style)) story.append(Spacer(1, 0.15*inch)) # Text Sources if conv.texts: story.append(Paragraph("Referenced Text Sources", source_title_style)) for i, txt in enumerate(conv.texts, 1): source_text = f"[{i}] {txt.file or 'Unknown Document'} (Page {txt.page or 'N/A'}, {round((txt.score or 0) * 100)}% match)" story.append(Paragraph(source_text, styles['Normal'])) excerpt = f'"{txt.text[:300]}..."' story.append(Paragraph(excerpt, styles['Normal'])) story.append(Spacer(1, 0.1*inch)) story.append(Spacer(1, 0.1*inch)) # Image Sources if conv.images: relevant_images = [img for img in conv.images if (img.score or 0) >= 0.3] if relevant_images: story.append(Paragraph("Referenced Images", source_title_style)) for i, img in enumerate(relevant_images[:3], 1): # Limit to 3 images # Add image metadata img_text = f"[{i}] {img.filename or 'Unknown'} from {img.file or 'Unknown Document'} (Page {img.page or 'N/A'}, {round((img.score or 0) * 100)}% match)" story.append(Paragraph(img_text, styles['Normal'])) # Try to add the actual image if img.path: try: # Construct the full path to the image img_path = img.path.replace('/extracted_images/', '') full_img_path = os.path.join('extracted_images', img_path) if os.path.exists(full_img_path): # Add image with max width of 4 inches rl_img = RLImage(full_img_path, width=4*inch, height=3*inch, kind='proportional') story.append(rl_img) story.append(Spacer(1, 0.1*inch)) except Exception as e: logger.warning(f"Could not add image to PDF: {e}") story.append(Spacer(1, 0.15*inch)) # Add separator between conversations (except after last one) if idx < len(request.conversations): story.append(Spacer(1, 0.2*inch)) # Add a horizontal line as separator line_table = Table([['']], colWidths=[6.5*inch]) line_table.setStyle(TableStyle([ ('LINEABOVE', (0, 0), (-1, 0), 1, colors.HexColor('#334155')), ])) story.append(line_table) story.append(Spacer(1, 0.2*inch)) # Build PDF doc.build(story) # Get PDF from buffer buffer.seek(0) # Return PDF as download return StreamingResponse( buffer, media_type="application/pdf", headers={ "Content-Disposition": f"attachment; filename=WHEC_Report_{datetime.now().strftime('%Y-%m-%d')}.pdf" } ) except Exception as e: logger.error(f"Error generating PDF report: {str(e)}") raise HTTPException(status_code=500, detail=f"Error generating report: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)