rehaan
Refactor model loading logic and improve error handling in app.py; add new .codex file
b8132d8 | """ | |
| Nivra Medical Image Classifier - HuggingFace Space | |
| AI-powered dermatology image classification for rural Indian healthcare. | |
| """ | |
| import gradio as gr | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| import torch | |
| from PIL import Image | |
| import json | |
| from typing import Dict, Any | |
| import logging | |
| import numpy as np | |
| import requests | |
| from io import BytesIO | |
| from threading import Lock | |
| # ============================================================================= | |
| # LOGGING | |
| # ============================================================================= | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # MODEL CONFIGURATION | |
| # ============================================================================= | |
| MODEL_NAME = "datdevsteve/dinov2-nivra-finetuned" | |
| image_processor = None | |
| model = None | |
| id2label = {} | |
| model_load_lock = Lock() | |
| def ensure_model_loaded() -> None: | |
| global image_processor, model, id2label | |
| if image_processor is not None and model is not None: | |
| return | |
| with model_load_lock: | |
| if image_processor is not None and model is not None: | |
| return | |
| logger.info(f"[i] Loading model: {MODEL_NAME}") | |
| try: | |
| image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| id2label = ( | |
| model.config.id2label if hasattr(model.config, "id2label") else {} | |
| ) | |
| logger.info("[i] Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"[!] Error loading model: {e}") | |
| raise RuntimeError( | |
| f"Unable to load model '{MODEL_NAME}'. Check Space network access " | |
| f"and repository visibility. Original error: {e}" | |
| ) from e | |
| # ============================================================================= | |
| # CONFIDENCE CLASSIFICATION | |
| # ============================================================================= | |
| def classify_confidence(score: float) -> str: | |
| if score >= 0.80: | |
| return "high" | |
| elif score >= 0.50: | |
| return "moderate" | |
| else: | |
| return "low" | |
| # ============================================================================= | |
| # CORE PREDICTION LOGIC | |
| # ============================================================================= | |
| def predict_medical_image(image: Image.Image, top_k: int = 5) -> Dict[str, Any]: | |
| try: | |
| ensure_model_loaded() | |
| inputs = image_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.softmax(outputs.logits, dim=-1)[0] | |
| predictions = [ | |
| { | |
| "label": id2label.get(idx, f"LABEL_{idx}"), | |
| "probability": float(prob) | |
| } | |
| for idx, prob in enumerate(probabilities) | |
| ] | |
| predictions = sorted(predictions, key=lambda x: x["probability"], reverse=True) | |
| top_predictions = predictions[:top_k] | |
| primary = top_predictions[0] | |
| structured_result = { | |
| "primary_condition": primary["label"], | |
| "confidence_score": primary["probability"], | |
| "confidence_level": classify_confidence(primary["probability"]), | |
| "top_predictions": top_predictions, | |
| "requires_medical_attention": primary["probability"] < 0.60, | |
| "model": MODEL_NAME | |
| } | |
| logger.info( | |
| f"[i] Prediction: {structured_result['primary_condition']} " | |
| f"({structured_result['confidence_score']:.4f})" | |
| ) | |
| return structured_result | |
| except Exception as e: | |
| logger.error(f"[!] Prediction error: {e}") | |
| return { | |
| "error": str(e), | |
| "primary_condition": "error", | |
| "confidence_score": 0.0, | |
| "confidence_level": "low", | |
| "top_predictions": [], | |
| "requires_medical_attention": True | |
| } | |
| # ============================================================================= | |
| # INPUT HANDLING (UPLOAD + URL) | |
| # ============================================================================= | |
| def load_image_from_input(image_upload, image_url): | |
| """ | |
| Supports: | |
| - Uploaded image (manual UI testing) | |
| - Supabase public image URL (agent usage) | |
| """ | |
| # Manual upload | |
| if image_upload is not None: | |
| if isinstance(image_upload, np.ndarray): | |
| return Image.fromarray(image_upload).convert("RGB") | |
| return image_upload.convert("RGB") | |
| # URL-based input | |
| if image_url: | |
| try: | |
| response = requests.get(image_url, timeout=10) | |
| response.raise_for_status() | |
| return Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception as e: | |
| raise ValueError(f"Failed to load image from URL: {e}") | |
| raise ValueError("No image provided.") | |
| # ============================================================================= | |
| # GRADIO WRAPPER | |
| # ============================================================================= | |
| def predict_gradio(image_upload, image_url, top_k): | |
| try: | |
| image = load_image_from_input(image_upload, image_url) | |
| except Exception as e: | |
| return f"[!] {str(e)}", "" | |
| result = predict_medical_image(image, top_k=top_k) | |
| if "error" in result: | |
| return f"[!] {result['error']}", "" | |
| primary_block = f""" | |
| ## 🎯 Primary Classification | |
| **Condition:** {result['primary_condition']} | |
| **Confidence Score:** {result['confidence_score']:.2%} | |
| **Confidence Level:** {result['confidence_level'].upper()} | |
| --- | |
| """ | |
| predictions_text = "## 📊 Top Predictions\n\n" | |
| for i, pred in enumerate(result["top_predictions"], 1): | |
| predictions_text += ( | |
| f"{i}. **{pred['label']}** — {pred['probability']:.2%}\n\n" | |
| ) | |
| disclaimer = """ | |
| --- | |
| ⚠️ **Medical Disclaimer:** This AI tool is for preliminary screening only. | |
| Consult a licensed healthcare professional for confirmed diagnosis. | |
| """ | |
| return ( | |
| primary_block + predictions_text + disclaimer, | |
| json.dumps(result, indent=2) | |
| ) | |
| # ============================================================================= | |
| # GRADIO INTERFACE | |
| # ============================================================================= | |
| def create_demo(): | |
| with gr.Blocks(title="Nivra Medical Image Classifier") as demo: | |
| gr.Markdown("# 🏥 Nivra Medical Image Classifier") | |
| gr.Markdown( | |
| "Upload an image or provide a public URL. " | |
| "The model is loaded on the first prediction so the Space can start quickly." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_upload = gr.Image( | |
| label="Upload Medical Image (Manual Testing)", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| height=350 | |
| ) | |
| image_url = gr.Textbox( | |
| label="Or Paste Medical Image URL (Supabase Public URL)", | |
| placeholder="https://your-supabase-url/image.jpg" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Number of predictions" | |
| ) | |
| predict_btn = gr.Button("🔍 Analyze Image", variant="primary") | |
| with gr.Column(scale=3): | |
| output_text = gr.Markdown() | |
| json_output = gr.Code(language="json") | |
| predict_btn.click( | |
| fn=predict_gradio, | |
| inputs=[image_upload, image_url, top_k_slider], | |
| outputs=[output_text, json_output], | |
| api_name="predict" # CRITICAL for gradio_client | |
| ) | |
| return demo | |
| # ============================================================================= | |
| # LAUNCH | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| ssr_mode=False | |
| ) | |