AbstractPhil commited on
Commit
9d278c5
·
verified ·
1 Parent(s): 7e37f31

Create cell3_trainer.py

Browse files
Files changed (1) hide show
  1. cell3_trainer.py +287 -0
cell3_trainer.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Superposition Patch Classifier - Unfrozen Trainer
3
+ ===================================================
4
+ Colab Cell 3 of 3 - depends on Cell 1 (generator.py) and Cell 2 (model.py).
5
+
6
+ End-to-end training: all parameters, all losses, no freezing.
7
+ Two-tier gate architecture trains jointly — local and structural gates
8
+ co-evolve with shape classification.
9
+ """
10
+
11
+ import os
12
+ import time
13
+ import numpy as np
14
+ from dataclasses import dataclass, asdict
15
+ from typing import Dict
16
+ from tqdm import tqdm
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import DataLoader
21
+
22
+ # Cell 1 provides: generate_dataset, analyze_patches_torch, ShapeDataset, collate_fn,
23
+ # MAX_WORKERS, NUM_CLASSES, CLASS_NAMES, MACRO_N,
24
+ # LOCAL_GATE_DIM, STRUCTURAL_GATE_DIM, TOTAL_GATE_DIM,
25
+ # NUM_LOCAL_DIMS, NUM_LOCAL_CURVS, NUM_LOCAL_BOUNDARY, NUM_LOCAL_AXES,
26
+ # NUM_STRUCT_TOPO, NUM_STRUCT_NEIGHBOR, NUM_STRUCT_ROLE, NUM_GATES
27
+
28
+ # Cell 2 provides: SuperpositionPatchClassifier, SuperpositionLoss
29
+
30
+
31
+ # === HuggingFace ==============================================================
32
+
33
+ HF_REPO = "AbstractPhil/grid-geometric-multishape"
34
+
35
+ def upload_checkpoint(model, epoch, metrics, config):
36
+ try:
37
+ from huggingface_hub import HfApi
38
+ api = HfApi()
39
+ path = f"/tmp/best_model_epoch{epoch}.pt"
40
+ torch.save({
41
+ "model_state_dict": model.state_dict(),
42
+ "epoch": epoch,
43
+ "metrics": metrics,
44
+ "config": asdict(config),
45
+ }, path)
46
+ api.upload_file(path_or_fileobj=path, path_in_repo=f"checkpoint_v10/best_model_epoch{epoch}.pt",
47
+ repo_id=HF_REPO, repo_type="model")
48
+ print(f" ✓ Uploaded checkpoint epoch {epoch}")
49
+ except Exception as e:
50
+ print(f" ✗ Upload failed: {e}")
51
+
52
+ def upload_tensorboard(log_dir):
53
+ try:
54
+ from huggingface_hub import HfApi
55
+ api = HfApi()
56
+ api.upload_folder(folder_path=log_dir, path_in_repo="runs/",
57
+ repo_id=HF_REPO, repo_type="model")
58
+ print(" ✓ Uploaded TensorBoard logs")
59
+ except Exception as e:
60
+ print(f" ✗ TB upload failed: {e}")
61
+
62
+
63
+ # === Metrics ==================================================================
64
+
65
+ def compute_metrics(outputs: Dict, targets: Dict) -> Dict[str, float]:
66
+ metrics = {}
67
+ occ_mask = targets["patch_occupancy"] > 0.01
68
+ n_occ = occ_mask.sum().item()
69
+
70
+ if n_occ > 0:
71
+ # Local gate metrics
72
+ pred_dims = outputs["local_dim_logits"].argmax(dim=-1)
73
+ true_dims = targets["patch_dims"].clamp(0, NUM_LOCAL_DIMS - 1)
74
+ metrics["local_dim_acc"] = ((pred_dims == true_dims) & occ_mask).sum().item() / n_occ
75
+
76
+ pred_curv = outputs["local_curv_logits"].argmax(dim=-1)
77
+ true_curv = targets["patch_curvature"].clamp(0, NUM_LOCAL_CURVS - 1)
78
+ metrics["local_curv_acc"] = ((pred_curv == true_curv) & occ_mask).sum().item() / n_occ
79
+
80
+ pred_bound = (torch.sigmoid(outputs["local_bound_logits"].squeeze(-1)) > 0.5).float()
81
+ true_bound = targets["patch_boundary"]
82
+ metrics["local_bound_acc"] = ((pred_bound == true_bound) & occ_mask).sum().item() / n_occ
83
+
84
+ pred_axis = (torch.sigmoid(outputs["local_axis_logits"]) > 0.5).float()
85
+ true_axis = targets["patch_axis_active"]
86
+ metrics["local_axis_acc"] = ((pred_axis == true_axis).all(dim=-1) & occ_mask).sum().item() / n_occ
87
+
88
+ # Structural gate metrics
89
+ pred_topo = outputs["struct_topo_logits"].argmax(dim=-1)
90
+ true_topo = targets["patch_topology"].clamp(0, NUM_STRUCT_TOPO - 1)
91
+ metrics["struct_topo_acc"] = ((pred_topo == true_topo) & occ_mask).sum().item() / n_occ
92
+
93
+ pred_role = outputs["struct_role_logits"].argmax(dim=-1)
94
+ true_role = targets["patch_surface_role"].clamp(0, NUM_STRUCT_ROLE - 1)
95
+ metrics["struct_role_acc"] = ((pred_role == true_role) & occ_mask).sum().item() / n_occ
96
+
97
+ # Shape metrics
98
+ if "patch_shape_logits" in outputs and "patch_shape_membership" in targets:
99
+ pred_shapes = (torch.sigmoid(outputs["patch_shape_logits"]) > 0.5).float()
100
+ true_shapes = targets["patch_shape_membership"]
101
+ shape_match = (pred_shapes == true_shapes).float().mean(dim=-1)
102
+ metrics["patch_shape_acc"] = (shape_match * occ_mask.float()).sum().item() / n_occ
103
+ else:
104
+ for k in ["local_dim_acc", "local_curv_acc", "local_bound_acc", "local_axis_acc",
105
+ "struct_topo_acc", "struct_role_acc", "patch_shape_acc"]:
106
+ metrics[k] = 0.0
107
+
108
+ # Global
109
+ if "global_shapes" in outputs and "global_shapes" in targets:
110
+ pred_shapes = (torch.sigmoid(outputs["global_shapes"]) > 0.5).float()
111
+ true_shapes = targets["global_shapes"]
112
+ metrics["global_shape_acc"] = (pred_shapes == true_shapes).float().mean().item()
113
+ true_pos = (pred_shapes * true_shapes).sum()
114
+ total_true = true_shapes.sum().clamp(min=1)
115
+ metrics["global_shape_recall"] = (true_pos / total_true).item()
116
+
117
+ pred_gates = (torch.sigmoid(outputs["global_gates"]) > 0.5).float()
118
+ true_gates = (targets["global_gates"] > 0.5).float()
119
+ metrics["global_gate_acc"] = (pred_gates == true_gates).float().mean().item()
120
+
121
+ return metrics
122
+
123
+
124
+ # === Config ===================================================================
125
+
126
+ @dataclass
127
+ class Config:
128
+ # Data
129
+ n_samples: int = 500000
130
+ n_val: int = 50000
131
+ seed: int = 420
132
+
133
+ # Model
134
+ embed_dim: int = 256
135
+ patch_dim: int = 64
136
+ n_bootstrap: int = 2
137
+ n_geometric: int = 2
138
+ n_heads: int = 4
139
+ dropout: float = 0.1
140
+
141
+ # Training
142
+ epochs: int = 200
143
+ batch_size: int = 512
144
+ lr: float = 3e-4
145
+ weight_decay: float = 0.01
146
+ warmup_steps: int = 500
147
+ upload_every: int = 20
148
+
149
+
150
+ # === Data Loading =============================================================
151
+
152
+ def make_loader(n_samples, seed, device, batch_size, shuffle=True):
153
+ data = generate_dataset(n_samples, seed=seed, num_workers=MAX_WORKERS)
154
+ grids = torch.from_numpy(data["grids"]).float().to(device)
155
+ memberships = torch.from_numpy(data["memberships"]).float().to(device)
156
+ with torch.no_grad():
157
+ patch_data = analyze_patches_torch(grids)
158
+ grids, memberships = grids.cpu(), memberships.cpu()
159
+ patch_data = {k: v.cpu() for k, v in patch_data.items()}
160
+ ds = ShapeDataset(grids, memberships, patch_data)
161
+ return DataLoader(ds, batch_size=batch_size, shuffle=shuffle,
162
+ collate_fn=collate_fn, num_workers=0, pin_memory=True)
163
+
164
+
165
+ # === Training =================================================================
166
+
167
+ def train():
168
+ config = Config()
169
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
170
+ print(f"Device: {device}")
171
+ print(f"Config: {config}")
172
+
173
+ from torch.utils.tensorboard import SummaryWriter
174
+ log_dir = "/tmp/tb_logs"
175
+ writer = SummaryWriter(log_dir)
176
+
177
+ # Generate data once
178
+ print(f"\nGenerating training set ({config.n_samples} samples)...")
179
+ train_loader = make_loader(config.n_samples, seed=config.seed, device=device,
180
+ batch_size=config.batch_size, shuffle=True)
181
+ print(f"✓ Train set ready")
182
+
183
+ print(f"Generating val set ({config.n_val} samples)...")
184
+ val_loader = make_loader(config.n_val, seed=0, device=device,
185
+ batch_size=config.batch_size * 2, shuffle=False)
186
+ print(f"✓ Val set ready")
187
+
188
+ # Model
189
+ model = SuperpositionPatchClassifier(
190
+ config.embed_dim, config.patch_dim, config.n_bootstrap, config.n_geometric,
191
+ config.n_heads, config.dropout).to(device)
192
+ n_params = sum(p.numel() for p in model.parameters())
193
+ print(f"Parameters: {n_params:,}")
194
+
195
+ # All losses active, all parameters trainable
196
+ loss_fn = SuperpositionLoss(local_weight=1.0, struct_weight=1.0, shape_weight=1.0, global_weight=0.5)
197
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
198
+
199
+ steps_per_epoch = len(train_loader)
200
+ total_steps = steps_per_epoch * config.epochs
201
+ def lr_lambda(step):
202
+ if step < config.warmup_steps:
203
+ return step / config.warmup_steps
204
+ return 0.5 * (1 + np.cos(np.pi * (step - config.warmup_steps) / (total_steps - config.warmup_steps)))
205
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
206
+
207
+ best_recall = 0.0
208
+ global_step = 0
209
+
210
+ print(f"\nTraining for {config.epochs} epochs (unfrozen, all losses)...\n")
211
+ for epoch in range(1, config.epochs + 1):
212
+ model.train()
213
+ epoch_loss, n_batches = 0.0, 0
214
+
215
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{config.epochs}")
216
+ for batch in pbar:
217
+ batch = {k: v.to(device) for k, v in batch.items()}
218
+ outputs = model(batch["grid"])
219
+ losses = loss_fn(outputs, batch)
220
+ optimizer.zero_grad()
221
+ losses["total"].backward()
222
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
223
+ optimizer.step()
224
+ scheduler.step()
225
+ global_step += 1
226
+ epoch_loss += losses["total"].item()
227
+ n_batches += 1
228
+ pbar.set_postfix(loss=f"{losses['total'].item():.3f}", lr=f"{scheduler.get_last_lr()[0]:.2e}")
229
+
230
+ avg_train_loss = epoch_loss / n_batches
231
+
232
+ # Validate
233
+ model.eval()
234
+ val_metrics_list = []
235
+ with torch.no_grad():
236
+ for batch in val_loader:
237
+ batch = {k: v.to(device) for k, v in batch.items()}
238
+ outputs = model(batch["grid"])
239
+ val_metrics_list.append(compute_metrics(outputs, batch))
240
+
241
+ m = {k: np.mean([v[k] for v in val_metrics_list]) for k in val_metrics_list[0]}
242
+
243
+ recall = m.get("global_shape_recall", 0)
244
+ local_min = min(m.get("local_dim_acc", 0), m.get("local_curv_acc", 0),
245
+ m.get("local_bound_acc", 0), m.get("local_axis_acc", 0))
246
+ struct_min = min(m.get("struct_topo_acc", 0), m.get("struct_role_acc", 0))
247
+
248
+ print(f"Epoch {epoch} | Loss: {avg_train_loss:.4f} | Recall: {recall:.4f} | "
249
+ f"Local≥{local_min:.4f} | Struct≥{struct_min:.4f}")
250
+
251
+ # TensorBoard
252
+ writer.add_scalar("loss/train", avg_train_loss, epoch)
253
+ writer.add_scalar("recall", recall, epoch)
254
+ writer.add_scalar("local/dim", m.get("local_dim_acc", 0), epoch)
255
+ writer.add_scalar("local/curv", m.get("local_curv_acc", 0), epoch)
256
+ writer.add_scalar("local/bound", m.get("local_bound_acc", 0), epoch)
257
+ writer.add_scalar("local/axis", m.get("local_axis_acc", 0), epoch)
258
+ writer.add_scalar("struct/topo", m.get("struct_topo_acc", 0), epoch)
259
+ writer.add_scalar("struct/role", m.get("struct_role_acc", 0), epoch)
260
+ writer.add_scalar("shape/patch_acc", m.get("patch_shape_acc", 0), epoch)
261
+ writer.add_scalar("shape/global_acc", m.get("global_shape_acc", 0), epoch)
262
+ writer.add_scalar("lr", scheduler.get_last_lr()[0], epoch)
263
+
264
+ # Upload
265
+ if recall > best_recall:
266
+ best_recall = recall
267
+ if epoch % config.upload_every == 0 or epoch == config.epochs:
268
+ upload_checkpoint(model, epoch, m, config)
269
+ elif epoch % config.upload_every == 0:
270
+ upload_checkpoint(model, epoch, m, config)
271
+
272
+ # Final
273
+ writer.close()
274
+ upload_checkpoint(model, config.epochs, m, config)
275
+ upload_tensorboard(log_dir)
276
+ print(f"\n{'='*70}")
277
+ print(f"TRAINING COMPLETE")
278
+ print(f" Local gates: ≥{local_min:.4f}")
279
+ print(f" Struct gates: ≥{struct_min:.4f}")
280
+ print(f" Best Recall: {best_recall:.4f}")
281
+ print(f"{'='*70}")
282
+
283
+
284
+ # === Run ======================================================================
285
+ train()
286
+
287
+ print("✓ Training complete")