from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import Optional, List import torch import torch.nn as nn from transformers import T5Tokenizer, T5ForConditionalGeneration import spacy import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Import model handler from model_handler import PointerGeneratorT5, MedicalQAProcessor # Initialize FastAPI app = FastAPI( title="Medical Q&A Pointer-Generator API", description="UMLS-based Pointer-Generator Network for Medical Question Answering", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for models device = None tokenizer = None model = None nlp = None qa_processor = None DEFAULT_MAX_LENGTH = 100 DEFAULT_SENTENCE_STRUCTURE = True # Pydantic models for request/response class QuestionRequest(BaseModel): question: str = Field(..., description="Medical question to answer", example="What are the symptoms of diabetes?") context: str = Field(..., description="Medical context/passage", example="Diabetes is characterized by high blood sugar...") class QuestionResponse(BaseModel): question: str answer: str p_gen_score: str entities_detected: List[dict] class HealthResponse(BaseModel): status: str model_loaded: bool device: str @app.on_event("startup") async def load_models(): """Load models on startup""" global device, tokenizer, model, nlp, qa_processor try: logger.info("Loading models...") # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"Using device: {device}") # Load T5 tokenizer and model logger.info("Loading T5 model") tokenizer = T5Tokenizer.from_pretrained('t5-base') model = PointerGeneratorT5('t5-base') model = model.to(device) model.eval() logger.info("T5 model loaded successfully") # Load spacy model with scispaCy priority logger.info("Loading medical NER model") try: # Try loading scispaCy first import subprocess subprocess.run([ "pip", "install", "https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_sm-0.5.4.tar.gz" ], check=False) nlp = spacy.load("en_core_sci_sm") logger.info("scispaCy medical model loaded successfully") except Exception as e: logger.warning(f"scispaCy failed to load: {e}") logger.info("Falling back to standard spaCy model...") try: nlp = spacy.load("en_core_web_sm") logger.info("Standard spaCy model loaded") except: # Download if not available import subprocess subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"]) nlp = spacy.load("en_core_web_sm") logger.info("Standard spaCy model downloaded and loaded") # Initialize QA processor qa_processor = MedicalQAProcessor(model, tokenizer, device, nlp) logger.info("Medical Q&A Processor initialized successfully") logger.info("API READY TO SERVE REQUESTS") except Exception as e: logger.error(f"Error loading models: {str(e)}") raise @app.get("/", response_model=dict) async def root(): """Root endpoint""" return { "message": "Medical Q&A Pointer-Generator API", "version": "1.0.0", "endpoints": { "health": "/health", "answer": "/api/answer" } } @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" return { "status": "healthy" if qa_processor is not None else "unhealthy", "model_loaded": qa_processor is not None, "device": str(device) } @app.post("/api/answer", response_model=QuestionResponse) async def answer_question(request: QuestionRequest): try: if qa_processor is None: raise HTTPException(status_code=503, detail="Model not loaded") # Validate inputs if not request.question.strip(): raise HTTPException(status_code=400, detail="Question cannot be empty") if not request.context.strip(): raise HTTPException(status_code=400, detail="Context cannot be empty") # Extract entities for display doc = nlp(request.question) entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents] # Generate answer result = qa_processor.generate_answer( question=request.question, context=request.context, max_length=DEFAULT_MAX_LENGTH, use_sentence_structure=DEFAULT_SENTENCE_STRUCTURE ) return QuestionResponse( question=request.question, answer=result['answer'], p_gen_score=result['p_gen_score'], entities_detected=entities ) except HTTPException: raise except Exception as e: logger.error(f"Error processing question: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post("/api/batch", response_model=List[QuestionResponse]) async def answer_batch(requests: List[QuestionRequest]): try: if qa_processor is None: raise HTTPException(status_code=503, detail="Model not loaded") if len(requests) > 10: raise HTTPException(status_code=400, detail="Maximum 10 questions per batch") results = [] for req in requests: doc = nlp(req.question) entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents] result = qa_processor.generate_answer( question=req.question, context=req.context, max_length=DEFAULT_MAX_LENGTH, use_sentence_structure=DEFAULT_SENTENCE_STRUCTURE ) results.append(QuestionResponse( question=req.question, answer=result['answer'], p_gen_score=result['p_gen_score'], entities_detected=entities )) return results except HTTPException: raise except Exception as e: logger.error(f"Error processing batch: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)