Buckets:
| """AxiomFlow HyperWeave reference block. | |
| AxiomFlow is a TinyMind-native long-context reference design. It is built to | |
| separate three jobs that are often conflated: | |
| - exact recent-token work in a bounded local lane, | |
| - exact long-context recall through external retrieved evidence, | |
| - constant-size recurrent compression for model state. | |
| This module is intentionally a PyTorch reference, not a CUDA superiority claim. | |
| Performance claims are gated by evaluation reports. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| from .layers import RMSNorm | |
| class AxiomFlowConfig: | |
| dim: int = 1024 | |
| local_window: int = 128 | |
| memory_slots: int = 8 | |
| memory_rank: int = 32 | |
| retrieval_top_k: int = 8 | |
| retrieved_chunk_tokens: int = 8192 | |
| contractive_eps: float = 1e-3 | |
| residual_alpha: float = 0.2 | |
| class AxiomFlowState: | |
| local_exact: torch.Tensor | |
| compressed_memory: torch.Tensor | |
| def cached_token_capacity(self) -> int: | |
| return int(self.local_exact.shape[1] + self.compressed_memory.shape[1]) | |
| class HyperWeaveScheduler(nn.Module): | |
| """Token router over local, ledger, and compressed lanes.""" | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.router = nn.Linear(dim, 3, bias=True) | |
| def forward(self, hidden: torch.Tensor) -> torch.Tensor: | |
| return torch.softmax(self.router(hidden), dim=-1) | |
| class AxiomFlowBlock(nn.Module): | |
| """Reference HyperWeave block with bounded memory growth.""" | |
| def __init__(self, cfg: AxiomFlowConfig): | |
| super().__init__() | |
| if cfg.dim <= 0: | |
| raise ValueError("dim must be positive") | |
| if cfg.local_window <= 0 or cfg.memory_slots <= 0 or cfg.memory_rank <= 0: | |
| raise ValueError("local_window, memory_slots, and memory_rank must be positive") | |
| self.cfg = cfg | |
| self.norm = RMSNorm(cfg.dim) | |
| self.q_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) | |
| self.k_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) | |
| self.v_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) | |
| self.ledger_q = nn.Linear(cfg.dim, cfg.dim, bias=False) | |
| self.ledger_k = nn.Linear(cfg.dim, cfg.dim, bias=False) | |
| self.ledger_v = nn.Linear(cfg.dim, cfg.dim, bias=False) | |
| self.write_proj = nn.Linear(cfg.dim, cfg.memory_rank, bias=False) | |
| self.memory_to_dim = nn.Linear(cfg.memory_rank, cfg.dim, bias=False) | |
| self.memory_gate = nn.Linear(cfg.dim, cfg.memory_slots, bias=True) | |
| self.slot_query = nn.Linear(cfg.dim, cfg.memory_slots, bias=False) | |
| self.scheduler = HyperWeaveScheduler(cfg.dim) | |
| self.out_proj = nn.Linear(cfg.dim, cfg.dim, bias=False) | |
| self.slot_bias = nn.Parameter(torch.zeros(cfg.memory_slots, cfg.memory_rank)) | |
| def _empty_state(self, batch: int, device: torch.device, dtype: torch.dtype) -> AxiomFlowState: | |
| local = torch.zeros(batch, 0, self.cfg.dim, device=device, dtype=dtype) | |
| memory = torch.zeros(batch, self.cfg.memory_slots, self.cfg.memory_rank, device=device, dtype=dtype) | |
| return AxiomFlowState(local_exact=local, compressed_memory=memory) | |
| def _update_local_exact(self, hidden: torch.Tensor, state: AxiomFlowState) -> torch.Tensor: | |
| local = torch.cat([state.local_exact.to(hidden.device, hidden.dtype), hidden.detach()], dim=1) | |
| return local[:, -self.cfg.local_window :].contiguous() | |
| def _local_lane(self, hidden: torch.Tensor, local_exact: torch.Tensor) -> torch.Tensor: | |
| if local_exact.shape[1] == 0: | |
| return torch.zeros_like(hidden) | |
| q = self.q_proj(hidden) | |
| k = self.k_proj(local_exact) | |
| v = self.v_proj(local_exact) | |
| scale = 1.0 / math.sqrt(max(self.cfg.dim, 1)) | |
| weights = torch.softmax(torch.einsum("bsd,bwd->bsw", q, k) * scale, dim=-1) | |
| return torch.einsum("bsw,bwd->bsd", weights, v) | |
| def _ledger_lane(self, hidden: torch.Tensor, retrieved_chunks: torch.Tensor | None) -> torch.Tensor: | |
| if retrieved_chunks is None: | |
| return torch.zeros_like(hidden) | |
| if retrieved_chunks.dim() != 4: | |
| raise ValueError("retrieved_chunks must have shape [batch, top_k, chunk_tokens, dim]") | |
| chunks = retrieved_chunks.to(hidden.device, hidden.dtype) | |
| if chunks.shape[0] != hidden.shape[0] or chunks.shape[-1] != self.cfg.dim: | |
| raise ValueError("retrieved_chunks batch and dim must match hidden") | |
| chunks = chunks[:, : self.cfg.retrieval_top_k].mean(dim=2) | |
| q = self.ledger_q(hidden) | |
| k = self.ledger_k(chunks) | |
| v = self.ledger_v(chunks) | |
| scale = 1.0 / math.sqrt(max(self.cfg.dim, 1)) | |
| weights = torch.softmax(torch.einsum("bsd,bkd->bsk", q, k) * scale, dim=-1) | |
| return torch.einsum("bsk,bkd->bsd", weights, v) | |
| def _update_memory(self, hidden: torch.Tensor, state: AxiomFlowState) -> torch.Tensor: | |
| summary = hidden.mean(dim=1) | |
| contract = torch.exp(-torch.nn.functional.softplus(self.memory_gate(summary))).unsqueeze(-1) | |
| contract = contract.clamp(max=1.0 - self.cfg.contractive_eps) | |
| write = torch.tanh(self.write_proj(summary)).unsqueeze(1) | |
| memory = state.compressed_memory.to(hidden.device, hidden.dtype) | |
| return contract * memory + (1.0 - contract) * (write + self.slot_bias.unsqueeze(0)) | |
| def _compressed_lane(self, hidden: torch.Tensor, memory: torch.Tensor) -> torch.Tensor: | |
| slot_scores = torch.softmax(self.slot_query(hidden), dim=-1) | |
| memory_values = self.memory_to_dim(memory) | |
| return torch.einsum("bsk,bkd->bsd", slot_scores, memory_values) | |
| def _coherence_gate( | |
| self, | |
| local: torch.Tensor, | |
| ledger: torch.Tensor, | |
| compressed: torch.Tensor, | |
| route: torch.Tensor, | |
| ) -> dict: | |
| lane_names = ["local_exact", "ledger_recall", "compressed_field"] | |
| lane_tensors = [local, ledger, compressed] | |
| lane_norms = [float(t.detach().norm(dim=-1).mean().cpu()) for t in lane_tensors] | |
| route_mean = route.detach().mean(dim=(0, 1)) | |
| route_values = [float(x.cpu()) for x in route_mean] | |
| contributions = [lane_norms[i] * route_values[i] for i in range(3)] | |
| zero_work_lanes = [ | |
| lane_names[i] | |
| for i, value in enumerate(contributions) | |
| if value <= 1e-8 or route_values[i] <= 1e-4 | |
| ] | |
| productive_lane_count = 3 - len(zero_work_lanes) | |
| support_floor = min(contributions) | |
| support_ceiling = max(contributions) | |
| mutual_support_score = support_floor / (support_ceiling + 1e-8) | |
| return { | |
| "passed": bool(productive_lane_count == 3 and mutual_support_score > 0.01), | |
| "productive_lane_count": productive_lane_count, | |
| "zero_work_lanes": zero_work_lanes, | |
| "lane_norms": dict(zip(lane_names, lane_norms)), | |
| "route_weights": dict(zip(lane_names, route_values)), | |
| "lane_contributions": dict(zip(lane_names, contributions)), | |
| "mutual_support_score": float(mutual_support_score), | |
| } | |
| def _intensity_pressure( | |
| self, | |
| local: torch.Tensor, | |
| ledger: torch.Tensor, | |
| compressed: torch.Tensor, | |
| route: torch.Tensor, | |
| ) -> dict: | |
| lane_stack = torch.stack([local, ledger, compressed], dim=2) | |
| lane_flat = lane_stack.flatten(0, 1) | |
| lane_normed = torch.nn.functional.normalize(lane_flat, dim=-1, eps=1e-6) | |
| cosine = torch.einsum("bld,bmd->blm", lane_normed, lane_normed) | |
| pair_idx = torch.triu_indices(3, 3, offset=1, device=cosine.device) | |
| off_diag = cosine[:, pair_idx[0], pair_idx[1]] | |
| max_pairwise_abs_cosine = off_diag.detach().abs().max() | |
| route_mean = route.mean(dim=(0, 1)) | |
| target = torch.full_like(route_mean, 1.0 / 3.0) | |
| route_balance_loss = (route_mean - target).pow(2).mean() | |
| lane_diversity_loss = torch.relu(off_diag.abs() - 0.60).pow(2).mean() | |
| contribution_loss = torch.relu(torch.tensor(0.01, device=route.device, dtype=route.dtype) - route_mean).pow(2).mean() | |
| aux_loss = route_balance_loss + 0.25 * lane_diversity_loss + contribution_loss | |
| max_route_weight = route_mean.detach().max() | |
| return { | |
| "aux_loss": aux_loss, | |
| "route_balance_loss": route_balance_loss.detach(), | |
| "lane_diversity_loss": lane_diversity_loss.detach(), | |
| "contribution_loss": contribution_loss.detach(), | |
| "anti_collapse_gate": { | |
| "passed": bool(float(max_route_weight.cpu()) < 0.90), | |
| "max_route_weight": float(max_route_weight.cpu()), | |
| "target_route_weight": 1.0 / 3.0, | |
| }, | |
| "lane_diversity_gate": { | |
| "passed": bool(float(max_pairwise_abs_cosine.cpu()) < 0.98), | |
| "max_pairwise_abs_cosine": float(max_pairwise_abs_cosine.cpu()), | |
| }, | |
| } | |
| def forward( | |
| self, | |
| hidden: torch.Tensor, | |
| state: AxiomFlowState | None = None, | |
| retrieved_chunks: torch.Tensor | None = None, | |
| return_metrics: bool = False, | |
| ) -> tuple[torch.Tensor, AxiomFlowState] | tuple[torch.Tensor, AxiomFlowState, dict]: | |
| if hidden.dim() != 3 or hidden.shape[-1] != self.cfg.dim: | |
| raise ValueError("hidden must have shape [batch, seq, dim]") | |
| if state is None: | |
| state = self._empty_state(hidden.shape[0], hidden.device, hidden.dtype) | |
| u = self.norm(hidden) | |
| next_local = self._update_local_exact(u, state) | |
| next_memory = self._update_memory(u, state) | |
| local = self._local_lane(u, next_local) | |
| ledger = self._ledger_lane(u, retrieved_chunks) | |
| compressed = self._compressed_lane(u, next_memory) | |
| route = self.scheduler(u) | |
| mixed = ( | |
| route[..., 0:1] * local | |
| + route[..., 1:2] * ledger | |
| + route[..., 2:3] * compressed | |
| ) | |
| alpha = min(float(self.cfg.residual_alpha), 1.0) | |
| out = hidden + alpha * torch.tanh(self.out_proj(mixed)) | |
| next_state = AxiomFlowState(local_exact=next_local, compressed_memory=next_memory) | |
| if not return_metrics: | |
| return out, next_state | |
| route_entropy = -(route * torch.log(route.clamp_min(1e-8))).sum(dim=-1).mean() | |
| metrics = { | |
| "route_weights_mean": route.detach().mean(dim=(0, 1)), | |
| "route_entropy": route_entropy.detach(), | |
| "local_exact_tokens_stored": int(next_local.shape[1]), | |
| "long_context_kv_tokens_stored": 0, | |
| "bounded_state_tokens_equivalent": int(self.cfg.local_window + self.cfg.memory_slots), | |
| "compressed_memory_norm": next_memory.detach().norm(dim=-1).mean(), | |
| "coherence_gate": self._coherence_gate(local, ledger, compressed, route), | |
| "intensity_pressure": self._intensity_pressure(local, ledger, compressed, route), | |
| } | |
| return out, next_state, metrics | |
Xet Storage Details
- Size:
- 11 kB
- Xet hash:
- dda34636445dc2909fae7a6cb85fc3a1e7c363205537f0703767d5ef28cf9a80
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.