""" train.py — Fine-tune the SensiNet dual-stream model on a binary mammogram dataset. Expected dataset layout ----------------------- data/ train/ benign/ <- benign mammogram images (.jpg / .png / .dcm converted to jpg) malignant/ <- malignant mammogram images val/ benign/ malignant/ If you only have a flat folder + CSV (CBIS-DDSM style), run prepare_data.py first. Usage ----- python train.py --data data --output models/advanced_model_best.pth The saved file is a raw state_dict compatible with MammogramModel._load_model(). """ import argparse import os from pathlib import Path import torch import torch.nn as nn from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader from torchvision import datasets, transforms from app.architecture import AdvancedBreastCancerModel # ── Hyperparameters ──────────────────────────────────────────────────────────── IMG_SIZE = 299 # Xception / EfficientNet-B3 both happy at 299 BATCH_SIZE = 16 EPOCHS_HEAD = 20 # frozen backbone, train classifier + projection layers only EPOCHS_FINE = 50 # unfreeze all, lower LR LR_HEAD = 1e-3 LR_FINE = 1e-5 PATIENCE_EARLY = 10 PATIENCE_LR = 4 # ────────────────────────────────────────────────────────────────────────────── def make_loaders(data_dir: str): train_tf = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.15, contrast=0.15), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) val_tf = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_ds = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_tf) val_ds = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=val_tf) # Expect exactly two classes: benign=0, malignant=1 print(f"Class mapping: {train_ds.class_to_idx}") assert set(train_ds.class_to_idx.keys()) == {"benign", "malignant"}, ( "Dataset must have exactly 'benign' and 'malignant' subdirs" ) train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True) return train_loader, val_loader, train_ds.class_to_idx def _freeze_backbones(model: AdvancedBreastCancerModel) -> None: for param in model.stream1.parameters(): param.requires_grad = False for param in model.stream2.parameters(): param.requires_grad = False def _unfreeze_all(model: AdvancedBreastCancerModel) -> None: for param in model.parameters(): param.requires_grad = True def run_epoch(model, loader, criterion, optimizer, device, training: bool): model.train() if training else model.eval() total_loss = 0.0 correct = 0 total = 0 ctx = torch.enable_grad() if training else torch.no_grad() with ctx: for images, labels in loader: images = images.to(device) # labels: 0=benign, 1=malignant → float for BCEWithLogitsLoss targets = labels.float().to(device) logits = model(images).squeeze(1) loss = criterion(logits, targets) if training: optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * images.size(0) preds = (torch.sigmoid(logits) >= 0.40).long() correct += (preds == labels.to(device)).sum().item() total += images.size(0) return total_loss / total, correct / total def train(data_dir: str, output_path: str) -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") train_loader, val_loader, _ = make_loaders(data_dir) model = AdvancedBreastCancerModel().to(device) criterion = nn.BCEWithLogitsLoss() best_val_acc = 0.0 output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) # ── Phase 1: train head only ─────────────────────────────────────────────── print("\n=== Phase 1: training classifier head (frozen backbones) ===") _freeze_backbones(model) optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR_HEAD) scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-7, verbose=True) no_improve = 0 for epoch in range(1, EPOCHS_HEAD + 1): tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True) vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False) scheduler.step(vl_loss) print(f"[P1 {epoch:02d}/{EPOCHS_HEAD}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}") if vl_acc > best_val_acc: best_val_acc = vl_acc torch.save(model.state_dict(), output_path) print(f" ✓ Saved (val_acc={best_val_acc:.3f})") no_improve = 0 else: no_improve += 1 if no_improve >= PATIENCE_EARLY: print(" Early stopping (Phase 1)") break # ── Phase 2: fine-tune all layers ───────────────────────────────────────── print("\n=== Phase 2: fine-tuning all layers ===") _unfreeze_all(model) optimizer = Adam(model.parameters(), lr=LR_FINE) scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-8, verbose=True) no_improve = 0 for epoch in range(1, EPOCHS_FINE + 1): tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True) vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False) scheduler.step(vl_loss) print(f"[P2 {epoch:02d}/{EPOCHS_FINE}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}") if vl_acc > best_val_acc: best_val_acc = vl_acc torch.save(model.state_dict(), output_path) print(f" ✓ Saved (val_acc={best_val_acc:.3f})") no_improve = 0 else: no_improve += 1 if no_improve >= PATIENCE_EARLY: print(" Early stopping (Phase 2)") break print(f"\nDone. Best val_acc={best_val_acc:.3f}") print(f"Weights → {output_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train SensiNet mammogram classifier") parser.add_argument("--data", default="data", help="Root data dir (must contain train/ and val/)") parser.add_argument("--output", default="weights/advanced_model_best.pth", help="Output weights path") args = parser.parse_args() train(args.data, args.output)