rehaan
Refactor model loading logic and improve error handling in app.py; add new .codex file
b8132d8
"""
Nivra Medical Image Classifier - HuggingFace Space
AI-powered dermatology image classification for rural Indian healthcare.
"""
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
import json
from typing import Dict, Any
import logging
import numpy as np
import requests
from io import BytesIO
from threading import Lock
# =============================================================================
# LOGGING
# =============================================================================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# MODEL CONFIGURATION
# =============================================================================
MODEL_NAME = "datdevsteve/dinov2-nivra-finetuned"
image_processor = None
model = None
id2label = {}
model_load_lock = Lock()
def ensure_model_loaded() -> None:
global image_processor, model, id2label
if image_processor is not None and model is not None:
return
with model_load_lock:
if image_processor is not None and model is not None:
return
logger.info(f"[i] Loading model: {MODEL_NAME}")
try:
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
model.eval()
id2label = (
model.config.id2label if hasattr(model.config, "id2label") else {}
)
logger.info("[i] Model loaded successfully")
except Exception as e:
logger.error(f"[!] Error loading model: {e}")
raise RuntimeError(
f"Unable to load model '{MODEL_NAME}'. Check Space network access "
f"and repository visibility. Original error: {e}"
) from e
# =============================================================================
# CONFIDENCE CLASSIFICATION
# =============================================================================
def classify_confidence(score: float) -> str:
if score >= 0.80:
return "high"
elif score >= 0.50:
return "moderate"
else:
return "low"
# =============================================================================
# CORE PREDICTION LOGIC
# =============================================================================
def predict_medical_image(image: Image.Image, top_k: int = 5) -> Dict[str, Any]:
try:
ensure_model_loaded()
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=-1)[0]
predictions = [
{
"label": id2label.get(idx, f"LABEL_{idx}"),
"probability": float(prob)
}
for idx, prob in enumerate(probabilities)
]
predictions = sorted(predictions, key=lambda x: x["probability"], reverse=True)
top_predictions = predictions[:top_k]
primary = top_predictions[0]
structured_result = {
"primary_condition": primary["label"],
"confidence_score": primary["probability"],
"confidence_level": classify_confidence(primary["probability"]),
"top_predictions": top_predictions,
"requires_medical_attention": primary["probability"] < 0.60,
"model": MODEL_NAME
}
logger.info(
f"[i] Prediction: {structured_result['primary_condition']} "
f"({structured_result['confidence_score']:.4f})"
)
return structured_result
except Exception as e:
logger.error(f"[!] Prediction error: {e}")
return {
"error": str(e),
"primary_condition": "error",
"confidence_score": 0.0,
"confidence_level": "low",
"top_predictions": [],
"requires_medical_attention": True
}
# =============================================================================
# INPUT HANDLING (UPLOAD + URL)
# =============================================================================
def load_image_from_input(image_upload, image_url):
"""
Supports:
- Uploaded image (manual UI testing)
- Supabase public image URL (agent usage)
"""
# Manual upload
if image_upload is not None:
if isinstance(image_upload, np.ndarray):
return Image.fromarray(image_upload).convert("RGB")
return image_upload.convert("RGB")
# URL-based input
if image_url:
try:
response = requests.get(image_url, timeout=10)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
raise ValueError(f"Failed to load image from URL: {e}")
raise ValueError("No image provided.")
# =============================================================================
# GRADIO WRAPPER
# =============================================================================
def predict_gradio(image_upload, image_url, top_k):
try:
image = load_image_from_input(image_upload, image_url)
except Exception as e:
return f"[!] {str(e)}", ""
result = predict_medical_image(image, top_k=top_k)
if "error" in result:
return f"[!] {result['error']}", ""
primary_block = f"""
## 🎯 Primary Classification
**Condition:** {result['primary_condition']}
**Confidence Score:** {result['confidence_score']:.2%}
**Confidence Level:** {result['confidence_level'].upper()}
---
"""
predictions_text = "## 📊 Top Predictions\n\n"
for i, pred in enumerate(result["top_predictions"], 1):
predictions_text += (
f"{i}. **{pred['label']}** — {pred['probability']:.2%}\n\n"
)
disclaimer = """
---
⚠️ **Medical Disclaimer:** This AI tool is for preliminary screening only.
Consult a licensed healthcare professional for confirmed diagnosis.
"""
return (
primary_block + predictions_text + disclaimer,
json.dumps(result, indent=2)
)
# =============================================================================
# GRADIO INTERFACE
# =============================================================================
def create_demo():
with gr.Blocks(title="Nivra Medical Image Classifier") as demo:
gr.Markdown("# 🏥 Nivra Medical Image Classifier")
gr.Markdown(
"Upload an image or provide a public URL. "
"The model is loaded on the first prediction so the Space can start quickly."
)
with gr.Row():
with gr.Column(scale=2):
image_upload = gr.Image(
label="Upload Medical Image (Manual Testing)",
type="pil",
sources=["upload", "clipboard"],
height=350
)
image_url = gr.Textbox(
label="Or Paste Medical Image URL (Supabase Public URL)",
placeholder="https://your-supabase-url/image.jpg"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of predictions"
)
predict_btn = gr.Button("🔍 Analyze Image", variant="primary")
with gr.Column(scale=3):
output_text = gr.Markdown()
json_output = gr.Code(language="json")
predict_btn.click(
fn=predict_gradio,
inputs=[image_upload, image_url, top_k_slider],
outputs=[output_text, json_output],
api_name="predict" # CRITICAL for gradio_client
)
return demo
# =============================================================================
# LAUNCH
# =============================================================================
if __name__ == "__main__":
demo = create_demo()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
ssr_mode=False
)