import os import time import torch from src.args import parse_arguments from src.datasets.common import get_dataloader, maybe_dictionarize from src.datasets.registry import get_dataset from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp from src.eval import eval_single_dataset from src.heads import get_classification_head from src.linearize import LinearizedImageEncoder from src.modeling import ImageClassifier, ImageEncoder from src.attention_only_finetune import AttentionOnlyFinetuneEncoder from src.utils import LabelSmoothing, cosine_lr, accuracy def finetune(rank, args): setup_ddp(rank, args.world_size, port=args.port) train_dataset = args.train_dataset ckpdir = os.path.join(args.save, train_dataset) valid_modes = [ "standard", "standard_ortho", "linear", "linear_ortho", "linear-2", "linear-2_ortho", ] assert args.finetuning_mode in valid_modes, f"Mode {args.finetuning_mode} not supported." is_linearized = args.finetuning_mode in ("linear", "linear_ortho") is_linear2 = args.finetuning_mode in ("linear-2", "linear-2_ortho") is_standard_ortho = args.finetuning_mode == "standard_ortho" is_linear_ortho = args.finetuning_mode == "linear_ortho" is_linear2_ortho = args.finetuning_mode == "linear-2_ortho" needs_ortho = is_standard_ortho or is_linear_ortho or is_linear2_ortho print(f"Using fine-tuning mode: {args.finetuning_mode}") if needs_ortho and args.ortho_lambda > 0: print(f" -> With OrthoReg (lambda={args.ortho_lambda})") mode_prefix_map = { "standard": "", "standard_ortho": "standard_ortho", "linear": "linear", "linear_ortho": "linear_ortho", "linear-2": "linear-2", "linear-2_ortho": "linear-2_ortho", } mode_prefix = mode_prefix_map[args.finetuning_mode] ft_path = os.path.join(ckpdir, f"{mode_prefix}_finetuned.pt" if mode_prefix else "finetuned.pt") zs_path = os.path.join(ckpdir, f"{mode_prefix}_zeroshot.pt" if mode_prefix else "zeroshot.pt") if os.path.exists(zs_path) and os.path.exists(ft_path): print(f"Skipping fine-tuning because {ft_path} exists.") return zs_path, ft_path assert train_dataset is not None, "Please provide a training dataset." if args.load is not None and args.load.endswith("pt"): if is_linearized: image_encoder = LinearizedImageEncoder.load(args.load) elif is_linear2: image_encoder = AttentionOnlyFinetuneEncoder.load(args.load, args) else: image_encoder = ImageEncoder.load(args.load) else: print("Building image encoder.") if is_linearized: image_encoder = LinearizedImageEncoder(args, keep_lang=False) elif is_linear2: image_encoder = AttentionOnlyFinetuneEncoder(args, keep_lang=False) else: image_encoder = ImageEncoder(args) # Save a frozen copy of pretrained weights for ortho loss (standard_ortho / linear-2_ortho) pretrained_state_dict_ref = None if is_standard_ortho or is_linear2_ortho: print("Saving pretrained state dict reference for ortho loss.") pretrained_state_dict_ref = { k: v.clone().detach() for k, v in image_encoder.model.state_dict().items() } classification_head = get_classification_head(args, train_dataset) model = ImageClassifier(image_encoder, classification_head) model.freeze_head() model = model.cuda() preprocess_fn = model.train_preprocess print_every = 100 dataset = get_dataset( train_dataset, preprocess_fn, location=args.data_location, batch_size=args.batch_size, ) data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None) num_batches = len(dataset.train_loader) ddp_loader = distribute_loader(data_loader) ddp_model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[rank], find_unused_parameters=True, output_device=rank, ) loss_fn = LabelSmoothing(args.ls) if args.ls > 0 else torch.nn.CrossEntropyLoss() params = [p for p in ddp_model.parameters() if p.requires_grad] optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) scheduler = cosine_lr( optimizer, args.lr, args.warmup_length, args.epochs * num_batches // args.num_grad_accumulation, ) if args.save is not None and is_main_process(): os.makedirs(ckpdir, exist_ok=True) ddp_model.module.image_encoder.save(zs_path) for epoch in range(args.epochs): ddp_model.train() for i, batch in enumerate(ddp_loader): start_time = time.time() step = ( i // args.num_grad_accumulation + epoch * num_batches // args.num_grad_accumulation ) batch = maybe_dictionarize(batch) inputs = batch["images"].cuda() labels = batch["labels"].cuda() data_time = time.time() - start_time ortho_loss = 0.0 if needs_ortho and args.ortho_lambda > 0: logits, ortho_loss = ddp_model( inputs, calculate_ortho_loss=True, pretrained_state_dict=pretrained_state_dict_ref, ) else: logits = ddp_model(inputs) classification_loss = loss_fn(logits, labels) loss = classification_loss + args.ortho_lambda * ortho_loss (acc1,) = accuracy(logits, labels, topk=(1,)) acc1 /= labels.size(0) loss.backward() if (i + 1) % args.num_grad_accumulation == 0: scheduler(step) torch.nn.utils.clip_grad_norm_(params, 1.0) optimizer.step() optimizer.zero_grad() batch_time = time.time() - start_time if ( args.checkpoint_every > 0 and step % args.checkpoint_every == 0 and is_main_process() ): ckpt_name = f"{mode_prefix}_checkpoint_{step}.pt" if mode_prefix else f"checkpoint_{step}.pt" ddp_model.module.image_encoder.save(os.path.join(ckpdir, ckpt_name)) if ( step % print_every == 0 and ((i + 1) % args.num_grad_accumulation == 0) and is_main_process() ): percent_complete = 100 * i / len(ddp_loader) log_msg = ( f"Train Epoch: {epoch} [{percent_complete:.0f}%]\t" f"Total Loss: {loss.item():.6f}\t" f"CE Loss: {classification_loss.item():.6f}\t" ) if needs_ortho and args.ortho_lambda > 0: log_msg += f"Ortho Loss: {ortho_loss.item():.6f}\t" log_msg += f"Acc@1: {100*acc1:.2f}%\tData (t) {data_time:.3f}" print(log_msg, flush=True) if is_main_process(): image_encoder = ddp_model.module.image_encoder eval_single_dataset(image_encoder, train_dataset, args) if args.save is not None and is_main_process(): image_encoder.save(ft_path) return zs_path, ft_path cleanup_ddp() if __name__ == "__main__": train_datasets = [ "Cars", "DTD", "EuroSAT", "GTSRB", "MNIST", "RESISC45", "SUN397", "SVHN", ] epochs = { "Cars": 35, "DTD": 76, "EuroSAT": 12, "GTSRB": 11, "MNIST": 5, "RESISC45": 15, "SUN397": 14, "SVHN": 4, } for dataset in train_datasets: args = parse_arguments() args.epochs = epochs[dataset] args.train_dataset = dataset + "Val" args.batch_size = 64 if args.model == "ViT-L-14" else 128 args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1 if 'ortho' in args.finetuning_mode: args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}" else: if args.seed is not None: args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}" else: args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}" print("=" * 100) print(f"Finetuning {args.model} on {dataset}") print("=" * 100) torch.multiprocessing.spawn(finetune, args=(args,), nprocs=args.world_size)