Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import asyncio | |
| from dataclasses import dataclass | |
| import time | |
| from tqdm import tqdm | |
| class BenchmarkResult: | |
| """Container for benchmark results""" | |
| benchmark_name: str | |
| model_name: str | |
| total_questions: int | |
| correct: int | |
| accuracy: float | |
| avg_response_time: float | |
| raw_results: List[Dict[str, Any]] | |
| class BaseBenchmark(ABC): | |
| """Base class for all benchmark implementations""" | |
| def __init__(self, name: str, dataset_name: str = None): | |
| self.name = name | |
| self.dataset_name = dataset_name or name | |
| self.dataset = None | |
| self.results = [] | |
| async def load_dataset(self, sample_size: Optional[int] = None, **kwargs): | |
| """Load the benchmark dataset""" | |
| pass | |
| async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]: | |
| """Evaluate a single sample""" | |
| pass | |
| def format_prompt(self, sample: Dict[str, Any]) -> str: | |
| """Format the prompt for the model""" | |
| pass | |
| async def run_benchmark(self, api, sample_size: Optional[int] = None, **kwargs) -> BenchmarkResult: | |
| """Run the benchmark on the given API""" | |
| print(f"Running {self.name} benchmark on {api.model_name}...") | |
| # Load dataset | |
| await self.load_dataset(sample_size, **kwargs) | |
| if not self.dataset: | |
| raise ValueError(f"No dataset loaded for {self.name}") | |
| # Prepare samples | |
| samples = self.dataset if sample_size is None else self.dataset[:sample_size] | |
| total_samples = len(samples) | |
| # Run evaluation | |
| correct_count = 0 | |
| response_times = [] | |
| raw_results = [] | |
| # Use async semaphore for concurrent requests | |
| concurrent_limit = kwargs.get('concurrent_requests', 5) | |
| semaphore = asyncio.Semaphore(concurrent_limit) | |
| async def evaluate_with_semaphore(sample, idx): | |
| async with semaphore: | |
| start_time = time.time() | |
| is_correct, result = await self.evaluate_sample(api, sample, **kwargs) | |
| end_time = time.time() | |
| result['response_time'] = end_time - start_time | |
| result['index'] = idx | |
| return is_correct, result | |
| # Create tasks for all samples | |
| tasks = [evaluate_with_semaphore(sample, idx) for idx, sample in enumerate(samples)] | |
| # Run with progress bar | |
| # Add imports needed for progress saving | |
| import json | |
| import os | |
| with tqdm(total=total_samples, desc=f"{self.name}") as pbar: | |
| for coro in asyncio.as_completed(tasks): | |
| is_correct, result = await coro | |
| if is_correct: | |
| correct_count += 1 | |
| response_times.append(result['response_time']) | |
| raw_results.append(result) | |
| pbar.update(1) | |
| # --- START: REAL-TIME PROGRESS SAVING --- | |
| # Every 10 samples, save the progress to a file | |
| if pbar.n > 0 and pbar.n % 10 == 0: | |
| # Ensure results directory exists | |
| results_dir = kwargs.get('output_dir', 'results') | |
| os.makedirs(results_dir, exist_ok=True) | |
| progress_path = os.path.join(results_dir, f'{self.name}_progress.json') | |
| # Sort results by index before saving | |
| sorted_progress = sorted(raw_results, key=lambda x: x['index']) | |
| try: | |
| with open(progress_path, 'w') as f: | |
| json.dump(sorted_progress, f, indent=2) | |
| except Exception as e: | |
| print(f"Error saving progress: {e}") | |
| # --- END: REAL-TIME PROGRESS SAVING --- | |
| # Calculate metrics | |
| accuracy = correct_count / total_samples if total_samples > 0 else 0 | |
| avg_response_time = sum(response_times) / len(response_times) if response_times else 0 | |
| return BenchmarkResult( | |
| benchmark_name=self.name, | |
| model_name=api.model_name, | |
| total_questions=total_samples, | |
| correct=correct_count, | |
| accuracy=accuracy, | |
| avg_response_time=avg_response_time, | |
| raw_results=sorted(raw_results, key=lambda x: x['index']) | |
| ) |