|
|
""" |
|
|
FastAPI Application for Multimodal RAG System |
|
|
US Army Medical Research Papers Q&A |
|
|
""" |
|
|
|
|
|
""" |
|
|
FastAPI Application for Multimodal RAG System |
|
|
US Army Medical Research Papers Q&A |
|
|
""" |
|
|
|
|
|
import os |
|
|
import io |
|
|
import logging |
|
|
from typing import List, Dict, Optional, Union |
|
|
from contextlib import asynccontextmanager |
|
|
from datetime import datetime |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import FileResponse, StreamingResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
from query_index import MultimodalRAGSystem |
|
|
|
|
|
|
|
|
from reportlab.lib.pagesizes import letter |
|
|
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle |
|
|
from reportlab.lib.units import inch |
|
|
from reportlab.lib.enums import TA_CENTER, TA_JUSTIFY |
|
|
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, PageBreak, Image as RLImage, Table, TableStyle |
|
|
from reportlab.lib import colors |
|
|
import requests |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
rag_system: Optional[MultimodalRAGSystem] = None |
|
|
|
|
|
|
|
|
chat_history: List[Dict[str, str]] = [] |
|
|
MAX_HISTORY_TURNS = 3 |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Initialize and cleanup RAG system""" |
|
|
global rag_system |
|
|
|
|
|
logger.info("Starting RAG system initialization...") |
|
|
try: |
|
|
rag_system = MultimodalRAGSystem() |
|
|
logger.info("RAG system initialized successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"Error during initialization: {str(e)}") |
|
|
rag_system = None |
|
|
|
|
|
yield |
|
|
|
|
|
logger.info("Shutting down RAG system...") |
|
|
rag_system = None |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Multimodal RAG API", |
|
|
description="Q&A system for US Army medical research papers (Text + Images)", |
|
|
version="2.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists("extracted_images"): |
|
|
app.mount("/extracted_images", StaticFiles(directory="extracted_images"), name="images") |
|
|
|
|
|
|
|
|
if os.path.exists("WHEC_Documents"): |
|
|
app.mount("/documents", StaticFiles(directory="WHEC_Documents"), name="documents") |
|
|
|
|
|
|
|
|
class QueryRequest(BaseModel): |
|
|
question: str = Field(..., min_length=1, max_length=1000, description="Question to ask") |
|
|
|
|
|
class ImageSource(BaseModel): |
|
|
path: Optional[str] |
|
|
filename: Optional[str] |
|
|
score: Optional[float] |
|
|
page: Optional[Union[str, int]] |
|
|
file: Optional[str] |
|
|
link: Optional[str] = None |
|
|
|
|
|
class TextSource(BaseModel): |
|
|
text: str |
|
|
score: float |
|
|
page: Optional[Union[str, int]] |
|
|
file: Optional[str] |
|
|
link: Optional[str] = None |
|
|
|
|
|
class QueryResponse(BaseModel): |
|
|
answer: str |
|
|
images: List[ImageSource] |
|
|
texts: List[TextSource] |
|
|
question: str |
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
rag_initialized: bool |
|
|
|
|
|
class ConversationItem(BaseModel): |
|
|
question: str |
|
|
answer: str |
|
|
images: List[ImageSource] |
|
|
texts: List[TextSource] |
|
|
timestamp: str |
|
|
|
|
|
class ReportRequest(BaseModel): |
|
|
conversations: List[ConversationItem] |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", tags=["Root"]) |
|
|
async def root(): |
|
|
"""Serve the frontend application""" |
|
|
return FileResponse('static/index.html') |
|
|
|
|
|
@app.get("/health", response_model=HealthResponse, tags=["Health"]) |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return HealthResponse( |
|
|
status="healthy", |
|
|
rag_initialized=rag_system is not None |
|
|
) |
|
|
|
|
|
@app.post("/query", response_model=QueryResponse, tags=["Query"]) |
|
|
async def query_rag(request: QueryRequest): |
|
|
global chat_history |
|
|
|
|
|
if not rag_system: |
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized.") |
|
|
|
|
|
try: |
|
|
result = rag_system.ask( |
|
|
query_str=request.question, |
|
|
chat_history=chat_history |
|
|
) |
|
|
|
|
|
|
|
|
chat_history.append({"role": "user", "content": request.question}) |
|
|
chat_history.append({"role": "assistant", "content": result["answer"]}) |
|
|
|
|
|
|
|
|
if len(chat_history) > MAX_HISTORY_TURNS * 2: |
|
|
chat_history = chat_history[-MAX_HISTORY_TURNS * 2:] |
|
|
|
|
|
return QueryResponse( |
|
|
answer=result["answer"], |
|
|
images=result["images"], |
|
|
texts=result["texts"], |
|
|
question=request.question |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing query: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/generate-report", tags=["Report"]) |
|
|
async def generate_report(request: ReportRequest): |
|
|
"""Generate a PDF report from conversation data""" |
|
|
try: |
|
|
buffer = io.BytesIO() |
|
|
|
|
|
|
|
|
doc = SimpleDocTemplate( |
|
|
buffer, |
|
|
pagesize=letter, |
|
|
rightMargin=0.75*inch, |
|
|
leftMargin=0.75*inch, |
|
|
topMargin=1*inch, |
|
|
bottomMargin=0.75*inch |
|
|
) |
|
|
|
|
|
|
|
|
story = [] |
|
|
|
|
|
|
|
|
styles = getSampleStyleSheet() |
|
|
|
|
|
|
|
|
title_style = ParagraphStyle( |
|
|
'CustomTitle', |
|
|
parent=styles['Heading1'], |
|
|
fontSize=24, |
|
|
textColor=colors.HexColor('#3b82f6'), |
|
|
spaceAfter=12, |
|
|
alignment=TA_CENTER, |
|
|
fontName='Helvetica-Bold' |
|
|
) |
|
|
|
|
|
subtitle_style = ParagraphStyle( |
|
|
'CustomSubtitle', |
|
|
parent=styles['Normal'], |
|
|
fontSize=12, |
|
|
textColor=colors.HexColor('#64748b'), |
|
|
spaceAfter=20, |
|
|
alignment=TA_CENTER |
|
|
) |
|
|
|
|
|
question_style = ParagraphStyle( |
|
|
'QuestionStyle', |
|
|
parent=styles['Heading2'], |
|
|
fontSize=14, |
|
|
textColor=colors.HexColor('#3b82f6'), |
|
|
spaceAfter=10, |
|
|
fontName='Helvetica-Bold' |
|
|
) |
|
|
|
|
|
answer_style = ParagraphStyle( |
|
|
'AnswerStyle', |
|
|
parent=styles['Normal'], |
|
|
fontSize=11, |
|
|
alignment=TA_JUSTIFY, |
|
|
spaceAfter=12 |
|
|
) |
|
|
|
|
|
source_title_style = ParagraphStyle( |
|
|
'SourceTitle', |
|
|
parent=styles['Heading3'], |
|
|
fontSize=11, |
|
|
textColor=colors.HexColor('#8b5cf6'), |
|
|
spaceAfter=6, |
|
|
fontName='Helvetica-Bold' |
|
|
) |
|
|
|
|
|
|
|
|
story.append(Paragraph("WHEC Research Assistant", title_style)) |
|
|
story.append(Paragraph("Conversation Report", title_style)) |
|
|
story.append(Spacer(1, 0.2*inch)) |
|
|
|
|
|
|
|
|
current_time = datetime.now().strftime("%B %d, %Y at %I:%M %p") |
|
|
story.append(Paragraph(f"Generated: {current_time}", subtitle_style)) |
|
|
story.append(Paragraph("Source: WHEC (Warrior Heat- and Exertion-Related Events Collaborative)", subtitle_style)) |
|
|
story.append(Paragraph(f"Total Questions: {len(request.conversations)}", subtitle_style)) |
|
|
|
|
|
story.append(Spacer(1, 0.3*inch)) |
|
|
|
|
|
|
|
|
story.append(Spacer(1, 0.1*inch)) |
|
|
|
|
|
|
|
|
for idx, conv in enumerate(request.conversations, 1): |
|
|
|
|
|
story.append(Paragraph(f"Question {idx}", question_style)) |
|
|
story.append(Paragraph(conv.question, answer_style)) |
|
|
story.append(Spacer(1, 0.15*inch)) |
|
|
|
|
|
|
|
|
story.append(Paragraph("Answer", source_title_style)) |
|
|
story.append(Paragraph(conv.answer, answer_style)) |
|
|
story.append(Spacer(1, 0.15*inch)) |
|
|
|
|
|
|
|
|
if conv.texts: |
|
|
story.append(Paragraph("Referenced Text Sources", source_title_style)) |
|
|
for i, txt in enumerate(conv.texts, 1): |
|
|
source_text = f"[{i}] {txt.file or 'Unknown Document'} (Page {txt.page or 'N/A'}, {round((txt.score or 0) * 100)}% match)" |
|
|
story.append(Paragraph(source_text, styles['Normal'])) |
|
|
|
|
|
excerpt = f'<i>"{txt.text[:300]}..."</i>' |
|
|
story.append(Paragraph(excerpt, styles['Normal'])) |
|
|
story.append(Spacer(1, 0.1*inch)) |
|
|
|
|
|
story.append(Spacer(1, 0.1*inch)) |
|
|
|
|
|
|
|
|
if conv.images: |
|
|
relevant_images = [img for img in conv.images if (img.score or 0) >= 0.3] |
|
|
if relevant_images: |
|
|
story.append(Paragraph("Referenced Images", source_title_style)) |
|
|
|
|
|
for i, img in enumerate(relevant_images[:3], 1): |
|
|
|
|
|
img_text = f"[{i}] {img.filename or 'Unknown'} from {img.file or 'Unknown Document'} (Page {img.page or 'N/A'}, {round((img.score or 0) * 100)}% match)" |
|
|
story.append(Paragraph(img_text, styles['Normal'])) |
|
|
|
|
|
|
|
|
if img.path: |
|
|
try: |
|
|
|
|
|
img_path = img.path.replace('/extracted_images/', '') |
|
|
full_img_path = os.path.join('extracted_images', img_path) |
|
|
|
|
|
if os.path.exists(full_img_path): |
|
|
|
|
|
rl_img = RLImage(full_img_path, width=4*inch, height=3*inch, kind='proportional') |
|
|
story.append(rl_img) |
|
|
story.append(Spacer(1, 0.1*inch)) |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not add image to PDF: {e}") |
|
|
|
|
|
story.append(Spacer(1, 0.15*inch)) |
|
|
|
|
|
|
|
|
if idx < len(request.conversations): |
|
|
story.append(Spacer(1, 0.2*inch)) |
|
|
|
|
|
line_table = Table([['']], colWidths=[6.5*inch]) |
|
|
line_table.setStyle(TableStyle([ |
|
|
('LINEABOVE', (0, 0), (-1, 0), 1, colors.HexColor('#334155')), |
|
|
])) |
|
|
story.append(line_table) |
|
|
story.append(Spacer(1, 0.2*inch)) |
|
|
|
|
|
|
|
|
doc.build(story) |
|
|
|
|
|
|
|
|
buffer.seek(0) |
|
|
|
|
|
|
|
|
return StreamingResponse( |
|
|
buffer, |
|
|
media_type="application/pdf", |
|
|
headers={ |
|
|
"Content-Disposition": f"attachment; filename=WHEC_Report_{datetime.now().strftime('%Y-%m-%d')}.pdf" |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating PDF report: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Error generating report: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |