| | """Training utils for VibeToken.""" |
| | import json |
| | import os |
| | import time |
| | import math |
| | from pathlib import Path |
| | import pprint |
| | import glob |
| | from collections import defaultdict |
| | import random |
| | import gc |
| |
|
| | from data import SimpleImageDataset, PretoeknizedDataSetJSONL, PretokenizedWebDataset |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from omegaconf import OmegaConf |
| | from torch.optim import AdamW |
| | from utils.lr_schedulers import get_scheduler |
| | from modeling.modules import EMAModel, ReconstructionLoss_Single_Stage |
| | from modeling.vibetoken_model import VibeTokenModel, PretrainedTokenizer |
| | from evaluator import VQGANEvaluator |
| |
|
| | from utils.viz_utils import make_viz_from_samples |
| | from torchinfo import summary |
| | import accelerate |
| |
|
| | def get_config(): |
| | """Reads configs from a yaml file and terminal.""" |
| | cli_conf = OmegaConf.from_cli() |
| |
|
| | yaml_conf = OmegaConf.load(cli_conf.config) |
| | conf = OmegaConf.merge(yaml_conf, cli_conf) |
| |
|
| | return conf |
| |
|
| |
|
| | class AverageMeter(object): |
| | """Computes and stores the average and current value. |
| | |
| | This class is borrowed from |
| | https://github.com/pytorch/examples/blob/main/imagenet/main.py#L423 |
| | """ |
| |
|
| | def __init__(self): |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def update(self, val, n=1): |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| |
|
| |
|
| | def create_pretrained_tokenizer(config, accelerator=None): |
| | if config.model.vq_model.finetune_decoder: |
| | pretrianed_tokenizer = None |
| | else: |
| | pretrianed_tokenizer = PretrainedTokenizer(config.model.vq_model.pretrained_tokenizer_weight) |
| | if accelerator is not None: |
| | pretrianed_tokenizer.to(accelerator.device) |
| | return pretrianed_tokenizer |
| |
|
| |
|
| | def create_model_and_loss_module(config, logger, accelerator, |
| | model_type="vibetoken"): |
| | """Creates model and loss module.""" |
| | logger.info("Creating model and loss module.") |
| | if model_type == "vibetoken": |
| | if config.model.sub_model_type == "vibetoken": |
| | model_cls = VibeTokenModel |
| | loss_cls = ReconstructionLoss_Single_Stage |
| | else: |
| | raise ValueError(f"Unsupported sub_model_type {config.model.sub_model_type}") |
| | else: |
| | raise ValueError(f"Unsupported model_type {model_type}") |
| | model = model_cls(config) |
| |
|
| | if config.experiment.get("init_weight", ""): |
| | model_weight = torch.load(config.experiment.init_weight, map_location="cpu") |
| | if config.model.vq_model.finetune_decoder: |
| | pretrained_tokenizer_weight = torch.load( |
| | config.model.vq_model.pretrained_tokenizer_weight, map_location="cpu" |
| | ) |
| | pretrained_tokenizer_weight = {"pixel_" + k:v for k,v in pretrained_tokenizer_weight.items() if not "encoder." in k} |
| | model_weight.update(pretrained_tokenizer_weight) |
| | |
| | msg = model.load_state_dict(model_weight, strict=False) |
| | logger.info(f"loading weight from {config.experiment.init_weight}, msg: {msg}") |
| |
|
| | |
| | ema_model = None |
| | if config.training.use_ema: |
| | ema_model = EMAModel(model.parameters(), decay=0.999, |
| | model_cls=model_cls, config=config) |
| | def load_model_hook(models, input_dir): |
| | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), |
| | model_cls=model_cls, config=config) |
| | ema_model.load_state_dict(load_model.state_dict()) |
| | ema_model.to(accelerator.device) |
| | del load_model |
| |
|
| | def save_model_hook(models, weights, output_dir): |
| | if accelerator.is_main_process: |
| | ema_model.save_pretrained(os.path.join(output_dir, "ema_model")) |
| |
|
| | accelerator.register_load_state_pre_hook(load_model_hook) |
| | accelerator.register_save_state_pre_hook(save_model_hook) |
| |
|
| | loss_module = loss_cls(config=config) if loss_cls is not None else None |
| |
|
| | if accelerator.is_main_process: |
| | if model_type in ["vibetoken"]: |
| | logger.info("VibeToken model summary not implemented yet.") |
| | else: |
| | raise NotImplementedError |
| |
|
| | return model, ema_model, loss_module |
| |
|
| |
|
| | def create_optimizer(config, logger, model, loss_module, |
| | model_type="vibetoken", need_discrminator=True): |
| | """Creates optimizer for model and discriminator.""" |
| | logger.info("Creating optimizers.") |
| | optimizer_config = config.optimizer.params |
| | learning_rate = optimizer_config.learning_rate |
| |
|
| | optimizer_type = config.optimizer.name |
| | if optimizer_type == "adamw": |
| | optimizer_cls = AdamW |
| | else: |
| | raise ValueError(f"Optimizer {optimizer_type} not supported") |
| |
|
| | exclude = (lambda n, p: p.ndim < 2 or "ln" in n or "bias" in n or 'latent_tokens' in n |
| | or 'mask_token' in n or 'embedding' in n or 'norm' in n or 'gamma' in n or 'embed' in n) |
| | include = lambda n, p: not exclude(n, p) |
| | named_parameters = list(model.named_parameters()) |
| | gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] |
| | rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] |
| | optimizer = optimizer_cls( |
| | [ |
| | {"params": gain_or_bias_params, "weight_decay": 0.}, |
| | {"params": rest_params, "weight_decay": optimizer_config.weight_decay}, |
| | ], |
| | lr=learning_rate, |
| | betas=(optimizer_config.beta1, optimizer_config.beta2) |
| | ) |
| |
|
| | if (config.model.vq_model.finetune_decoder or model_type == "vibetoken") and need_discrminator: |
| | discriminator_learning_rate = optimizer_config.discriminator_learning_rate |
| | discriminator_named_parameters = list(loss_module.named_parameters()) |
| | discriminator_gain_or_bias_params = [p for n, p in discriminator_named_parameters if exclude(n, p) and p.requires_grad] |
| | discriminator_rest_params = [p for n, p in discriminator_named_parameters if include(n, p) and p.requires_grad] |
| |
|
| | discriminator_optimizer = optimizer_cls( |
| | [ |
| | {"params": discriminator_gain_or_bias_params, "weight_decay": 0.}, |
| | {"params": discriminator_rest_params, "weight_decay": optimizer_config.weight_decay}, |
| | ], |
| | lr=discriminator_learning_rate, |
| | betas=(optimizer_config.beta1, optimizer_config.beta2) |
| | ) |
| | else: |
| | discriminator_optimizer = None |
| |
|
| | assert discriminator_optimizer is not None, "Discriminator optimizer is None with condition values: {config.model.vq_model.finetune_decoder} {model_type} {need_discrminator}" |
| |
|
| | return optimizer, discriminator_optimizer |
| |
|
| |
|
| | def create_lr_scheduler(config, logger, accelerator, optimizer, discriminator_optimizer=None): |
| | """Creates learning rate scheduler for model and discriminator.""" |
| | logger.info("Creating lr_schedulers.") |
| | lr_scheduler = get_scheduler( |
| | config.lr_scheduler.scheduler, |
| | optimizer=optimizer, |
| | num_training_steps=config.training.max_train_steps * accelerator.num_processes, |
| | num_warmup_steps=config.lr_scheduler.params.warmup_steps * accelerator.num_processes, |
| | base_lr=config.lr_scheduler.params.learning_rate, |
| | end_lr=config.lr_scheduler.params.end_lr, |
| | ) |
| | if discriminator_optimizer is not None: |
| | discriminator_lr_scheduler = get_scheduler( |
| | config.lr_scheduler.scheduler, |
| | optimizer=discriminator_optimizer, |
| | num_training_steps=config.training.max_train_steps * accelerator.num_processes - config.losses.discriminator_start, |
| | num_warmup_steps=config.lr_scheduler.params.warmup_steps * accelerator.num_processes, |
| | base_lr=config.lr_scheduler.params.learning_rate, |
| | end_lr=config.lr_scheduler.params.end_lr, |
| | ) |
| | else: |
| | discriminator_lr_scheduler = None |
| | return lr_scheduler, discriminator_lr_scheduler |
| |
|
| |
|
| | def create_dataloader(config, logger, accelerator): |
| | """Creates data loader for training and testing.""" |
| | logger.info("Creating dataloaders.") |
| | total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes |
| | total_batch_size = ( |
| | config.training.per_gpu_batch_size * accelerator.num_processes * config.training.gradient_accumulation_steps |
| | ) |
| | preproc_config = config.dataset.preprocessing |
| | dataset_config = config.dataset.params |
| |
|
| | if dataset_config.get("pretokenization", "") and dataset_config.get("dataset_with_text_label", False) is True: |
| | dataset = PretokenizedWebDataset( |
| | train_shards_path=dataset_config.train_shards_path_or_url, |
| | eval_shards_path=dataset_config.eval_shards_path_or_url, |
| | num_train_examples=config.experiment.max_train_examples, |
| | per_gpu_batch_size=config.training.per_gpu_batch_size, |
| | global_batch_size=total_batch_size_without_accum, |
| | num_workers_per_gpu=dataset_config.num_workers_per_gpu, |
| | resize_shorter_edge=preproc_config.resize_shorter_edge, |
| | crop_size=preproc_config.crop_size, |
| | random_crop=preproc_config.random_crop, |
| | random_flip=preproc_config.random_flip, |
| | normalize_mean=preproc_config.normalize_mean, |
| | normalize_std=preproc_config.normalize_std, |
| | process_recap=preproc_config.get("preproc_recap", True), |
| | use_recap_prob=preproc_config.get("use_recap_prob", 0.95) |
| | ) |
| | train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader |
| | elif dataset_config.get("pretokenization", "") and dataset_config.get("dataset_with_text_label", False) is False: |
| | dataset = SimpleImageDataset( |
| | train_shards_path=dataset_config.train_shards_path_or_url, |
| | eval_shards_path=dataset_config.eval_shards_path_or_url, |
| | num_train_examples=config.experiment.max_train_examples, |
| | per_gpu_batch_size=config.training.per_gpu_batch_size, |
| | global_batch_size=total_batch_size_without_accum, |
| | num_workers_per_gpu=dataset_config.num_workers_per_gpu, |
| | resize_shorter_edge=preproc_config.resize_shorter_edge, |
| | crop_size=preproc_config.crop_size, |
| | random_crop=preproc_config.random_crop, |
| | random_flip=preproc_config.random_flip, |
| | dataset_with_class_label=dataset_config.get("dataset_with_class_label", True), |
| | dataset_with_text_label=dataset_config.get("dataset_with_text_label", False), |
| | res_ratio_filtering=preproc_config.get("res_ratio_filtering", False), |
| | min_tokens=preproc_config.min_tokens, |
| | max_tokens=preproc_config.max_tokens, |
| | ) |
| | train_dataloader, eval_dataloader = dataset.train_dataloader, dataset.eval_dataloader |
| | else: |
| | if dataset_config.get("pretokenization", ""): |
| | train_dataloader = DataLoader( |
| | PretoeknizedDataSetJSONL(dataset_config.pretokenization), |
| | batch_size=config.training.per_gpu_batch_size, |
| | shuffle=True, drop_last=True, pin_memory=True) |
| | train_dataloader.num_batches = math.ceil( |
| | config.experiment.max_train_examples / total_batch_size_without_accum) |
| | |
| | return train_dataloader, eval_dataloader |
| |
|
| |
|
| | class LazyVQGANEvaluator: |
| | """A lazy-loading wrapper for VQGANEvaluator that delays inception model initialization.""" |
| | |
| | def __init__(self, device, enable_rfid=True, enable_inception_score=True, |
| | enable_codebook_usage_measure=False, enable_codebook_entropy_measure=False, |
| | num_codebook_entries=1024, accelerator=None): |
| | self._device = device |
| | self._enable_rfid = enable_rfid |
| | self._enable_inception_score = enable_inception_score |
| | self._enable_codebook_usage_measure = enable_codebook_usage_measure |
| | self._enable_codebook_entropy_measure = enable_codebook_entropy_measure |
| | self._num_codebook_entries = num_codebook_entries |
| | self._accelerator = accelerator |
| | self._evaluator = None |
| | self._initialized = False |
| | |
| | def _ensure_initialized(self): |
| | """Initialize the real evaluator only when needed.""" |
| | if not self._initialized: |
| | if self._accelerator and self._accelerator.num_processes > 1: |
| | if self._accelerator.is_main_process: |
| | try: |
| | from evaluator.inception import get_inception_model |
| | _ = get_inception_model() |
| | except Exception as e: |
| | print(f"Warning: Failed to pre-load inception model: {e}") |
| | |
| | if self._accelerator: |
| | self._accelerator.wait_for_everyone() |
| | |
| | try: |
| | self._evaluator = VQGANEvaluator( |
| | device=self._device, |
| | enable_rfid=self._enable_rfid, |
| | enable_inception_score=self._enable_inception_score, |
| | enable_codebook_usage_measure=self._enable_codebook_usage_measure, |
| | enable_codebook_entropy_measure=self._enable_codebook_entropy_measure, |
| | num_codebook_entries=self._num_codebook_entries |
| | ) |
| | self._initialized = True |
| | except Exception as e: |
| | print(f"Warning: Failed to create VQGANEvaluator, using dummy: {e}") |
| | class DummyEvaluator: |
| | def reset_metrics(self): pass |
| | def update(self, real_images, fake_images, codebook_indices=None): pass |
| | def result(self): |
| | return {"InceptionScore": 0.0, "rFID": 0.0, "CodebookUsage": 0.0, "CodebookEntropy": 0.0} |
| | self._evaluator = DummyEvaluator() |
| | self._initialized = True |
| | |
| | def reset_metrics(self): |
| | self._ensure_initialized() |
| | return self._evaluator.reset_metrics() |
| | |
| | def update(self, real_images, fake_images, codebook_indices=None): |
| | self._ensure_initialized() |
| | return self._evaluator.update(real_images, fake_images, codebook_indices) |
| | |
| | def result(self): |
| | self._ensure_initialized() |
| | return self._evaluator.result() |
| |
|
| |
|
| | def create_evaluator(config, logger, accelerator): |
| | """Creates evaluator.""" |
| | logger.info("Creating evaluator.") |
| | |
| | if config.model.vq_model.get("quantize_mode", "vq") in ["vq", "softvq", "mvq"]: |
| | evaluator = LazyVQGANEvaluator( |
| | device=accelerator.device, |
| | enable_rfid=True, |
| | enable_inception_score=True, |
| | enable_codebook_usage_measure=True, |
| | enable_codebook_entropy_measure=True, |
| | num_codebook_entries=config.model.vq_model.codebook_size, |
| | accelerator=accelerator |
| | ) |
| | elif config.model.vq_model.get("quantize_mode", "vq") == "vae": |
| | evaluator = LazyVQGANEvaluator( |
| | device=accelerator.device, |
| | enable_rfid=True, |
| | enable_inception_score=True, |
| | enable_codebook_usage_measure=False, |
| | enable_codebook_entropy_measure=False, |
| | accelerator=accelerator |
| | ) |
| | else: |
| | raise NotImplementedError |
| | |
| | logger.info("Lazy evaluator creation completed.") |
| | return evaluator |
| |
|
| |
|
| | def auto_resume(config, logger, accelerator, ema_model, |
| | num_update_steps_per_epoch, strict=True): |
| | """Auto resuming the training.""" |
| | global_step = 0 |
| | first_epoch = 0 |
| | if config.experiment.resume: |
| | accelerator.wait_for_everyone() |
| | if accelerator.is_main_process: |
| | local_ckpt_list = list(glob.glob(os.path.join( |
| | config.experiment.output_dir, "checkpoint*"))) |
| | logger.info(f"All globbed checkpoints are: {local_ckpt_list}") |
| | else: |
| | local_ckpt_list = [] |
| | |
| | if accelerator.num_processes > 1: |
| | checkpoint_count = torch.tensor(len(local_ckpt_list), device=accelerator.device) |
| | accelerate.utils.broadcast(checkpoint_count, 0) |
| | |
| | if checkpoint_count > 0: |
| | if accelerator.is_main_process: |
| | if len(local_ckpt_list) > 1: |
| | fn = lambda x: int(x.split('/')[-1].split('-')[-1]) |
| | checkpoint_paths = sorted(local_ckpt_list, key=fn, reverse=True) |
| | else: |
| | checkpoint_paths = local_ckpt_list |
| | latest_checkpoint = checkpoint_paths[0] |
| | else: |
| | latest_checkpoint = "" |
| | |
| | if accelerator.is_main_process: |
| | checkpoint_path_tensor = torch.tensor([ord(c) for c in latest_checkpoint], device=accelerator.device, dtype=torch.long) |
| | path_length = torch.tensor(len(latest_checkpoint), device=accelerator.device) |
| | else: |
| | path_length = torch.tensor(0, device=accelerator.device) |
| | |
| | accelerate.utils.broadcast(path_length, 0) |
| | |
| | if not accelerator.is_main_process: |
| | checkpoint_path_tensor = torch.zeros(path_length.item(), device=accelerator.device, dtype=torch.long) |
| | |
| | accelerate.utils.broadcast(checkpoint_path_tensor, 0) |
| | |
| | if not accelerator.is_main_process: |
| | latest_checkpoint = ''.join([chr(c.item()) for c in checkpoint_path_tensor]) |
| | |
| | global_step = load_checkpoint( |
| | Path(latest_checkpoint), |
| | accelerator, |
| | logger=logger, |
| | strict=strict |
| | ) |
| | if config.training.use_ema: |
| | ema_model.set_step(global_step) |
| | first_epoch = global_step // num_update_steps_per_epoch |
| | else: |
| | logger.info("Training from scratch.") |
| | else: |
| | if len(local_ckpt_list) >= 1: |
| | if len(local_ckpt_list) > 1: |
| | fn = lambda x: int(x.split('/')[-1].split('-')[-1]) |
| | checkpoint_paths = sorted(local_ckpt_list, key=fn, reverse=True) |
| | else: |
| | checkpoint_paths = local_ckpt_list |
| | global_step = load_checkpoint( |
| | Path(checkpoint_paths[0]), |
| | accelerator, |
| | logger=logger, |
| | strict=strict |
| | ) |
| | if config.training.use_ema: |
| | ema_model.set_step(global_step) |
| | first_epoch = global_step // num_update_steps_per_epoch |
| | else: |
| | logger.info("Training from scratch.") |
| | |
| | accelerator.wait_for_everyone() |
| | return global_step, first_epoch |
| |
|
| |
|
| | def train_one_epoch(config, logger, accelerator, |
| | model, ema_model, loss_module, |
| | optimizer, discriminator_optimizer, |
| | lr_scheduler, discriminator_lr_scheduler, |
| | train_dataloader, eval_dataloader, |
| | evaluator, |
| | global_step, |
| | model_type="vibetoken", |
| | clip_tokenizer=None, |
| | clip_encoder=None, |
| | pretrained_tokenizer=None): |
| | """One epoch training.""" |
| | batch_time_meter = AverageMeter() |
| | data_time_meter = AverageMeter() |
| | end = time.time() |
| |
|
| | model.train() |
| |
|
| | autoencoder_logs = defaultdict(float) |
| | discriminator_logs = defaultdict(float) |
| | for i, batch in enumerate(train_dataloader): |
| | model.train() |
| | if "image" in batch: |
| | images = batch["image"].to( |
| | accelerator.device, memory_format=torch.contiguous_format, non_blocking=True |
| | ) |
| | if config.training.get("variable_resolution", False): |
| | any2any = config.training.variable_resolution.get("any2any", True) |
| |
|
| | dims = config.training.variable_resolution.dim |
| | ratios = config.training.variable_resolution.ratio |
| | assert len(dims) == len(ratios), "dims and ratios must have the same length" |
| | input_res = tuple(random.choices(dims, weights=ratios, k=1)[0]) |
| | |
| | if any2any: |
| | output_res = tuple(random.choices(dims, weights=ratios, k=1)[0]) |
| | else: |
| | output_res = input_res |
| | |
| | images = torch.nn.functional.interpolate(images, size=output_res, mode="bilinear", align_corners=False) |
| | input_images = torch.nn.functional.interpolate(images, size=input_res, mode="bilinear", align_corners=False) |
| | else: |
| | input_images = images |
| | output_res = (None, None) |
| |
|
| | fnames = batch["__key__"] |
| | data_time_meter.update(time.time() - end) |
| |
|
| | if pretrained_tokenizer is not None: |
| | pretrained_tokenizer.eval() |
| | proxy_codes = pretrained_tokenizer.encode(images) |
| | else: |
| | proxy_codes = None |
| |
|
| | with accelerator.accumulate([model, loss_module]): |
| | additional_args = {} |
| | if config.model.get("train_with_attention", False): |
| | additional_args["key_attention_mask"] = batch["attention_mask"].to( |
| | accelerator.device, memory_format=torch.contiguous_format, non_blocking=True |
| | ) |
| | reconstructed_images, extra_results_dict = model(input_images, height=output_res[0], width=output_res[1], **additional_args) |
| | autoencoder_loss, loss_dict = loss_module( |
| | images, |
| | reconstructed_images, |
| | extra_results_dict, |
| | global_step, |
| | mode="generator", |
| | ) |
| |
|
| | autoencoder_logs = {} |
| | for k, v in loss_dict.items(): |
| | if k in ["discriminator_factor", "d_weight"]: |
| | if type(v) == torch.Tensor: |
| | autoencoder_logs["train/" + k] = v.cpu().item() |
| | else: |
| | autoencoder_logs["train/" + k] = v |
| | else: |
| | gathered_tensor = accelerator.gather(v) |
| | autoencoder_logs["train/" + k] = gathered_tensor.mean().item() |
| | del gathered_tensor |
| | |
| | torch.cuda.empty_cache() |
| | accelerator.backward(autoencoder_loss) |
| |
|
| | if config.training.max_grad_norm is not None and accelerator.sync_gradients: |
| | accelerator.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) |
| |
|
| | optimizer.step() |
| | lr_scheduler.step() |
| |
|
| | if ( |
| | accelerator.sync_gradients |
| | and (global_step + 1) % config.experiment.log_grad_norm_every == 0 |
| | and accelerator.is_main_process |
| | ): |
| | log_grad_norm(model, accelerator, global_step + 1) |
| |
|
| | optimizer.zero_grad(set_to_none=True) |
| |
|
| | |
| | discriminator_logs = defaultdict(float) |
| | if (config.model.vq_model.finetune_decoder or model_type == "vibetoken") and accelerator.unwrap_model(loss_module).should_discriminator_be_trained(global_step): |
| | discriminator_logs = defaultdict(float) |
| | discriminator_loss, loss_dict_discriminator = loss_module( |
| | images, |
| | reconstructed_images, |
| | extra_results_dict, |
| | global_step=global_step, |
| | mode="discriminator", |
| | ) |
| |
|
| | for k, v in loss_dict_discriminator.items(): |
| | if k in ["logits_real", "logits_fake"]: |
| | if type(v) == torch.Tensor: |
| | discriminator_logs["train/" + k] = v.cpu().item() |
| | else: |
| | discriminator_logs["train/" + k] = v |
| | else: |
| | gathered_tensor = accelerator.gather(v) |
| | discriminator_logs["train/" + k] = gathered_tensor.mean().item() |
| | del gathered_tensor |
| |
|
| | torch.cuda.empty_cache() |
| | accelerator.backward(discriminator_loss) |
| |
|
| | if config.training.max_grad_norm is not None and accelerator.sync_gradients: |
| | accelerator.clip_grad_norm_(loss_module.parameters(), config.training.max_grad_norm) |
| |
|
| | discriminator_optimizer.step() |
| | discriminator_lr_scheduler.step() |
| | |
| | if ( |
| | accelerator.sync_gradients |
| | and (global_step + 1) % config.experiment.log_grad_norm_every == 0 |
| | and accelerator.is_main_process |
| | ): |
| | log_grad_norm(loss_module, accelerator, global_step + 1) |
| | |
| | discriminator_optimizer.zero_grad(set_to_none=True) |
| |
|
| | if accelerator.sync_gradients: |
| | if config.training.use_ema: |
| | ema_model.step(model.parameters()) |
| | batch_time_meter.update(time.time() - end) |
| | end = time.time() |
| |
|
| | if (global_step + 1) % config.experiment.log_every == 0: |
| | samples_per_second_per_gpu = ( |
| | config.training.gradient_accumulation_steps * config.training.per_gpu_batch_size / batch_time_meter.val |
| | ) |
| |
|
| | lr = lr_scheduler.get_last_lr()[0] |
| | logger.info( |
| | f"Data (t): {data_time_meter.val:0.4f}, {samples_per_second_per_gpu:0.2f}/s/gpu " |
| | f"Batch (t): {batch_time_meter.val:0.4f} " |
| | f"LR: {lr:0.6f} " |
| | f"Step: {global_step + 1} " |
| | f"Total Loss: {autoencoder_logs['train/total_loss']:0.4f} " |
| | f"Recon Loss: {autoencoder_logs['train/reconstruction_loss']:0.4f} " |
| | ) |
| | logs = { |
| | "lr": lr, |
| | "lr/generator": lr, |
| | "samples/sec/gpu": samples_per_second_per_gpu, |
| | "time/data_time": data_time_meter.val, |
| | "time/batch_time": batch_time_meter.val, |
| | } |
| | logs.update(autoencoder_logs) |
| | logs.update(discriminator_logs) |
| | accelerator.log(logs, step=global_step + 1) |
| |
|
| | del autoencoder_logs, discriminator_logs, logs |
| | gc.collect() |
| |
|
| | batch_time_meter.reset() |
| | data_time_meter.reset() |
| |
|
| | |
| | if (global_step + 1) % config.experiment.save_every == 0: |
| | save_path = save_checkpoint( |
| | model, config.experiment.output_dir, accelerator, global_step + 1, logger=logger) |
| | accelerator.wait_for_everyone() |
| |
|
| | |
| | if (global_step + 1) % config.experiment.generate_every == 0: |
| | if accelerator.is_main_process: |
| | if config.training.get("use_ema", False): |
| | ema_model.store(model.parameters()) |
| | ema_model.copy_to(model.parameters()) |
| |
|
| | reconstruct_images( |
| | model, |
| | images[:config.training.num_generated_images], |
| | fnames[:config.training.num_generated_images], |
| | accelerator, |
| | global_step + 1, |
| | config.experiment.output_dir, |
| | logger=logger, |
| | config=config, |
| | pretrained_tokenizer=pretrained_tokenizer |
| | ) |
| |
|
| | if config.training.get("use_ema", False): |
| | ema_model.restore(model.parameters()) |
| | |
| | accelerator.wait_for_everyone() |
| |
|
| |
|
| | |
| | if eval_dataloader is not None and (global_step + 1) % config.experiment.eval_every == 0: |
| | logger.info(f"Computing metrics on the validation set.") |
| | if config.training.get("use_ema", False): |
| | ema_model.store(model.parameters()) |
| | ema_model.copy_to(model.parameters()) |
| | eval_scores = eval_reconstruction( |
| | config, |
| | model, |
| | eval_dataloader, |
| | accelerator, |
| | evaluator, |
| | pretrained_tokenizer=pretrained_tokenizer |
| | ) |
| | logger.info( |
| | f"EMA EVALUATION " |
| | f"Step: {global_step + 1} " |
| | ) |
| | logger.info(pprint.pformat(eval_scores)) |
| | if accelerator.is_main_process: |
| | eval_log = {f'ema_eval/'+k: v for k, v in eval_scores.items()} |
| | accelerator.log(eval_log, step=global_step + 1) |
| | if config.training.get("use_ema", False): |
| | ema_model.restore(model.parameters()) |
| | else: |
| | eval_scores = eval_reconstruction( |
| | config, |
| | model, |
| | eval_dataloader, |
| | accelerator, |
| | evaluator, |
| | pretrained_tokenizer=pretrained_tokenizer |
| | ) |
| |
|
| | logger.info( |
| | f"Non-EMA EVALUATION " |
| | f"Step: {global_step + 1} " |
| | ) |
| | logger.info(pprint.pformat(eval_scores)) |
| | if accelerator.is_main_process: |
| | eval_log = {f'eval/'+k: v for k, v in eval_scores.items()} |
| | accelerator.log(eval_log, step=global_step + 1) |
| |
|
| | accelerator.wait_for_everyone() |
| |
|
| | global_step += 1 |
| |
|
| | if global_step >= config.training.max_train_steps: |
| | accelerator.print( |
| | f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}" |
| | ) |
| | break |
| |
|
| |
|
| | return global_step |
| |
|
| |
|
| | @torch.no_grad() |
| | def eval_reconstruction( |
| | config, |
| | model, |
| | eval_loader, |
| | accelerator, |
| | evaluator, |
| | pretrained_tokenizer=None |
| | ): |
| | model.eval() |
| | evaluator.reset_metrics() |
| | local_model = accelerator.unwrap_model(model) |
| |
|
| | accelerator.wait_for_everyone() |
| | |
| | for batch in eval_loader: |
| | images = batch["image"].to( |
| | accelerator.device, memory_format=torch.contiguous_format, non_blocking=True |
| | ) |
| |
|
| | original_images = torch.clone(images) |
| | additional_args = {} |
| | if config.model.get("eval_with_attention", False): |
| | additional_args["key_attention_mask"] = batch["attention_mask"].to( |
| | accelerator.device, memory_format=torch.contiguous_format, non_blocking=True |
| | ) |
| | reconstructed_images, model_dict = local_model(images, **additional_args) |
| |
|
| | if pretrained_tokenizer is not None: |
| | reconstructed_images = pretrained_tokenizer.decode(reconstructed_images.argmax(1)) |
| | reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0) |
| | reconstructed_images = torch.round(reconstructed_images * 255.0) / 255.0 |
| | original_images = torch.clamp(original_images, 0.0, 1.0) |
| | |
| | if isinstance(model_dict, dict): |
| | evaluator.update(original_images, reconstructed_images.squeeze(2), model_dict["min_encoding_indices"]) |
| | else: |
| | evaluator.update(original_images, reconstructed_images.squeeze(2), None) |
| | |
| | accelerator.wait_for_everyone() |
| | |
| | local_results = evaluator.result() |
| | |
| | if accelerator.num_processes > 1: |
| | gathered_results = {} |
| | for key, value in local_results.items(): |
| | if isinstance(value, (int, float)): |
| | value_tensor = torch.tensor(value, device=accelerator.device) |
| | gathered_values = accelerator.gather(value_tensor) |
| | gathered_results[key] = gathered_values.mean().item() |
| | else: |
| | gathered_results[key] = value |
| | |
| | accelerator.wait_for_everyone() |
| | model.train() |
| | return gathered_results |
| | else: |
| | model.train() |
| | return local_results |
| |
|
| |
|
| | @torch.no_grad() |
| | def reconstruct_images(model, original_images, fnames, accelerator, |
| | global_step, output_dir, logger, config=None, |
| | pretrained_tokenizer=None): |
| | logger.info("Reconstructing images...") |
| | original_images = torch.clone(original_images) |
| | _, _, height, width = original_images.shape |
| | model.eval() |
| | dtype = torch.float32 |
| | if accelerator.mixed_precision == "fp16": |
| | dtype = torch.float16 |
| | elif accelerator.mixed_precision == "bf16": |
| | dtype = torch.bfloat16 |
| |
|
| | with torch.autocast("cuda", dtype=dtype, enabled=accelerator.mixed_precision != "no"): |
| | enc_tokens, encoder_dict = accelerator.unwrap_model(model).encode(original_images) |
| | reconstructed_images = accelerator.unwrap_model(model).decode(enc_tokens, height=height, width=width) |
| | if pretrained_tokenizer is not None: |
| | reconstructed_images = pretrained_tokenizer.decode(reconstructed_images.argmax(1)) |
| |
|
| | images_for_saving, images_for_logging = make_viz_from_samples( |
| | original_images, |
| | reconstructed_images |
| | ) |
| | if config.training.enable_wandb: |
| | accelerator.get_tracker("wandb").log_images( |
| | {f"Train Reconstruction": images_for_saving}, |
| | step=global_step |
| | ) |
| | else: |
| | accelerator.get_tracker("tensorboard").log_images( |
| | {"Train Reconstruction": images_for_logging}, step=global_step |
| | ) |
| | root = Path(output_dir) / "train_images" |
| | os.makedirs(root, exist_ok=True) |
| | for i,img in enumerate(images_for_saving): |
| | filename = f"{global_step:08}_s-{i:03}-{fnames[i]}.png" |
| | path = os.path.join(root, filename) |
| | img.save(path) |
| |
|
| | model.train() |
| |
|
| |
|
| | def save_checkpoint(model, output_dir, accelerator, global_step, logger) -> Path: |
| | save_path = Path(output_dir) / f"checkpoint-{global_step}" |
| |
|
| | state_dict = accelerator.get_state_dict(model) |
| | if accelerator.is_main_process: |
| | unwrapped_model = accelerator.unwrap_model(model) |
| | unwrapped_model.save_pretrained_weight( |
| | save_path / "unwrapped_model", |
| | save_function=accelerator.save, |
| | state_dict=state_dict, |
| | ) |
| | json.dump({"global_step": global_step}, (save_path / "metadata.json").open("w+")) |
| | logger.info(f"Saved state to {save_path}") |
| |
|
| | accelerator.save_state(save_path) |
| | return save_path |
| |
|
| |
|
| | def load_checkpoint(checkpoint_path: Path, accelerator, logger, strict=True): |
| | logger.info(f"Load checkpoint from {checkpoint_path}") |
| |
|
| | accelerator.load_state(checkpoint_path, strict=strict) |
| | |
| | with open(checkpoint_path / "metadata.json", "r") as f: |
| | global_step = int(json.load(f)["global_step"]) |
| |
|
| | logger.info(f"Resuming at global_step {global_step}") |
| | return global_step |
| |
|
| |
|
| | def log_grad_norm(model, accelerator, global_step): |
| | for name, param in model.named_parameters(): |
| | if param.grad is not None: |
| | grads = param.grad.detach().data |
| | grad_norm = (grads.norm(p=2) / grads.numel()).item() |
| | accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step) |
| |
|