|
|
""" |
|
|
FastAPI Application for Multimodal RAG System |
|
|
US Army Medical Research Papers Q&A |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
from typing import List, Dict, Optional, Union |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import FileResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
from query_index import MultimodalRAGSystem |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
rag_system: Optional[MultimodalRAGSystem] = None |
|
|
|
|
|
|
|
|
last_qa_context: Optional[str] = None |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Initialize and cleanup RAG system""" |
|
|
global rag_system |
|
|
|
|
|
logger.info("Starting RAG system initialization...") |
|
|
try: |
|
|
rag_system = MultimodalRAGSystem() |
|
|
logger.info("RAG system initialized successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"Error during initialization: {str(e)}") |
|
|
rag_system = None |
|
|
|
|
|
yield |
|
|
|
|
|
logger.info("Shutting down RAG system...") |
|
|
rag_system = None |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Multimodal RAG API", |
|
|
description="Q&A system for US Army medical research papers (Text + Images)", |
|
|
version="2.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists("extracted_images"): |
|
|
app.mount("/extracted_images", StaticFiles(directory="extracted_images"), name="images") |
|
|
|
|
|
|
|
|
if os.path.exists("WHEC_Documents"): |
|
|
app.mount("/documents", StaticFiles(directory="WHEC_Documents"), name="documents") |
|
|
|
|
|
|
|
|
class QueryRequest(BaseModel): |
|
|
question: str = Field(..., min_length=1, max_length=1000, description="Question to ask") |
|
|
|
|
|
class ImageSource(BaseModel): |
|
|
path: Optional[str] |
|
|
filename: Optional[str] |
|
|
score: Optional[float] |
|
|
page: Optional[Union[str, int]] |
|
|
file: Optional[str] |
|
|
link: Optional[str] = None |
|
|
|
|
|
class TextSource(BaseModel): |
|
|
text: str |
|
|
score: float |
|
|
page: Optional[Union[str, int]] |
|
|
file: Optional[str] |
|
|
link: Optional[str] = None |
|
|
|
|
|
class QueryResponse(BaseModel): |
|
|
answer: str |
|
|
images: List[ImageSource] |
|
|
texts: List[TextSource] |
|
|
question: str |
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
rag_initialized: bool |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", tags=["Root"]) |
|
|
async def root(): |
|
|
"""Serve the frontend application""" |
|
|
return FileResponse('static/index.html') |
|
|
|
|
|
@app.get("/health", response_model=HealthResponse, tags=["Health"]) |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return HealthResponse( |
|
|
status="healthy", |
|
|
rag_initialized=rag_system is not None |
|
|
) |
|
|
|
|
|
@app.post("/query", response_model=QueryResponse, tags=["Query"]) |
|
|
async def query_rag(request: QueryRequest): |
|
|
""" |
|
|
Query the RAG system |
|
|
""" |
|
|
global last_qa_context |
|
|
|
|
|
if not rag_system: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="RAG system not initialized. Check logs for errors." |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
if last_qa_context: |
|
|
prompt = ( |
|
|
f"Previous question and answer:\n" |
|
|
f"{last_qa_context}\n\n" |
|
|
f"Follow up question:\n" |
|
|
f"{request.question}" |
|
|
) |
|
|
else: |
|
|
prompt = request.question |
|
|
|
|
|
|
|
|
result = rag_system.ask(prompt) |
|
|
|
|
|
|
|
|
last_qa_context = ( |
|
|
f"Question: {request.question}\n" |
|
|
f"Answer: {result['answer']}" |
|
|
) |
|
|
|
|
|
return QueryResponse( |
|
|
answer=result['answer'], |
|
|
images=result['images'], |
|
|
texts=result['texts'], |
|
|
question=request.question |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing query: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |