Spaces:
Running on Zero
Running on Zero
Bundle PR diffusers (yiyi-refactor-fused + native prompt upsampling)
Browse files- diffusers_src/src/diffusers/models/transformers/transformer_ideogram4.py +85 -57
- diffusers_src/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py +329 -184
- diffusers_src/src/diffusers/pipelines/ideogram4/prompt_enhancer.py +109 -0
- diffusers_src/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py +1 -3
- diffusers_src/src/diffusers/utils/__init__.py +1 -0
- diffusers_src/src/diffusers/utils/import_utils.py +5 -0
diffusers_src/src/diffusers/models/transformers/transformer_ideogram4.py
CHANGED
|
@@ -12,6 +12,7 @@
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
|
|
|
| 15 |
import math
|
| 16 |
|
| 17 |
import torch
|
|
@@ -22,6 +23,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
|
| 22 |
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 23 |
from ...utils import logging
|
| 24 |
from ...utils.torch_utils import maybe_allow_in_graph
|
|
|
|
|
|
|
| 25 |
from ..modeling_outputs import Transformer2DModelOutput
|
| 26 |
from ..modeling_utils import ModelMixin
|
| 27 |
from ..normalization import RMSNorm
|
|
@@ -44,19 +47,6 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
| 44 |
return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
|
| 45 |
|
| 46 |
|
| 47 |
-
def _apply_rotary_pos_emb(
|
| 48 |
-
q: torch.Tensor,
|
| 49 |
-
k: torch.Tensor,
|
| 50 |
-
cos: torch.Tensor,
|
| 51 |
-
sin: torch.Tensor,
|
| 52 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 53 |
-
cos = cos.unsqueeze(1)
|
| 54 |
-
sin = sin.unsqueeze(1)
|
| 55 |
-
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
| 56 |
-
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
| 57 |
-
return q_embed, k_embed
|
| 58 |
-
|
| 59 |
-
|
| 60 |
class Ideogram4MRoPE(nn.Module):
|
| 61 |
"""Multi-axis (t, h, w) interleaved rotary position embedding."""
|
| 62 |
|
|
@@ -74,7 +64,6 @@ class Ideogram4MRoPE(nn.Module):
|
|
| 74 |
self.mrope_section = tuple(mrope_section)
|
| 75 |
self.head_dim = head_dim
|
| 76 |
|
| 77 |
-
@torch.no_grad()
|
| 78 |
def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 79 |
# position_ids: (B, L, 3) of int (axes are t, h, w).
|
| 80 |
if position_ids.ndim != 3 or position_ids.shape[-1] != 3:
|
|
@@ -97,8 +86,49 @@ class Ideogram4MRoPE(nn.Module):
|
|
| 97 |
return emb.cos(), emb.sin()
|
| 98 |
|
| 99 |
|
| 100 |
-
class
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-5) -> None:
|
| 104 |
super().__init__()
|
|
@@ -113,34 +143,23 @@ class Ideogram4Attention(nn.Module):
|
|
| 113 |
self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
|
| 114 |
self.o = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 115 |
|
|
|
|
|
|
|
| 116 |
def forward(
|
| 117 |
self,
|
| 118 |
hidden_states: torch.Tensor,
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
) -> torch.Tensor:
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
# SDPA expects (B, num_heads, L, head_dim).
|
| 132 |
-
q = q.transpose(1, 2)
|
| 133 |
-
k = k.transpose(1, 2)
|
| 134 |
-
v = v.transpose(1, 2)
|
| 135 |
-
|
| 136 |
-
q, k = _apply_rotary_pos_emb(q, k, cos, sin)
|
| 137 |
-
|
| 138 |
-
# Block-diagonal mask from segment ids: tokens only attend within their segment.
|
| 139 |
-
attn_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1)
|
| 140 |
-
|
| 141 |
-
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 142 |
-
out = out.transpose(1, 2).reshape(batch_size, seq_len, self.hidden_size)
|
| 143 |
-
return self.o(out)
|
| 144 |
|
| 145 |
|
| 146 |
class Ideogram4MLP(nn.Module):
|
|
@@ -180,9 +199,8 @@ class Ideogram4TransformerBlock(nn.Module):
|
|
| 180 |
def forward(
|
| 181 |
self,
|
| 182 |
hidden_states: torch.Tensor,
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
sin: torch.Tensor,
|
| 186 |
adaln_input: torch.Tensor,
|
| 187 |
) -> torch.Tensor:
|
| 188 |
mod = self.adaln_modulation(adaln_input)
|
|
@@ -194,9 +212,8 @@ class Ideogram4TransformerBlock(nn.Module):
|
|
| 194 |
|
| 195 |
attn_out = self.attention(
|
| 196 |
self.attention_norm1(hidden_states) * scale_msa,
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
sin=sin,
|
| 200 |
)
|
| 201 |
hidden_states = hidden_states + gate_msa * self.attention_norm2(attn_out)
|
| 202 |
hidden_states = hidden_states + gate_mlp * self.ffn_norm2(
|
|
@@ -251,7 +268,7 @@ class Ideogram4FinalLayer(nn.Module):
|
|
| 251 |
return self.linear(self.norm_final(hidden_states) * scale)
|
| 252 |
|
| 253 |
|
| 254 |
-
class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 255 |
r"""
|
| 256 |
The flow-matching transformer backbone used by the Ideogram 4 pipeline.
|
| 257 |
|
|
@@ -346,6 +363,19 @@ class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
|
| 346 |
adaln_dim=adaln_dim,
|
| 347 |
)
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
def forward(
|
| 350 |
self,
|
| 351 |
hidden_states: torch.Tensor,
|
|
@@ -377,19 +407,13 @@ class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
|
| 377 |
|
| 378 |
Returns:
|
| 379 |
[`~models.modeling_outputs.Transformer2DModelOutput`] or a `tuple` whose first element is a tensor of shape
|
| 380 |
-
`(batch_size, sequence_length, in_channels)` in
|
| 381 |
`OUTPUT_IMAGE_INDICATOR` carry meaningful velocity predictions.
|
| 382 |
"""
|
| 383 |
batch_size, seq_len, in_channels = hidden_states.shape
|
| 384 |
if in_channels != self.in_channels:
|
| 385 |
raise ValueError(f"Expected last dim {self.in_channels}, got {in_channels}.")
|
| 386 |
|
| 387 |
-
param_dtype = self.dtype
|
| 388 |
-
hidden_states = hidden_states.to(param_dtype)
|
| 389 |
-
timestep = timestep.to(param_dtype)
|
| 390 |
-
encoder_hidden_states = encoder_hidden_states.to(param_dtype)
|
| 391 |
-
|
| 392 |
-
indicator = indicator.to(torch.long)
|
| 393 |
llm_token_mask = (indicator == LLM_TOKEN_INDICATOR).to(hidden_states.dtype).unsqueeze(-1)
|
| 394 |
output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(hidden_states.dtype).unsqueeze(-1)
|
| 395 |
|
|
@@ -414,16 +438,20 @@ class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
|
| 414 |
cos, sin = self.rotary_emb(position_ids)
|
| 415 |
cos = cos.to(hidden_states.dtype)
|
| 416 |
sin = sin.to(hidden_states.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
for block in self.layers:
|
| 419 |
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 420 |
hidden_states = self._gradient_checkpointing_func(
|
| 421 |
-
block, hidden_states,
|
| 422 |
)
|
| 423 |
else:
|
| 424 |
-
hidden_states = block(hidden_states,
|
| 425 |
|
| 426 |
-
output = self.final_layer(hidden_states, conditioning=adaln_input)
|
| 427 |
|
| 428 |
if not return_dict:
|
| 429 |
return (output,)
|
|
|
|
| 12 |
# See the License for the specific language governing permissions and
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
+
import inspect
|
| 16 |
import math
|
| 17 |
|
| 18 |
import torch
|
|
|
|
| 23 |
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 24 |
from ...utils import logging
|
| 25 |
from ...utils.torch_utils import maybe_allow_in_graph
|
| 26 |
+
from ..attention import AttentionMixin, AttentionModuleMixin
|
| 27 |
+
from ..attention_dispatch import dispatch_attention_fn
|
| 28 |
from ..modeling_outputs import Transformer2DModelOutput
|
| 29 |
from ..modeling_utils import ModelMixin
|
| 30 |
from ..normalization import RMSNorm
|
|
|
|
| 47 |
return torch.cat((-x[..., half:], x[..., :half]), dim=-1)
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
class Ideogram4MRoPE(nn.Module):
|
| 51 |
"""Multi-axis (t, h, w) interleaved rotary position embedding."""
|
| 52 |
|
|
|
|
| 64 |
self.mrope_section = tuple(mrope_section)
|
| 65 |
self.head_dim = head_dim
|
| 66 |
|
|
|
|
| 67 |
def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 68 |
# position_ids: (B, L, 3) of int (axes are t, h, w).
|
| 69 |
if position_ids.ndim != 3 or position_ids.shape[-1] != 3:
|
|
|
|
| 86 |
return emb.cos(), emb.sin()
|
| 87 |
|
| 88 |
|
| 89 |
+
class Ideogram4AttnProcessor:
|
| 90 |
+
_attention_backend = None
|
| 91 |
+
_parallel_config = None
|
| 92 |
+
|
| 93 |
+
def __call__(
|
| 94 |
+
self,
|
| 95 |
+
attn: "Ideogram4Attention",
|
| 96 |
+
hidden_states: torch.Tensor,
|
| 97 |
+
attention_mask: torch.Tensor,
|
| 98 |
+
image_rotary_emb: tuple[torch.Tensor, torch.Tensor],
|
| 99 |
+
) -> torch.Tensor:
|
| 100 |
+
batch_size, seq_len, _ = hidden_states.shape
|
| 101 |
+
|
| 102 |
+
qkv = attn.qkv(hidden_states).view(batch_size, seq_len, 3, attn.num_heads, attn.head_dim)
|
| 103 |
+
query, key, value = qkv.unbind(dim=2)
|
| 104 |
+
|
| 105 |
+
query = attn.norm_q(query)
|
| 106 |
+
key = attn.norm_k(key)
|
| 107 |
+
|
| 108 |
+
# MRoPE applied in (B, L, num_heads, head_dim) layout; cos/sin broadcast over the head axis.
|
| 109 |
+
cos, sin = image_rotary_emb
|
| 110 |
+
cos = cos.unsqueeze(2)
|
| 111 |
+
sin = sin.unsqueeze(2)
|
| 112 |
+
query = (query * cos) + (_rotate_half(query) * sin)
|
| 113 |
+
key = (key * cos) + (_rotate_half(key) * sin)
|
| 114 |
+
|
| 115 |
+
hidden_states = dispatch_attention_fn(
|
| 116 |
+
query,
|
| 117 |
+
key,
|
| 118 |
+
value,
|
| 119 |
+
attn_mask=attention_mask,
|
| 120 |
+
backend=self._attention_backend,
|
| 121 |
+
parallel_config=self._parallel_config,
|
| 122 |
+
)
|
| 123 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 124 |
+
return attn.o(hidden_states)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class Ideogram4Attention(nn.Module, AttentionModuleMixin):
|
| 128 |
+
"""Self-attention with merged QKV, q/k RMSNorm, MRoPE and a block-diagonal segment mask."""
|
| 129 |
+
|
| 130 |
+
_default_processor_cls = Ideogram4AttnProcessor
|
| 131 |
+
_available_processors = [Ideogram4AttnProcessor]
|
| 132 |
|
| 133 |
def __init__(self, hidden_size: int, num_heads: int, eps: float = 1e-5) -> None:
|
| 134 |
super().__init__()
|
|
|
|
| 143 |
self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
|
| 144 |
self.o = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 145 |
|
| 146 |
+
self.set_processor(self._default_processor_cls())
|
| 147 |
+
|
| 148 |
def forward(
|
| 149 |
self,
|
| 150 |
hidden_states: torch.Tensor,
|
| 151 |
+
attention_mask: torch.Tensor | None = None,
|
| 152 |
+
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 153 |
+
**kwargs,
|
| 154 |
) -> torch.Tensor:
|
| 155 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 156 |
+
unused_kwargs = [k for k in kwargs if k not in attn_parameters]
|
| 157 |
+
if len(unused_kwargs) > 0:
|
| 158 |
+
logger.warning(
|
| 159 |
+
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
| 160 |
+
)
|
| 161 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 162 |
+
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
class Ideogram4MLP(nn.Module):
|
|
|
|
| 199 |
def forward(
|
| 200 |
self,
|
| 201 |
hidden_states: torch.Tensor,
|
| 202 |
+
attention_mask: torch.Tensor,
|
| 203 |
+
image_rotary_emb: tuple[torch.Tensor, torch.Tensor],
|
|
|
|
| 204 |
adaln_input: torch.Tensor,
|
| 205 |
) -> torch.Tensor:
|
| 206 |
mod = self.adaln_modulation(adaln_input)
|
|
|
|
| 212 |
|
| 213 |
attn_out = self.attention(
|
| 214 |
self.attention_norm1(hidden_states) * scale_msa,
|
| 215 |
+
attention_mask=attention_mask,
|
| 216 |
+
image_rotary_emb=image_rotary_emb,
|
|
|
|
| 217 |
)
|
| 218 |
hidden_states = hidden_states + gate_msa * self.attention_norm2(attn_out)
|
| 219 |
hidden_states = hidden_states + gate_mlp * self.ffn_norm2(
|
|
|
|
| 268 |
return self.linear(self.norm_final(hidden_states) * scale)
|
| 269 |
|
| 270 |
|
| 271 |
+
class Ideogram4Transformer2DModel(ModelMixin, ConfigMixin, AttentionMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 272 |
r"""
|
| 273 |
The flow-matching transformer backbone used by the Ideogram 4 pipeline.
|
| 274 |
|
|
|
|
| 363 |
adaln_dim=adaln_dim,
|
| 364 |
)
|
| 365 |
|
| 366 |
+
def fuse_qkv_projections(self):
|
| 367 |
+
# The attention already uses a single fused `qkv` projection, so there is nothing to fuse.
|
| 368 |
+
raise NotImplementedError(
|
| 369 |
+
"Ideogram4Transformer2DModel already uses a fused QKV projection (`attention.qkv`), "
|
| 370 |
+
"so `fuse_qkv_projections()` is not applicable."
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
def unfuse_qkv_projections(self):
|
| 374 |
+
raise NotImplementedError(
|
| 375 |
+
"Ideogram4Transformer2DModel uses a fused QKV projection that cannot be split, "
|
| 376 |
+
"so `unfuse_qkv_projections()` is not applicable."
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
def forward(
|
| 380 |
self,
|
| 381 |
hidden_states: torch.Tensor,
|
|
|
|
| 407 |
|
| 408 |
Returns:
|
| 409 |
[`~models.modeling_outputs.Transformer2DModelOutput`] or a `tuple` whose first element is a tensor of shape
|
| 410 |
+
`(batch_size, sequence_length, in_channels)` in the model's compute dtype. Only positions tagged with
|
| 411 |
`OUTPUT_IMAGE_INDICATOR` carry meaningful velocity predictions.
|
| 412 |
"""
|
| 413 |
batch_size, seq_len, in_channels = hidden_states.shape
|
| 414 |
if in_channels != self.in_channels:
|
| 415 |
raise ValueError(f"Expected last dim {self.in_channels}, got {in_channels}.")
|
| 416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
llm_token_mask = (indicator == LLM_TOKEN_INDICATOR).to(hidden_states.dtype).unsqueeze(-1)
|
| 418 |
output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(hidden_states.dtype).unsqueeze(-1)
|
| 419 |
|
|
|
|
| 438 |
cos, sin = self.rotary_emb(position_ids)
|
| 439 |
cos = cos.to(hidden_states.dtype)
|
| 440 |
sin = sin.to(hidden_states.dtype)
|
| 441 |
+
image_rotary_emb = (cos, sin)
|
| 442 |
+
|
| 443 |
+
# Block-diagonal mask from segment ids: tokens only attend within their segment. Shared by every block.
|
| 444 |
+
attention_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1)
|
| 445 |
|
| 446 |
for block in self.layers:
|
| 447 |
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 448 |
hidden_states = self._gradient_checkpointing_func(
|
| 449 |
+
block, hidden_states, attention_mask, image_rotary_emb, adaln_input
|
| 450 |
)
|
| 451 |
else:
|
| 452 |
+
hidden_states = block(hidden_states, attention_mask, image_rotary_emb, adaln_input)
|
| 453 |
|
| 454 |
+
output = self.final_layer(hidden_states, conditioning=adaln_input)
|
| 455 |
|
| 456 |
if not return_dict:
|
| 457 |
return (output,)
|
diffusers_src/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py
CHANGED
|
@@ -29,10 +29,11 @@ from ...models.transformers.transformer_ideogram4 import (
|
|
| 29 |
Ideogram4Transformer2DModel,
|
| 30 |
)
|
| 31 |
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 32 |
-
from ...utils import logging, replace_example_docstring
|
| 33 |
from ...utils.torch_utils import randn_tensor
|
| 34 |
from ..pipeline_utils import DiffusionPipeline
|
| 35 |
from .pipeline_output import Ideogram4PipelineOutput
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
@@ -42,10 +43,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
| 42 |
# text conditioning consumed by the Ideogram4 transformer.
|
| 43 |
QWEN3_VL_ACTIVATION_LAYERS = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 35)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
DEFAULT_STD = 1.5
|
| 49 |
|
| 50 |
|
| 51 |
EXAMPLE_DOC_STRING = """
|
|
@@ -109,6 +109,32 @@ def _resolution_aware_mu(
|
|
| 109 |
return base_mu + 0.5 * math.log(num_pixels / base_pixels)
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
class Ideogram4Pipeline(DiffusionPipeline):
|
| 113 |
r"""
|
| 114 |
Text-to-image pipeline for Ideogram4.
|
|
@@ -165,38 +191,110 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 165 |
self.patch_size = 2
|
| 166 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * self.patch_size)
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
def
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
self,
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
max_text_tokens: int,
|
| 189 |
device: torch.device,
|
| 190 |
-
) ->
|
| 191 |
-
"""Build the packed
|
| 192 |
-
tokenized = [self._tokenize(p, max_text_tokens) for p in prompts]
|
| 193 |
-
batch_size = len(prompts)
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
grid_w = width // patch
|
| 200 |
num_image_tokens = grid_h * grid_w
|
| 201 |
total_seq_len = max_text_tokens + num_image_tokens
|
| 202 |
|
|
@@ -206,21 +304,15 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 206 |
t_idx = torch.zeros_like(h_idx)
|
| 207 |
image_pos = torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET
|
| 208 |
|
| 209 |
-
token_ids = torch.zeros(batch_size, total_seq_len, dtype=torch.long)
|
| 210 |
-
text_position_ids = torch.zeros(batch_size, total_seq_len, 3, dtype=torch.long)
|
| 211 |
position_ids = torch.zeros(batch_size, total_seq_len, 3, dtype=torch.long)
|
| 212 |
segment_ids = torch.full((batch_size, total_seq_len), SEQUENCE_PADDING_INDICATOR, dtype=torch.long)
|
| 213 |
indicator = torch.zeros(batch_size, total_seq_len, dtype=torch.long)
|
| 214 |
|
| 215 |
-
for b,
|
| 216 |
-
|
| 217 |
-
offset = pad_len
|
| 218 |
-
|
| 219 |
-
token_ids[b, offset : offset + num_text] = toks
|
| 220 |
|
| 221 |
text_pos = torch.arange(num_text)
|
| 222 |
text_pos_3d = torch.stack([text_pos, text_pos, text_pos], dim=1)
|
| 223 |
-
text_position_ids[b, offset : offset + num_text] = text_pos_3d
|
| 224 |
position_ids[b, offset : offset + num_text] = text_pos_3d
|
| 225 |
position_ids[b, offset + num_text :] = image_pos
|
| 226 |
|
|
@@ -229,16 +321,7 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 229 |
|
| 230 |
segment_ids[b, offset : offset + num_text + num_image_tokens] = 1
|
| 231 |
|
| 232 |
-
return
|
| 233 |
-
"token_ids": token_ids.to(device),
|
| 234 |
-
"text_position_ids": text_position_ids.to(device),
|
| 235 |
-
"position_ids": position_ids.to(device),
|
| 236 |
-
"segment_ids": segment_ids.to(device),
|
| 237 |
-
"indicator": indicator.to(device),
|
| 238 |
-
"num_image_tokens": num_image_tokens,
|
| 239 |
-
"grid_h": grid_h,
|
| 240 |
-
"grid_w": grid_w,
|
| 241 |
-
}
|
| 242 |
|
| 243 |
def _get_text_encoder_hidden_states(
|
| 244 |
self,
|
|
@@ -283,28 +366,60 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 283 |
|
| 284 |
def encode_prompt(
|
| 285 |
self,
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
device: torch.device,
|
| 291 |
-
) -> torch.Tensor:
|
| 292 |
-
"""
|
| 293 |
-
batch_size, seq_len = token_ids.shape
|
| 294 |
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
-
|
| 305 |
-
text_mask = attention_mask.to(stacked.dtype).unsqueeze(-1)
|
| 306 |
-
stacked = stacked * text_mask
|
| 307 |
-
return stacked.to(torch.float32)
|
| 308 |
|
| 309 |
def prepare_latents(
|
| 310 |
self,
|
|
@@ -325,27 +440,6 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 325 |
latents = latents.to(device=device, dtype=dtype)
|
| 326 |
return latents
|
| 327 |
|
| 328 |
-
def _decode(self, z: torch.Tensor, grid_h: int, grid_w: int) -> torch.Tensor:
|
| 329 |
-
"""Unpatch latents, denormalize with the VAE batch-norm stats, and decode through the VAE."""
|
| 330 |
-
batch_size = z.shape[0]
|
| 331 |
-
patch = self.patch_size
|
| 332 |
-
|
| 333 |
-
# VAE bn stores per-channel statistics on the packed-channel latent space (ae_channels * patch ** 2).
|
| 334 |
-
bn_mean = self.vae.bn.running_mean.view(1, 1, -1).to(device=z.device, dtype=z.dtype)
|
| 335 |
-
bn_std = torch.sqrt(self.vae.bn.running_var + self.vae.config.batch_norm_eps).view(1, 1, -1)
|
| 336 |
-
bn_std = bn_std.to(device=z.device, dtype=z.dtype)
|
| 337 |
-
|
| 338 |
-
z = z * bn_std + bn_mean
|
| 339 |
-
|
| 340 |
-
ae_channels = z.shape[-1] // (patch * patch)
|
| 341 |
-
z = z.view(batch_size, grid_h, grid_w, patch, patch, ae_channels)
|
| 342 |
-
z = z.permute(0, 5, 1, 3, 2, 4).contiguous()
|
| 343 |
-
z = z.view(batch_size, ae_channels, grid_h * patch, grid_w * patch)
|
| 344 |
-
|
| 345 |
-
z = z.to(self.vae.dtype)
|
| 346 |
-
image = self.vae.decode(z, return_dict=False)[0]
|
| 347 |
-
return image
|
| 348 |
-
|
| 349 |
@property
|
| 350 |
def guidance_scale(self) -> float | None:
|
| 351 |
return self._guidance_scale
|
|
@@ -358,6 +452,50 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 358 |
def interrupt(self) -> bool:
|
| 359 |
return self._interrupt
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
@torch.no_grad()
|
| 362 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 363 |
def __call__(
|
|
@@ -365,11 +503,12 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 365 |
prompt: str | list[str] | None = None,
|
| 366 |
height: int = 2048,
|
| 367 |
width: int = 2048,
|
| 368 |
-
num_inference_steps: int =
|
| 369 |
guidance_scale: float | None = None,
|
| 370 |
-
guidance_schedule: list[float] | torch.Tensor | None =
|
| 371 |
-
mu: float =
|
| 372 |
-
std: float =
|
|
|
|
| 373 |
max_sequence_length: int = 2048,
|
| 374 |
num_images_per_prompt: int = 1,
|
| 375 |
generator: torch.Generator | list[torch.Generator] | None = None,
|
|
@@ -377,7 +516,7 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 377 |
output_type: str = "pil",
|
| 378 |
return_dict: bool = True,
|
| 379 |
callback_on_step_end: Callable[["Ideogram4Pipeline", int, int, dict[str, Any]], dict[str, Any]] | None = None,
|
| 380 |
-
callback_on_step_end_tensor_inputs: list[str]
|
| 381 |
) -> Ideogram4PipelineOutput | tuple[Any]:
|
| 382 |
r"""
|
| 383 |
Run text-to-image generation.
|
|
@@ -396,16 +535,19 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 396 |
velocity predictions are blended as `v = guidance_scale * v_pos + (1 - guidance_scale) * v_neg`.
|
| 397 |
Mutually exclusive with `guidance_schedule` (setting both raises). Defaults to `None`.
|
| 398 |
guidance_schedule (`list[float]` or `torch.Tensor`, *optional*):
|
| 399 |
-
Per-step guidance scale schedule; must have length `num_inference_steps`. The first entry corresponds
|
| 400 |
-
the first step (largest noise level). Mutually exclusive with `guidance_scale`
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
to 3.0 for the final 3 "polish" steps).
|
| 404 |
mu (`float`, *optional*, defaults to 0.0):
|
| 405 |
Base mean of the logit-normal flow-matching schedule. The schedule mean is shifted by half the log of
|
| 406 |
the resolution ratio relative to 512x512.
|
| 407 |
std (`float`, *optional*, defaults to 1.5):
|
| 408 |
Standard deviation of the logit-normal flow-matching schedule.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
max_sequence_length (`int`, *optional*, defaults to 2048):
|
| 410 |
Maximum number of text tokens per prompt.
|
| 411 |
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
@@ -428,66 +570,63 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 428 |
Returns:
|
| 429 |
[`~pipelines.ideogram4.Ideogram4PipelineOutput`] or `tuple`.
|
| 430 |
"""
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
if num_images_per_prompt <= 0:
|
| 441 |
-
raise ValueError(f"`num_images_per_prompt` must be > 0, got {num_images_per_prompt}.")
|
| 442 |
-
if guidance_scale is not None and guidance_schedule is not None:
|
| 443 |
-
raise ValueError(
|
| 444 |
-
"Only one of `guidance_scale` and `guidance_schedule` may be set."
|
| 445 |
-
)
|
| 446 |
-
if guidance_scale is None and guidance_schedule is None:
|
| 447 |
-
raise ValueError(
|
| 448 |
-
"One of `guidance_scale` and `guidance_schedule` must be set."
|
| 449 |
-
)
|
| 450 |
-
|
| 451 |
-
callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs or ["latents"]
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
f"(vae_scale_factor * patch_size)."
|
| 458 |
-
)
|
| 459 |
|
| 460 |
device = self._execution_device
|
| 461 |
self._guidance_scale = guidance_scale
|
| 462 |
self._interrupt = False
|
| 463 |
|
| 464 |
-
#
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
|
|
|
| 471 |
)
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
text_position_ids=inputs["text_position_ids"],
|
| 481 |
-
indicator=inputs["indicator"],
|
| 482 |
device=device,
|
| 483 |
)
|
| 484 |
|
| 485 |
-
# 3. Replicate
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
|
| 492 |
# 4. Set up the resolution-aware logit-normal schedule on the scheduler.
|
| 493 |
schedule_mu = _resolution_aware_mu(height=height, width=width, base_mu=mu)
|
|
@@ -496,21 +635,16 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 496 |
timesteps = self.scheduler.timesteps
|
| 497 |
self._num_timesteps = len(timesteps)
|
| 498 |
|
| 499 |
-
# 5. Resolve per-step guidance
|
| 500 |
-
# `guidance_schedule`
|
| 501 |
if guidance_scale is not None:
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
gw = torch.as_tensor(guidance_schedule, dtype=torch.float32, device=device)
|
| 505 |
-
if gw.shape != (num_inference_steps,):
|
| 506 |
-
raise ValueError(
|
| 507 |
-
f"`guidance_schedule` must have shape ({num_inference_steps},), got {tuple(gw.shape)}"
|
| 508 |
-
)
|
| 509 |
|
| 510 |
# 6. Prepare latents in the packed (B, num_image_tokens, latent_dim) layout.
|
| 511 |
latent_dim = self.transformer.config.in_channels
|
| 512 |
latents = self.prepare_latents(
|
| 513 |
-
batch_size=
|
| 514 |
num_image_tokens=num_image_tokens,
|
| 515 |
latent_dim=latent_dim,
|
| 516 |
dtype=torch.float32,
|
|
@@ -519,27 +653,21 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 519 |
latents=latents,
|
| 520 |
)
|
| 521 |
|
| 522 |
-
# 7.
|
| 523 |
max_text_tokens = max_sequence_length
|
| 524 |
-
neg_position_ids = inputs["position_ids"][:, max_text_tokens:]
|
| 525 |
-
neg_segment_ids = inputs["segment_ids"][:, max_text_tokens:]
|
| 526 |
-
neg_indicator = inputs["indicator"][:, max_text_tokens:]
|
| 527 |
-
neg_llm_features = torch.zeros(
|
| 528 |
-
effective_batch_size,
|
| 529 |
-
num_image_tokens,
|
| 530 |
-
llm_features.shape[-1],
|
| 531 |
-
dtype=llm_features.dtype,
|
| 532 |
-
device=device,
|
| 533 |
-
)
|
| 534 |
-
|
| 535 |
text_z_padding = torch.zeros(
|
| 536 |
-
|
| 537 |
max_text_tokens,
|
| 538 |
latent_dim,
|
| 539 |
dtype=torch.float32,
|
| 540 |
device=device,
|
| 541 |
)
|
| 542 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
# 8. Denoising loop. The scheduler stores `num_train_timesteps`-scaled timesteps; convert back to model time.
|
| 544 |
num_train_timesteps = self.scheduler.config.num_train_timesteps
|
| 545 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
@@ -549,36 +677,40 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 549 |
|
| 550 |
# Map sigma-domain timestep to model time `t` in [0, 1] (0 = noise, 1 = clean data).
|
| 551 |
t_model = 1.0 - (t.float() / num_train_timesteps)
|
| 552 |
-
t_model = t_model.expand(
|
| 553 |
|
| 554 |
# Conditional pass operates on the full packed sequence.
|
| 555 |
-
pos_z = torch.cat([text_z_padding, latents], dim=1)
|
| 556 |
pos_out = self.transformer(
|
| 557 |
hidden_states=pos_z,
|
| 558 |
timestep=t_model,
|
| 559 |
encoder_hidden_states=llm_features,
|
| 560 |
-
position_ids=
|
| 561 |
-
segment_ids=
|
| 562 |
-
indicator=
|
| 563 |
return_dict=False,
|
| 564 |
)[0]
|
| 565 |
-
|
|
|
|
|
|
|
| 566 |
|
| 567 |
# Unconditional pass uses image-only positions with zeroed text features.
|
| 568 |
neg_v = self.unconditional_transformer(
|
| 569 |
-
hidden_states=latents,
|
| 570 |
timestep=t_model,
|
| 571 |
encoder_hidden_states=neg_llm_features,
|
| 572 |
position_ids=neg_position_ids,
|
| 573 |
segment_ids=neg_segment_ids,
|
| 574 |
indicator=neg_indicator,
|
| 575 |
return_dict=False,
|
| 576 |
-
)[0]
|
| 577 |
|
|
|
|
|
|
|
| 578 |
gw_i = gw[i]
|
| 579 |
v = gw_i * pos_v + (1.0 - gw_i) * neg_v
|
| 580 |
|
| 581 |
-
latents = self.scheduler.step(-v
|
| 582 |
|
| 583 |
if callback_on_step_end is not None:
|
| 584 |
callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
|
|
@@ -587,11 +719,24 @@ class Ideogram4Pipeline(DiffusionPipeline):
|
|
| 587 |
|
| 588 |
progress_bar.update()
|
| 589 |
|
| 590 |
-
# 9. Decode.
|
| 591 |
if output_type == "latent":
|
| 592 |
image = latents
|
| 593 |
else:
|
| 594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
image = self.image_processor.postprocess(decoded.float(), output_type=output_type)
|
| 596 |
|
| 597 |
self.maybe_free_model_hooks()
|
|
|
|
| 29 |
Ideogram4Transformer2DModel,
|
| 30 |
)
|
| 31 |
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
| 32 |
+
from ...utils import is_outlines_available, logging, replace_example_docstring
|
| 33 |
from ...utils.torch_utils import randn_tensor
|
| 34 |
from ..pipeline_utils import DiffusionPipeline
|
| 35 |
from .pipeline_output import Ideogram4PipelineOutput
|
| 36 |
+
from .prompt_enhancer import CAPTION_SYSTEM_MESSAGE, CAPTION_USER_TEMPLATE, build_caption_logits_processor
|
| 37 |
|
| 38 |
|
| 39 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
| 43 |
# text conditioning consumed by the Ideogram4 transformer.
|
| 44 |
QWEN3_VL_ACTIVATION_LAYERS = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 35)
|
| 45 |
|
| 46 |
+
# LM head grafted onto the (head-less) text encoder for optional prompt upsampling.
|
| 47 |
+
DEFAULT_PROMPT_ENHANCER_LM_HEAD_REPO = "multimodalart/qwen3-vl-8b-instruct-lm-head"
|
| 48 |
+
PROMPT_UPSAMPLE_TEMPERATURE = 1.0
|
|
|
|
| 49 |
|
| 50 |
|
| 51 |
EXAMPLE_DOC_STRING = """
|
|
|
|
| 109 |
return base_mu + 0.5 * math.log(num_pixels / base_pixels)
|
| 110 |
|
| 111 |
|
| 112 |
+
def _expand_tensor_to_effective_batch(
|
| 113 |
+
tensor: torch.Tensor,
|
| 114 |
+
batch_size: int,
|
| 115 |
+
num_per_prompt: int,
|
| 116 |
+
tensor_name: str | None = None,
|
| 117 |
+
) -> torch.Tensor:
|
| 118 |
+
"""Replicate `tensor` along dim 0 from `batch_size` (or 1) to `batch_size * num_per_prompt`."""
|
| 119 |
+
target_batch_size = batch_size * num_per_prompt
|
| 120 |
+
|
| 121 |
+
if tensor.shape[0] == target_batch_size:
|
| 122 |
+
return tensor
|
| 123 |
+
|
| 124 |
+
if tensor.shape[0] == 1:
|
| 125 |
+
repeat_by = target_batch_size
|
| 126 |
+
elif tensor.shape[0] == batch_size:
|
| 127 |
+
repeat_by = num_per_prompt
|
| 128 |
+
else:
|
| 129 |
+
tensor_name = f"`{tensor_name}`" if tensor_name is not None else "Tensor"
|
| 130 |
+
raise ValueError(
|
| 131 |
+
f"{tensor_name} batch size must be 1, `batch_size` ({batch_size}), or "
|
| 132 |
+
f"`batch_size * num_*_per_prompt` ({target_batch_size}), but got {tensor.shape[0]}."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return torch.repeat_interleave(tensor, repeats=repeat_by, dim=0, output_size=tensor.shape[0] * repeat_by)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
class Ideogram4Pipeline(DiffusionPipeline):
|
| 139 |
r"""
|
| 140 |
Text-to-image pipeline for Ideogram4.
|
|
|
|
| 191 |
self.patch_size = 2
|
| 192 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * self.patch_size)
|
| 193 |
|
| 194 |
+
# Lazily built by `load_prompt_enhancer` for optional prompt upsampling.
|
| 195 |
+
self._caption_model = None
|
| 196 |
+
self._caption_logits_processor = None
|
| 197 |
+
|
| 198 |
+
def load_prompt_enhancer(
|
| 199 |
+
self,
|
| 200 |
+
lm_head_repo_id: str = DEFAULT_PROMPT_ENHANCER_LM_HEAD_REPO,
|
| 201 |
+
lm_head_filename: str = "lm_head.safetensors",
|
| 202 |
+
torch_dtype: torch.dtype | None = None,
|
| 203 |
+
) -> PreTrainedModel:
|
| 204 |
+
"""Make the frozen text encoder generative for prompt upsampling by grafting a hosted LM head.
|
| 205 |
+
|
| 206 |
+
The head is the only extra weight loaded; the encoder body is shared (no second model in memory).
|
| 207 |
+
Called automatically by `upsample_prompt` on first use. Generation is constrained to the caption JSON
|
| 208 |
+
schema when `outlines` is installed; otherwise it falls back to unconstrained decoding with a warning.
|
| 209 |
+
"""
|
| 210 |
+
from accelerate import init_empty_weights
|
| 211 |
+
from huggingface_hub import hf_hub_download
|
| 212 |
+
from safetensors.torch import load_file
|
| 213 |
+
from transformers import Qwen3VLForConditionalGeneration
|
| 214 |
+
|
| 215 |
+
dtype = torch_dtype or self.text_encoder.dtype
|
| 216 |
+
head_weight = load_file(hf_hub_download(lm_head_repo_id, lm_head_filename))["lm_head.weight"].to(dtype)
|
| 217 |
+
|
| 218 |
+
with init_empty_weights():
|
| 219 |
+
caption_model = Qwen3VLForConditionalGeneration(self.text_encoder.config)
|
| 220 |
+
caption_model.model = self.text_encoder # reuse the loaded encoder body
|
| 221 |
+
lm_head = torch.nn.Linear(head_weight.shape[1], head_weight.shape[0], bias=False)
|
| 222 |
+
with torch.no_grad():
|
| 223 |
+
lm_head.weight.copy_(head_weight)
|
| 224 |
+
caption_model.lm_head = lm_head.to(device=self.text_encoder.device, dtype=dtype)
|
| 225 |
+
caption_model.eval()
|
| 226 |
+
|
| 227 |
+
if is_outlines_available():
|
| 228 |
+
logits_processor = build_caption_logits_processor(caption_model, self.tokenizer)
|
| 229 |
+
else:
|
| 230 |
+
logits_processor = None
|
| 231 |
+
logger.warning(
|
| 232 |
+
"`outlines` is not installed; prompt upsampling will run unconstrained and may not return "
|
| 233 |
+
"schema-valid JSON. Install with `pip install outlines` for structured captions."
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
self._caption_model = caption_model
|
| 237 |
+
self._caption_logits_processor = logits_processor
|
| 238 |
+
return caption_model
|
| 239 |
+
|
| 240 |
+
def upsample_prompt(
|
| 241 |
+
self,
|
| 242 |
+
prompt: str | list[str],
|
| 243 |
+
height: int = 2048,
|
| 244 |
+
width: int = 2048,
|
| 245 |
+
max_new_tokens: int = 1024,
|
| 246 |
+
lm_head_repo_id: str = DEFAULT_PROMPT_ENHANCER_LM_HEAD_REPO,
|
| 247 |
+
device: torch.device | None = None,
|
| 248 |
+
) -> list[str]:
|
| 249 |
+
"""Rewrite each prompt into Ideogram4's native structured JSON caption via the grafted text encoder."""
|
| 250 |
+
if self._caption_model is None:
|
| 251 |
+
self.load_prompt_enhancer(lm_head_repo_id=lm_head_repo_id)
|
| 252 |
+
|
| 253 |
+
device = device or self._caption_model.device
|
| 254 |
+
prompts = [prompt] if isinstance(prompt, str) else list(prompt)
|
| 255 |
+
divisor = math.gcd(width, height) or 1
|
| 256 |
+
aspect_ratio = f"{width // divisor}:{height // divisor}"
|
| 257 |
+
|
| 258 |
+
captions = []
|
| 259 |
+
for text_prompt in prompts:
|
| 260 |
+
messages = [
|
| 261 |
+
{"role": "system", "content": CAPTION_SYSTEM_MESSAGE},
|
| 262 |
+
{
|
| 263 |
+
"role": "user",
|
| 264 |
+
"content": CAPTION_USER_TEMPLATE.format(aspect_ratio=aspect_ratio, original_prompt=text_prompt),
|
| 265 |
+
},
|
| 266 |
+
]
|
| 267 |
+
inputs = self.tokenizer.apply_chat_template(
|
| 268 |
+
messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
|
| 269 |
+
).to(device)
|
| 270 |
+
generate_kwargs = {
|
| 271 |
+
"max_new_tokens": max_new_tokens,
|
| 272 |
+
"do_sample": True,
|
| 273 |
+
"temperature": PROMPT_UPSAMPLE_TEMPERATURE,
|
| 274 |
+
"use_cache": True,
|
| 275 |
+
}
|
| 276 |
+
if self._caption_logits_processor is not None:
|
| 277 |
+
self._caption_logits_processor.reset()
|
| 278 |
+
generate_kwargs["logits_processor"] = [self._caption_logits_processor]
|
| 279 |
+
generated = self._caption_model.generate(**inputs, **generate_kwargs)
|
| 280 |
+
new_tokens = generated[:, inputs["input_ids"].shape[1] :]
|
| 281 |
+
captions.append(self.tokenizer.decode(new_tokens[0], skip_special_tokens=True).strip())
|
| 282 |
+
return captions
|
| 283 |
+
|
| 284 |
+
def _prepare_ids(
|
| 285 |
self,
|
| 286 |
+
text_lengths: list[int],
|
| 287 |
+
grid_h: int,
|
| 288 |
+
grid_w: int,
|
| 289 |
max_text_tokens: int,
|
| 290 |
device: torch.device,
|
| 291 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 292 |
+
"""Build the packed `[left-pad][text][image]` layout from the per-prompt text lengths and the image grid.
|
|
|
|
|
|
|
| 293 |
|
| 294 |
+
Returns `position_ids` (3-axis MRoPE), `segment_ids` (block-diagonal attention) and `indicator` (per-token
|
| 295 |
+
text/image/pad role).
|
| 296 |
+
"""
|
| 297 |
+
batch_size = len(text_lengths)
|
|
|
|
| 298 |
num_image_tokens = grid_h * grid_w
|
| 299 |
total_seq_len = max_text_tokens + num_image_tokens
|
| 300 |
|
|
|
|
| 304 |
t_idx = torch.zeros_like(h_idx)
|
| 305 |
image_pos = torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET
|
| 306 |
|
|
|
|
|
|
|
| 307 |
position_ids = torch.zeros(batch_size, total_seq_len, 3, dtype=torch.long)
|
| 308 |
segment_ids = torch.full((batch_size, total_seq_len), SEQUENCE_PADDING_INDICATOR, dtype=torch.long)
|
| 309 |
indicator = torch.zeros(batch_size, total_seq_len, dtype=torch.long)
|
| 310 |
|
| 311 |
+
for b, num_text in enumerate(text_lengths):
|
| 312 |
+
offset = max_text_tokens - num_text
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
text_pos = torch.arange(num_text)
|
| 315 |
text_pos_3d = torch.stack([text_pos, text_pos, text_pos], dim=1)
|
|
|
|
| 316 |
position_ids[b, offset : offset + num_text] = text_pos_3d
|
| 317 |
position_ids[b, offset + num_text :] = image_pos
|
| 318 |
|
|
|
|
| 321 |
|
| 322 |
segment_ids[b, offset : offset + num_text + num_image_tokens] = 1
|
| 323 |
|
| 324 |
+
return position_ids.to(device), segment_ids.to(device), indicator.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
def _get_text_encoder_hidden_states(
|
| 327 |
self,
|
|
|
|
| 366 |
|
| 367 |
def encode_prompt(
|
| 368 |
self,
|
| 369 |
+
prompt: str | list[str],
|
| 370 |
+
grid_h: int,
|
| 371 |
+
grid_w: int,
|
| 372 |
+
max_sequence_length: int,
|
| 373 |
device: torch.device,
|
| 374 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 375 |
+
"""Prepare the conditioning for the packed text+image sequence (one entry per prompt).
|
|
|
|
| 376 |
|
| 377 |
+
Returns a flat tuple `(prompt_embeds, position_ids, segment_ids, indicator)`. The unconditional branch carries
|
| 378 |
+
no text, so the pipeline builds its (zeroed) inputs directly rather than encoding a negative prompt.
|
| 379 |
+
"""
|
| 380 |
+
prompts = [prompt] if isinstance(prompt, str) else list(prompt)
|
| 381 |
+
batch_size = len(prompts)
|
| 382 |
+
num_image_tokens = grid_h * grid_w
|
| 383 |
|
| 384 |
+
# Tokenize each chat-formatted prompt and left-pad to `max_sequence_length`. Only the text region is fed to
|
| 385 |
+
# the encoder: the packed image tokens come after the text and the encoder is causal, so they never affect it.
|
| 386 |
+
token_ids = torch.zeros(batch_size, max_sequence_length, dtype=torch.long)
|
| 387 |
+
attention_mask = torch.zeros(batch_size, max_sequence_length, dtype=torch.long)
|
| 388 |
+
text_position_ids = torch.zeros(batch_size, max_sequence_length, dtype=torch.long)
|
| 389 |
+
text_lengths = []
|
| 390 |
+
for b, text_prompt in enumerate(prompts):
|
| 391 |
+
messages = [{"role": "user", "content": [{"type": "text", "text": text_prompt}]}]
|
| 392 |
+
text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| 393 |
+
toks = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0]
|
| 394 |
+
n = int(toks.shape[0])
|
| 395 |
+
if n > max_sequence_length:
|
| 396 |
+
raise ValueError(f"prompt has {n} tokens, exceeds max_sequence_length={max_sequence_length}")
|
| 397 |
+
text_lengths.append(n)
|
| 398 |
+
offset = max_sequence_length - n
|
| 399 |
+
token_ids[b, offset:] = toks
|
| 400 |
+
attention_mask[b, offset:] = 1
|
| 401 |
+
text_position_ids[b, offset:] = torch.arange(n)
|
| 402 |
+
|
| 403 |
+
token_ids = token_ids.to(device)
|
| 404 |
+
attention_mask = attention_mask.to(device)
|
| 405 |
+
text_position_ids = text_position_ids.to(device)
|
| 406 |
+
|
| 407 |
+
# Concatenate the tapped activation-layer hidden states into per-token text features, zeroing padding.
|
| 408 |
+
selected = self._get_text_encoder_hidden_states(token_ids, attention_mask, text_position_ids)
|
| 409 |
+
text_features = torch.stack(selected, dim=0).permute(1, 2, 3, 0).reshape(batch_size, max_sequence_length, -1)
|
| 410 |
+
text_features = (text_features * attention_mask.to(text_features.dtype).unsqueeze(-1)).to(torch.float32)
|
| 411 |
+
|
| 412 |
+
position_ids, segment_ids, indicator = self._prepare_ids(
|
| 413 |
+
text_lengths, grid_h, grid_w, max_sequence_length, device
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Pack the text features into the full sequence; image positions carry no text features.
|
| 417 |
+
image_feature_padding = torch.zeros(
|
| 418 |
+
batch_size, num_image_tokens, text_features.shape[-1], dtype=text_features.dtype, device=device
|
| 419 |
+
)
|
| 420 |
+
prompt_embeds = torch.cat([text_features, image_feature_padding], dim=1)
|
| 421 |
|
| 422 |
+
return prompt_embeds, position_ids, segment_ids, indicator
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
def prepare_latents(
|
| 425 |
self,
|
|
|
|
| 440 |
latents = latents.to(device=device, dtype=dtype)
|
| 441 |
return latents
|
| 442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
@property
|
| 444 |
def guidance_scale(self) -> float | None:
|
| 445 |
return self._guidance_scale
|
|
|
|
| 452 |
def interrupt(self) -> bool:
|
| 453 |
return self._interrupt
|
| 454 |
|
| 455 |
+
def check_inputs(
|
| 456 |
+
self,
|
| 457 |
+
prompt,
|
| 458 |
+
height,
|
| 459 |
+
width,
|
| 460 |
+
num_inference_steps,
|
| 461 |
+
guidance_scale,
|
| 462 |
+
guidance_schedule,
|
| 463 |
+
callback_on_step_end_tensor_inputs=None,
|
| 464 |
+
):
|
| 465 |
+
if prompt is None:
|
| 466 |
+
raise ValueError("`prompt` must be provided.")
|
| 467 |
+
if not isinstance(prompt, (str, list)):
|
| 468 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 469 |
+
|
| 470 |
+
if (
|
| 471 |
+
height % (self.vae_scale_factor * self.patch_size) != 0
|
| 472 |
+
or width % (self.vae_scale_factor * self.patch_size) != 0
|
| 473 |
+
):
|
| 474 |
+
raise ValueError(
|
| 475 |
+
f"`height` ({height}) and `width` ({width}) must both be divisible by {self.vae_scale_factor * self.patch_size} "
|
| 476 |
+
f"(vae_scale_factor * patch_size)."
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Guidance is controlled by either a constant `guidance_scale` or a per-step `guidance_schedule`; exactly
|
| 480 |
+
# one must be set (the `guidance_schedule` default makes the no-arg call use the recommended schedule).
|
| 481 |
+
if guidance_scale is not None and guidance_schedule is not None:
|
| 482 |
+
raise ValueError("Only one of `guidance_scale` and `guidance_schedule` may be set.")
|
| 483 |
+
if guidance_scale is None and guidance_schedule is None:
|
| 484 |
+
raise ValueError("One of `guidance_scale` and `guidance_schedule` must be set.")
|
| 485 |
+
if guidance_schedule is not None and len(guidance_schedule) != num_inference_steps:
|
| 486 |
+
raise ValueError(
|
| 487 |
+
f"`guidance_schedule` must have length `num_inference_steps` ({num_inference_steps}), "
|
| 488 |
+
f"got {len(guidance_schedule)}."
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 492 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 493 |
+
):
|
| 494 |
+
raise ValueError(
|
| 495 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
|
| 496 |
+
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
@torch.no_grad()
|
| 500 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 501 |
def __call__(
|
|
|
|
| 503 |
prompt: str | list[str] | None = None,
|
| 504 |
height: int = 2048,
|
| 505 |
width: int = 2048,
|
| 506 |
+
num_inference_steps: int = 48,
|
| 507 |
guidance_scale: float | None = None,
|
| 508 |
+
guidance_schedule: list[float] | torch.Tensor | None = (7.0,) * 45 + (3.0,) * 3,
|
| 509 |
+
mu: float = 0.0,
|
| 510 |
+
std: float = 1.5,
|
| 511 |
+
prompt_upsampling: bool = False,
|
| 512 |
max_sequence_length: int = 2048,
|
| 513 |
num_images_per_prompt: int = 1,
|
| 514 |
generator: torch.Generator | list[torch.Generator] | None = None,
|
|
|
|
| 516 |
output_type: str = "pil",
|
| 517 |
return_dict: bool = True,
|
| 518 |
callback_on_step_end: Callable[["Ideogram4Pipeline", int, int, dict[str, Any]], dict[str, Any]] | None = None,
|
| 519 |
+
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
| 520 |
) -> Ideogram4PipelineOutput | tuple[Any]:
|
| 521 |
r"""
|
| 522 |
Run text-to-image generation.
|
|
|
|
| 535 |
velocity predictions are blended as `v = guidance_scale * v_pos + (1 - guidance_scale) * v_neg`.
|
| 536 |
Mutually exclusive with `guidance_schedule` (setting both raises). Defaults to `None`.
|
| 537 |
guidance_schedule (`list[float]` or `torch.Tensor`, *optional*):
|
| 538 |
+
Per-step guidance scale schedule; must have length `num_inference_steps`. The first entry corresponds
|
| 539 |
+
to the first step (largest noise level). Mutually exclusive with `guidance_scale`; exactly one must be
|
| 540 |
+
set. Defaults to the recommended schedule (7.0 for the main steps, dropping to 3.0 for the final 3
|
| 541 |
+
"polish" steps). To use a constant scale instead, pass `guidance_scale` and `guidance_schedule=None`.
|
|
|
|
| 542 |
mu (`float`, *optional*, defaults to 0.0):
|
| 543 |
Base mean of the logit-normal flow-matching schedule. The schedule mean is shifted by half the log of
|
| 544 |
the resolution ratio relative to 512x512.
|
| 545 |
std (`float`, *optional*, defaults to 1.5):
|
| 546 |
Standard deviation of the logit-normal flow-matching schedule.
|
| 547 |
+
prompt_upsampling (`bool`, *optional*, defaults to `False`):
|
| 548 |
+
If `True`, rewrite `prompt` into Ideogram4's native structured JSON caption via
|
| 549 |
+
[`~Ideogram4Pipeline.upsample_prompt`] before encoding. Requires the prompt-enhancer LM head
|
| 550 |
+
(downloaded on first use); install `outlines` for schema-constrained captions.
|
| 551 |
max_sequence_length (`int`, *optional*, defaults to 2048):
|
| 552 |
Maximum number of text tokens per prompt.
|
| 553 |
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
|
|
| 570 |
Returns:
|
| 571 |
[`~pipelines.ideogram4.Ideogram4PipelineOutput`] or `tuple`.
|
| 572 |
"""
|
| 573 |
+
self.check_inputs(
|
| 574 |
+
prompt=prompt,
|
| 575 |
+
height=height,
|
| 576 |
+
width=width,
|
| 577 |
+
num_inference_steps=num_inference_steps,
|
| 578 |
+
guidance_scale=guidance_scale,
|
| 579 |
+
guidance_schedule=guidance_schedule,
|
| 580 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 581 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
|
| 583 |
+
if isinstance(prompt, str):
|
| 584 |
+
batch_size = 1
|
| 585 |
+
elif isinstance(prompt, list):
|
| 586 |
+
batch_size = len(prompt)
|
|
|
|
|
|
|
| 587 |
|
| 588 |
device = self._execution_device
|
| 589 |
self._guidance_scale = guidance_scale
|
| 590 |
self._interrupt = False
|
| 591 |
|
| 592 |
+
# 0. Optionally rewrite the prompt(s) into Ideogram4's native structured JSON caption.
|
| 593 |
+
if prompt_upsampling:
|
| 594 |
+
prompt = self.upsample_prompt(prompt, height=height, width=width, device=device)
|
| 595 |
+
|
| 596 |
+
# 1. Image grid (drives both the packed layout and the latent shape).
|
| 597 |
+
grid_h, grid_w = (
|
| 598 |
+
height // (self.vae_scale_factor * self.patch_size),
|
| 599 |
+
width // (self.vae_scale_factor * self.patch_size),
|
| 600 |
)
|
| 601 |
+
num_image_tokens = grid_h * grid_w
|
| 602 |
+
|
| 603 |
+
# 2. Encode prompts into the packed conditioning (one entry per prompt).
|
| 604 |
+
llm_features, position_ids, segment_ids, indicator = self.encode_prompt(
|
| 605 |
+
prompt=prompt,
|
| 606 |
+
grid_h=grid_h,
|
| 607 |
+
grid_w=grid_w,
|
| 608 |
+
max_sequence_length=max_sequence_length,
|
|
|
|
|
|
|
| 609 |
device=device,
|
| 610 |
)
|
| 611 |
|
| 612 |
+
# 3. Replicate the conditioning for num_images_per_prompt.
|
| 613 |
+
llm_features = _expand_tensor_to_effective_batch(llm_features, batch_size, num_images_per_prompt)
|
| 614 |
+
position_ids = _expand_tensor_to_effective_batch(position_ids, batch_size, num_images_per_prompt)
|
| 615 |
+
segment_ids = _expand_tensor_to_effective_batch(segment_ids, batch_size, num_images_per_prompt)
|
| 616 |
+
indicator = _expand_tensor_to_effective_batch(indicator, batch_size, num_images_per_prompt)
|
| 617 |
+
|
| 618 |
+
# 4. Unconditional (image-only) branch, derived from the conditioning: zeroed text features and the
|
| 619 |
+
# image-region slices of the layout.
|
| 620 |
+
neg_llm_features = torch.zeros(
|
| 621 |
+
batch_size * num_images_per_prompt,
|
| 622 |
+
num_image_tokens,
|
| 623 |
+
llm_features.shape[-1],
|
| 624 |
+
dtype=llm_features.dtype,
|
| 625 |
+
device=device,
|
| 626 |
+
)
|
| 627 |
+
neg_position_ids = position_ids[:, max_sequence_length:]
|
| 628 |
+
neg_segment_ids = segment_ids[:, max_sequence_length:]
|
| 629 |
+
neg_indicator = indicator[:, max_sequence_length:]
|
| 630 |
|
| 631 |
# 4. Set up the resolution-aware logit-normal schedule on the scheduler.
|
| 632 |
schedule_mu = _resolution_aware_mu(height=height, width=width, base_mu=mu)
|
|
|
|
| 635 |
timesteps = self.scheduler.timesteps
|
| 636 |
self._num_timesteps = len(timesteps)
|
| 637 |
|
| 638 |
+
# 5. Resolve the per-step guidance schedule (a constant `guidance_scale` broadcasts to every step, otherwise
|
| 639 |
+
# use the provided `guidance_schedule`, validated by `check_inputs`) and the tensor of per-step weights `gw`.
|
| 640 |
if guidance_scale is not None:
|
| 641 |
+
guidance_schedule = [float(guidance_scale)] * num_inference_steps
|
| 642 |
+
gw = torch.as_tensor(guidance_schedule, dtype=torch.float32, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
|
| 644 |
# 6. Prepare latents in the packed (B, num_image_tokens, latent_dim) layout.
|
| 645 |
latent_dim = self.transformer.config.in_channels
|
| 646 |
latents = self.prepare_latents(
|
| 647 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 648 |
num_image_tokens=num_image_tokens,
|
| 649 |
latent_dim=latent_dim,
|
| 650 |
dtype=torch.float32,
|
|
|
|
| 653 |
latents=latents,
|
| 654 |
)
|
| 655 |
|
| 656 |
+
# 7. Padding for the text region of the conditional packed sequence (image latents are appended after it).
|
| 657 |
max_text_tokens = max_sequence_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
text_z_padding = torch.zeros(
|
| 659 |
+
batch_size * num_images_per_prompt,
|
| 660 |
max_text_tokens,
|
| 661 |
latent_dim,
|
| 662 |
dtype=torch.float32,
|
| 663 |
device=device,
|
| 664 |
)
|
| 665 |
|
| 666 |
+
# The transformers run in their loaded compute dtype; cast the (otherwise float32) text features to match.
|
| 667 |
+
# `latents` stay float32 for scheduler precision and are cast per-step at the transformer call below.
|
| 668 |
+
llm_features = llm_features.to(self.transformer.dtype)
|
| 669 |
+
neg_llm_features = neg_llm_features.to(self.unconditional_transformer.dtype)
|
| 670 |
+
|
| 671 |
# 8. Denoising loop. The scheduler stores `num_train_timesteps`-scaled timesteps; convert back to model time.
|
| 672 |
num_train_timesteps = self.scheduler.config.num_train_timesteps
|
| 673 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
|
|
| 677 |
|
| 678 |
# Map sigma-domain timestep to model time `t` in [0, 1] (0 = noise, 1 = clean data).
|
| 679 |
t_model = 1.0 - (t.float() / num_train_timesteps)
|
| 680 |
+
t_model = t_model.expand(batch_size * num_images_per_prompt).to(self.transformer.dtype)
|
| 681 |
|
| 682 |
# Conditional pass operates on the full packed sequence.
|
| 683 |
+
pos_z = torch.cat([text_z_padding, latents], dim=1).to(self.transformer.dtype)
|
| 684 |
pos_out = self.transformer(
|
| 685 |
hidden_states=pos_z,
|
| 686 |
timestep=t_model,
|
| 687 |
encoder_hidden_states=llm_features,
|
| 688 |
+
position_ids=position_ids,
|
| 689 |
+
segment_ids=segment_ids,
|
| 690 |
+
indicator=indicator,
|
| 691 |
return_dict=False,
|
| 692 |
)[0]
|
| 693 |
+
# Velocity (and guidance) is computed in float32 for scheduler precision; the transformers
|
| 694 |
+
# return their compute dtype, so cast the predicted velocities up here.
|
| 695 |
+
pos_v = pos_out[:, max_text_tokens:].to(torch.float32)
|
| 696 |
|
| 697 |
# Unconditional pass uses image-only positions with zeroed text features.
|
| 698 |
neg_v = self.unconditional_transformer(
|
| 699 |
+
hidden_states=latents.to(self.unconditional_transformer.dtype),
|
| 700 |
timestep=t_model,
|
| 701 |
encoder_hidden_states=neg_llm_features,
|
| 702 |
position_ids=neg_position_ids,
|
| 703 |
segment_ids=neg_segment_ids,
|
| 704 |
indicator=neg_indicator,
|
| 705 |
return_dict=False,
|
| 706 |
+
)[0].to(torch.float32)
|
| 707 |
|
| 708 |
+
# Expose the current step's guidance weight via `self.guidance_scale` so callbacks can read it.
|
| 709 |
+
self._guidance_scale = guidance_schedule[i]
|
| 710 |
gw_i = gw[i]
|
| 711 |
v = gw_i * pos_v + (1.0 - gw_i) * neg_v
|
| 712 |
|
| 713 |
+
latents = self.scheduler.step(-v, t, latents, return_dict=False)[0]
|
| 714 |
|
| 715 |
if callback_on_step_end is not None:
|
| 716 |
callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
|
|
|
|
| 719 |
|
| 720 |
progress_bar.update()
|
| 721 |
|
| 722 |
+
# 9. Decode: unpatch the latents, denormalize with the VAE batch-norm stats, and decode through the VAE.
|
| 723 |
if output_type == "latent":
|
| 724 |
image = latents
|
| 725 |
else:
|
| 726 |
+
z = latents
|
| 727 |
+
# VAE bn stores per-channel statistics on the packed-channel latent space (ae_channels * patch ** 2).
|
| 728 |
+
bn_mean = self.vae.bn.running_mean.view(1, 1, -1).to(device=z.device, dtype=z.dtype)
|
| 729 |
+
bn_std = torch.sqrt(self.vae.bn.running_var + self.vae.config.batch_norm_eps).view(1, 1, -1)
|
| 730 |
+
bn_std = bn_std.to(device=z.device, dtype=z.dtype)
|
| 731 |
+
z = z * bn_std + bn_mean
|
| 732 |
+
|
| 733 |
+
patch = self.patch_size
|
| 734 |
+
ae_channels = z.shape[-1] // (patch * patch)
|
| 735 |
+
z = z.view(batch_size * num_images_per_prompt, grid_h, grid_w, patch, patch, ae_channels)
|
| 736 |
+
z = z.permute(0, 5, 1, 3, 2, 4).contiguous()
|
| 737 |
+
z = z.view(batch_size * num_images_per_prompt, ae_channels, grid_h * patch, grid_w * patch)
|
| 738 |
+
|
| 739 |
+
decoded = self.vae.decode(z.to(self.vae.dtype), return_dict=False)[0]
|
| 740 |
image = self.image_processor.postprocess(decoded.float(), output_type=output_type)
|
| 741 |
|
| 742 |
self.maybe_free_model_hooks()
|
diffusers_src/src/diffusers/pipelines/ideogram4/prompt_enhancer.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Ideogram AI and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Prompt-enhancement assets for Ideogram4.
|
| 16 |
+
|
| 17 |
+
Ideogram4 is trained on a *structured JSON caption* rather than a free-form prompt. The optional prompt
|
| 18 |
+
enhancer rewrites a short user idea into that native caption schema, using the pipeline's own (frozen)
|
| 19 |
+
Qwen3-VL text encoder grafted with a generative head (see `Ideogram4Pipeline.load_prompt_enhancer`).
|
| 20 |
+
|
| 21 |
+
This mirrors the role of Flux2's `system_messages.py`, but the target is a constrained JSON object instead of
|
| 22 |
+
free text, so `outlines` (an optional dependency) is used to guarantee a schema-valid result when available.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# System message that instructs the encoder to emit Ideogram4's native single-line JSON caption.
|
| 26 |
+
CAPTION_SYSTEM_MESSAGE = """You convert a short user idea into a structured JSON caption for an image renderer. Output ONE minified single-line JSON object and NOTHING else (no markdown, no commentary).
|
| 27 |
+
|
| 28 |
+
SCHEMA — keys in this exact order:
|
| 29 |
+
{"high_level_description":"...","compositional_deconstruction":{"background":"...","elements":[ ... ]}}
|
| 30 |
+
- object element: {"type":"obj","desc":"..."}
|
| 31 |
+
- text element: {"type":"text","text":"VERBATIM CHARS","desc":"..."}
|
| 32 |
+
|
| 33 |
+
STEP 1 — PICK THE MEDIUM. It decides what `background` and `elements` mean. Honor any medium or style the user implies; default to photograph only when nothing else fits. Render ANY subject faithfully — real, fantastical, sci-fi, surreal, abstract — in the chosen medium.
|
| 34 |
+
|
| 35 |
+
A) DESIGNED ARTIFACT — poster, logo, album/book cover, flyer, banner, sticker, packaging, app icon, infographic, menu, card, wordmark. THE FRAME IS THE ARTIFACT — never a photo of it hanging in a room.
|
| 36 |
+
- high_level_description: name it as graphic design (e.g. "a minimalist jazz poster, flat graphic design...").
|
| 37 |
+
- background: the design's OWN backdrop only — a flat color, gradient, or simple texture filling the frame. No room, wall, floor, easel, depth, or camera/photo language.
|
| 38 |
+
- elements: the design's parts as a flat 2D layout — a `text` element for every headline/label (verbatim), `obj` elements for the central graphic/illustration/shapes/badges. Place by region (top / center / bottom).
|
| 39 |
+
|
| 40 |
+
B) SCENE — a photograph, illustration, painting, 3D render, anime frame, etc. of a real or imagined place or subject.
|
| 41 |
+
- high_level_description: one sentence naming the subject and the medium/style.
|
| 42 |
+
- background: the scene SHELL — surroundings, ground/sky/walls, atmosphere, ambient light; concrete and specific. The ground/floor/water surface lives here, never as an element.
|
| 43 |
+
- elements: the main subject FIRST as an `obj`, then supporting `obj` elements (props, secondary subjects) that plausibly belong. Add `text` elements only where the scene would really carry text (signs, labels, brands).
|
| 44 |
+
|
| 45 |
+
C) ABSTRACT / CONCEPTUAL — "nostalgia", "chaos and order", "sound waves", pure pattern. Concretize the idea into a deliberate visual composition.
|
| 46 |
+
- background: the dominant color field, gradient, or texture of the composition.
|
| 47 |
+
- elements: the shapes, forms, motifs, or symbolic objects that carry the concept, as `obj` elements. Add `text` only if the idea calls for words.
|
| 48 |
+
|
| 49 |
+
UNIVERSAL RULES (every medium):
|
| 50 |
+
1. The user's core subject/concept MUST appear among the elements (as an `obj`, normally first). Naming it only in high_level_description or background is NOT enough.
|
| 51 |
+
2. Commit to ONE concrete value each (one color, one style, one count). No hedging: ban "various", "such as", "e.g.", "or similar", "maybe", "X or Y" for one property.
|
| 52 |
+
3. NEVER use a transparent, empty, or plain white background UNLESS the user explicitly says "transparent", "isolated", "sticker", or "cutout".
|
| 53 |
+
4. A coherent subject (one animal, person, vehicle, object) is exactly ONE element; its parts go inside its `desc`. Use separate elements for genuinely separate subjects.
|
| 54 |
+
5. Each `desc` is 25-55 words, identity-first, standalone. Do not mention shadows, depth of field, bokeh, lens, focus, or grain.
|
| 55 |
+
6. high_level_description: one sentence, at most 40 words, starts with the subject, names the medium. Preserve non-ASCII characters as-is.
|
| 56 |
+
7. Output STRICTLY VALID JSON: double quotes around every key and string, NO trailing commas, each element object closes with "}" right after its last value.
|
| 57 |
+
8. Catch the "warm" impulse. Only when you are about to describe light as "warm", "golden", "amber", or "honey", stop and check: is there a specific physical source in the scene casting that colour (candle, sunset, lamp, neon, fire)? If YES, name the source and the colour it casts instead of the mood word. If NO, you are just reaching for warmth as ambience — drop it and leave the light neutral ("soft" or "even"). Don't recolour or relight anything else; this only intercepts the warm reach, every other scene and mood the user wants is untouched.
|
| 58 |
+
9. Describe physical reality, not impressions. Avoid mood-words — "luminous", "radiant", "vibrant", "lush", "dynamic", "gorgeous", "stunning", "breathtaking", "mesmerizing", and metaphorical "glowing" — they produce a generic AI look (the same trap as "warm"). Use observable properties: "the cheekbone catches a small highlight", not "luminous complexion".
|
| 59 |
+
10. Every named thing must appear as its own element. Each subject, object, sign, and quoted phrase the user names gets its own element — quoted text (single or double quotes) becomes its own verbatim `text` element. Count the named units in the prompt; the element list must hold at least that many. Don't drop or merge them.
|
| 60 |
+
11. Don't add what wasn't asked for. No glitch art, wireframe overlay, body fragmentation, double-exposure, "dissolving", or extra stylization unless the prompt requests it. Asked for a cinematic photo of a journalist → render that, not a glitch-art composite.
|
| 61 |
+
12. Name attributes concretely, anchored to landmarks. People: skin tone, hair (colour + style), each visible garment with colour, expression, pose, one distinguishing feature. Objects: shape, material, colour, a distinctive part. Place things against named references — "resting on the lower-right corner of the table", not "on the surface".
|
| 62 |
+
13. Name real references by name. If the user names a brand, product, character, place, or person (Nike Dunk Low, Spider-Man, the Eiffel Tower), keep that exact name in the `desc`; don't swap it for a generic look-alike unless they ask for an anonymous one.
|
| 63 |
+
14. "Professional photo/headshot" of a person means professional CONTEXT — neutral attire, soft even daylight, neutral backdrop, friendly expression — not dramatic studio gear; no heavy rim-light or creamy bokeh unless asked.
|
| 64 |
+
|
| 65 |
+
EXAMPLES
|
| 66 |
+
|
| 67 |
+
User idea: a cup of coffee on a table
|
| 68 |
+
Output: {"high_level_description":"A white ceramic cup of black coffee on a worn wooden cafe table, a casual overcast-daylight phone photograph with an off-center composition.","compositional_deconstruction":{"background":"Scratched oak cafe table filling the lower frame, a pale grey mortar-lined brick wall a few feet behind slightly out of focus, a tall window on the left spilling soft overcast daylight across the table, neutral white balance, muted brown and green tones.","elements":[{"type":"obj","desc":"White ceramic cup of black coffee with a thin curved handle turned to the right and a faint crema ring at the rim, resting on a matching round saucer near the center of the table, a thin wisp of steam at the surface."},{"type":"obj","desc":"Brushed-steel teaspoon lying on the saucer to the right of the cup, handle angled toward the lower-right corner, a single small water droplet on the bowl of the spoon."}]}}
|
| 69 |
+
|
| 70 |
+
User idea: a minimalist poster for a jazz festival
|
| 71 |
+
Output: {"high_level_description":"A minimalist jazz festival poster, flat graphic design with bold typography and a single abstract saxophone motif on a deep teal background.","compositional_deconstruction":{"background":"Solid deep teal background filling the entire frame with a subtle fine paper-grain texture and a thin mustard-yellow keyline border just inside the edges, no scene and no depth.","elements":[{"type":"obj","desc":"A large flat geometric saxophone in mustard yellow and cream, centered in the upper two-thirds, built from simple bold shapes with no shading, angled diagonally from lower-left to upper-right."},{"type":"text","text":"JAZZ\\nFESTIVAL","desc":"Large bold condensed sans-serif headline in cream, stacked on two lines across the center of the poster, slightly overlapping the saxophone motif."},{"type":"text","text":"NOV 15 · CITY HALL","desc":"Small uppercase mustard-yellow caption centered near the bottom edge with wide letter spacing."}]}}"""
|
| 72 |
+
|
| 73 |
+
# User turn. `{aspect_ratio}` and `{original_prompt}` are filled in by `Ideogram4Pipeline.upsample_prompt`.
|
| 74 |
+
CAPTION_USER_TEMPLATE = """TARGET IMAGE ASPECT RATIO: {aspect_ratio} (width:height).
|
| 75 |
+
User idea: {original_prompt}"""
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def build_caption_logits_processor(model, tokenizer):
|
| 79 |
+
"""Build an `outlines` logits processor that constrains generation to the Ideogram4 caption schema.
|
| 80 |
+
|
| 81 |
+
Returns a logits processor compatible with `transformers` `generate(logits_processor=[...])`. The caller is
|
| 82 |
+
responsible for checking `is_outlines_available()` first; `outlines` (and its `pydantic` dependency) are
|
| 83 |
+
imported lazily here so they remain optional. The schema mirrors Ideogram's native caption /
|
| 84 |
+
caption_verifier: a high-level description plus a compositional deconstruction of background + typed elements.
|
| 85 |
+
"""
|
| 86 |
+
from typing import List, Literal, Union
|
| 87 |
+
|
| 88 |
+
import outlines
|
| 89 |
+
from pydantic import BaseModel, Field
|
| 90 |
+
|
| 91 |
+
class ObjElement(BaseModel):
|
| 92 |
+
type: Literal["obj"]
|
| 93 |
+
desc: str
|
| 94 |
+
|
| 95 |
+
class TextElement(BaseModel):
|
| 96 |
+
type: Literal["text"]
|
| 97 |
+
text: str
|
| 98 |
+
desc: str
|
| 99 |
+
|
| 100 |
+
class Composition(BaseModel):
|
| 101 |
+
background: str
|
| 102 |
+
elements: List[Union[ObjElement, TextElement]] = Field(min_length=1)
|
| 103 |
+
|
| 104 |
+
class Caption(BaseModel):
|
| 105 |
+
high_level_description: str
|
| 106 |
+
compositional_deconstruction: Composition
|
| 107 |
+
|
| 108 |
+
outlines_model = outlines.from_transformers(model, tokenizer)
|
| 109 |
+
return outlines.Generator(outlines_model, Caption).logits_processor
|
diffusers_src/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
CHANGED
|
@@ -206,12 +206,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
|
| 206 |
module._parameters[tensor_name] = new_value
|
| 207 |
|
| 208 |
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
|
| 209 |
-
import math
|
| 210 |
-
|
| 211 |
current_param_shape = current_param.shape
|
| 212 |
loaded_param_shape = loaded_param.shape
|
| 213 |
|
| 214 |
-
n =
|
| 215 |
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
|
| 216 |
if loaded_param_shape != inferred_shape:
|
| 217 |
raise ValueError(
|
|
|
|
| 206 |
module._parameters[tensor_name] = new_value
|
| 207 |
|
| 208 |
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
|
|
|
|
|
|
|
| 209 |
current_param_shape = current_param.shape
|
| 210 |
loaded_param_shape = loaded_param.shape
|
| 211 |
|
| 212 |
+
n = current_param_shape.numel()
|
| 213 |
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
|
| 214 |
if loaded_param_shape != inferred_shape:
|
| 215 |
raise ValueError(
|
diffusers_src/src/diffusers/utils/__init__.py
CHANGED
|
@@ -101,6 +101,7 @@ from .import_utils import (
|
|
| 101 |
is_opencv_available,
|
| 102 |
is_optimum_quanto_available,
|
| 103 |
is_optimum_quanto_version,
|
|
|
|
| 104 |
is_peft_available,
|
| 105 |
is_peft_version,
|
| 106 |
is_pytorch_retinaface_available,
|
|
|
|
| 101 |
is_opencv_available,
|
| 102 |
is_optimum_quanto_available,
|
| 103 |
is_optimum_quanto_version,
|
| 104 |
+
is_outlines_available,
|
| 105 |
is_peft_available,
|
| 106 |
is_peft_version,
|
| 107 |
is_pytorch_retinaface_available,
|
diffusers_src/src/diffusers/utils/import_utils.py
CHANGED
|
@@ -204,6 +204,7 @@ _wandb_available, _wandb_version = _is_package_available("wandb")
|
|
| 204 |
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
|
| 205 |
_compel_available, _compel_version = _is_package_available("compel")
|
| 206 |
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
|
|
|
|
| 207 |
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
|
| 208 |
_peft_available, _peft_version = _is_package_available("peft")
|
| 209 |
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
|
|
@@ -370,6 +371,10 @@ def is_sentencepiece_available():
|
|
| 370 |
return _sentencepiece_available
|
| 371 |
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
def is_imageio_available():
|
| 374 |
return _imageio_available
|
| 375 |
|
|
|
|
| 204 |
_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard")
|
| 205 |
_compel_available, _compel_version = _is_package_available("compel")
|
| 206 |
_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece")
|
| 207 |
+
_outlines_available, _outlines_version = _is_package_available("outlines")
|
| 208 |
_torchsde_available, _torchsde_version = _is_package_available("torchsde")
|
| 209 |
_peft_available, _peft_version = _is_package_available("peft")
|
| 210 |
_torchvision_available, _torchvision_version = _is_package_available("torchvision")
|
|
|
|
| 371 |
return _sentencepiece_available
|
| 372 |
|
| 373 |
|
| 374 |
+
def is_outlines_available():
|
| 375 |
+
return _outlines_available
|
| 376 |
+
|
| 377 |
+
|
| 378 |
def is_imageio_available():
|
| 379 |
return _imageio_available
|
| 380 |
|