|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from .config import RippleConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RippleHead(nn.Module): |
|
|
""" |
|
|
Attention head using Decay-Biased (ALiBi-style) attention. |
|
|
|
|
|
The "Ripple Field" applies a learnable distance decay bias to the attention |
|
|
weights, allowing the model to generalize to sequence lengths beyond training. |
|
|
|
|
|
Memory Optimization (RFC-001): |
|
|
- Phase 1: SDPA (Scaled Dot Product Attention) which fuses softmax/dropout |
|
|
- Phase 2: Sliding Window Attention - limits attention to last `w` tokens |
|
|
|
|
|
Memory Complexity: |
|
|
- Full attention (window=None): O(T²) |
|
|
- Sliding window (window=w): O(T × w) - LINEAR in sequence length! |
|
|
|
|
|
Expected savings with window=512: ~90% memory reduction for T>2048 |
|
|
""" |
|
|
|
|
|
def __init__(self, config: RippleConfig, head_idx: int = 0): |
|
|
super().__init__() |
|
|
self.head_size = config.n_embd // config.n_head |
|
|
self.key = nn.Linear(config.n_embd, self.head_size, bias=config.bias) |
|
|
self.query = nn.Linear(config.n_embd, self.head_size, bias=config.bias) |
|
|
self.value = nn.Linear(config.n_embd, self.head_size, bias=config.bias) |
|
|
self.dropout_p = config.dropout |
|
|
|
|
|
|
|
|
|
|
|
self.attention_window = getattr(config, 'attention_window', None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_heads = config.n_head |
|
|
def get_slopes(n): |
|
|
def get_slopes_power_of_2(n): |
|
|
|
|
|
|
|
|
start = 0.5 |
|
|
ratio = 0.5 ** (8 / n) |
|
|
return [start * (ratio**i) for i in range(n)] |
|
|
|
|
|
if math.log2(n).is_integer(): |
|
|
return get_slopes_power_of_2(n) |
|
|
else: |
|
|
|
|
|
return get_slopes_power_of_2(2**math.ceil(math.log2(n)))[:n] |
|
|
|
|
|
slopes = get_slopes(num_heads) |
|
|
initial_decay = slopes[head_idx] |
|
|
|
|
|
|
|
|
self.decay_factor = nn.Parameter(torch.tensor([initial_decay])) |
|
|
|
|
|
|
|
|
self._cached_bias = None |
|
|
|
|
|
def _get_ripple_bias(self, T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: |
|
|
""" |
|
|
Get or create cached ripple bias with integrated causal mask. |
|
|
|
|
|
RFC-001 Phase 1 & 2 Optimization: |
|
|
- Phase 1: Bias is cached and only recreated when needed |
|
|
- Phase 2: When window is set, bias is only [T, window] instead of [T, T] |
|
|
|
|
|
The causal mask is fused into the bias using -inf for future tokens. |
|
|
""" |
|
|
current_decay = torch.abs(self.decay_factor).item() |
|
|
window = self.attention_window |
|
|
|
|
|
|
|
|
effective_size = min(T, window) if window else T |
|
|
|
|
|
|
|
|
needs_rebuild = ( |
|
|
self._cached_bias is None or |
|
|
self._cached_bias_size < effective_size or |
|
|
self._cached_decay_value != current_decay or |
|
|
self._cached_bias.device != device or |
|
|
self._cached_window != window |
|
|
) |
|
|
|
|
|
if needs_rebuild: |
|
|
if window and window < T: |
|
|
|
|
|
|
|
|
|
|
|
indices = torch.arange(window, device=device, dtype=dtype) |
|
|
dist = indices.unsqueeze(0) - indices.unsqueeze(1) |
|
|
else: |
|
|
|
|
|
indices = torch.arange(T, device=device, dtype=dtype) |
|
|
dist = indices.unsqueeze(0) - indices.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
ripple_bias = dist.clamp(max=0) * current_decay |
|
|
|
|
|
|
|
|
mask_value = torch.finfo(dtype).min |
|
|
ripple_bias = ripple_bias.masked_fill(dist > 0, mask_value) |
|
|
|
|
|
|
|
|
self._cached_bias = ripple_bias |
|
|
self._cached_bias_size = effective_size |
|
|
self._cached_decay_value = current_decay |
|
|
self._cached_window = window |
|
|
|
|
|
|
|
|
if window and window < T: |
|
|
return self._cached_bias[:min(T, window), :min(T, window)] |
|
|
return self._cached_bias[:T, :T] |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, C = x.shape |
|
|
window = self.attention_window |
|
|
|
|
|
|
|
|
q = self.query(x) |
|
|
k = self.key(x) |
|
|
v = self.value(x) |
|
|
|
|
|
|
|
|
if window and T > window: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = [] |
|
|
chunk_size = window |
|
|
|
|
|
for start in range(0, T, chunk_size): |
|
|
end = min(start + chunk_size, T) |
|
|
chunk_len = end - start |
|
|
|
|
|
|
|
|
kv_start = max(0, start - window + 1) |
|
|
kv_end = end |
|
|
kv_len = kv_end - kv_start |
|
|
|
|
|
|
|
|
q_chunk = q[:, start:end, :] |
|
|
|
|
|
|
|
|
k_chunk = k[:, kv_start:kv_end, :] |
|
|
v_chunk = v[:, kv_start:kv_end, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q_positions = torch.arange(start, end, device=x.device, dtype=q.dtype) |
|
|
k_positions = torch.arange(kv_start, kv_end, device=x.device, dtype=q.dtype) |
|
|
|
|
|
|
|
|
dist = k_positions.unsqueeze(0) - q_positions.unsqueeze(1) |
|
|
|
|
|
|
|
|
current_decay = torch.abs(self.decay_factor) |
|
|
ripple_bias = dist.clamp(max=0) * current_decay |
|
|
|
|
|
|
|
|
mask_value = torch.finfo(q.dtype).min |
|
|
ripple_bias = ripple_bias.masked_fill(dist > 0, mask_value) |
|
|
|
|
|
|
|
|
q_chunk = q_chunk.unsqueeze(1) |
|
|
k_chunk = k_chunk.unsqueeze(1) |
|
|
v_chunk = v_chunk.unsqueeze(1) |
|
|
|
|
|
|
|
|
y_chunk = F.scaled_dot_product_attention( |
|
|
q_chunk, k_chunk, v_chunk, |
|
|
attn_mask=ripple_bias, |
|
|
dropout_p=self.dropout_p if self.training else 0.0, |
|
|
is_causal=False |
|
|
) |
|
|
|
|
|
outputs.append(y_chunk.squeeze(1)) |
|
|
|
|
|
|
|
|
y = torch.cat(outputs, dim=1) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
ripple_bias = self._get_ripple_bias(T, x.device, q.dtype) |
|
|
|
|
|
|
|
|
q = q.unsqueeze(1) |
|
|
k = k.unsqueeze(1) |
|
|
v = v.unsqueeze(1) |
|
|
|
|
|
y = F.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
attn_mask=ripple_bias, |
|
|
dropout_p=self.dropout_p if self.training else 0.0, |
|
|
is_causal=False |
|
|
) |
|
|
|
|
|
y = y.squeeze(1) |
|
|
|
|
|
return y |
|
|
|
|
|
class RippleMLP(nn.Module): |
|
|
def __init__(self, config: RippleConfig): |
|
|
super().__init__() |
|
|
|
|
|
hidden_dim = int(config.n_embd * 8 / 3) |
|
|
if hidden_dim % 2 != 0: |
|
|
hidden_dim += 1 |
|
|
|
|
|
self.fc1 = nn.Linear(config.n_embd, hidden_dim) |
|
|
self.fc2 = nn.Linear(hidden_dim // 2, config.n_embd) |
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
h = self.fc1(x) |
|
|
x_val, x_gate = h.chunk(2, dim=-1) |
|
|
|
|
|
return self.dropout(self.fc2(x_val * F.silu(x_gate))) |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config: RippleConfig): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(config.n_embd) |
|
|
self.heads = nn.ModuleList([RippleHead(config, i) for i in range(config.n_head)]) |
|
|
self.ln2 = nn.LayerNorm(config.n_embd) |
|
|
self.ffwd = RippleMLP(config) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
heads_out = torch.cat([h(self.ln1(x)) for h in self.heads], dim=-1) |
|
|
x = x + heads_out |
|
|
x = x + self.ffwd(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
class RippleGPT(nn.Module): |
|
|
def __init__(self, config: RippleConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
|
|
|
if config.use_absolute_pos_emb: |
|
|
self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd) |
|
|
|
|
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) |
|
|
self.ln_f = nn.LayerNorm(config.n_embd) |
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
|
|
def forward(self, idx, targets=None): |
|
|
B, T = idx.shape |
|
|
device = idx.device |
|
|
|
|
|
x = self.token_embedding_table(idx) |
|
|
|
|
|
if self.config.use_absolute_pos_emb: |
|
|
pos = torch.arange(T, device=device) |
|
|
x = x + self.position_embedding_table(pos) |
|
|
|
|
|
x = self.blocks(x) |
|
|
x = self.ln_f(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
loss = None |
|
|
if targets is not None: |
|
|
B, T, C = logits.shape |
|
|
flat_logits = logits.view(B*T, C) |
|
|
flat_targets = targets.view(B*T) |
|
|
loss = F.cross_entropy(flat_logits, flat_targets) |
|
|
return logits, loss |
|
|
|
|
|
def get_decay_stats(self): |
|
|
"""Returns statistics about the learned decay factors across all heads.""" |
|
|
decays = [] |
|
|
for block in self.blocks: |
|
|
for head in block.heads: |
|
|
decays.append(torch.abs(head.decay_factor).item()) |
|
|
decays = torch.tensor(decays) |
|
|
return { |
|
|
'min': decays.min().item(), |
|
|
'max': decays.max().item(), |
|
|
'mean': decays.mean().item(), |
|
|
'std': decays.std().item() |
|
|
} |
|
|
|
|
|
|
|
|
def get_num_params(self): |
|
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
|
""" |
|
|
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
|
|
the sequence max_new_tokens times, feeding the predictions back into the model each time. |
|
|
""" |
|
|
for _ in range(max_new_tokens): |
|
|
|
|
|
if self.config.use_absolute_pos_emb: |
|
|
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idx_cond = idx |
|
|
|
|
|
|
|
|
logits, _ = self(idx_cond) |
|
|
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
|
|
return idx |
|
|
|