Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import torch.nn as nn | |
| def tokenize_prompt(tokenizer, prompt, max_sequence_length): | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| return text_input_ids | |
| def _encode_prompt_with_t5( | |
| text_encoder, | |
| tokenizer, | |
| max_sequence_length=512, | |
| prompt=None, | |
| num_images_per_prompt=1, | |
| device=None, | |
| text_input_ids=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if tokenizer is not None: | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| else: | |
| if text_input_ids is None: | |
| raise ValueError("text_input_ids must be provided when the tokenizer is not specified") | |
| prompt_embeds = text_encoder(text_input_ids.to(device))[0] | |
| if hasattr(text_encoder, "module"): | |
| dtype = text_encoder.module.dtype | |
| else: | |
| dtype = text_encoder.dtype | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| _, seq_len, _ = prompt_embeds.shape | |
| # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | |
| return prompt_embeds | |
| def _encode_prompt_with_clip( | |
| text_encoder, | |
| tokenizer, | |
| prompt: str, | |
| device=None, | |
| text_input_ids=None, | |
| num_images_per_prompt: int = 1, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| batch_size = len(prompt) | |
| if tokenizer is not None: | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_overflowing_tokens=False, | |
| return_length=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| else: | |
| if text_input_ids is None: | |
| raise ValueError("text_input_ids must be provided when the tokenizer is not specified") | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) | |
| if hasattr(text_encoder, "module"): | |
| dtype = text_encoder.module.dtype | |
| else: | |
| dtype = text_encoder.dtype | |
| # Use pooled output of CLIPTextModel | |
| prompt_embeds = prompt_embeds.pooler_output | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) | |
| return prompt_embeds | |
| def encode_prompt( | |
| text_encoders, | |
| tokenizers, | |
| prompt: str, | |
| max_sequence_length, | |
| device=None, | |
| num_images_per_prompt: int = 1, | |
| text_input_ids_list=None, | |
| ): | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| if hasattr(text_encoders[0], "module"): | |
| dtype = text_encoders[0].module.dtype | |
| else: | |
| dtype = text_encoders[0].dtype | |
| pooled_prompt_embeds = _encode_prompt_with_clip( | |
| text_encoder=text_encoders[0], | |
| tokenizer=tokenizers[0], | |
| prompt=prompt, | |
| device=device if device is not None else text_encoders[0].device, | |
| num_images_per_prompt=num_images_per_prompt, | |
| text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, | |
| ) | |
| prompt_embeds = _encode_prompt_with_t5( | |
| text_encoder=text_encoders[1], | |
| tokenizer=tokenizers[1], | |
| max_sequence_length=max_sequence_length, | |
| prompt=prompt, | |
| num_images_per_prompt=num_images_per_prompt, | |
| device=device if device is not None else text_encoders[1].device, | |
| text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, | |
| ) | |
| text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) | |
| return prompt_embeds, pooled_prompt_embeds, text_ids | |
| from transformers import T5EncoderModel, T5Tokenizer, CLIPTokenizer, CLIPTextModel | |
| class T5Embedder(torch.nn.Module): | |
| def __init__(self, device, max_length=300): | |
| super().__init__() | |
| self.device = device | |
| self.max_length = max_length | |
| dtype = torch.bfloat16 | |
| self.dtype = dtype | |
| t5_version = 'google/t5-v1_1-xxl' | |
| self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length) | |
| self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device) | |
| self.t5_encoder = self.t5_encoder.eval().requires_grad_(False) | |
| self.num_shared = max_length | |
| def forward(self, text): | |
| if isinstance(text, str): | |
| text = [text] | |
| batch_encoding = self.t5_tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| prompt_embeds = self.t5_encoder( | |
| input_ids=batch_encoding["input_ids"].to(self.device), | |
| attention_mask=None, | |
| output_hidden_states=False, | |
| )['last_hidden_state'] | |
| prompt_attention_mask = batch_encoding['attention_mask'].to(self.device) | |
| new_text = [x.split('.')[0] for x in text] | |
| batch_encoding = self.t5_tokenizer( | |
| new_text, | |
| truncation=True, | |
| max_length=self.num_shared, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| shared_prompt_embeds = self.t5_encoder( | |
| input_ids=batch_encoding["input_ids"].to(self.device), | |
| attention_mask=None, | |
| output_hidden_states=False, | |
| )['last_hidden_state'] | |
| return prompt_embeds, shared_prompt_embeds, prompt_attention_mask | |
| import random | |
| from torch.utils.checkpoint import checkpoint | |
| from peft import LoraConfig, set_peft_model_state_dict | |
| class LoraT5EmbedderNoGradientCheck(torch.nn.Module): | |
| def __init__(self, device, rank=64, max_length=300): | |
| super().__init__() | |
| self.device = device | |
| self.max_length = max_length | |
| dtype = torch.bfloat16 | |
| self.dtype = dtype | |
| t5_version = 'google/t5-v1_1-xxl' | |
| self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length) | |
| self.t5_encoder = T5EncoderModel.from_pretrained(t5_version, torch_dtype=dtype).to(device=device).to(dtype) | |
| self.t5_encoder.gradient_checkpointing_enable() | |
| self.t5_encoder.config.gradient_checkpointing = True | |
| self.t5_encoder.requires_grad_(False) | |
| self.t5_encoder.eval() | |
| # Add LoRA adapters to the T5 model | |
| text_lora_config = LoraConfig( | |
| r=rank, | |
| lora_alpha=rank, | |
| lora_dropout=0.0, | |
| init_lora_weights="gaussian", | |
| target_modules=["SelfAttention.q", "SelfAttention.k", "SelfAttention.v", "SelfAttention.o", "DenseReluDense.wi", "DenseReluDense.wo"], | |
| ) | |
| self.t5_encoder.add_adapter(text_lora_config) | |
| #self.t5_encoder.encoder.embed_tokens.weight.requires_grad = True | |
| print(f"Gradient checkpointing enabled: {self.t5_encoder.is_gradient_checkpointing}") | |
| image_encoder_path = 'openai/clip-vit-large-patch14' | |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device=device).to(torch.bfloat16) | |
| self.image_encoder = self.image_encoder.eval().requires_grad_(False) | |
| def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding): | |
| """ | |
| Compute group lasso for non-pad non-change tokens, L1 for change tokens, | |
| and group sparsity for pad non-change tokens. | |
| Args: | |
| prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim] | |
| perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim] | |
| replaced_ids: List of replaced token indices for each sample in batch | |
| batch_encoding: The tokenizer output containing input_ids | |
| Returns: | |
| l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor) | |
| l1_loss: L1 loss for change tokens (scalar tensor) | |
| pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor) | |
| """ | |
| batch_size = prompt_embeds.size(0) | |
| pad_token_id = self.t5_tokenizer.pad_token_id | |
| input_ids = batch_encoding["input_ids"] | |
| l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device) | |
| l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device) | |
| pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device) | |
| # Track valid samples for each loss type separately | |
| l1_valid_samples = 0 | |
| l2_valid_samples = 0 | |
| pad_valid_samples = 0 | |
| for i in range(batch_size): | |
| # Get the replaced index for this sample | |
| replaced_idx = replaced_ids[i] | |
| if replaced_idx is None: | |
| # No replacement happened (all padding), skip | |
| continue | |
| # Find padding and non-padding token indices | |
| pad_mask = input_ids[i] == pad_token_id | |
| non_pad_mask = ~pad_mask | |
| pad_indices = torch.where(pad_mask)[0] | |
| non_pad_indices = torch.where(non_pad_mask)[0] | |
| # Filter out the replaced index from non-padding indices (non-pad non-change) | |
| non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx] | |
| # Compute L1 loss on selected (replaced) index - CHANGE TOKEN | |
| selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx] | |
| l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean() | |
| l1_valid_samples += 1 | |
| # Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens | |
| if len(non_selected_non_pad_indices) > 0: | |
| non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[ | |
| i, non_selected_non_pad_indices] | |
| l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1)) | |
| l2_loss_total = l2_loss_total + l2_per_token.mean() | |
| l2_valid_samples += 1 | |
| # Compute group sparsity loss on PAD NON-CHANGE tokens | |
| if len(pad_indices) > 0: | |
| pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices] | |
| # Group sparsity: L2 norm per token (encourages entire token embeddings to be zero) | |
| pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1)) | |
| pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean() | |
| pad_valid_samples += 1 | |
| # Average over valid samples for each loss type | |
| l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0, | |
| device=prompt_embeds.device) | |
| l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0, | |
| device=prompt_embeds.device) | |
| pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0, | |
| device=prompt_embeds.device) | |
| return l2_loss, l1_loss, pad_group_loss | |
| def forward(self, text, image=None): | |
| if isinstance(text, str): | |
| text = [text] | |
| batch_encoding = self.t5_tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| prompt_embeds = self.t5_encoder( | |
| input_ids=batch_encoding["input_ids"].to(self.device), | |
| attention_mask=None, | |
| output_hidden_states=False, | |
| )['last_hidden_state'] | |
| # Get input_ids and create a copy to modify | |
| input_ids = batch_encoding["input_ids"].clone() | |
| batch_size = input_ids.size(0) | |
| # Get the padding token id | |
| pad_token_id = self.t5_tokenizer.pad_token_id | |
| replaced_ids = [] | |
| # For each sample in the batch | |
| for i in range(batch_size): | |
| # Find indices of non-padding tokens | |
| non_pad_mask = input_ids[i] != pad_token_id | |
| non_pad_indices = torch.where(non_pad_mask)[0] | |
| # If there are meaningful tokens, randomly select one to replace | |
| if len(non_pad_indices) > 0: | |
| # Randomly select an index from non-padding tokens | |
| random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)] | |
| # Replace with padding token | |
| input_ids[i, random_idx] = pad_token_id | |
| replaced_ids.append(random_idx.item()) | |
| else: | |
| replaced_ids.append(None) # No replacement if all tokens are padding | |
| perturbed_prompt_embeds = self.t5_encoder( | |
| input_ids=input_ids.to(self.device), | |
| attention_mask=None, | |
| output_hidden_states=False, | |
| )['last_hidden_state'] | |
| l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss( | |
| prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding | |
| ) | |
| with torch.no_grad(): | |
| if image is not None: | |
| clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds | |
| else: | |
| clip_image_embeds = None | |
| return prompt_embeds, l2_loss, l1_loss, pad_loss,clip_image_embeds | |
| from peft import LoraConfig, set_peft_model_state_dict | |
| import torch.utils.checkpoint as checkpoint | |
| from transformers import CLIPVisionModelWithProjection | |
| class LoraT5Embedder(torch.nn.Module): | |
| def __init__(self, device, rank=128, max_length=300, use_gradient_checkpointing=True): | |
| super().__init__() | |
| self.device = device | |
| self.max_length = max_length | |
| self.use_gradient_checkpointing = use_gradient_checkpointing | |
| dtype = torch.bfloat16 | |
| self.dtype = dtype | |
| t5_version = 'google/t5-v1_1-xxl' | |
| self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_version, max_length=max_length) | |
| self.t5_encoder = T5EncoderModel.from_pretrained( | |
| t5_version, | |
| torch_dtype=dtype | |
| ).to(device=device).to(dtype) | |
| self.t5_encoder.requires_grad_(False) | |
| # Add LoRA adapters to the T5 model | |
| text_lora_config = LoraConfig( | |
| r=rank, | |
| lora_alpha=rank, | |
| lora_dropout=0.0, | |
| init_lora_weights="gaussian", | |
| target_modules=["q", "k", "v", "o", "wi", "wo"], | |
| ) | |
| self.t5_encoder.add_adapter(text_lora_config) | |
| self.t5_encoder.encoder.embed_tokens.weight.requires_grad_(True) | |
| # Manually implement gradient checkpointing for T5 encoder blocks | |
| if self.use_gradient_checkpointing: | |
| self._enable_gradient_checkpointing() | |
| print(f"Gradient checkpointing enabled: {self.use_gradient_checkpointing}") | |
| image_encoder_path = 'openai/clip-vit-large-patch14' | |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| image_encoder_path | |
| ).to(device=device).to(torch.bfloat16) | |
| self.image_encoder = self.image_encoder.eval().requires_grad_(False) | |
| def _enable_gradient_checkpointing(self): | |
| """ | |
| Manually wrap T5 encoder blocks with gradient checkpointing. | |
| """ | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| # Wrap each T5 block with checkpointing | |
| for block in self.t5_encoder.encoder.block: | |
| # Store original forward | |
| block._original_forward = block.forward | |
| # Create checkpointed forward | |
| def make_checkpointed_forward(blk): | |
| def checkpointed_forward(*args, **kwargs): | |
| # Checkpoint requires a function that takes tensors as input | |
| def forward_wrapper(*inputs): | |
| # Reconstruct kwargs from inputs | |
| hidden_states = inputs[0] | |
| attention_mask = inputs[1] if len(inputs) > 1 else None | |
| position_bias = inputs[2] if len(inputs) > 2 else None | |
| return blk._original_forward( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| position_bias=position_bias, | |
| **{k: v for k, v in kwargs.items() if | |
| k not in ['hidden_states', 'attention_mask', 'position_bias']} | |
| ) | |
| # Prepare inputs for checkpointing | |
| hidden_states = kwargs.get('hidden_states', args[0] if args else None) | |
| attention_mask = kwargs.get('attention_mask', args[1] if len(args) > 1 else None) | |
| position_bias = kwargs.get('position_bias', args[2] if len(args) > 2 else None) | |
| # Use checkpoint | |
| checkpoint_inputs = [hidden_states] | |
| if attention_mask is not None: | |
| checkpoint_inputs.append(attention_mask) | |
| if position_bias is not None: | |
| checkpoint_inputs.append(position_bias) | |
| return checkpoint.checkpoint( | |
| forward_wrapper, | |
| *checkpoint_inputs, | |
| use_reentrant=False | |
| ) | |
| return checkpointed_forward | |
| block.forward = make_checkpointed_forward(block) | |
| def _encode_text(self, input_ids): | |
| """Helper function to encode text through T5.""" | |
| return self.t5_encoder( | |
| input_ids=input_ids.to(self.device), | |
| attention_mask=None, | |
| output_hidden_states=False, | |
| )['last_hidden_state'] | |
| def compute_perturbation_loss(self, prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding): | |
| """ | |
| Compute group lasso for non-pad non-change tokens, L1 for change tokens, | |
| and group sparsity for pad non-change tokens. | |
| Args: | |
| prompt_embeds: Original embeddings [batch_size, seq_len, hidden_dim] | |
| perturbed_prompt_embeds: Perturbed embeddings [batch_size, seq_len, hidden_dim] | |
| replaced_ids: List of replaced token indices for each sample in batch | |
| batch_encoding: The tokenizer output containing input_ids | |
| Returns: | |
| l2_loss: Group lasso loss for non-pad non-change tokens (scalar tensor) | |
| l1_loss: L1 loss for change tokens (scalar tensor) | |
| pad_group_loss: Group sparsity loss for pad non-change tokens (scalar tensor) | |
| """ | |
| batch_size = prompt_embeds.size(0) | |
| pad_token_id = self.t5_tokenizer.pad_token_id | |
| input_ids = batch_encoding["input_ids"] | |
| l2_loss_total = torch.tensor(0.0, device=prompt_embeds.device) | |
| l1_loss_total = torch.tensor(0.0, device=prompt_embeds.device) | |
| pad_group_loss_total = torch.tensor(0.0, device=prompt_embeds.device) | |
| # Track valid samples for each loss type separately | |
| l1_valid_samples = 0 | |
| l2_valid_samples = 0 | |
| pad_valid_samples = 0 | |
| for i in range(batch_size): | |
| # Get the replaced index for this sample | |
| replaced_idx = replaced_ids[i] | |
| if replaced_idx is None: | |
| # No replacement happened (all padding), skip | |
| continue | |
| # Find padding and non-padding token indices | |
| pad_mask = input_ids[i] == pad_token_id | |
| non_pad_mask = ~pad_mask | |
| pad_indices = torch.where(pad_mask)[0] | |
| non_pad_indices = torch.where(non_pad_mask)[0] | |
| # Filter out the replaced index from non-padding indices (non-pad non-change) | |
| non_selected_non_pad_indices = non_pad_indices[non_pad_indices != replaced_idx] | |
| # Compute L1 loss on selected (replaced) index - CHANGE TOKEN | |
| selected_diff = prompt_embeds[i, replaced_idx] - perturbed_prompt_embeds[i, replaced_idx] | |
| l1_loss_total = l1_loss_total + torch.abs(selected_diff).mean() | |
| l1_valid_samples += 1 | |
| # Compute group lasso (L2) loss on NON-PAD NON-CHANGE tokens | |
| if len(non_selected_non_pad_indices) > 0: | |
| non_selected_diff = prompt_embeds[i, non_selected_non_pad_indices] - perturbed_prompt_embeds[ | |
| i, non_selected_non_pad_indices] | |
| l2_per_token = torch.sqrt((non_selected_diff ** 2).sum(dim=1)) | |
| l2_loss_total = l2_loss_total + l2_per_token.mean() | |
| l2_valid_samples += 1 | |
| # Compute group sparsity loss on PAD NON-CHANGE tokens | |
| if len(pad_indices) > 0: | |
| pad_diff = prompt_embeds[i, pad_indices] - perturbed_prompt_embeds[i, pad_indices] | |
| # Group sparsity: L2 norm per token (encourages entire token embeddings to be zero) | |
| pad_group_per_token = torch.sqrt((pad_diff ** 2).sum(dim=1)) | |
| pad_group_loss_total = pad_group_loss_total + pad_group_per_token.mean() | |
| pad_valid_samples += 1 | |
| # Average over valid samples for each loss type | |
| l2_loss = l2_loss_total / l2_valid_samples if l2_valid_samples > 0 else torch.tensor(0.0, | |
| device=prompt_embeds.device) | |
| l1_loss = l1_loss_total / l1_valid_samples if l1_valid_samples > 0 else torch.tensor(0.0, | |
| device=prompt_embeds.device) | |
| pad_group_loss = pad_group_loss_total / pad_valid_samples if pad_valid_samples > 0 else torch.tensor(0.0, | |
| device=prompt_embeds.device) | |
| return l2_loss, l1_loss, pad_group_loss | |
| def forward(self, text, image=None): | |
| if isinstance(text, str): | |
| text = [text] | |
| batch_encoding = self.t5_tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| attn_mask = batch_encoding["attention_mask"].to(self.device) | |
| # First encoding | |
| prompt_embeds = self._encode_text(batch_encoding["input_ids"]) | |
| # Get input_ids and create a copy to modify | |
| input_ids = batch_encoding["input_ids"].clone() | |
| batch_size = input_ids.size(0) | |
| # Get the padding token id | |
| # get the id for the first sentinel token | |
| mask_token = "<extra_id_0>" | |
| mask_token_id = self.t5_tokenizer.convert_tokens_to_ids(mask_token) | |
| pad_token_id = self.t5_tokenizer.pad_token_id | |
| replaced_ids = [] | |
| # For each sample in the batch | |
| for i in range(batch_size): | |
| # Find indices of non-padding tokens | |
| non_pad_mask = input_ids[i] != pad_token_id | |
| non_pad_indices = torch.where(non_pad_mask)[0] | |
| # If there are meaningful tokens, randomly select one to replace | |
| if len(non_pad_indices) > 0: | |
| # Randomly select an index from non-padding tokens | |
| random_idx = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)] | |
| random_idx2 = non_pad_indices[random.randint(0, len(non_pad_indices) - 1)] | |
| # Replace with padding token | |
| input_ids[i, random_idx] = mask_token_id | |
| replaced_ids.append(random_idx.item()) | |
| else: | |
| replaced_ids.append(None) # No replacement if all tokens are padding | |
| # Second encoding with perturbed input | |
| perturbed_prompt_embeds = self._encode_text(input_ids) | |
| """ | |
| l2_loss, l1_loss, pad_loss = self.compute_perturbation_loss( | |
| prompt_embeds, perturbed_prompt_embeds, replaced_ids, batch_encoding | |
| ) | |
| """ | |
| with torch.no_grad(): | |
| if image is not None: | |
| clip_image_embeds = self.image_encoder(image.to(self.device)).image_embeds | |
| else: | |
| clip_image_embeds = None | |
| #return prompt_embeds, l2_loss, l1_loss, pad_loss, clip_image_embeds, attn_mask | |
| return prompt_embeds, clip_image_embeds, perturbed_prompt_embeds, replaced_ids, self.t5_tokenizer, batch_encoding | |
| from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer | |
| class QwenEmbedder(nn.Module): | |
| def __init__(self, device, max_length=512): | |
| super().__init__() | |
| self.device = device | |
| self.max_length = max_length | |
| dtype = torch.bfloat16 | |
| self.dtype = dtype | |
| self.tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True) | |
| self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=dtype, | |
| ).to(device=device) | |
| self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | |
| self.prompt_template_encode_start_idx = 34 | |
| self.tokenizer_max_length = max_length | |
| def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): | |
| bool_mask = mask.bool() | |
| valid_lengths = bool_mask.sum(dim=1) | |
| selected = hidden_states[bool_mask] | |
| split_result = torch.split(selected, valid_lengths.tolist(), dim=0) | |
| return split_result | |
| def _get_qwen_prompt_embeds( | |
| self, | |
| prompt = None, | |
| device = None, | |
| dtype = None, | |
| ): | |
| device = device or self._execution_device | |
| dtype = dtype or self.text_encoder.dtype | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| template = self.prompt_template_encode | |
| drop_idx = self.prompt_template_encode_start_idx | |
| txt = [template.format(e) for e in prompt] | |
| txt_tokens = self.tokenizer( | |
| txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" | |
| ).to(device) | |
| encoder_hidden_states = self.text_encoder( | |
| input_ids=txt_tokens.input_ids, | |
| attention_mask=txt_tokens.attention_mask, | |
| output_hidden_states=True, | |
| ) | |
| hidden_states = encoder_hidden_states.hidden_states[-1] | |
| split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) | |
| split_hidden_states = [e[drop_idx:] for e in split_hidden_states] | |
| attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] | |
| #max_seq_len = max([e.size(0) for e in split_hidden_states]) | |
| max_seq_len = self.max_length | |
| prompt_embeds = torch.stack( | |
| [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] | |
| ) | |
| encoder_attention_mask = torch.stack( | |
| [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] | |
| ) | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| return prompt_embeds, encoder_attention_mask | |
| def forward(self, text): | |
| prompt_embeds, attention_mask = self._get_qwen_prompt_embeds( | |
| prompt=text, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| return prompt_embeds, attention_mask | |
| # pip install accelerate | |
| from transformers import AutoProcessor, Gemma3ForConditionalGeneration | |
| from PIL import Image | |
| import requests | |
| import torch | |
| import torch.nn as nn | |
| Qwen25VL_7b_PREFIX_edit = '''Given an user editing prompt and an source image, only describe the editing area and how they should change in a detailed way. | |
| Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: | |
| ''' | |
| Qwen25VL_7b_PREFIX_t2i = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt: | |
| - If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes. | |
| - If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n | |
| Here are examples of how to transform or refine prompts: | |
| - User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers. | |
| - User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n | |
| Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations: | |
| User Prompt:''' | |
| Qwen25VL_7b_PREFIX_image = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate." | |
| model_id = "google/gemma-3-4b-it" | |
| from transformers import AutoTokenizer, TrainingArguments, Gemma3ForCausalLM, AutoModel, Gemma3Model | |
| from transformers import Dinov2Model, AutoImageProcessor | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from PIL import Image | |
| import numpy as np | |
| class GemmaEmbedder(nn.Module): | |
| def __init__(self, max_sequence_length=300, model_id='google/gemma-3-4b-it'): | |
| super().__init__() | |
| device = torch.cuda.current_device() | |
| self.model = Gemma3Model.from_pretrained(model_id).to(device).to(torch.bfloat16) | |
| #self.model = Gemma3ForConditionalGeneration.from_pretrained(model_id).to(device).to(torch.bfloat16) | |
| self.processor = AutoProcessor.from_pretrained(model_id) | |
| self.device = device | |
| self.max_sequence_length = max_sequence_length | |
| #self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token # Use eos token as pad token | |
| self.processor.tokenizer.padding_side = "right" | |
| def get_features(self, hidden_states, input_ids): | |
| hidden_states = hidden_states[0] | |
| input_ids = input_ids[0].tolist() | |
| pad_text_embeds = torch.zeros([self.max_sequence_length, 2560], dtype=torch.bfloat16, device=self.device) | |
| pad_text_mask = torch.zeros([self.max_sequence_length], device=self.device) | |
| def find_last(lst, value): | |
| indices = [i for i, x in enumerate(lst) if x == value] | |
| return indices[-1] | |
| if 256000 in input_ids: | |
| text_start = input_ids.index(256000)+2 | |
| else: | |
| text_start = find_last(input_ids, 108)+1 | |
| text_end = len(input_ids)-6 | |
| bos_embed = hidden_states[:2] | |
| text_embeds = hidden_states[text_start:text_end + 1] | |
| text_embeds = torch.cat([bos_embed, text_embeds], dim=0) | |
| pad_text_embeds[:len(text_embeds), :] = text_embeds[:self.max_sequence_length] | |
| pad_text_mask[:len(text_embeds)] = 1.0 | |
| image_embeds = hidden_states[np.array(input_ids) == self.processor.tokenizer.image_token_id] | |
| """ | |
| print(input_ids) | |
| print(input_ids[text_start:text_end + 1]) | |
| decoded = self.processor.decode(input_ids[text_start:text_end+1], skip_special_tokens=False) | |
| print("Decoded text:", decoded, text_start, text_end, input_ids[text_start:text_end + 1], input_ids[1:2]) | |
| print("Text embeddings shape:", text_embeds.shape) | |
| norm = RMSNorm(2560, eps=1e-6).to(self.device).to(torch.bfloat16) | |
| print(text_embeds, ' >>> ext embeds') | |
| print(norm(text_embeds), ' >>> normed embeds') | |
| """ | |
| return image_embeds, pad_text_embeds, pad_text_mask | |
| def forward(self, caps, images=None): | |
| text_embeds = [] | |
| text_masks = [] | |
| full_image_embeds = [] | |
| device = self.model.device | |
| if images is None: | |
| images = [None] * len(caps) | |
| for cap,img in zip(caps, images): | |
| if img is not None: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_edit}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": img}, | |
| {"type": "text", "text": cap}, | |
| ] | |
| } | |
| ] | |
| else: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_t2i}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": cap}, | |
| ] | |
| } | |
| ] | |
| inputs = self.processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt", | |
| max_length = 640, | |
| truncation = True, | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| #sample_image_embeds = outputs.image_hidden_states | |
| sample_text_embeds, sample_text_mask, sample_image_embeds = [], [], [] | |
| for hidden in [outputs.hidden_states[-1]]: | |
| cur_image_embeds, cur_text_embeds, cur_text_mask = self.get_features(hidden, inputs["input_ids"]) | |
| sample_text_embeds.append(cur_text_embeds) | |
| sample_text_mask.append(cur_text_mask) | |
| sample_image_embeds.append(cur_image_embeds) | |
| text_embeds.append(torch.cat(sample_text_embeds, dim=0)) | |
| text_masks.append(torch.cat(sample_text_mask, dim=0)) | |
| #full_image_embeds.append(sample_image_embeds) | |
| full_image_embeds.append(torch.cat(sample_image_embeds, dim=0)) | |
| """ | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False) | |
| generation = generation[0][input_len:] | |
| decoded = self.processor.decode(generation, skip_special_tokens=True) | |
| print(cap, ' <>>> gemma ',decoded) | |
| """ | |
| text_embeds = torch.stack(text_embeds, dim=0) | |
| text_masks = torch.stack(text_masks, dim=0) | |
| full_image_embeds = torch.stack(full_image_embeds, dim=0) | |
| return { | |
| 'text_embeds': text_embeds, | |
| 'text_masks': text_masks, | |
| 'image_embeds': full_image_embeds, | |
| } | |
| class GemmaTextEmbedder(nn.Module): | |
| def __init__(self, device, max_sequence_length=300, model_id='./gemma-3-4b-it'): | |
| super().__init__() | |
| self.model = Gemma3Model.from_pretrained(model_id).to(device).to(torch.bfloat16) | |
| #self.model = Gemma3ForConditionalGeneration.from_pretrained(model_id).to(device).to(torch.bfloat16) | |
| self.processor = AutoProcessor.from_pretrained(model_id) | |
| self.real_device = device | |
| self.max_sequence_length = max_sequence_length | |
| #self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token # Use eos token as pad token | |
| self.processor.tokenizer.padding_side = "right" | |
| def dtype(self): | |
| """Return the dtype of the model parameters.""" | |
| return next(self.parameters()).dtype | |
| def device(self): | |
| """Return the device of the model parameters.""" | |
| return next(self.parameters()).device | |
| def get_features(self, hidden_states, input_ids): | |
| hidden_states = hidden_states[0] | |
| input_ids = input_ids[0].tolist() | |
| pad_text_embeds = torch.zeros([self.max_sequence_length, 2560], dtype=torch.bfloat16, device=self.device) | |
| pad_text_mask = torch.zeros([self.max_sequence_length], device=self.device) | |
| def find_last(lst, value): | |
| indices = [i for i, x in enumerate(lst) if x == value] | |
| return indices[-1] | |
| if 256000 in input_ids: | |
| text_start = input_ids.index(256000)+2 | |
| else: | |
| text_start = find_last(input_ids, 108)+1 | |
| text_end = len(input_ids)-6 | |
| bos_embed = hidden_states[:2] | |
| text_embeds = hidden_states[text_start:text_end + 1] | |
| text_embeds = torch.cat([bos_embed, text_embeds], dim=0) | |
| pad_text_embeds[:len(text_embeds), :] = text_embeds[:self.max_sequence_length] | |
| pad_text_mask[:len(text_embeds)] = 1.0 | |
| pad_text_embeds[len(text_embeds):, :] = 0.0 | |
| pad_text_mask[len(text_embeds):] = 0.0 | |
| """ | |
| print(input_ids) | |
| print(input_ids[text_start:text_end + 1]) | |
| decoded = self.processor.decode(input_ids[text_start:text_end+1], skip_special_tokens=False) | |
| print("Decoded text:", decoded, text_start, text_end, input_ids[text_start:text_end + 1], input_ids[1:2]) | |
| print("Text embeddings shape:", text_embeds.shape) | |
| print(text_embeds, ' >>> ext embeds') | |
| """ | |
| return pad_text_embeds, pad_text_mask | |
| def forward(self, caps, images=None): | |
| text_embeds = [] | |
| text_masks = [] | |
| full_image_embeds = [] | |
| device = self.model.device | |
| if isinstance(caps, str): | |
| caps = [caps] | |
| if images is None: | |
| images = [None] * len(caps) | |
| for cap,img in zip(caps, images): | |
| if img is not None: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_edit}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": img}, | |
| {"type": "text", "text": cap}, | |
| ] | |
| } | |
| ] | |
| else: | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [{"type": "text", "text": Qwen25VL_7b_PREFIX_t2i}] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": cap}, | |
| ] | |
| } | |
| ] | |
| inputs = self.processor.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt", | |
| max_length = 640, | |
| truncation = True, | |
| ).to(self.model.device, dtype=torch.bfloat16) | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| #sample_image_embeds = outputs.image_hidden_states | |
| sample_text_embeds, sample_text_mask, sample_image_embeds = [], [], [] | |
| for hidden in [outputs.hidden_states[-1]]: | |
| cur_text_embeds, cur_text_mask = self.get_features(hidden, inputs["input_ids"]) | |
| sample_text_embeds.append(cur_text_embeds) | |
| sample_text_mask.append(cur_text_mask) | |
| text_embeds.append(torch.cat(sample_text_embeds, dim=0)) | |
| text_masks.append(torch.cat(sample_text_mask, dim=0)) | |
| """ | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = self.model.generate(**inputs, max_new_tokens=100, do_sample=False) | |
| generation = generation[0][input_len:] | |
| decoded = self.processor.decode(generation, skip_special_tokens=True) | |
| print(cap, ' <>>> gemma ',decoded) | |
| """ | |
| text_embeds = torch.stack(text_embeds, dim=0) | |
| text_masks = torch.stack(text_masks, dim=0) | |
| return text_embeds, text_masks.to(text_embeds.dtype) | |
| from transformers import AutoModel, AutoTokenizer | |
| from transformers import SiglipVisionModel, AutoProcessor | |
| class Gemma2Embedder(nn.Module): | |
| def __init__(self, max_length=300): | |
| super().__init__() | |
| self.text_encoder = AutoModel.from_pretrained( | |
| "google/gemma-2-2b", | |
| torch_dtype=torch.bfloat16, | |
| ).to(torch.cuda.current_device()).to(torch.bfloat16).eval() | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "google/gemma-2-2b", | |
| ) | |
| self.tokenizer.padding_side = "right" | |
| self.max_length = max_length | |
| self.system_prompt = "You are an assistant designed to edit images faithfully based on user prompts. <Prompt Start> " | |
| system_ids = self.tokenizer( | |
| self.system_prompt, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| max_length=self.max_length, | |
| padding="max_length", | |
| truncation=True, | |
| ).input_ids.flatten().view(-1).numpy().tolist() | |
| self.len_system_prompt = system_ids.index(self.tokenizer.pad_token_id)-1 | |
| self.weight_dtype = torch.bfloat16 | |
| def forward(self, caption): | |
| if isinstance(caption, str): | |
| caption = [caption] | |
| caption = [self.system_prompt + c for c in caption] | |
| text_inputs = self.tokenizer( | |
| caption, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| max_length=self.max_length+self.len_system_prompt, | |
| padding="max_length", | |
| truncation=True, | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| attention_mask = text_inputs.attention_mask | |
| text_input_ids = text_input_ids.to(self.text_encoder.device) | |
| attention_mask = attention_mask.to(self.text_encoder.device) | |
| embeds = self.text_encoder(text_input_ids, attention_mask=attention_mask, | |
| output_hidden_states=True | |
| ).hidden_states[-2] | |
| embeds = embeds[:, self.len_system_prompt:, :] | |
| attention_mask = attention_mask[:, self.len_system_prompt:] | |
| return { | |
| 'text_embeds': embeds, | |
| 'text_masks': attention_mask, | |
| } | |
| class T5TextEmbedder(nn.Module): | |
| def __init__(self, device, pretrained_path="google/flan-t5-xxl", max_length=300): | |
| super().__init__() | |
| self.model = T5EncoderModel.from_pretrained(pretrained_path).to(device=device).to(torch.bfloat16) | |
| self.tokenizer = T5Tokenizer.from_pretrained(pretrained_path) | |
| self.max_length = max_length | |
| self.model.eval() | |
| self.model.requires_grad_(False) | |
| def dtype(self): | |
| """Return the dtype of the model parameters.""" | |
| return next(self.parameters()).dtype | |
| def device(self): | |
| """Return the device of the model parameters.""" | |
| return next(self.parameters()).device | |
| def forward( | |
| self, caption | |
| ): | |
| max_length = self.max_length | |
| text_inputs = self.tokenizer( | |
| caption, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| max_length=max_length, | |
| padding="max_length", | |
| truncation=True, | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| attention_mask = text_inputs.attention_mask | |
| text_input_ids = text_input_ids.to(self.model.device) | |
| attention_mask = attention_mask.to(self.model.device) | |
| outputs = self.model(text_input_ids, attention_mask=attention_mask) | |
| embeddings = outputs.last_hidden_state | |
| return embeddings, attention_mask.to(embeddings.dtype) | |
| if __name__ == '__main__': | |
| from datasets import load_dataset | |
| dataset = load_dataset("facebook/emu_edit_test_set", split='validation[:200]') | |
| item = dataset[0:4] | |
| another_item = dataset[0:4] | |
| from diffusers.models.normalization import RMSNorm | |
| image_encoder = CLIPImageEncoder(device="cuda:0") | |
| clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| image_embeds = image_encoder(clip_processor(images=item['image'], return_tensors="pt").pixel_values.to("cuda:0").to(torch.bfloat16)) | |
| print(image_embeds.shape, ' >>>> image embeds') | |
| #model = GemmaTextEmbedder(device="cuda:0") | |
| model = LoraT5Embedder(device="cuda:0") | |
| prompt_embeds, l2_loss, l1_loss, pad_loss, clip_image_embeds, attn_mask = model( | |
| [ | |
| """A heartwarming 3D rendered scene of | |
| an elderly farmer and a tiny orange | |
| kitten. The farmer, with a gentle smile, | |
| walks alongside the kitten in a lush, | |
| green garden filled with thriving plants, | |
| showcasing a fruitful harvest. The | |
| intricate details of the overalls and the | |
| farmer's worn, weathered face tell a | |
| story of years spent tending to the land, the farmer is wearing a blue shirt""", | |
| ], | |
| image=clip_processor(images=item['image'], return_tensors="pt").pixel_values.to("cuda:0").to(torch.bfloat16 | |
| )) | |
| print(l2_loss, ' >>> l2 loss ', l1_loss, ' >>> l1 loss ', pad_loss, ' >>> pad loss ') | |
| print(clip_image_embeds.shape, ' >>> clip image embeds ') | |
| #print(gemma_dict['text_embeds'],) | |
| #print(gemma_dict['image_embeds'], ' >>> image embeds') | |
| """ | |
| from dataset import create_loader | |
| from PIL import Image as PILImage | |
| from PIL import Image as PILImage | |
| import PIL | |
| import numpy as np | |
| import torch.nn.functional as F | |
| loader = create_loader('edit', batch_size=16, shuffle=False) | |
| batch = next(iter(loader)) | |
| source = batch['source_images'] | |
| source_pils = [PIL.Image.fromarray(((x.permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype(np.uint8)) for x in source] | |
| target = batch['target_images'] | |
| target_pils = [PIL.Image.fromarray(((x.permute(1, 2, 0).cpu().numpy() + 1) * 127.5).astype(np.uint8)) for x in target] | |
| from torchvision.utils import save_image | |
| print(batch['captions']) | |
| images = [] | |
| for (x, y) in zip(batch['source_images'], batch['target_images']): | |
| images.append(x) | |
| images.append(y) | |
| save_image((torch.stack(images) + 1) / 2, 'example_pairs.jpg', nrow=8) | |
| gemma_dict = model(batch['captions'], source_pils, target_pils) | |
| image_embeds = gemma_dict['image_embeds'] | |
| target_image_embeds = gemma_dict['target_image_embeds'] | |
| print("Image embeds shape:", image_embeds.shape) | |
| print("Target image embeds shape:", target_image_embeds.shape) | |
| from qwen import compute_and_save_similarity_grid | |
| compute_and_save_similarity_grid(image_embeds, target_image_embeds, "gemma_similarity_grid.jpg") | |
| """ | |