# inference/predict.py """ Image preprocessing and MC-Dropout inference pipeline. """ import logging from typing import Any import numpy as np import torch from PIL import Image log = logging.getLogger(__name__) CLASS_NAMES = ["normal", "mild", "moderate", "severe"] _IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) _IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) def preprocess_image(img: Image.Image, image_size: int = 380) -> torch.Tensor: """Convert PIL Image to normalised (1, 3, H, W) float tensor.""" img = img.convert("RGB").resize((image_size, image_size), Image.BICUBIC) arr = np.array(img, dtype=np.float32) / 255.0 tensor = torch.from_numpy(arr).permute(2, 0, 1) # (3, H, W) tensor = (tensor - _IMAGENET_MEAN) / _IMAGENET_STD return tensor.unsqueeze(0) # (1, 3, H, W) def mc_dropout_predict( model: torch.nn.Module, image_tensor: torch.Tensor, n_samples: int = 30, ) -> dict[str, Any]: """ Run MC Dropout inference. Activates dropout at inference time for n_samples forward passes. Returns Hb estimate, 95% CI, class probabilities, and classification label. """ model.train() # activate dropout hb_samples = [] cls_samples = [] try: with torch.no_grad(): for _ in range(n_samples): hb_pred, cls_logits = model(image_tensor) hb_samples.append(hb_pred.item()) cls_samples.append(torch.softmax(cls_logits, dim=1).squeeze().numpy()) finally: model.eval() # always restore eval mode, even on exception hb_arr = np.array(hb_samples) cls_arr = np.array(cls_samples).mean(axis=0) # (4,) hb_mean = float(np.mean(hb_arr)) hb_lo = float(np.percentile(hb_arr, 2.5)) hb_hi = float(np.percentile(hb_arr, 97.5)) pred_class_idx = int(np.argmax(cls_arr)) return { "hb_estimate": round(hb_mean, 2), "hb_ci_95": [round(hb_lo, 2), round(hb_hi, 2)], "classification": CLASS_NAMES[pred_class_idx], "class_probabilities": { name: round(float(cls_arr[i]), 4) for i, name in enumerate(CLASS_NAMES) }, "_hb_samples": hb_arr.tolist(), # kept for ensemble CI computation; stripped before API response } def run_full_prediction( conj_img: Image.Image | None, nail_img: Image.Image | None, conj_model: torch.nn.Module | None, nail_model: torch.nn.Module | None, w_conj: float = 0.5, w_nail: float = 0.5, image_size: int = 380, n_mc_samples: int = 30, ) -> dict[str, Any]: """ Run prediction on available images, ensemble if both present. Fills 'per_model' field with individual model results. """ results = {} if conj_img is not None and conj_model is not None: t = preprocess_image(conj_img, image_size) results["conjunctiva"] = mc_dropout_predict(conj_model, t, n_mc_samples) if nail_img is not None and nail_model is not None: t = preprocess_image(nail_img, image_size) results["nailbed"] = mc_dropout_predict(nail_model, t, n_mc_samples) if not results: raise ValueError("No model results — ensure at least one image and model are provided.") # Ensemble if "conjunctiva" in results and "nailbed" in results: cls_probs = { k: w_conj * results["conjunctiva"]["class_probabilities"][k] + w_nail * results["nailbed"]["class_probabilities"][k] for k in CLASS_NAMES } best_cls = max(cls_probs, key=cls_probs.get) # Compute CI from combined weighted MC samples (statistically valid) samples_c = np.array(results["conjunctiva"]["_hb_samples"]) samples_n = np.array(results["nailbed"]["_hb_samples"]) ensemble_samples = w_conj * samples_c + w_nail * samples_n hb_mean = float(np.mean(ensemble_samples)) ci_lo = float(np.percentile(ensemble_samples, 2.5)) ci_hi = float(np.percentile(ensemble_samples, 97.5)) ensemble = { "hb_estimate": round(hb_mean, 2), "hb_ci_95": [round(ci_lo, 2), round(ci_hi, 2)], "classification": best_cls, "class_probabilities": {k: round(v, 4) for k, v in cls_probs.items()}, } elif "conjunctiva" in results: ensemble = results["conjunctiva"] else: ensemble = results["nailbed"] # Strip internal MC samples before returning — not part of the public API contract for r in results.values(): r.pop("_hb_samples", None) return { **ensemble, "per_model": results, "model_version": "v1.0.0", "disclaimer": "Research tool only. Not a certified diagnostic device. Clinical confirmation required.", }