43 / Meissonic /InfinityStar /infinity /schedules /infinity_star_interact.py
BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import os
import json
import numpy as np
import torch
import torch.nn.functional as F
semantic_scale_ind = 7
detail_frame_inds = [18,19]
def flatten_two_level_list(two_level_list):
flatten_list = []
for item in two_level_list:
flatten_list.extend(item)
return flatten_list
def interpolate(tensor, size, mode, quantizer, is_semantic_scale):
"""
arguments:
tensor: (B,C,T,H,W)
size: (C1,T,H1,W1)
mode: str
quantizer: quantizer
is_semantic_scale: bool
return:
tensor: (B,*size)
"""
B, C, T, H, W = tensor.shape
C1, T, H1, W1 = size
if quantizer.other_args.use_learnable_dim_proj:
if is_semantic_scale:
if C > C1:
proj = quantizer.semantic_proj_down
elif C < C1:
proj = quantizer.semantic_proj_up
else:
if C > C1:
proj = quantizer.detail_proj_down
elif C < C1:
proj = quantizer.detail_proj_up
if C != C1:
tensor = tensor.permute(0,2,3,4,1) # (B,C,T,H,W) -> (B,T,H,W,C)
tensor = proj(tensor) # (B,T,H,W,C1)
tensor = tensor.permute(0,4,1,2,3) # (B,T,H,W,C1) -> (B,C1,T,H,W)
tensor = F.interpolate(tensor, size=(T, H1, W1), mode=mode) # (B,C1,T,H,W) -> (B,C1,T,H1,W1)
return tensor
else:
tensor = tensor.permute(0,2,1,3,4) # (B,C,T,H,W) -> (B,T,C,H,W)
tensor = F.interpolate(tensor, size=(C1, H1, W1), mode=mode)
tensor = tensor.permute(0,2,1,3,4) # (B,T,C1,H1,W1) -> (B,C1,T,H1,W1)
return tensor
def get_scale_pack_info(scale_schedule, first_full_spatial_size_scale_index, args):
meta = {}
sid2clipid_innsid = {}
clipid_innsid2sid = {}
scales_per_clip = first_full_spatial_size_scale_index + 1
compress_frames_inner_clip = args.frames_inner_clip
total_clips = len(scale_schedule) // scales_per_clip
context_clips = args.context_frames // args.frames_inner_clip
for si in range(len(scale_schedule)):
clipid = si // scales_per_clip
if clipid == 0:
frame_ss, frame_ee = 0, scale_schedule[scales_per_clip*1-1][0]
else:
frame_ss = scale_schedule[scales_per_clip*1-1][0] + (clipid-1) * compress_frames_inner_clip
frame_ee = frame_ss + scale_schedule[scales_per_clip*(clipid+1)-1][0]
if context_clips < total_clips-1:
assert scale_schedule[si][0] == compress_frames_inner_clip
sid2clipid_innsid[si] = (clipid, si % scales_per_clip)
clipid_innsid2sid[(clipid, si % scales_per_clip)] = si
# add clip ind for ref
if si <= first_full_spatial_size_scale_index:
meta[si] = {
'clipid': clipid,
'frame_ss': frame_ss,
'frame_ee': frame_ee,
'left_ref': [-1],
'right_ref': [-1],
}
else:
meta[si] = {
'clipid': clipid,
'frame_ss': frame_ss,
'frame_ee': frame_ee,
'left_ref': [clipid-1],
'right_ref': [-1],
}
# append inner scale ind to clip ind, (frame pack)
if args.context_from_largest_no > 0:
meta[si]['left_ref'] = [(meta[si]['left_ref'][i], max(0, scales_per_clip - args.context_from_largest_no - args.context_interval*i)) for i in range(len(meta[si]['left_ref']))]
meta[si]['right_ref'] = [(meta[si]['right_ref'][i], max(0, scales_per_clip - args.context_from_largest_no - args.context_interval*i)) for i in range(len(meta[si]['right_ref']))]
for si in meta:
meta[si]['left_ref_sids'], meta[si]['right_ref_sids'] = [], []
for clipid, innsid in (meta[si]['left_ref']):
if clipid != -1:
meta[si]['left_ref_sids'].append(clipid_innsid2sid[(clipid, innsid)])
for fid, innsid in (meta[si]['right_ref']):
if fid != -1:
meta[si]['right_ref_sids'].append(clipid_innsid2sid[(clipid, innsid)])
meta[si]['ref_sids'] = meta[si]['left_ref_sids'] + meta[si]['right_ref_sids']
return meta
def video_encode(
vae,
inp_B3HW,
vae_features=None,
self_correction=None,
device='cuda',
args=None,
infer_mode=False,
rope2d_freqs_grid=None,
dynamic_resolution_h_w=None,
text_lens=[],
caption_nums=None,
rank=0,
vis_verbose=False,
np_generator=None,
skip_last=0,
train_max_token_len=0,
first_frame_features=[],
**kwargs,
):
if vae_features is None:
raw_features, _, _ = vae.encode_for_raw_features(inp_B3HW, scale_schedule=None, slice=True)
raw_features_list = [raw_features]
x_recon_raw = vae.decode(raw_features[0], slice=True)
x_recon_raw = torch.clamp(x_recon_raw, min=-1, max=1)
print(f'raw_features.shape: {raw_features[0].shape}')
else:
raw_features_list = vae_features
if np_generator is not None:
random_obj = np_generator
else:
random_obj = np.random.default_rng()
# raw_features_list: list of [1,d,t,h,w]:
gt_all_bit_indices = []
pred_all_bit_indices = []
var_input_list = []
sequece_packing_scales = [] # with trunk
flatten_packing_scales = []
h_div_w_template_list = np.array(list(dynamic_resolution_h_w.keys()))
visual_rope_cache_list = []
noise_list = []
scale_pack_info_list = []
image_scale_repetition = json.loads(args.image_scale_repetition)
video_scale_repetition = json.loads(args.video_scale_repetition)
scales_in_one_clip = dynamic_resolution_h_w[h_div_w_template_list[0]][args.pn]['scales_in_one_clip']
other_info_by_scale = []
select_repeat_idx_list = []
examples = len(raw_features_list)
assert len(image_scale_repetition) == len(video_scale_repetition), f'{len(image_scale_repetition)} != {len(video_scale_repetition)}'
assert examples == 1, f'currently only support examples==1, buf found {examples=}'
with torch.amp.autocast('cuda', enabled = False):
for example_ind, complete_raw_features in enumerate(raw_features_list):
complete_raw_features = complete_raw_features[0]
if first_frame_features[example_ind] is None:
first_frame_feature_ = complete_raw_features[:,:,0:1] # [B,d,1,h,w]
else:
first_frame_feature_ = first_frame_features[example_ind][0] # [B,d,1,h,w]
# assert complete_raw_features.shape[-3] > 21
# 前21帧,构成一个 I1V1 的 clip
# 后面的 t-21 帧,构成一个 V2 的 clip,它的 condition 是 V1 resize 后的结果和 V1 的最后一帧
new_raw_features_list = [complete_raw_features[:,:,:21], complete_raw_features[:,:,21:]]
t, h, w = new_raw_features_list[0].shape[-3:]
h_div_w = h / w
mapped_h_div_w_template = h_div_w_template_list[np.argmin(np.abs(h_div_w-h_div_w_template_list))]
min_t = min(dynamic_resolution_h_w[mapped_h_div_w_template][args.pn]['pt2scale_schedule'].keys())
image_scale_schedule = dynamic_resolution_h_w[mapped_h_div_w_template][args.pn]['pt2scale_schedule'][min_t]
scale_schedule = dynamic_resolution_h_w[mapped_h_div_w_template][args.pn]['pt2scale_schedule'][t]
for ind, raw_features in enumerate(new_raw_features_list):
if raw_features.numel() == 0:
break
mode = 'first_iv_clip'
global_si_base = 0
if ind == 1:
scale_schedule = scale_schedule[scales_in_one_clip:]
scale_schedule = [(raw_features.shape[-3], ph, pw) for pt, ph, pw in scale_schedule]
mode = 'second_v_clip'
global_si_base = sum(image_scale_repetition) + sum(video_scale_repetition)
if args.apply_spatial_patchify:
vae_scale_schedule = [(pt, ph*2, pw*2) for pt, ph, pw in scale_schedule]
else:
vae_scale_schedule = scale_schedule
first_full_spatial_size_scale_index = len(image_scale_schedule) - 1
scale_pack_info = get_scale_pack_info(vae_scale_schedule, first_full_spatial_size_scale_index, args)
scale_pack_info_list.append(scale_pack_info)
if raw_features.dim() == 4:
codes_out = raw_features.unsqueeze(2) # [B, d, t, h, w]
else:
codes_out = raw_features # [B, d, t, h, w]
# print(f'{raw_features.shape=}, {scale_schedule=}')
v_d = codes_out.shape[1]
B, C, T, H, W = codes_out.shape
if args.noise_input:
noise = torch.randn((B, v_d, *vae_scale_schedule[0]), device=device, dtype=raw_features.dtype)
else:
noise = torch.zeros((B, v_d, *vae_scale_schedule[0]), device=device, dtype=raw_features.dtype)
if infer_mode: noise_list.append(noise)
next_var_input = noise
valid_scales = len(vae_scale_schedule) - skip_last
assert len(image_scale_repetition) == len(image_scale_schedule), f'{len(image_scale_repetition)} != {len(image_scale_schedule)}'
real_si = 0
noise_apply_strength = self_correction.noise_apply_strength
for si in range(valid_scales):
pt, ph, pw = vae_scale_schedule[si]
rel_si_in_one_clip = si % len(image_scale_schedule)
if si < len(image_scale_schedule): # image
repeat_times = image_scale_repetition[rel_si_in_one_clip]
else:
repeat_times = video_scale_repetition[rel_si_in_one_clip]
select_repeat_idx = random_obj.integers(0, repeat_times)
select_repeat_idx_list.append(select_repeat_idx)
frame_ss, frame_ee = scale_pack_info[si]['frame_ss'], scale_pack_info[si]['frame_ee']
target = codes_out[:,:,frame_ss:frame_ee]
for repeat_idx in range(repeat_times):
if (not infer_mode) and (repeat_idx==select_repeat_idx):
visual_rope_cache_list.append(get_visual_rope_embeds(rope2d_freqs_grid, scale_schedule[-1], scale_schedule[si], list(range(frame_ss, frame_ee)), real_si, device))
if next_var_input.shape[-3:] != target.shape[-3:]:
next_var_input = F.interpolate(next_var_input, size=target.shape[-3:], mode=vae.quantizer.z_interplote_up).contiguous()
cum_var_input = next_var_input
this_scale_var_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si], mode=vae.quantizer.z_interplote_down).contiguous()
residual = target - cum_var_input
if args.use_two_stage_lfq:
if rel_si_in_one_clip >= args.semantic_scales:
is_semantic_scale = False
C1 = vae.quantizer.detail_scale_dim
lfq = vae.quantizer.lfq_detail
else:
is_semantic_scale = True
C1 = vae.quantizer.semantic_scale_dim
lfq = vae.quantizer.lfq_semantic
residual = interpolate(residual, size=(C1, *vae_scale_schedule[si]), mode=vae.quantizer.z_interplote_down, quantizer=vae.quantizer, is_semantic_scale=is_semantic_scale).contiguous()
else:
residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=vae.quantizer.z_interplote_down).contiguous()
try:
lfq = vae.quantizer.lfq_detail
except:
lfq = vae.quantizer.lfq
quantized, _, bit_indices, loss = lfq(residual) # quantized shape: [B, d, t, h, w], bit_indices shape: [B,t,h,w,d]
if args.reduce_accumulate_error_method == 'bsc':
if si < min(len(vae_scale_schedule)-1, self_correction.noise_apply_layers):
pred_bit_indices, quantized = self_correction.apply_noise_requant(bit_indices, quantized, args, device, si, lfq, noise_apply_strength, num_lvl=2, np_generator=random_obj)
else:
pred_bit_indices = bit_indices
else:
raise NotImplementedError(args.reduce_accumulate_error_method)
if infer_mode or (repeat_idx==select_repeat_idx):
pred_all_bit_indices.append(pred_bit_indices)
var_input_list.append(this_scale_var_input)
gt_all_bit_indices.append(bit_indices)
other_info_by_scale.append({'largest_scale': scale_schedule[-1], 'real_si': si, 'mode': mode, 'global_si': real_si+global_si_base})
if args.use_two_stage_lfq:
quantized_scaled = interpolate(quantized, size=target.shape[-4:], mode=vae.quantizer.z_interplote_up, quantizer=vae.quantizer, is_semantic_scale=is_semantic_scale).contiguous()
else:
quantized_scaled = F.interpolate(quantized, size=target.shape[-3:], mode=vae.quantizer.z_interplote_up).contiguous()
next_var_input = cum_var_input + quantized_scaled
real_si += 1
if si < len(vae_scale_schedule)-1: # since first scale is [sos], here we only need len(vae_scale_schedule)-1 cum_var_input and x_BLC_wo_prefix
if vae_scale_schedule[si][-2:] == vae_scale_schedule[-1][-2:]:
if args.noise_input:
next_var_input = torch.randn((B, v_d, *vae_scale_schedule[si+1]), device=device, dtype=raw_features.dtype)
else:
next_var_input = torch.zeros((B, v_d, *vae_scale_schedule[si+1]), device=device, dtype=raw_features.dtype)
if infer_mode: noise_list.append(next_var_input)
sequece_packing_scales.append(scale_schedule[:valid_scales])
if ind == 0:
former_clip_features = raw_features[:,:,-20:]
if infer_mode:
return noise_list, x_recon_raw, pred_all_bit_indices, None, None, scale_pack_info
if vis_verbose:
print(f'Rank={rank}, {sequece_packing_scales=} {select_repeat_idx_list=}', force=True)
if args.train_second_clip_only:
drop_scales = len(sequece_packing_scales[0])
sequece_packing_scales = sequece_packing_scales[1:]
scale_pack_info_list = scale_pack_info_list[1:]
gt_all_bit_indices = gt_all_bit_indices[drop_scales:]
pred_all_bit_indices = pred_all_bit_indices[drop_scales:]
other_info_by_scale = other_info_by_scale[drop_scales:]
var_input_list = var_input_list[drop_scales:]
visual_rope_cache_list = visual_rope_cache_list[drop_scales:]
flatten_packing_scales = flatten_two_level_list(sequece_packing_scales)
def add_noise(features, noise_choices=[0.00, 0.15, 0.30]):
feature_std = features.std()
rand_noise_strength = np.random.choice(noise_choices)
return features + rand_noise_strength * feature_std * torch.randn_like(features)
# add conditions
semantic_condition = F.interpolate(former_clip_features, size=(20, *scale_schedule[semantic_scale_ind][-2:]), mode=vae.quantizer.z_interplote_down)
semantic_condition = add_noise(semantic_condition)
assert former_clip_features.shape[2] == 20
detail_condition = torch.cat([first_frame_feature_, add_noise(former_clip_features[:,:,detail_frame_inds])], dim=2)
var_input_list.extend([semantic_condition, detail_condition])
visual_rope_cache_list.append(get_visual_rope_embeds(rope2d_freqs_grid, detail_condition.shape[-3:], semantic_condition.shape[-3:], list(range(1, 21)), 800, device))
visual_rope_cache_list.append(get_visual_rope_embeds(rope2d_freqs_grid, detail_condition.shape[-3:], detail_condition.shape[-3:], [0]+[item+1 for item in detail_frame_inds], 801, device))
# set scale_lengths and querysid_refsid
scale_lengths = [ pt * ph * pw for pt,ph,pw in flatten_packing_scales]
scale_lengths = scale_lengths + [torch.tensor(semantic_condition.shape[-3:]).prod().item(), torch.tensor(detail_condition.shape[-3:]).prod().item()]
scale_lengths = scale_lengths + text_lens
valid_scales = len(scale_lengths)
pad_seq_len = train_max_token_len - np.sum(scale_lengths)
assert pad_seq_len >= 0, f'pad_seq_len: {pad_seq_len} < 0, {scale_lengths=}'
if pad_seq_len:
scale_lengths = scale_lengths + [pad_seq_len]
max_sid_nums = 2000
querysid_refsid = torch.zeros((max_sid_nums, max_sid_nums), device=args.device, dtype=torch.bool) # Attention! this shape should be the same for different iterations !!!
for i in range(valid_scales):
querysid_refsid[i][i] = True
base = 0
for ind, scale_schedule in enumerate(sequece_packing_scales):
real_example_ind = ind // 2 # for each example, there are two scale_schedule
scale_pack_info = scale_pack_info_list[ind]
for local_querysid in range(len(scale_schedule)):
global_querysid = base + local_querysid
if other_info_by_scale[base+local_querysid]['mode'] == 'first_iv_clip':
global_text_sid = len(flatten_packing_scales) + 2 + sum(caption_nums[:real_example_ind]) + 0
querysid_refsid[global_querysid][global_text_sid] = True
elif other_info_by_scale[base+local_querysid]['mode'] == 'second_v_clip':
global_text_sid = len(flatten_packing_scales) + 2 + sum(caption_nums[:real_example_ind]) + 1
querysid_refsid[global_querysid][global_text_sid] = True
querysid_refsid[global_querysid][len(flatten_packing_scales)+0] = True # i can see semantic condition
querysid_refsid[global_querysid][len(flatten_packing_scales)+1] = True # i can see detail condition
else:
raise ValueError(f'Unknown mode: {other_info_by_scale[base+local_querysid]["mode"]}')
for local_refsid in (scale_pack_info[local_querysid]['ref_sids']):
global_refsid = base + local_refsid
querysid_refsid[global_querysid][global_refsid] = True
base += len(scale_schedule)
gt_ms_idx_Bl = []
for item in gt_all_bit_indices:
if args.apply_spatial_patchify:
# item shape: (B,t,H,W,d)
item = item.permute(0,1,4,2,3) # (B,t,d,H,W)
# (B,t,d,H,W) -> (B,t,4d,H/2,W/2)
item = torch.nn.functional.pixel_unshuffle(item, 2)
_, tt, dd, hh, ww = item.shape
# (B,t,4d,H/2,W/2) -> (B,t,H/2,W/2,4d) -> (B,t*H/2*w/2,4d)
item = item.permute(0,1,3,4,2).reshape(B, tt*hh*ww, dd)
else:
_, tt, hh, ww, dd = item.shape
item = item.reshape(B, tt*hh*ww, dd)
gt_ms_idx_Bl.append(item.type(torch.long))
gt_BLC = gt_ms_idx_Bl # torch.cat(gt_ms_idx_Bl, 1).contiguous().type(torch.long)
for i in range(len(var_input_list)):
if args.apply_spatial_patchify:
# (B,d,t,H,W) -> (B,t,d,H,W) -> (B,t,4d,H/2,W/2) -> (B,t,H/2,W/2,4d)
var_input_list[i] = torch.nn.functional.pixel_unshuffle(var_input_list[i].permute(0,2,1,3,4), 2).permute(0,1,3,4,2)
var_input_list[i] = var_input_list[i].reshape(B, -1, 4*vae.codebook_dim)
else:
# (B,d,t,H,W) -> (B,t,H,W,d)
var_input_list[i] = var_input_list[i].permute(0,2,3,4,1)
var_input_list[i] = var_input_list[i].reshape(B, -1, vae.codebook_dim)
x_BLC = torch.cat(var_input_list, 1)
visual_rope_cache = torch.cat(visual_rope_cache_list, dim=4)
x_BLC_mask = None
return x_BLC, x_BLC_mask, gt_BLC, pred_all_bit_indices, visual_rope_cache, sequece_packing_scales, scale_lengths, querysid_refsid, other_info_by_scale, pad_seq_len
def video_decode(
vae,
all_indices,
scale_schedule,
label_type,
args=None,
noise_list=None,
trunc_scales=-1,
**kwargs,
):
image_scale_repetition = json.loads(args.image_scale_repetition)
video_scale_repetition = json.loads(args.video_scale_repetition)
assert len(image_scale_repetition) == len(video_scale_repetition), f'{len(image_scale_repetition)} != {len(video_scale_repetition)}'
real_si = 0
noise_ptr = 0
summed_codes = []
scales_in_one_clip = args.first_full_spatial_size_scale_index+1
clips = len(noise_list) - 1
for clip_id in range(clips):
if clip_id == 1:
scale_schedule = scale_schedule[(args.first_full_spatial_size_scale_index+1):]
t = all_indices[-1].shape[1] # [B,t,h,w,d]
scale_schedule = [(t, ph, pw) for pt, ph, pw in scale_schedule]
summed_codes.append(noise_list[noise_ptr])
noise_ptr += 1
v_d = summed_codes[0].shape[1]
for si, (pt, ph, pw) in enumerate(scale_schedule):
if si < len(image_scale_repetition): # image
repeat_times = image_scale_repetition[si%len(image_scale_repetition)]
else:
repeat_times = video_scale_repetition[si%len(image_scale_repetition)]
for repeat_idx in range(repeat_times):
tgt_shape = (pt, scale_schedule[-1][-2], scale_schedule[-1][-1])
if args.use_two_stage_lfq:
if (si % scales_in_one_clip) >= args.semantic_scales:
is_semantic_scale = False
lfq = vae.quantizer.lfq_detail
else:
is_semantic_scale = True
lfq = vae.quantizer.lfq_semantic
codes = lfq.indices_to_codes(all_indices[real_si], label_type)
codes = interpolate(codes, size=(v_d, *tgt_shape), mode=vae.quantizer.z_interplote_up, quantizer=vae.quantizer, is_semantic_scale=is_semantic_scale).contiguous()
else:
codes = vae.quantizer.lfq_detail.indices_to_codes(all_indices[real_si], label_type)
codes = F.interpolate(codes, size=tgt_shape, mode=vae.quantizer.z_interplote_up).contiguous()
summed_codes[-1] = F.interpolate(summed_codes[-1], size=tgt_shape, mode=vae.quantizer.z_interplote_up).contiguous()
summed_codes[-1] += codes
real_si += 1
if si < len(scale_schedule)-1 and scale_schedule[si][-2:] == tgt_shape[-2:]:
summed_codes.append(noise_list[noise_ptr])
noise_ptr += 1
summed_codes = torch.cat(summed_codes, dim=-3)
x_recon = vae.decode(summed_codes, slice=True)
x_recon = torch.clamp(x_recon, min=-1, max=1)
return x_recon
def get_visual_rope_embeds(rope2d_freqs_grid, largest_scale, current_scale, t_list, real_sid, device=None):
# freqs_scales: (2, max_scales, ceil(dim_div_2 / 4))
# freqs_frames: (2, max_frames, ceil(dim_div_2 / 4))
rope2d_freqs_grid['freqs_scales'] = rope2d_freqs_grid['freqs_scales'].to(device)
rope2d_freqs_grid['freqs_frames'] = rope2d_freqs_grid['freqs_frames'].to(device)
rope2d_freqs_grid['freqs_height'] = rope2d_freqs_grid['freqs_height'].to(device)
rope2d_freqs_grid['freqs_width'] = rope2d_freqs_grid['freqs_width'].to(device)
_, uph, upw = largest_scale
pt, ph, pw = current_scale
dim_div_2_div_4 = rope2d_freqs_grid['freqs_scales'].shape[2]
dim_div_2 = dim_div_2_div_4 * 4
f_scales = rope2d_freqs_grid['freqs_scales'][:, real_sid].reshape(2, 1, dim_div_2_div_4)
f_frames = rope2d_freqs_grid['freqs_frames'][:, t_list]
f_height = rope2d_freqs_grid['freqs_height'][:, (torch.arange(ph) * (uph / ph)).round().int()]
f_width = rope2d_freqs_grid['freqs_width'][:, (torch.arange(pw) * (upw / pw)).round().int()]
rope_embeds = torch.cat([
f_scales[ :, :, None, None, None, :].expand(-1, -1, pt, ph, pw, -1),
f_frames[ :, None, :, None, None, :].expand(-1, 1, -1, ph, pw, -1),
f_height[ :, None, None, :, None, :].expand(-1, 1, pt, -1, pw, -1),
f_width[ :, None, None, None, :, :].expand(-1, 1, pt, ph, -1, -1),
], dim=-1) # (2, 1, pt, ph, pw, dim_div_2)
rope_embeds = rope_embeds.reshape(2, 1, 1, 1, 1*pt*ph*pw, dim_div_2) # (2, 1, 1, 1, 1*pt*ph*pw, dim_div_2)
return rope_embeds