|
|
""" |
|
|
Pentachoron Constellation — Multi-Channel, HF Push, and Dataset Sweep |
|
|
MIT |
|
|
Author: AbstractPhil |
|
|
Quartermaster: Mirel (GPT-5 Thinking) |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os, sys, json, math, time, random, shutil, zipfile, platform |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
from typing import List, Tuple, Dict, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import datasets, transforms |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm import tqdm |
|
|
from sklearn.metrics import confusion_matrix |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config: Dict = { |
|
|
|
|
|
"input_dim": 28*28, |
|
|
"input_channels": "auto", |
|
|
"base_dim": 56, |
|
|
"proj_dim": None, |
|
|
|
|
|
|
|
|
"num_classes": 10, |
|
|
"num_pentachoron_pairs": 2, |
|
|
"lambda_separation": 0.391, |
|
|
|
|
|
|
|
|
"num_heads": 2, |
|
|
"channels": 24, |
|
|
|
|
|
|
|
|
"batch_size": 1024, |
|
|
"epochs": 20, |
|
|
"lr": 1e-2, |
|
|
"weight_decay": 1e-5, |
|
|
"temp": 0.7, |
|
|
|
|
|
|
|
|
"w_ce": 1.0, |
|
|
"w_dual": 1.0, |
|
|
"w_rose": 1.0, |
|
|
"w_diag": 0.1, |
|
|
"w_reg": 0.1, |
|
|
|
|
|
|
|
|
"loss_weight_scalar": 0.1, |
|
|
|
|
|
|
|
|
"img_size": 28, |
|
|
"img_channels": "auto", |
|
|
"normalize": True, |
|
|
"per_dataset_norm": True, |
|
|
"augment": False, |
|
|
|
|
|
|
|
|
"sweep_all": False, |
|
|
"seed": 420, |
|
|
|
|
|
|
|
|
"hf_repo_id": "AbstractPhil/pentachora-multi-channel-frequency-encoded", |
|
|
"dataset": "QMNIST", |
|
|
} |
|
|
|
|
|
|
|
|
config.setdefault("hf_subdir_root", "") |
|
|
config.setdefault("hf_dataset_dir_template", "{dataset}") |
|
|
config.setdefault("hf_run_dir_template", "{ts}_{dataset}") |
|
|
config.setdefault("hf_weight_suffix_dataset", True) |
|
|
config.setdefault("hf_preserve_case", True) |
|
|
|
|
|
|
|
|
config.setdefault("deterministic", True) |
|
|
config.setdefault("strict_determinism", False) |
|
|
config.setdefault("deterministic_cublas", False) |
|
|
config.setdefault("seed_per_dataset", False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.autograd.set_detect_anomaly(False) |
|
|
|
|
|
|
|
|
if bool(config.get("deterministic", True)): |
|
|
torch.backends.cudnn.benchmark = False |
|
|
torch.backends.cudnn.deterministic = True |
|
|
else: |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
|
torch.backends.cudnn.allow_tf32 = False |
|
|
torch.backends.cuda.matmul.allow_tf32 = False |
|
|
|
|
|
|
|
|
if bool(config.get("deterministic_cublas", False)): |
|
|
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("PENTACHORON CONSTELLATION CONFIGURATION") |
|
|
print("="*60) |
|
|
for k, v in config.items(): |
|
|
print(f"{k:24}: {v}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seed_everything(seed: int = 42, |
|
|
deterministic: bool | None = None, |
|
|
strict: bool | None = None): |
|
|
"""Seed Python, NumPy, Torch (CPU+CUDA), and set hash seed/env flags.""" |
|
|
if deterministic is None: |
|
|
deterministic = bool(config.get("deterministic", True)) |
|
|
if strict is None: |
|
|
strict = bool(config.get("strict_determinism", False)) |
|
|
|
|
|
|
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
try: |
|
|
import torch |
|
|
torch.use_deterministic_algorithms(strict) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def make_torch_generator(seed: int) -> torch.Generator: |
|
|
g = torch.Generator() |
|
|
g.manual_seed(seed) |
|
|
return g |
|
|
|
|
|
def seed_worker(worker_id: int): |
|
|
"""Seed DataLoader worker; uses PyTorch's initial_seed to derive unique stream.""" |
|
|
worker_seed = torch.initial_seed() % 2**32 |
|
|
np.random.seed(worker_seed) |
|
|
random.seed(worker_seed) |
|
|
|
|
|
|
|
|
seed_everything(int(config.get("seed", 42))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure(pkg, pip_name=None): |
|
|
pip_name = pip_name or pkg |
|
|
try: |
|
|
__import__(pkg) |
|
|
except Exception: |
|
|
print(f"[setup] Installing {pip_name} ...") |
|
|
os.system(f"{sys.executable} -m pip install -q {pip_name}") |
|
|
|
|
|
_ensure("safetensors") |
|
|
_ensure("huggingface_hub") |
|
|
_ensure("pandas") |
|
|
_ensure("psutil") |
|
|
_ensure("medmnist") |
|
|
|
|
|
from safetensors.torch import save_file as save_safetensors |
|
|
from huggingface_hub import HfApi, create_repo, whoami, login |
|
|
import pandas as pd |
|
|
import psutil |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _timestamp() -> str: |
|
|
return datetime.now().strftime("%Y%m%d-%H%M%S") |
|
|
|
|
|
def _param_count(m: nn.Module) -> int: |
|
|
return sum(p.numel() for p in m.parameters()) |
|
|
|
|
|
def _resolve_repo_id(cfg: Dict) -> str: |
|
|
rid = os.getenv("PENTACHORA_HF_REPO") or cfg.get("hf_repo_id") |
|
|
if not rid: |
|
|
raise RuntimeError("Set config['hf_repo_id'] or export PENTACHORA_HF_REPO.") |
|
|
return rid |
|
|
|
|
|
def _hf_login_if_needed(): |
|
|
try: |
|
|
_ = whoami() |
|
|
except Exception: |
|
|
token = os.getenv("HF_TOKEN") |
|
|
if token: |
|
|
login(token=token, add_to_git_credential=True) |
|
|
else: |
|
|
print("[huggingface] No login found and HF_TOKEN not set. Push may fail; run `huggingface-cli login`.") |
|
|
|
|
|
def _ensure_repo(repo_id: str) -> HfApi: |
|
|
api = HfApi() |
|
|
create_repo(repo_id=repo_id, private=False, exist_ok=True, repo_type="model") |
|
|
return api |
|
|
|
|
|
def _zip_dir(src: Path, dst_zip: Path): |
|
|
with zipfile.ZipFile(dst_zip, "w", zipfile.ZIP_DEFLATED) as z: |
|
|
for p in src.rglob("*"): |
|
|
z.write(p, arcname=p.relative_to(src)) |
|
|
|
|
|
def _dataset_slug(name_or_names) -> str: |
|
|
if isinstance(name_or_names, (list, tuple)): |
|
|
return "+".join(n.strip().lower() for n in name_or_names) |
|
|
return str(name_or_names).strip().lower() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import medmnist |
|
|
from medmnist import INFO as MED_INFO |
|
|
except Exception: |
|
|
medmnist = None |
|
|
MED_INFO = None |
|
|
|
|
|
_TORCHVISION_KEYS = { |
|
|
"mnist": "MNIST", |
|
|
"fashionmnist": "FashionMNIST", |
|
|
"kmnist": "KMNIST", |
|
|
"emnist": "EMNIST", |
|
|
"qmnist": "QMNIST", |
|
|
"usps": "USPS", |
|
|
} |
|
|
_MEDMNIST_MAP = { |
|
|
"bloodmnist": "bloodmnist", "pathmnist": "pathmnist", "octmnist": "octmnist", |
|
|
"pneumoniamnist": "pneumoniamnist", "dermamnist": "dermamnist", "retinamnist": "retinamnist", |
|
|
"breastmnist": "breastmnist", "organamnist": "organamnist", "organcmnist": "organcmnist", |
|
|
"organsmnist": "organsmnist", "tissuemnist": "tissuemnist", |
|
|
} |
|
|
|
|
|
_DATASET_STATS_1CH = { |
|
|
"MNIST": ([0.1307], [0.3081]), |
|
|
"FashionMNIST": ([0.2860], [0.3530]), |
|
|
"KMNIST": ([0.1918], [0.3483]), |
|
|
"EMNIST": ([0.1307], [0.3081]), |
|
|
"QMNIST": ([0.1307], [0.3081]), |
|
|
"USPS": ([0.5000], [0.5000]), |
|
|
} |
|
|
_MEAN1, _STD1 = [0.5], [0.5] |
|
|
_MEAN3, _STD3 = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] |
|
|
|
|
|
def _norm_stats(name: str, channels: int) -> Tuple[List[float], List[float]]: |
|
|
if channels == 1: |
|
|
return _DATASET_STATS_1CH.get(name, (_MEAN1, _STD1)) |
|
|
return _MEAN3, _STD3 |
|
|
|
|
|
def _to_channels(target_c: int): |
|
|
def _fn(t: torch.Tensor) -> torch.Tensor: |
|
|
c = t.shape[0] |
|
|
if c == target_c: |
|
|
return t |
|
|
if target_c == 1: |
|
|
if c == 3: |
|
|
r, g, b = t[0], t[1], t[2] |
|
|
gray = 0.2989*r + 0.5870*g + 0.1140*b |
|
|
return gray.unsqueeze(0) |
|
|
return t[:1] |
|
|
if target_c == 3: |
|
|
if c == 1: |
|
|
return t.repeat(3, 1, 1) |
|
|
return t[:3] |
|
|
return t[:target_c] |
|
|
return transforms.Lambda(_fn) |
|
|
|
|
|
def _augmentations_for(name: str, size: int, channels: int) -> List[transforms.Transform]: |
|
|
aug = [] |
|
|
if not bool(config.get("augment", False)): |
|
|
return aug |
|
|
if name.upper() in {"MNIST","KMNIST","EMNIST","QMNIST","USPS"}: |
|
|
aug += [transforms.RandomAffine(degrees=8, translate=(0.05, 0.05), scale=(0.95, 1.05))] |
|
|
if size >= 32: |
|
|
pad = max(1, int(0.03 * size)) |
|
|
aug += [transforms.RandomCrop(size, padding=pad)] |
|
|
return aug |
|
|
if size >= 32: |
|
|
pad = max(1, int(0.03 * size)) |
|
|
aug += [transforms.RandomCrop(size, padding=pad)] |
|
|
aug += [transforms.RandomAffine(degrees=10, translate=(0.05, 0.05), scale=(0.95, 1.05))] |
|
|
if channels == 3 and name.lower().endswith("mnist"): |
|
|
aug += [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.02)] |
|
|
return aug |
|
|
|
|
|
def _build_transforms(dataset_name: str, split: str, native_c: int, target_c: int|str, size: int, normalize: bool, per_dataset_norm: bool) -> transforms.Compose: |
|
|
t: List[transforms.Transform] = [] |
|
|
if size != 28: |
|
|
t.append(transforms.Resize((size, size))) |
|
|
t.append(transforms.ToTensor()) |
|
|
out_c = native_c |
|
|
if target_c != "auto": |
|
|
t.append(_to_channels(int(target_c))) |
|
|
out_c = int(target_c) |
|
|
if split == "train": |
|
|
t = _augmentations_for(dataset_name, size, out_c) + t |
|
|
if normalize: |
|
|
if per_dataset_norm: |
|
|
mean, std = _norm_stats(dataset_name, out_c) |
|
|
else: |
|
|
mean, std = (_MEAN1, _STD1) if out_c == 1 else (_MEAN3, _STD3) |
|
|
t.append(transforms.Normalize(mean=mean, std=std)) |
|
|
t.append(transforms.Lambda(lambda x: x.view(-1))) |
|
|
return transforms.Compose(t) |
|
|
|
|
|
def collate_as_int(batch): |
|
|
xs, ys = zip(*batch) |
|
|
xs = torch.stack(xs, dim=0) |
|
|
_ys = [] |
|
|
for y in ys: |
|
|
if isinstance(y, (int, np.integer)): |
|
|
_ys.append(int(y)) |
|
|
elif torch.is_tensor(y): |
|
|
if y.ndim == 0: _ys.append(int(y.item())) |
|
|
elif y.ndim == 1 and y.numel()==1: _ys.append(int(y.item())) |
|
|
else: _ys.append(int(y.argmax().item())) |
|
|
else: |
|
|
arr = np.asarray(y) |
|
|
if arr.ndim == 0 or (arr.ndim==1 and arr.size==1): |
|
|
_ys.append(int(arr.item())) |
|
|
else: |
|
|
_ys.append(int(arr.argmax())) |
|
|
ys_tensor = torch.tensor(_ys, dtype=torch.long) |
|
|
return xs, ys_tensor |
|
|
|
|
|
def _get_med_info(flag: str) -> dict: |
|
|
if MED_INFO is None: |
|
|
raise ImportError("medmnist is not installed. `pip install medmnist`") |
|
|
if flag not in MED_INFO: |
|
|
raise KeyError(f"Unknown MedMNIST flag: {flag}") |
|
|
return MED_INFO[flag] |
|
|
|
|
|
def _med_class_names(info: dict) -> List[str]: |
|
|
lab = info["label"] |
|
|
return [lab[str(i)] for i in range(len(lab))] |
|
|
|
|
|
def load_single_dataset(name: str, split: str, |
|
|
cfg: Optional[Dict]=None, |
|
|
resolved_target_channels: Optional[int|str]=None |
|
|
) -> Tuple[torch.utils.data.Dataset, int, List[str], int, int]: |
|
|
""" |
|
|
Return: dataset, num_classes, class_names, input_dim (C*H*W), output_channels |
|
|
""" |
|
|
cfg = cfg or config |
|
|
name_key = name.strip() |
|
|
name_lower = name_key.lower() |
|
|
|
|
|
size = int(cfg.get("img_size", 28)) |
|
|
want_c = cfg.get("img_channels", "auto") if resolved_target_channels is None else resolved_target_channels |
|
|
normalize = bool(cfg.get("normalize", True)) |
|
|
per_dataset_norm = bool(cfg.get("per_dataset_norm", True)) |
|
|
|
|
|
|
|
|
if name_lower in _TORCHVISION_KEYS: |
|
|
canonical = _TORCHVISION_KEYS[name_lower] |
|
|
native_c = 1 |
|
|
transform = _build_transforms(canonical, split, native_c, want_c, size, normalize, per_dataset_norm) |
|
|
|
|
|
if canonical == "MNIST": |
|
|
ds = datasets.MNIST("./data", train=(split=="train"), transform=transform, download=True) |
|
|
ncls = 10; cls_names = [f"digit-{i}" for i in range(10)] |
|
|
elif canonical == "FashionMNIST": |
|
|
base = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot'] |
|
|
ds = datasets.FashionMNIST("./data", train=(split=="train"), transform=transform, download=True) |
|
|
ncls = 10; cls_names = [f"fashion-{n}" for n in base] |
|
|
elif canonical == "KMNIST": |
|
|
ds = datasets.KMNIST("./data", train=(split=="train"), transform=transform, download=True) |
|
|
ncls = 10; cls_names = [f"kmnist-{c}" for c in ['お','き','す','つ','な','は','ま','や','れ','を']] |
|
|
elif canonical == "EMNIST": |
|
|
ds = datasets.EMNIST("./data", split='balanced', train=(split=="train"), transform=transform, download=True) |
|
|
letters = ['0','1','2','3','4','5','6','7','8','9', |
|
|
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z', |
|
|
'a','b','d','e','f','g','h','n','q','r','t'] |
|
|
ncls = 47; cls_names = [f"emnist-{c}" for c in letters] |
|
|
elif canonical == "QMNIST": |
|
|
ds = datasets.QMNIST("./data", what=('train' if split=="train" else 'test'), transform=transform, download=True) |
|
|
ncls = 10; cls_names = [f"qmnist-{i}" for i in range(10)] |
|
|
elif canonical == "USPS": |
|
|
ds = datasets.USPS("./data", train=(split=="train"), transform=transform, download=True) |
|
|
ncls = 10; cls_names = [f"usps-{i}" for i in range(10)] |
|
|
else: |
|
|
raise ValueError(f"Unhandled TorchVision dataset: {canonical}") |
|
|
|
|
|
out_c = native_c if want_c == "auto" else int(want_c) |
|
|
input_dim = out_c * size * size |
|
|
return ds, ncls, cls_names, input_dim, out_c |
|
|
|
|
|
|
|
|
if name_lower in _MEDMNIST_MAP: |
|
|
if medmnist is None: |
|
|
raise ImportError("medmnist not available. `pip install medmnist`") |
|
|
flag = _MEDMNIST_MAP[name_lower] |
|
|
info = _get_med_info(flag) |
|
|
DataClass = getattr(medmnist, info["python_class"]) |
|
|
native_c = int(info.get("n_channels", 1)) |
|
|
out_c = native_c if want_c == "auto" else int(want_c) |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
_to_channels(out_c) if want_c != "auto" else transforms.Lambda(lambda t: t), |
|
|
*(_augmentations_for(name_key, size, out_c) if (split=="train") else []), |
|
|
transforms.Resize((size, size)) if size != 28 else transforms.Lambda(lambda t: t), |
|
|
transforms.Normalize(*(_norm_stats(name_key, out_c) if (normalize and per_dataset_norm) else ((_MEAN1,_STD1) if out_c==1 else (_MEAN3,_STD3)))) if normalize else transforms.Lambda(lambda t: t), |
|
|
transforms.Lambda(lambda t: t.view(-1)), |
|
|
]) |
|
|
ds = DataClass(split=('train' if split=="train" else 'test'), transform=transform, download=True, size=size) |
|
|
ncls = len(info["label"]); cls_names = _med_class_names(info) |
|
|
input_dim = out_c * size * size |
|
|
return ds, ncls, cls_names, input_dim, out_c |
|
|
|
|
|
raise ValueError(f"Unknown dataset name: {name}") |
|
|
|
|
|
def get_dataset_single(name: str, batch_size: int, num_workers: int = 2): |
|
|
""" |
|
|
Load a single dataset honoring config overrides. |
|
|
Returns: train_loader, test_loader, num_classes, class_names, input_dim, channels |
|
|
""" |
|
|
tr, ntr, names_tr, in_tr, out_c = load_single_dataset(name, "train", config) |
|
|
te, nte, names_te, in_te, out_c2 = load_single_dataset(name, "test", config) |
|
|
assert ntr == nte and in_tr == in_te and out_c == out_c2 |
|
|
g = make_torch_generator(int(config.get("seed", 42))) |
|
|
train_loader = DataLoader( |
|
|
tr, batch_size=batch_size, shuffle=True, num_workers=num_workers, |
|
|
pin_memory=torch.cuda.is_available(), collate_fn=collate_as_int, |
|
|
worker_init_fn=seed_worker, generator=g, persistent_workers=False |
|
|
) |
|
|
test_loader = DataLoader( |
|
|
te, batch_size=batch_size, shuffle=False, num_workers=num_workers, |
|
|
pin_memory=torch.cuda.is_available(), collate_fn=collate_as_int, |
|
|
worker_init_fn=seed_worker, generator=g, persistent_workers=False |
|
|
) |
|
|
|
|
|
return train_loader, test_loader, ntr, names_tr, in_tr, out_c |
|
|
|
|
|
|
|
|
TORCHVISION_DATASETS = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST", "USPS"] |
|
|
MEDMNIST_DATASETS = ["BloodMNIST","PathMNIST","OCTMNIST","PneumoniaMNIST","DermaMNIST", |
|
|
"RetinaMNIST","BreastMNIST","OrganAMNIST","OrganCMNIST","OrganSMNIST","TissueMNIST"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PentaFreqExtractor(nn.Module): |
|
|
""" |
|
|
Multi-channel spectral extractor: |
|
|
- Input: [B, C*H*W], unflatten -> [B, C, H, W] |
|
|
- 5 frequency bands -> encode to base_dim each |
|
|
""" |
|
|
def __init__(self, input_dim: int = 784, input_ch: int = 1, base_dim: int = 64, channels: int = 12): |
|
|
super().__init__() |
|
|
self.input_dim = input_dim |
|
|
self.input_ch = int(input_ch) |
|
|
side_f = (input_dim / max(1, self.input_ch)) ** 0.5 |
|
|
side = int(side_f) |
|
|
assert side * side * self.input_ch == input_dim, f"input_dim ({input_dim}) != C*H*W with H=W; C={self.input_ch}, side≈{side_f:.3f}" |
|
|
self.unflatten = nn.Unflatten(1, (self.input_ch, side, side)) |
|
|
self.base_dim = base_dim |
|
|
|
|
|
|
|
|
self.v0_ultrahigh = nn.Sequential( |
|
|
nn.Conv2d(self.input_ch, channels, 3, padding=1), |
|
|
nn.BatchNorm2d(channels), nn.ReLU(), |
|
|
nn.Conv2d(channels, channels, 3, padding=1, groups=channels), |
|
|
nn.BatchNorm2d(channels), nn.ReLU(), |
|
|
nn.AdaptiveAvgPool2d(7), nn.Flatten() |
|
|
); self.v0_encode = nn.Linear(channels * 49, base_dim) |
|
|
|
|
|
|
|
|
self.v1_high = nn.Sequential( |
|
|
nn.Conv2d(self.input_ch, channels, 3, padding=1), |
|
|
nn.BatchNorm2d(channels), nn.Tanh(), |
|
|
nn.MaxPool2d(2), |
|
|
nn.Conv2d(channels, channels, 3, padding=1), |
|
|
nn.BatchNorm2d(channels), nn.Tanh(), |
|
|
nn.AdaptiveAvgPool2d(7), nn.Flatten() |
|
|
); self.v1_encode = nn.Linear(channels * 49, base_dim) |
|
|
|
|
|
|
|
|
self.v2_mid = nn.Sequential( |
|
|
nn.Conv2d(self.input_ch, channels, 5, padding=2, stride=2), |
|
|
nn.BatchNorm2d(channels), nn.GELU(), |
|
|
nn.Conv2d(channels, channels, 3, padding=1), |
|
|
nn.BatchNorm2d(channels), nn.GELU(), |
|
|
nn.AdaptiveAvgPool2d(7), nn.Flatten() |
|
|
); self.v2_encode = nn.Linear(channels * 49, base_dim) |
|
|
|
|
|
|
|
|
self.v3_lowmid = nn.Sequential( |
|
|
nn.AvgPool2d(2), |
|
|
nn.Conv2d(self.input_ch, channels, 7, padding=3), |
|
|
nn.BatchNorm2d(channels), nn.SiLU(), |
|
|
nn.AdaptiveAvgPool2d(7), nn.Flatten() |
|
|
); self.v3_encode = nn.Linear(channels * 49, base_dim) |
|
|
|
|
|
|
|
|
self.v4_low = nn.Sequential( |
|
|
nn.AvgPool2d(4), |
|
|
nn.Conv2d(self.input_ch, channels, 7, padding=3), |
|
|
nn.BatchNorm2d(channels), nn.Sigmoid(), |
|
|
nn.AdaptiveAvgPool2d(7), nn.Flatten() |
|
|
); self.v4_encode = nn.Linear(channels * 49, base_dim) |
|
|
|
|
|
self.register_buffer("adjacency_matrix", torch.ones(5, 5) - torch.eye(5)) |
|
|
self._init_edge_kernels(channels) |
|
|
|
|
|
@torch.no_grad() |
|
|
def _init_edge_kernels(self, channels: int): |
|
|
if channels < 5: return |
|
|
conv0 = self.v0_ultrahigh[0] |
|
|
if not isinstance(conv0, nn.Conv2d): return |
|
|
if conv0.weight.shape[1] >= 1: |
|
|
k = conv0.weight |
|
|
k[0,0] = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=k.dtype)/4 |
|
|
k[1,0] = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=k.dtype)/4 |
|
|
k[2,0] = torch.tensor([[0,-1,0],[-1,4,-1],[0,-1,0]], dtype=k.dtype)/4 |
|
|
k[3,0] = torch.tensor([[1,0,0],[0,-1,0],[0,0,0]], dtype=k.dtype)/2 |
|
|
k[4,0] = torch.tensor([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=k.dtype)/3 |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
img = self.unflatten(x) |
|
|
v0 = self.v0_encode(self.v0_ultrahigh(img)) |
|
|
v1 = self.v1_encode(self.v1_high(img)) |
|
|
v2 = self.v2_encode(self.v2_mid(img)) |
|
|
v3 = self.v3_encode(self.v3_lowmid(img)) |
|
|
v4 = self.v4_encode(self.v4_low(img)) |
|
|
vertices = torch.stack([v0, v1, v2, v3, v4], dim=1) |
|
|
return vertices, self.adjacency_matrix |
|
|
|
|
|
class PentachoronCrossAttention(nn.Module): |
|
|
def __init__(self, dim: int, num_heads: int = 14, dropout: float = 0.0): |
|
|
super().__init__() |
|
|
self.attn = nn.MultiheadAttention(dim, num_heads=num_heads, dropout=dropout, batch_first=True) |
|
|
def _row_to_attn_mask(self, row: torch.Tensor) -> torch.Tensor: |
|
|
mask = torch.zeros(1, row.numel(), device=row.device, dtype=torch.float32) |
|
|
mask[(row == 0).unsqueeze(0)] = float("-inf") |
|
|
return mask |
|
|
def forward(self, vertices: torch.Tensor, adjacency: torch.Tensor) -> torch.Tensor: |
|
|
B, V, D = vertices.shape |
|
|
outs = [] |
|
|
for i in range(V): |
|
|
q = vertices[:, i:i+1, :] |
|
|
k = v = vertices |
|
|
mask = self._row_to_attn_mask(adjacency[i].to(vertices.device)) |
|
|
out, _ = self.attn(q, k, v, attn_mask=mask, need_weights=False) |
|
|
outs.append(out) |
|
|
return torch.cat(outs, dim=1) |
|
|
|
|
|
class PentachoronOpinionFusion(nn.Module): |
|
|
def __init__(self, base_dim: int = 64, proj_dim: Optional[int] = None, num_heads: int = 14, p_dropout: float = 0.2): |
|
|
super().__init__() |
|
|
self.cross = PentachoronCrossAttention(dim=base_dim, num_heads=num_heads) |
|
|
self.fusion = nn.Sequential( |
|
|
nn.Linear(base_dim * 5, base_dim * 3), |
|
|
nn.BatchNorm1d(base_dim * 3), nn.ReLU(), nn.Dropout(p_dropout), |
|
|
nn.Linear(base_dim * 3, base_dim * 2), |
|
|
nn.BatchNorm1d(base_dim * 2), nn.ReLU(), |
|
|
nn.Linear(base_dim * 2, base_dim), |
|
|
) |
|
|
self.projection = None if proj_dim is None else nn.Linear(base_dim, proj_dim, bias=False) |
|
|
self._lambda_raw = nn.Parameter(torch.tensor(0.0)) |
|
|
|
|
|
@staticmethod |
|
|
def _softmax_geometry(vertices: torch.Tensor, adjacency: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
v_norm = F.normalize(vertices, dim=2, eps=1e-8) |
|
|
sims = torch.bmm(v_norm, v_norm.transpose(1, 2)) |
|
|
edge_strengths = sims * adjacency.to(vertices.dtype).unsqueeze(0) |
|
|
weights = F.softmax(edge_strengths.sum(dim=2), dim=1) |
|
|
weighted = vertices * weights.unsqueeze(2) |
|
|
return weighted, weights |
|
|
|
|
|
def forward(self, vertices: torch.Tensor, adjacency: torch.Tensor, return_diag: bool = False): |
|
|
soft_out, weights = self._softmax_geometry(vertices, adjacency) |
|
|
attn_out = self.cross(vertices, adjacency) |
|
|
lam = torch.sigmoid(self._lambda_raw) |
|
|
combined = lam * soft_out + (1.0 - lam) * attn_out |
|
|
fused = self.fusion(combined.flatten(1)) |
|
|
if self.projection is not None: |
|
|
fused = self.projection(fused) |
|
|
z = F.normalize(fused, dim=1) |
|
|
if not return_diag: |
|
|
return z, None |
|
|
return z, {"lambda": lam.detach(), "softmax_weights": weights.detach()} |
|
|
|
|
|
class PentaFreqEncoderV2(nn.Module): |
|
|
def __init__(self, input_dim: int = 784, input_ch: int = 1, base_dim: int = 64, proj_dim: Optional[int] = None, num_heads: int = 14, channels: int = 12): |
|
|
super().__init__() |
|
|
self.extractor = PentaFreqExtractor(input_dim=input_dim, input_ch=input_ch, base_dim=base_dim, channels=channels) |
|
|
self.opinion = PentachoronOpinionFusion(base_dim=base_dim, proj_dim=proj_dim, num_heads=num_heads) |
|
|
@torch.no_grad() |
|
|
def get_frequency_contributions(self, x: torch.Tensor) -> torch.Tensor: |
|
|
verts, adj = self.extractor(x) |
|
|
_, w = self.opinion._softmax_geometry(verts, adj) |
|
|
return w |
|
|
def forward(self, x: torch.Tensor, return_diag: bool = False): |
|
|
verts, adj = self.extractor(x) |
|
|
z, diag = self.opinion(verts, adj, return_diag) |
|
|
return (z, diag) if return_diag else z |
|
|
|
|
|
class BatchedPentachoronConstellation(nn.Module): |
|
|
def __init__(self, num_classes: int, dim: int, num_pairs: int = 5, device: Optional[torch.device] = None, lambda_sep: float = 0.5): |
|
|
super().__init__() |
|
|
self.num_classes = num_classes |
|
|
self.dim = dim |
|
|
self.num_pairs = num_pairs |
|
|
self.device = device if device is not None else torch.device("cpu") |
|
|
self.lambda_separation = lambda_sep |
|
|
|
|
|
self.dispatchers = nn.Parameter(self._init_batched_pentachora()) |
|
|
self.specialists = nn.Parameter(self._init_batched_pentachora()) |
|
|
|
|
|
self.dispatcher_weights = nn.Parameter(torch.randn(num_pairs, 5) * 0.1) |
|
|
self.specialist_weights = nn.Parameter(torch.randn(num_pairs, 5) * 0.1) |
|
|
self.temps = nn.Parameter(0.3 * torch.ones(num_pairs)) |
|
|
|
|
|
self.register_buffer("vertex_map", self._create_vertex_mapping()) |
|
|
|
|
|
self.group_heads = nn.ModuleList([ |
|
|
nn.Linear(dim, int((self.vertex_map == i).sum().item())) if int((self.vertex_map == i).sum().item()) > 0 else None |
|
|
for i in range(5) |
|
|
]) |
|
|
|
|
|
self.cross_attention = nn.MultiheadAttention(embed_dim=dim, num_heads=14, dropout=0.1, batch_first=True) |
|
|
self.aggregation_weights = nn.Parameter(torch.ones(num_pairs) / num_pairs) |
|
|
|
|
|
self.fusion = nn.Sequential( |
|
|
nn.Linear(num_classes * num_pairs, num_classes * 2), |
|
|
nn.BatchNorm1d(num_classes * 2), nn.ReLU(), nn.Dropout(0.2), |
|
|
nn.Linear(num_classes * 2, num_classes) |
|
|
) |
|
|
|
|
|
self.coherence_head = nn.Sequential(nn.Linear(dim, dim // 2), nn.GELU(), nn.Linear(dim // 2, 1)) |
|
|
|
|
|
def _init_batched_pentachora(self) -> torch.Tensor: |
|
|
sqrt15, sqrt10, sqrt5 = np.sqrt(15), np.sqrt(10), np.sqrt(5) |
|
|
base_simplex = torch.tensor([ |
|
|
[ 1.0, 0.0, 0.0, 0.0], |
|
|
[-0.25, sqrt15/4, 0.0, 0.0], |
|
|
[-0.25,-sqrt15/12, sqrt10/3, 0.0], |
|
|
[-0.25,-sqrt15/12,-sqrt10/6, sqrt5/2], |
|
|
[-0.25,-sqrt15/12,-sqrt10/6,-sqrt5/2] |
|
|
], device=self.device, dtype=torch.float32) |
|
|
base_simplex = F.normalize(base_simplex, dim=1) |
|
|
pentachora = torch.zeros(self.num_pairs, 5, self.dim, device=self.device, dtype=torch.float32) |
|
|
for i in range(self.num_pairs): |
|
|
pentachora[i, :, :4] = base_simplex * (1 + 0.1 * i) |
|
|
if self.dim > 4: |
|
|
pentachora[i, :, 4:] = torch.randn(5, self.dim - 4, device=self.device) * (random.random() * 0.25) |
|
|
return pentachora * 2.0 |
|
|
|
|
|
def _create_vertex_mapping(self) -> torch.Tensor: |
|
|
mapping = torch.zeros(self.num_classes, dtype=torch.long) |
|
|
for i in range(self.num_classes): |
|
|
mapping[i] = i % 5 |
|
|
return mapping |
|
|
|
|
|
def forward(self, x: torch.Tensor): |
|
|
B = x.size(0) |
|
|
coherence_gate = torch.sigmoid(self.coherence_head(x)) |
|
|
|
|
|
x_exp = x.unsqueeze(1).unsqueeze(2) |
|
|
disp_exp = self.dispatchers.unsqueeze(0) |
|
|
spec_exp = self.specialists.unsqueeze(0) |
|
|
disp_d = torch.norm(x_exp - disp_exp, dim=3) |
|
|
spec_d = torch.norm(x_exp - spec_exp, dim=3) |
|
|
|
|
|
dw = F.softmax(self.dispatcher_weights, dim=1).unsqueeze(0) |
|
|
sw = F.softmax(self.specialist_weights, dim=1).unsqueeze(0) |
|
|
temps = torch.clamp(self.temps, 0.1, 2.0).view(1, -1, 1) |
|
|
|
|
|
disp_logits = -(disp_d * dw) / temps |
|
|
spec_logits = -(spec_d * sw) / temps |
|
|
|
|
|
c = coherence_gate.unsqueeze(-1) |
|
|
disp_probs = F.softmax(disp_logits * c, dim=2) |
|
|
spec_probs = F.softmax(spec_logits * c, dim=2) |
|
|
probs = 0.5 * disp_probs + 0.5 * spec_probs |
|
|
|
|
|
scores_by_pair = [] |
|
|
for p in range(self.num_pairs): |
|
|
pair_scores = torch.zeros(B, self.num_classes, device=x.device) |
|
|
for v_idx in range(5): |
|
|
idxs = (self.vertex_map == v_idx).nonzero(as_tuple=True)[0] |
|
|
if len(idxs) == 0: continue |
|
|
v_prob = probs[:, p, v_idx:v_idx+1] |
|
|
if self.group_heads[v_idx] is not None: |
|
|
g_logits = self.group_heads[v_idx](x) |
|
|
gated = g_logits * v_prob |
|
|
for i, cls in enumerate(idxs.tolist()): |
|
|
if i < gated.size(1): |
|
|
pair_scores[:, cls] = gated[:, i] |
|
|
scores_by_pair.append(pair_scores) |
|
|
|
|
|
scores_tensor = torch.stack(scores_by_pair, dim=1) |
|
|
|
|
|
centers = self.dispatchers.mean(dim=1).unsqueeze(0).expand(B, -1, -1) |
|
|
_attn, _ = self.cross_attention(centers, centers, centers) |
|
|
|
|
|
agg = F.softmax(self.aggregation_weights, dim=0).view(1, -1, 1) |
|
|
weighted = (scores_tensor * agg).sum(dim=1) |
|
|
fused = self.fusion(scores_tensor.flatten(1)) |
|
|
final = 0.6 * weighted + 0.4 * fused |
|
|
return final, {"disp_d": disp_d, "spec_d": spec_d, "probs": probs} |
|
|
|
|
|
def _batched_cayley_menger(self, pentachora: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Stable CM proxy: det(M) via eigvals of (M + eps*I). |
|
|
Returns a positive scalar per cube; larger => more 'volumetric' (less degenerate). |
|
|
""" |
|
|
P = pentachora.shape[0] |
|
|
d2 = torch.cdist(pentachora, pentachora).pow(2) |
|
|
M = torch.zeros(P, 6, 6, device=self.device, dtype=pentachora.dtype) |
|
|
M[:, 0, 1:] = 1.0 |
|
|
M[:, 1:, 0] = 1.0 |
|
|
M[:, 1:, 1:] = d2 |
|
|
|
|
|
eps = 1e-6 |
|
|
I = torch.eye(6, device=self.device, dtype=pentachora.dtype).unsqueeze(0) |
|
|
M_eps = M + eps * I |
|
|
|
|
|
evals = torch.linalg.eigvalsh(M_eps) |
|
|
evals = evals.clamp_min(1e-12) |
|
|
logdet = evals.log().sum(dim=1) |
|
|
det = torch.exp(logdet) |
|
|
|
|
|
det = torch.nan_to_num(det, nan=0.0, posinf=1e6, neginf=0.0) |
|
|
return det |
|
|
|
|
|
|
|
|
def _batched_edge_variance(self, pentachora: torch.Tensor) -> torch.Tensor: |
|
|
d = torch.cdist(pentachora, pentachora) |
|
|
mask = torch.triu(torch.ones(5, 5, device=pentachora.device), diagonal=1).bool() |
|
|
edges = torch.stack([d[p][mask] for p in range(self.num_pairs)]) |
|
|
return edges.var(dim=1).sum() + torch.relu(0.5 - edges.min(dim=1)[0]).sum() |
|
|
|
|
|
def regularization_loss(self, vertex_weights=None) -> torch.Tensor: |
|
|
disp_cm = self._batched_cayley_menger(self.dispatchers) |
|
|
spec_cm = self._batched_cayley_menger(self.specialists) |
|
|
cm_loss = torch.relu(1.0 - torch.abs(disp_cm)).sum() + torch.relu(1.0 - torch.abs(spec_cm)).sum() |
|
|
edge_loss = self._batched_edge_variance(self.dispatchers) + self._batched_edge_variance(self.specialists) |
|
|
disp_centers = self.dispatchers.mean(dim=1) |
|
|
spec_centers = self.specialists.mean(dim=1) |
|
|
cos_sims = F.cosine_similarity(disp_centers, spec_centers, dim=1, eps=1e-8) |
|
|
ortho = torch.abs(cos_sims).sum() * self.lambda_separation |
|
|
separations = torch.norm(disp_centers - spec_centers, dim=1) |
|
|
sep = torch.relu(2.0 - separations).sum() * self.lambda_separation |
|
|
|
|
|
dyn = 0.0 |
|
|
if vertex_weights is not None: |
|
|
vw = vertex_weights.to(self.dispatchers.device) |
|
|
disp_norms = torch.norm(self.dispatchers, p=2, dim=2) |
|
|
spec_norms = torch.norm(self.specialists, p=2, dim=2) |
|
|
dyn = 0.1 * ((disp_norms * vw.unsqueeze(0)).mean() + (spec_norms * vw.unsqueeze(0)).mean()) |
|
|
|
|
|
return (cm_loss + edge_loss + ortho + sep) / self.num_pairs + dyn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dual_contrastive_loss(latents, targets, constellation, temp: float): |
|
|
B = latents.size(0) |
|
|
z = F.normalize(latents, dim=1, eps=1e-8) |
|
|
disp = F.normalize(constellation.dispatchers, dim=2, eps=1e-8) |
|
|
spec = F.normalize(constellation.specialists, dim=2, eps=1e-8) |
|
|
disp_logits = torch.einsum('bd,pvd->bpv', z, disp) / temp |
|
|
spec_logits = torch.einsum('bd,pvd->bpv', z, spec) / temp |
|
|
tvert = constellation.vertex_map[targets] |
|
|
idx = tvert.view(B, 1, 1).expand(B, disp_logits.size(1), 1) |
|
|
disp_pos = disp_logits.gather(2, idx).squeeze(2) |
|
|
spec_pos = spec_logits.gather(2, idx).squeeze(2) |
|
|
disp_lse = torch.logsumexp(disp_logits, dim=2) |
|
|
spec_lse = torch.logsumexp(spec_logits, dim=2) |
|
|
return (disp_lse - disp_pos).mean() + (spec_lse - spec_pos).mean() |
|
|
|
|
|
class RoseDiagnosticHead(nn.Module): |
|
|
def __init__(self, latent_dim: int, hidden_dim: int = 128): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential(nn.Linear(latent_dim, hidden_dim), nn.GELU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, 1)) |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.net(x) |
|
|
|
|
|
def rose_score_magnitude(x, need, relation, purpose, eps: float = 1e-8): |
|
|
x_n = F.normalize(x, dim=-1, eps=eps) |
|
|
n_n = F.normalize(need, dim=-1, eps=eps) |
|
|
r_n = F.normalize(relation, dim=-1, eps=eps) |
|
|
p_n = F.normalize(purpose, dim=-1, eps=eps) |
|
|
r7 = ((x_n*n_n).sum(-1) + (x_n*r_n).sum(-1) + (x_n*p_n).sum(-1)) / 3.0 |
|
|
r8 = x.norm(dim=-1).clamp_min(eps) |
|
|
return r7 * r8 |
|
|
|
|
|
def rose_contrastive_loss(latents, targets, constellation, temp: float = 0.5): |
|
|
B, D = latents.shape |
|
|
tvert = constellation.vertex_map[targets] |
|
|
need = constellation.specialists[:, tvert, :].mean(dim=0) |
|
|
relation = constellation.dispatchers[:, tvert, :].mean(dim=0) |
|
|
purpose = constellation.specialists.mean(dim=(0, 1)).unsqueeze(0).expand(B, D) |
|
|
rose = rose_score_magnitude(latents, need, relation, purpose) |
|
|
weights = (1.0 - torch.tanh(rose)).detach() |
|
|
spec = F.normalize(constellation.specialists.mean(dim=0), dim=1, eps=1e-8) |
|
|
disp = F.normalize(constellation.dispatchers.mean(dim=0), dim=1, eps=1e-8) |
|
|
z = F.normalize(latents, dim=1, eps=1e-8) |
|
|
spec_logits = (z @ spec.T) / temp |
|
|
disp_logits = (z @ disp.T) / temp |
|
|
spec_pos = spec_logits.gather(1, tvert.view(-1,1)).squeeze(1) |
|
|
disp_pos = disp_logits.gather(1, tvert.view(-1,1)).squeeze(1) |
|
|
spec_lse = torch.logsumexp(spec_logits, dim=1) |
|
|
disp_lse = torch.logsumexp(disp_logits, dim=1) |
|
|
per_sample = 0.5 * ((spec_lse - spec_pos) + (disp_lse - disp_pos)) |
|
|
return (per_sample * weights).mean(), rose.detach() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_class_similarity(constellation_model: BatchedPentachoronConstellation, num_classes: int) -> torch.Tensor: |
|
|
W = constellation_model.fusion[-1].weight.data.detach() |
|
|
Wn = F.normalize(W, p=2, dim=1) |
|
|
return torch.clamp(Wn @ Wn.T, 0.0, 1.0) |
|
|
|
|
|
def vertex_weights_from_confusion(cm: np.ndarray, class_similarity: torch.Tensor, vertex_map: torch.Tensor, device: torch.device) -> torch.Tensor: |
|
|
C = cm.shape[0] |
|
|
totals = cm.sum(axis=1) |
|
|
correct = cm.diagonal() |
|
|
acc = np.divide(correct, totals, out=np.zeros_like(correct, dtype=float), where=totals != 0) |
|
|
confusion_scores = 1.0 - torch.tensor(acc, device=device, dtype=torch.float32) |
|
|
sigma = 0.5 |
|
|
gaussian = torch.exp(-((1 - class_similarity) ** 2) / (2 * sigma ** 2)) |
|
|
propagated = gaussian @ confusion_scores |
|
|
v_sum = torch.zeros(5, device=device); v_cnt = torch.zeros(5, device=device) |
|
|
for cls, v in enumerate(vertex_map.tolist()): |
|
|
v_sum[v] += propagated[cls]; v_cnt[v] += 1 |
|
|
v_avg = torch.zeros_like(v_sum); mask = v_cnt > 0; v_avg[mask] = v_sum[mask] / v_cnt[mask] |
|
|
vw = 1.0 - torch.tanh(v_avg) |
|
|
return F.normalize(vw, p=1, dim=0) * 5.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(encoder: nn.Module, constellation: nn.Module, loader, num_classes: int, device: torch.device, collect_diag: bool = False): |
|
|
encoder.eval(); constellation.eval() |
|
|
all_preds, all_targets = [], [] |
|
|
lambda_vals = [] |
|
|
soft_w_sums = torch.zeros(5, device=device) |
|
|
soft_w_count = 0 |
|
|
with torch.no_grad(): |
|
|
for x, y in tqdm(loader, desc="Evaluating"): |
|
|
x, y = x.to(device), y.to(device) |
|
|
if collect_diag: |
|
|
z, diag = encoder(x, return_diag=True) |
|
|
w = diag["softmax_weights"] |
|
|
soft_w_sums += w.sum(dim=0); soft_w_count += w.size(0) |
|
|
else: |
|
|
z = encoder(x) |
|
|
logits, _ = constellation(z) |
|
|
preds = logits.argmax(dim=1) |
|
|
all_preds.append(preds.cpu().numpy()) |
|
|
all_targets.append(y.cpu().numpy()) |
|
|
if hasattr(encoder, "opinion") and hasattr(encoder.opinion, "_lambda_raw"): |
|
|
lambda_vals.append(float(torch.sigmoid(encoder.opinion._lambda_raw).item())) |
|
|
all_preds = np.concatenate(all_preds); all_targets = np.concatenate(all_targets) |
|
|
acc = float((all_preds == all_targets).mean()) |
|
|
cm = confusion_matrix(all_targets, all_preds, labels=np.arange(num_classes)) |
|
|
per_class = np.divide(cm.diagonal(), cm.sum(axis=1), out=np.zeros(num_classes), where=cm.sum(axis=1)!=0) |
|
|
avg_soft_w = (soft_w_sums / soft_w_count).detach().cpu().numpy() if (collect_diag and soft_w_count > 0) else None |
|
|
lam_eval = float(np.mean(lambda_vals)) if lambda_vals else None |
|
|
return acc, per_class.tolist(), cm, avg_soft_w, lam_eval |
|
|
|
|
|
def _adapt_pairs_by_classes(cfg: Dict, num_classes: int) -> int: |
|
|
|
|
|
pairs = cfg.get("num_pentachoron_pairs", 1) |
|
|
target = max(1, int(math.ceil(num_classes / 10))) |
|
|
return max(pairs, target) |
|
|
|
|
|
def train_one( |
|
|
train_loader, |
|
|
test_loader, |
|
|
num_classes: int, |
|
|
cfg: dict, |
|
|
device: torch.device, |
|
|
writer: SummaryWriter, |
|
|
class_names: Optional[list] = None, |
|
|
): |
|
|
pairs = _adapt_pairs_by_classes(cfg, num_classes) |
|
|
if pairs != cfg.get("num_pentachoron_pairs"): |
|
|
print(f"[auto] Adjusting num_pentachoron_pairs -> {pairs} for {num_classes} classes.") |
|
|
cfg_local = dict(cfg); cfg_local["num_pentachoron_pairs"] = pairs |
|
|
|
|
|
encoder = PentaFreqEncoderV2( |
|
|
input_dim=cfg_local["input_dim"], |
|
|
input_ch=cfg_local.get("input_channels", 1), |
|
|
base_dim=cfg_local["base_dim"], |
|
|
proj_dim=None, |
|
|
num_heads=cfg_local.get("num_heads", 14), |
|
|
channels=cfg_local.get("channels", 12), |
|
|
).to(device) |
|
|
|
|
|
constellation = BatchedPentachoronConstellation( |
|
|
num_classes=num_classes, |
|
|
dim=cfg_local["base_dim"], |
|
|
num_pairs=cfg_local["num_pentachoron_pairs"], |
|
|
device=device, |
|
|
lambda_sep=cfg_local["lambda_separation"], |
|
|
).to(device) |
|
|
|
|
|
diag_head = RoseDiagnosticHead(cfg_local["base_dim"]).to(device) |
|
|
|
|
|
params = list(encoder.parameters()) + list(constellation.parameters()) + list(diag_head.parameters()) |
|
|
optim = torch.optim.AdamW(params, lr=cfg_local["lr"], weight_decay=cfg_local["weight_decay"]) |
|
|
lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=cfg_local["epochs"]) |
|
|
|
|
|
w_ce = float(cfg_local.get("w_ce", 1.0)) |
|
|
w_dual = float(cfg_local.get("w_dual", 1.0)) |
|
|
w_rose = float(cfg_local.get("w_rose", 1.0)) |
|
|
w_diag = float(cfg_local.get("w_diag", 0.1)) |
|
|
w_reg = float(cfg_local.get("w_reg", cfg_local["loss_weight_scalar"])) |
|
|
|
|
|
history = {"train_loss": [], "train_acc": [], "test_acc": [], "ce": [], "dual": [], "rose": [], "diag": [], "reg": [], "lambda": []} |
|
|
best = {"acc": 0.0, "cm": None, "epoch": -1} |
|
|
vertex_weights = None |
|
|
|
|
|
global_step = 0 |
|
|
for epoch in range(cfg_local["epochs"]): |
|
|
encoder.train(); constellation.train(); diag_head.train() |
|
|
sum_loss = sum_ce = sum_dual = sum_rose = sum_diag = sum_reg = 0.0 |
|
|
correct = total = 0 |
|
|
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg_local['epochs']} [Train]") |
|
|
for x, y in pbar: |
|
|
x, y = x.to(device), y.to(device) |
|
|
optim.zero_grad() |
|
|
|
|
|
z = encoder(x) |
|
|
logits, _ = constellation(z) |
|
|
|
|
|
l_ce = F.cross_entropy(logits, y) |
|
|
l_dual = dual_contrastive_loss(z, y, constellation, temp=cfg_local["temp"]) |
|
|
l_rose, rose_scores = rose_contrastive_loss(z, y, constellation, temp=cfg_local["temp"]) |
|
|
pred_rose = diag_head(z.detach()).squeeze(-1) |
|
|
l_diag = F.mse_loss(pred_rose, rose_scores) |
|
|
l_reg = constellation.regularization_loss(vertex_weights=vertex_weights) |
|
|
|
|
|
loss = (w_ce*l_ce) + (w_dual*l_dual) + (w_rose*l_rose) + (w_diag*l_diag) + (w_reg*l_reg) |
|
|
|
|
|
|
|
|
if not torch.isfinite(l_ce) or not torch.isfinite(l_dual) \ |
|
|
or not torch.isfinite(l_rose) or not torch.isfinite(l_diag) \ |
|
|
or not torch.isfinite(l_reg) or not torch.isfinite(loss): |
|
|
print("[NaN-guard] non-finite detected. Skipping step. " |
|
|
f"ce={l_ce.item() if torch.isfinite(l_ce) else 'nan'}, " |
|
|
f"dual={l_dual.item() if torch.isfinite(l_dual) else 'nan'}, " |
|
|
f"rose={l_rose.item() if torch.isfinite(l_rose) else 'nan'}, " |
|
|
f"reg={l_reg.item() if torch.isfinite(l_reg) else 'nan'}") |
|
|
|
|
|
|
|
|
for g in optim.param_groups: |
|
|
g["lr"] = max(g["lr"] * 0.5, 1e-6) |
|
|
|
|
|
optim.zero_grad(set_to_none=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for p in list(encoder.parameters()) + list(constellation.parameters()): |
|
|
if torch.isfinite(p).all(): |
|
|
p.clamp_(-1e3, 1e3) |
|
|
continue |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0) |
|
|
torch.nn.utils.clip_grad_norm_(constellation.parameters(), 1.0) |
|
|
optim.step() |
|
|
|
|
|
bs = x.size(0) |
|
|
sum_loss += loss.item() * bs |
|
|
sum_ce += l_ce.item() * bs |
|
|
sum_dual += l_dual.item() * bs |
|
|
sum_rose += l_rose.item() * bs |
|
|
sum_diag += l_diag.item() * bs |
|
|
sum_reg += l_reg.item() * bs |
|
|
|
|
|
preds = logits.argmax(dim=1) |
|
|
correct += (preds == y).sum().item() |
|
|
total += bs |
|
|
|
|
|
|
|
|
writer.add_scalar("step/loss", loss.item(), global_step) |
|
|
writer.add_scalar("step/ce", l_ce.item(), global_step) |
|
|
writer.add_scalar("step/dual", l_dual.item(), global_step) |
|
|
writer.add_scalar("step/rose", l_rose.item(), global_step) |
|
|
writer.add_scalar("step/diag", l_diag.item(), global_step) |
|
|
writer.add_scalar("step/reg", l_reg.item(), global_step) |
|
|
global_step += 1 |
|
|
|
|
|
pbar.set_postfix({ |
|
|
"loss": f"{loss.item():.4f}", |
|
|
"acc": f"{correct/max(1,total):.4f}", |
|
|
"ce": f"{l_ce.item():.3f}", |
|
|
"dual": f"{l_dual.item():.3f}", |
|
|
"rose": f"{l_rose.item():.3f}", |
|
|
"reg": f"{l_reg.item():.3f}", |
|
|
}) |
|
|
|
|
|
train_loss = sum_loss / max(1, total) |
|
|
train_acc = correct / max(1, total) |
|
|
history["train_loss"].append(train_loss) |
|
|
history["train_acc"].append(train_acc) |
|
|
history["ce"].append(sum_ce / max(1,total)) |
|
|
history["dual"].append(sum_dual / max(1,total)) |
|
|
history["rose"].append(sum_rose / max(1,total)) |
|
|
history["diag"].append(sum_diag / max(1,total)) |
|
|
history["reg"].append(sum_reg / max(1,total)) |
|
|
|
|
|
|
|
|
test_acc, per_class_acc, cm, avg_soft_w, lam_eval = evaluate( |
|
|
encoder, constellation, test_loader, num_classes, device, collect_diag=True |
|
|
) |
|
|
history["test_acc"].append(test_acc) |
|
|
if lam_eval is not None: |
|
|
history["lambda"].append(lam_eval) |
|
|
else: |
|
|
history["lambda"].append(float(torch.sigmoid(encoder.opinion._lambda_raw).item()) |
|
|
if hasattr(encoder, "opinion") else 0.5) |
|
|
|
|
|
lr_sched.step() |
|
|
|
|
|
|
|
|
writer.add_scalar("epoch/train_loss", train_loss, epoch+1) |
|
|
writer.add_scalar("epoch/train_acc", train_acc, epoch+1) |
|
|
writer.add_scalar("epoch/test_acc", test_acc, epoch+1) |
|
|
writer.add_scalar("epoch/lr", optim.param_groups[0]["lr"], epoch+1) |
|
|
writer.add_scalar("epoch/lambda", history["lambda"][-1], epoch+1) |
|
|
|
|
|
print(f"\n[Epoch {epoch+1}/{cfg_local['epochs']}] " |
|
|
f"TrainLoss {train_loss:.4f} | TrainAcc {train_acc:.4f} | TestAcc {test_acc:.4f} | " |
|
|
f"CE {history['ce'][-1]:.3f} Dual {history['dual'][-1]:.3f} " |
|
|
f"ROSE {history['rose'][-1]:.3f} Reg {history['reg'][-1]:.3f} λ {history['lambda'][-1]:.3f}") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
class_sim = get_class_similarity(constellation, num_classes).to(device) |
|
|
vertex_weights = vertex_weights_from_confusion(cm, class_sim, constellation.vertex_map, device) |
|
|
|
|
|
if test_acc > best["acc"]: |
|
|
best["acc"], best["cm"], best["epoch"] = test_acc, cm, epoch+1 |
|
|
print(f" 🎯 New Best Acc: {best['acc']:.4f} at epoch {best['epoch']}") |
|
|
|
|
|
|
|
|
try: |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
os.makedirs("plots", exist_ok=True) |
|
|
plt.figure(figsize=(9, 7)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=class_names, yticklabels=class_names) |
|
|
plt.title(f'Confusion (Epoch {epoch+1}) | Acc: {test_acc:.4f}') |
|
|
plt.xlabel('Predicted'); plt.ylabel('True'); plt.tight_layout() |
|
|
plt.savefig(f'plots/confusion_epoch_{epoch+1}.png', dpi=150) |
|
|
plt.close() |
|
|
except Exception as e: |
|
|
print(f"(Confusion heatmap skipped this epoch: {e})") |
|
|
|
|
|
return encoder, constellation, diag_head, history, best |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_history(history: dict, outdir: str = "plots"): |
|
|
os.makedirs(outdir, exist_ok=True) |
|
|
import matplotlib.pyplot as plt |
|
|
plt.figure(figsize=(10,5)) |
|
|
plt.plot(history['train_acc'], label='Train Acc') |
|
|
plt.plot(history['test_acc'], label='Test Acc') |
|
|
plt.title('Accuracy over Epochs'); plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.grid(True, ls='--', alpha=0.4) |
|
|
plt.tight_layout(); plt.savefig(f"{outdir}/accuracy.png", dpi=150); plt.close() |
|
|
|
|
|
plt.figure(figsize=(10,5)) |
|
|
plt.plot(history['train_loss'], label='Total') |
|
|
plt.plot(history['ce'], label='CE') |
|
|
plt.plot(history['dual'], label='DualNCE') |
|
|
plt.plot(history['rose'], label='ROSE') |
|
|
plt.plot(history['reg'], label='Reg') |
|
|
plt.title('Loss Components'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True, ls='--', alpha=0.4) |
|
|
plt.tight_layout(); plt.savefig(f"{outdir}/loss_components.png", dpi=150); plt.close() |
|
|
|
|
|
plt.figure(figsize=(8,4)) |
|
|
plt.plot(history['lambda']) |
|
|
plt.title('λ (Geometry ↔ Attention Gate)'); plt.xlabel('Epoch'); plt.ylabel('λ'); plt.grid(True, ls='--', alpha=0.4) |
|
|
plt.tight_layout(); plt.savefig(f"{outdir}/lambda.png", dpi=150); plt.close() |
|
|
|
|
|
def plot_confusion(cm: np.ndarray, class_names: list, outpath: str): |
|
|
import matplotlib.pyplot as plt |
|
|
try: |
|
|
import seaborn as sns |
|
|
plt.figure(figsize=(10,8)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=class_names, yticklabels=class_names) |
|
|
plt.title('Best Confusion Matrix'); plt.xlabel('Predicted'); plt.ylabel('True') |
|
|
plt.tight_layout(); plt.savefig(outpath, dpi=150); plt.close() |
|
|
except Exception: |
|
|
plt.figure(figsize=(10,8)) |
|
|
plt.imshow(cm, aspect='auto'); plt.title('Best Confusion Matrix') |
|
|
plt.xlabel('Predicted'); plt.ylabel('True'); plt.colorbar() |
|
|
plt.tight_layout(); plt.savefig(outpath, dpi=150); plt.close() |
|
|
|
|
|
|
|
|
def _sanitize_for_path(s: str, preserve_case: bool = True) -> str: |
|
|
"""Keep letters/digits/._- ; replace others with '_'.""" |
|
|
out = [] |
|
|
for ch in (s if preserve_case else s.lower()): |
|
|
if ch.isalnum() or ch in "._-": |
|
|
out.append(ch) |
|
|
else: |
|
|
out.append("_") |
|
|
return "".join(out) |
|
|
|
|
|
def _build_hf_paths(dataset_name: str, cfg: Dict, ts: str) -> Dict[str, str]: |
|
|
pres = bool(cfg.get("hf_preserve_case", True)) |
|
|
dataset_disp = dataset_name if pres else dataset_name.lower() |
|
|
dataset_token = _sanitize_for_path(dataset_disp, preserve_case=pres) |
|
|
slug = _dataset_slug(dataset_name) |
|
|
|
|
|
root = cfg.get("hf_subdir_root", "pentachora-adaptive-encoded").strip("/") |
|
|
|
|
|
|
|
|
dtempl = cfg.get("hf_dataset_dir_template", "{dataset}") |
|
|
rtempl = cfg.get("hf_run_dir_template", "{ts}_{dataset}") |
|
|
|
|
|
dataset_dir = dtempl.format(dataset=dataset_token, slug=slug, ts=ts) |
|
|
run_dir = rtempl.format(dataset=dataset_token, slug=slug, ts=ts) |
|
|
|
|
|
path_in_repo = f"{root}/{dataset_dir}/{run_dir}".strip("/") |
|
|
local_root = Path("artifacts") / root / dataset_dir / run_dir |
|
|
|
|
|
return { |
|
|
"dataset_token": dataset_token, |
|
|
"path_in_repo": path_in_repo, |
|
|
"local_root": str(local_root), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def save_and_push_artifacts( |
|
|
*, |
|
|
encoder: nn.Module, |
|
|
constellation: nn.Module, |
|
|
diag_head: nn.Module, |
|
|
config: Dict, |
|
|
class_names: List[str], |
|
|
history: Dict, |
|
|
best: Dict, |
|
|
tb_log_dir: Path, |
|
|
dataset_names: List[str], |
|
|
): |
|
|
assert len(dataset_names) == 1, "Pass a single dataset name to save_and_push_artifacts" |
|
|
dataset_name = dataset_names[0] |
|
|
|
|
|
ts = _timestamp() |
|
|
repo_id = _resolve_repo_id(config) |
|
|
_hf_login_if_needed() |
|
|
api = _ensure_repo(repo_id) |
|
|
|
|
|
paths = _build_hf_paths(dataset_name, config, ts) |
|
|
base_out = Path(paths["local_root"]) |
|
|
base_out.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
ds_token = paths["dataset_token"] |
|
|
suffix = f"_{ds_token}" if bool(config.get("hf_weight_suffix_dataset", True)) else "" |
|
|
|
|
|
wdir = base_out / "weights"; wdir.mkdir(parents=True, exist_ok=True) |
|
|
save_safetensors({k: v.cpu() for k, v in encoder.state_dict().items()}, str(wdir / f"encoder{suffix}.safetensors")) |
|
|
save_safetensors({k: v.cpu() for k, v in constellation.state_dict().items()}, str(wdir / f"constellation{suffix}.safetensors")) |
|
|
save_safetensors({k: v.cpu() for k, v in diag_head.state_dict().items()}, str(wdir / f"diagnostic_head{suffix}.safetensors")) |
|
|
|
|
|
|
|
|
(base_out / "config.json").write_text(json.dumps(config, indent=2, sort_keys=True), encoding="utf-8") |
|
|
(base_out / "history.json").write_text(json.dumps(history, indent=2, sort_keys=True), encoding="utf-8") |
|
|
|
|
|
|
|
|
max_len = max(len(history.get("train_loss", [])), len(history.get("train_acc", [])), len(history.get("test_acc", []))) |
|
|
df = pd.DataFrame({ |
|
|
"epoch": list(range(1, max_len + 1)), |
|
|
"train_loss": history.get("train_loss", [np.nan]*max_len), |
|
|
"train_acc": history.get("train_acc", [np.nan]*max_len), |
|
|
"test_acc": history.get("test_acc", [np.nan]*max_len), |
|
|
"ce": history.get("ce", [np.nan]*max_len), |
|
|
"dual": history.get("dual", [np.nan]*max_len), |
|
|
"rose": history.get("rose", [np.nan]*max_len), |
|
|
"diag": history.get("diag", [np.nan]*max_len), |
|
|
"reg": history.get("reg", [np.nan]*max_len), |
|
|
"lambda": history.get("lambda", [np.nan]*max_len), |
|
|
}) |
|
|
df.to_csv(base_out / "history.csv", index=False) |
|
|
|
|
|
|
|
|
if Path("plots").exists(): |
|
|
shutil.copytree("plots", base_out / "plots", dirs_exist_ok=True) |
|
|
|
|
|
|
|
|
tb_dst = base_out / "tensorboard"; tb_dst.mkdir(parents=True, exist_ok=True) |
|
|
if tb_log_dir and Path(tb_log_dir).exists(): |
|
|
for p in Path(tb_log_dir).glob("*"): |
|
|
shutil.copy2(p, tb_dst / p.name) |
|
|
_zip_dir(tb_dst, base_out / "tensorboard_events.zip") |
|
|
|
|
|
|
|
|
manifest = { |
|
|
"timestamp": ts, |
|
|
"repo_id": repo_id, |
|
|
"subdirectory": paths["path_in_repo"], |
|
|
"dataset_name": dataset_name, |
|
|
"class_names": class_names, |
|
|
"num_classes": len(class_names), |
|
|
"models": { |
|
|
"encoder": {"params": _param_count(encoder)}, |
|
|
"constellation": {"params": _param_count(constellation)}, |
|
|
"diagnostic_head": {"params": _param_count(diag_head)}, |
|
|
}, |
|
|
"results": { |
|
|
"best_test_accuracy": float(best.get("acc", 0.0)), |
|
|
"best_epoch": int(best.get("epoch", -1)), |
|
|
}, |
|
|
"environment": { |
|
|
"python": sys.version, |
|
|
"platform": platform.platform(), |
|
|
"torch": torch.__version__, |
|
|
"cuda_available": torch.cuda.is_available(), |
|
|
"cuda_device": (torch.cuda.get_device_name(0) if torch.cuda.is_available() else None), |
|
|
"cpu_count": psutil.cpu_count(logical=True), |
|
|
"memory_gb": round(psutil.virtual_memory().total / (1024**3), 2), |
|
|
}, |
|
|
} |
|
|
(base_out / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True), encoding="utf-8") |
|
|
|
|
|
(base_out / "README.md").write_text( |
|
|
f"""# Pentachora Adaptive Encoded — {ts} |
|
|
|
|
|
**Dataset:** {dataset_name} |
|
|
|
|
|
**Contents** |
|
|
- `weights/*.safetensors` — encoder, constellation, diagnostic head |
|
|
- `config.json`, `manifest.json` |
|
|
- `history.json` / `history.csv` |
|
|
- `tensorboard/` (and `tensorboard_events.zip`) |
|
|
- `plots/` — accuracy, loss, λ, confusion |
|
|
""", |
|
|
encoding="utf-8" |
|
|
) |
|
|
|
|
|
|
|
|
print(f"[push] Uploading to hf://{repo_id}/{paths['path_in_repo']}") |
|
|
api.upload_folder( |
|
|
repo_id=repo_id, |
|
|
folder_path=str(base_out), |
|
|
path_in_repo=paths["path_in_repo"], |
|
|
repo_type="model", |
|
|
commit_message=f"[{dataset_name}] {ts} | best_acc={manifest['results']['best_test_accuracy']:.4f}", |
|
|
) |
|
|
print("[push] ✅ Upload complete.") |
|
|
return base_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_one_dataset(name: str) -> Dict: |
|
|
print("\n" + "="*60) |
|
|
print(f"RUN: {name}") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
train_loader, test_loader, ncls, class_names, in_dim, out_c = get_dataset_single( |
|
|
name, batch_size=config["batch_size"], num_workers=2 |
|
|
) |
|
|
cfg_local = dict(config) |
|
|
cfg_local["num_classes"] = ncls |
|
|
cfg_local["input_dim"] = in_dim |
|
|
cfg_local["input_channels"] = out_c |
|
|
|
|
|
|
|
|
ts = _timestamp() |
|
|
tb_dir = Path("tb_logs") / f"{_dataset_slug(name)}" / ts |
|
|
tb_dir.mkdir(parents=True, exist_ok=True) |
|
|
writer = SummaryWriter(log_dir=str(tb_dir)) |
|
|
|
|
|
start = time.time() |
|
|
encoder, constellation, diag_head, history, best = train_one( |
|
|
train_loader, test_loader, ncls, cfg_local, device, writer, class_names |
|
|
) |
|
|
elapsed_min = (time.time() - start) / 60.0 |
|
|
|
|
|
|
|
|
plot_history(history, outdir="plots") |
|
|
if best["cm"] is not None: |
|
|
plot_confusion(best["cm"], class_names, outpath=f"plots/best_confusion_{_dataset_slug(name)}_epoch_{best['epoch']}.png") |
|
|
|
|
|
|
|
|
save_and_push_artifacts( |
|
|
encoder=encoder, |
|
|
constellation=constellation, |
|
|
diag_head=diag_head, |
|
|
config=cfg_local, |
|
|
class_names=class_names, |
|
|
history=history, |
|
|
best=best, |
|
|
tb_log_dir=tb_dir, |
|
|
dataset_names=[name], |
|
|
) |
|
|
|
|
|
result = { |
|
|
"dataset": name, |
|
|
"classes": ncls, |
|
|
"channels": out_c, |
|
|
"img_size": config.get("img_size", 28), |
|
|
"best_acc": float(best["acc"]), |
|
|
"best_epoch": int(best["epoch"]), |
|
|
"params_encoder": _param_count(encoder), |
|
|
"params_constellation": _param_count(constellation), |
|
|
"elapsed_min": round(elapsed_min, 3), |
|
|
} |
|
|
print(f"[done] {name} -> best_acc={result['best_acc']:.4f} @ epoch {result['best_epoch']} time={result['elapsed_min']:.2f}m") |
|
|
return result |
|
|
|
|
|
def run_sweep(datasets: List[str]) -> Dict: |
|
|
os.makedirs("sweeps", exist_ok=True) |
|
|
results = [] |
|
|
failures = [] |
|
|
for name in datasets: |
|
|
try: |
|
|
results.append(run_one_dataset(name)) |
|
|
except Exception as e: |
|
|
print(f"[fail] {name}: {e}") |
|
|
failures.append({"dataset": name, "error": str(e)}) |
|
|
|
|
|
|
|
|
ts = _timestamp() |
|
|
sweep_dir = Path("sweeps") / ts |
|
|
sweep_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
df = pd.DataFrame(results) |
|
|
df.to_csv(sweep_dir / "results.csv", index=False) |
|
|
(sweep_dir / "results.json").write_text(json.dumps(results, indent=2), encoding="utf-8") |
|
|
(sweep_dir / "failures.json").write_text(json.dumps(failures, indent=2), encoding="utf-8") |
|
|
|
|
|
|
|
|
repo_id = _resolve_repo_id(config) |
|
|
_hf_login_if_needed() |
|
|
api = _ensure_repo(repo_id) |
|
|
path_in_repo = f"pentachora-adaptive-encoded/_sweep/{ts}" |
|
|
print(f"[push] Uploading sweep summary to hf://{repo_id}/{path_in_repo}") |
|
|
api.upload_folder(repo_id=repo_id, folder_path=str(sweep_dir), path_in_repo=path_in_repo, repo_type="model") |
|
|
print("[push] ✅ Sweep summary uploaded.") |
|
|
|
|
|
return {"timestamp": ts, "results": results, "failures": failures, "path_in_repo": path_in_repo} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
print("\n" + "="*60) |
|
|
print("PENTACHORON CONSTELLATION FINAL CONFIGURATION") |
|
|
print("="*60) |
|
|
for k, v in config.items(): |
|
|
print(f"{k:24}: {v}") |
|
|
if config["lr"] > 1e-1: |
|
|
print(f"⚠️ High LR detected ({config['lr']}). If unstable, try 5e-3 to 5e-2.") |
|
|
|
|
|
|
|
|
if bool(config.get("sweep_all", False)) or os.getenv("RUN_SWEEP", "0") == "1": |
|
|
|
|
|
datasets_all = list(TORCHVISION_DATASETS) |
|
|
if medmnist is not None: |
|
|
datasets_all += MEDMNIST_DATASETS |
|
|
out = run_sweep(datasets_all) |
|
|
print(f"\nSweep complete. Summary path: {out['path_in_repo']}") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATASET = config.get("dataset", "FashionMNIST") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_loader, test_loader, ncls, class_names, in_dim, out_c = get_dataset_single( |
|
|
DATASET, batch_size=config["batch_size"], num_workers=2 |
|
|
) |
|
|
config["num_classes"] = ncls |
|
|
config["input_dim"] = in_dim |
|
|
config["input_channels"] = out_c |
|
|
|
|
|
tb_dir = Path("tb_logs") / f"{_dataset_slug(DATASET)}" / _timestamp() |
|
|
tb_dir.mkdir(parents=True, exist_ok=True) |
|
|
writer = SummaryWriter(log_dir=str(tb_dir)) |
|
|
|
|
|
start = time.time() |
|
|
encoder, constellation, diag_head, history, best = train_one( |
|
|
train_loader, test_loader, ncls, config, device, writer, class_names |
|
|
) |
|
|
elapsed = (time.time() - start) / 60.0 |
|
|
|
|
|
|
|
|
plot_history(history, outdir="plots") |
|
|
if best["cm"] is not None: |
|
|
plot_confusion(best["cm"], class_names, outpath=f"plots/best_confusion_epoch_{best['epoch']}.png") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("TRAINING COMPLETE") |
|
|
print("="*60) |
|
|
print(f"Best Test Accuracy : {best['acc']*100:.2f}% @ epoch {best['epoch']}") |
|
|
print(f"Total Training Time: {elapsed:.2f} minutes") |
|
|
|
|
|
save_and_push_artifacts( |
|
|
encoder=encoder, |
|
|
constellation=constellation, |
|
|
diag_head=diag_head, |
|
|
config=config, |
|
|
class_names=class_names, |
|
|
history=history, |
|
|
best=best, |
|
|
tb_log_dir=tb_dir, |
|
|
dataset_names=[DATASET], |
|
|
) |
|
|
print("[done] Artifacts uploaded and saved locally.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|