|
|
"""
|
|
|
Advanced Mixture of Experts (MoE³) Architecture
|
|
|
Implements hierarchical, consultative experts with domain specialization
|
|
|
"""
|
|
|
|
|
|
import math
|
|
|
import time
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
try:
|
|
|
from torch._dynamo import disable as torchdynamo_disable
|
|
|
except Exception:
|
|
|
def torchdynamo_disable(fn):
|
|
|
return fn
|
|
|
from typing import Dict, List, Tuple, Optional, Any
|
|
|
from dataclasses import dataclass
|
|
|
from enum import Enum
|
|
|
import math
|
|
|
import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class ExpertType(Enum):
|
|
|
"""Types of experts in the hierarchy"""
|
|
|
KNOWLEDGE = "knowledge"
|
|
|
SKILL = "skill"
|
|
|
META = "meta"
|
|
|
SAFETY = "safety"
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ExpertConfig:
|
|
|
"""Configuration for expert modules"""
|
|
|
num_knowledge_experts: int = 64
|
|
|
num_skill_experts: int = 32
|
|
|
num_meta_experts: int = 16
|
|
|
num_safety_experts: int = 8
|
|
|
expert_capacity: float = 1.25
|
|
|
expert_dropout: float = 0.1
|
|
|
|
|
|
|
|
|
|
|
|
load_balance_weight: float = 0.01
|
|
|
z_loss_weight: float = 0.001
|
|
|
importance_weight: float = 0.005
|
|
|
entropy_reg_weight: float = 0.5
|
|
|
aux_loss_weight: float = 0.01
|
|
|
|
|
|
top_k: int = 2
|
|
|
moe_dtype: torch.dtype = torch.float32
|
|
|
expert_parallelism: bool = True
|
|
|
consultative_attention: bool = True
|
|
|
hierarchical_routing: bool = True
|
|
|
|
|
|
|
|
|
class NoisyTopKRouter(nn.Module):
|
|
|
"""Noisy Top-K routing with load balancing"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
hidden_dim: int,
|
|
|
num_experts: int,
|
|
|
top_k: int = 2,
|
|
|
noise_std: float = 1.0,
|
|
|
warmup_steps: int = 100
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.num_experts = num_experts
|
|
|
self.top_k = top_k
|
|
|
self.noise_std = noise_std
|
|
|
self.warmup_steps = warmup_steps
|
|
|
|
|
|
|
|
|
self.register_buffer('_forward_count', torch.tensor(0, dtype=torch.long))
|
|
|
|
|
|
|
|
|
|
|
|
self.gate = nn.Linear(hidden_dim, num_experts, bias=True, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
self.gate.weight.zero_()
|
|
|
|
|
|
self.gate.bias.zero_()
|
|
|
|
|
|
|
|
|
self.noise_linear = nn.Linear(hidden_dim, num_experts, dtype=torch.float32)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
self.noise_linear.weight.zero_()
|
|
|
|
|
|
@torchdynamo_disable
|
|
|
def forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
training: bool = True
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
|
|
|
"""
|
|
|
COMPLETELY REWRITTEN ROUTER with proven balanced routing
|
|
|
Returns: (dispatch_mask, combine_weights, aux_losses)
|
|
|
"""
|
|
|
batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
|
hidden_states_flat = hidden_states.view(-1, hidden_dim)
|
|
|
original_dtype = hidden_states.dtype
|
|
|
|
|
|
|
|
|
if training:
|
|
|
self._forward_count += 1
|
|
|
|
|
|
|
|
|
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
with torch.amp.autocast(device_type=device_type, enabled=False):
|
|
|
hs_fp32 = hidden_states_flat.to(dtype=torch.float32)
|
|
|
|
|
|
|
|
|
warmup_factor = max(0.0, 1.0 - self._forward_count.item() / max(self.warmup_steps, 1))
|
|
|
|
|
|
|
|
|
if training and warmup_factor > 0.5:
|
|
|
|
|
|
|
|
|
uniform_scores = torch.ones_like(
|
|
|
torch.zeros(batch_size * seq_len, self.num_experts, device=hs_fp32.device)
|
|
|
) / self.num_experts
|
|
|
|
|
|
|
|
|
random_noise = torch.randn_like(uniform_scores) * 0.01
|
|
|
scores = uniform_scores + random_noise
|
|
|
scores = F.softmax(scores / 0.1, dim=-1)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raw_logits = self.gate(hs_fp32)
|
|
|
|
|
|
if training and warmup_factor > 0:
|
|
|
|
|
|
gumbel_noise = -torch.log(-torch.log(torch.rand_like(raw_logits) + 1e-10) + 1e-10)
|
|
|
temperature = 5.0 * warmup_factor + 1.0 * (1 - warmup_factor)
|
|
|
logits = (raw_logits * (1 - warmup_factor) + gumbel_noise) / temperature
|
|
|
elif training:
|
|
|
|
|
|
gumbel_noise = -torch.log(-torch.log(torch.rand_like(raw_logits) + 1e-10) + 1e-10)
|
|
|
temperature = max(1.0, 8.0 / self.num_experts)
|
|
|
logits = (raw_logits + gumbel_noise * 0.3) / temperature
|
|
|
else:
|
|
|
logits = raw_logits
|
|
|
|
|
|
|
|
|
scores = F.softmax(logits, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
min_prob = max(0.001, 0.4 / self.num_experts)
|
|
|
scores = scores * (1 - min_prob * self.num_experts) + min_prob
|
|
|
scores = scores / scores.sum(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
actual_top_k = min(self.top_k, self.num_experts)
|
|
|
top_k_scores, top_k_indices = torch.topk(scores, actual_top_k, dim=-1)
|
|
|
|
|
|
|
|
|
top_k_scores = top_k_scores / top_k_scores.sum(dim=-1, keepdim=True)
|
|
|
top_k_scores = top_k_scores.to(original_dtype)
|
|
|
|
|
|
|
|
|
dispatch_mask = torch.zeros(
|
|
|
batch_size * seq_len, self.num_experts, actual_top_k,
|
|
|
dtype=torch.bool, device=hidden_states.device
|
|
|
)
|
|
|
|
|
|
|
|
|
combine_weights = torch.zeros(
|
|
|
batch_size * seq_len, self.num_experts, actual_top_k,
|
|
|
dtype=hidden_states.dtype, device=hidden_states.device
|
|
|
)
|
|
|
|
|
|
|
|
|
for k in range(actual_top_k):
|
|
|
expert_idx = top_k_indices[:, k]
|
|
|
dispatch_mask[torch.arange(batch_size * seq_len), expert_idx, k] = True
|
|
|
combine_weights[torch.arange(batch_size * seq_len), expert_idx, k] = top_k_scores[:, k]
|
|
|
|
|
|
|
|
|
aux_losses = self._compute_aux_losses(scores, dispatch_mask)
|
|
|
|
|
|
|
|
|
dispatch_mask = dispatch_mask.view(batch_size, seq_len, self.num_experts, actual_top_k)
|
|
|
combine_weights = combine_weights.view(batch_size, seq_len, self.num_experts, actual_top_k)
|
|
|
|
|
|
return dispatch_mask, combine_weights, aux_losses
|
|
|
|
|
|
def _compute_aux_losses(
|
|
|
self,
|
|
|
scores: torch.Tensor,
|
|
|
dispatch_mask: torch.Tensor
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
OPTIMIZED: Compute stable load balancing losses
|
|
|
Ensures max_exp stays at 50% for top_k=2
|
|
|
"""
|
|
|
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
with torch.amp.autocast(device_type=device_type, enabled=False):
|
|
|
scores_fp32 = scores.to(torch.float32)
|
|
|
dispatch_fp32 = dispatch_mask.float()
|
|
|
|
|
|
|
|
|
expert_usage = dispatch_fp32.sum(dim=(0, 2))
|
|
|
total_tokens = expert_usage.sum() + 1e-10
|
|
|
expert_usage_normalized = expert_usage / total_tokens
|
|
|
|
|
|
|
|
|
target_usage = (self.top_k / self.num_experts)
|
|
|
uniform_target = torch.full_like(expert_usage_normalized, target_usage)
|
|
|
|
|
|
|
|
|
|
|
|
gate_probs = scores_fp32.mean(dim=0)
|
|
|
fraction_routed = expert_usage_normalized
|
|
|
|
|
|
|
|
|
load_loss = (gate_probs * fraction_routed).sum() * self.num_experts
|
|
|
load_loss = load_loss.clamp(min=0.0, max=10.0)
|
|
|
|
|
|
|
|
|
importance = scores_fp32.mean(dim=0)
|
|
|
importance_variance = torch.var(importance, unbiased=False)
|
|
|
importance_mean = torch.mean(importance)
|
|
|
importance_loss = importance_variance / (importance_mean ** 2 + 1e-10)
|
|
|
importance_loss = importance_loss.clamp(min=0.0, max=1.0)
|
|
|
|
|
|
|
|
|
routing_entropy = -torch.sum(scores_fp32 * torch.log(scores_fp32 + 1e-10), dim=-1).mean()
|
|
|
max_entropy = torch.log(torch.tensor(float(self.num_experts), device=scores.device, dtype=torch.float32))
|
|
|
|
|
|
|
|
|
normalized_entropy = routing_entropy / (max_entropy + 1e-10)
|
|
|
|
|
|
|
|
|
|
|
|
entropy_strength = min(10.0, 20.0 / self.num_experts)
|
|
|
|
|
|
entropy_reg_loss = (1.0 - normalized_entropy) * entropy_strength
|
|
|
entropy_reg_loss = entropy_reg_loss.clamp(min=0.0, max=10.0)
|
|
|
|
|
|
|
|
|
router_logits_squared = torch.logsumexp(scores_fp32, dim=-1) ** 2
|
|
|
z_loss = router_logits_squared.mean()
|
|
|
z_loss = z_loss.clamp(min=0.0, max=100.0)
|
|
|
|
|
|
|
|
|
expert_entropy = -torch.sum(
|
|
|
expert_usage_normalized * torch.log(expert_usage_normalized + 1e-10)
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
'load_loss': load_loss.to(scores.dtype),
|
|
|
'importance_loss': importance_loss.to(scores.dtype),
|
|
|
'z_loss': z_loss.to(scores.dtype),
|
|
|
'entropy_reg_loss': entropy_reg_loss.to(scores.dtype),
|
|
|
'expert_usage': expert_usage_normalized,
|
|
|
'routing_entropy': expert_entropy
|
|
|
}
|
|
|
|
|
|
|
|
|
class Expert(nn.Module):
|
|
|
"""Individual expert module"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
hidden_dim: int,
|
|
|
intermediate_dim: int,
|
|
|
dropout: float = 0.1,
|
|
|
activation: str = 'gelu',
|
|
|
specialization: Optional[str] = None
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.specialization = specialization
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(hidden_dim, intermediate_dim)
|
|
|
self.fc2 = nn.Linear(intermediate_dim, hidden_dim)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
|
|
|
|
|
|
|
|
|
if specialization:
|
|
|
self.specialization_layer = nn.Linear(hidden_dim, hidden_dim)
|
|
|
self.specialization_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
"""Expert forward pass"""
|
|
|
|
|
|
hidden = self.fc1(hidden_states)
|
|
|
hidden = self.activation(hidden)
|
|
|
hidden = self.dropout(hidden)
|
|
|
output = self.fc2(hidden)
|
|
|
|
|
|
|
|
|
if self.specialization:
|
|
|
specialized = self.specialization_layer(output)
|
|
|
specialized = self.specialization_norm(specialized)
|
|
|
output = output + specialized
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
class CrossExpertAttention(nn.Module):
|
|
|
"""Attention mechanism for expert consultation"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
hidden_dim: int,
|
|
|
num_heads: int = 8,
|
|
|
dropout: float = 0.1
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.num_heads = num_heads
|
|
|
self.head_dim = hidden_dim // num_heads
|
|
|
|
|
|
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
|
|
|
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
|
|
|
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
|
|
|
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
expert_outputs: List[torch.Tensor],
|
|
|
weights: Optional[torch.Tensor] = None
|
|
|
) -> torch.Tensor:
|
|
|
"""Allow experts to attend to each other"""
|
|
|
if len(expert_outputs) == 0:
|
|
|
return torch.zeros_like(expert_outputs[0])
|
|
|
|
|
|
|
|
|
stacked = torch.stack(expert_outputs, dim=0)
|
|
|
num_experts, batch_size, seq_len, hidden_dim = stacked.shape
|
|
|
|
|
|
|
|
|
reshaped = stacked.permute(1, 0, 2, 3).reshape(
|
|
|
batch_size, num_experts * seq_len, hidden_dim
|
|
|
)
|
|
|
|
|
|
|
|
|
Q = self.q_proj(reshaped).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
K = self.k_proj(reshaped).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
V = self.v_proj(reshaped).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
|
|
|
|
|
|
if weights is not None:
|
|
|
|
|
|
weights_expanded = weights.unsqueeze(1).expand_as(scores)
|
|
|
scores = scores * weights_expanded
|
|
|
|
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1)
|
|
|
attn_weights = self.dropout(attn_weights)
|
|
|
|
|
|
|
|
|
attn_output = torch.matmul(attn_weights, V)
|
|
|
attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, hidden_dim)
|
|
|
|
|
|
|
|
|
output = self.o_proj(attn_output)
|
|
|
output = output.view(batch_size, num_experts, seq_len, hidden_dim)
|
|
|
output = output.mean(dim=1)
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
class HierarchicalMoE(nn.Module):
|
|
|
"""Hierarchical Mixture of Experts with multiple levels"""
|
|
|
|
|
|
def __init__(self, config: ExpertConfig, hidden_dim: int, intermediate_dim: int):
|
|
|
super().__init__()
|
|
|
|
|
|
self.config = config
|
|
|
self.hidden_dim = hidden_dim
|
|
|
|
|
|
|
|
|
self.knowledge_experts = nn.ModuleList([
|
|
|
Expert(
|
|
|
hidden_dim, intermediate_dim,
|
|
|
dropout=config.expert_dropout,
|
|
|
specialization=f"knowledge_{i}"
|
|
|
)
|
|
|
for i in range(config.num_knowledge_experts)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.skill_experts = nn.ModuleList([
|
|
|
Expert(
|
|
|
hidden_dim, intermediate_dim,
|
|
|
dropout=config.expert_dropout,
|
|
|
specialization=f"skill_{i}"
|
|
|
)
|
|
|
for i in range(config.num_skill_experts)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.meta_experts = nn.ModuleList([
|
|
|
Expert(
|
|
|
hidden_dim, intermediate_dim,
|
|
|
dropout=config.expert_dropout,
|
|
|
specialization=f"meta_{i}"
|
|
|
)
|
|
|
for i in range(config.num_meta_experts)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.safety_experts = nn.ModuleList([
|
|
|
Expert(
|
|
|
hidden_dim, intermediate_dim,
|
|
|
dropout=config.expert_dropout,
|
|
|
specialization="safety"
|
|
|
)
|
|
|
for i in range(config.num_safety_experts)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.knowledge_router = NoisyTopKRouter(
|
|
|
hidden_dim, config.num_knowledge_experts, config.top_k, warmup_steps=100
|
|
|
)
|
|
|
self.skill_router = NoisyTopKRouter(
|
|
|
hidden_dim, config.num_skill_experts, config.top_k, warmup_steps=100
|
|
|
)
|
|
|
self.meta_router = NoisyTopKRouter(
|
|
|
hidden_dim, config.num_meta_experts, config.top_k, warmup_steps=100
|
|
|
)
|
|
|
self.safety_router = NoisyTopKRouter(
|
|
|
hidden_dim, config.num_safety_experts, min(config.top_k, config.num_safety_experts), warmup_steps=100
|
|
|
)
|
|
|
|
|
|
|
|
|
if config.consultative_attention:
|
|
|
self.cross_expert_attention = CrossExpertAttention(hidden_dim)
|
|
|
|
|
|
|
|
|
if config.hierarchical_routing:
|
|
|
self.hierarchy_combiner = nn.Sequential(
|
|
|
nn.Linear(hidden_dim * 4, hidden_dim * 2),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(config.expert_dropout),
|
|
|
nn.Linear(hidden_dim * 2, hidden_dim)
|
|
|
)
|
|
|
|
|
|
def _route_and_compute(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
experts: nn.ModuleList,
|
|
|
router: NoisyTopKRouter,
|
|
|
expert_type: str
|
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
|
"""Route to experts and compute output"""
|
|
|
batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
|
|
|
|
|
|
|
dispatch_mask, combine_weights, aux_losses = router(hidden_states, self.training)
|
|
|
|
|
|
|
|
|
expert_output = torch.zeros_like(hidden_states)
|
|
|
|
|
|
|
|
|
for expert_idx, expert in enumerate(experts):
|
|
|
|
|
|
expert_mask = dispatch_mask[:, :, expert_idx, :].any(dim=-1)
|
|
|
|
|
|
if expert_mask.any():
|
|
|
|
|
|
expert_input = hidden_states[expert_mask]
|
|
|
|
|
|
|
|
|
expert_result = expert(expert_input.unsqueeze(0)).squeeze(0)
|
|
|
|
|
|
|
|
|
expert_weights = combine_weights[:, :, expert_idx, :].sum(dim=-1)
|
|
|
expert_weights = expert_weights[expert_mask]
|
|
|
|
|
|
|
|
|
expert_output[expert_mask] += expert_result * expert_weights.unsqueeze(-1)
|
|
|
|
|
|
|
|
|
prefixed_losses = {f"{expert_type}_{k}": v for k, v in aux_losses.items()}
|
|
|
|
|
|
return expert_output, prefixed_losses
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
return_all_levels: bool = False
|
|
|
) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
|
|
"""Forward pass through hierarchical MoE"""
|
|
|
|
|
|
all_aux_losses = {}
|
|
|
level_outputs = {}
|
|
|
|
|
|
|
|
|
knowledge_output, knowledge_losses = self._route_and_compute(
|
|
|
hidden_states, self.knowledge_experts, self.knowledge_router, "knowledge"
|
|
|
)
|
|
|
level_outputs['knowledge'] = knowledge_output
|
|
|
all_aux_losses.update(knowledge_losses)
|
|
|
|
|
|
|
|
|
skill_input = hidden_states + 0.5 * knowledge_output
|
|
|
skill_output, skill_losses = self._route_and_compute(
|
|
|
skill_input, self.skill_experts, self.skill_router, "skill"
|
|
|
)
|
|
|
level_outputs['skill'] = skill_output
|
|
|
all_aux_losses.update(skill_losses)
|
|
|
|
|
|
|
|
|
meta_input = hidden_states + 0.3 * knowledge_output + 0.3 * skill_output
|
|
|
meta_output, meta_losses = self._route_and_compute(
|
|
|
meta_input, self.meta_experts, self.meta_router, "meta"
|
|
|
)
|
|
|
level_outputs['meta'] = meta_output
|
|
|
all_aux_losses.update(meta_losses)
|
|
|
|
|
|
|
|
|
safety_input = hidden_states + 0.2 * (knowledge_output + skill_output + meta_output)
|
|
|
safety_output, safety_losses = self._route_and_compute(
|
|
|
safety_input, self.safety_experts, self.safety_router, "safety"
|
|
|
)
|
|
|
level_outputs['safety'] = safety_output
|
|
|
all_aux_losses.update(safety_losses)
|
|
|
|
|
|
|
|
|
if self.config.consultative_attention and hasattr(self, 'cross_expert_attention'):
|
|
|
expert_outputs_list = [knowledge_output, skill_output, meta_output, safety_output]
|
|
|
consulted_output = self.cross_expert_attention(expert_outputs_list)
|
|
|
level_outputs['consulted'] = consulted_output
|
|
|
|
|
|
|
|
|
if self.config.hierarchical_routing:
|
|
|
combined = torch.cat([
|
|
|
knowledge_output,
|
|
|
skill_output,
|
|
|
meta_output,
|
|
|
safety_output
|
|
|
], dim=-1)
|
|
|
final_output = self.hierarchy_combiner(combined)
|
|
|
else:
|
|
|
|
|
|
final_output = (
|
|
|
0.3 * knowledge_output +
|
|
|
0.3 * skill_output +
|
|
|
0.2 * meta_output +
|
|
|
0.2 * safety_output
|
|
|
)
|
|
|
|
|
|
|
|
|
final_output = hidden_states + final_output
|
|
|
|
|
|
|
|
|
expert_utilization = {}
|
|
|
total_routing_entropy = 0.0
|
|
|
|
|
|
for expert_type in ['knowledge', 'skill', 'meta', 'safety']:
|
|
|
if f"{expert_type}_expert_usage" in all_aux_losses:
|
|
|
usage = all_aux_losses[f"{expert_type}_expert_usage"]
|
|
|
entropy = all_aux_losses[f"{expert_type}_routing_entropy"]
|
|
|
|
|
|
|
|
|
expert_utilization[f"{expert_type}_usage_pct"] = (usage * 100).tolist()
|
|
|
|
|
|
|
|
|
if usage.numel() > 1:
|
|
|
expert_utilization[f"{expert_type}_load_variance"] = float(torch.var(usage, unbiased=False))
|
|
|
else:
|
|
|
expert_utilization[f"{expert_type}_load_variance"] = 0.0
|
|
|
|
|
|
|
|
|
expert_utilization[f"{expert_type}_top_expert_pct"] = float(torch.max(usage) * 100)
|
|
|
|
|
|
|
|
|
expert_utilization[f"{expert_type}_entropy"] = float(entropy)
|
|
|
total_routing_entropy += entropy
|
|
|
|
|
|
|
|
|
expert_utilization['total_routing_entropy'] = float(total_routing_entropy)
|
|
|
expert_utilization['avg_routing_entropy'] = float(total_routing_entropy / 4)
|
|
|
|
|
|
|
|
|
info = {
|
|
|
'aux_losses': all_aux_losses,
|
|
|
'level_outputs': level_outputs if return_all_levels else None,
|
|
|
'num_experts_used': {
|
|
|
'knowledge': self.config.num_knowledge_experts,
|
|
|
'skill': self.config.num_skill_experts,
|
|
|
'meta': self.config.num_meta_experts,
|
|
|
'safety': self.config.num_safety_experts
|
|
|
},
|
|
|
'expert_utilization': expert_utilization
|
|
|
}
|
|
|
|
|
|
return final_output, info
|
|
|
|
|
|
|
|
|
class MoELayer(nn.Module):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: ExpertConfig,
|
|
|
hidden_dim: int,
|
|
|
intermediate_dim: Optional[int] = None
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.config = config
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.intermediate_dim = intermediate_dim or hidden_dim * 4
|
|
|
|
|
|
|
|
|
self.input_norm = nn.LayerNorm(hidden_dim)
|
|
|
self.post_moe_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
|
|
|
self.moe = HierarchicalMoE(config, hidden_dim, self.intermediate_dim)
|
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.expert_dropout)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
return_aux_loss: bool = True
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
|
"""Forward pass through MoE layer"""
|
|
|
|
|
|
|
|
|
normed_hidden_states = self.input_norm(hidden_states)
|
|
|
|
|
|
|
|
|
moe_output, moe_info = self.moe(normed_hidden_states)
|
|
|
|
|
|
|
|
|
self.last_moe_info = moe_info
|
|
|
|
|
|
|
|
|
moe_output = self.dropout(moe_output)
|
|
|
|
|
|
|
|
|
output = self.post_moe_norm(moe_output)
|
|
|
|
|
|
|
|
|
aux_loss = None
|
|
|
if return_aux_loss and moe_info.get('aux_losses'):
|
|
|
aux_losses = moe_info['aux_losses']
|
|
|
|
|
|
|
|
|
total_aux_loss = torch.tensor(0.0, device=hidden_states.device)
|
|
|
|
|
|
|
|
|
load_balance_weight = getattr(self.config, 'load_balance_weight', 0.01)
|
|
|
z_loss_weight = getattr(self.config, 'z_loss_weight', 0.001)
|
|
|
importance_weight = getattr(self.config, 'importance_weight', 0.01)
|
|
|
|
|
|
for key, loss in aux_losses.items():
|
|
|
if 'load_loss' in key:
|
|
|
|
|
|
total_aux_loss += load_balance_weight * loss
|
|
|
elif 'z_loss' in key:
|
|
|
|
|
|
total_aux_loss += z_loss_weight * loss
|
|
|
elif 'importance_loss' in key:
|
|
|
|
|
|
total_aux_loss += importance_weight * loss
|
|
|
elif 'entropy_reg_loss' in key:
|
|
|
|
|
|
entropy_weight = getattr(self.config, 'entropy_reg_weight', 1.0)
|
|
|
total_aux_loss += entropy_weight * loss
|
|
|
|
|
|
aux_loss = total_aux_loss
|
|
|
|
|
|
return output, aux_loss
|
|
|
|
|
|
|
|
|
def create_moe_layers(
|
|
|
num_layers: int,
|
|
|
config: ExpertConfig,
|
|
|
hidden_dim: int,
|
|
|
sparse_layers: Optional[List[int]] = None
|
|
|
) -> nn.ModuleList:
|
|
|
"""Create MoE layers for transformer"""
|
|
|
|
|
|
if sparse_layers is None:
|
|
|
|
|
|
sparse_layers = list(range(1, num_layers, 2))
|
|
|
|
|
|
layers = nn.ModuleList()
|
|
|
|
|
|
for layer_idx in range(num_layers):
|
|
|
if layer_idx in sparse_layers:
|
|
|
|
|
|
layer = MoELayer(config, hidden_dim)
|
|
|
else:
|
|
|
|
|
|
layer = nn.Sequential(
|
|
|
nn.LayerNorm(hidden_dim),
|
|
|
nn.Linear(hidden_dim, hidden_dim * 4),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(hidden_dim * 4, hidden_dim),
|
|
|
nn.Dropout(config.expert_dropout)
|
|
|
)
|
|
|
|
|
|
layers.append(layer)
|
|
|
|
|
|
return layers
|
|
|
|