"""Model inference for MelanoScope AI.""" import json import logging import time from typing import Dict, Any, List, Optional, Tuple import numpy as np import onnxruntime as ort from ..config.settings import ModelConfig, DATA_FILE, MODEL_FILE from .preprocessing import ImagePreprocessor from .utils import probabilities_to_ints, format_confidence logger = logging.getLogger(__name__) class MelanoScopeModel: """MelanoScope AI model for skin lesion classification.""" def __init__(self): self.preprocessor = ImagePreprocessor() self.session: Optional[ort.InferenceSession] = None self.classes: List[str] = [] self.medical_data: Dict[str, Any] = {} self._load_model() self._load_medical_data() logger.info(f"Model initialized with {len(self.classes)} classes") def _load_model(self) -> None: """Load ONNX model.""" try: if not MODEL_FILE.exists(): raise FileNotFoundError(f"Model file not found: {MODEL_FILE}") self.session = ort.InferenceSession(str(MODEL_FILE), providers=ModelConfig.ORT_PROVIDERS) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise RuntimeError(f"Model loading failed: {e}") def _load_medical_data(self) -> None: """Load medical condition data.""" try: with open(DATA_FILE, "r", encoding="utf-8") as f: self.medical_data = json.load(f) self.classes = list(self.medical_data.keys()) logger.info(f"Loaded data for {len(self.classes)} conditions") except Exception as e: logger.error(f"Failed to load medical data: {e}") raise RuntimeError(f"Medical data loading failed: {e}") def predict(self, image_input: Any) -> Tuple[str, str, str, str, str, str, Any, str]: """Perform inference on input image.""" if image_input is None: return self._create_empty_result("Please upload an image and click Analyze.") try: start_time = time.time() input_tensor = self.preprocessor.preprocess(image_input) if input_tensor is None: return self._create_empty_result("Invalid image") prediction_result = self._run_inference(input_tensor) if prediction_result is None: return self._create_empty_result("Inference error") pred_name, confidence, prob_df = prediction_result medical_info = self._get_medical_info(pred_name) latency_ms = int((time.time() - start_time) * 1000) return ( pred_name, confidence, medical_info["description"], medical_info["symptoms"], medical_info["causes"], medical_info["treatment"], prob_df, f"{latency_ms} ms" ) except Exception as e: logger.error(f"Prediction failed: {e}") return self._create_empty_result(f"Error: {str(e)}") def _run_inference(self, input_tensor: np.ndarray) -> Optional[Tuple[str, str, Any]]: """Run model inference.""" try: input_name = self.session.get_inputs()[0].name output = self.session.run(None, {input_name: input_tensor}) logits = output[0].squeeze() pred_idx = int(np.argmax(logits)) pred_name = self.classes[pred_idx] # Softmax probabilities exp_logits = np.exp(logits - np.max(logits)) probabilities = exp_logits / exp_logits.sum() confidence = format_confidence(probabilities[pred_idx]) prob_ints = probabilities_to_ints(probabilities * 100.0) prob_df = self._create_probability_dataframe(prob_ints) return pred_name, confidence, prob_df except Exception as e: logger.error(f"Inference failed: {e}") return None def _create_probability_dataframe(self, probabilities: np.ndarray) -> Any: """Create sorted probability dataframe.""" try: import pandas as pd return pd.DataFrame({ "item": self.classes, "probability": probabilities.astype(int) }).sort_values("probability", ascending=True) except Exception as e: logger.error(f"Error creating dataframe: {e}") import pandas as pd return pd.DataFrame({"item": self.classes, "probability": [0] * len(self.classes)}) def _get_medical_info(self, condition_name: str) -> Dict[str, str]: """Get medical information for condition.""" condition_data = self.medical_data.get(condition_name, {}) return { "description": condition_data.get("description", ""), "symptoms": condition_data.get("symptoms", ""), "causes": condition_data.get("causes", ""), "treatment": condition_data.get("treatment-1", "") } def _create_empty_result(self, message: str) -> Tuple[str, str, str, str, str, str, Any, str]: """Create empty result with error message.""" try: import pandas as pd empty_df = pd.DataFrame({"item": self.classes, "probability": [0] * len(self.classes)}) except: empty_df = None return (message, "", "", "", "", "", empty_df, "") def get_model_info(self) -> Dict[str, Any]: """Get model information and metadata.""" info = { "classes": self.classes, "num_classes": len(self.classes), "model_file": str(MODEL_FILE), "providers": ModelConfig.ORT_PROVIDERS } if self.session: try: input_info = self.session.get_inputs()[0] info.update({ "input_shape": input_info.shape, "input_type": input_info.type, "input_name": input_info.name }) except Exception as e: logger.warning(f"Could not get model input info: {e}") return info