shinka-backup / tasks /alphaevolve_ac2 /run_experiment.py
JustinTX's picture
Add files using upload-large-folder tool
40607c3 verified
#!/usr/bin/env python3
"""Task-specific experiment runner for AlphaEvolve AC2."""
import argparse
import sys
from datetime import datetime
from pathlib import Path
import requests
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from shinka.core import EvolutionConfig, EvolutionRunner
from shinka.database import DatabaseConfig
from shinka.launch import LocalJobConfig
from tasks.alphaevolve_ac2.prompt import TASK_SYS_MSG
def parse_args():
parser = argparse.ArgumentParser(
description="Run ShinkaEvolve on AlphaEvolve AC2",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--experiment-name", type=str, required=True)
parser.add_argument("--num-generations", type=int, default=200)
parser.add_argument("--max-parallel-jobs", type=int, default=5)
parser.add_argument("--meta-interval", type=int, default=10)
parser.add_argument(
"--use-text-feedback",
dest="use_text_feedback",
action="store_true",
default=True,
help="Include evaluator text_feedback (including auxiliary metric descriptions) in mutation prompts",
)
parser.add_argument(
"--no-text-feedback",
dest="use_text_feedback",
action="store_false",
help="Disable text_feedback injection into mutation prompts",
)
parser.add_argument("--num-islands", type=int, default=2)
parser.add_argument("--archive-size", type=int, default=40)
parser.add_argument(
"--llm-models",
nargs="+",
type=str,
default=["native-gemini-3-flash-preview"],
)
parser.add_argument(
"--llm-selection",
type=str,
default="ucb1",
choices=["ucb1", "thompson", "epsilon_greedy", "random"],
)
parser.add_argument(
"--llm-temperatures",
nargs="+",
type=float,
default=[0.0, 0.5, 1.0],
)
parser.add_argument("--llm-max-tokens", type=int, default=65536)
parser.add_argument(
"--trajectory-log",
action="store_true",
default=False,
help="Enable per-LLM-call trajectory logging for Shinka mutation loop",
)
parser.add_argument(
"--trajectory-log-dir",
type=str,
default="llm_trajectories",
help="Directory (relative to gen dir or absolute) for trajectory JSON files",
)
parser.add_argument(
"--patch-types",
nargs="+",
type=str,
default=["diff", "full", "cross"],
)
parser.add_argument(
"--patch-probs",
nargs="+",
type=float,
default=[0.6, 0.3, 0.1],
)
parser.add_argument("--use-eval-service", action="store_true", default=False)
parser.add_argument("--eval-service-url", type=str, default="http://localhost:8765")
parser.add_argument(
"--eval-trigger-mode",
type=str,
default=None,
choices=["always", "periodic", "plateau", "mixed"],
)
parser.add_argument("--eval-trigger-interval", type=int, default=None)
parser.add_argument("--use-wandb", action="store_true", default=False)
parser.add_argument("--wandb-project", type=str, default="ev2")
parser.add_argument("--wandb-entity", type=str, default="tengxiao")
parser.add_argument("--wandb-run-name", type=str, default=None)
parser.add_argument("--wandb-tags", nargs="*", type=str, default=None)
parser.add_argument("--results-dir", type=str, default=None)
parser.add_argument(
"--initial-code",
type=str,
default="tasks/alphaevolve_ac2/initial.py",
)
parser.add_argument(
"--evaluator",
type=str,
default="tasks/alphaevolve_ac2/evaluate_ori.py",
)
parser.add_argument(
"--evaluator-module",
type=str,
default="tasks.alphaevolve_ac2.evaluate_ori",
)
parser.add_argument("--evaluator-function", type=str, default="main")
parser.add_argument("--verbose", action="store_true", default=True)
return parser.parse_args()
def resolve_defaults(args):
if args.results_dir is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.results_dir = (
"tasks/alphaevolve_ac2/results/"
f"results_{args.experiment_name}_{timestamp}"
)
if args.use_wandb and args.wandb_run_name is None:
args.wandb_run_name = (
f"{args.experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
)
return args
def check_eval_service(url: str):
try:
response = requests.get(f"{url}/api/v1/status", timeout=2.0)
if response.status_code == 200:
return True, response.json()
except Exception as exc:
return False, str(exc)
return False, "Unknown error"
def main():
args = resolve_defaults(parse_args())
results_dir = Path(args.results_dir)
results_dir.mkdir(parents=True, exist_ok=True)
print("=" * 80)
print("ShinkaEvolve: AlphaEvolve AC2")
print("=" * 80)
print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Experiment: {args.experiment_name}")
print(f"Generations: {args.num_generations}")
print(f"Parallel: {args.max_parallel_jobs}")
print(f"Models: {', '.join(args.llm_models)}")
print(f"Results Dir: {results_dir}")
print("=" * 80)
if args.use_eval_service:
ok, info = check_eval_service(args.eval_service_url)
if not ok:
print(f"Eval service not available at {args.eval_service_url}: {info}")
sys.exit(1)
print(f"Eval service ready: {args.eval_service_url}")
job_config = LocalJobConfig(eval_program_path=args.evaluator)
db_config = DatabaseConfig(
num_islands=args.num_islands,
archive_size=args.archive_size,
elite_selection_ratio=0.3,
num_archive_inspirations=4,
num_top_k_inspirations=2,
migration_interval=10,
migration_rate=0.1,
island_elitism=True,
parent_selection_strategy="weighted",
parent_selection_lambda=10.0,
)
evo_config = EvolutionConfig(
task_sys_msg=TASK_SYS_MSG,
patch_types=args.patch_types,
patch_type_probs=args.patch_probs,
num_generations=args.num_generations,
max_parallel_jobs=args.max_parallel_jobs,
max_patch_resamples=3,
max_patch_attempts=3,
job_type="local",
language="python",
llm_models=args.llm_models,
llm_kwargs=dict(
temperatures=args.llm_temperatures,
max_tokens=args.llm_max_tokens,
reasoning_efforts=["auto", "low", "medium", "high"],
),
llm_dynamic_selection=args.llm_selection,
llm_dynamic_selection_kwargs=dict(exploration_coef=1.0),
meta_rec_interval=args.meta_interval,
meta_llm_models=[args.llm_models[0]],
meta_llm_kwargs=dict(temperatures=[0.0], max_tokens=32768),
novelty_llm_models=[args.llm_models[0]],
novelty_llm_kwargs=dict(temperatures=[0.0], max_tokens=32768),
embedding_model="text-embedding-3-small",
code_embed_sim_threshold=0.995,
init_program_path=args.initial_code,
results_dir=str(results_dir),
eval_service_url=args.eval_service_url if args.use_eval_service else None,
use_eval_service=args.use_eval_service,
evaluator_module=args.evaluator_module if args.use_eval_service else None,
evaluator_function=args.evaluator_function,
eval_service_trigger_mode=(
args.eval_trigger_mode if args.use_eval_service else None
),
eval_service_trigger_interval=(
args.eval_trigger_interval if args.use_eval_service else None
),
enable_wandb=args.use_wandb,
wandb_project=args.wandb_project,
wandb_entity=args.wandb_entity,
wandb_run_name=args.wandb_run_name,
wandb_tags=args.wandb_tags,
use_text_feedback=args.use_text_feedback,
trajectory_log=args.trajectory_log,
trajectory_log_dir=args.trajectory_log_dir,
)
runner = EvolutionRunner(
evo_config=evo_config,
job_config=job_config,
db_config=db_config,
verbose=args.verbose,
)
runner.run()
if __name__ == "__main__":
main()