|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import os |
|
|
from functools import partial |
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from timm.models.layers import DropPath, drop_path |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from infinity.schedules.dynamic_resolution import get_first_full_spatial_size_scale_index |
|
|
|
|
|
|
|
|
def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0, activated_h_div_w_templates=[]): |
|
|
|
|
|
half_dim = dim // 2 |
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) |
|
|
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
t_height = t_height / scaling_factor |
|
|
freqs_height = torch.outer(t_height, inv_freq) |
|
|
t_width = t_width / scaling_factor |
|
|
freqs_width = torch.outer(t_width, inv_freq) |
|
|
freqs_grid_map = torch.concat([ |
|
|
freqs_height[:, None, :].expand(-1, max_width, -1), |
|
|
freqs_width[None, :, :].expand(max_height, -1, -1), |
|
|
], dim=-1) |
|
|
freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0) |
|
|
|
|
|
|
|
|
rope2d_freqs_grid = {} |
|
|
for h_div_w in activated_h_div_w_templates: |
|
|
assert h_div_w in dynamic_resolution_h_w, f'Unknown h_div_w: {h_div_w}' |
|
|
scale_schedule = dynamic_resolution_h_w[h_div_w]['1M']['image_scales'] |
|
|
_, ph, pw = scale_schedule[-1] |
|
|
max_edge_length = freqs_grid_map.shape[1] |
|
|
if ph >= pw: |
|
|
uph, upw = max_edge_length, int(max_edge_length / ph * pw) |
|
|
else: |
|
|
uph, upw = int(max_edge_length / pw * ph), max_edge_length |
|
|
rope_cache_list = [] |
|
|
for (_, ph, pw) in scale_schedule: |
|
|
ph_mul_pw = ph * pw |
|
|
if rope2d_normalized_by_hw == 1: |
|
|
rope_cache = F.interpolate(freqs_grid_map[:, :uph, :upw, :].permute([0,3,1,2]), size=(ph, pw), mode='bilinear', align_corners=True) |
|
|
rope_cache = rope_cache.permute([0,2,3,1]) |
|
|
elif rope2d_normalized_by_hw == 2: |
|
|
_, uph, upw = scale_schedule[-1] |
|
|
indices = torch.stack([ |
|
|
(torch.arange(ph) * (uph / ph)).reshape(ph, 1).expand(ph, pw), |
|
|
(torch.arange(pw) * (upw / pw)).reshape(1, pw).expand(ph, pw), |
|
|
], dim=-1).round().int() |
|
|
indices = indices.reshape(-1, 2) |
|
|
rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], :] |
|
|
rope_cache = rope_cache.reshape(2, ph, pw, -1) |
|
|
elif rope2d_normalized_by_hw == 0: |
|
|
rope_cache = freqs_grid_map[:, :ph, :pw, :] |
|
|
else: |
|
|
raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}') |
|
|
rope_cache_list.append(rope_cache.reshape(2, ph_mul_pw, -1)) |
|
|
cat_rope_cache = torch.cat(rope_cache_list, 1) |
|
|
if cat_rope_cache.shape[1] % pad_to_multiplier: |
|
|
pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache.shape[1] % pad_to_multiplier, half_dim) |
|
|
cat_rope_cache = torch.cat([cat_rope_cache, pad], dim=1) |
|
|
cat_rope_cache = cat_rope_cache[:,None,None,None] |
|
|
for pn in dynamic_resolution_h_w[h_div_w]: |
|
|
scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['image_scales'] |
|
|
tmp_scale_schedule = [(1, h, w) for _, h, w in scale_schedule] |
|
|
rope2d_freqs_grid[str(tuple(tmp_scale_schedule))] = cat_rope_cache |
|
|
return rope2d_freqs_grid |
|
|
|
|
|
|
|
|
def precompute_rope3d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_frames=128, max_height=2048 // 8, max_width=2048 // 8, base=10000.0, device=None, activated_h_div_w_templates=[], steps_per_frame=4, pn=None, args=None): |
|
|
|
|
|
assert dim % 2 == 0, f'Only support dim % 2 == 0, but got dim={dim}' |
|
|
dim_div_2 = dim // 2 |
|
|
num_of_freqs = int(np.ceil(dim_div_2 / 3)) |
|
|
inv_freq = 1.0 / (base ** (torch.arange(num_of_freqs, dtype=torch.int64).float().to(device) / num_of_freqs)) |
|
|
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
t_frames = torch.arange(max_frames, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
freqs_height = torch.outer(t_height, inv_freq) |
|
|
freqs_width = torch.outer(t_width, inv_freq) |
|
|
freqs_frames = torch.outer(t_frames, inv_freq) |
|
|
if (num_of_freqs*3) - dim_div_2 == 0: |
|
|
offset_t, offset_h, offset_w = num_of_freqs, num_of_freqs, num_of_freqs |
|
|
elif (num_of_freqs*3) - dim_div_2 == 2: |
|
|
offset_t, offset_h, offset_w = num_of_freqs, num_of_freqs-1, num_of_freqs-1 |
|
|
else: |
|
|
offset_t, offset_h, offset_w = num_of_freqs-1, num_of_freqs, num_of_freqs |
|
|
freqs_grid_map = torch.concat([ |
|
|
freqs_frames[:, None, None, :offset_t].expand(-1, max_height, max_width, -1), |
|
|
freqs_height[None, :, None, :offset_h].expand(max_frames, -1, max_width, -1), |
|
|
freqs_width[None, None, :, :offset_w].expand(max_frames, max_height, -1, -1), |
|
|
], dim=-1) |
|
|
freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0) |
|
|
|
|
|
|
|
|
rope2d_freqs_grid = {} |
|
|
for h_div_w in activated_h_div_w_templates: |
|
|
assert h_div_w in dynamic_resolution_h_w, f'Unknown h_div_w: {h_div_w}' |
|
|
image_scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['image_scales'] |
|
|
video_scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['video_scales'] |
|
|
first_full_spatial_size_scale_index = get_first_full_spatial_size_scale_index(video_scale_schedule) |
|
|
pt, ph, pw = video_scale_schedule[-1] |
|
|
rope_cache_list4image, rope_cache_list4video = [], [] |
|
|
|
|
|
|
|
|
for si, (pt, ph, pw) in enumerate(image_scale_schedule): |
|
|
assert pt == 1 |
|
|
mul_pt_ph_pw = pt * ph * pw |
|
|
mul_ph_pw = ph * pw |
|
|
if rope2d_normalized_by_hw == 2: |
|
|
upt, uph, upw = image_scale_schedule[-1] |
|
|
t_inds = 0 * torch.ones(pt, ph, pw) |
|
|
indices = torch.stack([ |
|
|
t_inds, |
|
|
(torch.arange(ph) * (uph / ph)).reshape(1, ph, 1).expand(pt, ph, pw), |
|
|
(torch.arange(pw) * (upw / pw)).reshape(1, 1, pw).expand(pt, ph, pw), |
|
|
], dim=-1).round().int() |
|
|
indices = indices.reshape(-1, 3) |
|
|
rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], indices[:,2], :] |
|
|
rope_cache = rope_cache.reshape(2, pt, ph, pw, -1) |
|
|
elif rope2d_normalized_by_hw == 0: |
|
|
rope_cache = freqs_grid_map[:, :pt, :ph, :pw, :] |
|
|
else: |
|
|
raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}') |
|
|
rope_cache_list4image.append(rope_cache.reshape(2, mul_ph_pw, -1)) |
|
|
|
|
|
|
|
|
for si, (pt, ph, pw) in enumerate(video_scale_schedule): |
|
|
mul_pt_ph_pw = pt * ph * pw |
|
|
mul_ph_pw = ph * pw |
|
|
if rope2d_normalized_by_hw == 2: |
|
|
upt, uph, upw = video_scale_schedule[-1] |
|
|
if args.dynamic_scale_schedule == 'infinity_video_tower': |
|
|
t_ind = int(np.ceil((si - first_full_spatial_size_scale_index) / steps_per_frame)) |
|
|
t_ind = max(t_ind, 0) |
|
|
t_inds = t_ind * torch.ones(pt, ph, pw) |
|
|
print(f't_ind: {t_ind}, si: {si}, (pt, ph, pw): {(pt, ph, pw)}') |
|
|
else: |
|
|
t_inds = (torch.arange(pt)).reshape(pt, 1, 1).expand(pt, ph, pw) |
|
|
indices = torch.stack([ |
|
|
t_inds, |
|
|
(torch.arange(ph) * (uph / ph)).reshape(1, ph, 1).expand(pt, ph, pw), |
|
|
(torch.arange(pw) * (upw / pw)).reshape(1, 1, pw).expand(pt, ph, pw), |
|
|
], dim=-1).round().int() |
|
|
indices = indices.reshape(-1, 3) |
|
|
rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], indices[:,2], :] |
|
|
rope_cache = rope_cache.reshape(2, pt, ph, pw, -1) |
|
|
elif rope2d_normalized_by_hw == 0: |
|
|
rope_cache = freqs_grid_map[:, :pt, :ph, :pw, :] |
|
|
else: |
|
|
raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}') |
|
|
rope_cache_list4video.append(rope_cache.reshape(2, mul_pt_ph_pw, -1)) |
|
|
cat_rope_cache4image = torch.cat(rope_cache_list4image, 1) |
|
|
cat_rope_cache4video = torch.cat(rope_cache_list4video, 1) |
|
|
if cat_rope_cache4image.shape[1] % pad_to_multiplier: |
|
|
pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache4image.shape[1] % pad_to_multiplier, dim//2) |
|
|
cat_rope_cache4image = torch.cat([cat_rope_cache4image, pad], dim=1) |
|
|
if cat_rope_cache4video.shape[1] % pad_to_multiplier: |
|
|
pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache4video.shape[1] % pad_to_multiplier, dim//2) |
|
|
cat_rope_cache4video = torch.cat([cat_rope_cache4video, pad], dim=1) |
|
|
cat_rope_cache4image = cat_rope_cache4image[:,None,None,None] |
|
|
cat_rope_cache4video = cat_rope_cache4video[:,None,None,None] |
|
|
rope2d_freqs_grid[str(tuple(image_scale_schedule))] = cat_rope_cache4image |
|
|
rope2d_freqs_grid[str(tuple(video_scale_schedule))] = cat_rope_cache4video |
|
|
return rope2d_freqs_grid |
|
|
|
|
|
|
|
|
def precompute_rope4d_freqs_grid( |
|
|
dim, |
|
|
rope2d_normalized_by_hw, |
|
|
pad_to_multiplier=1, |
|
|
max_scales=128, |
|
|
max_frames=128, |
|
|
max_height=2048 // 8, |
|
|
max_width=2048 // 8, |
|
|
base=10000.0, |
|
|
device=None, |
|
|
activated_h_div_w_templates=[], |
|
|
steps_per_frame=4, |
|
|
text_maxlen=0, |
|
|
pn=None, |
|
|
args=None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
print(f'[precompute_rope4d_freqs_grid: 4d]: start') |
|
|
assert dim % 2 == 0, f'Only support dim % 2 == 0, but got dim={dim}' |
|
|
dim_div_2 = dim // 2 |
|
|
num_of_freqs = int(np.ceil(dim_div_2 / 4)) |
|
|
inv_freq = 1.0 / (base ** (torch.arange(num_of_freqs, dtype=torch.int64).float().to(device) / num_of_freqs)) |
|
|
t_scales = torch.arange(text_maxlen+max_scales, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
t_frames = torch.arange(max_frames, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
freqs_scales = torch.outer(t_scales, inv_freq) |
|
|
freqs_frames = torch.outer(t_frames, inv_freq) |
|
|
freqs_height = torch.outer(t_height, inv_freq) |
|
|
freqs_width = torch.outer(t_width, inv_freq) |
|
|
assert num_of_freqs*4==dim_div_2 |
|
|
freqs_scales = torch.stack([torch.cos(freqs_scales), torch.sin(freqs_scales)], dim=0) |
|
|
freqs_frames = torch.stack([torch.cos(freqs_frames), torch.sin(freqs_frames)], dim=0) |
|
|
freqs_height = torch.stack([torch.cos(freqs_height), torch.sin(freqs_height)], dim=0) |
|
|
freqs_width = torch.stack([torch.cos(freqs_width), torch.sin(freqs_width)], dim=0) |
|
|
tm = text_maxlen |
|
|
rope_text_embeds = torch.cat([ |
|
|
freqs_scales[ :, :tm, None, None, None, :].expand(-1, -1, -1, -1, -1, -1), |
|
|
freqs_frames[ :, None, :1, None, None, :].expand(-1, tm, -1, -1, -1, -1), |
|
|
freqs_height[ :, None, None, :1, None, :].expand(-1, tm, -1, -1, -1, -1), |
|
|
freqs_width[ :, None, None, None, :1, :].expand(-1, tm, -1, -1, -1, -1), |
|
|
], dim=-1) |
|
|
rope_text_embeds = rope_text_embeds.reshape(2, 1, 1, 1, tm, dim_div_2) |
|
|
rope2d_freqs_grid = {} |
|
|
rope2d_freqs_grid['freqs_text'] = rope_text_embeds |
|
|
rope2d_freqs_grid['freqs_scales'] = freqs_scales[:, tm:] |
|
|
rope2d_freqs_grid['freqs_frames'] = freqs_frames |
|
|
rope2d_freqs_grid['freqs_height'] = freqs_height |
|
|
rope2d_freqs_grid['freqs_width'] = freqs_width |
|
|
return rope2d_freqs_grid |
|
|
|
|
|
def apply_rotary_emb(q, k, rope_cache): |
|
|
device_type = q.device.type |
|
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
|
qk = [q, k] |
|
|
rope_cache = rope_cache[:,0] |
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
for i in range(2): |
|
|
qk[i] = qk[i].reshape(*qk[i].shape[:-1], -1, 2) |
|
|
tmp1 = qk[i][..., 1] * rope_cache[1] |
|
|
tmp2 = qk[i][..., 0] * rope_cache[1] |
|
|
qk[i][..., 0].mul_(rope_cache[0]).sub_(tmp1) |
|
|
qk[i][..., 1].mul_(rope_cache[0]).add_(tmp2) |
|
|
qk[i] = qk[i].reshape(*qk[i].shape[:-2], -1) |
|
|
q, k = qk |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return q, k |