Spaces:
Sleeping
Sleeping
| """Model inference for MelanoScope AI.""" | |
| import json | |
| import logging | |
| import time | |
| from typing import Dict, Any, List, Optional, Tuple | |
| import numpy as np | |
| import onnxruntime as ort | |
| from ..config.settings import ModelConfig, DATA_FILE, MODEL_FILE | |
| from .preprocessing import ImagePreprocessor | |
| from .utils import probabilities_to_ints, format_confidence | |
| logger = logging.getLogger(__name__) | |
| class MelanoScopeModel: | |
| """MelanoScope AI model for skin lesion classification.""" | |
| def __init__(self): | |
| self.preprocessor = ImagePreprocessor() | |
| self.session: Optional[ort.InferenceSession] = None | |
| self.classes: List[str] = [] | |
| self.medical_data: Dict[str, Any] = {} | |
| self._load_model() | |
| self._load_medical_data() | |
| logger.info(f"Model initialized with {len(self.classes)} classes") | |
| def _load_model(self) -> None: | |
| """Load ONNX model.""" | |
| try: | |
| if not MODEL_FILE.exists(): | |
| raise FileNotFoundError(f"Model file not found: {MODEL_FILE}") | |
| self.session = ort.InferenceSession(str(MODEL_FILE), providers=ModelConfig.ORT_PROVIDERS) | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise RuntimeError(f"Model loading failed: {e}") | |
| def _load_medical_data(self) -> None: | |
| """Load medical condition data.""" | |
| try: | |
| with open(DATA_FILE, "r", encoding="utf-8") as f: | |
| self.medical_data = json.load(f) | |
| self.classes = list(self.medical_data.keys()) | |
| logger.info(f"Loaded data for {len(self.classes)} conditions") | |
| except Exception as e: | |
| logger.error(f"Failed to load medical data: {e}") | |
| raise RuntimeError(f"Medical data loading failed: {e}") | |
| def predict(self, image_input: Any) -> Tuple[str, str, str, str, str, str, Any, str]: | |
| """Perform inference on input image.""" | |
| if image_input is None: | |
| return self._create_empty_result("Please upload an image and click Analyze.") | |
| try: | |
| start_time = time.time() | |
| input_tensor = self.preprocessor.preprocess(image_input) | |
| if input_tensor is None: | |
| return self._create_empty_result("Invalid image") | |
| prediction_result = self._run_inference(input_tensor) | |
| if prediction_result is None: | |
| return self._create_empty_result("Inference error") | |
| pred_name, confidence, prob_df = prediction_result | |
| medical_info = self._get_medical_info(pred_name) | |
| latency_ms = int((time.time() - start_time) * 1000) | |
| return ( | |
| pred_name, confidence, | |
| medical_info["description"], medical_info["symptoms"], | |
| medical_info["causes"], medical_info["treatment"], | |
| prob_df, f"{latency_ms} ms" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Prediction failed: {e}") | |
| return self._create_empty_result(f"Error: {str(e)}") | |
| def _run_inference(self, input_tensor: np.ndarray) -> Optional[Tuple[str, str, Any]]: | |
| """Run model inference.""" | |
| try: | |
| input_name = self.session.get_inputs()[0].name | |
| output = self.session.run(None, {input_name: input_tensor}) | |
| logits = output[0].squeeze() | |
| pred_idx = int(np.argmax(logits)) | |
| pred_name = self.classes[pred_idx] | |
| # Softmax probabilities | |
| exp_logits = np.exp(logits - np.max(logits)) | |
| probabilities = exp_logits / exp_logits.sum() | |
| confidence = format_confidence(probabilities[pred_idx]) | |
| prob_ints = probabilities_to_ints(probabilities * 100.0) | |
| prob_df = self._create_probability_dataframe(prob_ints) | |
| return pred_name, confidence, prob_df | |
| except Exception as e: | |
| logger.error(f"Inference failed: {e}") | |
| return None | |
| def _create_probability_dataframe(self, probabilities: np.ndarray) -> Any: | |
| """Create sorted probability dataframe.""" | |
| try: | |
| import pandas as pd | |
| return pd.DataFrame({ | |
| "item": self.classes, | |
| "probability": probabilities.astype(int) | |
| }).sort_values("probability", ascending=True) | |
| except Exception as e: | |
| logger.error(f"Error creating dataframe: {e}") | |
| import pandas as pd | |
| return pd.DataFrame({"item": self.classes, "probability": [0] * len(self.classes)}) | |
| def _get_medical_info(self, condition_name: str) -> Dict[str, str]: | |
| """Get medical information for condition.""" | |
| condition_data = self.medical_data.get(condition_name, {}) | |
| return { | |
| "description": condition_data.get("description", ""), | |
| "symptoms": condition_data.get("symptoms", ""), | |
| "causes": condition_data.get("causes", ""), | |
| "treatment": condition_data.get("treatment-1", "") | |
| } | |
| def _create_empty_result(self, message: str) -> Tuple[str, str, str, str, str, str, Any, str]: | |
| """Create empty result with error message.""" | |
| try: | |
| import pandas as pd | |
| empty_df = pd.DataFrame({"item": self.classes, "probability": [0] * len(self.classes)}) | |
| except: | |
| empty_df = None | |
| return (message, "", "", "", "", "", empty_df, "") | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """Get model information and metadata.""" | |
| info = { | |
| "classes": self.classes, | |
| "num_classes": len(self.classes), | |
| "model_file": str(MODEL_FILE), | |
| "providers": ModelConfig.ORT_PROVIDERS | |
| } | |
| if self.session: | |
| try: | |
| input_info = self.session.get_inputs()[0] | |
| info.update({ | |
| "input_shape": input_info.shape, | |
| "input_type": input_info.type, | |
| "input_name": input_info.name | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Could not get model input info: {e}") | |
| return info | |