File size: 6,451 Bytes
eee6498 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
More stable and computationally efficient than LayerNorm.
Used in LLaMA, PaLM, Gopher.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class SwiGLU(nn.Module):
"""
Swish-Gated Linear Unit.
SOTA activation function for FFNs (outperforms GELU/ReLU).
"""
def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Gate mechanism: (x * sigmoid(x)) * linear(x)
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.w3(self.dropout(hidden))
class SEBlock(nn.Module):
"""
Squeeze-and-Excitation Block.
Allows the model to dynamically weight different dimensions of the embedding
based on global context.
"""
def __init__(self, dim: int, reduction: int = 4):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Sequential(
nn.Linear(dim, dim // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(dim // reduction, dim, bias=False),
nn.Sigmoid()
)
def forward(self, x):
# Input: [B, D] -> unsqueeze to [B, D, 1] for pool/conv compatibility if needed
# But here we are working with vectors, so we simulate it.
b, d = x.shape
y = self.fc(x) # [B, D]
return x * y
class DropPath(nn.Module):
"""Stochastic depth regularizer (Improved)."""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1.0 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class ModernBlock(nn.Module):
"""
A Pre-Norm Block combining RMSNorm, SwiGLU, and Channel Attention.
"""
def __init__(self, dim: int, expand: int = 4, dropout: float = 0.1,
layer_scale_init: float = 1e-6, drop_path: float = 0.0):
super().__init__()
# 1. Normalization
self.norm = RMSNorm(dim)
# 2. SOTA Feed Forward (SwiGLU)
# SwiGLU usually requires 2/3 hidden dim of standard MLP to match params,
# but we keep it high for expressivity.
self.ffn = SwiGLU(dim, int(dim * expand * 2 / 3), dropout=dropout)
# 3. Channel Attention (Context awareness)
self.se = SEBlock(dim, reduction=4)
# 4. Regularization
self.layer_scale = nn.Parameter(torch.ones(dim) * layer_scale_init) if layer_scale_init > 0 else None
self.drop_path = DropPath(drop_path)
def forward(self, x):
residual = x
# Pre-Norm Architecture
out = self.norm(x)
out = self.ffn(out)
out = self.se(out) # Apply attention
if self.layer_scale is not None:
out = out * self.layer_scale
out = self.drop_path(out)
return residual + out
class ModernTrajectoryNet(nn.Module):
def __init__(self, config):
super().__init__()
self.d_model = config.d_model
self.n_layers = config.n_layers
# Config defaults
dropout = getattr(config, "dropout", 0.1)
expand = getattr(config, "expand", 4)
drop_path_rate = getattr(config, "drop_path_rate", 0.1)
# Input Projection (Projects to latent space)
self.input_proj = nn.Sequential(
RMSNorm(self.d_model),
nn.Linear(self.d_model, self.d_model)
)
# Backbone
self.blocks = nn.ModuleList([
ModernBlock(
dim=self.d_model,
expand=expand,
dropout=dropout,
drop_path=drop_path_rate * (i / (self.n_layers - 1)) # Linear decay
) for i in range(self.n_layers)
])
self.final_norm = RMSNorm(self.d_model)
# Projector Head (SimCLR / CLIP style)
# Important: Keep high dimension for the final linear probe
self.head = nn.Sequential(
nn.Linear(self.d_model, self.d_model),
nn.GELU(),
nn.Linear(self.d_model, self.d_model)
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, return_trajectory=False):
# Handle sequence dimension if present
if x.dim() == 3:
x = x.mean(dim=1)
x = self.input_proj(x)
trajectory = []
for block in self.blocks:
x = block(x)
trajectory.append(x)
x = self.final_norm(x)
# Residual connection to original input is implicit via the blocks,
# but for trajectory learning, we want the final head to dictate the shift.
output = self.head(x)
# OPTIONAL: Add Denoising / Residual connection to input
# output = output + input_tensor_if_saved
if return_trajectory:
return output, torch.stack(trajectory, dim=1)
return output
# Backwards compatibility
HybridMambaAttentionModel = ModernTrajectoryNet
|