abpt / src /model /fog_flow.py
Search
auto: sync run_testformer_wikitext_combo_remote.py
f37be5a
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.backbone import TransformerBlock
from src.model.config import ModelConfig
@dataclass(frozen=True)
class FogTaskProfile:
name: str
compare_ratio: float
memory_ratio: float
expand_ratio: float
gate_ratio: float
hybrid_start_ratio: float
max_layers: int
adapter_scale: float
@dataclass(frozen=True)
class FogLayerGeometry:
layer_idx: int
stage: str
d_compare: int
d_memory: int
d_expand: int
d_gate: int
residual_scale: float
def _align_to_heads(value: int, n_heads: int) -> int:
aligned = max(n_heads, (value // n_heads) * n_heads)
if aligned < value:
aligned += n_heads
return aligned
def resolve_fog_task_profile(cfg: ModelConfig) -> FogTaskProfile:
profile = cfg.fog_task_profile
if profile == "stories":
return FogTaskProfile(
name="stories",
compare_ratio=0.18,
memory_ratio=0.60,
expand_ratio=1.35,
gate_ratio=0.08,
hybrid_start_ratio=0.67,
max_layers=2,
adapter_scale=0.11,
)
if profile == "code":
return FogTaskProfile(
name="code",
compare_ratio=0.34,
memory_ratio=0.82,
expand_ratio=1.70,
gate_ratio=0.12,
hybrid_start_ratio=0.45,
max_layers=3,
adapter_scale=0.16,
)
if profile == "math":
return FogTaskProfile(
name="math",
compare_ratio=0.42,
memory_ratio=0.70,
expand_ratio=1.95,
gate_ratio=0.14,
hybrid_start_ratio=0.45,
max_layers=3,
adapter_scale=0.18,
)
if profile == "synthetic":
return FogTaskProfile(
name="synthetic",
compare_ratio=0.32,
memory_ratio=0.84,
expand_ratio=2.10,
gate_ratio=0.14,
hybrid_start_ratio=0.0,
max_layers=cfg.n_layers,
adapter_scale=0.20,
)
return FogTaskProfile(
name="balanced",
compare_ratio=cfg.fog_compare_ratio,
memory_ratio=cfg.fog_memory_ratio,
expand_ratio=cfg.fog_expand_ratio,
gate_ratio=cfg.fog_gate_ratio,
hybrid_start_ratio=0.55,
max_layers=min(2, cfg.n_layers),
adapter_scale=0.13,
)
def select_fog_adapter_layers(cfg: ModelConfig, profile: FogTaskProfile) -> list[int]:
start_idx = min(cfg.n_layers - 1, max(0, int(cfg.n_layers * profile.hybrid_start_ratio)))
candidate_layers = list(range(start_idx, cfg.n_layers))
if len(candidate_layers) <= profile.max_layers:
return candidate_layers
return candidate_layers[-profile.max_layers:]
def build_fog_geometries(cfg: ModelConfig) -> list[FogLayerGeometry]:
profile = resolve_fog_task_profile(cfg)
adapter_layers = select_fog_adapter_layers(cfg, profile)
geometries: list[FogLayerGeometry] = []
if not adapter_layers:
return geometries
for adapter_pos, layer_idx in enumerate(adapter_layers):
depth = adapter_pos / max(len(adapter_layers) - 1, 1)
if depth < 0.34:
stage = "early"
compare_ratio = profile.compare_ratio * 0.95
memory_ratio = profile.memory_ratio * 0.90
expand_ratio = profile.expand_ratio * 0.90
gate_ratio = profile.gate_ratio * 0.90
residual_scale = profile.adapter_scale * 0.85
elif depth < 0.67:
stage = "middle"
compare_ratio = profile.compare_ratio
memory_ratio = profile.memory_ratio
expand_ratio = profile.expand_ratio
gate_ratio = profile.gate_ratio
residual_scale = profile.adapter_scale
else:
stage = "late"
compare_ratio = profile.compare_ratio * 1.05
memory_ratio = profile.memory_ratio * 1.05
expand_ratio = profile.expand_ratio * 1.10
gate_ratio = profile.gate_ratio * 1.10
residual_scale = profile.adapter_scale * 1.10
geometries.append(
FogLayerGeometry(
layer_idx=layer_idx,
stage=stage,
d_compare=_align_to_heads(max(cfg.n_heads, int(cfg.d_model * compare_ratio)), cfg.n_heads),
d_memory=_align_to_heads(max(cfg.n_heads, int(cfg.d_model * memory_ratio)), cfg.n_heads),
d_expand=max(cfg.d_model, int(cfg.d_model * expand_ratio)),
d_gate=max(4, int(cfg.d_model * gate_ratio)),
residual_scale=residual_scale,
)
)
return geometries
class FogAttention(nn.Module):
def __init__(self, d_model: int, d_compare: int, d_memory: int, n_heads: int, dropout: float) -> None:
super().__init__()
assert d_compare % n_heads == 0
assert d_memory % n_heads == 0
self.n_heads = n_heads
self.compare_head_dim = d_compare // n_heads
self.memory_head_dim = d_memory // n_heads
self.d_memory = d_memory
self.q_proj = nn.Linear(d_model, d_compare)
self.k_proj = nn.Linear(d_model, d_compare)
self.v_proj = nn.Linear(d_model, d_memory)
self.out_proj = nn.Linear(d_memory, d_model)
self.attn_dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
b, t, _ = x.shape
q = self.q_proj(x).view(b, t, self.n_heads, self.compare_head_dim).transpose(1, 2)
k = self.k_proj(x).view(b, t, self.n_heads, self.compare_head_dim).transpose(1, 2)
v = self.v_proj(x).view(b, t, self.n_heads, self.memory_head_dim).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.compare_head_dim)
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = self.attn_dropout(torch.softmax(scores, dim=-1))
y = torch.matmul(attn, v)
y = y.transpose(1, 2).contiguous().view(b, t, self.d_memory)
return self.out_proj(y)
class FogFFN(nn.Module):
def __init__(self, d_model: int, d_expand: int, d_gate: int, dropout: float) -> None:
super().__init__()
self.expand = nn.Linear(d_model, d_expand)
self.gate = nn.Linear(d_model, d_gate)
self.gate_up = nn.Linear(d_gate, d_expand)
self.compress = nn.Linear(d_expand, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
expanded = F.silu(self.expand(x))
gate = torch.sigmoid(self.gate_up(F.silu(self.gate(x))))
return self.compress(self.dropout(expanded * gate))
class FogAdapterBlock(nn.Module):
def __init__(self, cfg: ModelConfig, geometry: FogLayerGeometry) -> None:
super().__init__()
self.geometry = geometry
self.ln1 = nn.LayerNorm(cfg.d_model)
self.ln2 = nn.LayerNorm(cfg.d_model)
self.attn = FogAttention(cfg.d_model, geometry.d_compare, geometry.d_memory, cfg.n_heads, cfg.dropout)
self.ffn = FogFFN(cfg.d_model, geometry.d_expand, geometry.d_gate, cfg.dropout)
self.drop = nn.Dropout(cfg.dropout)
self.attn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale)))
self.ffn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale)))
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
attn_update = self.attn_scale * self.drop(self.attn(self.ln1(x), mask))
ffn_update = self.ffn_scale * self.drop(self.ffn(self.ln2(x + attn_update)))
return attn_update + ffn_update
class FogFlowBackbone(nn.Module):
def __init__(self, cfg: ModelConfig) -> None:
super().__init__()
self.cfg = cfg
self.profile = resolve_fog_task_profile(cfg)
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([TransformerBlock(cfg, i) for i in range(cfg.n_layers)])
self.layer_geometries = build_fog_geometries(cfg)
self.fog_blocks = nn.ModuleDict(
{str(geom.layer_idx): FogAdapterBlock(cfg, geom) for geom in self.layer_geometries}
)
self.fog_layers = [geom.layer_idx for geom in self.layer_geometries]
self.ln_final = nn.LayerNorm(cfg.d_model)
self.register_buffer(
"_causal_mask",
torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool)).unsqueeze(0).unsqueeze(0),
persistent=False,
)
def forward(self, input_ids: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor] | list[int] | str]:
_, t = input_ids.shape
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos))
layer_outputs = [x]
mask = self._causal_mask[:, :, :t, :t]
for idx, block in enumerate(self.blocks):
x = block(x, layer_outputs)
if str(idx) in self.fog_blocks:
x = x + self.fog_blocks[str(idx)](x, mask)
layer_outputs.append(x)
hidden = self.ln_final(x)
return {
"hidden": hidden,
"layer_outputs": layer_outputs,
"fog_layers": self.fog_layers,
"fog_profile": self.profile.name,
"flow_type": "fog_hybrid",
}