BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from infinity.models.videovae.utils.misc import is_torch_optim_sch
def inflate_gen(state_dict, temporal_patch_size, spatial_patch_size, strategy="average", inflation_pe=False):
new_state_dict = state_dict.copy()
pe_image0_w = state_dict["encoder.to_patch_emb_first_frame.1.weight"] # image_channel * patch_width * patch_height
pe_image0_b = state_dict["encoder.to_patch_emb_first_frame.1.bias"] # image_channel * patch_width * patch_height
pe_image1_w = state_dict["encoder.to_patch_emb_first_frame.2.weight"] # image_channel * patch_width * patch_height, dim
pe_image1_b = state_dict["encoder.to_patch_emb_first_frame.2.bias"] # image_channel * patch_width * patch_height
pe_image2_w = state_dict["encoder.to_patch_emb_first_frame.3.weight"] # image_channel * patch_width * patch_height
pe_image2_b = state_dict["encoder.to_patch_emb_first_frame.3.bias"] # image_channel * patch_width * patch_height
pd_image0_w = state_dict["decoder.to_pixels_first_frame.0.weight"] # dim, image_channel * patch_width * patch_height
pd_image0_b = state_dict["decoder.to_pixels_first_frame.0.bias"] # image_channel * patch_width * patch_height
pe_video0_w = state_dict["encoder.to_patch_emb.1.weight"]
old_patch_size = int(math.sqrt(pe_image0_w.shape[0] // 3))
old_patch_size_temporal = pe_video0_w.shape[0] // (3 * old_patch_size * old_patch_size)
if old_patch_size != spatial_patch_size or old_patch_size_temporal != temporal_patch_size:
if not inflation_pe:
del new_state_dict["encoder.to_patch_emb_first_frame.1.weight"]
del new_state_dict["encoder.to_patch_emb_first_frame.1.bias"]
del new_state_dict["encoder.to_patch_emb_first_frame.2.weight"]
del new_state_dict["decoder.to_pixels_first_frame.0.weight"]
del new_state_dict["decoder.to_pixels_first_frame.0.bias"]
del new_state_dict["encoder.to_patch_emb.1.weight"]
del new_state_dict["encoder.to_patch_emb.1.bias"]
del new_state_dict["encoder.to_patch_emb.2.weight"]
del new_state_dict["decoder.to_pixels.0.weight"]
del new_state_dict["decoder.to_pixels.0.bias"]
return new_state_dict
print(f"Inflate the patch embedding size from {old_patch_size_temporal}x{old_patch_size}x{old_patch_size} to {temporal_patch_size}x{spatial_patch_size}x{spatial_patch_size}.")
pe_image0_w = F.interpolate(pe_image0_w.unsqueeze(0).unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
pe_image0_b = F.interpolate(pe_image0_b.unsqueeze(0).unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
pe_image1_w = F.interpolate(pe_image1_w.unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0)
new_state_dict["encoder.to_patch_emb_first_frame.1.weight"] = pe_image0_w
new_state_dict["encoder.to_patch_emb_first_frame.1.bias"] = pe_image0_b
new_state_dict["encoder.to_patch_emb_first_frame.2.weight"] = pe_image1_w
pd_image0_w = F.interpolate(pd_image0_w.permute(1, 0).unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0).permute(1, 0)
pd_image0_b = F.interpolate(pd_image0_b.unsqueeze(0).unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
new_state_dict["decoder.to_pixels_first_frame.0.weight"] = pd_image0_w
new_state_dict["decoder.to_pixels_first_frame.0.bias"] = pd_image0_b
pe_video0_w = state_dict["encoder.to_patch_emb.1.weight"]
pe_video0_b = state_dict["encoder.to_patch_emb.1.bias"]
pe_video1_w = state_dict["encoder.to_patch_emb.2.weight"]
pe_video0_w = F.interpolate(pe_video0_w.unsqueeze(0).unsqueeze(0), size=(3 * temporal_patch_size * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
pe_video0_b = F.interpolate(pe_video0_b.unsqueeze(0).unsqueeze(0), size=(3 * temporal_patch_size* spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
pe_video1_w = F.interpolate(pe_video1_w.unsqueeze(0), size=(3 * temporal_patch_size * spatial_patch_size * spatial_patch_size)).squeeze(0)
pd_video0_w = state_dict["decoder.to_pixels.0.weight"]
pd_video0_b = state_dict["decoder.to_pixels.0.bias"]
pd_video0_w = F.interpolate(pd_image0_w.permute(1, 0).unsqueeze(0), size=(3 * temporal_patch_size * spatial_patch_size * spatial_patch_size)).squeeze(0).permute(1, 0)
pd_video0_b = F.interpolate(pd_image0_b.unsqueeze(0).unsqueeze(0), size=(3 * temporal_patch_size * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
new_state_dict["encoder.to_patch_emb.1.weight"] = pe_video0_w
new_state_dict["encoder.to_patch_emb.1.bias"] = pe_video0_b
new_state_dict["encoder.to_patch_emb.2.weight"] = pe_video1_w
new_state_dict["decoder.to_pixels.0.weight"] = pd_video0_w
new_state_dict["decoder.to_pixels.0.bias"] = pd_video0_b
return new_state_dict
if strategy == "average":
pe_video0_w = torch.cat([pe_image0_w/temporal_patch_size] * temporal_patch_size)
pe_video0_b = torch.cat([pe_image0_b/temporal_patch_size] * temporal_patch_size)
pe_video1_w = torch.cat([pe_image1_w/temporal_patch_size] * temporal_patch_size, dim=-1)
pe_video1_b = pe_image1_b # torch.cat([pe_image1_b/temporal_patch_size] * temporal_patch_size)
pe_video2_w = pe_image2_w # torch.cat([pe_image2_w/temporal_patch_size] * temporal_patch_size)
pe_video2_b = pe_image2_b # torch.cat([pe_image2_b/temporal_patch_size] * temporal_patch_size)
elif strategy == "first":
pe_video0_w = torch.cat([pe_image0_w] + [torch.zeros_like(pe_image0_w, dtype=pe_image0_w.dtype)] * (temporal_patch_size - 1))
pe_video0_b = torch.cat([pe_image0_b] + [torch.zeros_like(pe_image0_b, dtype=pe_image0_b.dtype)] * (temporal_patch_size - 1))
pe_video1_w = torch.cat([pe_image1_w] + [torch.zeros_like(pe_image1_w, dtype=pe_image1_w.dtype)] * (temporal_patch_size - 1), dim=-1)
pe_video1_b = pe_image1_b # torch.cat([pe_image1_b] + [torch.zeros_like(pe_image1_b, dtype=pe_image1_b.dtype)] * (temporal_patch_size - 1))
pe_video2_w = pe_image2_w # torch.cat([pe_image2_w] + [torch.zeros_like(pe_image2_w, dtype=pe_image2_w.dtype)] * (temporal_patch_size - 1))
pe_video2_b = pe_image2_b # torch.cat([pe_image2_b] + [torch.zeros_like(pe_image2_b, dtype=pe_image2_b.dtype)] * (temporal_patch_size - 1))
else:
raise NotImplementedError
new_state_dict["encoder.to_patch_emb.1.weight"] = pe_video0_w
new_state_dict["encoder.to_patch_emb.1.bias"] = pe_video0_b
new_state_dict["encoder.to_patch_emb.2.weight"] = pe_video1_w
new_state_dict["encoder.to_patch_emb.2.bias"] = pe_video1_b
new_state_dict["encoder.to_patch_emb.3.weight"] = pe_video2_w
new_state_dict["encoder.to_patch_emb.3.bias"] = pe_video2_b
if strategy == "average":
pd_video0_w = torch.cat([pd_image0_w/temporal_patch_size] * temporal_patch_size)
pd_video0_b = torch.cat([pd_image0_b/temporal_patch_size] * temporal_patch_size)
elif strategy == "first":
pd_video0_w = torch.cat([pd_image0_w] + [torch.zeros_like(pd_image0_w, dtype=pd_image0_w.dtype)] * (temporal_patch_size - 1))
pd_video0_b = torch.cat([pd_image0_b] + [torch.zeros_like(pd_image0_b, dtype=pd_image0_b.dtype)] * (temporal_patch_size - 1))
else:
raise NotImplementedError
new_state_dict["decoder.to_pixels.0.weight"] = pd_video0_w
new_state_dict["decoder.to_pixels.0.bias"] = pd_video0_b
return new_state_dict
def inflate_dis(state_dict, strategy="center"):
print("#" * 50)
print(f"Initialize the video discriminator with {strategy}.")
print("#" * 50)
idis_weights = {k: v for k, v in state_dict.items() if "image_discriminator" in k}
vids_weights = {k: v for k, v in state_dict.items() if "video_discriminator" in k}
new_state_dict = state_dict.copy()
for k in vids_weights.keys():
del new_state_dict[k]
for k in idis_weights.keys():
new_k = "video_discriminator" + k[len("image_discriminator"):]
if "weight" in k and new_state_dict[k].ndim == 4:
old_weight = state_dict[k]
if strategy == "average":
new_weight = old_weight.unsqueeze(2).repeat(1, 1, 4, 1, 1) / 4
elif strategy == "center":
new_weight_ = old_weight# .unsqueeze(2) # O I 1 K K
new_weight = torch.zeros((new_weight_.size(0), new_weight_.size(1), 4, new_weight_.size(2), new_weight_.size(3)), dtype=new_weight_.dtype)
new_weight[:, :, 1] = new_weight_
elif strategy == "first":
new_weight_ = old_weight# .unsqueeze(2)
new_weight = torch.zeros((new_weight_.size(0), new_weight_.size(1), 4, new_weight_.size(2), new_weight_.size(3)), dtype=new_weight_.dtype)
new_weight[:, :, 0] = new_weight_
elif strategy == "last":
new_weight_ = old_weight# .unsqueeze(2)
new_weight = torch.zeros((new_weight_.size(0), new_weight_.size(1), 4, new_weight_.size(2), new_weight_.size(3)), dtype=new_weight_.dtype)
new_weight[:, :, -1] = new_weight_
else:
raise NotImplementedError
new_state_dict[new_k] = new_weight
elif "bias" in k:
new_state_dict[new_k] = state_dict[k]
else:
new_state_dict[new_k] = state_dict[k]
return new_state_dict
def load_unstrictly(state_dict, model, loaded_keys=[]):
missing_keys = []
for name, param in model.named_parameters():
if name in state_dict:
try:
param.data.copy_(state_dict[name])
except:
# print(f"{name} mismatch: param {name}, shape {param.data.shape}, state_dict shape {state_dict[name].shape}")
missing_keys.append(name)
elif name not in loaded_keys:
missing_keys.append(name)
return model, missing_keys
def init_vae_only(state_dict, vae):
vae, missing_keys = load_unstrictly(state_dict, vae)
print(f"missing keys in loading vae: {[key for key in missing_keys if not key.startswith('flux')]}")
return vae
def init_image_disc(state_dict, image_disc, args):
if args.no_init_idis or args.init_idis == "no":
state_dict = {}
else:
state_dict = state_dict["image_disc"]
# load nn.GroupNorm to Normalize class
delete_keys = []
loaded_keys = []
model = image_disc
for key in state_dict:
if key.endswith(".weight"):
norm_key = key.replace(".weight", ".norm.weight")
if norm_key and norm_key in model.state_dict():
model.state_dict()[norm_key].copy_(state_dict[key])
delete_keys.append(key)
loaded_keys.append(norm_key)
if key.endswith(".bias"):
norm_key = key.replace(".bias", ".norm.bias")
if norm_key and norm_key in model.state_dict():
model.state_dict()[norm_key].copy_(state_dict[key])
delete_keys.append(key)
loaded_keys.append(norm_key)
for key in delete_keys:
del state_dict[key]
msg = image_disc.load_state_dict(state_dict, strict=False)
print(f"image disc missing: {[key for key in msg.missing_keys if key not in loaded_keys]}")
print(f"image disc unexpected: {msg.unexpected_keys}")
return image_disc
def init_video_disc(state_dict, video_disc, args):
# init video disc
if args.init_vdis == "no":
video_disc_state_dict = {}
elif args.init_vdis == "keep":
video_disc_state_dict = state_dict["video_disc"]
else:
video_disc_state_dict = inflate_dis(state_dict["video_disc"], strategy=args.init_vdis)
msg = video_disc.load_state_dict(video_disc_state_dict, strict=False)
print(f"video disc missing: {msg.missing_keys}")
print(f"video disc unexpected: {msg.unexpected_keys}")
return video_disc
def init_vit_from_image(state_dict, vae, image_disc, video_disc, args):
if args.init_vgen == "no":
vae_state_dict = state_dict["vae"]
del vae_state_dict["encoder.to_patch_emb.1.weight"]
del vae_state_dict["encoder.to_patch_emb.1.bias"]
del vae_state_dict["encoder.to_patch_emb.2.weight"]
del vae_state_dict["encoder.to_patch_emb.2.bias"]
del vae_state_dict["encoder.to_patch_emb.3.weight"]
del vae_state_dict["encoder.to_patch_emb.3.bias"]
del vae_state_dict["decoder.to_pixels.0.weight"]
del vae_state_dict["decoder.to_pixels.0.bias"]
vae_state_dict = state_dict["vae"]
elif args.init_vgen == "keep":
vae_state_dict = state_dict["vae"]
else:
vae_state_dict = inflate_gen(state_dict["vae"], temporal_patch_size=args.temporal_patch_size, spatial_patch_size=args.patch_size, strategy=args.init_vgen, inflation_pe=args.inflation_pe)
if args.vq_to_vae:
del vae_state_dict["pre_vq_conv.1.weight"]
del vae_state_dict["pre_vq_conv.1.bias"]
msg = vae.load_state_dict(vae_state_dict, strict=False)
print(f"vae missing: {msg.missing_keys}")
print(f"vae unexpected: {msg.unexpected_keys}")
image_disc = init_image_disc(state_dict, image_disc, args)
# video_disc = init_video_disc(state_dict, image_disc, args) # random init video discriminator
return vae, image_disc, video_disc
def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
delete_keys = []
loaded_keys = []
for key in state_dict:
if key.startswith(prefix):
_key = key[len(prefix):]
if _key in model.state_dict():
# load nn.Conv2d or nn.Linear to nn.Linear
if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key):
load_weights = state_dict[key].squeeze()
elif _key.endswith(".conv.weight") and expand:
if model.state_dict()[_key].shape == state_dict[key].shape:
# 2D cnn to 2D cnn
load_weights = state_dict[key]
else:
# 2D cnn to 3D cnn
_expand_dim = model.state_dict()[_key].shape[2]
load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
load_weights = load_weights / _expand_dim # normalize across expand dim
else:
load_weights = state_dict[key]
model.state_dict()[_key].copy_(load_weights)
delete_keys.append(key)
loaded_keys.append(prefix+_key)
# load nn.Conv2d to Conv class
conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."]
if any(k in _key for k in conv_list):
if _key.endswith(".weight"):
conv_key = _key.replace(".weight", ".conv.weight")
if conv_key and conv_key in model.state_dict():
if model.state_dict()[conv_key].shape == state_dict[key].shape:
# 2D cnn to 2D cnn
load_weights = state_dict[key]
else:
# 2D cnn to 3D cnn
_expand_dim = model.state_dict()[conv_key].shape[2]
load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
load_weights = load_weights / _expand_dim # normalize across expand dim
model.state_dict()[conv_key].copy_(load_weights)
delete_keys.append(key)
loaded_keys.append(prefix+conv_key)
if _key.endswith(".bias"):
conv_key = _key.replace(".bias", ".conv.bias")
if conv_key and conv_key in model.state_dict():
model.state_dict()[conv_key].copy_(state_dict[key])
delete_keys.append(key)
loaded_keys.append(prefix+conv_key)
# load nn.GroupNorm to Normalize class
if "norm" in _key:
if _key.endswith(".weight"):
norm_key = _key.replace(".weight", ".norm.weight")
if norm_key and norm_key in model.state_dict():
model.state_dict()[norm_key].copy_(state_dict[key])
delete_keys.append(key)
loaded_keys.append(prefix+norm_key)
if _key.endswith(".bias"):
norm_key = _key.replace(".bias", ".norm.bias")
if norm_key and norm_key in model.state_dict():
model.state_dict()[norm_key].copy_(state_dict[key])
delete_keys.append(key)
loaded_keys.append(prefix+norm_key)
for key in delete_keys:
del state_dict[key]
return model, state_dict, loaded_keys
def init_cnn_from_image(state_dict, vae, image_disc, video_disc, args, expand=False):
vae.encoder, state_dict["vae"], loaded_keys1 = load_cnn(vae.encoder, state_dict["vae"], prefix="encoder.", expand=expand)
vae.decoder, state_dict["vae"], loaded_keys2 = load_cnn(vae.decoder, state_dict["vae"], prefix="decoder.", expand=expand)
loaded_keys = loaded_keys1 + loaded_keys2
# msg = vae.load_state_dict(state_dict["vae"], strict=False)
# print(f"vae missing: {[key for key in msg.missing_keys if key not in loaded_keys]}")
# print(f"vae unexpected: {msg.unexpected_keys}")
vae, missing_keys = load_unstrictly(state_dict["vae"], vae, loaded_keys)
if image_disc:
image_disc = init_image_disc(state_dict, image_disc, args)
### random init video discriminator
# if video_disc:
# video_disc = init_video_disc(state_dict, image_disc, args)
return vae, image_disc, video_disc
def resume_from_ckpt(state_dict, model_optims, load_optims=True):
all_missing_keys = []
# load weights first
for k in model_optims:
if model_optims[k] and state_dict[k] and (not is_torch_optim_sch(model_optims[k])) and k in state_dict:
model_optims[k], missing_keys = load_unstrictly(state_dict[k], model_optims[k])
all_missing_keys += missing_keys
if len(all_missing_keys) == 0 and load_optims:
print("Loading optimizer states")
for k in model_optims:
if model_optims[k] and state_dict[k] and is_torch_optim_sch(model_optims[k]) and k in state_dict:
model_optims[k].load_state_dict(state_dict[k])
else:
print(f"missing weights: {all_missing_keys}, load_optims={load_optims}, do not load optimzer states")
return model_optims, state_dict["step"]
### old version
# def get_last_ckpt(root_dir):
# if not os.path.exists(root_dir): return None, None
# ckpt_files = {}
# for dirpath, dirnames, filenames in os.walk(root_dir):
# for filename in filenames:
# if filename.endswith('.ckpt'):
# num_iter = int(filename.split('-')[1].split('=')[1])
# ckpt_files[num_iter]=os.path.join(dirpath, filename)
# iter_list = list(ckpt_files.keys())
# if len(iter_list) == 0: return None, None
# max_iter = max(iter_list)
# return ckpt_files[max_iter], max_iter