vit-beans-v3 / trainer.py
AbstractPhil's picture
Create trainer.py
356a611 verified
# train_cantor_fusion_hf.py - PRODUCTION WITH HUGGINGFACE + TENSORBOARD + SAFETENSORS
"""
Cantor Fusion Classifier with HuggingFace Integration
------------------------------------------------------
# Install
try:
!pip uninstall -qy geometricvocab
except:
pass
!pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git
#
Features:
- HuggingFace Hub uploads (ONE shared repo, organized by run)
- TensorBoard logging (loss, accuracy, fusion metrics)
- Easy CIFAR-10/100 switching
- Automatic checkpoint management
- SafeTensors format (ClamAV safe)
- Smart upload intervals
Author: AbstractPhil
License: MIT
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from torch.cuda.amp import autocast, GradScaler
from safetensors.torch import save_file, load_file
import math
import os
import json
from typing import Optional, Dict, List, Tuple, Union
from dataclasses import dataclass, asdict
import time
from pathlib import Path
from tqdm import tqdm
# HuggingFace
from huggingface_hub import HfApi, create_repo, upload_folder, upload_file
import yaml
# Import from your repo
from geovocab2.train.model.layers.attention.cantor_multiheaded_fusion import (
CantorMultiheadFusion,
CantorFusionConfig
)
from geovocab2.shapes.factory.cantor_route_factory import (
CantorRouteFactory,
RouteMode,
SimplexConfig
)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Configuration
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@dataclass
class CantorTrainingConfig:
"""Complete configuration for Cantor fusion training."""
# Dataset
dataset: str = "cifar10" # "cifar10" or "cifar100"
num_classes: int = 10
# Architecture
image_size: int = 32
patch_size: int = 4
embed_dim: int = 384
num_fusion_blocks: int = 6
num_heads: int = 8
fusion_window: int = 32
fusion_mode: str = "weighted" # "weighted" or "consciousness"
k_simplex: int = 4
use_beatrix: bool = False
beatrix_tau: float = 0.25
# Optimization
precompute_geometric: bool = True
use_torch_compile: bool = True
use_mixed_precision: bool = False
# Regularization
dropout: float = 0.1
drop_path_rate: float = 0.15
# Training
batch_size: int = 128
num_epochs: int = 100
learning_rate: float = 3e-4
weight_decay: float = 0.05
warmup_epochs: int = 5
grad_clip: float = 1.0
# Data augmentation
use_augmentation: bool = True
use_autoaugment: bool = True
# System
device: str = "cuda" if torch.cuda.is_available() else "cpu"
num_workers: int = 4
seed: int = 42
# Paths
weights_dir: str = "weights"
model_name: str = "vit-beans-v3"
run_name: Optional[str] = None # Auto-generated if None
# HuggingFace - ONE SHARED REPO
hf_username: str = "AbstractPhil"
hf_repo_name: Optional[str] = None # Auto-generated if None (shared repo)
upload_to_hf: bool = True
hf_token: Optional[str] = None # Set via environment or pass directly
# Logging
log_interval: int = 50 # Log every N batches
save_interval: int = 10 # Save checkpoint every N epochs
checkpoint_upload_interval: int = 10 # Upload checkpoint every N epochs
def __post_init__(self):
# Auto-set num_classes based on dataset
if self.dataset == "cifar10":
self.num_classes = 10
elif self.dataset == "cifar100":
self.num_classes = 100
else:
raise ValueError(f"Unknown dataset: {self.dataset}")
# Auto-generate run name
if self.run_name is None:
timestamp = time.strftime("%Y%m%d_%H%M%S")
self.run_name = f"{self.dataset}_{self.fusion_mode}_{timestamp}"
# ONE SHARED REPO for all runs
if self.hf_repo_name is None:
self.hf_repo_name = self.model_name # "cantor-fusion-cifar"
# Set HF token from environment if not provided
if self.hf_token is None:
self.hf_token = os.environ.get("HF_TOKEN")
# Calculate derived values
assert self.image_size % self.patch_size == 0
self.num_patches = (self.image_size // self.patch_size) ** 2
self.patch_dim = self.patch_size * self.patch_size * 3
# Create paths
self.output_dir = Path(self.weights_dir) / self.model_name / self.run_name
self.checkpoint_dir = self.output_dir / "checkpoints"
self.tensorboard_dir = self.output_dir / "tensorboard"
# Create directories
self.output_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.tensorboard_dir.mkdir(parents=True, exist_ok=True)
def save(self, path: Union[str, Path]):
"""Save config to YAML file."""
path = Path(path)
with open(path, 'w') as f:
yaml.dump(asdict(self), f, default_flow_style=False)
@classmethod
def load(cls, path: Union[str, Path]):
"""Load config from YAML file."""
path = Path(path)
with open(path, 'r') as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Model Components
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class PatchEmbedding(nn.Module):
"""Patch embedding layer."""
def __init__(self, config: CantorTrainingConfig):
super().__init__()
self.config = config
self.proj = nn.Conv2d(3, config.embed_dim, kernel_size=config.patch_size, stride=config.patch_size)
self.pos_embed = nn.Parameter(torch.randn(1, config.num_patches, config.embed_dim) * 0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
x = x + self.pos_embed
return x
class DropPath(nn.Module):
"""Stochastic depth."""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class CantorFusionBlock(nn.Module):
"""Cantor fusion block."""
def __init__(self, config: CantorTrainingConfig, drop_path: float = 0.0):
super().__init__()
self.norm1 = nn.LayerNorm(config.embed_dim)
fusion_config = CantorFusionConfig(
dim=config.embed_dim,
num_heads=config.num_heads,
fusion_window=config.fusion_window,
fusion_mode=config.fusion_mode,
k_simplex=config.k_simplex,
use_beatrix_routing=config.use_beatrix,
use_consciousness_weighting=(config.fusion_mode == "consciousness"),
beatrix_tau=config.beatrix_tau,
use_gating=True,
dropout=config.dropout,
residual=False,
precompute_staircase=config.precompute_geometric,
precompute_routes=config.precompute_geometric,
precompute_distances=config.precompute_geometric,
use_optimized_gather=True,
staircase_cache_sizes=[config.num_patches],
use_torch_compile=config.use_torch_compile
)
self.fusion = CantorMultiheadFusion(fusion_config)
self.norm2 = nn.LayerNorm(config.embed_dim)
mlp_hidden = config.embed_dim * 4
self.mlp = nn.Sequential(
nn.Linear(config.embed_dim, mlp_hidden),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(mlp_hidden, config.embed_dim),
nn.Dropout(config.dropout)
)
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
def forward(self, x: torch.Tensor, return_fusion_info: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
fusion_result = self.fusion(self.norm1(x))
x = x + self.drop_path(fusion_result['output'])
x = x + self.drop_path(self.mlp(self.norm2(x)))
if return_fusion_info:
fusion_info = {
'consciousness': fusion_result.get('consciousness'),
'cantor_measure': fusion_result.get('cantor_measure')
}
return x, fusion_info
return x
class CantorClassifier(nn.Module):
"""Cantor fusion classifier."""
def __init__(self, config: CantorTrainingConfig):
super().__init__()
self.config = config
self.patch_embed = PatchEmbedding(config)
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_fusion_blocks)]
self.blocks = nn.ModuleList([
CantorFusionBlock(config, drop_path=dpr[i])
for i in range(config.num_fusion_blocks)
])
self.norm = nn.LayerNorm(config.embed_dim)
self.head = nn.Linear(config.embed_dim, config.num_classes)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x: torch.Tensor, return_fusion_info: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict]]]:
x = self.patch_embed(x)
fusion_infos = []
for i, block in enumerate(self.blocks):
if return_fusion_info and i == len(self.blocks) - 1:
x, fusion_info = block(x, return_fusion_info=True)
fusion_infos.append(fusion_info)
else:
x = block(x)
x = self.norm(x)
x = x.mean(dim=1)
logits = self.head(x)
if return_fusion_info:
return logits, fusion_infos
return logits
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# HuggingFace Integration - ONE SHARED REPO
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class HuggingFaceUploader:
"""Manages HuggingFace Hub uploads to ONE shared repo."""
def __init__(self, config: CantorTrainingConfig):
self.config = config
self.api = HfApi(token=config.hf_token) if config.upload_to_hf else None
self.repo_id = f"{config.hf_username}/{config.hf_repo_name}"
# Organize by run inside the shared repo
self.run_prefix = f"runs/{config.run_name}"
if config.upload_to_hf:
self._create_repo()
self._update_main_readme() # NEW: Update main README
def _create_repo(self):
"""Create HuggingFace repo if it doesn't exist."""
try:
create_repo(
repo_id=self.repo_id,
token=self.config.hf_token,
exist_ok=True,
private=False
)
print(f"[HF] Repository: https://huggingface.co/{self.repo_id}")
print(f"[HF] Run folder: {self.run_prefix}")
except Exception as e:
print(f"[HF] Warning: Could not create repo: {e}")
def _update_main_readme(self):
"""Create or update the main shared README at repo root."""
if not self.config.upload_to_hf or self.api is None:
return
main_readme = f"""---
tags:
- image-classification
- cantor-fusion
- geometric-deep-learning
- safetensors
- vision-transformer
library_name: pytorch
datasets:
- cifar10
- cifar100
metrics:
- accuracy
---
# {self.config.hf_repo_name}
**Geometric Deep Learning with Cantor Multihead Fusion**
This repository contains multiple training runs using Cantor fusion architecture with pentachoron structures and geometric routing. All models use SafeTensors format for security.
## Repository Structure
```
{self.config.hf_repo_name}/
β”œβ”€β”€ runs/
β”‚ β”œβ”€β”€ cifar10_weighted_TIMESTAMP/
β”‚ β”‚ β”œβ”€β”€ checkpoints/
β”‚ β”‚ β”‚ β”œβ”€β”€ best_model.safetensors
β”‚ β”‚ β”‚ β”œβ”€β”€ best_training_state.pt
β”‚ β”‚ β”‚ └── best_metadata.json
β”‚ β”‚ β”œβ”€β”€ tensorboard/
β”‚ β”‚ β”œβ”€β”€ config.yaml
β”‚ β”‚ └── README.md
β”‚ β”œβ”€β”€ cifar100_consciousness_TIMESTAMP/
β”‚ β”‚ └── ...
β”‚ └── ...
└── README.md (this file)
```
## Current Run
**Latest**: `{self.config.run_name}`
- **Dataset**: {self.config.dataset.upper()}
- **Fusion Mode**: {self.config.fusion_mode}
- **Architecture**: {self.config.num_fusion_blocks} blocks, {self.config.num_heads} heads
- **Simplex**: {self.config.k_simplex}-simplex ({self.config.k_simplex + 1} vertices)
## Architecture
The Cantor Fusion architecture uses:
- **Geometric Routing**: Pentachoron (5-simplex) structures for token routing
- **Cantor Multihead Fusion**: Multiple fusion heads with geometric attention
- **Beatrix Consciousness Routing**: Optional consciousness-aware token fusion using the Devil's Staircase
- **SafeTensors Format**: All model weights use SafeTensors (not pickle) for security
## Usage
### Download a Model
```python
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
# Download model weights
model_path = hf_hub_download(
repo_id="{self.repo_id}",
filename="runs/YOUR_RUN_NAME/checkpoints/best_model.safetensors"
)
# Load weights (SafeTensors - no pickle!)
state_dict = load_file(model_path)
model.load_state_dict(state_dict)
```
### Browse Runs
Each run directory contains:
- `checkpoints/` - Model weights (safetensors), training state, metadata
- `tensorboard/` - TensorBoard logs for visualization
- `config.yaml` - Complete training configuration
- `README.md` - Run-specific details and results
## Model Variants
- **Weighted Fusion**: Standard geometric fusion with learned weights
- **Consciousness Fusion**: Uses Beatrix routing with consciousness emergence
## Citation
```bibtex
@misc{{{self.config.hf_repo_name.replace('-', '_')},
author = {{AbstractPhil}},
title = {{{self.config.hf_repo_name}: Geometric Deep Learning with Cantor Fusion}},
year = {{2025}},
publisher = {{HuggingFace}},
url = {{https://huggingface.co/{self.repo_id}}}
}}
```
## Training Details
All models trained with:
- Optimizer: AdamW
- Mixed Precision: Available on A100
- Augmentation: AutoAugment (CIFAR10 policy)
- Format: SafeTensors (ClamAV safe)
Built with geometric consciousness-aware routing using the Devil's Staircase (Beatrix) and pentachoron parameterization.
---
**Repository maintained by**: [@{self.config.hf_username}](https://huggingface.co/{self.config.hf_username})
**Latest update**: {time.strftime("%Y-%m-%d %H:%M:%S")}
"""
# Save main README locally
main_readme_path = Path(self.config.weights_dir) / self.config.model_name / "MAIN_README.md"
main_readme_path.parent.mkdir(parents=True, exist_ok=True)
with open(main_readme_path, 'w') as f:
f.write(main_readme)
try:
# Upload to repo root (not inside runs/)
upload_file(
path_or_fileobj=str(main_readme_path),
path_in_repo="README.md", # Root level!
repo_id=self.repo_id,
token=self.config.hf_token
)
print(f"[HF] Updated main README")
except Exception as e:
print(f"[HF] Main README upload failed: {e}")
def upload_checkpoint(self, checkpoint_path: Path, is_best: bool = False):
"""Upload checkpoint to HuggingFace."""
if not self.config.upload_to_hf or self.api is None:
return
try:
# Upload to run-specific folder
path_in_repo = f"{self.run_prefix}/checkpoints/{checkpoint_path.name}"
if is_best:
path_in_repo = f"{self.run_prefix}/checkpoints/best_model.pt"
upload_file(
path_or_fileobj=str(checkpoint_path),
path_in_repo=path_in_repo,
repo_id=self.repo_id,
token=self.config.hf_token
)
print(f"[HF] Uploaded: {path_in_repo}")
except Exception as e:
print(f"[HF] Upload failed: {e}")
def upload_file(self, file_path: Path, repo_path: str):
"""Upload single file to HuggingFace."""
if not self.config.upload_to_hf or self.api is None:
return
try:
# Prepend run prefix if not already there
if not repo_path.startswith(self.run_prefix) and not repo_path.startswith("runs/"):
full_path = f"{self.run_prefix}/{repo_path}"
else:
full_path = repo_path
upload_file(
path_or_fileobj=str(file_path),
path_in_repo=full_path,
repo_id=self.repo_id,
token=self.config.hf_token
)
print(f"[HF] βœ“ Uploaded: {full_path}")
except Exception as e:
print(f"[HF] βœ— Upload failed ({full_path}): {e}")
def upload_folder_contents(self, folder_path: Path, repo_folder: str):
"""Upload entire folder to HuggingFace."""
if not self.config.upload_to_hf or self.api is None:
return
try:
# Upload to run-specific folder
full_path = f"{self.run_prefix}/{repo_folder}"
upload_folder(
folder_path=str(folder_path),
repo_id=self.repo_id,
path_in_repo=full_path,
token=self.config.hf_token,
ignore_patterns=["*.pyc", "__pycache__"]
)
print(f"[HF] Uploaded folder: {full_path}")
except Exception as e:
print(f"[HF] Folder upload failed: {e}")
def create_model_card(self, trainer_stats: Dict):
"""Create and upload run-specific model card."""
if not self.config.upload_to_hf:
return
run_card = f"""# Run: {self.config.run_name}
## Configuration
- **Dataset**: {self.config.dataset.upper()}
- **Fusion Mode**: {self.config.fusion_mode}
- **Parameters**: {trainer_stats['total_params']:,}
- **Simplex**: {self.config.k_simplex}-simplex ({self.config.k_simplex + 1} vertices)
## Performance
- **Best Validation Accuracy**: {trainer_stats['best_acc']:.2f}%
- **Training Time**: {trainer_stats['training_time']:.1f} hours
- **Batch Size**: {trainer_stats.get('batch_size', 'N/A')}
- **Mixed Precision**: {trainer_stats.get('mixed_precision', False)}
- **Final Epoch**: {trainer_stats['final_epoch']}
## Files
- `{self.run_prefix}/checkpoints/best_model.safetensors` - Model weights (SafeTensors)
- `{self.run_prefix}/checkpoints/best_training_state.pt` - Optimizer/scheduler state
- `{self.run_prefix}/checkpoints/best_metadata.json` - Training metadata
- `{self.run_prefix}/config.yaml` - Full configuration
- `{self.run_prefix}/tensorboard/` - TensorBoard logs
## Usage
```python
from safetensors.torch import load_file
import torch
# Download from HuggingFace Hub
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(
repo_id="{self.repo_id}",
filename="{self.run_prefix}/checkpoints/best_model.safetensors"
)
# Load model weights (SafeTensors - no pickle!)
state_dict = load_file(model_path)
model.load_state_dict(state_dict)
```
## Training Configuration
```yaml
embed_dim: {self.config.embed_dim}
num_fusion_blocks: {self.config.num_fusion_blocks}
num_heads: {self.config.num_heads}
fusion_mode: {self.config.fusion_mode}
k_simplex: {self.config.k_simplex}
learning_rate: {self.config.learning_rate}
batch_size: {self.config.batch_size}
epochs: {self.config.num_epochs}
weight_decay: {self.config.weight_decay}
```
## Details
Built with geometric consciousness-aware routing using the Devil's Staircase (Beatrix) and pentachoron parameterization.
**Training completed**: {time.strftime("%Y-%m-%d %H:%M:%S")}
**Safe Format**: All model weights use SafeTensors (not pickle) for maximum security.
---
[← Back to main repository](https://huggingface.co/{self.repo_id})
"""
# Save run-specific README
readme_path = self.config.output_dir / "RUN_README.md"
with open(readme_path, 'w') as f:
f.write(run_card)
try:
upload_file(
path_or_fileobj=str(readme_path),
path_in_repo=f"{self.run_prefix}/README.md",
repo_id=self.repo_id,
token=self.config.hf_token
)
print(f"[HF] Uploaded run README")
except Exception as e:
print(f"[HF] Run README upload failed: {e}")
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Trainer with TensorBoard + HuggingFace + SafeTensors
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class Trainer:
"""Training manager with TensorBoard, HuggingFace, and SafeTensors."""
def __init__(self, config: CantorTrainingConfig):
self.config = config
self.device = torch.device(config.device)
# Set seed
torch.manual_seed(config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(config.seed)
# Model
print("\n" + "=" * 70)
print(f"Initializing Cantor Classifier - {config.dataset.upper()}")
print("=" * 70)
init_start = time.time()
self.model = CantorClassifier(config).to(self.device)
init_time = time.time() - init_start
print(f"\n[Model] Initialization time: {init_time:.2f}s")
self.print_model_info()
# Optimizer & Scheduler
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
self.scheduler = self.create_scheduler()
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# Mixed precision
self.use_amp = config.use_mixed_precision and config.device == "cuda"
self.scaler = GradScaler() if self.use_amp else None
if self.use_amp:
print(f"[Training] Mixed precision enabled")
# TensorBoard
self.writer = SummaryWriter(log_dir=str(config.tensorboard_dir))
print(f"[TensorBoard] Logging to: {config.tensorboard_dir}")
print(f"[Checkpoints] Format: SafeTensors (ClamAV safe)")
# HuggingFace
self.hf_uploader = HuggingFaceUploader(config) if config.upload_to_hf else None
# Save config
config.save(config.output_dir / "config.yaml")
# Metrics
self.best_acc = 0.0
self.global_step = 0
self.start_time = time.time()
self.upload_count = 0
def print_model_info(self):
"""Print model info."""
total_params = sum(p.numel() for p in self.model.parameters())
print(f"\nParameters: {total_params:,}")
print(f"Dataset: {self.config.dataset.upper()}")
print(f"Classes: {self.config.num_classes}")
print(f"Fusion mode: {self.config.fusion_mode}")
print(f"Output: {self.config.output_dir}")
def create_scheduler(self):
"""Create scheduler with warmup."""
def lr_lambda(epoch):
if epoch < self.config.warmup_epochs:
return (epoch + 1) / self.config.warmup_epochs
progress = (epoch - self.config.warmup_epochs) / (self.config.num_epochs - self.config.warmup_epochs)
return 0.5 * (1 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
def train_epoch(self, train_loader: DataLoader, epoch: int) -> Tuple[float, float]:
"""Train one epoch."""
self.model.train()
total_loss, correct, total = 0.0, 0, 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs} [Train]")
for batch_idx, (images, labels) in enumerate(pbar):
images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True)
# Forward
if self.use_amp:
with autocast():
logits = self.model(images)
loss = self.criterion(logits, labels)
self.optimizer.zero_grad(set_to_none=True)
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
logits = self.model(images)
loss = self.criterion(logits, labels)
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
self.optimizer.step()
# Metrics
total_loss += loss.item()
_, predicted = logits.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
# TensorBoard logging
if batch_idx % self.config.log_interval == 0:
self.writer.add_scalar('train/loss', loss.item(), self.global_step)
self.writer.add_scalar('train/accuracy', 100. * correct / total, self.global_step)
self.writer.add_scalar('train/learning_rate', self.scheduler.get_last_lr()[0], self.global_step)
self.global_step += 1
pbar.set_postfix({
'loss': f'{loss.item():.4f}',
'acc': f'{100. * correct / total:.2f}%',
'lr': f'{self.scheduler.get_last_lr()[0]:.6f}'
})
return total_loss / len(train_loader), 100. * correct / total
@torch.no_grad()
def evaluate(self, val_loader: DataLoader, epoch: int) -> Tuple[float, Dict]:
"""Evaluate."""
self.model.eval()
total_loss, correct, total = 0.0, 0, 0
consciousness_values = []
pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs} [Val] ")
for batch_idx, (images, labels) in enumerate(pbar):
images, labels = images.to(self.device, non_blocking=True), labels.to(self.device, non_blocking=True)
# Forward with fusion info on last batch
return_info = (batch_idx == len(val_loader) - 1)
if self.use_amp:
with autocast():
if return_info:
logits, fusion_infos = self.model(images, return_fusion_info=True)
if fusion_infos and fusion_infos[0].get('consciousness') is not None:
consciousness_values.append(fusion_infos[0]['consciousness'].mean().item())
else:
logits = self.model(images)
loss = self.criterion(logits, labels)
else:
if return_info:
logits, fusion_infos = self.model(images, return_fusion_info=True)
if fusion_infos and fusion_infos[0].get('consciousness') is not None:
consciousness_values.append(fusion_infos[0]['consciousness'].mean().item())
else:
logits = self.model(images)
loss = self.criterion(logits, labels)
total_loss += loss.item()
_, predicted = logits.max(1)
correct += predicted.eq(labels).sum().item()
total += labels.size(0)
pbar.set_postfix({
'loss': f'{total_loss / (batch_idx + 1):.4f}',
'acc': f'{100. * correct / total:.2f}%'
})
avg_loss = total_loss / len(val_loader)
accuracy = 100. * correct / total
# TensorBoard logging
self.writer.add_scalar('val/loss', avg_loss, epoch)
self.writer.add_scalar('val/accuracy', accuracy, epoch)
if consciousness_values:
self.writer.add_scalar('val/consciousness', sum(consciousness_values) / len(consciousness_values), epoch)
metrics = {
'loss': avg_loss,
'accuracy': accuracy,
'consciousness': sum(consciousness_values) / len(consciousness_values) if consciousness_values else None
}
return accuracy, metrics
def train(self, train_loader: DataLoader, val_loader: DataLoader):
"""Full training loop."""
print("\n" + "=" * 70)
print("Starting training...")
print(f"Format: SafeTensors (model) + PT (training state)")
print(f"Upload: Best + every {self.config.checkpoint_upload_interval} epochs")
print("=" * 70 + "\n")
for epoch in range(self.config.num_epochs):
# Train
train_loss, train_acc = self.train_epoch(train_loader, epoch)
# Evaluate
val_acc, val_metrics = self.evaluate(val_loader, epoch)
# Update scheduler
self.scheduler.step()
# Print summary
print(f"\n{'='*70}")
print(f"Epoch [{epoch + 1}/{self.config.num_epochs}] Summary:")
print(f" Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
print(f" Val: Loss={val_metrics['loss']:.4f}, Acc={val_acc:.2f}%")
if val_metrics['consciousness']:
print(f" Consciousness: {val_metrics['consciousness']:.4f}")
# Checkpoint logic
is_best = val_acc > self.best_acc
should_save_regular = ((epoch + 1) % self.config.save_interval == 0)
should_upload_regular = ((epoch + 1) % self.config.checkpoint_upload_interval == 0)
if is_best:
self.best_acc = val_acc
print(f" βœ“ New best model! Accuracy: {val_acc:.2f}%")
# Save best locally, upload only on interval
self.save_checkpoint(epoch, val_acc, prefix="best", upload=should_upload_regular)
if should_save_regular:
self.save_checkpoint(epoch, val_acc, prefix=f"epoch_{epoch+1}", upload=should_upload_regular)
print(f" HF Uploads: {self.upload_count}")
print(f"{'='*70}\n")
# Flush TensorBoard
if (epoch + 1) % 10 == 0:
self.writer.flush()
# Training complete
training_time = (time.time() - self.start_time) / 3600
print("\n" + "=" * 70)
print("Training Complete!")
print(f"Best Validation Accuracy: {self.best_acc:.2f}%")
print(f"Training Time: {training_time:.2f} hours")
print(f"Total Uploads: {self.upload_count}")
print("=" * 70)
# Upload to HuggingFace
if self.hf_uploader:
# Always upload final best model
print("\n[HF] Uploading final best model...")
best_model_path = self.config.checkpoint_dir / "best_model.safetensors"
best_state_path = self.config.checkpoint_dir / "best_training_state.pt"
best_metadata_path = self.config.checkpoint_dir / "best_metadata.json"
config_path = self.config.output_dir / "config.yaml"
if best_model_path.exists():
self.hf_uploader.upload_file(best_model_path, "checkpoints/best_model.safetensors")
if best_state_path.exists():
self.hf_uploader.upload_file(best_state_path, "checkpoints/best_training_state.pt")
if best_metadata_path.exists():
self.hf_uploader.upload_file(best_metadata_path, "checkpoints/best_metadata.json")
if config_path.exists():
self.hf_uploader.upload_file(config_path, "config.yaml")
print("[HF] Final upload: TensorBoard logs...")
self.hf_uploader.upload_folder_contents(self.config.tensorboard_dir, "tensorboard")
trainer_stats = {
'total_params': sum(p.numel() for p in self.model.parameters()),
'best_acc': self.best_acc,
'training_time': training_time,
'final_epoch': self.config.num_epochs,
'batch_size': self.config.batch_size,
'mixed_precision': self.use_amp
}
self.hf_uploader.create_model_card(trainer_stats)
self.writer.close()
def save_checkpoint(self, epoch: int, accuracy: float, prefix: str = "checkpoint", upload: bool = False):
"""Save checkpoint as safetensors with selective upload."""
checkpoint_dir = self.config.checkpoint_dir
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# 1. Save model weights as safetensors (SAFE!)
model_path = checkpoint_dir / f"{prefix}_model.safetensors"
save_file(self.model.state_dict(), str(model_path))
# 2. Save optimizer/scheduler state separately (small .pt files)
training_state = {
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
}
if self.scaler is not None:
training_state['scaler_state_dict'] = self.scaler.state_dict()
training_state_path = checkpoint_dir / f"{prefix}_training_state.pt"
torch.save(training_state, training_state_path)
# 3. Save metadata as JSON
metadata = {
'epoch': epoch,
'accuracy': accuracy,
'best_accuracy': self.best_acc,
'global_step': self.global_step,
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
}
metadata_path = checkpoint_dir / f"{prefix}_metadata.json"
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
is_best = (prefix == "best")
if is_best:
print(f" πŸ’Ύ Saved best: {prefix}_model.safetensors")
else:
print(f" πŸ’Ύ Saved: {prefix}_model.safetensors", end="")
# Upload to HuggingFace
if self.hf_uploader and upload:
# Upload model weights (safetensors)
self.hf_uploader.upload_file(
model_path,
f"checkpoints/{prefix}_model.safetensors"
)
# Upload training state (.pt - small file)
self.hf_uploader.upload_file(
training_state_path,
f"checkpoints/{prefix}_training_state.pt"
)
# Upload metadata (json)
self.hf_uploader.upload_file(
metadata_path,
f"checkpoints/{prefix}_metadata.json"
)
# Upload config (only for best)
if is_best:
config_path = self.config.output_dir / "config.yaml"
if config_path.exists():
self.hf_uploader.upload_file(config_path, "config.yaml")
self.upload_count += 1
if not is_best:
print(" β†’ Uploaded to HF")
else:
if not is_best:
print(" (local only)")
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Data Loading
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def get_data_loaders(config: CantorTrainingConfig) -> Tuple[DataLoader, DataLoader]:
"""Create data loaders."""
# Normalization (same for both datasets)
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)
# Augmentation
if config.use_augmentation:
if config.use_autoaugment:
policy = transforms.AutoAugmentPolicy.CIFAR10
train_transform = transforms.Compose([
transforms.AutoAugment(policy),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
else:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
# Dataset selection
if config.dataset == "cifar10":
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_transform)
elif config.dataset == "cifar100":
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
val_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=val_transform)
else:
raise ValueError(f"Unknown dataset: {config.dataset}")
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
pin_memory=(config.device == "cuda")
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
pin_memory=(config.device == "cuda")
)
return train_loader, val_loader
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Main
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def main():
"""Main training function."""
config = CantorTrainingConfig(
# Dataset: "cifar10" or "cifar100"
dataset="cifar100",
# Architecture
embed_dim=512,
num_fusion_blocks=6,
num_heads=8,
fusion_mode="consciousness", # "weighted" or "consciousness"
k_simplex=4,
use_beatrix=False,
# Training
batch_size=128,
num_epochs=100,
learning_rate=3e-4,
# Augmentation
use_augmentation=True,
use_autoaugment=True,
# System
device="cuda",
# HuggingFace - ONE SHARED REPO
hf_username="AbstractPhil",
upload_to_hf=True,
)
print("=" * 70)
print(f"Cantor Fusion Classifier - {config.dataset.upper()}")
print("=" * 70)
print(f"\nConfiguration:")
print(f" Dataset: {config.dataset}")
print(f" Fusion mode: {config.fusion_mode}")
print(f" Output: {config.output_dir}")
print(f" HuggingFace: {'Enabled' if config.upload_to_hf else 'Disabled'}")
if config.upload_to_hf:
print(f" Repo: {config.hf_username}/{config.hf_repo_name}")
print(f" Run: {config.run_name}")
# Load data
print("\nLoading data...")
train_loader, val_loader = get_data_loaders(config)
print(f" Train: {len(train_loader.dataset)} samples")
print(f" Val: {len(val_loader.dataset)} samples")
# Train
trainer = Trainer(config)
trainer.train(train_loader, val_loader)
if __name__ == "__main__":
main()