Med_AI_backend / models.py
github-actions[bot]
Auto-sync backend from GitHub
2279de6
"""ML model loading and inference for Valeon diagnostic endpoints."""
from __future__ import annotations
import io
import logging
import threading
from typing import Any
import numpy as np
from PIL import Image
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Custom Keras layers
# Registered BEFORE any model is loaded.
# Dual registration:
# package="custom" -> key "custom>MBConvBlock" (Keras standard)
# package="" -> key "MBConvBlock" (what .keras bundle stores)
# ---------------------------------------------------------------------------
import keras
class _SEBlock(keras.layers.Layer):
"""Squeeze-and-Excitation block using Dense layers.
Matches the original saved model structure:
se_block/global_pool (GlobalAveragePooling2D, keepdims=True)
se_block/squeeze (Dense, swish)
se_block/excite (Dense, sigmoid)
"""
def __init__(self, se_filters: int, expanded_filters: int, **kwargs):
super().__init__(**kwargs)
self.se_filters = se_filters
self.expanded_filters = expanded_filters
def build(self, input_shape):
self.global_pool = keras.layers.GlobalAveragePooling2D(
keepdims=True, name="global_pool"
)
self.squeeze = keras.layers.Dense(
self.se_filters, activation="swish", name="squeeze"
)
self.excite = keras.layers.Dense(
self.expanded_filters, activation="sigmoid", name="excite"
)
super().build(input_shape)
def call(self, x):
se = self.global_pool(x)
se = self.squeeze(se)
se = self.excite(se)
return x * se
def get_config(self):
base = super().get_config()
base.update(
dict(se_filters=self.se_filters, expanded_filters=self.expanded_filters)
)
return base
@keras.saving.register_keras_serializable(package="custom")
class MBConvBlock(keras.layers.Layer):
"""Mobile Inverted Bottleneck Conv block with Squeeze-and-Excitation.
mixed_bfloat16 dtype fix: the model was trained with
keras.mixed_precision.Policy('mixed_bfloat16'), which means
BatchNormalization outputs bfloat16 while the original `inputs`
tensor is float32. We cast `inputs` to match `x` before the residual
add so the dtypes always agree regardless of the active policy.
Sub-layers use explicit names that match the original saved model's
HDF5 weight structure (depthwise_conv, depthwise_bn, expand_conv,
expand_bn, project_conv, project_bn, se_block).
"""
def __init__(
self,
filters: int,
kernel_size: int = 3,
strides: int = 1,
expand_ratio: int = 1,
se_ratio: float = 0.25,
drop_connect_rate: float = 0.0,
input_filters: int = 0,
**kwargs,
):
super().__init__(**kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
self.expand_ratio = expand_ratio
self.se_ratio = se_ratio
self.drop_connect_rate = drop_connect_rate
self.input_filters = input_filters
self._expanded_filters = max(1, int(input_filters * expand_ratio))
self._se_filters = max(1, int(self._expanded_filters * se_ratio))
self._use_residual = (strides == 1 and input_filters == filters)
def build(self, input_shape):
if self.expand_ratio != 1:
self._expand_conv = keras.layers.Conv2D(
self._expanded_filters, 1, padding="same", use_bias=False,
name="expand_conv",
)
self._expand_bn = keras.layers.BatchNormalization(name="expand_bn")
self._dw_conv = keras.layers.DepthwiseConv2D(
self.kernel_size,
strides=self.strides,
padding="same",
use_bias=False,
name="depthwise_conv",
)
self._dw_bn = keras.layers.BatchNormalization(name="depthwise_bn")
if self.se_ratio > 0:
self._se_block = _SEBlock(
self._se_filters, self._expanded_filters, name="se_block"
)
self._project_conv = keras.layers.Conv2D(
self.filters, 1, padding="same", use_bias=False,
name="project_conv",
)
self._project_bn = keras.layers.BatchNormalization(name="project_bn")
if self.drop_connect_rate > 0 and self._use_residual:
self._drop = keras.layers.Dropout(
self.drop_connect_rate, noise_shape=(None, 1, 1, 1)
)
else:
self._drop = None
super().build(input_shape)
def call(self, inputs, training=None):
import tensorflow as tf
x = inputs
if self.expand_ratio != 1:
x = keras.activations.swish(
self._expand_bn(self._expand_conv(x), training=training)
)
x = keras.activations.swish(
self._dw_bn(self._dw_conv(x), training=training)
)
if self.se_ratio > 0:
x = self._se_block(x)
x = self._project_bn(self._project_conv(x), training=training)
if self._use_residual:
if self._drop is not None:
x = self._drop(x, training=training)
# Cast shortcut to match x dtype (handles mixed_bfloat16 training)
shortcut = tf.cast(inputs, x.dtype)
x = x + shortcut
return x
def get_config(self):
base = super().get_config()
base.update(
dict(
filters=self.filters,
kernel_size=self.kernel_size,
strides=self.strides,
expand_ratio=self.expand_ratio,
se_ratio=self.se_ratio,
drop_connect_rate=self.drop_connect_rate,
input_filters=self.input_filters,
)
)
return base
# Second registration: bare keys (what the .keras bundle stores)
try:
keras.saving.register_keras_serializable(package="")(MBConvBlock)
except Exception:
pass
try:
keras.saving.register_keras_serializable(package="custom")(_SEBlock)
except Exception:
pass
try:
keras.saving.register_keras_serializable(package="")(_SEBlock)
except Exception:
pass
_SKIN_CUSTOM_OBJECTS: dict[str, Any] = {
"MBConvBlock": MBConvBlock,
"_SEBlock": _SEBlock,
}
# ---------------------------------------------------------------------------
# TFSMLayer shim
# Wraps keras.layers.TFSMLayer so it exposes a .predict() interface.
# ---------------------------------------------------------------------------
class _TFSMShim:
"""Thin wrapper around TFSMLayer that mimics model.predict()."""
def __init__(self, layer: Any):
self._layer = layer
def predict(self, x, verbose=0):
import tensorflow as tf
tensor = tf.constant(x, dtype=tf.float32)
out = self._layer(tensor, training=False)
if isinstance(out, dict):
out = list(out.values())[0]
return out.numpy()
# ---------------------------------------------------------------------------
# Model registry
# ---------------------------------------------------------------------------
MODEL_REGISTRY: dict[str, dict[str, Any]] = {
"cataract": {
"repo_id": "Arko007/Cataract-Detection-CNN",
"arch_file": "model_architecture.json",
"weights_file": "model_weights.weights.h5",
"framework": "keras_json_weights",
"input_size": (224, 224),
"classes": ["Cataract", "Normal"],
},
"diabetic_retinopathy": {
"repo_id": "Arko007/diabetic-retinopathy-v1",
"filename": "best_model.h5",
"framework": "tf",
"input_size": (384, 384),
"classes": [
"Grade 0 - No DR",
"Grade 1 - Mild DR",
"Grade 2 - Moderate DR",
"Grade 3 - Severe DR",
"Grade 4 - Proliferative DR",
],
},
"kidney": {
"repo_id": "Arko007/kidney-ct-classifier-efficientnet",
"filename": "best_model.pth",
"framework": "pytorch_efficientnet",
"input_size": (224, 224),
"classes": ["Cyst", "Normal", "Stone", "Tumor"],
},
"skin": {
"repo_id": "Arko007/skin-disease-detector-ai",
"filename": "model.keras",
"framework": "keras3",
"input_size": (512, 512),
"classes": [
"Actinic Keratosis",
"Basal Cell Carcinoma",
"Dermatofibroma",
"Nevus",
"Pigmented Benign Keratosis",
"Seborrheic Keratosis",
"Squamous Cell Carcinoma",
"Vascular Lesion",
],
},
"cardiac": {
"repo_id": "Arko007/cardiac-mri-cnn",
"filename": "best_model_epoch20_auc0.8129.pt",
"framework": "pytorch_cardiac",
# Input size matches original app: 896×896
"input_size": (896, 896),
# Index 0 = Normal, Index 1 = Sick (matches original model_service.py)
"classes": ["Normal", "Sick"],
},
}
_loaded_models: dict[str, Any] = {}
_load_locks: dict[str, threading.Lock] = {k: threading.Lock() for k in MODEL_REGISTRY}
# ---------------------------------------------------------------------------
# Image preprocessing
# ---------------------------------------------------------------------------
def _preprocess_image_tf(image_bytes: bytes, target_size: tuple[int, int]) -> np.ndarray:
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img = img.resize(target_size, Image.LANCZOS)
arr = np.array(img, dtype=np.float32) / 255.0
return np.expand_dims(arr, axis=0)
def _preprocess_dr(image_bytes: bytes) -> np.ndarray:
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img = img.resize((384, 384), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
return np.expand_dims(arr, axis=0)
def _preprocess_skin(image_bytes: bytes) -> np.ndarray:
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
img = img.resize((512, 512))
arr = np.array(img, dtype=np.float32) / 255.0
return np.expand_dims(arr, axis=0)
def _preprocess_image_torch(image_bytes: bytes, target_size: tuple[int, int]):
import torch
from torchvision import transforms
img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
transform = transforms.Compose([
transforms.Resize(target_size),
transforms.CenterCrop(target_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
return transform(img).unsqueeze(0)
def _preprocess_cardiac(image_bytes: bytes, target_size: tuple[int, int]):
"""Cardiac-specific preprocessing: Grayscale → 3ch, normalize with 0.5/0.5.
Matches the original ModelService transform exactly:
Resize → Grayscale(num_output_channels=3) → ToTensor → Normalize(0.5, 0.5)
"""
import torch
from torchvision import transforms
img = Image.open(io.BytesIO(image_bytes))
# Convert to grayscale first (as the model was trained on grayscale MRI)
if img.mode != "L":
img = img.convert("L")
transform = transforms.Compose([
transforms.Resize(target_size),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
return transform(img).unsqueeze(0)
# ---------------------------------------------------------------------------
# Cardiac model: DenseNet-169 (matches original training architecture)
# ---------------------------------------------------------------------------
def _build_cardiac_model(num_classes: int = 2):
"""Build DenseNet-169 with a replaced classifier head.
The checkpoint was trained with torchvision DenseNet-169 where only
the final Linear layer was replaced — identical to the original app's
ModelService._load_model().
"""
import torch.nn as nn
from torchvision import models
model = models.densenet169(weights=None)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
return model
# ---------------------------------------------------------------------------
# h5py path-based weight helpers
# ---------------------------------------------------------------------------
# Keras auto-naming: class_name → snake_case base used as HDF5 key.
_CLASS_TO_H5_BASE: dict[str, str] = {
"InputLayer": "input_layer",
"Conv2D": "conv2d",
"BatchNormalization": "batch_normalization",
"Activation": "activation",
"MBConvBlock": "mb_conv_block",
"GlobalAveragePooling2D": "global_average_pooling2d",
"Dropout": "dropout",
"Dense": "dense",
}
# Variable name → positional index inside an HDF5 ``vars/`` group.
_VAR_NAME_TO_INDEX: dict[str, int] = {
"kernel": 0,
"bias": 1,
"gamma": 0,
"beta": 1,
"moving_mean": 2,
"moving_variance": 3,
}
def _build_outer_name_map(config: dict) -> dict[str, str]:
"""Map config layer names → HDF5 layer keys.
The .keras bundle may store HDF5 layer keys using Keras auto-generated
names (e.g. ``conv2d``, ``mb_conv_block_1``) while config.json uses
user-specified names (e.g. ``stem_conv``, ``block_1a``). This function
re-derives the auto-name by counting class occurrences in config order.
"""
class_counter: dict[str, int] = {}
outer_map: dict[str, str] = {}
for layer_cfg in config["config"]["layers"]:
class_name = layer_cfg["class_name"]
layer_name = layer_cfg["name"]
base = _CLASS_TO_H5_BASE.get(class_name, class_name.lower())
count = class_counter.get(base, 0)
h5_key = base if count == 0 else f"{base}_{count}"
outer_map[layer_name] = h5_key
class_counter[base] = count + 1
return outer_map
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def _download(repo_id: str, filename: str) -> str:
from huggingface_hub import hf_hub_download
return hf_hub_download(repo_id=repo_id, filename=filename)
def _load_skin_model(repo_id: str, filename: str) -> Any:
"""Four-strategy loader for the skin model.keras file.
Strategy 1: keras.saving.load_model with safe_mode=False.
Strategy 2: TFSMLayer on the HF snapshot (SavedModel path).
Strategy 3: tf.keras.models.load_model with safe_mode=False.
Strategy 4: Manual unzip + h5py index-based weight assignment.
Reads model.weights.h5 from inside the .keras ZIP and
assigns each tensor to model.weights[i] by DFS index
order — bypasses ALL name/shape matching so every BN
gamma/beta/moving_mean/moving_variance is loaded correctly.
"""
import tensorflow as tf
path = _download(repo_id, filename)
# -----------------------------------------------------------------------
# Strategy 1 — keras.saving.load_model with safe_mode=False
# -----------------------------------------------------------------------
try:
model = keras.saving.load_model(
path,
custom_objects=_SKIN_CUSTOM_OBJECTS,
compile=False,
safe_mode=False,
)
logger.info("Skin model loaded via keras.saving.load_model (safe_mode=False).")
return model
except Exception as e1:
logger.warning("keras.saving.load_model failed for skin: %s", e1)
# -----------------------------------------------------------------------
# Strategy 2 — TFSMLayer on the HF snapshot directory
# -----------------------------------------------------------------------
try:
from huggingface_hub import snapshot_download
snapshot_dir = snapshot_download(repo_id=repo_id)
layer = keras.layers.TFSMLayer(
snapshot_dir,
call_endpoint="serving_default",
)
logger.info("Skin model loaded via TFSMLayer (SavedModel snapshot).")
return _TFSMShim(layer)
except Exception as e2:
logger.warning("TFSMLayer failed for skin: %s", e2)
# -----------------------------------------------------------------------
# Strategy 3 — tf.keras legacy loader with safe_mode=False
# -----------------------------------------------------------------------
try:
import inspect
load_kwargs: dict[str, Any] = {"compile": False}
if "safe_mode" in inspect.signature(tf.keras.models.load_model).parameters:
load_kwargs["safe_mode"] = False
model = tf.keras.models.load_model(
path,
custom_objects=_SKIN_CUSTOM_OBJECTS,
**load_kwargs,
)
logger.info("Skin model loaded via tf.keras.models.load_model.")
return model
except Exception as e3:
logger.warning("tf.keras.models.load_model failed for skin: %s", e3)
# -----------------------------------------------------------------------
# Strategy 4 — manual unzip + h5py path-based weight assignment
#
# The .keras bundle contains config.json (architecture) and
# model.weights.h5 (weights keyed by auto-generated layer names).
# The config uses user-specified names (e.g. "block_1a") while the
# HDF5 uses Keras auto-names (e.g. "mb_conv_block"). We bridge
# the two by building an explicit outer-name mapping and then
# translating each model variable path to its HDF5 dataset path.
# -----------------------------------------------------------------------
import zipfile, tempfile, os, json, h5py
try:
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(path, "r") as zf:
zf.extractall(tmpdir)
config_path = os.path.join(tmpdir, "config.json")
weights_candidates = [
os.path.join(tmpdir, "model.weights.h5"),
os.path.join(tmpdir, "weights.h5"),
]
weights_path = next(
(p for p in weights_candidates if os.path.exists(p)), None
)
if not os.path.exists(config_path):
raise FileNotFoundError("config.json not found inside .keras bundle")
with open(config_path, "r") as f:
config_data = json.load(f)
# Rebuild architecture from config
model = keras.models.model_from_json(
json.dumps(config_data),
custom_objects=_SKIN_CUSTOM_OBJECTS,
)
# Force all layers to create their variables
model.build((None, 512, 512, 3))
# Run one dummy forward pass so all sub-layers build
dummy = np.zeros((1, 512, 512, 3), dtype=np.float32)
try:
model(dummy, training=False)
except Exception:
pass
if not weights_path:
logger.warning(
"Skin model architecture rebuilt but no weights file found "
"— predictions will be random."
)
return model
# Build flat dict of all HDF5 layer datasets
h5_data: dict[str, np.ndarray] = {}
with h5py.File(weights_path, "r") as hf:
def _visit(name, obj):
if isinstance(obj, h5py.Dataset) and name.startswith("layers/"):
h5_data[name] = np.array(obj)
hf.visititems(_visit)
# Build outer-name mapping (config name → HDF5 key)
outer_map = _build_outer_name_map(config_data)
# Debug: log first 5 H5 keys and first 5 model var paths
h5_keys_sample = sorted(h5_data.keys())[:5]
var_paths_sample = [v.path for v in model.weights[:5]]
logger.info(
"H5 keys sample: %s | var paths sample: %s",
h5_keys_sample, var_paths_sample,
)
# Assign weights by translating each model var path → HDF5 path
assigned = 0
skipped = 0
for var in model.weights:
parts = var.path.split("/")
outer_name = parts[0]
var_name = parts[-1]
h5_outer = outer_map.get(outer_name)
if h5_outer is None:
logger.debug("No outer mapping for %s", var.path)
skipped += 1
continue
var_idx = _VAR_NAME_TO_INDEX.get(var_name, 0)
if len(parts) == 2:
# Simple layer: outer/var_name
h5_path = f"layers/{h5_outer}/vars/{var_idx}"
elif len(parts) == 3:
# Sub-layer: outer/inner/var_name
h5_path = f"layers/{h5_outer}/{parts[1]}/vars/{var_idx}"
elif len(parts) == 4:
# Nested sub-layer: outer/inner/sub_inner/var_name
h5_path = f"layers/{h5_outer}/{parts[1]}/{parts[2]}/vars/{var_idx}"
else:
logger.debug("Unexpected path depth for %s", var.path)
skipped += 1
continue
arr = h5_data.get(h5_path)
if arr is not None and arr.shape == tuple(var.shape):
target_dtype = (
var.dtype.as_numpy_dtype
if hasattr(var.dtype, "as_numpy_dtype")
else np.float32
)
var.assign(arr.astype(target_dtype))
assigned += 1
else:
if arr is not None:
logger.debug(
"Shape mismatch for %s: model=%s h5=%s (h5_path=%s)",
var.path, var.shape, arr.shape, h5_path,
)
else:
logger.debug(
"H5 path not found for %s → %s", var.path, h5_path,
)
skipped += 1
logger.info(
"Skin model loaded via h5py path-based assignment: "
"%d assigned, %d skipped.",
assigned, skipped,
)
if skipped > 0:
logger.warning(
"%d weights could not be assigned. "
"Predictions may be partially degraded.",
skipped,
)
# ---- Self-test at load time ----
# Min confidence for any single class on the expected test image.
# A correctly-loaded 8-class model should exceed random chance
# (1/8 = 12.5%) by a wide margin; 35% is a conservative floor.
_SELFTEST_MIN_CONFIDENCE = 0.35
test_image_path = os.environ.get("SKIN_TEST_IMAGE", "")
if test_image_path and os.path.exists(test_image_path):
try:
with open(test_image_path, "rb") as f:
test_bytes = f.read()
test_input = _preprocess_skin(test_bytes)
test_raw = model.predict(test_input, verbose=0)
test_probs = test_raw[0].tolist()
max_conf = max(test_probs)
max_class = test_probs.index(max_conf)
logger.info(
"Skin model self-test: max_confidence=%.4f class_index=%d",
max_conf, max_class,
)
if max_conf < _SELFTEST_MIN_CONFIDENCE:
logger.error(
"SKIN MODEL WEIGHT LOADING FAILURE: max confidence "
"%.4f < %.2f. Weights are not loaded correctly. "
"All predictions will be unreliable.",
max_conf, _SELFTEST_MIN_CONFIDENCE,
)
except Exception as selftest_err:
logger.warning(
"Skin model self-test failed: %s", selftest_err
)
return model
except Exception as e4:
logger.error("All skin model loading strategies failed. Last error: %s", e4)
raise RuntimeError(
"Could not load skin model after 4 strategies. "
f"[S1] {e1} | [S2] {e2} | [S3] {e3} | [S4] {e4}"
) from e4
def _load_model(name: str) -> Any:
cfg = MODEL_REGISTRY[name]
fw = cfg["framework"]
if fw == "keras_json_weights":
import json
arch_path = _download(cfg["repo_id"], cfg["arch_file"])
weights_path = _download(cfg["repo_id"], cfg["weights_file"])
with open(arch_path, "r") as f:
arch_json = json.load(f)
model = keras.models.model_from_json(
arch_json if isinstance(arch_json, str) else json.dumps(arch_json)
)
model.load_weights(weights_path)
return model
if fw == "keras3":
return _load_skin_model(cfg["repo_id"], cfg["filename"])
if fw == "tf":
import tensorflow as tf
path = _download(cfg["repo_id"], cfg["filename"])
return tf.keras.models.load_model(path, compile=False)
if fw == "pytorch_efficientnet":
import torch
from efficientnet_pytorch import EfficientNet
path = _download(cfg["repo_id"], cfg["filename"])
model = EfficientNet.from_name("efficientnet-b0", num_classes=len(cfg["classes"]))
state = torch.load(path, map_location="cpu", weights_only=False)
if isinstance(state, dict) and "model_state_dict" in state:
state = state["model_state_dict"]
model.load_state_dict(state, strict=False)
model.eval()
return model
if fw == "pytorch_cardiac":
import torch
path = _download(cfg["repo_id"], cfg["filename"])
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
state_dict = checkpoint.get("model_state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint
model = _build_cardiac_model(len(cfg["classes"]))
model.load_state_dict(state_dict, strict=False)
model.eval()
return model
raise ValueError(f"Unknown framework: {fw}")
def get_model(name: str) -> Any:
if name not in MODEL_REGISTRY:
raise KeyError(f"Unknown model: {name}")
if name not in _loaded_models:
with _load_locks[name]:
if name not in _loaded_models:
logger.info("Loading model %s \u2026", name)
_loaded_models[name] = _load_model(name)
logger.info("Model %s loaded and cached.", name)
return _loaded_models[name]
# ---------------------------------------------------------------------------
# Prediction
# ---------------------------------------------------------------------------
def predict(name: str, image_bytes: bytes) -> dict[str, float]:
cfg = MODEL_REGISTRY[name]
model = get_model(name)
fw = cfg["framework"]
classes = cfg["classes"]
size = cfg["input_size"]
if fw in ("tf", "keras3", "keras_json_weights"):
if name == "diabetic_retinopathy":
inp = _preprocess_dr(image_bytes)
elif name == "skin":
inp = _preprocess_skin(image_bytes)
else:
inp = _preprocess_image_tf(image_bytes, size)
raw = model.predict(inp, verbose=0)
probs = raw[0]
if probs.shape[-1] == 1:
p = float(probs[0])
probs_list = [1.0 - p, p]
else:
probs_list = probs.tolist()
total = sum(probs_list)
if total > 0:
probs_list = [x / total for x in probs_list]
probs_list = probs_list[: len(classes)]
while len(probs_list) < len(classes):
probs_list.append(0.0)
return {c: round(p, 6) for c, p in zip(classes, probs_list)}
# PyTorch path
import torch
# Use cardiac-specific preprocessing for the cardiac model
if name == "cardiac":
inp = _preprocess_cardiac(image_bytes, size)
else:
inp = _preprocess_image_torch(image_bytes, size)
with torch.no_grad():
logits = model(inp)
if logits.shape[-1] == 1:
p = torch.sigmoid(logits).item()
probs_list = [1.0 - p, p]
else:
probs_list = torch.softmax(logits, dim=-1).squeeze().tolist()
if isinstance(probs_list, float):
probs_list = [probs_list]
probs_list = probs_list[: len(classes)]
while len(probs_list) < len(classes):
probs_list.append(0.0)
return {c: round(p, 6) for c, p in zip(classes, probs_list)}