datdevsteve's picture
catastrophic refactor of "/predict" to "predict" in api_name
2e79116 verified
"""
Nivra Medical Image Classifier - HuggingFace Space
Medical Image Classification for Indian Healthcare
"""
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
import json
from typing import Dict, Any
import logging
import numpy as np
import requests
from io import BytesIO
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# =============================================================================
# MODEL CONFIGURATION
# =============================================================================
MODEL_NAME = "datdevsteve/dinov2-nivra-finetuned"
logger.info(f"[i] Loading model: {MODEL_NAME}")
try:
image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
model.eval()
logger.info("[i] Model loaded successfully")
except Exception as e:
logger.error(f"[!] Error loading model: {e}")
raise
id2label = model.config.id2label if hasattr(model.config, "id2label") else {}
# =============================================================================
# CORE PREDICTION LOGIC
# =============================================================================
def predict_medical_image(image: Image.Image, top_k: int = 5) -> Dict[str, Any]:
try:
inputs = image_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=-1)[0]
predictions = []
for idx, prob in enumerate(probabilities):
predictions.append({
"label": id2label.get(idx, f"LABEL_{idx}"),
"score": float(prob)
})
predictions = sorted(predictions, key=lambda x: x["score"], reverse=True)
predictions = predictions[:top_k]
result = {
"predictions": predictions,
"primary_classification": predictions[0]["label"],
"confidence": predictions[0]["score"],
"model": MODEL_NAME
}
logger.info(f"[i] Prediction: {result['primary_classification']} ({result['confidence']:.4f})")
return result
except Exception as e:
logger.error(f"[!] Prediction error: {e}")
return {
"error": str(e),
"predictions": [],
"primary_classification": "error",
"confidence": 0.0
}
# =============================================================================
# INPUT HANDLING (URL + Upload)
# =============================================================================
def load_image_from_input(image_upload, image_url):
"""
Handles both:
- Uploaded image (manual UI testing)
- Supabase image URL (agent usage)
"""
if image_upload is not None:
if isinstance(image_upload, np.ndarray):
return Image.fromarray(image_upload).convert("RGB")
return image_upload.convert("RGB")
if image_url:
try:
response = requests.get(image_url, timeout=10)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
raise ValueError(f"Failed to load image from URL: {e}")
raise ValueError("No image provided.")
# =============================================================================
# GRADIO PREDICTION WRAPPER
# =============================================================================
def predict_gradio(image_upload, image_url, top_k):
try:
image = load_image_from_input(image_upload, image_url)
except Exception as e:
return f"[!] {str(e)}", ""
result = predict_medical_image(image, top_k=top_k)
if "error" in result:
return f"[!] {result['error']}", ""
primary = f"""
## 🎯 Primary Classification
**Condition:** {result['primary_classification']}
**Confidence:** {result['confidence']:.2%}
---
"""
predictions_text = "## πŸ“Š Top Predictions\n\n"
for i, pred in enumerate(result["predictions"], 1):
predictions_text += f"{i}. **{pred['label']}** β€” {pred['score']:.2%}\n\n"
disclaimer = """
---
⚠️ **Medical Disclaimer:** This AI tool is for preliminary screening only.
Consult a licensed healthcare professional for confirmed diagnosis.
"""
json_output = json.dumps(result, indent=2)
return primary + predictions_text + disclaimer, json_output
# =============================================================================
# GRADIO INTERFACE
# =============================================================================
def create_demo():
with gr.Blocks(title="Nivra Medical Image Classifier") as demo:
gr.Markdown("# πŸ₯ Nivra Medical Image Classifier")
with gr.Row():
with gr.Column(scale=2):
image_upload = gr.Image(
label="Upload Medical Image (Manual Testing)",
type="pil",
sources=["upload", "clipboard"],
height=350
)
image_url = gr.Textbox(
label="Or Paste Medical Image URL (Supabase Public URL)",
placeholder="https://your-supabase-url/image.jpg"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of predictions"
)
predict_btn = gr.Button("πŸ” Analyze Image", variant="primary")
with gr.Column(scale=3):
output_text = gr.Markdown()
json_output = gr.Code(language="json")
predict_btn.click(
fn=predict_gradio,
inputs=[image_upload, image_url, top_k_slider],
outputs=[output_text, json_output],
api_name="predict" # IMPORTANT for gradio_client
)
return demo
# =============================================================================
# LAUNCH
# =============================================================================
if __name__ == "__main__":
demo = create_demo()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
ssr_mode=False
)