import logging import os from pathlib import Path import numpy as np import torch import torch.nn as nn from PIL import Image from torchvision import transforms from app.architecture import AdvancedBreastCancerModel logger = logging.getLogger(__name__) # ImageNet normalisation (same as SensiNet training pipeline) TRANSFORM = transforms.Compose([ transforms.Resize((299, 299)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) WEIGHTS_DIR = Path(__file__).resolve().parent.parent / "weights" DEFAULT_WEIGHTS = WEIGHTS_DIR / "advanced_model_best.pth" # Malignancy probability threshold (same as SensiNet default) THRESHOLD = 0.40 # Number of Bayesian MC-Dropout forward passes MC_PASSES = 10 def _prob_to_birads(prob: float) -> int: """Map malignancy probability to BI-RADS category.""" if prob < 0.10: return 1 # Negative if prob < 0.25: return 2 # Benign if prob < 0.50: return 3 # Probably benign if prob < 0.75: return 4 # Suspicious return 5 # Highly suggestive of malignancy def _birads_findings(birads: int, prob: float, prediction: str) -> str: templates = { 1: "No suspicious findings detected. Mammographic appearance is unremarkable.", 2: "Benign-appearing pattern identified. Correlate with prior imaging if available.", 3: "Probably benign appearance. Short-interval follow-up may be considered.", 4: "Suspicious abnormality pattern detected. Tissue biopsy is recommended.", 5: "Highly suggestive of malignancy. Urgent diagnostic workup is recommended.", } base = templates.get(birads, "Analysis complete.") return f"Model prediction: {prediction} (probability {prob:.1%}). {base}" class MammogramModel: """Loads the SensiNet dual-stream model and runs inference.""" def __init__(self) -> None: self.mode = os.getenv("MODEL_MODE", "real") self.version = os.getenv("MODEL_VERSION", "sensinet-v1") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._model: AdvancedBreastCancerModel | None = None weights_path = Path(os.getenv("MODEL_WEIGHTS", str(DEFAULT_WEIGHTS))) if weights_path.exists(): self._load_model(weights_path) else: logger.warning("Weights not found at %s — falling back to mock mode", weights_path) self.mode = "mock" def _load_model(self, weights_path: Path) -> None: logger.info("Loading SensiNet model from %s onto %s …", weights_path, self.device) net = AdvancedBreastCancerModel() state = torch.load(weights_path, map_location=self.device, weights_only=False) net.load_state_dict(state) net.to(self.device) net.eval() self._model = net logger.info("Model loaded successfully.") # ------------------------------------------------------------------ def predict(self, image: Image.Image) -> dict: if self._model is None or self.mode == "mock": return self._mock_predict(image) return self._real_predict(image) # ------------------------------------------------------------------ # Real inference with Bayesian MC-Dropout # ------------------------------------------------------------------ def _real_predict(self, image: Image.Image) -> dict: rgb = image.convert("RGB") tensor = TRANSFORM(rgb).unsqueeze(0).to(self.device) def enable_dropout(m: nn.Module) -> None: if isinstance(m, (nn.Dropout, nn.Dropout2d)): m.train() self._model.apply(enable_dropout) mc_predictions: list[float] = [] with torch.no_grad(): for _ in range(MC_PASSES): logits = self._model(tensor) prob = torch.sigmoid(logits).item() mc_predictions.append(prob) self._model.eval() prob_malig = float(np.mean(mc_predictions)) variance = float(np.var(mc_predictions)) decision_confidence = max(0.50, 0.99 - (variance * 2.0)) if prob_malig < 0.10 or prob_malig > 0.90: decision_confidence = min(0.99, decision_confidence + 0.10) prediction = "Malignant" if prob_malig >= THRESHOLD else "Benign" birads = _prob_to_birads(prob_malig) return { "birads": birads, "confidence": round(decision_confidence, 3), "malignancy_probability": round(prob_malig, 3), "findings_text": _birads_findings(birads, prob_malig, prediction), "model_version": self.version, } # ------------------------------------------------------------------ # Deterministic mock fallback (no weights needed) # ------------------------------------------------------------------ @staticmethod def _mock_predict(image: Image.Image) -> dict: import hashlib arr = np.array(image.convert("L"), dtype=np.float32) / 255.0 digest = hashlib.sha256(arr.tobytes()).hexdigest() seed = int(digest[:8], 16) rng = np.random.default_rng(seed) raw = float(min(max(arr.mean() + rng.uniform(-0.04, 0.04), 0.0), 1.0)) birads = _prob_to_birads(raw) return { "birads": birads, "confidence": round(max(0.55, min(0.98, 0.55 + abs(raw - 0.5))), 3), "malignancy_probability": round(raw, 3), "findings_text": _birads_findings(birads, raw, "Malignant" if raw >= THRESHOLD else "Benign"), "model_version": "mock-v1", }