File size: 7,638 Bytes
596aaa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
train.py  β€”  Fine-tune the SensiNet dual-stream model on a binary mammogram dataset.

Expected dataset layout
-----------------------
data/
  train/
    benign/        <- benign mammogram images (.jpg / .png / .dcm converted to jpg)
    malignant/     <- malignant mammogram images
  val/
    benign/
    malignant/

If you only have a flat folder + CSV (CBIS-DDSM style), run prepare_data.py first.

Usage
-----
python train.py --data data --output models/advanced_model_best.pth

The saved file is a raw state_dict compatible with MammogramModel._load_model().
"""

import argparse
import os
from pathlib import Path

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from app.architecture import AdvancedBreastCancerModel

# ── Hyperparameters ────────────────────────────────────────────────────────────
IMG_SIZE = 299          # Xception / EfficientNet-B3 both happy at 299
BATCH_SIZE = 16
EPOCHS_HEAD = 20        # frozen backbone, train classifier + projection layers only
EPOCHS_FINE = 50        # unfreeze all, lower LR
LR_HEAD = 1e-3
LR_FINE = 1e-5
PATIENCE_EARLY = 10
PATIENCE_LR = 4
# ──────────────────────────────────────────────────────────────────────────────


def make_loaders(data_dir: str):
    train_tf = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.15, contrast=0.15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    val_tf = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_ds = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_tf)
    val_ds = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=val_tf)

    # Expect exactly two classes: benign=0, malignant=1
    print(f"Class mapping: {train_ds.class_to_idx}")
    assert set(train_ds.class_to_idx.keys()) == {"benign", "malignant"}, (
        "Dataset must have exactly 'benign' and 'malignant' subdirs"
    )

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, val_loader, train_ds.class_to_idx


def _freeze_backbones(model: AdvancedBreastCancerModel) -> None:
    for param in model.stream1.parameters():
        param.requires_grad = False
    for param in model.stream2.parameters():
        param.requires_grad = False


def _unfreeze_all(model: AdvancedBreastCancerModel) -> None:
    for param in model.parameters():
        param.requires_grad = True


def run_epoch(model, loader, criterion, optimizer, device, training: bool):
    model.train() if training else model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    ctx = torch.enable_grad() if training else torch.no_grad()
    with ctx:
        for images, labels in loader:
            images = images.to(device)
            # labels: 0=benign, 1=malignant β†’ float for BCEWithLogitsLoss
            targets = labels.float().to(device)

            logits = model(images).squeeze(1)
            loss = criterion(logits, targets)

            if training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item() * images.size(0)
            preds = (torch.sigmoid(logits) >= 0.40).long()
            correct += (preds == labels.to(device)).sum().item()
            total += images.size(0)

    return total_loss / total, correct / total


def train(data_dir: str, output_path: str) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    train_loader, val_loader, _ = make_loaders(data_dir)

    model = AdvancedBreastCancerModel().to(device)
    criterion = nn.BCEWithLogitsLoss()

    best_val_acc = 0.0
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # ── Phase 1: train head only ───────────────────────────────────────────────
    print("\n=== Phase 1: training classifier head (frozen backbones) ===")
    _freeze_backbones(model)
    optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR_HEAD)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-7, verbose=True)
    no_improve = 0

    for epoch in range(1, EPOCHS_HEAD + 1):
        tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True)
        vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False)
        scheduler.step(vl_loss)
        print(f"[P1 {epoch:02d}/{EPOCHS_HEAD}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}")

        if vl_acc > best_val_acc:
            best_val_acc = vl_acc
            torch.save(model.state_dict(), output_path)
            print(f"  βœ“ Saved (val_acc={best_val_acc:.3f})")
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= PATIENCE_EARLY:
                print("  Early stopping (Phase 1)")
                break

    # ── Phase 2: fine-tune all layers ─────────────────────────────────────────
    print("\n=== Phase 2: fine-tuning all layers ===")
    _unfreeze_all(model)
    optimizer = Adam(model.parameters(), lr=LR_FINE)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-8, verbose=True)
    no_improve = 0

    for epoch in range(1, EPOCHS_FINE + 1):
        tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True)
        vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False)
        scheduler.step(vl_loss)
        print(f"[P2 {epoch:02d}/{EPOCHS_FINE}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}")

        if vl_acc > best_val_acc:
            best_val_acc = vl_acc
            torch.save(model.state_dict(), output_path)
            print(f"  βœ“ Saved (val_acc={best_val_acc:.3f})")
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= PATIENCE_EARLY:
                print("  Early stopping (Phase 2)")
                break

    print(f"\nDone. Best val_acc={best_val_acc:.3f}")
    print(f"Weights β†’ {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train SensiNet mammogram classifier")
    parser.add_argument("--data", default="data", help="Root data dir (must contain train/ and val/)")
    parser.add_argument("--output", default="weights/advanced_model_best.pth", help="Output weights path")
    args = parser.parse_args()
    train(args.data, args.output)