| import os |
| import time |
|
|
| import torch |
|
|
| from src.args import parse_arguments |
| from src.datasets.common import get_dataloader, maybe_dictionarize |
| from src.datasets.registry import get_dataset |
| from src.distributed import cleanup_ddp, distribute_loader, is_main_process, setup_ddp |
| from src.eval import eval_single_dataset |
| from src.heads import get_classification_head |
| from src.linearize import LinearizedImageEncoder |
| from src.modeling import ImageClassifier, ImageEncoder |
| from src.attention_only_finetune import AttentionOnlyFinetuneEncoder |
| from src.utils import LabelSmoothing, cosine_lr, accuracy |
|
|
|
|
| def finetune(rank, args): |
| setup_ddp(rank, args.world_size, port=args.port) |
|
|
| train_dataset = args.train_dataset |
| ckpdir = os.path.join(args.save, train_dataset) |
|
|
| valid_modes = [ |
| "standard", "standard_ortho", |
| "linear", "linear_ortho", |
| "linear-2", "linear-2_ortho", |
| ] |
| assert args.finetuning_mode in valid_modes, f"Mode {args.finetuning_mode} not supported." |
|
|
| is_linearized = args.finetuning_mode in ("linear", "linear_ortho") |
| is_linear2 = args.finetuning_mode in ("linear-2", "linear-2_ortho") |
| is_standard_ortho = args.finetuning_mode == "standard_ortho" |
| is_linear_ortho = args.finetuning_mode == "linear_ortho" |
| is_linear2_ortho = args.finetuning_mode == "linear-2_ortho" |
| needs_ortho = is_standard_ortho or is_linear_ortho or is_linear2_ortho |
|
|
| print(f"Using fine-tuning mode: {args.finetuning_mode}") |
| if needs_ortho and args.ortho_lambda > 0: |
| print(f" -> With OrthoReg (lambda={args.ortho_lambda})") |
|
|
| mode_prefix_map = { |
| "standard": "", |
| "standard_ortho": "standard_ortho", |
| "linear": "linear", |
| "linear_ortho": "linear_ortho", |
| "linear-2": "linear-2", |
| "linear-2_ortho": "linear-2_ortho", |
| } |
| mode_prefix = mode_prefix_map[args.finetuning_mode] |
|
|
| ft_path = os.path.join(ckpdir, f"{mode_prefix}_finetuned.pt" if mode_prefix else "finetuned.pt") |
| zs_path = os.path.join(ckpdir, f"{mode_prefix}_zeroshot.pt" if mode_prefix else "zeroshot.pt") |
|
|
| if os.path.exists(zs_path) and os.path.exists(ft_path): |
| print(f"Skipping fine-tuning because {ft_path} exists.") |
| return zs_path, ft_path |
|
|
| assert train_dataset is not None, "Please provide a training dataset." |
|
|
| if args.load is not None and args.load.endswith("pt"): |
| if is_linearized: |
| image_encoder = LinearizedImageEncoder.load(args.load) |
| elif is_linear2: |
| image_encoder = AttentionOnlyFinetuneEncoder.load(args.load, args) |
| else: |
| image_encoder = ImageEncoder.load(args.load) |
| else: |
| print("Building image encoder.") |
| if is_linearized: |
| image_encoder = LinearizedImageEncoder(args, keep_lang=False) |
| elif is_linear2: |
| image_encoder = AttentionOnlyFinetuneEncoder(args, keep_lang=False) |
| else: |
| image_encoder = ImageEncoder(args) |
|
|
| |
| pretrained_state_dict_ref = None |
| if is_standard_ortho or is_linear2_ortho: |
| print("Saving pretrained state dict reference for ortho loss.") |
| pretrained_state_dict_ref = { |
| k: v.clone().detach() for k, v in image_encoder.model.state_dict().items() |
| } |
|
|
| classification_head = get_classification_head(args, train_dataset) |
| model = ImageClassifier(image_encoder, classification_head) |
| model.freeze_head() |
| model = model.cuda() |
|
|
| preprocess_fn = model.train_preprocess |
| print_every = 100 |
|
|
| dataset = get_dataset( |
| train_dataset, |
| preprocess_fn, |
| location=args.data_location, |
| batch_size=args.batch_size, |
| ) |
| data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None) |
| num_batches = len(dataset.train_loader) |
|
|
| ddp_loader = distribute_loader(data_loader) |
| ddp_model = torch.nn.parallel.DistributedDataParallel( |
| model, |
| device_ids=[rank], |
| find_unused_parameters=True, |
| output_device=rank, |
| ) |
|
|
| loss_fn = LabelSmoothing(args.ls) if args.ls > 0 else torch.nn.CrossEntropyLoss() |
|
|
| params = [p for p in ddp_model.parameters() if p.requires_grad] |
| optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd) |
| scheduler = cosine_lr( |
| optimizer, |
| args.lr, |
| args.warmup_length, |
| args.epochs * num_batches // args.num_grad_accumulation, |
| ) |
|
|
| if args.save is not None and is_main_process(): |
| os.makedirs(ckpdir, exist_ok=True) |
| ddp_model.module.image_encoder.save(zs_path) |
|
|
| for epoch in range(args.epochs): |
| ddp_model.train() |
|
|
| for i, batch in enumerate(ddp_loader): |
| start_time = time.time() |
| step = ( |
| i // args.num_grad_accumulation |
| + epoch * num_batches // args.num_grad_accumulation |
| ) |
|
|
| batch = maybe_dictionarize(batch) |
| inputs = batch["images"].cuda() |
| labels = batch["labels"].cuda() |
| data_time = time.time() - start_time |
|
|
| ortho_loss = 0.0 |
| if needs_ortho and args.ortho_lambda > 0: |
| logits, ortho_loss = ddp_model( |
| inputs, |
| calculate_ortho_loss=True, |
| pretrained_state_dict=pretrained_state_dict_ref, |
| ) |
| else: |
| logits = ddp_model(inputs) |
|
|
| classification_loss = loss_fn(logits, labels) |
| loss = classification_loss + args.ortho_lambda * ortho_loss |
|
|
| (acc1,) = accuracy(logits, labels, topk=(1,)) |
| acc1 /= labels.size(0) |
|
|
| loss.backward() |
|
|
| if (i + 1) % args.num_grad_accumulation == 0: |
| scheduler(step) |
| torch.nn.utils.clip_grad_norm_(params, 1.0) |
| optimizer.step() |
| optimizer.zero_grad() |
|
|
| batch_time = time.time() - start_time |
|
|
| if ( |
| args.checkpoint_every > 0 |
| and step % args.checkpoint_every == 0 |
| and is_main_process() |
| ): |
| ckpt_name = f"{mode_prefix}_checkpoint_{step}.pt" if mode_prefix else f"checkpoint_{step}.pt" |
| ddp_model.module.image_encoder.save(os.path.join(ckpdir, ckpt_name)) |
|
|
| if ( |
| step % print_every == 0 |
| and ((i + 1) % args.num_grad_accumulation == 0) |
| and is_main_process() |
| ): |
| percent_complete = 100 * i / len(ddp_loader) |
| log_msg = ( |
| f"Train Epoch: {epoch} [{percent_complete:.0f}%]\t" |
| f"Total Loss: {loss.item():.6f}\t" |
| f"CE Loss: {classification_loss.item():.6f}\t" |
| ) |
| if needs_ortho and args.ortho_lambda > 0: |
| log_msg += f"Ortho Loss: {ortho_loss.item():.6f}\t" |
| log_msg += f"Acc@1: {100*acc1:.2f}%\tData (t) {data_time:.3f}" |
| print(log_msg, flush=True) |
|
|
| if is_main_process(): |
| image_encoder = ddp_model.module.image_encoder |
| eval_single_dataset(image_encoder, train_dataset, args) |
|
|
| if args.save is not None and is_main_process(): |
| image_encoder.save(ft_path) |
| return zs_path, ft_path |
|
|
| cleanup_ddp() |
|
|
|
|
| if __name__ == "__main__": |
| train_datasets = [ |
| "Cars", |
| "DTD", |
| "EuroSAT", |
| "GTSRB", |
| "MNIST", |
| "RESISC45", |
| "SUN397", |
| "SVHN", |
| ] |
| epochs = { |
| "Cars": 35, |
| "DTD": 76, |
| "EuroSAT": 12, |
| "GTSRB": 11, |
| "MNIST": 5, |
| "RESISC45": 15, |
| "SUN397": 14, |
| "SVHN": 4, |
| } |
|
|
| for dataset in train_datasets: |
| args = parse_arguments() |
|
|
| args.epochs = epochs[dataset] |
| args.train_dataset = dataset + "Val" |
|
|
| args.batch_size = 64 if args.model == "ViT-L-14" else 128 |
| args.num_grad_accumulation = 2 if args.model == "ViT-L-14" else 1 |
|
|
| if 'ortho' in args.finetuning_mode: |
| args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_lambda{args.ortho_lambda}_{args.model}" |
| else: |
| if args.seed is not None: |
| args.save = f"checkpoints_{args.seed}/{args.finetuning_mode}_{args.lr}_{args.model}" |
| else: |
| args.save = f"checkpoints/{args.finetuning_mode}_{args.lr}_{args.model}" |
|
|
| print("=" * 100) |
| print(f"Finetuning {args.model} on {dataset}") |
| print("=" * 100) |
| torch.multiprocessing.spawn(finetune, args=(args,), nprocs=args.world_size) |
|
|