Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn, Tensor | |
| from torch.optim import SGD, Adam, AdamW, RAdam | |
| from torch.amp import GradScaler | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from functools import partial | |
| from argparse import ArgumentParser | |
| import os, sys, math | |
| from typing import Union, Tuple, Dict, List, Optional | |
| from collections import OrderedDict | |
| parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| sys.path.append(parent_dir) | |
| import losses | |
| def _check_lr(lr: float, eta_min: float) -> None: | |
| assert lr > eta_min > 0, f"lr and eta_min must satisfy 0 < eta_min < lr, got lr={lr} and eta_min={eta_min}." | |
| def _check_warmup(warmup_epochs: int, warmup_lr: float) -> None: | |
| assert warmup_epochs >= 0, f"warmup_epochs must be non-negative, got {warmup_epochs}." | |
| assert warmup_lr > 0, f"warmup_lr must be positive, got {warmup_lr}." | |
| def _warmup_lr( | |
| epoch: int, | |
| base_lr: float, | |
| warmup_epochs: int, | |
| warmup_lr: float, | |
| ) -> float: | |
| """ | |
| Linear Warmup | |
| """ | |
| base_lr, warmup_lr = float(base_lr), float(warmup_lr) | |
| assert epoch >= 0, f"epoch must be non-negative, got {epoch}." | |
| _check_warmup(warmup_epochs, warmup_lr) | |
| if epoch < warmup_epochs: | |
| # Compute the current learning rate in log-linear scale | |
| lr = math.exp(math.log(warmup_lr) + epoch * (math.log(base_lr) - math.log(warmup_lr)) / warmup_epochs) | |
| else: | |
| lr = base_lr | |
| return lr | |
| def step_decay( | |
| epoch: int, | |
| base_lr: float, | |
| warmup_epochs: int, | |
| warmup_lr: float, | |
| step_size: int, | |
| gamma: float, | |
| eta_min: float, | |
| ) -> float: | |
| """ | |
| Warmup + Step Decay | |
| """ | |
| base_lr, warmup_lr, eta_min = float(base_lr), float(warmup_lr), float(eta_min) | |
| assert epoch >= 0, f"epoch must be non-negative, got {epoch}." | |
| assert step_size >= 1, f"step_size must be greater than or equal to 1, got {step_size}." | |
| assert 0 < gamma < 1, f"gamma must be in the range (0, 1), got {gamma}." | |
| _check_lr(base_lr, eta_min) | |
| _check_warmup(warmup_epochs, warmup_lr) | |
| if epoch < warmup_epochs: | |
| lr = _warmup_lr(epoch, base_lr, warmup_epochs, warmup_lr) | |
| else: | |
| epoch -= warmup_epochs | |
| lr = base_lr * (gamma ** (epoch // step_size)) | |
| lr = max(lr, eta_min) | |
| return lr / base_lr | |
| def cosine_annealing( | |
| epoch: int, | |
| base_lr: float, | |
| warmup_epochs: int, | |
| warmup_lr: float, | |
| T_max: int, | |
| eta_min: float, | |
| ) -> float: | |
| """ | |
| Warmup + Cosine Annealing | |
| """ | |
| base_lr, warmup_lr, eta_min = float(base_lr), float(warmup_lr), float(eta_min) | |
| assert epoch >= 0, f"epoch must be non-negative, got {epoch}." | |
| assert T_max >= 1, f"T_max must be greater than or equal to 1, got {T_max}." | |
| _check_lr(base_lr, eta_min) | |
| _check_warmup(warmup_epochs, warmup_lr) | |
| if epoch < warmup_epochs: | |
| lr = _warmup_lr(epoch, base_lr, warmup_epochs, warmup_lr) | |
| else: | |
| epoch -= warmup_epochs | |
| lr = eta_min + (base_lr - eta_min) * (1 + math.cos(math.pi * epoch / T_max)) / 2 | |
| return lr / base_lr | |
| def cosine_annealing_warm_restarts( | |
| epoch: int, | |
| base_lr: float, | |
| warmup_epochs: int, | |
| warmup_lr: float, | |
| T_0: int, | |
| T_mult: int, | |
| eta_min: float, | |
| ) -> float: | |
| """ | |
| Warmup + Cosine Annealing with Warm Restarts | |
| """ | |
| base_lr, warmup_lr, eta_min = float(base_lr), float(warmup_lr), float(eta_min) | |
| assert epoch >= 0, f"epoch must be non-negative, got {epoch}." | |
| assert isinstance(T_0, int) and T_0 >= 1, f"T_0 must be greater than or equal to 1, got {T_0}." | |
| assert isinstance(T_mult, int) and T_mult >= 1, f"T_mult must be greater than or equal to 1, got {T_mult}." | |
| _check_lr(base_lr, eta_min) | |
| _check_warmup(warmup_epochs, warmup_lr) | |
| if epoch < warmup_epochs: | |
| lr = _warmup_lr(epoch, base_lr, warmup_epochs, warmup_lr) | |
| else: | |
| epoch -= warmup_epochs | |
| if T_mult == 1: | |
| T_cur = epoch % T_0 | |
| T_i = T_0 | |
| else: | |
| n = int(math.log((epoch / T_0 * (T_mult - 1) + 1), T_mult)) | |
| T_cur = epoch - T_0 * (T_mult ** n - 1) / (T_mult - 1) | |
| T_i = T_0 * T_mult ** (n) | |
| lr = eta_min + (base_lr - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2 | |
| return lr / base_lr | |
| def get_loss_fn(args: ArgumentParser) -> nn.Module: | |
| return losses.QuadLoss( | |
| input_size=args.input_size, | |
| block_size=args.block_size, | |
| bins=args.bins, | |
| reg_loss=args.reg_loss, | |
| aux_loss=args.aux_loss, | |
| weight_cls=args.weight_cls, | |
| weight_reg=args.weight_reg, | |
| weight_aux=args.weight_aux, | |
| numItermax=args.numItermax, | |
| regularization=args.regularization, | |
| scales=args.scales, | |
| min_scale_weight=args.min_scale_weight, | |
| max_scale_weight=args.max_scale_weight, | |
| alpha=args.alpha, | |
| ) | |
| def get_optimizer( | |
| args: ArgumentParser, | |
| model: nn.Module | |
| ) -> Tuple[Union[SGD, Adam, AdamW, RAdam], LambdaLR]: | |
| backbone_params = [] | |
| new_params = [] | |
| vpt_params = [] | |
| adpater_params = [] | |
| for name, param in model.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| if "vpt" in name: | |
| vpt_params.append(param) | |
| elif "adapter" in name: | |
| adpater_params.append(param) | |
| elif "backbone" not in name or ("refiner" in name or "decoder" in name): | |
| new_params.append(param) | |
| else: | |
| backbone_params.append(param) | |
| if args.num_vpt is not None: # using VTP to tune ViT-based model | |
| assert len(backbone_params) == 0, f"Expected backbone_params to be empty when using VTP, got {len(backbone_params)}" | |
| assert len(adpater_params) == 0, f"Expected adpater_params to be empty when using VTP, got {len(adpater_params)}" | |
| param_groups = [ | |
| {"params": vpt_params,"lr": args.vpt_lr, "weight_decay": args.vpt_weight_decay}, | |
| {"params": new_params, "lr": args.lr, "weight_decay": args.weight_decay}, | |
| ] | |
| elif args.adapter: # using adapter to tune CLIP-based model | |
| assert len(backbone_params) == 0, f"Expected backbone_params to be empty when using adapter, got {len(backbone_params)}" | |
| assert len(vpt_params) == 0, f"Expected vpt_params to be empty when using adapter, got {len(vpt_params)}" | |
| param_groups = [ | |
| {"params": adpater_params, "lr": args.adapter_lr, "weight_decay": args.adapter_weight_decay}, | |
| {"params": new_params, "lr": args.lr, "weight_decay": args.weight_decay}, | |
| ] | |
| else: | |
| param_groups = [ | |
| {"params": new_params, "lr": args.lr, "weight_decay": args.weight_decay}, | |
| {"params": backbone_params, "lr": args.backbone_lr, "weight_decay": args.backbone_weight_decay} | |
| ] | |
| if args.optimizer == "adam": | |
| optimizer = Adam(param_groups) | |
| elif args.optimizer == "adamw": | |
| optimizer = AdamW(param_groups) | |
| elif args.optimizer == "sgd": | |
| optimizer = SGD(param_groups, momentum=0.9) | |
| else: | |
| assert args.optimizer == "radam", f"Expected optimizer to be one of ['adam', 'adamw', 'sgd', 'radam'], got {args.optimizer}." | |
| optimizer = RAdam(param_groups, decoupled_weight_decay=True) | |
| if args.scheduler == "step": | |
| lr_lambda = partial( | |
| step_decay, | |
| base_lr=args.lr, | |
| warmup_epochs=args.warmup_epochs, | |
| warmup_lr=args.warmup_lr, | |
| step_size=args.step_size, | |
| eta_min=args.eta_min, | |
| gamma=args.gamma, | |
| ) | |
| elif args.scheduler == "cos": | |
| lr_lambda = partial( | |
| cosine_annealing, | |
| base_lr=args.lr, | |
| warmup_epochs=args.warmup_epochs, | |
| warmup_lr=args.warmup_lr, | |
| T_max=args.T_max, | |
| eta_min=args.eta_min, | |
| ) | |
| elif args.scheduler == "cos_restarts": | |
| lr_lambda = partial( | |
| cosine_annealing_warm_restarts, | |
| warmup_epochs=args.warmup_epochs, | |
| warmup_lr=args.warmup_lr, | |
| T_0=args.T_0, | |
| T_mult=args.T_mult, | |
| eta_min=args.eta_min, | |
| base_lr=args.lr | |
| ) | |
| scheduler = LambdaLR( | |
| optimizer=optimizer, | |
| lr_lambda=[lr_lambda for _ in range(len(param_groups))] | |
| ) | |
| return optimizer, scheduler | |
| def load_checkpoint( | |
| args: ArgumentParser, | |
| model: nn.Module, | |
| optimizer: Union[SGD, Adam, AdamW, RAdam], | |
| scheduler: LambdaLR, | |
| grad_scaler: GradScaler, | |
| ckpt_dir: Optional[str] = None, | |
| ) -> Tuple[nn.Module, Union[SGD, Adam, AdamW, RAdam], Union[LambdaLR, None], GradScaler, int, Union[Dict[str, float], None], Dict[str, List[float]], Dict[str, float]]: | |
| ckpt_path = os.path.join(args.ckpt_dir if ckpt_dir is None else ckpt_dir, "ckpt.pth") | |
| if not os.path.exists(ckpt_path): | |
| # Try to find other common checkpoint names if ckpt.pth is not found | |
| for alt_name in ["best_mae.pth", "best_rmse.pth", "best_nae.pth"]: | |
| alt_path = os.path.join(args.ckpt_dir if ckpt_dir is None else ckpt_dir, alt_name) | |
| if os.path.exists(alt_path): | |
| ckpt_path = alt_path | |
| break | |
| if os.path.exists(ckpt_path): | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| state_dict = ckpt["model_state_dict"] | |
| model_state_dict = model.state_dict() | |
| # Handle multi-class weight mapping | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| if k in model_state_dict: | |
| if v.shape == model_state_dict[k].shape: | |
| new_state_dict[k] = v | |
| else: | |
| print(f"Shape mismatch for {k}: checkpoint {v.shape} vs model {model_state_dict[k].shape}. Skipping.") | |
| else: | |
| # Logic for mapping single-class head to multi-class heads | |
| # e.g., mapping bin_heads.0.weight to bin_heads.1.weight, bin_heads.2.weight... | |
| if "bin_heads.0" in k or "pi_heads.0" in k: | |
| for i in range(1, args.num_classes): | |
| new_key = k.replace(".0.", f".{i}.") | |
| if new_key in model_state_dict: | |
| new_state_dict[new_key] = v | |
| print(f"Mapping {k} -> {new_key}") | |
| new_state_dict[k] = v | |
| msg = model.load_state_dict(new_state_dict, strict=False) | |
| print(f"Loaded checkpoint from {ckpt_path}.") | |
| if len(msg.missing_keys) > 0: | |
| print(f"Missing keys (initialized randomly): {len(msg.missing_keys)}") | |
| if len(msg.unexpected_keys) > 0: | |
| print(f"Unexpected keys (ignored): {len(msg.unexpected_keys)}") | |
| # Only load optimizer/scheduler if we are resuming the SAME experiment | |
| # If we changed num_classes, it's better to start with a fresh optimizer | |
| try: | |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) | |
| if scheduler is not None: | |
| scheduler.load_state_dict(ckpt["scheduler_state_dict"]) | |
| if grad_scaler is not None: | |
| grad_scaler.load_state_dict(ckpt["grad_scaler_state_dict"]) | |
| start_epoch = ckpt["epoch"] | |
| print("Resuming optimizer and scheduler state.") | |
| except: | |
| print("Optimizer/Scheduler state mismatch. Starting with fresh optimizer.") | |
| start_epoch = 1 | |
| loss_info = ckpt.get("loss_info", None) | |
| hist_scores = ckpt.get("hist_scores", None) | |
| if hist_scores is None: | |
| metrics = ["mae", "rmse", "nae"] | |
| hist_scores = {m: [] for m in metrics} | |
| for i in range(args.num_classes): | |
| for m in metrics: | |
| hist_scores[f"{m}_class_{i}"] = [] | |
| best_scores = ckpt.get("best_scores", {k: [torch.inf] * args.save_best_k for k in hist_scores.keys()}) | |
| else: | |
| print("No checkpoint found. Starting from scratch.") | |
| start_epoch = 1 | |
| loss_info = None | |
| metrics = ["mae", "rmse", "nae"] | |
| hist_scores = {m: [] for m in metrics} | |
| for i in range(args.num_classes): | |
| for m in metrics: | |
| hist_scores[f"{m}_class_{i}"] = [] | |
| best_scores = {k: [torch.inf] * args.save_best_k for k in hist_scores.keys()} | |
| return model, optimizer, scheduler, grad_scaler, start_epoch, loss_info, hist_scores, best_scores | |
| def save_checkpoint( | |
| epoch: int, | |
| model_state_dict: OrderedDict[str, Tensor], | |
| optimizer_state_dict: OrderedDict[str, Tensor], | |
| scheduler_state_dict: OrderedDict[str, Tensor], | |
| grad_scaler_state_dict: OrderedDict[str, Tensor], | |
| loss_info: Dict[str, List[float]], | |
| hist_scores: Dict[str, List[float]], | |
| best_scores: Dict[str, float], | |
| ckpt_dir: str, | |
| ) -> None: | |
| ckpt = { | |
| "epoch": epoch, | |
| "model_state_dict": model_state_dict, | |
| "optimizer_state_dict": optimizer_state_dict, | |
| "scheduler_state_dict": scheduler_state_dict, | |
| "grad_scaler_state_dict": grad_scaler_state_dict, | |
| "loss_info": loss_info, | |
| "hist_scores": hist_scores, | |
| "best_scores": best_scores, | |
| } | |
| torch.save(ckpt, os.path.join(ckpt_dir, "ckpt.pth")) | |