|
|
""" |
|
|
Feed-Forward Network for SLM. |
|
|
Uses GELU activation (not SwiGLU) for better INT8 quantization. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .config import SLMConfig |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
"""Feed-Forward Network with GELU activation. |
|
|
|
|
|
Architecture: Linear -> GELU -> Linear |
|
|
- Input: [batch, seq, hidden_size=768] |
|
|
- Hidden: [batch, seq, intermediate_size=3072] |
|
|
- Output: [batch, seq, hidden_size=768] |
|
|
|
|
|
Why GELU over SwiGLU: |
|
|
- Fewer operations (2 matmuls vs 3) |
|
|
- Better INT8 quantization behavior |
|
|
- Full QNN support without decomposition |
|
|
- SwiGLU benefits mainly appear at >1B parameters |
|
|
""" |
|
|
|
|
|
def __init__(self, config: SLMConfig): |
|
|
"""Initialize FFN. |
|
|
|
|
|
Args: |
|
|
config: Model configuration |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
|
|
|
|
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
|
|
|
|
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
|
|
|
self.dropout = config.dropout |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Forward pass through FFN. |
|
|
|
|
|
Args: |
|
|
x: Input tensor [batch, seq, hidden_size] |
|
|
|
|
|
Returns: |
|
|
Output tensor [batch, seq, hidden_size] |
|
|
""" |
|
|
|
|
|
hidden = self.up_proj(x) |
|
|
hidden = F.gelu(hidden, approximate="tanh") |
|
|
|
|
|
|
|
|
output = self.down_proj(hidden) |
|
|
|
|
|
|
|
|
if self.training and self.dropout > 0: |
|
|
output = F.dropout(output, p=self.dropout) |
|
|
|
|
|
return output |
|
|
|