pandelis's picture
Add Zerolang editing environment
bb1b296 verified
"""Synthetic training rows for the Zerolang editing environment."""
from __future__ import annotations
from typing import Any
from .task_builders import (
_branch_literal_task,
_call_task,
_condition_task,
_diagnostic_task,
_helper_task,
_literal_task,
_two_helper_task,
)
LEGACY_TRAIN_TASKS: list[dict[str, Any]] = [
_helper_task(
"helper-return-update",
"answer",
"40 + 1",
"40 + 2",
42,
"math works",
split="train",
),
_call_task("callee-argument-update", "2, 2", "2, 3", 5, split="train"),
_condition_task("comparison-target-update", "score", 7, 8, "ready", split="train"),
_diagnostic_task("fallible-main-repair", "needs raises", split="train"),
]
def _literal_train_tasks() -> list[dict[str, Any]]:
pairs = [
("queue pending", "queue ready"),
("job queued", "job running"),
("job running", "job complete"),
("build red", "build green"),
("node cold", "node warm"),
("cache miss", "cache hit"),
("retry later", "retry now"),
("draft note", "final note"),
("plan open", "plan closed"),
("graph stale", "graph fresh"),
("route /v1/run", "route /v2/run"),
("status [100]", "status [200]"),
("phase: alpha", "phase: beta"),
("phase: beta", "phase: gamma"),
("step 1/4", "step 2/4"),
("step 2/4", "step 3/4"),
("score 10/20", "score 18/20"),
("level: low", "level: high"),
("mode manual", "mode auto"),
("window closed", "window open"),
("target west", "target east"),
("port 3000", "port 8080"),
("run id a1", "run id b2"),
("batch small", "batch large"),
("token old", "token new"),
("edge loose", "edge locked"),
("module local", "module remote"),
("worker idle", "worker busy"),
("agent paused", "agent active"),
("output empty", "output full"),
("index 0", "index 1"),
("flag off", "flag on"),
("signal weak", "signal strong"),
("health warn", "health pass"),
("check skipped", "check passed"),
("ticket open", "ticket merged"),
("snapshot old", "snapshot new"),
("profile dev", "profile prod"),
("version 0.1", "version 0.2"),
("result unknown", "result known"),
]
return [
_literal_task(f"train-literal-{index:03d}", old, new, split="train")
for index, (old, new) in enumerate(pairs, start=1)
]
def _branch_literal_train_tasks() -> list[dict[str, Any]]:
specs = [
("ready_gate", "gate draft", "gate ready"),
("emit_gate", "emit old", "emit new"),
("mode_gate", "mode test", "mode live"),
("route_gate", "route blue", "route green"),
("status_gate", "status low", "status high"),
("phase_gate", "phase one", "phase two"),
("counter_gate", "count fail", "count pass"),
("worker_gate", "worker wait", "worker run"),
("deploy_gate", "deploy hold", "deploy ship"),
("review_gate", "review open", "review done"),
("graph_gate", "graph dirty", "graph clean"),
("patch_gate", "patch text", "patch graph"),
("score_gate", "score bad", "score good"),
("plan_gate", "plan rough", "plan exact"),
("test_gate", "test flaky", "test stable"),
("queue_gate", "queue blocked", "queue clear"),
("cache_gate", "cache cold", "cache hot"),
("trace_gate", "trace off", "trace on"),
("run_gate", "run dry", "run real"),
("sync_gate", "sync stale", "sync current"),
]
return [
_branch_literal_task(f"train-branch-literal-{index:03d}", helper, old, new, split="train")
for index, (helper, old, new) in enumerate(specs, start=1)
]
def _helper_train_tasks() -> list[dict[str, Any]]:
helpers = ["answer", "score", "total", "count", "value", "limit", "level", "points"]
outputs = ["ok", "ready", "matched", "accepted", "passed", "open", "done", "green"]
tasks: list[dict[str, Any]] = []
for index in range(1, 26):
left = 10 + index
target_right = 3 + (index % 9)
source_right = target_right - 1
expected = left + target_right
tasks.append(
_helper_task(
f"train-helper-add-{index:03d}",
helpers[index % len(helpers)],
f"{left} + {source_right}",
f"{left} + {target_right}",
expected,
outputs[index % len(outputs)],
split="train",
)
)
for index in range(1, 26):
left = 60 + index
target_right = 5 + (index % 11)
source_right = target_right + 1
expected = left - target_right
tasks.append(
_helper_task(
f"train-helper-sub-{index:03d}",
helpers[(index + 3) % len(helpers)],
f"{left} - {source_right}",
f"{left} - {target_right}",
expected,
outputs[(index + 2) % len(outputs)],
split="train",
)
)
return tasks
def _two_helper_train_tasks() -> list[dict[str, Any]]:
primary_helpers = ["score", "total", "count", "value", "answer", "level", "points", "result"]
other_helpers = ["spare", "backup", "idle", "other", "side", "helper", "extra", "unused"]
tasks: list[dict[str, Any]] = []
for index in range(1, 21):
left = 20 + index
target_right = 2 + (index % 7)
source_right = target_right - 1
expected = left + target_right
other_expr = f"{4 + index % 6} + {8 + index % 5}"
tasks.append(
_two_helper_task(
f"train-two-helper-{index:03d}",
primary_helpers[index % len(primary_helpers)],
other_helpers[index % len(other_helpers)],
f"{left} + {source_right}",
f"{left} + {target_right}",
other_expr,
expected,
split="train",
)
)
return tasks
def _call_train_tasks() -> list[dict[str, Any]]:
tasks: list[dict[str, Any]] = []
for index in range(1, 31):
left = 1 + (index % 17)
target_right = 2 + (index % 13)
source_right = target_right - 1
expected = left + target_right
tasks.append(
_call_task(
f"train-call-update-{index:03d}",
f"{left}, {source_right}",
f"{left}, {target_right}",
expected,
split="train",
)
)
return tasks
def _condition_train_tasks() -> list[dict[str, Any]]:
helpers = ["score", "count", "level", "token", "value", "flag", "marker", "limit"]
tasks: list[dict[str, Any]] = []
for index in range(1, 26):
returned = 5 + (index * 3)
source_compare = returned + 1
tasks.append(
_condition_task(
f"train-condition-update-{index:03d}",
helpers[index % len(helpers)],
returned,
source_compare,
"matched",
split="train",
)
)
return tasks
def _diagnostic_train_tasks() -> list[dict[str, Any]]:
messages = [
"train starting",
"train ready",
"diagnostic pass",
"writer needs raises",
"output accepted",
"payload saved",
"attempt complete",
"retry complete",
"batch emitted",
"sample logged",
"graph checked",
"patch validated",
"route verified",
"state stored",
"run complete",
"score written",
"marker emitted",
"world write",
"tool output",
"final line",
]
return [
_diagnostic_task(f"train-diagnostic-{index:03d}", message, split="train")
for index, message in enumerate(messages, start=1)
]
TRAIN_TASKS: list[dict[str, Any]] = [
*_literal_train_tasks(),
*_branch_literal_train_tasks(),
*LEGACY_TRAIN_TASKS,
*_helper_train_tasks(),
*_two_helper_train_tasks(),
*_call_train_tasks(),
*_condition_train_tasks(),
*_diagnostic_train_tasks(),
]