|
|
import json |
|
|
from model_run import VLLMClient |
|
|
from typing import Dict, List |
|
|
import logging |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BenchmarkEvaluator: |
|
|
def __init__(self, model_path: str): |
|
|
self.client = VLLMClient(model_path) |
|
|
self.nest_name = model_path.split('/')[1] |
|
|
|
|
|
def load_data(self, file_path: str) -> List[Dict]: |
|
|
"""Load data from JSON file.""" |
|
|
try: |
|
|
with open(file_path, 'r') as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading data: {e}") |
|
|
raise |
|
|
|
|
|
def log_api_call(self, input_data: Dict, api_response: Dict, ground_truth: str, error: str = None) -> None: |
|
|
"""Log API call details to a JSON file.""" |
|
|
import os |
|
|
from datetime import datetime |
|
|
import uuid |
|
|
|
|
|
|
|
|
|
|
|
log_dir = f"benchmark_logs/{self.nest_name}" |
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
unique_id = str(uuid.uuid4())[:8] |
|
|
filename = f"{log_dir}/api_call_{timestamp}_{unique_id}.json" |
|
|
|
|
|
|
|
|
log_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"input": str(input_data), |
|
|
"ground_truth": str(ground_truth), |
|
|
"api_response": str(api_response), |
|
|
"error": error |
|
|
} |
|
|
|
|
|
|
|
|
with open(filename, 'w', encoding='utf-8') as f: |
|
|
json.dump(log_data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
def get_model_response(self, system_prompt: str, input_text: str, ground_truth: str) -> Dict: |
|
|
"""Get response from the model.""" |
|
|
input_data = { |
|
|
"system_prompt": system_prompt, |
|
|
"input_text": input_text |
|
|
} |
|
|
|
|
|
try: |
|
|
response = self.client.send_message(system_prompt, input_text) |
|
|
|
|
|
parsed_response = eval(response['result']) |
|
|
|
|
|
|
|
|
self.log_api_call(input_data, parsed_response, ground_truth) |
|
|
|
|
|
return parsed_response |
|
|
except Exception as e: |
|
|
|
|
|
self.log_api_call(input_data, None, str(e)) |
|
|
logger.error(f"Error getting model response: {e}") |
|
|
return None |
|
|
|
|
|
def normalize_is_met(self, value: str) -> str: |
|
|
"""Normalize is_met value to lowercase.""" |
|
|
if not isinstance(value, str): |
|
|
return str(value).lower() |
|
|
return value.lower() |
|
|
|
|
|
def calculate_accuracy(self, ground_truth: List[Dict], model_outputs: List[Dict]) -> float: |
|
|
"""Calculate accuracy between ground truth and model outputs.""" |
|
|
if len(ground_truth) != len(model_outputs): |
|
|
raise ValueError("Ground truth and model outputs must have the same length") |
|
|
|
|
|
correct = 0 |
|
|
total = len(ground_truth) |
|
|
|
|
|
for gt, mo in zip(ground_truth, model_outputs): |
|
|
gt_is_met = self.normalize_is_met(gt['output']['is_met']) |
|
|
mo_is_met = self.normalize_is_met(mo['assessments'][0]['is_met']) |
|
|
|
|
|
if gt_is_met == mo_is_met: |
|
|
correct += 1 |
|
|
|
|
|
return correct / total if total > 0 else 0 |
|
|
|
|
|
def run_benchmark(self, file_path: str) -> Dict: |
|
|
"""Run the complete benchmarking process.""" |
|
|
|
|
|
data = self.load_data(file_path) |
|
|
|
|
|
|
|
|
model_outputs = [] |
|
|
ground_truth = [] |
|
|
|
|
|
|
|
|
for entry in tqdm(data, desc="Processing entries"): |
|
|
|
|
|
model_response = self.get_model_response( |
|
|
entry['system_prompt'], |
|
|
entry['input'], |
|
|
entry['output'] |
|
|
) |
|
|
|
|
|
if model_response: |
|
|
model_outputs.append(model_response) |
|
|
ground_truth.append(entry) |
|
|
|
|
|
|
|
|
accuracy = self.calculate_accuracy(ground_truth, model_outputs) |
|
|
|
|
|
return { |
|
|
'accuracy': accuracy, |
|
|
'total_samples': len(data), |
|
|
'processed_samples': len(model_outputs) |
|
|
} |
|
|
|