Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| from einops import rearrange | |
| from torch.utils.checkpoint import checkpoint | |
| from transformer_engine.pytorch.attention import apply_rotary_pos_emb | |
| from cosmos_predict1.diffusion.module.attention import Attention | |
| from cosmos_predict1.diffusion.training.utils.peft.lora_net import LoRALinearLayer, TELoRALinearLayer | |
| from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType | |
| try: | |
| from megatron.core import parallel_state | |
| USE_MEGATRON = True | |
| except ImportError: | |
| USE_MEGATRON = False | |
| def enable_attn_lora(attn: Attention, peft_control: dict) -> None: | |
| """ | |
| Enable LoRA for the attention block based on the peft_control dictionary. | |
| Args: | |
| attn (Attention): The attention block to configure. | |
| peft_control (dict): Dictionary containing PEFT configuration. | |
| """ | |
| attn.peft_lora_enabled = False | |
| if peft_control: | |
| try: | |
| if peft_control["customization_type"] == CustomizationType.LORA: | |
| attn.peft_lora_enabled = True | |
| else: | |
| raise Exception(f"Unsupported Customization type {peft_control['customization_type']}") | |
| except KeyError as e: | |
| raise KeyError(f"peft_control dictionary expected to have attribute {e.args[0]}.") | |
| def configure_attn_lora(attn: Attention, peft_control: dict) -> None: | |
| """ | |
| Configure LoRA for the attention block based on the peft_control dictionary. | |
| Args: | |
| attn (Attention): The attention block to configure. | |
| peft_control (dict): Dictionary containing PEFT configuration. | |
| """ | |
| try: | |
| attn.q_lora_enabled = peft_control.get("to_q", {}).get("activate", False) | |
| attn.k_lora_enabled = peft_control.get("to_k", {}).get("activate", False) | |
| attn.v_lora_enabled = peft_control.get("to_v", {}).get("activate", False) | |
| attn.out_lora_enabled = peft_control.get("to_out", {}).get("activate", False) | |
| if attn.q_lora_enabled: | |
| attn.q_lora_rank = peft_control["to_q"]["lora_rank"] | |
| attn.q_lora_scale = float(peft_control["to_q"]["lora_scale"]) | |
| if attn.k_lora_enabled: | |
| attn.k_lora_rank = peft_control["to_k"]["lora_rank"] | |
| attn.k_lora_scale = float(peft_control["to_k"]["lora_scale"]) | |
| if attn.v_lora_enabled: | |
| attn.v_lora_rank = peft_control["to_v"]["lora_rank"] | |
| attn.v_lora_scale = float(peft_control["to_v"]["lora_scale"]) | |
| if attn.out_lora_enabled: | |
| attn.out_lora_rank = peft_control["to_out"]["lora_rank"] | |
| attn.out_lora_scale = float(peft_control["to_out"]["lora_scale"]) | |
| except KeyError as e: | |
| raise KeyError(f"All layers (to_q, etc) specified must have attribute {e.args[0]}.") | |
| except ValueError as e: | |
| raise ValueError(f"Could not convert string to float: {e}") | |
| def cal_qkv_lora( | |
| self, | |
| x: torch.Tensor, | |
| context: torch.Tensor = None, | |
| mask: torch.Tensor = None, | |
| rope_emb: torch.Tensor = None, | |
| **kwargs, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| del kwargs | |
| """ | |
| Calculate the Q, K, V matrices with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_qkv. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| context (torch.Tensor, optional): Context tensor | |
| mask (torch.Tensor, optional): Mask tensor | |
| rope_emb (torch.Tensor, optional): Rotary positional embedding | |
| Returns: | |
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The Q, K, V matrices. | |
| """ | |
| q = self.to_q[0](x) | |
| context = x if context is None else context | |
| k = self.to_k[0](context) | |
| v = self.to_v[0](context) | |
| if self.peft_lora_enabled: | |
| try: | |
| if self.q_lora_enabled: | |
| q_lora = self.to_q_lora(x) | |
| q = q + self.q_lora_scale * q_lora | |
| if self.k_lora_enabled: | |
| k_lora = self.to_k_lora(context) | |
| k = k + self.k_lora_scale * k_lora | |
| if self.v_lora_enabled: | |
| v_lora = self.to_v_lora(context) | |
| v = v + self.v_lora_scale * v_lora | |
| except AttributeError as e: | |
| raise AttributeError(f"lora enabled, but missing class attribute {e.args[0]} of Attention block") | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads // self.tp_size, c=self.dim_head), | |
| (q, k, v), | |
| ) | |
| def apply_norm_and_rotary_pos_emb(q, k, v, rope_emb): | |
| q = self.to_q[1](q) | |
| k = self.to_k[1](k) | |
| v = self.to_v[1](v) | |
| if self.is_selfattn and rope_emb is not None: # only apply to self-attention! | |
| q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) | |
| k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) | |
| return q, k, v | |
| q, k, v = checkpoint(apply_norm_and_rotary_pos_emb, q, k, v, rope_emb, use_reentrant=False) | |
| return q, k, v | |
| def cal_attn_lora(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: | |
| """ | |
| Calculate the attention output with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_attn. | |
| Args: | |
| q (torch.Tensor): Query tensor. | |
| k (torch.Tensor): Key tensor. | |
| v (torch.Tensor): Value tensor. | |
| mask (torch.Tensor, optional): Mask tensor. | |
| Returns: | |
| torch.Tensor: The attention output. | |
| """ | |
| if self.backend == "transformer_engine": | |
| seq_dim = self.qkv_format.index("s") | |
| assert ( | |
| q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 | |
| ), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." | |
| attn_out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V] | |
| out = self.to_out(attn_out) | |
| if self.peft_lora_enabled and self.out_lora_enabled: | |
| try: | |
| out_lora = self.to_out_lora(attn_out) | |
| out = out + self.out_lora_scale * out_lora | |
| except AttributeError as e: | |
| raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") | |
| return out | |
| elif self.backend == "torch": | |
| attn_out = self.attn_op(q, k, v, mask=mask) # [B, Mq, H, V] | |
| attn_out = rearrange(attn_out, " b ... n c -> b ... (n c)") | |
| out = self.to_out(attn_out) | |
| if self.peft_lora_enabled and self.out_lora_enabled: | |
| try: | |
| out_lora = self.to_out_lora(attn_out) | |
| out = out + self.out_lora_scale * out_lora | |
| except AttributeError as e: | |
| raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") | |
| return out | |
| else: | |
| raise ValueError(f"Backend {self.backend} not found") | |
| def build_attn_lora(attn: Attention, peft_control: dict) -> None: | |
| """ | |
| Configure, build and add LoRA layers to the attention block. | |
| Args: | |
| attn (Attention): The attention block to add LoRA layers to. | |
| peft_control (dict): Dictionary containing PEFT configuration. | |
| """ | |
| enable_attn_lora(attn, peft_control) | |
| configure_attn_lora(attn, peft_control) | |
| if attn.peft_lora_enabled: | |
| query_dim = attn.query_dim | |
| inner_dim = attn.inner_dim | |
| context_dim = attn.context_dim | |
| tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=False) if USE_MEGATRON else None | |
| if attn.tp_size == 1: | |
| if attn.q_lora_enabled: | |
| attn.to_q_lora = LoRALinearLayer(query_dim, inner_dim, rank=attn.q_lora_rank, linear=True) | |
| if attn.k_lora_enabled: | |
| attn.to_k_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.k_lora_rank, linear=True) | |
| if attn.v_lora_enabled: | |
| attn.to_v_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.v_lora_rank, linear=True) | |
| if attn.out_lora_enabled: | |
| attn.to_out_lora = LoRALinearLayer(inner_dim, query_dim, rank=attn.out_lora_rank, linear=True) | |
| else: | |
| sequence_parallel = getattr(parallel_state, "sequence_parallel", False) | |
| if attn.q_lora_enabled: | |
| attn.to_q_lora = TELoRALinearLayer( | |
| query_dim, | |
| inner_dim, | |
| rank=attn.q_lora_rank, | |
| linear=True, | |
| tp_size=attn.tp_size, | |
| tp_group=tp_group, | |
| sequence_parallel=sequence_parallel, | |
| parallel_mode="column", | |
| ) | |
| if attn.k_lora_enabled: | |
| attn.to_k_lora = TELoRALinearLayer( | |
| context_dim, | |
| inner_dim, | |
| rank=attn.k_lora_rank, | |
| linear=True, | |
| tp_size=attn.tp_size, | |
| tp_group=tp_group, | |
| sequence_parallel=sequence_parallel, | |
| parallel_mode="column", | |
| ) | |
| if attn.v_lora_enabled: | |
| attn.to_v_lora = TELoRALinearLayer( | |
| context_dim, | |
| inner_dim, | |
| rank=attn.v_lora_rank, | |
| linear=True, | |
| tp_size=attn.tp_size, | |
| tp_group=tp_group, | |
| sequence_parallel=sequence_parallel, | |
| parallel_mode="column", | |
| ) | |
| if attn.out_lora_enabled: | |
| attn.to_out_lora = TELoRALinearLayer( | |
| inner_dim, | |
| query_dim, | |
| rank=attn.out_lora_rank, | |
| linear=True, | |
| tp_size=attn.tp_size, | |
| tp_group=tp_group, | |
| sequence_parallel=sequence_parallel, | |
| parallel_mode="row", | |
| ) | |
| attn.cal_qkv = cal_qkv_lora.__get__(attn, attn.__class__) | |
| attn.cal_attn = cal_attn_lora.__get__(attn, attn.__class__) | |