File size: 4,434 Bytes
9d5b280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
from model_run import VLLMClient
from typing import Dict, List
import logging
from tqdm import tqdm

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class BenchmarkEvaluator:
    def __init__(self, model_path: str):
        self.client = VLLMClient(model_path)
        self.nest_name = model_path.split('/')[1]
        
    def load_data(self, file_path: str) -> List[Dict]:
        """Load data from JSON file."""
        try:
            with open(file_path, 'r') as f:
                return json.load(f)
        except Exception as e:
            logger.error(f"Error loading data: {e}")
            raise

    def log_api_call(self, input_data: Dict, api_response: Dict, ground_truth: str, error: str = None) -> None:
        """Log API call details to a JSON file."""
        import os
        from datetime import datetime
        import uuid

        # Create benchmark_logs directory if it doesn't exist

        log_dir = f"benchmark_logs/{self.nest_name}"
        os.makedirs(log_dir, exist_ok=True)

        # Create unique filename using timestamp and UUID
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        unique_id = str(uuid.uuid4())[:8]
        filename = f"{log_dir}/api_call_{timestamp}_{unique_id}.json"

        # Prepare log data
        log_data = {
            "timestamp": datetime.now().isoformat(),
            "input": str(input_data),
            "ground_truth": str(ground_truth),
            "api_response": str(api_response),
            "error": error
        }

        # Write to file
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(log_data, f, indent=2, ensure_ascii=False)

    def get_model_response(self, system_prompt: str, input_text: str, ground_truth: str) -> Dict:
        """Get response from the model."""
        input_data = {
            "system_prompt": system_prompt,
            "input_text": input_text
        }
        
        try:
            response = self.client.send_message(system_prompt, input_text)
            # Parse the response to match the expected format
            parsed_response = eval(response['result'])
            
            # Log successful API call
            self.log_api_call(input_data, parsed_response, ground_truth)
            
            return parsed_response
        except Exception as e:
            # Log failed API call
            self.log_api_call(input_data, None, str(e))
            logger.error(f"Error getting model response: {e}")
            return None

    def normalize_is_met(self, value: str) -> str:
        """Normalize is_met value to lowercase."""
        if not isinstance(value, str):
            return str(value).lower()
        return value.lower()

    def calculate_accuracy(self, ground_truth: List[Dict], model_outputs: List[Dict]) -> float:
        """Calculate accuracy between ground truth and model outputs."""
        if len(ground_truth) != len(model_outputs):
            raise ValueError("Ground truth and model outputs must have the same length")

        correct = 0
        total = len(ground_truth)

        for gt, mo in zip(ground_truth, model_outputs):
            gt_is_met = self.normalize_is_met(gt['output']['is_met'])
            mo_is_met = self.normalize_is_met(mo['assessments'][0]['is_met'])
            
            if gt_is_met == mo_is_met:
                correct += 1

        return correct / total if total > 0 else 0

    def run_benchmark(self, file_path: str) -> Dict:
        """Run the complete benchmarking process."""
        # Load data
        data = self.load_data(file_path)
        
        # Store model outputs
        model_outputs = []
        ground_truth = []

        # Process each entry
        for entry in tqdm(data, desc="Processing entries"):
            # Get model response
            model_response = self.get_model_response(
                entry['system_prompt'],
                entry['input'],
                entry['output']
            )
            
            if model_response:
                model_outputs.append(model_response)
                ground_truth.append(entry)

        # Calculate accuracy
        accuracy = self.calculate_accuracy(ground_truth, model_outputs)

        return {
            'accuracy': accuracy,
            'total_samples': len(data),
            'processed_samples': len(model_outputs)
        }