|
|
""" |
|
|
Nivra ClinicalBERT Fine-tuned Model - HuggingFace Space |
|
|
Symptom Classification for Indian Healthcare |
|
|
""" |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
import json |
|
|
from typing import Dict, List, Any |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "datdevsteve/nivra-clinicalbert-finetuned" |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"[i] Loading model: {MODEL_NAME}") |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
|
|
model.eval() |
|
|
logger.info("[i] Model loaded successfully") |
|
|
except Exception as e: |
|
|
logger.error(f"[!] Error loading model: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
id2label = model.config.id2label if hasattr(model.config, 'id2label') else {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_symptoms(text: str, return_all_scores: bool = True, top_k: int = 5) -> Dict[str, Any]: |
|
|
""" |
|
|
Predict diseases by classifying symptom text |
|
|
|
|
|
Args: |
|
|
text: Patient's symptom description |
|
|
return_all_scores: If True, return all class probabilities |
|
|
top_k: Number of top predictions to return |
|
|
|
|
|
Returns: |
|
|
Dictionary with predictions and metadata |
|
|
""" |
|
|
try: |
|
|
|
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding=True |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.softmax(logits, dim=-1)[0] |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
for idx, prob in enumerate(probabilities): |
|
|
label = id2label.get(idx, f"LABEL_{idx}") |
|
|
score = float(prob) |
|
|
predictions.append({ |
|
|
"label": label, |
|
|
"score": score |
|
|
}) |
|
|
|
|
|
|
|
|
predictions = sorted(predictions, key=lambda x: x['score'], reverse=True) |
|
|
|
|
|
if not return_all_scores: |
|
|
predictions = predictions[:top_k] |
|
|
|
|
|
result = { |
|
|
"predictions": predictions, |
|
|
"primary_classification": predictions[0]['label'], |
|
|
"confidence": predictions[0]['score'], |
|
|
"model": MODEL_NAME, |
|
|
"input_text": text |
|
|
} |
|
|
|
|
|
logger.info(f"[i] Prediction: {predictions[0]['label']} ({predictions[0]['score']:.4f})") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[!] Prediction error: {e}") |
|
|
return { |
|
|
"error": str(e), |
|
|
"predictions": [], |
|
|
"primary_classification": "error", |
|
|
"confidence": 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
def predict_api(text: str) -> str: |
|
|
""" |
|
|
API endpoint for symptom text prediction (returns JSON string) |
|
|
""" |
|
|
result = predict_symptoms(text, return_all_scores=True, top_k=5) |
|
|
return json.dumps(result, indent=2) |
|
|
|
|
|
|
|
|
def predict_gradio(text: str, top_k: int = 5) -> tuple: |
|
|
""" |
|
|
Gradio interface function (returns formatted output) |
|
|
""" |
|
|
if not text or text.strip() == "": |
|
|
return "[i] Please enter symptom description", "" |
|
|
|
|
|
result = predict_symptoms(text, return_all_scores=True, top_k=top_k) |
|
|
|
|
|
if "error" in result: |
|
|
return f"[!] Error: {result['error']}", "" |
|
|
|
|
|
|
|
|
primary = f""" |
|
|
## π― Primary Classification |
|
|
**Condition:** {result['primary_classification']} |
|
|
**Confidence:** {result['confidence']:.2%} |
|
|
--- |
|
|
""" |
|
|
|
|
|
|
|
|
predictions_text = "## π Top Predictions\n\n" |
|
|
for i, pred in enumerate(result['predictions'][:top_k], 1): |
|
|
bar = "β" * int(pred['score'] * 20) |
|
|
predictions_text += f"{i}. **{pred['label']}** \n" |
|
|
predictions_text += f" {bar} {pred['score']:.2%}\n\n" |
|
|
|
|
|
|
|
|
json_output = json.dumps(result, indent=2) |
|
|
|
|
|
return primary + predictions_text, json_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EXAMPLES = [ |
|
|
["Patient presents fever of 102Β°F, severe headache, body pain and weakness for 3 days"], |
|
|
["Patient presents persistent cough with yellow phlegm, chest congestion, difficulty breathing"], |
|
|
["Patient presents stomach pain, nausea, vomiting and diarrhea since yesterday"], |
|
|
["Patient presents severe headache on one side, sensitivity to light and sound, nausea"], |
|
|
["Patient presents skin rash with red bumps, itching all over body"], |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_demo(): |
|
|
"""Create Gradio interface""" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Nivra ClinicalBERT Inference Endpoint") as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# π₯ Nivra ClinicalBERT - Symptom Text Classifier |
|
|
|
|
|
Fine-tuned on Indian healthcare data for accurate symptom classification. |
|
|
|
|
|
### How to use: |
|
|
1. Enter patient's symptom description in plain language |
|
|
2. Get AI-powered classification with confidence scores |
|
|
3. Use the JSON output for API integration |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
text_input = gr.Textbox( |
|
|
label="Patient Symptom Description", |
|
|
placeholder="Enter symptoms here... (e.g., 'I have fever, headache and body pain')", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
top_k_slider = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=5, |
|
|
step=1, |
|
|
label="Number of predictions to show" |
|
|
) |
|
|
|
|
|
predict_btn = gr.Button("π Analyze Symptoms", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown("### π Example Cases:") |
|
|
gr.Examples( |
|
|
examples=EXAMPLES, |
|
|
inputs=text_input, |
|
|
label="Click to try" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
output_text = gr.Markdown(label="Analysis Results") |
|
|
|
|
|
with gr.Accordion("π JSON Output (for API integration)", open=False): |
|
|
json_output = gr.Code( |
|
|
label="JSON Response", |
|
|
language="json", |
|
|
lines=15 |
|
|
) |
|
|
|
|
|
|
|
|
predict_btn.click( |
|
|
fn=predict_gradio, |
|
|
inputs=[text_input, top_k_slider], |
|
|
outputs=[output_text, json_output] |
|
|
) |
|
|
|
|
|
|
|
|
text_input.submit( |
|
|
fn=predict_gradio, |
|
|
inputs=[text_input, top_k_slider], |
|
|
outputs=[output_text, json_output] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
## π API Usage |
|
|
|
|
|
### Python Example: |
|
|
```python |
|
|
import requests |
|
|
|
|
|
url = "https://datdevsteve-nivra-clinicalbert-finetuned.hf.space/api/predict" |
|
|
response = requests.post(url, json={"data": ["Your symptom text here"]}) |
|
|
result = response.json() |
|
|
print(result) |
|
|
``` |
|
|
|
|
|
### cURL Example: |
|
|
```bash |
|
|
curl -X POST https://datdevsteve-nivra-clinicalbert-finetuned.hf.space/api/predict \\ |
|
|
-H "Content-Type: application/json" \\ |
|
|
-d '{"data": ["fever and headache for 2 days"]}' |
|
|
``` |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_demo() |
|
|
|
|
|
|
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
theme=gr.themes.Ocean() |
|
|
|
|
|
|
|
|
) |
|
|
|