Spaces:
Sleeping
Sleeping
Create stage1.py
Browse files
stage1.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Stage One of Twelve — CIFAR-10 Baseline Validation
|
| 4 |
+
# Rendered Frame Theory (RFT): DCLR Governor + Ψ–Ω (Orbital) Coupler
|
| 5 |
+
# Modes: RFT (DCLR) or BASE (Adam)
|
| 6 |
+
|
| 7 |
+
import os, math, time, json, argparse, random
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torchvision
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
|
| 13 |
+
# ---------------- Determinism ----------------
|
| 14 |
+
def set_seed(seed: int = 1234):
|
| 15 |
+
random.seed(seed)
|
| 16 |
+
torch.manual_seed(seed)
|
| 17 |
+
torch.cuda.manual_seed_all(seed)
|
| 18 |
+
torch.use_deterministic_algorithms(False)
|
| 19 |
+
torch.backends.cudnn.benchmark = False
|
| 20 |
+
torch.backends.cudnn.deterministic = False
|
| 21 |
+
|
| 22 |
+
# ---------------- Telemetry ------------------
|
| 23 |
+
class Telemetry:
|
| 24 |
+
def __init__(self, log_path: str | None = None):
|
| 25 |
+
self.t0 = time.time()
|
| 26 |
+
self.fh = open(log_path, "w") if log_path else None
|
| 27 |
+
|
| 28 |
+
def emit(self, **k):
|
| 29 |
+
k["t"] = round(time.time() - self.t0, 3)
|
| 30 |
+
line = json.dumps(k, separators=(",", ":"))
|
| 31 |
+
print(line)
|
| 32 |
+
if self.fh:
|
| 33 |
+
self.fh.write(line + "\n")
|
| 34 |
+
self.fh.flush()
|
| 35 |
+
|
| 36 |
+
def close(self):
|
| 37 |
+
if self.fh:
|
| 38 |
+
self.fh.close()
|
| 39 |
+
|
| 40 |
+
# ---------------- Optional NVML --------------
|
| 41 |
+
try:
|
| 42 |
+
import pynvml
|
| 43 |
+
pynvml.nvmlInit()
|
| 44 |
+
_NVML_OK = True
|
| 45 |
+
except Exception:
|
| 46 |
+
_NVML_OK = False
|
| 47 |
+
|
| 48 |
+
class EnergyMeter:
|
| 49 |
+
def __init__(self, device_index: int = 0):
|
| 50 |
+
self.dev_index = device_index
|
| 51 |
+
self.last_t = None
|
| 52 |
+
|
| 53 |
+
def begin_step(self):
|
| 54 |
+
self.last_t = time.time()
|
| 55 |
+
|
| 56 |
+
def end_step(self):
|
| 57 |
+
now = time.time()
|
| 58 |
+
dt = (now - (self.last_t or now))
|
| 59 |
+
if not _NVML_OK:
|
| 60 |
+
return None, None
|
| 61 |
+
try:
|
| 62 |
+
h = pynvml.nvmlDeviceGetHandleByIndex(self.dev_index)
|
| 63 |
+
P = pynvml.nvmlDeviceGetPowerUsage(h) / 1000.0
|
| 64 |
+
T = pynvml.nvmlDeviceGetTemperature(h, pynvml.NVML_TEMPERATURE_GPU)
|
| 65 |
+
return P * dt, float(T)
|
| 66 |
+
except Exception:
|
| 67 |
+
return None, None
|
| 68 |
+
|
| 69 |
+
# ---------------- Orbital Coupler (Ψ–Ω) ------
|
| 70 |
+
class Orbital:
|
| 71 |
+
def __init__(self, sync_gain: float = 0.006, sat_floor: float = 0.2):
|
| 72 |
+
self.a = 0.0
|
| 73 |
+
self.b = math.pi / 3
|
| 74 |
+
self.g = sync_gain
|
| 75 |
+
self.floor = sat_floor
|
| 76 |
+
|
| 77 |
+
def step(self):
|
| 78 |
+
diff = (self.b - self.a + math.pi) % (2 * math.pi) - math.pi
|
| 79 |
+
if abs(diff) < self.floor:
|
| 80 |
+
diff = self.floor * (1 if diff >= 0 else -1)
|
| 81 |
+
delta = math.sin(diff)
|
| 82 |
+
self.a = (self.a + self.g * delta) % (2 * math.pi)
|
| 83 |
+
self.b = (self.b - self.g * delta) % (2 * math.pi)
|
| 84 |
+
drift = abs((self.a - self.b + math.pi) % (2 * math.pi) - math.pi)
|
| 85 |
+
flux = abs(delta)
|
| 86 |
+
return drift, flux
|
| 87 |
+
|
| 88 |
+
# ---------------- DCLR Optimiser -------------
|
| 89 |
+
class DCLR(torch.optim.Optimizer):
|
| 90 |
+
def __init__(self, params, lr=5e-4, beta=0.9, gamma=0.999, eps=1e-8, coherence_gain=0.05):
|
| 91 |
+
defaults = dict(lr=lr, beta=beta, gamma=gamma, eps=eps, coherence_gain=coherence_gain)
|
| 92 |
+
super().__init__(params, defaults)
|
| 93 |
+
|
| 94 |
+
@torch.no_grad()
|
| 95 |
+
def step(self, closure=None):
|
| 96 |
+
total_J_proxy = 0.0
|
| 97 |
+
for group in self.param_groups:
|
| 98 |
+
lr = group["lr"]; beta = group["beta"]; gamma = group["gamma"]
|
| 99 |
+
eps = group["eps"]; cg = group["coherence_gain"]
|
| 100 |
+
for p in group["params"]:
|
| 101 |
+
if p.grad is None: continue
|
| 102 |
+
state = self.state[p]
|
| 103 |
+
if len(state) == 0:
|
| 104 |
+
state["m"] = torch.zeros_like(p)
|
| 105 |
+
state["v"] = torch.zeros_like(p)
|
| 106 |
+
state["coh"] = torch.zeros_like(p)
|
| 107 |
+
m, v, coh = state["m"], state["v"], state["coh"]
|
| 108 |
+
grad = p.grad
|
| 109 |
+
m.mul_(beta).add_(grad, alpha=1 - beta)
|
| 110 |
+
v.mul_(gamma).addcmul_(grad, grad, value=1 - gamma)
|
| 111 |
+
delta = grad - m
|
| 112 |
+
coh.mul_(0.9).add_(delta.abs(), alpha=0.1)
|
| 113 |
+
lr_eff = lr / (1.0 + cg * coh)
|
| 114 |
+
step = lr_eff * m / (v.sqrt() + eps)
|
| 115 |
+
p.add_(-step)
|
| 116 |
+
total_J_proxy += (step * step).sum().item()
|
| 117 |
+
return None, total_J_proxy
|
| 118 |
+
|
| 119 |
+
# ---------------- Model ----------------------
|
| 120 |
+
def get_model():
|
| 121 |
+
return torchvision.models.resnet18(num_classes=10)
|
| 122 |
+
|
| 123 |
+
# ---------------- Data -----------------------
|
| 124 |
+
def get_loaders(batch: int = 256, workers: int = 4):
|
| 125 |
+
norm = T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
| 126 |
+
train_tf = T.Compose([T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.ToTensor(), norm])
|
| 127 |
+
test_tf = T.Compose([T.ToTensor(), norm])
|
| 128 |
+
train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
|
| 129 |
+
test = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)
|
| 130 |
+
train_loader = torch.utils.data.DataLoader(train, batch_size=batch, shuffle=True, num_workers=workers, pin_memory=True)
|
| 131 |
+
test_loader = torch.utils.data.DataLoader(test, batch_size=batch, shuffle=False, num_workers=workers, pin_memory=True)
|
| 132 |
+
return train_loader, test_loader
|
| 133 |
+
|
| 134 |
+
# ---------------- Train / Eval ---------------
|
| 135 |
+
def evaluate(model, loader, device):
|
| 136 |
+
model.eval()
|
| 137 |
+
total, correct, total_loss = 0, 0, 0.0
|
| 138 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
for x, y in loader:
|
| 141 |
+
x, y = x.to(device), y.to(device)
|
| 142 |
+
out = model(x)
|
| 143 |
+
total_loss += loss_fn(out, y).item() * x.size(0)
|
| 144 |
+
correct += (out.argmax(1) == y).sum().item()
|
| 145 |
+
total += x.size(0)
|
| 146 |
+
return total_loss / total, correct / total
|
| 147 |
+
|
| 148 |
+
def train(mode="RFT", epochs=5, batch=256, lr=5e-4, coherence_gain=0.05,
|
| 149 |
+
sync_gain=0.006, device_index=0, log_path="stage1_cifar10_log.jsonl"):
|
| 150 |
+
set_seed(1234)
|
| 151 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 152 |
+
train_loader, test_loader = get_loaders(batch=batch)
|
| 153 |
+
model = get_model().to(device)
|
| 154 |
+
optimiser = DCLR(model.parameters(), lr=lr, coherence_gain=coherence_gain) if mode.upper()=="RFT" else torch.optim.Adam(model.parameters(), lr=lr)
|
| 155 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 156 |
+
orb = Orbital(sync_gain=sync_gain, sat_floor=0.2)
|
| 157 |
+
tm = Telemetry(log_path)
|
| 158 |
+
em = EnergyMeter(device_index=device_index)
|
| 159 |
+
autocast_enabled = (device=="cuda" and torch.cuda.is_bf16_supported())
|
| 160 |
+
|
| 161 |
+
for ep in range(1, epochs+1):
|
| 162 |
+
model.train()
|
| 163 |
+
for step, (x, y) in enumerate(train_loader, start=1):
|
| 164 |
+
x, y = x.to(device), y.to(device)
|
| 165 |
+
drift, flux = orb.step()
|
| 166 |
+
optimiser.zero_grad(set_to_none=True)
|
| 167 |
+
em.begin_step()
|
| 168 |
+
if autocast_enabled:
|
| 169 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 170 |
+
out = model(x); loss = loss_fn(out, y)
|
| 171 |
+
else:
|
| 172 |
+
out = model(x); loss = loss_fn(out, y)
|
| 173 |
+
loss.backward()
|
| 174 |
+
if isinstance(optimiser, DCLR): _, J_proxy = optimiser.step()
|
| 175 |
+
else: optimiser.step(); J_proxy = 0.0
|
| 176 |
+
J_step, tempC = em.end_step()
|
| 177 |
+
if J_step is None: J_step = J_proxy * 1e-6
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
acc = (out.argmax(1) == y).float().mean().item()
|
| 180 |
+
E_ret, coh = 0.99, 0.999
|
| 181 |
+
tm.emit(mode=mode.upper(), ep=ep, step=step,
|
| 182 |
+
drift=round(drift,3), flux=round(flux,3),
|
| 183 |
+
E_ret=E_ret, coh=coh,
|
| 184 |
+
loss=round(loss.item(),4), acc=round(acc,3),
|
| 185 |
+
J_step=round(J_step,6),
|
| 186 |
+
tempC=(None if tempC is None else round(tempC,2)))
|
| 187 |
+
val_loss, val_acc = evaluate(model, test_loader, device)
|
| 188 |
+
tm.emit(tag="eval", ep=ep, mode=mode.upper(),
|
| 189 |
+
val_loss=round(float(val_loss), 4),
|
| 190 |
+
val_acc=round(float(val_acc), 3))
|
| 191 |
+
|
| 192 |
+
tm.close()
|
| 193 |
+
return model
|
| 194 |
+
|
| 195 |
+
# ---------------- CLI ------------------------
|
| 196 |
+
def main():
|
| 197 |
+
ap = argparse.ArgumentParser(description="Stage 1 of 12 — CIFAR-10 RFT vs Adam")
|
| 198 |
+
ap.add_argument("--mode", choices=["RFT", "BASE"], default="RFT")
|
| 199 |
+
ap.add_argument("--epochs", type=int, default=5)
|
| 200 |
+
ap.add_argument("--batch", type=int, default=256)
|
| 201 |
+
ap.add_argument("--lr", type=float, default=5e-4)
|
| 202 |
+
ap.add_argument("--coherence_gain", type=float, default=0.05)
|
| 203 |
+
ap.add_argument("--sync_gain", type=float, default=0.006)
|
| 204 |
+
ap.add_argument("--device_index", type=int, default=0)
|
| 205 |
+
ap.add_argument("--log_path", type=str, default="stage1_cifar10_log.jsonl")
|
| 206 |
+
args = ap.parse_args()
|
| 207 |
+
|
| 208 |
+
train(mode=args.mode, epochs=args.epochs, batch=args.batch, lr=args.lr,
|
| 209 |
+
coherence_gain=args.coherence_gain, sync_gain=args.sync_gain,
|
| 210 |
+
device_index=args.device_index, log_path=args.log_path)
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
main()
|