| | """Training script for VibeToken. |
| | |
| | Reference: |
| | https://github.com/huggingface/open-muse |
| | """ |
| | import math |
| | import os |
| | import sys |
| | from pathlib import Path |
| | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) |
| | sys.path.append(parent_dir) |
| |
|
| | from accelerate.utils import set_seed |
| | from accelerate import Accelerator |
| |
|
| | import torch |
| | import wandb |
| | from omegaconf import OmegaConf |
| | from utils.logger import setup_logger |
| |
|
| | from utils.train_utils import ( |
| | get_config, create_pretrained_tokenizer, |
| | create_model_and_loss_module, |
| | create_optimizer, create_lr_scheduler, create_dataloader, |
| | create_evaluator, auto_resume, save_checkpoint, |
| | train_one_epoch) |
| |
|
| |
|
| | def main(): |
| | workspace = os.environ.get('WORKSPACE', '') |
| | if workspace: |
| | torch.hub.set_dir(workspace + "/models/hub") |
| |
|
| | config = get_config() |
| | |
| | if config.training.enable_tf32: |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | torch.backends.cudnn.benchmark = True |
| | torch.backends.cudnn.deterministic = False |
| |
|
| | output_dir = config.experiment.output_dir |
| | os.makedirs(output_dir, exist_ok=True) |
| | config.experiment.logging_dir = os.path.join(output_dir, "logs") |
| |
|
| | |
| | tracker = "tensorboard" |
| | if config.training.enable_wandb: |
| | tracker = "wandb" |
| |
|
| | accelerator = Accelerator( |
| | gradient_accumulation_steps=config.training.gradient_accumulation_steps, |
| | mixed_precision=config.training.mixed_precision, |
| | log_with=tracker, |
| | project_dir=config.experiment.logging_dir, |
| | split_batches=False, |
| | ) |
| |
|
| | logger = setup_logger(name="VibeToken", log_level="INFO", |
| | output_file=f"{output_dir}/log{accelerator.process_index}.txt") |
| |
|
| | if accelerator.is_main_process: |
| | if config.training.enable_wandb: |
| | wandb_config = config.training.get("wandb", {}) |
| | wandb_project = wandb_config.get("project", config.experiment.project) |
| | wandb_entity = wandb_config.get("entity", None) |
| | wandb_name = wandb_config.get("name", config.experiment.name) |
| | wandb_tags = list(wandb_config.get("tags", [])) |
| | wandb_notes = wandb_config.get("notes", None) |
| | wandb_resume_id = wandb_config.get("resume_id", None) |
| |
|
| | wandb_init_kwargs = { |
| | "wandb": { |
| | "name": wandb_name, |
| | "dir": output_dir, |
| | "resume": "allow", |
| | } |
| | } |
| | if wandb_entity: |
| | wandb_init_kwargs["wandb"]["entity"] = wandb_entity |
| | if wandb_tags: |
| | wandb_init_kwargs["wandb"]["tags"] = wandb_tags |
| | if wandb_notes: |
| | wandb_init_kwargs["wandb"]["notes"] = wandb_notes |
| | if wandb_resume_id: |
| | wandb_init_kwargs["wandb"]["id"] = wandb_resume_id |
| |
|
| | accelerator.init_trackers( |
| | project_name=wandb_project, |
| | config=OmegaConf.to_container(config, resolve=True), |
| | init_kwargs=wandb_init_kwargs, |
| | ) |
| | logger.info(f"WandB initialized - Project: {wandb_project}, Name: {wandb_name}") |
| | else: |
| | accelerator.init_trackers(config.experiment.name) |
| |
|
| | config_path = Path(output_dir) / "config.yaml" |
| | logger.info(f"Saving config to {config_path}") |
| | OmegaConf.save(config, config_path) |
| | logger.info(f"Config:\n{OmegaConf.to_yaml(config)}") |
| |
|
| | |
| | if config.training.seed is not None: |
| | set_seed(config.training.seed, device_specific=True) |
| |
|
| | accelerator.wait_for_everyone() |
| |
|
| | |
| | if config.model.vq_model.is_legacy: |
| | if accelerator.is_main_process: |
| | logger.info("Creating pretrained tokenizer on main process...") |
| | accelerator.wait_for_everyone() |
| | pretrained_tokenizer = create_pretrained_tokenizer(config, accelerator) |
| | accelerator.wait_for_everyone() |
| | if accelerator.is_main_process: |
| | logger.info("Pretrained tokenizer creation completed.") |
| | else: |
| | pretrained_tokenizer = None |
| |
|
| | if accelerator.is_main_process: |
| | logger.info("Creating model and loss module...") |
| | accelerator.wait_for_everyone() |
| | |
| | model, ema_model, loss_module = create_model_and_loss_module( |
| | config, logger, accelerator, model_type="vibetoken") |
| | |
| | accelerator.wait_for_everyone() |
| | if accelerator.is_main_process: |
| | logger.info("Model creation completed.") |
| |
|
| | optimizer, discriminator_optimizer = create_optimizer(config, logger, model, loss_module, model_type="vibetoken") |
| |
|
| | lr_scheduler, discriminator_lr_scheduler = create_lr_scheduler( |
| | config, logger, accelerator, optimizer, discriminator_optimizer) |
| |
|
| | if accelerator.is_main_process: |
| | logger.info("Creating dataloaders...") |
| | train_dataloader, eval_dataloader = create_dataloader(config, logger, accelerator) |
| | accelerator.wait_for_everyone() |
| |
|
| | |
| | if accelerator.is_main_process: |
| | logger.info("Setting up evaluator...") |
| | evaluator = create_evaluator(config, logger, accelerator) |
| |
|
| | |
| | logger.info("Preparing model, optimizer and dataloaders") |
| | |
| | if config.model.vq_model.is_legacy: |
| | if config.model.vq_model.finetune_decoder: |
| | model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare( |
| | model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler |
| | ) |
| | else: |
| | model, optimizer, lr_scheduler = accelerator.prepare( |
| | model, optimizer, lr_scheduler |
| | ) |
| | else: |
| | model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler = accelerator.prepare( |
| | model, loss_module, optimizer, discriminator_optimizer, lr_scheduler, discriminator_lr_scheduler |
| | ) |
| |
|
| | if config.training.use_ema: |
| | ema_model.to(accelerator.device) |
| |
|
| | total_batch_size_without_accum = config.training.per_gpu_batch_size * accelerator.num_processes |
| | num_batches = math.ceil( |
| | config.experiment.max_train_examples / total_batch_size_without_accum) |
| | num_update_steps_per_epoch = math.ceil(num_batches / config.training.gradient_accumulation_steps) |
| | num_train_epochs = math.ceil(config.training.max_train_steps / num_update_steps_per_epoch) |
| |
|
| | |
| | logger.info("***** Running training *****") |
| | logger.info(f" Num training steps = {config.training.max_train_steps}") |
| | logger.info(f" Gradient Accumulation steps = {config.training.gradient_accumulation_steps}") |
| | logger.info(f" Instantaneous batch size per gpu = { config.training.per_gpu_batch_size}") |
| | logger.info(f""" Total train batch size (w. parallel, distributed & accumulation) = {( |
| | config.training.per_gpu_batch_size * |
| | accelerator.num_processes * |
| | config.training.gradient_accumulation_steps)}""") |
| | global_step = 0 |
| | first_epoch = 0 |
| |
|
| | global_step, first_epoch = auto_resume( |
| | config, logger, accelerator, ema_model, num_update_steps_per_epoch, |
| | strict=True) |
| |
|
| | for current_epoch in range(first_epoch, num_train_epochs): |
| | accelerator.print(f"Epoch {current_epoch}/{num_train_epochs-1} started.") |
| | global_step = train_one_epoch(config, logger, accelerator, |
| | model, ema_model, loss_module, |
| | optimizer, discriminator_optimizer, |
| | lr_scheduler, discriminator_lr_scheduler, |
| | train_dataloader, eval_dataloader, |
| | evaluator, |
| | global_step, |
| | pretrained_tokenizer=pretrained_tokenizer, |
| | model_type="vibetoken") |
| | |
| | if global_step >= config.training.max_train_steps: |
| | accelerator.print( |
| | f"Finishing training: Global step is >= Max train steps: {global_step} >= {config.training.max_train_steps}" |
| | ) |
| | break |
| |
|
| | accelerator.wait_for_everyone() |
| | |
| | save_checkpoint(model, output_dir, accelerator, global_step, logger=logger) |
| | |
| | if accelerator.is_main_process: |
| | model = accelerator.unwrap_model(model) |
| | if config.training.use_ema: |
| | ema_model.copy_to(model.parameters()) |
| | model.save_pretrained_weight(output_dir) |
| |
|
| | if accelerator.is_main_process and config.training.enable_wandb: |
| | wandb.finish() |
| | logger.info("WandB run finished") |
| | accelerator.end_training() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |