Spaces:
Sleeping
Sleeping
| """ | |
| Nivra ClinicalBERT Text Classifier - FastAPI Backend | |
| HuggingFace Space Inference API for Symptom Text Classification | |
| """ | |
| from fastapi import FastAPI, HTTPException, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field, validator | |
| from typing import List, Optional, Dict, Any | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import logging | |
| import time | |
| from contextlib import asynccontextmanager | |
| # ============================================================================= | |
| # LOGGING CONFIGURATION | |
| # ============================================================================= | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # GLOBAL MODEL VARIABLES | |
| # ============================================================================= | |
| MODEL_NAME = "datdevsteve/clinicalbert-nivra-finetuned" | |
| model = None | |
| tokenizer = None | |
| id2label = {} | |
| # ============================================================================= | |
| # LIFESPAN CONTEXT MANAGER (Model Loading) | |
| # ============================================================================= | |
| async def lifespan(app: FastAPI): | |
| """Load model on startup and cleanup on shutdown""" | |
| global model, tokenizer, id2label | |
| logger.info(f"[STARTUP] Loading model: {MODEL_NAME}") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| id2label = model.config.id2label if hasattr(model.config, 'id2label') else {} | |
| logger.info("[STARTUP] Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"[STARTUP ERROR] Failed to load model: {e}") | |
| raise | |
| yield # Application runs here | |
| logger.info("[SHUTDOWN] Cleaning up resources...") | |
| # Cleanup if needed | |
| # ============================================================================= | |
| # FASTAPI APP INITIALIZATION | |
| # ============================================================================= | |
| app = FastAPI( | |
| title="Nivra ClinicalBERT Text Classifier API", | |
| description="AI-powered symptom text classification for Indian Healthcare using ClinicalBERT", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| lifespan=lifespan | |
| ) | |
| # ============================================================================= | |
| # CORS MIDDLEWARE | |
| # ============================================================================= | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify exact origins | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================================================= | |
| # PYDANTIC MODELS | |
| # ============================================================================= | |
| class SymptomTextRequest(BaseModel): | |
| text: str = Field( | |
| ..., | |
| min_length=5, | |
| max_length=1000, | |
| description="Patient symptom description", | |
| example="Patient presents fever of 102°F, severe headache, body pain and weakness for 3 days" | |
| ) | |
| top_k: Optional[int] = Field( | |
| default=5, | |
| ge=1, | |
| le=20, | |
| description="Number of top predictions to return" | |
| ) | |
| def validate_text(cls, v): | |
| """Validate text input""" | |
| if not v or v.strip() == "": | |
| raise ValueError("Text cannot be empty") | |
| return v.strip() | |
| class BatchSymptomRequest(BaseModel): | |
| texts: List[str] = Field( | |
| ..., | |
| min_items=1, | |
| max_items=10, | |
| description="List of symptom descriptions to classify" | |
| ) | |
| top_k: Optional[int] = Field( | |
| default=3, | |
| ge=1, | |
| le=10, | |
| description="Number of top predictions per text" | |
| ) | |
| class PredictionResult(BaseModel): | |
| label: str = Field(..., description="Predicted disease/condition") | |
| score: float = Field(..., ge=0.0, le=1.0, description="Confidence score") | |
| class TextClassificationResponse(BaseModel): | |
| success: bool = Field(default=True, description="Request success status") | |
| primary_classification: str = Field(..., description="Top predicted condition") | |
| confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score") | |
| predictions: List[PredictionResult] = Field(..., description="All predictions") | |
| model: str = Field(..., description="Model identifier") | |
| processing_time_ms: float = Field(..., description="Inference time in milliseconds") | |
| input_text: str = Field(..., description="Original input text") | |
| class BatchClassificationResponse(BaseModel): | |
| success: bool = Field(default=True) | |
| batch_size: int = Field(..., description="Number of texts processed") | |
| results: List[TextClassificationResponse] = Field(..., description="Individual results") | |
| total_processing_time_ms: float = Field(..., description="Total processing time") | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| model_name: str | |
| timestamp: str | |
| class ErrorResponse(BaseModel): | |
| success: bool = False | |
| error: str | |
| detail: Optional[str] = None | |
| # ============================================================================= | |
| # HELPER FUNCTIONS | |
| # ============================================================================= | |
| def predict_symptoms(text: str, top_k: int = 5) -> Dict[str, Any]: | |
| """ | |
| Classify symptom text to predict diseases | |
| Args: | |
| text: Patient's symptom description | |
| top_k: Number of top predictions to return | |
| Returns: | |
| Dictionary with predictions and metadata | |
| """ | |
| try: | |
| start_time = time.time() | |
| # Tokenize input | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=-1)[0] | |
| # Format predictions | |
| predictions = [] | |
| for idx, prob in enumerate(probabilities): | |
| label = id2label.get(idx, f"LABEL_{idx}") | |
| score = float(prob) | |
| predictions.append({ | |
| "label": label, | |
| "score": score | |
| }) | |
| # Sort by confidence | |
| predictions = sorted(predictions, key=lambda x: x['score'], reverse=True) | |
| top_predictions = predictions[:top_k] | |
| processing_time = (time.time() - start_time) * 1000 # Convert to ms | |
| result = { | |
| "primary_classification": top_predictions[0]['label'], | |
| "confidence": top_predictions[0]['score'], | |
| "predictions": top_predictions, | |
| "model": MODEL_NAME, | |
| "processing_time_ms": round(processing_time, 2), | |
| "input_text": text[:100] + "..." if len(text) > 100 else text | |
| } | |
| logger.info(f"[PREDICTION] {top_predictions[0]['label']} ({top_predictions[0]['score']:.4f}) - {processing_time:.2f}ms") | |
| return result | |
| except Exception as e: | |
| logger.error(f"[PREDICTION ERROR] {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| # ============================================================================= | |
| # API ENDPOINTS | |
| # ============================================================================= | |
| async def root(): | |
| """Root endpoint - API information""" | |
| return { | |
| "message": "Nivra ClinicalBERT Text Classifier API", | |
| "version": "1.0.0", | |
| "status": "active", | |
| "model": MODEL_NAME, | |
| "endpoints": { | |
| "health": "/health", | |
| "docs": "/docs", | |
| "predict_single": "/api/v1/predict", | |
| "predict_batch": "/api/v1/predict/batch" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint for monitoring""" | |
| from datetime import datetime | |
| return HealthResponse( | |
| status="healthy" if model is not None else "unhealthy", | |
| model_loaded=model is not None, | |
| model_name=MODEL_NAME, | |
| timestamp=datetime.utcnow().isoformat() | |
| ) | |
| async def predict_single(request: SymptomTextRequest): | |
| """ | |
| Classify patient symptom descriptions to predict medical conditions | |
| **Example Request:** | |
| ```json | |
| { | |
| "text": "Patient presents fever of 102°F, severe headache, body pain and weakness for 3 days", | |
| "top_k": 5 | |
| } | |
| ``` | |
| **Use Cases:** | |
| - Symptom-based diagnosis assistance | |
| - Preliminary medical screening | |
| - Healthcare chatbot integration | |
| - Medical triage systems | |
| """ | |
| try: | |
| result = predict_symptoms(request.text, top_k=request.top_k) | |
| return TextClassificationResponse(**result, success=True) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"[PREDICT ERROR] {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}") | |
| async def predict_batch(request: BatchSymptomRequest): | |
| """ | |
| Classify multiple symptom descriptions in a single request | |
| **Example Request:** | |
| ```json | |
| { | |
| "texts": [ | |
| "fever and headache for 2 days", | |
| "persistent cough with chest pain", | |
| "stomach pain and nausea" | |
| ], | |
| "top_k": 3 | |
| } | |
| ``` | |
| **Limitation:** Maximum 10 texts per batch | |
| """ | |
| try: | |
| start_time = time.time() | |
| results = [] | |
| for text in request.texts: | |
| try: | |
| result = predict_symptoms(text, top_k=request.top_k) | |
| results.append(TextClassificationResponse(**result, success=True)) | |
| except Exception as e: | |
| logger.error(f"[BATCH ERROR] Text: '{text[:50]}...' - Error: {str(e)}") | |
| # Add error result for this text | |
| results.append(TextClassificationResponse( | |
| success=False, | |
| primary_classification="error", | |
| confidence=0.0, | |
| predictions=[], | |
| model=MODEL_NAME, | |
| processing_time_ms=0.0, | |
| input_text=text[:100] | |
| )) | |
| total_time = (time.time() - start_time) * 1000 | |
| return BatchClassificationResponse( | |
| success=True, | |
| batch_size=len(request.texts), | |
| results=results, | |
| total_processing_time_ms=round(total_time, 2) | |
| ) | |
| except Exception as e: | |
| logger.error(f"[BATCH ERROR] {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Batch processing failed: {str(e)}") | |
| async def get_labels(): | |
| """ | |
| Retrieve all possible disease/condition labels the model can predict | |
| **Returns:** Dictionary mapping label IDs to human-readable names | |
| """ | |
| return { | |
| "total_labels": len(id2label), | |
| "labels": id2label | |
| } | |
| # ============================================================================= | |
| # ERROR HANDLERS | |
| # ============================================================================= | |
| async def http_exception_handler(request, exc): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"success": False, "error": exc.detail} | |
| ) | |
| async def general_exception_handler(request, exc): | |
| logger.error(f"[UNHANDLED ERROR] {str(exc)}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"success": False, "error": "Internal server error", "detail": str(exc)} | |
| ) | |
| # ============================================================================= | |
| # MAIN ENTRY POINT | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| "api_main:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| reload=False, | |
| log_level="info" | |
| ) | |