tampee commited on
Commit
596aaa1
Β·
1 Parent(s): 1d745c0

feat: integrate real SensiNet mammography model

Browse files

- Add SensiNet dual-stream architecture (Xception + EfficientNet-B3 with CBAM)
- Replace mock inference with real Bayesian MC-Dropout prediction (10 passes)
- Add /analyze endpoint with SSRF protection (allowlist + private IP blocking)
- Add malignancy_probability to response schema
- Add training and data preparation utility scripts
- Model weights (131MB .pth) gitignored β€” must be downloaded separately

Files changed (9) hide show
  1. .env.example +6 -2
  2. .gitignore +1 -0
  3. app/architecture.py +100 -0
  4. app/main.py +53 -1
  5. app/model.py +134 -31
  6. app/schemas.py +5 -0
  7. prepare_data.py +75 -0
  8. requirements.txt +5 -0
  9. train.py +188 -0
.env.example CHANGED
@@ -1,2 +1,6 @@
1
- MODEL_MODE=mock
2
- MODEL_VERSION=baseline-mock-v1
 
 
 
 
 
1
+ MODEL_MODE=real
2
+ MODEL_VERSION=sensinet-v1
3
+ # Path to .pth weights (defaults to weights/advanced_model_best.pth)
4
+ # MODEL_WEIGHTS=weights/advanced_model_best.pth
5
+ # Comma-separated list of allowed hostnames for the /analyze endpoint (SSRF protection)
6
+ # ALLOWED_IMAGE_HOSTS=your-project.supabase.co
.gitignore CHANGED
@@ -3,3 +3,4 @@ __pycache__/
3
  *.pyc
4
  .env
5
  .DS_Store
 
 
3
  *.pyc
4
  .env
5
  .DS_Store
6
+ weights/*.pth
app/architecture.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SensiNet Dual-Stream Architecture for Mammographic Classification.
3
+
4
+ Architecture:
5
+ - Stream 1: Xception (feature-rich legacy backbone)
6
+ - Stream 2: EfficientNet-B3 (modern efficient backbone)
7
+ - Fusion: Projected feature maps concatenated and refined via CBAM attention
8
+ - Output: Single logit (sigmoid β†’ malignancy probability)
9
+
10
+ Source: Aredeksu/SensiNet-Mammography on Hugging Face (Apache-2.0 license)
11
+ Trained on: CBIS-DDSM dataset
12
+ """
13
+
14
+ import timm
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class ChannelAttention(nn.Module):
21
+ def __init__(self, in_planes: int, ratio: int = 16):
22
+ super().__init__()
23
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
24
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
25
+ self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
26
+ self.relu1 = nn.ReLU()
27
+ self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
28
+ self.sigmoid = nn.Sigmoid()
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
32
+ max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
33
+ return self.sigmoid(avg_out + max_out)
34
+
35
+
36
+ class SpatialAttention(nn.Module):
37
+ def __init__(self, kernel_size: int = 7):
38
+ super().__init__()
39
+ padding = 3 if kernel_size == 7 else 1
40
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
41
+ self.sigmoid = nn.Sigmoid()
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ avg_out = torch.mean(x, dim=1, keepdim=True)
45
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
46
+ x = torch.cat([avg_out, max_out], dim=1)
47
+ return self.sigmoid(self.conv1(x))
48
+
49
+
50
+ class CBAM(nn.Module):
51
+ def __init__(self, planes: int):
52
+ super().__init__()
53
+ self.ca = ChannelAttention(planes)
54
+ self.sa = SpatialAttention()
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ x = x * self.ca(x)
58
+ x = x * self.sa(x)
59
+ return x
60
+
61
+
62
+ class AdvancedBreastCancerModel(nn.Module):
63
+ def __init__(self) -> None:
64
+ super().__init__()
65
+
66
+ # Stream 1: Xception β†’ 2048 channels
67
+ self.stream1 = timm.create_model("xception", pretrained=False, num_classes=0)
68
+ # Stream 2: EfficientNet-B3 β†’ 1536 channels
69
+ self.stream2 = timm.create_model("efficientnet_b3", pretrained=False, num_classes=0)
70
+
71
+ # Projection layers to 512 each
72
+ self.proj1 = nn.Conv2d(2048, 512, 1)
73
+ self.proj2 = nn.Conv2d(1536, 512, 1)
74
+
75
+ # Attention fusion (512 + 512 = 1024)
76
+ self.fusion_attention = CBAM(1024)
77
+
78
+ # Classification head
79
+ self.classifier = nn.Sequential(
80
+ nn.AdaptiveAvgPool2d(1),
81
+ nn.Flatten(),
82
+ nn.Linear(1024, 512),
83
+ nn.BatchNorm1d(512),
84
+ nn.ReLU(),
85
+ nn.Dropout(0.5),
86
+ nn.Linear(512, 1),
87
+ )
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ f1 = self.stream1.forward_features(x) # [B, 2048, H1, W1]
91
+ f2 = self.stream2.forward_features(x) # [B, 1536, H2, W2]
92
+
93
+ if f1.shape[2:] != f2.shape[2:]:
94
+ f2 = F.interpolate(f2, size=f1.shape[2:], mode="bilinear", align_corners=False)
95
+
96
+ p1 = self.proj1(f1)
97
+ p2 = self.proj2(f2)
98
+ concat = torch.cat([p1, p2], dim=1)
99
+ refined = self.fusion_attention(concat)
100
+ return self.classifier(refined)
app/main.py CHANGED
@@ -1,10 +1,21 @@
1
  import io
 
 
 
 
2
 
 
3
  from fastapi import FastAPI, File, HTTPException, UploadFile
4
  from PIL import Image, UnidentifiedImageError
5
 
6
  from app.model import MammogramModel
7
- from app.schemas import PredictResponse
 
 
 
 
 
 
8
 
9
  app = FastAPI(title="Mammogram Inference API", version="0.1.0")
10
  model = MammogramModel()
@@ -31,3 +42,44 @@ async def predict(file: UploadFile = File(...)) -> PredictResponse:
31
 
32
  result = model.predict(image)
33
  return PredictResponse(**result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
2
+ import ipaddress
3
+ import os
4
+ import socket
5
+ from urllib.parse import urlparse
6
 
7
+ import requests as http_requests
8
  from fastapi import FastAPI, File, HTTPException, UploadFile
9
  from PIL import Image, UnidentifiedImageError
10
 
11
  from app.model import MammogramModel
12
+ from app.schemas import AnalyzeRequest, PredictResponse
13
+
14
+ # Comma-separated list of allowed URL hostnames (e.g. your Supabase storage host)
15
+ _ALLOWED_HOSTS_ENV = os.getenv("ALLOWED_IMAGE_HOSTS", "")
16
+ ALLOWED_HOSTS: set[str] = {
17
+ h.strip().lower() for h in _ALLOWED_HOSTS_ENV.split(",") if h.strip()
18
+ }
19
 
20
  app = FastAPI(title="Mammogram Inference API", version="0.1.0")
21
  model = MammogramModel()
 
42
 
43
  result = model.predict(image)
44
  return PredictResponse(**result)
45
+
46
+
47
+ def _validate_url(url: str) -> str:
48
+ """Validate image URL to prevent SSRF attacks."""
49
+ parsed = urlparse(url)
50
+ if parsed.scheme not in ("https",):
51
+ raise HTTPException(status_code=400, detail="Only HTTPS URLs are allowed")
52
+ hostname = (parsed.hostname or "").lower()
53
+ if not hostname:
54
+ raise HTTPException(status_code=400, detail="Invalid URL")
55
+ if ALLOWED_HOSTS and hostname not in ALLOWED_HOSTS:
56
+ raise HTTPException(status_code=400, detail="Image host not in allowlist")
57
+ # Block private/loopback IPs to prevent SSRF
58
+ try:
59
+ for info in socket.getaddrinfo(hostname, None):
60
+ addr = info[4][0]
61
+ ip = ipaddress.ip_address(addr)
62
+ if ip.is_private or ip.is_loopback or ip.is_link_local:
63
+ raise HTTPException(status_code=400, detail="URL resolves to a private address")
64
+ except socket.gaierror as exc:
65
+ raise HTTPException(status_code=400, detail="Cannot resolve hostname") from exc
66
+ return url
67
+
68
+
69
+ @app.post("/analyze", response_model=PredictResponse)
70
+ def analyze(body: AnalyzeRequest) -> PredictResponse:
71
+ """Accept a public image URL, download it, and run inference."""
72
+ _validate_url(body.image_url)
73
+ try:
74
+ resp = http_requests.get(body.image_url, timeout=30)
75
+ resp.raise_for_status()
76
+ except http_requests.RequestException as exc:
77
+ raise HTTPException(status_code=400, detail=f"Failed to fetch image: {exc}") from exc
78
+
79
+ try:
80
+ image = Image.open(io.BytesIO(resp.content))
81
+ except UnidentifiedImageError as exc:
82
+ raise HTTPException(status_code=400, detail="URL did not return a valid image") from exc
83
+
84
+ result = model.predict(image)
85
+ return PredictResponse(**result)
app/model.py CHANGED
@@ -1,50 +1,153 @@
1
- import hashlib
2
  import os
 
3
 
4
  import numpy as np
 
 
5
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class MammogramModel:
 
 
9
  def __init__(self) -> None:
10
- self.mode = os.getenv("MODEL_MODE", "mock")
11
- self.version = os.getenv("MODEL_VERSION", "baseline-mock-v1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def predict(self, image: Image.Image) -> dict:
14
- # Placeholder implementation: deterministic mock score from image bytes.
15
- # Replace with your real PyTorch model loading and inference later.
16
- rgb = image.convert("L")
17
- arr = np.array(rgb, dtype=np.float32) / 255.0
18
- mean_intensity = float(arr.mean())
 
 
 
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  digest = hashlib.sha256(arr.tobytes()).hexdigest()
21
  seed = int(digest[:8], 16)
22
  rng = np.random.default_rng(seed)
23
- jitter = float(rng.uniform(-0.04, 0.04))
24
-
25
- raw = min(max(mean_intensity + jitter, 0.0), 1.0)
26
-
27
- if raw < 0.20:
28
- birads = 1
29
- findings = "No suspicious findings detected in this preliminary model pass."
30
- elif raw < 0.35:
31
- birads = 2
32
- findings = "Likely benign pattern; correlate with prior imaging."
33
- elif raw < 0.55:
34
- birads = 3
35
- findings = "Probably benign appearance; short-interval follow-up may be considered."
36
- elif raw < 0.75:
37
- birads = 4
38
- findings = "Suspicious abnormality pattern; biopsy correlation recommended."
39
- else:
40
- birads = 5
41
- findings = "Highly suggestive of malignancy pattern; urgent diagnostic follow-up recommended."
42
 
43
- confidence = max(0.55, min(0.98, 0.55 + abs(raw - 0.5)))
44
 
45
  return {
46
  "birads": birads,
47
- "confidence": round(confidence, 3),
48
- "findings_text": findings,
49
- "model_version": self.version,
 
50
  }
 
1
+ import logging
2
  import os
3
+ from pathlib import Path
4
 
5
  import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
  from PIL import Image
9
+ from torchvision import transforms
10
+
11
+ from app.architecture import AdvancedBreastCancerModel
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # ImageNet normalisation (same as SensiNet training pipeline)
16
+ TRANSFORM = transforms.Compose([
17
+ transforms.Resize((299, 299)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
20
+ ])
21
+
22
+ WEIGHTS_DIR = Path(__file__).resolve().parent.parent / "weights"
23
+ DEFAULT_WEIGHTS = WEIGHTS_DIR / "advanced_model_best.pth"
24
+
25
+ # Malignancy probability threshold (same as SensiNet default)
26
+ THRESHOLD = 0.40
27
+
28
+ # Number of Bayesian MC-Dropout forward passes
29
+ MC_PASSES = 10
30
+
31
+
32
+ def _prob_to_birads(prob: float) -> int:
33
+ """Map malignancy probability to BI-RADS category."""
34
+ if prob < 0.10:
35
+ return 1 # Negative
36
+ if prob < 0.25:
37
+ return 2 # Benign
38
+ if prob < 0.50:
39
+ return 3 # Probably benign
40
+ if prob < 0.75:
41
+ return 4 # Suspicious
42
+ return 5 # Highly suggestive of malignancy
43
+
44
+
45
+ def _birads_findings(birads: int, prob: float, prediction: str) -> str:
46
+ templates = {
47
+ 1: "No suspicious findings detected. Mammographic appearance is unremarkable.",
48
+ 2: "Benign-appearing pattern identified. Correlate with prior imaging if available.",
49
+ 3: "Probably benign appearance. Short-interval follow-up may be considered.",
50
+ 4: "Suspicious abnormality pattern detected. Tissue biopsy is recommended.",
51
+ 5: "Highly suggestive of malignancy. Urgent diagnostic workup is recommended.",
52
+ }
53
+ base = templates.get(birads, "Analysis complete.")
54
+ return f"Model prediction: {prediction} (probability {prob:.1%}). {base}"
55
 
56
 
57
  class MammogramModel:
58
+ """Loads the SensiNet dual-stream model and runs inference."""
59
+
60
  def __init__(self) -> None:
61
+ self.mode = os.getenv("MODEL_MODE", "real")
62
+ self.version = os.getenv("MODEL_VERSION", "sensinet-v1")
63
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ self._model: AdvancedBreastCancerModel | None = None
65
+
66
+ weights_path = Path(os.getenv("MODEL_WEIGHTS", str(DEFAULT_WEIGHTS)))
67
+ if weights_path.exists():
68
+ self._load_model(weights_path)
69
+ else:
70
+ logger.warning("Weights not found at %s β€” falling back to mock mode", weights_path)
71
+ self.mode = "mock"
72
+
73
+ def _load_model(self, weights_path: Path) -> None:
74
+ logger.info("Loading SensiNet model from %s onto %s …", weights_path, self.device)
75
+ net = AdvancedBreastCancerModel()
76
+ state = torch.load(weights_path, map_location=self.device, weights_only=False)
77
+ net.load_state_dict(state)
78
+ net.to(self.device)
79
+ net.eval()
80
+ self._model = net
81
+ logger.info("Model loaded successfully.")
82
+
83
+ # ------------------------------------------------------------------
84
 
85
  def predict(self, image: Image.Image) -> dict:
86
+ if self._model is None or self.mode == "mock":
87
+ return self._mock_predict(image)
88
+ return self._real_predict(image)
89
+
90
+ # ------------------------------------------------------------------
91
+ # Real inference with Bayesian MC-Dropout
92
+ # ------------------------------------------------------------------
93
+
94
+ def _real_predict(self, image: Image.Image) -> dict:
95
+ rgb = image.convert("RGB")
96
+ tensor = TRANSFORM(rgb).unsqueeze(0).to(self.device)
97
 
98
+ def enable_dropout(m: nn.Module) -> None:
99
+ if isinstance(m, (nn.Dropout, nn.Dropout2d)):
100
+ m.train()
101
+
102
+ self._model.apply(enable_dropout)
103
+
104
+ mc_predictions: list[float] = []
105
+ with torch.no_grad():
106
+ for _ in range(MC_PASSES):
107
+ logits = self._model(tensor)
108
+ prob = torch.sigmoid(logits).item()
109
+ mc_predictions.append(prob)
110
+
111
+ self._model.eval()
112
+
113
+ prob_malig = float(np.mean(mc_predictions))
114
+ variance = float(np.var(mc_predictions))
115
+
116
+ decision_confidence = max(0.50, 0.99 - (variance * 2.0))
117
+ if prob_malig < 0.10 or prob_malig > 0.90:
118
+ decision_confidence = min(0.99, decision_confidence + 0.10)
119
+
120
+ prediction = "Malignant" if prob_malig >= THRESHOLD else "Benign"
121
+ birads = _prob_to_birads(prob_malig)
122
+
123
+ return {
124
+ "birads": birads,
125
+ "confidence": round(decision_confidence, 3),
126
+ "malignancy_probability": round(prob_malig, 3),
127
+ "findings_text": _birads_findings(birads, prob_malig, prediction),
128
+ "model_version": self.version,
129
+ }
130
+
131
+ # ------------------------------------------------------------------
132
+ # Deterministic mock fallback (no weights needed)
133
+ # ------------------------------------------------------------------
134
+
135
+ @staticmethod
136
+ def _mock_predict(image: Image.Image) -> dict:
137
+ import hashlib
138
+
139
+ arr = np.array(image.convert("L"), dtype=np.float32) / 255.0
140
  digest = hashlib.sha256(arr.tobytes()).hexdigest()
141
  seed = int(digest[:8], 16)
142
  rng = np.random.default_rng(seed)
143
+ raw = float(min(max(arr.mean() + rng.uniform(-0.04, 0.04), 0.0), 1.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ birads = _prob_to_birads(raw)
146
 
147
  return {
148
  "birads": birads,
149
+ "confidence": round(max(0.55, min(0.98, 0.55 + abs(raw - 0.5))), 3),
150
+ "malignancy_probability": round(raw, 3),
151
+ "findings_text": _birads_findings(birads, raw, "Malignant" if raw >= THRESHOLD else "Benign"),
152
+ "model_version": "mock-v1",
153
  }
app/schemas.py CHANGED
@@ -1,8 +1,13 @@
1
  from pydantic import BaseModel, Field
2
 
3
 
 
 
 
 
4
  class PredictResponse(BaseModel):
5
  birads: int = Field(ge=0, le=6)
6
  confidence: float = Field(ge=0.0, le=1.0)
 
7
  findings_text: str
8
  model_version: str
 
1
  from pydantic import BaseModel, Field
2
 
3
 
4
+ class AnalyzeRequest(BaseModel):
5
+ image_url: str
6
+
7
+
8
  class PredictResponse(BaseModel):
9
  birads: int = Field(ge=0, le=6)
10
  confidence: float = Field(ge=0.0, le=1.0)
11
+ malignancy_probability: float = Field(ge=0.0, le=1.0, default=0.0)
12
  findings_text: str
13
  model_version: str
prepare_data.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ prepare_data.py β€” organise raw CBIS-DDSM images into train/val folder structure.
3
+
4
+ If your downloaded images are already in data/train/benign etc., skip this.
5
+
6
+ Usage
7
+ -----
8
+ python prepare_data.py --images /path/to/raw/images --csv /path/to/labels.csv
9
+
10
+ CSV must have columns: file_path, pathology
11
+ pathology values: BENIGN, MALIGNANT (or benign, malignant)
12
+
13
+ Output
14
+ ------
15
+ data/
16
+ train/benign/ train/malignant/
17
+ val/benign/ val/malignant/
18
+ """
19
+
20
+ import argparse
21
+ import os
22
+ import shutil
23
+ import random
24
+
25
+ TRAIN_RATIO = 0.85
26
+
27
+
28
+ def prepare(images_dir: str, csv_path: str, output_dir: str, seed: int = 42) -> None:
29
+ import csv
30
+
31
+ random.seed(seed)
32
+
33
+ records: list[tuple[str, str]] = []
34
+ with open(csv_path, newline="") as f:
35
+ reader = csv.DictReader(f)
36
+ for row in reader:
37
+ # normalise label
38
+ label = row.get("pathology", row.get("label", "")).strip().lower()
39
+ if label in ("benign", "benign_without_callback"):
40
+ label = "benign"
41
+ elif label in ("malignant",):
42
+ label = "malignant"
43
+ else:
44
+ continue # skip unknown labels
45
+
46
+ img_path = os.path.join(images_dir, row.get("file_path", "").strip())
47
+ if os.path.isfile(img_path):
48
+ records.append((img_path, label))
49
+
50
+ print(f"Found {len(records)} labelled images")
51
+ random.shuffle(records)
52
+
53
+ split = int(len(records) * TRAIN_RATIO)
54
+ splits = {"train": records[:split], "val": records[split:]}
55
+
56
+ for split_name, items in splits.items():
57
+ for label in ("benign", "malignant"):
58
+ os.makedirs(os.path.join(output_dir, split_name, label), exist_ok=True)
59
+ for src, label in items:
60
+ fname = os.path.basename(src)
61
+ dst = os.path.join(output_dir, split_name, label, fname)
62
+ shutil.copy2(src, dst)
63
+ counts = {lbl: sum(1 for _, l in items if l == lbl) for lbl in ("benign", "malignant")}
64
+ print(f"{split_name}: {counts}")
65
+
66
+ print(f"Data prepared in {output_dir}/")
67
+
68
+
69
+ if __name__ == "__main__":
70
+ parser = argparse.ArgumentParser()
71
+ parser.add_argument("--images", required=True, help="Directory containing raw image files")
72
+ parser.add_argument("--csv", required=True, help="CSV file with file_path and pathology columns")
73
+ parser.add_argument("--output", default="data", help="Output directory")
74
+ args = parser.parse_args()
75
+ prepare(args.images, args.csv, args.output)
requirements.txt CHANGED
@@ -4,3 +4,8 @@ python-multipart==0.0.9
4
  pydantic==2.11.3
5
  numpy==2.2.4
6
  pillow==11.1.0
 
 
 
 
 
 
4
  pydantic==2.11.3
5
  numpy==2.2.4
6
  pillow==11.1.0
7
+ torch>=2.0.0
8
+ torchvision>=0.15.0
9
+ timm>=0.9.0
10
+ opencv-python-headless>=4.8.0
11
+ requests>=2.28.0
train.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train.py β€” Fine-tune the SensiNet dual-stream model on a binary mammogram dataset.
3
+
4
+ Expected dataset layout
5
+ -----------------------
6
+ data/
7
+ train/
8
+ benign/ <- benign mammogram images (.jpg / .png / .dcm converted to jpg)
9
+ malignant/ <- malignant mammogram images
10
+ val/
11
+ benign/
12
+ malignant/
13
+
14
+ If you only have a flat folder + CSV (CBIS-DDSM style), run prepare_data.py first.
15
+
16
+ Usage
17
+ -----
18
+ python train.py --data data --output models/advanced_model_best.pth
19
+
20
+ The saved file is a raw state_dict compatible with MammogramModel._load_model().
21
+ """
22
+
23
+ import argparse
24
+ import os
25
+ from pathlib import Path
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ from torch.optim import Adam
30
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
31
+ from torch.utils.data import DataLoader
32
+ from torchvision import datasets, transforms
33
+
34
+ from app.architecture import AdvancedBreastCancerModel
35
+
36
+ # ── Hyperparameters ────────────────────────────────────────────────────────────
37
+ IMG_SIZE = 299 # Xception / EfficientNet-B3 both happy at 299
38
+ BATCH_SIZE = 16
39
+ EPOCHS_HEAD = 20 # frozen backbone, train classifier + projection layers only
40
+ EPOCHS_FINE = 50 # unfreeze all, lower LR
41
+ LR_HEAD = 1e-3
42
+ LR_FINE = 1e-5
43
+ PATIENCE_EARLY = 10
44
+ PATIENCE_LR = 4
45
+ # ──────────────────────────────────────────────────────────────────────────────
46
+
47
+
48
+ def make_loaders(data_dir: str):
49
+ train_tf = transforms.Compose([
50
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
51
+ transforms.RandomHorizontalFlip(),
52
+ transforms.RandomRotation(15),
53
+ transforms.ColorJitter(brightness=0.15, contrast=0.15),
54
+ transforms.ToTensor(),
55
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
56
+ ])
57
+ val_tf = transforms.Compose([
58
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
61
+ ])
62
+
63
+ train_ds = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_tf)
64
+ val_ds = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=val_tf)
65
+
66
+ # Expect exactly two classes: benign=0, malignant=1
67
+ print(f"Class mapping: {train_ds.class_to_idx}")
68
+ assert set(train_ds.class_to_idx.keys()) == {"benign", "malignant"}, (
69
+ "Dataset must have exactly 'benign' and 'malignant' subdirs"
70
+ )
71
+
72
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
73
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
74
+ return train_loader, val_loader, train_ds.class_to_idx
75
+
76
+
77
+ def _freeze_backbones(model: AdvancedBreastCancerModel) -> None:
78
+ for param in model.stream1.parameters():
79
+ param.requires_grad = False
80
+ for param in model.stream2.parameters():
81
+ param.requires_grad = False
82
+
83
+
84
+ def _unfreeze_all(model: AdvancedBreastCancerModel) -> None:
85
+ for param in model.parameters():
86
+ param.requires_grad = True
87
+
88
+
89
+ def run_epoch(model, loader, criterion, optimizer, device, training: bool):
90
+ model.train() if training else model.eval()
91
+ total_loss = 0.0
92
+ correct = 0
93
+ total = 0
94
+
95
+ ctx = torch.enable_grad() if training else torch.no_grad()
96
+ with ctx:
97
+ for images, labels in loader:
98
+ images = images.to(device)
99
+ # labels: 0=benign, 1=malignant β†’ float for BCEWithLogitsLoss
100
+ targets = labels.float().to(device)
101
+
102
+ logits = model(images).squeeze(1)
103
+ loss = criterion(logits, targets)
104
+
105
+ if training:
106
+ optimizer.zero_grad()
107
+ loss.backward()
108
+ optimizer.step()
109
+
110
+ total_loss += loss.item() * images.size(0)
111
+ preds = (torch.sigmoid(logits) >= 0.40).long()
112
+ correct += (preds == labels.to(device)).sum().item()
113
+ total += images.size(0)
114
+
115
+ return total_loss / total, correct / total
116
+
117
+
118
+ def train(data_dir: str, output_path: str) -> None:
119
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120
+ print(f"Device: {device}")
121
+
122
+ train_loader, val_loader, _ = make_loaders(data_dir)
123
+
124
+ model = AdvancedBreastCancerModel().to(device)
125
+ criterion = nn.BCEWithLogitsLoss()
126
+
127
+ best_val_acc = 0.0
128
+ output_path = Path(output_path)
129
+ output_path.parent.mkdir(parents=True, exist_ok=True)
130
+
131
+ # ── Phase 1: train head only ───────────────────────────────────────────────
132
+ print("\n=== Phase 1: training classifier head (frozen backbones) ===")
133
+ _freeze_backbones(model)
134
+ optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR_HEAD)
135
+ scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-7, verbose=True)
136
+ no_improve = 0
137
+
138
+ for epoch in range(1, EPOCHS_HEAD + 1):
139
+ tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True)
140
+ vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False)
141
+ scheduler.step(vl_loss)
142
+ print(f"[P1 {epoch:02d}/{EPOCHS_HEAD}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}")
143
+
144
+ if vl_acc > best_val_acc:
145
+ best_val_acc = vl_acc
146
+ torch.save(model.state_dict(), output_path)
147
+ print(f" βœ“ Saved (val_acc={best_val_acc:.3f})")
148
+ no_improve = 0
149
+ else:
150
+ no_improve += 1
151
+ if no_improve >= PATIENCE_EARLY:
152
+ print(" Early stopping (Phase 1)")
153
+ break
154
+
155
+ # ── Phase 2: fine-tune all layers ─────────────────────────────────────────
156
+ print("\n=== Phase 2: fine-tuning all layers ===")
157
+ _unfreeze_all(model)
158
+ optimizer = Adam(model.parameters(), lr=LR_FINE)
159
+ scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-8, verbose=True)
160
+ no_improve = 0
161
+
162
+ for epoch in range(1, EPOCHS_FINE + 1):
163
+ tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True)
164
+ vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False)
165
+ scheduler.step(vl_loss)
166
+ print(f"[P2 {epoch:02d}/{EPOCHS_FINE}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}")
167
+
168
+ if vl_acc > best_val_acc:
169
+ best_val_acc = vl_acc
170
+ torch.save(model.state_dict(), output_path)
171
+ print(f" βœ“ Saved (val_acc={best_val_acc:.3f})")
172
+ no_improve = 0
173
+ else:
174
+ no_improve += 1
175
+ if no_improve >= PATIENCE_EARLY:
176
+ print(" Early stopping (Phase 2)")
177
+ break
178
+
179
+ print(f"\nDone. Best val_acc={best_val_acc:.3f}")
180
+ print(f"Weights β†’ {output_path}")
181
+
182
+
183
+ if __name__ == "__main__":
184
+ parser = argparse.ArgumentParser(description="Train SensiNet mammogram classifier")
185
+ parser.add_argument("--data", default="data", help="Root data dir (must contain train/ and val/)")
186
+ parser.add_argument("--output", default="weights/advanced_model_best.pth", help="Output weights path")
187
+ args = parser.parse_args()
188
+ train(args.data, args.output)