sparse-transformer-experiments / experiments /sparse_transformer_v15_inactive_prediction.py
theapemachine's picture
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb
"""
Sparse Transformer v15: Inactive-Update Prediction Diagnostics.
Tests two simple ideas:
1. Correlated-neighbor prediction:
Use active chunks as sensors. For each inactive chunk, find historically
correlated active chunks and predict its update magnitude from them.
2. Graph / boundary interpolation:
Treat chunks as nodes in a learned similarity graph. Active chunks are
boundary values. Inactive chunk magnitudes are filled in by diffusion.
This is intentionally a diagnostic script, not a speed benchmark.
It computes dense gradients every step so we can measure whether inactive
updates are predictable.
Run:
python3 sparse_transformer_v15_inactive_prediction.py --device mps --benchmark_sync
Good first runs:
python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 512
python3 sparse_transformer_v15_inactive_prediction.py --device mps --steps 300 --n_embd 1024
"""
import argparse
import math
import random
import time
from typing import Dict, List, Literal, Optional, Tuple
import torch
torch.set_num_threads(1)
import torch.nn as nn
import torch.nn.functional as F
Policy = Literal["predicted_magnitude", "random"]
def sync_device(device: str) -> None:
if device == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif device == "mps" and hasattr(torch, "mps"):
torch.mps.synchronize()
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
def make_cpu_generator(seed: int) -> torch.Generator:
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
return gen
# -----------------------------
# Data
# -----------------------------
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
rng = random.Random(seed)
words = [
"ada", "turing", "grace", "lovelace", "gradients",
"tokens", "circuits", "features", "boldly", "strangely",
"matrix", "attention", "kernel", "entropy", "signal",
]
return "\n".join(
" ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
for _ in range(n_sentences)
)
class CharCorpus:
def __init__(self, text: str, block_size: int, device: str):
chars = sorted(set(text))
self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for ch, i in self.stoi.items()}
self.vocab_size = len(chars)
self.block_size = block_size
self.device = device
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
self.train_data = data[: int(0.9 * len(data))]
self.val_data = data[int(0.9 * len(data)) :]
def get_batch(
self,
split: str,
batch_size: int,
generator: Optional[torch.Generator] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
data = self.train_data if split == "train" else self.val_data
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
x = torch.stack([data[i : i + self.block_size] for i in ix])
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
return x.to(self.device), y.to(self.device)
# -----------------------------
# Model
# -----------------------------
class SparseLinear(nn.Linear):
"""Name retained for compatibility with earlier experiments.
In this diagnostic script, backward is dense. We only use chunk masks
analytically after gradients are computed.
"""
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.head_dim = n_embd // n_head
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
self.c_proj = SparseLinear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
qkv = self.c_attn(x)
q, k, v = qkv.split(C, dim=2)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.c_proj(y)
class FeedForward(nn.Module):
def __init__(self, n_embd: int, dropout: float):
super().__init__()
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
self.c_proj = SparseLinear(4 * n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
class Block(nn.Module):
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = FeedForward(n_embd, dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class MiniGPT(nn.Module):
def __init__(
self,
vocab_size: int,
block_size: int,
n_layer: int,
n_head: int,
n_embd: int,
dropout: float,
):
super().__init__()
self.block_size = block_size
self.tok_emb = nn.Embedding(vocab_size, n_embd)
self.pos_emb = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(
*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
)
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
B, T = idx.shape
pos = torch.arange(T, device=idx.device)
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
return [m for m in model.modules() if isinstance(m, SparseLinear)]
# -----------------------------
# Chunk geometry and diagnostics
# -----------------------------
class ChunkMap:
def __init__(self, model: nn.Module, chunk_size: int, device: str):
self.model = model
self.chunk_size = chunk_size
self.device = device
self.linears = get_sparse_linears(model)
self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
self.chunk_to_module_local: List[Tuple[nn.Module, int]] = []
offset = 0
for m in self.linears:
assert m.out_features % chunk_size == 0, (
f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
)
n_chunks = m.out_features // chunk_size
ids = torch.arange(offset, offset + n_chunks, device=device)
self.module_to_chunk_ids[m] = ids
for local_c in range(n_chunks):
self.chunk_to_module_local.append((m, local_c))
offset += n_chunks
self.n_chunks = offset
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
self.direction_ema: List[Optional[torch.Tensor]] = [None for _ in range(self.n_chunks)]
# Histories for correlation and graph similarities.
self.mass_history: List[torch.Tensor] = []
def choose_active(
self,
step: int,
warmup_steps: int,
active_fraction: float,
policy: Policy,
) -> torch.Tensor:
if step < warmup_steps:
return torch.ones(self.n_chunks, dtype=torch.bool, device=self.device)
k = max(1, int(active_fraction * self.n_chunks))
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
if policy == "random":
idx = torch.randperm(self.n_chunks, device=self.device)[:k]
else:
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
idx = torch.topk(scores, k=k).indices
mask[idx] = True
return mask
@torch.no_grad()
def chunk_gradient_vectors(self) -> List[torch.Tensor]:
vecs: List[torch.Tensor] = []
for m, local_c in self.chunk_to_module_local:
start = local_c * self.chunk_size
end = (local_c + 1) * self.chunk_size
parts = []
if m.weight.grad is None:
parts.append(torch.zeros_like(m.weight[start:end]).flatten())
else:
parts.append(m.weight.grad[start:end].detach().flatten())
if m.bias is not None:
if m.bias.grad is None:
parts.append(torch.zeros_like(m.bias[start:end]).flatten())
else:
parts.append(m.bias.grad[start:end].detach().flatten())
vecs.append(torch.cat(parts))
return vecs
@torch.no_grad()
def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor:
return torch.stack([v.norm() for v in vecs]).to(self.device)
@torch.no_grad()
def update_predictor(
self,
active_mask: torch.Tensor,
vecs: List[torch.Tensor],
mass_beta: float = 0.95,
dir_beta: float = 0.95,
store_history: bool = True,
) -> torch.Tensor:
masses = self.chunk_masses_from_vecs(vecs)
observed = active_mask
# First observation should initialize directly, not get shrunk by beta.
never_seen = observed & (self.predicted_mass == 0)
already_seen = observed & ~never_seen
self.predicted_mass[never_seen] = masses[never_seen]
self.predicted_mass[already_seen] = (
mass_beta * self.predicted_mass[already_seen]
+ (1.0 - mass_beta) * masses[already_seen]
)
for i, is_active in enumerate(observed.tolist()):
if not is_active:
continue
v = vecs[i]
n = v.norm()
if n <= 1e-12:
continue
unit = v / n
if self.direction_ema[i] is None:
self.direction_ema[i] = unit.detach().clone()
else:
self.direction_ema[i] = (
dir_beta * self.direction_ema[i] + (1.0 - dir_beta) * unit
)
self.direction_ema[i] = self.direction_ema[i] / (self.direction_ema[i].norm() + 1e-12)
if store_history:
self.mass_history.append(masses.detach().clone())
max_hist = 128
if len(self.mass_history) > max_hist:
self.mass_history = self.mass_history[-max_hist:]
return masses
def layer_aware_masks(self) -> List[torch.Tensor]:
masks = []
for m, ids in self.module_to_chunk_ids.items():
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
mask[ids] = True
masks.append(mask)
return masks
def dense_cosine_from_vecs(a: List[torch.Tensor], b: List[torch.Tensor]) -> float:
va = torch.cat([x.flatten() for x in a])
vb = torch.cat([x.flatten() for x in b])
return float(F.cosine_similarity(va, vb, dim=0).item())
def mse_reduction_vs_zero(true_vecs: List[torch.Tensor], pred_vecs: List[torch.Tensor], mask: torch.Tensor) -> float:
idxs = torch.nonzero(mask, as_tuple=False).flatten().tolist()
if not idxs:
return float("nan")
true = torch.cat([true_vecs[i].flatten() for i in idxs])
pred = torch.cat([pred_vecs[i].flatten() for i in idxs])
zero_mse = torch.mean(true.square())
pred_mse = torch.mean((true - pred).square())
return float((1.0 - pred_mse / (zero_mse + 1e-12)).item())
def active_only_prediction(true_vecs: List[torch.Tensor], active_mask: torch.Tensor) -> List[torch.Tensor]:
out = []
for i, v in enumerate(true_vecs):
out.append(v.clone() if bool(active_mask[i]) else torch.zeros_like(v))
return out
def ema_direction_prediction(
cmap: ChunkMap,
true_vecs: List[torch.Tensor],
active_mask: torch.Tensor,
inactive_magnitudes: torch.Tensor,
) -> List[torch.Tensor]:
out = []
for i, v in enumerate(true_vecs):
if bool(active_mask[i]):
out.append(v.clone())
else:
direction = cmap.direction_ema[i]
if direction is None:
out.append(torch.zeros_like(v))
else:
out.append(direction.to(v.device, v.dtype) * inactive_magnitudes[i])
return out
def build_mass_similarity(cmap: ChunkMap, min_history: int = 8) -> Optional[torch.Tensor]:
if len(cmap.mass_history) < min_history:
return None
H = torch.stack(cmap.mass_history, dim=0) # [history, chunks]
H = H - H.mean(dim=0, keepdim=True)
H = H / (H.std(dim=0, keepdim=True) + 1e-6)
S = (H.T @ H) / max(1, H.shape[0] - 1)
S = torch.clamp(S, min=0.0)
# Remove self similarity.
S.fill_diagonal_(0.0)
# Layer-aware block diagonal: avoid mixing unrelated layers by default.
layer_masks = cmap.layer_aware_masks()
layer_allowed = torch.zeros_like(S, dtype=torch.bool)
for mask in layer_masks:
layer_allowed |= mask[:, None] & mask[None, :]
S = torch.where(layer_allowed, S, torch.zeros_like(S))
return S
def knn_magnitude_prediction(
cmap: ChunkMap,
active_mask: torch.Tensor,
true_masses: torch.Tensor,
k_neighbors: int = 3,
) -> torch.Tensor:
"""Predict inactive magnitudes as weighted average of correlated active magnitudes."""
S = build_mass_similarity(cmap)
if S is None:
pred = cmap.predicted_mass.clone()
pred[active_mask] = true_masses[active_mask]
return pred
pred = torch.zeros_like(true_masses)
pred[active_mask] = true_masses[active_mask]
active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
if active_idx.numel() == 0:
return pred
for i in inactive_idx.tolist():
weights = S[i, active_idx]
if weights.sum() <= 1e-12:
pred[i] = cmap.predicted_mass[i]
continue
kk = min(k_neighbors, weights.numel())
top = torch.topk(weights, k=kk)
w = top.values
aidx = active_idx[top.indices]
pred[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12)
return pred
def graph_diffusion_magnitude_prediction(
cmap: ChunkMap,
active_mask: torch.Tensor,
true_masses: torch.Tensor,
diffusion_steps: int = 8,
alpha: float = 0.7,
) -> torch.Tensor:
"""Boundary-value style magnitude interpolation over a learned similarity graph.
Active nodes are clamped to observed true magnitudes. Inactive nodes diffuse
toward graph-neighbor values.
"""
S = build_mass_similarity(cmap)
if S is None:
pred = cmap.predicted_mass.clone()
pred[active_mask] = true_masses[active_mask]
return pred
W = S / (S.sum(dim=1, keepdim=True) + 1e-12)
pred = cmap.predicted_mass.clone()
pred[active_mask] = true_masses[active_mask]
for _ in range(diffusion_steps):
proposal = W @ pred
pred = alpha * proposal + (1.0 - alpha) * pred
pred[active_mask] = true_masses[active_mask]
return torch.clamp(pred, min=0.0)
# -----------------------------
# Optimizer
# -----------------------------
class SimpleAdam:
"""Small Adam-like optimizer for diagnostics.
This is intentionally simple and consistent across runs. It is not trying
to be production AdamW.
"""
def __init__(self, model: nn.Module, lr: float = 3e-4):
self.model = model
self.lr = lr
self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
def zero_grad(self):
for p in self.model.parameters():
p.grad = None
@torch.no_grad()
def step(self):
for p in self.model.parameters():
if p.grad is None:
continue
if p not in self.state:
self.state[p] = {
"m": torch.zeros_like(p),
"v": torch.zeros_like(p),
}
m = self.state[p]["m"]
v = self.state[p]["v"]
m.mul_(0.9).add_(p.grad, alpha=0.1)
v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
# -----------------------------
# Apply chunk-gradient predictions
# -----------------------------
@torch.no_grad()
def install_chunk_prediction_as_grads(
cmap: ChunkMap,
pred_vecs: List[torch.Tensor],
):
"""Overwrite SparseLinear weight/bias grads from predicted chunk vectors.
Non-SparseLinear parameters keep their dense gradients.
"""
for m, ids in cmap.module_to_chunk_ids.items():
if m.weight.grad is None:
continue
m.weight.grad.zero_()
if m.bias is not None and m.bias.grad is not None:
m.bias.grad.zero_()
for local_c, global_id in enumerate(ids.tolist()):
start = local_c * cmap.chunk_size
end = (local_c + 1) * cmap.chunk_size
v = pred_vecs[global_id]
w_numel = cmap.chunk_size * m.weight.shape[1]
w_flat = v[:w_numel]
m.weight.grad[start:end] = w_flat.view(cmap.chunk_size, m.weight.shape[1])
if m.bias is not None and m.bias.grad is not None:
b_flat = v[w_numel:]
if b_flat.numel() > 0:
m.bias.grad[start:end] = b_flat.view(cmap.chunk_size)
# -----------------------------
# Training / diagnostics
# -----------------------------
def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
model.eval()
with torch.no_grad():
x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
_, loss = model(x, y)
model.train()
return float(loss.item())
def run_experiment(
mode: str,
device: str,
steps: int,
batch_size: int,
block_size: int,
n_layer: int,
n_head: int,
n_embd: int,
chunk_size: int,
active_fraction: float,
warmup_steps: int,
policy: Policy,
benchmark_sync: bool,
) -> Dict[str, float]:
set_seed(42)
corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
opt = SimpleAdam(model, lr=3e-4)
cmap = ChunkMap(model, chunk_size=chunk_size, device=device)
metric_rows = []
if benchmark_sync:
sync_device(device)
t0 = time.perf_counter()
for step in range(steps):
x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
opt.zero_grad()
_, loss = model(x, y)
loss.backward()
true_vecs = cmap.chunk_gradient_vectors()
true_masses = cmap.chunk_masses_from_vecs(true_vecs)
active_mask = cmap.choose_active(
step=step,
warmup_steps=warmup_steps,
active_fraction=active_fraction,
policy=policy,
)
if step < warmup_steps or mode == "dense":
pred_vecs = [v.clone() for v in true_vecs]
else:
active_only_vecs = active_only_prediction(true_vecs, active_mask)
if mode == "active_only":
pred_vecs = active_only_vecs
elif mode == "knn_magnitude":
pred_masses = knn_magnitude_prediction(cmap, active_mask, true_masses)
pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses)
elif mode == "graph_diffusion":
pred_masses = graph_diffusion_magnitude_prediction(cmap, active_mask, true_masses)
pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses)
elif mode == "ema_inactive":
pred_masses = cmap.predicted_mass.clone()
pred_masses[active_mask] = true_masses[active_mask]
pred_vecs = ema_direction_prediction(cmap, true_vecs, active_mask, pred_masses)
else:
raise ValueError(f"Unknown mode: {mode}")
install_chunk_prediction_as_grads(cmap, pred_vecs)
if step % 25 == 0:
inactive_mask = ~active_mask
row = {
"cosine_full": dense_cosine_from_vecs(true_vecs, pred_vecs),
"inactive_mse_reduction": mse_reduction_vs_zero(true_vecs, pred_vecs, inactive_mask),
"active_frac": float(active_mask.float().mean().item()),
"val": evaluate(model, corpus, batch_size, seed=999 + step),
}
metric_rows.append(row)
# Update predictor after measuring and installing predicted grads.
# Use true active chunk observations only, mimicking sparse observation.
cmap.update_predictor(active_mask, true_vecs, store_history=True)
opt.step()
if benchmark_sync:
sync_device(device)
elapsed = time.perf_counter() - t0
val_loss = evaluate(model, corpus, batch_size, seed=12345)
if metric_rows:
avg_cos = sum(r["cosine_full"] for r in metric_rows) / len(metric_rows)
avg_mse_red = sum(r["inactive_mse_reduction"] for r in metric_rows) / len(metric_rows)
else:
avg_cos = float("nan")
avg_mse_red = float("nan")
return {
"val": val_loss,
"ms": 1000.0 * elapsed / steps,
"cos": avg_cos,
"mse_red": avg_mse_red,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=300)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--block_size", type=int, default=128)
parser.add_argument("--n_layer", type=int, default=4)
parser.add_argument("--n_head", type=int, default=8)
parser.add_argument("--n_embd", type=int, default=512)
parser.add_argument("--chunk_size", type=int, default=64)
parser.add_argument("--active_fraction", type=float, default=0.10)
parser.add_argument("--warmup_steps", type=int, default=25)
parser.add_argument("--policy", type=str, default="predicted_magnitude", choices=["predicted_magnitude", "random"])
parser.add_argument("--device", type=str, default="mps")
parser.add_argument("--benchmark_sync", action="store_true")
args = parser.parse_args()
modes = [
"dense",
"active_only",
"ema_inactive",
"knn_magnitude",
"graph_diffusion",
]
print(f"\nInactive-update prediction diagnostic")
print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}")
print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps} policy={args.policy}\n")
print(f"{'mode':>18s} | {'val':>8s} | {'ms/step':>8s} | {'grad_cos':>8s} | {'inactive_mse+':>13s}")
print("-" * 70)
for mode in modes:
result = run_experiment(
mode=mode,
device=args.device,
steps=args.steps,
batch_size=args.batch_size,
block_size=args.block_size,
n_layer=args.n_layer,
n_head=args.n_head,
n_embd=args.n_embd,
chunk_size=args.chunk_size,
active_fraction=args.active_fraction,
warmup_steps=args.warmup_steps,
policy=args.policy,
benchmark_sync=args.benchmark_sync,
)
print(
f"{mode:>18s} | "
f"{result['val']:8.4f} | "
f"{result['ms']:8.2f} | "
f"{result['cos']:8.3f} | "
f"{result['mse_red']:13.3f}"
)
if __name__ == "__main__":
main()