Spaces:
Sleeping
Sleeping
File size: 6,004 Bytes
978fed5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | """
MLE-Bench Workflow
Simple wrapper for running SciDER FullWorkflow on MLE-Bench competition tasks.
MLE-Bench provides:
- instructions.md: Specific task instructions (used as user_query)
- description.md: Overall task background description
This wrapper register models, reads these files, builds user_query, and invokes FullWorkflow.
"""
import sys
from pathlib import Path
from loguru import logger
# Add parent directory to path to find scider and bench modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from bench_workflows.register_models.gemini import (
register_gemini3_medium_high_models,
register_gemini_low_medium_models,
register_gemini_medium_high_models,
)
from bench_workflows.register_models.gpt import (
register_gpt_low_medium_models,
register_gpt_medium_high_models,
)
from scider.workflows.full_workflow import run_full_workflow
def build_mlebench_user_query(
instructions_path: Path,
description_path: Path,
) -> tuple[str, str]:
"""
Build user query and data description from MLE-Bench task files.
Args:
instructions_path: Path to instructions.md
description_path: Path to description.md
Returns:
Tuple of (user_query, data_desc)
- user_query: Task instructions for the experiment
- data_desc: Task description for data analysis context
"""
# Load instructions
if not instructions_path.exists():
raise FileNotFoundError(f"Instructions file not found: {instructions_path}")
instructions = instructions_path.read_text(encoding="utf-8")
# Load description
if not description_path.exists():
raise FileNotFoundError(f"Description file not found: {description_path}")
description = description_path.read_text(encoding="utf-8")
# Use instructions as user_query, description as data_desc
user_query = instructions
data_desc = description
return user_query, data_desc
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="MLE-Bench Workflow - Run SciDER on MLE-Bench competition tasks",
prog="python -m bench.mlebench_workflow",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic usage
python -m bench.mlebench_workflow \\
-i competition/instructions.md \\
-d competition/description.md \\
--data competition/data \\
-w workspace
# With custom settings
python -m bench.mlebench_workflow \\
-i competition/instructions.md \\
-d competition/description.md \\
--data competition/data \\
-w workspace \\
--max-revisions 10 \\
--session-name my_experiment
""",
)
# Required arguments
parser.add_argument(
"--instructions",
"-i",
required=True,
help="Path to instructions.md (task instructions)",
)
parser.add_argument(
"--description",
"-d",
required=True,
help="Path to description.md (task background)",
)
parser.add_argument(
"--data",
required=True,
help="Path to the data directory or file",
)
parser.add_argument(
"--workspace",
"-w",
required=True,
help="Workspace directory for the experiment",
)
# Optional arguments
parser.add_argument(
"--repo-source",
default=None,
help="Optional repository source (local path or git URL)",
)
parser.add_argument(
"--max-revisions",
type=int,
default=3,
help="Maximum revision loops (default: 3)",
)
parser.add_argument(
"--data-recursion-limit",
type=int,
default=512,
help="Recursion limit for DataAgent (default: 512)",
)
parser.add_argument(
"--experiment-recursion-limit",
type=int,
default=512,
help="Recursion limit for ExperimentAgent (default: 512)",
)
parser.add_argument(
"--session-name",
default=None,
help="Custom session name (otherwise uses timestamp)",
)
parser.add_argument(
"--models",
choices=[
"gpt-low-medium",
"gpt-medium-high",
"gemini-low-medium",
"gemini-medium-high",
"gemini3-medium-high",
],
default="gemini-low-medium",
help="Model configuration to use (default: gemini-low-medium)",
)
args = parser.parse_args()
# Register models based on choice
logger.info(f"Registering models: {args.models}")
match args.models:
case "gpt-low-medium":
register_gpt_low_medium_models()
case "gpt-medium-high":
register_gpt_medium_high_models()
case "gemini-low-medium":
register_gemini_low_medium_models()
case "gemini-medium-high":
register_gemini_medium_high_models()
case "gemini3-medium-high":
register_gemini3_medium_high_models()
# Build user query and data description from MLE-Bench files
logger.info("Building user query from MLE-Bench task files...")
user_query, data_desc = build_mlebench_user_query(
instructions_path=Path(args.instructions),
description_path=Path(args.description),
)
logger.info(f"User query built: {len(user_query)} chars")
logger.info(f"Data description built: {len(data_desc)} chars")
# Run FullWorkflow
result = run_full_workflow(
data_path=args.data,
workspace_path=args.workspace,
user_query=user_query,
data_desc=data_desc,
repo_source=args.repo_source,
max_revisions=args.max_revisions,
data_agent_recursion_limit=args.data_recursion_limit,
experiment_agent_recursion_limit=args.experiment_recursion_limit,
session_name=args.session_name,
)
# Save summary
result.save_summary()
print(f"\nStatus: {result.final_status}")
|