Spaces:
Runtime error
Runtime error
| """ | |
| ============================================================ | |
| Rangoli Classification Training Pipeline | |
| ============================================================ | |
| Full training loop with: | |
| - Mixed precision training (AMP) | |
| - Cosine annealing with warm restarts | |
| - Learning rate warmup | |
| - Gradient clipping | |
| - MixUp / CutMix augmentation | |
| - Early stopping | |
| - TensorBoard logging | |
| - Checkpoint management | |
| - Progressive unfreezing | |
| Usage: | |
| python scripts/train.py --config configs/config.yaml --model resnet50 | |
| python scripts/train.py --config configs/config.yaml --model efficientnet_b3 --gpu 0 | |
| python scripts/train.py --config configs/config.yaml --model all # Train all models | |
| ============================================================ | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import yaml | |
| import time | |
| import argparse | |
| import numpy as np | |
| from datetime import datetime | |
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.cuda.amp import GradScaler, autocast | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| # Add project root to path | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from dataset.rangoli_dataset import create_dataloaders, MixUpCutMix | |
| from models.classifier import build_model, build_loss_function | |
| class AverageMeter: | |
| """Track running averages.""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| class EarlyStopping: | |
| """Early stopping with patience.""" | |
| def __init__(self, patience=10, min_delta=0.001, mode="max"): | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.mode = mode | |
| self.counter = 0 | |
| self.best_score = None | |
| self.should_stop = False | |
| def __call__(self, score): | |
| if self.best_score is None: | |
| self.best_score = score | |
| return False | |
| if self.mode == "max": | |
| improved = score > self.best_score + self.min_delta | |
| else: | |
| improved = score < self.best_score - self.min_delta | |
| if improved: | |
| self.best_score = score | |
| self.counter = 0 | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| self.should_stop = True | |
| return self.should_stop | |
| def get_optimizer(model, config): | |
| """Create optimizer with layer-wise learning rates.""" | |
| training_cfg = config["training"] | |
| base_lr = training_cfg["learning_rate"] | |
| # Discriminative learning rates | |
| layer_groups = model.get_layer_groups() | |
| param_groups = [ | |
| {"params": g["params"], "lr": base_lr * g["lr_scale"]} | |
| for g in layer_groups | |
| ] | |
| if training_cfg["optimizer"] == "adamw": | |
| optimizer = optim.AdamW( | |
| param_groups, | |
| lr=base_lr, | |
| weight_decay=training_cfg["weight_decay"], | |
| ) | |
| elif training_cfg["optimizer"] == "sgd": | |
| optimizer = optim.SGD( | |
| param_groups, | |
| lr=base_lr, | |
| momentum=0.9, | |
| weight_decay=training_cfg["weight_decay"], | |
| nesterov=True, | |
| ) | |
| return optimizer | |
| def get_scheduler(optimizer, config): | |
| """Create learning rate scheduler.""" | |
| training_cfg = config["training"] | |
| if training_cfg["scheduler"] == "cosine_annealing_warm_restarts": | |
| scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( | |
| optimizer, | |
| T_0=training_cfg["T_0"], | |
| T_mult=training_cfg["T_mult"], | |
| eta_min=training_cfg["eta_min"], | |
| ) | |
| elif training_cfg["scheduler"] == "cosine_annealing": | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| T_max=training_cfg["num_epochs"], | |
| eta_min=training_cfg["eta_min"], | |
| ) | |
| elif training_cfg["scheduler"] == "one_cycle": | |
| scheduler = optim.lr_scheduler.OneCycleLR( | |
| optimizer, | |
| max_lr=training_cfg["learning_rate"], | |
| epochs=training_cfg["num_epochs"], | |
| steps_per_epoch=100, # Will be updated | |
| ) | |
| return scheduler | |
| def warmup_lr(optimizer, epoch, warmup_epochs, warmup_lr_val, base_lr): | |
| """Linear warmup.""" | |
| if epoch < warmup_epochs: | |
| lr = warmup_lr_val + (base_lr - warmup_lr_val) * epoch / warmup_epochs | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = lr * param_group.get("lr_scale", 1.0) if "lr_scale" in str(param_group) else lr | |
| def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, | |
| scaler, mixup_cutmix, device, epoch, config): | |
| """Train for one epoch.""" | |
| model.train() | |
| loss_meter = AverageMeter() | |
| acc_meter = AverageMeter() | |
| training_cfg = config["training"] | |
| use_amp = training_cfg.get("use_amp", True) and device.type == "cuda" | |
| pbar = tqdm(train_loader, desc=f" Train Epoch {epoch+1}", leave=False) | |
| for batch_idx, (images, targets) in enumerate(pbar): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| # Apply MixUp/CutMix | |
| use_mixup = mixup_cutmix is not None and np.random.random() < 0.5 | |
| if use_mixup: | |
| images, targets_mixed = mixup_cutmix(images, targets) | |
| # Forward pass with mixed precision | |
| with autocast(enabled=use_amp): | |
| outputs = model(images) | |
| if use_mixup: | |
| loss = criterion(outputs, targets_mixed) | |
| else: | |
| loss = criterion(outputs, targets) | |
| # Backward pass | |
| optimizer.zero_grad() | |
| if use_amp: | |
| scaler.scale(loss).backward() | |
| if training_cfg.get("max_grad_norm"): | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), training_cfg["max_grad_norm"] | |
| ) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| loss.backward() | |
| if training_cfg.get("max_grad_norm"): | |
| torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), training_cfg["max_grad_norm"] | |
| ) | |
| optimizer.step() | |
| # Accuracy (for non-mixup samples) | |
| if not use_mixup: | |
| _, predicted = outputs.max(1) | |
| correct = predicted.eq(targets).sum().item() | |
| acc_meter.update(correct / targets.size(0), targets.size(0)) | |
| loss_meter.update(loss.item(), images.size(0)) | |
| pbar.set_postfix({ | |
| "loss": f"{loss_meter.avg:.4f}", | |
| "acc": f"{acc_meter.avg:.4f}" if acc_meter.count > 0 else "N/A", | |
| "lr": f"{optimizer.param_groups[-1]['lr']:.6f}", | |
| }) | |
| if scheduler is not None: | |
| scheduler.step() | |
| return loss_meter.avg, acc_meter.avg | |
| def validate(model, val_loader, criterion, device, use_amp=True): | |
| """Validate the model.""" | |
| model.eval() | |
| loss_meter = AverageMeter() | |
| acc_meter = AverageMeter() | |
| all_preds = [] | |
| all_targets = [] | |
| for images, targets in tqdm(val_loader, desc=" Validate", leave=False): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| with autocast(enabled=use_amp and device.type == "cuda"): | |
| outputs = model(images) | |
| loss = criterion(outputs, targets) | |
| _, predicted = outputs.max(1) | |
| correct = predicted.eq(targets).sum().item() | |
| loss_meter.update(loss.item(), images.size(0)) | |
| acc_meter.update(correct / targets.size(0), targets.size(0)) | |
| all_preds.extend(predicted.cpu().numpy()) | |
| all_targets.extend(targets.cpu().numpy()) | |
| return loss_meter.avg, acc_meter.avg, np.array(all_preds), np.array(all_targets) | |
| def save_checkpoint(model, optimizer, scheduler, epoch, val_acc, val_loss, | |
| config, model_name, save_dir, is_best=False): | |
| """Save model checkpoint.""" | |
| os.makedirs(save_dir, exist_ok=True) | |
| checkpoint = { | |
| "epoch": epoch, | |
| "model_name": model_name, | |
| "architecture": config["models"][model_name]["architecture"], | |
| "num_classes": config["num_classes"], | |
| "state_dict": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "scheduler": scheduler.state_dict() if scheduler else None, | |
| "val_acc": val_acc, | |
| "val_loss": val_loss, | |
| "config": config, | |
| } | |
| # Save latest | |
| torch.save(checkpoint, os.path.join(save_dir, f"{model_name}_latest.pth")) | |
| # Save best | |
| if is_best: | |
| torch.save(checkpoint, os.path.join(save_dir, f"{model_name}_best.pth")) | |
| print(f" >> Saved new best model: val_acc={val_acc:.4f}") | |
| def train_model(model_name, config, device): | |
| """Full training pipeline for a single model.""" | |
| print(f"\n{'#'*60}") | |
| print(f" TRAINING: {model_name.upper()}") | |
| print(f"{'#'*60}") | |
| training_cfg = config["training"] | |
| # Create output directories | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| run_name = f"{model_name}_{timestamp}" | |
| checkpoint_dir = os.path.join(config["paths"]["checkpoints"], run_name) | |
| log_dir = os.path.join(config["paths"]["logs"], run_name) | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| os.makedirs(log_dir, exist_ok=True) | |
| # TensorBoard | |
| writer = SummaryWriter(log_dir) | |
| # Data | |
| manifest_path = os.path.join(config["paths"]["processed_data"], "dataset_manifest.json") | |
| train_loader, val_loader, test_loader, class_to_idx = create_dataloaders(config, manifest_path) | |
| # Load class weights | |
| class_weights = None | |
| if os.path.exists(manifest_path): | |
| with open(manifest_path) as f: | |
| manifest = json.load(f) | |
| class_weights = manifest.get("class_weights") | |
| # Model | |
| model = build_model(model_name, config).to(device) | |
| # Loss | |
| criterion = build_loss_function(config, class_weights, device) | |
| # Optimizer & Scheduler | |
| optimizer = get_optimizer(model, config) | |
| scheduler = get_scheduler(optimizer, config) | |
| # Mixed Precision | |
| scaler = GradScaler(enabled=training_cfg.get("use_amp", True) and device.type == "cuda") | |
| # MixUp/CutMix | |
| mixup_cutmix = MixUpCutMix( | |
| mixup_alpha=config["augmentation"].get("mixup_alpha", 0.2), | |
| cutmix_alpha=config["augmentation"].get("cutmix_alpha", 1.0), | |
| num_classes=config["num_classes"], | |
| ) | |
| # Early Stopping | |
| early_stopping = EarlyStopping( | |
| patience=training_cfg["early_stopping_patience"], mode="max" | |
| ) | |
| # ========== Phase 1: Frozen Backbone ========== | |
| print("\n --- Phase 1: Training classifier head (backbone frozen) ---") | |
| model.freeze_backbone() | |
| frozen_epochs = min(5, training_cfg["num_epochs"] // 5) | |
| best_val_acc = 0.0 | |
| history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []} | |
| for epoch in range(frozen_epochs): | |
| train_loss, train_acc = train_one_epoch( | |
| model, train_loader, criterion, optimizer, scheduler, | |
| scaler, None, device, epoch, config # No mixup for frozen phase | |
| ) | |
| val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device) | |
| print(f" Epoch {epoch+1}/{frozen_epochs} | " | |
| f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | " | |
| f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}") | |
| writer.add_scalars("Phase1/Loss", {"train": train_loss, "val": val_loss}, epoch) | |
| writer.add_scalars("Phase1/Accuracy", {"train": train_acc, "val": val_acc}, epoch) | |
| # ========== Phase 2: Gradual Unfreezing ========== | |
| print("\n --- Phase 2: Fine-tuning (progressive unfreezing) ---") | |
| model.unfreeze_backbone(unfreeze_from=0.5) | |
| # Reset optimizer with discriminative LR | |
| optimizer = get_optimizer(model, config) | |
| scheduler = get_scheduler(optimizer, config) | |
| total_epochs = training_cfg["num_epochs"] | |
| for epoch in range(total_epochs): | |
| # Warmup | |
| warmup_lr(optimizer, epoch, | |
| training_cfg.get("warmup_epochs", 5), | |
| training_cfg.get("warmup_lr", 1e-5), | |
| training_cfg["learning_rate"]) | |
| # Progressive unfreezing at epoch milestones | |
| if epoch == total_epochs // 4: | |
| model.unfreeze_backbone(unfreeze_from=0.25) | |
| elif epoch == total_epochs // 2: | |
| model.unfreeze_backbone(unfreeze_from=0.0) # Fully unfreeze | |
| # Train | |
| train_loss, train_acc = train_one_epoch( | |
| model, train_loader, criterion, optimizer, scheduler, | |
| scaler, mixup_cutmix, device, epoch, config | |
| ) | |
| # Validate | |
| val_loss, val_acc, val_preds, val_targets = validate( | |
| model, val_loader, criterion, device | |
| ) | |
| # History | |
| history["train_loss"].append(train_loss) | |
| history["train_acc"].append(train_acc) | |
| history["val_loss"].append(val_loss) | |
| history["val_acc"].append(val_acc) | |
| # TensorBoard | |
| writer.add_scalars("Phase2/Loss", {"train": train_loss, "val": val_loss}, epoch) | |
| writer.add_scalars("Phase2/Accuracy", {"train": train_acc, "val": val_acc}, epoch) | |
| writer.add_scalar("LR", optimizer.param_groups[-1]["lr"], epoch) | |
| # Save checkpoint | |
| is_best = val_acc > best_val_acc | |
| if is_best: | |
| best_val_acc = val_acc | |
| save_checkpoint( | |
| model, optimizer, scheduler, epoch, val_acc, val_loss, | |
| config, model_name, checkpoint_dir, is_best | |
| ) | |
| print(f" Epoch {epoch+1}/{total_epochs} | " | |
| f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | " | |
| f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} | " | |
| f"Best: {best_val_acc:.4f} {'*' if is_best else ''}") | |
| # Early Stopping | |
| if early_stopping(val_acc): | |
| print(f"\n >> Early stopping at epoch {epoch+1}") | |
| break | |
| # Save training history | |
| history_path = os.path.join(checkpoint_dir, "training_history.json") | |
| with open(history_path, "w") as f: | |
| json.dump(history, f, indent=2) | |
| writer.close() | |
| print(f"\n {'='*50}") | |
| print(f" TRAINING COMPLETE: {model_name}") | |
| print(f" Best Validation Accuracy: {best_val_acc:.4f}") | |
| print(f" Checkpoints: {checkpoint_dir}") | |
| print(f" TensorBoard: {log_dir}") | |
| print(f" {'='*50}") | |
| return best_val_acc, history | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train Rangoli Classifier") | |
| parser.add_argument("--config", type=str, default="configs/config.yaml") | |
| parser.add_argument("--model", type=str, default="resnet50", | |
| choices=["resnet50", "efficientnet_b3", "vit_base", | |
| "convnext_small", "mobilenet_v3", "swin_transformer", "all"]) | |
| parser.add_argument("--gpu", type=int, default=0) | |
| parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint") | |
| args = parser.parse_args() | |
| # Load config | |
| with open(args.config, "r") as f: | |
| config = yaml.safe_load(f) | |
| # Device | |
| if torch.cuda.is_available(): | |
| device = torch.device(f"cuda:{args.gpu}") | |
| print(f" Using GPU: {torch.cuda.get_device_name(args.gpu)}") | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| print(" Using Apple MPS") | |
| else: | |
| device = torch.device("cpu") | |
| print(" Using CPU (training will be slow)") | |
| # Train | |
| if args.model == "all": | |
| results = {} | |
| model_names = list(config["models"].keys()) | |
| for model_name in model_names: | |
| best_acc, history = train_model(model_name, config, device) | |
| results[model_name] = {"best_val_acc": best_acc, "epochs": len(history["val_acc"])} | |
| # Summary | |
| print("\n" + "="*60) | |
| print(" COMPARATIVE RESULTS") | |
| print("="*60) | |
| for name, res in sorted(results.items(), key=lambda x: x[1]["best_val_acc"], reverse=True): | |
| print(f" {name:25s} : {res['best_val_acc']:.4f} ({res['epochs']} epochs)") | |
| # Save results | |
| results_path = os.path.join(config["paths"]["reports"], "comparative_results.json") | |
| os.makedirs(os.path.dirname(results_path), exist_ok=True) | |
| with open(results_path, "w") as f: | |
| json.dump(results, f, indent=2) | |
| else: | |
| train_model(args.model, config, device) | |
| if __name__ == "__main__": | |
| main() | |