Spaces:
Sleeping
Sleeping
| """ | |
| Nivra Medical Image Classifier - HuggingFace Space | |
| Medical Image Classification for Indian Healthcare | |
| """ | |
| import gradio as gr | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| import torch | |
| from PIL import Image | |
| import json | |
| from typing import Dict, Any | |
| import logging | |
| import numpy as np | |
| import requests | |
| from io import BytesIO | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================= | |
| # MODEL CONFIGURATION | |
| # ============================================================================= | |
| MODEL_NAME = "datdevsteve/dinov2-nivra-finetuned" | |
| logger.info(f"[i] Loading model: {MODEL_NAME}") | |
| try: | |
| image_processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| logger.info("[i] Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"[!] Error loading model: {e}") | |
| raise | |
| id2label = model.config.id2label if hasattr(model.config, "id2label") else {} | |
| # ============================================================================= | |
| # CORE PREDICTION LOGIC | |
| # ============================================================================= | |
| def predict_medical_image(image: Image.Image, top_k: int = 5) -> Dict[str, Any]: | |
| try: | |
| inputs = image_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.softmax(outputs.logits, dim=-1)[0] | |
| predictions = [] | |
| for idx, prob in enumerate(probabilities): | |
| predictions.append({ | |
| "label": id2label.get(idx, f"LABEL_{idx}"), | |
| "score": float(prob) | |
| }) | |
| predictions = sorted(predictions, key=lambda x: x["score"], reverse=True) | |
| predictions = predictions[:top_k] | |
| result = { | |
| "predictions": predictions, | |
| "primary_classification": predictions[0]["label"], | |
| "confidence": predictions[0]["score"], | |
| "model": MODEL_NAME | |
| } | |
| logger.info(f"[i] Prediction: {result['primary_classification']} ({result['confidence']:.4f})") | |
| return result | |
| except Exception as e: | |
| logger.error(f"[!] Prediction error: {e}") | |
| return { | |
| "error": str(e), | |
| "predictions": [], | |
| "primary_classification": "error", | |
| "confidence": 0.0 | |
| } | |
| # ============================================================================= | |
| # INPUT HANDLING (URL + Upload) | |
| # ============================================================================= | |
| def load_image_from_input(image_upload, image_url): | |
| """ | |
| Handles both: | |
| - Uploaded image (manual UI testing) | |
| - Supabase image URL (agent usage) | |
| """ | |
| if image_upload is not None: | |
| if isinstance(image_upload, np.ndarray): | |
| return Image.fromarray(image_upload).convert("RGB") | |
| return image_upload.convert("RGB") | |
| if image_url: | |
| try: | |
| response = requests.get(image_url, timeout=10) | |
| response.raise_for_status() | |
| return Image.open(BytesIO(response.content)).convert("RGB") | |
| except Exception as e: | |
| raise ValueError(f"Failed to load image from URL: {e}") | |
| raise ValueError("No image provided.") | |
| # ============================================================================= | |
| # GRADIO PREDICTION WRAPPER | |
| # ============================================================================= | |
| def predict_gradio(image_upload, image_url, top_k): | |
| try: | |
| image = load_image_from_input(image_upload, image_url) | |
| except Exception as e: | |
| return f"[!] {str(e)}", "" | |
| result = predict_medical_image(image, top_k=top_k) | |
| if "error" in result: | |
| return f"[!] {result['error']}", "" | |
| primary = f""" | |
| ## π― Primary Classification | |
| **Condition:** {result['primary_classification']} | |
| **Confidence:** {result['confidence']:.2%} | |
| --- | |
| """ | |
| predictions_text = "## π Top Predictions\n\n" | |
| for i, pred in enumerate(result["predictions"], 1): | |
| predictions_text += f"{i}. **{pred['label']}** β {pred['score']:.2%}\n\n" | |
| disclaimer = """ | |
| --- | |
| β οΈ **Medical Disclaimer:** This AI tool is for preliminary screening only. | |
| Consult a licensed healthcare professional for confirmed diagnosis. | |
| """ | |
| json_output = json.dumps(result, indent=2) | |
| return primary + predictions_text + disclaimer, json_output | |
| # ============================================================================= | |
| # GRADIO INTERFACE | |
| # ============================================================================= | |
| def create_demo(): | |
| with gr.Blocks(title="Nivra Medical Image Classifier") as demo: | |
| gr.Markdown("# π₯ Nivra Medical Image Classifier") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_upload = gr.Image( | |
| label="Upload Medical Image (Manual Testing)", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| height=350 | |
| ) | |
| image_url = gr.Textbox( | |
| label="Or Paste Medical Image URL (Supabase Public URL)", | |
| placeholder="https://your-supabase-url/image.jpg" | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Number of predictions" | |
| ) | |
| predict_btn = gr.Button("π Analyze Image", variant="primary") | |
| with gr.Column(scale=3): | |
| output_text = gr.Markdown() | |
| json_output = gr.Code(language="json") | |
| predict_btn.click( | |
| fn=predict_gradio, | |
| inputs=[image_upload, image_url, top_k_slider], | |
| outputs=[output_text, json_output], | |
| api_name="predict" # IMPORTANT for gradio_client | |
| ) | |
| return demo | |
| # ============================================================================= | |
| # LAUNCH | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| ssr_mode=False | |
| ) | |