lewtun HF Staff OpenAI Codex commited on
Commit
479aaea
·
2 Parent(s): ade0b7e2b4c539

Deploy 2026-05-11

Browse files

Co-authored-by: OpenAI Codex <codex@openai.com>

agent/config.py CHANGED
@@ -2,7 +2,7 @@ import json
2
  import os
3
  import re
4
  from pathlib import Path
5
- from typing import Any, Union
6
 
7
  from dotenv import load_dotenv
8
  from fastmcp.mcp_config import (
@@ -46,6 +46,7 @@ class Config(BaseModel):
46
  # Permission control parameters
47
  confirm_cpu_jobs: bool = True
48
  auto_file_upload: bool = False
 
49
 
50
  # Reasoning effort *preference* — the ceiling the user wants. The probe
51
  # on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
 
2
  import os
3
  import re
4
  from pathlib import Path
5
+ from typing import Any, Literal, Union
6
 
7
  from dotenv import load_dotenv
8
  from fastmcp.mcp_config import (
 
46
  # Permission control parameters
47
  confirm_cpu_jobs: bool = True
48
  auto_file_upload: bool = False
49
+ tool_runtime: Literal["local", "sandbox"] = "local"
50
 
51
  # Reasoning effort *preference* — the ceiling the user wants. The probe
52
  # on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
agent/core/agent_loop.py CHANGED
@@ -32,7 +32,11 @@ from agent.core.prompt_caching import with_prompt_caching
32
  from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
33
  from agent.core.tools import ToolRouter
34
  from agent.tools.jobs_tool import CPU_FLAVORS
35
- from agent.tools.sandbox_tool import DEFAULT_CPU_SANDBOX_HARDWARE
 
 
 
 
36
 
37
  logger = logging.getLogger(__name__)
38
 
@@ -40,6 +44,43 @@ ToolCall = ChatCompletionMessageToolCall
40
 
41
  _MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
42
  _MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  def _malformed_tool_name(message: Message) -> str | None:
@@ -1153,6 +1194,7 @@ class Handlers:
1153
  final_response = None
1154
  errored = False
1155
  max_iterations = session.config.max_iterations
 
1156
 
1157
  while max_iterations == -1 or iteration < max_iterations:
1158
  # ── Cancellation check: before LLM call ──
@@ -1301,6 +1343,51 @@ class Handlers:
1301
 
1302
  # If no tool calls, add assistant message and we're done
1303
  if not tool_calls:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1304
  logger.debug(
1305
  "Agent loop ending: no tool calls. "
1306
  "finish_reason=%s, token_count=%d, "
@@ -1324,6 +1411,8 @@ class Handlers:
1324
  final_response = content
1325
  break
1326
 
 
 
1327
  # Validate tool call args (one json.loads per call, once)
1328
  # and split into good vs bad
1329
  good_tools: list[tuple[ToolCall, str, dict]] = []
@@ -1940,6 +2029,8 @@ class Handlers:
1940
  _ = session.save_and_upload_detached(repo_id)
1941
 
1942
  session.is_running = False
 
 
1943
  await session.send_event(Event(event_type="shutdown"))
1944
  return True
1945
 
@@ -2023,6 +2114,8 @@ async def submission_loop(
2023
  )
2024
  if session_holder is not None:
2025
  session_holder[0] = session
 
 
2026
  logger.info("Agent loop started")
2027
 
2028
  # Retry any failed uploads from previous sessions (fire-and-forget).
 
32
  from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
33
  from agent.core.tools import ToolRouter
34
  from agent.tools.jobs_tool import CPU_FLAVORS
35
+ from agent.tools.sandbox_tool import (
36
+ DEFAULT_CPU_SANDBOX_HARDWARE,
37
+ start_cpu_sandbox_preload,
38
+ teardown_session_sandbox,
39
+ )
40
 
41
  logger = logging.getLogger(__name__)
42
 
 
44
 
45
  _MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
46
  _MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
47
+ _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT = 2
48
+
49
+
50
+ def _unfinished_plan_items(session: Session) -> list[dict[str, str]]:
51
+ plan = getattr(session, "current_plan", None) or []
52
+ unfinished: list[dict[str, str]] = []
53
+ for item in plan:
54
+ if not isinstance(item, dict):
55
+ continue
56
+ status = item.get("status")
57
+ if status in {"pending", "in_progress"}:
58
+ unfinished.append(item)
59
+ return unfinished
60
+
61
+
62
+ def _format_plan_items_for_guard(items: list[dict[str, str]], limit: int = 4) -> str:
63
+ formatted = []
64
+ for item in items[:limit]:
65
+ item_id = item.get("id") or "?"
66
+ content = item.get("content") or "(unnamed task)"
67
+ status = item.get("status") or "unknown"
68
+ formatted.append(f"{item_id}. {content} [{status}]")
69
+ if len(items) > limit:
70
+ formatted.append(f"... and {len(items) - limit} more")
71
+ return "; ".join(formatted)
72
+
73
+
74
+ def _no_tool_incomplete_plan_prompt(items: list[dict[str, str]]) -> str:
75
+ summary = _format_plan_items_for_guard(items)
76
+ return (
77
+ "[SYSTEM: CONTINUATION GUARD] Your previous response ended without any "
78
+ "tool calls, but the task is not complete. The current plan still has "
79
+ f"unfinished items: {summary}. Do not return control to the user yet. "
80
+ "Continue from the next unfinished item and make at least one tool call "
81
+ "now. If you genuinely cannot continue, first use tools to inspect the "
82
+ "state or verify the blocker."
83
+ )
84
 
85
 
86
  def _malformed_tool_name(message: Message) -> str | None:
 
1194
  final_response = None
1195
  errored = False
1196
  max_iterations = session.config.max_iterations
1197
+ no_tool_incomplete_plan_retries = 0
1198
 
1199
  while max_iterations == -1 or iteration < max_iterations:
1200
  # ── Cancellation check: before LLM call ──
 
1343
 
1344
  # If no tool calls, add assistant message and we're done
1345
  if not tool_calls:
1346
+ unfinished_plan = _unfinished_plan_items(session)
1347
+ if (
1348
+ unfinished_plan
1349
+ and no_tool_incomplete_plan_retries
1350
+ < _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT
1351
+ ):
1352
+ logger.info(
1353
+ "No tool calls with unfinished plan; retrying agent turn "
1354
+ "(attempt %d/%d)",
1355
+ no_tool_incomplete_plan_retries + 1,
1356
+ _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT,
1357
+ )
1358
+ if content:
1359
+ assistant_msg = _assistant_message_from_result(
1360
+ llm_result,
1361
+ model_name=llm_params.get("model"),
1362
+ )
1363
+ session.context_manager.add_message(
1364
+ assistant_msg, token_count
1365
+ )
1366
+ session.context_manager.add_message(
1367
+ Message(
1368
+ role="user",
1369
+ content=_no_tool_incomplete_plan_prompt(
1370
+ unfinished_plan
1371
+ ),
1372
+ )
1373
+ )
1374
+ no_tool_incomplete_plan_retries += 1
1375
+ await session.send_event(
1376
+ Event(
1377
+ event_type="tool_log",
1378
+ data={
1379
+ "tool": "system",
1380
+ "log": (
1381
+ "Plan still has unfinished items after a "
1382
+ "text-only response — retrying instead of "
1383
+ "returning to the prompt."
1384
+ ),
1385
+ },
1386
+ )
1387
+ )
1388
+ iteration += 1
1389
+ continue
1390
+
1391
  logger.debug(
1392
  "Agent loop ending: no tool calls. "
1393
  "finish_reason=%s, token_count=%d, "
 
1411
  final_response = content
1412
  break
1413
 
1414
+ no_tool_incomplete_plan_retries = 0
1415
+
1416
  # Validate tool call args (one json.loads per call, once)
1417
  # and split into good vs bad
1418
  good_tools: list[tuple[ToolCall, str, dict]] = []
 
2029
  _ = session.save_and_upload_detached(repo_id)
2030
 
2031
  session.is_running = False
2032
+ if not getattr(session, "local_mode", False):
2033
+ await teardown_session_sandbox(session)
2034
  await session.send_event(Event(event_type="shutdown"))
2035
  return True
2036
 
 
2114
  )
2115
  if session_holder is not None:
2116
  session_holder[0] = session
2117
+ if not local_mode:
2118
+ start_cpu_sandbox_preload(session)
2119
  logger.info("Agent loop started")
2120
 
2121
  # Retry any failed uploads from previous sessions (fire-and-forget).
agent/core/session.py CHANGED
@@ -99,6 +99,7 @@ class Session:
99
  self.hf_token: Optional[str] = hf_token
100
  self.user_id: Optional[str] = user_id
101
  self.hf_username: Optional[str] = hf_username
 
102
  self.persistence_store = persistence_store
103
  self.tool_router = tool_router
104
  self.stream = stream
@@ -117,6 +118,7 @@ class Session:
117
  self.session_id = session_id or str(uuid.uuid4())
118
  self.config = config
119
  self.is_running = True
 
120
  self._cancelled = asyncio.Event()
121
  self.pending_approval: Optional[dict[str, Any]] = None
122
  self.sandbox = None
 
99
  self.hf_token: Optional[str] = hf_token
100
  self.user_id: Optional[str] = user_id
101
  self.hf_username: Optional[str] = hf_username
102
+ self.local_mode = local_mode
103
  self.persistence_store = persistence_store
104
  self.tool_router = tool_router
105
  self.stream = stream
 
118
  self.session_id = session_id or str(uuid.uuid4())
119
  self.config = config
120
  self.is_running = True
121
+ self.current_plan: list[dict[str, str]] = []
122
  self._cancelled = asyncio.Event()
123
  self.pending_approval: Optional[dict[str, Any]] = None
124
  self.sandbox = None
agent/main.py CHANGED
@@ -59,6 +59,34 @@ CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.j
59
  logger = logging.getLogger(__name__)
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
63
  if tool_info.get("tool") != "hf_jobs":
64
  return False
@@ -957,6 +985,7 @@ async def _handle_slash_command(
957
  session = session_holder[0] if session_holder else None
958
  print(f"Model: {config.model_name}")
959
  print(f"Reasoning effort: {config.reasoning_effort or 'off'}")
 
960
  if session:
961
  print(f"Turns: {session.turn_count}")
962
  print(f"Context items: {len(session.context_manager.items)}")
@@ -1076,7 +1105,7 @@ async def _handle_share_traces_command(arg: str, config, session) -> None:
1076
  console.print(f"[green]Dataset is now {label}.[/green] {url}")
1077
 
1078
 
1079
- async def main(model: str | None = None):
1080
  """Interactive chat with the agent"""
1081
 
1082
  # Clear screen
@@ -1088,16 +1117,23 @@ async def main(model: str | None = None):
1088
  config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
1089
  if model:
1090
  config.model_name = model
 
 
1091
 
1092
- # HF token — required for Hub-backed models/tools, but not for local LLMs.
 
1093
  hf_token = resolve_hf_token()
1094
- if not hf_token and not is_local_model_id(config.model_name):
1095
  hf_token = await _prompt_and_save_hf_token(prompt_session)
1096
 
1097
  # Resolve username for banner
1098
  hf_user = _get_hf_user(hf_token)
1099
 
1100
- print_banner(model=config.model_name, hf_user=hf_user)
 
 
 
 
1101
 
1102
  # Pre-warm the HF router catalog in the background so /model switches
1103
  # don't block on a network fetch.
@@ -1116,8 +1152,10 @@ async def main(model: str | None = None):
1116
 
1117
  notification_gateway = NotificationGateway(config.messaging)
1118
  await notification_gateway.start()
1119
- # Create tool router with local mode
1120
- tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
 
 
1121
 
1122
  # Session holder for interrupt/model/status access
1123
  session_holder = [None]
@@ -1131,7 +1169,7 @@ async def main(model: str | None = None):
1131
  session_holder=session_holder,
1132
  hf_token=hf_token,
1133
  user_id=hf_user,
1134
- local_mode=True,
1135
  stream=True,
1136
  notification_gateway=notification_gateway,
1137
  notification_destinations=config.messaging.default_auto_destinations(),
@@ -1153,6 +1191,8 @@ async def main(model: str | None = None):
1153
  )
1154
 
1155
  await ready_event.wait()
 
 
1156
 
1157
  submission_id = [0]
1158
  # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137
@@ -1310,6 +1350,7 @@ async def headless_main(
1310
  model: str | None = None,
1311
  max_iterations: int | None = None,
1312
  stream: bool = True,
 
1313
  ) -> None:
1314
  """Run a single prompt headlessly and exit."""
1315
  import logging
@@ -1322,11 +1363,13 @@ async def headless_main(
1322
 
1323
  if model:
1324
  config.model_name = model
 
 
1325
 
1326
  hf_token = resolve_hf_token()
1327
- if not hf_token and not is_local_model_id(config.model_name):
1328
  print(
1329
- "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.",
1330
  file=sys.stderr,
1331
  )
1332
  sys.exit(1)
@@ -1342,6 +1385,7 @@ async def headless_main(
1342
  config.max_iterations = max_iterations
1343
 
1344
  print(f"Model: {config.model_name}", file=sys.stderr)
 
1345
  print(f"Max iterations: {config.max_iterations}", file=sys.stderr)
1346
  print(f"Prompt: {prompt}", file=sys.stderr)
1347
  print("---", file=sys.stderr)
@@ -1349,7 +1393,9 @@ async def headless_main(
1349
  submission_queue: asyncio.Queue = asyncio.Queue()
1350
  event_queue: asyncio.Queue = asyncio.Queue()
1351
 
1352
- tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
 
 
1353
  session_holder: list = [None]
1354
 
1355
  agent_task = asyncio.create_task(
@@ -1361,7 +1407,7 @@ async def headless_main(
1361
  session_holder=session_holder,
1362
  hf_token=hf_token,
1363
  user_id=hf_user,
1364
- local_mode=True,
1365
  stream=stream,
1366
  notification_gateway=notification_gateway,
1367
  notification_destinations=config.messaging.default_auto_destinations(),
@@ -1556,6 +1602,11 @@ def cli():
1556
  action="store_true",
1557
  help="Disable token streaming (use non-streaming LLM calls)",
1558
  )
 
 
 
 
 
1559
  args = parser.parse_args()
1560
 
1561
  try:
@@ -1569,10 +1620,11 @@ def cli():
1569
  model=args.model,
1570
  max_iterations=max_iter,
1571
  stream=not args.no_stream,
 
1572
  )
1573
  )
1574
  else:
1575
- asyncio.run(main(model=args.model))
1576
  except KeyboardInterrupt:
1577
  print("\n\nGoodbye!")
1578
 
 
59
  logger = logging.getLogger(__name__)
60
 
61
 
62
+ def _apply_tool_runtime_override(config: Any, *, sandbox_tools: bool) -> str:
63
+ if sandbox_tools:
64
+ config.tool_runtime = "sandbox"
65
+ return getattr(config, "tool_runtime", "local")
66
+
67
+
68
+ def _is_local_tool_runtime(config: Any) -> bool:
69
+ return getattr(config, "tool_runtime", "local") == "local"
70
+
71
+
72
+ def _tool_runtime_label(local_mode: bool) -> str:
73
+ return "local filesystem" if local_mode else "HF sandbox"
74
+
75
+
76
+ async def _wait_for_initial_sandbox_preload(session_holder: list | None) -> None:
77
+ session = session_holder[0] if session_holder else None
78
+ task = getattr(session, "sandbox_preload_task", None)
79
+ if not task:
80
+ return
81
+ try:
82
+ await asyncio.shield(task)
83
+ except asyncio.CancelledError:
84
+ raise
85
+ except Exception:
86
+ # The sandbox tool will surface the stored preload error on first use.
87
+ return
88
+
89
+
90
  def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
91
  if tool_info.get("tool") != "hf_jobs":
92
  return False
 
985
  session = session_holder[0] if session_holder else None
986
  print(f"Model: {config.model_name}")
987
  print(f"Reasoning effort: {config.reasoning_effort or 'off'}")
988
+ print(f"Tool runtime: {_tool_runtime_label(_is_local_tool_runtime(config))}")
989
  if session:
990
  print(f"Turns: {session.turn_count}")
991
  print(f"Context items: {len(session.context_manager.items)}")
 
1105
  console.print(f"[green]Dataset is now {label}.[/green] {url}")
1106
 
1107
 
1108
+ async def main(model: str | None = None, sandbox_tools: bool = False):
1109
  """Interactive chat with the agent"""
1110
 
1111
  # Clear screen
 
1117
  config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
1118
  if model:
1119
  config.model_name = model
1120
+ _apply_tool_runtime_override(config, sandbox_tools=sandbox_tools)
1121
+ local_mode = _is_local_tool_runtime(config)
1122
 
1123
+ # HF token — required for Hub-backed models/tools and sandbox tools, but
1124
+ # not for local LLMs using only local filesystem tools.
1125
  hf_token = resolve_hf_token()
1126
+ if not hf_token and (not is_local_model_id(config.model_name) or not local_mode):
1127
  hf_token = await _prompt_and_save_hf_token(prompt_session)
1128
 
1129
  # Resolve username for banner
1130
  hf_user = _get_hf_user(hf_token)
1131
 
1132
+ print_banner(
1133
+ model=config.model_name,
1134
+ hf_user=hf_user,
1135
+ tool_runtime=_tool_runtime_label(local_mode),
1136
+ )
1137
 
1138
  # Pre-warm the HF router catalog in the background so /model switches
1139
  # don't block on a network fetch.
 
1152
 
1153
  notification_gateway = NotificationGateway(config.messaging)
1154
  await notification_gateway.start()
1155
+ # Create tool router with the selected CLI tool runtime.
1156
+ tool_router = ToolRouter(
1157
+ config.mcpServers, hf_token=hf_token, local_mode=local_mode
1158
+ )
1159
 
1160
  # Session holder for interrupt/model/status access
1161
  session_holder = [None]
 
1169
  session_holder=session_holder,
1170
  hf_token=hf_token,
1171
  user_id=hf_user,
1172
+ local_mode=local_mode,
1173
  stream=True,
1174
  notification_gateway=notification_gateway,
1175
  notification_destinations=config.messaging.default_auto_destinations(),
 
1191
  )
1192
 
1193
  await ready_event.wait()
1194
+ if not local_mode:
1195
+ await _wait_for_initial_sandbox_preload(session_holder)
1196
 
1197
  submission_id = [0]
1198
  # Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137
 
1350
  model: str | None = None,
1351
  max_iterations: int | None = None,
1352
  stream: bool = True,
1353
+ sandbox_tools: bool = False,
1354
  ) -> None:
1355
  """Run a single prompt headlessly and exit."""
1356
  import logging
 
1363
 
1364
  if model:
1365
  config.model_name = model
1366
+ _apply_tool_runtime_override(config, sandbox_tools=sandbox_tools)
1367
+ local_mode = _is_local_tool_runtime(config)
1368
 
1369
  hf_token = resolve_hf_token()
1370
+ if not hf_token and (not is_local_model_id(config.model_name) or not local_mode):
1371
  print(
1372
+ "ERROR: No HF token found. Set HF_TOKEN or run `hf auth login`.",
1373
  file=sys.stderr,
1374
  )
1375
  sys.exit(1)
 
1385
  config.max_iterations = max_iterations
1386
 
1387
  print(f"Model: {config.model_name}", file=sys.stderr)
1388
+ print(f"Tool runtime: {_tool_runtime_label(local_mode)}", file=sys.stderr)
1389
  print(f"Max iterations: {config.max_iterations}", file=sys.stderr)
1390
  print(f"Prompt: {prompt}", file=sys.stderr)
1391
  print("---", file=sys.stderr)
 
1393
  submission_queue: asyncio.Queue = asyncio.Queue()
1394
  event_queue: asyncio.Queue = asyncio.Queue()
1395
 
1396
+ tool_router = ToolRouter(
1397
+ config.mcpServers, hf_token=hf_token, local_mode=local_mode
1398
+ )
1399
  session_holder: list = [None]
1400
 
1401
  agent_task = asyncio.create_task(
 
1407
  session_holder=session_holder,
1408
  hf_token=hf_token,
1409
  user_id=hf_user,
1410
+ local_mode=local_mode,
1411
  stream=stream,
1412
  notification_gateway=notification_gateway,
1413
  notification_destinations=config.messaging.default_auto_destinations(),
 
1602
  action="store_true",
1603
  help="Disable token streaming (use non-streaming LLM calls)",
1604
  )
1605
+ parser.add_argument(
1606
+ "--sandbox-tools",
1607
+ action="store_true",
1608
+ help="Use HF Space sandbox tools instead of local filesystem tools",
1609
+ )
1610
  args = parser.parse_args()
1611
 
1612
  try:
 
1620
  model=args.model,
1621
  max_iterations=max_iter,
1622
  stream=not args.no_stream,
1623
+ sandbox_tools=args.sandbox_tools,
1624
  )
1625
  )
1626
  else:
1627
+ asyncio.run(main(model=args.model, sandbox_tools=args.sandbox_tools))
1628
  except KeyboardInterrupt:
1629
  print("\n\nGoodbye!")
1630
 
agent/prompts/system_prompt_v3.yaml CHANGED
@@ -66,7 +66,7 @@ system_prompt: |
66
  report_to="trackio"
67
  run_name="<descriptive-run-name>" # e.g. "sft_qwen3-4b_lr2e-5_bs128"
68
  project="<descriptive-project-name>" # keeps related runs grouped so you can compare them
69
- trackio_space_id="<username>/mlintern-<8-char-id>" # creates a public dashboard Space
70
  `project` and `trackio_space_id` can also be set via TRACKIO_PROJECT / TRACKIO_SPACE_ID env vars.
71
 
72
  Alerts are how iterations decide what to change. Use trackio.alert(title, text, level) at every decision point in training. Levels:
@@ -102,9 +102,18 @@ system_prompt: |
102
 
103
  # When submitting a training job
104
 
 
 
 
 
 
 
 
 
105
  Before calling hf_jobs, output a pre-flight check:
106
  - Reference implementation: [which example you based this on]
107
  - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
 
108
  - push_to_hub=True and hub_model_id set
109
  - timeout: [value] (based on: [model size] on [hardware])
110
  - Trackio monitoring included and deploying metrics to a public Space
@@ -127,7 +136,7 @@ system_prompt: |
127
 
128
  Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier.
129
 
130
- Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths.
131
 
132
 
133
  # When a task has 3+ steps
 
66
  report_to="trackio"
67
  run_name="<descriptive-run-name>" # e.g. "sft_qwen3-4b_lr2e-5_bs128"
68
  project="<descriptive-project-name>" # keeps related runs grouped so you can compare them
69
+ trackio_space_id="<username>/ml-intern-<8-char-id>" # creates a public dashboard Space
70
  `project` and `trackio_space_id` can also be set via TRACKIO_PROJECT / TRACKIO_SPACE_ID env vars.
71
 
72
  Alerts are how iterations decide what to change. Use trackio.alert(title, text, level) at every decision point in training. Levels:
 
102
 
103
  # When submitting a training job
104
 
105
+ Never pass a local machine path to hf_jobs.script, such as /Users/..., /home/..., /fsx/..., or a repo checkout path. HF Jobs runs in a fresh cloud environment where local files do not exist. For hf_jobs.script, use exactly one of:
106
+ - inline Python source code
107
+ - a file already written in the session sandbox, e.g. /app/train.py, ./train.py, or train.py
108
+ - a public/raw URL
109
+ If you wrote or tested a script locally, read the file content and submit it inline, or write it into the sandbox first.
110
+
111
+ GPU preflight is mandatory before hf_jobs when the job will run on GPU, or when the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, or torch.compile. First create a GPU sandbox with sandbox_create (t4-small minimum; choose larger hardware when VRAM requires it), run a tiny smoke test there using the same imports, model-loading path, training entrypoint, and a tiny dataset/subset, then fix failures before submitting. If you skip GPU sandbox preflight, state why before calling hf_jobs.
112
+
113
  Before calling hf_jobs, output a pre-flight check:
114
  - Reference implementation: [which example you based this on]
115
  - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
116
+ - GPU sandbox smoke test: [hardware and result, or explicitly not applicable because ...]
117
  - push_to_hub=True and hub_model_id set
118
  - timeout: [value] (based on: [model size] on [hardware])
119
  - Trackio monitoring included and deploying metrics to a public Space
 
136
 
137
  Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier.
138
 
139
+ Use a GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16/fp16, quantization, flash attention, torch.compile, or model loading. CPU sandboxes cannot test GPU code paths. If the available sandbox tiers cannot fit the full model path, test the largest useful smoke path, state what was not covered, and submit one HF job first.
140
 
141
 
142
  # When a task has 3+ steps
agent/tools/jobs_tool.py CHANGED
@@ -1112,11 +1112,14 @@ HF_JOBS_TOOL_SPEC = {
1112
  "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. "
1113
  "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n"
1114
  "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
 
 
 
1115
  "- Training config MUST include push_to_hub=True and hub_model_id. "
1116
  "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
1117
  "- Include trackio monitoring and provide the dashboard URL to the user. "
1118
  "When the script uses report_to='trackio', also pass `trackio_space_id` "
1119
- "(e.g. '<username>/mlintern-<8char>') and `trackio_project` as tool args — "
1120
  "they are injected as TRACKIO_SPACE_ID/TRACKIO_PROJECT env vars and let the UI embed the live dashboard.\n\n"
1121
  "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. "
1122
  "Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n"
@@ -1157,8 +1160,9 @@ HF_JOBS_TOOL_SPEC = {
1157
  "script": {
1158
  "type": "string",
1159
  "description": (
1160
- "Python code or sandbox file path (e.g. '/app/train.py') or URL. "
1161
  "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. "
 
1162
  "Mutually exclusive with 'command'."
1163
  ),
1164
  },
@@ -1204,7 +1208,7 @@ HF_JOBS_TOOL_SPEC = {
1204
  "type": "string",
1205
  "description": (
1206
  "Optional. The HF Space hosting the trackio dashboard for this run "
1207
- "(e.g. '<username>/mlintern-<8char>', under YOUR HF namespace). "
1208
  "Injected as TRACKIO_SPACE_ID env var and used by the UI to embed "
1209
  "the live dashboard. Set this whenever the script uses "
1210
  "report_to='trackio'. The Space is auto-created and seeded with the "
 
1112
  "- You MUST have called github_find_examples + github_read_file to find a working reference implementation. "
1113
  "Scripts based on your internal knowledge WILL use outdated APIs and fail.\n"
1114
  "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
1115
+ "- If the job runs on GPU, or the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, "
1116
+ "or torch.compile, you MUST create a GPU sandbox with sandbox_create first, run a tiny smoke test there, "
1117
+ "and fix failures before submitting. If skipped, state why before calling hf_jobs.\n"
1118
  "- Training config MUST include push_to_hub=True and hub_model_id. "
1119
  "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
1120
  "- Include trackio monitoring and provide the dashboard URL to the user. "
1121
  "When the script uses report_to='trackio', also pass `trackio_space_id` "
1122
+ "(e.g. '<username>/ml-intern-<8char>') and `trackio_project` as tool args — "
1123
  "they are injected as TRACKIO_SPACE_ID/TRACKIO_PROJECT env vars and let the UI embed the live dashboard.\n\n"
1124
  "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. "
1125
  "Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n"
 
1160
  "script": {
1161
  "type": "string",
1162
  "description": (
1163
+ "Python code, sandbox file path (e.g. '/app/train.py', './train.py', or bare 'train.py'), or URL. "
1164
  "Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. "
1165
+ "For GPU/model-loading training scripts, smoke-test in a GPU sandbox before submission. "
1166
  "Mutually exclusive with 'command'."
1167
  ),
1168
  },
 
1208
  "type": "string",
1209
  "description": (
1210
  "Optional. The HF Space hosting the trackio dashboard for this run "
1211
+ "(e.g. '<username>/ml-intern-<8char>', under YOUR HF namespace). "
1212
  "Injected as TRACKIO_SPACE_ID env var and used by the UI to embed "
1213
  "the live dashboard. Set this whenever the script uses "
1214
  "report_to='trackio'. The Space is auto-created and seeded with the "
agent/tools/plan_tool.py CHANGED
@@ -54,20 +54,24 @@ class PlanTool:
54
  "isError": True,
55
  }
56
 
57
- # Store the raw todos structure in memory
58
- _current_plan = todos
 
 
 
 
59
 
60
  # Emit plan update event if session is available
61
  if self.session:
62
  await self.session.send_event(
63
  Event(
64
  event_type="plan_update",
65
- data={"plan": todos},
66
  )
67
  )
68
 
69
  # Format only for display using terminal_display utility
70
- formatted_output = format_plan_tool_output(todos)
71
 
72
  return {
73
  "formatted": formatted_output,
 
54
  "isError": True,
55
  }
56
 
57
+ # Store a session-scoped copy so the runtime can tell whether a
58
+ # text-only model response is trying to stop while work remains.
59
+ stored_todos = [dict(todo) for todo in todos]
60
+ _current_plan = stored_todos
61
+ if self.session is not None:
62
+ self.session.current_plan = stored_todos
63
 
64
  # Emit plan update event if session is available
65
  if self.session:
66
  await self.session.send_event(
67
  Event(
68
  event_type="plan_update",
69
+ data={"plan": stored_todos},
70
  )
71
  )
72
 
73
  # Format only for display using terminal_display utility
74
+ formatted_output = format_plan_tool_output(stored_todos)
75
 
76
  return {
77
  "formatted": formatted_output,
agent/tools/sandbox_client.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  # /// script
3
  # requires-python = ">=3.10"
4
- # dependencies = ["huggingface_hub>=0.20.0", "httpx>=0.27.0"]
5
  # ///
6
  """
7
  Sandbox Tools — Agent-native primitives for HF Space dev-mode sandboxes.
@@ -615,18 +615,19 @@ class Sandbox:
615
  kwargs = {
616
  "from_id": template,
617
  "to_id": space_id,
 
618
  "private": private,
619
- "hardware": hardware,
620
  }
621
  if sleep_time is not None:
622
- kwargs["sleep_time"] = sleep_time
623
 
624
- api.duplicate_space(**kwargs)
625
  _log(f"Space created: https://huggingface.co/spaces/{space_id}")
626
 
627
  _check_cancel()
628
 
629
- # ``duplicate_space`` sends hardware and sleepTimeSeconds in the
630
  # initial create request. Avoid a second /hardware call: deployed HF
631
  # OAuth tokens can 401 on that endpoint for a just-created private
632
  # Space even though duplication itself succeeded. We rely on the
@@ -775,21 +776,23 @@ class Sandbox:
775
  f"Last status: {last_status}, last error: {last_err}"
776
  )
777
 
778
- def delete(self):
779
  """Delete the Space. Only works if this Sandbox created it."""
780
  if not self._owns_space:
781
  raise RuntimeError(
782
  f"This Sandbox did not create {self.space_id}. "
783
  f"Use self._hf_api.delete_repo() directly if you're sure."
784
  )
785
- print(f"Deleting sandbox: {self.space_id}...")
 
786
  self._hf_api.delete_repo(self.space_id, repo_type="space")
787
  # Clear ownership so a second cleanup call (e.g. delete_session +
788
  # _run_session.finally both fire) early-returns instead of retrying
789
  # a 404 delete and emitting a spurious ERROR log.
790
  self._owns_space = False
791
  self._client.close()
792
- print("Deleted.")
 
793
 
794
  def pause(self):
795
  """Pause the Space (stops billing, preserves state)."""
 
1
  #!/usr/bin/env python3
2
  # /// script
3
  # requires-python = ">=3.10"
4
+ # dependencies = ["huggingface_hub>=1.12.0", "httpx>=0.27.0"]
5
  # ///
6
  """
7
  Sandbox Tools — Agent-native primitives for HF Space dev-mode sandboxes.
 
615
  kwargs = {
616
  "from_id": template,
617
  "to_id": space_id,
618
+ "repo_type": "space",
619
  "private": private,
620
+ "space_hardware": hardware,
621
  }
622
  if sleep_time is not None:
623
+ kwargs["space_sleep_time"] = sleep_time
624
 
625
+ api.duplicate_repo(**kwargs)
626
  _log(f"Space created: https://huggingface.co/spaces/{space_id}")
627
 
628
  _check_cancel()
629
 
630
+ # ``duplicate_repo`` sends hardware and sleepTimeSeconds in the
631
  # initial create request. Avoid a second /hardware call: deployed HF
632
  # OAuth tokens can 401 on that endpoint for a just-created private
633
  # Space even though duplication itself succeeded. We rely on the
 
776
  f"Last status: {last_status}, last error: {last_err}"
777
  )
778
 
779
+ def delete(self, log: Callable[[str], object] | None = None):
780
  """Delete the Space. Only works if this Sandbox created it."""
781
  if not self._owns_space:
782
  raise RuntimeError(
783
  f"This Sandbox did not create {self.space_id}. "
784
  f"Use self._hf_api.delete_repo() directly if you're sure."
785
  )
786
+ if log:
787
+ log(f"Deleting sandbox: {self.space_id}...")
788
  self._hf_api.delete_repo(self.space_id, repo_type="space")
789
  # Clear ownership so a second cleanup call (e.g. delete_session +
790
  # _run_session.finally both fire) early-returns instead of retrying
791
  # a 404 delete and emitting a spurious ERROR log.
792
  self._owns_space = False
793
  self._client.close()
794
+ if log:
795
+ log("Deleted.")
796
 
797
  def pause(self):
798
  """Pause the Space (stops billing, preserves state)."""
agent/tools/sandbox_tool.py CHANGED
@@ -16,6 +16,7 @@ import logging
16
  import re
17
  import threading
18
  import weakref
 
19
  from datetime import datetime, timedelta, timezone
20
  from typing import Any
21
 
@@ -58,17 +59,41 @@ def _get_sandbox_create_lock(owner: str) -> asyncio.Lock:
58
  return lock
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def _looks_like_path(script: str) -> bool:
62
  """Return True if the script string looks like a file path (not inline code)."""
63
- return (
64
  isinstance(script, str)
65
  and script.strip() == script
66
  and not any(c in script for c in "\r\n\0")
67
- and (
68
- script.startswith("/")
69
- or script.startswith("./")
70
- or script.startswith("../")
71
- )
 
 
 
 
 
 
72
  )
73
 
74
 
@@ -303,14 +328,8 @@ async def _create_sandbox_locked(
303
  )
304
  )
305
 
306
- # Thread-safe log callback: posts tool_log events from the worker thread
307
- loop = asyncio.get_running_loop()
308
-
309
- def _log(msg: str) -> None:
310
- loop.call_soon_threadsafe(
311
- session.event_queue.put_nowait,
312
- Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
313
- )
314
 
315
  # Bridge asyncio cancel event to a threading.Event for the blocking create call.
316
  # We poll session._cancelled from the main loop in a background task and set
@@ -352,7 +371,7 @@ async def _create_sandbox_locked(
352
  if cancel_flag.is_set():
353
  if getattr(sb, "_owns_space", False):
354
  try:
355
- await asyncio.to_thread(sb.delete)
356
  except Exception as e:
357
  logger.warning(
358
  "Failed to delete cancelled sandbox %s: %s", sb.space_id, e
@@ -497,6 +516,7 @@ async def teardown_session_sandbox(session: Any) -> None:
497
  return
498
 
499
  space_id = getattr(sandbox, "space_id", None)
 
500
  last_err: Exception | None = None
501
  for attempt in range(3):
502
  try:
@@ -505,7 +525,7 @@ async def teardown_session_sandbox(session: Any) -> None:
505
  space_id,
506
  attempt + 1,
507
  )
508
- await asyncio.to_thread(sandbox.delete)
509
  from agent.core import telemetry
510
 
511
  await telemetry.record_sandbox_destroy(session, sandbox)
@@ -542,7 +562,7 @@ SANDBOX_CREATE_TOOL_SPEC = {
542
  "Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). "
543
  "If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n"
544
  "If you intend to run a training script in this sandbox that uses report_to='trackio', "
545
- "pass `trackio_space_id` (e.g. '<username>/mlintern-<8char>') and `trackio_project` so they "
546
  "are set as TRACKIO_SPACE_ID/TRACKIO_PROJECT secrets in the sandbox and the UI can embed the live dashboard.\n\n"
547
  "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n"
548
  ),
@@ -563,7 +583,7 @@ SANDBOX_CREATE_TOOL_SPEC = {
563
  "type": "string",
564
  "description": (
565
  "Optional. The HF Space hosting the trackio dashboard for runs in this sandbox "
566
- "(e.g. '<username>/mlintern-<8char>', under YOUR HF namespace). Injected as "
567
  "TRACKIO_SPACE_ID secret and surfaced to the UI. The Space is auto-created and "
568
  "seeded with the trackio dashboard — DO NOT pre-create it via hf_repo_git, "
569
  "that produces an empty Space that breaks the embed."
 
16
  import re
17
  import threading
18
  import weakref
19
+ from collections.abc import Callable
20
  from datetime import datetime, timedelta, timezone
21
  from typing import Any
22
 
 
59
  return lock
60
 
61
 
62
+ def _session_tool_logger(
63
+ session: Any, *, tool: str = "sandbox"
64
+ ) -> Callable[[str], object] | None:
65
+ event_queue = getattr(session, "event_queue", None)
66
+ if event_queue is None:
67
+ return None
68
+
69
+ loop = asyncio.get_running_loop()
70
+
71
+ def _log(msg: str) -> None:
72
+ loop.call_soon_threadsafe(
73
+ event_queue.put_nowait,
74
+ Event(event_type="tool_log", data={"tool": tool, "log": msg}),
75
+ )
76
+
77
+ return _log
78
+
79
+
80
  def _looks_like_path(script: str) -> bool:
81
  """Return True if the script string looks like a file path (not inline code)."""
82
+ if not (
83
  isinstance(script, str)
84
  and script.strip() == script
85
  and not any(c in script for c in "\r\n\0")
86
+ ):
87
+ return False
88
+
89
+ if script.startswith("http://") or script.startswith("https://"):
90
+ return False
91
+
92
+ return (
93
+ script.startswith("/")
94
+ or script.startswith("./")
95
+ or script.startswith("../")
96
+ or (script.endswith(".py") and not any(c.isspace() for c in script))
97
  )
98
 
99
 
 
328
  )
329
  )
330
 
331
+ # Thread-safe log callback: posts tool_log events from worker threads.
332
+ _log = _session_tool_logger(session) or (lambda msg: None)
 
 
 
 
 
 
333
 
334
  # Bridge asyncio cancel event to a threading.Event for the blocking create call.
335
  # We poll session._cancelled from the main loop in a background task and set
 
371
  if cancel_flag.is_set():
372
  if getattr(sb, "_owns_space", False):
373
  try:
374
+ await asyncio.to_thread(sb.delete, log=_log)
375
  except Exception as e:
376
  logger.warning(
377
  "Failed to delete cancelled sandbox %s: %s", sb.space_id, e
 
516
  return
517
 
518
  space_id = getattr(sandbox, "space_id", None)
519
+ delete_log = _session_tool_logger(session)
520
  last_err: Exception | None = None
521
  for attempt in range(3):
522
  try:
 
525
  space_id,
526
  attempt + 1,
527
  )
528
+ await asyncio.to_thread(sandbox.delete, log=delete_log)
529
  from agent.core import telemetry
530
 
531
  await telemetry.record_sandbox_destroy(session, sandbox)
 
562
  "Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). "
563
  "If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n"
564
  "If you intend to run a training script in this sandbox that uses report_to='trackio', "
565
+ "pass `trackio_space_id` (e.g. '<username>/ml-intern-<8char>') and `trackio_project` so they "
566
  "are set as TRACKIO_SPACE_ID/TRACKIO_PROJECT secrets in the sandbox and the UI can embed the live dashboard.\n\n"
567
  "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n"
568
  ),
 
583
  "type": "string",
584
  "description": (
585
  "Optional. The HF Space hosting the trackio dashboard for runs in this sandbox "
586
+ "(e.g. '<username>/ml-intern-<8char>', under YOUR HF namespace). Injected as "
587
  "TRACKIO_SPACE_ID secret and surfaced to the UI. The Space is auto-created and "
588
  "seeded with the trackio dashboard — DO NOT pre-create it via hf_repo_git, "
589
  "that produces an empty Space that breaks the embed."
agent/utils/terminal_display.py CHANGED
@@ -6,6 +6,7 @@ import asyncio
6
  import re
7
 
8
  from rich.console import Console
 
9
  from rich.markdown import Heading, Markdown
10
  from rich.panel import Panel
11
  from rich.theme import Theme
@@ -92,7 +93,11 @@ def get_console() -> Console:
92
  # ── Banner ─────────────────────────────────────────────────────────────
93
 
94
 
95
- def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
 
 
 
 
96
  """Print particle logo then CRT boot sequence with system info."""
97
  from agent.utils.particle_logo import run_particle_logo
98
  from agent.utils.crt_boot import run_boot_sequence
@@ -115,6 +120,7 @@ def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
115
  (f"{_I}Initializing agent runtime...", gold),
116
  (f"{_I} User: {user_label}", dim_gold),
117
  (f"{_I} Model: {model_label}", dim_gold),
 
118
  (f"{_I} Tools: loading...", dim_gold),
119
  ("", ""),
120
  (f"{_I}/help for commands · /model to switch · /quit to exit", gold),
@@ -446,23 +452,72 @@ def print_yolo_approve(count: int) -> None:
446
 
447
  # ── Help ───────────────────────────────────────────────────────────────
448
 
449
- HELP_TEXT = f"""\
450
- {_I}[bold]Commands[/bold]
451
- {_I} [cyan]/help[/cyan] Show this help
452
- {_I} [cyan]/undo[/cyan] Undo last turn
453
- {_I} [cyan]/compact[/cyan] Compact context window
454
- {_I} [cyan]/resume[/cyan] [index|id|path] Pick up from a log in ./session_logs
455
- {_I} [cyan]/model[/cyan] [id] Show available models or switch
456
- {_I} [cyan]/effort[/cyan] [level] Reasoning effort (minimal|low|medium|high|xhigh|max|off)
457
- {_I} [cyan]/yolo[/cyan] Toggle auto-approve mode
458
- {_I} [cyan]/status[/cyan] Current model & turn count
459
- {_I} [cyan]/share-traces[/cyan] [public|private] Show/flip visibility of your HF trace dataset
460
- {_I} [cyan]/quit[/cyan] Exit"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
 
463
  def print_help() -> None:
464
  _console.print()
465
- _console.print(HELP_TEXT)
466
  _console.print()
467
 
468
 
 
6
  import re
7
 
8
  from rich.console import Console
9
+ from rich.markup import escape
10
  from rich.markdown import Heading, Markdown
11
  from rich.panel import Panel
12
  from rich.theme import Theme
 
93
  # ── Banner ─────────────────────────────────────────────────────────────
94
 
95
 
96
+ def print_banner(
97
+ model: str | None = None,
98
+ hf_user: str | None = None,
99
+ tool_runtime: str | None = None,
100
+ ) -> None:
101
  """Print particle logo then CRT boot sequence with system info."""
102
  from agent.utils.particle_logo import run_particle_logo
103
  from agent.utils.crt_boot import run_boot_sequence
 
120
  (f"{_I}Initializing agent runtime...", gold),
121
  (f"{_I} User: {user_label}", dim_gold),
122
  (f"{_I} Model: {model_label}", dim_gold),
123
+ (f"{_I} Tool runtime: {tool_runtime or 'local filesystem'}", dim_gold),
124
  (f"{_I} Tools: loading...", dim_gold),
125
  ("", ""),
126
  (f"{_I}/help for commands · /model to switch · /quit to exit", gold),
 
452
 
453
  # ── Help ───────────────────────────────────────────────────────────────
454
 
455
+ HELP_ROWS: tuple[tuple[str, str, str], ...] = (
456
+ ("/help", "", "Show this help"),
457
+ ("/undo", "", "Undo last turn"),
458
+ ("/compact", "", "Compact context window"),
459
+ ("/resume", "[index|id|path]", "Pick up from ./session_logs"),
460
+ ("/model", "[id]", "Show available models or switch"),
461
+ (
462
+ "/effort",
463
+ "[level]",
464
+ "Set reasoning effort preference",
465
+ ),
466
+ ("/yolo", "", "Toggle auto-approve mode"),
467
+ ("/status", "", "Current model & turn count"),
468
+ (
469
+ "/share-traces",
470
+ "[public|private]",
471
+ "Show or change HF trace visibility",
472
+ ),
473
+ ("/quit", "", "Exit"),
474
+ )
475
+
476
+
477
+ def _help_column_widths(
478
+ rows: tuple[tuple[str, str, str], ...],
479
+ ) -> tuple[int, int]:
480
+ return (
481
+ max(len(command) for command, _, _ in rows),
482
+ max(len(args) for _, args, _ in rows),
483
+ )
484
+
485
+
486
+ def _format_help_row(
487
+ command: str,
488
+ args: str,
489
+ description: str,
490
+ command_width: int,
491
+ args_width: int,
492
+ ) -> str:
493
+ command_gap = " " * (command_width - len(command) + 2)
494
+ args_gap = " " * (args_width - len(args) + 2)
495
+ command_markup = f"[cyan]{escape(command)}[/cyan]"
496
+ args_markup = f"[muted]{escape(args)}[/muted]" if args else ""
497
+ return f"{_I} {command_markup}{command_gap}{args_markup}{args_gap}{description}"
498
+
499
+
500
+ def format_help_text(rows: tuple[tuple[str, str, str], ...] | None = None) -> str:
501
+ help_rows = HELP_ROWS if rows is None else rows
502
+ command_width, args_width = _help_column_widths(help_rows)
503
+ return "\n".join(
504
+ [f"{_I}[bold]Commands[/bold]"]
505
+ + [
506
+ _format_help_row(
507
+ command,
508
+ args,
509
+ description,
510
+ command_width,
511
+ args_width,
512
+ )
513
+ for command, args, description in help_rows
514
+ ]
515
+ )
516
 
517
 
518
  def print_help() -> None:
519
  _console.print()
520
+ _console.print(format_help_text())
521
  _console.print()
522
 
523
 
backend/dataset_uploads.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers for session-scoped dataset uploads to the Hugging Face Hub."""
2
+
3
+ import asyncio
4
+ import os
5
+ import re
6
+ import uuid
7
+ from dataclasses import dataclass
8
+ from urllib.parse import quote
9
+
10
+ from fastapi import HTTPException, UploadFile
11
+ from huggingface_hub import HfApi
12
+
13
+ MAX_DATASET_UPLOAD_BYTES = 100 * 1024 * 1024
14
+ ALLOWED_DATASET_EXTENSIONS = {"csv", "json", "jsonl"}
15
+ _SAFE_FILENAME_RE = re.compile(r"[^A-Za-z0-9._-]+")
16
+ _SAFE_NAMESPACE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,95}$")
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class DatasetUpload:
21
+ session_id: str
22
+ repo_id: str
23
+ repo_type: str
24
+ private: bool
25
+ upload_id: str
26
+ config_name: str
27
+ filename: str
28
+ original_filename: str
29
+ path_in_repo: str
30
+ size_bytes: int
31
+ format: str
32
+ hub_url: str
33
+ load_dataset_snippet: str
34
+
35
+ def response_payload(self) -> dict[str, str | int | bool]:
36
+ return {
37
+ "session_id": self.session_id,
38
+ "repo_id": self.repo_id,
39
+ "repo_type": self.repo_type,
40
+ "private": self.private,
41
+ "upload_id": self.upload_id,
42
+ "config_name": self.config_name,
43
+ "filename": self.filename,
44
+ "path_in_repo": self.path_in_repo,
45
+ "size_bytes": self.size_bytes,
46
+ "format": self.format,
47
+ "hub_url": self.hub_url,
48
+ "load_dataset_snippet": self.load_dataset_snippet,
49
+ }
50
+
51
+
52
+ def sanitize_dataset_filename(filename: str | None) -> str:
53
+ """Return a Hub-safe basename while preserving the extension."""
54
+ raw = os.path.basename(filename or "").strip()
55
+ if not raw:
56
+ raw = "dataset.csv"
57
+
58
+ safe = _SAFE_FILENAME_RE.sub("-", raw).strip(".-_")
59
+ if not safe:
60
+ safe = "dataset.csv"
61
+
62
+ stem, ext = os.path.splitext(safe)
63
+ if not stem:
64
+ stem = "dataset"
65
+ if not ext:
66
+ ext = ".csv"
67
+
68
+ max_stem_len = 96 - len(ext)
69
+ stem = stem[:max_stem_len].strip(".-_") or "dataset"
70
+ return f"{stem}{ext.lower()}"
71
+
72
+
73
+ def display_filename(filename: str | None, fallback: str) -> str:
74
+ raw = os.path.basename(filename or "").strip()
75
+ if not raw:
76
+ return fallback
77
+ cleaned = "".join(char for char in raw if ord(char) >= 32)
78
+ return cleaned[:160] or fallback
79
+
80
+
81
+ def dataset_format_from_filename(filename: str) -> str:
82
+ ext = os.path.splitext(filename)[1].lower().lstrip(".")
83
+ if ext not in ALLOWED_DATASET_EXTENSIONS:
84
+ raise HTTPException(
85
+ status_code=400,
86
+ detail="Only .csv, .json, and .jsonl dataset files are supported.",
87
+ )
88
+ return ext
89
+
90
+
91
+ def session_dataset_repo_id(hf_username: str | None, session_id: str) -> str:
92
+ namespace = (hf_username or "").strip()
93
+ if not namespace or not _SAFE_NAMESPACE_RE.fullmatch(namespace):
94
+ raise HTTPException(
95
+ status_code=400,
96
+ detail="Could not determine a valid Hugging Face namespace.",
97
+ )
98
+
99
+ safe_session_id = re.sub(r"[^A-Za-z0-9]+", "-", session_id).strip("-")
100
+ if not safe_session_id:
101
+ safe_session_id = uuid.uuid4().hex[:8]
102
+ return f"{namespace}/ml-intern-{safe_session_id[:8]}-datasets"
103
+
104
+
105
+ async def upload_size_bytes(upload: UploadFile) -> int:
106
+ await asyncio.to_thread(upload.file.seek, 0, os.SEEK_END)
107
+ size = await asyncio.to_thread(upload.file.tell)
108
+ await asyncio.to_thread(upload.file.seek, 0)
109
+ return int(size)
110
+
111
+
112
+ async def validate_dataset_upload(upload: UploadFile) -> tuple[str, str, int]:
113
+ dataset_format = dataset_format_from_filename(upload.filename or "")
114
+ safe_filename = sanitize_dataset_filename(upload.filename)
115
+ size = await upload_size_bytes(upload)
116
+ if size <= 0:
117
+ raise HTTPException(status_code=400, detail="Uploaded dataset file is empty.")
118
+ if size > MAX_DATASET_UPLOAD_BYTES:
119
+ raise HTTPException(
120
+ status_code=413,
121
+ detail="Dataset upload exceeds the 100 MB limit.",
122
+ )
123
+ return safe_filename, dataset_format, size
124
+
125
+
126
+ def dataset_hub_url(repo_id: str, path_in_repo: str) -> str:
127
+ quoted_path = quote(path_in_repo, safe="/")
128
+ return f"https://huggingface.co/datasets/{repo_id}/blob/main/{quoted_path}"
129
+
130
+
131
+ def dataset_config_name(upload_id: str) -> str:
132
+ safe_upload_id = re.sub(r"[^A-Za-z0-9]+", "_", upload_id).strip("_").lower()
133
+ if not safe_upload_id:
134
+ safe_upload_id = "dataset"
135
+ return f"upload_{safe_upload_id[:32]}"
136
+
137
+
138
+ def dataset_config_name_from_path(path_in_repo: str) -> str:
139
+ parts = path_in_repo.split("/")
140
+ if len(parts) >= 3 and parts[0] == "uploads":
141
+ return dataset_config_name(parts[1])
142
+ stem = os.path.splitext(os.path.basename(path_in_repo))[0]
143
+ return dataset_config_name(stem)
144
+
145
+
146
+ def is_dataset_upload_path(path_in_repo: str) -> bool:
147
+ parts = path_in_repo.split("/")
148
+ if len(parts) != 3 or parts[0] != "uploads" or not parts[1] or not parts[2]:
149
+ return False
150
+ extension = os.path.splitext(path_in_repo)[1].lower().lstrip(".")
151
+ return extension in ALLOWED_DATASET_EXTENSIONS
152
+
153
+
154
+ def unique_dataset_upload_paths(paths: list[str]) -> list[str]:
155
+ seen = set()
156
+ upload_paths = []
157
+ for path in paths:
158
+ if not is_dataset_upload_path(path) or path in seen:
159
+ continue
160
+ seen.add(path)
161
+ upload_paths.append(path)
162
+ return upload_paths
163
+
164
+
165
+ def load_dataset_snippet(repo_id: str, config_name: str) -> str:
166
+ return (
167
+ "from datasets import load_dataset\n\n"
168
+ f'dataset = load_dataset("{repo_id}", "{config_name}", '
169
+ 'split="train", token=True)'
170
+ )
171
+
172
+
173
+ def dataset_repo_card(repo_id: str, upload_paths: list[str]) -> bytes:
174
+ config_lines = []
175
+ unique_upload_paths = unique_dataset_upload_paths(upload_paths)
176
+ if unique_upload_paths:
177
+ config_lines.append("configs:")
178
+ for path in unique_upload_paths:
179
+ config_lines.extend(
180
+ [
181
+ f"- config_name: {dataset_config_name_from_path(path)}",
182
+ " data_files:",
183
+ " - split: train",
184
+ f' path: "{path}"',
185
+ ]
186
+ )
187
+
188
+ configs = "\n".join(config_lines)
189
+ if configs:
190
+ configs = f"{configs}\n"
191
+
192
+ content = f"""---
193
+ tags:
194
+ - ml-intern
195
+ - uploaded-dataset
196
+ {configs}---
197
+
198
+ # {repo_id}
199
+
200
+ Private dataset files uploaded through ML Intern.
201
+
202
+ Files are stored under `uploads/<upload_id>/` and are attached to the
203
+ corresponding ML Intern session context by Hub reference, not by copying file
204
+ contents into the chat.
205
+
206
+ Each uploaded file is exposed as its own dataset config so files with different
207
+ schemas can coexist in the same session repo.
208
+ """
209
+ return content.encode("utf-8")
210
+
211
+
212
+ def dataset_context_note(upload: DatasetUpload) -> str:
213
+ return f"""[SYSTEM: The user uploaded a dataset file for this session.
214
+
215
+ Use this Hugging Face Hub dataset reference when the task needs the uploaded data.
216
+ Do not look for the uploaded file on local disk and do not ask the user to
217
+ upload it again unless this Hub reference fails.
218
+
219
+ - Repo ID: {upload.repo_id}
220
+ - Repo type: dataset
221
+ - Dataset config: {upload.config_name}
222
+ - File in repo: {upload.path_in_repo}
223
+ - Original filename: {upload.original_filename}
224
+ - Stored filename: {upload.filename}
225
+ - Format: {upload.format}
226
+ - Size: {upload.size_bytes} bytes
227
+ - Hub URL: {upload.hub_url}
228
+
229
+ Load it with:
230
+ ```python
231
+ {upload.load_dataset_snippet}
232
+ ```
233
+ ]"""
234
+
235
+
236
+ async def push_dataset_upload_to_hub(
237
+ *,
238
+ upload: UploadFile,
239
+ session_id: str,
240
+ hf_username: str,
241
+ hf_token: str,
242
+ ) -> DatasetUpload:
243
+ safe_filename, dataset_format, size = await validate_dataset_upload(upload)
244
+ original_filename = display_filename(upload.filename, safe_filename)
245
+ upload_id = uuid.uuid4().hex[:12]
246
+ config_name = dataset_config_name(upload_id)
247
+ repo_id = session_dataset_repo_id(hf_username, session_id)
248
+ path_in_repo = f"uploads/{upload_id}/{safe_filename}"
249
+ hub_url = dataset_hub_url(repo_id, path_in_repo)
250
+ snippet = load_dataset_snippet(repo_id, config_name)
251
+ api = HfApi(token=hf_token)
252
+
253
+ await asyncio.to_thread(
254
+ api.create_repo,
255
+ repo_id=repo_id,
256
+ repo_type="dataset",
257
+ private=True,
258
+ exist_ok=True,
259
+ )
260
+ await asyncio.to_thread(
261
+ api.update_repo_settings,
262
+ repo_id=repo_id,
263
+ repo_type="dataset",
264
+ private=True,
265
+ )
266
+ repo_files = await asyncio.to_thread(
267
+ api.list_repo_files,
268
+ repo_id=repo_id,
269
+ repo_type="dataset",
270
+ )
271
+ upload_paths = unique_dataset_upload_paths([*repo_files, path_in_repo])
272
+ await asyncio.to_thread(upload.file.seek, 0)
273
+ file_bytes = await asyncio.to_thread(upload.file.read)
274
+ await asyncio.to_thread(
275
+ api.upload_file,
276
+ path_or_fileobj=file_bytes,
277
+ path_in_repo=path_in_repo,
278
+ repo_id=repo_id,
279
+ repo_type="dataset",
280
+ commit_message=f"Upload dataset file {safe_filename}",
281
+ )
282
+ await asyncio.to_thread(
283
+ api.upload_file,
284
+ path_or_fileobj=dataset_repo_card(repo_id, upload_paths),
285
+ path_in_repo="README.md",
286
+ repo_id=repo_id,
287
+ repo_type="dataset",
288
+ commit_message="Update ML Intern dataset upload configs",
289
+ )
290
+
291
+ return DatasetUpload(
292
+ session_id=session_id,
293
+ repo_id=repo_id,
294
+ repo_type="dataset",
295
+ private=True,
296
+ upload_id=upload_id,
297
+ config_name=config_name,
298
+ filename=safe_filename,
299
+ original_filename=original_filename,
300
+ path_in_repo=path_in_repo,
301
+ size_bytes=size,
302
+ format=dataset_format,
303
+ hub_url=hub_url,
304
+ load_dataset_snippet=snippet,
305
+ )
backend/models.py CHANGED
@@ -1,7 +1,7 @@
1
  """Pydantic models for API requests and responses."""
2
 
3
  from enum import Enum
4
- from typing import Any
5
 
6
  from pydantic import BaseModel, Field
7
 
@@ -120,6 +120,23 @@ class SessionYoloRequest(BaseModel):
120
  cost_cap_usd: float | None = Field(default=None, ge=0)
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  class HealthResponse(BaseModel):
124
  """Health check response."""
125
 
 
1
  """Pydantic models for API requests and responses."""
2
 
3
  from enum import Enum
4
+ from typing import Any, Literal
5
 
6
  from pydantic import BaseModel, Field
7
 
 
120
  cost_cap_usd: float | None = Field(default=None, ge=0)
121
 
122
 
123
+ class DatasetUploadResponse(BaseModel):
124
+ """Response for a dataset file uploaded to the Hub."""
125
+
126
+ session_id: str
127
+ repo_id: str
128
+ repo_type: Literal["dataset"] = "dataset"
129
+ private: bool = True
130
+ upload_id: str
131
+ config_name: str
132
+ filename: str
133
+ path_in_repo: str
134
+ size_bytes: int
135
+ format: Literal["csv", "json", "jsonl"]
136
+ hub_url: str
137
+ load_dataset_snippet: str
138
+
139
+
140
  class HealthResponse(BaseModel):
141
  """Health check response."""
142
 
backend/routes/agent.py CHANGED
@@ -21,10 +21,18 @@ from fastapi import (
21
  )
22
  from fastapi.exceptions import RequestValidationError
23
  from fastapi.responses import StreamingResponse
24
- from litellm import acompletion
 
25
  from pydantic import ValidationError
 
 
 
 
 
 
26
  from models import (
27
  ApprovalRequest,
 
28
  HealthResponse,
29
  LLMHealthResponse,
30
  SessionInfo,
@@ -58,6 +66,7 @@ PREMIUM_MODEL_IDS = {
58
  DEFAULT_CLAUDE_MODEL_ID,
59
  "openai/gpt-5.5",
60
  }
 
61
 
62
 
63
  def _claude_picker_model_id() -> str:
@@ -203,6 +212,63 @@ def _user_hf_token(user: dict[str, Any] | None) -> str | None:
203
  return user.get(INTERNAL_HF_TOKEN_KEY)
204
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  async def _check_session_access(
207
  session_id: str,
208
  user: dict[str, Any],
@@ -542,6 +608,86 @@ async def set_session_notifications(
542
  }
543
 
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  @router.patch("/session/{session_id}/yolo")
546
  async def set_session_yolo(
547
  session_id: str,
 
21
  )
22
  from fastapi.exceptions import RequestValidationError
23
  from fastapi.responses import StreamingResponse
24
+ from huggingface_hub.errors import HfHubHTTPError
25
+ from litellm import Message, acompletion
26
  from pydantic import ValidationError
27
+ from starlette.datastructures import FormData, UploadFile
28
+ from dataset_uploads import (
29
+ MAX_DATASET_UPLOAD_BYTES,
30
+ dataset_context_note,
31
+ push_dataset_upload_to_hub,
32
+ )
33
  from models import (
34
  ApprovalRequest,
35
+ DatasetUploadResponse,
36
  HealthResponse,
37
  LLMHealthResponse,
38
  SessionInfo,
 
66
  DEFAULT_CLAUDE_MODEL_ID,
67
  "openai/gpt-5.5",
68
  }
69
+ DATASET_UPLOAD_MULTIPART_SLACK_BYTES = 1024 * 1024
70
 
71
 
72
  def _claude_picker_model_id() -> str:
 
212
  return user.get(INTERNAL_HF_TOKEN_KEY)
213
 
214
 
215
+ def _reject_oversize_dataset_upload(request: Request) -> None:
216
+ raw_content_length = request.headers.get("content-length")
217
+ if raw_content_length is None:
218
+ return
219
+ try:
220
+ content_length = int(raw_content_length)
221
+ except (TypeError, ValueError):
222
+ return
223
+ if content_length > MAX_DATASET_UPLOAD_BYTES + DATASET_UPLOAD_MULTIPART_SLACK_BYTES:
224
+ raise HTTPException(
225
+ status_code=413,
226
+ detail="Dataset upload exceeds the 100 MB limit.",
227
+ )
228
+
229
+
230
+ def _dataset_upload_file_from_form(form: FormData) -> UploadFile:
231
+ uploaded_files = [
232
+ (key, value)
233
+ for key, value in form.multi_items()
234
+ if isinstance(value, UploadFile)
235
+ ]
236
+ if len(uploaded_files) != 1:
237
+ raise HTTPException(
238
+ status_code=400,
239
+ detail="Upload exactly one dataset file.",
240
+ )
241
+ field_name, upload = uploaded_files[0]
242
+ if field_name != "file":
243
+ raise HTTPException(
244
+ status_code=400,
245
+ detail="Missing 'file' upload field.",
246
+ )
247
+ return upload
248
+
249
+
250
+ def _dataset_upload_hub_http_exception(error: HfHubHTTPError) -> HTTPException:
251
+ status_code = getattr(error.response, "status_code", None)
252
+ if status_code == 401:
253
+ detail = "Hugging Face rejected the token used for the dataset upload."
254
+ return HTTPException(status_code=401, detail=detail)
255
+ if status_code == 403:
256
+ detail = (
257
+ "Hugging Face denied permission to create or write to the dataset repo."
258
+ )
259
+ return HTTPException(status_code=403, detail=detail)
260
+ if status_code == 404:
261
+ detail = "Could not find the Hugging Face namespace or dataset repo."
262
+ return HTTPException(status_code=404, detail=detail)
263
+ if status_code == 429:
264
+ detail = "Hugging Face Hub rate limit reached while uploading the dataset."
265
+ return HTTPException(status_code=429, detail=detail)
266
+ return HTTPException(
267
+ status_code=502,
268
+ detail="Hugging Face Hub upload failed. Please try again.",
269
+ )
270
+
271
+
272
  async def _check_session_access(
273
  session_id: str,
274
  user: dict[str, Any],
 
608
  }
609
 
610
 
611
+ @router.post("/session/{session_id}/datasets", response_model=DatasetUploadResponse)
612
+ async def upload_session_dataset(
613
+ session_id: str,
614
+ request: Request,
615
+ user: dict = Depends(get_current_user),
616
+ ) -> DatasetUploadResponse:
617
+ """Upload a CSV/JSON dataset file to a private Hub dataset for this session."""
618
+ file: UploadFile | None = None
619
+ try:
620
+ _reject_oversize_dataset_upload(request)
621
+ agent_session = await _check_session_access(session_id, user, request)
622
+ if not agent_session or not agent_session.is_active:
623
+ raise HTTPException(status_code=404, detail="Session not found")
624
+ if agent_session.is_processing:
625
+ raise HTTPException(
626
+ status_code=409,
627
+ detail="Cannot upload a dataset while the agent is processing.",
628
+ )
629
+ if agent_session.session.pending_approval:
630
+ raise HTTPException(
631
+ status_code=409,
632
+ detail="Approve or reject pending tools before uploading a dataset.",
633
+ )
634
+
635
+ hf_token = (
636
+ resolve_hf_request_token(request, include_env_fallback=False)
637
+ or _user_hf_token(user)
638
+ or resolve_hf_request_token(request)
639
+ )
640
+ if not hf_token:
641
+ raise HTTPException(
642
+ status_code=401,
643
+ detail="A Hugging Face token is required to upload datasets.",
644
+ )
645
+
646
+ form = await request.form(
647
+ max_files=1,
648
+ max_fields=1,
649
+ max_part_size=MAX_DATASET_UPLOAD_BYTES,
650
+ )
651
+ file = _dataset_upload_file_from_form(form)
652
+ hf_username = user.get("username") or agent_session.hf_username
653
+ uploaded = await push_dataset_upload_to_hub(
654
+ upload=file,
655
+ session_id=session_id,
656
+ hf_username=hf_username,
657
+ hf_token=hf_token,
658
+ )
659
+ agent_session.session.context_manager.add_message(
660
+ Message(role="user", content=dataset_context_note(uploaded))
661
+ )
662
+ await session_manager.persist_session_snapshot(agent_session)
663
+ logger.info(
664
+ "Uploaded dataset file %s to %s for session %s",
665
+ uploaded.filename,
666
+ uploaded.repo_id,
667
+ session_id,
668
+ )
669
+ return DatasetUploadResponse(**uploaded.response_payload())
670
+ except HTTPException:
671
+ raise
672
+ except HfHubHTTPError as e:
673
+ logger.warning(
674
+ "Hub rejected dataset upload for session %s: status=%s request_id=%s",
675
+ session_id,
676
+ getattr(e.response, "status_code", None),
677
+ getattr(e, "request_id", None),
678
+ )
679
+ raise _dataset_upload_hub_http_exception(e)
680
+ except Exception:
681
+ logger.exception("Dataset upload failed for session %s", session_id)
682
+ raise HTTPException(
683
+ status_code=502,
684
+ detail="Dataset upload failed. Please try again.",
685
+ )
686
+ finally:
687
+ if file is not None:
688
+ await file.close()
689
+
690
+
691
  @router.patch("/session/{session_id}/yolo")
692
  async def set_session_yolo(
693
  session_id: str,
configs/cli_agent_config.json CHANGED
@@ -7,6 +7,7 @@
7
  "yolo_mode": false,
8
  "confirm_cpu_jobs": true,
9
  "auto_file_upload": true,
 
10
  "messaging": {
11
  "enabled": false,
12
  "auto_event_types": ["approval_required", "error", "turn_complete"],
 
7
  "yolo_mode": false,
8
  "confirm_cpu_jobs": true,
9
  "auto_file_upload": true,
10
+ "tool_runtime": "local",
11
  "messaging": {
12
  "enabled": false,
13
  "auto_event_types": ["approval_required", "error", "turn_complete"],
frontend/src/components/Chat/ChatInput.tsx CHANGED
@@ -11,12 +11,15 @@ import {
11
  ListItemIcon,
12
  ListItemText,
13
  Chip,
 
14
  Snackbar,
 
15
  } from '@mui/material';
16
  import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
17
  import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
18
  import StopIcon from '@mui/icons-material/Stop';
19
- import { apiFetch } from '@/utils/api';
 
20
  import { useUserQuota } from '@/hooks/useUserQuota';
21
  import ClaudeCapDialog from '@/components/ClaudeCapDialog';
22
  import JobsUpgradeDialog from '@/components/JobsUpgradeDialog';
@@ -118,18 +121,49 @@ interface ChatInputProps {
118
  initialModelPath?: string | null;
119
  onSend: (text: string) => void;
120
  onStop?: () => void;
 
121
  isProcessing?: boolean;
122
  disabled?: boolean;
123
  placeholder?: string;
124
  }
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath);
127
  const isPremiumModel = (m: ModelOption) => isPremiumPath(m.modelPath);
128
  const firstFreeModel = (options: ModelOption[]) => options.find(m => !isPremiumModel(m)) ?? options[0];
129
 
130
- export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
 
 
 
 
 
 
 
 
 
 
131
  const [input, setInput] = useState('');
132
  const inputRef = useRef<HTMLTextAreaElement>(null);
 
133
  const [modelOptions, setModelOptions] = useState<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
134
  const modelOptionsRef = useRef<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
135
  const sessionIdRef = useRef<string | undefined>(sessionId);
@@ -150,6 +184,11 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
150
  const updateSessionModel = useSessionStore((s) => s.updateSessionModel);
151
  const [awaitingTopUp, setAwaitingTopUp] = useState(false);
152
  const [modelSwitchError, setModelSwitchError] = useState<string | null>(null);
 
 
 
 
 
153
  const lastSentRef = useRef<string>('');
154
 
155
  useEffect(() => {
@@ -216,12 +255,75 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
216
  }, [disabled, isProcessing]);
217
 
218
  const handleSend = useCallback(() => {
219
- if (input.trim() && !disabled) {
220
  lastSentRef.current = input;
221
  onSend(input);
222
  setInput('');
223
  }
224
- }, [input, disabled, onSend]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  // When the chat transport reports a premium-model quota 429, restore the typed
227
  // text so the user doesn't lose their message.
@@ -231,6 +333,18 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
231
  }
232
  }, [claudeQuotaExhausted]);
233
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  // Refresh the quota display whenever the session changes (user might
235
  // have started another tab that spent quota).
236
  useEffect(() => {
@@ -382,9 +496,12 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
382
  <Box
383
  className="composer"
384
  sx={{
385
- display: 'flex',
386
- gap: '10px',
387
- alignItems: 'flex-start',
 
 
 
388
  bgcolor: 'var(--composer-bg)',
389
  borderRadius: 'var(--radius-md)',
390
  p: '12px',
@@ -420,7 +537,7 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
420
  }
421
  }}
422
  sx={{
423
- flex: 1,
424
  '& .MuiInputBase-root': {
425
  p: 0,
426
  backgroundColor: 'transparent',
@@ -431,11 +548,46 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
431
  }
432
  }}
433
  />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  {isProcessing ? (
435
  <IconButton
436
  onClick={onStop}
437
  sx={{
438
- mt: 1,
 
 
439
  p: 1.5,
440
  borderRadius: '10px',
441
  color: 'var(--muted-text)',
@@ -455,9 +607,11 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
455
  ) : (
456
  <IconButton
457
  onClick={handleSend}
458
- disabled={disabled || !input.trim()}
459
  sx={{
460
- mt: 1,
 
 
461
  p: 1,
462
  borderRadius: '10px',
463
  color: 'var(--muted-text)',
@@ -475,6 +629,65 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
475
  </IconButton>
476
  )}
477
  </Box>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
  {/* Powered By Badge */}
480
  <Box
 
11
  ListItemIcon,
12
  ListItemText,
13
  Chip,
14
+ LinearProgress,
15
  Snackbar,
16
+ Tooltip,
17
  } from '@mui/material';
18
  import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
19
  import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
20
  import StopIcon from '@mui/icons-material/Stop';
21
+ import AddIcon from '@mui/icons-material/Add';
22
+ import { apiFetch, apiUpload } from '@/utils/api';
23
  import { useUserQuota } from '@/hooks/useUserQuota';
24
  import ClaudeCapDialog from '@/components/ClaudeCapDialog';
25
  import JobsUpgradeDialog from '@/components/JobsUpgradeDialog';
 
121
  initialModelPath?: string | null;
122
  onSend: (text: string) => void;
123
  onStop?: () => void;
124
+ onDatasetUploaded?: () => Promise<boolean> | boolean;
125
  isProcessing?: boolean;
126
  disabled?: boolean;
127
  placeholder?: string;
128
  }
129
 
130
+ interface DatasetUploadResponse {
131
+ session_id: string;
132
+ repo_id: string;
133
+ repo_type: 'dataset';
134
+ private: true;
135
+ upload_id: string;
136
+ config_name: string;
137
+ filename: string;
138
+ path_in_repo: string;
139
+ size_bytes: number;
140
+ format: 'csv' | 'json' | 'jsonl';
141
+ hub_url: string;
142
+ load_dataset_snippet: string;
143
+ }
144
+
145
+ const MAX_DATASET_UPLOAD_BYTES = 100 * 1024 * 1024;
146
+ const DATASET_UPLOAD_ACCEPT = '.csv,.json,.jsonl';
147
+ const DATASET_UPLOAD_EXTENSIONS = new Set(['csv', 'json', 'jsonl']);
148
+
149
  const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath);
150
  const isPremiumModel = (m: ModelOption) => isPremiumPath(m.modelPath);
151
  const firstFreeModel = (options: ModelOption[]) => options.find(m => !isPremiumModel(m)) ?? options[0];
152
 
153
+ const formatBytes = (bytes: number) => {
154
+ if (bytes < 1024) return `${bytes} B`;
155
+ if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`;
156
+ return `${(bytes / (1024 * 1024)).toFixed(1)} MB`;
157
+ };
158
+
159
+ const datasetRepoUrl = (repoId: string) => (
160
+ `https://huggingface.co/datasets/${repoId.split('/').map(encodeURIComponent).join('/')}`
161
+ );
162
+
163
+ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, onDatasetUploaded, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
164
  const [input, setInput] = useState('');
165
  const inputRef = useRef<HTMLTextAreaElement>(null);
166
+ const fileInputRef = useRef<HTMLInputElement>(null);
167
  const [modelOptions, setModelOptions] = useState<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
168
  const modelOptionsRef = useRef<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
169
  const sessionIdRef = useRef<string | undefined>(sessionId);
 
184
  const updateSessionModel = useSessionStore((s) => s.updateSessionModel);
185
  const [awaitingTopUp, setAwaitingTopUp] = useState(false);
186
  const [modelSwitchError, setModelSwitchError] = useState<string | null>(null);
187
+ const [datasetUploadError, setDatasetUploadError] = useState<string | null>(null);
188
+ const [datasetUploadSuccess, setDatasetUploadSuccess] = useState<string | null>(null);
189
+ const [uploadedDatasets, setUploadedDatasets] = useState<DatasetUploadResponse[]>([]);
190
+ const [isUploadingDataset, setIsUploadingDataset] = useState(false);
191
+ const [datasetUploadProgress, setDatasetUploadProgress] = useState<number | null>(null);
192
  const lastSentRef = useRef<string>('');
193
 
194
  useEffect(() => {
 
255
  }, [disabled, isProcessing]);
256
 
257
  const handleSend = useCallback(() => {
258
+ if (input.trim() && !disabled && !isUploadingDataset) {
259
  lastSentRef.current = input;
260
  onSend(input);
261
  setInput('');
262
  }
263
+ }, [input, disabled, isUploadingDataset, onSend]);
264
+
265
+ const handleDatasetUploadClick = useCallback(() => {
266
+ fileInputRef.current?.click();
267
+ }, []);
268
+
269
+ const handleDatasetFileChange = useCallback(
270
+ async (event: React.ChangeEvent<HTMLInputElement>) => {
271
+ const file = event.target.files?.[0];
272
+ event.target.value = '';
273
+ if (!file) return;
274
+
275
+ if (!sessionId) {
276
+ setDatasetUploadError('Start a session before uploading a dataset.');
277
+ return;
278
+ }
279
+
280
+ const extension = file.name.split('.').pop()?.toLowerCase() || '';
281
+ if (!DATASET_UPLOAD_EXTENSIONS.has(extension)) {
282
+ setDatasetUploadError('Only CSV, JSON, and JSONL dataset files are supported.');
283
+ return;
284
+ }
285
+ if (file.size > MAX_DATASET_UPLOAD_BYTES) {
286
+ setDatasetUploadError(
287
+ `Dataset files must be 100 MB or smaller. ${file.name} is ${formatBytes(file.size)}.`
288
+ );
289
+ return;
290
+ }
291
+ if (file.size === 0) {
292
+ setDatasetUploadError('Uploaded dataset file is empty.');
293
+ return;
294
+ }
295
+
296
+ const formData = new FormData();
297
+ formData.append('file', file);
298
+ setIsUploadingDataset(true);
299
+ setDatasetUploadProgress(0);
300
+ setDatasetUploadError(null);
301
+ setDatasetUploadSuccess(null);
302
+ try {
303
+ const res = await apiUpload(`/api/session/${sessionId}/datasets`, formData, {
304
+ onProgress: ({ percent }) => {
305
+ setDatasetUploadProgress(percent !== null && percent < 100 ? percent : null);
306
+ },
307
+ });
308
+ if (!res.ok) {
309
+ setDatasetUploadError(await readApiErrorMessage(res, 'Dataset upload failed.'));
310
+ return;
311
+ }
312
+ const payload = await res.json() as DatasetUploadResponse;
313
+ setUploadedDatasets((previous) => [payload, ...previous]);
314
+ setDatasetUploadSuccess(`Uploaded ${payload.filename} to ${payload.repo_id}`);
315
+ await onDatasetUploaded?.();
316
+ } catch (error) {
317
+ setDatasetUploadError(
318
+ error instanceof Error ? error.message : 'Dataset upload failed.'
319
+ );
320
+ } finally {
321
+ setIsUploadingDataset(false);
322
+ setDatasetUploadProgress(null);
323
+ }
324
+ },
325
+ [sessionId, onDatasetUploaded],
326
+ );
327
 
328
  // When the chat transport reports a premium-model quota 429, restore the typed
329
  // text so the user doesn't lose their message.
 
333
  }
334
  }, [claudeQuotaExhausted]);
335
 
336
+ useEffect(() => {
337
+ if (!datasetUploadError) return;
338
+ const timeout = window.setTimeout(() => setDatasetUploadError(null), 7000);
339
+ return () => window.clearTimeout(timeout);
340
+ }, [datasetUploadError]);
341
+
342
+ useEffect(() => {
343
+ if (!datasetUploadSuccess) return;
344
+ const timeout = window.setTimeout(() => setDatasetUploadSuccess(null), 5000);
345
+ return () => window.clearTimeout(timeout);
346
+ }, [datasetUploadSuccess]);
347
+
348
  // Refresh the quota display whenever the session changes (user might
349
  // have started another tab that spent quota).
350
  useEffect(() => {
 
496
  <Box
497
  className="composer"
498
  sx={{
499
+ display: 'grid',
500
+ gridTemplateColumns: 'auto 1fr auto',
501
+ gridTemplateRows: 'auto auto',
502
+ columnGap: '10px',
503
+ rowGap: '4px',
504
+ alignItems: 'end',
505
  bgcolor: 'var(--composer-bg)',
506
  borderRadius: 'var(--radius-md)',
507
  p: '12px',
 
537
  }
538
  }}
539
  sx={{
540
+ gridColumn: '1 / -1',
541
  '& .MuiInputBase-root': {
542
  p: 0,
543
  backgroundColor: 'transparent',
 
548
  }
549
  }}
550
  />
551
+ <input
552
+ ref={fileInputRef}
553
+ type="file"
554
+ accept={DATASET_UPLOAD_ACCEPT}
555
+ onChange={handleDatasetFileChange}
556
+ style={{ display: 'none' }}
557
+ />
558
+ <Box sx={{ gridColumn: '1', gridRow: '2', display: 'flex' }}>
559
+ <Tooltip title="Upload dataset">
560
+ <span>
561
+ <IconButton
562
+ onClick={handleDatasetUploadClick}
563
+ disabled={disabled || isProcessing || isUploadingDataset || !sessionId}
564
+ sx={{
565
+ p: 1,
566
+ borderRadius: '50%',
567
+ color: uploadedDatasets.length ? 'var(--accent-yellow)' : 'var(--muted-text)',
568
+ transition: 'all 0.2s',
569
+ '&:hover': {
570
+ color: 'var(--accent-yellow)',
571
+ bgcolor: 'var(--hover-bg)',
572
+ },
573
+ '&.Mui-disabled': {
574
+ opacity: 0.3,
575
+ },
576
+ }}
577
+ aria-label="Upload dataset"
578
+ >
579
+ <AddIcon fontSize="small" />
580
+ </IconButton>
581
+ </span>
582
+ </Tooltip>
583
+ </Box>
584
  {isProcessing ? (
585
  <IconButton
586
  onClick={onStop}
587
  sx={{
588
+ gridColumn: '3',
589
+ gridRow: '2',
590
+ justifySelf: 'end',
591
  p: 1.5,
592
  borderRadius: '10px',
593
  color: 'var(--muted-text)',
 
607
  ) : (
608
  <IconButton
609
  onClick={handleSend}
610
+ disabled={disabled || isUploadingDataset || !input.trim()}
611
  sx={{
612
+ gridColumn: '3',
613
+ gridRow: '2',
614
+ justifySelf: 'end',
615
  p: 1,
616
  borderRadius: '10px',
617
  color: 'var(--muted-text)',
 
629
  </IconButton>
630
  )}
631
  </Box>
632
+ {isUploadingDataset && (
633
+ <Box sx={{ mt: 1, px: 0.5 }}>
634
+ <LinearProgress
635
+ variant={datasetUploadProgress === null ? 'indeterminate' : 'determinate'}
636
+ value={datasetUploadProgress ?? 0}
637
+ aria-label="Dataset upload progress"
638
+ sx={{
639
+ height: 4,
640
+ borderRadius: 999,
641
+ bgcolor: 'rgba(255,255,255,0.08)',
642
+ '& .MuiLinearProgress-bar': {
643
+ borderRadius: 999,
644
+ bgcolor: 'var(--accent-yellow)',
645
+ },
646
+ }}
647
+ />
648
+ </Box>
649
+ )}
650
+ {(datasetUploadError || datasetUploadSuccess) && (
651
+ <Box sx={{ display: 'flex', justifyContent: 'center', mt: 1 }}>
652
+ <Alert
653
+ severity={datasetUploadError ? 'error' : 'success'}
654
+ variant="filled"
655
+ onClose={() => {
656
+ setDatasetUploadError(null);
657
+ setDatasetUploadSuccess(null);
658
+ }}
659
+ sx={{ fontSize: '0.8rem', maxWidth: 520, width: '100%' }}
660
+ >
661
+ {datasetUploadError ?? datasetUploadSuccess}
662
+ </Alert>
663
+ </Box>
664
+ )}
665
+ {uploadedDatasets.length > 0 && (
666
+ <Box sx={{ display: 'flex', flexWrap: 'wrap', gap: 0.75, justifyContent: 'center', mt: 1 }}>
667
+ {uploadedDatasets.map((dataset) => (
668
+ <Chip
669
+ key={dataset.upload_id}
670
+ size="small"
671
+ label={`Dataset: ${dataset.filename}`}
672
+ component="a"
673
+ href={datasetRepoUrl(dataset.repo_id)}
674
+ target="_blank"
675
+ rel="noreferrer"
676
+ clickable
677
+ sx={{
678
+ maxWidth: '100%',
679
+ bgcolor: 'rgba(255,255,255,0.08)',
680
+ color: 'var(--text)',
681
+ border: '1px solid var(--divider)',
682
+ '& .MuiChip-label': {
683
+ overflow: 'hidden',
684
+ textOverflow: 'ellipsis',
685
+ },
686
+ }}
687
+ />
688
+ ))}
689
+ </Box>
690
+ )}
691
 
692
  {/* Powered By Badge */}
693
  <Box
frontend/src/components/SessionChat.tsx CHANGED
@@ -27,7 +27,16 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
27
  const sessionMeta = sessions.find((s) => s.id === sessionId);
28
  const isExpired = sessionMeta?.expired === true;
29
 
30
- const { messages, sendMessage, stop, status, undoLastTurn, editAndRegenerate, approveTools } = useAgentChat({
 
 
 
 
 
 
 
 
 
31
  sessionId,
32
  isActive,
33
  onReady: () => logger.log(`Session ${sessionId} ready`),
@@ -116,6 +125,7 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
116
  initialModelPath={sessionMeta?.model}
117
  onSend={handleSendMessage}
118
  onStop={handleStop}
 
119
  isProcessing={busy}
120
  disabled={!isConnected || activityStatus.type === 'waiting-approval'}
121
  placeholder={
 
27
  const sessionMeta = sessions.find((s) => s.id === sessionId);
28
  const isExpired = sessionMeta?.expired === true;
29
 
30
+ const {
31
+ messages,
32
+ sendMessage,
33
+ stop,
34
+ status,
35
+ undoLastTurn,
36
+ editAndRegenerate,
37
+ approveTools,
38
+ refreshMessages,
39
+ } = useAgentChat({
40
  sessionId,
41
  isActive,
42
  onReady: () => logger.log(`Session ${sessionId} ready`),
 
125
  initialModelPath={sessionMeta?.model}
126
  onSend={handleSendMessage}
127
  onStop={handleStop}
128
+ onDatasetUploaded={refreshMessages}
129
  isProcessing={busy}
130
  disabled={!isConnected || activityStatus.type === 'waiting-approval'}
131
  placeholder={
frontend/src/hooks/useAgentChat.ts CHANGED
@@ -804,6 +804,48 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
804
  }
805
  }, [sessionId, chat]);
806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
  return {
808
  messages: chat.messages,
809
  sendMessage: chat.sendMessage,
@@ -812,5 +854,6 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
812
  undoLastTurn,
813
  editAndRegenerate,
814
  approveTools,
 
815
  };
816
  }
 
804
  }
805
  }, [sessionId, chat]);
806
 
807
+ const refreshMessages = useCallback(async () => {
808
+ try {
809
+ const [msgsRes, infoRes] = await Promise.all([
810
+ apiFetch(`/api/session/${sessionId}/messages`),
811
+ apiFetch(`/api/session/${sessionId}`),
812
+ ]);
813
+ if (!msgsRes.ok) return false;
814
+
815
+ const data = await msgsRes.json();
816
+ if (!Array.isArray(data) || data.length === 0) return false;
817
+ saveBackendMessages(sessionId, data);
818
+
819
+ let pendingIds: Set<string> | undefined;
820
+ if (infoRes.ok) {
821
+ const info = await infoRes.json();
822
+ if (info.pending_approval && Array.isArray(info.pending_approval)) {
823
+ pendingIds = new Set(
824
+ info.pending_approval.map((t: { tool_call_id: string }) => t.tool_call_id)
825
+ );
826
+ if (pendingIds.size > 0) setNeedsAttention(sessionId, true);
827
+ }
828
+ if (info.auto_approval) {
829
+ updateSessionYolo(sessionId, info.auto_approval);
830
+ }
831
+ }
832
+
833
+ const uiMsgs = llmMessagesToUIMessages(
834
+ data,
835
+ pendingIds,
836
+ chatActionsRef.current.messages,
837
+ );
838
+ const setMsgs = chatActionsRef.current.setMessages;
839
+ if (setMsgs && uiMsgs.length > 0) {
840
+ setMsgs(uiMsgs);
841
+ saveMessages(sessionId, uiMsgs);
842
+ }
843
+ return true;
844
+ } catch {
845
+ return false;
846
+ }
847
+ }, [sessionId, setNeedsAttention, updateSessionYolo]);
848
+
849
  return {
850
  messages: chat.messages,
851
  sendMessage: chat.sendMessage,
 
854
  undoLastTurn,
855
  editAndRegenerate,
856
  approveTools,
857
+ refreshMessages,
858
  };
859
  }
frontend/src/utils/api.ts CHANGED
@@ -7,15 +7,36 @@
7
 
8
  import { triggerLogin } from '@/hooks/useAuth';
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  /** Wrapper around fetch with credentials and common headers. */
11
  export async function apiFetch(
12
  path: string,
13
  options: RequestInit = {}
14
  ): Promise<Response> {
15
- const headers: Record<string, string> = {
16
- 'Content-Type': 'application/json',
17
- ...(options.headers as Record<string, string>),
18
- };
 
19
 
20
  const response = await fetch(path, {
21
  ...options,
@@ -23,19 +44,50 @@ export async function apiFetch(
23
  credentials: 'include', // Send cookies with every request
24
  });
25
 
26
- // Handle 401 — redirect to login
27
- if (response.status === 401) {
28
- try {
29
- const authStatus = await fetch('/auth/status', { credentials: 'include' });
30
- const data = await authStatus.json();
31
- if (data.auth_enabled) {
32
- triggerLogin();
33
- throw new Error('Authentication required — redirecting to login.');
34
- }
35
- } catch (e) {
36
- if (e instanceof Error && e.message.includes('redirecting')) throw e;
37
- }
38
- }
39
 
40
  return response;
41
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import { triggerLogin } from '@/hooks/useAuth';
9
 
10
+ export interface ApiUploadProgress {
11
+ loaded: number;
12
+ total: number | null;
13
+ percent: number | null;
14
+ }
15
+
16
+ async function handleUnauthorized(response: Response): Promise<void> {
17
+ if (response.status !== 401) return;
18
+ try {
19
+ const authStatus = await fetch('/auth/status', { credentials: 'include' });
20
+ const data = await authStatus.json();
21
+ if (data.auth_enabled) {
22
+ triggerLogin();
23
+ throw new Error('Authentication required — redirecting to login.');
24
+ }
25
+ } catch (e) {
26
+ if (e instanceof Error && e.message.includes('redirecting')) throw e;
27
+ }
28
+ }
29
+
30
  /** Wrapper around fetch with credentials and common headers. */
31
  export async function apiFetch(
32
  path: string,
33
  options: RequestInit = {}
34
  ): Promise<Response> {
35
+ const headers = new Headers(options.headers);
36
+ const isFormData = options.body instanceof FormData;
37
+ if (!isFormData && !headers.has('Content-Type')) {
38
+ headers.set('Content-Type', 'application/json');
39
+ }
40
 
41
  const response = await fetch(path, {
42
  ...options,
 
44
  credentials: 'include', // Send cookies with every request
45
  });
46
 
47
+ await handleUnauthorized(response);
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  return response;
50
+ }
51
+
52
+ function headersFromXhr(rawHeaders: string): Headers {
53
+ const headers = new Headers();
54
+ rawHeaders.trim().split(/[\r\n]+/).forEach((line) => {
55
+ const separator = line.indexOf(':');
56
+ if (separator <= 0) return;
57
+ headers.append(
58
+ line.slice(0, separator).trim(),
59
+ line.slice(separator + 1).trim(),
60
+ );
61
+ });
62
+ return headers;
63
+ }
64
+
65
+ export async function apiUpload(
66
+ path: string,
67
+ formData: FormData,
68
+ options: { onProgress?: (progress: ApiUploadProgress) => void } = {},
69
+ ): Promise<Response> {
70
+ return new Promise<Response>((resolve, reject) => {
71
+ const xhr = new XMLHttpRequest();
72
+ xhr.open('POST', path);
73
+ xhr.withCredentials = true;
74
+ xhr.upload.onprogress = (event) => {
75
+ const total = event.lengthComputable ? event.total : null;
76
+ const percent = total
77
+ ? Math.min(100, Math.round((event.loaded / total) * 100))
78
+ : null;
79
+ options.onProgress?.({ loaded: event.loaded, total, percent });
80
+ };
81
+ xhr.onerror = () => reject(new Error('Network error while uploading.'));
82
+ xhr.onabort = () => reject(new Error('Dataset upload was canceled.'));
83
+ xhr.onload = () => {
84
+ const response = new Response(xhr.responseText, {
85
+ status: xhr.status,
86
+ statusText: xhr.statusText,
87
+ headers: headersFromXhr(xhr.getAllResponseHeaders()),
88
+ });
89
+ handleUnauthorized(response).then(() => resolve(response)).catch(reject);
90
+ };
91
+ xhr.send(formData);
92
+ });
93
+ }
pyproject.toml CHANGED
@@ -28,6 +28,7 @@ dependencies = [
28
  "websockets>=13.0",
29
  "apscheduler>=3.10,<4",
30
  "pymongo>=4.17.0",
 
31
  ]
32
 
33
  [project.optional-dependencies]
 
28
  "websockets>=13.0",
29
  "apscheduler>=3.10,<4",
30
  "pymongo>=4.17.0",
31
+ "python-multipart>=0.0.20",
32
  ]
33
 
34
  [project.optional-dependencies]
tests/unit/test_cli_rendering.py CHANGED
@@ -1,10 +1,12 @@
1
  """Regression tests for interactive CLI rendering and research model routing."""
2
 
 
3
  import sys
4
  from io import StringIO
5
  from types import SimpleNamespace
6
 
7
  import pytest
 
8
 
9
  import agent.main as main_mod
10
  from agent.tools.research_tool import _get_research_model
@@ -29,6 +31,50 @@ def test_non_anthropic_research_model_is_unchanged():
29
  assert _get_research_model("openai/gpt-5.4") == "openai/gpt-5.4"
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def test_subagent_display_does_not_spawn_background_redraw(monkeypatch):
33
  calls: list[object] = []
34
 
@@ -52,10 +98,11 @@ def test_subagent_display_does_not_spawn_background_redraw(monkeypatch):
52
 
53
 
54
  def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
55
- seen: dict[str, str | None] = {}
56
 
57
- async def fake_main(*, model=None):
58
  seen["model"] = model
 
59
 
60
  monkeypatch.setattr(sys, "argv", ["ml-intern", "--model", "openai/gpt-5.5"])
61
  monkeypatch.setattr(main_mod, "main", fake_main)
@@ -63,6 +110,61 @@ def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
63
  main_mod.cli()
64
 
65
  assert seen["model"] == "openai/gpt-5.5"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  @pytest.mark.asyncio
@@ -70,9 +172,10 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
70
  class StopAfterBanner(Exception):
71
  pass
72
 
73
- def fake_banner(*, model=None, hf_user=None):
74
  assert model == "openai/gpt-5.5"
75
  assert hf_user == "tester"
 
76
  raise StopAfterBanner
77
 
78
  monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
@@ -85,9 +188,150 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
85
  lambda _path, **_kwargs: SimpleNamespace(
86
  model_name="moonshotai/Kimi-K2.6",
87
  mcpServers={},
 
88
  ),
89
  )
90
  monkeypatch.setattr(main_mod, "print_banner", fake_banner)
91
 
92
  with pytest.raises(StopAfterBanner):
93
  await main_mod.main(model="openai/gpt-5.5")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Regression tests for interactive CLI rendering and research model routing."""
2
 
3
+ import asyncio
4
  import sys
5
  from io import StringIO
6
  from types import SimpleNamespace
7
 
8
  import pytest
9
+ from rich.console import Console
10
 
11
  import agent.main as main_mod
12
  from agent.tools.research_tool import _get_research_model
 
31
  assert _get_research_model("openai/gpt-5.4") == "openai/gpt-5.4"
32
 
33
 
34
+ def test_help_output_keeps_descriptions_aligned(monkeypatch):
35
+ output = StringIO()
36
+ console = Console(
37
+ file=output,
38
+ color_system=None,
39
+ theme=terminal_display._THEME,
40
+ width=120,
41
+ )
42
+ monkeypatch.setattr(terminal_display, "_console", console)
43
+
44
+ terminal_display.print_help()
45
+
46
+ lines = [line.rstrip() for line in output.getvalue().splitlines() if line.strip()]
47
+ description_columns = []
48
+ for command, args, description in terminal_display.HELP_ROWS:
49
+ line = next(line for line in lines if command in line)
50
+ if args:
51
+ assert args in line
52
+ description_columns.append(line.index(description))
53
+
54
+ assert len(set(description_columns)) == 1
55
+
56
+
57
+ def test_help_output_recomputes_widths_from_rows():
58
+ rows = terminal_display.HELP_ROWS + (
59
+ ("/longer-command", "[longer-args]", "Synthetic help row"),
60
+ )
61
+ output = StringIO()
62
+ Console(
63
+ file=output,
64
+ color_system=None,
65
+ theme=terminal_display._THEME,
66
+ width=140,
67
+ ).print(terminal_display.format_help_text(rows))
68
+
69
+ lines = [line.rstrip() for line in output.getvalue().splitlines() if line.strip()]
70
+ description_columns = [
71
+ next(line for line in lines if command in line).index(description)
72
+ for command, _args, description in rows
73
+ ]
74
+
75
+ assert len(set(description_columns)) == 1
76
+
77
+
78
  def test_subagent_display_does_not_spawn_background_redraw(monkeypatch):
79
  calls: list[object] = []
80
 
 
98
 
99
 
100
  def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
101
+ seen: dict[str, object] = {}
102
 
103
+ async def fake_main(*, model=None, sandbox_tools=False):
104
  seen["model"] = model
105
+ seen["sandbox_tools"] = sandbox_tools
106
 
107
  monkeypatch.setattr(sys, "argv", ["ml-intern", "--model", "openai/gpt-5.5"])
108
  monkeypatch.setattr(main_mod, "main", fake_main)
 
110
  main_mod.cli()
111
 
112
  assert seen["model"] == "openai/gpt-5.5"
113
+ assert seen["sandbox_tools"] is False
114
+
115
+
116
+ def test_cli_forwards_sandbox_flag_to_interactive_main(monkeypatch):
117
+ seen: dict[str, object] = {}
118
+
119
+ async def fake_main(*, model=None, sandbox_tools=False):
120
+ seen["model"] = model
121
+ seen["sandbox_tools"] = sandbox_tools
122
+
123
+ monkeypatch.setattr(sys, "argv", ["ml-intern", "--sandbox-tools"])
124
+ monkeypatch.setattr(main_mod, "main", fake_main)
125
+
126
+ main_mod.cli()
127
+
128
+ assert seen == {"model": None, "sandbox_tools": True}
129
+
130
+
131
+ def test_cli_forwards_sandbox_flag_to_headless_main(monkeypatch):
132
+ seen: dict[str, object] = {}
133
+
134
+ async def fake_headless_main(
135
+ prompt,
136
+ *,
137
+ model=None,
138
+ max_iterations=None,
139
+ stream=True,
140
+ sandbox_tools=False,
141
+ ):
142
+ seen.update(
143
+ {
144
+ "prompt": prompt,
145
+ "model": model,
146
+ "max_iterations": max_iterations,
147
+ "stream": stream,
148
+ "sandbox_tools": sandbox_tools,
149
+ }
150
+ )
151
+
152
+ monkeypatch.setattr(
153
+ sys,
154
+ "argv",
155
+ ["ml-intern", "--sandbox-tools", "--no-stream", "train a model"],
156
+ )
157
+ monkeypatch.setattr(main_mod, "headless_main", fake_headless_main)
158
+
159
+ main_mod.cli()
160
+
161
+ assert seen == {
162
+ "prompt": "train a model",
163
+ "model": None,
164
+ "max_iterations": None,
165
+ "stream": False,
166
+ "sandbox_tools": True,
167
+ }
168
 
169
 
170
  @pytest.mark.asyncio
 
172
  class StopAfterBanner(Exception):
173
  pass
174
 
175
+ def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
176
  assert model == "openai/gpt-5.5"
177
  assert hf_user == "tester"
178
+ assert tool_runtime == "local filesystem"
179
  raise StopAfterBanner
180
 
181
  monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
 
188
  lambda _path, **_kwargs: SimpleNamespace(
189
  model_name="moonshotai/Kimi-K2.6",
190
  mcpServers={},
191
+ tool_runtime="local",
192
  ),
193
  )
194
  monkeypatch.setattr(main_mod, "print_banner", fake_banner)
195
 
196
  with pytest.raises(StopAfterBanner):
197
  await main_mod.main(model="openai/gpt-5.5")
198
+
199
+
200
+ @pytest.mark.asyncio
201
+ async def test_local_model_local_runtime_skips_hf_token_prompt(monkeypatch):
202
+ class StopAfterBanner(Exception):
203
+ pass
204
+
205
+ async def fail_prompt(_prompt_session):
206
+ raise AssertionError("local model with local tools should not prompt")
207
+
208
+ def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
209
+ assert model == "llamacpp/model"
210
+ assert hf_user is None
211
+ assert tool_runtime == "local filesystem"
212
+ raise StopAfterBanner
213
+
214
+ monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
215
+ monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
216
+ monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None)
217
+ monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fail_prompt)
218
+ monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: None)
219
+ monkeypatch.setattr(
220
+ main_mod,
221
+ "load_config",
222
+ lambda _path, **_kwargs: SimpleNamespace(
223
+ model_name="llamacpp/model",
224
+ mcpServers={},
225
+ tool_runtime="local",
226
+ ),
227
+ )
228
+ monkeypatch.setattr(main_mod, "print_banner", fake_banner)
229
+
230
+ with pytest.raises(StopAfterBanner):
231
+ await main_mod.main()
232
+
233
+
234
+ @pytest.mark.asyncio
235
+ async def test_local_model_sandbox_runtime_prompts_for_hf_token(monkeypatch):
236
+ class StopAfterBanner(Exception):
237
+ pass
238
+
239
+ prompted = False
240
+
241
+ async def fake_prompt(_prompt_session):
242
+ nonlocal prompted
243
+ prompted = True
244
+ return "hf-token"
245
+
246
+ def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
247
+ assert model == "llamacpp/model"
248
+ assert hf_user == "tester"
249
+ assert tool_runtime == "HF sandbox"
250
+ raise StopAfterBanner
251
+
252
+ monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
253
+ monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
254
+ monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None)
255
+ monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fake_prompt)
256
+ monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester")
257
+ monkeypatch.setattr(
258
+ main_mod,
259
+ "load_config",
260
+ lambda _path, **_kwargs: SimpleNamespace(
261
+ model_name="llamacpp/model",
262
+ mcpServers={},
263
+ tool_runtime="local",
264
+ ),
265
+ )
266
+ monkeypatch.setattr(main_mod, "print_banner", fake_banner)
267
+
268
+ with pytest.raises(StopAfterBanner):
269
+ await main_mod.main(sandbox_tools=True)
270
+
271
+ assert prompted is True
272
+
273
+
274
+ @pytest.mark.asyncio
275
+ async def test_interactive_main_passes_sandbox_runtime_to_tool_router(monkeypatch):
276
+ class StopAfterToolRouter(Exception):
277
+ pass
278
+
279
+ seen: dict[str, object] = {}
280
+
281
+ class FakeGateway:
282
+ def __init__(self, _config):
283
+ pass
284
+
285
+ async def start(self):
286
+ pass
287
+
288
+ class FakeToolRouter:
289
+ def __init__(self, mcp_servers, *, hf_token=None, local_mode=True):
290
+ seen["mcp_servers"] = mcp_servers
291
+ seen["hf_token"] = hf_token
292
+ seen["local_mode"] = local_mode
293
+ raise StopAfterToolRouter
294
+
295
+ from agent.core import hf_router_catalog
296
+
297
+ monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
298
+ monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
299
+ monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: "hf-token")
300
+ monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester")
301
+ monkeypatch.setattr(main_mod, "print_banner", lambda **_kwargs: None)
302
+ monkeypatch.setattr(hf_router_catalog, "prewarm", lambda: None)
303
+ monkeypatch.setattr(
304
+ main_mod,
305
+ "load_config",
306
+ lambda _path, **_kwargs: SimpleNamespace(
307
+ model_name="llamacpp/model",
308
+ mcpServers={"server": object()},
309
+ messaging=SimpleNamespace(default_auto_destinations=lambda: []),
310
+ tool_runtime="local",
311
+ ),
312
+ )
313
+ monkeypatch.setattr(main_mod, "NotificationGateway", FakeGateway)
314
+ monkeypatch.setattr(main_mod, "ToolRouter", FakeToolRouter)
315
+
316
+ with pytest.raises(StopAfterToolRouter):
317
+ await main_mod.main(sandbox_tools=True)
318
+
319
+ assert seen["hf_token"] == "hf-token"
320
+ assert seen["local_mode"] is False
321
+
322
+
323
+ @pytest.mark.asyncio
324
+ async def test_initial_sandbox_preload_waits_before_prompt():
325
+ waited = False
326
+
327
+ async def preload():
328
+ nonlocal waited
329
+ await asyncio.sleep(0)
330
+ waited = True
331
+
332
+ task = asyncio.create_task(preload())
333
+ await main_mod._wait_for_initial_sandbox_preload(
334
+ [SimpleNamespace(sandbox_preload_task=task)]
335
+ )
336
+
337
+ assert waited is True
tests/unit/test_config.py CHANGED
@@ -1,5 +1,8 @@
1
  import json
2
 
 
 
 
3
  from agent import config as config_module
4
 
5
 
@@ -121,3 +124,35 @@ def test_slack_user_defaults_can_be_disabled(tmp_path, monkeypatch):
121
 
122
  assert not config.messaging.enabled
123
  assert config.messaging.destinations == {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
 
3
+ import pytest
4
+ from pydantic import ValidationError
5
+
6
  from agent import config as config_module
7
 
8
 
 
124
 
125
  assert not config.messaging.enabled
126
  assert config.messaging.destinations == {}
127
+
128
+
129
+ def test_tool_runtime_defaults_to_local(tmp_path):
130
+ config_path = tmp_path / "config.json"
131
+ _write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
132
+
133
+ config = config_module.load_config(str(config_path))
134
+
135
+ assert config.tool_runtime == "local"
136
+
137
+
138
+ def test_user_config_can_set_sandbox_tool_runtime(tmp_path, monkeypatch):
139
+ config_path = tmp_path / "config.json"
140
+ user_config_path = tmp_path / "user-config.json"
141
+ _write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
142
+ _write_json(user_config_path, {"tool_runtime": "sandbox"})
143
+ monkeypatch.setenv("ML_INTERN_CLI_CONFIG", str(user_config_path))
144
+
145
+ config = config_module.load_config(str(config_path), include_user_defaults=True)
146
+
147
+ assert config.tool_runtime == "sandbox"
148
+
149
+
150
+ def test_invalid_tool_runtime_is_rejected(tmp_path):
151
+ config_path = tmp_path / "config.json"
152
+ _write_json(
153
+ config_path,
154
+ {"model_name": "moonshotai/Kimi-K2.6", "tool_runtime": "hybrid"},
155
+ )
156
+
157
+ with pytest.raises(ValidationError):
158
+ config_module.load_config(str(config_path))
tests/unit/test_dataset_uploads.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import sys
3
+ from pathlib import Path
4
+ from types import SimpleNamespace
5
+
6
+ import httpx
7
+ import pytest
8
+ from fastapi import HTTPException, UploadFile
9
+ from huggingface_hub.errors import HfHubHTTPError
10
+ from starlette.datastructures import FormData
11
+
12
+ _BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend"
13
+ if str(_BACKEND_DIR) not in sys.path:
14
+ sys.path.insert(0, str(_BACKEND_DIR))
15
+
16
+ import dataset_uploads # noqa: E402
17
+ from routes import agent # noqa: E402
18
+
19
+
20
+ def _upload(filename: str, content: bytes = b"a,b\n1,2\n") -> UploadFile:
21
+ return UploadFile(filename=filename, file=io.BytesIO(content))
22
+
23
+
24
+ def _track_close(upload: UploadFile):
25
+ state = {"closed": False}
26
+ original_close = upload.close
27
+
28
+ async def close():
29
+ state["closed"] = True
30
+ await original_close()
31
+
32
+ upload.close = close
33
+ return state
34
+
35
+
36
+ def _request(
37
+ upload: UploadFile | None = None,
38
+ headers: dict[str, str] | None = None,
39
+ ):
40
+ state = {"form_called": False}
41
+
42
+ class FakeRequest:
43
+ def __init__(self):
44
+ self.headers = headers or {}
45
+ self.cookies = {}
46
+
47
+ async def form(self, **_kwargs):
48
+ state["form_called"] = True
49
+ if upload is None:
50
+ raise AssertionError("request.form() should not be called")
51
+ return FormData([("file", upload)])
52
+
53
+ return FakeRequest(), state
54
+
55
+
56
+ def test_sanitize_dataset_filename_strips_paths_and_unsafe_chars():
57
+ assert (
58
+ dataset_uploads.sanitize_dataset_filename("../../bad file (final).CSV")
59
+ == "bad-file-final.csv"
60
+ )
61
+ assert dataset_uploads.sanitize_dataset_filename("") == "dataset.csv"
62
+
63
+
64
+ def test_dataset_format_rejects_unsupported_extension():
65
+ with pytest.raises(HTTPException) as exc_info:
66
+ dataset_uploads.dataset_format_from_filename("notes.txt")
67
+
68
+ assert exc_info.value.status_code == 400
69
+
70
+ with pytest.raises(HTTPException):
71
+ dataset_uploads.dataset_format_from_filename("notes")
72
+
73
+
74
+ def test_dataset_repo_card_exposes_each_upload_as_config():
75
+ card = dataset_uploads.dataset_repo_card(
76
+ "alice/ml-intern-s1-datasets",
77
+ [
78
+ "README.md",
79
+ "uploads/oldabc/rows.jsonl",
80
+ "uploads/oldabc/rows.jsonl",
81
+ "uploads/newdef/table.csv",
82
+ ],
83
+ ).decode("utf-8")
84
+
85
+ assert "configs:" in card
86
+ assert "- config_name: upload_oldabc" in card
87
+ assert ' path: "uploads/oldabc/rows.jsonl"' in card
88
+ assert "- config_name: upload_newdef" in card
89
+ assert ' path: "uploads/newdef/table.csv"' in card
90
+ assert card.count("- config_name: upload_oldabc") == 1
91
+
92
+
93
+ @pytest.mark.asyncio
94
+ async def test_validate_dataset_upload_rejects_size_over_limit(monkeypatch):
95
+ monkeypatch.setattr(dataset_uploads, "MAX_DATASET_UPLOAD_BYTES", 3)
96
+ upload = _upload("rows.csv", b"abcd")
97
+ try:
98
+ with pytest.raises(HTTPException) as exc_info:
99
+ await dataset_uploads.validate_dataset_upload(upload)
100
+ finally:
101
+ await upload.close()
102
+
103
+ assert exc_info.value.status_code == 413
104
+
105
+
106
+ @pytest.mark.asyncio
107
+ async def test_push_dataset_upload_creates_private_repo_and_uploads_file(monkeypatch):
108
+ instances = []
109
+
110
+ class FakeApi:
111
+ def __init__(self, token):
112
+ self.token = token
113
+ self.create_calls = []
114
+ self.settings_calls = []
115
+ self.list_calls = []
116
+ self.upload_calls = []
117
+ instances.append(self)
118
+
119
+ def create_repo(self, **kwargs):
120
+ self.create_calls.append(kwargs)
121
+
122
+ def update_repo_settings(self, **kwargs):
123
+ self.settings_calls.append(kwargs)
124
+
125
+ def list_repo_files(self, **kwargs):
126
+ self.list_calls.append(kwargs)
127
+ return [
128
+ "README.md",
129
+ "uploads/oldupload/old.jsonl",
130
+ "uploads/notes.txt",
131
+ ]
132
+
133
+ def upload_file(self, **kwargs):
134
+ if kwargs["path_in_repo"] != "README.md":
135
+ assert kwargs["path_or_fileobj"] == b"a,b\n1,2\n"
136
+ self.upload_calls.append(kwargs)
137
+
138
+ monkeypatch.setattr(dataset_uploads, "HfApi", FakeApi)
139
+ monkeypatch.setattr(
140
+ dataset_uploads.uuid,
141
+ "uuid4",
142
+ lambda: SimpleNamespace(hex="feedfacecafebeef"),
143
+ )
144
+
145
+ upload = _upload("../Data Set.CSV")
146
+ try:
147
+ result = await dataset_uploads.push_dataset_upload_to_hub(
148
+ upload=upload,
149
+ session_id="12345678-90ab-cdef-1234-567890abcdef",
150
+ hf_username="alice",
151
+ hf_token="hf-token",
152
+ )
153
+ finally:
154
+ await upload.close()
155
+
156
+ api = instances[0]
157
+ assert api.token == "hf-token"
158
+ assert api.create_calls == [
159
+ {
160
+ "repo_id": "alice/ml-intern-12345678-datasets",
161
+ "repo_type": "dataset",
162
+ "private": True,
163
+ "exist_ok": True,
164
+ }
165
+ ]
166
+ assert api.settings_calls == [
167
+ {
168
+ "repo_id": "alice/ml-intern-12345678-datasets",
169
+ "repo_type": "dataset",
170
+ "private": True,
171
+ }
172
+ ]
173
+ assert api.list_calls == [
174
+ {
175
+ "repo_id": "alice/ml-intern-12345678-datasets",
176
+ "repo_type": "dataset",
177
+ }
178
+ ]
179
+ assert [call["path_in_repo"] for call in api.upload_calls] == [
180
+ "uploads/feedfacecafe/Data-Set.csv",
181
+ "README.md",
182
+ ]
183
+ readme = api.upload_calls[1]["path_or_fileobj"].decode("utf-8")
184
+ assert "- config_name: upload_oldupload" in readme
185
+ assert ' path: "uploads/oldupload/old.jsonl"' in readme
186
+ assert "- config_name: upload_feedfacecafe" in readme
187
+ assert ' path: "uploads/feedfacecafe/Data-Set.csv"' in readme
188
+ assert result.repo_id == "alice/ml-intern-12345678-datasets"
189
+ assert result.config_name == "upload_feedfacecafe"
190
+ assert result.format == "csv"
191
+ assert result.load_dataset_snippet == (
192
+ "from datasets import load_dataset\n\n"
193
+ 'dataset = load_dataset("alice/ml-intern-12345678-datasets", '
194
+ '"upload_feedfacecafe", split="train", token=True)'
195
+ )
196
+
197
+
198
+ @pytest.mark.asyncio
199
+ async def test_upload_route_requires_hf_token_without_parsing_upload(monkeypatch):
200
+ monkeypatch.delenv("HF_TOKEN", raising=False)
201
+ upload = _upload("rows.csv")
202
+ close_state = _track_close(upload)
203
+ request, request_state = _request(upload)
204
+
205
+ async def fake_check_session_access(*_args, **_kwargs):
206
+ return SimpleNamespace(
207
+ is_active=True,
208
+ is_processing=False,
209
+ session=SimpleNamespace(pending_approval=None),
210
+ hf_username="alice",
211
+ )
212
+
213
+ monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
214
+
215
+ try:
216
+ with pytest.raises(HTTPException) as exc_info:
217
+ await agent.upload_session_dataset(
218
+ "s1",
219
+ request,
220
+ {"user_id": "u1", "username": "alice"},
221
+ )
222
+
223
+ assert exc_info.value.status_code == 401
224
+ assert request_state["form_called"] is False
225
+ assert close_state["closed"] is False
226
+ finally:
227
+ await upload.close()
228
+
229
+
230
+ @pytest.mark.asyncio
231
+ async def test_upload_route_rejects_content_length_before_parsing(monkeypatch):
232
+ upload = _upload("rows.csv")
233
+ close_state = _track_close(upload)
234
+ request, request_state = _request(
235
+ upload,
236
+ headers={
237
+ "content-length": str(
238
+ dataset_uploads.MAX_DATASET_UPLOAD_BYTES
239
+ + agent.DATASET_UPLOAD_MULTIPART_SLACK_BYTES
240
+ + 1
241
+ )
242
+ },
243
+ )
244
+
245
+ async def fake_check_session_access(*_args, **_kwargs):
246
+ raise AssertionError("session access should not run for oversized uploads")
247
+
248
+ monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
249
+
250
+ try:
251
+ with pytest.raises(HTTPException) as exc_info:
252
+ await agent.upload_session_dataset(
253
+ "s1",
254
+ request,
255
+ {
256
+ "user_id": "u1",
257
+ "username": "alice",
258
+ agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
259
+ },
260
+ )
261
+
262
+ assert exc_info.value.status_code == 413
263
+ assert request_state["form_called"] is False
264
+ assert close_state["closed"] is False
265
+ finally:
266
+ await upload.close()
267
+
268
+
269
+ @pytest.mark.asyncio
270
+ async def test_upload_route_rejects_busy_session_without_parsing_upload(monkeypatch):
271
+ upload = _upload("rows.csv")
272
+ close_state = _track_close(upload)
273
+ request, request_state = _request(upload)
274
+
275
+ async def fake_check_session_access(*_args, **_kwargs):
276
+ return SimpleNamespace(
277
+ is_active=True,
278
+ is_processing=True,
279
+ session=SimpleNamespace(pending_approval=None),
280
+ hf_username="alice",
281
+ )
282
+
283
+ monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
284
+
285
+ with pytest.raises(HTTPException) as exc_info:
286
+ await agent.upload_session_dataset(
287
+ "s1",
288
+ request,
289
+ {
290
+ "user_id": "u1",
291
+ "username": "alice",
292
+ agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
293
+ },
294
+ )
295
+
296
+ assert exc_info.value.status_code == 409
297
+ assert request_state["form_called"] is False
298
+ assert close_state["closed"] is False
299
+ await upload.close()
300
+
301
+
302
+ @pytest.mark.asyncio
303
+ async def test_upload_route_appends_context_note_and_persists(monkeypatch):
304
+ upload = _upload("rows.jsonl", b'{"text":"hi"}\n')
305
+ close_state = _track_close(upload)
306
+ request, request_state = _request(upload)
307
+ messages = []
308
+ persisted = []
309
+ agent_session = SimpleNamespace(
310
+ is_active=True,
311
+ is_processing=False,
312
+ session=SimpleNamespace(
313
+ pending_approval=None,
314
+ context_manager=SimpleNamespace(add_message=messages.append),
315
+ ),
316
+ hf_username="alice",
317
+ )
318
+ uploaded = dataset_uploads.DatasetUpload(
319
+ session_id="s1",
320
+ repo_id="alice/ml-intern-s1-datasets",
321
+ repo_type="dataset",
322
+ private=True,
323
+ upload_id="abc123",
324
+ config_name="upload_abc123",
325
+ filename="rows.jsonl",
326
+ original_filename="rows.jsonl",
327
+ path_in_repo="uploads/abc123/rows.jsonl",
328
+ size_bytes=14,
329
+ format="jsonl",
330
+ hub_url="https://huggingface.co/datasets/alice/ml-intern-s1-datasets/blob/main/uploads/abc123/rows.jsonl",
331
+ load_dataset_snippet='dataset = load_dataset("json")',
332
+ )
333
+
334
+ async def fake_check_session_access(*_args, **_kwargs):
335
+ return agent_session
336
+
337
+ async def fake_push_dataset_upload_to_hub(**kwargs):
338
+ assert kwargs["upload"] is upload
339
+ assert kwargs["hf_token"] == "hf-token"
340
+ return uploaded
341
+
342
+ async def fake_persist_session_snapshot(value):
343
+ persisted.append(value)
344
+
345
+ monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
346
+ monkeypatch.setattr(
347
+ agent, "push_dataset_upload_to_hub", fake_push_dataset_upload_to_hub
348
+ )
349
+ monkeypatch.setattr(
350
+ agent.session_manager,
351
+ "persist_session_snapshot",
352
+ fake_persist_session_snapshot,
353
+ )
354
+
355
+ response = await agent.upload_session_dataset(
356
+ "s1",
357
+ request,
358
+ {
359
+ "user_id": "u1",
360
+ "username": "alice",
361
+ agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
362
+ },
363
+ )
364
+
365
+ assert response.repo_id == uploaded.repo_id
366
+ assert response.config_name == uploaded.config_name
367
+ assert response.path_in_repo == uploaded.path_in_repo
368
+ assert len(messages) == 1
369
+ assert messages[0].role == "user"
370
+ assert messages[0].content.startswith("[SYSTEM:")
371
+ assert uploaded.config_name in messages[0].content
372
+ assert uploaded.path_in_repo in messages[0].content
373
+ assert persisted == [agent_session]
374
+ assert request_state["form_called"] is True
375
+ assert close_state["closed"] is True
376
+
377
+
378
+ @pytest.mark.asyncio
379
+ async def test_upload_route_closes_upload_when_hub_upload_fails(monkeypatch):
380
+ upload = _upload("rows.csv")
381
+ close_state = _track_close(upload)
382
+ request, request_state = _request(upload)
383
+
384
+ async def fake_check_session_access(*_args, **_kwargs):
385
+ return SimpleNamespace(
386
+ is_active=True,
387
+ is_processing=False,
388
+ session=SimpleNamespace(pending_approval=None),
389
+ hf_username="alice",
390
+ )
391
+
392
+ async def fake_push_dataset_upload_to_hub(**_kwargs):
393
+ raise RuntimeError("hub unavailable")
394
+
395
+ monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
396
+ monkeypatch.setattr(
397
+ agent, "push_dataset_upload_to_hub", fake_push_dataset_upload_to_hub
398
+ )
399
+
400
+ with pytest.raises(HTTPException) as exc_info:
401
+ await agent.upload_session_dataset(
402
+ "s1",
403
+ request,
404
+ {
405
+ "user_id": "u1",
406
+ "username": "alice",
407
+ agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
408
+ },
409
+ )
410
+
411
+ assert exc_info.value.status_code == 502
412
+ assert exc_info.value.detail == "Dataset upload failed. Please try again."
413
+ assert request_state["form_called"] is True
414
+ assert close_state["closed"] is True
415
+
416
+
417
+ @pytest.mark.asyncio
418
+ async def test_upload_route_maps_hub_permission_error_safely(monkeypatch):
419
+ upload = _upload("rows.csv")
420
+ close_state = _track_close(upload)
421
+ request, request_state = _request(upload)
422
+
423
+ async def fake_check_session_access(*_args, **_kwargs):
424
+ return SimpleNamespace(
425
+ is_active=True,
426
+ is_processing=False,
427
+ session=SimpleNamespace(pending_approval=None),
428
+ hf_username="alice",
429
+ )
430
+
431
+ async def fake_push_dataset_upload_to_hub(**_kwargs):
432
+ response = httpx.Response(
433
+ 403,
434
+ request=httpx.Request("POST", "https://huggingface.co/api/datasets"),
435
+ headers={"x-request-id": "req-123"},
436
+ )
437
+ raise HfHubHTTPError(
438
+ "403 Forbidden: token hf_secret cannot write",
439
+ response=response,
440
+ server_message="token hf_secret cannot write",
441
+ )
442
+
443
+ monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
444
+ monkeypatch.setattr(
445
+ agent, "push_dataset_upload_to_hub", fake_push_dataset_upload_to_hub
446
+ )
447
+
448
+ with pytest.raises(HTTPException) as exc_info:
449
+ await agent.upload_session_dataset(
450
+ "s1",
451
+ request,
452
+ {
453
+ "user_id": "u1",
454
+ "username": "alice",
455
+ agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
456
+ },
457
+ )
458
+
459
+ assert exc_info.value.status_code == 403
460
+ assert exc_info.value.detail == (
461
+ "Hugging Face denied permission to create or write to the dataset repo."
462
+ )
463
+ assert "hf_secret" not in exc_info.value.detail
464
+ assert request_state["form_called"] is True
465
+ assert close_state["closed"] is True
tests/unit/test_hub_artifacts.py CHANGED
@@ -549,7 +549,7 @@ def test_sitecustomize_caches_lazy_collection_slug_across_bootstraps(
549
  ]
550
 
551
 
552
- def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
553
  import huggingface_hub as hub
554
  from huggingface_hub import HfApi
555
 
@@ -579,6 +579,10 @@ def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
579
  def fake_add_collection_item(self, **kwargs):
580
  collection_items.append(kwargs)
581
 
 
 
 
 
582
  monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
583
  monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
584
  monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
 
549
  ]
550
 
551
 
552
+ def test_sitecustomize_skips_sandbox_space_registration(monkeypatch, tmp_path):
553
  import huggingface_hub as hub
554
  from huggingface_hub import HfApi
555
 
 
579
  def fake_add_collection_item(self, **kwargs):
580
  collection_items.append(kwargs)
581
 
582
+ monkeypatch.setenv(
583
+ "ML_INTERN_ARTIFACT_COLLECTION_CACHE",
584
+ str(tmp_path / "collection-slug.txt"),
585
+ )
586
  monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
587
  monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
588
  monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
tests/unit/test_no_tool_continuation_guard.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+
4
+ import pytest
5
+
6
+ from agent.config import Config
7
+ from agent.core import agent_loop
8
+ from agent.core.agent_loop import Handlers, LLMResult
9
+ from agent.core.session import Session
10
+ from agent.tools.plan_tool import PlanTool
11
+
12
+
13
+ class FakeToolRouter:
14
+ def __init__(self):
15
+ self.calls = []
16
+
17
+ def get_tool_specs_for_llm(self):
18
+ return [
19
+ {
20
+ "type": "function",
21
+ "function": {
22
+ "name": "plan_tool",
23
+ "description": "Update plan",
24
+ "parameters": {"type": "object"},
25
+ },
26
+ }
27
+ ]
28
+
29
+ async def call_tool(self, name, arguments, session=None, tool_call_id=None):
30
+ self.calls.append((name, arguments, tool_call_id))
31
+ if name == "plan_tool" and session is not None:
32
+ session.current_plan = [dict(todo) for todo in arguments["todos"]]
33
+ return "plan updated", True
34
+
35
+
36
+ @pytest.mark.asyncio
37
+ async def test_plan_tool_stores_session_scoped_plan():
38
+ events = []
39
+
40
+ class FakeSession:
41
+ current_plan = []
42
+
43
+ async def send_event(self, event):
44
+ events.append(event)
45
+
46
+ session = FakeSession()
47
+ todos = [{"id": "1", "content": "Smoke test", "status": "in_progress"}]
48
+
49
+ result = await PlanTool(session=session).execute({"todos": todos})
50
+
51
+ assert result["isError"] is False
52
+ assert session.current_plan == todos
53
+ assert events[0].event_type == "plan_update"
54
+ assert events[0].data == {"plan": todos}
55
+
56
+
57
+ @pytest.mark.asyncio
58
+ async def test_no_tool_response_retries_when_plan_is_incomplete(monkeypatch):
59
+ config = Config.model_validate(
60
+ {"model_name": "openai/test", "save_sessions": False}
61
+ )
62
+ event_queue = asyncio.Queue()
63
+ router = FakeToolRouter()
64
+ session = Session(
65
+ event_queue,
66
+ config,
67
+ tool_router=router,
68
+ stream=False,
69
+ )
70
+ session.current_plan = [
71
+ {
72
+ "id": "1",
73
+ "content": "Write and smoke-test training script",
74
+ "status": "in_progress",
75
+ },
76
+ {"id": "2", "content": "Launch full training job", "status": "pending"},
77
+ ]
78
+ calls = []
79
+
80
+ async def fake_call_llm_non_streaming(session, messages, tools, llm_params):
81
+ calls.append(messages)
82
+ if len(calls) == 1:
83
+ return LLMResult(
84
+ content="I should keep going, but I forgot to call a tool.",
85
+ tool_calls_acc={},
86
+ token_count=10,
87
+ finish_reason="stop",
88
+ )
89
+ if len(calls) == 2:
90
+ assert "CONTINUATION GUARD" in messages[-1].content
91
+ return LLMResult(
92
+ content=None,
93
+ tool_calls_acc={
94
+ 0: {
95
+ "id": "call_1",
96
+ "function": {
97
+ "name": "plan_tool",
98
+ "arguments": json.dumps(
99
+ {
100
+ "todos": [
101
+ {
102
+ "id": "1",
103
+ "content": "Write and smoke-test training script",
104
+ "status": "completed",
105
+ },
106
+ {
107
+ "id": "2",
108
+ "content": "Launch full training job",
109
+ "status": "completed",
110
+ },
111
+ ]
112
+ }
113
+ ),
114
+ },
115
+ }
116
+ },
117
+ token_count=20,
118
+ finish_reason="tool_calls",
119
+ )
120
+ return LLMResult(
121
+ content="Done.",
122
+ tool_calls_acc={},
123
+ token_count=30,
124
+ finish_reason="stop",
125
+ )
126
+
127
+ monkeypatch.setattr(
128
+ agent_loop, "_resolve_llm_params", lambda *_, **__: {"model": "openai/test"}
129
+ )
130
+ monkeypatch.setattr(
131
+ agent_loop, "_call_llm_non_streaming", fake_call_llm_non_streaming
132
+ )
133
+
134
+ final = await Handlers.run_agent(session, "continue")
135
+
136
+ assert final == "Done."
137
+ assert len(calls) == 3
138
+ assert router.calls[0][0] == "plan_tool"
139
+ assert all(todo["status"] == "completed" for todo in session.current_plan)
140
+ events = []
141
+ while not event_queue.empty():
142
+ events.append(await event_queue.get())
143
+ assert any(
144
+ event.event_type == "tool_log"
145
+ and "text-only response" in (event.data or {}).get("log", "")
146
+ for event in events
147
+ )
tests/unit/test_sandbox_auto_start.py CHANGED
@@ -1,7 +1,15 @@
 
1
  from types import SimpleNamespace
2
  from pathlib import Path
3
 
 
 
 
 
4
  from agent.core.agent_loop import _needs_approval
 
 
 
5
  from agent.tools.sandbox_tool import get_sandbox_tools
6
 
7
 
@@ -34,3 +42,102 @@ def test_prompt_and_tool_specs_do_not_require_cpu_sandbox_create():
34
  in tool_specs["sandbox_create"]
35
  )
36
  assert "started automatically for normal CPU work" in tool_specs["bash"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
  from types import SimpleNamespace
3
  from pathlib import Path
4
 
5
+ import pytest
6
+
7
+ from agent.config import Config
8
+ from agent.core import agent_loop
9
  from agent.core.agent_loop import _needs_approval
10
+ from agent.core.session import OpType
11
+ from agent.core.tools import create_builtin_tools
12
+ from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
13
  from agent.tools.sandbox_tool import get_sandbox_tools
14
 
15
 
 
42
  in tool_specs["sandbox_create"]
43
  )
44
  assert "started automatically for normal CPU work" in tool_specs["bash"]
45
+
46
+
47
+ def test_prompt_rejects_local_machine_paths_for_hf_jobs_scripts():
48
+ prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
49
+
50
+ assert "Never pass a local machine path to hf_jobs.script" in prompt
51
+ assert "/fsx/..." in prompt
52
+ assert "inline Python source code" in prompt
53
+ assert "a file already written in the session sandbox" in prompt
54
+
55
+
56
+ def test_prompt_and_hf_jobs_spec_require_gpu_preflight_for_gpu_jobs():
57
+ prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
58
+ jobs_description = HF_JOBS_TOOL_SPEC["description"]
59
+
60
+ assert "GPU preflight is mandatory before hf_jobs" in prompt
61
+ assert "GPU sandbox smoke test" in prompt
62
+ assert "If you skip GPU sandbox preflight" in prompt
63
+ assert "you MUST create a GPU sandbox with sandbox_create first" in jobs_description
64
+ assert "If skipped, state why before calling hf_jobs" in jobs_description
65
+
66
+
67
+ def test_local_tool_runtime_excludes_sandbox_create():
68
+ tool_names = {tool.name for tool in create_builtin_tools(local_mode=True)}
69
+
70
+ assert {"bash", "read", "write", "edit"} <= tool_names
71
+ assert "sandbox_create" not in tool_names
72
+
73
+
74
+ def test_sandbox_tool_runtime_includes_sandbox_create():
75
+ tool_names = {tool.name for tool in create_builtin_tools(local_mode=False)}
76
+
77
+ assert {"sandbox_create", "bash", "read", "write", "edit"} <= tool_names
78
+
79
+
80
+ @pytest.mark.asyncio
81
+ async def test_cli_sandbox_runtime_preloads_and_tears_down_sandbox(monkeypatch):
82
+ started = []
83
+ torn_down = []
84
+
85
+ class FakeToolRouter:
86
+ tools = {}
87
+
88
+ def get_tool_specs_for_llm(self):
89
+ return []
90
+
91
+ async def __aenter__(self):
92
+ return self
93
+
94
+ async def __aexit__(self, exc_type, exc, tb):
95
+ return None
96
+
97
+ def fake_start_cpu_sandbox_preload(session):
98
+ started.append(session)
99
+ return None
100
+
101
+ async def fake_teardown_session_sandbox(session):
102
+ torn_down.append(session)
103
+
104
+ monkeypatch.setattr(
105
+ agent_loop, "start_cpu_sandbox_preload", fake_start_cpu_sandbox_preload
106
+ )
107
+ monkeypatch.setattr(
108
+ agent_loop, "teardown_session_sandbox", fake_teardown_session_sandbox
109
+ )
110
+
111
+ submission_queue = asyncio.Queue()
112
+ event_queue = asyncio.Queue()
113
+ session_holder = [None]
114
+ config = Config.model_validate(
115
+ {"model_name": "openai/gpt-5.5", "save_sessions": False}
116
+ )
117
+
118
+ task = asyncio.create_task(
119
+ agent_loop.submission_loop(
120
+ submission_queue,
121
+ event_queue,
122
+ config=config,
123
+ tool_router=FakeToolRouter(),
124
+ session_holder=session_holder,
125
+ hf_token="hf-token",
126
+ user_id="tester",
127
+ local_mode=False,
128
+ )
129
+ )
130
+
131
+ ready = await asyncio.wait_for(event_queue.get(), timeout=1)
132
+ assert ready.event_type == "ready"
133
+ assert started == [session_holder[0]]
134
+ assert session_holder[0].local_mode is False
135
+
136
+ await submission_queue.put(
137
+ SimpleNamespace(
138
+ operation=SimpleNamespace(op_type=OpType.SHUTDOWN, data=None),
139
+ )
140
+ )
141
+ await asyncio.wait_for(task, timeout=1)
142
+
143
+ assert torn_down == [session_holder[0]]
tests/unit/test_sandbox_private_spaces.py CHANGED
@@ -13,6 +13,28 @@ def _fail_metadata_update(*args, **kwargs):
13
  raise AssertionError("sandbox creation should not update Space metadata")
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
17
  duplicate_kwargs = {}
18
  logs: list[str] = []
@@ -22,8 +44,25 @@ def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
22
  def __init__(self, token=None):
23
  self.token = token
24
 
25
- def duplicate_space(self, **kwargs):
26
- duplicate_kwargs.update(kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def request_space_hardware(self, space_id, hardware, sleep_time=None):
29
  requested_hardware.append((space_id, hardware, sleep_time))
@@ -45,12 +84,38 @@ def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
45
 
46
  Sandbox.create(owner="alice", token="hf-token", log=logs.append)
47
 
 
48
  assert duplicate_kwargs["private"] is True
49
- assert duplicate_kwargs["hardware"] == "cpu-basic"
50
  assert requested_hardware == []
51
  assert not any("sleep time" in log for log in logs)
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
55
  runtime_calls = 0
56
 
@@ -67,7 +132,16 @@ def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
67
  def __init__(self, token=None):
68
  self.token = token
69
 
70
- def duplicate_space(self, **kwargs):
 
 
 
 
 
 
 
 
 
71
  pass
72
 
73
  def request_space_hardware(self, space_id, hardware, sleep_time=None):
@@ -107,8 +181,25 @@ def test_sandbox_client_configures_gpu_at_duplication(monkeypatch):
107
  def __init__(self, token=None):
108
  self.token = token
109
 
110
- def duplicate_space(self, **kwargs):
111
- duplicate_kwargs.update(kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def request_space_hardware(self, space_id, hardware, sleep_time=None):
114
  requested_hardware.append((space_id, hardware, sleep_time))
@@ -137,8 +228,9 @@ def test_sandbox_client_configures_gpu_at_duplication(monkeypatch):
137
  )
138
 
139
  assert sandbox.space_id.startswith("alice/sandbox-")
140
- assert duplicate_kwargs["hardware"] == "t4-small"
141
- assert duplicate_kwargs["sleep_time"] == 2700
 
142
  assert requested_hardware == []
143
  assert "Using duplicated Space hardware: t4-small" in logs
144
  assert "Using duplicated Space sleep time: 2700s" in logs
@@ -153,8 +245,25 @@ def test_sandbox_client_logs_cpu_sleep_time_as_hub_fixed(monkeypatch):
153
  def __init__(self, token=None):
154
  self.token = token
155
 
156
- def duplicate_space(self, **kwargs):
157
- duplicate_kwargs.update(kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  def request_space_hardware(self, space_id, hardware, sleep_time=None):
160
  requested_hardware.append((space_id, hardware, sleep_time))
@@ -180,8 +289,9 @@ def test_sandbox_client_logs_cpu_sleep_time_as_hub_fixed(monkeypatch):
180
  log=logs.append,
181
  )
182
 
183
- assert duplicate_kwargs["hardware"] == "cpu-basic"
184
- assert duplicate_kwargs["sleep_time"] == 2700
 
185
  assert requested_hardware == []
186
  assert "Using duplicated Space hardware: cpu-basic" in logs
187
  assert (
@@ -310,6 +420,71 @@ def test_ensure_sandbox_overrides_private_argument(monkeypatch):
310
  assert persisted[-1]["sandbox_status"] == "active"
311
 
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  def test_sandbox_creation_is_serialized_per_owner(monkeypatch):
314
  active_creates = 0
315
  max_active_creates = 0
@@ -429,7 +604,7 @@ def test_sandbox_create_replaces_auto_cpu_sandbox(monkeypatch):
429
  space_id="alice/sandbox-cpu",
430
  url="https://huggingface.co/spaces/alice/sandbox-cpu",
431
  _owns_space=True,
432
- delete=lambda: deleted.append("alice/sandbox-cpu"),
433
  )
434
  self.sandbox_hardware = "cpu-basic"
435
  self.sandbox_preload_task = None
@@ -474,10 +649,11 @@ def test_sandbox_create_replaces_auto_cpu_sandbox(monkeypatch):
474
 
475
  def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
476
  deleted: list[str] = []
 
477
  persisted: list[dict] = []
478
 
479
- async def fake_record_sandbox_destroy(*args, **kwargs):
480
- pass
481
 
482
  monkeypatch.setattr(
483
  telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy
@@ -485,20 +661,28 @@ def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
485
 
486
  async def run():
487
  cancel_event = threading.Event()
 
488
 
489
  async def preload():
490
  await asyncio.sleep(0)
491
 
 
 
 
 
 
 
492
  session = SimpleNamespace(
493
  session_id="s1",
494
  sandbox=SimpleNamespace(
495
  space_id="alice/sandbox-12345678",
496
  _owns_space=True,
497
- delete=lambda: deleted.append("alice/sandbox-12345678"),
498
  ),
499
  sandbox_hardware="cpu-basic",
500
  sandbox_preload_task=asyncio.create_task(preload()),
501
  sandbox_preload_cancel_event=cancel_event,
 
502
  persistence_store=SimpleNamespace(
503
  update_session_fields=lambda session_id, **fields: _record_metadata(
504
  session_id, fields
@@ -507,17 +691,33 @@ def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
507
  )
508
 
509
  await sandbox_tool.teardown_session_sandbox(session)
510
- return session, cancel_event
 
 
 
 
511
 
512
  async def _record_metadata(session_id, fields):
513
  persisted.append({"session_id": session_id, **fields})
514
 
515
- session, cancel_event = asyncio.run(run())
516
 
517
  assert cancel_event.is_set()
518
  assert deleted == ["alice/sandbox-12345678"]
 
519
  assert session.sandbox is None
520
  assert session.sandbox_hardware is None
 
 
 
 
 
 
 
 
 
 
 
521
  assert persisted[-1]["session_id"] == "s1"
522
  assert persisted[-1]["sandbox_space_id"] is None
523
  assert persisted[-1]["sandbox_status"] == "destroyed"
 
13
  raise AssertionError("sandbox creation should not update Space metadata")
14
 
15
 
16
+ def _capture_duplicate_repo_call(
17
+ captured,
18
+ *,
19
+ from_id,
20
+ to_id,
21
+ repo_type,
22
+ private,
23
+ space_hardware,
24
+ space_sleep_time=None,
25
+ ):
26
+ captured.update(
27
+ {
28
+ "from_id": from_id,
29
+ "to_id": to_id,
30
+ "repo_type": repo_type,
31
+ "private": private,
32
+ "space_hardware": space_hardware,
33
+ "space_sleep_time": space_sleep_time,
34
+ }
35
+ )
36
+
37
+
38
  def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
39
  duplicate_kwargs = {}
40
  logs: list[str] = []
 
44
  def __init__(self, token=None):
45
  self.token = token
46
 
47
+ def duplicate_repo(
48
+ self,
49
+ *,
50
+ from_id,
51
+ to_id,
52
+ repo_type,
53
+ private,
54
+ space_hardware,
55
+ space_sleep_time=None,
56
+ ):
57
+ _capture_duplicate_repo_call(
58
+ duplicate_kwargs,
59
+ from_id=from_id,
60
+ to_id=to_id,
61
+ repo_type=repo_type,
62
+ private=private,
63
+ space_hardware=space_hardware,
64
+ space_sleep_time=space_sleep_time,
65
+ )
66
 
67
  def request_space_hardware(self, space_id, hardware, sleep_time=None):
68
  requested_hardware.append((space_id, hardware, sleep_time))
 
84
 
85
  Sandbox.create(owner="alice", token="hf-token", log=logs.append)
86
 
87
+ assert duplicate_kwargs["repo_type"] == "space"
88
  assert duplicate_kwargs["private"] is True
89
+ assert duplicate_kwargs["space_hardware"] == "cpu-basic"
90
  assert requested_hardware == []
91
  assert not any("sleep time" in log for log in logs)
92
 
93
 
94
+ def test_sandbox_delete_uses_log_callback_without_stdout(monkeypatch, capsys):
95
+ deleted: list[tuple[str, str]] = []
96
+
97
+ class FakeApi:
98
+ def __init__(self, token=None):
99
+ self.token = token
100
+
101
+ def delete_repo(self, repo_id, repo_type):
102
+ deleted.append((repo_id, repo_type))
103
+
104
+ monkeypatch.setattr(sandbox_client, "HfApi", FakeApi)
105
+
106
+ sandbox = Sandbox("alice/sandbox-12345678", token="hf-token", _owns_space=True)
107
+ logs: list[str] = []
108
+
109
+ sandbox.delete(log=logs.append)
110
+
111
+ captured = capsys.readouterr()
112
+ assert captured.out == ""
113
+ assert captured.err == ""
114
+ assert deleted == [("alice/sandbox-12345678", "space")]
115
+ assert logs == ["Deleting sandbox: alice/sandbox-12345678...", "Deleted."]
116
+ assert sandbox._owns_space is False
117
+
118
+
119
  def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
120
  runtime_calls = 0
121
 
 
132
  def __init__(self, token=None):
133
  self.token = token
134
 
135
+ def duplicate_repo(
136
+ self,
137
+ *,
138
+ from_id,
139
+ to_id,
140
+ repo_type,
141
+ private,
142
+ space_hardware,
143
+ space_sleep_time=None,
144
+ ):
145
  pass
146
 
147
  def request_space_hardware(self, space_id, hardware, sleep_time=None):
 
181
  def __init__(self, token=None):
182
  self.token = token
183
 
184
+ def duplicate_repo(
185
+ self,
186
+ *,
187
+ from_id,
188
+ to_id,
189
+ repo_type,
190
+ private,
191
+ space_hardware,
192
+ space_sleep_time=None,
193
+ ):
194
+ _capture_duplicate_repo_call(
195
+ duplicate_kwargs,
196
+ from_id=from_id,
197
+ to_id=to_id,
198
+ repo_type=repo_type,
199
+ private=private,
200
+ space_hardware=space_hardware,
201
+ space_sleep_time=space_sleep_time,
202
+ )
203
 
204
  def request_space_hardware(self, space_id, hardware, sleep_time=None):
205
  requested_hardware.append((space_id, hardware, sleep_time))
 
228
  )
229
 
230
  assert sandbox.space_id.startswith("alice/sandbox-")
231
+ assert duplicate_kwargs["repo_type"] == "space"
232
+ assert duplicate_kwargs["space_hardware"] == "t4-small"
233
+ assert duplicate_kwargs["space_sleep_time"] == 2700
234
  assert requested_hardware == []
235
  assert "Using duplicated Space hardware: t4-small" in logs
236
  assert "Using duplicated Space sleep time: 2700s" in logs
 
245
  def __init__(self, token=None):
246
  self.token = token
247
 
248
+ def duplicate_repo(
249
+ self,
250
+ *,
251
+ from_id,
252
+ to_id,
253
+ repo_type,
254
+ private,
255
+ space_hardware,
256
+ space_sleep_time=None,
257
+ ):
258
+ _capture_duplicate_repo_call(
259
+ duplicate_kwargs,
260
+ from_id=from_id,
261
+ to_id=to_id,
262
+ repo_type=repo_type,
263
+ private=private,
264
+ space_hardware=space_hardware,
265
+ space_sleep_time=space_sleep_time,
266
+ )
267
 
268
  def request_space_hardware(self, space_id, hardware, sleep_time=None):
269
  requested_hardware.append((space_id, hardware, sleep_time))
 
289
  log=logs.append,
290
  )
291
 
292
+ assert duplicate_kwargs["repo_type"] == "space"
293
+ assert duplicate_kwargs["space_hardware"] == "cpu-basic"
294
+ assert duplicate_kwargs["space_sleep_time"] == 2700
295
  assert requested_hardware == []
296
  assert "Using duplicated Space hardware: cpu-basic" in logs
297
  assert (
 
420
  assert persisted[-1]["sandbox_status"] == "active"
421
 
422
 
423
+ def test_cancelled_sandbox_creation_logs_delete_through_tool_log(monkeypatch):
424
+ deleted: list[str] = []
425
+
426
+ class FakeSession:
427
+ def __init__(self):
428
+ self.hf_token = "hf-token"
429
+ self.sandbox = None
430
+ self.event_queue = asyncio.Queue()
431
+ self._cancelled = asyncio.Event()
432
+
433
+ async def send_event(self, event):
434
+ await self.event_queue.put(event)
435
+
436
+ def fake_create(**kwargs):
437
+ def delete(log=None):
438
+ deleted.append("alice/sandbox-12345678")
439
+ if log:
440
+ log("Deleting sandbox: alice/sandbox-12345678...")
441
+ log("Deleted.")
442
+
443
+ return SimpleNamespace(
444
+ space_id="alice/sandbox-12345678",
445
+ url="https://huggingface.co/spaces/alice/sandbox-12345678",
446
+ _owns_space=True,
447
+ delete=delete,
448
+ )
449
+
450
+ monkeypatch.setattr(Sandbox, "create", staticmethod(fake_create))
451
+
452
+ async def run():
453
+ session = FakeSession()
454
+ cancel_event = threading.Event()
455
+ cancel_event.set()
456
+
457
+ sb, error = await sandbox_tool._create_sandbox_locked(
458
+ session,
459
+ api=SimpleNamespace(),
460
+ owner="alice",
461
+ hardware="cpu-basic",
462
+ cancel_event=cancel_event,
463
+ )
464
+ await asyncio.sleep(0)
465
+ events = []
466
+ while not session.event_queue.empty():
467
+ events.append(await session.event_queue.get())
468
+ return sb, error, events
469
+
470
+ sb, error, events = asyncio.run(run())
471
+
472
+ assert sb is None
473
+ assert error == "Sandbox creation cancelled by user."
474
+ assert deleted == ["alice/sandbox-12345678"]
475
+ assert [
476
+ event.data
477
+ for event in events
478
+ if event.event_type == "tool_log"
479
+ and event.data
480
+ and event.data.get("log")
481
+ in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."}
482
+ ] == [
483
+ {"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."},
484
+ {"tool": "sandbox", "log": "Deleted."},
485
+ ]
486
+
487
+
488
  def test_sandbox_creation_is_serialized_per_owner(monkeypatch):
489
  active_creates = 0
490
  max_active_creates = 0
 
604
  space_id="alice/sandbox-cpu",
605
  url="https://huggingface.co/spaces/alice/sandbox-cpu",
606
  _owns_space=True,
607
+ delete=lambda log=None: deleted.append("alice/sandbox-cpu"),
608
  )
609
  self.sandbox_hardware = "cpu-basic"
610
  self.sandbox_preload_task = None
 
649
 
650
  def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
651
  deleted: list[str] = []
652
+ destroyed: list[str] = []
653
  persisted: list[dict] = []
654
 
655
+ async def fake_record_sandbox_destroy(session, sandbox, *args, **kwargs):
656
+ destroyed.append(sandbox.space_id)
657
 
658
  monkeypatch.setattr(
659
  telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy
 
661
 
662
  async def run():
663
  cancel_event = threading.Event()
664
+ event_queue = asyncio.Queue()
665
 
666
  async def preload():
667
  await asyncio.sleep(0)
668
 
669
+ def delete(log=None):
670
+ deleted.append("alice/sandbox-12345678")
671
+ if log:
672
+ log("Deleting sandbox: alice/sandbox-12345678...")
673
+ log("Deleted.")
674
+
675
  session = SimpleNamespace(
676
  session_id="s1",
677
  sandbox=SimpleNamespace(
678
  space_id="alice/sandbox-12345678",
679
  _owns_space=True,
680
+ delete=delete,
681
  ),
682
  sandbox_hardware="cpu-basic",
683
  sandbox_preload_task=asyncio.create_task(preload()),
684
  sandbox_preload_cancel_event=cancel_event,
685
+ event_queue=event_queue,
686
  persistence_store=SimpleNamespace(
687
  update_session_fields=lambda session_id, **fields: _record_metadata(
688
  session_id, fields
 
691
  )
692
 
693
  await sandbox_tool.teardown_session_sandbox(session)
694
+ await asyncio.sleep(0)
695
+ events = []
696
+ while not event_queue.empty():
697
+ events.append(await event_queue.get())
698
+ return session, cancel_event, events
699
 
700
  async def _record_metadata(session_id, fields):
701
  persisted.append({"session_id": session_id, **fields})
702
 
703
+ session, cancel_event, events = asyncio.run(run())
704
 
705
  assert cancel_event.is_set()
706
  assert deleted == ["alice/sandbox-12345678"]
707
+ assert destroyed == ["alice/sandbox-12345678"]
708
  assert session.sandbox is None
709
  assert session.sandbox_hardware is None
710
+ assert [
711
+ event.data
712
+ for event in events
713
+ if event.event_type == "tool_log"
714
+ and event.data
715
+ and event.data.get("log")
716
+ in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."}
717
+ ] == [
718
+ {"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."},
719
+ {"tool": "sandbox", "log": "Deleted."},
720
+ ]
721
  assert persisted[-1]["session_id"] == "s1"
722
  assert persisted[-1]["sandbox_space_id"] is None
723
  assert persisted[-1]["sandbox_status"] == "destroyed"
tests/unit/test_sandbox_script_resolution.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import SimpleNamespace
2
+
3
+ import pytest
4
+
5
+ from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
6
+ from agent.tools.sandbox_tool import resolve_sandbox_script
7
+
8
+
9
+ class FakeSandbox:
10
+ def __init__(self):
11
+ self.read_paths = []
12
+
13
+ def read(self, path, *, limit):
14
+ self.read_paths.append((path, limit))
15
+ return SimpleNamespace(
16
+ success=True,
17
+ output="1\tprint('training')\n2\tprint('done')",
18
+ error="",
19
+ )
20
+
21
+
22
+ @pytest.mark.asyncio
23
+ async def test_resolve_sandbox_script_accepts_bare_python_filename():
24
+ sandbox = FakeSandbox()
25
+
26
+ content, error = await resolve_sandbox_script(sandbox, "train_smollm2.py")
27
+
28
+ assert error is None
29
+ assert content == "print('training')\nprint('done')"
30
+ assert sandbox.read_paths == [("train_smollm2.py", 100_000)]
31
+
32
+
33
+ @pytest.mark.asyncio
34
+ async def test_resolve_sandbox_script_accepts_relative_python_path():
35
+ sandbox = FakeSandbox()
36
+
37
+ content, error = await resolve_sandbox_script(sandbox, "scripts/train.py")
38
+
39
+ assert error is None
40
+ assert content == "print('training')\nprint('done')"
41
+ assert sandbox.read_paths == [("scripts/train.py", 100_000)]
42
+
43
+
44
+ @pytest.mark.asyncio
45
+ @pytest.mark.parametrize(
46
+ "script",
47
+ [
48
+ "https://example.com/train.py",
49
+ "http://example.com/train.py",
50
+ "train_smollm2.py --epochs 1",
51
+ "print('hello')",
52
+ ],
53
+ )
54
+ async def test_resolve_sandbox_script_ignores_non_path_scripts(script):
55
+ sandbox = FakeSandbox()
56
+
57
+ content, error = await resolve_sandbox_script(sandbox, script)
58
+
59
+ assert content is None
60
+ assert error is None
61
+ assert sandbox.read_paths == []
62
+
63
+
64
+ def test_hf_jobs_script_description_mentions_bare_python_filenames():
65
+ script_description = HF_JOBS_TOOL_SPEC["parameters"]["properties"]["script"][
66
+ "description"
67
+ ]
68
+
69
+ assert "bare 'train.py'" in script_description
70
+ assert "smoke-test in a GPU sandbox before submission" in script_description
tests/unit/test_session_manager_persistence.py CHANGED
@@ -207,7 +207,7 @@ async def test_close_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
207
  session.sandbox = SimpleNamespace(
208
  space_id="owner/sandbox-12345678",
209
  _owns_space=True,
210
- delete=lambda: deleted.append("owner/sandbox-12345678"),
211
  )
212
  session.sandbox_hardware = "cpu-basic"
213
  session.sandbox_preload_cancel_event = preload_cancel_event
 
207
  session.sandbox = SimpleNamespace(
208
  space_id="owner/sandbox-12345678",
209
  _owns_space=True,
210
+ delete=lambda log=None: deleted.append("owner/sandbox-12345678"),
211
  )
212
  session.sandbox_hardware = "cpu-basic"
213
  session.sandbox_preload_cancel_event = preload_cancel_event
tests/unit/test_trackio_space_ids.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
5
+ from agent.tools.sandbox_tool import SANDBOX_CREATE_TOOL_SPEC
6
+
7
+
8
+ def test_trackio_space_examples_use_hyphenated_ml_intern_prefix():
9
+ prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
10
+ tool_specs = json.dumps([HF_JOBS_TOOL_SPEC, SANDBOX_CREATE_TOOL_SPEC])
11
+ legacy_prefix = "ml" + "intern"
12
+
13
+ assert "<username>/ml-intern-<8-char-id>" in prompt
14
+ assert "<username>/ml-intern-<8char>" in tool_specs
15
+ assert legacy_prefix not in prompt
16
+ assert legacy_prefix not in tool_specs
uv.lock CHANGED
@@ -1788,6 +1788,7 @@ dependencies = [
1788
  { name = "pydantic" },
1789
  { name = "pymongo" },
1790
  { name = "python-dotenv" },
 
1791
  { name = "requests" },
1792
  { name = "rich" },
1793
  { name = "thefuzz" },
@@ -1840,6 +1841,7 @@ requires-dist = [
1840
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
1841
  { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.2.0" },
1842
  { name = "python-dotenv", specifier = ">=1.2.1" },
 
1843
  { name = "requests", specifier = ">=2.33.0" },
1844
  { name = "rich", specifier = ">=13.0.0" },
1845
  { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.12" },
 
1788
  { name = "pydantic" },
1789
  { name = "pymongo" },
1790
  { name = "python-dotenv" },
1791
+ { name = "python-multipart" },
1792
  { name = "requests" },
1793
  { name = "rich" },
1794
  { name = "thefuzz" },
 
1841
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
1842
  { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.2.0" },
1843
  { name = "python-dotenv", specifier = ">=1.2.1" },
1844
+ { name = "python-multipart", specifier = ">=0.0.20" },
1845
  { name = "requests", specifier = ">=2.33.0" },
1846
  { name = "rich", specifier = ">=13.0.0" },
1847
  { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.12" },