cp524 commited on
Commit
9a96e6d
·
1 Parent(s): 78f8c32

Add SMC stuff

Browse files
src/smc/lora_pipeline.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Callable, Dict, List, Optional, Union
3
+
4
+ import torch
5
+ from huggingface_hub.utils import validate_hf_hub_args
6
+
7
+ from diffusers.utils import (
8
+ USE_PEFT_BACKEND,
9
+ deprecate,
10
+ get_submodule_by_name,
11
+ is_bitsandbytes_available,
12
+ is_gguf_available,
13
+ is_peft_available,
14
+ is_peft_version,
15
+ is_torch_version,
16
+ is_transformers_available,
17
+ is_transformers_version,
18
+ logging,
19
+ )
20
+
21
+ from diffusers.loaders.lora_base import (
22
+ LoraBaseMixin,
23
+ _fetch_state_dict,
24
+ _pack_dict_with_prefix
25
+ )
26
+
27
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
28
+ if is_torch_version(">=", "1.9.0"):
29
+ if (
30
+ is_peft_available()
31
+ and is_peft_version(">=", "0.13.1")
32
+ and is_transformers_available()
33
+ and is_transformers_version(">", "4.45.2")
34
+ ):
35
+ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ TRANSFORMER_NAME = "transformer"
42
+
43
+ class MeissonicLoraLoaderMixin(LoraBaseMixin):
44
+ r"""
45
+ Load LoRA layers into [`Transformer2DModel`]. Specific to [`MeissonicPipeline`].
46
+ """
47
+
48
+ _lora_loadable_modules = ["transformer"]
49
+ transformer_name = TRANSFORMER_NAME
50
+
51
+ @classmethod
52
+ @validate_hf_hub_args
53
+ def lora_state_dict(
54
+ cls,
55
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
56
+ return_alphas: bool = False,
57
+ **kwargs,
58
+ ):
59
+ r"""
60
+ Return state dict for lora weights and the network alphas.
61
+
62
+ <Tip warning={true}>
63
+
64
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
65
+
66
+ This function is experimental and might change in the future.
67
+
68
+ </Tip>
69
+
70
+ Parameters:
71
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
72
+ Can be either:
73
+
74
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
75
+ the Hub.
76
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
77
+ with [`ModelMixin.save_pretrained`].
78
+ - A [torch state
79
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
80
+
81
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
82
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
83
+ is not used.
84
+ force_download (`bool`, *optional*, defaults to `False`):
85
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
86
+ cached versions if they exist.
87
+
88
+ proxies (`Dict[str, str]`, *optional*):
89
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
90
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
91
+ local_files_only (`bool`, *optional*, defaults to `False`):
92
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
93
+ won't be downloaded from the Hub.
94
+ token (`str` or *bool*, *optional*):
95
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
96
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
97
+ revision (`str`, *optional*, defaults to `"main"`):
98
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
99
+ allowed by Git.
100
+ subfolder (`str`, *optional*, defaults to `""`):
101
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
102
+ return_lora_metadata (`bool`, *optional*, defaults to False):
103
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
104
+ """
105
+ # Load the main state dict first which has the LoRA layers for either of
106
+ # transformer and text encoder or both.
107
+ cache_dir = kwargs.pop("cache_dir", None)
108
+ force_download = kwargs.pop("force_download", False)
109
+ proxies = kwargs.pop("proxies", None)
110
+ local_files_only = kwargs.pop("local_files_only", None)
111
+ token = kwargs.pop("token", None)
112
+ revision = kwargs.pop("revision", None)
113
+ subfolder = kwargs.pop("subfolder", None)
114
+ weight_name = kwargs.pop("weight_name", None)
115
+ use_safetensors = kwargs.pop("use_safetensors", None)
116
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
117
+
118
+ allow_pickle = False
119
+ if use_safetensors is None:
120
+ use_safetensors = True
121
+ allow_pickle = True
122
+
123
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
124
+
125
+ state_dict, metadata = _fetch_state_dict(
126
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
127
+ weight_name=weight_name,
128
+ use_safetensors=use_safetensors,
129
+ local_files_only=local_files_only,
130
+ cache_dir=cache_dir,
131
+ force_download=force_download,
132
+ proxies=proxies,
133
+ token=token,
134
+ revision=revision,
135
+ subfolder=subfolder,
136
+ user_agent=user_agent,
137
+ allow_pickle=allow_pickle,
138
+ )
139
+
140
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
141
+ if is_dora_scale_present:
142
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
143
+ logger.warning(warn_msg)
144
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
145
+
146
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
147
+ return out
148
+
149
+ def load_lora_weights(
150
+ self,
151
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
152
+ adapter_name: Optional[str] = None,
153
+ hotswap: bool = False,
154
+ **kwargs,
155
+ ):
156
+ """
157
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
158
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
159
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
160
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
161
+ dict is loaded into `self.transformer`.
162
+
163
+ Parameters:
164
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
165
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
166
+ adapter_name (`str`, *optional*):
167
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
168
+ `default_{i}` where i is the total number of adapters being loaded.
169
+ low_cpu_mem_usage (`bool`, *optional*):
170
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
171
+ weights.
172
+ hotswap (`bool`, *optional*):
173
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
174
+ kwargs (`dict`, *optional*):
175
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
176
+ """
177
+ if not USE_PEFT_BACKEND:
178
+ raise ValueError("PEFT backend is required for this method.")
179
+
180
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
181
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
182
+ raise ValueError(
183
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
184
+ )
185
+
186
+ # if a dict is passed, copy it instead of modifying it inplace
187
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
188
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
189
+
190
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
191
+ kwargs["return_lora_metadata"] = True
192
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
193
+
194
+ is_correct_format = all("lora" in key for key in state_dict.keys())
195
+ if not is_correct_format:
196
+ raise ValueError("Invalid LoRA checkpoint.")
197
+
198
+ self.load_lora_into_transformer(
199
+ state_dict,
200
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
201
+ adapter_name=adapter_name,
202
+ metadata=metadata,
203
+ _pipeline=self,
204
+ low_cpu_mem_usage=low_cpu_mem_usage,
205
+ hotswap=hotswap,
206
+ )
207
+
208
+ @classmethod
209
+ def load_lora_into_transformer(
210
+ cls,
211
+ state_dict,
212
+ transformer,
213
+ adapter_name=None,
214
+ _pipeline=None,
215
+ low_cpu_mem_usage=False,
216
+ hotswap: bool = False,
217
+ metadata=None,
218
+ ):
219
+ """
220
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
221
+
222
+ Parameters:
223
+ state_dict (`dict`):
224
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
225
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
226
+ encoder lora layers.
227
+ transformer (`SD3Transformer2DModel`):
228
+ The Transformer model to load the LoRA layers into.
229
+ adapter_name (`str`, *optional*):
230
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
231
+ `default_{i}` where i is the total number of adapters being loaded.
232
+ low_cpu_mem_usage (`bool`, *optional*):
233
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
234
+ weights.
235
+ hotswap (`bool`, *optional*):
236
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
237
+ metadata (`dict`):
238
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
239
+ from the state dict.
240
+ """
241
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
242
+ raise ValueError(
243
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
244
+ )
245
+
246
+ # Load the layers corresponding to transformer.
247
+ logger.info(f"Loading {cls.transformer_name}.")
248
+ transformer.load_lora_adapter(
249
+ state_dict,
250
+ network_alphas=None,
251
+ adapter_name=adapter_name,
252
+ metadata=metadata,
253
+ _pipeline=_pipeline,
254
+ low_cpu_mem_usage=low_cpu_mem_usage,
255
+ hotswap=hotswap,
256
+ )
257
+
258
+ @classmethod
259
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
260
+ def save_lora_weights(
261
+ cls,
262
+ save_directory: Union[str, os.PathLike],
263
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
264
+ is_main_process: bool = True,
265
+ weight_name: str = None,
266
+ save_function: Callable = None,
267
+ safe_serialization: bool = True,
268
+ transformer_lora_adapter_metadata: Optional[dict] = None,
269
+ ):
270
+ r"""
271
+ Save the LoRA parameters corresponding to the transformer.
272
+
273
+ Arguments:
274
+ save_directory (`str` or `os.PathLike`):
275
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
276
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
277
+ State dict of the LoRA layers corresponding to the `transformer`.
278
+ is_main_process (`bool`, *optional*, defaults to `True`):
279
+ Whether the process calling this is the main process or not. Useful during distributed training and you
280
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
281
+ process to avoid race conditions.
282
+ save_function (`Callable`):
283
+ The function to use to save the state dictionary. Useful during distributed training when you need to
284
+ replace `torch.save` with another method. Can be configured with the environment variable
285
+ `DIFFUSERS_SAVE_MODE`.
286
+ safe_serialization (`bool`, *optional*, defaults to `True`):
287
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
288
+ transformer_lora_adapter_metadata:
289
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
290
+ """
291
+ state_dict = {}
292
+ lora_adapter_metadata = {}
293
+
294
+ if not transformer_lora_layers:
295
+ raise ValueError("You must pass `transformer_lora_layers`.")
296
+
297
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
298
+
299
+ if transformer_lora_adapter_metadata is not None:
300
+ lora_adapter_metadata.update(
301
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
302
+ )
303
+
304
+ # Save the model
305
+ cls.write_lora_layers(
306
+ state_dict=state_dict,
307
+ save_directory=save_directory,
308
+ is_main_process=is_main_process,
309
+ weight_name=weight_name,
310
+ save_function=save_function,
311
+ safe_serialization=safe_serialization,
312
+ lora_adapter_metadata=lora_adapter_metadata,
313
+ )
src/smc/pipeline.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Callable, List
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+ from diffusers.image_processor import VaeImageProcessor
8
+ from diffusers.models.autoencoders.vq_model import VQModel
9
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12
+
13
+ from src.smc.transformer import Transformer2DModel
14
+ from src.smc.scheduler import BaseScheduler
15
+ from src.smc.resampling import compute_ess_from_log_w, normalize_weights
16
+ from src.smc.lora_pipeline import MeissonicLoraLoaderMixin
17
+
18
+
19
+ def logmeanexp(x, dim=None, keepdim=False):
20
+ """Numerically stable log-mean-exp using torch.logsumexp."""
21
+ if dim is None:
22
+ x = x.view(-1)
23
+ dim = 0
24
+ # log-sum-exp with or without keeping the reduced dim
25
+ lse = torch.logsumexp(x, dim=dim, keepdim=keepdim)
26
+ # subtract log(N) to convert sum into mean (broadcasts correctly)
27
+ return lse - math.log(x.size(dim))
28
+
29
+
30
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
31
+ """
32
+ Build positional IDs for latent-image tokens.
33
+
34
+ Each latent token corresponds to a downsampled image “pixel” in a 2D grid.
35
+ This function creates a (H//2, W//2, 3) grid where:
36
+ - channel 0 is reserved (all zeros)
37
+ - channel 1 stores the row index (0 .. H//2-1)
38
+ - channel 2 stores the column index (0 .. W//2-1)
39
+
40
+ Args:
41
+ batch_size (int): Number of images in the batch (unused here, but kept for API consistency).
42
+ height (int): Input image height (pre-VAE) or latent height depending on call site.
43
+ width (int): Input image width (pre-VAE) or latent width depending on call site.
44
+ device (torch.device): Device on which to place the returned tensor.
45
+ dtype (torch.dtype): Desired data type of the returned tensor.
46
+
47
+ Returns:
48
+ torch.Tensor of shape ((H//2 * W//2), 3) with dtype and device as specified.
49
+ Each row is [0, row_index, col_index], flattened in row-major order.
50
+ """
51
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
52
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
53
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
54
+
55
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
56
+
57
+ latent_image_ids = latent_image_ids.reshape(
58
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
59
+ )
60
+
61
+ return latent_image_ids.to(device=device, dtype=dtype)
62
+
63
+
64
+ class Pipeline(
65
+ DiffusionPipeline,
66
+ MeissonicLoraLoaderMixin,
67
+ ):
68
+ image_processor: VaeImageProcessor
69
+ vqvae: VQModel
70
+ tokenizer: CLIPTokenizer
71
+ text_encoder: CLIPTextModelWithProjection
72
+ transformer: Transformer2DModel
73
+ scheduler: BaseScheduler
74
+
75
+ model_cpu_offload_seq = "text_encoder->transformer->vqvae"
76
+
77
+ def __init__(
78
+ self,
79
+ vqvae: VQModel,
80
+ tokenizer: CLIPTokenizer,
81
+ text_encoder: CLIPTextModelWithProjection,
82
+ transformer: Transformer2DModel,
83
+ scheduler: BaseScheduler,
84
+ ):
85
+ super().__init__()
86
+
87
+ self.register_modules(
88
+ vqvae=vqvae,
89
+ tokenizer=tokenizer,
90
+ text_encoder=text_encoder,
91
+ transformer=transformer,
92
+ scheduler=scheduler,
93
+ )
94
+ self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) # type: ignore
95
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
96
+ self.model_dtype = torch.bfloat16
97
+
98
+ self.mask_index = self.scheduler.mask_token_id # type: ignore
99
+ self.vocab_size = self.transformer.config.vocab_size # type:ignore
100
+ self.codebook_size = self.transformer.config.codebook_size # type: ignore
101
+
102
+ @torch.no_grad()
103
+ def __call__(
104
+ self,
105
+ prompt: str|List[str],
106
+ reward_fn: Callable,
107
+ resample_fn: Callable,
108
+ resample_frequency: int = 1,
109
+ kl_weight: float = 1.0,
110
+ lambdas: Optional[torch.Tensor] = None,
111
+ height: Optional[int] = 1024,
112
+ width: Optional[int] = 1024,
113
+ num_inference_steps: int = 48,
114
+ guidance_scale: float = 9.0,
115
+ negative_prompt = None,
116
+ batches: int = 1, # Number of independent SMCs
117
+ num_particles: int = 1, # Number of particles per SMC
118
+ batch_p: int = 1, # Number of parallel particles
119
+ phi: int = 1, # number of samples for reward approximation
120
+ tau: float = 1.0, # temperature for taking x0 samples
121
+ output_type="pil",
122
+ micro_conditioning_aesthetic_score: int = 6,
123
+ micro_conditioning_crop_coord: Tuple[int, int] = (0, 0),
124
+ proposal_type:str = "locally_optimal",
125
+ ft_model_pipe = None, # needs to supplied if proposal_type is ft_model
126
+ use_ft_model_for_expected_reward: bool = False, # Whether to use the forward model for expected reward
127
+ use_continuous_formulation: bool = False, # Whether to use a continuous formulation of carry over unmasking
128
+ disable_progress_bar: bool = False,
129
+ final_strategy="argmax_rewards",
130
+ verbose=True,
131
+ ):
132
+ # 0. Set default lambdas
133
+ if lambdas is None:
134
+ lambdas = torch.ones(num_inference_steps + 1)
135
+ assert len(lambdas) == num_inference_steps + 1, f"lambdas must of length {num_inference_steps + 1}"
136
+ lambdas = lambdas.clamp_min(0.001).to(self._execution_device)
137
+
138
+ # 1. n_particles, batch_size etc
139
+ total_particles = batches * num_particles
140
+ batch_p = min(batch_p, total_particles)
141
+ H, W = height // self.vae_scale_factor, width // self.vae_scale_factor
142
+
143
+ # 2.1. Calculate prompt (and negative prompt) embeddings
144
+ if isinstance(prompt, str):
145
+ prompt = [prompt]
146
+ input_ids = self.tokenizer(
147
+ prompt,
148
+ return_tensors="pt",
149
+ padding="max_length",
150
+ truncation=True,
151
+ max_length=77,
152
+ ).input_ids.to(self._execution_device)
153
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
154
+ prompt_embeds = outputs.text_embeds
155
+ encoder_hidden_states = outputs.hidden_states[-2]
156
+ prompt_embeds = prompt_embeds.repeat(batch_p, 1)
157
+ encoder_hidden_states = encoder_hidden_states.repeat(batch_p, 1, 1)
158
+ if guidance_scale > 1.0:
159
+ if negative_prompt is None:
160
+ negative_prompt = [""]
161
+ else:
162
+ negative_prompt = [negative_prompt]
163
+ input_ids = self.tokenizer(
164
+ negative_prompt,
165
+ return_tensors="pt",
166
+ padding="max_length",
167
+ truncation=True,
168
+ max_length=77,
169
+ ).input_ids.to(self._execution_device)
170
+ outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
171
+ negative_prompt_embeds = outputs.text_embeds
172
+ negative_encoder_hidden_states = outputs.hidden_states[-2]
173
+ negative_prompt_embeds = negative_prompt_embeds.repeat(batch_p, 1)
174
+ negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(batch_p, 1, 1)
175
+ prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds])
176
+ encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states])
177
+
178
+ # 2.2. Prepare micro-conditions
179
+ micro_conds = torch.tensor(
180
+ [
181
+ width,
182
+ height,
183
+ micro_conditioning_crop_coord[0],
184
+ micro_conditioning_crop_coord[1],
185
+ micro_conditioning_aesthetic_score,
186
+ ],
187
+ device=self._execution_device,
188
+ dtype=encoder_hidden_states.dtype,
189
+ )
190
+ micro_conds = micro_conds.unsqueeze(0)
191
+ micro_conds = micro_conds.expand(2 * batch_p if guidance_scale > 1.0 else batch_p, -1)
192
+
193
+
194
+ # 3. Intialize latents
195
+ latents = torch.full(
196
+ (total_particles, H, W), self.mask_index, dtype=torch.long, device=self._execution_device # type: ignore
197
+ )
198
+
199
+ # Set some constant vectors
200
+ ONE = torch.ones(self.vocab_size, device=self._execution_device).float()
201
+ MASK = F.one_hot(torch.tensor(self.mask_index), num_classes=self.vocab_size).float().to(self._execution_device) # type: ignore
202
+
203
+ # 4. Set scheduler timesteps
204
+ self.scheduler.set_timesteps(num_inference_steps)
205
+
206
+ # 5. Set SMC variables
207
+ logits = torch.zeros((*latents.shape, self.vocab_size), device=self._execution_device)
208
+ logits_ft_model = torch.zeros((*latents.shape, self.vocab_size), device=self._execution_device)
209
+ rewards = torch.zeros((total_particles,), device=self._execution_device)
210
+ rewards_grad = torch.zeros((*latents.shape, self.vocab_size), device=self._execution_device)
211
+ log_twist = torch.zeros((total_particles, ), device=self._execution_device)
212
+ log_prob_proposal = torch.zeros((total_particles, ), device=self._execution_device)
213
+ log_prob_diffusion = torch.zeros((total_particles, ), device=self._execution_device)
214
+ log_w = torch.zeros((total_particles, ), device=self._execution_device)
215
+
216
+ def propagate():
217
+ if proposal_type == "locally_optimal":
218
+ propgate_locally_optimal()
219
+ # elif proposal_type == "straight_through_gradients":
220
+ # propagate_straight_through_gradients()
221
+ elif proposal_type == "reverse":
222
+ propagate_reverse()
223
+ elif proposal_type == "without_SMC":
224
+ propagate_without_SMC()
225
+ elif proposal_type == "ft_model":
226
+ propagate_ft_model()
227
+ else:
228
+ raise NotImplementedError(f"Proposal type {proposal_type} is not implemented.")
229
+
230
+ def propgate_locally_optimal():
231
+ nonlocal log_w, latents, log_prob_proposal, log_prob_diffusion, logits, rewards, rewards_grad, log_twist
232
+ log_twist_prev = log_twist.clone()
233
+ for j in range(0, total_particles, batch_p):
234
+ latents_batch = latents[j:j+batch_p]
235
+ with torch.enable_grad():
236
+ latents_one_hot = F.one_hot(latents_batch, num_classes=self.vocab_size).to(dtype=self.model_dtype).requires_grad_(True)
237
+ tmp_logits = self.get_logits(latents_one_hot, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
238
+
239
+ tmp_rewards = torch.zeros(latents_batch.size(0), phi, device=self._execution_device)
240
+ gamma = 1 - ((ONE - MASK) * latents_one_hot).sum(dim=-1, keepdim=True)
241
+ for phi_i in range(phi):
242
+ sample = F.gumbel_softmax(tmp_logits, tau=tau, hard=True)
243
+ if use_continuous_formulation:
244
+ sample = gamma * sample + (ONE - MASK) * latents_one_hot
245
+ sample = self._decode_one_hot_latents(sample, batch_p, height, width, "pt")
246
+ tmp_rewards[:, phi_i] = reward_fn(sample)
247
+ tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
248
+
249
+ tmp_rewards_grad = torch.autograd.grad(
250
+ outputs=tmp_rewards,
251
+ inputs=latents_one_hot,
252
+ grad_outputs=torch.ones_like(tmp_rewards)
253
+ )[0].detach()
254
+
255
+ logits[j:j+batch_p] = tmp_logits.detach()
256
+ rewards[j:j+batch_p] = tmp_rewards.detach()
257
+ rewards_grad[j:j+batch_p] = tmp_rewards_grad.detach()
258
+ log_twist[j:j+batch_p] = rewards[j:j+batch_p] * scale_cur
259
+
260
+ if verbose:
261
+ print("Rewards: ", rewards)
262
+
263
+ # Calculate weights
264
+ incremental_log_w = (log_prob_diffusion - log_prob_proposal) + (log_twist - log_twist_prev)
265
+ log_w += incremental_log_w
266
+
267
+ # Now reshape log_w to (batches, num_particles)
268
+ log_w = log_w.reshape(batches, num_particles)
269
+
270
+ if verbose:
271
+ print("log_prob_diffusion - log_prob_proposal: ", log_prob_diffusion - log_prob_proposal)
272
+ print("log_twist - log_twist_prev:", log_twist - log_twist_prev)
273
+ print("Incremental log weights: ", incremental_log_w)
274
+ print("Log weights: ", log_w)
275
+ print("Normalized weights: ", normalize_weights(log_w, dim=-1))
276
+
277
+ # Resample particles
278
+ if verbose:
279
+ print(f"ESS: ", compute_ess_from_log_w(log_w, dim=-1))
280
+
281
+ if resample_condition:
282
+ resample_indices = []
283
+ log_w_new = []
284
+ is_resampled = False
285
+ for batch in range(batches):
286
+ resample_indices_batch, is_resampled_batch, log_w_batch = resample_fn(log_w[batch])
287
+ resample_indices.append(resample_indices_batch + batch * num_particles)
288
+ log_w_new.append(log_w_batch)
289
+ is_resampled = is_resampled or is_resampled_batch
290
+
291
+ resample_indices = torch.cat(resample_indices, dim=0)
292
+ log_w = torch.cat(log_w_new, dim=0)
293
+
294
+ if is_resampled:
295
+ latents = latents[resample_indices]
296
+ logits = logits[resample_indices]
297
+ rewards = rewards[resample_indices]
298
+ rewards_grad = rewards_grad[resample_indices]
299
+ log_twist = log_twist[resample_indices]
300
+
301
+ if verbose:
302
+ print("Resample indices: ", resample_indices)
303
+
304
+ if log_w.ndim == 2:
305
+ log_w = log_w.reshape(total_particles)
306
+
307
+
308
+ # Propose new particles
309
+ sched_out = self.scheduler.step_with_approx_guidance(
310
+ latents=latents,
311
+ logits=logits,
312
+ approx_guidance=rewards_grad * scale_next,
313
+ step=i,
314
+ )
315
+ if verbose:
316
+ print("Approx guidance norm: ", ((rewards_grad * scale_next) ** 2).sum(dim=(1, 2)).sqrt())
317
+ latents, log_prob_proposal, log_prob_diffusion = (
318
+ sched_out.new_latents,
319
+ sched_out.log_prob_proposal,
320
+ sched_out.log_prob_diffusion,
321
+ )
322
+
323
+ def propagate_reverse():
324
+ nonlocal log_w, latents, logits, rewards, log_twist
325
+ log_twist_prev = log_twist.clone()
326
+ for j in range(0, total_particles, batch_p):
327
+ latents_batch = latents[j:j+batch_p]
328
+ with torch.no_grad():
329
+ tmp_logits = self.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
330
+
331
+ tmp_rewards = torch.zeros(latents_batch.size(0), phi, device=self._execution_device)
332
+ tmp_logp_x0 = self.model._subs_parameterization(tmp_logits, latents_batch)
333
+ for phi_i in range(phi):
334
+ sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
335
+ sample = self._decode_latents(sample, batch_p, height, width, "pt")
336
+ tmp_rewards[:, phi_i] = reward_fn(sample)
337
+ tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
338
+
339
+ logits[j:j+batch_p] = tmp_logits.detach()
340
+ rewards[j:j+batch_p] = tmp_rewards.detach()
341
+ log_twist[j:j+batch_p] = rewards[j:j+batch_p] * scale_cur
342
+
343
+ if verbose:
344
+ print("Rewards: ", rewards)
345
+
346
+ # Calculate weights
347
+ incremental_log_w = (log_twist - log_twist_prev)
348
+ log_w += incremental_log_w
349
+
350
+ # Now reshape log_w to (batches, num_particles)
351
+ log_w = log_w.reshape(batches, num_particles)
352
+
353
+ if verbose:
354
+ print("log_twist - log_twist_prev:", log_twist - log_twist_prev)
355
+ print("Incremental log weights: ", incremental_log_w)
356
+ print("Log weights: ", log_w)
357
+ print("Normalized weights: ", normalize_weights(log_w, dim=-1))
358
+
359
+ # Resample particles
360
+ if verbose:
361
+ print(f"ESS: ", compute_ess_from_log_w(log_w, dim=-1))
362
+
363
+ if resample_condition:
364
+ resample_indices = []
365
+ log_w_new = []
366
+ is_resampled = False
367
+ for batch in range(batches):
368
+ resample_indices_batch, is_resampled_batch, log_w_batch = resample_fn(log_w[batch])
369
+ resample_indices.append(resample_indices_batch + batch * num_particles)
370
+ log_w_new.append(log_w_batch)
371
+ is_resampled = is_resampled or is_resampled_batch
372
+
373
+ resample_indices = torch.cat(resample_indices, dim=0)
374
+ log_w = torch.cat(log_w_new, dim=0)
375
+
376
+ if is_resampled:
377
+ latents = latents[resample_indices]
378
+ logits = logits[resample_indices]
379
+ rewards = rewards[resample_indices]
380
+ log_twist = log_twist[resample_indices]
381
+
382
+ if verbose:
383
+ print("Resample indices: ", resample_indices)
384
+
385
+ if log_w.ndim == 2:
386
+ log_w = log_w.reshape(total_particles)
387
+
388
+
389
+ # Propose new particles
390
+ sched_out = self.scheduler.step(
391
+ latents=latents,
392
+ logits=logits,
393
+ step=i,
394
+ )
395
+ latents = sched_out.new_latents
396
+
397
+ def propagate_without_SMC():
398
+ nonlocal latents, logits
399
+ for j in range(0, total_particles, batch_p):
400
+ latents_batch = latents[j:j+batch_p]
401
+ with torch.no_grad():
402
+ tmp_logits = self.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
403
+ logits[j:j+batch_p] = tmp_logits.detach()
404
+
405
+ # Propose new particles
406
+ sched_out = self.scheduler.step(
407
+ latents=latents,
408
+ logits=logits,
409
+ step=i,
410
+ )
411
+ latents = sched_out.new_latents
412
+
413
+ def propagate_ft_model():
414
+ assert ft_model_pipe is not None, f"ft_model must be provided for proposal_type={proposal_type}."
415
+ nonlocal log_w, latents, log_prob_proposal, log_prob_diffusion, logits, logits_ft_model, rewards, log_twist
416
+ log_twist_prev = log_twist.clone()
417
+ for j in range(0, total_particles, batch_p):
418
+ latents_batch = latents[j:j+batch_p]
419
+ with torch.no_grad():
420
+ tmp_logits = self.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
421
+ tmp_logits_ft_model = ft_model_pipe.get_logits(latents_batch, guidance_scale, height, encoder_hidden_states, micro_conds, prompt_embeds, timestep)
422
+
423
+ tmp_rewards = torch.zeros(latents_batch.size(0), phi, device=self._execution_device)
424
+ if use_ft_model_for_expected_reward:
425
+ tmp_logp_x0 = ft_model_pipe._subs_parameterization(tmp_logits_ft_model, latents_batch)
426
+ else:
427
+ tmp_logp_x0 = self._subs_parameterization(tmp_logits, latents_batch)
428
+ for phi_i in range(phi):
429
+ sample = F.gumbel_softmax(tmp_logp_x0, tau=tau, hard=True).argmax(dim=-1)
430
+ sample = self._decode_latents(sample, batch_p, height, width, "pt")
431
+ tmp_rewards[:, phi_i] = reward_fn(sample)
432
+ tmp_rewards = logmeanexp(tmp_rewards * scale_cur, dim=-1) / scale_cur
433
+
434
+ logits[j:j+batch_p] = tmp_logits.detach()
435
+ logits_ft_model[j:j+batch_p] = tmp_logits_ft_model.detach()
436
+ rewards[j:j+batch_p] = tmp_rewards.detach()
437
+ log_twist[j:j+batch_p] = rewards[j:j+batch_p] * scale_cur
438
+
439
+ if verbose:
440
+ print("Rewards: ", rewards)
441
+
442
+ # Calculate weights
443
+ incremental_log_w = (log_prob_diffusion - log_prob_proposal) + (log_twist - log_twist_prev)
444
+ log_w += incremental_log_w
445
+
446
+ # Now reshape log_w to (batches, num_particles)
447
+ log_w = log_w.reshape(batches, num_particles)
448
+
449
+ if verbose:
450
+ print("log_prob_diffusion - log_prob_proposal: ", log_prob_diffusion - log_prob_proposal)
451
+ print("log_twist - log_twist_prev:", log_twist - log_twist_prev)
452
+ print("Incremental log weights: ", incremental_log_w)
453
+ print("Log weights: ", log_w)
454
+ print("Normalized weights: ", normalize_weights(log_w, dim=-1))
455
+
456
+ # Resample particles
457
+ if verbose:
458
+ print(f"ESS: ", compute_ess_from_log_w(log_w, dim=-1))
459
+
460
+ if resample_condition:
461
+ resample_indices = []
462
+ log_w_new = []
463
+ is_resampled = False
464
+ for batch in range(batches):
465
+ resample_indices_batch, is_resampled_batch, log_w_batch = resample_fn(log_w[batch])
466
+ resample_indices.append(resample_indices_batch + batch * num_particles)
467
+ log_w_new.append(log_w_batch)
468
+ is_resampled = is_resampled or is_resampled_batch
469
+
470
+ resample_indices = torch.cat(resample_indices, dim=0)
471
+ log_w = torch.cat(log_w_new, dim=0)
472
+
473
+ if is_resampled:
474
+ latents = latents[resample_indices]
475
+ logits = logits[resample_indices]
476
+ logits_ft_model = logits_ft_model[resample_indices]
477
+ rewards = rewards[resample_indices]
478
+ log_twist = log_twist[resample_indices]
479
+
480
+ if verbose:
481
+ print("Resample indices: ", resample_indices)
482
+
483
+ if log_w.ndim == 2:
484
+ log_w = log_w.reshape(total_particles)
485
+
486
+
487
+ # Propose new particles
488
+ approx_guidance = logits_ft_model - logits # this effectively makes logits_ft_model the proposal distribution
489
+ approx_guidance[..., self.codebook_size:] = 0.0 # avoid nan due to (inf - inf)
490
+ sched_out = self.scheduler.step_with_approx_guidance(
491
+ latents=latents,
492
+ logits=logits,
493
+ approx_guidance=approx_guidance,
494
+ step=i,
495
+ )
496
+ latents, log_prob_proposal, log_prob_diffusion = (
497
+ sched_out.new_latents,
498
+ sched_out.log_prob_proposal,
499
+ sched_out.log_prob_diffusion,
500
+ )
501
+
502
+ bar = enumerate(reversed(range(num_inference_steps)))
503
+ if not disable_progress_bar:
504
+ bar = tqdm(bar, leave=False)
505
+ for i, timestep in bar:
506
+ resample_condition = (i + 1) % resample_frequency == 0
507
+ scale_cur = lambdas[i] / kl_weight
508
+ scale_next = lambdas[i + 1] / kl_weight
509
+ if verbose:
510
+ print(f"scale_cur: {scale_cur}, scale_next: {scale_next}")
511
+ propagate()
512
+ print('\n\n')
513
+
514
+ # Final SMC weights
515
+ scale_cur = lambdas[-1] / kl_weight
516
+ log_twist_prev = log_twist.clone()
517
+ for j in range(0, total_particles, batch_p):
518
+ latents_batch = latents[j:j+batch_p]
519
+ with torch.no_grad():
520
+ sample = self._decode_latents(latents_batch, batch_p, height, width, "pt")
521
+ tmp_rewards = reward_fn(sample)
522
+ rewards[j:j+batch_p] = tmp_rewards
523
+ log_twist[j:j+batch_p] = tmp_rewards * scale_cur
524
+
525
+ if verbose:
526
+ print("Rewards: ", rewards)
527
+
528
+ # Calculate weights
529
+ incremental_log_w = (log_prob_diffusion - log_prob_proposal) + (log_twist - log_twist_prev)
530
+ log_w += incremental_log_w
531
+
532
+ # Now reshape everything to (batches, num_particles) for final strategy
533
+ log_w = log_w.reshape(batches, num_particles)
534
+ latents = latents.reshape(batches, num_particles, H, W)
535
+ rewards = rewards.reshape(batches, num_particles)
536
+
537
+ if verbose:
538
+ print("log_prob_diffusion - log_prob_proposal: ", log_prob_diffusion - log_prob_proposal)
539
+ print("log_twist - log_twist_prev:", log_twist - log_twist_prev)
540
+ print("Incremental log weights: ", incremental_log_w)
541
+ print("Log weights: ", log_w)
542
+ print("Normalized weights: ", normalize_weights(log_w, dim=-1))
543
+
544
+ if final_strategy == "multinomial":
545
+ final_indices = torch.multinomial(normalize_weights(log_w, dim=-1), num_samples=1).squeeze(-1)
546
+ elif final_strategy == "argmax_rewards":
547
+ final_indices = rewards.argmax(dim=-1)
548
+ elif final_strategy == "argmax_weights":
549
+ final_indices = log_w.argmax(dim=-1)
550
+ else:
551
+ raise NotImplementedError(f"Final strategy {final_strategy} is not implemented.")
552
+
553
+ if verbose:
554
+ print("Final selected indices: ", final_indices)
555
+
556
+ latents = latents[
557
+ torch.arange(batches, device=latents.device),
558
+ final_indices
559
+ ]
560
+
561
+ # Decode latents
562
+ outputs = []
563
+ for j in range(0, batches, batch_p):
564
+ latents_batch = latents[j:j+batch_p]
565
+ outputs.extend(
566
+ self._decode_latents(latents_batch, batch_p, height, width, output_type) # type: ignore
567
+ )
568
+ if output_type == "pt":
569
+ outputs = torch.stack(outputs, dim=0)
570
+ return outputs
571
+
572
+ def get_logits(self, latents, guidance_scale, resolution, encoder_hidden_states, micro_conds, prompt_embeds, timestep):
573
+ if guidance_scale > 1.0:
574
+ # Latents are duplicated to get both unconditional and conditional logits
575
+ model_input = torch.cat([latents] * 2) # type: ignore
576
+ else:
577
+ model_input = latents
578
+ # img_ids, text_ids are used for positional embeddings
579
+ if resolution == 1024: #args.resolution == 1024:
580
+ img_ids = _prepare_latent_image_ids(model_input.shape[0], model_input.shape[1],model_input.shape[2],model_input.device,model_input.dtype)
581
+ else:
582
+ img_ids = _prepare_latent_image_ids(model_input.shape[0],2*model_input.shape[1],2*model_input.shape[2],model_input.device,model_input.dtype)
583
+ txt_ids = torch.zeros(encoder_hidden_states.shape[1],3).to(device = encoder_hidden_states.device, dtype = encoder_hidden_states.dtype)
584
+ model_output = self.transformer(
585
+ hidden_states = model_input,
586
+ micro_conds=micro_conds,
587
+ pooled_projections=prompt_embeds,
588
+ encoder_hidden_states=encoder_hidden_states,
589
+ img_ids = img_ids,
590
+ txt_ids = txt_ids,
591
+ timestep = torch.tensor([timestep], device=model_input.device, dtype=torch.long),
592
+ )
593
+ if guidance_scale > 1.0:
594
+ uncond_logits, cond_logits = model_output.chunk(2)
595
+ model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
596
+ tmp_logits = torch.permute(model_output, (0, 2, 3, 1)).float()
597
+ pad_logits = torch.full(
598
+ (*tmp_logits.shape[:3], self.vocab_size - self.codebook_size),
599
+ -torch.inf,
600
+ device=tmp_logits.device, dtype=tmp_logits.dtype
601
+ )
602
+ tmp_logits = torch.cat([tmp_logits, pad_logits], dim=-1)
603
+ return tmp_logits
604
+
605
+ def _decode_latents(self, latents, batch_size, height, width, output_type):
606
+ if output_type == "latent":
607
+ output = latents
608
+ else:
609
+ needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast # type: ignore
610
+ if needs_upcasting:
611
+ self.vqvae.float()
612
+ output = self.vqvae.decode(
613
+ latents,
614
+ force_not_quantize=True,
615
+ shape=(
616
+ batch_size,
617
+ height // self.vae_scale_factor,
618
+ width // self.vae_scale_factor,
619
+ self.vqvae.config.latent_channels, # type: ignore
620
+ ),
621
+ ).sample.clip(0, 1) # type: ignore
622
+ output = self.image_processor.postprocess(output, output_type)
623
+ if needs_upcasting:
624
+ self.vqvae.half()
625
+ return output
626
+
627
+ def _decode_one_hot_latents(self, latents_one_hot, batch_size, height, width, output_type):
628
+ shape = (
629
+ batch_size,
630
+ height // self.vae_scale_factor,
631
+ width // self.vae_scale_factor,
632
+ self.vqvae.config.latent_channels, # type: ignore
633
+ )
634
+ codebook_size = self.transformer.config.codebook_size #type: ignore
635
+
636
+ needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast # type: ignore
637
+ if needs_upcasting:
638
+ self.vqvae.float()
639
+
640
+ # get quantized latent vectors
641
+ embedding = self.vqvae.quantize.embedding.weight
642
+ h: torch.Tensor = latents_one_hot[..., :codebook_size].to(embedding.dtype) @ embedding
643
+ h = h.view(shape)
644
+ # reshape back to match original input shape
645
+ h = h.permute(0, 3, 1, 2).contiguous()
646
+
647
+ # Setting lookup_from_codebook to False, as we already have the codebook embeddings in h
648
+ self.vqvae.config.lookup_from_codebook = False # type: ignore
649
+ output = self.vqvae.decode(
650
+ h, # type: ignore
651
+ force_not_quantize=True,
652
+ ).sample.clip(0, 1) # type: ignore
653
+ self.vqvae.config.lookup_from_codebook = True # type: ignore
654
+
655
+ output = self.image_processor.postprocess(output, output_type)
656
+
657
+ if needs_upcasting:
658
+ self.vqvae.half()
659
+
660
+ return output
661
+
662
+ def _subs_parameterization(self, logits, latents):
663
+ B, H, W, C = logits.shape
664
+ logits = logits.view(B, H * W, C)
665
+ assert latents.shape == (B, H, W)
666
+ latents = latents.view(B, H * W)
667
+
668
+ logits = logits - torch.logsumexp(logits, dim=-1,
669
+ keepdim=True)
670
+ unmasked_indices = (latents != self.mask_index)
671
+ logits[unmasked_indices] = -torch.inf
672
+ logits[unmasked_indices, latents[unmasked_indices]] = 0
673
+
674
+ logits = logits.view(B, H, W, C)
675
+ return logits
src/smc/resampling.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ def compute_ess(w, dim=-1):
7
+ ess = (w.sum(dim=dim))**2 / torch.sum(w**2, dim=dim)
8
+ return ess
9
+
10
+ def compute_ess_from_log_w(log_w, dim=-1):
11
+ return compute_ess(normalize_weights(log_w, dim=dim), dim=dim)
12
+
13
+ def normalize_weights(log_weights, dim=-1):
14
+ return torch.exp(normalize_log_weights(log_weights, dim=dim))
15
+
16
+ def normalize_log_weights(log_weights, dim=-1):
17
+ log_weights = log_weights - log_weights.max(dim=dim, keepdims=True)[0]
18
+ log_weights = log_weights - torch.logsumexp(log_weights, dim=dim, keepdims=True) # type: ignore
19
+ return log_weights
20
+
21
+ def stratified_resample(log_weights: torch.Tensor):
22
+ N = log_weights.shape[0]
23
+ weights = normalize_weights(log_weights)
24
+ cdf = torch.cumsum(weights, dim=0)
25
+
26
+ # Stratified uniform samples
27
+ u = (torch.arange(N, dtype=torch.float32, device=log_weights.device) + torch.rand(N, device=log_weights.device)) / N
28
+
29
+ indices = torch.searchsorted(cdf, u, right=True)
30
+ return indices
31
+
32
+ def systematic_resample(log_weights: torch.Tensor, normalized=True):
33
+ N = log_weights.shape[0]
34
+ weights = normalize_weights(log_weights)
35
+ cdf = torch.cumsum(weights, dim=0)
36
+
37
+ # Systematic uniform samples
38
+ u0 = torch.rand(1, device=log_weights.device) / N
39
+ u = u0 + torch.arange(N, dtype=torch.float32, device=log_weights.device) / N
40
+
41
+ indices = torch.searchsorted(cdf, u, right=True)
42
+ return indices
43
+
44
+ def multinomial_resample(log_weights: torch.Tensor, normalized=True):
45
+ N = log_weights.shape[0]
46
+ weights = normalize_weights(log_weights)
47
+ resampled_indices = torch.multinomial(weights, N, replacement=True)
48
+ return resampled_indices
49
+
50
+ def partial_resample(log_weights: torch.Tensor,
51
+ resample_fn: Callable[[torch.Tensor], torch.Tensor],
52
+ M: int) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ """
54
+ Perform partial resampling on a set of particles using PyTorch.
55
+
56
+ Args:
57
+ log_weights (torch.Tensor): 1D tensor of shape (K,) containing log-weights.
58
+ resample_fn (callable): function that takes log_weights and n_samples,
59
+ returning a tensor of shape (n_samples,) of sampled indices.
60
+ M (int): total number of particles to resample.
61
+
62
+ Returns:
63
+ new_indices (torch.Tensor): 1D tensor of shape (K,) mapping each output slot to
64
+ an original particle index.
65
+ new_log_weights (torch.Tensor): 1D tensor of shape (K,) of updated log-weights.
66
+ """
67
+ K = log_weights.numel()
68
+
69
+ # Convert log-weights to normalized weights
70
+ log_weights = normalize_log_weights(log_weights)
71
+ weights = torch.exp(log_weights)
72
+
73
+ # Determine how many high and low weights to resample
74
+ M_hi = 1 # M // 2
75
+ M_lo = M - M_hi
76
+
77
+ # Get indices of highest and lowest weights
78
+ _, hi_idx = torch.topk(weights, M_hi, largest=True)
79
+ _, lo_idx = torch.topk(weights, M_lo, largest=False)
80
+ I = torch.cat([hi_idx, lo_idx]) # indices selected for resampling
81
+
82
+ # Perform multinomial resampling only on selected subset
83
+ # resample_fn expects log-weights of the subset
84
+ subset_logw = log_weights[I]
85
+ local_sampled = resample_fn(subset_logw) # indices in [0, len(I))
86
+ # Map back to original indices
87
+ sampled = I[local_sampled]
88
+
89
+ # Build new index mapping: default to identity (retain original)
90
+ new_indices = torch.arange(K, device=log_weights.device)
91
+ new_indices[I] = sampled
92
+
93
+ # Compute new uniform weight for resampled particles
94
+ total_I_weight = weights[I].sum()
95
+ uniform_weight = total_I_weight / M
96
+
97
+ # Prepare new log-weights
98
+ new_log_weight = torch.empty_like(log_weights)
99
+ # For non-resampled, keep original log-weights
100
+ mask = torch.ones(K, dtype=torch.bool, device=log_weights.device)
101
+ mask[I] = False
102
+ new_log_weight[mask] = log_weights[mask]
103
+ # For resampled, assign uniform log-weight
104
+ new_log_weight[I] = torch.log(uniform_weight)
105
+
106
+ return new_indices, new_log_weight
107
+
108
+
109
+ def resample(log_w, ess_threshold=None, partial=False):
110
+ """
111
+ Resample the log weights and return the indices of the resampled particles.
112
+
113
+ Parameters
114
+ ----------
115
+ log_w : array_like
116
+ The log weights of the particles.
117
+ ess_threshold : float, optional
118
+ The effective sample size (ESS) threshold. If the ESS is below this
119
+ threshold, resampling is performed. If None, no resampling is
120
+ performed.
121
+ partial : bool, optional
122
+ If True, the resampling is performed on the partial weights. If False,
123
+ the resampling is performed on the full weights.
124
+
125
+ Returns
126
+ -------
127
+ array_like
128
+ The indices of the resampled particles.
129
+ """
130
+ base_sampling_fn = systematic_resample
131
+ N = log_w.size(0)
132
+ ess = compute_ess_from_log_w(log_w)
133
+ if ess_threshold is not None and ess >= ess_threshold * N:
134
+ # Skip resampling as ess is not below the threshold
135
+ return (
136
+ torch.arange(N, device=log_w.device),
137
+ False,
138
+ log_w
139
+ )
140
+ if partial:
141
+ resample_indices, log_w = partial_resample(log_w, base_sampling_fn, N // 2)
142
+ else:
143
+ resample_indices = base_sampling_fn(log_w)
144
+ log_w = torch.zeros_like(log_w)
145
+ return (
146
+ resample_indices,
147
+ True,
148
+ log_w
149
+ )
src/smc/scheduler.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple, Union, List
4
+
5
+ import math
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from src.meissonic.scheduler import mask_by_random_topk
11
+
12
+
13
+ @dataclass
14
+ class SchedulerStepOutput:
15
+ new_latents: torch.Tensor
16
+
17
+
18
+ @dataclass
19
+ class SchedulerApproxGuidanceOutput:
20
+ new_latents: torch.Tensor
21
+ log_prob_proposal: torch.Tensor
22
+ log_prob_diffusion: torch.Tensor
23
+
24
+
25
+ class BaseScheduler(ABC):
26
+ @abstractmethod
27
+ def step(
28
+ self,
29
+ latents: torch.Tensor,
30
+ step: int,
31
+ logits: torch.Tensor,
32
+ ) -> SchedulerStepOutput:
33
+ pass
34
+
35
+ @abstractmethod
36
+ def set_timesteps(self, num_inference_steps: int):
37
+ pass
38
+
39
+ @abstractmethod
40
+ def step_with_approx_guidance(
41
+ self,
42
+ latents: torch.Tensor,
43
+ step: int,
44
+ logits: torch.Tensor,
45
+ approx_guidance: torch.Tensor,
46
+ ) -> SchedulerApproxGuidanceOutput:
47
+ pass
48
+
49
+
50
+ def sum_masked_logits(
51
+ logits: torch.Tensor,
52
+ preds: torch.Tensor,
53
+ mask: torch.Tensor
54
+ ) -> torch.Tensor:
55
+ """
56
+ Sum logits at `preds` indices, masked by `mask`, handling invalid `preds`.
57
+
58
+ Args:
59
+ logits: Tensor of shape (B, H, W, C) - logits over C classes.
60
+ preds: Tensor of shape (B, H, W) - predicted class indices.
61
+ mask: Tensor of shape (B, H, W) - binary mask to include positions.
62
+
63
+ Returns:
64
+ Tensor of shape (B,) - sum of selected logits per batch item.
65
+ """
66
+ B, H, W, C = logits.shape
67
+ # Ensure preds are in valid index range [0, C-1]
68
+ valid = (preds >= 0) & (preds <= preds[mask].max())
69
+ # Replace invalid preds with a dummy index (0), which we will mask later
70
+ safe_preds = preds.masked_fill(~valid, 0)
71
+ # Gather logits at predicted indices
72
+ selected = torch.gather(logits, dim=3, index=safe_preds.unsqueeze(-1)).squeeze(-1)
73
+ # Zero out contributions from invalid preds and masked positions
74
+ selected = selected * valid * mask
75
+ # Sum over H, W dimension
76
+ return selected.sum(dim=(1, 2))
77
+
78
+ def log1mexp(x: torch.Tensor) -> torch.Tensor:
79
+ """
80
+ Numerically stable computation of log(1 - exp(x)) for x < 0.
81
+ """
82
+ return torch.where(
83
+ x > -1,
84
+ torch.log(-torch.expm1(x)),
85
+ torch.log1p(-torch.exp(x)),
86
+ )
87
+
88
+
89
+ class MeissonicScheduler(BaseScheduler):
90
+ def __init__(self,
91
+ mask_token_id: int,
92
+ masking_schedule: str = "cosine",
93
+ ):
94
+ self.mask_token_id = mask_token_id
95
+ self.masking_schedule = masking_schedule
96
+
97
+ def set_timesteps(self, num_inference_steps: int, temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), device='cuda'):
98
+ self.num_inference_steps = num_inference_steps
99
+ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
100
+ if isinstance(temperature, (tuple, list)):
101
+ self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
102
+ else:
103
+ self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
104
+
105
+ def step(
106
+ self,
107
+ latents: torch.Tensor,
108
+ step: int,
109
+ logits: torch.Tensor,
110
+ ) -> SchedulerStepOutput:
111
+ batch_size, height, width, vocab_size = logits.shape
112
+ sample = latents.reshape(batch_size, height * width)
113
+ model_output = logits.reshape(batch_size, height * width, vocab_size)
114
+
115
+ unknown_map = sample == self.mask_token_id
116
+
117
+ probs = model_output.softmax(dim=-1)
118
+
119
+ device = probs.device
120
+ probs_ = probs
121
+ if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
122
+ probs_ = probs_.float() # multinomial is not implemented for cpu half precision
123
+ probs_ = probs_.reshape(-1, probs.size(-1))
124
+ pred_original_sample = torch.multinomial(probs_, 1).to(device=device)
125
+ pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
126
+ pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
127
+
128
+ timestep = self.num_inference_steps - 1 - step
129
+ if timestep == 0:
130
+ prev_sample = pred_original_sample
131
+ else:
132
+ seq_len = sample.shape[1]
133
+ step_idx = (self.timesteps == timestep).nonzero()
134
+ ratio = (step_idx + 1) / len(self.timesteps)
135
+
136
+ if self.masking_schedule == "cosine":
137
+ mask_ratio = torch.cos(ratio * math.pi / 2)
138
+ elif self.masking_schedule == "linear":
139
+ mask_ratio = 1 - ratio
140
+ else:
141
+ raise ValueError(f"unknown masking schedule {self.masking_schedule}")
142
+
143
+ mask_len = (seq_len * mask_ratio).floor()
144
+ # do not mask more than amount previously masked
145
+ mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
146
+ # mask at least one
147
+ mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
148
+
149
+ selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
150
+ # Ignores the tokens given in the input by overwriting their confidence.
151
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
152
+
153
+ masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx].item())
154
+
155
+ # Masks tokens with lower confidence.
156
+ prev_sample = torch.where(masking, self.mask_token_id, pred_original_sample)
157
+
158
+ print("Unmasked:", (prev_sample != self.mask_token_id).sum(dim=1))
159
+ prev_sample = prev_sample.reshape(batch_size, height, width)
160
+ pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
161
+
162
+ return SchedulerStepOutput(new_latents=prev_sample)
163
+
164
+
165
+ def step_with_approx_guidance(
166
+ self,
167
+ latents: torch.Tensor,
168
+ step: int,
169
+ logits: torch.Tensor,
170
+ approx_guidance: torch.Tensor,
171
+ ) -> SchedulerApproxGuidanceOutput:
172
+ proposal_logits = logits + approx_guidance
173
+ sched_out = self.step(latents, step, proposal_logits)
174
+ new_latents = sched_out.new_latents
175
+
176
+ newly_filled_positions = (latents != new_latents)
177
+ print("Newly filled positions:", newly_filled_positions.sum(dim=(1, 2)))
178
+
179
+ log_prob_proposal = sum_masked_logits(
180
+ logits=proposal_logits.log_softmax(dim=-1),
181
+ preds=new_latents,
182
+ mask=newly_filled_positions,
183
+ )
184
+ log_prob_diffusion = sum_masked_logits(
185
+ logits=logits.log_softmax(dim=-1),
186
+ preds=new_latents,
187
+ mask=newly_filled_positions,
188
+ )
189
+ print("log prob proposal:", log_prob_proposal)
190
+ print("log prob diffusion:", log_prob_diffusion)
191
+ return SchedulerApproxGuidanceOutput(
192
+ new_latents,
193
+ log_prob_proposal,
194
+ log_prob_diffusion,
195
+ )
196
+
197
+
198
+ class ReMDMScheduler(BaseScheduler):
199
+ def __init__(
200
+ self,
201
+ schedule,
202
+ remask_strategy,
203
+ eta,
204
+ mask_token_id,
205
+ temperature=1.0,
206
+ ):
207
+ self.schedule = schedule
208
+ self.remask_strategy = remask_strategy
209
+ self.eta = eta
210
+ self.temperature = temperature
211
+ self.mask_token_id = mask_token_id
212
+
213
+ def set_timesteps(self, num_inference_steps: int):
214
+ self.num_inference_steps = num_inference_steps
215
+ if self.schedule == "linear":
216
+ self.alphas = 1 - torch.linspace(0, 1, num_inference_steps + 1)
217
+ elif self.schedule == "cosine":
218
+ self.alphas = 1 - torch.cos((math.pi/2) * (1 - torch.linspace(0, 1, num_inference_steps + 1)))
219
+ else:
220
+ raise ValueError(f"unknown masking schedule {self.schedule}")
221
+
222
+ def step(
223
+ self,
224
+ latents: torch.Tensor,
225
+ step: int,
226
+ logits: torch.Tensor,
227
+ ) -> SchedulerStepOutput:
228
+ B, H, W, C = logits.shape
229
+ assert latents.shape == (B, H, W)
230
+
231
+ latents = latents.reshape(B, H*W)
232
+ logits = logits.reshape(B, H*W, C)
233
+
234
+ t = self.num_inference_steps - step
235
+ s = t - 1
236
+
237
+ alpha_t = self.alphas[t]
238
+ alpha_s = self.alphas[s]
239
+ sigma_t_max = torch.clamp_max((1 - alpha_s) / alpha_t, 1.0)
240
+ if self.remask_strategy == "max_cap":
241
+ sigma_t = torch.clamp_max(sigma_t_max, self.eta)
242
+ elif self.remask_strategy == "rescale":
243
+ sigma_t = sigma_t_max * self.eta
244
+ else:
245
+ raise ValueError(f"unknown masking schedule {self.remask_strategy}")
246
+
247
+ # z_t != m
248
+ x_theta = F.one_hot(latents, num_classes=C).float()
249
+ logits_z_t_neq_m = (
250
+ torch.log(x_theta) +
251
+ torch.log(1 - sigma_t)
252
+ )
253
+ logits_z_t_neq_m[..., self.mask_token_id] = (
254
+ torch.log(sigma_t)
255
+ )
256
+
257
+ # z_t = m
258
+ log_x_theta = (logits / self.temperature).log_softmax(dim=-1)
259
+ logits_z_t_eq_m = (
260
+ log_x_theta +
261
+ torch.log((alpha_s - (1 - sigma_t) * alpha_t) / (1 - alpha_t))
262
+ )
263
+ logits_z_t_eq_m[..., self.mask_token_id] = (
264
+ torch.log((1 - alpha_s - sigma_t * alpha_t) / (1 - alpha_t))
265
+ )
266
+
267
+ z_t_neq_m = (latents != self.mask_token_id)
268
+ p_theta_logits = torch.where(
269
+ z_t_neq_m.unsqueeze(-1).expand(-1, -1, C),
270
+ logits_z_t_neq_m,
271
+ logits_z_t_eq_m,
272
+ )
273
+ assert torch.allclose(torch.exp(p_theta_logits).sum(dim=-1), torch.ones(B, H*W, device=logits.device)), (torch.exp(p_theta_logits).sum(dim=-1) - torch.ones(B, H*W, device=logits.device)).abs().max()
274
+ diffusion_dist = torch.distributions.Categorical(logits=p_theta_logits) # type: ignore
275
+ new_latents = diffusion_dist.sample()
276
+ print("Unmasked:", (new_latents != self.mask_token_id).sum(dim=1))
277
+ return SchedulerStepOutput(new_latents.reshape(B, H, W))
278
+
279
+ def step_with_approx_guidance(
280
+ self,
281
+ latents: torch.Tensor,
282
+ step: int,
283
+ logits: torch.Tensor,
284
+ approx_guidance: torch.Tensor,
285
+ ) -> SchedulerApproxGuidanceOutput:
286
+ B, H, W, C = logits.shape
287
+ assert latents.shape == (B, H, W)
288
+ assert approx_guidance.shape == (B, H, W, C)
289
+
290
+ latents = latents.reshape(B, H*W)
291
+ logits = logits.reshape(B, H*W, C)
292
+ approx_guidance = approx_guidance.reshape(B, H*W, C)
293
+
294
+ t = self.num_inference_steps - step
295
+ s = t - 1
296
+
297
+ alpha_t = self.alphas[t]
298
+ alpha_s = self.alphas[s]
299
+ sigma_t_max = torch.clamp_max((1 - alpha_s) / alpha_t, 1.0)
300
+ if self.remask_strategy == "max_cap":
301
+ sigma_t = torch.clamp_max(sigma_t_max, self.eta)
302
+ elif self.remask_strategy == "rescale":
303
+ sigma_t = sigma_t_max * self.eta
304
+ else:
305
+ raise ValueError(f"unknown masking schedule {self.remask_strategy}")
306
+
307
+ # z_t != m
308
+ x_theta = F.one_hot(latents, num_classes=C).float()
309
+ logits_z_t_neq_m = (
310
+ torch.log(x_theta) +
311
+ torch.log(1 - sigma_t)
312
+ )
313
+ logits_z_t_neq_m[..., self.mask_token_id] = (
314
+ torch.log(sigma_t)
315
+ )
316
+
317
+ # z_t = m
318
+ log_x_theta = (logits / self.temperature).log_softmax(dim=-1)
319
+ logits_z_t_eq_m = (
320
+ log_x_theta +
321
+ torch.log((alpha_s - (1 - sigma_t) * alpha_t) / (1 - alpha_t))
322
+ )
323
+ logits_z_t_eq_m[..., self.mask_token_id] = (
324
+ torch.log((1 - alpha_s - sigma_t * alpha_t) / (1 - alpha_t))
325
+ )
326
+
327
+ z_t_neq_m = (latents != self.mask_token_id)
328
+ p_theta_logits = torch.where(
329
+ z_t_neq_m.unsqueeze(-1).expand(-1, -1, C),
330
+ logits_z_t_neq_m,
331
+ logits_z_t_eq_m,
332
+ )
333
+ assert torch.allclose(torch.exp(p_theta_logits).sum(dim=-1), torch.ones(B, H*W, device=logits.device))
334
+
335
+ proposal_logits = (p_theta_logits + approx_guidance).log_softmax(dim=-1)
336
+ assert torch.allclose(torch.exp(proposal_logits).sum(dim=-1), torch.ones(B, H*W, device=logits.device))
337
+
338
+ # modify proposal logits to have the same mask schedule as the original logits
339
+ proposal_logits[..., :self.mask_token_id] += (
340
+ torch.logsumexp(p_theta_logits[..., :self.mask_token_id], dim=(1, 2), keepdim=True) -
341
+ torch.logsumexp(proposal_logits[..., :self.mask_token_id], dim=(1, 2), keepdim=True)
342
+ )
343
+ proposal_logits[..., :self.mask_token_id] = torch.where(
344
+ proposal_logits[..., :self.mask_token_id].logsumexp(dim=-1, keepdim=True) >= 0,
345
+ proposal_logits[..., :self.mask_token_id].log_softmax(dim=-1),
346
+ proposal_logits[..., :self.mask_token_id]
347
+ )
348
+ assert not (proposal_logits[..., :self.mask_token_id].logsumexp(dim=-1) > 1e-6).any(), proposal_logits[..., :self.mask_token_id].logsumexp(dim=-1).max()
349
+ proposal_logits[..., self.mask_token_id] = (
350
+ log1mexp(proposal_logits[..., :self.mask_token_id].logsumexp(dim=-1).clamp_max(0))
351
+ )
352
+ assert torch.allclose(torch.exp(proposal_logits).sum(dim=-1), torch.ones(B, H*W, device=logits.device)), (torch.exp(proposal_logits).sum(dim=-1) - torch.ones(B, H*W, device=logits.device)).abs().max()
353
+ # modify proposal logits to have the same mask schedule as the original logits
354
+
355
+ proposal_dist = torch.distributions.Categorical(logits=proposal_logits) # type: ignore
356
+ diffusion_dist = torch.distributions.Categorical(logits=p_theta_logits) # type: ignore
357
+
358
+ new_latents = proposal_dist.sample()
359
+
360
+ log_prob_proposal = proposal_dist.log_prob(new_latents).sum(dim=1)
361
+ log_prob_diffusion = diffusion_dist.log_prob(new_latents).sum(dim=1)
362
+
363
+ print("Unmasked:", (new_latents != self.mask_token_id).sum(dim=1))
364
+ return SchedulerApproxGuidanceOutput(
365
+ new_latents.reshape(B, H, W),
366
+ log_prob_proposal,
367
+ log_prob_diffusion,
368
+ )
src/smc/transformer.py ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team, The InstantX Team and The MeissonFlow Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from diffusers.models.attention import FeedForward, BasicTransformerBlock, SkipFFTransformerBlock
26
+ from diffusers.models.attention_processor import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ FluxAttnProcessor2_0,
30
+ # FusedFluxAttnProcessor2_0,
31
+ )
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, GlobalResponseNorm, RMSNorm
34
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
35
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
36
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings,TimestepEmbedding, get_timestep_embedding #,FluxPosEmbed
37
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
38
+ from diffusers.models.resnet import Downsample2D, Upsample2D
39
+
40
+ from typing import List
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+
46
+ def get_3d_rotary_pos_embed(
47
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
48
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
49
+ """
50
+ RoPE for video tokens with 3D structure.
51
+
52
+ Args:
53
+ embed_dim: (`int`):
54
+ The embedding dimension size, corresponding to hidden_size_head.
55
+ crops_coords (`Tuple[int]`):
56
+ The top-left and bottom-right coordinates of the crop.
57
+ grid_size (`Tuple[int]`):
58
+ The grid size of the spatial positional embedding (height, width).
59
+ temporal_size (`int`):
60
+ The size of the temporal dimension.
61
+ theta (`float`):
62
+ Scaling factor for frequency computation.
63
+ use_real (`bool`):
64
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
65
+
66
+ Returns:
67
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
68
+ """
69
+ start, stop = crops_coords
70
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
71
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
72
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
73
+
74
+ # Compute dimensions for each axis
75
+ dim_t = embed_dim // 4
76
+ dim_h = embed_dim // 8 * 3
77
+ dim_w = embed_dim // 8 * 3
78
+
79
+ # Temporal frequencies
80
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
81
+ grid_t = torch.from_numpy(grid_t).float()
82
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
83
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
84
+
85
+ # Spatial frequencies for height and width
86
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
87
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
88
+ grid_h = torch.from_numpy(grid_h).float()
89
+ grid_w = torch.from_numpy(grid_w).float()
90
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
91
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
92
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
93
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
94
+
95
+ # Broadcast and concatenate tensors along specified dimension
96
+ def broadcast(tensors, dim=-1):
97
+ num_tensors = len(tensors)
98
+ shape_lens = {len(t.shape) for t in tensors}
99
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
100
+ shape_len = list(shape_lens)[0]
101
+ dim = (dim + shape_len) if dim < 0 else dim
102
+ dims = list(zip(*(list(t.shape) for t in tensors)))
103
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
104
+ assert all(
105
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
106
+ ), "invalid dimensions for broadcastable concatenation"
107
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
108
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
109
+ expanded_dims.insert(dim, (dim, dims[dim]))
110
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
111
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
112
+ return torch.cat(tensors, dim=dim)
113
+
114
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
115
+
116
+ t, h, w, d = freqs.shape
117
+ freqs = freqs.view(t * h * w, d)
118
+
119
+ # Generate sine and cosine components
120
+ sin = freqs.sin()
121
+ cos = freqs.cos()
122
+
123
+ if use_real:
124
+ return cos, sin
125
+ else:
126
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
127
+ return freqs_cis
128
+
129
+
130
+ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
131
+ """
132
+ RoPE for image tokens with 2d structure.
133
+
134
+ Args:
135
+ embed_dim: (`int`):
136
+ The embedding dimension size
137
+ crops_coords (`Tuple[int]`)
138
+ The top-left and bottom-right coordinates of the crop.
139
+ grid_size (`Tuple[int]`):
140
+ The grid size of the positional embedding.
141
+ use_real (`bool`):
142
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
143
+
144
+ Returns:
145
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
146
+ """
147
+ start, stop = crops_coords
148
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
149
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
150
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
151
+ grid = np.stack(grid, axis=0) # [2, W, H]
152
+
153
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
154
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
155
+ return pos_embed
156
+
157
+
158
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
159
+ assert embed_dim % 4 == 0
160
+
161
+ # use half of dimensions to encode grid_h
162
+ emb_h = get_1d_rotary_pos_embed(
163
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
164
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
165
+ emb_w = get_1d_rotary_pos_embed(
166
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
167
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
168
+
169
+ if use_real:
170
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
171
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
172
+ return cos, sin
173
+ else:
174
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
175
+ return emb
176
+
177
+
178
+ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
179
+ assert embed_dim % 4 == 0
180
+
181
+ emb_h = get_1d_rotary_pos_embed(
182
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
183
+ ) # (H, D/4)
184
+ emb_w = get_1d_rotary_pos_embed(
185
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
186
+ ) # (W, D/4)
187
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
188
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
189
+
190
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
191
+ return emb
192
+
193
+
194
+ def get_1d_rotary_pos_embed(
195
+ dim: int,
196
+ pos: Union[np.ndarray, int],
197
+ theta: float = 10000.0,
198
+ use_real=False,
199
+ linear_factor=1.0,
200
+ ntk_factor=1.0,
201
+ repeat_interleave_real=True,
202
+ freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
203
+ ):
204
+ """
205
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
206
+
207
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
208
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
209
+ data type.
210
+
211
+ Args:
212
+ dim (`int`): Dimension of the frequency tensor.
213
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
214
+ theta (`float`, *optional*, defaults to 10000.0):
215
+ Scaling factor for frequency computation. Defaults to 10000.0.
216
+ use_real (`bool`, *optional*):
217
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
218
+ linear_factor (`float`, *optional*, defaults to 1.0):
219
+ Scaling factor for the context extrapolation. Defaults to 1.0.
220
+ ntk_factor (`float`, *optional*, defaults to 1.0):
221
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
222
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
223
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
224
+ Otherwise, they are concateanted with themselves.
225
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
226
+ the dtype of the frequency tensor.
227
+ Returns:
228
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
229
+ """
230
+ assert dim % 2 == 0
231
+
232
+ if isinstance(pos, int):
233
+ pos = np.arange(pos)
234
+ theta = theta * ntk_factor
235
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
236
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
237
+ freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
238
+ if use_real and repeat_interleave_real:
239
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
240
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
241
+ return freqs_cos, freqs_sin
242
+ elif use_real:
243
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
244
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
245
+ return freqs_cos, freqs_sin
246
+ else:
247
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
248
+ return freqs_cis
249
+
250
+
251
+ class FluxPosEmbed(nn.Module):
252
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
253
+ def __init__(self, theta: int, axes_dim: List[int]):
254
+ super().__init__()
255
+ self.theta = theta
256
+ self.axes_dim = axes_dim
257
+
258
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
259
+ n_axes = ids.shape[-1]
260
+ cos_out = []
261
+ sin_out = []
262
+ pos = ids.squeeze().float().cpu().numpy()
263
+ is_mps = ids.device.type == "mps"
264
+ freqs_dtype = torch.float32 if is_mps else torch.float64
265
+ for i in range(n_axes):
266
+ cos, sin = get_1d_rotary_pos_embed(
267
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
268
+ )
269
+ cos_out.append(cos)
270
+ sin_out.append(sin)
271
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
272
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
273
+ return freqs_cos, freqs_sin
274
+
275
+
276
+
277
+ class FusedFluxAttnProcessor2_0:
278
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
279
+
280
+ def __init__(self):
281
+ if not hasattr(F, "scaled_dot_product_attention"):
282
+ raise ImportError(
283
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
284
+ )
285
+
286
+ def __call__(
287
+ self,
288
+ attn: Attention,
289
+ hidden_states: torch.FloatTensor,
290
+ encoder_hidden_states: torch.FloatTensor = None,
291
+ attention_mask: Optional[torch.FloatTensor] = None,
292
+ image_rotary_emb: Optional[torch.Tensor] = None,
293
+ ) -> torch.FloatTensor:
294
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
295
+
296
+ # `sample` projections.
297
+ qkv = attn.to_qkv(hidden_states)
298
+ split_size = qkv.shape[-1] // 3
299
+ query, key, value = torch.split(qkv, split_size, dim=-1)
300
+
301
+ inner_dim = key.shape[-1]
302
+ head_dim = inner_dim // attn.heads
303
+
304
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
305
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
306
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
307
+
308
+ if attn.norm_q is not None:
309
+ query = attn.norm_q(query)
310
+ if attn.norm_k is not None:
311
+ key = attn.norm_k(key)
312
+
313
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
314
+ # `context` projections.
315
+ if encoder_hidden_states is not None:
316
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
317
+ split_size = encoder_qkv.shape[-1] // 3
318
+ (
319
+ encoder_hidden_states_query_proj,
320
+ encoder_hidden_states_key_proj,
321
+ encoder_hidden_states_value_proj,
322
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
323
+
324
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
325
+ batch_size, -1, attn.heads, head_dim
326
+ ).transpose(1, 2)
327
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
328
+ batch_size, -1, attn.heads, head_dim
329
+ ).transpose(1, 2)
330
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
331
+ batch_size, -1, attn.heads, head_dim
332
+ ).transpose(1, 2)
333
+
334
+ if attn.norm_added_q is not None:
335
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
336
+ if attn.norm_added_k is not None:
337
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
338
+
339
+ # attention
340
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
341
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
342
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
343
+
344
+ if image_rotary_emb is not None:
345
+ from diffusers.models.embeddings import apply_rotary_emb
346
+
347
+ query = apply_rotary_emb(query, image_rotary_emb)
348
+ key = apply_rotary_emb(key, image_rotary_emb)
349
+
350
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
351
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
352
+ hidden_states = hidden_states.to(query.dtype)
353
+
354
+ if encoder_hidden_states is not None:
355
+ encoder_hidden_states, hidden_states = (
356
+ hidden_states[:, : encoder_hidden_states.shape[1]],
357
+ hidden_states[:, encoder_hidden_states.shape[1] :],
358
+ )
359
+
360
+ # linear proj
361
+ hidden_states = attn.to_out[0](hidden_states)
362
+ # dropout
363
+ hidden_states = attn.to_out[1](hidden_states)
364
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
365
+
366
+ return hidden_states, encoder_hidden_states
367
+ else:
368
+ return hidden_states
369
+
370
+
371
+
372
+ @maybe_allow_in_graph
373
+ class SingleTransformerBlock(nn.Module):
374
+ r"""
375
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
376
+
377
+ Reference: https://arxiv.org/abs/2403.03206
378
+
379
+ Parameters:
380
+ dim (`int`): The number of channels in the input and output.
381
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
382
+ attention_head_dim (`int`): The number of channels in each head.
383
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
384
+ processing of `context` conditions.
385
+ """
386
+
387
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
388
+ super().__init__()
389
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
390
+
391
+ self.norm = AdaLayerNormZeroSingle(dim)
392
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
393
+ self.act_mlp = nn.GELU(approximate="tanh")
394
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
395
+
396
+ processor = FluxAttnProcessor2_0()
397
+ self.attn = Attention(
398
+ query_dim=dim,
399
+ cross_attention_dim=None,
400
+ dim_head=attention_head_dim,
401
+ heads=num_attention_heads,
402
+ out_dim=dim,
403
+ bias=True,
404
+ processor=processor,
405
+ qk_norm="rms_norm",
406
+ eps=1e-6,
407
+ pre_only=True,
408
+ )
409
+
410
+ def forward(
411
+ self,
412
+ hidden_states: torch.FloatTensor,
413
+ temb: torch.FloatTensor,
414
+ image_rotary_emb=None,
415
+ ):
416
+ residual = hidden_states
417
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
418
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
419
+
420
+ attn_output = self.attn(
421
+ hidden_states=norm_hidden_states,
422
+ image_rotary_emb=image_rotary_emb,
423
+ )
424
+
425
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
426
+ gate = gate.unsqueeze(1)
427
+ hidden_states = gate * self.proj_out(hidden_states)
428
+ hidden_states = residual + hidden_states
429
+ if hidden_states.dtype == torch.float16:
430
+ hidden_states = hidden_states.clip(-65504, 65504)
431
+
432
+ return hidden_states
433
+
434
+ @maybe_allow_in_graph
435
+ class TransformerBlock(nn.Module):
436
+ r"""
437
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
438
+
439
+ Reference: https://arxiv.org/abs/2403.03206
440
+
441
+ Parameters:
442
+ dim (`int`): The number of channels in the input and output.
443
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
444
+ attention_head_dim (`int`): The number of channels in each head.
445
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
446
+ processing of `context` conditions.
447
+ """
448
+
449
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
450
+ super().__init__()
451
+
452
+ self.norm1 = AdaLayerNormZero(dim)
453
+
454
+ self.norm1_context = AdaLayerNormZero(dim)
455
+
456
+ if hasattr(F, "scaled_dot_product_attention"):
457
+ processor = FluxAttnProcessor2_0()
458
+ else:
459
+ raise ValueError(
460
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
461
+ )
462
+ self.attn = Attention(
463
+ query_dim=dim,
464
+ cross_attention_dim=None,
465
+ added_kv_proj_dim=dim,
466
+ dim_head=attention_head_dim,
467
+ heads=num_attention_heads,
468
+ out_dim=dim,
469
+ context_pre_only=False,
470
+ bias=True,
471
+ processor=processor,
472
+ qk_norm=qk_norm,
473
+ eps=eps,
474
+ )
475
+
476
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
477
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
478
+ # self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
479
+
480
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
481
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
482
+ # self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
483
+
484
+ # let chunk size default to None
485
+ self._chunk_size = None
486
+ self._chunk_dim = 0
487
+
488
+ def forward(
489
+ self,
490
+ hidden_states: torch.FloatTensor,
491
+ encoder_hidden_states: torch.FloatTensor,
492
+ temb: torch.FloatTensor,
493
+ image_rotary_emb=None,
494
+ ):
495
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
496
+
497
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
498
+ encoder_hidden_states, emb=temb
499
+ )
500
+ # Attention.
501
+ attn_output, context_attn_output = self.attn(
502
+ hidden_states=norm_hidden_states,
503
+ encoder_hidden_states=norm_encoder_hidden_states,
504
+ image_rotary_emb=image_rotary_emb,
505
+ )
506
+
507
+ # Process attention outputs for the `hidden_states`.
508
+ attn_output = gate_msa.unsqueeze(1) * attn_output
509
+ hidden_states = hidden_states + attn_output
510
+
511
+ norm_hidden_states = self.norm2(hidden_states)
512
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
513
+
514
+ ff_output = self.ff(norm_hidden_states)
515
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
516
+
517
+ hidden_states = hidden_states + ff_output
518
+
519
+ # Process attention outputs for the `encoder_hidden_states`.
520
+
521
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
522
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
523
+
524
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
525
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
526
+
527
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
528
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
529
+ if encoder_hidden_states.dtype == torch.float16:
530
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
531
+
532
+ return encoder_hidden_states, hidden_states
533
+
534
+
535
+ class UVit2DConvEmbed(nn.Module):
536
+ def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias):
537
+ super().__init__()
538
+ self.embeddings = nn.Embedding(vocab_size, in_channels)
539
+ self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine)
540
+ self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias)
541
+
542
+ def forward(self, input_ids):
543
+ if input_ids.is_floating_point():
544
+ embeddings = input_ids @ self.embeddings.weight
545
+ else:
546
+ embeddings = self.embeddings(input_ids)
547
+ embeddings = self.layer_norm(embeddings)
548
+ embeddings = embeddings.permute(0, 3, 1, 2)
549
+ embeddings = self.conv(embeddings)
550
+ return embeddings
551
+
552
+ class ConvMlmLayer(nn.Module):
553
+ def __init__(
554
+ self,
555
+ block_out_channels: int,
556
+ in_channels: int,
557
+ use_bias: bool,
558
+ ln_elementwise_affine: bool,
559
+ layer_norm_eps: float,
560
+ codebook_size: int,
561
+ ):
562
+ super().__init__()
563
+ self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias)
564
+ self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine)
565
+ self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias)
566
+
567
+ def forward(self, hidden_states):
568
+ hidden_states = self.conv1(hidden_states)
569
+ hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
570
+ logits = self.conv2(hidden_states)
571
+ return logits
572
+
573
+ class SwiGLU(nn.Module):
574
+ r"""
575
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
576
+ but uses SiLU / Swish instead of GeLU.
577
+
578
+ Parameters:
579
+ dim_in (`int`): The number of channels in the input.
580
+ dim_out (`int`): The number of channels in the output.
581
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
582
+ """
583
+
584
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
585
+ super().__init__()
586
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
587
+ self.activation = nn.SiLU()
588
+
589
+ def forward(self, hidden_states):
590
+ hidden_states = self.proj(hidden_states)
591
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
592
+ return hidden_states * self.activation(gate)
593
+
594
+ class ConvNextBlock(nn.Module):
595
+ def __init__(
596
+ self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4
597
+ ):
598
+ super().__init__()
599
+ self.depthwise = nn.Conv2d(
600
+ channels,
601
+ channels,
602
+ kernel_size=3,
603
+ padding=1,
604
+ groups=channels,
605
+ bias=use_bias,
606
+ )
607
+ self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine)
608
+ self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias)
609
+ self.channelwise_act = nn.GELU()
610
+ self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor))
611
+ self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias)
612
+ self.channelwise_dropout = nn.Dropout(hidden_dropout)
613
+ self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias)
614
+
615
+ def forward(self, x, cond_embeds):
616
+ x_res = x
617
+
618
+ x = self.depthwise(x)
619
+
620
+ x = x.permute(0, 2, 3, 1)
621
+ x = self.norm(x)
622
+
623
+ x = self.channelwise_linear_1(x)
624
+ x = self.channelwise_act(x)
625
+ x = self.channelwise_norm(x)
626
+ x = self.channelwise_linear_2(x)
627
+ x = self.channelwise_dropout(x)
628
+
629
+ x = x.permute(0, 3, 1, 2)
630
+
631
+ x = x + x_res
632
+
633
+ scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1)
634
+ x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
635
+
636
+ return x
637
+
638
+ class Simple_UVitBlock(nn.Module):
639
+ def __init__(
640
+ self,
641
+ channels,
642
+ ln_elementwise_affine,
643
+ layer_norm_eps,
644
+ use_bias,
645
+ downsample: bool,
646
+ upsample: bool,
647
+ ):
648
+ super().__init__()
649
+
650
+ if downsample:
651
+ self.downsample = Downsample2D(
652
+ channels,
653
+ use_conv=True,
654
+ padding=0,
655
+ name="Conv2d_0",
656
+ kernel_size=2,
657
+ norm_type="rms_norm",
658
+ eps=layer_norm_eps,
659
+ elementwise_affine=ln_elementwise_affine,
660
+ bias=use_bias,
661
+ )
662
+ else:
663
+ self.downsample = None
664
+
665
+ if upsample:
666
+ self.upsample = Upsample2D(
667
+ channels,
668
+ use_conv_transpose=True,
669
+ kernel_size=2,
670
+ padding=0,
671
+ name="conv",
672
+ norm_type="rms_norm",
673
+ eps=layer_norm_eps,
674
+ elementwise_affine=ln_elementwise_affine,
675
+ bias=use_bias,
676
+ interpolate=False,
677
+ )
678
+ else:
679
+ self.upsample = None
680
+
681
+ def forward(self, x):
682
+ # print("before,", x.shape)
683
+ if self.downsample is not None:
684
+ # print('downsample')
685
+ x = self.downsample(x)
686
+
687
+ if self.upsample is not None:
688
+ # print('upsample')
689
+ x = self.upsample(x)
690
+ # print("after,", x.shape)
691
+ return x
692
+
693
+ class Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
694
+ """
695
+ The Transformer model introduced in Flux.
696
+
697
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
698
+
699
+ Parameters:
700
+ patch_size (`int`): Patch size to turn the input data into small patches.
701
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
702
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
703
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
704
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
705
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
706
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
707
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
708
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
709
+ """
710
+
711
+ _supports_gradient_checkpointing = False #True
712
+ # Due to NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph.
713
+ # Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
714
+ _no_split_modules = ["TransformerBlock", "SingleTransformerBlock"]
715
+
716
+ @register_to_config
717
+ def __init__(
718
+ self,
719
+ patch_size: int = 1,
720
+ in_channels: int = 64,
721
+ num_layers: int = 19,
722
+ num_single_layers: int = 38,
723
+ attention_head_dim: int = 128,
724
+ num_attention_heads: int = 24,
725
+ joint_attention_dim: int = 4096,
726
+ pooled_projection_dim: int = 768,
727
+ guidance_embeds: bool = False, # unused in our implementation
728
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
729
+ vocab_size: int = 8256,
730
+ codebook_size: int = 8192,
731
+ downsample: bool = False,
732
+ upsample: bool = False,
733
+ ):
734
+ super().__init__()
735
+ self.out_channels = in_channels
736
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
737
+
738
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
739
+ text_time_guidance_cls = (
740
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
741
+ )
742
+ self.time_text_embed = text_time_guidance_cls(
743
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
744
+ )
745
+
746
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
747
+
748
+ self.transformer_blocks = nn.ModuleList(
749
+ [
750
+ TransformerBlock(
751
+ dim=self.inner_dim,
752
+ num_attention_heads=self.config.num_attention_heads,
753
+ attention_head_dim=self.config.attention_head_dim,
754
+ )
755
+ for i in range(self.config.num_layers)
756
+ ]
757
+ )
758
+
759
+ self.single_transformer_blocks = nn.ModuleList(
760
+ [
761
+ SingleTransformerBlock(
762
+ dim=self.inner_dim,
763
+ num_attention_heads=self.config.num_attention_heads,
764
+ attention_head_dim=self.config.attention_head_dim,
765
+ )
766
+ for i in range(self.config.num_single_layers)
767
+ ]
768
+ )
769
+
770
+
771
+ self.gradient_checkpointing = False
772
+
773
+ in_channels_embed = self.inner_dim
774
+ ln_elementwise_affine = True
775
+ layer_norm_eps = 1e-06
776
+ use_bias = False
777
+ micro_cond_embed_dim = 1280
778
+ self.embed = UVit2DConvEmbed(
779
+ in_channels_embed, self.inner_dim, self.config.vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias
780
+ )
781
+ self.mlm_layer = ConvMlmLayer(
782
+ self.inner_dim, in_channels_embed, use_bias, ln_elementwise_affine, layer_norm_eps, self.config.codebook_size
783
+ )
784
+ self.cond_embed = TimestepEmbedding(
785
+ micro_cond_embed_dim + self.config.pooled_projection_dim, self.inner_dim, sample_proj_bias=use_bias
786
+ )
787
+ self.encoder_proj_layer_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
788
+ self.project_to_hidden_norm = RMSNorm(in_channels_embed, layer_norm_eps, ln_elementwise_affine)
789
+ self.project_to_hidden = nn.Linear(in_channels_embed, self.inner_dim, bias=use_bias)
790
+ self.project_from_hidden_norm = RMSNorm(self.inner_dim, layer_norm_eps, ln_elementwise_affine)
791
+ self.project_from_hidden = nn.Linear(self.inner_dim, in_channels_embed, bias=use_bias)
792
+
793
+ self.down_block = Simple_UVitBlock(
794
+ self.inner_dim,
795
+ ln_elementwise_affine,
796
+ layer_norm_eps,
797
+ use_bias,
798
+ downsample,
799
+ False,
800
+ )
801
+ self.up_block = Simple_UVitBlock(
802
+ self.inner_dim, #block_out_channels,
803
+ ln_elementwise_affine,
804
+ layer_norm_eps,
805
+ use_bias,
806
+ False,
807
+ upsample=upsample,
808
+ )
809
+
810
+ # self.fuse_qkv_projections()
811
+
812
+ @property
813
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
814
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
815
+ r"""
816
+ Returns:
817
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
818
+ indexed by its weight name.
819
+ """
820
+ # set recursively
821
+ processors = {}
822
+
823
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
824
+ if hasattr(module, "get_processor"):
825
+ processors[f"{name}.processor"] = module.get_processor()
826
+
827
+ for sub_name, child in module.named_children():
828
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
829
+
830
+ return processors
831
+
832
+ for name, module in self.named_children():
833
+ fn_recursive_add_processors(name, module, processors)
834
+
835
+ return processors
836
+
837
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
838
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
839
+ r"""
840
+ Sets the attention processor to use to compute attention.
841
+
842
+ Parameters:
843
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
844
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
845
+ for **all** `Attention` layers.
846
+
847
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
848
+ processor. This is strongly recommended when setting trainable attention processors.
849
+
850
+ """
851
+ count = len(self.attn_processors.keys())
852
+
853
+ if isinstance(processor, dict) and len(processor) != count:
854
+ raise ValueError(
855
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
856
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
857
+ )
858
+
859
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
860
+ if hasattr(module, "set_processor"):
861
+ if not isinstance(processor, dict):
862
+ module.set_processor(processor)
863
+ else:
864
+ module.set_processor(processor.pop(f"{name}.processor"))
865
+
866
+ for sub_name, child in module.named_children():
867
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
868
+
869
+ for name, module in self.named_children():
870
+ fn_recursive_attn_processor(name, module, processor)
871
+
872
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
873
+ def fuse_qkv_projections(self):
874
+ """
875
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
876
+ are fused. For cross-attention modules, key and value projection matrices are fused.
877
+
878
+ <Tip warning={true}>
879
+
880
+ This API is 🧪 experimental.
881
+
882
+ </Tip>
883
+ """
884
+ self.original_attn_processors = None
885
+
886
+ for _, attn_processor in self.attn_processors.items():
887
+ if "Added" in str(attn_processor.__class__.__name__):
888
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
889
+
890
+ self.original_attn_processors = self.attn_processors
891
+
892
+ for module in self.modules():
893
+ if isinstance(module, Attention):
894
+ module.fuse_projections(fuse=True)
895
+
896
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
897
+
898
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
899
+ def unfuse_qkv_projections(self):
900
+ """Disables the fused QKV projection if enabled.
901
+
902
+ <Tip warning={true}>
903
+
904
+ This API is 🧪 experimental.
905
+
906
+ </Tip>
907
+
908
+ """
909
+ if self.original_attn_processors is not None:
910
+ self.set_attn_processor(self.original_attn_processors)
911
+
912
+ def _set_gradient_checkpointing(self, module, value=False):
913
+ if hasattr(module, "gradient_checkpointing"):
914
+ module.gradient_checkpointing = value
915
+
916
+ def forward(
917
+ self,
918
+ hidden_states: torch.Tensor,
919
+ encoder_hidden_states: torch.Tensor = None,
920
+ pooled_projections: torch.Tensor = None,
921
+ timestep: torch.LongTensor = None,
922
+ img_ids: torch.Tensor = None,
923
+ txt_ids: torch.Tensor = None,
924
+ guidance: torch.Tensor = None,
925
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
926
+ controlnet_block_samples= None,
927
+ controlnet_single_block_samples=None,
928
+ return_dict: bool = True,
929
+ micro_conds: torch.Tensor = None,
930
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
931
+ """
932
+ The [`FluxTransformer2DModel`] forward method.
933
+
934
+ Args:
935
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
936
+ Input `hidden_states`.
937
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
938
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
939
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
940
+ from the embeddings of input conditions.
941
+ timestep ( `torch.LongTensor`):
942
+ Used to indicate denoising step.
943
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
944
+ A list of tensors that if specified are added to the residuals of transformer blocks.
945
+ joint_attention_kwargs (`dict`, *optional*):
946
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
947
+ `self.processor` in
948
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
949
+ return_dict (`bool`, *optional*, defaults to `True`):
950
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
951
+ tuple.
952
+
953
+ Returns:
954
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
955
+ `tuple` where the first element is the sample tensor.
956
+ """
957
+ micro_cond_encode_dim = 256 # same as self.config.micro_cond_encode_dim = 256 from amused
958
+ micro_cond_embeds = get_timestep_embedding(
959
+ micro_conds.flatten(), micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0
960
+ )
961
+ micro_cond_embeds = micro_cond_embeds.reshape((hidden_states.shape[0], -1))
962
+
963
+ pooled_projections = torch.cat([pooled_projections, micro_cond_embeds], dim=1)
964
+ pooled_projections = pooled_projections.to(dtype=self.dtype)
965
+ pooled_projections = self.cond_embed(pooled_projections).to(encoder_hidden_states.dtype)
966
+
967
+
968
+ hidden_states = self.embed(hidden_states)
969
+
970
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
971
+ encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
972
+ hidden_states = self.down_block(hidden_states)
973
+
974
+ batch_size, channels, height, width = hidden_states.shape
975
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
976
+ hidden_states = self.project_to_hidden_norm(hidden_states)
977
+ hidden_states = self.project_to_hidden(hidden_states)
978
+
979
+
980
+ if joint_attention_kwargs is not None:
981
+ joint_attention_kwargs = joint_attention_kwargs.copy()
982
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
983
+ else:
984
+ lora_scale = 1.0
985
+
986
+ if USE_PEFT_BACKEND:
987
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
988
+ scale_lora_layers(self, lora_scale)
989
+ else:
990
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
991
+ logger.warning(
992
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
993
+ )
994
+
995
+ timestep = timestep.to(hidden_states.dtype) * 1000
996
+ if guidance is not None:
997
+ guidance = guidance.to(hidden_states.dtype) * 1000
998
+ else:
999
+ guidance = None
1000
+ temb = (
1001
+ self.time_text_embed(timestep, pooled_projections)
1002
+ if guidance is None
1003
+ else self.time_text_embed(timestep, guidance, pooled_projections)
1004
+ )
1005
+
1006
+ if txt_ids.ndim == 3:
1007
+ logger.warning(
1008
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
1009
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1010
+ )
1011
+ txt_ids = txt_ids[0]
1012
+ if img_ids.ndim == 3:
1013
+ logger.warning(
1014
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
1015
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
1016
+ )
1017
+ img_ids = img_ids[0]
1018
+ ids = torch.cat((txt_ids, img_ids), dim=0)
1019
+
1020
+ image_rotary_emb = self.pos_embed(ids)
1021
+
1022
+ for index_block, block in enumerate(self.transformer_blocks):
1023
+ if self.training and self.gradient_checkpointing:
1024
+
1025
+ def create_custom_forward(module, return_dict=None):
1026
+ def custom_forward(*inputs):
1027
+ if return_dict is not None:
1028
+ return module(*inputs, return_dict=return_dict)
1029
+ else:
1030
+ return module(*inputs)
1031
+
1032
+ return custom_forward
1033
+
1034
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1035
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1036
+ create_custom_forward(block),
1037
+ hidden_states,
1038
+ encoder_hidden_states,
1039
+ temb,
1040
+ image_rotary_emb,
1041
+ **ckpt_kwargs,
1042
+ )
1043
+
1044
+ else:
1045
+ encoder_hidden_states, hidden_states = block(
1046
+ hidden_states=hidden_states,
1047
+ encoder_hidden_states=encoder_hidden_states,
1048
+ temb=temb,
1049
+ image_rotary_emb=image_rotary_emb,
1050
+ )
1051
+
1052
+
1053
+ # controlnet residual
1054
+ if controlnet_block_samples is not None:
1055
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
1056
+ interval_control = int(np.ceil(interval_control))
1057
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
1058
+
1059
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1060
+
1061
+ for index_block, block in enumerate(self.single_transformer_blocks):
1062
+ if self.training and self.gradient_checkpointing:
1063
+
1064
+ def create_custom_forward(module, return_dict=None):
1065
+ def custom_forward(*inputs):
1066
+ if return_dict is not None:
1067
+ return module(*inputs, return_dict=return_dict)
1068
+ else:
1069
+ return module(*inputs)
1070
+
1071
+ return custom_forward
1072
+
1073
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1074
+ hidden_states = torch.utils.checkpoint.checkpoint(
1075
+ create_custom_forward(block),
1076
+ hidden_states,
1077
+ temb,
1078
+ image_rotary_emb,
1079
+ **ckpt_kwargs,
1080
+ )
1081
+
1082
+ else:
1083
+ hidden_states = block(
1084
+ hidden_states=hidden_states,
1085
+ temb=temb,
1086
+ image_rotary_emb=image_rotary_emb,
1087
+ )
1088
+
1089
+ # controlnet residual
1090
+ if controlnet_single_block_samples is not None:
1091
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
1092
+ interval_control = int(np.ceil(interval_control))
1093
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
1094
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1095
+ + controlnet_single_block_samples[index_block // interval_control]
1096
+ )
1097
+
1098
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
1099
+
1100
+
1101
+ hidden_states = self.project_from_hidden_norm(hidden_states)
1102
+ hidden_states = self.project_from_hidden(hidden_states)
1103
+
1104
+
1105
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
1106
+
1107
+ hidden_states = self.up_block(hidden_states)
1108
+
1109
+ if USE_PEFT_BACKEND:
1110
+ # remove `lora_scale` from each PEFT layer
1111
+ unscale_lora_layers(self, lora_scale)
1112
+
1113
+ output = self.mlm_layer(hidden_states)
1114
+ # self.unfuse_qkv_projections()
1115
+ if not return_dict:
1116
+ return (output,)
1117
+
1118
+
1119
+ return output