wire-speed-transformer / stream_trainer.py
OpenTransformer's picture
Upload stream_trainer.py with huggingface_hub
44d9388 verified
#!/usr/bin/env python3
"""
WIRE-SPEED TRANSFORMER - Learns directly from network stream
No batching. No epochs. Just continuous absorption.
Receives tokenized data via stdin from Rust feeder.
Updates weights after every micro-batch (configurable, default 32 tokens).
"""
import sys
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
# ─────────────────── Config ───────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
# Tiny model for wire-speed updates
CONFIG = {
"d": 256, # embedding dim
"layers": 4, # transformer layers
"heads": 8, # attention heads
"rank": 32, # attention rank (from n.py's tuneable attention)
"vocab": 128256, # DeepSeek V3.2 vocab
"ctx": 512, # context window
}
LR = 1e-4
UPDATE_EVERY = 32 # tokens between weight updates (micro-batch)
PRINT_EVERY = 10000 # tokens between stats
# ─────────────────── Model (simplified from n.py) ───────────────────
class TuneableAttention(nn.Module):
def __init__(self, d, h, r):
super().__init__()
self.h, self.dk, self.r = h, d // h, r
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.U = nn.Parameter(torch.randn(self.dk, r) * 0.02)
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, mask=None):
B, N, D = x.shape
qkv = self.qkv(x).view(B, N, 3, self.h, self.dk)
q, k, v = qkv.unbind(2) # B, N, h, dk
# Project Q and K through U for tuneable rank
q = (q @ self.U) # B, N, h, r
k = (k @ self.U) # B, N, h, r
# Attention
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.r)
if mask is not None:
att = att + mask
att = F.softmax(att, dim=-1)
out = (att @ v).transpose(1, 2).reshape(B, N, D)
return self.proj(out)
class Block(nn.Module):
def __init__(self, d, h, r):
super().__init__()
self.ln1 = nn.LayerNorm(d)
self.attn = TuneableAttention(d, h, r)
self.ln2 = nn.LayerNorm(d)
self.ff = nn.Sequential(
nn.Linear(d, 4 * d),
nn.GELU(),
nn.Linear(4 * d, d)
)
def forward(self, x, mask):
x = x + self.attn(self.ln1(x), mask)
x = x + self.ff(self.ln2(x))
return x
class StreamingTransformer(nn.Module):
def __init__(self, cfg):
super().__init__()
d, L, h, r, V = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"], cfg["vocab"]
self.emb = nn.Embedding(V, d)
self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(L)])
self.ln = nn.LayerNorm(d)
self.head = nn.Linear(d, V, bias=False)
# Weight tying
self.head.weight = self.emb.weight
def forward(self, x):
B, N = x.shape
# Causal mask
mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9
h = self.emb(x)
for block in self.blocks:
h = block(h, mask)
return self.head(self.ln(h))
def count_params(self):
return sum(p.numel() for p in self.parameters())
# ─────────────────── Online Trainer ───────────────────
class WireSpeedTrainer:
def __init__(self, model, lr=LR):
self.model = model.to(DEVICE)
self.opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95))
self.ctx_size = CONFIG["ctx"]
# Rolling buffer for context
self.buffer = deque(maxlen=self.ctx_size + 1)
# Stats
self.tokens_seen = 0
self.total_loss = 0.0
self.updates = 0
self.start_time = time.time()
def ingest_token(self, token_id):
"""Absorb a single token. Update weights when buffer fills."""
self.buffer.append(token_id)
self.tokens_seen += 1
# Update every N tokens when we have enough context
if len(self.buffer) >= UPDATE_EVERY + 1 and self.tokens_seen % UPDATE_EVERY == 0:
self._update()
# Print stats
if self.tokens_seen % PRINT_EVERY == 0:
self._print_stats()
def _update(self):
"""Single gradient step on current buffer."""
# Convert buffer to tensor
tokens = list(self.buffer)
x = torch.tensor(tokens[:-1], device=DEVICE).unsqueeze(0) # input
y = torch.tensor(tokens[1:], device=DEVICE).unsqueeze(0) # target
# Forward
self.model.train()
logits = self.model(x)
# Loss on last UPDATE_EVERY positions only (most recent)
loss = F.cross_entropy(
logits[:, -UPDATE_EVERY:].reshape(-1, CONFIG["vocab"]),
y[:, -UPDATE_EVERY:].reshape(-1)
)
# Backward
self.opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.opt.step()
self.total_loss += loss.item()
self.updates += 1
def _print_stats(self):
elapsed = time.time() - self.start_time
tok_per_sec = self.tokens_seen / elapsed if elapsed > 0 else 0
avg_loss = self.total_loss / max(1, self.updates)
print(f"[{elapsed:.0f}s] {self.tokens_seen:,} tok | {tok_per_sec:.0f} tok/s | "
f"loss={avg_loss:.4f} | updates={self.updates}", flush=True)
# ─────────────────── Main ───────────────────
def main():
print(f"Wire-Speed Transformer", flush=True)
print(f"Config: {CONFIG}", flush=True)
print(f"Device: {DEVICE}", flush=True)
model = StreamingTransformer(CONFIG)
params = model.count_params()
print(f"Parameters: {params:,} ({params/1e6:.1f}M)", flush=True)
trainer = WireSpeedTrainer(model)
print(f"Listening for tokens on stdin...", flush=True)
print(f"Update every {UPDATE_EVERY} tokens, print every {PRINT_EVERY}", flush=True)
# Read token IDs from stdin (one per line from Rust feeder)
for line in sys.stdin:
try:
token_id = int(line.strip())
if 0 <= token_id < CONFIG["vocab"]:
trainer.ingest_token(token_id)
except ValueError:
continue # Skip malformed lines
print(f"Stream ended. Total tokens: {trainer.tokens_seen:,}", flush=True)
if __name__ == "__main__":
main()