| |
| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
| from collections import defaultdict |
| import numpy as np |
| from tqdm import tqdm |
| import matplotlib.pyplot as plt |
|
|
| def evaluate_pentachora_vit(model, test_loader, device='cuda'): |
| """Properly evaluate PentachoraViT model.""" |
| model.eval() |
|
|
| |
| class_names = get_cifar100_class_names() |
|
|
| |
| print(f"Model Configuration:") |
| print(f" Internal dim: {model.dim}") |
| print(f" Vocab dim: {model.vocab_dim}") |
| print(f" Num classes: {model.num_classes}") |
|
|
| |
| if hasattr(model, 'cls_tokens') and hasattr(model.cls_tokens, 'class_pentachora'): |
| crystals = model.cls_tokens.class_pentachora |
| print(f" Crystal shape: {crystals.shape}") |
| else: |
| print(" No crystals found!") |
| return None |
|
|
| |
| all_predictions = [] |
| all_targets = [] |
| all_confidences = [] |
| geometric_alignments_by_class = defaultdict(list) |
| aux_predictions = [] |
|
|
| with torch.no_grad(): |
| for images, targets in tqdm(test_loader, desc="Evaluating"): |
| images = images.to(device) |
| targets = targets.to(device) |
|
|
| |
| outputs = model(images) |
|
|
| |
| logits = outputs['logits'] |
| probs = F.softmax(logits, dim=1) |
| confidence, predicted = torch.max(probs, 1) |
|
|
| |
| all_predictions.extend(predicted.cpu().numpy()) |
| all_targets.extend(targets.cpu().numpy()) |
| all_confidences.extend(confidence.cpu().numpy()) |
|
|
| |
| if 'aux_logits' in outputs: |
| aux_probs = F.softmax(outputs['aux_logits'], dim=1) |
| _, aux_pred = torch.max(aux_probs, 1) |
| aux_predictions.extend(aux_pred.cpu().numpy()) |
|
|
| |
| if 'geometric_alignments' in outputs: |
| |
| geo_align = outputs['geometric_alignments'] |
| |
| geo_align_mean = geo_align.mean(dim=1) |
|
|
| for i, target_class in enumerate(targets): |
| class_idx = target_class.item() |
| |
| geometric_alignments_by_class[class_idx].append( |
| geo_align_mean[i, class_idx].item() |
| ) |
|
|
| |
| all_predictions = np.array(all_predictions) |
| all_targets = np.array(all_targets) |
| all_confidences = np.array(all_confidences) |
|
|
| |
| class_results = [] |
| for class_idx in range(len(class_names)): |
| mask = all_targets == class_idx |
| if mask.sum() == 0: |
| continue |
|
|
| class_preds = all_predictions[mask] |
| correct = (class_preds == class_idx).sum() |
| total = mask.sum() |
| accuracy = 100.0 * correct / total |
|
|
| |
| class_conf = all_confidences[mask].mean() |
|
|
| |
| geo_align = np.mean(geometric_alignments_by_class[class_idx]) if geometric_alignments_by_class[class_idx] else 0 |
|
|
| |
| class_crystal = crystals[class_idx].detach().cpu() |
| vertex_variance = class_crystal.var(dim=0).mean().item() |
|
|
| |
| crystal_norm = class_crystal.norm(dim=-1).mean().item() |
|
|
| class_results.append({ |
| 'class_idx': class_idx, |
| 'class_name': class_names[class_idx], |
| 'accuracy': accuracy, |
| 'correct': int(correct), |
| 'total': int(total), |
| 'avg_confidence': class_conf, |
| 'geometric_alignment': geo_align, |
| 'vertex_variance': vertex_variance, |
| 'crystal_norm': crystal_norm |
| }) |
|
|
| |
| class_results.sort(key=lambda x: x['accuracy'], reverse=True) |
|
|
| |
| overall_acc = 100.0 * (all_predictions == all_targets).mean() |
|
|
| |
| aux_acc = None |
| if aux_predictions: |
| aux_predictions = np.array(aux_predictions) |
| aux_acc = 100.0 * (aux_predictions == all_targets).mean() |
|
|
| |
| print(f"\n" + "="*80) |
| print(f"EVALUATION RESULTS") |
| print(f"="*80) |
| print(f"\nOverall Accuracy: {overall_acc:.2f}%") |
| if aux_acc: |
| print(f"Auxiliary Head Accuracy: {aux_acc:.2f}%") |
|
|
| |
| print(f"\nTop 10 Classes:") |
| print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}") |
| print("-"*70) |
| for r in class_results[:10]: |
| print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " |
| f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}") |
|
|
| |
| print(f"\nBottom 10 Classes:") |
| print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}") |
| print("-"*70) |
| for r in class_results[-10:]: |
| print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " |
| f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}") |
|
|
| |
| accuracies = [r['accuracy'] for r in class_results] |
| geo_aligns = [r['geometric_alignment'] for r in class_results] |
| crystal_norms = [r['crystal_norm'] for r in class_results] |
| vertex_vars = [r['vertex_variance'] for r in class_results] |
|
|
| print(f"\nCorrelations with Accuracy:") |
| print(f" Geometric Alignment: {np.corrcoef(accuracies, geo_aligns)[0,1]:.3f}") |
| print(f" Crystal Norm: {np.corrcoef(accuracies, crystal_norms)[0,1]:.3f}") |
| print(f" Vertex Variance: {np.corrcoef(accuracies, vertex_vars)[0,1]:.3f}") |
|
|
| |
| fig, axes = plt.subplots(2, 2, figsize=(12, 10)) |
|
|
| |
| ax = axes[0, 0] |
| ax.hist(accuracies, bins=20, edgecolor='black', alpha=0.7) |
| ax.axvline(overall_acc, color='red', linestyle='--', label=f'Overall: {overall_acc:.1f}%') |
| ax.set_xlabel('Accuracy (%)') |
| ax.set_ylabel('Count') |
| ax.set_title('Per-Class Accuracy Distribution') |
| ax.legend() |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[0, 1] |
| scatter = ax.scatter(geo_aligns, accuracies, c=crystal_norms, cmap='viridis', alpha=0.6) |
| ax.set_xlabel('Geometric Alignment Score') |
| ax.set_ylabel('Accuracy (%)') |
| ax.set_title('Accuracy vs Geometric Alignment\n(color = crystal norm)') |
| plt.colorbar(scatter, ax=ax) |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[1, 0] |
| ax.scatter(crystal_norms, accuracies, alpha=0.6) |
| ax.set_xlabel('Crystal Norm (avg magnitude)') |
| ax.set_ylabel('Accuracy (%)') |
| ax.set_title('Accuracy vs Crystal Norm') |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[1, 1] |
| top10_acc = [r['accuracy'] for r in class_results[:10]] |
| bottom10_acc = [r['accuracy'] for r in class_results[-10:]] |
| top10_geo = [r['geometric_alignment'] for r in class_results[:10]] |
| bottom10_geo = [r['geometric_alignment'] for r in class_results[-10:]] |
|
|
| x = np.arange(10) |
| width = 0.35 |
| ax.bar(x - width/2, top10_acc, width, label='Top 10 Accuracy', color='green', alpha=0.7) |
| ax.bar(x + width/2, bottom10_acc, width, label='Bottom 10 Accuracy', color='red', alpha=0.7) |
| ax.set_xlabel('Rank within group') |
| ax.set_ylabel('Accuracy (%)') |
| ax.set_title('Top 10 vs Bottom 10 Classes') |
| ax.legend() |
| ax.grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.show() |
| |
| |
| |
| print(f"\n{'='*90}") |
| print("Sparky — Full Class Spectrum") |
| print(f"{'='*90}") |
| print(f"{'Idx':<5} {'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12} {'Variance':<10}") |
| print("-" * 90) |
|
|
| for r in sorted(class_results, key=lambda x: x['class_idx']): |
| print(f"{r['class_idx']:<5} {r['class_name']:<20} " |
| f"{r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " |
| f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f} " |
| f"{r['vertex_variance']:>9.8f}") |
|
|
| return class_results, overall_acc |
|
|
| |
| if 'model' in globals(): |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| _, test_loader = get_cifar100_dataloaders(batch_size=100) |
|
|
| results, overall = evaluate_pentachora_vit(model, test_loader, device) |
|
|
| |
| print("\nCrystal Geometry Analysis:") |
| print("-"*50) |
|
|
| |
| crystals = model.cls_tokens.class_pentachora.detach().cpu() |
|
|
| |
| crystals_flat = crystals.mean(dim=1) |
| crystals_norm = F.normalize(crystals_flat, dim=1) |
| similarities = torch.matmul(crystals_norm, crystals_norm.T) |
|
|
| |
| print("\nMost similar classes with poor performance:") |
| for i in range(100): |
| for j in range(i+1, 100): |
| if results[i]['accuracy'] < 20 and results[j]['accuracy'] < 20: |
| sim = similarities[results[i]['class_idx'], results[j]['class_idx']].item() |
| if sim > 0.5: |
| print(f" {results[i]['class_name']:<15} ({results[i]['accuracy']:.1f}%) ↔ " |
| f"{results[j]['class_name']:<15} ({results[j]['accuracy']:.1f}%) : {sim:.3f}") |
| |
| else: |
| print("No model found in memory!") |