datdevsteve's picture
Update app.py
8de7fbd verified
"""
Nivra ClinicalBERT Fine-tuned Model - HuggingFace Space
Symptom Classification for Indian Healthcare
"""
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import json
from typing import Dict, List, Any
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# MODEL CONFIGURATION
# =============================================================================
MODEL_NAME = "datdevsteve/nivra-clinicalbert-finetuned"
# Load model and tokenizer
logger.info(f"[i] Loading model: {MODEL_NAME}")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
logger.info("[i] Model loaded successfully")
except Exception as e:
logger.error(f"[!] Error loading model: {e}")
raise
# Get label names from model config
id2label = model.config.id2label if hasattr(model.config, 'id2label') else {}
# =============================================================================
# PREDICTION FUNCTIONS
# =============================================================================
def predict_symptoms(text: str, return_all_scores: bool = True, top_k: int = 5) -> Dict[str, Any]:
"""
Predict diseases by classifying symptom text
Args:
text: Patient's symptom description
return_all_scores: If True, return all class probabilities
top_k: Number of top predictions to return
Returns:
Dictionary with predictions and metadata
"""
try:
# Tokenize input
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Get predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)[0]
# Format results
predictions = []
for idx, prob in enumerate(probabilities):
label = id2label.get(idx, f"LABEL_{idx}")
score = float(prob)
predictions.append({
"label": label,
"score": score
})
# Sort by score
predictions = sorted(predictions, key=lambda x: x['score'], reverse=True)
if not return_all_scores:
predictions = predictions[:top_k]
result = {
"predictions": predictions,
"primary_classification": predictions[0]['label'],
"confidence": predictions[0]['score'],
"model": MODEL_NAME,
"input_text": text
}
logger.info(f"[i] Prediction: {predictions[0]['label']} ({predictions[0]['score']:.4f})")
return result
except Exception as e:
logger.error(f"[!] Prediction error: {e}")
return {
"error": str(e),
"predictions": [],
"primary_classification": "error",
"confidence": 0.0
}
def predict_api(text: str) -> str:
"""
API endpoint for symptom text prediction (returns JSON string)
"""
result = predict_symptoms(text, return_all_scores=True, top_k=5)
return json.dumps(result, indent=2)
def predict_gradio(text: str, top_k: int = 5) -> tuple:
"""
Gradio interface function (returns formatted output)
"""
if not text or text.strip() == "":
return "[i] Please enter symptom description", ""
result = predict_symptoms(text, return_all_scores=True, top_k=top_k)
if "error" in result:
return f"[!] Error: {result['error']}", ""
# Format primary result
primary = f"""
## 🎯 Primary Classification
**Condition:** {result['primary_classification']}
**Confidence:** {result['confidence']:.2%}
---
"""
# Format top predictions
predictions_text = "## πŸ“Š Top Predictions\n\n"
for i, pred in enumerate(result['predictions'][:top_k], 1):
bar = "β–ˆ" * int(pred['score'] * 20)
predictions_text += f"{i}. **{pred['label']}** \n"
predictions_text += f" {bar} {pred['score']:.2%}\n\n"
# JSON output for API testing
json_output = json.dumps(result, indent=2)
return primary + predictions_text, json_output
# =============================================================================
# EXAMPLE CASES
# =============================================================================
EXAMPLES = [
["Patient presents fever of 102Β°F, severe headache, body pain and weakness for 3 days"],
["Patient presents persistent cough with yellow phlegm, chest congestion, difficulty breathing"],
["Patient presents stomach pain, nausea, vomiting and diarrhea since yesterday"],
["Patient presents severe headache on one side, sensitivity to light and sound, nausea"],
["Patient presents skin rash with red bumps, itching all over body"],
]
# =============================================================================
# GRADIO INTERFACE
# =============================================================================
def create_demo():
"""Create Gradio interface"""
# Remove theme from Blocks constructor for Gradio 6.0 compatibility
with gr.Blocks(title="Nivra ClinicalBERT Inference Endpoint") as demo:
gr.Markdown("""
# πŸ₯ Nivra ClinicalBERT - Symptom Text Classifier
Fine-tuned on Indian healthcare data for accurate symptom classification.
### How to use:
1. Enter patient's symptom description in plain language
2. Get AI-powered classification with confidence scores
3. Use the JSON output for API integration
""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Patient Symptom Description",
placeholder="Enter symptoms here... (e.g., 'I have fever, headache and body pain')",
lines=5
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of predictions to show"
)
predict_btn = gr.Button("πŸ” Analyze Symptoms", variant="primary", size="lg")
gr.Markdown("### πŸ“ Example Cases:")
gr.Examples(
examples=EXAMPLES,
inputs=text_input,
label="Click to try"
)
with gr.Column(scale=3):
output_text = gr.Markdown(label="Analysis Results")
with gr.Accordion("πŸ“„ JSON Output (for API integration)", open=False):
json_output = gr.Code(
label="JSON Response",
language="json",
lines=15
)
# Event handlers
predict_btn.click(
fn=predict_gradio,
inputs=[text_input, top_k_slider],
outputs=[output_text, json_output]
)
# Also trigger on Enter key
text_input.submit(
fn=predict_gradio,
inputs=[text_input, top_k_slider],
outputs=[output_text, json_output]
)
# API documentation
gr.Markdown("""
---
## πŸ”Œ API Usage
### Python Example:
```python
import requests
url = "https://datdevsteve-nivra-clinicalbert-finetuned.hf.space/api/predict"
response = requests.post(url, json={"data": ["Your symptom text here"]})
result = response.json()
print(result)
```
### cURL Example:
```bash
curl -X POST https://datdevsteve-nivra-clinicalbert-finetuned.hf.space/api/predict \\
-H "Content-Type: application/json" \\
-d '{"data": ["fever and headache for 2 days"]}'
```
""")
return demo
# =============================================================================
# LAUNCH
# =============================================================================
if __name__ == "__main__":
demo = create_demo()
# Launch with Gradio 6.0 compatible parameters
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
theme=gr.themes.Ocean()
# API is automatically enabled in Gradio 6.0+
)