Deploy 2026-05-11
Browse filesCo-authored-by: OpenAI Codex <codex@openai.com>
- agent/config.py +2 -1
- agent/core/agent_loop.py +94 -1
- agent/core/session.py +2 -0
- agent/main.py +64 -12
- agent/prompts/system_prompt_v3.yaml +11 -2
- agent/tools/jobs_tool.py +7 -3
- agent/tools/plan_tool.py +8 -4
- agent/tools/sandbox_client.py +11 -8
- agent/tools/sandbox_tool.py +38 -18
- agent/utils/terminal_display.py +69 -14
- backend/dataset_uploads.py +305 -0
- backend/models.py +18 -1
- backend/routes/agent.py +147 -1
- configs/cli_agent_config.json +1 -0
- frontend/src/components/Chat/ChatInput.tsx +224 -11
- frontend/src/components/SessionChat.tsx +11 -1
- frontend/src/hooks/useAgentChat.ts +43 -0
- frontend/src/utils/api.ts +70 -18
- pyproject.toml +1 -0
- tests/unit/test_cli_rendering.py +247 -3
- tests/unit/test_config.py +35 -0
- tests/unit/test_dataset_uploads.py +465 -0
- tests/unit/test_hub_artifacts.py +5 -1
- tests/unit/test_no_tool_continuation_guard.py +147 -0
- tests/unit/test_sandbox_auto_start.py +107 -0
- tests/unit/test_sandbox_private_spaces.py +218 -18
- tests/unit/test_sandbox_script_resolution.py +70 -0
- tests/unit/test_session_manager_persistence.py +1 -1
- tests/unit/test_trackio_space_ids.py +16 -0
- uv.lock +2 -0
agent/config.py
CHANGED
|
@@ -2,7 +2,7 @@ import json
|
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import Any, Union
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from fastmcp.mcp_config import (
|
|
@@ -46,6 +46,7 @@ class Config(BaseModel):
|
|
| 46 |
# Permission control parameters
|
| 47 |
confirm_cpu_jobs: bool = True
|
| 48 |
auto_file_upload: bool = False
|
|
|
|
| 49 |
|
| 50 |
# Reasoning effort *preference* — the ceiling the user wants. The probe
|
| 51 |
# on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
|
|
|
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
from pathlib import Path
|
| 5 |
+
from typing import Any, Literal, Union
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
from fastmcp.mcp_config import (
|
|
|
|
| 46 |
# Permission control parameters
|
| 47 |
confirm_cpu_jobs: bool = True
|
| 48 |
auto_file_upload: bool = False
|
| 49 |
+
tool_runtime: Literal["local", "sandbox"] = "local"
|
| 50 |
|
| 51 |
# Reasoning effort *preference* — the ceiling the user wants. The probe
|
| 52 |
# on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
|
agent/core/agent_loop.py
CHANGED
|
@@ -32,7 +32,11 @@ from agent.core.prompt_caching import with_prompt_caching
|
|
| 32 |
from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
|
| 33 |
from agent.core.tools import ToolRouter
|
| 34 |
from agent.tools.jobs_tool import CPU_FLAVORS
|
| 35 |
-
from agent.tools.sandbox_tool import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
logger = logging.getLogger(__name__)
|
| 38 |
|
|
@@ -40,6 +44,43 @@ ToolCall = ChatCompletionMessageToolCall
|
|
| 40 |
|
| 41 |
_MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
|
| 42 |
_MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def _malformed_tool_name(message: Message) -> str | None:
|
|
@@ -1153,6 +1194,7 @@ class Handlers:
|
|
| 1153 |
final_response = None
|
| 1154 |
errored = False
|
| 1155 |
max_iterations = session.config.max_iterations
|
|
|
|
| 1156 |
|
| 1157 |
while max_iterations == -1 or iteration < max_iterations:
|
| 1158 |
# ── Cancellation check: before LLM call ──
|
|
@@ -1301,6 +1343,51 @@ class Handlers:
|
|
| 1301 |
|
| 1302 |
# If no tool calls, add assistant message and we're done
|
| 1303 |
if not tool_calls:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1304 |
logger.debug(
|
| 1305 |
"Agent loop ending: no tool calls. "
|
| 1306 |
"finish_reason=%s, token_count=%d, "
|
|
@@ -1324,6 +1411,8 @@ class Handlers:
|
|
| 1324 |
final_response = content
|
| 1325 |
break
|
| 1326 |
|
|
|
|
|
|
|
| 1327 |
# Validate tool call args (one json.loads per call, once)
|
| 1328 |
# and split into good vs bad
|
| 1329 |
good_tools: list[tuple[ToolCall, str, dict]] = []
|
|
@@ -1940,6 +2029,8 @@ class Handlers:
|
|
| 1940 |
_ = session.save_and_upload_detached(repo_id)
|
| 1941 |
|
| 1942 |
session.is_running = False
|
|
|
|
|
|
|
| 1943 |
await session.send_event(Event(event_type="shutdown"))
|
| 1944 |
return True
|
| 1945 |
|
|
@@ -2023,6 +2114,8 @@ async def submission_loop(
|
|
| 2023 |
)
|
| 2024 |
if session_holder is not None:
|
| 2025 |
session_holder[0] = session
|
|
|
|
|
|
|
| 2026 |
logger.info("Agent loop started")
|
| 2027 |
|
| 2028 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
|
|
|
| 32 |
from agent.core.session import DEFAULT_SESSION_LOG_DIR, Event, OpType, Session
|
| 33 |
from agent.core.tools import ToolRouter
|
| 34 |
from agent.tools.jobs_tool import CPU_FLAVORS
|
| 35 |
+
from agent.tools.sandbox_tool import (
|
| 36 |
+
DEFAULT_CPU_SANDBOX_HARDWARE,
|
| 37 |
+
start_cpu_sandbox_preload,
|
| 38 |
+
teardown_session_sandbox,
|
| 39 |
+
)
|
| 40 |
|
| 41 |
logger = logging.getLogger(__name__)
|
| 42 |
|
|
|
|
| 44 |
|
| 45 |
_MALFORMED_TOOL_PREFIX = "ERROR: Tool call to '"
|
| 46 |
_MALFORMED_TOOL_SUFFIX = "' had malformed JSON arguments"
|
| 47 |
+
_NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT = 2
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _unfinished_plan_items(session: Session) -> list[dict[str, str]]:
|
| 51 |
+
plan = getattr(session, "current_plan", None) or []
|
| 52 |
+
unfinished: list[dict[str, str]] = []
|
| 53 |
+
for item in plan:
|
| 54 |
+
if not isinstance(item, dict):
|
| 55 |
+
continue
|
| 56 |
+
status = item.get("status")
|
| 57 |
+
if status in {"pending", "in_progress"}:
|
| 58 |
+
unfinished.append(item)
|
| 59 |
+
return unfinished
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _format_plan_items_for_guard(items: list[dict[str, str]], limit: int = 4) -> str:
|
| 63 |
+
formatted = []
|
| 64 |
+
for item in items[:limit]:
|
| 65 |
+
item_id = item.get("id") or "?"
|
| 66 |
+
content = item.get("content") or "(unnamed task)"
|
| 67 |
+
status = item.get("status") or "unknown"
|
| 68 |
+
formatted.append(f"{item_id}. {content} [{status}]")
|
| 69 |
+
if len(items) > limit:
|
| 70 |
+
formatted.append(f"... and {len(items) - limit} more")
|
| 71 |
+
return "; ".join(formatted)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _no_tool_incomplete_plan_prompt(items: list[dict[str, str]]) -> str:
|
| 75 |
+
summary = _format_plan_items_for_guard(items)
|
| 76 |
+
return (
|
| 77 |
+
"[SYSTEM: CONTINUATION GUARD] Your previous response ended without any "
|
| 78 |
+
"tool calls, but the task is not complete. The current plan still has "
|
| 79 |
+
f"unfinished items: {summary}. Do not return control to the user yet. "
|
| 80 |
+
"Continue from the next unfinished item and make at least one tool call "
|
| 81 |
+
"now. If you genuinely cannot continue, first use tools to inspect the "
|
| 82 |
+
"state or verify the blocker."
|
| 83 |
+
)
|
| 84 |
|
| 85 |
|
| 86 |
def _malformed_tool_name(message: Message) -> str | None:
|
|
|
|
| 1194 |
final_response = None
|
| 1195 |
errored = False
|
| 1196 |
max_iterations = session.config.max_iterations
|
| 1197 |
+
no_tool_incomplete_plan_retries = 0
|
| 1198 |
|
| 1199 |
while max_iterations == -1 or iteration < max_iterations:
|
| 1200 |
# ── Cancellation check: before LLM call ──
|
|
|
|
| 1343 |
|
| 1344 |
# If no tool calls, add assistant message and we're done
|
| 1345 |
if not tool_calls:
|
| 1346 |
+
unfinished_plan = _unfinished_plan_items(session)
|
| 1347 |
+
if (
|
| 1348 |
+
unfinished_plan
|
| 1349 |
+
and no_tool_incomplete_plan_retries
|
| 1350 |
+
< _NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT
|
| 1351 |
+
):
|
| 1352 |
+
logger.info(
|
| 1353 |
+
"No tool calls with unfinished plan; retrying agent turn "
|
| 1354 |
+
"(attempt %d/%d)",
|
| 1355 |
+
no_tool_incomplete_plan_retries + 1,
|
| 1356 |
+
_NO_TOOL_INCOMPLETE_PLAN_RETRY_LIMIT,
|
| 1357 |
+
)
|
| 1358 |
+
if content:
|
| 1359 |
+
assistant_msg = _assistant_message_from_result(
|
| 1360 |
+
llm_result,
|
| 1361 |
+
model_name=llm_params.get("model"),
|
| 1362 |
+
)
|
| 1363 |
+
session.context_manager.add_message(
|
| 1364 |
+
assistant_msg, token_count
|
| 1365 |
+
)
|
| 1366 |
+
session.context_manager.add_message(
|
| 1367 |
+
Message(
|
| 1368 |
+
role="user",
|
| 1369 |
+
content=_no_tool_incomplete_plan_prompt(
|
| 1370 |
+
unfinished_plan
|
| 1371 |
+
),
|
| 1372 |
+
)
|
| 1373 |
+
)
|
| 1374 |
+
no_tool_incomplete_plan_retries += 1
|
| 1375 |
+
await session.send_event(
|
| 1376 |
+
Event(
|
| 1377 |
+
event_type="tool_log",
|
| 1378 |
+
data={
|
| 1379 |
+
"tool": "system",
|
| 1380 |
+
"log": (
|
| 1381 |
+
"Plan still has unfinished items after a "
|
| 1382 |
+
"text-only response — retrying instead of "
|
| 1383 |
+
"returning to the prompt."
|
| 1384 |
+
),
|
| 1385 |
+
},
|
| 1386 |
+
)
|
| 1387 |
+
)
|
| 1388 |
+
iteration += 1
|
| 1389 |
+
continue
|
| 1390 |
+
|
| 1391 |
logger.debug(
|
| 1392 |
"Agent loop ending: no tool calls. "
|
| 1393 |
"finish_reason=%s, token_count=%d, "
|
|
|
|
| 1411 |
final_response = content
|
| 1412 |
break
|
| 1413 |
|
| 1414 |
+
no_tool_incomplete_plan_retries = 0
|
| 1415 |
+
|
| 1416 |
# Validate tool call args (one json.loads per call, once)
|
| 1417 |
# and split into good vs bad
|
| 1418 |
good_tools: list[tuple[ToolCall, str, dict]] = []
|
|
|
|
| 2029 |
_ = session.save_and_upload_detached(repo_id)
|
| 2030 |
|
| 2031 |
session.is_running = False
|
| 2032 |
+
if not getattr(session, "local_mode", False):
|
| 2033 |
+
await teardown_session_sandbox(session)
|
| 2034 |
await session.send_event(Event(event_type="shutdown"))
|
| 2035 |
return True
|
| 2036 |
|
|
|
|
| 2114 |
)
|
| 2115 |
if session_holder is not None:
|
| 2116 |
session_holder[0] = session
|
| 2117 |
+
if not local_mode:
|
| 2118 |
+
start_cpu_sandbox_preload(session)
|
| 2119 |
logger.info("Agent loop started")
|
| 2120 |
|
| 2121 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
agent/core/session.py
CHANGED
|
@@ -99,6 +99,7 @@ class Session:
|
|
| 99 |
self.hf_token: Optional[str] = hf_token
|
| 100 |
self.user_id: Optional[str] = user_id
|
| 101 |
self.hf_username: Optional[str] = hf_username
|
|
|
|
| 102 |
self.persistence_store = persistence_store
|
| 103 |
self.tool_router = tool_router
|
| 104 |
self.stream = stream
|
|
@@ -117,6 +118,7 @@ class Session:
|
|
| 117 |
self.session_id = session_id or str(uuid.uuid4())
|
| 118 |
self.config = config
|
| 119 |
self.is_running = True
|
|
|
|
| 120 |
self._cancelled = asyncio.Event()
|
| 121 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 122 |
self.sandbox = None
|
|
|
|
| 99 |
self.hf_token: Optional[str] = hf_token
|
| 100 |
self.user_id: Optional[str] = user_id
|
| 101 |
self.hf_username: Optional[str] = hf_username
|
| 102 |
+
self.local_mode = local_mode
|
| 103 |
self.persistence_store = persistence_store
|
| 104 |
self.tool_router = tool_router
|
| 105 |
self.stream = stream
|
|
|
|
| 118 |
self.session_id = session_id or str(uuid.uuid4())
|
| 119 |
self.config = config
|
| 120 |
self.is_running = True
|
| 121 |
+
self.current_plan: list[dict[str, str]] = []
|
| 122 |
self._cancelled = asyncio.Event()
|
| 123 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 124 |
self.sandbox = None
|
agent/main.py
CHANGED
|
@@ -59,6 +59,34 @@ CLI_CONFIG_PATH = Path(__file__).parent.parent / "configs" / "cli_agent_config.j
|
|
| 59 |
logger = logging.getLogger(__name__)
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
|
| 63 |
if tool_info.get("tool") != "hf_jobs":
|
| 64 |
return False
|
|
@@ -957,6 +985,7 @@ async def _handle_slash_command(
|
|
| 957 |
session = session_holder[0] if session_holder else None
|
| 958 |
print(f"Model: {config.model_name}")
|
| 959 |
print(f"Reasoning effort: {config.reasoning_effort or 'off'}")
|
|
|
|
| 960 |
if session:
|
| 961 |
print(f"Turns: {session.turn_count}")
|
| 962 |
print(f"Context items: {len(session.context_manager.items)}")
|
|
@@ -1076,7 +1105,7 @@ async def _handle_share_traces_command(arg: str, config, session) -> None:
|
|
| 1076 |
console.print(f"[green]Dataset is now {label}.[/green] {url}")
|
| 1077 |
|
| 1078 |
|
| 1079 |
-
async def main(model: str | None = None):
|
| 1080 |
"""Interactive chat with the agent"""
|
| 1081 |
|
| 1082 |
# Clear screen
|
|
@@ -1088,16 +1117,23 @@ async def main(model: str | None = None):
|
|
| 1088 |
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 1089 |
if model:
|
| 1090 |
config.model_name = model
|
|
|
|
|
|
|
| 1091 |
|
| 1092 |
-
# HF token — required for Hub-backed models/tools
|
|
|
|
| 1093 |
hf_token = resolve_hf_token()
|
| 1094 |
-
if not hf_token and not is_local_model_id(config.model_name):
|
| 1095 |
hf_token = await _prompt_and_save_hf_token(prompt_session)
|
| 1096 |
|
| 1097 |
# Resolve username for banner
|
| 1098 |
hf_user = _get_hf_user(hf_token)
|
| 1099 |
|
| 1100 |
-
print_banner(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1101 |
|
| 1102 |
# Pre-warm the HF router catalog in the background so /model switches
|
| 1103 |
# don't block on a network fetch.
|
|
@@ -1116,8 +1152,10 @@ async def main(model: str | None = None):
|
|
| 1116 |
|
| 1117 |
notification_gateway = NotificationGateway(config.messaging)
|
| 1118 |
await notification_gateway.start()
|
| 1119 |
-
# Create tool router with
|
| 1120 |
-
tool_router = ToolRouter(
|
|
|
|
|
|
|
| 1121 |
|
| 1122 |
# Session holder for interrupt/model/status access
|
| 1123 |
session_holder = [None]
|
|
@@ -1131,7 +1169,7 @@ async def main(model: str | None = None):
|
|
| 1131 |
session_holder=session_holder,
|
| 1132 |
hf_token=hf_token,
|
| 1133 |
user_id=hf_user,
|
| 1134 |
-
local_mode=
|
| 1135 |
stream=True,
|
| 1136 |
notification_gateway=notification_gateway,
|
| 1137 |
notification_destinations=config.messaging.default_auto_destinations(),
|
|
@@ -1153,6 +1191,8 @@ async def main(model: str | None = None):
|
|
| 1153 |
)
|
| 1154 |
|
| 1155 |
await ready_event.wait()
|
|
|
|
|
|
|
| 1156 |
|
| 1157 |
submission_id = [0]
|
| 1158 |
# Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137
|
|
@@ -1310,6 +1350,7 @@ async def headless_main(
|
|
| 1310 |
model: str | None = None,
|
| 1311 |
max_iterations: int | None = None,
|
| 1312 |
stream: bool = True,
|
|
|
|
| 1313 |
) -> None:
|
| 1314 |
"""Run a single prompt headlessly and exit."""
|
| 1315 |
import logging
|
|
@@ -1322,11 +1363,13 @@ async def headless_main(
|
|
| 1322 |
|
| 1323 |
if model:
|
| 1324 |
config.model_name = model
|
|
|
|
|
|
|
| 1325 |
|
| 1326 |
hf_token = resolve_hf_token()
|
| 1327 |
-
if not hf_token and not is_local_model_id(config.model_name):
|
| 1328 |
print(
|
| 1329 |
-
"ERROR: No HF token found. Set HF_TOKEN or run `
|
| 1330 |
file=sys.stderr,
|
| 1331 |
)
|
| 1332 |
sys.exit(1)
|
|
@@ -1342,6 +1385,7 @@ async def headless_main(
|
|
| 1342 |
config.max_iterations = max_iterations
|
| 1343 |
|
| 1344 |
print(f"Model: {config.model_name}", file=sys.stderr)
|
|
|
|
| 1345 |
print(f"Max iterations: {config.max_iterations}", file=sys.stderr)
|
| 1346 |
print(f"Prompt: {prompt}", file=sys.stderr)
|
| 1347 |
print("---", file=sys.stderr)
|
|
@@ -1349,7 +1393,9 @@ async def headless_main(
|
|
| 1349 |
submission_queue: asyncio.Queue = asyncio.Queue()
|
| 1350 |
event_queue: asyncio.Queue = asyncio.Queue()
|
| 1351 |
|
| 1352 |
-
tool_router = ToolRouter(
|
|
|
|
|
|
|
| 1353 |
session_holder: list = [None]
|
| 1354 |
|
| 1355 |
agent_task = asyncio.create_task(
|
|
@@ -1361,7 +1407,7 @@ async def headless_main(
|
|
| 1361 |
session_holder=session_holder,
|
| 1362 |
hf_token=hf_token,
|
| 1363 |
user_id=hf_user,
|
| 1364 |
-
local_mode=
|
| 1365 |
stream=stream,
|
| 1366 |
notification_gateway=notification_gateway,
|
| 1367 |
notification_destinations=config.messaging.default_auto_destinations(),
|
|
@@ -1556,6 +1602,11 @@ def cli():
|
|
| 1556 |
action="store_true",
|
| 1557 |
help="Disable token streaming (use non-streaming LLM calls)",
|
| 1558 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1559 |
args = parser.parse_args()
|
| 1560 |
|
| 1561 |
try:
|
|
@@ -1569,10 +1620,11 @@ def cli():
|
|
| 1569 |
model=args.model,
|
| 1570 |
max_iterations=max_iter,
|
| 1571 |
stream=not args.no_stream,
|
|
|
|
| 1572 |
)
|
| 1573 |
)
|
| 1574 |
else:
|
| 1575 |
-
asyncio.run(main(model=args.model))
|
| 1576 |
except KeyboardInterrupt:
|
| 1577 |
print("\n\nGoodbye!")
|
| 1578 |
|
|
|
|
| 59 |
logger = logging.getLogger(__name__)
|
| 60 |
|
| 61 |
|
| 62 |
+
def _apply_tool_runtime_override(config: Any, *, sandbox_tools: bool) -> str:
|
| 63 |
+
if sandbox_tools:
|
| 64 |
+
config.tool_runtime = "sandbox"
|
| 65 |
+
return getattr(config, "tool_runtime", "local")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _is_local_tool_runtime(config: Any) -> bool:
|
| 69 |
+
return getattr(config, "tool_runtime", "local") == "local"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _tool_runtime_label(local_mode: bool) -> str:
|
| 73 |
+
return "local filesystem" if local_mode else "HF sandbox"
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
async def _wait_for_initial_sandbox_preload(session_holder: list | None) -> None:
|
| 77 |
+
session = session_holder[0] if session_holder else None
|
| 78 |
+
task = getattr(session, "sandbox_preload_task", None)
|
| 79 |
+
if not task:
|
| 80 |
+
return
|
| 81 |
+
try:
|
| 82 |
+
await asyncio.shield(task)
|
| 83 |
+
except asyncio.CancelledError:
|
| 84 |
+
raise
|
| 85 |
+
except Exception:
|
| 86 |
+
# The sandbox tool will surface the stored preload error on first use.
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
|
| 90 |
def _is_scheduled_hf_job_tool(tool_info: dict[str, Any]) -> bool:
|
| 91 |
if tool_info.get("tool") != "hf_jobs":
|
| 92 |
return False
|
|
|
|
| 985 |
session = session_holder[0] if session_holder else None
|
| 986 |
print(f"Model: {config.model_name}")
|
| 987 |
print(f"Reasoning effort: {config.reasoning_effort or 'off'}")
|
| 988 |
+
print(f"Tool runtime: {_tool_runtime_label(_is_local_tool_runtime(config))}")
|
| 989 |
if session:
|
| 990 |
print(f"Turns: {session.turn_count}")
|
| 991 |
print(f"Context items: {len(session.context_manager.items)}")
|
|
|
|
| 1105 |
console.print(f"[green]Dataset is now {label}.[/green] {url}")
|
| 1106 |
|
| 1107 |
|
| 1108 |
+
async def main(model: str | None = None, sandbox_tools: bool = False):
|
| 1109 |
"""Interactive chat with the agent"""
|
| 1110 |
|
| 1111 |
# Clear screen
|
|
|
|
| 1117 |
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 1118 |
if model:
|
| 1119 |
config.model_name = model
|
| 1120 |
+
_apply_tool_runtime_override(config, sandbox_tools=sandbox_tools)
|
| 1121 |
+
local_mode = _is_local_tool_runtime(config)
|
| 1122 |
|
| 1123 |
+
# HF token — required for Hub-backed models/tools and sandbox tools, but
|
| 1124 |
+
# not for local LLMs using only local filesystem tools.
|
| 1125 |
hf_token = resolve_hf_token()
|
| 1126 |
+
if not hf_token and (not is_local_model_id(config.model_name) or not local_mode):
|
| 1127 |
hf_token = await _prompt_and_save_hf_token(prompt_session)
|
| 1128 |
|
| 1129 |
# Resolve username for banner
|
| 1130 |
hf_user = _get_hf_user(hf_token)
|
| 1131 |
|
| 1132 |
+
print_banner(
|
| 1133 |
+
model=config.model_name,
|
| 1134 |
+
hf_user=hf_user,
|
| 1135 |
+
tool_runtime=_tool_runtime_label(local_mode),
|
| 1136 |
+
)
|
| 1137 |
|
| 1138 |
# Pre-warm the HF router catalog in the background so /model switches
|
| 1139 |
# don't block on a network fetch.
|
|
|
|
| 1152 |
|
| 1153 |
notification_gateway = NotificationGateway(config.messaging)
|
| 1154 |
await notification_gateway.start()
|
| 1155 |
+
# Create tool router with the selected CLI tool runtime.
|
| 1156 |
+
tool_router = ToolRouter(
|
| 1157 |
+
config.mcpServers, hf_token=hf_token, local_mode=local_mode
|
| 1158 |
+
)
|
| 1159 |
|
| 1160 |
# Session holder for interrupt/model/status access
|
| 1161 |
session_holder = [None]
|
|
|
|
| 1169 |
session_holder=session_holder,
|
| 1170 |
hf_token=hf_token,
|
| 1171 |
user_id=hf_user,
|
| 1172 |
+
local_mode=local_mode,
|
| 1173 |
stream=True,
|
| 1174 |
notification_gateway=notification_gateway,
|
| 1175 |
notification_destinations=config.messaging.default_auto_destinations(),
|
|
|
|
| 1191 |
)
|
| 1192 |
|
| 1193 |
await ready_event.wait()
|
| 1194 |
+
if not local_mode:
|
| 1195 |
+
await _wait_for_initial_sandbox_preload(session_holder)
|
| 1196 |
|
| 1197 |
submission_id = [0]
|
| 1198 |
# Mirrors codex-rs/tui/src/bottom_pane/mod.rs:137
|
|
|
|
| 1350 |
model: str | None = None,
|
| 1351 |
max_iterations: int | None = None,
|
| 1352 |
stream: bool = True,
|
| 1353 |
+
sandbox_tools: bool = False,
|
| 1354 |
) -> None:
|
| 1355 |
"""Run a single prompt headlessly and exit."""
|
| 1356 |
import logging
|
|
|
|
| 1363 |
|
| 1364 |
if model:
|
| 1365 |
config.model_name = model
|
| 1366 |
+
_apply_tool_runtime_override(config, sandbox_tools=sandbox_tools)
|
| 1367 |
+
local_mode = _is_local_tool_runtime(config)
|
| 1368 |
|
| 1369 |
hf_token = resolve_hf_token()
|
| 1370 |
+
if not hf_token and (not is_local_model_id(config.model_name) or not local_mode):
|
| 1371 |
print(
|
| 1372 |
+
"ERROR: No HF token found. Set HF_TOKEN or run `hf auth login`.",
|
| 1373 |
file=sys.stderr,
|
| 1374 |
)
|
| 1375 |
sys.exit(1)
|
|
|
|
| 1385 |
config.max_iterations = max_iterations
|
| 1386 |
|
| 1387 |
print(f"Model: {config.model_name}", file=sys.stderr)
|
| 1388 |
+
print(f"Tool runtime: {_tool_runtime_label(local_mode)}", file=sys.stderr)
|
| 1389 |
print(f"Max iterations: {config.max_iterations}", file=sys.stderr)
|
| 1390 |
print(f"Prompt: {prompt}", file=sys.stderr)
|
| 1391 |
print("---", file=sys.stderr)
|
|
|
|
| 1393 |
submission_queue: asyncio.Queue = asyncio.Queue()
|
| 1394 |
event_queue: asyncio.Queue = asyncio.Queue()
|
| 1395 |
|
| 1396 |
+
tool_router = ToolRouter(
|
| 1397 |
+
config.mcpServers, hf_token=hf_token, local_mode=local_mode
|
| 1398 |
+
)
|
| 1399 |
session_holder: list = [None]
|
| 1400 |
|
| 1401 |
agent_task = asyncio.create_task(
|
|
|
|
| 1407 |
session_holder=session_holder,
|
| 1408 |
hf_token=hf_token,
|
| 1409 |
user_id=hf_user,
|
| 1410 |
+
local_mode=local_mode,
|
| 1411 |
stream=stream,
|
| 1412 |
notification_gateway=notification_gateway,
|
| 1413 |
notification_destinations=config.messaging.default_auto_destinations(),
|
|
|
|
| 1602 |
action="store_true",
|
| 1603 |
help="Disable token streaming (use non-streaming LLM calls)",
|
| 1604 |
)
|
| 1605 |
+
parser.add_argument(
|
| 1606 |
+
"--sandbox-tools",
|
| 1607 |
+
action="store_true",
|
| 1608 |
+
help="Use HF Space sandbox tools instead of local filesystem tools",
|
| 1609 |
+
)
|
| 1610 |
args = parser.parse_args()
|
| 1611 |
|
| 1612 |
try:
|
|
|
|
| 1620 |
model=args.model,
|
| 1621 |
max_iterations=max_iter,
|
| 1622 |
stream=not args.no_stream,
|
| 1623 |
+
sandbox_tools=args.sandbox_tools,
|
| 1624 |
)
|
| 1625 |
)
|
| 1626 |
else:
|
| 1627 |
+
asyncio.run(main(model=args.model, sandbox_tools=args.sandbox_tools))
|
| 1628 |
except KeyboardInterrupt:
|
| 1629 |
print("\n\nGoodbye!")
|
| 1630 |
|
agent/prompts/system_prompt_v3.yaml
CHANGED
|
@@ -66,7 +66,7 @@ system_prompt: |
|
|
| 66 |
report_to="trackio"
|
| 67 |
run_name="<descriptive-run-name>" # e.g. "sft_qwen3-4b_lr2e-5_bs128"
|
| 68 |
project="<descriptive-project-name>" # keeps related runs grouped so you can compare them
|
| 69 |
-
trackio_space_id="<username>/
|
| 70 |
`project` and `trackio_space_id` can also be set via TRACKIO_PROJECT / TRACKIO_SPACE_ID env vars.
|
| 71 |
|
| 72 |
Alerts are how iterations decide what to change. Use trackio.alert(title, text, level) at every decision point in training. Levels:
|
|
@@ -102,9 +102,18 @@ system_prompt: |
|
|
| 102 |
|
| 103 |
# When submitting a training job
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
Before calling hf_jobs, output a pre-flight check:
|
| 106 |
- Reference implementation: [which example you based this on]
|
| 107 |
- Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
|
|
|
|
| 108 |
- push_to_hub=True and hub_model_id set
|
| 109 |
- timeout: [value] (based on: [model size] on [hardware])
|
| 110 |
- Trackio monitoring included and deploying metrics to a public Space
|
|
@@ -127,7 +136,7 @@ system_prompt: |
|
|
| 127 |
|
| 128 |
Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier.
|
| 129 |
|
| 130 |
-
Use GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16, or model loading. CPU sandboxes cannot test GPU code paths.
|
| 131 |
|
| 132 |
|
| 133 |
# When a task has 3+ steps
|
|
|
|
| 66 |
report_to="trackio"
|
| 67 |
run_name="<descriptive-run-name>" # e.g. "sft_qwen3-4b_lr2e-5_bs128"
|
| 68 |
project="<descriptive-project-name>" # keeps related runs grouped so you can compare them
|
| 69 |
+
trackio_space_id="<username>/ml-intern-<8-char-id>" # creates a public dashboard Space
|
| 70 |
`project` and `trackio_space_id` can also be set via TRACKIO_PROJECT / TRACKIO_SPACE_ID env vars.
|
| 71 |
|
| 72 |
Alerts are how iterations decide what to change. Use trackio.alert(title, text, level) at every decision point in training. Levels:
|
|
|
|
| 102 |
|
| 103 |
# When submitting a training job
|
| 104 |
|
| 105 |
+
Never pass a local machine path to hf_jobs.script, such as /Users/..., /home/..., /fsx/..., or a repo checkout path. HF Jobs runs in a fresh cloud environment where local files do not exist. For hf_jobs.script, use exactly one of:
|
| 106 |
+
- inline Python source code
|
| 107 |
+
- a file already written in the session sandbox, e.g. /app/train.py, ./train.py, or train.py
|
| 108 |
+
- a public/raw URL
|
| 109 |
+
If you wrote or tested a script locally, read the file content and submit it inline, or write it into the sandbox first.
|
| 110 |
+
|
| 111 |
+
GPU preflight is mandatory before hf_jobs when the job will run on GPU, or when the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, or torch.compile. First create a GPU sandbox with sandbox_create (t4-small minimum; choose larger hardware when VRAM requires it), run a tiny smoke test there using the same imports, model-loading path, training entrypoint, and a tiny dataset/subset, then fix failures before submitting. If you skip GPU sandbox preflight, state why before calling hf_jobs.
|
| 112 |
+
|
| 113 |
Before calling hf_jobs, output a pre-flight check:
|
| 114 |
- Reference implementation: [which example you based this on]
|
| 115 |
- Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
|
| 116 |
+
- GPU sandbox smoke test: [hardware and result, or explicitly not applicable because ...]
|
| 117 |
- push_to_hub=True and hub_model_id set
|
| 118 |
- timeout: [value] (based on: [model size] on [hardware])
|
| 119 |
- Trackio monitoring included and deploying metrics to a public Space
|
|
|
|
| 136 |
|
| 137 |
Do NOT call sandbox_create before normal CPU work. Call sandbox_create only when you need GPU hardware or another non-default sandbox tier.
|
| 138 |
|
| 139 |
+
Use a GPU sandbox (t4-small minimum) when testing code that uses CUDA, bf16/fp16, quantization, flash attention, torch.compile, or model loading. CPU sandboxes cannot test GPU code paths. If the available sandbox tiers cannot fit the full model path, test the largest useful smoke path, state what was not covered, and submit one HF job first.
|
| 140 |
|
| 141 |
|
| 142 |
# When a task has 3+ steps
|
agent/tools/jobs_tool.py
CHANGED
|
@@ -1112,11 +1112,14 @@ HF_JOBS_TOOL_SPEC = {
|
|
| 1112 |
"- You MUST have called github_find_examples + github_read_file to find a working reference implementation. "
|
| 1113 |
"Scripts based on your internal knowledge WILL use outdated APIs and fail.\n"
|
| 1114 |
"- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
|
|
|
|
|
|
|
|
|
|
| 1115 |
"- Training config MUST include push_to_hub=True and hub_model_id. "
|
| 1116 |
"Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
|
| 1117 |
"- Include trackio monitoring and provide the dashboard URL to the user. "
|
| 1118 |
"When the script uses report_to='trackio', also pass `trackio_space_id` "
|
| 1119 |
-
"(e.g. '<username>/
|
| 1120 |
"they are injected as TRACKIO_SPACE_ID/TRACKIO_PROJECT env vars and let the UI embed the live dashboard.\n\n"
|
| 1121 |
"BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. "
|
| 1122 |
"Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n"
|
|
@@ -1157,8 +1160,9 @@ HF_JOBS_TOOL_SPEC = {
|
|
| 1157 |
"script": {
|
| 1158 |
"type": "string",
|
| 1159 |
"description": (
|
| 1160 |
-
"Python code
|
| 1161 |
"Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. "
|
|
|
|
| 1162 |
"Mutually exclusive with 'command'."
|
| 1163 |
),
|
| 1164 |
},
|
|
@@ -1204,7 +1208,7 @@ HF_JOBS_TOOL_SPEC = {
|
|
| 1204 |
"type": "string",
|
| 1205 |
"description": (
|
| 1206 |
"Optional. The HF Space hosting the trackio dashboard for this run "
|
| 1207 |
-
"(e.g. '<username>/
|
| 1208 |
"Injected as TRACKIO_SPACE_ID env var and used by the UI to embed "
|
| 1209 |
"the live dashboard. Set this whenever the script uses "
|
| 1210 |
"report_to='trackio'. The Space is auto-created and seeded with the "
|
|
|
|
| 1112 |
"- You MUST have called github_find_examples + github_read_file to find a working reference implementation. "
|
| 1113 |
"Scripts based on your internal knowledge WILL use outdated APIs and fail.\n"
|
| 1114 |
"- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
|
| 1115 |
+
"- If the job runs on GPU, or the script loads a model, uses CUDA, bf16/fp16, quantization, flash attention, "
|
| 1116 |
+
"or torch.compile, you MUST create a GPU sandbox with sandbox_create first, run a tiny smoke test there, "
|
| 1117 |
+
"and fix failures before submitting. If skipped, state why before calling hf_jobs.\n"
|
| 1118 |
"- Training config MUST include push_to_hub=True and hub_model_id. "
|
| 1119 |
"Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
|
| 1120 |
"- Include trackio monitoring and provide the dashboard URL to the user. "
|
| 1121 |
"When the script uses report_to='trackio', also pass `trackio_space_id` "
|
| 1122 |
+
"(e.g. '<username>/ml-intern-<8char>') and `trackio_project` as tool args — "
|
| 1123 |
"they are injected as TRACKIO_SPACE_ID/TRACKIO_PROJECT env vars and let the UI embed the live dashboard.\n\n"
|
| 1124 |
"BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. "
|
| 1125 |
"Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n"
|
|
|
|
| 1160 |
"script": {
|
| 1161 |
"type": "string",
|
| 1162 |
"description": (
|
| 1163 |
+
"Python code, sandbox file path (e.g. '/app/train.py', './train.py', or bare 'train.py'), or URL. "
|
| 1164 |
"Triggers Python mode. For ML training: base this on a working example found via github_find_examples, not on internal knowledge. "
|
| 1165 |
+
"For GPU/model-loading training scripts, smoke-test in a GPU sandbox before submission. "
|
| 1166 |
"Mutually exclusive with 'command'."
|
| 1167 |
),
|
| 1168 |
},
|
|
|
|
| 1208 |
"type": "string",
|
| 1209 |
"description": (
|
| 1210 |
"Optional. The HF Space hosting the trackio dashboard for this run "
|
| 1211 |
+
"(e.g. '<username>/ml-intern-<8char>', under YOUR HF namespace). "
|
| 1212 |
"Injected as TRACKIO_SPACE_ID env var and used by the UI to embed "
|
| 1213 |
"the live dashboard. Set this whenever the script uses "
|
| 1214 |
"report_to='trackio'. The Space is auto-created and seeded with the "
|
agent/tools/plan_tool.py
CHANGED
|
@@ -54,20 +54,24 @@ class PlanTool:
|
|
| 54 |
"isError": True,
|
| 55 |
}
|
| 56 |
|
| 57 |
-
# Store the
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# Emit plan update event if session is available
|
| 61 |
if self.session:
|
| 62 |
await self.session.send_event(
|
| 63 |
Event(
|
| 64 |
event_type="plan_update",
|
| 65 |
-
data={"plan":
|
| 66 |
)
|
| 67 |
)
|
| 68 |
|
| 69 |
# Format only for display using terminal_display utility
|
| 70 |
-
formatted_output = format_plan_tool_output(
|
| 71 |
|
| 72 |
return {
|
| 73 |
"formatted": formatted_output,
|
|
|
|
| 54 |
"isError": True,
|
| 55 |
}
|
| 56 |
|
| 57 |
+
# Store a session-scoped copy so the runtime can tell whether a
|
| 58 |
+
# text-only model response is trying to stop while work remains.
|
| 59 |
+
stored_todos = [dict(todo) for todo in todos]
|
| 60 |
+
_current_plan = stored_todos
|
| 61 |
+
if self.session is not None:
|
| 62 |
+
self.session.current_plan = stored_todos
|
| 63 |
|
| 64 |
# Emit plan update event if session is available
|
| 65 |
if self.session:
|
| 66 |
await self.session.send_event(
|
| 67 |
Event(
|
| 68 |
event_type="plan_update",
|
| 69 |
+
data={"plan": stored_todos},
|
| 70 |
)
|
| 71 |
)
|
| 72 |
|
| 73 |
# Format only for display using terminal_display utility
|
| 74 |
+
formatted_output = format_plan_tool_output(stored_todos)
|
| 75 |
|
| 76 |
return {
|
| 77 |
"formatted": formatted_output,
|
agent/tools/sandbox_client.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# /// script
|
| 3 |
# requires-python = ">=3.10"
|
| 4 |
-
# dependencies = ["huggingface_hub>=
|
| 5 |
# ///
|
| 6 |
"""
|
| 7 |
Sandbox Tools — Agent-native primitives for HF Space dev-mode sandboxes.
|
|
@@ -615,18 +615,19 @@ class Sandbox:
|
|
| 615 |
kwargs = {
|
| 616 |
"from_id": template,
|
| 617 |
"to_id": space_id,
|
|
|
|
| 618 |
"private": private,
|
| 619 |
-
"
|
| 620 |
}
|
| 621 |
if sleep_time is not None:
|
| 622 |
-
kwargs["
|
| 623 |
|
| 624 |
-
api.
|
| 625 |
_log(f"Space created: https://huggingface.co/spaces/{space_id}")
|
| 626 |
|
| 627 |
_check_cancel()
|
| 628 |
|
| 629 |
-
# ``
|
| 630 |
# initial create request. Avoid a second /hardware call: deployed HF
|
| 631 |
# OAuth tokens can 401 on that endpoint for a just-created private
|
| 632 |
# Space even though duplication itself succeeded. We rely on the
|
|
@@ -775,21 +776,23 @@ class Sandbox:
|
|
| 775 |
f"Last status: {last_status}, last error: {last_err}"
|
| 776 |
)
|
| 777 |
|
| 778 |
-
def delete(self):
|
| 779 |
"""Delete the Space. Only works if this Sandbox created it."""
|
| 780 |
if not self._owns_space:
|
| 781 |
raise RuntimeError(
|
| 782 |
f"This Sandbox did not create {self.space_id}. "
|
| 783 |
f"Use self._hf_api.delete_repo() directly if you're sure."
|
| 784 |
)
|
| 785 |
-
|
|
|
|
| 786 |
self._hf_api.delete_repo(self.space_id, repo_type="space")
|
| 787 |
# Clear ownership so a second cleanup call (e.g. delete_session +
|
| 788 |
# _run_session.finally both fire) early-returns instead of retrying
|
| 789 |
# a 404 delete and emitting a spurious ERROR log.
|
| 790 |
self._owns_space = False
|
| 791 |
self._client.close()
|
| 792 |
-
|
|
|
|
| 793 |
|
| 794 |
def pause(self):
|
| 795 |
"""Pause the Space (stops billing, preserves state)."""
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
# /// script
|
| 3 |
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = ["huggingface_hub>=1.12.0", "httpx>=0.27.0"]
|
| 5 |
# ///
|
| 6 |
"""
|
| 7 |
Sandbox Tools — Agent-native primitives for HF Space dev-mode sandboxes.
|
|
|
|
| 615 |
kwargs = {
|
| 616 |
"from_id": template,
|
| 617 |
"to_id": space_id,
|
| 618 |
+
"repo_type": "space",
|
| 619 |
"private": private,
|
| 620 |
+
"space_hardware": hardware,
|
| 621 |
}
|
| 622 |
if sleep_time is not None:
|
| 623 |
+
kwargs["space_sleep_time"] = sleep_time
|
| 624 |
|
| 625 |
+
api.duplicate_repo(**kwargs)
|
| 626 |
_log(f"Space created: https://huggingface.co/spaces/{space_id}")
|
| 627 |
|
| 628 |
_check_cancel()
|
| 629 |
|
| 630 |
+
# ``duplicate_repo`` sends hardware and sleepTimeSeconds in the
|
| 631 |
# initial create request. Avoid a second /hardware call: deployed HF
|
| 632 |
# OAuth tokens can 401 on that endpoint for a just-created private
|
| 633 |
# Space even though duplication itself succeeded. We rely on the
|
|
|
|
| 776 |
f"Last status: {last_status}, last error: {last_err}"
|
| 777 |
)
|
| 778 |
|
| 779 |
+
def delete(self, log: Callable[[str], object] | None = None):
|
| 780 |
"""Delete the Space. Only works if this Sandbox created it."""
|
| 781 |
if not self._owns_space:
|
| 782 |
raise RuntimeError(
|
| 783 |
f"This Sandbox did not create {self.space_id}. "
|
| 784 |
f"Use self._hf_api.delete_repo() directly if you're sure."
|
| 785 |
)
|
| 786 |
+
if log:
|
| 787 |
+
log(f"Deleting sandbox: {self.space_id}...")
|
| 788 |
self._hf_api.delete_repo(self.space_id, repo_type="space")
|
| 789 |
# Clear ownership so a second cleanup call (e.g. delete_session +
|
| 790 |
# _run_session.finally both fire) early-returns instead of retrying
|
| 791 |
# a 404 delete and emitting a spurious ERROR log.
|
| 792 |
self._owns_space = False
|
| 793 |
self._client.close()
|
| 794 |
+
if log:
|
| 795 |
+
log("Deleted.")
|
| 796 |
|
| 797 |
def pause(self):
|
| 798 |
"""Pause the Space (stops billing, preserves state)."""
|
agent/tools/sandbox_tool.py
CHANGED
|
@@ -16,6 +16,7 @@ import logging
|
|
| 16 |
import re
|
| 17 |
import threading
|
| 18 |
import weakref
|
|
|
|
| 19 |
from datetime import datetime, timedelta, timezone
|
| 20 |
from typing import Any
|
| 21 |
|
|
@@ -58,17 +59,41 @@ def _get_sandbox_create_lock(owner: str) -> asyncio.Lock:
|
|
| 58 |
return lock
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def _looks_like_path(script: str) -> bool:
|
| 62 |
"""Return True if the script string looks like a file path (not inline code)."""
|
| 63 |
-
|
| 64 |
isinstance(script, str)
|
| 65 |
and script.strip() == script
|
| 66 |
and not any(c in script for c in "\r\n\0")
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
|
|
@@ -303,14 +328,8 @@ async def _create_sandbox_locked(
|
|
| 303 |
)
|
| 304 |
)
|
| 305 |
|
| 306 |
-
# Thread-safe log callback: posts tool_log events from
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
def _log(msg: str) -> None:
|
| 310 |
-
loop.call_soon_threadsafe(
|
| 311 |
-
session.event_queue.put_nowait,
|
| 312 |
-
Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
|
| 313 |
-
)
|
| 314 |
|
| 315 |
# Bridge asyncio cancel event to a threading.Event for the blocking create call.
|
| 316 |
# We poll session._cancelled from the main loop in a background task and set
|
|
@@ -352,7 +371,7 @@ async def _create_sandbox_locked(
|
|
| 352 |
if cancel_flag.is_set():
|
| 353 |
if getattr(sb, "_owns_space", False):
|
| 354 |
try:
|
| 355 |
-
await asyncio.to_thread(sb.delete)
|
| 356 |
except Exception as e:
|
| 357 |
logger.warning(
|
| 358 |
"Failed to delete cancelled sandbox %s: %s", sb.space_id, e
|
|
@@ -497,6 +516,7 @@ async def teardown_session_sandbox(session: Any) -> None:
|
|
| 497 |
return
|
| 498 |
|
| 499 |
space_id = getattr(sandbox, "space_id", None)
|
|
|
|
| 500 |
last_err: Exception | None = None
|
| 501 |
for attempt in range(3):
|
| 502 |
try:
|
|
@@ -505,7 +525,7 @@ async def teardown_session_sandbox(session: Any) -> None:
|
|
| 505 |
space_id,
|
| 506 |
attempt + 1,
|
| 507 |
)
|
| 508 |
-
await asyncio.to_thread(sandbox.delete)
|
| 509 |
from agent.core import telemetry
|
| 510 |
|
| 511 |
await telemetry.record_sandbox_destroy(session, sandbox)
|
|
@@ -542,7 +562,7 @@ SANDBOX_CREATE_TOOL_SPEC = {
|
|
| 542 |
"Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). "
|
| 543 |
"If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n"
|
| 544 |
"If you intend to run a training script in this sandbox that uses report_to='trackio', "
|
| 545 |
-
"pass `trackio_space_id` (e.g. '<username>/
|
| 546 |
"are set as TRACKIO_SPACE_ID/TRACKIO_PROJECT secrets in the sandbox and the UI can embed the live dashboard.\n\n"
|
| 547 |
"Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n"
|
| 548 |
),
|
|
@@ -563,7 +583,7 @@ SANDBOX_CREATE_TOOL_SPEC = {
|
|
| 563 |
"type": "string",
|
| 564 |
"description": (
|
| 565 |
"Optional. The HF Space hosting the trackio dashboard for runs in this sandbox "
|
| 566 |
-
"(e.g. '<username>/
|
| 567 |
"TRACKIO_SPACE_ID secret and surfaced to the UI. The Space is auto-created and "
|
| 568 |
"seeded with the trackio dashboard — DO NOT pre-create it via hf_repo_git, "
|
| 569 |
"that produces an empty Space that breaks the embed."
|
|
|
|
| 16 |
import re
|
| 17 |
import threading
|
| 18 |
import weakref
|
| 19 |
+
from collections.abc import Callable
|
| 20 |
from datetime import datetime, timedelta, timezone
|
| 21 |
from typing import Any
|
| 22 |
|
|
|
|
| 59 |
return lock
|
| 60 |
|
| 61 |
|
| 62 |
+
def _session_tool_logger(
|
| 63 |
+
session: Any, *, tool: str = "sandbox"
|
| 64 |
+
) -> Callable[[str], object] | None:
|
| 65 |
+
event_queue = getattr(session, "event_queue", None)
|
| 66 |
+
if event_queue is None:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
loop = asyncio.get_running_loop()
|
| 70 |
+
|
| 71 |
+
def _log(msg: str) -> None:
|
| 72 |
+
loop.call_soon_threadsafe(
|
| 73 |
+
event_queue.put_nowait,
|
| 74 |
+
Event(event_type="tool_log", data={"tool": tool, "log": msg}),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return _log
|
| 78 |
+
|
| 79 |
+
|
| 80 |
def _looks_like_path(script: str) -> bool:
|
| 81 |
"""Return True if the script string looks like a file path (not inline code)."""
|
| 82 |
+
if not (
|
| 83 |
isinstance(script, str)
|
| 84 |
and script.strip() == script
|
| 85 |
and not any(c in script for c in "\r\n\0")
|
| 86 |
+
):
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
if script.startswith("http://") or script.startswith("https://"):
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
return (
|
| 93 |
+
script.startswith("/")
|
| 94 |
+
or script.startswith("./")
|
| 95 |
+
or script.startswith("../")
|
| 96 |
+
or (script.endswith(".py") and not any(c.isspace() for c in script))
|
| 97 |
)
|
| 98 |
|
| 99 |
|
|
|
|
| 328 |
)
|
| 329 |
)
|
| 330 |
|
| 331 |
+
# Thread-safe log callback: posts tool_log events from worker threads.
|
| 332 |
+
_log = _session_tool_logger(session) or (lambda msg: None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
# Bridge asyncio cancel event to a threading.Event for the blocking create call.
|
| 335 |
# We poll session._cancelled from the main loop in a background task and set
|
|
|
|
| 371 |
if cancel_flag.is_set():
|
| 372 |
if getattr(sb, "_owns_space", False):
|
| 373 |
try:
|
| 374 |
+
await asyncio.to_thread(sb.delete, log=_log)
|
| 375 |
except Exception as e:
|
| 376 |
logger.warning(
|
| 377 |
"Failed to delete cancelled sandbox %s: %s", sb.space_id, e
|
|
|
|
| 516 |
return
|
| 517 |
|
| 518 |
space_id = getattr(sandbox, "space_id", None)
|
| 519 |
+
delete_log = _session_tool_logger(session)
|
| 520 |
last_err: Exception | None = None
|
| 521 |
for attempt in range(3):
|
| 522 |
try:
|
|
|
|
| 525 |
space_id,
|
| 526 |
attempt + 1,
|
| 527 |
)
|
| 528 |
+
await asyncio.to_thread(sandbox.delete, log=delete_log)
|
| 529 |
from agent.core import telemetry
|
| 530 |
|
| 531 |
await telemetry.record_sandbox_destroy(session, sandbox)
|
|
|
|
| 562 |
"Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). "
|
| 563 |
"If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n"
|
| 564 |
"If you intend to run a training script in this sandbox that uses report_to='trackio', "
|
| 565 |
+
"pass `trackio_space_id` (e.g. '<username>/ml-intern-<8char>') and `trackio_project` so they "
|
| 566 |
"are set as TRACKIO_SPACE_ID/TRACKIO_PROJECT secrets in the sandbox and the UI can embed the live dashboard.\n\n"
|
| 567 |
"Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n"
|
| 568 |
),
|
|
|
|
| 583 |
"type": "string",
|
| 584 |
"description": (
|
| 585 |
"Optional. The HF Space hosting the trackio dashboard for runs in this sandbox "
|
| 586 |
+
"(e.g. '<username>/ml-intern-<8char>', under YOUR HF namespace). Injected as "
|
| 587 |
"TRACKIO_SPACE_ID secret and surfaced to the UI. The Space is auto-created and "
|
| 588 |
"seeded with the trackio dashboard — DO NOT pre-create it via hf_repo_git, "
|
| 589 |
"that produces an empty Space that breaks the embed."
|
agent/utils/terminal_display.py
CHANGED
|
@@ -6,6 +6,7 @@ import asyncio
|
|
| 6 |
import re
|
| 7 |
|
| 8 |
from rich.console import Console
|
|
|
|
| 9 |
from rich.markdown import Heading, Markdown
|
| 10 |
from rich.panel import Panel
|
| 11 |
from rich.theme import Theme
|
|
@@ -92,7 +93,11 @@ def get_console() -> Console:
|
|
| 92 |
# ── Banner ─────────────────────────────────────────────────────────────
|
| 93 |
|
| 94 |
|
| 95 |
-
def print_banner(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
"""Print particle logo then CRT boot sequence with system info."""
|
| 97 |
from agent.utils.particle_logo import run_particle_logo
|
| 98 |
from agent.utils.crt_boot import run_boot_sequence
|
|
@@ -115,6 +120,7 @@ def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
|
|
| 115 |
(f"{_I}Initializing agent runtime...", gold),
|
| 116 |
(f"{_I} User: {user_label}", dim_gold),
|
| 117 |
(f"{_I} Model: {model_label}", dim_gold),
|
|
|
|
| 118 |
(f"{_I} Tools: loading...", dim_gold),
|
| 119 |
("", ""),
|
| 120 |
(f"{_I}/help for commands · /model to switch · /quit to exit", gold),
|
|
@@ -446,23 +452,72 @@ def print_yolo_approve(count: int) -> None:
|
|
| 446 |
|
| 447 |
# ── Help ───────────────────────────────────────────────────────────────
|
| 448 |
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
| 462 |
|
| 463 |
def print_help() -> None:
|
| 464 |
_console.print()
|
| 465 |
-
_console.print(
|
| 466 |
_console.print()
|
| 467 |
|
| 468 |
|
|
|
|
| 6 |
import re
|
| 7 |
|
| 8 |
from rich.console import Console
|
| 9 |
+
from rich.markup import escape
|
| 10 |
from rich.markdown import Heading, Markdown
|
| 11 |
from rich.panel import Panel
|
| 12 |
from rich.theme import Theme
|
|
|
|
| 93 |
# ── Banner ─────────────────────────────────────────────────────────────
|
| 94 |
|
| 95 |
|
| 96 |
+
def print_banner(
|
| 97 |
+
model: str | None = None,
|
| 98 |
+
hf_user: str | None = None,
|
| 99 |
+
tool_runtime: str | None = None,
|
| 100 |
+
) -> None:
|
| 101 |
"""Print particle logo then CRT boot sequence with system info."""
|
| 102 |
from agent.utils.particle_logo import run_particle_logo
|
| 103 |
from agent.utils.crt_boot import run_boot_sequence
|
|
|
|
| 120 |
(f"{_I}Initializing agent runtime...", gold),
|
| 121 |
(f"{_I} User: {user_label}", dim_gold),
|
| 122 |
(f"{_I} Model: {model_label}", dim_gold),
|
| 123 |
+
(f"{_I} Tool runtime: {tool_runtime or 'local filesystem'}", dim_gold),
|
| 124 |
(f"{_I} Tools: loading...", dim_gold),
|
| 125 |
("", ""),
|
| 126 |
(f"{_I}/help for commands · /model to switch · /quit to exit", gold),
|
|
|
|
| 452 |
|
| 453 |
# ── Help ───────────────────────────────────────────────────────────────
|
| 454 |
|
| 455 |
+
HELP_ROWS: tuple[tuple[str, str, str], ...] = (
|
| 456 |
+
("/help", "", "Show this help"),
|
| 457 |
+
("/undo", "", "Undo last turn"),
|
| 458 |
+
("/compact", "", "Compact context window"),
|
| 459 |
+
("/resume", "[index|id|path]", "Pick up from ./session_logs"),
|
| 460 |
+
("/model", "[id]", "Show available models or switch"),
|
| 461 |
+
(
|
| 462 |
+
"/effort",
|
| 463 |
+
"[level]",
|
| 464 |
+
"Set reasoning effort preference",
|
| 465 |
+
),
|
| 466 |
+
("/yolo", "", "Toggle auto-approve mode"),
|
| 467 |
+
("/status", "", "Current model & turn count"),
|
| 468 |
+
(
|
| 469 |
+
"/share-traces",
|
| 470 |
+
"[public|private]",
|
| 471 |
+
"Show or change HF trace visibility",
|
| 472 |
+
),
|
| 473 |
+
("/quit", "", "Exit"),
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def _help_column_widths(
|
| 478 |
+
rows: tuple[tuple[str, str, str], ...],
|
| 479 |
+
) -> tuple[int, int]:
|
| 480 |
+
return (
|
| 481 |
+
max(len(command) for command, _, _ in rows),
|
| 482 |
+
max(len(args) for _, args, _ in rows),
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def _format_help_row(
|
| 487 |
+
command: str,
|
| 488 |
+
args: str,
|
| 489 |
+
description: str,
|
| 490 |
+
command_width: int,
|
| 491 |
+
args_width: int,
|
| 492 |
+
) -> str:
|
| 493 |
+
command_gap = " " * (command_width - len(command) + 2)
|
| 494 |
+
args_gap = " " * (args_width - len(args) + 2)
|
| 495 |
+
command_markup = f"[cyan]{escape(command)}[/cyan]"
|
| 496 |
+
args_markup = f"[muted]{escape(args)}[/muted]" if args else ""
|
| 497 |
+
return f"{_I} {command_markup}{command_gap}{args_markup}{args_gap}{description}"
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def format_help_text(rows: tuple[tuple[str, str, str], ...] | None = None) -> str:
|
| 501 |
+
help_rows = HELP_ROWS if rows is None else rows
|
| 502 |
+
command_width, args_width = _help_column_widths(help_rows)
|
| 503 |
+
return "\n".join(
|
| 504 |
+
[f"{_I}[bold]Commands[/bold]"]
|
| 505 |
+
+ [
|
| 506 |
+
_format_help_row(
|
| 507 |
+
command,
|
| 508 |
+
args,
|
| 509 |
+
description,
|
| 510 |
+
command_width,
|
| 511 |
+
args_width,
|
| 512 |
+
)
|
| 513 |
+
for command, args, description in help_rows
|
| 514 |
+
]
|
| 515 |
+
)
|
| 516 |
|
| 517 |
|
| 518 |
def print_help() -> None:
|
| 519 |
_console.print()
|
| 520 |
+
_console.print(format_help_text())
|
| 521 |
_console.print()
|
| 522 |
|
| 523 |
|
backend/dataset_uploads.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for session-scoped dataset uploads to the Hugging Face Hub."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import uuid
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from urllib.parse import quote
|
| 9 |
+
|
| 10 |
+
from fastapi import HTTPException, UploadFile
|
| 11 |
+
from huggingface_hub import HfApi
|
| 12 |
+
|
| 13 |
+
MAX_DATASET_UPLOAD_BYTES = 100 * 1024 * 1024
|
| 14 |
+
ALLOWED_DATASET_EXTENSIONS = {"csv", "json", "jsonl"}
|
| 15 |
+
_SAFE_FILENAME_RE = re.compile(r"[^A-Za-z0-9._-]+")
|
| 16 |
+
_SAFE_NAMESPACE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,95}$")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass(frozen=True)
|
| 20 |
+
class DatasetUpload:
|
| 21 |
+
session_id: str
|
| 22 |
+
repo_id: str
|
| 23 |
+
repo_type: str
|
| 24 |
+
private: bool
|
| 25 |
+
upload_id: str
|
| 26 |
+
config_name: str
|
| 27 |
+
filename: str
|
| 28 |
+
original_filename: str
|
| 29 |
+
path_in_repo: str
|
| 30 |
+
size_bytes: int
|
| 31 |
+
format: str
|
| 32 |
+
hub_url: str
|
| 33 |
+
load_dataset_snippet: str
|
| 34 |
+
|
| 35 |
+
def response_payload(self) -> dict[str, str | int | bool]:
|
| 36 |
+
return {
|
| 37 |
+
"session_id": self.session_id,
|
| 38 |
+
"repo_id": self.repo_id,
|
| 39 |
+
"repo_type": self.repo_type,
|
| 40 |
+
"private": self.private,
|
| 41 |
+
"upload_id": self.upload_id,
|
| 42 |
+
"config_name": self.config_name,
|
| 43 |
+
"filename": self.filename,
|
| 44 |
+
"path_in_repo": self.path_in_repo,
|
| 45 |
+
"size_bytes": self.size_bytes,
|
| 46 |
+
"format": self.format,
|
| 47 |
+
"hub_url": self.hub_url,
|
| 48 |
+
"load_dataset_snippet": self.load_dataset_snippet,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def sanitize_dataset_filename(filename: str | None) -> str:
|
| 53 |
+
"""Return a Hub-safe basename while preserving the extension."""
|
| 54 |
+
raw = os.path.basename(filename or "").strip()
|
| 55 |
+
if not raw:
|
| 56 |
+
raw = "dataset.csv"
|
| 57 |
+
|
| 58 |
+
safe = _SAFE_FILENAME_RE.sub("-", raw).strip(".-_")
|
| 59 |
+
if not safe:
|
| 60 |
+
safe = "dataset.csv"
|
| 61 |
+
|
| 62 |
+
stem, ext = os.path.splitext(safe)
|
| 63 |
+
if not stem:
|
| 64 |
+
stem = "dataset"
|
| 65 |
+
if not ext:
|
| 66 |
+
ext = ".csv"
|
| 67 |
+
|
| 68 |
+
max_stem_len = 96 - len(ext)
|
| 69 |
+
stem = stem[:max_stem_len].strip(".-_") or "dataset"
|
| 70 |
+
return f"{stem}{ext.lower()}"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def display_filename(filename: str | None, fallback: str) -> str:
|
| 74 |
+
raw = os.path.basename(filename or "").strip()
|
| 75 |
+
if not raw:
|
| 76 |
+
return fallback
|
| 77 |
+
cleaned = "".join(char for char in raw if ord(char) >= 32)
|
| 78 |
+
return cleaned[:160] or fallback
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def dataset_format_from_filename(filename: str) -> str:
|
| 82 |
+
ext = os.path.splitext(filename)[1].lower().lstrip(".")
|
| 83 |
+
if ext not in ALLOWED_DATASET_EXTENSIONS:
|
| 84 |
+
raise HTTPException(
|
| 85 |
+
status_code=400,
|
| 86 |
+
detail="Only .csv, .json, and .jsonl dataset files are supported.",
|
| 87 |
+
)
|
| 88 |
+
return ext
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def session_dataset_repo_id(hf_username: str | None, session_id: str) -> str:
|
| 92 |
+
namespace = (hf_username or "").strip()
|
| 93 |
+
if not namespace or not _SAFE_NAMESPACE_RE.fullmatch(namespace):
|
| 94 |
+
raise HTTPException(
|
| 95 |
+
status_code=400,
|
| 96 |
+
detail="Could not determine a valid Hugging Face namespace.",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
safe_session_id = re.sub(r"[^A-Za-z0-9]+", "-", session_id).strip("-")
|
| 100 |
+
if not safe_session_id:
|
| 101 |
+
safe_session_id = uuid.uuid4().hex[:8]
|
| 102 |
+
return f"{namespace}/ml-intern-{safe_session_id[:8]}-datasets"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
async def upload_size_bytes(upload: UploadFile) -> int:
|
| 106 |
+
await asyncio.to_thread(upload.file.seek, 0, os.SEEK_END)
|
| 107 |
+
size = await asyncio.to_thread(upload.file.tell)
|
| 108 |
+
await asyncio.to_thread(upload.file.seek, 0)
|
| 109 |
+
return int(size)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
async def validate_dataset_upload(upload: UploadFile) -> tuple[str, str, int]:
|
| 113 |
+
dataset_format = dataset_format_from_filename(upload.filename or "")
|
| 114 |
+
safe_filename = sanitize_dataset_filename(upload.filename)
|
| 115 |
+
size = await upload_size_bytes(upload)
|
| 116 |
+
if size <= 0:
|
| 117 |
+
raise HTTPException(status_code=400, detail="Uploaded dataset file is empty.")
|
| 118 |
+
if size > MAX_DATASET_UPLOAD_BYTES:
|
| 119 |
+
raise HTTPException(
|
| 120 |
+
status_code=413,
|
| 121 |
+
detail="Dataset upload exceeds the 100 MB limit.",
|
| 122 |
+
)
|
| 123 |
+
return safe_filename, dataset_format, size
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def dataset_hub_url(repo_id: str, path_in_repo: str) -> str:
|
| 127 |
+
quoted_path = quote(path_in_repo, safe="/")
|
| 128 |
+
return f"https://huggingface.co/datasets/{repo_id}/blob/main/{quoted_path}"
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def dataset_config_name(upload_id: str) -> str:
|
| 132 |
+
safe_upload_id = re.sub(r"[^A-Za-z0-9]+", "_", upload_id).strip("_").lower()
|
| 133 |
+
if not safe_upload_id:
|
| 134 |
+
safe_upload_id = "dataset"
|
| 135 |
+
return f"upload_{safe_upload_id[:32]}"
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def dataset_config_name_from_path(path_in_repo: str) -> str:
|
| 139 |
+
parts = path_in_repo.split("/")
|
| 140 |
+
if len(parts) >= 3 and parts[0] == "uploads":
|
| 141 |
+
return dataset_config_name(parts[1])
|
| 142 |
+
stem = os.path.splitext(os.path.basename(path_in_repo))[0]
|
| 143 |
+
return dataset_config_name(stem)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def is_dataset_upload_path(path_in_repo: str) -> bool:
|
| 147 |
+
parts = path_in_repo.split("/")
|
| 148 |
+
if len(parts) != 3 or parts[0] != "uploads" or not parts[1] or not parts[2]:
|
| 149 |
+
return False
|
| 150 |
+
extension = os.path.splitext(path_in_repo)[1].lower().lstrip(".")
|
| 151 |
+
return extension in ALLOWED_DATASET_EXTENSIONS
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def unique_dataset_upload_paths(paths: list[str]) -> list[str]:
|
| 155 |
+
seen = set()
|
| 156 |
+
upload_paths = []
|
| 157 |
+
for path in paths:
|
| 158 |
+
if not is_dataset_upload_path(path) or path in seen:
|
| 159 |
+
continue
|
| 160 |
+
seen.add(path)
|
| 161 |
+
upload_paths.append(path)
|
| 162 |
+
return upload_paths
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def load_dataset_snippet(repo_id: str, config_name: str) -> str:
|
| 166 |
+
return (
|
| 167 |
+
"from datasets import load_dataset\n\n"
|
| 168 |
+
f'dataset = load_dataset("{repo_id}", "{config_name}", '
|
| 169 |
+
'split="train", token=True)'
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def dataset_repo_card(repo_id: str, upload_paths: list[str]) -> bytes:
|
| 174 |
+
config_lines = []
|
| 175 |
+
unique_upload_paths = unique_dataset_upload_paths(upload_paths)
|
| 176 |
+
if unique_upload_paths:
|
| 177 |
+
config_lines.append("configs:")
|
| 178 |
+
for path in unique_upload_paths:
|
| 179 |
+
config_lines.extend(
|
| 180 |
+
[
|
| 181 |
+
f"- config_name: {dataset_config_name_from_path(path)}",
|
| 182 |
+
" data_files:",
|
| 183 |
+
" - split: train",
|
| 184 |
+
f' path: "{path}"',
|
| 185 |
+
]
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
configs = "\n".join(config_lines)
|
| 189 |
+
if configs:
|
| 190 |
+
configs = f"{configs}\n"
|
| 191 |
+
|
| 192 |
+
content = f"""---
|
| 193 |
+
tags:
|
| 194 |
+
- ml-intern
|
| 195 |
+
- uploaded-dataset
|
| 196 |
+
{configs}---
|
| 197 |
+
|
| 198 |
+
# {repo_id}
|
| 199 |
+
|
| 200 |
+
Private dataset files uploaded through ML Intern.
|
| 201 |
+
|
| 202 |
+
Files are stored under `uploads/<upload_id>/` and are attached to the
|
| 203 |
+
corresponding ML Intern session context by Hub reference, not by copying file
|
| 204 |
+
contents into the chat.
|
| 205 |
+
|
| 206 |
+
Each uploaded file is exposed as its own dataset config so files with different
|
| 207 |
+
schemas can coexist in the same session repo.
|
| 208 |
+
"""
|
| 209 |
+
return content.encode("utf-8")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def dataset_context_note(upload: DatasetUpload) -> str:
|
| 213 |
+
return f"""[SYSTEM: The user uploaded a dataset file for this session.
|
| 214 |
+
|
| 215 |
+
Use this Hugging Face Hub dataset reference when the task needs the uploaded data.
|
| 216 |
+
Do not look for the uploaded file on local disk and do not ask the user to
|
| 217 |
+
upload it again unless this Hub reference fails.
|
| 218 |
+
|
| 219 |
+
- Repo ID: {upload.repo_id}
|
| 220 |
+
- Repo type: dataset
|
| 221 |
+
- Dataset config: {upload.config_name}
|
| 222 |
+
- File in repo: {upload.path_in_repo}
|
| 223 |
+
- Original filename: {upload.original_filename}
|
| 224 |
+
- Stored filename: {upload.filename}
|
| 225 |
+
- Format: {upload.format}
|
| 226 |
+
- Size: {upload.size_bytes} bytes
|
| 227 |
+
- Hub URL: {upload.hub_url}
|
| 228 |
+
|
| 229 |
+
Load it with:
|
| 230 |
+
```python
|
| 231 |
+
{upload.load_dataset_snippet}
|
| 232 |
+
```
|
| 233 |
+
]"""
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
async def push_dataset_upload_to_hub(
|
| 237 |
+
*,
|
| 238 |
+
upload: UploadFile,
|
| 239 |
+
session_id: str,
|
| 240 |
+
hf_username: str,
|
| 241 |
+
hf_token: str,
|
| 242 |
+
) -> DatasetUpload:
|
| 243 |
+
safe_filename, dataset_format, size = await validate_dataset_upload(upload)
|
| 244 |
+
original_filename = display_filename(upload.filename, safe_filename)
|
| 245 |
+
upload_id = uuid.uuid4().hex[:12]
|
| 246 |
+
config_name = dataset_config_name(upload_id)
|
| 247 |
+
repo_id = session_dataset_repo_id(hf_username, session_id)
|
| 248 |
+
path_in_repo = f"uploads/{upload_id}/{safe_filename}"
|
| 249 |
+
hub_url = dataset_hub_url(repo_id, path_in_repo)
|
| 250 |
+
snippet = load_dataset_snippet(repo_id, config_name)
|
| 251 |
+
api = HfApi(token=hf_token)
|
| 252 |
+
|
| 253 |
+
await asyncio.to_thread(
|
| 254 |
+
api.create_repo,
|
| 255 |
+
repo_id=repo_id,
|
| 256 |
+
repo_type="dataset",
|
| 257 |
+
private=True,
|
| 258 |
+
exist_ok=True,
|
| 259 |
+
)
|
| 260 |
+
await asyncio.to_thread(
|
| 261 |
+
api.update_repo_settings,
|
| 262 |
+
repo_id=repo_id,
|
| 263 |
+
repo_type="dataset",
|
| 264 |
+
private=True,
|
| 265 |
+
)
|
| 266 |
+
repo_files = await asyncio.to_thread(
|
| 267 |
+
api.list_repo_files,
|
| 268 |
+
repo_id=repo_id,
|
| 269 |
+
repo_type="dataset",
|
| 270 |
+
)
|
| 271 |
+
upload_paths = unique_dataset_upload_paths([*repo_files, path_in_repo])
|
| 272 |
+
await asyncio.to_thread(upload.file.seek, 0)
|
| 273 |
+
file_bytes = await asyncio.to_thread(upload.file.read)
|
| 274 |
+
await asyncio.to_thread(
|
| 275 |
+
api.upload_file,
|
| 276 |
+
path_or_fileobj=file_bytes,
|
| 277 |
+
path_in_repo=path_in_repo,
|
| 278 |
+
repo_id=repo_id,
|
| 279 |
+
repo_type="dataset",
|
| 280 |
+
commit_message=f"Upload dataset file {safe_filename}",
|
| 281 |
+
)
|
| 282 |
+
await asyncio.to_thread(
|
| 283 |
+
api.upload_file,
|
| 284 |
+
path_or_fileobj=dataset_repo_card(repo_id, upload_paths),
|
| 285 |
+
path_in_repo="README.md",
|
| 286 |
+
repo_id=repo_id,
|
| 287 |
+
repo_type="dataset",
|
| 288 |
+
commit_message="Update ML Intern dataset upload configs",
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
return DatasetUpload(
|
| 292 |
+
session_id=session_id,
|
| 293 |
+
repo_id=repo_id,
|
| 294 |
+
repo_type="dataset",
|
| 295 |
+
private=True,
|
| 296 |
+
upload_id=upload_id,
|
| 297 |
+
config_name=config_name,
|
| 298 |
+
filename=safe_filename,
|
| 299 |
+
original_filename=original_filename,
|
| 300 |
+
path_in_repo=path_in_repo,
|
| 301 |
+
size_bytes=size,
|
| 302 |
+
format=dataset_format,
|
| 303 |
+
hub_url=hub_url,
|
| 304 |
+
load_dataset_snippet=snippet,
|
| 305 |
+
)
|
backend/models.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""Pydantic models for API requests and responses."""
|
| 2 |
|
| 3 |
from enum import Enum
|
| 4 |
-
from typing import Any
|
| 5 |
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
|
@@ -120,6 +120,23 @@ class SessionYoloRequest(BaseModel):
|
|
| 120 |
cost_cap_usd: float | None = Field(default=None, ge=0)
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
class HealthResponse(BaseModel):
|
| 124 |
"""Health check response."""
|
| 125 |
|
|
|
|
| 1 |
"""Pydantic models for API requests and responses."""
|
| 2 |
|
| 3 |
from enum import Enum
|
| 4 |
+
from typing import Any, Literal
|
| 5 |
|
| 6 |
from pydantic import BaseModel, Field
|
| 7 |
|
|
|
|
| 120 |
cost_cap_usd: float | None = Field(default=None, ge=0)
|
| 121 |
|
| 122 |
|
| 123 |
+
class DatasetUploadResponse(BaseModel):
|
| 124 |
+
"""Response for a dataset file uploaded to the Hub."""
|
| 125 |
+
|
| 126 |
+
session_id: str
|
| 127 |
+
repo_id: str
|
| 128 |
+
repo_type: Literal["dataset"] = "dataset"
|
| 129 |
+
private: bool = True
|
| 130 |
+
upload_id: str
|
| 131 |
+
config_name: str
|
| 132 |
+
filename: str
|
| 133 |
+
path_in_repo: str
|
| 134 |
+
size_bytes: int
|
| 135 |
+
format: Literal["csv", "json", "jsonl"]
|
| 136 |
+
hub_url: str
|
| 137 |
+
load_dataset_snippet: str
|
| 138 |
+
|
| 139 |
+
|
| 140 |
class HealthResponse(BaseModel):
|
| 141 |
"""Health check response."""
|
| 142 |
|
backend/routes/agent.py
CHANGED
|
@@ -21,10 +21,18 @@ from fastapi import (
|
|
| 21 |
)
|
| 22 |
from fastapi.exceptions import RequestValidationError
|
| 23 |
from fastapi.responses import StreamingResponse
|
| 24 |
-
from
|
|
|
|
| 25 |
from pydantic import ValidationError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
from models import (
|
| 27 |
ApprovalRequest,
|
|
|
|
| 28 |
HealthResponse,
|
| 29 |
LLMHealthResponse,
|
| 30 |
SessionInfo,
|
|
@@ -58,6 +66,7 @@ PREMIUM_MODEL_IDS = {
|
|
| 58 |
DEFAULT_CLAUDE_MODEL_ID,
|
| 59 |
"openai/gpt-5.5",
|
| 60 |
}
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
def _claude_picker_model_id() -> str:
|
|
@@ -203,6 +212,63 @@ def _user_hf_token(user: dict[str, Any] | None) -> str | None:
|
|
| 203 |
return user.get(INTERNAL_HF_TOKEN_KEY)
|
| 204 |
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
async def _check_session_access(
|
| 207 |
session_id: str,
|
| 208 |
user: dict[str, Any],
|
|
@@ -542,6 +608,86 @@ async def set_session_notifications(
|
|
| 542 |
}
|
| 543 |
|
| 544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
@router.patch("/session/{session_id}/yolo")
|
| 546 |
async def set_session_yolo(
|
| 547 |
session_id: str,
|
|
|
|
| 21 |
)
|
| 22 |
from fastapi.exceptions import RequestValidationError
|
| 23 |
from fastapi.responses import StreamingResponse
|
| 24 |
+
from huggingface_hub.errors import HfHubHTTPError
|
| 25 |
+
from litellm import Message, acompletion
|
| 26 |
from pydantic import ValidationError
|
| 27 |
+
from starlette.datastructures import FormData, UploadFile
|
| 28 |
+
from dataset_uploads import (
|
| 29 |
+
MAX_DATASET_UPLOAD_BYTES,
|
| 30 |
+
dataset_context_note,
|
| 31 |
+
push_dataset_upload_to_hub,
|
| 32 |
+
)
|
| 33 |
from models import (
|
| 34 |
ApprovalRequest,
|
| 35 |
+
DatasetUploadResponse,
|
| 36 |
HealthResponse,
|
| 37 |
LLMHealthResponse,
|
| 38 |
SessionInfo,
|
|
|
|
| 66 |
DEFAULT_CLAUDE_MODEL_ID,
|
| 67 |
"openai/gpt-5.5",
|
| 68 |
}
|
| 69 |
+
DATASET_UPLOAD_MULTIPART_SLACK_BYTES = 1024 * 1024
|
| 70 |
|
| 71 |
|
| 72 |
def _claude_picker_model_id() -> str:
|
|
|
|
| 212 |
return user.get(INTERNAL_HF_TOKEN_KEY)
|
| 213 |
|
| 214 |
|
| 215 |
+
def _reject_oversize_dataset_upload(request: Request) -> None:
|
| 216 |
+
raw_content_length = request.headers.get("content-length")
|
| 217 |
+
if raw_content_length is None:
|
| 218 |
+
return
|
| 219 |
+
try:
|
| 220 |
+
content_length = int(raw_content_length)
|
| 221 |
+
except (TypeError, ValueError):
|
| 222 |
+
return
|
| 223 |
+
if content_length > MAX_DATASET_UPLOAD_BYTES + DATASET_UPLOAD_MULTIPART_SLACK_BYTES:
|
| 224 |
+
raise HTTPException(
|
| 225 |
+
status_code=413,
|
| 226 |
+
detail="Dataset upload exceeds the 100 MB limit.",
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _dataset_upload_file_from_form(form: FormData) -> UploadFile:
|
| 231 |
+
uploaded_files = [
|
| 232 |
+
(key, value)
|
| 233 |
+
for key, value in form.multi_items()
|
| 234 |
+
if isinstance(value, UploadFile)
|
| 235 |
+
]
|
| 236 |
+
if len(uploaded_files) != 1:
|
| 237 |
+
raise HTTPException(
|
| 238 |
+
status_code=400,
|
| 239 |
+
detail="Upload exactly one dataset file.",
|
| 240 |
+
)
|
| 241 |
+
field_name, upload = uploaded_files[0]
|
| 242 |
+
if field_name != "file":
|
| 243 |
+
raise HTTPException(
|
| 244 |
+
status_code=400,
|
| 245 |
+
detail="Missing 'file' upload field.",
|
| 246 |
+
)
|
| 247 |
+
return upload
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _dataset_upload_hub_http_exception(error: HfHubHTTPError) -> HTTPException:
|
| 251 |
+
status_code = getattr(error.response, "status_code", None)
|
| 252 |
+
if status_code == 401:
|
| 253 |
+
detail = "Hugging Face rejected the token used for the dataset upload."
|
| 254 |
+
return HTTPException(status_code=401, detail=detail)
|
| 255 |
+
if status_code == 403:
|
| 256 |
+
detail = (
|
| 257 |
+
"Hugging Face denied permission to create or write to the dataset repo."
|
| 258 |
+
)
|
| 259 |
+
return HTTPException(status_code=403, detail=detail)
|
| 260 |
+
if status_code == 404:
|
| 261 |
+
detail = "Could not find the Hugging Face namespace or dataset repo."
|
| 262 |
+
return HTTPException(status_code=404, detail=detail)
|
| 263 |
+
if status_code == 429:
|
| 264 |
+
detail = "Hugging Face Hub rate limit reached while uploading the dataset."
|
| 265 |
+
return HTTPException(status_code=429, detail=detail)
|
| 266 |
+
return HTTPException(
|
| 267 |
+
status_code=502,
|
| 268 |
+
detail="Hugging Face Hub upload failed. Please try again.",
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
async def _check_session_access(
|
| 273 |
session_id: str,
|
| 274 |
user: dict[str, Any],
|
|
|
|
| 608 |
}
|
| 609 |
|
| 610 |
|
| 611 |
+
@router.post("/session/{session_id}/datasets", response_model=DatasetUploadResponse)
|
| 612 |
+
async def upload_session_dataset(
|
| 613 |
+
session_id: str,
|
| 614 |
+
request: Request,
|
| 615 |
+
user: dict = Depends(get_current_user),
|
| 616 |
+
) -> DatasetUploadResponse:
|
| 617 |
+
"""Upload a CSV/JSON dataset file to a private Hub dataset for this session."""
|
| 618 |
+
file: UploadFile | None = None
|
| 619 |
+
try:
|
| 620 |
+
_reject_oversize_dataset_upload(request)
|
| 621 |
+
agent_session = await _check_session_access(session_id, user, request)
|
| 622 |
+
if not agent_session or not agent_session.is_active:
|
| 623 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 624 |
+
if agent_session.is_processing:
|
| 625 |
+
raise HTTPException(
|
| 626 |
+
status_code=409,
|
| 627 |
+
detail="Cannot upload a dataset while the agent is processing.",
|
| 628 |
+
)
|
| 629 |
+
if agent_session.session.pending_approval:
|
| 630 |
+
raise HTTPException(
|
| 631 |
+
status_code=409,
|
| 632 |
+
detail="Approve or reject pending tools before uploading a dataset.",
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
hf_token = (
|
| 636 |
+
resolve_hf_request_token(request, include_env_fallback=False)
|
| 637 |
+
or _user_hf_token(user)
|
| 638 |
+
or resolve_hf_request_token(request)
|
| 639 |
+
)
|
| 640 |
+
if not hf_token:
|
| 641 |
+
raise HTTPException(
|
| 642 |
+
status_code=401,
|
| 643 |
+
detail="A Hugging Face token is required to upload datasets.",
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
form = await request.form(
|
| 647 |
+
max_files=1,
|
| 648 |
+
max_fields=1,
|
| 649 |
+
max_part_size=MAX_DATASET_UPLOAD_BYTES,
|
| 650 |
+
)
|
| 651 |
+
file = _dataset_upload_file_from_form(form)
|
| 652 |
+
hf_username = user.get("username") or agent_session.hf_username
|
| 653 |
+
uploaded = await push_dataset_upload_to_hub(
|
| 654 |
+
upload=file,
|
| 655 |
+
session_id=session_id,
|
| 656 |
+
hf_username=hf_username,
|
| 657 |
+
hf_token=hf_token,
|
| 658 |
+
)
|
| 659 |
+
agent_session.session.context_manager.add_message(
|
| 660 |
+
Message(role="user", content=dataset_context_note(uploaded))
|
| 661 |
+
)
|
| 662 |
+
await session_manager.persist_session_snapshot(agent_session)
|
| 663 |
+
logger.info(
|
| 664 |
+
"Uploaded dataset file %s to %s for session %s",
|
| 665 |
+
uploaded.filename,
|
| 666 |
+
uploaded.repo_id,
|
| 667 |
+
session_id,
|
| 668 |
+
)
|
| 669 |
+
return DatasetUploadResponse(**uploaded.response_payload())
|
| 670 |
+
except HTTPException:
|
| 671 |
+
raise
|
| 672 |
+
except HfHubHTTPError as e:
|
| 673 |
+
logger.warning(
|
| 674 |
+
"Hub rejected dataset upload for session %s: status=%s request_id=%s",
|
| 675 |
+
session_id,
|
| 676 |
+
getattr(e.response, "status_code", None),
|
| 677 |
+
getattr(e, "request_id", None),
|
| 678 |
+
)
|
| 679 |
+
raise _dataset_upload_hub_http_exception(e)
|
| 680 |
+
except Exception:
|
| 681 |
+
logger.exception("Dataset upload failed for session %s", session_id)
|
| 682 |
+
raise HTTPException(
|
| 683 |
+
status_code=502,
|
| 684 |
+
detail="Dataset upload failed. Please try again.",
|
| 685 |
+
)
|
| 686 |
+
finally:
|
| 687 |
+
if file is not None:
|
| 688 |
+
await file.close()
|
| 689 |
+
|
| 690 |
+
|
| 691 |
@router.patch("/session/{session_id}/yolo")
|
| 692 |
async def set_session_yolo(
|
| 693 |
session_id: str,
|
configs/cli_agent_config.json
CHANGED
|
@@ -7,6 +7,7 @@
|
|
| 7 |
"yolo_mode": false,
|
| 8 |
"confirm_cpu_jobs": true,
|
| 9 |
"auto_file_upload": true,
|
|
|
|
| 10 |
"messaging": {
|
| 11 |
"enabled": false,
|
| 12 |
"auto_event_types": ["approval_required", "error", "turn_complete"],
|
|
|
|
| 7 |
"yolo_mode": false,
|
| 8 |
"confirm_cpu_jobs": true,
|
| 9 |
"auto_file_upload": true,
|
| 10 |
+
"tool_runtime": "local",
|
| 11 |
"messaging": {
|
| 12 |
"enabled": false,
|
| 13 |
"auto_event_types": ["approval_required", "error", "turn_complete"],
|
frontend/src/components/Chat/ChatInput.tsx
CHANGED
|
@@ -11,12 +11,15 @@ import {
|
|
| 11 |
ListItemIcon,
|
| 12 |
ListItemText,
|
| 13 |
Chip,
|
|
|
|
| 14 |
Snackbar,
|
|
|
|
| 15 |
} from '@mui/material';
|
| 16 |
import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
|
| 17 |
import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
|
| 18 |
import StopIcon from '@mui/icons-material/Stop';
|
| 19 |
-
import
|
|
|
|
| 20 |
import { useUserQuota } from '@/hooks/useUserQuota';
|
| 21 |
import ClaudeCapDialog from '@/components/ClaudeCapDialog';
|
| 22 |
import JobsUpgradeDialog from '@/components/JobsUpgradeDialog';
|
|
@@ -118,18 +121,49 @@ interface ChatInputProps {
|
|
| 118 |
initialModelPath?: string | null;
|
| 119 |
onSend: (text: string) => void;
|
| 120 |
onStop?: () => void;
|
|
|
|
| 121 |
isProcessing?: boolean;
|
| 122 |
disabled?: boolean;
|
| 123 |
placeholder?: string;
|
| 124 |
}
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath);
|
| 127 |
const isPremiumModel = (m: ModelOption) => isPremiumPath(m.modelPath);
|
| 128 |
const firstFreeModel = (options: ModelOption[]) => options.find(m => !isPremiumModel(m)) ?? options[0];
|
| 129 |
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
const [input, setInput] = useState('');
|
| 132 |
const inputRef = useRef<HTMLTextAreaElement>(null);
|
|
|
|
| 133 |
const [modelOptions, setModelOptions] = useState<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
|
| 134 |
const modelOptionsRef = useRef<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
|
| 135 |
const sessionIdRef = useRef<string | undefined>(sessionId);
|
|
@@ -150,6 +184,11 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 150 |
const updateSessionModel = useSessionStore((s) => s.updateSessionModel);
|
| 151 |
const [awaitingTopUp, setAwaitingTopUp] = useState(false);
|
| 152 |
const [modelSwitchError, setModelSwitchError] = useState<string | null>(null);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
const lastSentRef = useRef<string>('');
|
| 154 |
|
| 155 |
useEffect(() => {
|
|
@@ -216,12 +255,75 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 216 |
}, [disabled, isProcessing]);
|
| 217 |
|
| 218 |
const handleSend = useCallback(() => {
|
| 219 |
-
if (input.trim() && !disabled) {
|
| 220 |
lastSentRef.current = input;
|
| 221 |
onSend(input);
|
| 222 |
setInput('');
|
| 223 |
}
|
| 224 |
-
}, [input, disabled, onSend]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
// When the chat transport reports a premium-model quota 429, restore the typed
|
| 227 |
// text so the user doesn't lose their message.
|
|
@@ -231,6 +333,18 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 231 |
}
|
| 232 |
}, [claudeQuotaExhausted]);
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
// Refresh the quota display whenever the session changes (user might
|
| 235 |
// have started another tab that spent quota).
|
| 236 |
useEffect(() => {
|
|
@@ -382,9 +496,12 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 382 |
<Box
|
| 383 |
className="composer"
|
| 384 |
sx={{
|
| 385 |
-
display: '
|
| 386 |
-
|
| 387 |
-
|
|
|
|
|
|
|
|
|
|
| 388 |
bgcolor: 'var(--composer-bg)',
|
| 389 |
borderRadius: 'var(--radius-md)',
|
| 390 |
p: '12px',
|
|
@@ -420,7 +537,7 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 420 |
}
|
| 421 |
}}
|
| 422 |
sx={{
|
| 423 |
-
|
| 424 |
'& .MuiInputBase-root': {
|
| 425 |
p: 0,
|
| 426 |
backgroundColor: 'transparent',
|
|
@@ -431,11 +548,46 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 431 |
}
|
| 432 |
}}
|
| 433 |
/>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
{isProcessing ? (
|
| 435 |
<IconButton
|
| 436 |
onClick={onStop}
|
| 437 |
sx={{
|
| 438 |
-
|
|
|
|
|
|
|
| 439 |
p: 1.5,
|
| 440 |
borderRadius: '10px',
|
| 441 |
color: 'var(--muted-text)',
|
|
@@ -455,9 +607,11 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 455 |
) : (
|
| 456 |
<IconButton
|
| 457 |
onClick={handleSend}
|
| 458 |
-
disabled={disabled || !input.trim()}
|
| 459 |
sx={{
|
| 460 |
-
|
|
|
|
|
|
|
| 461 |
p: 1,
|
| 462 |
borderRadius: '10px',
|
| 463 |
color: 'var(--muted-text)',
|
|
@@ -475,6 +629,65 @@ export default function ChatInput({ sessionId, initialModelPath, onSend, onStop,
|
|
| 475 |
</IconButton>
|
| 476 |
)}
|
| 477 |
</Box>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
|
| 479 |
{/* Powered By Badge */}
|
| 480 |
<Box
|
|
|
|
| 11 |
ListItemIcon,
|
| 12 |
ListItemText,
|
| 13 |
Chip,
|
| 14 |
+
LinearProgress,
|
| 15 |
Snackbar,
|
| 16 |
+
Tooltip,
|
| 17 |
} from '@mui/material';
|
| 18 |
import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward';
|
| 19 |
import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown';
|
| 20 |
import StopIcon from '@mui/icons-material/Stop';
|
| 21 |
+
import AddIcon from '@mui/icons-material/Add';
|
| 22 |
+
import { apiFetch, apiUpload } from '@/utils/api';
|
| 23 |
import { useUserQuota } from '@/hooks/useUserQuota';
|
| 24 |
import ClaudeCapDialog from '@/components/ClaudeCapDialog';
|
| 25 |
import JobsUpgradeDialog from '@/components/JobsUpgradeDialog';
|
|
|
|
| 121 |
initialModelPath?: string | null;
|
| 122 |
onSend: (text: string) => void;
|
| 123 |
onStop?: () => void;
|
| 124 |
+
onDatasetUploaded?: () => Promise<boolean> | boolean;
|
| 125 |
isProcessing?: boolean;
|
| 126 |
disabled?: boolean;
|
| 127 |
placeholder?: string;
|
| 128 |
}
|
| 129 |
|
| 130 |
+
interface DatasetUploadResponse {
|
| 131 |
+
session_id: string;
|
| 132 |
+
repo_id: string;
|
| 133 |
+
repo_type: 'dataset';
|
| 134 |
+
private: true;
|
| 135 |
+
upload_id: string;
|
| 136 |
+
config_name: string;
|
| 137 |
+
filename: string;
|
| 138 |
+
path_in_repo: string;
|
| 139 |
+
size_bytes: number;
|
| 140 |
+
format: 'csv' | 'json' | 'jsonl';
|
| 141 |
+
hub_url: string;
|
| 142 |
+
load_dataset_snippet: string;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
const MAX_DATASET_UPLOAD_BYTES = 100 * 1024 * 1024;
|
| 146 |
+
const DATASET_UPLOAD_ACCEPT = '.csv,.json,.jsonl';
|
| 147 |
+
const DATASET_UPLOAD_EXTENSIONS = new Set(['csv', 'json', 'jsonl']);
|
| 148 |
+
|
| 149 |
const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath);
|
| 150 |
const isPremiumModel = (m: ModelOption) => isPremiumPath(m.modelPath);
|
| 151 |
const firstFreeModel = (options: ModelOption[]) => options.find(m => !isPremiumModel(m)) ?? options[0];
|
| 152 |
|
| 153 |
+
const formatBytes = (bytes: number) => {
|
| 154 |
+
if (bytes < 1024) return `${bytes} B`;
|
| 155 |
+
if (bytes < 1024 * 1024) return `${(bytes / 1024).toFixed(1)} KB`;
|
| 156 |
+
return `${(bytes / (1024 * 1024)).toFixed(1)} MB`;
|
| 157 |
+
};
|
| 158 |
+
|
| 159 |
+
const datasetRepoUrl = (repoId: string) => (
|
| 160 |
+
`https://huggingface.co/datasets/${repoId.split('/').map(encodeURIComponent).join('/')}`
|
| 161 |
+
);
|
| 162 |
+
|
| 163 |
+
export default function ChatInput({ sessionId, initialModelPath, onSend, onStop, onDatasetUploaded, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
|
| 164 |
const [input, setInput] = useState('');
|
| 165 |
const inputRef = useRef<HTMLTextAreaElement>(null);
|
| 166 |
+
const fileInputRef = useRef<HTMLInputElement>(null);
|
| 167 |
const [modelOptions, setModelOptions] = useState<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
|
| 168 |
const modelOptionsRef = useRef<ModelOption[]>(DEFAULT_MODEL_OPTIONS);
|
| 169 |
const sessionIdRef = useRef<string | undefined>(sessionId);
|
|
|
|
| 184 |
const updateSessionModel = useSessionStore((s) => s.updateSessionModel);
|
| 185 |
const [awaitingTopUp, setAwaitingTopUp] = useState(false);
|
| 186 |
const [modelSwitchError, setModelSwitchError] = useState<string | null>(null);
|
| 187 |
+
const [datasetUploadError, setDatasetUploadError] = useState<string | null>(null);
|
| 188 |
+
const [datasetUploadSuccess, setDatasetUploadSuccess] = useState<string | null>(null);
|
| 189 |
+
const [uploadedDatasets, setUploadedDatasets] = useState<DatasetUploadResponse[]>([]);
|
| 190 |
+
const [isUploadingDataset, setIsUploadingDataset] = useState(false);
|
| 191 |
+
const [datasetUploadProgress, setDatasetUploadProgress] = useState<number | null>(null);
|
| 192 |
const lastSentRef = useRef<string>('');
|
| 193 |
|
| 194 |
useEffect(() => {
|
|
|
|
| 255 |
}, [disabled, isProcessing]);
|
| 256 |
|
| 257 |
const handleSend = useCallback(() => {
|
| 258 |
+
if (input.trim() && !disabled && !isUploadingDataset) {
|
| 259 |
lastSentRef.current = input;
|
| 260 |
onSend(input);
|
| 261 |
setInput('');
|
| 262 |
}
|
| 263 |
+
}, [input, disabled, isUploadingDataset, onSend]);
|
| 264 |
+
|
| 265 |
+
const handleDatasetUploadClick = useCallback(() => {
|
| 266 |
+
fileInputRef.current?.click();
|
| 267 |
+
}, []);
|
| 268 |
+
|
| 269 |
+
const handleDatasetFileChange = useCallback(
|
| 270 |
+
async (event: React.ChangeEvent<HTMLInputElement>) => {
|
| 271 |
+
const file = event.target.files?.[0];
|
| 272 |
+
event.target.value = '';
|
| 273 |
+
if (!file) return;
|
| 274 |
+
|
| 275 |
+
if (!sessionId) {
|
| 276 |
+
setDatasetUploadError('Start a session before uploading a dataset.');
|
| 277 |
+
return;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
const extension = file.name.split('.').pop()?.toLowerCase() || '';
|
| 281 |
+
if (!DATASET_UPLOAD_EXTENSIONS.has(extension)) {
|
| 282 |
+
setDatasetUploadError('Only CSV, JSON, and JSONL dataset files are supported.');
|
| 283 |
+
return;
|
| 284 |
+
}
|
| 285 |
+
if (file.size > MAX_DATASET_UPLOAD_BYTES) {
|
| 286 |
+
setDatasetUploadError(
|
| 287 |
+
`Dataset files must be 100 MB or smaller. ${file.name} is ${formatBytes(file.size)}.`
|
| 288 |
+
);
|
| 289 |
+
return;
|
| 290 |
+
}
|
| 291 |
+
if (file.size === 0) {
|
| 292 |
+
setDatasetUploadError('Uploaded dataset file is empty.');
|
| 293 |
+
return;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
const formData = new FormData();
|
| 297 |
+
formData.append('file', file);
|
| 298 |
+
setIsUploadingDataset(true);
|
| 299 |
+
setDatasetUploadProgress(0);
|
| 300 |
+
setDatasetUploadError(null);
|
| 301 |
+
setDatasetUploadSuccess(null);
|
| 302 |
+
try {
|
| 303 |
+
const res = await apiUpload(`/api/session/${sessionId}/datasets`, formData, {
|
| 304 |
+
onProgress: ({ percent }) => {
|
| 305 |
+
setDatasetUploadProgress(percent !== null && percent < 100 ? percent : null);
|
| 306 |
+
},
|
| 307 |
+
});
|
| 308 |
+
if (!res.ok) {
|
| 309 |
+
setDatasetUploadError(await readApiErrorMessage(res, 'Dataset upload failed.'));
|
| 310 |
+
return;
|
| 311 |
+
}
|
| 312 |
+
const payload = await res.json() as DatasetUploadResponse;
|
| 313 |
+
setUploadedDatasets((previous) => [payload, ...previous]);
|
| 314 |
+
setDatasetUploadSuccess(`Uploaded ${payload.filename} to ${payload.repo_id}`);
|
| 315 |
+
await onDatasetUploaded?.();
|
| 316 |
+
} catch (error) {
|
| 317 |
+
setDatasetUploadError(
|
| 318 |
+
error instanceof Error ? error.message : 'Dataset upload failed.'
|
| 319 |
+
);
|
| 320 |
+
} finally {
|
| 321 |
+
setIsUploadingDataset(false);
|
| 322 |
+
setDatasetUploadProgress(null);
|
| 323 |
+
}
|
| 324 |
+
},
|
| 325 |
+
[sessionId, onDatasetUploaded],
|
| 326 |
+
);
|
| 327 |
|
| 328 |
// When the chat transport reports a premium-model quota 429, restore the typed
|
| 329 |
// text so the user doesn't lose their message.
|
|
|
|
| 333 |
}
|
| 334 |
}, [claudeQuotaExhausted]);
|
| 335 |
|
| 336 |
+
useEffect(() => {
|
| 337 |
+
if (!datasetUploadError) return;
|
| 338 |
+
const timeout = window.setTimeout(() => setDatasetUploadError(null), 7000);
|
| 339 |
+
return () => window.clearTimeout(timeout);
|
| 340 |
+
}, [datasetUploadError]);
|
| 341 |
+
|
| 342 |
+
useEffect(() => {
|
| 343 |
+
if (!datasetUploadSuccess) return;
|
| 344 |
+
const timeout = window.setTimeout(() => setDatasetUploadSuccess(null), 5000);
|
| 345 |
+
return () => window.clearTimeout(timeout);
|
| 346 |
+
}, [datasetUploadSuccess]);
|
| 347 |
+
|
| 348 |
// Refresh the quota display whenever the session changes (user might
|
| 349 |
// have started another tab that spent quota).
|
| 350 |
useEffect(() => {
|
|
|
|
| 496 |
<Box
|
| 497 |
className="composer"
|
| 498 |
sx={{
|
| 499 |
+
display: 'grid',
|
| 500 |
+
gridTemplateColumns: 'auto 1fr auto',
|
| 501 |
+
gridTemplateRows: 'auto auto',
|
| 502 |
+
columnGap: '10px',
|
| 503 |
+
rowGap: '4px',
|
| 504 |
+
alignItems: 'end',
|
| 505 |
bgcolor: 'var(--composer-bg)',
|
| 506 |
borderRadius: 'var(--radius-md)',
|
| 507 |
p: '12px',
|
|
|
|
| 537 |
}
|
| 538 |
}}
|
| 539 |
sx={{
|
| 540 |
+
gridColumn: '1 / -1',
|
| 541 |
'& .MuiInputBase-root': {
|
| 542 |
p: 0,
|
| 543 |
backgroundColor: 'transparent',
|
|
|
|
| 548 |
}
|
| 549 |
}}
|
| 550 |
/>
|
| 551 |
+
<input
|
| 552 |
+
ref={fileInputRef}
|
| 553 |
+
type="file"
|
| 554 |
+
accept={DATASET_UPLOAD_ACCEPT}
|
| 555 |
+
onChange={handleDatasetFileChange}
|
| 556 |
+
style={{ display: 'none' }}
|
| 557 |
+
/>
|
| 558 |
+
<Box sx={{ gridColumn: '1', gridRow: '2', display: 'flex' }}>
|
| 559 |
+
<Tooltip title="Upload dataset">
|
| 560 |
+
<span>
|
| 561 |
+
<IconButton
|
| 562 |
+
onClick={handleDatasetUploadClick}
|
| 563 |
+
disabled={disabled || isProcessing || isUploadingDataset || !sessionId}
|
| 564 |
+
sx={{
|
| 565 |
+
p: 1,
|
| 566 |
+
borderRadius: '50%',
|
| 567 |
+
color: uploadedDatasets.length ? 'var(--accent-yellow)' : 'var(--muted-text)',
|
| 568 |
+
transition: 'all 0.2s',
|
| 569 |
+
'&:hover': {
|
| 570 |
+
color: 'var(--accent-yellow)',
|
| 571 |
+
bgcolor: 'var(--hover-bg)',
|
| 572 |
+
},
|
| 573 |
+
'&.Mui-disabled': {
|
| 574 |
+
opacity: 0.3,
|
| 575 |
+
},
|
| 576 |
+
}}
|
| 577 |
+
aria-label="Upload dataset"
|
| 578 |
+
>
|
| 579 |
+
<AddIcon fontSize="small" />
|
| 580 |
+
</IconButton>
|
| 581 |
+
</span>
|
| 582 |
+
</Tooltip>
|
| 583 |
+
</Box>
|
| 584 |
{isProcessing ? (
|
| 585 |
<IconButton
|
| 586 |
onClick={onStop}
|
| 587 |
sx={{
|
| 588 |
+
gridColumn: '3',
|
| 589 |
+
gridRow: '2',
|
| 590 |
+
justifySelf: 'end',
|
| 591 |
p: 1.5,
|
| 592 |
borderRadius: '10px',
|
| 593 |
color: 'var(--muted-text)',
|
|
|
|
| 607 |
) : (
|
| 608 |
<IconButton
|
| 609 |
onClick={handleSend}
|
| 610 |
+
disabled={disabled || isUploadingDataset || !input.trim()}
|
| 611 |
sx={{
|
| 612 |
+
gridColumn: '3',
|
| 613 |
+
gridRow: '2',
|
| 614 |
+
justifySelf: 'end',
|
| 615 |
p: 1,
|
| 616 |
borderRadius: '10px',
|
| 617 |
color: 'var(--muted-text)',
|
|
|
|
| 629 |
</IconButton>
|
| 630 |
)}
|
| 631 |
</Box>
|
| 632 |
+
{isUploadingDataset && (
|
| 633 |
+
<Box sx={{ mt: 1, px: 0.5 }}>
|
| 634 |
+
<LinearProgress
|
| 635 |
+
variant={datasetUploadProgress === null ? 'indeterminate' : 'determinate'}
|
| 636 |
+
value={datasetUploadProgress ?? 0}
|
| 637 |
+
aria-label="Dataset upload progress"
|
| 638 |
+
sx={{
|
| 639 |
+
height: 4,
|
| 640 |
+
borderRadius: 999,
|
| 641 |
+
bgcolor: 'rgba(255,255,255,0.08)',
|
| 642 |
+
'& .MuiLinearProgress-bar': {
|
| 643 |
+
borderRadius: 999,
|
| 644 |
+
bgcolor: 'var(--accent-yellow)',
|
| 645 |
+
},
|
| 646 |
+
}}
|
| 647 |
+
/>
|
| 648 |
+
</Box>
|
| 649 |
+
)}
|
| 650 |
+
{(datasetUploadError || datasetUploadSuccess) && (
|
| 651 |
+
<Box sx={{ display: 'flex', justifyContent: 'center', mt: 1 }}>
|
| 652 |
+
<Alert
|
| 653 |
+
severity={datasetUploadError ? 'error' : 'success'}
|
| 654 |
+
variant="filled"
|
| 655 |
+
onClose={() => {
|
| 656 |
+
setDatasetUploadError(null);
|
| 657 |
+
setDatasetUploadSuccess(null);
|
| 658 |
+
}}
|
| 659 |
+
sx={{ fontSize: '0.8rem', maxWidth: 520, width: '100%' }}
|
| 660 |
+
>
|
| 661 |
+
{datasetUploadError ?? datasetUploadSuccess}
|
| 662 |
+
</Alert>
|
| 663 |
+
</Box>
|
| 664 |
+
)}
|
| 665 |
+
{uploadedDatasets.length > 0 && (
|
| 666 |
+
<Box sx={{ display: 'flex', flexWrap: 'wrap', gap: 0.75, justifyContent: 'center', mt: 1 }}>
|
| 667 |
+
{uploadedDatasets.map((dataset) => (
|
| 668 |
+
<Chip
|
| 669 |
+
key={dataset.upload_id}
|
| 670 |
+
size="small"
|
| 671 |
+
label={`Dataset: ${dataset.filename}`}
|
| 672 |
+
component="a"
|
| 673 |
+
href={datasetRepoUrl(dataset.repo_id)}
|
| 674 |
+
target="_blank"
|
| 675 |
+
rel="noreferrer"
|
| 676 |
+
clickable
|
| 677 |
+
sx={{
|
| 678 |
+
maxWidth: '100%',
|
| 679 |
+
bgcolor: 'rgba(255,255,255,0.08)',
|
| 680 |
+
color: 'var(--text)',
|
| 681 |
+
border: '1px solid var(--divider)',
|
| 682 |
+
'& .MuiChip-label': {
|
| 683 |
+
overflow: 'hidden',
|
| 684 |
+
textOverflow: 'ellipsis',
|
| 685 |
+
},
|
| 686 |
+
}}
|
| 687 |
+
/>
|
| 688 |
+
))}
|
| 689 |
+
</Box>
|
| 690 |
+
)}
|
| 691 |
|
| 692 |
{/* Powered By Badge */}
|
| 693 |
<Box
|
frontend/src/components/SessionChat.tsx
CHANGED
|
@@ -27,7 +27,16 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
|
|
| 27 |
const sessionMeta = sessions.find((s) => s.id === sessionId);
|
| 28 |
const isExpired = sessionMeta?.expired === true;
|
| 29 |
|
| 30 |
-
const {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
sessionId,
|
| 32 |
isActive,
|
| 33 |
onReady: () => logger.log(`Session ${sessionId} ready`),
|
|
@@ -116,6 +125,7 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
|
|
| 116 |
initialModelPath={sessionMeta?.model}
|
| 117 |
onSend={handleSendMessage}
|
| 118 |
onStop={handleStop}
|
|
|
|
| 119 |
isProcessing={busy}
|
| 120 |
disabled={!isConnected || activityStatus.type === 'waiting-approval'}
|
| 121 |
placeholder={
|
|
|
|
| 27 |
const sessionMeta = sessions.find((s) => s.id === sessionId);
|
| 28 |
const isExpired = sessionMeta?.expired === true;
|
| 29 |
|
| 30 |
+
const {
|
| 31 |
+
messages,
|
| 32 |
+
sendMessage,
|
| 33 |
+
stop,
|
| 34 |
+
status,
|
| 35 |
+
undoLastTurn,
|
| 36 |
+
editAndRegenerate,
|
| 37 |
+
approveTools,
|
| 38 |
+
refreshMessages,
|
| 39 |
+
} = useAgentChat({
|
| 40 |
sessionId,
|
| 41 |
isActive,
|
| 42 |
onReady: () => logger.log(`Session ${sessionId} ready`),
|
|
|
|
| 125 |
initialModelPath={sessionMeta?.model}
|
| 126 |
onSend={handleSendMessage}
|
| 127 |
onStop={handleStop}
|
| 128 |
+
onDatasetUploaded={refreshMessages}
|
| 129 |
isProcessing={busy}
|
| 130 |
disabled={!isConnected || activityStatus.type === 'waiting-approval'}
|
| 131 |
placeholder={
|
frontend/src/hooks/useAgentChat.ts
CHANGED
|
@@ -804,6 +804,48 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 804 |
}
|
| 805 |
}, [sessionId, chat]);
|
| 806 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
return {
|
| 808 |
messages: chat.messages,
|
| 809 |
sendMessage: chat.sendMessage,
|
|
@@ -812,5 +854,6 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 812 |
undoLastTurn,
|
| 813 |
editAndRegenerate,
|
| 814 |
approveTools,
|
|
|
|
| 815 |
};
|
| 816 |
}
|
|
|
|
| 804 |
}
|
| 805 |
}, [sessionId, chat]);
|
| 806 |
|
| 807 |
+
const refreshMessages = useCallback(async () => {
|
| 808 |
+
try {
|
| 809 |
+
const [msgsRes, infoRes] = await Promise.all([
|
| 810 |
+
apiFetch(`/api/session/${sessionId}/messages`),
|
| 811 |
+
apiFetch(`/api/session/${sessionId}`),
|
| 812 |
+
]);
|
| 813 |
+
if (!msgsRes.ok) return false;
|
| 814 |
+
|
| 815 |
+
const data = await msgsRes.json();
|
| 816 |
+
if (!Array.isArray(data) || data.length === 0) return false;
|
| 817 |
+
saveBackendMessages(sessionId, data);
|
| 818 |
+
|
| 819 |
+
let pendingIds: Set<string> | undefined;
|
| 820 |
+
if (infoRes.ok) {
|
| 821 |
+
const info = await infoRes.json();
|
| 822 |
+
if (info.pending_approval && Array.isArray(info.pending_approval)) {
|
| 823 |
+
pendingIds = new Set(
|
| 824 |
+
info.pending_approval.map((t: { tool_call_id: string }) => t.tool_call_id)
|
| 825 |
+
);
|
| 826 |
+
if (pendingIds.size > 0) setNeedsAttention(sessionId, true);
|
| 827 |
+
}
|
| 828 |
+
if (info.auto_approval) {
|
| 829 |
+
updateSessionYolo(sessionId, info.auto_approval);
|
| 830 |
+
}
|
| 831 |
+
}
|
| 832 |
+
|
| 833 |
+
const uiMsgs = llmMessagesToUIMessages(
|
| 834 |
+
data,
|
| 835 |
+
pendingIds,
|
| 836 |
+
chatActionsRef.current.messages,
|
| 837 |
+
);
|
| 838 |
+
const setMsgs = chatActionsRef.current.setMessages;
|
| 839 |
+
if (setMsgs && uiMsgs.length > 0) {
|
| 840 |
+
setMsgs(uiMsgs);
|
| 841 |
+
saveMessages(sessionId, uiMsgs);
|
| 842 |
+
}
|
| 843 |
+
return true;
|
| 844 |
+
} catch {
|
| 845 |
+
return false;
|
| 846 |
+
}
|
| 847 |
+
}, [sessionId, setNeedsAttention, updateSessionYolo]);
|
| 848 |
+
|
| 849 |
return {
|
| 850 |
messages: chat.messages,
|
| 851 |
sendMessage: chat.sendMessage,
|
|
|
|
| 854 |
undoLastTurn,
|
| 855 |
editAndRegenerate,
|
| 856 |
approveTools,
|
| 857 |
+
refreshMessages,
|
| 858 |
};
|
| 859 |
}
|
frontend/src/utils/api.ts
CHANGED
|
@@ -7,15 +7,36 @@
|
|
| 7 |
|
| 8 |
import { triggerLogin } from '@/hooks/useAuth';
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
/** Wrapper around fetch with credentials and common headers. */
|
| 11 |
export async function apiFetch(
|
| 12 |
path: string,
|
| 13 |
options: RequestInit = {}
|
| 14 |
): Promise<Response> {
|
| 15 |
-
const headers
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
const response = await fetch(path, {
|
| 21 |
...options,
|
|
@@ -23,19 +44,50 @@ export async function apiFetch(
|
|
| 23 |
credentials: 'include', // Send cookies with every request
|
| 24 |
});
|
| 25 |
|
| 26 |
-
|
| 27 |
-
if (response.status === 401) {
|
| 28 |
-
try {
|
| 29 |
-
const authStatus = await fetch('/auth/status', { credentials: 'include' });
|
| 30 |
-
const data = await authStatus.json();
|
| 31 |
-
if (data.auth_enabled) {
|
| 32 |
-
triggerLogin();
|
| 33 |
-
throw new Error('Authentication required — redirecting to login.');
|
| 34 |
-
}
|
| 35 |
-
} catch (e) {
|
| 36 |
-
if (e instanceof Error && e.message.includes('redirecting')) throw e;
|
| 37 |
-
}
|
| 38 |
-
}
|
| 39 |
|
| 40 |
return response;
|
| 41 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
import { triggerLogin } from '@/hooks/useAuth';
|
| 9 |
|
| 10 |
+
export interface ApiUploadProgress {
|
| 11 |
+
loaded: number;
|
| 12 |
+
total: number | null;
|
| 13 |
+
percent: number | null;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
async function handleUnauthorized(response: Response): Promise<void> {
|
| 17 |
+
if (response.status !== 401) return;
|
| 18 |
+
try {
|
| 19 |
+
const authStatus = await fetch('/auth/status', { credentials: 'include' });
|
| 20 |
+
const data = await authStatus.json();
|
| 21 |
+
if (data.auth_enabled) {
|
| 22 |
+
triggerLogin();
|
| 23 |
+
throw new Error('Authentication required — redirecting to login.');
|
| 24 |
+
}
|
| 25 |
+
} catch (e) {
|
| 26 |
+
if (e instanceof Error && e.message.includes('redirecting')) throw e;
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
/** Wrapper around fetch with credentials and common headers. */
|
| 31 |
export async function apiFetch(
|
| 32 |
path: string,
|
| 33 |
options: RequestInit = {}
|
| 34 |
): Promise<Response> {
|
| 35 |
+
const headers = new Headers(options.headers);
|
| 36 |
+
const isFormData = options.body instanceof FormData;
|
| 37 |
+
if (!isFormData && !headers.has('Content-Type')) {
|
| 38 |
+
headers.set('Content-Type', 'application/json');
|
| 39 |
+
}
|
| 40 |
|
| 41 |
const response = await fetch(path, {
|
| 42 |
...options,
|
|
|
|
| 44 |
credentials: 'include', // Send cookies with every request
|
| 45 |
});
|
| 46 |
|
| 47 |
+
await handleUnauthorized(response);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
return response;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
function headersFromXhr(rawHeaders: string): Headers {
|
| 53 |
+
const headers = new Headers();
|
| 54 |
+
rawHeaders.trim().split(/[\r\n]+/).forEach((line) => {
|
| 55 |
+
const separator = line.indexOf(':');
|
| 56 |
+
if (separator <= 0) return;
|
| 57 |
+
headers.append(
|
| 58 |
+
line.slice(0, separator).trim(),
|
| 59 |
+
line.slice(separator + 1).trim(),
|
| 60 |
+
);
|
| 61 |
+
});
|
| 62 |
+
return headers;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
export async function apiUpload(
|
| 66 |
+
path: string,
|
| 67 |
+
formData: FormData,
|
| 68 |
+
options: { onProgress?: (progress: ApiUploadProgress) => void } = {},
|
| 69 |
+
): Promise<Response> {
|
| 70 |
+
return new Promise<Response>((resolve, reject) => {
|
| 71 |
+
const xhr = new XMLHttpRequest();
|
| 72 |
+
xhr.open('POST', path);
|
| 73 |
+
xhr.withCredentials = true;
|
| 74 |
+
xhr.upload.onprogress = (event) => {
|
| 75 |
+
const total = event.lengthComputable ? event.total : null;
|
| 76 |
+
const percent = total
|
| 77 |
+
? Math.min(100, Math.round((event.loaded / total) * 100))
|
| 78 |
+
: null;
|
| 79 |
+
options.onProgress?.({ loaded: event.loaded, total, percent });
|
| 80 |
+
};
|
| 81 |
+
xhr.onerror = () => reject(new Error('Network error while uploading.'));
|
| 82 |
+
xhr.onabort = () => reject(new Error('Dataset upload was canceled.'));
|
| 83 |
+
xhr.onload = () => {
|
| 84 |
+
const response = new Response(xhr.responseText, {
|
| 85 |
+
status: xhr.status,
|
| 86 |
+
statusText: xhr.statusText,
|
| 87 |
+
headers: headersFromXhr(xhr.getAllResponseHeaders()),
|
| 88 |
+
});
|
| 89 |
+
handleUnauthorized(response).then(() => resolve(response)).catch(reject);
|
| 90 |
+
};
|
| 91 |
+
xhr.send(formData);
|
| 92 |
+
});
|
| 93 |
+
}
|
pyproject.toml
CHANGED
|
@@ -28,6 +28,7 @@ dependencies = [
|
|
| 28 |
"websockets>=13.0",
|
| 29 |
"apscheduler>=3.10,<4",
|
| 30 |
"pymongo>=4.17.0",
|
|
|
|
| 31 |
]
|
| 32 |
|
| 33 |
[project.optional-dependencies]
|
|
|
|
| 28 |
"websockets>=13.0",
|
| 29 |
"apscheduler>=3.10,<4",
|
| 30 |
"pymongo>=4.17.0",
|
| 31 |
+
"python-multipart>=0.0.20",
|
| 32 |
]
|
| 33 |
|
| 34 |
[project.optional-dependencies]
|
tests/unit/test_cli_rendering.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
"""Regression tests for interactive CLI rendering and research model routing."""
|
| 2 |
|
|
|
|
| 3 |
import sys
|
| 4 |
from io import StringIO
|
| 5 |
from types import SimpleNamespace
|
| 6 |
|
| 7 |
import pytest
|
|
|
|
| 8 |
|
| 9 |
import agent.main as main_mod
|
| 10 |
from agent.tools.research_tool import _get_research_model
|
|
@@ -29,6 +31,50 @@ def test_non_anthropic_research_model_is_unchanged():
|
|
| 29 |
assert _get_research_model("openai/gpt-5.4") == "openai/gpt-5.4"
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def test_subagent_display_does_not_spawn_background_redraw(monkeypatch):
|
| 33 |
calls: list[object] = []
|
| 34 |
|
|
@@ -52,10 +98,11 @@ def test_subagent_display_does_not_spawn_background_redraw(monkeypatch):
|
|
| 52 |
|
| 53 |
|
| 54 |
def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
|
| 55 |
-
seen: dict[str,
|
| 56 |
|
| 57 |
-
async def fake_main(*, model=None):
|
| 58 |
seen["model"] = model
|
|
|
|
| 59 |
|
| 60 |
monkeypatch.setattr(sys, "argv", ["ml-intern", "--model", "openai/gpt-5.5"])
|
| 61 |
monkeypatch.setattr(main_mod, "main", fake_main)
|
|
@@ -63,6 +110,61 @@ def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
|
|
| 63 |
main_mod.cli()
|
| 64 |
|
| 65 |
assert seen["model"] == "openai/gpt-5.5"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
@pytest.mark.asyncio
|
|
@@ -70,9 +172,10 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
|
|
| 70 |
class StopAfterBanner(Exception):
|
| 71 |
pass
|
| 72 |
|
| 73 |
-
def fake_banner(*, model=None, hf_user=None):
|
| 74 |
assert model == "openai/gpt-5.5"
|
| 75 |
assert hf_user == "tester"
|
|
|
|
| 76 |
raise StopAfterBanner
|
| 77 |
|
| 78 |
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
|
@@ -85,9 +188,150 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
|
|
| 85 |
lambda _path, **_kwargs: SimpleNamespace(
|
| 86 |
model_name="moonshotai/Kimi-K2.6",
|
| 87 |
mcpServers={},
|
|
|
|
| 88 |
),
|
| 89 |
)
|
| 90 |
monkeypatch.setattr(main_mod, "print_banner", fake_banner)
|
| 91 |
|
| 92 |
with pytest.raises(StopAfterBanner):
|
| 93 |
await main_mod.main(model="openai/gpt-5.5")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""Regression tests for interactive CLI rendering and research model routing."""
|
| 2 |
|
| 3 |
+
import asyncio
|
| 4 |
import sys
|
| 5 |
from io import StringIO
|
| 6 |
from types import SimpleNamespace
|
| 7 |
|
| 8 |
import pytest
|
| 9 |
+
from rich.console import Console
|
| 10 |
|
| 11 |
import agent.main as main_mod
|
| 12 |
from agent.tools.research_tool import _get_research_model
|
|
|
|
| 31 |
assert _get_research_model("openai/gpt-5.4") == "openai/gpt-5.4"
|
| 32 |
|
| 33 |
|
| 34 |
+
def test_help_output_keeps_descriptions_aligned(monkeypatch):
|
| 35 |
+
output = StringIO()
|
| 36 |
+
console = Console(
|
| 37 |
+
file=output,
|
| 38 |
+
color_system=None,
|
| 39 |
+
theme=terminal_display._THEME,
|
| 40 |
+
width=120,
|
| 41 |
+
)
|
| 42 |
+
monkeypatch.setattr(terminal_display, "_console", console)
|
| 43 |
+
|
| 44 |
+
terminal_display.print_help()
|
| 45 |
+
|
| 46 |
+
lines = [line.rstrip() for line in output.getvalue().splitlines() if line.strip()]
|
| 47 |
+
description_columns = []
|
| 48 |
+
for command, args, description in terminal_display.HELP_ROWS:
|
| 49 |
+
line = next(line for line in lines if command in line)
|
| 50 |
+
if args:
|
| 51 |
+
assert args in line
|
| 52 |
+
description_columns.append(line.index(description))
|
| 53 |
+
|
| 54 |
+
assert len(set(description_columns)) == 1
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_help_output_recomputes_widths_from_rows():
|
| 58 |
+
rows = terminal_display.HELP_ROWS + (
|
| 59 |
+
("/longer-command", "[longer-args]", "Synthetic help row"),
|
| 60 |
+
)
|
| 61 |
+
output = StringIO()
|
| 62 |
+
Console(
|
| 63 |
+
file=output,
|
| 64 |
+
color_system=None,
|
| 65 |
+
theme=terminal_display._THEME,
|
| 66 |
+
width=140,
|
| 67 |
+
).print(terminal_display.format_help_text(rows))
|
| 68 |
+
|
| 69 |
+
lines = [line.rstrip() for line in output.getvalue().splitlines() if line.strip()]
|
| 70 |
+
description_columns = [
|
| 71 |
+
next(line for line in lines if command in line).index(description)
|
| 72 |
+
for command, _args, description in rows
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
assert len(set(description_columns)) == 1
|
| 76 |
+
|
| 77 |
+
|
| 78 |
def test_subagent_display_does_not_spawn_background_redraw(monkeypatch):
|
| 79 |
calls: list[object] = []
|
| 80 |
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
def test_cli_forwards_model_flag_to_interactive_main(monkeypatch):
|
| 101 |
+
seen: dict[str, object] = {}
|
| 102 |
|
| 103 |
+
async def fake_main(*, model=None, sandbox_tools=False):
|
| 104 |
seen["model"] = model
|
| 105 |
+
seen["sandbox_tools"] = sandbox_tools
|
| 106 |
|
| 107 |
monkeypatch.setattr(sys, "argv", ["ml-intern", "--model", "openai/gpt-5.5"])
|
| 108 |
monkeypatch.setattr(main_mod, "main", fake_main)
|
|
|
|
| 110 |
main_mod.cli()
|
| 111 |
|
| 112 |
assert seen["model"] == "openai/gpt-5.5"
|
| 113 |
+
assert seen["sandbox_tools"] is False
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_cli_forwards_sandbox_flag_to_interactive_main(monkeypatch):
|
| 117 |
+
seen: dict[str, object] = {}
|
| 118 |
+
|
| 119 |
+
async def fake_main(*, model=None, sandbox_tools=False):
|
| 120 |
+
seen["model"] = model
|
| 121 |
+
seen["sandbox_tools"] = sandbox_tools
|
| 122 |
+
|
| 123 |
+
monkeypatch.setattr(sys, "argv", ["ml-intern", "--sandbox-tools"])
|
| 124 |
+
monkeypatch.setattr(main_mod, "main", fake_main)
|
| 125 |
+
|
| 126 |
+
main_mod.cli()
|
| 127 |
+
|
| 128 |
+
assert seen == {"model": None, "sandbox_tools": True}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def test_cli_forwards_sandbox_flag_to_headless_main(monkeypatch):
|
| 132 |
+
seen: dict[str, object] = {}
|
| 133 |
+
|
| 134 |
+
async def fake_headless_main(
|
| 135 |
+
prompt,
|
| 136 |
+
*,
|
| 137 |
+
model=None,
|
| 138 |
+
max_iterations=None,
|
| 139 |
+
stream=True,
|
| 140 |
+
sandbox_tools=False,
|
| 141 |
+
):
|
| 142 |
+
seen.update(
|
| 143 |
+
{
|
| 144 |
+
"prompt": prompt,
|
| 145 |
+
"model": model,
|
| 146 |
+
"max_iterations": max_iterations,
|
| 147 |
+
"stream": stream,
|
| 148 |
+
"sandbox_tools": sandbox_tools,
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
monkeypatch.setattr(
|
| 153 |
+
sys,
|
| 154 |
+
"argv",
|
| 155 |
+
["ml-intern", "--sandbox-tools", "--no-stream", "train a model"],
|
| 156 |
+
)
|
| 157 |
+
monkeypatch.setattr(main_mod, "headless_main", fake_headless_main)
|
| 158 |
+
|
| 159 |
+
main_mod.cli()
|
| 160 |
+
|
| 161 |
+
assert seen == {
|
| 162 |
+
"prompt": "train a model",
|
| 163 |
+
"model": None,
|
| 164 |
+
"max_iterations": None,
|
| 165 |
+
"stream": False,
|
| 166 |
+
"sandbox_tools": True,
|
| 167 |
+
}
|
| 168 |
|
| 169 |
|
| 170 |
@pytest.mark.asyncio
|
|
|
|
| 172 |
class StopAfterBanner(Exception):
|
| 173 |
pass
|
| 174 |
|
| 175 |
+
def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
|
| 176 |
assert model == "openai/gpt-5.5"
|
| 177 |
assert hf_user == "tester"
|
| 178 |
+
assert tool_runtime == "local filesystem"
|
| 179 |
raise StopAfterBanner
|
| 180 |
|
| 181 |
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
|
|
|
| 188 |
lambda _path, **_kwargs: SimpleNamespace(
|
| 189 |
model_name="moonshotai/Kimi-K2.6",
|
| 190 |
mcpServers={},
|
| 191 |
+
tool_runtime="local",
|
| 192 |
),
|
| 193 |
)
|
| 194 |
monkeypatch.setattr(main_mod, "print_banner", fake_banner)
|
| 195 |
|
| 196 |
with pytest.raises(StopAfterBanner):
|
| 197 |
await main_mod.main(model="openai/gpt-5.5")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@pytest.mark.asyncio
|
| 201 |
+
async def test_local_model_local_runtime_skips_hf_token_prompt(monkeypatch):
|
| 202 |
+
class StopAfterBanner(Exception):
|
| 203 |
+
pass
|
| 204 |
+
|
| 205 |
+
async def fail_prompt(_prompt_session):
|
| 206 |
+
raise AssertionError("local model with local tools should not prompt")
|
| 207 |
+
|
| 208 |
+
def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
|
| 209 |
+
assert model == "llamacpp/model"
|
| 210 |
+
assert hf_user is None
|
| 211 |
+
assert tool_runtime == "local filesystem"
|
| 212 |
+
raise StopAfterBanner
|
| 213 |
+
|
| 214 |
+
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
| 215 |
+
monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
|
| 216 |
+
monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None)
|
| 217 |
+
monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fail_prompt)
|
| 218 |
+
monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: None)
|
| 219 |
+
monkeypatch.setattr(
|
| 220 |
+
main_mod,
|
| 221 |
+
"load_config",
|
| 222 |
+
lambda _path, **_kwargs: SimpleNamespace(
|
| 223 |
+
model_name="llamacpp/model",
|
| 224 |
+
mcpServers={},
|
| 225 |
+
tool_runtime="local",
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
monkeypatch.setattr(main_mod, "print_banner", fake_banner)
|
| 229 |
+
|
| 230 |
+
with pytest.raises(StopAfterBanner):
|
| 231 |
+
await main_mod.main()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
@pytest.mark.asyncio
|
| 235 |
+
async def test_local_model_sandbox_runtime_prompts_for_hf_token(monkeypatch):
|
| 236 |
+
class StopAfterBanner(Exception):
|
| 237 |
+
pass
|
| 238 |
+
|
| 239 |
+
prompted = False
|
| 240 |
+
|
| 241 |
+
async def fake_prompt(_prompt_session):
|
| 242 |
+
nonlocal prompted
|
| 243 |
+
prompted = True
|
| 244 |
+
return "hf-token"
|
| 245 |
+
|
| 246 |
+
def fake_banner(*, model=None, hf_user=None, tool_runtime=None):
|
| 247 |
+
assert model == "llamacpp/model"
|
| 248 |
+
assert hf_user == "tester"
|
| 249 |
+
assert tool_runtime == "HF sandbox"
|
| 250 |
+
raise StopAfterBanner
|
| 251 |
+
|
| 252 |
+
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
| 253 |
+
monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
|
| 254 |
+
monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: None)
|
| 255 |
+
monkeypatch.setattr(main_mod, "_prompt_and_save_hf_token", fake_prompt)
|
| 256 |
+
monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester")
|
| 257 |
+
monkeypatch.setattr(
|
| 258 |
+
main_mod,
|
| 259 |
+
"load_config",
|
| 260 |
+
lambda _path, **_kwargs: SimpleNamespace(
|
| 261 |
+
model_name="llamacpp/model",
|
| 262 |
+
mcpServers={},
|
| 263 |
+
tool_runtime="local",
|
| 264 |
+
),
|
| 265 |
+
)
|
| 266 |
+
monkeypatch.setattr(main_mod, "print_banner", fake_banner)
|
| 267 |
+
|
| 268 |
+
with pytest.raises(StopAfterBanner):
|
| 269 |
+
await main_mod.main(sandbox_tools=True)
|
| 270 |
+
|
| 271 |
+
assert prompted is True
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@pytest.mark.asyncio
|
| 275 |
+
async def test_interactive_main_passes_sandbox_runtime_to_tool_router(monkeypatch):
|
| 276 |
+
class StopAfterToolRouter(Exception):
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
seen: dict[str, object] = {}
|
| 280 |
+
|
| 281 |
+
class FakeGateway:
|
| 282 |
+
def __init__(self, _config):
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
async def start(self):
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
class FakeToolRouter:
|
| 289 |
+
def __init__(self, mcp_servers, *, hf_token=None, local_mode=True):
|
| 290 |
+
seen["mcp_servers"] = mcp_servers
|
| 291 |
+
seen["hf_token"] = hf_token
|
| 292 |
+
seen["local_mode"] = local_mode
|
| 293 |
+
raise StopAfterToolRouter
|
| 294 |
+
|
| 295 |
+
from agent.core import hf_router_catalog
|
| 296 |
+
|
| 297 |
+
monkeypatch.setattr(main_mod.os, "system", lambda *_args, **_kwargs: 0)
|
| 298 |
+
monkeypatch.setattr(main_mod, "PromptSession", lambda: object())
|
| 299 |
+
monkeypatch.setattr(main_mod, "resolve_hf_token", lambda: "hf-token")
|
| 300 |
+
monkeypatch.setattr(main_mod, "_get_hf_user", lambda _token: "tester")
|
| 301 |
+
monkeypatch.setattr(main_mod, "print_banner", lambda **_kwargs: None)
|
| 302 |
+
monkeypatch.setattr(hf_router_catalog, "prewarm", lambda: None)
|
| 303 |
+
monkeypatch.setattr(
|
| 304 |
+
main_mod,
|
| 305 |
+
"load_config",
|
| 306 |
+
lambda _path, **_kwargs: SimpleNamespace(
|
| 307 |
+
model_name="llamacpp/model",
|
| 308 |
+
mcpServers={"server": object()},
|
| 309 |
+
messaging=SimpleNamespace(default_auto_destinations=lambda: []),
|
| 310 |
+
tool_runtime="local",
|
| 311 |
+
),
|
| 312 |
+
)
|
| 313 |
+
monkeypatch.setattr(main_mod, "NotificationGateway", FakeGateway)
|
| 314 |
+
monkeypatch.setattr(main_mod, "ToolRouter", FakeToolRouter)
|
| 315 |
+
|
| 316 |
+
with pytest.raises(StopAfterToolRouter):
|
| 317 |
+
await main_mod.main(sandbox_tools=True)
|
| 318 |
+
|
| 319 |
+
assert seen["hf_token"] == "hf-token"
|
| 320 |
+
assert seen["local_mode"] is False
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@pytest.mark.asyncio
|
| 324 |
+
async def test_initial_sandbox_preload_waits_before_prompt():
|
| 325 |
+
waited = False
|
| 326 |
+
|
| 327 |
+
async def preload():
|
| 328 |
+
nonlocal waited
|
| 329 |
+
await asyncio.sleep(0)
|
| 330 |
+
waited = True
|
| 331 |
+
|
| 332 |
+
task = asyncio.create_task(preload())
|
| 333 |
+
await main_mod._wait_for_initial_sandbox_preload(
|
| 334 |
+
[SimpleNamespace(sandbox_preload_task=task)]
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
assert waited is True
|
tests/unit/test_config.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
import json
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
from agent import config as config_module
|
| 4 |
|
| 5 |
|
|
@@ -121,3 +124,35 @@ def test_slack_user_defaults_can_be_disabled(tmp_path, monkeypatch):
|
|
| 121 |
|
| 122 |
assert not config.messaging.enabled
|
| 123 |
assert config.messaging.destinations == {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
|
| 3 |
+
import pytest
|
| 4 |
+
from pydantic import ValidationError
|
| 5 |
+
|
| 6 |
from agent import config as config_module
|
| 7 |
|
| 8 |
|
|
|
|
| 124 |
|
| 125 |
assert not config.messaging.enabled
|
| 126 |
assert config.messaging.destinations == {}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def test_tool_runtime_defaults_to_local(tmp_path):
|
| 130 |
+
config_path = tmp_path / "config.json"
|
| 131 |
+
_write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
|
| 132 |
+
|
| 133 |
+
config = config_module.load_config(str(config_path))
|
| 134 |
+
|
| 135 |
+
assert config.tool_runtime == "local"
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def test_user_config_can_set_sandbox_tool_runtime(tmp_path, monkeypatch):
|
| 139 |
+
config_path = tmp_path / "config.json"
|
| 140 |
+
user_config_path = tmp_path / "user-config.json"
|
| 141 |
+
_write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
|
| 142 |
+
_write_json(user_config_path, {"tool_runtime": "sandbox"})
|
| 143 |
+
monkeypatch.setenv("ML_INTERN_CLI_CONFIG", str(user_config_path))
|
| 144 |
+
|
| 145 |
+
config = config_module.load_config(str(config_path), include_user_defaults=True)
|
| 146 |
+
|
| 147 |
+
assert config.tool_runtime == "sandbox"
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def test_invalid_tool_runtime_is_rejected(tmp_path):
|
| 151 |
+
config_path = tmp_path / "config.json"
|
| 152 |
+
_write_json(
|
| 153 |
+
config_path,
|
| 154 |
+
{"model_name": "moonshotai/Kimi-K2.6", "tool_runtime": "hybrid"},
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
with pytest.raises(ValidationError):
|
| 158 |
+
config_module.load_config(str(config_path))
|
tests/unit/test_dataset_uploads.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
|
| 6 |
+
import httpx
|
| 7 |
+
import pytest
|
| 8 |
+
from fastapi import HTTPException, UploadFile
|
| 9 |
+
from huggingface_hub.errors import HfHubHTTPError
|
| 10 |
+
from starlette.datastructures import FormData
|
| 11 |
+
|
| 12 |
+
_BACKEND_DIR = Path(__file__).resolve().parent.parent.parent / "backend"
|
| 13 |
+
if str(_BACKEND_DIR) not in sys.path:
|
| 14 |
+
sys.path.insert(0, str(_BACKEND_DIR))
|
| 15 |
+
|
| 16 |
+
import dataset_uploads # noqa: E402
|
| 17 |
+
from routes import agent # noqa: E402
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _upload(filename: str, content: bytes = b"a,b\n1,2\n") -> UploadFile:
|
| 21 |
+
return UploadFile(filename=filename, file=io.BytesIO(content))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _track_close(upload: UploadFile):
|
| 25 |
+
state = {"closed": False}
|
| 26 |
+
original_close = upload.close
|
| 27 |
+
|
| 28 |
+
async def close():
|
| 29 |
+
state["closed"] = True
|
| 30 |
+
await original_close()
|
| 31 |
+
|
| 32 |
+
upload.close = close
|
| 33 |
+
return state
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _request(
|
| 37 |
+
upload: UploadFile | None = None,
|
| 38 |
+
headers: dict[str, str] | None = None,
|
| 39 |
+
):
|
| 40 |
+
state = {"form_called": False}
|
| 41 |
+
|
| 42 |
+
class FakeRequest:
|
| 43 |
+
def __init__(self):
|
| 44 |
+
self.headers = headers or {}
|
| 45 |
+
self.cookies = {}
|
| 46 |
+
|
| 47 |
+
async def form(self, **_kwargs):
|
| 48 |
+
state["form_called"] = True
|
| 49 |
+
if upload is None:
|
| 50 |
+
raise AssertionError("request.form() should not be called")
|
| 51 |
+
return FormData([("file", upload)])
|
| 52 |
+
|
| 53 |
+
return FakeRequest(), state
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_sanitize_dataset_filename_strips_paths_and_unsafe_chars():
|
| 57 |
+
assert (
|
| 58 |
+
dataset_uploads.sanitize_dataset_filename("../../bad file (final).CSV")
|
| 59 |
+
== "bad-file-final.csv"
|
| 60 |
+
)
|
| 61 |
+
assert dataset_uploads.sanitize_dataset_filename("") == "dataset.csv"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_dataset_format_rejects_unsupported_extension():
|
| 65 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 66 |
+
dataset_uploads.dataset_format_from_filename("notes.txt")
|
| 67 |
+
|
| 68 |
+
assert exc_info.value.status_code == 400
|
| 69 |
+
|
| 70 |
+
with pytest.raises(HTTPException):
|
| 71 |
+
dataset_uploads.dataset_format_from_filename("notes")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_dataset_repo_card_exposes_each_upload_as_config():
|
| 75 |
+
card = dataset_uploads.dataset_repo_card(
|
| 76 |
+
"alice/ml-intern-s1-datasets",
|
| 77 |
+
[
|
| 78 |
+
"README.md",
|
| 79 |
+
"uploads/oldabc/rows.jsonl",
|
| 80 |
+
"uploads/oldabc/rows.jsonl",
|
| 81 |
+
"uploads/newdef/table.csv",
|
| 82 |
+
],
|
| 83 |
+
).decode("utf-8")
|
| 84 |
+
|
| 85 |
+
assert "configs:" in card
|
| 86 |
+
assert "- config_name: upload_oldabc" in card
|
| 87 |
+
assert ' path: "uploads/oldabc/rows.jsonl"' in card
|
| 88 |
+
assert "- config_name: upload_newdef" in card
|
| 89 |
+
assert ' path: "uploads/newdef/table.csv"' in card
|
| 90 |
+
assert card.count("- config_name: upload_oldabc") == 1
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@pytest.mark.asyncio
|
| 94 |
+
async def test_validate_dataset_upload_rejects_size_over_limit(monkeypatch):
|
| 95 |
+
monkeypatch.setattr(dataset_uploads, "MAX_DATASET_UPLOAD_BYTES", 3)
|
| 96 |
+
upload = _upload("rows.csv", b"abcd")
|
| 97 |
+
try:
|
| 98 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 99 |
+
await dataset_uploads.validate_dataset_upload(upload)
|
| 100 |
+
finally:
|
| 101 |
+
await upload.close()
|
| 102 |
+
|
| 103 |
+
assert exc_info.value.status_code == 413
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@pytest.mark.asyncio
|
| 107 |
+
async def test_push_dataset_upload_creates_private_repo_and_uploads_file(monkeypatch):
|
| 108 |
+
instances = []
|
| 109 |
+
|
| 110 |
+
class FakeApi:
|
| 111 |
+
def __init__(self, token):
|
| 112 |
+
self.token = token
|
| 113 |
+
self.create_calls = []
|
| 114 |
+
self.settings_calls = []
|
| 115 |
+
self.list_calls = []
|
| 116 |
+
self.upload_calls = []
|
| 117 |
+
instances.append(self)
|
| 118 |
+
|
| 119 |
+
def create_repo(self, **kwargs):
|
| 120 |
+
self.create_calls.append(kwargs)
|
| 121 |
+
|
| 122 |
+
def update_repo_settings(self, **kwargs):
|
| 123 |
+
self.settings_calls.append(kwargs)
|
| 124 |
+
|
| 125 |
+
def list_repo_files(self, **kwargs):
|
| 126 |
+
self.list_calls.append(kwargs)
|
| 127 |
+
return [
|
| 128 |
+
"README.md",
|
| 129 |
+
"uploads/oldupload/old.jsonl",
|
| 130 |
+
"uploads/notes.txt",
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
def upload_file(self, **kwargs):
|
| 134 |
+
if kwargs["path_in_repo"] != "README.md":
|
| 135 |
+
assert kwargs["path_or_fileobj"] == b"a,b\n1,2\n"
|
| 136 |
+
self.upload_calls.append(kwargs)
|
| 137 |
+
|
| 138 |
+
monkeypatch.setattr(dataset_uploads, "HfApi", FakeApi)
|
| 139 |
+
monkeypatch.setattr(
|
| 140 |
+
dataset_uploads.uuid,
|
| 141 |
+
"uuid4",
|
| 142 |
+
lambda: SimpleNamespace(hex="feedfacecafebeef"),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
upload = _upload("../Data Set.CSV")
|
| 146 |
+
try:
|
| 147 |
+
result = await dataset_uploads.push_dataset_upload_to_hub(
|
| 148 |
+
upload=upload,
|
| 149 |
+
session_id="12345678-90ab-cdef-1234-567890abcdef",
|
| 150 |
+
hf_username="alice",
|
| 151 |
+
hf_token="hf-token",
|
| 152 |
+
)
|
| 153 |
+
finally:
|
| 154 |
+
await upload.close()
|
| 155 |
+
|
| 156 |
+
api = instances[0]
|
| 157 |
+
assert api.token == "hf-token"
|
| 158 |
+
assert api.create_calls == [
|
| 159 |
+
{
|
| 160 |
+
"repo_id": "alice/ml-intern-12345678-datasets",
|
| 161 |
+
"repo_type": "dataset",
|
| 162 |
+
"private": True,
|
| 163 |
+
"exist_ok": True,
|
| 164 |
+
}
|
| 165 |
+
]
|
| 166 |
+
assert api.settings_calls == [
|
| 167 |
+
{
|
| 168 |
+
"repo_id": "alice/ml-intern-12345678-datasets",
|
| 169 |
+
"repo_type": "dataset",
|
| 170 |
+
"private": True,
|
| 171 |
+
}
|
| 172 |
+
]
|
| 173 |
+
assert api.list_calls == [
|
| 174 |
+
{
|
| 175 |
+
"repo_id": "alice/ml-intern-12345678-datasets",
|
| 176 |
+
"repo_type": "dataset",
|
| 177 |
+
}
|
| 178 |
+
]
|
| 179 |
+
assert [call["path_in_repo"] for call in api.upload_calls] == [
|
| 180 |
+
"uploads/feedfacecafe/Data-Set.csv",
|
| 181 |
+
"README.md",
|
| 182 |
+
]
|
| 183 |
+
readme = api.upload_calls[1]["path_or_fileobj"].decode("utf-8")
|
| 184 |
+
assert "- config_name: upload_oldupload" in readme
|
| 185 |
+
assert ' path: "uploads/oldupload/old.jsonl"' in readme
|
| 186 |
+
assert "- config_name: upload_feedfacecafe" in readme
|
| 187 |
+
assert ' path: "uploads/feedfacecafe/Data-Set.csv"' in readme
|
| 188 |
+
assert result.repo_id == "alice/ml-intern-12345678-datasets"
|
| 189 |
+
assert result.config_name == "upload_feedfacecafe"
|
| 190 |
+
assert result.format == "csv"
|
| 191 |
+
assert result.load_dataset_snippet == (
|
| 192 |
+
"from datasets import load_dataset\n\n"
|
| 193 |
+
'dataset = load_dataset("alice/ml-intern-12345678-datasets", '
|
| 194 |
+
'"upload_feedfacecafe", split="train", token=True)'
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@pytest.mark.asyncio
|
| 199 |
+
async def test_upload_route_requires_hf_token_without_parsing_upload(monkeypatch):
|
| 200 |
+
monkeypatch.delenv("HF_TOKEN", raising=False)
|
| 201 |
+
upload = _upload("rows.csv")
|
| 202 |
+
close_state = _track_close(upload)
|
| 203 |
+
request, request_state = _request(upload)
|
| 204 |
+
|
| 205 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 206 |
+
return SimpleNamespace(
|
| 207 |
+
is_active=True,
|
| 208 |
+
is_processing=False,
|
| 209 |
+
session=SimpleNamespace(pending_approval=None),
|
| 210 |
+
hf_username="alice",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 217 |
+
await agent.upload_session_dataset(
|
| 218 |
+
"s1",
|
| 219 |
+
request,
|
| 220 |
+
{"user_id": "u1", "username": "alice"},
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
assert exc_info.value.status_code == 401
|
| 224 |
+
assert request_state["form_called"] is False
|
| 225 |
+
assert close_state["closed"] is False
|
| 226 |
+
finally:
|
| 227 |
+
await upload.close()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@pytest.mark.asyncio
|
| 231 |
+
async def test_upload_route_rejects_content_length_before_parsing(monkeypatch):
|
| 232 |
+
upload = _upload("rows.csv")
|
| 233 |
+
close_state = _track_close(upload)
|
| 234 |
+
request, request_state = _request(
|
| 235 |
+
upload,
|
| 236 |
+
headers={
|
| 237 |
+
"content-length": str(
|
| 238 |
+
dataset_uploads.MAX_DATASET_UPLOAD_BYTES
|
| 239 |
+
+ agent.DATASET_UPLOAD_MULTIPART_SLACK_BYTES
|
| 240 |
+
+ 1
|
| 241 |
+
)
|
| 242 |
+
},
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 246 |
+
raise AssertionError("session access should not run for oversized uploads")
|
| 247 |
+
|
| 248 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 252 |
+
await agent.upload_session_dataset(
|
| 253 |
+
"s1",
|
| 254 |
+
request,
|
| 255 |
+
{
|
| 256 |
+
"user_id": "u1",
|
| 257 |
+
"username": "alice",
|
| 258 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 259 |
+
},
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
assert exc_info.value.status_code == 413
|
| 263 |
+
assert request_state["form_called"] is False
|
| 264 |
+
assert close_state["closed"] is False
|
| 265 |
+
finally:
|
| 266 |
+
await upload.close()
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@pytest.mark.asyncio
|
| 270 |
+
async def test_upload_route_rejects_busy_session_without_parsing_upload(monkeypatch):
|
| 271 |
+
upload = _upload("rows.csv")
|
| 272 |
+
close_state = _track_close(upload)
|
| 273 |
+
request, request_state = _request(upload)
|
| 274 |
+
|
| 275 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 276 |
+
return SimpleNamespace(
|
| 277 |
+
is_active=True,
|
| 278 |
+
is_processing=True,
|
| 279 |
+
session=SimpleNamespace(pending_approval=None),
|
| 280 |
+
hf_username="alice",
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 284 |
+
|
| 285 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 286 |
+
await agent.upload_session_dataset(
|
| 287 |
+
"s1",
|
| 288 |
+
request,
|
| 289 |
+
{
|
| 290 |
+
"user_id": "u1",
|
| 291 |
+
"username": "alice",
|
| 292 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 293 |
+
},
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
assert exc_info.value.status_code == 409
|
| 297 |
+
assert request_state["form_called"] is False
|
| 298 |
+
assert close_state["closed"] is False
|
| 299 |
+
await upload.close()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@pytest.mark.asyncio
|
| 303 |
+
async def test_upload_route_appends_context_note_and_persists(monkeypatch):
|
| 304 |
+
upload = _upload("rows.jsonl", b'{"text":"hi"}\n')
|
| 305 |
+
close_state = _track_close(upload)
|
| 306 |
+
request, request_state = _request(upload)
|
| 307 |
+
messages = []
|
| 308 |
+
persisted = []
|
| 309 |
+
agent_session = SimpleNamespace(
|
| 310 |
+
is_active=True,
|
| 311 |
+
is_processing=False,
|
| 312 |
+
session=SimpleNamespace(
|
| 313 |
+
pending_approval=None,
|
| 314 |
+
context_manager=SimpleNamespace(add_message=messages.append),
|
| 315 |
+
),
|
| 316 |
+
hf_username="alice",
|
| 317 |
+
)
|
| 318 |
+
uploaded = dataset_uploads.DatasetUpload(
|
| 319 |
+
session_id="s1",
|
| 320 |
+
repo_id="alice/ml-intern-s1-datasets",
|
| 321 |
+
repo_type="dataset",
|
| 322 |
+
private=True,
|
| 323 |
+
upload_id="abc123",
|
| 324 |
+
config_name="upload_abc123",
|
| 325 |
+
filename="rows.jsonl",
|
| 326 |
+
original_filename="rows.jsonl",
|
| 327 |
+
path_in_repo="uploads/abc123/rows.jsonl",
|
| 328 |
+
size_bytes=14,
|
| 329 |
+
format="jsonl",
|
| 330 |
+
hub_url="https://huggingface.co/datasets/alice/ml-intern-s1-datasets/blob/main/uploads/abc123/rows.jsonl",
|
| 331 |
+
load_dataset_snippet='dataset = load_dataset("json")',
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 335 |
+
return agent_session
|
| 336 |
+
|
| 337 |
+
async def fake_push_dataset_upload_to_hub(**kwargs):
|
| 338 |
+
assert kwargs["upload"] is upload
|
| 339 |
+
assert kwargs["hf_token"] == "hf-token"
|
| 340 |
+
return uploaded
|
| 341 |
+
|
| 342 |
+
async def fake_persist_session_snapshot(value):
|
| 343 |
+
persisted.append(value)
|
| 344 |
+
|
| 345 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 346 |
+
monkeypatch.setattr(
|
| 347 |
+
agent, "push_dataset_upload_to_hub", fake_push_dataset_upload_to_hub
|
| 348 |
+
)
|
| 349 |
+
monkeypatch.setattr(
|
| 350 |
+
agent.session_manager,
|
| 351 |
+
"persist_session_snapshot",
|
| 352 |
+
fake_persist_session_snapshot,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
response = await agent.upload_session_dataset(
|
| 356 |
+
"s1",
|
| 357 |
+
request,
|
| 358 |
+
{
|
| 359 |
+
"user_id": "u1",
|
| 360 |
+
"username": "alice",
|
| 361 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 362 |
+
},
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
assert response.repo_id == uploaded.repo_id
|
| 366 |
+
assert response.config_name == uploaded.config_name
|
| 367 |
+
assert response.path_in_repo == uploaded.path_in_repo
|
| 368 |
+
assert len(messages) == 1
|
| 369 |
+
assert messages[0].role == "user"
|
| 370 |
+
assert messages[0].content.startswith("[SYSTEM:")
|
| 371 |
+
assert uploaded.config_name in messages[0].content
|
| 372 |
+
assert uploaded.path_in_repo in messages[0].content
|
| 373 |
+
assert persisted == [agent_session]
|
| 374 |
+
assert request_state["form_called"] is True
|
| 375 |
+
assert close_state["closed"] is True
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
@pytest.mark.asyncio
|
| 379 |
+
async def test_upload_route_closes_upload_when_hub_upload_fails(monkeypatch):
|
| 380 |
+
upload = _upload("rows.csv")
|
| 381 |
+
close_state = _track_close(upload)
|
| 382 |
+
request, request_state = _request(upload)
|
| 383 |
+
|
| 384 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 385 |
+
return SimpleNamespace(
|
| 386 |
+
is_active=True,
|
| 387 |
+
is_processing=False,
|
| 388 |
+
session=SimpleNamespace(pending_approval=None),
|
| 389 |
+
hf_username="alice",
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
async def fake_push_dataset_upload_to_hub(**_kwargs):
|
| 393 |
+
raise RuntimeError("hub unavailable")
|
| 394 |
+
|
| 395 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 396 |
+
monkeypatch.setattr(
|
| 397 |
+
agent, "push_dataset_upload_to_hub", fake_push_dataset_upload_to_hub
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 401 |
+
await agent.upload_session_dataset(
|
| 402 |
+
"s1",
|
| 403 |
+
request,
|
| 404 |
+
{
|
| 405 |
+
"user_id": "u1",
|
| 406 |
+
"username": "alice",
|
| 407 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 408 |
+
},
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
assert exc_info.value.status_code == 502
|
| 412 |
+
assert exc_info.value.detail == "Dataset upload failed. Please try again."
|
| 413 |
+
assert request_state["form_called"] is True
|
| 414 |
+
assert close_state["closed"] is True
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@pytest.mark.asyncio
|
| 418 |
+
async def test_upload_route_maps_hub_permission_error_safely(monkeypatch):
|
| 419 |
+
upload = _upload("rows.csv")
|
| 420 |
+
close_state = _track_close(upload)
|
| 421 |
+
request, request_state = _request(upload)
|
| 422 |
+
|
| 423 |
+
async def fake_check_session_access(*_args, **_kwargs):
|
| 424 |
+
return SimpleNamespace(
|
| 425 |
+
is_active=True,
|
| 426 |
+
is_processing=False,
|
| 427 |
+
session=SimpleNamespace(pending_approval=None),
|
| 428 |
+
hf_username="alice",
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
async def fake_push_dataset_upload_to_hub(**_kwargs):
|
| 432 |
+
response = httpx.Response(
|
| 433 |
+
403,
|
| 434 |
+
request=httpx.Request("POST", "https://huggingface.co/api/datasets"),
|
| 435 |
+
headers={"x-request-id": "req-123"},
|
| 436 |
+
)
|
| 437 |
+
raise HfHubHTTPError(
|
| 438 |
+
"403 Forbidden: token hf_secret cannot write",
|
| 439 |
+
response=response,
|
| 440 |
+
server_message="token hf_secret cannot write",
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
monkeypatch.setattr(agent, "_check_session_access", fake_check_session_access)
|
| 444 |
+
monkeypatch.setattr(
|
| 445 |
+
agent, "push_dataset_upload_to_hub", fake_push_dataset_upload_to_hub
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
with pytest.raises(HTTPException) as exc_info:
|
| 449 |
+
await agent.upload_session_dataset(
|
| 450 |
+
"s1",
|
| 451 |
+
request,
|
| 452 |
+
{
|
| 453 |
+
"user_id": "u1",
|
| 454 |
+
"username": "alice",
|
| 455 |
+
agent.INTERNAL_HF_TOKEN_KEY: "hf-token",
|
| 456 |
+
},
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
assert exc_info.value.status_code == 403
|
| 460 |
+
assert exc_info.value.detail == (
|
| 461 |
+
"Hugging Face denied permission to create or write to the dataset repo."
|
| 462 |
+
)
|
| 463 |
+
assert "hf_secret" not in exc_info.value.detail
|
| 464 |
+
assert request_state["form_called"] is True
|
| 465 |
+
assert close_state["closed"] is True
|
tests/unit/test_hub_artifacts.py
CHANGED
|
@@ -549,7 +549,7 @@ def test_sitecustomize_caches_lazy_collection_slug_across_bootstraps(
|
|
| 549 |
]
|
| 550 |
|
| 551 |
|
| 552 |
-
def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
|
| 553 |
import huggingface_hub as hub
|
| 554 |
from huggingface_hub import HfApi
|
| 555 |
|
|
@@ -579,6 +579,10 @@ def test_sitecustomize_skips_sandbox_space_registration(monkeypatch):
|
|
| 579 |
def fake_add_collection_item(self, **kwargs):
|
| 580 |
collection_items.append(kwargs)
|
| 581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
|
| 583 |
monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
|
| 584 |
monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
|
|
|
|
| 549 |
]
|
| 550 |
|
| 551 |
|
| 552 |
+
def test_sitecustomize_skips_sandbox_space_registration(monkeypatch, tmp_path):
|
| 553 |
import huggingface_hub as hub
|
| 554 |
from huggingface_hub import HfApi
|
| 555 |
|
|
|
|
| 579 |
def fake_add_collection_item(self, **kwargs):
|
| 580 |
collection_items.append(kwargs)
|
| 581 |
|
| 582 |
+
monkeypatch.setenv(
|
| 583 |
+
"ML_INTERN_ARTIFACT_COLLECTION_CACHE",
|
| 584 |
+
str(tmp_path / "collection-slug.txt"),
|
| 585 |
+
)
|
| 586 |
monkeypatch.setattr(HfApi, "upload_file", fake_upload_file)
|
| 587 |
monkeypatch.setattr(HfApi, "create_collection", fake_create_collection)
|
| 588 |
monkeypatch.setattr(HfApi, "add_collection_item", fake_add_collection_item)
|
tests/unit/test_no_tool_continuation_guard.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from agent.config import Config
|
| 7 |
+
from agent.core import agent_loop
|
| 8 |
+
from agent.core.agent_loop import Handlers, LLMResult
|
| 9 |
+
from agent.core.session import Session
|
| 10 |
+
from agent.tools.plan_tool import PlanTool
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FakeToolRouter:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.calls = []
|
| 16 |
+
|
| 17 |
+
def get_tool_specs_for_llm(self):
|
| 18 |
+
return [
|
| 19 |
+
{
|
| 20 |
+
"type": "function",
|
| 21 |
+
"function": {
|
| 22 |
+
"name": "plan_tool",
|
| 23 |
+
"description": "Update plan",
|
| 24 |
+
"parameters": {"type": "object"},
|
| 25 |
+
},
|
| 26 |
+
}
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
async def call_tool(self, name, arguments, session=None, tool_call_id=None):
|
| 30 |
+
self.calls.append((name, arguments, tool_call_id))
|
| 31 |
+
if name == "plan_tool" and session is not None:
|
| 32 |
+
session.current_plan = [dict(todo) for todo in arguments["todos"]]
|
| 33 |
+
return "plan updated", True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.mark.asyncio
|
| 37 |
+
async def test_plan_tool_stores_session_scoped_plan():
|
| 38 |
+
events = []
|
| 39 |
+
|
| 40 |
+
class FakeSession:
|
| 41 |
+
current_plan = []
|
| 42 |
+
|
| 43 |
+
async def send_event(self, event):
|
| 44 |
+
events.append(event)
|
| 45 |
+
|
| 46 |
+
session = FakeSession()
|
| 47 |
+
todos = [{"id": "1", "content": "Smoke test", "status": "in_progress"}]
|
| 48 |
+
|
| 49 |
+
result = await PlanTool(session=session).execute({"todos": todos})
|
| 50 |
+
|
| 51 |
+
assert result["isError"] is False
|
| 52 |
+
assert session.current_plan == todos
|
| 53 |
+
assert events[0].event_type == "plan_update"
|
| 54 |
+
assert events[0].data == {"plan": todos}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@pytest.mark.asyncio
|
| 58 |
+
async def test_no_tool_response_retries_when_plan_is_incomplete(monkeypatch):
|
| 59 |
+
config = Config.model_validate(
|
| 60 |
+
{"model_name": "openai/test", "save_sessions": False}
|
| 61 |
+
)
|
| 62 |
+
event_queue = asyncio.Queue()
|
| 63 |
+
router = FakeToolRouter()
|
| 64 |
+
session = Session(
|
| 65 |
+
event_queue,
|
| 66 |
+
config,
|
| 67 |
+
tool_router=router,
|
| 68 |
+
stream=False,
|
| 69 |
+
)
|
| 70 |
+
session.current_plan = [
|
| 71 |
+
{
|
| 72 |
+
"id": "1",
|
| 73 |
+
"content": "Write and smoke-test training script",
|
| 74 |
+
"status": "in_progress",
|
| 75 |
+
},
|
| 76 |
+
{"id": "2", "content": "Launch full training job", "status": "pending"},
|
| 77 |
+
]
|
| 78 |
+
calls = []
|
| 79 |
+
|
| 80 |
+
async def fake_call_llm_non_streaming(session, messages, tools, llm_params):
|
| 81 |
+
calls.append(messages)
|
| 82 |
+
if len(calls) == 1:
|
| 83 |
+
return LLMResult(
|
| 84 |
+
content="I should keep going, but I forgot to call a tool.",
|
| 85 |
+
tool_calls_acc={},
|
| 86 |
+
token_count=10,
|
| 87 |
+
finish_reason="stop",
|
| 88 |
+
)
|
| 89 |
+
if len(calls) == 2:
|
| 90 |
+
assert "CONTINUATION GUARD" in messages[-1].content
|
| 91 |
+
return LLMResult(
|
| 92 |
+
content=None,
|
| 93 |
+
tool_calls_acc={
|
| 94 |
+
0: {
|
| 95 |
+
"id": "call_1",
|
| 96 |
+
"function": {
|
| 97 |
+
"name": "plan_tool",
|
| 98 |
+
"arguments": json.dumps(
|
| 99 |
+
{
|
| 100 |
+
"todos": [
|
| 101 |
+
{
|
| 102 |
+
"id": "1",
|
| 103 |
+
"content": "Write and smoke-test training script",
|
| 104 |
+
"status": "completed",
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"id": "2",
|
| 108 |
+
"content": "Launch full training job",
|
| 109 |
+
"status": "completed",
|
| 110 |
+
},
|
| 111 |
+
]
|
| 112 |
+
}
|
| 113 |
+
),
|
| 114 |
+
},
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
token_count=20,
|
| 118 |
+
finish_reason="tool_calls",
|
| 119 |
+
)
|
| 120 |
+
return LLMResult(
|
| 121 |
+
content="Done.",
|
| 122 |
+
tool_calls_acc={},
|
| 123 |
+
token_count=30,
|
| 124 |
+
finish_reason="stop",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
monkeypatch.setattr(
|
| 128 |
+
agent_loop, "_resolve_llm_params", lambda *_, **__: {"model": "openai/test"}
|
| 129 |
+
)
|
| 130 |
+
monkeypatch.setattr(
|
| 131 |
+
agent_loop, "_call_llm_non_streaming", fake_call_llm_non_streaming
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
final = await Handlers.run_agent(session, "continue")
|
| 135 |
+
|
| 136 |
+
assert final == "Done."
|
| 137 |
+
assert len(calls) == 3
|
| 138 |
+
assert router.calls[0][0] == "plan_tool"
|
| 139 |
+
assert all(todo["status"] == "completed" for todo in session.current_plan)
|
| 140 |
+
events = []
|
| 141 |
+
while not event_queue.empty():
|
| 142 |
+
events.append(await event_queue.get())
|
| 143 |
+
assert any(
|
| 144 |
+
event.event_type == "tool_log"
|
| 145 |
+
and "text-only response" in (event.data or {}).get("log", "")
|
| 146 |
+
for event in events
|
| 147 |
+
)
|
tests/unit/test_sandbox_auto_start.py
CHANGED
|
@@ -1,7 +1,15 @@
|
|
|
|
|
| 1 |
from types import SimpleNamespace
|
| 2 |
from pathlib import Path
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from agent.core.agent_loop import _needs_approval
|
|
|
|
|
|
|
|
|
|
| 5 |
from agent.tools.sandbox_tool import get_sandbox_tools
|
| 6 |
|
| 7 |
|
|
@@ -34,3 +42,102 @@ def test_prompt_and_tool_specs_do_not_require_cpu_sandbox_create():
|
|
| 34 |
in tool_specs["sandbox_create"]
|
| 35 |
)
|
| 36 |
assert "started automatically for normal CPU work" in tool_specs["bash"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
from types import SimpleNamespace
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from agent.config import Config
|
| 8 |
+
from agent.core import agent_loop
|
| 9 |
from agent.core.agent_loop import _needs_approval
|
| 10 |
+
from agent.core.session import OpType
|
| 11 |
+
from agent.core.tools import create_builtin_tools
|
| 12 |
+
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
|
| 13 |
from agent.tools.sandbox_tool import get_sandbox_tools
|
| 14 |
|
| 15 |
|
|
|
|
| 42 |
in tool_specs["sandbox_create"]
|
| 43 |
)
|
| 44 |
assert "started automatically for normal CPU work" in tool_specs["bash"]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_prompt_rejects_local_machine_paths_for_hf_jobs_scripts():
|
| 48 |
+
prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
|
| 49 |
+
|
| 50 |
+
assert "Never pass a local machine path to hf_jobs.script" in prompt
|
| 51 |
+
assert "/fsx/..." in prompt
|
| 52 |
+
assert "inline Python source code" in prompt
|
| 53 |
+
assert "a file already written in the session sandbox" in prompt
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_prompt_and_hf_jobs_spec_require_gpu_preflight_for_gpu_jobs():
|
| 57 |
+
prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
|
| 58 |
+
jobs_description = HF_JOBS_TOOL_SPEC["description"]
|
| 59 |
+
|
| 60 |
+
assert "GPU preflight is mandatory before hf_jobs" in prompt
|
| 61 |
+
assert "GPU sandbox smoke test" in prompt
|
| 62 |
+
assert "If you skip GPU sandbox preflight" in prompt
|
| 63 |
+
assert "you MUST create a GPU sandbox with sandbox_create first" in jobs_description
|
| 64 |
+
assert "If skipped, state why before calling hf_jobs" in jobs_description
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_local_tool_runtime_excludes_sandbox_create():
|
| 68 |
+
tool_names = {tool.name for tool in create_builtin_tools(local_mode=True)}
|
| 69 |
+
|
| 70 |
+
assert {"bash", "read", "write", "edit"} <= tool_names
|
| 71 |
+
assert "sandbox_create" not in tool_names
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_sandbox_tool_runtime_includes_sandbox_create():
|
| 75 |
+
tool_names = {tool.name for tool in create_builtin_tools(local_mode=False)}
|
| 76 |
+
|
| 77 |
+
assert {"sandbox_create", "bash", "read", "write", "edit"} <= tool_names
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@pytest.mark.asyncio
|
| 81 |
+
async def test_cli_sandbox_runtime_preloads_and_tears_down_sandbox(monkeypatch):
|
| 82 |
+
started = []
|
| 83 |
+
torn_down = []
|
| 84 |
+
|
| 85 |
+
class FakeToolRouter:
|
| 86 |
+
tools = {}
|
| 87 |
+
|
| 88 |
+
def get_tool_specs_for_llm(self):
|
| 89 |
+
return []
|
| 90 |
+
|
| 91 |
+
async def __aenter__(self):
|
| 92 |
+
return self
|
| 93 |
+
|
| 94 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
def fake_start_cpu_sandbox_preload(session):
|
| 98 |
+
started.append(session)
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
async def fake_teardown_session_sandbox(session):
|
| 102 |
+
torn_down.append(session)
|
| 103 |
+
|
| 104 |
+
monkeypatch.setattr(
|
| 105 |
+
agent_loop, "start_cpu_sandbox_preload", fake_start_cpu_sandbox_preload
|
| 106 |
+
)
|
| 107 |
+
monkeypatch.setattr(
|
| 108 |
+
agent_loop, "teardown_session_sandbox", fake_teardown_session_sandbox
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
submission_queue = asyncio.Queue()
|
| 112 |
+
event_queue = asyncio.Queue()
|
| 113 |
+
session_holder = [None]
|
| 114 |
+
config = Config.model_validate(
|
| 115 |
+
{"model_name": "openai/gpt-5.5", "save_sessions": False}
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
task = asyncio.create_task(
|
| 119 |
+
agent_loop.submission_loop(
|
| 120 |
+
submission_queue,
|
| 121 |
+
event_queue,
|
| 122 |
+
config=config,
|
| 123 |
+
tool_router=FakeToolRouter(),
|
| 124 |
+
session_holder=session_holder,
|
| 125 |
+
hf_token="hf-token",
|
| 126 |
+
user_id="tester",
|
| 127 |
+
local_mode=False,
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
ready = await asyncio.wait_for(event_queue.get(), timeout=1)
|
| 132 |
+
assert ready.event_type == "ready"
|
| 133 |
+
assert started == [session_holder[0]]
|
| 134 |
+
assert session_holder[0].local_mode is False
|
| 135 |
+
|
| 136 |
+
await submission_queue.put(
|
| 137 |
+
SimpleNamespace(
|
| 138 |
+
operation=SimpleNamespace(op_type=OpType.SHUTDOWN, data=None),
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
await asyncio.wait_for(task, timeout=1)
|
| 142 |
+
|
| 143 |
+
assert torn_down == [session_holder[0]]
|
tests/unit/test_sandbox_private_spaces.py
CHANGED
|
@@ -13,6 +13,28 @@ def _fail_metadata_update(*args, **kwargs):
|
|
| 13 |
raise AssertionError("sandbox creation should not update Space metadata")
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
|
| 17 |
duplicate_kwargs = {}
|
| 18 |
logs: list[str] = []
|
|
@@ -22,8 +44,25 @@ def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
|
|
| 22 |
def __init__(self, token=None):
|
| 23 |
self.token = token
|
| 24 |
|
| 25 |
-
def
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
| 29 |
requested_hardware.append((space_id, hardware, sleep_time))
|
|
@@ -45,12 +84,38 @@ def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
|
|
| 45 |
|
| 46 |
Sandbox.create(owner="alice", token="hf-token", log=logs.append)
|
| 47 |
|
|
|
|
| 48 |
assert duplicate_kwargs["private"] is True
|
| 49 |
-
assert duplicate_kwargs["
|
| 50 |
assert requested_hardware == []
|
| 51 |
assert not any("sleep time" in log for log in logs)
|
| 52 |
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
|
| 55 |
runtime_calls = 0
|
| 56 |
|
|
@@ -67,7 +132,16 @@ def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
|
|
| 67 |
def __init__(self, token=None):
|
| 68 |
self.token = token
|
| 69 |
|
| 70 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
pass
|
| 72 |
|
| 73 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
|
@@ -107,8 +181,25 @@ def test_sandbox_client_configures_gpu_at_duplication(monkeypatch):
|
|
| 107 |
def __init__(self, token=None):
|
| 108 |
self.token = token
|
| 109 |
|
| 110 |
-
def
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
| 114 |
requested_hardware.append((space_id, hardware, sleep_time))
|
|
@@ -137,8 +228,9 @@ def test_sandbox_client_configures_gpu_at_duplication(monkeypatch):
|
|
| 137 |
)
|
| 138 |
|
| 139 |
assert sandbox.space_id.startswith("alice/sandbox-")
|
| 140 |
-
assert duplicate_kwargs["
|
| 141 |
-
assert duplicate_kwargs["
|
|
|
|
| 142 |
assert requested_hardware == []
|
| 143 |
assert "Using duplicated Space hardware: t4-small" in logs
|
| 144 |
assert "Using duplicated Space sleep time: 2700s" in logs
|
|
@@ -153,8 +245,25 @@ def test_sandbox_client_logs_cpu_sleep_time_as_hub_fixed(monkeypatch):
|
|
| 153 |
def __init__(self, token=None):
|
| 154 |
self.token = token
|
| 155 |
|
| 156 |
-
def
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
| 160 |
requested_hardware.append((space_id, hardware, sleep_time))
|
|
@@ -180,8 +289,9 @@ def test_sandbox_client_logs_cpu_sleep_time_as_hub_fixed(monkeypatch):
|
|
| 180 |
log=logs.append,
|
| 181 |
)
|
| 182 |
|
| 183 |
-
assert duplicate_kwargs["
|
| 184 |
-
assert duplicate_kwargs["
|
|
|
|
| 185 |
assert requested_hardware == []
|
| 186 |
assert "Using duplicated Space hardware: cpu-basic" in logs
|
| 187 |
assert (
|
|
@@ -310,6 +420,71 @@ def test_ensure_sandbox_overrides_private_argument(monkeypatch):
|
|
| 310 |
assert persisted[-1]["sandbox_status"] == "active"
|
| 311 |
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
def test_sandbox_creation_is_serialized_per_owner(monkeypatch):
|
| 314 |
active_creates = 0
|
| 315 |
max_active_creates = 0
|
|
@@ -429,7 +604,7 @@ def test_sandbox_create_replaces_auto_cpu_sandbox(monkeypatch):
|
|
| 429 |
space_id="alice/sandbox-cpu",
|
| 430 |
url="https://huggingface.co/spaces/alice/sandbox-cpu",
|
| 431 |
_owns_space=True,
|
| 432 |
-
delete=lambda: deleted.append("alice/sandbox-cpu"),
|
| 433 |
)
|
| 434 |
self.sandbox_hardware = "cpu-basic"
|
| 435 |
self.sandbox_preload_task = None
|
|
@@ -474,10 +649,11 @@ def test_sandbox_create_replaces_auto_cpu_sandbox(monkeypatch):
|
|
| 474 |
|
| 475 |
def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
| 476 |
deleted: list[str] = []
|
|
|
|
| 477 |
persisted: list[dict] = []
|
| 478 |
|
| 479 |
-
async def fake_record_sandbox_destroy(*args, **kwargs):
|
| 480 |
-
|
| 481 |
|
| 482 |
monkeypatch.setattr(
|
| 483 |
telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy
|
|
@@ -485,20 +661,28 @@ def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
|
| 485 |
|
| 486 |
async def run():
|
| 487 |
cancel_event = threading.Event()
|
|
|
|
| 488 |
|
| 489 |
async def preload():
|
| 490 |
await asyncio.sleep(0)
|
| 491 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
session = SimpleNamespace(
|
| 493 |
session_id="s1",
|
| 494 |
sandbox=SimpleNamespace(
|
| 495 |
space_id="alice/sandbox-12345678",
|
| 496 |
_owns_space=True,
|
| 497 |
-
delete=
|
| 498 |
),
|
| 499 |
sandbox_hardware="cpu-basic",
|
| 500 |
sandbox_preload_task=asyncio.create_task(preload()),
|
| 501 |
sandbox_preload_cancel_event=cancel_event,
|
|
|
|
| 502 |
persistence_store=SimpleNamespace(
|
| 503 |
update_session_fields=lambda session_id, **fields: _record_metadata(
|
| 504 |
session_id, fields
|
|
@@ -507,17 +691,33 @@ def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
|
| 507 |
)
|
| 508 |
|
| 509 |
await sandbox_tool.teardown_session_sandbox(session)
|
| 510 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
|
| 512 |
async def _record_metadata(session_id, fields):
|
| 513 |
persisted.append({"session_id": session_id, **fields})
|
| 514 |
|
| 515 |
-
session, cancel_event = asyncio.run(run())
|
| 516 |
|
| 517 |
assert cancel_event.is_set()
|
| 518 |
assert deleted == ["alice/sandbox-12345678"]
|
|
|
|
| 519 |
assert session.sandbox is None
|
| 520 |
assert session.sandbox_hardware is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
assert persisted[-1]["session_id"] == "s1"
|
| 522 |
assert persisted[-1]["sandbox_space_id"] is None
|
| 523 |
assert persisted[-1]["sandbox_status"] == "destroyed"
|
|
|
|
| 13 |
raise AssertionError("sandbox creation should not update Space metadata")
|
| 14 |
|
| 15 |
|
| 16 |
+
def _capture_duplicate_repo_call(
|
| 17 |
+
captured,
|
| 18 |
+
*,
|
| 19 |
+
from_id,
|
| 20 |
+
to_id,
|
| 21 |
+
repo_type,
|
| 22 |
+
private,
|
| 23 |
+
space_hardware,
|
| 24 |
+
space_sleep_time=None,
|
| 25 |
+
):
|
| 26 |
+
captured.update(
|
| 27 |
+
{
|
| 28 |
+
"from_id": from_id,
|
| 29 |
+
"to_id": to_id,
|
| 30 |
+
"repo_type": repo_type,
|
| 31 |
+
"private": private,
|
| 32 |
+
"space_hardware": space_hardware,
|
| 33 |
+
"space_sleep_time": space_sleep_time,
|
| 34 |
+
}
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
def test_sandbox_client_defaults_to_private_spaces(monkeypatch):
|
| 39 |
duplicate_kwargs = {}
|
| 40 |
logs: list[str] = []
|
|
|
|
| 44 |
def __init__(self, token=None):
|
| 45 |
self.token = token
|
| 46 |
|
| 47 |
+
def duplicate_repo(
|
| 48 |
+
self,
|
| 49 |
+
*,
|
| 50 |
+
from_id,
|
| 51 |
+
to_id,
|
| 52 |
+
repo_type,
|
| 53 |
+
private,
|
| 54 |
+
space_hardware,
|
| 55 |
+
space_sleep_time=None,
|
| 56 |
+
):
|
| 57 |
+
_capture_duplicate_repo_call(
|
| 58 |
+
duplicate_kwargs,
|
| 59 |
+
from_id=from_id,
|
| 60 |
+
to_id=to_id,
|
| 61 |
+
repo_type=repo_type,
|
| 62 |
+
private=private,
|
| 63 |
+
space_hardware=space_hardware,
|
| 64 |
+
space_sleep_time=space_sleep_time,
|
| 65 |
+
)
|
| 66 |
|
| 67 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
| 68 |
requested_hardware.append((space_id, hardware, sleep_time))
|
|
|
|
| 84 |
|
| 85 |
Sandbox.create(owner="alice", token="hf-token", log=logs.append)
|
| 86 |
|
| 87 |
+
assert duplicate_kwargs["repo_type"] == "space"
|
| 88 |
assert duplicate_kwargs["private"] is True
|
| 89 |
+
assert duplicate_kwargs["space_hardware"] == "cpu-basic"
|
| 90 |
assert requested_hardware == []
|
| 91 |
assert not any("sleep time" in log for log in logs)
|
| 92 |
|
| 93 |
|
| 94 |
+
def test_sandbox_delete_uses_log_callback_without_stdout(monkeypatch, capsys):
|
| 95 |
+
deleted: list[tuple[str, str]] = []
|
| 96 |
+
|
| 97 |
+
class FakeApi:
|
| 98 |
+
def __init__(self, token=None):
|
| 99 |
+
self.token = token
|
| 100 |
+
|
| 101 |
+
def delete_repo(self, repo_id, repo_type):
|
| 102 |
+
deleted.append((repo_id, repo_type))
|
| 103 |
+
|
| 104 |
+
monkeypatch.setattr(sandbox_client, "HfApi", FakeApi)
|
| 105 |
+
|
| 106 |
+
sandbox = Sandbox("alice/sandbox-12345678", token="hf-token", _owns_space=True)
|
| 107 |
+
logs: list[str] = []
|
| 108 |
+
|
| 109 |
+
sandbox.delete(log=logs.append)
|
| 110 |
+
|
| 111 |
+
captured = capsys.readouterr()
|
| 112 |
+
assert captured.out == ""
|
| 113 |
+
assert captured.err == ""
|
| 114 |
+
assert deleted == [("alice/sandbox-12345678", "space")]
|
| 115 |
+
assert logs == ["Deleting sandbox: alice/sandbox-12345678...", "Deleted."]
|
| 116 |
+
assert sandbox._owns_space is False
|
| 117 |
+
|
| 118 |
+
|
| 119 |
def test_sandbox_client_retries_transient_runtime_404(monkeypatch):
|
| 120 |
runtime_calls = 0
|
| 121 |
|
|
|
|
| 132 |
def __init__(self, token=None):
|
| 133 |
self.token = token
|
| 134 |
|
| 135 |
+
def duplicate_repo(
|
| 136 |
+
self,
|
| 137 |
+
*,
|
| 138 |
+
from_id,
|
| 139 |
+
to_id,
|
| 140 |
+
repo_type,
|
| 141 |
+
private,
|
| 142 |
+
space_hardware,
|
| 143 |
+
space_sleep_time=None,
|
| 144 |
+
):
|
| 145 |
pass
|
| 146 |
|
| 147 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
|
|
|
| 181 |
def __init__(self, token=None):
|
| 182 |
self.token = token
|
| 183 |
|
| 184 |
+
def duplicate_repo(
|
| 185 |
+
self,
|
| 186 |
+
*,
|
| 187 |
+
from_id,
|
| 188 |
+
to_id,
|
| 189 |
+
repo_type,
|
| 190 |
+
private,
|
| 191 |
+
space_hardware,
|
| 192 |
+
space_sleep_time=None,
|
| 193 |
+
):
|
| 194 |
+
_capture_duplicate_repo_call(
|
| 195 |
+
duplicate_kwargs,
|
| 196 |
+
from_id=from_id,
|
| 197 |
+
to_id=to_id,
|
| 198 |
+
repo_type=repo_type,
|
| 199 |
+
private=private,
|
| 200 |
+
space_hardware=space_hardware,
|
| 201 |
+
space_sleep_time=space_sleep_time,
|
| 202 |
+
)
|
| 203 |
|
| 204 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
| 205 |
requested_hardware.append((space_id, hardware, sleep_time))
|
|
|
|
| 228 |
)
|
| 229 |
|
| 230 |
assert sandbox.space_id.startswith("alice/sandbox-")
|
| 231 |
+
assert duplicate_kwargs["repo_type"] == "space"
|
| 232 |
+
assert duplicate_kwargs["space_hardware"] == "t4-small"
|
| 233 |
+
assert duplicate_kwargs["space_sleep_time"] == 2700
|
| 234 |
assert requested_hardware == []
|
| 235 |
assert "Using duplicated Space hardware: t4-small" in logs
|
| 236 |
assert "Using duplicated Space sleep time: 2700s" in logs
|
|
|
|
| 245 |
def __init__(self, token=None):
|
| 246 |
self.token = token
|
| 247 |
|
| 248 |
+
def duplicate_repo(
|
| 249 |
+
self,
|
| 250 |
+
*,
|
| 251 |
+
from_id,
|
| 252 |
+
to_id,
|
| 253 |
+
repo_type,
|
| 254 |
+
private,
|
| 255 |
+
space_hardware,
|
| 256 |
+
space_sleep_time=None,
|
| 257 |
+
):
|
| 258 |
+
_capture_duplicate_repo_call(
|
| 259 |
+
duplicate_kwargs,
|
| 260 |
+
from_id=from_id,
|
| 261 |
+
to_id=to_id,
|
| 262 |
+
repo_type=repo_type,
|
| 263 |
+
private=private,
|
| 264 |
+
space_hardware=space_hardware,
|
| 265 |
+
space_sleep_time=space_sleep_time,
|
| 266 |
+
)
|
| 267 |
|
| 268 |
def request_space_hardware(self, space_id, hardware, sleep_time=None):
|
| 269 |
requested_hardware.append((space_id, hardware, sleep_time))
|
|
|
|
| 289 |
log=logs.append,
|
| 290 |
)
|
| 291 |
|
| 292 |
+
assert duplicate_kwargs["repo_type"] == "space"
|
| 293 |
+
assert duplicate_kwargs["space_hardware"] == "cpu-basic"
|
| 294 |
+
assert duplicate_kwargs["space_sleep_time"] == 2700
|
| 295 |
assert requested_hardware == []
|
| 296 |
assert "Using duplicated Space hardware: cpu-basic" in logs
|
| 297 |
assert (
|
|
|
|
| 420 |
assert persisted[-1]["sandbox_status"] == "active"
|
| 421 |
|
| 422 |
|
| 423 |
+
def test_cancelled_sandbox_creation_logs_delete_through_tool_log(monkeypatch):
|
| 424 |
+
deleted: list[str] = []
|
| 425 |
+
|
| 426 |
+
class FakeSession:
|
| 427 |
+
def __init__(self):
|
| 428 |
+
self.hf_token = "hf-token"
|
| 429 |
+
self.sandbox = None
|
| 430 |
+
self.event_queue = asyncio.Queue()
|
| 431 |
+
self._cancelled = asyncio.Event()
|
| 432 |
+
|
| 433 |
+
async def send_event(self, event):
|
| 434 |
+
await self.event_queue.put(event)
|
| 435 |
+
|
| 436 |
+
def fake_create(**kwargs):
|
| 437 |
+
def delete(log=None):
|
| 438 |
+
deleted.append("alice/sandbox-12345678")
|
| 439 |
+
if log:
|
| 440 |
+
log("Deleting sandbox: alice/sandbox-12345678...")
|
| 441 |
+
log("Deleted.")
|
| 442 |
+
|
| 443 |
+
return SimpleNamespace(
|
| 444 |
+
space_id="alice/sandbox-12345678",
|
| 445 |
+
url="https://huggingface.co/spaces/alice/sandbox-12345678",
|
| 446 |
+
_owns_space=True,
|
| 447 |
+
delete=delete,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
monkeypatch.setattr(Sandbox, "create", staticmethod(fake_create))
|
| 451 |
+
|
| 452 |
+
async def run():
|
| 453 |
+
session = FakeSession()
|
| 454 |
+
cancel_event = threading.Event()
|
| 455 |
+
cancel_event.set()
|
| 456 |
+
|
| 457 |
+
sb, error = await sandbox_tool._create_sandbox_locked(
|
| 458 |
+
session,
|
| 459 |
+
api=SimpleNamespace(),
|
| 460 |
+
owner="alice",
|
| 461 |
+
hardware="cpu-basic",
|
| 462 |
+
cancel_event=cancel_event,
|
| 463 |
+
)
|
| 464 |
+
await asyncio.sleep(0)
|
| 465 |
+
events = []
|
| 466 |
+
while not session.event_queue.empty():
|
| 467 |
+
events.append(await session.event_queue.get())
|
| 468 |
+
return sb, error, events
|
| 469 |
+
|
| 470 |
+
sb, error, events = asyncio.run(run())
|
| 471 |
+
|
| 472 |
+
assert sb is None
|
| 473 |
+
assert error == "Sandbox creation cancelled by user."
|
| 474 |
+
assert deleted == ["alice/sandbox-12345678"]
|
| 475 |
+
assert [
|
| 476 |
+
event.data
|
| 477 |
+
for event in events
|
| 478 |
+
if event.event_type == "tool_log"
|
| 479 |
+
and event.data
|
| 480 |
+
and event.data.get("log")
|
| 481 |
+
in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."}
|
| 482 |
+
] == [
|
| 483 |
+
{"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."},
|
| 484 |
+
{"tool": "sandbox", "log": "Deleted."},
|
| 485 |
+
]
|
| 486 |
+
|
| 487 |
+
|
| 488 |
def test_sandbox_creation_is_serialized_per_owner(monkeypatch):
|
| 489 |
active_creates = 0
|
| 490 |
max_active_creates = 0
|
|
|
|
| 604 |
space_id="alice/sandbox-cpu",
|
| 605 |
url="https://huggingface.co/spaces/alice/sandbox-cpu",
|
| 606 |
_owns_space=True,
|
| 607 |
+
delete=lambda log=None: deleted.append("alice/sandbox-cpu"),
|
| 608 |
)
|
| 609 |
self.sandbox_hardware = "cpu-basic"
|
| 610 |
self.sandbox_preload_task = None
|
|
|
|
| 649 |
|
| 650 |
def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
| 651 |
deleted: list[str] = []
|
| 652 |
+
destroyed: list[str] = []
|
| 653 |
persisted: list[dict] = []
|
| 654 |
|
| 655 |
+
async def fake_record_sandbox_destroy(session, sandbox, *args, **kwargs):
|
| 656 |
+
destroyed.append(sandbox.space_id)
|
| 657 |
|
| 658 |
monkeypatch.setattr(
|
| 659 |
telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy
|
|
|
|
| 661 |
|
| 662 |
async def run():
|
| 663 |
cancel_event = threading.Event()
|
| 664 |
+
event_queue = asyncio.Queue()
|
| 665 |
|
| 666 |
async def preload():
|
| 667 |
await asyncio.sleep(0)
|
| 668 |
|
| 669 |
+
def delete(log=None):
|
| 670 |
+
deleted.append("alice/sandbox-12345678")
|
| 671 |
+
if log:
|
| 672 |
+
log("Deleting sandbox: alice/sandbox-12345678...")
|
| 673 |
+
log("Deleted.")
|
| 674 |
+
|
| 675 |
session = SimpleNamespace(
|
| 676 |
session_id="s1",
|
| 677 |
sandbox=SimpleNamespace(
|
| 678 |
space_id="alice/sandbox-12345678",
|
| 679 |
_owns_space=True,
|
| 680 |
+
delete=delete,
|
| 681 |
),
|
| 682 |
sandbox_hardware="cpu-basic",
|
| 683 |
sandbox_preload_task=asyncio.create_task(preload()),
|
| 684 |
sandbox_preload_cancel_event=cancel_event,
|
| 685 |
+
event_queue=event_queue,
|
| 686 |
persistence_store=SimpleNamespace(
|
| 687 |
update_session_fields=lambda session_id, **fields: _record_metadata(
|
| 688 |
session_id, fields
|
|
|
|
| 691 |
)
|
| 692 |
|
| 693 |
await sandbox_tool.teardown_session_sandbox(session)
|
| 694 |
+
await asyncio.sleep(0)
|
| 695 |
+
events = []
|
| 696 |
+
while not event_queue.empty():
|
| 697 |
+
events.append(await event_queue.get())
|
| 698 |
+
return session, cancel_event, events
|
| 699 |
|
| 700 |
async def _record_metadata(session_id, fields):
|
| 701 |
persisted.append({"session_id": session_id, **fields})
|
| 702 |
|
| 703 |
+
session, cancel_event, events = asyncio.run(run())
|
| 704 |
|
| 705 |
assert cancel_event.is_set()
|
| 706 |
assert deleted == ["alice/sandbox-12345678"]
|
| 707 |
+
assert destroyed == ["alice/sandbox-12345678"]
|
| 708 |
assert session.sandbox is None
|
| 709 |
assert session.sandbox_hardware is None
|
| 710 |
+
assert [
|
| 711 |
+
event.data
|
| 712 |
+
for event in events
|
| 713 |
+
if event.event_type == "tool_log"
|
| 714 |
+
and event.data
|
| 715 |
+
and event.data.get("log")
|
| 716 |
+
in {"Deleting sandbox: alice/sandbox-12345678...", "Deleted."}
|
| 717 |
+
] == [
|
| 718 |
+
{"tool": "sandbox", "log": "Deleting sandbox: alice/sandbox-12345678..."},
|
| 719 |
+
{"tool": "sandbox", "log": "Deleted."},
|
| 720 |
+
]
|
| 721 |
assert persisted[-1]["session_id"] == "s1"
|
| 722 |
assert persisted[-1]["sandbox_space_id"] is None
|
| 723 |
assert persisted[-1]["sandbox_status"] == "destroyed"
|
tests/unit/test_sandbox_script_resolution.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from types import SimpleNamespace
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
|
| 6 |
+
from agent.tools.sandbox_tool import resolve_sandbox_script
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FakeSandbox:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.read_paths = []
|
| 12 |
+
|
| 13 |
+
def read(self, path, *, limit):
|
| 14 |
+
self.read_paths.append((path, limit))
|
| 15 |
+
return SimpleNamespace(
|
| 16 |
+
success=True,
|
| 17 |
+
output="1\tprint('training')\n2\tprint('done')",
|
| 18 |
+
error="",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@pytest.mark.asyncio
|
| 23 |
+
async def test_resolve_sandbox_script_accepts_bare_python_filename():
|
| 24 |
+
sandbox = FakeSandbox()
|
| 25 |
+
|
| 26 |
+
content, error = await resolve_sandbox_script(sandbox, "train_smollm2.py")
|
| 27 |
+
|
| 28 |
+
assert error is None
|
| 29 |
+
assert content == "print('training')\nprint('done')"
|
| 30 |
+
assert sandbox.read_paths == [("train_smollm2.py", 100_000)]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@pytest.mark.asyncio
|
| 34 |
+
async def test_resolve_sandbox_script_accepts_relative_python_path():
|
| 35 |
+
sandbox = FakeSandbox()
|
| 36 |
+
|
| 37 |
+
content, error = await resolve_sandbox_script(sandbox, "scripts/train.py")
|
| 38 |
+
|
| 39 |
+
assert error is None
|
| 40 |
+
assert content == "print('training')\nprint('done')"
|
| 41 |
+
assert sandbox.read_paths == [("scripts/train.py", 100_000)]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.mark.asyncio
|
| 45 |
+
@pytest.mark.parametrize(
|
| 46 |
+
"script",
|
| 47 |
+
[
|
| 48 |
+
"https://example.com/train.py",
|
| 49 |
+
"http://example.com/train.py",
|
| 50 |
+
"train_smollm2.py --epochs 1",
|
| 51 |
+
"print('hello')",
|
| 52 |
+
],
|
| 53 |
+
)
|
| 54 |
+
async def test_resolve_sandbox_script_ignores_non_path_scripts(script):
|
| 55 |
+
sandbox = FakeSandbox()
|
| 56 |
+
|
| 57 |
+
content, error = await resolve_sandbox_script(sandbox, script)
|
| 58 |
+
|
| 59 |
+
assert content is None
|
| 60 |
+
assert error is None
|
| 61 |
+
assert sandbox.read_paths == []
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def test_hf_jobs_script_description_mentions_bare_python_filenames():
|
| 65 |
+
script_description = HF_JOBS_TOOL_SPEC["parameters"]["properties"]["script"][
|
| 66 |
+
"description"
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
assert "bare 'train.py'" in script_description
|
| 70 |
+
assert "smoke-test in a GPU sandbox before submission" in script_description
|
tests/unit/test_session_manager_persistence.py
CHANGED
|
@@ -207,7 +207,7 @@ async def test_close_cancels_preload_and_deletes_owned_sandbox(monkeypatch):
|
|
| 207 |
session.sandbox = SimpleNamespace(
|
| 208 |
space_id="owner/sandbox-12345678",
|
| 209 |
_owns_space=True,
|
| 210 |
-
delete=lambda: deleted.append("owner/sandbox-12345678"),
|
| 211 |
)
|
| 212 |
session.sandbox_hardware = "cpu-basic"
|
| 213 |
session.sandbox_preload_cancel_event = preload_cancel_event
|
|
|
|
| 207 |
session.sandbox = SimpleNamespace(
|
| 208 |
space_id="owner/sandbox-12345678",
|
| 209 |
_owns_space=True,
|
| 210 |
+
delete=lambda log=None: deleted.append("owner/sandbox-12345678"),
|
| 211 |
)
|
| 212 |
session.sandbox_hardware = "cpu-basic"
|
| 213 |
session.sandbox_preload_cancel_event = preload_cancel_event
|
tests/unit/test_trackio_space_ids.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC
|
| 5 |
+
from agent.tools.sandbox_tool import SANDBOX_CREATE_TOOL_SPEC
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_trackio_space_examples_use_hyphenated_ml_intern_prefix():
|
| 9 |
+
prompt = Path("agent/prompts/system_prompt_v3.yaml").read_text()
|
| 10 |
+
tool_specs = json.dumps([HF_JOBS_TOOL_SPEC, SANDBOX_CREATE_TOOL_SPEC])
|
| 11 |
+
legacy_prefix = "ml" + "intern"
|
| 12 |
+
|
| 13 |
+
assert "<username>/ml-intern-<8-char-id>" in prompt
|
| 14 |
+
assert "<username>/ml-intern-<8char>" in tool_specs
|
| 15 |
+
assert legacy_prefix not in prompt
|
| 16 |
+
assert legacy_prefix not in tool_specs
|
uv.lock
CHANGED
|
@@ -1788,6 +1788,7 @@ dependencies = [
|
|
| 1788 |
{ name = "pydantic" },
|
| 1789 |
{ name = "pymongo" },
|
| 1790 |
{ name = "python-dotenv" },
|
|
|
|
| 1791 |
{ name = "requests" },
|
| 1792 |
{ name = "rich" },
|
| 1793 |
{ name = "thefuzz" },
|
|
@@ -1840,6 +1841,7 @@ requires-dist = [
|
|
| 1840 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
|
| 1841 |
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.2.0" },
|
| 1842 |
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
|
|
|
| 1843 |
{ name = "requests", specifier = ">=2.33.0" },
|
| 1844 |
{ name = "rich", specifier = ">=13.0.0" },
|
| 1845 |
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.12" },
|
|
|
|
| 1788 |
{ name = "pydantic" },
|
| 1789 |
{ name = "pymongo" },
|
| 1790 |
{ name = "python-dotenv" },
|
| 1791 |
+
{ name = "python-multipart" },
|
| 1792 |
{ name = "requests" },
|
| 1793 |
{ name = "rich" },
|
| 1794 |
{ name = "thefuzz" },
|
|
|
|
| 1841 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
|
| 1842 |
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.2.0" },
|
| 1843 |
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
| 1844 |
+
{ name = "python-multipart", specifier = ">=0.0.20" },
|
| 1845 |
{ name = "requests", specifier = ">=2.33.0" },
|
| 1846 |
{ name = "rich", specifier = ">=13.0.0" },
|
| 1847 |
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.12" },
|