akrao9 commited on
Commit
fe4bf5e
Β·
verified Β·
1 Parent(s): e900c99

Add Boomer FLA fine-tuned checkpoint (step 055000, ema weights)

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ transformer/diffusion_pytorch_model.safetensors filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ tags:
4
+ - text-to-image
5
+ - diffusion
6
+ - linear-attention
7
+ - pytorch
8
+ - safetensors
9
+ language:
10
+ - en
11
+ pipeline_tag: text-to-image
12
+ ---
13
+
14
+ # Boomer FLA
15
+
16
+ Boomer FLA is a 657M parameter text-to-image diffusion model that generates **1024Γ—1024px** images from text prompts.
17
+
18
+ Instead of standard quadratic self-attention, it uses **GatedDeltaNet** β€” a bidirectional Flash Linear Attention mixer β€” as the backbone of its transformer blocks. This keeps memory flat regardless of sequence length. Every 6th block adds a full SDPA layer for global spatial coherence.
19
+
20
+ Text conditioning uses **Gemma 4 2B** (1536-dim embeddings, up to 300 tokens). Decoding uses the **DC-AE f32c32** VAE with 32Γ— spatial compression, producing 32Γ—32 latents from 1024px images. Sampling uses **STORK-2**, a high-order Runge–Kutta flow matching solver that converges in 32 steps.
21
+
22
+ Fine-tuned from a JourneyDB-pretrained base on 600k high-resolution images at 1024px.
23
+
24
+ ---
25
+
26
+ ## Sample outputs
27
+
28
+ ![Model grid β€” portraits and landscapes](boomer_model_grid.png)
29
+
30
+ *Portraits (top) and landscapes (bottom) generated at 1024Γ—1024px, 32 STORK-2 steps, CFG 4.5.*
31
+
32
+ ---
33
+
34
+ ## Architecture
35
+
36
+ | Property | Value |
37
+ |---|---|
38
+ | Parameters | 657M |
39
+ | Backbone | Bidirectional GatedDeltaNet (Flash Linear Attention) |
40
+ | Depth | 24 layers |
41
+ | Hidden dim | 896 |
42
+ | Heads | 14 |
43
+ | Image attention | Every 6th layer (full SDPA + 2D RoPE) |
44
+ | Patch size | 1 β€” one token per latent pixel (256 tokens at 512px, 1024 tokens at 1024px) |
45
+ | Text encoder | Gemma 4 2B (`google/gemma-4-E2B-it`) |
46
+ | VAE | DC-AE f32c32 (`mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers`) |
47
+ | Sampler | STORK-2, 32 steps |
48
+ | Dtype | bfloat16 |
49
+
50
+ ---
51
+
52
+ ## Training details
53
+
54
+ | Setting | Value |
55
+ |---|---|
56
+ | Pre-train dataset | JourneyDB (~3.8M images, 512px, patch size 1) |
57
+ | Fine-tune dataset | FineT2I (~600k images, 1024px, patch size 1) |
58
+ | Fine-tune steps | 55,000 |
59
+ | Batch size | 24 |
60
+ | Learning rate | 1e-4, linear warmup (1000 steps) β†’ cosine decay |
61
+ | Flow shift | 1.5 |
62
+ | Timestep sampler | Plateau logit-normal (ΞΌ=0, Οƒ=1) |
63
+ | Min-SNR Ξ³ | 5.0 |
64
+ | CFG dropout | 0.1 |
65
+ | EMA decay | 0.999 |
66
+ | Gradient clip | 0.3 |
67
+ | Optimizer | Fused AdamW |
68
+ | Hardware | NVIDIA A100 (Google Colab) |
69
+ | Precision | bfloat16 |
70
+
71
+ ---
72
+
73
+ ## VRAM and RAM requirements
74
+
75
+ Measured at 1024Γ—1024px, bfloat16, STORK-2, CFG batch=2.
76
+
77
+ | Component | VRAM |
78
+ |---|---|
79
+ | DiT weights (EMA, bf16) | 1.25 GB |
80
+ | Gemma 4 2B text encoder | 8.62 GB |
81
+ | Denoising peak (CFG on) | 1.36 GB |
82
+ | VAE decode peak | 3.51 GB |
83
+
84
+ | Mode | Peak VRAM | Minimum GPU |
85
+ |---|---|---|
86
+ | **Condition-cache** β€” pre-encoded embeddings, no text encoder in VRAM | **4.76 GB** | RTX 3060 8GB, T4 |
87
+ | **Fresh-prompt** β€” text encoder + DiT + VAE together | **13.38 GB** | RTX 3090, A100 |
88
+
89
+ **System RAM**: loading the text encoder (Gemma 4 2B) requires ~9 GB of system RAM even when using GPU. For condition-cache mode, encoding can be done on CPU with ~9 GB RAM β€” the generation step then needs only 5 GB VRAM.
90
+
91
+ ---
92
+
93
+ ## Usage
94
+
95
+ ### Install
96
+
97
+ ```bash
98
+ pip install torch diffusers transformers accelerate safetensors
99
+ pip install git+https://github.com/fla-org/flash-linear-attention.git
100
+ # STORK is bundled with the model β€” no separate install needed
101
+ ```
102
+
103
+ ### Generate
104
+
105
+ ```python
106
+ from diffusers import DiffusionPipeline
107
+
108
+ pipe = DiffusionPipeline.from_pretrained(
109
+ "akrao9/Boomer-T2I",
110
+ trust_remote_code=True,
111
+ torch_dtype="auto",
112
+ ).to("cuda")
113
+
114
+ image = pipe("a photorealistic portrait of a woman with dark hair")[0]
115
+ image.save("output.png")
116
+ ```
117
+
118
+ ### Parameters
119
+
120
+ ```python
121
+ image = pipe(
122
+ "a rocky coastline at sunset with crashing waves",
123
+ steps=32, # denoising steps β€” 32 is recommended with STORK-2
124
+ cfg_scale=4.5, # classifier-free guidance scale (4.0–5.0)
125
+ cfg_rescale=0.5, # reduces over-saturation at high CFG
126
+ seed=42,
127
+ )[0]
128
+ ```
129
+
130
+ ### Low VRAM β€” condition cache mode
131
+
132
+ Encode prompts once on any machine (including CPU), save the embedding, then generate with only the 1.25 GB DiT loaded. Peak VRAM drops from 13.38 GB β†’ 4.76 GB.
133
+
134
+ ```python
135
+ # Step 1 β€” encode on any machine (even CPU with 9GB RAM)
136
+ import torch
137
+ from transformers import AutoModelForCausalLM, AutoProcessor
138
+
139
+ TE_REPO = "google/gemma-4-E2B-it"
140
+ tokenizer = AutoProcessor.from_pretrained(TE_REPO)
141
+ text_encoder = AutoModelForCausalLM.from_pretrained(
142
+ TE_REPO, torch_dtype=torch.bfloat16
143
+ ).get_decoder()
144
+
145
+ tokens = tokenizer(
146
+ "a mountain lake surrounded by alpine peaks",
147
+ max_length=300, padding="max_length",
148
+ truncation=True, return_tensors="pt",
149
+ )
150
+ with torch.inference_mode():
151
+ hidden = text_encoder(
152
+ tokens["input_ids"], attention_mask=tokens["attention_mask"]
153
+ )[0]
154
+ idx = [0] + list(range(-299, 0))
155
+ emb = hidden[:, idx]
156
+ mask = tokens["attention_mask"][:, idx]
157
+
158
+ torch.save({"emb": emb.cpu(), "mask": mask.cpu()}, "condition.pt")
159
+ ```
160
+
161
+ ```python
162
+ # Step 2 β€” generate on low-VRAM GPU (no text encoder needed in VRAM)
163
+ from diffusers import DiffusionPipeline
164
+ import torch
165
+
166
+ pipe = DiffusionPipeline.from_pretrained(
167
+ "akrao9/Boomer-T2I",
168
+ trust_remote_code=True,
169
+ torch_dtype="auto",
170
+ ).to("cuda")
171
+
172
+ saved = torch.load("condition.pt")
173
+ image = pipe(
174
+ prompt="",
175
+ _preencoded_emb=saved["emb"].cuda(),
176
+ _preencoded_mask=saved["mask"].cuda(),
177
+ )[0]
178
+ image.save("output.png")
179
+ ```
180
+
181
+ ---
182
+
183
+ ## Limitations
184
+
185
+ - **Strong at** β€” photorealistic human portraits, dramatic landscapes, architectural scenes
186
+ - **Weak at** β€” animals, text rendering, small detailed objects (limited training data coverage)
187
+ - Landscapes have a painterly/HDR bias inherited from heavily post-processed stock images in the training set
188
+ - Not safety filtered β€” outputs may reflect biases in the training data
189
+ - Maximum tested resolution: **1024Γ—1024px**
190
+
191
+ ---
192
+
193
+ ## License
194
+
195
+ The Boomer FLA model weights are released for research and personal use. Commercial use is not permitted without explicit permission.
196
+
197
+ Upstream component licenses:
198
+ - DC-AE VAE: [mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers)
199
+ - Gemma 4 text encoder: [Gemma Terms of Use](https://ai.google.dev/gemma/terms)
STORKScheduler.py ADDED
@@ -0,0 +1,1641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ import math
14
+ from dataclasses import dataclass
15
+ from typing import List, Optional, Tuple, Union
16
+ import numpy as np
17
+ import torch
18
+ from scipy.io import loadmat
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
21
+ from diffusers.utils import BaseOutput, is_scipy_available, logging
22
+ from pathlib import Path
23
+
24
+
25
+
26
+ @dataclass
27
+ class STORKSchedulerOutput(BaseOutput):
28
+ """
29
+ Output class for the scheduler's `step` function output.
30
+
31
+ Args:
32
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
33
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
34
+ denoising loop.
35
+ """
36
+
37
+ prev_sample: torch.FloatTensor
38
+
39
+
40
+ current_file = Path(__file__)
41
+ CONSTANTSFOLDER = f"{current_file.parent}/STORK_constants"
42
+
43
+
44
+
45
+
46
+
47
+ class STORKScheduler(SchedulerMixin, ConfigMixin):
48
+ """
49
+ `STORKScheduler` uses modified stabilized Runge-Kutta method for the backward ODE in the diffusion or flow matching models.
50
+ This include the original STORK method and the modified STORK++ methods.
51
+
52
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
53
+ methods the library implements for all schedulers such as loading and saving.
54
+
55
+ Args:
56
+ num_train_timesteps (`int`, defaults to 1000):
57
+ The number of diffusion steps to train the model.
58
+ shift (`float`, defaults to 1.0):
59
+ The shift value for the timestep schedule.
60
+ use_dynamic_shifting (`bool`, defaults to False):
61
+ Whether to apply timestep shifting on-the-fly based on the image resolution.
62
+ base_shift (`float`, defaults to 0.5):
63
+ Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
64
+ with desired output.
65
+ max_shift (`float`, defaults to 1.15):
66
+ Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
67
+ more exaggerated or stylized.
68
+ base_image_seq_len (`int`, defaults to 256):
69
+ The base image sequence length.
70
+ max_image_seq_len (`int`, defaults to 4096):
71
+ The maximum image sequence length.
72
+ invert_sigmas (`bool`, defaults to False):
73
+ Whether to invert the sigmas.
74
+ shift_terminal (`float`, defaults to None):
75
+ The end value of the shifted timestep schedule.
76
+ use_karras_sigmas (`bool`, defaults to False):
77
+ Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
78
+ use_exponential_sigmas (`bool`, defaults to False):
79
+ Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
80
+ use_beta_sigmas (`bool`, defaults to False):
81
+ Whether to use beta sigmas for step sizes in the noise schedule during sampling.
82
+ solver_order (`int`, defaults to 2):
83
+ The STORK order which can be `2` or `4`. It is recommended to use `solver_order=2` uniformly.
84
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
85
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process) or `flow_prediction`.
86
+ time_shift_type (`str`, defaults to "exponential"):
87
+ The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
88
+ derivative_order (`int`, defaults to 1):
89
+ The order of the Taylor expansion derivative to use for the sub-step velocity approximation. Only supports 1, 2 or 3.
90
+ s (`int`, defaults to 50):
91
+ The number of sub-steps to use in the STORK.
92
+ precision (`str`, defaults to "float32"):
93
+ The precision to use for the scheduler; supports "float32", "bfloat16", or "float16".
94
+ """
95
+
96
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
97
+ order = 1
98
+
99
+ @register_to_config
100
+ def __init__(
101
+ self,
102
+ num_train_timesteps: int = 1000,
103
+ shift: float = 1.0,
104
+ use_dynamic_shifting: bool = False,
105
+ beta_start: float = 0.0001,
106
+ beta_end: float = 0.02,
107
+ beta_schedule: str = "linear",
108
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
109
+ stopping_eps: float = 1e-2,
110
+ solver_order: int = 4,
111
+ prediction_type: str = "epsilon",
112
+ time_shift_type: str = "exponential",
113
+ derivative_order: int = 1,
114
+ s: int = 50,
115
+ base_shift: Optional[float] = 0.5,
116
+ max_shift: Optional[float] = 1.15,
117
+ base_image_seq_len: Optional[int] = 256,
118
+ max_image_seq_len: Optional[int] = 4096,
119
+ invert_sigmas: bool = False,
120
+ shift_terminal: Optional[float] = None,
121
+ use_karras_sigmas: Optional[bool] = False,
122
+ use_exponential_sigmas: Optional[bool] = False,
123
+ use_beta_sigmas: Optional[bool] = False,
124
+ ):
125
+
126
+ super().__init__()
127
+ # if prediction_type == "flow_prediction" and sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
128
+ # raise ValueError(
129
+ # "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
130
+ # )
131
+ if time_shift_type not in {"exponential", "linear"}:
132
+ raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
133
+
134
+ # We manually enforce precision to float32 for numerical issues.Add commentMore actions
135
+ self.np_dtype = np.float32
136
+ self.dtype = torch.float32
137
+
138
+
139
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=self.np_dtype)[::-1].copy()
140
+ timesteps = torch.from_numpy(timesteps).to(dtype=self.dtype)
141
+ sigmas = timesteps / num_train_timesteps
142
+
143
+
144
+
145
+ if not use_dynamic_shifting:
146
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
147
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
148
+
149
+ self.timesteps = None #sigmas * num_train_timesteps
150
+ self._step_index = None
151
+ self._begin_index = None
152
+ self._shift = shift
153
+ self.sigmas = sigmas #.to("cpu") # to avoid too much CPU/GPU communication
154
+ self.sigma_min = self.sigmas[-1].item()
155
+ self.sigma_max = self.sigmas[0].item()
156
+ # Store the predictions for the velocity/noise for higher order derivative approximations
157
+ self.velocity_predictions = []
158
+ self.noise_predictions = []
159
+ self.s = s
160
+ self.derivative_order = derivative_order
161
+
162
+ self.solver_order = solver_order
163
+ self.prediction_type = prediction_type
164
+
165
+
166
+ # Set the betas for noise-based models
167
+ if trained_betas is not None:
168
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
169
+ elif beta_schedule == "linear":
170
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
171
+ elif beta_schedule == "scaled_linear":
172
+ # this schedule is very specific to the latent diffusion model.
173
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
174
+ else:
175
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
176
+
177
+
178
+ self.alphas = 1.0 - self.betas
179
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
180
+
181
+ # standard deviation of the initial noise distribution
182
+ self.init_noise_sigma = 1.0
183
+
184
+ # Noise-based models epsilon to avoid numerical issues
185
+ self.stopping_eps = stopping_eps
186
+
187
+
188
+
189
+
190
+ def set_timesteps(
191
+ self,
192
+ num_inference_steps: Optional[int] = None,
193
+ device: Union[str, torch.device] = None,
194
+ sigmas: Optional[List[float]] = None,
195
+ mu: Optional[float] = None,
196
+ timesteps: Optional[List[float]] = None,
197
+ ):
198
+ """
199
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
200
+
201
+ Args:
202
+ num_inference_steps (`int`, *optional*):
203
+ The number of diffusion steps used when generating samples with a pre-trained model.
204
+ device (`str` or `torch.device`, *optional*):
205
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
206
+ sigmas (`List[float]`, *optional*):
207
+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
208
+ automatically.
209
+ mu (`float`, *optional*):
210
+ Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
211
+ shifting.
212
+ timesteps (`List[float]`, *optional*):
213
+ Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
214
+ automatically.
215
+ """
216
+
217
+ if self.config.use_dynamic_shifting and mu is None:
218
+ raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
219
+
220
+ if sigmas is not None and timesteps is not None:
221
+ if len(sigmas) != len(timesteps):
222
+ raise ValueError("`sigmas` and `timesteps` should have the same length")
223
+
224
+ if num_inference_steps is not None:
225
+ if (sigmas is not None and len(sigmas) != num_inference_steps) or (
226
+ timesteps is not None and len(timesteps) != num_inference_steps
227
+ ):
228
+ raise ValueError(
229
+ "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
230
+ )
231
+ else:
232
+ num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
233
+
234
+ self.num_inference_steps = num_inference_steps
235
+
236
+ if self.prediction_type == "epsilon":
237
+ self.set_timesteps_noise(num_inference_steps, device)
238
+ elif self.prediction_type == "flow_prediction":
239
+ self.set_timesteps_flow_matching(num_inference_steps, device, sigmas, mu, timesteps)
240
+ else:
241
+ raise ValueError(f"Prediction type {self.prediction_type} is not yet supported")
242
+
243
+ # Reset the step index and begin index
244
+ self._step_index = None
245
+ self._begin_index = None
246
+
247
+
248
+
249
+ def set_timesteps_noise(self,
250
+ num_inference_steps: Optional[int] = None,
251
+ device: Union[str, torch.device] = None,
252
+ ):
253
+ """
254
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference), for noise-based models.
255
+
256
+ Args:
257
+ num_inference_steps (`int`, *optional*):
258
+ The number of diffusion steps used when generating samples with a pre-trained model.
259
+ device (`str` or `torch.device`, *optional*):
260
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
261
+ """
262
+ seq = np.linspace(0, 1, self.num_inference_steps+1)
263
+ seq[0] = self.stopping_eps
264
+ seq = seq[:-1]
265
+ seq = seq[::-1]
266
+
267
+
268
+ # The following lines are for the uniform timestepping case
269
+ self.dt = seq[0] - seq[1]
270
+ seq = seq * self.config.num_train_timesteps
271
+ seq[-1] = self.stopping_eps * self.config.num_train_timesteps
272
+ self._timesteps = seq
273
+ self.timesteps = torch.from_numpy(seq.copy()).to(device)
274
+
275
+
276
+ self._step_index = None
277
+ self._begin_index = None
278
+
279
+ self.noise_predictions = []
280
+
281
+
282
+ def set_timesteps_flow_matching(self,
283
+ num_inference_steps: Optional[int] = None,
284
+ device: Union[str, torch.device] = None,
285
+ sigmas: Optional[List[float]] = None,
286
+ mu: Optional[float] = None,
287
+ timesteps: Optional[List[float]] = None,
288
+ ):
289
+ """
290
+ Sets the discrete timesteps used for the flow matching based models (to be run before inference).
291
+
292
+ Args:
293
+ num_inference_steps (`int`, *optional*):
294
+ The number of diffusion steps used when generating samples with a pre-trained model.
295
+ device (`str` or `torch.device`, *optional*):
296
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
297
+ sigmas (`List[float]`, *optional*):
298
+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
299
+ automatically.
300
+ mu (`float`, *optional*):
301
+ Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
302
+ shifting.
303
+ timesteps (`List[float]`, *optional*):
304
+ Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
305
+ automatically.
306
+ """
307
+ self.num_inference_steps = num_inference_steps
308
+
309
+ # 1. Prepare default sigmas
310
+ is_timesteps_provided = timesteps is not None
311
+
312
+ if is_timesteps_provided:
313
+ timesteps = np.array(timesteps).astype(np.float32)
314
+
315
+ if sigmas is None:
316
+ if timesteps is None:
317
+ timesteps = np.linspace(
318
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
319
+ )
320
+ sigmas = timesteps / self.config.num_train_timesteps
321
+ else:
322
+ sigmas = np.array(sigmas).astype(np.float32)
323
+ num_inference_steps = len(sigmas)
324
+
325
+ # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
326
+ # "exponential" or "linear" type is applied
327
+ if self.config.use_dynamic_shifting:
328
+ sigmas = self.time_shift(mu, 1.0, sigmas)
329
+ else:
330
+ sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
331
+
332
+ # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
333
+ if self.config.shift_terminal:
334
+ sigmas = self.stretch_shift_to_terminal(sigmas)
335
+
336
+ # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
337
+ if self.config.use_karras_sigmas:
338
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
339
+ elif self.config.use_exponential_sigmas:
340
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
341
+ elif self.config.use_beta_sigmas:
342
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
343
+
344
+ # 5. Convert sigmas and timesteps to tensors and move to specified device
345
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
346
+ if not is_timesteps_provided:
347
+ timesteps = sigmas * self.config.num_train_timesteps
348
+ else:
349
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
350
+
351
+ # 6. Append the terminal sigma value.
352
+ # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
353
+ # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
354
+ if self.config.invert_sigmas:
355
+ sigmas = 1.0 - sigmas
356
+ timesteps = sigmas * self.config.num_train_timesteps
357
+ sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
358
+ else:
359
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
360
+
361
+ self.timesteps = timesteps
362
+ self.sigmas = sigmas
363
+
364
+
365
+ # Create the dt list
366
+ self.dt_list = self.sigmas[:-1] - self.sigmas[1:]
367
+ self.dt_list = self.dt_list.reshape(-1)
368
+
369
+ self.dt_list = self.dt_list.tolist()
370
+ self.dt_list = torch.tensor(self.dt_list).to(self.dtype)
371
+
372
+ self.velocity_predictions = []
373
+
374
+
375
+ @property
376
+ def shift(self):
377
+ """
378
+ The value used for shifting.
379
+ """
380
+ return self._shift
381
+
382
+ @property
383
+ def step_index(self):
384
+ """
385
+ The index counter for current timestep. It will increase 1 after each scheduler step.
386
+ """
387
+ return self._step_index
388
+
389
+ @property
390
+ def begin_index(self):
391
+ """
392
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
393
+ """
394
+ return self._begin_index
395
+
396
+
397
+
398
+ def set_shift(self, shift: float):
399
+ self._shift = shift
400
+
401
+ def set_begin_index(self, begin_index: int):
402
+ """
403
+ Set the begin index for the scheduler.
404
+
405
+ Args:
406
+ begin_index (`int`):
407
+ The begin index to set.
408
+ """
409
+ self._begin_index = begin_index
410
+
411
+ def scale_noise(
412
+ self,
413
+ sample: torch.FloatTensor,
414
+ timestep: Union[float, torch.FloatTensor],
415
+ noise: Optional[torch.FloatTensor] = None,
416
+ ) -> torch.FloatTensor:
417
+ """
418
+ Forward process in flow-matching
419
+
420
+ Args:
421
+ sample (`torch.FloatTensor`):
422
+ The input sample.
423
+ timestep (`int`, *optional*):
424
+ The current timestep in the diffusion chain.
425
+
426
+ Returns:
427
+ `torch.FloatTensor`:
428
+ A scaled input sample.
429
+ """
430
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
431
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
432
+
433
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
434
+ # mps does not support float64
435
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=self.dtype)
436
+ timestep = timestep.to(sample.device, dtype=self.dtype)
437
+ else:
438
+ schedule_timesteps = self.timesteps.to(sample.device)
439
+ timestep = timestep.to(sample.device)
440
+
441
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
442
+ if self.begin_index is None:
443
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
444
+ elif self.step_index is not None:
445
+ # add_noise is called after first denoising step (for inpainting)
446
+ step_indices = [self.step_index] * timestep.shape[0]
447
+ else:
448
+ # add noise is called before first denoising step to create initial latent(img2img)
449
+ step_indices = [self.begin_index] * timestep.shape[0]
450
+
451
+ sigma = sigmas[step_indices].flatten()
452
+ while len(sigma.shape) < len(sample.shape):
453
+ sigma = sigma.unsqueeze(-1)
454
+
455
+ sample = sigma * noise + (1.0 - sigma) * sample
456
+
457
+ return sample
458
+
459
+ def _sigma_to_t(self, sigma):
460
+ return sigma * self.config.num_train_timesteps
461
+
462
+ def index_for_timestep(self, timestep, schedule_timesteps):
463
+ """
464
+ Get the index for a given timestep in the schedule.
465
+
466
+ Args:
467
+ timestep (`torch.Tensor`):
468
+ The timestep to find the index for.
469
+ schedule_timesteps (`torch.Tensor`):
470
+ The schedule timesteps.
471
+
472
+ Returns:
473
+ `int`:
474
+ The index for the timestep.
475
+ """
476
+ # Find the closest timestep in the schedule
477
+ indices = torch.searchsorted(schedule_timesteps, timestep, right=True)
478
+ indices = torch.clamp(indices, 0, len(schedule_timesteps) - 1)
479
+ return indices.item()
480
+
481
+
482
+
483
+ def step(
484
+ self,
485
+ model_output: torch.Tensor,
486
+ timestep: Union[int, torch.Tensor],
487
+ sample: torch.Tensor = None,
488
+ return_dict: bool = True,
489
+ **kwargs
490
+ ) -> torch.Tensor:
491
+ '''
492
+ One step of the STORK update for flow matching or noise-based diffusion models.
493
+
494
+ Args:
495
+ model_output (`torch.FloatTensor`):
496
+ The direct output from learned diffusion model.
497
+ timestep (`float`):
498
+ The current discrete timestep in the diffusion chain.
499
+ sample (`torch.FloatTensor`):
500
+ A current instance of a sample created by the diffusion process.
501
+ return_dict (`bool`, defaults to `True`):
502
+ Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple.
503
+
504
+ Returns:
505
+ result (Union[Tuple, STORKSchedulerOutput]):
506
+ The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues.
507
+ '''
508
+ original_model_output_dtype = model_output.dtype
509
+ # Cast model_output and sample to "torch.float32" to avoid numerical issues
510
+ model_output = model_output.to(self.dtype)
511
+ sample = sample.to(self.dtype)
512
+ # Move sample to model_output's device
513
+ sample = sample.to(model_output.device)
514
+
515
+ """
516
+ self.velocity_predictions always contain upcasted model_output in torch.float32 dtype.
517
+ """
518
+
519
+ if self.prediction_type == "epsilon":
520
+ if self.solver_order == 2:
521
+ result = self.step_noise_2(model_output, timestep, sample, return_dict)
522
+ elif self.solver_order == 4:
523
+ result = self.step_noise_4(model_output, timestep, sample, return_dict)
524
+ else:
525
+ raise ValueError(f"Solver order {self.solver_order} is not yet supported for noise-based models")
526
+ elif self.prediction_type == "flow_prediction":
527
+ if self.solver_order == 1:
528
+ result = self.step_flow_matching_1(model_output, timestep, sample, return_dict)
529
+ elif self.solver_order == 2:
530
+ result = self.step_flow_matching_2(model_output, timestep, sample, return_dict)
531
+ elif self.solver_order == 4:
532
+ result = self.step_flow_matching_4(model_output, timestep, sample, return_dict)
533
+ else:
534
+ raise ValueError(f"Solver order {self.solver_order} is not yet supported for flow matching models")
535
+ else:
536
+ raise ValueError(f"Prediction type {self.prediction_type} is not yet supported")
537
+
538
+ # Convert the result back to the original dtype of model_output, as this result will be used as the next input to the model
539
+ if return_dict:
540
+ result.prev_sample = result.prev_sample.to(original_model_output_dtype)
541
+ else:
542
+ result = (result[0].to(original_model_output_dtype),)
543
+ return result
544
+
545
+
546
+ def step_flow_matching_1(
547
+ self,
548
+ model_output: torch.Tensor,
549
+ timestep: Union[int, torch.Tensor],
550
+ sample: torch.Tensor = None,
551
+ return_dict: bool = False
552
+ ) -> torch.Tensor:
553
+ # Initialize the step index if it's the first step
554
+ if self._step_index is None:
555
+ self._step_index = 0
556
+
557
+
558
+ # Compute the startup phase or the derivative approximation for the main step
559
+ if self._step_index == 0:
560
+ img_next = sample - model_output * self.dt_list[self._step_index]
561
+ self._step_index += 1
562
+ self.velocity_predictions.append(model_output)
563
+
564
+ if not return_dict:
565
+ return (img_next,)
566
+ return STORKSchedulerOutput(prev_sample=img_next)
567
+ else:
568
+ t = self.sigmas[self._step_index]
569
+ t_start = torch.ones(model_output.shape, device=sample.device) * t
570
+ t_next = self.sigmas[self._step_index + 1]
571
+
572
+ h1 = self.dt_list[self._step_index-1]
573
+
574
+ if self.derivative_order == 1:
575
+ # Ensure h1 is a tensor for proper broadcasting
576
+ h1_tensor = torch.tensor(h1, device=model_output.device, dtype=model_output.dtype)
577
+ velocity_derivative = (self.velocity_predictions[-1] - model_output) / h1_tensor
578
+ velocity_second_derivative = None
579
+ velocity_third_derivative = None
580
+ elif self.derivative_order == 2:
581
+ # Ensure h1 and h2 are tensors for proper broadcasting
582
+ h1_tensor = torch.tensor(h1, device=model_output.device, dtype=model_output.dtype)
583
+ if self._step_index == 1:
584
+ img_next = sample - 1.5 * model_output * self.dt_list[self._step_index] + 0.5 * self.velocity_predictions[-1] * self.dt_list[self._step_index-1]
585
+ self._step_index += 1
586
+ self.velocity_predictions.append(model_output)
587
+
588
+ if not return_dict:
589
+ return (img_next,)
590
+ return STORKSchedulerOutput(prev_sample=img_next)
591
+ else:
592
+ h2 = self.dt_list[self._step_index-2]
593
+ h2_tensor = torch.tensor(h2, device=model_output.device, dtype=model_output.dtype)
594
+ velocity_derivative = (-self.velocity_predictions[-2] + 4 * self.velocity_predictions[-1] - 3 * model_output) / (2 * h1_tensor)
595
+ velocity_second_derivative = 2 / (h1_tensor * h2_tensor * (h1_tensor + h2_tensor)) * (self.velocity_predictions[-2] * h1_tensor - self.velocity_predictions[-1] * (h1_tensor + h2_tensor) + model_output * h2_tensor)
596
+ velocity_third_derivative = None
597
+ elif self.derivative_order == 3:
598
+
599
+ if self._step_index == 1 or self._step_index == 2:
600
+ img_next = sample - 1.5 * model_output * self.dt_list[self._step_index] + 0.5 * self.velocity_predictions[-1] * self.dt_list[self._step_index-1]
601
+ self._step_index += 1
602
+ self.velocity_predictions.append(model_output)
603
+
604
+ if not return_dict:
605
+ return (img_next,)
606
+ return STORKSchedulerOutput(prev_sample=img_next)
607
+ else:
608
+ h2 = h1 + self.dt_list[self._step_index-2]
609
+ h3 = h2 + self.dt_list[self._step_index-3]
610
+ # Ensure h1, h2, and h3 are tensors for proper broadcasting
611
+ h1_tensor = torch.tensor(h1, device=model_output.device, dtype=model_output.dtype)
612
+ h2_tensor = torch.tensor(h2, device=model_output.device, dtype=model_output.dtype)
613
+ h3_tensor = torch.tensor(h3, device=model_output.device, dtype=model_output.dtype)
614
+ velocity_derivative = ((h2_tensor * h3_tensor) * (self.velocity_predictions[-1] - model_output) - (h1_tensor * h3_tensor) * (self.velocity_predictions[-2] - model_output) + (h1_tensor * h2_tensor) * (self.velocity_predictions[-3] - model_output)) / (h1_tensor * h2_tensor * h3_tensor)
615
+ velocity_second_derivative = 2 * ((h2_tensor + h3_tensor) * (self.velocity_predictions[-1] - model_output) - (h1_tensor + h3_tensor) * (self.velocity_predictions[-2] - model_output) + (h1_tensor + h2_tensor) * (self.velocity_predictions[-3] - model_output)) / (h1_tensor * h2_tensor * h3_tensor)
616
+ velocity_third_derivative = 6 * ((h2_tensor - h3_tensor) * (self.velocity_predictions[-1] - model_output) + (h3_tensor - h1_tensor) * (self.velocity_predictions[-2] - model_output) + (h1_tensor - h2_tensor) * (self.velocity_predictions[-3] - model_output)) / (h1_tensor * h2_tensor * h3_tensor)
617
+ else:
618
+ print("The noise approximation order is not supported!")
619
+ exit()
620
+
621
+ self.velocity_predictions.append(model_output)
622
+ self._step_index += 1
623
+
624
+
625
+
626
+ Y_j_2 = sample
627
+ Y_j_1 = sample
628
+ Y_j = sample
629
+
630
+
631
+ # Implementation of our Runge-Kutta-Gegenbauer second order method
632
+ for j in range(1, self.s + 1):
633
+ # Calculate the corresponding \bar{alpha}_t and beta_t that aligns with the correct timestep
634
+ fraction = (j - 1) * (j + 2) / (self.s * (self.s + 3))
635
+
636
+ if j == 1:
637
+ mu_tilde = 4 / (self.s * (self.s + 1))
638
+ dt = (t - t_next) * torch.ones(model_output.shape, device=sample.device)
639
+ Y_j = Y_j_1 - dt * mu_tilde * model_output
640
+ else:
641
+ mu = (2 * j + 1) * self.coeff_rock1(j) / (j * self.coeff_rock1(j - 1))
642
+ nu = -(j + 1) * self.coeff_rock1(j) / (j * self.coeff_rock1(j - 2))
643
+ mu_tilde = mu * 4 / (self.s * (self.s + 1))
644
+
645
+
646
+ # Probability flow ODE update
647
+ diff = -fraction * (t - t_next) * torch.ones(model_output.shape, device=sample.device)
648
+ velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative)
649
+ Y_j = mu * Y_j_1 + nu * Y_j_2 - dt * mu_tilde * velocity
650
+
651
+ Y_j_2 = Y_j_1
652
+ Y_j_1 = Y_j
653
+
654
+
655
+
656
+ img_next = Y_j
657
+ img_next = img_next.to(model_output.dtype)
658
+
659
+ return SchedulerOutput(prev_sample=img_next)
660
+
661
+
662
+
663
+
664
+ def step_flow_matching_2(
665
+ self,
666
+ model_output: torch.Tensor,
667
+ timestep: Union[int, torch.Tensor],
668
+ sample: torch.Tensor = None,
669
+ return_dict: bool = False,
670
+ ) -> torch.Tensor:
671
+ '''
672
+ One step of the STORK2 update for flow matching based models.
673
+
674
+ Args:
675
+ model_output (`torch.FloatTensor`):
676
+ The direct output from learned diffusion model.
677
+ timestep (`float`):
678
+ The current discrete timestep in the diffusion chain.
679
+ sample (`torch.FloatTensor`):
680
+ A current instance of a sample created by the diffusion process.
681
+ return_dict (`bool`, defaults to `True`):
682
+ Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple.
683
+
684
+ Returns:
685
+ result (Union[Tuple, STORKSchedulerOutput]):
686
+ The next sample in the diffusion chain, either as a tuple or as a [`~schedulers.STORKSchedulerOutput`]. The value is converted back to the original dtype of `model_output` to avoid numerical issues.
687
+ '''
688
+ # Initialize the step index if it's the first step
689
+ if self._step_index is None:
690
+ self._step_index = 0
691
+
692
+
693
+ # Compute the startup phase or the derivative approximation for the main step
694
+ if self._step_index == 0:
695
+ img_next = sample - model_output * self.dt_list[self._step_index]
696
+ self._step_index += 1
697
+ self.velocity_predictions.append(model_output)
698
+
699
+ if not return_dict:
700
+ return (img_next,)
701
+ return STORKSchedulerOutput(prev_sample=img_next)
702
+ else:
703
+ t = self.sigmas[self._step_index]
704
+ t_start = torch.ones(model_output.shape, device=sample.device) * t
705
+ t_next = self.sigmas[self._step_index + 1]
706
+
707
+ h1 = self.dt_list[self._step_index-1]
708
+
709
+ if self.derivative_order == 1:
710
+ # Ensure h1 is a tensor for proper broadcasting
711
+ h1_tensor = torch.tensor(h1, device=model_output.device, dtype=model_output.dtype)
712
+ velocity_derivative = (self.velocity_predictions[-1] - model_output) / h1_tensor
713
+ velocity_second_derivative = None
714
+ velocity_third_derivative = None
715
+ elif self.derivative_order == 2:
716
+ # Ensure h1 and h2 are tensors for proper broadcasting
717
+ h1_tensor = torch.tensor(h1, device=model_output.device, dtype=model_output.dtype)
718
+ if self._step_index == 1:
719
+ img_next = sample - 1.5 * model_output * self.dt_list[self._step_index] + 0.5 * self.velocity_predictions[-1] * self.dt_list[self._step_index-1]
720
+ self._step_index += 1
721
+ self.velocity_predictions.append(model_output)
722
+
723
+ if not return_dict:
724
+ return (img_next,)
725
+ return STORKSchedulerOutput(prev_sample=img_next)
726
+ else:
727
+ h2 = self.dt_list[self._step_index-2]
728
+ h2_tensor = torch.tensor(h2, device=model_output.device, dtype=model_output.dtype)
729
+ velocity_derivative = (-self.velocity_predictions[-2] + 4 * self.velocity_predictions[-1] - 3 * model_output) / (2 * h1_tensor)
730
+ velocity_second_derivative = 2 / (h1_tensor * h2_tensor * (h1_tensor + h2_tensor)) * (self.velocity_predictions[-2] * h1_tensor - self.velocity_predictions[-1] * (h1_tensor + h2_tensor) + model_output * h2_tensor)
731
+ velocity_third_derivative = None
732
+ elif self.derivative_order == 3:
733
+
734
+ if self._step_index == 1 or self._step_index == 2:
735
+ img_next = sample - 1.5 * model_output * self.dt_list[self._step_index] + 0.5 * self.velocity_predictions[-1] * self.dt_list[self._step_index-1]
736
+ self._step_index += 1
737
+ self.velocity_predictions.append(model_output)
738
+
739
+ if not return_dict:
740
+ return (img_next,)
741
+ return STORKSchedulerOutput(prev_sample=img_next)
742
+ else:
743
+ h2 = h1 + self.dt_list[self._step_index-2]
744
+ h3 = h2 + self.dt_list[self._step_index-3]
745
+ # Ensure h1, h2, and h3 are tensors for proper broadcasting
746
+ h1_tensor = torch.tensor(h1, device=model_output.device, dtype=model_output.dtype)
747
+ h2_tensor = torch.tensor(h2, device=model_output.device, dtype=model_output.dtype)
748
+ h3_tensor = torch.tensor(h3, device=model_output.device, dtype=model_output.dtype)
749
+ velocity_derivative = ((h2_tensor * h3_tensor) * (self.velocity_predictions[-1] - model_output) - (h1_tensor * h3_tensor) * (self.velocity_predictions[-2] - model_output) + (h1_tensor * h2_tensor) * (self.velocity_predictions[-3] - model_output)) / (h1_tensor * h2_tensor * h3_tensor)
750
+ velocity_second_derivative = 2 * ((h2_tensor + h3_tensor) * (self.velocity_predictions[-1] - model_output) - (h1_tensor + h3_tensor) * (self.velocity_predictions[-2] - model_output) + (h1_tensor + h2_tensor) * (self.velocity_predictions[-3] - model_output)) / (h1_tensor * h2_tensor * h3_tensor)
751
+ velocity_third_derivative = 6 * ((h2_tensor - h3_tensor) * (self.velocity_predictions[-1] - model_output) + (h3_tensor - h1_tensor) * (self.velocity_predictions[-2] - model_output) + (h1_tensor - h2_tensor) * (self.velocity_predictions[-3] - model_output)) / (h1_tensor * h2_tensor * h3_tensor)
752
+ else:
753
+ print("The noise approximation order is not supported!")
754
+ exit()
755
+
756
+ self.velocity_predictions.append(model_output)
757
+ self._step_index += 1
758
+
759
+
760
+ Y_j_2 = sample
761
+ Y_j_1 = sample
762
+ Y_j = sample
763
+
764
+
765
+ # Implementation of our Runge-Kutta-Gegenbauer second order method
766
+ for j in range(1, self.s + 1):
767
+ # Calculate the corresponding \bar{alpha}_t and beta_t that aligns with the correct timestep
768
+ if j > 1:
769
+ if j == 2:
770
+ fraction = 4 / (3 * (self.s**2 + self.s - 2))
771
+ else:
772
+ fraction = ((j - 1)**2 + (j - 1) - 2) / (self.s**2 + self.s - 2)
773
+
774
+ if j == 1:
775
+ mu_tilde = 6 / ((self.s + 4) * (self.s - 1))
776
+ dt = (t - t_next) * torch.ones(model_output.shape, device=sample.device)
777
+ Y_j = Y_j_1 - dt * mu_tilde * model_output
778
+ else:
779
+ mu = (2 * j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 1))
780
+ nu = -(j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 2))
781
+ mu_tilde = mu * 6 / ((self.s + 4) * (self.s - 1))
782
+ gamma_tilde = -mu_tilde * (1 - j * (j + 1) * self.b_coeff(j-1)/ 2)
783
+
784
+
785
+ # Probability flow ODE update
786
+ diff = -fraction * (t - t_next) * torch.ones(model_output.shape, device=sample.device)
787
+ velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative)
788
+ Y_j = mu * Y_j_1 + nu * Y_j_2 + (1 - mu - nu) * sample - dt * mu_tilde * velocity - dt * gamma_tilde * model_output
789
+
790
+ Y_j_2 = Y_j_1
791
+ Y_j_1 = Y_j
792
+
793
+
794
+
795
+ img_next = Y_j
796
+ img_next = img_next.to(model_output.dtype)
797
+
798
+ if not return_dict:
799
+ return (img_next,)
800
+ return STORKSchedulerOutput(prev_sample=img_next)
801
+
802
+
803
+ def step_flow_matching_4(
804
+ self,
805
+ model_output: torch.Tensor,
806
+ timestep: Union[int, torch.Tensor],
807
+ sample: torch.Tensor = None,
808
+ return_dict: bool = False,
809
+ ) -> torch.Tensor:
810
+ '''
811
+ One step of the STORK4 update for flow matching models
812
+
813
+ Args:
814
+ model_output (`torch.FloatTensor`):
815
+ The direct output from learned diffusion model.
816
+ timestep (`float`):
817
+ The current discrete timestep in the diffusion chain.
818
+ sample (`torch.FloatTensor`):
819
+ A current instance of a sample created by the diffusion process.
820
+
821
+ Returns:
822
+ `torch.FloatTensor`: The next sample in the diffusion chain.
823
+ '''
824
+ # Initialize the step index if it's the first step
825
+ if self._step_index is None:
826
+ self._step_index = 0
827
+
828
+
829
+ # Compute the startup phase or the derivative approximation for the main step
830
+ if self._step_index == 0:
831
+ img_next = sample - model_output * self.dt_list[self._step_index]
832
+ self._step_index += 1
833
+ self.velocity_predictions.append(model_output)
834
+
835
+ if not return_dict:
836
+ return (img_next,)
837
+ return STORKSchedulerOutput(prev_sample=img_next)
838
+ else:
839
+ t = self.sigmas[self._step_index]
840
+ t_start = torch.ones(model_output.shape, device=sample.device) * t
841
+ t_next = self.sigmas[self._step_index + 1]
842
+
843
+ h1 = self.dt_list[self._step_index-1]
844
+
845
+ if self.derivative_order == 1:
846
+ # Ensure h1 is a tensor for proper broadcasting
847
+ h1_tensor = torch.tensor(h1, device=model_output.device, dtype=model_output.dtype)
848
+ velocity_derivative = (self.velocity_predictions[-1] - model_output) / h1_tensor
849
+ velocity_second_derivative = None
850
+ velocity_third_derivative = None
851
+ elif self.derivative_order == 2:
852
+ # Ensure h1 and h2 are tensors for proper broadcasting
853
+ h1_tensor = torch.tensor(h1, device=model_output.device, dtype=model_output.dtype)
854
+ if self._step_index == 1:
855
+ img_next = sample - 1.5 * model_output * self.dt_list[self._step_index] + 0.5 * self.velocity_predictions[-1] * self.dt_list[self._step_index-1]
856
+ self._step_index += 1
857
+ self.velocity_predictions.append(model_output)
858
+
859
+ if not return_dict:
860
+ return (img_next,)
861
+ return STORKSchedulerOutput(prev_sample=img_next)
862
+ else:
863
+ h2 = self.dt_list[self._step_index-2]
864
+ h2_tensor = torch.tensor(h2, device=model_output.device, dtype=model_output.dtype)
865
+ velocity_derivative = (-self.velocity_predictions[-2] + 4 * self.velocity_predictions[-1] - 3 * model_output) / (2 * h1_tensor)
866
+ velocity_second_derivative = 2 / (h1_tensor * h2_tensor * (h1_tensor + h2_tensor)) * (self.velocity_predictions[-2] * h1_tensor - self.velocity_predictions[-1] * (h1_tensor + h2_tensor) + model_output * h2_tensor)
867
+ velocity_third_derivative = None
868
+ elif self.derivative_order == 3:
869
+
870
+ if self._step_index == 1 or self._step_index == 2:
871
+ img_next = sample - 1.5 * model_output * self.dt_list[self._step_index] + 0.5 * self.velocity_predictions[-1] * self.dt_list[self._step_index-1]
872
+ self._step_index += 1
873
+ self.velocity_predictions.append(model_output)
874
+
875
+ if not return_dict:
876
+ return (img_next,)
877
+ return STORKSchedulerOutput(prev_sample=img_next)
878
+ else:
879
+ h2 = h1 + self.dt_list[self._step_index-2]
880
+ h3 = h2 + self.dt_list[self._step_index-3]
881
+ # Ensure h1, h2, and h3 are tensors for proper broadcasting
882
+ h1_tensor = torch.tensor(h1, device=model_output.device, dtype=model_output.dtype)
883
+ h2_tensor = torch.tensor(h2, device=model_output.device, dtype=model_output.dtype)
884
+ h3_tensor = torch.tensor(h3, device=model_output.device, dtype=model_output.dtype)
885
+ velocity_derivative = ((h2_tensor * h3_tensor) * (self.velocity_predictions[-1] - model_output) - (h1_tensor * h3_tensor) * (self.velocity_predictions[-2] - model_output) + (h1_tensor * h2_tensor) * (self.velocity_predictions[-3] - model_output)) / (h1_tensor * h2_tensor * h3_tensor)
886
+ velocity_second_derivative = 2 * ((h2_tensor + h3_tensor) * (self.velocity_predictions[-1] - model_output) - (h1_tensor + h3_tensor) * (self.velocity_predictions[-2] - model_output) + (h1_tensor + h2_tensor) * (self.velocity_predictions[-3] - model_output)) / (h1_tensor * h2_tensor * h3_tensor)
887
+ velocity_third_derivative = 6 * ((h2_tensor - h3_tensor) * (self.velocity_predictions[-1] - model_output) + (h3_tensor - h1_tensor) * (self.velocity_predictions[-2] - model_output) + (h1_tensor - h2_tensor) * (self.velocity_predictions[-3] - model_output)) / (h1_tensor * h2_tensor * h3_tensor)
888
+ else:
889
+ print("The noise approximation order is not supported!")
890
+ exit()
891
+
892
+ self.velocity_predictions.append(model_output)
893
+ self._step_index += 1
894
+
895
+ Y_j_2 = sample
896
+ Y_j_1 = sample
897
+ Y_j = sample
898
+
899
+ ci1 = t_start
900
+ ci2 = t_start
901
+ ci3 = t_start
902
+
903
+ # Coefficients of ROCK4
904
+ ms, fpa, fpb, fpbe, recf = self.coeff_rock4()
905
+ # Choose the degree that's in the precomputed table
906
+ mdeg, mp = self.mdegr(self.s, ms)
907
+ mz = int(mp[0])
908
+ mr = int(mp[1])
909
+
910
+ '''
911
+ The first part of the STORK4 update
912
+ '''
913
+ for j in range(1, mdeg + 1):
914
+
915
+ # First sub-step in the first part of the STORK4 update
916
+ if j == 1:
917
+ temp1 = -(t - t_next) * recf[mr] * torch.ones(model_output.shape, device=sample.device)
918
+ ci1 = t_start + temp1
919
+ ci2 = ci1
920
+ Y_j_1 = sample + temp1 * model_output
921
+ # Y_j = sample + temp1 * model_output
922
+ # Second and the following sub-steps in the first part of the STORK4 update
923
+ else:
924
+ diff = ci1 - t_start
925
+ velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative)
926
+
927
+ temp1 = -(t - t_next) * recf[mr + 2 * (j-2) + 1] * torch.ones(model_output.shape, device=sample.device)
928
+ temp3 = -recf[mr + 2 * (j-2) + 2] * torch.ones(model_output.shape, device=sample.device)
929
+ temp2 = torch.ones(model_output.shape, device=sample.device) - temp3
930
+
931
+ ci1 = temp1 + temp2 * ci2 + temp3 * ci3
932
+ Y_j = temp1 * velocity + temp2 * Y_j_1 + temp3 * Y_j_2
933
+
934
+ # Update the intermediate variables
935
+ Y_j_2 = Y_j_1
936
+ Y_j_1 = Y_j
937
+
938
+ ci3 = ci2
939
+ ci2 = ci1
940
+
941
+ '''
942
+ The finishing four-step procedure as a composition method
943
+ '''
944
+ # First finishing step
945
+ temp1 = -(t - t_next) * fpa[mz,0] * torch.ones(model_output.shape, device=sample.device)
946
+ diff = ci1 - t_start
947
+ velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative)
948
+ Y_j_1 = velocity
949
+ Y_j_3 = Y_j + temp1 * Y_j_1
950
+
951
+ # Second finishing step
952
+ ci2 = ci1 + temp1
953
+ temp1 = -(t - t_next) * fpa[mz,1] * torch.ones(model_output.shape, device=sample.device)
954
+ temp2 = -(t - t_next) * fpa[mz,2] * torch.ones(model_output.shape, device=sample.device)
955
+ diff = ci2 - t_start
956
+ velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative)
957
+ Y_j_2 = velocity
958
+ Y_j_4 = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2
959
+
960
+ # Third finishing step
961
+ ci2 = ci1 + temp1 + temp2
962
+ temp1 = -(t - t_next) * fpa[mz,3] * torch.ones(model_output.shape, device=sample.device)
963
+ temp2 = -(t - t_next) * fpa[mz,4] * torch.ones(model_output.shape, device=sample.device)
964
+ temp3 = -(t - t_next) * fpa[mz,5] * torch.ones(model_output.shape, device=sample.device)
965
+ diff = ci2 - t_start
966
+ velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative)
967
+ Y_j_3 = velocity
968
+ # This is the counterpart of the final step in the noise-based diffusion models STORK4
969
+ # fnt = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3
970
+
971
+ # Fourth finishing step
972
+ ci2 = ci1 + temp1 + temp2 + temp3
973
+ temp1 = -(t - t_next) * fpb[mz,0] * torch.ones(model_output.shape, device=sample.device)
974
+ temp2 = -(t - t_next) * fpb[mz,1] * torch.ones(model_output.shape, device=sample.device)
975
+ temp3 = -(t - t_next) * fpb[mz,2] * torch.ones(model_output.shape, device=sample.device)
976
+ temp4 = -(t - t_next) * fpb[mz,3] * torch.ones(model_output.shape, device=sample.device)
977
+ diff = ci2 - t_start
978
+ velocity = self.taylor_approximation(self.derivative_order, diff, model_output, velocity_derivative, velocity_second_derivative, velocity_third_derivative)
979
+ Y_j_4 = velocity
980
+ Y_j = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 + temp4 * Y_j_4
981
+ img_next = Y_j
982
+
983
+ if not return_dict:
984
+ return (img_next,)
985
+ return STORKSchedulerOutput(prev_sample=img_next)
986
+
987
+
988
+ def step_noise_2(
989
+ self,
990
+ model_output: torch.Tensor,
991
+ timestep: Union[int, torch.Tensor],
992
+ sample: torch.Tensor = None,
993
+ return_dict: bool = False,
994
+ ) -> torch.Tensor:
995
+ '''
996
+ One step of the STORK2 update for noise-based diffusion models.
997
+
998
+ Args:
999
+ model_output (`torch.FloatTensor`):
1000
+ The direct output from learned diffusion model.
1001
+ timestep (`float`):
1002
+ The current discrete timestep in the diffusion chain.
1003
+ sample (`torch.FloatTensor`):
1004
+ A current instance of a sample created by the diffusion process.
1005
+ return_dict (`bool`, defaults to `True`):
1006
+ Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple.
1007
+
1008
+ Returns:
1009
+ `torch.FloatTensor`: The next sample in the diffusion chain.
1010
+ '''
1011
+ # Initialize the step index if it's the first step
1012
+ if self._step_index is None:
1013
+ self._step_index = 0
1014
+ self.initial_noise = model_output
1015
+
1016
+
1017
+ total_step = self.config.num_train_timesteps
1018
+ t = self.timesteps[self._step_index] / total_step
1019
+
1020
+ beta_0, beta_1 = self.betas[0], self.betas[-1]
1021
+ t_start = torch.ones(model_output.shape, device=sample.device) * t
1022
+ beta_t = (beta_0 + t_start * (beta_1 - beta_0)) * total_step
1023
+ log_mean_coeff = (-0.25 * t_start ** 2 * (beta_1 - beta_0) - 0.5 * t_start * beta_0) * total_step
1024
+ std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
1025
+
1026
+ # Tweedie's trick
1027
+ if self._step_index == len(self.timesteps) - 1:
1028
+ noise_last = model_output
1029
+ img_next = sample - std * noise_last
1030
+ if not return_dict:
1031
+ return (img_next,)
1032
+ return STORKSchedulerOutput(prev_sample=img_next)
1033
+
1034
+ t_next = self.timesteps[self._step_index + 1] / total_step
1035
+
1036
+ # drift, diffusion -> f(x,t), g(t)
1037
+ drift_initial, diffusion_initial = -0.5 * beta_t * sample, torch.sqrt(beta_t) * torch.ones(sample.shape, device=sample.device)
1038
+ noise_initial = model_output
1039
+ score = -noise_initial / std # score -> noise
1040
+ drift_initial = drift_initial - diffusion_initial ** 2 * score * 0.5 # drift -> dx/dt
1041
+
1042
+
1043
+ dt = torch.ones(model_output.shape, device=sample.device) * self.dt
1044
+
1045
+ if self._step_index == 0:
1046
+ # FIRST RUN
1047
+ self.initial_sample = sample
1048
+ img_next = sample - 0.5 * dt * drift_initial
1049
+
1050
+ self.noise_predictions.append(noise_initial)
1051
+ self._step_index += 1
1052
+
1053
+ self.initial_sample = sample
1054
+ self.initial_drift = drift_initial
1055
+ self.initial_noise = model_output
1056
+
1057
+ return SchedulerOutput(prev_sample=img_next)
1058
+ elif self._step_index == 1:
1059
+ # SECOND RUN
1060
+ t_previous = torch.ones(model_output.shape, device=sample.device) * self.timesteps[0] / 1000
1061
+ drift_previous = self.drift_function(self.betas, self.config.num_train_timesteps, t_previous, self.initial_sample, self.noise_predictions[-1])
1062
+
1063
+ img_next = sample - 0.75 * dt * drift_initial + 0.25 * dt * drift_previous
1064
+
1065
+ self.noise_predictions.append(noise_initial)
1066
+ self._step_index += 1
1067
+
1068
+ return SchedulerOutput(prev_sample=img_next)
1069
+ elif self._step_index == 2:
1070
+ h = 0.5 * dt
1071
+
1072
+ noise_derivative = (3 * self.noise_predictions[0] - 4 * self.noise_predictions[1] + model_output) / (2 * h)
1073
+ noise_second_derivative = (self.noise_predictions[0] - 2 * self.noise_predictions[1] + model_output) / (h ** 2)
1074
+ noise_third_derivative = None
1075
+
1076
+ model_output = self.initial_noise
1077
+ drift_initial = self.initial_drift
1078
+ sample = self.initial_sample
1079
+
1080
+ t = self.timesteps[0] / total_step
1081
+ t_start = torch.ones(model_output.shape, device=sample.device) * t
1082
+ t_next = self.timesteps[2] / total_step
1083
+ elif self._step_index == 3:
1084
+ h = 0.5 * dt
1085
+
1086
+ noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h)
1087
+ noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2)
1088
+ noise_third_derivative = None
1089
+
1090
+ self.noise_predictions.append(noise_initial)
1091
+ elif self._step_index == 4:
1092
+ h = dt
1093
+
1094
+ noise_derivative = (-3 * noise_initial + 4 * self.noise_predictions[-1] - self.noise_predictions[-2]) / (2 * h)
1095
+ noise_second_derivative = (noise_initial - 2 * self.noise_predictions[-1] + self.noise_predictions[-2]) / (h ** 2)
1096
+ noise_third_derivative = None
1097
+
1098
+ self.noise_predictions.append(noise_initial)
1099
+ else:
1100
+ # ALL ELSE
1101
+ h = dt
1102
+
1103
+ noise_derivative = (2 * self.noise_predictions[-3] - 9 * self.noise_predictions[-2] + 18 * self.noise_predictions[-1] - 11 * noise_initial) / (6 * h)
1104
+ noise_second_derivative = (-self.noise_predictions[-3] + 4 * self.noise_predictions[-2] -5 * self.noise_predictions[-1] + 2 * noise_initial) / (h**2)
1105
+ noise_third_derivative = (self.noise_predictions[-3] - 3 * self.noise_predictions[-2] + 3 * self.noise_predictions[-1] - noise_initial) / (h**3)
1106
+
1107
+ self.noise_predictions.append(noise_initial)
1108
+
1109
+
1110
+ Y_j_2 = sample
1111
+ Y_j_1 = sample
1112
+ Y_j = sample
1113
+
1114
+ # Implementation of our Runge-Kutta-Gegenbauer second order method
1115
+ for j in range(1, self.s + 1):
1116
+ # Calculate the corresponding \bar{alpha}_t and beta_t that aligns with the correct timestep
1117
+ if j > 1:
1118
+ if j == 2:
1119
+ fraction = 4 / (3 * (self.s**2 + self.s - 2))
1120
+ else:
1121
+ fraction = ((j - 1)**2 + (j - 1) - 2) / (self.s**2 + self.s - 2)
1122
+
1123
+ if j == 1:
1124
+ mu_tilde = 6 / ((self.s + 4) * (self.s - 1))
1125
+ dt = (t - t_next) * torch.ones(model_output.shape, device=sample.device)
1126
+ Y_j = Y_j_1 - dt * mu_tilde * model_output
1127
+ else:
1128
+ mu = (2 * j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 1))
1129
+ nu = -(j + 1) * self.b_coeff(j) / (j * self.b_coeff(j - 2))
1130
+ mu_tilde = mu * 6 / ((self.s + 4) * (self.s - 1))
1131
+ gamma_tilde = -mu_tilde * (1 - j * (j + 1) * self.b_coeff(j-1)/ 2)
1132
+
1133
+
1134
+ # Probability flow ODE update
1135
+ diff = -fraction * (t - t_next) * torch.ones(model_output.shape, device=sample.device)
1136
+ velocity = self.taylor_approximation(self.derivative_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative)
1137
+ Y_j = mu * Y_j_1 + nu * Y_j_2 + (1 - mu - nu) * sample - dt * mu_tilde * velocity - dt * gamma_tilde * model_output
1138
+
1139
+ Y_j_2 = Y_j_1
1140
+ Y_j_1 = Y_j
1141
+
1142
+
1143
+
1144
+ img_next = Y_j
1145
+ img_next = img_next.to(model_output.dtype)
1146
+ self._step_index += 1
1147
+
1148
+ if not return_dict:
1149
+ return (img_next,)
1150
+ return STORKSchedulerOutput(prev_sample=img_next)
1151
+
1152
+
1153
+ def step_noise_4(
1154
+ self,
1155
+ model_output: torch.Tensor,
1156
+ timestep: Union[int, torch.Tensor],
1157
+ sample: torch.Tensor = None,
1158
+ return_dict: bool = False,
1159
+ ) -> torch.Tensor:
1160
+ '''
1161
+ One step of the STORK4 update for noise-based diffusion models.
1162
+
1163
+ Args:
1164
+ model_output (`torch.FloatTensor`):
1165
+ The direct output from learned diffusion model.
1166
+ timestep (`float`):
1167
+ The current discrete timestep in the diffusion chain.
1168
+ sample (`torch.FloatTensor`):
1169
+ A current instance of a sample created by the diffusion process.
1170
+ return_dict (`bool`, defaults to `True`):
1171
+ Whether or not to return a [`~schedulers.STORKSchedulerOutput`] instead of a plain tuple.
1172
+
1173
+ Returns:
1174
+ `torch.FloatTensor`: The next sample in the diffusion chain.
1175
+ '''
1176
+
1177
+
1178
+
1179
+ # Initialize the step index if it's the first step
1180
+ if self._step_index is None:
1181
+ self._step_index = 0
1182
+ self.initial_noise = model_output
1183
+
1184
+
1185
+ total_step = self.config.num_train_timesteps
1186
+ t = self.timesteps[self._step_index] / total_step
1187
+
1188
+ beta_0, beta_1 = self.betas[0], self.betas[-1]
1189
+ t_start = torch.ones(model_output.shape, device=sample.device) * t
1190
+ beta_t = (beta_0 + t_start * (beta_1 - beta_0)) * total_step
1191
+ log_mean_coeff = (-0.25 * t_start ** 2 * (beta_1 - beta_0) - 0.5 * t_start * beta_0) * total_step
1192
+ std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
1193
+
1194
+ # Tweedie's trick
1195
+ if self._step_index == len(self.timesteps) - 1:
1196
+ noise_last = model_output
1197
+ img_next = sample - std * noise_last
1198
+ if not return_dict:
1199
+ return (img_next,)
1200
+ return STORKSchedulerOutput(prev_sample=img_next)
1201
+
1202
+ t_next = self.timesteps[self._step_index + 1] / total_step
1203
+
1204
+ # drift, diffusion -> f(x,t), g(t)
1205
+ drift_initial, diffusion_initial = -0.5 * beta_t * sample, torch.sqrt(beta_t) * torch.ones(sample.shape, device=sample.device)
1206
+ noise_initial = model_output
1207
+ score = -noise_initial / std # score -> noise
1208
+ drift_initial = drift_initial - diffusion_initial ** 2 * score * 0.5 # drift -> dx/dt
1209
+
1210
+
1211
+ dt = torch.ones(model_output.shape, device=sample.device) * self.dt
1212
+
1213
+
1214
+ if self.derivative_order == 2:
1215
+ if self._step_index == 0:
1216
+ # Initial Euler update
1217
+ self.initial_sample = sample
1218
+ img_next = sample - dt * drift_initial
1219
+
1220
+ self.noise_predictions.append(noise_initial)
1221
+ self._step_index += 1
1222
+
1223
+ self.initial_drift = drift_initial
1224
+
1225
+ if not return_dict:
1226
+ return (img_next,)
1227
+ return SchedulerOutput(prev_sample=img_next)
1228
+ elif self._step_index == 1:
1229
+ # Initial 2-step Adams-Bashforth update
1230
+ drift_previous = self.initial_drift
1231
+
1232
+ img_next = sample - 1.5 * dt * drift_initial + 0.5 * dt * drift_previous
1233
+
1234
+ self.noise_predictions.append(noise_initial)
1235
+ self._step_index += 1
1236
+
1237
+ if not return_dict:
1238
+ return (img_next,)
1239
+ return SchedulerOutput(prev_sample=img_next)
1240
+ else:
1241
+ # STORK4 update
1242
+ h = dt
1243
+
1244
+ # The first derivative is calculated using the three point approximation,
1245
+ # and the second derivative is calculated using the standardtwo point approximation.
1246
+ noise_derivative = (-self.noise_predictions[-2] + 4 * self.noise_predictions[-1] - 3 * noise_initial) / (2 * h)
1247
+ noise_second_derivative = (self.noise_predictions[-2] - 2 * self.noise_predictions[-1] + noise_initial) / h**2
1248
+ noise_third_derivative = None
1249
+
1250
+ self.noise_predictions.append(noise_initial)
1251
+ noise_approx_order = 2
1252
+ elif self.derivative_order == 1:
1253
+ if self._step_index == 0:
1254
+ # Initial Euler update
1255
+ self.initial_sample = sample
1256
+ img_next = sample - dt * drift_initial
1257
+
1258
+ self.noise_predictions.append(noise_initial)
1259
+ self._step_index += 1
1260
+
1261
+ self.initial_drift = drift_initial
1262
+
1263
+ if not return_dict:
1264
+ return (img_next,)
1265
+ return SchedulerOutput(prev_sample=img_next)
1266
+ else:
1267
+ # STORK4 update
1268
+ h = dt
1269
+
1270
+ noise_derivative = (self.noise_predictions[-1] - noise_initial) / h
1271
+ noise_second_derivative = None
1272
+ noise_third_derivative = None
1273
+
1274
+ self.noise_predictions.append(noise_initial)
1275
+ noise_approx_order = 1
1276
+ else:
1277
+ raise ValueError(f"Unknown derivative order: {self.derivative_order}")
1278
+
1279
+
1280
+ Y_j_2 = sample
1281
+ Y_j_1 = sample
1282
+ Y_j = sample
1283
+
1284
+ ci1 = t_start
1285
+ ci2 = t_start
1286
+ ci3 = t_start
1287
+
1288
+ # Coefficients of ROCK4
1289
+ ms, fpa, fpb, fpbe, recf = self.coeff_rock4()
1290
+ # Choose the degree that's in the precomputed table
1291
+ mdeg, mp = self.mdegr(self.s, ms)
1292
+ mz = int(mp[0])
1293
+ mr = int(mp[1])
1294
+
1295
+ '''
1296
+ The first part of the STORK4 update
1297
+ '''
1298
+ for j in range(1, mdeg + 1):
1299
+
1300
+ # First sub-step in the first part of the STORK4 update
1301
+ if j == 1:
1302
+ temp1 = -(t - t_next) * recf[mr] * torch.ones(model_output.shape, device=sample.device)
1303
+ ci1 = t_start + temp1
1304
+ ci2 = ci1
1305
+ Y_j_1 = sample + temp1 * model_output #subver
1306
+
1307
+ # drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, t_start, Y_j, model_output)
1308
+ # Y_j_1 = sample + temp1 * drift_approx
1309
+
1310
+ # Second and the following sub-steps in the first part of the STORK4 update
1311
+ else:
1312
+ diff = ci1 - t_start
1313
+ noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative)
1314
+ drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci1, Y_j_1, noise_approx)
1315
+
1316
+ temp1 = -(t - t_next) * recf[mr + 2 * (j-2) + 1] * torch.ones(model_output.shape, device=sample.device)
1317
+ temp3 = -recf[mr + 2 * (j-2) + 2] * torch.ones(model_output.shape, device=sample.device)
1318
+ temp2 = torch.ones(model_output.shape, device=sample.device) - temp3
1319
+
1320
+ ci1 = temp1 + temp2 * ci2 + temp3 * ci3
1321
+ Y_j = temp1 * drift_approx + temp2 * Y_j_1 + temp3 * Y_j_2
1322
+
1323
+ # Update the intermediate variables
1324
+ Y_j_2 = Y_j_1
1325
+ Y_j_1 = Y_j
1326
+
1327
+ ci3 = ci2
1328
+ ci2 = ci1
1329
+
1330
+ '''
1331
+ The finishing four-step procedure as a composition method
1332
+ '''
1333
+ # First finishing step
1334
+ temp1 = -(t - t_next) * fpa[mz,0] * torch.ones(model_output.shape, device=sample.device)
1335
+ diff = ci1 - t_start
1336
+ noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative)
1337
+ drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci1, Y_j, noise_approx)
1338
+ Y_j_1 = drift_approx
1339
+ Y_j_3 = Y_j + temp1 * Y_j_1
1340
+
1341
+ # Second finishing step
1342
+ ci2 = ci1 + temp1
1343
+ temp1 = -(t - t_next) * fpa[mz,1] * torch.ones(model_output.shape, device=sample.device)
1344
+ temp2 = -(t - t_next) * fpa[mz,2] * torch.ones(model_output.shape, device=sample.device)
1345
+ diff = ci2 - t_start
1346
+ noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative)
1347
+ drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci2, Y_j_3, noise_approx)
1348
+ Y_j_2 = drift_approx
1349
+ Y_j_4 = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2
1350
+
1351
+ # Third finishing step
1352
+ ci2 = ci1 + temp1 + temp2
1353
+ temp1 = -(t - t_next) * fpa[mz,3] * torch.ones(model_output.shape, device=sample.device)
1354
+ temp2 = -(t - t_next) * fpa[mz,4] * torch.ones(model_output.shape, device=sample.device)
1355
+ temp3 = -(t - t_next) * fpa[mz,5] * torch.ones(model_output.shape, device=sample.device)
1356
+ diff = ci2 - t_start
1357
+ noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative)
1358
+ drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci2, Y_j_4, noise_approx)
1359
+ Y_j_3 = drift_approx
1360
+ fnt = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3
1361
+
1362
+ # Fourth finishing step
1363
+ ci2 = ci1 + temp1 + temp2 + temp3
1364
+ temp1 = -(t - t_next) * fpb[mz,0] * torch.ones(model_output.shape, device=sample.device)
1365
+ temp2 = -(t - t_next) * fpb[mz,1] * torch.ones(model_output.shape, device=sample.device)
1366
+ temp3 = -(t - t_next) * fpb[mz,2] * torch.ones(model_output.shape, device=sample.device)
1367
+ temp4 = -(t - t_next) * fpb[mz,3] * torch.ones(model_output.shape, device=sample.device)
1368
+ diff = ci2 - t_start
1369
+ noise_approx = self.taylor_approximation(noise_approx_order, diff, model_output, noise_derivative, noise_second_derivative, noise_third_derivative)
1370
+ drift_approx = self.drift_function(self.betas, self.config.num_train_timesteps, ci2, fnt, noise_approx)
1371
+ Y_j_4 = drift_approx
1372
+ Y_j = Y_j + temp1 * Y_j_1 + temp2 * Y_j_2 + temp3 * Y_j_3 + temp4 * Y_j_4
1373
+
1374
+
1375
+
1376
+ img_next = Y_j
1377
+ self._step_index += 1
1378
+
1379
+ if not return_dict:
1380
+ return (img_next,)
1381
+ return STORKSchedulerOutput(prev_sample=img_next)
1382
+
1383
+
1384
+
1385
+
1386
+ def __len__(self):
1387
+ return self.config.num_train_timesteps
1388
+
1389
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1390
+ """
1391
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
1392
+ current timestep.
1393
+
1394
+ Args:
1395
+ sample (`torch.Tensor`):
1396
+ The input sample.
1397
+
1398
+ Returns:
1399
+ `torch.Tensor`:
1400
+ A scaled input sample.
1401
+ """
1402
+ return sample
1403
+
1404
+ def add_noise(
1405
+ self,
1406
+ original_samples: torch.FloatTensor,
1407
+ noise: torch.FloatTensor,
1408
+ timesteps: torch.IntTensor,
1409
+ ) -> torch.FloatTensor:
1410
+ """
1411
+ Add noise to the original samples according to the noise magnitude at the given timestep.
1412
+
1413
+ Args:
1414
+ original_samples (`torch.FloatTensor`):
1415
+ The original samples.
1416
+ noise (`torch.FloatTensor`):
1417
+ The noise to add.
1418
+ timesteps (`torch.IntTensor`):
1419
+ The timesteps for which to add noise.
1420
+
1421
+ Returns:
1422
+ `torch.FloatTensor`:
1423
+ The noisy samples.
1424
+ """
1425
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1426
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
1427
+ timesteps = timesteps.to(original_samples.device)
1428
+
1429
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
1430
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1431
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
1432
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1433
+
1434
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
1435
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1436
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
1437
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1438
+
1439
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
1440
+ return noisy_samples
1441
+
1442
+ def get_velocity(
1443
+ self,
1444
+ sample: torch.FloatTensor,
1445
+ noise: torch.FloatTensor,
1446
+ timesteps: torch.IntTensor,
1447
+ ) -> torch.FloatTensor:
1448
+ """
1449
+ Get the velocity (score) for the given sample, noise, and timesteps.
1450
+
1451
+ Args:
1452
+ sample (`torch.FloatTensor`):
1453
+ The sample.
1454
+ noise (`torch.FloatTensor`):
1455
+ The noise.
1456
+ timesteps (`torch.IntTensor`):
1457
+ The timesteps.
1458
+
1459
+ Returns:
1460
+ `torch.FloatTensor`:
1461
+ The velocity.
1462
+ """
1463
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
1464
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
1465
+ timesteps = timesteps.to(sample.device)
1466
+
1467
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
1468
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1469
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
1470
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1471
+
1472
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
1473
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1474
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
1475
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1476
+
1477
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
1478
+ return velocity
1479
+
1480
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
1481
+ if self.config.time_shift_type == "exponential":
1482
+ return self._time_shift_exponential(mu, sigma, t)
1483
+ elif self.config.time_shift_type == "linear":
1484
+ return self._time_shift_linear(mu, sigma, t)
1485
+
1486
+ def _time_shift_exponential(self, mu, sigma, t):
1487
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
1488
+
1489
+ def _time_shift_linear(self, mu, sigma, t):
1490
+ return mu / (mu + (1 / t - 1) ** sigma)
1491
+
1492
+ def taylor_approximation(self, taylor_approx_order, diff, model_output, derivative, second_derivative, third_derivative=None):
1493
+ if taylor_approx_order == 1:
1494
+ approx_value = model_output + diff * derivative
1495
+ elif taylor_approx_order == 2:
1496
+ if third_derivative is not None:
1497
+ raise ValueError("The third derivative is computed but not used!")
1498
+ approx_value = model_output + diff * derivative + 0.5 * diff**2 * second_derivative
1499
+ elif taylor_approx_order == 3:
1500
+ if third_derivative is None:
1501
+ raise ValueError("The third derivative is not computed!")
1502
+ approx_value = model_output + diff * derivative + 0.5 * diff**2 * second_derivative \
1503
+ + diff**3 * third_derivative / 6
1504
+ else:
1505
+ print("The noise approximation order is not supported!")
1506
+ exit()
1507
+
1508
+ return approx_value
1509
+
1510
+
1511
+ def drift_function(self, betas, total_step, t_eval, y_eval, noise):
1512
+ '''
1513
+ Drift function for the probability flow ODE in the noise-based diffusion model.
1514
+
1515
+ Args:
1516
+ betas (`torch.FloatTensor`):
1517
+ The betas of the diffusion model.
1518
+ total_step (`int`):
1519
+ The total number of steps in the diffusion chain.
1520
+ t_eval (`torch.FloatTensor`):
1521
+ The timestep to be evaluated at in the diffusion chain.
1522
+ y_eval (`torch.FloatTensor`):
1523
+ The sample to be evaluated at in the diffusion chain.
1524
+ noise (`torch.FloatTensor`):
1525
+ The noise used at the current timestep in the diffusion chain.
1526
+
1527
+ Returns:
1528
+ `torch.FloatTensor`:
1529
+ The drift term for the probability flow ODE in the diffusion model.
1530
+ '''
1531
+ beta_0, beta_1 = betas[0], betas[-1]
1532
+ beta_t = (beta_0 + t_eval * (beta_1 - beta_0)) * total_step
1533
+ beta_t = beta_t * torch.ones(y_eval.shape, device=y_eval.device)
1534
+
1535
+ log_mean_coeff = (-0.25 * t_eval ** 2 * (beta_1 - beta_0) - 0.5 * t_eval * beta_0) * total_step
1536
+ std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
1537
+
1538
+ # drift, diffusion -> f(x,t), g(t)
1539
+ drift, diffusion = -0.5 * beta_t * y_eval, torch.sqrt(beta_t) * torch.ones(y_eval.shape, device=y_eval.device)
1540
+ score = -noise / std # score -> noise
1541
+ drift = drift - diffusion ** 2 * score * 0.5 # drift -> dx/dt
1542
+
1543
+ return drift
1544
+
1545
+ def b_coeff(self, j):
1546
+ '''
1547
+ Coefficients of STORK2. The are based on the second order Runge-Kutta-Gegenbauer method.
1548
+ Details of the coefficients can be found in https://www.sciencedirect.com/science/article/pii/S0021999120306537
1549
+
1550
+ Args:
1551
+ j (`int`):
1552
+ The sub-step index of the coefficient.
1553
+
1554
+ Returns:
1555
+ `float`:
1556
+ The coefficient of the STORK2.
1557
+ '''
1558
+ if j < 0:
1559
+ print("The b_j coefficient in the RKG method can't have j negative")
1560
+ return
1561
+ if j == 0:
1562
+ return 1
1563
+ if j == 1:
1564
+ return 1 / 3
1565
+
1566
+ return 4 * (j - 1) * (j + 4) / (3 * j * (j + 1) * (j + 2) * (j + 3))
1567
+
1568
+ def coeff_rock1(self, j):
1569
+ if j < 0:
1570
+ print("The b_j coefficient in the RKG method can't have j negative")
1571
+ return 2 / ((j + 1) * (j + 2))
1572
+
1573
+ def coeff_rock4(self):
1574
+ '''
1575
+ Load pre-computed coefficients of STORK4. The are based on the fourth order orthogonal Runge-Kutta-Chebyshev (ROCK4) method.
1576
+ Details of the coefficients can be found in https://epubs.siam.org/doi/abs/10.1137/S1064827500379549.
1577
+ The pre-computed coefficients are based on the implementation https://www.mathworks.com/matlabcentral/fileexchange/12129-rock4.
1578
+
1579
+ Args:
1580
+ j (`int`):
1581
+ The sub-step index of the coefficient.
1582
+
1583
+ Returns:
1584
+ ms (`torch.FloatTensor`):
1585
+ The degrees that coefficients were pre-computed for STORK4.
1586
+ fpa, fpb, fpbe, recf (`torch.FloatTensor`):
1587
+ The parameters for the finishing procedure.
1588
+ '''
1589
+ # Degrees
1590
+ data = loadmat(f'{CONSTANTSFOLDER}/ms.mat')
1591
+ ms = data['ms'][0]
1592
+
1593
+ # Parameters for the finishing procedure
1594
+ data = loadmat(f'{CONSTANTSFOLDER}/fpa.mat')
1595
+ fpa = data['fpa']
1596
+
1597
+ data = loadmat(f'{CONSTANTSFOLDER}/fpb.mat')
1598
+ fpb = data['fpb']
1599
+
1600
+ data = loadmat(f'{CONSTANTSFOLDER}/fpbe.mat')
1601
+ fpbe = data['fpbe']
1602
+
1603
+ # Parameters for the recurrence procedure
1604
+ data = loadmat(f'{CONSTANTSFOLDER}/recf.mat')
1605
+ recf = data['recf'][0]
1606
+
1607
+ return ms, fpa, fpb, fpbe, recf
1608
+
1609
+
1610
+
1611
+ def mdegr(self, mdeg1, ms):
1612
+ '''
1613
+ Find the optimal degree in the pre-computed degree coefficients table for the STORK4 method.
1614
+
1615
+ Args:
1616
+ mdeg1 (`int`):
1617
+ The degree to be evaluated.
1618
+ ms (`torch.FloatTensor`):
1619
+ The degrees that coefficients were pre-computed for STORK4.
1620
+
1621
+ Returns:
1622
+ mdeg (`int`):
1623
+ The optimal degree in the pre-computed degree coefficients table for the STORK4 method.
1624
+ mp (`torch.FloatTensor`):
1625
+ The pointer which select the degree in ms[i], such that mdeg<=ms[i].
1626
+ mp[0] (`int`): The pointer which select the degree in ms[i], such that mdeg<=ms[i].
1627
+ mp[1] (`int`): The pointer which gives the corresponding position of a_1 in the data recf for the selected degree.
1628
+ '''
1629
+ mp = torch.zeros(2)
1630
+ mp[1] = 1
1631
+ mdeg = mdeg1
1632
+ for i in range(len(ms)):
1633
+ if (ms[i]/mdeg) >= 1:
1634
+ mdeg = ms[i]
1635
+ mp[0] = i
1636
+ mp[1] = mp[1] - 1
1637
+ break
1638
+ else:
1639
+ mp[1] = mp[1] + ms[i] * 2 - 1
1640
+
1641
+ return mdeg, mp
model_index.json ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BoomerPipeline",
3
+ "_diffusers_version": "0.33.0",
4
+ "_boomer_version": "1.0.0",
5
+ "transformer": {
6
+ "type": "BoomerFLADiT",
7
+ "config": "transformer/config.json",
8
+ "weights": "transformer/diffusion_pytorch_model.safetensors"
9
+ },
10
+ "vae": {
11
+ "repo_id": "mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers",
12
+ "scaling_factor": 0.41407
13
+ },
14
+ "text_encoder": {
15
+ "repo_id": "google/gemma-4-E2B-it",
16
+ "max_length": 300
17
+ },
18
+ "scheduler": "scheduler/scheduler_config.json",
19
+ "latent_normalization": {
20
+ "mean": [
21
+ -0.0492647976627326,
22
+ 0.1468768895561574,
23
+ 0.04348348892886716,
24
+ 0.024480677381432905,
25
+ -0.36225658200088656,
26
+ -0.20211585707241028,
27
+ 0.14117315317920875,
28
+ -0.10090931608978718,
29
+ 0.03270434805644763,
30
+ 0.025280784072185633,
31
+ 0.42561436599740116,
32
+ 0.07644308823255123,
33
+ -0.12361726209519984,
34
+ 0.2738135117924219,
35
+ 0.027258959344701725,
36
+ -0.15685215533194713,
37
+ 0.1988568681778619,
38
+ 0.07728443773748003,
39
+ 0.21031734156716703,
40
+ 0.10236920059442055,
41
+ -0.26953907125577387,
42
+ -0.1039037490181443,
43
+ -0.14040348514520445,
44
+ -0.050237464451169944,
45
+ 0.21928026529320632,
46
+ 0.05541749411464261,
47
+ 0.15868418162302406,
48
+ 0.09498460353222035,
49
+ 0.07154154194705771,
50
+ -0.0980392861411281,
51
+ 0.3445218162967998,
52
+ 0.14452621160838316
53
+ ],
54
+ "std": [
55
+ 0.7239755227946175,
56
+ 0.7084603356016493,
57
+ 0.7371127244335353,
58
+ 0.7148677155404667,
59
+ 0.7610612675902568,
60
+ 0.7831300251777134,
61
+ 1.241222644947736,
62
+ 1.1914623118386434,
63
+ 0.7064647426283694,
64
+ 1.0233179582088132,
65
+ 0.7671679694251226,
66
+ 0.6818639786525276,
67
+ 0.7394871026815577,
68
+ 0.6749445490371844,
69
+ 0.7961588737844489,
70
+ 0.7955142324161893,
71
+ 0.7545916153429181,
72
+ 0.7799818111961734,
73
+ 0.706798939521899,
74
+ 0.7014546090493033,
75
+ 0.9678039884252744,
76
+ 0.7504288798344418,
77
+ 0.7296232257036755,
78
+ 0.7257654508983634,
79
+ 2.2632974219950786,
80
+ 0.8916002210501063,
81
+ 0.8534945911823539,
82
+ 0.7403039593986197,
83
+ 0.7264856936752643,
84
+ 0.6879722344092367,
85
+ 0.7331094494058464,
86
+ 0.6896992616885751
87
+ ]
88
+ },
89
+ "training_info": {
90
+ "dataset": "FineT2IcacheBF16_1024",
91
+ "base_model": "journeydb-pretrained-boomer-fla",
92
+ "image_size_px": 1024,
93
+ "patch_size": 1,
94
+ "latent_size": 32,
95
+ "latent_tokens": 1024,
96
+ "steps_finetune": 55000,
97
+ "steps_configured": 75000,
98
+ "batch_size": 24,
99
+ "lr": 0.0001,
100
+ "lr_scheduler": "linear-warmup-cosine",
101
+ "warmup_steps": 1000,
102
+ "min_lr_ratio": 0.1,
103
+ "flow_shift": 1.5,
104
+ "t_sampler": "plateau-logit-normal",
105
+ "logit_mean": 0.0,
106
+ "logit_std": 1.0,
107
+ "min_snr_gamma": 5.0,
108
+ "cond_dropout": 0.1,
109
+ "ema_decay": 0.999,
110
+ "ema_update_every": 8,
111
+ "grad_clip": 0.3,
112
+ "latent_stats_mode": "channel"
113
+ }
114
+ }
modeling_boomer_fla.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BoomerFLADiT model β€” self-contained for HuggingFace trust_remote_code distribution.
2
+
3
+ All dependencies inlined: no boomer package import needed.
4
+ External pip requirements: torch, flash-linear-attention (fla).
5
+ """
6
+ # ── inlined from boomer/models/latent_dit.py ──────────────────────────────────
7
+ from __future__ import annotations
8
+ import math
9
+ import sys
10
+ import types
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ from torch.utils.checkpoint import checkpoint as _ckpt
18
+
19
+
20
+ class AttentionRMSNorm(nn.Module):
21
+ def __init__(self, dim: int, scale_factor: float = 0.01, eps: float = 1e-6) -> None:
22
+ super().__init__()
23
+ self.eps = eps
24
+ self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ normed = x.float() * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
28
+ weight = self.weight.view(*([1] * (x.ndim - 2)), -1)
29
+ return (weight * normed).type_as(x)
30
+
31
+
32
+ class CaptionEmbedder(nn.Module):
33
+ def __init__(self, in_channels: int, hidden_size: int, token_num: int) -> None:
34
+ super().__init__()
35
+ self.y_proj = nn.Sequential(
36
+ nn.Linear(in_channels, hidden_size),
37
+ nn.GELU(approximate="tanh"),
38
+ nn.Linear(hidden_size, hidden_size),
39
+ )
40
+ null_init = torch.randn(token_num, in_channels) / math.sqrt(in_channels)
41
+ self.null_text_embedding = nn.Parameter(null_init.unsqueeze(0))
42
+
43
+ def forward(self, caption: torch.Tensor) -> torch.Tensor:
44
+ return self.y_proj(caption)
45
+
46
+ def null_condition(self, batch_size, *, device, dtype, mask_dtype=None, token_num=None):
47
+ text = self.null_text_embedding
48
+ if token_num is not None and token_num != text.shape[1]:
49
+ if token_num < text.shape[1]:
50
+ text = text[:, :token_num]
51
+ else:
52
+ pad = text.new_zeros(text.shape[0], token_num - text.shape[1], text.shape[2])
53
+ text = torch.cat([text, pad], dim=1)
54
+ text = text.expand(batch_size, -1, -1).to(device=device, dtype=dtype)
55
+ mask = torch.ones(batch_size, text.shape[1], device=device, dtype=mask_dtype or torch.long)
56
+ if token_num is not None and token_num > self.null_text_embedding.shape[1]:
57
+ mask[:, self.null_text_embedding.shape[1]:] = 0
58
+ return text, mask
59
+
60
+
61
+ class TimestepEmbedder(nn.Module):
62
+ def __init__(self, hidden_dim: int) -> None:
63
+ super().__init__()
64
+ self.net = nn.Sequential(nn.Linear(1, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim))
65
+
66
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
67
+ dtype = self.net[0].weight.dtype
68
+ return self.net(timesteps.to(dtype=dtype).view(-1, 1))
69
+
70
+
71
+ # ── rest of boomer_fla_dit.py below (unchanged except no boomer imports) ──────
72
+
73
+
74
+
75
+ @dataclass(frozen=True)
76
+ class BoomerFLADiTConfig:
77
+ model_type: str = "boomer_fla"
78
+ latent_channels: int = 32
79
+ latent_size: int = 16
80
+ text_dim: int = 1536
81
+ text_seq_len: int = 300
82
+ hidden_dim: int = 1152
83
+ depth: int = 28
84
+ num_heads: int = 16
85
+ mlp_ratio: float = 2.5
86
+ y_norm: bool = True
87
+ y_norm_scale_factor: float = 0.01
88
+ mixer_type: str = "fla_linear"
89
+ fla_mode: str = "chunk"
90
+ fla_feature_map: str = "relu"
91
+ fla_bidirectional: bool = False
92
+ use_short_conv: bool = False
93
+ conv_size: int = 4
94
+ image_attention_every: int = 0
95
+ image_attention_backend: str = "sdpa"
96
+ image_attention_rope: bool = False
97
+ image_rope_theta: float = 10000.0
98
+ cross_attention_backend: str = "sdpa"
99
+ cross_attention_qk_norm: bool = True
100
+ parallel_block: bool = False
101
+ dual_stream_depth: int = 0
102
+ multimodal_coord_ids: bool = False
103
+ use_abs_pos_embed: bool = True
104
+ patch_size: int = 1
105
+ gradient_checkpointing: bool = False
106
+
107
+
108
+ def maybe_add_sibling_fla_repo() -> None:
109
+ candidates = [
110
+ Path(__file__).resolve().parents[3] / "flash-linear-attention",
111
+ Path("/content/flash-linear-attention"),
112
+ Path("/content/flame"),
113
+ ]
114
+ for path in candidates:
115
+ if (path / "fla").is_dir() and str(path) not in sys.path:
116
+ sys.path.insert(0, str(path))
117
+
118
+
119
+ def maybe_add_sibling_flash_attention_repo() -> None:
120
+ candidates = [
121
+ Path(__file__).resolve().parents[3] / "flash-attention" / "hopper",
122
+ Path(__file__).resolve().parents[3] / "flash-attention",
123
+ Path("/work/flash-attention/hopper"),
124
+ Path("/work/flash-attention"),
125
+ Path("/home/jovyan/work/flash-attention"),
126
+ Path("/content/flash-attention/hopper"),
127
+ Path("/content/flash-attention"),
128
+ ]
129
+ for path in candidates:
130
+ if path.exists() and str(path) not in sys.path:
131
+ sys.path.insert(0, str(path))
132
+
133
+
134
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
135
+ return x * (1.0 + scale) + shift
136
+
137
+
138
+ class ConvLayer(nn.Module):
139
+ def __init__(
140
+ self,
141
+ in_dim: int,
142
+ out_dim: int,
143
+ kernel_size: int,
144
+ *,
145
+ groups: int = 1,
146
+ bias: bool = False,
147
+ act: str | None = None,
148
+ ) -> None:
149
+ super().__init__()
150
+ self.conv = nn.Conv2d(
151
+ in_dim,
152
+ out_dim,
153
+ kernel_size=kernel_size,
154
+ padding=kernel_size // 2,
155
+ groups=groups,
156
+ bias=bias,
157
+ )
158
+ self.act = nn.SiLU() if act == "silu" else nn.Identity()
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ return self.act(self.conv(x))
162
+
163
+
164
+ class GLUMBConv(nn.Module):
165
+ """Sana GLUMBConv FFN: 1x1 expand, depthwise spatial conv, GLU, 1x1 project."""
166
+
167
+ def __init__(self, hidden_dim: int, mlp_ratio: float) -> None:
168
+ super().__init__()
169
+ inner_dim = int(hidden_dim * mlp_ratio)
170
+ self.inner_dim = inner_dim
171
+ self.inverted_conv = ConvLayer(hidden_dim, inner_dim * 2, 1, bias=True, act="silu")
172
+ self.depth_conv = ConvLayer(inner_dim * 2, inner_dim * 2, 3, groups=inner_dim * 2, bias=True)
173
+ self.point_conv = ConvLayer(inner_dim, hidden_dim, 1, bias=False)
174
+ nn.init.zeros_(self.point_conv.conv.weight)
175
+ self.glu_act = nn.SiLU()
176
+
177
+ def forward(self, x: torch.Tensor, *, height: int, width: int) -> torch.Tensor:
178
+ batch, tokens, channels = x.shape
179
+ if tokens != height * width:
180
+ raise ValueError(f"Expected {height * width} image tokens, got {tokens}")
181
+ x = x.reshape(batch, height, width, channels).permute(0, 3, 1, 2).contiguous()
182
+ x = self.inverted_conv(x)
183
+ x = self.depth_conv(x)
184
+ x, gate = x.chunk(2, dim=1)
185
+ x = x * self.glu_act(gate)
186
+ x = self.point_conv(x)
187
+ return x.reshape(batch, channels, tokens).transpose(1, 2).contiguous()
188
+
189
+
190
+ class TorchSelfAttention(nn.Module):
191
+ def __init__(self, hidden_dim: int, num_heads: int) -> None:
192
+ super().__init__()
193
+ self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
194
+
195
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
196
+ return self.attn(x, x, x, need_weights=False)[0]
197
+
198
+
199
+ class TokenMLP(nn.Module):
200
+ def __init__(self, hidden_dim: int, mlp_ratio: float) -> None:
201
+ super().__init__()
202
+ inner_dim = int(hidden_dim * mlp_ratio)
203
+ self.net = nn.Sequential(
204
+ nn.Linear(hidden_dim, inner_dim),
205
+ nn.GELU(approximate="tanh"),
206
+ nn.Linear(inner_dim, hidden_dim),
207
+ )
208
+
209
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
210
+ return self.net(x)
211
+
212
+
213
+ class MultimodalCoordinateRoPE(nn.Module):
214
+ """FLUX-style coordinate-ID RoPE for joint text/image attention."""
215
+
216
+ def __init__(self, head_dim: int, *, image_size: int, text_seq_len: int, theta: float = 10000.0) -> None:
217
+ super().__init__()
218
+ if head_dim < 6 or head_dim % 2 != 0:
219
+ raise ValueError(f"head_dim={head_dim} must be even and at least 6 for multimodal RoPE")
220
+ if theta <= 0.0:
221
+ raise ValueError(f"theta must be positive, got {theta}")
222
+ type_dim = max(2, (head_dim // 4) // 2 * 2)
223
+ while type_dim > 2 and (head_dim - type_dim) % 4 != 0:
224
+ type_dim -= 2
225
+ remaining = head_dim - type_dim
226
+ row_dim = max(2, (remaining // 2) // 2 * 2)
227
+ col_dim = remaining - row_dim
228
+ if col_dim < 2 or col_dim % 2 != 0:
229
+ raise ValueError(f"could not split head_dim={head_dim} into even multimodal RoPE axes")
230
+ self.axes_dim = (type_dim, row_dim, col_dim)
231
+ self.head_dim = head_dim
232
+ self.image_size = image_size
233
+ self.text_seq_len = text_seq_len
234
+
235
+ for index, dim in enumerate(self.axes_dim):
236
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
237
+ self.register_buffer(f"inv_freq_{index}", inv_freq, persistent=False)
238
+
239
+ @staticmethod
240
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
241
+ x1, x2 = x.chunk(2, dim=-1)
242
+ return torch.cat((-x2, x1), dim=-1)
243
+
244
+ def image_ids(self, batch_size: int, *, height: int, width: int, device: torch.device | str) -> torch.Tensor:
245
+ token_idx = torch.arange(height * width, device=device)
246
+ rows = token_idx // width
247
+ cols = token_idx % width
248
+ token_type = torch.ones_like(rows)
249
+ ids = torch.stack([token_type, rows, cols], dim=-1)
250
+ return ids.unsqueeze(0).expand(batch_size, -1, -1)
251
+
252
+ def text_ids(self, batch_size: int, token_count: int, *, device: torch.device | str) -> torch.Tensor:
253
+ positions = torch.arange(token_count, device=device)
254
+ token_type = torch.zeros_like(positions)
255
+ zeros = torch.zeros_like(positions)
256
+ ids = torch.stack([token_type, positions, zeros], dim=-1)
257
+ return ids.unsqueeze(0).expand(batch_size, -1, -1)
258
+
259
+ def _axis_apply(self, x: torch.Tensor, axis_ids: torch.Tensor, axis_index: int) -> torch.Tensor:
260
+ inv_freq = getattr(self, f"inv_freq_{axis_index}")
261
+ angles = axis_ids.float().unsqueeze(-1) * inv_freq.to(device=x.device).view(1, 1, -1)
262
+ cos = torch.cat([angles.cos(), angles.cos()], dim=-1).unsqueeze(2).to(dtype=x.dtype)
263
+ sin = torch.cat([angles.sin(), angles.sin()], dim=-1).unsqueeze(2).to(dtype=x.dtype)
264
+ return x * cos + self._rotate_half(x) * sin
265
+
266
+ def apply(
267
+ self,
268
+ q: torch.Tensor,
269
+ k: torch.Tensor,
270
+ ids: torch.Tensor,
271
+ ) -> tuple[torch.Tensor, torch.Tensor]:
272
+ if q.shape[-1] != self.head_dim or k.shape[-1] != self.head_dim:
273
+ raise ValueError(f"expected head_dim={self.head_dim}, got q={q.shape[-1]} k={k.shape[-1]}")
274
+ if ids.shape[:2] != q.shape[:2] or ids.shape[-1] != len(self.axes_dim):
275
+ raise ValueError(f"expected ids shape (B, T, {len(self.axes_dim)}), got {tuple(ids.shape)}")
276
+ q_chunks = q.split(self.axes_dim, dim=-1)
277
+ k_chunks = k.split(self.axes_dim, dim=-1)
278
+ q_out = []
279
+ k_out = []
280
+ for index, (q_axis, k_axis) in enumerate(zip(q_chunks, k_chunks, strict=True)):
281
+ q_out.append(self._axis_apply(q_axis, ids[..., index], index))
282
+ k_out.append(self._axis_apply(k_axis, ids[..., index], index))
283
+ return torch.cat(q_out, dim=-1), torch.cat(k_out, dim=-1)
284
+
285
+
286
+ class RoPE2D(nn.Module):
287
+ """2D RoPE for image tokens on a fixed HΓ—W grid (row-major flattening).
288
+
289
+ Splits head_dim in half: the first half encodes height, the second width.
290
+ Each half uses standard 1D RoPE with shared cos/sin tables per axis.
291
+ """
292
+
293
+ def __init__(self, head_dim: int, grid_size: int, *, theta: float = 10000.0) -> None:
294
+ super().__init__()
295
+ if head_dim % 4 != 0:
296
+ raise ValueError(
297
+ f"head_dim={head_dim} must be divisible by 4 for 2D RoPE "
298
+ f"(half for H, half for W, each needing pairs)"
299
+ )
300
+ if grid_size <= 0:
301
+ raise ValueError(f"grid_size must be positive, got {grid_size}")
302
+ if theta <= 0.0:
303
+ raise ValueError(f"theta must be positive, got {theta}")
304
+ self.head_dim = head_dim
305
+ self.grid_size = grid_size
306
+ self.half_dim = head_dim // 2
307
+
308
+ freqs = 1.0 / (theta ** (torch.arange(0, self.half_dim, 2).float() / self.half_dim))
309
+ token_idx = torch.arange(grid_size * grid_size)
310
+ h_idx = token_idx // grid_size
311
+ w_idx = token_idx % grid_size
312
+
313
+ def axis_tables(pos_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
314
+ angles = torch.outer(pos_idx.float(), freqs)
315
+ cos = torch.cat([angles.cos(), angles.cos()], dim=-1)[None, :, None, :]
316
+ sin = torch.cat([angles.sin(), angles.sin()], dim=-1)[None, :, None, :]
317
+ return cos, sin
318
+
319
+ cos_h, sin_h = axis_tables(h_idx)
320
+ cos_w, sin_w = axis_tables(w_idx)
321
+ self.register_buffer("cos_h", cos_h, persistent=False)
322
+ self.register_buffer("sin_h", sin_h, persistent=False)
323
+ self.register_buffer("cos_w", cos_w, persistent=False)
324
+ self.register_buffer("sin_w", sin_w, persistent=False)
325
+
326
+ @staticmethod
327
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
328
+ x1, x2 = x.chunk(2, dim=-1)
329
+ return torch.cat((-x2, x1), dim=-1)
330
+
331
+ def _apply_axis_rope(
332
+ self,
333
+ x: torch.Tensor,
334
+ cos: torch.Tensor,
335
+ sin: torch.Tensor,
336
+ ) -> torch.Tensor:
337
+ return x * cos.to(dtype=x.dtype) + self._rotate_half(x) * sin.to(dtype=x.dtype)
338
+
339
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
340
+ batch, tokens, num_heads, head_dim = q.shape
341
+ if head_dim != self.head_dim:
342
+ raise ValueError(f"expected head_dim={self.head_dim}, got {head_dim}")
343
+ expected_tokens = self.grid_size * self.grid_size
344
+ if tokens != expected_tokens:
345
+ raise ValueError(f"expected {expected_tokens} image tokens, got {tokens}")
346
+
347
+ q_h, q_w = q.chunk(2, dim=-1)
348
+ k_h, k_w = k.chunk(2, dim=-1)
349
+ q_h = self._apply_axis_rope(q_h, self.cos_h, self.sin_h)
350
+ q_w = self._apply_axis_rope(q_w, self.cos_w, self.sin_w)
351
+ k_h = self._apply_axis_rope(k_h, self.cos_h, self.sin_h)
352
+ k_w = self._apply_axis_rope(k_w, self.cos_w, self.sin_w)
353
+ return torch.cat([q_h, q_w], dim=-1), torch.cat([k_h, k_w], dim=-1)
354
+
355
+
356
+ class FullImageSelfAttention(nn.Module):
357
+ """Full image-token attention for the small DC-AE latent grid."""
358
+
359
+ def __init__(
360
+ self,
361
+ hidden_dim: int,
362
+ num_heads: int,
363
+ *,
364
+ backend: str = "sdpa",
365
+ grid_size: int | None = None,
366
+ rope: bool = False,
367
+ rope_theta: float = 10000.0,
368
+ ) -> None:
369
+ super().__init__()
370
+ if hidden_dim % num_heads != 0:
371
+ raise ValueError(f"hidden_dim={hidden_dim} must be divisible by num_heads={num_heads}")
372
+ if backend not in {"sdpa", "flash3", "flash4", "auto"}:
373
+ raise ValueError(f"Unsupported image_attention_backend: {backend}")
374
+ if rope and grid_size is None:
375
+ raise ValueError("grid_size is required when rope=True")
376
+ self.hidden_dim = hidden_dim
377
+ self.num_heads = num_heads
378
+ self.head_dim = hidden_dim // num_heads
379
+ self.backend = backend
380
+ self.qkv = nn.Linear(hidden_dim, hidden_dim * 3)
381
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim)
382
+ nn.init.zeros_(self.out_proj.weight)
383
+ nn.init.zeros_(self.out_proj.bias)
384
+ self.rope = (
385
+ RoPE2D(self.head_dim, grid_size, theta=rope_theta)
386
+ if rope and grid_size is not None
387
+ else None
388
+ )
389
+ self._flash3_attn_func = None
390
+ self._flash3_import_attempted = False
391
+ self._flash4_attn_func = None
392
+ self._flash4_import_attempted = False
393
+
394
+ def _get_flash3_attn_func(self):
395
+ if self._flash3_import_attempted:
396
+ return self._flash3_attn_func
397
+ self._flash3_import_attempted = True
398
+ maybe_add_sibling_flash_attention_repo()
399
+ try:
400
+ from flash_attn_interface import flash_attn_func
401
+ except Exception:
402
+ try:
403
+ from flash_attn.flash_attn_interface import flash_attn_func
404
+ except Exception:
405
+ flash_attn_func = None
406
+ self._flash3_attn_func = flash_attn_func
407
+ return self._flash3_attn_func
408
+
409
+ def _get_flash4_attn_func(self):
410
+ if self._flash4_import_attempted:
411
+ return self._flash4_attn_func
412
+ self._flash4_import_attempted = True
413
+ maybe_add_sibling_flash_attention_repo()
414
+ try:
415
+ from flash_attn.cute.interface import flash_attn_func
416
+ except Exception:
417
+ flash4_paths = [
418
+ Path(__file__).resolve().parents[3] / "flash-attention" / "flash_attn",
419
+ Path("/work/flash-attention/flash_attn"),
420
+ Path("/home/jovyan/work/flash-attention/flash_attn"),
421
+ Path("/content/flash-attention/flash_attn"),
422
+ ]
423
+ existing_paths = [str(path) for path in flash4_paths if (path / "cute").is_dir()]
424
+ if existing_paths:
425
+ for name in list(sys.modules):
426
+ if name == "flash_attn" or name.startswith("flash_attn."):
427
+ del sys.modules[name]
428
+ flash_attn_pkg = types.ModuleType("flash_attn")
429
+ flash_attn_pkg.__path__ = existing_paths
430
+ sys.modules["flash_attn"] = flash_attn_pkg
431
+ try:
432
+ from flash_attn.cute.interface import flash_attn_func
433
+ except Exception:
434
+ flash_attn_func = None
435
+ else:
436
+ flash_attn_func = None
437
+ self._flash4_attn_func = flash_attn_func
438
+ return self._flash4_attn_func
439
+
440
+ def _flash3_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
441
+ flash_attn_func = self._get_flash3_attn_func()
442
+ if flash_attn_func is None:
443
+ raise ImportError(
444
+ "image_attention_backend='flash3' requires FlashAttention-3. "
445
+ "Install it or use --image-attn-backend sdpa."
446
+ )
447
+ out = flash_attn_func(q, k, v, causal=False)
448
+ if isinstance(out, tuple):
449
+ out = out[0]
450
+ return out
451
+
452
+ def _flash4_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
453
+ flash_attn_func = self._get_flash4_attn_func()
454
+ if flash_attn_func is None:
455
+ raise ImportError(
456
+ "image_attention_backend='flash4' requires FlashAttention-4/CuTe. "
457
+ "Install flash-attn-4 or use --image-attn-backend sdpa."
458
+ )
459
+ out = flash_attn_func(q, k, v, causal=False)
460
+ if isinstance(out, tuple):
461
+ out = out[0]
462
+ return out
463
+
464
+ @staticmethod
465
+ def _flash_compute_dtype(x: torch.Tensor) -> torch.dtype | None:
466
+ """FA kernels need fp16/bf16; fp32 master weights + compile may still pass fp32 activations."""
467
+ if not x.is_cuda:
468
+ return None
469
+ if x.dtype in {torch.float16, torch.bfloat16}:
470
+ return x.dtype
471
+ if torch.is_autocast_enabled():
472
+ return torch.get_autocast_dtype("cuda")
473
+ return torch.bfloat16
474
+
475
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
476
+ batch, tokens, channels = x.shape
477
+ qkv = self.qkv(x).reshape(batch, tokens, 3, self.num_heads, self.head_dim)
478
+ q, k, v = qkv.unbind(dim=2)
479
+ if self.rope is not None:
480
+ q, k = self.rope(q, k)
481
+
482
+ flash_dtype = self._flash_compute_dtype(x)
483
+ use_flash = self.backend in {"flash3", "flash4", "auto"} and flash_dtype is not None
484
+ if use_flash and (q.dtype != flash_dtype or k.dtype != flash_dtype or v.dtype != flash_dtype):
485
+ q, k, v = q.to(flash_dtype), k.to(flash_dtype), v.to(flash_dtype)
486
+
487
+ if self.backend == "flash4" and use_flash:
488
+ out = self._flash4_attention(q, k, v)
489
+ elif self.backend == "flash3" and use_flash:
490
+ out = self._flash3_attention(q, k, v)
491
+ elif self.backend == "auto" and use_flash:
492
+ try:
493
+ out = self._flash4_attention(q, k, v)
494
+ except Exception:
495
+ try:
496
+ out = self._flash3_attention(q, k, v)
497
+ except Exception:
498
+ use_flash = False
499
+ if self.backend in {"flash3", "flash4"} and not use_flash:
500
+ raise RuntimeError(
501
+ f"image_attention_backend='{self.backend}' requires CUDA fp16/bf16 compute; got {x.device} {x.dtype}"
502
+ )
503
+ if use_flash and out.dtype != x.dtype:
504
+ out = out.to(dtype=x.dtype)
505
+ if not use_flash:
506
+ q = q.transpose(1, 2)
507
+ k = k.transpose(1, 2)
508
+ v = v.transpose(1, 2)
509
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
510
+ out = out.transpose(1, 2)
511
+
512
+ out = out.reshape(batch, tokens, channels)
513
+ return self.out_proj(out)
514
+
515
+
516
+ class SanaMultiHeadCrossAttention(nn.Module):
517
+ """Sana-style cross-attention with optional q/k norm and SDPA/xformers kernels."""
518
+
519
+ def __init__(
520
+ self,
521
+ hidden_dim: int,
522
+ num_heads: int,
523
+ *,
524
+ backend: str = "sdpa",
525
+ qk_norm: bool = True,
526
+ ) -> None:
527
+ super().__init__()
528
+ if hidden_dim % num_heads != 0:
529
+ raise ValueError(f"hidden_dim={hidden_dim} must be divisible by num_heads={num_heads}")
530
+ if backend not in {"sdpa", "xformers", "auto"}:
531
+ raise ValueError(f"Unsupported cross_attention_backend: {backend}")
532
+ self.hidden_dim = hidden_dim
533
+ self.num_heads = num_heads
534
+ self.head_dim = hidden_dim // num_heads
535
+ self.backend = backend
536
+ self.q_linear = nn.Linear(hidden_dim, hidden_dim)
537
+ self.kv_linear = nn.Linear(hidden_dim, hidden_dim * 2)
538
+ self.q_norm = AttentionRMSNorm(hidden_dim, scale_factor=1.0, eps=1e-6) if qk_norm else nn.Identity()
539
+ self.k_norm = AttentionRMSNorm(hidden_dim, scale_factor=1.0, eps=1e-6) if qk_norm else nn.Identity()
540
+ self.proj = nn.Linear(hidden_dim, hidden_dim)
541
+ # adaLN-Zero style: cross-attn starts as a no-op so Gemma text cannot spike GDN states early.
542
+ nn.init.zeros_(self.proj.weight)
543
+ nn.init.zeros_(self.proj.bias)
544
+ self._xformers_ops = None
545
+ self._xformers_import_attempted = False
546
+
547
+ def _get_xformers_ops(self):
548
+ if self._xformers_import_attempted:
549
+ return self._xformers_ops
550
+ self._xformers_import_attempted = True
551
+ try:
552
+ import xformers.ops as xops
553
+ except Exception:
554
+ xops = None
555
+ self._xformers_ops = xops
556
+ return self._xformers_ops
557
+
558
+ def _xformers_attention(
559
+ self,
560
+ q: torch.Tensor,
561
+ k: torch.Tensor,
562
+ v: torch.Tensor,
563
+ key_padding_mask: torch.Tensor | None,
564
+ ) -> torch.Tensor:
565
+ xops = self._get_xformers_ops()
566
+ if xops is None:
567
+ raise ImportError(
568
+ "cross_attention_backend='xformers' requires xformers. "
569
+ "Install it or use --cross-attn-backend sdpa."
570
+ )
571
+
572
+ batch, image_tokens = q.shape[:2]
573
+ text_tokens = k.shape[1]
574
+ q_lens = [image_tokens] * batch
575
+ q_compact = q.reshape(1, batch * image_tokens, self.num_heads, self.head_dim)
576
+ if key_padding_mask is None:
577
+ kv_lens = [text_tokens] * batch
578
+ k_compact = k.reshape(1, batch * text_tokens, self.num_heads, self.head_dim)
579
+ v_compact = v.reshape(1, batch * text_tokens, self.num_heads, self.head_dim)
580
+ else:
581
+ valid_mask = ~key_padding_mask.bool()
582
+ kv_lens = valid_mask.sum(dim=1).tolist()
583
+ if any(length <= 0 for length in kv_lens):
584
+ raise ValueError("xformers cross-attention received a sample with zero valid text tokens")
585
+ k_compact = torch.cat([k[index, valid_mask[index]] for index in range(batch)], dim=0).unsqueeze(0)
586
+ v_compact = torch.cat([v[index, valid_mask[index]] for index in range(batch)], dim=0).unsqueeze(0)
587
+
588
+ attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(q_lens, kv_lens)
589
+ out = xops.memory_efficient_attention(q_compact, k_compact, v_compact, attn_bias=attn_bias, p=0.0)
590
+ return out.reshape(batch, image_tokens, self.num_heads, self.head_dim)
591
+
592
+ def _sdpa_attention(
593
+ self,
594
+ q: torch.Tensor,
595
+ k: torch.Tensor,
596
+ v: torch.Tensor,
597
+ key_padding_mask: torch.Tensor | None,
598
+ attn_bias: torch.Tensor | None = None,
599
+ ) -> torch.Tensor:
600
+ q = q.transpose(1, 2)
601
+ k = k.transpose(1, 2)
602
+ v = v.transpose(1, 2)
603
+ attn_mask = attn_bias
604
+ if attn_mask is None and key_padding_mask is not None:
605
+ attn_mask = key_padding_mask[:, None, None, :].to(dtype=q.dtype)
606
+ attn_mask = attn_mask.masked_fill(attn_mask > 0, -10000.0)
607
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
608
+ return out.transpose(1, 2)
609
+
610
+ def forward(
611
+ self,
612
+ x: torch.Tensor,
613
+ cond: torch.Tensor,
614
+ key_padding_mask: torch.Tensor | None = None,
615
+ attn_bias: torch.Tensor | None = None,
616
+ ) -> torch.Tensor:
617
+ batch, image_tokens, channels = x.shape
618
+ # Sana order: linear projection first, then per-token q/k RMSNorm before head split.
619
+ # This caps dot-product growth when cond carries high-magnitude Gemma caption states.
620
+ q = self.q_linear(x)
621
+ q = self.q_norm(q).reshape(batch, image_tokens, self.num_heads, self.head_dim)
622
+ k, v = self.kv_linear(cond).chunk(2, dim=-1)
623
+ k = self.k_norm(k).reshape(batch, cond.shape[1], self.num_heads, self.head_dim)
624
+ v = v.reshape(batch, cond.shape[1], self.num_heads, self.head_dim)
625
+
626
+ use_xformers = self.backend in {"xformers", "auto"} and x.is_cuda and x.dtype in {
627
+ torch.float16,
628
+ torch.bfloat16,
629
+ }
630
+ if use_xformers:
631
+ try:
632
+ out = self._xformers_attention(q, k, v, key_padding_mask)
633
+ except Exception:
634
+ if self.backend == "xformers":
635
+ raise
636
+ use_xformers = False
637
+ if self.backend == "xformers" and not use_xformers:
638
+ raise RuntimeError(
639
+ f"cross_attention_backend='xformers' requires CUDA fp16/bf16 tensors; got {x.device} {x.dtype}"
640
+ )
641
+ if not use_xformers:
642
+ out = self._sdpa_attention(q, k, v, key_padding_mask, attn_bias)
643
+
644
+ return self.proj(out.reshape(batch, image_tokens, channels))
645
+
646
+
647
+ class FLASelfMixer(nn.Module):
648
+ def __init__(self, config: BoomerFLADiTConfig, *, layer_idx: int) -> None:
649
+ super().__init__()
650
+ try:
651
+ import fla.layers as fla_layers
652
+ except Exception:
653
+ maybe_add_sibling_fla_repo()
654
+ import fla.layers as fla_layers
655
+
656
+ hidden_dim = config.hidden_dim
657
+ self.bidirectional = config.fla_bidirectional
658
+
659
+ def make_mixer() -> nn.Module:
660
+ if config.mixer_type == "fla_linear":
661
+ return fla_layers.LinearAttention(
662
+ hidden_size=hidden_dim,
663
+ num_heads=config.num_heads,
664
+ mode=config.fla_mode,
665
+ feature_map=config.fla_feature_map,
666
+ output_norm="rmsnorm",
667
+ layer_idx=layer_idx,
668
+ )
669
+ if config.mixer_type == "fla_gated_deltanet":
670
+ return fla_layers.GatedDeltaNet(
671
+ hidden_size=hidden_dim,
672
+ num_heads=config.num_heads,
673
+ head_dim=hidden_dim // config.num_heads,
674
+ expand_v=1,
675
+ mode=config.fla_mode,
676
+ use_short_conv=config.use_short_conv,
677
+ conv_size=config.conv_size,
678
+ layer_idx=layer_idx,
679
+ )
680
+ if config.mixer_type == "fla_gla":
681
+ return fla_layers.GatedLinearAttention(
682
+ hidden_size=hidden_dim,
683
+ num_heads=config.num_heads,
684
+ mode=config.fla_mode,
685
+ feature_map=config.fla_feature_map,
686
+ use_short_conv=config.use_short_conv,
687
+ conv_size=config.conv_size,
688
+ layer_idx=layer_idx,
689
+ )
690
+ raise ValueError(f"Unsupported FLA mixer_type: {config.mixer_type}")
691
+
692
+ self.mixer_fwd = make_mixer()
693
+ self.mixer_bwd = make_mixer() if self.bidirectional else None
694
+ if self.bidirectional:
695
+ self.out_proj = nn.Linear(hidden_dim * 2, hidden_dim, bias=False)
696
+ nn.init.zeros_(self.out_proj.weight)
697
+
698
+ @staticmethod
699
+ def _run_mixer(mixer: nn.Module, x: torch.Tensor) -> torch.Tensor:
700
+ y = mixer(x)
701
+ if isinstance(y, tuple):
702
+ y = y[0]
703
+ return y
704
+
705
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
706
+ y = self._run_mixer(self.mixer_fwd, x)
707
+ if not self.bidirectional:
708
+ return y
709
+ if self.mixer_bwd is None:
710
+ raise RuntimeError("bidirectional FLASelfMixer is missing the backward mixer")
711
+ y_rev = self._run_mixer(self.mixer_bwd, x.flip(1)).flip(1)
712
+ return self.out_proj(torch.cat([y, y_rev], dim=-1))
713
+
714
+
715
+ class BoomerFLABlock(nn.Module):
716
+ def __init__(self, config: BoomerFLADiTConfig, *, layer_idx: int) -> None:
717
+ super().__init__()
718
+ hidden_dim = config.hidden_dim
719
+ self.parallel_block = config.parallel_block
720
+ self.use_image_attention = (
721
+ config.image_attention_every > 0 and (layer_idx + 1) % config.image_attention_every == 0
722
+ )
723
+ self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
724
+ if config.mixer_type in {"torch", "fallback"}:
725
+ self.self_attn = TorchSelfAttention(hidden_dim, config.num_heads)
726
+ else:
727
+ self.self_attn = FLASelfMixer(config, layer_idx=layer_idx)
728
+ if self.use_image_attention:
729
+ self.image_attn_norm = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
730
+ self.image_attn_mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 3))
731
+ self.image_attn = FullImageSelfAttention(
732
+ hidden_dim,
733
+ config.num_heads,
734
+ backend=config.image_attention_backend,
735
+ grid_size=config.latent_size // config.patch_size,
736
+ rope=config.image_attention_rope,
737
+ rope_theta=config.image_rope_theta,
738
+ )
739
+ self.image_attn_scale_shift_table = nn.Parameter(torch.zeros(3, hidden_dim))
740
+ cross_backend = config.cross_attention_backend
741
+ if config.cross_attention_qk_norm and cross_backend == "mha":
742
+ raise ValueError(
743
+ "cross_attention_qk_norm requires SanaMultiHeadCrossAttention "
744
+ "(cross_attention_backend sdpa/xformers/auto), not mha"
745
+ )
746
+ if cross_backend == "mha":
747
+ self.cross_attn = nn.MultiheadAttention(hidden_dim, config.num_heads, batch_first=True)
748
+ else:
749
+ self.cross_attn = SanaMultiHeadCrossAttention(
750
+ hidden_dim,
751
+ config.num_heads,
752
+ backend=cross_backend,
753
+ qk_norm=config.cross_attention_qk_norm,
754
+ )
755
+ self.mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 9))
756
+ self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
757
+ self.norm3 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
758
+ self.mlp = GLUMBConv(hidden_dim, config.mlp_ratio)
759
+ self.scale_shift_table = nn.Parameter(torch.zeros(9, hidden_dim))
760
+
761
+ def _cross_attention(
762
+ self,
763
+ x: torch.Tensor,
764
+ text_tokens: torch.Tensor,
765
+ text_key_padding_mask: torch.Tensor,
766
+ text_attn_bias: torch.Tensor | None,
767
+ ) -> torch.Tensor:
768
+ if isinstance(self.cross_attn, nn.MultiheadAttention):
769
+ return self.cross_attn(
770
+ x,
771
+ text_tokens,
772
+ text_tokens,
773
+ key_padding_mask=text_key_padding_mask,
774
+ need_weights=False,
775
+ )[0]
776
+ return self.cross_attn(x, text_tokens, text_key_padding_mask, text_attn_bias)
777
+
778
+ def forward(
779
+ self,
780
+ x: torch.Tensor,
781
+ text_tokens: torch.Tensor,
782
+ t_emb: torch.Tensor,
783
+ text_key_padding_mask: torch.Tensor,
784
+ text_attn_bias: torch.Tensor | None,
785
+ *,
786
+ height: int,
787
+ width: int,
788
+ ) -> torch.Tensor:
789
+ timestep_mod = self.mod(t_emb)
790
+ (
791
+ shift_msa,
792
+ scale_msa,
793
+ gate_msa,
794
+ shift_cross,
795
+ scale_cross,
796
+ gate_cross,
797
+ shift_mlp,
798
+ scale_mlp,
799
+ gate_mlp,
800
+ ) = (self.scale_shift_table[None] + timestep_mod.reshape(x.shape[0], 9, -1)).chunk(9, dim=1)
801
+ if self.parallel_block:
802
+ base = x
803
+ branches = [
804
+ gate_msa * self.self_attn(modulate(self.norm1(base), shift_msa, scale_msa)),
805
+ gate_cross
806
+ * self._cross_attention(
807
+ modulate(self.norm3(base), shift_cross, scale_cross),
808
+ text_tokens,
809
+ text_key_padding_mask,
810
+ text_attn_bias,
811
+ ),
812
+ gate_mlp * self.mlp(modulate(self.norm2(base), shift_mlp, scale_mlp), height=height, width=width),
813
+ ]
814
+ if self.use_image_attention:
815
+ image_attn_mod = self.image_attn_mod(t_emb)
816
+ shift_img, scale_img, gate_img = (
817
+ self.image_attn_scale_shift_table[None] + image_attn_mod.reshape(x.shape[0], 3, -1)
818
+ ).chunk(3, dim=1)
819
+ branches.append(
820
+ gate_img * self.image_attn(modulate(self.image_attn_norm(base), shift_img, scale_img))
821
+ )
822
+ return base + sum(branches)
823
+
824
+ x = x + gate_msa * self.self_attn(modulate(self.norm1(x), shift_msa, scale_msa))
825
+ if self.use_image_attention:
826
+ image_attn_mod = self.image_attn_mod(t_emb)
827
+ shift_img, scale_img, gate_img = (
828
+ self.image_attn_scale_shift_table[None] + image_attn_mod.reshape(x.shape[0], 3, -1)
829
+ ).chunk(3, dim=1)
830
+ x = x + gate_img * self.image_attn(modulate(self.image_attn_norm(x), shift_img, scale_img))
831
+ x = x + gate_cross * self._cross_attention(
832
+ modulate(self.norm3(x), shift_cross, scale_cross),
833
+ text_tokens,
834
+ text_key_padding_mask,
835
+ text_attn_bias,
836
+ )
837
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), height=height, width=width)
838
+ return x
839
+
840
+
841
+ class BoomerFLADualStreamBlock(nn.Module):
842
+ """FLUX-style early block with one joint text+image attention operation."""
843
+
844
+ updates_text = True
845
+
846
+ def __init__(self, config: BoomerFLADiTConfig, *, layer_idx: int) -> None:
847
+ super().__init__()
848
+ hidden_dim = config.hidden_dim
849
+ if hidden_dim % config.num_heads != 0:
850
+ raise ValueError(f"hidden_dim={hidden_dim} must be divisible by num_heads={config.num_heads}")
851
+ self.num_heads = config.num_heads
852
+ self.head_dim = hidden_dim // config.num_heads
853
+ self.hidden_dim = hidden_dim
854
+ self.parallel_block = config.parallel_block
855
+
856
+ self.image_mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 6))
857
+ self.image_norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
858
+ self.image_qkv = nn.Linear(hidden_dim, hidden_dim * 3)
859
+ self.image_q_norm = AttentionRMSNorm(self.head_dim, scale_factor=1.0, eps=1e-6)
860
+ self.image_k_norm = AttentionRMSNorm(self.head_dim, scale_factor=1.0, eps=1e-6)
861
+ self.image_out_proj = nn.Linear(hidden_dim, hidden_dim)
862
+ self.image_norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
863
+ self.image_mlp = GLUMBConv(hidden_dim, config.mlp_ratio)
864
+ self.image_scale_shift_table = nn.Parameter(torch.zeros(6, hidden_dim))
865
+
866
+ self.text_mod = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 6))
867
+ self.text_norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
868
+ self.text_qkv = nn.Linear(hidden_dim, hidden_dim * 3)
869
+ self.text_q_norm = AttentionRMSNorm(self.head_dim, scale_factor=1.0, eps=1e-6)
870
+ self.text_k_norm = AttentionRMSNorm(self.head_dim, scale_factor=1.0, eps=1e-6)
871
+ self.text_out_proj = nn.Linear(hidden_dim, hidden_dim)
872
+ self.text_norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
873
+ self.text_mlp = TokenMLP(hidden_dim, config.mlp_ratio)
874
+ self.text_scale_shift_table = nn.Parameter(torch.zeros(6, hidden_dim))
875
+
876
+ def _qkv(
877
+ self,
878
+ x: torch.Tensor,
879
+ qkv: nn.Linear,
880
+ q_norm: AttentionRMSNorm,
881
+ k_norm: AttentionRMSNorm,
882
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
883
+ batch, tokens, _ = x.shape
884
+ q, k, v = qkv(x).reshape(batch, tokens, 3, self.num_heads, self.head_dim).unbind(dim=2)
885
+ q = q_norm(q)
886
+ k = k_norm(k)
887
+ return q, k, v
888
+
889
+ def _joint_attention(
890
+ self,
891
+ image_tokens: torch.Tensor,
892
+ text_tokens: torch.Tensor,
893
+ text_key_padding_mask: torch.Tensor,
894
+ coord_rope: MultimodalCoordinateRoPE | None,
895
+ image_coord_ids: torch.Tensor | None,
896
+ text_coord_ids: torch.Tensor | None,
897
+ ) -> tuple[torch.Tensor, torch.Tensor]:
898
+ image_q, image_k, image_v = self._qkv(
899
+ image_tokens,
900
+ self.image_qkv,
901
+ self.image_q_norm,
902
+ self.image_k_norm,
903
+ )
904
+ text_q, text_k, text_v = self._qkv(
905
+ text_tokens,
906
+ self.text_qkv,
907
+ self.text_q_norm,
908
+ self.text_k_norm,
909
+ )
910
+ q = torch.cat([text_q, image_q], dim=1)
911
+ k = torch.cat([text_k, image_k], dim=1)
912
+ v = torch.cat([text_v, image_v], dim=1)
913
+ if coord_rope is not None:
914
+ if image_coord_ids is None or text_coord_ids is None:
915
+ raise ValueError("coordinate ids are required when multimodal coord RoPE is enabled")
916
+ coord_ids = torch.cat([text_coord_ids, image_coord_ids], dim=1)
917
+ q, k = coord_rope.apply(q, k, coord_ids)
918
+
919
+ image_mask = torch.zeros(
920
+ image_tokens.shape[0],
921
+ image_tokens.shape[1],
922
+ device=image_tokens.device,
923
+ dtype=text_key_padding_mask.dtype,
924
+ )
925
+ key_padding_mask = torch.cat([text_key_padding_mask, image_mask], dim=1)
926
+ attn_bias = key_padding_mask[:, None, None, :].to(dtype=q.dtype)
927
+ attn_bias = attn_bias.masked_fill(attn_bias > 0, -10000.0)
928
+ out = F.scaled_dot_product_attention(
929
+ q.transpose(1, 2),
930
+ k.transpose(1, 2),
931
+ v.transpose(1, 2),
932
+ attn_mask=attn_bias,
933
+ dropout_p=0.0,
934
+ is_causal=False,
935
+ )
936
+ out = out.transpose(1, 2).reshape(image_tokens.shape[0], text_tokens.shape[1] + image_tokens.shape[1], -1)
937
+ text_out, image_out = out.split([text_tokens.shape[1], image_tokens.shape[1]], dim=1)
938
+ return self.image_out_proj(image_out), self.text_out_proj(text_out)
939
+
940
+ def forward(
941
+ self,
942
+ x: torch.Tensor,
943
+ text_tokens: torch.Tensor,
944
+ t_emb: torch.Tensor,
945
+ text_key_padding_mask: torch.Tensor,
946
+ text_attn_bias: torch.Tensor | None,
947
+ *,
948
+ height: int,
949
+ width: int,
950
+ coord_rope: MultimodalCoordinateRoPE | None = None,
951
+ image_coord_ids: torch.Tensor | None = None,
952
+ text_coord_ids: torch.Tensor | None = None,
953
+ ) -> tuple[torch.Tensor, torch.Tensor]:
954
+ del text_attn_bias
955
+ image_timestep_mod = self.image_mod(t_emb)
956
+ text_timestep_mod = self.text_mod(t_emb)
957
+ image_shift_attn, image_scale_attn, image_gate_attn, image_shift_mlp, image_scale_mlp, image_gate_mlp = (
958
+ self.image_scale_shift_table[None] + image_timestep_mod.reshape(x.shape[0], 6, -1)
959
+ ).chunk(6, dim=1)
960
+ text_shift_attn, text_scale_attn, text_gate_attn, text_shift_mlp, text_scale_mlp, text_gate_mlp = (
961
+ self.text_scale_shift_table[None] + text_timestep_mod.reshape(text_tokens.shape[0], 6, -1)
962
+ ).chunk(6, dim=1)
963
+
964
+ image_base = x
965
+ text_base = text_tokens
966
+ image_attn_in = modulate(self.image_norm1(image_base), image_shift_attn, image_scale_attn)
967
+ text_attn_in = modulate(self.text_norm1(text_base), text_shift_attn, text_scale_attn)
968
+ image_attn, text_attn = self._joint_attention(
969
+ image_attn_in,
970
+ text_attn_in,
971
+ text_key_padding_mask,
972
+ coord_rope,
973
+ image_coord_ids,
974
+ text_coord_ids,
975
+ )
976
+ if self.parallel_block:
977
+ x = image_base + image_gate_attn * image_attn + image_gate_mlp * self.image_mlp(
978
+ modulate(self.image_norm2(image_base), image_shift_mlp, image_scale_mlp),
979
+ height=height,
980
+ width=width,
981
+ )
982
+ text_tokens = text_base + text_gate_attn * text_attn + text_gate_mlp * self.text_mlp(
983
+ modulate(self.text_norm2(text_base), text_shift_mlp, text_scale_mlp)
984
+ )
985
+ return x, text_tokens
986
+
987
+ x = image_base + image_gate_attn * image_attn
988
+ text_tokens = text_base + text_gate_attn * text_attn
989
+ x = x + image_gate_mlp * self.image_mlp(
990
+ modulate(self.image_norm2(x), image_shift_mlp, image_scale_mlp),
991
+ height=height,
992
+ width=width,
993
+ )
994
+ text_tokens = text_tokens + text_gate_mlp * self.text_mlp(
995
+ modulate(self.text_norm2(text_tokens), text_shift_mlp, text_scale_mlp)
996
+ )
997
+ return x, text_tokens
998
+
999
+
1000
+ class BoomerFLADiT(nn.Module):
1001
+ """Boomer DiT with FLA mixers, optional full image attention, and GLUMBConv FFNs."""
1002
+
1003
+ def __init__(self, config: BoomerFLADiTConfig = BoomerFLADiTConfig()) -> None:
1004
+ super().__init__()
1005
+ if config.patch_size <= 0:
1006
+ raise ValueError(f"patch_size must be positive, got {config.patch_size}")
1007
+ if config.latent_size % config.patch_size != 0:
1008
+ raise ValueError(
1009
+ f"latent_size={config.latent_size} must be divisible by patch_size={config.patch_size}"
1010
+ )
1011
+ if config.dual_stream_depth < 0:
1012
+ raise ValueError(f"dual_stream_depth must be non-negative, got {config.dual_stream_depth}")
1013
+ if config.dual_stream_depth > config.depth:
1014
+ raise ValueError(f"dual_stream_depth={config.dual_stream_depth} exceeds depth={config.depth}")
1015
+ self.config = config
1016
+ hidden_dim = config.hidden_dim
1017
+ self.patch_size = config.patch_size
1018
+ self.token_grid_size = config.latent_size // config.patch_size
1019
+ token_count = self.token_grid_size * self.token_grid_size
1020
+ self.x_embedder = (
1021
+ nn.Linear(config.latent_channels, hidden_dim)
1022
+ if config.patch_size == 1
1023
+ else nn.Conv2d(
1024
+ config.latent_channels,
1025
+ hidden_dim,
1026
+ kernel_size=config.patch_size,
1027
+ stride=config.patch_size,
1028
+ )
1029
+ )
1030
+ self.pos_embed = nn.Parameter(torch.zeros(1, token_count, hidden_dim)) if config.use_abs_pos_embed else None
1031
+ self.t_embedder = TimestepEmbedder(hidden_dim)
1032
+ self.caption_embedder = CaptionEmbedder(config.text_dim, hidden_dim, config.text_seq_len)
1033
+ self.attention_y_norm = (
1034
+ AttentionRMSNorm(hidden_dim, scale_factor=config.y_norm_scale_factor) if config.y_norm else None
1035
+ )
1036
+ self.coord_embedder = (
1037
+ MultimodalCoordinateRoPE(
1038
+ hidden_dim // config.num_heads,
1039
+ image_size=self.token_grid_size,
1040
+ text_seq_len=config.text_seq_len,
1041
+ theta=config.image_rope_theta,
1042
+ )
1043
+ if config.multimodal_coord_ids
1044
+ else None
1045
+ )
1046
+ self.blocks = nn.ModuleList(
1047
+ [
1048
+ (
1049
+ BoomerFLADualStreamBlock(config, layer_idx=i)
1050
+ if i < config.dual_stream_depth
1051
+ else BoomerFLABlock(config, layer_idx=i)
1052
+ )
1053
+ for i in range(config.depth)
1054
+ ]
1055
+ )
1056
+ self.final_norm = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
1057
+ self.final_t_block = nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 2))
1058
+ self.out_proj = nn.Linear(hidden_dim, config.latent_channels * config.patch_size * config.patch_size)
1059
+ self.initialize_weights()
1060
+
1061
+ def initialize_weights(self) -> None:
1062
+ if self.pos_embed is not None:
1063
+ nn.init.normal_(self.pos_embed, std=0.02)
1064
+
1065
+ for block in self.blocks:
1066
+ if isinstance(block, BoomerFLADualStreamBlock):
1067
+ nn.init.zeros_(block.image_mod[1].weight)
1068
+ nn.init.zeros_(block.image_mod[1].bias)
1069
+ nn.init.zeros_(block.text_mod[1].weight)
1070
+ nn.init.zeros_(block.text_mod[1].bias)
1071
+ nn.init.normal_(block.image_scale_shift_table, std=0.02)
1072
+ nn.init.normal_(block.text_scale_shift_table, std=0.02)
1073
+ else:
1074
+ nn.init.zeros_(block.mod[1].weight)
1075
+ nn.init.zeros_(block.mod[1].bias)
1076
+ nn.init.normal_(block.scale_shift_table, std=0.02)
1077
+ if block.use_image_attention:
1078
+ nn.init.zeros_(block.image_attn_mod[1].weight)
1079
+ nn.init.zeros_(block.image_attn_mod[1].bias)
1080
+ nn.init.normal_(block.image_attn_scale_shift_table, std=0.02)
1081
+
1082
+ nn.init.zeros_(self.final_t_block[1].weight)
1083
+ nn.init.zeros_(self.final_t_block[1].bias)
1084
+ nn.init.zeros_(self.out_proj.weight)
1085
+ nn.init.zeros_(self.out_proj.bias)
1086
+
1087
+ def apply_y_norm(self, caption_tokens: torch.Tensor) -> torch.Tensor:
1088
+ if self.attention_y_norm is None:
1089
+ return caption_tokens
1090
+ return self.attention_y_norm(caption_tokens)
1091
+
1092
+ def null_condition(
1093
+ self,
1094
+ batch_size: int,
1095
+ *,
1096
+ device: torch.device | str,
1097
+ dtype: torch.dtype,
1098
+ mask_dtype: torch.dtype | None = None,
1099
+ token_num: int | None = None,
1100
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1101
+ return self.caption_embedder.null_condition(
1102
+ batch_size,
1103
+ device=device,
1104
+ dtype=dtype,
1105
+ mask_dtype=mask_dtype,
1106
+ token_num=token_num,
1107
+ )
1108
+
1109
+ def apply_condition_dropout(
1110
+ self,
1111
+ text_embedding: torch.Tensor,
1112
+ attention_mask: torch.Tensor,
1113
+ probability: float,
1114
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1115
+ if probability <= 0.0:
1116
+ return text_embedding, attention_mask
1117
+ batch_size = text_embedding.shape[0]
1118
+ null_text, null_mask = self.null_condition(
1119
+ batch_size,
1120
+ device=text_embedding.device,
1121
+ dtype=text_embedding.dtype,
1122
+ mask_dtype=attention_mask.dtype,
1123
+ token_num=text_embedding.shape[-2],
1124
+ )
1125
+ # torch.where over a per-sample bool. Avoids the bool(drop.any()) CUDA
1126
+ # sync (which would defeat the training-loop sync removal) and skips
1127
+ # the full-tensor .clone() that the previous in-place path required.
1128
+ drop = torch.rand(batch_size, device=text_embedding.device) < probability
1129
+ drop_text = drop.view(batch_size, *([1] * (text_embedding.dim() - 1)))
1130
+ drop_mask = drop.view(batch_size, *([1] * (attention_mask.dim() - 1)))
1131
+ text_embedding = torch.where(drop_text, null_text, text_embedding)
1132
+ attention_mask = torch.where(drop_mask, null_mask, attention_mask)
1133
+ return text_embedding, attention_mask
1134
+
1135
+ def forward(
1136
+ self,
1137
+ noisy_latent: torch.Tensor,
1138
+ timesteps: torch.Tensor,
1139
+ text_embedding: torch.Tensor,
1140
+ attention_mask: torch.Tensor,
1141
+ ) -> torch.Tensor:
1142
+ batch, channels, height, width = noisy_latent.shape
1143
+ if channels != self.config.latent_channels:
1144
+ raise ValueError(
1145
+ f"Expected latent_channels={self.config.latent_channels}, got shape {tuple(noisy_latent.shape)}"
1146
+ )
1147
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
1148
+ raise ValueError(
1149
+ f"latent height/width must be divisible by patch_size={self.patch_size}, got {(height, width)}"
1150
+ )
1151
+ token_height = height // self.patch_size
1152
+ token_width = width // self.patch_size
1153
+ token_count = token_height * token_width
1154
+ if self.pos_embed is not None and token_count != self.pos_embed.shape[1]:
1155
+ raise ValueError(
1156
+ f"absolute pos_embed expects {self.pos_embed.shape[1]} latent tokens, got {token_count}. "
1157
+ "Disable it with --no-abs-pos-embed for variable latent sizes."
1158
+ )
1159
+ if text_embedding.shape[-1] != self.config.text_dim:
1160
+ raise ValueError(f"text_embedding last dim must be {self.config.text_dim}, got {text_embedding.shape[-1]}")
1161
+
1162
+ text_tokens = self.caption_embedder(text_embedding)
1163
+ text_tokens = self.apply_y_norm(text_tokens)
1164
+ text_key_padding_mask = attention_mask == 0
1165
+
1166
+ if self.patch_size == 1:
1167
+ x = noisy_latent.flatten(2).transpose(1, 2)
1168
+ x = self.x_embedder(x)
1169
+ else:
1170
+ x = self.x_embedder(noisy_latent).flatten(2).transpose(1, 2)
1171
+ if self.pos_embed is not None:
1172
+ x = x + self.pos_embed
1173
+ image_coord_ids = None
1174
+ text_coord_ids = None
1175
+ if self.coord_embedder is not None:
1176
+ image_coord_ids = self.coord_embedder.image_ids(
1177
+ batch,
1178
+ height=token_height,
1179
+ width=token_width,
1180
+ device=x.device,
1181
+ )
1182
+ text_coord_ids = self.coord_embedder.text_ids(batch, text_tokens.shape[1], device=text_tokens.device)
1183
+ text_attn_bias = text_key_padding_mask[:, None, None, :].to(dtype=x.dtype)
1184
+ text_attn_bias = text_attn_bias.masked_fill(text_attn_bias > 0, -10000.0)
1185
+ t_emb = self.t_embedder(timesteps)
1186
+ use_ckpt = self.config.gradient_checkpointing and self.training
1187
+ for block in self.blocks:
1188
+ if getattr(block, "updates_text", False):
1189
+ # Dual-stream block: returns (x, text_tokens).
1190
+ # Non-tensor args (height, width, coord_rope, coord IDs) captured via closure.
1191
+ _h, _w = token_height, token_width
1192
+ _cr, _ii, _ti = self.coord_embedder, image_coord_ids, text_coord_ids
1193
+ if use_ckpt:
1194
+ def _dual_fn(x, tt, te, mk, bi,
1195
+ _blk=block, h=_h, w=_w, cr=_cr, ii=_ii, ti=_ti):
1196
+ return _blk(x, tt, te, mk, bi,
1197
+ height=h, width=w, coord_rope=cr,
1198
+ image_coord_ids=ii, text_coord_ids=ti)
1199
+ x, text_tokens = _ckpt(_dual_fn, x, text_tokens, t_emb,
1200
+ text_key_padding_mask, text_attn_bias,
1201
+ use_reentrant=False,
1202
+ preserve_rng_state=False)
1203
+ else:
1204
+ x, text_tokens = block(
1205
+ x, text_tokens, t_emb, text_key_padding_mask, text_attn_bias,
1206
+ height=token_height, width=token_width,
1207
+ coord_rope=self.coord_embedder,
1208
+ image_coord_ids=image_coord_ids, text_coord_ids=text_coord_ids,
1209
+ )
1210
+ else:
1211
+ # Single-stream block: returns x only.
1212
+ _h, _w = token_height, token_width
1213
+ if use_ckpt:
1214
+ def _single_fn(x, tt, te, mk, bi,
1215
+ _blk=block, h=_h, w=_w):
1216
+ return _blk(x, tt, te, mk, bi, height=h, width=w)
1217
+ x = _ckpt(_single_fn, x, text_tokens, t_emb,
1218
+ text_key_padding_mask, text_attn_bias,
1219
+ use_reentrant=False,
1220
+ preserve_rng_state=False)
1221
+ else:
1222
+ x = block(
1223
+ x, text_tokens, t_emb, text_key_padding_mask, text_attn_bias,
1224
+ height=token_height, width=token_width,
1225
+ )
1226
+ final_mod = self.final_t_block(t_emb)
1227
+ shift, scale = final_mod.reshape(batch, 2, -1).chunk(2, dim=1)
1228
+ x = modulate(self.final_norm(x), shift, scale)
1229
+ x = self.out_proj(x)
1230
+ if self.patch_size == 1:
1231
+ return x.transpose(1, 2).reshape(batch, channels, height, width)
1232
+ patch = self.patch_size
1233
+ x = x.reshape(batch, token_height, token_width, channels, patch, patch)
1234
+ x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
1235
+ return x.reshape(batch, channels, height, width)
pipeline_boomer.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BoomerPipeline β€” HuggingFace DiffusionPipeline wrapper for Boomer FLA.
2
+
3
+ Load with:
4
+ from diffusers import DiffusionPipeline
5
+ pipe = DiffusionPipeline.from_pretrained("akrao9/Boomer-T2I", trust_remote_code=True).to("cuda")
6
+ image = pipe("a photorealistic portrait of a woman with dark hair")[0]
7
+
8
+ Requires:
9
+ pip install torch diffusers transformers accelerate safetensors
10
+ pip install git+https://github.com/Algomancer/STORK.git
11
+ pip install git+https://github.com/sustcsonglin/flash-linear-attention.git
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import gc
17
+ import json
18
+ import sys
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+ from typing import Any, List, Optional, Union
22
+
23
+ import torch
24
+ from diffusers import DiffusionPipeline
25
+ from diffusers.utils import logging
26
+ from PIL import Image
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ # ── pipeline output ────────────────────────────────────────────────────────────
32
+ @dataclass
33
+ class BoomerOutput:
34
+ """Return type of BoomerPipeline.__call__."""
35
+ images: List[Image.Image]
36
+
37
+ def __iter__(self):
38
+ return iter(self.images)
39
+
40
+ def __getitem__(self, idx):
41
+ return self.images[idx]
42
+
43
+ def __len__(self):
44
+ return len(self.images)
45
+
46
+
47
+ # ── text encoding helpers ──────────────────────────────────────────────────────
48
+ def _gemma_select_index(max_length: int) -> list[int]:
49
+ """Sana-style token selection: BOS + last (max_length-1) tokens."""
50
+ return [0] + list(range(-max_length + 1, 0))
51
+
52
+
53
+ @torch.inference_mode()
54
+ def _encode_prompts(
55
+ tokenizer: Any,
56
+ text_encoder: Any,
57
+ prompts: list[str],
58
+ max_length: int,
59
+ device: str,
60
+ dtype: torch.dtype,
61
+ ) -> tuple[torch.Tensor, torch.Tensor]:
62
+ """Encode prompts with Gemma 4. Returns (embeddings, attention_mask)."""
63
+ token_kwargs = dict(
64
+ max_length=max_length,
65
+ padding="max_length",
66
+ truncation=True,
67
+ return_tensors="pt",
68
+ )
69
+ try:
70
+ tokens = tokenizer(text=prompts, **token_kwargs)
71
+ except TypeError:
72
+ tokens = tokenizer(prompts, **token_kwargs)
73
+
74
+ input_ids = tokens["input_ids"].to(device)
75
+ attention_mask = tokens["attention_mask"].to(device)
76
+
77
+ output = text_encoder(input_ids, attention_mask=attention_mask)
78
+ hidden = output[0] if isinstance(output, tuple) else output.last_hidden_state
79
+
80
+ select_idx = _gemma_select_index(max_length)
81
+ return (
82
+ hidden[:, select_idx, :].to(dtype=dtype),
83
+ attention_mask[:, select_idx],
84
+ )
85
+
86
+
87
+ # ── latent helpers (inlined from latent_norm.py) ──────────────────────────────
88
+ def _stat_tensor(value: Any, latent: torch.Tensor) -> torch.Tensor:
89
+ tensor = torch.as_tensor(value, device=latent.device, dtype=latent.dtype)
90
+ if tensor.ndim == 0:
91
+ return tensor
92
+ if tensor.ndim == 1:
93
+ if tensor.numel() != latent.shape[1]:
94
+ raise ValueError(f"latent stat has {tensor.numel()} channels, expected {latent.shape[1]}")
95
+ return tensor.view(1, tensor.numel(), 1, 1)
96
+ raise ValueError("latent stat must be scalar or 1-D channel list")
97
+
98
+
99
+ def _denormalize(latent: torch.Tensor, mean: Any, std: Any) -> torch.Tensor:
100
+ return latent * _stat_tensor(std, latent) + _stat_tensor(mean, latent)
101
+
102
+
103
+ # ── pipeline ───────────────────────────────────────────────────────────────────
104
+ class BoomerPipeline(DiffusionPipeline):
105
+ """
106
+ Text-to-image generation with Boomer FLA (Flash Linear Attention DiT).
107
+
108
+ Components
109
+ ----------
110
+ transformer : BoomerFLADiT 657 M param FLA denoiser
111
+ vae : AutoencoderDC DC-AE f32c32 decoder
112
+ text_encoder: Gemma 4 2B decoder 1536-dim text embeddings
113
+ tokenizer : AutoProcessor Gemma tokenizer
114
+ """
115
+
116
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
117
+
118
+ def __init__(
119
+ self,
120
+ transformer,
121
+ vae,
122
+ text_encoder,
123
+ tokenizer,
124
+ *,
125
+ model_config,
126
+ latent_mean: Any = 0.0,
127
+ latent_std: Any = 1.0,
128
+ scaling_factor: float = 0.41407,
129
+ flow_shift: float = 1.5,
130
+ max_text_length: int = 300,
131
+ ) -> None:
132
+ super().__init__()
133
+ self.register_modules(
134
+ transformer=transformer,
135
+ vae=vae,
136
+ text_encoder=text_encoder,
137
+ tokenizer=tokenizer,
138
+ )
139
+ self.model_config = model_config
140
+ self.latent_mean = latent_mean
141
+ self.latent_std = latent_std
142
+ self.scaling_factor = scaling_factor
143
+ self.flow_shift = flow_shift
144
+ self.max_text_length = max_text_length
145
+
146
+ # ── loading ────────────────────────────────────────────────────────────────
147
+ @classmethod
148
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "BoomerPipeline": # type: ignore[override]
149
+ """Load Boomer from a HuggingFace repo or local directory."""
150
+ from huggingface_hub import snapshot_download
151
+
152
+ token = kwargs.pop("token", None) or kwargs.pop("use_auth_token", None)
153
+ dtype = kwargs.pop("torch_dtype", torch.bfloat16)
154
+ cache_dir = kwargs.pop("cache_dir", None)
155
+
156
+ # ── 1. Resolve local snapshot path ────────────────────────────────────
157
+ if Path(pretrained_model_name_or_path).is_dir():
158
+ local = Path(pretrained_model_name_or_path)
159
+ else:
160
+ logger.info(f"Downloading Boomer snapshot from {pretrained_model_name_or_path} ...")
161
+ local = Path(snapshot_download(
162
+ pretrained_model_name_or_path,
163
+ token=token,
164
+ cache_dir=cache_dir,
165
+ ignore_patterns=["*.png", "*.jpg", "*.jpeg", "*.gif"],
166
+ ))
167
+
168
+ # Add snapshot dir to sys.path so local .py files import each other
169
+ if str(local) not in sys.path:
170
+ sys.path.insert(0, str(local))
171
+
172
+ # ── 2. Import local model/scheduler modules ────────────────────────────
173
+ from modeling_boomer_fla import BoomerFLADiT, BoomerFLADiTConfig # noqa: PLC0415
174
+ from scheduling_boomer_stork import make_stork_scheduler # noqa: PLC0415
175
+ from safetensors.torch import load_file # noqa: PLC0415
176
+
177
+ # ── 3. Load transformer config + weights ───────────────────────────────
178
+ transformer_dir = local / "transformer"
179
+ cfg_raw = json.loads((transformer_dir / "config.json").read_text())
180
+ cfg_clean = {k: v for k, v in cfg_raw.items() if not k.startswith("_")}
181
+ model_config = BoomerFLADiTConfig(**cfg_clean)
182
+
183
+ logger.info("Loading Boomer FLA DiT weights ...")
184
+ state_dict = load_file(str(transformer_dir / "diffusion_pytorch_model.safetensors"))
185
+ transformer = BoomerFLADiT(model_config)
186
+ missing, unexpected = transformer.load_state_dict(state_dict, strict=False)
187
+ if missing:
188
+ logger.warning(f"Missing keys in transformer state dict: {len(missing)}")
189
+ if unexpected:
190
+ logger.warning(f"Unexpected keys in transformer state dict: {len(unexpected)}")
191
+ transformer = transformer.to(dtype=dtype)
192
+
193
+ # ── 4. Read metadata from index / scheduler config ─────────────────────
194
+ model_index = json.loads((local / "model_index.json").read_text())
195
+ sched_cfg = json.loads((local / "scheduler" / "scheduler_config.json").read_text())
196
+ flow_shift = float(sched_cfg.get("flow_shift", 1.5))
197
+ scaling_factor = float(sched_cfg.get("scaling_factor", 0.41407))
198
+ latent_norm = model_index.get("latent_normalization", {})
199
+ latent_mean = latent_norm.get("mean", 0.0)
200
+ latent_std = latent_norm.get("std", 1.0)
201
+ te_info = model_index.get("text_encoder", {})
202
+ te_repo = te_info.get("repo_id", "google/gemma-4-E2B-it")
203
+ max_text_len = int(te_info.get("max_length", 300))
204
+ vae_info = model_index.get("vae", {})
205
+ vae_repo = vae_info.get("repo_id", "mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
206
+
207
+ # ── 5. Load VAE ─────────────────────────────────────────────────────────
208
+ from diffusers import AutoencoderDC # noqa: PLC0415
209
+ logger.info(f"Loading VAE from {vae_repo} ...")
210
+ vae = AutoencoderDC.from_pretrained(vae_repo, torch_dtype=dtype, token=token)
211
+
212
+ # ── 6. Load text encoder ────────────────────────────────────────────────
213
+ from transformers import AutoModelForCausalLM, AutoProcessor # noqa: PLC0415
214
+ logger.info(f"Loading text encoder from {te_repo} ...")
215
+ tokenizer = AutoProcessor.from_pretrained(te_repo, token=token)
216
+ text_encoder = AutoModelForCausalLM.from_pretrained(
217
+ te_repo, torch_dtype=dtype, token=token,
218
+ )
219
+ if hasattr(text_encoder, "get_decoder"):
220
+ text_encoder = text_encoder.get_decoder()
221
+
222
+ return cls(
223
+ transformer=transformer,
224
+ vae=vae,
225
+ text_encoder=text_encoder,
226
+ tokenizer=tokenizer,
227
+ model_config=model_config,
228
+ latent_mean=latent_mean,
229
+ latent_std=latent_std,
230
+ scaling_factor=scaling_factor,
231
+ flow_shift=flow_shift,
232
+ max_text_length=max_text_len,
233
+ )
234
+
235
+ # ── generation ─────────────────────────────────────────────────────────────
236
+ @torch.inference_mode()
237
+ def __call__(
238
+ self,
239
+ prompt: Union[str, List[str]],
240
+ steps: int = 32,
241
+ seed: int = 42,
242
+ cfg_scale: float = 4.5,
243
+ cfg_rescale: float = 0.5,
244
+ substeps: int = 5,
245
+ offload_text_encoder: bool = True,
246
+ output_type: str = "pil",
247
+ **kwargs,
248
+ ) -> BoomerOutput:
249
+ """
250
+ Generate images from text prompts.
251
+
252
+ Parameters
253
+ ----------
254
+ prompt : str or list[str]
255
+ steps : denoising steps (default 32 with STORK-2)
256
+ seed : random seed
257
+ cfg_scale : classifier-free guidance scale (4.0–5.0 recommended)
258
+ cfg_rescale : CFG variance rescale (0.5 recommended)
259
+ substeps : STORK-2 internal RK micro-steps (5 recommended)
260
+ offload_text_encoder : unload text encoder after encoding to free VRAM
261
+ output_type : "pil" (default) or "latent"
262
+ """
263
+ from scheduling_boomer_stork import make_stork_scheduler # noqa: PLC0415
264
+
265
+ prompts = [prompt] if isinstance(prompt, str) else prompt
266
+ batch = len(prompts)
267
+ device = self._execution_device
268
+ dtype = next(self.transformer.parameters()).dtype
269
+
270
+ # ── Phase 1: encode text ───────────────────────────────────────────────
271
+ self.text_encoder.to(device)
272
+ self.text_encoder.eval()
273
+ text_emb, attn_mask = _encode_prompts(
274
+ self.tokenizer, self.text_encoder, prompts,
275
+ max_length=self.max_text_length,
276
+ device=device, dtype=dtype,
277
+ )
278
+ if offload_text_encoder:
279
+ self.text_encoder.to("cpu")
280
+ gc.collect()
281
+ if device != "cpu":
282
+ torch.cuda.empty_cache()
283
+
284
+ # ── Phase 2: denoise ───────────────────────────────────────────────────
285
+ self.transformer.to(device)
286
+ self.transformer.eval()
287
+
288
+ uncond_emb, uncond_mask = self.transformer.null_condition(
289
+ batch, device=device, dtype=dtype,
290
+ mask_dtype=attn_mask.dtype, token_num=text_emb.shape[-2],
291
+ )
292
+
293
+ sched = make_stork_scheduler(
294
+ steps=steps, device=device,
295
+ flow_shift=self.flow_shift,
296
+ solver_order=2, derivative_order=1, substeps=substeps,
297
+ )
298
+
299
+ gen = torch.Generator(device=device).manual_seed(seed)
300
+ latent = torch.randn(
301
+ batch,
302
+ self.model_config.latent_channels,
303
+ self.model_config.latent_size,
304
+ self.model_config.latent_size,
305
+ generator=gen, device=device, dtype=dtype,
306
+ )
307
+
308
+ for step in range(steps):
309
+ sigma = sched.sigmas[step].to(device=device, dtype=dtype)
310
+ lb = latent.repeat(2, 1, 1, 1)
311
+ tb = sigma.expand(batch * 2)
312
+ txt_b = torch.cat([uncond_emb, text_emb], dim=0)
313
+ msk_b = torch.cat([uncond_mask, attn_mask], dim=0)
314
+ uv, cv = self.transformer(lb, tb, txt_b, msk_b).chunk(2)
315
+ guided = uv + cfg_scale * (cv - uv)
316
+ if cfg_rescale > 0.0:
317
+ sc = cv.std(dim=(1, 2, 3), keepdim=True)
318
+ sg = guided.std(dim=(1, 2, 3), keepdim=True)
319
+ guided = cfg_rescale * guided * (sc / sg.clamp_min(1e-5)) + (1.0 - cfg_rescale) * guided
320
+ latent = sched.step(guided, sigma, latent, return_dict=True).prev_sample
321
+
322
+ if output_type == "latent":
323
+ return BoomerOutput(images=[latent])
324
+
325
+ # ── Phase 3: decode ────────────────────────────────────────────────────
326
+ self.vae.to(device)
327
+ self.vae.eval()
328
+
329
+ latent_dec = _denormalize(latent, self.latent_mean, self.latent_std)
330
+ # VAE expects latent / scaling_factor
331
+ latent_dec = latent_dec / self.scaling_factor
332
+ decoded = self.vae.decode(latent_dec, return_dict=False)[0]
333
+
334
+ # [-1, 1] β†’ [0, 1] β†’ uint8 PIL
335
+ pixels = (decoded.float() / 2.0 + 0.5).clamp(0, 1)
336
+ images = []
337
+ for i in range(batch):
338
+ img_np = (pixels[i].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")
339
+ images.append(Image.fromarray(img_np))
340
+
341
+ return BoomerOutput(images=images)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchScheduler",
3
+ "flow_shift": 1.5,
4
+ "num_train_steps": 75000,
5
+ "sampler": "stork2",
6
+ "stork_substeps": 5,
7
+ "num_inference_steps": 32,
8
+ "scaling_factor": 0.41407
9
+ }
scheduling_boomer_stork.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """STORK flow-matching scheduler loader β€” self-contained for HuggingFace distribution.
2
+
3
+ unshift_sigma inlined from boomer/sana_flow.py. No boomer package import needed.
4
+ Requires: torch. STORK repo must be pip-installed or available on sys.path.
5
+ pip install git+https://github.com/Algomancer/STORK.git
6
+ or clone alongside Boomer:
7
+ git clone https://github.com/Algomancer/STORK.git
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import importlib.util
13
+ import sys
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import torch
18
+
19
+
20
+ # ── inlined from boomer/sana_flow.py ──────────────────────────────────────────
21
+ def unshift_sigma(sigma: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
22
+ """Invert Sana/FlowMatch sigma shift."""
23
+ if shift <= 0.0:
24
+ raise ValueError(f"flow shift must be positive, got {shift}")
25
+ if shift == 1.0:
26
+ return sigma
27
+ return sigma / (shift - (shift - 1.0) * sigma).clamp_min(1e-12)
28
+
29
+
30
+ # ── STORK loader ───────────────────────────────────────────────────────────────
31
+ def _stork_candidates() -> list[Path]:
32
+ """Common locations where STORKScheduler.py might live."""
33
+ here = Path(__file__).resolve().parent
34
+ return [
35
+ here, # bundled alongside this file (HF snapshot dir)
36
+ here.parent / "STORK",
37
+ here / "STORK",
38
+ Path("/content/STORK"),
39
+ Path("/tmp/STORK"),
40
+ ]
41
+
42
+
43
+ def _find_stork_path() -> Path:
44
+ # 1. Already importable?
45
+ try:
46
+ import STORKScheduler # noqa: F401
47
+ return Path(STORKScheduler.__file__)
48
+ except ImportError:
49
+ pass
50
+
51
+ # 2. Search known candidate directories
52
+ for candidate_dir in _stork_candidates():
53
+ p = candidate_dir / "STORKScheduler.py"
54
+ if p.is_file():
55
+ if str(candidate_dir) not in sys.path:
56
+ sys.path.insert(0, str(candidate_dir))
57
+ return p
58
+
59
+ # 3. Try pip-installed stork package
60
+ try:
61
+ import stork
62
+ p = Path(stork.__file__).parent / "STORKScheduler.py"
63
+ if p.is_file():
64
+ return p
65
+ except ImportError:
66
+ pass
67
+
68
+ searched = "\n ".join(str(d / "STORKScheduler.py") for d in _stork_candidates())
69
+ raise FileNotFoundError(
70
+ "Could not find STORKScheduler.py. Options:\n"
71
+ " pip install git+https://github.com/ZT220501/STORK.git\n"
72
+ "or clone it next to Boomer:\n"
73
+ " git clone https://github.com/ZT220501/STORK.git\n"
74
+ f"Searched:\n {searched}"
75
+ )
76
+
77
+
78
+ def make_stork_scheduler(
79
+ *,
80
+ steps: int,
81
+ device: str | torch.device,
82
+ flow_shift: float,
83
+ solver_order: int,
84
+ derivative_order: int,
85
+ substeps: int,
86
+ start_sigma: float | None = None,
87
+ ) -> Any:
88
+ stork_path = _find_stork_path()
89
+ spec = importlib.util.spec_from_file_location("STORKScheduler", stork_path)
90
+ if spec is None or spec.loader is None:
91
+ raise ImportError(f"Could not load STORK scheduler from {stork_path}")
92
+ module = importlib.util.module_from_spec(spec)
93
+ spec.loader.exec_module(module)
94
+ STORKScheduler = module.STORKScheduler
95
+
96
+ scheduler = STORKScheduler(
97
+ shift=flow_shift,
98
+ solver_order=solver_order,
99
+ prediction_type="flow_prediction",
100
+ derivative_order=derivative_order,
101
+ s=substeps,
102
+ )
103
+ if start_sigma is not None:
104
+ start = torch.tensor(float(start_sigma), dtype=torch.float32)
105
+ base_start = unshift_sigma(start, flow_shift)
106
+ base_sigmas = torch.linspace(float(base_start.item()), 0.0, steps + 1, dtype=torch.float32)[:-1].tolist()
107
+ else:
108
+ base_sigmas = torch.linspace(1.0, 0.0, steps + 1, dtype=torch.float32)[:-1].tolist()
109
+
110
+ scheduler.set_timesteps(num_inference_steps=steps, device=device, sigmas=base_sigmas)
111
+ scheduler.dt_list = scheduler.dt_list.to(device=device)
112
+ scheduler.sigmas = scheduler.sigmas.to(device=device)
113
+ return scheduler
transformer/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "boomer_fla",
3
+ "latent_channels": 32,
4
+ "latent_size": 32,
5
+ "text_dim": 1536,
6
+ "text_seq_len": 384,
7
+ "hidden_dim": 896,
8
+ "depth": 24,
9
+ "num_heads": 14,
10
+ "mlp_ratio": 2.5,
11
+ "y_norm": true,
12
+ "y_norm_scale_factor": 0.01,
13
+ "mixer_type": "fla_gated_deltanet",
14
+ "fla_mode": "chunk",
15
+ "fla_feature_map": "relu",
16
+ "fla_bidirectional": true,
17
+ "use_short_conv": true,
18
+ "conv_size": 4,
19
+ "image_attention_every": 6,
20
+ "image_attention_backend": "sdpa",
21
+ "image_attention_rope": true,
22
+ "image_rope_theta": 10000.0,
23
+ "cross_attention_backend": "sdpa",
24
+ "cross_attention_qk_norm": true,
25
+ "parallel_block": true,
26
+ "dual_stream_depth": 2,
27
+ "multimodal_coord_ids": true,
28
+ "use_abs_pos_embed": false,
29
+ "patch_size": 1,
30
+ "gradient_checkpointing": false,
31
+ "_architecture_class": "BoomerFLADiT",
32
+ "_boomer_version": "1.0.0",
33
+ "_image_size_px": 1024,
34
+ "_latent_tokens": 1024
35
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c76fe83bd56e2485b0a2b5b4c03040aaef24da21218e461393b0bf07e40b5d8
3
+ size 2629135976