Spaces:
Sleeping
Sleeping
File size: 7,327 Bytes
56d35ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | import inspect
import math
from typing import Callable, List, Optional, Tuple, Union
from einops import rearrange
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from diffusers.models.attention_processor import Attention
class LoRACompatibleLinear(nn.Linear):
"""
A Linear layer that can be used with LoRA.
"""
def __init__(self, *args, lora_layer= None, **kwargs):
super().__init__(*args, **kwargs)
self.weight.requires_grad_(False)
if self.bias is not None:
self.bias.requires_grad_(False)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer):
self.lora_layer = lora_layer
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
if self.lora_layer is None:
out = super().forward(hidden_states)
return out
else:
out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
return out
class param_CondLoRAMoELayer(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
cond_dim: int,
num_experts: int = 4,
rank: int = 4,
network_alpha: Optional[float] = None,
top_k: int = 1,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
use_shared_expert: bool = True, # New argument for shared expert
shared_expert_rank: int = None,
):
super().__init__()
self.rank = rank
self.num_experts = num_experts
self.top_k = top_k
self.norm_lora_scale = 16 // rank
self.device = device
self.dtype = dtype
self.use_shared_expert = use_shared_expert # Store whether to use shared expert
# num_experts -= int(use_shared_expert)
# Directly split expert into A and B
self.loraA = nn.Parameter(
torch.zeros(num_experts, rank, in_features, device=device, dtype=dtype)
)
self.loraB = nn.Parameter(
torch.zeros(num_experts, out_features, rank, device=device, dtype=dtype)
)
# Shared expert parameters (if enabled)
if self.use_shared_expert:
rank = shared_expert_rank if shared_expert_rank else rank
self.shared_A = nn.Parameter(
torch.zeros(1, rank, in_features, device=device, dtype=dtype)
)
self.shared_B = nn.Parameter(
torch.zeros(1, out_features, rank, device=device, dtype=dtype)
)
# Gating
self.cond_gate = nn.Linear(cond_dim, num_experts, device=device, dtype=dtype, bias=False)
self.uninit_expert_idx = 0
self.reset_parameters()
def reset_parameters(self):
with torch.no_grad():
self.loraA.normal_(mean=0.0, std=1.0 / float(self.rank))
self.loraB.zero_()
# Initialize shared expert weights (if using shared expert)
if self.use_shared_expert:
self.shared_A.normal_(mean=0.0, std=1.0 / float(self.rank))
self.shared_B.zero_()
def set_latents(self, cond_hidden_states: torch.Tensor = None):
self.cond_hidden_states = cond_hidden_states
def clear_latents(self):
self.cond_hidden_states = None
def set_pretrained_expert_weights(self, svd_lora_weights, keep_rank=None):
# Default keep_rank to self.rank if not provided
if keep_rank is None:
keep_rank = self.rank
if self.uninit_expert_idx == self.num_experts:
print("attn processor 已经初始化满了")
return
A_new, B_new = svd_lora_weights
A_new, B_new = A_new[:keep_rank, :], B_new[:, :keep_rank] # Use keep_rank for the slicing
# Handle the case when keep_rank > self.rank
num_splits = keep_rank // self.rank
for i in range(num_splits):
start_idx = i * self.rank
end_idx = (i + 1) * self.rank
self.loraA.data[self.uninit_expert_idx + i] = A_new[start_idx:end_idx, :].to(device=self.device, dtype=self.dtype)
self.loraB.data[self.uninit_expert_idx + i] = B_new[:, start_idx:end_idx].to(device=self.device, dtype=self.dtype)
self.uninit_expert_idx += num_splits
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# gating
gate_logits = self.cond_gate(self.cond_hidden_states) # [B, num_experts]
# ====== Top-k before softmax ======
if self.top_k == self.num_experts:
topk_logits = gate_logits
topk_idx = torch.arange(self.num_experts, device=self.device).expand(gate_logits.size(0), -1)
else:
topk_logits, topk_idx = torch.topk(gate_logits, self.top_k, dim=-1) # [B, k]
# softmax only on selected logits
topk_scores = F.softmax(topk_logits, dim=-1)
self.top_k_idx = topk_idx
topk_scores *= self.norm_lora_scale
# Ensure input shape is [B, T, D]
if hidden_states.dim() == 2:
hidden_states = hidden_states.unsqueeze(1) # [B, 1, D]
squeeze_back = True
else:
squeeze_back = False
B, T, D = hidden_states.shape
# Select top-k expert parameters
A_selected = self.loraA[topk_idx] # [B, k, r, D]
B_selected = self.loraB[topk_idx] # [B, k, out_features, r]
# Include shared expert (if enabled)
if self.use_shared_expert:
A_shared = self.shared_A.expand(B, -1, -1, -1) # [B, 1, r, D]
B_shared = self.shared_B.expand(B, -1, -1, -1) # [B, 1, out_features, r]
A_selected = torch.cat([A_shared, A_selected], dim=1) # [B, k+1, r, D]
B_selected = torch.cat([B_shared, B_selected], dim=1) # [B, k+1, out_features, r]
topk_scores = F.pad(topk_scores, (0, 1), "constant", 0) # Pad scores for shared expert
# Replicate the input for top-k selection
flat_in = hidden_states.unsqueeze(1).expand(-1, self.top_k + int(self.use_shared_expert), -1, -1) # [B, k+1, T, D]
# Calculate (x @ A^T) @ B^T
inter = torch.einsum("bktd,bkrd->bktr", flat_in, A_selected) # [B, k+1, T, r]
expert_out = torch.einsum("bktr,bkor->bkto", inter, B_selected) # [B, k+1, T, out_features]
# Weighted sum
outputs = torch.einsum("bkto,bk->bto", expert_out, topk_scores)
if squeeze_back:
outputs = outputs.squeeze(1)
return outputs
# ---- 测试用例 ----
if __name__ == "__main__":
torch.manual_seed(42)
B, T, D = 3, 5, 8 # batch=3, time=5, in_features=8
out_features = 6
cond_dim = 4
num_experts = 4
rank = 2
top_k = 2
layer = param_CondLoRAMoELayer(
in_features=D,
out_features=out_features,
cond_dim=cond_dim,
num_experts=num_experts,
rank=rank,
top_k=top_k,
use_shared_expert=True,
)
hidden_states = torch.randn(B, T, D) # [3, 5, 8]
cond_hidden_states = torch.randn(B, cond_dim) # [3, 4]
layer.set_latents(cond_hidden_states=cond_hidden_states)
out = layer(hidden_states)
print("Output shape:", out.shape) # 期望 [3, 5, 6]
print(out)
|