File size: 10,534 Bytes
1419f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#!/usr/bin/env python3
"""
Comprehensive Evaluation Runner for GAP-CLIP
=============================================

This script runs all available evaluations on the GAP-CLIP model and generates
a comprehensive report with metrics, visualizations, and comparisons.

Usage:
    python run_all_evaluations.py [--repo-id REPO_ID] [--output OUTPUT_DIR]

Features:
    - Runs all evaluation scripts
    - Generates summary report
    - Creates visualizations
    - Compares with baseline models
    - Saves results to organized directory

Author: Lea Attia Sarfati
"""

import os
import sys
import json
import argparse
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

# Import evaluation modules
try:
    from evaluation.main_model_evaluation import (
        evaluate_fashion_mnist,
        evaluate_kaggle_marqo,
        evaluate_local_validation
    )
    from example_usage import load_models_from_hf
except ImportError as e:
    print(f"⚠️  Import error: {e}")
    print("Make sure you're running from the correct directory")
    sys.exit(1)


class EvaluationRunner:
    """
    Comprehensive evaluation runner for GAP-CLIP.
    
    Runs all available evaluations and generates a summary report.
    """
    
    def __init__(self, repo_id: str, output_dir: str = "evaluation_results"):
        """
        Initialize the evaluation runner.
        
        Args:
            repo_id: Hugging Face repository ID
            output_dir: Directory to save results
        """
        self.repo_id = repo_id
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True, parents=True)
        
        # Create timestamp for this run
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.run_dir = self.output_dir / f"run_{self.timestamp}"
        self.run_dir.mkdir(exist_ok=True)
        
        self.results = {}
        self.models = None
        
    def load_models(self):
        """Load models from Hugging Face."""
        print("=" * 80)
        print("πŸ“₯ Loading Models")
        print("=" * 80)
        
        try:
            self.models = load_models_from_hf(self.repo_id)
            print("βœ… Models loaded successfully\n")
            return True
        except Exception as e:
            print(f"❌ Failed to load models: {e}\n")
            return False
    
    def run_fashion_mnist_evaluation(self):
        """Run Fashion-MNIST evaluation."""
        print("\n" + "=" * 80)
        print("πŸ‘• Fashion-MNIST Evaluation")
        print("=" * 80)
        
        try:
            results = evaluate_fashion_mnist(
                model=self.models['main_model'],
                processor=self.models['processor'],
                device=self.models['device']
            )
            
            self.results['fashion_mnist'] = results
            print("βœ… Fashion-MNIST evaluation completed")
            return results
            
        except Exception as e:
            print(f"❌ Fashion-MNIST evaluation failed: {e}")
            return None
    
    def run_kaggle_evaluation(self):
        """Run KAGL Marqo evaluation."""
        print("\n" + "=" * 80)
        print("πŸ›οΈ  KAGL Marqo Evaluation")
        print("=" * 80)
        
        try:
            results = evaluate_kaggle_marqo(
                model=self.models['main_model'],
                processor=self.models['processor'],
                device=self.models['device']
            )
            
            self.results['kaggle_marqo'] = results
            print("βœ… KAGL Marqo evaluation completed")
            return results
            
        except Exception as e:
            print(f"❌ KAGL Marqo evaluation failed: {e}")
            return None
    
    def run_local_evaluation(self):
        """Run local validation evaluation."""
        print("\n" + "=" * 80)
        print("πŸ“ Local Validation Evaluation")
        print("=" * 80)
        
        try:
            results = evaluate_local_validation(
                model=self.models['main_model'],
                processor=self.models['processor'],
                device=self.models['device']
            )
            
            self.results['local_validation'] = results
            print("βœ… Local validation evaluation completed")
            return results
            
        except Exception as e:
            print(f"❌ Local validation evaluation failed: {e}")
            return None
    
    def generate_summary(self):
        """Generate summary report."""
        print("\n" + "=" * 80)
        print("πŸ“Š Generating Summary Report")
        print("=" * 80)
        
        summary = {
            'timestamp': self.timestamp,
            'repo_id': self.repo_id,
            'evaluations': {}
        }
        
        # Collect all results
        for eval_name, eval_results in self.results.items():
            if eval_results:
                summary['evaluations'][eval_name] = eval_results
        
        # Save to JSON
        summary_path = self.run_dir / "summary.json"
        with open(summary_path, 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"βœ… Summary saved to: {summary_path}")
        
        # Print summary
        self.print_summary(summary)
        
        return summary
    
    def print_summary(self, summary):
        """Print formatted summary."""
        print("\n" + "=" * 80)
        print("πŸ“ˆ Evaluation Summary")
        print("=" * 80)
        print(f"\nRepository: {summary['repo_id']}")
        print(f"Timestamp: {summary['timestamp']}\n")
        
        for eval_name, eval_results in summary['evaluations'].items():
            print(f"\n{'─' * 40}")
            print(f"πŸ“Š {eval_name.upper()}")
            print(f"{'─' * 40}")
            
            if isinstance(eval_results, dict):
                for key, value in eval_results.items():
                    if isinstance(value, (int, float)):
                        print(f"  {key}: {value:.4f}")
                    else:
                        print(f"  {key}: {value}")
        
        print("\n" + "=" * 80)
    
    def create_visualizations(self):
        """Create summary visualizations."""
        print("\n" + "=" * 80)
        print("πŸ“Š Creating Visualizations")
        print("=" * 80)
        
        # Create comparison chart
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Collect metrics
        datasets = []
        color_accuracies = []
        hierarchy_accuracies = []
        
        for eval_name, eval_results in self.results.items():
            if eval_results and isinstance(eval_results, dict):
                datasets.append(eval_name)
                
                # Try to get color accuracy
                color_acc = eval_results.get('color_nn_accuracy', 0)
                color_accuracies.append(color_acc)
                
                # Try to get hierarchy accuracy
                hier_acc = eval_results.get('hierarchy_nn_accuracy', 0)
                hierarchy_accuracies.append(hier_acc)
        
        # Plot color accuracies
        if color_accuracies:
            axes[0].bar(datasets, color_accuracies, color='skyblue')
            axes[0].set_title('Color Classification Accuracy', fontsize=14, fontweight='bold')
            axes[0].set_ylabel('Accuracy', fontsize=12)
            axes[0].set_ylim([0, 1])
            axes[0].grid(axis='y', alpha=0.3)
            
            # Add value labels
            for i, v in enumerate(color_accuracies):
                axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10)
        
        # Plot hierarchy accuracies
        if hierarchy_accuracies:
            axes[1].bar(datasets, hierarchy_accuracies, color='lightcoral')
            axes[1].set_title('Hierarchy Classification Accuracy', fontsize=14, fontweight='bold')
            axes[1].set_ylabel('Accuracy', fontsize=12)
            axes[1].set_ylim([0, 1])
            axes[1].grid(axis='y', alpha=0.3)
            
            # Add value labels
            for i, v in enumerate(hierarchy_accuracies):
                axes[1].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10)
        
        plt.tight_layout()
        
        # Save figure
        fig_path = self.run_dir / "summary_comparison.png"
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"βœ… Visualization saved to: {fig_path}")
    
    def run_all(self):
        """Run all evaluations."""
        print("=" * 80)
        print("πŸš€ GAP-CLIP Comprehensive Evaluation")
        print("=" * 80)
        print(f"Repository: {self.repo_id}")
        print(f"Output directory: {self.run_dir}\n")
        
        # Load models
        if not self.load_models():
            print("❌ Failed to load models. Exiting.")
            return False
        
        # Run evaluations
        self.run_fashion_mnist_evaluation()
        self.run_kaggle_evaluation()
        self.run_local_evaluation()
        
        # Generate summary and visualizations
        summary = self.generate_summary()
        self.create_visualizations()
        
        print("\n" + "=" * 80)
        print("πŸŽ‰ Evaluation Complete!")
        print("=" * 80)
        print(f"Results saved to: {self.run_dir}")
        print(f"  - summary.json: Detailed results")
        print(f"  - summary_comparison.png: Visual comparison")
        print("=" * 80)
        
        return True


def main():
    """Main function for command-line usage."""
    parser = argparse.ArgumentParser(
        description="Run comprehensive evaluation on GAP-CLIP",
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    
    parser.add_argument(
        "--repo-id",
        type=str,
        default="Leacb4/gap-clip",
        help="Hugging Face repository ID (default: Leacb4/gap-clip)"
    )
    
    parser.add_argument(
        "--output",
        type=str,
        default="evaluation_results",
        help="Output directory for results (default: evaluation_results)"
    )
    
    args = parser.parse_args()
    
    # Create runner and execute
    runner = EvaluationRunner(
        repo_id=args.repo_id,
        output_dir=args.output
    )
    
    success = runner.run_all()
    sys.exit(0 if success else 1)


if __name__ == "__main__":
    main()