| import logging |
| from glob import glob |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
|
|
| import comfy.model_management |
| import comfy.sd |
| import comfy.supported_models_base |
| import folder_paths |
| import torch |
| from PIL import Image |
| from transformers import ( |
| AutoImageProcessor, |
| AutoTokenizer, |
| Gemma3Config, |
| Gemma3ForConditionalGeneration, |
| Gemma3Processor, |
| ) |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS |
| from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES |
|
|
| from .nodes_registry import comfy_node |
| from .text_embeddings_connectors import load_text_embeddings_pipeline |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _load_system_prompt(filename: str) -> str: |
| """Load system prompt from file at module level.""" |
| try: |
| prompt_path = Path(__file__).parent / "system_prompts" / filename |
| if prompt_path.exists(): |
| return prompt_path.read_text(encoding="utf-8").strip() |
| except Exception as e: |
| logger.warning(f"Could not load {filename}: {e}") |
| return "" |
|
|
|
|
| DEFAULT_T2V_SYSTEM_PROMPT = _load_system_prompt("gemma_t2v_system_prompt.txt") |
| DEFAULT_I2V_SYSTEM_PROMPT = _load_system_prompt("gemma_i2v_system_prompt.txt") |
|
|
|
|
| def tensor_to_pil(tensor: torch.Tensor) -> Image.Image: |
| """Convert ComfyUI image tensor to PIL Image.""" |
| if tensor.dim() == 4: |
| tensor = tensor[0] |
| numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") |
| return Image.fromarray(numpy_image) |
|
|
|
|
| class LTXVGemmaTokenizer: |
| def __init__(self, tokenizer_path: str, max_length: int = 1024): |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_path, local_files_only=True, model_max_length=max_length |
| ) |
| |
| self.tokenizer.padding_side = "left" |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| self.max_length = max_length |
|
|
| def tokenize_with_weights(self, text: str, return_word_ids: bool = False): |
| text = text.strip() |
| encoded = self.tokenizer( |
| text, |
| padding="max_length", |
| max_length=self.max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| input_ids = encoded.input_ids |
| attention_mask = encoded.attention_mask |
| tuples = [ |
| (token_id, attn, i) |
| for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0])) |
| ] |
| out = {"gemma": tuples} |
|
|
| if not return_word_ids: |
| out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()} |
|
|
| return out |
|
|
|
|
| class LTXVGemmaTextEncoderModel(torch.nn.Module): |
| def __init__( |
| self, |
| model: Gemma3ForConditionalGeneration, |
| feature_extractor, |
| embeddings_processor, |
| processor: Gemma3Processor | None = None, |
| dtype=torch.bfloat16, |
| device="cpu", |
| ): |
| super().__init__() |
| self.model = model |
| self.processor = processor |
| self.feature_extractor = feature_extractor.to(dtype=dtype) |
| self.embeddings_processor = embeddings_processor.to(dtype=dtype) |
| self.dtypes = set([dtype]) |
| |
| |
| self._model_memory_required = ( |
| comfy.model_management.module_size(self.model) + 256 * 1024 * 1024 |
| ) |
|
|
| def set_clip_options(self, options): |
| pass |
|
|
| def reset_clip_options(self): |
| pass |
|
|
| def forward(self, input_ids, attention_mask, padding_side="right"): |
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| ) |
| all_layer_hiddens = torch.stack(outputs.hidden_states, dim=-1) |
|
|
| |
| features = self.feature_extractor( |
| all_layer_hiddens, attention_mask, padding_side |
| ) |
| return features |
|
|
| def encode_token_weights(self, token_weight_pairs): |
| token_pairs = token_weight_pairs["gemma"] |
| input_ids = torch.tensor( |
| [[t[0] for t in token_pairs]], device=self.model.device |
| ) |
| attention_mask = torch.tensor( |
| [[w[1] for w in token_pairs]], device=self.model.device |
| ) |
|
|
| self.to(self.model.device) |
|
|
| features = self(input_ids, attention_mask, padding_side="left") |
|
|
| |
| encoded_input_dtype = next(iter(features.values())).dtype |
| connector_attention_mask = (attention_mask - 1).to(encoded_input_dtype).reshape( |
| (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) |
| ) * torch.finfo(encoded_input_dtype).max |
|
|
| |
| encoded, mask = self.embeddings_processor.create_embeddings( |
| features, connector_attention_mask |
| ) |
|
|
| return encoded, None, {"attention_mask": mask} |
|
|
| def load_sd(self, sd): |
| return self.model.load_state_dict(sd, strict=False) |
|
|
| def memory_required(self, input_shape): |
| |
| return self._model_memory_required |
|
|
|
|
| def ltxv_gemma_tokenizer(tokenizer_path, max_length=256): |
| class _LTXVGemmaTokenizer(LTXVGemmaTokenizer): |
| def __init__(self, embedding_directory=None, tokenizer_data={}): |
| super().__init__(tokenizer_path, max_length=max_length) |
|
|
| return _LTXVGemmaTokenizer |
|
|
|
|
| def ltxv_gemma_clip(encoder_path, ltxv_path, processor=None, dtype=None): |
| class _LTXVGemmaTextEncoderModel(LTXVGemmaTextEncoderModel): |
| def __init__(self, device="cpu", dtype=dtype, model_options={}): |
| dtype = torch.bfloat16 |
|
|
| gemma_model = Gemma3ForConditionalGeneration.from_pretrained( |
| encoder_path, |
| local_files_only=True, |
| torch_dtype=dtype, |
| ) |
|
|
| feature_extractor, embeddings_processor = load_text_embeddings_pipeline( |
| ltxv_path, |
| dtype=dtype, |
| fallback_proj_path=encoder_path / "proj_linear.safetensors", |
| ) |
|
|
| super().__init__( |
| model=gemma_model, |
| feature_extractor=feature_extractor, |
| embeddings_processor=embeddings_processor, |
| processor=processor, |
| dtype=dtype, |
| device=device, |
| ) |
|
|
| return _LTXVGemmaTextEncoderModel |
|
|
|
|
| def find_matching_dir(root_path: str, pattern: str) -> str: |
| """ |
| Recursively search for files matching a glob pattern and return the parent directory of the first match. |
| """ |
| matches = [ |
| Path(p) |
| for p in glob(f"{root_path}/**", recursive=True) |
| if Path(p).match(pattern) |
| ] |
| if not matches: |
| raise FileNotFoundError( |
| f"No files matching pattern '{pattern}' found under {root_path}" |
| ) |
| return str(matches[0].parent) |
|
|
|
|
| @comfy_node(name="LTXVGemmaCLIPModelLoader", description="Gemma 3 Model Loader") |
| class LTXVGemmaCLIPModelLoader: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "gemma_path": ( |
| folder_paths.get_filename_list("text_encoders"), |
| {"tooltip": "The name of the text encoder model to load."}, |
| ), |
| "ltxv_path": ( |
| folder_paths.get_filename_list("checkpoints"), |
| {"tooltip": "The name of the ltxv model to load."}, |
| ), |
| "max_length": ( |
| "INT", |
| {"default": 1024, "min": 16, "max": 131072, "step": 8}, |
| ), |
| } |
| } |
|
|
| RETURN_TYPES = ("CLIP",) |
| RETURN_NAMES = ("clip",) |
| FUNCTION = "load_model" |
| CATEGORY = "lightricks/LTXV" |
| TITLE = "LTXV Gemma CLIP Loader" |
| OUTPUT_NODE = False |
|
|
| def load_model(self, gemma_path: str, ltxv_path: str, max_length: int): |
| path = Path(folder_paths.get_full_path("text_encoders", gemma_path)) |
| model_root = path.parents[1] |
| tokenizer_path = Path(find_matching_dir(model_root, "tokenizer.model")) |
| gemma_model_path = Path(find_matching_dir(model_root, "model*.safetensors")) |
| processor_path = Path(find_matching_dir(model_root, "preprocessor_config.json")) |
| tokenizer_class = ltxv_gemma_tokenizer(tokenizer_path, max_length=max_length) |
|
|
| processor = None |
| try: |
| image_processor = AutoImageProcessor.from_pretrained( |
| str(processor_path), |
| local_files_only=True, |
| ) |
| processor = Gemma3Processor( |
| image_processor=image_processor, |
| tokenizer=tokenizer_class().tokenizer, |
| ) |
| logger.info(f"Loaded processor from {model_root} - enhancement enabled") |
| except Exception as e: |
| logger.warning(f"Could not load processor from {model_root}: {e}") |
|
|
| clip_dtype = torch.bfloat16 |
| ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path) |
| clip_target = comfy.supported_models_base.ClipTarget( |
| tokenizer=tokenizer_class, |
| clip=ltxv_gemma_clip( |
| gemma_model_path, ltxv_full_path, processor=processor, dtype=clip_dtype |
| ), |
| ) |
|
|
| return (comfy.sd.CLIP(clip_target),) |
|
|
|
|
| _UNICODE_REPLACEMENTS = str.maketrans( |
| "\u2018\u2019\u201c\u201d\u2014\u2013\u00a0\u2032\u2212", "''\"\"-- '-" |
| ) |
|
|
|
|
| def clean_response(text): |
| text = text.translate(_UNICODE_REPLACEMENTS) |
|
|
| |
| for i, char in enumerate(text): |
| if char.isalpha(): |
| return text[i:] |
| return text |
|
|
|
|
| @comfy_node(name="LTXVGemmaEnhancePrompt", description="Gemma 3 Prompt Enhancer") |
| class LTXVGemmaEnhancePrompt: |
| """Enhance prompts using Gemma 3 model. Supports T2V and I2V modes.""" |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "clip": ("CLIP",), |
| "prompt": ("STRING", {"multiline": True, "default": ""}), |
| "system_prompt": ( |
| "STRING", |
| { |
| "multiline": True, |
| "default": DEFAULT_T2V_SYSTEM_PROMPT, |
| }, |
| ), |
| "max_tokens": ( |
| "INT", |
| {"default": 512, "min": 32, "max": 1024, "step": 16}, |
| ), |
| "bypass_i2v": ("BOOLEAN", {"default": False}), |
| }, |
| "optional": { |
| "image": ("IMAGE",), |
| "seed": ( |
| "INT", |
| {"default": 42, "min": 0, "max": 0xFFFFFFFF}, |
| ), |
| }, |
| } |
|
|
| RETURN_TYPES = ("STRING",) |
| RETURN_NAMES = ("enhanced_prompt",) |
| FUNCTION = "enhance" |
| CATEGORY = "lightricks/LTXV" |
| TITLE = "LTXV Gemma Enhance Prompt" |
| OUTPUT_NODE = True |
| DESCRIPTION = ( |
| "Enhance text prompts using Gemma 3 VLLM for improved video generation." |
| ) |
|
|
| def enhance( |
| self, |
| clip, |
| prompt: str, |
| system_prompt: str, |
| max_tokens: int, |
| bypass_i2v: bool, |
| image: Optional[torch.Tensor] = None, |
| seed: int = 42, |
| ): |
| if not isinstance(seed, int): |
| seed = 42 |
|
|
| clip.load_model() |
| encoder = clip.cond_stage_model |
|
|
| if not hasattr(encoder, "processor") or encoder.processor is None: |
| if hasattr(encoder, "gemma3_12b"): |
| model, processor = transformers_gemma3_from_encoder(encoder) |
| else: |
| raise ValueError( |
| "Processor not loaded - enhancement not available. " |
| "Ensure your model directory has chat_template.json, processor_config.json, " |
| "and preprocessor_config.json files." |
| ) |
| elif isinstance(encoder, LTXVGemmaTextEncoderModel): |
| model = encoder.model |
| processor = encoder.processor |
| |
| use_i2v = image is not None and not bypass_i2v |
|
|
| |
| if use_i2v and system_prompt.strip() == DEFAULT_T2V_SYSTEM_PROMPT.strip(): |
| system_prompt = DEFAULT_I2V_SYSTEM_PROMPT |
| logger.info("Auto-selected I2V system prompt for image-to-video mode") |
|
|
| if not system_prompt or not system_prompt.strip(): |
| raise ValueError( |
| "system_prompt is required and cannot be empty or whitespace-only" |
| ) |
|
|
| if use_i2v: |
| pil_image = tensor_to_pil(image) |
| enhanced_prompt = enhance_i2v( |
| processor=processor, |
| model=model, |
| prompt=prompt, |
| image=pil_image, |
| system_prompt=system_prompt, |
| max_new_tokens=max_tokens, |
| seed=seed, |
| ) |
| else: |
| enhanced_prompt = enhance_t2v( |
| processor=processor, |
| model=model, |
| prompt=prompt, |
| system_prompt=system_prompt, |
| max_new_tokens=max_tokens, |
| seed=seed, |
| ) |
|
|
| enhanced_prompt = clean_response(enhanced_prompt) |
|
|
| return (enhanced_prompt,) |
|
|
|
|
| def _enhance( |
| processor: Gemma3Processor, |
| model: Gemma3ForConditionalGeneration, |
| messages: list, |
| image: Optional[Image.Image] = None, |
| max_new_tokens: int = 512, |
| seed: int = 42, |
| ) -> str: |
| """Common enhancement logic for both T2V and I2V modes.""" |
| if processor is None: |
| raise ValueError("Processor not loaded - enhancement not available") |
|
|
| text = processor.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| model_inputs = processor( |
| text=text, |
| images=image, |
| return_tensors="pt", |
| ).to(model.device) |
|
|
| pad_token_id = ( |
| processor.tokenizer.pad_token_id |
| if processor.tokenizer.pad_token_id is not None |
| else 0 |
| ) |
| model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id) |
|
|
| with ( |
| torch.inference_mode(), |
| torch.random.fork_rng(devices=[model.device]), |
| torch.autocast(device_type=model.device.type, dtype=model.dtype), |
| ): |
| torch.manual_seed(seed) |
| outputs = model.generate( |
| **model_inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=0.7, |
| ) |
| generated_ids = outputs[0][len(model_inputs.input_ids[0]) :] |
| enhanced_prompt = processor.tokenizer.decode( |
| generated_ids, skip_special_tokens=True |
| ) |
|
|
| return enhanced_prompt |
|
|
|
|
| def enhance_t2v( |
| processor: Gemma3Processor, |
| model: Gemma3ForConditionalGeneration, |
| prompt: str, |
| system_prompt: str, |
| max_new_tokens: int = 512, |
| seed: int = 42, |
| ) -> str: |
| """Enhance a text prompt for T2V generation.""" |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": f"User Raw Input Prompt: {prompt}."}, |
| ] |
| return _enhance( |
| processor, model, messages, max_new_tokens=max_new_tokens, seed=seed |
| ) |
|
|
|
|
| def enhance_i2v( |
| processor: Gemma3Processor, |
| model: Gemma3ForConditionalGeneration, |
| prompt: str, |
| image: Image.Image, |
| system_prompt: str, |
| max_new_tokens: int = 512, |
| seed: int = 42, |
| ) -> str: |
| """Enhance a text prompt for I2V generation using a reference image.""" |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image"}, |
| {"type": "text", "text": f"User Raw Input Prompt: {prompt}."}, |
| ], |
| }, |
| ] |
| return _enhance( |
| processor, |
| model, |
| messages, |
| image=image, |
| max_new_tokens=max_new_tokens, |
| seed=seed, |
| ) |
|
|
|
|
| def _cat_with_padding( |
| tensor: torch.Tensor, |
| padding_length: int, |
| value: int | float, |
| ) -> torch.Tensor: |
| """Concatenate a tensor with a padding tensor of the given value.""" |
| return torch.cat( |
| [ |
| tensor, |
| torch.full( |
| (1, padding_length), |
| value, |
| dtype=tensor.dtype, |
| device=tensor.device, |
| ), |
| ], |
| dim=1, |
| ) |
|
|
|
|
| def _pad_inputs_for_attention_alignment(model_inputs, pad_token_id, alignment: int = 8): |
| """Pad sequence length to multiple of alignment for Flash Attention compatibility. |
| |
| Flash Attention within SDPA requires sequence lengths aligned to 8 bytes. |
| This pads input_ids, attention_mask, and token_type_ids (if present) to prevent |
| 'p.attn_bias_ptr is not correctly aligned' errors. |
| """ |
| seq_len = model_inputs.input_ids.shape[1] |
| padded_len = ((seq_len + alignment - 1) // alignment) * alignment |
| padding_length = padded_len - seq_len |
|
|
| if padding_length > 0: |
| model_inputs["input_ids"] = _cat_with_padding( |
| model_inputs.input_ids, padding_length, pad_token_id |
| ) |
|
|
| model_inputs["attention_mask"] = _cat_with_padding( |
| model_inputs.attention_mask, padding_length, 0 |
| ) |
|
|
| if ( |
| "token_type_ids" in model_inputs |
| and model_inputs["token_type_ids"] is not None |
| ): |
| model_inputs["token_type_ids"] = _cat_with_padding( |
| model_inputs["token_type_ids"], padding_length, 0 |
| ) |
|
|
| return model_inputs |
|
|
|
|
| def _locate_model_within_model(super_model, model_name): |
| class_name = MODEL_MAPPING_NAMES.get(model_name, None) |
| if class_name is None: |
| return None |
| for module in super_model.modules(): |
| if module.__class__.__name__ == class_name: |
| return module |
| return None |
|
|
|
|
| def _locate_unique_parameter_owner_by_leaf( |
| root: torch.nn.Module, |
| leaf_param_name: str, |
| must_have_in_path: Optional[str] = None, |
| ) -> Optional[Tuple[torch.nn.Module, str, torch.nn.Parameter, str]]: |
|
|
| modules = dict(root.named_modules()) |
|
|
| candidates: List[Tuple[torch.nn.Module, str, torch.nn.Parameter, str]] = [] |
| for full_name, p in root.named_parameters(recurse=True): |
| parts = full_name.split(".") |
| leaf = parts[-1] |
| if leaf != leaf_param_name: |
| continue |
| if must_have_in_path is not None and must_have_in_path not in parts: |
| continue |
|
|
| owner_path = ".".join(parts[:-1]) |
| owner = modules.get(owner_path, root if owner_path == "" else None) |
| if owner is None: |
| continue |
| candidates.append((owner, leaf, p, full_name)) |
|
|
| if not candidates: |
| return None |
| return candidates[0] |
|
|
|
|
| def transformers_gemma3_from_encoder(encoder): |
| jsons_path = Path(__file__).parent / "gemma_configs" |
| config = Gemma3Config.from_json_file(jsons_path / "gemma3cfg.json") |
| with torch.device("meta"): |
| metamodel = Gemma3ForConditionalGeneration(config) |
| t_model_name = config.text_config.model_type |
| t_model = _locate_model_within_model(metamodel, t_model_name) |
| if t_model is None: |
| raise ValueError( |
| "Can't locate text model while converting text encoder to Gemma3ForConditionalGeneration" |
| ) |
| t_model.load_state_dict( |
| encoder.gemma3_12b.transformer.model.state_dict(), assign=True, strict=False |
| ) |
| v_tower_name = config.vision_config.model_type |
| v_tower = _locate_model_within_model(metamodel, v_tower_name) |
| if v_tower is None: |
| raise ValueError( |
| "Can't locate vision model while converting text encoder to Gemma3ForConditionalGeneration" |
| ) |
| v_model = v_tower.vision_model |
| v_model.load_state_dict( |
| encoder.gemma3_12b.transformer.vision_model.state_dict(), |
| assign=True, |
| strict=False, |
| ) |
| metamodel.multi_modal_projector.load_state_dict( |
| encoder.gemma3_12b.transformer.multi_modal_projector.state_dict(), |
| assign=True, |
| strict=False, |
| ) |
| config = config.text_config |
| dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
| base = config.rope_local_base_freq |
|
|
| device = encoder.device |
| positions_length = len(v_model.embeddings.position_ids[0]) |
| position_ids = torch.arange( |
| positions_length, dtype=torch.long, device="cpu" |
| ).unsqueeze(0) |
| v_model.embeddings.register_buffer("position_ids", position_ids) |
| embed_scale = torch.tensor(config.hidden_size**0.5, device=device) |
| t_model.embed_tokens.register_buffer("embed_scale", embed_scale) |
| local_rope_freqs = 1.0 / ( |
| base |
| ** ( |
| torch.arange(0, dim, 2, dtype=torch.int64).to( |
| device=device, dtype=torch.float |
| ) |
| / dim |
| ) |
| ) |
| t_model.rotary_emb_local.register_buffer("inv_freq", local_rope_freqs) |
| rope_freqs, _ = ROPE_INIT_FUNCTIONS[config.rope_scaling["rope_type"]]( |
| config, device |
| ) |
| t_model.rotary_emb.register_buffer("inv_freq", rope_freqs) |
| lm_head_requires_grad = False |
| loc_result = _locate_unique_parameter_owner_by_leaf( |
| metamodel, leaf_param_name="weight", must_have_in_path="lm_head" |
| ) |
| if loc_result is None: |
| raise ValueError( |
| "Can't locate lm_head while converting text encoder to Gemma3ForConditionalGeneration" |
| ) |
| lm_head_owner, lm_head_attr, _, _ = loc_result |
| real_w = t_model.embed_tokens.weight |
| setattr( |
| lm_head_owner, |
| lm_head_attr, |
| torch.nn.Parameter(real_w, requires_grad=lm_head_requires_grad), |
| ) |
| metamodel.to(device) |
|
|
| tokenizer_class = ltxv_gemma_tokenizer(jsons_path, max_length=1024) |
| image_processor = AutoImageProcessor.from_pretrained( |
| str(jsons_path), |
| local_files_only=True, |
| ) |
| processor = Gemma3Processor( |
| image_processor=image_processor, |
| tokenizer=tokenizer_class().tokenizer, |
| ) |
| return metamodel, processor |
|
|