MultiModal / moe.py
szxllm's picture
Update moe.py
958b4f3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, List
import math
class Expert(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
dropout: float = 0.0,
bias: bool = False
):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=bias)
self.w2 = nn.Linear(hidden_dim, dim, bias=bias)
self.w3 = nn.Linear(dim, hidden_dim, bias=bias)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self._init_weights()
def _init_weights(self):
"""改进的权重初始化"""
for module in [self.w1, self.w2, self.w3]:
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class TopKRouter(nn.Module):
def __init__(
self,
dim: int,
num_experts: int,
top_k: int = 2,
capacity_factor: float = 1.25,
noise_std: float = 1.0,
use_expert_capacity: bool = True,
router_z_loss_coef: float = 0.001,
router_aux_loss_coef: float = 0.01
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.capacity_factor = capacity_factor
self.noise_std = noise_std
self.use_expert_capacity = use_expert_capacity
self.router_z_loss_coef = router_z_loss_coef
self.router_aux_loss_coef = router_aux_loss_coef
self.gate = nn.Linear(dim, num_experts, bias=False)
nn.init.normal_(self.gate.weight, mean=0.0, std=0.02)
def _compute_routing_weights(
self,
logits: torch.Tensor,
use_noise: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
if use_noise and self.training:
noise = torch.randn_like(logits) * self.noise_std
logits = logits + noise
top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
top_k_gates = F.softmax(top_k_logits, dim=-1)
return top_k_gates, top_k_indices
def _compute_auxiliary_loss(
self,
logits: torch.Tensor,
top_k_indices: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
num_tokens = logits.shape[0]
router_probs = F.softmax(logits, dim=-1)
expert_probs = router_probs.mean(dim=0)
expert_mask = F.one_hot(top_k_indices, self.num_experts).float()
expert_freq = expert_mask.sum(dim=[0, 1]) / (num_tokens * self.top_k)
load_balance_loss = self.num_experts * torch.sum(expert_probs * expert_freq)
z_loss = torch.mean(logits ** 2)
return load_balance_loss, z_loss
def forward(
self,
x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
logits = self.gate(x)
top_k_gates, top_k_indices = self._compute_routing_weights(
logits, use_noise=self.training
)
if self.training:
load_balance_loss, z_loss = self._compute_auxiliary_loss(logits, top_k_indices)
auxiliary_loss = (
self.router_aux_loss_coef * load_balance_loss +
self.router_z_loss_coef * z_loss
)
else:
auxiliary_loss = torch.tensor(0.0, device=x.device)
return top_k_gates, top_k_indices, auxiliary_loss
class MixtureOfExperts(nn.Module):
def __init__(
self,
dim: int,
num_experts: int = 8,
expert_hidden_dim: Optional[int] = None,
top_k: int = 2,
dropout: float = 0.0,
capacity_factor: float = 1.25,
use_expert_capacity: bool = True,
router_z_loss_coef: float = 0.001,
router_aux_loss_coef: float = 0.01,
noise_std: float = 1.0,
ffn_dim_multiplier: Optional[float] = None
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.capacity_factor = capacity_factor
self.use_expert_capacity = use_expert_capacity
if expert_hidden_dim is None:
if ffn_dim_multiplier is not None:
expert_hidden_dim = int(dim * ffn_dim_multiplier)
else:
expert_hidden_dim = int(2 * dim * 4 / 3)
expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
self.experts = nn.ModuleList([
Expert(dim, expert_hidden_dim, dropout, bias=False)
for _ in range(num_experts)
])
self.router = TopKRouter(
dim=dim,
num_experts=num_experts,
top_k=top_k,
capacity_factor=capacity_factor,
noise_std=noise_std,
use_expert_capacity=use_expert_capacity,
router_z_loss_coef=router_z_loss_coef,
router_aux_loss_coef=router_aux_loss_coef
)
def _compute_expert_capacity(self, num_tokens: int) -> int:
"""计算每个专家的容量"""
if not self.use_expert_capacity:
return num_tokens
capacity = int(
(num_tokens / self.num_experts) * self.capacity_factor * self.top_k
)
return max(capacity, 1)
def forward(
self,
x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, D = x.shape
num_tokens = B * T
x_flat = x.view(-1, D)
top_k_gates, top_k_indices, auxiliary_loss = self.router(x_flat)
output = torch.zeros_like(x_flat)
expert_capacity = self._compute_expert_capacity(num_tokens)
for expert_idx, expert in enumerate(self.experts):
expert_mask = (top_k_indices == expert_idx)
token_indices, topk_positions = torch.where(expert_mask)
if len(token_indices) == 0:
continue
if self.use_expert_capacity and len(token_indices) > expert_capacity:
perm = torch.randperm(len(token_indices), device=x.device)[:expert_capacity]
token_indices = token_indices[perm]
topk_positions = topk_positions[perm]
expert_input = x_flat[token_indices]
expert_gates = top_k_gates[token_indices, topk_positions]
expert_output = expert(expert_input)
expert_output = expert_output * expert_gates.unsqueeze(-1)
output.index_add_(0, token_indices, expert_output)
output = output.view(B, T, D)
return output, auxiliary_loss
class SparseDispatcher:
def __init__(
self,
num_experts: int,
gates: torch.Tensor,
expert_indices: torch.Tensor
):
self.num_experts = num_experts
self._gates = gates
self._expert_indices = expert_indices
self._expert_masks = []
for i in range(num_experts):
self._expert_masks.append((expert_indices == i).nonzero(as_tuple=True)[0])
def dispatch(self, inp: torch.Tensor) -> List[torch.Tensor]:
expert_inputs = []
for mask in self._expert_masks:
if len(mask) > 0:
expert_inputs.append(inp[mask])
else:
expert_inputs.append(
torch.empty(0, inp.size(-1), device=inp.device, dtype=inp.dtype)
)
return expert_inputs
def combine(self, expert_outputs: List[torch.Tensor]) -> torch.Tensor:
output_shape = (self._gates.size(0), expert_outputs[0].size(-1))
output = torch.zeros(
output_shape,
device=self._gates.device,
dtype=expert_outputs[0].dtype
)
for expert_idx, expert_out in enumerate(expert_outputs):
mask = self._expert_masks[expert_idx]
if len(mask) > 0:
weighted_output = expert_out * self._gates[mask, expert_idx].unsqueeze(-1)
output[mask] += weighted_output
return output
def expert_to_gates(self) -> List[torch.Tensor]:
gates_per_expert = []
for expert_idx in range(self.num_experts):
mask = self._expert_masks[expert_idx]
if len(mask) > 0:
gates_per_expert.append(self._gates[mask, expert_idx])
else:
gates_per_expert.append(torch.empty(0, device=self._gates.device))
return gates_per_expert
class MoELayer(nn.Module):
def __init__(
self,
dim: int,
num_experts: int = 8,
expert_hidden_dim: Optional[int] = None,
top_k: int = 2,
dropout: float = 0.0,
capacity_factor: float = 1.25
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
if expert_hidden_dim is None:
expert_hidden_dim = int(2 * dim * 4 / 3)
expert_hidden_dim = 256 * ((expert_hidden_dim + 255) // 256)
self.experts = nn.ModuleList([
Expert(dim, expert_hidden_dim, dropout)
for _ in range(num_experts)
])
self.gate = nn.Linear(dim, num_experts, bias=False)
nn.init.normal_(self.gate.weight, std=0.02)
self.capacity_factor = capacity_factor
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, D = x.shape
x_flat = x.view(-1, D)
gates = F.softmax(self.gate(x_flat), dim=-1)
top_k_gates, top_k_indices = torch.topk(gates, self.top_k, dim=-1)
top_k_gates = F.softmax(top_k_gates, dim=-1)
expert_probs = gates.mean(dim=0)
expert_counts = F.one_hot(top_k_indices, self.num_experts).float().sum(dim=[0, 1])
expert_counts = expert_counts / (B * T * self.top_k)
aux_loss = self.num_experts * torch.sum(expert_probs * expert_counts)
output = torch.zeros_like(x_flat)
for expert_idx, expert in enumerate(self.experts):
expert_mask = (top_k_indices == expert_idx)
token_indices, topk_positions = torch.where(expert_mask)
if len(token_indices) == 0:
continue
expert_input = x_flat[token_indices]
expert_gates = top_k_gates[token_indices, topk_positions]
expert_output = expert(expert_input)
expert_output = expert_output * expert_gates.unsqueeze(-1)
output.index_add_(0, token_indices, expert_output)
output = output.view(B, T, D)
return output, aux_loss