BryanW's picture
Upload folder using huggingface_hub
3d1c0e1 verified
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import math
import os
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from timm.models.layers import DropPath, drop_path
from torch.utils.checkpoint import checkpoint
from infinity.schedules.dynamic_resolution import get_first_full_spatial_size_scale_index
def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0, activated_h_div_w_templates=[]):
# split the dimension into half, one for x and one for y
half_dim = dim // 2
inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
t_height = t_height / scaling_factor
freqs_height = torch.outer(t_height, inv_freq) # (max_height, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely y*theta
t_width = t_width / scaling_factor
freqs_width = torch.outer(t_width, inv_freq) # (max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely x*theta
freqs_grid_map = torch.concat([
freqs_height[:, None, :].expand(-1, max_width, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
freqs_width[None, :, :].expand(max_height, -1, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
], dim=-1) # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0)
# (2, max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
rope2d_freqs_grid = {}
for h_div_w in activated_h_div_w_templates:
assert h_div_w in dynamic_resolution_h_w, f'Unknown h_div_w: {h_div_w}'
scale_schedule = dynamic_resolution_h_w[h_div_w]['1M']['image_scales']
_, ph, pw = scale_schedule[-1]
max_edge_length = freqs_grid_map.shape[1]
if ph >= pw:
uph, upw = max_edge_length, int(max_edge_length / ph * pw)
else:
uph, upw = int(max_edge_length / pw * ph), max_edge_length
rope_cache_list = []
for (_, ph, pw) in scale_schedule:
ph_mul_pw = ph * pw
if rope2d_normalized_by_hw == 1: # downsample
rope_cache = F.interpolate(freqs_grid_map[:, :uph, :upw, :].permute([0,3,1,2]), size=(ph, pw), mode='bilinear', align_corners=True)
rope_cache = rope_cache.permute([0,2,3,1]) # (2, ph, pw, half_head_dim)
elif rope2d_normalized_by_hw == 2: # star stylee
_, uph, upw = scale_schedule[-1]
indices = torch.stack([
(torch.arange(ph) * (uph / ph)).reshape(ph, 1).expand(ph, pw),
(torch.arange(pw) * (upw / pw)).reshape(1, pw).expand(ph, pw),
], dim=-1).round().int() # (ph, pw, 2)
indices = indices.reshape(-1, 2) # (ph*pw, 2)
rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], :] # (2, ph*pw, half_head_dim)
rope_cache = rope_cache.reshape(2, ph, pw, -1)
elif rope2d_normalized_by_hw == 0:
rope_cache = freqs_grid_map[:, :ph, :pw, :] # (2, ph, pw, half_head_dim)
else:
raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
rope_cache_list.append(rope_cache.reshape(2, ph_mul_pw, -1))
cat_rope_cache = torch.cat(rope_cache_list, 1) # (2, seq_len, half_head_dim)
if cat_rope_cache.shape[1] % pad_to_multiplier:
pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache.shape[1] % pad_to_multiplier, half_dim)
cat_rope_cache = torch.cat([cat_rope_cache, pad], dim=1)
cat_rope_cache = cat_rope_cache[:,None,None,None] # (2, 1, 1, 1, seq_len, half_dim)
for pn in dynamic_resolution_h_w[h_div_w]:
scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['image_scales']
tmp_scale_schedule = [(1, h, w) for _, h, w in scale_schedule]
rope2d_freqs_grid[str(tuple(tmp_scale_schedule))] = cat_rope_cache
return rope2d_freqs_grid
def precompute_rope3d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_frames=128, max_height=2048 // 8, max_width=2048 // 8, base=10000.0, device=None, activated_h_div_w_templates=[], steps_per_frame=4, pn=None, args=None):
# split the dimension into three parts, one for x, one for y, and one for t
assert dim % 2 == 0, f'Only support dim % 2 == 0, but got dim={dim}'
dim_div_2 = dim // 2
num_of_freqs = int(np.ceil(dim_div_2 / 3))
inv_freq = 1.0 / (base ** (torch.arange(num_of_freqs, dtype=torch.int64).float().to(device) / num_of_freqs)) # namely theta, 1 / (10000^(i/dim_div_3)), i=0,2,..., dim_div_3-2, totally dim_div_3 / 2 elems
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
t_frames = torch.arange(max_frames, device=device, dtype=torch.int64).type_as(inv_freq)
freqs_height = torch.outer(t_height, inv_freq) # (max_height, ceil(dim_div_2 / 3)), namely y*theta
freqs_width = torch.outer(t_width, inv_freq) # (max_width, ceil(dim_div_2 / 3)), namely x*theta
freqs_frames = torch.outer(t_frames, inv_freq) # (max_width, ceil(dim_div_2 / 3)), namely x*theta
if (num_of_freqs*3) - dim_div_2 == 0:
offset_t, offset_h, offset_w = num_of_freqs, num_of_freqs, num_of_freqs
elif (num_of_freqs*3) - dim_div_2 == 2: # 2 elems that should be drop
offset_t, offset_h, offset_w = num_of_freqs, num_of_freqs-1, num_of_freqs-1
else: # 1 elems that should be drop
offset_t, offset_h, offset_w = num_of_freqs-1, num_of_freqs, num_of_freqs
freqs_grid_map = torch.concat([
freqs_frames[:, None, None, :offset_t].expand(-1, max_height, max_width, -1), # (max_frames, max_height, max_width, ceil(dim_div_2 / 3))
freqs_height[None, :, None, :offset_h].expand(max_frames, -1, max_width, -1), # (max_frames, max_height, max_width, ceil(dim_div_2 / 3))
freqs_width[None, None, :, :offset_w].expand(max_frames, max_height, -1, -1), # (max_frames, max_height, max_width, ceil(dim_div_2 / 3))
], dim=-1) # (max_frames, max_height, max_width, dim / 2)
freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0)
# (2, max_frames, max_height, max_width, dim / 2)
rope2d_freqs_grid = {}
for h_div_w in activated_h_div_w_templates:
assert h_div_w in dynamic_resolution_h_w, f'Unknown h_div_w: {h_div_w}'
image_scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['image_scales']
video_scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['video_scales']
first_full_spatial_size_scale_index = get_first_full_spatial_size_scale_index(video_scale_schedule)
pt, ph, pw = video_scale_schedule[-1]
rope_cache_list4image, rope_cache_list4video = [], []
# image
for si, (pt, ph, pw) in enumerate(image_scale_schedule):
assert pt == 1
mul_pt_ph_pw = pt * ph * pw
mul_ph_pw = ph * pw
if rope2d_normalized_by_hw == 2: # star style
upt, uph, upw = image_scale_schedule[-1]
t_inds = 0 * torch.ones(pt, ph, pw)
indices = torch.stack([
t_inds,
(torch.arange(ph) * (uph / ph)).reshape(1, ph, 1).expand(pt, ph, pw),
(torch.arange(pw) * (upw / pw)).reshape(1, 1, pw).expand(pt, ph, pw),
], dim=-1).round().int() # (pt, ph, pw, 3)
indices = indices.reshape(-1, 3) # (pt*ph*pw, 3)
rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], indices[:,2], :] # (2, pt*ph*pw, dim / 2)
rope_cache = rope_cache.reshape(2, pt, ph, pw, -1)
elif rope2d_normalized_by_hw == 0:
rope_cache = freqs_grid_map[:, :pt, :ph, :pw, :] # (2, pt, ph, pw, dim / 2)
else:
raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
rope_cache_list4image.append(rope_cache.reshape(2, mul_ph_pw, -1)) # (2, 1*ph*pw, dim / 2)
# video
for si, (pt, ph, pw) in enumerate(video_scale_schedule):
mul_pt_ph_pw = pt * ph * pw
mul_ph_pw = ph * pw
if rope2d_normalized_by_hw == 2: # star style
upt, uph, upw = video_scale_schedule[-1]
if args.dynamic_scale_schedule == 'infinity_video_tower':
t_ind = int(np.ceil((si - first_full_spatial_size_scale_index) / steps_per_frame))
t_ind = max(t_ind, 0)
t_inds = t_ind * torch.ones(pt, ph, pw)
print(f't_ind: {t_ind}, si: {si}, (pt, ph, pw): {(pt, ph, pw)}')
else:
t_inds = (torch.arange(pt)).reshape(pt, 1, 1).expand(pt, ph, pw)
indices = torch.stack([
t_inds,
(torch.arange(ph) * (uph / ph)).reshape(1, ph, 1).expand(pt, ph, pw),
(torch.arange(pw) * (upw / pw)).reshape(1, 1, pw).expand(pt, ph, pw),
], dim=-1).round().int() # (pt, ph, pw, 3)
indices = indices.reshape(-1, 3) # (pt*ph*pw, 3)
rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], indices[:,2], :] # (2, pt*ph*pw, dim / 2)
rope_cache = rope_cache.reshape(2, pt, ph, pw, -1)
elif rope2d_normalized_by_hw == 0:
rope_cache = freqs_grid_map[:, :pt, :ph, :pw, :] # (2, pt, ph, pw, dim / 2)
else:
raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
rope_cache_list4video.append(rope_cache.reshape(2, mul_pt_ph_pw, -1)) # (2, pt*ph*pw, dim / 2)
cat_rope_cache4image = torch.cat(rope_cache_list4image, 1) # (2, seq_len, dim / 2)
cat_rope_cache4video = torch.cat(rope_cache_list4video, 1) # (2, seq_len, dim / 2)
if cat_rope_cache4image.shape[1] % pad_to_multiplier:
pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache4image.shape[1] % pad_to_multiplier, dim//2)
cat_rope_cache4image = torch.cat([cat_rope_cache4image, pad], dim=1)
if cat_rope_cache4video.shape[1] % pad_to_multiplier:
pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache4video.shape[1] % pad_to_multiplier, dim//2)
cat_rope_cache4video = torch.cat([cat_rope_cache4video, pad], dim=1)
cat_rope_cache4image = cat_rope_cache4image[:,None,None,None] # (2, 1, 1, 1, seq_len, dim / 2)
cat_rope_cache4video = cat_rope_cache4video[:,None,None,None] # (2, 1, 1, 1, seq_len, dim / 2)
rope2d_freqs_grid[str(tuple(image_scale_schedule))] = cat_rope_cache4image
rope2d_freqs_grid[str(tuple(video_scale_schedule))] = cat_rope_cache4video
return rope2d_freqs_grid
def precompute_rope4d_freqs_grid(
dim,
rope2d_normalized_by_hw,
pad_to_multiplier=1,
max_scales=128,
max_frames=128,
max_height=2048 // 8,
max_width=2048 // 8,
base=10000.0,
device=None,
activated_h_div_w_templates=[],
steps_per_frame=4,
text_maxlen=0,
pn=None,
args=None,
**kwargs,
):
# split the dimension into three parts, one for x, one for y, and one for t
print(f'[precompute_rope4d_freqs_grid: 4d]: start')
assert dim % 2 == 0, f'Only support dim % 2 == 0, but got dim={dim}'
dim_div_2 = dim // 2
num_of_freqs = int(np.ceil(dim_div_2 / 4))
inv_freq = 1.0 / (base ** (torch.arange(num_of_freqs, dtype=torch.int64).float().to(device) / num_of_freqs)) # namely theta, 1 / (10000^(i/dim_div_4)), i=0,2,..., dim_div_4-2, totally dim_div_4 / 2 elems
t_scales = torch.arange(text_maxlen+max_scales, device=device, dtype=torch.int64).type_as(inv_freq)
t_frames = torch.arange(max_frames, device=device, dtype=torch.int64).type_as(inv_freq)
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
freqs_scales = torch.outer(t_scales, inv_freq) # (text_maxlen+max_scales, ceil(dim_div_2 / 4)), namely x*theta
freqs_frames = torch.outer(t_frames, inv_freq) # (max_frames, ceil(dim_div_2 / 4)), namely x*theta
freqs_height = torch.outer(t_height, inv_freq) # (max_height, ceil(dim_div_2 / 4)), namely y*theta
freqs_width = torch.outer(t_width, inv_freq) # (max_width, ceil(dim_div_2 / 4)), namely x*theta
assert num_of_freqs*4==dim_div_2
freqs_scales = torch.stack([torch.cos(freqs_scales), torch.sin(freqs_scales)], dim=0)
freqs_frames = torch.stack([torch.cos(freqs_frames), torch.sin(freqs_frames)], dim=0)
freqs_height = torch.stack([torch.cos(freqs_height), torch.sin(freqs_height)], dim=0)
freqs_width = torch.stack([torch.cos(freqs_width), torch.sin(freqs_width)], dim=0)
tm = text_maxlen
rope_text_embeds = torch.cat([
freqs_scales[ :, :tm, None, None, None, :].expand(-1, -1, -1, -1, -1, -1),
freqs_frames[ :, None, :1, None, None, :].expand(-1, tm, -1, -1, -1, -1),
freqs_height[ :, None, None, :1, None, :].expand(-1, tm, -1, -1, -1, -1),
freqs_width[ :, None, None, None, :1, :].expand(-1, tm, -1, -1, -1, -1),
], dim=-1) # (2, tm, 1, 1, 1, dim_div_2)
rope_text_embeds = rope_text_embeds.reshape(2, 1, 1, 1, tm, dim_div_2)
rope2d_freqs_grid = {}
rope2d_freqs_grid['freqs_text'] = rope_text_embeds # (2, 1, 1, 1, text_maxlen, dim / 2)
rope2d_freqs_grid['freqs_scales'] = freqs_scales[:, tm:] # (2, max_scales, ceil(dim_div_2 / 4))
rope2d_freqs_grid['freqs_frames'] = freqs_frames # (2, max_frames, ceil(dim_div_2 / 4))
rope2d_freqs_grid['freqs_height'] = freqs_height # (2, max_height, ceil(dim_div_2 / 4))
rope2d_freqs_grid['freqs_width'] = freqs_width # (2, max_width, ceil(dim_div_2 / 4))
return rope2d_freqs_grid
def apply_rotary_emb(q, k, rope_cache):
device_type = q.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
qk = [q, k]
rope_cache = rope_cache[:,0]
with torch.autocast(device_type=device_type, enabled=False):
for i in range(2):
qk[i] = qk[i].reshape(*qk[i].shape[:-1], -1, 2)
tmp1 = qk[i][..., 1] * rope_cache[1]
tmp2 = qk[i][..., 0] * rope_cache[1]
qk[i][..., 0].mul_(rope_cache[0]).sub_(tmp1)
qk[i][..., 1].mul_(rope_cache[0]).add_(tmp2)
qk[i] = qk[i].reshape(*qk[i].shape[:-2], -1)
q, k = qk
# qk = qk.reshape(*qk.shape[:-1], -1, 2) #(2, batch_size, heads, seq_len, half_head_dim, 2)
# qk = torch.stack([
# qk[...,0] * rope_cache[0] - qk[...,1] * rope_cache[1],
# qk[...,0] * rope_cache[1] + qk[...,1] * rope_cache[0],
# ], dim=-1) # (2, batch_size, heads, seq_len, half_head_dim, 2), here stack + reshape should not be concate
# qk = qk.reshape(*qk.shape[:-2], -1) #(2, batch_size, heads, seq_len, head_dim)
# q, k = qk.unbind(dim=0) # (batch_size, heads, seq_len, head_dim)
return q, k