Instructions to use BiliSakura/NiT-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/NiT-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/NiT-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- NiT-B/pipeline.py +54 -1
- NiT-B/transformer/nit_transformer_2d.py +195 -10
- NiT-L/pipeline.py +54 -1
- NiT-L/transformer/nit_transformer_2d.py +195 -10
- NiT-S/pipeline.py +54 -1
- NiT-S/transformer/nit_transformer_2d.py +195 -10
- NiT-XL/pipeline.py +54 -1
- NiT-XL/transformer/nit_transformer_2d.py +195 -10
NiT-B/pipeline.py
CHANGED
|
@@ -212,11 +212,27 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 212 |
width: int,
|
| 213 |
num_inference_steps: int,
|
| 214 |
output_type: str,
|
|
|
|
|
|
|
| 215 |
) -> None:
|
| 216 |
if num_inference_steps < 1:
|
| 217 |
raise ValueError("num_inference_steps must be >= 1.")
|
| 218 |
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 219 |
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
spatial_downsample = self._get_vae_spatial_downsample()
|
| 222 |
if height % spatial_downsample != 0 or width % spatial_downsample != 0:
|
|
@@ -261,6 +277,29 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 261 |
)
|
| 262 |
return packed_latents, image_sizes
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def _apply_classifier_free_guidance(
|
| 265 |
self,
|
| 266 |
model_output: torch.Tensor,
|
|
@@ -305,6 +344,9 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 305 |
num_inference_steps: int = 50,
|
| 306 |
guidance_scale: float = 1.0,
|
| 307 |
guidance_interval: Tuple[float, float] = (0.0, 1.0),
|
|
|
|
|
|
|
|
|
|
| 308 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 309 |
output_type: str = "pil",
|
| 310 |
return_dict: bool = True,
|
|
@@ -325,6 +367,16 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 325 |
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 326 |
guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
|
| 327 |
Flow-time interval where CFG is applied.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
generator (`torch.Generator`, *optional*):
|
| 329 |
RNG for reproducibility.
|
| 330 |
output_type (`str`, defaults to `"pil"`):
|
|
@@ -335,7 +387,8 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 335 |
default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
|
| 336 |
height = int(height or default_size)
|
| 337 |
width = int(width or default_size)
|
| 338 |
-
self.check_inputs(height, width, num_inference_steps, output_type)
|
|
|
|
| 339 |
|
| 340 |
device = self._execution_device
|
| 341 |
model_dtype = next(self.transformer.parameters()).dtype
|
|
|
|
| 212 |
width: int,
|
| 213 |
num_inference_steps: int,
|
| 214 |
output_type: str,
|
| 215 |
+
interpolation: Optional[str] = None,
|
| 216 |
+
ori_max_pe_len: Optional[int] = None,
|
| 217 |
) -> None:
|
| 218 |
if num_inference_steps < 1:
|
| 219 |
raise ValueError("num_inference_steps must be >= 1.")
|
| 220 |
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 221 |
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 222 |
+
if interpolation is not None and interpolation not in {
|
| 223 |
+
"no",
|
| 224 |
+
"linear",
|
| 225 |
+
"ntk-aware",
|
| 226 |
+
"ntk-by-parts",
|
| 227 |
+
"yarn",
|
| 228 |
+
"ntk-aware-pro1",
|
| 229 |
+
"ntk-aware-pro2",
|
| 230 |
+
"scale1",
|
| 231 |
+
"scale2",
|
| 232 |
+
}:
|
| 233 |
+
raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.")
|
| 234 |
+
if interpolation not in {None, "no"} and ori_max_pe_len is None:
|
| 235 |
+
raise ValueError("ori_max_pe_len is required when interpolation is enabled.")
|
| 236 |
|
| 237 |
spatial_downsample = self._get_vae_spatial_downsample()
|
| 238 |
if height % spatial_downsample != 0 or width % spatial_downsample != 0:
|
|
|
|
| 277 |
)
|
| 278 |
return packed_latents, image_sizes
|
| 279 |
|
| 280 |
+
def _maybe_configure_rope_extrapolation(
|
| 281 |
+
self,
|
| 282 |
+
height: int,
|
| 283 |
+
width: int,
|
| 284 |
+
interpolation: Optional[str],
|
| 285 |
+
ori_max_pe_len: Optional[int],
|
| 286 |
+
decouple: bool,
|
| 287 |
+
) -> None:
|
| 288 |
+
if interpolation in {None, "no"}:
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
spatial_downsample = self._get_vae_spatial_downsample()
|
| 292 |
+
patch_size = int(self.transformer.config.patch_size)
|
| 293 |
+
latent_h = height // spatial_downsample // patch_size
|
| 294 |
+
latent_w = width // spatial_downsample // patch_size
|
| 295 |
+
self.transformer.configure_rope_extrapolation(
|
| 296 |
+
custom_freqs=interpolation,
|
| 297 |
+
max_pe_len_h=latent_h,
|
| 298 |
+
max_pe_len_w=latent_w,
|
| 299 |
+
ori_max_pe_len=int(ori_max_pe_len),
|
| 300 |
+
decouple=decouple,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
def _apply_classifier_free_guidance(
|
| 304 |
self,
|
| 305 |
model_output: torch.Tensor,
|
|
|
|
| 344 |
num_inference_steps: int = 50,
|
| 345 |
guidance_scale: float = 1.0,
|
| 346 |
guidance_interval: Tuple[float, float] = (0.0, 1.0),
|
| 347 |
+
interpolation: Optional[str] = None,
|
| 348 |
+
ori_max_pe_len: Optional[int] = None,
|
| 349 |
+
decouple: bool = False,
|
| 350 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 351 |
output_type: str = "pil",
|
| 352 |
return_dict: bool = True,
|
|
|
|
| 367 |
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 368 |
guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
|
| 369 |
Flow-time interval where CFG is applied.
|
| 370 |
+
interpolation (`str`, *optional*):
|
| 371 |
+
VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or
|
| 372 |
+
`"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`,
|
| 373 |
+
`"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use
|
| 374 |
+
the transformer's configured RoPE.
|
| 375 |
+
ori_max_pe_len (`int`, *optional*):
|
| 376 |
+
Original maximum latent side length seen during training. Required when
|
| 377 |
+
`interpolation` is enabled.
|
| 378 |
+
decouple (`bool`, defaults to `False`):
|
| 379 |
+
Whether to decouple height and width when computing extrapolated RoPE frequencies.
|
| 380 |
generator (`torch.Generator`, *optional*):
|
| 381 |
RNG for reproducibility.
|
| 382 |
output_type (`str`, defaults to `"pil"`):
|
|
|
|
| 387 |
default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
|
| 388 |
height = int(height or default_size)
|
| 389 |
width = int(width or default_size)
|
| 390 |
+
self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len)
|
| 391 |
+
self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple)
|
| 392 |
|
| 393 |
device = self._execution_device
|
| 394 |
model_dtype = next(self.transformer.parameters()).dtype
|
NiT-B/transformer/nit_transformer_2d.py
CHANGED
|
@@ -74,6 +74,54 @@ def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.
|
|
| 74 |
return torch.get_default_dtype()
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
class NiTPatchEmbed(nn.Module):
|
| 78 |
def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
|
| 79 |
super().__init__()
|
|
@@ -125,6 +173,8 @@ class NiTLabelEmbedder(nn.Module):
|
|
| 125 |
|
| 126 |
|
| 127 |
class NiTRotaryEmbedding(nn.Module):
|
|
|
|
|
|
|
| 128 |
def __init__(
|
| 129 |
self,
|
| 130 |
head_dim: int,
|
|
@@ -137,26 +187,127 @@ class NiTRotaryEmbedding(nn.Module):
|
|
| 137 |
ori_max_pe_len: Optional[int] = None,
|
| 138 |
):
|
| 139 |
super().__init__()
|
| 140 |
-
|
| 141 |
-
if custom_freqs not in
|
| 142 |
raise ValueError(
|
| 143 |
-
"
|
| 144 |
-
"
|
| 145 |
-
"by changing the model config before loading."
|
| 146 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
dim = head_dim // 2
|
| 148 |
if dim % 2 != 0:
|
| 149 |
raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
default_dtype = _get_float_dtype_or_default()
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
|
| 156 |
freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
|
| 157 |
self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
|
| 158 |
self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 161 |
grids = []
|
| 162 |
for height, width in image_sizes.tolist():
|
|
@@ -166,10 +317,12 @@ class NiTRotaryEmbedding(nn.Module):
|
|
| 166 |
grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
|
| 167 |
grids.append(torch.stack(grid, dim=0).reshape(2, -1))
|
| 168 |
grid = torch.cat(grids, dim=1)
|
|
|
|
| 169 |
freqs_h = self.freqs_h_cached.to(device)[grid[0]]
|
| 170 |
freqs_w = self.freqs_w_cached.to(device)[grid[1]]
|
| 171 |
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
| 172 |
-
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
class NiTAttention(nn.Module):
|
|
@@ -367,6 +520,38 @@ class NiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
| 367 |
)
|
| 368 |
self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
|
| 371 |
batch_size, channels, height, width = hidden_states.shape
|
| 372 |
if channels != self.in_channels:
|
|
|
|
| 74 |
return torch.get_default_dtype()
|
| 75 |
|
| 76 |
|
| 77 |
+
# VisionYaRN / VisionNTK helpers (from native NiT / FiT VisionRotaryEmbedding).
|
| 78 |
+
def _find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
| 79 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
| 83 |
+
low = math.floor(_find_correction_factor(low_rot, dim, base, max_position_embeddings))
|
| 84 |
+
high = math.ceil(_find_correction_factor(high_rot, dim, base, max_position_embeddings))
|
| 85 |
+
return max(low, 0), min(high, dim - 1)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _linear_ramp_mask(min_value, max_value, dim):
|
| 89 |
+
if min_value == max_value:
|
| 90 |
+
max_value += 0.001
|
| 91 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min_value) / (max_value - min_value)
|
| 92 |
+
return torch.clamp(linear_func, 0, 1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _find_newbase_ntk(dim, base=10000, scale=1):
|
| 96 |
+
return base * scale ** (dim / (dim - 2))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _get_mscale(scale: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
return torch.where(scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _get_proportion(length_test, length_train):
|
| 104 |
+
length_test = length_test * 2
|
| 105 |
+
ratio = length_test / length_train
|
| 106 |
+
return torch.where(
|
| 107 |
+
torch.tensor(ratio) <= 1.0,
|
| 108 |
+
torch.tensor(1.0),
|
| 109 |
+
torch.sqrt(torch.log(torch.tensor(length_test)) / torch.log(torch.tensor(length_train))),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
TRAINED_ROPE_FREQS = {"normal", "scale1", "scale2"}
|
| 114 |
+
EXTRAPOLATION_ROPE_FREQS = {
|
| 115 |
+
"linear",
|
| 116 |
+
"ntk-aware",
|
| 117 |
+
"ntk-aware-pro1",
|
| 118 |
+
"ntk-aware-pro2",
|
| 119 |
+
"ntk-by-parts",
|
| 120 |
+
"yarn",
|
| 121 |
+
}
|
| 122 |
+
SUPPORTED_ROPE_FREQS = TRAINED_ROPE_FREQS | EXTRAPOLATION_ROPE_FREQS
|
| 123 |
+
|
| 124 |
+
|
| 125 |
class NiTPatchEmbed(nn.Module):
|
| 126 |
def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
|
| 127 |
super().__init__()
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
class NiTRotaryEmbedding(nn.Module):
|
| 176 |
+
"""2D axial RoPE with VisionYaRN (`yarn`) and VisionNTK extrapolation modes."""
|
| 177 |
+
|
| 178 |
def __init__(
|
| 179 |
self,
|
| 180 |
head_dim: int,
|
|
|
|
| 187 |
ori_max_pe_len: Optional[int] = None,
|
| 188 |
):
|
| 189 |
super().__init__()
|
| 190 |
+
custom_freqs = custom_freqs.lower()
|
| 191 |
+
if custom_freqs not in SUPPORTED_ROPE_FREQS:
|
| 192 |
raise ValueError(
|
| 193 |
+
f"Unsupported RoPE frequency variant {custom_freqs!r}. "
|
| 194 |
+
f"Supported values: {sorted(SUPPORTED_ROPE_FREQS)}."
|
|
|
|
| 195 |
)
|
| 196 |
+
if custom_freqs not in TRAINED_ROPE_FREQS and (
|
| 197 |
+
max_pe_len_h is None or max_pe_len_w is None or ori_max_pe_len is None
|
| 198 |
+
):
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"Extrapolation mode {custom_freqs!r} requires max_pe_len_h, max_pe_len_w, and ori_max_pe_len."
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
dim = head_dim // 2
|
| 204 |
if dim % 2 != 0:
|
| 205 |
raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
|
| 206 |
+
|
| 207 |
+
self.dim = dim
|
| 208 |
+
self.custom_freqs = custom_freqs
|
| 209 |
+
self.theta = theta
|
| 210 |
+
self.decouple = decouple
|
| 211 |
+
self.ori_max_pe_len = ori_max_pe_len
|
| 212 |
default_dtype = _get_float_dtype_or_default()
|
| 213 |
+
|
| 214 |
+
if custom_freqs in TRAINED_ROPE_FREQS:
|
| 215 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
|
| 216 |
+
freqs_h = freqs
|
| 217 |
+
freqs_w = freqs.clone()
|
| 218 |
+
else:
|
| 219 |
+
if decouple:
|
| 220 |
+
freqs_h = self._get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len, default_dtype)
|
| 221 |
+
freqs_w = self._get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len, default_dtype)
|
| 222 |
+
else:
|
| 223 |
+
max_pe_len = max(max_pe_len_h, max_pe_len_w)
|
| 224 |
+
freqs = self._get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len, default_dtype)
|
| 225 |
+
freqs_h = freqs
|
| 226 |
+
freqs_w = freqs.clone()
|
| 227 |
+
|
| 228 |
+
if max_pe_len_h is not None and max_pe_len_w is not None and ori_max_pe_len is not None:
|
| 229 |
+
scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)
|
| 230 |
+
proportion1 = _get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
|
| 231 |
+
proportion2 = _get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len**2)
|
| 232 |
+
self.register_buffer("mscale", _get_mscale(scale).to(default_dtype), persistent=False)
|
| 233 |
+
self.register_buffer(
|
| 234 |
+
"proportion1",
|
| 235 |
+
proportion1.to(dtype=default_dtype) if isinstance(proportion1, torch.Tensor) else torch.tensor(float(proportion1), dtype=default_dtype),
|
| 236 |
+
persistent=False,
|
| 237 |
+
)
|
| 238 |
+
self.register_buffer(
|
| 239 |
+
"proportion2",
|
| 240 |
+
proportion2.to(dtype=default_dtype) if isinstance(proportion2, torch.Tensor) else torch.tensor(float(proportion2), dtype=default_dtype),
|
| 241 |
+
persistent=False,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.register_buffer("freqs_h", freqs_h, persistent=False)
|
| 245 |
+
self.register_buffer("freqs_w", freqs_w, persistent=False)
|
| 246 |
+
|
| 247 |
+
cache_len = max(max_cached_len, max_pe_len_h or 0, max_pe_len_w or 0, 1)
|
| 248 |
+
positions = torch.arange(cache_len, dtype=default_dtype)
|
| 249 |
freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
|
| 250 |
freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
|
| 251 |
self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
|
| 252 |
self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
|
| 253 |
|
| 254 |
+
def _get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len, default_dtype):
|
| 255 |
+
assert isinstance(ori_max_pe_len, int)
|
| 256 |
+
if not isinstance(max_pe_len, torch.Tensor):
|
| 257 |
+
max_pe_len = torch.tensor(max_pe_len, dtype=default_dtype)
|
| 258 |
+
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
| 259 |
+
freq_indices = torch.arange(0, dim, 2, dtype=default_dtype) / dim
|
| 260 |
+
|
| 261 |
+
if self.custom_freqs == "linear":
|
| 262 |
+
freqs = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices)
|
| 263 |
+
elif self.custom_freqs in {"ntk-aware", "ntk-aware-pro1", "ntk-aware-pro2"}:
|
| 264 |
+
freqs = 1.0 / torch.pow(
|
| 265 |
+
_find_newbase_ntk(dim, theta, scale).view(-1, 1),
|
| 266 |
+
freq_indices.to(scale),
|
| 267 |
+
).squeeze()
|
| 268 |
+
elif self.custom_freqs == "ntk-by-parts":
|
| 269 |
+
beta_0, beta_1 = 1.25, 0.75
|
| 270 |
+
gamma_0, gamma_1 = 16, 2
|
| 271 |
+
freqs_base = 1.0 / (theta**freq_indices)
|
| 272 |
+
freqs_linear = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
|
| 273 |
+
freqs_ntk = 1.0 / torch.pow(
|
| 274 |
+
_find_newbase_ntk(dim, theta, scale).view(-1, 1),
|
| 275 |
+
freq_indices.to(scale),
|
| 276 |
+
).squeeze()
|
| 277 |
+
low, high = _find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
| 278 |
+
freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
|
| 279 |
+
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
| 280 |
+
low, high = _find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
| 281 |
+
freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
|
| 282 |
+
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
|
| 283 |
+
elif self.custom_freqs == "yarn":
|
| 284 |
+
beta_fast, beta_slow = 32, 1
|
| 285 |
+
freqs_extrapolation = 1.0 / (theta**freq_indices.to(scale))
|
| 286 |
+
freqs_interpolation = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
|
| 287 |
+
low, high = _find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
|
| 288 |
+
freqs_mask = (1 - _linear_ramp_mask(low, high, dim // 2).to(scale).float())
|
| 289 |
+
freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
|
| 290 |
+
else:
|
| 291 |
+
raise ValueError(f"Unknown extrapolation mode {self.custom_freqs!r}.")
|
| 292 |
+
|
| 293 |
+
if isinstance(freqs, torch.Tensor) and freqs.ndim > 1:
|
| 294 |
+
freqs = freqs.squeeze()
|
| 295 |
+
return freqs.to(default_dtype)
|
| 296 |
+
|
| 297 |
+
def _apply_magnitude_scaling(
|
| 298 |
+
self, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
|
| 299 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 300 |
+
if self.custom_freqs == "yarn" and hasattr(self, "mscale"):
|
| 301 |
+
freqs_cos = freqs_cos * self.mscale
|
| 302 |
+
freqs_sin = freqs_sin * self.mscale
|
| 303 |
+
elif self.custom_freqs in {"ntk-aware-pro1", "scale1"} and hasattr(self, "proportion1"):
|
| 304 |
+
freqs_cos = freqs_cos * self.proportion1
|
| 305 |
+
freqs_sin = freqs_sin * self.proportion1
|
| 306 |
+
elif self.custom_freqs in {"ntk-aware-pro2", "scale2"} and hasattr(self, "proportion2"):
|
| 307 |
+
freqs_cos = freqs_cos * self.proportion2
|
| 308 |
+
freqs_sin = freqs_sin * self.proportion2
|
| 309 |
+
return freqs_cos, freqs_sin
|
| 310 |
+
|
| 311 |
def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 312 |
grids = []
|
| 313 |
for height, width in image_sizes.tolist():
|
|
|
|
| 317 |
grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
|
| 318 |
grids.append(torch.stack(grid, dim=0).reshape(2, -1))
|
| 319 |
grid = torch.cat(grids, dim=1)
|
| 320 |
+
|
| 321 |
freqs_h = self.freqs_h_cached.to(device)[grid[0]]
|
| 322 |
freqs_w = self.freqs_w_cached.to(device)[grid[1]]
|
| 323 |
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
| 324 |
+
freqs_cos, freqs_sin = self._apply_magnitude_scaling(freqs.cos(), freqs.sin())
|
| 325 |
+
return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
|
| 326 |
|
| 327 |
|
| 328 |
class NiTAttention(nn.Module):
|
|
|
|
| 520 |
)
|
| 521 |
self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
|
| 522 |
|
| 523 |
+
def configure_rope_extrapolation(
|
| 524 |
+
self,
|
| 525 |
+
custom_freqs: str,
|
| 526 |
+
max_pe_len_h: int,
|
| 527 |
+
max_pe_len_w: int,
|
| 528 |
+
ori_max_pe_len: int,
|
| 529 |
+
decouple: bool = False,
|
| 530 |
+
theta: Optional[int] = None,
|
| 531 |
+
) -> None:
|
| 532 |
+
"""Configure VisionYaRN / VisionNTK extrapolation before high-resolution inference."""
|
| 533 |
+
theta = int(theta if theta is not None else getattr(self.config, "theta", 10000))
|
| 534 |
+
head_dim = self.config.hidden_size // self.config.num_heads
|
| 535 |
+
self.rope = NiTRotaryEmbedding(
|
| 536 |
+
head_dim,
|
| 537 |
+
custom_freqs=custom_freqs,
|
| 538 |
+
theta=theta,
|
| 539 |
+
max_pe_len_h=max_pe_len_h,
|
| 540 |
+
max_pe_len_w=max_pe_len_w,
|
| 541 |
+
decouple=decouple,
|
| 542 |
+
ori_max_pe_len=ori_max_pe_len,
|
| 543 |
+
)
|
| 544 |
+
for key, value in {
|
| 545 |
+
"custom_freqs": custom_freqs.lower(),
|
| 546 |
+
"max_pe_len_h": max_pe_len_h,
|
| 547 |
+
"max_pe_len_w": max_pe_len_w,
|
| 548 |
+
"decouple": decouple,
|
| 549 |
+
"ori_max_pe_len": ori_max_pe_len,
|
| 550 |
+
"theta": theta,
|
| 551 |
+
}.items():
|
| 552 |
+
if hasattr(self.config, key):
|
| 553 |
+
setattr(self.config, key, value)
|
| 554 |
+
|
| 555 |
def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
|
| 556 |
batch_size, channels, height, width = hidden_states.shape
|
| 557 |
if channels != self.in_channels:
|
NiT-L/pipeline.py
CHANGED
|
@@ -212,11 +212,27 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 212 |
width: int,
|
| 213 |
num_inference_steps: int,
|
| 214 |
output_type: str,
|
|
|
|
|
|
|
| 215 |
) -> None:
|
| 216 |
if num_inference_steps < 1:
|
| 217 |
raise ValueError("num_inference_steps must be >= 1.")
|
| 218 |
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 219 |
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
spatial_downsample = self._get_vae_spatial_downsample()
|
| 222 |
if height % spatial_downsample != 0 or width % spatial_downsample != 0:
|
|
@@ -261,6 +277,29 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 261 |
)
|
| 262 |
return packed_latents, image_sizes
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def _apply_classifier_free_guidance(
|
| 265 |
self,
|
| 266 |
model_output: torch.Tensor,
|
|
@@ -305,6 +344,9 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 305 |
num_inference_steps: int = 50,
|
| 306 |
guidance_scale: float = 1.0,
|
| 307 |
guidance_interval: Tuple[float, float] = (0.0, 1.0),
|
|
|
|
|
|
|
|
|
|
| 308 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 309 |
output_type: str = "pil",
|
| 310 |
return_dict: bool = True,
|
|
@@ -325,6 +367,16 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 325 |
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 326 |
guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
|
| 327 |
Flow-time interval where CFG is applied.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
generator (`torch.Generator`, *optional*):
|
| 329 |
RNG for reproducibility.
|
| 330 |
output_type (`str`, defaults to `"pil"`):
|
|
@@ -335,7 +387,8 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 335 |
default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
|
| 336 |
height = int(height or default_size)
|
| 337 |
width = int(width or default_size)
|
| 338 |
-
self.check_inputs(height, width, num_inference_steps, output_type)
|
|
|
|
| 339 |
|
| 340 |
device = self._execution_device
|
| 341 |
model_dtype = next(self.transformer.parameters()).dtype
|
|
|
|
| 212 |
width: int,
|
| 213 |
num_inference_steps: int,
|
| 214 |
output_type: str,
|
| 215 |
+
interpolation: Optional[str] = None,
|
| 216 |
+
ori_max_pe_len: Optional[int] = None,
|
| 217 |
) -> None:
|
| 218 |
if num_inference_steps < 1:
|
| 219 |
raise ValueError("num_inference_steps must be >= 1.")
|
| 220 |
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 221 |
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 222 |
+
if interpolation is not None and interpolation not in {
|
| 223 |
+
"no",
|
| 224 |
+
"linear",
|
| 225 |
+
"ntk-aware",
|
| 226 |
+
"ntk-by-parts",
|
| 227 |
+
"yarn",
|
| 228 |
+
"ntk-aware-pro1",
|
| 229 |
+
"ntk-aware-pro2",
|
| 230 |
+
"scale1",
|
| 231 |
+
"scale2",
|
| 232 |
+
}:
|
| 233 |
+
raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.")
|
| 234 |
+
if interpolation not in {None, "no"} and ori_max_pe_len is None:
|
| 235 |
+
raise ValueError("ori_max_pe_len is required when interpolation is enabled.")
|
| 236 |
|
| 237 |
spatial_downsample = self._get_vae_spatial_downsample()
|
| 238 |
if height % spatial_downsample != 0 or width % spatial_downsample != 0:
|
|
|
|
| 277 |
)
|
| 278 |
return packed_latents, image_sizes
|
| 279 |
|
| 280 |
+
def _maybe_configure_rope_extrapolation(
|
| 281 |
+
self,
|
| 282 |
+
height: int,
|
| 283 |
+
width: int,
|
| 284 |
+
interpolation: Optional[str],
|
| 285 |
+
ori_max_pe_len: Optional[int],
|
| 286 |
+
decouple: bool,
|
| 287 |
+
) -> None:
|
| 288 |
+
if interpolation in {None, "no"}:
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
spatial_downsample = self._get_vae_spatial_downsample()
|
| 292 |
+
patch_size = int(self.transformer.config.patch_size)
|
| 293 |
+
latent_h = height // spatial_downsample // patch_size
|
| 294 |
+
latent_w = width // spatial_downsample // patch_size
|
| 295 |
+
self.transformer.configure_rope_extrapolation(
|
| 296 |
+
custom_freqs=interpolation,
|
| 297 |
+
max_pe_len_h=latent_h,
|
| 298 |
+
max_pe_len_w=latent_w,
|
| 299 |
+
ori_max_pe_len=int(ori_max_pe_len),
|
| 300 |
+
decouple=decouple,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
def _apply_classifier_free_guidance(
|
| 304 |
self,
|
| 305 |
model_output: torch.Tensor,
|
|
|
|
| 344 |
num_inference_steps: int = 50,
|
| 345 |
guidance_scale: float = 1.0,
|
| 346 |
guidance_interval: Tuple[float, float] = (0.0, 1.0),
|
| 347 |
+
interpolation: Optional[str] = None,
|
| 348 |
+
ori_max_pe_len: Optional[int] = None,
|
| 349 |
+
decouple: bool = False,
|
| 350 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 351 |
output_type: str = "pil",
|
| 352 |
return_dict: bool = True,
|
|
|
|
| 367 |
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 368 |
guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
|
| 369 |
Flow-time interval where CFG is applied.
|
| 370 |
+
interpolation (`str`, *optional*):
|
| 371 |
+
VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or
|
| 372 |
+
`"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`,
|
| 373 |
+
`"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use
|
| 374 |
+
the transformer's configured RoPE.
|
| 375 |
+
ori_max_pe_len (`int`, *optional*):
|
| 376 |
+
Original maximum latent side length seen during training. Required when
|
| 377 |
+
`interpolation` is enabled.
|
| 378 |
+
decouple (`bool`, defaults to `False`):
|
| 379 |
+
Whether to decouple height and width when computing extrapolated RoPE frequencies.
|
| 380 |
generator (`torch.Generator`, *optional*):
|
| 381 |
RNG for reproducibility.
|
| 382 |
output_type (`str`, defaults to `"pil"`):
|
|
|
|
| 387 |
default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
|
| 388 |
height = int(height or default_size)
|
| 389 |
width = int(width or default_size)
|
| 390 |
+
self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len)
|
| 391 |
+
self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple)
|
| 392 |
|
| 393 |
device = self._execution_device
|
| 394 |
model_dtype = next(self.transformer.parameters()).dtype
|
NiT-L/transformer/nit_transformer_2d.py
CHANGED
|
@@ -74,6 +74,54 @@ def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.
|
|
| 74 |
return torch.get_default_dtype()
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
class NiTPatchEmbed(nn.Module):
|
| 78 |
def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
|
| 79 |
super().__init__()
|
|
@@ -125,6 +173,8 @@ class NiTLabelEmbedder(nn.Module):
|
|
| 125 |
|
| 126 |
|
| 127 |
class NiTRotaryEmbedding(nn.Module):
|
|
|
|
|
|
|
| 128 |
def __init__(
|
| 129 |
self,
|
| 130 |
head_dim: int,
|
|
@@ -137,26 +187,127 @@ class NiTRotaryEmbedding(nn.Module):
|
|
| 137 |
ori_max_pe_len: Optional[int] = None,
|
| 138 |
):
|
| 139 |
super().__init__()
|
| 140 |
-
|
| 141 |
-
if custom_freqs not in
|
| 142 |
raise ValueError(
|
| 143 |
-
"
|
| 144 |
-
"
|
| 145 |
-
"by changing the model config before loading."
|
| 146 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
dim = head_dim // 2
|
| 148 |
if dim % 2 != 0:
|
| 149 |
raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
default_dtype = _get_float_dtype_or_default()
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
|
| 156 |
freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
|
| 157 |
self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
|
| 158 |
self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 161 |
grids = []
|
| 162 |
for height, width in image_sizes.tolist():
|
|
@@ -166,10 +317,12 @@ class NiTRotaryEmbedding(nn.Module):
|
|
| 166 |
grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
|
| 167 |
grids.append(torch.stack(grid, dim=0).reshape(2, -1))
|
| 168 |
grid = torch.cat(grids, dim=1)
|
|
|
|
| 169 |
freqs_h = self.freqs_h_cached.to(device)[grid[0]]
|
| 170 |
freqs_w = self.freqs_w_cached.to(device)[grid[1]]
|
| 171 |
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
| 172 |
-
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
class NiTAttention(nn.Module):
|
|
@@ -367,6 +520,38 @@ class NiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
| 367 |
)
|
| 368 |
self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
|
| 371 |
batch_size, channels, height, width = hidden_states.shape
|
| 372 |
if channels != self.in_channels:
|
|
|
|
| 74 |
return torch.get_default_dtype()
|
| 75 |
|
| 76 |
|
| 77 |
+
# VisionYaRN / VisionNTK helpers (from native NiT / FiT VisionRotaryEmbedding).
|
| 78 |
+
def _find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
| 79 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
| 83 |
+
low = math.floor(_find_correction_factor(low_rot, dim, base, max_position_embeddings))
|
| 84 |
+
high = math.ceil(_find_correction_factor(high_rot, dim, base, max_position_embeddings))
|
| 85 |
+
return max(low, 0), min(high, dim - 1)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _linear_ramp_mask(min_value, max_value, dim):
|
| 89 |
+
if min_value == max_value:
|
| 90 |
+
max_value += 0.001
|
| 91 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min_value) / (max_value - min_value)
|
| 92 |
+
return torch.clamp(linear_func, 0, 1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _find_newbase_ntk(dim, base=10000, scale=1):
|
| 96 |
+
return base * scale ** (dim / (dim - 2))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _get_mscale(scale: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
return torch.where(scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _get_proportion(length_test, length_train):
|
| 104 |
+
length_test = length_test * 2
|
| 105 |
+
ratio = length_test / length_train
|
| 106 |
+
return torch.where(
|
| 107 |
+
torch.tensor(ratio) <= 1.0,
|
| 108 |
+
torch.tensor(1.0),
|
| 109 |
+
torch.sqrt(torch.log(torch.tensor(length_test)) / torch.log(torch.tensor(length_train))),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
TRAINED_ROPE_FREQS = {"normal", "scale1", "scale2"}
|
| 114 |
+
EXTRAPOLATION_ROPE_FREQS = {
|
| 115 |
+
"linear",
|
| 116 |
+
"ntk-aware",
|
| 117 |
+
"ntk-aware-pro1",
|
| 118 |
+
"ntk-aware-pro2",
|
| 119 |
+
"ntk-by-parts",
|
| 120 |
+
"yarn",
|
| 121 |
+
}
|
| 122 |
+
SUPPORTED_ROPE_FREQS = TRAINED_ROPE_FREQS | EXTRAPOLATION_ROPE_FREQS
|
| 123 |
+
|
| 124 |
+
|
| 125 |
class NiTPatchEmbed(nn.Module):
|
| 126 |
def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
|
| 127 |
super().__init__()
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
class NiTRotaryEmbedding(nn.Module):
|
| 176 |
+
"""2D axial RoPE with VisionYaRN (`yarn`) and VisionNTK extrapolation modes."""
|
| 177 |
+
|
| 178 |
def __init__(
|
| 179 |
self,
|
| 180 |
head_dim: int,
|
|
|
|
| 187 |
ori_max_pe_len: Optional[int] = None,
|
| 188 |
):
|
| 189 |
super().__init__()
|
| 190 |
+
custom_freqs = custom_freqs.lower()
|
| 191 |
+
if custom_freqs not in SUPPORTED_ROPE_FREQS:
|
| 192 |
raise ValueError(
|
| 193 |
+
f"Unsupported RoPE frequency variant {custom_freqs!r}. "
|
| 194 |
+
f"Supported values: {sorted(SUPPORTED_ROPE_FREQS)}."
|
|
|
|
| 195 |
)
|
| 196 |
+
if custom_freqs not in TRAINED_ROPE_FREQS and (
|
| 197 |
+
max_pe_len_h is None or max_pe_len_w is None or ori_max_pe_len is None
|
| 198 |
+
):
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"Extrapolation mode {custom_freqs!r} requires max_pe_len_h, max_pe_len_w, and ori_max_pe_len."
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
dim = head_dim // 2
|
| 204 |
if dim % 2 != 0:
|
| 205 |
raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
|
| 206 |
+
|
| 207 |
+
self.dim = dim
|
| 208 |
+
self.custom_freqs = custom_freqs
|
| 209 |
+
self.theta = theta
|
| 210 |
+
self.decouple = decouple
|
| 211 |
+
self.ori_max_pe_len = ori_max_pe_len
|
| 212 |
default_dtype = _get_float_dtype_or_default()
|
| 213 |
+
|
| 214 |
+
if custom_freqs in TRAINED_ROPE_FREQS:
|
| 215 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
|
| 216 |
+
freqs_h = freqs
|
| 217 |
+
freqs_w = freqs.clone()
|
| 218 |
+
else:
|
| 219 |
+
if decouple:
|
| 220 |
+
freqs_h = self._get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len, default_dtype)
|
| 221 |
+
freqs_w = self._get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len, default_dtype)
|
| 222 |
+
else:
|
| 223 |
+
max_pe_len = max(max_pe_len_h, max_pe_len_w)
|
| 224 |
+
freqs = self._get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len, default_dtype)
|
| 225 |
+
freqs_h = freqs
|
| 226 |
+
freqs_w = freqs.clone()
|
| 227 |
+
|
| 228 |
+
if max_pe_len_h is not None and max_pe_len_w is not None and ori_max_pe_len is not None:
|
| 229 |
+
scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)
|
| 230 |
+
proportion1 = _get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
|
| 231 |
+
proportion2 = _get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len**2)
|
| 232 |
+
self.register_buffer("mscale", _get_mscale(scale).to(default_dtype), persistent=False)
|
| 233 |
+
self.register_buffer(
|
| 234 |
+
"proportion1",
|
| 235 |
+
proportion1.to(dtype=default_dtype) if isinstance(proportion1, torch.Tensor) else torch.tensor(float(proportion1), dtype=default_dtype),
|
| 236 |
+
persistent=False,
|
| 237 |
+
)
|
| 238 |
+
self.register_buffer(
|
| 239 |
+
"proportion2",
|
| 240 |
+
proportion2.to(dtype=default_dtype) if isinstance(proportion2, torch.Tensor) else torch.tensor(float(proportion2), dtype=default_dtype),
|
| 241 |
+
persistent=False,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.register_buffer("freqs_h", freqs_h, persistent=False)
|
| 245 |
+
self.register_buffer("freqs_w", freqs_w, persistent=False)
|
| 246 |
+
|
| 247 |
+
cache_len = max(max_cached_len, max_pe_len_h or 0, max_pe_len_w or 0, 1)
|
| 248 |
+
positions = torch.arange(cache_len, dtype=default_dtype)
|
| 249 |
freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
|
| 250 |
freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
|
| 251 |
self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
|
| 252 |
self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
|
| 253 |
|
| 254 |
+
def _get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len, default_dtype):
|
| 255 |
+
assert isinstance(ori_max_pe_len, int)
|
| 256 |
+
if not isinstance(max_pe_len, torch.Tensor):
|
| 257 |
+
max_pe_len = torch.tensor(max_pe_len, dtype=default_dtype)
|
| 258 |
+
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
| 259 |
+
freq_indices = torch.arange(0, dim, 2, dtype=default_dtype) / dim
|
| 260 |
+
|
| 261 |
+
if self.custom_freqs == "linear":
|
| 262 |
+
freqs = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices)
|
| 263 |
+
elif self.custom_freqs in {"ntk-aware", "ntk-aware-pro1", "ntk-aware-pro2"}:
|
| 264 |
+
freqs = 1.0 / torch.pow(
|
| 265 |
+
_find_newbase_ntk(dim, theta, scale).view(-1, 1),
|
| 266 |
+
freq_indices.to(scale),
|
| 267 |
+
).squeeze()
|
| 268 |
+
elif self.custom_freqs == "ntk-by-parts":
|
| 269 |
+
beta_0, beta_1 = 1.25, 0.75
|
| 270 |
+
gamma_0, gamma_1 = 16, 2
|
| 271 |
+
freqs_base = 1.0 / (theta**freq_indices)
|
| 272 |
+
freqs_linear = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
|
| 273 |
+
freqs_ntk = 1.0 / torch.pow(
|
| 274 |
+
_find_newbase_ntk(dim, theta, scale).view(-1, 1),
|
| 275 |
+
freq_indices.to(scale),
|
| 276 |
+
).squeeze()
|
| 277 |
+
low, high = _find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
| 278 |
+
freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
|
| 279 |
+
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
| 280 |
+
low, high = _find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
| 281 |
+
freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
|
| 282 |
+
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
|
| 283 |
+
elif self.custom_freqs == "yarn":
|
| 284 |
+
beta_fast, beta_slow = 32, 1
|
| 285 |
+
freqs_extrapolation = 1.0 / (theta**freq_indices.to(scale))
|
| 286 |
+
freqs_interpolation = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
|
| 287 |
+
low, high = _find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
|
| 288 |
+
freqs_mask = (1 - _linear_ramp_mask(low, high, dim // 2).to(scale).float())
|
| 289 |
+
freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
|
| 290 |
+
else:
|
| 291 |
+
raise ValueError(f"Unknown extrapolation mode {self.custom_freqs!r}.")
|
| 292 |
+
|
| 293 |
+
if isinstance(freqs, torch.Tensor) and freqs.ndim > 1:
|
| 294 |
+
freqs = freqs.squeeze()
|
| 295 |
+
return freqs.to(default_dtype)
|
| 296 |
+
|
| 297 |
+
def _apply_magnitude_scaling(
|
| 298 |
+
self, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
|
| 299 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 300 |
+
if self.custom_freqs == "yarn" and hasattr(self, "mscale"):
|
| 301 |
+
freqs_cos = freqs_cos * self.mscale
|
| 302 |
+
freqs_sin = freqs_sin * self.mscale
|
| 303 |
+
elif self.custom_freqs in {"ntk-aware-pro1", "scale1"} and hasattr(self, "proportion1"):
|
| 304 |
+
freqs_cos = freqs_cos * self.proportion1
|
| 305 |
+
freqs_sin = freqs_sin * self.proportion1
|
| 306 |
+
elif self.custom_freqs in {"ntk-aware-pro2", "scale2"} and hasattr(self, "proportion2"):
|
| 307 |
+
freqs_cos = freqs_cos * self.proportion2
|
| 308 |
+
freqs_sin = freqs_sin * self.proportion2
|
| 309 |
+
return freqs_cos, freqs_sin
|
| 310 |
+
|
| 311 |
def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 312 |
grids = []
|
| 313 |
for height, width in image_sizes.tolist():
|
|
|
|
| 317 |
grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
|
| 318 |
grids.append(torch.stack(grid, dim=0).reshape(2, -1))
|
| 319 |
grid = torch.cat(grids, dim=1)
|
| 320 |
+
|
| 321 |
freqs_h = self.freqs_h_cached.to(device)[grid[0]]
|
| 322 |
freqs_w = self.freqs_w_cached.to(device)[grid[1]]
|
| 323 |
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
| 324 |
+
freqs_cos, freqs_sin = self._apply_magnitude_scaling(freqs.cos(), freqs.sin())
|
| 325 |
+
return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
|
| 326 |
|
| 327 |
|
| 328 |
class NiTAttention(nn.Module):
|
|
|
|
| 520 |
)
|
| 521 |
self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
|
| 522 |
|
| 523 |
+
def configure_rope_extrapolation(
|
| 524 |
+
self,
|
| 525 |
+
custom_freqs: str,
|
| 526 |
+
max_pe_len_h: int,
|
| 527 |
+
max_pe_len_w: int,
|
| 528 |
+
ori_max_pe_len: int,
|
| 529 |
+
decouple: bool = False,
|
| 530 |
+
theta: Optional[int] = None,
|
| 531 |
+
) -> None:
|
| 532 |
+
"""Configure VisionYaRN / VisionNTK extrapolation before high-resolution inference."""
|
| 533 |
+
theta = int(theta if theta is not None else getattr(self.config, "theta", 10000))
|
| 534 |
+
head_dim = self.config.hidden_size // self.config.num_heads
|
| 535 |
+
self.rope = NiTRotaryEmbedding(
|
| 536 |
+
head_dim,
|
| 537 |
+
custom_freqs=custom_freqs,
|
| 538 |
+
theta=theta,
|
| 539 |
+
max_pe_len_h=max_pe_len_h,
|
| 540 |
+
max_pe_len_w=max_pe_len_w,
|
| 541 |
+
decouple=decouple,
|
| 542 |
+
ori_max_pe_len=ori_max_pe_len,
|
| 543 |
+
)
|
| 544 |
+
for key, value in {
|
| 545 |
+
"custom_freqs": custom_freqs.lower(),
|
| 546 |
+
"max_pe_len_h": max_pe_len_h,
|
| 547 |
+
"max_pe_len_w": max_pe_len_w,
|
| 548 |
+
"decouple": decouple,
|
| 549 |
+
"ori_max_pe_len": ori_max_pe_len,
|
| 550 |
+
"theta": theta,
|
| 551 |
+
}.items():
|
| 552 |
+
if hasattr(self.config, key):
|
| 553 |
+
setattr(self.config, key, value)
|
| 554 |
+
|
| 555 |
def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
|
| 556 |
batch_size, channels, height, width = hidden_states.shape
|
| 557 |
if channels != self.in_channels:
|
NiT-S/pipeline.py
CHANGED
|
@@ -212,11 +212,27 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 212 |
width: int,
|
| 213 |
num_inference_steps: int,
|
| 214 |
output_type: str,
|
|
|
|
|
|
|
| 215 |
) -> None:
|
| 216 |
if num_inference_steps < 1:
|
| 217 |
raise ValueError("num_inference_steps must be >= 1.")
|
| 218 |
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 219 |
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
spatial_downsample = self._get_vae_spatial_downsample()
|
| 222 |
if height % spatial_downsample != 0 or width % spatial_downsample != 0:
|
|
@@ -261,6 +277,29 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 261 |
)
|
| 262 |
return packed_latents, image_sizes
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def _apply_classifier_free_guidance(
|
| 265 |
self,
|
| 266 |
model_output: torch.Tensor,
|
|
@@ -305,6 +344,9 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 305 |
num_inference_steps: int = 50,
|
| 306 |
guidance_scale: float = 1.0,
|
| 307 |
guidance_interval: Tuple[float, float] = (0.0, 1.0),
|
|
|
|
|
|
|
|
|
|
| 308 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 309 |
output_type: str = "pil",
|
| 310 |
return_dict: bool = True,
|
|
@@ -325,6 +367,16 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 325 |
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 326 |
guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
|
| 327 |
Flow-time interval where CFG is applied.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
generator (`torch.Generator`, *optional*):
|
| 329 |
RNG for reproducibility.
|
| 330 |
output_type (`str`, defaults to `"pil"`):
|
|
@@ -335,7 +387,8 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 335 |
default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
|
| 336 |
height = int(height or default_size)
|
| 337 |
width = int(width or default_size)
|
| 338 |
-
self.check_inputs(height, width, num_inference_steps, output_type)
|
|
|
|
| 339 |
|
| 340 |
device = self._execution_device
|
| 341 |
model_dtype = next(self.transformer.parameters()).dtype
|
|
|
|
| 212 |
width: int,
|
| 213 |
num_inference_steps: int,
|
| 214 |
output_type: str,
|
| 215 |
+
interpolation: Optional[str] = None,
|
| 216 |
+
ori_max_pe_len: Optional[int] = None,
|
| 217 |
) -> None:
|
| 218 |
if num_inference_steps < 1:
|
| 219 |
raise ValueError("num_inference_steps must be >= 1.")
|
| 220 |
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 221 |
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 222 |
+
if interpolation is not None and interpolation not in {
|
| 223 |
+
"no",
|
| 224 |
+
"linear",
|
| 225 |
+
"ntk-aware",
|
| 226 |
+
"ntk-by-parts",
|
| 227 |
+
"yarn",
|
| 228 |
+
"ntk-aware-pro1",
|
| 229 |
+
"ntk-aware-pro2",
|
| 230 |
+
"scale1",
|
| 231 |
+
"scale2",
|
| 232 |
+
}:
|
| 233 |
+
raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.")
|
| 234 |
+
if interpolation not in {None, "no"} and ori_max_pe_len is None:
|
| 235 |
+
raise ValueError("ori_max_pe_len is required when interpolation is enabled.")
|
| 236 |
|
| 237 |
spatial_downsample = self._get_vae_spatial_downsample()
|
| 238 |
if height % spatial_downsample != 0 or width % spatial_downsample != 0:
|
|
|
|
| 277 |
)
|
| 278 |
return packed_latents, image_sizes
|
| 279 |
|
| 280 |
+
def _maybe_configure_rope_extrapolation(
|
| 281 |
+
self,
|
| 282 |
+
height: int,
|
| 283 |
+
width: int,
|
| 284 |
+
interpolation: Optional[str],
|
| 285 |
+
ori_max_pe_len: Optional[int],
|
| 286 |
+
decouple: bool,
|
| 287 |
+
) -> None:
|
| 288 |
+
if interpolation in {None, "no"}:
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
spatial_downsample = self._get_vae_spatial_downsample()
|
| 292 |
+
patch_size = int(self.transformer.config.patch_size)
|
| 293 |
+
latent_h = height // spatial_downsample // patch_size
|
| 294 |
+
latent_w = width // spatial_downsample // patch_size
|
| 295 |
+
self.transformer.configure_rope_extrapolation(
|
| 296 |
+
custom_freqs=interpolation,
|
| 297 |
+
max_pe_len_h=latent_h,
|
| 298 |
+
max_pe_len_w=latent_w,
|
| 299 |
+
ori_max_pe_len=int(ori_max_pe_len),
|
| 300 |
+
decouple=decouple,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
def _apply_classifier_free_guidance(
|
| 304 |
self,
|
| 305 |
model_output: torch.Tensor,
|
|
|
|
| 344 |
num_inference_steps: int = 50,
|
| 345 |
guidance_scale: float = 1.0,
|
| 346 |
guidance_interval: Tuple[float, float] = (0.0, 1.0),
|
| 347 |
+
interpolation: Optional[str] = None,
|
| 348 |
+
ori_max_pe_len: Optional[int] = None,
|
| 349 |
+
decouple: bool = False,
|
| 350 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 351 |
output_type: str = "pil",
|
| 352 |
return_dict: bool = True,
|
|
|
|
| 367 |
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 368 |
guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
|
| 369 |
Flow-time interval where CFG is applied.
|
| 370 |
+
interpolation (`str`, *optional*):
|
| 371 |
+
VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or
|
| 372 |
+
`"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`,
|
| 373 |
+
`"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use
|
| 374 |
+
the transformer's configured RoPE.
|
| 375 |
+
ori_max_pe_len (`int`, *optional*):
|
| 376 |
+
Original maximum latent side length seen during training. Required when
|
| 377 |
+
`interpolation` is enabled.
|
| 378 |
+
decouple (`bool`, defaults to `False`):
|
| 379 |
+
Whether to decouple height and width when computing extrapolated RoPE frequencies.
|
| 380 |
generator (`torch.Generator`, *optional*):
|
| 381 |
RNG for reproducibility.
|
| 382 |
output_type (`str`, defaults to `"pil"`):
|
|
|
|
| 387 |
default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
|
| 388 |
height = int(height or default_size)
|
| 389 |
width = int(width or default_size)
|
| 390 |
+
self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len)
|
| 391 |
+
self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple)
|
| 392 |
|
| 393 |
device = self._execution_device
|
| 394 |
model_dtype = next(self.transformer.parameters()).dtype
|
NiT-S/transformer/nit_transformer_2d.py
CHANGED
|
@@ -74,6 +74,54 @@ def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.
|
|
| 74 |
return torch.get_default_dtype()
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
class NiTPatchEmbed(nn.Module):
|
| 78 |
def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
|
| 79 |
super().__init__()
|
|
@@ -125,6 +173,8 @@ class NiTLabelEmbedder(nn.Module):
|
|
| 125 |
|
| 126 |
|
| 127 |
class NiTRotaryEmbedding(nn.Module):
|
|
|
|
|
|
|
| 128 |
def __init__(
|
| 129 |
self,
|
| 130 |
head_dim: int,
|
|
@@ -137,26 +187,127 @@ class NiTRotaryEmbedding(nn.Module):
|
|
| 137 |
ori_max_pe_len: Optional[int] = None,
|
| 138 |
):
|
| 139 |
super().__init__()
|
| 140 |
-
|
| 141 |
-
if custom_freqs not in
|
| 142 |
raise ValueError(
|
| 143 |
-
"
|
| 144 |
-
"
|
| 145 |
-
"by changing the model config before loading."
|
| 146 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
dim = head_dim // 2
|
| 148 |
if dim % 2 != 0:
|
| 149 |
raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
default_dtype = _get_float_dtype_or_default()
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
|
| 156 |
freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
|
| 157 |
self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
|
| 158 |
self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 161 |
grids = []
|
| 162 |
for height, width in image_sizes.tolist():
|
|
@@ -166,10 +317,12 @@ class NiTRotaryEmbedding(nn.Module):
|
|
| 166 |
grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
|
| 167 |
grids.append(torch.stack(grid, dim=0).reshape(2, -1))
|
| 168 |
grid = torch.cat(grids, dim=1)
|
|
|
|
| 169 |
freqs_h = self.freqs_h_cached.to(device)[grid[0]]
|
| 170 |
freqs_w = self.freqs_w_cached.to(device)[grid[1]]
|
| 171 |
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
| 172 |
-
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
class NiTAttention(nn.Module):
|
|
@@ -367,6 +520,38 @@ class NiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
| 367 |
)
|
| 368 |
self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
|
| 371 |
batch_size, channels, height, width = hidden_states.shape
|
| 372 |
if channels != self.in_channels:
|
|
|
|
| 74 |
return torch.get_default_dtype()
|
| 75 |
|
| 76 |
|
| 77 |
+
# VisionYaRN / VisionNTK helpers (from native NiT / FiT VisionRotaryEmbedding).
|
| 78 |
+
def _find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
| 79 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
| 83 |
+
low = math.floor(_find_correction_factor(low_rot, dim, base, max_position_embeddings))
|
| 84 |
+
high = math.ceil(_find_correction_factor(high_rot, dim, base, max_position_embeddings))
|
| 85 |
+
return max(low, 0), min(high, dim - 1)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _linear_ramp_mask(min_value, max_value, dim):
|
| 89 |
+
if min_value == max_value:
|
| 90 |
+
max_value += 0.001
|
| 91 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min_value) / (max_value - min_value)
|
| 92 |
+
return torch.clamp(linear_func, 0, 1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _find_newbase_ntk(dim, base=10000, scale=1):
|
| 96 |
+
return base * scale ** (dim / (dim - 2))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _get_mscale(scale: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
return torch.where(scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _get_proportion(length_test, length_train):
|
| 104 |
+
length_test = length_test * 2
|
| 105 |
+
ratio = length_test / length_train
|
| 106 |
+
return torch.where(
|
| 107 |
+
torch.tensor(ratio) <= 1.0,
|
| 108 |
+
torch.tensor(1.0),
|
| 109 |
+
torch.sqrt(torch.log(torch.tensor(length_test)) / torch.log(torch.tensor(length_train))),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
TRAINED_ROPE_FREQS = {"normal", "scale1", "scale2"}
|
| 114 |
+
EXTRAPOLATION_ROPE_FREQS = {
|
| 115 |
+
"linear",
|
| 116 |
+
"ntk-aware",
|
| 117 |
+
"ntk-aware-pro1",
|
| 118 |
+
"ntk-aware-pro2",
|
| 119 |
+
"ntk-by-parts",
|
| 120 |
+
"yarn",
|
| 121 |
+
}
|
| 122 |
+
SUPPORTED_ROPE_FREQS = TRAINED_ROPE_FREQS | EXTRAPOLATION_ROPE_FREQS
|
| 123 |
+
|
| 124 |
+
|
| 125 |
class NiTPatchEmbed(nn.Module):
|
| 126 |
def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
|
| 127 |
super().__init__()
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
class NiTRotaryEmbedding(nn.Module):
|
| 176 |
+
"""2D axial RoPE with VisionYaRN (`yarn`) and VisionNTK extrapolation modes."""
|
| 177 |
+
|
| 178 |
def __init__(
|
| 179 |
self,
|
| 180 |
head_dim: int,
|
|
|
|
| 187 |
ori_max_pe_len: Optional[int] = None,
|
| 188 |
):
|
| 189 |
super().__init__()
|
| 190 |
+
custom_freqs = custom_freqs.lower()
|
| 191 |
+
if custom_freqs not in SUPPORTED_ROPE_FREQS:
|
| 192 |
raise ValueError(
|
| 193 |
+
f"Unsupported RoPE frequency variant {custom_freqs!r}. "
|
| 194 |
+
f"Supported values: {sorted(SUPPORTED_ROPE_FREQS)}."
|
|
|
|
| 195 |
)
|
| 196 |
+
if custom_freqs not in TRAINED_ROPE_FREQS and (
|
| 197 |
+
max_pe_len_h is None or max_pe_len_w is None or ori_max_pe_len is None
|
| 198 |
+
):
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"Extrapolation mode {custom_freqs!r} requires max_pe_len_h, max_pe_len_w, and ori_max_pe_len."
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
dim = head_dim // 2
|
| 204 |
if dim % 2 != 0:
|
| 205 |
raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
|
| 206 |
+
|
| 207 |
+
self.dim = dim
|
| 208 |
+
self.custom_freqs = custom_freqs
|
| 209 |
+
self.theta = theta
|
| 210 |
+
self.decouple = decouple
|
| 211 |
+
self.ori_max_pe_len = ori_max_pe_len
|
| 212 |
default_dtype = _get_float_dtype_or_default()
|
| 213 |
+
|
| 214 |
+
if custom_freqs in TRAINED_ROPE_FREQS:
|
| 215 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
|
| 216 |
+
freqs_h = freqs
|
| 217 |
+
freqs_w = freqs.clone()
|
| 218 |
+
else:
|
| 219 |
+
if decouple:
|
| 220 |
+
freqs_h = self._get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len, default_dtype)
|
| 221 |
+
freqs_w = self._get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len, default_dtype)
|
| 222 |
+
else:
|
| 223 |
+
max_pe_len = max(max_pe_len_h, max_pe_len_w)
|
| 224 |
+
freqs = self._get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len, default_dtype)
|
| 225 |
+
freqs_h = freqs
|
| 226 |
+
freqs_w = freqs.clone()
|
| 227 |
+
|
| 228 |
+
if max_pe_len_h is not None and max_pe_len_w is not None and ori_max_pe_len is not None:
|
| 229 |
+
scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)
|
| 230 |
+
proportion1 = _get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
|
| 231 |
+
proportion2 = _get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len**2)
|
| 232 |
+
self.register_buffer("mscale", _get_mscale(scale).to(default_dtype), persistent=False)
|
| 233 |
+
self.register_buffer(
|
| 234 |
+
"proportion1",
|
| 235 |
+
proportion1.to(dtype=default_dtype) if isinstance(proportion1, torch.Tensor) else torch.tensor(float(proportion1), dtype=default_dtype),
|
| 236 |
+
persistent=False,
|
| 237 |
+
)
|
| 238 |
+
self.register_buffer(
|
| 239 |
+
"proportion2",
|
| 240 |
+
proportion2.to(dtype=default_dtype) if isinstance(proportion2, torch.Tensor) else torch.tensor(float(proportion2), dtype=default_dtype),
|
| 241 |
+
persistent=False,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.register_buffer("freqs_h", freqs_h, persistent=False)
|
| 245 |
+
self.register_buffer("freqs_w", freqs_w, persistent=False)
|
| 246 |
+
|
| 247 |
+
cache_len = max(max_cached_len, max_pe_len_h or 0, max_pe_len_w or 0, 1)
|
| 248 |
+
positions = torch.arange(cache_len, dtype=default_dtype)
|
| 249 |
freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
|
| 250 |
freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
|
| 251 |
self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
|
| 252 |
self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
|
| 253 |
|
| 254 |
+
def _get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len, default_dtype):
|
| 255 |
+
assert isinstance(ori_max_pe_len, int)
|
| 256 |
+
if not isinstance(max_pe_len, torch.Tensor):
|
| 257 |
+
max_pe_len = torch.tensor(max_pe_len, dtype=default_dtype)
|
| 258 |
+
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
| 259 |
+
freq_indices = torch.arange(0, dim, 2, dtype=default_dtype) / dim
|
| 260 |
+
|
| 261 |
+
if self.custom_freqs == "linear":
|
| 262 |
+
freqs = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices)
|
| 263 |
+
elif self.custom_freqs in {"ntk-aware", "ntk-aware-pro1", "ntk-aware-pro2"}:
|
| 264 |
+
freqs = 1.0 / torch.pow(
|
| 265 |
+
_find_newbase_ntk(dim, theta, scale).view(-1, 1),
|
| 266 |
+
freq_indices.to(scale),
|
| 267 |
+
).squeeze()
|
| 268 |
+
elif self.custom_freqs == "ntk-by-parts":
|
| 269 |
+
beta_0, beta_1 = 1.25, 0.75
|
| 270 |
+
gamma_0, gamma_1 = 16, 2
|
| 271 |
+
freqs_base = 1.0 / (theta**freq_indices)
|
| 272 |
+
freqs_linear = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
|
| 273 |
+
freqs_ntk = 1.0 / torch.pow(
|
| 274 |
+
_find_newbase_ntk(dim, theta, scale).view(-1, 1),
|
| 275 |
+
freq_indices.to(scale),
|
| 276 |
+
).squeeze()
|
| 277 |
+
low, high = _find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
| 278 |
+
freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
|
| 279 |
+
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
| 280 |
+
low, high = _find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
| 281 |
+
freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
|
| 282 |
+
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
|
| 283 |
+
elif self.custom_freqs == "yarn":
|
| 284 |
+
beta_fast, beta_slow = 32, 1
|
| 285 |
+
freqs_extrapolation = 1.0 / (theta**freq_indices.to(scale))
|
| 286 |
+
freqs_interpolation = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
|
| 287 |
+
low, high = _find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
|
| 288 |
+
freqs_mask = (1 - _linear_ramp_mask(low, high, dim // 2).to(scale).float())
|
| 289 |
+
freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
|
| 290 |
+
else:
|
| 291 |
+
raise ValueError(f"Unknown extrapolation mode {self.custom_freqs!r}.")
|
| 292 |
+
|
| 293 |
+
if isinstance(freqs, torch.Tensor) and freqs.ndim > 1:
|
| 294 |
+
freqs = freqs.squeeze()
|
| 295 |
+
return freqs.to(default_dtype)
|
| 296 |
+
|
| 297 |
+
def _apply_magnitude_scaling(
|
| 298 |
+
self, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
|
| 299 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 300 |
+
if self.custom_freqs == "yarn" and hasattr(self, "mscale"):
|
| 301 |
+
freqs_cos = freqs_cos * self.mscale
|
| 302 |
+
freqs_sin = freqs_sin * self.mscale
|
| 303 |
+
elif self.custom_freqs in {"ntk-aware-pro1", "scale1"} and hasattr(self, "proportion1"):
|
| 304 |
+
freqs_cos = freqs_cos * self.proportion1
|
| 305 |
+
freqs_sin = freqs_sin * self.proportion1
|
| 306 |
+
elif self.custom_freqs in {"ntk-aware-pro2", "scale2"} and hasattr(self, "proportion2"):
|
| 307 |
+
freqs_cos = freqs_cos * self.proportion2
|
| 308 |
+
freqs_sin = freqs_sin * self.proportion2
|
| 309 |
+
return freqs_cos, freqs_sin
|
| 310 |
+
|
| 311 |
def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 312 |
grids = []
|
| 313 |
for height, width in image_sizes.tolist():
|
|
|
|
| 317 |
grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
|
| 318 |
grids.append(torch.stack(grid, dim=0).reshape(2, -1))
|
| 319 |
grid = torch.cat(grids, dim=1)
|
| 320 |
+
|
| 321 |
freqs_h = self.freqs_h_cached.to(device)[grid[0]]
|
| 322 |
freqs_w = self.freqs_w_cached.to(device)[grid[1]]
|
| 323 |
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
| 324 |
+
freqs_cos, freqs_sin = self._apply_magnitude_scaling(freqs.cos(), freqs.sin())
|
| 325 |
+
return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
|
| 326 |
|
| 327 |
|
| 328 |
class NiTAttention(nn.Module):
|
|
|
|
| 520 |
)
|
| 521 |
self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
|
| 522 |
|
| 523 |
+
def configure_rope_extrapolation(
|
| 524 |
+
self,
|
| 525 |
+
custom_freqs: str,
|
| 526 |
+
max_pe_len_h: int,
|
| 527 |
+
max_pe_len_w: int,
|
| 528 |
+
ori_max_pe_len: int,
|
| 529 |
+
decouple: bool = False,
|
| 530 |
+
theta: Optional[int] = None,
|
| 531 |
+
) -> None:
|
| 532 |
+
"""Configure VisionYaRN / VisionNTK extrapolation before high-resolution inference."""
|
| 533 |
+
theta = int(theta if theta is not None else getattr(self.config, "theta", 10000))
|
| 534 |
+
head_dim = self.config.hidden_size // self.config.num_heads
|
| 535 |
+
self.rope = NiTRotaryEmbedding(
|
| 536 |
+
head_dim,
|
| 537 |
+
custom_freqs=custom_freqs,
|
| 538 |
+
theta=theta,
|
| 539 |
+
max_pe_len_h=max_pe_len_h,
|
| 540 |
+
max_pe_len_w=max_pe_len_w,
|
| 541 |
+
decouple=decouple,
|
| 542 |
+
ori_max_pe_len=ori_max_pe_len,
|
| 543 |
+
)
|
| 544 |
+
for key, value in {
|
| 545 |
+
"custom_freqs": custom_freqs.lower(),
|
| 546 |
+
"max_pe_len_h": max_pe_len_h,
|
| 547 |
+
"max_pe_len_w": max_pe_len_w,
|
| 548 |
+
"decouple": decouple,
|
| 549 |
+
"ori_max_pe_len": ori_max_pe_len,
|
| 550 |
+
"theta": theta,
|
| 551 |
+
}.items():
|
| 552 |
+
if hasattr(self.config, key):
|
| 553 |
+
setattr(self.config, key, value)
|
| 554 |
+
|
| 555 |
def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
|
| 556 |
batch_size, channels, height, width = hidden_states.shape
|
| 557 |
if channels != self.in_channels:
|
NiT-XL/pipeline.py
CHANGED
|
@@ -212,11 +212,27 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 212 |
width: int,
|
| 213 |
num_inference_steps: int,
|
| 214 |
output_type: str,
|
|
|
|
|
|
|
| 215 |
) -> None:
|
| 216 |
if num_inference_steps < 1:
|
| 217 |
raise ValueError("num_inference_steps must be >= 1.")
|
| 218 |
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 219 |
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
spatial_downsample = self._get_vae_spatial_downsample()
|
| 222 |
if height % spatial_downsample != 0 or width % spatial_downsample != 0:
|
|
@@ -261,6 +277,29 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 261 |
)
|
| 262 |
return packed_latents, image_sizes
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def _apply_classifier_free_guidance(
|
| 265 |
self,
|
| 266 |
model_output: torch.Tensor,
|
|
@@ -305,6 +344,9 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 305 |
num_inference_steps: int = 50,
|
| 306 |
guidance_scale: float = 1.0,
|
| 307 |
guidance_interval: Tuple[float, float] = (0.0, 1.0),
|
|
|
|
|
|
|
|
|
|
| 308 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 309 |
output_type: str = "pil",
|
| 310 |
return_dict: bool = True,
|
|
@@ -325,6 +367,16 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 325 |
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 326 |
guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
|
| 327 |
Flow-time interval where CFG is applied.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
generator (`torch.Generator`, *optional*):
|
| 329 |
RNG for reproducibility.
|
| 330 |
output_type (`str`, defaults to `"pil"`):
|
|
@@ -335,7 +387,8 @@ class NiTPipeline(DiffusionPipeline):
|
|
| 335 |
default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
|
| 336 |
height = int(height or default_size)
|
| 337 |
width = int(width or default_size)
|
| 338 |
-
self.check_inputs(height, width, num_inference_steps, output_type)
|
|
|
|
| 339 |
|
| 340 |
device = self._execution_device
|
| 341 |
model_dtype = next(self.transformer.parameters()).dtype
|
|
|
|
| 212 |
width: int,
|
| 213 |
num_inference_steps: int,
|
| 214 |
output_type: str,
|
| 215 |
+
interpolation: Optional[str] = None,
|
| 216 |
+
ori_max_pe_len: Optional[int] = None,
|
| 217 |
) -> None:
|
| 218 |
if num_inference_steps < 1:
|
| 219 |
raise ValueError("num_inference_steps must be >= 1.")
|
| 220 |
if output_type not in {"pil", "np", "pt", "latent"}:
|
| 221 |
raise ValueError("output_type must be one of: 'pil', 'np', 'pt', 'latent'.")
|
| 222 |
+
if interpolation is not None and interpolation not in {
|
| 223 |
+
"no",
|
| 224 |
+
"linear",
|
| 225 |
+
"ntk-aware",
|
| 226 |
+
"ntk-by-parts",
|
| 227 |
+
"yarn",
|
| 228 |
+
"ntk-aware-pro1",
|
| 229 |
+
"ntk-aware-pro2",
|
| 230 |
+
"scale1",
|
| 231 |
+
"scale2",
|
| 232 |
+
}:
|
| 233 |
+
raise ValueError(f"Unsupported interpolation mode: {interpolation!r}.")
|
| 234 |
+
if interpolation not in {None, "no"} and ori_max_pe_len is None:
|
| 235 |
+
raise ValueError("ori_max_pe_len is required when interpolation is enabled.")
|
| 236 |
|
| 237 |
spatial_downsample = self._get_vae_spatial_downsample()
|
| 238 |
if height % spatial_downsample != 0 or width % spatial_downsample != 0:
|
|
|
|
| 277 |
)
|
| 278 |
return packed_latents, image_sizes
|
| 279 |
|
| 280 |
+
def _maybe_configure_rope_extrapolation(
|
| 281 |
+
self,
|
| 282 |
+
height: int,
|
| 283 |
+
width: int,
|
| 284 |
+
interpolation: Optional[str],
|
| 285 |
+
ori_max_pe_len: Optional[int],
|
| 286 |
+
decouple: bool,
|
| 287 |
+
) -> None:
|
| 288 |
+
if interpolation in {None, "no"}:
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
spatial_downsample = self._get_vae_spatial_downsample()
|
| 292 |
+
patch_size = int(self.transformer.config.patch_size)
|
| 293 |
+
latent_h = height // spatial_downsample // patch_size
|
| 294 |
+
latent_w = width // spatial_downsample // patch_size
|
| 295 |
+
self.transformer.configure_rope_extrapolation(
|
| 296 |
+
custom_freqs=interpolation,
|
| 297 |
+
max_pe_len_h=latent_h,
|
| 298 |
+
max_pe_len_w=latent_w,
|
| 299 |
+
ori_max_pe_len=int(ori_max_pe_len),
|
| 300 |
+
decouple=decouple,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
def _apply_classifier_free_guidance(
|
| 304 |
self,
|
| 305 |
model_output: torch.Tensor,
|
|
|
|
| 344 |
num_inference_steps: int = 50,
|
| 345 |
guidance_scale: float = 1.0,
|
| 346 |
guidance_interval: Tuple[float, float] = (0.0, 1.0),
|
| 347 |
+
interpolation: Optional[str] = None,
|
| 348 |
+
ori_max_pe_len: Optional[int] = None,
|
| 349 |
+
decouple: bool = False,
|
| 350 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 351 |
output_type: str = "pil",
|
| 352 |
return_dict: bool = True,
|
|
|
|
| 367 |
Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
|
| 368 |
guidance_interval (`tuple[float, float]`, defaults to `(0.0, 1.0)`):
|
| 369 |
Flow-time interval where CFG is applied.
|
| 370 |
+
interpolation (`str`, *optional*):
|
| 371 |
+
VisionYaRN / VisionNTK extrapolation mode. Use `"yarn"` for VisionYaRN or
|
| 372 |
+
`"ntk-aware"`, `"ntk-by-parts"`, `"ntk-aware-pro1"`, `"ntk-aware-pro2"`,
|
| 373 |
+
`"scale1"`, or `"scale2"` for VisionNTK variants. Pass `"no"` or omit to use
|
| 374 |
+
the transformer's configured RoPE.
|
| 375 |
+
ori_max_pe_len (`int`, *optional*):
|
| 376 |
+
Original maximum latent side length seen during training. Required when
|
| 377 |
+
`interpolation` is enabled.
|
| 378 |
+
decouple (`bool`, defaults to `False`):
|
| 379 |
+
Whether to decouple height and width when computing extrapolated RoPE frequencies.
|
| 380 |
generator (`torch.Generator`, *optional*):
|
| 381 |
RNG for reproducibility.
|
| 382 |
output_type (`str`, defaults to `"pil"`):
|
|
|
|
| 387 |
default_size = DEFAULT_NATIVE_RESOLUTION if self.vae is not None else 256
|
| 388 |
height = int(height or default_size)
|
| 389 |
width = int(width or default_size)
|
| 390 |
+
self.check_inputs(height, width, num_inference_steps, output_type, interpolation, ori_max_pe_len)
|
| 391 |
+
self._maybe_configure_rope_extrapolation(height, width, interpolation, ori_max_pe_len, decouple)
|
| 392 |
|
| 393 |
device = self._execution_device
|
| 394 |
model_dtype = next(self.transformer.parameters()).dtype
|
NiT-XL/transformer/nit_transformer_2d.py
CHANGED
|
@@ -74,6 +74,54 @@ def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.
|
|
| 74 |
return torch.get_default_dtype()
|
| 75 |
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
class NiTPatchEmbed(nn.Module):
|
| 78 |
def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
|
| 79 |
super().__init__()
|
|
@@ -125,6 +173,8 @@ class NiTLabelEmbedder(nn.Module):
|
|
| 125 |
|
| 126 |
|
| 127 |
class NiTRotaryEmbedding(nn.Module):
|
|
|
|
|
|
|
| 128 |
def __init__(
|
| 129 |
self,
|
| 130 |
head_dim: int,
|
|
@@ -137,26 +187,127 @@ class NiTRotaryEmbedding(nn.Module):
|
|
| 137 |
ori_max_pe_len: Optional[int] = None,
|
| 138 |
):
|
| 139 |
super().__init__()
|
| 140 |
-
|
| 141 |
-
if custom_freqs not in
|
| 142 |
raise ValueError(
|
| 143 |
-
"
|
| 144 |
-
"
|
| 145 |
-
"by changing the model config before loading."
|
| 146 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
dim = head_dim // 2
|
| 148 |
if dim % 2 != 0:
|
| 149 |
raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
default_dtype = _get_float_dtype_or_default()
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
|
| 156 |
freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
|
| 157 |
self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
|
| 158 |
self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 161 |
grids = []
|
| 162 |
for height, width in image_sizes.tolist():
|
|
@@ -166,10 +317,12 @@ class NiTRotaryEmbedding(nn.Module):
|
|
| 166 |
grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
|
| 167 |
grids.append(torch.stack(grid, dim=0).reshape(2, -1))
|
| 168 |
grid = torch.cat(grids, dim=1)
|
|
|
|
| 169 |
freqs_h = self.freqs_h_cached.to(device)[grid[0]]
|
| 170 |
freqs_w = self.freqs_w_cached.to(device)[grid[1]]
|
| 171 |
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
| 172 |
-
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
class NiTAttention(nn.Module):
|
|
@@ -367,6 +520,38 @@ class NiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
| 367 |
)
|
| 368 |
self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
|
| 369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
|
| 371 |
batch_size, channels, height, width = hidden_states.shape
|
| 372 |
if channels != self.in_channels:
|
|
|
|
| 74 |
return torch.get_default_dtype()
|
| 75 |
|
| 76 |
|
| 77 |
+
# VisionYaRN / VisionNTK helpers (from native NiT / FiT VisionRotaryEmbedding).
|
| 78 |
+
def _find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
|
| 79 |
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
|
| 83 |
+
low = math.floor(_find_correction_factor(low_rot, dim, base, max_position_embeddings))
|
| 84 |
+
high = math.ceil(_find_correction_factor(high_rot, dim, base, max_position_embeddings))
|
| 85 |
+
return max(low, 0), min(high, dim - 1)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _linear_ramp_mask(min_value, max_value, dim):
|
| 89 |
+
if min_value == max_value:
|
| 90 |
+
max_value += 0.001
|
| 91 |
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min_value) / (max_value - min_value)
|
| 92 |
+
return torch.clamp(linear_func, 0, 1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _find_newbase_ntk(dim, base=10000, scale=1):
|
| 96 |
+
return base * scale ** (dim / (dim - 2))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _get_mscale(scale: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
return torch.where(scale <= 1.0, torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _get_proportion(length_test, length_train):
|
| 104 |
+
length_test = length_test * 2
|
| 105 |
+
ratio = length_test / length_train
|
| 106 |
+
return torch.where(
|
| 107 |
+
torch.tensor(ratio) <= 1.0,
|
| 108 |
+
torch.tensor(1.0),
|
| 109 |
+
torch.sqrt(torch.log(torch.tensor(length_test)) / torch.log(torch.tensor(length_train))),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
TRAINED_ROPE_FREQS = {"normal", "scale1", "scale2"}
|
| 114 |
+
EXTRAPOLATION_ROPE_FREQS = {
|
| 115 |
+
"linear",
|
| 116 |
+
"ntk-aware",
|
| 117 |
+
"ntk-aware-pro1",
|
| 118 |
+
"ntk-aware-pro2",
|
| 119 |
+
"ntk-by-parts",
|
| 120 |
+
"yarn",
|
| 121 |
+
}
|
| 122 |
+
SUPPORTED_ROPE_FREQS = TRAINED_ROPE_FREQS | EXTRAPOLATION_ROPE_FREQS
|
| 123 |
+
|
| 124 |
+
|
| 125 |
class NiTPatchEmbed(nn.Module):
|
| 126 |
def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
|
| 127 |
super().__init__()
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
class NiTRotaryEmbedding(nn.Module):
|
| 176 |
+
"""2D axial RoPE with VisionYaRN (`yarn`) and VisionNTK extrapolation modes."""
|
| 177 |
+
|
| 178 |
def __init__(
|
| 179 |
self,
|
| 180 |
head_dim: int,
|
|
|
|
| 187 |
ori_max_pe_len: Optional[int] = None,
|
| 188 |
):
|
| 189 |
super().__init__()
|
| 190 |
+
custom_freqs = custom_freqs.lower()
|
| 191 |
+
if custom_freqs not in SUPPORTED_ROPE_FREQS:
|
| 192 |
raise ValueError(
|
| 193 |
+
f"Unsupported RoPE frequency variant {custom_freqs!r}. "
|
| 194 |
+
f"Supported values: {sorted(SUPPORTED_ROPE_FREQS)}."
|
|
|
|
| 195 |
)
|
| 196 |
+
if custom_freqs not in TRAINED_ROPE_FREQS and (
|
| 197 |
+
max_pe_len_h is None or max_pe_len_w is None or ori_max_pe_len is None
|
| 198 |
+
):
|
| 199 |
+
raise ValueError(
|
| 200 |
+
f"Extrapolation mode {custom_freqs!r} requires max_pe_len_h, max_pe_len_w, and ori_max_pe_len."
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
dim = head_dim // 2
|
| 204 |
if dim % 2 != 0:
|
| 205 |
raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
|
| 206 |
+
|
| 207 |
+
self.dim = dim
|
| 208 |
+
self.custom_freqs = custom_freqs
|
| 209 |
+
self.theta = theta
|
| 210 |
+
self.decouple = decouple
|
| 211 |
+
self.ori_max_pe_len = ori_max_pe_len
|
| 212 |
default_dtype = _get_float_dtype_or_default()
|
| 213 |
+
|
| 214 |
+
if custom_freqs in TRAINED_ROPE_FREQS:
|
| 215 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
|
| 216 |
+
freqs_h = freqs
|
| 217 |
+
freqs_w = freqs.clone()
|
| 218 |
+
else:
|
| 219 |
+
if decouple:
|
| 220 |
+
freqs_h = self._get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len, default_dtype)
|
| 221 |
+
freqs_w = self._get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len, default_dtype)
|
| 222 |
+
else:
|
| 223 |
+
max_pe_len = max(max_pe_len_h, max_pe_len_w)
|
| 224 |
+
freqs = self._get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len, default_dtype)
|
| 225 |
+
freqs_h = freqs
|
| 226 |
+
freqs_w = freqs.clone()
|
| 227 |
+
|
| 228 |
+
if max_pe_len_h is not None and max_pe_len_w is not None and ori_max_pe_len is not None:
|
| 229 |
+
scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)
|
| 230 |
+
proportion1 = _get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
|
| 231 |
+
proportion2 = _get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len**2)
|
| 232 |
+
self.register_buffer("mscale", _get_mscale(scale).to(default_dtype), persistent=False)
|
| 233 |
+
self.register_buffer(
|
| 234 |
+
"proportion1",
|
| 235 |
+
proportion1.to(dtype=default_dtype) if isinstance(proportion1, torch.Tensor) else torch.tensor(float(proportion1), dtype=default_dtype),
|
| 236 |
+
persistent=False,
|
| 237 |
+
)
|
| 238 |
+
self.register_buffer(
|
| 239 |
+
"proportion2",
|
| 240 |
+
proportion2.to(dtype=default_dtype) if isinstance(proportion2, torch.Tensor) else torch.tensor(float(proportion2), dtype=default_dtype),
|
| 241 |
+
persistent=False,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self.register_buffer("freqs_h", freqs_h, persistent=False)
|
| 245 |
+
self.register_buffer("freqs_w", freqs_w, persistent=False)
|
| 246 |
+
|
| 247 |
+
cache_len = max(max_cached_len, max_pe_len_h or 0, max_pe_len_w or 0, 1)
|
| 248 |
+
positions = torch.arange(cache_len, dtype=default_dtype)
|
| 249 |
freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
|
| 250 |
freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
|
| 251 |
self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
|
| 252 |
self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
|
| 253 |
|
| 254 |
+
def _get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len, default_dtype):
|
| 255 |
+
assert isinstance(ori_max_pe_len, int)
|
| 256 |
+
if not isinstance(max_pe_len, torch.Tensor):
|
| 257 |
+
max_pe_len = torch.tensor(max_pe_len, dtype=default_dtype)
|
| 258 |
+
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)
|
| 259 |
+
freq_indices = torch.arange(0, dim, 2, dtype=default_dtype) / dim
|
| 260 |
+
|
| 261 |
+
if self.custom_freqs == "linear":
|
| 262 |
+
freqs = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices)
|
| 263 |
+
elif self.custom_freqs in {"ntk-aware", "ntk-aware-pro1", "ntk-aware-pro2"}:
|
| 264 |
+
freqs = 1.0 / torch.pow(
|
| 265 |
+
_find_newbase_ntk(dim, theta, scale).view(-1, 1),
|
| 266 |
+
freq_indices.to(scale),
|
| 267 |
+
).squeeze()
|
| 268 |
+
elif self.custom_freqs == "ntk-by-parts":
|
| 269 |
+
beta_0, beta_1 = 1.25, 0.75
|
| 270 |
+
gamma_0, gamma_1 = 16, 2
|
| 271 |
+
freqs_base = 1.0 / (theta**freq_indices)
|
| 272 |
+
freqs_linear = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
|
| 273 |
+
freqs_ntk = 1.0 / torch.pow(
|
| 274 |
+
_find_newbase_ntk(dim, theta, scale).view(-1, 1),
|
| 275 |
+
freq_indices.to(scale),
|
| 276 |
+
).squeeze()
|
| 277 |
+
low, high = _find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
|
| 278 |
+
freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
|
| 279 |
+
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
|
| 280 |
+
low, high = _find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
|
| 281 |
+
freqs_mask = 1 - _linear_ramp_mask(low, high, dim // 2).to(scale)
|
| 282 |
+
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
|
| 283 |
+
elif self.custom_freqs == "yarn":
|
| 284 |
+
beta_fast, beta_slow = 32, 1
|
| 285 |
+
freqs_extrapolation = 1.0 / (theta**freq_indices.to(scale))
|
| 286 |
+
freqs_interpolation = 1.0 / torch.einsum("..., f -> ... f", scale, theta**freq_indices.to(scale))
|
| 287 |
+
low, high = _find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
|
| 288 |
+
freqs_mask = (1 - _linear_ramp_mask(low, high, dim // 2).to(scale).float())
|
| 289 |
+
freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
|
| 290 |
+
else:
|
| 291 |
+
raise ValueError(f"Unknown extrapolation mode {self.custom_freqs!r}.")
|
| 292 |
+
|
| 293 |
+
if isinstance(freqs, torch.Tensor) and freqs.ndim > 1:
|
| 294 |
+
freqs = freqs.squeeze()
|
| 295 |
+
return freqs.to(default_dtype)
|
| 296 |
+
|
| 297 |
+
def _apply_magnitude_scaling(
|
| 298 |
+
self, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
|
| 299 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 300 |
+
if self.custom_freqs == "yarn" and hasattr(self, "mscale"):
|
| 301 |
+
freqs_cos = freqs_cos * self.mscale
|
| 302 |
+
freqs_sin = freqs_sin * self.mscale
|
| 303 |
+
elif self.custom_freqs in {"ntk-aware-pro1", "scale1"} and hasattr(self, "proportion1"):
|
| 304 |
+
freqs_cos = freqs_cos * self.proportion1
|
| 305 |
+
freqs_sin = freqs_sin * self.proportion1
|
| 306 |
+
elif self.custom_freqs in {"ntk-aware-pro2", "scale2"} and hasattr(self, "proportion2"):
|
| 307 |
+
freqs_cos = freqs_cos * self.proportion2
|
| 308 |
+
freqs_sin = freqs_sin * self.proportion2
|
| 309 |
+
return freqs_cos, freqs_sin
|
| 310 |
+
|
| 311 |
def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 312 |
grids = []
|
| 313 |
for height, width in image_sizes.tolist():
|
|
|
|
| 317 |
grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
|
| 318 |
grids.append(torch.stack(grid, dim=0).reshape(2, -1))
|
| 319 |
grid = torch.cat(grids, dim=1)
|
| 320 |
+
|
| 321 |
freqs_h = self.freqs_h_cached.to(device)[grid[0]]
|
| 322 |
freqs_w = self.freqs_w_cached.to(device)[grid[1]]
|
| 323 |
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
| 324 |
+
freqs_cos, freqs_sin = self._apply_magnitude_scaling(freqs.cos(), freqs.sin())
|
| 325 |
+
return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
|
| 326 |
|
| 327 |
|
| 328 |
class NiTAttention(nn.Module):
|
|
|
|
| 520 |
)
|
| 521 |
self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
|
| 522 |
|
| 523 |
+
def configure_rope_extrapolation(
|
| 524 |
+
self,
|
| 525 |
+
custom_freqs: str,
|
| 526 |
+
max_pe_len_h: int,
|
| 527 |
+
max_pe_len_w: int,
|
| 528 |
+
ori_max_pe_len: int,
|
| 529 |
+
decouple: bool = False,
|
| 530 |
+
theta: Optional[int] = None,
|
| 531 |
+
) -> None:
|
| 532 |
+
"""Configure VisionYaRN / VisionNTK extrapolation before high-resolution inference."""
|
| 533 |
+
theta = int(theta if theta is not None else getattr(self.config, "theta", 10000))
|
| 534 |
+
head_dim = self.config.hidden_size // self.config.num_heads
|
| 535 |
+
self.rope = NiTRotaryEmbedding(
|
| 536 |
+
head_dim,
|
| 537 |
+
custom_freqs=custom_freqs,
|
| 538 |
+
theta=theta,
|
| 539 |
+
max_pe_len_h=max_pe_len_h,
|
| 540 |
+
max_pe_len_w=max_pe_len_w,
|
| 541 |
+
decouple=decouple,
|
| 542 |
+
ori_max_pe_len=ori_max_pe_len,
|
| 543 |
+
)
|
| 544 |
+
for key, value in {
|
| 545 |
+
"custom_freqs": custom_freqs.lower(),
|
| 546 |
+
"max_pe_len_h": max_pe_len_h,
|
| 547 |
+
"max_pe_len_w": max_pe_len_w,
|
| 548 |
+
"decouple": decouple,
|
| 549 |
+
"ori_max_pe_len": ori_max_pe_len,
|
| 550 |
+
"theta": theta,
|
| 551 |
+
}.items():
|
| 552 |
+
if hasattr(self.config, key):
|
| 553 |
+
setattr(self.config, key, value)
|
| 554 |
+
|
| 555 |
def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
|
| 556 |
batch_size, channels, height, width = hidden_states.shape
|
| 557 |
if channels != self.in_channels:
|