Fabrice-TIERCELIN commited on
Commit
da4932a
·
verified ·
1 Parent(s): d53609a

Upload 3 files

Browse files
packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/av_encoder.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+
3
+ import torch
4
+ from transformers.models.gemma3 import Gemma3ForConditionalGeneration
5
+
6
+ from ltx_core.loader.sd_ops import SDOps
7
+ from ltx_core.model.model_protocol import ModelConfigurator
8
+ from ltx_core.text_encoders.gemma.embeddings_connector import (
9
+ Embeddings1DConnector,
10
+ Embeddings1DConnectorConfigurator,
11
+ )
12
+ from ltx_core.text_encoders.gemma.encoders.base_encoder import GemmaTextEncoderModelBase
13
+ from ltx_core.text_encoders.gemma.feature_extractor import GemmaFeaturesExtractorProjLinear
14
+ from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
15
+
16
+
17
+ class AVGemmaEncoderOutput(NamedTuple):
18
+ video_encoding: torch.Tensor
19
+ audio_encoding: torch.Tensor
20
+ attention_mask: torch.Tensor
21
+
22
+
23
+ class AVGemmaTextEncoderModel(GemmaTextEncoderModelBase):
24
+ """
25
+ AVGemma Text Encoder Model.
26
+ This class combines the tokenizer, Gemma model, feature extractor from base class and a
27
+ video and audio embeddings connectors to provide a preprocessing for audio-visual pipeline.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ feature_extractor_linear: GemmaFeaturesExtractorProjLinear,
33
+ embeddings_connector: Embeddings1DConnector,
34
+ audio_embeddings_connector: Embeddings1DConnector,
35
+ tokenizer: LTXVGemmaTokenizer | None = None,
36
+ model: Gemma3ForConditionalGeneration | None = None,
37
+ dtype: torch.dtype = torch.bfloat16,
38
+ ) -> None:
39
+ super().__init__(
40
+ feature_extractor_linear=feature_extractor_linear,
41
+ tokenizer=tokenizer,
42
+ model=model,
43
+ dtype=dtype,
44
+ )
45
+ self.embeddings_connector = embeddings_connector.to(dtype=dtype)
46
+ self.audio_embeddings_connector = audio_embeddings_connector.to(dtype=dtype)
47
+
48
+ def _run_connectors(
49
+ self, encoded_input: torch.Tensor, attention_mask: torch.Tensor
50
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
52
+
53
+ encoded, encoded_connector_attention_mask = self.embeddings_connector(
54
+ encoded_input,
55
+ connector_attention_mask,
56
+ )
57
+
58
+ # restore the mask values to int64
59
+ attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
60
+ attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
61
+ encoded = encoded * attention_mask
62
+
63
+ encoded_for_audio, _ = self.audio_embeddings_connector(encoded_input, connector_attention_mask)
64
+
65
+ return encoded, encoded_for_audio, attention_mask.squeeze(-1)
66
+
67
+ def forward(self, text: str, padding_side: str = "left") -> AVGemmaEncoderOutput:
68
+ encoded_inputs, attention_mask = self._preprocess_text(text, padding_side)
69
+ video_encoding, audio_encoding, attention_mask = self._run_connectors(encoded_inputs, attention_mask)
70
+ return AVGemmaEncoderOutput(video_encoding, audio_encoding, attention_mask)
71
+
72
+
73
+ class AVGemmaTextEncoderModelConfigurator(ModelConfigurator[AVGemmaTextEncoderModel]):
74
+ @classmethod
75
+ def from_config(cls: type["AVGemmaTextEncoderModel"], config: dict) -> "AVGemmaTextEncoderModel":
76
+ feature_extractor_linear = GemmaFeaturesExtractorProjLinear.from_config(config)
77
+ embeddings_connector = Embeddings1DConnectorConfigurator.from_config(config)
78
+ audio_embeddings_connector = Embeddings1DConnectorConfigurator.from_config(config)
79
+ return AVGemmaTextEncoderModel(
80
+ feature_extractor_linear=feature_extractor_linear,
81
+ embeddings_connector=embeddings_connector,
82
+ audio_embeddings_connector=audio_embeddings_connector,
83
+ )
84
+
85
+
86
+ AV_GEMMA_TEXT_ENCODER_KEY_OPS = (
87
+ SDOps("AV_GEMMA_TEXT_ENCODER_KEY_OPS")
88
+ .with_matching(prefix="text_embedding_projection.")
89
+ .with_matching(prefix="model.diffusion_model.audio_embeddings_connector.")
90
+ .with_matching(prefix="model.diffusion_model.video_embeddings_connector.")
91
+ .with_replacement("text_embedding_projection.", "feature_extractor_linear.")
92
+ .with_replacement("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.")
93
+ .with_replacement("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.")
94
+ )
packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from transformers import AutoImageProcessor, Gemma3ForConditionalGeneration, Gemma3Processor
7
+
8
+ from ltx_core.loader.module_ops import ModuleOps
9
+ from ltx_core.text_encoders.gemma.feature_extractor import GemmaFeaturesExtractorProjLinear
10
+ from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
11
+
12
+
13
+ class GemmaTextEncoderModelBase(torch.nn.Module):
14
+ """
15
+ Gemma Text Encoder Model.
16
+ This base class combines the tokenizer, Gemma model and feature extractor to provide a preprocessing
17
+ for implementation classes for multimodal pipelines. It processes input text through tokenization,
18
+ obtains hidden states from the base language model, applies a linear feature extractor.
19
+ Args:
20
+ tokenizer (LTXVGemmaTokenizer): The tokenizer used for text preprocessing.
21
+ model (Gemma3ForConditionalGeneration): The base Gemma LLM.
22
+ feature_extractor_linear (GemmaFeaturesExtractorProjLinear): Linear projection for hidden state aggregation.
23
+ dtype (torch.dtype, optional): The data type for model parameters (default: torch.bfloat16).
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ feature_extractor_linear: GemmaFeaturesExtractorProjLinear,
29
+ tokenizer: LTXVGemmaTokenizer | None = None,
30
+ model: Gemma3ForConditionalGeneration | None = None,
31
+ img_processor: Gemma3Processor | None = None,
32
+ dtype: torch.dtype = torch.bfloat16,
33
+ ) -> None:
34
+ super().__init__()
35
+ self._gemma_root = None
36
+ self.tokenizer = tokenizer
37
+ self.model = model
38
+ self.processor = img_processor
39
+ self.feature_extractor_linear = feature_extractor_linear.to(dtype=dtype)
40
+
41
+ def _run_feature_extractor(
42
+ self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, padding_side: str = "right"
43
+ ) -> torch.Tensor:
44
+ encoded_text_features = torch.stack(hidden_states, dim=-1)
45
+ encoded_text_features_dtype = encoded_text_features.dtype
46
+
47
+ sequence_lengths = attention_mask.sum(dim=-1)
48
+ normed_concated_encoded_text_features = _norm_and_concat_padded_batch(
49
+ encoded_text_features, sequence_lengths, padding_side=padding_side
50
+ )
51
+
52
+ return self.feature_extractor_linear(normed_concated_encoded_text_features.to(encoded_text_features_dtype))
53
+
54
+ def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
55
+ return (attention_mask - 1).to(dtype).reshape(
56
+ (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
57
+ ) * torch.finfo(dtype).max
58
+
59
+ def _preprocess_text(self, text: str, padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
60
+ """
61
+ Encode a given string into feature tensors suitable for downstream tasks.
62
+ Args:
63
+ text (str): Input string to encode.
64
+ Returns:
65
+ tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
66
+ """
67
+ token_pairs = self.tokenizer.tokenize_with_weights(text)["gemma"]
68
+ input_ids = torch.tensor([[t[0] for t in token_pairs]], device=self.model.device)
69
+ attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=self.model.device)
70
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
71
+ projected = self._run_feature_extractor(
72
+ hidden_states=outputs.hidden_states, attention_mask=attention_mask, padding_side=padding_side
73
+ )
74
+ return projected, attention_mask
75
+
76
+ def _init_image_processor(self) -> None:
77
+ img_processor = AutoImageProcessor.from_pretrained(self._gemma_root, local_files_only=True)
78
+ if not self.tokenizer:
79
+ raise ValueError("Tokenizer is not loaded, cannot load image processor")
80
+ self.processor = Gemma3Processor(image_processor=img_processor, tokenizer=self.tokenizer.tokenizer)
81
+
82
+ def _enhance(
83
+ self,
84
+ messages: list[dict[str, str]],
85
+ image: torch.Tensor | None = None,
86
+ max_new_tokens: int = 512,
87
+ seed: int = 42,
88
+ ) -> str:
89
+ if self.processor is None:
90
+ self._init_image_processor()
91
+ text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
92
+
93
+ model_inputs = self.processor(
94
+ text=text,
95
+ images=image,
96
+ return_tensors="pt",
97
+ ).to(self.model.device)
98
+ pad_token_id = self.processor.tokenizer.pad_token_id if self.processor.tokenizer.pad_token_id is not None else 0
99
+ model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id=pad_token_id)
100
+
101
+ with torch.inference_mode(), torch.random.fork_rng(devices=[self.model.device]):
102
+ torch.manual_seed(seed)
103
+ outputs = self.model.generate(
104
+ **model_inputs,
105
+ max_new_tokens=max_new_tokens,
106
+ do_sample=True,
107
+ temperature=0.7,
108
+ )
109
+ generated_ids = outputs[0][len(model_inputs.input_ids[0]) :]
110
+ enhanced_prompt = self.processor.tokenizer.decode(generated_ids, skip_special_tokens=True)
111
+
112
+ return enhanced_prompt
113
+
114
+ def enhance_t2v(
115
+ self,
116
+ prompt: str,
117
+ max_new_tokens: int = 512,
118
+ system_prompt: str | None = None,
119
+ seed: int = 42,
120
+ ) -> str:
121
+ """Enhance a text prompt for T2V generation."""
122
+
123
+ system_prompt = system_prompt or self.default_gemma_t2v_system_prompt
124
+
125
+ messages = [
126
+ {"role": "system", "content": system_prompt},
127
+ {"role": "user", "content": f"user prompt: {prompt}"},
128
+ ]
129
+
130
+ return self._enhance(messages, max_new_tokens=max_new_tokens, seed=seed)
131
+
132
+ def enhance_i2v(
133
+ self,
134
+ prompt: str,
135
+ image: torch.Tensor,
136
+ max_new_tokens: int = 512,
137
+ system_prompt: str | None = None,
138
+ seed: int = 42,
139
+ ) -> str:
140
+ """Enhance a text prompt for I2V generation using a reference image."""
141
+ system_prompt = system_prompt or self.default_gemma_i2v_system_prompt
142
+ messages = [
143
+ {"role": "system", "content": system_prompt},
144
+ {
145
+ "role": "user",
146
+ "content": [
147
+ {"type": "image"},
148
+ {"type": "text", "text": f"User Raw Input Prompt: {prompt}."},
149
+ ],
150
+ },
151
+ ]
152
+ return self._enhance(messages, image=image, max_new_tokens=max_new_tokens, seed=seed)
153
+
154
+ @functools.cached_property
155
+ def default_gemma_i2v_system_prompt(self) -> str:
156
+ return _load_system_prompt("gemma_i2v_system_prompt.txt")
157
+
158
+ @functools.cached_property
159
+ def default_gemma_t2v_system_prompt(self) -> str:
160
+ return _load_system_prompt("gemma_t2v_system_prompt.txt")
161
+
162
+ def forward(self, text: str, padding_side: str = "left") -> tuple[torch.Tensor, torch.Tensor]:
163
+ raise NotImplementedError("This method is not implemented for the base class")
164
+
165
+
166
+ def _norm_and_concat_padded_batch(
167
+ encoded_text: torch.Tensor,
168
+ sequence_lengths: torch.Tensor,
169
+ padding_side: str = "right",
170
+ ) -> torch.Tensor:
171
+ """Normalize and flatten multi-layer hidden states, respecting padding.
172
+ Performs per-batch, per-layer normalization using masked mean and range,
173
+ then concatenates across the layer dimension.
174
+ Args:
175
+ encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
176
+ sequence_lengths: Number of valid (non-padded) tokens per batch item.
177
+ padding_side: Whether padding is on "left" or "right".
178
+ Returns:
179
+ Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
180
+ with padded positions zeroed out.
181
+ """
182
+ b, t, d, l = encoded_text.shape # noqa: E741
183
+ device = encoded_text.device
184
+
185
+ # Build mask: [B, T, 1, 1]
186
+ token_indices = torch.arange(t, device=device)[None, :] # [1, T]
187
+
188
+ if padding_side == "right":
189
+ # For right padding, valid tokens are from 0 to sequence_length-1
190
+ mask = token_indices < sequence_lengths[:, None] # [B, T]
191
+ elif padding_side == "left":
192
+ # For left padding, valid tokens are from (T - sequence_length) to T-1
193
+ start_indices = t - sequence_lengths[:, None] # [B, 1]
194
+ mask = token_indices >= start_indices # [B, T]
195
+ else:
196
+ raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
197
+
198
+ mask = rearrange(mask, "b t -> b t 1 1")
199
+
200
+ eps = 1e-6
201
+
202
+ # Compute masked mean: [B, 1, 1, L]
203
+ masked = encoded_text.masked_fill(~mask, 0.0)
204
+ denom = (sequence_lengths * d).view(b, 1, 1, 1)
205
+ mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
206
+
207
+ # Compute masked min/max: [B, 1, 1, L]
208
+ x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
209
+ x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
210
+ range_ = x_max - x_min
211
+
212
+ # Normalize only the valid tokens
213
+ normed = 8 * (encoded_text - mean) / (range_ + eps)
214
+
215
+ # concat to be [Batch, T, D * L] - this preserves the original structure
216
+ normed = normed.reshape(b, t, -1) # [B, T, D * L]
217
+
218
+ # Apply mask to preserve original padding (set padded positions to 0)
219
+ mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
220
+ normed = normed.masked_fill(~mask_flattened, 0.0)
221
+
222
+ return normed
223
+
224
+
225
+ @functools.lru_cache(maxsize=2)
226
+ def _load_system_prompt(prompt_name: str) -> str:
227
+ with open(Path(__file__).parent / "prompts" / f"{prompt_name}", "r") as f:
228
+ return f.read()
229
+
230
+
231
+ def _find_matching_dir(root_path: str, pattern: str) -> str:
232
+ """
233
+ Recursively search for files matching a glob pattern and return the parent directory of the first match.
234
+
235
+ LT_INTERNAL_BEGIN
236
+ Handles both LT internal storage and HuggingFace directory structures for Gemma model files.
237
+ See: https://huggingface.co/google/gemma-3-12b-it-qat-q4_0-unquantized
238
+ LT_INTERNAL_END
239
+ """
240
+
241
+ matches = list(Path(root_path).rglob(pattern))
242
+ if not matches:
243
+ raise FileNotFoundError(f"No files matching pattern '{pattern}' found under {root_path}")
244
+ return str(matches[0].parent)
245
+
246
+
247
+ def module_ops_from_gemma_root(gemma_root: str, local_files_only: bool = True) -> tuple[ModuleOps, ...]:
248
+ if len(gemma_root.split("/")) != 2:
249
+ gemma_path = _find_matching_dir(gemma_root, "model*.safetensors")
250
+ tokenizer_path = _find_matching_dir(gemma_root, "tokenizer.model")
251
+ else:
252
+ # Hub ID: google/gemma-3-12b-it-qat-q4_0-unquantized
253
+ gemma_path = tokenizer_path = gemma_root
254
+
255
+ # LT_INTERNAL_BEGIN
256
+ # Note: We pass torch_dtype to from_pretrained here to maintain backward compatibility with older versions of
257
+ # Transformers. This is necessary to compare results with ComfyUI, which uses an older version that raises an error
258
+ # when dtype is passed. Current solution only logs a warning.
259
+ # LT_INTERNAL_END
260
+ def load_gemma(module: GemmaTextEncoderModelBase) -> GemmaTextEncoderModelBase:
261
+ module.model = Gemma3ForConditionalGeneration.from_pretrained(
262
+ gemma_path, local_files_only=local_files_only, torch_dtype=torch.bfloat16
263
+ )
264
+ module._gemma_root = module._gemma_root or gemma_root
265
+ return module
266
+
267
+ def load_tokenizer(module: GemmaTextEncoderModelBase) -> GemmaTextEncoderModelBase:
268
+ module.tokenizer = LTXVGemmaTokenizer(tokenizer_path, 1024, local_files_only)
269
+ module._gemma_root = module._gemma_root or gemma_root
270
+ return module
271
+
272
+ gemma_load_ops = ModuleOps(
273
+ "GemmaLoad",
274
+ matcher=lambda module: isinstance(module, GemmaTextEncoderModelBase) and module.model is None,
275
+ mutator=load_gemma,
276
+ )
277
+ tokenizer_load_ops = ModuleOps(
278
+ "TokenizerLoad",
279
+ matcher=lambda module: isinstance(module, GemmaTextEncoderModelBase) and module.tokenizer is None,
280
+ mutator=load_tokenizer,
281
+ )
282
+ return (gemma_load_ops, tokenizer_load_ops)
283
+
284
+
285
+ def encode_text(text_encoder: GemmaTextEncoderModelBase, prompts: list[str]) -> list[tuple[torch.Tensor, torch.Tensor]]:
286
+ """
287
+ Encode a list of prompts using the provided Gemma text encoder.
288
+ Args:
289
+ text_encoder: The Gemma text encoder instance.
290
+ prompts: List of prompt strings to encode.
291
+ Returns:
292
+ List of tuples, each containing (v_context, a_context) tensors for each prompt.
293
+ """
294
+ result = []
295
+ for prompt in prompts:
296
+ v_context, a_context, _ = text_encoder(prompt)
297
+ result.append((v_context, a_context))
298
+ return result
299
+
300
+
301
+ def _cat_with_padding(
302
+ tensor: torch.Tensor,
303
+ padding_length: int,
304
+ value: int | float,
305
+ ) -> torch.Tensor:
306
+ """Concatenate a tensor with a padding tensor of the given value."""
307
+ return torch.cat(
308
+ [
309
+ tensor,
310
+ torch.full(
311
+ (1, padding_length),
312
+ value,
313
+ dtype=tensor.dtype,
314
+ device=tensor.device,
315
+ ),
316
+ ],
317
+ dim=1,
318
+ )
319
+
320
+
321
+ def _pad_inputs_for_attention_alignment(
322
+ model_inputs: dict[str, torch.Tensor],
323
+ pad_token_id: int = 0,
324
+ alignment: int = 8,
325
+ ) -> dict[str, torch.Tensor]:
326
+ """Pad sequence length to multiple of alignment for Flash Attention compatibility.
327
+ Flash Attention within SDPA requires sequence lengths aligned to 8 bytes.
328
+ This pads input_ids, attention_mask, and token_type_ids (if present) to prevent
329
+ 'p.attn_bias_ptr is not correctly aligned' errors.
330
+ """
331
+ seq_len = model_inputs.input_ids.shape[1]
332
+ padded_len = ((seq_len + alignment - 1) // alignment) * alignment
333
+ padding_length = padded_len - seq_len
334
+
335
+ if padding_length > 0:
336
+ model_inputs["input_ids"] = _cat_with_padding(model_inputs.input_ids, padding_length, pad_token_id)
337
+
338
+ model_inputs["attention_mask"] = _cat_with_padding(model_inputs.attention_mask, padding_length, 0)
339
+
340
+ if "token_type_ids" in model_inputs and model_inputs["token_type_ids"] is not None:
341
+ model_inputs["token_type_ids"] = _cat_with_padding(model_inputs["token_type_ids"], padding_length, 0)
342
+
343
+ return model_inputs
packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/video_only_encoder.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+
3
+ import torch
4
+ from transformers import Gemma3ForConditionalGeneration
5
+
6
+ from ltx_core.loader.sd_ops import SDOps
7
+ from ltx_core.model.model_protocol import ModelConfigurator
8
+ from ltx_core.text_encoders.gemma.embeddings_connector import (
9
+ Embeddings1DConnector,
10
+ Embeddings1DConnectorConfigurator,
11
+ )
12
+ from ltx_core.text_encoders.gemma.encoders.base_encoder import GemmaTextEncoderModelBase
13
+ from ltx_core.text_encoders.gemma.feature_extractor import GemmaFeaturesExtractorProjLinear
14
+ from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
15
+
16
+
17
+ class VideoGemmaEncoderOutput(NamedTuple):
18
+ video_encoding: torch.Tensor
19
+ attention_mask: torch.Tensor
20
+
21
+
22
+ class VideoGemmaTextEncoderModel(GemmaTextEncoderModelBase):
23
+ """
24
+ Video Gemma Text Encoder Model.
25
+ This class combines the tokenizer, Gemma model, feature extractor from base class and a
26
+ video embeddings connector to provide a preprocessing for video only pipeline.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ feature_extractor_linear: GemmaFeaturesExtractorProjLinear,
32
+ embeddings_connector: Embeddings1DConnector,
33
+ tokenizer: LTXVGemmaTokenizer | None = None,
34
+ model: Gemma3ForConditionalGeneration | None = None,
35
+ dtype: torch.dtype = torch.bfloat16,
36
+ ) -> None:
37
+ super().__init__(
38
+ feature_extractor_linear=feature_extractor_linear,
39
+ tokenizer=tokenizer,
40
+ model=model,
41
+ dtype=dtype,
42
+ )
43
+ self.embeddings_connector = embeddings_connector.to(dtype=dtype)
44
+
45
+ def _run_connector(
46
+ self, encoded_input: torch.Tensor, attention_mask: torch.Tensor
47
+ ) -> tuple[torch.Tensor, torch.Tensor]:
48
+ connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
49
+
50
+ encoded, encoded_connector_attention_mask = self.embeddings_connector(
51
+ encoded_input,
52
+ connector_attention_mask,
53
+ )
54
+
55
+ # restore the mask values to int64
56
+ attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
57
+ attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
58
+ encoded = encoded * attention_mask
59
+
60
+ return encoded, attention_mask.squeeze(-1)
61
+
62
+ def forward(self, text: str, padding_side: str = "left") -> VideoGemmaEncoderOutput:
63
+ encoded_inputs, attention_mask = self._preprocess_text(text, padding_side)
64
+ video_encoding, attention_mask = self._run_connector(encoded_inputs, attention_mask)
65
+ return VideoGemmaEncoderOutput(video_encoding, attention_mask)
66
+
67
+
68
+ class VideoGemmaTextEncoderModelConfigurator(ModelConfigurator[VideoGemmaTextEncoderModel]):
69
+ @classmethod
70
+ def from_config(cls: type["VideoGemmaTextEncoderModel"], config: dict) -> "VideoGemmaTextEncoderModel":
71
+ feature_extractor_linear = GemmaFeaturesExtractorProjLinear.from_config(config)
72
+ embeddings_connector = Embeddings1DConnectorConfigurator.from_config(config)
73
+ return VideoGemmaTextEncoderModel(
74
+ feature_extractor_linear=feature_extractor_linear,
75
+ embeddings_connector=embeddings_connector,
76
+ )
77
+
78
+
79
+ VIDEO_ONLY_GEMMA_TEXT_ENCODER_KEY_OPS = (
80
+ SDOps("VIDEO_ONLY_GEMMA_TEXT_ENCODER_KEY_OPS")
81
+ .with_matching(prefix="text_embedding_projection.")
82
+ .with_matching(prefix="model.diffusion_model.embeddings_connector.")
83
+ .with_replacement("text_embedding_projection.", "feature_extractor_linear.")
84
+ .with_replacement("model.diffusion_model.embeddings_connector.", "embeddings_connector.")
85
+ )