Spaces:
Build error
Build error
| import argparse | |
| import os | |
| from loguru import logger | |
| import torch | |
| from torch import nn | |
| from torch.cuda import amp | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| from torch.utils.tensorboard import SummaryWriter | |
| from virtex.config import Config | |
| from virtex.factories import ( | |
| DownstreamDatasetFactory, | |
| PretrainingModelFactory, | |
| OptimizerFactory, | |
| LRSchedulerFactory, | |
| ) | |
| from virtex.utils.checkpointing import CheckpointManager | |
| from virtex.utils.common import common_parser, common_setup, cycle | |
| import virtex.utils.distributed as dist | |
| from virtex.utils.metrics import TopkAccuracy | |
| from virtex.utils.timer import Timer | |
| # fmt: off | |
| parser = common_parser( | |
| description="""Do image classification with linear models and frozen | |
| feature extractor, or fine-tune the feature extractor end-to-end.""" | |
| ) | |
| group = parser.add_argument_group("Downstream config arguments.") | |
| group.add_argument( | |
| "--down-config", metavar="FILE", help="Path to a downstream config file." | |
| ) | |
| group.add_argument( | |
| "--down-config-override", nargs="*", default=[], | |
| help="A list of key-value pairs to modify downstream config params.", | |
| ) | |
| parser.add_argument_group("Checkpointing and Logging") | |
| parser.add_argument( | |
| "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], | |
| default="virtex", help="""How to initialize weights: | |
| 1. 'random' initializes all weights randomly | |
| 2. 'imagenet' initializes backbone weights from torchvision model zoo | |
| 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path | |
| - with 'torchvision', state dict would be from PyTorch's training | |
| script. | |
| - with 'virtex' it should be for our full pretrained model.""" | |
| ) | |
| parser.add_argument( | |
| "--log-every", type=int, default=50, | |
| help="""Log training curves to tensorboard after every these many iterations | |
| only master process logs averaged loss values across processes.""", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-path", | |
| help="""Path to load checkpoint and run downstream task evaluation. The | |
| name of checkpoint file is required to be `model_*.pth`, where * is | |
| iteration number from which the checkpoint was serialized.""" | |
| ) | |
| parser.add_argument( | |
| "--checkpoint-every", type=int, default=5000, | |
| help="""Serialize model to a checkpoint after every these many iterations. | |
| For ImageNet, (5005 iterations = 1 epoch); for iNaturalist (1710 iterations | |
| = 1 epoch).""", | |
| ) | |
| # fmt: on | |
| def main(_A: argparse.Namespace): | |
| if _A.num_gpus_per_machine == 0: | |
| # Set device as CPU if num_gpus_per_machine = 0. | |
| device = torch.device("cpu") | |
| else: | |
| # Get the current device as set for current distributed process. | |
| # Check `launch` function in `virtex.utils.distributed` module. | |
| device = torch.cuda.current_device() | |
| # Create a downstream config object (this will be immutable) and perform | |
| # common setup such as logging and setting up serialization directory. | |
| _DOWNC = Config(_A.down_config, _A.down_config_override) | |
| common_setup(_DOWNC, _A, job_type="downstream") | |
| # Create a (pretraining) config object and backup in serializaion directory. | |
| _C = Config(_A.config, _A.config_override) | |
| _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) | |
| # Get dataset name for tensorboard logging. | |
| DATASET = _DOWNC.DATA.ROOT.split("/")[-1] | |
| # Set number of output classes according to dataset: | |
| NUM_CLASSES_MAPPING = {"imagenet": 1000, "inaturalist": 8142} | |
| NUM_CLASSES = NUM_CLASSES_MAPPING[DATASET] | |
| # ------------------------------------------------------------------------- | |
| # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER | |
| # ------------------------------------------------------------------------- | |
| train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="train") | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), | |
| num_workers=_A.cpu_workers, | |
| sampler=DistributedSampler( | |
| train_dataset, | |
| num_replicas=dist.get_world_size(), | |
| rank=dist.get_rank(), | |
| shuffle=True, | |
| ), | |
| drop_last=False, | |
| pin_memory=True, | |
| collate_fn=train_dataset.collate_fn, | |
| ) | |
| val_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="val") | |
| val_dataloader = DataLoader( | |
| val_dataset, | |
| batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), | |
| num_workers=_A.cpu_workers, | |
| sampler=DistributedSampler( | |
| val_dataset, | |
| num_replicas=dist.get_world_size(), | |
| rank=dist.get_rank(), | |
| shuffle=False, | |
| ), | |
| pin_memory=True, | |
| drop_last=False, | |
| collate_fn=val_dataset.collate_fn, | |
| ) | |
| # Initialize model using pretraining config. | |
| pretrained_model = PretrainingModelFactory.from_config(_C) | |
| # Load weights according to the init method, do nothing for `random`, and | |
| # `imagenet` is already taken care of. | |
| if _A.weight_init == "virtex": | |
| CheckpointManager(model=pretrained_model).load(_A.checkpoint_path) | |
| elif _A.weight_init == "torchvision": | |
| # Keep strict=False because this state dict may have weights for | |
| # last fc layer. | |
| pretrained_model.visual.cnn.load_state_dict( | |
| torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], | |
| strict=False, | |
| ) | |
| # Pull out the CNN (torchvision-like) from our pretrained model and add | |
| # back the FC layer - this is exists in torchvision models, and is set to | |
| # `nn.Identity()` during pretraining. | |
| model = pretrained_model.visual.cnn # type: ignore | |
| model.fc = nn.Linear(_DOWNC.MODEL.VISUAL.FEATURE_SIZE, NUM_CLASSES).to(device) | |
| model = model.to(device) | |
| # Re-initialize the FC layer. | |
| torch.nn.init.normal_(model.fc.weight.data, mean=0.0, std=0.01) | |
| torch.nn.init.constant_(model.fc.bias.data, 0.0) | |
| # Freeze all layers except FC as per config param. | |
| if _DOWNC.MODEL.VISUAL.FROZEN: | |
| # Set model to eval mode to prevent BatchNorm from updating running | |
| # mean and std. With only a linear layer, being in eval mode when | |
| # training will not matter anyway. | |
| model.eval() | |
| for name, param in model.named_parameters(): | |
| if "fc" not in name: | |
| param.requires_grad = False | |
| # Cross entropy loss and accuracy meter. | |
| criterion = nn.CrossEntropyLoss() | |
| top1 = TopkAccuracy(top_k=1) | |
| optimizer = OptimizerFactory.from_config(_DOWNC, model.named_parameters()) | |
| scheduler = LRSchedulerFactory.from_config(_DOWNC, optimizer) | |
| del pretrained_model | |
| # ------------------------------------------------------------------------- | |
| # BEFORE TRAINING STARTS | |
| # ------------------------------------------------------------------------- | |
| # Create a gradient scaler for automatic mixed precision. | |
| scaler = amp.GradScaler(enabled=_DOWNC.AMP) | |
| # Create an iterator from dataloader to sample batches perpetually. | |
| train_dataloader_iter = cycle(train_dataloader, device) | |
| if dist.get_world_size() > 1: | |
| dist.synchronize() | |
| model = nn.parallel.DistributedDataParallel( | |
| model, device_ids=[device], find_unused_parameters=True | |
| ) | |
| if dist.is_master_process(): | |
| checkpoint_manager = CheckpointManager( | |
| _A.serialization_dir, | |
| model=model, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| ) | |
| tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) | |
| # Keep track of time per iteration and ETA. | |
| timer = Timer(start_from=1, total_iterations=_DOWNC.OPTIM.NUM_ITERATIONS) | |
| # ------------------------------------------------------------------------- | |
| # TRAINING LOOP | |
| # ------------------------------------------------------------------------- | |
| for iteration in range(1, _DOWNC.OPTIM.NUM_ITERATIONS + 1): | |
| timer.tic() | |
| optimizer.zero_grad() | |
| batch = next(train_dataloader_iter) | |
| with amp.autocast(enabled=_DOWNC.AMP): | |
| logits = model(batch["image"]) | |
| loss = criterion(logits, batch["label"]) | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| timer.toc() | |
| if iteration % _A.log_every == 0 and dist.is_master_process(): | |
| logger.info( | |
| f"{timer.stats} | Loss: {loss:.3f} | GPU: {dist.gpu_mem_usage()} MB" | |
| ) | |
| tensorboard_writer.add_scalar(f"{DATASET}/train_loss", loss, iteration) | |
| tensorboard_writer.add_scalar( | |
| f"{DATASET}/learning_rate", | |
| optimizer.param_groups[0]["lr"], | |
| iteration, | |
| ) | |
| # --------------------------------------------------------------------- | |
| # VALIDATION | |
| # --------------------------------------------------------------------- | |
| if iteration % _A.checkpoint_every == 0: | |
| torch.set_grad_enabled(False) | |
| model.eval() | |
| total_val_loss = torch.tensor(0.0).to(device) | |
| for val_iteration, batch in enumerate(val_dataloader, start=1): | |
| for key in batch: | |
| batch[key] = batch[key].to(device) | |
| logits = model(batch["image"]) | |
| loss = criterion(logits, batch["label"]) | |
| top1(logits, batch["label"]) | |
| total_val_loss += loss | |
| # Divide each loss component by number of val batches per GPU. | |
| total_val_loss = total_val_loss / val_iteration | |
| dist.average_across_processes(total_val_loss) | |
| # Get accumulated Top-1 accuracy for logging across GPUs. | |
| acc = top1.get_metric(reset=True) | |
| dist.average_across_processes(acc) | |
| torch.set_grad_enabled(True) | |
| # Set model back to train mode only when fine-tuning end-to-end. | |
| if not _DOWNC.MODEL.VISUAL.FROZEN: | |
| model.train() | |
| # Save recent checkpoint and best checkpoint based on accuracy. | |
| if dist.is_master_process(): | |
| checkpoint_manager.step(iteration) | |
| if iteration % _A.checkpoint_every == 0 and dist.is_master_process(): | |
| logger.info(f"Iter: {iteration} | Top-1 accuracy: {acc})") | |
| tensorboard_writer.add_scalar( | |
| f"{DATASET}/val_loss", total_val_loss, iteration | |
| ) | |
| # This name scoping will result in Tensorboard displaying all metrics | |
| # (VOC07, caption, etc.) together. | |
| tensorboard_writer.add_scalars( | |
| f"metrics/{DATASET}", {"top1": acc}, iteration | |
| ) | |
| # All processes will wait till master process is done logging. | |
| dist.synchronize() | |
| if __name__ == "__main__": | |
| _A = parser.parse_args() | |
| # Add an arg in config override if `--weight-init` is imagenet. | |
| if _A.weight_init == "imagenet": | |
| _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) | |
| if _A.num_gpus_per_machine == 0: | |
| main(_A) | |
| else: | |
| # This will launch `main` and set appropriate CUDA device (GPU ID) as | |
| # per process (accessed in the beginning of `main`). | |
| dist.launch( | |
| main, | |
| num_machines=_A.num_machines, | |
| num_gpus_per_machine=_A.num_gpus_per_machine, | |
| machine_rank=_A.machine_rank, | |
| dist_url=_A.dist_url, | |
| args=(_A,), | |
| ) | |