Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.model.backbone import TransformerBlock | |
| from src.model.config import ModelConfig | |
| class FogTaskProfile: | |
| name: str | |
| compare_ratio: float | |
| memory_ratio: float | |
| expand_ratio: float | |
| gate_ratio: float | |
| hybrid_start_ratio: float | |
| max_layers: int | |
| adapter_scale: float | |
| class FogLayerGeometry: | |
| layer_idx: int | |
| stage: str | |
| d_compare: int | |
| d_memory: int | |
| d_expand: int | |
| d_gate: int | |
| residual_scale: float | |
| def _align_to_heads(value: int, n_heads: int) -> int: | |
| aligned = max(n_heads, (value // n_heads) * n_heads) | |
| if aligned < value: | |
| aligned += n_heads | |
| return aligned | |
| def resolve_fog_task_profile(cfg: ModelConfig) -> FogTaskProfile: | |
| profile = cfg.fog_task_profile | |
| if profile == "stories": | |
| return FogTaskProfile( | |
| name="stories", | |
| compare_ratio=0.18, | |
| memory_ratio=0.60, | |
| expand_ratio=1.35, | |
| gate_ratio=0.08, | |
| hybrid_start_ratio=0.67, | |
| max_layers=2, | |
| adapter_scale=0.11, | |
| ) | |
| if profile == "code": | |
| return FogTaskProfile( | |
| name="code", | |
| compare_ratio=0.34, | |
| memory_ratio=0.82, | |
| expand_ratio=1.70, | |
| gate_ratio=0.12, | |
| hybrid_start_ratio=0.45, | |
| max_layers=3, | |
| adapter_scale=0.16, | |
| ) | |
| if profile == "math": | |
| return FogTaskProfile( | |
| name="math", | |
| compare_ratio=0.42, | |
| memory_ratio=0.70, | |
| expand_ratio=1.95, | |
| gate_ratio=0.14, | |
| hybrid_start_ratio=0.45, | |
| max_layers=3, | |
| adapter_scale=0.18, | |
| ) | |
| if profile == "synthetic": | |
| return FogTaskProfile( | |
| name="synthetic", | |
| compare_ratio=0.32, | |
| memory_ratio=0.84, | |
| expand_ratio=2.10, | |
| gate_ratio=0.14, | |
| hybrid_start_ratio=0.0, | |
| max_layers=cfg.n_layers, | |
| adapter_scale=0.20, | |
| ) | |
| return FogTaskProfile( | |
| name="balanced", | |
| compare_ratio=cfg.fog_compare_ratio, | |
| memory_ratio=cfg.fog_memory_ratio, | |
| expand_ratio=cfg.fog_expand_ratio, | |
| gate_ratio=cfg.fog_gate_ratio, | |
| hybrid_start_ratio=0.55, | |
| max_layers=min(2, cfg.n_layers), | |
| adapter_scale=0.13, | |
| ) | |
| def select_fog_adapter_layers(cfg: ModelConfig, profile: FogTaskProfile) -> list[int]: | |
| start_idx = min(cfg.n_layers - 1, max(0, int(cfg.n_layers * profile.hybrid_start_ratio))) | |
| candidate_layers = list(range(start_idx, cfg.n_layers)) | |
| if len(candidate_layers) <= profile.max_layers: | |
| return candidate_layers | |
| return candidate_layers[-profile.max_layers:] | |
| def build_fog_geometries(cfg: ModelConfig) -> list[FogLayerGeometry]: | |
| profile = resolve_fog_task_profile(cfg) | |
| adapter_layers = select_fog_adapter_layers(cfg, profile) | |
| geometries: list[FogLayerGeometry] = [] | |
| if not adapter_layers: | |
| return geometries | |
| for adapter_pos, layer_idx in enumerate(adapter_layers): | |
| depth = adapter_pos / max(len(adapter_layers) - 1, 1) | |
| if depth < 0.34: | |
| stage = "early" | |
| compare_ratio = profile.compare_ratio * 0.95 | |
| memory_ratio = profile.memory_ratio * 0.90 | |
| expand_ratio = profile.expand_ratio * 0.90 | |
| gate_ratio = profile.gate_ratio * 0.90 | |
| residual_scale = profile.adapter_scale * 0.85 | |
| elif depth < 0.67: | |
| stage = "middle" | |
| compare_ratio = profile.compare_ratio | |
| memory_ratio = profile.memory_ratio | |
| expand_ratio = profile.expand_ratio | |
| gate_ratio = profile.gate_ratio | |
| residual_scale = profile.adapter_scale | |
| else: | |
| stage = "late" | |
| compare_ratio = profile.compare_ratio * 1.05 | |
| memory_ratio = profile.memory_ratio * 1.05 | |
| expand_ratio = profile.expand_ratio * 1.10 | |
| gate_ratio = profile.gate_ratio * 1.10 | |
| residual_scale = profile.adapter_scale * 1.10 | |
| geometries.append( | |
| FogLayerGeometry( | |
| layer_idx=layer_idx, | |
| stage=stage, | |
| d_compare=_align_to_heads(max(cfg.n_heads, int(cfg.d_model * compare_ratio)), cfg.n_heads), | |
| d_memory=_align_to_heads(max(cfg.n_heads, int(cfg.d_model * memory_ratio)), cfg.n_heads), | |
| d_expand=max(cfg.d_model, int(cfg.d_model * expand_ratio)), | |
| d_gate=max(4, int(cfg.d_model * gate_ratio)), | |
| residual_scale=residual_scale, | |
| ) | |
| ) | |
| return geometries | |
| class FogAttention(nn.Module): | |
| def __init__(self, d_model: int, d_compare: int, d_memory: int, n_heads: int, dropout: float) -> None: | |
| super().__init__() | |
| assert d_compare % n_heads == 0 | |
| assert d_memory % n_heads == 0 | |
| self.n_heads = n_heads | |
| self.compare_head_dim = d_compare // n_heads | |
| self.memory_head_dim = d_memory // n_heads | |
| self.d_memory = d_memory | |
| self.q_proj = nn.Linear(d_model, d_compare) | |
| self.k_proj = nn.Linear(d_model, d_compare) | |
| self.v_proj = nn.Linear(d_model, d_memory) | |
| self.out_proj = nn.Linear(d_memory, d_model) | |
| self.attn_dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
| b, t, _ = x.shape | |
| q = self.q_proj(x).view(b, t, self.n_heads, self.compare_head_dim).transpose(1, 2) | |
| k = self.k_proj(x).view(b, t, self.n_heads, self.compare_head_dim).transpose(1, 2) | |
| v = self.v_proj(x).view(b, t, self.n_heads, self.memory_head_dim).transpose(1, 2) | |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.compare_head_dim) | |
| scores = scores.masked_fill(mask == 0, float("-inf")) | |
| attn = self.attn_dropout(torch.softmax(scores, dim=-1)) | |
| y = torch.matmul(attn, v) | |
| y = y.transpose(1, 2).contiguous().view(b, t, self.d_memory) | |
| return self.out_proj(y) | |
| class FogFFN(nn.Module): | |
| def __init__(self, d_model: int, d_expand: int, d_gate: int, dropout: float) -> None: | |
| super().__init__() | |
| self.expand = nn.Linear(d_model, d_expand) | |
| self.gate = nn.Linear(d_model, d_gate) | |
| self.gate_up = nn.Linear(d_gate, d_expand) | |
| self.compress = nn.Linear(d_expand, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| expanded = F.silu(self.expand(x)) | |
| gate = torch.sigmoid(self.gate_up(F.silu(self.gate(x)))) | |
| return self.compress(self.dropout(expanded * gate)) | |
| class FogAdapterBlock(nn.Module): | |
| def __init__(self, cfg: ModelConfig, geometry: FogLayerGeometry) -> None: | |
| super().__init__() | |
| self.geometry = geometry | |
| self.ln1 = nn.LayerNorm(cfg.d_model) | |
| self.ln2 = nn.LayerNorm(cfg.d_model) | |
| self.attn = FogAttention(cfg.d_model, geometry.d_compare, geometry.d_memory, cfg.n_heads, cfg.dropout) | |
| self.ffn = FogFFN(cfg.d_model, geometry.d_expand, geometry.d_gate, cfg.dropout) | |
| self.drop = nn.Dropout(cfg.dropout) | |
| self.attn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale))) | |
| self.ffn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale))) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
| attn_update = self.attn_scale * self.drop(self.attn(self.ln1(x), mask)) | |
| ffn_update = self.ffn_scale * self.drop(self.ffn(self.ln2(x + attn_update))) | |
| return attn_update + ffn_update | |
| class FogFlowBackbone(nn.Module): | |
| def __init__(self, cfg: ModelConfig) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.profile = resolve_fog_task_profile(cfg) | |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) | |
| self.drop = nn.Dropout(cfg.dropout) | |
| self.blocks = nn.ModuleList([TransformerBlock(cfg, i) for i in range(cfg.n_layers)]) | |
| self.layer_geometries = build_fog_geometries(cfg) | |
| self.fog_blocks = nn.ModuleDict( | |
| {str(geom.layer_idx): FogAdapterBlock(cfg, geom) for geom in self.layer_geometries} | |
| ) | |
| self.fog_layers = [geom.layer_idx for geom in self.layer_geometries] | |
| self.ln_final = nn.LayerNorm(cfg.d_model) | |
| self.register_buffer( | |
| "_causal_mask", | |
| torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool)).unsqueeze(0).unsqueeze(0), | |
| persistent=False, | |
| ) | |
| def forward(self, input_ids: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor] | list[int] | str]: | |
| _, t = input_ids.shape | |
| pos = torch.arange(t, device=input_ids.device).unsqueeze(0) | |
| x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) | |
| layer_outputs = [x] | |
| mask = self._causal_mask[:, :, :t, :t] | |
| for idx, block in enumerate(self.blocks): | |
| x = block(x, layer_outputs) | |
| if str(idx) in self.fog_blocks: | |
| x = x + self.fog_blocks[str(idx)](x, mask) | |
| layer_outputs.append(x) | |
| hidden = self.ln_final(x) | |
| return { | |
| "hidden": hidden, | |
| "layer_outputs": layer_outputs, | |
| "fog_layers": self.fog_layers, | |
| "fog_profile": self.profile.name, | |
| "flow_type": "fog_hybrid", | |
| } | |