BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
#!/usr/bin/python3
import torch
from infinity.models import Infinity
from infinity.utils import arg_util
def load_visual_tokenizer(args, device=None):
if not device:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.vae_type in [8,12,14,16,18,20,24,32,48,64,128]:
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
print(f'Load VAE from {args.vae_path}')
if args.videovae == 10: # absorb patchify
from infinity.models.videovae.models.load_vae_bsq_wan_absorb_patchify import video_vae_model
vae_local = video_vae_model(args.vae_path, schedule_mode, codebook_dim, global_args=args, test_mode=True).to(device)
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
return vae_local
def build_vae_gpt(args: arg_util.Args, force_flash=False, device='cuda'):
vae_local = load_visual_tokenizer(args, device)
if force_flash: args.flash = True
gpt_kw = dict(
text_channels=args.Ct5,
text_maxlen=args.tlen,
norm_eps=args.norm_eps,
rms_norm=args.rms_norm,
cond_drop_rate=args.cfg,
rand_uncond=args.rand_uncond,
raw_scale_schedule=args.scale_schedule,
top_p=args.topp,
top_k=args.topk,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
add_lvl_embeding_on_first_block=args.add_lvl_embeding_on_first_block,
num_of_label_value=args.num_of_label_value,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=None,
apply_spatial_patchify=args.apply_spatial_patchify,
video_frames=args.video_frames,
other_args=args,
)
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, _ = model_str.rsplit('c', maxsplit=1)
from timm.models import create_model
gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw)
vae_local = vae_local.to('cuda')
assert all(not p.requires_grad for p in vae_local.parameters())
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp