pgn-medical-qa / app.py
dev2004v's picture
Update app.py
9702570 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional, List
import torch
import torch.nn as nn
from transformers import T5Tokenizer, T5ForConditionalGeneration
import spacy
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Import model handler
from model_handler import PointerGeneratorT5, MedicalQAProcessor
# Initialize FastAPI
app = FastAPI(
title="Medical Q&A Pointer-Generator API",
description="UMLS-based Pointer-Generator Network for Medical Question Answering",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for models
device = None
tokenizer = None
model = None
nlp = None
qa_processor = None
DEFAULT_MAX_LENGTH = 100
DEFAULT_SENTENCE_STRUCTURE = True
# Pydantic models for request/response
class QuestionRequest(BaseModel):
question: str = Field(..., description="Medical question to answer", example="What are the symptoms of diabetes?")
context: str = Field(..., description="Medical context/passage", example="Diabetes is characterized by high blood sugar...")
class QuestionResponse(BaseModel):
question: str
answer: str
p_gen_score: str
entities_detected: List[dict]
class HealthResponse(BaseModel):
status: str
model_loaded: bool
device: str
@app.on_event("startup")
async def load_models():
"""Load models on startup"""
global device, tokenizer, model, nlp, qa_processor
try:
logger.info("Loading models...")
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")
# Load T5 tokenizer and model
logger.info("Loading T5 model")
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = PointerGeneratorT5('t5-base')
model = model.to(device)
model.eval()
logger.info("T5 model loaded successfully")
# Load spacy model with scispaCy priority
logger.info("Loading medical NER model")
try:
# Try loading scispaCy first
import subprocess
subprocess.run([
"pip", "install",
"https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_sm-0.5.4.tar.gz"
], check=False)
nlp = spacy.load("en_core_sci_sm")
logger.info("scispaCy medical model loaded successfully")
except Exception as e:
logger.warning(f"scispaCy failed to load: {e}")
logger.info("Falling back to standard spaCy model...")
try:
nlp = spacy.load("en_core_web_sm")
logger.info("Standard spaCy model loaded")
except:
# Download if not available
import subprocess
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
nlp = spacy.load("en_core_web_sm")
logger.info("Standard spaCy model downloaded and loaded")
# Initialize QA processor
qa_processor = MedicalQAProcessor(model, tokenizer, device, nlp)
logger.info("Medical Q&A Processor initialized successfully")
logger.info("API READY TO SERVE REQUESTS")
except Exception as e:
logger.error(f"Error loading models: {str(e)}")
raise
@app.get("/", response_model=dict)
async def root():
"""Root endpoint"""
return {
"message": "Medical Q&A Pointer-Generator API",
"version": "1.0.0",
"endpoints": {
"health": "/health",
"answer": "/api/answer"
}
}
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if qa_processor is not None else "unhealthy",
"model_loaded": qa_processor is not None,
"device": str(device)
}
@app.post("/api/answer", response_model=QuestionResponse)
async def answer_question(request: QuestionRequest):
try:
if qa_processor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Validate inputs
if not request.question.strip():
raise HTTPException(status_code=400, detail="Question cannot be empty")
if not request.context.strip():
raise HTTPException(status_code=400, detail="Context cannot be empty")
# Extract entities for display
doc = nlp(request.question)
entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents]
# Generate answer
result = qa_processor.generate_answer(
question=request.question,
context=request.context,
max_length=DEFAULT_MAX_LENGTH,
use_sentence_structure=DEFAULT_SENTENCE_STRUCTURE
)
return QuestionResponse(
question=request.question,
answer=result['answer'],
p_gen_score=result['p_gen_score'],
entities_detected=entities
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing question: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/api/batch", response_model=List[QuestionResponse])
async def answer_batch(requests: List[QuestionRequest]):
try:
if qa_processor is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if len(requests) > 10:
raise HTTPException(status_code=400, detail="Maximum 10 questions per batch")
results = []
for req in requests:
doc = nlp(req.question)
entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents]
result = qa_processor.generate_answer(
question=req.question,
context=req.context,
max_length=DEFAULT_MAX_LENGTH,
use_sentence_structure=DEFAULT_SENTENCE_STRUCTURE
)
results.append(QuestionResponse(
question=req.question,
answer=result['answer'],
p_gen_score=result['p_gen_score'],
entities_detected=entities
))
return results
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing batch: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)