gap-clip / evaluation /run_all_evaluations.py
Leacb4's picture
Upload evaluation/run_all_evaluations.py with huggingface_hub
1419f2c verified
#!/usr/bin/env python3
"""
Comprehensive Evaluation Runner for GAP-CLIP
=============================================
This script runs all available evaluations on the GAP-CLIP model and generates
a comprehensive report with metrics, visualizations, and comparisons.
Usage:
python run_all_evaluations.py [--repo-id REPO_ID] [--output OUTPUT_DIR]
Features:
- Runs all evaluation scripts
- Generates summary report
- Creates visualizations
- Compares with baseline models
- Saves results to organized directory
Author: Lea Attia Sarfati
"""
import os
import sys
import json
import argparse
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))
# Import evaluation modules
try:
from evaluation.main_model_evaluation import (
evaluate_fashion_mnist,
evaluate_kaggle_marqo,
evaluate_local_validation
)
from example_usage import load_models_from_hf
except ImportError as e:
print(f"⚠️ Import error: {e}")
print("Make sure you're running from the correct directory")
sys.exit(1)
class EvaluationRunner:
"""
Comprehensive evaluation runner for GAP-CLIP.
Runs all available evaluations and generates a summary report.
"""
def __init__(self, repo_id: str, output_dir: str = "evaluation_results"):
"""
Initialize the evaluation runner.
Args:
repo_id: Hugging Face repository ID
output_dir: Directory to save results
"""
self.repo_id = repo_id
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True, parents=True)
# Create timestamp for this run
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.run_dir = self.output_dir / f"run_{self.timestamp}"
self.run_dir.mkdir(exist_ok=True)
self.results = {}
self.models = None
def load_models(self):
"""Load models from Hugging Face."""
print("=" * 80)
print("📥 Loading Models")
print("=" * 80)
try:
self.models = load_models_from_hf(self.repo_id)
print("✅ Models loaded successfully\n")
return True
except Exception as e:
print(f"❌ Failed to load models: {e}\n")
return False
def run_fashion_mnist_evaluation(self):
"""Run Fashion-MNIST evaluation."""
print("\n" + "=" * 80)
print("👕 Fashion-MNIST Evaluation")
print("=" * 80)
try:
results = evaluate_fashion_mnist(
model=self.models['main_model'],
processor=self.models['processor'],
device=self.models['device']
)
self.results['fashion_mnist'] = results
print("✅ Fashion-MNIST evaluation completed")
return results
except Exception as e:
print(f"❌ Fashion-MNIST evaluation failed: {e}")
return None
def run_kaggle_evaluation(self):
"""Run KAGL Marqo evaluation."""
print("\n" + "=" * 80)
print("🛍️ KAGL Marqo Evaluation")
print("=" * 80)
try:
results = evaluate_kaggle_marqo(
model=self.models['main_model'],
processor=self.models['processor'],
device=self.models['device']
)
self.results['kaggle_marqo'] = results
print("✅ KAGL Marqo evaluation completed")
return results
except Exception as e:
print(f"❌ KAGL Marqo evaluation failed: {e}")
return None
def run_local_evaluation(self):
"""Run local validation evaluation."""
print("\n" + "=" * 80)
print("📁 Local Validation Evaluation")
print("=" * 80)
try:
results = evaluate_local_validation(
model=self.models['main_model'],
processor=self.models['processor'],
device=self.models['device']
)
self.results['local_validation'] = results
print("✅ Local validation evaluation completed")
return results
except Exception as e:
print(f"❌ Local validation evaluation failed: {e}")
return None
def generate_summary(self):
"""Generate summary report."""
print("\n" + "=" * 80)
print("📊 Generating Summary Report")
print("=" * 80)
summary = {
'timestamp': self.timestamp,
'repo_id': self.repo_id,
'evaluations': {}
}
# Collect all results
for eval_name, eval_results in self.results.items():
if eval_results:
summary['evaluations'][eval_name] = eval_results
# Save to JSON
summary_path = self.run_dir / "summary.json"
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"✅ Summary saved to: {summary_path}")
# Print summary
self.print_summary(summary)
return summary
def print_summary(self, summary):
"""Print formatted summary."""
print("\n" + "=" * 80)
print("📈 Evaluation Summary")
print("=" * 80)
print(f"\nRepository: {summary['repo_id']}")
print(f"Timestamp: {summary['timestamp']}\n")
for eval_name, eval_results in summary['evaluations'].items():
print(f"\n{'─' * 40}")
print(f"📊 {eval_name.upper()}")
print(f"{'─' * 40}")
if isinstance(eval_results, dict):
for key, value in eval_results.items():
if isinstance(value, (int, float)):
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")
print("\n" + "=" * 80)
def create_visualizations(self):
"""Create summary visualizations."""
print("\n" + "=" * 80)
print("📊 Creating Visualizations")
print("=" * 80)
# Create comparison chart
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
# Collect metrics
datasets = []
color_accuracies = []
hierarchy_accuracies = []
for eval_name, eval_results in self.results.items():
if eval_results and isinstance(eval_results, dict):
datasets.append(eval_name)
# Try to get color accuracy
color_acc = eval_results.get('color_nn_accuracy', 0)
color_accuracies.append(color_acc)
# Try to get hierarchy accuracy
hier_acc = eval_results.get('hierarchy_nn_accuracy', 0)
hierarchy_accuracies.append(hier_acc)
# Plot color accuracies
if color_accuracies:
axes[0].bar(datasets, color_accuracies, color='skyblue')
axes[0].set_title('Color Classification Accuracy', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_ylim([0, 1])
axes[0].grid(axis='y', alpha=0.3)
# Add value labels
for i, v in enumerate(color_accuracies):
axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10)
# Plot hierarchy accuracies
if hierarchy_accuracies:
axes[1].bar(datasets, hierarchy_accuracies, color='lightcoral')
axes[1].set_title('Hierarchy Classification Accuracy', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_ylim([0, 1])
axes[1].grid(axis='y', alpha=0.3)
# Add value labels
for i, v in enumerate(hierarchy_accuracies):
axes[1].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10)
plt.tight_layout()
# Save figure
fig_path = self.run_dir / "summary_comparison.png"
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"✅ Visualization saved to: {fig_path}")
def run_all(self):
"""Run all evaluations."""
print("=" * 80)
print("🚀 GAP-CLIP Comprehensive Evaluation")
print("=" * 80)
print(f"Repository: {self.repo_id}")
print(f"Output directory: {self.run_dir}\n")
# Load models
if not self.load_models():
print("❌ Failed to load models. Exiting.")
return False
# Run evaluations
self.run_fashion_mnist_evaluation()
self.run_kaggle_evaluation()
self.run_local_evaluation()
# Generate summary and visualizations
summary = self.generate_summary()
self.create_visualizations()
print("\n" + "=" * 80)
print("🎉 Evaluation Complete!")
print("=" * 80)
print(f"Results saved to: {self.run_dir}")
print(f" - summary.json: Detailed results")
print(f" - summary_comparison.png: Visual comparison")
print("=" * 80)
return True
def main():
"""Main function for command-line usage."""
parser = argparse.ArgumentParser(
description="Run comprehensive evaluation on GAP-CLIP",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--repo-id",
type=str,
default="Leacb4/gap-clip",
help="Hugging Face repository ID (default: Leacb4/gap-clip)"
)
parser.add_argument(
"--output",
type=str,
default="evaluation_results",
help="Output directory for results (default: evaluation_results)"
)
args = parser.parse_args()
# Create runner and execute
runner = EvaluationRunner(
repo_id=args.repo_id,
output_dir=args.output
)
success = runner.run_all()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()