File size: 2,357 Bytes
9d5b280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import csv
from benchmarking import BenchmarkEvaluator
import logging
import glob


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def read_models(models_file):
    """Read model names from text file"""
    try:
        with open(models_file, 'r') as f:
            return [line.strip() for line in f if line.strip()]
    except Exception as e:
        logger.error(f"Error reading models file: {e}")
        return []

def evaluate_model(model_name, file_path):
    """Evaluate a single model and return its results"""
    try:
        evaluator = BenchmarkEvaluator(model_name)
        results = evaluator.run_benchmark(file_path)
        return {
            'model_name': model_name,
            'accuracy': results['accuracy'], 
            'total_samples': results['total_samples'],
            'processed_samples': results['processed_samples']
        }
    except Exception as e:
        logger.error(f"Error evaluating model {model_name}: {e}")
        return None

def save_results_to_csv(results, output_file='model_results.csv'):
    """Save evaluation results to CSV file"""
    fieldnames = ['model_name', 'accuracy', 'total_samples', 'processed_samples']
    try:
        with open(output_file, 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(results)
        logger.info(f"Results saved to {output_file}")
    except Exception as e:
        logger.error(f"Error saving results to CSV: {e}")

def main():

    FILE_PATH = "_Benchmarking_DB.json"
    
    models = [
        "../kista_checkpoint-1199",
        "../kista_checkpoint-2398",
        "../kista_checkpoint-3597"
    ]
    if not models:
        logger.error("No models found in file")
        return

    all_results = []

    for model_name in models:
        logger.info(f"Evaluating model: {model_name}")
        result = evaluate_model(model_name, FILE_PATH)
        if result:
            all_results.append(result)
            print(f"\nResults for {model_name}:")
            print(f"Accuracy: {result['accuracy']:.2%}")
            print(f"Total samples: {result['total_samples']}")
            print(f"Processed_samples: {result['processed_samples']}")

    save_results_to_csv(all_results)

if __name__ == "__main__":
    main()