File size: 4,254 Bytes
2266343
 
 
 
 
 
 
6aee128
2266343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ba7476
 
 
 
2266343
 
 
 
 
 
 
 
6aee128
2266343
8ba7476
2266343
 
 
 
6aee128
2266343
8ba7476
2266343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""

FastAPI Application for Multimodal RAG System

US Army Medical Research Papers Q&A

"""

import os
import logging
from typing import List, Dict, Optional, Union
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field

# Import from query_index (standalone)
from query_index import MultimodalRAGSystem

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Global variables
rag_system: Optional[MultimodalRAGSystem] = None

# Lifecycle management
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize and cleanup RAG system"""
    global rag_system
    
    logger.info("Starting RAG system initialization...")
    try:
        rag_system = MultimodalRAGSystem()
        logger.info("RAG system initialized successfully!")
    except Exception as e:
        logger.error(f"Error during initialization: {str(e)}")
        rag_system = None
    
    yield
    
    logger.info("Shutting down RAG system...")
    rag_system = None

# Create FastAPI app
app = FastAPI(
    title="Multimodal RAG API",
    description="Q&A system for US Army medical research papers (Text + Images)",
    version="2.0.0",
    lifespan=lifespan
)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")

# Mount extracted images
# This allows the frontend to load images via /extracted_images/filename.jpg
if os.path.exists("extracted_images"):
    app.mount("/extracted_images", StaticFiles(directory="extracted_images"), name="images")

# Mount PDF documents
if os.path.exists("WHEC_Documents"):
    app.mount("/documents", StaticFiles(directory="WHEC_Documents"), name="documents")

# Pydantic models
class QueryRequest(BaseModel):
    question: str = Field(..., min_length=1, max_length=1000, description="Question to ask")

class ImageSource(BaseModel):
    path: Optional[str]
    filename: Optional[str]
    score: Optional[float]
    page: Optional[Union[str, int]] # could be int or str depending on metadata
    file: Optional[str]
    link: Optional[str] = None

class TextSource(BaseModel):
    text: str
    score: float
    page: Optional[Union[str, int]]
    file: Optional[str]
    link: Optional[str] = None

class QueryResponse(BaseModel):
    answer: str
    images: List[ImageSource]
    texts: List[TextSource]
    question: str

class HealthResponse(BaseModel):
    status: str
    rag_initialized: bool

# API Endpoints

@app.get("/", tags=["Root"])
async def root():
    """Serve the frontend application"""
    return FileResponse('static/index.html')

@app.get("/health", response_model=HealthResponse, tags=["Health"])
async def health_check():
    """Health check endpoint"""
    return HealthResponse(
        status="healthy",
        rag_initialized=rag_system is not None
    )

@app.post("/query", response_model=QueryResponse, tags=["Query"])
async def query_rag(request: QueryRequest):
    """

    Query the RAG system

    """
    if not rag_system:
        raise HTTPException(
            status_code=503,
            detail="RAG system not initialized. Check logs for errors."
        )
    
    try:
        # Get answer
        result = rag_system.ask(request.question)
        
        return QueryResponse(
            answer=result['answer'],
            images=result['images'],
            texts=result['texts'],
            question=request.question
        )
        
    except Exception as e:
        logger.error(f"Error processing query: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)