bbkdevops's picture
download
raw
9.32 kB
"""TinyMind PureField core.
This block is the original TinyMind bounded-memory path: recent tokens are
mixed exactly through a causal local window while older context is compressed
into contractive multi-timescale recurrent state.
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import OmegaConfig
from .layers import RMSNorm
from .regen_kv import ReGenesisKVBlock
class LowRankAdapter(nn.Module):
"""Small per-layer adapter added to a shared projection."""
def __init__(self, in_features: int, out_features: int, rank: int, scale: float = 0.02):
super().__init__()
rank = max(1, min(rank, in_features, out_features))
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
self.scale = scale / math.sqrt(rank)
nn.init.normal_(self.down.weight, std=0.02)
nn.init.zeros_(self.up.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.up(self.down(x)) * self.scale
class SharedLowRankProjection(nn.Module):
"""Shared dense weight plus per-layer low-rank correction."""
def __init__(self, shared: nn.Linear, adapter: LowRankAdapter):
super().__init__()
self.shared = shared
self.adapter = adapter
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.shared(x) + self.adapter(x)
class PureFieldShared(nn.Module):
"""Shared PureField projections exported by the INT4 2:4 sparse toolchain."""
def __init__(self, cfg: OmegaConfig):
super().__init__()
j = int(cfg.timescale_count)
r = int(cfg.memory_ranks)
d = int(cfg.dim)
self.purefield_shared_q = nn.Linear(d, r, bias=False)
self.purefield_shared_k = nn.Linear(d, r, bias=False)
self.purefield_shared_v = nn.Linear(d, d, bias=False)
self.purefield_shared_contract = nn.Linear(d, j, bias=True)
self.purefield_shared_write = nn.Linear(d, j, bias=True)
self.purefield_shared_read = nn.Linear(r, j, bias=True)
self.purefield_shared_out = nn.Linear(d * 3, d, bias=False)
class PureFieldBlock(nn.Module):
"""Bounded recurrent memory block with exact local context window."""
def __init__(
self,
cfg: OmegaConfig,
layer_index: int,
shared: PureFieldShared | None = None,
):
super().__init__()
self.cfg = cfg
self.layer_index = layer_index
self.dim = int(cfg.dim)
self.memory_ranks = int(cfg.memory_ranks)
self.timescale_count = int(cfg.timescale_count)
self.local_window = max(1, int(cfg.local_window))
self.contractive_eps = float(cfg.contractive_eps)
self.residual_scale = min(float(cfg.residual_alpha), 1.0 / math.sqrt(max(int(cfg.n_layers), 1)))
self.norm = RMSNorm(self.dim)
self.shared = shared if shared is not None else PureFieldShared(cfg)
rank = int(cfg.low_rank)
self.q_proj = SharedLowRankProjection(
self.shared.purefield_shared_q, LowRankAdapter(self.dim, self.memory_ranks, rank)
)
self.k_proj = SharedLowRankProjection(
self.shared.purefield_shared_k, LowRankAdapter(self.dim, self.memory_ranks, rank)
)
self.v_proj = SharedLowRankProjection(
self.shared.purefield_shared_v, LowRankAdapter(self.dim, self.dim, rank)
)
self.contract_proj = SharedLowRankProjection(
self.shared.purefield_shared_contract, LowRankAdapter(self.dim, self.timescale_count, rank)
)
self.write_proj = SharedLowRankProjection(
self.shared.purefield_shared_write, LowRankAdapter(self.dim, self.timescale_count, rank)
)
self.read_proj = SharedLowRankProjection(
self.shared.purefield_shared_read, LowRankAdapter(self.memory_ranks, self.timescale_count, rank)
)
self.out_proj = SharedLowRankProjection(
self.shared.purefield_shared_out, LowRankAdapter(self.dim * 3, self.dim, rank)
)
self.regen_kv = ReGenesisKVBlock(cfg) if bool(getattr(cfg, "regen_kv_enabled", False)) else None
self.timescale_gain: torch.Tensor
taus = torch.pow(2.0, torch.arange(self.timescale_count, dtype=torch.float32))
self.register_buffer("timescale_gain", torch.rsqrt(taus).view(1, self.timescale_count, 1, 1))
def _empty_memory(self, batch: int, x: torch.Tensor) -> torch.Tensor:
return torch.zeros(
batch,
self.timescale_count,
self.memory_ranks,
self.dim,
device=x.device,
dtype=x.dtype,
)
def _initial_state(self, batch: int, x: torch.Tensor, kv_cache: dict | None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if kv_cache is None:
return (
self._empty_memory(batch, x),
torch.empty(batch, 0, self.memory_ranks, device=x.device, dtype=x.dtype),
torch.empty(batch, 0, self.dim, device=x.device, dtype=x.dtype),
)
memory = kv_cache.get("memory")
if memory is None:
memory = self._empty_memory(batch, x)
else:
memory = memory.to(device=x.device, dtype=x.dtype)
local_k = kv_cache.get("local_k")
local_v = kv_cache.get("local_v")
if local_k is None:
local_k = torch.empty(batch, 0, self.memory_ranks, device=x.device, dtype=x.dtype)
else:
local_k = local_k.to(device=x.device, dtype=x.dtype)
if local_v is None:
local_v = torch.empty(batch, 0, self.dim, device=x.device, dtype=x.dtype)
else:
local_v = local_v.to(device=x.device, dtype=x.dtype)
return memory, local_k[:, -self.local_window :], local_v[:, -self.local_window :]
def forward(
self,
x: torch.Tensor,
kv_cache: dict | None = None,
mask: torch.Tensor | None = None,
return_stats: bool = False,
) -> tuple[torch.Tensor, dict] | tuple[torch.Tensor, dict, dict[str, torch.Tensor]]:
del mask
batch, seq_len, _ = x.shape
u = self.norm(x)
q = F.normalize(self.q_proj(u), dim=-1, eps=1e-6)
k = F.normalize(self.k_proj(u), dim=-1, eps=1e-6)
v = torch.tanh(self.v_proj(u))
c_raw = self.contract_proj(u)
c_gate = torch.exp(-F.softplus(c_raw))
eps = min(max(self.contractive_eps, 1e-6), 0.49)
c_gate = c_gate * (1.0 - 2.0 * eps) + eps
i_gate = torch.sigmoid(self.write_proj(u))
memory, local_k, local_v = self._initial_state(batch, x, kv_cache)
y_mem_parts: list[torch.Tensor] = []
y_win_parts: list[torch.Tensor] = []
for t in range(seq_len):
q_t = q[:, t]
k_t = k[:, t]
v_t = v[:, t]
write = torch.einsum("br,bd->brd", k_t, v_t)
decay = c_gate[:, t].view(batch, self.timescale_count, 1, 1)
insert = i_gate[:, t].view(batch, self.timescale_count, 1, 1)
# ensure timescale_gain is same device and dtype as tensors before multiplying
tg = self.timescale_gain.to(device=x.device, dtype=x.dtype)
memory = decay * memory + insert * (write.unsqueeze(1) * tg)
read = torch.softmax(self.read_proj(q_t), dim=-1)
slot_values = torch.einsum("br,bjrd->bjd", q_t, memory)
y_mem_parts.append((read.unsqueeze(-1) * slot_values).sum(dim=1))
local_k = torch.cat([local_k, k_t.unsqueeze(1)], dim=1)[:, -self.local_window :]
local_v = torch.cat([local_v, v_t.unsqueeze(1)], dim=1)[:, -self.local_window :]
scores = torch.einsum("br,blr->bl", q_t, local_k) / math.sqrt(max(self.memory_ranks, 1))
weights = torch.softmax(scores, dim=-1)
y_win_parts.append(torch.einsum("bl,bld->bd", weights, local_v))
y_mem = torch.stack(y_mem_parts, dim=1)
y_win = torch.stack(y_win_parts, dim=1)
y = self.out_proj(torch.cat([y_mem, y_win, u], dim=-1))
out = x + self.residual_scale * torch.tanh(y)
new_cache = {"memory": memory, "local_k": local_k, "local_v": local_v}
if self.regen_kv is not None:
retrieved_tokens = kv_cache.get("retrieved_tokens") if kv_cache is not None else None
out, regen_cache, regen_aux = self.regen_kv(
out,
memory_state=memory,
retrieved_tokens=retrieved_tokens,
return_aux=True,
)
new_cache["regen_kv"] = regen_cache
new_cache["regen_aux"] = regen_aux
if return_stats:
stats = {
"contractive_gate": c_gate.detach(),
"write_gate": i_gate.detach(),
"memory_norm": memory.detach().norm(dim=(-1, -2)),
"residual_scale": torch.tensor(self.residual_scale, device=x.device, dtype=x.dtype),
}
if "regen_aux" in new_cache:
stats["hash_consistency_loss"] = new_cache["regen_aux"]["hash_consistency_loss"].detach()
return out, new_cache, stats
return out, new_cache

Xet Storage Details

Size:
9.32 kB
·
Xet hash:
e17e5ba32caf025eeaec82c08ece69f04babb95cf12896ce6fd8a55926dc00f7

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.