Spaces:
Sleeping
Sleeping
| """ | |
| train.py β Fine-tune the SensiNet dual-stream model on a binary mammogram dataset. | |
| Expected dataset layout | |
| ----------------------- | |
| data/ | |
| train/ | |
| benign/ <- benign mammogram images (.jpg / .png / .dcm converted to jpg) | |
| malignant/ <- malignant mammogram images | |
| val/ | |
| benign/ | |
| malignant/ | |
| If you only have a flat folder + CSV (CBIS-DDSM style), run prepare_data.py first. | |
| Usage | |
| ----- | |
| python train.py --data data --output models/advanced_model_best.pth | |
| The saved file is a raw state_dict compatible with MammogramModel._load_model(). | |
| """ | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import Adam | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from torch.utils.data import DataLoader | |
| from torchvision import datasets, transforms | |
| from app.architecture import AdvancedBreastCancerModel | |
| # ββ Hyperparameters ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| IMG_SIZE = 299 # Xception / EfficientNet-B3 both happy at 299 | |
| BATCH_SIZE = 16 | |
| EPOCHS_HEAD = 20 # frozen backbone, train classifier + projection layers only | |
| EPOCHS_FINE = 50 # unfreeze all, lower LR | |
| LR_HEAD = 1e-3 | |
| LR_FINE = 1e-5 | |
| PATIENCE_EARLY = 10 | |
| PATIENCE_LR = 4 | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_loaders(data_dir: str): | |
| train_tf = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(15), | |
| transforms.ColorJitter(brightness=0.15, contrast=0.15), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| val_tf = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| train_ds = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_tf) | |
| val_ds = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=val_tf) | |
| # Expect exactly two classes: benign=0, malignant=1 | |
| print(f"Class mapping: {train_ds.class_to_idx}") | |
| assert set(train_ds.class_to_idx.keys()) == {"benign", "malignant"}, ( | |
| "Dataset must have exactly 'benign' and 'malignant' subdirs" | |
| ) | |
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True) | |
| val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True) | |
| return train_loader, val_loader, train_ds.class_to_idx | |
| def _freeze_backbones(model: AdvancedBreastCancerModel) -> None: | |
| for param in model.stream1.parameters(): | |
| param.requires_grad = False | |
| for param in model.stream2.parameters(): | |
| param.requires_grad = False | |
| def _unfreeze_all(model: AdvancedBreastCancerModel) -> None: | |
| for param in model.parameters(): | |
| param.requires_grad = True | |
| def run_epoch(model, loader, criterion, optimizer, device, training: bool): | |
| model.train() if training else model.eval() | |
| total_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| ctx = torch.enable_grad() if training else torch.no_grad() | |
| with ctx: | |
| for images, labels in loader: | |
| images = images.to(device) | |
| # labels: 0=benign, 1=malignant β float for BCEWithLogitsLoss | |
| targets = labels.float().to(device) | |
| logits = model(images).squeeze(1) | |
| loss = criterion(logits, targets) | |
| if training: | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() * images.size(0) | |
| preds = (torch.sigmoid(logits) >= 0.40).long() | |
| correct += (preds == labels.to(device)).sum().item() | |
| total += images.size(0) | |
| return total_loss / total, correct / total | |
| def train(data_dir: str, output_path: str) -> None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Device: {device}") | |
| train_loader, val_loader, _ = make_loaders(data_dir) | |
| model = AdvancedBreastCancerModel().to(device) | |
| criterion = nn.BCEWithLogitsLoss() | |
| best_val_acc = 0.0 | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # ββ Phase 1: train head only βββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n=== Phase 1: training classifier head (frozen backbones) ===") | |
| _freeze_backbones(model) | |
| optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR_HEAD) | |
| scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-7, verbose=True) | |
| no_improve = 0 | |
| for epoch in range(1, EPOCHS_HEAD + 1): | |
| tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True) | |
| vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False) | |
| scheduler.step(vl_loss) | |
| print(f"[P1 {epoch:02d}/{EPOCHS_HEAD}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}") | |
| if vl_acc > best_val_acc: | |
| best_val_acc = vl_acc | |
| torch.save(model.state_dict(), output_path) | |
| print(f" β Saved (val_acc={best_val_acc:.3f})") | |
| no_improve = 0 | |
| else: | |
| no_improve += 1 | |
| if no_improve >= PATIENCE_EARLY: | |
| print(" Early stopping (Phase 1)") | |
| break | |
| # ββ Phase 2: fine-tune all layers βββββββββββββββββββββββββββββββββββββββββ | |
| print("\n=== Phase 2: fine-tuning all layers ===") | |
| _unfreeze_all(model) | |
| optimizer = Adam(model.parameters(), lr=LR_FINE) | |
| scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=PATIENCE_LR, min_lr=1e-8, verbose=True) | |
| no_improve = 0 | |
| for epoch in range(1, EPOCHS_FINE + 1): | |
| tr_loss, tr_acc = run_epoch(model, train_loader, criterion, optimizer, device, training=True) | |
| vl_loss, vl_acc = run_epoch(model, val_loader, criterion, optimizer, device, training=False) | |
| scheduler.step(vl_loss) | |
| print(f"[P2 {epoch:02d}/{EPOCHS_FINE}] loss={tr_loss:.4f} acc={tr_acc:.3f} | val_loss={vl_loss:.4f} val_acc={vl_acc:.3f}") | |
| if vl_acc > best_val_acc: | |
| best_val_acc = vl_acc | |
| torch.save(model.state_dict(), output_path) | |
| print(f" β Saved (val_acc={best_val_acc:.3f})") | |
| no_improve = 0 | |
| else: | |
| no_improve += 1 | |
| if no_improve >= PATIENCE_EARLY: | |
| print(" Early stopping (Phase 2)") | |
| break | |
| print(f"\nDone. Best val_acc={best_val_acc:.3f}") | |
| print(f"Weights β {output_path}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Train SensiNet mammogram classifier") | |
| parser.add_argument("--data", default="data", help="Root data dir (must contain train/ and val/)") | |
| parser.add_argument("--output", default="weights/advanced_model_best.pth", help="Output weights path") | |
| args = parser.parse_args() | |
| train(args.data, args.output) | |