File size: 15,865 Bytes
3d1c0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT

import math
import os
from functools import partial
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from timm.models.layers import DropPath, drop_path
from torch.utils.checkpoint import checkpoint
from infinity.schedules.dynamic_resolution import get_first_full_spatial_size_scale_index


def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0, activated_h_div_w_templates=[]):
    # split the dimension into half, one for x and one for y
    half_dim = dim // 2
    inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2
    t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
    t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
    t_height = t_height / scaling_factor
    freqs_height = torch.outer(t_height, inv_freq)  # (max_height, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely y*theta
    t_width = t_width / scaling_factor
    freqs_width = torch.outer(t_width, inv_freq)  # (max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely x*theta
    freqs_grid_map = torch.concat([
        freqs_height[:, None, :].expand(-1, max_width, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
        freqs_width[None, :, :].expand(max_height, -1, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
    ], dim=-1)  # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
    freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0)
    # (2, max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))

    rope2d_freqs_grid = {}
    for h_div_w in activated_h_div_w_templates:
        assert h_div_w in dynamic_resolution_h_w, f'Unknown h_div_w: {h_div_w}'
        scale_schedule = dynamic_resolution_h_w[h_div_w]['1M']['image_scales']
        _, ph, pw = scale_schedule[-1]
        max_edge_length = freqs_grid_map.shape[1]
        if ph >= pw:
            uph, upw = max_edge_length, int(max_edge_length / ph * pw)
        else:
            uph, upw = int(max_edge_length / pw * ph), max_edge_length
        rope_cache_list = []
        for (_, ph, pw) in scale_schedule:
            ph_mul_pw = ph * pw
            if rope2d_normalized_by_hw == 1: # downsample
                rope_cache = F.interpolate(freqs_grid_map[:, :uph, :upw, :].permute([0,3,1,2]), size=(ph, pw), mode='bilinear', align_corners=True)
                rope_cache = rope_cache.permute([0,2,3,1]) # (2, ph, pw, half_head_dim)
            elif rope2d_normalized_by_hw == 2: # star stylee
                _, uph, upw = scale_schedule[-1]
                indices = torch.stack([
                    (torch.arange(ph) * (uph / ph)).reshape(ph, 1).expand(ph, pw),
                    (torch.arange(pw) * (upw / pw)).reshape(1, pw).expand(ph, pw),
                ], dim=-1).round().int() # (ph, pw, 2)
                indices = indices.reshape(-1, 2) # (ph*pw, 2)
                rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], :] # (2, ph*pw, half_head_dim)
                rope_cache = rope_cache.reshape(2, ph, pw, -1)
            elif rope2d_normalized_by_hw == 0:
                rope_cache = freqs_grid_map[:, :ph, :pw, :] # (2, ph, pw, half_head_dim)
            else:
                raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
            rope_cache_list.append(rope_cache.reshape(2, ph_mul_pw, -1))
        cat_rope_cache = torch.cat(rope_cache_list, 1) # (2, seq_len, half_head_dim)
        if cat_rope_cache.shape[1] % pad_to_multiplier:
            pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache.shape[1] % pad_to_multiplier, half_dim)
            cat_rope_cache = torch.cat([cat_rope_cache, pad], dim=1)
        cat_rope_cache = cat_rope_cache[:,None,None,None] # (2, 1, 1, 1, seq_len, half_dim)
        for pn in dynamic_resolution_h_w[h_div_w]:
            scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['image_scales']
            tmp_scale_schedule = [(1, h, w) for _, h, w in scale_schedule]
            rope2d_freqs_grid[str(tuple(tmp_scale_schedule))] = cat_rope_cache
    return rope2d_freqs_grid


def precompute_rope3d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_frames=128, max_height=2048 // 8, max_width=2048 // 8, base=10000.0, device=None, activated_h_div_w_templates=[], steps_per_frame=4, pn=None, args=None):
    # split the dimension into three parts, one for x, one for y, and one for t
    assert dim % 2 == 0, f'Only support dim % 2 == 0, but got dim={dim}'
    dim_div_2 = dim // 2
    num_of_freqs = int(np.ceil(dim_div_2 / 3))
    inv_freq = 1.0 / (base ** (torch.arange(num_of_freqs, dtype=torch.int64).float().to(device) / num_of_freqs)) # namely theta, 1 / (10000^(i/dim_div_3)), i=0,2,..., dim_div_3-2, totally dim_div_3 / 2 elems
    t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
    t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
    t_frames = torch.arange(max_frames, device=device, dtype=torch.int64).type_as(inv_freq)
    freqs_height = torch.outer(t_height, inv_freq)  # (max_height, ceil(dim_div_2 / 3)), namely y*theta
    freqs_width = torch.outer(t_width, inv_freq)  # (max_width, ceil(dim_div_2 / 3)), namely x*theta
    freqs_frames = torch.outer(t_frames, inv_freq)  # (max_width, ceil(dim_div_2 / 3)), namely x*theta
    if (num_of_freqs*3) - dim_div_2 == 0:
        offset_t, offset_h, offset_w = num_of_freqs, num_of_freqs, num_of_freqs
    elif (num_of_freqs*3) - dim_div_2 == 2: # 2 elems that should be drop
        offset_t, offset_h, offset_w = num_of_freqs, num_of_freqs-1, num_of_freqs-1
    else: # 1 elems that should be drop
        offset_t, offset_h, offset_w = num_of_freqs-1, num_of_freqs, num_of_freqs
    freqs_grid_map = torch.concat([
        freqs_frames[:, None, None, :offset_t].expand(-1, max_height, max_width, -1), # (max_frames, max_height, max_width, ceil(dim_div_2 / 3))
        freqs_height[None, :, None, :offset_h].expand(max_frames, -1, max_width, -1), # (max_frames, max_height, max_width, ceil(dim_div_2 / 3))
        freqs_width[None, None, :, :offset_w].expand(max_frames, max_height, -1, -1), # (max_frames, max_height, max_width, ceil(dim_div_2 / 3))
    ], dim=-1)  # (max_frames, max_height, max_width, dim / 2)
    freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0)
    # (2, max_frames, max_height, max_width, dim / 2)

    rope2d_freqs_grid = {}
    for h_div_w in activated_h_div_w_templates:
        assert h_div_w in dynamic_resolution_h_w, f'Unknown h_div_w: {h_div_w}'
        image_scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['image_scales']
        video_scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['video_scales']
        first_full_spatial_size_scale_index = get_first_full_spatial_size_scale_index(video_scale_schedule)
        pt, ph, pw = video_scale_schedule[-1]
        rope_cache_list4image, rope_cache_list4video = [], []
        
        # image
        for si, (pt, ph, pw) in enumerate(image_scale_schedule):
            assert pt == 1
            mul_pt_ph_pw = pt * ph * pw
            mul_ph_pw = ph * pw
            if rope2d_normalized_by_hw == 2: # star style
                upt, uph, upw = image_scale_schedule[-1]
                t_inds = 0 * torch.ones(pt, ph, pw)
                indices = torch.stack([
                    t_inds,
                    (torch.arange(ph) * (uph / ph)).reshape(1, ph, 1).expand(pt, ph, pw),
                    (torch.arange(pw) * (upw / pw)).reshape(1, 1, pw).expand(pt, ph, pw),
                ], dim=-1).round().int() # (pt, ph, pw, 3)
                indices = indices.reshape(-1, 3) # (pt*ph*pw, 3)
                rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], indices[:,2], :] # (2, pt*ph*pw, dim / 2)
                rope_cache = rope_cache.reshape(2, pt, ph, pw, -1)
            elif rope2d_normalized_by_hw == 0:
                rope_cache = freqs_grid_map[:, :pt, :ph, :pw, :] # (2, pt, ph, pw, dim / 2)
            else:
                raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
            rope_cache_list4image.append(rope_cache.reshape(2, mul_ph_pw, -1)) # (2, 1*ph*pw, dim / 2)
        
        # video
        for si, (pt, ph, pw) in enumerate(video_scale_schedule):
            mul_pt_ph_pw = pt * ph * pw
            mul_ph_pw = ph * pw
            if rope2d_normalized_by_hw == 2: # star style
                upt, uph, upw = video_scale_schedule[-1]
                if args.dynamic_scale_schedule == 'infinity_video_tower':
                    t_ind = int(np.ceil((si - first_full_spatial_size_scale_index) / steps_per_frame))
                    t_ind = max(t_ind, 0)
                    t_inds = t_ind * torch.ones(pt, ph, pw)
                    print(f't_ind: {t_ind}, si: {si}, (pt, ph, pw): {(pt, ph, pw)}')
                else:
                    t_inds = (torch.arange(pt)).reshape(pt, 1, 1).expand(pt, ph, pw)
                indices = torch.stack([
                    t_inds,
                    (torch.arange(ph) * (uph / ph)).reshape(1, ph, 1).expand(pt, ph, pw),
                    (torch.arange(pw) * (upw / pw)).reshape(1, 1, pw).expand(pt, ph, pw),
                ], dim=-1).round().int() # (pt, ph, pw, 3)
                indices = indices.reshape(-1, 3) # (pt*ph*pw, 3)
                rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], indices[:,2], :] # (2, pt*ph*pw, dim / 2)
                rope_cache = rope_cache.reshape(2, pt, ph, pw, -1)
            elif rope2d_normalized_by_hw == 0:
                rope_cache = freqs_grid_map[:, :pt, :ph, :pw, :] # (2, pt, ph, pw, dim / 2)
            else:
                raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
            rope_cache_list4video.append(rope_cache.reshape(2, mul_pt_ph_pw, -1)) # (2, pt*ph*pw, dim / 2)
        cat_rope_cache4image = torch.cat(rope_cache_list4image, 1) # (2, seq_len, dim / 2)
        cat_rope_cache4video = torch.cat(rope_cache_list4video, 1) # (2, seq_len, dim / 2)
        if cat_rope_cache4image.shape[1] % pad_to_multiplier:
            pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache4image.shape[1] % pad_to_multiplier, dim//2)
            cat_rope_cache4image = torch.cat([cat_rope_cache4image, pad], dim=1)
        if cat_rope_cache4video.shape[1] % pad_to_multiplier:
            pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache4video.shape[1] % pad_to_multiplier, dim//2)
            cat_rope_cache4video = torch.cat([cat_rope_cache4video, pad], dim=1)
        cat_rope_cache4image = cat_rope_cache4image[:,None,None,None] # (2, 1, 1, 1, seq_len, dim / 2)
        cat_rope_cache4video = cat_rope_cache4video[:,None,None,None] # (2, 1, 1, 1, seq_len, dim / 2)
        rope2d_freqs_grid[str(tuple(image_scale_schedule))] = cat_rope_cache4image
        rope2d_freqs_grid[str(tuple(video_scale_schedule))] = cat_rope_cache4video
    return rope2d_freqs_grid


def precompute_rope4d_freqs_grid(
        dim, 
        rope2d_normalized_by_hw, 
        pad_to_multiplier=1, 
        max_scales=128, 
        max_frames=128, 
        max_height=2048 // 8, 
        max_width=2048 // 8, 
        base=10000.0, 
        device=None, 
        activated_h_div_w_templates=[], 
        steps_per_frame=4, 
        text_maxlen=0, 
        pn=None, 
        args=None,
        **kwargs,
):
    # split the dimension into three parts, one for x, one for y, and one for t
    print(f'[precompute_rope4d_freqs_grid: 4d]: start')
    assert dim % 2 == 0, f'Only support dim % 2 == 0, but got dim={dim}'
    dim_div_2 = dim // 2
    num_of_freqs = int(np.ceil(dim_div_2 / 4))
    inv_freq = 1.0 / (base ** (torch.arange(num_of_freqs, dtype=torch.int64).float().to(device) / num_of_freqs)) # namely theta, 1 / (10000^(i/dim_div_4)), i=0,2,..., dim_div_4-2, totally dim_div_4 / 2 elems
    t_scales = torch.arange(text_maxlen+max_scales, device=device, dtype=torch.int64).type_as(inv_freq)
    t_frames = torch.arange(max_frames, device=device, dtype=torch.int64).type_as(inv_freq)
    t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
    t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
    freqs_scales = torch.outer(t_scales, inv_freq)  # (text_maxlen+max_scales, ceil(dim_div_2 / 4)), namely x*theta
    freqs_frames = torch.outer(t_frames, inv_freq)  # (max_frames, ceil(dim_div_2 / 4)), namely x*theta
    freqs_height = torch.outer(t_height, inv_freq)  # (max_height, ceil(dim_div_2 / 4)), namely y*theta
    freqs_width = torch.outer(t_width, inv_freq)  # (max_width, ceil(dim_div_2 / 4)), namely x*theta
    assert num_of_freqs*4==dim_div_2
    freqs_scales = torch.stack([torch.cos(freqs_scales), torch.sin(freqs_scales)], dim=0)
    freqs_frames = torch.stack([torch.cos(freqs_frames), torch.sin(freqs_frames)], dim=0)
    freqs_height = torch.stack([torch.cos(freqs_height), torch.sin(freqs_height)], dim=0)
    freqs_width = torch.stack([torch.cos(freqs_width), torch.sin(freqs_width)], dim=0)
    tm = text_maxlen
    rope_text_embeds = torch.cat([
        freqs_scales[   :,   :tm,  None,   None,   None,   :].expand(-1, -1, -1, -1, -1, -1),
        freqs_frames[   :,  None,    :1,   None,   None,   :].expand(-1, tm, -1, -1, -1, -1),
        freqs_height[   :,  None,  None,     :1,   None,   :].expand(-1, tm, -1, -1, -1, -1),
        freqs_width[    :,  None,  None,   None,     :1,   :].expand(-1, tm, -1, -1, -1, -1),
    ], dim=-1)  # (2, tm, 1, 1, 1, dim_div_2)
    rope_text_embeds = rope_text_embeds.reshape(2, 1, 1, 1, tm, dim_div_2)
    rope2d_freqs_grid = {}
    rope2d_freqs_grid['freqs_text'] = rope_text_embeds # (2, 1, 1, 1, text_maxlen, dim / 2)
    rope2d_freqs_grid['freqs_scales'] = freqs_scales[:, tm:] # (2, max_scales, ceil(dim_div_2 / 4))
    rope2d_freqs_grid['freqs_frames'] = freqs_frames # (2, max_frames, ceil(dim_div_2 / 4))
    rope2d_freqs_grid['freqs_height'] = freqs_height # (2, max_height, ceil(dim_div_2 / 4))
    rope2d_freqs_grid['freqs_width'] = freqs_width # (2, max_width, ceil(dim_div_2 / 4))
    return rope2d_freqs_grid

def apply_rotary_emb(q, k, rope_cache):
    device_type = q.device.type
    device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
    qk = [q, k]
    rope_cache = rope_cache[:,0]
    with torch.autocast(device_type=device_type, enabled=False):
        for i in range(2):
            qk[i] = qk[i].reshape(*qk[i].shape[:-1], -1, 2)
            tmp1 = qk[i][..., 1] * rope_cache[1]
            tmp2 = qk[i][..., 0] * rope_cache[1]
            qk[i][..., 0].mul_(rope_cache[0]).sub_(tmp1)
            qk[i][..., 1].mul_(rope_cache[0]).add_(tmp2)
            qk[i] = qk[i].reshape(*qk[i].shape[:-2], -1)
        q, k = qk
        # qk = qk.reshape(*qk.shape[:-1], -1, 2) #(2, batch_size, heads, seq_len, half_head_dim, 2)
        # qk = torch.stack([
        #     qk[...,0] * rope_cache[0] - qk[...,1] * rope_cache[1],
        #     qk[...,0] * rope_cache[1] + qk[...,1] * rope_cache[0],
        # ], dim=-1) # (2, batch_size, heads, seq_len, half_head_dim, 2), here stack + reshape should not be concate
        # qk = qk.reshape(*qk.shape[:-2], -1) #(2, batch_size, heads, seq_len, head_dim)
        # q, k = qk.unbind(dim=0) # (batch_size, heads, seq_len, head_dim)
    return q, k