Spaces:
Sleeping
Sleeping
| import inspect | |
| import math | |
| from typing import Callable, List, Optional, Tuple, Union | |
| from einops import rearrange | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch import Tensor | |
| from diffusers.models.attention_processor import Attention | |
| class LoRACompatibleLinear(nn.Linear): | |
| """ | |
| A Linear layer that can be used with LoRA. | |
| """ | |
| def __init__(self, *args, lora_layer= None, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.weight.requires_grad_(False) | |
| if self.bias is not None: | |
| self.bias.requires_grad_(False) | |
| self.lora_layer = lora_layer | |
| def set_lora_layer(self, lora_layer): | |
| self.lora_layer = lora_layer | |
| def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: | |
| if self.lora_layer is None: | |
| out = super().forward(hidden_states) | |
| return out | |
| else: | |
| out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) | |
| return out | |
| class param_CondLoRAMoELayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| cond_dim: int, | |
| num_experts: int = 4, | |
| rank: int = 4, | |
| network_alpha: Optional[float] = None, | |
| top_k: int = 1, | |
| device: Optional[Union[torch.device, str]] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| use_shared_expert: bool = True, # New argument for shared expert | |
| shared_expert_rank: int = None, | |
| ): | |
| super().__init__() | |
| self.rank = rank | |
| self.num_experts = num_experts | |
| self.top_k = top_k | |
| self.norm_lora_scale = 16 // rank | |
| self.device = device | |
| self.dtype = dtype | |
| self.use_shared_expert = use_shared_expert # Store whether to use shared expert | |
| # num_experts -= int(use_shared_expert) | |
| # Directly split expert into A and B | |
| self.loraA = nn.Parameter( | |
| torch.zeros(num_experts, rank, in_features, device=device, dtype=dtype) | |
| ) | |
| self.loraB = nn.Parameter( | |
| torch.zeros(num_experts, out_features, rank, device=device, dtype=dtype) | |
| ) | |
| # Shared expert parameters (if enabled) | |
| if self.use_shared_expert: | |
| rank = shared_expert_rank if shared_expert_rank else rank | |
| self.shared_A = nn.Parameter( | |
| torch.zeros(1, rank, in_features, device=device, dtype=dtype) | |
| ) | |
| self.shared_B = nn.Parameter( | |
| torch.zeros(1, out_features, rank, device=device, dtype=dtype) | |
| ) | |
| # Gating | |
| self.cond_gate = nn.Linear(cond_dim, num_experts, device=device, dtype=dtype, bias=False) | |
| self.uninit_expert_idx = 0 | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| with torch.no_grad(): | |
| self.loraA.normal_(mean=0.0, std=1.0 / float(self.rank)) | |
| self.loraB.zero_() | |
| # Initialize shared expert weights (if using shared expert) | |
| if self.use_shared_expert: | |
| self.shared_A.normal_(mean=0.0, std=1.0 / float(self.rank)) | |
| self.shared_B.zero_() | |
| def set_latents(self, cond_hidden_states: torch.Tensor = None): | |
| self.cond_hidden_states = cond_hidden_states | |
| def clear_latents(self): | |
| self.cond_hidden_states = None | |
| def set_pretrained_expert_weights(self, svd_lora_weights, keep_rank=None): | |
| # Default keep_rank to self.rank if not provided | |
| if keep_rank is None: | |
| keep_rank = self.rank | |
| if self.uninit_expert_idx == self.num_experts: | |
| print("attn processor 已经初始化满了") | |
| return | |
| A_new, B_new = svd_lora_weights | |
| A_new, B_new = A_new[:keep_rank, :], B_new[:, :keep_rank] # Use keep_rank for the slicing | |
| # Handle the case when keep_rank > self.rank | |
| num_splits = keep_rank // self.rank | |
| for i in range(num_splits): | |
| start_idx = i * self.rank | |
| end_idx = (i + 1) * self.rank | |
| self.loraA.data[self.uninit_expert_idx + i] = A_new[start_idx:end_idx, :].to(device=self.device, dtype=self.dtype) | |
| self.loraB.data[self.uninit_expert_idx + i] = B_new[:, start_idx:end_idx].to(device=self.device, dtype=self.dtype) | |
| self.uninit_expert_idx += num_splits | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| # gating | |
| gate_logits = self.cond_gate(self.cond_hidden_states) # [B, num_experts] | |
| # ====== Top-k before softmax ====== | |
| if self.top_k == self.num_experts: | |
| topk_logits = gate_logits | |
| topk_idx = torch.arange(self.num_experts, device=self.device).expand(gate_logits.size(0), -1) | |
| else: | |
| topk_logits, topk_idx = torch.topk(gate_logits, self.top_k, dim=-1) # [B, k] | |
| # softmax only on selected logits | |
| topk_scores = F.softmax(topk_logits, dim=-1) | |
| self.top_k_idx = topk_idx | |
| topk_scores *= self.norm_lora_scale | |
| # Ensure input shape is [B, T, D] | |
| if hidden_states.dim() == 2: | |
| hidden_states = hidden_states.unsqueeze(1) # [B, 1, D] | |
| squeeze_back = True | |
| else: | |
| squeeze_back = False | |
| B, T, D = hidden_states.shape | |
| # Select top-k expert parameters | |
| A_selected = self.loraA[topk_idx] # [B, k, r, D] | |
| B_selected = self.loraB[topk_idx] # [B, k, out_features, r] | |
| # Include shared expert (if enabled) | |
| if self.use_shared_expert: | |
| A_shared = self.shared_A.expand(B, -1, -1, -1) # [B, 1, r, D] | |
| B_shared = self.shared_B.expand(B, -1, -1, -1) # [B, 1, out_features, r] | |
| A_selected = torch.cat([A_shared, A_selected], dim=1) # [B, k+1, r, D] | |
| B_selected = torch.cat([B_shared, B_selected], dim=1) # [B, k+1, out_features, r] | |
| topk_scores = F.pad(topk_scores, (0, 1), "constant", 0) # Pad scores for shared expert | |
| # Replicate the input for top-k selection | |
| flat_in = hidden_states.unsqueeze(1).expand(-1, self.top_k + int(self.use_shared_expert), -1, -1) # [B, k+1, T, D] | |
| # Calculate (x @ A^T) @ B^T | |
| inter = torch.einsum("bktd,bkrd->bktr", flat_in, A_selected) # [B, k+1, T, r] | |
| expert_out = torch.einsum("bktr,bkor->bkto", inter, B_selected) # [B, k+1, T, out_features] | |
| # Weighted sum | |
| outputs = torch.einsum("bkto,bk->bto", expert_out, topk_scores) | |
| if squeeze_back: | |
| outputs = outputs.squeeze(1) | |
| return outputs | |
| # ---- 测试用例 ---- | |
| if __name__ == "__main__": | |
| torch.manual_seed(42) | |
| B, T, D = 3, 5, 8 # batch=3, time=5, in_features=8 | |
| out_features = 6 | |
| cond_dim = 4 | |
| num_experts = 4 | |
| rank = 2 | |
| top_k = 2 | |
| layer = param_CondLoRAMoELayer( | |
| in_features=D, | |
| out_features=out_features, | |
| cond_dim=cond_dim, | |
| num_experts=num_experts, | |
| rank=rank, | |
| top_k=top_k, | |
| use_shared_expert=True, | |
| ) | |
| hidden_states = torch.randn(B, T, D) # [3, 5, 8] | |
| cond_hidden_states = torch.randn(B, cond_dim) # [3, 4] | |
| layer.set_latents(cond_hidden_states=cond_hidden_states) | |
| out = layer(hidden_states) | |
| print("Output shape:", out.shape) # 期望 [3, 5, 6] | |
| print(out) | |