abpt / src /fog /model_code.py
Search
auto: sync run_testformer_wikitext_combo_remote.py
f37be5a
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.fog.config import FOGConfig
from src.fog.model_runtime import RMSNorm, RuntimeStructuredAttention, RuntimeStructuredBlock, RuntimeStructuredFFN
from src.fog.model_structured_v2 import (
LayerGeometryV2,
StructuredMotifBlockV2,
build_layer_geometries_v2,
)
class LocalSyntaxBranch(nn.Module):
def __init__(self, d_model: int, d_local: int, kernel_size: int, dropout: float) -> None:
super().__init__()
self.kernel_size = kernel_size
self.in_proj = nn.Linear(d_model, d_local)
self.depthwise = nn.Conv1d(
in_channels=d_local,
out_channels=d_local,
kernel_size=kernel_size,
groups=d_local,
bias=True,
)
self.pointwise = nn.Conv1d(
in_channels=d_local,
out_channels=d_local,
kernel_size=1,
bias=True,
)
self.out_proj = nn.Linear(d_local, d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = F.silu(self.in_proj(x))
h = h.transpose(1, 2)
h = F.pad(h, (self.kernel_size - 1, 0))
h = self.depthwise(h)
h = F.silu(self.pointwise(h))
h = h.transpose(1, 2)
h = self.drop(h)
return self.out_proj(h)
class CodeAwareStructuredBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, geometry: LayerGeometryV2, dropout: float) -> None:
super().__init__()
self.geometry = geometry
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
self.norm_local = RMSNorm(d_model)
self.attn = RuntimeStructuredAttention(d_model, geometry.d_compare, geometry.d_memory, n_heads)
self.ffn = RuntimeStructuredFFN(d_model, geometry, dropout)
d_local = max(32, min(d_model, geometry.d_compare * 2))
kernel_size = 5 if geometry.stage != "late" else 3
self.local_branch = LocalSyntaxBranch(
d_model=d_model,
d_local=d_local,
kernel_size=kernel_size,
dropout=dropout,
)
self.drop = nn.Dropout(dropout)
self.attn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale)))
self.ffn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale)))
self.local_scale = nn.Parameter(torch.tensor(0.07 if geometry.stage == "early" else 0.05))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn_scale * self.drop(self.attn(self.norm1(x)))
x = x + self.local_scale * self.drop(self.local_branch(self.norm_local(x)))
x = x + self.ffn_scale * self.drop(self.ffn(self.norm2(x)))
return x
class CodeAwareStructuredTransformer(nn.Module):
def __init__(self, cfg: FOGConfig) -> None:
super().__init__()
self.cfg = cfg
self.layer_geometries = build_layer_geometries_v2(cfg)
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList(
[
CodeAwareStructuredBlock(
d_model=cfg.d_model,
n_heads=cfg.n_heads,
geometry=geometry,
dropout=cfg.dropout,
)
for geometry in self.layer_geometries
]
)
self.norm_f = RMSNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.tok_emb.weight = self.head.weight
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor | None = None,
loss_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor | list[dict[str, int | str]]]:
_, t = input_ids.shape
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
for block in self.blocks:
x = block(x)
x = self.norm_f(x)
logits = self.head(x)
loss = None
if targets is not None:
if loss_mask is not None:
flat_logits = logits.view(-1, logits.size(-1))
flat_targets = targets.view(-1)
flat_mask = loss_mask.view(-1).bool()
if flat_mask.any():
loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask])
else:
loss = torch.tensor(0.0, device=logits.device)
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
geometry_summary = [
{
"stage": g.stage,
"d_compare": g.d_compare,
"d_memory": g.d_memory,
"d_expand": g.d_expand,
"d_gate": g.d_gate,
}
for g in self.layer_geometries
]
return {"logits": logits, "loss": loss, "geometry": geometry_summary}
class LocalSyntaxBranchLight(nn.Module):
def __init__(self, d_model: int, d_local: int, kernel_size: int, dropout: float) -> None:
super().__init__()
self.kernel_size = kernel_size
self.in_proj = nn.Linear(d_model, d_local)
self.depthwise = nn.Conv1d(
in_channels=d_local,
out_channels=d_local,
kernel_size=kernel_size,
groups=d_local,
bias=True,
)
self.out_proj = nn.Linear(d_local, d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = F.silu(self.in_proj(x))
h = h.transpose(1, 2)
h = F.pad(h, (self.kernel_size - 1, 0))
h = self.depthwise(h)
h = h.transpose(1, 2)
h = self.drop(F.silu(h))
return self.out_proj(h)
class CodeAwareStructuredLightBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, geometry: LayerGeometryV2, dropout: float) -> None:
super().__init__()
self.geometry = geometry
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
self.norm_local = RMSNorm(d_model)
self.attn = RuntimeStructuredAttention(d_model, geometry.d_compare, geometry.d_memory, n_heads)
self.ffn = RuntimeStructuredFFN(d_model, geometry, dropout)
self.use_local = geometry.stage == "early"
if self.use_local:
d_local = max(16, min(d_model // 4, geometry.d_compare))
self.local_branch = LocalSyntaxBranchLight(
d_model=d_model,
d_local=d_local,
kernel_size=3,
dropout=dropout,
)
self.local_scale = nn.Parameter(torch.tensor(0.035))
else:
self.local_branch = None
self.local_scale = None
self.drop = nn.Dropout(dropout)
self.attn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale)))
self.ffn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn_scale * self.drop(self.attn(self.norm1(x)))
if self.local_branch is not None and self.local_scale is not None:
x = x + self.local_scale * self.drop(self.local_branch(self.norm_local(x)))
x = x + self.ffn_scale * self.drop(self.ffn(self.norm2(x)))
return x
class CodeAwareStructuredLightTransformer(nn.Module):
def __init__(self, cfg: FOGConfig) -> None:
super().__init__()
self.cfg = cfg
self.layer_geometries = build_layer_geometries_v2(cfg)
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList(
[
CodeAwareStructuredLightBlock(
d_model=cfg.d_model,
n_heads=cfg.n_heads,
geometry=geometry,
dropout=cfg.dropout,
)
for geometry in self.layer_geometries
]
)
self.norm_f = RMSNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.tok_emb.weight = self.head.weight
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor | None = None,
loss_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor | list[dict[str, int | str]]]:
_, t = input_ids.shape
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
for block in self.blocks:
x = block(x)
x = self.norm_f(x)
logits = self.head(x)
loss = None
if targets is not None:
if loss_mask is not None:
flat_logits = logits.view(-1, logits.size(-1))
flat_targets = targets.view(-1)
flat_mask = loss_mask.view(-1).bool()
if flat_mask.any():
loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask])
else:
loss = torch.tensor(0.0, device=logits.device)
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
geometry_summary = [
{
"stage": g.stage,
"d_compare": g.d_compare,
"d_memory": g.d_memory,
"d_expand": g.d_expand,
"d_gate": g.d_gate,
}
for g in self.layer_geometries
]
return {"logits": logits, "loss": loss, "geometry": geometry_summary}
class RecentCopyBias(nn.Module):
def __init__(self, d_model: int, vocab_size: int, window: int = 8) -> None:
super().__init__()
self.vocab_size = vocab_size
self.window = window
self.copy_gate = nn.Linear(d_model, 1)
self.lag_logits = nn.Parameter(torch.linspace(0.0, -1.5, steps=window))
self.copy_scale = nn.Parameter(torch.tensor(1.25))
def forward(self, hidden: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
b, t, _ = hidden.shape
gate = torch.sigmoid(self.copy_gate(hidden)).squeeze(-1)
lag_weights = F.softmax(self.lag_logits, dim=0)
bias = hidden.new_zeros((b, t, self.vocab_size))
for lag in range(self.window):
weight = lag_weights[lag]
if lag == 0:
tokens = input_ids
contrib = gate * weight
bias.scatter_add_(2, tokens.unsqueeze(-1), contrib.unsqueeze(-1))
elif lag < t:
tokens = input_ids[:, :-lag]
contrib = gate[:, lag:] * weight
bias[:, lag:, :].scatter_add_(2, tokens.unsqueeze(-1), contrib.unsqueeze(-1))
return self.copy_scale * bias
class CodeAwareCopyTransformer(nn.Module):
def __init__(self, cfg: FOGConfig, copy_window: int = 8) -> None:
super().__init__()
self.cfg = cfg
self.layer_geometries = build_layer_geometries_v2(cfg)
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList(
[
RuntimeStructuredBlock(
d_model=cfg.d_model,
n_heads=cfg.n_heads,
geometry=geometry,
dropout=cfg.dropout,
)
for geometry in self.layer_geometries
]
)
self.norm_f = RMSNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.tok_emb.weight = self.head.weight
self.copy_bias = RecentCopyBias(cfg.d_model, cfg.vocab_size, window=copy_window)
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor | None = None,
loss_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor | list[dict[str, int | str]]]:
_, t = input_ids.shape
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
for block in self.blocks:
x = block(x)
x = self.norm_f(x)
logits = self.head(x)
logits = logits + self.copy_bias(x, input_ids)
loss = None
if targets is not None:
if loss_mask is not None:
flat_logits = logits.view(-1, logits.size(-1))
flat_targets = targets.view(-1)
flat_mask = loss_mask.view(-1).bool()
if flat_mask.any():
loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask])
else:
loss = torch.tensor(0.0, device=logits.device)
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
geometry_summary = [
{
"stage": g.stage,
"d_compare": g.d_compare,
"d_memory": g.d_memory,
"d_expand": g.d_expand,
"d_gate": g.d_gate,
}
for g in self.layer_geometries
]
return {"logits": logits, "loss": loss, "geometry": geometry_summary}
class StructuredV2CopyTransformer(nn.Module):
def __init__(self, cfg: FOGConfig, copy_window: int = 8) -> None:
super().__init__()
self.cfg = cfg
self.layer_geometries = build_layer_geometries_v2(cfg)
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList(
[
StructuredMotifBlockV2(
d_model=cfg.d_model,
n_heads=cfg.n_heads,
geometry=geometry,
dropout=cfg.dropout,
)
for geometry in self.layer_geometries
]
)
self.ln_f = nn.LayerNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.tok_emb.weight = self.head.weight
self.copy_bias = RecentCopyBias(cfg.d_model, cfg.vocab_size, window=copy_window)
self.copy_blend = nn.Parameter(torch.tensor(0.85))
self.register_buffer(
"_causal_mask",
torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool)).unsqueeze(0).unsqueeze(0),
persistent=False,
)
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor | None = None,
loss_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor | list[dict[str, int | str]]]:
_, t = input_ids.shape
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
mask = self._causal_mask[:, :, :t, :t]
for block in self.blocks:
x = block(x, mask)
x = self.ln_f(x)
logits = self.head(x)
logits = logits + self.copy_blend * self.copy_bias(x, input_ids)
loss = None
if targets is not None:
if loss_mask is not None:
flat_logits = logits.view(-1, logits.size(-1))
flat_targets = targets.view(-1)
flat_mask = loss_mask.view(-1).bool()
if flat_mask.any():
loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask])
else:
loss = torch.tensor(0.0, device=logits.device)
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
geometry_summary = [
{
"stage": g.stage,
"d_compare": g.d_compare,
"d_memory": g.d_memory,
"d_expand": g.d_expand,
"d_gate": g.d_gate,
}
for g in self.layer_geometries
]
return {"logits": logits, "loss": loss, "geometry": geometry_summary}