#!/usr/bin/env python3 """ WIRE-SPEED TRANSFORMER - Learns directly from network stream No batching. No epochs. Just continuous absorption. Receives tokenized data via stdin from Rust feeder. Updates weights after every micro-batch (configurable, default 32 tokens). """ import sys import math import time import torch import torch.nn as nn import torch.nn.functional as F from collections import deque # ─────────────────── Config ─────────────────── DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cuda.matmul.allow_tf32 = True # Tiny model for wire-speed updates CONFIG = { "d": 256, # embedding dim "layers": 4, # transformer layers "heads": 8, # attention heads "rank": 32, # attention rank (from n.py's tuneable attention) "vocab": 128256, # DeepSeek V3.2 vocab "ctx": 512, # context window } LR = 1e-4 UPDATE_EVERY = 32 # tokens between weight updates (micro-batch) PRINT_EVERY = 10000 # tokens between stats # ─────────────────── Model (simplified from n.py) ─────────────────── class TuneableAttention(nn.Module): def __init__(self, d, h, r): super().__init__() self.h, self.dk, self.r = h, d // h, r self.qkv = nn.Linear(d, 3 * d, bias=False) self.U = nn.Parameter(torch.randn(self.dk, r) * 0.02) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, D = x.shape qkv = self.qkv(x).view(B, N, 3, self.h, self.dk) q, k, v = qkv.unbind(2) # B, N, h, dk # Project Q and K through U for tuneable rank q = (q @ self.U) # B, N, h, r k = (k @ self.U) # B, N, h, r # Attention q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) att = (q @ k.transpose(-1, -2)) / math.sqrt(self.r) if mask is not None: att = att + mask att = F.softmax(att, dim=-1) out = (att @ v).transpose(1, 2).reshape(B, N, D) return self.proj(out) class Block(nn.Module): def __init__(self, d, h, r): super().__init__() self.ln1 = nn.LayerNorm(d) self.attn = TuneableAttention(d, h, r) self.ln2 = nn.LayerNorm(d) self.ff = nn.Sequential( nn.Linear(d, 4 * d), nn.GELU(), nn.Linear(4 * d, d) ) def forward(self, x, mask): x = x + self.attn(self.ln1(x), mask) x = x + self.ff(self.ln2(x)) return x class StreamingTransformer(nn.Module): def __init__(self, cfg): super().__init__() d, L, h, r, V = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"], cfg["vocab"] self.emb = nn.Embedding(V, d) self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(L)]) self.ln = nn.LayerNorm(d) self.head = nn.Linear(d, V, bias=False) # Weight tying self.head.weight = self.emb.weight def forward(self, x): B, N = x.shape # Causal mask mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9 h = self.emb(x) for block in self.blocks: h = block(h, mask) return self.head(self.ln(h)) def count_params(self): return sum(p.numel() for p in self.parameters()) # ─────────────────── Online Trainer ─────────────────── class WireSpeedTrainer: def __init__(self, model, lr=LR): self.model = model.to(DEVICE) self.opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95)) self.ctx_size = CONFIG["ctx"] # Rolling buffer for context self.buffer = deque(maxlen=self.ctx_size + 1) # Stats self.tokens_seen = 0 self.total_loss = 0.0 self.updates = 0 self.start_time = time.time() def ingest_token(self, token_id): """Absorb a single token. Update weights when buffer fills.""" self.buffer.append(token_id) self.tokens_seen += 1 # Update every N tokens when we have enough context if len(self.buffer) >= UPDATE_EVERY + 1 and self.tokens_seen % UPDATE_EVERY == 0: self._update() # Print stats if self.tokens_seen % PRINT_EVERY == 0: self._print_stats() def _update(self): """Single gradient step on current buffer.""" # Convert buffer to tensor tokens = list(self.buffer) x = torch.tensor(tokens[:-1], device=DEVICE).unsqueeze(0) # input y = torch.tensor(tokens[1:], device=DEVICE).unsqueeze(0) # target # Forward self.model.train() logits = self.model(x) # Loss on last UPDATE_EVERY positions only (most recent) loss = F.cross_entropy( logits[:, -UPDATE_EVERY:].reshape(-1, CONFIG["vocab"]), y[:, -UPDATE_EVERY:].reshape(-1) ) # Backward self.opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.opt.step() self.total_loss += loss.item() self.updates += 1 def _print_stats(self): elapsed = time.time() - self.start_time tok_per_sec = self.tokens_seen / elapsed if elapsed > 0 else 0 avg_loss = self.total_loss / max(1, self.updates) print(f"[{elapsed:.0f}s] {self.tokens_seen:,} tok | {tok_per_sec:.0f} tok/s | " f"loss={avg_loss:.4f} | updates={self.updates}", flush=True) # ─────────────────── Main ─────────────────── def main(): print(f"Wire-Speed Transformer", flush=True) print(f"Config: {CONFIG}", flush=True) print(f"Device: {DEVICE}", flush=True) model = StreamingTransformer(CONFIG) params = model.count_params() print(f"Parameters: {params:,} ({params/1e6:.1f}M)", flush=True) trainer = WireSpeedTrainer(model) print(f"Listening for tokens on stdin...", flush=True) print(f"Update every {UPDATE_EVERY} tokens, print every {PRINT_EVERY}", flush=True) # Read token IDs from stdin (one per line from Rust feeder) for line in sys.stdin: try: token_id = int(line.strip()) if 0 <= token_id < CONFIG["vocab"]: trainer.ingest_token(token_id) except ValueError: continue # Skip malformed lines print(f"Stream ended. Total tokens: {trainer.tokens_seen:,}", flush=True) if __name__ == "__main__": main()