File size: 22,996 Bytes
3d1c0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT

import json
import os
import random
import sys
import time
from collections import OrderedDict
from typing import Union

import numpy as np
import torch
from tap import Tap

import infinity.utils.dist as dist
from infinity.utils.sequence_parallel import SequenceParallelManager as sp_manager


class Args(Tap):
    # ==================================================================================================================
    # ============================================= Paths and Directories ============================================
    # ==================================================================================================================
    local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output')  # Directory to save checkpoints
    data_path: str = ''  # Path to the image dataset
    video_data_path: str = ''  # Path to the video dataset
    bed: str = ''  # Directory to copy checkpoints apart from local_out_path
    vae_path: str = ''  # Path to the VAE checkpoint
    log_txt_path: str = ''  # Path to the log file
    t5_path: str = ''  # Path to the T5 model; if not specified, it will be automatically found
    token_cache_dir: str = ''  # Directory for token cache

    # ==================================================================================================================
    # =============================================== General Training =================================================
    # ==================================================================================================================
    exp_name: str = ''  # Experiment name
    project_name: str = 'infinitystar'  # Name of the wandb project
    tf32: bool = True  # Whether to use TensorFloat32
    auto_resume: bool = True  # Whether to automatically resume from the last checkpoint
    rush_resume: str = ''  # Path to a pretrained infinity checkpoint for rushing resume
    rush_omnistore_resume: str = ''  # Path to an omnistore pretrained checkpoint for rushing resume
    torchshard_resume: str = ''      # Path to an torch shard checkpoint resume
    log_every_iter: bool = False  # Whether to log every iteration
    checkpoint_type: str = 'torch'  # Type of checkpoint: 'torch' or 'onmistore'
    device: str = 'cpu'  # Device to use for training ('cpu' or 'cuda')
    is_master_node: bool = None  # Whether the current node is the master node
    epoch: int = 300  # Number of training epochs
    log_freq: int = 1  # Logging frequency in stdout
    save_model_iters_freq: int = 1000  # Frequency of saving the model in iterations
    short_cap_prob: float = 0.2  # Probability of training with short captions
    label_smooth: float = 0.0  # Label smoothing factor
    cfg: float = 0.1  # Classifier-free guidance dropout probability
    rand_uncond: bool = False  # Whether to use random, unlearnable unconditional embedding
    twoclip_alternatingtraining: int = 0  # Whether to use two-clip alternating training
    wp_it: int = 100  # Warm-up iterations

    # ==================================================================================================================
    # ===================================================== Model ======================================================
    # ==================================================================================================================
    model: str = ''  # Model type: 'b' for VAE training, or any other for GPT training
    sdpa_mem: bool = True  # Whether to use memory-efficient SDPA
    rms_norm: bool = False  # Whether to use RMS normalization
    tau: float = 1  # Tau of self-attention in GPT
    tini: float = -1  # Initialization parameters
    topp: float = 0.0                     # top-p
    topk: float = 0.0                     # top-k
    fused_norm: bool = False  # Whether to use fused normalization
    flash: bool = False  # Whether to use customized flash-attention kernel
    use_flex_attn: bool = False  # Whether to use flex_attn to speed up training
    norm_eps: float = 1e-6  # Epsilon for normalization layers
    Ct5: int = 2048  # Feature dimension of the text encoder
    simple_text_proj: int = 1  # Whether to use a simple text projection
    mask_type: str = 'infinity_elegant_clip20frames_v2'  # Self-attention mask type ('var' or 'video_tower')
    mask_video_first_frame: int = 0  # Whether to mask the first frame of the video when calculating loss

    use_fsdp_model_ema: int = 0  # Whether to use FSDP model EMA
    model_ema_decay: float = 0.9999  # Model EMA decay rate

    rope_type: str = '4d'  # RoPE type ('2d', '3d', or '4d')
    rope2d_each_sa_layer: int = 1  # Apply RoPE2D to each self-attention layer
    rope2d_normalized_by_hw: int = 2  # Apply normalized RoPE2D
    add_lvl_embeding_on_first_block: int = 0  # Apply level PE embedding only to the first block

    # ==================================================================================================================
    # ================================================== Scale Schedule =============================================
    # ==================================================================================================================
    semantic_scales: int = 8  # Number of semantic scales
    semantic_scale_dim: int = 16  # Dimension of semantic scales
    detail_scale_dim: int = 64  # Dimension of detail scales
    use_learnable_dim_proj: int = 0  # Whether to use a learnable dimension projection
    detail_scale_min_tokens: int = 80  # Minimum number of tokens for detail scale
    pn: str = ''  # Pixel numbers, choose from '0.06M', '0.25M', '1M'
    scale_schedule: tuple = None  # [Automatically set] Scale schedule based on pn
    patch_size: int = None  # [Automatically set] Patch size based on scale_schedule
    dynamic_scale_schedule: str = ''  # Dynamic scale schedule for video
    min_scale_ind: int = 3  # Minimum scale index for infinity frame pack
    max_reweight_value: int = 40  # Clipping value for reweighting
    image_scale_repetition: str = '[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]'  # Repetition for image scales
    video_scale_repetition: str = '[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]'  # Repetition for video scales
    inner_scale_boost: int = 0  # Whether to boost inner scales
    drop_720p_last_scale: int = 1  # Whether to drop the last scale for 720p
    reweight_loss_by_scale: int = 0  # Reweight loss by scale

    # ==================================================================================================================
    # ================================================== Optimization ==================================================
    # ==================================================================================================================
    tlr: float = 2e-5  # Learning rate
    grad_clip: float = 5  # Gradient clipping threshold
    cdec: bool = False  # Whether to decay the grad clip thresholds
    opt: str = 'adamw'  # Optimizer type ('adamw' or 'lion')
    ada: str = '0.9_0.97'  # Adam's beta parameters (e.g., '0.9_0.999')
    adam_eps: float = 0.0  # Adam's epsilon
    fused_adam: bool = True  # Whether to use fused Adam optimizer
    disable_weight_decay: int = 1  # Whether to disable weight decay on sparse params
    fp16: int = 2  # Floating point precision: 1 for fp16, 2 for bf16
    
    # ==================================================================================================================
    # ====================================================== Data ======================================================
    # ==================================================================================================================
    video_fps: int = 16  # Frames per second for video
    video_frames: int = 81  # Number of frames per video
    video_batch_size: int = 1  # Batch size for video data
    workers: int = 16  # Number of dataloader workers
    image_batch_size: int = 0  # [Automatically set] Batch size per GPU for image data
    ac: int = 1  # Gradient accumulation steps
    r_accu: float = 1.0  # [Automatically set] Reciprocal of gradient accumulation
    tlen: int = 512  # Truncate text embedding to this length
    num_of_label_value: int = 2  # Number of label values (2 for bitwise, 0 for index-wise)
    dynamic_resolution_across_gpus: int = 1  # Allow dynamic resolution across GPUs
    enable_dynamic_length_prompt: int = 0  # Enable dynamic length prompt during training
    use_streaming_dataset: int = 0  # Whether to use a streaming dataset
    iterable_data_buffersize: int = 90000  # Buffer size for streaming dataset
    image_batches_multiply: float = 1.0  # Multiplier for the number of image batches per epoch
    down_size_limit: int = 10000  # Download size limit for videos in MB
    addition_pn_list: str = '[]'  # Additional pixel number list
    video_caption_type: str = 'tarsier2_caption'  # Type of video caption to use
    only_images4extract_feats: int = 0  # Whether to only extract features for images
    train_max_token_len: int = -1  # Maximum token length for training
    train_with_var_seq_len: int = 0  # Whether to train with variable sequence length
    video_var_len_prob: str = '[30, 30, 30, 5, 3, 2]'  # Probability distribution for variable video length
    duration_resolution: int = 1  # Resolution for duration
    seq_pack_bucket: int = 1000  # Bucket size for sequence packing
    drop_long_video: int = 0  # Whether to drop long videos
    min_video_frames: int = -1  # Minimum number of video frames
    restrict_data_size: int = -1  # Restrict the size of the dataset
    allow_less_one_elem_in_seq: int = 0  # Allow sequences with less than one element
    train_192pshort: int = 0  # Whether to train with 192p short videos
    steps_per_frame: int = 3  # Steps per frame for the video tower
    add_motion_score2caption: int = 0  # Whether to prepend motion score to the caption
    context_frames: int = 10000  # Context frames for the video tower
    cached_video_frames: int = 81  # Number of cached video frames
    frames_inner_clip: int = 20  # Number of frames in a clip for infinity frame pack
    context_interval: int = 2  # Context interval
    context_from_largest_no: int = 1  # Context from the largest number
    append_duration2caption: int = 0  # Whether to append duration to the caption
    cache_check_mode: int = 0  # Cache check mode
    online_t5: bool = True  # Whether to use online T5 or load local features
    
    # ==================================================================================================================
    # ============================================= Distributed Training ===============================================
    # ==================================================================================================================
    enable_hybrid_shard: bool = False  # Whether to use hybrid FSDP
    inner_shard_degree: int = 8  # Inner degree for FSDP
    zero: int = 0  # DeepSpeed ZeRO stage
    buck: str = 'chunk'  # Module-wise bucketing for FSDP
    fsdp_orig: bool = True  # Whether to use original FSDP
    enable_checkpointing: str = None  # Checkpointing strategy: 'full-block', 'self-attn'
    pad_to_multiplier: int = 128  # Pad sequence length to a multiplier of this value
    sp_size: int = 0  # Sequence parallelism size
    fsdp_save_flatten_model: int = 1  # Whether to save the flattened model in FSDP
    inject_sync: int = 0  # Whether to inject synchronization
    model_init_device: str = 'cuda'  # Device for model initialization
    fsdp_init_device: str = 'cuda'  # Device for FSDP initialization
    
    # ==================================================================================================================
    # ======================================================= VAE ======================================================
    # ==================================================================================================================
    vae_type: int = 64  # VAE type (e.g., 16/32/64 for bsq vae quant bits)
    fake_vae_input: bool = False  # Whether to use fake VAE input for debugging
    use_slice: int = 1  # Whether to use slicing for VAE encoding
    use_vae_token_cache: int = 1  # Whether to use token cache for VAE
    save_vae_token_cache: int = 0  # Whether to save the VAE token cache
    allow_online_vae_feature_extraction: int = 1  # Allow online VAE feature extraction
    use_text_token_cache: int = 0  # Whether to use text token cache
    videovae: int = 10  # Whether to use a video VAE
    use_feat_proj: int = 2  # Whether to use feature projection
    use_two_stage_lfq: int = 0  # Whether to use two-stage LFQ
    casual_multi_scale: int = 0  # Whether to use casual multi-scale
    temporal_compress_rate: int = 4  # Temporal compression rate
    apply_spatial_patchify: int = 0  # Whether to apply spatial patchify
    

    # ==================================================================================================================
    # ============================================ Bitwise Self-Correction =============================================
    # ==================================================================================================================
    noise_apply_layers: int = 1000  # Apply noise to layers
    noise_apply_strength: str = '-1'  # Noise strength
    noise_apply_requant: int = 1  # Requant after applying noise
    noise_apply_random_one: int = 0  # Requant only one scale randomly
    debug_bsc: int = 0  # Save figures and set breakpoints for debugging BSC
    noise_input: int = 0  # Whether to add noise to the input
    reduce_accumulate_error_method: str = 'bsc'  # Method to reduce accumulation error
    


    ############################  Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
    ############################  Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################
    ############################  Attention! The following arguments and configurations are set automatically, you can skip reading the following part ###############################


    # would be automatically set in runtime
    branch: str = '' # subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
    commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]'  # [automatically set; don't specify this]
    commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip()    # [automatically set; don't specify this]
    cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:])  # [automatically set; don't specify this]
    tag: str = 'UK'                     # [automatically set; don't specify this]
    cur_it: str = ''                    # [automatically set; don't specify this]
    MFU: float = None                   # [automatically set; don't specify this]
    HFU: float = None                   # [automatically set; don't specify this]
    # ==================================================================================================================
    # ======================== ignore these parts below since they are only for debug use ==============================
    # ==================================================================================================================
    
    dbg: bool = 'KEVIN_LOCAL' in os.environ       # only used when debug about unused param in DDP
    prof: int = 0           # profile
    prof_freq: int = 50     # profile
    profall: int = 0
    # ==================================================================================================================
    # ======================== ignore these parts above since they are only for debug use ==============================
    # ==================================================================================================================
    
    @property
    def gpt_training(self):
        return len(self.model) > 0

    def set_initial_seed(self, benchmark: bool):
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = benchmark
        assert self.seed
        seed = self.seed
        torch.backends.cudnn.deterministic = True
        os.environ['PYTHONHASHSEED'] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
    
    def dump_log(self):
        if not dist.is_local_master():
            return
        nd = {'is_master': dist.is_visualizer()}
        for k, v in {
            'name': self.exp_name, 
            'tag': self.tag, 
            'cmd': self.cmd, 
            'commit': self.commit_id, 
            'branch': self.branch,
            'cur_it': self.cur_it,
            'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()),
            'opt': self.opt,
            'is_master_node': self.is_master_node,
        }.items():
            if hasattr(v, 'item'):v = v.item()
            if v is None or (isinstance(v, str) and len(v) == 0): continue
            nd[k] = v
        
        with open(self.log_txt_path, 'w') as fp:
            json.dump(nd, fp, indent=2)
    
    def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
        d = (OrderedDict if key_ordered else dict)()
        for k in self.class_variables.keys():
            if k not in {'device', 'dbg_ks_fp'}:     # these are not serializable
                d[k] = getattr(self, k)
        return d
    
    def load_state_dict(self, d: Union[OrderedDict, dict, str]):
        if isinstance(d, str):  # for compatibility with old version
            d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
        for k in d.keys():
            if k in {'is_large_model', 'gpt_training'}:
                continue
            try:
                setattr(self, k, d[k])
            except Exception as e:
                print(f'k={k}, v={d[k]}')
                raise e
    
    @staticmethod
    def set_tf32(tf32: bool):
        if torch.cuda.is_available():
            torch.backends.cudnn.allow_tf32 = bool(tf32)
            torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
            if hasattr(torch, 'set_float32_matmul_precision'):
                torch.set_float32_matmul_precision('high' if tf32 else 'highest')
                print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
            print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
            print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
    
    def __str__(self):
        s = []
        for k in self.class_variables.keys():
            if k not in {'device', 'dbg_ks_fp'}:     # these are not serializable
                s.append(f'  {k:20s}: {getattr(self, k)}')
        s = '\n'.join(s)
        return f'{{\n{s}\n}}\n'


def init_dist_and_get_args():
    for i in range(len(sys.argv)):
        if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
            del sys.argv[i]
            break
    args = Args(explicit_bool=True).parse_args(known_only=True)
    
    if len(args.extra_args) > 0 and args.is_master_node == 0:
        print(f'======================================================================================')
        print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
        print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
        print(f'======================================================================================\n\n')
    
    args.set_tf32(args.tf32)
    
    try: os.makedirs(args.bed, exist_ok=True)
    except: pass
    try: os.makedirs(args.local_out_path, exist_ok=True)
    except: pass
    
    dist.init_distributed_mode(local_out_path=args.local_out_path, fork=False, timeout_minutes=30)
    args.device = dist.get_device()

    # sync seed
    args.seed = int(time.time())
    seed = torch.tensor([args.seed], device=args.device)
    if torch.distributed.is_initialized():
        torch.distributed.all_reduce(seed, op=torch.distributed.ReduceOp.MIN)
    args.seed = seed.item()

    if args.sp_size > 1:
        print(f"INFO: sp_size={args.sp_size}")
        sp_manager.init_sp(args.sp_size)
        
    
    args.r_accu = 1 / args.ac   # gradient accumulation
    args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9')
    args.opt = args.opt.lower().strip()
    
    # gpt args
    if args.gpt_training:
        assert args.vae_path, 'VAE ckpt must be specified when training GPT'
        from infinity.models import alias_dict
        if args.model in alias_dict:
            args.model = alias_dict[args.model]
    
    args.log_txt_path = os.path.join(args.local_out_path, 'log.txt')
    
    args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing
    args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing
    assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \
        f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}."
    
    if len(args.exp_name) == 0:
        args.exp_name = os.path.basename(args.bed) or 'test_exp'
    
    if '-' in args.exp_name:
        args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1)
    else:
        args.tag = 'UK'
    
    if dist.is_master():
        os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}')
    
    if args.sdpa_mem:
        from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp
        enable_flash_sdp(True)
        enable_mem_efficient_sdp(True)
        enable_math_sdp(False)
    print(args)
    if isinstance(args.noise_apply_strength, str):
        args.noise_apply_strength = list(map(float, args.noise_apply_strength.split(',')))
    elif isinstance(args.noise_apply_strength, float):
        args.noise_apply_strength = [args.noise_apply_strength]
    return args