| | from fastapi import FastAPI, HTTPException, status |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import JSONResponse |
| | from pydantic import BaseModel |
| | from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoConfig |
| | import torch |
| | import os |
| | import sys |
| | import traceback |
| | from typing import Optional, Dict, Any |
| | from accelerate import Accelerator |
| | import time |
| | import psutil |
| | from loguru import logger |
| |
|
| | |
| | logger.remove() |
| | logger.add( |
| | sys.stderr, |
| | level="INFO", |
| | format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" |
| | ) |
| |
|
| | |
| | app = FastAPI( |
| | title="Clinical Report Generator API", |
| | description="Production API for generating clinical report summaries using T5", |
| | version="1.0.0", |
| | docs_url="/documentation", |
| | redoc_url="/redoc" |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["https://pdarleyjr.github.io"], |
| | allow_credentials=True, |
| | allow_methods=["POST", "GET"], |
| | allow_headers=["*"], |
| | max_age=3600, |
| | ) |
| |
|
| | |
| | MODEL_ID = "pdarleyjr/iplc-t5-clinical" |
| |
|
| | class ModelManager: |
| | def __init__(self): |
| | self.model = None |
| | self.tokenizer = None |
| | self.accelerator = Accelerator() |
| | self.last_load_time = None |
| | self.load_lock = False |
| |
|
| | async def load_model(self) -> bool: |
| | """Load model and tokenizer with proper error handling and logging""" |
| | if self.load_lock: |
| | logger.warning("Model load already in progress") |
| | return False |
| |
|
| | try: |
| | self.load_lock = True |
| | logger.info("Starting model and tokenizer loading process...") |
| | |
| | |
| | memory = psutil.virtual_memory() |
| | logger.info(f"System memory: {memory.percent}% used, {memory.available / (1024*1024*1024):.2f}GB available") |
| | if torch.cuda.is_available(): |
| | logger.info(f"CUDA memory: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated") |
| |
|
| | |
| | logger.info("Initializing tokenizer...") |
| | self.tokenizer = T5Tokenizer.from_pretrained( |
| | MODEL_ID, |
| | use_fast=True, |
| | model_max_length=512 |
| | ) |
| | logger.success("Tokenizer loaded successfully") |
| |
|
| | |
| | logger.info("Fetching model configuration...") |
| | config = AutoConfig.from_pretrained( |
| | MODEL_ID, |
| | trust_remote_code=False |
| | ) |
| | logger.success("Model configuration loaded successfully") |
| |
|
| | |
| | logger.info("Loading model (this may take a few minutes)...") |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | logger.info(f"Using device: {device}") |
| | |
| | self.model = T5ForConditionalGeneration.from_pretrained( |
| | MODEL_ID, |
| | config=config, |
| | torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
| | low_cpu_mem_usage=True |
| | ).to(device) |
| | logger.success("Model loaded successfully") |
| |
|
| | |
| | self.model = self.accelerator.prepare_model(self.model) |
| | logger.success("Model prepared with accelerator") |
| |
|
| | |
| | memory = psutil.virtual_memory() |
| | logger.info(f"Final memory usage: {memory.percent}% used, {memory.available / (1024*1024*1024):.2f}GB available") |
| | if torch.cuda.is_available(): |
| | logger.info(f"Final CUDA memory: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated") |
| |
|
| | self.last_load_time = time.time() |
| | return True |
| |
|
| | except Exception as e: |
| | logger.exception("Error loading model") |
| | self.model = None |
| | self.tokenizer = None |
| | return False |
| |
|
| | finally: |
| | self.load_lock = False |
| |
|
| | def is_loaded(self) -> bool: |
| | """Check if model and tokenizer are loaded""" |
| | return self.model is not None and self.tokenizer is not None |
| |
|
| | def get_load_time(self) -> Optional[float]: |
| | """Get the last successful load time""" |
| | return self.last_load_time |
| |
|
| | |
| | model_manager = ModelManager() |
| |
|
| | class PredictRequest(BaseModel): |
| | """Request model for prediction endpoint""" |
| | text: str |
| |
|
| | class Config: |
| | schema_extra = { |
| | "example": { |
| | "text": "evaluation type: initial. primary diagnosis: F84.0. severity: mild. primary language: english" |
| | } |
| | } |
| |
|
| | @app.post("/predict", |
| | response_model=Dict[str, Any], |
| | status_code=status.HTTP_200_OK, |
| | responses={ |
| | 500: {"description": "Internal server error"}, |
| | 503: {"description": "Service unavailable - model loading"} |
| | }) |
| | async def predict(request: PredictRequest) -> JSONResponse: |
| | """Generate a clinical report summary""" |
| | start_time = time.time() |
| | |
| | try: |
| | |
| | if not model_manager.is_loaded(): |
| | logger.warning("Model not loaded, attempting to load...") |
| | success = await model_manager.load_model() |
| | if not success: |
| | return JSONResponse( |
| | status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| | content={ |
| | "success": False, |
| | "error": "Model is initializing. Please try again in a few moments." |
| | } |
| | ) |
| |
|
| | |
| | input_text = "summarize: " + request.text |
| | input_ids = model_manager.tokenizer.encode( |
| | input_text, |
| | return_tensors="pt", |
| | max_length=512, |
| | truncation=True, |
| | padding=True |
| | ) |
| |
|
| | |
| | try: |
| | device = next(model_manager.model.parameters()).device |
| | input_ids = input_ids.to(device) |
| | |
| | with torch.no_grad(), model_manager.accelerator.autocast(): |
| | outputs = model_manager.model.generate( |
| | input_ids, |
| | max_length=512, |
| | num_beams=5, |
| | no_repeat_ngram_size=3, |
| | length_penalty=2.0, |
| | early_stopping=True, |
| | pad_token_id=model_manager.tokenizer.pad_token_id, |
| | eos_token_id=model_manager.tokenizer.eos_token_id, |
| | temperature=0.7 |
| | ) |
| |
|
| | summary = model_manager.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | |
| | process_time = time.time() - start_time |
| | logger.info(f"Summary generated in {process_time:.2f} seconds") |
| |
|
| | return JSONResponse( |
| | content={ |
| | "success": True, |
| | "data": summary, |
| | "error": None, |
| | "metrics": { |
| | "process_time": process_time |
| | } |
| | } |
| | ) |
| |
|
| | except torch.cuda.OutOfMemoryError: |
| | logger.error("CUDA out of memory error - clearing cache and reducing batch size") |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | logger.info(f"CUDA memory after cleanup: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB allocated") |
| | return JSONResponse( |
| | status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| | content={ |
| | "success": False, |
| | "error": "Server is currently overloaded. Please try again later." |
| | } |
| | ) |
| |
|
| | except Exception as e: |
| | logger.exception("Error in predict endpoint") |
| | return JSONResponse( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | content={ |
| | "success": False, |
| | "error": "An unexpected error occurred. Please try again later." |
| | } |
| | ) |
| |
|
| | @app.get("/health", |
| | response_model=Dict[str, Any], |
| | status_code=status.HTTP_200_OK) |
| | async def health_check() -> JSONResponse: |
| | """Check API and model health status""" |
| | try: |
| | is_loaded = model_manager.is_loaded() |
| | load_time = model_manager.get_load_time() |
| | |
| | return JSONResponse( |
| | content={ |
| | "status": "healthy", |
| | "model_loaded": is_loaded, |
| | "last_load_time": load_time, |
| | "version": "1.0.0", |
| | "gpu_available": torch.cuda.is_available(), |
| | "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None |
| | } |
| | ) |
| | except Exception as e: |
| | logger.error(f"Error in health check: {str(e)}") |
| | return JSONResponse( |
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| | content={ |
| | "status": "unhealthy", |
| | "error": str(e) |
| | } |
| | ) |
| |
|
| | @app.on_event("startup") |
| | async def startup_event() -> None: |
| | """Initialize model on startup""" |
| | logger.info("Starting application in production mode...") |
| | logger.info(f"System resources - CPU: {psutil.cpu_percent()}%, Memory: {psutil.virtual_memory().percent}%") |
| | if torch.cuda.is_available(): |
| | logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}") |
| | await model_manager.load_model() |
| |
|
| | @app.on_event("shutdown") |
| | async def shutdown_event() -> None: |
| | """Clean up resources on shutdown""" |
| | logger.info("Initiating graceful shutdown...") |
| | |
| | if torch.cuda.is_available(): |
| | logger.info(f"Final CUDA memory before cleanup: {torch.cuda.memory_allocated() / (1024*1024*1024):.2f}GB") |
| | torch.cuda.empty_cache() |
| | logger.info("CUDA cache cleared") |
| | logger.info(f"Final system stats - CPU: {psutil.cpu_percent()}%, Memory: {psutil.virtual_memory().percent}%") |
| | logger.success("Application shutdown complete") |
| |
|
| | |
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|