""" Nivra ClinicalBERT Fine-tuned Model - HuggingFace Space Symptom Classification for Indian Healthcare """ import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import json from typing import Dict, List, Any import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ============================================================================= # MODEL CONFIGURATION # ============================================================================= MODEL_NAME = "datdevsteve/nivra-clinicalbert-finetuned" # Load model and tokenizer logger.info(f"[i] Loading model: {MODEL_NAME}") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) model.eval() logger.info("[i] Model loaded successfully") except Exception as e: logger.error(f"[!] Error loading model: {e}") raise # Get label names from model config id2label = model.config.id2label if hasattr(model.config, 'id2label') else {} # ============================================================================= # PREDICTION FUNCTIONS # ============================================================================= def predict_symptoms(text: str, return_all_scores: bool = True, top_k: int = 5) -> Dict[str, Any]: """ Predict diseases by classifying symptom text Args: text: Patient's symptom description return_all_scores: If True, return all class probabilities top_k: Number of top predictions to return Returns: Dictionary with predictions and metadata """ try: # Tokenize input inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ) # Get predictions with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1)[0] # Format results predictions = [] for idx, prob in enumerate(probabilities): label = id2label.get(idx, f"LABEL_{idx}") score = float(prob) predictions.append({ "label": label, "score": score }) # Sort by score predictions = sorted(predictions, key=lambda x: x['score'], reverse=True) if not return_all_scores: predictions = predictions[:top_k] result = { "predictions": predictions, "primary_classification": predictions[0]['label'], "confidence": predictions[0]['score'], "model": MODEL_NAME, "input_text": text } logger.info(f"[i] Prediction: {predictions[0]['label']} ({predictions[0]['score']:.4f})") return result except Exception as e: logger.error(f"[!] Prediction error: {e}") return { "error": str(e), "predictions": [], "primary_classification": "error", "confidence": 0.0 } def predict_api(text: str) -> str: """ API endpoint for symptom text prediction (returns JSON string) """ result = predict_symptoms(text, return_all_scores=True, top_k=5) return json.dumps(result, indent=2) def predict_gradio(text: str, top_k: int = 5) -> tuple: """ Gradio interface function (returns formatted output) """ if not text or text.strip() == "": return "[i] Please enter symptom description", "" result = predict_symptoms(text, return_all_scores=True, top_k=top_k) if "error" in result: return f"[!] Error: {result['error']}", "" # Format primary result primary = f""" ## 🎯 Primary Classification **Condition:** {result['primary_classification']} **Confidence:** {result['confidence']:.2%} --- """ # Format top predictions predictions_text = "## 📊 Top Predictions\n\n" for i, pred in enumerate(result['predictions'][:top_k], 1): bar = "█" * int(pred['score'] * 20) predictions_text += f"{i}. **{pred['label']}** \n" predictions_text += f" {bar} {pred['score']:.2%}\n\n" # JSON output for API testing json_output = json.dumps(result, indent=2) return primary + predictions_text, json_output # ============================================================================= # EXAMPLE CASES # ============================================================================= EXAMPLES = [ ["Patient presents fever of 102°F, severe headache, body pain and weakness for 3 days"], ["Patient presents persistent cough with yellow phlegm, chest congestion, difficulty breathing"], ["Patient presents stomach pain, nausea, vomiting and diarrhea since yesterday"], ["Patient presents severe headache on one side, sensitivity to light and sound, nausea"], ["Patient presents skin rash with red bumps, itching all over body"], ] # ============================================================================= # GRADIO INTERFACE # ============================================================================= def create_demo(): """Create Gradio interface""" # Remove theme from Blocks constructor for Gradio 6.0 compatibility with gr.Blocks(title="Nivra ClinicalBERT Inference Endpoint") as demo: gr.Markdown(""" # 🏥 Nivra ClinicalBERT - Symptom Text Classifier Fine-tuned on Indian healthcare data for accurate symptom classification. ### How to use: 1. Enter patient's symptom description in plain language 2. Get AI-powered classification with confidence scores 3. Use the JSON output for API integration """) with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Patient Symptom Description", placeholder="Enter symptoms here... (e.g., 'I have fever, headache and body pain')", lines=5 ) with gr.Row(): top_k_slider = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of predictions to show" ) predict_btn = gr.Button("🔍 Analyze Symptoms", variant="primary", size="lg") gr.Markdown("### 📝 Example Cases:") gr.Examples( examples=EXAMPLES, inputs=text_input, label="Click to try" ) with gr.Column(scale=3): output_text = gr.Markdown(label="Analysis Results") with gr.Accordion("📄 JSON Output (for API integration)", open=False): json_output = gr.Code( label="JSON Response", language="json", lines=15 ) # Event handlers predict_btn.click( fn=predict_gradio, inputs=[text_input, top_k_slider], outputs=[output_text, json_output] ) # Also trigger on Enter key text_input.submit( fn=predict_gradio, inputs=[text_input, top_k_slider], outputs=[output_text, json_output] ) # API documentation gr.Markdown(""" --- ## 🔌 API Usage ### Python Example: ```python import requests url = "https://datdevsteve-nivra-clinicalbert-finetuned.hf.space/api/predict" response = requests.post(url, json={"data": ["Your symptom text here"]}) result = response.json() print(result) ``` ### cURL Example: ```bash curl -X POST https://datdevsteve-nivra-clinicalbert-finetuned.hf.space/api/predict \\ -H "Content-Type: application/json" \\ -d '{"data": ["fever and headache for 2 days"]}' ``` """) return demo # ============================================================================= # LAUNCH # ============================================================================= if __name__ == "__main__": demo = create_demo() # Launch with Gradio 6.0 compatible parameters demo.launch( server_name="0.0.0.0", server_port=7860, share=False, theme=gr.themes.Ocean() # API is automatically enabled in Gradio 6.0+ )