"""ML model loading and inference for Valeon diagnostic endpoints.""" from __future__ import annotations import io import logging import threading from typing import Any import numpy as np from PIL import Image logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Custom Keras layers # Registered BEFORE any model is loaded. # Dual registration: # package="custom" -> key "custom>MBConvBlock" (Keras standard) # package="" -> key "MBConvBlock" (what .keras bundle stores) # --------------------------------------------------------------------------- import keras class _SEBlock(keras.layers.Layer): """Squeeze-and-Excitation block using Dense layers. Matches the original saved model structure: se_block/global_pool (GlobalAveragePooling2D, keepdims=True) se_block/squeeze (Dense, swish) se_block/excite (Dense, sigmoid) """ def __init__(self, se_filters: int, expanded_filters: int, **kwargs): super().__init__(**kwargs) self.se_filters = se_filters self.expanded_filters = expanded_filters def build(self, input_shape): self.global_pool = keras.layers.GlobalAveragePooling2D( keepdims=True, name="global_pool" ) self.squeeze = keras.layers.Dense( self.se_filters, activation="swish", name="squeeze" ) self.excite = keras.layers.Dense( self.expanded_filters, activation="sigmoid", name="excite" ) super().build(input_shape) def call(self, x): se = self.global_pool(x) se = self.squeeze(se) se = self.excite(se) return x * se def get_config(self): base = super().get_config() base.update( dict(se_filters=self.se_filters, expanded_filters=self.expanded_filters) ) return base @keras.saving.register_keras_serializable(package="custom") class MBConvBlock(keras.layers.Layer): """Mobile Inverted Bottleneck Conv block with Squeeze-and-Excitation. mixed_bfloat16 dtype fix: the model was trained with keras.mixed_precision.Policy('mixed_bfloat16'), which means BatchNormalization outputs bfloat16 while the original `inputs` tensor is float32. We cast `inputs` to match `x` before the residual add so the dtypes always agree regardless of the active policy. Sub-layers use explicit names that match the original saved model's HDF5 weight structure (depthwise_conv, depthwise_bn, expand_conv, expand_bn, project_conv, project_bn, se_block). """ def __init__( self, filters: int, kernel_size: int = 3, strides: int = 1, expand_ratio: int = 1, se_ratio: float = 0.25, drop_connect_rate: float = 0.0, input_filters: int = 0, **kwargs, ): super().__init__(**kwargs) self.filters = filters self.kernel_size = kernel_size self.strides = strides self.expand_ratio = expand_ratio self.se_ratio = se_ratio self.drop_connect_rate = drop_connect_rate self.input_filters = input_filters self._expanded_filters = max(1, int(input_filters * expand_ratio)) self._se_filters = max(1, int(self._expanded_filters * se_ratio)) self._use_residual = (strides == 1 and input_filters == filters) def build(self, input_shape): if self.expand_ratio != 1: self._expand_conv = keras.layers.Conv2D( self._expanded_filters, 1, padding="same", use_bias=False, name="expand_conv", ) self._expand_bn = keras.layers.BatchNormalization(name="expand_bn") self._dw_conv = keras.layers.DepthwiseConv2D( self.kernel_size, strides=self.strides, padding="same", use_bias=False, name="depthwise_conv", ) self._dw_bn = keras.layers.BatchNormalization(name="depthwise_bn") if self.se_ratio > 0: self._se_block = _SEBlock( self._se_filters, self._expanded_filters, name="se_block" ) self._project_conv = keras.layers.Conv2D( self.filters, 1, padding="same", use_bias=False, name="project_conv", ) self._project_bn = keras.layers.BatchNormalization(name="project_bn") if self.drop_connect_rate > 0 and self._use_residual: self._drop = keras.layers.Dropout( self.drop_connect_rate, noise_shape=(None, 1, 1, 1) ) else: self._drop = None super().build(input_shape) def call(self, inputs, training=None): import tensorflow as tf x = inputs if self.expand_ratio != 1: x = keras.activations.swish( self._expand_bn(self._expand_conv(x), training=training) ) x = keras.activations.swish( self._dw_bn(self._dw_conv(x), training=training) ) if self.se_ratio > 0: x = self._se_block(x) x = self._project_bn(self._project_conv(x), training=training) if self._use_residual: if self._drop is not None: x = self._drop(x, training=training) # Cast shortcut to match x dtype (handles mixed_bfloat16 training) shortcut = tf.cast(inputs, x.dtype) x = x + shortcut return x def get_config(self): base = super().get_config() base.update( dict( filters=self.filters, kernel_size=self.kernel_size, strides=self.strides, expand_ratio=self.expand_ratio, se_ratio=self.se_ratio, drop_connect_rate=self.drop_connect_rate, input_filters=self.input_filters, ) ) return base # Second registration: bare keys (what the .keras bundle stores) try: keras.saving.register_keras_serializable(package="")(MBConvBlock) except Exception: pass try: keras.saving.register_keras_serializable(package="custom")(_SEBlock) except Exception: pass try: keras.saving.register_keras_serializable(package="")(_SEBlock) except Exception: pass _SKIN_CUSTOM_OBJECTS: dict[str, Any] = { "MBConvBlock": MBConvBlock, "_SEBlock": _SEBlock, } # --------------------------------------------------------------------------- # TFSMLayer shim # Wraps keras.layers.TFSMLayer so it exposes a .predict() interface. # --------------------------------------------------------------------------- class _TFSMShim: """Thin wrapper around TFSMLayer that mimics model.predict().""" def __init__(self, layer: Any): self._layer = layer def predict(self, x, verbose=0): import tensorflow as tf tensor = tf.constant(x, dtype=tf.float32) out = self._layer(tensor, training=False) if isinstance(out, dict): out = list(out.values())[0] return out.numpy() # --------------------------------------------------------------------------- # Model registry # --------------------------------------------------------------------------- MODEL_REGISTRY: dict[str, dict[str, Any]] = { "cataract": { "repo_id": "Arko007/Cataract-Detection-CNN", "arch_file": "model_architecture.json", "weights_file": "model_weights.weights.h5", "framework": "keras_json_weights", "input_size": (224, 224), "classes": ["Cataract", "Normal"], }, "diabetic_retinopathy": { "repo_id": "Arko007/diabetic-retinopathy-v1", "filename": "best_model.h5", "framework": "tf", "input_size": (384, 384), "classes": [ "Grade 0 - No DR", "Grade 1 - Mild DR", "Grade 2 - Moderate DR", "Grade 3 - Severe DR", "Grade 4 - Proliferative DR", ], }, "kidney": { "repo_id": "Arko007/kidney-ct-classifier-efficientnet", "filename": "best_model.pth", "framework": "pytorch_efficientnet", "input_size": (224, 224), "classes": ["Cyst", "Normal", "Stone", "Tumor"], }, "skin": { "repo_id": "Arko007/skin-disease-detector-ai", "filename": "model.keras", "framework": "keras3", "input_size": (512, 512), "classes": [ "Actinic Keratosis", "Basal Cell Carcinoma", "Dermatofibroma", "Nevus", "Pigmented Benign Keratosis", "Seborrheic Keratosis", "Squamous Cell Carcinoma", "Vascular Lesion", ], }, "cardiac": { "repo_id": "Arko007/cardiac-mri-cnn", "filename": "best_model_epoch20_auc0.8129.pt", "framework": "pytorch_cardiac", # Input size matches original app: 896×896 "input_size": (896, 896), # Index 0 = Normal, Index 1 = Sick (matches original model_service.py) "classes": ["Normal", "Sick"], }, } _loaded_models: dict[str, Any] = {} _load_locks: dict[str, threading.Lock] = {k: threading.Lock() for k in MODEL_REGISTRY} # --------------------------------------------------------------------------- # Image preprocessing # --------------------------------------------------------------------------- def _preprocess_image_tf(image_bytes: bytes, target_size: tuple[int, int]) -> np.ndarray: img = Image.open(io.BytesIO(image_bytes)).convert("RGB") img = img.resize(target_size, Image.LANCZOS) arr = np.array(img, dtype=np.float32) / 255.0 return np.expand_dims(arr, axis=0) def _preprocess_dr(image_bytes: bytes) -> np.ndarray: img = Image.open(io.BytesIO(image_bytes)).convert("RGB") img = img.resize((384, 384), Image.BILINEAR) arr = np.array(img, dtype=np.float32) / 255.0 return np.expand_dims(arr, axis=0) def _preprocess_skin(image_bytes: bytes) -> np.ndarray: img = Image.open(io.BytesIO(image_bytes)).convert("RGB") img = img.resize((512, 512)) arr = np.array(img, dtype=np.float32) / 255.0 return np.expand_dims(arr, axis=0) def _preprocess_image_torch(image_bytes: bytes, target_size: tuple[int, int]): import torch from torchvision import transforms img = Image.open(io.BytesIO(image_bytes)).convert("RGB") transform = transforms.Compose([ transforms.Resize(target_size), transforms.CenterCrop(target_size), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return transform(img).unsqueeze(0) def _preprocess_cardiac(image_bytes: bytes, target_size: tuple[int, int]): """Cardiac-specific preprocessing: Grayscale → 3ch, normalize with 0.5/0.5. Matches the original ModelService transform exactly: Resize → Grayscale(num_output_channels=3) → ToTensor → Normalize(0.5, 0.5) """ import torch from torchvision import transforms img = Image.open(io.BytesIO(image_bytes)) # Convert to grayscale first (as the model was trained on grayscale MRI) if img.mode != "L": img = img.convert("L") transform = transforms.Compose([ transforms.Resize(target_size), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) return transform(img).unsqueeze(0) # --------------------------------------------------------------------------- # Cardiac model: DenseNet-169 (matches original training architecture) # --------------------------------------------------------------------------- def _build_cardiac_model(num_classes: int = 2): """Build DenseNet-169 with a replaced classifier head. The checkpoint was trained with torchvision DenseNet-169 where only the final Linear layer was replaced — identical to the original app's ModelService._load_model(). """ import torch.nn as nn from torchvision import models model = models.densenet169(weights=None) model.classifier = nn.Linear(model.classifier.in_features, num_classes) return model # --------------------------------------------------------------------------- # h5py path-based weight helpers # --------------------------------------------------------------------------- # Keras auto-naming: class_name → snake_case base used as HDF5 key. _CLASS_TO_H5_BASE: dict[str, str] = { "InputLayer": "input_layer", "Conv2D": "conv2d", "BatchNormalization": "batch_normalization", "Activation": "activation", "MBConvBlock": "mb_conv_block", "GlobalAveragePooling2D": "global_average_pooling2d", "Dropout": "dropout", "Dense": "dense", } # Variable name → positional index inside an HDF5 ``vars/`` group. _VAR_NAME_TO_INDEX: dict[str, int] = { "kernel": 0, "bias": 1, "gamma": 0, "beta": 1, "moving_mean": 2, "moving_variance": 3, } def _build_outer_name_map(config: dict) -> dict[str, str]: """Map config layer names → HDF5 layer keys. The .keras bundle may store HDF5 layer keys using Keras auto-generated names (e.g. ``conv2d``, ``mb_conv_block_1``) while config.json uses user-specified names (e.g. ``stem_conv``, ``block_1a``). This function re-derives the auto-name by counting class occurrences in config order. """ class_counter: dict[str, int] = {} outer_map: dict[str, str] = {} for layer_cfg in config["config"]["layers"]: class_name = layer_cfg["class_name"] layer_name = layer_cfg["name"] base = _CLASS_TO_H5_BASE.get(class_name, class_name.lower()) count = class_counter.get(base, 0) h5_key = base if count == 0 else f"{base}_{count}" outer_map[layer_name] = h5_key class_counter[base] = count + 1 return outer_map # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def _download(repo_id: str, filename: str) -> str: from huggingface_hub import hf_hub_download return hf_hub_download(repo_id=repo_id, filename=filename) def _load_skin_model(repo_id: str, filename: str) -> Any: """Four-strategy loader for the skin model.keras file. Strategy 1: keras.saving.load_model with safe_mode=False. Strategy 2: TFSMLayer on the HF snapshot (SavedModel path). Strategy 3: tf.keras.models.load_model with safe_mode=False. Strategy 4: Manual unzip + h5py index-based weight assignment. Reads model.weights.h5 from inside the .keras ZIP and assigns each tensor to model.weights[i] by DFS index order — bypasses ALL name/shape matching so every BN gamma/beta/moving_mean/moving_variance is loaded correctly. """ import tensorflow as tf path = _download(repo_id, filename) # ----------------------------------------------------------------------- # Strategy 1 — keras.saving.load_model with safe_mode=False # ----------------------------------------------------------------------- try: model = keras.saving.load_model( path, custom_objects=_SKIN_CUSTOM_OBJECTS, compile=False, safe_mode=False, ) logger.info("Skin model loaded via keras.saving.load_model (safe_mode=False).") return model except Exception as e1: logger.warning("keras.saving.load_model failed for skin: %s", e1) # ----------------------------------------------------------------------- # Strategy 2 — TFSMLayer on the HF snapshot directory # ----------------------------------------------------------------------- try: from huggingface_hub import snapshot_download snapshot_dir = snapshot_download(repo_id=repo_id) layer = keras.layers.TFSMLayer( snapshot_dir, call_endpoint="serving_default", ) logger.info("Skin model loaded via TFSMLayer (SavedModel snapshot).") return _TFSMShim(layer) except Exception as e2: logger.warning("TFSMLayer failed for skin: %s", e2) # ----------------------------------------------------------------------- # Strategy 3 — tf.keras legacy loader with safe_mode=False # ----------------------------------------------------------------------- try: import inspect load_kwargs: dict[str, Any] = {"compile": False} if "safe_mode" in inspect.signature(tf.keras.models.load_model).parameters: load_kwargs["safe_mode"] = False model = tf.keras.models.load_model( path, custom_objects=_SKIN_CUSTOM_OBJECTS, **load_kwargs, ) logger.info("Skin model loaded via tf.keras.models.load_model.") return model except Exception as e3: logger.warning("tf.keras.models.load_model failed for skin: %s", e3) # ----------------------------------------------------------------------- # Strategy 4 — manual unzip + h5py path-based weight assignment # # The .keras bundle contains config.json (architecture) and # model.weights.h5 (weights keyed by auto-generated layer names). # The config uses user-specified names (e.g. "block_1a") while the # HDF5 uses Keras auto-names (e.g. "mb_conv_block"). We bridge # the two by building an explicit outer-name mapping and then # translating each model variable path to its HDF5 dataset path. # ----------------------------------------------------------------------- import zipfile, tempfile, os, json, h5py try: with tempfile.TemporaryDirectory() as tmpdir: with zipfile.ZipFile(path, "r") as zf: zf.extractall(tmpdir) config_path = os.path.join(tmpdir, "config.json") weights_candidates = [ os.path.join(tmpdir, "model.weights.h5"), os.path.join(tmpdir, "weights.h5"), ] weights_path = next( (p for p in weights_candidates if os.path.exists(p)), None ) if not os.path.exists(config_path): raise FileNotFoundError("config.json not found inside .keras bundle") with open(config_path, "r") as f: config_data = json.load(f) # Rebuild architecture from config model = keras.models.model_from_json( json.dumps(config_data), custom_objects=_SKIN_CUSTOM_OBJECTS, ) # Force all layers to create their variables model.build((None, 512, 512, 3)) # Run one dummy forward pass so all sub-layers build dummy = np.zeros((1, 512, 512, 3), dtype=np.float32) try: model(dummy, training=False) except Exception: pass if not weights_path: logger.warning( "Skin model architecture rebuilt but no weights file found " "— predictions will be random." ) return model # Build flat dict of all HDF5 layer datasets h5_data: dict[str, np.ndarray] = {} with h5py.File(weights_path, "r") as hf: def _visit(name, obj): if isinstance(obj, h5py.Dataset) and name.startswith("layers/"): h5_data[name] = np.array(obj) hf.visititems(_visit) # Build outer-name mapping (config name → HDF5 key) outer_map = _build_outer_name_map(config_data) # Debug: log first 5 H5 keys and first 5 model var paths h5_keys_sample = sorted(h5_data.keys())[:5] var_paths_sample = [v.path for v in model.weights[:5]] logger.info( "H5 keys sample: %s | var paths sample: %s", h5_keys_sample, var_paths_sample, ) # Assign weights by translating each model var path → HDF5 path assigned = 0 skipped = 0 for var in model.weights: parts = var.path.split("/") outer_name = parts[0] var_name = parts[-1] h5_outer = outer_map.get(outer_name) if h5_outer is None: logger.debug("No outer mapping for %s", var.path) skipped += 1 continue var_idx = _VAR_NAME_TO_INDEX.get(var_name, 0) if len(parts) == 2: # Simple layer: outer/var_name h5_path = f"layers/{h5_outer}/vars/{var_idx}" elif len(parts) == 3: # Sub-layer: outer/inner/var_name h5_path = f"layers/{h5_outer}/{parts[1]}/vars/{var_idx}" elif len(parts) == 4: # Nested sub-layer: outer/inner/sub_inner/var_name h5_path = f"layers/{h5_outer}/{parts[1]}/{parts[2]}/vars/{var_idx}" else: logger.debug("Unexpected path depth for %s", var.path) skipped += 1 continue arr = h5_data.get(h5_path) if arr is not None and arr.shape == tuple(var.shape): target_dtype = ( var.dtype.as_numpy_dtype if hasattr(var.dtype, "as_numpy_dtype") else np.float32 ) var.assign(arr.astype(target_dtype)) assigned += 1 else: if arr is not None: logger.debug( "Shape mismatch for %s: model=%s h5=%s (h5_path=%s)", var.path, var.shape, arr.shape, h5_path, ) else: logger.debug( "H5 path not found for %s → %s", var.path, h5_path, ) skipped += 1 logger.info( "Skin model loaded via h5py path-based assignment: " "%d assigned, %d skipped.", assigned, skipped, ) if skipped > 0: logger.warning( "%d weights could not be assigned. " "Predictions may be partially degraded.", skipped, ) # ---- Self-test at load time ---- # Min confidence for any single class on the expected test image. # A correctly-loaded 8-class model should exceed random chance # (1/8 = 12.5%) by a wide margin; 35% is a conservative floor. _SELFTEST_MIN_CONFIDENCE = 0.35 test_image_path = os.environ.get("SKIN_TEST_IMAGE", "") if test_image_path and os.path.exists(test_image_path): try: with open(test_image_path, "rb") as f: test_bytes = f.read() test_input = _preprocess_skin(test_bytes) test_raw = model.predict(test_input, verbose=0) test_probs = test_raw[0].tolist() max_conf = max(test_probs) max_class = test_probs.index(max_conf) logger.info( "Skin model self-test: max_confidence=%.4f class_index=%d", max_conf, max_class, ) if max_conf < _SELFTEST_MIN_CONFIDENCE: logger.error( "SKIN MODEL WEIGHT LOADING FAILURE: max confidence " "%.4f < %.2f. Weights are not loaded correctly. " "All predictions will be unreliable.", max_conf, _SELFTEST_MIN_CONFIDENCE, ) except Exception as selftest_err: logger.warning( "Skin model self-test failed: %s", selftest_err ) return model except Exception as e4: logger.error("All skin model loading strategies failed. Last error: %s", e4) raise RuntimeError( "Could not load skin model after 4 strategies. " f"[S1] {e1} | [S2] {e2} | [S3] {e3} | [S4] {e4}" ) from e4 def _load_model(name: str) -> Any: cfg = MODEL_REGISTRY[name] fw = cfg["framework"] if fw == "keras_json_weights": import json arch_path = _download(cfg["repo_id"], cfg["arch_file"]) weights_path = _download(cfg["repo_id"], cfg["weights_file"]) with open(arch_path, "r") as f: arch_json = json.load(f) model = keras.models.model_from_json( arch_json if isinstance(arch_json, str) else json.dumps(arch_json) ) model.load_weights(weights_path) return model if fw == "keras3": return _load_skin_model(cfg["repo_id"], cfg["filename"]) if fw == "tf": import tensorflow as tf path = _download(cfg["repo_id"], cfg["filename"]) return tf.keras.models.load_model(path, compile=False) if fw == "pytorch_efficientnet": import torch from efficientnet_pytorch import EfficientNet path = _download(cfg["repo_id"], cfg["filename"]) model = EfficientNet.from_name("efficientnet-b0", num_classes=len(cfg["classes"])) state = torch.load(path, map_location="cpu", weights_only=False) if isinstance(state, dict) and "model_state_dict" in state: state = state["model_state_dict"] model.load_state_dict(state, strict=False) model.eval() return model if fw == "pytorch_cardiac": import torch path = _download(cfg["repo_id"], cfg["filename"]) checkpoint = torch.load(path, map_location="cpu", weights_only=False) state_dict = checkpoint.get("model_state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint model = _build_cardiac_model(len(cfg["classes"])) model.load_state_dict(state_dict, strict=False) model.eval() return model raise ValueError(f"Unknown framework: {fw}") def get_model(name: str) -> Any: if name not in MODEL_REGISTRY: raise KeyError(f"Unknown model: {name}") if name not in _loaded_models: with _load_locks[name]: if name not in _loaded_models: logger.info("Loading model %s \u2026", name) _loaded_models[name] = _load_model(name) logger.info("Model %s loaded and cached.", name) return _loaded_models[name] # --------------------------------------------------------------------------- # Prediction # --------------------------------------------------------------------------- def predict(name: str, image_bytes: bytes) -> dict[str, float]: cfg = MODEL_REGISTRY[name] model = get_model(name) fw = cfg["framework"] classes = cfg["classes"] size = cfg["input_size"] if fw in ("tf", "keras3", "keras_json_weights"): if name == "diabetic_retinopathy": inp = _preprocess_dr(image_bytes) elif name == "skin": inp = _preprocess_skin(image_bytes) else: inp = _preprocess_image_tf(image_bytes, size) raw = model.predict(inp, verbose=0) probs = raw[0] if probs.shape[-1] == 1: p = float(probs[0]) probs_list = [1.0 - p, p] else: probs_list = probs.tolist() total = sum(probs_list) if total > 0: probs_list = [x / total for x in probs_list] probs_list = probs_list[: len(classes)] while len(probs_list) < len(classes): probs_list.append(0.0) return {c: round(p, 6) for c, p in zip(classes, probs_list)} # PyTorch path import torch # Use cardiac-specific preprocessing for the cardiac model if name == "cardiac": inp = _preprocess_cardiac(image_bytes, size) else: inp = _preprocess_image_torch(image_bytes, size) with torch.no_grad(): logits = model(inp) if logits.shape[-1] == 1: p = torch.sigmoid(logits).item() probs_list = [1.0 - p, p] else: probs_list = torch.softmax(logits, dim=-1).squeeze().tolist() if isinstance(probs_list, float): probs_list = [probs_list] probs_list = probs_list[: len(classes)] while len(probs_list) < len(classes): probs_list.append(0.0) return {c: round(p, 6) for c, p in zip(classes, probs_list)}