Spaces:
Paused
Paused
Refactor app.py to improve UI layout and rename weight download function; update import path for AutoEncoder in vae.py
f7f1ca1
| import argparse | |
| import torch | |
| from models.bsq_vae.flux_vqgan import AutoEncoder | |
| def load_cnn(model, state_dict, prefix, expand=False, use_linear=False): | |
| delete_keys = [] | |
| loaded_keys = [] | |
| for key in state_dict: | |
| if key.startswith(prefix): | |
| _key = key[len(prefix):] | |
| if _key in model.state_dict(): | |
| # load nn.Conv2d or nn.Linear to nn.Linear | |
| if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key): | |
| load_weights = state_dict[key].squeeze() | |
| elif _key.endswith(".conv.weight") and expand: | |
| if model.state_dict()[_key].shape == state_dict[key].shape: | |
| # 2D cnn to 2D cnn | |
| load_weights = state_dict[key] | |
| else: | |
| # 2D cnn to 3D cnn | |
| _expand_dim = model.state_dict()[_key].shape[2] | |
| load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) | |
| else: | |
| load_weights = state_dict[key] | |
| model.state_dict()[_key].copy_(load_weights) | |
| delete_keys.append(key) | |
| loaded_keys.append(prefix+_key) | |
| # load nn.Conv2d to Conv class | |
| conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."] | |
| if any(k in _key for k in conv_list): | |
| if _key.endswith(".weight"): | |
| conv_key = _key.replace(".weight", ".conv.weight") | |
| if conv_key and conv_key in model.state_dict(): | |
| if model.state_dict()[conv_key].shape == state_dict[key].shape: | |
| # 2D cnn to 2D cnn | |
| load_weights = state_dict[key] | |
| else: | |
| # 2D cnn to 3D cnn | |
| _expand_dim = model.state_dict()[conv_key].shape[2] | |
| load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) | |
| model.state_dict()[conv_key].copy_(load_weights) | |
| delete_keys.append(key) | |
| loaded_keys.append(prefix+conv_key) | |
| if _key.endswith(".bias"): | |
| conv_key = _key.replace(".bias", ".conv.bias") | |
| if conv_key and conv_key in model.state_dict(): | |
| model.state_dict()[conv_key].copy_(state_dict[key]) | |
| delete_keys.append(key) | |
| loaded_keys.append(prefix+conv_key) | |
| # load nn.GroupNorm to Normalize class | |
| if "norm" in _key: | |
| if _key.endswith(".weight"): | |
| norm_key = _key.replace(".weight", ".norm.weight") | |
| if norm_key and norm_key in model.state_dict(): | |
| model.state_dict()[norm_key].copy_(state_dict[key]) | |
| delete_keys.append(key) | |
| loaded_keys.append(prefix+norm_key) | |
| if _key.endswith(".bias"): | |
| norm_key = _key.replace(".bias", ".norm.bias") | |
| if norm_key and norm_key in model.state_dict(): | |
| model.state_dict()[norm_key].copy_(state_dict[key]) | |
| delete_keys.append(key) | |
| loaded_keys.append(prefix+norm_key) | |
| for key in delete_keys: | |
| del state_dict[key] | |
| return model, state_dict, loaded_keys | |
| def vae_model(vqgan_ckpt, schedule_mode, codebook_dim, codebook_size, test_mode=True, patch_size=16, encoder_ch_mult=[1, 2, 4, 4, 4], decoder_ch_mult=[1, 2, 4, 4, 4],): | |
| args=argparse.Namespace( | |
| vqgan_ckpt=vqgan_ckpt, | |
| sd_ckpt=None, | |
| inference_type='image', | |
| save='./imagenet_val_bsq', | |
| save_prediction=True, | |
| image_recon4video=False, | |
| junke_old=False, | |
| device='cuda', | |
| max_steps=1000000.0, | |
| log_every=1, | |
| visu_every=1000, | |
| ckpt_every=1000, | |
| default_root_dir='', | |
| compile='no', | |
| ema='no', | |
| lr=0.0001, | |
| beta1=0.9, | |
| beta2=0.95, | |
| warmup_steps=0, | |
| optim_type='Adam', | |
| disc_optim_type=None, | |
| lr_min=0.0, | |
| warmup_lr_init=0.0, | |
| max_grad_norm=1.0, | |
| max_grad_norm_disc=1.0, | |
| disable_sch=False, | |
| patch_size=patch_size, | |
| temporal_patch_size=4, | |
| embedding_dim=256, | |
| codebook_dim=codebook_dim, | |
| num_quantizers=8, | |
| quantizer_type='MultiScaleBSQ', | |
| use_vae=False, | |
| use_freq_enc=False, | |
| use_freq_dec=False, | |
| preserve_norm=False, | |
| ln_before_quant=False, | |
| ln_init_by_sqrt=False, | |
| use_pxsf=False, | |
| new_quant=True, | |
| use_decay_factor=False, | |
| mask_out=False, | |
| use_stochastic_depth=False, | |
| drop_rate=0.0, | |
| schedule_mode=schedule_mode, | |
| lr_drop=None, | |
| lr_drop_rate=0.1, | |
| keep_first_quant=False, | |
| keep_last_quant=False, | |
| remove_residual_detach=False, | |
| use_out_phi=False, | |
| use_out_phi_res=False, | |
| use_lecam_reg=False, | |
| lecam_weight=0.05, | |
| perceptual_model='vgg16', | |
| base_ch_disc=64, | |
| random_flip=False, | |
| flip_prob=0.5, | |
| flip_mode='stochastic', | |
| max_flip_lvl=1, | |
| not_load_optimizer=False, | |
| use_lecam_reg_zero=False, | |
| freeze_encoder=False, | |
| rm_downsample=False, | |
| random_flip_1lvl=False, | |
| flip_lvl_idx=0, | |
| drop_when_test=False, | |
| drop_lvl_idx=0, | |
| drop_lvl_num=1, | |
| disc_version='v1', | |
| magvit_disc=False, | |
| sigmoid_in_disc=False, | |
| activation_in_disc='leaky_relu', | |
| apply_blur=False, | |
| apply_noise=False, | |
| dis_warmup_steps=0, | |
| dis_lr_multiplier=1.0, | |
| dis_minlr_multiplier=False, | |
| disc_channels=64, | |
| disc_layers=3, | |
| discriminator_iter_start=0, | |
| disc_pretrain_iter=0, | |
| disc_optim_steps=1, | |
| disc_warmup=0, | |
| disc_pool='no', | |
| disc_pool_size=1000, | |
| advanced_disc=False, | |
| recon_loss_type='l1', | |
| video_perceptual_weight=0.0, | |
| image_gan_weight=1.0, | |
| video_gan_weight=1.0, | |
| image_disc_weight=0.0, | |
| video_disc_weight=0.0, | |
| l1_weight=4.0, | |
| gan_feat_weight=0.0, | |
| perceptual_weight=0.0, | |
| kl_weight=0.0, | |
| lfq_weight=0.0, | |
| entropy_loss_weight=0.1, | |
| commitment_loss_weight=0.25, | |
| diversity_gamma=1, | |
| norm_type='group', | |
| disc_loss_type='hinge', | |
| use_checkpoint=False, | |
| precision='fp32', | |
| encoder_dtype='fp32', | |
| upcast_attention='', | |
| upcast_tf32=False, | |
| tokenizer='flux', | |
| pretrained=None, | |
| pretrained_mode='full', | |
| inflation_pe=False, | |
| init_vgen='no', | |
| no_init_idis=False, | |
| init_idis='keep', | |
| init_vdis='no', | |
| enable_nan_detector=False, | |
| turn_on_profiler=False, | |
| profiler_scheduler_wait_steps=10, | |
| debug=True, | |
| video_logger=False, | |
| bytenas='', | |
| username='', | |
| seed=1234, | |
| vq_to_vae=False, | |
| load_not_strict=False, | |
| zero=0, | |
| bucket_cap_mb=40, | |
| manual_gc_interval=1000, | |
| data_path=[''], | |
| data_type=[''], | |
| dataset_list=['imagenet'], | |
| fps=-1, | |
| dataaug='resizecrop', | |
| multi_resolution=False, | |
| random_bucket_ratio=0.0, | |
| sequence_length=16, | |
| resolution=[256, 256], | |
| batch_size=[1], | |
| num_workers=0, | |
| image_channels=3, | |
| codebook_size=codebook_size, | |
| codebook_l2_norm=True, | |
| codebook_show_usage=True, | |
| commit_loss_beta=0.25, | |
| entropy_loss_ratio=0.0, | |
| base_ch=128, | |
| num_res_blocks=2, | |
| encoder_ch_mult=encoder_ch_mult, | |
| decoder_ch_mult=decoder_ch_mult, | |
| dropout_p=0.0, | |
| cnn_type='2d', | |
| cnn_version='v1', | |
| conv_in_out_2d='no', | |
| conv_inner_2d='no', | |
| res_conv_2d='no', | |
| cnn_attention='no', | |
| cnn_norm_axis='spatial', | |
| flux_weight=0, | |
| cycle_weight=0, | |
| cycle_feat_weight=0, | |
| cycle_gan_weight=0, | |
| cycle_loop=0, | |
| z_drop=0.0) | |
| vae = AutoEncoder(args) | |
| use_vae = vae.use_vae | |
| if not use_vae: | |
| num_codes = args.codebook_size | |
| if isinstance(vqgan_ckpt, str): | |
| state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True) | |
| else: | |
| state_dict = args.vqgan_ckpt | |
| if state_dict: | |
| if args.ema == "yes": | |
| vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["ema"], prefix="", expand=False) | |
| else: | |
| vae, new_state_dict, loaded_keys = load_cnn(vae, state_dict["vae"], prefix="", expand=False) | |
| if test_mode: | |
| vae.eval() | |
| [p.requires_grad_(False) for p in vae.parameters()] | |
| return vae |