#!/usr/bin/env python3 """ Comprehensive Evaluation Runner for GAP-CLIP ============================================= This script runs all available evaluations on the GAP-CLIP model and generates a comprehensive report with metrics, visualizations, and comparisons. Usage: python run_all_evaluations.py [--repo-id REPO_ID] [--output OUTPUT_DIR] Features: - Runs all evaluation scripts - Generates summary report - Creates visualizations - Compares with baseline models - Saves results to organized directory Author: Lea Attia Sarfati """ import os import sys import json import argparse from pathlib import Path from datetime import datetime import matplotlib.pyplot as plt import pandas as pd # Add parent directory to path sys.path.insert(0, str(Path(__file__).parent.parent)) # Import evaluation modules try: from evaluation.main_model_evaluation import ( evaluate_fashion_mnist, evaluate_kaggle_marqo, evaluate_local_validation ) from example_usage import load_models_from_hf except ImportError as e: print(f"⚠️ Import error: {e}") print("Make sure you're running from the correct directory") sys.exit(1) class EvaluationRunner: """ Comprehensive evaluation runner for GAP-CLIP. Runs all available evaluations and generates a summary report. """ def __init__(self, repo_id: str, output_dir: str = "evaluation_results"): """ Initialize the evaluation runner. Args: repo_id: Hugging Face repository ID output_dir: Directory to save results """ self.repo_id = repo_id self.output_dir = Path(output_dir) self.output_dir.mkdir(exist_ok=True, parents=True) # Create timestamp for this run self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.run_dir = self.output_dir / f"run_{self.timestamp}" self.run_dir.mkdir(exist_ok=True) self.results = {} self.models = None def load_models(self): """Load models from Hugging Face.""" print("=" * 80) print("📥 Loading Models") print("=" * 80) try: self.models = load_models_from_hf(self.repo_id) print("✅ Models loaded successfully\n") return True except Exception as e: print(f"❌ Failed to load models: {e}\n") return False def run_fashion_mnist_evaluation(self): """Run Fashion-MNIST evaluation.""" print("\n" + "=" * 80) print("👕 Fashion-MNIST Evaluation") print("=" * 80) try: results = evaluate_fashion_mnist( model=self.models['main_model'], processor=self.models['processor'], device=self.models['device'] ) self.results['fashion_mnist'] = results print("✅ Fashion-MNIST evaluation completed") return results except Exception as e: print(f"❌ Fashion-MNIST evaluation failed: {e}") return None def run_kaggle_evaluation(self): """Run KAGL Marqo evaluation.""" print("\n" + "=" * 80) print("🛍️ KAGL Marqo Evaluation") print("=" * 80) try: results = evaluate_kaggle_marqo( model=self.models['main_model'], processor=self.models['processor'], device=self.models['device'] ) self.results['kaggle_marqo'] = results print("✅ KAGL Marqo evaluation completed") return results except Exception as e: print(f"❌ KAGL Marqo evaluation failed: {e}") return None def run_local_evaluation(self): """Run local validation evaluation.""" print("\n" + "=" * 80) print("📁 Local Validation Evaluation") print("=" * 80) try: results = evaluate_local_validation( model=self.models['main_model'], processor=self.models['processor'], device=self.models['device'] ) self.results['local_validation'] = results print("✅ Local validation evaluation completed") return results except Exception as e: print(f"❌ Local validation evaluation failed: {e}") return None def generate_summary(self): """Generate summary report.""" print("\n" + "=" * 80) print("📊 Generating Summary Report") print("=" * 80) summary = { 'timestamp': self.timestamp, 'repo_id': self.repo_id, 'evaluations': {} } # Collect all results for eval_name, eval_results in self.results.items(): if eval_results: summary['evaluations'][eval_name] = eval_results # Save to JSON summary_path = self.run_dir / "summary.json" with open(summary_path, 'w') as f: json.dump(summary, f, indent=2) print(f"✅ Summary saved to: {summary_path}") # Print summary self.print_summary(summary) return summary def print_summary(self, summary): """Print formatted summary.""" print("\n" + "=" * 80) print("📈 Evaluation Summary") print("=" * 80) print(f"\nRepository: {summary['repo_id']}") print(f"Timestamp: {summary['timestamp']}\n") for eval_name, eval_results in summary['evaluations'].items(): print(f"\n{'─' * 40}") print(f"📊 {eval_name.upper()}") print(f"{'─' * 40}") if isinstance(eval_results, dict): for key, value in eval_results.items(): if isinstance(value, (int, float)): print(f" {key}: {value:.4f}") else: print(f" {key}: {value}") print("\n" + "=" * 80) def create_visualizations(self): """Create summary visualizations.""" print("\n" + "=" * 80) print("📊 Creating Visualizations") print("=" * 80) # Create comparison chart fig, axes = plt.subplots(1, 2, figsize=(15, 6)) # Collect metrics datasets = [] color_accuracies = [] hierarchy_accuracies = [] for eval_name, eval_results in self.results.items(): if eval_results and isinstance(eval_results, dict): datasets.append(eval_name) # Try to get color accuracy color_acc = eval_results.get('color_nn_accuracy', 0) color_accuracies.append(color_acc) # Try to get hierarchy accuracy hier_acc = eval_results.get('hierarchy_nn_accuracy', 0) hierarchy_accuracies.append(hier_acc) # Plot color accuracies if color_accuracies: axes[0].bar(datasets, color_accuracies, color='skyblue') axes[0].set_title('Color Classification Accuracy', fontsize=14, fontweight='bold') axes[0].set_ylabel('Accuracy', fontsize=12) axes[0].set_ylim([0, 1]) axes[0].grid(axis='y', alpha=0.3) # Add value labels for i, v in enumerate(color_accuracies): axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10) # Plot hierarchy accuracies if hierarchy_accuracies: axes[1].bar(datasets, hierarchy_accuracies, color='lightcoral') axes[1].set_title('Hierarchy Classification Accuracy', fontsize=14, fontweight='bold') axes[1].set_ylabel('Accuracy', fontsize=12) axes[1].set_ylim([0, 1]) axes[1].grid(axis='y', alpha=0.3) # Add value labels for i, v in enumerate(hierarchy_accuracies): axes[1].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10) plt.tight_layout() # Save figure fig_path = self.run_dir / "summary_comparison.png" plt.savefig(fig_path, dpi=300, bbox_inches='tight') plt.close() print(f"✅ Visualization saved to: {fig_path}") def run_all(self): """Run all evaluations.""" print("=" * 80) print("🚀 GAP-CLIP Comprehensive Evaluation") print("=" * 80) print(f"Repository: {self.repo_id}") print(f"Output directory: {self.run_dir}\n") # Load models if not self.load_models(): print("❌ Failed to load models. Exiting.") return False # Run evaluations self.run_fashion_mnist_evaluation() self.run_kaggle_evaluation() self.run_local_evaluation() # Generate summary and visualizations summary = self.generate_summary() self.create_visualizations() print("\n" + "=" * 80) print("🎉 Evaluation Complete!") print("=" * 80) print(f"Results saved to: {self.run_dir}") print(f" - summary.json: Detailed results") print(f" - summary_comparison.png: Visual comparison") print("=" * 80) return True def main(): """Main function for command-line usage.""" parser = argparse.ArgumentParser( description="Run comprehensive evaluation on GAP-CLIP", formatter_class=argparse.RawDescriptionHelpFormatter ) parser.add_argument( "--repo-id", type=str, default="Leacb4/gap-clip", help="Hugging Face repository ID (default: Leacb4/gap-clip)" ) parser.add_argument( "--output", type=str, default="evaluation_results", help="Output directory for results (default: evaluation_results)" ) args = parser.parse_args() # Create runner and execute runner = EvaluationRunner( repo_id=args.repo_id, output_dir=args.output ) success = runner.run_all() sys.exit(0 if success else 1) if __name__ == "__main__": main()