|
|
|
|
|
""" |
|
|
WIRE-SPEED TRANSFORMER - Learns directly from network stream |
|
|
No batching. No epochs. Just continuous absorption. |
|
|
|
|
|
Receives tokenized data via stdin from Rust feeder. |
|
|
Updates weights after every micro-batch (configurable, default 32 tokens). |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import math |
|
|
import time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from collections import deque |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
|
|
|
CONFIG = { |
|
|
"d": 256, |
|
|
"layers": 4, |
|
|
"heads": 8, |
|
|
"rank": 32, |
|
|
"vocab": 128256, |
|
|
"ctx": 512, |
|
|
} |
|
|
|
|
|
LR = 1e-4 |
|
|
UPDATE_EVERY = 32 |
|
|
PRINT_EVERY = 10000 |
|
|
|
|
|
|
|
|
class TuneableAttention(nn.Module): |
|
|
def __init__(self, d, h, r): |
|
|
super().__init__() |
|
|
self.h, self.dk, self.r = h, d // h, r |
|
|
self.qkv = nn.Linear(d, 3 * d, bias=False) |
|
|
self.U = nn.Parameter(torch.randn(self.dk, r) * 0.02) |
|
|
self.proj = nn.Linear(d, d, bias=False) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
B, N, D = x.shape |
|
|
qkv = self.qkv(x).view(B, N, 3, self.h, self.dk) |
|
|
q, k, v = qkv.unbind(2) |
|
|
|
|
|
|
|
|
q = (q @ self.U) |
|
|
k = (k @ self.U) |
|
|
|
|
|
|
|
|
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.r) |
|
|
if mask is not None: |
|
|
att = att + mask |
|
|
att = F.softmax(att, dim=-1) |
|
|
out = (att @ v).transpose(1, 2).reshape(B, N, D) |
|
|
return self.proj(out) |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, d, h, r): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d) |
|
|
self.attn = TuneableAttention(d, h, r) |
|
|
self.ln2 = nn.LayerNorm(d) |
|
|
self.ff = nn.Sequential( |
|
|
nn.Linear(d, 4 * d), |
|
|
nn.GELU(), |
|
|
nn.Linear(4 * d, d) |
|
|
) |
|
|
|
|
|
def forward(self, x, mask): |
|
|
x = x + self.attn(self.ln1(x), mask) |
|
|
x = x + self.ff(self.ln2(x)) |
|
|
return x |
|
|
|
|
|
class StreamingTransformer(nn.Module): |
|
|
def __init__(self, cfg): |
|
|
super().__init__() |
|
|
d, L, h, r, V = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"], cfg["vocab"] |
|
|
self.emb = nn.Embedding(V, d) |
|
|
self.blocks = nn.ModuleList([Block(d, h, r) for _ in range(L)]) |
|
|
self.ln = nn.LayerNorm(d) |
|
|
self.head = nn.Linear(d, V, bias=False) |
|
|
|
|
|
self.head.weight = self.emb.weight |
|
|
|
|
|
def forward(self, x): |
|
|
B, N = x.shape |
|
|
|
|
|
mask = torch.triu(torch.ones(N, N, device=x.device), 1) * -1e9 |
|
|
|
|
|
h = self.emb(x) |
|
|
for block in self.blocks: |
|
|
h = block(h, mask) |
|
|
return self.head(self.ln(h)) |
|
|
|
|
|
def count_params(self): |
|
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
|
|
|
class WireSpeedTrainer: |
|
|
def __init__(self, model, lr=LR): |
|
|
self.model = model.to(DEVICE) |
|
|
self.opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95)) |
|
|
self.ctx_size = CONFIG["ctx"] |
|
|
|
|
|
|
|
|
self.buffer = deque(maxlen=self.ctx_size + 1) |
|
|
|
|
|
|
|
|
self.tokens_seen = 0 |
|
|
self.total_loss = 0.0 |
|
|
self.updates = 0 |
|
|
self.start_time = time.time() |
|
|
|
|
|
def ingest_token(self, token_id): |
|
|
"""Absorb a single token. Update weights when buffer fills.""" |
|
|
self.buffer.append(token_id) |
|
|
self.tokens_seen += 1 |
|
|
|
|
|
|
|
|
if len(self.buffer) >= UPDATE_EVERY + 1 and self.tokens_seen % UPDATE_EVERY == 0: |
|
|
self._update() |
|
|
|
|
|
|
|
|
if self.tokens_seen % PRINT_EVERY == 0: |
|
|
self._print_stats() |
|
|
|
|
|
def _update(self): |
|
|
"""Single gradient step on current buffer.""" |
|
|
|
|
|
tokens = list(self.buffer) |
|
|
x = torch.tensor(tokens[:-1], device=DEVICE).unsqueeze(0) |
|
|
y = torch.tensor(tokens[1:], device=DEVICE).unsqueeze(0) |
|
|
|
|
|
|
|
|
self.model.train() |
|
|
logits = self.model(x) |
|
|
|
|
|
|
|
|
loss = F.cross_entropy( |
|
|
logits[:, -UPDATE_EVERY:].reshape(-1, CONFIG["vocab"]), |
|
|
y[:, -UPDATE_EVERY:].reshape(-1) |
|
|
) |
|
|
|
|
|
|
|
|
self.opt.zero_grad() |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
|
|
self.opt.step() |
|
|
|
|
|
self.total_loss += loss.item() |
|
|
self.updates += 1 |
|
|
|
|
|
def _print_stats(self): |
|
|
elapsed = time.time() - self.start_time |
|
|
tok_per_sec = self.tokens_seen / elapsed if elapsed > 0 else 0 |
|
|
avg_loss = self.total_loss / max(1, self.updates) |
|
|
|
|
|
print(f"[{elapsed:.0f}s] {self.tokens_seen:,} tok | {tok_per_sec:.0f} tok/s | " |
|
|
f"loss={avg_loss:.4f} | updates={self.updates}", flush=True) |
|
|
|
|
|
|
|
|
def main(): |
|
|
print(f"Wire-Speed Transformer", flush=True) |
|
|
print(f"Config: {CONFIG}", flush=True) |
|
|
print(f"Device: {DEVICE}", flush=True) |
|
|
|
|
|
model = StreamingTransformer(CONFIG) |
|
|
params = model.count_params() |
|
|
print(f"Parameters: {params:,} ({params/1e6:.1f}M)", flush=True) |
|
|
|
|
|
trainer = WireSpeedTrainer(model) |
|
|
|
|
|
print(f"Listening for tokens on stdin...", flush=True) |
|
|
print(f"Update every {UPDATE_EVERY} tokens, print every {PRINT_EVERY}", flush=True) |
|
|
|
|
|
|
|
|
for line in sys.stdin: |
|
|
try: |
|
|
token_id = int(line.strip()) |
|
|
if 0 <= token_id < CONFIG["vocab"]: |
|
|
trainer.ingest_token(token_id) |
|
|
except ValueError: |
|
|
continue |
|
|
|
|
|
print(f"Stream ended. Total tokens: {trainer.tokens_seen:,}", flush=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|