grid-geometric-multishape / cell3_trainer_v10.py
AbstractPhil's picture
Rename cell3_trainer.py to cell3_trainer_v10.py
655ab8b verified
"""
Superposition Patch Classifier - Unfrozen Trainer
===================================================
Colab Cell 3 of 3 - depends on Cell 1 (generator.py) and Cell 2 (model.py).
End-to-end training: all parameters, all losses, no freezing.
Two-tier gate architecture trains jointly — local and structural gates
co-evolve with shape classification.
"""
import os
import time
import numpy as np
from dataclasses import dataclass, asdict
from typing import Dict
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# Cell 1 provides: generate_dataset, analyze_patches_torch, ShapeDataset, collate_fn,
# MAX_WORKERS, NUM_CLASSES, CLASS_NAMES, MACRO_N,
# LOCAL_GATE_DIM, STRUCTURAL_GATE_DIM, TOTAL_GATE_DIM,
# NUM_LOCAL_DIMS, NUM_LOCAL_CURVS, NUM_LOCAL_BOUNDARY, NUM_LOCAL_AXES,
# NUM_STRUCT_TOPO, NUM_STRUCT_NEIGHBOR, NUM_STRUCT_ROLE, NUM_GATES
# Cell 2 provides: SuperpositionPatchClassifier, SuperpositionLoss
# === HuggingFace ==============================================================
HF_REPO = "AbstractPhil/grid-geometric-multishape"
def upload_checkpoint(model, epoch, metrics, config):
try:
from huggingface_hub import HfApi
api = HfApi()
path = f"/tmp/best_model_epoch{epoch}.pt"
torch.save({
"model_state_dict": model.state_dict(),
"epoch": epoch,
"metrics": metrics,
"config": asdict(config),
}, path)
api.upload_file(path_or_fileobj=path, path_in_repo=f"checkpoint_v10/best_model_epoch{epoch}.pt",
repo_id=HF_REPO, repo_type="model")
print(f" ✓ Uploaded checkpoint epoch {epoch}")
except Exception as e:
print(f" ✗ Upload failed: {e}")
def upload_tensorboard(log_dir):
try:
from huggingface_hub import HfApi
api = HfApi()
api.upload_folder(folder_path=log_dir, path_in_repo="runs/",
repo_id=HF_REPO, repo_type="model")
print(" ✓ Uploaded TensorBoard logs")
except Exception as e:
print(f" ✗ TB upload failed: {e}")
# === Metrics ==================================================================
def compute_metrics(outputs: Dict, targets: Dict) -> Dict[str, float]:
metrics = {}
occ_mask = targets["patch_occupancy"] > 0.01
n_occ = occ_mask.sum().item()
if n_occ > 0:
# Local gate metrics
pred_dims = outputs["local_dim_logits"].argmax(dim=-1)
true_dims = targets["patch_dims"].clamp(0, NUM_LOCAL_DIMS - 1)
metrics["local_dim_acc"] = ((pred_dims == true_dims) & occ_mask).sum().item() / n_occ
pred_curv = outputs["local_curv_logits"].argmax(dim=-1)
true_curv = targets["patch_curvature"].clamp(0, NUM_LOCAL_CURVS - 1)
metrics["local_curv_acc"] = ((pred_curv == true_curv) & occ_mask).sum().item() / n_occ
pred_bound = (torch.sigmoid(outputs["local_bound_logits"].squeeze(-1)) > 0.5).float()
true_bound = targets["patch_boundary"]
metrics["local_bound_acc"] = ((pred_bound == true_bound) & occ_mask).sum().item() / n_occ
pred_axis = (torch.sigmoid(outputs["local_axis_logits"]) > 0.5).float()
true_axis = targets["patch_axis_active"]
metrics["local_axis_acc"] = ((pred_axis == true_axis).all(dim=-1) & occ_mask).sum().item() / n_occ
# Structural gate metrics
pred_topo = outputs["struct_topo_logits"].argmax(dim=-1)
true_topo = targets["patch_topology"].clamp(0, NUM_STRUCT_TOPO - 1)
metrics["struct_topo_acc"] = ((pred_topo == true_topo) & occ_mask).sum().item() / n_occ
pred_role = outputs["struct_role_logits"].argmax(dim=-1)
true_role = targets["patch_surface_role"].clamp(0, NUM_STRUCT_ROLE - 1)
metrics["struct_role_acc"] = ((pred_role == true_role) & occ_mask).sum().item() / n_occ
# Shape metrics
if "patch_shape_logits" in outputs and "patch_shape_membership" in targets:
pred_shapes = (torch.sigmoid(outputs["patch_shape_logits"]) > 0.5).float()
true_shapes = targets["patch_shape_membership"]
shape_match = (pred_shapes == true_shapes).float().mean(dim=-1)
metrics["patch_shape_acc"] = (shape_match * occ_mask.float()).sum().item() / n_occ
else:
for k in ["local_dim_acc", "local_curv_acc", "local_bound_acc", "local_axis_acc",
"struct_topo_acc", "struct_role_acc", "patch_shape_acc"]:
metrics[k] = 0.0
# Global
if "global_shapes" in outputs and "global_shapes" in targets:
pred_shapes = (torch.sigmoid(outputs["global_shapes"]) > 0.5).float()
true_shapes = targets["global_shapes"]
metrics["global_shape_acc"] = (pred_shapes == true_shapes).float().mean().item()
true_pos = (pred_shapes * true_shapes).sum()
total_true = true_shapes.sum().clamp(min=1)
metrics["global_shape_recall"] = (true_pos / total_true).item()
pred_gates = (torch.sigmoid(outputs["global_gates"]) > 0.5).float()
true_gates = (targets["global_gates"] > 0.5).float()
metrics["global_gate_acc"] = (pred_gates == true_gates).float().mean().item()
return metrics
# === Config ===================================================================
@dataclass
class Config:
# Data
n_samples: int = 500000
n_val: int = 50000
seed: int = 420
# Model
embed_dim: int = 256
patch_dim: int = 64
n_bootstrap: int = 2
n_geometric: int = 2
n_heads: int = 4
dropout: float = 0.1
# Training
epochs: int = 200
batch_size: int = 512
lr: float = 3e-4
weight_decay: float = 0.01
warmup_steps: int = 500
upload_every: int = 20
# === Data Loading =============================================================
def make_loader(n_samples, seed, device, batch_size, shuffle=True):
data = generate_dataset(n_samples, seed=seed, num_workers=MAX_WORKERS)
grids = torch.from_numpy(data["grids"]).float().to(device)
memberships = torch.from_numpy(data["memberships"]).float().to(device)
with torch.no_grad():
patch_data = analyze_patches_torch(grids)
grids, memberships = grids.cpu(), memberships.cpu()
patch_data = {k: v.cpu() for k, v in patch_data.items()}
ds = ShapeDataset(grids, memberships, patch_data)
return DataLoader(ds, batch_size=batch_size, shuffle=shuffle,
collate_fn=collate_fn, num_workers=0, pin_memory=True)
# === Training =================================================================
def train():
config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print(f"Config: {config}")
from torch.utils.tensorboard import SummaryWriter
log_dir = "/tmp/tb_logs"
writer = SummaryWriter(log_dir)
# Generate data once
print(f"\nGenerating training set ({config.n_samples} samples)...")
train_loader = make_loader(config.n_samples, seed=config.seed, device=device,
batch_size=config.batch_size, shuffle=True)
print(f"✓ Train set ready")
print(f"Generating val set ({config.n_val} samples)...")
val_loader = make_loader(config.n_val, seed=0, device=device,
batch_size=config.batch_size * 2, shuffle=False)
print(f"✓ Val set ready")
# Model
model = SuperpositionPatchClassifier(
config.embed_dim, config.patch_dim, config.n_bootstrap, config.n_geometric,
config.n_heads, config.dropout).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,}")
# All losses active, all parameters trainable
loss_fn = SuperpositionLoss(local_weight=1.0, struct_weight=1.0, shape_weight=1.0, global_weight=0.5)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * config.epochs
def lr_lambda(step):
if step < config.warmup_steps:
return step / config.warmup_steps
return 0.5 * (1 + np.cos(np.pi * (step - config.warmup_steps) / (total_steps - config.warmup_steps)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
best_recall = 0.0
global_step = 0
print(f"\nTraining for {config.epochs} epochs (unfrozen, all losses)...\n")
for epoch in range(1, config.epochs + 1):
model.train()
epoch_loss, n_batches = 0.0, 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{config.epochs}")
for batch in pbar:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(batch["grid"])
losses = loss_fn(outputs, batch)
optimizer.zero_grad()
losses["total"].backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
global_step += 1
epoch_loss += losses["total"].item()
n_batches += 1
pbar.set_postfix(loss=f"{losses['total'].item():.3f}", lr=f"{scheduler.get_last_lr()[0]:.2e}")
avg_train_loss = epoch_loss / n_batches
# Validate
model.eval()
val_metrics_list = []
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(batch["grid"])
val_metrics_list.append(compute_metrics(outputs, batch))
m = {k: np.mean([v[k] for v in val_metrics_list]) for k in val_metrics_list[0]}
recall = m.get("global_shape_recall", 0)
local_min = min(m.get("local_dim_acc", 0), m.get("local_curv_acc", 0),
m.get("local_bound_acc", 0), m.get("local_axis_acc", 0))
struct_min = min(m.get("struct_topo_acc", 0), m.get("struct_role_acc", 0))
print(f"Epoch {epoch} | Loss: {avg_train_loss:.4f} | Recall: {recall:.4f} | "
f"Local≥{local_min:.4f} | Struct≥{struct_min:.4f}")
# TensorBoard
writer.add_scalar("loss/train", avg_train_loss, epoch)
writer.add_scalar("recall", recall, epoch)
writer.add_scalar("local/dim", m.get("local_dim_acc", 0), epoch)
writer.add_scalar("local/curv", m.get("local_curv_acc", 0), epoch)
writer.add_scalar("local/bound", m.get("local_bound_acc", 0), epoch)
writer.add_scalar("local/axis", m.get("local_axis_acc", 0), epoch)
writer.add_scalar("struct/topo", m.get("struct_topo_acc", 0), epoch)
writer.add_scalar("struct/role", m.get("struct_role_acc", 0), epoch)
writer.add_scalar("shape/patch_acc", m.get("patch_shape_acc", 0), epoch)
writer.add_scalar("shape/global_acc", m.get("global_shape_acc", 0), epoch)
writer.add_scalar("lr", scheduler.get_last_lr()[0], epoch)
# Upload
if recall > best_recall:
best_recall = recall
if epoch % config.upload_every == 0 or epoch == config.epochs:
upload_checkpoint(model, epoch, m, config)
elif epoch % config.upload_every == 0:
upload_checkpoint(model, epoch, m, config)
# Final
writer.close()
upload_checkpoint(model, config.epochs, m, config)
upload_tensorboard(log_dir)
print(f"\n{'='*70}")
print(f"TRAINING COMPLETE")
print(f" Local gates: ≥{local_min:.4f}")
print(f" Struct gates: ≥{struct_min:.4f}")
print(f" Best Recall: {best_recall:.4f}")
print(f"{'='*70}")
# === Run ======================================================================
train()
print("✓ Training complete")