File size: 4,270 Bytes
16b9846
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import json
import pandas as pd
from typing import Dict, List
import argparse

# Add this at the beginning of the script
parser = argparse.ArgumentParser(description='Process model results')
parser.add_argument('--modelname', type=str, help='Model name to use as prefix')
args = parser.parse_args()

def find_result_files(root_dir: str) -> List[tuple]:
    """
    Find all results JSON files and their corresponding checkpoint numbers.
    Returns list of (checkpoint_number, file_path) tuples.
    """
    result_files = []
    
    for root, dirs, files in os.walk(root_dir):
        if 'result' in root.lower():
            for file in files:
                if file.startswith('result') and file.endswith('.json'):
                    # Extract checkpoint number from path
                    checkpoint = None
                    path_parts = root.split(os.sep)
                    for part in path_parts:
                        if part.startswith('checkpoint-'):
                            checkpoint = part
                            break
                    
                    if checkpoint:
                        result_files.append((checkpoint, os.path.join(root, file)))
    
    return result_files

def extract_metrics(json_path: str) -> Dict[str, float]:
    """
    Extract specific accuracy metrics from a results JSON file and convert to percentages.
    """
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    metrics = {}
    target_tasks = [
        'medmcqa',
        'medqa_4options',
        'mmlu_anatomy',
        'mmlu_clinical_knowledge',
        'mmlu_college_biology',
        'mmlu_college_medicine',
        'mmlu_medical_genetics',
        'mmlu_professional_medicine',
        'pubmedqa'
    ]
    
    results = data.get('results', {})
    for task in target_tasks:
        if task in results:
            # Convert to percentage and round to 3 decimal places
            value = results[task].get('acc,none')
            if value is not None:
                metrics[task] = round(value * 100, 3)
            else:
                metrics[task] = None
    
    # Calculate average of available metrics
    valid_metrics = [v for v in metrics.values() if v is not None]
    if valid_metrics:
        metrics['average'] = round(sum(valid_metrics) / len(valid_metrics), 3)
    else:
        metrics['average'] = None
    
    return metrics

def process_all_results(root_dir: str, output_file: str = 'model_metrics.csv', checkpoint_prefix: str = None):
    """
    Process all result files and create a CSV with metrics as percentages.
    
    Parameters:
        root_dir (str): Root directory to search for result files
        output_file (str): Output CSV filename
        checkpoint_prefix (str): Optional prefix to add before checkpoint numbers (e.g., "model_name_")
    """
    result_files = find_result_files(root_dir)
    all_metrics = []
    
    for checkpoint, file_path in result_files:
        metrics = extract_metrics(file_path)
        # Add prefix to checkpoint if provided
        if checkpoint_prefix:
            metrics['checkpoint'] = f"{checkpoint_prefix}{checkpoint}"
        else:
            metrics['checkpoint'] = checkpoint
        all_metrics.append(metrics)
    
    if all_metrics:
        df = pd.DataFrame(all_metrics)
        # Reorder columns to put checkpoint and average first
        cols = ['checkpoint', 'average'] + [col for col in df.columns if col not in ['checkpoint', 'average']]
        df = df[cols]
        
        # Format float columns to 3 decimal places
        float_cols = [col for col in df.columns if col != 'checkpoint']
        for col in float_cols:
            df[col] = df[col].apply(lambda x: f"{x:.3f}" if pd.notnull(x) else x)
            
        df.to_csv(output_file, index=False)
        print(f"Results saved to {output_file}")
    else:
        print("No result files found.")

# Usage examples
if __name__ == "__main__":
    root_directory = "."  # Replace with your root directory path
    
    # Example 1: Without prefix (original behavior)
    # process_all_results(root_directory)
    
    process_all_results(
        root_directory,
        output_file='model_metrics.csv',
        checkpoint_prefix=args.modelname)