Spaces:
Sleeping
Sleeping
File size: 6,607 Bytes
bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a bc1fb7d b46360a 9bdbbe5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """Model inference for MelanoScope AI."""
import json
import logging
import time
from typing import Dict, Any, List, Optional, Tuple
import numpy as np
import onnxruntime as ort
from ..config.settings import ModelConfig, DATA_FILE, MODEL_FILE
from .preprocessing import ImagePreprocessor
from .utils import probabilities_to_ints, format_confidence
logger = logging.getLogger(__name__)
class MelanoScopeModel:
"""MelanoScope AI model for skin lesion classification."""
def __init__(self):
self.preprocessor = ImagePreprocessor()
self.session: Optional[ort.InferenceSession] = None
self.classes: List[str] = []
self.medical_data: Dict[str, Any] = {}
self._load_model()
self._load_medical_data()
logger.info(f"Model initialized with {len(self.classes)} classes")
def _load_model(self) -> None:
"""Load ONNX model."""
try:
if not MODEL_FILE.exists():
raise FileNotFoundError(f"Model file not found: {MODEL_FILE}")
self.session = ort.InferenceSession(str(MODEL_FILE), providers=ModelConfig.ORT_PROVIDERS)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError(f"Model loading failed: {e}")
def _load_medical_data(self) -> None:
"""Load medical condition data."""
try:
with open(DATA_FILE, "r", encoding="utf-8") as f:
self.medical_data = json.load(f)
self.classes = list(self.medical_data.keys())
logger.info(f"Loaded data for {len(self.classes)} conditions")
except Exception as e:
logger.error(f"Failed to load medical data: {e}")
raise RuntimeError(f"Medical data loading failed: {e}")
def predict(self, image_input: Any) -> Tuple[str, str, str, str, str, str, Any, str]:
"""Perform inference on input image."""
if image_input is None:
return self._create_empty_result("Please upload an image and click Analyze.")
try:
start_time = time.time()
input_tensor = self.preprocessor.preprocess(image_input)
if input_tensor is None:
return self._create_empty_result("Invalid image")
prediction_result = self._run_inference(input_tensor)
if prediction_result is None:
return self._create_empty_result("Inference error")
pred_name, confidence, prob_df = prediction_result
medical_info = self._get_medical_info(pred_name)
latency_ms = int((time.time() - start_time) * 1000)
return (
pred_name, confidence,
medical_info["description"], medical_info["symptoms"],
medical_info["causes"], medical_info["treatment"],
prob_df, f"{latency_ms} ms"
)
except Exception as e:
logger.error(f"Prediction failed: {e}")
return self._create_empty_result(f"Error: {str(e)}")
def _run_inference(self, input_tensor: np.ndarray) -> Optional[Tuple[str, str, Any]]:
"""Run model inference."""
try:
input_name = self.session.get_inputs()[0].name
output = self.session.run(None, {input_name: input_tensor})
logits = output[0].squeeze()
pred_idx = int(np.argmax(logits))
pred_name = self.classes[pred_idx]
# Softmax probabilities
exp_logits = np.exp(logits - np.max(logits))
probabilities = exp_logits / exp_logits.sum()
confidence = format_confidence(probabilities[pred_idx])
prob_ints = probabilities_to_ints(probabilities * 100.0)
prob_df = self._create_probability_dataframe(prob_ints)
return pred_name, confidence, prob_df
except Exception as e:
logger.error(f"Inference failed: {e}")
return None
def _create_probability_dataframe(self, probabilities: np.ndarray) -> Any:
"""Create sorted probability dataframe."""
try:
import pandas as pd
return pd.DataFrame({
"item": self.classes,
"probability": probabilities.astype(int)
}).sort_values("probability", ascending=True)
except Exception as e:
logger.error(f"Error creating dataframe: {e}")
import pandas as pd
return pd.DataFrame({"item": self.classes, "probability": [0] * len(self.classes)})
def _get_medical_info(self, condition_name: str) -> Dict[str, str]:
"""Get medical information for condition."""
condition_data = self.medical_data.get(condition_name, {})
return {
"description": condition_data.get("description", ""),
"symptoms": condition_data.get("symptoms", ""),
"causes": condition_data.get("causes", ""),
"treatment": condition_data.get("treatment-1", "")
}
def _create_empty_result(self, message: str) -> Tuple[str, str, str, str, str, str, Any, str]:
"""Create empty result with error message."""
try:
import pandas as pd
empty_df = pd.DataFrame({"item": self.classes, "probability": [0] * len(self.classes)})
except:
empty_df = None
return (message, "", "", "", "", "", empty_df, "")
def get_model_info(self) -> Dict[str, Any]:
"""Get model information and metadata."""
info = {
"classes": self.classes,
"num_classes": len(self.classes),
"model_file": str(MODEL_FILE),
"providers": ModelConfig.ORT_PROVIDERS
}
if self.session:
try:
input_info = self.session.get_inputs()[0]
info.update({
"input_shape": input_info.shape,
"input_type": input_info.type,
"input_name": input_info.name
})
except Exception as e:
logger.warning(f"Could not get model input info: {e}")
return info
|