Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI Application for Multimodal RAG System | |
| US Army Medical Research Papers Q&A | |
| """ | |
| """ | |
| FastAPI Application for Multimodal RAG System | |
| US Army Medical Research Papers Q&A | |
| """ | |
| import os | |
| import io | |
| import logging | |
| from typing import List, Dict, Optional, Union | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| # Import from query_index (standalone) | |
| from query_index import MultimodalRAGSystem | |
| # PDF generation libraries | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
| from reportlab.lib.units import inch | |
| from reportlab.lib.enums import TA_CENTER, TA_JUSTIFY | |
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, PageBreak, Image as RLImage, Table, TableStyle | |
| from reportlab.lib import colors | |
| import requests | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global variables | |
| rag_system: Optional[MultimodalRAGSystem] = None | |
| # Keep short conversation history | |
| chat_history: List[Dict[str, str]] = [] | |
| MAX_HISTORY_TURNS = 3 | |
| # Lifecycle management | |
| async def lifespan(app: FastAPI): | |
| """Initialize and cleanup RAG system""" | |
| global rag_system | |
| logger.info("Starting RAG system initialization...") | |
| try: | |
| rag_system = MultimodalRAGSystem() | |
| logger.info("RAG system initialized successfully!") | |
| except Exception as e: | |
| logger.error(f"Error during initialization: {str(e)}") | |
| rag_system = None | |
| yield | |
| logger.info("Shutting down RAG system...") | |
| rag_system = None | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Multimodal RAG API", | |
| description="Q&A system for US Army medical research papers (Text + Images)", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Mount extracted images | |
| # This allows the frontend to load images via /extracted_images/filename.jpg | |
| if os.path.exists("extracted_images"): | |
| app.mount("/extracted_images", StaticFiles(directory="extracted_images"), name="images") | |
| # Mount PDF documents | |
| if os.path.exists("WHEC_Documents"): | |
| app.mount("/documents", StaticFiles(directory="WHEC_Documents"), name="documents") | |
| # Pydantic models | |
| class QueryRequest(BaseModel): | |
| question: str = Field(..., min_length=1, max_length=1000, description="Question to ask") | |
| class ImageSource(BaseModel): | |
| path: Optional[str] | |
| filename: Optional[str] | |
| score: Optional[float] | |
| page: Optional[Union[str, int]] | |
| file: Optional[str] | |
| link: Optional[str] = None | |
| class TextSource(BaseModel): | |
| text: str | |
| score: float | |
| page: Optional[Union[str, int]] | |
| file: Optional[str] | |
| link: Optional[str] = None | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| images: List[ImageSource] | |
| texts: List[TextSource] | |
| question: str | |
| class HealthResponse(BaseModel): | |
| status: str | |
| rag_initialized: bool | |
| class ConversationItem(BaseModel): | |
| question: str | |
| answer: str | |
| images: List[ImageSource] | |
| texts: List[TextSource] | |
| timestamp: str | |
| class ReportRequest(BaseModel): | |
| conversations: List[ConversationItem] | |
| # API Endpoints | |
| async def root(): | |
| """Serve the frontend application""" | |
| return FileResponse('static/index.html') | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return HealthResponse( | |
| status="healthy", | |
| rag_initialized=rag_system is not None | |
| ) | |
| async def query_rag(request: QueryRequest): | |
| global chat_history | |
| if not rag_system: | |
| raise HTTPException(status_code=503, detail="RAG system not initialized.") | |
| try: | |
| result = rag_system.ask( | |
| query_str=request.question, | |
| chat_history=chat_history | |
| ) | |
| # Update history | |
| chat_history.append({"role": "user", "content": request.question}) | |
| chat_history.append({"role": "assistant", "content": result["answer"]}) | |
| # Trim history safely | |
| if len(chat_history) > MAX_HISTORY_TURNS * 2: | |
| chat_history = chat_history[-MAX_HISTORY_TURNS * 2:] | |
| return QueryResponse( | |
| answer=result["answer"], | |
| images=result["images"], | |
| texts=result["texts"], | |
| question=request.question | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing query: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_report(request: ReportRequest): | |
| """Generate a PDF report from conversation data""" | |
| try: | |
| buffer = io.BytesIO() | |
| # Create PDF document | |
| doc = SimpleDocTemplate( | |
| buffer, | |
| pagesize=letter, | |
| rightMargin=0.75*inch, | |
| leftMargin=0.75*inch, | |
| topMargin=1*inch, | |
| bottomMargin=0.75*inch | |
| ) | |
| # Container for the 'Flowable' objects | |
| story = [] | |
| # Define styles | |
| styles = getSampleStyleSheet() | |
| # Custom styles | |
| title_style = ParagraphStyle( | |
| 'CustomTitle', | |
| parent=styles['Heading1'], | |
| fontSize=24, | |
| textColor=colors.HexColor('#3b82f6'), | |
| spaceAfter=12, | |
| alignment=TA_CENTER, | |
| fontName='Helvetica-Bold' | |
| ) | |
| subtitle_style = ParagraphStyle( | |
| 'CustomSubtitle', | |
| parent=styles['Normal'], | |
| fontSize=12, | |
| textColor=colors.HexColor('#64748b'), | |
| spaceAfter=20, | |
| alignment=TA_CENTER | |
| ) | |
| question_style = ParagraphStyle( | |
| 'QuestionStyle', | |
| parent=styles['Heading2'], | |
| fontSize=14, | |
| textColor=colors.HexColor('#3b82f6'), | |
| spaceAfter=10, | |
| fontName='Helvetica-Bold' | |
| ) | |
| answer_style = ParagraphStyle( | |
| 'AnswerStyle', | |
| parent=styles['Normal'], | |
| fontSize=11, | |
| alignment=TA_JUSTIFY, | |
| spaceAfter=12 | |
| ) | |
| source_title_style = ParagraphStyle( | |
| 'SourceTitle', | |
| parent=styles['Heading3'], | |
| fontSize=11, | |
| textColor=colors.HexColor('#8b5cf6'), | |
| spaceAfter=6, | |
| fontName='Helvetica-Bold' | |
| ) | |
| # Add title | |
| story.append(Paragraph("WHEC Research Assistant", title_style)) | |
| story.append(Paragraph("Conversation Report", title_style)) | |
| story.append(Spacer(1, 0.2*inch)) | |
| # Add metadata | |
| current_time = datetime.now().strftime("%B %d, %Y at %I:%M %p") | |
| story.append(Paragraph(f"Generated: {current_time}", subtitle_style)) | |
| story.append(Paragraph("Source: WHEC (Warrior Heat- and Exertion-Related Events Collaborative)", subtitle_style)) | |
| story.append(Paragraph(f"Total Questions: {len(request.conversations)}", subtitle_style)) | |
| story.append(Spacer(1, 0.3*inch)) | |
| # Add horizontal line | |
| story.append(Spacer(1, 0.1*inch)) | |
| # Process each conversation | |
| for idx, conv in enumerate(request.conversations, 1): | |
| # Question | |
| story.append(Paragraph(f"Question {idx}", question_style)) | |
| story.append(Paragraph(conv.question, answer_style)) | |
| story.append(Spacer(1, 0.15*inch)) | |
| # Answer | |
| story.append(Paragraph("Answer", source_title_style)) | |
| story.append(Paragraph(conv.answer, answer_style)) | |
| story.append(Spacer(1, 0.15*inch)) | |
| # Text Sources | |
| if conv.texts: | |
| story.append(Paragraph("Referenced Text Sources", source_title_style)) | |
| for i, txt in enumerate(conv.texts, 1): | |
| source_text = f"[{i}] {txt.file or 'Unknown Document'} (Page {txt.page or 'N/A'}, {round((txt.score or 0) * 100)}% match)" | |
| story.append(Paragraph(source_text, styles['Normal'])) | |
| excerpt = f'<i>"{txt.text[:300]}..."</i>' | |
| story.append(Paragraph(excerpt, styles['Normal'])) | |
| story.append(Spacer(1, 0.1*inch)) | |
| story.append(Spacer(1, 0.1*inch)) | |
| # Image Sources | |
| if conv.images: | |
| relevant_images = [img for img in conv.images if (img.score or 0) >= 0.3] | |
| if relevant_images: | |
| story.append(Paragraph("Referenced Images", source_title_style)) | |
| for i, img in enumerate(relevant_images[:3], 1): # Limit to 3 images | |
| # Add image metadata | |
| img_text = f"[{i}] {img.filename or 'Unknown'} from {img.file or 'Unknown Document'} (Page {img.page or 'N/A'}, {round((img.score or 0) * 100)}% match)" | |
| story.append(Paragraph(img_text, styles['Normal'])) | |
| # Try to add the actual image | |
| if img.path: | |
| try: | |
| # Construct the full path to the image | |
| img_path = img.path.replace('/extracted_images/', '') | |
| full_img_path = os.path.join('extracted_images', img_path) | |
| if os.path.exists(full_img_path): | |
| # Add image with max width of 4 inches | |
| rl_img = RLImage(full_img_path, width=4*inch, height=3*inch, kind='proportional') | |
| story.append(rl_img) | |
| story.append(Spacer(1, 0.1*inch)) | |
| except Exception as e: | |
| logger.warning(f"Could not add image to PDF: {e}") | |
| story.append(Spacer(1, 0.15*inch)) | |
| # Add separator between conversations (except after last one) | |
| if idx < len(request.conversations): | |
| story.append(Spacer(1, 0.2*inch)) | |
| # Add a horizontal line as separator | |
| line_table = Table([['']], colWidths=[6.5*inch]) | |
| line_table.setStyle(TableStyle([ | |
| ('LINEABOVE', (0, 0), (-1, 0), 1, colors.HexColor('#334155')), | |
| ])) | |
| story.append(line_table) | |
| story.append(Spacer(1, 0.2*inch)) | |
| # Build PDF | |
| doc.build(story) | |
| # Get PDF from buffer | |
| buffer.seek(0) | |
| # Return PDF as download | |
| return StreamingResponse( | |
| buffer, | |
| media_type="application/pdf", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=WHEC_Report_{datetime.now().strftime('%Y-%m-%d')}.pdf" | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating PDF report: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error generating report: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |