mammogram-analyzer / train.py
tampee's picture
feat: integrate real SensiNet mammography model
596aaa1
"""
train.py β€” Fine-tune the SensiNet dual-stream model on a binary mammogram dataset.
Expected dataset layout
-----------------------
data/
train/
benign/ <- benign mammogram images (.jpg / .png / .dcm converted to jpg)
malignant/ <- malignant mammogram images
val/
benign/
malignant/
If you only have a flat folder + CSV (CBIS-DDSM style), run prepare_data.py first.
Usage
-----
python train.py --data data --output models/advanced_model_best.pth
The saved file is a raw state_dict compatible with MammogramModel._load_model().
"""
import argparse
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from app.architecture import AdvancedBreastCancerModel
# ── Hyperparameters ────────────────────────────────────────────────────────────
IMG_SIZE = 299 # Xception / EfficientNet-B3 both happy at 299
BATCH_SIZE = 16
EPOCHS_HEAD = 20 # frozen backbone, train classifier + projection layers only
EPOCHS_FINE = 50 # unfreeze all, lower LR
LR_HEAD = 1e-3
LR_FINE = 1e-5
PATIENCE_EARLY = 10
PATIENCE_LR = 4
# ──────────────────────────────────────────────────────────────────────────────
def make_loaders(data_dir: str):
train_tf = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.15, contrast=0.15),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_tf = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_ds = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_tf)
val_ds = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=val_tf)
# Expect exactly two classes: benign=0, malignant=1
print(f"Class mapping: {train_ds.class_to_idx}")
assert set(train_ds.class_to_idx.keys()) == {"benign", "malignant"}, (
"Dataset must have exactly 'benign' and 'malignant' subdirs"
)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
return train_loader, val_loader, train_ds.class_to_idx
def _freeze_backbones(model: AdvancedBreastCancerModel) -> None:
for param in model.stream1.parameters():
param.requires_grad = False
for param in model.stream2.parameters():
param.requires_grad = False
def _unfreeze_all(model: AdvancedBreastCancerModel) -> None:
for param in model.parameters():
param.requires_grad = True
def run_epoch(model, loader, criterion, optimizer, device, training: bool):
model.train() if training else model.eval()
total_loss = 0.0
correct = 0
total = 0
ctx = torch.enable_grad() if training else torch.no_grad()
with ctx:
for images, labels in loader:
images = images.to(device)
# labels: 0=benign, 1=malignant β†’ float for BCEWithLogitsLoss
targets = labels.float().to(device)
logits = model(images).squeeze(1)
loss = criterion(logits, targets)
if training:
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
preds = (torch.sigmoid(logits) >= 0.40).long()
correct += (preds == labels.to(device)).sum().item()
total += images.size(0)
return total_loss / total, correct / total
def train(data_dir: str, output_path: str) -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
train_loader, val_loader, _ = make_loaders(data_dir)
model = AdvancedBreastCancerModel().to(device)
criterion = nn.BCEWithLogitsLoss()
best_val_acc = 0.0
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# ── Phase 1: train head only ───────────────────────────────────────────────
print("\n=== Phase 1: training classifier head (frozen backbones) ===")
_freeze_backbones(model)
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR_HEAD)
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-7, verbose=True)
no_improve = 0
for epoch in range(1, EPOCHS_HEAD + 1):
tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True)
vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False)
scheduler.step(vl_loss)
print(f"[P1 {epoch:02d}/{EPOCHS_HEAD}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}")
if vl_acc > best_val_acc:
best_val_acc = vl_acc
torch.save(model.state_dict(), output_path)
print(f" βœ“ Saved (val_acc={best_val_acc:.3f})")
no_improve = 0
else:
no_improve += 1
if no_improve >= PATIENCE_EARLY:
print(" Early stopping (Phase 1)")
break
# ── Phase 2: fine-tune all layers ─────────────────────────────────────────
print("\n=== Phase 2: fine-tuning all layers ===")
_unfreeze_all(model)
optimizer = Adam(model.parameters(), lr=LR_FINE)
scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-8, verbose=True)
no_improve = 0
for epoch in range(1, EPOCHS_FINE + 1):
tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True)
vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False)
scheduler.step(vl_loss)
print(f"[P2 {epoch:02d}/{EPOCHS_FINE}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}")
if vl_acc > best_val_acc:
best_val_acc = vl_acc
torch.save(model.state_dict(), output_path)
print(f" βœ“ Saved (val_acc={best_val_acc:.3f})")
no_improve = 0
else:
no_improve += 1
if no_improve >= PATIENCE_EARLY:
print(" Early stopping (Phase 2)")
break
print(f"\nDone. Best val_acc={best_val_acc:.3f}")
print(f"Weights β†’ {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train SensiNet mammogram classifier")
parser.add_argument("--data", default="data", help="Root data dir (must contain train/ and val/)")
parser.add_argument("--output", default="weights/advanced_model_best.pth", help="Output weights path")
args = parser.parse_args()
train(args.data, args.output)