Buckets:
| """TinyMind PureField core. | |
| This block is the original TinyMind bounded-memory path: recent tokens are | |
| mixed exactly through a causal local window while older context is compressed | |
| into contractive multi-timescale recurrent state. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .config import OmegaConfig | |
| from .layers import RMSNorm | |
| from .regen_kv import ReGenesisKVBlock | |
| class LowRankAdapter(nn.Module): | |
| """Small per-layer adapter added to a shared projection.""" | |
| def __init__(self, in_features: int, out_features: int, rank: int, scale: float = 0.02): | |
| super().__init__() | |
| rank = max(1, min(rank, in_features, out_features)) | |
| self.down = nn.Linear(in_features, rank, bias=False) | |
| self.up = nn.Linear(rank, out_features, bias=False) | |
| self.scale = scale / math.sqrt(rank) | |
| nn.init.normal_(self.down.weight, std=0.02) | |
| nn.init.zeros_(self.up.weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.up(self.down(x)) * self.scale | |
| class SharedLowRankProjection(nn.Module): | |
| """Shared dense weight plus per-layer low-rank correction.""" | |
| def __init__(self, shared: nn.Linear, adapter: LowRankAdapter): | |
| super().__init__() | |
| self.shared = shared | |
| self.adapter = adapter | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.shared(x) + self.adapter(x) | |
| class PureFieldShared(nn.Module): | |
| """Shared PureField projections exported by the INT4 2:4 sparse toolchain.""" | |
| def __init__(self, cfg: OmegaConfig): | |
| super().__init__() | |
| j = int(cfg.timescale_count) | |
| r = int(cfg.memory_ranks) | |
| d = int(cfg.dim) | |
| self.purefield_shared_q = nn.Linear(d, r, bias=False) | |
| self.purefield_shared_k = nn.Linear(d, r, bias=False) | |
| self.purefield_shared_v = nn.Linear(d, d, bias=False) | |
| self.purefield_shared_contract = nn.Linear(d, j, bias=True) | |
| self.purefield_shared_write = nn.Linear(d, j, bias=True) | |
| self.purefield_shared_read = nn.Linear(r, j, bias=True) | |
| self.purefield_shared_out = nn.Linear(d * 3, d, bias=False) | |
| class PureFieldBlock(nn.Module): | |
| """Bounded recurrent memory block with exact local context window.""" | |
| def __init__( | |
| self, | |
| cfg: OmegaConfig, | |
| layer_index: int, | |
| shared: PureFieldShared | None = None, | |
| ): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.layer_index = layer_index | |
| self.dim = int(cfg.dim) | |
| self.memory_ranks = int(cfg.memory_ranks) | |
| self.timescale_count = int(cfg.timescale_count) | |
| self.local_window = max(1, int(cfg.local_window)) | |
| self.contractive_eps = float(cfg.contractive_eps) | |
| self.residual_scale = min(float(cfg.residual_alpha), 1.0 / math.sqrt(max(int(cfg.n_layers), 1))) | |
| self.norm = RMSNorm(self.dim) | |
| self.shared = shared if shared is not None else PureFieldShared(cfg) | |
| rank = int(cfg.low_rank) | |
| self.q_proj = SharedLowRankProjection( | |
| self.shared.purefield_shared_q, LowRankAdapter(self.dim, self.memory_ranks, rank) | |
| ) | |
| self.k_proj = SharedLowRankProjection( | |
| self.shared.purefield_shared_k, LowRankAdapter(self.dim, self.memory_ranks, rank) | |
| ) | |
| self.v_proj = SharedLowRankProjection( | |
| self.shared.purefield_shared_v, LowRankAdapter(self.dim, self.dim, rank) | |
| ) | |
| self.contract_proj = SharedLowRankProjection( | |
| self.shared.purefield_shared_contract, LowRankAdapter(self.dim, self.timescale_count, rank) | |
| ) | |
| self.write_proj = SharedLowRankProjection( | |
| self.shared.purefield_shared_write, LowRankAdapter(self.dim, self.timescale_count, rank) | |
| ) | |
| self.read_proj = SharedLowRankProjection( | |
| self.shared.purefield_shared_read, LowRankAdapter(self.memory_ranks, self.timescale_count, rank) | |
| ) | |
| self.out_proj = SharedLowRankProjection( | |
| self.shared.purefield_shared_out, LowRankAdapter(self.dim * 3, self.dim, rank) | |
| ) | |
| self.regen_kv = ReGenesisKVBlock(cfg) if bool(getattr(cfg, "regen_kv_enabled", False)) else None | |
| self.timescale_gain: torch.Tensor | |
| taus = torch.pow(2.0, torch.arange(self.timescale_count, dtype=torch.float32)) | |
| self.register_buffer("timescale_gain", torch.rsqrt(taus).view(1, self.timescale_count, 1, 1)) | |
| def _empty_memory(self, batch: int, x: torch.Tensor) -> torch.Tensor: | |
| return torch.zeros( | |
| batch, | |
| self.timescale_count, | |
| self.memory_ranks, | |
| self.dim, | |
| device=x.device, | |
| dtype=x.dtype, | |
| ) | |
| def _initial_state(self, batch: int, x: torch.Tensor, kv_cache: dict | None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| if kv_cache is None: | |
| return ( | |
| self._empty_memory(batch, x), | |
| torch.empty(batch, 0, self.memory_ranks, device=x.device, dtype=x.dtype), | |
| torch.empty(batch, 0, self.dim, device=x.device, dtype=x.dtype), | |
| ) | |
| memory = kv_cache.get("memory") | |
| if memory is None: | |
| memory = self._empty_memory(batch, x) | |
| else: | |
| memory = memory.to(device=x.device, dtype=x.dtype) | |
| local_k = kv_cache.get("local_k") | |
| local_v = kv_cache.get("local_v") | |
| if local_k is None: | |
| local_k = torch.empty(batch, 0, self.memory_ranks, device=x.device, dtype=x.dtype) | |
| else: | |
| local_k = local_k.to(device=x.device, dtype=x.dtype) | |
| if local_v is None: | |
| local_v = torch.empty(batch, 0, self.dim, device=x.device, dtype=x.dtype) | |
| else: | |
| local_v = local_v.to(device=x.device, dtype=x.dtype) | |
| return memory, local_k[:, -self.local_window :], local_v[:, -self.local_window :] | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| kv_cache: dict | None = None, | |
| mask: torch.Tensor | None = None, | |
| return_stats: bool = False, | |
| ) -> tuple[torch.Tensor, dict] | tuple[torch.Tensor, dict, dict[str, torch.Tensor]]: | |
| del mask | |
| batch, seq_len, _ = x.shape | |
| u = self.norm(x) | |
| q = F.normalize(self.q_proj(u), dim=-1, eps=1e-6) | |
| k = F.normalize(self.k_proj(u), dim=-1, eps=1e-6) | |
| v = torch.tanh(self.v_proj(u)) | |
| c_raw = self.contract_proj(u) | |
| c_gate = torch.exp(-F.softplus(c_raw)) | |
| eps = min(max(self.contractive_eps, 1e-6), 0.49) | |
| c_gate = c_gate * (1.0 - 2.0 * eps) + eps | |
| i_gate = torch.sigmoid(self.write_proj(u)) | |
| memory, local_k, local_v = self._initial_state(batch, x, kv_cache) | |
| y_mem_parts: list[torch.Tensor] = [] | |
| y_win_parts: list[torch.Tensor] = [] | |
| for t in range(seq_len): | |
| q_t = q[:, t] | |
| k_t = k[:, t] | |
| v_t = v[:, t] | |
| write = torch.einsum("br,bd->brd", k_t, v_t) | |
| decay = c_gate[:, t].view(batch, self.timescale_count, 1, 1) | |
| insert = i_gate[:, t].view(batch, self.timescale_count, 1, 1) | |
| # ensure timescale_gain is same device and dtype as tensors before multiplying | |
| tg = self.timescale_gain.to(device=x.device, dtype=x.dtype) | |
| memory = decay * memory + insert * (write.unsqueeze(1) * tg) | |
| read = torch.softmax(self.read_proj(q_t), dim=-1) | |
| slot_values = torch.einsum("br,bjrd->bjd", q_t, memory) | |
| y_mem_parts.append((read.unsqueeze(-1) * slot_values).sum(dim=1)) | |
| local_k = torch.cat([local_k, k_t.unsqueeze(1)], dim=1)[:, -self.local_window :] | |
| local_v = torch.cat([local_v, v_t.unsqueeze(1)], dim=1)[:, -self.local_window :] | |
| scores = torch.einsum("br,blr->bl", q_t, local_k) / math.sqrt(max(self.memory_ranks, 1)) | |
| weights = torch.softmax(scores, dim=-1) | |
| y_win_parts.append(torch.einsum("bl,bld->bd", weights, local_v)) | |
| y_mem = torch.stack(y_mem_parts, dim=1) | |
| y_win = torch.stack(y_win_parts, dim=1) | |
| y = self.out_proj(torch.cat([y_mem, y_win, u], dim=-1)) | |
| out = x + self.residual_scale * torch.tanh(y) | |
| new_cache = {"memory": memory, "local_k": local_k, "local_v": local_v} | |
| if self.regen_kv is not None: | |
| retrieved_tokens = kv_cache.get("retrieved_tokens") if kv_cache is not None else None | |
| out, regen_cache, regen_aux = self.regen_kv( | |
| out, | |
| memory_state=memory, | |
| retrieved_tokens=retrieved_tokens, | |
| return_aux=True, | |
| ) | |
| new_cache["regen_kv"] = regen_cache | |
| new_cache["regen_aux"] = regen_aux | |
| if return_stats: | |
| stats = { | |
| "contractive_gate": c_gate.detach(), | |
| "write_gate": i_gate.detach(), | |
| "memory_norm": memory.detach().norm(dim=(-1, -2)), | |
| "residual_scale": torch.tensor(self.residual_scale, device=x.device, dtype=x.dtype), | |
| } | |
| if "regen_aux" in new_cache: | |
| stats["hash_consistency_loss"] = new_cache["regen_aux"]["hash_consistency_loss"].detach() | |
| return out, new_cache, stats | |
| return out, new_cache | |
Xet Storage Details
- Size:
- 9.32 kB
- Xet hash:
- e17e5ba32caf025eeaec82c08ece69f04babb95cf12896ce6fd8a55926dc00f7
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.