Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""
Advanced Mixture of Experts (MoE³) Architecture
Implements hierarchical, consultative experts with domain specialization
"""
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
# Safe import for torch.compile disable
try:
from torch._dynamo import disable as torchdynamo_disable
except Exception: # pragma: no cover
def torchdynamo_disable(fn):
return fn
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from enum import Enum
import math
import logging
logger = logging.getLogger(__name__)
class ExpertType(Enum):
"""Types of experts in the hierarchy"""
KNOWLEDGE = "knowledge" # Domain knowledge experts
SKILL = "skill" # Task-specific skill experts
META = "meta" # Meta-reasoning experts
SAFETY = "safety" # Safety and alignment experts
@dataclass
class ExpertConfig:
"""Configuration for expert modules"""
num_knowledge_experts: int = 64
num_skill_experts: int = 32
num_meta_experts: int = 16
num_safety_experts: int = 8
expert_capacity: float = 1.25 # Capacity factor for load balancing
expert_dropout: float = 0.1
# OPTIMIZED: Load balancing loss weights for stable 50% max expert usage
# These weights are carefully tuned to maintain balance without causing instability
load_balance_weight: float = 0.01 # Switch Transformers load balancing
z_loss_weight: float = 0.001 # Router logit regularization (prevent extremes)
importance_weight: float = 0.005 # Routing diversity (reduced for stability)
entropy_reg_weight: float = 0.5 # Entropy regularization (gentler than before)
aux_loss_weight: float = 0.01 # Legacy - kept for compatibility
top_k: int = 2 # Number of experts to route to (50% max expert for k=2)
moe_dtype: torch.dtype = torch.float32
expert_parallelism: bool = True
consultative_attention: bool = True
hierarchical_routing: bool = True
class NoisyTopKRouter(nn.Module):
"""Noisy Top-K routing with load balancing"""
def __init__(
self,
hidden_dim: int,
num_experts: int,
top_k: int = 2,
noise_std: float = 1.0,
warmup_steps: int = 100
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_experts = num_experts
self.top_k = top_k
self.noise_std = noise_std
self.warmup_steps = warmup_steps
# Track forward passes for warmup
self.register_buffer('_forward_count', torch.tensor(0, dtype=torch.long))
# Router network - use float32
# Router gate with BALANCED initialization for immediate load balancing
self.gate = nn.Linear(hidden_dim, num_experts, bias=True, dtype=torch.float32)
# CRITICAL: Initialize router to start nearly uniform from step 0
with torch.no_grad():
# ZERO weights initially - purely noise-driven routing at start
self.gate.weight.zero_()
# Zero bias - all experts equally likely before learning
self.gate.bias.zero_()
# Learnable noise parameters
self.noise_linear = nn.Linear(hidden_dim, num_experts, dtype=torch.float32)
# Initialize noise layer with small weights
with torch.no_grad():
self.noise_linear.weight.zero_() # Start with pure random routing
@torchdynamo_disable
def forward(
self,
hidden_states: torch.Tensor,
training: bool = True
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
"""
COMPLETELY REWRITTEN ROUTER with proven balanced routing
Returns: (dispatch_mask, combine_weights, aux_losses)
"""
batch_size, seq_len, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim) # [B*S, H]
original_dtype = hidden_states.dtype
# Increment forward counter
if training:
self._forward_count += 1
# CRITICAL FIX: Force balanced routing with warmup period
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.amp.autocast(device_type=device_type, enabled=False):
hs_fp32 = hidden_states_flat.to(dtype=torch.float32)
# Calculate warmup factor (1.0 at start, 0.0 after warmup)
warmup_factor = max(0.0, 1.0 - self._forward_count.item() / max(self.warmup_steps, 1))
# AGGRESSIVE WARMUP FIX: Force truly uniform distribution during early steps
if training and warmup_factor > 0.5:
# During first 50 steps: PURE UNIFORM + small random noise
# This guarantees balanced routing from step 0
uniform_scores = torch.ones_like(
torch.zeros(batch_size * seq_len, self.num_experts, device=hs_fp32.device)
) / self.num_experts
# Add tiny random perturbations to break ties
random_noise = torch.randn_like(uniform_scores) * 0.01
scores = uniform_scores + random_noise
scores = F.softmax(scores / 0.1, dim=-1) # Low temp for nearly uniform
else:
# After warmup threshold: use learned routing
raw_logits = self.gate(hs_fp32) # [B*S, E]
if training and warmup_factor > 0:
# Gradual transition phase (steps 50-100)
gumbel_noise = -torch.log(-torch.log(torch.rand_like(raw_logits) + 1e-10) + 1e-10)
temperature = 5.0 * warmup_factor + 1.0 * (1 - warmup_factor)
logits = (raw_logits * (1 - warmup_factor) + gumbel_noise) / temperature
elif training:
# After warmup: normal routing with exploration
gumbel_noise = -torch.log(-torch.log(torch.rand_like(raw_logits) + 1e-10) + 1e-10)
temperature = max(1.0, 8.0 / self.num_experts)
logits = (raw_logits + gumbel_noise * 0.3) / temperature
else:
logits = raw_logits
# Compute softmax probabilities
scores = F.softmax(logits, dim=-1)
# ADAPTIVE: Minimum probability scales with expert count
# For 8 experts: min_prob = 0.05 (5%), for 64 experts: min_prob = 0.006 (0.6%)
min_prob = max(0.001, 0.4 / self.num_experts)
scores = scores * (1 - min_prob * self.num_experts) + min_prob
scores = scores / scores.sum(dim=-1, keepdim=True) # Renormalize
# Top-k selection with actual_top_k
actual_top_k = min(self.top_k, self.num_experts)
top_k_scores, top_k_indices = torch.topk(scores, actual_top_k, dim=-1)
# Renormalize and cast back
top_k_scores = top_k_scores / top_k_scores.sum(dim=-1, keepdim=True)
top_k_scores = top_k_scores.to(original_dtype)
# Create dispatch mask [B*S, E, K] - use actual_top_k
dispatch_mask = torch.zeros(
batch_size * seq_len, self.num_experts, actual_top_k,
dtype=torch.bool, device=hidden_states.device
)
# Create combine weights [B*S, E, K] - use actual_top_k
combine_weights = torch.zeros(
batch_size * seq_len, self.num_experts, actual_top_k,
dtype=hidden_states.dtype, device=hidden_states.device
)
# Fill dispatch mask and weights
for k in range(actual_top_k):
expert_idx = top_k_indices[:, k] # [B*S]
dispatch_mask[torch.arange(batch_size * seq_len), expert_idx, k] = True
combine_weights[torch.arange(batch_size * seq_len), expert_idx, k] = top_k_scores[:, k]
# Compute auxiliary losses for load balancing
aux_losses = self._compute_aux_losses(scores, dispatch_mask)
# Reshape back
dispatch_mask = dispatch_mask.view(batch_size, seq_len, self.num_experts, actual_top_k)
combine_weights = combine_weights.view(batch_size, seq_len, self.num_experts, actual_top_k)
return dispatch_mask, combine_weights, aux_losses
def _compute_aux_losses(
self,
scores: torch.Tensor,
dispatch_mask: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
OPTIMIZED: Compute stable load balancing losses
Ensures max_exp stays at 50% for top_k=2
"""
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.amp.autocast(device_type=device_type, enabled=False):
scores_fp32 = scores.to(torch.float32)
dispatch_fp32 = dispatch_mask.float()
# Expert usage based on actual routing decisions
expert_usage = dispatch_fp32.sum(dim=(0, 2)) # [E]
total_tokens = expert_usage.sum() + 1e-10
expert_usage_normalized = expert_usage / total_tokens
# Target: For top_k=2, ideal is 2/num_experts per expert (50% max when k=2)
target_usage = (self.top_k / self.num_experts)
uniform_target = torch.full_like(expert_usage_normalized, target_usage)
# Switch Transformer style load balancing loss (more stable)
# P(expert) * f(expert) where P is gate prob, f is fraction routed
gate_probs = scores_fp32.mean(dim=0) # [E]
fraction_routed = expert_usage_normalized # [E]
# Load loss: minimize variance in expert*fraction product
load_loss = (gate_probs * fraction_routed).sum() * self.num_experts
load_loss = load_loss.clamp(min=0.0, max=10.0) # Prevent explosion
# Importance loss: encourage routing diversity
importance = scores_fp32.mean(dim=0) # [E]
importance_variance = torch.var(importance, unbiased=False)
importance_mean = torch.mean(importance)
importance_loss = importance_variance / (importance_mean ** 2 + 1e-10)
importance_loss = importance_loss.clamp(min=0.0, max=1.0)
# Entropy regularization - CRITICAL for small expert counts
routing_entropy = -torch.sum(scores_fp32 * torch.log(scores_fp32 + 1e-10), dim=-1).mean()
max_entropy = torch.log(torch.tensor(float(self.num_experts), device=scores.device, dtype=torch.float32))
# Normalized entropy (0 to 1, where 1 is perfect balance)
normalized_entropy = routing_entropy / (max_entropy + 1e-10)
# ADAPTIVE: Stronger penalty for smaller expert counts
# Small pools collapse more easily, need stronger regularization
entropy_strength = min(10.0, 20.0 / self.num_experts)
# Loss increases as entropy decreases (encourage high entropy)
entropy_reg_loss = (1.0 - normalized_entropy) * entropy_strength
entropy_reg_loss = entropy_reg_loss.clamp(min=0.0, max=10.0)
# Z-loss: router logit regularization (prevent extreme values)
router_logits_squared = torch.logsumexp(scores_fp32, dim=-1) ** 2
z_loss = router_logits_squared.mean()
z_loss = z_loss.clamp(min=0.0, max=100.0)
# Expert usage entropy (measure balance across experts)
expert_entropy = -torch.sum(
expert_usage_normalized * torch.log(expert_usage_normalized + 1e-10)
)
return {
'load_loss': load_loss.to(scores.dtype),
'importance_loss': importance_loss.to(scores.dtype),
'z_loss': z_loss.to(scores.dtype),
'entropy_reg_loss': entropy_reg_loss.to(scores.dtype),
'expert_usage': expert_usage_normalized,
'routing_entropy': expert_entropy
}
class Expert(nn.Module):
"""Individual expert module"""
def __init__(
self,
hidden_dim: int,
intermediate_dim: int,
dropout: float = 0.1,
activation: str = 'gelu',
specialization: Optional[str] = None
):
super().__init__()
self.specialization = specialization
# Expert MLP
self.fc1 = nn.Linear(hidden_dim, intermediate_dim)
self.fc2 = nn.Linear(intermediate_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
# Activation
self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
# Specialization-specific components
if specialization:
self.specialization_layer = nn.Linear(hidden_dim, hidden_dim)
self.specialization_norm = nn.LayerNorm(hidden_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Expert forward pass"""
# Main expert computation
hidden = self.fc1(hidden_states)
hidden = self.activation(hidden)
hidden = self.dropout(hidden)
output = self.fc2(hidden)
# Apply specialization if present
if self.specialization:
specialized = self.specialization_layer(output)
specialized = self.specialization_norm(specialized)
output = output + specialized
return output
class CrossExpertAttention(nn.Module):
"""Attention mechanism for expert consultation"""
def __init__(
self,
hidden_dim: int,
num_heads: int = 8,
dropout: float = 0.1
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(
self,
expert_outputs: List[torch.Tensor],
weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Allow experts to attend to each other"""
if len(expert_outputs) == 0:
return torch.zeros_like(expert_outputs[0])
# Stack expert outputs [num_experts, batch, seq, hidden]
stacked = torch.stack(expert_outputs, dim=0)
num_experts, batch_size, seq_len, hidden_dim = stacked.shape
# Reshape for attention [batch, num_experts * seq, hidden]
reshaped = stacked.permute(1, 0, 2, 3).reshape(
batch_size, num_experts * seq_len, hidden_dim
)
# Compute attention
Q = self.q_proj(reshaped).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(reshaped).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(reshaped).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Apply weights if provided
if weights is not None:
# Expand weights to match attention shape
weights_expanded = weights.unsqueeze(1).expand_as(scores)
scores = scores * weights_expanded
# Attention weights
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention
attn_output = torch.matmul(attn_weights, V)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, -1, hidden_dim)
# Output projection and reshape back
output = self.o_proj(attn_output)
output = output.view(batch_size, num_experts, seq_len, hidden_dim)
output = output.mean(dim=1) # Average over experts
return output
class HierarchicalMoE(nn.Module):
"""Hierarchical Mixture of Experts with multiple levels"""
def __init__(self, config: ExpertConfig, hidden_dim: int, intermediate_dim: int):
super().__init__()
self.config = config
self.hidden_dim = hidden_dim
# Knowledge experts (domain-specific)
self.knowledge_experts = nn.ModuleList([
Expert(
hidden_dim, intermediate_dim,
dropout=config.expert_dropout,
specialization=f"knowledge_{i}"
)
for i in range(config.num_knowledge_experts)
])
# Skill experts (task-specific)
self.skill_experts = nn.ModuleList([
Expert(
hidden_dim, intermediate_dim,
dropout=config.expert_dropout,
specialization=f"skill_{i}"
)
for i in range(config.num_skill_experts)
])
# Meta experts (reasoning and planning)
self.meta_experts = nn.ModuleList([
Expert(
hidden_dim, intermediate_dim,
dropout=config.expert_dropout,
specialization=f"meta_{i}"
)
for i in range(config.num_meta_experts)
])
# Safety experts (alignment and safety)
self.safety_experts = nn.ModuleList([
Expert(
hidden_dim, intermediate_dim,
dropout=config.expert_dropout,
specialization="safety"
)
for i in range(config.num_safety_experts)
])
# Routers for each level with warmup for balanced initialization
self.knowledge_router = NoisyTopKRouter(
hidden_dim, config.num_knowledge_experts, config.top_k, warmup_steps=100
)
self.skill_router = NoisyTopKRouter(
hidden_dim, config.num_skill_experts, config.top_k, warmup_steps=100
)
self.meta_router = NoisyTopKRouter(
hidden_dim, config.num_meta_experts, config.top_k, warmup_steps=100
)
self.safety_router = NoisyTopKRouter(
hidden_dim, config.num_safety_experts, min(config.top_k, config.num_safety_experts), warmup_steps=100
)
# Cross-expert attention for consultation
if config.consultative_attention:
self.cross_expert_attention = CrossExpertAttention(hidden_dim)
# Hierarchical combination
if config.hierarchical_routing:
self.hierarchy_combiner = nn.Sequential(
nn.Linear(hidden_dim * 4, hidden_dim * 2),
nn.ReLU(),
nn.Dropout(config.expert_dropout),
nn.Linear(hidden_dim * 2, hidden_dim)
)
def _route_and_compute(
self,
hidden_states: torch.Tensor,
experts: nn.ModuleList,
router: NoisyTopKRouter,
expert_type: str
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Route to experts and compute output"""
batch_size, seq_len, hidden_dim = hidden_states.shape
# Get routing decisions
dispatch_mask, combine_weights, aux_losses = router(hidden_states, self.training)
# Initialize output
expert_output = torch.zeros_like(hidden_states)
# Process each expert
for expert_idx, expert in enumerate(experts):
# Get tokens routed to this expert
expert_mask = dispatch_mask[:, :, expert_idx, :].any(dim=-1) # [B, S]
if expert_mask.any():
# Get input for this expert
expert_input = hidden_states[expert_mask]
# Compute expert output
expert_result = expert(expert_input.unsqueeze(0)).squeeze(0)
# Get weights for this expert
expert_weights = combine_weights[:, :, expert_idx, :].sum(dim=-1) # [B, S]
expert_weights = expert_weights[expert_mask]
# Weighted accumulation
expert_output[expert_mask] += expert_result * expert_weights.unsqueeze(-1)
# Add auxiliary losses with type prefix
prefixed_losses = {f"{expert_type}_{k}": v for k, v in aux_losses.items()}
return expert_output, prefixed_losses
def forward(
self,
hidden_states: torch.Tensor,
return_all_levels: bool = False
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Forward pass through hierarchical MoE"""
all_aux_losses = {}
level_outputs = {}
# Level 1: Knowledge experts
knowledge_output, knowledge_losses = self._route_and_compute(
hidden_states, self.knowledge_experts, self.knowledge_router, "knowledge"
)
level_outputs['knowledge'] = knowledge_output
all_aux_losses.update(knowledge_losses)
# Level 2: Skill experts (can see knowledge expert output)
skill_input = hidden_states + 0.5 * knowledge_output # Residual connection
skill_output, skill_losses = self._route_and_compute(
skill_input, self.skill_experts, self.skill_router, "skill"
)
level_outputs['skill'] = skill_output
all_aux_losses.update(skill_losses)
# Level 3: Meta experts (can see both knowledge and skill)
meta_input = hidden_states + 0.3 * knowledge_output + 0.3 * skill_output
meta_output, meta_losses = self._route_and_compute(
meta_input, self.meta_experts, self.meta_router, "meta"
)
level_outputs['meta'] = meta_output
all_aux_losses.update(meta_losses)
# Level 4: Safety experts (see all previous levels)
safety_input = hidden_states + 0.2 * (knowledge_output + skill_output + meta_output)
safety_output, safety_losses = self._route_and_compute(
safety_input, self.safety_experts, self.safety_router, "safety"
)
level_outputs['safety'] = safety_output
all_aux_losses.update(safety_losses)
# Consultative attention between expert outputs
if self.config.consultative_attention and hasattr(self, 'cross_expert_attention'):
expert_outputs_list = [knowledge_output, skill_output, meta_output, safety_output]
consulted_output = self.cross_expert_attention(expert_outputs_list)
level_outputs['consulted'] = consulted_output
# Hierarchical combination
if self.config.hierarchical_routing:
combined = torch.cat([
knowledge_output,
skill_output,
meta_output,
safety_output
], dim=-1)
final_output = self.hierarchy_combiner(combined)
else:
# Simple weighted combination
final_output = (
0.3 * knowledge_output +
0.3 * skill_output +
0.2 * meta_output +
0.2 * safety_output
)
# Add residual connection
final_output = hidden_states + final_output
# Calculate expert utilization metrics
expert_utilization = {}
total_routing_entropy = 0.0
for expert_type in ['knowledge', 'skill', 'meta', 'safety']:
if f"{expert_type}_expert_usage" in all_aux_losses:
usage = all_aux_losses[f"{expert_type}_expert_usage"]
entropy = all_aux_losses[f"{expert_type}_routing_entropy"]
# Per-expert utilization percentages
expert_utilization[f"{expert_type}_usage_pct"] = (usage * 100).tolist()
# Load variance (how balanced the routing is)
if usage.numel() > 1:
expert_utilization[f"{expert_type}_load_variance"] = float(torch.var(usage, unbiased=False))
else:
expert_utilization[f"{expert_type}_load_variance"] = 0.0
# Top expert concentration (what % goes to most used expert)
expert_utilization[f"{expert_type}_top_expert_pct"] = float(torch.max(usage) * 100)
# Routing entropy (higher = more balanced)
expert_utilization[f"{expert_type}_entropy"] = float(entropy)
total_routing_entropy += entropy
# Overall MoE health metrics
expert_utilization['total_routing_entropy'] = float(total_routing_entropy)
expert_utilization['avg_routing_entropy'] = float(total_routing_entropy / 4) # 4 expert types
# Prepare info dictionary
info = {
'aux_losses': all_aux_losses,
'level_outputs': level_outputs if return_all_levels else None,
'num_experts_used': {
'knowledge': self.config.num_knowledge_experts,
'skill': self.config.num_skill_experts,
'meta': self.config.num_meta_experts,
'safety': self.config.num_safety_experts
},
'expert_utilization': expert_utilization
}
return final_output, info
class MoELayer(nn.Module):
def __init__(
self,
config: ExpertConfig,
hidden_dim: int,
intermediate_dim: Optional[int] = None
):
super().__init__()
self.config = config
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim or hidden_dim * 4
# Layer normalization
self.input_norm = nn.LayerNorm(hidden_dim)
self.post_moe_norm = nn.LayerNorm(hidden_dim)
# Hierarchical MoE
self.moe = HierarchicalMoE(config, hidden_dim, self.intermediate_dim)
# Dropout
self.dropout = nn.Dropout(config.expert_dropout)
def forward(
self,
hidden_states: torch.Tensor,
return_aux_loss: bool = True
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward pass through MoE layer"""
# Normalize input
normed_hidden_states = self.input_norm(hidden_states)
# MoE forward pass
moe_output, moe_info = self.moe(normed_hidden_states)
# Store MoE info for later access
self.last_moe_info = moe_info
# Apply dropout
moe_output = self.dropout(moe_output)
# Post normalization
output = self.post_moe_norm(moe_output)
# Calculate auxiliary loss
aux_loss = None
if return_aux_loss and moe_info.get('aux_losses'):
aux_losses = moe_info['aux_losses']
# Combine all auxiliary losses with proper weighting
total_aux_loss = torch.tensor(0.0, device=hidden_states.device)
# Load balancing losses (Switch Transformers approach)
load_balance_weight = getattr(self.config, 'load_balance_weight', 0.01)
z_loss_weight = getattr(self.config, 'z_loss_weight', 0.001)
importance_weight = getattr(self.config, 'importance_weight', 0.01)
for key, loss in aux_losses.items():
if 'load_loss' in key:
# Primary load balancing loss - encourages uniform expert usage
total_aux_loss += load_balance_weight * loss
elif 'z_loss' in key:
# Z-loss - prevents router logits from becoming too large
total_aux_loss += z_loss_weight * loss
elif 'importance_loss' in key:
# Importance loss - encourages diversity in routing decisions
total_aux_loss += importance_weight * loss
elif 'entropy_reg_loss' in key:
# NUCLEAR FIX: Direct entropy regularization
entropy_weight = getattr(self.config, 'entropy_reg_weight', 1.0)
total_aux_loss += entropy_weight * loss
aux_loss = total_aux_loss
return output, aux_loss
def create_moe_layers(
num_layers: int,
config: ExpertConfig,
hidden_dim: int,
sparse_layers: Optional[List[int]] = None
) -> nn.ModuleList:
"""Create MoE layers for transformer"""
if sparse_layers is None:
# Default: every other layer is MoE
sparse_layers = list(range(1, num_layers, 2))
layers = nn.ModuleList()
for layer_idx in range(num_layers):
if layer_idx in sparse_layers:
# MoE layer
layer = MoELayer(config, hidden_dim)
else:
# Regular FFN layer (placeholder - would use actual FFN)
layer = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Linear(hidden_dim * 4, hidden_dim),
nn.Dropout(config.expert_dropout)
)
layers.append(layer)
return layers