datdevsteve commited on
Commit
dd4ebea
·
verified ·
1 Parent(s): 8de7fbd

Create api_main.py

Browse files
Files changed (1) hide show
  1. api_main.py +392 -0
api_main.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nivra ClinicalBERT Text Classifier - FastAPI Backend
3
+ HuggingFace Space Inference API for Symptom Text Classification
4
+ """
5
+ from fastapi import FastAPI, HTTPException, status
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse
8
+ from pydantic import BaseModel, Field, validator
9
+ from typing import List, Optional, Dict, Any
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ import logging
13
+ import time
14
+ from contextlib import asynccontextmanager
15
+
16
+ # =============================================================================
17
+ # LOGGING CONFIGURATION
18
+ # =============================================================================
19
+
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # =============================================================================
27
+ # GLOBAL MODEL VARIABLES
28
+ # =============================================================================
29
+
30
+ MODEL_NAME = "datdevsteve/clinicalbert-nivra-finetuned"
31
+ model = None
32
+ tokenizer = None
33
+ id2label = {}
34
+
35
+ # =============================================================================
36
+ # LIFESPAN CONTEXT MANAGER (Model Loading)
37
+ # =============================================================================
38
+
39
+ @asynccontextmanager
40
+ async def lifespan(app: FastAPI):
41
+ """Load model on startup and cleanup on shutdown"""
42
+ global model, tokenizer, id2label
43
+
44
+ logger.info(f"[STARTUP] Loading model: {MODEL_NAME}")
45
+ try:
46
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
47
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
48
+ model.eval()
49
+ id2label = model.config.id2label if hasattr(model.config, 'id2label') else {}
50
+ logger.info("[STARTUP] Model loaded successfully!")
51
+ except Exception as e:
52
+ logger.error(f"[STARTUP ERROR] Failed to load model: {e}")
53
+ raise
54
+
55
+ yield # Application runs here
56
+
57
+ logger.info("[SHUTDOWN] Cleaning up resources...")
58
+ # Cleanup if needed
59
+
60
+ # =============================================================================
61
+ # FASTAPI APP INITIALIZATION
62
+ # =============================================================================
63
+
64
+ app = FastAPI(
65
+ title="Nivra ClinicalBERT Text Classifier API",
66
+ description="AI-powered symptom text classification for Indian Healthcare using ClinicalBERT",
67
+ version="1.0.0",
68
+ docs_url="/docs",
69
+ redoc_url="/redoc",
70
+ lifespan=lifespan
71
+ )
72
+
73
+ # =============================================================================
74
+ # CORS MIDDLEWARE
75
+ # =============================================================================
76
+
77
+ app.add_middleware(
78
+ CORSMiddleware,
79
+ allow_origins=["*"], # In production, specify exact origins
80
+ allow_credentials=True,
81
+ allow_methods=["*"],
82
+ allow_headers=["*"],
83
+ )
84
+
85
+ # =============================================================================
86
+ # PYDANTIC MODELS
87
+ # =============================================================================
88
+
89
+ class SymptomTextRequest(BaseModel):
90
+ text: str = Field(
91
+ ...,
92
+ min_length=5,
93
+ max_length=1000,
94
+ description="Patient symptom description",
95
+ example="Patient presents fever of 102°F, severe headache, body pain and weakness for 3 days"
96
+ )
97
+ top_k: Optional[int] = Field(
98
+ default=5,
99
+ ge=1,
100
+ le=20,
101
+ description="Number of top predictions to return"
102
+ )
103
+
104
+ @validator('text')
105
+ def validate_text(cls, v):
106
+ """Validate text input"""
107
+ if not v or v.strip() == "":
108
+ raise ValueError("Text cannot be empty")
109
+ return v.strip()
110
+
111
+ class BatchSymptomRequest(BaseModel):
112
+ texts: List[str] = Field(
113
+ ...,
114
+ min_items=1,
115
+ max_items=10,
116
+ description="List of symptom descriptions to classify"
117
+ )
118
+ top_k: Optional[int] = Field(
119
+ default=3,
120
+ ge=1,
121
+ le=10,
122
+ description="Number of top predictions per text"
123
+ )
124
+
125
+ class PredictionResult(BaseModel):
126
+ label: str = Field(..., description="Predicted disease/condition")
127
+ score: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
128
+
129
+ class TextClassificationResponse(BaseModel):
130
+ success: bool = Field(default=True, description="Request success status")
131
+ primary_classification: str = Field(..., description="Top predicted condition")
132
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
133
+ predictions: List[PredictionResult] = Field(..., description="All predictions")
134
+ model: str = Field(..., description="Model identifier")
135
+ processing_time_ms: float = Field(..., description="Inference time in milliseconds")
136
+ input_text: str = Field(..., description="Original input text")
137
+
138
+ class BatchClassificationResponse(BaseModel):
139
+ success: bool = Field(default=True)
140
+ batch_size: int = Field(..., description="Number of texts processed")
141
+ results: List[TextClassificationResponse] = Field(..., description="Individual results")
142
+ total_processing_time_ms: float = Field(..., description="Total processing time")
143
+
144
+ class HealthResponse(BaseModel):
145
+ status: str
146
+ model_loaded: bool
147
+ model_name: str
148
+ timestamp: str
149
+
150
+ class ErrorResponse(BaseModel):
151
+ success: bool = False
152
+ error: str
153
+ detail: Optional[str] = None
154
+
155
+ # =============================================================================
156
+ # HELPER FUNCTIONS
157
+ # =============================================================================
158
+
159
+ def predict_symptoms(text: str, top_k: int = 5) -> Dict[str, Any]:
160
+ """
161
+ Classify symptom text to predict diseases
162
+
163
+ Args:
164
+ text: Patient's symptom description
165
+ top_k: Number of top predictions to return
166
+
167
+ Returns:
168
+ Dictionary with predictions and metadata
169
+ """
170
+ try:
171
+ start_time = time.time()
172
+
173
+ # Tokenize input
174
+ inputs = tokenizer(
175
+ text,
176
+ return_tensors="pt",
177
+ truncation=True,
178
+ max_length=512,
179
+ padding=True
180
+ )
181
+
182
+ # Get predictions
183
+ with torch.no_grad():
184
+ outputs = model(**inputs)
185
+ logits = outputs.logits
186
+ probabilities = torch.softmax(logits, dim=-1)[0]
187
+
188
+ # Format predictions
189
+ predictions = []
190
+ for idx, prob in enumerate(probabilities):
191
+ label = id2label.get(idx, f"LABEL_{idx}")
192
+ score = float(prob)
193
+ predictions.append({
194
+ "label": label,
195
+ "score": score
196
+ })
197
+
198
+ # Sort by confidence
199
+ predictions = sorted(predictions, key=lambda x: x['score'], reverse=True)
200
+ top_predictions = predictions[:top_k]
201
+
202
+ processing_time = (time.time() - start_time) * 1000 # Convert to ms
203
+
204
+ result = {
205
+ "primary_classification": top_predictions[0]['label'],
206
+ "confidence": top_predictions[0]['score'],
207
+ "predictions": top_predictions,
208
+ "model": MODEL_NAME,
209
+ "processing_time_ms": round(processing_time, 2),
210
+ "input_text": text[:100] + "..." if len(text) > 100 else text
211
+ }
212
+
213
+ logger.info(f"[PREDICTION] {top_predictions[0]['label']} ({top_predictions[0]['score']:.4f}) - {processing_time:.2f}ms")
214
+ return result
215
+
216
+ except Exception as e:
217
+ logger.error(f"[PREDICTION ERROR] {str(e)}")
218
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
219
+
220
+ # =============================================================================
221
+ # API ENDPOINTS
222
+ # =============================================================================
223
+
224
+ @app.get("/", tags=["Root"])
225
+ async def root():
226
+ """Root endpoint - API information"""
227
+ return {
228
+ "message": "Nivra ClinicalBERT Text Classifier API",
229
+ "version": "1.0.0",
230
+ "status": "active",
231
+ "model": MODEL_NAME,
232
+ "endpoints": {
233
+ "health": "/health",
234
+ "docs": "/docs",
235
+ "predict_single": "/api/v1/predict",
236
+ "predict_batch": "/api/v1/predict/batch"
237
+ }
238
+ }
239
+
240
+ @app.get("/health", response_model=HealthResponse, tags=["Health"])
241
+ async def health_check():
242
+ """Health check endpoint for monitoring"""
243
+ from datetime import datetime
244
+
245
+ return HealthResponse(
246
+ status="healthy" if model is not None else "unhealthy",
247
+ model_loaded=model is not None,
248
+ model_name=MODEL_NAME,
249
+ timestamp=datetime.utcnow().isoformat()
250
+ )
251
+
252
+ @app.post(
253
+ "/api/v1/predict",
254
+ response_model=TextClassificationResponse,
255
+ tags=["Prediction"],
256
+ summary="Classify symptom text to predict disease/condition"
257
+ )
258
+ async def predict_single(request: SymptomTextRequest):
259
+ """
260
+ Classify patient symptom descriptions to predict medical conditions
261
+
262
+ **Example Request:**
263
+ ```json
264
+ {
265
+ "text": "Patient presents fever of 102°F, severe headache, body pain and weakness for 3 days",
266
+ "top_k": 5
267
+ }
268
+ ```
269
+
270
+ **Use Cases:**
271
+ - Symptom-based diagnosis assistance
272
+ - Preliminary medical screening
273
+ - Healthcare chatbot integration
274
+ - Medical triage systems
275
+ """
276
+ try:
277
+ result = predict_symptoms(request.text, top_k=request.top_k)
278
+ return TextClassificationResponse(**result, success=True)
279
+
280
+ except HTTPException:
281
+ raise
282
+ except Exception as e:
283
+ logger.error(f"[PREDICT ERROR] {str(e)}")
284
+ raise HTTPException(status_code=500, detail=f"Processing failed: {str(e)}")
285
+
286
+ @app.post(
287
+ "/api/v1/predict/batch",
288
+ response_model=BatchClassificationResponse,
289
+ tags=["Prediction"],
290
+ summary="Batch classification for multiple symptom texts"
291
+ )
292
+ async def predict_batch(request: BatchSymptomRequest):
293
+ """
294
+ Classify multiple symptom descriptions in a single request
295
+
296
+ **Example Request:**
297
+ ```json
298
+ {
299
+ "texts": [
300
+ "fever and headache for 2 days",
301
+ "persistent cough with chest pain",
302
+ "stomach pain and nausea"
303
+ ],
304
+ "top_k": 3
305
+ }
306
+ ```
307
+
308
+ **Limitation:** Maximum 10 texts per batch
309
+ """
310
+ try:
311
+ start_time = time.time()
312
+ results = []
313
+
314
+ for text in request.texts:
315
+ try:
316
+ result = predict_symptoms(text, top_k=request.top_k)
317
+ results.append(TextClassificationResponse(**result, success=True))
318
+ except Exception as e:
319
+ logger.error(f"[BATCH ERROR] Text: '{text[:50]}...' - Error: {str(e)}")
320
+ # Add error result for this text
321
+ results.append(TextClassificationResponse(
322
+ success=False,
323
+ primary_classification="error",
324
+ confidence=0.0,
325
+ predictions=[],
326
+ model=MODEL_NAME,
327
+ processing_time_ms=0.0,
328
+ input_text=text[:100]
329
+ ))
330
+
331
+ total_time = (time.time() - start_time) * 1000
332
+
333
+ return BatchClassificationResponse(
334
+ success=True,
335
+ batch_size=len(request.texts),
336
+ results=results,
337
+ total_processing_time_ms=round(total_time, 2)
338
+ )
339
+
340
+ except Exception as e:
341
+ logger.error(f"[BATCH ERROR] {str(e)}")
342
+ raise HTTPException(status_code=500, detail=f"Batch processing failed: {str(e)}")
343
+
344
+ @app.get(
345
+ "/api/v1/labels",
346
+ tags=["Model Info"],
347
+ summary="Get all possible classification labels"
348
+ )
349
+ async def get_labels():
350
+ """
351
+ Retrieve all possible disease/condition labels the model can predict
352
+
353
+ **Returns:** Dictionary mapping label IDs to human-readable names
354
+ """
355
+ return {
356
+ "total_labels": len(id2label),
357
+ "labels": id2label
358
+ }
359
+
360
+ # =============================================================================
361
+ # ERROR HANDLERS
362
+ # =============================================================================
363
+
364
+ @app.exception_handler(HTTPException)
365
+ async def http_exception_handler(request, exc):
366
+ return JSONResponse(
367
+ status_code=exc.status_code,
368
+ content={"success": False, "error": exc.detail}
369
+ )
370
+
371
+ @app.exception_handler(Exception)
372
+ async def general_exception_handler(request, exc):
373
+ logger.error(f"[UNHANDLED ERROR] {str(exc)}")
374
+ return JSONResponse(
375
+ status_code=500,
376
+ content={"success": False, "error": "Internal server error", "detail": str(exc)}
377
+ )
378
+
379
+ # =============================================================================
380
+ # MAIN ENTRY POINT
381
+ # =============================================================================
382
+
383
+ if __name__ == "__main__":
384
+ import uvicorn
385
+
386
+ uvicorn.run(
387
+ "api_main:app",
388
+ host="0.0.0.0",
389
+ port=7860,
390
+ reload=False,
391
+ log_level="info"
392
+ )