Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Any, Dict, Optional, Tuple, Union | |
| from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| from safetensors.torch import load_file | |
| logger = logging.get_logger(__name__) | |
| def decode_latents(pipe, latents): | |
| video = pipe.decode_latents(latents) | |
| video = pipe.video_processor.postprocess_video(video=video, output_type="np") | |
| return video | |
| def create_attention_mask(text_length: int, seq_length: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: | |
| """ | |
| Create an attention mask to block text from attending to alpha. | |
| Args: | |
| text_length: Length of the text sequence. | |
| seq_length: Length of the other sequence. | |
| device: The device where the mask will be stored. | |
| dtype: The data type of the mask tensor. | |
| Returns: | |
| An attention mask tensor. | |
| """ | |
| total_length = text_length + seq_length | |
| dense_mask = torch.ones((total_length, total_length), dtype=torch.bool) | |
| dense_mask[:text_length, text_length + seq_length // 2:] = False | |
| return dense_mask.to(device=device, dtype=dtype) | |
| class RGBALoRACogVideoXAttnProcessor: | |
| r""" | |
| Processor for implementing scaled dot-product attention for the CogVideoX model. | |
| It applies a rotary embedding on query and key vectors, but does not include spatial normalization. | |
| """ | |
| def __init__(self, device, dtype, attention_mask, lora_rank=128, lora_alpha=1.0, latent_dim=3072): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0 or later.") | |
| # Initialize LoRA layers | |
| self.lora_alpha = lora_alpha | |
| self.lora_rank = lora_rank | |
| # Helper function to create LoRA layers | |
| def create_lora_layer(in_dim, mid_dim, out_dim): | |
| return nn.Sequential( | |
| nn.Linear(in_dim, mid_dim, bias=False, device=device, dtype=dtype), | |
| nn.Linear(mid_dim, out_dim, bias=False, device=device, dtype=dtype) | |
| ) | |
| self.to_q_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) | |
| self.to_k_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) | |
| self.to_v_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) | |
| self.to_out_lora = create_lora_layer(latent_dim, lora_rank, latent_dim) | |
| # Store attention mask | |
| self.attention_mask = attention_mask | |
| def _apply_lora(self, hidden_states, seq_len, query, key, value, scaling): | |
| """Applies LoRA updates to query, key, and value tensors.""" | |
| query_delta = self.to_q_lora(hidden_states).to(query.device) | |
| query[:, -seq_len // 2:, :] += query_delta[:, -seq_len // 2:, :] * scaling | |
| key_delta = self.to_k_lora(hidden_states).to(key.device) | |
| key[:, -seq_len // 2:, :] += key_delta[:, -seq_len // 2:, :] * scaling | |
| value_delta = self.to_v_lora(hidden_states).to(value.device) | |
| value[:, -seq_len // 2:, :] += value_delta[:, -seq_len // 2:, :] * scaling | |
| return query, key, value | |
| def _apply_rotary_embedding(self, query, key, image_rotary_emb, seq_len, text_seq_length, attn): | |
| """Applies rotary embeddings to query and key tensors.""" | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| # Apply rotary embedding to RGB and alpha sections | |
| query[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb( | |
| query[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb) | |
| query[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb( | |
| query[:, :, text_seq_length + seq_len // 2:], image_rotary_emb) | |
| if not attn.is_cross_attention: | |
| key[:, :, text_seq_length:text_seq_length + seq_len // 2] = apply_rotary_emb( | |
| key[:, :, text_seq_length:text_seq_length + seq_len // 2], image_rotary_emb) | |
| key[:, :, text_seq_length + seq_len // 2:] = apply_rotary_emb( | |
| key[:, :, text_seq_length + seq_len // 2:], image_rotary_emb) | |
| return query, key | |
| def __call__( | |
| self, | |
| attn, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| # Concatenate encoder and decoder hidden states | |
| text_seq_length = encoder_hidden_states.size(1) | |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
| batch_size, sequence_length, _ = hidden_states.shape | |
| seq_len = hidden_states.shape[1] - text_seq_length | |
| scaling = self.lora_alpha / self.lora_rank | |
| # Apply LoRA to query, key, value | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(hidden_states) | |
| value = attn.to_v(hidden_states) | |
| query, key, value = self._apply_lora(hidden_states, seq_len, query, key, value, scaling) | |
| # Reshape query, key, value for multi-head attention | |
| inner_dim = key.shape[-1] | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # Normalize query and key if required | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # Apply rotary embeddings if provided | |
| if image_rotary_emb is not None: | |
| query, key = self._apply_rotary_embedding(query, key, image_rotary_emb, seq_len, text_seq_length, attn) | |
| # Compute scaled dot-product attention | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=self.attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| # Reshape the output tensor back to the original shape | |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
| # Apply linear projection and LoRA to the output | |
| original_hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states_delta = self.to_out_lora(hidden_states).to(hidden_states.device) | |
| original_hidden_states[:, -seq_len // 2:, :] += hidden_states_delta[:, -seq_len // 2:, :] * scaling | |
| # Apply dropout | |
| hidden_states = attn.to_out[1](original_hidden_states) | |
| # Split back into encoder and decoder hidden states | |
| encoder_hidden_states, hidden_states = hidden_states.split( | |
| [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 | |
| ) | |
| return hidden_states, encoder_hidden_states | |
| def prepare_for_rgba_inference( | |
| model, rgba_weights_path: str, device: torch.device, dtype: torch.dtype, | |
| lora_rank: int = 128, lora_alpha: float = 1.0, text_length: int = 226, seq_length: int = 35100 | |
| ): | |
| def load_lora_sequential_weights(lora_layer, lora_layers, prefix): | |
| lora_layer[0].load_state_dict({'weight': lora_layers[f"{prefix}.lora_A.weight"]}) | |
| lora_layer[1].load_state_dict({'weight': lora_layers[f"{prefix}.lora_B.weight"]}) | |
| rgba_weights = load_file(rgba_weights_path) | |
| aux_emb = rgba_weights['domain_emb'] | |
| attention_mask = create_attention_mask(text_length, seq_length, device, dtype) | |
| attn_procs = {} | |
| for name in model.attn_processors.keys(): | |
| attn_processor = RGBALoRACogVideoXAttnProcessor( | |
| device=device, dtype=dtype, attention_mask=attention_mask, | |
| lora_rank=lora_rank, lora_alpha=lora_alpha | |
| ) | |
| index = name.split('.')[1] | |
| base_prefix = f'transformer.transformer_blocks.{index}.attn1' | |
| for lora_layer, prefix in [ | |
| (attn_processor.to_q_lora, f'{base_prefix}.to_q'), | |
| (attn_processor.to_k_lora, f'{base_prefix}.to_k'), | |
| (attn_processor.to_v_lora, f'{base_prefix}.to_v'), | |
| (attn_processor.to_out_lora, f'{base_prefix}.to_out.0'), | |
| ]: | |
| load_lora_sequential_weights(lora_layer, rgba_weights, prefix) | |
| attn_procs[name] = attn_processor | |
| model.set_attn_processor(attn_procs) | |
| def custom_forward(self): | |
| def forward( | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| timestep: Union[int, float, torch.LongTensor], | |
| timestep_cond: Optional[torch.Tensor] = None, | |
| ofs: Optional[Union[int, float, torch.LongTensor]] = None, | |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| attention_kwargs: Optional[Dict[str, Any]] = None, | |
| return_dict: bool = True, | |
| ): | |
| if attention_kwargs is not None: | |
| attention_kwargs = attention_kwargs.copy() | |
| lora_scale = attention_kwargs.pop("scale", 1.0) | |
| else: | |
| lora_scale = 1.0 | |
| if USE_PEFT_BACKEND: | |
| # weight the lora layers by setting `lora_scale` for each PEFT layer | |
| scale_lora_layers(self, lora_scale) | |
| else: | |
| if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: | |
| logger.warning( | |
| "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." | |
| ) | |
| batch_size, num_frames, channels, height, width = hidden_states.shape | |
| # 1. Time embedding | |
| timesteps = timestep | |
| t_emb = self.time_proj(timesteps) | |
| # timesteps does not contain any weights and will always return f32 tensors | |
| # but time_embedding might actually be running in fp16. so we need to cast here. | |
| # there might be better ways to encapsulate this. | |
| t_emb = t_emb.to(dtype=hidden_states.dtype) | |
| emb = self.time_embedding(t_emb, timestep_cond) | |
| if self.ofs_embedding is not None: | |
| ofs_emb = self.ofs_proj(ofs) | |
| ofs_emb = ofs_emb.to(dtype=hidden_states.dtype) | |
| ofs_emb = self.ofs_embedding(ofs_emb) | |
| emb = emb + ofs_emb | |
| # 2. Patch embedding | |
| hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) | |
| hidden_states = self.embedding_dropout(hidden_states) | |
| text_seq_length = encoder_hidden_states.shape[1] | |
| encoder_hidden_states = hidden_states[:, :text_seq_length] | |
| hidden_states = hidden_states[:, text_seq_length:] | |
| hidden_states[:, hidden_states.size(1) // 2:, :] += aux_emb.expand(batch_size, -1, -1).to(hidden_states.device, dtype=hidden_states.dtype) | |
| # 3. Transformer blocks | |
| for i, block in enumerate(self.transformer_blocks): | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(block), | |
| hidden_states, | |
| encoder_hidden_states, | |
| emb, | |
| image_rotary_emb, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| hidden_states, encoder_hidden_states = block( | |
| hidden_states=hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| temb=emb, | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| if not self.config.use_rotary_positional_embeddings: | |
| # CogVideoX-2B | |
| hidden_states = self.norm_final(hidden_states) | |
| else: | |
| # CogVideoX-5B | |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) | |
| hidden_states = self.norm_final(hidden_states) | |
| hidden_states = hidden_states[:, text_seq_length:] | |
| # 4. Final block | |
| hidden_states = self.norm_out(hidden_states, temb=emb) | |
| hidden_states = self.proj_out(hidden_states) | |
| # 5. Unpatchify | |
| p = self.config.patch_size | |
| p_t = self.config.patch_size_t | |
| if p_t is None: | |
| output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) | |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) | |
| else: | |
| output = hidden_states.reshape( | |
| batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p | |
| ) | |
| output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) | |
| if USE_PEFT_BACKEND: | |
| # remove `lora_scale` from each PEFT layer | |
| unscale_lora_layers(self, lora_scale) | |
| if not return_dict: | |
| return (output,) | |
| return Transformer2DModelOutput(sample=output) | |
| return forward | |
| model.forward = custom_forward(model) | |