Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, List | |
| import torch | |
| import torch.nn as nn | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| import spacy | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Import model handler | |
| from model_handler import PointerGeneratorT5, MedicalQAProcessor | |
| # Initialize FastAPI | |
| app = FastAPI( | |
| title="Medical Q&A Pointer-Generator API", | |
| description="UMLS-based Pointer-Generator Network for Medical Question Answering", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables for models | |
| device = None | |
| tokenizer = None | |
| model = None | |
| nlp = None | |
| qa_processor = None | |
| DEFAULT_MAX_LENGTH = 100 | |
| DEFAULT_SENTENCE_STRUCTURE = True | |
| # Pydantic models for request/response | |
| class QuestionRequest(BaseModel): | |
| question: str = Field(..., description="Medical question to answer", example="What are the symptoms of diabetes?") | |
| context: str = Field(..., description="Medical context/passage", example="Diabetes is characterized by high blood sugar...") | |
| class QuestionResponse(BaseModel): | |
| question: str | |
| answer: str | |
| p_gen_score: str | |
| entities_detected: List[dict] | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| device: str | |
| async def load_models(): | |
| """Load models on startup""" | |
| global device, tokenizer, model, nlp, qa_processor | |
| try: | |
| logger.info("Loading models...") | |
| # Set device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| logger.info(f"Using device: {device}") | |
| # Load T5 tokenizer and model | |
| logger.info("Loading T5 model") | |
| tokenizer = T5Tokenizer.from_pretrained('t5-base') | |
| model = PointerGeneratorT5('t5-base') | |
| model = model.to(device) | |
| model.eval() | |
| logger.info("T5 model loaded successfully") | |
| # Load spacy model with scispaCy priority | |
| logger.info("Loading medical NER model") | |
| try: | |
| # Try loading scispaCy first | |
| import subprocess | |
| subprocess.run([ | |
| "pip", "install", | |
| "https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_sm-0.5.4.tar.gz" | |
| ], check=False) | |
| nlp = spacy.load("en_core_sci_sm") | |
| logger.info("scispaCy medical model loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"scispaCy failed to load: {e}") | |
| logger.info("Falling back to standard spaCy model...") | |
| try: | |
| nlp = spacy.load("en_core_web_sm") | |
| logger.info("Standard spaCy model loaded") | |
| except: | |
| # Download if not available | |
| import subprocess | |
| subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"]) | |
| nlp = spacy.load("en_core_web_sm") | |
| logger.info("Standard spaCy model downloaded and loaded") | |
| # Initialize QA processor | |
| qa_processor = MedicalQAProcessor(model, tokenizer, device, nlp) | |
| logger.info("Medical Q&A Processor initialized successfully") | |
| logger.info("API READY TO SERVE REQUESTS") | |
| except Exception as e: | |
| logger.error(f"Error loading models: {str(e)}") | |
| raise | |
| async def root(): | |
| """Root endpoint""" | |
| return { | |
| "message": "Medical Q&A Pointer-Generator API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health", | |
| "answer": "/api/answer" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy" if qa_processor is not None else "unhealthy", | |
| "model_loaded": qa_processor is not None, | |
| "device": str(device) | |
| } | |
| async def answer_question(request: QuestionRequest): | |
| try: | |
| if qa_processor is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| # Validate inputs | |
| if not request.question.strip(): | |
| raise HTTPException(status_code=400, detail="Question cannot be empty") | |
| if not request.context.strip(): | |
| raise HTTPException(status_code=400, detail="Context cannot be empty") | |
| # Extract entities for display | |
| doc = nlp(request.question) | |
| entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents] | |
| # Generate answer | |
| result = qa_processor.generate_answer( | |
| question=request.question, | |
| context=request.context, | |
| max_length=DEFAULT_MAX_LENGTH, | |
| use_sentence_structure=DEFAULT_SENTENCE_STRUCTURE | |
| ) | |
| return QuestionResponse( | |
| question=request.question, | |
| answer=result['answer'], | |
| p_gen_score=result['p_gen_score'], | |
| entities_detected=entities | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing question: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def answer_batch(requests: List[QuestionRequest]): | |
| try: | |
| if qa_processor is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| if len(requests) > 10: | |
| raise HTTPException(status_code=400, detail="Maximum 10 questions per batch") | |
| results = [] | |
| for req in requests: | |
| doc = nlp(req.question) | |
| entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents] | |
| result = qa_processor.generate_answer( | |
| question=req.question, | |
| context=req.context, | |
| max_length=DEFAULT_MAX_LENGTH, | |
| use_sentence_structure=DEFAULT_SENTENCE_STRUCTURE | |
| ) | |
| results.append(QuestionResponse( | |
| question=req.question, | |
| answer=result['answer'], | |
| p_gen_score=result['p_gen_score'], | |
| entities_detected=entities | |
| )) | |
| return results | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing batch: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |