File size: 2,662 Bytes
3d1c0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
#!/usr/bin/python3
import torch

from infinity.models import Infinity
from infinity.utils import arg_util

def load_visual_tokenizer(args, device=None):
    if not device:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.vae_type in [8,12,14,16,18,20,24,32,48,64,128]:
        schedule_mode = "dynamic"
        codebook_dim = args.vae_type # 18
        print(f'Load VAE from {args.vae_path}')

        if args.videovae == 10: # absorb patchify
            from infinity.models.videovae.models.load_vae_bsq_wan_absorb_patchify import video_vae_model
            vae_local = video_vae_model(args.vae_path, schedule_mode, codebook_dim, global_args=args, test_mode=True).to(device)
        else:
            raise ValueError(f"vae_type {args.vae_type} not supported")
    else:
        raise ValueError(f"vae_type {args.vae_type} not supported")
    return vae_local

def build_vae_gpt(args: arg_util.Args, force_flash=False, device='cuda'):
    vae_local = load_visual_tokenizer(args, device)

    if force_flash: args.flash = True
    gpt_kw = dict(
        text_channels=args.Ct5, 
        text_maxlen=args.tlen,
        norm_eps=args.norm_eps, 
        rms_norm=args.rms_norm,
        cond_drop_rate=args.cfg, 
        rand_uncond=args.rand_uncond,
        raw_scale_schedule=args.scale_schedule,
        top_p=args.topp,
        top_k=args.topk,
        checkpointing=args.enable_checkpointing,
        pad_to_multiplier=args.pad_to_multiplier,
        use_flex_attn=args.use_flex_attn,
        add_lvl_embeding_on_first_block=args.add_lvl_embeding_on_first_block,
        num_of_label_value=args.num_of_label_value,
        rope2d_each_sa_layer=args.rope2d_each_sa_layer,
        rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
        pn=args.pn,
        train_h_div_w_list=None,
        apply_spatial_patchify=args.apply_spatial_patchify,
        video_frames=args.video_frames,
        other_args=args,
    )
    
    print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
    gpt_kw['vae_local'] = vae_local
    
    model_str = args.model.replace('vgpt', 'infinity')   # legacy
    print(f"{model_str=}")
    if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
        model_str, _ = model_str.rsplit('c', maxsplit=1)    
    from timm.models import create_model
    gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw)
    vae_local = vae_local.to('cuda')
    assert all(not p.requires_grad for p in vae_local.parameters())
    assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
    return vae_local, gpt_wo_ddp