|
|
"""# ββββββββββββ
|
|
|
|
|
|
# `feedforward.py`
|
|
|
|
|
|
Regarding dropout:
|
|
|
|
|
|
- I don't see it applied to the MoE in DeepSeek-V3, [here](https://huggingface.co/deepseek-ai/DeepSeek-R1/blob/main/modeling_deepseek.py).
|
|
|
|
|
|
- I don't see it applied in [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L140)
|
|
|
|
|
|
Norms:
|
|
|
|
|
|
* nn.RMSNorm [here](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)
|
|
|
|
|
|
## FFN
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from ..models.shared_space_config import SharedSpaceDecoderConfig
|
|
|
|
|
|
|
|
|
def create_norm_layer(hidden_size: int, config: SharedSpaceDecoderConfig) -> nn.Module:
|
|
|
"""
|
|
|
Create a normalization layer based on the config norm_type.
|
|
|
|
|
|
Args:
|
|
|
hidden_size: The dimension to normalize over
|
|
|
config: Configuration containing norm_type and epsilon values
|
|
|
|
|
|
Returns:
|
|
|
Either a LayerNorm or RMSNorm layer
|
|
|
"""
|
|
|
if config.norm_type == "layernorm":
|
|
|
return nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
|
|
|
elif config.norm_type == "rmsnorm":
|
|
|
return DeepseekV3RMSNorm(hidden_size, eps=config.rms_norm_eps)
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"Unknown norm_type: {config.norm_type}")
|
|
|
|
|
|
|
|
|
|
|
|
class DeepseekV3RMSNorm(nn.Module):
|
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
|
"""
|
|
|
DeepseekV3RMSNorm is equivalent to T5LayerNorm
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
|
self.variance_epsilon = eps
|
|
|
|
|
|
def forward(self, hidden_states):
|
|
|
input_dtype = hidden_states.dtype
|
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
class SubspaceFeedForward(nn.Module):
|
|
|
"""
|
|
|
Feed-forward block for SharedSpaceDecoder.
|
|
|
|
|
|
Implements SwiGLU:
|
|
|
FFN(x) = W_out( Swish(W_in(x)) β W_gate(x) ) + residual
|
|
|
|
|
|
Supports both dense and decomposed MLP variants.
|
|
|
|
|
|
Dense:
|
|
|
- W_in: Linear(hidden_dim β intermediate_dim)
|
|
|
- W_gate: Linear(hidden_dim β intermediate_dim)
|
|
|
- W_out: Linear(intermediate_dim β hidden_dim)
|
|
|
|
|
|
Decomposed:
|
|
|
- W_in_shared: Linear(hidden_dim β rank, bias=False)
|
|
|
- W_in_shared_norm: RMSNorm
|
|
|
- W_in: Linear(rank β intermediate_dim)
|
|
|
- W_gate_shared: Linear(hidden_dim β rank, bias=False)
|
|
|
- W_gate_shared_norm: RMSNorm
|
|
|
- W_gate: Linear(rank β intermediate_dim)
|
|
|
- W_out: Linear(intermediate_dim β rank, bias=False)
|
|
|
- W_out_shared: Linear(rank β hidden_dim)
|
|
|
|
|
|
Residual, dropout, and post-norm are handled inside the block.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config, layer_idx):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.is_dense = (not config.ffn_decompose) or (layer_idx < config.num_dense_layers)
|
|
|
|
|
|
hidden_dim = config.hidden_size
|
|
|
intermediate_dim = config.intermediate_size
|
|
|
|
|
|
|
|
|
if self.is_dense:
|
|
|
|
|
|
self.W_in = nn.Linear(hidden_dim, intermediate_dim)
|
|
|
self.W_gate = nn.Linear(hidden_dim, intermediate_dim)
|
|
|
self.W_out = nn.Linear(intermediate_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
else:
|
|
|
rank = config.ffn_rank
|
|
|
|
|
|
print("hidden_dim:", hidden_dim)
|
|
|
print("rank:", rank)
|
|
|
|
|
|
|
|
|
self.W_in_shared = nn.Linear(hidden_dim, rank, bias=False)
|
|
|
self.W_in_shared_norm = create_norm_layer(rank, config)
|
|
|
self.W_in = nn.Linear(rank, intermediate_dim, bias=True)
|
|
|
|
|
|
|
|
|
self.W_gate_shared = nn.Linear(hidden_dim, rank, bias=False)
|
|
|
self.W_gate_shared_norm = create_norm_layer(rank, config)
|
|
|
self.W_gate = nn.Linear(rank, intermediate_dim, bias=True)
|
|
|
|
|
|
|
|
|
self.W_out = nn.Linear(intermediate_dim, rank, bias=False)
|
|
|
|
|
|
|
|
|
self.W_out_shared = nn.Linear(rank, hidden_dim, bias=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.is_dense:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_proj = self.W_in(x)
|
|
|
|
|
|
|
|
|
gate = self.W_gate(x)
|
|
|
|
|
|
|
|
|
x = F.silu(x_proj) * gate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.W_out(x)
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_proj = self.W_in(self.W_in_shared_norm(self.W_in_shared(x)))
|
|
|
|
|
|
|
|
|
|
|
|
gate = self.W_gate(self.W_gate_shared_norm(self.W_gate_shared(x)))
|
|
|
|
|
|
|
|
|
x = F.silu(x_proj) * gate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.W_out_shared(self.W_out(x))
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|