import torch import torch.nn.functional as F import numpy as np import torch.nn as nn def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def _encode_prompt_with_t5( text_encoder, tokenizer, max_sequence_length=512, prompt=None, num_images_per_prompt=1, device=None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device))[0] if hasattr(text_encoder, "module"): dtype = text_encoder.module.dtype else: dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, device=None, text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) if hasattr(text_encoder, "module"): dtype = text_encoder.module.dtype else: dtype = text_encoder.dtype # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def encode_prompt( text_encoders, tokenizers, prompt: str, max_sequence_length, device=None, num_images_per_prompt: int = 1, text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt if hasattr(text_encoders[0], "module"): dtype = text_encoders[0].module.dtype else: dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( text_encoder=text_encoders[1], tokenizer=tokenizers[1], max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids from transformers import T5EncoderModel, T5Tokenizer, CLIPTokenizer, CLIPTextModel class T5Embedder(torch.nn.Module): def __init__(self, device, max_length=300): super().__init__() self.device = device self.max_length = max_length dtype = torch.bfloat16 self.dtype = dtype t5_version = 'google/t5-v1_1-xxl' self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length) self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device) self.t5_encoder = self.t5_encoder.eval().requires_grad_(False) self.num_shared = max_length @torch.no_grad() def forward(self, text): if isinstance(text, str): text = [text] batch_encoding = self.t5_tokenizer( text, truncation=True, max_length=self.max_length, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) prompt_embeds = self.t5_encoder( input_ids=batch_encoding["input_ids"].to(self.device), attention_mask=None, output_hidden_states=False, )['last_hidden_state'] prompt_attention_mask = batch_encoding['attention_mask'].to(self.device) new_text = [x.split('.')[0] for x in text] batch_encoding = self.t5_tokenizer( new_text, truncation=True, max_length=self.num_shared, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) shared_prompt_embeds = self.t5_encoder( input_ids=batch_encoding["input_ids"].to(self.device), attention_mask=None, output_hidden_states=False, )['last_hidden_state'] return prompt_embeds, shared_prompt_embeds, prompt_attention_mask import random from torch.utils.checkpoint import checkpoint from peft import LoraConfig, set_peft_model_state_dict class LoraT5EmbedderNoGradientCheck(torch.nn.Module): def __init__(self, device, rank=64, max_length=300): super().__init__() self.device = device self.max_length = max_length dtype = torch.bfloat16 self.dtype = dtype t5_version = 'google/t5-v1_1-xxl' self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length) self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device).to(dtype) self.t5_encoder.gradient_checkpointing_enable() self.t5_encoder.config.gradient_checkpointing = True self.t5_encoder.requires_grad_(False) self.t5_encoder.eval() # Add LoRA adapters to the T5 model text_lora_config = LoraConfig( r=rank, lora_alpha=rank, lora_dropout=0.0, init_lora_weights="gaussian", target_modules=["SelfAttention.q", "SelfAttention.k", "SelfAttention.v", "SelfAttention.o", "DenseReluDense.wi", "DenseReluDense.wo"], ) self.t5_encoder.add_adapter(text_lora_config) #self.t5_encoder.encoder.embed_tokens.weight.requires_grad = True print(f"Gradient checkpointing enabled: {self.t5_encoder.is_gradient_checkpointing}") image_encoder_path = 'openai/clip-vit-large-patch14' self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device=device).to(torch.bfloat16) self.image_encoder = self.image_encoder.eval().requires_grad_(False) def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding): """ Compute group lasso for non-pad non-change tokens, L1 for change tokens, and group sparsity for pad non-change tokens. Args: prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim] perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim] replaced_ids: List of replaced token indices for each sample in batch batch_encoding: The tokenizer output containing input_ids Returns: l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor) l1_loss: L1 loss for change tokens (scalar tensor) pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor) """ batch_size = prompt_embeds.size(0) pad_token_id = self.t5_tokenizer.pad_token_id input_ids = batch_encoding["input_ids"] l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device) l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device) pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device) # Track valid samples for each loss type separately l1_valid_samples = 0 l2_valid_samples = 0 pad_valid_samples = 0 for i in range(batch_size): # Get the replaced index for this sample replaced_idx = replaced_ids[i] if replaced_idx is None: # No replacement happened (all padding), skip continue # Find padding and non-padding token indices pad_mask = input_ids[i] == pad_token_id non_pad_mask = ~pad_mask pad_indices = torch.where(pad_mask)[0] non_pad_indices = torch.where(non_pad_mask)[0] # Filter out the replaced index from non-padding indices (non-pad non-change) non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx] # Compute L1 loss on selected (replaced) index - CHANGE TOKEN selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx] l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean() l1_valid_samples += 1 # Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens if len(non_selected_non_pad_indices) > 0: non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[ i, non_selected_non_pad_indices] l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1)) l2_loss_total = l2_loss_total + l2_per_token.mean() l2_valid_samples += 1 # Compute group sparsity loss on PAD NON-CHANGE tokens if len(pad_indices) > 0: pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices] # Group sparsity: L2 norm per token (encourages entire token embeddings to be zero) pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1)) pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean() pad_valid_samples += 1 # Average over valid samples for each loss type l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0, device=prompt_embeds.device) l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0, device=prompt_embeds.device) pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0, device=prompt_embeds.device) return l2_loss, l1_loss, pad_group_loss def forward(self, text, image=None): if isinstance(text, str): text = [text] batch_encoding = self.t5_tokenizer( text, truncation=True, max_length=self.max_length, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) prompt_embeds = self.t5_encoder( input_ids=batch_encoding["input_ids"].to(self.device), attention_mask=None, output_hidden_states=False, )['last_hidden_state'] # Get input_ids and create a copy to modify input_ids = batch_encoding["input_ids"].clone() batch_size = input_ids.size(0) # Get the padding token id pad_token_id = self.t5_tokenizer.pad_token_id replaced_ids = [] # For each sample in the batch for i in range(batch_size): # Find indices of non-padding tokens non_pad_mask = input_ids[i] != pad_token_id non_pad_indices = torch.where(non_pad_mask)[0] # If there are meaningful tokens, randomly select one to replace if len(non_pad_indices) > 0: # Randomly select an index from non-padding tokens random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)] # Replace with padding token input_ids[i, random_idx] = pad_token_id replaced_ids.append(random_idx.item()) else: replaced_ids.append(None) # No replacement if all tokens are padding perturbed_prompt_embeds = self.t5_encoder( input_ids=input_ids.to(self.device), attention_mask=None, output_hidden_states=False, )['last_hidden_state'] l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss( prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding ) with torch.no_grad(): if image is not None: clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds else: clip_image_embeds = None return prompt_embeds, l2_loss, l1_loss, pad_loss,clip_image_embeds from peft import LoraConfig, set_peft_model_state_dict import torch.utils.checkpoint as checkpoint from transformers import CLIPVisionModelWithProjection class LoraT5Embedder(torch.nn.Module): def __init__(self, device, rank=128, max_length=300, use_gradient_checkpointing=True): super().__init__() self.device = device self.max_length = max_length self.use_gradient_checkpointing = use_gradient_checkpointing dtype = torch.bfloat16 self.dtype = dtype t5_version = 'google/t5-v1_1-xxl' self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length) self.t5_encoder = T5EncoderModel.from_pretrained( t5_version, torch_dtype=dtype ).to(device=device).to(dtype) self.t5_encoder.requires_grad_(False) # Add LoRA adapters to the T5 model text_lora_config = LoraConfig( r=rank, lora_alpha=rank, lora_dropout=0.0, init_lora_weights="gaussian", target_modules=["q", "k", "v", "o", "wi", "wo"], ) self.t5_encoder.add_adapter(text_lora_config) self.t5_encoder.encoder.embed_tokens.weight.requires_grad_(True) # Manually implement gradient checkpointing for T5 encoder blocks if self.use_gradient_checkpointing: self._enable_gradient_checkpointing() print(f"Gradient checkpointing enabled: {self.use_gradient_checkpointing}") image_encoder_path = 'openai/clip-vit-large-patch14' self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( image_encoder_path ).to(device=device).to(torch.bfloat16) self.image_encoder = self.image_encoder.eval().requires_grad_(False) def _enable_gradient_checkpointing(self): """ Manually wrap T5 encoder blocks with gradient checkpointing. """ def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # Wrap each T5 block with checkpointing for block in self.t5_encoder.encoder.block: # Store original forward block._original_forward = block.forward # Create checkpointed forward def make_checkpointed_forward(blk): def checkpointed_forward(*args, **kwargs): # Checkpoint requires a function that takes tensors as input def forward_wrapper(*inputs): # Reconstruct kwargs from inputs hidden_states = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else None position_bias = inputs[2] if len(inputs) > 2 else None return blk._original_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_bias=position_bias, **{k: v for k, v in kwargs.items() if k not in ['hidden_states', 'attention_mask', 'position_bias']} ) # Prepare inputs for checkpointing hidden_states = kwargs.get('hidden_states', args[0] if args else None) attention_mask = kwargs.get('attention_mask', args[1] if len(args) > 1 else None) position_bias = kwargs.get('position_bias', args[2] if len(args) > 2 else None) # Use checkpoint checkpoint_inputs = [hidden_states] if attention_mask is not None: checkpoint_inputs.append(attention_mask) if position_bias is not None: checkpoint_inputs.append(position_bias) return checkpoint.checkpoint( forward_wrapper, *checkpoint_inputs, use_reentrant=False ) return checkpointed_forward block.forward = make_checkpointed_forward(block) def _encode_text(self, input_ids): """Helper function to encode text through T5.""" return self.t5_encoder( input_ids=input_ids.to(self.device), attention_mask=None, output_hidden_states=False, )['last_hidden_state'] def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding): """ Compute group lasso for non-pad non-change tokens, L1 for change tokens, and group sparsity for pad non-change tokens. Args: prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim] perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim] replaced_ids: List of replaced token indices for each sample in batch batch_encoding: The tokenizer output containing input_ids Returns: l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor) l1_loss: L1 loss for change tokens (scalar tensor) pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor) """ batch_size = prompt_embeds.size(0) pad_token_id = self.t5_tokenizer.pad_token_id input_ids = batch_encoding["input_ids"] l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device) l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device) pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device) # Track valid samples for each loss type separately l1_valid_samples = 0 l2_valid_samples = 0 pad_valid_samples = 0 for i in range(batch_size): # Get the replaced index for this sample replaced_idx = replaced_ids[i] if replaced_idx is None: # No replacement happened (all padding), skip continue # Find padding and non-padding token indices pad_mask = input_ids[i] == pad_token_id non_pad_mask = ~pad_mask pad_indices = torch.where(pad_mask)[0] non_pad_indices = torch.where(non_pad_mask)[0] # Filter out the replaced index from non-padding indices (non-pad non-change) non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx] # Compute L1 loss on selected (replaced) index - CHANGE TOKEN selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx] l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean() l1_valid_samples += 1 # Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens if len(non_selected_non_pad_indices) > 0: non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[ i, non_selected_non_pad_indices] l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1)) l2_loss_total = l2_loss_total + l2_per_token.mean() l2_valid_samples += 1 # Compute group sparsity loss on PAD NON-CHANGE tokens if len(pad_indices) > 0: pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices] # Group sparsity: L2 norm per token (encourages entire token embeddings to be zero) pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1)) pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean() pad_valid_samples += 1 # Average over valid samples for each loss type l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0, device=prompt_embeds.device) l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0, device=prompt_embeds.device) pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0, device=prompt_embeds.device) return l2_loss, l1_loss, pad_group_loss def forward(self, text, image=None): if isinstance(text, str): text = [text] batch_encoding = self.t5_tokenizer( text, truncation=True, max_length=self.max_length, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) attn_mask = batch_encoding["attention_mask"].to(self.device) # First encoding prompt_embeds = self._encode_text(batch_encoding["input_ids"]) # Get input_ids and create a copy to modify input_ids = batch_encoding["input_ids"].clone() batch_size = input_ids.size(0) # Get the padding token id # get the id for the first sentinel token mask_token = "" mask_token_id = self.t5_tokenizer.convert_tokens_to_ids(mask_token) pad_token_id = self.t5_tokenizer.pad_token_id replaced_ids = [] # For each sample in the batch for i in range(batch_size): # Find indices of non-padding tokens non_pad_mask = input_ids[i] != pad_token_id non_pad_indices = torch.where(non_pad_mask)[0] # If there are meaningful tokens, randomly select one to replace if len(non_pad_indices) > 0: # Randomly select an index from non-padding tokens random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)] random_idx2 = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)] # Replace with padding token input_ids[i, random_idx] = mask_token_id replaced_ids.append(random_idx.item()) else: replaced_ids.append(None) # No replacement if all tokens are padding # Second encoding with perturbed input perturbed_prompt_embeds = self._encode_text(input_ids) """ l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss( prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding ) """ with torch.no_grad(): if image is not None: clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds else: clip_image_embeds = None #return prompt_embeds, l2_loss, l1_loss, pad_loss, clip_image_embeds, attn_mask return prompt_embeds, clip_image_embeds, perturbed_prompt_embeds, replaced_ids, self.t5_tokenizer, batch_encoding from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer class QwenEmbedder(nn.Module): def __init__(self, device, max_length=512): super().__init__() self.device = device self.max_length = max_length dtype = torch.bfloat16 self.dtype = dtype self.tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True) self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=dtype, ).to(device=device) self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.prompt_template_encode_start_idx = 34 self.tokenizer_max_length = max_length def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): bool_mask = mask.bool() valid_lengths = bool_mask.sum(dim=1) selected = hidden_states[bool_mask] split_result = torch.split(selected, valid_lengths.tolist(), dim=0) return split_result def _get_qwen_prompt_embeds( self, prompt = None, device = None, dtype = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt template = self.prompt_template_encode drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] txt_tokens = self.tokenizer( txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" ).to(device) encoder_hidden_states = self.text_encoder( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True, ) hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] #max_seq_len = max([e.size(0) for e in split_hidden_states]) max_seq_len = self.max_length prompt_embeds = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] ) encoder_attention_mask = torch.stack( [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds, encoder_attention_mask @torch.no_grad() def forward(self, text): prompt_embeds, attention_mask = self._get_qwen_prompt_embeds( prompt=text, device=self.device, dtype=self.dtype, ) return prompt_embeds, attention_mask # pip install accelerate from transformers import AutoProcessor, Gemma3ForConditionalGeneration from PIL import Image import requests import torch import torch.nn as nn Qwen25VL_7b_PREFIX_edit = '''Given an user editing prompt and an source image, only describe the editing area and how they should change in a detailed way. Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: ''' Qwen25VL_7b_PREFIX_t2i = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: - If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n Here are examples of how to transform or refine prompts: - User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. - User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: User Prompt:''' Qwen25VL_7b_PREFIX_image = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate." model_id = "google/gemma-3-4b-it" from transformers import AutoTokenizer, TrainingArguments, Gemma3ForCausalLM, AutoModel, Gemma3Model from transformers import Dinov2Model, AutoImageProcessor import torch import torchvision.transforms as transforms import torchvision.models as models from PIL import Image import numpy as np class GemmaEmbedder(nn.Module): def __init__(self, max_sequence_length=300, model_id='google/gemma-3-4b-it'): super().__init__() device = torch.cuda.current_device() self.model = Gemma3Model.from_pretrained(model_id).to(device).to(torch.bfloat16) #self.model = Gemma3ForConditionalGeneration.from_pretrained(model_id).to(device).to(torch.bfloat16) self.processor = AutoProcessor.from_pretrained(model_id) self.device = device self.max_sequence_length = max_sequence_length #self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token # Use eos token as pad token self.processor.tokenizer.padding_side = "right" def get_features(self, hidden_states, input_ids): hidden_states = hidden_states[0] input_ids = input_ids[0].tolist() pad_text_embeds = torch.zeros([self.max_sequence_length, 2560], dtype=torch.bfloat16, device=self.device) pad_text_mask = torch.zeros([self.max_sequence_length], device=self.device) def find_last(lst, value): indices = [i for i, x in enumerate(lst) if x == value] return indices[-1] if 256000 in input_ids: text_start = input_ids.index(256000)+2 else: text_start = find_last(input_ids, 108)+1 text_end = len(input_ids)-6 bos_embed = hidden_states[:2] text_embeds = hidden_states[text_start:text_end + 1] text_embeds = torch.cat([bos_embed, text_embeds], dim=0) pad_text_embeds[:len(text_embeds), :] = text_embeds[:self.max_sequence_length] pad_text_mask[:len(text_embeds)] = 1.0 image_embeds = hidden_states[np.array(input_ids) == self.processor.tokenizer.image_token_id] """ print(input_ids) print(input_ids[text_start:text_end + 1]) decoded = self.processor.decode(input_ids[text_start:text_end+1], skip_special_tokens=False) print("Decoded text:", decoded, text_start, text_end, input_ids[text_start:text_end + 1], input_ids[1:2]) print("Text embeddings shape:", text_embeds.shape) norm = RMSNorm(2560, eps=1e-6).to(self.device).to(torch.bfloat16) print(text_embeds, ' >>> ext embeds') print(norm(text_embeds), ' >>> normed embeds') """ return image_embeds, pad_text_embeds, pad_text_mask @torch.no_grad() def forward(self, caps, images=None): text_embeds = [] text_masks = [] full_image_embeds = [] device = self.model.device if images is None: images = [None] * len(caps) for cap,img in zip(caps, images): if img is not None: messages = [ { "role": "system", "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_edit}] }, { "role": "user", "content": [ {"type": "image", "image": img}, {"type": "text", "text": cap}, ] } ] else: messages = [ { "role": "system", "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_t2i}] }, { "role": "user", "content": [ {"type": "text", "text": cap}, ] } ] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", max_length = 640, truncation = True, ).to(self.model.device, dtype=torch.bfloat16) outputs = self.model(**inputs, output_hidden_states=True) #sample_image_embeds = outputs.image_hidden_states sample_text_embeds, sample_text_mask, sample_image_embeds = [], [], [] for hidden in [outputs.hidden_states[-1]]: cur_image_embeds, cur_text_embeds, cur_text_mask = self.get_features(hidden, inputs["input_ids"]) sample_text_embeds.append(cur_text_embeds) sample_text_mask.append(cur_text_mask) sample_image_embeds.append(cur_image_embeds) text_embeds.append(torch.cat(sample_text_embeds, dim=0)) text_masks.append(torch.cat(sample_text_mask, dim=0)) #full_image_embeds.append(sample_image_embeds) full_image_embeds.append(torch.cat(sample_image_embeds, dim=0)) """ input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False) generation = generation[0][input_len:] decoded = self.processor.decode(generation, skip_special_tokens=True) print(cap, ' <>>> gemma ',decoded) """ text_embeds = torch.stack(text_embeds, dim=0) text_masks = torch.stack(text_masks, dim=0) full_image_embeds = torch.stack(full_image_embeds, dim=0) return { 'text_embeds': text_embeds, 'text_masks': text_masks, 'image_embeds': full_image_embeds, } class GemmaTextEmbedder(nn.Module): def __init__(self, device, max_sequence_length=300, model_id='./gemma-3-4b-it'): super().__init__() self.model = Gemma3Model.from_pretrained(model_id).to(device).to(torch.bfloat16) #self.model = Gemma3ForConditionalGeneration.from_pretrained(model_id).to(device).to(torch.bfloat16) self.processor = AutoProcessor.from_pretrained(model_id) self.real_device = device self.max_sequence_length = max_sequence_length #self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token # Use eos token as pad token self.processor.tokenizer.padding_side = "right" @property def dtype(self): """Return the dtype of the model parameters.""" return next(self.parameters()).dtype @property def device(self): """Return the device of the model parameters.""" return next(self.parameters()).device def get_features(self, hidden_states, input_ids): hidden_states = hidden_states[0] input_ids = input_ids[0].tolist() pad_text_embeds = torch.zeros([self.max_sequence_length, 2560], dtype=torch.bfloat16, device=self.device) pad_text_mask = torch.zeros([self.max_sequence_length], device=self.device) def find_last(lst, value): indices = [i for i, x in enumerate(lst) if x == value] return indices[-1] if 256000 in input_ids: text_start = input_ids.index(256000)+2 else: text_start = find_last(input_ids, 108)+1 text_end = len(input_ids)-6 bos_embed = hidden_states[:2] text_embeds = hidden_states[text_start:text_end + 1] text_embeds = torch.cat([bos_embed, text_embeds], dim=0) pad_text_embeds[:len(text_embeds), :] = text_embeds[:self.max_sequence_length] pad_text_mask[:len(text_embeds)] = 1.0 pad_text_embeds[len(text_embeds):, :] = 0.0 pad_text_mask[len(text_embeds):] = 0.0 """ print(input_ids) print(input_ids[text_start:text_end + 1]) decoded = self.processor.decode(input_ids[text_start:text_end+1], skip_special_tokens=False) print("Decoded text:", decoded, text_start, text_end, input_ids[text_start:text_end + 1], input_ids[1:2]) print("Text embeddings shape:", text_embeds.shape) print(text_embeds, ' >>> ext embeds') """ return pad_text_embeds, pad_text_mask @torch.no_grad() def forward(self, caps, images=None): text_embeds = [] text_masks = [] full_image_embeds = [] device = self.model.device if isinstance(caps, str): caps = [caps] if images is None: images = [None] * len(caps) for cap,img in zip(caps, images): if img is not None: messages = [ { "role": "system", "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_edit}] }, { "role": "user", "content": [ {"type": "image", "image": img}, {"type": "text", "text": cap}, ] } ] else: messages = [ { "role": "system", "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_t2i}] }, { "role": "user", "content": [ {"type": "text", "text": cap}, ] } ] inputs = self.processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", max_length = 640, truncation = True, ).to(self.model.device, dtype=torch.bfloat16) outputs = self.model(**inputs, output_hidden_states=True) #sample_image_embeds = outputs.image_hidden_states sample_text_embeds, sample_text_mask, sample_image_embeds = [], [], [] for hidden in [outputs.hidden_states[-1]]: cur_text_embeds, cur_text_mask = self.get_features(hidden, inputs["input_ids"]) sample_text_embeds.append(cur_text_embeds) sample_text_mask.append(cur_text_mask) text_embeds.append(torch.cat(sample_text_embeds, dim=0)) text_masks.append(torch.cat(sample_text_mask, dim=0)) """ input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False) generation = generation[0][input_len:] decoded = self.processor.decode(generation, skip_special_tokens=True) print(cap, ' <>>> gemma ',decoded) """ text_embeds = torch.stack(text_embeds, dim=0) text_masks = torch.stack(text_masks, dim=0) return text_embeds, text_masks.to(text_embeds.dtype) from transformers import AutoModel, AutoTokenizer from transformers import SiglipVisionModel, AutoProcessor class Gemma2Embedder(nn.Module): def __init__(self, max_length=300): super().__init__() self.text_encoder = AutoModel.from_pretrained( "google/gemma-2-2b", torch_dtype=torch.bfloat16, ).to(torch.cuda.current_device()).to(torch.bfloat16).eval() self.tokenizer = AutoTokenizer.from_pretrained( "google/gemma-2-2b", ) self.tokenizer.padding_side = "right" self.max_length = max_length self.system_prompt = "You are an assistant designed to edit images faithfully based on user prompts. " system_ids = self.tokenizer( self.system_prompt, return_tensors="pt", add_special_tokens=True, max_length=self.max_length, padding="max_length", truncation=True, ).input_ids.flatten().view(-1).numpy().tolist() self.len_system_prompt = system_ids.index(self.tokenizer.pad_token_id)-1 self.weight_dtype = torch.bfloat16 @torch.no_grad() def forward(self, caption): if isinstance(caption, str): caption = [caption] caption = [self.system_prompt + c for c in caption] text_inputs = self.tokenizer( caption, return_tensors="pt", add_special_tokens=True, max_length=self.max_length+self.len_system_prompt, padding="max_length", truncation=True, ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask text_input_ids = text_input_ids.to(self.text_encoder.device) attention_mask = attention_mask.to(self.text_encoder.device) embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask, output_hidden_states=True ).hidden_states[-2] embeds = embeds[:, self.len_system_prompt:, :] attention_mask = attention_mask[:, self.len_system_prompt:] return { 'text_embeds': embeds, 'text_masks': attention_mask, } class T5TextEmbedder(nn.Module): def __init__(self, device, pretrained_path="google/flan-t5-xxl", max_length=300): super().__init__() self.model = T5EncoderModel.from_pretrained(pretrained_path).to(device=device).to(torch.bfloat16) self.tokenizer = T5Tokenizer.from_pretrained(pretrained_path) self.max_length = max_length self.model.eval() self.model.requires_grad_(False) @property def dtype(self): """Return the dtype of the model parameters.""" return next(self.parameters()).dtype @property def device(self): """Return the device of the model parameters.""" return next(self.parameters()).device def forward( self, caption ): max_length = self.max_length text_inputs = self.tokenizer( caption, return_tensors="pt", add_special_tokens=True, max_length=max_length, padding="max_length", truncation=True, ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask text_input_ids = text_input_ids.to(self.model.device) attention_mask = attention_mask.to(self.model.device) outputs = self.model(text_input_ids, attention_mask=attention_mask) embeddings = outputs.last_hidden_state return embeddings, attention_mask.to(embeddings.dtype) if __name__ == '__main__': from datasets import load_dataset dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]') item = dataset[0:4] another_item = dataset[0:4] from diffusers.models.normalization import RMSNorm image_encoder = CLIPImageEncoder(device="cuda:0") clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") image_embeds = image_encoder(clip_processor(images=item['image'], return_tensors="pt").pixel_values.to("cuda:0").to(torch.bfloat16)) print(image_embeds.shape, ' >>>> image embeds') #model = GemmaTextEmbedder(device="cuda:0") model = LoraT5Embedder(device="cuda:0") prompt_embeds, l2_loss, l1_loss, pad_loss, clip_image_embeds, attn_mask = model( [ """A heartwarming 3D rendered scene of an elderly farmer and a tiny orange kitten. The farmer, with a gentle smile, walks alongside the kitten in a lush, green garden filled with thriving plants, showcasing a fruitful harvest. The intricate details of the overalls and the farmer's worn, weathered face tell a story of years spent tending to the land, the farmer is wearing a blue shirt""", ], image=clip_processor(images=item['image'], return_tensors="pt").pixel_values.to("cuda:0").to(torch.bfloat16 )) print(l2_loss, ' >>> l2 loss ', l1_loss, ' >>> l1 loss ', pad_loss, ' >>> pad loss ') print(clip_image_embeds.shape, ' >>> clip image embeds ') #print(gemma_dict['text_embeds'],) #print(gemma_dict['image_embeds'], ' >>> image embeds') """ from dataset import create_loader from PIL import Image as PILImage from PIL import Image as PILImage import PIL import numpy as np import torch.nn.functional as F loader = create_loader('edit', batch_size=16, shuffle=False) batch = next(iter(loader)) source = batch['source_images'] source_pils = [PIL.Image.fromarray(((x.permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype(np.uint8)) for x in source] target = batch['target_images'] target_pils = [PIL.Image.fromarray(((x.permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype(np.uint8)) for x in target] from torchvision.utils import save_image print(batch['captions']) images = [] for (x, y) in zip(batch['source_images'], batch['target_images']): images.append(x) images.append(y) save_image((torch.stack(images) + 1) / 2, 'example_pairs.jpg', nrow=8) gemma_dict = model(batch['captions'], source_pils, target_pils) image_embeds = gemma_dict['image_embeds'] target_image_embeds = gemma_dict['target_image_embeds'] print("Image embeds shape:", image_embeds.shape) print("Target image embeds shape:", target_image_embeds.shape) from qwen import compute_and_save_similarity_grid compute_and_save_similarity_grid(image_embeds, target_image_embeds, "gemma_similarity_grid.jpg") """