RFTSystems commited on
Commit
83dd5ba
·
verified ·
1 Parent(s): 832e755

Create stage2.py

Browse files
Files changed (1) hide show
  1. stage2.py +109 -0
stage2.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stage2.py
2
+ # Author: Liam Grinstead
3
+ # Purpose: Orbital & Agent Coupling Validation (Stage Two of Twelve)
4
+
5
+ import torch, math, time, json, random, argparse, numpy as np
6
+
7
+ # ---------------- Determinism ----------------
8
+ def set_seed(s=1234):
9
+ random.seed(s); np.random.seed(s)
10
+ torch.manual_seed(s); torch.cuda.manual_seed_all(s)
11
+
12
+ # ---------------- Telemetry ------------------
13
+ class Telemetry:
14
+ def __init__(self, path=None):
15
+ self.t0 = time.time()
16
+ self.f = open(path, "w") if path else None
17
+ def emit(self, **k):
18
+ k["t"] = round(time.time() - self.t0, 3)
19
+ line = json.dumps(k, separators=(",", ":"))
20
+ print(line)
21
+ if self.f:
22
+ self.f.write(line + "\n"); self.f.flush()
23
+ def close(self):
24
+ if self.f: self.f.close()
25
+
26
+ # ---------------- Orbital Coupler ------------
27
+ class Orbital:
28
+ def __init__(self, g=0.006, floor=0.2):
29
+ self.a = 0.0; self.b = math.pi/3; self.g = g; self.floor = floor
30
+ def step(self):
31
+ diff = (self.b - self.a + math.pi) % (2*math.pi) - math.pi
32
+ if abs(diff) < self.floor:
33
+ diff = self.floor * (1 if diff >= 0 else -1)
34
+ delta = math.sin(diff)
35
+ self.a = (self.a + self.g * delta) % (2*math.pi)
36
+ self.b = (self.b - self.g * delta) % (2*math.pi)
37
+ drift = abs((self.a - self.b + math.pi) % (2*math.pi) - math.pi)
38
+ return drift, abs(delta)
39
+
40
+ # ---------------- DCLR Optimiser -------------
41
+ class DCLR(torch.optim.Optimizer):
42
+ def __init__(self, params, lr=5e-3, beta=0.9, gamma=0.999, eps=1e-8, cg=0.05):
43
+ super().__init__(params, dict(lr=lr,beta=beta,gamma=gamma,eps=eps,cg=cg))
44
+ @torch.no_grad()
45
+ def step(self, closure=None):
46
+ tot = 0.0
47
+ for g in self.param_groups:
48
+ lr, beta, gamma, eps, c = g["lr"], g["beta"], g["gamma"], g["eps"], g["cg"]
49
+ for p in g["params"]:
50
+ if p.grad is None: continue
51
+ st = self.state[p]
52
+ if not st:
53
+ st["m"] = torch.zeros_like(p)
54
+ st["v"] = torch.zeros_like(p)
55
+ st["coh"] = torch.zeros_like(p)
56
+ m,v,h = st["m"],st["v"],st["coh"]; grad=p.grad
57
+ m.mul_(beta).add_(grad,alpha=1-beta)
58
+ v.mul_(gamma).addcmul_(grad,grad,value=1-gamma)
59
+ d=grad-m; h.mul_(0.9).add_(d.abs(),alpha=0.1)
60
+ lr_eff=lr/(1+c*h)
61
+ step=lr_eff*m/(v.sqrt()+eps); p.add_(-step)
62
+ tot += (step*step).sum().item()
63
+ return None, tot
64
+
65
+ # ---------------- Agent Field ----------------
66
+ class Agents(torch.nn.Module):
67
+ def __init__(self, n=256, box=10.0, r0=0.15):
68
+ super().__init__()
69
+ self.n=n; self.box=box; self.r0=r0
70
+ pos=(torch.rand(n,2)-0.5)*box
71
+ vel=torch.zeros(n,2)
72
+ self.pos=torch.nn.Parameter(pos); self.vel=torch.nn.Parameter(vel)
73
+ def forward(self):
74
+ n=self.n; pos=self.pos
75
+ diff=pos.unsqueeze(1)-pos.unsqueeze(0)
76
+ dist=torch.clamp(diff.norm(dim=-1),1e-6)
77
+ mask=(dist<self.r0) & (~torch.eye(n,dtype=torch.bool,device=pos.device))
78
+ rep=(diff/(dist.unsqueeze(-1)+1e-6))*mask.unsqueeze(-1)
79
+ rep=rep.sum(dim=1)
80
+ spring=-0.001*pos
81
+ acc=0.05*rep + spring
82
+ return acc
83
+
84
+ # ---------------- Runner ---------------------
85
+ def train(mode="RFT", steps=500, n=256, r0=0.165, log_path="stage2_agents.jsonl"):
86
+ set_seed(1234)
87
+ tm=Telemetry(log_path); orb=Orbital()
88
+ dev="cuda" if torch.cuda.is_available() else "cpu"
89
+ A=Agents(n=n, r0=r0).to(dev)
90
+ opt = DCLR(A.parameters(), lr=5e-3) if mode=="RFT" else torch.optim.SGD(A.parameters(), lr=5e-3)
91
+ collisions=0
92
+ for s in range(1, steps+1):
93
+ drift,flux=orb.step()
94
+ opt.zero_grad(set_to_none=True)
95
+ acc=A()
96
+ loss=(acc**2).mean()
97
+ loss.backward()
98
+ if isinstance(opt,DCLR): _,J=opt.step()
99
+ else: opt.step(); J=0.0
100
+ with torch.no_grad():
101
+ A.pos.add_(A.vel*0.0)
102
+ d=torch.cdist(A.pos, A.pos)
103
+ c=(d< A.r0*0.99).sum().item()-n
104
+ collisions = max(0, c)
105
+ tm.emit(mode=mode, step=s, drift=round(drift,3), flux=round(flux,3),
106
+ E_ret=0.992, coh=0.999, loss=round(float(loss.item()),4),
107
+ collisions=collisions)
108
+ tm.close()
109
+ return f"Stage 2 complete. Telemetry saved to {log_path}"