Fabrice-TIERCELIN commited on
Commit
0130909
·
verified ·
1 Parent(s): da4932a

Upload 4 files

Browse files
packages/ltx-core/src/ltx_core/text_encoders/gemma/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gemma text encoder components."""
2
+
3
+ from ltx_core.text_encoders.gemma.encoders.av_encoder import (
4
+ AV_GEMMA_TEXT_ENCODER_KEY_OPS,
5
+ AVGemmaEncoderOutput,
6
+ AVGemmaTextEncoderModel,
7
+ AVGemmaTextEncoderModelConfigurator,
8
+ )
9
+ from ltx_core.text_encoders.gemma.encoders.base_encoder import (
10
+ GemmaTextEncoderModelBase,
11
+ encode_text,
12
+ module_ops_from_gemma_root,
13
+ )
14
+ from ltx_core.text_encoders.gemma.encoders.video_only_encoder import (
15
+ VideoGemmaEncoderOutput,
16
+ VideoGemmaTextEncoderModel,
17
+ VideoGemmaTextEncoderModelConfigurator,
18
+ )
19
+
20
+ __all__ = [
21
+ "AV_GEMMA_TEXT_ENCODER_KEY_OPS",
22
+ "AVGemmaEncoderOutput",
23
+ "AVGemmaTextEncoderModel",
24
+ "AVGemmaTextEncoderModelConfigurator",
25
+ "GemmaTextEncoderModelBase",
26
+ "VideoGemmaEncoderOutput",
27
+ "VideoGemmaTextEncoderModel",
28
+ "VideoGemmaTextEncoderModelConfigurator",
29
+ "encode_text",
30
+ "module_ops_from_gemma_root",
31
+ ]
packages/ltx-core/src/ltx_core/text_encoders/gemma/embeddings_connector.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.model.model_protocol import ModelConfigurator
4
+ from ltx_core.model.transformer.attention import Attention
5
+ from ltx_core.model.transformer.feed_forward import FeedForward
6
+ from ltx_core.model.transformer.rope import (
7
+ LTXRopeType,
8
+ generate_freq_grid_np,
9
+ generate_freq_grid_pytorch,
10
+ precompute_freqs_cis,
11
+ )
12
+ from ltx_core.utils import rms_norm
13
+
14
+
15
+ class _BasicTransformerBlock1D(torch.nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ heads: int,
20
+ dim_head: int,
21
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
22
+ ):
23
+ super().__init__()
24
+
25
+ self.attn1 = Attention(
26
+ query_dim=dim,
27
+ heads=heads,
28
+ dim_head=dim_head,
29
+ rope_type=rope_type,
30
+ )
31
+
32
+ self.ff = FeedForward(
33
+ dim,
34
+ dim_out=dim,
35
+ )
36
+
37
+ def forward(
38
+ self,
39
+ hidden_states: torch.Tensor,
40
+ attention_mask: torch.Tensor | None = None,
41
+ pe: torch.Tensor | None = None,
42
+ ) -> torch.Tensor:
43
+ # Notice that normalization is always applied before the real computation in the following blocks.
44
+
45
+ # 1. Normalization Before Self-Attention
46
+ norm_hidden_states = rms_norm(hidden_states)
47
+
48
+ norm_hidden_states = norm_hidden_states.squeeze(1)
49
+
50
+ # 2. Self-Attention
51
+ attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
52
+
53
+ hidden_states = attn_output + hidden_states
54
+ if hidden_states.ndim == 4:
55
+ hidden_states = hidden_states.squeeze(1)
56
+
57
+ # 3. Normalization before Feed-Forward
58
+ norm_hidden_states = rms_norm(hidden_states)
59
+
60
+ # 4. Feed-forward
61
+ ff_output = self.ff(norm_hidden_states)
62
+
63
+ hidden_states = ff_output + hidden_states
64
+ if hidden_states.ndim == 4:
65
+ hidden_states = hidden_states.squeeze(1)
66
+
67
+ return hidden_states
68
+
69
+
70
+ class Embeddings1DConnector(torch.nn.Module):
71
+ """
72
+ Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
73
+ other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
74
+ substitute padded positions with learnable registers. The module is highly configurable for head size, number of
75
+ layers, and register usage.
76
+ Args:
77
+ attention_head_dim (int): Dimension of each attention head (default=128).
78
+ num_attention_heads (int): Number of attention heads (default=30).
79
+ num_layers (int): Number of transformer layers (default=2).
80
+ positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
81
+ positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
82
+ causal_temporal_positioning (bool): If True, uses causal attention (default=False).
83
+ num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
84
+ register replacement. (default=128)
85
+ rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
86
+ double_precision_rope (bool): Use double precision rope calculation (default=False).
87
+ """
88
+
89
+ _supports_gradient_checkpointing = True
90
+
91
+ def __init__(
92
+ self,
93
+ attention_head_dim: int = 128,
94
+ num_attention_heads: int = 30,
95
+ num_layers: int = 2,
96
+ positional_embedding_theta: float = 10000.0,
97
+ positional_embedding_max_pos: list[int] | None = None,
98
+ causal_temporal_positioning: bool = False,
99
+ num_learnable_registers: int | None = 128,
100
+ rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
101
+ double_precision_rope: bool = False,
102
+ ):
103
+ super().__init__()
104
+ self.num_attention_heads = num_attention_heads
105
+ self.inner_dim = num_attention_heads * attention_head_dim
106
+ self.causal_temporal_positioning = causal_temporal_positioning
107
+ self.positional_embedding_theta = positional_embedding_theta
108
+ self.positional_embedding_max_pos = (
109
+ positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
110
+ )
111
+ self.rope_type = rope_type
112
+ self.double_precision_rope = double_precision_rope
113
+ self.transformer_1d_blocks = torch.nn.ModuleList(
114
+ [
115
+ _BasicTransformerBlock1D(
116
+ dim=self.inner_dim,
117
+ heads=num_attention_heads,
118
+ dim_head=attention_head_dim,
119
+ rope_type=rope_type,
120
+ )
121
+ for _ in range(num_layers)
122
+ ]
123
+ )
124
+
125
+ self.num_learnable_registers = num_learnable_registers
126
+ if self.num_learnable_registers:
127
+ self.learnable_registers = torch.nn.Parameter(
128
+ torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
129
+ )
130
+
131
+ def _replace_padded_with_learnable_registers(
132
+ self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
133
+ ) -> tuple[torch.Tensor, torch.Tensor]:
134
+ assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
135
+ f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
136
+ f"{self.num_learnable_registers}."
137
+ )
138
+
139
+ num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
140
+ learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
141
+ attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
142
+
143
+ non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
144
+ non_zero_nums = non_zero_hidden_states.shape[1]
145
+ pad_length = hidden_states.shape[1] - non_zero_nums
146
+ adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
147
+ flipped_mask = torch.flip(attention_mask_binary, dims=[1])
148
+ hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
149
+
150
+ attention_mask = torch.full_like(
151
+ attention_mask,
152
+ 0.0,
153
+ dtype=attention_mask.dtype,
154
+ device=attention_mask.device,
155
+ )
156
+
157
+ return hidden_states, attention_mask
158
+
159
+ def forward(
160
+ self,
161
+ hidden_states: torch.Tensor,
162
+ attention_mask: torch.Tensor | None = None,
163
+ ) -> tuple[torch.Tensor, torch.Tensor]:
164
+ """
165
+ Forward pass of Embeddings1DConnector.
166
+ Args:
167
+ hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
168
+ attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
169
+ Returns:
170
+ tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
171
+ """
172
+ if self.num_learnable_registers:
173
+ hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
174
+
175
+ indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
176
+ indices_grid = indices_grid[None, None, :]
177
+ freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
178
+ freqs_cis = precompute_freqs_cis(
179
+ indices_grid=indices_grid,
180
+ dim=self.inner_dim,
181
+ out_dtype=hidden_states.dtype,
182
+ theta=self.positional_embedding_theta,
183
+ max_pos=self.positional_embedding_max_pos,
184
+ num_attention_heads=self.num_attention_heads,
185
+ rope_type=self.rope_type,
186
+ freq_grid_generator=freq_grid_generator,
187
+ )
188
+
189
+ for block in self.transformer_1d_blocks:
190
+ hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
191
+
192
+ hidden_states = rms_norm(hidden_states)
193
+
194
+ return hidden_states, attention_mask
195
+
196
+
197
+ class Embeddings1DConnectorConfigurator(ModelConfigurator[Embeddings1DConnector]):
198
+ @classmethod
199
+ def from_config(cls: type[Embeddings1DConnector], config: dict) -> Embeddings1DConnector:
200
+ config = config.get("transformer", {})
201
+ rope_type = LTXRopeType(config.get("rope_type", "interleaved"))
202
+ double_precision_rope = config.get("frequencies_precision", False) == "float64"
203
+ pe_max_pos = config.get("connector_positional_embedding_max_pos", [1])
204
+
205
+ connector = Embeddings1DConnector(
206
+ positional_embedding_max_pos=pe_max_pos,
207
+ rope_type=rope_type,
208
+ double_precision_rope=double_precision_rope,
209
+ )
210
+ return connector
packages/ltx-core/src/ltx_core/text_encoders/gemma/feature_extractor.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ltx_core.model.model_protocol import ModelConfigurator
4
+
5
+
6
+ class GemmaFeaturesExtractorProjLinear(torch.nn.Module, ModelConfigurator["GemmaFeaturesExtractorProjLinear"]):
7
+ """
8
+ Feature extractor module for Gemma models.
9
+ This module applies a single linear projection to the input tensor.
10
+ It expects a flattened feature tensor of shape (batch_size, 3840*49).
11
+ The linear layer maps this to a (batch_size, 3840) embedding.
12
+ Attributes:
13
+ aggregate_embed (torch.nn.Linear): Linear projection layer.
14
+ """
15
+
16
+ def __init__(self) -> None:
17
+ """
18
+ Initialize the GemmaFeaturesExtractorProjLinear module.
19
+ The input dimension is expected to be 3840 * 49, and the output is 3840.
20
+ """
21
+ super().__init__()
22
+ self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ """
26
+ Forward pass for the feature extractor.
27
+ Args:
28
+ x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
29
+ Returns:
30
+ torch.Tensor: Output tensor of shape (batch_size, 3840).
31
+ """
32
+ return self.aggregate_embed(x)
33
+
34
+ @classmethod
35
+ def from_config(cls: type["GemmaFeaturesExtractorProjLinear"], _config: dict) -> "GemmaFeaturesExtractorProjLinear":
36
+ return cls()
packages/ltx-core/src/ltx_core/text_encoders/gemma/tokenizer.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+
3
+
4
+ class LTXVGemmaTokenizer:
5
+ """
6
+ Tokenizer wrapper for Gemma models compatible with LTXV processes.
7
+ This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
8
+ ensuring correct settings and output formatting for downstream consumption.
9
+ """
10
+
11
+ def __init__(self, tokenizer_path: str, max_length: int = 256, local_files_only: bool = True):
12
+ """
13
+ Initialize the tokenizer.
14
+ Args:
15
+ tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
16
+ max_length (int, optional): Max sequence length for encoding. Defaults to 256.
17
+ """
18
+ self.tokenizer = AutoTokenizer.from_pretrained(
19
+ tokenizer_path, local_files_only=local_files_only, model_max_length=max_length
20
+ )
21
+ # Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
22
+ self.tokenizer.padding_side = "left"
23
+ if self.tokenizer.pad_token is None:
24
+ self.tokenizer.pad_token = self.tokenizer.eos_token
25
+
26
+ self.max_length = max_length
27
+
28
+ def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
29
+ """
30
+ Tokenize the given text and return token IDs and attention weights.
31
+ Args:
32
+ text (str): The input string to tokenize.
33
+ return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
34
+ If False (default), omits the indices.
35
+ Returns:
36
+ dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
37
+ A dictionary with a "gemma" key mapping to:
38
+ - a list of (token_id, attention_mask) tuples if return_word_ids is False;
39
+ - a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
40
+ Example:
41
+ >>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
42
+ >>> tokenizer.tokenize_with_weights("hello world")
43
+ {'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
44
+ """
45
+ text = text.strip()
46
+ encoded = self.tokenizer(
47
+ text,
48
+ padding="max_length",
49
+ max_length=self.max_length,
50
+ truncation=True,
51
+ return_tensors="pt",
52
+ )
53
+ input_ids = encoded.input_ids
54
+ attention_mask = encoded.attention_mask
55
+ tuples = [
56
+ (token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
57
+ ]
58
+ out = {"gemma": tuples}
59
+
60
+ if not return_word_ids:
61
+ # Return only (token_id, attention_mask) pairs, omitting token position
62
+ out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
63
+
64
+ return out