Upload 4 files
Browse files- packages/ltx-core/src/ltx_core/text_encoders/gemma/__init__.py +31 -0
- packages/ltx-core/src/ltx_core/text_encoders/gemma/embeddings_connector.py +210 -0
- packages/ltx-core/src/ltx_core/text_encoders/gemma/feature_extractor.py +36 -0
- packages/ltx-core/src/ltx_core/text_encoders/gemma/tokenizer.py +64 -0
packages/ltx-core/src/ltx_core/text_encoders/gemma/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gemma text encoder components."""
|
| 2 |
+
|
| 3 |
+
from ltx_core.text_encoders.gemma.encoders.av_encoder import (
|
| 4 |
+
AV_GEMMA_TEXT_ENCODER_KEY_OPS,
|
| 5 |
+
AVGemmaEncoderOutput,
|
| 6 |
+
AVGemmaTextEncoderModel,
|
| 7 |
+
AVGemmaTextEncoderModelConfigurator,
|
| 8 |
+
)
|
| 9 |
+
from ltx_core.text_encoders.gemma.encoders.base_encoder import (
|
| 10 |
+
GemmaTextEncoderModelBase,
|
| 11 |
+
encode_text,
|
| 12 |
+
module_ops_from_gemma_root,
|
| 13 |
+
)
|
| 14 |
+
from ltx_core.text_encoders.gemma.encoders.video_only_encoder import (
|
| 15 |
+
VideoGemmaEncoderOutput,
|
| 16 |
+
VideoGemmaTextEncoderModel,
|
| 17 |
+
VideoGemmaTextEncoderModelConfigurator,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"AV_GEMMA_TEXT_ENCODER_KEY_OPS",
|
| 22 |
+
"AVGemmaEncoderOutput",
|
| 23 |
+
"AVGemmaTextEncoderModel",
|
| 24 |
+
"AVGemmaTextEncoderModelConfigurator",
|
| 25 |
+
"GemmaTextEncoderModelBase",
|
| 26 |
+
"VideoGemmaEncoderOutput",
|
| 27 |
+
"VideoGemmaTextEncoderModel",
|
| 28 |
+
"VideoGemmaTextEncoderModelConfigurator",
|
| 29 |
+
"encode_text",
|
| 30 |
+
"module_ops_from_gemma_root",
|
| 31 |
+
]
|
packages/ltx-core/src/ltx_core/text_encoders/gemma/embeddings_connector.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 4 |
+
from ltx_core.model.transformer.attention import Attention
|
| 5 |
+
from ltx_core.model.transformer.feed_forward import FeedForward
|
| 6 |
+
from ltx_core.model.transformer.rope import (
|
| 7 |
+
LTXRopeType,
|
| 8 |
+
generate_freq_grid_np,
|
| 9 |
+
generate_freq_grid_pytorch,
|
| 10 |
+
precompute_freqs_cis,
|
| 11 |
+
)
|
| 12 |
+
from ltx_core.utils import rms_norm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class _BasicTransformerBlock1D(torch.nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
heads: int,
|
| 20 |
+
dim_head: int,
|
| 21 |
+
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
self.attn1 = Attention(
|
| 26 |
+
query_dim=dim,
|
| 27 |
+
heads=heads,
|
| 28 |
+
dim_head=dim_head,
|
| 29 |
+
rope_type=rope_type,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
self.ff = FeedForward(
|
| 33 |
+
dim,
|
| 34 |
+
dim_out=dim,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def forward(
|
| 38 |
+
self,
|
| 39 |
+
hidden_states: torch.Tensor,
|
| 40 |
+
attention_mask: torch.Tensor | None = None,
|
| 41 |
+
pe: torch.Tensor | None = None,
|
| 42 |
+
) -> torch.Tensor:
|
| 43 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 44 |
+
|
| 45 |
+
# 1. Normalization Before Self-Attention
|
| 46 |
+
norm_hidden_states = rms_norm(hidden_states)
|
| 47 |
+
|
| 48 |
+
norm_hidden_states = norm_hidden_states.squeeze(1)
|
| 49 |
+
|
| 50 |
+
# 2. Self-Attention
|
| 51 |
+
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
| 52 |
+
|
| 53 |
+
hidden_states = attn_output + hidden_states
|
| 54 |
+
if hidden_states.ndim == 4:
|
| 55 |
+
hidden_states = hidden_states.squeeze(1)
|
| 56 |
+
|
| 57 |
+
# 3. Normalization before Feed-Forward
|
| 58 |
+
norm_hidden_states = rms_norm(hidden_states)
|
| 59 |
+
|
| 60 |
+
# 4. Feed-forward
|
| 61 |
+
ff_output = self.ff(norm_hidden_states)
|
| 62 |
+
|
| 63 |
+
hidden_states = ff_output + hidden_states
|
| 64 |
+
if hidden_states.ndim == 4:
|
| 65 |
+
hidden_states = hidden_states.squeeze(1)
|
| 66 |
+
|
| 67 |
+
return hidden_states
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Embeddings1DConnector(torch.nn.Module):
|
| 71 |
+
"""
|
| 72 |
+
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
| 73 |
+
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
| 74 |
+
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
|
| 75 |
+
layers, and register usage.
|
| 76 |
+
Args:
|
| 77 |
+
attention_head_dim (int): Dimension of each attention head (default=128).
|
| 78 |
+
num_attention_heads (int): Number of attention heads (default=30).
|
| 79 |
+
num_layers (int): Number of transformer layers (default=2).
|
| 80 |
+
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
|
| 81 |
+
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
|
| 82 |
+
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
|
| 83 |
+
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
|
| 84 |
+
register replacement. (default=128)
|
| 85 |
+
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
|
| 86 |
+
double_precision_rope (bool): Use double precision rope calculation (default=False).
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
_supports_gradient_checkpointing = True
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
attention_head_dim: int = 128,
|
| 94 |
+
num_attention_heads: int = 30,
|
| 95 |
+
num_layers: int = 2,
|
| 96 |
+
positional_embedding_theta: float = 10000.0,
|
| 97 |
+
positional_embedding_max_pos: list[int] | None = None,
|
| 98 |
+
causal_temporal_positioning: bool = False,
|
| 99 |
+
num_learnable_registers: int | None = 128,
|
| 100 |
+
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
| 101 |
+
double_precision_rope: bool = False,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.num_attention_heads = num_attention_heads
|
| 105 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 106 |
+
self.causal_temporal_positioning = causal_temporal_positioning
|
| 107 |
+
self.positional_embedding_theta = positional_embedding_theta
|
| 108 |
+
self.positional_embedding_max_pos = (
|
| 109 |
+
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
|
| 110 |
+
)
|
| 111 |
+
self.rope_type = rope_type
|
| 112 |
+
self.double_precision_rope = double_precision_rope
|
| 113 |
+
self.transformer_1d_blocks = torch.nn.ModuleList(
|
| 114 |
+
[
|
| 115 |
+
_BasicTransformerBlock1D(
|
| 116 |
+
dim=self.inner_dim,
|
| 117 |
+
heads=num_attention_heads,
|
| 118 |
+
dim_head=attention_head_dim,
|
| 119 |
+
rope_type=rope_type,
|
| 120 |
+
)
|
| 121 |
+
for _ in range(num_layers)
|
| 122 |
+
]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.num_learnable_registers = num_learnable_registers
|
| 126 |
+
if self.num_learnable_registers:
|
| 127 |
+
self.learnable_registers = torch.nn.Parameter(
|
| 128 |
+
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def _replace_padded_with_learnable_registers(
|
| 132 |
+
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
| 133 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 134 |
+
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
|
| 135 |
+
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
|
| 136 |
+
f"{self.num_learnable_registers}."
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
|
| 140 |
+
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
|
| 141 |
+
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
|
| 142 |
+
|
| 143 |
+
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
| 144 |
+
non_zero_nums = non_zero_hidden_states.shape[1]
|
| 145 |
+
pad_length = hidden_states.shape[1] - non_zero_nums
|
| 146 |
+
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
| 147 |
+
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
| 148 |
+
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
| 149 |
+
|
| 150 |
+
attention_mask = torch.full_like(
|
| 151 |
+
attention_mask,
|
| 152 |
+
0.0,
|
| 153 |
+
dtype=attention_mask.dtype,
|
| 154 |
+
device=attention_mask.device,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
return hidden_states, attention_mask
|
| 158 |
+
|
| 159 |
+
def forward(
|
| 160 |
+
self,
|
| 161 |
+
hidden_states: torch.Tensor,
|
| 162 |
+
attention_mask: torch.Tensor | None = None,
|
| 163 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 164 |
+
"""
|
| 165 |
+
Forward pass of Embeddings1DConnector.
|
| 166 |
+
Args:
|
| 167 |
+
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
|
| 168 |
+
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
|
| 169 |
+
Returns:
|
| 170 |
+
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
|
| 171 |
+
"""
|
| 172 |
+
if self.num_learnable_registers:
|
| 173 |
+
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
|
| 174 |
+
|
| 175 |
+
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
|
| 176 |
+
indices_grid = indices_grid[None, None, :]
|
| 177 |
+
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
|
| 178 |
+
freqs_cis = precompute_freqs_cis(
|
| 179 |
+
indices_grid=indices_grid,
|
| 180 |
+
dim=self.inner_dim,
|
| 181 |
+
out_dtype=hidden_states.dtype,
|
| 182 |
+
theta=self.positional_embedding_theta,
|
| 183 |
+
max_pos=self.positional_embedding_max_pos,
|
| 184 |
+
num_attention_heads=self.num_attention_heads,
|
| 185 |
+
rope_type=self.rope_type,
|
| 186 |
+
freq_grid_generator=freq_grid_generator,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
for block in self.transformer_1d_blocks:
|
| 190 |
+
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
|
| 191 |
+
|
| 192 |
+
hidden_states = rms_norm(hidden_states)
|
| 193 |
+
|
| 194 |
+
return hidden_states, attention_mask
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class Embeddings1DConnectorConfigurator(ModelConfigurator[Embeddings1DConnector]):
|
| 198 |
+
@classmethod
|
| 199 |
+
def from_config(cls: type[Embeddings1DConnector], config: dict) -> Embeddings1DConnector:
|
| 200 |
+
config = config.get("transformer", {})
|
| 201 |
+
rope_type = LTXRopeType(config.get("rope_type", "interleaved"))
|
| 202 |
+
double_precision_rope = config.get("frequencies_precision", False) == "float64"
|
| 203 |
+
pe_max_pos = config.get("connector_positional_embedding_max_pos", [1])
|
| 204 |
+
|
| 205 |
+
connector = Embeddings1DConnector(
|
| 206 |
+
positional_embedding_max_pos=pe_max_pos,
|
| 207 |
+
rope_type=rope_type,
|
| 208 |
+
double_precision_rope=double_precision_rope,
|
| 209 |
+
)
|
| 210 |
+
return connector
|
packages/ltx-core/src/ltx_core/text_encoders/gemma/feature_extractor.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ltx_core.model.model_protocol import ModelConfigurator
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class GemmaFeaturesExtractorProjLinear(torch.nn.Module, ModelConfigurator["GemmaFeaturesExtractorProjLinear"]):
|
| 7 |
+
"""
|
| 8 |
+
Feature extractor module for Gemma models.
|
| 9 |
+
This module applies a single linear projection to the input tensor.
|
| 10 |
+
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
| 11 |
+
The linear layer maps this to a (batch_size, 3840) embedding.
|
| 12 |
+
Attributes:
|
| 13 |
+
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self) -> None:
|
| 17 |
+
"""
|
| 18 |
+
Initialize the GemmaFeaturesExtractorProjLinear module.
|
| 19 |
+
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
| 20 |
+
"""
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
| 23 |
+
|
| 24 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 25 |
+
"""
|
| 26 |
+
Forward pass for the feature extractor.
|
| 27 |
+
Args:
|
| 28 |
+
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
| 29 |
+
Returns:
|
| 30 |
+
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
| 31 |
+
"""
|
| 32 |
+
return self.aggregate_embed(x)
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def from_config(cls: type["GemmaFeaturesExtractorProjLinear"], _config: dict) -> "GemmaFeaturesExtractorProjLinear":
|
| 36 |
+
return cls()
|
packages/ltx-core/src/ltx_core/text_encoders/gemma/tokenizer.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LTXVGemmaTokenizer:
|
| 5 |
+
"""
|
| 6 |
+
Tokenizer wrapper for Gemma models compatible with LTXV processes.
|
| 7 |
+
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
|
| 8 |
+
ensuring correct settings and output formatting for downstream consumption.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, tokenizer_path: str, max_length: int = 256, local_files_only: bool = True):
|
| 12 |
+
"""
|
| 13 |
+
Initialize the tokenizer.
|
| 14 |
+
Args:
|
| 15 |
+
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
|
| 16 |
+
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
|
| 17 |
+
"""
|
| 18 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 19 |
+
tokenizer_path, local_files_only=local_files_only, model_max_length=max_length
|
| 20 |
+
)
|
| 21 |
+
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
|
| 22 |
+
self.tokenizer.padding_side = "left"
|
| 23 |
+
if self.tokenizer.pad_token is None:
|
| 24 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 25 |
+
|
| 26 |
+
self.max_length = max_length
|
| 27 |
+
|
| 28 |
+
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
|
| 29 |
+
"""
|
| 30 |
+
Tokenize the given text and return token IDs and attention weights.
|
| 31 |
+
Args:
|
| 32 |
+
text (str): The input string to tokenize.
|
| 33 |
+
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
|
| 34 |
+
If False (default), omits the indices.
|
| 35 |
+
Returns:
|
| 36 |
+
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
|
| 37 |
+
A dictionary with a "gemma" key mapping to:
|
| 38 |
+
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
|
| 39 |
+
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
|
| 40 |
+
Example:
|
| 41 |
+
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
|
| 42 |
+
>>> tokenizer.tokenize_with_weights("hello world")
|
| 43 |
+
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
|
| 44 |
+
"""
|
| 45 |
+
text = text.strip()
|
| 46 |
+
encoded = self.tokenizer(
|
| 47 |
+
text,
|
| 48 |
+
padding="max_length",
|
| 49 |
+
max_length=self.max_length,
|
| 50 |
+
truncation=True,
|
| 51 |
+
return_tensors="pt",
|
| 52 |
+
)
|
| 53 |
+
input_ids = encoded.input_ids
|
| 54 |
+
attention_mask = encoded.attention_mask
|
| 55 |
+
tuples = [
|
| 56 |
+
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
|
| 57 |
+
]
|
| 58 |
+
out = {"gemma": tuples}
|
| 59 |
+
|
| 60 |
+
if not return_word_ids:
|
| 61 |
+
# Return only (token_id, attention_mask) pairs, omitting token position
|
| 62 |
+
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
|
| 63 |
+
|
| 64 |
+
return out
|