| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from collections import defaultdict |
| | import numpy as np |
| | from tqdm import tqdm |
| | import matplotlib.pyplot as plt |
| |
|
| | def evaluate_pentachora_vit(model, test_loader, device='cuda'): |
| | """Properly evaluate PentachoraViT model.""" |
| | model.eval() |
| |
|
| | |
| | class_names = get_cifar100_class_names() |
| |
|
| | |
| | print(f"Model Configuration:") |
| | print(f" Internal dim: {model.dim}") |
| | print(f" Vocab dim: {model.vocab_dim}") |
| | print(f" Num classes: {model.num_classes}") |
| |
|
| | |
| | if hasattr(model, 'cls_tokens') and hasattr(model.cls_tokens, 'class_pentachora'): |
| | crystals = model.cls_tokens.class_pentachora |
| | print(f" Crystal shape: {crystals.shape}") |
| | else: |
| | print(" No crystals found!") |
| | return None |
| |
|
| | |
| | all_predictions = [] |
| | all_targets = [] |
| | all_confidences = [] |
| | geometric_alignments_by_class = defaultdict(list) |
| | aux_predictions = [] |
| |
|
| | with torch.no_grad(): |
| | for images, targets in tqdm(test_loader, desc="Evaluating"): |
| | images = images.to(device) |
| | targets = targets.to(device) |
| |
|
| | |
| | outputs = model(images) |
| |
|
| | |
| | logits = outputs['logits'] |
| | probs = F.softmax(logits, dim=1) |
| | confidence, predicted = torch.max(probs, 1) |
| |
|
| | |
| | all_predictions.extend(predicted.cpu().numpy()) |
| | all_targets.extend(targets.cpu().numpy()) |
| | all_confidences.extend(confidence.cpu().numpy()) |
| |
|
| | |
| | if 'aux_logits' in outputs: |
| | aux_probs = F.softmax(outputs['aux_logits'], dim=1) |
| | _, aux_pred = torch.max(aux_probs, 1) |
| | aux_predictions.extend(aux_pred.cpu().numpy()) |
| |
|
| | |
| | if 'geometric_alignments' in outputs: |
| | |
| | geo_align = outputs['geometric_alignments'] |
| | |
| | geo_align_mean = geo_align.mean(dim=1) |
| |
|
| | for i, target_class in enumerate(targets): |
| | class_idx = target_class.item() |
| | |
| | geometric_alignments_by_class[class_idx].append( |
| | geo_align_mean[i, class_idx].item() |
| | ) |
| |
|
| | |
| | all_predictions = np.array(all_predictions) |
| | all_targets = np.array(all_targets) |
| | all_confidences = np.array(all_confidences) |
| |
|
| | |
| | class_results = [] |
| | for class_idx in range(len(class_names)): |
| | mask = all_targets == class_idx |
| | if mask.sum() == 0: |
| | continue |
| |
|
| | class_preds = all_predictions[mask] |
| | correct = (class_preds == class_idx).sum() |
| | total = mask.sum() |
| | accuracy = 100.0 * correct / total |
| |
|
| | |
| | class_conf = all_confidences[mask].mean() |
| |
|
| | |
| | geo_align = np.mean(geometric_alignments_by_class[class_idx]) if geometric_alignments_by_class[class_idx] else 0 |
| |
|
| | |
| | class_crystal = crystals[class_idx].detach().cpu() |
| | vertex_variance = class_crystal.var(dim=0).mean().item() |
| |
|
| | |
| | crystal_norm = class_crystal.norm(dim=-1).mean().item() |
| |
|
| | class_results.append({ |
| | 'class_idx': class_idx, |
| | 'class_name': class_names[class_idx], |
| | 'accuracy': accuracy, |
| | 'correct': int(correct), |
| | 'total': int(total), |
| | 'avg_confidence': class_conf, |
| | 'geometric_alignment': geo_align, |
| | 'vertex_variance': vertex_variance, |
| | 'crystal_norm': crystal_norm |
| | }) |
| |
|
| | |
| | class_results.sort(key=lambda x: x['accuracy'], reverse=True) |
| |
|
| | |
| | overall_acc = 100.0 * (all_predictions == all_targets).mean() |
| |
|
| | |
| | aux_acc = None |
| | if aux_predictions: |
| | aux_predictions = np.array(aux_predictions) |
| | aux_acc = 100.0 * (aux_predictions == all_targets).mean() |
| |
|
| | |
| | print(f"\n" + "="*80) |
| | print(f"EVALUATION RESULTS") |
| | print(f"="*80) |
| | print(f"\nOverall Accuracy: {overall_acc:.2f}%") |
| | if aux_acc: |
| | print(f"Auxiliary Head Accuracy: {aux_acc:.2f}%") |
| |
|
| | |
| | print(f"\nTop 10 Classes:") |
| | print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}") |
| | print("-"*70) |
| | for r in class_results[:10]: |
| | print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " |
| | f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}") |
| |
|
| | |
| | print(f"\nBottom 10 Classes:") |
| | print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}") |
| | print("-"*70) |
| | for r in class_results[-10:]: |
| | print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " |
| | f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}") |
| |
|
| | |
| | accuracies = [r['accuracy'] for r in class_results] |
| | geo_aligns = [r['geometric_alignment'] for r in class_results] |
| | crystal_norms = [r['crystal_norm'] for r in class_results] |
| | vertex_vars = [r['vertex_variance'] for r in class_results] |
| |
|
| | print(f"\nCorrelations with Accuracy:") |
| | print(f" Geometric Alignment: {np.corrcoef(accuracies, geo_aligns)[0,1]:.3f}") |
| | print(f" Crystal Norm: {np.corrcoef(accuracies, crystal_norms)[0,1]:.3f}") |
| | print(f" Vertex Variance: {np.corrcoef(accuracies, vertex_vars)[0,1]:.3f}") |
| |
|
| | |
| | fig, axes = plt.subplots(2, 2, figsize=(12, 10)) |
| |
|
| | |
| | ax = axes[0, 0] |
| | ax.hist(accuracies, bins=20, edgecolor='black', alpha=0.7) |
| | ax.axvline(overall_acc, color='red', linestyle='--', label=f'Overall: {overall_acc:.1f}%') |
| | ax.set_xlabel('Accuracy (%)') |
| | ax.set_ylabel('Count') |
| | ax.set_title('Per-Class Accuracy Distribution') |
| | ax.legend() |
| | ax.grid(True, alpha=0.3) |
| |
|
| | |
| | ax = axes[0, 1] |
| | scatter = ax.scatter(geo_aligns, accuracies, c=crystal_norms, cmap='viridis', alpha=0.6) |
| | ax.set_xlabel('Geometric Alignment Score') |
| | ax.set_ylabel('Accuracy (%)') |
| | ax.set_title('Accuracy vs Geometric Alignment\n(color = crystal norm)') |
| | plt.colorbar(scatter, ax=ax) |
| | ax.grid(True, alpha=0.3) |
| |
|
| | |
| | ax = axes[1, 0] |
| | ax.scatter(crystal_norms, accuracies, alpha=0.6) |
| | ax.set_xlabel('Crystal Norm (avg magnitude)') |
| | ax.set_ylabel('Accuracy (%)') |
| | ax.set_title('Accuracy vs Crystal Norm') |
| | ax.grid(True, alpha=0.3) |
| |
|
| | |
| | ax = axes[1, 1] |
| | top10_acc = [r['accuracy'] for r in class_results[:10]] |
| | bottom10_acc = [r['accuracy'] for r in class_results[-10:]] |
| | top10_geo = [r['geometric_alignment'] for r in class_results[:10]] |
| | bottom10_geo = [r['geometric_alignment'] for r in class_results[-10:]] |
| |
|
| | x = np.arange(10) |
| | width = 0.35 |
| | ax.bar(x - width/2, top10_acc, width, label='Top 10 Accuracy', color='green', alpha=0.7) |
| | ax.bar(x + width/2, bottom10_acc, width, label='Bottom 10 Accuracy', color='red', alpha=0.7) |
| | ax.set_xlabel('Rank within group') |
| | ax.set_ylabel('Accuracy (%)') |
| | ax.set_title('Top 10 vs Bottom 10 Classes') |
| | ax.legend() |
| | ax.grid(True, alpha=0.3) |
| |
|
| | plt.tight_layout() |
| | plt.show() |
| | |
| | |
| | |
| | print(f"\n{'='*90}") |
| | print("Sparky — Full Class Spectrum") |
| | print(f"{'='*90}") |
| | print(f"{'Idx':<5} {'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12} {'Variance':<10}") |
| | print("-" * 90) |
| |
|
| | for r in sorted(class_results, key=lambda x: x['class_idx']): |
| | print(f"{r['class_idx']:<5} {r['class_name']:<20} " |
| | f"{r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " |
| | f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f} " |
| | f"{r['vertex_variance']:>9.8f}") |
| |
|
| | return class_results, overall_acc |
| |
|
| | |
| | if 'model' in globals(): |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | _, test_loader = get_cifar100_dataloaders(batch_size=100) |
| |
|
| | results, overall = evaluate_pentachora_vit(model, test_loader, device) |
| |
|
| | |
| | print("\nCrystal Geometry Analysis:") |
| | print("-"*50) |
| |
|
| | |
| | crystals = model.cls_tokens.class_pentachora.detach().cpu() |
| |
|
| | |
| | crystals_flat = crystals.mean(dim=1) |
| | crystals_norm = F.normalize(crystals_flat, dim=1) |
| | similarities = torch.matmul(crystals_norm, crystals_norm.T) |
| |
|
| | |
| | print("\nMost similar classes with poor performance:") |
| | for i in range(100): |
| | for j in range(i+1, 100): |
| | if results[i]['accuracy'] < 20 and results[j]['accuracy'] < 20: |
| | sim = similarities[results[i]['class_idx'], results[j]['class_idx']].item() |
| | if sim > 0.5: |
| | print(f" {results[i]['class_name']:<15} ({results[i]['accuracy']:.1f}%) ↔ " |
| | f"{results[j]['class_name']:<15} ({results[j]['accuracy']:.1f}%) : {sim:.3f}") |
| | |
| | else: |
| | print("No model found in memory!") |