| # Copyright (c) 2025 FoundationVision | |
| # SPDX-License-Identifier: MIT | |
| import json | |
| import os | |
| import random | |
| import sys | |
| import time | |
| from collections import OrderedDict | |
| from typing import Union | |
| import numpy as np | |
| import torch | |
| from tap import Tap | |
| import infinity.utils.dist as dist | |
| from infinity.utils.sequence_parallel import SequenceParallelManager as sp_manager | |
| class Args(Tap): | |
| # ================================================================================================================== | |
| # ============================================= Paths and Directories ============================================ | |
| # ================================================================================================================== | |
| local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # Directory to save checkpoints | |
| data_path: str = '' # Path to the image dataset | |
| video_data_path: str = '' # Path to the video dataset | |
| bed: str = '' # Directory to copy checkpoints apart from local_out_path | |
| vae_path: str = '' # Path to the VAE checkpoint | |
| log_txt_path: str = '' # Path to the log file | |
| t5_path: str = '' # Path to the T5 model; if not specified, it will be automatically found | |
| token_cache_dir: str = '' # Directory for token cache | |
| # ================================================================================================================== | |
| # =============================================== General Training ================================================= | |
| # ================================================================================================================== | |
| exp_name: str = '' # Experiment name | |
| project_name: str = 'infinitystar' # Name of the wandb project | |
| tf32: bool = True # Whether to use TensorFloat32 | |
| auto_resume: bool = True # Whether to automatically resume from the last checkpoint | |
| rush_resume: str = '' # Path to a pretrained infinity checkpoint for rushing resume | |
| rush_omnistore_resume: str = '' # Path to an omnistore pretrained checkpoint for rushing resume | |
| torchshard_resume: str = '' # Path to an torch shard checkpoint resume | |
| log_every_iter: bool = False # Whether to log every iteration | |
| checkpoint_type: str = 'torch' # Type of checkpoint: 'torch' or 'onmistore' | |
| device: str = 'cpu' # Device to use for training ('cpu' or 'cuda') | |
| is_master_node: bool = None # Whether the current node is the master node | |
| epoch: int = 300 # Number of training epochs | |
| log_freq: int = 1 # Logging frequency in stdout | |
| save_model_iters_freq: int = 1000 # Frequency of saving the model in iterations | |
| short_cap_prob: float = 0.2 # Probability of training with short captions | |
| label_smooth: float = 0.0 # Label smoothing factor | |
| cfg: float = 0.1 # Classifier-free guidance dropout probability | |
| rand_uncond: bool = False # Whether to use random, unlearnable unconditional embedding | |
| twoclip_alternatingtraining: int = 0 # Whether to use two-clip alternating training | |
| wp_it: int = 100 # Warm-up iterations | |
| # ================================================================================================================== | |
| # ===================================================== Model ====================================================== | |
| # ================================================================================================================== | |
| model: str = '' # Model type: 'b' for VAE training, or any other for GPT training | |
| sdpa_mem: bool = True # Whether to use memory-efficient SDPA | |
| rms_norm: bool = False # Whether to use RMS normalization | |
| tau: float = 1 # Tau of self-attention in GPT | |
| tini: float = -1 # Initialization parameters | |
| topp: float = 0.0 # top-p | |
| topk: float = 0.0 # top-k | |
| fused_norm: bool = False # Whether to use fused normalization | |
| flash: bool = False # Whether to use customized flash-attention kernel | |
| use_flex_attn: bool = False # Whether to use flex_attn to speed up training | |
| norm_eps: float = 1e-6 # Epsilon for normalization layers | |
| Ct5: int = 2048 # Feature dimension of the text encoder | |
| simple_text_proj: int = 1 # Whether to use a simple text projection | |
| mask_type: str = 'infinity_elegant_clip20frames_v2' # Self-attention mask type ('var' or 'video_tower') | |
| mask_video_first_frame: int = 0 # Whether to mask the first frame of the video when calculating loss | |
| use_fsdp_model_ema: int = 0 # Whether to use FSDP model EMA | |
| model_ema_decay: float = 0.9999 # Model EMA decay rate | |
| rope_type: str = '4d' # RoPE type ('2d', '3d', or '4d') | |
| rope2d_each_sa_layer: int = 1 # Apply RoPE2D to each self-attention layer | |
| rope2d_normalized_by_hw: int = 2 # Apply normalized RoPE2D | |
| add_lvl_embeding_on_first_block: int = 0 # Apply level PE embedding only to the first block | |
| # ================================================================================================================== | |
| # ================================================== Scale Schedule ============================================= | |
| # ================================================================================================================== | |
| semantic_scales: int = 8 # Number of semantic scales | |
| semantic_scale_dim: int = 16 # Dimension of semantic scales | |
| detail_scale_dim: int = 64 # Dimension of detail scales | |
| use_learnable_dim_proj: int = 0 # Whether to use a learnable dimension projection | |
| detail_scale_min_tokens: int = 80 # Minimum number of tokens for detail scale | |
| pn: str = '' # Pixel numbers, choose from '0.06M', '0.25M', '1M' | |
| scale_schedule: tuple = None # [Automatically set] Scale schedule based on pn | |
| patch_size: int = None # [Automatically set] Patch size based on scale_schedule | |
| dynamic_scale_schedule: str = '' # Dynamic scale schedule for video | |
| min_scale_ind: int = 3 # Minimum scale index for infinity frame pack | |
| max_reweight_value: int = 40 # Clipping value for reweighting | |
| image_scale_repetition: str = '[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]' # Repetition for image scales | |
| video_scale_repetition: str = '[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]' # Repetition for video scales | |
| inner_scale_boost: int = 0 # Whether to boost inner scales | |
| drop_720p_last_scale: int = 1 # Whether to drop the last scale for 720p | |
| reweight_loss_by_scale: int = 0 # Reweight loss by scale | |
| # ================================================================================================================== | |
| # ================================================== Optimization ================================================== | |
| # ================================================================================================================== | |
| tlr: float = 2e-5 # Learning rate | |
| grad_clip: float = 5 # Gradient clipping threshold | |
| cdec: bool = False # Whether to decay the grad clip thresholds | |
| opt: str = 'adamw' # Optimizer type ('adamw' or 'lion') | |
| ada: str = '0.9_0.97' # Adam's beta parameters (e.g., '0.9_0.999') | |
| adam_eps: float = 0.0 # Adam's epsilon | |
| fused_adam: bool = True # Whether to use fused Adam optimizer | |
| disable_weight_decay: int = 1 # Whether to disable weight decay on sparse params | |
| fp16: int = 2 # Floating point precision: 1 for fp16, 2 for bf16 | |
| # ================================================================================================================== | |
| # ====================================================== Data ====================================================== | |
| # ================================================================================================================== | |
| video_fps: int = 16 # Frames per second for video | |
| video_frames: int = 81 # Number of frames per video | |
| video_batch_size: int = 1 # Batch size for video data | |
| workers: int = 16 # Number of dataloader workers | |
| image_batch_size: int = 0 # [Automatically set] Batch size per GPU for image data | |
| ac: int = 1 # Gradient accumulation steps | |
| r_accu: float = 1.0 # [Automatically set] Reciprocal of gradient accumulation | |
| tlen: int = 512 # Truncate text embedding to this length | |
| num_of_label_value: int = 2 # Number of label values (2 for bitwise, 0 for index-wise) | |
| dynamic_resolution_across_gpus: int = 1 # Allow dynamic resolution across GPUs | |
| enable_dynamic_length_prompt: int = 0 # Enable dynamic length prompt during training | |
| use_streaming_dataset: int = 0 # Whether to use a streaming dataset | |
| iterable_data_buffersize: int = 90000 # Buffer size for streaming dataset | |
| image_batches_multiply: float = 1.0 # Multiplier for the number of image batches per epoch | |
| down_size_limit: int = 10000 # Download size limit for videos in MB | |
| addition_pn_list: str = '[]' # Additional pixel number list | |
| video_caption_type: str = 'tarsier2_caption' # Type of video caption to use | |
| only_images4extract_feats: int = 0 # Whether to only extract features for images | |
| train_max_token_len: int = -1 # Maximum token length for training | |
| train_with_var_seq_len: int = 0 # Whether to train with variable sequence length | |
| video_var_len_prob: str = '[30, 30, 30, 5, 3, 2]' # Probability distribution for variable video length | |
| duration_resolution: int = 1 # Resolution for duration | |
| seq_pack_bucket: int = 1000 # Bucket size for sequence packing | |
| drop_long_video: int = 0 # Whether to drop long videos | |
| min_video_frames: int = -1 # Minimum number of video frames | |
| restrict_data_size: int = -1 # Restrict the size of the dataset | |
| allow_less_one_elem_in_seq: int = 0 # Allow sequences with less than one element | |
| train_192pshort: int = 0 # Whether to train with 192p short videos | |
| steps_per_frame: int = 3 # Steps per frame for the video tower | |
| add_motion_score2caption: int = 0 # Whether to prepend motion score to the caption | |
| context_frames: int = 10000 # Context frames for the video tower | |
| cached_video_frames: int = 81 # Number of cached video frames | |
| frames_inner_clip: int = 20 # Number of frames in a clip for infinity frame pack | |
| context_interval: int = 2 # Context interval | |
| context_from_largest_no: int = 1 # Context from the largest number | |
| append_duration2caption: int = 0 # Whether to append duration to the caption | |
| cache_check_mode: int = 0 # Cache check mode | |
| online_t5: bool = True # Whether to use online T5 or load local features | |
| # ================================================================================================================== | |
| # ============================================= Distributed Training =============================================== | |
| # ================================================================================================================== | |
| enable_hybrid_shard: bool = False # Whether to use hybrid FSDP | |
| inner_shard_degree: int = 8 # Inner degree for FSDP | |
| zero: int = 0 # DeepSpeed ZeRO stage | |
| buck: str = 'chunk' # Module-wise bucketing for FSDP | |
| fsdp_orig: bool = True # Whether to use original FSDP | |
| enable_checkpointing: str = None # Checkpointing strategy: 'full-block', 'self-attn' | |
| pad_to_multiplier: int = 128 # Pad sequence length to a multiplier of this value | |
| sp_size: int = 0 # Sequence parallelism size | |
| fsdp_save_flatten_model: int = 1 # Whether to save the flattened model in FSDP | |
| inject_sync: int = 0 # Whether to inject synchronization | |
| model_init_device: str = 'cuda' # Device for model initialization | |
| fsdp_init_device: str = 'cuda' # Device for FSDP initialization | |
| # ================================================================================================================== | |
| # ======================================================= VAE ====================================================== | |
| # ================================================================================================================== | |
| vae_type: int = 64 # VAE type (e.g., 16/32/64 for bsq vae quant bits) | |
| fake_vae_input: bool = False # Whether to use fake VAE input for debugging | |
| use_slice: int = 1 # Whether to use slicing for VAE encoding | |
| use_vae_token_cache: int = 1 # Whether to use token cache for VAE | |
| save_vae_token_cache: int = 0 # Whether to save the VAE token cache | |
| allow_online_vae_feature_extraction: int = 1 # Allow online VAE feature extraction | |
| use_text_token_cache: int = 0 # Whether to use text token cache | |
| videovae: int = 10 # Whether to use a video VAE | |
| use_feat_proj: int = 2 # Whether to use feature projection | |
| use_two_stage_lfq: int = 0 # Whether to use two-stage LFQ | |
| casual_multi_scale: int = 0 # Whether to use casual multi-scale | |
| temporal_compress_rate: int = 4 # Temporal compression rate | |
| apply_spatial_patchify: int = 0 # Whether to apply spatial patchify | |
| # ================================================================================================================== | |
| # ============================================ Bitwise Self-Correction ============================================= | |
| # ================================================================================================================== | |
| noise_apply_layers: int = 1000 # Apply noise to layers | |
| noise_apply_strength: str = '-1' # Noise strength | |
| noise_apply_requant: int = 1 # Requant after applying noise | |
| noise_apply_random_one: int = 0 # Requant only one scale randomly | |
| debug_bsc: int = 0 # Save figures and set breakpoints for debugging BSC | |
| noise_input: int = 0 # Whether to add noise to the input | |
| reduce_accumulate_error_method: str = 'bsc' # Method to reduce accumulation error | |
| ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### | |
| ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### | |
| ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### | |
| # would be automatically set in runtime | |
| branch: str = '' # subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] | |
| commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] | |
| commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] | |
| cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:]) # [automatically set; don't specify this] | |
| tag: str = 'UK' # [automatically set; don't specify this] | |
| cur_it: str = '' # [automatically set; don't specify this] | |
| MFU: float = None # [automatically set; don't specify this] | |
| HFU: float = None # [automatically set; don't specify this] | |
| # ================================================================================================================== | |
| # ======================== ignore these parts below since they are only for debug use ============================== | |
| # ================================================================================================================== | |
| dbg: bool = 'KEVIN_LOCAL' in os.environ # only used when debug about unused param in DDP | |
| prof: int = 0 # profile | |
| prof_freq: int = 50 # profile | |
| profall: int = 0 | |
| # ================================================================================================================== | |
| # ======================== ignore these parts above since they are only for debug use ============================== | |
| # ================================================================================================================== | |
| def gpt_training(self): | |
| return len(self.model) > 0 | |
| def set_initial_seed(self, benchmark: bool): | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = benchmark | |
| assert self.seed | |
| seed = self.seed | |
| torch.backends.cudnn.deterministic = True | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def dump_log(self): | |
| if not dist.is_local_master(): | |
| return | |
| nd = {'is_master': dist.is_visualizer()} | |
| for k, v in { | |
| 'name': self.exp_name, | |
| 'tag': self.tag, | |
| 'cmd': self.cmd, | |
| 'commit': self.commit_id, | |
| 'branch': self.branch, | |
| 'cur_it': self.cur_it, | |
| 'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()), | |
| 'opt': self.opt, | |
| 'is_master_node': self.is_master_node, | |
| }.items(): | |
| if hasattr(v, 'item'):v = v.item() | |
| if v is None or (isinstance(v, str) and len(v) == 0): continue | |
| nd[k] = v | |
| with open(self.log_txt_path, 'w') as fp: | |
| json.dump(nd, fp, indent=2) | |
| def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: | |
| d = (OrderedDict if key_ordered else dict)() | |
| for k in self.class_variables.keys(): | |
| if k not in {'device', 'dbg_ks_fp'}: # these are not serializable | |
| d[k] = getattr(self, k) | |
| return d | |
| def load_state_dict(self, d: Union[OrderedDict, dict, str]): | |
| if isinstance(d, str): # for compatibility with old version | |
| d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l])) | |
| for k in d.keys(): | |
| if k in {'is_large_model', 'gpt_training'}: | |
| continue | |
| try: | |
| setattr(self, k, d[k]) | |
| except Exception as e: | |
| print(f'k={k}, v={d[k]}') | |
| raise e | |
| def set_tf32(tf32: bool): | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.allow_tf32 = bool(tf32) | |
| torch.backends.cuda.matmul.allow_tf32 = bool(tf32) | |
| if hasattr(torch, 'set_float32_matmul_precision'): | |
| torch.set_float32_matmul_precision('high' if tf32 else 'highest') | |
| print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}') | |
| print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}') | |
| print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}') | |
| def __str__(self): | |
| s = [] | |
| for k in self.class_variables.keys(): | |
| if k not in {'device', 'dbg_ks_fp'}: # these are not serializable | |
| s.append(f' {k:20s}: {getattr(self, k)}') | |
| s = '\n'.join(s) | |
| return f'{{\n{s}\n}}\n' | |
| def init_dist_and_get_args(): | |
| for i in range(len(sys.argv)): | |
| if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='): | |
| del sys.argv[i] | |
| break | |
| args = Args(explicit_bool=True).parse_args(known_only=True) | |
| if len(args.extra_args) > 0 and args.is_master_node == 0: | |
| print(f'======================================================================================') | |
| print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') | |
| print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') | |
| print(f'======================================================================================\n\n') | |
| args.set_tf32(args.tf32) | |
| try: os.makedirs(args.bed, exist_ok=True) | |
| except: pass | |
| try: os.makedirs(args.local_out_path, exist_ok=True) | |
| except: pass | |
| dist.init_distributed_mode(local_out_path=args.local_out_path, fork=False, timeout_minutes=30) | |
| args.device = dist.get_device() | |
| # sync seed | |
| args.seed = int(time.time()) | |
| seed = torch.tensor([args.seed], device=args.device) | |
| if torch.distributed.is_initialized(): | |
| torch.distributed.all_reduce(seed, op=torch.distributed.ReduceOp.MIN) | |
| args.seed = seed.item() | |
| if args.sp_size > 1: | |
| print(f"INFO: sp_size={args.sp_size}") | |
| sp_manager.init_sp(args.sp_size) | |
| args.r_accu = 1 / args.ac # gradient accumulation | |
| args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9') | |
| args.opt = args.opt.lower().strip() | |
| # gpt args | |
| if args.gpt_training: | |
| assert args.vae_path, 'VAE ckpt must be specified when training GPT' | |
| from infinity.models import alias_dict | |
| if args.model in alias_dict: | |
| args.model = alias_dict[args.model] | |
| args.log_txt_path = os.path.join(args.local_out_path, 'log.txt') | |
| args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing | |
| args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing | |
| assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \ | |
| f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}." | |
| if len(args.exp_name) == 0: | |
| args.exp_name = os.path.basename(args.bed) or 'test_exp' | |
| if '-' in args.exp_name: | |
| args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1) | |
| else: | |
| args.tag = 'UK' | |
| if dist.is_master(): | |
| os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}') | |
| if args.sdpa_mem: | |
| from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp | |
| enable_flash_sdp(True) | |
| enable_mem_efficient_sdp(True) | |
| enable_math_sdp(False) | |
| print(args) | |
| if isinstance(args.noise_apply_strength, str): | |
| args.noise_apply_strength = list(map(float, args.noise_apply_strength.split(','))) | |
| elif isinstance(args.noise_apply_strength, float): | |
| args.noise_apply_strength = [args.noise_apply_strength] | |
| return args | |