Add CLI sandbox runtime and fix HF Jobs script paths (#237)
Browse files* Resolve bare Python job scripts
Co-authored-by: OpenAI Codex <codex@openai.com>
* Update HF Jobs script description
Co-authored-by: OpenAI Codex <codex@openai.com>
* Clarify HF Jobs script path prompt
Co-authored-by: OpenAI Codex <codex@openai.com>
* Add CLI sandbox tool runtime
Co-authored-by: OpenAI Codex <codex@openai.com>
* Document CLI sandbox tool runtime
Co-authored-by: OpenAI Codex <codex@openai.com>
* Wait for initial CLI sandbox preload
Co-authored-by: OpenAI Codex <codex@openai.com>
* Strengthen GPU sandbox preflight guidance
Co-authored-by: OpenAI Codex <codex@openai.com>
* Guard against text-only stops with unfinished plans
Co-authored-by: OpenAI Codex <codex@openai.com>
* Route sandbox deletion logs through tool events
Co-authored-by: OpenAI Codex <codex@openai.com>
---------
Co-authored-by: OpenAI Codex <codex@openai.com>
- README.md +25 -0
- agent/config.py +2 -1
- agent/core/agent_loop.py +94 -1
- agent/core/session.py +2 -0
- agent/main.py +64 -12
- agent/prompts/system_prompt_v3.yaml +10 -1
- agent/tools/jobs_tool.py +5 -1
- agent/tools/plan_tool.py +8 -4
- agent/tools/sandbox_client.py +5 -3
- agent/tools/sandbox_tool.py +36 -16
- agent/utils/terminal_display.py +6 -1
- configs/cli_agent_config.json +1 -0
- tests/unit/test_cli_rendering.py +202 -3
- tests/unit/test_config.py +35 -0
- tests/unit/test_hub_artifacts.py +5 -1
- tests/unit/test_no_tool_continuation_guard.py +147 -0
- tests/unit/test_sandbox_auto_start.py +107 -0
- tests/unit/test_sandbox_private_spaces.py +121 -6
- tests/unit/test_sandbox_script_resolution.py +70 -0
- tests/unit/test_session_manager_persistence.py +1 -1
|
@@ -63,6 +63,7 @@ ml-intern --model anthropic/claude-opus-4-7 "your prompt" # requires ANTHROPIC
|
|
| 63 |
ml-intern --model openai/gpt-5.5 "your prompt" # requires OPENAI_API_KEY
|
| 64 |
ml-intern --model ollama/llama3.1:8b "your prompt"
|
| 65 |
ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt"
|
|
|
|
| 66 |
ml-intern --max-iterations 100 "your prompt"
|
| 67 |
ml-intern --no-stream "your prompt"
|
| 68 |
```
|
|
@@ -97,6 +98,30 @@ one shared local endpoint, or override a specific provider with its matching
|
|
| 97 |
`VLLM_API_KEY`. Provider-specific variables take precedence over the shared
|
| 98 |
local variables. Base URLs may include or omit `/v1`.
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
## Sharing Traces
|
| 101 |
|
| 102 |
Every session is auto-uploaded to your **own private Hugging Face dataset**
|
|
|
|
| 63 |
ml-intern --model openai/gpt-5.5 "your prompt" # requires OPENAI_API_KEY
|
| 64 |
ml-intern --model ollama/llama3.1:8b "your prompt"
|
| 65 |
ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt"
|
| 66 |
+
ml-intern --sandbox-tools "your prompt" # use HF Space sandbox tools
|
| 67 |
ml-intern --max-iterations 100 "your prompt"
|
| 68 |
ml-intern --no-stream "your prompt"
|
| 69 |
```
|
|
|
|
| 98 |
`VLLM_API_KEY`. Provider-specific variables take precedence over the shared
|
| 99 |
local variables. Base URLs may include or omit `/v1`.
|
| 100 |
|
| 101 |
+
**CLI tool runtime:**
|
| 102 |
+
|
| 103 |
+
By default, the CLI runs `bash`, `read`, `write`, and `edit` on your local
|
| 104 |
+
filesystem. To use HF Space sandbox tools instead, including `sandbox_create`,
|
| 105 |
+
opt in with `--sandbox-tools`:
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
ml-intern --sandbox-tools "test this training script in a GPU sandbox"
|
| 109 |
+
ml-intern --model llamacpp/ggml-org/gemma-3-1b-it-GGUF --sandbox-tools
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
Sandbox tool runtime requires `HF_TOKEN`, even when the selected model is local,
|
| 113 |
+
because it creates private HF Spaces. You can also make sandbox tools your CLI
|
| 114 |
+
default in `~/.config/ml-intern/cli_agent_config.json`:
|
| 115 |
+
|
| 116 |
+
```json
|
| 117 |
+
{ "tool_runtime": "sandbox" }
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
Use the default local runtime when you want tools to inspect or edit files in
|
| 121 |
+
your checkout. Use sandbox runtime when you want the agent to create or replace
|
| 122 |
+
an HF Space sandbox, test code remotely, or request GPU sandbox hardware before
|
| 123 |
+
launching larger HF Jobs.
|
| 124 |
+
|
| 125 |
## Sharing Traces
|
| 126 |
|
| 127 |
Every session is auto-uploaded to your **own private Hugging Face dataset**
|
|
@@ -2,7 +2,7 @@ import json
|
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import Any, Union
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from fastmcp.mcp_config import (
|
|
@@ -46,6 +46,7 @@ class Config(BaseModel):
|
|
| 46 |
# Permission control parameters
|
| 47 |
confirm_cpu_jobs: bool = True
|
| 48 |
auto_file_upload: bool = False
|
|
|
|
| 49 |
|
| 50 |
# Reasoning effort *preference* — the ceiling the user wants. The probe
|
| 51 |
# on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
|
|
|
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
from pathlib import Path
|
| 5 |
+
from typing import Any, Literal, Union
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from fastmcp.mcp_config import (
|
|
|
|
| 46 |
# Permission control parameters
|
| 47 |
confirm_cpu_jobs: bool = True
|
| 48 |
auto_file_upload: bool = False
|
| 49 |
+
tool_runtime: Literal["local", "sandbox"] = "local"
|
| 50 |
|
| 51 |
# Reasoning effort *preference* — the ceiling the user wants. The probe
|
| 52 |
# on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
|
|
@@ -32,7 +32,11 @@ from agent.core.prompt_caching import with_prompt_caching
|
|
| 32 |
from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
|
| 33 |
from agent.core.tools import ToolRouter
|
| 34 |
from agent.tools.jobs_tool import CPU_FLAVORS
|
| 35 |
-
from agent.tools.sandbox_tool import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
logger = logging.getLogger(__name__)
|
| 38 |
|
|
@@ -40,6 +44,43 @@ ToolCall = ChatCompletionMessageToolCall
|
|
| 40 |
|
| 41 |
_MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
|
| 42 |
_MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def _malformed_tool_name(message: Message) -> str | None:
|
|
@@ -1153,6 +1194,7 @@ class Handlers:
|
|
| 1153 |
final_response = None
|
| 1154 |
errored = False
|
| 1155 |
max_iterations = session.config.max_iterations
|
|
|
|
| 1156 |
|
| 1157 |
while max_iterations == -1 or iteration < max_iterations:
|
| 1158 |
# ── Cancellation check: before LLM call ──
|
|
@@ -1301,6 +1343,51 @@ class Handlers:
|
|
| 1301 |
|
| 1302 |
# If no tool calls, add assistant message and we're done
|
| 1303 |
if not tool_calls:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1304 |
logger.debug(
|
| 1305 |
"Agent loop ending: no tool calls. "
|
| 1306 |
"finish_reason=%s, token_count=%d, "
|
|
@@ -1324,6 +1411,8 @@ class Handlers:
|
|
| 1324 |
final_response = content
|
| 1325 |
break
|
| 1326 |
|
|
|
|
|
|
|
| 1327 |
# Validate tool call args (one json.loads per call, once)
|
| 1328 |
# and split into good vs bad
|
| 1329 |
good_tools: list[tuple[ToolCall, str, dict]] = []
|
|
@@ -1940,6 +2029,8 @@ class Handlers:
|
|
| 1940 |
_ = session.save_and_upload_detached(repo_id)
|
| 1941 |
|
| 1942 |
session.is_running = False
|
|
|
|
|
|
|
| 1943 |
await session.send_event(Event(event_type="shutdown"))
|
| 1944 |
return True
|
| 1945 |
|
|
@@ -2023,6 +2114,8 @@ async def submission_loop(
|
|
| 2023 |
)
|
| 2024 |
if session_holder is not None:
|
| 2025 |
session_holder[0] = session
|
|
|
|
|
|
|
| 2026 |
logger.info("Agent loop started")
|
| 2027 |
|
| 2028 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
|
|
|
| 32 |
from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
|
| 33 |
from agent.core.tools import ToolRouter
|
| 34 |
from agent.tools.jobs_tool import CPU_FLAVORS
|
| 35 |
+
from agent.tools.sandbox_tool import (
|
| 36 |
+
DEFAULT_CPU_SANDBOX_HARDWARE,
|
| 37 |
+
start_cpu_sandbox_preload,
|
| 38 |
+
teardown_session_sandbox,
|
| 39 |
+
)
|
| 40 |
|
| 41 |
logger = logging.getLogger(__name__)
|
| 42 |
|
|
|
|
| 44 |
|
| 45 |
_MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
|
| 46 |
_MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
|
| 47 |
+
_NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT = 2
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _unfinished_plan_items(session: Session) -> list[dict[str, str]]:
|
| 51 |
+
plan = getattr(session, "current_plan", None) or []
|
| 52 |
+
unfinished: list[dict[str, str]] = []
|
| 53 |
+
for item in plan:
|
| 54 |
+
if not isinstance(item, dict):
|
| 55 |
+
continue
|
| 56 |
+
status = item.get("status")
|
| 57 |
+
if status in {"pending", "in_progress"}:
|
| 58 |
+
unfinished.append(item)
|
| 59 |
+
return unfinished
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _format_plan_items_for_guard(items: list[dict[str, str]], limit: int = 4) -> str:
|
| 63 |
+
formatted = []
|
| 64 |
+
for item in items[:limit]:
|
| 65 |
+
item_id = item.get("id") or "?"
|
| 66 |
+
content = item.get("content") or "(unnamed task)"
|
| 67 |
+
status = item.get("status") or "unknown"
|
| 68 |
+
formatted.append(f"{item_id}. {content} [{status}]")
|
| 69 |
+
if len(items) > limit:
|
| 70 |
+
formatted.append(f"... and {len(items) - limit} more")
|
| 71 |
+
return "; ".join(formatted)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _no_tool_incomplete_plan_prompt(items: list[dict[str, str]]) -> str:
|
| 75 |
+
summary = _format_plan_items_for_guard(items)
|
| 76 |
+
return (
|
| 77 |
+
"[SYSTEM: CONTINUATION GUARD] Your previous response ended without any "
|
| 78 |
+
"tool calls, but the task is not complete. The current plan still has "
|
| 79 |
+
f"unfinished items: {summary}. Do not return control to the user yet. "
|
| 80 |
+
"Continue from the next unfinished item and make at least one tool call "
|
| 81 |
+
"now. If you genuinely cannot continue, first use tools to inspect the "
|
| 82 |
+
"state or verify the blocker."
|
| 83 |
+
)
|
| 84 |
|
| 85 |
|
| 86 |
def _malformed_tool_name(message: Message) -> str | None:
|
|
|
|
| 1194 |
final_response = None
|
| 1195 |
errored = False
|
| 1196 |
max_iterations = session.config.max_iterations
|
| 1197 |
+
no_tool_incomplete_plan_retries = 0
|
| 1198 |
|
| 1199 |
while max_iterations == -1 or iteration < max_iterations:
|
| 1200 |
# ── Cancellation check: before LLM call ──
|
|
|
|
| 1343 |
|
| 1344 |
# If no tool calls, add assistant message and we're done
|
| 1345 |
if not tool_calls:
|
| 1346 |
+
unfinished_plan = _unfinished_plan_items(session)
|
| 1347 |
+
if (
|
| 1348 |
+
unfinished_plan
|
| 1349 |
+
and no_tool_incomplete_plan_retries
|
| 1350 |
+
< _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT
|
| 1351 |
+
):
|
| 1352 |
+
logger.info(
|
| 1353 |
+
"No tool calls with unfinished plan; retrying agent turn "
|
| 1354 |
+
"(attempt %d/%d)",
|
| 1355 |
+
no_tool_incomplete_plan_retries + 1,
|
| 1356 |
+
_NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT,
|
| 1357 |
+
)
|
| 1358 |
+
if content:
|
| 1359 |
+
assistant_msg = _assistant_message_from_result(
|
| 1360 |
+
llm_result,
|
| 1361 |
+
model_name=llm_params.get("model"),
|
| 1362 |
+
)
|
| 1363 |
+
session.context_manager.add_message(
|
| 1364 |
+
assistant_msg, token_count
|
| 1365 |
+
)
|
| 1366 |
+
session.context_manager.add_message(
|
| 1367 |
+
Message(
|
| 1368 |
+
role="user",
|
| 1369 |
+
content=_no_tool_incomplete_plan_prompt(
|
| 1370 |
+
unfinished_plan
|
| 1371 |
+
),
|
| 1372 |
+
)
|
| 1373 |
+
)
|
| 1374 |
+
no_tool_incomplete_plan_retries += 1
|
| 1375 |
+
await session.send_event(
|
| 1376 |
+
Event(
|
| 1377 |
+
event_type="tool_log",
|
| 1378 |
+
data={
|
| 1379 |
+
"tool": "system",
|
| 1380 |
+
"log": (
|
| 1381 |
+
"Plan still has unfinished items after a "
|
| 1382 |
+
"text-only response — retrying instead of "
|
| 1383 |
+
"returning to the prompt."
|
| 1384 |
+
),
|
| 1385 |
+
},
|
| 1386 |
+
)
|
| 1387 |
+
)
|
| 1388 |
+
iteration += 1
|
| 1389 |
+
continue
|
| 1390 |
+
|
| 1391 |
logger.debug(
|
| 1392 |
"Agent loop ending: no tool calls. "
|
| 1393 |
"finish_reason=%s, token_count=%d, "
|
|
|
|
| 1411 |
final_response = content
|
| 1412 |
break
|
| 1413 |
|
| 1414 |
+
no_tool_incomplete_plan_retries = 0
|
| 1415 |
+
|
| 1416 |
# Validate tool call args (one json.loads per call, once)
|
| 1417 |
# and split into good vs bad
|
| 1418 |
good_tools: list[tuple[ToolCall, str, dict]] = []
|
|
|
|
| 2029 |
_ = session.save_and_upload_detached(repo_id)
|
| 2030 |
|
| 2031 |
session.is_running = False
|
| 2032 |
+
if not getattr(session, "local_mode", False):
|
| 2033 |
+
await teardown_session_sandbox(session)
|
| 2034 |
await session.send_event(Event(event_type="shutdown"))
|
| 2035 |
return True
|
| 2036 |
|
|
|
|
| 2114 |
)
|
| 2115 |
if session_holder is not None:
|
| 2116 |
session_holder[0] = session
|
| 2117 |
+
if not local_mode:
|
| 2118 |
+
start_cpu_sandbox_preload(session)
|
| 2119 |
logger.info("Agent loop started")
|
| 2120 |
|
| 2121 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
|
@@ -99,6 +99,7 @@ class Session:
|
|
| 99 |
self.hf_token: Optional[str] = hf_token
|
| 100 |
self.user_id: Optional[str] = user_id
|
| 101 |
self.hf_username: Optional[str] = hf_username
|
|
|
|
| 102 |
self.persistence_store = persistence_store
|
| 103 |
self.tool_router = tool_router
|
| 104 |
self.stream = stream
|
|
@@ -117,6 +118,7 @@ class Session:
|
|
| 117 |
self.session_id = session_id or str(uuid.uuid4())
|
| 118 |
self.config = config
|
| 119 |
self.is_running = True
|
|
|
|
| 120 |
self._cancelled = asyncio.Event()
|
| 121 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 122 |
self.sandbox = None
|
|
|
|
| 99 |
self.hf_token: Optional[str] = hf_token
|
| 100 |
self.user_id: Optional[str] = user_id
|
| 101 |
self.hf_username: Optional[str] = hf_username
|
| 102 |
+
self.local_mode = local_mode
|
| 103 |
self.persistence_store = persistence_store
|
| 104 |
self.tool_router = tool_router
|
| 105 |
self.stream = stream
|
|
|
|
| 118 |
self.session_id = session_id or str(uuid.uuid4())
|
| 119 |
self.config = config
|
| 120 |
self.is_running = True
|
| 121 |
+
self.current_plan: list[dict[str, str]] = []
|
| 122 |
self._cancelled = asyncio.Event()
|
| 123 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 124 |
self.sandbox = None
|
|
@@ -59,6 +59,34 @@ CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.j
|
|
| 59 |
logger = logging.getLogger(__name__)
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
|
| 63 |
if tool_info.get("tool") != "hf_jobs":
|
| 64 |
return False
|
|
@@ -957,6 +985,7 @@ async def _handle_slash_command(
|
|
| 957 |
session = session_holder[0] if session_holder else None
|
| 958 |
print(f"Model: {config.model_name}")
|
| 959 |
print(f"Reasoning effort: {config.reasoning_effort or 'off'}")
|
|
|
|
| 960 |
if session:
|
| 961 |
print(f"Turns: {session.turn_count}")
|
| 962 |
print(f"Context items: {len(session.context_manager.items)}")
|
|
@@ -1076,7 +1105,7 @@ async def _handle_share_traces_command(arg: str, config, session) -> None:
|
|
| 1076 |
console.print(f"[green]Dataset is now {label}.[/green] {url}")
|
| 1077 |
|
| 1078 |
|
| 1079 |
-
async def main(model: str | None = None):
|
| 1080 |
"""Interactive chat with the agent"""
|
| 1081 |
|
| 1082 |
# Clear screen
|
|
@@ -1088,16 +1117,23 @@ async def main(model: str | None = None):
|
|
| 1088 |
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 1089 |
if model:
|
| 1090 |
config.model_name = model
|
|
|
|
|
|
|
| 1091 |
|
| 1092 |
-
# HF token — required for Hub-backed models/tools
|
|
|
|
| 1093 |
hf_token = resolve_hf_token()
|
| 1094 |
-
if not hf_token and not is_local_model_id(config.model_name):
|
| 1095 |
hf_token = await _prompt_and_save_hf_token(prompt_session)
|
| 1096 |
|
| 1097 |
# Resolve username for banner
|
| 1098 |
hf_user = _get_hf_user(hf_token)
|
| 1099 |
|
| 1100 |
-
print_banner(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1101 |
|
| 1102 |
# Pre-warm the HF router catalog in the background so /model switches
|
| 1103 |
# don't block on a network fetch.
|
|
@@ -1116,8 +1152,10 @@ async def main(model: str | None = None):
|
|
| 1116 |
|
| 1117 |
notification_gateway = NotificationGateway(config.messaging)
|
| 1118 |
await notification_gateway.start()
|
| 1119 |
-
# Create tool router with
|
| 1120 |
-
tool_router = ToolRouter(
|
|
|
|
|
|
|
| 1121 |
|
| 1122 |
# Session holder for interrupt/model/status access
|
| 1123 |
session_holder = [None]
|
|
@@ -1131,7 +1169,7 @@ async def main(model: str | None = None):
|
|
| 1131 |
session_holder=session_holder,
|
| 1132 |
hf_token=hf_token,
|
| 1133 |
user_id=hf_user,
|
| 1134 |
-
local_mode=
|
| 1135 |
stream=True,
|
| 1136 |
notification_gateway=notification_gateway,
|
| 1137 |
notification_destinations=config.messaging.default_auto_destinations(),
|
|
@@ -1153,6 +1191,8 @@ async def main(model: str | None = None):
|
|
| 1153 |
)
|
| 1154 |
|
| 1155 |
await ready_event.wait()
|
|
|
|
|
|
|
| 1156 |
|
| 1157 |
submission_id = [0]
|
| 1158 |
# Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137
|
|
@@ -1310,6 +1350,7 @@ async def headless_main(
|
|
| 1310 |
model: str | None = None,
|
| 1311 |
max_iterations: int | None = None,
|
| 1312 |
stream: bool = True,
|
|
|
|
| 1313 |
) -> None:
|
| 1314 |
"""Run a single prompt headlessly and exit."""
|
| 1315 |
import logging
|
|
@@ -1322,11 +1363,13 @@ async def headless_main(
|
|
| 1322 |
|
| 1323 |
if model:
|
| 1324 |
config.model_name = model
|
|
|
|
|
|
|
| 1325 |
|
| 1326 |
hf_token = resolve_hf_token()
|
| 1327 |
-
if not hf_token and not is_local_model_id(config.model_name):
|
| 1328 |
print(
|
| 1329 |
-
"ERROR: No HF token found. Set HF_TOKEN or run `
|
| 1330 |
file=sys.stderr,
|
| 1331 |
)
|
| 1332 |
sys.exit(1)
|
|
@@ -1342,6 +1385,7 @@ async def headless_main(
|
|
| 1342 |
config.max_iterations = max_iterations
|
| 1343 |
|
| 1344 |
print(f"Model: {config.model_name}", file=sys.stderr)
|
|
|
|
| 1345 |
print(f"Max iterations: {config.max_iterations}", file=sys.stderr)
|
| 1346 |
print(f"Prompt: {prompt}", file=sys.stderr)
|
| 1347 |
print("---", file=sys.stderr)
|
|
@@ -1349,7 +1393,9 @@ async def headless_main(
|
|
| 1349 |
submission_queue: asyncio.Queue = asyncio.Queue()
|
| 1350 |
event_queue: asyncio.Queue = asyncio.Queue()
|
| 1351 |
|
| 1352 |
-
tool_router = ToolRouter(
|
|
|
|
|
|
|
| 1353 |
session_holder: list = [None]
|
| 1354 |
|
| 1355 |
agent_task = asyncio.create_task(
|
|
@@ -1361,7 +1407,7 @@ async def headless_main(
|
|
| 1361 |
session_holder=session_holder,
|
| 1362 |
hf_token=hf_token,
|
| 1363 |
user_id=hf_user,
|
| 1364 |
-
local_mode=
|
| 1365 |
stream=stream,
|
| 1366 |
notification_gateway=notification_gateway,
|
| 1367 |
notification_destinations=config.messaging.default_auto_destinations(),
|
|
@@ -1556,6 +1602,11 @@ def cli():
|
|
| 1556 |
action="store_true",
|
| 1557 |
help="Disable token streaming (use non-streaming LLM calls)",
|
| 1558 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1559 |
args = parser.parse_args()
|
| 1560 |
|
| 1561 |
try:
|
|
@@ -1569,10 +1620,11 @@ def cli():
|
|
| 1569 |
model=args.model,
|
| 1570 |
max_iterations=max_iter,
|
| 1571 |
stream=not args.no_stream,
|
|
|
|
| 1572 |
)
|
| 1573 |
)
|
| 1574 |
else:
|
| 1575 |
-
asyncio.run(main(model=args.model))
|
| 1576 |
except KeyboardInterrupt:
|
| 1577 |
print("\n\nGoodbye!")
|
| 1578 |
|
|
|
|
| 59 |
logger = logging.getLogger(__name__)
|
| 60 |
|
| 61 |
|
| 62 |
+
def _apply_tool_runtime_override(config: Any, *, sandbox_tools: bool) -> str:
|
| 63 |
+
if sandbox_tools:
|
| 64 |
+
config.tool_runtime = "sandbox"
|
| 65 |
+
return getattr(config, "tool_runtime", "local")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _is_local_tool_runtime(config: Any) -> bool:
|
| 69 |
+
return getattr(config, "tool_runtime", "local") == "local"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _tool_runtime_label(local_mode: bool) -> str:
|
| 73 |
+
return "local filesystem" if local_mode else "HF sandbox"
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
async def _wait_for_initial_sandbox_preload(session_holder: list | None) -> None:
|
| 77 |
+
session = session_holder[0] if session_holder else None
|
| 78 |
+
task = getattr(session, "sandbox_preload_task", None)
|
| 79 |
+
if not task:
|
| 80 |
+
return
|
| 81 |
+
try:
|
| 82 |
+
await asyncio.shield(task)
|
| 83 |
+
except asyncio.CancelledError:
|
| 84 |
+
raise
|
| 85 |
+
except Exception:
|
| 86 |
+
# The sandbox tool will surface the stored preload error on first use.
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
|
| 90 |
def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
|
| 91 |
if tool_info.get("tool") != "hf_jobs":
|
| 92 |
return False
|
|
|
|
| 985 |
session = session_holder[0] if session_holder else None
|
| 986 |
print(f"Model: {config.model_name}")
|
| 987 |
print(f"Reasoning effort: {config.reasoning_effort or 'off'}")
|
| 988 |
+
print(f"Tool runtime: {_tool_runtime_label(_is_local_tool_runtime(config))}")
|
| 989 |
if session:
|
| 990 |
print(f"Turns: {session.turn_count}")
|
| 991 |
print(f"Context items: {len(session.context_manager.items)}")
|
|
|
|
| 1105 |
console.print(f"[green]Dataset is now {label}.[/green] {url}")
|
| 1106 |
|
| 1107 |
|
| 1108 |
+
async def main(model: str | None = None, sandbox_tools: bool = False):
|
| 1109 |
"""Interactive chat with the agent"""
|
| 1110 |
|
| 1111 |
# Clear screen
|
|
|
|
| 1117 |
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 1118 |
if model:
|
| 1119 |
config.model_name = model
|
| 1120 |
+
_apply_tool_runtime_override(config, sandbox_tools=sandbox_tools)
|
| 1121 |
+
local_mode = _is_local_tool_runtime(config)
|
| 1122 |
|
| 1123 |
+
# HF token — required for Hub-backed models/tools and sandbox tools, but
|
| 1124 |
+
# not for local LLMs using only local filesystem tools.
|
| 1125 |
hf_token = resolve_hf_token()
|
| 1126 |
+
if not hf_token and (not is_local_model_id(config.model_name) or not local_mode):
|
| 1127 |
hf_token = await _prompt_and_save_hf_token(prompt_session)
|
| 1128 |
|
| 1129 |
# Resolve username for banner
|
| 1130 |
hf_user = _get_hf_user(hf_token)
|
| 1131 |
|
| 1132 |
+
print_banner(
|
| 1133 |
+
model=config.model_name,
|
| 1134 |
+
hf_user=hf_user,
|
| 1135 |
+
tool_runtime=_tool_runtime_label(local_mode),
|
| 1136 |
+
)
|
| 1137 |
|
| 1138 |
# Pre-warm the HF router catalog in the background so /model switches
|
| 1139 |
# don't block on a network fetch.
|
|
|
|
| 1152 |
|
| 1153 |
notification_gateway = NotificationGateway(config.messaging)
|
| 1154 |
await notification_gateway.start()
|
| 1155 |
+
# Create tool router with the selected CLI tool runtime.
|
| 1156 |
+
tool_router = ToolRouter(
|
| 1157 |
+
config.mcpServers, hf_token=hf_token, local_mode=local_mode
|
| 1158 |
+
)
|
| 1159 |
|
| 1160 |
# Session holder for interrupt/model/status access
|
| 1161 |
session_holder = [None]
|
|
|
|
| 1169 |
session_holder=session_holder,
|
| 1170 |
hf_token=hf_token,
|
| 1171 |
user_id=hf_user,
|
| 1172 |
+
local_mode=local_mode,
|
| 1173 |
stream=True,
|
| 1174 |
notification_gateway=notification_gateway,
|
| 1175 |
notification_destinations=config.messaging.default_auto_destinations(),
|
|
|
|
| 1191 |
)
|
| 1192 |
|
| 1193 |
await ready_event.wait()
|
| 1194 |
+
if not local_mode:
|
| 1195 |
+
await _wait_for_initial_sandbox_preload(session_holder)
|
| 1196 |
|
| 1197 |
submission_id = [0]
|
| 1198 |
# Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137
|
|
|
|
| 1350 |
model: str | None = None,
|
| 1351 |
max_iterations: int | None = None,
|
| 1352 |
stream: bool = True,
|
| 1353 |
+
sandbox_tools: bool = False,
|
| 1354 |
) -> None:
|
| 1355 |
"""Run a single prompt headlessly and exit."""
|
| 1356 |
import logging
|
|
|
|
| 1363 |
|
| 1364 |
if model:
|
| 1365 |
config.model_name = model
|
| 1366 |
+
_apply_tool_runtime_override(config, sandbox_tools=sandbox_tools)
|
| 1367 |
+
local_mode = _is_local_tool_runtime(config)
|
| 1368 |
|
| 1369 |
hf_token = resolve_hf_token()
|
| 1370 |
+
if not hf_token and (not is_local_model_id(config.model_name) or not local_mode):
|
| 1371 |
print(
|
| 1372 |
+
"ERROR: No HF token found. Set HF_TOKEN or run `hf auth login`.",
|
| 1373 |
file=sys.stderr,
|
| 1374 |
)
|
| 1375 |
sys.exit(1)
|
|
|
|
| 1385 |
config.max_iterations = max_iterations
|
| 1386 |
|
| 1387 |
print(f"Model: {config.model_name}", file=sys.stderr)
|
| 1388 |
+
print(f"Tool runtime: {_tool_runtime_label(local_mode)}", file=sys.stderr)
|
| 1389 |
print(f"Max iterations: {config.max_iterations}", file=sys.stderr)
|
| 1390 |
print(f"Prompt: {prompt}", file=sys.stderr)
|
| 1391 |
print("---", file=sys.stderr)
|
|
|
|
| 1393 |
submission_queue: asyncio.Queue = asyncio.Queue()
|
| 1394 |
event_queue: asyncio.Queue = asyncio.Queue()
|
| 1395 |
|
| 1396 |
+
tool_router = ToolRouter(
|
| 1397 |
+
config.mcpServers, hf_token=hf_token, local_mode=local_mode
|
| 1398 |
+
)
|
| 1399 |
session_holder: list = [None]
|
| 1400 |
|
| 1401 |
agent_task = asyncio.create_task(
|
|
|
|
| 1407 |
session_holder=session_holder,
|
| 1408 |
hf_token=hf_token,
|
| 1409 |
user_id=hf_user,
|
| 1410 |
+
local_mode=local_mode,
|
| 1411 |
stream=stream,
|
| 1412 |
notification_gateway=notification_gateway,
|
| 1413 |
notification_destinations=config.messaging.default_auto_destinations(),
|
|
|
|
| 1602 |
action="store_true",
|
| 1603 |
help="Disable token streaming (use non-streaming LLM calls)",
|
| 1604 |
)
|
| 1605 |
+
parser.add_argument(
|
| 1606 |
+
"--sandbox-tools",
|
| 1607 |
+
action="store_true",
|
| 1608 |
+
help="Use HF Space sandbox tools instead of local filesystem tools",
|
| 1609 |
+
)
|
| 1610 |
args = parser.parse_args()
|
| 1611 |
|
| 1612 |
try:
|
|
|
|
| 1620 |
model=args.model,
|
| 1621 |
max_iterations=max_iter,
|
| 1622 |
stream=not args.no_stream,
|
| 1623 |
+
sandbox_tools=args.sandbox_tools,
|
| 1624 |
)
|
| 1625 |
)
|
| 1626 |
else:
|
| 1627 |
+
asyncio.run(main(model=args.model, sandbox_tools=args.sandbox_tools))
|
| 1628 |
except KeyboardInterrupt:
|
| 1629 |
print("\n\nGoodbye!")
|
| 1630 |
|
|
@@ -102,9 +102,18 @@ system_prompt: |
|
|
| 102 |
|
| 103 |
# When submitting a training job
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
Before calling hf_jobs, output a pre-flight check:
|
| 106 |
- Reference implementation: [which example you based this on]
|
| 107 |
- Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
|
|
|
|
| 108 |
- push_to_hub=True and hub_model_id set
|
| 109 |
- timeout: [value] (based on: [model size] on [hardware])
|
| 110 |
- Trackio monitoring included and deploying metrics to a public Space
|
|
@@ -127,7 +136,7 @@ system_prompt: |
|
|
| 127 |
|
| 128 |
Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier.
|
| 129 |
|
| 130 |
-
Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths.
|
| 131 |
|
| 132 |
|
| 133 |
# When a task has 3+ steps
|
|
|
|
| 102 |
|
| 103 |
# When submitting a training job
|
| 104 |
|
| 105 |
+
Never pass a local machine path to hf_jobs.script, such as /Users/..., /home/..., /fsx/..., or a repo checkout path. HF Jobs runs in a fresh cloud environment where local files do not exist. For hf_jobs.script, use exactly one of:
|
| 106 |
+
- inline Python source code
|
| 107 |
+
- a file already written in the session sandbox, e.g. /app/train.py, ./train.py, or train.py
|
| 108 |
+
- a public/raw URL
|
| 109 |
+
If you wrote or tested a script locally, read the file content and submit it inline, or write it into the sandbox first.
|
| 110 |
+
|
| 111 |
+
GPU preflight is mandatory before hf_jobs when the job will run on GPU, or when the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, or torch.compile. First create a GPU sandbox with sandbox_create (t4-small minimum; choose larger hardware when VRAM requires it), run a tiny smoke test there using the same imports, model-loading path, training entrypoint, and a tiny dataset/subset, then fix failures before submitting. If you skip GPU sandbox preflight, state why before calling hf_jobs.
|
| 112 |
+
|
| 113 |
Before calling hf_jobs, output a pre-flight check:
|
| 114 |
- Reference implementation: [which example you based this on]
|
| 115 |
- Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
|
| 116 |
+
- GPU sandbox smoke test: [hardware and result, or explicitly not applicable because ...]
|
| 117 |
- push_to_hub=True and hub_model_id set
|
| 118 |
- timeout: [value] (based on: [model size] on [hardware])
|
| 119 |
- Trackio monitoring included and deploying metrics to a public Space
|
|
|
|
| 136 |
|
| 137 |
Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier.
|
| 138 |
|
| 139 |
+
Use a GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16/fp16, quantization, flash attention, torch.compile, or model loading. CPU sandboxes cannot test GPU code paths. If the available sandbox tiers cannot fit the full model path, test the largest useful smoke path, state what was not covered, and submit one HF job first.
|
| 140 |
|
| 141 |
|
| 142 |
# When a task has 3+ steps
|
|
@@ -1112,6 +1112,9 @@ HF_JOBS_TOOL_SPEC = {
|
|
| 1112 |
"- You MUST have called github_find_examples + github_read_file to find a working reference implementation. "
|
| 1113 |
"Scripts based on your internal knowledge WILL use outdated APIs and fail.\n"
|
| 1114 |
"- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
|
|
|
|
|
|
|
|
|
|
| 1115 |
"- Training config MUST include push_to_hub=True and hub_model_id. "
|
| 1116 |
"Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
|
| 1117 |
"- Include trackio monitoring and provide the dashboard URL to the user. "
|
|
@@ -1157,8 +1160,9 @@ HF_JOBS_TOOL_SPEC = {
|
|
| 1157 |
"script": {
|
| 1158 |
"type": "string",
|
| 1159 |
"description": (
|
| 1160 |
-
"Python code
|
| 1161 |
"Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. "
|
|
|
|
| 1162 |
"Mutually exclusive with 'command'."
|
| 1163 |
),
|
| 1164 |
},
|
|
|
|
| 1112 |
"- You MUST have called github_find_examples + github_read_file to find a working reference implementation. "
|
| 1113 |
"Scripts based on your internal knowledge WILL use outdated APIs and fail.\n"
|
| 1114 |
"- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
|
| 1115 |
+
"- If the job runs on GPU, or the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, "
|
| 1116 |
+
"or torch.compile, you MUST create a GPU sandbox with sandbox_create first, run a tiny smoke test there, "
|
| 1117 |
+
"and fix failures before submitting. If skipped, state why before calling hf_jobs.\n"
|
| 1118 |
"- Training config MUST include push_to_hub=True and hub_model_id. "
|
| 1119 |
"Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
|
| 1120 |
"- Include trackio monitoring and provide the dashboard URL to the user. "
|
|
|
|
| 1160 |
"script": {
|
| 1161 |
"type": "string",
|
| 1162 |
"description": (
|
| 1163 |
+
"Python code, sandbox file path (e.g. '/app/train.py', './train.py', or bare 'train.py'), or URL. "
|
| 1164 |
"Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. "
|
| 1165 |
+
"For GPU/model-loading training scripts, smoke-test in a GPU sandbox before submission. "
|
| 1166 |
"Mutually exclusive with 'command'."
|
| 1167 |
),
|
| 1168 |
},
|
|
@@ -54,20 +54,24 @@ class PlanTool:
|
|
| 54 |
"isError": True,
|
| 55 |
}
|
| 56 |
|
| 57 |
-
# Store the
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# Emit plan update event if session is available
|
| 61 |
if self.session:
|
| 62 |
await self.session.send_event(
|
| 63 |
Event(
|
| 64 |
event_type="plan_update",
|
| 65 |
-
data={"plan":
|
| 66 |
)
|
| 67 |
)
|
| 68 |
|
| 69 |
# Format only for display using terminal_display utility
|
| 70 |
-
formatted_output = format_plan_tool_output(
|
| 71 |
|
| 72 |
return {
|
| 73 |
"formatted": formatted_output,
|
|
|
|
| 54 |
"isError": True,
|
| 55 |
}
|
| 56 |
|
| 57 |
+
# Store a session-scoped copy so the runtime can tell whether a
|
| 58 |
+
# text-only model response is trying to stop while work remains.
|
| 59 |
+
stored_todos = [dict(todo) for todo in todos]
|
| 60 |
+
_current_plan = stored_todos
|
| 61 |
+
if self.session is not None:
|
| 62 |
+
self.session.current_plan = stored_todos
|
| 63 |
|
| 64 |
# Emit plan update event if session is available
|
| 65 |
if self.session:
|
| 66 |
await self.session.send_event(
|
| 67 |
Event(
|
| 68 |
event_type="plan_update",
|
| 69 |
+
data={"plan": stored_todos},
|
| 70 |
)
|
| 71 |
)
|
| 72 |
|
| 73 |
# Format only for display using terminal_display utility
|
| 74 |
+
formatted_output = format_plan_tool_output(stored_todos)
|
| 75 |
|
| 76 |
return {
|
| 77 |
"formatted": formatted_output,
|
|
@@ -776,21 +776,23 @@ class Sandbox:
|
|
| 776 |
f"Last status: {last_status}, last error: {last_err}"
|
| 777 |
)
|
| 778 |
|
| 779 |
-
def delete(self):
|
| 780 |
"""Delete the Space. Only works if this Sandbox created it."""
|
| 781 |
if not self._owns_space:
|
| 782 |
raise RuntimeError(
|
| 783 |
f"This Sandbox did not create {self.space_id}. "
|
| 784 |
f"Use self._hf_api.delete_repo() directly if you're sure."
|
| 785 |
)
|
| 786 |
-
|
|
|
|
| 787 |
self._hf_api.delete_repo(self.space_id, repo_type="space")
|
| 788 |
# Clear ownership so a second cleanup call (e.g. delete_session +
|
| 789 |
# _run_session.finally both fire) early-returns instead of retrying
|
| 790 |
# a 404 delete and emitting a spurious ERROR log.
|
| 791 |
self._owns_space = False
|
| 792 |
self._client.close()
|
| 793 |
-
|
|
|
|
| 794 |
|
| 795 |
def pause(self):
|
| 796 |
"""Pause the Space (stops billing, preserves state)."""
|
|
|
|
| 776 |
f"Last status: {last_status}, last error: {last_err}"
|
| 777 |
)
|
| 778 |
|
| 779 |
+
def delete(self, log: Callable[[str], object] | None = None):
|
| 780 |
"""Delete the Space. Only works if this Sandbox created it."""
|
| 781 |
if not self._owns_space:
|
| 782 |
raise RuntimeError(
|
| 783 |
f"This Sandbox did not create {self.space_id}. "
|
| 784 |
f"Use self._hf_api.delete_repo() directly if you're sure."
|
| 785 |
)
|
| 786 |
+
if log:
|
| 787 |
+
log(f"Deleting sandbox: {self.space_id}...")
|
| 788 |
self._hf_api.delete_repo(self.space_id, repo_type="space")
|
| 789 |
# Clear ownership so a second cleanup call (e.g. delete_session +
|
| 790 |
# _run_session.finally both fire) early-returns instead of retrying
|
| 791 |
# a 404 delete and emitting a spurious ERROR log.
|
| 792 |
self._owns_space = False
|
| 793 |
self._client.close()
|
| 794 |
+
if log:
|
| 795 |
+
log("Deleted.")
|
| 796 |
|
| 797 |
def pause(self):
|
| 798 |
"""Pause the Space (stops billing, preserves state)."""
|
|
@@ -16,6 +16,7 @@ import logging
|
|
| 16 |
import re
|
| 17 |
import threading
|
| 18 |
import weakref
|
|
|
|
| 19 |
from datetime import datetime, timedelta, timezone
|
| 20 |
from typing import Any
|
| 21 |
|
|
@@ -58,17 +59,41 @@ def _get_sandbox_create_lock(owner: str) -> asyncio.Lock:
|
|
| 58 |
return lock
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def _looks_like_path(script: str) -> bool:
|
| 62 |
"""Return True if the script string looks like a file path (not inline code)."""
|
| 63 |
-
|
| 64 |
isinstance(script, str)
|
| 65 |
and script.strip() == script
|
| 66 |
and not any(c in script for c in "\r\n\0")
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
|
|
@@ -303,14 +328,8 @@ async def _create_sandbox_locked(
|
|
| 303 |
)
|
| 304 |
)
|
| 305 |
|
| 306 |
-
# Thread-safe log callback: posts tool_log events from
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
def _log(msg: str) -> None:
|
| 310 |
-
loop.call_soon_threadsafe(
|
| 311 |
-
session.event_queue.put_nowait,
|
| 312 |
-
Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
|
| 313 |
-
)
|
| 314 |
|
| 315 |
# Bridge asyncio cancel event to a threading.Event for the blocking create call.
|
| 316 |
# We poll session._cancelled from the main loop in a background task and set
|
|
@@ -352,7 +371,7 @@ async def _create_sandbox_locked(
|
|
| 352 |
if cancel_flag.is_set():
|
| 353 |
if getattr(sb, "_owns_space", False):
|
| 354 |
try:
|
| 355 |
-
await asyncio.to_thread(sb.delete)
|
| 356 |
except Exception as e:
|
| 357 |
logger.warning(
|
| 358 |
"Failed to delete cancelled sandbox %s: %s", sb.space_id, e
|
|
@@ -497,6 +516,7 @@ async def teardown_session_sandbox(session: Any) -> None:
|
|
| 497 |
return
|
| 498 |
|
| 499 |
space_id = getattr(sandbox, "space_id", None)
|
|
|
|
| 500 |
last_err: Exception | None = None
|
| 501 |
for attempt in range(3):
|
| 502 |
try:
|
|
@@ -505,7 +525,7 @@ async def teardown_session_sandbox(session: Any) -> None:
|
|
| 505 |
space_id,
|
| 506 |
attempt + 1,
|
| 507 |
)
|
| 508 |
-
await asyncio.to_thread(sandbox.delete)
|
| 509 |
from agent.core import telemetry
|
| 510 |
|
| 511 |
await telemetry.record_sandbox_destroy(session, sandbox)
|
|
|
|
| 16 |
import re
|
| 17 |
import threading
|
| 18 |
import weakref
|
| 19 |
+
from collections.abc import Callable
|
| 20 |
from datetime import datetime, timedelta, timezone
|
| 21 |
from typing import Any
|
| 22 |
|
|
|
|
| 59 |
return lock
|
| 60 |
|
| 61 |
|
| 62 |
+
def _session_tool_logger(
|
| 63 |
+
session: Any, *, tool: str = "sandbox"
|
| 64 |
+
) -> Callable[[str], object] | None:
|
| 65 |
+
event_queue = getattr(session, "event_queue", None)
|
| 66 |
+
if event_queue is None:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
loop = asyncio.get_running_loop()
|
| 70 |
+
|
| 71 |
+
def _log(msg: str) -> None:
|
| 72 |
+
loop.call_soon_threadsafe(
|
| 73 |
+
event_queue.put_nowait,
|
| 74 |
+
Event(event_type="tool_log", data={"tool": tool, "log": msg}),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return _log
|
| 78 |
+
|
| 79 |
+
|
| 80 |
def _looks_like_path(script: str) -> bool:
|
| 81 |
"""Return True if the script string looks like a file path (not inline code)."""
|
| 82 |
+
if not (
|
| 83 |
isinstance(script, str)
|
| 84 |
and script.strip() == script
|
| 85 |
and not any(c in script for c in "\r\n\0")
|
| 86 |
+
):
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
if script.startswith("http://") or script.startswith("https://"):
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
return (
|
| 93 |
+
script.startswith("/")
|
| 94 |
+
or script.startswith("./")
|
| 95 |
+
or script.startswith("../")
|
| 96 |
+
or (script.endswith(".py") and not any(c.isspace() for c in script))
|
| 97 |
)
|
| 98 |
|
| 99 |
|
|
|
|
| 328 |
)
|
| 329 |
)
|
| 330 |
|
| 331 |
+
# Thread-safe log callback: posts tool_log events from worker threads.
|
| 332 |
+
_log = _session_tool_logger(session) or (lambda msg: None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
# Bridge asyncio cancel event to a threading.Event for the blocking create call.
|
| 335 |
# We poll session._cancelled from the main loop in a background task and set
|
|
|
|
| 371 |
if cancel_flag.is_set():
|
| 372 |
if getattr(sb, "_owns_space", False):
|
| 373 |
try:
|
| 374 |
+
await asyncio.to_thread(sb.delete, log=_log)
|
| 375 |
except Exception as e:
|
| 376 |
logger.warning(
|
| 377 |
"Failed to delete cancelled sandbox %s: %s", sb.space_id, e
|
|
|
|
| 516 |
return
|
| 517 |
|
| 518 |
space_id = getattr(sandbox, "space_id", None)
|
| 519 |
+
delete_log = _session_tool_logger(session)
|
| 520 |
last_err: Exception | None = None
|
| 521 |
for attempt in range(3):
|
| 522 |
try:
|
|
|
|
| 525 |
space_id,
|
| 526 |
attempt + 1,
|
| 527 |
)
|
| 528 |
+
await asyncio.to_thread(sandbox.delete, log=delete_log)
|
| 529 |
from agent.core import telemetry
|
| 530 |
|
| 531 |
await telemetry.record_sandbox_destroy(session, sandbox)
|
|
@@ -93,7 +93,11 @@ def get_console() -> Console:
|
|
| 93 |
# ── Banner ─────────────────────────────────────────────────────────────
|
| 94 |
|
| 95 |
|
| 96 |
-
def print_banner(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
"""Print particle logo then CRT boot sequence with system info."""
|
| 98 |
from agent.utils.particle_logo import run_particle_logo
|
| 99 |
from agent.utils.crt_boot import run_boot_sequence
|
|
@@ -116,6 +120,7 @@ def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
|
|
| 116 |
(f"{_I}Initializing agent runtime...", gold),
|
| 117 |
(f"{_I} User: {user_label}", dim_gold),
|
| 118 |
(f"{_I} Model: {model_label}", dim_gold),
|
|
|
|
| 119 |
(f"{_I} Tools: loading...", dim_gold),
|
| 120 |
("", ""),
|
| 121 |
(f"{_I}/help for commands · /model to switch · /quit to exit", gold),
|
|
|
|
| 93 |
# ── Banner ─────────────────────────────────────────────────────────────
|
| 94 |
|
| 95 |
|
| 96 |
+
def print_banner(
|
| 97 |
+
model: str | None = None,
|
| 98 |
+
hf_user: str | None = None,
|
| 99 |
+
tool_runtime: str | None = None,
|
| 100 |
+
) -> None:
|
| 101 |
"""Print particle logo then CRT boot sequence with system info."""
|
| 102 |
from agent.utils.particle_logo import run_particle_logo
|
| 103 |
from agent.utils.crt_boot import run_boot_sequence
|
|
|
|
| 120 |
(f"{_I}Initializing agent runtime...", gold),
|
| 121 |
(f"{_I} User: {user_label}", dim_gold),
|
| 122 |
(f"{_I} Model: {model_label}", dim_gold),
|
| 123 |
+
(f"{_I} Tool runtime: {tool_runtime or 'local filesystem'}", dim_gold),
|
| 124 |
(f"{_I} Tools: loading...", dim_gold),
|
| 125 |
("", ""),
|
| 126 |
(f"{_I}/help for commands · /model to switch · /quit to exit", gold),
|
|
@@ -7,6 +7,7 @@
|
|
| 7 |
"yolo_mode": false,
|
| 8 |
"confirm_cpu_jobs": true,
|
| 9 |
"auto_file_upload": true,
|
|
|
|
| 10 |
"messaging": {
|
| 11 |
"enabled": false,
|
| 12 |
"auto_event_types": ["approval_required", "error", "turn_complete"],
|
|
|
|
| 7 |
"yolo_mode": false,
|
| 8 |
"confirm_cpu_jobs": true,
|
| 9 |
"auto_file_upload": true,
|
| 10 |
+
"tool_runtime": "local",
|
| 11 |
"messaging": {
|
| 12 |
"enabled": false,
|
| 13 |
"auto_event_types": ["approval_required", "error", "turn_complete"],
|
|
@@ -1,5 +1,6 @@
|
|
| 1 |
"""Regression tests for interactive CLI rendering and research model routing."""
|
| 2 |
|
|
|
|
| 3 |
import sys
|
| 4 |
from io import StringIO
|
| 5 |
from types import SimpleNamespace
|
|
@@ -97,10 +98,11 @@ def test_subagent_display_does_not_spawn_background_redraw(monkeypatch):
|
|
| 97 |
|
| 98 |
|
| 99 |
def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
|
| 100 |
-
seen: dict[str,
|
| 101 |
|
| 102 |
-
async def fake_main(*, model=None):
|
| 103 |
seen["model"] = model
|
|
|
|
| 104 |
|
| 105 |
monkeypatch.setattr(sys, "argv", ["ml-intern", "--model", "openai/gpt-5.5"])
|
| 106 |
monkeypatch.setattr(main_mod, "main", fake_main)
|
|
@@ -108,6 +110,61 @@ def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
|
|
| 108 |
main_mod.cli()
|
| 109 |
|
| 110 |
assert seen["model"] == "openai/gpt-5.5"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
|
| 113 |
@pytest.mark.asyncio
|
|
@@ -115,9 +172,10 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
|
|
| 115 |
class StopAfterBanner(Exception):
|
| 116 |
pass
|
| 117 |
|
| 118 |
-
def fake_banner(*, model=None, hf_user=None):
|
| 119 |
assert model == "openai/gpt-5.5"
|
| 120 |
assert hf_user == "tester"
|
|
|
|
| 121 |
raise StopAfterBanner
|
| 122 |
|
| 123 |
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
|
@@ -130,9 +188,150 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
|
|
| 130 |
lambda _path, **_kwargs: SimpleNamespace(
|
| 131 |
model_name="moonshotai/Kimi-K2.6",
|
| 132 |
mcpServers={},
|
|
|
|
| 133 |
),
|
| 134 |
)
|
| 135 |
monkeypatch.setattr(main_mod, "print_banner", fake_banner)
|
| 136 |
|
| 137 |
with pytest.raises(StopAfterBanner):
|
| 138 |
await main_mod.main(model="openai/gpt-5.5")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Regression tests for interactive CLI rendering and research model routing."""
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
import sys
|
| 5 |
from io import StringIO
|
| 6 |
from types import SimpleNamespace
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
|
| 101 |
+
seen: dict[str, object] = {}
|
| 102 |
|
| 103 |
+
async def fake_main(*, model=None, sandbox_tools=False):
|
| 104 |
seen["model"] = model
|
| 105 |
+
seen["sandbox_tools"] = sandbox_tools
|
| 106 |
|
| 107 |
monkeypatch.setattr(sys, "argv", ["ml-intern", "--model", "openai/gpt-5.5"])
|
| 108 |
monkeypatch.setattr(main_mod, "main", fake_main)
|
|
|
|
| 110 |
main_mod.cli()
|
| 111 |
|
| 112 |
assert seen["model"] == "openai/gpt-5.5"
|
| 113 |
+
assert seen["sandbox_tools"] is False
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_cli_forwards_sandbox_flag_to_interactive_main(monkeypatch):
|
| 117 |
+
seen: dict[str, object] = {}
|
| 118 |
+
|
| 119 |
+
async def fake_main(*, model=None, sandbox_tools=False):
|
| 120 |
+
seen["model"] = model
|
| 121 |
+
seen["sandbox_tools"] = sandbox_tools
|
| 122 |
+
|
| 123 |
+
monkeypatch.setattr(sys, "argv", ["ml-intern", "--sandbox-tools"])
|
| 124 |
+
monkeypatch.setattr(main_mod, "main", fake_main)
|
| 125 |
+
|
| 126 |
+
main_mod.cli()
|
| 127 |
+
|
| 128 |
+
assert seen == {"model": None, "sandbox_tools": True}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def test_cli_forwards_sandbox_flag_to_headless_main(monkeypatch):
|
| 132 |
+
seen: dict[str, object] = {}
|
| 133 |
+
|
| 134 |
+
async def fake_headless_main(
|
| 135 |
+
prompt,
|
| 136 |
+
*,
|
| 137 |
+
model=None,
|
| 138 |
+
max_iterations=None,
|
| 139 |
+
stream=True,
|
| 140 |
+
sandbox_tools=False,
|
| 141 |
+
):
|
| 142 |
+
seen.update(
|
| 143 |
+
{
|
| 144 |
+
"prompt": prompt,
|
| 145 |
+
"model": model,
|
| 146 |
+
"max_iterations": max_iterations,
|
| 147 |
+
"stream": stream,
|
| 148 |
+
"sandbox_tools": sandbox_tools,
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
monkeypatch.setattr(
|
| 153 |
+
sys,
|
| 154 |
+
"argv",
|
| 155 |
+
["ml-intern", "--sandbox-tools", "--no-stream", "train a model"],
|
| 156 |
+
)
|
| 157 |
+
monkeypatch.setattr(main_mod, "headless_main", fake_headless_main)
|
| 158 |
+
|
| 159 |
+
main_mod.cli()
|
| 160 |
+
|
| 161 |
+
assert seen == {
|
| 162 |
+
"prompt": "train a model",
|
| 163 |
+
"model": None,
|
| 164 |
+
"max_iterations": None,
|
| 165 |
+
"stream": False,
|
| 166 |
+
"sandbox_tools": True,
|
| 167 |
+
}
|
| 168 |
|
| 169 |
|
| 170 |
@pytest.mark.asyncio
|
|
|
|
| 172 |
class StopAfterBanner(Exception):
|
| 173 |
pass
|
| 174 |
|
| 175 |
+
def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
|
| 176 |
assert model == "openai/gpt-5.5"
|
| 177 |
assert hf_user == "tester"
|
| 178 |
+
assert tool_runtime == "local filesystem"
|
| 179 |
raise StopAfterBanner
|
| 180 |
|
| 181 |
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
|
|
|
| 188 |
lambda _path, **_kwargs: SimpleNamespace(
|
| 189 |
model_name="moonshotai/Kimi-K2.6",
|
| 190 |
mcpServers={},
|
| 191 |
+
tool_runtime="local",
|
| 192 |
),
|
| 193 |
)
|
| 194 |
monkeypatch.setattr(main_mod, "print_banner", fake_banner)
|
| 195 |
|
| 196 |
with pytest.raises(StopAfterBanner):
|
| 197 |
await main_mod.main(model="openai/gpt-5.5")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@pytest.mark.asyncio
|
| 201 |
+
async def test_local_model_local_runtime_skips_hf_token_prompt(monkeypatch):
|
| 202 |
+
class StopAfterBanner(Exception):
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
async def fail_prompt(_prompt_session):
|
| 206 |
+
raise AssertionError("local model with local tools should not prompt")
|
| 207 |
+
|
| 208 |
+
def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
|
| 209 |
+
assert model == "llamacpp/model"
|
| 210 |
+
assert hf_user is None
|
| 211 |
+
assert tool_runtime == "local filesystem"
|
| 212 |
+
raise StopAfterBanner
|
| 213 |
+
|
| 214 |
+
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
| 215 |
+
monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
|
| 216 |
+
monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None)
|
| 217 |
+
monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fail_prompt)
|
| 218 |
+
monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: None)
|
| 219 |
+
monkeypatch.setattr(
|
| 220 |
+
main_mod,
|
| 221 |
+
"load_config",
|
| 222 |
+
lambda _path, **_kwargs: SimpleNamespace(
|
| 223 |
+
model_name="llamacpp/model",
|
| 224 |
+
mcpServers={},
|
| 225 |
+
tool_runtime="local",
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
monkeypatch.setattr(main_mod, "print_banner", fake_banner)
|
| 229 |
+
|
| 230 |
+
with pytest.raises(StopAfterBanner):
|
| 231 |
+
await main_mod.main()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@pytest.mark.asyncio
|
| 235 |
+
async def test_local_model_sandbox_runtime_prompts_for_hf_token(monkeypatch):
|
| 236 |
+
class StopAfterBanner(Exception):
|
| 237 |
+
pass
|
| 238 |
+
|
| 239 |
+
prompted = False
|
| 240 |
+
|
| 241 |
+
async def fake_prompt(_prompt_session):
|
| 242 |
+
nonlocal prompted
|
| 243 |
+
prompted = True
|
| 244 |
+
return "hf-token"
|
| 245 |
+
|
| 246 |
+
def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
|
| 247 |
+
assert model == "llamacpp/model"
|
| 248 |
+
assert hf_user == "tester"
|
| 249 |
+
assert tool_runtime == "HF sandbox"
|
| 250 |
+
raise StopAfterBanner
|
| 251 |
+
|
| 252 |
+
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
| 253 |
+
monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
|
| 254 |
+
monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None)
|
| 255 |
+
monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fake_prompt)
|
| 256 |
+
monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester")
|
| 257 |
+
monkeypatch.setattr(
|
| 258 |
+
main_mod,
|
| 259 |
+
"load_config",
|
| 260 |
+
lambda _path, **_kwargs: SimpleNamespace(
|
| 261 |
+
model_name="llamacpp/model",
|
| 262 |
+
mcpServers={},
|
| 263 |
+
tool_runtime="local",
|
| 264 |
+
),
|
| 265 |
+
)
|
| 266 |
+
monkeypatch.setattr(main_mod, "print_banner", fake_banner)
|
| 267 |
+
|
| 268 |
+
with pytest.raises(StopAfterBanner):
|
| 269 |
+
await main_mod.main(sandbox_tools=True)
|
| 270 |
+
|
| 271 |
+
assert prompted is True
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@pytest.mark.asyncio
|
| 275 |
+
async def test_interactive_main_passes_sandbox_runtime_to_tool_router(monkeypatch):
|
| 276 |
+
class StopAfterToolRouter(Exception):
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
seen: dict[str, object] = {}
|
| 280 |
+
|
| 281 |
+
class FakeGateway:
|
| 282 |
+
def __init__(self, _config):
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
async def start(self):
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
class FakeToolRouter:
|
| 289 |
+
def __init__(self, mcp_servers, *, hf_token=None, local_mode=True):
|
| 290 |
+
seen["mcp_servers"] = mcp_servers
|
| 291 |
+
seen["hf_token"] = hf_token
|
| 292 |
+
seen["local_mode"] = local_mode
|
| 293 |
+
raise StopAfterToolRouter
|
| 294 |
+
|
| 295 |
+
from agent.core import hf_router_catalog
|
| 296 |
+
|
| 297 |
+
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
| 298 |
+
monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
|
| 299 |
+
monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: "hf-token")
|
| 300 |
+
monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester")
|
| 301 |
+
monkeypatch.setattr(main_mod, "print_banner", lambda **_kwargs: None)
|
| 302 |
+
monkeypatch.setattr(hf_router_catalog, "prewarm", lambda: None)
|
| 303 |
+
monkeypatch.setattr(
|
| 304 |
+
main_mod,
|
| 305 |
+
"load_config",
|
| 306 |
+
lambda _path, **_kwargs: SimpleNamespace(
|
| 307 |
+
model_name="llamacpp/model",
|
| 308 |
+
mcpServers={"server": object()},
|
| 309 |
+
messaging=SimpleNamespace(default_auto_destinations=lambda: []),
|
| 310 |
+
tool_runtime="local",
|
| 311 |
+
),
|
| 312 |
+
)
|
| 313 |
+
monkeypatch.setattr(main_mod, "NotificationGateway", FakeGateway)
|
| 314 |
+
monkeypatch.setattr(main_mod, "ToolRouter", FakeToolRouter)
|
| 315 |
+
|
| 316 |
+
with pytest.raises(StopAfterToolRouter):
|
| 317 |
+
await main_mod.main(sandbox_tools=True)
|
| 318 |
+
|
| 319 |
+
assert seen["hf_token"] == "hf-token"
|
| 320 |
+
assert seen["local_mode"] is False
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@pytest.mark.asyncio
|
| 324 |
+
async def test_initial_sandbox_preload_waits_before_prompt():
|
| 325 |
+
waited = False
|
| 326 |
+
|
| 327 |
+
async def preload():
|
| 328 |
+
nonlocal waited
|
| 329 |
+
await asyncio.sleep(0)
|
| 330 |
+
waited = True
|
| 331 |
+
|
| 332 |
+
task = asyncio.create_task(preload())
|
| 333 |
+
await main_mod._wait_for_initial_sandbox_preload(
|
| 334 |
+
[SimpleNamespace(sandbox_preload_task=task)]
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
assert waited is True
|
|
@@ -1,5 +1,8 @@
|
|
| 1 |
import json
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
from agent import config as config_module
|
| 4 |
|
| 5 |
|
|
@@ -121,3 +124,35 @@ def test_slack_user_defaults_can_be_disabled(tmp_path, monkeypatch):
|
|
| 121 |
|
| 122 |
assert not config.messaging.enabled
|
| 123 |
assert config.messaging.destinations == {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
|
| 3 |
+
import pytest
|
| 4 |
+
from pydantic import ValidationError
|
| 5 |
+
|
| 6 |
from agent import config as config_module
|
| 7 |
|
| 8 |
|
|
|
|
| 124 |
|
| 125 |
assert not config.messaging.enabled
|
| 126 |
assert config.messaging.destinations == {}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def test_tool_runtime_defaults_to_local(tmp_path):
|
| 130 |
+
config_path = tmp_path / "config.json"
|
| 131 |
+
_write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
|
| 132 |
+
|
| 133 |
+
config = config_module.load_config(str(config_path))
|
| 134 |
+
|
| 135 |
+
assert config.tool_runtime == "local"
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def test_user_config_can_set_sandbox_tool_runtime(tmp_path, monkeypatch):
|
| 139 |
+
config_path = tmp_path / "config.json"
|
| 140 |
+
user_config_path = tmp_path / "user-config.json"
|
| 141 |
+
_write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
|
| 142 |
+
_write_json(user_config_path, {"tool_runtime": "sandbox"})
|
| 143 |
+
monkeypatch.setenv("ML_INTERN_CLI_CONFIG", str(user_config_path))
|
| 144 |
+
|
| 145 |
+
config = config_module.load_config(str(config_path), include_user_defaults=True)
|
| 146 |
+
|
| 147 |
+
assert config.tool_runtime == "sandbox"
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def test_invalid_tool_runtime_is_rejected(tmp_path):
|
| 151 |
+
config_path = tmp_path / "config.json"
|
| 152 |
+
_write_json(
|
| 153 |
+
config_path,
|
| 154 |
+
{"model_name": "moonshotai/Kimi-K2.6", "tool_runtime": "hybrid"},
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
with pytest.raises(ValidationError):
|
| 158 |
+
config_module.load_config(str(config_path))
|
|
@@ -549,7 +549,7 @@ def test_sitecustomize_caches_lazy_collection_slug_across_bootstraps(
|
|
| 549 |
]
|
| 550 |
|
| 551 |
|
| 552 |
-
def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
|
| 553 |
import huggingface_hub as hub
|
| 554 |
from huggingface_hub import HfApi
|
| 555 |
|
|
@@ -579,6 +579,10 @@ def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
|
|
| 579 |
def fake_add_collection_item(self, **kwargs):
|
| 580 |
collection_items.append(kwargs)
|
| 581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
|
| 583 |
monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
|
| 584 |
monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
|
|
|
|
| 549 |
]
|
| 550 |
|
| 551 |
|
| 552 |
+
def test_sitecustomize_skips_sandbox_space_registration(monkeypatch, tmp_path):
|
| 553 |
import huggingface_hub as hub
|
| 554 |
from huggingface_hub import HfApi
|
| 555 |
|
|
|
|
| 579 |
def fake_add_collection_item(self, **kwargs):
|
| 580 |
collection_items.append(kwargs)
|
| 581 |
|
| 582 |
+
monkeypatch.setenv(
|
| 583 |
+
"ML_INTERN_ARTIFACT_COLLECTION_CACHE",
|
| 584 |
+
str(tmp_path / "collection-slug.txt"),
|
| 585 |
+
)
|
| 586 |
monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
|
| 587 |
monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
|
| 588 |
monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
|
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from agent.config import Config
|
| 7 |
+
from agent.core import agent_loop
|
| 8 |
+
from agent.core.agent_loop import Handlers, LLMResult
|
| 9 |
+
from agent.core.session import Session
|
| 10 |
+
from agent.tools.plan_tool import PlanTool
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FakeToolRouter:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.calls = []
|
| 16 |
+
|
| 17 |
+
def get_tool_specs_for_llm(self):
|
| 18 |
+
return [
|
| 19 |
+
{
|
| 20 |
+
"type": "function",
|
| 21 |
+
"function": {
|
| 22 |
+
"name": "plan_tool",
|
| 23 |
+
"description": "Update plan",
|
| 24 |
+
"parameters": {"type": "object"},
|
| 25 |
+
},
|
| 26 |
+
}
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
async def call_tool(self, name, arguments, session=None, tool_call_id=None):
|
| 30 |
+
self.calls.append((name, arguments, tool_call_id))
|
| 31 |
+
if name == "plan_tool" and session is not None:
|
| 32 |
+
session.current_plan = [dict(todo) for todo in arguments["todos"]]
|
| 33 |
+
return "plan updated", True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.mark.asyncio
|
| 37 |
+
async def test_plan_tool_stores_session_scoped_plan():
|
| 38 |
+
events = []
|
| 39 |
+
|
| 40 |
+
class FakeSession:
|
| 41 |
+
current_plan = []
|
| 42 |
+
|
| 43 |
+
async def send_event(self, event):
|
| 44 |
+
events.append(event)
|
| 45 |
+
|
| 46 |
+
session = FakeSession()
|
| 47 |
+
todos = [{"id": "1", "content": "Smoke test", "status": "in_progress"}]
|
| 48 |
+
|
| 49 |
+
result = await PlanTool(session=session).execute({"todos": todos})
|
| 50 |
+
|
| 51 |
+
assert result["isError"] is False
|
| 52 |
+
assert session.current_plan == todos
|
| 53 |
+
assert events[0].event_type == "plan_update"
|
| 54 |
+
assert events[0].data == {"plan": todos}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@pytest.mark.asyncio
|
| 58 |
+
async def test_no_tool_response_retries_when_plan_is_incomplete(monkeypatch):
|
| 59 |
+
config = Config.model_validate(
|
| 60 |
+
{"model_name": "openai/test", "save_sessions": False}
|
| 61 |
+
)
|
| 62 |
+
event_queue = asyncio.Queue()
|
| 63 |
+
router = FakeToolRouter()
|
| 64 |
+
session = Session(
|
| 65 |
+
event_queue,
|
| 66 |
+
config,
|
| 67 |
+
tool_router=router,
|
| 68 |
+
stream=False,
|
| 69 |
+
)
|
| 70 |
+
session.current_plan = [
|
| 71 |
+
{
|
| 72 |
+
"id": "1",
|
| 73 |
+
"content": "Write and smoke-test training script",
|
| 74 |
+
"status": "in_progress",
|
| 75 |
+
},
|
| 76 |
+
{"id": "2", "content": "Launch full training job", "status": "pending"},
|
| 77 |
+
]
|
| 78 |
+
calls = []
|
| 79 |
+
|
| 80 |
+
async def fake_call_llm_non_streaming(session, messages, tools, llm_params):
|
| 81 |
+
calls.append(messages)
|
| 82 |
+
if len(calls) == 1:
|
| 83 |
+
return LLMResult(
|
| 84 |
+
content="I should keep going, but I forgot to call a tool.",
|
| 85 |
+
tool_calls_acc={},
|
| 86 |
+
token_count=10,
|
| 87 |
+
finish_reason="stop",
|
| 88 |
+
)
|
| 89 |
+
if len(calls) == 2:
|
| 90 |
+
assert "CONTINUATION GUARD" in messages[-1].content
|
| 91 |
+
return LLMResult(
|
| 92 |
+
content=None,
|
| 93 |
+
tool_calls_acc={
|
| 94 |
+
0: {
|
| 95 |
+
"id": "call_1",
|
| 96 |
+
"function": {
|
| 97 |
+
"name": "plan_tool",
|
| 98 |
+
"arguments": json.dumps(
|
| 99 |
+
{
|
| 100 |
+
"todos": [
|
| 101 |
+
{
|
| 102 |
+
"id": "1",
|
| 103 |
+
"content": "Write and smoke-test training script",
|
| 104 |
+
"status": "completed",
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"id": "2",
|
| 108 |
+
"content": "Launch full training job",
|
| 109 |
+
"status": "completed",
|
| 110 |
+
},
|
| 111 |
+
]
|
| 112 |
+
}
|
| 113 |
+
),
|
| 114 |
+
},
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
token_count=20,
|
| 118 |
+
finish_reason="tool_calls",
|
| 119 |
+
)
|
| 120 |
+
return LLMResult(
|
| 121 |
+
content="Done.",
|
| 122 |
+
tool_calls_acc={},
|
| 123 |
+
token_count=30,
|
| 124 |
+
finish_reason="stop",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
monkeypatch.setattr(
|
| 128 |
+
agent_loop, "_resolve_llm_params", lambda *_, **__: {"model": "openai/test"}
|
| 129 |
+
)
|
| 130 |
+
monkeypatch.setattr(
|
| 131 |
+
agent_loop, "_call_llm_non_streaming", fake_call_llm_non_streaming
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
final = await Handlers.run_agent(session, "continue")
|
| 135 |
+
|
| 136 |
+
assert final == "Done."
|
| 137 |
+
assert len(calls) == 3
|
| 138 |
+
assert router.calls[0][0] == "plan_tool"
|
| 139 |
+
assert all(todo["status"] == "completed" for todo in session.current_plan)
|
| 140 |
+
events = []
|
| 141 |
+
while not event_queue.empty():
|
| 142 |
+
events.append(await event_queue.get())
|
| 143 |
+
assert any(
|
| 144 |
+
event.event_type == "tool_log"
|
| 145 |
+
and "text-only response" in (event.data or {}).get("log", "")
|
| 146 |
+
for event in events
|
| 147 |
+
)
|
|
@@ -1,7 +1,15 @@
|
|
|
|
|
| 1 |
from types import SimpleNamespace
|
| 2 |
from pathlib import Path
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from agent.core.agent_loop import _needs_approval
|
|
|
|
|
|
|
|
|
|
| 5 |
from agent.tools.sandbox_tool import get_sandbox_tools
|
| 6 |
|
| 7 |
|
|
@@ -34,3 +42,102 @@ def test_prompt_and_tool_specs_do_not_require_cpu_sandbox_create():
|
|
| 34 |
in tool_specs["sandbox_create"]
|
| 35 |
)
|
| 36 |
assert "started automatically for normal CPU work" in tool_specs["bash"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
from types import SimpleNamespace
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from agent.config import Config
|
| 8 |
+
from agent.core import agent_loop
|
| 9 |
from agent.core.agent_loop import _needs_approval
|
| 10 |
+
from agent.core.session import OpType
|
| 11 |
+
from agent.core.tools import create_builtin_tools
|
| 12 |
+
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
|
| 13 |
from agent.tools.sandbox_tool import get_sandbox_tools
|
| 14 |
|
| 15 |
|
|
|
|
| 42 |
in tool_specs["sandbox_create"]
|
| 43 |
)
|
| 44 |
assert "started automatically for normal CPU work" in tool_specs["bash"]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_prompt_rejects_local_machine_paths_for_hf_jobs_scripts():
|
| 48 |
+
prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
|
| 49 |
+
|
| 50 |
+
assert "Never pass a local machine path to hf_jobs.script" in prompt
|
| 51 |
+
assert "/fsx/..." in prompt
|
| 52 |
+
assert "inline Python source code" in prompt
|
| 53 |
+
assert "a file already written in the session sandbox" in prompt
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_prompt_and_hf_jobs_spec_require_gpu_preflight_for_gpu_jobs():
|
| 57 |
+
prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
|
| 58 |
+
jobs_description = HF_JOBS_TOOL_SPEC["description"]
|
| 59 |
+
|
| 60 |
+
assert "GPU preflight is mandatory before hf_jobs" in prompt
|
| 61 |
+
assert "GPU sandbox smoke test" in prompt
|
| 62 |
+
assert "If you skip GPU sandbox preflight" in prompt
|
| 63 |
+
assert "you MUST create a GPU sandbox with sandbox_create first" in jobs_description
|
| 64 |
+
assert "If skipped, state why before calling hf_jobs" in jobs_description
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_local_tool_runtime_excludes_sandbox_create():
|
| 68 |
+
tool_names = {tool.name for tool in create_builtin_tools(local_mode=True)}
|
| 69 |
+
|
| 70 |
+
assert {"bash", "read", "write", "edit"} <= tool_names
|
| 71 |
+
assert "sandbox_create" not in tool_names
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_sandbox_tool_runtime_includes_sandbox_create():
|
| 75 |
+
tool_names = {tool.name for tool in create_builtin_tools(local_mode=False)}
|
| 76 |
+
|
| 77 |
+
assert {"sandbox_create", "bash", "read", "write", "edit"} <= tool_names
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@pytest.mark.asyncio
|
| 81 |
+
async def test_cli_sandbox_runtime_preloads_and_tears_down_sandbox(monkeypatch):
|
| 82 |
+
started = []
|
| 83 |
+
torn_down = []
|
| 84 |
+
|
| 85 |
+
class FakeToolRouter:
|
| 86 |
+
tools = {}
|
| 87 |
+
|
| 88 |
+
def get_tool_specs_for_llm(self):
|
| 89 |
+
return []
|
| 90 |
+
|
| 91 |
+
async def __aenter__(self):
|
| 92 |
+
return self
|
| 93 |
+
|
| 94 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
def fake_start_cpu_sandbox_preload(session):
|
| 98 |
+
started.append(session)
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
async def fake_teardown_session_sandbox(session):
|
| 102 |
+
torn_down.append(session)
|
| 103 |
+
|
| 104 |
+
monkeypatch.setattr(
|
| 105 |
+
agent_loop, "start_cpu_sandbox_preload", fake_start_cpu_sandbox_preload
|
| 106 |
+
)
|
| 107 |
+
monkeypatch.setattr(
|
| 108 |
+
agent_loop, "teardown_session_sandbox", fake_teardown_session_sandbox
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
submission_queue = asyncio.Queue()
|
| 112 |
+
event_queue = asyncio.Queue()
|
| 113 |
+
session_holder = [None]
|
| 114 |
+
config = Config.model_validate(
|
| 115 |
+
{"model_name": "openai/gpt-5.5", "save_sessions": False}
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
task = asyncio.create_task(
|
| 119 |
+
agent_loop.submission_loop(
|
| 120 |
+
submission_queue,
|
| 121 |
+
event_queue,
|
| 122 |
+
config=config,
|
| 123 |
+
tool_router=FakeToolRouter(),
|
| 124 |
+
session_holder=session_holder,
|
| 125 |
+
hf_token="hf-token",
|
| 126 |
+
user_id="tester",
|
| 127 |
+
local_mode=False,
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
ready = await asyncio.wait_for(event_queue.get(), timeout=1)
|
| 132 |
+
assert ready.event_type == "ready"
|
| 133 |
+
assert started == [session_holder[0]]
|
| 134 |
+
assert session_holder[0].local_mode is False
|
| 135 |
+
|
| 136 |
+
await submission_queue.put(
|
| 137 |
+
SimpleNamespace(
|
| 138 |
+
operation=SimpleNamespace(op_type=OpType.SHUTDOWN, data=None),
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
await asyncio.wait_for(task, timeout=1)
|
| 142 |
+
|
| 143 |
+
assert torn_down == [session_holder[0]]
|
|
@@ -91,6 +91,31 @@ def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
|
|
| 91 |
assert not any("sleep time" in log for log in logs)
|
| 92 |
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
|
| 95 |
runtime_calls = 0
|
| 96 |
|
|
@@ -395,6 +420,71 @@ def test_ensure_sandbox_overrides_private_argument(monkeypatch):
|
|
| 395 |
assert persisted[-1]["sandbox_status"] == "active"
|
| 396 |
|
| 397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
def test_sandbox_creation_is_serialized_per_owner(monkeypatch):
|
| 399 |
active_creates = 0
|
| 400 |
max_active_creates = 0
|
|
@@ -514,7 +604,7 @@ def test_sandbox_create_replaces_auto_cpu_sandbox(monkeypatch):
|
|
| 514 |
space_id="alice/sandbox-cpu",
|
| 515 |
url="https://huggingface.co/spaces/alice/sandbox-cpu",
|
| 516 |
_owns_space=True,
|
| 517 |
-
delete=lambda: deleted.append("alice/sandbox-cpu"),
|
| 518 |
)
|
| 519 |
self.sandbox_hardware = "cpu-basic"
|
| 520 |
self.sandbox_preload_task = None
|
|
@@ -559,10 +649,11 @@ def test_sandbox_create_replaces_auto_cpu_sandbox(monkeypatch):
|
|
| 559 |
|
| 560 |
def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
| 561 |
deleted: list[str] = []
|
|
|
|
| 562 |
persisted: list[dict] = []
|
| 563 |
|
| 564 |
-
async def fake_record_sandbox_destroy(*args, **kwargs):
|
| 565 |
-
|
| 566 |
|
| 567 |
monkeypatch.setattr(
|
| 568 |
telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy
|
|
@@ -570,20 +661,28 @@ def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
|
| 570 |
|
| 571 |
async def run():
|
| 572 |
cancel_event = threading.Event()
|
|
|
|
| 573 |
|
| 574 |
async def preload():
|
| 575 |
await asyncio.sleep(0)
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
session = SimpleNamespace(
|
| 578 |
session_id="s1",
|
| 579 |
sandbox=SimpleNamespace(
|
| 580 |
space_id="alice/sandbox-12345678",
|
| 581 |
_owns_space=True,
|
| 582 |
-
delete=
|
| 583 |
),
|
| 584 |
sandbox_hardware="cpu-basic",
|
| 585 |
sandbox_preload_task=asyncio.create_task(preload()),
|
| 586 |
sandbox_preload_cancel_event=cancel_event,
|
|
|
|
| 587 |
persistence_store=SimpleNamespace(
|
| 588 |
update_session_fields=lambda session_id, **fields: _record_metadata(
|
| 589 |
session_id, fields
|
|
@@ -592,17 +691,33 @@ def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
|
| 592 |
)
|
| 593 |
|
| 594 |
await sandbox_tool.teardown_session_sandbox(session)
|
| 595 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
async def _record_metadata(session_id, fields):
|
| 598 |
persisted.append({"session_id": session_id, **fields})
|
| 599 |
|
| 600 |
-
session, cancel_event = asyncio.run(run())
|
| 601 |
|
| 602 |
assert cancel_event.is_set()
|
| 603 |
assert deleted == ["alice/sandbox-12345678"]
|
|
|
|
| 604 |
assert session.sandbox is None
|
| 605 |
assert session.sandbox_hardware is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
assert persisted[-1]["session_id"] == "s1"
|
| 607 |
assert persisted[-1]["sandbox_space_id"] is None
|
| 608 |
assert persisted[-1]["sandbox_status"] == "destroyed"
|
|
|
|
| 91 |
assert not any("sleep time" in log for log in logs)
|
| 92 |
|
| 93 |
|
| 94 |
+
def test_sandbox_delete_uses_log_callback_without_stdout(monkeypatch, capsys):
|
| 95 |
+
deleted: list[tuple[str, str]] = []
|
| 96 |
+
|
| 97 |
+
class FakeApi:
|
| 98 |
+
def __init__(self, token=None):
|
| 99 |
+
self.token = token
|
| 100 |
+
|
| 101 |
+
def delete_repo(self, repo_id, repo_type):
|
| 102 |
+
deleted.append((repo_id, repo_type))
|
| 103 |
+
|
| 104 |
+
monkeypatch.setattr(sandbox_client, "HfApi", FakeApi)
|
| 105 |
+
|
| 106 |
+
sandbox = Sandbox("alice/sandbox-12345678", token="hf-token", _owns_space=True)
|
| 107 |
+
logs: list[str] = []
|
| 108 |
+
|
| 109 |
+
sandbox.delete(log=logs.append)
|
| 110 |
+
|
| 111 |
+
captured = capsys.readouterr()
|
| 112 |
+
assert captured.out == ""
|
| 113 |
+
assert captured.err == ""
|
| 114 |
+
assert deleted == [("alice/sandbox-12345678", "space")]
|
| 115 |
+
assert logs == ["Deleting sandbox: alice/sandbox-12345678...", "Deleted."]
|
| 116 |
+
assert sandbox._owns_space is False
|
| 117 |
+
|
| 118 |
+
|
| 119 |
def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
|
| 120 |
runtime_calls = 0
|
| 121 |
|
|
|
|
| 420 |
assert persisted[-1]["sandbox_status"] == "active"
|
| 421 |
|
| 422 |
|
| 423 |
+
def test_cancelled_sandbox_creation_logs_delete_through_tool_log(monkeypatch):
|
| 424 |
+
deleted: list[str] = []
|
| 425 |
+
|
| 426 |
+
class FakeSession:
|
| 427 |
+
def __init__(self):
|
| 428 |
+
self.hf_token = "hf-token"
|
| 429 |
+
self.sandbox = None
|
| 430 |
+
self.event_queue = asyncio.Queue()
|
| 431 |
+
self._cancelled = asyncio.Event()
|
| 432 |
+
|
| 433 |
+
async def send_event(self, event):
|
| 434 |
+
await self.event_queue.put(event)
|
| 435 |
+
|
| 436 |
+
def fake_create(**kwargs):
|
| 437 |
+
def delete(log=None):
|
| 438 |
+
deleted.append("alice/sandbox-12345678")
|
| 439 |
+
if log:
|
| 440 |
+
log("Deleting sandbox: alice/sandbox-12345678...")
|
| 441 |
+
log("Deleted.")
|
| 442 |
+
|
| 443 |
+
return SimpleNamespace(
|
| 444 |
+
space_id="alice/sandbox-12345678",
|
| 445 |
+
url="https://huggingface.co/spaces/alice/sandbox-12345678",
|
| 446 |
+
_owns_space=True,
|
| 447 |
+
delete=delete,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
monkeypatch.setattr(Sandbox, "create", staticmethod(fake_create))
|
| 451 |
+
|
| 452 |
+
async def run():
|
| 453 |
+
session = FakeSession()
|
| 454 |
+
cancel_event = threading.Event()
|
| 455 |
+
cancel_event.set()
|
| 456 |
+
|
| 457 |
+
sb, error = await sandbox_tool._create_sandbox_locked(
|
| 458 |
+
session,
|
| 459 |
+
api=SimpleNamespace(),
|
| 460 |
+
owner="alice",
|
| 461 |
+
hardware="cpu-basic",
|
| 462 |
+
cancel_event=cancel_event,
|
| 463 |
+
)
|
| 464 |
+
await asyncio.sleep(0)
|
| 465 |
+
events = []
|
| 466 |
+
while not session.event_queue.empty():
|
| 467 |
+
events.append(await session.event_queue.get())
|
| 468 |
+
return sb, error, events
|
| 469 |
+
|
| 470 |
+
sb, error, events = asyncio.run(run())
|
| 471 |
+
|
| 472 |
+
assert sb is None
|
| 473 |
+
assert error == "Sandbox creation cancelled by user."
|
| 474 |
+
assert deleted == ["alice/sandbox-12345678"]
|
| 475 |
+
assert [
|
| 476 |
+
event.data
|
| 477 |
+
for event in events
|
| 478 |
+
if event.event_type == "tool_log"
|
| 479 |
+
and event.data
|
| 480 |
+
and event.data.get("log")
|
| 481 |
+
in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."}
|
| 482 |
+
] == [
|
| 483 |
+
{"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."},
|
| 484 |
+
{"tool": "sandbox", "log": "Deleted."},
|
| 485 |
+
]
|
| 486 |
+
|
| 487 |
+
|
| 488 |
def test_sandbox_creation_is_serialized_per_owner(monkeypatch):
|
| 489 |
active_creates = 0
|
| 490 |
max_active_creates = 0
|
|
|
|
| 604 |
space_id="alice/sandbox-cpu",
|
| 605 |
url="https://huggingface.co/spaces/alice/sandbox-cpu",
|
| 606 |
_owns_space=True,
|
| 607 |
+
delete=lambda log=None: deleted.append("alice/sandbox-cpu"),
|
| 608 |
)
|
| 609 |
self.sandbox_hardware = "cpu-basic"
|
| 610 |
self.sandbox_preload_task = None
|
|
|
|
| 649 |
|
| 650 |
def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
| 651 |
deleted: list[str] = []
|
| 652 |
+
destroyed: list[str] = []
|
| 653 |
persisted: list[dict] = []
|
| 654 |
|
| 655 |
+
async def fake_record_sandbox_destroy(session, sandbox, *args, **kwargs):
|
| 656 |
+
destroyed.append(sandbox.space_id)
|
| 657 |
|
| 658 |
monkeypatch.setattr(
|
| 659 |
telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy
|
|
|
|
| 661 |
|
| 662 |
async def run():
|
| 663 |
cancel_event = threading.Event()
|
| 664 |
+
event_queue = asyncio.Queue()
|
| 665 |
|
| 666 |
async def preload():
|
| 667 |
await asyncio.sleep(0)
|
| 668 |
|
| 669 |
+
def delete(log=None):
|
| 670 |
+
deleted.append("alice/sandbox-12345678")
|
| 671 |
+
if log:
|
| 672 |
+
log("Deleting sandbox: alice/sandbox-12345678...")
|
| 673 |
+
log("Deleted.")
|
| 674 |
+
|
| 675 |
session = SimpleNamespace(
|
| 676 |
session_id="s1",
|
| 677 |
sandbox=SimpleNamespace(
|
| 678 |
space_id="alice/sandbox-12345678",
|
| 679 |
_owns_space=True,
|
| 680 |
+
delete=delete,
|
| 681 |
),
|
| 682 |
sandbox_hardware="cpu-basic",
|
| 683 |
sandbox_preload_task=asyncio.create_task(preload()),
|
| 684 |
sandbox_preload_cancel_event=cancel_event,
|
| 685 |
+
event_queue=event_queue,
|
| 686 |
persistence_store=SimpleNamespace(
|
| 687 |
update_session_fields=lambda session_id, **fields: _record_metadata(
|
| 688 |
session_id, fields
|
|
|
|
| 691 |
)
|
| 692 |
|
| 693 |
await sandbox_tool.teardown_session_sandbox(session)
|
| 694 |
+
await asyncio.sleep(0)
|
| 695 |
+
events = []
|
| 696 |
+
while not event_queue.empty():
|
| 697 |
+
events.append(await event_queue.get())
|
| 698 |
+
return session, cancel_event, events
|
| 699 |
|
| 700 |
async def _record_metadata(session_id, fields):
|
| 701 |
persisted.append({"session_id": session_id, **fields})
|
| 702 |
|
| 703 |
+
session, cancel_event, events = asyncio.run(run())
|
| 704 |
|
| 705 |
assert cancel_event.is_set()
|
| 706 |
assert deleted == ["alice/sandbox-12345678"]
|
| 707 |
+
assert destroyed == ["alice/sandbox-12345678"]
|
| 708 |
assert session.sandbox is None
|
| 709 |
assert session.sandbox_hardware is None
|
| 710 |
+
assert [
|
| 711 |
+
event.data
|
| 712 |
+
for event in events
|
| 713 |
+
if event.event_type == "tool_log"
|
| 714 |
+
and event.data
|
| 715 |
+
and event.data.get("log")
|
| 716 |
+
in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."}
|
| 717 |
+
] == [
|
| 718 |
+
{"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."},
|
| 719 |
+
{"tool": "sandbox", "log": "Deleted."},
|
| 720 |
+
]
|
| 721 |
assert persisted[-1]["session_id"] == "s1"
|
| 722 |
assert persisted[-1]["sandbox_space_id"] is None
|
| 723 |
assert persisted[-1]["sandbox_status"] == "destroyed"
|
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import SimpleNamespace
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
|
| 6 |
+
from agent.tools.sandbox_tool import resolve_sandbox_script
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FakeSandbox:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.read_paths = []
|
| 12 |
+
|
| 13 |
+
def read(self, path, *, limit):
|
| 14 |
+
self.read_paths.append((path, limit))
|
| 15 |
+
return SimpleNamespace(
|
| 16 |
+
success=True,
|
| 17 |
+
output="1\tprint('training')\n2\tprint('done')",
|
| 18 |
+
error="",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@pytest.mark.asyncio
|
| 23 |
+
async def test_resolve_sandbox_script_accepts_bare_python_filename():
|
| 24 |
+
sandbox = FakeSandbox()
|
| 25 |
+
|
| 26 |
+
content, error = await resolve_sandbox_script(sandbox, "train_smollm2.py")
|
| 27 |
+
|
| 28 |
+
assert error is None
|
| 29 |
+
assert content == "print('training')\nprint('done')"
|
| 30 |
+
assert sandbox.read_paths == [("train_smollm2.py", 100_000)]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@pytest.mark.asyncio
|
| 34 |
+
async def test_resolve_sandbox_script_accepts_relative_python_path():
|
| 35 |
+
sandbox = FakeSandbox()
|
| 36 |
+
|
| 37 |
+
content, error = await resolve_sandbox_script(sandbox, "scripts/train.py")
|
| 38 |
+
|
| 39 |
+
assert error is None
|
| 40 |
+
assert content == "print('training')\nprint('done')"
|
| 41 |
+
assert sandbox.read_paths == [("scripts/train.py", 100_000)]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.mark.asyncio
|
| 45 |
+
@pytest.mark.parametrize(
|
| 46 |
+
"script",
|
| 47 |
+
[
|
| 48 |
+
"https://example.com/train.py",
|
| 49 |
+
"http://example.com/train.py",
|
| 50 |
+
"train_smollm2.py --epochs 1",
|
| 51 |
+
"print('hello')",
|
| 52 |
+
],
|
| 53 |
+
)
|
| 54 |
+
async def test_resolve_sandbox_script_ignores_non_path_scripts(script):
|
| 55 |
+
sandbox = FakeSandbox()
|
| 56 |
+
|
| 57 |
+
content, error = await resolve_sandbox_script(sandbox, script)
|
| 58 |
+
|
| 59 |
+
assert content is None
|
| 60 |
+
assert error is None
|
| 61 |
+
assert sandbox.read_paths == []
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_hf_jobs_script_description_mentions_bare_python_filenames():
|
| 65 |
+
script_description = HF_JOBS_TOOL_SPEC["parameters"]["properties"]["script"][
|
| 66 |
+
"description"
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
assert "bare 'train.py'" in script_description
|
| 70 |
+
assert "smoke-test in a GPU sandbox before submission" in script_description
|
|
@@ -207,7 +207,7 @@ async def test_close_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
|
| 207 |
session.sandbox = SimpleNamespace(
|
| 208 |
space_id="owner/sandbox-12345678",
|
| 209 |
_owns_space=True,
|
| 210 |
-
delete=lambda: deleted.append("owner/sandbox-12345678"),
|
| 211 |
)
|
| 212 |
session.sandbox_hardware = "cpu-basic"
|
| 213 |
session.sandbox_preload_cancel_event = preload_cancel_event
|
|
|
|
| 207 |
session.sandbox = SimpleNamespace(
|
| 208 |
space_id="owner/sandbox-12345678",
|
| 209 |
_owns_space=True,
|
| 210 |
+
delete=lambda log=None: deleted.append("owner/sandbox-12345678"),
|
| 211 |
)
|
| 212 |
session.sandbox_hardware = "cpu-basic"
|
| 213 |
session.sandbox_preload_cancel_event = preload_cancel_event
|