Spaces:
Sleeping
Sleeping
File size: 14,195 Bytes
7952f32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 | """Multi-turn repo-editing tasks.
Each Task specifies:
- A target repo to work on (points to a sample_repos/ subdir)
- A natural-language description of the change to make
- A set of test functions (Python code strings) that verify the change
- The maximum number of turns allowed
Training tasks are deliberately structured to require multi-step navigation:
1. The agent must QUERY the graph to find relevant nodes
2. INSPECT nodes to understand the existing code
3. ADD or UPDATE nodes to implement the change
4. SUBMIT to trigger compilation + test execution
This sparse reward structure forces the agent to develop structured planning
and state tracking across long trajectories β the core theme of this project.
"""
from __future__ import annotations
import importlib.util
import sys
import textwrap
import traceback
from dataclasses import dataclass, field
from pathlib import Path
SAMPLE_REPOS_DIR = Path(__file__).resolve().parent.parent / "graphforge" / "sample_repos"
@dataclass
class RepoTask:
task_id: str
repo_name: str # package name (used as tempdir subdir)
description: str # natural-language task for the agent
test_code: str # Python assertions using short imports
max_turns: int = 15
difficulty: int = 0 # 0=easy, 1=medium, 2=hard
hints: list[str] = field(default_factory=list)
repo_path: str | None = None # if set, full path to repo source dir
TASK_BANK: dict[str, RepoTask] = {}
def _reg(task: RepoTask) -> RepoTask:
TASK_BANK[task.task_id] = task
return task
# ββ Task 0: add validate_due_date ββββββββββββββββββββββββββββββββββββββββββββ
_reg(RepoTask(
task_id="t0.validate_due_date",
repo_name="task_manager",
description=textwrap.dedent("""\
Add a function `validate_due_date(due_date) -> bool` to `validators.py`.
The function should return True if:
- due_date is None (no deadline), OR
- due_date is a datetime.date instance
It should return False for any other type (strings, integers, etc.).
""").strip(),
test_code=textwrap.dedent("""\
from datetime import date
from task_manager.validators import validate_due_date
assert validate_due_date(None) is True, "None is valid (no deadline)"
assert validate_due_date(date(2025, 1, 1)) is True, "date object is valid"
assert validate_due_date("2025-01-01") is False, "string is not valid"
assert validate_due_date(20250101) is False, "int is not valid"
assert validate_due_date([]) is False, "list is not valid"
""").strip(),
max_turns=12,
hints=[
"Look in validators.py to see the style of existing validators.",
"The function signature should be: def validate_due_date(due_date) -> bool",
"Import datetime.date inside the function or at the top of validators.py.",
],
))
# ββ Task 1: add Task.is_overdue βββββββββββββββββββββββββββββββββββββββββββββββ
_reg(RepoTask(
task_id="t1.is_overdue",
repo_name="task_manager",
description=textwrap.dedent("""\
Add a method `is_overdue(self, today: date) -> bool` to the `Task`
class in `models.py`.
The method should return True if:
- the task has a due_date AND
- today is strictly after the due_date AND
- the task is not yet done
It should return False if there is no due_date, or if the task is done,
or if today <= due_date.
""").strip(),
test_code=textwrap.dedent("""\
from datetime import date
from task_manager.models import Task
t_past = Task("x", "low", [], due_date=date(2020, 1, 1))
t_future = Task("y", "low", [], due_date=date(2099, 1, 1))
t_none = Task("z", "low", [], due_date=None)
t_done = Task("d", "low", [], due_date=date(2020, 1, 1))
t_done.complete()
today = date.today()
assert t_past.is_overdue(today) is True, "past due date β overdue"
assert t_future.is_overdue(today) is False, "future due date β not overdue"
assert t_none.is_overdue(today) is False, "no due date β not overdue"
assert t_done.is_overdue(today) is False, "done task β not overdue"
""").strip(),
max_turns=15,
difficulty=1,
hints=[
"The Task class is in models.py.",
"The method should check self.due_date, today, and self.done.",
],
))
# ββ Task 2: add TaskStore.find_by_tag βββββββββββββββββββββββββββββββββββββββββ
_reg(RepoTask(
task_id="t2.find_by_tag",
repo_name="task_manager",
description=textwrap.dedent("""\
Add a method `find_by_tag(self, tag: str) -> list[Task]` to the
`TaskStore` class in `storage.py`.
The method should return a list of all tasks that have `tag` in
their `tags` list. Return an empty list if no tasks match.
""").strip(),
test_code=textwrap.dedent("""\
from task_manager.models import Task
from task_manager.storage import TaskStore
store = TaskStore()
store.add(Task("t1", "high", ["python", "backend"], None))
store.add(Task("t2", "low", ["frontend"], None))
store.add(Task("t3", "medium", ["python"], None))
result = store.find_by_tag("python")
assert len(result) == 2, f"Expected 2, got {len(result)}"
titles = {t.title for t in result}
assert titles == {"t1", "t3"}, f"Wrong titles: {titles}"
empty = store.find_by_tag("devops")
assert empty == [], f"Expected [], got {empty}"
""").strip(),
max_turns=15,
difficulty=1,
))
# ββ Task 3 (hard): enforce priority validation in api.create_task βββββββββββββ
_reg(RepoTask(
task_id="t3.enforce_priority",
repo_name="task_manager",
description=textwrap.dedent("""\
Update the `create_task` function in `api.py` so that it validates
the `priority` argument using `validate_priority` from `validators.py`.
If the priority is invalid, raise `ValueError` with a clear message.
The existing validations for title and tags must still work.
Note: `validate_priority` already exists in validators.py.
You must import and call it inside `create_task`.
""").strip(),
test_code=textwrap.dedent("""\
from task_manager import api as _api
_api.reset_store() # clean state between runs
# valid priority passes through
t = _api.create_task("Buy milk", priority="high")
assert t.priority == "high"
# invalid priority raises ValueError
raised = False
try:
_api.create_task("Bad task", priority="urgent")
except ValueError:
raised = True
assert raised, "create_task should raise ValueError for invalid priority"
# title validation still works
raised2 = False
try:
_api.create_task("", priority="low")
except ValueError:
raised2 = True
assert raised2, "create_task should still reject empty title"
""").strip(),
max_turns=18,
difficulty=2,
hints=[
"api.py already imports validate_title and validate_tags from validators.",
"You need to also import validate_priority and call it in create_task.",
],
))
# ββ Humanize tasks (real-world library) ββββββββββββββββββββββββββββββββββββββ
_reg(RepoTask(
task_id="t4.intpercent",
repo_name="humanize",
description=textwrap.dedent("""\
Add a function `intpercent(value: float, decimal_places: int = 1) -> str`
to `number.py`.
The function should convert a fraction to a percentage string:
0.0 β "0.0%"
0.5 β "50.0%"
0.753 β "75.3%"
1.0 β "100.0%"
Use `decimal_places` to control how many digits appear after the decimal.
If decimal_places=0, return an integer percentage with no decimal point.
""").strip(),
test_code=textwrap.dedent("""\
from humanize.number import intpercent
assert intpercent(0.0) == "0.0%", f"got {intpercent(0.0)!r}"
assert intpercent(0.5) == "50.0%", f"got {intpercent(0.5)!r}"
assert intpercent(0.753) == "75.3%", f"got {intpercent(0.753)!r}"
assert intpercent(1.0) == "100.0%", f"got {intpercent(1.0)!r}"
assert intpercent(0.5, decimal_places=0) == "50%", f"got {intpercent(0.5, decimal_places=0)!r}"
""").strip(),
max_turns=12,
difficulty=0,
hints=[
"Look at number.py β the existing functions show the style to follow.",
"Use f-string formatting: f'{value * 100:.{decimal_places}f}%'",
],
))
_reg(RepoTask(
task_id="t5.naturalfilecount",
repo_name="humanize",
description=textwrap.dedent("""\
Add a function `naturalfilecount(n: int) -> str` to `filesize.py`.
The function should return a human-readable file count:
0 β "no files"
1 β "1 file"
2 β "2 files"
99 β "99 files"
""").strip(),
test_code=textwrap.dedent("""\
from humanize.filesize import naturalfilecount
assert naturalfilecount(0) == "no files", f"got {naturalfilecount(0)!r}"
assert naturalfilecount(1) == "1 file", f"got {naturalfilecount(1)!r}"
assert naturalfilecount(2) == "2 files", f"got {naturalfilecount(2)!r}"
assert naturalfilecount(99) == "99 files", f"got {naturalfilecount(99)!r}"
""").strip(),
max_turns=12,
difficulty=0,
hints=[
"Look at filesize.py β naturalsize is the only function there.",
"This is a short function: handle n==0, n==1, and n>1 as three cases.",
],
))
_reg(RepoTask(
task_id="t6.metric",
repo_name="humanize",
description=textwrap.dedent("""\
Add a function `metric(value: float, unit: str = "") -> str` to `number.py`.
The function should format a number using SI metric prefixes:
1_500_000 β "1.5 M"
2_000 β "2.0 k"
500 β "500" (no prefix below 1000)
Supported prefixes (largest to smallest): T (10ΒΉΒ²), G (10βΉ), M (10βΆ), k (10Β³).
If a unit is provided, append it after the prefix: metric(1500, "Hz") β "1.5 kHz".
Always format the scaled number to 1 decimal place.
""").strip(),
test_code=textwrap.dedent("""\
from humanize.number import metric
assert metric(1_500_000) == "1.5 M", f"got {metric(1_500_000)!r}"
assert metric(2_000) == "2.0 k", f"got {metric(2_000)!r}"
assert metric(500) == "500", f"got {metric(500)!r}"
assert metric(1_500, "Hz") == "1.5 kHz", f"got {metric(1_500, 'Hz')!r}"
assert metric(2e9, "W") == "2.0 GW", f"got {metric(2e9, 'W')!r}"
""").strip(),
max_turns=15,
difficulty=1,
hints=[
"Loop through prefixes from largest to smallest: (1e12,'T'), (1e9,'G'), (1e6,'M'), (1e3,'k').",
"If abs(value) >= threshold, scale and format; otherwise return str(int(value)).",
],
))
_reg(RepoTask(
task_id="t7.age",
repo_name="humanize",
description=textwrap.dedent("""\
Add a function `age(birth_date) -> str` to `time.py`.
The function receives a `datetime.date` and returns a human-readable age:
- If the person is under 1 year old, return "X months old" (use 30-day months).
- If exactly 1 year, return "1 year old".
- Otherwise return "X years old".
Use `datetime.date.today()` as the reference point.
Assume birth_date is always a valid date in the past.
""").strip(),
test_code=textwrap.dedent("""\
import datetime as dt
from humanize.time import age
today = dt.date.today()
dob_25y = today.replace(year=today.year - 25)
dob_1y = today.replace(year=today.year - 1)
dob_6m = today - dt.timedelta(days=182)
dob_2m = today - dt.timedelta(days=61)
assert age(dob_25y) == "25 years old", f"got {age(dob_25y)!r}"
assert age(dob_1y) == "1 year old", f"got {age(dob_1y)!r}"
assert age(dob_6m) == "6 months old", f"got {age(dob_6m)!r}"
assert age(dob_2m) == "2 months old", f"got {age(dob_2m)!r}"
""").strip(),
max_turns=15,
difficulty=1,
hints=[
"import datetime as dt is already at the top of time.py.",
"days = (dt.date.today() - birth_date).days; years = days // 365; months = days // 30",
],
))
# ββ test runner βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def run_tests(task: RepoTask) -> tuple[bool, str]:
"""Execute task.test_code and return (passed, message)."""
# Reload all task_manager modules to pick up any source-level changes
_reload_task_manager()
try:
exec(compile(task.test_code, "<test>", "exec"), {}) # noqa: S102
return True, "All assertions passed."
except AssertionError as exc:
return False, f"AssertionError: {exc}"
except Exception:
return False, traceback.format_exc(limit=5)
def _reload_task_manager() -> None:
"""Force-reload all task_manager submodules so edits take effect."""
prefix = "graphforge.sample_repos.task_manager"
to_reload = [k for k in sys.modules if k.startswith(prefix)]
for mod_name in to_reload:
del sys.modules[mod_name]
def all_task_ids() -> list[str]:
return list(TASK_BANK.keys())
def get_task(task_id: str) -> RepoTask | None:
return TASK_BANK.get(task_id)
|