BiliSakura commited on
Commit
cc27ca3
·
verified ·
1 Parent(s): dc71583

Delete custom_pipeline

Browse files
custom_pipeline/__init__.py DELETED
@@ -1,30 +0,0 @@
1
- from .pipeline_nit import NiTPipeline, NiTPipelineOutput
2
- from .transformer_nit import NiTTransformer2DModel, NiTTransformer2DModelOutput
3
- from .scheduling_flow_match_nit import NiTFlowMatchScheduler, NiTFlowMatchSchedulerOutput
4
-
5
-
6
- def _register_with_diffusers():
7
- """
8
- Expose NiT classes on the `diffusers` namespace so pipeline/component loading
9
- via `from_pretrained()` can resolve entries declared in model_index.json.
10
- """
11
- try:
12
- import diffusers
13
- except Exception:
14
- return
15
-
16
- setattr(diffusers, "NiTPipeline", NiTPipeline)
17
- setattr(diffusers, "NiTTransformer2DModel", NiTTransformer2DModel)
18
- setattr(diffusers, "NiTFlowMatchScheduler", NiTFlowMatchScheduler)
19
-
20
-
21
- _register_with_diffusers()
22
-
23
- __all__ = [
24
- "NiTPipeline",
25
- "NiTPipelineOutput",
26
- "NiTTransformer2DModel",
27
- "NiTTransformer2DModelOutput",
28
- "NiTFlowMatchScheduler",
29
- "NiTFlowMatchSchedulerOutput",
30
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
custom_pipeline/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (1.11 kB)
 
custom_pipeline/__pycache__/pipeline_nit.cpython-312.pyc DELETED
Binary file (12.3 kB)
 
custom_pipeline/__pycache__/scheduling_flow_match_nit.cpython-312.pyc DELETED
Binary file (11.3 kB)
 
custom_pipeline/__pycache__/transformer_nit.cpython-312.pyc DELETED
Binary file (31.4 kB)
 
custom_pipeline/pipeline_nit.py DELETED
@@ -1,237 +0,0 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
-
6
- from dataclasses import dataclass
7
- from typing import List, Optional, Tuple, Union
8
-
9
- import torch
10
-
11
- try:
12
- from diffusers.image_processor import VaeImageProcessor
13
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
14
- from diffusers.utils import BaseOutput
15
- except Exception: # pragma: no cover - importable without a full diffusers install.
16
- class BaseOutput(dict):
17
- def __post_init__(self):
18
- self.update(self.__dict__)
19
-
20
- class DiffusionPipeline:
21
- def register_modules(self, **kwargs):
22
- for name, module in kwargs.items():
23
- setattr(self, name, module)
24
-
25
- @property
26
- def _execution_device(self):
27
- return torch.device("cpu")
28
-
29
- def maybe_free_model_hooks(self):
30
- pass
31
-
32
- class VaeImageProcessor:
33
- def postprocess(self, image, output_type="pil"):
34
- return image
35
-
36
-
37
- @dataclass
38
- class NiTPipelineOutput(BaseOutput):
39
- images: Union[torch.FloatTensor, List]
40
-
41
-
42
- class NiTPipeline(DiffusionPipeline):
43
- r"""
44
- Native-resolution Image Synthesis pipeline using a class-conditional NiT transformer.
45
-
46
- This pipeline follows Diffusers conventions: transformer, scheduler, and VAE are
47
- saved as separate subfolders and restored with `DiffusionPipeline.from_pretrained`.
48
- The transformer predicts flow-matching velocity in latent space.
49
- """
50
-
51
- model_cpu_offload_seq = "transformer->vae"
52
- _optional_components = ["vae"]
53
-
54
- def __init__(self, transformer, scheduler, vae=None):
55
- super().__init__()
56
- self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
57
- self.image_processor = VaeImageProcessor()
58
-
59
- def _prepare_latents(
60
- self,
61
- batch_size: int,
62
- height: int,
63
- width: int,
64
- dtype: torch.dtype,
65
- device: torch.device,
66
- generator: Optional[Union[torch.Generator, List[torch.Generator]]],
67
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
68
- if self.vae is None:
69
- spatial_downsample = 1
70
- elif self.vae.__class__.__name__ == "AutoencoderDC" or "dc-ae" in getattr(self.vae.config, "_name_or_path", ""):
71
- spatial_downsample = 32
72
- else:
73
- spatial_downsample = getattr(self.vae.config, "block_out_channels", [0, 0, 0, 0])
74
- spatial_downsample = 2 ** (len(spatial_downsample) - 1)
75
-
76
- if height % spatial_downsample != 0 or width % spatial_downsample != 0:
77
- raise ValueError(f"height and width must be divisible by the VAE downsample factor {spatial_downsample}.")
78
-
79
- latent_height = height // spatial_downsample
80
- latent_width = width // spatial_downsample
81
- patch_size = int(self.transformer.config.patch_size)
82
- if latent_height % patch_size != 0 or latent_width % patch_size != 0:
83
- raise ValueError("Latent height and width must be divisible by transformer's patch_size.")
84
-
85
- token_height = latent_height // patch_size
86
- token_width = latent_width // patch_size
87
- image_sizes = torch.tensor([[token_height, token_width]] * batch_size, device=device, dtype=torch.long)
88
-
89
- # Match native NiT sampler initialization exactly: sample directly in packed-token space.
90
- packed_shape = (
91
- batch_size * token_height * token_width,
92
- self.transformer.config.in_channels,
93
- patch_size,
94
- patch_size,
95
- )
96
- packed_latents = torch.randn(packed_shape, generator=generator, device=device, dtype=dtype)
97
- return packed_latents, image_sizes
98
-
99
- def _apply_classifier_free_guidance(
100
- self,
101
- model_output: torch.Tensor,
102
- guidance_scale: float,
103
- guidance_active: bool,
104
- ) -> torch.Tensor:
105
- if guidance_scale <= 1.0 or not guidance_active:
106
- return model_output
107
- model_output_cond, model_output_uncond = model_output.chunk(2)
108
- return model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
109
-
110
- def _get_vae_dtype(self, latents: torch.Tensor) -> torch.dtype:
111
- vae_dtype = getattr(self.vae, "dtype", None)
112
- if vae_dtype is not None:
113
- return vae_dtype
114
- vae_params = next(self.vae.parameters(), None)
115
- return vae_params.dtype if vae_params is not None else latents.dtype
116
-
117
- def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
118
- if self.vae is None:
119
- return latents
120
- vae_dtype = self._get_vae_dtype(latents)
121
- latents = latents.to(dtype=vae_dtype)
122
- scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
123
- latents = latents / scaling_factor
124
- if self.vae.__class__.__name__ == "AutoencoderDC":
125
- image = self.vae._decode(latents)
126
- else:
127
- image = self.vae.decode(latents)
128
- image = image.sample if hasattr(image, "sample") else image
129
- return image
130
-
131
- @torch.no_grad()
132
- def __call__(
133
- self,
134
- class_labels: Union[int, List[int], torch.LongTensor],
135
- height: int = 256,
136
- width: int = 256,
137
- num_inference_steps: int = 50,
138
- guidance_scale: float = 1.0,
139
- guidance_interval: Tuple[float, float] = (0.0, 1.0),
140
- mode: str = "ode",
141
- heun: bool = False,
142
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
143
- output_type: str = "pil",
144
- return_dict: bool = True,
145
- ) -> Union[NiTPipelineOutput, Tuple]:
146
- device = self._execution_device
147
- model_dtype = next(self.transformer.parameters()).dtype
148
-
149
- if isinstance(class_labels, int):
150
- class_labels = [class_labels]
151
- if not torch.is_tensor(class_labels):
152
- class_labels = torch.tensor(class_labels, device=device, dtype=torch.long)
153
- else:
154
- class_labels = class_labels.to(device=device, dtype=torch.long)
155
- batch_size = class_labels.numel()
156
-
157
- packed_latents, image_sizes = self._prepare_latents(batch_size, height, width, model_dtype, device, generator)
158
- packed_latents = packed_latents.to(dtype=torch.float64)
159
- timesteps = self.scheduler.set_timesteps(num_inference_steps, device=device, mode=mode)
160
-
161
- null_labels = torch.full_like(class_labels, self.transformer.config.num_classes)
162
- for index, timestep in enumerate(timesteps[:-1]):
163
- next_timestep = timesteps[index + 1]
164
- guidance_active = guidance_interval[0] <= float(timestep) <= guidance_interval[1]
165
- if guidance_scale > 1.0 and guidance_active:
166
- model_input = torch.cat([packed_latents, packed_latents], dim=0)
167
- labels = torch.cat([class_labels, null_labels], dim=0)
168
- model_image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
169
- else:
170
- model_input = packed_latents
171
- labels = class_labels
172
- model_image_sizes = image_sizes
173
-
174
- timestep_batch = torch.full((labels.numel(),), float(timestep), device=device, dtype=model_dtype)
175
- model_output = self.transformer(
176
- model_input.to(dtype=model_dtype),
177
- timestep_batch,
178
- labels,
179
- image_sizes=model_image_sizes,
180
- return_dict=True,
181
- ).sample
182
- model_output = self._apply_classifier_free_guidance(model_output, guidance_scale, guidance_active)
183
-
184
- if heun and mode == "ode" and index < len(timesteps) - 2:
185
- provisional = self.scheduler.step(
186
- model_output,
187
- timestep[None],
188
- packed_latents,
189
- next_timestep[None],
190
- image_sizes=image_sizes,
191
- ).prev_sample
192
- if guidance_scale > 1.0 and guidance_active:
193
- prime_input = torch.cat([provisional, provisional], dim=0)
194
- labels = torch.cat([class_labels, null_labels], dim=0)
195
- model_image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
196
- else:
197
- prime_input = provisional
198
- labels = class_labels
199
- model_image_sizes = image_sizes
200
- next_timestep_batch = torch.full((labels.numel(),), float(next_timestep), device=device, dtype=model_dtype)
201
- next_model_output = self.transformer(
202
- prime_input.to(dtype=model_dtype),
203
- next_timestep_batch,
204
- labels,
205
- image_sizes=model_image_sizes,
206
- return_dict=True,
207
- ).sample
208
- next_model_output = self._apply_classifier_free_guidance(
209
- next_model_output, guidance_scale, guidance_active
210
- )
211
- packed_latents = self.scheduler.step_heun(
212
- model_output, next_model_output, timestep[None], packed_latents, next_timestep[None]
213
- ).prev_sample
214
- else:
215
- packed_latents = self.scheduler.step(
216
- model_output,
217
- timestep[None],
218
- packed_latents,
219
- next_timestep[None],
220
- image_sizes=image_sizes,
221
- generator=generator,
222
- ).prev_sample
223
-
224
- latents = self.transformer._unpack_latents(packed_latents, image_sizes)
225
- image = self._decode_latents(latents)
226
- if self.vae is not None:
227
- image = (image / 2 + 0.5).clamp(0, 1)
228
- image = self.image_processor.postprocess(
229
- image,
230
- output_type=output_type,
231
- do_denormalize=[False] * image.shape[0],
232
- )
233
-
234
- self.maybe_free_model_hooks()
235
- if not return_dict:
236
- return (image,)
237
- return NiTPipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
custom_pipeline/scheduling_flow_match_nit.py DELETED
@@ -1,187 +0,0 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
-
6
- from dataclasses import dataclass
7
- from typing import Optional, Tuple
8
-
9
- import numpy as np
10
- import torch
11
-
12
- try:
13
- from diffusers.configuration_utils import ConfigMixin, register_to_config
14
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
15
- from diffusers.utils import BaseOutput
16
- except Exception: # pragma: no cover - importable without an installed diffusers checkout.
17
- class BaseOutput(dict):
18
- def __post_init__(self):
19
- self.update(self.__dict__)
20
-
21
- class ConfigMixin:
22
- config_name = "scheduler_config.json"
23
-
24
- class SchedulerMixin:
25
- pass
26
-
27
- def register_to_config(init):
28
- return init
29
-
30
-
31
- @dataclass
32
- class NiTFlowMatchSchedulerOutput(BaseOutput):
33
- prev_sample: torch.FloatTensor
34
-
35
-
36
- class NiTFlowMatchScheduler(SchedulerMixin, ConfigMixin):
37
- """
38
- Flow-matching ODE/SDE scheduler used by Native-resolution Image Synthesis (NiT).
39
-
40
- The model predicts velocity with a linear path by default. Timesteps run from 1 to 0,
41
- matching the original sampler while exposing the standard Diffusers `set_timesteps`
42
- and `step` API.
43
- """
44
-
45
- config_name = "scheduler_config.json"
46
- order = 1
47
-
48
- @register_to_config
49
- def __init__(
50
- self,
51
- mode: str = "ode",
52
- path_type: str = "linear",
53
- num_train_timesteps: int = 1000,
54
- ):
55
- if mode not in {"ode", "sde"}:
56
- raise ValueError("mode must be either 'ode' or 'sde'.")
57
- if path_type not in {"linear", "cosine"}:
58
- raise ValueError("path_type must be either 'linear' or 'cosine'.")
59
- self.mode = mode
60
- self.path_type = path_type
61
- self.num_train_timesteps = num_train_timesteps
62
- # Native NiT integrates in float64 for better numerical stability.
63
- self.timesteps = torch.from_numpy(np.linspace(1.0, 0.0, num_train_timesteps + 1)).to(dtype=torch.float64)
64
-
65
- def set_timesteps(
66
- self,
67
- num_inference_steps: int,
68
- device: Optional[torch.device] = None,
69
- mode: Optional[str] = None,
70
- ):
71
- mode = mode or self.mode
72
- dtype = self.timesteps.dtype
73
- if mode == "sde":
74
- timesteps = torch.linspace(1.0, 0.04, num_inference_steps, dtype=dtype)
75
- timesteps = torch.cat([timesteps, torch.zeros(1, dtype=dtype)])
76
- elif mode == "ode":
77
- timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, dtype=dtype)
78
- else:
79
- raise ValueError("mode must be either 'ode' or 'sde'.")
80
- self.mode = mode
81
- self.timesteps = timesteps.to(device=device)
82
- return self.timesteps
83
-
84
- @staticmethod
85
- def _expand_t_like_sample(timestep: torch.Tensor, sample: torch.Tensor, image_sizes: torch.LongTensor):
86
- dims = [1] * (sample.ndim - 1)
87
- seqlens = image_sizes[:, 0] * image_sizes[:, 1]
88
- if timestep.numel() == 1:
89
- timestep = timestep.repeat(image_sizes.shape[0])
90
- return torch.cat(
91
- [timestep[i].reshape(1, *dims).repeat(int(seqlens[i]), *dims) for i in range(image_sizes.shape[0])]
92
- )
93
-
94
- def _get_score_from_velocity(
95
- self,
96
- model_output: torch.Tensor,
97
- sample: torch.Tensor,
98
- timestep: torch.Tensor,
99
- image_sizes: torch.LongTensor,
100
- ):
101
- timestep = self._expand_t_like_sample(timestep, sample, image_sizes)
102
- if self.path_type == "linear":
103
- alpha_t, d_alpha_t = 1 - timestep, torch.ones_like(timestep) * -1
104
- sigma_t, d_sigma_t = timestep, torch.ones_like(timestep)
105
- elif self.path_type == "cosine":
106
- alpha_t = torch.cos(timestep * np.pi / 2)
107
- sigma_t = torch.sin(timestep * np.pi / 2)
108
- d_alpha_t = -np.pi / 2 * torch.sin(timestep * np.pi / 2)
109
- d_sigma_t = np.pi / 2 * torch.cos(timestep * np.pi / 2)
110
- else:
111
- raise ValueError(f"Unsupported path_type: {self.path_type}")
112
- reverse_alpha_ratio = alpha_t / d_alpha_t
113
- variance = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
114
- return (reverse_alpha_ratio * model_output - sample) / variance
115
-
116
- @staticmethod
117
- def _compute_diffusion(timestep: torch.Tensor):
118
- return 2 * timestep
119
-
120
- @staticmethod
121
- def _promote_dtypes(*tensors: torch.Tensor) -> torch.dtype:
122
- dtype = None
123
- for tensor in tensors:
124
- if tensor.is_floating_point() or tensor.is_complex():
125
- dtype = tensor.dtype if dtype is None else torch.promote_types(dtype, tensor.dtype)
126
- return dtype if dtype is not None else torch.get_default_dtype()
127
-
128
- def step(
129
- self,
130
- model_output: torch.Tensor,
131
- timestep: torch.Tensor,
132
- sample: torch.Tensor,
133
- next_timestep: torch.Tensor,
134
- image_sizes: Optional[torch.LongTensor] = None,
135
- generator: Optional[torch.Generator] = None,
136
- return_dict: bool = True,
137
- ) -> NiTFlowMatchSchedulerOutput:
138
- compute_dtype = torch.float64
139
- sample = sample.to(dtype=compute_dtype)
140
- model_output = model_output.to(dtype=compute_dtype)
141
- timestep = timestep.to(device=sample.device, dtype=compute_dtype).flatten()
142
- next_timestep = next_timestep.to(device=sample.device, dtype=compute_dtype).flatten()
143
-
144
- if self.mode == "ode":
145
- prev_sample = sample + (next_timestep[0] - timestep[0]) * model_output
146
- else:
147
- if image_sizes is None:
148
- raise ValueError("image_sizes are required for SDE sampling.")
149
- image_sizes = image_sizes.to(device=sample.device, dtype=torch.long)
150
- diffusion = self._compute_diffusion(timestep[0])
151
- score = self._get_score_from_velocity(model_output, sample, timestep, image_sizes)
152
- drift = model_output - 0.5 * diffusion * score
153
- dt = next_timestep[0] - timestep[0]
154
- if torch.allclose(next_timestep[0], torch.zeros_like(next_timestep[0])):
155
- prev_sample = sample + drift * dt
156
- else:
157
- if generator is not None:
158
- noise = torch.randn(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
159
- else:
160
- noise = torch.randn_like(sample)
161
- prev_sample = sample + drift * dt + torch.sqrt(diffusion) * noise * torch.sqrt(torch.abs(dt))
162
-
163
- if not return_dict:
164
- return (prev_sample,)
165
- return NiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
166
-
167
- def step_heun(
168
- self,
169
- model_output: torch.Tensor,
170
- next_model_output: torch.Tensor,
171
- timestep: torch.Tensor,
172
- sample: torch.Tensor,
173
- next_timestep: torch.Tensor,
174
- return_dict: bool = True,
175
- ) -> NiTFlowMatchSchedulerOutput:
176
- if self.mode != "ode":
177
- raise ValueError("Heun correction is only defined for ODE sampling.")
178
- compute_dtype = torch.float64
179
- sample = sample.to(dtype=compute_dtype)
180
- model_output = model_output.to(dtype=compute_dtype)
181
- next_model_output = next_model_output.to(dtype=compute_dtype)
182
- timestep = timestep.to(device=sample.device, dtype=compute_dtype).flatten()
183
- next_timestep = next_timestep.to(device=sample.device, dtype=compute_dtype).flatten()
184
- prev_sample = sample + (next_timestep[0] - timestep[0]) * (0.5 * model_output + 0.5 * next_model_output)
185
- if not return_dict:
186
- return (prev_sample,)
187
- return NiTFlowMatchSchedulerOutput(prev_sample=prev_sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
custom_pipeline/transformer_nit.py DELETED
@@ -1,471 +0,0 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
-
6
- from dataclasses import dataclass
7
- import math
8
- from typing import List, Optional, Tuple, Union
9
-
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
-
14
- try:
15
- from diffusers.configuration_utils import ConfigMixin, register_to_config
16
- from diffusers.models.modeling_utils import ModelMixin
17
- from diffusers.utils import BaseOutput
18
- except Exception: # pragma: no cover - lets this subtree be tested outside diffusers.
19
- class BaseOutput(dict):
20
- def __post_init__(self):
21
- self.update(self.__dict__)
22
-
23
- class _Config(dict):
24
- def __getattr__(self, key):
25
- try:
26
- return self[key]
27
- except KeyError as error:
28
- raise AttributeError(key) from error
29
-
30
- class ConfigMixin:
31
- config_name = "config.json"
32
-
33
- class ModelMixin(nn.Module):
34
- pass
35
-
36
- def register_to_config(init):
37
- def wrapper(self, *args, **kwargs):
38
- import inspect
39
-
40
- signature = inspect.signature(init)
41
- bound = signature.bind(self, *args, **kwargs)
42
- bound.apply_defaults()
43
- self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"})
44
- init(self, *args, **kwargs)
45
-
46
- return wrapper
47
-
48
-
49
- try:
50
- from flash_attn import flash_attn_varlen_func
51
- except Exception: # pragma: no cover - optional acceleration.
52
- flash_attn_varlen_func = None
53
-
54
-
55
- @dataclass
56
- class NiTTransformer2DModelOutput(BaseOutput):
57
- sample: torch.FloatTensor
58
- projection_states: Optional[Tuple[torch.FloatTensor, ...]] = None
59
-
60
-
61
- def _modulate(hidden_states: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
62
- return hidden_states * (1 + scale) + shift
63
-
64
-
65
- def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
66
- hidden_states = hidden_states.reshape(*hidden_states.shape[:-1], -1, 2)
67
- hidden_states_1, hidden_states_2 = hidden_states.unbind(dim=-1)
68
- return torch.stack((-hidden_states_2, hidden_states_1), dim=-1).flatten(-2)
69
-
70
-
71
- def _get_float_dtype_or_default(tensor: Optional[torch.Tensor] = None) -> torch.dtype:
72
- if tensor is not None and tensor.is_floating_point():
73
- return tensor.dtype
74
- return torch.get_default_dtype()
75
-
76
-
77
- class NiTPatchEmbed(nn.Module):
78
- def __init__(self, patch_size: int, in_channels: int, hidden_size: int):
79
- super().__init__()
80
- self.patch_size = (patch_size, patch_size)
81
- self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=True)
82
-
83
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
84
- hidden_states = self.proj(hidden_states)
85
- return hidden_states.flatten(2).transpose(1, 2)
86
-
87
-
88
- class NiTTimestepEmbedder(nn.Module):
89
- def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
90
- super().__init__()
91
- self.frequency_embedding_size = frequency_embedding_size
92
- self.mlp = nn.Sequential(
93
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
94
- nn.SiLU(),
95
- nn.Linear(hidden_size, hidden_size, bias=True),
96
- )
97
-
98
- @staticmethod
99
- def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000):
100
- half = embedding_dim // 2
101
- # Keep sinusoid construction in fp32 to mirror the native NiT implementation.
102
- exponent = -math.log(max_period) * torch.arange(half, dtype=torch.float32, device=timesteps.device) / half
103
- freqs = torch.exp(exponent)
104
- args = timesteps.float()[:, None] * freqs[None]
105
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
106
- if embedding_dim % 2:
107
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
108
- return embedding
109
-
110
- def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
111
- timestep_freq = self.get_timestep_embedding(timesteps, self.frequency_embedding_size).to(timesteps.dtype)
112
- return self.mlp(timestep_freq)
113
-
114
-
115
- class NiTLabelEmbedder(nn.Module):
116
- def __init__(self, num_classes: int, hidden_size: int, dropout_prob: float):
117
- super().__init__()
118
- use_cfg_embedding = dropout_prob > 0
119
- self.embedding_table = nn.Embedding(num_classes + int(use_cfg_embedding), hidden_size)
120
- self.num_classes = num_classes
121
- self.dropout_prob = dropout_prob
122
-
123
- def forward(self, class_labels: torch.LongTensor) -> torch.Tensor:
124
- return self.embedding_table(class_labels)
125
-
126
-
127
- class NiTRotaryEmbedding(nn.Module):
128
- def __init__(
129
- self,
130
- head_dim: int,
131
- custom_freqs: str = "normal",
132
- theta: int = 10000,
133
- max_cached_len: int = 1024,
134
- max_pe_len_h: Optional[int] = None,
135
- max_pe_len_w: Optional[int] = None,
136
- decouple: bool = False,
137
- ori_max_pe_len: Optional[int] = None,
138
- ):
139
- super().__init__()
140
- del max_pe_len_h, max_pe_len_w, decouple, ori_max_pe_len
141
- if custom_freqs not in {"normal", "scale1", "scale2"}:
142
- raise ValueError(
143
- "This Diffusers implementation supports the trained RoPE frequencies directly. "
144
- "Checkpoint conversion preserves weights; extrapolation variants should be handled "
145
- "by changing the model config before loading."
146
- )
147
- dim = head_dim // 2
148
- if dim % 2 != 0:
149
- raise ValueError("NiT rotary embedding requires head_dim // 2 to be even.")
150
- default_dtype = _get_float_dtype_or_default()
151
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=default_dtype) / dim))
152
- self.register_buffer("freqs_h", freqs, persistent=False)
153
- self.register_buffer("freqs_w", freqs.clone(), persistent=False)
154
- positions = torch.arange(max_cached_len, dtype=default_dtype)
155
- freqs_h_cached = torch.einsum("n,f->nf", positions, self.freqs_h).repeat_interleave(2, dim=-1)
156
- freqs_w_cached = torch.einsum("n,f->nf", positions, self.freqs_w).repeat_interleave(2, dim=-1)
157
- self.register_buffer("freqs_h_cached", freqs_h_cached, persistent=False)
158
- self.register_buffer("freqs_w_cached", freqs_w_cached, persistent=False)
159
-
160
- def forward(self, image_sizes: torch.LongTensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
161
- grids = []
162
- for height, width in image_sizes.tolist():
163
- # Use the same meshgrid ordering as native NiT to preserve RoPE-token alignment.
164
- grid_h = torch.arange(height, device=device)
165
- grid_w = torch.arange(width, device=device)
166
- grid = torch.meshgrid(grid_h, grid_w, indexing="xy")
167
- grids.append(torch.stack(grid, dim=0).reshape(2, -1))
168
- grid = torch.cat(grids, dim=1)
169
- freqs_h = self.freqs_h_cached.to(device)[grid[0]]
170
- freqs_w = self.freqs_w_cached.to(device)[grid[1]]
171
- freqs = torch.cat([freqs_h, freqs_w], dim=-1)
172
- return freqs.cos().unsqueeze(1), freqs.sin().unsqueeze(1)
173
-
174
-
175
- class NiTAttention(nn.Module):
176
- def __init__(self, hidden_size: int, num_heads: int, qk_norm: bool = False):
177
- super().__init__()
178
- if hidden_size % num_heads != 0:
179
- raise ValueError("hidden_size must be divisible by num_heads")
180
- self.num_heads = num_heads
181
- self.head_dim = hidden_size // num_heads
182
- self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
183
- self.q_norm = nn.LayerNorm(self.head_dim) if qk_norm else nn.Identity()
184
- self.k_norm = nn.LayerNorm(self.head_dim) if qk_norm else nn.Identity()
185
- self.proj = nn.Linear(hidden_size, hidden_size)
186
- self.proj_drop = nn.Dropout(0.0)
187
-
188
- def forward(
189
- self,
190
- hidden_states: torch.Tensor,
191
- cu_seqlens: torch.IntTensor,
192
- freqs_cos: torch.Tensor,
193
- freqs_sin: torch.Tensor,
194
- ) -> torch.Tensor:
195
- qkv = self.qkv(hidden_states).reshape(hidden_states.shape[0], 3, self.num_heads, self.head_dim)
196
- query, key, value = qkv.unbind(dim=1)
197
- original_dtype = qkv.dtype
198
- query = self.q_norm(query)
199
- key = self.k_norm(key)
200
- query = query * freqs_cos + _rotate_half(query) * freqs_sin
201
- key = key * freqs_cos + _rotate_half(key) * freqs_sin
202
- query = query.to(dtype=original_dtype)
203
- key = key.to(dtype=original_dtype)
204
-
205
- if flash_attn_varlen_func is not None and query.is_cuda:
206
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
207
- hidden_states = flash_attn_varlen_func(
208
- query, key, value, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
209
- ).reshape(hidden_states.shape[0], -1)
210
- else:
211
- segments = []
212
- for start, end in zip(cu_seqlens[:-1].tolist(), cu_seqlens[1:].tolist()):
213
- q = query[start:end].transpose(0, 1).unsqueeze(0)
214
- k = key[start:end].transpose(0, 1).unsqueeze(0)
215
- v = value[start:end].transpose(0, 1).unsqueeze(0)
216
- segments.append(F.scaled_dot_product_attention(q, k, v).squeeze(0).transpose(0, 1))
217
- hidden_states = torch.cat(segments, dim=0).reshape(hidden_states.shape[0], -1)
218
-
219
- hidden_states = self.proj(hidden_states)
220
- return self.proj_drop(hidden_states)
221
-
222
-
223
- class NiTMLP(nn.Module):
224
- def __init__(self, hidden_size: int, mlp_hidden_dim: int):
225
- super().__init__()
226
- self.fc1 = nn.Linear(hidden_size, mlp_hidden_dim)
227
- self.act = nn.GELU(approximate="tanh")
228
- self.drop1 = nn.Dropout(0.0)
229
- self.norm = nn.Identity()
230
- self.fc2 = nn.Linear(mlp_hidden_dim, hidden_size)
231
- self.drop2 = nn.Dropout(0.0)
232
-
233
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
234
- hidden_states = self.fc1(hidden_states)
235
- hidden_states = self.act(hidden_states)
236
- hidden_states = self.drop1(hidden_states)
237
- hidden_states = self.norm(hidden_states)
238
- hidden_states = self.fc2(hidden_states)
239
- return self.drop2(hidden_states)
240
-
241
-
242
- class NiTBlock(nn.Module):
243
- def __init__(
244
- self,
245
- hidden_size: int,
246
- num_heads: int,
247
- mlp_ratio: float = 4.0,
248
- qk_norm: bool = False,
249
- use_adaln_lora: bool = False,
250
- adaln_lora_dim: int = 512,
251
- ):
252
- super().__init__()
253
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
254
- self.attn = NiTAttention(hidden_size, num_heads=num_heads, qk_norm=qk_norm)
255
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
256
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
257
- self.mlp = NiTMLP(hidden_size, mlp_hidden_dim)
258
- if use_adaln_lora:
259
- self.adaLN_modulation = nn.Sequential(
260
- nn.SiLU(),
261
- nn.Linear(hidden_size, adaln_lora_dim, bias=True),
262
- nn.Linear(adaln_lora_dim, 6 * hidden_size, bias=True),
263
- )
264
- else:
265
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
266
-
267
- def forward(self, hidden_states, conditioning, cu_seqlens, freqs_cos, freqs_sin):
268
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(conditioning).chunk(
269
- 6, dim=-1
270
- )
271
- hidden_states = hidden_states + gate_msa * self.attn(
272
- _modulate(self.norm1(hidden_states), shift_msa, scale_msa), cu_seqlens, freqs_cos, freqs_sin
273
- )
274
- hidden_states = hidden_states + gate_mlp * self.mlp(
275
- _modulate(self.norm2(hidden_states), shift_mlp, scale_mlp)
276
- )
277
- return hidden_states
278
-
279
-
280
- class NiTFinalLayer(nn.Module):
281
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
282
- super().__init__()
283
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
284
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
285
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
286
-
287
- def forward(self, hidden_states: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
288
- shift, scale = self.adaLN_modulation(conditioning).chunk(2, dim=-1)
289
- hidden_states = _modulate(self.norm_final(hidden_states), shift, scale)
290
- return self.linear(hidden_states)
291
-
292
-
293
- def _build_mlp(hidden_size: int, projector_dim: int, z_dim: int) -> nn.Sequential:
294
- return nn.Sequential(
295
- nn.Linear(hidden_size, projector_dim),
296
- nn.SiLU(),
297
- nn.Linear(projector_dim, projector_dim),
298
- nn.SiLU(),
299
- nn.Linear(projector_dim, z_dim),
300
- )
301
-
302
-
303
- class NiTTransformer2DModel(ModelMixin, ConfigMixin):
304
- config_name = "config.json"
305
-
306
- @register_to_config
307
- def __init__(
308
- self,
309
- input_size: int = 32,
310
- patch_size: int = 1,
311
- in_channels: int = 32,
312
- hidden_size: int = 1152,
313
- depth: int = 28,
314
- num_heads: int = 16,
315
- mlp_ratio: float = 4.0,
316
- class_dropout_prob: float = 0.1,
317
- num_classes: int = 1000,
318
- encoder_depth: int = 8,
319
- projector_dim: int = 2048,
320
- z_dim: int = 1280,
321
- use_checkpoint: bool = False,
322
- custom_freqs: str = "normal",
323
- theta: int = 10000,
324
- max_pe_len_h: Optional[int] = None,
325
- max_pe_len_w: Optional[int] = None,
326
- decouple: bool = False,
327
- ori_max_pe_len: Optional[int] = None,
328
- qk_norm: bool = True,
329
- use_adaln_lora: bool = False,
330
- adaln_lora_dim: int = 512,
331
- ):
332
- super().__init__()
333
- del input_size
334
- self.in_channels = in_channels
335
- self.out_channels = in_channels
336
- self.patch_size = patch_size
337
- self.num_heads = num_heads
338
- self.num_classes = num_classes
339
- self.encoder_depth = encoder_depth
340
- self.use_checkpoint = use_checkpoint
341
-
342
- self.x_embedder = NiTPatchEmbed(patch_size, in_channels, hidden_size)
343
- self.t_embedder = NiTTimestepEmbedder(hidden_size)
344
- self.y_embedder = NiTLabelEmbedder(num_classes, hidden_size, class_dropout_prob)
345
- self.rope = NiTRotaryEmbedding(
346
- hidden_size // num_heads,
347
- custom_freqs=custom_freqs,
348
- theta=theta,
349
- max_pe_len_h=max_pe_len_h,
350
- max_pe_len_w=max_pe_len_w,
351
- decouple=decouple,
352
- ori_max_pe_len=ori_max_pe_len,
353
- )
354
- self.projector = _build_mlp(hidden_size, projector_dim, z_dim)
355
- self.blocks = nn.ModuleList(
356
- [
357
- NiTBlock(
358
- hidden_size,
359
- num_heads,
360
- mlp_ratio=mlp_ratio,
361
- qk_norm=qk_norm,
362
- use_adaln_lora=use_adaln_lora,
363
- adaln_lora_dim=adaln_lora_dim,
364
- )
365
- for _ in range(depth)
366
- ]
367
- )
368
- self.final_layer = NiTFinalLayer(hidden_size, patch_size, self.out_channels)
369
-
370
- def _pack_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.LongTensor, Tuple[int, int]]:
371
- batch_size, channels, height, width = hidden_states.shape
372
- if channels != self.in_channels:
373
- raise ValueError(f"Expected {self.in_channels} latent channels, got {channels}.")
374
- if height % self.patch_size != 0 or width % self.patch_size != 0:
375
- raise ValueError("Latent height and width must be divisible by patch_size.")
376
- latent_h = height // self.patch_size
377
- latent_w = width // self.patch_size
378
- hidden_states = hidden_states.reshape(batch_size, channels, latent_h, self.patch_size, latent_w, self.patch_size)
379
- hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).reshape(
380
- batch_size * latent_h * latent_w, channels, self.patch_size, self.patch_size
381
- )
382
- image_sizes = torch.tensor([[latent_h, latent_w]] * batch_size, device=hidden_states.device, dtype=torch.long)
383
- return hidden_states, image_sizes, (height, width)
384
-
385
- def _unpack_latents(self, hidden_states: torch.Tensor, image_sizes: torch.LongTensor) -> torch.Tensor:
386
- if image_sizes.shape[0] == 1:
387
- height, width = image_sizes[0].tolist()
388
- hidden_states = hidden_states.reshape(height, width, self.out_channels, self.patch_size, self.patch_size)
389
- return hidden_states.permute(2, 0, 3, 1, 4).reshape(
390
- 1, self.out_channels, height * self.patch_size, width * self.patch_size
391
- )
392
-
393
- samples = []
394
- cursor = 0
395
- for height, width in image_sizes.tolist():
396
- length = height * width
397
- sample = hidden_states[cursor : cursor + length].reshape(
398
- height, width, self.out_channels, self.patch_size, self.patch_size
399
- )
400
- samples.append(
401
- sample.permute(2, 0, 3, 1, 4).reshape(
402
- self.out_channels, height * self.patch_size, width * self.patch_size
403
- )
404
- )
405
- cursor += length
406
- if len({tuple(sample.shape) for sample in samples}) != 1:
407
- return hidden_states
408
- return torch.stack(samples, dim=0)
409
-
410
- def forward(
411
- self,
412
- hidden_states: torch.Tensor,
413
- timestep: Union[torch.Tensor, float],
414
- class_labels: torch.LongTensor,
415
- image_sizes: Optional[Union[torch.LongTensor, List[Tuple[int, int]]]] = None,
416
- return_dict: bool = True,
417
- output_projection_states: bool = False,
418
- ) -> Union[NiTTransformer2DModelOutput, Tuple[torch.Tensor, ...]]:
419
- input_was_image = hidden_states.dim() == 4 and image_sizes is None
420
- if input_was_image:
421
- hidden_states, image_sizes, _ = self._pack_latents(hidden_states)
422
- elif image_sizes is None:
423
- raise ValueError("image_sizes must be provided when hidden_states are already packed.")
424
- elif not torch.is_tensor(image_sizes):
425
- image_sizes = torch.tensor(image_sizes, device=hidden_states.device, dtype=torch.long)
426
- else:
427
- image_sizes = image_sizes.to(device=hidden_states.device, dtype=torch.long)
428
-
429
- if not torch.is_tensor(timestep):
430
- timestep = torch.tensor([timestep], device=hidden_states.device, dtype=hidden_states.dtype)
431
- timestep = timestep.to(device=hidden_states.device, dtype=hidden_states.dtype).flatten()
432
- if timestep.numel() == 1:
433
- timestep = timestep.repeat(image_sizes.shape[0])
434
- class_labels = class_labels.to(device=hidden_states.device, dtype=torch.long).flatten()
435
-
436
- hidden_states = self.x_embedder(hidden_states).squeeze(1)
437
- freqs_cos, freqs_sin = self.rope(image_sizes, hidden_states.device)
438
-
439
- seqlens = image_sizes[:, 0] * image_sizes[:, 1]
440
- cu_seqlens = torch.cat(
441
- [torch.zeros(1, device=hidden_states.device, dtype=torch.int32), torch.cumsum(seqlens, dim=0).int()]
442
- )
443
-
444
- conditioning = self.t_embedder(timestep) + self.y_embedder(class_labels)
445
- conditioning = torch.cat([conditioning[i].repeat(int(seqlens[i]), 1) for i in range(image_sizes.shape[0])], dim=0)
446
-
447
- projection_states = []
448
- for index, block in enumerate(self.blocks):
449
- if self.use_checkpoint and self.training:
450
- hidden_states = torch.utils.checkpoint.checkpoint(
451
- block, hidden_states, conditioning, cu_seqlens, freqs_cos, freqs_sin, use_reentrant=False
452
- )
453
- else:
454
- hidden_states = block(hidden_states, conditioning, cu_seqlens, freqs_cos, freqs_sin)
455
- if output_projection_states and (index + 1) == self.encoder_depth:
456
- projection_states.append(self.projector(hidden_states))
457
-
458
- hidden_states = self.final_layer(hidden_states, conditioning)
459
- hidden_states = hidden_states.reshape(hidden_states.shape[0], self.out_channels, self.patch_size, self.patch_size)
460
- if input_was_image:
461
- hidden_states = self._unpack_latents(hidden_states, image_sizes)
462
-
463
- if not return_dict:
464
- output = (hidden_states,)
465
- if output_projection_states:
466
- output = output + (tuple(projection_states),)
467
- return output
468
- return NiTTransformer2DModelOutput(
469
- sample=hidden_states,
470
- projection_states=tuple(projection_states) if output_projection_states else None,
471
- )