Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.fog.config import FOGConfig | |
| from src.fog.model_runtime import RMSNorm, RuntimeStructuredAttention, RuntimeStructuredBlock, RuntimeStructuredFFN | |
| from src.fog.model_structured_v2 import ( | |
| LayerGeometryV2, | |
| StructuredMotifBlockV2, | |
| build_layer_geometries_v2, | |
| ) | |
| class LocalSyntaxBranch(nn.Module): | |
| def __init__(self, d_model: int, d_local: int, kernel_size: int, dropout: float) -> None: | |
| super().__init__() | |
| self.kernel_size = kernel_size | |
| self.in_proj = nn.Linear(d_model, d_local) | |
| self.depthwise = nn.Conv1d( | |
| in_channels=d_local, | |
| out_channels=d_local, | |
| kernel_size=kernel_size, | |
| groups=d_local, | |
| bias=True, | |
| ) | |
| self.pointwise = nn.Conv1d( | |
| in_channels=d_local, | |
| out_channels=d_local, | |
| kernel_size=1, | |
| bias=True, | |
| ) | |
| self.out_proj = nn.Linear(d_local, d_model) | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| h = F.silu(self.in_proj(x)) | |
| h = h.transpose(1, 2) | |
| h = F.pad(h, (self.kernel_size - 1, 0)) | |
| h = self.depthwise(h) | |
| h = F.silu(self.pointwise(h)) | |
| h = h.transpose(1, 2) | |
| h = self.drop(h) | |
| return self.out_proj(h) | |
| class CodeAwareStructuredBlock(nn.Module): | |
| def __init__(self, d_model: int, n_heads: int, geometry: LayerGeometryV2, dropout: float) -> None: | |
| super().__init__() | |
| self.geometry = geometry | |
| self.norm1 = RMSNorm(d_model) | |
| self.norm2 = RMSNorm(d_model) | |
| self.norm_local = RMSNorm(d_model) | |
| self.attn = RuntimeStructuredAttention(d_model, geometry.d_compare, geometry.d_memory, n_heads) | |
| self.ffn = RuntimeStructuredFFN(d_model, geometry, dropout) | |
| d_local = max(32, min(d_model, geometry.d_compare * 2)) | |
| kernel_size = 5 if geometry.stage != "late" else 3 | |
| self.local_branch = LocalSyntaxBranch( | |
| d_model=d_model, | |
| d_local=d_local, | |
| kernel_size=kernel_size, | |
| dropout=dropout, | |
| ) | |
| self.drop = nn.Dropout(dropout) | |
| self.attn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale))) | |
| self.ffn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale))) | |
| self.local_scale = nn.Parameter(torch.tensor(0.07 if geometry.stage == "early" else 0.05)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn_scale * self.drop(self.attn(self.norm1(x))) | |
| x = x + self.local_scale * self.drop(self.local_branch(self.norm_local(x))) | |
| x = x + self.ffn_scale * self.drop(self.ffn(self.norm2(x))) | |
| return x | |
| class CodeAwareStructuredTransformer(nn.Module): | |
| def __init__(self, cfg: FOGConfig) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.layer_geometries = build_layer_geometries_v2(cfg) | |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) | |
| self.drop = nn.Dropout(cfg.dropout) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| CodeAwareStructuredBlock( | |
| d_model=cfg.d_model, | |
| n_heads=cfg.n_heads, | |
| geometry=geometry, | |
| dropout=cfg.dropout, | |
| ) | |
| for geometry in self.layer_geometries | |
| ] | |
| ) | |
| self.norm_f = RMSNorm(cfg.d_model) | |
| self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| self.tok_emb.weight = self.head.weight | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| loss_mask: torch.Tensor | None = None, | |
| ) -> dict[str, torch.Tensor | list[dict[str, int | str]]]: | |
| _, t = input_ids.shape | |
| pos = torch.arange(t, device=input_ids.device).unsqueeze(0) | |
| x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.norm_f(x) | |
| logits = self.head(x) | |
| loss = None | |
| if targets is not None: | |
| if loss_mask is not None: | |
| flat_logits = logits.view(-1, logits.size(-1)) | |
| flat_targets = targets.view(-1) | |
| flat_mask = loss_mask.view(-1).bool() | |
| if flat_mask.any(): | |
| loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask]) | |
| else: | |
| loss = torch.tensor(0.0, device=logits.device) | |
| else: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| geometry_summary = [ | |
| { | |
| "stage": g.stage, | |
| "d_compare": g.d_compare, | |
| "d_memory": g.d_memory, | |
| "d_expand": g.d_expand, | |
| "d_gate": g.d_gate, | |
| } | |
| for g in self.layer_geometries | |
| ] | |
| return {"logits": logits, "loss": loss, "geometry": geometry_summary} | |
| class LocalSyntaxBranchLight(nn.Module): | |
| def __init__(self, d_model: int, d_local: int, kernel_size: int, dropout: float) -> None: | |
| super().__init__() | |
| self.kernel_size = kernel_size | |
| self.in_proj = nn.Linear(d_model, d_local) | |
| self.depthwise = nn.Conv1d( | |
| in_channels=d_local, | |
| out_channels=d_local, | |
| kernel_size=kernel_size, | |
| groups=d_local, | |
| bias=True, | |
| ) | |
| self.out_proj = nn.Linear(d_local, d_model) | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| h = F.silu(self.in_proj(x)) | |
| h = h.transpose(1, 2) | |
| h = F.pad(h, (self.kernel_size - 1, 0)) | |
| h = self.depthwise(h) | |
| h = h.transpose(1, 2) | |
| h = self.drop(F.silu(h)) | |
| return self.out_proj(h) | |
| class CodeAwareStructuredLightBlock(nn.Module): | |
| def __init__(self, d_model: int, n_heads: int, geometry: LayerGeometryV2, dropout: float) -> None: | |
| super().__init__() | |
| self.geometry = geometry | |
| self.norm1 = RMSNorm(d_model) | |
| self.norm2 = RMSNorm(d_model) | |
| self.norm_local = RMSNorm(d_model) | |
| self.attn = RuntimeStructuredAttention(d_model, geometry.d_compare, geometry.d_memory, n_heads) | |
| self.ffn = RuntimeStructuredFFN(d_model, geometry, dropout) | |
| self.use_local = geometry.stage == "early" | |
| if self.use_local: | |
| d_local = max(16, min(d_model // 4, geometry.d_compare)) | |
| self.local_branch = LocalSyntaxBranchLight( | |
| d_model=d_model, | |
| d_local=d_local, | |
| kernel_size=3, | |
| dropout=dropout, | |
| ) | |
| self.local_scale = nn.Parameter(torch.tensor(0.035)) | |
| else: | |
| self.local_branch = None | |
| self.local_scale = None | |
| self.drop = nn.Dropout(dropout) | |
| self.attn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale))) | |
| self.ffn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale))) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.attn_scale * self.drop(self.attn(self.norm1(x))) | |
| if self.local_branch is not None and self.local_scale is not None: | |
| x = x + self.local_scale * self.drop(self.local_branch(self.norm_local(x))) | |
| x = x + self.ffn_scale * self.drop(self.ffn(self.norm2(x))) | |
| return x | |
| class CodeAwareStructuredLightTransformer(nn.Module): | |
| def __init__(self, cfg: FOGConfig) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.layer_geometries = build_layer_geometries_v2(cfg) | |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) | |
| self.drop = nn.Dropout(cfg.dropout) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| CodeAwareStructuredLightBlock( | |
| d_model=cfg.d_model, | |
| n_heads=cfg.n_heads, | |
| geometry=geometry, | |
| dropout=cfg.dropout, | |
| ) | |
| for geometry in self.layer_geometries | |
| ] | |
| ) | |
| self.norm_f = RMSNorm(cfg.d_model) | |
| self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| self.tok_emb.weight = self.head.weight | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| loss_mask: torch.Tensor | None = None, | |
| ) -> dict[str, torch.Tensor | list[dict[str, int | str]]]: | |
| _, t = input_ids.shape | |
| pos = torch.arange(t, device=input_ids.device).unsqueeze(0) | |
| x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.norm_f(x) | |
| logits = self.head(x) | |
| loss = None | |
| if targets is not None: | |
| if loss_mask is not None: | |
| flat_logits = logits.view(-1, logits.size(-1)) | |
| flat_targets = targets.view(-1) | |
| flat_mask = loss_mask.view(-1).bool() | |
| if flat_mask.any(): | |
| loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask]) | |
| else: | |
| loss = torch.tensor(0.0, device=logits.device) | |
| else: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| geometry_summary = [ | |
| { | |
| "stage": g.stage, | |
| "d_compare": g.d_compare, | |
| "d_memory": g.d_memory, | |
| "d_expand": g.d_expand, | |
| "d_gate": g.d_gate, | |
| } | |
| for g in self.layer_geometries | |
| ] | |
| return {"logits": logits, "loss": loss, "geometry": geometry_summary} | |
| class RecentCopyBias(nn.Module): | |
| def __init__(self, d_model: int, vocab_size: int, window: int = 8) -> None: | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.window = window | |
| self.copy_gate = nn.Linear(d_model, 1) | |
| self.lag_logits = nn.Parameter(torch.linspace(0.0, -1.5, steps=window)) | |
| self.copy_scale = nn.Parameter(torch.tensor(1.25)) | |
| def forward(self, hidden: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: | |
| b, t, _ = hidden.shape | |
| gate = torch.sigmoid(self.copy_gate(hidden)).squeeze(-1) | |
| lag_weights = F.softmax(self.lag_logits, dim=0) | |
| bias = hidden.new_zeros((b, t, self.vocab_size)) | |
| for lag in range(self.window): | |
| weight = lag_weights[lag] | |
| if lag == 0: | |
| tokens = input_ids | |
| contrib = gate * weight | |
| bias.scatter_add_(2, tokens.unsqueeze(-1), contrib.unsqueeze(-1)) | |
| elif lag < t: | |
| tokens = input_ids[:, :-lag] | |
| contrib = gate[:, lag:] * weight | |
| bias[:, lag:, :].scatter_add_(2, tokens.unsqueeze(-1), contrib.unsqueeze(-1)) | |
| return self.copy_scale * bias | |
| class CodeAwareCopyTransformer(nn.Module): | |
| def __init__(self, cfg: FOGConfig, copy_window: int = 8) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.layer_geometries = build_layer_geometries_v2(cfg) | |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) | |
| self.drop = nn.Dropout(cfg.dropout) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| RuntimeStructuredBlock( | |
| d_model=cfg.d_model, | |
| n_heads=cfg.n_heads, | |
| geometry=geometry, | |
| dropout=cfg.dropout, | |
| ) | |
| for geometry in self.layer_geometries | |
| ] | |
| ) | |
| self.norm_f = RMSNorm(cfg.d_model) | |
| self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| self.tok_emb.weight = self.head.weight | |
| self.copy_bias = RecentCopyBias(cfg.d_model, cfg.vocab_size, window=copy_window) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| loss_mask: torch.Tensor | None = None, | |
| ) -> dict[str, torch.Tensor | list[dict[str, int | str]]]: | |
| _, t = input_ids.shape | |
| pos = torch.arange(t, device=input_ids.device).unsqueeze(0) | |
| x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.norm_f(x) | |
| logits = self.head(x) | |
| logits = logits + self.copy_bias(x, input_ids) | |
| loss = None | |
| if targets is not None: | |
| if loss_mask is not None: | |
| flat_logits = logits.view(-1, logits.size(-1)) | |
| flat_targets = targets.view(-1) | |
| flat_mask = loss_mask.view(-1).bool() | |
| if flat_mask.any(): | |
| loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask]) | |
| else: | |
| loss = torch.tensor(0.0, device=logits.device) | |
| else: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| geometry_summary = [ | |
| { | |
| "stage": g.stage, | |
| "d_compare": g.d_compare, | |
| "d_memory": g.d_memory, | |
| "d_expand": g.d_expand, | |
| "d_gate": g.d_gate, | |
| } | |
| for g in self.layer_geometries | |
| ] | |
| return {"logits": logits, "loss": loss, "geometry": geometry_summary} | |
| class StructuredV2CopyTransformer(nn.Module): | |
| def __init__(self, cfg: FOGConfig, copy_window: int = 8) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.layer_geometries = build_layer_geometries_v2(cfg) | |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) | |
| self.drop = nn.Dropout(cfg.dropout) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| StructuredMotifBlockV2( | |
| d_model=cfg.d_model, | |
| n_heads=cfg.n_heads, | |
| geometry=geometry, | |
| dropout=cfg.dropout, | |
| ) | |
| for geometry in self.layer_geometries | |
| ] | |
| ) | |
| self.ln_f = nn.LayerNorm(cfg.d_model) | |
| self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| self.tok_emb.weight = self.head.weight | |
| self.copy_bias = RecentCopyBias(cfg.d_model, cfg.vocab_size, window=copy_window) | |
| self.copy_blend = nn.Parameter(torch.tensor(0.85)) | |
| self.register_buffer( | |
| "_causal_mask", | |
| torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool)).unsqueeze(0).unsqueeze(0), | |
| persistent=False, | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| loss_mask: torch.Tensor | None = None, | |
| ) -> dict[str, torch.Tensor | list[dict[str, int | str]]]: | |
| _, t = input_ids.shape | |
| pos = torch.arange(t, device=input_ids.device).unsqueeze(0) | |
| x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) | |
| mask = self._causal_mask[:, :, :t, :t] | |
| for block in self.blocks: | |
| x = block(x, mask) | |
| x = self.ln_f(x) | |
| logits = self.head(x) | |
| logits = logits + self.copy_blend * self.copy_bias(x, input_ids) | |
| loss = None | |
| if targets is not None: | |
| if loss_mask is not None: | |
| flat_logits = logits.view(-1, logits.size(-1)) | |
| flat_targets = targets.view(-1) | |
| flat_mask = loss_mask.view(-1).bool() | |
| if flat_mask.any(): | |
| loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask]) | |
| else: | |
| loss = torch.tensor(0.0, device=logits.device) | |
| else: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| geometry_summary = [ | |
| { | |
| "stage": g.stage, | |
| "d_compare": g.d_compare, | |
| "d_memory": g.d_memory, | |
| "d_expand": g.d_expand, | |
| "d_gate": g.d_gate, | |
| } | |
| for g in self.layer_geometries | |
| ] | |
| return {"logits": logits, "loss": loss, "geometry": geometry_summary} | |