| """ |
| Superposition Patch Classifier - Unfrozen Trainer |
| =================================================== |
| Colab Cell 3 of 3 - depends on Cell 1 (generator.py) and Cell 2 (model.py). |
| |
| End-to-end training: all parameters, all losses, no freezing. |
| Two-tier gate architecture trains jointly — local and structural gates |
| co-evolve with shape classification. |
| """ |
|
|
| import os |
| import time |
| import numpy as np |
| from dataclasses import dataclass, asdict |
| from typing import Dict |
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
|
|
| |
|
|
| HF_REPO = "AbstractPhil/grid-geometric-multishape" |
|
|
| def upload_checkpoint(model, epoch, metrics, config): |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| path = f"/tmp/best_model_epoch{epoch}.pt" |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "epoch": epoch, |
| "metrics": metrics, |
| "config": asdict(config), |
| }, path) |
| api.upload_file(path_or_fileobj=path, path_in_repo=f"checkpoint_v10/best_model_epoch{epoch}.pt", |
| repo_id=HF_REPO, repo_type="model") |
| print(f" ✓ Uploaded checkpoint epoch {epoch}") |
| except Exception as e: |
| print(f" ✗ Upload failed: {e}") |
|
|
| def upload_tensorboard(log_dir): |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_folder(folder_path=log_dir, path_in_repo="runs/", |
| repo_id=HF_REPO, repo_type="model") |
| print(" ✓ Uploaded TensorBoard logs") |
| except Exception as e: |
| print(f" ✗ TB upload failed: {e}") |
|
|
|
|
| |
|
|
| def compute_metrics(outputs: Dict, targets: Dict) -> Dict[str, float]: |
| metrics = {} |
| occ_mask = targets["patch_occupancy"] > 0.01 |
| n_occ = occ_mask.sum().item() |
|
|
| if n_occ > 0: |
| |
| pred_dims = outputs["local_dim_logits"].argmax(dim=-1) |
| true_dims = targets["patch_dims"].clamp(0, NUM_LOCAL_DIMS - 1) |
| metrics["local_dim_acc"] = ((pred_dims == true_dims) & occ_mask).sum().item() / n_occ |
|
|
| pred_curv = outputs["local_curv_logits"].argmax(dim=-1) |
| true_curv = targets["patch_curvature"].clamp(0, NUM_LOCAL_CURVS - 1) |
| metrics["local_curv_acc"] = ((pred_curv == true_curv) & occ_mask).sum().item() / n_occ |
|
|
| pred_bound = (torch.sigmoid(outputs["local_bound_logits"].squeeze(-1)) > 0.5).float() |
| true_bound = targets["patch_boundary"] |
| metrics["local_bound_acc"] = ((pred_bound == true_bound) & occ_mask).sum().item() / n_occ |
|
|
| pred_axis = (torch.sigmoid(outputs["local_axis_logits"]) > 0.5).float() |
| true_axis = targets["patch_axis_active"] |
| metrics["local_axis_acc"] = ((pred_axis == true_axis).all(dim=-1) & occ_mask).sum().item() / n_occ |
|
|
| |
| pred_topo = outputs["struct_topo_logits"].argmax(dim=-1) |
| true_topo = targets["patch_topology"].clamp(0, NUM_STRUCT_TOPO - 1) |
| metrics["struct_topo_acc"] = ((pred_topo == true_topo) & occ_mask).sum().item() / n_occ |
|
|
| pred_role = outputs["struct_role_logits"].argmax(dim=-1) |
| true_role = targets["patch_surface_role"].clamp(0, NUM_STRUCT_ROLE - 1) |
| metrics["struct_role_acc"] = ((pred_role == true_role) & occ_mask).sum().item() / n_occ |
|
|
| |
| if "patch_shape_logits" in outputs and "patch_shape_membership" in targets: |
| pred_shapes = (torch.sigmoid(outputs["patch_shape_logits"]) > 0.5).float() |
| true_shapes = targets["patch_shape_membership"] |
| shape_match = (pred_shapes == true_shapes).float().mean(dim=-1) |
| metrics["patch_shape_acc"] = (shape_match * occ_mask.float()).sum().item() / n_occ |
| else: |
| for k in ["local_dim_acc", "local_curv_acc", "local_bound_acc", "local_axis_acc", |
| "struct_topo_acc", "struct_role_acc", "patch_shape_acc"]: |
| metrics[k] = 0.0 |
|
|
| |
| if "global_shapes" in outputs and "global_shapes" in targets: |
| pred_shapes = (torch.sigmoid(outputs["global_shapes"]) > 0.5).float() |
| true_shapes = targets["global_shapes"] |
| metrics["global_shape_acc"] = (pred_shapes == true_shapes).float().mean().item() |
| true_pos = (pred_shapes * true_shapes).sum() |
| total_true = true_shapes.sum().clamp(min=1) |
| metrics["global_shape_recall"] = (true_pos / total_true).item() |
|
|
| pred_gates = (torch.sigmoid(outputs["global_gates"]) > 0.5).float() |
| true_gates = (targets["global_gates"] > 0.5).float() |
| metrics["global_gate_acc"] = (pred_gates == true_gates).float().mean().item() |
|
|
| return metrics |
|
|
|
|
| |
|
|
| @dataclass |
| class Config: |
| |
| n_samples: int = 500000 |
| n_val: int = 50000 |
| seed: int = 420 |
|
|
| |
| embed_dim: int = 256 |
| patch_dim: int = 64 |
| n_bootstrap: int = 2 |
| n_geometric: int = 2 |
| n_heads: int = 4 |
| dropout: float = 0.1 |
|
|
| |
| epochs: int = 200 |
| batch_size: int = 512 |
| lr: float = 3e-4 |
| weight_decay: float = 0.01 |
| warmup_steps: int = 500 |
| upload_every: int = 20 |
|
|
|
|
| |
|
|
| def make_loader(n_samples, seed, device, batch_size, shuffle=True): |
| data = generate_dataset(n_samples, seed=seed, num_workers=MAX_WORKERS) |
| grids = torch.from_numpy(data["grids"]).float().to(device) |
| memberships = torch.from_numpy(data["memberships"]).float().to(device) |
| with torch.no_grad(): |
| patch_data = analyze_patches_torch(grids) |
| grids, memberships = grids.cpu(), memberships.cpu() |
| patch_data = {k: v.cpu() for k, v in patch_data.items()} |
| ds = ShapeDataset(grids, memberships, patch_data) |
| return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, |
| collate_fn=collate_fn, num_workers=0, pin_memory=True) |
|
|
|
|
| |
|
|
| def train(): |
| config = Config() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| print(f"Config: {config}") |
|
|
| from torch.utils.tensorboard import SummaryWriter |
| log_dir = "/tmp/tb_logs" |
| writer = SummaryWriter(log_dir) |
|
|
| |
| print(f"\nGenerating training set ({config.n_samples} samples)...") |
| train_loader = make_loader(config.n_samples, seed=config.seed, device=device, |
| batch_size=config.batch_size, shuffle=True) |
| print(f"✓ Train set ready") |
|
|
| print(f"Generating val set ({config.n_val} samples)...") |
| val_loader = make_loader(config.n_val, seed=0, device=device, |
| batch_size=config.batch_size * 2, shuffle=False) |
| print(f"✓ Val set ready") |
|
|
| |
| model = SuperpositionPatchClassifier( |
| config.embed_dim, config.patch_dim, config.n_bootstrap, config.n_geometric, |
| config.n_heads, config.dropout).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Parameters: {n_params:,}") |
|
|
| |
| loss_fn = SuperpositionLoss(local_weight=1.0, struct_weight=1.0, shape_weight=1.0, global_weight=0.5) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) |
|
|
| steps_per_epoch = len(train_loader) |
| total_steps = steps_per_epoch * config.epochs |
| def lr_lambda(step): |
| if step < config.warmup_steps: |
| return step / config.warmup_steps |
| return 0.5 * (1 + np.cos(np.pi * (step - config.warmup_steps) / (total_steps - config.warmup_steps))) |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
| best_recall = 0.0 |
| global_step = 0 |
|
|
| print(f"\nTraining for {config.epochs} epochs (unfrozen, all losses)...\n") |
| for epoch in range(1, config.epochs + 1): |
| model.train() |
| epoch_loss, n_batches = 0.0, 0 |
|
|
| pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{config.epochs}") |
| for batch in pbar: |
| batch = {k: v.to(device) for k, v in batch.items()} |
| outputs = model(batch["grid"]) |
| losses = loss_fn(outputs, batch) |
| optimizer.zero_grad() |
| losses["total"].backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| global_step += 1 |
| epoch_loss += losses["total"].item() |
| n_batches += 1 |
| pbar.set_postfix(loss=f"{losses['total'].item():.3f}", lr=f"{scheduler.get_last_lr()[0]:.2e}") |
|
|
| avg_train_loss = epoch_loss / n_batches |
|
|
| |
| model.eval() |
| val_metrics_list = [] |
| with torch.no_grad(): |
| for batch in val_loader: |
| batch = {k: v.to(device) for k, v in batch.items()} |
| outputs = model(batch["grid"]) |
| val_metrics_list.append(compute_metrics(outputs, batch)) |
|
|
| m = {k: np.mean([v[k] for v in val_metrics_list]) for k in val_metrics_list[0]} |
|
|
| recall = m.get("global_shape_recall", 0) |
| local_min = min(m.get("local_dim_acc", 0), m.get("local_curv_acc", 0), |
| m.get("local_bound_acc", 0), m.get("local_axis_acc", 0)) |
| struct_min = min(m.get("struct_topo_acc", 0), m.get("struct_role_acc", 0)) |
|
|
| print(f"Epoch {epoch} | Loss: {avg_train_loss:.4f} | Recall: {recall:.4f} | " |
| f"Local≥{local_min:.4f} | Struct≥{struct_min:.4f}") |
|
|
| |
| writer.add_scalar("loss/train", avg_train_loss, epoch) |
| writer.add_scalar("recall", recall, epoch) |
| writer.add_scalar("local/dim", m.get("local_dim_acc", 0), epoch) |
| writer.add_scalar("local/curv", m.get("local_curv_acc", 0), epoch) |
| writer.add_scalar("local/bound", m.get("local_bound_acc", 0), epoch) |
| writer.add_scalar("local/axis", m.get("local_axis_acc", 0), epoch) |
| writer.add_scalar("struct/topo", m.get("struct_topo_acc", 0), epoch) |
| writer.add_scalar("struct/role", m.get("struct_role_acc", 0), epoch) |
| writer.add_scalar("shape/patch_acc", m.get("patch_shape_acc", 0), epoch) |
| writer.add_scalar("shape/global_acc", m.get("global_shape_acc", 0), epoch) |
| writer.add_scalar("lr", scheduler.get_last_lr()[0], epoch) |
|
|
| |
| if recall > best_recall: |
| best_recall = recall |
| if epoch % config.upload_every == 0 or epoch == config.epochs: |
| upload_checkpoint(model, epoch, m, config) |
| elif epoch % config.upload_every == 0: |
| upload_checkpoint(model, epoch, m, config) |
|
|
| |
| writer.close() |
| upload_checkpoint(model, config.epochs, m, config) |
| upload_tensorboard(log_dir) |
| print(f"\n{'='*70}") |
| print(f"TRAINING COMPLETE") |
| print(f" Local gates: ≥{local_min:.4f}") |
| print(f" Struct gates: ≥{struct_min:.4f}") |
| print(f" Best Recall: {best_recall:.4f}") |
| print(f"{'='*70}") |
|
|
|
|
| |
| train() |
|
|
| print("✓ Training complete") |