BiliSakura commited on
Commit
4f51e55
·
verified ·
1 Parent(s): 5423eac

Upload folder using huggingface_hub

Browse files
NiT-B/pipeline.py CHANGED
@@ -212,11 +212,27 @@ class NiTPipeline(DiffusionPipeline):
212
  width: int,
213
  num_inference_steps: int,
214
  output_type: str,
 
 
215
  ) -> None:
216
  if num_inference_steps < 1:
217
  raise ValueError("num_inference_steps must be >= 1.")
218
  if output_type not in {"pil", "np", "pt", "latent"}:
219
  raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  spatial_downsample = self._get_vae_spatial_downsample()
222
  if height % spatial_downsample != 0 or width % spatial_downsample != 0:
@@ -261,6 +277,29 @@ class NiTPipeline(DiffusionPipeline):
261
  )
262
  return packed_latents, image_sizes
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def _apply_classifier_free_guidance(
265
  self,
266
  model_output: torch.Tensor,
@@ -305,6 +344,9 @@ class NiTPipeline(DiffusionPipeline):
305
  num_inference_steps: int = 50,
306
  guidance_scale: float = 1.0,
307
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
 
 
 
308
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
309
  output_type: str = "pil",
310
  return_dict: bool = True,
@@ -325,6 +367,16 @@ class NiTPipeline(DiffusionPipeline):
325
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
326
  guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
327
  Flow-time interval where CFG is applied.
 
 
 
 
 
 
 
 
 
 
328
  generator (`torch.Generator`, *optional*):
329
  RNG for reproducibility.
330
  output_type (`str`, defaults to `"pil"`):
@@ -335,7 +387,8 @@ class NiTPipeline(DiffusionPipeline):
335
  default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
336
  height = int(height or default_size)
337
  width = int(width or default_size)
338
- self.check_inputs(height, width, num_inference_steps, output_type)
 
339
 
340
  device = self._execution_device
341
  model_dtype = next(self.transformer.parameters()).dtype
 
212
  width: int,
213
  num_inference_steps: int,
214
  output_type: str,
215
+ interpolation: Optional[str] = None,
216
+ ori_max_pe_len: Optional[int] = None,
217
  ) -> None:
218
  if num_inference_steps < 1:
219
  raise ValueError("num_inference_steps must be >= 1.")
220
  if output_type not in {"pil", "np", "pt", "latent"}:
221
  raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
222
+ if interpolation is not None and interpolation not in {
223
+ "no",
224
+ "linear",
225
+ "ntk-aware",
226
+ "ntk-by-parts",
227
+ "yarn",
228
+ "ntk-aware-pro1",
229
+ "ntk-aware-pro2",
230
+ "scale1",
231
+ "scale2",
232
+ }:
233
+ raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.")
234
+ if interpolation not in {None, "no"} and ori_max_pe_len is None:
235
+ raise ValueError("ori_max_pe_len is required when interpolation is enabled.")
236
 
237
  spatial_downsample = self._get_vae_spatial_downsample()
238
  if height % spatial_downsample != 0 or width % spatial_downsample != 0:
 
277
  )
278
  return packed_latents, image_sizes
279
 
280
+ def _maybe_configure_rope_extrapolation(
281
+ self,
282
+ height: int,
283
+ width: int,
284
+ interpolation: Optional[str],
285
+ ori_max_pe_len: Optional[int],
286
+ decouple: bool,
287
+ ) -> None:
288
+ if interpolation in {None, "no"}:
289
+ return
290
+
291
+ spatial_downsample = self._get_vae_spatial_downsample()
292
+ patch_size = int(self.transformer.config.patch_size)
293
+ latent_h = height // spatial_downsample // patch_size
294
+ latent_w = width // spatial_downsample // patch_size
295
+ self.transformer.configure_rope_extrapolation(
296
+ custom_freqs=interpolation,
297
+ max_pe_len_h=latent_h,
298
+ max_pe_len_w=latent_w,
299
+ ori_max_pe_len=int(ori_max_pe_len),
300
+ decouple=decouple,
301
+ )
302
+
303
  def _apply_classifier_free_guidance(
304
  self,
305
  model_output: torch.Tensor,
 
344
  num_inference_steps: int = 50,
345
  guidance_scale: float = 1.0,
346
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
347
+ interpolation: Optional[str] = None,
348
+ ori_max_pe_len: Optional[int] = None,
349
+ decouple: bool = False,
350
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
351
  output_type: str = "pil",
352
  return_dict: bool = True,
 
367
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
368
  guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
369
  Flow-time interval where CFG is applied.
370
+ interpolation (`str`, *optional*):
371
+ VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or
372
+ `"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`,
373
+ `"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use
374
+ the transformer's configured RoPE.
375
+ ori_max_pe_len (`int`, *optional*):
376
+ Original maximum latent side length seen during training. Required when
377
+ `interpolation` is enabled.
378
+ decouple (`bool`, defaults to `False`):
379
+ Whether to decouple height and width when computing extrapolated RoPE frequencies.
380
  generator (`torch.Generator`, *optional*):
381
  RNG for reproducibility.
382
  output_type (`str`, defaults to `"pil"`):
 
387
  default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
388
  height = int(height or default_size)
389
  width = int(width or default_size)
390
+ self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len)
391
+ self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple)
392
 
393
  device = self._execution_device
394
  model_dtype = next(self.transformer.parameters()).dtype
NiT-B/transformer/nit_transformer_2d.py CHANGED
@@ -74,6 +74,54 @@ def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.
74
  return torch.get_default_dtype()
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class NiTPatchEmbed(nn.Module):
78
  def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
79
  super().__init__()
@@ -125,6 +173,8 @@ class NiTLabelEmbedder(nn.Module):
125
 
126
 
127
  class NiTRotaryEmbedding(nn.Module):
 
 
128
  def __init__(
129
  self,
130
  head_dim: int,
@@ -137,26 +187,127 @@ class NiTRotaryEmbedding(nn.Module):
137
  ori_max_pe_len: Optional[int] = None,
138
  ):
139
  super().__init__()
140
- del max_pe_len_h, max_pe_len_w, decouple, ori_max_pe_len
141
- if custom_freqs not in {"normal", "scale1", "scale2"}:
142
  raise ValueError(
143
- "This Diffusers implementation supports the trained RoPE frequencies directly. "
144
- "Checkpoint conversion preserves weights; extrapolation variants should be handled "
145
- "by changing the model config before loading."
146
  )
 
 
 
 
 
 
 
147
  dim = head_dim // 2
148
  if dim % 2 != 0:
149
  raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
 
 
 
 
 
 
150
  default_dtype = _get_float_dtype_or_default()
151
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
152
- self.register_buffer("freqs_h", freqs, persistent=False)
153
- self.register_buffer("freqs_w", freqs.clone(), persistent=False)
154
- positions = torch.arange(max_cached_len, dtype=default_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
156
  freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
157
  self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
158
  self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
161
  grids = []
162
  for height, width in image_sizes.tolist():
@@ -166,10 +317,12 @@ class NiTRotaryEmbedding(nn.Module):
166
  grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
167
  grids.append(torch.stack(grid, dim=0).reshape(2, -1))
168
  grid = torch.cat(grids, dim=1)
 
169
  freqs_h = self.freqs_h_cached.to(device)[grid[0]]
170
  freqs_w = self.freqs_w_cached.to(device)[grid[1]]
171
  freqs = torch.cat([freqs_h, freqs_w], dim=-1)
172
- return freqs.cos().unsqueeze(1), freqs.sin().unsqueeze(1)
 
173
 
174
 
175
  class NiTAttention(nn.Module):
@@ -367,6 +520,38 @@ class NiTTransformer2DModel(ModelMixin, ConfigMixin):
367
  )
368
  self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
371
  batch_size, channels, height, width = hidden_states.shape
372
  if channels != self.in_channels:
 
74
  return torch.get_default_dtype()
75
 
76
 
77
+ # VisionYaRN / VisionNTK helpers (from native NiT / FiT VisionRotaryEmbedding).
78
+ def _find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
79
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
80
+
81
+
82
+ def _find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
83
+ low = math.floor(_find_correction_factor(low_rot, dim, base, max_position_embeddings))
84
+ high = math.ceil(_find_correction_factor(high_rot, dim, base, max_position_embeddings))
85
+ return max(low, 0), min(high, dim - 1)
86
+
87
+
88
+ def _linear_ramp_mask(min_value, max_value, dim):
89
+ if min_value == max_value:
90
+ max_value += 0.001
91
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min_value) / (max_value - min_value)
92
+ return torch.clamp(linear_func, 0, 1)
93
+
94
+
95
+ def _find_newbase_ntk(dim, base=10000, scale=1):
96
+ return base * scale ** (dim / (dim - 2))
97
+
98
+
99
+ def _get_mscale(scale: torch.Tensor) -> torch.Tensor:
100
+ return torch.where(scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
101
+
102
+
103
+ def _get_proportion(length_test, length_train):
104
+ length_test = length_test * 2
105
+ ratio = length_test / length_train
106
+ return torch.where(
107
+ torch.tensor(ratio) <= 1.0,
108
+ torch.tensor(1.0),
109
+ torch.sqrt(torch.log(torch.tensor(length_test)) / torch.log(torch.tensor(length_train))),
110
+ )
111
+
112
+
113
+ TRAINED_ROPE_FREQS = {"normal", "scale1", "scale2"}
114
+ EXTRAPOLATION_ROPE_FREQS = {
115
+ "linear",
116
+ "ntk-aware",
117
+ "ntk-aware-pro1",
118
+ "ntk-aware-pro2",
119
+ "ntk-by-parts",
120
+ "yarn",
121
+ }
122
+ SUPPORTED_ROPE_FREQS = TRAINED_ROPE_FREQS | EXTRAPOLATION_ROPE_FREQS
123
+
124
+
125
  class NiTPatchEmbed(nn.Module):
126
  def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
127
  super().__init__()
 
173
 
174
 
175
  class NiTRotaryEmbedding(nn.Module):
176
+ """2D axial RoPE with VisionYaRN (`yarn`) and VisionNTK extrapolation modes."""
177
+
178
  def __init__(
179
  self,
180
  head_dim: int,
 
187
  ori_max_pe_len: Optional[int] = None,
188
  ):
189
  super().__init__()
190
+ custom_freqs = custom_freqs.lower()
191
+ if custom_freqs not in SUPPORTED_ROPE_FREQS:
192
  raise ValueError(
193
+ f"Unsupported RoPE frequency variant {custom_freqs!r}. "
194
+ f"Supported values: {sorted(SUPPORTED_ROPE_FREQS)}."
 
195
  )
196
+ if custom_freqs not in TRAINED_ROPE_FREQS and (
197
+ max_pe_len_h is None or max_pe_len_w is None or ori_max_pe_len is None
198
+ ):
199
+ raise ValueError(
200
+ f"Extrapolation mode {custom_freqs!r} requires max_pe_len_h, max_pe_len_w, and ori_max_pe_len."
201
+ )
202
+
203
  dim = head_dim // 2
204
  if dim % 2 != 0:
205
  raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
206
+
207
+ self.dim = dim
208
+ self.custom_freqs = custom_freqs
209
+ self.theta = theta
210
+ self.decouple = decouple
211
+ self.ori_max_pe_len = ori_max_pe_len
212
  default_dtype = _get_float_dtype_or_default()
213
+
214
+ if custom_freqs in TRAINED_ROPE_FREQS:
215
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
216
+ freqs_h = freqs
217
+ freqs_w = freqs.clone()
218
+ else:
219
+ if decouple:
220
+ freqs_h = self._get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len, default_dtype)
221
+ freqs_w = self._get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len, default_dtype)
222
+ else:
223
+ max_pe_len = max(max_pe_len_h, max_pe_len_w)
224
+ freqs = self._get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len, default_dtype)
225
+ freqs_h = freqs
226
+ freqs_w = freqs.clone()
227
+
228
+ if max_pe_len_h is not None and max_pe_len_w is not None and ori_max_pe_len is not None:
229
+ scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)
230
+ proportion1 = _get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
231
+ proportion2 = _get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len**2)
232
+ self.register_buffer("mscale", _get_mscale(scale).to(default_dtype), persistent=False)
233
+ self.register_buffer(
234
+ "proportion1",
235
+ proportion1.to(dtype=default_dtype) if isinstance(proportion1, torch.Tensor) else torch.tensor(float(proportion1), dtype=default_dtype),
236
+ persistent=False,
237
+ )
238
+ self.register_buffer(
239
+ "proportion2",
240
+ proportion2.to(dtype=default_dtype) if isinstance(proportion2, torch.Tensor) else torch.tensor(float(proportion2), dtype=default_dtype),
241
+ persistent=False,
242
+ )
243
+
244
+ self.register_buffer("freqs_h", freqs_h, persistent=False)
245
+ self.register_buffer("freqs_w", freqs_w, persistent=False)
246
+
247
+ cache_len = max(max_cached_len, max_pe_len_h or 0, max_pe_len_w or 0, 1)
248
+ positions = torch.arange(cache_len, dtype=default_dtype)
249
  freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
250
  freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
251
  self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
252
  self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
253
 
254
+ def _get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len, default_dtype):
255
+ assert isinstance(ori_max_pe_len, int)
256
+ if not isinstance(max_pe_len, torch.Tensor):
257
+ max_pe_len = torch.tensor(max_pe_len, dtype=default_dtype)
258
+ scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
259
+ freq_indices = torch.arange(0, dim, 2, dtype=default_dtype) / dim
260
+
261
+ if self.custom_freqs == "linear":
262
+ freqs = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices)
263
+ elif self.custom_freqs in {"ntk-aware", "ntk-aware-pro1", "ntk-aware-pro2"}:
264
+ freqs = 1.0 / torch.pow(
265
+ _find_newbase_ntk(dim, theta, scale).view(-1, 1),
266
+ freq_indices.to(scale),
267
+ ).squeeze()
268
+ elif self.custom_freqs == "ntk-by-parts":
269
+ beta_0, beta_1 = 1.25, 0.75
270
+ gamma_0, gamma_1 = 16, 2
271
+ freqs_base = 1.0 / (theta**freq_indices)
272
+ freqs_linear = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
273
+ freqs_ntk = 1.0 / torch.pow(
274
+ _find_newbase_ntk(dim, theta, scale).view(-1, 1),
275
+ freq_indices.to(scale),
276
+ ).squeeze()
277
+ low, high = _find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
278
+ freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
279
+ freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
280
+ low, high = _find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
281
+ freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
282
+ freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
283
+ elif self.custom_freqs == "yarn":
284
+ beta_fast, beta_slow = 32, 1
285
+ freqs_extrapolation = 1.0 / (theta**freq_indices.to(scale))
286
+ freqs_interpolation = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
287
+ low, high = _find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
288
+ freqs_mask = (1 - _linear_ramp_mask(low, high, dim // 2).to(scale).float())
289
+ freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
290
+ else:
291
+ raise ValueError(f"Unknown extrapolation mode {self.custom_freqs!r}.")
292
+
293
+ if isinstance(freqs, torch.Tensor) and freqs.ndim > 1:
294
+ freqs = freqs.squeeze()
295
+ return freqs.to(default_dtype)
296
+
297
+ def _apply_magnitude_scaling(
298
+ self, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
299
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
300
+ if self.custom_freqs == "yarn" and hasattr(self, "mscale"):
301
+ freqs_cos = freqs_cos * self.mscale
302
+ freqs_sin = freqs_sin * self.mscale
303
+ elif self.custom_freqs in {"ntk-aware-pro1", "scale1"} and hasattr(self, "proportion1"):
304
+ freqs_cos = freqs_cos * self.proportion1
305
+ freqs_sin = freqs_sin * self.proportion1
306
+ elif self.custom_freqs in {"ntk-aware-pro2", "scale2"} and hasattr(self, "proportion2"):
307
+ freqs_cos = freqs_cos * self.proportion2
308
+ freqs_sin = freqs_sin * self.proportion2
309
+ return freqs_cos, freqs_sin
310
+
311
  def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
312
  grids = []
313
  for height, width in image_sizes.tolist():
 
317
  grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
318
  grids.append(torch.stack(grid, dim=0).reshape(2, -1))
319
  grid = torch.cat(grids, dim=1)
320
+
321
  freqs_h = self.freqs_h_cached.to(device)[grid[0]]
322
  freqs_w = self.freqs_w_cached.to(device)[grid[1]]
323
  freqs = torch.cat([freqs_h, freqs_w], dim=-1)
324
+ freqs_cos, freqs_sin = self._apply_magnitude_scaling(freqs.cos(), freqs.sin())
325
+ return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
326
 
327
 
328
  class NiTAttention(nn.Module):
 
520
  )
521
  self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
522
 
523
+ def configure_rope_extrapolation(
524
+ self,
525
+ custom_freqs: str,
526
+ max_pe_len_h: int,
527
+ max_pe_len_w: int,
528
+ ori_max_pe_len: int,
529
+ decouple: bool = False,
530
+ theta: Optional[int] = None,
531
+ ) -> None:
532
+ """Configure VisionYaRN / VisionNTK extrapolation before high-resolution inference."""
533
+ theta = int(theta if theta is not None else getattr(self.config, "theta", 10000))
534
+ head_dim = self.config.hidden_size // self.config.num_heads
535
+ self.rope = NiTRotaryEmbedding(
536
+ head_dim,
537
+ custom_freqs=custom_freqs,
538
+ theta=theta,
539
+ max_pe_len_h=max_pe_len_h,
540
+ max_pe_len_w=max_pe_len_w,
541
+ decouple=decouple,
542
+ ori_max_pe_len=ori_max_pe_len,
543
+ )
544
+ for key, value in {
545
+ "custom_freqs": custom_freqs.lower(),
546
+ "max_pe_len_h": max_pe_len_h,
547
+ "max_pe_len_w": max_pe_len_w,
548
+ "decouple": decouple,
549
+ "ori_max_pe_len": ori_max_pe_len,
550
+ "theta": theta,
551
+ }.items():
552
+ if hasattr(self.config, key):
553
+ setattr(self.config, key, value)
554
+
555
  def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
556
  batch_size, channels, height, width = hidden_states.shape
557
  if channels != self.in_channels:
NiT-L/pipeline.py CHANGED
@@ -212,11 +212,27 @@ class NiTPipeline(DiffusionPipeline):
212
  width: int,
213
  num_inference_steps: int,
214
  output_type: str,
 
 
215
  ) -> None:
216
  if num_inference_steps < 1:
217
  raise ValueError("num_inference_steps must be >= 1.")
218
  if output_type not in {"pil", "np", "pt", "latent"}:
219
  raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  spatial_downsample = self._get_vae_spatial_downsample()
222
  if height % spatial_downsample != 0 or width % spatial_downsample != 0:
@@ -261,6 +277,29 @@ class NiTPipeline(DiffusionPipeline):
261
  )
262
  return packed_latents, image_sizes
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def _apply_classifier_free_guidance(
265
  self,
266
  model_output: torch.Tensor,
@@ -305,6 +344,9 @@ class NiTPipeline(DiffusionPipeline):
305
  num_inference_steps: int = 50,
306
  guidance_scale: float = 1.0,
307
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
 
 
 
308
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
309
  output_type: str = "pil",
310
  return_dict: bool = True,
@@ -325,6 +367,16 @@ class NiTPipeline(DiffusionPipeline):
325
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
326
  guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
327
  Flow-time interval where CFG is applied.
 
 
 
 
 
 
 
 
 
 
328
  generator (`torch.Generator`, *optional*):
329
  RNG for reproducibility.
330
  output_type (`str`, defaults to `"pil"`):
@@ -335,7 +387,8 @@ class NiTPipeline(DiffusionPipeline):
335
  default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
336
  height = int(height or default_size)
337
  width = int(width or default_size)
338
- self.check_inputs(height, width, num_inference_steps, output_type)
 
339
 
340
  device = self._execution_device
341
  model_dtype = next(self.transformer.parameters()).dtype
 
212
  width: int,
213
  num_inference_steps: int,
214
  output_type: str,
215
+ interpolation: Optional[str] = None,
216
+ ori_max_pe_len: Optional[int] = None,
217
  ) -> None:
218
  if num_inference_steps < 1:
219
  raise ValueError("num_inference_steps must be >= 1.")
220
  if output_type not in {"pil", "np", "pt", "latent"}:
221
  raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
222
+ if interpolation is not None and interpolation not in {
223
+ "no",
224
+ "linear",
225
+ "ntk-aware",
226
+ "ntk-by-parts",
227
+ "yarn",
228
+ "ntk-aware-pro1",
229
+ "ntk-aware-pro2",
230
+ "scale1",
231
+ "scale2",
232
+ }:
233
+ raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.")
234
+ if interpolation not in {None, "no"} and ori_max_pe_len is None:
235
+ raise ValueError("ori_max_pe_len is required when interpolation is enabled.")
236
 
237
  spatial_downsample = self._get_vae_spatial_downsample()
238
  if height % spatial_downsample != 0 or width % spatial_downsample != 0:
 
277
  )
278
  return packed_latents, image_sizes
279
 
280
+ def _maybe_configure_rope_extrapolation(
281
+ self,
282
+ height: int,
283
+ width: int,
284
+ interpolation: Optional[str],
285
+ ori_max_pe_len: Optional[int],
286
+ decouple: bool,
287
+ ) -> None:
288
+ if interpolation in {None, "no"}:
289
+ return
290
+
291
+ spatial_downsample = self._get_vae_spatial_downsample()
292
+ patch_size = int(self.transformer.config.patch_size)
293
+ latent_h = height // spatial_downsample // patch_size
294
+ latent_w = width // spatial_downsample // patch_size
295
+ self.transformer.configure_rope_extrapolation(
296
+ custom_freqs=interpolation,
297
+ max_pe_len_h=latent_h,
298
+ max_pe_len_w=latent_w,
299
+ ori_max_pe_len=int(ori_max_pe_len),
300
+ decouple=decouple,
301
+ )
302
+
303
  def _apply_classifier_free_guidance(
304
  self,
305
  model_output: torch.Tensor,
 
344
  num_inference_steps: int = 50,
345
  guidance_scale: float = 1.0,
346
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
347
+ interpolation: Optional[str] = None,
348
+ ori_max_pe_len: Optional[int] = None,
349
+ decouple: bool = False,
350
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
351
  output_type: str = "pil",
352
  return_dict: bool = True,
 
367
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
368
  guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
369
  Flow-time interval where CFG is applied.
370
+ interpolation (`str`, *optional*):
371
+ VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or
372
+ `"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`,
373
+ `"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use
374
+ the transformer's configured RoPE.
375
+ ori_max_pe_len (`int`, *optional*):
376
+ Original maximum latent side length seen during training. Required when
377
+ `interpolation` is enabled.
378
+ decouple (`bool`, defaults to `False`):
379
+ Whether to decouple height and width when computing extrapolated RoPE frequencies.
380
  generator (`torch.Generator`, *optional*):
381
  RNG for reproducibility.
382
  output_type (`str`, defaults to `"pil"`):
 
387
  default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
388
  height = int(height or default_size)
389
  width = int(width or default_size)
390
+ self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len)
391
+ self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple)
392
 
393
  device = self._execution_device
394
  model_dtype = next(self.transformer.parameters()).dtype
NiT-L/transformer/nit_transformer_2d.py CHANGED
@@ -74,6 +74,54 @@ def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.
74
  return torch.get_default_dtype()
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class NiTPatchEmbed(nn.Module):
78
  def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
79
  super().__init__()
@@ -125,6 +173,8 @@ class NiTLabelEmbedder(nn.Module):
125
 
126
 
127
  class NiTRotaryEmbedding(nn.Module):
 
 
128
  def __init__(
129
  self,
130
  head_dim: int,
@@ -137,26 +187,127 @@ class NiTRotaryEmbedding(nn.Module):
137
  ori_max_pe_len: Optional[int] = None,
138
  ):
139
  super().__init__()
140
- del max_pe_len_h, max_pe_len_w, decouple, ori_max_pe_len
141
- if custom_freqs not in {"normal", "scale1", "scale2"}:
142
  raise ValueError(
143
- "This Diffusers implementation supports the trained RoPE frequencies directly. "
144
- "Checkpoint conversion preserves weights; extrapolation variants should be handled "
145
- "by changing the model config before loading."
146
  )
 
 
 
 
 
 
 
147
  dim = head_dim // 2
148
  if dim % 2 != 0:
149
  raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
 
 
 
 
 
 
150
  default_dtype = _get_float_dtype_or_default()
151
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
152
- self.register_buffer("freqs_h", freqs, persistent=False)
153
- self.register_buffer("freqs_w", freqs.clone(), persistent=False)
154
- positions = torch.arange(max_cached_len, dtype=default_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
156
  freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
157
  self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
158
  self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
161
  grids = []
162
  for height, width in image_sizes.tolist():
@@ -166,10 +317,12 @@ class NiTRotaryEmbedding(nn.Module):
166
  grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
167
  grids.append(torch.stack(grid, dim=0).reshape(2, -1))
168
  grid = torch.cat(grids, dim=1)
 
169
  freqs_h = self.freqs_h_cached.to(device)[grid[0]]
170
  freqs_w = self.freqs_w_cached.to(device)[grid[1]]
171
  freqs = torch.cat([freqs_h, freqs_w], dim=-1)
172
- return freqs.cos().unsqueeze(1), freqs.sin().unsqueeze(1)
 
173
 
174
 
175
  class NiTAttention(nn.Module):
@@ -367,6 +520,38 @@ class NiTTransformer2DModel(ModelMixin, ConfigMixin):
367
  )
368
  self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
371
  batch_size, channels, height, width = hidden_states.shape
372
  if channels != self.in_channels:
 
74
  return torch.get_default_dtype()
75
 
76
 
77
+ # VisionYaRN / VisionNTK helpers (from native NiT / FiT VisionRotaryEmbedding).
78
+ def _find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
79
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
80
+
81
+
82
+ def _find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
83
+ low = math.floor(_find_correction_factor(low_rot, dim, base, max_position_embeddings))
84
+ high = math.ceil(_find_correction_factor(high_rot, dim, base, max_position_embeddings))
85
+ return max(low, 0), min(high, dim - 1)
86
+
87
+
88
+ def _linear_ramp_mask(min_value, max_value, dim):
89
+ if min_value == max_value:
90
+ max_value += 0.001
91
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min_value) / (max_value - min_value)
92
+ return torch.clamp(linear_func, 0, 1)
93
+
94
+
95
+ def _find_newbase_ntk(dim, base=10000, scale=1):
96
+ return base * scale ** (dim / (dim - 2))
97
+
98
+
99
+ def _get_mscale(scale: torch.Tensor) -> torch.Tensor:
100
+ return torch.where(scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
101
+
102
+
103
+ def _get_proportion(length_test, length_train):
104
+ length_test = length_test * 2
105
+ ratio = length_test / length_train
106
+ return torch.where(
107
+ torch.tensor(ratio) <= 1.0,
108
+ torch.tensor(1.0),
109
+ torch.sqrt(torch.log(torch.tensor(length_test)) / torch.log(torch.tensor(length_train))),
110
+ )
111
+
112
+
113
+ TRAINED_ROPE_FREQS = {"normal", "scale1", "scale2"}
114
+ EXTRAPOLATION_ROPE_FREQS = {
115
+ "linear",
116
+ "ntk-aware",
117
+ "ntk-aware-pro1",
118
+ "ntk-aware-pro2",
119
+ "ntk-by-parts",
120
+ "yarn",
121
+ }
122
+ SUPPORTED_ROPE_FREQS = TRAINED_ROPE_FREQS | EXTRAPOLATION_ROPE_FREQS
123
+
124
+
125
  class NiTPatchEmbed(nn.Module):
126
  def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
127
  super().__init__()
 
173
 
174
 
175
  class NiTRotaryEmbedding(nn.Module):
176
+ """2D axial RoPE with VisionYaRN (`yarn`) and VisionNTK extrapolation modes."""
177
+
178
  def __init__(
179
  self,
180
  head_dim: int,
 
187
  ori_max_pe_len: Optional[int] = None,
188
  ):
189
  super().__init__()
190
+ custom_freqs = custom_freqs.lower()
191
+ if custom_freqs not in SUPPORTED_ROPE_FREQS:
192
  raise ValueError(
193
+ f"Unsupported RoPE frequency variant {custom_freqs!r}. "
194
+ f"Supported values: {sorted(SUPPORTED_ROPE_FREQS)}."
 
195
  )
196
+ if custom_freqs not in TRAINED_ROPE_FREQS and (
197
+ max_pe_len_h is None or max_pe_len_w is None or ori_max_pe_len is None
198
+ ):
199
+ raise ValueError(
200
+ f"Extrapolation mode {custom_freqs!r} requires max_pe_len_h, max_pe_len_w, and ori_max_pe_len."
201
+ )
202
+
203
  dim = head_dim // 2
204
  if dim % 2 != 0:
205
  raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
206
+
207
+ self.dim = dim
208
+ self.custom_freqs = custom_freqs
209
+ self.theta = theta
210
+ self.decouple = decouple
211
+ self.ori_max_pe_len = ori_max_pe_len
212
  default_dtype = _get_float_dtype_or_default()
213
+
214
+ if custom_freqs in TRAINED_ROPE_FREQS:
215
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
216
+ freqs_h = freqs
217
+ freqs_w = freqs.clone()
218
+ else:
219
+ if decouple:
220
+ freqs_h = self._get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len, default_dtype)
221
+ freqs_w = self._get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len, default_dtype)
222
+ else:
223
+ max_pe_len = max(max_pe_len_h, max_pe_len_w)
224
+ freqs = self._get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len, default_dtype)
225
+ freqs_h = freqs
226
+ freqs_w = freqs.clone()
227
+
228
+ if max_pe_len_h is not None and max_pe_len_w is not None and ori_max_pe_len is not None:
229
+ scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)
230
+ proportion1 = _get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
231
+ proportion2 = _get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len**2)
232
+ self.register_buffer("mscale", _get_mscale(scale).to(default_dtype), persistent=False)
233
+ self.register_buffer(
234
+ "proportion1",
235
+ proportion1.to(dtype=default_dtype) if isinstance(proportion1, torch.Tensor) else torch.tensor(float(proportion1), dtype=default_dtype),
236
+ persistent=False,
237
+ )
238
+ self.register_buffer(
239
+ "proportion2",
240
+ proportion2.to(dtype=default_dtype) if isinstance(proportion2, torch.Tensor) else torch.tensor(float(proportion2), dtype=default_dtype),
241
+ persistent=False,
242
+ )
243
+
244
+ self.register_buffer("freqs_h", freqs_h, persistent=False)
245
+ self.register_buffer("freqs_w", freqs_w, persistent=False)
246
+
247
+ cache_len = max(max_cached_len, max_pe_len_h or 0, max_pe_len_w or 0, 1)
248
+ positions = torch.arange(cache_len, dtype=default_dtype)
249
  freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
250
  freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
251
  self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
252
  self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
253
 
254
+ def _get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len, default_dtype):
255
+ assert isinstance(ori_max_pe_len, int)
256
+ if not isinstance(max_pe_len, torch.Tensor):
257
+ max_pe_len = torch.tensor(max_pe_len, dtype=default_dtype)
258
+ scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
259
+ freq_indices = torch.arange(0, dim, 2, dtype=default_dtype) / dim
260
+
261
+ if self.custom_freqs == "linear":
262
+ freqs = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices)
263
+ elif self.custom_freqs in {"ntk-aware", "ntk-aware-pro1", "ntk-aware-pro2"}:
264
+ freqs = 1.0 / torch.pow(
265
+ _find_newbase_ntk(dim, theta, scale).view(-1, 1),
266
+ freq_indices.to(scale),
267
+ ).squeeze()
268
+ elif self.custom_freqs == "ntk-by-parts":
269
+ beta_0, beta_1 = 1.25, 0.75
270
+ gamma_0, gamma_1 = 16, 2
271
+ freqs_base = 1.0 / (theta**freq_indices)
272
+ freqs_linear = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
273
+ freqs_ntk = 1.0 / torch.pow(
274
+ _find_newbase_ntk(dim, theta, scale).view(-1, 1),
275
+ freq_indices.to(scale),
276
+ ).squeeze()
277
+ low, high = _find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
278
+ freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
279
+ freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
280
+ low, high = _find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
281
+ freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
282
+ freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
283
+ elif self.custom_freqs == "yarn":
284
+ beta_fast, beta_slow = 32, 1
285
+ freqs_extrapolation = 1.0 / (theta**freq_indices.to(scale))
286
+ freqs_interpolation = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
287
+ low, high = _find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
288
+ freqs_mask = (1 - _linear_ramp_mask(low, high, dim // 2).to(scale).float())
289
+ freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
290
+ else:
291
+ raise ValueError(f"Unknown extrapolation mode {self.custom_freqs!r}.")
292
+
293
+ if isinstance(freqs, torch.Tensor) and freqs.ndim > 1:
294
+ freqs = freqs.squeeze()
295
+ return freqs.to(default_dtype)
296
+
297
+ def _apply_magnitude_scaling(
298
+ self, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
299
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
300
+ if self.custom_freqs == "yarn" and hasattr(self, "mscale"):
301
+ freqs_cos = freqs_cos * self.mscale
302
+ freqs_sin = freqs_sin * self.mscale
303
+ elif self.custom_freqs in {"ntk-aware-pro1", "scale1"} and hasattr(self, "proportion1"):
304
+ freqs_cos = freqs_cos * self.proportion1
305
+ freqs_sin = freqs_sin * self.proportion1
306
+ elif self.custom_freqs in {"ntk-aware-pro2", "scale2"} and hasattr(self, "proportion2"):
307
+ freqs_cos = freqs_cos * self.proportion2
308
+ freqs_sin = freqs_sin * self.proportion2
309
+ return freqs_cos, freqs_sin
310
+
311
  def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
312
  grids = []
313
  for height, width in image_sizes.tolist():
 
317
  grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
318
  grids.append(torch.stack(grid, dim=0).reshape(2, -1))
319
  grid = torch.cat(grids, dim=1)
320
+
321
  freqs_h = self.freqs_h_cached.to(device)[grid[0]]
322
  freqs_w = self.freqs_w_cached.to(device)[grid[1]]
323
  freqs = torch.cat([freqs_h, freqs_w], dim=-1)
324
+ freqs_cos, freqs_sin = self._apply_magnitude_scaling(freqs.cos(), freqs.sin())
325
+ return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
326
 
327
 
328
  class NiTAttention(nn.Module):
 
520
  )
521
  self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
522
 
523
+ def configure_rope_extrapolation(
524
+ self,
525
+ custom_freqs: str,
526
+ max_pe_len_h: int,
527
+ max_pe_len_w: int,
528
+ ori_max_pe_len: int,
529
+ decouple: bool = False,
530
+ theta: Optional[int] = None,
531
+ ) -> None:
532
+ """Configure VisionYaRN / VisionNTK extrapolation before high-resolution inference."""
533
+ theta = int(theta if theta is not None else getattr(self.config, "theta", 10000))
534
+ head_dim = self.config.hidden_size // self.config.num_heads
535
+ self.rope = NiTRotaryEmbedding(
536
+ head_dim,
537
+ custom_freqs=custom_freqs,
538
+ theta=theta,
539
+ max_pe_len_h=max_pe_len_h,
540
+ max_pe_len_w=max_pe_len_w,
541
+ decouple=decouple,
542
+ ori_max_pe_len=ori_max_pe_len,
543
+ )
544
+ for key, value in {
545
+ "custom_freqs": custom_freqs.lower(),
546
+ "max_pe_len_h": max_pe_len_h,
547
+ "max_pe_len_w": max_pe_len_w,
548
+ "decouple": decouple,
549
+ "ori_max_pe_len": ori_max_pe_len,
550
+ "theta": theta,
551
+ }.items():
552
+ if hasattr(self.config, key):
553
+ setattr(self.config, key, value)
554
+
555
  def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
556
  batch_size, channels, height, width = hidden_states.shape
557
  if channels != self.in_channels:
NiT-S/pipeline.py CHANGED
@@ -212,11 +212,27 @@ class NiTPipeline(DiffusionPipeline):
212
  width: int,
213
  num_inference_steps: int,
214
  output_type: str,
 
 
215
  ) -> None:
216
  if num_inference_steps < 1:
217
  raise ValueError("num_inference_steps must be >= 1.")
218
  if output_type not in {"pil", "np", "pt", "latent"}:
219
  raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  spatial_downsample = self._get_vae_spatial_downsample()
222
  if height % spatial_downsample != 0 or width % spatial_downsample != 0:
@@ -261,6 +277,29 @@ class NiTPipeline(DiffusionPipeline):
261
  )
262
  return packed_latents, image_sizes
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def _apply_classifier_free_guidance(
265
  self,
266
  model_output: torch.Tensor,
@@ -305,6 +344,9 @@ class NiTPipeline(DiffusionPipeline):
305
  num_inference_steps: int = 50,
306
  guidance_scale: float = 1.0,
307
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
 
 
 
308
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
309
  output_type: str = "pil",
310
  return_dict: bool = True,
@@ -325,6 +367,16 @@ class NiTPipeline(DiffusionPipeline):
325
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
326
  guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
327
  Flow-time interval where CFG is applied.
 
 
 
 
 
 
 
 
 
 
328
  generator (`torch.Generator`, *optional*):
329
  RNG for reproducibility.
330
  output_type (`str`, defaults to `"pil"`):
@@ -335,7 +387,8 @@ class NiTPipeline(DiffusionPipeline):
335
  default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
336
  height = int(height or default_size)
337
  width = int(width or default_size)
338
- self.check_inputs(height, width, num_inference_steps, output_type)
 
339
 
340
  device = self._execution_device
341
  model_dtype = next(self.transformer.parameters()).dtype
 
212
  width: int,
213
  num_inference_steps: int,
214
  output_type: str,
215
+ interpolation: Optional[str] = None,
216
+ ori_max_pe_len: Optional[int] = None,
217
  ) -> None:
218
  if num_inference_steps < 1:
219
  raise ValueError("num_inference_steps must be >= 1.")
220
  if output_type not in {"pil", "np", "pt", "latent"}:
221
  raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
222
+ if interpolation is not None and interpolation not in {
223
+ "no",
224
+ "linear",
225
+ "ntk-aware",
226
+ "ntk-by-parts",
227
+ "yarn",
228
+ "ntk-aware-pro1",
229
+ "ntk-aware-pro2",
230
+ "scale1",
231
+ "scale2",
232
+ }:
233
+ raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.")
234
+ if interpolation not in {None, "no"} and ori_max_pe_len is None:
235
+ raise ValueError("ori_max_pe_len is required when interpolation is enabled.")
236
 
237
  spatial_downsample = self._get_vae_spatial_downsample()
238
  if height % spatial_downsample != 0 or width % spatial_downsample != 0:
 
277
  )
278
  return packed_latents, image_sizes
279
 
280
+ def _maybe_configure_rope_extrapolation(
281
+ self,
282
+ height: int,
283
+ width: int,
284
+ interpolation: Optional[str],
285
+ ori_max_pe_len: Optional[int],
286
+ decouple: bool,
287
+ ) -> None:
288
+ if interpolation in {None, "no"}:
289
+ return
290
+
291
+ spatial_downsample = self._get_vae_spatial_downsample()
292
+ patch_size = int(self.transformer.config.patch_size)
293
+ latent_h = height // spatial_downsample // patch_size
294
+ latent_w = width // spatial_downsample // patch_size
295
+ self.transformer.configure_rope_extrapolation(
296
+ custom_freqs=interpolation,
297
+ max_pe_len_h=latent_h,
298
+ max_pe_len_w=latent_w,
299
+ ori_max_pe_len=int(ori_max_pe_len),
300
+ decouple=decouple,
301
+ )
302
+
303
  def _apply_classifier_free_guidance(
304
  self,
305
  model_output: torch.Tensor,
 
344
  num_inference_steps: int = 50,
345
  guidance_scale: float = 1.0,
346
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
347
+ interpolation: Optional[str] = None,
348
+ ori_max_pe_len: Optional[int] = None,
349
+ decouple: bool = False,
350
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
351
  output_type: str = "pil",
352
  return_dict: bool = True,
 
367
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
368
  guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
369
  Flow-time interval where CFG is applied.
370
+ interpolation (`str`, *optional*):
371
+ VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or
372
+ `"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`,
373
+ `"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use
374
+ the transformer's configured RoPE.
375
+ ori_max_pe_len (`int`, *optional*):
376
+ Original maximum latent side length seen during training. Required when
377
+ `interpolation` is enabled.
378
+ decouple (`bool`, defaults to `False`):
379
+ Whether to decouple height and width when computing extrapolated RoPE frequencies.
380
  generator (`torch.Generator`, *optional*):
381
  RNG for reproducibility.
382
  output_type (`str`, defaults to `"pil"`):
 
387
  default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
388
  height = int(height or default_size)
389
  width = int(width or default_size)
390
+ self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len)
391
+ self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple)
392
 
393
  device = self._execution_device
394
  model_dtype = next(self.transformer.parameters()).dtype
NiT-S/transformer/nit_transformer_2d.py CHANGED
@@ -74,6 +74,54 @@ def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.
74
  return torch.get_default_dtype()
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class NiTPatchEmbed(nn.Module):
78
  def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
79
  super().__init__()
@@ -125,6 +173,8 @@ class NiTLabelEmbedder(nn.Module):
125
 
126
 
127
  class NiTRotaryEmbedding(nn.Module):
 
 
128
  def __init__(
129
  self,
130
  head_dim: int,
@@ -137,26 +187,127 @@ class NiTRotaryEmbedding(nn.Module):
137
  ori_max_pe_len: Optional[int] = None,
138
  ):
139
  super().__init__()
140
- del max_pe_len_h, max_pe_len_w, decouple, ori_max_pe_len
141
- if custom_freqs not in {"normal", "scale1", "scale2"}:
142
  raise ValueError(
143
- "This Diffusers implementation supports the trained RoPE frequencies directly. "
144
- "Checkpoint conversion preserves weights; extrapolation variants should be handled "
145
- "by changing the model config before loading."
146
  )
 
 
 
 
 
 
 
147
  dim = head_dim // 2
148
  if dim % 2 != 0:
149
  raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
 
 
 
 
 
 
150
  default_dtype = _get_float_dtype_or_default()
151
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
152
- self.register_buffer("freqs_h", freqs, persistent=False)
153
- self.register_buffer("freqs_w", freqs.clone(), persistent=False)
154
- positions = torch.arange(max_cached_len, dtype=default_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
156
  freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
157
  self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
158
  self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
161
  grids = []
162
  for height, width in image_sizes.tolist():
@@ -166,10 +317,12 @@ class NiTRotaryEmbedding(nn.Module):
166
  grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
167
  grids.append(torch.stack(grid, dim=0).reshape(2, -1))
168
  grid = torch.cat(grids, dim=1)
 
169
  freqs_h = self.freqs_h_cached.to(device)[grid[0]]
170
  freqs_w = self.freqs_w_cached.to(device)[grid[1]]
171
  freqs = torch.cat([freqs_h, freqs_w], dim=-1)
172
- return freqs.cos().unsqueeze(1), freqs.sin().unsqueeze(1)
 
173
 
174
 
175
  class NiTAttention(nn.Module):
@@ -367,6 +520,38 @@ class NiTTransformer2DModel(ModelMixin, ConfigMixin):
367
  )
368
  self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
371
  batch_size, channels, height, width = hidden_states.shape
372
  if channels != self.in_channels:
 
74
  return torch.get_default_dtype()
75
 
76
 
77
+ # VisionYaRN / VisionNTK helpers (from native NiT / FiT VisionRotaryEmbedding).
78
+ def _find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
79
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
80
+
81
+
82
+ def _find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
83
+ low = math.floor(_find_correction_factor(low_rot, dim, base, max_position_embeddings))
84
+ high = math.ceil(_find_correction_factor(high_rot, dim, base, max_position_embeddings))
85
+ return max(low, 0), min(high, dim - 1)
86
+
87
+
88
+ def _linear_ramp_mask(min_value, max_value, dim):
89
+ if min_value == max_value:
90
+ max_value += 0.001
91
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min_value) / (max_value - min_value)
92
+ return torch.clamp(linear_func, 0, 1)
93
+
94
+
95
+ def _find_newbase_ntk(dim, base=10000, scale=1):
96
+ return base * scale ** (dim / (dim - 2))
97
+
98
+
99
+ def _get_mscale(scale: torch.Tensor) -> torch.Tensor:
100
+ return torch.where(scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
101
+
102
+
103
+ def _get_proportion(length_test, length_train):
104
+ length_test = length_test * 2
105
+ ratio = length_test / length_train
106
+ return torch.where(
107
+ torch.tensor(ratio) <= 1.0,
108
+ torch.tensor(1.0),
109
+ torch.sqrt(torch.log(torch.tensor(length_test)) / torch.log(torch.tensor(length_train))),
110
+ )
111
+
112
+
113
+ TRAINED_ROPE_FREQS = {"normal", "scale1", "scale2"}
114
+ EXTRAPOLATION_ROPE_FREQS = {
115
+ "linear",
116
+ "ntk-aware",
117
+ "ntk-aware-pro1",
118
+ "ntk-aware-pro2",
119
+ "ntk-by-parts",
120
+ "yarn",
121
+ }
122
+ SUPPORTED_ROPE_FREQS = TRAINED_ROPE_FREQS | EXTRAPOLATION_ROPE_FREQS
123
+
124
+
125
  class NiTPatchEmbed(nn.Module):
126
  def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
127
  super().__init__()
 
173
 
174
 
175
  class NiTRotaryEmbedding(nn.Module):
176
+ """2D axial RoPE with VisionYaRN (`yarn`) and VisionNTK extrapolation modes."""
177
+
178
  def __init__(
179
  self,
180
  head_dim: int,
 
187
  ori_max_pe_len: Optional[int] = None,
188
  ):
189
  super().__init__()
190
+ custom_freqs = custom_freqs.lower()
191
+ if custom_freqs not in SUPPORTED_ROPE_FREQS:
192
  raise ValueError(
193
+ f"Unsupported RoPE frequency variant {custom_freqs!r}. "
194
+ f"Supported values: {sorted(SUPPORTED_ROPE_FREQS)}."
 
195
  )
196
+ if custom_freqs not in TRAINED_ROPE_FREQS and (
197
+ max_pe_len_h is None or max_pe_len_w is None or ori_max_pe_len is None
198
+ ):
199
+ raise ValueError(
200
+ f"Extrapolation mode {custom_freqs!r} requires max_pe_len_h, max_pe_len_w, and ori_max_pe_len."
201
+ )
202
+
203
  dim = head_dim // 2
204
  if dim % 2 != 0:
205
  raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
206
+
207
+ self.dim = dim
208
+ self.custom_freqs = custom_freqs
209
+ self.theta = theta
210
+ self.decouple = decouple
211
+ self.ori_max_pe_len = ori_max_pe_len
212
  default_dtype = _get_float_dtype_or_default()
213
+
214
+ if custom_freqs in TRAINED_ROPE_FREQS:
215
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
216
+ freqs_h = freqs
217
+ freqs_w = freqs.clone()
218
+ else:
219
+ if decouple:
220
+ freqs_h = self._get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len, default_dtype)
221
+ freqs_w = self._get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len, default_dtype)
222
+ else:
223
+ max_pe_len = max(max_pe_len_h, max_pe_len_w)
224
+ freqs = self._get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len, default_dtype)
225
+ freqs_h = freqs
226
+ freqs_w = freqs.clone()
227
+
228
+ if max_pe_len_h is not None and max_pe_len_w is not None and ori_max_pe_len is not None:
229
+ scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)
230
+ proportion1 = _get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
231
+ proportion2 = _get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len**2)
232
+ self.register_buffer("mscale", _get_mscale(scale).to(default_dtype), persistent=False)
233
+ self.register_buffer(
234
+ "proportion1",
235
+ proportion1.to(dtype=default_dtype) if isinstance(proportion1, torch.Tensor) else torch.tensor(float(proportion1), dtype=default_dtype),
236
+ persistent=False,
237
+ )
238
+ self.register_buffer(
239
+ "proportion2",
240
+ proportion2.to(dtype=default_dtype) if isinstance(proportion2, torch.Tensor) else torch.tensor(float(proportion2), dtype=default_dtype),
241
+ persistent=False,
242
+ )
243
+
244
+ self.register_buffer("freqs_h", freqs_h, persistent=False)
245
+ self.register_buffer("freqs_w", freqs_w, persistent=False)
246
+
247
+ cache_len = max(max_cached_len, max_pe_len_h or 0, max_pe_len_w or 0, 1)
248
+ positions = torch.arange(cache_len, dtype=default_dtype)
249
  freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
250
  freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
251
  self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
252
  self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
253
 
254
+ def _get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len, default_dtype):
255
+ assert isinstance(ori_max_pe_len, int)
256
+ if not isinstance(max_pe_len, torch.Tensor):
257
+ max_pe_len = torch.tensor(max_pe_len, dtype=default_dtype)
258
+ scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
259
+ freq_indices = torch.arange(0, dim, 2, dtype=default_dtype) / dim
260
+
261
+ if self.custom_freqs == "linear":
262
+ freqs = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices)
263
+ elif self.custom_freqs in {"ntk-aware", "ntk-aware-pro1", "ntk-aware-pro2"}:
264
+ freqs = 1.0 / torch.pow(
265
+ _find_newbase_ntk(dim, theta, scale).view(-1, 1),
266
+ freq_indices.to(scale),
267
+ ).squeeze()
268
+ elif self.custom_freqs == "ntk-by-parts":
269
+ beta_0, beta_1 = 1.25, 0.75
270
+ gamma_0, gamma_1 = 16, 2
271
+ freqs_base = 1.0 / (theta**freq_indices)
272
+ freqs_linear = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
273
+ freqs_ntk = 1.0 / torch.pow(
274
+ _find_newbase_ntk(dim, theta, scale).view(-1, 1),
275
+ freq_indices.to(scale),
276
+ ).squeeze()
277
+ low, high = _find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
278
+ freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
279
+ freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
280
+ low, high = _find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
281
+ freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
282
+ freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
283
+ elif self.custom_freqs == "yarn":
284
+ beta_fast, beta_slow = 32, 1
285
+ freqs_extrapolation = 1.0 / (theta**freq_indices.to(scale))
286
+ freqs_interpolation = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
287
+ low, high = _find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
288
+ freqs_mask = (1 - _linear_ramp_mask(low, high, dim // 2).to(scale).float())
289
+ freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
290
+ else:
291
+ raise ValueError(f"Unknown extrapolation mode {self.custom_freqs!r}.")
292
+
293
+ if isinstance(freqs, torch.Tensor) and freqs.ndim > 1:
294
+ freqs = freqs.squeeze()
295
+ return freqs.to(default_dtype)
296
+
297
+ def _apply_magnitude_scaling(
298
+ self, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
299
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
300
+ if self.custom_freqs == "yarn" and hasattr(self, "mscale"):
301
+ freqs_cos = freqs_cos * self.mscale
302
+ freqs_sin = freqs_sin * self.mscale
303
+ elif self.custom_freqs in {"ntk-aware-pro1", "scale1"} and hasattr(self, "proportion1"):
304
+ freqs_cos = freqs_cos * self.proportion1
305
+ freqs_sin = freqs_sin * self.proportion1
306
+ elif self.custom_freqs in {"ntk-aware-pro2", "scale2"} and hasattr(self, "proportion2"):
307
+ freqs_cos = freqs_cos * self.proportion2
308
+ freqs_sin = freqs_sin * self.proportion2
309
+ return freqs_cos, freqs_sin
310
+
311
  def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
312
  grids = []
313
  for height, width in image_sizes.tolist():
 
317
  grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
318
  grids.append(torch.stack(grid, dim=0).reshape(2, -1))
319
  grid = torch.cat(grids, dim=1)
320
+
321
  freqs_h = self.freqs_h_cached.to(device)[grid[0]]
322
  freqs_w = self.freqs_w_cached.to(device)[grid[1]]
323
  freqs = torch.cat([freqs_h, freqs_w], dim=-1)
324
+ freqs_cos, freqs_sin = self._apply_magnitude_scaling(freqs.cos(), freqs.sin())
325
+ return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
326
 
327
 
328
  class NiTAttention(nn.Module):
 
520
  )
521
  self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
522
 
523
+ def configure_rope_extrapolation(
524
+ self,
525
+ custom_freqs: str,
526
+ max_pe_len_h: int,
527
+ max_pe_len_w: int,
528
+ ori_max_pe_len: int,
529
+ decouple: bool = False,
530
+ theta: Optional[int] = None,
531
+ ) -> None:
532
+ """Configure VisionYaRN / VisionNTK extrapolation before high-resolution inference."""
533
+ theta = int(theta if theta is not None else getattr(self.config, "theta", 10000))
534
+ head_dim = self.config.hidden_size // self.config.num_heads
535
+ self.rope = NiTRotaryEmbedding(
536
+ head_dim,
537
+ custom_freqs=custom_freqs,
538
+ theta=theta,
539
+ max_pe_len_h=max_pe_len_h,
540
+ max_pe_len_w=max_pe_len_w,
541
+ decouple=decouple,
542
+ ori_max_pe_len=ori_max_pe_len,
543
+ )
544
+ for key, value in {
545
+ "custom_freqs": custom_freqs.lower(),
546
+ "max_pe_len_h": max_pe_len_h,
547
+ "max_pe_len_w": max_pe_len_w,
548
+ "decouple": decouple,
549
+ "ori_max_pe_len": ori_max_pe_len,
550
+ "theta": theta,
551
+ }.items():
552
+ if hasattr(self.config, key):
553
+ setattr(self.config, key, value)
554
+
555
  def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
556
  batch_size, channels, height, width = hidden_states.shape
557
  if channels != self.in_channels:
NiT-XL/pipeline.py CHANGED
@@ -212,11 +212,27 @@ class NiTPipeline(DiffusionPipeline):
212
  width: int,
213
  num_inference_steps: int,
214
  output_type: str,
 
 
215
  ) -> None:
216
  if num_inference_steps < 1:
217
  raise ValueError("num_inference_steps must be >= 1.")
218
  if output_type not in {"pil", "np", "pt", "latent"}:
219
  raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  spatial_downsample = self._get_vae_spatial_downsample()
222
  if height % spatial_downsample != 0 or width % spatial_downsample != 0:
@@ -261,6 +277,29 @@ class NiTPipeline(DiffusionPipeline):
261
  )
262
  return packed_latents, image_sizes
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def _apply_classifier_free_guidance(
265
  self,
266
  model_output: torch.Tensor,
@@ -305,6 +344,9 @@ class NiTPipeline(DiffusionPipeline):
305
  num_inference_steps: int = 50,
306
  guidance_scale: float = 1.0,
307
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
 
 
 
308
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
309
  output_type: str = "pil",
310
  return_dict: bool = True,
@@ -325,6 +367,16 @@ class NiTPipeline(DiffusionPipeline):
325
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
326
  guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
327
  Flow-time interval where CFG is applied.
 
 
 
 
 
 
 
 
 
 
328
  generator (`torch.Generator`, *optional*):
329
  RNG for reproducibility.
330
  output_type (`str`, defaults to `"pil"`):
@@ -335,7 +387,8 @@ class NiTPipeline(DiffusionPipeline):
335
  default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
336
  height = int(height or default_size)
337
  width = int(width or default_size)
338
- self.check_inputs(height, width, num_inference_steps, output_type)
 
339
 
340
  device = self._execution_device
341
  model_dtype = next(self.transformer.parameters()).dtype
 
212
  width: int,
213
  num_inference_steps: int,
214
  output_type: str,
215
+ interpolation: Optional[str] = None,
216
+ ori_max_pe_len: Optional[int] = None,
217
  ) -> None:
218
  if num_inference_steps < 1:
219
  raise ValueError("num_inference_steps must be >= 1.")
220
  if output_type not in {"pil", "np", "pt", "latent"}:
221
  raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
222
+ if interpolation is not None and interpolation not in {
223
+ "no",
224
+ "linear",
225
+ "ntk-aware",
226
+ "ntk-by-parts",
227
+ "yarn",
228
+ "ntk-aware-pro1",
229
+ "ntk-aware-pro2",
230
+ "scale1",
231
+ "scale2",
232
+ }:
233
+ raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.")
234
+ if interpolation not in {None, "no"} and ori_max_pe_len is None:
235
+ raise ValueError("ori_max_pe_len is required when interpolation is enabled.")
236
 
237
  spatial_downsample = self._get_vae_spatial_downsample()
238
  if height % spatial_downsample != 0 or width % spatial_downsample != 0:
 
277
  )
278
  return packed_latents, image_sizes
279
 
280
+ def _maybe_configure_rope_extrapolation(
281
+ self,
282
+ height: int,
283
+ width: int,
284
+ interpolation: Optional[str],
285
+ ori_max_pe_len: Optional[int],
286
+ decouple: bool,
287
+ ) -> None:
288
+ if interpolation in {None, "no"}:
289
+ return
290
+
291
+ spatial_downsample = self._get_vae_spatial_downsample()
292
+ patch_size = int(self.transformer.config.patch_size)
293
+ latent_h = height // spatial_downsample // patch_size
294
+ latent_w = width // spatial_downsample // patch_size
295
+ self.transformer.configure_rope_extrapolation(
296
+ custom_freqs=interpolation,
297
+ max_pe_len_h=latent_h,
298
+ max_pe_len_w=latent_w,
299
+ ori_max_pe_len=int(ori_max_pe_len),
300
+ decouple=decouple,
301
+ )
302
+
303
  def _apply_classifier_free_guidance(
304
  self,
305
  model_output: torch.Tensor,
 
344
  num_inference_steps: int = 50,
345
  guidance_scale: float = 1.0,
346
  guidance_interval: Tuple[float, float] = (0.0, 1.0),
347
+ interpolation: Optional[str] = None,
348
+ ori_max_pe_len: Optional[int] = None,
349
+ decouple: bool = False,
350
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
351
  output_type: str = "pil",
352
  return_dict: bool = True,
 
367
  Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
368
  guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
369
  Flow-time interval where CFG is applied.
370
+ interpolation (`str`, *optional*):
371
+ VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or
372
+ `"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`,
373
+ `"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use
374
+ the transformer's configured RoPE.
375
+ ori_max_pe_len (`int`, *optional*):
376
+ Original maximum latent side length seen during training. Required when
377
+ `interpolation` is enabled.
378
+ decouple (`bool`, defaults to `False`):
379
+ Whether to decouple height and width when computing extrapolated RoPE frequencies.
380
  generator (`torch.Generator`, *optional*):
381
  RNG for reproducibility.
382
  output_type (`str`, defaults to `"pil"`):
 
387
  default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
388
  height = int(height or default_size)
389
  width = int(width or default_size)
390
+ self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len)
391
+ self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple)
392
 
393
  device = self._execution_device
394
  model_dtype = next(self.transformer.parameters()).dtype
NiT-XL/transformer/nit_transformer_2d.py CHANGED
@@ -74,6 +74,54 @@ def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.
74
  return torch.get_default_dtype()
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class NiTPatchEmbed(nn.Module):
78
  def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
79
  super().__init__()
@@ -125,6 +173,8 @@ class NiTLabelEmbedder(nn.Module):
125
 
126
 
127
  class NiTRotaryEmbedding(nn.Module):
 
 
128
  def __init__(
129
  self,
130
  head_dim: int,
@@ -137,26 +187,127 @@ class NiTRotaryEmbedding(nn.Module):
137
  ori_max_pe_len: Optional[int] = None,
138
  ):
139
  super().__init__()
140
- del max_pe_len_h, max_pe_len_w, decouple, ori_max_pe_len
141
- if custom_freqs not in {"normal", "scale1", "scale2"}:
142
  raise ValueError(
143
- "This Diffusers implementation supports the trained RoPE frequencies directly. "
144
- "Checkpoint conversion preserves weights; extrapolation variants should be handled "
145
- "by changing the model config before loading."
146
  )
 
 
 
 
 
 
 
147
  dim = head_dim // 2
148
  if dim % 2 != 0:
149
  raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
 
 
 
 
 
 
150
  default_dtype = _get_float_dtype_or_default()
151
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
152
- self.register_buffer("freqs_h", freqs, persistent=False)
153
- self.register_buffer("freqs_w", freqs.clone(), persistent=False)
154
- positions = torch.arange(max_cached_len, dtype=default_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
156
  freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
157
  self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
158
  self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
161
  grids = []
162
  for height, width in image_sizes.tolist():
@@ -166,10 +317,12 @@ class NiTRotaryEmbedding(nn.Module):
166
  grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
167
  grids.append(torch.stack(grid, dim=0).reshape(2, -1))
168
  grid = torch.cat(grids, dim=1)
 
169
  freqs_h = self.freqs_h_cached.to(device)[grid[0]]
170
  freqs_w = self.freqs_w_cached.to(device)[grid[1]]
171
  freqs = torch.cat([freqs_h, freqs_w], dim=-1)
172
- return freqs.cos().unsqueeze(1), freqs.sin().unsqueeze(1)
 
173
 
174
 
175
  class NiTAttention(nn.Module):
@@ -367,6 +520,38 @@ class NiTTransformer2DModel(ModelMixin, ConfigMixin):
367
  )
368
  self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
371
  batch_size, channels, height, width = hidden_states.shape
372
  if channels != self.in_channels:
 
74
  return torch.get_default_dtype()
75
 
76
 
77
+ # VisionYaRN / VisionNTK helpers (from native NiT / FiT VisionRotaryEmbedding).
78
+ def _find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
79
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
80
+
81
+
82
+ def _find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
83
+ low = math.floor(_find_correction_factor(low_rot, dim, base, max_position_embeddings))
84
+ high = math.ceil(_find_correction_factor(high_rot, dim, base, max_position_embeddings))
85
+ return max(low, 0), min(high, dim - 1)
86
+
87
+
88
+ def _linear_ramp_mask(min_value, max_value, dim):
89
+ if min_value == max_value:
90
+ max_value += 0.001
91
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min_value) / (max_value - min_value)
92
+ return torch.clamp(linear_func, 0, 1)
93
+
94
+
95
+ def _find_newbase_ntk(dim, base=10000, scale=1):
96
+ return base * scale ** (dim / (dim - 2))
97
+
98
+
99
+ def _get_mscale(scale: torch.Tensor) -> torch.Tensor:
100
+ return torch.where(scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
101
+
102
+
103
+ def _get_proportion(length_test, length_train):
104
+ length_test = length_test * 2
105
+ ratio = length_test / length_train
106
+ return torch.where(
107
+ torch.tensor(ratio) <= 1.0,
108
+ torch.tensor(1.0),
109
+ torch.sqrt(torch.log(torch.tensor(length_test)) / torch.log(torch.tensor(length_train))),
110
+ )
111
+
112
+
113
+ TRAINED_ROPE_FREQS = {"normal", "scale1", "scale2"}
114
+ EXTRAPOLATION_ROPE_FREQS = {
115
+ "linear",
116
+ "ntk-aware",
117
+ "ntk-aware-pro1",
118
+ "ntk-aware-pro2",
119
+ "ntk-by-parts",
120
+ "yarn",
121
+ }
122
+ SUPPORTED_ROPE_FREQS = TRAINED_ROPE_FREQS | EXTRAPOLATION_ROPE_FREQS
123
+
124
+
125
  class NiTPatchEmbed(nn.Module):
126
  def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
127
  super().__init__()
 
173
 
174
 
175
  class NiTRotaryEmbedding(nn.Module):
176
+ """2D axial RoPE with VisionYaRN (`yarn`) and VisionNTK extrapolation modes."""
177
+
178
  def __init__(
179
  self,
180
  head_dim: int,
 
187
  ori_max_pe_len: Optional[int] = None,
188
  ):
189
  super().__init__()
190
+ custom_freqs = custom_freqs.lower()
191
+ if custom_freqs not in SUPPORTED_ROPE_FREQS:
192
  raise ValueError(
193
+ f"Unsupported RoPE frequency variant {custom_freqs!r}. "
194
+ f"Supported values: {sorted(SUPPORTED_ROPE_FREQS)}."
 
195
  )
196
+ if custom_freqs not in TRAINED_ROPE_FREQS and (
197
+ max_pe_len_h is None or max_pe_len_w is None or ori_max_pe_len is None
198
+ ):
199
+ raise ValueError(
200
+ f"Extrapolation mode {custom_freqs!r} requires max_pe_len_h, max_pe_len_w, and ori_max_pe_len."
201
+ )
202
+
203
  dim = head_dim // 2
204
  if dim % 2 != 0:
205
  raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
206
+
207
+ self.dim = dim
208
+ self.custom_freqs = custom_freqs
209
+ self.theta = theta
210
+ self.decouple = decouple
211
+ self.ori_max_pe_len = ori_max_pe_len
212
  default_dtype = _get_float_dtype_or_default()
213
+
214
+ if custom_freqs in TRAINED_ROPE_FREQS:
215
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
216
+ freqs_h = freqs
217
+ freqs_w = freqs.clone()
218
+ else:
219
+ if decouple:
220
+ freqs_h = self._get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len, default_dtype)
221
+ freqs_w = self._get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len, default_dtype)
222
+ else:
223
+ max_pe_len = max(max_pe_len_h, max_pe_len_w)
224
+ freqs = self._get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len, default_dtype)
225
+ freqs_h = freqs
226
+ freqs_w = freqs.clone()
227
+
228
+ if max_pe_len_h is not None and max_pe_len_w is not None and ori_max_pe_len is not None:
229
+ scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)
230
+ proportion1 = _get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
231
+ proportion2 = _get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len**2)
232
+ self.register_buffer("mscale", _get_mscale(scale).to(default_dtype), persistent=False)
233
+ self.register_buffer(
234
+ "proportion1",
235
+ proportion1.to(dtype=default_dtype) if isinstance(proportion1, torch.Tensor) else torch.tensor(float(proportion1), dtype=default_dtype),
236
+ persistent=False,
237
+ )
238
+ self.register_buffer(
239
+ "proportion2",
240
+ proportion2.to(dtype=default_dtype) if isinstance(proportion2, torch.Tensor) else torch.tensor(float(proportion2), dtype=default_dtype),
241
+ persistent=False,
242
+ )
243
+
244
+ self.register_buffer("freqs_h", freqs_h, persistent=False)
245
+ self.register_buffer("freqs_w", freqs_w, persistent=False)
246
+
247
+ cache_len = max(max_cached_len, max_pe_len_h or 0, max_pe_len_w or 0, 1)
248
+ positions = torch.arange(cache_len, dtype=default_dtype)
249
  freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
250
  freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
251
  self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
252
  self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
253
 
254
+ def _get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len, default_dtype):
255
+ assert isinstance(ori_max_pe_len, int)
256
+ if not isinstance(max_pe_len, torch.Tensor):
257
+ max_pe_len = torch.tensor(max_pe_len, dtype=default_dtype)
258
+ scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
259
+ freq_indices = torch.arange(0, dim, 2, dtype=default_dtype) / dim
260
+
261
+ if self.custom_freqs == "linear":
262
+ freqs = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices)
263
+ elif self.custom_freqs in {"ntk-aware", "ntk-aware-pro1", "ntk-aware-pro2"}:
264
+ freqs = 1.0 / torch.pow(
265
+ _find_newbase_ntk(dim, theta, scale).view(-1, 1),
266
+ freq_indices.to(scale),
267
+ ).squeeze()
268
+ elif self.custom_freqs == "ntk-by-parts":
269
+ beta_0, beta_1 = 1.25, 0.75
270
+ gamma_0, gamma_1 = 16, 2
271
+ freqs_base = 1.0 / (theta**freq_indices)
272
+ freqs_linear = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
273
+ freqs_ntk = 1.0 / torch.pow(
274
+ _find_newbase_ntk(dim, theta, scale).view(-1, 1),
275
+ freq_indices.to(scale),
276
+ ).squeeze()
277
+ low, high = _find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
278
+ freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
279
+ freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
280
+ low, high = _find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
281
+ freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
282
+ freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
283
+ elif self.custom_freqs == "yarn":
284
+ beta_fast, beta_slow = 32, 1
285
+ freqs_extrapolation = 1.0 / (theta**freq_indices.to(scale))
286
+ freqs_interpolation = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
287
+ low, high = _find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
288
+ freqs_mask = (1 - _linear_ramp_mask(low, high, dim // 2).to(scale).float())
289
+ freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
290
+ else:
291
+ raise ValueError(f"Unknown extrapolation mode {self.custom_freqs!r}.")
292
+
293
+ if isinstance(freqs, torch.Tensor) and freqs.ndim > 1:
294
+ freqs = freqs.squeeze()
295
+ return freqs.to(default_dtype)
296
+
297
+ def _apply_magnitude_scaling(
298
+ self, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
299
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
300
+ if self.custom_freqs == "yarn" and hasattr(self, "mscale"):
301
+ freqs_cos = freqs_cos * self.mscale
302
+ freqs_sin = freqs_sin * self.mscale
303
+ elif self.custom_freqs in {"ntk-aware-pro1", "scale1"} and hasattr(self, "proportion1"):
304
+ freqs_cos = freqs_cos * self.proportion1
305
+ freqs_sin = freqs_sin * self.proportion1
306
+ elif self.custom_freqs in {"ntk-aware-pro2", "scale2"} and hasattr(self, "proportion2"):
307
+ freqs_cos = freqs_cos * self.proportion2
308
+ freqs_sin = freqs_sin * self.proportion2
309
+ return freqs_cos, freqs_sin
310
+
311
  def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
312
  grids = []
313
  for height, width in image_sizes.tolist():
 
317
  grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
318
  grids.append(torch.stack(grid, dim=0).reshape(2, -1))
319
  grid = torch.cat(grids, dim=1)
320
+
321
  freqs_h = self.freqs_h_cached.to(device)[grid[0]]
322
  freqs_w = self.freqs_w_cached.to(device)[grid[1]]
323
  freqs = torch.cat([freqs_h, freqs_w], dim=-1)
324
+ freqs_cos, freqs_sin = self._apply_magnitude_scaling(freqs.cos(), freqs.sin())
325
+ return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
326
 
327
 
328
  class NiTAttention(nn.Module):
 
520
  )
521
  self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
522
 
523
+ def configure_rope_extrapolation(
524
+ self,
525
+ custom_freqs: str,
526
+ max_pe_len_h: int,
527
+ max_pe_len_w: int,
528
+ ori_max_pe_len: int,
529
+ decouple: bool = False,
530
+ theta: Optional[int] = None,
531
+ ) -> None:
532
+ """Configure VisionYaRN / VisionNTK extrapolation before high-resolution inference."""
533
+ theta = int(theta if theta is not None else getattr(self.config, "theta", 10000))
534
+ head_dim = self.config.hidden_size // self.config.num_heads
535
+ self.rope = NiTRotaryEmbedding(
536
+ head_dim,
537
+ custom_freqs=custom_freqs,
538
+ theta=theta,
539
+ max_pe_len_h=max_pe_len_h,
540
+ max_pe_len_w=max_pe_len_w,
541
+ decouple=decouple,
542
+ ori_max_pe_len=ori_max_pe_len,
543
+ )
544
+ for key, value in {
545
+ "custom_freqs": custom_freqs.lower(),
546
+ "max_pe_len_h": max_pe_len_h,
547
+ "max_pe_len_w": max_pe_len_w,
548
+ "decouple": decouple,
549
+ "ori_max_pe_len": ori_max_pe_len,
550
+ "theta": theta,
551
+ }.items():
552
+ if hasattr(self.config, key):
553
+ setattr(self.config, key, value)
554
+
555
  def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
556
  batch_size, channels, height, width = hidden_states.shape
557
  if channels != self.in_channels: