# Copyright (c) 2025 FoundationVision # SPDX-License-Identifier: MIT import json import os import random import sys import time from collections import OrderedDict from typing import Union import numpy as np import torch from tap import Tap import infinity.utils.dist as dist from infinity.utils.sequence_parallel import SequenceParallelManager as sp_manager class Args(Tap): # ================================================================================================================== # ============================================= Paths and Directories ============================================ # ================================================================================================================== local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # Directory to save checkpoints data_path: str = '' # Path to the image dataset video_data_path: str = '' # Path to the video dataset bed: str = '' # Directory to copy checkpoints apart from local_out_path vae_path: str = '' # Path to the VAE checkpoint log_txt_path: str = '' # Path to the log file t5_path: str = '' # Path to the T5 model; if not specified, it will be automatically found token_cache_dir: str = '' # Directory for token cache # ================================================================================================================== # =============================================== General Training ================================================= # ================================================================================================================== exp_name: str = '' # Experiment name project_name: str = 'infinitystar' # Name of the wandb project tf32: bool = True # Whether to use TensorFloat32 auto_resume: bool = True # Whether to automatically resume from the last checkpoint rush_resume: str = '' # Path to a pretrained infinity checkpoint for rushing resume rush_omnistore_resume: str = '' # Path to an omnistore pretrained checkpoint for rushing resume torchshard_resume: str = '' # Path to an torch shard checkpoint resume log_every_iter: bool = False # Whether to log every iteration checkpoint_type: str = 'torch' # Type of checkpoint: 'torch' or 'onmistore' device: str = 'cpu' # Device to use for training ('cpu' or 'cuda') is_master_node: bool = None # Whether the current node is the master node epoch: int = 300 # Number of training epochs log_freq: int = 1 # Logging frequency in stdout save_model_iters_freq: int = 1000 # Frequency of saving the model in iterations short_cap_prob: float = 0.2 # Probability of training with short captions label_smooth: float = 0.0 # Label smoothing factor cfg: float = 0.1 # Classifier-free guidance dropout probability rand_uncond: bool = False # Whether to use random, unlearnable unconditional embedding twoclip_alternatingtraining: int = 0 # Whether to use two-clip alternating training wp_it: int = 100 # Warm-up iterations # ================================================================================================================== # ===================================================== Model ====================================================== # ================================================================================================================== model: str = '' # Model type: 'b' for VAE training, or any other for GPT training sdpa_mem: bool = True # Whether to use memory-efficient SDPA rms_norm: bool = False # Whether to use RMS normalization tau: float = 1 # Tau of self-attention in GPT tini: float = -1 # Initialization parameters topp: float = 0.0 # top-p topk: float = 0.0 # top-k fused_norm: bool = False # Whether to use fused normalization flash: bool = False # Whether to use customized flash-attention kernel use_flex_attn: bool = False # Whether to use flex_attn to speed up training norm_eps: float = 1e-6 # Epsilon for normalization layers Ct5: int = 2048 # Feature dimension of the text encoder simple_text_proj: int = 1 # Whether to use a simple text projection mask_type: str = 'infinity_elegant_clip20frames_v2' # Self-attention mask type ('var' or 'video_tower') mask_video_first_frame: int = 0 # Whether to mask the first frame of the video when calculating loss use_fsdp_model_ema: int = 0 # Whether to use FSDP model EMA model_ema_decay: float = 0.9999 # Model EMA decay rate rope_type: str = '4d' # RoPE type ('2d', '3d', or '4d') rope2d_each_sa_layer: int = 1 # Apply RoPE2D to each self-attention layer rope2d_normalized_by_hw: int = 2 # Apply normalized RoPE2D add_lvl_embeding_on_first_block: int = 0 # Apply level PE embedding only to the first block # ================================================================================================================== # ================================================== Scale Schedule ============================================= # ================================================================================================================== semantic_scales: int = 8 # Number of semantic scales semantic_scale_dim: int = 16 # Dimension of semantic scales detail_scale_dim: int = 64 # Dimension of detail scales use_learnable_dim_proj: int = 0 # Whether to use a learnable dimension projection detail_scale_min_tokens: int = 80 # Minimum number of tokens for detail scale pn: str = '' # Pixel numbers, choose from '0.06M', '0.25M', '1M' scale_schedule: tuple = None # [Automatically set] Scale schedule based on pn patch_size: int = None # [Automatically set] Patch size based on scale_schedule dynamic_scale_schedule: str = '' # Dynamic scale schedule for video min_scale_ind: int = 3 # Minimum scale index for infinity frame pack max_reweight_value: int = 40 # Clipping value for reweighting image_scale_repetition: str = '[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]' # Repetition for image scales video_scale_repetition: str = '[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]' # Repetition for video scales inner_scale_boost: int = 0 # Whether to boost inner scales drop_720p_last_scale: int = 1 # Whether to drop the last scale for 720p reweight_loss_by_scale: int = 0 # Reweight loss by scale # ================================================================================================================== # ================================================== Optimization ================================================== # ================================================================================================================== tlr: float = 2e-5 # Learning rate grad_clip: float = 5 # Gradient clipping threshold cdec: bool = False # Whether to decay the grad clip thresholds opt: str = 'adamw' # Optimizer type ('adamw' or 'lion') ada: str = '0.9_0.97' # Adam's beta parameters (e.g., '0.9_0.999') adam_eps: float = 0.0 # Adam's epsilon fused_adam: bool = True # Whether to use fused Adam optimizer disable_weight_decay: int = 1 # Whether to disable weight decay on sparse params fp16: int = 2 # Floating point precision: 1 for fp16, 2 for bf16 # ================================================================================================================== # ====================================================== Data ====================================================== # ================================================================================================================== video_fps: int = 16 # Frames per second for video video_frames: int = 81 # Number of frames per video video_batch_size: int = 1 # Batch size for video data workers: int = 16 # Number of dataloader workers image_batch_size: int = 0 # [Automatically set] Batch size per GPU for image data ac: int = 1 # Gradient accumulation steps r_accu: float = 1.0 # [Automatically set] Reciprocal of gradient accumulation tlen: int = 512 # Truncate text embedding to this length num_of_label_value: int = 2 # Number of label values (2 for bitwise, 0 for index-wise) dynamic_resolution_across_gpus: int = 1 # Allow dynamic resolution across GPUs enable_dynamic_length_prompt: int = 0 # Enable dynamic length prompt during training use_streaming_dataset: int = 0 # Whether to use a streaming dataset iterable_data_buffersize: int = 90000 # Buffer size for streaming dataset image_batches_multiply: float = 1.0 # Multiplier for the number of image batches per epoch down_size_limit: int = 10000 # Download size limit for videos in MB addition_pn_list: str = '[]' # Additional pixel number list video_caption_type: str = 'tarsier2_caption' # Type of video caption to use only_images4extract_feats: int = 0 # Whether to only extract features for images train_max_token_len: int = -1 # Maximum token length for training train_with_var_seq_len: int = 0 # Whether to train with variable sequence length video_var_len_prob: str = '[30, 30, 30, 5, 3, 2]' # Probability distribution for variable video length duration_resolution: int = 1 # Resolution for duration seq_pack_bucket: int = 1000 # Bucket size for sequence packing drop_long_video: int = 0 # Whether to drop long videos min_video_frames: int = -1 # Minimum number of video frames restrict_data_size: int = -1 # Restrict the size of the dataset allow_less_one_elem_in_seq: int = 0 # Allow sequences with less than one element train_192pshort: int = 0 # Whether to train with 192p short videos steps_per_frame: int = 3 # Steps per frame for the video tower add_motion_score2caption: int = 0 # Whether to prepend motion score to the caption context_frames: int = 10000 # Context frames for the video tower cached_video_frames: int = 81 # Number of cached video frames frames_inner_clip: int = 20 # Number of frames in a clip for infinity frame pack context_interval: int = 2 # Context interval context_from_largest_no: int = 1 # Context from the largest number append_duration2caption: int = 0 # Whether to append duration to the caption cache_check_mode: int = 0 # Cache check mode online_t5: bool = True # Whether to use online T5 or load local features # ================================================================================================================== # ============================================= Distributed Training =============================================== # ================================================================================================================== enable_hybrid_shard: bool = False # Whether to use hybrid FSDP inner_shard_degree: int = 8 # Inner degree for FSDP zero: int = 0 # DeepSpeed ZeRO stage buck: str = 'chunk' # Module-wise bucketing for FSDP fsdp_orig: bool = True # Whether to use original FSDP enable_checkpointing: str = None # Checkpointing strategy: 'full-block', 'self-attn' pad_to_multiplier: int = 128 # Pad sequence length to a multiplier of this value sp_size: int = 0 # Sequence parallelism size fsdp_save_flatten_model: int = 1 # Whether to save the flattened model in FSDP inject_sync: int = 0 # Whether to inject synchronization model_init_device: str = 'cuda' # Device for model initialization fsdp_init_device: str = 'cuda' # Device for FSDP initialization # ================================================================================================================== # ======================================================= VAE ====================================================== # ================================================================================================================== vae_type: int = 64 # VAE type (e.g., 16/32/64 for bsq vae quant bits) fake_vae_input: bool = False # Whether to use fake VAE input for debugging use_slice: int = 1 # Whether to use slicing for VAE encoding use_vae_token_cache: int = 1 # Whether to use token cache for VAE save_vae_token_cache: int = 0 # Whether to save the VAE token cache allow_online_vae_feature_extraction: int = 1 # Allow online VAE feature extraction use_text_token_cache: int = 0 # Whether to use text token cache videovae: int = 10 # Whether to use a video VAE use_feat_proj: int = 2 # Whether to use feature projection use_two_stage_lfq: int = 0 # Whether to use two-stage LFQ casual_multi_scale: int = 0 # Whether to use casual multi-scale temporal_compress_rate: int = 4 # Temporal compression rate apply_spatial_patchify: int = 0 # Whether to apply spatial patchify # ================================================================================================================== # ============================================ Bitwise Self-Correction ============================================= # ================================================================================================================== noise_apply_layers: int = 1000 # Apply noise to layers noise_apply_strength: str = '-1' # Noise strength noise_apply_requant: int = 1 # Requant after applying noise noise_apply_random_one: int = 0 # Requant only one scale randomly debug_bsc: int = 0 # Save figures and set breakpoints for debugging BSC noise_input: int = 0 # Whether to add noise to the input reduce_accumulate_error_method: str = 'bsc' # Method to reduce accumulation error ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### # would be automatically set in runtime branch: str = '' # subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:]) # [automatically set; don't specify this] tag: str = 'UK' # [automatically set; don't specify this] cur_it: str = '' # [automatically set; don't specify this] MFU: float = None # [automatically set; don't specify this] HFU: float = None # [automatically set; don't specify this] # ================================================================================================================== # ======================== ignore these parts below since they are only for debug use ============================== # ================================================================================================================== dbg: bool = 'KEVIN_LOCAL' in os.environ # only used when debug about unused param in DDP prof: int = 0 # profile prof_freq: int = 50 # profile profall: int = 0 # ================================================================================================================== # ======================== ignore these parts above since they are only for debug use ============================== # ================================================================================================================== @property def gpt_training(self): return len(self.model) > 0 def set_initial_seed(self, benchmark: bool): torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = benchmark assert self.seed seed = self.seed torch.backends.cudnn.deterministic = True os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def dump_log(self): if not dist.is_local_master(): return nd = {'is_master': dist.is_visualizer()} for k, v in { 'name': self.exp_name, 'tag': self.tag, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'cur_it': self.cur_it, 'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()), 'opt': self.opt, 'is_master_node': self.is_master_node, }.items(): if hasattr(v, 'item'):v = v.item() if v is None or (isinstance(v, str) and len(v) == 0): continue nd[k] = v with open(self.log_txt_path, 'w') as fp: json.dump(nd, fp, indent=2) def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: d = (OrderedDict if key_ordered else dict)() for k in self.class_variables.keys(): if k not in {'device', 'dbg_ks_fp'}: # these are not serializable d[k] = getattr(self, k) return d def load_state_dict(self, d: Union[OrderedDict, dict, str]): if isinstance(d, str): # for compatibility with old version d: dict = eval('\n'.join([l for l in d.splitlines() if ' 1: print(f"INFO: sp_size={args.sp_size}") sp_manager.init_sp(args.sp_size) args.r_accu = 1 / args.ac # gradient accumulation args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9') args.opt = args.opt.lower().strip() # gpt args if args.gpt_training: assert args.vae_path, 'VAE ckpt must be specified when training GPT' from infinity.models import alias_dict if args.model in alias_dict: args.model = alias_dict[args.model] args.log_txt_path = os.path.join(args.local_out_path, 'log.txt') args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \ f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}." if len(args.exp_name) == 0: args.exp_name = os.path.basename(args.bed) or 'test_exp' if '-' in args.exp_name: args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1) else: args.tag = 'UK' if dist.is_master(): os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}') if args.sdpa_mem: from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp enable_flash_sdp(True) enable_mem_efficient_sdp(True) enable_math_sdp(False) print(args) if isinstance(args.noise_apply_strength, str): args.noise_apply_strength = list(map(float, args.noise_apply_strength.split(','))) elif isinstance(args.noise_apply_strength, float): args.noise_apply_strength = [args.noise_apply_strength] return args