import logging from glob import glob from pathlib import Path from typing import List, Optional, Tuple import comfy.model_management import comfy.sd import comfy.supported_models_base import folder_paths import torch from PIL import Image from transformers import ( AutoImageProcessor, AutoTokenizer, Gemma3Config, Gemma3ForConditionalGeneration, Gemma3Processor, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES from .nodes_registry import comfy_node from .text_embeddings_connectors import load_text_embeddings_pipeline logger = logging.getLogger(__name__) def _load_system_prompt(filename: str) -> str: """Load system prompt from file at module level.""" try: prompt_path = Path(__file__).parent / "system_prompts" / filename if prompt_path.exists(): return prompt_path.read_text(encoding="utf-8").strip() except Exception as e: logger.warning(f"Could not load {filename}: {e}") return "" DEFAULT_T2V_SYSTEM_PROMPT = _load_system_prompt("gemma_t2v_system_prompt.txt") DEFAULT_I2V_SYSTEM_PROMPT = _load_system_prompt("gemma_i2v_system_prompt.txt") def tensor_to_pil(tensor: torch.Tensor) -> Image.Image: """Convert ComfyUI image tensor to PIL Image.""" if tensor.dim() == 4: tensor = tensor[0] numpy_image = (tensor.cpu().numpy() * 255).astype("uint8") return Image.fromarray(numpy_image) class LTXVGemmaTokenizer: def __init__(self, tokenizer_path: str, max_length: int = 1024): self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, local_files_only=True, model_max_length=max_length ) # Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much. self.tokenizer.padding_side = "left" if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.max_length = max_length def tokenize_with_weights(self, text: str, return_word_ids: bool = False): text = text.strip() encoded = self.tokenizer( text, padding="max_length", max_length=self.max_length, truncation=True, return_tensors="pt", ) input_ids = encoded.input_ids attention_mask = encoded.attention_mask tuples = [ (token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0])) ] out = {"gemma": tuples} if not return_word_ids: out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()} return out class LTXVGemmaTextEncoderModel(torch.nn.Module): def __init__( self, model: Gemma3ForConditionalGeneration, feature_extractor, # FeatureExtractorV1/V2 embeddings_processor, # VideoEmbeddingsProcessor or AVEmbeddingsProcessor processor: Gemma3Processor | None = None, dtype=torch.bfloat16, device="cpu", ): super().__init__() self.model = model self.processor = processor self.feature_extractor = feature_extractor.to(dtype=dtype) self.embeddings_processor = embeddings_processor.to(dtype=dtype) self.dtypes = set([dtype]) # Cache an estimate of memory required to load/keep the model on device # weights size + small overhead self._model_memory_required = ( comfy.model_management.module_size(self.model) + 256 * 1024 * 1024 ) def set_clip_options(self, options): pass def reset_clip_options(self): pass def forward(self, input_ids, attention_mask, padding_side="right"): # Block 1: Run Gemma model outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) all_layer_hiddens = torch.stack(outputs.hidden_states, dim=-1) # [B, T, D, L] # Block 2: Feature extraction features = self.feature_extractor( all_layer_hiddens, attention_mask, padding_side ) return features # dict with "video" and optionally "audio" def encode_token_weights(self, token_weight_pairs): token_pairs = token_weight_pairs["gemma"] input_ids = torch.tensor( [[t[0] for t in token_pairs]], device=self.model.device ) attention_mask = torch.tensor( [[w[1] for w in token_pairs]], device=self.model.device ) self.to(self.model.device) features = self(input_ids, attention_mask, padding_side="left") # Convert binary mask -> additive mask for processor encoded_input_dtype = next(iter(features.values())).dtype connector_attention_mask = (attention_mask - 1).to(encoded_input_dtype).reshape( (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) ) * torch.finfo(encoded_input_dtype).max # Block 3: Embeddings processor encoded, mask = self.embeddings_processor.create_embeddings( features, connector_attention_mask ) return encoded, None, {"attention_mask": mask} def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) def memory_required(self, input_shape): # Return a conservative estimate in bytesed(input_shape) return self._model_memory_required def ltxv_gemma_tokenizer(tokenizer_path, max_length=256): class _LTXVGemmaTokenizer(LTXVGemmaTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(tokenizer_path, max_length=max_length) return _LTXVGemmaTokenizer def ltxv_gemma_clip(encoder_path, ltxv_path, processor=None, dtype=None): class _LTXVGemmaTextEncoderModel(LTXVGemmaTextEncoderModel): def __init__(self, device="cpu", dtype=dtype, model_options={}): dtype = torch.bfloat16 # TODO: make this configurable gemma_model = Gemma3ForConditionalGeneration.from_pretrained( encoder_path, local_files_only=True, torch_dtype=dtype, ) feature_extractor, embeddings_processor = load_text_embeddings_pipeline( ltxv_path, dtype=dtype, fallback_proj_path=encoder_path / "proj_linear.safetensors", ) super().__init__( model=gemma_model, feature_extractor=feature_extractor, embeddings_processor=embeddings_processor, processor=processor, dtype=dtype, device=device, ) return _LTXVGemmaTextEncoderModel def find_matching_dir(root_path: str, pattern: str) -> str: """ Recursively search for files matching a glob pattern and return the parent directory of the first match. """ matches = [ Path(p) for p in glob(f"{root_path}/**", recursive=True) if Path(p).match(pattern) ] if not matches: raise FileNotFoundError( f"No files matching pattern '{pattern}' found under {root_path}" ) return str(matches[0].parent) @comfy_node(name="LTXVGemmaCLIPModelLoader", description="Gemma 3 Model Loader") class LTXVGemmaCLIPModelLoader: @classmethod def INPUT_TYPES(s): return { "required": { "gemma_path": ( folder_paths.get_filename_list("text_encoders"), {"tooltip": "The name of the text encoder model to load."}, ), "ltxv_path": ( folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the ltxv model to load."}, ), "max_length": ( "INT", {"default": 1024, "min": 16, "max": 131072, "step": 8}, ), } } RETURN_TYPES = ("CLIP",) RETURN_NAMES = ("clip",) FUNCTION = "load_model" CATEGORY = "lightricks/LTXV" TITLE = "LTXV Gemma CLIP Loader" OUTPUT_NODE = False def load_model(self, gemma_path: str, ltxv_path: str, max_length: int): path = Path(folder_paths.get_full_path("text_encoders", gemma_path)) model_root = path.parents[1] tokenizer_path = Path(find_matching_dir(model_root, "tokenizer.model")) gemma_model_path = Path(find_matching_dir(model_root, "model*.safetensors")) processor_path = Path(find_matching_dir(model_root, "preprocessor_config.json")) tokenizer_class = ltxv_gemma_tokenizer(tokenizer_path, max_length=max_length) processor = None try: image_processor = AutoImageProcessor.from_pretrained( str(processor_path), local_files_only=True, ) processor = Gemma3Processor( image_processor=image_processor, tokenizer=tokenizer_class().tokenizer, ) logger.info(f"Loaded processor from {model_root} - enhancement enabled") except Exception as e: logger.warning(f"Could not load processor from {model_root}: {e}") clip_dtype = torch.bfloat16 ltxv_full_path = folder_paths.get_full_path("checkpoints", ltxv_path) clip_target = comfy.supported_models_base.ClipTarget( tokenizer=tokenizer_class, clip=ltxv_gemma_clip( gemma_model_path, ltxv_full_path, processor=processor, dtype=clip_dtype ), ) return (comfy.sd.CLIP(clip_target),) _UNICODE_REPLACEMENTS = str.maketrans( "\u2018\u2019\u201c\u201d\u2014\u2013\u00a0\u2032\u2212", "''\"\"-- '-" ) def clean_response(text): text = text.translate(_UNICODE_REPLACEMENTS) # Remove leading non-letter characters for i, char in enumerate(text): if char.isalpha(): return text[i:] return text @comfy_node(name="LTXVGemmaEnhancePrompt", description="Gemma 3 Prompt Enhancer") class LTXVGemmaEnhancePrompt: """Enhance prompts using Gemma 3 model. Supports T2V and I2V modes.""" @classmethod def INPUT_TYPES(s): return { "required": { "clip": ("CLIP",), "prompt": ("STRING", {"multiline": True, "default": ""}), "system_prompt": ( "STRING", { "multiline": True, "default": DEFAULT_T2V_SYSTEM_PROMPT, }, ), "max_tokens": ( "INT", {"default": 512, "min": 32, "max": 1024, "step": 16}, ), "bypass_i2v": ("BOOLEAN", {"default": False}), }, "optional": { "image": ("IMAGE",), "seed": ( "INT", {"default": 42, "min": 0, "max": 0xFFFFFFFF}, ), }, } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("enhanced_prompt",) FUNCTION = "enhance" CATEGORY = "lightricks/LTXV" TITLE = "LTXV Gemma Enhance Prompt" OUTPUT_NODE = True DESCRIPTION = ( "Enhance text prompts using Gemma 3 VLLM for improved video generation." ) def enhance( self, clip, prompt: str, system_prompt: str, max_tokens: int, bypass_i2v: bool, image: Optional[torch.Tensor] = None, seed: int = 42, ): if not isinstance(seed, int): seed = 42 clip.load_model() encoder = clip.cond_stage_model if not hasattr(encoder, "processor") or encoder.processor is None: if hasattr(encoder, "gemma3_12b"): model, processor = transformers_gemma3_from_encoder(encoder) else: raise ValueError( "Processor not loaded - enhancement not available. " "Ensure your model directory has chat_template.json, processor_config.json, " "and preprocessor_config.json files." ) elif isinstance(encoder, LTXVGemmaTextEncoderModel): model = encoder.model processor = encoder.processor # Determine mode: use I2V if image is provided and not bypassed use_i2v = image is not None and not bypass_i2v # Auto-select the appropriate system prompt if user is using default T2V prompt if use_i2v and system_prompt.strip() == DEFAULT_T2V_SYSTEM_PROMPT.strip(): system_prompt = DEFAULT_I2V_SYSTEM_PROMPT logger.info("Auto-selected I2V system prompt for image-to-video mode") if not system_prompt or not system_prompt.strip(): raise ValueError( "system_prompt is required and cannot be empty or whitespace-only" ) if use_i2v: pil_image = tensor_to_pil(image) enhanced_prompt = enhance_i2v( processor=processor, model=model, prompt=prompt, image=pil_image, system_prompt=system_prompt, max_new_tokens=max_tokens, seed=seed, ) else: enhanced_prompt = enhance_t2v( processor=processor, model=model, prompt=prompt, system_prompt=system_prompt, max_new_tokens=max_tokens, seed=seed, ) enhanced_prompt = clean_response(enhanced_prompt) return (enhanced_prompt,) def _enhance( processor: Gemma3Processor, model: Gemma3ForConditionalGeneration, messages: list, image: Optional[Image.Image] = None, max_new_tokens: int = 512, seed: int = 42, ) -> str: """Common enhancement logic for both T2V and I2V modes.""" if processor is None: raise ValueError("Processor not loaded - enhancement not available") text = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) model_inputs = processor( text=text, images=image, return_tensors="pt", ).to(model.device) pad_token_id = ( processor.tokenizer.pad_token_id if processor.tokenizer.pad_token_id is not None else 0 ) model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id) with ( torch.inference_mode(), torch.random.fork_rng(devices=[model.device]), torch.autocast(device_type=model.device.type, dtype=model.dtype), ): torch.manual_seed(seed) outputs = model.generate( **model_inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, ) generated_ids = outputs[0][len(model_inputs.input_ids[0]) :] enhanced_prompt = processor.tokenizer.decode( generated_ids, skip_special_tokens=True ) return enhanced_prompt def enhance_t2v( processor: Gemma3Processor, model: Gemma3ForConditionalGeneration, prompt: str, system_prompt: str, max_new_tokens: int = 512, seed: int = 42, ) -> str: """Enhance a text prompt for T2V generation.""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"User Raw Input Prompt: {prompt}."}, ] return _enhance( processor, model, messages, max_new_tokens=max_new_tokens, seed=seed ) def enhance_i2v( processor: Gemma3Processor, model: Gemma3ForConditionalGeneration, prompt: str, image: Image.Image, system_prompt: str, max_new_tokens: int = 512, seed: int = 42, ) -> str: """Enhance a text prompt for I2V generation using a reference image.""" messages = [ {"role": "system", "content": system_prompt}, { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": f"User Raw Input Prompt: {prompt}."}, ], }, ] return _enhance( processor, model, messages, image=image, max_new_tokens=max_new_tokens, seed=seed, ) def _cat_with_padding( tensor: torch.Tensor, padding_length: int, value: int | float, ) -> torch.Tensor: """Concatenate a tensor with a padding tensor of the given value.""" return torch.cat( [ tensor, torch.full( (1, padding_length), value, dtype=tensor.dtype, device=tensor.device, ), ], dim=1, ) def _pad_inputs_for_attention_alignment(model_inputs, pad_token_id, alignment: int = 8): """Pad sequence length to multiple of alignment for Flash Attention compatibility. Flash Attention within SDPA requires sequence lengths aligned to 8 bytes. This pads input_ids, attention_mask, and token_type_ids (if present) to prevent 'p.attn_bias_ptr is not correctly aligned' errors. """ seq_len = model_inputs.input_ids.shape[1] padded_len = ((seq_len + alignment - 1) // alignment) * alignment padding_length = padded_len - seq_len if padding_length > 0: model_inputs["input_ids"] = _cat_with_padding( model_inputs.input_ids, padding_length, pad_token_id ) model_inputs["attention_mask"] = _cat_with_padding( model_inputs.attention_mask, padding_length, 0 ) if ( "token_type_ids" in model_inputs and model_inputs["token_type_ids"] is not None ): model_inputs["token_type_ids"] = _cat_with_padding( model_inputs["token_type_ids"], padding_length, 0 ) return model_inputs def _locate_model_within_model(super_model, model_name): class_name = MODEL_MAPPING_NAMES.get(model_name, None) if class_name is None: return None for module in super_model.modules(): if module.__class__.__name__ == class_name: return module return None def _locate_unique_parameter_owner_by_leaf( root: torch.nn.Module, leaf_param_name: str, must_have_in_path: Optional[str] = None, ) -> Optional[Tuple[torch.nn.Module, str, torch.nn.Parameter, str]]: modules = dict(root.named_modules()) candidates: List[Tuple[torch.nn.Module, str, torch.nn.Parameter, str]] = [] for full_name, p in root.named_parameters(recurse=True): parts = full_name.split(".") leaf = parts[-1] if leaf != leaf_param_name: continue if must_have_in_path is not None and must_have_in_path not in parts: continue owner_path = ".".join(parts[:-1]) owner = modules.get(owner_path, root if owner_path == "" else None) if owner is None: continue candidates.append((owner, leaf, p, full_name)) if not candidates: return None return candidates[0] def transformers_gemma3_from_encoder(encoder): jsons_path = Path(__file__).parent / "gemma_configs" config = Gemma3Config.from_json_file(jsons_path / "gemma3cfg.json") with torch.device("meta"): metamodel = Gemma3ForConditionalGeneration(config) t_model_name = config.text_config.model_type t_model = _locate_model_within_model(metamodel, t_model_name) if t_model is None: raise ValueError( "Can't locate text model while converting text encoder to Gemma3ForConditionalGeneration" ) t_model.load_state_dict( encoder.gemma3_12b.transformer.model.state_dict(), assign=True, strict=False ) v_tower_name = config.vision_config.model_type v_tower = _locate_model_within_model(metamodel, v_tower_name) if v_tower is None: raise ValueError( "Can't locate vision model while converting text encoder to Gemma3ForConditionalGeneration" ) v_model = v_tower.vision_model v_model.load_state_dict( encoder.gemma3_12b.transformer.vision_model.state_dict(), assign=True, strict=False, ) metamodel.multi_modal_projector.load_state_dict( encoder.gemma3_12b.transformer.multi_modal_projector.state_dict(), assign=True, strict=False, ) config = config.text_config dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) base = config.rope_local_base_freq device = encoder.device positions_length = len(v_model.embeddings.position_ids[0]) position_ids = torch.arange( positions_length, dtype=torch.long, device="cpu" ).unsqueeze(0) v_model.embeddings.register_buffer("position_ids", position_ids) embed_scale = torch.tensor(config.hidden_size**0.5, device=device) t_model.embed_tokens.register_buffer("embed_scale", embed_scale) local_rope_freqs = 1.0 / ( base ** ( torch.arange(0, dim, 2, dtype=torch.int64).to( device=device, dtype=torch.float ) / dim ) ) t_model.rotary_emb_local.register_buffer("inv_freq", local_rope_freqs) rope_freqs, _ = ROPE_INIT_FUNCTIONS[config.rope_scaling["rope_type"]]( config, device ) t_model.rotary_emb.register_buffer("inv_freq", rope_freqs) lm_head_requires_grad = False loc_result = _locate_unique_parameter_owner_by_leaf( metamodel, leaf_param_name="weight", must_have_in_path="lm_head" ) if loc_result is None: raise ValueError( "Can't locate lm_head while converting text encoder to Gemma3ForConditionalGeneration" ) lm_head_owner, lm_head_attr, _, _ = loc_result real_w = t_model.embed_tokens.weight setattr( lm_head_owner, lm_head_attr, torch.nn.Parameter(real_w, requires_grad=lm_head_requires_grad), ) metamodel.to(device) tokenizer_class = ltxv_gemma_tokenizer(jsons_path, max_length=1024) image_processor = AutoImageProcessor.from_pretrained( str(jsons_path), local_files_only=True, ) processor = Gemma3Processor( image_processor=image_processor, tokenizer=tokenizer_class().tokenizer, ) return metamodel, processor