Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision import transforms | |
| from app.architecture import AdvancedBreastCancerModel | |
| logger = logging.getLogger(__name__) | |
| # ImageNet normalisation (same as SensiNet training pipeline) | |
| TRANSFORM = transforms.Compose([ | |
| transforms.Resize((299, 299)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| WEIGHTS_DIR = Path(__file__).resolve().parent.parent / "weights" | |
| DEFAULT_WEIGHTS = WEIGHTS_DIR / "advanced_model_best.pth" | |
| # Malignancy probability threshold (same as SensiNet default) | |
| THRESHOLD = 0.40 | |
| # Number of Bayesian MC-Dropout forward passes | |
| MC_PASSES = 10 | |
| def _prob_to_birads(prob: float) -> int: | |
| """Map malignancy probability to BI-RADS category.""" | |
| if prob < 0.10: | |
| return 1 # Negative | |
| if prob < 0.25: | |
| return 2 # Benign | |
| if prob < 0.50: | |
| return 3 # Probably benign | |
| if prob < 0.75: | |
| return 4 # Suspicious | |
| return 5 # Highly suggestive of malignancy | |
| def _birads_findings(birads: int, prob: float, prediction: str) -> str: | |
| templates = { | |
| 1: "No suspicious findings detected. Mammographic appearance is unremarkable.", | |
| 2: "Benign-appearing pattern identified. Correlate with prior imaging if available.", | |
| 3: "Probably benign appearance. Short-interval follow-up may be considered.", | |
| 4: "Suspicious abnormality pattern detected. Tissue biopsy is recommended.", | |
| 5: "Highly suggestive of malignancy. Urgent diagnostic workup is recommended.", | |
| } | |
| base = templates.get(birads, "Analysis complete.") | |
| return f"Model prediction: {prediction} (probability {prob:.1%}). {base}" | |
| class MammogramModel: | |
| """Loads the SensiNet dual-stream model and runs inference.""" | |
| def __init__(self) -> None: | |
| self.mode = os.getenv("MODEL_MODE", "real") | |
| self.version = os.getenv("MODEL_VERSION", "sensinet-v1") | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self._model: AdvancedBreastCancerModel | None = None | |
| weights_path = Path(os.getenv("MODEL_WEIGHTS", str(DEFAULT_WEIGHTS))) | |
| if weights_path.exists(): | |
| self._load_model(weights_path) | |
| else: | |
| logger.warning("Weights not found at %s — falling back to mock mode", weights_path) | |
| self.mode = "mock" | |
| def _load_model(self, weights_path: Path) -> None: | |
| logger.info("Loading SensiNet model from %s onto %s …", weights_path, self.device) | |
| net = AdvancedBreastCancerModel() | |
| state = torch.load(weights_path, map_location=self.device, weights_only=False) | |
| net.load_state_dict(state) | |
| net.to(self.device) | |
| net.eval() | |
| self._model = net | |
| logger.info("Model loaded successfully.") | |
| # ------------------------------------------------------------------ | |
| def predict(self, image: Image.Image) -> dict: | |
| if self._model is None or self.mode == "mock": | |
| return self._mock_predict(image) | |
| return self._real_predict(image) | |
| # ------------------------------------------------------------------ | |
| # Real inference with Bayesian MC-Dropout | |
| # ------------------------------------------------------------------ | |
| def _real_predict(self, image: Image.Image) -> dict: | |
| rgb = image.convert("RGB") | |
| tensor = TRANSFORM(rgb).unsqueeze(0).to(self.device) | |
| def enable_dropout(m: nn.Module) -> None: | |
| if isinstance(m, (nn.Dropout, nn.Dropout2d)): | |
| m.train() | |
| self._model.apply(enable_dropout) | |
| mc_predictions: list[float] = [] | |
| with torch.no_grad(): | |
| for _ in range(MC_PASSES): | |
| logits = self._model(tensor) | |
| prob = torch.sigmoid(logits).item() | |
| mc_predictions.append(prob) | |
| self._model.eval() | |
| prob_malig = float(np.mean(mc_predictions)) | |
| variance = float(np.var(mc_predictions)) | |
| decision_confidence = max(0.50, 0.99 - (variance * 2.0)) | |
| if prob_malig < 0.10 or prob_malig > 0.90: | |
| decision_confidence = min(0.99, decision_confidence + 0.10) | |
| prediction = "Malignant" if prob_malig >= THRESHOLD else "Benign" | |
| birads = _prob_to_birads(prob_malig) | |
| return { | |
| "birads": birads, | |
| "confidence": round(decision_confidence, 3), | |
| "malignancy_probability": round(prob_malig, 3), | |
| "findings_text": _birads_findings(birads, prob_malig, prediction), | |
| "model_version": self.version, | |
| } | |
| # ------------------------------------------------------------------ | |
| # Deterministic mock fallback (no weights needed) | |
| # ------------------------------------------------------------------ | |
| def _mock_predict(image: Image.Image) -> dict: | |
| import hashlib | |
| arr = np.array(image.convert("L"), dtype=np.float32) / 255.0 | |
| digest = hashlib.sha256(arr.tobytes()).hexdigest() | |
| seed = int(digest[:8], 16) | |
| rng = np.random.default_rng(seed) | |
| raw = float(min(max(arr.mean() + rng.uniform(-0.04, 0.04), 0.0), 1.0)) | |
| birads = _prob_to_birads(raw) | |
| return { | |
| "birads": birads, | |
| "confidence": round(max(0.55, min(0.98, 0.55 + abs(raw - 0.5))), 3), | |
| "malignancy_probability": round(raw, 3), | |
| "findings_text": _birads_findings(birads, raw, "Malignant" if raw >= THRESHOLD else "Benign"), | |
| "model_version": "mock-v1", | |
| } | |