| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torchvision import datasets, transforms |
| | from torch.utils.data import DataLoader |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from tqdm import tqdm |
| | from torch.utils.tensorboard import SummaryWriter |
| | from huggingface_hub import HfApi, create_repo, upload_folder |
| | from safetensors.torch import save_file, load_file |
| | import os |
| | import json |
| | import hashlib |
| | from datetime import datetime |
| | from google.colab import userdata |
| |
|
| | |
| | HF_TOKEN = userdata.get('HF_TOKEN') |
| | REPO_ID = "AbstractPhil/penta-classifier-prototype" |
| |
|
| | |
| | run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | config_str = f"emnist_byclass_b1024_lr1e-3_{run_timestamp}" |
| | run_hash = hashlib.md5(config_str.encode()).hexdigest()[:8] |
| |
|
| | |
| | os.makedirs("checkpoints", exist_ok=True) |
| | os.makedirs("tensorboard_logs", exist_ok=True) |
| |
|
| | |
| | writer = SummaryWriter(f'tensorboard_logs/{run_hash}') |
| |
|
| | |
| | api = HfApi() |
| | try: |
| | create_repo(REPO_ID, repo_type="model", token=HF_TOKEN, exist_ok=True) |
| | print(f"Using HuggingFace repo: {REPO_ID}") |
| | except Exception as e: |
| | print(f"Repo setup: {e}") |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| | if device.type == "cuda": |
| | print(f"GPU: {torch.cuda.get_device_name(0)}") |
| | print(f"Memory Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB") |
| | torch.backends.cudnn.benchmark = True |
| | torch.backends.cudnn.enabled = True |
| |
|
| | |
| | config = { |
| | "input_dim": 28 * 28, |
| | "base_dim": 64, |
| | "batch_size": 1024, |
| | "epochs": 5, |
| | "initial_lr": 1e-3, |
| | "temp_contrastive": 0.1, |
| | "lambda_contrastive": 0.5, |
| | "lambda_cayley": 0.01, |
| | "dataset": "EMNIST_byclass", |
| | "run_hash": run_hash, |
| | "timestamp": run_timestamp |
| | } |
| |
|
| | |
| | config_path = f"checkpoints/config_{run_hash}.json" |
| | with open(config_path, 'w') as f: |
| | json.dump(config, f, indent=2) |
| |
|
| | |
| | writer.add_text('Config', json.dumps(config, indent=2), 0) |
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Lambda(lambda x: x.view(-1)) |
| | ]) |
| |
|
| | train_dataset = datasets.EMNIST(root="./data", split='byclass', train=True, transform=transform, download=True) |
| | test_dataset = datasets.EMNIST(root="./data", split='byclass', train=False, transform=transform, download=True) |
| |
|
| | num_classes = len(train_dataset.classes) |
| | config["num_classes"] = num_classes |
| |
|
| | train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], pin_memory=True, |
| | shuffle=True, num_workers=4, prefetch_factor=8) |
| | test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], pin_memory=True, |
| | shuffle=False, num_workers=4, prefetch_factor=8) |
| |
|
| | print(f"Train: {len(train_dataset)} samples, Test: {len(test_dataset)} samples") |
| | print(f"Classes: {num_classes}") |
| |
|
| | |
| | class AdaptiveEncoder(nn.Module): |
| | """Multi-layer encoder with normalization and multi-scale outputs""" |
| | def __init__(self, input_dim, base_dim=128): |
| | super().__init__() |
| | self.fc1 = nn.Linear(input_dim, 512) |
| | self.bn1 = nn.BatchNorm1d(512) |
| | self.dropout1 = nn.Dropout(0.2) |
| |
|
| | self.fc2 = nn.Linear(512, 256) |
| | self.bn2 = nn.BatchNorm1d(256) |
| | self.dropout2 = nn.Dropout(0.2) |
| |
|
| | self.fc3 = nn.Linear(256, 128) |
| | self.bn3 = nn.BatchNorm1d(128) |
| |
|
| | self.fc_coarse = nn.Linear(256, base_dim // 4) |
| | self.fc_medium = nn.Linear(128, base_dim // 2) |
| | self.fc_fine = nn.Linear(128, base_dim) |
| |
|
| | self.norm_coarse = nn.LayerNorm(base_dim // 4) |
| | self.norm_medium = nn.LayerNorm(base_dim // 2) |
| | self.norm_fine = nn.LayerNorm(base_dim) |
| |
|
| | def forward(self, x): |
| | h1 = F.relu(self.bn1(self.fc1(x))) |
| | h1 = self.dropout1(h1) |
| | h2 = F.relu(self.bn2(self.fc2(h1))) |
| | h2 = self.dropout2(h2) |
| | h3 = F.relu(self.bn3(self.fc3(h2))) |
| |
|
| | coarse = self.norm_coarse(self.fc_coarse(h2)) |
| | medium = self.norm_medium(self.fc_medium(h3)) |
| | fine = self.norm_fine(self.fc_fine(h3)) |
| |
|
| | return coarse, medium, fine |
| |
|
| | def init_perfect_pentachora(num_classes, latent_dim, device='cuda'): |
| | """Initialize as regular 4-simplices in orthogonal subspaces""" |
| | pentachora = torch.zeros(num_classes, 5, latent_dim, device=device) |
| |
|
| | sqrt15 = np.sqrt(15) |
| | sqrt10 = np.sqrt(10) |
| | sqrt5 = np.sqrt(5) |
| |
|
| | simplex = torch.tensor([ |
| | [1.0, 0.0, 0.0, 0.0], |
| | [-0.25, sqrt15/4, 0.0, 0.0], |
| | [-0.25, -sqrt15/12, sqrt10/3, 0.0], |
| | [-0.25, -sqrt15/12, -sqrt10/6, sqrt5/2], |
| | [-0.25, -sqrt15/12, -sqrt10/6, -sqrt5/2] |
| | ], dtype=torch.float32, device=device) |
| |
|
| | simplex = F.normalize(simplex, dim=1) |
| |
|
| | dims_per_class = latent_dim // num_classes |
| | for c in range(num_classes): |
| | if dims_per_class >= 4: |
| | start = c * dims_per_class |
| | pentachora[c, :, start:start+4] = simplex |
| | else: |
| | rotation = torch.randn(4, latent_dim, device=device) |
| | rotation = F.normalize(rotation, dim=1) |
| | pentachora[c] = torch.mm(simplex, rotation[:4]) |
| |
|
| | return nn.Parameter(pentachora * 2.0) |
| |
|
| | class PerfectPentachoron(nn.Module): |
| | """Multi-scale pentachoron with learnable metric and vertex weights""" |
| | def __init__(self, num_classes, base_dim, device='cuda'): |
| | super().__init__() |
| | self.device = device |
| | self.num_classes = num_classes |
| | self.base_dim = base_dim |
| |
|
| | self.penta_coarse = init_perfect_pentachora(num_classes, base_dim // 4, device) |
| | self.penta_medium = init_perfect_pentachora(num_classes, base_dim // 2, device) |
| | self.penta_fine = init_perfect_pentachora(num_classes, base_dim, device) |
| |
|
| | self.vertex_weights = nn.Parameter(torch.ones(num_classes, 5, device=device) / 5) |
| |
|
| | self.metric_coarse = nn.Parameter(torch.eye(base_dim // 4, device=device)) |
| | self.metric_medium = nn.Parameter(torch.eye(base_dim // 2, device=device)) |
| | self.metric_fine = nn.Parameter(torch.eye(base_dim, device=device)) |
| |
|
| | self.scale_weights = nn.Parameter(torch.tensor([0.2, 0.3, 0.5], device=device)) |
| |
|
| | def mahalanobis_distance(self, x, pentachora, metric): |
| | x_trans = torch.matmul(x, metric) |
| | p_trans = torch.einsum('cpd,de->cpe', pentachora, metric) |
| | diffs = p_trans.unsqueeze(0) - x_trans.unsqueeze(1).unsqueeze(2) |
| | dists = torch.norm(diffs, dim=-1) |
| | return dists |
| |
|
| | def forward(self, x_coarse, x_medium, x_fine): |
| | dists_c = self.mahalanobis_distance(x_coarse, self.penta_coarse, self.metric_coarse) |
| | dists_m = self.mahalanobis_distance(x_medium, self.penta_medium, self.metric_medium) |
| | dists_f = self.mahalanobis_distance(x_fine, self.penta_fine, self.metric_fine) |
| |
|
| | weights = F.softmax(self.vertex_weights, dim=1).unsqueeze(0) |
| | dists_c = dists_c * weights |
| | dists_m = dists_m * weights |
| | dists_f = dists_f * weights |
| |
|
| | scores_c = -dists_c.sum(dim=-1) |
| | scores_m = -dists_m.sum(dim=-1) |
| | scores_f = -dists_f.sum(dim=-1) |
| |
|
| | w = F.softmax(self.scale_weights, dim=0) |
| | scores = w[0] * scores_c + w[1] * scores_m + w[2] * scores_f |
| |
|
| | return scores, (dists_c, dists_m, dists_f) |
| |
|
| | def regularization_loss(self): |
| | mask = torch.triu(torch.ones(5, 5, device=self.device), diagonal=1).bool() |
| | |
| | diffs_c = self.penta_coarse.unsqueeze(2) - self.penta_coarse.unsqueeze(1) |
| | dists_c = torch.norm(diffs_c, dim=-1) |
| | edges_c = dists_c[:, mask] |
| | |
| | diffs_m = self.penta_medium.unsqueeze(2) - self.penta_medium.unsqueeze(1) |
| | dists_m = torch.norm(diffs_m, dim=-1) |
| | edges_m = dists_m[:, mask] |
| | |
| | diffs_f = self.penta_fine.unsqueeze(2) - self.penta_fine.unsqueeze(1) |
| | dists_f = torch.norm(diffs_f, dim=-1) |
| | edges_f = dists_f[:, mask] |
| | |
| | all_edges = torch.stack([edges_c, edges_m, edges_f], dim=0) |
| | |
| | edge_var = torch.var(all_edges, dim=2).mean() |
| | min_edges = torch.min(all_edges, dim=2)[0] |
| | collapse_penalty = torch.relu(0.5 - min_edges).mean() |
| | |
| | return edge_var + collapse_penalty |
| |
|
| | def contrastive_pentachoron_loss_batched(latents, targets, pentachora, temp=0.1): |
| | batch_size = latents.size(0) |
| | num_classes = pentachora.size(0) |
| |
|
| | diffs = latents.unsqueeze(1).unsqueeze(2) - pentachora.unsqueeze(0) |
| | dists = torch.norm(diffs, dim=-1) |
| | min_dists, _ = torch.min(dists, dim=2) |
| |
|
| | sims = -min_dists / temp |
| | targets_one_hot = F.one_hot(targets, num_classes).float() |
| |
|
| | max_sims, _ = torch.max(sims, dim=1, keepdim=True) |
| | exp_sims = torch.exp(sims - max_sims) |
| |
|
| | pos_sims = torch.sum(exp_sims * targets_one_hot, dim=1) |
| | all_sims = torch.sum(exp_sims, dim=1) |
| |
|
| | loss = -torch.log(pos_sims / all_sims).mean() |
| | return loss |
| |
|
| | |
| | encoder = AdaptiveEncoder(config["input_dim"], config["base_dim"]).to(device) |
| | classifier = PerfectPentachoron(num_classes, config["base_dim"], device).to(device) |
| |
|
| | |
| | try: |
| | encoder = torch.compile(encoder) |
| | classifier = torch.compile(classifier) |
| | print("Models compiled successfully") |
| | except: |
| | print("Torch compile not available, using eager mode") |
| |
|
| | optimizer = torch.optim.AdamW([ |
| | {'params': encoder.parameters(), 'lr': config["initial_lr"]}, |
| | {'params': classifier.parameters(), 'lr': config["initial_lr"] * 0.5} |
| | ], weight_decay=1e-5) |
| |
|
| | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["epochs"]) |
| |
|
| | |
| | def save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=False): |
| | """Save checkpoint as safetensors with proper organization""" |
| | |
| | encoder_state = {f"encoder.{k}": v.cpu() for k, v in encoder.state_dict().items()} |
| | classifier_state = {f"classifier.{k}": v.cpu() for k, v in classifier.state_dict().items()} |
| | |
| | |
| | model_state = {**encoder_state, **classifier_state} |
| | |
| | |
| | checkpoint_name = f"checkpoint_{run_hash}_epoch_{epoch:03d}.safetensors" |
| | if is_best: |
| | checkpoint_name = f"best_{run_hash}.safetensors" |
| | |
| | checkpoint_path = os.path.join("checkpoints", checkpoint_name) |
| | save_file(model_state, checkpoint_path) |
| | |
| | |
| | training_state = { |
| | 'epoch': epoch, |
| | 'optimizer': optimizer.state_dict(), |
| | 'scheduler': scheduler.state_dict(), |
| | 'metrics': metrics, |
| | 'config': config |
| | } |
| | |
| | state_path = checkpoint_path.replace('.safetensors', '_state.pt') |
| | torch.save(training_state, state_path) |
| | |
| | print(f"Saved checkpoint: {checkpoint_name}") |
| | |
| | |
| | try: |
| | |
| | upload_folder( |
| | folder_path="checkpoints", |
| | repo_id=REPO_ID, |
| | repo_type="model", |
| | token=HF_TOKEN, |
| | path_in_repo=f"weights/{run_hash}", |
| | commit_message=f"Epoch {epoch} - Test Acc: {metrics['test_acc']:.4f}" |
| | ) |
| | |
| | |
| | upload_folder( |
| | folder_path=f"tensorboard_logs/{run_hash}", |
| | repo_id=REPO_ID, |
| | repo_type="model", |
| | token=HF_TOKEN, |
| | path_in_repo=f"runs/{run_hash}", |
| | commit_message=f"TensorBoard logs - Epoch {epoch}" |
| | ) |
| | except Exception as e: |
| | print(f"HF upload error: {e}") |
| |
|
| | |
| | def train_epoch(epoch): |
| | encoder.train() |
| | classifier.train() |
| |
|
| | total_loss = 0.0 |
| | total_ce = 0.0 |
| | total_contr = 0.0 |
| | total_reg = 0.0 |
| | correct = 0 |
| | total = 0 |
| |
|
| | pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") |
| | for batch_idx, (inputs, targets) in enumerate(pbar): |
| | inputs, targets = inputs.to(device), targets.to(device) |
| |
|
| | optimizer.zero_grad() |
| |
|
| | x_coarse, x_medium, x_fine = encoder(inputs) |
| | scores, all_dists = classifier(x_coarse, x_medium, x_fine) |
| |
|
| | ce_loss = F.cross_entropy(scores, targets) |
| |
|
| | contr_c = contrastive_pentachoron_loss_batched(x_coarse, targets, classifier.penta_coarse, config["temp_contrastive"]) |
| | contr_m = contrastive_pentachoron_loss_batched(x_medium, targets, classifier.penta_medium, config["temp_contrastive"]) |
| | contr_f = contrastive_pentachoron_loss_batched(x_fine, targets, classifier.penta_fine, config["temp_contrastive"]) |
| | contr_loss = (contr_c + contr_m + contr_f) / 3 |
| |
|
| | reg_loss = classifier.regularization_loss() |
| |
|
| | loss = ce_loss + config["lambda_contrastive"] * contr_loss + config["lambda_cayley"] * reg_loss |
| |
|
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0) |
| | torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0) |
| | optimizer.step() |
| |
|
| | total_loss += loss.item() * inputs.size(0) |
| | total_ce += ce_loss.item() * inputs.size(0) |
| | total_contr += contr_loss.item() * inputs.size(0) |
| | total_reg += reg_loss.item() * inputs.size(0) |
| |
|
| | preds = scores.argmax(dim=1) |
| | correct += (preds == targets).sum().item() |
| | total += inputs.size(0) |
| |
|
| | |
| | if batch_idx % 50 == 0: |
| | global_step = epoch * len(train_loader) + batch_idx |
| | writer.add_scalar('Train/BatchLoss', loss.item(), global_step) |
| | writer.add_scalar('Train/BatchAcc', correct/total, global_step) |
| |
|
| | pbar.set_postfix({ |
| | 'loss': f"{loss.item():.4f}", |
| | 'acc': f"{correct/total:.4f}", |
| | 'lr': f"{optimizer.param_groups[0]['lr']:.1e}" |
| | }) |
| |
|
| | return (total_loss/total, total_ce/total, total_contr/total, |
| | total_reg/total, correct/total) |
| |
|
| | @torch.no_grad() |
| | def evaluate(): |
| | encoder.eval() |
| | classifier.eval() |
| |
|
| | correct = 0 |
| | total = 0 |
| | class_correct = [0] * num_classes |
| | class_total = [0] * num_classes |
| |
|
| | pbar = tqdm(test_loader, desc="Evaluating") |
| | for inputs, targets in pbar: |
| | inputs, targets = inputs.to(device), targets.to(device) |
| |
|
| | x_coarse, x_medium, x_fine = encoder(inputs) |
| | scores, _ = classifier(x_coarse, x_medium, x_fine) |
| |
|
| | preds = scores.argmax(dim=1) |
| | correct += (preds == targets).sum().item() |
| | total += inputs.size(0) |
| |
|
| | for i in range(targets.size(0)): |
| | label = targets[i].item() |
| | class_total[label] += 1 |
| | if preds[i] == targets[i]: |
| | class_correct[label] += 1 |
| |
|
| | pbar.set_postfix({'acc': f"{correct/total:.4f}"}) |
| |
|
| | class_accs = [class_correct[i]/max(1, class_total[i]) for i in range(num_classes)] |
| | return correct/total, class_accs |
| |
|
| | |
| | print("\n" + "="*60) |
| | print(f"PERFECT PENTACHORON TRAINING - Run {run_hash}") |
| | print("="*60 + "\n") |
| |
|
| | best_acc = 0.0 |
| | train_history = [] |
| | test_history = [] |
| | patience = 7 |
| | no_improve = 0 |
| |
|
| | for epoch in range(config["epochs"]): |
| | |
| | train_loss, train_ce, train_contr, train_reg, train_acc = train_epoch(epoch) |
| | train_history.append(train_acc) |
| |
|
| | |
| | test_acc, class_accs = evaluate() |
| | test_history.append(test_acc) |
| |
|
| | |
| | writer.add_scalar('Loss/Total', train_loss, epoch) |
| | writer.add_scalar('Loss/CE', train_ce, epoch) |
| | writer.add_scalar('Loss/Contrastive', train_contr, epoch) |
| | writer.add_scalar('Loss/Regularization', train_reg, epoch) |
| | writer.add_scalar('Accuracy/Train', train_acc, epoch) |
| | writer.add_scalar('Accuracy/Test', test_acc, epoch) |
| | writer.add_scalar('Learning/LR', optimizer.param_groups[0]['lr'], epoch) |
| | writer.add_scalar('Learning/Generalization_Gap', train_acc - test_acc, epoch) |
| |
|
| | |
| | for i, acc in enumerate(class_accs[:10]): |
| | writer.add_scalar(f'ClassAcc/Class_{i}', acc, epoch) |
| |
|
| | |
| | scale_weights = F.softmax(classifier.scale_weights, dim=0) |
| | writer.add_scalar('Scales/Coarse', scale_weights[0], epoch) |
| | writer.add_scalar('Scales/Medium', scale_weights[1], epoch) |
| | writer.add_scalar('Scales/Fine', scale_weights[2], epoch) |
| |
|
| | scheduler.step() |
| |
|
| | |
| | print(f"\n[Epoch {epoch+1}/{config['epochs']}]") |
| | print(f"Train | Loss: {train_loss:.4f} | CE: {train_ce:.4f} | " |
| | f"Contr: {train_contr:.4f} | Reg: {train_reg:.4f} | Acc: {train_acc:.4f}") |
| | print(f"Test | Acc: {test_acc:.4f} | Best: {best_acc:.4f}") |
| |
|
| | |
| | metrics = { |
| | 'train_acc': train_acc, |
| | 'test_acc': test_acc, |
| | 'train_loss': train_loss, |
| | 'class_accs': class_accs |
| | } |
| |
|
| | |
| | if test_acc > best_acc: |
| | best_acc = test_acc |
| | no_improve = 0 |
| | print(f"NEW BEST! Saving checkpoint...") |
| | save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=True) |
| | else: |
| | no_improve += 1 |
| | if (epoch + 1) % 5 == 0: |
| | save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics) |
| |
|
| | |
| | if no_improve >= patience: |
| | print(f"Early stopping triggered (no improvement for {patience} epochs)") |
| | break |
| |
|
| | |
| | print("\n" + "="*60) |
| | print("FINAL RESULTS") |
| | print("="*60) |
| | print(f"Best Test Accuracy: {best_acc:.4f}") |
| | print(f"Final Train Accuracy: {train_history[-1]:.4f}") |
| | print(f"Generalization Gap: {train_history[-1] - test_history[-1]:.4f}") |
| |
|
| | |
| | save_checkpoint(epoch, encoder, classifier, optimizer, scheduler, metrics, is_best=False) |
| |
|
| | |
| | with torch.no_grad(): |
| | vertex_importance = F.softmax(classifier.vertex_weights, dim=1) |
| | scale_weights = F.softmax(classifier.scale_weights, dim=0).cpu().numpy() |
| | |
| | geometry_info = { |
| | 'scale_importance': { |
| | 'coarse': float(scale_weights[0]), |
| | 'medium': float(scale_weights[1]), |
| | 'fine': float(scale_weights[2]) |
| | }, |
| | 'dominant_vertices': {} |
| | } |
| | |
| | for c in range(min(10, num_classes)): |
| | weights = vertex_importance[c].cpu().numpy() |
| | dominant = np.argmax(weights) |
| | geometry_info['dominant_vertices'][f'class_{c}'] = { |
| | 'vertex': int(dominant), |
| | 'weight': float(weights[dominant]) |
| | } |
| | |
| | writer.add_text('Final_Geometry', json.dumps(geometry_info, indent=2), epoch) |
| |
|
| | writer.close() |
| | print(f"\n✨ Training Complete! Run hash: {run_hash}") |
| | print(f"Results uploaded to: https://huggingface.co/{REPO_ID}") |
| | print(f"TensorBoard: tensorboard --logdir tensorboard_logs/{run_hash}") |