File size: 11,955 Bytes
da2d4b2
 
 
 
 
ef5aecf
 
 
 
 
da2d4b2
ef5aecf
da2d4b2
 
 
ef5aecf
da2d4b2
 
 
ef5aecf
da2d4b2
 
 
 
 
 
ef5aecf
 
 
 
 
 
 
 
 
da2d4b2
 
 
 
 
 
 
 
 
 
31b06d3
 
 
da2d4b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef5aecf
da2d4b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef5aecf
 
 
 
 
 
 
 
 
 
da2d4b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31b06d3
9fd10cc
da2d4b2
9fd10cc
da2d4b2
9fd10cc
 
ef5aecf
31b06d3
9fd10cc
da2d4b2
31b06d3
 
 
 
 
 
 
9fd10cc
da2d4b2
9fd10cc
 
 
da2d4b2
 
9fd10cc
da2d4b2
 
9fd10cc
 
ef5aecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31b06d3
da2d4b2
 
2266343
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
"""
FastAPI Application for Multimodal RAG System
US Army Medical Research Papers Q&A
"""

"""
FastAPI Application for Multimodal RAG System
US Army Medical Research Papers Q&A
"""

import os
import io
import logging
from typing import List, Dict, Optional, Union
from contextlib import asynccontextmanager
from datetime import datetime

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field

# Import from query_index (standalone)
from query_index import MultimodalRAGSystem

# PDF generation libraries
from reportlab.lib.pagesizes import letter
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import inch
from reportlab.lib.enums import TA_CENTER, TA_JUSTIFY
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, PageBreak, Image as RLImage, Table, TableStyle
from reportlab.lib import colors
import requests

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Global variables
rag_system: Optional[MultimodalRAGSystem] = None

# Keep short conversation history 
chat_history: List[Dict[str, str]] = []
MAX_HISTORY_TURNS = 3

# Lifecycle management
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Initialize and cleanup RAG system"""
    global rag_system
    
    logger.info("Starting RAG system initialization...")
    try:
        rag_system = MultimodalRAGSystem()
        logger.info("RAG system initialized successfully!")
    except Exception as e:
        logger.error(f"Error during initialization: {str(e)}")
        rag_system = None
    
    yield
    
    logger.info("Shutting down RAG system...")
    rag_system = None

# Create FastAPI app
app = FastAPI(
    title="Multimodal RAG API",
    description="Q&A system for US Army medical research papers (Text + Images)",
    version="2.0.0",
    lifespan=lifespan
)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")

# Mount extracted images
# This allows the frontend to load images via /extracted_images/filename.jpg
if os.path.exists("extracted_images"):
    app.mount("/extracted_images", StaticFiles(directory="extracted_images"), name="images")

# Mount PDF documents
if os.path.exists("WHEC_Documents"):
    app.mount("/documents", StaticFiles(directory="WHEC_Documents"), name="documents")

# Pydantic models
class QueryRequest(BaseModel):
    question: str = Field(..., min_length=1, max_length=1000, description="Question to ask")

class ImageSource(BaseModel):
    path: Optional[str]
    filename: Optional[str]
    score: Optional[float]
    page: Optional[Union[str, int]]
    file: Optional[str]
    link: Optional[str] = None

class TextSource(BaseModel):
    text: str
    score: float
    page: Optional[Union[str, int]]
    file: Optional[str]
    link: Optional[str] = None

class QueryResponse(BaseModel):
    answer: str
    images: List[ImageSource]
    texts: List[TextSource]
    question: str

class HealthResponse(BaseModel):
    status: str
    rag_initialized: bool

class ConversationItem(BaseModel):
    question: str
    answer: str
    images: List[ImageSource]
    texts: List[TextSource]
    timestamp: str

class ReportRequest(BaseModel):
    conversations: List[ConversationItem]

# API Endpoints

@app.get("/", tags=["Root"])
async def root():
    """Serve the frontend application"""
    return FileResponse('static/index.html')

@app.get("/health", response_model=HealthResponse, tags=["Health"])
async def health_check():
    """Health check endpoint"""
    return HealthResponse(
        status="healthy",
        rag_initialized=rag_system is not None
    )

@app.post("/query", response_model=QueryResponse, tags=["Query"])
async def query_rag(request: QueryRequest):
    global chat_history

    if not rag_system:
        raise HTTPException(status_code=503, detail="RAG system not initialized.")

    try:
        result = rag_system.ask(
            query_str=request.question,
            chat_history=chat_history
        )

        # Update history
        chat_history.append({"role": "user", "content": request.question})
        chat_history.append({"role": "assistant", "content": result["answer"]})

        # Trim history safely
        if len(chat_history) > MAX_HISTORY_TURNS * 2:
            chat_history = chat_history[-MAX_HISTORY_TURNS * 2:]

        return QueryResponse(
            answer=result["answer"],
            images=result["images"],
            texts=result["texts"],
            question=request.question
        )

    except Exception as e:
        logger.error(f"Error processing query: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/generate-report", tags=["Report"])
async def generate_report(request: ReportRequest):
    """Generate a PDF report from conversation data"""
    try:
        buffer = io.BytesIO()
        
        # Create PDF document
        doc = SimpleDocTemplate(
            buffer,
            pagesize=letter,
            rightMargin=0.75*inch,
            leftMargin=0.75*inch,
            topMargin=1*inch,
            bottomMargin=0.75*inch
        )
        
        # Container for the 'Flowable' objects
        story = []
        
        # Define styles
        styles = getSampleStyleSheet()
        
        # Custom styles
        title_style = ParagraphStyle(
            'CustomTitle',
            parent=styles['Heading1'],
            fontSize=24,
            textColor=colors.HexColor('#3b82f6'),
            spaceAfter=12,
            alignment=TA_CENTER,
            fontName='Helvetica-Bold'
        )
        
        subtitle_style = ParagraphStyle(
            'CustomSubtitle',
            parent=styles['Normal'],
            fontSize=12,
            textColor=colors.HexColor('#64748b'),
            spaceAfter=20,
            alignment=TA_CENTER
        )
        
        question_style = ParagraphStyle(
            'QuestionStyle',
            parent=styles['Heading2'],
            fontSize=14,
            textColor=colors.HexColor('#3b82f6'),
            spaceAfter=10,
            fontName='Helvetica-Bold'
        )
        
        answer_style = ParagraphStyle(
            'AnswerStyle',
            parent=styles['Normal'],
            fontSize=11,
            alignment=TA_JUSTIFY,
            spaceAfter=12
        )
        
        source_title_style = ParagraphStyle(
            'SourceTitle',
            parent=styles['Heading3'],
            fontSize=11,
            textColor=colors.HexColor('#8b5cf6'),
            spaceAfter=6,
            fontName='Helvetica-Bold'
        )
        
        # Add title
        story.append(Paragraph("WHEC Research Assistant", title_style))
        story.append(Paragraph("Conversation Report", title_style))
        story.append(Spacer(1, 0.2*inch))
        
        # Add metadata
        current_time = datetime.now().strftime("%B %d, %Y at %I:%M %p")
        story.append(Paragraph(f"Generated: {current_time}", subtitle_style))
        story.append(Paragraph("Source: WHEC (Warrior Heat- and Exertion-Related Events Collaborative)", subtitle_style))
        story.append(Paragraph(f"Total Questions: {len(request.conversations)}", subtitle_style))
        
        story.append(Spacer(1, 0.3*inch))
        
        # Add horizontal line
        story.append(Spacer(1, 0.1*inch))
        
        # Process each conversation
        for idx, conv in enumerate(request.conversations, 1):
            # Question
            story.append(Paragraph(f"Question {idx}", question_style))
            story.append(Paragraph(conv.question, answer_style))
            story.append(Spacer(1, 0.15*inch))
            
            # Answer
            story.append(Paragraph("Answer", source_title_style))
            story.append(Paragraph(conv.answer, answer_style))
            story.append(Spacer(1, 0.15*inch))
            
            # Text Sources
            if conv.texts:
                story.append(Paragraph("Referenced Text Sources", source_title_style))
                for i, txt in enumerate(conv.texts, 1):
                    source_text = f"[{i}] {txt.file or 'Unknown Document'} (Page {txt.page or 'N/A'}, {round((txt.score or 0) * 100)}% match)"
                    story.append(Paragraph(source_text, styles['Normal']))
                    
                    excerpt = f'<i>"{txt.text[:300]}..."</i>'
                    story.append(Paragraph(excerpt, styles['Normal']))
                    story.append(Spacer(1, 0.1*inch))
                
                story.append(Spacer(1, 0.1*inch))
            
            # Image Sources
            if conv.images:
                relevant_images = [img for img in conv.images if (img.score or 0) >= 0.3]
                if relevant_images:
                    story.append(Paragraph("Referenced Images", source_title_style))
                    
                    for i, img in enumerate(relevant_images[:3], 1):  # Limit to 3 images
                        # Add image metadata
                        img_text = f"[{i}] {img.filename or 'Unknown'} from {img.file or 'Unknown Document'} (Page {img.page or 'N/A'}, {round((img.score or 0) * 100)}% match)"
                        story.append(Paragraph(img_text, styles['Normal']))
                        
                        # Try to add the actual image
                        if img.path:
                            try:
                                # Construct the full path to the image
                                img_path = img.path.replace('/extracted_images/', '')
                                full_img_path = os.path.join('extracted_images', img_path)
                                
                                if os.path.exists(full_img_path):
                                    # Add image with max width of 4 inches
                                    rl_img = RLImage(full_img_path, width=4*inch, height=3*inch, kind='proportional')
                                    story.append(rl_img)
                                    story.append(Spacer(1, 0.1*inch))
                            except Exception as e:
                                logger.warning(f"Could not add image to PDF: {e}")
                    
                    story.append(Spacer(1, 0.15*inch))
            
            # Add separator between conversations (except after last one)
            if idx < len(request.conversations):
                story.append(Spacer(1, 0.2*inch))
                # Add a horizontal line as separator
                line_table = Table([['']], colWidths=[6.5*inch])
                line_table.setStyle(TableStyle([
                    ('LINEABOVE', (0, 0), (-1, 0), 1, colors.HexColor('#334155')),
                ]))
                story.append(line_table)
                story.append(Spacer(1, 0.2*inch))
        
        # Build PDF
        doc.build(story)
        
        # Get PDF from buffer
        buffer.seek(0)
        
        # Return PDF as download
        return StreamingResponse(
            buffer,
            media_type="application/pdf",
            headers={
                "Content-Disposition": f"attachment; filename=WHEC_Report_{datetime.now().strftime('%Y-%m-%d')}.pdf"
            }
        )
        
    except Exception as e:
        logger.error(f"Error generating PDF report: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error generating report: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)