AbstractPhil's picture
labeled mit and marked python
c69b879 verified
"""
Pentachoron Constellation — Multi-Channel, HF Push, and Dataset Sweep
MIT
Author: AbstractPhil
Quartermaster: Mirel (GPT-5 Thinking)
"""
from __future__ import annotations
import os, sys, json, math, time, random, shutil, zipfile, platform
from pathlib import Path
from datetime import datetime
from typing import List, Tuple, Dict, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
## ---------------------------------------------------------------------
## Fast settings / safety
## ---------------------------------------------------------------------
#torch.autograd.set_detect_anomaly(False)
#if torch.cuda.is_available():
# torch.backends.cudnn.benchmark = True
# torch.cuda.empty_cache()
# ---------------------------------------------------------------------
# Configuration (edit these)
# ---------------------------------------------------------------------
config: Dict = {
# Model dims
"input_dim": 28*28, # will be set by loader
"input_channels": "auto", # "auto" | 1 | 3 ; loader enforces
"base_dim": 56,
"proj_dim": None,
# Constellation
"num_classes": 10, # set by loader
"num_pentachoron_pairs": 2,
"lambda_separation": 0.391,
# Attention / extractor
"num_heads": 2,
"channels": 24,
# Training
"batch_size": 1024,
"epochs": 20,
"lr": 1e-2,
"weight_decay": 1e-5,
"temp": 0.7,
# Loss weights
"w_ce": 1.0,
"w_dual": 1.0,
"w_rose": 1.0,
"w_diag": 0.1,
"w_reg": 0.1, # default geom reg
# Legacy compat
"loss_weight_scalar": 0.1,
# Dataset override knobs
"img_size": 28, # unified target size
"img_channels": "auto", # "auto" | 1 | 3 ; coerces all sets
"normalize": True,
"per_dataset_norm": True,
"augment": False, # safe light aug
# Sweep control
"sweep_all": False, # set True to run all datasets
"seed": 420,
# Hugging Face
"hf_repo_id": "AbstractPhil/pentachora-multi-channel-frequency-encoded",
"dataset": "QMNIST",
}
# --- HF pathing / naming ---
config.setdefault("hf_subdir_root", "")
config.setdefault("hf_dataset_dir_template", "{dataset}") # folder under root
config.setdefault("hf_run_dir_template", "{ts}_{dataset}") # or "{ts}"
config.setdefault("hf_weight_suffix_dataset", True) # encoder_{dataset}.safetensors etc.
config.setdefault("hf_preserve_case", True) # keep DatasetName casing in paths
# --- Reproducibility / determinism ---
config.setdefault("deterministic", True) # set cudnn deterministic + disable benchmark
config.setdefault("strict_determinism", False) # torch.use_deterministic_algorithms(True)
config.setdefault("deterministic_cublas", False) # set CUBLAS_WORKSPACE_CONFIG
config.setdefault("seed_per_dataset", False) # re-seed using dataset name in sweep
# ---------------------------------------------------------------------
# Fast settings / safety
# ---------------------------------------------------------------------
torch.autograd.set_detect_anomaly(False)
# Determinism knobs (must be set before layers allocate kernels)
if bool(config.get("deterministic", True)):
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
else:
torch.backends.cudnn.benchmark = True
# TF32 off → numerically stable & repeatable on Ampere+
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
# cuBLAS deterministic workspace (opt-in; can slow some kernels)
if bool(config.get("deterministic_cublas", False)):
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("\n" + "="*60)
print("PENTACHORON CONSTELLATION CONFIGURATION")
print("="*60)
for k, v in config.items():
print(f"{k:24}: {v}")
# ---------------------------------------------------------------------
# Reproducibility
# ---------------------------------------------------------------------
# ---------------------------------------------------------------------
# Reproducibility
# ---------------------------------------------------------------------
def seed_everything(seed: int = 42,
deterministic: bool | None = None,
strict: bool | None = None):
"""Seed Python, NumPy, Torch (CPU+CUDA), and set hash seed/env flags."""
if deterministic is None:
deterministic = bool(config.get("deterministic", True))
if strict is None:
strict = bool(config.get("strict_determinism", False))
# OS / interpreter
os.environ["PYTHONHASHSEED"] = str(seed)
try:
import torch
torch.use_deterministic_algorithms(strict) # raises on nondet ops if True
except Exception:
pass
# RNGs
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def make_torch_generator(seed: int) -> torch.Generator:
g = torch.Generator()
g.manual_seed(seed)
return g
def seed_worker(worker_id: int):
"""Seed DataLoader worker; uses PyTorch's initial_seed to derive unique stream."""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# Initial global seed
seed_everything(int(config.get("seed", 42)))
# ---------------------------------------------------------------------
# Setup & deps
# ---------------------------------------------------------------------
def _ensure(pkg, pip_name=None):
pip_name = pip_name or pkg
try:
__import__(pkg)
except Exception:
print(f"[setup] Installing {pip_name} ...")
os.system(f"{sys.executable} -m pip install -q {pip_name}")
_ensure("safetensors")
_ensure("huggingface_hub")
_ensure("pandas")
_ensure("psutil")
_ensure("medmnist")
from safetensors.torch import save_file as save_safetensors
from huggingface_hub import HfApi, create_repo, whoami, login
import pandas as pd
import psutil
from torch.utils.tensorboard import SummaryWriter
# ---------------------------------------------------------------------
# Small utils
# ---------------------------------------------------------------------
def _timestamp() -> str:
return datetime.now().strftime("%Y%m%d-%H%M%S")
def _param_count(m: nn.Module) -> int:
return sum(p.numel() for p in m.parameters())
def _resolve_repo_id(cfg: Dict) -> str:
rid = os.getenv("PENTACHORA_HF_REPO") or cfg.get("hf_repo_id")
if not rid:
raise RuntimeError("Set config['hf_repo_id'] or export PENTACHORA_HF_REPO.")
return rid
def _hf_login_if_needed():
try:
_ = whoami()
except Exception:
token = os.getenv("HF_TOKEN")
if token:
login(token=token, add_to_git_credential=True)
else:
print("[huggingface] No login found and HF_TOKEN not set. Push may fail; run `huggingface-cli login`.")
def _ensure_repo(repo_id: str) -> HfApi:
api = HfApi()
create_repo(repo_id=repo_id, private=False, exist_ok=True, repo_type="model")
return api
def _zip_dir(src: Path, dst_zip: Path):
with zipfile.ZipFile(dst_zip, "w", zipfile.ZIP_DEFLATED) as z:
for p in src.rglob("*"):
z.write(p, arcname=p.relative_to(src))
def _dataset_slug(name_or_names) -> str:
if isinstance(name_or_names, (list, tuple)):
return "+".join(n.strip().lower() for n in name_or_names)
return str(name_or_names).strip().lower()
# ---------------------------------------------------------------------
# Dataset loader (TorchVision + MedMNIST), config-aware
# ---------------------------------------------------------------------
try:
import medmnist
from medmnist import INFO as MED_INFO
except Exception:
medmnist = None
MED_INFO = None
_TORCHVISION_KEYS = {
"mnist": "MNIST",
"fashionmnist": "FashionMNIST",
"kmnist": "KMNIST",
"emnist": "EMNIST", # balanced
"qmnist": "QMNIST",
"usps": "USPS",
}
_MEDMNIST_MAP = {
"bloodmnist": "bloodmnist", "pathmnist": "pathmnist", "octmnist": "octmnist",
"pneumoniamnist": "pneumoniamnist", "dermamnist": "dermamnist", "retinamnist": "retinamnist",
"breastmnist": "breastmnist", "organamnist": "organamnist", "organcmnist": "organcmnist",
"organsmnist": "organsmnist", "tissuemnist": "tissuemnist",
}
_DATASET_STATS_1CH = {
"MNIST": ([0.1307], [0.3081]),
"FashionMNIST": ([0.2860], [0.3530]),
"KMNIST": ([0.1918], [0.3483]),
"EMNIST": ([0.1307], [0.3081]),
"QMNIST": ([0.1307], [0.3081]),
"USPS": ([0.5000], [0.5000]),
}
_MEAN1, _STD1 = [0.5], [0.5]
_MEAN3, _STD3 = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
def _norm_stats(name: str, channels: int) -> Tuple[List[float], List[float]]:
if channels == 1:
return _DATASET_STATS_1CH.get(name, (_MEAN1, _STD1))
return _MEAN3, _STD3
def _to_channels(target_c: int):
def _fn(t: torch.Tensor) -> torch.Tensor:
c = t.shape[0]
if c == target_c:
return t
if target_c == 1:
if c == 3:
r, g, b = t[0], t[1], t[2]
gray = 0.2989*r + 0.5870*g + 0.1140*b
return gray.unsqueeze(0)
return t[:1]
if target_c == 3:
if c == 1:
return t.repeat(3, 1, 1)
return t[:3]
return t[:target_c]
return transforms.Lambda(_fn)
def _augmentations_for(name: str, size: int, channels: int) -> List[transforms.Transform]:
aug = []
if not bool(config.get("augment", False)):
return aug
if name.upper() in {"MNIST","KMNIST","EMNIST","QMNIST","USPS"}:
aug += [transforms.RandomAffine(degrees=8, translate=(0.05, 0.05), scale=(0.95, 1.05))]
if size >= 32:
pad = max(1, int(0.03 * size))
aug += [transforms.RandomCrop(size, padding=pad)]
return aug
if size >= 32:
pad = max(1, int(0.03 * size))
aug += [transforms.RandomCrop(size, padding=pad)]
aug += [transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05))]
if channels == 3 and name.lower().endswith("mnist"):
aug += [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.02)]
return aug
def _build_transforms(dataset_name: str, split: str, native_c: int, target_c: int|str, size: int, normalize: bool, per_dataset_norm: bool) -> transforms.Compose:
t: List[transforms.Transform] = []
if size != 28:
t.append(transforms.Resize((size, size)))
t.append(transforms.ToTensor())
out_c = native_c
if target_c != "auto":
t.append(_to_channels(int(target_c)))
out_c = int(target_c)
if split == "train":
t = _augmentations_for(dataset_name, size, out_c) + t
if normalize:
if per_dataset_norm:
mean, std = _norm_stats(dataset_name, out_c)
else:
mean, std = (_MEAN1, _STD1) if out_c == 1 else (_MEAN3, _STD3)
t.append(transforms.Normalize(mean=mean, std=std))
t.append(transforms.Lambda(lambda x: x.view(-1)))
return transforms.Compose(t)
def collate_as_int(batch):
xs, ys = zip(*batch)
xs = torch.stack(xs, dim=0)
_ys = []
for y in ys:
if isinstance(y, (int, np.integer)):
_ys.append(int(y))
elif torch.is_tensor(y):
if y.ndim == 0: _ys.append(int(y.item()))
elif y.ndim == 1 and y.numel()==1: _ys.append(int(y.item()))
else: _ys.append(int(y.argmax().item()))
else:
arr = np.asarray(y)
if arr.ndim == 0 or (arr.ndim==1 and arr.size==1):
_ys.append(int(arr.item()))
else:
_ys.append(int(arr.argmax()))
ys_tensor = torch.tensor(_ys, dtype=torch.long)
return xs, ys_tensor
def _get_med_info(flag: str) -> dict:
if MED_INFO is None:
raise ImportError("medmnist is not installed. `pip install medmnist`")
if flag not in MED_INFO:
raise KeyError(f"Unknown MedMNIST flag: {flag}")
return MED_INFO[flag]
def _med_class_names(info: dict) -> List[str]:
lab = info["label"]
return [lab[str(i)] for i in range(len(lab))]
def load_single_dataset(name: str, split: str,
cfg: Optional[Dict]=None,
resolved_target_channels: Optional[int|str]=None
) -> Tuple[torch.utils.data.Dataset, int, List[str], int, int]:
"""
Return: dataset, num_classes, class_names, input_dim (C*H*W), output_channels
"""
cfg = cfg or config
name_key = name.strip()
name_lower = name_key.lower()
size = int(cfg.get("img_size", 28))
want_c = cfg.get("img_channels", "auto") if resolved_target_channels is None else resolved_target_channels
normalize = bool(cfg.get("normalize", True))
per_dataset_norm = bool(cfg.get("per_dataset_norm", True))
# TorchVision
if name_lower in _TORCHVISION_KEYS:
canonical = _TORCHVISION_KEYS[name_lower]
native_c = 1
transform = _build_transforms(canonical, split, native_c, want_c, size, normalize, per_dataset_norm)
if canonical == "MNIST":
ds = datasets.MNIST("./data", train=(split=="train"), transform=transform, download=True)
ncls = 10; cls_names = [f"digit-{i}" for i in range(10)]
elif canonical == "FashionMNIST":
base = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
ds = datasets.FashionMNIST("./data", train=(split=="train"), transform=transform, download=True)
ncls = 10; cls_names = [f"fashion-{n}" for n in base]
elif canonical == "KMNIST":
ds = datasets.KMNIST("./data", train=(split=="train"), transform=transform, download=True)
ncls = 10; cls_names = [f"kmnist-{c}" for c in ['お','き','す','つ','な','は','ま','や','れ','を']]
elif canonical == "EMNIST":
ds = datasets.EMNIST("./data", split='balanced', train=(split=="train"), transform=transform, download=True)
letters = ['0','1','2','3','4','5','6','7','8','9',
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z',
'a','b','d','e','f','g','h','n','q','r','t']
ncls = 47; cls_names = [f"emnist-{c}" for c in letters]
elif canonical == "QMNIST":
ds = datasets.QMNIST("./data", what=('train' if split=="train" else 'test'), transform=transform, download=True)
ncls = 10; cls_names = [f"qmnist-{i}" for i in range(10)]
elif canonical == "USPS":
ds = datasets.USPS("./data", train=(split=="train"), transform=transform, download=True)
ncls = 10; cls_names = [f"usps-{i}" for i in range(10)]
else:
raise ValueError(f"Unhandled TorchVision dataset: {canonical}")
out_c = native_c if want_c == "auto" else int(want_c)
input_dim = out_c * size * size
return ds, ncls, cls_names, input_dim, out_c
# MedMNIST
if name_lower in _MEDMNIST_MAP:
if medmnist is None:
raise ImportError("medmnist not available. `pip install medmnist`")
flag = _MEDMNIST_MAP[name_lower]
info = _get_med_info(flag)
DataClass = getattr(medmnist, info["python_class"])
native_c = int(info.get("n_channels", 1))
out_c = native_c if want_c == "auto" else int(want_c)
transform = transforms.Compose([
transforms.ToTensor(),
_to_channels(out_c) if want_c != "auto" else transforms.Lambda(lambda t: t),
*(_augmentations_for(name_key, size, out_c) if (split=="train") else []),
transforms.Resize((size, size)) if size != 28 else transforms.Lambda(lambda t: t),
transforms.Normalize(*(_norm_stats(name_key, out_c) if (normalize and per_dataset_norm) else ((_MEAN1,_STD1) if out_c==1 else (_MEAN3,_STD3)))) if normalize else transforms.Lambda(lambda t: t),
transforms.Lambda(lambda t: t.view(-1)),
])
ds = DataClass(split=('train' if split=="train" else 'test'), transform=transform, download=True, size=size)
ncls = len(info["label"]); cls_names = _med_class_names(info)
input_dim = out_c * size * size
return ds, ncls, cls_names, input_dim, out_c
raise ValueError(f"Unknown dataset name: {name}")
def get_dataset_single(name: str, batch_size: int, num_workers: int = 2):
"""
Load a single dataset honoring config overrides.
Returns: train_loader, test_loader, num_classes, class_names, input_dim, channels
"""
tr, ntr, names_tr, in_tr, out_c = load_single_dataset(name, "train", config)
te, nte, names_te, in_te, out_c2 = load_single_dataset(name, "test", config)
assert ntr == nte and in_tr == in_te and out_c == out_c2
g = make_torch_generator(int(config.get("seed", 42)))
train_loader = DataLoader(
tr, batch_size=batch_size, shuffle=True, num_workers=num_workers,
pin_memory=torch.cuda.is_available(), collate_fn=collate_as_int,
worker_init_fn=seed_worker, generator=g, persistent_workers=False
)
test_loader = DataLoader(
te, batch_size=batch_size, shuffle=False, num_workers=num_workers,
pin_memory=torch.cuda.is_available(), collate_fn=collate_as_int,
worker_init_fn=seed_worker, generator=g, persistent_workers=False
)
return train_loader, test_loader, ntr, names_tr, in_tr, out_c
# Dataset catalogs
TORCHVISION_DATASETS = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST", "USPS"]
MEDMNIST_DATASETS = ["BloodMNIST","PathMNIST","OCTMNIST","PneumoniaMNIST","DermaMNIST",
"RetinaMNIST","BreastMNIST","OrganAMNIST","OrganCMNIST","OrganSMNIST","TissueMNIST"]
# ---------------------------------------------------------------------
# Models
# ---------------------------------------------------------------------
class PentaFreqExtractor(nn.Module):
"""
Multi-channel spectral extractor:
- Input: [B, C*H*W], unflatten -> [B, C, H, W]
- 5 frequency bands -> encode to base_dim each
"""
def __init__(self, input_dim: int = 784, input_ch: int = 1, base_dim: int = 64, channels: int = 12):
super().__init__()
self.input_dim = input_dim
self.input_ch = int(input_ch)
side_f = (input_dim / max(1, self.input_ch)) ** 0.5
side = int(side_f)
assert side * side * self.input_ch == input_dim, f"input_dim ({input_dim}) != C*H*W with H=W; C={self.input_ch}, side≈{side_f:.3f}"
self.unflatten = nn.Unflatten(1, (self.input_ch, side, side))
self.base_dim = base_dim
# Vertex 0 (ultra-high)
self.v0_ultrahigh = nn.Sequential(
nn.Conv2d(self.input_ch, channels, 3, padding=1),
nn.BatchNorm2d(channels), nn.ReLU(),
nn.Conv2d(channels, channels, 3, padding=1, groups=channels),
nn.BatchNorm2d(channels), nn.ReLU(),
nn.AdaptiveAvgPool2d(7), nn.Flatten()
); self.v0_encode = nn.Linear(channels * 49, base_dim)
# Vertex 1 (high)
self.v1_high = nn.Sequential(
nn.Conv2d(self.input_ch, channels, 3, padding=1),
nn.BatchNorm2d(channels), nn.Tanh(),
nn.MaxPool2d(2),
nn.Conv2d(channels, channels, 3, padding=1),
nn.BatchNorm2d(channels), nn.Tanh(),
nn.AdaptiveAvgPool2d(7), nn.Flatten()
); self.v1_encode = nn.Linear(channels * 49, base_dim)
# Vertex 2 (mid)
self.v2_mid = nn.Sequential(
nn.Conv2d(self.input_ch, channels, 5, padding=2, stride=2),
nn.BatchNorm2d(channels), nn.GELU(),
nn.Conv2d(channels, channels, 3, padding=1),
nn.BatchNorm2d(channels), nn.GELU(),
nn.AdaptiveAvgPool2d(7), nn.Flatten()
); self.v2_encode = nn.Linear(channels * 49, base_dim)
# Vertex 3 (low-mid)
self.v3_lowmid = nn.Sequential(
nn.AvgPool2d(2),
nn.Conv2d(self.input_ch, channels, 7, padding=3),
nn.BatchNorm2d(channels), nn.SiLU(),
nn.AdaptiveAvgPool2d(7), nn.Flatten()
); self.v3_encode = nn.Linear(channels * 49, base_dim)
# Vertex 4 (low)
self.v4_low = nn.Sequential(
nn.AvgPool2d(4),
nn.Conv2d(self.input_ch, channels, 7, padding=3),
nn.BatchNorm2d(channels), nn.Sigmoid(),
nn.AdaptiveAvgPool2d(7), nn.Flatten()
); self.v4_encode = nn.Linear(channels * 49, base_dim)
self.register_buffer("adjacency_matrix", torch.ones(5, 5) - torch.eye(5))
self._init_edge_kernels(channels)
@torch.no_grad()
def _init_edge_kernels(self, channels: int):
if channels < 5: return
conv0 = self.v0_ultrahigh[0]
if not isinstance(conv0, nn.Conv2d): return
if conv0.weight.shape[1] >= 1:
k = conv0.weight
k[0,0] = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=k.dtype)/4
k[1,0] = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=k.dtype)/4
k[2,0] = torch.tensor([[0,-1,0],[-1,4,-1],[0,-1,0]], dtype=k.dtype)/4
k[3,0] = torch.tensor([[1,0,0],[0,-1,0],[0,0,0]], dtype=k.dtype)/2
k[4,0] = torch.tensor([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=k.dtype)/3
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
img = self.unflatten(x)
v0 = self.v0_encode(self.v0_ultrahigh(img))
v1 = self.v1_encode(self.v1_high(img))
v2 = self.v2_encode(self.v2_mid(img))
v3 = self.v3_encode(self.v3_lowmid(img))
v4 = self.v4_encode(self.v4_low(img))
vertices = torch.stack([v0, v1, v2, v3, v4], dim=1) # [B,5,D]
return vertices, self.adjacency_matrix
class PentachoronCrossAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 14, dropout: float = 0.0):
super().__init__()
self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, dropout=dropout, batch_first=True)
def _row_to_attn_mask(self, row: torch.Tensor) -> torch.Tensor:
mask = torch.zeros(1, row.numel(), device=row.device, dtype=torch.float32)
mask[(row == 0).unsqueeze(0)] = float("-inf")
return mask
def forward(self, vertices: torch.Tensor, adjacency: torch.Tensor) -> torch.Tensor:
B, V, D = vertices.shape
outs = []
for i in range(V):
q = vertices[:, i:i+1, :]
k = v = vertices
mask = self._row_to_attn_mask(adjacency[i].to(vertices.device))
out, _ = self.attn(q, k, v, attn_mask=mask, need_weights=False)
outs.append(out)
return torch.cat(outs, dim=1)
class PentachoronOpinionFusion(nn.Module):
def __init__(self, base_dim: int = 64, proj_dim: Optional[int] = None, num_heads: int = 14, p_dropout: float = 0.2):
super().__init__()
self.cross = PentachoronCrossAttention(dim=base_dim, num_heads=num_heads)
self.fusion = nn.Sequential(
nn.Linear(base_dim * 5, base_dim * 3),
nn.BatchNorm1d(base_dim * 3), nn.ReLU(), nn.Dropout(p_dropout),
nn.Linear(base_dim * 3, base_dim * 2),
nn.BatchNorm1d(base_dim * 2), nn.ReLU(),
nn.Linear(base_dim * 2, base_dim),
)
self.projection = None if proj_dim is None else nn.Linear(base_dim, proj_dim, bias=False)
self._lambda_raw = nn.Parameter(torch.tensor(0.0))
@staticmethod
def _softmax_geometry(vertices: torch.Tensor, adjacency: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
v_norm = F.normalize(vertices, dim=2, eps=1e-8)
sims = torch.bmm(v_norm, v_norm.transpose(1, 2))
edge_strengths = sims * adjacency.to(vertices.dtype).unsqueeze(0)
weights = F.softmax(edge_strengths.sum(dim=2), dim=1) # [B,5]
weighted = vertices * weights.unsqueeze(2)
return weighted, weights
def forward(self, vertices: torch.Tensor, adjacency: torch.Tensor, return_diag: bool = False):
soft_out, weights = self._softmax_geometry(vertices, adjacency)
attn_out = self.cross(vertices, adjacency)
lam = torch.sigmoid(self._lambda_raw)
combined = lam * soft_out + (1.0 - lam) * attn_out
fused = self.fusion(combined.flatten(1))
if self.projection is not None:
fused = self.projection(fused)
z = F.normalize(fused, dim=1)
if not return_diag:
return z, None
return z, {"lambda": lam.detach(), "softmax_weights": weights.detach()}
class PentaFreqEncoderV2(nn.Module):
def __init__(self, input_dim: int = 784, input_ch: int = 1, base_dim: int = 64, proj_dim: Optional[int] = None, num_heads: int = 14, channels: int = 12):
super().__init__()
self.extractor = PentaFreqExtractor(input_dim=input_dim, input_ch=input_ch, base_dim=base_dim, channels=channels)
self.opinion = PentachoronOpinionFusion(base_dim=base_dim, proj_dim=proj_dim, num_heads=num_heads)
@torch.no_grad()
def get_frequency_contributions(self, x: torch.Tensor) -> torch.Tensor:
verts, adj = self.extractor(x)
_, w = self.opinion._softmax_geometry(verts, adj)
return w
def forward(self, x: torch.Tensor, return_diag: bool = False):
verts, adj = self.extractor(x)
z, diag = self.opinion(verts, adj, return_diag)
return (z, diag) if return_diag else z
class BatchedPentachoronConstellation(nn.Module):
def __init__(self, num_classes: int, dim: int, num_pairs: int = 5, device: Optional[torch.device] = None, lambda_sep: float = 0.5):
super().__init__()
self.num_classes = num_classes
self.dim = dim
self.num_pairs = num_pairs
self.device = device if device is not None else torch.device("cpu")
self.lambda_separation = lambda_sep
self.dispatchers = nn.Parameter(self._init_batched_pentachora())
self.specialists = nn.Parameter(self._init_batched_pentachora())
self.dispatcher_weights = nn.Parameter(torch.randn(num_pairs, 5) * 0.1)
self.specialist_weights = nn.Parameter(torch.randn(num_pairs, 5) * 0.1)
self.temps = nn.Parameter(0.3 * torch.ones(num_pairs))
self.register_buffer("vertex_map", self._create_vertex_mapping())
self.group_heads = nn.ModuleList([
nn.Linear(dim, int((self.vertex_map == i).sum().item())) if int((self.vertex_map == i).sum().item()) > 0 else None
for i in range(5)
])
self.cross_attention = nn.MultiheadAttention(embed_dim=dim, num_heads=14, dropout=0.1, batch_first=True)
self.aggregation_weights = nn.Parameter(torch.ones(num_pairs) / num_pairs)
self.fusion = nn.Sequential(
nn.Linear(num_classes * num_pairs, num_classes * 2),
nn.BatchNorm1d(num_classes * 2), nn.ReLU(), nn.Dropout(0.2),
nn.Linear(num_classes * 2, num_classes)
)
self.coherence_head = nn.Sequential(nn.Linear(dim, dim // 2), nn.GELU(), nn.Linear(dim // 2, 1))
def _init_batched_pentachora(self) -> torch.Tensor:
sqrt15, sqrt10, sqrt5 = np.sqrt(15), np.sqrt(10), np.sqrt(5)
base_simplex = torch.tensor([
[ 1.0, 0.0, 0.0, 0.0],
[-0.25, sqrt15/4, 0.0, 0.0],
[-0.25,-sqrt15/12, sqrt10/3, 0.0],
[-0.25,-sqrt15/12,-sqrt10/6, sqrt5/2],
[-0.25,-sqrt15/12,-sqrt10/6,-sqrt5/2]
], device=self.device, dtype=torch.float32)
base_simplex = F.normalize(base_simplex, dim=1)
pentachora = torch.zeros(self.num_pairs, 5, self.dim, device=self.device, dtype=torch.float32)
for i in range(self.num_pairs):
pentachora[i, :, :4] = base_simplex * (1 + 0.1 * i)
if self.dim > 4:
pentachora[i, :, 4:] = torch.randn(5, self.dim - 4, device=self.device) * (random.random() * 0.25)
return pentachora * 2.0
def _create_vertex_mapping(self) -> torch.Tensor:
mapping = torch.zeros(self.num_classes, dtype=torch.long)
for i in range(self.num_classes):
mapping[i] = i % 5
return mapping
def forward(self, x: torch.Tensor):
B = x.size(0)
coherence_gate = torch.sigmoid(self.coherence_head(x)) # [B,1]
x_exp = x.unsqueeze(1).unsqueeze(2) # [B,1,1,D]
disp_exp = self.dispatchers.unsqueeze(0) # [1,P,5,D]
spec_exp = self.specialists.unsqueeze(0) # [1,P,5,D]
disp_d = torch.norm(x_exp - disp_exp, dim=3) # [B,P,5]
spec_d = torch.norm(x_exp - spec_exp, dim=3) # [B,P,5]
dw = F.softmax(self.dispatcher_weights, dim=1).unsqueeze(0)
sw = F.softmax(self.specialist_weights, dim=1).unsqueeze(0)
temps = torch.clamp(self.temps, 0.1, 2.0).view(1, -1, 1)
disp_logits = -(disp_d * dw) / temps
spec_logits = -(spec_d * sw) / temps
c = coherence_gate.unsqueeze(-1)
disp_probs = F.softmax(disp_logits * c, dim=2)
spec_probs = F.softmax(spec_logits * c, dim=2)
probs = 0.5 * disp_probs + 0.5 * spec_probs
scores_by_pair = []
for p in range(self.num_pairs):
pair_scores = torch.zeros(B, self.num_classes, device=x.device)
for v_idx in range(5):
idxs = (self.vertex_map == v_idx).nonzero(as_tuple=True)[0]
if len(idxs) == 0: continue
v_prob = probs[:, p, v_idx:v_idx+1]
if self.group_heads[v_idx] is not None:
g_logits = self.group_heads[v_idx](x) # [B, |idxs|]
gated = g_logits * v_prob
for i, cls in enumerate(idxs.tolist()):
if i < gated.size(1):
pair_scores[:, cls] = gated[:, i]
scores_by_pair.append(pair_scores)
scores_tensor = torch.stack(scores_by_pair, dim=1) # [B,P,C]
centers = self.dispatchers.mean(dim=1).unsqueeze(0).expand(B, -1, -1)
_attn, _ = self.cross_attention(centers, centers, centers)
agg = F.softmax(self.aggregation_weights, dim=0).view(1, -1, 1)
weighted = (scores_tensor * agg).sum(dim=1) # [B,C]
fused = self.fusion(scores_tensor.flatten(1)) # [B,C]
final = 0.6 * weighted + 0.4 * fused
return final, {"disp_d": disp_d, "spec_d": spec_d, "probs": probs}
def _batched_cayley_menger(self, pentachora: torch.Tensor) -> torch.Tensor:
"""
Stable CM proxy: det(M) via eigvals of (M + eps*I).
Returns a positive scalar per cube; larger => more 'volumetric' (less degenerate).
"""
P = pentachora.shape[0]
d2 = torch.cdist(pentachora, pentachora).pow(2) # [P,5,5]
M = torch.zeros(P, 6, 6, device=self.device, dtype=pentachora.dtype)
M[:, 0, 1:] = 1.0
M[:, 1:, 0] = 1.0
M[:, 1:, 1:] = d2
eps = 1e-6
I = torch.eye(6, device=self.device, dtype=pentachora.dtype).unsqueeze(0)
M_eps = M + eps * I # make SPD-ish
# eigvalsh → real, sorted
evals = torch.linalg.eigvalsh(M_eps) # [P,6]
evals = evals.clamp_min(1e-12) # avoid log(<=0)
logdet = evals.log().sum(dim=1) # log|det|
det = torch.exp(logdet) # |det|
# keep it finite
det = torch.nan_to_num(det, nan=0.0, posinf=1e6, neginf=0.0)
return det
def _batched_edge_variance(self, pentachora: torch.Tensor) -> torch.Tensor:
d = torch.cdist(pentachora, pentachora)
mask = torch.triu(torch.ones(5, 5, device=pentachora.device), diagonal=1).bool()
edges = torch.stack([d[p][mask] for p in range(self.num_pairs)]) # [P,10]
return edges.var(dim=1).sum() + torch.relu(0.5 - edges.min(dim=1)[0]).sum()
def regularization_loss(self, vertex_weights=None) -> torch.Tensor:
disp_cm = self._batched_cayley_menger(self.dispatchers)
spec_cm = self._batched_cayley_menger(self.specialists)
cm_loss = torch.relu(1.0 - torch.abs(disp_cm)).sum() + torch.relu(1.0 - torch.abs(spec_cm)).sum()
edge_loss = self._batched_edge_variance(self.dispatchers) + self._batched_edge_variance(self.specialists)
disp_centers = self.dispatchers.mean(dim=1)
spec_centers = self.specialists.mean(dim=1)
cos_sims = F.cosine_similarity(disp_centers, spec_centers, dim=1, eps=1e-8)
ortho = torch.abs(cos_sims).sum() * self.lambda_separation
separations = torch.norm(disp_centers - spec_centers, dim=1)
sep = torch.relu(2.0 - separations).sum() * self.lambda_separation
dyn = 0.0
if vertex_weights is not None:
vw = vertex_weights.to(self.dispatchers.device)
disp_norms = torch.norm(self.dispatchers, p=2, dim=2)
spec_norms = torch.norm(self.specialists, p=2, dim=2)
dyn = 0.1 * ((disp_norms * vw.unsqueeze(0)).mean() + (spec_norms * vw.unsqueeze(0)).mean())
return (cm_loss + edge_loss + ortho + sep) / self.num_pairs + dyn
# ---------------------------------------------------------------------
# Losses
# ---------------------------------------------------------------------
def dual_contrastive_loss(latents, targets, constellation, temp: float):
B = latents.size(0)
z = F.normalize(latents, dim=1, eps=1e-8)
disp = F.normalize(constellation.dispatchers, dim=2, eps=1e-8)
spec = F.normalize(constellation.specialists, dim=2, eps=1e-8)
disp_logits = torch.einsum('bd,pvd->bpv', z, disp) / temp
spec_logits = torch.einsum('bd,pvd->bpv', z, spec) / temp
tvert = constellation.vertex_map[targets]
idx = tvert.view(B, 1, 1).expand(B, disp_logits.size(1), 1)
disp_pos = disp_logits.gather(2, idx).squeeze(2)
spec_pos = spec_logits.gather(2, idx).squeeze(2)
disp_lse = torch.logsumexp(disp_logits, dim=2)
spec_lse = torch.logsumexp(spec_logits, dim=2)
return (disp_lse - disp_pos).mean() + (spec_lse - spec_pos).mean()
class RoseDiagnosticHead(nn.Module):
def __init__(self, latent_dim: int, hidden_dim: int = 128):
super().__init__()
self.net = nn.Sequential(nn.Linear(latent_dim, hidden_dim), nn.GELU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, 1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def rose_score_magnitude(x, need, relation, purpose, eps: float = 1e-8):
x_n = F.normalize(x, dim=-1, eps=eps)
n_n = F.normalize(need, dim=-1, eps=eps)
r_n = F.normalize(relation, dim=-1, eps=eps)
p_n = F.normalize(purpose, dim=-1, eps=eps)
r7 = ((x_n*n_n).sum(-1) + (x_n*r_n).sum(-1) + (x_n*p_n).sum(-1)) / 3.0
r8 = x.norm(dim=-1).clamp_min(eps)
return r7 * r8
def rose_contrastive_loss(latents, targets, constellation, temp: float = 0.5):
B, D = latents.shape
tvert = constellation.vertex_map[targets]
need = constellation.specialists[:, tvert, :].mean(dim=0)
relation = constellation.dispatchers[:, tvert, :].mean(dim=0)
purpose = constellation.specialists.mean(dim=(0, 1)).unsqueeze(0).expand(B, D)
rose = rose_score_magnitude(latents, need, relation, purpose)
weights = (1.0 - torch.tanh(rose)).detach()
spec = F.normalize(constellation.specialists.mean(dim=0), dim=1, eps=1e-8)
disp = F.normalize(constellation.dispatchers.mean(dim=0), dim=1, eps=1e-8)
z = F.normalize(latents, dim=1, eps=1e-8)
spec_logits = (z @ spec.T) / temp
disp_logits = (z @ disp.T) / temp
spec_pos = spec_logits.gather(1, tvert.view(-1,1)).squeeze(1)
disp_pos = disp_logits.gather(1, tvert.view(-1,1)).squeeze(1)
spec_lse = torch.logsumexp(spec_logits, dim=1)
disp_lse = torch.logsumexp(disp_logits, dim=1)
per_sample = 0.5 * ((spec_lse - spec_pos) + (disp_lse - disp_pos))
return (per_sample * weights).mean(), rose.detach()
# ---------------------------------------------------------------------
# Regularization helpers
# ---------------------------------------------------------------------
def get_class_similarity(constellation_model: BatchedPentachoronConstellation, num_classes: int) -> torch.Tensor:
W = constellation_model.fusion[-1].weight.data.detach()
Wn = F.normalize(W, p=2, dim=1)
return torch.clamp(Wn @ Wn.T, 0.0, 1.0)
def vertex_weights_from_confusion(cm: np.ndarray, class_similarity: torch.Tensor, vertex_map: torch.Tensor, device: torch.device) -> torch.Tensor:
C = cm.shape[0]
totals = cm.sum(axis=1)
correct = cm.diagonal()
acc = np.divide(correct, totals, out=np.zeros_like(correct, dtype=float), where=totals != 0)
confusion_scores = 1.0 - torch.tensor(acc, device=device, dtype=torch.float32)
sigma = 0.5
gaussian = torch.exp(-((1 - class_similarity) ** 2) / (2 * sigma ** 2))
propagated = gaussian @ confusion_scores
v_sum = torch.zeros(5, device=device); v_cnt = torch.zeros(5, device=device)
for cls, v in enumerate(vertex_map.tolist()):
v_sum[v] += propagated[cls]; v_cnt[v] += 1
v_avg = torch.zeros_like(v_sum); mask = v_cnt > 0; v_avg[mask] = v_sum[mask] / v_cnt[mask]
vw = 1.0 - torch.tanh(v_avg)
return F.normalize(vw, p=1, dim=0) * 5.0
# ---------------------------------------------------------------------
# Evaluate / Train
# ---------------------------------------------------------------------
def evaluate(encoder: nn.Module, constellation: nn.Module, loader, num_classes: int, device: torch.device, collect_diag: bool = False):
encoder.eval(); constellation.eval()
all_preds, all_targets = [], []
lambda_vals = []
soft_w_sums = torch.zeros(5, device=device)
soft_w_count = 0
with torch.no_grad():
for x, y in tqdm(loader, desc="Evaluating"):
x, y = x.to(device), y.to(device)
if collect_diag:
z, diag = encoder(x, return_diag=True)
w = diag["softmax_weights"]
soft_w_sums += w.sum(dim=0); soft_w_count += w.size(0)
else:
z = encoder(x)
logits, _ = constellation(z)
preds = logits.argmax(dim=1)
all_preds.append(preds.cpu().numpy())
all_targets.append(y.cpu().numpy())
if hasattr(encoder, "opinion") and hasattr(encoder.opinion, "_lambda_raw"):
lambda_vals.append(float(torch.sigmoid(encoder.opinion._lambda_raw).item()))
all_preds = np.concatenate(all_preds); all_targets = np.concatenate(all_targets)
acc = float((all_preds == all_targets).mean())
cm = confusion_matrix(all_targets, all_preds, labels=np.arange(num_classes))
per_class = np.divide(cm.diagonal(), cm.sum(axis=1), out=np.zeros(num_classes), where=cm.sum(axis=1)!=0)
avg_soft_w = (soft_w_sums / soft_w_count).detach().cpu().numpy() if (collect_diag and soft_w_count > 0) else None
lam_eval = float(np.mean(lambda_vals)) if lambda_vals else None
return acc, per_class.tolist(), cm, avg_soft_w, lam_eval
def _adapt_pairs_by_classes(cfg: Dict, num_classes: int) -> int:
# Keep ~<=10 classes per vertex group across pairs
pairs = cfg.get("num_pentachoron_pairs", 1)
target = max(1, int(math.ceil(num_classes / 10)))
return max(pairs, target)
def train_one(
train_loader,
test_loader,
num_classes: int,
cfg: dict,
device: torch.device,
writer: SummaryWriter,
class_names: Optional[list] = None,
):
pairs = _adapt_pairs_by_classes(cfg, num_classes)
if pairs != cfg.get("num_pentachoron_pairs"):
print(f"[auto] Adjusting num_pentachoron_pairs -> {pairs} for {num_classes} classes.")
cfg_local = dict(cfg); cfg_local["num_pentachoron_pairs"] = pairs
encoder = PentaFreqEncoderV2(
input_dim=cfg_local["input_dim"],
input_ch=cfg_local.get("input_channels", 1),
base_dim=cfg_local["base_dim"],
proj_dim=None,
num_heads=cfg_local.get("num_heads", 14),
channels=cfg_local.get("channels", 12),
).to(device)
constellation = BatchedPentachoronConstellation(
num_classes=num_classes,
dim=cfg_local["base_dim"],
num_pairs=cfg_local["num_pentachoron_pairs"],
device=device,
lambda_sep=cfg_local["lambda_separation"],
).to(device)
diag_head = RoseDiagnosticHead(cfg_local["base_dim"]).to(device)
params = list(encoder.parameters()) + list(constellation.parameters()) + list(diag_head.parameters())
optim = torch.optim.AdamW(params, lr=cfg_local["lr"], weight_decay=cfg_local["weight_decay"])
lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=cfg_local["epochs"])
w_ce = float(cfg_local.get("w_ce", 1.0))
w_dual = float(cfg_local.get("w_dual", 1.0))
w_rose = float(cfg_local.get("w_rose", 1.0))
w_diag = float(cfg_local.get("w_diag", 0.1))
w_reg = float(cfg_local.get("w_reg", cfg_local["loss_weight_scalar"]))
history = {"train_loss": [], "train_acc": [], "test_acc": [], "ce": [], "dual": [], "rose": [], "diag": [], "reg": [], "lambda": []}
best = {"acc": 0.0, "cm": None, "epoch": -1}
vertex_weights = None
global_step = 0
for epoch in range(cfg_local["epochs"]):
encoder.train(); constellation.train(); diag_head.train()
sum_loss = sum_ce = sum_dual = sum_rose = sum_diag = sum_reg = 0.0
correct = total = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg_local['epochs']} [Train]")
for x, y in pbar:
x, y = x.to(device), y.to(device)
optim.zero_grad()
z = encoder(x)
logits, _ = constellation(z)
l_ce = F.cross_entropy(logits, y)
l_dual = dual_contrastive_loss(z, y, constellation, temp=cfg_local["temp"])
l_rose, rose_scores = rose_contrastive_loss(z, y, constellation, temp=cfg_local["temp"])
pred_rose = diag_head(z.detach()).squeeze(-1)
l_diag = F.mse_loss(pred_rose, rose_scores)
l_reg = constellation.regularization_loss(vertex_weights=vertex_weights)
loss = (w_ce*l_ce) + (w_dual*l_dual) + (w_rose*l_rose) + (w_diag*l_diag) + (w_reg*l_reg)
# after computing l_ce, l_dual, l_rose, l_diag, l_reg and loss
if not torch.isfinite(l_ce) or not torch.isfinite(l_dual) \
or not torch.isfinite(l_rose) or not torch.isfinite(l_diag) \
or not torch.isfinite(l_reg) or not torch.isfinite(loss):
print("[NaN-guard] non-finite detected. Skipping step. "
f"ce={l_ce.item() if torch.isfinite(l_ce) else 'nan'}, "
f"dual={l_dual.item() if torch.isfinite(l_dual) else 'nan'}, "
f"rose={l_rose.item() if torch.isfinite(l_rose) else 'nan'}, "
f"reg={l_reg.item() if torch.isfinite(l_reg) else 'nan'}")
# Soft defuse: drop LR a bit for stability
for g in optim.param_groups:
g["lr"] = max(g["lr"] * 0.5, 1e-6)
optim.zero_grad(set_to_none=True)
# Optional: clip parameter norms right now to kill accidental blow-ups
with torch.no_grad():
for p in list(encoder.parameters()) + list(constellation.parameters()):
if torch.isfinite(p).all():
p.clamp_(-1e3, 1e3)
continue
loss.backward()
torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(constellation.parameters(), 1.0)
optim.step()
bs = x.size(0)
sum_loss += loss.item() * bs
sum_ce += l_ce.item() * bs
sum_dual += l_dual.item() * bs
sum_rose += l_rose.item() * bs
sum_diag += l_diag.item() * bs
sum_reg += l_reg.item() * bs
preds = logits.argmax(dim=1)
correct += (preds == y).sum().item()
total += bs
# TB (step)
writer.add_scalar("step/loss", loss.item(), global_step)
writer.add_scalar("step/ce", l_ce.item(), global_step)
writer.add_scalar("step/dual", l_dual.item(), global_step)
writer.add_scalar("step/rose", l_rose.item(), global_step)
writer.add_scalar("step/diag", l_diag.item(), global_step)
writer.add_scalar("step/reg", l_reg.item(), global_step)
global_step += 1
pbar.set_postfix({
"loss": f"{loss.item():.4f}",
"acc": f"{correct/max(1,total):.4f}",
"ce": f"{l_ce.item():.3f}",
"dual": f"{l_dual.item():.3f}",
"rose": f"{l_rose.item():.3f}",
"reg": f"{l_reg.item():.3f}",
})
train_loss = sum_loss / max(1, total)
train_acc = correct / max(1, total)
history["train_loss"].append(train_loss)
history["train_acc"].append(train_acc)
history["ce"].append(sum_ce / max(1,total))
history["dual"].append(sum_dual / max(1,total))
history["rose"].append(sum_rose / max(1,total))
history["diag"].append(sum_diag / max(1,total))
history["reg"].append(sum_reg / max(1,total))
# Eval
test_acc, per_class_acc, cm, avg_soft_w, lam_eval = evaluate(
encoder, constellation, test_loader, num_classes, device, collect_diag=True
)
history["test_acc"].append(test_acc)
if lam_eval is not None:
history["lambda"].append(lam_eval)
else:
history["lambda"].append(float(torch.sigmoid(encoder.opinion._lambda_raw).item())
if hasattr(encoder, "opinion") else 0.5)
lr_sched.step()
# TB (epoch)
writer.add_scalar("epoch/train_loss", train_loss, epoch+1)
writer.add_scalar("epoch/train_acc", train_acc, epoch+1)
writer.add_scalar("epoch/test_acc", test_acc, epoch+1)
writer.add_scalar("epoch/lr", optim.param_groups[0]["lr"], epoch+1)
writer.add_scalar("epoch/lambda", history["lambda"][-1], epoch+1)
print(f"\n[Epoch {epoch+1}/{cfg_local['epochs']}] "
f"TrainLoss {train_loss:.4f} | TrainAcc {train_acc:.4f} | TestAcc {test_acc:.4f} | "
f"CE {history['ce'][-1]:.3f} Dual {history['dual'][-1]:.3f} "
f"ROSE {history['rose'][-1]:.3f} Reg {history['reg'][-1]:.3f} λ {history['lambda'][-1]:.3f}")
# Update reg weights
with torch.no_grad():
class_sim = get_class_similarity(constellation, num_classes).to(device)
vertex_weights = vertex_weights_from_confusion(cm, class_sim, constellation.vertex_map, device)
if test_acc > best["acc"]:
best["acc"], best["cm"], best["epoch"] = test_acc, cm, epoch+1
print(f" 🎯 New Best Acc: {best['acc']:.4f} at epoch {best['epoch']}")
# Optional confusion per epoch
try:
import matplotlib.pyplot as plt
import seaborn as sns
os.makedirs("plots", exist_ok=True)
plt.figure(figsize=(9, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title(f'Confusion (Epoch {epoch+1}) | Acc: {test_acc:.4f}')
plt.xlabel('Predicted'); plt.ylabel('True'); plt.tight_layout()
plt.savefig(f'plots/confusion_epoch_{epoch+1}.png', dpi=150)
plt.close()
except Exception as e:
print(f"(Confusion heatmap skipped this epoch: {e})")
return encoder, constellation, diag_head, history, best
# ---------------------------------------------------------------------
# Plots (local convenience; artifacts already keep TB)
# ---------------------------------------------------------------------
def plot_history(history: dict, outdir: str = "plots"):
os.makedirs(outdir, exist_ok=True)
import matplotlib.pyplot as plt
plt.figure(figsize=(10,5))
plt.plot(history['train_acc'], label='Train Acc')
plt.plot(history['test_acc'], label='Test Acc')
plt.title('Accuracy over Epochs'); plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.grid(True, ls='--', alpha=0.4)
plt.tight_layout(); plt.savefig(f"{outdir}/accuracy.png", dpi=150); plt.close()
plt.figure(figsize=(10,5))
plt.plot(history['train_loss'], label='Total')
plt.plot(history['ce'], label='CE')
plt.plot(history['dual'], label='DualNCE')
plt.plot(history['rose'], label='ROSE')
plt.plot(history['reg'], label='Reg')
plt.title('Loss Components'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True, ls='--', alpha=0.4)
plt.tight_layout(); plt.savefig(f"{outdir}/loss_components.png", dpi=150); plt.close()
plt.figure(figsize=(8,4))
plt.plot(history['lambda'])
plt.title('λ (Geometry ↔ Attention Gate)'); plt.xlabel('Epoch'); plt.ylabel('λ'); plt.grid(True, ls='--', alpha=0.4)
plt.tight_layout(); plt.savefig(f"{outdir}/lambda.png", dpi=150); plt.close()
def plot_confusion(cm: np.ndarray, class_names: list, outpath: str):
import matplotlib.pyplot as plt
try:
import seaborn as sns
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title('Best Confusion Matrix'); plt.xlabel('Predicted'); plt.ylabel('True')
plt.tight_layout(); plt.savefig(outpath, dpi=150); plt.close()
except Exception:
plt.figure(figsize=(10,8))
plt.imshow(cm, aspect='auto'); plt.title('Best Confusion Matrix')
plt.xlabel('Predicted'); plt.ylabel('True'); plt.colorbar()
plt.tight_layout(); plt.savefig(outpath, dpi=150); plt.close()
def _sanitize_for_path(s: str, preserve_case: bool = True) -> str:
"""Keep letters/digits/._- ; replace others with '_'."""
out = []
for ch in (s if preserve_case else s.lower()):
if ch.isalnum() or ch in "._-":
out.append(ch)
else:
out.append("_")
return "".join(out)
def _build_hf_paths(dataset_name: str, cfg: Dict, ts: str) -> Dict[str, str]:
pres = bool(cfg.get("hf_preserve_case", True))
dataset_disp = dataset_name if pres else dataset_name.lower()
dataset_token = _sanitize_for_path(dataset_disp, preserve_case=pres)
slug = _dataset_slug(dataset_name) # lowercase with '+' for sweeps (kept for convenience)
root = cfg.get("hf_subdir_root", "pentachora-adaptive-encoded").strip("/")
# templates allow {dataset}, {slug}, {ts}
dtempl = cfg.get("hf_dataset_dir_template", "{dataset}")
rtempl = cfg.get("hf_run_dir_template", "{ts}_{dataset}")
dataset_dir = dtempl.format(dataset=dataset_token, slug=slug, ts=ts)
run_dir = rtempl.format(dataset=dataset_token, slug=slug, ts=ts)
path_in_repo = f"{root}/{dataset_dir}/{run_dir}".strip("/")
local_root = Path("artifacts") / root / dataset_dir / run_dir
return {
"dataset_token": dataset_token,
"path_in_repo": path_in_repo,
"local_root": str(local_root),
}
def save_and_push_artifacts(
*,
encoder: nn.Module,
constellation: nn.Module,
diag_head: nn.Module,
config: Dict,
class_names: List[str],
history: Dict,
best: Dict,
tb_log_dir: Path,
dataset_names: List[str], # pass a single dataset name here
):
assert len(dataset_names) == 1, "Pass a single dataset name to save_and_push_artifacts"
dataset_name = dataset_names[0]
ts = _timestamp()
repo_id = _resolve_repo_id(config)
_hf_login_if_needed()
api = _ensure_repo(repo_id)
paths = _build_hf_paths(dataset_name, config, ts)
base_out = Path(paths["local_root"])
base_out.mkdir(parents=True, exist_ok=True)
# Weight file naming
ds_token = paths["dataset_token"]
suffix = f"_{ds_token}" if bool(config.get("hf_weight_suffix_dataset", True)) else ""
wdir = base_out / "weights"; wdir.mkdir(parents=True, exist_ok=True)
save_safetensors({k: v.cpu() for k, v in encoder.state_dict().items()}, str(wdir / f"encoder{suffix}.safetensors"))
save_safetensors({k: v.cpu() for k, v in constellation.state_dict().items()}, str(wdir / f"constellation{suffix}.safetensors"))
save_safetensors({k: v.cpu() for k, v in diag_head.state_dict().items()}, str(wdir / f"diagnostic_head{suffix}.safetensors"))
# Config + history
(base_out / "config.json").write_text(json.dumps(config, indent=2, sort_keys=True), encoding="utf-8")
(base_out / "history.json").write_text(json.dumps(history, indent=2, sort_keys=True), encoding="utf-8")
# CSV history
max_len = max(len(history.get("train_loss", [])), len(history.get("train_acc", [])), len(history.get("test_acc", [])))
df = pd.DataFrame({
"epoch": list(range(1, max_len + 1)),
"train_loss": history.get("train_loss", [np.nan]*max_len),
"train_acc": history.get("train_acc", [np.nan]*max_len),
"test_acc": history.get("test_acc", [np.nan]*max_len),
"ce": history.get("ce", [np.nan]*max_len),
"dual": history.get("dual", [np.nan]*max_len),
"rose": history.get("rose", [np.nan]*max_len),
"diag": history.get("diag", [np.nan]*max_len),
"reg": history.get("reg", [np.nan]*max_len),
"lambda": history.get("lambda", [np.nan]*max_len),
})
df.to_csv(base_out / "history.csv", index=False)
# Plots
if Path("plots").exists():
shutil.copytree("plots", base_out / "plots", dirs_exist_ok=True)
# TensorBoard
tb_dst = base_out / "tensorboard"; tb_dst.mkdir(parents=True, exist_ok=True)
if tb_log_dir and Path(tb_log_dir).exists():
for p in Path(tb_log_dir).glob("*"):
shutil.copy2(p, tb_dst / p.name)
_zip_dir(tb_dst, base_out / "tensorboard_events.zip")
# Manifest + README
manifest = {
"timestamp": ts,
"repo_id": repo_id,
"subdirectory": paths["path_in_repo"],
"dataset_name": dataset_name,
"class_names": class_names,
"num_classes": len(class_names),
"models": {
"encoder": {"params": _param_count(encoder)},
"constellation": {"params": _param_count(constellation)},
"diagnostic_head": {"params": _param_count(diag_head)},
},
"results": {
"best_test_accuracy": float(best.get("acc", 0.0)),
"best_epoch": int(best.get("epoch", -1)),
},
"environment": {
"python": sys.version,
"platform": platform.platform(),
"torch": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"cuda_device": (torch.cuda.get_device_name(0) if torch.cuda.is_available() else None),
"cpu_count": psutil.cpu_count(logical=True),
"memory_gb": round(psutil.virtual_memory().total / (1024**3), 2),
},
}
(base_out / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8")
(base_out / "README.md").write_text(
f"""# Pentachora Adaptive Encoded — {ts}
**Dataset:** {dataset_name}
**Contents**
- `weights/*.safetensors` — encoder, constellation, diagnostic head
- `config.json`, `manifest.json`
- `history.json` / `history.csv`
- `tensorboard/` (and `tensorboard_events.zip`)
- `plots/` — accuracy, loss, λ, confusion
""",
encoding="utf-8"
)
# Push
print(f"[push] Uploading to hf://{repo_id}/{paths['path_in_repo']}")
api.upload_folder(
repo_id=repo_id,
folder_path=str(base_out),
path_in_repo=paths["path_in_repo"],
repo_type="model",
commit_message=f"[{dataset_name}] {ts} | best_acc={manifest['results']['best_test_accuracy']:.4f}",
)
print("[push] ✅ Upload complete.")
return base_out
# ---------------------------------------------------------------------
# Dataset sweep
# ---------------------------------------------------------------------
def run_one_dataset(name: str) -> Dict:
print("\n" + "="*60)
print(f"RUN: {name}")
print("="*60)
# Load
train_loader, test_loader, ncls, class_names, in_dim, out_c = get_dataset_single(
name, batch_size=config["batch_size"], num_workers=2
)
cfg_local = dict(config)
cfg_local["num_classes"] = ncls
cfg_local["input_dim"] = in_dim
cfg_local["input_channels"] = out_c
# TB writer per dataset
ts = _timestamp()
tb_dir = Path("tb_logs") / f"{_dataset_slug(name)}" / ts
tb_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=str(tb_dir))
start = time.time()
encoder, constellation, diag_head, history, best = train_one(
train_loader, test_loader, ncls, cfg_local, device, writer, class_names
)
elapsed_min = (time.time() - start) / 60.0
# Plots
plot_history(history, outdir="plots")
if best["cm"] is not None:
plot_confusion(best["cm"], class_names, outpath=f"plots/best_confusion_{_dataset_slug(name)}_epoch_{best['epoch']}.png")
# Push artifacts
save_and_push_artifacts(
encoder=encoder,
constellation=constellation,
diag_head=diag_head,
config=cfg_local,
class_names=class_names,
history=history,
best=best,
tb_log_dir=tb_dir,
dataset_names=[name],
)
result = {
"dataset": name,
"classes": ncls,
"channels": out_c,
"img_size": config.get("img_size", 28),
"best_acc": float(best["acc"]),
"best_epoch": int(best["epoch"]),
"params_encoder": _param_count(encoder),
"params_constellation": _param_count(constellation),
"elapsed_min": round(elapsed_min, 3),
}
print(f"[done] {name} -> best_acc={result['best_acc']:.4f} @ epoch {result['best_epoch']} time={result['elapsed_min']:.2f}m")
return result
def run_sweep(datasets: List[str]) -> Dict:
os.makedirs("sweeps", exist_ok=True)
results = []
failures = []
for name in datasets:
try:
results.append(run_one_dataset(name))
except Exception as e:
print(f"[fail] {name}: {e}")
failures.append({"dataset": name, "error": str(e)})
# Save local sweep summary
ts = _timestamp()
sweep_dir = Path("sweeps") / ts
sweep_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(results)
df.to_csv(sweep_dir / "results.csv", index=False)
(sweep_dir / "results.json").write_text(json.dumps(results, indent=2), encoding="utf-8")
(sweep_dir / "failures.json").write_text(json.dumps(failures, indent=2), encoding="utf-8")
# Push sweep summary
repo_id = _resolve_repo_id(config)
_hf_login_if_needed()
api = _ensure_repo(repo_id)
path_in_repo = f"pentachora-adaptive-encoded/_sweep/{ts}"
print(f"[push] Uploading sweep summary to hf://{repo_id}/{path_in_repo}")
api.upload_folder(repo_id=repo_id, folder_path=str(sweep_dir), path_in_repo=path_in_repo, repo_type="model")
print("[push] ✅ Sweep summary uploaded.")
return {"timestamp": ts, "results": results, "failures": failures, "path_in_repo": path_in_repo}
# ---------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------
def main():
print("\n" + "="*60)
print("PENTACHORON CONSTELLATION FINAL CONFIGURATION")
print("="*60)
for k, v in config.items():
print(f"{k:24}: {v}")
if config["lr"] > 1e-1:
print(f"⚠️ High LR detected ({config['lr']}). If unstable, try 5e-3 to 5e-2.")
# Sweep mode?
if bool(config.get("sweep_all", False)) or os.getenv("RUN_SWEEP", "0") == "1":
# Try all TorchVision + MedMNIST (skip those not available)
datasets_all = list(TORCHVISION_DATASETS)
if medmnist is not None:
datasets_all += MEDMNIST_DATASETS
out = run_sweep(datasets_all)
print(f"\nSweep complete. Summary path: {out['path_in_repo']}")
return
# Single dataset default (edit here as desired)
DATASET = config.get("dataset", "FashionMNIST")
train_loader, test_loader, ncls, class_names, in_dim, out_c = get_dataset_single(
DATASET, batch_size=config["batch_size"], num_workers=2
)
config["num_classes"] = ncls
config["input_dim"] = in_dim
config["input_channels"] = out_c
tb_dir = Path("tb_logs") / f"{_dataset_slug(DATASET)}" / _timestamp()
tb_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=str(tb_dir))
start = time.time()
encoder, constellation, diag_head, history, best = train_one(
train_loader, test_loader, ncls, config, device, writer, class_names
)
elapsed = (time.time() - start) / 60.0
# Plots
plot_history(history, outdir="plots")
if best["cm"] is not None:
plot_confusion(best["cm"], class_names, outpath=f"plots/best_confusion_epoch_{best['epoch']}.png")
print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"Best Test Accuracy : {best['acc']*100:.2f}% @ epoch {best['epoch']}")
print(f"Total Training Time: {elapsed:.2f} minutes")
save_and_push_artifacts(
encoder=encoder,
constellation=constellation,
diag_head=diag_head,
config=config,
class_names=class_names,
history=history,
best=best,
tb_log_dir=tb_dir,
dataset_names=[DATASET],
)
print("[done] Artifacts uploaded and saved locally.")
if __name__ == "__main__":
main()