bbkdevops's picture
download
raw
11 kB
"""AxiomFlow HyperWeave reference block.
AxiomFlow is a TinyMind-native long-context reference design. It is built to
separate three jobs that are often conflated:
- exact recent-token work in a bounded local lane,
- exact long-context recall through external retrieved evidence,
- constant-size recurrent compression for model state.
This module is intentionally a PyTorch reference, not a CUDA superiority claim.
Performance claims are gated by evaluation reports.
"""
from __future__ import annotations
from dataclasses import dataclass
import math
import torch
import torch.nn as nn
from .layers import RMSNorm
@dataclass
class AxiomFlowConfig:
dim: int = 1024
local_window: int = 128
memory_slots: int = 8
memory_rank: int = 32
retrieval_top_k: int = 8
retrieved_chunk_tokens: int = 8192
contractive_eps: float = 1e-3
residual_alpha: float = 0.2
@dataclass
class AxiomFlowState:
local_exact: torch.Tensor
compressed_memory: torch.Tensor
def cached_token_capacity(self) -> int:
return int(self.local_exact.shape[1] + self.compressed_memory.shape[1])
class HyperWeaveScheduler(nn.Module):
"""Token router over local, ledger, and compressed lanes."""
def __init__(self, dim: int):
super().__init__()
self.router = nn.Linear(dim, 3, bias=True)
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
return torch.softmax(self.router(hidden), dim=-1)
class AxiomFlowBlock(nn.Module):
"""Reference HyperWeave block with bounded memory growth."""
def __init__(self, cfg: AxiomFlowConfig):
super().__init__()
if cfg.dim <= 0:
raise ValueError("dim must be positive")
if cfg.local_window <= 0 or cfg.memory_slots <= 0 or cfg.memory_rank <= 0:
raise ValueError("local_window, memory_slots, and memory_rank must be positive")
self.cfg = cfg
self.norm = RMSNorm(cfg.dim)
self.q_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.k_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.v_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.ledger_q = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.ledger_k = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.ledger_v = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.write_proj = nn.Linear(cfg.dim, cfg.memory_rank, bias=False)
self.memory_to_dim = nn.Linear(cfg.memory_rank, cfg.dim, bias=False)
self.memory_gate = nn.Linear(cfg.dim, cfg.memory_slots, bias=True)
self.slot_query = nn.Linear(cfg.dim, cfg.memory_slots, bias=False)
self.scheduler = HyperWeaveScheduler(cfg.dim)
self.out_proj = nn.Linear(cfg.dim, cfg.dim, bias=False)
self.slot_bias = nn.Parameter(torch.zeros(cfg.memory_slots, cfg.memory_rank))
def _empty_state(self, batch: int, device: torch.device, dtype: torch.dtype) -> AxiomFlowState:
local = torch.zeros(batch, 0, self.cfg.dim, device=device, dtype=dtype)
memory = torch.zeros(batch, self.cfg.memory_slots, self.cfg.memory_rank, device=device, dtype=dtype)
return AxiomFlowState(local_exact=local, compressed_memory=memory)
def _update_local_exact(self, hidden: torch.Tensor, state: AxiomFlowState) -> torch.Tensor:
local = torch.cat([state.local_exact.to(hidden.device, hidden.dtype), hidden.detach()], dim=1)
return local[:, -self.cfg.local_window :].contiguous()
def _local_lane(self, hidden: torch.Tensor, local_exact: torch.Tensor) -> torch.Tensor:
if local_exact.shape[1] == 0:
return torch.zeros_like(hidden)
q = self.q_proj(hidden)
k = self.k_proj(local_exact)
v = self.v_proj(local_exact)
scale = 1.0 / math.sqrt(max(self.cfg.dim, 1))
weights = torch.softmax(torch.einsum("bsd,bwd->bsw", q, k) * scale, dim=-1)
return torch.einsum("bsw,bwd->bsd", weights, v)
def _ledger_lane(self, hidden: torch.Tensor, retrieved_chunks: torch.Tensor | None) -> torch.Tensor:
if retrieved_chunks is None:
return torch.zeros_like(hidden)
if retrieved_chunks.dim() != 4:
raise ValueError("retrieved_chunks must have shape [batch, top_k, chunk_tokens, dim]")
chunks = retrieved_chunks.to(hidden.device, hidden.dtype)
if chunks.shape[0] != hidden.shape[0] or chunks.shape[-1] != self.cfg.dim:
raise ValueError("retrieved_chunks batch and dim must match hidden")
chunks = chunks[:, : self.cfg.retrieval_top_k].mean(dim=2)
q = self.ledger_q(hidden)
k = self.ledger_k(chunks)
v = self.ledger_v(chunks)
scale = 1.0 / math.sqrt(max(self.cfg.dim, 1))
weights = torch.softmax(torch.einsum("bsd,bkd->bsk", q, k) * scale, dim=-1)
return torch.einsum("bsk,bkd->bsd", weights, v)
def _update_memory(self, hidden: torch.Tensor, state: AxiomFlowState) -> torch.Tensor:
summary = hidden.mean(dim=1)
contract = torch.exp(-torch.nn.functional.softplus(self.memory_gate(summary))).unsqueeze(-1)
contract = contract.clamp(max=1.0 - self.cfg.contractive_eps)
write = torch.tanh(self.write_proj(summary)).unsqueeze(1)
memory = state.compressed_memory.to(hidden.device, hidden.dtype)
return contract * memory + (1.0 - contract) * (write + self.slot_bias.unsqueeze(0))
def _compressed_lane(self, hidden: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
slot_scores = torch.softmax(self.slot_query(hidden), dim=-1)
memory_values = self.memory_to_dim(memory)
return torch.einsum("bsk,bkd->bsd", slot_scores, memory_values)
def _coherence_gate(
self,
local: torch.Tensor,
ledger: torch.Tensor,
compressed: torch.Tensor,
route: torch.Tensor,
) -> dict:
lane_names = ["local_exact", "ledger_recall", "compressed_field"]
lane_tensors = [local, ledger, compressed]
lane_norms = [float(t.detach().norm(dim=-1).mean().cpu()) for t in lane_tensors]
route_mean = route.detach().mean(dim=(0, 1))
route_values = [float(x.cpu()) for x in route_mean]
contributions = [lane_norms[i] * route_values[i] for i in range(3)]
zero_work_lanes = [
lane_names[i]
for i, value in enumerate(contributions)
if value <= 1e-8 or route_values[i] <= 1e-4
]
productive_lane_count = 3 - len(zero_work_lanes)
support_floor = min(contributions)
support_ceiling = max(contributions)
mutual_support_score = support_floor / (support_ceiling + 1e-8)
return {
"passed": bool(productive_lane_count == 3 and mutual_support_score > 0.01),
"productive_lane_count": productive_lane_count,
"zero_work_lanes": zero_work_lanes,
"lane_norms": dict(zip(lane_names, lane_norms)),
"route_weights": dict(zip(lane_names, route_values)),
"lane_contributions": dict(zip(lane_names, contributions)),
"mutual_support_score": float(mutual_support_score),
}
def _intensity_pressure(
self,
local: torch.Tensor,
ledger: torch.Tensor,
compressed: torch.Tensor,
route: torch.Tensor,
) -> dict:
lane_stack = torch.stack([local, ledger, compressed], dim=2)
lane_flat = lane_stack.flatten(0, 1)
lane_normed = torch.nn.functional.normalize(lane_flat, dim=-1, eps=1e-6)
cosine = torch.einsum("bld,bmd->blm", lane_normed, lane_normed)
pair_idx = torch.triu_indices(3, 3, offset=1, device=cosine.device)
off_diag = cosine[:, pair_idx[0], pair_idx[1]]
max_pairwise_abs_cosine = off_diag.detach().abs().max()
route_mean = route.mean(dim=(0, 1))
target = torch.full_like(route_mean, 1.0 / 3.0)
route_balance_loss = (route_mean - target).pow(2).mean()
lane_diversity_loss = torch.relu(off_diag.abs() - 0.60).pow(2).mean()
contribution_loss = torch.relu(torch.tensor(0.01, device=route.device, dtype=route.dtype) - route_mean).pow(2).mean()
aux_loss = route_balance_loss + 0.25 * lane_diversity_loss + contribution_loss
max_route_weight = route_mean.detach().max()
return {
"aux_loss": aux_loss,
"route_balance_loss": route_balance_loss.detach(),
"lane_diversity_loss": lane_diversity_loss.detach(),
"contribution_loss": contribution_loss.detach(),
"anti_collapse_gate": {
"passed": bool(float(max_route_weight.cpu()) < 0.90),
"max_route_weight": float(max_route_weight.cpu()),
"target_route_weight": 1.0 / 3.0,
},
"lane_diversity_gate": {
"passed": bool(float(max_pairwise_abs_cosine.cpu()) < 0.98),
"max_pairwise_abs_cosine": float(max_pairwise_abs_cosine.cpu()),
},
}
def forward(
self,
hidden: torch.Tensor,
state: AxiomFlowState | None = None,
retrieved_chunks: torch.Tensor | None = None,
return_metrics: bool = False,
) -> tuple[torch.Tensor, AxiomFlowState] | tuple[torch.Tensor, AxiomFlowState, dict]:
if hidden.dim() != 3 or hidden.shape[-1] != self.cfg.dim:
raise ValueError("hidden must have shape [batch, seq, dim]")
if state is None:
state = self._empty_state(hidden.shape[0], hidden.device, hidden.dtype)
u = self.norm(hidden)
next_local = self._update_local_exact(u, state)
next_memory = self._update_memory(u, state)
local = self._local_lane(u, next_local)
ledger = self._ledger_lane(u, retrieved_chunks)
compressed = self._compressed_lane(u, next_memory)
route = self.scheduler(u)
mixed = (
route[..., 0:1] * local
+ route[..., 1:2] * ledger
+ route[..., 2:3] * compressed
)
alpha = min(float(self.cfg.residual_alpha), 1.0)
out = hidden + alpha * torch.tanh(self.out_proj(mixed))
next_state = AxiomFlowState(local_exact=next_local, compressed_memory=next_memory)
if not return_metrics:
return out, next_state
route_entropy = -(route * torch.log(route.clamp_min(1e-8))).sum(dim=-1).mean()
metrics = {
"route_weights_mean": route.detach().mean(dim=(0, 1)),
"route_entropy": route_entropy.detach(),
"local_exact_tokens_stored": int(next_local.shape[1]),
"long_context_kv_tokens_stored": 0,
"bounded_state_tokens_equivalent": int(self.cfg.local_window + self.cfg.memory_slots),
"compressed_memory_norm": next_memory.detach().norm(dim=-1).mean(),
"coherence_gate": self._coherence_gate(local, ledger, compressed, route),
"intensity_pressure": self._intensity_pressure(local, ledger, compressed, route),
}
return out, next_state, metrics

Xet Storage Details

Size:
11 kB
·
Xet hash:
dda34636445dc2909fae7a6cb85fc3a1e7c363205537f0703767d5ef28cf9a80

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.