| """ |
| Train CNN (ResNet-50 + GAP + 3 heads) on selected index. |
| Stratified 70/15/15 train/val/test by artist; checkpoints to checkpoints/. |
| Usage: python scripts/train_cnn.py [--arch cnn|cnnrnn] [--epochs N] [--resume] [--batch-size N] [--cpu] |
| --resume: load last.pt and train --epochs more (default 10). |
| --batch-size N: default 64; use 16 or 32 if MPS OOM on Mac. |
| --cpu: force CPU (avoids MPS out-of-memory on Apple Silicon). |
| |
| Persistence (under checkpoints/<run>/; see config.checkpoint_dir_for_arch): |
| - last.pt: latest epoch (model + optimizer) — used by --resume. |
| - best.pt: epoch with lowest val_loss — use for eval. |
| - train_log.csv: one row per epoch (loss + acc); append-only. |
| - results_summary.csv: best checkpoint metrics; written at end of run. |
| """ |
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| import pandas as pd |
| import os |
| from datetime import datetime |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Subset |
| from torchvision import transforms as T |
| from sklearn.model_selection import train_test_split |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
| def _atomic_torch_save(obj: object, path: Path) -> None: |
| """Write `path` via a temp file + `os.replace` so a kill mid-write does not truncate `last.pt` / `best.pt`.""" |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| tmp = path.with_suffix(path.suffix + ".tmp") |
| torch.save(obj, tmp) |
| os.replace(tmp, path) |
| sys.path.insert(0, str(ROOT / "src")) |
|
|
| from config import ( |
| INDEX_SELECTED, |
| WIKIART_ROOT, |
| checkpoint_dir_for_arch, |
| N_STYLE, |
| N_ARTIST, |
| N_GENRE, |
| BATCH_SIZE, |
| LR_BACKBONE, |
| LR_HEADS, |
| MOMENTUM, |
| WEIGHT_DECAY, |
| EPOCHS, |
| COOLDOWN_EPOCHS, |
| GRAD_CLIP, |
| LOSS_WEIGHT_ARTIST, |
| IMAGENET_MEAN, |
| IMAGENET_STD, |
| ) |
| from dataset import WikiArtDataset |
| from model import ResNet50BiLSTMThreeHeads, ResNet50ThreeHeads |
|
|
|
|
| def get_transforms(train: bool): |
| if train: |
| return T.Compose([ |
| T.RandomResizedCrop(224, scale=(0.08, 1.0)), |
| T.RandomHorizontalFlip(), |
| T.ToTensor(), |
| T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ]) |
| return T.Compose([ |
| T.Resize(256), |
| T.CenterCrop(224), |
| T.ToTensor(), |
| T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ]) |
|
|
|
|
| def stratified_split(df: pd.DataFrame, seed: int = 42): |
| """70/15/15 train/val/test stratified by artist_id.""" |
| idx = df.index.tolist() |
| y = df["artist_id"].values |
| idx_train, idx_rest = train_test_split(idx, test_size=0.3, stratify=y, random_state=seed) |
| y_rest = df.loc[idx_rest, "artist_id"].values |
| idx_val, idx_test = train_test_split(idx_rest, test_size=0.5, stratify=y_rest, random_state=seed) |
| return idx_train, idx_val, idx_test |
|
|
|
|
| def now_ts() -> str: |
| return datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
| def _scheduler_for_resume_legacy( |
| optimizer: torch.optim.Optimizer, |
| ckpt: dict, |
| cosine_t_max: int, |
| batches_per_epoch: int, |
| ) -> torch.optim.lr_scheduler.CosineAnnealingLR: |
| """Approximate cosine step count for checkpoints saved before `scheduler_state_dict` existed.""" |
| for g in optimizer.param_groups: |
| g.setdefault("initial_lr", g["lr"]) |
| e = int(ckpt.get("epoch", -1)) |
| if e < 0: |
| return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_t_max) |
| if ckpt.get("interrupted"): |
| b = int(ckpt.get("batch_in_epoch") or 0) |
| done = e * batches_per_epoch + max(0, b) |
| else: |
| done = (e + 1) * batches_per_epoch |
| last = min(max(done - 1, -1), cosine_t_max - 1) |
| return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_t_max, last_epoch=last) |
|
|
|
|
| def main() -> None: |
| if not INDEX_SELECTED.exists(): |
| print(f"[{now_ts()}] ERROR: {INDEX_SELECTED} not found. Run scripts/build_artgan_index.py first.") |
| sys.exit(1) |
| if not WIKIART_ROOT.exists(): |
| print(f"[{now_ts()}] ERROR: {WIKIART_ROOT} not found.") |
| sys.exit(1) |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--arch", type=str, default="cnn", choices=["cnn", "cnnrnn"], help="Model architecture") |
| parser.add_argument("--epochs", type=int, default=None, help="Total epochs (or extra if --resume)") |
| parser.add_argument("--resume", action="store_true", help="Load last.pt and train --epochs more (default 10)") |
| parser.add_argument("--batch-size", type=int, default=None, help="Batch size (default from config). Use 16 or 32 if MPS OOM.") |
| parser.add_argument("--cpu", action="store_true", help="Force CPU (avoids MPS out-of-memory on Mac)") |
| args = parser.parse_args() |
|
|
| if args.cpu: |
| |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "0" |
| device = torch.device("cpu") |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| else: |
| device = torch.device("cpu") |
| batch_size = args.batch_size if args.batch_size is not None else BATCH_SIZE |
| print(f"[{now_ts()}] [device] Using device={device} (args.cpu={args.cpu})") |
| ckpt_dir = checkpoint_dir_for_arch(args.arch) |
| ckpt_dir.mkdir(parents=True, exist_ok=True) |
| print(f"[{now_ts()}] [checkpoint dir] {ckpt_dir}") |
|
|
| df = pd.read_csv(INDEX_SELECTED) |
| print(f"[{now_ts()}] [1/5] Loaded index: {len(df):,} rows from {INDEX_SELECTED.name}") |
| idx_train, idx_val, idx_test = stratified_split(df) |
| print(f"[{now_ts()}] [2/5] Split: train {len(idx_train):,}, val {len(idx_val):,}, test {len(idx_test):,}") |
|
|
| train_ds = WikiArtDataset(INDEX_SELECTED, WIKIART_ROOT, transform=get_transforms(train=True)) |
| val_ds = WikiArtDataset(INDEX_SELECTED, WIKIART_ROOT, transform=get_transforms(train=False)) |
| train_subset = Subset(train_ds, idx_train) |
| val_subset = Subset(val_ds, idx_val) |
|
|
| train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=(device.type == "cuda")) |
| val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=0) |
| print( |
| f"[{now_ts()}] [3/5] Data loaders ready: {len(train_loader)} train batches, {len(val_loader)} val batches (batch_size={batch_size})" |
| ) |
|
|
| start_epoch = 0 |
| best_val_loss = float("inf") |
| resume_path = ckpt_dir / "last.pt" |
|
|
| ckpt: dict | None = None |
| if args.resume and resume_path.exists(): |
| try: |
| ckpt = torch.load(resume_path, map_location=device, weights_only=False) |
| except Exception as e: |
| print( |
| f"[{now_ts()}] ERROR: Cannot load {resume_path} (often a truncated file if the process was killed " |
| f"during `torch.save`).\n" |
| f" {e}\n" |
| f" Fix: if best.pt is intact, copy it over last.pt and resume, e.g.\n" |
| f" cp {ckpt_dir / 'best.pt'} {resume_path}", |
| file=sys.stderr, |
| ) |
| sys.exit(1) |
| start_epoch = int(ckpt["epoch"]) + 1 |
| best_val_loss = ckpt.get("val_loss", float("inf")) |
| extra = args.epochs if args.epochs is not None else 10 |
| total_epochs_this_run = extra |
| print(f"[{now_ts()}] Resuming from epoch {start_epoch - 1}, training {total_epochs_this_run} more epochs.") |
| else: |
| if args.resume: |
| print(f"[{now_ts()}] WARNING: --resume but no last.pt found; starting from scratch.") |
| extra = None |
| total_epochs_this_run = args.epochs if args.epochs is not None else EPOCHS + COOLDOWN_EPOCHS |
|
|
| if args.arch == "cnnrnn": |
| model = ResNet50BiLSTMThreeHeads(n_genre=N_GENRE, n_style=N_STYLE, n_artist=N_ARTIST).to(device) |
| backbone_params = list(model.backbone.parameters()) |
| head_params = list(model.lstm.parameters()) |
| head_params += list(model.genre_head.parameters()) |
| head_params += list(model.style_head.parameters()) |
| head_params += list(model.artist_head.parameters()) |
| else: |
| model = ResNet50ThreeHeads(n_genre=N_GENRE, n_style=N_STYLE, n_artist=N_ARTIST).to(device) |
| backbone_params = list(model.backbone.parameters()) + list(model.pool.parameters()) |
| head_params = ( |
| list(model.genre_head.parameters()) |
| + list(model.style_head.parameters()) |
| + list(model.artist_head.parameters()) |
| ) |
| optimizer = torch.optim.SGD( |
| [ |
| {"params": backbone_params, "lr": LR_BACKBONE}, |
| {"params": head_params, "lr": LR_HEADS}, |
| ], |
| momentum=MOMENTUM, |
| weight_decay=WEIGHT_DECAY, |
| ) |
| if args.resume and resume_path.exists() and ckpt is not None: |
| model.load_state_dict(ckpt["model_state_dict"]) |
| if "optimizer_state_dict" in ckpt: |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
|
|
| |
| cosine_t_max = (EPOCHS + COOLDOWN_EPOCHS) * len(train_loader) |
| if args.resume and ckpt is not None and "scheduler_state_dict" in ckpt: |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_t_max) |
| try: |
| scheduler.load_state_dict(ckpt["scheduler_state_dict"]) |
| except Exception as e: |
| print(f"[{now_ts()}] WARNING: Could not load scheduler state ({e}); reinitializing LR schedule from epoch.") |
| scheduler = _scheduler_for_resume_legacy(optimizer, ckpt, cosine_t_max, len(train_loader)) |
| elif args.resume and ckpt is not None: |
| scheduler = _scheduler_for_resume_legacy(optimizer, ckpt, cosine_t_max, len(train_loader)) |
| else: |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_t_max) |
|
|
| sched_step = getattr(scheduler, "last_epoch", -1) + 1 |
| print( |
| f"[{now_ts()}] [4/5] Model and optimizer ready. " |
| f"Scheduler: CosineAnnealingLR (T_max={cosine_t_max} steps ≈ {EPOCHS}+{COOLDOWN_EPOCHS} epochs × batches; " |
| f"next step index {sched_step})" |
| ) |
|
|
| def _log_next_steps() -> None: |
| print("\n--- Persistence (checkpoint dir) ---") |
| print(f" arch dir = {ckpt_dir}") |
| print(" last.pt = latest epoch (for --resume)") |
| print(" best.pt = best val_loss (for eval)") |
| print(" train_log.csv = per-epoch metrics (append)") |
| print(" results_summary.csv = best metrics (written at end)") |
| print(f"To resume training: python scripts/train_cnn.py --arch {args.arch} --resume --epochs N") |
| print(f"To evaluate best: python scripts/eval_cnn.py --arch {args.arch}") |
|
|
| print(f"[{now_ts()}] [5/5] Device: {device} Arch: {args.arch} Saving to: {ckpt_dir.resolve()}") |
| if start_epoch > 0: |
| print(f"[{now_ts()}] Resumed from epoch {start_epoch - 1}; running {total_epochs_this_run} more epochs.") |
| else: |
| print(f"[{now_ts()}] Starting from scratch; running {total_epochs_this_run} epochs.") |
| print() |
|
|
| current_epoch: int | None = None |
| current_batch_in_epoch: int | None = None |
| current_num_batches_in_epoch: int | None = None |
|
|
| try: |
| for i in range(total_epochs_this_run): |
| epoch = start_epoch + i |
| current_epoch = epoch |
| current_batch_in_epoch = 0 |
| current_num_batches_in_epoch = len(train_loader) |
| print(f"[{now_ts()}] --- Epoch {epoch} (step {i + 1}/{total_epochs_this_run}) ---") |
| model.train() |
| train_loss = 0.0 |
| for b, (images, style_id, artist_id, genre_id) in enumerate(train_loader, start=1): |
| current_batch_in_epoch = b |
| images = images.to(device) |
| style_id = style_id.to(device) |
| artist_id = artist_id.to(device) |
| genre_id = genre_id.to(device) |
| optimizer.zero_grad() |
| logits_g, logits_s, logits_a = model(images) |
| loss_g = F.cross_entropy(logits_g, genre_id) |
| loss_s = F.cross_entropy(logits_s, style_id) |
| loss_a = F.cross_entropy(logits_a, artist_id) |
| loss = loss_g + loss_s + LOSS_WEIGHT_ARTIST * loss_a |
| loss.backward() |
| if GRAD_CLIP > 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| optimizer.step() |
| scheduler.step() |
| train_loss += loss.item() |
| train_loss /= len(train_loader) |
|
|
| model.eval() |
| val_loss = 0.0 |
| correct_g = correct_s = correct_a = correct_a5 = total = 0 |
| with torch.no_grad(): |
| for images, style_id, artist_id, genre_id in val_loader: |
| images = images.to(device) |
| style_id = style_id.to(device) |
| artist_id = artist_id.to(device) |
| genre_id = genre_id.to(device) |
| logits_g, logits_s, logits_a = model(images) |
| loss = ( |
| F.cross_entropy(logits_g, genre_id) |
| + F.cross_entropy(logits_s, style_id) |
| + LOSS_WEIGHT_ARTIST * F.cross_entropy(logits_a, artist_id) |
| ) |
| val_loss += loss.item() |
| n = images.size(0) |
| total += n |
| correct_g += (logits_g.argmax(1) == genre_id).sum().item() |
| correct_s += (logits_s.argmax(1) == style_id).sum().item() |
| correct_a += (logits_a.argmax(1) == artist_id).sum().item() |
| _, top5 = logits_a.topk(5, dim=1) |
| correct_a5 += (top5 == artist_id.unsqueeze(1)).any(1).sum().item() |
| val_loss /= len(val_loader) |
| acc_g = correct_g / total |
| acc_s = correct_s / total |
| acc_a = correct_a / total |
| acc_a5 = correct_a5 / total |
|
|
| is_best = val_loss < best_val_loss |
| if is_best: |
| best_val_loss = val_loss |
|
|
| ckpt = { |
| "epoch": epoch, |
| "arch": args.arch, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "scheduler_state_dict": scheduler.state_dict(), |
| "val_loss": val_loss, |
| "val_genre_acc": acc_g, |
| "val_style_acc": acc_s, |
| "val_artist_acc": acc_a, |
| "val_artist_top5_acc": acc_a5, |
| "n_genre": N_GENRE, |
| "n_style": N_STYLE, |
| "n_artist": N_ARTIST, |
| } |
| _atomic_torch_save(ckpt, ckpt_dir / "last.pt") |
| if is_best: |
| _atomic_torch_save(ckpt, ckpt_dir / "best.pt") |
|
|
| log_row = { |
| "epoch": epoch, |
| "train_loss": round(train_loss, 4), |
| "val_loss": round(val_loss, 4), |
| "genre_acc": round(acc_g, 4), |
| "style_acc": round(acc_s, 4), |
| "artist_acc": round(acc_a, 4), |
| "artist_top5_acc": round(acc_a5, 4), |
| } |
| log_path = ckpt_dir / "train_log.csv" |
| pd.DataFrame([log_row]).to_csv(log_path, mode="a", header=not log_path.exists(), index=False) |
|
|
| save_msg = " [saved last.pt" + (" + best.pt" if is_best else "") + "]" |
| print( |
| f"[{now_ts()}] train_loss={train_loss:.4f} val_loss={val_loss:.4f} " |
| f"genre={acc_g:.2%} style={acc_s:.2%} artist={acc_a:.2%} artist_top5={acc_a5:.2%} best={is_best}{save_msg}" |
| ) |
| except KeyboardInterrupt: |
| |
| interrupted_ckpt = { |
| "epoch": current_epoch if current_epoch is not None else -1, |
| "arch": args.arch, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "scheduler_state_dict": scheduler.state_dict(), |
| "val_loss": None, |
| "val_genre_acc": None, |
| "val_style_acc": None, |
| "val_artist_acc": None, |
| "val_artist_top5_acc": None, |
| "n_genre": N_GENRE, |
| "n_style": N_STYLE, |
| "n_artist": N_ARTIST, |
| "interrupted": True, |
| "batch_in_epoch": current_batch_in_epoch, |
| "num_batches_in_epoch": current_num_batches_in_epoch, |
| } |
| _atomic_torch_save(interrupted_ckpt, ckpt_dir / "last.pt") |
| print( |
| "\n" |
| f"[{now_ts()}] Stopped by user (Ctrl+C). Saved resumable checkpoint to " |
| f"{(ckpt_dir / 'last.pt').resolve()}" |
| ) |
| finally: |
| _log_next_steps() |
|
|
| |
| best_ckpt_path = ckpt_dir / "best.pt" |
| if best_ckpt_path.exists(): |
| best_ckpt = torch.load(best_ckpt_path, map_location="cpu", weights_only=False) |
| summary = { |
| "best_epoch": best_ckpt.get("epoch"), |
| "val_loss": best_ckpt.get("val_loss"), |
| "val_genre_acc": best_ckpt.get("val_genre_acc"), |
| "val_style_acc": best_ckpt.get("val_style_acc"), |
| "val_artist_acc": best_ckpt.get("val_artist_acc"), |
| "val_artist_top5_acc": best_ckpt.get("val_artist_top5_acc"), |
| } |
| pd.DataFrame([summary]).to_csv(ckpt_dir / "results_summary.csv", index=False) |
| print(f"[{now_ts()}] Results summary:", ckpt_dir / "results_summary.csv") |
| print(f"[{now_ts()}] Done. Best checkpoint:", best_ckpt_path) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|