|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
from infinity.models import Infinity |
|
|
from infinity.utils import arg_util |
|
|
|
|
|
def load_visual_tokenizer(args, device=None): |
|
|
if not device: |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
if args.vae_type in [8,12,14,16,18,20,24,32,48,64,128]: |
|
|
schedule_mode = "dynamic" |
|
|
codebook_dim = args.vae_type |
|
|
print(f'Load VAE from {args.vae_path}') |
|
|
|
|
|
if args.videovae == 10: |
|
|
from infinity.models.videovae.models.load_vae_bsq_wan_absorb_patchify import video_vae_model |
|
|
vae_local = video_vae_model(args.vae_path, schedule_mode, codebook_dim, global_args=args, test_mode=True).to(device) |
|
|
else: |
|
|
raise ValueError(f"vae_type {args.vae_type} not supported") |
|
|
else: |
|
|
raise ValueError(f"vae_type {args.vae_type} not supported") |
|
|
return vae_local |
|
|
|
|
|
def build_vae_gpt(args: arg_util.Args, force_flash=False, device='cuda'): |
|
|
vae_local = load_visual_tokenizer(args, device) |
|
|
|
|
|
if force_flash: args.flash = True |
|
|
gpt_kw = dict( |
|
|
text_channels=args.Ct5, |
|
|
text_maxlen=args.tlen, |
|
|
norm_eps=args.norm_eps, |
|
|
rms_norm=args.rms_norm, |
|
|
cond_drop_rate=args.cfg, |
|
|
rand_uncond=args.rand_uncond, |
|
|
raw_scale_schedule=args.scale_schedule, |
|
|
top_p=args.topp, |
|
|
top_k=args.topk, |
|
|
checkpointing=args.enable_checkpointing, |
|
|
pad_to_multiplier=args.pad_to_multiplier, |
|
|
use_flex_attn=args.use_flex_attn, |
|
|
add_lvl_embeding_on_first_block=args.add_lvl_embeding_on_first_block, |
|
|
num_of_label_value=args.num_of_label_value, |
|
|
rope2d_each_sa_layer=args.rope2d_each_sa_layer, |
|
|
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, |
|
|
pn=args.pn, |
|
|
train_h_div_w_list=None, |
|
|
apply_spatial_patchify=args.apply_spatial_patchify, |
|
|
video_frames=args.video_frames, |
|
|
other_args=args, |
|
|
) |
|
|
|
|
|
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') |
|
|
gpt_kw['vae_local'] = vae_local |
|
|
|
|
|
model_str = args.model.replace('vgpt', 'infinity') |
|
|
print(f"{model_str=}") |
|
|
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): |
|
|
model_str, _ = model_str.rsplit('c', maxsplit=1) |
|
|
from timm.models import create_model |
|
|
gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw) |
|
|
vae_local = vae_local.to('cuda') |
|
|
assert all(not p.requires_grad for p in vae_local.parameters()) |
|
|
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) |
|
|
return vae_local, gpt_wo_ddp |
|
|
|