File size: 1,962 Bytes
b05b6f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from types import SimpleNamespace

import pytest

from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
from agent.tools.sandbox_tool import resolve_sandbox_script


class FakeSandbox:
    def __init__(self):
        self.read_paths = []

    def read(self, path, *, limit):
        self.read_paths.append((path, limit))
        return SimpleNamespace(
            success=True,
            output="1\tprint('training')\n2\tprint('done')",
            error="",
        )


@pytest.mark.asyncio
async def test_resolve_sandbox_script_accepts_bare_python_filename():
    sandbox = FakeSandbox()

    content, error = await resolve_sandbox_script(sandbox, "train_smollm2.py")

    assert error is None
    assert content == "print('training')\nprint('done')"
    assert sandbox.read_paths == [("train_smollm2.py", 100_000)]


@pytest.mark.asyncio
async def test_resolve_sandbox_script_accepts_relative_python_path():
    sandbox = FakeSandbox()

    content, error = await resolve_sandbox_script(sandbox, "scripts/train.py")

    assert error is None
    assert content == "print('training')\nprint('done')"
    assert sandbox.read_paths == [("scripts/train.py", 100_000)]


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "script",
    [
        "https://example.com/train.py",
        "http://example.com/train.py",
        "train_smollm2.py --epochs 1",
        "print('hello')",
    ],
)
async def test_resolve_sandbox_script_ignores_non_path_scripts(script):
    sandbox = FakeSandbox()

    content, error = await resolve_sandbox_script(sandbox, script)

    assert content is None
    assert error is None
    assert sandbox.read_paths == []


def test_hf_jobs_script_description_mentions_bare_python_filenames():
    script_description = HF_JOBS_TOOL_SPEC["parameters"]["properties"]["script"][
        "description"
    ]

    assert "bare 'train.py'" in script_description
    assert "smoke-test in a GPU sandbox before submission" in script_description