| | |
| | |
| |
|
| | import os |
| | import json |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | semantic_scale_ind = 7 |
| | detail_frame_inds = [18,19] |
| |
|
| | def flatten_two_level_list(two_level_list): |
| | flatten_list = [] |
| | for item in two_level_list: |
| | flatten_list.extend(item) |
| | return flatten_list |
| |
|
| | def interpolate(tensor, size, mode, quantizer, is_semantic_scale): |
| | """ |
| | arguments: |
| | tensor: (B,C,T,H,W) |
| | size: (C1,T,H1,W1) |
| | mode: str |
| | quantizer: quantizer |
| | is_semantic_scale: bool |
| | return: |
| | tensor: (B,*size) |
| | """ |
| | B, C, T, H, W = tensor.shape |
| | C1, T, H1, W1 = size |
| | if quantizer.other_args.use_learnable_dim_proj: |
| | if is_semantic_scale: |
| | if C > C1: |
| | proj = quantizer.semantic_proj_down |
| | elif C < C1: |
| | proj = quantizer.semantic_proj_up |
| | else: |
| | if C > C1: |
| | proj = quantizer.detail_proj_down |
| | elif C < C1: |
| | proj = quantizer.detail_proj_up |
| | if C != C1: |
| | tensor = tensor.permute(0,2,3,4,1) |
| | tensor = proj(tensor) |
| | tensor = tensor.permute(0,4,1,2,3) |
| | tensor = F.interpolate(tensor, size=(T, H1, W1), mode=mode) |
| | return tensor |
| | else: |
| | tensor = tensor.permute(0,2,1,3,4) |
| | tensor = F.interpolate(tensor, size=(C1, H1, W1), mode=mode) |
| | tensor = tensor.permute(0,2,1,3,4) |
| | return tensor |
| |
|
| | def get_scale_pack_info(scale_schedule, first_full_spatial_size_scale_index, args): |
| | meta = {} |
| | sid2clipid_innsid = {} |
| | clipid_innsid2sid = {} |
| | scales_per_clip = first_full_spatial_size_scale_index + 1 |
| | compress_frames_inner_clip = args.frames_inner_clip |
| | total_clips = len(scale_schedule) // scales_per_clip |
| | context_clips = args.context_frames // args.frames_inner_clip |
| | for si in range(len(scale_schedule)): |
| | clipid = si // scales_per_clip |
| | if clipid == 0: |
| | frame_ss, frame_ee = 0, scale_schedule[scales_per_clip*1-1][0] |
| | else: |
| | frame_ss = scale_schedule[scales_per_clip*1-1][0] + (clipid-1) * compress_frames_inner_clip |
| | frame_ee = frame_ss + scale_schedule[scales_per_clip*(clipid+1)-1][0] |
| | if context_clips < total_clips-1: |
| | assert scale_schedule[si][0] == compress_frames_inner_clip |
| | sid2clipid_innsid[si] = (clipid, si % scales_per_clip) |
| | clipid_innsid2sid[(clipid, si % scales_per_clip)] = si |
| | |
| | if si <= first_full_spatial_size_scale_index: |
| | meta[si] = { |
| | 'clipid': clipid, |
| | 'frame_ss': frame_ss, |
| | 'frame_ee': frame_ee, |
| | 'left_ref': [-1], |
| | 'right_ref': [-1], |
| | } |
| | else: |
| | meta[si] = { |
| | 'clipid': clipid, |
| | 'frame_ss': frame_ss, |
| | 'frame_ee': frame_ee, |
| | 'left_ref': [clipid-1], |
| | 'right_ref': [-1], |
| | } |
| | |
| | if args.context_from_largest_no > 0: |
| | meta[si]['left_ref'] = [(meta[si]['left_ref'][i], max(0, scales_per_clip - args.context_from_largest_no - args.context_interval*i)) for i in range(len(meta[si]['left_ref']))] |
| | meta[si]['right_ref'] = [(meta[si]['right_ref'][i], max(0, scales_per_clip - args.context_from_largest_no - args.context_interval*i)) for i in range(len(meta[si]['right_ref']))] |
| | for si in meta: |
| | meta[si]['left_ref_sids'], meta[si]['right_ref_sids'] = [], [] |
| | for clipid, innsid in (meta[si]['left_ref']): |
| | if clipid != -1: |
| | meta[si]['left_ref_sids'].append(clipid_innsid2sid[(clipid, innsid)]) |
| | for fid, innsid in (meta[si]['right_ref']): |
| | if fid != -1: |
| | meta[si]['right_ref_sids'].append(clipid_innsid2sid[(clipid, innsid)]) |
| | meta[si]['ref_sids'] = meta[si]['left_ref_sids'] + meta[si]['right_ref_sids'] |
| | return meta |
| |
|
| |
|
| | def video_encode( |
| | vae, |
| | inp_B3HW, |
| | vae_features=None, |
| | self_correction=None, |
| | device='cuda', |
| | args=None, |
| | infer_mode=False, |
| | rope2d_freqs_grid=None, |
| | dynamic_resolution_h_w=None, |
| | text_lens=[], |
| | caption_nums=None, |
| | rank=0, |
| | vis_verbose=False, |
| | np_generator=None, |
| | skip_last=0, |
| | train_max_token_len=0, |
| | first_frame_features=[], |
| | **kwargs, |
| | ): |
| | if vae_features is None: |
| | raw_features, _, _ = vae.encode_for_raw_features(inp_B3HW, scale_schedule=None, slice=True) |
| | raw_features_list = [raw_features] |
| | x_recon_raw = vae.decode(raw_features[0], slice=True) |
| | x_recon_raw = torch.clamp(x_recon_raw, min=-1, max=1) |
| | print(f'raw_features.shape: {raw_features[0].shape}') |
| | else: |
| | raw_features_list = vae_features |
| |
|
| | if np_generator is not None: |
| | random_obj = np_generator |
| | else: |
| | random_obj = np.random.default_rng() |
| | |
| | |
| | gt_all_bit_indices = [] |
| | pred_all_bit_indices = [] |
| | var_input_list = [] |
| | sequece_packing_scales = [] |
| | flatten_packing_scales = [] |
| | h_div_w_template_list = np.array(list(dynamic_resolution_h_w.keys())) |
| | visual_rope_cache_list = [] |
| | noise_list = [] |
| | scale_pack_info_list = [] |
| | image_scale_repetition = json.loads(args.image_scale_repetition) |
| | video_scale_repetition = json.loads(args.video_scale_repetition) |
| | scales_in_one_clip = dynamic_resolution_h_w[h_div_w_template_list[0]][args.pn]['scales_in_one_clip'] |
| | other_info_by_scale = [] |
| | select_repeat_idx_list = [] |
| | examples = len(raw_features_list) |
| | assert len(image_scale_repetition) == len(video_scale_repetition), f'{len(image_scale_repetition)} != {len(video_scale_repetition)}' |
| | assert examples == 1, f'currently only support examples==1, buf found {examples=}' |
| | with torch.amp.autocast('cuda', enabled = False): |
| | for example_ind, complete_raw_features in enumerate(raw_features_list): |
| | complete_raw_features = complete_raw_features[0] |
| | if first_frame_features[example_ind] is None: |
| | first_frame_feature_ = complete_raw_features[:,:,0:1] |
| | else: |
| | first_frame_feature_ = first_frame_features[example_ind][0] |
| | |
| | |
| | |
| | new_raw_features_list = [complete_raw_features[:,:,:21], complete_raw_features[:,:,21:]] |
| | t, h, w = new_raw_features_list[0].shape[-3:] |
| | h_div_w = h / w |
| | mapped_h_div_w_template = h_div_w_template_list[np.argmin(np.abs(h_div_w-h_div_w_template_list))] |
| | min_t = min(dynamic_resolution_h_w[mapped_h_div_w_template][args.pn]['pt2scale_schedule'].keys()) |
| | image_scale_schedule = dynamic_resolution_h_w[mapped_h_div_w_template][args.pn]['pt2scale_schedule'][min_t] |
| | scale_schedule = dynamic_resolution_h_w[mapped_h_div_w_template][args.pn]['pt2scale_schedule'][t] |
| | |
| | for ind, raw_features in enumerate(new_raw_features_list): |
| | if raw_features.numel() == 0: |
| | break |
| | mode = 'first_iv_clip' |
| | global_si_base = 0 |
| | if ind == 1: |
| | scale_schedule = scale_schedule[scales_in_one_clip:] |
| | scale_schedule = [(raw_features.shape[-3], ph, pw) for pt, ph, pw in scale_schedule] |
| | mode = 'second_v_clip' |
| | global_si_base = sum(image_scale_repetition) + sum(video_scale_repetition) |
| |
|
| | if args.apply_spatial_patchify: |
| | vae_scale_schedule = [(pt, ph*2, pw*2) for pt, ph, pw in scale_schedule] |
| | else: |
| | vae_scale_schedule = scale_schedule |
| | first_full_spatial_size_scale_index = len(image_scale_schedule) - 1 |
| | scale_pack_info = get_scale_pack_info(vae_scale_schedule, first_full_spatial_size_scale_index, args) |
| | scale_pack_info_list.append(scale_pack_info) |
| |
|
| | if raw_features.dim() == 4: |
| | codes_out = raw_features.unsqueeze(2) |
| | else: |
| | codes_out = raw_features |
| | |
| | v_d = codes_out.shape[1] |
| | B, C, T, H, W = codes_out.shape |
| | if args.noise_input: |
| | noise = torch.randn((B, v_d, *vae_scale_schedule[0]), device=device, dtype=raw_features.dtype) |
| | else: |
| | noise = torch.zeros((B, v_d, *vae_scale_schedule[0]), device=device, dtype=raw_features.dtype) |
| | if infer_mode: noise_list.append(noise) |
| | next_var_input = noise |
| | valid_scales = len(vae_scale_schedule) - skip_last |
| | assert len(image_scale_repetition) == len(image_scale_schedule), f'{len(image_scale_repetition)} != {len(image_scale_schedule)}' |
| | real_si = 0 |
| | noise_apply_strength = self_correction.noise_apply_strength |
| | for si in range(valid_scales): |
| | pt, ph, pw = vae_scale_schedule[si] |
| | rel_si_in_one_clip = si % len(image_scale_schedule) |
| | if si < len(image_scale_schedule): |
| | repeat_times = image_scale_repetition[rel_si_in_one_clip] |
| | else: |
| | repeat_times = video_scale_repetition[rel_si_in_one_clip] |
| | select_repeat_idx = random_obj.integers(0, repeat_times) |
| | select_repeat_idx_list.append(select_repeat_idx) |
| | frame_ss, frame_ee = scale_pack_info[si]['frame_ss'], scale_pack_info[si]['frame_ee'] |
| | target = codes_out[:,:,frame_ss:frame_ee] |
| | for repeat_idx in range(repeat_times): |
| | if (not infer_mode) and (repeat_idx==select_repeat_idx): |
| | visual_rope_cache_list.append(get_visual_rope_embeds(rope2d_freqs_grid, scale_schedule[-1], scale_schedule[si], list(range(frame_ss, frame_ee)), real_si, device)) |
| | if next_var_input.shape[-3:] != target.shape[-3:]: |
| | next_var_input = F.interpolate(next_var_input, size=target.shape[-3:], mode=vae.quantizer.z_interplote_up).contiguous() |
| | cum_var_input = next_var_input |
| | this_scale_var_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si], mode=vae.quantizer.z_interplote_down).contiguous() |
| | residual = target - cum_var_input |
| | if args.use_two_stage_lfq: |
| | if rel_si_in_one_clip >= args.semantic_scales: |
| | is_semantic_scale = False |
| | C1 = vae.quantizer.detail_scale_dim |
| | lfq = vae.quantizer.lfq_detail |
| | else: |
| | is_semantic_scale = True |
| | C1 = vae.quantizer.semantic_scale_dim |
| | lfq = vae.quantizer.lfq_semantic |
| | residual = interpolate(residual, size=(C1, *vae_scale_schedule[si]), mode=vae.quantizer.z_interplote_down, quantizer=vae.quantizer, is_semantic_scale=is_semantic_scale).contiguous() |
| | else: |
| | residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=vae.quantizer.z_interplote_down).contiguous() |
| | try: |
| | lfq = vae.quantizer.lfq_detail |
| | except: |
| | lfq = vae.quantizer.lfq |
| | quantized, _, bit_indices, loss = lfq(residual) |
| |
|
| | if args.reduce_accumulate_error_method == 'bsc': |
| | if si < min(len(vae_scale_schedule)-1, self_correction.noise_apply_layers): |
| | pred_bit_indices, quantized = self_correction.apply_noise_requant(bit_indices, quantized, args, device, si, lfq, noise_apply_strength, num_lvl=2, np_generator=random_obj) |
| | else: |
| | pred_bit_indices = bit_indices |
| | else: |
| | raise NotImplementedError(args.reduce_accumulate_error_method) |
| |
|
| | if infer_mode or (repeat_idx==select_repeat_idx): |
| | pred_all_bit_indices.append(pred_bit_indices) |
| | var_input_list.append(this_scale_var_input) |
| | gt_all_bit_indices.append(bit_indices) |
| | other_info_by_scale.append({'largest_scale': scale_schedule[-1], 'real_si': si, 'mode': mode, 'global_si': real_si+global_si_base}) |
| | if args.use_two_stage_lfq: |
| | quantized_scaled = interpolate(quantized, size=target.shape[-4:], mode=vae.quantizer.z_interplote_up, quantizer=vae.quantizer, is_semantic_scale=is_semantic_scale).contiguous() |
| | else: |
| | quantized_scaled = F.interpolate(quantized, size=target.shape[-3:], mode=vae.quantizer.z_interplote_up).contiguous() |
| | next_var_input = cum_var_input + quantized_scaled |
| | real_si += 1 |
| | |
| | if si < len(vae_scale_schedule)-1: |
| | if vae_scale_schedule[si][-2:] == vae_scale_schedule[-1][-2:]: |
| | if args.noise_input: |
| | next_var_input = torch.randn((B, v_d, *vae_scale_schedule[si+1]), device=device, dtype=raw_features.dtype) |
| | else: |
| | next_var_input = torch.zeros((B, v_d, *vae_scale_schedule[si+1]), device=device, dtype=raw_features.dtype) |
| | if infer_mode: noise_list.append(next_var_input) |
| |
|
| | sequece_packing_scales.append(scale_schedule[:valid_scales]) |
| | if ind == 0: |
| | former_clip_features = raw_features[:,:,-20:] |
| |
|
| | |
| | if infer_mode: |
| | return noise_list, x_recon_raw, pred_all_bit_indices, None, None, scale_pack_info |
| | |
| | if vis_verbose: |
| | print(f'Rank={rank}, {sequece_packing_scales=} {select_repeat_idx_list=}', force=True) |
| | |
| | if args.train_second_clip_only: |
| | drop_scales = len(sequece_packing_scales[0]) |
| | sequece_packing_scales = sequece_packing_scales[1:] |
| | scale_pack_info_list = scale_pack_info_list[1:] |
| | gt_all_bit_indices = gt_all_bit_indices[drop_scales:] |
| | pred_all_bit_indices = pred_all_bit_indices[drop_scales:] |
| | other_info_by_scale = other_info_by_scale[drop_scales:] |
| | var_input_list = var_input_list[drop_scales:] |
| | visual_rope_cache_list = visual_rope_cache_list[drop_scales:] |
| |
|
| | flatten_packing_scales = flatten_two_level_list(sequece_packing_scales) |
| |
|
| | def add_noise(features, noise_choices=[0.00, 0.15, 0.30]): |
| | feature_std = features.std() |
| | rand_noise_strength = np.random.choice(noise_choices) |
| | return features + rand_noise_strength * feature_std * torch.randn_like(features) |
| |
|
| | |
| | semantic_condition = F.interpolate(former_clip_features, size=(20, *scale_schedule[semantic_scale_ind][-2:]), mode=vae.quantizer.z_interplote_down) |
| | semantic_condition = add_noise(semantic_condition) |
| | assert former_clip_features.shape[2] == 20 |
| | detail_condition = torch.cat([first_frame_feature_, add_noise(former_clip_features[:,:,detail_frame_inds])], dim=2) |
| | var_input_list.extend([semantic_condition, detail_condition]) |
| | |
| | visual_rope_cache_list.append(get_visual_rope_embeds(rope2d_freqs_grid, detail_condition.shape[-3:], semantic_condition.shape[-3:], list(range(1, 21)), 800, device)) |
| | visual_rope_cache_list.append(get_visual_rope_embeds(rope2d_freqs_grid, detail_condition.shape[-3:], detail_condition.shape[-3:], [0]+[item+1 for item in detail_frame_inds], 801, device)) |
| |
|
| | |
| | scale_lengths = [ pt * ph * pw for pt,ph,pw in flatten_packing_scales] |
| | scale_lengths = scale_lengths + [torch.tensor(semantic_condition.shape[-3:]).prod().item(), torch.tensor(detail_condition.shape[-3:]).prod().item()] |
| | scale_lengths = scale_lengths + text_lens |
| |
|
| | valid_scales = len(scale_lengths) |
| | pad_seq_len = train_max_token_len - np.sum(scale_lengths) |
| | assert pad_seq_len >= 0, f'pad_seq_len: {pad_seq_len} < 0, {scale_lengths=}' |
| | if pad_seq_len: |
| | scale_lengths = scale_lengths + [pad_seq_len] |
| | max_sid_nums = 2000 |
| | querysid_refsid = torch.zeros((max_sid_nums, max_sid_nums), device=args.device, dtype=torch.bool) |
| | for i in range(valid_scales): |
| | querysid_refsid[i][i] = True |
| | base = 0 |
| | for ind, scale_schedule in enumerate(sequece_packing_scales): |
| | real_example_ind = ind // 2 |
| | scale_pack_info = scale_pack_info_list[ind] |
| | for local_querysid in range(len(scale_schedule)): |
| | global_querysid = base + local_querysid |
| | if other_info_by_scale[base+local_querysid]['mode'] == 'first_iv_clip': |
| | global_text_sid = len(flatten_packing_scales) + 2 + sum(caption_nums[:real_example_ind]) + 0 |
| | querysid_refsid[global_querysid][global_text_sid] = True |
| | elif other_info_by_scale[base+local_querysid]['mode'] == 'second_v_clip': |
| | global_text_sid = len(flatten_packing_scales) + 2 + sum(caption_nums[:real_example_ind]) + 1 |
| | querysid_refsid[global_querysid][global_text_sid] = True |
| | querysid_refsid[global_querysid][len(flatten_packing_scales)+0] = True |
| | querysid_refsid[global_querysid][len(flatten_packing_scales)+1] = True |
| | else: |
| | raise ValueError(f'Unknown mode: {other_info_by_scale[base+local_querysid]["mode"]}') |
| | for local_refsid in (scale_pack_info[local_querysid]['ref_sids']): |
| | global_refsid = base + local_refsid |
| | querysid_refsid[global_querysid][global_refsid] = True |
| | base += len(scale_schedule) |
| | |
| | gt_ms_idx_Bl = [] |
| | for item in gt_all_bit_indices: |
| | if args.apply_spatial_patchify: |
| | |
| | item = item.permute(0,1,4,2,3) |
| | |
| | item = torch.nn.functional.pixel_unshuffle(item, 2) |
| | _, tt, dd, hh, ww = item.shape |
| | |
| | item = item.permute(0,1,3,4,2).reshape(B, tt*hh*ww, dd) |
| | else: |
| | _, tt, hh, ww, dd = item.shape |
| | item = item.reshape(B, tt*hh*ww, dd) |
| | gt_ms_idx_Bl.append(item.type(torch.long)) |
| | gt_BLC = gt_ms_idx_Bl |
| | for i in range(len(var_input_list)): |
| | if args.apply_spatial_patchify: |
| | |
| | var_input_list[i] = torch.nn.functional.pixel_unshuffle(var_input_list[i].permute(0,2,1,3,4), 2).permute(0,1,3,4,2) |
| | var_input_list[i] = var_input_list[i].reshape(B, -1, 4*vae.codebook_dim) |
| | else: |
| | |
| | var_input_list[i] = var_input_list[i].permute(0,2,3,4,1) |
| | var_input_list[i] = var_input_list[i].reshape(B, -1, vae.codebook_dim) |
| | x_BLC = torch.cat(var_input_list, 1) |
| | visual_rope_cache = torch.cat(visual_rope_cache_list, dim=4) |
| | x_BLC_mask = None |
| | return x_BLC, x_BLC_mask, gt_BLC, pred_all_bit_indices, visual_rope_cache, sequece_packing_scales, scale_lengths, querysid_refsid, other_info_by_scale, pad_seq_len |
| |
|
| | def video_decode( |
| | vae, |
| | all_indices, |
| | scale_schedule, |
| | label_type, |
| | args=None, |
| | noise_list=None, |
| | trunc_scales=-1, |
| | **kwargs, |
| | ): |
| | image_scale_repetition = json.loads(args.image_scale_repetition) |
| | video_scale_repetition = json.loads(args.video_scale_repetition) |
| | assert len(image_scale_repetition) == len(video_scale_repetition), f'{len(image_scale_repetition)} != {len(video_scale_repetition)}' |
| | real_si = 0 |
| | noise_ptr = 0 |
| | summed_codes = [] |
| | scales_in_one_clip = args.first_full_spatial_size_scale_index+1 |
| | clips = len(noise_list) - 1 |
| | for clip_id in range(clips): |
| | if clip_id == 1: |
| | scale_schedule = scale_schedule[(args.first_full_spatial_size_scale_index+1):] |
| | t = all_indices[-1].shape[1] |
| | scale_schedule = [(t, ph, pw) for pt, ph, pw in scale_schedule] |
| | summed_codes.append(noise_list[noise_ptr]) |
| | noise_ptr += 1 |
| | v_d = summed_codes[0].shape[1] |
| | for si, (pt, ph, pw) in enumerate(scale_schedule): |
| | if si < len(image_scale_repetition): |
| | repeat_times = image_scale_repetition[si%len(image_scale_repetition)] |
| | else: |
| | repeat_times = video_scale_repetition[si%len(image_scale_repetition)] |
| | for repeat_idx in range(repeat_times): |
| | tgt_shape = (pt, scale_schedule[-1][-2], scale_schedule[-1][-1]) |
| | if args.use_two_stage_lfq: |
| | if (si % scales_in_one_clip) >= args.semantic_scales: |
| | is_semantic_scale = False |
| | lfq = vae.quantizer.lfq_detail |
| | else: |
| | is_semantic_scale = True |
| | lfq = vae.quantizer.lfq_semantic |
| | codes = lfq.indices_to_codes(all_indices[real_si], label_type) |
| | codes = interpolate(codes, size=(v_d, *tgt_shape), mode=vae.quantizer.z_interplote_up, quantizer=vae.quantizer, is_semantic_scale=is_semantic_scale).contiguous() |
| | else: |
| | codes = vae.quantizer.lfq_detail.indices_to_codes(all_indices[real_si], label_type) |
| | codes = F.interpolate(codes, size=tgt_shape, mode=vae.quantizer.z_interplote_up).contiguous() |
| | |
| | summed_codes[-1] = F.interpolate(summed_codes[-1], size=tgt_shape, mode=vae.quantizer.z_interplote_up).contiguous() |
| | summed_codes[-1] += codes |
| | real_si += 1 |
| | |
| | if si < len(scale_schedule)-1 and scale_schedule[si][-2:] == tgt_shape[-2:]: |
| | summed_codes.append(noise_list[noise_ptr]) |
| | noise_ptr += 1 |
| | |
| | summed_codes = torch.cat(summed_codes, dim=-3) |
| | x_recon = vae.decode(summed_codes, slice=True) |
| | x_recon = torch.clamp(x_recon, min=-1, max=1) |
| | return x_recon |
| |
|
| | def get_visual_rope_embeds(rope2d_freqs_grid, largest_scale, current_scale, t_list, real_sid, device=None): |
| | |
| | |
| | rope2d_freqs_grid['freqs_scales'] = rope2d_freqs_grid['freqs_scales'].to(device) |
| | rope2d_freqs_grid['freqs_frames'] = rope2d_freqs_grid['freqs_frames'].to(device) |
| | rope2d_freqs_grid['freqs_height'] = rope2d_freqs_grid['freqs_height'].to(device) |
| | rope2d_freqs_grid['freqs_width'] = rope2d_freqs_grid['freqs_width'].to(device) |
| | _, uph, upw = largest_scale |
| | pt, ph, pw = current_scale |
| | dim_div_2_div_4 = rope2d_freqs_grid['freqs_scales'].shape[2] |
| | dim_div_2 = dim_div_2_div_4 * 4 |
| | f_scales = rope2d_freqs_grid['freqs_scales'][:, real_sid].reshape(2, 1, dim_div_2_div_4) |
| | f_frames = rope2d_freqs_grid['freqs_frames'][:, t_list] |
| | f_height = rope2d_freqs_grid['freqs_height'][:, (torch.arange(ph) * (uph / ph)).round().int()] |
| | f_width = rope2d_freqs_grid['freqs_width'][:, (torch.arange(pw) * (upw / pw)).round().int()] |
| | rope_embeds = torch.cat([ |
| | f_scales[ :, :, None, None, None, :].expand(-1, -1, pt, ph, pw, -1), |
| | f_frames[ :, None, :, None, None, :].expand(-1, 1, -1, ph, pw, -1), |
| | f_height[ :, None, None, :, None, :].expand(-1, 1, pt, -1, pw, -1), |
| | f_width[ :, None, None, None, :, :].expand(-1, 1, pt, ph, -1, -1), |
| | ], dim=-1) |
| | rope_embeds = rope_embeds.reshape(2, 1, 1, 1, 1*pt*ph*pw, dim_div_2) |
| | return rope_embeds |
| |
|