multimodalart HF Staff commited on
Commit
168566a
·
verified ·
1 Parent(s): e99a766

Bundle PR diffusers (yiyi-refactor-fused + native prompt upsampling)

Browse files
diffusers_src/src/diffusers/models/transformers/transformer_ideogram4.py CHANGED
@@ -12,6 +12,7 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
15
  import math
16
 
17
  import torch
@@ -22,6 +23,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
22
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
  from ...utils import logging
24
  from ...utils.torch_utils import maybe_allow_in_graph
 
 
25
  from ..modeling_outputs import Transformer2DModelOutput
26
  from ..modeling_utils import ModelMixin
27
  from ..normalization import RMSNorm
@@ -44,19 +47,6 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
44
  return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
45
 
46
 
47
- def _apply_rotary_pos_emb(
48
- q: torch.Tensor,
49
- k: torch.Tensor,
50
- cos: torch.Tensor,
51
- sin: torch.Tensor,
52
- ) -> tuple[torch.Tensor, torch.Tensor]:
53
- cos = cos.unsqueeze(1)
54
- sin = sin.unsqueeze(1)
55
- q_embed = (q * cos) + (_rotate_half(q) * sin)
56
- k_embed = (k * cos) + (_rotate_half(k) * sin)
57
- return q_embed, k_embed
58
-
59
-
60
  class Ideogram4MRoPE(nn.Module):
61
  """Multi-axis (t, h, w) interleaved rotary position embedding."""
62
 
@@ -74,7 +64,6 @@ class Ideogram4MRoPE(nn.Module):
74
  self.mrope_section = tuple(mrope_section)
75
  self.head_dim = head_dim
76
 
77
- @torch.no_grad()
78
  def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
79
  # position_ids: (B, L, 3) of int (axes are t, h, w).
80
  if position_ids.ndim != 3 or position_ids.shape[-1] != 3:
@@ -97,8 +86,49 @@ class Ideogram4MRoPE(nn.Module):
97
  return emb.cos(), emb.sin()
98
 
99
 
100
- class Ideogram4Attention(nn.Module):
101
- """Self-attention with merged QKV, q/k RMSNorm, MRoPE and block-diagonal mask."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-5) -> None:
104
  super().__init__()
@@ -113,34 +143,23 @@ class Ideogram4Attention(nn.Module):
113
  self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
114
  self.o = nn.Linear(hidden_size, hidden_size, bias=False)
115
 
 
 
116
  def forward(
117
  self,
118
  hidden_states: torch.Tensor,
119
- segment_ids: torch.Tensor,
120
- cos: torch.Tensor,
121
- sin: torch.Tensor,
122
  ) -> torch.Tensor:
123
- batch_size, seq_len, _ = hidden_states.shape
124
-
125
- qkv = self.qkv(hidden_states).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
126
- q, k, v = qkv.unbind(dim=2)
127
-
128
- q = self.norm_q(q)
129
- k = self.norm_k(k)
130
-
131
- # SDPA expects (B, num_heads, L, head_dim).
132
- q = q.transpose(1, 2)
133
- k = k.transpose(1, 2)
134
- v = v.transpose(1, 2)
135
-
136
- q, k = _apply_rotary_pos_emb(q, k, cos, sin)
137
-
138
- # Block-diagonal mask from segment ids: tokens only attend within their segment.
139
- attn_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1)
140
-
141
- out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
142
- out = out.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
143
- return self.o(out)
144
 
145
 
146
  class Ideogram4MLP(nn.Module):
@@ -180,9 +199,8 @@ class Ideogram4TransformerBlock(nn.Module):
180
  def forward(
181
  self,
182
  hidden_states: torch.Tensor,
183
- segment_ids: torch.Tensor,
184
- cos: torch.Tensor,
185
- sin: torch.Tensor,
186
  adaln_input: torch.Tensor,
187
  ) -> torch.Tensor:
188
  mod = self.adaln_modulation(adaln_input)
@@ -194,9 +212,8 @@ class Ideogram4TransformerBlock(nn.Module):
194
 
195
  attn_out = self.attention(
196
  self.attention_norm1(hidden_states) * scale_msa,
197
- segment_ids=segment_ids,
198
- cos=cos,
199
- sin=sin,
200
  )
201
  hidden_states = hidden_states + gate_msa * self.attention_norm2(attn_out)
202
  hidden_states = hidden_states + gate_mlp * self.ffn_norm2(
@@ -251,7 +268,7 @@ class Ideogram4FinalLayer(nn.Module):
251
  return self.linear(self.norm_final(hidden_states) * scale)
252
 
253
 
254
- class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
255
  r"""
256
  The flow-matching transformer backbone used by the Ideogram 4 pipeline.
257
 
@@ -346,6 +363,19 @@ class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
346
  adaln_dim=adaln_dim,
347
  )
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  def forward(
350
  self,
351
  hidden_states: torch.Tensor,
@@ -377,19 +407,13 @@ class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
377
 
378
  Returns:
379
  [`~models.modeling_outputs.Transformer2DModelOutput`] or a `tuple` whose first element is a tensor of shape
380
- `(batch_size, sequence_length, in_channels)` in float32. Only positions tagged with
381
  `OUTPUT_IMAGE_INDICATOR` carry meaningful velocity predictions.
382
  """
383
  batch_size, seq_len, in_channels = hidden_states.shape
384
  if in_channels != self.in_channels:
385
  raise ValueError(f"Expected last dim {self.in_channels}, got {in_channels}.")
386
 
387
- param_dtype = self.dtype
388
- hidden_states = hidden_states.to(param_dtype)
389
- timestep = timestep.to(param_dtype)
390
- encoder_hidden_states = encoder_hidden_states.to(param_dtype)
391
-
392
- indicator = indicator.to(torch.long)
393
  llm_token_mask = (indicator == LLM_TOKEN_INDICATOR).to(hidden_states.dtype).unsqueeze(-1)
394
  output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(hidden_states.dtype).unsqueeze(-1)
395
 
@@ -414,16 +438,20 @@ class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
414
  cos, sin = self.rotary_emb(position_ids)
415
  cos = cos.to(hidden_states.dtype)
416
  sin = sin.to(hidden_states.dtype)
 
 
 
 
417
 
418
  for block in self.layers:
419
  if torch.is_grad_enabled() and self.gradient_checkpointing:
420
  hidden_states = self._gradient_checkpointing_func(
421
- block, hidden_states, segment_ids, cos, sin, adaln_input
422
  )
423
  else:
424
- hidden_states = block(hidden_states, segment_ids, cos, sin, adaln_input)
425
 
426
- output = self.final_layer(hidden_states, conditioning=adaln_input).to(torch.float32)
427
 
428
  if not return_dict:
429
  return (output,)
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ import inspect
16
  import math
17
 
18
  import torch
 
23
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
  from ...utils import logging
25
  from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention import AttentionMixin, AttentionModuleMixin
27
+ from ..attention_dispatch import dispatch_attention_fn
28
  from ..modeling_outputs import Transformer2DModelOutput
29
  from ..modeling_utils import ModelMixin
30
  from ..normalization import RMSNorm
 
47
  return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  class Ideogram4MRoPE(nn.Module):
51
  """Multi-axis (t, h, w) interleaved rotary position embedding."""
52
 
 
64
  self.mrope_section = tuple(mrope_section)
65
  self.head_dim = head_dim
66
 
 
67
  def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
68
  # position_ids: (B, L, 3) of int (axes are t, h, w).
69
  if position_ids.ndim != 3 or position_ids.shape[-1] != 3:
 
86
  return emb.cos(), emb.sin()
87
 
88
 
89
+ class Ideogram4AttnProcessor:
90
+ _attention_backend = None
91
+ _parallel_config = None
92
+
93
+ def __call__(
94
+ self,
95
+ attn: "Ideogram4Attention",
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: torch.Tensor,
98
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor],
99
+ ) -> torch.Tensor:
100
+ batch_size, seq_len, _ = hidden_states.shape
101
+
102
+ qkv = attn.qkv(hidden_states).view(batch_size, seq_len, 3, attn.num_heads, attn.head_dim)
103
+ query, key, value = qkv.unbind(dim=2)
104
+
105
+ query = attn.norm_q(query)
106
+ key = attn.norm_k(key)
107
+
108
+ # MRoPE applied in (B, L, num_heads, head_dim) layout; cos/sin broadcast over the head axis.
109
+ cos, sin = image_rotary_emb
110
+ cos = cos.unsqueeze(2)
111
+ sin = sin.unsqueeze(2)
112
+ query = (query * cos) + (_rotate_half(query) * sin)
113
+ key = (key * cos) + (_rotate_half(key) * sin)
114
+
115
+ hidden_states = dispatch_attention_fn(
116
+ query,
117
+ key,
118
+ value,
119
+ attn_mask=attention_mask,
120
+ backend=self._attention_backend,
121
+ parallel_config=self._parallel_config,
122
+ )
123
+ hidden_states = hidden_states.flatten(2, 3)
124
+ return attn.o(hidden_states)
125
+
126
+
127
+ class Ideogram4Attention(nn.Module, AttentionModuleMixin):
128
+ """Self-attention with merged QKV, q/k RMSNorm, MRoPE and a block-diagonal segment mask."""
129
+
130
+ _default_processor_cls = Ideogram4AttnProcessor
131
+ _available_processors = [Ideogram4AttnProcessor]
132
 
133
  def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-5) -> None:
134
  super().__init__()
 
143
  self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
144
  self.o = nn.Linear(hidden_size, hidden_size, bias=False)
145
 
146
+ self.set_processor(self._default_processor_cls())
147
+
148
  def forward(
149
  self,
150
  hidden_states: torch.Tensor,
151
+ attention_mask: torch.Tensor | None = None,
152
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
153
+ **kwargs,
154
  ) -> torch.Tensor:
155
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
156
+ unused_kwargs = [k for k in kwargs if k not in attn_parameters]
157
+ if len(unused_kwargs) > 0:
158
+ logger.warning(
159
+ f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
160
+ )
161
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
162
+ return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
  class Ideogram4MLP(nn.Module):
 
199
  def forward(
200
  self,
201
  hidden_states: torch.Tensor,
202
+ attention_mask: torch.Tensor,
203
+ image_rotary_emb: tuple[torch.Tensor, torch.Tensor],
 
204
  adaln_input: torch.Tensor,
205
  ) -> torch.Tensor:
206
  mod = self.adaln_modulation(adaln_input)
 
212
 
213
  attn_out = self.attention(
214
  self.attention_norm1(hidden_states) * scale_msa,
215
+ attention_mask=attention_mask,
216
+ image_rotary_emb=image_rotary_emb,
 
217
  )
218
  hidden_states = hidden_states + gate_msa * self.attention_norm2(attn_out)
219
  hidden_states = hidden_states + gate_mlp * self.ffn_norm2(
 
268
  return self.linear(self.norm_final(hidden_states) * scale)
269
 
270
 
271
+ class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, AttentionMixin, PeftAdapterMixin, FromOriginalModelMixin):
272
  r"""
273
  The flow-matching transformer backbone used by the Ideogram 4 pipeline.
274
 
 
363
  adaln_dim=adaln_dim,
364
  )
365
 
366
+ def fuse_qkv_projections(self):
367
+ # The attention already uses a single fused `qkv` projection, so there is nothing to fuse.
368
+ raise NotImplementedError(
369
+ "Ideogram4Transformer2DModel already uses a fused QKV projection (`attention.qkv`), "
370
+ "so `fuse_qkv_projections()` is not applicable."
371
+ )
372
+
373
+ def unfuse_qkv_projections(self):
374
+ raise NotImplementedError(
375
+ "Ideogram4Transformer2DModel uses a fused QKV projection that cannot be split, "
376
+ "so `unfuse_qkv_projections()` is not applicable."
377
+ )
378
+
379
  def forward(
380
  self,
381
  hidden_states: torch.Tensor,
 
407
 
408
  Returns:
409
  [`~models.modeling_outputs.Transformer2DModelOutput`] or a `tuple` whose first element is a tensor of shape
410
+ `(batch_size, sequence_length, in_channels)` in the model's compute dtype. Only positions tagged with
411
  `OUTPUT_IMAGE_INDICATOR` carry meaningful velocity predictions.
412
  """
413
  batch_size, seq_len, in_channels = hidden_states.shape
414
  if in_channels != self.in_channels:
415
  raise ValueError(f"Expected last dim {self.in_channels}, got {in_channels}.")
416
 
 
 
 
 
 
 
417
  llm_token_mask = (indicator == LLM_TOKEN_INDICATOR).to(hidden_states.dtype).unsqueeze(-1)
418
  output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(hidden_states.dtype).unsqueeze(-1)
419
 
 
438
  cos, sin = self.rotary_emb(position_ids)
439
  cos = cos.to(hidden_states.dtype)
440
  sin = sin.to(hidden_states.dtype)
441
+ image_rotary_emb = (cos, sin)
442
+
443
+ # Block-diagonal mask from segment ids: tokens only attend within their segment. Shared by every block.
444
+ attention_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1)
445
 
446
  for block in self.layers:
447
  if torch.is_grad_enabled() and self.gradient_checkpointing:
448
  hidden_states = self._gradient_checkpointing_func(
449
+ block, hidden_states, attention_mask, image_rotary_emb, adaln_input
450
  )
451
  else:
452
+ hidden_states = block(hidden_states, attention_mask, image_rotary_emb, adaln_input)
453
 
454
+ output = self.final_layer(hidden_states, conditioning=adaln_input)
455
 
456
  if not return_dict:
457
  return (output,)
diffusers_src/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py CHANGED
@@ -29,10 +29,11 @@ from ...models.transformers.transformer_ideogram4 import (
29
  Ideogram4Transformer2DModel,
30
  )
31
  from ...schedulers import FlowMatchEulerDiscreteScheduler
32
- from ...utils import logging, replace_example_docstring
33
  from ...utils.torch_utils import randn_tensor
34
  from ..pipeline_utils import DiffusionPipeline
35
  from .pipeline_output import Ideogram4PipelineOutput
 
36
 
37
 
38
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -42,10 +43,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
  # text conditioning consumed by the Ideogram4 transformer.
43
  QWEN3_VL_ACTIVATION_LAYERS = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 35)
44
 
45
- DEFAULT_NUM_INFERENCE_STEPS = 48
46
- DEFAULT_GUIDANCE_SCHEDULE = (7.0,) * 45 + (3.0,) * 3
47
- DEFAULT_MU = 0.0
48
- DEFAULT_STD = 1.5
49
 
50
 
51
  EXAMPLE_DOC_STRING = """
@@ -109,6 +109,32 @@ def _resolution_aware_mu(
109
  return base_mu + 0.5 * math.log(num_pixels / base_pixels)
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  class Ideogram4Pipeline(DiffusionPipeline):
113
  r"""
114
  Text-to-image pipeline for Ideogram4.
@@ -165,38 +191,110 @@ class Ideogram4Pipeline(DiffusionPipeline):
165
  self.patch_size = 2
166
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * self.patch_size)
167
 
168
- @property
169
- def _patch_dim(self) -> int:
170
- return self.vae_scale_factor * self.patch_size
171
-
172
- def _tokenize(self, prompt: str, max_text_tokens: int) -> tuple[torch.Tensor, int]:
173
- """Build chat-formatted token ids for a single prompt."""
174
- messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
175
- text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
176
- encoded = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)
177
- token_ids = encoded["input_ids"][0]
178
- num_text_tokens = int(token_ids.shape[0])
179
- if num_text_tokens > max_text_tokens:
180
- raise ValueError(f"prompt has {num_text_tokens} tokens, exceeds max_sequence_length={max_text_tokens}")
181
- return token_ids, num_text_tokens
182
-
183
- def _build_inputs(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  self,
185
- prompts: list[str],
186
- height: int,
187
- width: int,
188
  max_text_tokens: int,
189
  device: torch.device,
190
- ) -> dict[str, torch.Tensor]:
191
- """Build the packed sequence (left-padded text tokens then image tokens) for one batch."""
192
- tokenized = [self._tokenize(p, max_text_tokens) for p in prompts]
193
- batch_size = len(prompts)
194
 
195
- patch = self._patch_dim
196
- if height % patch != 0 or width % patch != 0:
197
- raise ValueError(f"height/width must be divisible by vae_scale_factor*patch_size={patch}")
198
- grid_h = height // patch
199
- grid_w = width // patch
200
  num_image_tokens = grid_h * grid_w
201
  total_seq_len = max_text_tokens + num_image_tokens
202
 
@@ -206,21 +304,15 @@ class Ideogram4Pipeline(DiffusionPipeline):
206
  t_idx = torch.zeros_like(h_idx)
207
  image_pos = torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET
208
 
209
- token_ids = torch.zeros(batch_size, total_seq_len, dtype=torch.long)
210
- text_position_ids = torch.zeros(batch_size, total_seq_len, 3, dtype=torch.long)
211
  position_ids = torch.zeros(batch_size, total_seq_len, 3, dtype=torch.long)
212
  segment_ids = torch.full((batch_size, total_seq_len), SEQUENCE_PADDING_INDICATOR, dtype=torch.long)
213
  indicator = torch.zeros(batch_size, total_seq_len, dtype=torch.long)
214
 
215
- for b, (toks, num_text) in enumerate(tokenized):
216
- pad_len = max_text_tokens - num_text
217
- offset = pad_len
218
-
219
- token_ids[b, offset : offset + num_text] = toks
220
 
221
  text_pos = torch.arange(num_text)
222
  text_pos_3d = torch.stack([text_pos, text_pos, text_pos], dim=1)
223
- text_position_ids[b, offset : offset + num_text] = text_pos_3d
224
  position_ids[b, offset : offset + num_text] = text_pos_3d
225
  position_ids[b, offset + num_text :] = image_pos
226
 
@@ -229,16 +321,7 @@ class Ideogram4Pipeline(DiffusionPipeline):
229
 
230
  segment_ids[b, offset : offset + num_text + num_image_tokens] = 1
231
 
232
- return {
233
- "token_ids": token_ids.to(device),
234
- "text_position_ids": text_position_ids.to(device),
235
- "position_ids": position_ids.to(device),
236
- "segment_ids": segment_ids.to(device),
237
- "indicator": indicator.to(device),
238
- "num_image_tokens": num_image_tokens,
239
- "grid_h": grid_h,
240
- "grid_w": grid_w,
241
- }
242
 
243
  def _get_text_encoder_hidden_states(
244
  self,
@@ -283,28 +366,60 @@ class Ideogram4Pipeline(DiffusionPipeline):
283
 
284
  def encode_prompt(
285
  self,
286
- prompts: list[str],
287
- token_ids: torch.Tensor,
288
- text_position_ids: torch.Tensor,
289
- indicator: torch.Tensor,
290
  device: torch.device,
291
- ) -> torch.Tensor:
292
- """Encode prompts using the text encoder and stack hidden states from the activation layers."""
293
- batch_size, seq_len = token_ids.shape
294
 
295
- attention_mask = (indicator == LLM_TOKEN_INDICATOR).to(torch.long)
296
- pos_2d = text_position_ids[..., 0].contiguous()
 
 
 
 
297
 
298
- with torch.no_grad():
299
- selected = self._get_text_encoder_hidden_states(token_ids, attention_mask, pos_2d)
300
- stacked = torch.stack(selected, dim=0) # (num_taps, B, L, H)
301
- stacked = stacked.permute(1, 2, 3, 0)
302
- stacked = stacked.reshape(batch_size, seq_len, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
- # Zero out non-LLM positions so the transformer only sees real text features.
305
- text_mask = attention_mask.to(stacked.dtype).unsqueeze(-1)
306
- stacked = stacked * text_mask
307
- return stacked.to(torch.float32)
308
 
309
  def prepare_latents(
310
  self,
@@ -325,27 +440,6 @@ class Ideogram4Pipeline(DiffusionPipeline):
325
  latents = latents.to(device=device, dtype=dtype)
326
  return latents
327
 
328
- def _decode(self, z: torch.Tensor, grid_h: int, grid_w: int) -> torch.Tensor:
329
- """Unpatch latents, denormalize with the VAE batch-norm stats, and decode through the VAE."""
330
- batch_size = z.shape[0]
331
- patch = self.patch_size
332
-
333
- # VAE bn stores per-channel statistics on the packed-channel latent space (ae_channels * patch ** 2).
334
- bn_mean = self.vae.bn.running_mean.view(1, 1, -1).to(device=z.device, dtype=z.dtype)
335
- bn_std = torch.sqrt(self.vae.bn.running_var + self.vae.config.batch_norm_eps).view(1, 1, -1)
336
- bn_std = bn_std.to(device=z.device, dtype=z.dtype)
337
-
338
- z = z * bn_std + bn_mean
339
-
340
- ae_channels = z.shape[-1] // (patch * patch)
341
- z = z.view(batch_size, grid_h, grid_w, patch, patch, ae_channels)
342
- z = z.permute(0, 5, 1, 3, 2, 4).contiguous()
343
- z = z.view(batch_size, ae_channels, grid_h * patch, grid_w * patch)
344
-
345
- z = z.to(self.vae.dtype)
346
- image = self.vae.decode(z, return_dict=False)[0]
347
- return image
348
-
349
  @property
350
  def guidance_scale(self) -> float | None:
351
  return self._guidance_scale
@@ -358,6 +452,50 @@ class Ideogram4Pipeline(DiffusionPipeline):
358
  def interrupt(self) -> bool:
359
  return self._interrupt
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  @torch.no_grad()
362
  @replace_example_docstring(EXAMPLE_DOC_STRING)
363
  def __call__(
@@ -365,11 +503,12 @@ class Ideogram4Pipeline(DiffusionPipeline):
365
  prompt: str | list[str] | None = None,
366
  height: int = 2048,
367
  width: int = 2048,
368
- num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS,
369
  guidance_scale: float | None = None,
370
- guidance_schedule: list[float] | torch.Tensor | None = DEFAULT_GUIDANCE_SCHEDULE,
371
- mu: float = DEFAULT_MU,
372
- std: float = DEFAULT_STD,
 
373
  max_sequence_length: int = 2048,
374
  num_images_per_prompt: int = 1,
375
  generator: torch.Generator | list[torch.Generator] | None = None,
@@ -377,7 +516,7 @@ class Ideogram4Pipeline(DiffusionPipeline):
377
  output_type: str = "pil",
378
  return_dict: bool = True,
379
  callback_on_step_end: Callable[["Ideogram4Pipeline", int, int, dict[str, Any]], dict[str, Any]] | None = None,
380
- callback_on_step_end_tensor_inputs: list[str] | None = None,
381
  ) -> Ideogram4PipelineOutput | tuple[Any]:
382
  r"""
383
  Run text-to-image generation.
@@ -396,16 +535,19 @@ class Ideogram4Pipeline(DiffusionPipeline):
396
  velocity predictions are blended as `v = guidance_scale * v_pos + (1 - guidance_scale) * v_neg`.
397
  Mutually exclusive with `guidance_schedule` (setting both raises). Defaults to `None`.
398
  guidance_schedule (`list[float]` or `torch.Tensor`, *optional*):
399
- Per-step guidance scale schedule; must have length `num_inference_steps`. The first entry corresponds to
400
- the first step (largest noise level). Mutually exclusive with `guidance_scale` (setting both raises).
401
- Exactly one of `guidance_scale` and `guidance_schedule` must be set; leaving both unset raises. The
402
- recommended schedule for best quality is `DEFAULT_GUIDANCE_SCHEDULE` (7.0 for the main steps, dropping
403
- to 3.0 for the final 3 "polish" steps).
404
  mu (`float`, *optional*, defaults to 0.0):
405
  Base mean of the logit-normal flow-matching schedule. The schedule mean is shifted by half the log of
406
  the resolution ratio relative to 512x512.
407
  std (`float`, *optional*, defaults to 1.5):
408
  Standard deviation of the logit-normal flow-matching schedule.
 
 
 
 
409
  max_sequence_length (`int`, *optional*, defaults to 2048):
410
  Maximum number of text tokens per prompt.
411
  num_images_per_prompt (`int`, *optional*, defaults to 1):
@@ -428,66 +570,63 @@ class Ideogram4Pipeline(DiffusionPipeline):
428
  Returns:
429
  [`~pipelines.ideogram4.Ideogram4PipelineOutput`] or `tuple`.
430
  """
431
- if prompt is None:
432
- raise ValueError("`prompt` must be provided.")
433
- if isinstance(prompt, str):
434
- prompts = [prompt]
435
- else:
436
- prompts = list(prompt)
437
-
438
- if num_inference_steps <= 0:
439
- raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
440
- if num_images_per_prompt <= 0:
441
- raise ValueError(f"`num_images_per_prompt` must be > 0, got {num_images_per_prompt}.")
442
- if guidance_scale is not None and guidance_schedule is not None:
443
- raise ValueError(
444
- "Only one of `guidance_scale` and `guidance_schedule` may be set."
445
- )
446
- if guidance_scale is None and guidance_schedule is None:
447
- raise ValueError(
448
- "One of `guidance_scale` and `guidance_schedule` must be set."
449
- )
450
-
451
- callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs or ["latents"]
452
 
453
- patch = self._patch_dim
454
- if height % patch != 0 or width % patch != 0:
455
- raise ValueError(
456
- f"`height` ({height}) and `width` ({width}) must both be divisible by {patch} "
457
- f"(vae_scale_factor * patch_size)."
458
- )
459
 
460
  device = self._execution_device
461
  self._guidance_scale = guidance_scale
462
  self._interrupt = False
463
 
464
- # 1. Build packed input layout shared by the conditional pass.
465
- inputs = self._build_inputs(
466
- prompts=prompts,
467
- height=height,
468
- width=width,
469
- max_text_tokens=max_sequence_length,
470
- device=device,
 
471
  )
472
- batch_size = len(prompts)
473
- num_image_tokens = inputs["num_image_tokens"]
474
- grid_h, grid_w = inputs["grid_h"], inputs["grid_w"]
475
-
476
- # 2. Encode prompts.
477
- llm_features = self.encode_prompt(
478
- prompts=prompts,
479
- token_ids=inputs["token_ids"],
480
- text_position_ids=inputs["text_position_ids"],
481
- indicator=inputs["indicator"],
482
  device=device,
483
  )
484
 
485
- # 3. Replicate per-prompt tensors for num_images_per_prompt.
486
- if num_images_per_prompt > 1:
487
- llm_features = llm_features.repeat_interleave(num_images_per_prompt, dim=0)
488
- for key in ("position_ids", "segment_ids", "indicator"):
489
- inputs[key] = inputs[key].repeat_interleave(num_images_per_prompt, dim=0)
490
- effective_batch_size = batch_size * num_images_per_prompt
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
  # 4. Set up the resolution-aware logit-normal schedule on the scheduler.
493
  schedule_mu = _resolution_aware_mu(height=height, width=width, base_mu=mu)
@@ -496,21 +635,16 @@ class Ideogram4Pipeline(DiffusionPipeline):
496
  timesteps = self.scheduler.timesteps
497
  self._num_timesteps = len(timesteps)
498
 
499
- # 5. Resolve per-step guidance weights. A constant `guidance_scale` takes one path; otherwise use the
500
- # `guidance_schedule`. Exactly one of the two is set (validated above).
501
  if guidance_scale is not None:
502
- gw = torch.full((num_inference_steps,), float(guidance_scale), dtype=torch.float32, device=device)
503
- else:
504
- gw = torch.as_tensor(guidance_schedule, dtype=torch.float32, device=device)
505
- if gw.shape != (num_inference_steps,):
506
- raise ValueError(
507
- f"`guidance_schedule` must have shape ({num_inference_steps},), got {tuple(gw.shape)}"
508
- )
509
 
510
  # 6. Prepare latents in the packed (B, num_image_tokens, latent_dim) layout.
511
  latent_dim = self.transformer.config.in_channels
512
  latents = self.prepare_latents(
513
- batch_size=effective_batch_size,
514
  num_image_tokens=num_image_tokens,
515
  latent_dim=latent_dim,
516
  dtype=torch.float32,
@@ -519,27 +653,21 @@ class Ideogram4Pipeline(DiffusionPipeline):
519
  latents=latents,
520
  )
521
 
522
- # 7. Pre-compute the inputs for the unconditional (image-only) branch.
523
  max_text_tokens = max_sequence_length
524
- neg_position_ids = inputs["position_ids"][:, max_text_tokens:]
525
- neg_segment_ids = inputs["segment_ids"][:, max_text_tokens:]
526
- neg_indicator = inputs["indicator"][:, max_text_tokens:]
527
- neg_llm_features = torch.zeros(
528
- effective_batch_size,
529
- num_image_tokens,
530
- llm_features.shape[-1],
531
- dtype=llm_features.dtype,
532
- device=device,
533
- )
534
-
535
  text_z_padding = torch.zeros(
536
- effective_batch_size,
537
  max_text_tokens,
538
  latent_dim,
539
  dtype=torch.float32,
540
  device=device,
541
  )
542
 
 
 
 
 
 
543
  # 8. Denoising loop. The scheduler stores `num_train_timesteps`-scaled timesteps; convert back to model time.
544
  num_train_timesteps = self.scheduler.config.num_train_timesteps
545
  with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -549,36 +677,40 @@ class Ideogram4Pipeline(DiffusionPipeline):
549
 
550
  # Map sigma-domain timestep to model time `t` in [0, 1] (0 = noise, 1 = clean data).
551
  t_model = 1.0 - (t.float() / num_train_timesteps)
552
- t_model = t_model.expand(effective_batch_size).to(self.transformer.dtype)
553
 
554
  # Conditional pass operates on the full packed sequence.
555
- pos_z = torch.cat([text_z_padding, latents], dim=1)
556
  pos_out = self.transformer(
557
  hidden_states=pos_z,
558
  timestep=t_model,
559
  encoder_hidden_states=llm_features,
560
- position_ids=inputs["position_ids"],
561
- segment_ids=inputs["segment_ids"],
562
- indicator=inputs["indicator"],
563
  return_dict=False,
564
  )[0]
565
- pos_v = pos_out[:, max_text_tokens:]
 
 
566
 
567
  # Unconditional pass uses image-only positions with zeroed text features.
568
  neg_v = self.unconditional_transformer(
569
- hidden_states=latents,
570
  timestep=t_model,
571
  encoder_hidden_states=neg_llm_features,
572
  position_ids=neg_position_ids,
573
  segment_ids=neg_segment_ids,
574
  indicator=neg_indicator,
575
  return_dict=False,
576
- )[0]
577
 
 
 
578
  gw_i = gw[i]
579
  v = gw_i * pos_v + (1.0 - gw_i) * neg_v
580
 
581
- latents = self.scheduler.step(-v.to(torch.float32), t, latents, return_dict=False)[0]
582
 
583
  if callback_on_step_end is not None:
584
  callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
@@ -587,11 +719,24 @@ class Ideogram4Pipeline(DiffusionPipeline):
587
 
588
  progress_bar.update()
589
 
590
- # 9. Decode.
591
  if output_type == "latent":
592
  image = latents
593
  else:
594
- decoded = self._decode(latents, grid_h=grid_h, grid_w=grid_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  image = self.image_processor.postprocess(decoded.float(), output_type=output_type)
596
 
597
  self.maybe_free_model_hooks()
 
29
  Ideogram4Transformer2DModel,
30
  )
31
  from ...schedulers import FlowMatchEulerDiscreteScheduler
32
+ from ...utils import is_outlines_available, logging, replace_example_docstring
33
  from ...utils.torch_utils import randn_tensor
34
  from ..pipeline_utils import DiffusionPipeline
35
  from .pipeline_output import Ideogram4PipelineOutput
36
+ from .prompt_enhancer import CAPTION_SYSTEM_MESSAGE, CAPTION_USER_TEMPLATE, build_caption_logits_processor
37
 
38
 
39
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
43
  # text conditioning consumed by the Ideogram4 transformer.
44
  QWEN3_VL_ACTIVATION_LAYERS = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 35)
45
 
46
+ # LM head grafted onto the (head-less) text encoder for optional prompt upsampling.
47
+ DEFAULT_PROMPT_ENHANCER_LM_HEAD_REPO = "multimodalart/qwen3-vl-8b-instruct-lm-head"
48
+ PROMPT_UPSAMPLE_TEMPERATURE = 1.0
 
49
 
50
 
51
  EXAMPLE_DOC_STRING = """
 
109
  return base_mu + 0.5 * math.log(num_pixels / base_pixels)
110
 
111
 
112
+ def _expand_tensor_to_effective_batch(
113
+ tensor: torch.Tensor,
114
+ batch_size: int,
115
+ num_per_prompt: int,
116
+ tensor_name: str | None = None,
117
+ ) -> torch.Tensor:
118
+ """Replicate `tensor` along dim 0 from `batch_size` (or 1) to `batch_size * num_per_prompt`."""
119
+ target_batch_size = batch_size * num_per_prompt
120
+
121
+ if tensor.shape[0] == target_batch_size:
122
+ return tensor
123
+
124
+ if tensor.shape[0] == 1:
125
+ repeat_by = target_batch_size
126
+ elif tensor.shape[0] == batch_size:
127
+ repeat_by = num_per_prompt
128
+ else:
129
+ tensor_name = f"`{tensor_name}`" if tensor_name is not None else "Tensor"
130
+ raise ValueError(
131
+ f"{tensor_name} batch size must be 1, `batch_size` ({batch_size}), or "
132
+ f"`batch_size * num_*_per_prompt` ({target_batch_size}), but got {tensor.shape[0]}."
133
+ )
134
+
135
+ return torch.repeat_interleave(tensor, repeats=repeat_by, dim=0, output_size=tensor.shape[0] * repeat_by)
136
+
137
+
138
  class Ideogram4Pipeline(DiffusionPipeline):
139
  r"""
140
  Text-to-image pipeline for Ideogram4.
 
191
  self.patch_size = 2
192
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * self.patch_size)
193
 
194
+ # Lazily built by `load_prompt_enhancer` for optional prompt upsampling.
195
+ self._caption_model = None
196
+ self._caption_logits_processor = None
197
+
198
+ def load_prompt_enhancer(
199
+ self,
200
+ lm_head_repo_id: str = DEFAULT_PROMPT_ENHANCER_LM_HEAD_REPO,
201
+ lm_head_filename: str = "lm_head.safetensors",
202
+ torch_dtype: torch.dtype | None = None,
203
+ ) -> PreTrainedModel:
204
+ """Make the frozen text encoder generative for prompt upsampling by grafting a hosted LM head.
205
+
206
+ The head is the only extra weight loaded; the encoder body is shared (no second model in memory).
207
+ Called automatically by `upsample_prompt` on first use. Generation is constrained to the caption JSON
208
+ schema when `outlines` is installed; otherwise it falls back to unconstrained decoding with a warning.
209
+ """
210
+ from accelerate import init_empty_weights
211
+ from huggingface_hub import hf_hub_download
212
+ from safetensors.torch import load_file
213
+ from transformers import Qwen3VLForConditionalGeneration
214
+
215
+ dtype = torch_dtype or self.text_encoder.dtype
216
+ head_weight = load_file(hf_hub_download(lm_head_repo_id, lm_head_filename))["lm_head.weight"].to(dtype)
217
+
218
+ with init_empty_weights():
219
+ caption_model = Qwen3VLForConditionalGeneration(self.text_encoder.config)
220
+ caption_model.model = self.text_encoder # reuse the loaded encoder body
221
+ lm_head = torch.nn.Linear(head_weight.shape[1], head_weight.shape[0], bias=False)
222
+ with torch.no_grad():
223
+ lm_head.weight.copy_(head_weight)
224
+ caption_model.lm_head = lm_head.to(device=self.text_encoder.device, dtype=dtype)
225
+ caption_model.eval()
226
+
227
+ if is_outlines_available():
228
+ logits_processor = build_caption_logits_processor(caption_model, self.tokenizer)
229
+ else:
230
+ logits_processor = None
231
+ logger.warning(
232
+ "`outlines` is not installed; prompt upsampling will run unconstrained and may not return "
233
+ "schema-valid JSON. Install with `pip install outlines` for structured captions."
234
+ )
235
+
236
+ self._caption_model = caption_model
237
+ self._caption_logits_processor = logits_processor
238
+ return caption_model
239
+
240
+ def upsample_prompt(
241
+ self,
242
+ prompt: str | list[str],
243
+ height: int = 2048,
244
+ width: int = 2048,
245
+ max_new_tokens: int = 1024,
246
+ lm_head_repo_id: str = DEFAULT_PROMPT_ENHANCER_LM_HEAD_REPO,
247
+ device: torch.device | None = None,
248
+ ) -> list[str]:
249
+ """Rewrite each prompt into Ideogram4's native structured JSON caption via the grafted text encoder."""
250
+ if self._caption_model is None:
251
+ self.load_prompt_enhancer(lm_head_repo_id=lm_head_repo_id)
252
+
253
+ device = device or self._caption_model.device
254
+ prompts = [prompt] if isinstance(prompt, str) else list(prompt)
255
+ divisor = math.gcd(width, height) or 1
256
+ aspect_ratio = f"{width // divisor}:{height // divisor}"
257
+
258
+ captions = []
259
+ for text_prompt in prompts:
260
+ messages = [
261
+ {"role": "system", "content": CAPTION_SYSTEM_MESSAGE},
262
+ {
263
+ "role": "user",
264
+ "content": CAPTION_USER_TEMPLATE.format(aspect_ratio=aspect_ratio, original_prompt=text_prompt),
265
+ },
266
+ ]
267
+ inputs = self.tokenizer.apply_chat_template(
268
+ messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
269
+ ).to(device)
270
+ generate_kwargs = {
271
+ "max_new_tokens": max_new_tokens,
272
+ "do_sample": True,
273
+ "temperature": PROMPT_UPSAMPLE_TEMPERATURE,
274
+ "use_cache": True,
275
+ }
276
+ if self._caption_logits_processor is not None:
277
+ self._caption_logits_processor.reset()
278
+ generate_kwargs["logits_processor"] = [self._caption_logits_processor]
279
+ generated = self._caption_model.generate(**inputs, **generate_kwargs)
280
+ new_tokens = generated[:, inputs["input_ids"].shape[1] :]
281
+ captions.append(self.tokenizer.decode(new_tokens[0], skip_special_tokens=True).strip())
282
+ return captions
283
+
284
+ def _prepare_ids(
285
  self,
286
+ text_lengths: list[int],
287
+ grid_h: int,
288
+ grid_w: int,
289
  max_text_tokens: int,
290
  device: torch.device,
291
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
292
+ """Build the packed `[left-pad][text][image]` layout from the per-prompt text lengths and the image grid.
 
 
293
 
294
+ Returns `position_ids` (3-axis MRoPE), `segment_ids` (block-diagonal attention) and `indicator` (per-token
295
+ text/image/pad role).
296
+ """
297
+ batch_size = len(text_lengths)
 
298
  num_image_tokens = grid_h * grid_w
299
  total_seq_len = max_text_tokens + num_image_tokens
300
 
 
304
  t_idx = torch.zeros_like(h_idx)
305
  image_pos = torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET
306
 
 
 
307
  position_ids = torch.zeros(batch_size, total_seq_len, 3, dtype=torch.long)
308
  segment_ids = torch.full((batch_size, total_seq_len), SEQUENCE_PADDING_INDICATOR, dtype=torch.long)
309
  indicator = torch.zeros(batch_size, total_seq_len, dtype=torch.long)
310
 
311
+ for b, num_text in enumerate(text_lengths):
312
+ offset = max_text_tokens - num_text
 
 
 
313
 
314
  text_pos = torch.arange(num_text)
315
  text_pos_3d = torch.stack([text_pos, text_pos, text_pos], dim=1)
 
316
  position_ids[b, offset : offset + num_text] = text_pos_3d
317
  position_ids[b, offset + num_text :] = image_pos
318
 
 
321
 
322
  segment_ids[b, offset : offset + num_text + num_image_tokens] = 1
323
 
324
+ return position_ids.to(device), segment_ids.to(device), indicator.to(device)
 
 
 
 
 
 
 
 
 
325
 
326
  def _get_text_encoder_hidden_states(
327
  self,
 
366
 
367
  def encode_prompt(
368
  self,
369
+ prompt: str | list[str],
370
+ grid_h: int,
371
+ grid_w: int,
372
+ max_sequence_length: int,
373
  device: torch.device,
374
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
375
+ """Prepare the conditioning for the packed text+image sequence (one entry per prompt).
 
376
 
377
+ Returns a flat tuple `(prompt_embeds, position_ids, segment_ids, indicator)`. The unconditional branch carries
378
+ no text, so the pipeline builds its (zeroed) inputs directly rather than encoding a negative prompt.
379
+ """
380
+ prompts = [prompt] if isinstance(prompt, str) else list(prompt)
381
+ batch_size = len(prompts)
382
+ num_image_tokens = grid_h * grid_w
383
 
384
+ # Tokenize each chat-formatted prompt and left-pad to `max_sequence_length`. Only the text region is fed to
385
+ # the encoder: the packed image tokens come after the text and the encoder is causal, so they never affect it.
386
+ token_ids = torch.zeros(batch_size, max_sequence_length, dtype=torch.long)
387
+ attention_mask = torch.zeros(batch_size, max_sequence_length, dtype=torch.long)
388
+ text_position_ids = torch.zeros(batch_size, max_sequence_length, dtype=torch.long)
389
+ text_lengths = []
390
+ for b, text_prompt in enumerate(prompts):
391
+ messages = [{"role": "user", "content": [{"type": "text", "text": text_prompt}]}]
392
+ text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
393
+ toks = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0]
394
+ n = int(toks.shape[0])
395
+ if n > max_sequence_length:
396
+ raise ValueError(f"prompt has {n} tokens, exceeds max_sequence_length={max_sequence_length}")
397
+ text_lengths.append(n)
398
+ offset = max_sequence_length - n
399
+ token_ids[b, offset:] = toks
400
+ attention_mask[b, offset:] = 1
401
+ text_position_ids[b, offset:] = torch.arange(n)
402
+
403
+ token_ids = token_ids.to(device)
404
+ attention_mask = attention_mask.to(device)
405
+ text_position_ids = text_position_ids.to(device)
406
+
407
+ # Concatenate the tapped activation-layer hidden states into per-token text features, zeroing padding.
408
+ selected = self._get_text_encoder_hidden_states(token_ids, attention_mask, text_position_ids)
409
+ text_features = torch.stack(selected, dim=0).permute(1, 2, 3, 0).reshape(batch_size, max_sequence_length, -1)
410
+ text_features = (text_features * attention_mask.to(text_features.dtype).unsqueeze(-1)).to(torch.float32)
411
+
412
+ position_ids, segment_ids, indicator = self._prepare_ids(
413
+ text_lengths, grid_h, grid_w, max_sequence_length, device
414
+ )
415
+
416
+ # Pack the text features into the full sequence; image positions carry no text features.
417
+ image_feature_padding = torch.zeros(
418
+ batch_size, num_image_tokens, text_features.shape[-1], dtype=text_features.dtype, device=device
419
+ )
420
+ prompt_embeds = torch.cat([text_features, image_feature_padding], dim=1)
421
 
422
+ return prompt_embeds, position_ids, segment_ids, indicator
 
 
 
423
 
424
  def prepare_latents(
425
  self,
 
440
  latents = latents.to(device=device, dtype=dtype)
441
  return latents
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  @property
444
  def guidance_scale(self) -> float | None:
445
  return self._guidance_scale
 
452
  def interrupt(self) -> bool:
453
  return self._interrupt
454
 
455
+ def check_inputs(
456
+ self,
457
+ prompt,
458
+ height,
459
+ width,
460
+ num_inference_steps,
461
+ guidance_scale,
462
+ guidance_schedule,
463
+ callback_on_step_end_tensor_inputs=None,
464
+ ):
465
+ if prompt is None:
466
+ raise ValueError("`prompt` must be provided.")
467
+ if not isinstance(prompt, (str, list)):
468
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
469
+
470
+ if (
471
+ height % (self.vae_scale_factor * self.patch_size) != 0
472
+ or width % (self.vae_scale_factor * self.patch_size) != 0
473
+ ):
474
+ raise ValueError(
475
+ f"`height` ({height}) and `width` ({width}) must both be divisible by {self.vae_scale_factor * self.patch_size} "
476
+ f"(vae_scale_factor * patch_size)."
477
+ )
478
+
479
+ # Guidance is controlled by either a constant `guidance_scale` or a per-step `guidance_schedule`; exactly
480
+ # one must be set (the `guidance_schedule` default makes the no-arg call use the recommended schedule).
481
+ if guidance_scale is not None and guidance_schedule is not None:
482
+ raise ValueError("Only one of `guidance_scale` and `guidance_schedule` may be set.")
483
+ if guidance_scale is None and guidance_schedule is None:
484
+ raise ValueError("One of `guidance_scale` and `guidance_schedule` must be set.")
485
+ if guidance_schedule is not None and len(guidance_schedule) != num_inference_steps:
486
+ raise ValueError(
487
+ f"`guidance_schedule` must have length `num_inference_steps` ({num_inference_steps}), "
488
+ f"got {len(guidance_schedule)}."
489
+ )
490
+
491
+ if callback_on_step_end_tensor_inputs is not None and not all(
492
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
493
+ ):
494
+ raise ValueError(
495
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
496
+ f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
497
+ )
498
+
499
  @torch.no_grad()
500
  @replace_example_docstring(EXAMPLE_DOC_STRING)
501
  def __call__(
 
503
  prompt: str | list[str] | None = None,
504
  height: int = 2048,
505
  width: int = 2048,
506
+ num_inference_steps: int = 48,
507
  guidance_scale: float | None = None,
508
+ guidance_schedule: list[float] | torch.Tensor | None = (7.0,) * 45 + (3.0,) * 3,
509
+ mu: float = 0.0,
510
+ std: float = 1.5,
511
+ prompt_upsampling: bool = False,
512
  max_sequence_length: int = 2048,
513
  num_images_per_prompt: int = 1,
514
  generator: torch.Generator | list[torch.Generator] | None = None,
 
516
  output_type: str = "pil",
517
  return_dict: bool = True,
518
  callback_on_step_end: Callable[["Ideogram4Pipeline", int, int, dict[str, Any]], dict[str, Any]] | None = None,
519
+ callback_on_step_end_tensor_inputs: list[str] = ["latents"],
520
  ) -> Ideogram4PipelineOutput | tuple[Any]:
521
  r"""
522
  Run text-to-image generation.
 
535
  velocity predictions are blended as `v = guidance_scale * v_pos + (1 - guidance_scale) * v_neg`.
536
  Mutually exclusive with `guidance_schedule` (setting both raises). Defaults to `None`.
537
  guidance_schedule (`list[float]` or `torch.Tensor`, *optional*):
538
+ Per-step guidance scale schedule; must have length `num_inference_steps`. The first entry corresponds
539
+ to the first step (largest noise level). Mutually exclusive with `guidance_scale`; exactly one must be
540
+ set. Defaults to the recommended schedule (7.0 for the main steps, dropping to 3.0 for the final 3
541
+ "polish" steps). To use a constant scale instead, pass `guidance_scale` and `guidance_schedule=None`.
 
542
  mu (`float`, *optional*, defaults to 0.0):
543
  Base mean of the logit-normal flow-matching schedule. The schedule mean is shifted by half the log of
544
  the resolution ratio relative to 512x512.
545
  std (`float`, *optional*, defaults to 1.5):
546
  Standard deviation of the logit-normal flow-matching schedule.
547
+ prompt_upsampling (`bool`, *optional*, defaults to `False`):
548
+ If `True`, rewrite `prompt` into Ideogram4's native structured JSON caption via
549
+ [`~Ideogram4Pipeline.upsample_prompt`] before encoding. Requires the prompt-enhancer LM head
550
+ (downloaded on first use); install `outlines` for schema-constrained captions.
551
  max_sequence_length (`int`, *optional*, defaults to 2048):
552
  Maximum number of text tokens per prompt.
553
  num_images_per_prompt (`int`, *optional*, defaults to 1):
 
570
  Returns:
571
  [`~pipelines.ideogram4.Ideogram4PipelineOutput`] or `tuple`.
572
  """
573
+ self.check_inputs(
574
+ prompt=prompt,
575
+ height=height,
576
+ width=width,
577
+ num_inference_steps=num_inference_steps,
578
+ guidance_scale=guidance_scale,
579
+ guidance_schedule=guidance_schedule,
580
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
581
+ )
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
+ if isinstance(prompt, str):
584
+ batch_size = 1
585
+ elif isinstance(prompt, list):
586
+ batch_size = len(prompt)
 
 
587
 
588
  device = self._execution_device
589
  self._guidance_scale = guidance_scale
590
  self._interrupt = False
591
 
592
+ # 0. Optionally rewrite the prompt(s) into Ideogram4's native structured JSON caption.
593
+ if prompt_upsampling:
594
+ prompt = self.upsample_prompt(prompt, height=height, width=width, device=device)
595
+
596
+ # 1. Image grid (drives both the packed layout and the latent shape).
597
+ grid_h, grid_w = (
598
+ height // (self.vae_scale_factor * self.patch_size),
599
+ width // (self.vae_scale_factor * self.patch_size),
600
  )
601
+ num_image_tokens = grid_h * grid_w
602
+
603
+ # 2. Encode prompts into the packed conditioning (one entry per prompt).
604
+ llm_features, position_ids, segment_ids, indicator = self.encode_prompt(
605
+ prompt=prompt,
606
+ grid_h=grid_h,
607
+ grid_w=grid_w,
608
+ max_sequence_length=max_sequence_length,
 
 
609
  device=device,
610
  )
611
 
612
+ # 3. Replicate the conditioning for num_images_per_prompt.
613
+ llm_features = _expand_tensor_to_effective_batch(llm_features, batch_size, num_images_per_prompt)
614
+ position_ids = _expand_tensor_to_effective_batch(position_ids, batch_size, num_images_per_prompt)
615
+ segment_ids = _expand_tensor_to_effective_batch(segment_ids, batch_size, num_images_per_prompt)
616
+ indicator = _expand_tensor_to_effective_batch(indicator, batch_size, num_images_per_prompt)
617
+
618
+ # 4. Unconditional (image-only) branch, derived from the conditioning: zeroed text features and the
619
+ # image-region slices of the layout.
620
+ neg_llm_features = torch.zeros(
621
+ batch_size * num_images_per_prompt,
622
+ num_image_tokens,
623
+ llm_features.shape[-1],
624
+ dtype=llm_features.dtype,
625
+ device=device,
626
+ )
627
+ neg_position_ids = position_ids[:, max_sequence_length:]
628
+ neg_segment_ids = segment_ids[:, max_sequence_length:]
629
+ neg_indicator = indicator[:, max_sequence_length:]
630
 
631
  # 4. Set up the resolution-aware logit-normal schedule on the scheduler.
632
  schedule_mu = _resolution_aware_mu(height=height, width=width, base_mu=mu)
 
635
  timesteps = self.scheduler.timesteps
636
  self._num_timesteps = len(timesteps)
637
 
638
+ # 5. Resolve the per-step guidance schedule (a constant `guidance_scale` broadcasts to every step, otherwise
639
+ # use the provided `guidance_schedule`, validated by `check_inputs`) and the tensor of per-step weights `gw`.
640
  if guidance_scale is not None:
641
+ guidance_schedule = [float(guidance_scale)] * num_inference_steps
642
+ gw = torch.as_tensor(guidance_schedule, dtype=torch.float32, device=device)
 
 
 
 
 
643
 
644
  # 6. Prepare latents in the packed (B, num_image_tokens, latent_dim) layout.
645
  latent_dim = self.transformer.config.in_channels
646
  latents = self.prepare_latents(
647
+ batch_size=batch_size * num_images_per_prompt,
648
  num_image_tokens=num_image_tokens,
649
  latent_dim=latent_dim,
650
  dtype=torch.float32,
 
653
  latents=latents,
654
  )
655
 
656
+ # 7. Padding for the text region of the conditional packed sequence (image latents are appended after it).
657
  max_text_tokens = max_sequence_length
 
 
 
 
 
 
 
 
 
 
 
658
  text_z_padding = torch.zeros(
659
+ batch_size * num_images_per_prompt,
660
  max_text_tokens,
661
  latent_dim,
662
  dtype=torch.float32,
663
  device=device,
664
  )
665
 
666
+ # The transformers run in their loaded compute dtype; cast the (otherwise float32) text features to match.
667
+ # `latents` stay float32 for scheduler precision and are cast per-step at the transformer call below.
668
+ llm_features = llm_features.to(self.transformer.dtype)
669
+ neg_llm_features = neg_llm_features.to(self.unconditional_transformer.dtype)
670
+
671
  # 8. Denoising loop. The scheduler stores `num_train_timesteps`-scaled timesteps; convert back to model time.
672
  num_train_timesteps = self.scheduler.config.num_train_timesteps
673
  with self.progress_bar(total=num_inference_steps) as progress_bar:
 
677
 
678
  # Map sigma-domain timestep to model time `t` in [0, 1] (0 = noise, 1 = clean data).
679
  t_model = 1.0 - (t.float() / num_train_timesteps)
680
+ t_model = t_model.expand(batch_size * num_images_per_prompt).to(self.transformer.dtype)
681
 
682
  # Conditional pass operates on the full packed sequence.
683
+ pos_z = torch.cat([text_z_padding, latents], dim=1).to(self.transformer.dtype)
684
  pos_out = self.transformer(
685
  hidden_states=pos_z,
686
  timestep=t_model,
687
  encoder_hidden_states=llm_features,
688
+ position_ids=position_ids,
689
+ segment_ids=segment_ids,
690
+ indicator=indicator,
691
  return_dict=False,
692
  )[0]
693
+ # Velocity (and guidance) is computed in float32 for scheduler precision; the transformers
694
+ # return their compute dtype, so cast the predicted velocities up here.
695
+ pos_v = pos_out[:, max_text_tokens:].to(torch.float32)
696
 
697
  # Unconditional pass uses image-only positions with zeroed text features.
698
  neg_v = self.unconditional_transformer(
699
+ hidden_states=latents.to(self.unconditional_transformer.dtype),
700
  timestep=t_model,
701
  encoder_hidden_states=neg_llm_features,
702
  position_ids=neg_position_ids,
703
  segment_ids=neg_segment_ids,
704
  indicator=neg_indicator,
705
  return_dict=False,
706
+ )[0].to(torch.float32)
707
 
708
+ # Expose the current step's guidance weight via `self.guidance_scale` so callbacks can read it.
709
+ self._guidance_scale = guidance_schedule[i]
710
  gw_i = gw[i]
711
  v = gw_i * pos_v + (1.0 - gw_i) * neg_v
712
 
713
+ latents = self.scheduler.step(-v, t, latents, return_dict=False)[0]
714
 
715
  if callback_on_step_end is not None:
716
  callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
 
719
 
720
  progress_bar.update()
721
 
722
+ # 9. Decode: unpatch the latents, denormalize with the VAE batch-norm stats, and decode through the VAE.
723
  if output_type == "latent":
724
  image = latents
725
  else:
726
+ z = latents
727
+ # VAE bn stores per-channel statistics on the packed-channel latent space (ae_channels * patch ** 2).
728
+ bn_mean = self.vae.bn.running_mean.view(1, 1, -1).to(device=z.device, dtype=z.dtype)
729
+ bn_std = torch.sqrt(self.vae.bn.running_var + self.vae.config.batch_norm_eps).view(1, 1, -1)
730
+ bn_std = bn_std.to(device=z.device, dtype=z.dtype)
731
+ z = z * bn_std + bn_mean
732
+
733
+ patch = self.patch_size
734
+ ae_channels = z.shape[-1] // (patch * patch)
735
+ z = z.view(batch_size * num_images_per_prompt, grid_h, grid_w, patch, patch, ae_channels)
736
+ z = z.permute(0, 5, 1, 3, 2, 4).contiguous()
737
+ z = z.view(batch_size * num_images_per_prompt, ae_channels, grid_h * patch, grid_w * patch)
738
+
739
+ decoded = self.vae.decode(z.to(self.vae.dtype), return_dict=False)[0]
740
  image = self.image_processor.postprocess(decoded.float(), output_type=output_type)
741
 
742
  self.maybe_free_model_hooks()
diffusers_src/src/diffusers/pipelines/ideogram4/prompt_enhancer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Ideogram AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Prompt-enhancement assets for Ideogram4.
16
+
17
+ Ideogram4 is trained on a *structured JSON caption* rather than a free-form prompt. The optional prompt
18
+ enhancer rewrites a short user idea into that native caption schema, using the pipeline's own (frozen)
19
+ Qwen3-VL text encoder grafted with a generative head (see `Ideogram4Pipeline.load_prompt_enhancer`).
20
+
21
+ This mirrors the role of Flux2's `system_messages.py`, but the target is a constrained JSON object instead of
22
+ free text, so `outlines` (an optional dependency) is used to guarantee a schema-valid result when available.
23
+ """
24
+
25
+ # System message that instructs the encoder to emit Ideogram4's native single-line JSON caption.
26
+ CAPTION_SYSTEM_MESSAGE = """You convert a short user idea into a structured JSON caption for an image renderer. Output ONE minified single-line JSON object and NOTHING else (no markdown, no commentary).
27
+
28
+ SCHEMA — keys in this exact order:
29
+ {"high_level_description":"...","compositional_deconstruction":{"background":"...","elements":[ ... ]}}
30
+ - object element: {"type":"obj","desc":"..."}
31
+ - text element: {"type":"text","text":"VERBATIM CHARS","desc":"..."}
32
+
33
+ STEP 1 — PICK THE MEDIUM. It decides what `background` and `elements` mean. Honor any medium or style the user implies; default to photograph only when nothing else fits. Render ANY subject faithfully — real, fantastical, sci-fi, surreal, abstract — in the chosen medium.
34
+
35
+ A) DESIGNED ARTIFACT — poster, logo, album/book cover, flyer, banner, sticker, packaging, app icon, infographic, menu, card, wordmark. THE FRAME IS THE ARTIFACT — never a photo of it hanging in a room.
36
+ - high_level_description: name it as graphic design (e.g. "a minimalist jazz poster, flat graphic design...").
37
+ - background: the design's OWN backdrop only — a flat color, gradient, or simple texture filling the frame. No room, wall, floor, easel, depth, or camera/photo language.
38
+ - elements: the design's parts as a flat 2D layout — a `text` element for every headline/label (verbatim), `obj` elements for the central graphic/illustration/shapes/badges. Place by region (top / center / bottom).
39
+
40
+ B) SCENE — a photograph, illustration, painting, 3D render, anime frame, etc. of a real or imagined place or subject.
41
+ - high_level_description: one sentence naming the subject and the medium/style.
42
+ - background: the scene SHELL — surroundings, ground/sky/walls, atmosphere, ambient light; concrete and specific. The ground/floor/water surface lives here, never as an element.
43
+ - elements: the main subject FIRST as an `obj`, then supporting `obj` elements (props, secondary subjects) that plausibly belong. Add `text` elements only where the scene would really carry text (signs, labels, brands).
44
+
45
+ C) ABSTRACT / CONCEPTUAL — "nostalgia", "chaos and order", "sound waves", pure pattern. Concretize the idea into a deliberate visual composition.
46
+ - background: the dominant color field, gradient, or texture of the composition.
47
+ - elements: the shapes, forms, motifs, or symbolic objects that carry the concept, as `obj` elements. Add `text` only if the idea calls for words.
48
+
49
+ UNIVERSAL RULES (every medium):
50
+ 1. The user's core subject/concept MUST appear among the elements (as an `obj`, normally first). Naming it only in high_level_description or background is NOT enough.
51
+ 2. Commit to ONE concrete value each (one color, one style, one count). No hedging: ban "various", "such as", "e.g.", "or similar", "maybe", "X or Y" for one property.
52
+ 3. NEVER use a transparent, empty, or plain white background UNLESS the user explicitly says "transparent", "isolated", "sticker", or "cutout".
53
+ 4. A coherent subject (one animal, person, vehicle, object) is exactly ONE element; its parts go inside its `desc`. Use separate elements for genuinely separate subjects.
54
+ 5. Each `desc` is 25-55 words, identity-first, standalone. Do not mention shadows, depth of field, bokeh, lens, focus, or grain.
55
+ 6. high_level_description: one sentence, at most 40 words, starts with the subject, names the medium. Preserve non-ASCII characters as-is.
56
+ 7. Output STRICTLY VALID JSON: double quotes around every key and string, NO trailing commas, each element object closes with "}" right after its last value.
57
+ 8. Catch the "warm" impulse. Only when you are about to describe light as "warm", "golden", "amber", or "honey", stop and check: is there a specific physical source in the scene casting that colour (candle, sunset, lamp, neon, fire)? If YES, name the source and the colour it casts instead of the mood word. If NO, you are just reaching for warmth as ambience — drop it and leave the light neutral ("soft" or "even"). Don't recolour or relight anything else; this only intercepts the warm reach, every other scene and mood the user wants is untouched.
58
+ 9. Describe physical reality, not impressions. Avoid mood-words — "luminous", "radiant", "vibrant", "lush", "dynamic", "gorgeous", "stunning", "breathtaking", "mesmerizing", and metaphorical "glowing" — they produce a generic AI look (the same trap as "warm"). Use observable properties: "the cheekbone catches a small highlight", not "luminous complexion".
59
+ 10. Every named thing must appear as its own element. Each subject, object, sign, and quoted phrase the user names gets its own element — quoted text (single or double quotes) becomes its own verbatim `text` element. Count the named units in the prompt; the element list must hold at least that many. Don't drop or merge them.
60
+ 11. Don't add what wasn't asked for. No glitch art, wireframe overlay, body fragmentation, double-exposure, "dissolving", or extra stylization unless the prompt requests it. Asked for a cinematic photo of a journalist → render that, not a glitch-art composite.
61
+ 12. Name attributes concretely, anchored to landmarks. People: skin tone, hair (colour + style), each visible garment with colour, expression, pose, one distinguishing feature. Objects: shape, material, colour, a distinctive part. Place things against named references — "resting on the lower-right corner of the table", not "on the surface".
62
+ 13. Name real references by name. If the user names a brand, product, character, place, or person (Nike Dunk Low, Spider-Man, the Eiffel Tower), keep that exact name in the `desc`; don't swap it for a generic look-alike unless they ask for an anonymous one.
63
+ 14. "Professional photo/headshot" of a person means professional CONTEXT — neutral attire, soft even daylight, neutral backdrop, friendly expression — not dramatic studio gear; no heavy rim-light or creamy bokeh unless asked.
64
+
65
+ EXAMPLES
66
+
67
+ User idea: a cup of coffee on a table
68
+ Output: {"high_level_description":"A white ceramic cup of black coffee on a worn wooden cafe table, a casual overcast-daylight phone photograph with an off-center composition.","compositional_deconstruction":{"background":"Scratched oak cafe table filling the lower frame, a pale grey mortar-lined brick wall a few feet behind slightly out of focus, a tall window on the left spilling soft overcast daylight across the table, neutral white balance, muted brown and green tones.","elements":[{"type":"obj","desc":"White ceramic cup of black coffee with a thin curved handle turned to the right and a faint crema ring at the rim, resting on a matching round saucer near the center of the table, a thin wisp of steam at the surface."},{"type":"obj","desc":"Brushed-steel teaspoon lying on the saucer to the right of the cup, handle angled toward the lower-right corner, a single small water droplet on the bowl of the spoon."}]}}
69
+
70
+ User idea: a minimalist poster for a jazz festival
71
+ Output: {"high_level_description":"A minimalist jazz festival poster, flat graphic design with bold typography and a single abstract saxophone motif on a deep teal background.","compositional_deconstruction":{"background":"Solid deep teal background filling the entire frame with a subtle fine paper-grain texture and a thin mustard-yellow keyline border just inside the edges, no scene and no depth.","elements":[{"type":"obj","desc":"A large flat geometric saxophone in mustard yellow and cream, centered in the upper two-thirds, built from simple bold shapes with no shading, angled diagonally from lower-left to upper-right."},{"type":"text","text":"JAZZ\\nFESTIVAL","desc":"Large bold condensed sans-serif headline in cream, stacked on two lines across the center of the poster, slightly overlapping the saxophone motif."},{"type":"text","text":"NOV 15 · CITY HALL","desc":"Small uppercase mustard-yellow caption centered near the bottom edge with wide letter spacing."}]}}"""
72
+
73
+ # User turn. `{aspect_ratio}` and `{original_prompt}` are filled in by `Ideogram4Pipeline.upsample_prompt`.
74
+ CAPTION_USER_TEMPLATE = """TARGET IMAGE ASPECT RATIO: {aspect_ratio} (width:height).
75
+ User idea: {original_prompt}"""
76
+
77
+
78
+ def build_caption_logits_processor(model, tokenizer):
79
+ """Build an `outlines` logits processor that constrains generation to the Ideogram4 caption schema.
80
+
81
+ Returns a logits processor compatible with `transformers` `generate(logits_processor=[...])`. The caller is
82
+ responsible for checking `is_outlines_available()` first; `outlines` (and its `pydantic` dependency) are
83
+ imported lazily here so they remain optional. The schema mirrors Ideogram's native caption /
84
+ caption_verifier: a high-level description plus a compositional deconstruction of background + typed elements.
85
+ """
86
+ from typing import List, Literal, Union
87
+
88
+ import outlines
89
+ from pydantic import BaseModel, Field
90
+
91
+ class ObjElement(BaseModel):
92
+ type: Literal["obj"]
93
+ desc: str
94
+
95
+ class TextElement(BaseModel):
96
+ type: Literal["text"]
97
+ text: str
98
+ desc: str
99
+
100
+ class Composition(BaseModel):
101
+ background: str
102
+ elements: List[Union[ObjElement, TextElement]] = Field(min_length=1)
103
+
104
+ class Caption(BaseModel):
105
+ high_level_description: str
106
+ compositional_deconstruction: Composition
107
+
108
+ outlines_model = outlines.from_transformers(model, tokenizer)
109
+ return outlines.Generator(outlines_model, Caption).logits_processor
diffusers_src/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py CHANGED
@@ -206,12 +206,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
206
  module._parameters[tensor_name] = new_value
207
 
208
  def check_quantized_param_shape(self, param_name, current_param, loaded_param):
209
- import math
210
-
211
  current_param_shape = current_param.shape
212
  loaded_param_shape = loaded_param.shape
213
 
214
- n = math.prod(current_param_shape)
215
  inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
216
  if loaded_param_shape != inferred_shape:
217
  raise ValueError(
 
206
  module._parameters[tensor_name] = new_value
207
 
208
  def check_quantized_param_shape(self, param_name, current_param, loaded_param):
 
 
209
  current_param_shape = current_param.shape
210
  loaded_param_shape = loaded_param.shape
211
 
212
+ n = current_param_shape.numel()
213
  inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
214
  if loaded_param_shape != inferred_shape:
215
  raise ValueError(
diffusers_src/src/diffusers/utils/__init__.py CHANGED
@@ -101,6 +101,7 @@ from .import_utils import (
101
  is_opencv_available,
102
  is_optimum_quanto_available,
103
  is_optimum_quanto_version,
 
104
  is_peft_available,
105
  is_peft_version,
106
  is_pytorch_retinaface_available,
 
101
  is_opencv_available,
102
  is_optimum_quanto_available,
103
  is_optimum_quanto_version,
104
+ is_outlines_available,
105
  is_peft_available,
106
  is_peft_version,
107
  is_pytorch_retinaface_available,
diffusers_src/src/diffusers/utils/import_utils.py CHANGED
@@ -204,6 +204,7 @@ _wandb_available, _wandb_version = _is_package_available("wandb")
204
  _tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
205
  _compel_available, _compel_version = _is_package_available("compel")
206
  _sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
 
207
  _torchsde_available, _torchsde_version = _is_package_available("torchsde")
208
  _peft_available, _peft_version = _is_package_available("peft")
209
  _torchvision_available, _torchvision_version = _is_package_available("torchvision")
@@ -370,6 +371,10 @@ def is_sentencepiece_available():
370
  return _sentencepiece_available
371
 
372
 
 
 
 
 
373
  def is_imageio_available():
374
  return _imageio_available
375
 
 
204
  _tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
205
  _compel_available, _compel_version = _is_package_available("compel")
206
  _sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
207
+ _outlines_available, _outlines_version = _is_package_available("outlines")
208
  _torchsde_available, _torchsde_version = _is_package_available("torchsde")
209
  _peft_available, _peft_version = _is_package_available("peft")
210
  _torchvision_available, _torchvision_version = _is_package_available("torchvision")
 
371
  return _sentencepiece_available
372
 
373
 
374
+ def is_outlines_available():
375
+ return _outlines_available
376
+
377
+
378
  def is_imageio_available():
379
  return _imageio_available
380