Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Stage One of Twelve — CIFAR-10 Baseline Validation | |
| # Rendered Frame Theory (RFT): DCLR Governor + Ψ–Ω (Orbital) Coupler | |
| # Modes: RFT (DCLR) or BASE (Adam) | |
| import os, math, time, json, argparse, random | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import torchvision.transforms as T | |
| # ---------------- Determinism ---------------- | |
| def set_seed(seed: int = 1234): | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.use_deterministic_algorithms(False) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = False | |
| # ---------------- Telemetry ------------------ | |
| class Telemetry: | |
| def __init__(self, log_path: str | None = None): | |
| self.t0 = time.time() | |
| self.fh = open(log_path, "w") if log_path else None | |
| def emit(self, **k): | |
| k["t"] = round(time.time() - self.t0, 3) | |
| line = json.dumps(k, separators=(",", ":")) | |
| print(line) | |
| if self.fh: | |
| self.fh.write(line + "\n") | |
| self.fh.flush() | |
| def close(self): | |
| if self.fh: | |
| self.fh.close() | |
| # ---------------- Optional NVML -------------- | |
| try: | |
| import pynvml | |
| pynvml.nvmlInit() | |
| _NVML_OK = True | |
| except Exception: | |
| _NVML_OK = False | |
| class EnergyMeter: | |
| def __init__(self, device_index: int = 0): | |
| self.dev_index = device_index | |
| self.last_t = None | |
| def begin_step(self): | |
| self.last_t = time.time() | |
| def end_step(self): | |
| now = time.time() | |
| dt = (now - (self.last_t or now)) | |
| if not _NVML_OK: | |
| return None, None | |
| try: | |
| h = pynvml.nvmlDeviceGetHandleByIndex(self.dev_index) | |
| P = pynvml.nvmlDeviceGetPowerUsage(h) / 1000.0 | |
| T = pynvml.nvmlDeviceGetTemperature(h, pynvml.NVML_TEMPERATURE_GPU) | |
| return P * dt, float(T) | |
| except Exception: | |
| return None, None | |
| # ---------------- Orbital Coupler (Ψ–Ω) ------ | |
| class Orbital: | |
| def __init__(self, sync_gain: float = 0.006, sat_floor: float = 0.2): | |
| self.a = 0.0 | |
| self.b = math.pi / 3 | |
| self.g = sync_gain | |
| self.floor = sat_floor | |
| def step(self): | |
| diff = (self.b - self.a + math.pi) % (2 * math.pi) - math.pi | |
| if abs(diff) < self.floor: | |
| diff = self.floor * (1 if diff >= 0 else -1) | |
| delta = math.sin(diff) | |
| self.a = (self.a + self.g * delta) % (2 * math.pi) | |
| self.b = (self.b - self.g * delta) % (2 * math.pi) | |
| drift = abs((self.a - self.b + math.pi) % (2 * math.pi) - math.pi) | |
| flux = abs(delta) | |
| return drift, flux | |
| # ---------------- DCLR Optimiser ------------- | |
| class DCLR(torch.optim.Optimizer): | |
| def __init__(self, params, lr=5e-4, beta=0.9, gamma=0.999, eps=1e-8, coherence_gain=0.05): | |
| defaults = dict(lr=lr, beta=beta, gamma=gamma, eps=eps, coherence_gain=coherence_gain) | |
| super().__init__(params, defaults) | |
| def step(self, closure=None): | |
| total_J_proxy = 0.0 | |
| for group in self.param_groups: | |
| lr = group["lr"]; beta = group["beta"]; gamma = group["gamma"] | |
| eps = group["eps"]; cg = group["coherence_gain"] | |
| for p in group["params"]: | |
| if p.grad is None: continue | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["m"] = torch.zeros_like(p) | |
| state["v"] = torch.zeros_like(p) | |
| state["coh"] = torch.zeros_like(p) | |
| m, v, coh = state["m"], state["v"], state["coh"] | |
| grad = p.grad | |
| m.mul_(beta).add_(grad, alpha=1 - beta) | |
| v.mul_(gamma).addcmul_(grad, grad, value=1 - gamma) | |
| delta = grad - m | |
| coh.mul_(0.9).add_(delta.abs(), alpha=0.1) | |
| lr_eff = lr / (1.0 + cg * coh) | |
| step = lr_eff * m / (v.sqrt() + eps) | |
| p.add_(-step) | |
| total_J_proxy += (step * step).sum().item() | |
| return None, total_J_proxy | |
| # ---------------- Model ---------------------- | |
| def get_model(): | |
| return torchvision.models.resnet18(num_classes=10) | |
| # ---------------- Data ----------------------- | |
| def get_loaders(batch: int = 256, workers: int = 4): | |
| norm = T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | |
| train_tf = T.Compose([T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.ToTensor(), norm]) | |
| test_tf = T.Compose([T.ToTensor(), norm]) | |
| train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf) | |
| test = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf) | |
| train_loader = torch.utils.data.DataLoader(train, batch_size=batch, shuffle=True, num_workers=workers, pin_memory=True) | |
| test_loader = torch.utils.data.DataLoader(test, batch_size=batch, shuffle=False, num_workers=workers, pin_memory=True) | |
| return train_loader, test_loader | |
| # ---------------- Train / Eval --------------- | |
| def evaluate(model, loader, device): | |
| model.eval() | |
| total, correct, total_loss = 0, 0, 0.0 | |
| loss_fn = nn.CrossEntropyLoss() | |
| with torch.no_grad(): | |
| for x, y in loader: | |
| x, y = x.to(device), y.to(device) | |
| out = model(x) | |
| total_loss += loss_fn(out, y).item() * x.size(0) | |
| correct += (out.argmax(1) == y).sum().item() | |
| total += x.size(0) | |
| return total_loss / total, correct / total | |
| def train(mode="RFT", epochs=5, batch=256, lr=5e-4, coherence_gain=0.05, | |
| sync_gain=0.006, device_index=0, log_path="stage1_cifar10_log.jsonl"): | |
| set_seed(1234) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| train_loader, test_loader = get_loaders(batch=batch) | |
| model = get_model().to(device) | |
| optimiser = DCLR(model.parameters(), lr=lr, coherence_gain=coherence_gain) if mode.upper()=="RFT" else torch.optim.Adam(model.parameters(), lr=lr) | |
| loss_fn = nn.CrossEntropyLoss() | |
| orb = Orbital(sync_gain=sync_gain, sat_floor=0.2) | |
| tm = Telemetry(log_path) | |
| em = EnergyMeter(device_index=device_index) | |
| autocast_enabled = (device=="cuda" and torch.cuda.is_bf16_supported()) | |
| for ep in range(1, epochs+1): | |
| model.train() | |
| for step, (x, y) in enumerate(train_loader, start=1): | |
| x, y = x.to(device), y.to(device) | |
| drift, flux = orb.step() | |
| optimiser.zero_grad(set_to_none=True) | |
| em.begin_step() | |
| if autocast_enabled: | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| out = model(x); loss = loss_fn(out, y) | |
| else: | |
| out = model(x); loss = loss_fn(out, y) | |
| loss.backward() | |
| if isinstance(optimiser, DCLR): _, J_proxy = optimiser.step() | |
| else: optimiser.step(); J_proxy = 0.0 | |
| J_step, tempC = em.end_step() | |
| if J_step is None: J_step = J_proxy * 1e-6 | |
| with torch.no_grad(): | |
| acc = (out.argmax(1) == y).float().mean().item() | |
| E_ret, coh = 0.99, 0.999 | |
| tm.emit(mode=mode.upper(), ep=ep, step=step, | |
| drift=round(drift,3), flux=round(flux,3), | |
| E_ret=E_ret, coh=coh, | |
| loss=round(loss.item(),4), acc=round(acc,3), | |
| J_step=round(J_step,6), | |
| tempC=(None if tempC is None else round(tempC,2))) | |
| val_loss, val_acc = evaluate(model, test_loader, device) | |
| tm.emit(tag="eval", ep=ep, mode=mode.upper(), | |
| val_loss=round(float(val_loss), 4), | |
| val_acc=round(float(val_acc), 3)) | |
| tm.close() | |
| return model | |
| # ---------------- CLI ------------------------ | |
| def main(): | |
| ap = argparse.ArgumentParser(description="Stage 1 of 12 — CIFAR-10 RFT vs Adam") | |
| ap.add_argument("--mode", choices=["RFT", "BASE"], default="RFT") | |
| ap.add_argument("--epochs", type=int, default=5) | |
| ap.add_argument("--batch", type=int, default=256) | |
| ap.add_argument("--lr", type=float, default=5e-4) | |
| ap.add_argument("--coherence_gain", type=float, default=0.05) | |
| ap.add_argument("--sync_gain", type=float, default=0.006) | |
| ap.add_argument("--device_index", type=int, default=0) | |
| ap.add_argument("--log_path", type=str, default="stage1_cifar10_log.jsonl") | |
| args = ap.parse_args() | |
| train(mode=args.mode, epochs=args.epochs, batch=args.batch, lr=args.lr, | |
| coherence_gain=args.coherence_gain, sync_gain=args.sync_gain, | |
| device_index=args.device_index, log_path=args.log_path) | |
| if __name__ == "__main__": | |
| main() | |