""" Superposition Patch Classifier - Unfrozen Trainer =================================================== Colab Cell 3 of 3 - depends on Cell 1 (generator.py) and Cell 2 (model.py). End-to-end training: all parameters, all losses, no freezing. Two-tier gate architecture trains jointly — local and structural gates co-evolve with shape classification. """ import os import time import numpy as np from dataclasses import dataclass, asdict from typing import Dict from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader # Cell 1 provides: generate_dataset, analyze_patches_torch, ShapeDataset, collate_fn, # MAX_WORKERS, NUM_CLASSES, CLASS_NAMES, MACRO_N, # LOCAL_GATE_DIM, STRUCTURAL_GATE_DIM, TOTAL_GATE_DIM, # NUM_LOCAL_DIMS, NUM_LOCAL_CURVS, NUM_LOCAL_BOUNDARY, NUM_LOCAL_AXES, # NUM_STRUCT_TOPO, NUM_STRUCT_NEIGHBOR, NUM_STRUCT_ROLE, NUM_GATES # Cell 2 provides: SuperpositionPatchClassifier, SuperpositionLoss # === HuggingFace ============================================================== HF_REPO = "AbstractPhil/grid-geometric-multishape" def upload_checkpoint(model, epoch, metrics, config): try: from huggingface_hub import HfApi api = HfApi() path = f"/tmp/best_model_epoch{epoch}.pt" torch.save({ "model_state_dict": model.state_dict(), "epoch": epoch, "metrics": metrics, "config": asdict(config), }, path) api.upload_file(path_or_fileobj=path, path_in_repo=f"checkpoint_v10/best_model_epoch{epoch}.pt", repo_id=HF_REPO, repo_type="model") print(f" ✓ Uploaded checkpoint epoch {epoch}") except Exception as e: print(f" ✗ Upload failed: {e}") def upload_tensorboard(log_dir): try: from huggingface_hub import HfApi api = HfApi() api.upload_folder(folder_path=log_dir, path_in_repo="runs/", repo_id=HF_REPO, repo_type="model") print(" ✓ Uploaded TensorBoard logs") except Exception as e: print(f" ✗ TB upload failed: {e}") # === Metrics ================================================================== def compute_metrics(outputs: Dict, targets: Dict) -> Dict[str, float]: metrics = {} occ_mask = targets["patch_occupancy"] > 0.01 n_occ = occ_mask.sum().item() if n_occ > 0: # Local gate metrics pred_dims = outputs["local_dim_logits"].argmax(dim=-1) true_dims = targets["patch_dims"].clamp(0, NUM_LOCAL_DIMS - 1) metrics["local_dim_acc"] = ((pred_dims == true_dims) & occ_mask).sum().item() / n_occ pred_curv = outputs["local_curv_logits"].argmax(dim=-1) true_curv = targets["patch_curvature"].clamp(0, NUM_LOCAL_CURVS - 1) metrics["local_curv_acc"] = ((pred_curv == true_curv) & occ_mask).sum().item() / n_occ pred_bound = (torch.sigmoid(outputs["local_bound_logits"].squeeze(-1)) > 0.5).float() true_bound = targets["patch_boundary"] metrics["local_bound_acc"] = ((pred_bound == true_bound) & occ_mask).sum().item() / n_occ pred_axis = (torch.sigmoid(outputs["local_axis_logits"]) > 0.5).float() true_axis = targets["patch_axis_active"] metrics["local_axis_acc"] = ((pred_axis == true_axis).all(dim=-1) & occ_mask).sum().item() / n_occ # Structural gate metrics pred_topo = outputs["struct_topo_logits"].argmax(dim=-1) true_topo = targets["patch_topology"].clamp(0, NUM_STRUCT_TOPO - 1) metrics["struct_topo_acc"] = ((pred_topo == true_topo) & occ_mask).sum().item() / n_occ pred_role = outputs["struct_role_logits"].argmax(dim=-1) true_role = targets["patch_surface_role"].clamp(0, NUM_STRUCT_ROLE - 1) metrics["struct_role_acc"] = ((pred_role == true_role) & occ_mask).sum().item() / n_occ # Shape metrics if "patch_shape_logits" in outputs and "patch_shape_membership" in targets: pred_shapes = (torch.sigmoid(outputs["patch_shape_logits"]) > 0.5).float() true_shapes = targets["patch_shape_membership"] shape_match = (pred_shapes == true_shapes).float().mean(dim=-1) metrics["patch_shape_acc"] = (shape_match * occ_mask.float()).sum().item() / n_occ else: for k in ["local_dim_acc", "local_curv_acc", "local_bound_acc", "local_axis_acc", "struct_topo_acc", "struct_role_acc", "patch_shape_acc"]: metrics[k] = 0.0 # Global if "global_shapes" in outputs and "global_shapes" in targets: pred_shapes = (torch.sigmoid(outputs["global_shapes"]) > 0.5).float() true_shapes = targets["global_shapes"] metrics["global_shape_acc"] = (pred_shapes == true_shapes).float().mean().item() true_pos = (pred_shapes * true_shapes).sum() total_true = true_shapes.sum().clamp(min=1) metrics["global_shape_recall"] = (true_pos / total_true).item() pred_gates = (torch.sigmoid(outputs["global_gates"]) > 0.5).float() true_gates = (targets["global_gates"] > 0.5).float() metrics["global_gate_acc"] = (pred_gates == true_gates).float().mean().item() return metrics # === Config =================================================================== @dataclass class Config: # Data n_samples: int = 500000 n_val: int = 50000 seed: int = 420 # Model embed_dim: int = 256 patch_dim: int = 64 n_bootstrap: int = 2 n_geometric: int = 2 n_heads: int = 4 dropout: float = 0.1 # Training epochs: int = 200 batch_size: int = 512 lr: float = 3e-4 weight_decay: float = 0.01 warmup_steps: int = 500 upload_every: int = 20 # === Data Loading ============================================================= def make_loader(n_samples, seed, device, batch_size, shuffle=True): data = generate_dataset(n_samples, seed=seed, num_workers=MAX_WORKERS) grids = torch.from_numpy(data["grids"]).float().to(device) memberships = torch.from_numpy(data["memberships"]).float().to(device) with torch.no_grad(): patch_data = analyze_patches_torch(grids) grids, memberships = grids.cpu(), memberships.cpu() patch_data = {k: v.cpu() for k, v in patch_data.items()} ds = ShapeDataset(grids, memberships, patch_data) return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=0, pin_memory=True) # === Training ================================================================= def train(): config = Config() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") print(f"Config: {config}") from torch.utils.tensorboard import SummaryWriter log_dir = "/tmp/tb_logs" writer = SummaryWriter(log_dir) # Generate data once print(f"\nGenerating training set ({config.n_samples} samples)...") train_loader = make_loader(config.n_samples, seed=config.seed, device=device, batch_size=config.batch_size, shuffle=True) print(f"✓ Train set ready") print(f"Generating val set ({config.n_val} samples)...") val_loader = make_loader(config.n_val, seed=0, device=device, batch_size=config.batch_size * 2, shuffle=False) print(f"✓ Val set ready") # Model model = SuperpositionPatchClassifier( config.embed_dim, config.patch_dim, config.n_bootstrap, config.n_geometric, config.n_heads, config.dropout).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"Parameters: {n_params:,}") # All losses active, all parameters trainable loss_fn = SuperpositionLoss(local_weight=1.0, struct_weight=1.0, shape_weight=1.0, global_weight=0.5) optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) steps_per_epoch = len(train_loader) total_steps = steps_per_epoch * config.epochs def lr_lambda(step): if step < config.warmup_steps: return step / config.warmup_steps return 0.5 * (1 + np.cos(np.pi * (step - config.warmup_steps) / (total_steps - config.warmup_steps))) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) best_recall = 0.0 global_step = 0 print(f"\nTraining for {config.epochs} epochs (unfrozen, all losses)...\n") for epoch in range(1, config.epochs + 1): model.train() epoch_loss, n_batches = 0.0, 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{config.epochs}") for batch in pbar: batch = {k: v.to(device) for k, v in batch.items()} outputs = model(batch["grid"]) losses = loss_fn(outputs, batch) optimizer.zero_grad() losses["total"].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() global_step += 1 epoch_loss += losses["total"].item() n_batches += 1 pbar.set_postfix(loss=f"{losses['total'].item():.3f}", lr=f"{scheduler.get_last_lr()[0]:.2e}") avg_train_loss = epoch_loss / n_batches # Validate model.eval() val_metrics_list = [] with torch.no_grad(): for batch in val_loader: batch = {k: v.to(device) for k, v in batch.items()} outputs = model(batch["grid"]) val_metrics_list.append(compute_metrics(outputs, batch)) m = {k: np.mean([v[k] for v in val_metrics_list]) for k in val_metrics_list[0]} recall = m.get("global_shape_recall", 0) local_min = min(m.get("local_dim_acc", 0), m.get("local_curv_acc", 0), m.get("local_bound_acc", 0), m.get("local_axis_acc", 0)) struct_min = min(m.get("struct_topo_acc", 0), m.get("struct_role_acc", 0)) print(f"Epoch {epoch} | Loss: {avg_train_loss:.4f} | Recall: {recall:.4f} | " f"Local≥{local_min:.4f} | Struct≥{struct_min:.4f}") # TensorBoard writer.add_scalar("loss/train", avg_train_loss, epoch) writer.add_scalar("recall", recall, epoch) writer.add_scalar("local/dim", m.get("local_dim_acc", 0), epoch) writer.add_scalar("local/curv", m.get("local_curv_acc", 0), epoch) writer.add_scalar("local/bound", m.get("local_bound_acc", 0), epoch) writer.add_scalar("local/axis", m.get("local_axis_acc", 0), epoch) writer.add_scalar("struct/topo", m.get("struct_topo_acc", 0), epoch) writer.add_scalar("struct/role", m.get("struct_role_acc", 0), epoch) writer.add_scalar("shape/patch_acc", m.get("patch_shape_acc", 0), epoch) writer.add_scalar("shape/global_acc", m.get("global_shape_acc", 0), epoch) writer.add_scalar("lr", scheduler.get_last_lr()[0], epoch) # Upload if recall > best_recall: best_recall = recall if epoch % config.upload_every == 0 or epoch == config.epochs: upload_checkpoint(model, epoch, m, config) elif epoch % config.upload_every == 0: upload_checkpoint(model, epoch, m, config) # Final writer.close() upload_checkpoint(model, config.epochs, m, config) upload_tensorboard(log_dir) print(f"\n{'='*70}") print(f"TRAINING COMPLETE") print(f" Local gates: ≥{local_min:.4f}") print(f" Struct gates: ≥{struct_min:.4f}") print(f" Best Recall: {best_recall:.4f}") print(f"{'='*70}") # === Run ====================================================================== train() print("✓ Training complete")