|
|
|
|
|
""" |
|
|
Comprehensive Evaluation Runner for GAP-CLIP |
|
|
============================================= |
|
|
|
|
|
This script runs all available evaluations on the GAP-CLIP model and generates |
|
|
a comprehensive report with metrics, visualizations, and comparisons. |
|
|
|
|
|
Usage: |
|
|
python run_all_evaluations.py [--repo-id REPO_ID] [--output OUTPUT_DIR] |
|
|
|
|
|
Features: |
|
|
- Runs all evaluation scripts |
|
|
- Generates summary report |
|
|
- Creates visualizations |
|
|
- Compares with baseline models |
|
|
- Saves results to organized directory |
|
|
|
|
|
Author: Lea Attia Sarfati |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import matplotlib.pyplot as plt |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
|
|
|
try: |
|
|
from evaluation.main_model_evaluation import ( |
|
|
evaluate_fashion_mnist, |
|
|
evaluate_kaggle_marqo, |
|
|
evaluate_local_validation |
|
|
) |
|
|
from example_usage import load_models_from_hf |
|
|
except ImportError as e: |
|
|
print(f"⚠️ Import error: {e}") |
|
|
print("Make sure you're running from the correct directory") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
class EvaluationRunner: |
|
|
""" |
|
|
Comprehensive evaluation runner for GAP-CLIP. |
|
|
|
|
|
Runs all available evaluations and generates a summary report. |
|
|
""" |
|
|
|
|
|
def __init__(self, repo_id: str, output_dir: str = "evaluation_results"): |
|
|
""" |
|
|
Initialize the evaluation runner. |
|
|
|
|
|
Args: |
|
|
repo_id: Hugging Face repository ID |
|
|
output_dir: Directory to save results |
|
|
""" |
|
|
self.repo_id = repo_id |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
self.run_dir = self.output_dir / f"run_{self.timestamp}" |
|
|
self.run_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.results = {} |
|
|
self.models = None |
|
|
|
|
|
def load_models(self): |
|
|
"""Load models from Hugging Face.""" |
|
|
print("=" * 80) |
|
|
print("📥 Loading Models") |
|
|
print("=" * 80) |
|
|
|
|
|
try: |
|
|
self.models = load_models_from_hf(self.repo_id) |
|
|
print("✅ Models loaded successfully\n") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load models: {e}\n") |
|
|
return False |
|
|
|
|
|
def run_fashion_mnist_evaluation(self): |
|
|
"""Run Fashion-MNIST evaluation.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("👕 Fashion-MNIST Evaluation") |
|
|
print("=" * 80) |
|
|
|
|
|
try: |
|
|
results = evaluate_fashion_mnist( |
|
|
model=self.models['main_model'], |
|
|
processor=self.models['processor'], |
|
|
device=self.models['device'] |
|
|
) |
|
|
|
|
|
self.results['fashion_mnist'] = results |
|
|
print("✅ Fashion-MNIST evaluation completed") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Fashion-MNIST evaluation failed: {e}") |
|
|
return None |
|
|
|
|
|
def run_kaggle_evaluation(self): |
|
|
"""Run KAGL Marqo evaluation.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("🛍️ KAGL Marqo Evaluation") |
|
|
print("=" * 80) |
|
|
|
|
|
try: |
|
|
results = evaluate_kaggle_marqo( |
|
|
model=self.models['main_model'], |
|
|
processor=self.models['processor'], |
|
|
device=self.models['device'] |
|
|
) |
|
|
|
|
|
self.results['kaggle_marqo'] = results |
|
|
print("✅ KAGL Marqo evaluation completed") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ KAGL Marqo evaluation failed: {e}") |
|
|
return None |
|
|
|
|
|
def run_local_evaluation(self): |
|
|
"""Run local validation evaluation.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("📁 Local Validation Evaluation") |
|
|
print("=" * 80) |
|
|
|
|
|
try: |
|
|
results = evaluate_local_validation( |
|
|
model=self.models['main_model'], |
|
|
processor=self.models['processor'], |
|
|
device=self.models['device'] |
|
|
) |
|
|
|
|
|
self.results['local_validation'] = results |
|
|
print("✅ Local validation evaluation completed") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Local validation evaluation failed: {e}") |
|
|
return None |
|
|
|
|
|
def generate_summary(self): |
|
|
"""Generate summary report.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("📊 Generating Summary Report") |
|
|
print("=" * 80) |
|
|
|
|
|
summary = { |
|
|
'timestamp': self.timestamp, |
|
|
'repo_id': self.repo_id, |
|
|
'evaluations': {} |
|
|
} |
|
|
|
|
|
|
|
|
for eval_name, eval_results in self.results.items(): |
|
|
if eval_results: |
|
|
summary['evaluations'][eval_name] = eval_results |
|
|
|
|
|
|
|
|
summary_path = self.run_dir / "summary.json" |
|
|
with open(summary_path, 'w') as f: |
|
|
json.dump(summary, f, indent=2) |
|
|
|
|
|
print(f"✅ Summary saved to: {summary_path}") |
|
|
|
|
|
|
|
|
self.print_summary(summary) |
|
|
|
|
|
return summary |
|
|
|
|
|
def print_summary(self, summary): |
|
|
"""Print formatted summary.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("📈 Evaluation Summary") |
|
|
print("=" * 80) |
|
|
print(f"\nRepository: {summary['repo_id']}") |
|
|
print(f"Timestamp: {summary['timestamp']}\n") |
|
|
|
|
|
for eval_name, eval_results in summary['evaluations'].items(): |
|
|
print(f"\n{'─' * 40}") |
|
|
print(f"📊 {eval_name.upper()}") |
|
|
print(f"{'─' * 40}") |
|
|
|
|
|
if isinstance(eval_results, dict): |
|
|
for key, value in eval_results.items(): |
|
|
if isinstance(value, (int, float)): |
|
|
print(f" {key}: {value:.4f}") |
|
|
else: |
|
|
print(f" {key}: {value}") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
|
|
|
def create_visualizations(self): |
|
|
"""Create summary visualizations.""" |
|
|
print("\n" + "=" * 80) |
|
|
print("📊 Creating Visualizations") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(15, 6)) |
|
|
|
|
|
|
|
|
datasets = [] |
|
|
color_accuracies = [] |
|
|
hierarchy_accuracies = [] |
|
|
|
|
|
for eval_name, eval_results in self.results.items(): |
|
|
if eval_results and isinstance(eval_results, dict): |
|
|
datasets.append(eval_name) |
|
|
|
|
|
|
|
|
color_acc = eval_results.get('color_nn_accuracy', 0) |
|
|
color_accuracies.append(color_acc) |
|
|
|
|
|
|
|
|
hier_acc = eval_results.get('hierarchy_nn_accuracy', 0) |
|
|
hierarchy_accuracies.append(hier_acc) |
|
|
|
|
|
|
|
|
if color_accuracies: |
|
|
axes[0].bar(datasets, color_accuracies, color='skyblue') |
|
|
axes[0].set_title('Color Classification Accuracy', fontsize=14, fontweight='bold') |
|
|
axes[0].set_ylabel('Accuracy', fontsize=12) |
|
|
axes[0].set_ylim([0, 1]) |
|
|
axes[0].grid(axis='y', alpha=0.3) |
|
|
|
|
|
|
|
|
for i, v in enumerate(color_accuracies): |
|
|
axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10) |
|
|
|
|
|
|
|
|
if hierarchy_accuracies: |
|
|
axes[1].bar(datasets, hierarchy_accuracies, color='lightcoral') |
|
|
axes[1].set_title('Hierarchy Classification Accuracy', fontsize=14, fontweight='bold') |
|
|
axes[1].set_ylabel('Accuracy', fontsize=12) |
|
|
axes[1].set_ylim([0, 1]) |
|
|
axes[1].grid(axis='y', alpha=0.3) |
|
|
|
|
|
|
|
|
for i, v in enumerate(hierarchy_accuracies): |
|
|
axes[1].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
fig_path = self.run_dir / "summary_comparison.png" |
|
|
plt.savefig(fig_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"✅ Visualization saved to: {fig_path}") |
|
|
|
|
|
def run_all(self): |
|
|
"""Run all evaluations.""" |
|
|
print("=" * 80) |
|
|
print("🚀 GAP-CLIP Comprehensive Evaluation") |
|
|
print("=" * 80) |
|
|
print(f"Repository: {self.repo_id}") |
|
|
print(f"Output directory: {self.run_dir}\n") |
|
|
|
|
|
|
|
|
if not self.load_models(): |
|
|
print("❌ Failed to load models. Exiting.") |
|
|
return False |
|
|
|
|
|
|
|
|
self.run_fashion_mnist_evaluation() |
|
|
self.run_kaggle_evaluation() |
|
|
self.run_local_evaluation() |
|
|
|
|
|
|
|
|
summary = self.generate_summary() |
|
|
self.create_visualizations() |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("🎉 Evaluation Complete!") |
|
|
print("=" * 80) |
|
|
print(f"Results saved to: {self.run_dir}") |
|
|
print(f" - summary.json: Detailed results") |
|
|
print(f" - summary_comparison.png: Visual comparison") |
|
|
print("=" * 80) |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function for command-line usage.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Run comprehensive evaluation on GAP-CLIP", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--repo-id", |
|
|
type=str, |
|
|
default="Leacb4/gap-clip", |
|
|
help="Hugging Face repository ID (default: Leacb4/gap-clip)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=str, |
|
|
default="evaluation_results", |
|
|
help="Output directory for results (default: evaluation_results)" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
runner = EvaluationRunner( |
|
|
repo_id=args.repo_id, |
|
|
output_dir=args.output |
|
|
) |
|
|
|
|
|
success = runner.run_all() |
|
|
sys.exit(0 if success else 1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|