import importlib.util import json import os import time import numpy as np import pickle import multiprocessing import queue from typing import Callable, Any, Dict, List, Tuple, Optional DEFAULT_METRICS_ON_ERROR = { "combined_score": 0.0, "execution_time_mean": 0.0, "execution_time_std": 0.0, "num_successful_runs": 0, "num_valid_runs": 0, "num_invalid_runs": 0, "all_validation_errors": [], } # Default timeout for each run (in seconds) DEFAULT_RUN_TIMEOUT = 120 # 2 minutes per run class TimeoutError(Exception): """Raised when execution exceeds timeout limit.""" pass def _run_function_with_timeout(func, kwargs, result_queue, error_queue): """ Worker function to run experiment_fn in a separate process. Args: func: The function to execute kwargs: Keyword arguments for the function result_queue: Queue to put the result error_queue: Queue to put any errors """ try: result = func(**kwargs) result_queue.put(result) except Exception as e: error_queue.put((type(e).__name__, str(e))) def run_with_timeout(func, kwargs, timeout_seconds): """ Execute a function with a timeout using multiprocessing. Args: func: The function to execute kwargs: Keyword arguments for the function timeout_seconds: Maximum time allowed (seconds) Returns: The result of the function call Raises: TimeoutError: If execution exceeds timeout Exception: Any exception raised by the function """ result_queue = multiprocessing.Queue() error_queue = multiprocessing.Queue() process = multiprocessing.Process( target=_run_function_with_timeout, args=(func, kwargs, result_queue, error_queue) ) process.start() process.join(timeout=timeout_seconds) if process.is_alive(): # Timeout occurred process.terminate() process.join(timeout=5) # Give it 5 seconds to terminate gracefully if process.is_alive(): process.kill() # Force kill if still alive raise TimeoutError( f"Execution exceeded timeout of {timeout_seconds} seconds" ) # Check for errors if not error_queue.empty(): error_type, error_msg = error_queue.get() raise RuntimeError(f"{error_type}: {error_msg}") # Get result if result_queue.empty(): raise RuntimeError("Function completed but no result was returned") return result_queue.get() def load_program(program_path: str) -> Any: """Loads a Python module dynamically from a given file path.""" spec = importlib.util.spec_from_file_location("program", program_path) if spec is None: raise ImportError(f"Could not load spec for module at {program_path}") if spec.loader is None: raise ImportError(f"Spec loader is None for module at {program_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def save_json_results( results_dir: str, metrics: Dict[str, Any], correct: bool, error: Optional[str] = None, ) -> None: """Saves metrics and correctness status to JSON files.""" os.makedirs(results_dir, exist_ok=True) correct_payload = {"correct": correct, "error": error} correct_file = os.path.join(results_dir, "correct.json") with open(correct_file, "w") as f: json.dump(correct_payload, f, indent=4) print(f"Correctness and error status saved to {correct_file}") metrics_file = os.path.join(results_dir, "metrics.json") with open(metrics_file, "w") as f: json.dump(metrics, f, indent=4) print(f"Metrics saved to {metrics_file}") def run_shinka_eval( program_path: str, results_dir: str, experiment_fn_name: str, num_runs: int, get_experiment_kwargs: Optional[Callable[[int], Dict[str, Any]]] = None, aggregate_metrics_fn: Optional[Callable[[List[Any]], Dict[str, Any]]] = None, validate_fn: Optional[Callable[[Any], Tuple[bool, Optional[str]]]] = None, default_metrics_on_error: Optional[Dict[str, Any]] = None, timeout_seconds: Optional[float] = None, ) -> Tuple[Dict[str, Any], bool, Optional[str]]: """ Runs an experiment multiple times, collects results, optionally validates, computes metrics, and saves them. Args: program_path: Path to the Python script/module to evaluate. results_dir: Directory to save `metrics.json` and `correct.json`. experiment_fn_name: Name of function to call in the loaded module. num_runs: Number of times to run the experiment function. get_experiment_kwargs: Opt. fn (run_idx_0_based -> kwargs_dict) for experiment args. Seed passed if None. aggregate_metrics_fn: Opt. fn (raw_results_list -> metrics_dict) for aggregation. If None, basic run stats (count, time) are recorded. validate_fn: Opt. fn (result -> (is_valid, error_msg)) to validate each run. Affects overall correctness. default_metrics_on_error: Metrics for eval failure. Uses predefined default if None. timeout_seconds: Maximum time allowed for each run (seconds). If None, uses DEFAULT_RUN_TIMEOUT (120s). Set to 0 or negative to disable timeout. Returns: A tuple: (metrics, overall_correct_flag, first_error_message) """ effective_default_metrics = ( default_metrics_on_error.copy() if default_metrics_on_error else DEFAULT_METRICS_ON_ERROR.copy() ) # Determine effective timeout if timeout_seconds is None: effective_timeout = DEFAULT_RUN_TIMEOUT elif timeout_seconds <= 0: effective_timeout = None # Disable timeout else: effective_timeout = timeout_seconds overall_correct_flag = True first_error_message: Optional[str] = None all_validation_errors_list: List[str] = [] num_valid_runs = 0 num_invalid_runs = 0 all_run_results: List[Any] = [] execution_times: List[float] = [] try: module = load_program(program_path) if not hasattr(module, experiment_fn_name): raise AttributeError( f"Experiment function '{experiment_fn_name}' not found in " f"{program_path}" ) experiment_fn = getattr(module, experiment_fn_name) for i in range(num_runs): kwargs: Dict[str, Any] = {} if get_experiment_kwargs: kwargs = get_experiment_kwargs(i) else: kwargs = {"seed": i + 1} start_time = time.perf_counter() # Execute with timeout if enabled try: if effective_timeout is not None: print(f"Running with timeout: {effective_timeout}s") run_result = run_with_timeout( experiment_fn, kwargs, effective_timeout ) else: run_result = experiment_fn(**kwargs) end_time = time.perf_counter() except TimeoutError as e: end_time = time.perf_counter() error_msg = f"Execution timeout after {effective_timeout}s: {str(e)}" print(f"⏱️ TIMEOUT: Run {i + 1}/{num_runs} - {error_msg}") # Treat timeout as validation failure num_invalid_runs += 1 overall_correct_flag = False if not first_error_message: first_error_message = error_msg all_validation_errors_list.append(error_msg) # Record execution time (up to timeout) execution_times.append(end_time - start_time) # Skip this run, continue to next continue all_run_results.append(run_result) execution_times.append(end_time - start_time) if validate_fn: is_valid, validation_err_msg = validate_fn(run_result) if not is_valid: num_invalid_runs += 1 overall_correct_flag = False if validation_err_msg: if not first_error_message: first_error_message = ( f"Validation failed: {validation_err_msg}" ) if validation_err_msg not in all_validation_errors_list: all_validation_errors_list.append(validation_err_msg) else: num_valid_runs += 1 print( f"Run {i + 1}/{num_runs} completed in {end_time - start_time:.2f} seconds" ) metrics: Dict[str, Any] if aggregate_metrics_fn: metrics = aggregate_metrics_fn(all_run_results) else: metrics = {"num_successful_runs": len(all_run_results)} if all_run_results: metrics["first_run_result_type"] = str(type(all_run_results[0])) metrics["raw_results_preview"] = str(all_run_results[:2]) else: metrics["first_run_result_type"] = "N/A" metrics["raw_results_preview"] = "N/A" metrics["execution_time_mean"] = ( float(np.mean(execution_times)) if execution_times else 0.0 ) metrics["execution_time_std"] = ( float(np.std(execution_times)) if execution_times else 0.0 ) if validate_fn: metrics["num_valid_runs"] = num_valid_runs metrics["num_invalid_runs"] = num_invalid_runs metrics["all_validation_errors"] = all_validation_errors_list except Exception as e: print(f"Evaluation error: {e}") metrics = { k: effective_default_metrics.get(k, v_default) for k, v_default in DEFAULT_METRICS_ON_ERROR.items() } if validate_fn: metrics.setdefault("num_valid_runs", 0) # Best guess for invalid runs if an exception occurs mid-evaluation num_potential_runs = num_runs if all_run_results is not None: num_potential_runs = len(all_run_results) metrics.setdefault("num_invalid_runs", num_potential_runs) metrics.setdefault("all_validation_errors", [str(e)]) first_error_message = str(e) overall_correct_flag = False if "extra_data" in metrics: os.makedirs(results_dir, exist_ok=True) extra_data = metrics.pop("extra_data") extra_file = os.path.join(results_dir, "extra.pkl") with open(extra_file, "wb") as f: pickle.dump(extra_data, f) print(f"Extra data saved to {extra_file}") save_json_results(results_dir, metrics, overall_correct_flag, first_error_message) return metrics, overall_correct_flag, first_error_message