DigitalDaimyo's picture
Upload analysis.py with huggingface_hub
9aaf11e verified
"""
Addressed State Attention (ASA) - Analysis Harness
Research implementation with mechanistic intervention capabilities.
For efficient training without interventions, use asm_training.py instead.
Features:
- Slot-mask causal interventions (slot_mask, slot_mask_where, slot_mask_scope)
- Refinement decomposition (orthogonal/parallel gating)
- Per-head geometry logging
- Configurable information storage (info_level, info_cfg)
Checkpoint Compatibility:
All parameter/buffer names match asm_training.py for weight sharing.
Do NOT rename: slot_keys, Wk_write, Wv_write, Wq_read, out_proj,
_alibi_slopes, _alibi_strength_param, _content_read_gamma_raw,
slot_in/slot_q/slot_k/slot_v/slot_out, _slotspace_gate_raw,
rope/rope_slotspace buffers.
Repository: https://github.com/DigitalDaimyo/AddressedStateAttention
Paper: https://github.com/DigitalDaimyo/AddressedStateAttention/tree/main/paper_drafts
"""
import math
from dataclasses import dataclass
from typing import Optional, Dict, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = [
'AddressedStateAttention',
'ASMBlock',
'ASMLanguageModel',
'ASMTrainConfig',
'build_model_from_cfg',
]
# ------------------------------------------------------------------ helpers ---
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, base: float = 10000.0):
super().__init__()
assert dim % 2 == 0, "RoPE requires even dim"
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._cos_cached = None
self._sin_cached = None
self._t_cached = None
self._device_cached = None
def get_cos_sin(self, T: int, device, dtype):
if (
self._t_cached == T
and self._cos_cached is not None
and self._device_cached == device
):
return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype)
t = torch.arange(T, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("t,f->tf", t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]
self._t_cached = T
self._device_cached = device
self._cos_cached = cos
self._sin_cached = sin
return cos.to(dtype=dtype), sin.to(dtype=dtype)
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
return (x * cos) + (_rotate_half(x) * sin)
def alibi_slopes(num_heads: int, device=None, dtype=torch.float32) -> torch.Tensor:
def get_slopes(n):
def power_of_2_slopes(n):
start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n).is_integer():
return power_of_2_slopes(n)
closest = 2 ** math.floor(math.log2(n))
return power_of_2_slopes(closest) + get_slopes(2 * closest)[0::2][: n - closest]
return torch.tensor(get_slopes(num_heads), device=device, dtype=dtype)
def _inv_softplus(y: torch.Tensor) -> torch.Tensor:
return torch.log(torch.expm1(y))
def phi(x: torch.Tensor) -> torch.Tensor:
"""Performer-style feature map (elu + 1)."""
return F.elu(x) + 1.0
# --------------------------------------------------------- main module ---
class AddressedStateAttention(nn.Module):
"""
Addressed State Attention (ASA) — unified research harness.
Core mechanism
--------------
* Prefix-softmax WRITE into K learned slots (streaming, O(T))
* READ routing from tokens → slots (softmax / top-k / external)
* Content-conditioned READ term (gamma-weighted)
* RoPE on write keys (geometry)
* ALiBi bias on write logits (prefix-friendly)
Slot-space refinement
---------------------
* Causal linear attention in a low-dim slot-address coordinate space
* Produces per-token signed weights over slots
* Decoded through the same streaming slot-state basis
* Gated by learnable ``slotspace_gate`` (softplus)
Causal intervention (slot mask)
-------------------------------
* ``slot_mask`` [K] float/bool, 1=keep 0=mask
* ``slot_mask_where`` "read" | "content_read_only" | "slotspace_only"
* ``slot_mask_scope`` "all" | "last_pos_only"
Refine-delta intervention (instance attrs, NO-OP by default)
----------------------------------------------------------------
* ``_intv_mode`` "off" | "delta_par" | "delta_orth" | "orth_gate" | …
* Decomposes refine delta into parallel / orthogonal vs base output
* See User Guide for configuration details.
Refine-geometry logging (NO output change)
------------------------------------------------
* ``_log_refine_geom = True`` enables per-head geometry vectors in info dict.
Info storage
------------
* ``info_level`` "basic" | "logits" | "full"
* ``info_cfg`` dict controlling which tensors to store, downsampling, CPU offload.
"""
# ---------------------------------------------------------------- init ---
def __init__(
self,
embed_dim: int,
num_heads: int = 8,
num_slots: int = 8,
dropout: float = 0.1,
# temperatures / numerics
read_temperature: float = 1.0,
write_temperature: float = 1.0,
state_fp32: bool = True,
slot_dropout: float = 0.0,
normalize_k: bool = False,
# positions (write geometry)
use_rope_keys: bool = True,
rope_base: float = 10000.0,
# write bias (ALiBi)
use_alibi_write: bool = True,
alibi_strength_init: float = 0.1,
learn_alibi_strength: bool = True,
min_strength: float = 0.0,
# content-conditioned read term
use_content_read: bool = True,
content_read_init: float = -4.0,
content_read_max_gamma: float = 3.0,
# slot-space refinement
use_slotspace_refine: bool = True,
slotspace_dim: int = 32,
slotspace_gate_init: float = -4.0,
slotspace_dropout: float = 0.05,
slotspace_signed_weights: bool = True,
# RoPE in slot-space matcher
use_rope_slotspace: bool = True,
rope_base_slotspace: float = 100000.0,
# perf knobs
write_chunk_size: int = 128,
slotspace_chunk_size: int = 128,
):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_slots = num_slots
self.head_dim = embed_dim // num_heads
self.dropout = nn.Dropout(dropout)
self.read_temperature = float(read_temperature)
self.write_temperature = float(write_temperature)
self.state_fp32 = bool(state_fp32)
self.slot_dropout = float(slot_dropout)
self.normalize_k = bool(normalize_k)
self.routing_override = None
self.use_rope_keys = bool(use_rope_keys)
self.use_alibi_write = bool(use_alibi_write)
self.learn_alibi_strength = bool(learn_alibi_strength)
self.min_strength = float(min_strength)
self.use_content_read = bool(use_content_read)
self.content_read_max_gamma = float(content_read_max_gamma)
self.use_slotspace_refine = bool(use_slotspace_refine)
self.slotspace_dim = int(slotspace_dim)
self.slotspace_dropout = nn.Dropout(float(slotspace_dropout))
self.slotspace_signed_weights = bool(slotspace_signed_weights)
self.write_chunk_size = int(write_chunk_size)
self.slotspace_chunk_size = int(slotspace_chunk_size)
# Learned slot keys: [H, K, d]
self.slot_keys = nn.Parameter(
torch.randn(num_heads, num_slots, self.head_dim) / math.sqrt(self.head_dim)
)
# Projections
self.Wk_write = nn.Linear(embed_dim, embed_dim, bias=False)
self.Wv_write = nn.Linear(embed_dim, embed_dim, bias=False)
self.Wq_read = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
# RoPE (write geometry)
self.rope = RotaryEmbedding(self.head_dim, base=rope_base) if self.use_rope_keys else None
# ALiBi
if self.use_alibi_write:
self.register_buffer("_alibi_slopes", alibi_slopes(num_heads), persistent=False)
else:
self.register_buffer("_alibi_slopes", torch.zeros(num_heads), persistent=False)
if self.use_alibi_write and self.learn_alibi_strength:
init = torch.tensor(float(alibi_strength_init) - self.min_strength).clamp_min(1e-8)
self._alibi_strength_param = nn.Parameter(_inv_softplus(init))
else:
self._alibi_strength_param = None
self.alibi_strength = float(alibi_strength_init)
# Content read gamma
if self.use_content_read:
self._content_read_gamma_raw = nn.Parameter(torch.tensor(float(content_read_init)))
else:
self._content_read_gamma_raw = None
# Slot-space refinement
self.use_rope_slotspace = bool(use_rope_slotspace) and bool(self.use_slotspace_refine)
if self.use_slotspace_refine:
self.slot_in = nn.Linear(num_slots, self.slotspace_dim, bias=False)
self.slot_q = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
self.slot_k = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
self.slot_v = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False)
self.slot_out = nn.Linear(self.slotspace_dim, num_slots, bias=False)
self._slotspace_gate_raw = nn.Parameter(torch.tensor(float(slotspace_gate_init)))
if self.use_rope_slotspace:
assert (self.slotspace_dim % 2) == 0, "use_rope_slotspace requires even slotspace_dim"
self.rope_slotspace = RotaryEmbedding(self.slotspace_dim, base=float(rope_base_slotspace))
else:
self.rope_slotspace = None
else:
self.slot_in = None
self.slot_q = self.slot_k = self.slot_v = None
self.slot_out = None
self._slotspace_gate_raw = None
self.rope_slotspace = None
# ----- intervention defaults (NO-OP) -----
self._intv_mode: str = "off"
self._intv_beta: float = 1.0
self._intv_score_kind: str = "orth_frac"
self._intv_tau_kind: str = "pctl"
self._intv_tau: float = 0.15
self._intv_tau_pctl: float = 75.0
self._intv_mask_mode: str = "soft"
self._intv_soft_temp: float = 0.05
self._intv_par_beta: float = 1.0
self._intv_head_mask: Optional[torch.Tensor] = None
self._intv_score_clip_pctl: float = 99.0
# ----- refine-geometry logging (no compute change) -----
self._log_refine_geom: bool = False
# -------------------------------------------------------- scalar params ---
def _alibi_strength(self, dtype, device) -> torch.Tensor:
if not (self.use_alibi_write and self.learn_alibi_strength):
return torch.tensor(self.alibi_strength, dtype=dtype, device=device)
return (F.softplus(self._alibi_strength_param) + self.min_strength).to(dtype=dtype, device=device)
def _content_read_gamma(self, dtype, device) -> torch.Tensor:
if not self.use_content_read:
return torch.tensor(0.0, dtype=dtype, device=device)
g = F.softplus(self._content_read_gamma_raw)
if self.content_read_max_gamma is not None and self.content_read_max_gamma > 0:
g = g.clamp(max=self.content_read_max_gamma)
return g.to(dtype=dtype, device=device)
def _slotspace_gate(self, dtype, device) -> torch.Tensor:
if not self.use_slotspace_refine:
return torch.tensor(0.0, dtype=dtype, device=device)
return F.softplus(self._slotspace_gate_raw).to(dtype=dtype, device=device)
# --------------------------------------------------------- numerics ---
@staticmethod
def _safe_exp_sub_max(s: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
diff = s - m
diff = diff.masked_fill(~torch.isfinite(m), float("-inf"))
return torch.exp(diff)
# ------------------------------------------------------ slot mask ---
def _resolve_slot_mask(
self,
slot_mask: Optional[torch.Tensor],
*,
B: int, H: int, L: int, K: int,
device, dtype, scope: str,
) -> Optional[torch.Tensor]:
"""Expand [K] mask → [B,H,L,K]. Falls back to self.slot_mask attr."""
if slot_mask is None:
slot_mask = getattr(self, "slot_mask", None)
if slot_mask is None:
return None
sm = slot_mask.to(device=device, dtype=dtype)
if sm.ndim != 1 or sm.numel() != K:
raise ValueError(f"slot_mask must be shape [K]={K}, got {tuple(sm.shape)}")
sm = sm.view(1, 1, 1, K)
if scope == "all":
return sm.expand(B, H, L, K)
if scope == "last_pos_only":
out = torch.ones((B, H, L, K), device=device, dtype=dtype)
out[:, :, -1:, :] = sm.expand(B, H, 1, K)
return out
raise ValueError(f"Unknown slot_mask_scope={scope!r}")
@staticmethod
def _apply_hard_mask_and_renorm(w: torch.Tensor, keep: torch.Tensor) -> torch.Tensor:
w = w * keep.to(w.dtype)
return w / w.sum(dim=-1, keepdim=True).clamp_min(1e-8)
# --------------------------------------------------- info helpers ---
@staticmethod
def default_info_cfg() -> Dict:
"""Return default info_cfg dict. Copy and modify before passing to forward()."""
return dict(
store_read_weights=True,
store_read_logits=True,
store_write_logits=True,
store_slot_state_norm=True,
store_out1=False,
store_delta=False,
store_slot_w=False,
detach_to_cpu=False,
time_stride=1,
batch_stride=1,
)
@staticmethod
def _store_tensor(
t: Optional[torch.Tensor], *, cfg: Dict, kind: str,
) -> Optional[torch.Tensor]:
"""Downsample + detach (+ optional CPU offload)."""
if t is None:
return None
bstride = int(cfg.get("batch_stride", 1))
tstride = int(cfg.get("time_stride", 1))
to_cpu = bool(cfg.get("detach_to_cpu", False))
x = t
if x.dim() >= 1 and bstride > 1:
x = x[::bstride]
if x.dim() == 4 and tstride > 1:
if kind == "bhtk":
x = x[:, :, ::tstride, :]
elif kind == "bhkt":
x = x[:, :, :, ::tstride]
x = x.detach()
if to_cpu:
x = x.to("cpu", non_blocking=True)
return x
# ------------------------------------------------ read-weight routing ---
def _compute_read_weights(
self,
*,
read_logits: torch.Tensor,
read_logits_key: torch.Tensor,
read_logits_content: Optional[torch.Tensor],
routing_mode: str,
routing_topk: int,
read_weights_override: Optional[torch.Tensor],
routing_noise: Optional[str],
routing_noise_scale: float,
rtemp: float,
sm: Optional[torch.Tensor],
slot_mask_where: str,
B: int, H: int, L: int, K: int,
T_total: int,
t0: int, t1: int,
q_read_c: torch.Tensor,
slot_keys: torch.Tensor,
slot_state_t: torch.Tensor,
valid: Optional[torch.Tensor],
state_dtype,
) -> torch.Tensor:
"""Compute read weights for one write-chunk. Handles noise, overrides, masks."""
# routing noise
if routing_noise is not None:
if routing_noise == "gumbel":
u = torch.rand_like(read_logits)
g = -torch.log(-torch.log(u.clamp_min(1e-8)).clamp_min(1e-8))
read_logits = read_logits + routing_noise_scale * g
elif routing_noise == "gaussian":
read_logits = read_logits + routing_noise_scale * torch.randn_like(read_logits)
else:
raise ValueError(f"Unknown routing_noise={routing_noise}")
# routing override (external callable or tensor)
if self.routing_override is not None:
if callable(self.routing_override):
ctx = dict(
t0=t0, t1=t1, B=B, H=H, T=T_total, K=K, d=self.head_dim,
rtemp=rtemp, state_dtype=state_dtype,
q_read_c=q_read_c, slot_keys=slot_keys,
slot_state_t=slot_state_t, valid=valid,
)
read_w = self.routing_override(
t0, t1, read_logits, read_logits_key, read_logits_content, ctx,
)
else:
read_w = self.routing_override[:, :, t0:t1, :].to(read_logits.dtype)
read_w = torch.nan_to_num(read_w, nan=0.0, posinf=0.0, neginf=0.0)
read_w = read_w.clamp_min(0.0)
read_w = read_w / read_w.sum(dim=-1, keepdim=True).clamp_min(1e-8)
else:
if routing_mode == "softmax":
read_w = torch.softmax(read_logits / rtemp, dim=-1)
elif routing_mode == "top1":
top = read_logits.argmax(dim=-1)
read_w = F.one_hot(top, num_classes=K).to(read_logits.dtype)
elif routing_mode == "topk":
kk = max(1, min(K, int(routing_topk)))
vals, idx = torch.topk(read_logits, k=kk, dim=-1)
masked = torch.full_like(read_logits, float("-inf"))
masked.scatter_(-1, idx, vals)
read_w = torch.softmax(masked / rtemp, dim=-1)
elif routing_mode == "external":
if read_weights_override is None:
raise ValueError("routing_mode='external' requires read_weights_override")
if read_weights_override.shape[-2] == T_total:
read_w = read_weights_override[:, :, t0:t1, :]
else:
read_w = read_weights_override
read_w = read_w / read_w.sum(dim=-1, keepdim=True).clamp_min(1e-8)
else:
raise ValueError(f"Unknown routing_mode={routing_mode}")
# slot mask at read stage
if slot_mask_where == "read" and sm is not None:
read_w = self._apply_hard_mask_and_renorm(read_w, (sm > 0.0))
return read_w
# ------------------------------------------- refine-delta intervention ---
def _apply_refine_intervention(
self,
out1: torch.Tensor,
delta: torch.Tensor,
slot_w: Optional[torch.Tensor],
):
"""Decompose refine delta into par/orth vs base output, optionally gate."""
eps = 1e-8
B, H, L, d = out1.shape
# head mask
hm = getattr(self, "_intv_head_mask", None)
if hm is not None:
hm = hm.to(device=out1.device).view(1, H, 1, 1).to(dtype=out1.dtype)
out1_norm2 = (out1 * out1).sum(dim=-1, keepdim=True).clamp_min(eps)
alpha = (delta * out1).sum(dim=-1, keepdim=True) / out1_norm2
delta_par = alpha * out1
delta_orth = delta - delta_par
logs = None
# geometry logging (no output change)
if getattr(self, "_log_refine_geom", False):
out1n = out1.norm(dim=-1).clamp_min(eps)
dn = delta.norm(dim=-1).clamp_min(eps)
dparn = delta_par.norm(dim=-1)
dorthn = delta_orth.norm(dim=-1)
a = alpha.squeeze(-1)
logs = dict(
geom_alpha_mean=a.mean(dim=(0, 2)),
geom_alpha_abs=a.abs().mean(dim=(0, 2)),
geom_sign_pos=(a > 0).float().mean(dim=(0, 2)),
geom_orth_frac=(dorthn / dn).mean(dim=(0, 2)),
geom_d_ratio=(dn / out1n).mean(dim=(0, 2)),
geom_dpar_ratio=(dparn / dn).mean(dim=(0, 2)),
)
mode = getattr(self, "_intv_mode", "off")
if mode is None or mode == "off":
return delta, logs
# --- intervention modes ---
if mode == "delta_par":
delta_mod = delta_par
logs = logs or {}
logs["alpha"] = alpha.squeeze(-1)
elif mode == "delta_orth":
delta_mod = delta_orth
logs = logs or {}
logs["alpha"] = alpha.squeeze(-1)
elif mode == "delta_par_plus_orth":
delta_mod = delta_par + delta_orth
logs = logs or {}
logs["alpha"] = alpha.squeeze(-1)
elif mode == "orth_gate":
beta = float(getattr(self, "_intv_beta", 1.0))
sk = getattr(self, "_intv_score_kind", "orth_frac")
out1n = out1.norm(dim=-1).clamp_min(eps)
dorthn = delta_orth.norm(dim=-1)
dn = delta.norm(dim=-1).clamp_min(eps)
if sk == "orth_ratio":
score = dorthn / out1n
elif sk == "orth_frac":
score = dorthn / dn
elif sk == "alpha_abs":
score = alpha.abs().squeeze(-1)
elif sk == "slot_peaked":
if slot_w is None:
raise ValueError("score_kind='slot_peaked' requires slot_w")
p = torch.softmax(slot_w.float(), dim=-1).clamp_min(1e-8)
Hrw = -(p * p.log()).sum(dim=-1)
K = p.shape[-1]
score = (1.0 - Hrw / max(1e-8, math.log(K))).to(dtype=out1.dtype)
else:
raise ValueError(f"Unknown _intv_score_kind={sk}")
# score clipping
clip_p = getattr(self, "_intv_score_clip_pctl", None)
if clip_p is not None:
clip_p = float(clip_p)
if 0.0 < clip_p < 100.0:
smax = torch.quantile(score.detach().flatten(), clip_p / 100.0).to(score.dtype)
score = torch.clamp(score, max=smax)
# tau
tk = getattr(self, "_intv_tau_kind", "pctl")
if tk == "abs":
tau = torch.tensor(float(getattr(self, "_intv_tau", 0.15)),
device=score.device, dtype=score.dtype)
elif tk == "pctl":
tau = torch.quantile(
score.detach().flatten(),
float(getattr(self, "_intv_tau_pctl", 75.0)) / 100.0,
).to(score.dtype)
else:
raise ValueError(f"Unknown _intv_tau_kind={tk}")
# mask
mm = getattr(self, "_intv_mask_mode", "soft")
if mm == "hard":
mask = (score > tau).to(out1.dtype)
elif mm == "soft":
temp = max(1e-6, float(getattr(self, "_intv_soft_temp", 0.05)))
mask = torch.sigmoid((score - tau) / temp).to(out1.dtype)
else:
raise ValueError(f"Unknown _intv_mask_mode={mm}")
par_beta = float(getattr(self, "_intv_par_beta", 1.0))
delta_mod = par_beta * delta_par + beta * mask.unsqueeze(-1) * delta_orth
logs = logs or {}
logs.update(dict(
score=score, tau=tau, mask=mask,
alpha=alpha.squeeze(-1),
out1_norm=out1n,
dpar_norm=delta_par.norm(dim=-1),
dorth_norm=dorthn,
))
else:
raise ValueError(f"Unknown _intv_mode={mode}")
# head targeting
if hm is not None:
delta_mod = hm * delta_mod + (1.0 - hm) * delta
logs = logs or {}
logs["head_mask"] = hm.squeeze(0).squeeze(-1).squeeze(-1).detach()
return delta_mod, logs
# ============================================================ forward ===
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_info: bool = False,
# routing
routing_mode: str = "softmax",
routing_topk: int = 2,
read_weights_override: Optional[torch.Tensor] = None,
routing_noise: Optional[str] = None,
routing_noise_scale: float = 1.0,
# slot mask (causal intervention)
slot_mask: Optional[torch.Tensor] = None,
slot_mask_where: str = "read",
slot_mask_scope: str = "all",
# info controls
info_level: str = "full",
info_cfg: Optional[Dict] = None,
) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
"""
Parameters
----------
x : [B, T, C]
attention_mask : [B, T] optional padding mask (1=valid, 0=pad)
return_info : if True, return diagnostics dict as second element
routing_mode : "softmax" | "top1" | "topk" | "external"
routing_topk : k for topk mode
read_weights_override : [B,H,T,K] or [B,H,L,K] for external routing
routing_noise : None | "gumbel" | "gaussian"
routing_noise_scale : scale for routing noise
slot_mask : [K] where 1=keep, 0=mask
slot_mask_where : "read" | "content_read_only" | "slotspace_only"
slot_mask_scope : "all" | "last_pos_only"
info_level : "basic" | "logits" | "full"
info_cfg : dict (see default_info_cfg())
Returns
-------
(output, info) where info is None if return_info=False.
"""
B, T, C = x.shape
H, K, d = self.num_heads, self.num_slots, self.head_dim
# ---- resolve info config ----
if info_cfg is None:
info_cfg = self.default_info_cfg()
store_read_weights = bool(info_cfg.get("store_read_weights", True))
store_read_logits = bool(info_cfg.get("store_read_logits", True)) and info_level in ("logits", "full")
store_write_logits = bool(info_cfg.get("store_write_logits", True)) and info_level == "full"
store_slot_norm = bool(info_cfg.get("store_slot_state_norm", True)) and info_level == "full"
store_out1 = bool(info_cfg.get("store_out1", False)) and return_info
store_delta = bool(info_cfg.get("store_delta", False)) and return_info
store_slot_w = bool(info_cfg.get("store_slot_w", False)) and return_info
# ---- projections ----
k_write = self.Wk_write(x).view(B, T, H, d).transpose(1, 2)
v_write = self.Wv_write(x).view(B, T, H, d).transpose(1, 2)
q_read = self.Wq_read(x).view(B, T, H, d).transpose(1, 2)
if self.normalize_k:
k_write = F.normalize(k_write, dim=-1, eps=1e-8)
if self.use_rope_keys:
cos, sin = self.rope.get_cos_sin(T, device=x.device, dtype=k_write.dtype)
k_write = apply_rope(k_write, cos, sin)
# slot dropout
slot_keys = self.slot_keys
if self.training and self.slot_dropout > 0.0:
drop = (torch.rand((H, K), device=x.device) < self.slot_dropout)
slot_keys = slot_keys * (~drop).to(slot_keys.dtype).unsqueeze(-1)
# ---- WRITE logits ----
write_logits_raw = torch.einsum("hkd,bhtd->bhkt", slot_keys, k_write) / math.sqrt(d)
state_dtype = torch.float32 if (self.state_fp32 and x.dtype != torch.float32) else x.dtype
write_logits = write_logits_raw.to(state_dtype) / max(1e-6, self.write_temperature)
# ALiBi
alibi_bias_applied = None
if self.use_alibi_write:
strength = self._alibi_strength(dtype=state_dtype, device=x.device)
slopes = self._alibi_slopes.to(device=x.device, dtype=state_dtype) * strength
pos_i = torch.arange(T, device=x.device, dtype=state_dtype)
alibi_bias = slopes.view(1, H, 1, 1) * pos_i.view(1, 1, 1, T)
write_logits = write_logits + alibi_bias
alibi_bias_applied = alibi_bias
# padding mask
if attention_mask is not None:
valid = attention_mask.to(dtype=torch.bool)
write_logits = write_logits.masked_fill(~valid.view(B, 1, 1, T), float("-inf"))
else:
valid = None
# ================================================================
# STREAMING WRITE + READ
# ================================================================
content_read_gamma = self._content_read_gamma(dtype=q_read.dtype, device=x.device)
rtemp = max(1e-6, self.read_temperature)
out_h = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype)
out1_full = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) if store_out1 else None
delta_full = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) if store_delta else None
slot_w_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) if store_slot_w else None
need_rw = bool(self.use_slotspace_refine) or (return_info and store_read_weights)
read_weights = torch.empty((B, H, T, K), device=x.device, dtype=q_read.dtype) if need_rw else None
slot_state_norm_t = (
torch.empty((B, H, T, K), device=x.device, dtype=torch.float32)
if (return_info and store_slot_norm) else None
)
if return_info and store_read_logits:
read_logits_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype)
read_logits_key_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype)
read_logits_content_full = (
torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) if self.use_content_read else None
)
else:
read_logits_full = read_logits_key_full = read_logits_content_full = None
# streaming state
denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype)
numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype)
m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype)
WRITE_CHUNK = self.write_chunk_size
for t0 in range(0, T, WRITE_CHUNK):
t1 = min(T, t0 + WRITE_CHUNK)
L = t1 - t0
wlog_c = write_logits[:, :, :, t0:t1]
m_c, _ = torch.cummax(wlog_c, dim=-1)
m_new = torch.maximum(m_state.unsqueeze(-1), m_c)
scale = torch.exp(m_state.unsqueeze(-1) - m_new)
denom_c = denom_state.unsqueeze(-1) * scale
numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1)
w_new = self._safe_exp_sub_max(wlog_c, m_new)
denom_c = denom_c + torch.cumsum(w_new, dim=-1)
v_c = v_write[:, :, t0:t1, :].to(state_dtype)
add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2)
numer_c = numer_c + add
slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1)
slot_state_t = slot_state_c.permute(0, 1, 3, 2, 4).contiguous()
# READ logits
q_read_c = q_read[:, :, t0:t1, :]
read_logits_key = torch.einsum("bhld,hkd->bhlk", q_read_c, slot_keys) / math.sqrt(d)
read_logits_content = None
if self.use_content_read:
read_logits_content = torch.einsum(
"bhld,bhlkd->bhlk", q_read_c, slot_state_t.to(q_read_c.dtype),
) / math.sqrt(d)
# slot mask for this chunk
sm = self._resolve_slot_mask(
slot_mask, B=B, H=H, L=L, K=K,
device=x.device, dtype=read_logits_key.dtype, scope=slot_mask_scope,
)
# apply mask to logits according to slot_mask_where
if slot_mask_where == "read":
if sm is not None:
read_logits_key = read_logits_key.masked_fill(sm <= 0.0, float("-inf"))
if self.use_content_read and read_logits_content is not None:
read_logits_content = read_logits_content.masked_fill(sm <= 0.0, float("-inf"))
elif slot_mask_where == "content_read_only":
if sm is not None and self.use_content_read and read_logits_content is not None:
read_logits_content = read_logits_content.masked_fill(sm <= 0.0, 0.0)
elif slot_mask_where == "slotspace_only":
pass # applied later on slot_w
else:
raise ValueError(f"Unknown slot_mask_where={slot_mask_where!r}")
# combine
rl = read_logits_key
if self.use_content_read and read_logits_content is not None:
rl = rl + content_read_gamma.to(rl.dtype) * read_logits_content
if return_info and store_read_logits:
read_logits_full[:, :, t0:t1, :] = rl.to(state_dtype)
read_logits_key_full[:, :, t0:t1, :] = read_logits_key.to(state_dtype)
if self.use_content_read and read_logits_content_full is not None:
read_logits_content_full[:, :, t0:t1, :] = read_logits_content.to(state_dtype)
# read weights
read_w_c = self._compute_read_weights(
read_logits=rl, read_logits_key=read_logits_key,
read_logits_content=read_logits_content,
routing_mode=routing_mode, routing_topk=routing_topk,
read_weights_override=read_weights_override,
routing_noise=routing_noise, routing_noise_scale=routing_noise_scale,
rtemp=rtemp, sm=sm, slot_mask_where=slot_mask_where,
B=B, H=H, L=L, K=K, T_total=T, t0=t0, t1=t1,
q_read_c=q_read_c, slot_keys=slot_keys,
slot_state_t=slot_state_t, valid=valid,
state_dtype=state_dtype,
)
if read_weights is not None:
read_weights[:, :, t0:t1, :] = read_w_c
# base output
out_h[:, :, t0:t1, :] = torch.einsum(
"bhlk,bhlkd->bhld", read_w_c.to(state_dtype), slot_state_t.to(state_dtype),
)
if out1_full is not None:
out1_full[:, :, t0:t1, :] = out_h[:, :, t0:t1, :]
if slot_state_norm_t is not None:
slot_state_norm_t[:, :, t0:t1, :] = slot_state_t.to(torch.float32).norm(dim=-1)
m_state = m_new[:, :, :, -1]
denom_state = denom_c[:, :, :, -1]
numer_state = numer_c[:, :, :, -1, :]
# ================================================================
# SLOT-SPACE REFINEMENT
# ================================================================
slotspace_delta_norm_mean = None
intv_logs_acc: Optional[Dict] = None
intv_logs_count = 0
if self.use_slotspace_refine:
slotspace_dtype = state_dtype
M = self.slotspace_dim
assert read_weights is not None
u = self.slot_in(read_weights.to(slotspace_dtype))
q_s = self.slot_q(u)
k_s = self.slot_k(u)
v_s = self.slot_v(u)
if self.use_rope_slotspace:
cos_s, sin_s = self.rope_slotspace.get_cos_sin(T, device=x.device, dtype=q_s.dtype)
q_s = apply_rope(q_s, cos_s, sin_s)
k_s = apply_rope(k_s, cos_s, sin_s)
qf = phi(q_s)
kf = phi(k_s)
if valid is not None:
vmask = valid.view(B, 1, T, 1).to(slotspace_dtype)
qf = qf * vmask
kf = kf * vmask
v_s = v_s * vmask
u2 = torch.empty((B, H, T, M), device=x.device, dtype=slotspace_dtype)
S_state = torch.zeros((B, H, M, M), device=x.device, dtype=slotspace_dtype)
Z_state = torch.zeros((B, H, M), device=x.device, dtype=slotspace_dtype)
SS_CHUNK = self.slotspace_chunk_size
for t0 in range(0, T, SS_CHUNK):
t1 = min(T, t0 + SS_CHUNK)
qf_c = qf[:, :, t0:t1, :]
kf_c = kf[:, :, t0:t1, :]
v_c = v_s[:, :, t0:t1, :]
kv = torch.einsum("bhlm,bhln->bhlmn", kf_c, v_c)
S_c = torch.cumsum(kv, dim=2) + S_state.unsqueeze(2)
Z_c = (torch.cumsum(kf_c, dim=2) + Z_state.unsqueeze(2)).clamp_min(1e-8)
num = torch.einsum("bhlm,bhlmn->bhln", qf_c, S_c)
den = torch.einsum("bhlm,bhlm->bhl", qf_c, Z_c).unsqueeze(-1).clamp_min(1e-8)
u2[:, :, t0:t1, :] = num / den
S_state = S_c[:, :, -1, :, :]
Z_state = Z_c[:, :, -1, :]
u2 = self.slotspace_dropout(u2)
slot_w = self.slot_out(u2)
if slot_w_full is not None:
slot_w_full[:] = slot_w.to(state_dtype)
if self.slotspace_signed_weights:
slot_w_eff = torch.tanh(slot_w)
else:
slot_w_eff = torch.softmax(slot_w, dim=-1)
# slotspace-only mask
if slot_mask_where == "slotspace_only":
sm_full = self._resolve_slot_mask(
slot_mask, B=B, H=H, L=T, K=K,
device=x.device, dtype=slot_w_eff.dtype, scope=slot_mask_scope,
)
if sm_full is not None:
slot_w_eff = slot_w_eff * (sm_full > 0.0).to(slot_w_eff.dtype)
if not self.slotspace_signed_weights:
slot_w_eff = slot_w_eff / slot_w_eff.sum(dim=-1, keepdim=True).clamp_min(1e-8)
gate = self._slotspace_gate(dtype=state_dtype, device=x.device).to(state_dtype)
# second streaming pass: decode delta through slot states
denom_state2 = torch.zeros((B, H, K), device=x.device, dtype=state_dtype)
numer_state2 = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype)
m_state2 = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype)
delta_norm_sum = torch.zeros((), device=x.device, dtype=torch.float32)
delta_norm_count = 0
for t0 in range(0, T, WRITE_CHUNK):
t1 = min(T, t0 + WRITE_CHUNK)
Lc = t1 - t0
wlog_c = write_logits[:, :, :, t0:t1]
m_c, _ = torch.cummax(wlog_c, dim=-1)
m_new = torch.maximum(m_state2.unsqueeze(-1), m_c)
scale = torch.exp(m_state2.unsqueeze(-1) - m_new)
denom_c = denom_state2.unsqueeze(-1) * scale
numer_c = numer_state2.unsqueeze(-2) * scale.unsqueeze(-1)
w_new = self._safe_exp_sub_max(wlog_c, m_new)
denom_c = denom_c + torch.cumsum(w_new, dim=-1)
v_c = v_write[:, :, t0:t1, :].to(state_dtype)
add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2)
numer_c = numer_c + add
slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1)
slot_state_t2 = slot_state_c.permute(0, 1, 3, 2, 4).contiguous()
slot_w_c = slot_w_eff[:, :, t0:t1, :].to(state_dtype)
delta_c = torch.einsum("bhlk,bhlkd->bhld", slot_w_c, slot_state_t2.to(state_dtype))
delta = gate * delta_c
if delta_full is not None:
delta_full[:, :, t0:t1, :] = delta
# intervention
slot_w_for_score = slot_w[:, :, t0:t1, :] if store_slot_w else None
delta_mod, logs = self._apply_refine_intervention(
out1=out_h[:, :, t0:t1, :], delta=delta, slot_w=slot_w_for_score,
)
out_h[:, :, t0:t1, :] = out_h[:, :, t0:t1, :] + delta_mod
# accumulate logs
if logs is not None and return_info:
if intv_logs_acc is None:
intv_logs_acc = {}
for klog, v in logs.items():
if torch.is_tensor(v):
vv = v.detach().to(torch.float32)
intv_logs_acc[klog] = vv if vv.ndim == 1 else vv.mean()
intv_logs_count = 1
else:
for klog, v in logs.items():
if torch.is_tensor(v) and klog in intv_logs_acc:
vv = v.detach().to(torch.float32)
intv_logs_acc[klog] = intv_logs_acc[klog] + (vv if vv.ndim == 1 else vv.mean())
intv_logs_count += 1
delta_norm_sum = delta_norm_sum + delta.detach().to(torch.float32).norm(dim=-1).sum()
delta_norm_count += B * H * Lc
m_state2 = m_new[:, :, :, -1]
denom_state2 = denom_c[:, :, :, -1]
numer_state2 = numer_c[:, :, :, -1, :]
slotspace_delta_norm_mean = (delta_norm_sum / max(1, delta_norm_count)).detach().cpu()
# ================================================================
# OUTPUT
# ================================================================
out = out_h.transpose(1, 2).contiguous().view(B, T, C)
out = self.out_proj(out)
out = self.dropout(out)
# ---- info dict ----
info = None
if return_info:
info = {
"content_read_gamma": content_read_gamma.detach().to(torch.float32).cpu(),
"routing_mode": routing_mode,
"slot_mask_where": slot_mask_where,
"slot_mask_scope": slot_mask_scope,
"intv_mode": getattr(self, "_intv_mode", "off"),
}
if alibi_bias_applied is not None and info_level == "full":
info["alibi_bias_applied"] = self._store_tensor(alibi_bias_applied.to(torch.float32), cfg=info_cfg, kind="other")
if self.use_alibi_write and self.learn_alibi_strength:
info["alibi_strength"] = self._alibi_strength(dtype=torch.float32, device=x.device).detach().cpu()
if self.use_slotspace_refine:
info["slotspace_gate"] = self._slotspace_gate(dtype=torch.float32, device=x.device).detach().cpu()
info["use_rope_slotspace"] = torch.tensor(bool(self.use_rope_slotspace))
if slotspace_delta_norm_mean is not None:
info["slotspace_delta_norm"] = slotspace_delta_norm_mean
# read weights
if store_read_weights and read_weights is not None:
info["read_weights"] = self._store_tensor(read_weights, cfg=info_cfg, kind="bhtk")
else:
info["read_weights"] = None
# slot state norm
if store_slot_norm and slot_state_norm_t is not None:
s = slot_state_norm_t.permute(0, 1, 3, 2).contiguous()
info["slot_state_norm"] = self._store_tensor(s, cfg=info_cfg, kind="bhkt")
else:
info["slot_state_norm"] = None
# read logits
if store_read_logits and read_logits_full is not None:
info["read_logits"] = self._store_tensor(read_logits_full.to(torch.float32), cfg=info_cfg, kind="bhtk")
info["read_logits_key"] = self._store_tensor(read_logits_key_full.to(torch.float32), cfg=info_cfg, kind="bhtk")
info["read_logits_content"] = (
self._store_tensor(read_logits_content_full.to(torch.float32), cfg=info_cfg, kind="bhtk")
if read_logits_content_full is not None else None
)
else:
info["read_logits"] = info["read_logits_key"] = info["read_logits_content"] = None
# write logits
if store_write_logits and info_level == "full":
info["write_logits_raw"] = self._store_tensor(write_logits_raw, cfg=info_cfg, kind="bhkt")
info["write_logits"] = self._store_tensor(write_logits.to(torch.float32), cfg=info_cfg, kind="bhkt")
else:
info["write_logits_raw"] = info["write_logits"] = None
# out1 / delta / slot_w
info["out1"] = self._store_tensor(out1_full.to(torch.float32), cfg=info_cfg, kind="other") if out1_full is not None else None
info["delta"] = self._store_tensor(delta_full.to(torch.float32), cfg=info_cfg, kind="other") if delta_full is not None else None
info["slot_w"] = self._store_tensor(slot_w_full.to(torch.float32), cfg=info_cfg, kind="bhtk") if slot_w_full is not None else None
# averaged intervention / geometry logs
if intv_logs_acc is not None and intv_logs_count > 0:
for klog, v in intv_logs_acc.items():
info[klog] = (v / float(intv_logs_count)).detach().cpu()
# backward-compatible scalar aliases
for alias_from, alias_to in [
("score", "intv_score_mean"), ("mask", "intv_mask_mean"),
("tau", "intv_tau"), ("alpha", "intv_alpha_mean"),
("out1_norm", "intv_out1_norm_mean"),
("dpar_norm", "intv_dpar_norm_mean"),
("dorth_norm", "intv_dorth_norm_mean"),
]:
if alias_from in intv_logs_acc:
val = info.get(alias_from)
if torch.is_tensor(val) and val.ndim != 1:
info[alias_to] = val
return out, info
# Addressed State Models (ASM): Config + Block + LM
#
# Unified companion for the consolidated AddressedStateAttention harness.
# Block.forward() and LM.forward() pass through the full ASA forward() surface:
# routing controls, slot mask, info_level, info_cfg.
#
# ============================================================================
# Config
# ============================================================================
@dataclass
class ASMTrainConfig:
# Data
dataset_name: str = "wikitext"
dataset_config: str = "wikitext-103-raw-v1"
tokenizer_name: str = "gpt2"
max_seq_len: int = 256
stride_frac_val: float = 0.50
seed: int = 1337
micro_batch_size: int = 2
grad_accum_steps: int = 8
train_samples_target: int = 100_000_000
val_samples_target: int = 25_000
# Training
batch_size: int = 64
learning_rate: float = 3e-4
weight_decay: float = 0.01
betas: Tuple[float, float] = (0.9, 0.95)
grad_clip: float = 1.0
warmup_steps: int = 1_000
total_steps: int = 75_000
eval_interval: int = 1_000
log_interval: int = 100
# Model
vocab_size: int = 50257
embed_dim: int = 384
num_layers: int = 23
num_heads: int = 8
num_slots: int = 32
mlp_ratio: float = 4.0
dropout: float = 0.1
tie_weights: bool = True
# ASA / numerics
read_temperature: float = 1.0
write_temperature: float = 1.0
slot_dropout: float = 0.05
state_fp32: bool = True
normalize_k: bool = False
# Positions
use_abs_pos: bool = False
use_rope_keys: bool = True
rope_base: float = 10000.0
use_alibi_write: bool = True
alibi_strength_init: float = 0.1
learn_alibi_strength: bool = True
min_strength: float = 0.0
# Content-conditioned read (gamma)
use_content_read: bool = True
content_read_init: float = -4.0
content_read_max_gamma: float = 3.0
# Slot-space refinement
use_slotspace_refine: bool = True
slotspace_dim: int = 64
slotspace_gate_init: float = -4.0
slotspace_dropout: float = 0.05
slotspace_signed_weights: bool = True
# RoPE inside slot-space matcher
use_rope_slotspace: bool = True
rope_base_slotspace: float = 100000.0
# Perf knobs
write_chunk_size: int = 128
slotspace_chunk_size: int = 128
enable_compiled: bool = False
# Analytics
eval_max_batches: int = 150
analytics_last_k: int = 32
# IO / caches
output_dir: str = "./drive/MyDrive/asm_outputs"
tag: str = "asm_wikitext"
cache_dir: str = "./drive/MyDrive/asm_caches"
val_windows_cache: str = "./drive/MyDrive/asm_val_cache_windows_1024.pkl"
# ============================================================================
# Block
# ============================================================================
class ASMBlock(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
num_slots: int,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
# temperatures / numerics
read_temperature: float = 1.0,
write_temperature: float = 1.0,
state_fp32: bool = True,
slot_dropout: float = 0.0,
normalize_k: bool = False,
# positions
use_rope_keys: bool = True,
rope_base: float = 10000.0,
use_alibi_write: bool = True,
# ALiBi
alibi_strength_init: float = 0.1,
learn_alibi_strength: bool = True,
min_strength: float = 0.0,
# content-conditioned read (gamma)
use_content_read: bool = True,
content_read_init: float = -4.0,
content_read_max_gamma: float = 3.0,
# slot-space refinement
use_slotspace_refine: bool = True,
slotspace_dim: int = 32,
slotspace_gate_init: float = -10.0,
slotspace_dropout: float = 0.0,
slotspace_signed_weights: bool = True,
# RoPE inside slot-space matcher
use_rope_slotspace: bool = True,
rope_base_slotspace: float = 100000.0,
# chunk sizes
write_chunk_size: int = 128,
slotspace_chunk_size: int = 128,
):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.asa = AddressedStateAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_slots=num_slots,
dropout=dropout,
read_temperature=read_temperature,
write_temperature=write_temperature,
state_fp32=state_fp32,
slot_dropout=slot_dropout,
normalize_k=normalize_k,
use_rope_keys=use_rope_keys,
rope_base=rope_base,
use_alibi_write=use_alibi_write,
alibi_strength_init=alibi_strength_init,
learn_alibi_strength=learn_alibi_strength,
min_strength=min_strength,
use_content_read=use_content_read,
content_read_init=content_read_init,
content_read_max_gamma=content_read_max_gamma,
use_slotspace_refine=use_slotspace_refine,
slotspace_dim=slotspace_dim,
slotspace_gate_init=slotspace_gate_init,
slotspace_dropout=slotspace_dropout,
slotspace_signed_weights=slotspace_signed_weights,
use_rope_slotspace=use_rope_slotspace,
rope_base_slotspace=rope_base_slotspace,
write_chunk_size=write_chunk_size,
slotspace_chunk_size=slotspace_chunk_size,
)
self.norm2 = nn.LayerNorm(embed_dim)
hidden = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, hidden, bias=False),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, embed_dim, bias=False),
nn.Dropout(dropout),
)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_info: bool = False,
# routing
routing_mode: str = "softmax",
routing_topk: int = 2,
read_weights_override: Optional[torch.Tensor] = None,
routing_noise: Optional[str] = None,
routing_noise_scale: float = 1.0,
# slot mask
slot_mask: Optional[torch.Tensor] = None,
slot_mask_where: str = "read",
slot_mask_scope: str = "all",
# info controls
info_level: str = "full",
info_cfg: Optional[Dict] = None,
):
a, info = self.asa(
self.norm1(x),
attention_mask=attention_mask,
return_info=return_info,
routing_mode=routing_mode,
routing_topk=routing_topk,
read_weights_override=read_weights_override,
routing_noise=routing_noise,
routing_noise_scale=routing_noise_scale,
slot_mask=slot_mask,
slot_mask_where=slot_mask_where,
slot_mask_scope=slot_mask_scope,
info_level=info_level,
info_cfg=info_cfg,
)
x = x + a
x = x + self.mlp(self.norm2(x))
return x, info
# ============================================================================
# LM
# ============================================================================
class ASMLanguageModel(nn.Module):
def __init__(
self,
vocab_size: int,
embed_dim: int = 384,
num_layers: int = 6,
num_heads: int = 8,
num_slots: int = 8,
max_seq_len: int = 1024,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
# temperatures / numerics
read_temperature: float = 1.0,
write_temperature: float = 1.0,
state_fp32: bool = True,
slot_dropout: float = 0.05,
normalize_k: bool = False,
tie_weights: bool = True,
# LM-level abs pos
use_abs_pos: bool = False,
# positions
use_rope_keys: bool = True,
rope_base: float = 10000.0,
use_alibi_write: bool = True,
# ALiBi
alibi_strength_init: float = 0.1,
learn_alibi_strength: bool = True,
min_strength: float = 0.0,
# content-conditioned read (gamma)
use_content_read: bool = True,
content_read_init: float = -4.0,
content_read_max_gamma: float = 3.0,
# slot-space refinement
use_slotspace_refine: bool = True,
slotspace_dim: int = 32,
slotspace_gate_init: float = -10.0,
slotspace_dropout: float = 0.0,
slotspace_signed_weights: bool = True,
# RoPE inside slot-space matcher
use_rope_slotspace: bool = True,
rope_base_slotspace: float = 100000.0,
# chunk sizes
write_chunk_size: int = 128,
slotspace_chunk_size: int = 128,
):
super().__init__()
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
self.use_abs_pos = bool(use_abs_pos)
self.tok = nn.Embedding(vocab_size, embed_dim)
self.pos = nn.Embedding(max_seq_len, embed_dim) if self.use_abs_pos else None
self.drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
ASMBlock(
embed_dim=embed_dim,
num_heads=num_heads,
num_slots=num_slots,
mlp_ratio=mlp_ratio,
dropout=dropout,
read_temperature=read_temperature,
write_temperature=write_temperature,
state_fp32=state_fp32,
slot_dropout=slot_dropout,
normalize_k=normalize_k,
use_rope_keys=use_rope_keys,
rope_base=rope_base,
use_alibi_write=use_alibi_write,
alibi_strength_init=alibi_strength_init,
learn_alibi_strength=learn_alibi_strength,
min_strength=min_strength,
use_content_read=use_content_read,
content_read_init=content_read_init,
content_read_max_gamma=content_read_max_gamma,
use_slotspace_refine=use_slotspace_refine,
slotspace_dim=slotspace_dim,
slotspace_gate_init=slotspace_gate_init,
slotspace_dropout=slotspace_dropout,
slotspace_signed_weights=slotspace_signed_weights,
use_rope_slotspace=use_rope_slotspace,
rope_base_slotspace=rope_base_slotspace,
write_chunk_size=write_chunk_size,
slotspace_chunk_size=slotspace_chunk_size,
)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(embed_dim)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
if tie_weights:
self.lm_head.weight = self.tok.weight
self.apply(self._init)
def _init(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_info: bool = False,
# routing
routing_mode: str = "softmax",
routing_topk: int = 2,
read_weights_override: Optional[torch.Tensor] = None,
routing_noise: Optional[str] = None,
routing_noise_scale: float = 1.0,
# slot mask
slot_mask: Optional[torch.Tensor] = None,
slot_mask_where: str = "read",
slot_mask_scope: str = "all",
# info controls
info_level: str = "full",
info_cfg: Optional[Dict] = None,
):
B, T = input_ids.shape
x = self.tok(input_ids)
if self.use_abs_pos:
pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1)
x = x + self.pos(pos)
x = self.drop(x)
infos: List[Optional[Dict[str, torch.Tensor]]] = []
for blk in self.blocks:
x, info = blk(
x,
attention_mask=attention_mask,
return_info=return_info,
routing_mode=routing_mode,
routing_topk=routing_topk,
read_weights_override=read_weights_override,
routing_noise=routing_noise,
routing_noise_scale=routing_noise_scale,
slot_mask=slot_mask,
slot_mask_where=slot_mask_where,
slot_mask_scope=slot_mask_scope,
info_level=info_level,
info_cfg=info_cfg,
)
if return_info:
infos.append(info)
x = self.norm(x)
logits = self.lm_head(x)
return (logits, infos) if return_info else logits
# ============================================================================
# Convenience: build model from config
# ============================================================================
def build_model_from_cfg(cfg: ASMTrainConfig) -> ASMLanguageModel:
return ASMLanguageModel(
vocab_size=cfg.vocab_size,
embed_dim=cfg.embed_dim,
num_layers=cfg.num_layers,
num_heads=cfg.num_heads,
num_slots=cfg.num_slots,
max_seq_len=cfg.max_seq_len,
mlp_ratio=cfg.mlp_ratio,
dropout=cfg.dropout,
read_temperature=cfg.read_temperature,
write_temperature=cfg.write_temperature,
state_fp32=cfg.state_fp32,
slot_dropout=cfg.slot_dropout,
normalize_k=cfg.normalize_k,
tie_weights=cfg.tie_weights,
use_abs_pos=cfg.use_abs_pos,
use_rope_keys=cfg.use_rope_keys,
rope_base=cfg.rope_base,
use_alibi_write=cfg.use_alibi_write,
alibi_strength_init=cfg.alibi_strength_init,
learn_alibi_strength=cfg.learn_alibi_strength,
min_strength=cfg.min_strength,
use_content_read=cfg.use_content_read,
content_read_init=cfg.content_read_init,
content_read_max_gamma=cfg.content_read_max_gamma,
use_slotspace_refine=cfg.use_slotspace_refine,
slotspace_dim=cfg.slotspace_dim,
slotspace_gate_init=cfg.slotspace_gate_init,
slotspace_dropout=cfg.slotspace_dropout,
slotspace_signed_weights=cfg.slotspace_signed_weights,
use_rope_slotspace=cfg.use_rope_slotspace,
rope_base_slotspace=cfg.rope_base_slotspace,
write_chunk_size=cfg.write_chunk_size,
slotspace_chunk_size=cfg.slotspace_chunk_size,
)