| import torch |
| from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer |
| from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention, |
| FeedForward) |
| from .ltx2_common import rms_norm |
|
|
|
|
| class LTX2TextEncoder(Gemma3ForConditionalGeneration): |
| def __init__(self): |
| config = Gemma3Config( |
| **{ |
| "architectures": ["Gemma3ForConditionalGeneration"], |
| "boi_token_index": 255999, |
| "dtype": "bfloat16", |
| "eoi_token_index": 256000, |
| "eos_token_id": [1, 106], |
| "image_token_index": 262144, |
| "initializer_range": 0.02, |
| "mm_tokens_per_image": 256, |
| "model_type": "gemma3", |
| "text_config": { |
| "_sliding_window_pattern": 6, |
| "attention_bias": False, |
| "attention_dropout": 0.0, |
| "attn_logit_softcapping": None, |
| "cache_implementation": "hybrid", |
| "dtype": "bfloat16", |
| "final_logit_softcapping": None, |
| "head_dim": 256, |
| "hidden_activation": "gelu_pytorch_tanh", |
| "hidden_size": 3840, |
| "initializer_range": 0.02, |
| "intermediate_size": 15360, |
| "layer_types": [ |
| "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", |
| "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", |
| "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", |
| "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", |
| "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", |
| "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", |
| "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", |
| "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", |
| "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", |
| "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", |
| "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", |
| "sliding_attention", "sliding_attention", "sliding_attention", "full_attention" |
| ], |
| "max_position_embeddings": 131072, |
| "model_type": "gemma3_text", |
| "num_attention_heads": 16, |
| "num_hidden_layers": 48, |
| "num_key_value_heads": 8, |
| "query_pre_attn_scalar": 256, |
| "rms_norm_eps": 1e-06, |
| "rope_local_base_freq": 10000, |
| "rope_scaling": { |
| "factor": 8.0, |
| "rope_type": "linear" |
| }, |
| "rope_theta": 1000000, |
| "sliding_window": 1024, |
| "sliding_window_pattern": 6, |
| "use_bidirectional_attention": False, |
| "use_cache": True, |
| "vocab_size": 262208 |
| }, |
| "transformers_version": "4.57.3", |
| "vision_config": { |
| "attention_dropout": 0.0, |
| "dtype": "bfloat16", |
| "hidden_act": "gelu_pytorch_tanh", |
| "hidden_size": 1152, |
| "image_size": 896, |
| "intermediate_size": 4304, |
| "layer_norm_eps": 1e-06, |
| "model_type": "siglip_vision_model", |
| "num_attention_heads": 16, |
| "num_channels": 3, |
| "num_hidden_layers": 27, |
| "patch_size": 14, |
| "vision_use_head": False |
| } |
| }) |
| super().__init__(config) |
|
|
|
|
| class LTXVGemmaTokenizer: |
| """ |
| Tokenizer wrapper for Gemma models compatible with LTXV processes. |
| This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders, |
| ensuring correct settings and output formatting for downstream consumption. |
| """ |
|
|
| def __init__(self, tokenizer_path: str, max_length: int = 1024): |
| """ |
| Initialize the tokenizer. |
| Args: |
| tokenizer_path (str): Path to the pretrained tokenizer files or model directory. |
| max_length (int, optional): Max sequence length for encoding. Defaults to 256. |
| """ |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_path, local_files_only=True, model_max_length=max_length |
| ) |
| |
| self.tokenizer.padding_side = "left" |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| self.max_length = max_length |
|
|
| def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]: |
| """ |
| Tokenize the given text and return token IDs and attention weights. |
| Args: |
| text (str): The input string to tokenize. |
| return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples. |
| If False (default), omits the indices. |
| Returns: |
| dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]: |
| A dictionary with a "gemma" key mapping to: |
| - a list of (token_id, attention_mask) tuples if return_word_ids is False; |
| - a list of (token_id, attention_mask, index) tuples if return_word_ids is True. |
| Example: |
| >>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8) |
| >>> tokenizer.tokenize_with_weights("hello world") |
| {'gemma': [(1234, 1), (5678, 1), (2, 0), ...]} |
| """ |
| text = text.strip() |
| encoded = self.tokenizer( |
| text, |
| padding="max_length", |
| max_length=self.max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| input_ids = encoded.input_ids |
| attention_mask = encoded.attention_mask |
| tuples = [ |
| (token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True)) |
| ] |
| out = {"gemma": tuples} |
|
|
| if not return_word_ids: |
| |
| out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()} |
|
|
| return out |
|
|
|
|
| class GemmaFeaturesExtractorProjLinear(torch.nn.Module): |
| """ |
| Feature extractor module for Gemma models. |
| This module applies a single linear projection to the input tensor. |
| It expects a flattened feature tensor of shape (batch_size, 3840*49). |
| The linear layer maps this to a (batch_size, 3840) embedding. |
| Attributes: |
| aggregate_embed (torch.nn.Linear): Linear projection layer. |
| """ |
|
|
| def __init__(self) -> None: |
| """ |
| Initialize the GemmaFeaturesExtractorProjLinear module. |
| The input dimension is expected to be 3840 * 49, and the output is 3840. |
| """ |
| super().__init__() |
| self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass for the feature extractor. |
| Args: |
| x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49). |
| Returns: |
| torch.Tensor: Output tensor of shape (batch_size, 3840). |
| """ |
| return self.aggregate_embed(x) |
|
|
|
|
| class _BasicTransformerBlock1D(torch.nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| heads: int, |
| dim_head: int, |
| rope_type: LTXRopeType = LTXRopeType.INTERLEAVED, |
| ): |
| super().__init__() |
|
|
| self.attn1 = Attention( |
| query_dim=dim, |
| heads=heads, |
| dim_head=dim_head, |
| rope_type=rope_type, |
| ) |
|
|
| self.ff = FeedForward( |
| dim, |
| dim_out=dim, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| pe: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| |
|
|
| |
| norm_hidden_states = rms_norm(hidden_states) |
|
|
| norm_hidden_states = norm_hidden_states.squeeze(1) |
|
|
| |
| attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) |
|
|
| hidden_states = attn_output + hidden_states |
| if hidden_states.ndim == 4: |
| hidden_states = hidden_states.squeeze(1) |
|
|
| |
| norm_hidden_states = rms_norm(hidden_states) |
|
|
| |
| ff_output = self.ff(norm_hidden_states) |
|
|
| hidden_states = ff_output + hidden_states |
| if hidden_states.ndim == 4: |
| hidden_states = hidden_states.squeeze(1) |
|
|
| return hidden_states |
|
|
|
|
| class Embeddings1DConnector(torch.nn.Module): |
| """ |
| Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or |
| other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can |
| substitute padded positions with learnable registers. The module is highly configurable for head size, number of |
| layers, and register usage. |
| Args: |
| attention_head_dim (int): Dimension of each attention head (default=128). |
| num_attention_heads (int): Number of attention heads (default=30). |
| num_layers (int): Number of transformer layers (default=2). |
| positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0). |
| positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]). |
| causal_temporal_positioning (bool): If True, uses causal attention (default=False). |
| num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables |
| register replacement. (default=128) |
| rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE). |
| double_precision_rope (bool): Use double precision rope calculation (default=False). |
| """ |
|
|
| _supports_gradient_checkpointing = True |
|
|
| def __init__( |
| self, |
| attention_head_dim: int = 128, |
| num_attention_heads: int = 30, |
| num_layers: int = 2, |
| positional_embedding_theta: float = 10000.0, |
| positional_embedding_max_pos: list[int] | None = [4096], |
| causal_temporal_positioning: bool = False, |
| num_learnable_registers: int | None = 128, |
| rope_type: LTXRopeType = LTXRopeType.SPLIT, |
| double_precision_rope: bool = True, |
| ): |
| super().__init__() |
| self.num_attention_heads = num_attention_heads |
| self.inner_dim = num_attention_heads * attention_head_dim |
| self.causal_temporal_positioning = causal_temporal_positioning |
| self.positional_embedding_theta = positional_embedding_theta |
| self.positional_embedding_max_pos = ( |
| positional_embedding_max_pos if positional_embedding_max_pos is not None else [1] |
| ) |
| self.rope_type = rope_type |
| self.double_precision_rope = double_precision_rope |
| self.transformer_1d_blocks = torch.nn.ModuleList( |
| [ |
| _BasicTransformerBlock1D( |
| dim=self.inner_dim, |
| heads=num_attention_heads, |
| dim_head=attention_head_dim, |
| rope_type=rope_type, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.num_learnable_registers = num_learnable_registers |
| if self.num_learnable_registers: |
| self.learnable_registers = torch.nn.Parameter( |
| torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0 |
| ) |
|
|
| def _replace_padded_with_learnable_registers( |
| self, hidden_states: torch.Tensor, attention_mask: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| assert hidden_states.shape[1] % self.num_learnable_registers == 0, ( |
| f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers " |
| f"{self.num_learnable_registers}." |
| ) |
|
|
| num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers |
| learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1)) |
| attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int() |
|
|
| non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :] |
| non_zero_nums = non_zero_hidden_states.shape[1] |
| pad_length = hidden_states.shape[1] - non_zero_nums |
| adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0) |
| flipped_mask = torch.flip(attention_mask_binary, dims=[1]) |
| hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers |
|
|
| attention_mask = torch.full_like( |
| attention_mask, |
| 0.0, |
| dtype=attention_mask.dtype, |
| device=attention_mask.device, |
| ) |
|
|
| return hidden_states, attention_mask |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Forward pass of Embeddings1DConnector. |
| Args: |
| hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]). |
| attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states). |
| Returns: |
| tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask. |
| """ |
| if self.num_learnable_registers: |
| hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask) |
|
|
| indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device) |
| indices_grid = indices_grid[None, None, :] |
| freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch |
| freqs_cis = precompute_freqs_cis( |
| indices_grid=indices_grid, |
| dim=self.inner_dim, |
| out_dtype=hidden_states.dtype, |
| theta=self.positional_embedding_theta, |
| max_pos=self.positional_embedding_max_pos, |
| num_attention_heads=self.num_attention_heads, |
| rope_type=self.rope_type, |
| freq_grid_generator=freq_grid_generator, |
| ) |
|
|
| for block in self.transformer_1d_blocks: |
| hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis) |
|
|
| hidden_states = rms_norm(hidden_states) |
|
|
| return hidden_states, attention_mask |
|
|
|
|
| class LTX2TextEncoderPostModules(torch.nn.Module): |
| def __init__(self,): |
| super().__init__() |
| self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear() |
| self.embeddings_connector = Embeddings1DConnector() |
| self.audio_embeddings_connector = Embeddings1DConnector() |
|
|