# ============================================ # PentachoraViT CIFAR-100 Evaluation # ============================================ import torch import torch.nn.functional as F from collections import defaultdict import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt def evaluate_pentachora_vit(model, test_loader, device='cuda'): """Properly evaluate PentachoraViT model.""" model.eval() # Get class names class_names = get_cifar100_class_names() # Check model configuration print(f"Model Configuration:") print(f" Internal dim: {model.dim}") print(f" Vocab dim: {model.vocab_dim}") print(f" Num classes: {model.num_classes}") # Get the class crystals if hasattr(model, 'cls_tokens') and hasattr(model.cls_tokens, 'class_pentachora'): crystals = model.cls_tokens.class_pentachora # [100, 5, vocab_dim] print(f" Crystal shape: {crystals.shape}") else: print(" No crystals found!") return None # Track metrics all_predictions = [] all_targets = [] all_confidences = [] geometric_alignments_by_class = defaultdict(list) aux_predictions = [] with torch.no_grad(): for images, targets in tqdm(test_loader, desc="Evaluating"): images = images.to(device) targets = targets.to(device) # Get model outputs dictionary outputs = model(images) # Main predictions from primary head logits = outputs['logits'] # [batch, 100] probs = F.softmax(logits, dim=1) confidence, predicted = torch.max(probs, 1) # Store predictions all_predictions.extend(predicted.cpu().numpy()) all_targets.extend(targets.cpu().numpy()) all_confidences.extend(confidence.cpu().numpy()) # Auxiliary predictions if 'aux_logits' in outputs: aux_probs = F.softmax(outputs['aux_logits'], dim=1) _, aux_pred = torch.max(aux_probs, 1) aux_predictions.extend(aux_pred.cpu().numpy()) # Geometric alignments - these show how patches align with class crystals if 'geometric_alignments' in outputs: # Shape: [batch, num_patches, num_classes] geo_align = outputs['geometric_alignments'] # Average over patches to get per-sample class alignments geo_align_mean = geo_align.mean(dim=1) # [batch, num_classes] for i, target_class in enumerate(targets): class_idx = target_class.item() # Store alignment score for the true class geometric_alignments_by_class[class_idx].append( geo_align_mean[i, class_idx].item() ) # Convert to numpy arrays all_predictions = np.array(all_predictions) all_targets = np.array(all_targets) all_confidences = np.array(all_confidences) # Calculate per-class metrics class_results = [] for class_idx in range(len(class_names)): mask = all_targets == class_idx if mask.sum() == 0: continue class_preds = all_predictions[mask] correct = (class_preds == class_idx).sum() total = mask.sum() accuracy = 100.0 * correct / total # Average confidence for this class class_conf = all_confidences[mask].mean() # Geometric alignment for this class geo_align = np.mean(geometric_alignments_by_class[class_idx]) if geometric_alignments_by_class[class_idx] else 0 # Crystal statistics class_crystal = crystals[class_idx].detach().cpu() # [5, vocab_dim] vertex_variance = class_crystal.var(dim=0).mean().item() # Crystal norm (average magnitude) crystal_norm = class_crystal.norm(dim=-1).mean().item() class_results.append({ 'class_idx': class_idx, 'class_name': class_names[class_idx], 'accuracy': accuracy, 'correct': int(correct), 'total': int(total), 'avg_confidence': class_conf, 'geometric_alignment': geo_align, 'vertex_variance': vertex_variance, 'crystal_norm': crystal_norm }) # Sort by accuracy class_results.sort(key=lambda x: x['accuracy'], reverse=True) # Overall metrics overall_acc = 100.0 * (all_predictions == all_targets).mean() # Auxiliary head accuracy if available aux_acc = None if aux_predictions: aux_predictions = np.array(aux_predictions) aux_acc = 100.0 * (aux_predictions == all_targets).mean() # Print results print(f"\n" + "="*80) print(f"EVALUATION RESULTS") print(f"="*80) print(f"\nOverall Accuracy: {overall_acc:.2f}%") if aux_acc: print(f"Auxiliary Head Accuracy: {aux_acc:.2f}%") # Top 10 classes print(f"\nTop 10 Classes:") print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}") print("-"*70) for r in class_results[:10]: print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}") # Bottom 10 classes print(f"\nBottom 10 Classes:") print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}") print("-"*70) for r in class_results[-10:]: print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}") # Analyze correlations accuracies = [r['accuracy'] for r in class_results] geo_aligns = [r['geometric_alignment'] for r in class_results] crystal_norms = [r['crystal_norm'] for r in class_results] vertex_vars = [r['vertex_variance'] for r in class_results] print(f"\nCorrelations with Accuracy:") print(f" Geometric Alignment: {np.corrcoef(accuracies, geo_aligns)[0,1]:.3f}") print(f" Crystal Norm: {np.corrcoef(accuracies, crystal_norms)[0,1]:.3f}") print(f" Vertex Variance: {np.corrcoef(accuracies, vertex_vars)[0,1]:.3f}") # Visualizations fig, axes = plt.subplots(2, 2, figsize=(12, 10)) # 1. Accuracy distribution ax = axes[0, 0] ax.hist(accuracies, bins=20, edgecolor='black', alpha=0.7) ax.axvline(overall_acc, color='red', linestyle='--', label=f'Overall: {overall_acc:.1f}%') ax.set_xlabel('Accuracy (%)') ax.set_ylabel('Count') ax.set_title('Per-Class Accuracy Distribution') ax.legend() ax.grid(True, alpha=0.3) # 2. Accuracy vs Geometric Alignment ax = axes[0, 1] scatter = ax.scatter(geo_aligns, accuracies, c=crystal_norms, cmap='viridis', alpha=0.6) ax.set_xlabel('Geometric Alignment Score') ax.set_ylabel('Accuracy (%)') ax.set_title('Accuracy vs Geometric Alignment\n(color = crystal norm)') plt.colorbar(scatter, ax=ax) ax.grid(True, alpha=0.3) # 3. Crystal Analysis ax = axes[1, 0] ax.scatter(crystal_norms, accuracies, alpha=0.6) ax.set_xlabel('Crystal Norm (avg magnitude)') ax.set_ylabel('Accuracy (%)') ax.set_title('Accuracy vs Crystal Norm') ax.grid(True, alpha=0.3) # 4. Top/Bottom comparison ax = axes[1, 1] top10_acc = [r['accuracy'] for r in class_results[:10]] bottom10_acc = [r['accuracy'] for r in class_results[-10:]] top10_geo = [r['geometric_alignment'] for r in class_results[:10]] bottom10_geo = [r['geometric_alignment'] for r in class_results[-10:]] x = np.arange(10) width = 0.35 ax.bar(x - width/2, top10_acc, width, label='Top 10 Accuracy', color='green', alpha=0.7) ax.bar(x + width/2, bottom10_acc, width, label='Bottom 10 Accuracy', color='red', alpha=0.7) ax.set_xlabel('Rank within group') ax.set_ylabel('Accuracy (%)') ax.set_title('Top 10 vs Bottom 10 Classes') ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() plt.show() # =================================================================================== # FULL 100-CLASS DIAGNOSTIC SPECTRUM (SORTED BY CLASS IDX FOR CONSISTENCY) # =================================================================================== print(f"\n{'='*90}") print("Sparky — Full Class Spectrum") print(f"{'='*90}") print(f"{'Idx':<5} {'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12} {'Variance':<10}") print("-" * 90) for r in sorted(class_results, key=lambda x: x['class_idx']): print(f"{r['class_idx']:<5} {r['class_name']:<20} " f"{r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f} " f"{r['vertex_variance']:>9.8f}") return class_results, overall_acc # Run evaluation if 'model' in globals(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') _, test_loader = get_cifar100_dataloaders(batch_size=100) results, overall = evaluate_pentachora_vit(model, test_loader, device) # Additional crystal analysis print("\nCrystal Geometry Analysis:") print("-"*50) # Get crystals crystals = model.cls_tokens.class_pentachora.detach().cpu() # Compute pairwise similarities between class crystals crystals_flat = crystals.mean(dim=1) # Average over 5 vertices crystals_norm = F.normalize(crystals_flat, dim=1) similarities = torch.matmul(crystals_norm, crystals_norm.T) # Find confused pairs (high similarity, both low accuracy) print("\nMost similar classes with poor performance:") for i in range(100): for j in range(i+1, 100): if results[i]['accuracy'] < 20 and results[j]['accuracy'] < 20: sim = similarities[results[i]['class_idx'], results[j]['class_idx']].item() if sim > 0.5: print(f" {results[i]['class_name']:<15} ({results[i]['accuracy']:.1f}%) ↔ " f"{results[j]['class_name']:<15} ({results[j]['accuracy']:.1f}%) : {sim:.3f}") else: print("No model found in memory!")