kzopp commited on
Commit
4a2e18d
·
verified ·
1 Parent(s): f3df932

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ flan-t5-xl-encoder-Q8_0.gguf filter=lfs diff=lfs merge=lfs -text
37
+ infinity_2b_reg_Q8_0.gguf filter=lfs diff=lfs merge=lfs -text
Infinity/infinity/models/basic.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definitions of blocks of VAR transformer model.
3
+ """
4
+
5
+ import math
6
+ import os
7
+ from functools import partial
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from timm.models.layers import DropPath, drop_path
15
+ from torch.utils.checkpoint import checkpoint
16
+
17
+ # Attention backend selection with fallback hierarchy:
18
+ # 1. SageAttention (optional, 2-5x faster than FlashAttention)
19
+ # 2. FlashAttention (optional, still faster than PyTorch)
20
+ # 3. PyTorch scaled_dot_product_attention (always available)
21
+
22
+ SAGE_ATTN_AVAILABLE = False
23
+ FLASH_ATTN_AVAILABLE = False
24
+ sageattn = None
25
+ sageattn_varlen = None
26
+ flash_attn_func = None
27
+ flash_attn_varlen_kvpacked_func = None
28
+
29
+ # Try to import SageAttention (optional, fastest option)
30
+ try:
31
+ from sageattention import sageattn, sageattn_varlen
32
+ SAGE_ATTN_AVAILABLE = True
33
+ print("[INFO] SageAttention detected - will use for 2-5x speedup over FlashAttention")
34
+ except ImportError:
35
+ pass
36
+
37
+ # Try to import FlashAttention (optional, fallback if SageAttention not available)
38
+ try:
39
+ from flash_attn import flash_attn_func # q, k, or v: BLHc, ret: BLHc
40
+ from flash_attn import flash_attn_varlen_kvpacked_func # qkv: N3Hc, ret: NHc
41
+ FLASH_ATTN_AVAILABLE = True
42
+ if not SAGE_ATTN_AVAILABLE:
43
+ print("[INFO] FlashAttention detected - will use for optimized attention")
44
+ except ImportError:
45
+ pass
46
+
47
+ # Print final status
48
+ if not SAGE_ATTN_AVAILABLE and not FLASH_ATTN_AVAILABLE:
49
+ print("[INFO] Using PyTorch scaled_dot_product_attention (no SageAttention or FlashAttention detected)")
50
+ print(" Install SageAttention for 2-5x speedup: pip install sageattention>=2.2.0 --no-build-isolation")
51
+
52
+ from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
53
+
54
+ # Import GGUF utilities for on-the-fly dequantization
55
+ try:
56
+ import sys
57
+ import os
58
+ # Add parent directory to path to find infinity_gguf_utils
59
+ current_dir = os.path.dirname(os.path.abspath(__file__))
60
+ parent_dirs = [
61
+ os.path.join(current_dir, '../../..'), # From Infinity/infinity/models to root
62
+ os.path.join(current_dir, '../../../..'), # One more level up if needed
63
+ ]
64
+ for parent_dir in parent_dirs:
65
+ if parent_dir not in sys.path:
66
+ sys.path.insert(0, parent_dir)
67
+ from infinity_gguf_utils import dequantize_gguf_tensor, GGUFParameter
68
+ GGUF_AVAILABLE = True
69
+ except ImportError:
70
+ GGUF_AVAILABLE = False
71
+ GGUFParameter = None
72
+
73
+ def get_weight_for_linear(linear_layer, target_dtype=None):
74
+ """
75
+ Helper function to get weight from a linear layer, dequantizing if it's a GGUF parameter.
76
+
77
+ Args:
78
+ linear_layer: nn.Linear or GGUFLinear layer
79
+ target_dtype: Target dtype for dequantization
80
+
81
+ Returns:
82
+ Weight tensor ready for use in F.linear
83
+ """
84
+ weight = linear_layer.weight
85
+ if GGUF_AVAILABLE and isinstance(weight, GGUFParameter):
86
+ # Dequantize GGUF weight
87
+ return dequantize_gguf_tensor(weight, target_dtype=target_dtype)
88
+ # For F16 or other non-quantized weights, convert to target dtype if specified
89
+ if target_dtype is not None and weight.dtype != target_dtype:
90
+ return weight.to(dtype=target_dtype)
91
+ return weight
92
+
93
+
94
+ # Import flash_attn's fused ops
95
+ try:
96
+ from flash_attn.ops.layer_norm import dropout_add_layer_norm
97
+ from flash_attn.ops.rms_norm import dropout_add_rms_norm
98
+ from flash_attn.ops.rms_norm import rms_norm as rms_norm_impl
99
+ from flash_attn.ops.fused_dense import fused_mlp_func
100
+ flash_fused_op_installed = True
101
+ except ImportError:
102
+ dropout_add_layer_norm = dropout_add_rms_norm = fused_mlp_func = None
103
+ flash_fused_op_installed = False
104
+
105
+ def rms_norm_impl(x, weight, epsilon):
106
+ return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(epsilon))) * weight
107
+
108
+
109
+ def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0):
110
+ # split the dimension into half, one for x and one for y
111
+ half_dim = dim // 2
112
+ inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) # namely theta, 1 / (10000^(i/half_dim)), i=0,2,..., half_dim-2
113
+ t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq)
114
+ t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq)
115
+ t_height = t_height / scaling_factor
116
+ freqs_height = torch.outer(t_height, inv_freq) # (max_height, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely y*theta
117
+ t_width = t_width / scaling_factor
118
+ freqs_width = torch.outer(t_width, inv_freq) # (max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2), namely x*theta
119
+ freqs_grid_map = torch.concat([
120
+ freqs_height[:, None, :].expand(-1, max_width, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
121
+ freqs_width[None, :, :].expand(max_height, -1, -1), # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d) / 2)
122
+ ], dim=-1) # (max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
123
+ freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0)
124
+ # (2, max_height, max_width, dim / (1 for 1d, 2 for 2d, 3 for 3d))
125
+
126
+ rope2d_freqs_grid = {}
127
+ for h_div_w in dynamic_resolution_h_w:
128
+ scale_schedule = dynamic_resolution_h_w[h_div_w]['1M']['scales']
129
+ _, ph, pw = scale_schedule[-1]
130
+ max_edge_length = freqs_grid_map.shape[1]
131
+ if ph >= pw:
132
+ uph, upw = max_edge_length, int(max_edge_length / ph * pw)
133
+ else:
134
+ uph, upw = int(max_edge_length / pw * ph), max_edge_length
135
+ rope_cache_list = []
136
+ for (_, ph, pw) in scale_schedule:
137
+ ph_mul_pw = ph * pw
138
+ if rope2d_normalized_by_hw == 1: # downsample
139
+ rope_cache = F.interpolate(freqs_grid_map[:, :uph, :upw, :].permute([0,3,1,2]), size=(ph, pw), mode='bilinear', align_corners=True)
140
+ rope_cache = rope_cache.permute([0,2,3,1]) # (2, ph, pw, half_head_dim)
141
+ elif rope2d_normalized_by_hw == 2: # star stylee
142
+ _, uph, upw = scale_schedule[-1]
143
+ indices = torch.stack([
144
+ (torch.arange(ph) * (uph / ph)).reshape(ph, 1).expand(ph, pw),
145
+ (torch.arange(pw) * (upw / pw)).reshape(1, pw).expand(ph, pw),
146
+ ], dim=-1).round().int() # (ph, pw, 2)
147
+ indices = indices.reshape(-1, 2) # (ph*pw, 2)
148
+ rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], :] # (2, ph*pw, half_head_dim)
149
+ rope_cache = rope_cache.reshape(2, ph, pw, -1)
150
+ elif rope2d_normalized_by_hw == 0:
151
+ rope_cache = freqs_grid_map[:, :ph, :pw, :] # (2, ph, pw, half_head_dim)
152
+ else:
153
+ raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}')
154
+ rope_cache_list.append(rope_cache.reshape(2, ph_mul_pw, -1))
155
+ cat_rope_cache = torch.cat(rope_cache_list, 1) # (2, seq_len, half_head_dim)
156
+ if cat_rope_cache.shape[1] % pad_to_multiplier:
157
+ pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache.shape[1] % pad_to_multiplier, half_dim)
158
+ cat_rope_cache = torch.cat([cat_rope_cache, pad], dim=1)
159
+ cat_rope_cache = cat_rope_cache[:,None,None,None] # (2, 1, 1, 1, seq_len, half_dim)
160
+ for pn in dynamic_resolution_h_w[h_div_w]:
161
+ scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['scales']
162
+ tmp_scale_schedule = [(1, h, w) for _, h, w in scale_schedule]
163
+ rope2d_freqs_grid[str(tuple(tmp_scale_schedule))] = cat_rope_cache
164
+ return rope2d_freqs_grid
165
+
166
+
167
+ def apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, pad_to_multiplier, rope2d_normalized_by_hw, scale_ind):
168
+ qk = torch.stack((q, k), dim=0) #(2, batch_size, heads, seq_len, head_dim)
169
+ device_type = qk.device.type
170
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
171
+ with torch.autocast(device_type=device_type, enabled=False):
172
+ seq_len = qk.shape[3]
173
+ start = 0
174
+ if scale_ind >= 1:
175
+ assert len(scale_schedule[0]) == 3
176
+ start = np.sum([item[0] * item[1] * item[2] for item in scale_schedule[:scale_ind]])
177
+ rope2d_freqs_grid[str(tuple(scale_schedule))] = rope2d_freqs_grid[str(tuple(scale_schedule))].to(qk.device)
178
+ assert start+seq_len <= rope2d_freqs_grid[str(tuple(scale_schedule))].shape[4]
179
+ rope_cache = rope2d_freqs_grid[str(tuple(scale_schedule))][:, :, :, :, start:start+seq_len] # rope_cache shape: [2, 1, 1, 1, seq_len, half_head_dim]
180
+ qk = qk.reshape(*qk.shape[:-1], -1, 2) #(2, batch_size, heads, seq_len, half_head_dim, 2)
181
+ qk = torch.stack([
182
+ rope_cache[0] * qk[...,0] - rope_cache[1] * qk[...,1],
183
+ rope_cache[1] * qk[...,0] + rope_cache[0] * qk[...,1],
184
+ ], dim=-1) # (2, batch_size, heads, seq_len, half_head_dim, 2), here stack + reshape should not be concate
185
+ qk = qk.reshape(*qk.shape[:-2], -1) #(2, batch_size, heads, seq_len, head_dim)
186
+ q, k = qk.unbind(dim=0) # (batch_size, heads, seq_len, head_dim)
187
+ return q, k
188
+
189
+
190
+ class FastRMSNorm(nn.Module):
191
+ def __init__(self, C, eps=1e-6, elementwise_affine=True):
192
+ super().__init__()
193
+ self.C = C
194
+ self.eps = eps
195
+ self.elementwise_affine = elementwise_affine
196
+ if self.elementwise_affine:
197
+ self.weight = nn.Parameter(torch.ones(C))
198
+ else:
199
+ self.register_buffer('weight', torch.ones(C))
200
+
201
+ def forward(self, x):
202
+ src_type = x.dtype
203
+ return rms_norm_impl(x.float(), self.weight, epsilon=self.eps).to(src_type)
204
+
205
+ def extra_repr(self) -> str:
206
+ return f'C={self.C}, eps={self.eps:g}, elementwise_affine={self.elementwise_affine}'
207
+
208
+
209
+ def get_dropout_layer(p):
210
+ return nn.Dropout(p, inplace=True) if p > 0 else nn.Identity()
211
+
212
+
213
+ class FFN(nn.Module):
214
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_mlp=False):
215
+ super().__init__()
216
+ self.fused_mlp_func = fused_mlp_func if fused_mlp else None
217
+ out_features = out_features or in_features
218
+ hidden_features = hidden_features or in_features
219
+ self.fc1 = nn.Linear(in_features, hidden_features)
220
+ self.act = nn.GELU(approximate='tanh')
221
+ self.fc2 = nn.Linear(hidden_features, out_features)
222
+ self.drop = get_dropout_layer(drop)
223
+ self.heuristic = -1
224
+
225
+ def forward(self, x):
226
+ if self.fused_mlp_func is not None:
227
+ return self.drop(self.fused_mlp_func(
228
+ x=x,
229
+ weight1=self.fc1.weight,
230
+ weight2=self.fc2.weight,
231
+ bias1=self.fc1.bias,
232
+ bias2=self.fc2.bias,
233
+ activation='gelu_approx',
234
+ save_pre_act=self.training,
235
+ return_residual=False,
236
+ checkpoint_lvl=0,
237
+ heuristic=self.heuristic,
238
+ process_group=None,
239
+ ))
240
+ else:
241
+ return self.drop(self.fc2( self.act(self.fc1(x)) ))
242
+
243
+ def extra_repr(self) -> str:
244
+ return f'fused_mlp={self.fused_mlp_func is not None}'
245
+
246
+
247
+ class FFNSwiGLU(nn.Module):
248
+ def __init__(self, in_features, hidden_features, out_features=None, drop=0., fused_mlp=False):
249
+ super().__init__()
250
+ self.fused_mlp_func = None
251
+ hidden_features = round(2 * hidden_features / 3 / 256) * 256
252
+
253
+ out_features = out_features or in_features
254
+ self.fcg = nn.Linear(in_features, hidden_features, bias=False)
255
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
256
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
257
+ self.drop = get_dropout_layer(drop)
258
+
259
+ def forward(self, x):
260
+ return self.drop(self.fc2( F.silu(self.fcg(x), inplace=True).mul_(self.fc1(x)) ))
261
+
262
+ def extra_repr(self) -> str:
263
+ return f'fused_mlp={self.fused_mlp_func is not None}'
264
+
265
+
266
+ class SelfAttention(nn.Module):
267
+ def __init__(
268
+ self, embed_dim=768, num_heads=12,
269
+ proj_drop=0., tau=1, cos_attn=False, customized_flash_attn=True, use_flex_attn=False,
270
+ batch_size=2, pad_to_multiplier=1, rope2d_normalized_by_hw=0,
271
+ ):
272
+ """
273
+ :param embed_dim: model's width
274
+ :param num_heads: num heads of multi-head attention
275
+ :param proj_drop: always 0 for testing
276
+ :param tau: always 1
277
+ :param cos_attn: always True: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
278
+ :param customized_flash_attn:
279
+ """
280
+ super().__init__()
281
+ assert embed_dim % num_heads == 0
282
+ self.using_flash = customized_flash_attn
283
+
284
+ self.num_heads, self.head_dim = num_heads, embed_dim // num_heads
285
+ self.tau, self.cos_attn = tau, cos_attn
286
+ if self.cos_attn:
287
+ self.scale = 1
288
+ size = (1, 1, self.num_heads, 1) if self.using_flash else (1, self.num_heads, 1, 1)
289
+ # size: 11H1 or 1H11
290
+ self.scale_mul_1H11 = nn.Parameter(torch.full(size=size, fill_value=4.0).log(), requires_grad=True)
291
+ self.max_scale_mul = torch.log(torch.tensor(100)).item()
292
+ else:
293
+ self.scale = 1 / math.sqrt(self.head_dim) / self.tau
294
+
295
+ self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
296
+ self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
297
+ self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
298
+
299
+ self.proj = nn.Linear(embed_dim, embed_dim)
300
+ self.proj_drop = get_dropout_layer(proj_drop)
301
+
302
+ self.caching = False # kv caching: only used during inference
303
+ self.cached_k = None # kv caching: only used during inference
304
+ self.cached_v = None # kv caching: only used during inference
305
+
306
+ self.batch_size = batch_size
307
+ self.use_flex_attn = use_flex_attn
308
+ self.pad_to_multiplier = pad_to_multiplier
309
+
310
+ self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
311
+
312
+
313
+ def kv_caching(self, enable: bool): # kv caching: only used during inference
314
+ self.caching = enable
315
+ self.cached_k = None
316
+ self.cached_v = None
317
+
318
+ # NOTE: attn_bias_or_two_vector is None during inference
319
+ def forward(self, x, attn_bias_or_two_vector: Union[torch.Tensor, Tuple[torch.IntTensor, torch.IntTensor]], attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0):
320
+ """
321
+ :param (fp32) x: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
322
+ :param (fp32) attn_bias_or_two_vector:
323
+ if not using_flash:
324
+ a block-wise, lower-triangle matrix, like:
325
+ [[[[0, -, -, -, -, -, -, -, -, -, -, -, -, -],
326
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
327
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
328
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
329
+ [0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -],
330
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
331
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
332
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
333
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
334
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
335
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
336
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
337
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
338
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]]
339
+ where 0 means visible and - means invisible (-inf)
340
+ else:
341
+ a tuple of two 1-dim int vector (VAR_visible_kvlen, VAR_invisible_qlen)
342
+ :return: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared
343
+ """
344
+ # x: fp32
345
+ B, L, C = x.shape
346
+
347
+ # qkv: amp, bf16
348
+ qkv = F.linear(input=x, weight=get_weight_for_linear(self.mat_qkv, target_dtype=x.dtype), bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) # BL3Hc
349
+ if self.using_flash: q, k, v = qkv.unbind(dim=2); L_dim = 1 # q or k or v: all are shaped in (B:batch_size, L:seq_len, H:heads, c:head_dim)
350
+ else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); L_dim = 2 # q or k or v: all are shaped in (B:batch_size, H:heads, L:seq_len, c:head_dim)
351
+
352
+ if self.cos_attn: # always True
353
+ scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() # 11H1 (flash), or 1H11 (not flash)
354
+ q = F.normalize(q, dim=-1, eps=1e-12).mul(scale_mul).contiguous() # fp32
355
+ k = F.normalize(k, dim=-1, eps=1e-12).contiguous() # fp32
356
+ v = v.contiguous() # bf16
357
+ else: # be contiguous, to make kernel happy
358
+ q = q.contiguous() # bf16
359
+ k = k.contiguous() # bf16
360
+ v = v.contiguous() # bf16
361
+ if rope2d_freqs_grid is not None:
362
+ q, k = apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, self.pad_to_multiplier, self.rope2d_normalized_by_hw, scale_ind) #, freqs_cis=freqs_cis)
363
+ if self.caching: # kv caching: only used during inference
364
+ if self.cached_k is None: self.cached_k = k; self.cached_v = v
365
+ else: k = self.cached_k = torch.cat((self.cached_k, k), dim=L_dim); v = self.cached_v = torch.cat((self.cached_v, v), dim=L_dim)
366
+
367
+ if self.using_flash:
368
+ # Try SageAttention first (if available and during inference)
369
+ if SAGE_ATTN_AVAILABLE and attn_bias_or_two_vector is None:
370
+ try:
371
+ # SageAttention: expects (B, num_heads, seq_len, head_dim) layout (HND format)
372
+ # Our q, k, v are already in (B, L, H, c) format, need to transpose to (B, H, L, c)
373
+ q_sage = q.transpose(1, 2) # (B, H, L, c)
374
+ k_sage = k.transpose(1, 2) # (B, H, L, c)
375
+ v_sage = v.transpose(1, 2) # (B, H, L, c)
376
+
377
+ # Convert to fp16 or bf16 if needed (SageAttention requires fp16/bf16)
378
+ target_dtype = torch.bfloat16 if v.dtype == torch.float32 else v.dtype
379
+ q_sage = q_sage.to(target_dtype)
380
+ k_sage = k_sage.to(target_dtype)
381
+ v_sage = v_sage.to(target_dtype)
382
+
383
+ # Use SageAttention for inference
384
+ oup = sageattn(q_sage, k_sage, v_sage, tensor_layout="HND", is_causal=False)
385
+ oup = oup.transpose(1, 2).reshape(B, L, C) # (B, H, L, c) -> (B, L, H, c) -> (B, L, C)
386
+ if target_dtype != v.dtype:
387
+ oup = oup.to(v.dtype)
388
+ except Exception as e:
389
+ print(f"[WARNING] SageAttention failed ({str(e)[:100]}), falling back to FlashAttention/PyTorch")
390
+ # Fall through to FlashAttention or PyTorch
391
+ if FLASH_ATTN_AVAILABLE:
392
+ kw = dict() if attn_bias_or_two_vector is None else dict(VAR_visible_kvlen=attn_bias_or_two_vector[0], VAR_invisible_qlen=attn_bias_or_two_vector[1])
393
+ oup = flash_attn_func(q.to(v.dtype), k.to(v.dtype), v, dropout_p=0, softmax_scale=self.scale, **kw).view(B, L, C)
394
+ else:
395
+ q_torch = q.transpose(1, 2)
396
+ k_torch = k.transpose(1, 2)
397
+ v_torch = v.transpose(1, 2)
398
+ oup = slow_attn(query=q_torch, key=k_torch, value=v_torch, scale=self.scale, dropout_p=0).transpose(1, 2).reshape(B, L, C)
399
+
400
+ # Fall back to FlashAttention if SageAttention not used
401
+ elif FLASH_ATTN_AVAILABLE:
402
+ if attn_bias_or_two_vector is not None: # training
403
+ kw = dict(VAR_visible_kvlen=attn_bias_or_two_vector[0], VAR_invisible_qlen=attn_bias_or_two_vector[1])
404
+ else: # inference (autoregressive sampling)
405
+ kw = dict()
406
+ oup = flash_attn_func(q.to(v.dtype), k.to(v.dtype), v, dropout_p=0, softmax_scale=self.scale, **kw).view(B, L, C)
407
+
408
+ # Final fallback to PyTorch SDPA
409
+ else:
410
+ q_torch = q.transpose(1, 2) # (B, H, L, c)
411
+ k_torch = k.transpose(1, 2)
412
+ v_torch = v.transpose(1, 2)
413
+ oup = slow_attn(query=q_torch, key=k_torch, value=v_torch, scale=self.scale, dropout_p=0).transpose(1, 2).reshape(B, L, C)
414
+ else:
415
+ # if self.cos_attn: q, k are in fp32; v is in bf16
416
+ # else: q, k, v are in bf16
417
+ if self.use_flex_attn and attn_fn is not None:
418
+ oup = attn_fn(q, k, v, scale=self.scale).transpose(1, 2).reshape(B, L, C)
419
+ else:
420
+ oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C)
421
+ # oup: bf16
422
+
423
+ return self.proj_drop(self.proj(oup))
424
+
425
+ def extra_repr(self) -> str:
426
+ tail = ''
427
+ return f'using_flash={self.using_flash}, tau={self.tau}, cos_attn={self.cos_attn}{tail}'
428
+
429
+
430
+ class CrossAttention(nn.Module):
431
+ def __init__(
432
+ self, for_attn_pool=False, embed_dim=768, kv_dim=4096, num_heads=12,
433
+ proj_drop=0., cos_attn=False, use_flash_attn=True,
434
+ ):
435
+ """
436
+ :param for_attn_pool: only used in VAR.text_proj_for_sos
437
+ :param embed_dim: Q's dim
438
+ :param kv_dim: K's and V's dim
439
+ :param num_heads: num heads of multi-head attention
440
+ :param proj_drop: proj drop out
441
+ :param cos_attn: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11
442
+ """
443
+ cos_attn = False # TODO: never use cos attn in cross attention with T5 kv
444
+ super().__init__()
445
+ self.for_attn_pool = for_attn_pool
446
+ self.embed_dim = embed_dim
447
+ self.kv_dim = kv_dim
448
+ assert embed_dim % num_heads == 0
449
+ self.num_heads, self.head_dim = num_heads, embed_dim // num_heads # =64
450
+ self.cos_attn = cos_attn
451
+ self.use_flash_attn = use_flash_attn
452
+ if self.cos_attn:
453
+ self.scale = 1
454
+ self.scale_mul_1H1 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
455
+ self.max_scale_mul = torch.log(torch.tensor(100)).item()
456
+ else:
457
+ self.scale = 1 / math.sqrt(self.head_dim)
458
+
459
+ if for_attn_pool:
460
+ q = torch.empty(1, self.num_heads, self.head_dim)
461
+ nn.init.trunc_normal_(q, mean=0, std=math.sqrt(1 / embed_dim / 3))
462
+ self.mat_q = nn.Parameter(q)
463
+ else:
464
+ self.mat_q = nn.Linear(embed_dim, embed_dim, bias=True)
465
+ self.mat_kv = nn.Linear(kv_dim, embed_dim*2, bias=False)
466
+ self.v_bias = nn.Parameter(torch.zeros(embed_dim))
467
+ self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
468
+
469
+ self.proj = nn.Linear(embed_dim, embed_dim)
470
+ self.proj_drop = get_dropout_layer(proj_drop)
471
+
472
+ def forward(self, q, ca_kv):
473
+ """
474
+ :param q: shaped as (batch, seq_len, Q_dim)
475
+ :param ca_kv: contains several vectors, each of which is shaped as (len_i, KV_dim). We have [len_1xKV_dim, len_2xKV_dim, len_3xKV_dim, ...] and lens == [len_1, len_2, len_3, ...]
476
+ - kv_compact: shaped as (sum(lens), KV_dim)
477
+ - cu_seqlens_k: cumulated sum of lens
478
+ - max_seqlen_k: int, max(lens)
479
+ NOTE: seq_len (num of Qs) can reach 10k; but len_i (num of KVs) must <= 256
480
+
481
+ :return: shaped as (batch, seq_len, Q_dim)
482
+ """
483
+ kv_compact, cu_seqlens_k, max_seqlen_k = ca_kv
484
+ N = kv_compact.shape[0]
485
+
486
+ kv_compact = F.linear(kv_compact, weight=get_weight_for_linear(self.mat_kv, target_dtype=kv_compact.dtype), bias=torch.cat((self.zero_k_bias, self.v_bias))).view(N, 2, self.num_heads, self.head_dim) # NC => N2Hc
487
+ # attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
488
+
489
+ if not self.for_attn_pool:
490
+ B, Lq = q.shape[:2]
491
+ q_compact = self.mat_q(q).view(-1, self.num_heads, self.head_dim)
492
+ else:
493
+ B = cu_seqlens_k.shape[0] - 1
494
+ Lq = 1
495
+ # Dequantize mat_q if it's a GGUFParameter
496
+ mat_q_data = self.mat_q
497
+ if GGUF_AVAILABLE and isinstance(mat_q_data, GGUFParameter):
498
+ mat_q_data = dequantize_gguf_tensor(mat_q_data, target_dtype=kv_compact.dtype)
499
+ q_compact = mat_q_data.repeat(B, 1, 1).to(dtype=kv_compact.dtype)
500
+
501
+ if self.cos_attn: # always False
502
+ scale_mul = self.scale_mul_1H1.clamp_max(self.max_scale_mul).exp()
503
+ k, v = kv_compact.unbind(dim=1)
504
+ q_compact = F.normalize(q_compact, dim=-1).mul(scale_mul)
505
+ k = F.normalize(k, dim=-1)
506
+ kv_compact = torch.stack((k, v), dim=1)
507
+
508
+ q_compact = q_compact.contiguous()
509
+ kv_compact = kv_compact.contiguous()
510
+
511
+
512
+ # Try optimized attention backends with graceful fallback
513
+ if self.use_flash_attn:
514
+ cu_seqlens_q = torch.arange(0, Lq * (B+1), Lq, dtype=torch.int32, device=q_compact.device)
515
+ oup = None
516
+
517
+ # Try SageAttention first (fastest option)
518
+ if SAGE_ATTN_AVAILABLE:
519
+ try:
520
+ # SageAttention varlen: expects separate k, v tensors
521
+ # kv_compact is (N, 2, num_heads, head_dim), split into k and v
522
+ k_compact, v_compact = kv_compact.unbind(dim=1) # Each is (N, num_heads, head_dim)
523
+
524
+ # Convert to fp16/bf16 if needed
525
+ target_dtype = torch.bfloat16 if q_compact.dtype == torch.float32 else q_compact.dtype
526
+ q_sage = q_compact.to(target_dtype)
527
+ k_sage = k_compact.to(target_dtype)
528
+ v_sage = v_compact.to(target_dtype)
529
+
530
+ # Use sageattn_varlen for variable length sequences
531
+ oup = sageattn_varlen(
532
+ q=q_sage,
533
+ k=k_sage,
534
+ v=v_sage,
535
+ cu_seqlens_q=cu_seqlens_q,
536
+ cu_seqlens_k=cu_seqlens_k,
537
+ max_seqlen_q=Lq,
538
+ max_seqlen_k=max_seqlen_k,
539
+ is_causal=False,
540
+ sm_scale=self.scale,
541
+ smooth_k=True
542
+ ).reshape(B, Lq, -1)
543
+
544
+ if target_dtype != q_compact.dtype:
545
+ oup = oup.float()
546
+
547
+ except Exception as e:
548
+ print(f"[WARNING] SageAttention failed ({str(e)[:100]}), falling back to FlashAttention/PyTorch")
549
+ oup = None
550
+
551
+ # Fall back to FlashAttention if SageAttention failed or not available
552
+ if oup is None and FLASH_ATTN_AVAILABLE:
553
+ try:
554
+ if q_compact.dtype == torch.float32:
555
+ oup = flash_attn_varlen_kvpacked_func(q=q_compact.to(dtype=torch.bfloat16), kv=kv_compact.to(dtype=torch.bfloat16), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
556
+ oup = oup.float()
557
+ else:
558
+ oup = flash_attn_varlen_kvpacked_func(q=q_compact, kv=kv_compact, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1)
559
+ except Exception as e:
560
+ print(f"[WARNING] FlashAttention failed ({str(e)[:100]}), falling back to PyTorch attention")
561
+ oup = None
562
+
563
+ # If both SageAttention and FlashAttention failed, fall back to PyTorch
564
+ if oup is None:
565
+ self.use_flash_attn = False # Disable optimized attention for future calls
566
+
567
+ # Fallback to PyTorch scaled_dot_product_attention
568
+ if not self.use_flash_attn:
569
+ # Unpack k and v from kv_compact: (N, 2, num_heads, head_dim)
570
+ k, v = kv_compact.unbind(dim=1) # k, v: (N, num_heads, head_dim)
571
+
572
+ # Reconstruct per-batch k and v tensors based on cu_seqlens_k
573
+ k_batched = []
574
+ v_batched = []
575
+ for i in range(B):
576
+ start = cu_seqlens_k[i].item()
577
+ end = cu_seqlens_k[i+1].item()
578
+ k_batched.append(k[start:end]) # (seq_len_i, num_heads, head_dim)
579
+ v_batched.append(v[start:end])
580
+
581
+ # Pad to max_seqlen_k for batching
582
+ k_padded = torch.stack([
583
+ F.pad(k_i, (0, 0, 0, 0, 0, max_seqlen_k - k_i.shape[0])) if k_i.shape[0] < max_seqlen_k else k_i
584
+ for k_i in k_batched
585
+ ]) # (B, max_seqlen_k, num_heads, head_dim)
586
+ v_padded = torch.stack([
587
+ F.pad(v_i, (0, 0, 0, 0, 0, max_seqlen_k - v_i.shape[0])) if v_i.shape[0] < max_seqlen_k else v_i
588
+ for v_i in v_batched
589
+ ]) # (B, max_seqlen_k, num_heads, head_dim)
590
+
591
+ # Reshape q_compact: (B*Lq, num_heads, head_dim) -> (B, Lq, num_heads, head_dim)
592
+ q_batched = q_compact.view(B, Lq, self.num_heads, self.head_dim)
593
+
594
+ # Transpose for attention: (B, num_heads, seq_len, head_dim)
595
+ q_attn = q_batched.transpose(1, 2) # (B, num_heads, Lq, head_dim)
596
+ k_attn = k_padded.transpose(1, 2) # (B, num_heads, max_seqlen_k, head_dim)
597
+ v_attn = v_padded.transpose(1, 2) # (B, num_heads, max_seqlen_k, head_dim)
598
+
599
+ # Create attention mask to mask out padding
600
+ attn_mask = torch.zeros(B, 1, Lq, max_seqlen_k, dtype=torch.bool, device=q_compact.device)
601
+ for i in range(B):
602
+ seq_len = cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item()
603
+ if seq_len < max_seqlen_k:
604
+ attn_mask[i, :, :, seq_len:] = True # Mask padding positions
605
+
606
+ # Apply attention
607
+ oup = slow_attn(
608
+ query=q_attn,
609
+ key=k_attn,
610
+ value=v_attn,
611
+ attn_mask=~attn_mask, # True = not masked, False = masked (inverted for PyTorch)
612
+ scale=self.scale,
613
+ dropout_p=0.0
614
+ ) # (B, num_heads, Lq, head_dim)
615
+
616
+ # Reshape back: (B, num_heads, Lq, head_dim) -> (B, Lq, embed_dim)
617
+ oup = oup.transpose(1, 2).reshape(B, Lq, -1)
618
+
619
+ return self.proj_drop(self.proj(oup))
620
+
621
+ def extra_repr(self) -> str:
622
+ return f'Cq={self.embed_dim}, Ckv={self.kv_dim}, cos_attn={self.cos_attn}'
623
+
624
+
625
+ class SelfAttnBlock(nn.Module):
626
+ def __init__(
627
+ self, embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
628
+ num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
629
+ swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
630
+ ):
631
+ super(SelfAttnBlock, self).__init__()
632
+ self.C, self.D = embed_dim, cond_dim
633
+ self.drop_path_rate = drop_path
634
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
635
+ self.attn = SelfAttention(
636
+ embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn, attn_fn = attn_fn
637
+ )
638
+ self.using_swiglu = swiglu
639
+ self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
640
+
641
+ self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
642
+ self.fused_norm_func = fused_norm_func
643
+ self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
644
+
645
+ self.shared_aln = shared_aln
646
+ if self.shared_aln:
647
+ self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
648
+ else:
649
+ lin = nn.Linear(cond_dim, 6*embed_dim)
650
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
651
+
652
+ # NOTE: attn_bias_or_two_vector is None during inference
653
+ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
654
+ with torch.cuda.amp.autocast(enabled=False):
655
+ if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
656
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
657
+ else:
658
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
659
+
660
+ if self.fused_ada_norm is None:
661
+ x = x + self.drop_path(self.attn( self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1))
662
+ x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
663
+ else:
664
+ x = x + self.drop_path(self.attn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1), attn_bias_or_two_vector=attn_bias_or_two_vector).mul_(gamma1))
665
+ x = x + self.drop_path(self.ffn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
666
+ return x
667
+
668
+ def extra_repr(self) -> str:
669
+ return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}'
670
+
671
+
672
+ class CrossAttnBlock(nn.Module):
673
+ def __init__(
674
+ self,
675
+ embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial,
676
+ num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False,
677
+ swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False,
678
+ use_flex_attn=False, batch_size=2, pad_to_multiplier=1, apply_rope2d=False, rope2d_normalized_by_hw=False,
679
+ ):
680
+ super(CrossAttnBlock, self).__init__()
681
+ self.C, self.D = embed_dim, cond_dim
682
+ self.drop_path_rate = drop_path
683
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
684
+ self.sa = SelfAttention(
685
+ embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn,
686
+ use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
687
+ )
688
+ self.ca = CrossAttention(embed_dim=embed_dim, kv_dim=kv_dim, num_heads=num_heads, proj_drop=drop, cos_attn=cos_attn)
689
+ self.using_swiglu = swiglu
690
+ self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp)
691
+
692
+ self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
693
+ self.fused_norm_func = fused_norm_func
694
+ self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
695
+ self.ca_norm = norm_layer(embed_dim, elementwise_affine=True)
696
+
697
+ self.shared_aln = shared_aln
698
+ if self.shared_aln: # always True
699
+ self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
700
+ else:
701
+ lin = nn.Linear(cond_dim, 6*embed_dim)
702
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
703
+
704
+ if cross_attn_layer_scale >= 0:
705
+ self.ca_gamma = nn.Parameter(cross_attn_layer_scale * torch.ones(embed_dim), requires_grad=True)
706
+ else:
707
+ self.ca_gamma = 1
708
+
709
+ self.checkpointing_sa_only = checkpointing_sa_only
710
+
711
+ # NOTE: attn_bias_or_two_vector is None during inference
712
+ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0): # todo: minGPT and vqgan also uses pre-norm, just like this, while MaskGiT uses post-norm
713
+ with torch.cuda.amp.autocast(enabled=False): # disable half precision
714
+ if self.shared_aln: # always True; (1, 1, 6, C) + (B, 1, 6, C)
715
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
716
+ else:
717
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
718
+
719
+ if self.fused_norm_func is None:
720
+ x_sa = self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1)
721
+ if self.checkpointing_sa_only and self.training:
722
+ x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
723
+ else:
724
+ x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid)
725
+ x = x + self.drop_path(x_sa.mul_(gamma1))
726
+ x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
727
+ x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
728
+ else:
729
+ x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1)
730
+ if self.checkpointing_sa_only and self.training:
731
+ x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
732
+ else:
733
+ x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, scale_ind=scale_ind)
734
+ x = x + self.drop_path(x_sa.mul_(gamma1))
735
+ x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma)
736
+ x = x + self.drop_path(self.ffn(self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) # this mul(gamma2) cannot be in-placed cuz we possibly use FusedMLP
737
+ return x
738
+
739
+ def extra_repr(self) -> str:
740
+ return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}, ca_gamma={"<learnable>" if isinstance(self.ca_gamma, nn.Parameter) else self.ca_gamma}'
741
+
742
+
743
+ class AdaLNBeforeHead(nn.Module):
744
+ def __init__(self, C, D, act: bool, norm_layer: partial, fused_norm_func=None): # C: embed_dim, D: cond_dim
745
+ super().__init__()
746
+ self.C, self.D = C, D
747
+ self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
748
+ self.fused_norm_func = fused_norm_func
749
+ self.norm_eps = norm_layer.keywords.get('eps', 1e-6)
750
+ lin = nn.Linear(D, 2*C)
751
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin)
752
+
753
+ def forward(self, x_BLC: torch.Tensor, cond_BD: Optional[torch.Tensor]):
754
+ scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
755
+ if self.fused_norm_func is None:
756
+ return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
757
+ else:
758
+ return self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x_BLC, scale=scale, shift=shift)
759
+
760
+
761
+ def main():
762
+ dev = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
763
+ rng = torch.Generator(device=dev)
764
+ # for Li in ([1, 3, 5], [1, 3]):
765
+ rng.manual_seed(0)
766
+ B, H, cq, ckv = 4, 8, 64, 96
767
+ Cq = H*cq
768
+ Ckv = H*ckv
769
+
770
+ Li = [5, 4, 7, 6]
771
+ Lq = 10
772
+ L = max(Li)
773
+ attn_bias = torch.zeros(B, 1, Lq, L, device=dev)
774
+ for i, x in enumerate(Li):
775
+ attn_bias[i, 0, :, x:] = -torch.inf
776
+
777
+ q = torch.randn(B, Lq, H, cq, generator=rng, device=dev)
778
+ k = torch.randn(B, L, H, ckv, generator=rng, device=dev)
779
+ v = torch.randn(B, L, H, ckv, generator=rng, device=dev)
780
+ tq, tk, tv = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # BHLc
781
+
782
+ seqlen_k = torch.tensor(Li, dtype=torch.int32, device=dev)
783
+ cu_seqlens_k = F.pad(torch.cumsum(seqlen_k, dim=0, dtype=torch.torch.int32), (1, 0))
784
+ kv = torch.stack([k, v], dim=2)
785
+ kv_compact = torch.cat([kv[i, :Li[i]] for i in range(B)], dim=0)
786
+
787
+ ca = CrossAttention(for_attn_pool=False, embed_dim=Cq, kv_dim=Ckv, num_heads=H)
788
+ CrossAttention.forward
789
+ ca(q, (kv_compact, cu_seqlens_k, max(Li))).mean().backward()
790
+
791
+
792
+ if __name__ == '__main__':
793
+ main()
Infinity/infinity/models/infinity.py ADDED
@@ -0,0 +1,817 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Definition of Infinity transformer model.
3
+ """
4
+
5
+ import math
6
+ import random
7
+ import time
8
+ from contextlib import nullcontext
9
+ from functools import partial
10
+ from typing import List, Optional, Tuple, Union, Dict, Any
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from timm.models import register_model
16
+ from torch.utils.checkpoint import checkpoint
17
+ from PIL import Image
18
+ import numpy as np
19
+
20
+ import infinity.utils.dist as dist
21
+ from infinity.utils.dist import for_visualize
22
+ from infinity.models.basic import flash_attn_func, flash_fused_op_installed, AdaLNBeforeHead, CrossAttnBlock, SelfAttnBlock, CrossAttention, FastRMSNorm, precompute_rope2d_freqs_grid
23
+ from infinity.utils import misc
24
+ from infinity.models.flex_attn import FlexAttn
25
+ from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
26
+
27
+ try:
28
+ from infinity.models.fused_op import fused_ada_layer_norm, fused_ada_rms_norm
29
+ except:
30
+ fused_ada_layer_norm, fused_ada_rms_norm = None, None
31
+
32
+
33
+ class MultiInpIdentity(nn.Module):
34
+ def forward(self, x, *args, **kwargs):
35
+ return x
36
+
37
+
38
+ class TextAttentivePool(nn.Module):
39
+ def __init__(self, Ct5: int, D: int):
40
+ super().__init__()
41
+ self.Ct5, self.D = Ct5, D
42
+ if D > 4096:
43
+ self.head_dim = 64
44
+ else:
45
+ self.head_dim = 128
46
+
47
+ self.num_heads = Ct5 // self.head_dim
48
+ self.ca = CrossAttention(for_attn_pool=True, embed_dim=self.D, kv_dim=Ct5, num_heads=self.num_heads)
49
+ def forward(self, ca_kv):
50
+ return self.ca(None, ca_kv).squeeze(1)
51
+
52
+ class SharedAdaLin(nn.Linear):
53
+ def forward(self, cond_BD):
54
+ C = self.weight.shape[0] // 6
55
+ # Import get_weight_for_linear from basic.py
56
+ from infinity.models.basic import get_weight_for_linear
57
+ weight = get_weight_for_linear(self, target_dtype=cond_BD.dtype)
58
+ return F.linear(cond_BD, weight, self.bias).reshape(-1, 1, 6, C) # B16C
59
+
60
+
61
+ class MultipleLayers(nn.Module):
62
+ def __init__(self, ls, num_blocks_in_a_chunk, index):
63
+ super().__init__()
64
+ self.module = nn.ModuleList()
65
+ for i in range(index, index+num_blocks_in_a_chunk):
66
+ self.module.append(ls[i])
67
+
68
+ def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, checkpointing_full_block=False, rope2d_freqs_grid=None):
69
+ h = x
70
+ for m in self.module:
71
+ if checkpointing_full_block:
72
+ h = torch.utils.checkpoint.checkpoint(m, h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False)
73
+ else:
74
+ h = m(h, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid)
75
+ return h
76
+
77
+ class Infinity(nn.Module):
78
+ def __init__(
79
+ self, vae_local,
80
+ text_channels=0, text_maxlen=0, # text-cond generation
81
+ selecting_idx=None, # class-cond generation
82
+ embed_dim=1024, depth=16, num_heads=16, mlp_ratio=4., # model's architecture
83
+ drop_rate=0., drop_path_rate=0., # drop out and drop path
84
+ norm_eps=1e-6, rms_norm=False, # norm layer
85
+ shared_aln=False, head_aln=True, # adaptive norm
86
+ cond_drop_rate=0.1, # for classifier-free guidance
87
+ rand_uncond=False,
88
+ cross_attn_layer_scale=-1., nm0=False, tau=1, cos_attn=True, swiglu=False,
89
+ raw_scale_schedule=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
90
+ head_depth=1,
91
+ top_p=0.0, top_k=0.0,
92
+ customized_flash_attn=False, fused_mlp=False, fused_norm=False,
93
+ block_chunks=1,
94
+ checkpointing=None,
95
+ pad_to_multiplier=0,
96
+ use_flex_attn=False,
97
+ batch_size=2,
98
+ add_lvl_embeding_only_first_block=1,
99
+ use_bit_label=1,
100
+ rope2d_each_sa_layer=0,
101
+ rope2d_normalized_by_hw=0,
102
+ pn=None,
103
+ train_h_div_w_list=None,
104
+ video_frames=1,
105
+ always_training_scales=20,
106
+ apply_spatial_patchify = 0,
107
+ inference_mode=False,
108
+ ):
109
+ # set hyperparameters
110
+ self.C = embed_dim
111
+ self.inference_mode = inference_mode
112
+ self.apply_spatial_patchify = apply_spatial_patchify
113
+ if self.apply_spatial_patchify:
114
+ self.d_vae = vae_local.embed_dim * 4
115
+ else:
116
+ self.d_vae = vae_local.embed_dim
117
+ self.use_bit_label = use_bit_label
118
+ self.codebook_dim = self.d_vae
119
+ self.V = (self.codebook_dim * 2) if self.use_bit_label else vae_local.vocab_size
120
+ self.bit_mask = vae_local.quantizer.lfq.mask if self.use_bit_label else None
121
+ self.Ct5 = text_channels
122
+ self.depth = depth
123
+ self.num_heads = num_heads
124
+ self.batch_size = batch_size
125
+ self.mlp_ratio = mlp_ratio
126
+ self.cond_drop_rate = cond_drop_rate
127
+ self.norm_eps = norm_eps
128
+ self.prog_si = -1
129
+ self.pn = pn
130
+ self.train_h_div_w_list = train_h_div_w_list if train_h_div_w_list else h_div_w_templates
131
+ self.video_frames = video_frames
132
+ self.always_training_scales = always_training_scales
133
+
134
+ assert add_lvl_embeding_only_first_block in [0,1]
135
+ self.add_lvl_embeding_only_first_block = add_lvl_embeding_only_first_block
136
+ assert rope2d_each_sa_layer in [0,1]
137
+ self.rope2d_each_sa_layer = rope2d_each_sa_layer
138
+ self.rope2d_normalized_by_hw = rope2d_normalized_by_hw
139
+ print(f'self.codebook_dim: {self.codebook_dim}, self.add_lvl_embeding_only_first_block: {self.add_lvl_embeding_only_first_block}, \
140
+ self.use_bit_label: {self.use_bit_label}, self.rope2d_each_sa_layer: {rope2d_each_sa_layer}, self.rope2d_normalized_by_hw: {self.rope2d_normalized_by_hw}')
141
+ head_up_method = ''
142
+ word_patch_size = 1 if head_up_method in {'', 'no'} else 2
143
+ if word_patch_size > 1:
144
+ assert all(raw_pn % word_patch_size == 0 for raw_pn in raw_scale_schedule), f'raw_scale_schedule={raw_scale_schedule}, not compatible with word_patch_size={word_patch_size}'
145
+
146
+ self.checkpointing = checkpointing
147
+ self.pad_to_multiplier = max(1, pad_to_multiplier)
148
+
149
+ customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames)
150
+ self.customized_flash_attn = customized_flash_attn and customized_kernel_installed
151
+ if customized_flash_attn and not customized_kernel_installed:
152
+ import inspect, warnings
153
+ file_path = inspect.getsourcefile(flash_attn_func)
154
+ line_number = inspect.getsourcelines(flash_attn_func)[1]
155
+ info = (
156
+ f'>>>>>> Customized FlashAttention2 is not installed or compiled, but specified in args by --flash=1. Set customized_flash_attn = False. <<<<<<\n'
157
+ f'>>>>>> `flash_attn_func` is in [line {line_number}] [file {file_path}] <<<<<<\n'
158
+ f'>>>>>> {flash_attn_func.__code__.co_varnames=} <<<<<<\n'
159
+ )
160
+ warnings.warn(info, ImportWarning)
161
+ print(info, flush=True)
162
+
163
+ self.raw_scale_schedule = raw_scale_schedule # 'raw' means before any patchifying
164
+ self.first_l = 1
165
+ # solve top-p top-k sampling hyperparameters
166
+ self.top_p, self.top_k = max(min(top_p, 1), 0), (round(top_k * self.V) if 0 < top_k < 1 else round(top_k))
167
+ if self.top_p < 1e-5: self.top_p = 0
168
+ if self.top_k >= self.V or self.top_k <= 0: self.top_k = 0
169
+
170
+ t = torch.zeros(dist.get_world_size(), device=dist.get_device())
171
+ t[dist.get_rank()] = float(flash_fused_op_installed)
172
+ dist.barrier()
173
+ dist.allreduce(t)
174
+ assert round(t.sum().item()) in {0, dist.get_world_size()}, f'flash_fused_op_installed: {t}'
175
+
176
+ super().__init__()
177
+ self.rng = torch.Generator(device=dist.get_device())
178
+ self.maybe_record_function = nullcontext
179
+ self.text_maxlen = text_maxlen
180
+ self.t2i = text_channels != 0
181
+
182
+ # [inp & position embedding]
183
+ init_std = math.sqrt(1 / self.C / 3)
184
+ self.norm0_cond = nn.Identity()
185
+ if self.t2i:
186
+ self.selecting_idx = None
187
+ self.num_classes = 0
188
+ self.D = self.C
189
+
190
+ cfg_uncond = torch.empty(self.text_maxlen, self.Ct5)
191
+ rng = torch.Generator(device='cpu')
192
+ rng.manual_seed(0)
193
+ torch.nn.init.trunc_normal_(cfg_uncond, std=1.2, generator=rng)
194
+ cfg_uncond /= self.Ct5 ** 0.5
195
+ if rand_uncond:
196
+ self.register_buffer('cfg_uncond', cfg_uncond)
197
+ else:
198
+ self.cfg_uncond = nn.Parameter(cfg_uncond)
199
+
200
+ self.text_norm = FastRMSNorm(self.Ct5, elementwise_affine=True, eps=norm_eps)
201
+ self.text_proj_for_sos = TextAttentivePool(self.Ct5, self.D)
202
+ self.text_proj_for_ca = nn.Sequential(
203
+ nn.Linear(self.Ct5, self.D),
204
+ nn.GELU(approximate='tanh'),
205
+ nn.Linear(self.D, self.D),
206
+ )
207
+ else: # class-label cond
208
+ if selecting_idx is None:
209
+ num_classes = 1000
210
+ print(f'======= WARNING: selecting_idx not specified, set to 1/{num_classes} @ {dist.get_device()} =======')
211
+ selecting_idx = torch.full((1, num_classes), fill_value=1/num_classes, dtype=torch.float32, device=dist.get_device())
212
+ self.selecting_idx = selecting_idx
213
+ self.num_classes = selecting_idx.shape[-1]
214
+ self.D = self.C
215
+ self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
216
+ nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
217
+
218
+ self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
219
+ nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
220
+ if self.rope2d_each_sa_layer:
221
+ rope2d_freqs_grid = precompute_rope2d_freqs_grid(dim=self.C//self.num_heads, dynamic_resolution_h_w=dynamic_resolution_h_w, pad_to_multiplier=self.pad_to_multiplier, rope2d_normalized_by_hw=self.rope2d_normalized_by_hw)
222
+ self.rope2d_freqs_grid = rope2d_freqs_grid
223
+ else:
224
+ raise ValueError(f'self.rope2d_each_sa_layer={self.rope2d_each_sa_layer} not implemented')
225
+ self.lvl_embed = nn.Embedding(15, self.C)
226
+ nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
227
+
228
+ # [input layers] input norm && input embedding
229
+ norm_layer = partial(FastRMSNorm if rms_norm else nn.LayerNorm, eps=norm_eps)
230
+ self.norm0_ve = norm_layer(self.d_vae) if nm0 else nn.Identity()
231
+ self.word_embed = nn.Linear(self.d_vae, self.C)
232
+
233
+ # [shared adaptive layernorm mapping network]
234
+ self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6*self.C)) if shared_aln else nn.Identity()
235
+
236
+ # fused norm
237
+ if fused_norm:
238
+ fused_norm_func = fused_ada_rms_norm if rms_norm else fused_ada_layer_norm
239
+ if fused_norm_func is not None: # pre-compile
240
+ B = 2
241
+ x = torch.randn(B, 1, self.C).requires_grad_(True)
242
+ scale = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
243
+ shift = torch.randn(B, 1, self.C).mul_(0.01).requires_grad_(True)
244
+ # fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale, shift=shift).mean().backward()
245
+ del B, x, scale, shift
246
+ else:
247
+ fused_norm_func = None
248
+
249
+ # [backbone and head]
250
+ self.use_flex_attn = use_flex_attn
251
+ self.attn_fn_compile_dict = {}
252
+ self.batch_size = batch_size
253
+ if self.use_flex_attn:
254
+ self.attn_fn_compile_dict = self.compile_flex_attn()
255
+
256
+ self.drop_path_rate = drop_path_rate
257
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # dpr means drop path rate (linearly increasing)
258
+ self.unregistered_blocks = []
259
+ for block_idx in range(depth):
260
+ block = (CrossAttnBlock if self.t2i else SelfAttnBlock)(
261
+ embed_dim=self.C, kv_dim=self.D, cross_attn_layer_scale=cross_attn_layer_scale, cond_dim=self.D, act=True, shared_aln=shared_aln, norm_layer=norm_layer,
262
+ num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[block_idx], tau=tau, cos_attn=cos_attn,
263
+ swiglu=swiglu, customized_flash_attn=self.customized_flash_attn, fused_mlp=fused_mlp, fused_norm_func=fused_norm_func,
264
+ checkpointing_sa_only=self.checkpointing == 'self-attn',
265
+ use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw,
266
+ )
267
+ self.unregistered_blocks.append(block)
268
+
269
+ # [head]
270
+ V = self.V
271
+ if head_aln:
272
+ self.head_nm = AdaLNBeforeHead(self.C, self.D, act=True, norm_layer=norm_layer, fused_norm_func=fused_norm_func)
273
+ self.head = nn.Linear(self.C, V) if head_depth == 1 else nn.Sequential(nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
274
+ else:
275
+ self.head_nm = MultiInpIdentity()
276
+ self.head = nn.Sequential(norm_layer(self.C), nn.Linear(self.C, V)) if head_depth == 1 else nn.Sequential(norm_layer(self.C), nn.Linear(self.C, self.C, bias=True), nn.GELU(approximate='tanh'), nn.Linear(self.C, V))
277
+
278
+ self.num_block_chunks = block_chunks or 1
279
+ self.num_blocks_in_a_chunk = depth // block_chunks
280
+ print(f"{self.num_blocks_in_a_chunk=}, {depth=}, {block_chunks=}")
281
+ assert self.num_blocks_in_a_chunk * block_chunks == depth
282
+ if self.num_block_chunks == 1:
283
+ self.blocks = nn.ModuleList(self.unregistered_blocks)
284
+ else:
285
+ self.block_chunks = nn.ModuleList()
286
+ for i in range(self.num_block_chunks):
287
+ self.block_chunks.append(MultipleLayers(self.unregistered_blocks, self.num_blocks_in_a_chunk, i*self.num_blocks_in_a_chunk))
288
+ print(
289
+ f'\n[constructor] ==== customized_flash_attn={self.customized_flash_attn} (using_flash={sum((b.sa.using_flash if self.t2i else b.attn.using_flash) for b in self.unregistered_blocks)}/{self.depth}), fused_mlp={fused_mlp} (fused_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.unregistered_blocks)}/{self.depth}) ==== \n'
290
+ f' [Infinity config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}, swiglu={swiglu} num_blocks_in_a_chunk={self.num_blocks_in_a_chunk}\n'
291
+ f' [drop ratios] drop_rate={drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
292
+ end='\n\n', flush=True
293
+ )
294
+
295
+
296
+ def compile_flex_attn(self):
297
+ attn_fn_compile_dict = {}
298
+ for h_div_w in self.train_h_div_w_list:
299
+ h_div_w_template = h_div_w_templates[np.argmin(np.abs(float(h_div_w) - h_div_w_templates))]
300
+ full_scale_schedule = dynamic_resolution_h_w[h_div_w_template][self.pn]['scales']
301
+ if self.inference_mode:
302
+ apply_flex_attn_scales = list(range(1, 1+len(full_scale_schedule)))
303
+ mask_type = "infinity_infer_mask_with_kv_cache"
304
+ auto_padding = True
305
+ else:
306
+ mask_type = 'var'
307
+ auto_padding = False
308
+ apply_flex_attn_scales = [min(self.always_training_scales, len(full_scale_schedule))]
309
+ for scales_num in apply_flex_attn_scales:
310
+ print(f'====== apply flex attn hdivw: {h_div_w} scales: {scales_num} ======')
311
+ scale_schedule = full_scale_schedule[:scales_num]
312
+ scale_schedule = [ (min(t, self.video_frames//4+1), h, w) for (t,h, w) in scale_schedule]
313
+ patchs_nums_tuple = tuple(scale_schedule)
314
+ SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple)
315
+ aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
316
+ attn_fn = FlexAttn(block_scales = patchs_nums_tuple,
317
+ mask_type = mask_type,
318
+ B = self.batch_size,
319
+ H = self.num_heads,
320
+ L = aligned_L,
321
+ auto_padding=auto_padding)
322
+ attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
323
+
324
+ if self.video_frames > 1: # append image attn_fn when self.video_frames > 1 (namely videos)
325
+ scale_schedule = [ (1, h, w) for (t,h, w) in scale_schedule]
326
+ patchs_nums_tuple = tuple(scale_schedule)
327
+ SEQ_L = sum( pt * ph * pw for pt, ph, pw in patchs_nums_tuple)
328
+ aligned_L = SEQ_L+ (self.pad_to_multiplier - SEQ_L % self.pad_to_multiplier) if SEQ_L % self.pad_to_multiplier != 0 else SEQ_L
329
+ attn_fn = FlexAttn(block_scales = patchs_nums_tuple,
330
+ mask_type = mask_type,
331
+ B = self.batch_size,
332
+ H = self.num_heads,
333
+ L = aligned_L)
334
+ attn_fn_compile_dict[patchs_nums_tuple] = attn_fn
335
+ return attn_fn_compile_dict
336
+
337
+ def _apply_module_with_dtype_handling(self, module, x):
338
+ """
339
+ Apply a module (Linear, Sequential, etc.) with F16 weight dtype handling.
340
+ """
341
+ from infinity.models.basic import get_weight_for_linear
342
+
343
+ if isinstance(module, nn.Linear):
344
+ # Handle Linear layer with dtype conversion
345
+ weight = get_weight_for_linear(module, target_dtype=x.dtype)
346
+ return F.linear(x, weight, module.bias)
347
+ elif isinstance(module, nn.Sequential):
348
+ # Recursively apply each layer in the sequential
349
+ for layer in module:
350
+ x = self._apply_module_with_dtype_handling(layer, x)
351
+ return x
352
+ else:
353
+ # For other modules (GELU, LayerNorm, etc.), apply directly
354
+ return module(x)
355
+
356
+ def get_logits(self, h: torch.Tensor, cond_BD: Optional[torch.Tensor]):
357
+ """
358
+ :param h: hidden_state, shaped (B or batch_size, L or seq_len, C or hidden_dim)
359
+ :param cond_BD: shaped (B or batch_size, D or cond_dim)
360
+ :param tau: temperature
361
+ :return: logits, shaped (B or batch_size, V or vocabulary_size)
362
+ """
363
+ with torch.amp.autocast('cuda', enabled=False):
364
+ x = self.head_nm(h.float(), cond_BD.float())
365
+ return self._apply_module_with_dtype_handling(self.head, x)
366
+
367
+ def add_lvl_embeding(self, feature, scale_ind, scale_schedule, need_to_pad=0):
368
+ bs, seq_len, c = feature.shape
369
+ patch_t, patch_h, patch_w = scale_schedule[scale_ind]
370
+ t_mul_h_mul_w = patch_t * patch_h * patch_w
371
+ assert t_mul_h_mul_w + need_to_pad == seq_len
372
+ feature[:, :t_mul_h_mul_w] += self.lvl_embed(scale_ind*torch.ones((bs, t_mul_h_mul_w),dtype=torch.int).to(feature.device))
373
+ return feature
374
+
375
+ def add_lvl_embeding_for_x_BLC(self, x_BLC, scale_schedule, need_to_pad=0):
376
+ ptr = 0
377
+ x_BLC_list = []
378
+ for scale_ind, patch_t_h_w in enumerate(scale_schedule):
379
+ scale_seq_len = np.array(patch_t_h_w).prod()
380
+ x_BLC_this_scale = x_BLC[:,ptr:ptr+scale_seq_len] # shape: [bs, patch_h*patch_w, c]
381
+ ptr += scale_seq_len
382
+ x_BLC_this_scale = self.add_lvl_embeding(x_BLC_this_scale, scale_ind, scale_schedule)
383
+ x_BLC_list.append(x_BLC_this_scale)
384
+ assert x_BLC.shape[1] == (ptr + need_to_pad), f'{x_BLC.shape[1]} != {ptr} + {need_to_pad}'
385
+ x_BLC_list.append(x_BLC[:,ptr:])
386
+ x_BLC = torch.cat(x_BLC_list, dim=1)
387
+ return x_BLC
388
+
389
+ def forward(self, label_B_or_BLT: Union[torch.LongTensor, Tuple[torch.FloatTensor, torch.IntTensor, int]], x_BLC_wo_prefix: torch.Tensor, scale_schedule: List[Tuple[int]],
390
+ cfg_infer=False,
391
+ **kwargs,
392
+ ) -> Union[torch.Tensor, List[torch.Tensor]]: # returns logits_BLV
393
+ """
394
+ label_B_or_BLT: label_B or (kv_compact, cu_seqlens_k, max_seqlen_k)
395
+ :return: logits BLV, V is vocab_size
396
+ """
397
+ if cfg_infer:
398
+ return self.autoregressive_infer_cfg(label_B_or_BLT=label_B_or_BLT, scale_schedule=scale_schedule, **kwargs)
399
+
400
+ x_BLC_wo_prefix = x_BLC_wo_prefix.float() # input should be float32
401
+ B = x_BLC_wo_prefix.shape[0]
402
+
403
+ # [1. get input sequence x_BLC]
404
+ with torch.amp.autocast('cuda', enabled=False):
405
+ kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
406
+ # drop cond
407
+ total = 0
408
+ for le in lens:
409
+ if random.random() < self.cond_drop_rate:
410
+ kv_compact[total:total+le] = self.cfg_uncond[:le]
411
+ total += le
412
+ must_on_graph = self.cfg_uncond[0, 0] * 0
413
+ kv_compact = self.text_norm(kv_compact).contiguous()
414
+ sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)).float().contiguous() # cond_BD should be float32
415
+ kv_compact = self.text_proj_for_ca(kv_compact).contiguous()
416
+ kv_compact[0, 0] += must_on_graph
417
+ ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
418
+
419
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD).contiguous() # gss: gamma, scale, shift; cond_BD_or_gss should be float32
420
+
421
+ sos = sos.unsqueeze(1).expand(B, 1, -1) + self.pos_start.expand(B, 1, -1)
422
+ x_BLC = torch.cat((sos, self.word_embed(self.norm0_ve(x_BLC_wo_prefix))), dim=1)
423
+
424
+ # [1.1. pad the seqlen dim]
425
+ l_end = x_BLC.shape[1]
426
+ need_to_pad = (l_end + self.pad_to_multiplier - 1) // self.pad_to_multiplier * self.pad_to_multiplier - l_end # 0
427
+
428
+ if self.customized_flash_attn:
429
+ Infinity_visible_kvlen = self.Infinity_visible_kvlen[:l_end]
430
+ Infinity_invisible_qlen = self.Infinity_invisible_qlen[:l_end]
431
+ attn_bias_or_two_vector = (Infinity_visible_kvlen, Infinity_invisible_qlen)
432
+ # todo: solve need_to_pad here
433
+ elif self.use_flex_attn:
434
+ if need_to_pad:
435
+ x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
436
+ assert x_BLC.shape[-1] % 128 == 0, 'x_BLC.shape[-1] % 128 != 0'
437
+ attn_bias_or_two_vector = None
438
+ else:
439
+ d: torch.Tensor = torch.cat([torch.full((pn[0]*pn[1]*pn[2],), i) for i, pn in enumerate(scale_schedule)]).view(1, l_end, 1)
440
+ dT = d.transpose(1, 2) # dT: 11L
441
+ attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, l_end, l_end)
442
+ attn_bias = attn_bias_for_masking[:, :, :l_end, :l_end].contiguous() # attn_bias: 11LL
443
+ if need_to_pad:
444
+ attn_bias = F.pad(attn_bias, (0, need_to_pad, 0, need_to_pad), value=-torch.inf)
445
+ attn_bias[0, 0, l_end:, 0] = 0
446
+ x_BLC = F.pad(x_BLC, (0, 0, 0, need_to_pad))
447
+ attn_bias_or_two_vector = attn_bias.type_as(x_BLC).to(x_BLC.device)
448
+
449
+ if self.use_flex_attn:
450
+ attn_fn = self.attn_fn_compile_dict[tuple(scale_schedule)]
451
+ else:
452
+ attn_fn = None
453
+
454
+ # [2. block loop]
455
+ SelfAttnBlock.forward, CrossAttnBlock.forward
456
+ checkpointing_full_block = self.checkpointing == 'full-block' and self.training
457
+ if self.num_block_chunks == 1:
458
+ for i, b in enumerate(self.blocks):
459
+ if self.add_lvl_embeding_only_first_block and i == 0:
460
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
461
+ if not self.add_lvl_embeding_only_first_block:
462
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
463
+ if checkpointing_full_block:
464
+ x_BLC = torch.utils.checkpoint.checkpoint(b, x_BLC, cond_BD_or_gss, ca_kv, attn_bias_or_two_vector, attn_fn, scale_schedule, self.rope2d_freqs_grid, use_reentrant=False)
465
+ else:
466
+ x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid)
467
+ else:
468
+ for i, chunk in enumerate(self.block_chunks): # this path
469
+ if self.add_lvl_embeding_only_first_block and i == 0:
470
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
471
+ if not self.add_lvl_embeding_only_first_block:
472
+ x_BLC = self.add_lvl_embeding_for_x_BLC(x_BLC, scale_schedule, need_to_pad)
473
+ x_BLC = chunk(x=x_BLC, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=attn_bias_or_two_vector, attn_fn=attn_fn, scale_schedule=scale_schedule, checkpointing_full_block=checkpointing_full_block, rope2d_freqs_grid=self.rope2d_freqs_grid)
474
+
475
+ # [3. unpad the seqlen dim, and then get logits]
476
+ return self.get_logits(x_BLC[:, :l_end], cond_BD) # return logits BLV, V is vocab_size
477
+
478
+ @torch.no_grad()
479
+ def autoregressive_infer_cfg(
480
+ self,
481
+ vae=None,
482
+ scale_schedule=None,
483
+ label_B_or_BLT=None,
484
+ B=1, negative_label_B_or_BLT=None, force_gt_Bhw=None,
485
+ g_seed=None, cfg_list=[], tau_list=[], cfg_sc=3, top_k=0, top_p=0.0,
486
+ returns_vemb=0, ratio_Bl1=None, gumbel=0, norm_cfg=False,
487
+ cfg_exp_k: float=0.0, cfg_insertion_layer=[-5],
488
+ vae_type=0, softmax_merge_topk=-1, ret_img=False,
489
+ trunk_scale=1000,
490
+ gt_leak=0, gt_ls_Bl=None,
491
+ inference_mode=False,
492
+ save_img_path=None,
493
+ sampling_per_bits=1,
494
+ ): # returns List[idx_Bl]
495
+ if g_seed is None: rng = None
496
+ else: self.rng.manual_seed(g_seed); rng = self.rng
497
+ assert len(cfg_list) >= len(scale_schedule)
498
+ assert len(tau_list) >= len(scale_schedule)
499
+
500
+ # scale_schedule is used by infinity, vae_scale_schedule is used by vae if there exists a spatial patchify,
501
+ # we need to convert scale_schedule to vae_scale_schedule by multiply 2 to h and w
502
+ if self.apply_spatial_patchify:
503
+ vae_scale_schedule = [(pt, 2*ph, 2*pw) for pt, ph, pw in scale_schedule]
504
+ else:
505
+ vae_scale_schedule = scale_schedule
506
+
507
+ kv_compact, lens, cu_seqlens_k, max_seqlen_k = label_B_or_BLT
508
+ if any(np.array(cfg_list) != 1):
509
+ bs = 2*B
510
+ if not negative_label_B_or_BLT:
511
+ kv_compact_un = kv_compact.clone()
512
+ total = 0
513
+ for le in lens:
514
+ kv_compact_un[total:total+le] = (self.cfg_uncond)[:le]
515
+ total += le
516
+ kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
517
+ cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k[1:]+cu_seqlens_k[-1]), dim=0)
518
+ else:
519
+ kv_compact_un, lens_un, cu_seqlens_k_un, max_seqlen_k_un = negative_label_B_or_BLT
520
+ kv_compact = torch.cat((kv_compact, kv_compact_un), dim=0)
521
+ cu_seqlens_k = torch.cat((cu_seqlens_k, cu_seqlens_k_un[1:]+cu_seqlens_k[-1]), dim=0)
522
+ max_seqlen_k = max(max_seqlen_k, max_seqlen_k_un)
523
+ else:
524
+ bs = B
525
+
526
+ kv_compact = self.text_norm(kv_compact)
527
+ sos = cond_BD = self.text_proj_for_sos((kv_compact, cu_seqlens_k, max_seqlen_k)) # sos shape: [2, 4096]
528
+ kv_compact = self.text_proj_for_ca(kv_compact) # kv_compact shape: [304, 4096]
529
+ ca_kv = kv_compact, cu_seqlens_k, max_seqlen_k
530
+ last_stage = sos.unsqueeze(1).expand(bs, 1, -1) + self.pos_start.expand(bs, 1, -1)
531
+
532
+ with torch.amp.autocast('cuda', enabled=False):
533
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD.float()).float().contiguous()
534
+ accu_BChw, cur_L, ret = None, 0, [] # current length, list of reconstructed images
535
+ idx_Bl_list, idx_Bld_list = [], []
536
+
537
+ if inference_mode:
538
+ for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(True)
539
+ else:
540
+ assert self.num_block_chunks > 1
541
+ for block_chunk_ in self.block_chunks:
542
+ for module in block_chunk_.module.module:
543
+ (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(True)
544
+
545
+ abs_cfg_insertion_layers = []
546
+ add_cfg_on_logits, add_cfg_on_probs = False, False
547
+ leng = len(self.unregistered_blocks)
548
+ for item in cfg_insertion_layer:
549
+ if item == 0: # add cfg on logits
550
+ add_cfg_on_logits = True
551
+ elif item == 1: # add cfg on probs
552
+ add_cfg_on_probs = True # todo in the future, we may want to add cfg on logits and probs
553
+ elif item < 0: # determine to add cfg at item-th layer's output
554
+ assert leng+item > 0, f'cfg_insertion_layer: {item} is not valid since len(unregistered_blocks)={self.num_block_chunks}'
555
+ abs_cfg_insertion_layers.append(leng+item)
556
+ else:
557
+ raise ValueError(f'cfg_insertion_layer: {item} is not valid')
558
+
559
+ num_stages_minus_1 = len(scale_schedule)-1
560
+ summed_codes = 0
561
+ for si, pn in enumerate(scale_schedule): # si: i-th segment
562
+ cfg = cfg_list[si]
563
+ if si >= trunk_scale:
564
+ break
565
+ cur_L += np.array(pn).prod()
566
+
567
+ need_to_pad = 0
568
+ attn_fn = None
569
+ if self.use_flex_attn:
570
+ # need_to_pad = (self.pad_to_multiplier - cur_L % self.pad_to_multiplier) % self.pad_to_multiplier
571
+ # if need_to_pad:
572
+ # last_stage = F.pad(last_stage, (0, 0, 0, need_to_pad))
573
+ attn_fn = self.attn_fn_compile_dict.get(tuple(scale_schedule[:(si+1)]), None)
574
+
575
+ # assert self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].sum() == 0, f'AR with {(self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L] != 0).sum()} / {self.attn_bias_for_masking[:, :, last_L:cur_L, :cur_L].numel()} mask item'
576
+ layer_idx = 0
577
+ for block_idx, b in enumerate(self.block_chunks):
578
+ # last_stage shape: [4, 1, 2048], cond_BD_or_gss.shape: [4, 1, 6, 2048], ca_kv[0].shape: [64, 2048], ca_kv[1].shape [5], ca_kv[2]: int
579
+ if self.add_lvl_embeding_only_first_block and block_idx == 0:
580
+ last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
581
+ if not self.add_lvl_embeding_only_first_block:
582
+ last_stage = self.add_lvl_embeding(last_stage, si, scale_schedule, need_to_pad=need_to_pad)
583
+
584
+ for m in b.module:
585
+ last_stage = m(x=last_stage, cond_BD=cond_BD_or_gss, ca_kv=ca_kv, attn_bias_or_two_vector=None, attn_fn=attn_fn, scale_schedule=scale_schedule, rope2d_freqs_grid=self.rope2d_freqs_grid, scale_ind=si)
586
+ if (cfg != 1) and (layer_idx in abs_cfg_insertion_layers):
587
+ # print(f'add cfg={cfg} on {layer_idx}-th layer output')
588
+ last_stage = cfg * last_stage[:B] + (1-cfg) * last_stage[B:]
589
+ last_stage = torch.cat((last_stage, last_stage), 0)
590
+ layer_idx += 1
591
+
592
+ if (cfg != 1) and add_cfg_on_logits:
593
+ # print(f'add cfg on add_cfg_on_logits')
594
+ logits_BlV = self.get_logits(last_stage, cond_BD).mul(1/tau_list[si])
595
+ logits_BlV = cfg * logits_BlV[:B] + (1-cfg) * logits_BlV[B:]
596
+ else:
597
+ logits_BlV = self.get_logits(last_stage[:B], cond_BD[:B]).mul(1/tau_list[si])
598
+
599
+ if self.use_bit_label:
600
+ tmp_bs, tmp_seq_len = logits_BlV.shape[:2]
601
+ logits_BlV = logits_BlV.reshape(tmp_bs, -1, 2)
602
+ idx_Bld = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
603
+ idx_Bld = idx_Bld.reshape(tmp_bs, tmp_seq_len, -1)
604
+ else:
605
+ idx_Bl = sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV, rng=rng, top_k=top_k or self.top_k, top_p=top_p or self.top_p, num_samples=1)[:, :, 0]
606
+ if vae_type != 0:
607
+ assert returns_vemb
608
+ if si < gt_leak:
609
+ idx_Bld = gt_ls_Bl[si]
610
+ else:
611
+ assert pn[0] == 1
612
+ idx_Bld = idx_Bld.reshape(B, pn[1], pn[2], -1) # shape: [B, h, w, d] or [B, h, w, 4d]
613
+ if self.apply_spatial_patchify: # unpatchify operation
614
+ idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w]
615
+ idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w]
616
+ idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d]
617
+ idx_Bld = idx_Bld.unsqueeze(1) # [B, 1, h, w, d] or [B, 1, 2h, 2w, d]
618
+
619
+ idx_Bld_list.append(idx_Bld)
620
+ codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
621
+ if si != num_stages_minus_1:
622
+ summed_codes += F.interpolate(codes, size=vae_scale_schedule[-1], mode=vae.quantizer.z_interplote_up)
623
+ last_stage = F.interpolate(summed_codes, size=vae_scale_schedule[si+1], mode=vae.quantizer.z_interplote_up) # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
624
+ last_stage = last_stage.squeeze(-3) # [B, d, h, w] or [B, d, 2h, 2w]
625
+ if self.apply_spatial_patchify: # patchify operation
626
+ last_stage = torch.nn.functional.pixel_unshuffle(last_stage, 2) # [B, 4d, h, w]
627
+ last_stage = last_stage.reshape(*last_stage.shape[:2], -1) # [B, d, h*w] or [B, 4d, h*w]
628
+ last_stage = torch.permute(last_stage, [0,2,1]) # [B, h*w, d] or [B, h*w, 4d]
629
+ else:
630
+ summed_codes += codes
631
+ else:
632
+ if si < gt_leak:
633
+ idx_Bl = gt_ls_Bl[si]
634
+ h_BChw = self.quant_only_used_in_inference[0].embedding(idx_Bl).float() # BlC
635
+
636
+ # h_BChw = h_BChw.float().transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1])
637
+ h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.d_vae, scale_schedule[si][0], scale_schedule[si][1], scale_schedule[si][2])
638
+ ret.append(h_BChw if returns_vemb != 0 else idx_Bl)
639
+ idx_Bl_list.append(idx_Bl)
640
+ if si != num_stages_minus_1:
641
+ accu_BChw, last_stage = self.quant_only_used_in_inference[0].one_step_fuse(si, num_stages_minus_1+1, accu_BChw, h_BChw, scale_schedule)
642
+
643
+ if si != num_stages_minus_1:
644
+ last_stage = self.word_embed(self.norm0_ve(last_stage))
645
+ last_stage = last_stage.repeat(bs//B, 1, 1)
646
+
647
+ if inference_mode:
648
+ for b in self.unregistered_blocks: (b.sa if isinstance(b, CrossAttnBlock) else b.attn).kv_caching(False)
649
+ else:
650
+ assert self.num_block_chunks > 1
651
+ for block_chunk_ in self.block_chunks:
652
+ for module in block_chunk_.module.module:
653
+ (module.sa if isinstance(module, CrossAttnBlock) else module.attn).kv_caching(False)
654
+
655
+ if not ret_img:
656
+ return ret, idx_Bl_list, []
657
+
658
+ if vae_type != 0:
659
+ img = vae.decode(summed_codes.squeeze(-3))
660
+ else:
661
+ img = vae.viz_from_ms_h_BChw(ret, scale_schedule=scale_schedule, same_shape=True, last_one=True)
662
+
663
+ img = (img + 1) / 2
664
+ img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,))
665
+ return ret, idx_Bl_list, img
666
+
667
+ @for_visualize
668
+ def vis_key_params(self, ep):
669
+ return
670
+
671
+ def load_state_dict(self, state_dict: Dict[str, Any], strict=False, assign=False):
672
+ for k in state_dict:
673
+ if 'cfg_uncond' in k:
674
+ old, new = state_dict[k], self.cfg_uncond.data
675
+ min_tlen = min(old.shape[0], new.shape[0])
676
+ if min_tlen == old.shape[0]:
677
+ state_dict[k] = torch.cat((old.to(device=new.device, dtype=new.dtype), new[min_tlen:]))
678
+ else:
679
+ state_dict[k] = old[:min_tlen]
680
+
681
+ for buf_name in ('lvl_1L', 'attn_bias_for_masking', 'Infinity_visible_kvlen', 'Infinity_invisible_qlen'):
682
+ state_dict.pop(buf_name, None)
683
+ if hasattr(self, buf_name):
684
+ state_dict[buf_name] = getattr(self, buf_name)
685
+
686
+ return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
687
+
688
+ def special_init(
689
+ self,
690
+ aln_init: float,
691
+ aln_gamma_init: float,
692
+ scale_head: float,
693
+ scale_proj: int,
694
+ ):
695
+ # init head's norm
696
+ if isinstance(self.head_nm, AdaLNBeforeHead):
697
+ self.head_nm.ada_lin[-1].weight.data.mul_(aln_init) # there's no gamma for head
698
+ if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
699
+ self.head_nm.ada_lin[-1].bias.data.zero_()
700
+
701
+ # init head's proj
702
+ if scale_head >= 0:
703
+ if isinstance(self.head, nn.Linear):
704
+ self.head.weight.data.mul_(scale_head)
705
+ self.head.bias.data.zero_()
706
+ elif isinstance(self.head, nn.Sequential):
707
+ self.head[-1].weight.data.mul_(scale_head)
708
+ self.head[-1].bias.data.zero_()
709
+
710
+ depth = len(self.unregistered_blocks)
711
+ for block_idx, sab in enumerate(self.unregistered_blocks):
712
+ sab: Union[SelfAttnBlock, CrossAttnBlock]
713
+ # init proj
714
+ scale = 1 / math.sqrt(2*depth if scale_proj == 1 else 2*(1 + block_idx))
715
+ if scale_proj == 1:
716
+ if self.t2i:
717
+ sab.sa.proj.weight.data.mul_(scale)
718
+ sab.ca.proj.weight.data.mul_(scale)
719
+ else:
720
+ sab.attn.proj.weight.data.mul_(scale)
721
+ sab.ffn.fc2.weight.data.mul_(scale)
722
+ # if sab.using_swiglu:
723
+ # nn.init.ones_(sab.ffn.fcg.bias)
724
+ # nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
725
+
726
+ # init ada_lin
727
+ if hasattr(sab, 'ada_lin'):
728
+ lin = sab.ada_lin[-1]
729
+ lin.weight.data[:2*self.C].mul_(aln_gamma_init) # init gamma
730
+ lin.weight.data[2*self.C:].mul_(aln_init) # init scale and shift
731
+ if hasattr(lin, 'bias') and lin.bias is not None:
732
+ lin.bias.data.zero_()
733
+ elif hasattr(sab, 'ada_gss'):
734
+ sab.ada_gss.data[:, :, :2, :].mul_(aln_gamma_init) # init gamma
735
+ sab.ada_gss.data[:, :, 2:, :].mul_(aln_init) # init scale and shift
736
+
737
+ def extra_repr(self):
738
+ return f'drop_path_rate={self.drop_path_rate}'
739
+
740
+ def get_layer_id_and_scale_exp(self, para_name: str):
741
+ raise NotImplementedError
742
+
743
+
744
+ def sample_with_top_k_top_p_also_inplace_modifying_logits_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
745
+ B, l, V = logits_BlV.shape
746
+ if top_k > 0:
747
+ top_k = min(top_k, V)
748
+ idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
749
+ logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
750
+ if top_p > 0:
751
+ sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
752
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
753
+ sorted_idx_to_remove[..., -1:] = False
754
+ logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
755
+ # sample (have to squeeze cuz multinomial can only be used on 2D tensor)
756
+ replacement = num_samples >= 0
757
+ num_samples = abs(num_samples)
758
+ return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
759
+
760
+ def sampling_with_top_k_top_p_also_inplace_modifying_probs_(probs_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
761
+ B, l, V = probs_BlV.shape
762
+ if top_k > 0:
763
+ top_k = min(top_k, V)
764
+ idx_to_remove = probs_BlV < probs_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
765
+ probs_BlV.masked_fill_(idx_to_remove, 0)
766
+ if top_p > 0:
767
+ sorted_probs, sorted_idx = probs_BlV.sort(dim=-1, descending=False)
768
+ sorted_idx_to_remove = sorted_probs.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
769
+ sorted_idx_to_remove[..., -1:] = False
770
+ probs_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), 0)
771
+ # sample (have to squeeze cuz multinomial can only be used on 2D tensor)
772
+ probs_BlV = probs_BlV / probs_BlV.sum(-1, keepdims=True)
773
+ replacement = num_samples >= 0
774
+ num_samples = abs(num_samples)
775
+ return torch.multinomial(probs_BlV.view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
776
+
777
+
778
+ def get_params_num(d, w, mlp):
779
+ m = round(mlp * w / 256) * 256
780
+ s = d * (w**2 * 8 + w*m * 2) # sa+ca, mlp
781
+ s += w**2 * 6 # saln
782
+ s += 4096 * w # pred
783
+ s += 32 * w # we
784
+
785
+ Ct5 = 4096
786
+ s += Ct5*w * 4 # T5 attn pool
787
+ s += Ct5*w + w*w # T5 mlp
788
+ return f'{s/1e9:.2f}B'
789
+
790
+
791
+ TIMM_KEYS = {'img_size', 'pretrained', 'pretrained_cfg', 'pretrained_cfg_overlay', 'global_pool'}
792
+
793
+ @register_model
794
+ def infinity_2b(depth=32, embed_dim=2048, num_heads=2048//128, drop_path_rate=0.1, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
795
+
796
+ @register_model
797
+ def infinity_20b(depth=58, embed_dim=4608, num_heads=4608//128, drop_path_rate=0.25, **kwargs): return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
798
+
799
+ # model configuration for scaling Infinity transformer
800
+ @register_model
801
+ def infinity_layer12(depth=12, embed_dim=768, num_heads=8, drop_path_rate=0.1, **kwargs):
802
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
803
+ @register_model
804
+ def infinity_layer16(depth=16, embed_dim=1152, num_heads=12, drop_path_rate=0.1, **kwargs):
805
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
806
+ @register_model
807
+ def infinity_layer24(depth=24, embed_dim=1536, num_heads=16, drop_path_rate=0.1, **kwargs):
808
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
809
+ @register_model
810
+ def infinity_layer32(depth=32, embed_dim=2080, num_heads=20, drop_path_rate=0.1, **kwargs):
811
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
812
+ @register_model
813
+ def infinity_layer40(depth=40, embed_dim=2688, num_heads=24, drop_path_rate=0.1, **kwargs):
814
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
815
+ @register_model
816
+ def infinity_layer48(depth=48, embed_dim=3360, num_heads=28, drop_path_rate=0.1, **kwargs):
817
+ return Infinity(depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=4, drop_path_rate=drop_path_rate, **{k: v for k, v in kwargs.items() if k not in TIMM_KEYS})
Infinity/infinity_vae_d32_reg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a37fa3ea1b2a1ebd23de61d91a5e68202825e5a67edaef4b7c55f5fd5b9cf26
3
+ size 1557324701
README.md CHANGED
@@ -1,3 +1,162 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Infinity-2B GGUF with SageAttention
2
+
3
+ Unofficial Q8_0 GGUF quantization of Infinity-2B with **SageAttention** support for even faster generation.
4
+
5
+ ## Features
6
+
7
+ ✨ **SageAttention Integration** - 2-5x faster than FlashAttention with automatic fallback
8
+ 🎨 **Gradio Web UI** - Easy-to-use interface for image generation
9
+ 💾 **Q8_0 Quantization** - ~75% memory reduction with minimal quality loss
10
+ 🚀 **Optimized Inference** - T5 encoder on CPU, efficient VRAM usage
11
+ 🔧 **GGUF Support** - On-the-fly dequantization with flexible deployment
12
+
13
+ ## Quick Start
14
+
15
+ ### Web UI (Recommended)
16
+
17
+ ```bash
18
+ python gradio_webui.py --autoload
19
+ ```
20
+
21
+ Then open `http://127.0.0.1:7860` in your browser.
22
+
23
+ ### Command Line
24
+
25
+ ```bash
26
+ python generate_image_2b_q8_gguf.py \
27
+ --prompt "an astronaut riding a horse on the moon" \
28
+ --output output.png
29
+ ```
30
+
31
+ ## Installation
32
+
33
+ ### 1. Basic Requirements
34
+
35
+ ```bash
36
+ pip install -r Infinity/requirements.txt
37
+ pip install gradio gguf
38
+ ```
39
+
40
+ ### 2. Install SageAttention (Optional, Recommended)
41
+
42
+ For faster generation:
43
+
44
+ ```bash
45
+ pip install sageattention>=2.2.0 --no-build-isolation
46
+ ```
47
+
48
+ **Requirements**: CUDA ≥12.0 (CUDA 12.8+ for Blackwell GPUs like RTX 50-series)
49
+
50
+ **Note**: SageAttention is optional. The code automatically falls back to:
51
+ 1. SageAttention (if installed) - 2-5x faster ✨
52
+ 2. FlashAttention (if available) - faster than PyTorch
53
+ 3. PyTorch SDPA (always works) - built-in fallback
54
+
55
+ ### 3. Download Models
56
+
57
+ You'll need:
58
+ - `infinity_2b_reg_Q8_0.gguf` - Infinity-2B model (~2.1 GB)
59
+ - `flan-t5-xl-encoder-Q8_0.gguf` - T5 text encoder (~1.0 GB)
60
+ - `Infinity/infinity_vae_d32_reg.pth` - VAE decoder (~0.5 GB)
61
+
62
+ ## Memory Requirements
63
+
64
+ | Component | VRAM Usage |
65
+ |-----------|-----------|
66
+ | Infinity-2B (Q8_0) | ~2.5 GB |
67
+ | VAE | ~0.5 GB |
68
+ | Working Memory | ~1-2 GB |
69
+ | **Total (1M res)** | **~4-5 GB** |
70
+
71
+ **T5 encoder runs on CPU** to save VRAM!
72
+
73
+ Recommended: **8GB+ VRAM** for comfortable 1M (1024×1024) generation
74
+
75
+ ## Web UI Features
76
+
77
+ The Gradio web interface provides:
78
+
79
+ - **Model Management**: Load models once, reuse for all generations
80
+ - **Full Parameter Control**: CFG scale, tau, resolution, aspect ratio, seed
81
+ - **Real-time Preview**: See your images as they generate
82
+ - **Progress Tracking**: Visual feedback during loading and generation
83
+ - **Clean Layout**: Model paths banner, settings on left, output on right
84
+
85
+ ### Web UI Options
86
+
87
+ ```bash
88
+ # Basic usage
89
+ python gradio_webui.py
90
+
91
+ # Auto-load models on startup (faster)
92
+ python gradio_webui.py --autoload
93
+
94
+ # Create public share link
95
+ python gradio_webui.py --share
96
+
97
+ # Custom port
98
+ python gradio_webui.py --server-port 8080
99
+
100
+ # Full options
101
+ python gradio_webui.py \
102
+ --autoload \
103
+ --server-port 7860 \
104
+ --infinity-gguf path/to/infinity.gguf \
105
+ --t5-gguf path/to/t5.gguf \
106
+ --vae-path path/to/vae.pth
107
+ ```
108
+
109
+ ## Command-Line Options
110
+
111
+ ```bash
112
+ python generate_image_2b_q8_gguf.py [OPTIONS]
113
+ ```
114
+
115
+ | Option | Description | Default |
116
+ |--------|-------------|---------|
117
+ | `--prompt TEXT` | Text prompt for image generation | "an astronaut..." |
118
+ | `--infinity-gguf PATH` | Path to Infinity GGUF file | infinity_2b_reg_Q8_0.gguf |
119
+ | `--t5-gguf PATH` | Path to T5 encoder GGUF | flan-t5-xl-encoder-Q8_0.gguf |
120
+ | `--vae-path PATH` | Path to VAE checkpoint | Infinity/infinity_vae_d32_reg.pth |
121
+ | `--output PATH` | Output image path | output.png |
122
+ | `--cfg-scale FLOAT` | CFG scale (1.0-10.0) | 3.0 |
123
+ | `--tau FLOAT` | Temperature (0.1-1.0) | 0.5 |
124
+ | `--seed INT` | Random seed for reproducibility | 42 |
125
+ | `--pn {0.06M,0.25M,1M}` | Resolution preset | 1M |
126
+ | `--aspect-ratio FLOAT` | Aspect ratio (height/width) | 1.0 |
127
+
128
+
129
+ ## Technical Details
130
+
131
+ ### Quantization
132
+
133
+ - **Q8_0 format**: 8-bit quantization with minimal quality loss
134
+ - **On-the-fly dequantization**: Using custom GGUFLinear layers
135
+ - **Memory savings**: ~75% reduction vs FP16
136
+ - **Quality**: Nearly identical to FP16
137
+
138
+ ### Architecture
139
+
140
+ - **Infinity-2B**: 2.0B parameters, embed_dim=2048, depth=32
141
+ - **T5-XL Encoder**: 2048-dim text embeddings
142
+ - **VAE**: d32 with dynamic resolution support
143
+
144
+ ### GGUF Support
145
+
146
+ The implementation includes:
147
+ - Import utilities for GGUF tensors
148
+ - Custom `GGUFLinear` layers for on-the-fly dequantization
149
+ - Patched attention mechanisms for compatibility
150
+ - F16 dtype handling for head layers
151
+
152
+ See [patch_infinity_for_gguf.sh](patch_infinity_for_gguf.sh) for implementation details.
153
+
154
+ ## Credits
155
+
156
+ - **Original Model**: [Infinity by FoundationVision](https://github.com/FoundationVision/Infinity)
157
+ - **SageAttention**: [thu-ml/SageAttention](https://github.com/thu-ml/SageAttention)
158
+ - **GGUF Format**: [ggerganov/ggml](https://github.com/ggerganov/ggml)
159
+
160
+ ## License
161
+
162
+ MIT
flan-t5-xl-encoder-Q8_0.gguf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d212c960e07faf2323e2136cb03e62578a8f6862f13709f480684da9f5d9a2e6
3
+ size 1563507296
generate_image_2b_q8_gguf.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate images using quantized Infinity-2B model (GGUF format)
4
+ Loads T5 text encoder from GGUF on CPU, Infinity model from GGUF on GPU
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
+
11
+ # Add Infinity to Python path (assumes Infinity repo is in same directory as this script)
12
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
13
+ INFINITY_PATH = os.path.join(SCRIPT_DIR, 'Infinity')
14
+ if os.path.exists(INFINITY_PATH):
15
+ sys.path.insert(0, INFINITY_PATH)
16
+ else:
17
+ print(f"Warning: Infinity repo not found at {INFINITY_PATH}")
18
+ print("Please clone the Infinity repo and run patch_infinity_for_gguf.sh")
19
+
20
+ import time
21
+ import argparse
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import numpy as np
25
+ import cv2
26
+ from typing import List
27
+ import gguf
28
+
29
+ # Import existing utilities
30
+ from infinity_gguf_utils import (
31
+ load_gguf_state_dict,
32
+ load_gguf_state_dict_with_params,
33
+ _replace_with_gguf_linear,
34
+ GGUFParameter,
35
+ dequantize_gguf_tensor,
36
+ GGUFLinear
37
+ )
38
+
39
+ # Import Infinity model and utilities
40
+ from infinity.models.infinity import Infinity
41
+ from infinity.models.bsq_vae.vae import vae_model
42
+ from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
43
+
44
+ # Import transformers for tokenizer
45
+ from transformers import AutoTokenizer
46
+
47
+
48
+ def load_t5_tokenizer_from_gguf(gguf_path):
49
+ """
50
+ Load T5 tokenizer from GGUF metadata or use standard tokenizer
51
+ For simplicity, we'll use the standard T5 tokenizer
52
+ """
53
+ print("[Loading T5 Tokenizer]")
54
+ # Use standard T5 tokenizer - the GGUF file should be compatible
55
+ # We can use any T5-v1.1-xxl tokenizer since the vocab is standard
56
+ try:
57
+ from transformers import T5TokenizerFast
58
+ # Try to find a local tokenizer or use HuggingFace
59
+ tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl", legacy=True)
60
+ tokenizer.model_max_length = 512
61
+ return tokenizer
62
+ except:
63
+ print("Warning: Could not load T5 tokenizer from HuggingFace, trying local cache...")
64
+ tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xxl", legacy=True)
65
+ tokenizer.model_max_length = 512
66
+ return tokenizer
67
+
68
+
69
+ def load_t5_encoder_from_gguf(gguf_path, device='cpu'):
70
+ """
71
+ Load T5 encoder from GGUF file and keep on CPU
72
+ Based on ComfyUI-GGUF loader implementation
73
+ """
74
+ print(f"[Loading T5 Encoder from GGUF: {gguf_path}]")
75
+ print(f"[T5 will be kept on {device}]")
76
+
77
+ # Apply NumPy 2.0 compatibility patch if needed
78
+ import numpy as np
79
+ if not hasattr(np.ndarray, 'newbyteorder'):
80
+ def newbyteorder(self, new_order):
81
+ return self.view(self.dtype.newbyteorder(new_order))
82
+ np.ndarray.newbyteorder = newbyteorder
83
+
84
+ # Load GGUF state dict
85
+ from gguf import GGUFReader
86
+ reader = GGUFReader(gguf_path)
87
+
88
+ # Map llama.cpp T5 keys to HuggingFace T5 keys
89
+ T5_SD_MAP = {
90
+ "enc.": "encoder.",
91
+ ".blk.": ".block.",
92
+ "token_embd": "shared",
93
+ "output_norm": "final_layer_norm",
94
+ "attn_q": "layer.0.SelfAttention.q",
95
+ "attn_k": "layer.0.SelfAttention.k",
96
+ "attn_v": "layer.0.SelfAttention.v",
97
+ "attn_o": "layer.0.SelfAttention.o",
98
+ "attn_norm": "layer.0.layer_norm",
99
+ "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
100
+ "ffn_up": "layer.1.DenseReluDense.wi_1",
101
+ "ffn_down": "layer.1.DenseReluDense.wo",
102
+ "ffn_gate": "layer.1.DenseReluDense.wi_0",
103
+ "ffn_norm": "layer.1.layer_norm",
104
+ }
105
+
106
+ # Load and convert tensors
107
+ state_dict = {}
108
+ print("Loading T5 tensors from GGUF...")
109
+ for tensor in reader.tensors:
110
+ tensor_name = tensor.name
111
+
112
+ # Apply key mapping
113
+ for old_key, new_key in T5_SD_MAP.items():
114
+ tensor_name = tensor_name.replace(old_key, new_key)
115
+
116
+ # Load tensor data
117
+ torch_tensor = torch.from_numpy(np.array(tensor.data))
118
+
119
+ # Determine shape
120
+ shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
121
+
122
+ # Check if quantized
123
+ is_quantized = tensor.tensor_type not in {
124
+ gguf.GGMLQuantizationType.F32,
125
+ gguf.GGMLQuantizationType.F16
126
+ }
127
+
128
+ if is_quantized:
129
+ # Dequantize to float16 for CPU inference
130
+ # print(f" Dequantizing {tensor_name} ({tensor.tensor_type})...")
131
+ param = GGUFParameter(torch_tensor, quant_type=tensor.tensor_type)
132
+ dequant_tensor = dequantize_gguf_tensor(param, target_dtype=torch.float16)
133
+ state_dict[tensor_name] = dequant_tensor.to(device)
134
+ else:
135
+ # Already F32 or F16
136
+ torch_tensor = torch_tensor.view(*shape)
137
+ if tensor.tensor_type == gguf.GGMLQuantizationType.F32:
138
+ state_dict[tensor_name] = torch_tensor.to(torch.float16).to(device)
139
+ else:
140
+ state_dict[tensor_name] = torch_tensor.to(device)
141
+
142
+ print(f"Loaded {len(state_dict)} tensors for T5 encoder")
143
+
144
+ # Load T5 model architecture from transformers
145
+ from transformers import T5EncoderModel, T5Config
146
+
147
+ # Create T5 config - for T5-XL (2048 dims, not XXL which is 4096)
148
+ # Try to load from local directory first, fall back to download if needed
149
+ try:
150
+ config = T5Config.from_pretrained("./flan-t5-xl-official")
151
+ print("Loaded T5 config from local directory")
152
+ except Exception as e:
153
+ print(f"Could not load config from local directory: {e}")
154
+ print("Falling back to download T5 config...")
155
+ config = T5Config.from_pretrained("google/flan-t5-xl")
156
+ print("Downloaded T5 config from HuggingFace")
157
+
158
+ # Create model
159
+ model = T5EncoderModel(config)
160
+
161
+ # Load state dict
162
+ print("Loading state dict into T5 model...")
163
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
164
+ if missing:
165
+ print(f" Missing keys: {missing[:5]}..." if len(missing) > 5 else f" Missing keys: {missing}")
166
+ if unexpected:
167
+ print(f" Unexpected keys: {unexpected[:5]}..." if len(unexpected) > 5 else f" Unexpected keys: {unexpected}")
168
+
169
+ model.to(device)
170
+ model.eval()
171
+ model.requires_grad_(False)
172
+
173
+ print(f"[T5 Encoder loaded successfully on {device}]")
174
+ return model
175
+
176
+
177
+ def load_infinity_from_gguf(gguf_path, vae, device='cuda', model_type='infinity_2b',
178
+ text_channels=2048, pn='1M'):
179
+ """
180
+ Load Infinity model from GGUF file
181
+ """
182
+ print(f"[Loading Infinity-2B from GGUF: {gguf_path}]")
183
+
184
+ # Model configuration for Infinity-2B
185
+ if model_type == 'infinity_2b':
186
+ kwargs_model = dict(
187
+ depth=32,
188
+ embed_dim=2048,
189
+ num_heads=2048//128, # 16 heads
190
+ drop_path_rate=0.1,
191
+ mlp_ratio=4,
192
+ block_chunks=8
193
+ )
194
+ else:
195
+ raise ValueError(f"Unsupported model type: {model_type}")
196
+
197
+ # Create Infinity model
198
+ text_maxlen = 512
199
+ print("[Creating Infinity model architecture]")
200
+
201
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True), torch.no_grad():
202
+ infinity_model = Infinity(
203
+ vae_local=vae,
204
+ text_channels=text_channels,
205
+ text_maxlen=text_maxlen,
206
+ shared_aln=True,
207
+ raw_scale_schedule=None,
208
+ checkpointing='full-block',
209
+ customized_flash_attn=False,
210
+ fused_norm=True,
211
+ pad_to_multiplier=128,
212
+ use_flex_attn=False,
213
+ add_lvl_embeding_only_first_block=1,
214
+ use_bit_label=1,
215
+ rope2d_each_sa_layer=1,
216
+ rope2d_normalized_by_hw=2,
217
+ pn=pn,
218
+ apply_spatial_patchify=0,
219
+ inference_mode=True,
220
+ train_h_div_w_list=[1.0],
221
+ **kwargs_model,
222
+ ).to(device=device)
223
+
224
+ print(f"[Infinity model size: {sum(p.numel() for p in infinity_model.parameters())/1e9:.2f}B parameters]")
225
+
226
+ # Convert to bfloat16
227
+ for block in infinity_model.unregistered_blocks:
228
+ block.bfloat16()
229
+
230
+ infinity_model.eval()
231
+ infinity_model.requires_grad_(False)
232
+
233
+ # Load GGUF weights with GGUFParameters
234
+ print("[Loading Infinity weights from GGUF]")
235
+ state_dict = load_gguf_state_dict_with_params(gguf_path, device=device)
236
+
237
+ # Replace Linear layers with GGUFLinear layers for on-the-fly dequantization
238
+ print("[Replacing Linear layers with GGUFLinear layers]")
239
+ infinity_model = _replace_with_gguf_linear(infinity_model, torch.bfloat16, state_dict, prefix="")
240
+
241
+ # Load weights directly into the model (not using load_state_dict)
242
+ print("[Loading weights into model]")
243
+ skipped_keys = []
244
+ for key, tensor in state_dict.items():
245
+ # Find the module and parameter name
246
+ parts = key.rsplit('.', 1)
247
+ if len(parts) != 2:
248
+ continue
249
+
250
+ module_name, param_name = parts
251
+
252
+ # Navigate to the module
253
+ module = infinity_model
254
+ for attr in module_name.split('.'):
255
+ if hasattr(module, attr):
256
+ module = getattr(module, attr)
257
+ else:
258
+ module = None
259
+ break
260
+
261
+ # Set the parameter
262
+ if module is not None and hasattr(module, param_name):
263
+ existing_param = getattr(module, param_name)
264
+
265
+ # Get the shape of the tensor to load
266
+ tensor_shape = tensor.shape
267
+ if hasattr(tensor, 'quant_shape'):
268
+ tensor_shape = tensor.quant_shape
269
+
270
+ # Check if shapes match
271
+ if existing_param.shape != tensor_shape:
272
+ print(f"[WARNING] Shape mismatch for {key}: expected {existing_param.shape}, got {tensor_shape}. Skipping.")
273
+ skipped_keys.append(key)
274
+ continue
275
+
276
+ # Set the parameter
277
+ if isinstance(tensor, torch.nn.Parameter):
278
+ setattr(module, param_name, tensor)
279
+ else:
280
+ setattr(module, param_name, torch.nn.Parameter(tensor, requires_grad=False))
281
+
282
+ if skipped_keys:
283
+ print(f"[INFO] Skipped {len(skipped_keys)} parameters due to shape mismatches")
284
+
285
+ infinity_model.rng = torch.Generator(device=device)
286
+
287
+ print("[Infinity model loaded successfully]")
288
+ return infinity_model
289
+
290
+
291
+ def load_vae(vae_path, vae_type=32, device='cuda'):
292
+ """
293
+ Load VAE model
294
+ """
295
+ print(f"[Loading VAE from {vae_path}]")
296
+
297
+ schedule_mode = "dynamic"
298
+ codebook_dim = vae_type
299
+ codebook_size = 2**codebook_dim
300
+ patch_size = 16
301
+ encoder_ch_mult = [1, 2, 4, 4, 4]
302
+ decoder_ch_mult = [1, 2, 4, 4, 4]
303
+
304
+ vae = vae_model(
305
+ vae_path,
306
+ schedule_mode,
307
+ codebook_dim,
308
+ codebook_size,
309
+ patch_size=patch_size,
310
+ encoder_ch_mult=encoder_ch_mult,
311
+ decoder_ch_mult=decoder_ch_mult,
312
+ test_mode=True
313
+ ).to(device)
314
+
315
+ print("[VAE loaded successfully]")
316
+ return vae
317
+
318
+
319
+ def encode_prompt(text_tokenizer, text_encoder, prompt, device='cuda'):
320
+ """
321
+ Encode text prompt using T5 encoder
322
+ """
323
+ print(f"Encoding prompt: {prompt}")
324
+
325
+ captions = [prompt]
326
+ tokens = text_tokenizer(
327
+ text=captions,
328
+ max_length=512,
329
+ padding='max_length',
330
+ truncation=True,
331
+ return_tensors='pt'
332
+ )
333
+
334
+ # Move tokens to appropriate devices
335
+ # T5 encoder is on CPU, so keep tokens on CPU too
336
+ input_ids = tokens.input_ids.to(text_encoder.device)
337
+ mask = tokens.attention_mask.to(text_encoder.device)
338
+
339
+ # Encode with T5
340
+ with torch.no_grad():
341
+ text_features = text_encoder(
342
+ input_ids=input_ids,
343
+ attention_mask=mask
344
+ )['last_hidden_state'].float()
345
+
346
+ # Move to GPU for Infinity model
347
+ text_features = text_features.to(device)
348
+ mask = mask.to(device)
349
+
350
+ lens: List[int] = mask.sum(dim=-1).tolist()
351
+ cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
352
+ Ltext = max(lens)
353
+
354
+ kv_compact = []
355
+ for len_i, feat_i in zip(lens, text_features.unbind(0)):
356
+ kv_compact.append(feat_i[:len_i])
357
+ kv_compact = torch.cat(kv_compact, dim=0)
358
+
359
+ # Ensure kv_compact is in float32 to avoid dtype mismatches
360
+ kv_compact = kv_compact.to(torch.float32)
361
+
362
+ text_cond_tuple = (kv_compact, lens, cu_seqlens_k, Ltext)
363
+ return text_cond_tuple
364
+
365
+
366
+ def generate_image(infinity_model, vae, text_tokenizer, text_encoder, prompt,
367
+ cfg_scale=3.0, tau=0.5, seed=None, scale_schedule=None,
368
+ vae_type=32, device='cuda'):
369
+ """
370
+ Generate image using Infinity model
371
+ """
372
+ print("[Starting image generation]")
373
+ start_time = time.time()
374
+
375
+ # Note: Deterministic mode is set early in main() if seed is provided
376
+ if seed is not None:
377
+ print(f"Using seed: {seed}")
378
+
379
+ # Encode prompt
380
+ text_cond_tuple = encode_prompt(text_tokenizer, text_encoder, prompt, device=device)
381
+
382
+ # Prepare cfg and tau lists
383
+ cfg_list = [cfg_scale] * len(scale_schedule)
384
+ tau_list = [tau] * len(scale_schedule)
385
+
386
+ print(f"CFG scale: {cfg_scale}, Tau: {tau}")
387
+ print(f"Scale schedule: {scale_schedule}")
388
+
389
+ # Generate with autocast
390
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
391
+ with torch.no_grad():
392
+ gen_start = time.time()
393
+
394
+ _, _, img_list = infinity_model.autoregressive_infer_cfg(
395
+ vae=vae,
396
+ scale_schedule=scale_schedule,
397
+ label_B_or_BLT=text_cond_tuple,
398
+ g_seed=seed,
399
+ B=1,
400
+ negative_label_B_or_BLT=None,
401
+ force_gt_Bhw=None,
402
+ cfg_sc=cfg_scale,
403
+ cfg_list=cfg_list,
404
+ tau_list=tau_list,
405
+ top_k=900,
406
+ top_p=0.97,
407
+ returns_vemb=1,
408
+ ratio_Bl1=None,
409
+ gumbel=0,
410
+ norm_cfg=False,
411
+ cfg_exp_k=0.0,
412
+ cfg_insertion_layer=[0], # Must be a list
413
+ vae_type=vae_type,
414
+ softmax_merge_topk=-1,
415
+ ret_img=True,
416
+ trunk_scale=1000,
417
+ gt_leak=0,
418
+ gt_ls_Bl=None,
419
+ inference_mode=True,
420
+ sampling_per_bits=1,
421
+ )
422
+
423
+ gen_time = time.time() - gen_start
424
+
425
+ img = img_list[0]
426
+
427
+ total_time = time.time() - start_time
428
+ print(f"[Generation complete! Total time: {total_time:.2f}s, Inference time: {gen_time:.2f}s]")
429
+
430
+ return img
431
+
432
+
433
+ def main():
434
+ parser = argparse.ArgumentParser(description='Generate images with Infinity-2B GGUF')
435
+ parser.add_argument('--prompt', type=str,
436
+ default='an astronaut riding a horse on the moon',
437
+ help='Text prompt for image generation')
438
+ parser.add_argument('--infinity-gguf', type=str,
439
+ default='infinity_2b_reg_Q8_0.gguf',
440
+ help='Path to Infinity-2B GGUF file')
441
+ parser.add_argument('--t5-gguf', type=str,
442
+ default='flan-t5-xl-encoder-Q8_0.gguf',
443
+ help='Path to T5 encoder GGUF file')
444
+ parser.add_argument('--vae-path', type=str,
445
+ default='Infinity/infinity_vae_d32_reg.pth',
446
+ help='Path to VAE checkpoint')
447
+ parser.add_argument('--output', type=str,
448
+ default='output.png',
449
+ help='Output image path')
450
+ parser.add_argument('--cfg-scale', type=float, default=3.0,
451
+ help='Classifier-free guidance scale')
452
+ parser.add_argument('--tau', type=float, default=0.5,
453
+ help='Temperature for self-attention')
454
+ parser.add_argument('--seed', type=int, default=42,
455
+ help='Random seed')
456
+ parser.add_argument('--pn', type=str, default='1M',
457
+ choices=['0.06M', '0.25M', '1M'],
458
+ help='Resolution preset')
459
+ parser.add_argument('--aspect-ratio', type=float, default=1.0,
460
+ help='Aspect ratio (height/width)')
461
+
462
+ args = parser.parse_args()
463
+
464
+ # Set deterministic mode early (before model loading) if seed is provided
465
+ if args.seed is not None:
466
+ torch.manual_seed(args.seed)
467
+ np.random.seed(args.seed)
468
+
469
+ # Enable deterministic mode for cuDNN
470
+ torch.backends.cudnn.deterministic = True
471
+ torch.backends.cudnn.benchmark = False
472
+
473
+ # Try to enable full deterministic mode
474
+ try:
475
+ torch.use_deterministic_algorithms(True, warn_only=True)
476
+ except Exception as e:
477
+ print(f"Warning: Could not enable full deterministic mode: {e}")
478
+
479
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
480
+ print(f"Using device: {device}")
481
+
482
+ # Set CUDA seed after device is determined
483
+ if args.seed is not None and device == 'cuda':
484
+ torch.cuda.manual_seed(args.seed)
485
+ torch.cuda.manual_seed_all(args.seed)
486
+
487
+ # Control SDPA backend for determinism
488
+ try:
489
+ torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False)
490
+ print(f"Deterministic mode enabled (seed={args.seed})")
491
+ except Exception as e:
492
+ print(f"Warning: Could not set SDPA backend: {e}")
493
+
494
+ if device == 'cpu':
495
+ print("WARNING: No GPU detected! This will be extremely slow.")
496
+
497
+ # Determine scale schedule based on aspect ratio
498
+ h_div_w_template = h_div_w_templates[
499
+ np.argmin(np.abs(h_div_w_templates - args.aspect_ratio))
500
+ ]
501
+ scale_schedule = dynamic_resolution_h_w[h_div_w_template][args.pn]['scales']
502
+ scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
503
+
504
+ print("\n" + "="*70)
505
+ print("Infinity-2B GGUF Image Generation")
506
+ print("="*70)
507
+
508
+ # Load models
509
+ print("\n[1/4] Loading VAE...")
510
+ vae = load_vae(args.vae_path, vae_type=32, device=device)
511
+
512
+ print("\n[2/4] Loading T5 Tokenizer...")
513
+ text_tokenizer = load_t5_tokenizer_from_gguf(args.t5_gguf)
514
+
515
+ print("\n[3/4] Loading T5 Encoder from GGUF (on CPU)...")
516
+ text_encoder = load_t5_encoder_from_gguf(args.t5_gguf, device='cpu')
517
+
518
+ print("\n[4/4] Loading Infinity-2B from GGUF...")
519
+ infinity_model = load_infinity_from_gguf(
520
+ args.infinity_gguf,
521
+ vae=vae,
522
+ device=device,
523
+ model_type='infinity_2b',
524
+ text_channels=2048, # Model projects T5's 4096 internally
525
+ pn=args.pn
526
+ )
527
+
528
+ print("\n" + "="*70)
529
+ print("All models loaded successfully!")
530
+ print("="*70)
531
+
532
+ # Generate image
533
+ print(f"\nGenerating image with prompt: '{args.prompt}'")
534
+ generated_image = generate_image(
535
+ infinity_model,
536
+ vae,
537
+ text_tokenizer,
538
+ text_encoder,
539
+ args.prompt,
540
+ cfg_scale=args.cfg_scale,
541
+ tau=args.tau,
542
+ seed=args.seed,
543
+ scale_schedule=scale_schedule,
544
+ vae_type=32,
545
+ device=device
546
+ )
547
+
548
+ # Save image
549
+ print(f"\nSaving image to {args.output}...")
550
+ image_np = generated_image.cpu().numpy()
551
+ cv2.imwrite(args.output, image_np)
552
+
553
+ print(f"\n{'='*70}")
554
+ print(f"✓ Image saved successfully to: {args.output}")
555
+ print(f"{'='*70}\n")
556
+
557
+
558
+ if __name__ == '__main__':
559
+ main()
gradio_webui.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio Web UI for Infinity-2B GGUF Image Generation
4
+ Provides an easy-to-use interface for generating images with the quantized model
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
+
11
+ # Add Infinity to Python path
12
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
13
+ INFINITY_PATH = os.path.join(SCRIPT_DIR, 'Infinity')
14
+ if os.path.exists(INFINITY_PATH):
15
+ sys.path.insert(0, INFINITY_PATH)
16
+
17
+ import time
18
+ import argparse
19
+ import torch
20
+ import numpy as np
21
+ import gradio as gr
22
+ from PIL import Image
23
+ from datetime import datetime
24
+
25
+ # Import the generation functions from our existing script
26
+ from generate_image_2b_q8_gguf import (
27
+ load_t5_tokenizer_from_gguf,
28
+ load_t5_encoder_from_gguf,
29
+ load_infinity_from_gguf,
30
+ load_vae,
31
+ generate_image
32
+ )
33
+
34
+ from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
35
+
36
+
37
+ # Global model storage
38
+ class ModelCache:
39
+ def __init__(self):
40
+ self.vae = None
41
+ self.text_tokenizer = None
42
+ self.text_encoder = None
43
+ self.infinity_model = None
44
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
+ self.loaded = False
46
+
47
+
48
+ model_cache = ModelCache()
49
+
50
+
51
+ def load_models(infinity_gguf_path, t5_gguf_path, vae_path, pn='1M', progress=gr.Progress()):
52
+ """
53
+ Load all models with progress tracking
54
+ """
55
+ global model_cache
56
+
57
+ if model_cache.loaded:
58
+ return "✓ Models already loaded!"
59
+
60
+ progress(0, desc="Loading VAE...")
61
+ model_cache.vae = load_vae(vae_path, vae_type=32, device=model_cache.device)
62
+
63
+ progress(0.25, desc="Loading T5 Tokenizer...")
64
+ model_cache.text_tokenizer = load_t5_tokenizer_from_gguf(t5_gguf_path)
65
+
66
+ progress(0.5, desc="Loading T5 Encoder (on CPU)...")
67
+ model_cache.text_encoder = load_t5_encoder_from_gguf(t5_gguf_path, device='cpu')
68
+
69
+ progress(0.75, desc="Loading Infinity-2B from GGUF...")
70
+ model_cache.infinity_model = load_infinity_from_gguf(
71
+ infinity_gguf_path,
72
+ vae=model_cache.vae,
73
+ device=model_cache.device,
74
+ model_type='infinity_2b',
75
+ text_channels=2048,
76
+ pn=pn
77
+ )
78
+
79
+ model_cache.loaded = True
80
+ progress(1.0, desc="Complete!")
81
+
82
+ return "✓ All models loaded successfully!"
83
+
84
+
85
+ def generate_image_gradio(
86
+ prompt,
87
+ cfg_scale,
88
+ tau,
89
+ seed,
90
+ aspect_ratio,
91
+ pn,
92
+ use_random_seed,
93
+ progress=gr.Progress()
94
+ ):
95
+ """
96
+ Generate image with Gradio progress tracking
97
+ """
98
+ global model_cache
99
+
100
+ if not model_cache.loaded:
101
+ return None, "❌ Please load models first!"
102
+
103
+ try:
104
+ # Use random seed if requested
105
+ if use_random_seed:
106
+ seed = np.random.randint(0, 2**31 - 1)
107
+
108
+ # Set seed for reproducibility
109
+ if seed is not None:
110
+ torch.manual_seed(seed)
111
+ np.random.seed(seed)
112
+ if model_cache.device == 'cuda':
113
+ torch.cuda.manual_seed(seed)
114
+ torch.cuda.manual_seed_all(seed)
115
+
116
+ # Determine scale schedule based on aspect ratio
117
+ h_div_w_template = h_div_w_templates[
118
+ np.argmin(np.abs(h_div_w_templates - aspect_ratio))
119
+ ]
120
+ scale_schedule = dynamic_resolution_h_w[h_div_w_template][pn]['scales']
121
+ scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
122
+
123
+ progress(0.1, desc="Encoding prompt...")
124
+ start_time = time.time()
125
+
126
+ progress(0.3, desc="Generating image (this may take a while)...")
127
+
128
+ # Generate image
129
+ img_np = generate_image(
130
+ model_cache.infinity_model,
131
+ model_cache.vae,
132
+ model_cache.text_tokenizer,
133
+ model_cache.text_encoder,
134
+ prompt,
135
+ cfg_scale=cfg_scale,
136
+ tau=tau,
137
+ seed=seed,
138
+ scale_schedule=scale_schedule,
139
+ vae_type=32,
140
+ device=model_cache.device
141
+ )
142
+
143
+ progress(0.9, desc="Converting to PIL Image...")
144
+
145
+ # Convert to PIL Image (RGB)
146
+ img_np = img_np.cpu().numpy()
147
+ # OpenCV uses BGR, convert to RGB
148
+ img_rgb = img_np[:, :, ::-1]
149
+ pil_image = Image.fromarray(img_rgb.astype(np.uint8))
150
+
151
+ elapsed_time = time.time() - start_time
152
+
153
+ # Get resolution
154
+ h, w = img_np.shape[:2]
155
+
156
+ info = f"""✓ Generation complete!
157
+
158
+ **Time**: {elapsed_time:.2f}s
159
+ **Resolution**: {w}x{h}
160
+ **Seed**: {seed}
161
+ **CFG Scale**: {cfg_scale}
162
+ **Tau**: {tau}
163
+ **Aspect Ratio**: {aspect_ratio:.2f}
164
+ **PN**: {pn}"""
165
+
166
+ progress(1.0, desc="Done!")
167
+
168
+ return pil_image, info
169
+
170
+ except Exception as e:
171
+ import traceback
172
+ error_msg = f"❌ Error during generation:\n{str(e)}\n\n{traceback.format_exc()}"
173
+ return None, error_msg
174
+
175
+
176
+ def create_ui():
177
+ """
178
+ Create Gradio UI
179
+ """
180
+ # Create Blocks without theme for compatibility with older Gradio versions
181
+ with gr.Blocks(title="Infinity-2B GGUF Generator") as demo:
182
+ gr.Markdown("# 🎨 Infinity-2B GGUF Image Generator")
183
+
184
+ # Model paths banner at the top
185
+ with gr.Row():
186
+ infinity_gguf = gr.Textbox(
187
+ label="Infinity-2B GGUF",
188
+ value="infinity_2b_reg_Q8_0.gguf",
189
+ scale=2
190
+ )
191
+
192
+ t5_gguf = gr.Textbox(
193
+ label="T5 GGUF",
194
+ value="flan-t5-xl-encoder-Q8_0.gguf",
195
+ scale=2
196
+ )
197
+
198
+ vae_path = gr.Textbox(
199
+ label="VAE Checkpoint",
200
+ value="Infinity/infinity_vae_d32_reg.pth",
201
+ scale=2
202
+ )
203
+
204
+ pn_load = gr.Dropdown(
205
+ label="Resolution Preset",
206
+ choices=['0.06M', '0.25M', '1M'],
207
+ value='1M',
208
+ scale=1
209
+ )
210
+
211
+ load_btn = gr.Button("🚀 Load Models", variant="primary", scale=1)
212
+
213
+ load_status = gr.Textbox(label="Status", interactive=False, show_label=False)
214
+
215
+ # Main content area
216
+ with gr.Row():
217
+ # Left column: Generation settings
218
+ with gr.Column(scale=1):
219
+ gr.Markdown("### Generation Settings")
220
+
221
+ prompt = gr.Textbox(
222
+ label="Prompt",
223
+ placeholder="Describe the image you want to generate...",
224
+ value="an astronaut riding a horse on the moon",
225
+ lines=3
226
+ )
227
+
228
+ with gr.Row():
229
+ cfg_scale = gr.Slider(
230
+ minimum=1.0,
231
+ maximum=10.0,
232
+ value=3.0,
233
+ step=0.5,
234
+ label="CFG Scale",
235
+ info="Higher = stronger prompt adherence"
236
+ )
237
+
238
+ tau = gr.Slider(
239
+ minimum=0.1,
240
+ maximum=1.0,
241
+ value=0.5,
242
+ step=0.05,
243
+ label="Tau (Temperature)",
244
+ info="Lower = more deterministic"
245
+ )
246
+
247
+ with gr.Row():
248
+ aspect_ratio = gr.Slider(
249
+ minimum=0.5,
250
+ maximum=2.0,
251
+ value=1.0,
252
+ step=0.1,
253
+ label="Aspect Ratio (H/W)",
254
+ info="1.0 = square, >1.0 = portrait, <1.0 = landscape"
255
+ )
256
+
257
+ pn = gr.Dropdown(
258
+ label="Resolution Preset",
259
+ choices=['0.06M', '0.25M', '1M'],
260
+ value='1M',
261
+ info="Higher = better quality but slower"
262
+ )
263
+
264
+ with gr.Row():
265
+ seed = gr.Number(
266
+ label="Seed",
267
+ value=42,
268
+ precision=0,
269
+ info="For reproducible results"
270
+ )
271
+
272
+ use_random_seed = gr.Checkbox(
273
+ label="Random Seed",
274
+ value=False,
275
+ info="Generate random seed each time"
276
+ )
277
+
278
+ generate_btn = gr.Button("✨ Generate Image", variant="primary", size="lg")
279
+
280
+ # Right column: Output
281
+ with gr.Column(scale=1):
282
+ output_image = gr.Image(
283
+ label="Generated Image",
284
+ type="pil",
285
+ height=600
286
+ )
287
+ output_info = gr.Markdown("Generate an image to see details here.")
288
+
289
+ # Wire up events
290
+ load_btn.click(
291
+ fn=load_models,
292
+ inputs=[infinity_gguf, t5_gguf, vae_path, pn_load],
293
+ outputs=[load_status]
294
+ )
295
+
296
+ generate_btn.click(
297
+ fn=generate_image_gradio,
298
+ inputs=[prompt, cfg_scale, tau, seed, aspect_ratio, pn, use_random_seed],
299
+ outputs=[output_image, output_info]
300
+ )
301
+
302
+ return demo
303
+
304
+
305
+ def main():
306
+ parser = argparse.ArgumentParser(description='Infinity-2B GGUF Gradio Web UI')
307
+ parser.add_argument('--share', action='store_true', help='Create a public share link')
308
+ parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Server name')
309
+ parser.add_argument('--server-port', type=int, default=7860, help='Server port')
310
+ parser.add_argument('--autoload', action='store_true', help='Auto-load models on startup')
311
+ parser.add_argument('--infinity-gguf', type=str, default='infinity_2b_reg_Q8_0.gguf')
312
+ parser.add_argument('--t5-gguf', type=str, default='flan-t5-xl-encoder-Q8_0.gguf')
313
+ parser.add_argument('--vae-path', type=str, default='Infinity/infinity_vae_d32_reg.pth')
314
+
315
+ args = parser.parse_args()
316
+
317
+ # Auto-load models if requested
318
+ if args.autoload:
319
+ print("Auto-loading models...")
320
+ load_models(args.infinity_gguf, args.t5_gguf, args.vae_path)
321
+
322
+ # Create and launch UI
323
+ demo = create_ui()
324
+
325
+ print("\n" + "="*70)
326
+ print("Starting Infinity-2B GGUF Web UI")
327
+ print("="*70)
328
+ print(f"Server: http://{args.server_name}:{args.server_port}")
329
+ if args.share:
330
+ print("Creating public share link...")
331
+ print("="*70 + "\n")
332
+
333
+ demo.launch(
334
+ server_name=args.server_name,
335
+ server_port=args.server_port,
336
+ share=args.share,
337
+ inbrowser=True
338
+ )
339
+
340
+
341
+ if __name__ == '__main__':
342
+ main()
infinity_2b_reg_Q8_0.gguf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:747220c5030d342f0195f34eb5c21ebb75b2bb855df96a848544be29f00326bc
3
+ size 2374494496
infinity_gguf_utils.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GGUF utilities for Infinity model inference
4
+ Includes GGUFParameter, dequantization functions, and GGUFLinear layer
5
+ """
6
+
7
+ import numpy as np
8
+
9
+ # Monkey patch for NumPy 2.0 compatibility (must be done before importing gguf)
10
+ if not hasattr(np.ndarray, 'newbyteorder'):
11
+ def newbyteorder(self, new_order):
12
+ return self.view(self.dtype.newbyteorder(new_order))
13
+ np.ndarray.newbyteorder = newbyteorder
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import gguf
18
+ from typing import Optional
19
+
20
+
21
+ # Dequantization constants
22
+ QK_K = 256
23
+ K_SCALE_SIZE = 12
24
+
25
+
26
+ def to_uint32(x):
27
+ """Convert bytes to uint32"""
28
+ x = x.view(torch.uint8).to(torch.int32)
29
+ return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
30
+
31
+
32
+ def split_block_dims(blocks, *args):
33
+ """Split block dimensions"""
34
+ n_max = blocks.shape[1]
35
+ dims = list(args) + [n_max - sum(args)]
36
+ return torch.split(blocks, dims, dim=1)
37
+
38
+
39
+ def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
40
+ """Dequantize Q8_0 blocks"""
41
+ d, x = split_block_dims(blocks, 2)
42
+ d = d.view(torch.float16).to(dtype)
43
+ x = x.view(torch.int8)
44
+ return d * x
45
+
46
+
47
+ def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
48
+ """Dequantize Q6_K blocks"""
49
+ n_blocks = blocks.shape[0]
50
+ ql, qh, scales, d = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
51
+
52
+ scales = scales.view(torch.int8).to(dtype)
53
+ d = d.view(torch.float16).to(dtype)
54
+ d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
55
+
56
+ ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
57
+ ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
58
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
59
+ qh = (qh & 0x03).reshape((n_blocks, -1, 32))
60
+ q = (ql | (qh << 4)).to(torch.int8) - 32
61
+ q = q.reshape((n_blocks, QK_K // 16, -1))
62
+
63
+ return (d * q).reshape((n_blocks, QK_K))
64
+
65
+
66
+ def get_scale_min(scales):
67
+ """Extract scale and min from packed data"""
68
+ n_blocks = scales.shape[0]
69
+ scales = scales.view(torch.uint8)
70
+ scales = scales.reshape((n_blocks, 3, 4))
71
+
72
+ d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
73
+
74
+ sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
75
+ min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
76
+
77
+ return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
78
+
79
+
80
+ def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
81
+ """Dequantize Q5_K blocks"""
82
+ n_blocks = blocks.shape[0]
83
+ d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
84
+
85
+ d = d.view(torch.float16).to(dtype)
86
+ dmin = dmin.view(torch.float16).to(dtype)
87
+
88
+ sc, m = get_scale_min(scales)
89
+
90
+ d = (d * sc).reshape((n_blocks, -1, 1))
91
+ dm = (dmin * m).reshape((n_blocks, -1, 1))
92
+
93
+ ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
94
+ qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1))
95
+ ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
96
+ qh = (qh & 0x01).reshape((n_blocks, -1, 32))
97
+ q = ql | (qh << 4)
98
+
99
+ return (d * q - dm).reshape((n_blocks, QK_K))
100
+
101
+
102
+ def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
103
+ """Dequantize Q4_K blocks"""
104
+ n_blocks = blocks.shape[0]
105
+ d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
106
+ d = d.view(torch.float16).to(dtype)
107
+ dmin = dmin.view(torch.float16).to(dtype)
108
+
109
+ sc, m = get_scale_min(scales)
110
+
111
+ d = (d * sc).reshape((n_blocks, -1, 1))
112
+ dm = (dmin * m).reshape((n_blocks, -1, 1))
113
+
114
+ qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
115
+ qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
116
+
117
+ return (d * qs - dm).reshape((n_blocks, QK_K))
118
+
119
+
120
+ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
121
+ """Dequantize BF16 blocks"""
122
+ return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
123
+
124
+
125
+ # Mapping of quantization types to dequantization functions
126
+ GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
127
+ DEQUANTIZE_FUNCTIONS = {
128
+ gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
129
+ gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
130
+ gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
131
+ gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
132
+ gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
133
+ }
134
+
135
+
136
+ def _quant_shape_from_byte_shape(shape, type_size, block_size):
137
+ """Calculate dequantized shape from quantized byte shape"""
138
+ return (*shape[:-1], shape[-1] // type_size * block_size)
139
+
140
+
141
+ def dequantize_gguf_tensor(tensor, target_dtype=None):
142
+ """
143
+ Dequantize a GGUF tensor to regular torch tensor
144
+
145
+ Args:
146
+ tensor: GGUFParameter or regular tensor
147
+ target_dtype: Target dtype for output (default: float32)
148
+
149
+ Returns:
150
+ Regular torch tensor
151
+ """
152
+ # If not quantized, just return the tensor
153
+ if not hasattr(tensor, "quant_type"):
154
+ return tensor.to(target_dtype) if target_dtype else tensor
155
+
156
+ quant_type = tensor.quant_type
157
+
158
+ # If F32 or F16, just convert normally
159
+ if quant_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
160
+ return tensor.to(target_dtype) if target_dtype else tensor
161
+
162
+ # Get dequantization function
163
+ if quant_type not in DEQUANTIZE_FUNCTIONS:
164
+ raise ValueError(f"Unsupported quantization type: {quant_type}")
165
+
166
+ dequant_fn = DEQUANTIZE_FUNCTIONS[quant_type]
167
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
168
+
169
+ # Prepare tensor for dequantization
170
+ tensor_bytes = tensor.view(torch.uint8)
171
+ shape = _quant_shape_from_byte_shape(tensor_bytes.shape, type_size, block_size)
172
+
173
+ n_blocks = tensor_bytes.numel() // type_size
174
+ blocks = tensor_bytes.reshape((n_blocks, type_size))
175
+
176
+ # Dequantize
177
+ dtype = target_dtype if target_dtype else torch.float32
178
+ dequant = dequant_fn(blocks, block_size, type_size, dtype=dtype)
179
+ dequant = dequant.reshape(shape)
180
+
181
+ return dequant
182
+
183
+
184
+ class GGUFParameter(torch.nn.Parameter):
185
+ """
186
+ Custom Parameter class for GGUF quantized tensors
187
+ Stores quantization metadata alongside the data
188
+ """
189
+ def __new__(cls, data, requires_grad=False, quant_type=None):
190
+ data = data if data is not None else torch.empty(0)
191
+ # Store byte shape before creating parameter
192
+ byte_shape = data.shape
193
+ self = torch.Tensor._make_subclass(cls, data, requires_grad)
194
+ self.quant_type = quant_type
195
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
196
+ self.quant_shape = _quant_shape_from_byte_shape(byte_shape, type_size, block_size)
197
+ return self
198
+
199
+ @property
200
+ def shape(self):
201
+ """Return the dequantized shape instead of byte shape"""
202
+ if hasattr(self, 'quant_shape'):
203
+ return self.quant_shape
204
+ # Fallback: get shape from parent class without causing recursion
205
+ return object.__getattribute__(self, 'data').shape if hasattr(self, 'data') else torch.Size()
206
+
207
+
208
+ def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix=""):
209
+ """
210
+ Replace nn.Linear layers with GGUF Linear layers for on-the-fly dequantization
211
+ Based on ComfyUI-WanVideoWrapper implementation
212
+ """
213
+ def _should_convert_to_gguf(state_dict, prefix):
214
+ weight_key = prefix + "weight"
215
+ return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter)
216
+
217
+ has_children = list(model.children())
218
+ if not has_children:
219
+ return
220
+
221
+ try:
222
+ from accelerate import init_empty_weights
223
+ use_accelerate = True
224
+ except ImportError:
225
+ use_accelerate = False
226
+
227
+ for name, module in model.named_children():
228
+ module_prefix = prefix + name + "."
229
+ _replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix)
230
+
231
+ if (
232
+ isinstance(module, nn.Linear)
233
+ and not isinstance(module, GGUFLinear)
234
+ and _should_convert_to_gguf(state_dict, module_prefix)
235
+ ):
236
+ # Get correct dimensions from the GGUF parameter shape
237
+ weight_param = state_dict[module_prefix + "weight"]
238
+ if hasattr(weight_param, 'quant_shape'):
239
+ out_features, in_features = weight_param.quant_shape
240
+ else:
241
+ out_features, in_features = weight_param.shape
242
+
243
+ # Check if this is a custom Linear subclass with a custom forward method
244
+ module_type = type(module)
245
+ has_custom_forward = (
246
+ module_type != nn.Linear and
247
+ hasattr(module_type, 'forward') and
248
+ module_type.forward is not nn.Linear.forward
249
+ )
250
+
251
+ if has_custom_forward:
252
+ # For custom Linear subclasses (like SharedAdaLin), create wrapped forward
253
+ from types import MethodType
254
+
255
+ def wrapped_forward(self, *args, **kwargs):
256
+ input_tensor = args[0] if args else None
257
+ if input_tensor is not None and hasattr(input_tensor, 'dtype'):
258
+ target_dtype = input_tensor.dtype if input_tensor.dtype in [torch.float16, torch.bfloat16, torch.float32] else compute_dtype
259
+ else:
260
+ target_dtype = compute_dtype
261
+
262
+ # Dequantize weights
263
+ dequant_weight = dequantize_gguf_tensor(self.weight, target_dtype=target_dtype)
264
+ dequant_bias = None
265
+ if self.bias is not None:
266
+ if isinstance(self.bias, GGUFParameter):
267
+ dequant_bias = dequantize_gguf_tensor(self.bias, target_dtype=target_dtype)
268
+ else:
269
+ dequant_bias = self.bias
270
+
271
+ # Perform linear operation
272
+ import torch.nn.functional as F
273
+ linear_output = F.linear(input_tensor, dequant_weight, dequant_bias)
274
+
275
+ # Apply custom reshaping for SharedAdaLin
276
+ if module_type.__name__ == 'SharedAdaLin':
277
+ C = dequant_weight.shape[0] // 6
278
+ return linear_output.reshape(-1, 1, 6, C)
279
+
280
+ return linear_output
281
+
282
+ new_module = GGUFLinear(
283
+ in_features,
284
+ out_features,
285
+ module.bias is not None,
286
+ compute_dtype=compute_dtype,
287
+ )
288
+ new_module.forward = MethodType(wrapped_forward, new_module)
289
+ else:
290
+ # Standard GGUFLinear replacement
291
+ if use_accelerate:
292
+ with init_empty_weights():
293
+ new_module = GGUFLinear(
294
+ in_features,
295
+ out_features,
296
+ module.bias is not None,
297
+ compute_dtype=compute_dtype,
298
+ )
299
+ else:
300
+ new_module = GGUFLinear(
301
+ in_features,
302
+ out_features,
303
+ module.bias is not None,
304
+ compute_dtype=compute_dtype,
305
+ )
306
+
307
+ model._modules[name] = new_module
308
+ model._modules[name].source_cls = type(module)
309
+ model._modules[name].requires_grad_(False)
310
+
311
+ return model
312
+
313
+
314
+ class GGUFLinear(nn.Linear):
315
+ """
316
+ Custom Linear layer that dequantizes GGUF weights on-the-fly
317
+ Compatible with Infinity model architecture
318
+ """
319
+ def __init__(
320
+ self,
321
+ in_features: int,
322
+ out_features: int,
323
+ bias: bool = True,
324
+ device=None,
325
+ dtype=None,
326
+ compute_dtype=None,
327
+ ):
328
+ super().__init__(in_features, out_features, bias, device, dtype)
329
+ self.compute_dtype = compute_dtype if compute_dtype else torch.float32
330
+
331
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
332
+ """
333
+ Forward pass with on-the-fly dequantization
334
+
335
+ Args:
336
+ input: Input tensor
337
+
338
+ Returns:
339
+ Output tensor after linear transformation
340
+ """
341
+ # Dequantize weight to compute dtype or match input dtype
342
+ target_dtype = input.dtype if input.dtype in [torch.float16, torch.bfloat16, torch.float32] else self.compute_dtype
343
+ weight = dequantize_gguf_tensor(self.weight, target_dtype=target_dtype)
344
+
345
+ # Transpose weight for PyTorch (GGUF stores as (out, in) for some, (in, out) for others)
346
+ # For linear layers, assume GGUF stores as (out, in)
347
+ # weight = weight.t()
348
+
349
+ # Dequantize bias if present
350
+ bias = None
351
+ if self.bias is not None:
352
+ bias = dequantize_gguf_tensor(self.bias, target_dtype=target_dtype)
353
+
354
+ # Perform linear operation
355
+ return torch.nn.functional.linear(input, weight, bias)
356
+
357
+
358
+ def load_gguf_state_dict_with_params(gguf_path, device='cuda'):
359
+ """
360
+ Load GGUF file and return state dict with GGUFParameters for quantized tensors
361
+ For use with _replace_with_gguf_linear
362
+ """
363
+ from gguf import GGUFReader
364
+ reader = GGUFReader(gguf_path)
365
+
366
+ state_dict = {}
367
+
368
+ for tensor in reader.tensors:
369
+ torch_tensor = torch.from_numpy(np.array(tensor.data)).to(device)
370
+
371
+ # Check if quantized
372
+ is_quantized = tensor.tensor_type not in {
373
+ gguf.GGMLQuantizationType.F32,
374
+ gguf.GGMLQuantizationType.F16
375
+ }
376
+
377
+ if is_quantized:
378
+ # Keep as GGUFParameter for on-the-fly dequantization
379
+ param = GGUFParameter(torch_tensor, quant_type=tensor.tensor_type)
380
+ state_dict[tensor.name] = param
381
+ else:
382
+ # Already F32 or F16 - convert to regular tensor
383
+ shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
384
+ torch_tensor = torch_tensor.view(*shape)
385
+ if tensor.tensor_type == gguf.GGMLQuantizationType.F32:
386
+ state_dict[tensor.name] = nn.Parameter(torch_tensor.float())
387
+ else:
388
+ state_dict[tensor.name] = nn.Parameter(torch_tensor.half())
389
+
390
+ return state_dict
391
+
392
+
393
+ def load_gguf_state_dict(gguf_path):
394
+ """
395
+ Load GGUF file and create state dict with GGUFParameters
396
+
397
+ Args:
398
+ gguf_path: Path to GGUF file
399
+
400
+ Returns:
401
+ state_dict: Dictionary mapping tensor names to GGUFParameters or regular tensors
402
+ """
403
+ from gguf import GGUFReader
404
+
405
+ reader = GGUFReader(gguf_path)
406
+ state_dict = {}
407
+
408
+ for tensor in reader.tensors:
409
+ # Check if tensor is quantized
410
+ is_quantized = tensor.tensor_type not in {
411
+ gguf.GGMLQuantizationType.F32,
412
+ gguf.GGMLQuantizationType.F16
413
+ }
414
+
415
+ # Create meta tensor with appropriate type
416
+ if is_quantized:
417
+ # For quantized tensors, create GGUFParameter
418
+ meta_tensor = torch.from_numpy(np.array(tensor.data)).to('cpu')
419
+ param = GGUFParameter(meta_tensor, quant_type=tensor.tensor_type)
420
+ state_dict[tensor.name] = param
421
+ else:
422
+ # For F32/F16, just load normally
423
+ state_dict[tensor.name] = torch.from_numpy(np.array(tensor.data)).to('cpu')
424
+
425
+ return state_dict
426
+
427
+
428
+ def replace_linear_with_gguf(model, state_dict, compute_dtype=torch.float32):
429
+ """
430
+ Recursively replace nn.Linear layers with GGUFLinear layers
431
+ where the corresponding weight in state_dict is a GGUFParameter
432
+
433
+ Args:
434
+ model: PyTorch model
435
+ state_dict: State dict with GGUFParameters
436
+ compute_dtype: Dtype to use for computation
437
+
438
+ Returns:
439
+ Modified model with GGUFLinear layers
440
+ """
441
+ for name, module in model.named_children():
442
+ # Recursively process children
443
+ replace_linear_with_gguf(module, state_dict, compute_dtype)
444
+
445
+ # Check if this is a Linear layer with quantized weights
446
+ if isinstance(module, nn.Linear):
447
+ weight_key = f"{get_module_prefix(model, name)}.weight"
448
+
449
+ if weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter):
450
+ # Replace with GGUFLinear
451
+ in_features = module.in_features
452
+ out_features = module.out_features
453
+ has_bias = module.bias is not None
454
+
455
+ gguf_linear = GGUFLinear(
456
+ in_features,
457
+ out_features,
458
+ bias=has_bias,
459
+ compute_dtype=compute_dtype
460
+ )
461
+
462
+ # Copy the module to the model
463
+ setattr(model, name, gguf_linear)
464
+
465
+ return model
466
+
467
+
468
+ def get_module_prefix(model, module_name):
469
+ """Helper to get the full prefix for a module"""
470
+ # This is a simplified version - you may need to adjust based on your model structure
471
+ return module_name
472
+
473
+
474
+ if __name__ == "__main__":
475
+ # Test dequantization
476
+ print("GGUF utilities loaded successfully!")
477
+ print(f"Supported quantization types: {list(DEQUANTIZE_FUNCTIONS.keys())}")