File size: 25,571 Bytes
9a0283b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
#!/usr/bin/env python3
"""
Final Comprehensive Analysis Script
Analyzes the comprehensive test results: 256 models × 20 tests
Generates all figures and tables for the paper
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime
import json

# Set style for publication-quality figures
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

def load_results():
    """Load the comprehensive test results."""
    excel_file = 'comprehensive_20_tests_results_20251014_153008.xlsx'
    json_file = 'comprehensive_20_tests_results_20251014_153008.json'

    print("="*80)
    print("LOADING COMPREHENSIVE TEST RESULTS")
    print("="*80)
    print(f"\nLoading from: {excel_file}")

    # Load Excel file
    xls = pd.ExcelFile(excel_file)

    # Load all sheets
    all_results = pd.read_excel(xls, 'All Results')
    model_rankings = pd.read_excel(xls, 'Model Rankings', index_col=0)
    test_difficulty = pd.read_excel(xls, 'Test Difficulty')
    category_performance = pd.read_excel(xls, 'Category Performance', index_col=0)

    print(f"  Total results: {len(all_results)}")
    print(f"  Models tested: {all_results['model'].nunique()}")
    print(f"  Tests conducted: {all_results['test_id'].nunique()}")

    # Load JSON for additional metadata
    with open(json_file, 'r') as f:
        json_data = json.load(f)

    return all_results, model_rankings, test_difficulty, category_performance, json_data

def print_summary_statistics(all_results, json_data):
    """Print comprehensive summary statistics."""
    print("\n" + "="*80)
    print("SUMMARY STATISTICS")
    print("="*80)

    metadata = json_data['metadata']
    summary = json_data['summary']

    print(f"\nDataset Overview:")
    print(f"  Total Models: {metadata['total_models']}")
    print(f"  Total Tests: {metadata['total_tests']}")
    print(f"  Total Evaluations: {metadata['total_results']}")
    print(f"  Timestamp: {metadata['timestamp']}")

    print(f"\nOverall Performance:")
    print(f"  Overall Pass Rate: {summary['overall_pass_rate']:.1f}%")
    print(f"  Best Model: {summary['best_model']} ({summary['best_model_score']:.1f}%)")
    print(f"  Hardest Test: Test {summary['hardest_test']} ({summary['hardest_test_pass_rate']:.1f}% pass rate)")

    # API success rate
    success_rate = (all_results['status'] == 'success').mean() * 100
    print(f"  API Success Rate: {success_rate:.1f}%")

    # Response time statistics
    avg_response_time = all_results[all_results['response_time'] > 0]['response_time'].mean()
    print(f"  Average Response Time: {avg_response_time:.2f}s")

def print_top_models(model_rankings):
    """Print top performing models."""
    print("\n" + "="*80)
    print("TOP 20 PERFORMING MODELS")
    print("="*80)

    print(f"\n{'Rank':<6} {'Pass Rate':<12} {'Model'}")
    print("-" * 80)

    for idx, (model, row) in enumerate(model_rankings.head(20).iterrows(), 1):
        pass_rate = row['Pass Rate (%)']
        print(f"{idx:<6} {pass_rate:>6.1f}%      {model}")

def print_test_difficulty(test_difficulty):
    """Print test difficulty analysis."""
    print("\n" + "="*80)
    print("TEST DIFFICULTY ANALYSIS (HARDEST TO EASIEST)")
    print("="*80)

    print(f"\n{'ID':<4} {'Pass Rate':<12} {'Category':<25} {'Test Name'}")
    print("-" * 95)

    # Sort by pass rate (ascending = hardest first)
    sorted_tests = test_difficulty.sort_values('Pass Rate (%)')

    for _, row in sorted_tests.iterrows():
        test_id = int(row['Test ID'])
        pass_rate = row['Pass Rate (%)']
        category = row['category'][:23] if pd.notna(row['category']) else ''
        name = row['name'][:45]
        print(f"{test_id:<4} {pass_rate:>6.1f}%      {category:<25} {name}")

def print_category_analysis(category_performance):
    """Print category performance analysis."""
    print("\n" + "="*80)
    print("PERFORMANCE BY CATEGORY")
    print("="*80)

    # Sort by pass rate
    sorted_cats = category_performance.sort_values('Pass Rate (%)')

    print(f"\n{'Category':<30} {'Pass Rate'}")
    print("-" * 45)

    for category, row in sorted_cats.iterrows():
        pass_rate = row['Pass Rate (%)']
        print(f"{category:<30} {pass_rate:>6.1f}%")

def analyze_by_provider(all_results):
    """Analyze performance by model provider."""
    print("\n" + "="*80)
    print("PROVIDER ANALYSIS")
    print("="*80)

    # Extract provider from model name
    all_results['provider'] = all_results['model'].apply(
        lambda x: x.split('/')[0] if '/' in x else 'other'
    )

    # Calculate provider statistics
    provider_stats = all_results.groupby('provider').agg({
        'passed': 'mean',
        'model': 'nunique',
        'test_id': 'count'
    }).round(3)

    provider_stats.columns = ['pass_rate', 'num_models', 'total_tests']
    provider_stats['pass_rate'] = provider_stats['pass_rate'] * 100
    provider_stats = provider_stats.sort_values('pass_rate', ascending=False)

    # Filter to providers with at least 3 models
    provider_stats_filtered = provider_stats[provider_stats['num_models'] >= 3]

    print(f"\n{'Provider':<20} {'Pass Rate':<12} {'Models':<10} {'Tests'}")
    print("-" * 55)

    for provider, row in provider_stats_filtered.head(15).iterrows():
        print(f"{provider:<20} {row['pass_rate']:>6.1f}%      {int(row['num_models']):<10} {int(row['total_tests'])}")

    return provider_stats_filtered

def create_heatmap(all_results, output_file='fig1_heatmap.pdf'):
    """Create heatmap of model performance on all tests."""
    print("\n" + "="*80)
    print("CREATING FIGURE 1: PERFORMANCE HEATMAP (Top 50 Models)")
    print("="*80)

    # Create pivot table
    pivot = all_results.pivot_table(
        index='model',
        columns='test_id',
        values='passed',
        aggfunc='first'
    )

    # Select top 50 models by overall performance
    model_scores = pivot.mean(axis=1).sort_values(ascending=False)
    top_50_models = model_scores.head(50).index
    pivot_top50 = pivot.loc[top_50_models]

    # Create figure
    fig, ax = plt.subplots(figsize=(14, 12))

    # Create heatmap
    sns.heatmap(
        pivot_top50,
        cmap=['#d73027', '#91cf60'],  # Red for fail, green for pass
        cbar_kws={'label': 'Pass (1) / Fail (0)', 'ticks': [0, 1]},
        xticklabels=True,
        yticklabels=True,
        vmin=0,
        vmax=1,
        linewidths=0.3,
        linecolor='gray',
        ax=ax
    )

    plt.title('Model Performance on 20 Diagnostic Tests (Top 50 Models)',
             fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Test ID', fontsize=12)
    plt.ylabel('Model', fontsize=12)
    plt.yticks(fontsize=7)
    plt.xticks(fontsize=10)
    plt.tight_layout()

    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"✅ Saved: {output_file}")
    plt.close()

def create_provider_chart(provider_stats, output_file='fig2_provider.pdf'):
    """Create bar chart of provider performance."""
    print("\n" + "="*80)
    print("CREATING FIGURE 2: PROVIDER COMPARISON")
    print("="*80)

    # Select top 12 providers
    top_providers = provider_stats.head(12)

    # Create figure
    fig, ax = plt.subplots(figsize=(12, 7))

    # Create bar plot
    x_pos = np.arange(len(top_providers))
    colors = sns.color_palette('husl', len(top_providers))
    bars = ax.bar(x_pos, top_providers['pass_rate'], color=colors, edgecolor='black', linewidth=0.5)

    # Customize plot
    ax.set_xlabel('Model Provider', fontsize=13, fontweight='bold')
    ax.set_ylabel('Average Pass Rate (%)', fontsize=13, fontweight='bold')
    ax.set_title('Performance by Model Provider (≥3 models)',
                fontsize=16, fontweight='bold', pad=20)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(top_providers.index, rotation=45, ha='right', fontsize=11)
    ax.set_ylim([0, max(top_providers['pass_rate']) * 1.15])
    ax.grid(axis='y', alpha=0.3)

    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 1,
               f'{height:.1f}%', ha='center', va='bottom', fontsize=10, fontweight='bold')

    # Add average line
    avg_rate = top_providers['pass_rate'].mean()
    ax.axhline(y=avg_rate, color='red', linestyle='--', alpha=0.7, linewidth=2,
              label=f'Average: {avg_rate:.1f}%')
    ax.legend(fontsize=11)

    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"✅ Saved: {output_file}")
    plt.close()

def create_difficulty_chart(test_difficulty, output_file='fig3_difficulty.pdf'):
    """Create horizontal bar chart of test difficulty."""
    print("\n" + "="*80)
    print("CREATING FIGURE 3: TEST DIFFICULTY RANKING")
    print("="*80)

    # Sort by difficulty (ascending pass rate = harder)
    sorted_tests = test_difficulty.sort_values('Pass Rate (%)')

    # Create shortened labels
    labels = []
    for _, row in sorted_tests.iterrows():
        test_id = int(row['Test ID'])
        name = row['name']
        # Shorten long names
        if len(name) > 35:
            name = name[:32] + '...'
        labels.append(f"T{test_id}: {name}")

    pass_rates = sorted_tests['Pass Rate (%)'].values

    # Create figure
    fig, ax = plt.subplots(figsize=(12, 10))

    # Color based on difficulty
    colors = []
    for rate in pass_rates:
        if rate < 10:
            colors.append('#8B0000')  # Dark red - extremely hard
        elif rate < 20:
            colors.append('#d73027')  # Red - very hard
        elif rate < 40:
            colors.append('#fdae61')  # Orange - hard
        elif rate < 60:
            colors.append('#fee08b')  # Yellow - medium
        elif rate < 80:
            colors.append('#a6d96a')  # Light green - easy
        else:
            colors.append('#1a9850')  # Dark green - very easy

    # Create horizontal bar plot
    y_pos = np.arange(len(labels))
    bars = ax.barh(y_pos, pass_rates, color=colors, edgecolor='black', linewidth=0.5)

    # Customize plot
    ax.set_xlabel('Pass Rate (%)', fontsize=13, fontweight='bold')
    ax.set_ylabel('Test', fontsize=13, fontweight='bold')
    ax.set_title('Test Difficulty Ranking (Hardest to Easiest)',
                fontsize=16, fontweight='bold', pad=20)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(labels, fontsize=9)
    ax.set_xlim([0, 105])
    ax.grid(axis='x', alpha=0.3)

    # Add value labels
    for bar, rate in zip(bars, pass_rates):
        width = bar.get_width()
        ax.text(width + 1.5, bar.get_y() + bar.get_height()/2.,
               f'{rate:.1f}%', ha='left', va='center', fontsize=9, fontweight='bold')

    # Add vertical reference lines
    ax.axvline(x=50, color='black', linestyle='--', alpha=0.3, linewidth=1)
    ax.axvline(x=25, color='red', linestyle=':', alpha=0.3, linewidth=1)
    ax.axvline(x=75, color='green', linestyle=':', alpha=0.3, linewidth=1)

    # Add legend for difficulty colors
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='#8B0000', label='Extremely Hard (<10%)'),
        Patch(facecolor='#d73027', label='Very Hard (10-20%)'),
        Patch(facecolor='#fdae61', label='Hard (20-40%)'),
        Patch(facecolor='#fee08b', label='Medium (40-60%)'),
        Patch(facecolor='#a6d96a', label='Easy (60-80%)'),
        Patch(facecolor='#1a9850', label='Very Easy (>80%)')
    ]
    ax.legend(handles=legend_elements, loc='lower right', fontsize=9)

    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"✅ Saved: {output_file}")
    plt.close()

def create_category_chart(category_performance, output_file='fig4_category.pdf'):
    """Create bar chart of category performance."""
    print("\n" + "="*80)
    print("CREATING FIGURE 4: CATEGORY PERFORMANCE")
    print("="*80)

    # Sort by pass rate
    sorted_cats = category_performance.sort_values('Pass Rate (%)')

    # Create figure
    fig, ax = plt.subplots(figsize=(10, 6))

    # Create bar plot
    x_pos = np.arange(len(sorted_cats))
    colors = plt.cm.RdYlGn(sorted_cats['Pass Rate (%)'] / 100)
    bars = ax.bar(x_pos, sorted_cats['Pass Rate (%)'], color=colors, edgecolor='black', linewidth=0.8)

    # Customize plot
    ax.set_xlabel('Category', fontsize=13, fontweight='bold')
    ax.set_ylabel('Pass Rate (%)', fontsize=13, fontweight='bold')
    ax.set_title('Performance by Test Category',
                fontsize=16, fontweight='bold', pad=20)
    ax.set_xticks(x_pos)
    ax.set_xticklabels(sorted_cats.index, rotation=45, ha='right', fontsize=11)
    ax.set_ylim([0, 100])
    ax.grid(axis='y', alpha=0.3)

    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 2,
               f'{height:.1f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')

    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"✅ Saved: {output_file}")
    plt.close()

def generate_latex_tables(model_rankings, test_difficulty, category_performance, provider_stats):
    """Generate LaTeX tables for the paper."""
    print("\n" + "="*80)
    print("GENERATING LATEX TABLES")
    print("="*80)

    output_file = 'paper_tables.tex'

    with open(output_file, 'w') as f:
        # Table 1: Top Models
        f.write("% Table 1: Top 15 Performing Models\n")
        f.write("\\begin{table}[htbp]\n")
        f.write("\\centering\n")
        f.write("\\caption{Top 15 Performing Models on 20 Diagnostic Tests}\n")
        f.write("\\label{tab:top_models}\n")
        f.write("\\begin{tabular}{rlr}\n")
        f.write("\\toprule\n")
        f.write("Rank & Model & Pass Rate \\\\\n")
        f.write("\\midrule\n")

        for idx, (model, row) in enumerate(model_rankings.head(15).iterrows(), 1):
            model_escaped = model.replace('_', '\\_').replace('&', '\\&')
            pass_rate = row['Pass Rate (%)']
            f.write(f"{idx} & \\texttt{{{model_escaped}}} & {pass_rate:.1f}\\% \\\\\n")

        f.write("\\bottomrule\n")
        f.write("\\end{tabular}\n")
        f.write("\\end{table}\n\n")

        # Table 2: Test Difficulty
        f.write("% Table 2: Test Difficulty (Hardest 10)\n")
        f.write("\\begin{table}[htbp]\n")
        f.write("\\centering\n")
        f.write("\\caption{Test Difficulty Analysis (10 Hardest Tests)}\n")
        f.write("\\label{tab:test_difficulty}\n")
        f.write("\\begin{tabular}{clrl}\n")
        f.write("\\toprule\n")
        f.write("ID & Test Name & Pass Rate & Category \\\\\n")
        f.write("\\midrule\n")

        sorted_tests = test_difficulty.sort_values('Pass Rate (%)')
        for _, row in sorted_tests.head(10).iterrows():
            test_id = int(row['Test ID'])
            name = row['name'][:40]
            name_escaped = name.replace('_', '\\_').replace('&', '\\&')
            pass_rate = row['Pass Rate (%)']
            category = row['category'][:20] if pd.notna(row['category']) else ''
            category_escaped = category.replace('_', '\\_').replace('&', '\\&')
            f.write(f"{test_id} & {name_escaped} & {pass_rate:.1f}\\% & {category_escaped} \\\\\n")

        f.write("\\bottomrule\n")
        f.write("\\end{tabular}\n")
        f.write("\\end{table}\n\n")

        # Table 3: Category Performance
        f.write("% Table 3: Category Performance\n")
        f.write("\\begin{table}[htbp]\n")
        f.write("\\centering\n")
        f.write("\\caption{Performance by Test Category}\n")
        f.write("\\label{tab:category_performance}\n")
        f.write("\\begin{tabular}{lr}\n")
        f.write("\\toprule\n")
        f.write("Category & Pass Rate \\\\\n")
        f.write("\\midrule\n")

        sorted_cats = category_performance.sort_values('Pass Rate (%)')
        for category, row in sorted_cats.iterrows():
            category_escaped = category.replace('_', '\\_').replace('&', '\\&')
            pass_rate = row['Pass Rate (%)']
            f.write(f"{category_escaped} & {pass_rate:.1f}\\% \\\\\n")

        f.write("\\bottomrule\n")
        f.write("\\end{tabular}\n")
        f.write("\\end{table}\n\n")

        # Table 4: Provider Comparison
        f.write("% Table 4: Provider Comparison (Top 10)\n")
        f.write("\\begin{table}[htbp]\n")
        f.write("\\centering\n")
        f.write("\\caption{Performance by Model Provider}\n")
        f.write("\\label{tab:provider_comparison}\n")
        f.write("\\begin{tabular}{lrrr}\n")
        f.write("\\toprule\n")
        f.write("Provider & Models & Pass Rate & Tests \\\\\n")
        f.write("\\midrule\n")

        for provider, row in provider_stats.head(10).iterrows():
            provider_escaped = provider.replace('_', '\\_').replace('&', '\\&')
            num_models = int(row['num_models'])
            pass_rate = row['pass_rate']
            total_tests = int(row['total_tests'])
            f.write(f"{provider_escaped} & {num_models} & {pass_rate:.1f}\\% & {total_tests} \\\\\n")

        f.write("\\bottomrule\n")
        f.write("\\end{tabular}\n")
        f.write("\\end{table}\n")

    print(f"✅ Saved: {output_file}")

def generate_summary_report(all_results, model_rankings, test_difficulty, category_performance, provider_stats, json_data):
    """Generate comprehensive summary report."""
    print("\n" + "="*80)
    print("GENERATING FINAL SUMMARY REPORT")
    print("="*80)

    output_file = 'FINAL_RESULTS_SUMMARY.md'

    with open(output_file, 'w') as f:
        f.write("# Comprehensive LLM Instruction Following Evaluation\n")
        f.write(f"## Final Results Summary\n\n")
        f.write(f"**Analysis Date:** {datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n")
        f.write(f"**Data Source:** {json_data['metadata']['timestamp']}\n\n")

        f.write("---\n\n")
        f.write("## Executive Summary\n\n")

        metadata = json_data['metadata']
        summary = json_data['summary']

        f.write(f"This evaluation tested **{metadata['total_models']} large language models** ")
        f.write(f"on **{metadata['total_tests']} diagnostic prompts**, ")
        f.write(f"resulting in **{metadata['total_results']} individual test evaluations**.\n\n")

        f.write(f"### Key Findings\n\n")
        f.write(f"- **Overall Pass Rate:** {summary['overall_pass_rate']:.1f}%\n")
        f.write(f"- **Best Performing Model:** {summary['best_model']} ({summary['best_model_score']:.1f}%)\n")
        f.write(f"- **Hardest Test:** Test {summary['hardest_test']} ({summary['hardest_test_pass_rate']:.1f}% pass rate)\n")

        # API success rate
        success_rate = (all_results['status'] == 'success').mean() * 100
        f.write(f"- **API Success Rate:** {success_rate:.1f}%\n\n")

        f.write("---\n\n")
        f.write("## Top 20 Performing Models\n\n")
        f.write("| Rank | Model | Pass Rate |\n")
        f.write("|------|-------|----------:|\n")

        for idx, (model, row) in enumerate(model_rankings.head(20).iterrows(), 1):
            pass_rate = row['Pass Rate (%)']
            f.write(f"| {idx} | {model} | {pass_rate:.1f}% |\n")

        f.write("\n---\n\n")
        f.write("## Test Difficulty Analysis\n\n")
        f.write("### Hardest Tests (Lowest Pass Rates)\n\n")
        f.write("| ID | Test Name | Category | Pass Rate |\n")
        f.write("|----|-----------|----------|----------:|\n")

        sorted_tests = test_difficulty.sort_values('Pass Rate (%)')
        for _, row in sorted_tests.head(10).iterrows():
            test_id = int(row['Test ID'])
            name = row['name']
            category = row['category'] if pd.notna(row['category']) else 'N/A'
            pass_rate = row['Pass Rate (%)']
            f.write(f"| {test_id} | {name} | {category} | {pass_rate:.1f}% |\n")

        f.write("\n### Easiest Tests (Highest Pass Rates)\n\n")
        f.write("| ID | Test Name | Category | Pass Rate |\n")
        f.write("|----|-----------|----------|----------:|\n")

        for _, row in sorted_tests.tail(10).iterrows():
            test_id = int(row['Test ID'])
            name = row['name']
            category = row['category'] if pd.notna(row['category']) else 'N/A'
            pass_rate = row['Pass Rate (%)']
            f.write(f"| {test_id} | {name} | {category} | {pass_rate:.1f}% |\n")

        f.write("\n---\n\n")
        f.write("## Performance by Category\n\n")
        f.write("| Category | Pass Rate |\n")
        f.write("|----------|----------:|\n")

        sorted_cats = category_performance.sort_values('Pass Rate (%)')
        for category, row in sorted_cats.iterrows():
            pass_rate = row['Pass Rate (%)']
            f.write(f"| {category} | {pass_rate:.1f}% |\n")

        f.write("\n**Key Insight:** String manipulation tests are by far the hardest category, ")
        f.write("while constraint compliance tests are the easiest.\n\n")

        f.write("---\n\n")
        f.write("## Performance by Provider\n\n")
        f.write("Top providers (with ≥3 models):\n\n")
        f.write("| Provider | Models | Pass Rate | Total Tests |\n")
        f.write("|----------|--------|-----------|------------:|\n")

        for provider, row in provider_stats.head(15).iterrows():
            num_models = int(row['num_models'])
            pass_rate = row['pass_rate']
            total_tests = int(row['total_tests'])
            f.write(f"| {provider} | {num_models} | {pass_rate:.1f}% | {total_tests} |\n")

        f.write("\n---\n\n")
        f.write("## Statistical Insights\n\n")

        # Calculate additional statistics
        model_pass_rates = all_results.groupby('model')['passed'].mean() * 100
        test_pass_rates = all_results.groupby('test_id')['passed'].mean() * 100

        f.write(f"### Model Performance Distribution\n\n")
        f.write(f"- **Mean Pass Rate:** {model_pass_rates.mean():.1f}%\n")
        f.write(f"- **Median Pass Rate:** {model_pass_rates.median():.1f}%\n")
        f.write(f"- **Standard Deviation:** {model_pass_rates.std():.1f}%\n")
        f.write(f"- **Models with 100% Pass Rate:** {(model_pass_rates == 100).sum()}\n")
        f.write(f"- **Models with 0% Pass Rate:** {(model_pass_rates == 0).sum()}\n\n")

        f.write(f"### Test Difficulty Distribution\n\n")
        f.write(f"- **Mean Test Pass Rate:** {test_pass_rates.mean():.1f}%\n")
        f.write(f"- **Median Test Pass Rate:** {test_pass_rates.median():.1f}%\n")
        f.write(f"- **Standard Deviation:** {test_pass_rates.std():.1f}%\n")
        f.write(f"- **Tests with >80% Pass Rate:** {(test_pass_rates > 80).sum()}\n")
        f.write(f"- **Tests with <20% Pass Rate:** {(test_pass_rates < 20).sum()}\n\n")

        f.write("---\n\n")
        f.write("## Files Generated\n\n")
        f.write("- `fig1_heatmap.pdf` - Performance heatmap (top 50 models)\n")
        f.write("- `fig2_provider.pdf` - Provider comparison chart\n")
        f.write("- `fig3_difficulty.pdf` - Test difficulty ranking\n")
        f.write("- `fig4_category.pdf` - Category performance chart\n")
        f.write("- `paper_tables.tex` - LaTeX tables for paper\n")
        f.write("- `FINAL_RESULTS_SUMMARY.md` - This summary document\n\n")

        f.write("---\n\n")
        f.write("## Conclusions\n\n")
        f.write("1. **Model Performance Varies Widely:** Pass rates range from 0% to 100%, ")
        f.write("indicating significant differences in instruction-following capabilities.\n\n")

        f.write("2. **String Manipulation is Hardest:** Tests requiring precise string manipulation ")
        f.write("have the lowest pass rates, suggesting this is a key challenge for current LLMs.\n\n")

        f.write("3. **Provider Differences:** Significant variation exists between model providers, ")
        f.write("with top providers achieving much higher pass rates.\n\n")

        f.write("4. **Few Perfect Models:** Only a small number of models achieve 100% pass rate, ")
        f.write("indicating that even top models struggle with some tests.\n\n")

        f.write("5. **API Reliability:** High API success rate indicates robust testing methodology.\n\n")

    print(f"✅ Saved: {output_file}")

def main():
    """Main analysis function."""
    print("\n" + "="*80)
    print("COMPREHENSIVE FINAL ANALYSIS")
    print(f"Analysis Date: {datetime.now().strftime('%Y-%m-%d %H:%M')}")
    print("="*80 + "\n")

    # Load results
    all_results, model_rankings, test_difficulty, category_performance, json_data = load_results()

    # Print summary statistics
    print_summary_statistics(all_results, json_data)
    print_top_models(model_rankings)
    print_test_difficulty(test_difficulty)
    print_category_analysis(category_performance)

    # Analyze by provider
    provider_stats = analyze_by_provider(all_results)

    # Create all visualizations
    create_heatmap(all_results)
    create_provider_chart(provider_stats)
    create_difficulty_chart(test_difficulty)
    create_category_chart(category_performance)

    # Generate LaTeX tables
    generate_latex_tables(model_rankings, test_difficulty, category_performance, provider_stats)

    # Generate summary report
    generate_summary_report(all_results, model_rankings, test_difficulty, category_performance, provider_stats, json_data)

    print("\n" + "="*80)
    print("✅ COMPREHENSIVE ANALYSIS COMPLETE")
    print("="*80)
    print("\nAll figures, tables, and reports have been generated successfully!")

if __name__ == "__main__":
    main()