lewtun HF Staff OpenAI Codex commited on
Commit
b05b6f5
·
unverified ·
1 Parent(s): 11112c6

Add CLI sandbox runtime and fix HF Jobs script paths (#237)

Browse files

* Resolve bare Python job scripts

Co-authored-by: OpenAI Codex <codex@openai.com>

* Update HF Jobs script description

Co-authored-by: OpenAI Codex <codex@openai.com>

* Clarify HF Jobs script path prompt

Co-authored-by: OpenAI Codex <codex@openai.com>

* Add CLI sandbox tool runtime

Co-authored-by: OpenAI Codex <codex@openai.com>

* Document CLI sandbox tool runtime

Co-authored-by: OpenAI Codex <codex@openai.com>

* Wait for initial CLI sandbox preload

Co-authored-by: OpenAI Codex <codex@openai.com>

* Strengthen GPU sandbox preflight guidance

Co-authored-by: OpenAI Codex <codex@openai.com>

* Guard against text-only stops with unfinished plans

Co-authored-by: OpenAI Codex <codex@openai.com>

* Route sandbox deletion logs through tool events

Co-authored-by: OpenAI Codex <codex@openai.com>

---------

Co-authored-by: OpenAI Codex <codex@openai.com>

README.md CHANGED
@@ -63,6 +63,7 @@ ml-intern --model anthropic/claude-opus-4-7 "your prompt" # requires ANTHROPIC
63
  ml-intern --model openai/gpt-5.5 "your prompt" # requires OPENAI_API_KEY
64
  ml-intern --model ollama/llama3.1:8b "your prompt"
65
  ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt"
 
66
  ml-intern --max-iterations 100 "your prompt"
67
  ml-intern --no-stream "your prompt"
68
  ```
@@ -97,6 +98,30 @@ one shared local endpoint, or override a specific provider with its matching
97
  `VLLM_API_KEY`. Provider-specific variables take precedence over the shared
98
  local variables. Base URLs may include or omit `/v1`.
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  ## Sharing Traces
101
 
102
  Every session is auto-uploaded to your **own private Hugging Face dataset**
 
63
  ml-intern --model openai/gpt-5.5 "your prompt" # requires OPENAI_API_KEY
64
  ml-intern --model ollama/llama3.1:8b "your prompt"
65
  ml-intern --model vllm/meta-llama/Llama-3.1-8B-Instruct "your prompt"
66
+ ml-intern --sandbox-tools "your prompt" # use HF Space sandbox tools
67
  ml-intern --max-iterations 100 "your prompt"
68
  ml-intern --no-stream "your prompt"
69
  ```
 
98
  `VLLM_API_KEY`. Provider-specific variables take precedence over the shared
99
  local variables. Base URLs may include or omit `/v1`.
100
 
101
+ **CLI tool runtime:**
102
+
103
+ By default, the CLI runs `bash`, `read`, `write`, and `edit` on your local
104
+ filesystem. To use HF Space sandbox tools instead, including `sandbox_create`,
105
+ opt in with `--sandbox-tools`:
106
+
107
+ ```bash
108
+ ml-intern --sandbox-tools "test this training script in a GPU sandbox"
109
+ ml-intern --model llamacpp/ggml-org/gemma-3-1b-it-GGUF --sandbox-tools
110
+ ```
111
+
112
+ Sandbox tool runtime requires `HF_TOKEN`, even when the selected model is local,
113
+ because it creates private HF Spaces. You can also make sandbox tools your CLI
114
+ default in `~/.config/ml-intern/cli_agent_config.json`:
115
+
116
+ ```json
117
+ { "tool_runtime": "sandbox" }
118
+ ```
119
+
120
+ Use the default local runtime when you want tools to inspect or edit files in
121
+ your checkout. Use sandbox runtime when you want the agent to create or replace
122
+ an HF Space sandbox, test code remotely, or request GPU sandbox hardware before
123
+ launching larger HF Jobs.
124
+
125
  ## Sharing Traces
126
 
127
  Every session is auto-uploaded to your **own private Hugging Face dataset**
agent/config.py CHANGED
@@ -2,7 +2,7 @@ import json
2
  import os
3
  import re
4
  from pathlib import Path
5
- from typing import Any, Union
6
 
7
  from dotenv import load_dotenv
8
  from fastmcp.mcp_config import (
@@ -46,6 +46,7 @@ class Config(BaseModel):
46
  # Permission control parameters
47
  confirm_cpu_jobs: bool = True
48
  auto_file_upload: bool = False
 
49
 
50
  # Reasoning effort *preference* — the ceiling the user wants. The probe
51
  # on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
 
2
  import os
3
  import re
4
  from pathlib import Path
5
+ from typing import Any, Literal, Union
6
 
7
  from dotenv import load_dotenv
8
  from fastmcp.mcp_config import (
 
46
  # Permission control parameters
47
  confirm_cpu_jobs: bool = True
48
  auto_file_upload: bool = False
49
+ tool_runtime: Literal["local", "sandbox"] = "local"
50
 
51
  # Reasoning effort *preference* — the ceiling the user wants. The probe
52
  # on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
agent/core/agent_loop.py CHANGED
@@ -32,7 +32,11 @@ from agent.core.prompt_caching import with_prompt_caching
32
  from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
33
  from agent.core.tools import ToolRouter
34
  from agent.tools.jobs_tool import CPU_FLAVORS
35
- from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE
 
 
 
 
36
 
37
  logger = logging.getLogger(__name__)
38
 
@@ -40,6 +44,43 @@ ToolCall = ChatCompletionMessageToolCall
40
 
41
  _MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
42
  _MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  def _malformed_tool_name(message: Message) -> str | None:
@@ -1153,6 +1194,7 @@ class Handlers:
1153
  final_response = None
1154
  errored = False
1155
  max_iterations = session.config.max_iterations
 
1156
 
1157
  while max_iterations == -1 or iteration < max_iterations:
1158
  # ── Cancellation check: before LLM call ──
@@ -1301,6 +1343,51 @@ class Handlers:
1301
 
1302
  # If no tool calls, add assistant message and we're done
1303
  if not tool_calls:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1304
  logger.debug(
1305
  "Agent loop ending: no tool calls. "
1306
  "finish_reason=%s, token_count=%d, "
@@ -1324,6 +1411,8 @@ class Handlers:
1324
  final_response = content
1325
  break
1326
 
 
 
1327
  # Validate tool call args (one json.loads per call, once)
1328
  # and split into good vs bad
1329
  good_tools: list[tuple[ToolCall, str, dict]] = []
@@ -1940,6 +2029,8 @@ class Handlers:
1940
  _ = session.save_and_upload_detached(repo_id)
1941
 
1942
  session.is_running = False
 
 
1943
  await session.send_event(Event(event_type="shutdown"))
1944
  return True
1945
 
@@ -2023,6 +2114,8 @@ async def submission_loop(
2023
  )
2024
  if session_holder is not None:
2025
  session_holder[0] = session
 
 
2026
  logger.info("Agent loop started")
2027
 
2028
  # Retry any failed uploads from previous sessions (fire-and-forget).
 
32
  from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
33
  from agent.core.tools import ToolRouter
34
  from agent.tools.jobs_tool import CPU_FLAVORS
35
+ from agent.tools.sandbox_tool import (
36
+ DEFAULT_CPU_SANDBOX_HARDWARE,
37
+ start_cpu_sandbox_preload,
38
+ teardown_session_sandbox,
39
+ )
40
 
41
  logger = logging.getLogger(__name__)
42
 
 
44
 
45
  _MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
46
  _MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
47
+ _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT = 2
48
+
49
+
50
+ def _unfinished_plan_items(session: Session) -> list[dict[str, str]]:
51
+ plan = getattr(session, "current_plan", None) or []
52
+ unfinished: list[dict[str, str]] = []
53
+ for item in plan:
54
+ if not isinstance(item, dict):
55
+ continue
56
+ status = item.get("status")
57
+ if status in {"pending", "in_progress"}:
58
+ unfinished.append(item)
59
+ return unfinished
60
+
61
+
62
+ def _format_plan_items_for_guard(items: list[dict[str, str]], limit: int = 4) -> str:
63
+ formatted = []
64
+ for item in items[:limit]:
65
+ item_id = item.get("id") or "?"
66
+ content = item.get("content") or "(unnamed task)"
67
+ status = item.get("status") or "unknown"
68
+ formatted.append(f"{item_id}. {content} [{status}]")
69
+ if len(items) > limit:
70
+ formatted.append(f"... and {len(items) - limit} more")
71
+ return "; ".join(formatted)
72
+
73
+
74
+ def _no_tool_incomplete_plan_prompt(items: list[dict[str, str]]) -> str:
75
+ summary = _format_plan_items_for_guard(items)
76
+ return (
77
+ "[SYSTEM: CONTINUATION GUARD] Your previous response ended without any "
78
+ "tool calls, but the task is not complete. The current plan still has "
79
+ f"unfinished items: {summary}. Do not return control to the user yet. "
80
+ "Continue from the next unfinished item and make at least one tool call "
81
+ "now. If you genuinely cannot continue, first use tools to inspect the "
82
+ "state or verify the blocker."
83
+ )
84
 
85
 
86
  def _malformed_tool_name(message: Message) -> str | None:
 
1194
  final_response = None
1195
  errored = False
1196
  max_iterations = session.config.max_iterations
1197
+ no_tool_incomplete_plan_retries = 0
1198
 
1199
  while max_iterations == -1 or iteration < max_iterations:
1200
  # ── Cancellation check: before LLM call ──
 
1343
 
1344
  # If no tool calls, add assistant message and we're done
1345
  if not tool_calls:
1346
+ unfinished_plan = _unfinished_plan_items(session)
1347
+ if (
1348
+ unfinished_plan
1349
+ and no_tool_incomplete_plan_retries
1350
+ < _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT
1351
+ ):
1352
+ logger.info(
1353
+ "No tool calls with unfinished plan; retrying agent turn "
1354
+ "(attempt %d/%d)",
1355
+ no_tool_incomplete_plan_retries + 1,
1356
+ _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT,
1357
+ )
1358
+ if content:
1359
+ assistant_msg = _assistant_message_from_result(
1360
+ llm_result,
1361
+ model_name=llm_params.get("model"),
1362
+ )
1363
+ session.context_manager.add_message(
1364
+ assistant_msg, token_count
1365
+ )
1366
+ session.context_manager.add_message(
1367
+ Message(
1368
+ role="user",
1369
+ content=_no_tool_incomplete_plan_prompt(
1370
+ unfinished_plan
1371
+ ),
1372
+ )
1373
+ )
1374
+ no_tool_incomplete_plan_retries += 1
1375
+ await session.send_event(
1376
+ Event(
1377
+ event_type="tool_log",
1378
+ data={
1379
+ "tool": "system",
1380
+ "log": (
1381
+ "Plan still has unfinished items after a "
1382
+ "text-only response — retrying instead of "
1383
+ "returning to the prompt."
1384
+ ),
1385
+ },
1386
+ )
1387
+ )
1388
+ iteration += 1
1389
+ continue
1390
+
1391
  logger.debug(
1392
  "Agent loop ending: no tool calls. "
1393
  "finish_reason=%s, token_count=%d, "
 
1411
  final_response = content
1412
  break
1413
 
1414
+ no_tool_incomplete_plan_retries = 0
1415
+
1416
  # Validate tool call args (one json.loads per call, once)
1417
  # and split into good vs bad
1418
  good_tools: list[tuple[ToolCall, str, dict]] = []
 
2029
  _ = session.save_and_upload_detached(repo_id)
2030
 
2031
  session.is_running = False
2032
+ if not getattr(session, "local_mode", False):
2033
+ await teardown_session_sandbox(session)
2034
  await session.send_event(Event(event_type="shutdown"))
2035
  return True
2036
 
 
2114
  )
2115
  if session_holder is not None:
2116
  session_holder[0] = session
2117
+ if not local_mode:
2118
+ start_cpu_sandbox_preload(session)
2119
  logger.info("Agent loop started")
2120
 
2121
  # Retry any failed uploads from previous sessions (fire-and-forget).
agent/core/session.py CHANGED
@@ -99,6 +99,7 @@ class Session:
99
  self.hf_token: Optional[str] = hf_token
100
  self.user_id: Optional[str] = user_id
101
  self.hf_username: Optional[str] = hf_username
 
102
  self.persistence_store = persistence_store
103
  self.tool_router = tool_router
104
  self.stream = stream
@@ -117,6 +118,7 @@ class Session:
117
  self.session_id = session_id or str(uuid.uuid4())
118
  self.config = config
119
  self.is_running = True
 
120
  self._cancelled = asyncio.Event()
121
  self.pending_approval: Optional[dict[str, Any]] = None
122
  self.sandbox = None
 
99
  self.hf_token: Optional[str] = hf_token
100
  self.user_id: Optional[str] = user_id
101
  self.hf_username: Optional[str] = hf_username
102
+ self.local_mode = local_mode
103
  self.persistence_store = persistence_store
104
  self.tool_router = tool_router
105
  self.stream = stream
 
118
  self.session_id = session_id or str(uuid.uuid4())
119
  self.config = config
120
  self.is_running = True
121
+ self.current_plan: list[dict[str, str]] = []
122
  self._cancelled = asyncio.Event()
123
  self.pending_approval: Optional[dict[str, Any]] = None
124
  self.sandbox = None
agent/main.py CHANGED
@@ -59,6 +59,34 @@ CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.j
59
  logger = logging.getLogger(__name__)
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
63
  if tool_info.get("tool") != "hf_jobs":
64
  return False
@@ -957,6 +985,7 @@ async def _handle_slash_command(
957
  session = session_holder[0] if session_holder else None
958
  print(f"Model: {config.model_name}")
959
  print(f"Reasoning effort: {config.reasoning_effort or 'off'}")
 
960
  if session:
961
  print(f"Turns: {session.turn_count}")
962
  print(f"Context items: {len(session.context_manager.items)}")
@@ -1076,7 +1105,7 @@ async def _handle_share_traces_command(arg: str, config, session) -> None:
1076
  console.print(f"[green]Dataset is now {label}.[/green] {url}")
1077
 
1078
 
1079
- async def main(model: str | None = None):
1080
  """Interactive chat with the agent"""
1081
 
1082
  # Clear screen
@@ -1088,16 +1117,23 @@ async def main(model: str | None = None):
1088
  config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
1089
  if model:
1090
  config.model_name = model
 
 
1091
 
1092
- # HF token — required for Hub-backed models/tools, but not for local LLMs.
 
1093
  hf_token = resolve_hf_token()
1094
- if not hf_token and not is_local_model_id(config.model_name):
1095
  hf_token = await _prompt_and_save_hf_token(prompt_session)
1096
 
1097
  # Resolve username for banner
1098
  hf_user = _get_hf_user(hf_token)
1099
 
1100
- print_banner(model=config.model_name, hf_user=hf_user)
 
 
 
 
1101
 
1102
  # Pre-warm the HF router catalog in the background so /model switches
1103
  # don't block on a network fetch.
@@ -1116,8 +1152,10 @@ async def main(model: str | None = None):
1116
 
1117
  notification_gateway = NotificationGateway(config.messaging)
1118
  await notification_gateway.start()
1119
- # Create tool router with local mode
1120
- tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
 
 
1121
 
1122
  # Session holder for interrupt/model/status access
1123
  session_holder = [None]
@@ -1131,7 +1169,7 @@ async def main(model: str | None = None):
1131
  session_holder=session_holder,
1132
  hf_token=hf_token,
1133
  user_id=hf_user,
1134
- local_mode=True,
1135
  stream=True,
1136
  notification_gateway=notification_gateway,
1137
  notification_destinations=config.messaging.default_auto_destinations(),
@@ -1153,6 +1191,8 @@ async def main(model: str | None = None):
1153
  )
1154
 
1155
  await ready_event.wait()
 
 
1156
 
1157
  submission_id = [0]
1158
  # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137
@@ -1310,6 +1350,7 @@ async def headless_main(
1310
  model: str | None = None,
1311
  max_iterations: int | None = None,
1312
  stream: bool = True,
 
1313
  ) -> None:
1314
  """Run a single prompt headlessly and exit."""
1315
  import logging
@@ -1322,11 +1363,13 @@ async def headless_main(
1322
 
1323
  if model:
1324
  config.model_name = model
 
 
1325
 
1326
  hf_token = resolve_hf_token()
1327
- if not hf_token and not is_local_model_id(config.model_name):
1328
  print(
1329
- "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.",
1330
  file=sys.stderr,
1331
  )
1332
  sys.exit(1)
@@ -1342,6 +1385,7 @@ async def headless_main(
1342
  config.max_iterations = max_iterations
1343
 
1344
  print(f"Model: {config.model_name}", file=sys.stderr)
 
1345
  print(f"Max iterations: {config.max_iterations}", file=sys.stderr)
1346
  print(f"Prompt: {prompt}", file=sys.stderr)
1347
  print("---", file=sys.stderr)
@@ -1349,7 +1393,9 @@ async def headless_main(
1349
  submission_queue: asyncio.Queue = asyncio.Queue()
1350
  event_queue: asyncio.Queue = asyncio.Queue()
1351
 
1352
- tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
 
 
1353
  session_holder: list = [None]
1354
 
1355
  agent_task = asyncio.create_task(
@@ -1361,7 +1407,7 @@ async def headless_main(
1361
  session_holder=session_holder,
1362
  hf_token=hf_token,
1363
  user_id=hf_user,
1364
- local_mode=True,
1365
  stream=stream,
1366
  notification_gateway=notification_gateway,
1367
  notification_destinations=config.messaging.default_auto_destinations(),
@@ -1556,6 +1602,11 @@ def cli():
1556
  action="store_true",
1557
  help="Disable token streaming (use non-streaming LLM calls)",
1558
  )
 
 
 
 
 
1559
  args = parser.parse_args()
1560
 
1561
  try:
@@ -1569,10 +1620,11 @@ def cli():
1569
  model=args.model,
1570
  max_iterations=max_iter,
1571
  stream=not args.no_stream,
 
1572
  )
1573
  )
1574
  else:
1575
- asyncio.run(main(model=args.model))
1576
  except KeyboardInterrupt:
1577
  print("\n\nGoodbye!")
1578
 
 
59
  logger = logging.getLogger(__name__)
60
 
61
 
62
+ def _apply_tool_runtime_override(config: Any, *, sandbox_tools: bool) -> str:
63
+ if sandbox_tools:
64
+ config.tool_runtime = "sandbox"
65
+ return getattr(config, "tool_runtime", "local")
66
+
67
+
68
+ def _is_local_tool_runtime(config: Any) -> bool:
69
+ return getattr(config, "tool_runtime", "local") == "local"
70
+
71
+
72
+ def _tool_runtime_label(local_mode: bool) -> str:
73
+ return "local filesystem" if local_mode else "HF sandbox"
74
+
75
+
76
+ async def _wait_for_initial_sandbox_preload(session_holder: list | None) -> None:
77
+ session = session_holder[0] if session_holder else None
78
+ task = getattr(session, "sandbox_preload_task", None)
79
+ if not task:
80
+ return
81
+ try:
82
+ await asyncio.shield(task)
83
+ except asyncio.CancelledError:
84
+ raise
85
+ except Exception:
86
+ # The sandbox tool will surface the stored preload error on first use.
87
+ return
88
+
89
+
90
  def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
91
  if tool_info.get("tool") != "hf_jobs":
92
  return False
 
985
  session = session_holder[0] if session_holder else None
986
  print(f"Model: {config.model_name}")
987
  print(f"Reasoning effort: {config.reasoning_effort or 'off'}")
988
+ print(f"Tool runtime: {_tool_runtime_label(_is_local_tool_runtime(config))}")
989
  if session:
990
  print(f"Turns: {session.turn_count}")
991
  print(f"Context items: {len(session.context_manager.items)}")
 
1105
  console.print(f"[green]Dataset is now {label}.[/green] {url}")
1106
 
1107
 
1108
+ async def main(model: str | None = None, sandbox_tools: bool = False):
1109
  """Interactive chat with the agent"""
1110
 
1111
  # Clear screen
 
1117
  config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
1118
  if model:
1119
  config.model_name = model
1120
+ _apply_tool_runtime_override(config, sandbox_tools=sandbox_tools)
1121
+ local_mode = _is_local_tool_runtime(config)
1122
 
1123
+ # HF token — required for Hub-backed models/tools and sandbox tools, but
1124
+ # not for local LLMs using only local filesystem tools.
1125
  hf_token = resolve_hf_token()
1126
+ if not hf_token and (not is_local_model_id(config.model_name) or not local_mode):
1127
  hf_token = await _prompt_and_save_hf_token(prompt_session)
1128
 
1129
  # Resolve username for banner
1130
  hf_user = _get_hf_user(hf_token)
1131
 
1132
+ print_banner(
1133
+ model=config.model_name,
1134
+ hf_user=hf_user,
1135
+ tool_runtime=_tool_runtime_label(local_mode),
1136
+ )
1137
 
1138
  # Pre-warm the HF router catalog in the background so /model switches
1139
  # don't block on a network fetch.
 
1152
 
1153
  notification_gateway = NotificationGateway(config.messaging)
1154
  await notification_gateway.start()
1155
+ # Create tool router with the selected CLI tool runtime.
1156
+ tool_router = ToolRouter(
1157
+ config.mcpServers, hf_token=hf_token, local_mode=local_mode
1158
+ )
1159
 
1160
  # Session holder for interrupt/model/status access
1161
  session_holder = [None]
 
1169
  session_holder=session_holder,
1170
  hf_token=hf_token,
1171
  user_id=hf_user,
1172
+ local_mode=local_mode,
1173
  stream=True,
1174
  notification_gateway=notification_gateway,
1175
  notification_destinations=config.messaging.default_auto_destinations(),
 
1191
  )
1192
 
1193
  await ready_event.wait()
1194
+ if not local_mode:
1195
+ await _wait_for_initial_sandbox_preload(session_holder)
1196
 
1197
  submission_id = [0]
1198
  # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137
 
1350
  model: str | None = None,
1351
  max_iterations: int | None = None,
1352
  stream: bool = True,
1353
+ sandbox_tools: bool = False,
1354
  ) -> None:
1355
  """Run a single prompt headlessly and exit."""
1356
  import logging
 
1363
 
1364
  if model:
1365
  config.model_name = model
1366
+ _apply_tool_runtime_override(config, sandbox_tools=sandbox_tools)
1367
+ local_mode = _is_local_tool_runtime(config)
1368
 
1369
  hf_token = resolve_hf_token()
1370
+ if not hf_token and (not is_local_model_id(config.model_name) or not local_mode):
1371
  print(
1372
+ "ERROR: No HF token found. Set HF_TOKEN or run `hf auth login`.",
1373
  file=sys.stderr,
1374
  )
1375
  sys.exit(1)
 
1385
  config.max_iterations = max_iterations
1386
 
1387
  print(f"Model: {config.model_name}", file=sys.stderr)
1388
+ print(f"Tool runtime: {_tool_runtime_label(local_mode)}", file=sys.stderr)
1389
  print(f"Max iterations: {config.max_iterations}", file=sys.stderr)
1390
  print(f"Prompt: {prompt}", file=sys.stderr)
1391
  print("---", file=sys.stderr)
 
1393
  submission_queue: asyncio.Queue = asyncio.Queue()
1394
  event_queue: asyncio.Queue = asyncio.Queue()
1395
 
1396
+ tool_router = ToolRouter(
1397
+ config.mcpServers, hf_token=hf_token, local_mode=local_mode
1398
+ )
1399
  session_holder: list = [None]
1400
 
1401
  agent_task = asyncio.create_task(
 
1407
  session_holder=session_holder,
1408
  hf_token=hf_token,
1409
  user_id=hf_user,
1410
+ local_mode=local_mode,
1411
  stream=stream,
1412
  notification_gateway=notification_gateway,
1413
  notification_destinations=config.messaging.default_auto_destinations(),
 
1602
  action="store_true",
1603
  help="Disable token streaming (use non-streaming LLM calls)",
1604
  )
1605
+ parser.add_argument(
1606
+ "--sandbox-tools",
1607
+ action="store_true",
1608
+ help="Use HF Space sandbox tools instead of local filesystem tools",
1609
+ )
1610
  args = parser.parse_args()
1611
 
1612
  try:
 
1620
  model=args.model,
1621
  max_iterations=max_iter,
1622
  stream=not args.no_stream,
1623
+ sandbox_tools=args.sandbox_tools,
1624
  )
1625
  )
1626
  else:
1627
+ asyncio.run(main(model=args.model, sandbox_tools=args.sandbox_tools))
1628
  except KeyboardInterrupt:
1629
  print("\n\nGoodbye!")
1630
 
agent/prompts/system_prompt_v3.yaml CHANGED
@@ -102,9 +102,18 @@ system_prompt: |
102
 
103
  # When submitting a training job
104
 
 
 
 
 
 
 
 
 
105
  Before calling hf_jobs, output a pre-flight check:
106
  - Reference implementation: [which example you based this on]
107
  - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
 
108
  - push_to_hub=True and hub_model_id set
109
  - timeout: [value] (based on: [model size] on [hardware])
110
  - Trackio monitoring included and deploying metrics to a public Space
@@ -127,7 +136,7 @@ system_prompt: |
127
 
128
  Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier.
129
 
130
- Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths.
131
 
132
 
133
  # When a task has 3+ steps
 
102
 
103
  # When submitting a training job
104
 
105
+ Never pass a local machine path to hf_jobs.script, such as /Users/..., /home/..., /fsx/..., or a repo checkout path. HF Jobs runs in a fresh cloud environment where local files do not exist. For hf_jobs.script, use exactly one of:
106
+ - inline Python source code
107
+ - a file already written in the session sandbox, e.g. /app/train.py, ./train.py, or train.py
108
+ - a public/raw URL
109
+ If you wrote or tested a script locally, read the file content and submit it inline, or write it into the sandbox first.
110
+
111
+ GPU preflight is mandatory before hf_jobs when the job will run on GPU, or when the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, or torch.compile. First create a GPU sandbox with sandbox_create (t4-small minimum; choose larger hardware when VRAM requires it), run a tiny smoke test there using the same imports, model-loading path, training entrypoint, and a tiny dataset/subset, then fix failures before submitting. If you skip GPU sandbox preflight, state why before calling hf_jobs.
112
+
113
  Before calling hf_jobs, output a pre-flight check:
114
  - Reference implementation: [which example you based this on]
115
  - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
116
+ - GPU sandbox smoke test: [hardware and result, or explicitly not applicable because ...]
117
  - push_to_hub=True and hub_model_id set
118
  - timeout: [value] (based on: [model size] on [hardware])
119
  - Trackio monitoring included and deploying metrics to a public Space
 
136
 
137
  Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier.
138
 
139
+ Use a GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16/fp16, quantization, flash attention, torch.compile, or model loading. CPU sandboxes cannot test GPU code paths. If the available sandbox tiers cannot fit the full model path, test the largest useful smoke path, state what was not covered, and submit one HF job first.
140
 
141
 
142
  # When a task has 3+ steps
agent/tools/jobs_tool.py CHANGED
@@ -1112,6 +1112,9 @@ HF_JOBS_TOOL_SPEC = {
1112
  "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. "
1113
  "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n"
1114
  "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
 
 
 
1115
  "- Training config MUST include push_to_hub=True and hub_model_id. "
1116
  "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
1117
  "- Include trackio monitoring and provide the dashboard URL to the user. "
@@ -1157,8 +1160,9 @@ HF_JOBS_TOOL_SPEC = {
1157
  "script": {
1158
  "type": "string",
1159
  "description": (
1160
- "Python code or sandbox file path (e.g. '/app/train.py') or URL. "
1161
  "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. "
 
1162
  "Mutually exclusive with 'command'."
1163
  ),
1164
  },
 
1112
  "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. "
1113
  "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n"
1114
  "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
1115
+ "- If the job runs on GPU, or the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, "
1116
+ "or torch.compile, you MUST create a GPU sandbox with sandbox_create first, run a tiny smoke test there, "
1117
+ "and fix failures before submitting. If skipped, state why before calling hf_jobs.\n"
1118
  "- Training config MUST include push_to_hub=True and hub_model_id. "
1119
  "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
1120
  "- Include trackio monitoring and provide the dashboard URL to the user. "
 
1160
  "script": {
1161
  "type": "string",
1162
  "description": (
1163
+ "Python code, sandbox file path (e.g. '/app/train.py', './train.py', or bare 'train.py'), or URL. "
1164
  "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. "
1165
+ "For GPU/model-loading training scripts, smoke-test in a GPU sandbox before submission. "
1166
  "Mutually exclusive with 'command'."
1167
  ),
1168
  },
agent/tools/plan_tool.py CHANGED
@@ -54,20 +54,24 @@ class PlanTool:
54
  "isError": True,
55
  }
56
 
57
- # Store the raw todos structure in memory
58
- _current_plan = todos
 
 
 
 
59
 
60
  # Emit plan update event if session is available
61
  if self.session:
62
  await self.session.send_event(
63
  Event(
64
  event_type="plan_update",
65
- data={"plan": todos},
66
  )
67
  )
68
 
69
  # Format only for display using terminal_display utility
70
- formatted_output = format_plan_tool_output(todos)
71
 
72
  return {
73
  "formatted": formatted_output,
 
54
  "isError": True,
55
  }
56
 
57
+ # Store a session-scoped copy so the runtime can tell whether a
58
+ # text-only model response is trying to stop while work remains.
59
+ stored_todos = [dict(todo) for todo in todos]
60
+ _current_plan = stored_todos
61
+ if self.session is not None:
62
+ self.session.current_plan = stored_todos
63
 
64
  # Emit plan update event if session is available
65
  if self.session:
66
  await self.session.send_event(
67
  Event(
68
  event_type="plan_update",
69
+ data={"plan": stored_todos},
70
  )
71
  )
72
 
73
  # Format only for display using terminal_display utility
74
+ formatted_output = format_plan_tool_output(stored_todos)
75
 
76
  return {
77
  "formatted": formatted_output,
agent/tools/sandbox_client.py CHANGED
@@ -776,21 +776,23 @@ class Sandbox:
776
  f"Last status: {last_status}, last error: {last_err}"
777
  )
778
 
779
- def delete(self):
780
  """Delete the Space. Only works if this Sandbox created it."""
781
  if not self._owns_space:
782
  raise RuntimeError(
783
  f"This Sandbox did not create {self.space_id}. "
784
  f"Use self._hf_api.delete_repo() directly if you're sure."
785
  )
786
- print(f"Deleting sandbox: {self.space_id}...")
 
787
  self._hf_api.delete_repo(self.space_id, repo_type="space")
788
  # Clear ownership so a second cleanup call (e.g. delete_session +
789
  # _run_session.finally both fire) early-returns instead of retrying
790
  # a 404 delete and emitting a spurious ERROR log.
791
  self._owns_space = False
792
  self._client.close()
793
- print("Deleted.")
 
794
 
795
  def pause(self):
796
  """Pause the Space (stops billing, preserves state)."""
 
776
  f"Last status: {last_status}, last error: {last_err}"
777
  )
778
 
779
+ def delete(self, log: Callable[[str], object] | None = None):
780
  """Delete the Space. Only works if this Sandbox created it."""
781
  if not self._owns_space:
782
  raise RuntimeError(
783
  f"This Sandbox did not create {self.space_id}. "
784
  f"Use self._hf_api.delete_repo() directly if you're sure."
785
  )
786
+ if log:
787
+ log(f"Deleting sandbox: {self.space_id}...")
788
  self._hf_api.delete_repo(self.space_id, repo_type="space")
789
  # Clear ownership so a second cleanup call (e.g. delete_session +
790
  # _run_session.finally both fire) early-returns instead of retrying
791
  # a 404 delete and emitting a spurious ERROR log.
792
  self._owns_space = False
793
  self._client.close()
794
+ if log:
795
+ log("Deleted.")
796
 
797
  def pause(self):
798
  """Pause the Space (stops billing, preserves state)."""
agent/tools/sandbox_tool.py CHANGED
@@ -16,6 +16,7 @@ import logging
16
  import re
17
  import threading
18
  import weakref
 
19
  from datetime import datetime, timedelta, timezone
20
  from typing import Any
21
 
@@ -58,17 +59,41 @@ def _get_sandbox_create_lock(owner: str) -> asyncio.Lock:
58
  return lock
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def _looks_like_path(script: str) -> bool:
62
  """Return True if the script string looks like a file path (not inline code)."""
63
- return (
64
  isinstance(script, str)
65
  and script.strip() == script
66
  and not any(c in script for c in "\r\n\0")
67
- and (
68
- script.startswith("/")
69
- or script.startswith("./")
70
- or script.startswith("../")
71
- )
 
 
 
 
 
 
72
  )
73
 
74
 
@@ -303,14 +328,8 @@ async def _create_sandbox_locked(
303
  )
304
  )
305
 
306
- # Thread-safe log callback: posts tool_log events from the worker thread
307
- loop = asyncio.get_running_loop()
308
-
309
- def _log(msg: str) -> None:
310
- loop.call_soon_threadsafe(
311
- session.event_queue.put_nowait,
312
- Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
313
- )
314
 
315
  # Bridge asyncio cancel event to a threading.Event for the blocking create call.
316
  # We poll session._cancelled from the main loop in a background task and set
@@ -352,7 +371,7 @@ async def _create_sandbox_locked(
352
  if cancel_flag.is_set():
353
  if getattr(sb, "_owns_space", False):
354
  try:
355
- await asyncio.to_thread(sb.delete)
356
  except Exception as e:
357
  logger.warning(
358
  "Failed to delete cancelled sandbox %s: %s", sb.space_id, e
@@ -497,6 +516,7 @@ async def teardown_session_sandbox(session: Any) -> None:
497
  return
498
 
499
  space_id = getattr(sandbox, "space_id", None)
 
500
  last_err: Exception | None = None
501
  for attempt in range(3):
502
  try:
@@ -505,7 +525,7 @@ async def teardown_session_sandbox(session: Any) -> None:
505
  space_id,
506
  attempt + 1,
507
  )
508
- await asyncio.to_thread(sandbox.delete)
509
  from agent.core import telemetry
510
 
511
  await telemetry.record_sandbox_destroy(session, sandbox)
 
16
  import re
17
  import threading
18
  import weakref
19
+ from collections.abc import Callable
20
  from datetime import datetime, timedelta, timezone
21
  from typing import Any
22
 
 
59
  return lock
60
 
61
 
62
+ def _session_tool_logger(
63
+ session: Any, *, tool: str = "sandbox"
64
+ ) -> Callable[[str], object] | None:
65
+ event_queue = getattr(session, "event_queue", None)
66
+ if event_queue is None:
67
+ return None
68
+
69
+ loop = asyncio.get_running_loop()
70
+
71
+ def _log(msg: str) -> None:
72
+ loop.call_soon_threadsafe(
73
+ event_queue.put_nowait,
74
+ Event(event_type="tool_log", data={"tool": tool, "log": msg}),
75
+ )
76
+
77
+ return _log
78
+
79
+
80
  def _looks_like_path(script: str) -> bool:
81
  """Return True if the script string looks like a file path (not inline code)."""
82
+ if not (
83
  isinstance(script, str)
84
  and script.strip() == script
85
  and not any(c in script for c in "\r\n\0")
86
+ ):
87
+ return False
88
+
89
+ if script.startswith("http://") or script.startswith("https://"):
90
+ return False
91
+
92
+ return (
93
+ script.startswith("/")
94
+ or script.startswith("./")
95
+ or script.startswith("../")
96
+ or (script.endswith(".py") and not any(c.isspace() for c in script))
97
  )
98
 
99
 
 
328
  )
329
  )
330
 
331
+ # Thread-safe log callback: posts tool_log events from worker threads.
332
+ _log = _session_tool_logger(session) or (lambda msg: None)
 
 
 
 
 
 
333
 
334
  # Bridge asyncio cancel event to a threading.Event for the blocking create call.
335
  # We poll session._cancelled from the main loop in a background task and set
 
371
  if cancel_flag.is_set():
372
  if getattr(sb, "_owns_space", False):
373
  try:
374
+ await asyncio.to_thread(sb.delete, log=_log)
375
  except Exception as e:
376
  logger.warning(
377
  "Failed to delete cancelled sandbox %s: %s", sb.space_id, e
 
516
  return
517
 
518
  space_id = getattr(sandbox, "space_id", None)
519
+ delete_log = _session_tool_logger(session)
520
  last_err: Exception | None = None
521
  for attempt in range(3):
522
  try:
 
525
  space_id,
526
  attempt + 1,
527
  )
528
+ await asyncio.to_thread(sandbox.delete, log=delete_log)
529
  from agent.core import telemetry
530
 
531
  await telemetry.record_sandbox_destroy(session, sandbox)
agent/utils/terminal_display.py CHANGED
@@ -93,7 +93,11 @@ def get_console() -> Console:
93
  # ── Banner ─────────────────────────────────────────────────────────────
94
 
95
 
96
- def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
 
 
 
 
97
  """Print particle logo then CRT boot sequence with system info."""
98
  from agent.utils.particle_logo import run_particle_logo
99
  from agent.utils.crt_boot import run_boot_sequence
@@ -116,6 +120,7 @@ def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
116
  (f"{_I}Initializing agent runtime...", gold),
117
  (f"{_I} User: {user_label}", dim_gold),
118
  (f"{_I} Model: {model_label}", dim_gold),
 
119
  (f"{_I} Tools: loading...", dim_gold),
120
  ("", ""),
121
  (f"{_I}/help for commands · /model to switch · /quit to exit", gold),
 
93
  # ── Banner ─────────────────────────────────────────────────────────────
94
 
95
 
96
+ def print_banner(
97
+ model: str | None = None,
98
+ hf_user: str | None = None,
99
+ tool_runtime: str | None = None,
100
+ ) -> None:
101
  """Print particle logo then CRT boot sequence with system info."""
102
  from agent.utils.particle_logo import run_particle_logo
103
  from agent.utils.crt_boot import run_boot_sequence
 
120
  (f"{_I}Initializing agent runtime...", gold),
121
  (f"{_I} User: {user_label}", dim_gold),
122
  (f"{_I} Model: {model_label}", dim_gold),
123
+ (f"{_I} Tool runtime: {tool_runtime or 'local filesystem'}", dim_gold),
124
  (f"{_I} Tools: loading...", dim_gold),
125
  ("", ""),
126
  (f"{_I}/help for commands · /model to switch · /quit to exit", gold),
configs/cli_agent_config.json CHANGED
@@ -7,6 +7,7 @@
7
  "yolo_mode": false,
8
  "confirm_cpu_jobs": true,
9
  "auto_file_upload": true,
 
10
  "messaging": {
11
  "enabled": false,
12
  "auto_event_types": ["approval_required", "error", "turn_complete"],
 
7
  "yolo_mode": false,
8
  "confirm_cpu_jobs": true,
9
  "auto_file_upload": true,
10
+ "tool_runtime": "local",
11
  "messaging": {
12
  "enabled": false,
13
  "auto_event_types": ["approval_required", "error", "turn_complete"],
tests/unit/test_cli_rendering.py CHANGED
@@ -1,5 +1,6 @@
1
  """Regression tests for interactive CLI rendering and research model routing."""
2
 
 
3
  import sys
4
  from io import StringIO
5
  from types import SimpleNamespace
@@ -97,10 +98,11 @@ def test_subagent_display_does_not_spawn_background_redraw(monkeypatch):
97
 
98
 
99
  def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
100
- seen: dict[str, str | None] = {}
101
 
102
- async def fake_main(*, model=None):
103
  seen["model"] = model
 
104
 
105
  monkeypatch.setattr(sys, "argv", ["ml-intern", "--model", "openai/gpt-5.5"])
106
  monkeypatch.setattr(main_mod, "main", fake_main)
@@ -108,6 +110,61 @@ def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
108
  main_mod.cli()
109
 
110
  assert seen["model"] == "openai/gpt-5.5"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  @pytest.mark.asyncio
@@ -115,9 +172,10 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
115
  class StopAfterBanner(Exception):
116
  pass
117
 
118
- def fake_banner(*, model=None, hf_user=None):
119
  assert model == "openai/gpt-5.5"
120
  assert hf_user == "tester"
 
121
  raise StopAfterBanner
122
 
123
  monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
@@ -130,9 +188,150 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
130
  lambda _path, **_kwargs: SimpleNamespace(
131
  model_name="moonshotai/Kimi-K2.6",
132
  mcpServers={},
 
133
  ),
134
  )
135
  monkeypatch.setattr(main_mod, "print_banner", fake_banner)
136
 
137
  with pytest.raises(StopAfterBanner):
138
  await main_mod.main(model="openai/gpt-5.5")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Regression tests for interactive CLI rendering and research model routing."""
2
 
3
+ import asyncio
4
  import sys
5
  from io import StringIO
6
  from types import SimpleNamespace
 
98
 
99
 
100
  def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
101
+ seen: dict[str, object] = {}
102
 
103
+ async def fake_main(*, model=None, sandbox_tools=False):
104
  seen["model"] = model
105
+ seen["sandbox_tools"] = sandbox_tools
106
 
107
  monkeypatch.setattr(sys, "argv", ["ml-intern", "--model", "openai/gpt-5.5"])
108
  monkeypatch.setattr(main_mod, "main", fake_main)
 
110
  main_mod.cli()
111
 
112
  assert seen["model"] == "openai/gpt-5.5"
113
+ assert seen["sandbox_tools"] is False
114
+
115
+
116
+ def test_cli_forwards_sandbox_flag_to_interactive_main(monkeypatch):
117
+ seen: dict[str, object] = {}
118
+
119
+ async def fake_main(*, model=None, sandbox_tools=False):
120
+ seen["model"] = model
121
+ seen["sandbox_tools"] = sandbox_tools
122
+
123
+ monkeypatch.setattr(sys, "argv", ["ml-intern", "--sandbox-tools"])
124
+ monkeypatch.setattr(main_mod, "main", fake_main)
125
+
126
+ main_mod.cli()
127
+
128
+ assert seen == {"model": None, "sandbox_tools": True}
129
+
130
+
131
+ def test_cli_forwards_sandbox_flag_to_headless_main(monkeypatch):
132
+ seen: dict[str, object] = {}
133
+
134
+ async def fake_headless_main(
135
+ prompt,
136
+ *,
137
+ model=None,
138
+ max_iterations=None,
139
+ stream=True,
140
+ sandbox_tools=False,
141
+ ):
142
+ seen.update(
143
+ {
144
+ "prompt": prompt,
145
+ "model": model,
146
+ "max_iterations": max_iterations,
147
+ "stream": stream,
148
+ "sandbox_tools": sandbox_tools,
149
+ }
150
+ )
151
+
152
+ monkeypatch.setattr(
153
+ sys,
154
+ "argv",
155
+ ["ml-intern", "--sandbox-tools", "--no-stream", "train a model"],
156
+ )
157
+ monkeypatch.setattr(main_mod, "headless_main", fake_headless_main)
158
+
159
+ main_mod.cli()
160
+
161
+ assert seen == {
162
+ "prompt": "train a model",
163
+ "model": None,
164
+ "max_iterations": None,
165
+ "stream": False,
166
+ "sandbox_tools": True,
167
+ }
168
 
169
 
170
  @pytest.mark.asyncio
 
172
  class StopAfterBanner(Exception):
173
  pass
174
 
175
+ def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
176
  assert model == "openai/gpt-5.5"
177
  assert hf_user == "tester"
178
+ assert tool_runtime == "local filesystem"
179
  raise StopAfterBanner
180
 
181
  monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
 
188
  lambda _path, **_kwargs: SimpleNamespace(
189
  model_name="moonshotai/Kimi-K2.6",
190
  mcpServers={},
191
+ tool_runtime="local",
192
  ),
193
  )
194
  monkeypatch.setattr(main_mod, "print_banner", fake_banner)
195
 
196
  with pytest.raises(StopAfterBanner):
197
  await main_mod.main(model="openai/gpt-5.5")
198
+
199
+
200
+ @pytest.mark.asyncio
201
+ async def test_local_model_local_runtime_skips_hf_token_prompt(monkeypatch):
202
+ class StopAfterBanner(Exception):
203
+ pass
204
+
205
+ async def fail_prompt(_prompt_session):
206
+ raise AssertionError("local model with local tools should not prompt")
207
+
208
+ def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
209
+ assert model == "llamacpp/model"
210
+ assert hf_user is None
211
+ assert tool_runtime == "local filesystem"
212
+ raise StopAfterBanner
213
+
214
+ monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
215
+ monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
216
+ monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None)
217
+ monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fail_prompt)
218
+ monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: None)
219
+ monkeypatch.setattr(
220
+ main_mod,
221
+ "load_config",
222
+ lambda _path, **_kwargs: SimpleNamespace(
223
+ model_name="llamacpp/model",
224
+ mcpServers={},
225
+ tool_runtime="local",
226
+ ),
227
+ )
228
+ monkeypatch.setattr(main_mod, "print_banner", fake_banner)
229
+
230
+ with pytest.raises(StopAfterBanner):
231
+ await main_mod.main()
232
+
233
+
234
+ @pytest.mark.asyncio
235
+ async def test_local_model_sandbox_runtime_prompts_for_hf_token(monkeypatch):
236
+ class StopAfterBanner(Exception):
237
+ pass
238
+
239
+ prompted = False
240
+
241
+ async def fake_prompt(_prompt_session):
242
+ nonlocal prompted
243
+ prompted = True
244
+ return "hf-token"
245
+
246
+ def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
247
+ assert model == "llamacpp/model"
248
+ assert hf_user == "tester"
249
+ assert tool_runtime == "HF sandbox"
250
+ raise StopAfterBanner
251
+
252
+ monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
253
+ monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
254
+ monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None)
255
+ monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fake_prompt)
256
+ monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester")
257
+ monkeypatch.setattr(
258
+ main_mod,
259
+ "load_config",
260
+ lambda _path, **_kwargs: SimpleNamespace(
261
+ model_name="llamacpp/model",
262
+ mcpServers={},
263
+ tool_runtime="local",
264
+ ),
265
+ )
266
+ monkeypatch.setattr(main_mod, "print_banner", fake_banner)
267
+
268
+ with pytest.raises(StopAfterBanner):
269
+ await main_mod.main(sandbox_tools=True)
270
+
271
+ assert prompted is True
272
+
273
+
274
+ @pytest.mark.asyncio
275
+ async def test_interactive_main_passes_sandbox_runtime_to_tool_router(monkeypatch):
276
+ class StopAfterToolRouter(Exception):
277
+ pass
278
+
279
+ seen: dict[str, object] = {}
280
+
281
+ class FakeGateway:
282
+ def __init__(self, _config):
283
+ pass
284
+
285
+ async def start(self):
286
+ pass
287
+
288
+ class FakeToolRouter:
289
+ def __init__(self, mcp_servers, *, hf_token=None, local_mode=True):
290
+ seen["mcp_servers"] = mcp_servers
291
+ seen["hf_token"] = hf_token
292
+ seen["local_mode"] = local_mode
293
+ raise StopAfterToolRouter
294
+
295
+ from agent.core import hf_router_catalog
296
+
297
+ monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
298
+ monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
299
+ monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: "hf-token")
300
+ monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester")
301
+ monkeypatch.setattr(main_mod, "print_banner", lambda **_kwargs: None)
302
+ monkeypatch.setattr(hf_router_catalog, "prewarm", lambda: None)
303
+ monkeypatch.setattr(
304
+ main_mod,
305
+ "load_config",
306
+ lambda _path, **_kwargs: SimpleNamespace(
307
+ model_name="llamacpp/model",
308
+ mcpServers={"server": object()},
309
+ messaging=SimpleNamespace(default_auto_destinations=lambda: []),
310
+ tool_runtime="local",
311
+ ),
312
+ )
313
+ monkeypatch.setattr(main_mod, "NotificationGateway", FakeGateway)
314
+ monkeypatch.setattr(main_mod, "ToolRouter", FakeToolRouter)
315
+
316
+ with pytest.raises(StopAfterToolRouter):
317
+ await main_mod.main(sandbox_tools=True)
318
+
319
+ assert seen["hf_token"] == "hf-token"
320
+ assert seen["local_mode"] is False
321
+
322
+
323
+ @pytest.mark.asyncio
324
+ async def test_initial_sandbox_preload_waits_before_prompt():
325
+ waited = False
326
+
327
+ async def preload():
328
+ nonlocal waited
329
+ await asyncio.sleep(0)
330
+ waited = True
331
+
332
+ task = asyncio.create_task(preload())
333
+ await main_mod._wait_for_initial_sandbox_preload(
334
+ [SimpleNamespace(sandbox_preload_task=task)]
335
+ )
336
+
337
+ assert waited is True
tests/unit/test_config.py CHANGED
@@ -1,5 +1,8 @@
1
  import json
2
 
 
 
 
3
  from agent import config as config_module
4
 
5
 
@@ -121,3 +124,35 @@ def test_slack_user_defaults_can_be_disabled(tmp_path, monkeypatch):
121
 
122
  assert not config.messaging.enabled
123
  assert config.messaging.destinations == {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
 
3
+ import pytest
4
+ from pydantic import ValidationError
5
+
6
  from agent import config as config_module
7
 
8
 
 
124
 
125
  assert not config.messaging.enabled
126
  assert config.messaging.destinations == {}
127
+
128
+
129
+ def test_tool_runtime_defaults_to_local(tmp_path):
130
+ config_path = tmp_path / "config.json"
131
+ _write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
132
+
133
+ config = config_module.load_config(str(config_path))
134
+
135
+ assert config.tool_runtime == "local"
136
+
137
+
138
+ def test_user_config_can_set_sandbox_tool_runtime(tmp_path, monkeypatch):
139
+ config_path = tmp_path / "config.json"
140
+ user_config_path = tmp_path / "user-config.json"
141
+ _write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
142
+ _write_json(user_config_path, {"tool_runtime": "sandbox"})
143
+ monkeypatch.setenv("ML_INTERN_CLI_CONFIG", str(user_config_path))
144
+
145
+ config = config_module.load_config(str(config_path), include_user_defaults=True)
146
+
147
+ assert config.tool_runtime == "sandbox"
148
+
149
+
150
+ def test_invalid_tool_runtime_is_rejected(tmp_path):
151
+ config_path = tmp_path / "config.json"
152
+ _write_json(
153
+ config_path,
154
+ {"model_name": "moonshotai/Kimi-K2.6", "tool_runtime": "hybrid"},
155
+ )
156
+
157
+ with pytest.raises(ValidationError):
158
+ config_module.load_config(str(config_path))
tests/unit/test_hub_artifacts.py CHANGED
@@ -549,7 +549,7 @@ def test_sitecustomize_caches_lazy_collection_slug_across_bootstraps(
549
  ]
550
 
551
 
552
- def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
553
  import huggingface_hub as hub
554
  from huggingface_hub import HfApi
555
 
@@ -579,6 +579,10 @@ def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
579
  def fake_add_collection_item(self, **kwargs):
580
  collection_items.append(kwargs)
581
 
 
 
 
 
582
  monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
583
  monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
584
  monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
 
549
  ]
550
 
551
 
552
+ def test_sitecustomize_skips_sandbox_space_registration(monkeypatch, tmp_path):
553
  import huggingface_hub as hub
554
  from huggingface_hub import HfApi
555
 
 
579
  def fake_add_collection_item(self, **kwargs):
580
  collection_items.append(kwargs)
581
 
582
+ monkeypatch.setenv(
583
+ "ML_INTERN_ARTIFACT_COLLECTION_CACHE",
584
+ str(tmp_path / "collection-slug.txt"),
585
+ )
586
  monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
587
  monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
588
  monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
tests/unit/test_no_tool_continuation_guard.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+
4
+ import pytest
5
+
6
+ from agent.config import Config
7
+ from agent.core import agent_loop
8
+ from agent.core.agent_loop import Handlers, LLMResult
9
+ from agent.core.session import Session
10
+ from agent.tools.plan_tool import PlanTool
11
+
12
+
13
+ class FakeToolRouter:
14
+ def __init__(self):
15
+ self.calls = []
16
+
17
+ def get_tool_specs_for_llm(self):
18
+ return [
19
+ {
20
+ "type": "function",
21
+ "function": {
22
+ "name": "plan_tool",
23
+ "description": "Update plan",
24
+ "parameters": {"type": "object"},
25
+ },
26
+ }
27
+ ]
28
+
29
+ async def call_tool(self, name, arguments, session=None, tool_call_id=None):
30
+ self.calls.append((name, arguments, tool_call_id))
31
+ if name == "plan_tool" and session is not None:
32
+ session.current_plan = [dict(todo) for todo in arguments["todos"]]
33
+ return "plan updated", True
34
+
35
+
36
+ @pytest.mark.asyncio
37
+ async def test_plan_tool_stores_session_scoped_plan():
38
+ events = []
39
+
40
+ class FakeSession:
41
+ current_plan = []
42
+
43
+ async def send_event(self, event):
44
+ events.append(event)
45
+
46
+ session = FakeSession()
47
+ todos = [{"id": "1", "content": "Smoke test", "status": "in_progress"}]
48
+
49
+ result = await PlanTool(session=session).execute({"todos": todos})
50
+
51
+ assert result["isError"] is False
52
+ assert session.current_plan == todos
53
+ assert events[0].event_type == "plan_update"
54
+ assert events[0].data == {"plan": todos}
55
+
56
+
57
+ @pytest.mark.asyncio
58
+ async def test_no_tool_response_retries_when_plan_is_incomplete(monkeypatch):
59
+ config = Config.model_validate(
60
+ {"model_name": "openai/test", "save_sessions": False}
61
+ )
62
+ event_queue = asyncio.Queue()
63
+ router = FakeToolRouter()
64
+ session = Session(
65
+ event_queue,
66
+ config,
67
+ tool_router=router,
68
+ stream=False,
69
+ )
70
+ session.current_plan = [
71
+ {
72
+ "id": "1",
73
+ "content": "Write and smoke-test training script",
74
+ "status": "in_progress",
75
+ },
76
+ {"id": "2", "content": "Launch full training job", "status": "pending"},
77
+ ]
78
+ calls = []
79
+
80
+ async def fake_call_llm_non_streaming(session, messages, tools, llm_params):
81
+ calls.append(messages)
82
+ if len(calls) == 1:
83
+ return LLMResult(
84
+ content="I should keep going, but I forgot to call a tool.",
85
+ tool_calls_acc={},
86
+ token_count=10,
87
+ finish_reason="stop",
88
+ )
89
+ if len(calls) == 2:
90
+ assert "CONTINUATION GUARD" in messages[-1].content
91
+ return LLMResult(
92
+ content=None,
93
+ tool_calls_acc={
94
+ 0: {
95
+ "id": "call_1",
96
+ "function": {
97
+ "name": "plan_tool",
98
+ "arguments": json.dumps(
99
+ {
100
+ "todos": [
101
+ {
102
+ "id": "1",
103
+ "content": "Write and smoke-test training script",
104
+ "status": "completed",
105
+ },
106
+ {
107
+ "id": "2",
108
+ "content": "Launch full training job",
109
+ "status": "completed",
110
+ },
111
+ ]
112
+ }
113
+ ),
114
+ },
115
+ }
116
+ },
117
+ token_count=20,
118
+ finish_reason="tool_calls",
119
+ )
120
+ return LLMResult(
121
+ content="Done.",
122
+ tool_calls_acc={},
123
+ token_count=30,
124
+ finish_reason="stop",
125
+ )
126
+
127
+ monkeypatch.setattr(
128
+ agent_loop, "_resolve_llm_params", lambda *_, **__: {"model": "openai/test"}
129
+ )
130
+ monkeypatch.setattr(
131
+ agent_loop, "_call_llm_non_streaming", fake_call_llm_non_streaming
132
+ )
133
+
134
+ final = await Handlers.run_agent(session, "continue")
135
+
136
+ assert final == "Done."
137
+ assert len(calls) == 3
138
+ assert router.calls[0][0] == "plan_tool"
139
+ assert all(todo["status"] == "completed" for todo in session.current_plan)
140
+ events = []
141
+ while not event_queue.empty():
142
+ events.append(await event_queue.get())
143
+ assert any(
144
+ event.event_type == "tool_log"
145
+ and "text-only response" in (event.data or {}).get("log", "")
146
+ for event in events
147
+ )
tests/unit/test_sandbox_auto_start.py CHANGED
@@ -1,7 +1,15 @@
 
1
  from types import SimpleNamespace
2
  from pathlib import Path
3
 
 
 
 
 
4
  from agent.core.agent_loop import _needs_approval
 
 
 
5
  from agent.tools.sandbox_tool import get_sandbox_tools
6
 
7
 
@@ -34,3 +42,102 @@ def test_prompt_and_tool_specs_do_not_require_cpu_sandbox_create():
34
  in tool_specs["sandbox_create"]
35
  )
36
  assert "started automatically for normal CPU work" in tool_specs["bash"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  from types import SimpleNamespace
3
  from pathlib import Path
4
 
5
+ import pytest
6
+
7
+ from agent.config import Config
8
+ from agent.core import agent_loop
9
  from agent.core.agent_loop import _needs_approval
10
+ from agent.core.session import OpType
11
+ from agent.core.tools import create_builtin_tools
12
+ from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
13
  from agent.tools.sandbox_tool import get_sandbox_tools
14
 
15
 
 
42
  in tool_specs["sandbox_create"]
43
  )
44
  assert "started automatically for normal CPU work" in tool_specs["bash"]
45
+
46
+
47
+ def test_prompt_rejects_local_machine_paths_for_hf_jobs_scripts():
48
+ prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
49
+
50
+ assert "Never pass a local machine path to hf_jobs.script" in prompt
51
+ assert "/fsx/..." in prompt
52
+ assert "inline Python source code" in prompt
53
+ assert "a file already written in the session sandbox" in prompt
54
+
55
+
56
+ def test_prompt_and_hf_jobs_spec_require_gpu_preflight_for_gpu_jobs():
57
+ prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
58
+ jobs_description = HF_JOBS_TOOL_SPEC["description"]
59
+
60
+ assert "GPU preflight is mandatory before hf_jobs" in prompt
61
+ assert "GPU sandbox smoke test" in prompt
62
+ assert "If you skip GPU sandbox preflight" in prompt
63
+ assert "you MUST create a GPU sandbox with sandbox_create first" in jobs_description
64
+ assert "If skipped, state why before calling hf_jobs" in jobs_description
65
+
66
+
67
+ def test_local_tool_runtime_excludes_sandbox_create():
68
+ tool_names = {tool.name for tool in create_builtin_tools(local_mode=True)}
69
+
70
+ assert {"bash", "read", "write", "edit"} <= tool_names
71
+ assert "sandbox_create" not in tool_names
72
+
73
+
74
+ def test_sandbox_tool_runtime_includes_sandbox_create():
75
+ tool_names = {tool.name for tool in create_builtin_tools(local_mode=False)}
76
+
77
+ assert {"sandbox_create", "bash", "read", "write", "edit"} <= tool_names
78
+
79
+
80
+ @pytest.mark.asyncio
81
+ async def test_cli_sandbox_runtime_preloads_and_tears_down_sandbox(monkeypatch):
82
+ started = []
83
+ torn_down = []
84
+
85
+ class FakeToolRouter:
86
+ tools = {}
87
+
88
+ def get_tool_specs_for_llm(self):
89
+ return []
90
+
91
+ async def __aenter__(self):
92
+ return self
93
+
94
+ async def __aexit__(self, exc_type, exc, tb):
95
+ return None
96
+
97
+ def fake_start_cpu_sandbox_preload(session):
98
+ started.append(session)
99
+ return None
100
+
101
+ async def fake_teardown_session_sandbox(session):
102
+ torn_down.append(session)
103
+
104
+ monkeypatch.setattr(
105
+ agent_loop, "start_cpu_sandbox_preload", fake_start_cpu_sandbox_preload
106
+ )
107
+ monkeypatch.setattr(
108
+ agent_loop, "teardown_session_sandbox", fake_teardown_session_sandbox
109
+ )
110
+
111
+ submission_queue = asyncio.Queue()
112
+ event_queue = asyncio.Queue()
113
+ session_holder = [None]
114
+ config = Config.model_validate(
115
+ {"model_name": "openai/gpt-5.5", "save_sessions": False}
116
+ )
117
+
118
+ task = asyncio.create_task(
119
+ agent_loop.submission_loop(
120
+ submission_queue,
121
+ event_queue,
122
+ config=config,
123
+ tool_router=FakeToolRouter(),
124
+ session_holder=session_holder,
125
+ hf_token="hf-token",
126
+ user_id="tester",
127
+ local_mode=False,
128
+ )
129
+ )
130
+
131
+ ready = await asyncio.wait_for(event_queue.get(), timeout=1)
132
+ assert ready.event_type == "ready"
133
+ assert started == [session_holder[0]]
134
+ assert session_holder[0].local_mode is False
135
+
136
+ await submission_queue.put(
137
+ SimpleNamespace(
138
+ operation=SimpleNamespace(op_type=OpType.SHUTDOWN, data=None),
139
+ )
140
+ )
141
+ await asyncio.wait_for(task, timeout=1)
142
+
143
+ assert torn_down == [session_holder[0]]
tests/unit/test_sandbox_private_spaces.py CHANGED
@@ -91,6 +91,31 @@ def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
91
  assert not any("sleep time" in log for log in logs)
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
95
  runtime_calls = 0
96
 
@@ -395,6 +420,71 @@ def test_ensure_sandbox_overrides_private_argument(monkeypatch):
395
  assert persisted[-1]["sandbox_status"] == "active"
396
 
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  def test_sandbox_creation_is_serialized_per_owner(monkeypatch):
399
  active_creates = 0
400
  max_active_creates = 0
@@ -514,7 +604,7 @@ def test_sandbox_create_replaces_auto_cpu_sandbox(monkeypatch):
514
  space_id="alice/sandbox-cpu",
515
  url="https://huggingface.co/spaces/alice/sandbox-cpu",
516
  _owns_space=True,
517
- delete=lambda: deleted.append("alice/sandbox-cpu"),
518
  )
519
  self.sandbox_hardware = "cpu-basic"
520
  self.sandbox_preload_task = None
@@ -559,10 +649,11 @@ def test_sandbox_create_replaces_auto_cpu_sandbox(monkeypatch):
559
 
560
  def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
561
  deleted: list[str] = []
 
562
  persisted: list[dict] = []
563
 
564
- async def fake_record_sandbox_destroy(*args, **kwargs):
565
- pass
566
 
567
  monkeypatch.setattr(
568
  telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy
@@ -570,20 +661,28 @@ def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
570
 
571
  async def run():
572
  cancel_event = threading.Event()
 
573
 
574
  async def preload():
575
  await asyncio.sleep(0)
576
 
 
 
 
 
 
 
577
  session = SimpleNamespace(
578
  session_id="s1",
579
  sandbox=SimpleNamespace(
580
  space_id="alice/sandbox-12345678",
581
  _owns_space=True,
582
- delete=lambda: deleted.append("alice/sandbox-12345678"),
583
  ),
584
  sandbox_hardware="cpu-basic",
585
  sandbox_preload_task=asyncio.create_task(preload()),
586
  sandbox_preload_cancel_event=cancel_event,
 
587
  persistence_store=SimpleNamespace(
588
  update_session_fields=lambda session_id, **fields: _record_metadata(
589
  session_id, fields
@@ -592,17 +691,33 @@ def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
592
  )
593
 
594
  await sandbox_tool.teardown_session_sandbox(session)
595
- return session, cancel_event
 
 
 
 
596
 
597
  async def _record_metadata(session_id, fields):
598
  persisted.append({"session_id": session_id, **fields})
599
 
600
- session, cancel_event = asyncio.run(run())
601
 
602
  assert cancel_event.is_set()
603
  assert deleted == ["alice/sandbox-12345678"]
 
604
  assert session.sandbox is None
605
  assert session.sandbox_hardware is None
 
 
 
 
 
 
 
 
 
 
 
606
  assert persisted[-1]["session_id"] == "s1"
607
  assert persisted[-1]["sandbox_space_id"] is None
608
  assert persisted[-1]["sandbox_status"] == "destroyed"
 
91
  assert not any("sleep time" in log for log in logs)
92
 
93
 
94
+ def test_sandbox_delete_uses_log_callback_without_stdout(monkeypatch, capsys):
95
+ deleted: list[tuple[str, str]] = []
96
+
97
+ class FakeApi:
98
+ def __init__(self, token=None):
99
+ self.token = token
100
+
101
+ def delete_repo(self, repo_id, repo_type):
102
+ deleted.append((repo_id, repo_type))
103
+
104
+ monkeypatch.setattr(sandbox_client, "HfApi", FakeApi)
105
+
106
+ sandbox = Sandbox("alice/sandbox-12345678", token="hf-token", _owns_space=True)
107
+ logs: list[str] = []
108
+
109
+ sandbox.delete(log=logs.append)
110
+
111
+ captured = capsys.readouterr()
112
+ assert captured.out == ""
113
+ assert captured.err == ""
114
+ assert deleted == [("alice/sandbox-12345678", "space")]
115
+ assert logs == ["Deleting sandbox: alice/sandbox-12345678...", "Deleted."]
116
+ assert sandbox._owns_space is False
117
+
118
+
119
  def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
120
  runtime_calls = 0
121
 
 
420
  assert persisted[-1]["sandbox_status"] == "active"
421
 
422
 
423
+ def test_cancelled_sandbox_creation_logs_delete_through_tool_log(monkeypatch):
424
+ deleted: list[str] = []
425
+
426
+ class FakeSession:
427
+ def __init__(self):
428
+ self.hf_token = "hf-token"
429
+ self.sandbox = None
430
+ self.event_queue = asyncio.Queue()
431
+ self._cancelled = asyncio.Event()
432
+
433
+ async def send_event(self, event):
434
+ await self.event_queue.put(event)
435
+
436
+ def fake_create(**kwargs):
437
+ def delete(log=None):
438
+ deleted.append("alice/sandbox-12345678")
439
+ if log:
440
+ log("Deleting sandbox: alice/sandbox-12345678...")
441
+ log("Deleted.")
442
+
443
+ return SimpleNamespace(
444
+ space_id="alice/sandbox-12345678",
445
+ url="https://huggingface.co/spaces/alice/sandbox-12345678",
446
+ _owns_space=True,
447
+ delete=delete,
448
+ )
449
+
450
+ monkeypatch.setattr(Sandbox, "create", staticmethod(fake_create))
451
+
452
+ async def run():
453
+ session = FakeSession()
454
+ cancel_event = threading.Event()
455
+ cancel_event.set()
456
+
457
+ sb, error = await sandbox_tool._create_sandbox_locked(
458
+ session,
459
+ api=SimpleNamespace(),
460
+ owner="alice",
461
+ hardware="cpu-basic",
462
+ cancel_event=cancel_event,
463
+ )
464
+ await asyncio.sleep(0)
465
+ events = []
466
+ while not session.event_queue.empty():
467
+ events.append(await session.event_queue.get())
468
+ return sb, error, events
469
+
470
+ sb, error, events = asyncio.run(run())
471
+
472
+ assert sb is None
473
+ assert error == "Sandbox creation cancelled by user."
474
+ assert deleted == ["alice/sandbox-12345678"]
475
+ assert [
476
+ event.data
477
+ for event in events
478
+ if event.event_type == "tool_log"
479
+ and event.data
480
+ and event.data.get("log")
481
+ in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."}
482
+ ] == [
483
+ {"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."},
484
+ {"tool": "sandbox", "log": "Deleted."},
485
+ ]
486
+
487
+
488
  def test_sandbox_creation_is_serialized_per_owner(monkeypatch):
489
  active_creates = 0
490
  max_active_creates = 0
 
604
  space_id="alice/sandbox-cpu",
605
  url="https://huggingface.co/spaces/alice/sandbox-cpu",
606
  _owns_space=True,
607
+ delete=lambda log=None: deleted.append("alice/sandbox-cpu"),
608
  )
609
  self.sandbox_hardware = "cpu-basic"
610
  self.sandbox_preload_task = None
 
649
 
650
  def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
651
  deleted: list[str] = []
652
+ destroyed: list[str] = []
653
  persisted: list[dict] = []
654
 
655
+ async def fake_record_sandbox_destroy(session, sandbox, *args, **kwargs):
656
+ destroyed.append(sandbox.space_id)
657
 
658
  monkeypatch.setattr(
659
  telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy
 
661
 
662
  async def run():
663
  cancel_event = threading.Event()
664
+ event_queue = asyncio.Queue()
665
 
666
  async def preload():
667
  await asyncio.sleep(0)
668
 
669
+ def delete(log=None):
670
+ deleted.append("alice/sandbox-12345678")
671
+ if log:
672
+ log("Deleting sandbox: alice/sandbox-12345678...")
673
+ log("Deleted.")
674
+
675
  session = SimpleNamespace(
676
  session_id="s1",
677
  sandbox=SimpleNamespace(
678
  space_id="alice/sandbox-12345678",
679
  _owns_space=True,
680
+ delete=delete,
681
  ),
682
  sandbox_hardware="cpu-basic",
683
  sandbox_preload_task=asyncio.create_task(preload()),
684
  sandbox_preload_cancel_event=cancel_event,
685
+ event_queue=event_queue,
686
  persistence_store=SimpleNamespace(
687
  update_session_fields=lambda session_id, **fields: _record_metadata(
688
  session_id, fields
 
691
  )
692
 
693
  await sandbox_tool.teardown_session_sandbox(session)
694
+ await asyncio.sleep(0)
695
+ events = []
696
+ while not event_queue.empty():
697
+ events.append(await event_queue.get())
698
+ return session, cancel_event, events
699
 
700
  async def _record_metadata(session_id, fields):
701
  persisted.append({"session_id": session_id, **fields})
702
 
703
+ session, cancel_event, events = asyncio.run(run())
704
 
705
  assert cancel_event.is_set()
706
  assert deleted == ["alice/sandbox-12345678"]
707
+ assert destroyed == ["alice/sandbox-12345678"]
708
  assert session.sandbox is None
709
  assert session.sandbox_hardware is None
710
+ assert [
711
+ event.data
712
+ for event in events
713
+ if event.event_type == "tool_log"
714
+ and event.data
715
+ and event.data.get("log")
716
+ in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."}
717
+ ] == [
718
+ {"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."},
719
+ {"tool": "sandbox", "log": "Deleted."},
720
+ ]
721
  assert persisted[-1]["session_id"] == "s1"
722
  assert persisted[-1]["sandbox_space_id"] is None
723
  assert persisted[-1]["sandbox_status"] == "destroyed"
tests/unit/test_sandbox_script_resolution.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+
3
+ import pytest
4
+
5
+ from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
6
+ from agent.tools.sandbox_tool import resolve_sandbox_script
7
+
8
+
9
+ class FakeSandbox:
10
+ def __init__(self):
11
+ self.read_paths = []
12
+
13
+ def read(self, path, *, limit):
14
+ self.read_paths.append((path, limit))
15
+ return SimpleNamespace(
16
+ success=True,
17
+ output="1\tprint('training')\n2\tprint('done')",
18
+ error="",
19
+ )
20
+
21
+
22
+ @pytest.mark.asyncio
23
+ async def test_resolve_sandbox_script_accepts_bare_python_filename():
24
+ sandbox = FakeSandbox()
25
+
26
+ content, error = await resolve_sandbox_script(sandbox, "train_smollm2.py")
27
+
28
+ assert error is None
29
+ assert content == "print('training')\nprint('done')"
30
+ assert sandbox.read_paths == [("train_smollm2.py", 100_000)]
31
+
32
+
33
+ @pytest.mark.asyncio
34
+ async def test_resolve_sandbox_script_accepts_relative_python_path():
35
+ sandbox = FakeSandbox()
36
+
37
+ content, error = await resolve_sandbox_script(sandbox, "scripts/train.py")
38
+
39
+ assert error is None
40
+ assert content == "print('training')\nprint('done')"
41
+ assert sandbox.read_paths == [("scripts/train.py", 100_000)]
42
+
43
+
44
+ @pytest.mark.asyncio
45
+ @pytest.mark.parametrize(
46
+ "script",
47
+ [
48
+ "https://example.com/train.py",
49
+ "http://example.com/train.py",
50
+ "train_smollm2.py --epochs 1",
51
+ "print('hello')",
52
+ ],
53
+ )
54
+ async def test_resolve_sandbox_script_ignores_non_path_scripts(script):
55
+ sandbox = FakeSandbox()
56
+
57
+ content, error = await resolve_sandbox_script(sandbox, script)
58
+
59
+ assert content is None
60
+ assert error is None
61
+ assert sandbox.read_paths == []
62
+
63
+
64
+ def test_hf_jobs_script_description_mentions_bare_python_filenames():
65
+ script_description = HF_JOBS_TOOL_SPEC["parameters"]["properties"]["script"][
66
+ "description"
67
+ ]
68
+
69
+ assert "bare 'train.py'" in script_description
70
+ assert "smoke-test in a GPU sandbox before submission" in script_description
tests/unit/test_session_manager_persistence.py CHANGED
@@ -207,7 +207,7 @@ async def test_close_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
207
  session.sandbox = SimpleNamespace(
208
  space_id="owner/sandbox-12345678",
209
  _owns_space=True,
210
- delete=lambda: deleted.append("owner/sandbox-12345678"),
211
  )
212
  session.sandbox_hardware = "cpu-basic"
213
  session.sandbox_preload_cancel_event = preload_cancel_event
 
207
  session.sandbox = SimpleNamespace(
208
  space_id="owner/sandbox-12345678",
209
  _owns_space=True,
210
+ delete=lambda log=None: deleted.append("owner/sandbox-12345678"),
211
  )
212
  session.sandbox_hardware = "cpu-basic"
213
  session.sandbox_preload_cancel_event = preload_cancel_event