dcavadia
update model info
9bdbbe5
"""Model inference for MelanoScope AI."""
import json
import logging
import time
from typing import Dict, Any, List, Optional, Tuple
import numpy as np
import onnxruntime as ort
from ..config.settings import ModelConfig, DATA_FILE, MODEL_FILE
from .preprocessing import ImagePreprocessor
from .utils import probabilities_to_ints, format_confidence
logger = logging.getLogger(__name__)
class MelanoScopeModel:
"""MelanoScope AI model for skin lesion classification."""
def __init__(self):
self.preprocessor = ImagePreprocessor()
self.session: Optional[ort.InferenceSession] = None
self.classes: List[str] = []
self.medical_data: Dict[str, Any] = {}
self._load_model()
self._load_medical_data()
logger.info(f"Model initialized with {len(self.classes)} classes")
def _load_model(self) -> None:
"""Load ONNX model."""
try:
if not MODEL_FILE.exists():
raise FileNotFoundError(f"Model file not found: {MODEL_FILE}")
self.session = ort.InferenceSession(str(MODEL_FILE), providers=ModelConfig.ORT_PROVIDERS)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise RuntimeError(f"Model loading failed: {e}")
def _load_medical_data(self) -> None:
"""Load medical condition data."""
try:
with open(DATA_FILE, "r", encoding="utf-8") as f:
self.medical_data = json.load(f)
self.classes = list(self.medical_data.keys())
logger.info(f"Loaded data for {len(self.classes)} conditions")
except Exception as e:
logger.error(f"Failed to load medical data: {e}")
raise RuntimeError(f"Medical data loading failed: {e}")
def predict(self, image_input: Any) -> Tuple[str, str, str, str, str, str, Any, str]:
"""Perform inference on input image."""
if image_input is None:
return self._create_empty_result("Please upload an image and click Analyze.")
try:
start_time = time.time()
input_tensor = self.preprocessor.preprocess(image_input)
if input_tensor is None:
return self._create_empty_result("Invalid image")
prediction_result = self._run_inference(input_tensor)
if prediction_result is None:
return self._create_empty_result("Inference error")
pred_name, confidence, prob_df = prediction_result
medical_info = self._get_medical_info(pred_name)
latency_ms = int((time.time() - start_time) * 1000)
return (
pred_name, confidence,
medical_info["description"], medical_info["symptoms"],
medical_info["causes"], medical_info["treatment"],
prob_df, f"{latency_ms} ms"
)
except Exception as e:
logger.error(f"Prediction failed: {e}")
return self._create_empty_result(f"Error: {str(e)}")
def _run_inference(self, input_tensor: np.ndarray) -> Optional[Tuple[str, str, Any]]:
"""Run model inference."""
try:
input_name = self.session.get_inputs()[0].name
output = self.session.run(None, {input_name: input_tensor})
logits = output[0].squeeze()
pred_idx = int(np.argmax(logits))
pred_name = self.classes[pred_idx]
# Softmax probabilities
exp_logits = np.exp(logits - np.max(logits))
probabilities = exp_logits / exp_logits.sum()
confidence = format_confidence(probabilities[pred_idx])
prob_ints = probabilities_to_ints(probabilities * 100.0)
prob_df = self._create_probability_dataframe(prob_ints)
return pred_name, confidence, prob_df
except Exception as e:
logger.error(f"Inference failed: {e}")
return None
def _create_probability_dataframe(self, probabilities: np.ndarray) -> Any:
"""Create sorted probability dataframe."""
try:
import pandas as pd
return pd.DataFrame({
"item": self.classes,
"probability": probabilities.astype(int)
}).sort_values("probability", ascending=True)
except Exception as e:
logger.error(f"Error creating dataframe: {e}")
import pandas as pd
return pd.DataFrame({"item": self.classes, "probability": [0] * len(self.classes)})
def _get_medical_info(self, condition_name: str) -> Dict[str, str]:
"""Get medical information for condition."""
condition_data = self.medical_data.get(condition_name, {})
return {
"description": condition_data.get("description", ""),
"symptoms": condition_data.get("symptoms", ""),
"causes": condition_data.get("causes", ""),
"treatment": condition_data.get("treatment-1", "")
}
def _create_empty_result(self, message: str) -> Tuple[str, str, str, str, str, str, Any, str]:
"""Create empty result with error message."""
try:
import pandas as pd
empty_df = pd.DataFrame({"item": self.classes, "probability": [0] * len(self.classes)})
except:
empty_df = None
return (message, "", "", "", "", "", empty_df, "")
def get_model_info(self) -> Dict[str, Any]:
"""Get model information and metadata."""
info = {
"classes": self.classes,
"num_classes": len(self.classes),
"model_file": str(MODEL_FILE),
"providers": ModelConfig.ORT_PROVIDERS
}
if self.session:
try:
input_info = self.session.get_inputs()[0]
info.update({
"input_shape": input_info.shape,
"input_type": input_info.type,
"input_name": input_info.name
})
except Exception as e:
logger.warning(f"Could not get model input info: {e}")
return info