sparse-transformer-experiments / experiments /sparse_transformer_v16_sensor_scheduler.py
theapemachine's picture
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb
"""
Sparse Transformer v16: Sensor-Based Mask Scheduling.
v15 showed that directly hallucinating inactive gradient vectors was harmful.
v16 tests the safer next idea:
Use active chunks as sensors to choose which chunks receive real gradients next.
No inactive gradient is invented. In sparse modes, inactive chunks get zero gradient.
The only question is whether active chunk observations improve future mask selection.
Schedulers:
dense
Dense baseline.
ema_topk
Select top chunks by each chunk's own EMA gradient mass.
knn_scheduler
Use active chunks as sensors. Predict next-step inactive chunk mass from
historically correlated active chunks. Select next mask from that score.
graph_scheduler
Boundary-value style magnitude diffusion over a chunk similarity graph.
Active chunks are clamped to observed magnitudes. Inactive magnitudes are
interpolated and used to choose the next mask.
random
Random sparse-support control.
This is still a diagnostic/simulation script: it computes dense gradients so we can
measure oracle Jaccard/cosine, then installs only the selected active chunk gradients
for sparse training.
Run:
python3 sparse_transformer_v16_sensor_scheduler.py --device mps --benchmark_sync
Useful:
python3 sparse_transformer_v16_sensor_scheduler.py --device mps --steps 500 --n_embd 512
python3 sparse_transformer_v16_sensor_scheduler.py --device mps --steps 500 --n_embd 1024
"""
from __future__ import annotations
import argparse
import math
import random
import time
from typing import Dict, List, Literal, Optional, Tuple
import torch
torch.set_num_threads(1)
import torch.nn as nn
import torch.nn.functional as F
Scheduler = Literal["dense", "ema_topk", "knn_scheduler", "graph_scheduler", "random"]
def sync_device(device: str) -> None:
if device == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif device == "mps" and hasattr(torch, "mps"):
torch.mps.synchronize()
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
def make_cpu_generator(seed: int) -> torch.Generator:
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
return gen
# -----------------------------
# Data
# -----------------------------
def make_synthetic_corpus(n_sentences: int = 12000, seed: int = 7) -> str:
rng = random.Random(seed)
words = [
"ada", "turing", "grace", "lovelace", "gradients",
"tokens", "circuits", "features", "boldly", "strangely",
"matrix", "attention", "kernel", "entropy", "signal",
]
return "\n".join(
" ".join(rng.choices(words, k=rng.randint(4, 10))) + "."
for _ in range(n_sentences)
)
class CharCorpus:
def __init__(self, text: str, block_size: int, device: str):
chars = sorted(set(text))
self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for ch, i in self.stoi.items()}
self.vocab_size = len(chars)
self.block_size = block_size
self.device = device
data = torch.tensor([self.stoi[ch] for ch in text], dtype=torch.long)
self.train_data = data[: int(0.9 * len(data))]
self.val_data = data[int(0.9 * len(data)) :]
def get_batch(
self,
split: str,
batch_size: int,
generator: Optional[torch.Generator] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
data = self.train_data if split == "train" else self.val_data
ix = torch.randint(len(data) - self.block_size - 1, (batch_size,), generator=generator)
x = torch.stack([data[i : i + self.block_size] for i in ix])
y = torch.stack([data[i + 1 : i + self.block_size + 1] for i in ix])
return x.to(self.device), y.to(self.device)
# -----------------------------
# Model
# -----------------------------
class SparseLinear(nn.Linear):
pass
class CausalSelfAttention(nn.Module):
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.head_dim = n_embd // n_head
self.c_attn = SparseLinear(n_embd, 3 * n_embd)
self.c_proj = SparseLinear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
qkv = self.c_attn(x)
q, k, v = qkv.split(C, dim=2)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.c_proj(y)
class FeedForward(nn.Module):
def __init__(self, n_embd: int, dropout: float):
super().__init__()
self.c_fc = SparseLinear(n_embd, 4 * n_embd)
self.c_proj = SparseLinear(4 * n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
class Block(nn.Module):
def __init__(self, n_embd: int, n_head: int, block_size: int, dropout: float):
super().__init__()
self.ln1 = nn.LayerNorm(n_embd)
self.attn = CausalSelfAttention(n_embd, n_head, block_size, dropout)
self.ln2 = nn.LayerNorm(n_embd)
self.mlp = FeedForward(n_embd, dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class MiniGPT(nn.Module):
def __init__(
self,
vocab_size: int,
block_size: int,
n_layer: int,
n_head: int,
n_embd: int,
dropout: float,
):
super().__init__()
self.block_size = block_size
self.tok_emb = nn.Embedding(vocab_size, n_embd)
self.pos_emb = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(
*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]
)
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
B, T = idx.shape
pos = torch.arange(T, device=idx.device)
x = self.tok_emb(idx) + self.pos_emb(pos)[None, :, :]
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
def get_sparse_linears(model: nn.Module) -> List[SparseLinear]:
return [m for m in model.modules() if isinstance(m, SparseLinear)]
# -----------------------------
# Chunk map and scheduler
# -----------------------------
class ChunkScheduler:
def __init__(
self,
model: nn.Module,
chunk_size: int,
active_fraction: float,
device: str,
scheduler: Scheduler,
mass_beta: float = 0.95,
):
self.model = model
self.chunk_size = chunk_size
self.active_fraction = active_fraction
self.device = device
self.scheduler = scheduler
self.mass_beta = mass_beta
self.linears = get_sparse_linears(model)
self.module_to_chunk_ids: Dict[nn.Module, torch.Tensor] = {}
self.chunk_to_module_local: List[Tuple[nn.Module, int]] = []
offset = 0
for m in self.linears:
assert m.out_features % chunk_size == 0, (
f"out_features {m.out_features} not divisible by chunk_size {chunk_size}"
)
n_chunks = m.out_features // chunk_size
ids = torch.arange(offset, offset + n_chunks, device=device)
self.module_to_chunk_ids[m] = ids
for local_c in range(n_chunks):
self.chunk_to_module_local.append((m, local_c))
offset += n_chunks
self.n_chunks = offset
self.predicted_mass = torch.zeros(self.n_chunks, device=device)
self.mass_history: List[torch.Tensor] = []
self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=device)
self.next_scores = torch.zeros(self.n_chunks, device=device)
self.prev_mask: Optional[torch.Tensor] = None
self.similarity: Optional[torch.Tensor] = None
def k_active(self) -> int:
return max(1, int(self.active_fraction * self.n_chunks))
def choose_mask(self, step: int, warmup_steps: int) -> torch.Tensor:
if self.scheduler == "dense" or step < warmup_steps:
self.current_mask = torch.ones(self.n_chunks, dtype=torch.bool, device=self.device)
return self.current_mask
k = self.k_active()
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
if self.scheduler == "random":
idx = torch.randperm(self.n_chunks, device=self.device)[:k]
elif self.scheduler == "ema_topk":
scores = self.predicted_mass + 1e-9 * torch.rand_like(self.predicted_mass)
idx = torch.topk(scores, k=k).indices
elif self.scheduler in ("knn_scheduler", "graph_scheduler"):
# next_scores are computed from the previous step's active sensors.
# If unavailable, fall back to EMA.
base = self.next_scores
if torch.count_nonzero(base).item() == 0:
base = self.predicted_mass
scores = base + 1e-9 * torch.rand_like(base)
idx = torch.topk(scores, k=k).indices
else:
raise ValueError(f"Unknown scheduler: {self.scheduler}")
mask[idx] = True
self.current_mask = mask
return mask
@torch.no_grad()
def chunk_gradient_vectors(self) -> List[torch.Tensor]:
vecs: List[torch.Tensor] = []
for m, local_c in self.chunk_to_module_local:
start = local_c * self.chunk_size
end = (local_c + 1) * self.chunk_size
parts = []
if m.weight.grad is None:
parts.append(torch.zeros_like(m.weight[start:end]).flatten())
else:
parts.append(m.weight.grad[start:end].detach().flatten())
if m.bias is not None:
if m.bias.grad is None:
parts.append(torch.zeros_like(m.bias[start:end]).flatten())
else:
parts.append(m.bias.grad[start:end].detach().flatten())
vecs.append(torch.cat(parts))
return vecs
@torch.no_grad()
def chunk_masses_from_vecs(self, vecs: List[torch.Tensor]) -> torch.Tensor:
return torch.stack([v.norm() for v in vecs]).to(self.device)
@torch.no_grad()
def update_from_observed(
self,
active_mask: torch.Tensor,
true_masses: torch.Tensor,
step: int,
warmup_steps: int,
) -> None:
observed = active_mask
never_seen = observed & (self.predicted_mass == 0)
already_seen = observed & ~never_seen
self.predicted_mass[never_seen] = true_masses[never_seen]
self.predicted_mass[already_seen] = (
self.mass_beta * self.predicted_mass[already_seen]
+ (1.0 - self.mass_beta) * true_masses[already_seen]
)
# During warmup we store dense mass histories to learn the similarity graph.
if step < warmup_steps:
self.mass_history.append(true_masses.detach().clone())
max_hist = 128
if len(self.mass_history) > max_hist:
self.mass_history = self.mass_history[-max_hist:]
if len(self.mass_history) >= 8:
self.similarity = self.build_similarity()
# Compute next_scores from current active observations.
if self.scheduler == "knn_scheduler":
self.next_scores = self.knn_scores(active_mask, true_masses)
elif self.scheduler == "graph_scheduler":
self.next_scores = self.diffusion_scores(active_mask, true_masses)
else:
self.next_scores = self.predicted_mass.clone()
def layer_allowed_mask(self) -> torch.Tensor:
allowed = torch.zeros((self.n_chunks, self.n_chunks), dtype=torch.bool, device=self.device)
for _, ids in self.module_to_chunk_ids.items():
allowed |= ids[:, None].eq(ids[None, :]) # placeholder overwritten below
allowed.zero_()
for _, ids in self.module_to_chunk_ids.items():
allowed[ids[:, None], ids[None, :]] = True
return allowed
def build_similarity(self) -> torch.Tensor:
H = torch.stack(self.mass_history, dim=0) # [history, chunks]
H = H - H.mean(dim=0, keepdim=True)
H = H / (H.std(dim=0, keepdim=True) + 1e-6)
S = (H.T @ H) / max(1, H.shape[0] - 1)
S = torch.clamp(S, min=0.0)
S.fill_diagonal_(0.0)
# Keep only within-layer similarities. Cross-layer correlation is too easy
# to overfit in this tiny diagnostic.
allowed = torch.zeros_like(S, dtype=torch.bool)
for _, ids in self.module_to_chunk_ids.items():
allowed[ids[:, None], ids[None, :]] = True
S = torch.where(allowed, S, torch.zeros_like(S))
return S
def knn_scores(self, active_mask: torch.Tensor, true_masses: torch.Tensor, k_neighbors: int = 3) -> torch.Tensor:
if self.similarity is None:
return self.predicted_mass.clone()
S = self.similarity
scores = self.predicted_mass.clone()
scores[active_mask] = true_masses[active_mask]
active_idx = torch.nonzero(active_mask, as_tuple=False).flatten()
inactive_idx = torch.nonzero(~active_mask, as_tuple=False).flatten()
if active_idx.numel() == 0:
return scores
for i in inactive_idx.tolist():
weights = S[i, active_idx]
if weights.sum() <= 1e-12:
continue
kk = min(k_neighbors, weights.numel())
top = torch.topk(weights, k=kk)
w = top.values
aidx = active_idx[top.indices]
scores[i] = (w * true_masses[aidx]).sum() / (w.sum() + 1e-12)
return scores
def diffusion_scores(
self,
active_mask: torch.Tensor,
true_masses: torch.Tensor,
diffusion_steps: int = 8,
alpha: float = 0.7,
) -> torch.Tensor:
if self.similarity is None:
return self.predicted_mass.clone()
S = self.similarity
W = S / (S.sum(dim=1, keepdim=True) + 1e-12)
scores = self.predicted_mass.clone()
scores[active_mask] = true_masses[active_mask]
for _ in range(diffusion_steps):
proposal = W @ scores
scores = alpha * proposal + (1.0 - alpha) * scores
scores[active_mask] = true_masses[active_mask]
return torch.clamp(scores, min=0.0)
def oracle_topk_mask(self, true_masses: torch.Tensor) -> torch.Tensor:
k = self.k_active()
mask = torch.zeros(self.n_chunks, dtype=torch.bool, device=self.device)
mask[torch.topk(true_masses, k=k).indices] = True
return mask
# -----------------------------
# Gradient installation and metrics
# -----------------------------
@torch.no_grad()
def install_active_only_grads(sched: ChunkScheduler, active_mask: torch.Tensor) -> None:
if sched.scheduler == "dense":
return
for m, ids in sched.module_to_chunk_ids.items():
local_active = active_mask[ids]
if m.weight.grad is not None:
for local_c, is_active in enumerate(local_active.tolist()):
if not is_active:
start = local_c * sched.chunk_size
end = (local_c + 1) * sched.chunk_size
m.weight.grad[start:end].zero_()
if m.bias is not None and m.bias.grad is not None:
for local_c, is_active in enumerate(local_active.tolist()):
if not is_active:
start = local_c * sched.chunk_size
end = (local_c + 1) * sched.chunk_size
m.bias.grad[start:end].zero_()
def dense_cosine_active_only(vecs: List[torch.Tensor], active_mask: torch.Tensor) -> float:
true = torch.cat([v.flatten() for v in vecs])
approx_parts = []
for i, v in enumerate(vecs):
approx_parts.append(v.flatten() if bool(active_mask[i]) else torch.zeros_like(v).flatten())
approx = torch.cat(approx_parts)
return float(F.cosine_similarity(true, approx, dim=0).item())
def jaccard(a: torch.Tensor, b: torch.Tensor) -> float:
inter = (a & b).sum().float()
union = (a | b).sum().float()
return float((inter / torch.clamp(union, min=1.0)).item())
class SimpleAdam:
def __init__(self, model: nn.Module, lr: float = 3e-4):
self.model = model
self.lr = lr
self.state: Dict[torch.nn.Parameter, Dict[str, torch.Tensor]] = {}
def zero_grad(self):
for p in self.model.parameters():
p.grad = None
@torch.no_grad()
def step(self):
for p in self.model.parameters():
if p.grad is None:
continue
if p not in self.state:
self.state[p] = {"m": torch.zeros_like(p), "v": torch.zeros_like(p)}
m = self.state[p]["m"]
v = self.state[p]["v"]
m.mul_(0.9).add_(p.grad, alpha=0.1)
v.mul_(0.999).addcmul_(p.grad, p.grad, value=0.001)
p.sub_(m / (torch.sqrt(v) + 1e-8), alpha=self.lr)
def evaluate(model: nn.Module, corpus: CharCorpus, batch_size: int, seed: int) -> float:
model.eval()
with torch.no_grad():
x, y = corpus.get_batch("val", batch_size, generator=make_cpu_generator(seed))
_, loss = model(x, y)
model.train()
return float(loss.item())
def run_experiment(
scheduler_name: Scheduler,
device: str,
steps: int,
batch_size: int,
block_size: int,
n_layer: int,
n_head: int,
n_embd: int,
chunk_size: int,
active_fraction: float,
warmup_steps: int,
benchmark_sync: bool,
) -> Dict[str, float]:
set_seed(42)
corpus = CharCorpus(make_synthetic_corpus(), block_size, device)
model = MiniGPT(corpus.vocab_size, block_size, n_layer, n_head, n_embd, 0.0).to(device)
opt = SimpleAdam(model, lr=3e-4)
sched = ChunkScheduler(
model=model,
chunk_size=chunk_size,
active_fraction=active_fraction,
device=device,
scheduler=scheduler_name,
)
metric_rows = []
if benchmark_sync:
sync_device(device)
t0 = time.perf_counter()
for step in range(steps):
x, y = corpus.get_batch("train", batch_size, generator=make_cpu_generator(step))
active_mask = sched.choose_mask(step=step, warmup_steps=warmup_steps)
opt.zero_grad()
_, loss = model(x, y)
loss.backward()
vecs = sched.chunk_gradient_vectors()
masses = sched.chunk_masses_from_vecs(vecs)
if step >= warmup_steps and scheduler_name != "dense":
oracle = sched.oracle_topk_mask(masses)
row = {
"cos": dense_cosine_active_only(vecs, active_mask),
"jacc": jaccard(active_mask, oracle),
"stable": jaccard(active_mask, sched.prev_mask) if sched.prev_mask is not None else 0.0,
"val": evaluate(model, corpus, batch_size, seed=10_000 + step) if step % 50 == 0 else float("nan"),
}
metric_rows.append(row)
install_active_only_grads(sched, active_mask)
# Important: update scheduler from the active observations only.
# Dense gradients exist for diagnostics, but unselected chunks should not
# teach the sparse scheduler after warmup.
observed_for_scheduler = active_mask if step >= warmup_steps else torch.ones_like(active_mask)
sched.update_from_observed(
active_mask=observed_for_scheduler,
true_masses=masses,
step=step,
warmup_steps=warmup_steps,
)
sched.prev_mask = active_mask.clone()
opt.step()
if benchmark_sync:
sync_device(device)
elapsed = time.perf_counter() - t0
val_loss = evaluate(model, corpus, batch_size, seed=12345)
if metric_rows:
avg_cos = sum(r["cos"] for r in metric_rows) / len(metric_rows)
avg_jacc = sum(r["jacc"] for r in metric_rows) / len(metric_rows)
avg_stable = sum(r["stable"] for r in metric_rows) / len(metric_rows)
else:
avg_cos = float("nan")
avg_jacc = float("nan")
avg_stable = float("nan")
return {
"val": val_loss,
"ms": 1000.0 * elapsed / steps,
"cos": avg_cos,
"jacc": avg_jacc,
"stable": avg_stable,
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=500)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--block_size", type=int, default=128)
parser.add_argument("--n_layer", type=int, default=4)
parser.add_argument("--n_head", type=int, default=8)
parser.add_argument("--n_embd", type=int, default=512)
parser.add_argument("--chunk_size", type=int, default=64)
parser.add_argument("--active_fraction", type=float, default=0.10)
parser.add_argument("--warmup_steps", type=int, default=25)
parser.add_argument("--device", type=str, default="mps")
parser.add_argument("--benchmark_sync", action="store_true")
args = parser.parse_args()
schedulers: List[Scheduler] = [
"dense",
"ema_topk",
"knn_scheduler",
"graph_scheduler",
"random",
]
print("\nSensor-based mask scheduling diagnostic")
print(f"device={args.device} steps={args.steps} d={args.n_embd} chunks={args.chunk_size}")
print(f"active_fraction={args.active_fraction} warmup={args.warmup_steps}\n")
print(f"{'scheduler':>18s} | {'val':>8s} | {'ms/step':>8s} | {'grad_cos':>8s} | {'jacc':>8s} | {'stable':>8s}")
print("-" * 78)
for sched_name in schedulers:
result = run_experiment(
scheduler_name=sched_name,
device=args.device,
steps=args.steps,
batch_size=args.batch_size,
block_size=args.block_size,
n_layer=args.n_layer,
n_head=args.n_head,
n_embd=args.n_embd,
chunk_size=args.chunk_size,
active_fraction=args.active_fraction,
warmup_steps=args.warmup_steps,
benchmark_sync=args.benchmark_sync,
)
print(
f"{sched_name:>18s} | "
f"{result['val']:8.4f} | "
f"{result['ms']:8.2f} | "
f"{result['cos']:8.3f} | "
f"{result['jacc']:8.3f} | "
f"{result['stable']:8.3f}"
)
if __name__ == "__main__":
main()