Deploy 2026-05-05
Browse filesCo-authored-by: OpenAI Codex <codex@openai.com>
This view is limited to 50 files because it contains too many changes. See raw diff
- .github/workflows/ci.yml +63 -0
- AGENTS.md +5 -0
- agent/config.py +11 -7
- agent/context_manager/manager.py +22 -10
- agent/core/agent_loop.py +285 -129
- agent/core/cost_estimation.py +6 -2
- agent/core/doom_loop.py +7 -3
- agent/core/effort_probe.py +42 -16
- agent/core/hf_router_catalog.py +3 -1
- agent/core/hub_artifacts.py +765 -0
- agent/core/llm_params.py +10 -3
- agent/core/model_switcher.py +13 -7
- agent/core/prompt_caching.py +12 -6
- agent/core/session.py +4 -3
- agent/core/session_persistence.py +12 -4
- agent/core/session_uploader.py +2 -3
- agent/core/telemetry.py +113 -77
- agent/core/tools.py +15 -5
- agent/main.py +85 -30
- agent/messaging/base.py +5 -1
- agent/messaging/gateway.py +9 -3
- agent/messaging/models.py +2 -8
- agent/messaging/slack.py +1 -3
- agent/sft/tagger.py +47 -18
- agent/tools/dataset_tools.py +3 -1
- agent/tools/edit_utils.py +26 -21
- agent/tools/hf_repo_files_tool.py +56 -17
- agent/tools/hf_repo_git_tool.py +140 -37
- agent/tools/jobs_tool.py +66 -18
- agent/tools/local_tools.py +22 -7
- agent/tools/papers_tool.py +65 -20
- agent/tools/research_tool.py +61 -38
- agent/tools/sandbox_client.py +13 -6
- agent/tools/sandbox_tool.py +23 -6
- agent/tools/web_search_tool.py +4 -1
- agent/utils/braille.py +5 -4
- agent/utils/crt_boot.py +5 -2
- agent/utils/particle_logo.py +3 -1
- agent/utils/terminal_display.py +61 -28
- backend/dependencies.py +3 -1
- backend/kpis_scheduler.py +4 -2
- backend/main.py +8 -6
- backend/models.py +3 -1
- backend/routes/agent.py +54 -15
- backend/routes/auth.py +0 -1
- backend/session_manager.py +73 -25
- backend/user_quotas.py +5 -1
- pyproject.toml +1 -0
- scripts/build_kpis.py +148 -47
- scripts/build_sft.py +21 -5
.github/workflows/ci.yml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CI
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
pull_request:
|
| 5 |
+
push:
|
| 6 |
+
branches: [main]
|
| 7 |
+
|
| 8 |
+
permissions:
|
| 9 |
+
contents: read
|
| 10 |
+
|
| 11 |
+
concurrency:
|
| 12 |
+
group: ci-${{ github.workflow }}-${{ github.ref }}
|
| 13 |
+
cancel-in-progress: true
|
| 14 |
+
|
| 15 |
+
jobs:
|
| 16 |
+
ruff:
|
| 17 |
+
name: Ruff
|
| 18 |
+
runs-on: ubuntu-latest
|
| 19 |
+
steps:
|
| 20 |
+
- uses: actions/checkout@v4
|
| 21 |
+
|
| 22 |
+
- name: Install uv
|
| 23 |
+
uses: astral-sh/setup-uv@v5
|
| 24 |
+
with:
|
| 25 |
+
enable-cache: true
|
| 26 |
+
cache-dependency-glob: uv.lock
|
| 27 |
+
|
| 28 |
+
- name: Set up Python
|
| 29 |
+
uses: actions/setup-python@v5
|
| 30 |
+
with:
|
| 31 |
+
python-version: "3.12"
|
| 32 |
+
|
| 33 |
+
- name: Install dependencies
|
| 34 |
+
run: uv sync --locked --extra dev
|
| 35 |
+
|
| 36 |
+
- name: Run Ruff
|
| 37 |
+
run: uv run ruff check .
|
| 38 |
+
|
| 39 |
+
- name: Check formatting
|
| 40 |
+
run: uv run ruff format --check .
|
| 41 |
+
|
| 42 |
+
tests:
|
| 43 |
+
name: Tests
|
| 44 |
+
runs-on: ubuntu-latest
|
| 45 |
+
steps:
|
| 46 |
+
- uses: actions/checkout@v4
|
| 47 |
+
|
| 48 |
+
- name: Install uv
|
| 49 |
+
uses: astral-sh/setup-uv@v5
|
| 50 |
+
with:
|
| 51 |
+
enable-cache: true
|
| 52 |
+
cache-dependency-glob: uv.lock
|
| 53 |
+
|
| 54 |
+
- name: Set up Python
|
| 55 |
+
uses: actions/setup-python@v5
|
| 56 |
+
with:
|
| 57 |
+
python-version: "3.12"
|
| 58 |
+
|
| 59 |
+
- name: Install dependencies
|
| 60 |
+
run: uv sync --locked --extra dev
|
| 61 |
+
|
| 62 |
+
- name: Run tests
|
| 63 |
+
run: uv run pytest
|
AGENTS.md
CHANGED
|
@@ -15,6 +15,11 @@ Notes:
|
|
| 15 |
- Prefer `npm ci` over `npm install` for setup, since `npm install` may rewrite `frontend/package-lock.json` metadata depending on npm version.
|
| 16 |
- Production defaults to the Bedrock Claude model. For local development with a personal Anthropic key, set `ANTHROPIC_API_KEY` and `ML_INTERN_CLAUDE_MODEL_ID=anthropic/claude-opus-4-6` before starting the backend. Other models are selected through the app's model switcher.
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
## GitHub CLI
|
| 19 |
|
| 20 |
- For multiline PR descriptions, prefer `gh pr edit <number> --body-file <file>` over inline `--body` so shell quoting, `$` env-var names, backticks, and newlines are preserved correctly.
|
|
|
|
| 15 |
- Prefer `npm ci` over `npm install` for setup, since `npm install` may rewrite `frontend/package-lock.json` metadata depending on npm version.
|
| 16 |
- Production defaults to the Bedrock Claude model. For local development with a personal Anthropic key, set `ANTHROPIC_API_KEY` and `ML_INTERN_CLAUDE_MODEL_ID=anthropic/claude-opus-4-6` before starting the backend. Other models are selected through the app's model switcher.
|
| 17 |
|
| 18 |
+
## Development Checks
|
| 19 |
+
|
| 20 |
+
- Before every commit, run `uv run ruff check .` and `uv run ruff format --check .`.
|
| 21 |
+
- If formatting fails, run `uv run ruff format .`, then re-run the Ruff checks before committing.
|
| 22 |
+
|
| 23 |
## GitHub CLI
|
| 24 |
|
| 25 |
- For multiline PR descriptions, prefer `gh pr edit <number> --body-file <file>` over inline `--body` so shell quoting, `$` env-var names, backticks, and newlines are preserved correctly.
|
agent/config.py
CHANGED
|
@@ -5,20 +5,20 @@ from pathlib import Path
|
|
| 5 |
from typing import Any, Union
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
-
|
| 9 |
-
from agent.messaging.models import MessagingConfig
|
| 10 |
-
|
| 11 |
-
# Project root: two levels up from this file (agent/config.py -> project root)
|
| 12 |
-
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 13 |
from fastmcp.mcp_config import (
|
| 14 |
RemoteMCPServer,
|
| 15 |
StdioMCPServer,
|
| 16 |
)
|
| 17 |
from pydantic import BaseModel
|
| 18 |
|
|
|
|
|
|
|
| 19 |
# These two are the canonical server config types for MCP servers.
|
| 20 |
MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer]
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
class Config(BaseModel):
|
| 24 |
"""Configuration manager"""
|
|
@@ -60,12 +60,16 @@ class Config(BaseModel):
|
|
| 60 |
|
| 61 |
|
| 62 |
USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG"
|
| 63 |
-
DEFAULT_USER_CONFIG_PATH =
|
|
|
|
|
|
|
| 64 |
SLACK_DEFAULT_DESTINATION = "slack.default"
|
| 65 |
SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"]
|
| 66 |
|
| 67 |
|
| 68 |
-
def _deep_merge_config(
|
|
|
|
|
|
|
| 69 |
merged = dict(base)
|
| 70 |
for key, value in override.items():
|
| 71 |
current = merged.get(key)
|
|
|
|
| 5 |
from typing import Any, Union
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from fastmcp.mcp_config import (
|
| 9 |
RemoteMCPServer,
|
| 10 |
StdioMCPServer,
|
| 11 |
)
|
| 12 |
from pydantic import BaseModel
|
| 13 |
|
| 14 |
+
from agent.messaging.models import MessagingConfig
|
| 15 |
+
|
| 16 |
# These two are the canonical server config types for MCP servers.
|
| 17 |
MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer]
|
| 18 |
|
| 19 |
+
# Project root: two levels up from this file (agent/config.py -> project root)
|
| 20 |
+
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 21 |
+
|
| 22 |
|
| 23 |
class Config(BaseModel):
|
| 24 |
"""Configuration manager"""
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG"
|
| 63 |
+
DEFAULT_USER_CONFIG_PATH = (
|
| 64 |
+
Path.home() / ".config" / "ml-intern" / "cli_agent_config.json"
|
| 65 |
+
)
|
| 66 |
SLACK_DEFAULT_DESTINATION = "slack.default"
|
| 67 |
SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"]
|
| 68 |
|
| 69 |
|
| 70 |
+
def _deep_merge_config(
|
| 71 |
+
base: dict[str, Any], override: dict[str, Any]
|
| 72 |
+
) -> dict[str, Any]:
|
| 73 |
merged = dict(base)
|
| 74 |
for key, value in override.items():
|
| 75 |
current = merged.get(key)
|
agent/context_manager/manager.py
CHANGED
|
@@ -3,7 +3,6 @@ Context management for conversation history
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import logging
|
| 6 |
-
import os
|
| 7 |
import time
|
| 8 |
import zoneinfo
|
| 9 |
from datetime import datetime
|
|
@@ -96,6 +95,7 @@ class CompactionFailedError(Exception):
|
|
| 96 |
burns Bedrock budget for free (~$3 per re-attempt on Opus).
|
| 97 |
"""
|
| 98 |
|
|
|
|
| 99 |
# Used when seeding a brand-new session from prior browser-cached messages.
|
| 100 |
# Here we're writing a note to *ourselves* — so preserve the tool-call trail,
|
| 101 |
# files produced, and planned next steps in first person. Optimized for
|
|
@@ -155,12 +155,15 @@ async def summarize_messages(
|
|
| 155 |
)
|
| 156 |
if session is not None:
|
| 157 |
from agent.core import telemetry
|
|
|
|
| 158 |
await telemetry.record_llm_call(
|
| 159 |
session,
|
| 160 |
model=model_name,
|
| 161 |
response=response,
|
| 162 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 163 |
-
finish_reason=response.choices[0].finish_reason
|
|
|
|
|
|
|
| 164 |
kind=kind,
|
| 165 |
)
|
| 166 |
summary = response.choices[0].message.content or ""
|
|
@@ -233,6 +236,7 @@ class ContextManager:
|
|
| 233 |
# CLI-specific context for local mode
|
| 234 |
if local_mode:
|
| 235 |
import os
|
|
|
|
| 236 |
cwd = os.getcwd()
|
| 237 |
local_context = (
|
| 238 |
f"\n\n# CLI / Local mode\n\n"
|
|
@@ -305,7 +309,9 @@ class ContextManager:
|
|
| 305 |
i = 0
|
| 306 |
while i < len(self.items):
|
| 307 |
msg = self.items[i]
|
| 308 |
-
if getattr(msg, "role", None) != "assistant" or not getattr(
|
|
|
|
|
|
|
| 309 |
i += 1
|
| 310 |
continue
|
| 311 |
|
|
@@ -316,7 +322,9 @@ class ContextManager:
|
|
| 316 |
# before the next non-tool message to satisfy provider ordering.
|
| 317 |
j = i + 1
|
| 318 |
immediate_ids: set[str | None] = set()
|
| 319 |
-
while
|
|
|
|
|
|
|
| 320 |
immediate_ids.add(getattr(self.items[j], "tool_call_id", None))
|
| 321 |
j += 1
|
| 322 |
|
|
@@ -386,7 +394,9 @@ class ContextManager:
|
|
| 386 |
|
| 387 |
@property
|
| 388 |
def needs_compaction(self) -> bool:
|
| 389 |
-
return self.running_context_usage > self.compaction_threshold and bool(
|
|
|
|
|
|
|
| 390 |
|
| 391 |
def _truncate_oversized(
|
| 392 |
self, messages: list[Message], model_name: str
|
|
@@ -425,7 +435,9 @@ class ContextManager:
|
|
| 425 |
)
|
| 426 |
logger.warning(
|
| 427 |
"Truncating %s message: %d -> %d tokens for compaction",
|
| 428 |
-
msg.role,
|
|
|
|
|
|
|
| 429 |
)
|
| 430 |
# Preserve all known assistant-side fields (tool_calls, thinking_blocks,
|
| 431 |
# reasoning_content, provider_specific_fields) even when content is
|
|
@@ -459,9 +471,9 @@ class ContextManager:
|
|
| 459 |
except Exception as e:
|
| 460 |
logger.warning("token_counter failed (%s); rough estimate", e)
|
| 461 |
# Rough fallback: 4 chars per token.
|
| 462 |
-
self.running_context_usage =
|
| 463 |
-
len(getattr(m, "content", "") or "") for m in self.items
|
| 464 |
-
)
|
| 465 |
|
| 466 |
async def compact(
|
| 467 |
self,
|
|
@@ -516,7 +528,7 @@ class ContextManager:
|
|
| 516 |
idx = first_user_idx + 1
|
| 517 |
|
| 518 |
recent_messages = self.items[idx:]
|
| 519 |
-
messages_to_summarize = self.items[first_user_idx + 1:idx]
|
| 520 |
|
| 521 |
# Truncate any message that's larger than _MAX_TOKENS_PER_MESSAGE in
|
| 522 |
# the parts we PRESERVE through compaction (first_user + recent_tail).
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import logging
|
|
|
|
| 6 |
import time
|
| 7 |
import zoneinfo
|
| 8 |
from datetime import datetime
|
|
|
|
| 95 |
burns Bedrock budget for free (~$3 per re-attempt on Opus).
|
| 96 |
"""
|
| 97 |
|
| 98 |
+
|
| 99 |
# Used when seeding a brand-new session from prior browser-cached messages.
|
| 100 |
# Here we're writing a note to *ourselves* — so preserve the tool-call trail,
|
| 101 |
# files produced, and planned next steps in first person. Optimized for
|
|
|
|
| 155 |
)
|
| 156 |
if session is not None:
|
| 157 |
from agent.core import telemetry
|
| 158 |
+
|
| 159 |
await telemetry.record_llm_call(
|
| 160 |
session,
|
| 161 |
model=model_name,
|
| 162 |
response=response,
|
| 163 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 164 |
+
finish_reason=response.choices[0].finish_reason
|
| 165 |
+
if response.choices
|
| 166 |
+
else None,
|
| 167 |
kind=kind,
|
| 168 |
)
|
| 169 |
summary = response.choices[0].message.content or ""
|
|
|
|
| 236 |
# CLI-specific context for local mode
|
| 237 |
if local_mode:
|
| 238 |
import os
|
| 239 |
+
|
| 240 |
cwd = os.getcwd()
|
| 241 |
local_context = (
|
| 242 |
f"\n\n# CLI / Local mode\n\n"
|
|
|
|
| 309 |
i = 0
|
| 310 |
while i < len(self.items):
|
| 311 |
msg = self.items[i]
|
| 312 |
+
if getattr(msg, "role", None) != "assistant" or not getattr(
|
| 313 |
+
msg, "tool_calls", None
|
| 314 |
+
):
|
| 315 |
i += 1
|
| 316 |
continue
|
| 317 |
|
|
|
|
| 322 |
# before the next non-tool message to satisfy provider ordering.
|
| 323 |
j = i + 1
|
| 324 |
immediate_ids: set[str | None] = set()
|
| 325 |
+
while (
|
| 326 |
+
j < len(self.items) and getattr(self.items[j], "role", None) == "tool"
|
| 327 |
+
):
|
| 328 |
immediate_ids.add(getattr(self.items[j], "tool_call_id", None))
|
| 329 |
j += 1
|
| 330 |
|
|
|
|
| 394 |
|
| 395 |
@property
|
| 396 |
def needs_compaction(self) -> bool:
|
| 397 |
+
return self.running_context_usage > self.compaction_threshold and bool(
|
| 398 |
+
self.items
|
| 399 |
+
)
|
| 400 |
|
| 401 |
def _truncate_oversized(
|
| 402 |
self, messages: list[Message], model_name: str
|
|
|
|
| 435 |
)
|
| 436 |
logger.warning(
|
| 437 |
"Truncating %s message: %d -> %d tokens for compaction",
|
| 438 |
+
msg.role,
|
| 439 |
+
n,
|
| 440 |
+
len(placeholder) // 4,
|
| 441 |
)
|
| 442 |
# Preserve all known assistant-side fields (tool_calls, thinking_blocks,
|
| 443 |
# reasoning_content, provider_specific_fields) even when content is
|
|
|
|
| 471 |
except Exception as e:
|
| 472 |
logger.warning("token_counter failed (%s); rough estimate", e)
|
| 473 |
# Rough fallback: 4 chars per token.
|
| 474 |
+
self.running_context_usage = (
|
| 475 |
+
sum(len(getattr(m, "content", "") or "") for m in self.items) // 4
|
| 476 |
+
)
|
| 477 |
|
| 478 |
async def compact(
|
| 479 |
self,
|
|
|
|
| 528 |
idx = first_user_idx + 1
|
| 529 |
|
| 530 |
recent_messages = self.items[idx:]
|
| 531 |
+
messages_to_summarize = self.items[first_user_idx + 1 : idx]
|
| 532 |
|
| 533 |
# Truncate any message that's larger than _MAX_TOKENS_PER_MESSAGE in
|
| 534 |
# the parts we PRESERVE through compaction (first_user + recent_tail).
|
agent/core/agent_loop.py
CHANGED
|
@@ -5,7 +5,6 @@ Main agent implementation with integrated tool system and MCP support
|
|
| 5 |
import asyncio
|
| 6 |
import json
|
| 7 |
import logging
|
| 8 |
-
import os
|
| 9 |
import time
|
| 10 |
from dataclasses import dataclass, field
|
| 11 |
from typing import Any
|
|
@@ -27,6 +26,7 @@ from agent.core.cost_estimation import CostEstimate, estimate_tool_cost
|
|
| 27 |
from agent.messaging.gateway import NotificationGateway
|
| 28 |
from agent.core import telemetry
|
| 29 |
from agent.core.doom_loop import check_for_doom_loop
|
|
|
|
| 30 |
from agent.core.llm_params import _resolve_llm_params
|
| 31 |
from agent.core.prompt_caching import with_prompt_caching
|
| 32 |
from agent.core.session import Event, OpType, Session
|
|
@@ -54,11 +54,12 @@ def _malformed_tool_name(message: Message) -> str | None:
|
|
| 54 |
end = content.find(_MALFORMED_TOOL_SUFFIX, len(_MALFORMED_TOOL_PREFIX))
|
| 55 |
if end == -1:
|
| 56 |
return None
|
| 57 |
-
return content[len(_MALFORMED_TOOL_PREFIX):end]
|
| 58 |
|
| 59 |
|
| 60 |
def _detect_repeated_malformed(
|
| 61 |
-
items: list[Message],
|
|
|
|
| 62 |
) -> str | None:
|
| 63 |
"""Return the repeated malformed tool name if the tail contains a streak.
|
| 64 |
|
|
@@ -118,6 +119,7 @@ def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
|
|
| 118 |
|
| 119 |
_IMMEDIATE_HF_JOB_RUNS = {"run", "uv"}
|
| 120 |
|
|
|
|
| 121 |
@dataclass(frozen=True)
|
| 122 |
class ApprovalDecision:
|
| 123 |
requires_approval: bool
|
|
@@ -142,7 +144,9 @@ def _is_scheduled_hf_job_run(tool_name: str, tool_args: dict) -> bool:
|
|
| 142 |
|
| 143 |
|
| 144 |
def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool:
|
| 145 |
-
return tool_name == "sandbox_create" or _is_immediate_hf_job_run(
|
|
|
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
def _base_needs_approval(
|
|
@@ -231,7 +235,9 @@ def _session_auto_approval_enabled(session: Session | None) -> bool:
|
|
| 231 |
|
| 232 |
|
| 233 |
def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool:
|
| 234 |
-
return bool(
|
|
|
|
|
|
|
| 235 |
|
| 236 |
|
| 237 |
def _remaining_budget_after_reservations(
|
|
@@ -251,7 +257,10 @@ def _budget_block_reason(
|
|
| 251 |
) -> str | None:
|
| 252 |
if estimate.estimated_cost_usd is None:
|
| 253 |
return estimate.block_reason or "Could not estimate the cost safely."
|
| 254 |
-
if
|
|
|
|
|
|
|
|
|
|
| 255 |
return (
|
| 256 |
f"Estimated cost ${estimate.estimated_cost_usd:.2f} exceeds "
|
| 257 |
f"remaining YOLO cap ${remaining_cap_usd:.2f}."
|
|
@@ -409,15 +418,25 @@ def _is_transient_error(error: Exception) -> bool:
|
|
| 409 |
"""Return True for errors that are likely transient and worth retrying."""
|
| 410 |
err_str = str(error).lower()
|
| 411 |
transient_patterns = [
|
| 412 |
-
"timeout",
|
| 413 |
-
"
|
| 414 |
-
"
|
| 415 |
-
"
|
| 416 |
-
"
|
| 417 |
-
"
|
| 418 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
]
|
| 420 |
-
return _is_rate_limit_error(error) or any(
|
|
|
|
|
|
|
| 421 |
|
| 422 |
|
| 423 |
def _is_effort_config_error(error: Exception) -> bool:
|
|
@@ -429,11 +448,14 @@ def _is_effort_config_error(error: Exception) -> bool:
|
|
| 429 |
doesn't work for the current model. We heal the cache and retry once.
|
| 430 |
"""
|
| 431 |
from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported
|
|
|
|
| 432 |
return _is_thinking_unsupported(error) or _is_invalid_effort(error)
|
| 433 |
|
| 434 |
|
| 435 |
async def _heal_effort_and_rebuild_params(
|
| 436 |
-
session: Session,
|
|
|
|
|
|
|
| 437 |
) -> dict:
|
| 438 |
"""Update the session's effort cache based on ``error`` and return new
|
| 439 |
llm_params. Called only when ``_is_effort_config_error(error)`` is True.
|
|
@@ -444,7 +466,11 @@ async def _heal_effort_and_rebuild_params(
|
|
| 444 |
• invalid-effort → re-run the full cascade probe; the result lands
|
| 445 |
in the cache
|
| 446 |
"""
|
| 447 |
-
from agent.core.effort_probe import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
|
| 449 |
model = session.config.model_name
|
| 450 |
if _is_thinking_unsupported(error):
|
|
@@ -453,12 +479,16 @@ async def _heal_effort_and_rebuild_params(
|
|
| 453 |
else:
|
| 454 |
try:
|
| 455 |
outcome = await probe_effort(
|
| 456 |
-
model,
|
|
|
|
|
|
|
| 457 |
session=session,
|
| 458 |
)
|
| 459 |
session.model_effective_effort[model] = outcome.effective_effort
|
| 460 |
logger.info(
|
| 461 |
-
"healed: %s effort cascade → %s",
|
|
|
|
|
|
|
| 462 |
)
|
| 463 |
except ProbeInconclusive:
|
| 464 |
# Transient during healing — strip thinking for safety, next
|
|
@@ -477,7 +507,11 @@ def _friendly_error_message(error: Exception) -> str | None:
|
|
| 477 |
"""Return a user-friendly message for known error types, or None to fall back to traceback."""
|
| 478 |
err_str = str(error).lower()
|
| 479 |
|
| 480 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
return (
|
| 482 |
"Authentication failed — your API key is missing or invalid.\n\n"
|
| 483 |
"To fix this, set the API key for your model provider:\n"
|
|
@@ -503,8 +537,7 @@ def _friendly_error_message(error: Exception) -> str | None:
|
|
| 503 |
)
|
| 504 |
|
| 505 |
if "model_not_found" in err_str or (
|
| 506 |
-
"model" in err_str
|
| 507 |
-
and ("not found" in err_str or "does not exist" in err_str)
|
| 508 |
):
|
| 509 |
return (
|
| 510 |
"Model not found. Use '/model' to list suggestions, or paste an "
|
|
@@ -530,7 +563,10 @@ async def _compact_and_notify(session: Session) -> None:
|
|
| 530 |
old_usage = cm.running_context_usage
|
| 531 |
logger.debug(
|
| 532 |
"Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s",
|
| 533 |
-
old_usage,
|
|
|
|
|
|
|
|
|
|
| 534 |
)
|
| 535 |
try:
|
| 536 |
await cm.compact(
|
|
@@ -542,24 +578,27 @@ async def _compact_and_notify(session: Session) -> None:
|
|
| 542 |
except CompactionFailedError as e:
|
| 543 |
logger.error(
|
| 544 |
"Compaction failed for session %s: %s — terminating session",
|
| 545 |
-
session.session_id,
|
|
|
|
| 546 |
)
|
| 547 |
# Persist the failure event so the dataset has a record of WHY this
|
| 548 |
# session ended (and the cost it incurred up to that point) even if
|
| 549 |
# save_and_upload_detached has issues downstream.
|
| 550 |
-
await session.send_event(
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
"
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
|
|
|
|
|
|
| 563 |
# Stop the agent loop; the finally in _run_session will fire
|
| 564 |
# cleanup_sandbox + save_trajectory so the dataset captures
|
| 565 |
# everything that did happen.
|
|
@@ -570,7 +609,10 @@ async def _compact_and_notify(session: Session) -> None:
|
|
| 570 |
if new_usage != old_usage:
|
| 571 |
logger.warning(
|
| 572 |
"Context compacted: %d -> %d tokens (max=%d, %d messages)",
|
| 573 |
-
old_usage,
|
|
|
|
|
|
|
|
|
|
| 574 |
)
|
| 575 |
await session.send_event(
|
| 576 |
Event(
|
|
@@ -609,6 +651,7 @@ async def _cleanup_on_cancel(session: Session) -> None:
|
|
| 609 |
@dataclass
|
| 610 |
class LLMResult:
|
| 611 |
"""Result from an LLM call (streaming or non-streaming)."""
|
|
|
|
| 612 |
content: str | None
|
| 613 |
tool_calls_acc: dict[int, dict]
|
| 614 |
token_count: int
|
|
@@ -728,16 +771,18 @@ async def _maybe_heal_invalid_thinking_signature(
|
|
| 728 |
if not stripped:
|
| 729 |
return False
|
| 730 |
|
| 731 |
-
await session.send_event(
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
"
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
|
|
|
|
|
|
| 741 |
return True
|
| 742 |
|
| 743 |
|
|
@@ -762,7 +807,9 @@ def _assistant_message_from_result(
|
|
| 762 |
return Message(**kwargs)
|
| 763 |
|
| 764 |
|
| 765 |
-
async def _call_llm_streaming(
|
|
|
|
|
|
|
| 766 |
"""Call the LLM with streaming, emitting assistant_chunk events."""
|
| 767 |
response = None
|
| 768 |
_healed_effort = False # one-shot safety net per call
|
|
@@ -788,11 +835,18 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 788 |
raise ContextWindowExceededError(str(e)) from e
|
| 789 |
if not _healed_effort and _is_effort_config_error(e):
|
| 790 |
_healed_effort = True
|
| 791 |
-
llm_params = await _heal_effort_and_rebuild_params(
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
continue
|
| 797 |
if await _maybe_heal_invalid_thinking_signature(
|
| 798 |
session,
|
|
@@ -806,12 +860,20 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 806 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 807 |
logger.warning(
|
| 808 |
"Transient LLM error (attempt %d/%d): %s — retrying in %ds",
|
| 809 |
-
_llm_attempt + 1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 810 |
)
|
| 811 |
-
await session.send_event(Event(
|
| 812 |
-
event_type="tool_log",
|
| 813 |
-
data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."},
|
| 814 |
-
))
|
| 815 |
await asyncio.sleep(_delay)
|
| 816 |
continue
|
| 817 |
raise
|
|
@@ -852,16 +914,21 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 852 |
idx = tc_delta.index
|
| 853 |
if idx not in tool_calls_acc:
|
| 854 |
tool_calls_acc[idx] = {
|
| 855 |
-
"id": "",
|
|
|
|
| 856 |
"function": {"name": "", "arguments": ""},
|
| 857 |
}
|
| 858 |
if tc_delta.id:
|
| 859 |
tool_calls_acc[idx]["id"] = tc_delta.id
|
| 860 |
if tc_delta.function:
|
| 861 |
if tc_delta.function.name:
|
| 862 |
-
tool_calls_acc[idx]["function"]["name"] +=
|
|
|
|
|
|
|
| 863 |
if tc_delta.function.arguments:
|
| 864 |
-
tool_calls_acc[idx]["function"]["arguments"] +=
|
|
|
|
|
|
|
| 865 |
|
| 866 |
if hasattr(chunk, "usage") and chunk.usage:
|
| 867 |
token_count = chunk.usage.total_tokens
|
|
@@ -881,7 +948,9 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 881 |
rebuilt = stream_chunk_builder(chunks, messages=messages)
|
| 882 |
if rebuilt and getattr(rebuilt, "choices", None):
|
| 883 |
rebuilt_msg = rebuilt.choices[0].message
|
| 884 |
-
thinking_blocks, reasoning_content = _extract_thinking_state(
|
|
|
|
|
|
|
| 885 |
except Exception:
|
| 886 |
logger.debug("Failed to rebuild streaming thinking state", exc_info=True)
|
| 887 |
|
|
@@ -896,7 +965,9 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 896 |
)
|
| 897 |
|
| 898 |
|
| 899 |
-
async def _call_llm_non_streaming(
|
|
|
|
|
|
|
| 900 |
"""Call the LLM without streaming, emit assistant_message at the end."""
|
| 901 |
response = None
|
| 902 |
_healed_effort = False
|
|
@@ -921,11 +992,18 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 921 |
raise ContextWindowExceededError(str(e)) from e
|
| 922 |
if not _healed_effort and _is_effort_config_error(e):
|
| 923 |
_healed_effort = True
|
| 924 |
-
llm_params = await _heal_effort_and_rebuild_params(
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 929 |
continue
|
| 930 |
if await _maybe_heal_invalid_thinking_signature(
|
| 931 |
session,
|
|
@@ -939,12 +1017,20 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 939 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 940 |
logger.warning(
|
| 941 |
"Transient LLM error (attempt %d/%d): %s — retrying in %ds",
|
| 942 |
-
_llm_attempt + 1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 943 |
)
|
| 944 |
-
await session.send_event(Event(
|
| 945 |
-
event_type="tool_log",
|
| 946 |
-
data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."},
|
| 947 |
-
))
|
| 948 |
await asyncio.sleep(_delay)
|
| 949 |
continue
|
| 950 |
raise
|
|
@@ -1037,7 +1123,8 @@ class Handlers:
|
|
| 1037 |
|
| 1038 |
@staticmethod
|
| 1039 |
async def run_agent(
|
| 1040 |
-
session: Session,
|
|
|
|
| 1041 |
) -> str | None:
|
| 1042 |
"""
|
| 1043 |
Handle user input (like user_input_or_turn in codex.rs:1291)
|
|
@@ -1124,12 +1211,18 @@ class Handlers:
|
|
| 1124 |
llm_params = _resolve_llm_params(
|
| 1125 |
session.config.model_name,
|
| 1126 |
session.hf_token,
|
| 1127 |
-
reasoning_effort=session.effective_effort_for(
|
|
|
|
|
|
|
| 1128 |
)
|
| 1129 |
if session.stream:
|
| 1130 |
-
llm_result = await _call_llm_streaming(
|
|
|
|
|
|
|
| 1131 |
else:
|
| 1132 |
-
llm_result = await _call_llm_non_streaming(
|
|
|
|
|
|
|
| 1133 |
|
| 1134 |
content = llm_result.content
|
| 1135 |
tool_calls_acc = llm_result.tool_calls_acc
|
|
@@ -1176,7 +1269,10 @@ class Handlers:
|
|
| 1176 |
await session.send_event(
|
| 1177 |
Event(
|
| 1178 |
event_type="tool_log",
|
| 1179 |
-
data={
|
|
|
|
|
|
|
|
|
|
| 1180 |
)
|
| 1181 |
)
|
| 1182 |
iteration += 1
|
|
@@ -1239,7 +1335,8 @@ class Handlers:
|
|
| 1239 |
except (json.JSONDecodeError, TypeError, ValueError):
|
| 1240 |
logger.warning(
|
| 1241 |
"Malformed arguments for tool_call %s (%s) — skipping",
|
| 1242 |
-
tc.id,
|
|
|
|
| 1243 |
)
|
| 1244 |
tc.function.arguments = "{}"
|
| 1245 |
bad_tools.append(tc)
|
|
@@ -1260,20 +1357,35 @@ class Handlers:
|
|
| 1260 |
f"arguments and was NOT executed. Retry with smaller content — "
|
| 1261 |
f"for 'write', split into multiple smaller writes using 'edit'."
|
| 1262 |
)
|
| 1263 |
-
session.context_manager.add_message(
|
| 1264 |
-
|
| 1265 |
-
|
| 1266 |
-
|
| 1267 |
-
|
| 1268 |
-
|
| 1269 |
-
|
| 1270 |
-
|
| 1271 |
-
|
| 1272 |
-
|
| 1273 |
-
|
| 1274 |
-
|
| 1275 |
-
|
| 1276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1277 |
|
| 1278 |
# ── Cancellation check: before tool execution ──
|
| 1279 |
if session.is_cancelled:
|
|
@@ -1298,7 +1410,9 @@ class Handlers:
|
|
| 1298 |
reserved_spend_usd=reserved_auto_spend_usd,
|
| 1299 |
)
|
| 1300 |
if decision.requires_approval:
|
| 1301 |
-
approval_required_tools.append(
|
|
|
|
|
|
|
| 1302 |
else:
|
| 1303 |
non_approval_tools.append((tc, tool_name, tool_args, decision))
|
| 1304 |
if (
|
|
@@ -1321,7 +1435,14 @@ class Handlers:
|
|
| 1321 |
)
|
| 1322 |
|
| 1323 |
# 2. Send all tool_call events upfront (so frontend shows them all)
|
| 1324 |
-
for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1325 |
if args_valid:
|
| 1326 |
await session.send_event(
|
| 1327 |
Event(
|
|
@@ -1352,12 +1473,14 @@ class Handlers:
|
|
| 1352 |
)
|
| 1353 |
return (tc, name, args, out, ok)
|
| 1354 |
|
| 1355 |
-
gather_task = asyncio.ensure_future(
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
|
|
|
|
|
|
|
| 1361 |
cancel_task = asyncio.ensure_future(session._cancelled.wait())
|
| 1362 |
|
| 1363 |
done, _ = await asyncio.wait(
|
|
@@ -1374,10 +1497,16 @@ class Handlers:
|
|
| 1374 |
# Notify frontend that in-flight tools were cancelled
|
| 1375 |
for tc, name, _args, _decision, valid, _ in parsed_tools:
|
| 1376 |
if valid:
|
| 1377 |
-
await session.send_event(
|
| 1378 |
-
|
| 1379 |
-
|
| 1380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1381 |
await _cleanup_on_cancel(session)
|
| 1382 |
break
|
| 1383 |
|
|
@@ -1414,10 +1543,15 @@ class Handlers:
|
|
| 1414 |
for tc, tool_name, tool_args, decision in approval_required_tools:
|
| 1415 |
# Resolve sandbox file paths for hf_jobs scripts so the
|
| 1416 |
# frontend can display & edit the actual file content.
|
| 1417 |
-
if tool_name == "hf_jobs" and isinstance(
|
|
|
|
|
|
|
| 1418 |
from agent.tools.sandbox_tool import resolve_sandbox_script
|
|
|
|
| 1419 |
sandbox = getattr(session, "sandbox", None)
|
| 1420 |
-
resolved, _ = await resolve_sandbox_script(
|
|
|
|
|
|
|
| 1421 |
if resolved:
|
| 1422 |
tool_args = {**tool_args, "script": resolved}
|
| 1423 |
|
|
@@ -1449,10 +1583,12 @@ class Handlers:
|
|
| 1449 |
"remaining_cap_usd": first.get("remaining_cap_usd"),
|
| 1450 |
}
|
| 1451 |
)
|
| 1452 |
-
await session.send_event(
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
|
|
|
|
|
|
| 1456 |
|
| 1457 |
# Store all approval-requiring tools (ToolCall objects for execution)
|
| 1458 |
session.pending_approval = {
|
|
@@ -1470,7 +1606,10 @@ class Handlers:
|
|
| 1470 |
logger.warning(
|
| 1471 |
"ContextWindowExceededError at iteration %d — forcing compaction "
|
| 1472 |
"(usage=%d, model_max_tokens=%d, messages=%d)",
|
| 1473 |
-
iteration,
|
|
|
|
|
|
|
|
|
|
| 1474 |
)
|
| 1475 |
cm.running_context_usage = cm.model_max_tokens + 1
|
| 1476 |
await _compact_and_notify(session)
|
|
@@ -1662,13 +1801,15 @@ class Handlers:
|
|
| 1662 |
|
| 1663 |
# Execute all approved tools concurrently (cancellable)
|
| 1664 |
if approved_tasks:
|
| 1665 |
-
gather_task = asyncio.ensure_future(
|
| 1666 |
-
|
| 1667 |
-
|
| 1668 |
-
|
| 1669 |
-
|
| 1670 |
-
|
| 1671 |
-
|
|
|
|
|
|
|
| 1672 |
cancel_task = asyncio.ensure_future(session._cancelled.wait())
|
| 1673 |
|
| 1674 |
done, _ = await asyncio.wait(
|
|
@@ -1684,10 +1825,16 @@ class Handlers:
|
|
| 1684 |
pass
|
| 1685 |
# Notify frontend that approved tools were cancelled
|
| 1686 |
for tc, tool_name, _args, _was_edited in approved_tasks:
|
| 1687 |
-
await session.send_event(
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1691 |
await _cleanup_on_cancel(session)
|
| 1692 |
await session.send_event(Event(event_type="interrupted"))
|
| 1693 |
session.increment_turn()
|
|
@@ -1839,14 +1986,20 @@ async def submission_loop(
|
|
| 1839 |
|
| 1840 |
# Create session with tool router
|
| 1841 |
session = Session(
|
| 1842 |
-
event_queue,
|
| 1843 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1844 |
notification_gateway=notification_gateway,
|
| 1845 |
notification_destinations=notification_destinations,
|
| 1846 |
defer_turn_complete_notification=defer_turn_complete_notification,
|
| 1847 |
)
|
| 1848 |
if session_holder is not None:
|
| 1849 |
session_holder[0] = session
|
|
|
|
| 1850 |
logger.info("Agent loop started")
|
| 1851 |
|
| 1852 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
|
@@ -1864,10 +2017,13 @@ async def submission_loop(
|
|
| 1864 |
async with tool_router:
|
| 1865 |
# Emit ready event after initialization
|
| 1866 |
await session.send_event(
|
| 1867 |
-
Event(
|
| 1868 |
-
"
|
| 1869 |
-
|
| 1870 |
-
|
|
|
|
|
|
|
|
|
|
| 1871 |
)
|
| 1872 |
|
| 1873 |
while session.is_running:
|
|
|
|
| 5 |
import asyncio
|
| 6 |
import json
|
| 7 |
import logging
|
|
|
|
| 8 |
import time
|
| 9 |
from dataclasses import dataclass, field
|
| 10 |
from typing import Any
|
|
|
|
| 26 |
from agent.messaging.gateway import NotificationGateway
|
| 27 |
from agent.core import telemetry
|
| 28 |
from agent.core.doom_loop import check_for_doom_loop
|
| 29 |
+
from agent.core.hub_artifacts import start_session_artifact_collection_task
|
| 30 |
from agent.core.llm_params import _resolve_llm_params
|
| 31 |
from agent.core.prompt_caching import with_prompt_caching
|
| 32 |
from agent.core.session import Event, OpType, Session
|
|
|
|
| 54 |
end = content.find(_MALFORMED_TOOL_SUFFIX, len(_MALFORMED_TOOL_PREFIX))
|
| 55 |
if end == -1:
|
| 56 |
return None
|
| 57 |
+
return content[len(_MALFORMED_TOOL_PREFIX) : end]
|
| 58 |
|
| 59 |
|
| 60 |
def _detect_repeated_malformed(
|
| 61 |
+
items: list[Message],
|
| 62 |
+
threshold: int = 2,
|
| 63 |
) -> str | None:
|
| 64 |
"""Return the repeated malformed tool name if the tail contains a streak.
|
| 65 |
|
|
|
|
| 119 |
|
| 120 |
_IMMEDIATE_HF_JOB_RUNS = {"run", "uv"}
|
| 121 |
|
| 122 |
+
|
| 123 |
@dataclass(frozen=True)
|
| 124 |
class ApprovalDecision:
|
| 125 |
requires_approval: bool
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool:
|
| 147 |
+
return tool_name == "sandbox_create" or _is_immediate_hf_job_run(
|
| 148 |
+
tool_name, tool_args
|
| 149 |
+
)
|
| 150 |
|
| 151 |
|
| 152 |
def _base_needs_approval(
|
|
|
|
| 235 |
|
| 236 |
|
| 237 |
def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool:
|
| 238 |
+
return bool(
|
| 239 |
+
(config and config.yolo_mode) or _session_auto_approval_enabled(session)
|
| 240 |
+
)
|
| 241 |
|
| 242 |
|
| 243 |
def _remaining_budget_after_reservations(
|
|
|
|
| 257 |
) -> str | None:
|
| 258 |
if estimate.estimated_cost_usd is None:
|
| 259 |
return estimate.block_reason or "Could not estimate the cost safely."
|
| 260 |
+
if (
|
| 261 |
+
remaining_cap_usd is not None
|
| 262 |
+
and estimate.estimated_cost_usd > remaining_cap_usd
|
| 263 |
+
):
|
| 264 |
return (
|
| 265 |
f"Estimated cost ${estimate.estimated_cost_usd:.2f} exceeds "
|
| 266 |
f"remaining YOLO cap ${remaining_cap_usd:.2f}."
|
|
|
|
| 418 |
"""Return True for errors that are likely transient and worth retrying."""
|
| 419 |
err_str = str(error).lower()
|
| 420 |
transient_patterns = [
|
| 421 |
+
"timeout",
|
| 422 |
+
"timed out",
|
| 423 |
+
"503",
|
| 424 |
+
"service unavailable",
|
| 425 |
+
"502",
|
| 426 |
+
"bad gateway",
|
| 427 |
+
"500",
|
| 428 |
+
"internal server error",
|
| 429 |
+
"overloaded",
|
| 430 |
+
"capacity",
|
| 431 |
+
"connection reset",
|
| 432 |
+
"connection refused",
|
| 433 |
+
"connection error",
|
| 434 |
+
"eof",
|
| 435 |
+
"broken pipe",
|
| 436 |
]
|
| 437 |
+
return _is_rate_limit_error(error) or any(
|
| 438 |
+
pattern in err_str for pattern in transient_patterns
|
| 439 |
+
)
|
| 440 |
|
| 441 |
|
| 442 |
def _is_effort_config_error(error: Exception) -> bool:
|
|
|
|
| 448 |
doesn't work for the current model. We heal the cache and retry once.
|
| 449 |
"""
|
| 450 |
from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported
|
| 451 |
+
|
| 452 |
return _is_thinking_unsupported(error) or _is_invalid_effort(error)
|
| 453 |
|
| 454 |
|
| 455 |
async def _heal_effort_and_rebuild_params(
|
| 456 |
+
session: Session,
|
| 457 |
+
error: Exception,
|
| 458 |
+
llm_params: dict,
|
| 459 |
) -> dict:
|
| 460 |
"""Update the session's effort cache based on ``error`` and return new
|
| 461 |
llm_params. Called only when ``_is_effort_config_error(error)`` is True.
|
|
|
|
| 466 |
• invalid-effort → re-run the full cascade probe; the result lands
|
| 467 |
in the cache
|
| 468 |
"""
|
| 469 |
+
from agent.core.effort_probe import (
|
| 470 |
+
ProbeInconclusive,
|
| 471 |
+
_is_thinking_unsupported,
|
| 472 |
+
probe_effort,
|
| 473 |
+
)
|
| 474 |
|
| 475 |
model = session.config.model_name
|
| 476 |
if _is_thinking_unsupported(error):
|
|
|
|
| 479 |
else:
|
| 480 |
try:
|
| 481 |
outcome = await probe_effort(
|
| 482 |
+
model,
|
| 483 |
+
session.config.reasoning_effort,
|
| 484 |
+
session.hf_token,
|
| 485 |
session=session,
|
| 486 |
)
|
| 487 |
session.model_effective_effort[model] = outcome.effective_effort
|
| 488 |
logger.info(
|
| 489 |
+
"healed: %s effort cascade → %s",
|
| 490 |
+
model,
|
| 491 |
+
outcome.effective_effort,
|
| 492 |
)
|
| 493 |
except ProbeInconclusive:
|
| 494 |
# Transient during healing — strip thinking for safety, next
|
|
|
|
| 507 |
"""Return a user-friendly message for known error types, or None to fall back to traceback."""
|
| 508 |
err_str = str(error).lower()
|
| 509 |
|
| 510 |
+
if (
|
| 511 |
+
"authentication" in err_str
|
| 512 |
+
or "unauthorized" in err_str
|
| 513 |
+
or "invalid x-api-key" in err_str
|
| 514 |
+
):
|
| 515 |
return (
|
| 516 |
"Authentication failed — your API key is missing or invalid.\n\n"
|
| 517 |
"To fix this, set the API key for your model provider:\n"
|
|
|
|
| 537 |
)
|
| 538 |
|
| 539 |
if "model_not_found" in err_str or (
|
| 540 |
+
"model" in err_str and ("not found" in err_str or "does not exist" in err_str)
|
|
|
|
| 541 |
):
|
| 542 |
return (
|
| 543 |
"Model not found. Use '/model' to list suggestions, or paste an "
|
|
|
|
| 563 |
old_usage = cm.running_context_usage
|
| 564 |
logger.debug(
|
| 565 |
"Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s",
|
| 566 |
+
old_usage,
|
| 567 |
+
cm.model_max_tokens,
|
| 568 |
+
cm.compaction_threshold,
|
| 569 |
+
cm.needs_compaction,
|
| 570 |
)
|
| 571 |
try:
|
| 572 |
await cm.compact(
|
|
|
|
| 578 |
except CompactionFailedError as e:
|
| 579 |
logger.error(
|
| 580 |
"Compaction failed for session %s: %s — terminating session",
|
| 581 |
+
session.session_id,
|
| 582 |
+
e,
|
| 583 |
)
|
| 584 |
# Persist the failure event so the dataset has a record of WHY this
|
| 585 |
# session ended (and the cost it incurred up to that point) even if
|
| 586 |
# save_and_upload_detached has issues downstream.
|
| 587 |
+
await session.send_event(
|
| 588 |
+
Event(
|
| 589 |
+
event_type="session_terminated",
|
| 590 |
+
data={
|
| 591 |
+
"reason": "compaction_failed",
|
| 592 |
+
"context_usage": cm.running_context_usage,
|
| 593 |
+
"context_threshold": cm.compaction_threshold,
|
| 594 |
+
"error": str(e)[:300],
|
| 595 |
+
"user_message": (
|
| 596 |
+
"Your conversation has grown too large to continue. "
|
| 597 |
+
"The work you've done is saved — start a new session to keep going."
|
| 598 |
+
),
|
| 599 |
+
},
|
| 600 |
+
)
|
| 601 |
+
)
|
| 602 |
# Stop the agent loop; the finally in _run_session will fire
|
| 603 |
# cleanup_sandbox + save_trajectory so the dataset captures
|
| 604 |
# everything that did happen.
|
|
|
|
| 609 |
if new_usage != old_usage:
|
| 610 |
logger.warning(
|
| 611 |
"Context compacted: %d -> %d tokens (max=%d, %d messages)",
|
| 612 |
+
old_usage,
|
| 613 |
+
new_usage,
|
| 614 |
+
cm.model_max_tokens,
|
| 615 |
+
len(cm.items),
|
| 616 |
)
|
| 617 |
await session.send_event(
|
| 618 |
Event(
|
|
|
|
| 651 |
@dataclass
|
| 652 |
class LLMResult:
|
| 653 |
"""Result from an LLM call (streaming or non-streaming)."""
|
| 654 |
+
|
| 655 |
content: str | None
|
| 656 |
tool_calls_acc: dict[int, dict]
|
| 657 |
token_count: int
|
|
|
|
| 771 |
if not stripped:
|
| 772 |
return False
|
| 773 |
|
| 774 |
+
await session.send_event(
|
| 775 |
+
Event(
|
| 776 |
+
event_type="tool_log",
|
| 777 |
+
data={
|
| 778 |
+
"tool": "system",
|
| 779 |
+
"log": (
|
| 780 |
+
"Anthropic rejected stale thinking signatures; retrying "
|
| 781 |
+
"without replayed thinking metadata."
|
| 782 |
+
),
|
| 783 |
+
},
|
| 784 |
+
)
|
| 785 |
+
)
|
| 786 |
return True
|
| 787 |
|
| 788 |
|
|
|
|
| 807 |
return Message(**kwargs)
|
| 808 |
|
| 809 |
|
| 810 |
+
async def _call_llm_streaming(
|
| 811 |
+
session: Session, messages, tools, llm_params
|
| 812 |
+
) -> LLMResult:
|
| 813 |
"""Call the LLM with streaming, emitting assistant_chunk events."""
|
| 814 |
response = None
|
| 815 |
_healed_effort = False # one-shot safety net per call
|
|
|
|
| 835 |
raise ContextWindowExceededError(str(e)) from e
|
| 836 |
if not _healed_effort and _is_effort_config_error(e):
|
| 837 |
_healed_effort = True
|
| 838 |
+
llm_params = await _heal_effort_and_rebuild_params(
|
| 839 |
+
session, e, llm_params
|
| 840 |
+
)
|
| 841 |
+
await session.send_event(
|
| 842 |
+
Event(
|
| 843 |
+
event_type="tool_log",
|
| 844 |
+
data={
|
| 845 |
+
"tool": "system",
|
| 846 |
+
"log": "Reasoning effort not supported for this model — adjusting and retrying.",
|
| 847 |
+
},
|
| 848 |
+
)
|
| 849 |
+
)
|
| 850 |
continue
|
| 851 |
if await _maybe_heal_invalid_thinking_signature(
|
| 852 |
session,
|
|
|
|
| 860 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 861 |
logger.warning(
|
| 862 |
"Transient LLM error (attempt %d/%d): %s — retrying in %ds",
|
| 863 |
+
_llm_attempt + 1,
|
| 864 |
+
_MAX_LLM_RETRIES,
|
| 865 |
+
e,
|
| 866 |
+
_delay,
|
| 867 |
+
)
|
| 868 |
+
await session.send_event(
|
| 869 |
+
Event(
|
| 870 |
+
event_type="tool_log",
|
| 871 |
+
data={
|
| 872 |
+
"tool": "system",
|
| 873 |
+
"log": f"LLM connection error, retrying in {_delay}s...",
|
| 874 |
+
},
|
| 875 |
+
)
|
| 876 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
await asyncio.sleep(_delay)
|
| 878 |
continue
|
| 879 |
raise
|
|
|
|
| 914 |
idx = tc_delta.index
|
| 915 |
if idx not in tool_calls_acc:
|
| 916 |
tool_calls_acc[idx] = {
|
| 917 |
+
"id": "",
|
| 918 |
+
"type": "function",
|
| 919 |
"function": {"name": "", "arguments": ""},
|
| 920 |
}
|
| 921 |
if tc_delta.id:
|
| 922 |
tool_calls_acc[idx]["id"] = tc_delta.id
|
| 923 |
if tc_delta.function:
|
| 924 |
if tc_delta.function.name:
|
| 925 |
+
tool_calls_acc[idx]["function"]["name"] += (
|
| 926 |
+
tc_delta.function.name
|
| 927 |
+
)
|
| 928 |
if tc_delta.function.arguments:
|
| 929 |
+
tool_calls_acc[idx]["function"]["arguments"] += (
|
| 930 |
+
tc_delta.function.arguments
|
| 931 |
+
)
|
| 932 |
|
| 933 |
if hasattr(chunk, "usage") and chunk.usage:
|
| 934 |
token_count = chunk.usage.total_tokens
|
|
|
|
| 948 |
rebuilt = stream_chunk_builder(chunks, messages=messages)
|
| 949 |
if rebuilt and getattr(rebuilt, "choices", None):
|
| 950 |
rebuilt_msg = rebuilt.choices[0].message
|
| 951 |
+
thinking_blocks, reasoning_content = _extract_thinking_state(
|
| 952 |
+
rebuilt_msg
|
| 953 |
+
)
|
| 954 |
except Exception:
|
| 955 |
logger.debug("Failed to rebuild streaming thinking state", exc_info=True)
|
| 956 |
|
|
|
|
| 965 |
)
|
| 966 |
|
| 967 |
|
| 968 |
+
async def _call_llm_non_streaming(
|
| 969 |
+
session: Session, messages, tools, llm_params
|
| 970 |
+
) -> LLMResult:
|
| 971 |
"""Call the LLM without streaming, emit assistant_message at the end."""
|
| 972 |
response = None
|
| 973 |
_healed_effort = False
|
|
|
|
| 992 |
raise ContextWindowExceededError(str(e)) from e
|
| 993 |
if not _healed_effort and _is_effort_config_error(e):
|
| 994 |
_healed_effort = True
|
| 995 |
+
llm_params = await _heal_effort_and_rebuild_params(
|
| 996 |
+
session, e, llm_params
|
| 997 |
+
)
|
| 998 |
+
await session.send_event(
|
| 999 |
+
Event(
|
| 1000 |
+
event_type="tool_log",
|
| 1001 |
+
data={
|
| 1002 |
+
"tool": "system",
|
| 1003 |
+
"log": "Reasoning effort not supported for this model — adjusting and retrying.",
|
| 1004 |
+
},
|
| 1005 |
+
)
|
| 1006 |
+
)
|
| 1007 |
continue
|
| 1008 |
if await _maybe_heal_invalid_thinking_signature(
|
| 1009 |
session,
|
|
|
|
| 1017 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 1018 |
logger.warning(
|
| 1019 |
"Transient LLM error (attempt %d/%d): %s — retrying in %ds",
|
| 1020 |
+
_llm_attempt + 1,
|
| 1021 |
+
_MAX_LLM_RETRIES,
|
| 1022 |
+
e,
|
| 1023 |
+
_delay,
|
| 1024 |
+
)
|
| 1025 |
+
await session.send_event(
|
| 1026 |
+
Event(
|
| 1027 |
+
event_type="tool_log",
|
| 1028 |
+
data={
|
| 1029 |
+
"tool": "system",
|
| 1030 |
+
"log": f"LLM connection error, retrying in {_delay}s...",
|
| 1031 |
+
},
|
| 1032 |
+
)
|
| 1033 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1034 |
await asyncio.sleep(_delay)
|
| 1035 |
continue
|
| 1036 |
raise
|
|
|
|
| 1123 |
|
| 1124 |
@staticmethod
|
| 1125 |
async def run_agent(
|
| 1126 |
+
session: Session,
|
| 1127 |
+
text: str,
|
| 1128 |
) -> str | None:
|
| 1129 |
"""
|
| 1130 |
Handle user input (like user_input_or_turn in codex.rs:1291)
|
|
|
|
| 1211 |
llm_params = _resolve_llm_params(
|
| 1212 |
session.config.model_name,
|
| 1213 |
session.hf_token,
|
| 1214 |
+
reasoning_effort=session.effective_effort_for(
|
| 1215 |
+
session.config.model_name
|
| 1216 |
+
),
|
| 1217 |
)
|
| 1218 |
if session.stream:
|
| 1219 |
+
llm_result = await _call_llm_streaming(
|
| 1220 |
+
session, messages, tools, llm_params
|
| 1221 |
+
)
|
| 1222 |
else:
|
| 1223 |
+
llm_result = await _call_llm_non_streaming(
|
| 1224 |
+
session, messages, tools, llm_params
|
| 1225 |
+
)
|
| 1226 |
|
| 1227 |
content = llm_result.content
|
| 1228 |
tool_calls_acc = llm_result.tool_calls_acc
|
|
|
|
| 1269 |
await session.send_event(
|
| 1270 |
Event(
|
| 1271 |
event_type="tool_log",
|
| 1272 |
+
data={
|
| 1273 |
+
"tool": "system",
|
| 1274 |
+
"log": f"Output truncated — retrying with smaller content ({dropped_names})",
|
| 1275 |
+
},
|
| 1276 |
)
|
| 1277 |
)
|
| 1278 |
iteration += 1
|
|
|
|
| 1335 |
except (json.JSONDecodeError, TypeError, ValueError):
|
| 1336 |
logger.warning(
|
| 1337 |
"Malformed arguments for tool_call %s (%s) — skipping",
|
| 1338 |
+
tc.id,
|
| 1339 |
+
tc.function.name,
|
| 1340 |
)
|
| 1341 |
tc.function.arguments = "{}"
|
| 1342 |
bad_tools.append(tc)
|
|
|
|
| 1357 |
f"arguments and was NOT executed. Retry with smaller content — "
|
| 1358 |
f"for 'write', split into multiple smaller writes using 'edit'."
|
| 1359 |
)
|
| 1360 |
+
session.context_manager.add_message(
|
| 1361 |
+
Message(
|
| 1362 |
+
role="tool",
|
| 1363 |
+
content=error_msg,
|
| 1364 |
+
tool_call_id=tc.id,
|
| 1365 |
+
name=tc.function.name,
|
| 1366 |
+
)
|
| 1367 |
+
)
|
| 1368 |
+
await session.send_event(
|
| 1369 |
+
Event(
|
| 1370 |
+
event_type="tool_call",
|
| 1371 |
+
data={
|
| 1372 |
+
"tool": tc.function.name,
|
| 1373 |
+
"arguments": {},
|
| 1374 |
+
"tool_call_id": tc.id,
|
| 1375 |
+
},
|
| 1376 |
+
)
|
| 1377 |
+
)
|
| 1378 |
+
await session.send_event(
|
| 1379 |
+
Event(
|
| 1380 |
+
event_type="tool_output",
|
| 1381 |
+
data={
|
| 1382 |
+
"tool": tc.function.name,
|
| 1383 |
+
"tool_call_id": tc.id,
|
| 1384 |
+
"output": error_msg,
|
| 1385 |
+
"success": False,
|
| 1386 |
+
},
|
| 1387 |
+
)
|
| 1388 |
+
)
|
| 1389 |
|
| 1390 |
# ── Cancellation check: before tool execution ──
|
| 1391 |
if session.is_cancelled:
|
|
|
|
| 1410 |
reserved_spend_usd=reserved_auto_spend_usd,
|
| 1411 |
)
|
| 1412 |
if decision.requires_approval:
|
| 1413 |
+
approval_required_tools.append(
|
| 1414 |
+
(tc, tool_name, tool_args, decision)
|
| 1415 |
+
)
|
| 1416 |
else:
|
| 1417 |
non_approval_tools.append((tc, tool_name, tool_args, decision))
|
| 1418 |
if (
|
|
|
|
| 1435 |
)
|
| 1436 |
|
| 1437 |
# 2. Send all tool_call events upfront (so frontend shows them all)
|
| 1438 |
+
for (
|
| 1439 |
+
tc,
|
| 1440 |
+
tool_name,
|
| 1441 |
+
tool_args,
|
| 1442 |
+
_decision,
|
| 1443 |
+
args_valid,
|
| 1444 |
+
_,
|
| 1445 |
+
) in parsed_tools:
|
| 1446 |
if args_valid:
|
| 1447 |
await session.send_event(
|
| 1448 |
Event(
|
|
|
|
| 1473 |
)
|
| 1474 |
return (tc, name, args, out, ok)
|
| 1475 |
|
| 1476 |
+
gather_task = asyncio.ensure_future(
|
| 1477 |
+
asyncio.gather(
|
| 1478 |
+
*[
|
| 1479 |
+
_exec_tool(tc, name, args, decision, valid, err)
|
| 1480 |
+
for tc, name, args, decision, valid, err in parsed_tools
|
| 1481 |
+
]
|
| 1482 |
+
)
|
| 1483 |
+
)
|
| 1484 |
cancel_task = asyncio.ensure_future(session._cancelled.wait())
|
| 1485 |
|
| 1486 |
done, _ = await asyncio.wait(
|
|
|
|
| 1497 |
# Notify frontend that in-flight tools were cancelled
|
| 1498 |
for tc, name, _args, _decision, valid, _ in parsed_tools:
|
| 1499 |
if valid:
|
| 1500 |
+
await session.send_event(
|
| 1501 |
+
Event(
|
| 1502 |
+
event_type="tool_state_change",
|
| 1503 |
+
data={
|
| 1504 |
+
"tool_call_id": tc.id,
|
| 1505 |
+
"tool": name,
|
| 1506 |
+
"state": "cancelled",
|
| 1507 |
+
},
|
| 1508 |
+
)
|
| 1509 |
+
)
|
| 1510 |
await _cleanup_on_cancel(session)
|
| 1511 |
break
|
| 1512 |
|
|
|
|
| 1543 |
for tc, tool_name, tool_args, decision in approval_required_tools:
|
| 1544 |
# Resolve sandbox file paths for hf_jobs scripts so the
|
| 1545 |
# frontend can display & edit the actual file content.
|
| 1546 |
+
if tool_name == "hf_jobs" and isinstance(
|
| 1547 |
+
tool_args.get("script"), str
|
| 1548 |
+
):
|
| 1549 |
from agent.tools.sandbox_tool import resolve_sandbox_script
|
| 1550 |
+
|
| 1551 |
sandbox = getattr(session, "sandbox", None)
|
| 1552 |
+
resolved, _ = await resolve_sandbox_script(
|
| 1553 |
+
sandbox, tool_args["script"]
|
| 1554 |
+
)
|
| 1555 |
if resolved:
|
| 1556 |
tool_args = {**tool_args, "script": resolved}
|
| 1557 |
|
|
|
|
| 1583 |
"remaining_cap_usd": first.get("remaining_cap_usd"),
|
| 1584 |
}
|
| 1585 |
)
|
| 1586 |
+
await session.send_event(
|
| 1587 |
+
Event(
|
| 1588 |
+
event_type="approval_required",
|
| 1589 |
+
data=event_data,
|
| 1590 |
+
)
|
| 1591 |
+
)
|
| 1592 |
|
| 1593 |
# Store all approval-requiring tools (ToolCall objects for execution)
|
| 1594 |
session.pending_approval = {
|
|
|
|
| 1606 |
logger.warning(
|
| 1607 |
"ContextWindowExceededError at iteration %d — forcing compaction "
|
| 1608 |
"(usage=%d, model_max_tokens=%d, messages=%d)",
|
| 1609 |
+
iteration,
|
| 1610 |
+
cm.running_context_usage,
|
| 1611 |
+
cm.model_max_tokens,
|
| 1612 |
+
len(cm.items),
|
| 1613 |
)
|
| 1614 |
cm.running_context_usage = cm.model_max_tokens + 1
|
| 1615 |
await _compact_and_notify(session)
|
|
|
|
| 1801 |
|
| 1802 |
# Execute all approved tools concurrently (cancellable)
|
| 1803 |
if approved_tasks:
|
| 1804 |
+
gather_task = asyncio.ensure_future(
|
| 1805 |
+
asyncio.gather(
|
| 1806 |
+
*[
|
| 1807 |
+
execute_tool(tc, tool_name, tool_args, was_edited)
|
| 1808 |
+
for tc, tool_name, tool_args, was_edited in approved_tasks
|
| 1809 |
+
],
|
| 1810 |
+
return_exceptions=True,
|
| 1811 |
+
)
|
| 1812 |
+
)
|
| 1813 |
cancel_task = asyncio.ensure_future(session._cancelled.wait())
|
| 1814 |
|
| 1815 |
done, _ = await asyncio.wait(
|
|
|
|
| 1825 |
pass
|
| 1826 |
# Notify frontend that approved tools were cancelled
|
| 1827 |
for tc, tool_name, _args, _was_edited in approved_tasks:
|
| 1828 |
+
await session.send_event(
|
| 1829 |
+
Event(
|
| 1830 |
+
event_type="tool_state_change",
|
| 1831 |
+
data={
|
| 1832 |
+
"tool_call_id": tc.id,
|
| 1833 |
+
"tool": tool_name,
|
| 1834 |
+
"state": "cancelled",
|
| 1835 |
+
},
|
| 1836 |
+
)
|
| 1837 |
+
)
|
| 1838 |
await _cleanup_on_cancel(session)
|
| 1839 |
await session.send_event(Event(event_type="interrupted"))
|
| 1840 |
session.increment_turn()
|
|
|
|
| 1986 |
|
| 1987 |
# Create session with tool router
|
| 1988 |
session = Session(
|
| 1989 |
+
event_queue,
|
| 1990 |
+
config=config,
|
| 1991 |
+
tool_router=tool_router,
|
| 1992 |
+
hf_token=hf_token,
|
| 1993 |
+
user_id=user_id,
|
| 1994 |
+
local_mode=local_mode,
|
| 1995 |
+
stream=stream,
|
| 1996 |
notification_gateway=notification_gateway,
|
| 1997 |
notification_destinations=notification_destinations,
|
| 1998 |
defer_turn_complete_notification=defer_turn_complete_notification,
|
| 1999 |
)
|
| 2000 |
if session_holder is not None:
|
| 2001 |
session_holder[0] = session
|
| 2002 |
+
start_session_artifact_collection_task(session, token=hf_token)
|
| 2003 |
logger.info("Agent loop started")
|
| 2004 |
|
| 2005 |
# Retry any failed uploads from previous sessions (fire-and-forget).
|
|
|
|
| 2017 |
async with tool_router:
|
| 2018 |
# Emit ready event after initialization
|
| 2019 |
await session.send_event(
|
| 2020 |
+
Event(
|
| 2021 |
+
event_type="ready",
|
| 2022 |
+
data={
|
| 2023 |
+
"message": "Agent initialized",
|
| 2024 |
+
"tool_count": len(tool_router.tools),
|
| 2025 |
+
},
|
| 2026 |
+
)
|
| 2027 |
)
|
| 2028 |
|
| 2029 |
while session.is_running:
|
agent/core/cost_estimation.py
CHANGED
|
@@ -88,7 +88,9 @@ class CostEstimate:
|
|
| 88 |
label: str | None = None
|
| 89 |
|
| 90 |
|
| 91 |
-
def parse_timeout_hours(
|
|
|
|
|
|
|
| 92 |
"""Parse HF timeout values into hours.
|
| 93 |
|
| 94 |
Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are
|
|
@@ -247,7 +249,9 @@ async def estimate_hf_job_cost(args: dict[str, Any]) -> CostEstimate:
|
|
| 247 |
)
|
| 248 |
|
| 249 |
|
| 250 |
-
async def estimate_sandbox_cost(
|
|
|
|
|
|
|
| 251 |
if session is not None and getattr(session, "sandbox", None):
|
| 252 |
return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing")
|
| 253 |
|
|
|
|
| 88 |
label: str | None = None
|
| 89 |
|
| 90 |
|
| 91 |
+
def parse_timeout_hours(
|
| 92 |
+
value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS
|
| 93 |
+
) -> float | None:
|
| 94 |
"""Parse HF timeout values into hours.
|
| 95 |
|
| 96 |
Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are
|
|
|
|
| 249 |
)
|
| 250 |
|
| 251 |
|
| 252 |
+
async def estimate_sandbox_cost(
|
| 253 |
+
args: dict[str, Any], *, session: Any = None
|
| 254 |
+
) -> CostEstimate:
|
| 255 |
if session is not None and getattr(session, "sandbox", None):
|
| 256 |
return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing")
|
| 257 |
|
agent/core/doom_loop.py
CHANGED
|
@@ -81,9 +81,11 @@ def extract_recent_tool_signatures(
|
|
| 81 |
name = getattr(fn, "name", "") or ""
|
| 82 |
args_str = getattr(fn, "arguments", "") or ""
|
| 83 |
result_hash = None
|
| 84 |
-
for follow in recent[idx + 1:]:
|
| 85 |
role = getattr(follow, "role", None)
|
| 86 |
-
if role == "tool" and getattr(follow, "tool_call_id", None) == getattr(
|
|
|
|
|
|
|
| 87 |
result_hash = _hash_args(str(getattr(follow, "content", "") or ""))
|
| 88 |
break
|
| 89 |
if role in {"assistant", "user"}:
|
|
@@ -174,7 +176,9 @@ def check_for_doom_loop(messages: list[Message]) -> str | None:
|
|
| 174 |
pattern = detect_repeating_sequence(signatures)
|
| 175 |
if pattern:
|
| 176 |
pattern_desc = " → ".join(s.name for s in pattern)
|
| 177 |
-
logger.warning(
|
|
|
|
|
|
|
| 178 |
return (
|
| 179 |
f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: "
|
| 180 |
f"[{pattern_desc}]. This pattern has repeated multiple times without progress. "
|
|
|
|
| 81 |
name = getattr(fn, "name", "") or ""
|
| 82 |
args_str = getattr(fn, "arguments", "") or ""
|
| 83 |
result_hash = None
|
| 84 |
+
for follow in recent[idx + 1 :]:
|
| 85 |
role = getattr(follow, "role", None)
|
| 86 |
+
if role == "tool" and getattr(follow, "tool_call_id", None) == getattr(
|
| 87 |
+
tc, "id", None
|
| 88 |
+
):
|
| 89 |
result_hash = _hash_args(str(getattr(follow, "content", "") or ""))
|
| 90 |
break
|
| 91 |
if role in {"assistant", "user"}:
|
|
|
|
| 176 |
pattern = detect_repeating_sequence(signatures)
|
| 177 |
if pattern:
|
| 178 |
pattern_desc = " → ".join(s.name for s in pattern)
|
| 179 |
+
logger.warning(
|
| 180 |
+
"Repetition guard activated: repeating sequence [%s]", pattern_desc
|
| 181 |
+
)
|
| 182 |
return (
|
| 183 |
f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: "
|
| 184 |
f"[{pattern_desc}]. This pattern has repeated multiple times without progress. "
|
agent/core/effort_probe.py
CHANGED
|
@@ -39,12 +39,12 @@ logger = logging.getLogger(__name__)
|
|
| 39 |
# requested level raise ``UnsupportedEffortError`` synchronously (no wasted
|
| 40 |
# network round-trip) and we advance to the next level.
|
| 41 |
_EFFORT_CASCADE: dict[str, list[str]] = {
|
| 42 |
-
"max":
|
| 43 |
-
"xhigh":
|
| 44 |
-
"high":
|
| 45 |
-
"medium":
|
| 46 |
"minimal": ["minimal", "low"],
|
| 47 |
-
"low":
|
| 48 |
}
|
| 49 |
|
| 50 |
_PROBE_TIMEOUT = 15.0
|
|
@@ -69,6 +69,7 @@ class ProbeOutcome:
|
|
| 69 |
* str → send this level
|
| 70 |
* None → model doesn't support thinking; strip it
|
| 71 |
"""
|
|
|
|
| 72 |
effective_effort: str | None
|
| 73 |
attempts: int
|
| 74 |
elapsed_ms: int
|
|
@@ -108,10 +109,15 @@ def _is_invalid_effort(e: Exception) -> bool:
|
|
| 108 |
return any(
|
| 109 |
phrase in s
|
| 110 |
for phrase in (
|
| 111 |
-
"invalid",
|
| 112 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
# LiteLLM's own pre-flight validation phrasing.
|
| 114 |
-
"only supported by",
|
|
|
|
| 115 |
)
|
| 116 |
)
|
| 117 |
|
|
@@ -128,11 +134,23 @@ def _is_transient(e: Exception) -> bool:
|
|
| 128 |
return any(
|
| 129 |
p in s
|
| 130 |
for p in (
|
| 131 |
-
"timeout",
|
| 132 |
-
"
|
| 133 |
-
"
|
| 134 |
-
"
|
| 135 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
)
|
| 137 |
)
|
| 138 |
|
|
@@ -173,7 +191,10 @@ async def probe_effort(
|
|
| 173 |
for effort in cascade:
|
| 174 |
try:
|
| 175 |
params = _resolve_llm_params(
|
| 176 |
-
model_name,
|
|
|
|
|
|
|
|
|
|
| 177 |
)
|
| 178 |
except UnsupportedEffortError:
|
| 179 |
# Provider can't even accept this effort name (e.g. "max" on
|
|
@@ -198,12 +219,15 @@ async def probe_effort(
|
|
| 198 |
# out of the probe and break model switching.
|
| 199 |
try:
|
| 200 |
from agent.core import telemetry
|
|
|
|
| 201 |
await telemetry.record_llm_call(
|
| 202 |
session,
|
| 203 |
model=model_name,
|
| 204 |
response=response,
|
| 205 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 206 |
-
finish_reason=response.choices[0].finish_reason
|
|
|
|
|
|
|
| 207 |
kind="effort_probe",
|
| 208 |
)
|
| 209 |
except Exception as _telem_err:
|
|
@@ -219,7 +243,9 @@ async def probe_effort(
|
|
| 219 |
note="model doesn't support reasoning, dropped",
|
| 220 |
)
|
| 221 |
if _is_invalid_effort(e):
|
| 222 |
-
logger.debug(
|
|
|
|
|
|
|
| 223 |
continue
|
| 224 |
if _is_transient(e):
|
| 225 |
raise ProbeInconclusive(str(e)) from e
|
|
|
|
| 39 |
# requested level raise ``UnsupportedEffortError`` synchronously (no wasted
|
| 40 |
# network round-trip) and we advance to the next level.
|
| 41 |
_EFFORT_CASCADE: dict[str, list[str]] = {
|
| 42 |
+
"max": ["max", "xhigh", "high", "medium", "low"],
|
| 43 |
+
"xhigh": ["xhigh", "high", "medium", "low"],
|
| 44 |
+
"high": ["high", "medium", "low"],
|
| 45 |
+
"medium": ["medium", "low"],
|
| 46 |
"minimal": ["minimal", "low"],
|
| 47 |
+
"low": ["low"],
|
| 48 |
}
|
| 49 |
|
| 50 |
_PROBE_TIMEOUT = 15.0
|
|
|
|
| 69 |
* str → send this level
|
| 70 |
* None → model doesn't support thinking; strip it
|
| 71 |
"""
|
| 72 |
+
|
| 73 |
effective_effort: str | None
|
| 74 |
attempts: int
|
| 75 |
elapsed_ms: int
|
|
|
|
| 109 |
return any(
|
| 110 |
phrase in s
|
| 111 |
for phrase in (
|
| 112 |
+
"invalid",
|
| 113 |
+
"not supported",
|
| 114 |
+
"must be one of",
|
| 115 |
+
"not a valid",
|
| 116 |
+
"unrecognized",
|
| 117 |
+
"unknown",
|
| 118 |
# LiteLLM's own pre-flight validation phrasing.
|
| 119 |
+
"only supported by",
|
| 120 |
+
"is only supported",
|
| 121 |
)
|
| 122 |
)
|
| 123 |
|
|
|
|
| 134 |
return any(
|
| 135 |
p in s
|
| 136 |
for p in (
|
| 137 |
+
"timeout",
|
| 138 |
+
"timed out",
|
| 139 |
+
"429",
|
| 140 |
+
"rate limit",
|
| 141 |
+
"503",
|
| 142 |
+
"service unavailable",
|
| 143 |
+
"502",
|
| 144 |
+
"bad gateway",
|
| 145 |
+
"500",
|
| 146 |
+
"internal server error",
|
| 147 |
+
"overloaded",
|
| 148 |
+
"capacity",
|
| 149 |
+
"connection reset",
|
| 150 |
+
"connection refused",
|
| 151 |
+
"connection error",
|
| 152 |
+
"eof",
|
| 153 |
+
"broken pipe",
|
| 154 |
)
|
| 155 |
)
|
| 156 |
|
|
|
|
| 191 |
for effort in cascade:
|
| 192 |
try:
|
| 193 |
params = _resolve_llm_params(
|
| 194 |
+
model_name,
|
| 195 |
+
hf_token,
|
| 196 |
+
reasoning_effort=effort,
|
| 197 |
+
strict=True,
|
| 198 |
)
|
| 199 |
except UnsupportedEffortError:
|
| 200 |
# Provider can't even accept this effort name (e.g. "max" on
|
|
|
|
| 219 |
# out of the probe and break model switching.
|
| 220 |
try:
|
| 221 |
from agent.core import telemetry
|
| 222 |
+
|
| 223 |
await telemetry.record_llm_call(
|
| 224 |
session,
|
| 225 |
model=model_name,
|
| 226 |
response=response,
|
| 227 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 228 |
+
finish_reason=response.choices[0].finish_reason
|
| 229 |
+
if response.choices
|
| 230 |
+
else None,
|
| 231 |
kind="effort_probe",
|
| 232 |
)
|
| 233 |
except Exception as _telem_err:
|
|
|
|
| 243 |
note="model doesn't support reasoning, dropped",
|
| 244 |
)
|
| 245 |
if _is_invalid_effort(e):
|
| 246 |
+
logger.debug(
|
| 247 |
+
"probe: %s rejected effort=%s, trying next", model_name, effort
|
| 248 |
+
)
|
| 249 |
continue
|
| 250 |
if _is_transient(e):
|
| 251 |
raise ProbeInconclusive(str(e)) from e
|
agent/core/hf_router_catalog.py
CHANGED
|
@@ -92,7 +92,9 @@ def _parse_entry(entry: dict) -> ModelInfo:
|
|
| 92 |
input_price=pricing.get("input"),
|
| 93 |
output_price=pricing.get("output"),
|
| 94 |
supports_tools=bool(p.get("supports_tools", False)),
|
| 95 |
-
supports_structured_output=bool(
|
|
|
|
|
|
|
| 96 |
)
|
| 97 |
)
|
| 98 |
return ModelInfo(id=entry.get("id", ""), providers=providers)
|
|
|
|
| 92 |
input_price=pricing.get("input"),
|
| 93 |
output_price=pricing.get("output"),
|
| 94 |
supports_tools=bool(p.get("supports_tools", False)),
|
| 95 |
+
supports_structured_output=bool(
|
| 96 |
+
p.get("supports_structured_output", False)
|
| 97 |
+
),
|
| 98 |
)
|
| 99 |
)
|
| 100 |
return ModelInfo(id=entry.get("id", ""), providers=providers)
|
agent/core/hub_artifacts.py
ADDED
|
@@ -0,0 +1,765 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Best-effort Hub metadata for artifacts generated by ML Intern sessions."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import base64
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
import shlex
|
| 8 |
+
import tempfile
|
| 9 |
+
import textwrap
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 15 |
+
from huggingface_hub.repocard import metadata_load, metadata_save
|
| 16 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
ML_INTERN_TAG = "ml-intern"
|
| 21 |
+
SUPPORTED_REPO_TYPES = {"model", "dataset", "space"}
|
| 22 |
+
PROVENANCE_MARKER = "<!-- ml-intern-provenance -->"
|
| 23 |
+
_COLLECTION_TITLE_PREFIX = "ml-intern-artifacts"
|
| 24 |
+
_COLLECTION_TITLE_MAX_LENGTH = 59
|
| 25 |
+
_UUID_SESSION_ID_RE = re.compile(
|
| 26 |
+
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
|
| 27 |
+
r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
|
| 28 |
+
)
|
| 29 |
+
_KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts"
|
| 30 |
+
_REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts"
|
| 31 |
+
_COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug"
|
| 32 |
+
_COLLECTION_TASK_ATTR = "_ml_intern_artifact_collection_task"
|
| 33 |
+
_SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {}
|
| 34 |
+
_USAGE_HEADING_RE = re.compile(
|
| 35 |
+
r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b",
|
| 36 |
+
re.IGNORECASE | re.MULTILINE,
|
| 37 |
+
)
|
| 38 |
+
_FRONT_MATTER_RE = re.compile(r"\A---\s*\n.*?\n---\s*\n?", re.DOTALL)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _safe_session_id(session: Any) -> str:
|
| 42 |
+
raw = str(getattr(session, "session_id", "") or "unknown-session")
|
| 43 |
+
safe = re.sub(r"[^A-Za-z0-9._-]+", "-", raw).strip("-")
|
| 44 |
+
return safe or "unknown-session"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def session_artifact_date(session: Any) -> str:
|
| 48 |
+
"""Return the YYYY-MM-DD partition date for a session."""
|
| 49 |
+
raw = getattr(session, "session_start_time", None)
|
| 50 |
+
if raw:
|
| 51 |
+
try:
|
| 52 |
+
return datetime.fromisoformat(str(raw).replace("Z", "+00:00")).strftime(
|
| 53 |
+
"%Y-%m-%d"
|
| 54 |
+
)
|
| 55 |
+
except ValueError:
|
| 56 |
+
logger.debug("Could not parse session_start_time=%r", raw)
|
| 57 |
+
return datetime.utcnow().strftime("%Y-%m-%d")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _collection_session_id_fragment(session: Any) -> str:
|
| 61 |
+
safe_id = _safe_session_id(session)
|
| 62 |
+
if _UUID_SESSION_ID_RE.match(safe_id):
|
| 63 |
+
return safe_id[:8]
|
| 64 |
+
stem = f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-"
|
| 65 |
+
max_id_length = max(1, _COLLECTION_TITLE_MAX_LENGTH - len(stem))
|
| 66 |
+
if len(safe_id) <= max_id_length:
|
| 67 |
+
return safe_id
|
| 68 |
+
return safe_id[:max_id_length].rstrip("-._") or safe_id[:max_id_length]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def artifact_collection_title(session: Any) -> str:
|
| 72 |
+
return (
|
| 73 |
+
f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-"
|
| 74 |
+
f"{_collection_session_id_fragment(session)}"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _artifact_key(repo_id: str, repo_type: str | None) -> str:
|
| 79 |
+
return f"{repo_type or 'model'}:{repo_id}"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _session_artifact_set(session: Any, attr: str) -> set[str]:
|
| 83 |
+
current = getattr(session, attr, None)
|
| 84 |
+
if isinstance(current, set):
|
| 85 |
+
return current
|
| 86 |
+
current = set()
|
| 87 |
+
try:
|
| 88 |
+
setattr(session, attr, current)
|
| 89 |
+
except Exception:
|
| 90 |
+
logger.warning(
|
| 91 |
+
"Could not attach %s to session; using process-local fallback state",
|
| 92 |
+
attr,
|
| 93 |
+
)
|
| 94 |
+
return _SESSION_ARTIFACT_SET_FALLBACK.setdefault((id(session), attr), set())
|
| 95 |
+
return current
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def remember_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> None:
|
| 99 |
+
if session is None or not repo_id:
|
| 100 |
+
return
|
| 101 |
+
_session_artifact_set(session, _KNOWN_ARTIFACTS_ATTR).add(
|
| 102 |
+
_artifact_key(repo_id, repo_type)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def is_known_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> bool:
|
| 107 |
+
if session is None or not repo_id:
|
| 108 |
+
return False
|
| 109 |
+
return _artifact_key(repo_id, repo_type) in _session_artifact_set(
|
| 110 |
+
session, _KNOWN_ARTIFACTS_ATTR
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _merge_tags(metadata: dict[str, Any], tag: str = ML_INTERN_TAG) -> dict[str, Any]:
|
| 115 |
+
merged = dict(metadata)
|
| 116 |
+
raw_tags = merged.get("tags")
|
| 117 |
+
if raw_tags is None:
|
| 118 |
+
tags: list[str] = []
|
| 119 |
+
elif isinstance(raw_tags, str):
|
| 120 |
+
tags = [raw_tags]
|
| 121 |
+
elif isinstance(raw_tags, list):
|
| 122 |
+
tags = [str(item) for item in raw_tags]
|
| 123 |
+
else:
|
| 124 |
+
tags = [str(raw_tags)]
|
| 125 |
+
|
| 126 |
+
if tag not in tags:
|
| 127 |
+
tags.append(tag)
|
| 128 |
+
merged["tags"] = tags
|
| 129 |
+
return merged
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _metadata_from_content(content: str) -> dict[str, Any]:
|
| 133 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 134 |
+
path = Path(tmp_dir) / "README.md"
|
| 135 |
+
path.write_text(content, encoding="utf-8")
|
| 136 |
+
return metadata_load(path) or {}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _content_with_metadata(content: str, metadata: dict[str, Any]) -> str:
|
| 140 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 141 |
+
path = Path(tmp_dir) / "README.md"
|
| 142 |
+
path.write_text(content, encoding="utf-8")
|
| 143 |
+
metadata_save(path, metadata)
|
| 144 |
+
return path.read_text(encoding="utf-8")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _body_without_metadata(content: str) -> str:
|
| 148 |
+
return _FRONT_MATTER_RE.sub("", content, count=1).strip()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _append_section(content: str, section: str) -> str:
|
| 152 |
+
base = content.rstrip()
|
| 153 |
+
if base:
|
| 154 |
+
return f"{base}\n\n{section.strip()}\n"
|
| 155 |
+
return f"{section.strip()}\n"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _provenance_section(repo_type: str) -> str:
|
| 159 |
+
label = {"model": "model", "dataset": "dataset"}.get(repo_type, "Hub")
|
| 160 |
+
return f"""{PROVENANCE_MARKER}
|
| 161 |
+
## Generated by ML Intern
|
| 162 |
+
|
| 163 |
+
This {label} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
|
| 164 |
+
|
| 165 |
+
- Try ML Intern: https://smolagents-ml-intern.hf.space
|
| 166 |
+
- Source code: https://github.com/huggingface/ml-intern
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _usage_section(repo_id: str, repo_type: str) -> str:
|
| 171 |
+
if repo_type == "dataset":
|
| 172 |
+
return f"""## Usage
|
| 173 |
+
|
| 174 |
+
```python
|
| 175 |
+
from datasets import load_dataset
|
| 176 |
+
|
| 177 |
+
dataset = load_dataset("{repo_id}")
|
| 178 |
+
```
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
return f"""## Usage
|
| 182 |
+
|
| 183 |
+
```python
|
| 184 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 185 |
+
|
| 186 |
+
model_id = "{repo_id}"
|
| 187 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 188 |
+
model = AutoModelForCausalLM.from_pretrained(model_id)
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def augment_repo_card_content(
|
| 196 |
+
content: str | None,
|
| 197 |
+
repo_id: str,
|
| 198 |
+
repo_type: str = "model",
|
| 199 |
+
*,
|
| 200 |
+
extra_metadata: dict[str, Any] | None = None,
|
| 201 |
+
) -> str:
|
| 202 |
+
"""Return README content with ML Intern metadata and provenance added."""
|
| 203 |
+
repo_type = repo_type or "model"
|
| 204 |
+
content = content or ""
|
| 205 |
+
metadata = _metadata_from_content(content)
|
| 206 |
+
if extra_metadata:
|
| 207 |
+
metadata = {**extra_metadata, **metadata}
|
| 208 |
+
metadata = _merge_tags(metadata)
|
| 209 |
+
updated = _content_with_metadata(content, metadata)
|
| 210 |
+
|
| 211 |
+
if not _body_without_metadata(updated):
|
| 212 |
+
updated = _append_section(updated, f"# {repo_id}")
|
| 213 |
+
|
| 214 |
+
if repo_type in {"model", "dataset"} and PROVENANCE_MARKER not in updated:
|
| 215 |
+
updated = _append_section(updated, _provenance_section(repo_type))
|
| 216 |
+
if not _USAGE_HEADING_RE.search(content):
|
| 217 |
+
updated = _append_section(updated, _usage_section(repo_id, repo_type))
|
| 218 |
+
|
| 219 |
+
return updated
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _read_remote_readme(
|
| 223 |
+
api: Any,
|
| 224 |
+
repo_id: str,
|
| 225 |
+
repo_type: str,
|
| 226 |
+
*,
|
| 227 |
+
token: str | bool | None = None,
|
| 228 |
+
) -> str:
|
| 229 |
+
token_value = token if token is not None else getattr(api, "token", None)
|
| 230 |
+
try:
|
| 231 |
+
readme_path = hf_hub_download(
|
| 232 |
+
repo_id=repo_id,
|
| 233 |
+
filename="README.md",
|
| 234 |
+
repo_type=repo_type,
|
| 235 |
+
token=token_value,
|
| 236 |
+
)
|
| 237 |
+
except (EntryNotFoundError, RepositoryNotFoundError):
|
| 238 |
+
return ""
|
| 239 |
+
return Path(readme_path).read_text(encoding="utf-8")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _update_repo_card(
|
| 243 |
+
api: Any,
|
| 244 |
+
repo_id: str,
|
| 245 |
+
repo_type: str,
|
| 246 |
+
*,
|
| 247 |
+
token: str | bool | None = None,
|
| 248 |
+
extra_metadata: dict[str, Any] | None = None,
|
| 249 |
+
) -> None:
|
| 250 |
+
current = _read_remote_readme(api, repo_id, repo_type, token=token)
|
| 251 |
+
updated = augment_repo_card_content(
|
| 252 |
+
current,
|
| 253 |
+
repo_id,
|
| 254 |
+
repo_type,
|
| 255 |
+
extra_metadata=extra_metadata,
|
| 256 |
+
)
|
| 257 |
+
if updated == current:
|
| 258 |
+
return
|
| 259 |
+
api.upload_file(
|
| 260 |
+
path_or_fileobj=updated.encode("utf-8"),
|
| 261 |
+
path_in_repo="README.md",
|
| 262 |
+
repo_id=repo_id,
|
| 263 |
+
repo_type=repo_type,
|
| 264 |
+
token=token,
|
| 265 |
+
commit_message="Update ML Intern artifact metadata",
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def _ensure_collection_slug(
|
| 270 |
+
api: Any,
|
| 271 |
+
session: Any,
|
| 272 |
+
*,
|
| 273 |
+
token: str | bool | None = None,
|
| 274 |
+
) -> str | None:
|
| 275 |
+
slug = getattr(session, _COLLECTION_SLUG_ATTR, None)
|
| 276 |
+
if slug:
|
| 277 |
+
return slug
|
| 278 |
+
|
| 279 |
+
title = artifact_collection_title(session)
|
| 280 |
+
collection = api.create_collection(
|
| 281 |
+
title=title,
|
| 282 |
+
description=(
|
| 283 |
+
f"Artifacts generated by ML Intern session {_safe_session_id(session)} "
|
| 284 |
+
f"on {session_artifact_date(session)}."
|
| 285 |
+
),
|
| 286 |
+
private=True,
|
| 287 |
+
exists_ok=True,
|
| 288 |
+
token=token,
|
| 289 |
+
)
|
| 290 |
+
slug = getattr(collection, "slug", None)
|
| 291 |
+
if slug:
|
| 292 |
+
setattr(session, _COLLECTION_SLUG_ATTR, slug)
|
| 293 |
+
return slug
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
async def ensure_session_artifact_collection(
|
| 297 |
+
session: Any,
|
| 298 |
+
*,
|
| 299 |
+
token: str | bool | None = None,
|
| 300 |
+
) -> str | None:
|
| 301 |
+
"""Create/cache the per-session artifact collection without raising."""
|
| 302 |
+
if session is None or not getattr(session, "session_id", None):
|
| 303 |
+
return None
|
| 304 |
+
token_value = token if token is not None else getattr(session, "hf_token", None)
|
| 305 |
+
if not token_value:
|
| 306 |
+
return None
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
api = HfApi(token=token_value)
|
| 310 |
+
return await asyncio.to_thread(
|
| 311 |
+
_ensure_collection_slug,
|
| 312 |
+
api,
|
| 313 |
+
session,
|
| 314 |
+
token=token_value,
|
| 315 |
+
)
|
| 316 |
+
except Exception as e:
|
| 317 |
+
logger.warning(
|
| 318 |
+
"ML Intern session collection creation failed for %s: %s",
|
| 319 |
+
_safe_session_id(session),
|
| 320 |
+
e,
|
| 321 |
+
)
|
| 322 |
+
return None
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def start_session_artifact_collection_task(
|
| 326 |
+
session: Any,
|
| 327 |
+
*,
|
| 328 |
+
token: str | bool | None = None,
|
| 329 |
+
) -> asyncio.Task | None:
|
| 330 |
+
"""Schedule best-effort collection creation for a newly started session."""
|
| 331 |
+
if session is None or not getattr(session, "session_id", None):
|
| 332 |
+
return None
|
| 333 |
+
if getattr(session, _COLLECTION_SLUG_ATTR, None):
|
| 334 |
+
return None
|
| 335 |
+
|
| 336 |
+
token_value = token if token is not None else getattr(session, "hf_token", None)
|
| 337 |
+
if not token_value:
|
| 338 |
+
return None
|
| 339 |
+
|
| 340 |
+
existing = getattr(session, _COLLECTION_TASK_ATTR, None)
|
| 341 |
+
if isinstance(existing, asyncio.Task) and not existing.done():
|
| 342 |
+
return existing
|
| 343 |
+
|
| 344 |
+
try:
|
| 345 |
+
loop = asyncio.get_running_loop()
|
| 346 |
+
except RuntimeError:
|
| 347 |
+
return None
|
| 348 |
+
|
| 349 |
+
async def _run() -> None:
|
| 350 |
+
await ensure_session_artifact_collection(session, token=token_value)
|
| 351 |
+
|
| 352 |
+
task = loop.create_task(_run())
|
| 353 |
+
try:
|
| 354 |
+
setattr(session, _COLLECTION_TASK_ATTR, task)
|
| 355 |
+
except Exception:
|
| 356 |
+
logger.debug("Could not attach ML Intern collection task to session")
|
| 357 |
+
return task
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def _add_to_collection(
|
| 361 |
+
api: Any,
|
| 362 |
+
session: Any,
|
| 363 |
+
repo_id: str,
|
| 364 |
+
repo_type: str,
|
| 365 |
+
*,
|
| 366 |
+
token: str | bool | None = None,
|
| 367 |
+
) -> None:
|
| 368 |
+
slug = _ensure_collection_slug(api, session, token=token)
|
| 369 |
+
if not slug:
|
| 370 |
+
return
|
| 371 |
+
api.add_collection_item(
|
| 372 |
+
collection_slug=slug,
|
| 373 |
+
item_id=repo_id,
|
| 374 |
+
item_type=repo_type,
|
| 375 |
+
note=(
|
| 376 |
+
f"Generated by ML Intern session {_safe_session_id(session)} "
|
| 377 |
+
f"on {session_artifact_date(session)}."
|
| 378 |
+
),
|
| 379 |
+
exists_ok=True,
|
| 380 |
+
token=token,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def register_hub_artifact(
|
| 385 |
+
api: Any,
|
| 386 |
+
repo_id: str,
|
| 387 |
+
repo_type: str = "model",
|
| 388 |
+
*,
|
| 389 |
+
session: Any = None,
|
| 390 |
+
token: str | bool | None = None,
|
| 391 |
+
extra_metadata: dict[str, Any] | None = None,
|
| 392 |
+
force: bool = False,
|
| 393 |
+
) -> bool:
|
| 394 |
+
"""Tag, card, and collection-register a Hub artifact without raising."""
|
| 395 |
+
if session is None or not repo_id:
|
| 396 |
+
return False
|
| 397 |
+
repo_type = repo_type or "model"
|
| 398 |
+
if repo_type not in SUPPORTED_REPO_TYPES:
|
| 399 |
+
return False
|
| 400 |
+
|
| 401 |
+
key = _artifact_key(repo_id, repo_type)
|
| 402 |
+
remember_hub_artifact(session, repo_id, repo_type)
|
| 403 |
+
registered = _session_artifact_set(session, _REGISTERED_ARTIFACTS_ATTR)
|
| 404 |
+
if key in registered and not force:
|
| 405 |
+
return True
|
| 406 |
+
|
| 407 |
+
token_value = token if token is not None else getattr(api, "token", None)
|
| 408 |
+
card_updated = False
|
| 409 |
+
collection_updated = False
|
| 410 |
+
try:
|
| 411 |
+
_update_repo_card(
|
| 412 |
+
api,
|
| 413 |
+
repo_id,
|
| 414 |
+
repo_type,
|
| 415 |
+
token=token_value,
|
| 416 |
+
extra_metadata=extra_metadata,
|
| 417 |
+
)
|
| 418 |
+
card_updated = True
|
| 419 |
+
except Exception as e:
|
| 420 |
+
logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e)
|
| 421 |
+
|
| 422 |
+
try:
|
| 423 |
+
_add_to_collection(api, session, repo_id, repo_type, token=token_value)
|
| 424 |
+
collection_updated = True
|
| 425 |
+
except Exception as e:
|
| 426 |
+
logger.debug("ML Intern collection update failed for %s: %s", repo_id, e)
|
| 427 |
+
|
| 428 |
+
if card_updated and collection_updated:
|
| 429 |
+
registered.add(key)
|
| 430 |
+
return True
|
| 431 |
+
return False
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def build_hub_artifact_sitecustomize(session: Any) -> str:
|
| 435 |
+
"""Build standalone sitecustomize.py code for HF Jobs Python processes."""
|
| 436 |
+
if session is None or not getattr(session, "session_id", None):
|
| 437 |
+
return ""
|
| 438 |
+
|
| 439 |
+
session_id = _safe_session_id(session)
|
| 440 |
+
session_date = session_artifact_date(session)
|
| 441 |
+
collection_title = artifact_collection_title(session)
|
| 442 |
+
collection_slug = getattr(session, _COLLECTION_SLUG_ATTR, None)
|
| 443 |
+
|
| 444 |
+
return (
|
| 445 |
+
textwrap.dedent(
|
| 446 |
+
f"""
|
| 447 |
+
# Auto-generated by ML Intern. Best-effort Hub artifact metadata only.
|
| 448 |
+
def _install_ml_intern_artifact_hooks():
|
| 449 |
+
import os
|
| 450 |
+
import re
|
| 451 |
+
import tempfile
|
| 452 |
+
from pathlib import Path
|
| 453 |
+
|
| 454 |
+
try:
|
| 455 |
+
import huggingface_hub as _hub
|
| 456 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 457 |
+
from huggingface_hub.repocard import metadata_load, metadata_save
|
| 458 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 459 |
+
except Exception:
|
| 460 |
+
return
|
| 461 |
+
|
| 462 |
+
session_id = {session_id!r}
|
| 463 |
+
session_date = {session_date!r}
|
| 464 |
+
collection_title = {collection_title!r}
|
| 465 |
+
tag = {ML_INTERN_TAG!r}
|
| 466 |
+
marker = {PROVENANCE_MARKER!r}
|
| 467 |
+
supported = {sorted(SUPPORTED_REPO_TYPES)!r}
|
| 468 |
+
registering = False
|
| 469 |
+
collection_slug = {collection_slug!r}
|
| 470 |
+
registered = set()
|
| 471 |
+
usage_re = re.compile(
|
| 472 |
+
r"^#{{2,6}}\\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\\b",
|
| 473 |
+
re.IGNORECASE | re.MULTILINE,
|
| 474 |
+
)
|
| 475 |
+
front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL)
|
| 476 |
+
|
| 477 |
+
def _token(value=None, api=None):
|
| 478 |
+
if isinstance(value, str) and value:
|
| 479 |
+
return value
|
| 480 |
+
api_token = getattr(api, "token", None)
|
| 481 |
+
if isinstance(api_token, str) and api_token:
|
| 482 |
+
return api_token
|
| 483 |
+
return (
|
| 484 |
+
os.environ.get("HF_TOKEN")
|
| 485 |
+
or os.environ.get("HUGGINGFACE_HUB_TOKEN")
|
| 486 |
+
or None
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
def _merge_tags(metadata):
|
| 490 |
+
metadata = dict(metadata or {{}})
|
| 491 |
+
raw_tags = metadata.get("tags")
|
| 492 |
+
if raw_tags is None:
|
| 493 |
+
tags = []
|
| 494 |
+
elif isinstance(raw_tags, str):
|
| 495 |
+
tags = [raw_tags]
|
| 496 |
+
elif isinstance(raw_tags, list):
|
| 497 |
+
tags = [str(item) for item in raw_tags]
|
| 498 |
+
else:
|
| 499 |
+
tags = [str(raw_tags)]
|
| 500 |
+
if tag not in tags:
|
| 501 |
+
tags.append(tag)
|
| 502 |
+
metadata["tags"] = tags
|
| 503 |
+
return metadata
|
| 504 |
+
|
| 505 |
+
def _metadata_from_content(content):
|
| 506 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 507 |
+
path = Path(tmp_dir) / "README.md"
|
| 508 |
+
path.write_text(content or "", encoding="utf-8")
|
| 509 |
+
return metadata_load(path) or {{}}
|
| 510 |
+
|
| 511 |
+
def _content_with_metadata(content, metadata):
|
| 512 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 513 |
+
path = Path(tmp_dir) / "README.md"
|
| 514 |
+
path.write_text(content or "", encoding="utf-8")
|
| 515 |
+
metadata_save(path, metadata)
|
| 516 |
+
return path.read_text(encoding="utf-8")
|
| 517 |
+
|
| 518 |
+
def _body_without_metadata(content):
|
| 519 |
+
return front_matter_re.sub("", content or "", count=1).strip()
|
| 520 |
+
|
| 521 |
+
def _append_section(content, section):
|
| 522 |
+
base = (content or "").rstrip()
|
| 523 |
+
if base:
|
| 524 |
+
return base + "\\n\\n" + section.strip() + "\\n"
|
| 525 |
+
return section.strip() + "\\n"
|
| 526 |
+
|
| 527 |
+
def _provenance(repo_type):
|
| 528 |
+
label = {{"model": "model", "dataset": "dataset"}}.get(
|
| 529 |
+
repo_type, "Hub"
|
| 530 |
+
)
|
| 531 |
+
return (
|
| 532 |
+
marker
|
| 533 |
+
+ "\\n## Generated by ML Intern\\n\\n"
|
| 534 |
+
+ f"This {{label}} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.\\n\\n"
|
| 535 |
+
+ "- Try ML Intern: https://smolagents-ml-intern.hf.space\\n"
|
| 536 |
+
+ "- Source code: https://github.com/huggingface/ml-intern\\n"
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
def _usage(repo_id, repo_type):
|
| 540 |
+
if repo_type == "dataset":
|
| 541 |
+
return (
|
| 542 |
+
"## Usage\\n\\n"
|
| 543 |
+
"```python\\n"
|
| 544 |
+
"from datasets import load_dataset\\n\\n"
|
| 545 |
+
f"dataset = load_dataset({{repo_id!r}})\\n"
|
| 546 |
+
"```\\n"
|
| 547 |
+
)
|
| 548 |
+
return (
|
| 549 |
+
"## Usage\\n\\n"
|
| 550 |
+
"```python\\n"
|
| 551 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\\n\\n"
|
| 552 |
+
f"model_id = {{repo_id!r}}\\n"
|
| 553 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id)\\n"
|
| 554 |
+
"model = AutoModelForCausalLM.from_pretrained(model_id)\\n"
|
| 555 |
+
"```\\n\\n"
|
| 556 |
+
"For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.\\n"
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def _augment(content, repo_id, repo_type, extra_metadata=None):
|
| 560 |
+
metadata = _metadata_from_content(content or "")
|
| 561 |
+
if extra_metadata:
|
| 562 |
+
metadata = {{**extra_metadata, **metadata}}
|
| 563 |
+
updated = _content_with_metadata(content or "", _merge_tags(metadata))
|
| 564 |
+
if not _body_without_metadata(updated):
|
| 565 |
+
updated = _append_section(updated, f"# {{repo_id}}")
|
| 566 |
+
if repo_type in {{"model", "dataset"}} and marker not in updated:
|
| 567 |
+
updated = _append_section(updated, _provenance(repo_type))
|
| 568 |
+
if not usage_re.search(content or ""):
|
| 569 |
+
updated = _append_section(updated, _usage(repo_id, repo_type))
|
| 570 |
+
return updated
|
| 571 |
+
|
| 572 |
+
def _readme(api, repo_id, repo_type, token_value):
|
| 573 |
+
try:
|
| 574 |
+
path = hf_hub_download(
|
| 575 |
+
repo_id=repo_id,
|
| 576 |
+
filename="README.md",
|
| 577 |
+
repo_type=repo_type,
|
| 578 |
+
token=token_value,
|
| 579 |
+
)
|
| 580 |
+
except (EntryNotFoundError, RepositoryNotFoundError):
|
| 581 |
+
return ""
|
| 582 |
+
return Path(path).read_text(encoding="utf-8")
|
| 583 |
+
|
| 584 |
+
def _ensure_collection(api, token_value):
|
| 585 |
+
nonlocal collection_slug
|
| 586 |
+
if collection_slug:
|
| 587 |
+
return collection_slug
|
| 588 |
+
collection = api.create_collection(
|
| 589 |
+
title=collection_title,
|
| 590 |
+
description=(
|
| 591 |
+
f"Artifacts generated by ML Intern session {{session_id}} "
|
| 592 |
+
f"on {{session_date}}."
|
| 593 |
+
),
|
| 594 |
+
private=True,
|
| 595 |
+
exists_ok=True,
|
| 596 |
+
token=token_value,
|
| 597 |
+
)
|
| 598 |
+
collection_slug = getattr(collection, "slug", None)
|
| 599 |
+
return collection_slug
|
| 600 |
+
|
| 601 |
+
def _register(
|
| 602 |
+
repo_id,
|
| 603 |
+
repo_type="model",
|
| 604 |
+
token_value=None,
|
| 605 |
+
extra_metadata=None,
|
| 606 |
+
force=False,
|
| 607 |
+
):
|
| 608 |
+
nonlocal registering
|
| 609 |
+
if registering or not repo_id:
|
| 610 |
+
return
|
| 611 |
+
repo_type = repo_type or "model"
|
| 612 |
+
if repo_type not in supported:
|
| 613 |
+
return
|
| 614 |
+
key = f"{{repo_type}}:{{repo_id}}"
|
| 615 |
+
if key in registered and not force:
|
| 616 |
+
return
|
| 617 |
+
registering = True
|
| 618 |
+
try:
|
| 619 |
+
token_value = _token(token_value)
|
| 620 |
+
api = HfApi(token=token_value)
|
| 621 |
+
try:
|
| 622 |
+
current = _readme(api, repo_id, repo_type, token_value)
|
| 623 |
+
updated = _augment(
|
| 624 |
+
current, repo_id, repo_type, extra_metadata=extra_metadata
|
| 625 |
+
)
|
| 626 |
+
if updated != current:
|
| 627 |
+
_original_upload_file(
|
| 628 |
+
api,
|
| 629 |
+
path_or_fileobj=updated.encode("utf-8"),
|
| 630 |
+
path_in_repo="README.md",
|
| 631 |
+
repo_id=repo_id,
|
| 632 |
+
repo_type=repo_type,
|
| 633 |
+
token=token_value,
|
| 634 |
+
commit_message="Update ML Intern artifact metadata",
|
| 635 |
+
)
|
| 636 |
+
except Exception:
|
| 637 |
+
pass
|
| 638 |
+
try:
|
| 639 |
+
slug = _ensure_collection(api, token_value)
|
| 640 |
+
if slug:
|
| 641 |
+
api.add_collection_item(
|
| 642 |
+
collection_slug=slug,
|
| 643 |
+
item_id=repo_id,
|
| 644 |
+
item_type=repo_type,
|
| 645 |
+
note=(
|
| 646 |
+
f"Generated by ML Intern session {{session_id}} "
|
| 647 |
+
f"on {{session_date}}."
|
| 648 |
+
),
|
| 649 |
+
exists_ok=True,
|
| 650 |
+
token=token_value,
|
| 651 |
+
)
|
| 652 |
+
except Exception:
|
| 653 |
+
pass
|
| 654 |
+
registered.add(key)
|
| 655 |
+
finally:
|
| 656 |
+
registering = False
|
| 657 |
+
|
| 658 |
+
_original_create_repo = HfApi.create_repo
|
| 659 |
+
_original_upload_file = HfApi.upload_file
|
| 660 |
+
_original_upload_folder = getattr(HfApi, "upload_folder", None)
|
| 661 |
+
_original_create_commit = getattr(HfApi, "create_commit", None)
|
| 662 |
+
|
| 663 |
+
def _repo_id(args, kwargs):
|
| 664 |
+
return kwargs.get("repo_id") or (args[0] if args else None)
|
| 665 |
+
|
| 666 |
+
def _repo_type(kwargs):
|
| 667 |
+
return kwargs.get("repo_type") or "model"
|
| 668 |
+
|
| 669 |
+
def _patched_create_repo(self, *args, **kwargs):
|
| 670 |
+
result = _original_create_repo(self, *args, **kwargs)
|
| 671 |
+
repo_id = _repo_id(args, kwargs)
|
| 672 |
+
repo_type = _repo_type(kwargs)
|
| 673 |
+
extra = None
|
| 674 |
+
if repo_type == "space" and kwargs.get("space_sdk"):
|
| 675 |
+
extra = {{"sdk": kwargs.get("space_sdk")}}
|
| 676 |
+
_register(repo_id, repo_type, _token(kwargs.get("token"), self), extra)
|
| 677 |
+
return result
|
| 678 |
+
|
| 679 |
+
def _patched_upload_file(self, *args, **kwargs):
|
| 680 |
+
result = _original_upload_file(self, *args, **kwargs)
|
| 681 |
+
if not kwargs.get("create_pr"):
|
| 682 |
+
force = kwargs.get("path_in_repo") == "README.md"
|
| 683 |
+
_register(
|
| 684 |
+
kwargs.get("repo_id"),
|
| 685 |
+
_repo_type(kwargs),
|
| 686 |
+
_token(kwargs.get("token"), self),
|
| 687 |
+
force=force,
|
| 688 |
+
)
|
| 689 |
+
return result
|
| 690 |
+
|
| 691 |
+
def _patched_upload_folder(self, *args, **kwargs):
|
| 692 |
+
result = _original_upload_folder(self, *args, **kwargs)
|
| 693 |
+
if not kwargs.get("create_pr"):
|
| 694 |
+
_register(
|
| 695 |
+
kwargs.get("repo_id"),
|
| 696 |
+
_repo_type(kwargs),
|
| 697 |
+
_token(kwargs.get("token"), self),
|
| 698 |
+
force=True,
|
| 699 |
+
)
|
| 700 |
+
return result
|
| 701 |
+
|
| 702 |
+
def _patched_create_commit(self, *args, **kwargs):
|
| 703 |
+
result = _original_create_commit(self, *args, **kwargs)
|
| 704 |
+
if not kwargs.get("create_pr"):
|
| 705 |
+
_register(
|
| 706 |
+
_repo_id(args, kwargs),
|
| 707 |
+
_repo_type(kwargs),
|
| 708 |
+
_token(kwargs.get("token"), self),
|
| 709 |
+
force=True,
|
| 710 |
+
)
|
| 711 |
+
return result
|
| 712 |
+
|
| 713 |
+
HfApi.create_repo = _patched_create_repo
|
| 714 |
+
HfApi.upload_file = _patched_upload_file
|
| 715 |
+
if _original_upload_folder is not None:
|
| 716 |
+
HfApi.upload_folder = _patched_upload_folder
|
| 717 |
+
if _original_create_commit is not None:
|
| 718 |
+
HfApi.create_commit = _patched_create_commit
|
| 719 |
+
|
| 720 |
+
def _patch_module_func(name, method_name):
|
| 721 |
+
original = getattr(_hub, name, None)
|
| 722 |
+
if original is None:
|
| 723 |
+
return
|
| 724 |
+
method = getattr(HfApi, method_name)
|
| 725 |
+
|
| 726 |
+
def _patched(*args, **kwargs):
|
| 727 |
+
api = HfApi(token=_token(kwargs.get("token")))
|
| 728 |
+
return method(api, *args, **kwargs)
|
| 729 |
+
|
| 730 |
+
setattr(_hub, name, _patched)
|
| 731 |
+
|
| 732 |
+
_patch_module_func("create_repo", "create_repo")
|
| 733 |
+
_patch_module_func("upload_file", "upload_file")
|
| 734 |
+
if _original_upload_folder is not None:
|
| 735 |
+
_patch_module_func("upload_folder", "upload_folder")
|
| 736 |
+
if _original_create_commit is not None:
|
| 737 |
+
_patch_module_func("create_commit", "create_commit")
|
| 738 |
+
|
| 739 |
+
try:
|
| 740 |
+
_install_ml_intern_artifact_hooks()
|
| 741 |
+
except Exception:
|
| 742 |
+
pass
|
| 743 |
+
"""
|
| 744 |
+
).strip()
|
| 745 |
+
+ "\n"
|
| 746 |
+
)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def wrap_shell_command_with_hub_artifact_bootstrap(
|
| 750 |
+
command: str,
|
| 751 |
+
session: Any,
|
| 752 |
+
) -> str:
|
| 753 |
+
"""Prefix a shell command so child Python processes load Hub hooks."""
|
| 754 |
+
sitecustomize = build_hub_artifact_sitecustomize(session)
|
| 755 |
+
if not sitecustomize or not command:
|
| 756 |
+
return command
|
| 757 |
+
|
| 758 |
+
encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii")
|
| 759 |
+
bootstrap = (
|
| 760 |
+
'_ml_intern_artifacts_dir="$(mktemp -d 2>/dev/null)" '
|
| 761 |
+
f"&& printf %s {shlex.quote(encoded)} | base64 -d "
|
| 762 |
+
'> "$_ml_intern_artifacts_dir/sitecustomize.py" '
|
| 763 |
+
'&& export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"'
|
| 764 |
+
)
|
| 765 |
+
return f"{bootstrap}; {command}"
|
agent/core/llm_params.py
CHANGED
|
@@ -56,9 +56,16 @@ def _patch_litellm_effort_validation() -> None:
|
|
| 56 |
# to return True for families where "max" / "xhigh" are acceptable
|
| 57 |
# at the API; the cascade handles the case when they're not.
|
| 58 |
return any(
|
| 59 |
-
v in m
|
| 60 |
-
|
| 61 |
-
"opus-4-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
)
|
| 63 |
)
|
| 64 |
|
|
|
|
| 56 |
# to return True for families where "max" / "xhigh" are acceptable
|
| 57 |
# at the API; the cascade handles the case when they're not.
|
| 58 |
return any(
|
| 59 |
+
v in m
|
| 60 |
+
for v in (
|
| 61 |
+
"opus-4-6",
|
| 62 |
+
"opus_4_6",
|
| 63 |
+
"opus-4.6",
|
| 64 |
+
"opus_4.6",
|
| 65 |
+
"opus-4-7",
|
| 66 |
+
"opus_4_7",
|
| 67 |
+
"opus-4.7",
|
| 68 |
+
"opus_4.7",
|
| 69 |
)
|
| 70 |
)
|
| 71 |
|
agent/core/model_switcher.py
CHANGED
|
@@ -28,7 +28,10 @@ SUGGESTED_MODELS = [
|
|
| 28 |
{"id": "openai/gpt-5.4", "label": "GPT-5.4"},
|
| 29 |
{"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"},
|
| 30 |
{"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"},
|
| 31 |
-
{
|
|
|
|
|
|
|
|
|
|
| 32 |
{"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"},
|
| 33 |
{"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"},
|
| 34 |
{"id": "zai-org/GLM-5.1", "label": "GLM 5.1"},
|
|
@@ -122,9 +125,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool:
|
|
| 122 |
)
|
| 123 |
ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a"
|
| 124 |
tools = "tools" if p.supports_tools else "no tools"
|
| 125 |
-
console.print(
|
| 126 |
-
f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]"
|
| 127 |
-
)
|
| 128 |
return True
|
| 129 |
|
| 130 |
|
|
@@ -183,7 +184,9 @@ async def probe_and_switch_model(
|
|
| 183 |
# Nothing to validate with a ping that we couldn't validate on the
|
| 184 |
# first real call just as cheaply. Skip the probe entirely.
|
| 185 |
_commit_switch(model_id, config, session, effective=None, cache=False)
|
| 186 |
-
console.print(
|
|
|
|
|
|
|
| 187 |
return
|
| 188 |
|
| 189 |
console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]")
|
|
@@ -203,8 +206,11 @@ async def probe_and_switch_model(
|
|
| 203 |
return
|
| 204 |
|
| 205 |
_commit_switch(
|
| 206 |
-
model_id,
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
| 208 |
)
|
| 209 |
effort_label = outcome.effective_effort or "off"
|
| 210 |
suffix = f" — {outcome.note}" if outcome.note else ""
|
|
|
|
| 28 |
{"id": "openai/gpt-5.4", "label": "GPT-5.4"},
|
| 29 |
{"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"},
|
| 30 |
{"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"},
|
| 31 |
+
{
|
| 32 |
+
"id": "bedrock/us.anthropic.claude-opus-4-6-v1",
|
| 33 |
+
"label": "Claude Opus 4.6 via Bedrock",
|
| 34 |
+
},
|
| 35 |
{"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"},
|
| 36 |
{"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"},
|
| 37 |
{"id": "zai-org/GLM-5.1", "label": "GLM 5.1"},
|
|
|
|
| 125 |
)
|
| 126 |
ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a"
|
| 127 |
tools = "tools" if p.supports_tools else "no tools"
|
| 128 |
+
console.print(f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]")
|
|
|
|
|
|
|
| 129 |
return True
|
| 130 |
|
| 131 |
|
|
|
|
| 184 |
# Nothing to validate with a ping that we couldn't validate on the
|
| 185 |
# first real call just as cheaply. Skip the probe entirely.
|
| 186 |
_commit_switch(model_id, config, session, effective=None, cache=False)
|
| 187 |
+
console.print(
|
| 188 |
+
f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]"
|
| 189 |
+
)
|
| 190 |
return
|
| 191 |
|
| 192 |
console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]")
|
|
|
|
| 206 |
return
|
| 207 |
|
| 208 |
_commit_switch(
|
| 209 |
+
model_id,
|
| 210 |
+
config,
|
| 211 |
+
session,
|
| 212 |
+
effective=outcome.effective_effort,
|
| 213 |
+
cache=True,
|
| 214 |
)
|
| 215 |
effort_label = outcome.effective_effort or "off"
|
| 216 |
suffix = f" — {outcome.note}" if outcome.note else ""
|
agent/core/prompt_caching.py
CHANGED
|
@@ -40,7 +40,11 @@ def with_prompt_caching(
|
|
| 40 |
|
| 41 |
if messages:
|
| 42 |
first = messages[0]
|
| 43 |
-
role =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
if role == "system":
|
| 45 |
content = (
|
| 46 |
first.get("content")
|
|
@@ -48,11 +52,13 @@ def with_prompt_caching(
|
|
| 48 |
else getattr(first, "content", None)
|
| 49 |
)
|
| 50 |
if isinstance(content, str) and content:
|
| 51 |
-
cached_block = [
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
new_first = {"role": "system", "content": cached_block}
|
| 57 |
messages = [new_first] + list(messages[1:])
|
| 58 |
|
|
|
|
| 40 |
|
| 41 |
if messages:
|
| 42 |
first = messages[0]
|
| 43 |
+
role = (
|
| 44 |
+
first.get("role")
|
| 45 |
+
if isinstance(first, dict)
|
| 46 |
+
else getattr(first, "role", None)
|
| 47 |
+
)
|
| 48 |
if role == "system":
|
| 49 |
content = (
|
| 50 |
first.get("content")
|
|
|
|
| 52 |
else getattr(first, "content", None)
|
| 53 |
)
|
| 54 |
if isinstance(content, str) and content:
|
| 55 |
+
cached_block = [
|
| 56 |
+
{
|
| 57 |
+
"type": "text",
|
| 58 |
+
"text": content,
|
| 59 |
+
"cache_control": {"type": "ephemeral"},
|
| 60 |
+
}
|
| 61 |
+
]
|
| 62 |
new_first = {"role": "system", "content": cached_block}
|
| 63 |
messages = [new_first] + list(messages[1:])
|
| 64 |
|
agent/core/session.py
CHANGED
|
@@ -48,7 +48,8 @@ def _get_max_tokens_safe(model_name: str) -> int:
|
|
| 48 |
continue
|
| 49 |
logger.info(
|
| 50 |
"No litellm.get_model_info entry for %s, falling back to %d",
|
| 51 |
-
model_name,
|
|
|
|
| 52 |
)
|
| 53 |
return _DEFAULT_MAX_TOKENS
|
| 54 |
|
|
@@ -277,8 +278,7 @@ class Session:
|
|
| 277 |
if summary:
|
| 278 |
summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS]
|
| 279 |
message = (
|
| 280 |
-
f"Session {self.session_id} completed successfully.\n"
|
| 281 |
-
f"{summary}"
|
| 282 |
)
|
| 283 |
else:
|
| 284 |
message = f"Session {self.session_id} completed successfully."
|
|
@@ -444,6 +444,7 @@ class Session:
|
|
| 444 |
# snapshot between heartbeats would otherwise leak them.
|
| 445 |
try:
|
| 446 |
from agent.core.redact import scrub
|
|
|
|
| 447 |
for key in ("messages", "events", "tools"):
|
| 448 |
if key in trajectory:
|
| 449 |
trajectory[key] = scrub(trajectory[key])
|
|
|
|
| 48 |
continue
|
| 49 |
logger.info(
|
| 50 |
"No litellm.get_model_info entry for %s, falling back to %d",
|
| 51 |
+
model_name,
|
| 52 |
+
_DEFAULT_MAX_TOKENS,
|
| 53 |
)
|
| 54 |
return _DEFAULT_MAX_TOKENS
|
| 55 |
|
|
|
|
| 278 |
if summary:
|
| 279 |
summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS]
|
| 280 |
message = (
|
| 281 |
+
f"Session {self.session_id} completed successfully.\n{summary}"
|
|
|
|
| 282 |
)
|
| 283 |
else:
|
| 284 |
message = f"Session {self.session_id} completed successfully."
|
|
|
|
| 444 |
# snapshot between heartbeats would otherwise leak them.
|
| 445 |
try:
|
| 446 |
from agent.core.redact import scrub
|
| 447 |
+
|
| 448 |
for key in ("messages", "events", "tools"):
|
| 449 |
if key in trajectory:
|
| 450 |
trajectory[key] = scrub(trajectory[key])
|
agent/core/session_persistence.py
CHANGED
|
@@ -271,7 +271,9 @@ class MongoSessionStore(NoopSessionStore):
|
|
| 271 |
upsert=True,
|
| 272 |
)
|
| 273 |
)
|
| 274 |
-
ops.append(
|
|
|
|
|
|
|
| 275 |
try:
|
| 276 |
if ops:
|
| 277 |
await self.db.session_messages.bulk_write(ops, ordered=False)
|
|
@@ -288,7 +290,9 @@ class MongoSessionStore(NoopSessionStore):
|
|
| 288 |
return None
|
| 289 |
if meta.get("visibility") == "deleted" and not include_deleted:
|
| 290 |
return None
|
| 291 |
-
cursor = self.db.session_messages.find({"session_id": session_id}).sort(
|
|
|
|
|
|
|
| 292 |
messages = [row.get("message") async for row in cursor]
|
| 293 |
return {"metadata": meta, "messages": messages}
|
| 294 |
|
|
@@ -356,7 +360,9 @@ class MongoSessionStore(NoopSessionStore):
|
|
| 356 |
logger.debug("Failed to append event for %s: %s", session_id, e)
|
| 357 |
return None
|
| 358 |
|
| 359 |
-
async def load_events_after(
|
|
|
|
|
|
|
| 360 |
if not self._ready():
|
| 361 |
return []
|
| 362 |
cursor = self.db.session_events.find(
|
|
@@ -496,6 +502,8 @@ def get_session_store() -> NoopSessionStore | MongoSessionStore:
|
|
| 496 |
return _store
|
| 497 |
|
| 498 |
|
| 499 |
-
def _reset_store_for_tests(
|
|
|
|
|
|
|
| 500 |
global _store
|
| 501 |
_store = store
|
|
|
|
| 271 |
upsert=True,
|
| 272 |
)
|
| 273 |
)
|
| 274 |
+
ops.append(
|
| 275 |
+
DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}})
|
| 276 |
+
)
|
| 277 |
try:
|
| 278 |
if ops:
|
| 279 |
await self.db.session_messages.bulk_write(ops, ordered=False)
|
|
|
|
| 290 |
return None
|
| 291 |
if meta.get("visibility") == "deleted" and not include_deleted:
|
| 292 |
return None
|
| 293 |
+
cursor = self.db.session_messages.find({"session_id": session_id}).sort(
|
| 294 |
+
"idx", 1
|
| 295 |
+
)
|
| 296 |
messages = [row.get("message") async for row in cursor]
|
| 297 |
return {"metadata": meta, "messages": messages}
|
| 298 |
|
|
|
|
| 360 |
logger.debug("Failed to append event for %s: %s", session_id, e)
|
| 361 |
return None
|
| 362 |
|
| 363 |
+
async def load_events_after(
|
| 364 |
+
self, session_id: str, after_seq: int = 0
|
| 365 |
+
) -> list[dict[str, Any]]:
|
| 366 |
if not self._ready():
|
| 367 |
return []
|
| 368 |
cursor = self.db.session_events.find(
|
|
|
|
| 502 |
return _store
|
| 503 |
|
| 504 |
|
| 505 |
+
def _reset_store_for_tests(
|
| 506 |
+
store: NoopSessionStore | MongoSessionStore | None = None,
|
| 507 |
+
) -> None:
|
| 508 |
global _store
|
| 509 |
_store = store
|
agent/core/session_uploader.py
CHANGED
|
@@ -94,8 +94,7 @@ def _msg_uuid(session_id: str, role: str, idx: int) -> str:
|
|
| 94 |
digest = hashlib.sha1(f"{session_id}::{role}::{idx}".encode("utf-8")).hexdigest()
|
| 95 |
# Format like a UUID for visual familiarity (32 hex chars w/ dashes).
|
| 96 |
return (
|
| 97 |
-
f"{digest[0:8]}-{digest[8:12]}-{digest[12:16]}-"
|
| 98 |
-
f"{digest[16:20]}-{digest[20:32]}"
|
| 99 |
)
|
| 100 |
|
| 101 |
|
|
@@ -347,7 +346,7 @@ def _update_upload_status(
|
|
| 347 |
|
| 348 |
def dataset_card_readme(repo_id: str) -> str:
|
| 349 |
"""Dataset card for personal ML Intern session trace repos."""
|
| 350 |
-
return
|
| 351 |
pretty_name: "ML Intern Session Traces"
|
| 352 |
language:
|
| 353 |
- en
|
|
|
|
| 94 |
digest = hashlib.sha1(f"{session_id}::{role}::{idx}".encode("utf-8")).hexdigest()
|
| 95 |
# Format like a UUID for visual familiarity (32 hex chars w/ dashes).
|
| 96 |
return (
|
| 97 |
+
f"{digest[0:8]}-{digest[8:12]}-{digest[12:16]}-{digest[16:20]}-{digest[20:32]}"
|
|
|
|
| 98 |
)
|
| 99 |
|
| 100 |
|
|
|
|
| 346 |
|
| 347 |
def dataset_card_readme(repo_id: str) -> str:
|
| 348 |
"""Dataset card for personal ML Intern session trace repos."""
|
| 349 |
+
return """---
|
| 350 |
pretty_name: "ML Intern Session Traces"
|
| 351 |
language:
|
| 352 |
- en
|
agent/core/telemetry.py
CHANGED
|
@@ -26,6 +26,7 @@ logger = logging.getLogger(__name__)
|
|
| 26 |
|
| 27 |
# ── usage extraction ────────────────────────────────────────────────────────
|
| 28 |
|
|
|
|
| 29 |
def extract_usage(response_or_chunk: Any) -> dict:
|
| 30 |
"""Flat usage dict from a litellm response or final-chunk usage object.
|
| 31 |
|
|
@@ -71,6 +72,7 @@ def extract_usage(response_or_chunk: Any) -> dict:
|
|
| 71 |
|
| 72 |
# ── llm_call ────────────────────────────────────────────────────────────────
|
| 73 |
|
|
|
|
| 74 |
async def record_llm_call(
|
| 75 |
session: Any,
|
| 76 |
*,
|
|
@@ -106,22 +108,26 @@ async def record_llm_call(
|
|
| 106 |
if response is not None:
|
| 107 |
try:
|
| 108 |
from litellm import completion_cost
|
|
|
|
| 109 |
cost_usd = float(completion_cost(completion_response=response) or 0.0)
|
| 110 |
except Exception:
|
| 111 |
cost_usd = 0.0
|
| 112 |
from agent.core.session import Event # local import to avoid cycle
|
|
|
|
| 113 |
try:
|
| 114 |
-
await session.send_event(
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
except Exception as e:
|
| 126 |
logger.debug("record_llm_call failed (non-fatal): %s", e)
|
| 127 |
return usage
|
|
@@ -129,6 +135,7 @@ async def record_llm_call(
|
|
| 129 |
|
| 130 |
# ── hf_jobs ────────────────────────────────────────────────────────────────
|
| 131 |
|
|
|
|
| 132 |
def _infer_push_to_hub(script_or_cmd: Any) -> bool:
|
| 133 |
if not isinstance(script_or_cmd, str):
|
| 134 |
return False
|
|
@@ -150,22 +157,25 @@ async def record_hf_job_submit(
|
|
| 150 |
"""Emit ``hf_job_submit``. Returns the monotonic start timestamp so the
|
| 151 |
caller can pass it back into :func:`record_hf_job_complete`."""
|
| 152 |
from agent.core.session import Event
|
|
|
|
| 153 |
t_start = time.monotonic()
|
| 154 |
try:
|
| 155 |
script_text = args.get("script") or args.get("command") or ""
|
| 156 |
-
await session.send_event(
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
except Exception as e:
|
| 170 |
logger.debug("record_hf_job_submit failed (non-fatal): %s", e)
|
| 171 |
return t_start
|
|
@@ -180,23 +190,27 @@ async def record_hf_job_complete(
|
|
| 180 |
submit_ts: float,
|
| 181 |
) -> None:
|
| 182 |
from agent.core.session import Event
|
|
|
|
| 183 |
try:
|
| 184 |
wall_time_s = int(time.monotonic() - submit_ts)
|
| 185 |
-
await session.send_event(
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
| 194 |
except Exception as e:
|
| 195 |
logger.debug("record_hf_job_complete failed (non-fatal): %s", e)
|
| 196 |
|
| 197 |
|
| 198 |
# ── sandbox ────────────────────────────────────���────────────────────────────
|
| 199 |
|
|
|
|
| 200 |
async def record_sandbox_create(
|
| 201 |
session: Any,
|
| 202 |
sandbox: Any,
|
|
@@ -205,39 +219,46 @@ async def record_sandbox_create(
|
|
| 205 |
create_latency_s: int,
|
| 206 |
) -> None:
|
| 207 |
from agent.core.session import Event
|
|
|
|
| 208 |
try:
|
| 209 |
# Pin created-at on the session so record_sandbox_destroy can diff.
|
| 210 |
session._sandbox_created_at = time.monotonic() - create_latency_s
|
| 211 |
-
await session.send_event(
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
| 219 |
except Exception as e:
|
| 220 |
logger.debug("record_sandbox_create failed (non-fatal): %s", e)
|
| 221 |
|
| 222 |
|
| 223 |
async def record_sandbox_destroy(session: Any, sandbox: Any) -> None:
|
| 224 |
from agent.core.session import Event
|
|
|
|
| 225 |
try:
|
| 226 |
created = getattr(session, "_sandbox_created_at", None)
|
| 227 |
lifetime_s = int(time.monotonic() - created) if created else None
|
| 228 |
-
await session.send_event(
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
| 235 |
except Exception as e:
|
| 236 |
logger.debug("record_sandbox_destroy failed (non-fatal): %s", e)
|
| 237 |
|
| 238 |
|
| 239 |
# ── feedback ───────────────────────────────────────────────────────────────
|
| 240 |
|
|
|
|
| 241 |
async def record_feedback(
|
| 242 |
session: Any,
|
| 243 |
*,
|
|
@@ -247,16 +268,19 @@ async def record_feedback(
|
|
| 247 |
comment: str | None = None,
|
| 248 |
) -> None:
|
| 249 |
from agent.core.session import Event
|
|
|
|
| 250 |
try:
|
| 251 |
-
await session.send_event(
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
|
|
|
| 260 |
except Exception as e:
|
| 261 |
logger.debug("record_feedback failed (non-fatal): %s", e)
|
| 262 |
|
|
@@ -269,15 +293,18 @@ async def record_jobs_access_blocked(
|
|
| 269 |
eligible_namespaces: list[str],
|
| 270 |
) -> None:
|
| 271 |
from agent.core.session import Event
|
|
|
|
| 272 |
try:
|
| 273 |
-
await session.send_event(
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
| 281 |
except Exception as e:
|
| 282 |
logger.debug("record_jobs_access_blocked failed (non-fatal): %s", e)
|
| 283 |
|
|
@@ -289,11 +316,14 @@ async def record_pro_cta_click(
|
|
| 289 |
target: str = "pro_pricing",
|
| 290 |
) -> None:
|
| 291 |
from agent.core.session import Event
|
|
|
|
| 292 |
try:
|
| 293 |
-
await session.send_event(
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
| 297 |
except Exception as e:
|
| 298 |
logger.debug("record_pro_cta_click failed (non-fatal): %s", e)
|
| 299 |
|
|
@@ -308,11 +338,14 @@ async def record_pro_conversion(
|
|
| 308 |
``MongoSessionStore.mark_pro_seen``; fired into the user's first Pro
|
| 309 |
session so the rollup picks it up alongside other event-driven KPIs."""
|
| 310 |
from agent.core.session import Event
|
|
|
|
| 311 |
try:
|
| 312 |
-
await session.send_event(
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
| 316 |
except Exception as e:
|
| 317 |
logger.debug("record_pro_conversion failed (non-fatal): %s", e)
|
| 318 |
|
|
@@ -327,11 +360,14 @@ async def record_credits_topped_up(
|
|
| 327 |
came back from the HF billing top-up flow and unblocked themselves.
|
| 328 |
Caller is responsible for firing this at most once per session."""
|
| 329 |
from agent.core.session import Event
|
|
|
|
| 330 |
try:
|
| 331 |
-
await session.send_event(
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
| 335 |
except Exception as e:
|
| 336 |
logger.debug("record_credits_topped_up failed (non-fatal): %s", e)
|
| 337 |
|
|
|
|
| 26 |
|
| 27 |
# ── usage extraction ────────────────────────────────────────────────────────
|
| 28 |
|
| 29 |
+
|
| 30 |
def extract_usage(response_or_chunk: Any) -> dict:
|
| 31 |
"""Flat usage dict from a litellm response or final-chunk usage object.
|
| 32 |
|
|
|
|
| 72 |
|
| 73 |
# ── llm_call ────────────────────────────────────────────────────────────────
|
| 74 |
|
| 75 |
+
|
| 76 |
async def record_llm_call(
|
| 77 |
session: Any,
|
| 78 |
*,
|
|
|
|
| 108 |
if response is not None:
|
| 109 |
try:
|
| 110 |
from litellm import completion_cost
|
| 111 |
+
|
| 112 |
cost_usd = float(completion_cost(completion_response=response) or 0.0)
|
| 113 |
except Exception:
|
| 114 |
cost_usd = 0.0
|
| 115 |
from agent.core.session import Event # local import to avoid cycle
|
| 116 |
+
|
| 117 |
try:
|
| 118 |
+
await session.send_event(
|
| 119 |
+
Event(
|
| 120 |
+
event_type="llm_call",
|
| 121 |
+
data={
|
| 122 |
+
"model": model,
|
| 123 |
+
"latency_ms": latency_ms,
|
| 124 |
+
"finish_reason": finish_reason,
|
| 125 |
+
"cost_usd": cost_usd,
|
| 126 |
+
"kind": kind,
|
| 127 |
+
**usage,
|
| 128 |
+
},
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
except Exception as e:
|
| 132 |
logger.debug("record_llm_call failed (non-fatal): %s", e)
|
| 133 |
return usage
|
|
|
|
| 135 |
|
| 136 |
# ── hf_jobs ────────────────────────────────────────────────────────────────
|
| 137 |
|
| 138 |
+
|
| 139 |
def _infer_push_to_hub(script_or_cmd: Any) -> bool:
|
| 140 |
if not isinstance(script_or_cmd, str):
|
| 141 |
return False
|
|
|
|
| 157 |
"""Emit ``hf_job_submit``. Returns the monotonic start timestamp so the
|
| 158 |
caller can pass it back into :func:`record_hf_job_complete`."""
|
| 159 |
from agent.core.session import Event
|
| 160 |
+
|
| 161 |
t_start = time.monotonic()
|
| 162 |
try:
|
| 163 |
script_text = args.get("script") or args.get("command") or ""
|
| 164 |
+
await session.send_event(
|
| 165 |
+
Event(
|
| 166 |
+
event_type="hf_job_submit",
|
| 167 |
+
data={
|
| 168 |
+
"job_id": getattr(job, "id", None),
|
| 169 |
+
"job_url": getattr(job, "url", None),
|
| 170 |
+
"flavor": args.get("hardware_flavor", "cpu-basic"),
|
| 171 |
+
"timeout": args.get("timeout", "30m"),
|
| 172 |
+
"job_type": job_type,
|
| 173 |
+
"image": image,
|
| 174 |
+
"namespace": args.get("namespace"),
|
| 175 |
+
"push_to_hub": _infer_push_to_hub(script_text),
|
| 176 |
+
},
|
| 177 |
+
)
|
| 178 |
+
)
|
| 179 |
except Exception as e:
|
| 180 |
logger.debug("record_hf_job_submit failed (non-fatal): %s", e)
|
| 181 |
return t_start
|
|
|
|
| 190 |
submit_ts: float,
|
| 191 |
) -> None:
|
| 192 |
from agent.core.session import Event
|
| 193 |
+
|
| 194 |
try:
|
| 195 |
wall_time_s = int(time.monotonic() - submit_ts)
|
| 196 |
+
await session.send_event(
|
| 197 |
+
Event(
|
| 198 |
+
event_type="hf_job_complete",
|
| 199 |
+
data={
|
| 200 |
+
"job_id": getattr(job, "id", None),
|
| 201 |
+
"flavor": flavor,
|
| 202 |
+
"final_status": final_status,
|
| 203 |
+
"wall_time_s": wall_time_s,
|
| 204 |
+
},
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
except Exception as e:
|
| 208 |
logger.debug("record_hf_job_complete failed (non-fatal): %s", e)
|
| 209 |
|
| 210 |
|
| 211 |
# ── sandbox ────────────────────────────────────���────────────────────────────
|
| 212 |
|
| 213 |
+
|
| 214 |
async def record_sandbox_create(
|
| 215 |
session: Any,
|
| 216 |
sandbox: Any,
|
|
|
|
| 219 |
create_latency_s: int,
|
| 220 |
) -> None:
|
| 221 |
from agent.core.session import Event
|
| 222 |
+
|
| 223 |
try:
|
| 224 |
# Pin created-at on the session so record_sandbox_destroy can diff.
|
| 225 |
session._sandbox_created_at = time.monotonic() - create_latency_s
|
| 226 |
+
await session.send_event(
|
| 227 |
+
Event(
|
| 228 |
+
event_type="sandbox_create",
|
| 229 |
+
data={
|
| 230 |
+
"sandbox_id": getattr(sandbox, "space_id", None),
|
| 231 |
+
"hardware": hardware,
|
| 232 |
+
"create_latency_s": int(create_latency_s),
|
| 233 |
+
},
|
| 234 |
+
)
|
| 235 |
+
)
|
| 236 |
except Exception as e:
|
| 237 |
logger.debug("record_sandbox_create failed (non-fatal): %s", e)
|
| 238 |
|
| 239 |
|
| 240 |
async def record_sandbox_destroy(session: Any, sandbox: Any) -> None:
|
| 241 |
from agent.core.session import Event
|
| 242 |
+
|
| 243 |
try:
|
| 244 |
created = getattr(session, "_sandbox_created_at", None)
|
| 245 |
lifetime_s = int(time.monotonic() - created) if created else None
|
| 246 |
+
await session.send_event(
|
| 247 |
+
Event(
|
| 248 |
+
event_type="sandbox_destroy",
|
| 249 |
+
data={
|
| 250 |
+
"sandbox_id": getattr(sandbox, "space_id", None),
|
| 251 |
+
"lifetime_s": lifetime_s,
|
| 252 |
+
},
|
| 253 |
+
)
|
| 254 |
+
)
|
| 255 |
except Exception as e:
|
| 256 |
logger.debug("record_sandbox_destroy failed (non-fatal): %s", e)
|
| 257 |
|
| 258 |
|
| 259 |
# ── feedback ───────────────────────────────────────────────────────────────
|
| 260 |
|
| 261 |
+
|
| 262 |
async def record_feedback(
|
| 263 |
session: Any,
|
| 264 |
*,
|
|
|
|
| 268 |
comment: str | None = None,
|
| 269 |
) -> None:
|
| 270 |
from agent.core.session import Event
|
| 271 |
+
|
| 272 |
try:
|
| 273 |
+
await session.send_event(
|
| 274 |
+
Event(
|
| 275 |
+
event_type="feedback",
|
| 276 |
+
data={
|
| 277 |
+
"rating": rating,
|
| 278 |
+
"turn_index": turn_index,
|
| 279 |
+
"message_id": message_id,
|
| 280 |
+
"comment": (comment or "")[:500],
|
| 281 |
+
},
|
| 282 |
+
)
|
| 283 |
+
)
|
| 284 |
except Exception as e:
|
| 285 |
logger.debug("record_feedback failed (non-fatal): %s", e)
|
| 286 |
|
|
|
|
| 293 |
eligible_namespaces: list[str],
|
| 294 |
) -> None:
|
| 295 |
from agent.core.session import Event
|
| 296 |
+
|
| 297 |
try:
|
| 298 |
+
await session.send_event(
|
| 299 |
+
Event(
|
| 300 |
+
event_type="jobs_access_blocked",
|
| 301 |
+
data={
|
| 302 |
+
"tool_call_ids": tool_call_ids,
|
| 303 |
+
"plan": plan,
|
| 304 |
+
"eligible_namespaces": eligible_namespaces,
|
| 305 |
+
},
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
except Exception as e:
|
| 309 |
logger.debug("record_jobs_access_blocked failed (non-fatal): %s", e)
|
| 310 |
|
|
|
|
| 316 |
target: str = "pro_pricing",
|
| 317 |
) -> None:
|
| 318 |
from agent.core.session import Event
|
| 319 |
+
|
| 320 |
try:
|
| 321 |
+
await session.send_event(
|
| 322 |
+
Event(
|
| 323 |
+
event_type="pro_cta_click",
|
| 324 |
+
data={"source": source, "target": target},
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
except Exception as e:
|
| 328 |
logger.debug("record_pro_cta_click failed (non-fatal): %s", e)
|
| 329 |
|
|
|
|
| 338 |
``MongoSessionStore.mark_pro_seen``; fired into the user's first Pro
|
| 339 |
session so the rollup picks it up alongside other event-driven KPIs."""
|
| 340 |
from agent.core.session import Event
|
| 341 |
+
|
| 342 |
try:
|
| 343 |
+
await session.send_event(
|
| 344 |
+
Event(
|
| 345 |
+
event_type="pro_conversion",
|
| 346 |
+
data={"first_seen_at": first_seen_at},
|
| 347 |
+
)
|
| 348 |
+
)
|
| 349 |
except Exception as e:
|
| 350 |
logger.debug("record_pro_conversion failed (non-fatal): %s", e)
|
| 351 |
|
|
|
|
| 360 |
came back from the HF billing top-up flow and unblocked themselves.
|
| 361 |
Caller is responsible for firing this at most once per session."""
|
| 362 |
from agent.core.session import Event
|
| 363 |
+
|
| 364 |
try:
|
| 365 |
+
await session.send_event(
|
| 366 |
+
Event(
|
| 367 |
+
event_type="credits_topped_up",
|
| 368 |
+
data={"namespace": namespace},
|
| 369 |
+
)
|
| 370 |
+
)
|
| 371 |
except Exception as e:
|
| 372 |
logger.debug("record_credits_topped_up failed (non-fatal): %s", e)
|
| 373 |
|
agent/core/tools.py
CHANGED
|
@@ -8,8 +8,6 @@ import warnings
|
|
| 8 |
from dataclasses import dataclass
|
| 9 |
from typing import Any, Awaitable, Callable, Optional
|
| 10 |
|
| 11 |
-
logger = logging.getLogger(__name__)
|
| 12 |
-
|
| 13 |
from fastmcp import Client
|
| 14 |
from fastmcp.exceptions import ToolError
|
| 15 |
from mcp.types import EmbeddedResource, ImageContent, TextContent
|
|
@@ -64,6 +62,8 @@ warnings.filterwarnings(
|
|
| 64 |
"ignore", category=DeprecationWarning, module="aiohttp.connector"
|
| 65 |
)
|
| 66 |
|
|
|
|
|
|
|
| 67 |
NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"]
|
| 68 |
|
| 69 |
|
|
@@ -131,7 +131,12 @@ class ToolRouter:
|
|
| 131 |
Based on codex-rs/core/src/tools/router.rs
|
| 132 |
"""
|
| 133 |
|
| 134 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
self.tools: dict[str, ToolSpec] = {}
|
| 136 |
self.mcp_servers: dict[str, dict[str, Any]] = {}
|
| 137 |
|
|
@@ -144,7 +149,9 @@ class ToolRouter:
|
|
| 144 |
for name, server in mcp_servers.items():
|
| 145 |
data = server.model_dump()
|
| 146 |
if hf_token:
|
| 147 |
-
data.setdefault("headers", {})["Authorization"] =
|
|
|
|
|
|
|
| 148 |
mcp_servers_payload[name] = data
|
| 149 |
self.mcp_client = Client({"mcpServers": mcp_servers_payload})
|
| 150 |
self._mcp_initialized = False
|
|
@@ -218,7 +225,9 @@ class ToolRouter:
|
|
| 218 |
await self.register_mcp_tools()
|
| 219 |
self._mcp_initialized = True
|
| 220 |
except Exception as e:
|
| 221 |
-
logger.warning(
|
|
|
|
|
|
|
| 222 |
self.mcp_client = None
|
| 223 |
|
| 224 |
await self.register_openapi_tool()
|
|
@@ -380,6 +389,7 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]:
|
|
| 380 |
# Sandbox or local tools (highest priority)
|
| 381 |
if local_mode:
|
| 382 |
from agent.tools.local_tools import get_local_tools
|
|
|
|
| 383 |
tools = get_local_tools() + tools
|
| 384 |
else:
|
| 385 |
tools = get_sandbox_tools() + tools
|
|
|
|
| 8 |
from dataclasses import dataclass
|
| 9 |
from typing import Any, Awaitable, Callable, Optional
|
| 10 |
|
|
|
|
|
|
|
| 11 |
from fastmcp import Client
|
| 12 |
from fastmcp.exceptions import ToolError
|
| 13 |
from mcp.types import EmbeddedResource, ImageContent, TextContent
|
|
|
|
| 62 |
"ignore", category=DeprecationWarning, module="aiohttp.connector"
|
| 63 |
)
|
| 64 |
|
| 65 |
+
logger = logging.getLogger(__name__)
|
| 66 |
+
|
| 67 |
NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"]
|
| 68 |
|
| 69 |
|
|
|
|
| 131 |
Based on codex-rs/core/src/tools/router.rs
|
| 132 |
"""
|
| 133 |
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
mcp_servers: dict[str, MCPServerConfig],
|
| 137 |
+
hf_token: str | None = None,
|
| 138 |
+
local_mode: bool = False,
|
| 139 |
+
):
|
| 140 |
self.tools: dict[str, ToolSpec] = {}
|
| 141 |
self.mcp_servers: dict[str, dict[str, Any]] = {}
|
| 142 |
|
|
|
|
| 149 |
for name, server in mcp_servers.items():
|
| 150 |
data = server.model_dump()
|
| 151 |
if hf_token:
|
| 152 |
+
data.setdefault("headers", {})["Authorization"] = (
|
| 153 |
+
f"Bearer {hf_token}"
|
| 154 |
+
)
|
| 155 |
mcp_servers_payload[name] = data
|
| 156 |
self.mcp_client = Client({"mcpServers": mcp_servers_payload})
|
| 157 |
self._mcp_initialized = False
|
|
|
|
| 225 |
await self.register_mcp_tools()
|
| 226 |
self._mcp_initialized = True
|
| 227 |
except Exception as e:
|
| 228 |
+
logger.warning(
|
| 229 |
+
"MCP connection failed, continuing without MCP tools: %s", e
|
| 230 |
+
)
|
| 231 |
self.mcp_client = None
|
| 232 |
|
| 233 |
await self.register_openapi_tool()
|
|
|
|
| 389 |
# Sandbox or local tools (highest priority)
|
| 390 |
if local_mode:
|
| 391 |
from agent.tools.local_tools import get_local_tools
|
| 392 |
+
|
| 393 |
tools = get_local_tools() + tools
|
| 394 |
else:
|
| 395 |
tools = get_sandbox_tools() + tools
|
agent/main.py
CHANGED
|
@@ -77,6 +77,7 @@ def _configure_runtime_logging() -> None:
|
|
| 77 |
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
|
| 78 |
logging.getLogger("litellm").setLevel(logging.ERROR)
|
| 79 |
|
|
|
|
| 80 |
def _safe_get_args(arguments: dict) -> dict:
|
| 81 |
"""Safely extract args dict from arguments, handling cases where LLM passes string."""
|
| 82 |
args = arguments.get("args", {})
|
|
@@ -92,6 +93,7 @@ def _get_hf_user(token: str | None) -> str | None:
|
|
| 92 |
return None
|
| 93 |
try:
|
| 94 |
from huggingface_hub import HfApi
|
|
|
|
| 95 |
return HfApi(token=token).whoami().get("name")
|
| 96 |
except Exception:
|
| 97 |
return None
|
|
@@ -134,10 +136,13 @@ async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str:
|
|
| 134 |
login(token=token, add_to_git_credential=False)
|
| 135 |
print("Token saved to ~/.cache/huggingface/token")
|
| 136 |
except Exception as e:
|
| 137 |
-
print(
|
|
|
|
|
|
|
| 138 |
|
| 139 |
return token
|
| 140 |
|
|
|
|
| 141 |
@dataclass
|
| 142 |
class Operation:
|
| 143 |
"""Operation to be executed by the agent"""
|
|
@@ -162,9 +167,9 @@ def _create_rich_console():
|
|
| 162 |
class _ThinkingShimmer:
|
| 163 |
"""Animated shiny/shimmer thinking indicator — a bright gradient sweeps across the text."""
|
| 164 |
|
| 165 |
-
_BASE = (90, 90, 110)
|
| 166 |
-
_HIGHLIGHT = (255, 200, 80)
|
| 167 |
-
_WIDTH = 5
|
| 168 |
_FPS = 24
|
| 169 |
|
| 170 |
def __init__(self, console):
|
|
@@ -245,7 +250,7 @@ class _StreamBuffer:
|
|
| 245 |
if idx == -1:
|
| 246 |
return None
|
| 247 |
block = self._buffer[:idx]
|
| 248 |
-
self._buffer = self._buffer[idx + 2:]
|
| 249 |
return block
|
| 250 |
|
| 251 |
async def flush_ready(
|
|
@@ -271,7 +276,9 @@ class _StreamBuffer:
|
|
| 271 |
"""Flush complete blocks, then render whatever incomplete tail remains."""
|
| 272 |
await self.flush_ready(cancel_event=cancel_event, instant=instant)
|
| 273 |
if self._buffer.strip():
|
| 274 |
-
await print_markdown(
|
|
|
|
|
|
|
| 275 |
self._buffer = ""
|
| 276 |
|
| 277 |
def discard(self):
|
|
@@ -372,7 +379,11 @@ async def event_listener(
|
|
| 372 |
elif event.event_type == "error":
|
| 373 |
shimmer.stop()
|
| 374 |
stream_buf.discard()
|
| 375 |
-
error =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
print_error(error)
|
| 377 |
turn_complete_event.set()
|
| 378 |
elif event.event_type == "shutdown":
|
|
@@ -392,8 +403,10 @@ async def event_listener(
|
|
| 392 |
|
| 393 |
# If yolo mode is active, auto-approve everything except
|
| 394 |
# scheduled HF jobs, whose recurring cost stays manual.
|
| 395 |
-
if
|
| 396 |
-
|
|
|
|
|
|
|
| 397 |
):
|
| 398 |
approvals = [
|
| 399 |
{
|
|
@@ -637,7 +650,9 @@ async def event_listener(
|
|
| 637 |
f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
|
| 638 |
)
|
| 639 |
except (KeyboardInterrupt, EOFError):
|
| 640 |
-
get_console().print(
|
|
|
|
|
|
|
| 641 |
approvals.append(
|
| 642 |
{
|
| 643 |
"tool_call_id": tool_call_id,
|
|
@@ -770,7 +785,11 @@ async def _handle_slash_command(
|
|
| 770 |
normalized = arg.removeprefix("huggingface/")
|
| 771 |
session = session_holder[0] if session_holder else None
|
| 772 |
await model_switcher.probe_and_switch_model(
|
| 773 |
-
normalized,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
)
|
| 775 |
return None
|
| 776 |
|
|
@@ -965,6 +984,7 @@ async def main(model: str | None = None):
|
|
| 965 |
# Pre-warm the HF router catalog in the background so /model switches
|
| 966 |
# don't block on a network fetch.
|
| 967 |
from agent.core import hf_router_catalog
|
|
|
|
| 968 |
asyncio.create_task(asyncio.to_thread(hf_router_catalog.prewarm))
|
| 969 |
|
| 970 |
# Create queues for communication
|
|
@@ -1110,7 +1130,11 @@ async def main(model: str | None = None):
|
|
| 1110 |
# Handle slash commands
|
| 1111 |
if user_input.strip().startswith("/"):
|
| 1112 |
sub = await _handle_slash_command(
|
| 1113 |
-
user_input.strip(),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1114 |
)
|
| 1115 |
if sub is None:
|
| 1116 |
# Command handled locally, loop back for input
|
|
@@ -1176,10 +1200,13 @@ async def headless_main(
|
|
| 1176 |
|
| 1177 |
hf_token = resolve_hf_token()
|
| 1178 |
if not hf_token:
|
| 1179 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 1180 |
sys.exit(1)
|
| 1181 |
|
| 1182 |
-
print(
|
| 1183 |
|
| 1184 |
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 1185 |
config.yolo_mode = True # Auto-approve everything in headless mode
|
|
@@ -1327,26 +1354,35 @@ async def headless_main(
|
|
| 1327 |
for t in tools_data
|
| 1328 |
]
|
| 1329 |
_hl_sub_id[0] += 1
|
| 1330 |
-
await submission_queue.put(
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
|
| 1334 |
-
|
| 1335 |
-
|
| 1336 |
-
|
|
|
|
|
|
|
| 1337 |
elif event.event_type == "compacted":
|
| 1338 |
old_tokens = event.data.get("old_tokens", 0) if event.data else 0
|
| 1339 |
new_tokens = event.data.get("new_tokens", 0) if event.data else 0
|
| 1340 |
print_compacted(old_tokens, new_tokens)
|
| 1341 |
elif event.event_type == "error":
|
| 1342 |
stream_buf.discard()
|
| 1343 |
-
error =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1344 |
print_error(error)
|
| 1345 |
break
|
| 1346 |
elif event.event_type in ("turn_complete", "interrupted"):
|
| 1347 |
stream_buf.discard()
|
| 1348 |
history_size = event.data.get("history_size", "?") if event.data else "?"
|
| 1349 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 1350 |
if event.event_type == "turn_complete":
|
| 1351 |
session = session_holder[0] if session_holder else None
|
| 1352 |
if session is not None:
|
|
@@ -1372,6 +1408,7 @@ def cli():
|
|
| 1372 |
"""Entry point for the ml-intern CLI command."""
|
| 1373 |
import logging as _logging
|
| 1374 |
import warnings
|
|
|
|
| 1375 |
# Suppress aiohttp "Unclosed client session" noise during event loop teardown
|
| 1376 |
_logging.getLogger("asyncio").setLevel(_logging.CRITICAL)
|
| 1377 |
_configure_runtime_logging()
|
|
@@ -1381,12 +1418,23 @@ def cli():
|
|
| 1381 |
warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh")
|
| 1382 |
|
| 1383 |
parser = argparse.ArgumentParser(description="Hugging Face Agent CLI")
|
| 1384 |
-
parser.add_argument(
|
| 1385 |
-
|
| 1386 |
-
|
| 1387 |
-
|
| 1388 |
-
|
| 1389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1390 |
args = parser.parse_args()
|
| 1391 |
|
| 1392 |
try:
|
|
@@ -1394,7 +1442,14 @@ def cli():
|
|
| 1394 |
max_iter = args.max_iterations
|
| 1395 |
if max_iter is not None and max_iter < 0:
|
| 1396 |
max_iter = 10_000 # effectively unlimited
|
| 1397 |
-
asyncio.run(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1398 |
else:
|
| 1399 |
asyncio.run(main(model=args.model))
|
| 1400 |
except KeyboardInterrupt:
|
|
|
|
| 77 |
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
|
| 78 |
logging.getLogger("litellm").setLevel(logging.ERROR)
|
| 79 |
|
| 80 |
+
|
| 81 |
def _safe_get_args(arguments: dict) -> dict:
|
| 82 |
"""Safely extract args dict from arguments, handling cases where LLM passes string."""
|
| 83 |
args = arguments.get("args", {})
|
|
|
|
| 93 |
return None
|
| 94 |
try:
|
| 95 |
from huggingface_hub import HfApi
|
| 96 |
+
|
| 97 |
return HfApi(token=token).whoami().get("name")
|
| 98 |
except Exception:
|
| 99 |
return None
|
|
|
|
| 136 |
login(token=token, add_to_git_credential=False)
|
| 137 |
print("Token saved to ~/.cache/huggingface/token")
|
| 138 |
except Exception as e:
|
| 139 |
+
print(
|
| 140 |
+
f"Warning: could not persist token ({e}), using for this session only."
|
| 141 |
+
)
|
| 142 |
|
| 143 |
return token
|
| 144 |
|
| 145 |
+
|
| 146 |
@dataclass
|
| 147 |
class Operation:
|
| 148 |
"""Operation to be executed by the agent"""
|
|
|
|
| 167 |
class _ThinkingShimmer:
|
| 168 |
"""Animated shiny/shimmer thinking indicator — a bright gradient sweeps across the text."""
|
| 169 |
|
| 170 |
+
_BASE = (90, 90, 110) # dim base color
|
| 171 |
+
_HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold)
|
| 172 |
+
_WIDTH = 5 # shimmer width in characters
|
| 173 |
_FPS = 24
|
| 174 |
|
| 175 |
def __init__(self, console):
|
|
|
|
| 250 |
if idx == -1:
|
| 251 |
return None
|
| 252 |
block = self._buffer[:idx]
|
| 253 |
+
self._buffer = self._buffer[idx + 2 :]
|
| 254 |
return block
|
| 255 |
|
| 256 |
async def flush_ready(
|
|
|
|
| 276 |
"""Flush complete blocks, then render whatever incomplete tail remains."""
|
| 277 |
await self.flush_ready(cancel_event=cancel_event, instant=instant)
|
| 278 |
if self._buffer.strip():
|
| 279 |
+
await print_markdown(
|
| 280 |
+
self._buffer, cancel_event=cancel_event, instant=instant
|
| 281 |
+
)
|
| 282 |
self._buffer = ""
|
| 283 |
|
| 284 |
def discard(self):
|
|
|
|
| 379 |
elif event.event_type == "error":
|
| 380 |
shimmer.stop()
|
| 381 |
stream_buf.discard()
|
| 382 |
+
error = (
|
| 383 |
+
event.data.get("error", "Unknown error")
|
| 384 |
+
if event.data
|
| 385 |
+
else "Unknown error"
|
| 386 |
+
)
|
| 387 |
print_error(error)
|
| 388 |
turn_complete_event.set()
|
| 389 |
elif event.event_type == "shutdown":
|
|
|
|
| 403 |
|
| 404 |
# If yolo mode is active, auto-approve everything except
|
| 405 |
# scheduled HF jobs, whose recurring cost stays manual.
|
| 406 |
+
if (
|
| 407 |
+
config
|
| 408 |
+
and config.yolo_mode
|
| 409 |
+
and not any(_is_scheduled_hf_job_tool(t) for t in tools_data)
|
| 410 |
):
|
| 411 |
approvals = [
|
| 412 |
{
|
|
|
|
| 650 |
f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
|
| 651 |
)
|
| 652 |
except (KeyboardInterrupt, EOFError):
|
| 653 |
+
get_console().print(
|
| 654 |
+
"[dim]Approval cancelled — rejecting remaining items[/dim]"
|
| 655 |
+
)
|
| 656 |
approvals.append(
|
| 657 |
{
|
| 658 |
"tool_call_id": tool_call_id,
|
|
|
|
| 785 |
normalized = arg.removeprefix("huggingface/")
|
| 786 |
session = session_holder[0] if session_holder else None
|
| 787 |
await model_switcher.probe_and_switch_model(
|
| 788 |
+
normalized,
|
| 789 |
+
config,
|
| 790 |
+
session,
|
| 791 |
+
console,
|
| 792 |
+
resolve_hf_token(),
|
| 793 |
)
|
| 794 |
return None
|
| 795 |
|
|
|
|
| 984 |
# Pre-warm the HF router catalog in the background so /model switches
|
| 985 |
# don't block on a network fetch.
|
| 986 |
from agent.core import hf_router_catalog
|
| 987 |
+
|
| 988 |
asyncio.create_task(asyncio.to_thread(hf_router_catalog.prewarm))
|
| 989 |
|
| 990 |
# Create queues for communication
|
|
|
|
| 1130 |
# Handle slash commands
|
| 1131 |
if user_input.strip().startswith("/"):
|
| 1132 |
sub = await _handle_slash_command(
|
| 1133 |
+
user_input.strip(),
|
| 1134 |
+
config,
|
| 1135 |
+
session_holder,
|
| 1136 |
+
submission_queue,
|
| 1137 |
+
submission_id,
|
| 1138 |
)
|
| 1139 |
if sub is None:
|
| 1140 |
# Command handled locally, loop back for input
|
|
|
|
| 1200 |
|
| 1201 |
hf_token = resolve_hf_token()
|
| 1202 |
if not hf_token:
|
| 1203 |
+
print(
|
| 1204 |
+
"ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.",
|
| 1205 |
+
file=sys.stderr,
|
| 1206 |
+
)
|
| 1207 |
sys.exit(1)
|
| 1208 |
|
| 1209 |
+
print("HF token loaded", file=sys.stderr)
|
| 1210 |
|
| 1211 |
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 1212 |
config.yolo_mode = True # Auto-approve everything in headless mode
|
|
|
|
| 1354 |
for t in tools_data
|
| 1355 |
]
|
| 1356 |
_hl_sub_id[0] += 1
|
| 1357 |
+
await submission_queue.put(
|
| 1358 |
+
Submission(
|
| 1359 |
+
id=f"hl_approval_{_hl_sub_id[0]}",
|
| 1360 |
+
operation=Operation(
|
| 1361 |
+
op_type=OpType.EXEC_APPROVAL,
|
| 1362 |
+
data={"approvals": approvals},
|
| 1363 |
+
),
|
| 1364 |
+
)
|
| 1365 |
+
)
|
| 1366 |
elif event.event_type == "compacted":
|
| 1367 |
old_tokens = event.data.get("old_tokens", 0) if event.data else 0
|
| 1368 |
new_tokens = event.data.get("new_tokens", 0) if event.data else 0
|
| 1369 |
print_compacted(old_tokens, new_tokens)
|
| 1370 |
elif event.event_type == "error":
|
| 1371 |
stream_buf.discard()
|
| 1372 |
+
error = (
|
| 1373 |
+
event.data.get("error", "Unknown error")
|
| 1374 |
+
if event.data
|
| 1375 |
+
else "Unknown error"
|
| 1376 |
+
)
|
| 1377 |
print_error(error)
|
| 1378 |
break
|
| 1379 |
elif event.event_type in ("turn_complete", "interrupted"):
|
| 1380 |
stream_buf.discard()
|
| 1381 |
history_size = event.data.get("history_size", "?") if event.data else "?"
|
| 1382 |
+
print(
|
| 1383 |
+
f"\n--- Agent {event.event_type} (history_size={history_size}) ---",
|
| 1384 |
+
file=sys.stderr,
|
| 1385 |
+
)
|
| 1386 |
if event.event_type == "turn_complete":
|
| 1387 |
session = session_holder[0] if session_holder else None
|
| 1388 |
if session is not None:
|
|
|
|
| 1408 |
"""Entry point for the ml-intern CLI command."""
|
| 1409 |
import logging as _logging
|
| 1410 |
import warnings
|
| 1411 |
+
|
| 1412 |
# Suppress aiohttp "Unclosed client session" noise during event loop teardown
|
| 1413 |
_logging.getLogger("asyncio").setLevel(_logging.CRITICAL)
|
| 1414 |
_configure_runtime_logging()
|
|
|
|
| 1418 |
warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh")
|
| 1419 |
|
| 1420 |
parser = argparse.ArgumentParser(description="Hugging Face Agent CLI")
|
| 1421 |
+
parser.add_argument(
|
| 1422 |
+
"prompt", nargs="?", default=None, help="Run headlessly with this prompt"
|
| 1423 |
+
)
|
| 1424 |
+
parser.add_argument(
|
| 1425 |
+
"--model", "-m", default=None, help="Model to use (default: from config)"
|
| 1426 |
+
)
|
| 1427 |
+
parser.add_argument(
|
| 1428 |
+
"--max-iterations",
|
| 1429 |
+
type=int,
|
| 1430 |
+
default=None,
|
| 1431 |
+
help="Max LLM requests per turn (default: 50, use -1 for unlimited)",
|
| 1432 |
+
)
|
| 1433 |
+
parser.add_argument(
|
| 1434 |
+
"--no-stream",
|
| 1435 |
+
action="store_true",
|
| 1436 |
+
help="Disable token streaming (use non-streaming LLM calls)",
|
| 1437 |
+
)
|
| 1438 |
args = parser.parse_args()
|
| 1439 |
|
| 1440 |
try:
|
|
|
|
| 1442 |
max_iter = args.max_iterations
|
| 1443 |
if max_iter is not None and max_iter < 0:
|
| 1444 |
max_iter = 10_000 # effectively unlimited
|
| 1445 |
+
asyncio.run(
|
| 1446 |
+
headless_main(
|
| 1447 |
+
args.prompt,
|
| 1448 |
+
model=args.model,
|
| 1449 |
+
max_iterations=max_iter,
|
| 1450 |
+
stream=not args.no_stream,
|
| 1451 |
+
)
|
| 1452 |
+
)
|
| 1453 |
else:
|
| 1454 |
asyncio.run(main(model=args.model))
|
| 1455 |
except KeyboardInterrupt:
|
agent/messaging/base.py
CHANGED
|
@@ -2,7 +2,11 @@ from abc import ABC, abstractmethod
|
|
| 2 |
|
| 3 |
import httpx
|
| 4 |
|
| 5 |
-
from agent.messaging.models import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class NotificationError(Exception):
|
|
|
|
| 2 |
|
| 3 |
import httpx
|
| 4 |
|
| 5 |
+
from agent.messaging.models import (
|
| 6 |
+
DestinationConfig,
|
| 7 |
+
NotificationRequest,
|
| 8 |
+
NotificationResult,
|
| 9 |
+
)
|
| 10 |
|
| 11 |
|
| 12 |
class NotificationError(Exception):
|
agent/messaging/gateway.py
CHANGED
|
@@ -39,7 +39,9 @@ class NotificationGateway:
|
|
| 39 |
if not self.enabled or self._worker_task is not None:
|
| 40 |
return
|
| 41 |
self._client = httpx.AsyncClient(timeout=10.0)
|
| 42 |
-
self._worker_task = asyncio.create_task(
|
|
|
|
|
|
|
| 43 |
|
| 44 |
async def flush(self) -> None:
|
| 45 |
if not self.enabled:
|
|
@@ -87,7 +89,9 @@ class NotificationGateway:
|
|
| 87 |
provider=destination.provider,
|
| 88 |
error=f"No provider implementation for '{destination.provider}'",
|
| 89 |
)
|
| 90 |
-
return await self._send_with_retries(
|
|
|
|
|
|
|
| 91 |
|
| 92 |
async def send_many(
|
| 93 |
self, requests: Iterable[NotificationRequest]
|
|
@@ -131,7 +135,9 @@ class NotificationGateway:
|
|
| 131 |
try:
|
| 132 |
for attempt in range(len(_RETRY_DELAYS) + 1):
|
| 133 |
try:
|
| 134 |
-
return await provider.send(
|
|
|
|
|
|
|
| 135 |
except RetryableNotificationError as exc:
|
| 136 |
if attempt >= len(_RETRY_DELAYS):
|
| 137 |
return NotificationResult(
|
|
|
|
| 39 |
if not self.enabled or self._worker_task is not None:
|
| 40 |
return
|
| 41 |
self._client = httpx.AsyncClient(timeout=10.0)
|
| 42 |
+
self._worker_task = asyncio.create_task(
|
| 43 |
+
self._worker(), name="notification-gateway"
|
| 44 |
+
)
|
| 45 |
|
| 46 |
async def flush(self) -> None:
|
| 47 |
if not self.enabled:
|
|
|
|
| 89 |
provider=destination.provider,
|
| 90 |
error=f"No provider implementation for '{destination.provider}'",
|
| 91 |
)
|
| 92 |
+
return await self._send_with_retries(
|
| 93 |
+
provider, request.destination, destination, request
|
| 94 |
+
)
|
| 95 |
|
| 96 |
async def send_many(
|
| 97 |
self, requests: Iterable[NotificationRequest]
|
|
|
|
| 135 |
try:
|
| 136 |
for attempt in range(len(_RETRY_DELAYS) + 1):
|
| 137 |
try:
|
| 138 |
+
return await provider.send(
|
| 139 |
+
client, destination_name, destination, request
|
| 140 |
+
)
|
| 141 |
except RetryableNotificationError as exc:
|
| 142 |
if attempt >= len(_RETRY_DELAYS):
|
| 143 |
return NotificationResult(
|
agent/messaging/models.py
CHANGED
|
@@ -55,9 +55,7 @@ class MessagingConfig(BaseModel):
|
|
| 55 |
seen: set[str] = set()
|
| 56 |
for event_type in event_types:
|
| 57 |
if event_type not in SUPPORTED_AUTO_EVENT_TYPES:
|
| 58 |
-
raise ValueError(
|
| 59 |
-
f"unsupported auto event type '{event_type}'"
|
| 60 |
-
)
|
| 61 |
if event_type not in seen:
|
| 62 |
normalized.append(event_type)
|
| 63 |
seen.add(event_type)
|
|
@@ -83,11 +81,7 @@ class MessagingConfig(BaseModel):
|
|
| 83 |
def default_auto_destinations(self) -> list[str]:
|
| 84 |
if not self.enabled:
|
| 85 |
return []
|
| 86 |
-
return [
|
| 87 |
-
name
|
| 88 |
-
for name in self.destinations
|
| 89 |
-
if self.can_auto_send(name)
|
| 90 |
-
]
|
| 91 |
|
| 92 |
|
| 93 |
class NotificationRequest(BaseModel):
|
|
|
|
| 55 |
seen: set[str] = set()
|
| 56 |
for event_type in event_types:
|
| 57 |
if event_type not in SUPPORTED_AUTO_EVENT_TYPES:
|
| 58 |
+
raise ValueError(f"unsupported auto event type '{event_type}'")
|
|
|
|
|
|
|
| 59 |
if event_type not in seen:
|
| 60 |
normalized.append(event_type)
|
| 61 |
seen.add(event_type)
|
|
|
|
| 81 |
def default_auto_destinations(self) -> list[str]:
|
| 82 |
if not self.enabled:
|
| 83 |
return []
|
| 84 |
+
return [name for name in self.destinations if self.can_auto_send(name)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
class NotificationRequest(BaseModel):
|
agent/messaging/slack.py
CHANGED
|
@@ -160,9 +160,7 @@ class SlackProvider(NotificationProvider):
|
|
| 160 |
raise RetryableNotificationError("Slack transport error") from exc
|
| 161 |
|
| 162 |
if response.status_code == 429 or response.status_code >= 500:
|
| 163 |
-
raise RetryableNotificationError(
|
| 164 |
-
f"Slack HTTP {response.status_code}"
|
| 165 |
-
)
|
| 166 |
if response.status_code >= 400:
|
| 167 |
raise NotificationError(f"Slack HTTP {response.status_code}")
|
| 168 |
|
|
|
|
| 160 |
raise RetryableNotificationError("Slack transport error") from exc
|
| 161 |
|
| 162 |
if response.status_code == 429 or response.status_code >= 500:
|
| 163 |
+
raise RetryableNotificationError(f"Slack HTTP {response.status_code}")
|
|
|
|
|
|
|
| 164 |
if response.status_code >= 400:
|
| 165 |
raise NotificationError(f"Slack HTTP {response.status_code}")
|
| 166 |
|
agent/sft/tagger.py
CHANGED
|
@@ -27,19 +27,29 @@ Tags are deduplicated before returning.
|
|
| 27 |
|
| 28 |
from __future__ import annotations
|
| 29 |
|
| 30 |
-
from typing import
|
| 31 |
|
| 32 |
# Flavor → GPU-family mapping. Keep conservative; unknown flavors → "none".
|
| 33 |
_GPU_FAMILY = {
|
| 34 |
-
"cpu-basic": "none",
|
| 35 |
-
"
|
| 36 |
-
"
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
}
|
| 44 |
|
| 45 |
# Substrings that count a flavor as multi-GPU.
|
|
@@ -48,9 +58,17 @@ _MULTI_GPU_MARKERS = ("x2", "x4", "x8")
|
|
| 48 |
# Tool names that don't touch training/inference or sandbox/jobs. If a session
|
| 49 |
# only used these, we tag it research_only.
|
| 50 |
_RESEARCH_ONLY_TOOLS = {
|
| 51 |
-
"research",
|
| 52 |
-
"
|
| 53 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
}
|
| 55 |
|
| 56 |
# Tool names that signal data manipulation workflows.
|
|
@@ -126,11 +144,22 @@ def _infer_task_tag(
|
|
| 126 |
# hf_jobs at all and a script mentions training APIs.
|
| 127 |
for script in hf_job_submit_scripts:
|
| 128 |
low = script.lower()
|
| 129 |
-
if any(
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
return "training"
|
| 135 |
|
| 136 |
# inference: sessions that use inference tools but never hf_jobs/sandbox
|
|
|
|
| 27 |
|
| 28 |
from __future__ import annotations
|
| 29 |
|
| 30 |
+
from typing import Iterable
|
| 31 |
|
| 32 |
# Flavor → GPU-family mapping. Keep conservative; unknown flavors → "none".
|
| 33 |
_GPU_FAMILY = {
|
| 34 |
+
"cpu-basic": "none",
|
| 35 |
+
"cpu-upgrade": "none",
|
| 36 |
+
"t4-small": "t4",
|
| 37 |
+
"t4-medium": "t4",
|
| 38 |
+
"l4x1": "l40s",
|
| 39 |
+
"l4x4": "l40s",
|
| 40 |
+
"l40sx1": "l40s",
|
| 41 |
+
"l40sx4": "l40s",
|
| 42 |
+
"l40sx8": "l40s",
|
| 43 |
+
"a10g-small": "a10g",
|
| 44 |
+
"a10g-large": "a10g",
|
| 45 |
+
"a10g-largex2": "a10g",
|
| 46 |
+
"a10g-largex4": "a10g",
|
| 47 |
+
"a100-large": "a100",
|
| 48 |
+
"a100x2": "a100",
|
| 49 |
+
"a100x4": "a100",
|
| 50 |
+
"a100x8": "a100",
|
| 51 |
+
"h100": "h100",
|
| 52 |
+
"h100x8": "h100",
|
| 53 |
}
|
| 54 |
|
| 55 |
# Substrings that count a flavor as multi-GPU.
|
|
|
|
| 58 |
# Tool names that don't touch training/inference or sandbox/jobs. If a session
|
| 59 |
# only used these, we tag it research_only.
|
| 60 |
_RESEARCH_ONLY_TOOLS = {
|
| 61 |
+
"research",
|
| 62 |
+
"github_find_examples",
|
| 63 |
+
"github_read_file",
|
| 64 |
+
"github_list_repos",
|
| 65 |
+
"hf_papers",
|
| 66 |
+
"explore_hf_docs",
|
| 67 |
+
"fetch_hf_docs",
|
| 68 |
+
"hub_repo_details",
|
| 69 |
+
"plan",
|
| 70 |
+
"hf_inspect_dataset",
|
| 71 |
+
"web_search",
|
| 72 |
}
|
| 73 |
|
| 74 |
# Tool names that signal data manipulation workflows.
|
|
|
|
| 144 |
# hf_jobs at all and a script mentions training APIs.
|
| 145 |
for script in hf_job_submit_scripts:
|
| 146 |
low = script.lower()
|
| 147 |
+
if any(
|
| 148 |
+
k in low
|
| 149 |
+
for k in (
|
| 150 |
+
"sftconfig",
|
| 151 |
+
"sfttrainer",
|
| 152 |
+
"trainer(",
|
| 153 |
+
"trainingarguments",
|
| 154 |
+
"grpo",
|
| 155 |
+
"dpo",
|
| 156 |
+
".train(",
|
| 157 |
+
"transformers import",
|
| 158 |
+
"trainer import",
|
| 159 |
+
"fine-tune",
|
| 160 |
+
"finetune",
|
| 161 |
+
)
|
| 162 |
+
):
|
| 163 |
return "training"
|
| 164 |
|
| 165 |
# inference: sessions that use inference tools but never hf_jobs/sandbox
|
agent/tools/dataset_tools.py
CHANGED
|
@@ -423,7 +423,9 @@ HF_INSPECT_DATASET_TOOL_SPEC = {
|
|
| 423 |
}
|
| 424 |
|
| 425 |
|
| 426 |
-
async def hf_inspect_dataset_handler(
|
|
|
|
|
|
|
| 427 |
"""Handler for agent tool router"""
|
| 428 |
try:
|
| 429 |
hf_token = session.hf_token if session else None
|
|
|
|
| 423 |
}
|
| 424 |
|
| 425 |
|
| 426 |
+
async def hf_inspect_dataset_handler(
|
| 427 |
+
arguments: dict[str, Any], session=None
|
| 428 |
+
) -> tuple[str, bool]:
|
| 429 |
"""Handler for agent tool router"""
|
| 430 |
try:
|
| 431 |
hf_token = session.hf_token if session else None
|
agent/tools/edit_utils.py
CHANGED
|
@@ -10,18 +10,18 @@ from __future__ import annotations
|
|
| 10 |
# ── Unicode normalization map ────────────────────────────────────────────
|
| 11 |
|
| 12 |
UNICODE_MAP = {
|
| 13 |
-
"\u2013": "-",
|
| 14 |
-
"\u2014": "-",
|
| 15 |
-
"\u2212": "-",
|
| 16 |
-
"\u2018": "'",
|
| 17 |
-
"\u2019": "'",
|
| 18 |
-
"\u201c": '"',
|
| 19 |
-
"\u201d": '"',
|
| 20 |
-
"\u00a0": " ",
|
| 21 |
-
"\u2003": " ",
|
| 22 |
-
"\u2002": " ",
|
| 23 |
-
"\u200b": "",
|
| 24 |
-
"\ufeff": "",
|
| 25 |
}
|
| 26 |
|
| 27 |
|
|
@@ -59,12 +59,12 @@ def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]:
|
|
| 59 |
line_start_map[i] = original byte offset of the start of line i.
|
| 60 |
"""
|
| 61 |
orig_lines = text.split("\n")
|
| 62 |
-
stripped_lines = [strip_fn(
|
| 63 |
return "\n".join(stripped_lines), orig_lines, stripped_lines
|
| 64 |
|
| 65 |
# Pass 2 — right-trim
|
| 66 |
c_rt, c_orig_lines, c_rt_lines = _build_stripped(content, str.rstrip)
|
| 67 |
-
p_rt = "\n".join(
|
| 68 |
idx = c_rt.find(p_rt)
|
| 69 |
if idx != -1:
|
| 70 |
orig_idx = _map_back(idx, c_orig_lines, c_rt_lines)
|
|
@@ -72,7 +72,7 @@ def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]:
|
|
| 72 |
|
| 73 |
# Pass 3 — both-sides trim
|
| 74 |
c_st, _, c_st_lines = _build_stripped(content, str.strip)
|
| 75 |
-
p_st = "\n".join(
|
| 76 |
idx = c_st.find(p_st)
|
| 77 |
if idx != -1:
|
| 78 |
orig_idx = _map_back(idx, c_orig_lines, c_st_lines)
|
|
@@ -114,7 +114,9 @@ def _map_back(
|
|
| 114 |
return 0
|
| 115 |
|
| 116 |
|
| 117 |
-
def fuzzy_find_original_match(
|
|
|
|
|
|
|
| 118 |
"""Find the *original* text in content that matches pattern fuzzily.
|
| 119 |
|
| 120 |
Returns (original_matched_text, match_note) or (None, None).
|
|
@@ -224,7 +226,9 @@ def apply_edit(
|
|
| 224 |
return new_content, 1, fuzzy_note
|
| 225 |
|
| 226 |
else:
|
| 227 |
-
raise ValueError(
|
|
|
|
|
|
|
| 228 |
|
| 229 |
|
| 230 |
# ── Syntax validation (Python) ───────────────────────────────────────────
|
|
@@ -255,14 +259,15 @@ def validate_python(content: str, path: str = "") -> list[str]:
|
|
| 255 |
return warnings
|
| 256 |
|
| 257 |
# 2. Training script heuristics
|
| 258 |
-
if any(
|
|
|
|
|
|
|
|
|
|
| 259 |
if "push_to_hub" not in content:
|
| 260 |
warnings.append(
|
| 261 |
"Training script warning: no 'push_to_hub' found — model may be lost when job ends"
|
| 262 |
)
|
| 263 |
if "hub_model_id" not in content:
|
| 264 |
-
warnings.append(
|
| 265 |
-
"Training script warning: no 'hub_model_id' found"
|
| 266 |
-
)
|
| 267 |
|
| 268 |
return warnings
|
|
|
|
| 10 |
# ── Unicode normalization map ────────────────────────────────────────────
|
| 11 |
|
| 12 |
UNICODE_MAP = {
|
| 13 |
+
"\u2013": "-", # en-dash
|
| 14 |
+
"\u2014": "-", # em-dash
|
| 15 |
+
"\u2212": "-", # minus sign
|
| 16 |
+
"\u2018": "'", # left single quote
|
| 17 |
+
"\u2019": "'", # right single quote
|
| 18 |
+
"\u201c": '"', # left double quote
|
| 19 |
+
"\u201d": '"', # right double quote
|
| 20 |
+
"\u00a0": " ", # non-breaking space
|
| 21 |
+
"\u2003": " ", # em space
|
| 22 |
+
"\u2002": " ", # en space
|
| 23 |
+
"\u200b": "", # zero-width space
|
| 24 |
+
"\ufeff": "", # BOM
|
| 25 |
}
|
| 26 |
|
| 27 |
|
|
|
|
| 59 |
line_start_map[i] = original byte offset of the start of line i.
|
| 60 |
"""
|
| 61 |
orig_lines = text.split("\n")
|
| 62 |
+
stripped_lines = [strip_fn(line) for line in orig_lines]
|
| 63 |
return "\n".join(stripped_lines), orig_lines, stripped_lines
|
| 64 |
|
| 65 |
# Pass 2 — right-trim
|
| 66 |
c_rt, c_orig_lines, c_rt_lines = _build_stripped(content, str.rstrip)
|
| 67 |
+
p_rt = "\n".join(line.rstrip() for line in pattern.split("\n"))
|
| 68 |
idx = c_rt.find(p_rt)
|
| 69 |
if idx != -1:
|
| 70 |
orig_idx = _map_back(idx, c_orig_lines, c_rt_lines)
|
|
|
|
| 72 |
|
| 73 |
# Pass 3 — both-sides trim
|
| 74 |
c_st, _, c_st_lines = _build_stripped(content, str.strip)
|
| 75 |
+
p_st = "\n".join(line.strip() for line in pattern.split("\n"))
|
| 76 |
idx = c_st.find(p_st)
|
| 77 |
if idx != -1:
|
| 78 |
orig_idx = _map_back(idx, c_orig_lines, c_st_lines)
|
|
|
|
| 114 |
return 0
|
| 115 |
|
| 116 |
|
| 117 |
+
def fuzzy_find_original_match(
|
| 118 |
+
content: str, pattern: str
|
| 119 |
+
) -> tuple[str | None, str | None]:
|
| 120 |
"""Find the *original* text in content that matches pattern fuzzily.
|
| 121 |
|
| 122 |
Returns (original_matched_text, match_note) or (None, None).
|
|
|
|
| 226 |
return new_content, 1, fuzzy_note
|
| 227 |
|
| 228 |
else:
|
| 229 |
+
raise ValueError(
|
| 230 |
+
f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before."
|
| 231 |
+
)
|
| 232 |
|
| 233 |
|
| 234 |
# ── Syntax validation (Python) ───────────────────────────────────────────
|
|
|
|
| 259 |
return warnings
|
| 260 |
|
| 261 |
# 2. Training script heuristics
|
| 262 |
+
if any(
|
| 263 |
+
kw in content
|
| 264 |
+
for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")
|
| 265 |
+
):
|
| 266 |
if "push_to_hub" not in content:
|
| 267 |
warnings.append(
|
| 268 |
"Training script warning: no 'push_to_hub' found — model may be lost when job ends"
|
| 269 |
)
|
| 270 |
if "hub_model_id" not in content:
|
| 271 |
+
warnings.append("Training script warning: no 'hub_model_id' found")
|
|
|
|
|
|
|
| 272 |
|
| 273 |
return warnings
|
agent/tools/hf_repo_files_tool.py
CHANGED
|
@@ -10,6 +10,7 @@ from typing import Any, Dict, Literal, Optional
|
|
| 10 |
from huggingface_hub import HfApi, hf_hub_download
|
| 11 |
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 12 |
|
|
|
|
| 13 |
from agent.tools.types import ToolResult
|
| 14 |
|
| 15 |
OperationType = Literal["list", "read", "upload", "delete"]
|
|
@@ -39,8 +40,9 @@ def _format_size(size_bytes: int) -> str:
|
|
| 39 |
class HfRepoFilesTool:
|
| 40 |
"""Tool for file operations on HF repos."""
|
| 41 |
|
| 42 |
-
def __init__(self, hf_token: Optional[str] = None):
|
| 43 |
self.api = HfApi(token=hf_token)
|
|
|
|
| 44 |
|
| 45 |
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 46 |
"""Execute the specified operation."""
|
|
@@ -61,7 +63,9 @@ class HfRepoFilesTool:
|
|
| 61 |
if handler:
|
| 62 |
return await handler(args)
|
| 63 |
else:
|
| 64 |
-
return self._error(
|
|
|
|
|
|
|
| 65 |
|
| 66 |
except RepositoryNotFoundError:
|
| 67 |
return self._error(f"Repository not found: {args.get('repo_id')}")
|
|
@@ -96,17 +100,23 @@ class HfRepoFilesTool:
|
|
| 96 |
revision = args.get("revision", "main")
|
| 97 |
path = args.get("path", "")
|
| 98 |
|
| 99 |
-
items = list(
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
|
| 108 |
if not items:
|
| 109 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
lines = []
|
| 112 |
total_size = 0
|
|
@@ -118,9 +128,16 @@ class HfRepoFilesTool:
|
|
| 118 |
lines.append(f"{item.path}/")
|
| 119 |
|
| 120 |
url = _build_repo_url(repo_id, repo_type)
|
| 121 |
-
response =
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
async def _read(self, args: Dict[str, Any]) -> ToolResult:
|
| 126 |
"""Read file content from a repository."""
|
|
@@ -160,8 +177,13 @@ class HfRepoFilesTool:
|
|
| 160 |
|
| 161 |
except UnicodeDecodeError:
|
| 162 |
import os
|
|
|
|
| 163 |
size = os.path.getsize(file_path)
|
| 164 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
async def _upload(self, args: Dict[str, Any]) -> ToolResult:
|
| 167 |
"""Upload content to a repository."""
|
|
@@ -194,6 +216,16 @@ class HfRepoFilesTool:
|
|
| 194 |
create_pr=create_pr,
|
| 195 |
)
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
url = _build_repo_url(repo_id, repo_type)
|
| 198 |
if create_pr and hasattr(result, "pr_url"):
|
| 199 |
response = f"**Uploaded as PR**\n{result.pr_url}"
|
|
@@ -235,7 +267,12 @@ class HfRepoFilesTool:
|
|
| 235 |
|
| 236 |
def _error(self, message: str) -> ToolResult:
|
| 237 |
"""Return an error result."""
|
| 238 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
# Tool specification
|
|
@@ -312,11 +349,13 @@ HF_REPO_FILES_TOOL_SPEC = {
|
|
| 312 |
}
|
| 313 |
|
| 314 |
|
| 315 |
-
async def hf_repo_files_handler(
|
|
|
|
|
|
|
| 316 |
"""Handler for agent tool router."""
|
| 317 |
try:
|
| 318 |
hf_token = session.hf_token if session else None
|
| 319 |
-
tool = HfRepoFilesTool(hf_token=hf_token)
|
| 320 |
result = await tool.execute(arguments)
|
| 321 |
return result["formatted"], not result.get("isError", False)
|
| 322 |
except Exception as e:
|
|
|
|
| 10 |
from huggingface_hub import HfApi, hf_hub_download
|
| 11 |
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 12 |
|
| 13 |
+
from agent.core.hub_artifacts import is_known_hub_artifact, register_hub_artifact
|
| 14 |
from agent.tools.types import ToolResult
|
| 15 |
|
| 16 |
OperationType = Literal["list", "read", "upload", "delete"]
|
|
|
|
| 40 |
class HfRepoFilesTool:
|
| 41 |
"""Tool for file operations on HF repos."""
|
| 42 |
|
| 43 |
+
def __init__(self, hf_token: Optional[str] = None, session: Any = None):
|
| 44 |
self.api = HfApi(token=hf_token)
|
| 45 |
+
self.session = session
|
| 46 |
|
| 47 |
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 48 |
"""Execute the specified operation."""
|
|
|
|
| 63 |
if handler:
|
| 64 |
return await handler(args)
|
| 65 |
else:
|
| 66 |
+
return self._error(
|
| 67 |
+
f"Unknown operation: {operation}. Valid: list, read, upload, delete"
|
| 68 |
+
)
|
| 69 |
|
| 70 |
except RepositoryNotFoundError:
|
| 71 |
return self._error(f"Repository not found: {args.get('repo_id')}")
|
|
|
|
| 100 |
revision = args.get("revision", "main")
|
| 101 |
path = args.get("path", "")
|
| 102 |
|
| 103 |
+
items = list(
|
| 104 |
+
await _async_call(
|
| 105 |
+
self.api.list_repo_tree,
|
| 106 |
+
repo_id=repo_id,
|
| 107 |
+
repo_type=repo_type,
|
| 108 |
+
revision=revision,
|
| 109 |
+
path_in_repo=path,
|
| 110 |
+
recursive=True,
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
|
| 114 |
if not items:
|
| 115 |
+
return {
|
| 116 |
+
"formatted": f"No files in {repo_id}",
|
| 117 |
+
"totalResults": 0,
|
| 118 |
+
"resultsShared": 0,
|
| 119 |
+
}
|
| 120 |
|
| 121 |
lines = []
|
| 122 |
total_size = 0
|
|
|
|
| 128 |
lines.append(f"{item.path}/")
|
| 129 |
|
| 130 |
url = _build_repo_url(repo_id, repo_type)
|
| 131 |
+
response = (
|
| 132 |
+
f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n"
|
| 133 |
+
+ "\n".join(lines)
|
| 134 |
+
)
|
| 135 |
|
| 136 |
+
return {
|
| 137 |
+
"formatted": response,
|
| 138 |
+
"totalResults": len(items),
|
| 139 |
+
"resultsShared": len(items),
|
| 140 |
+
}
|
| 141 |
|
| 142 |
async def _read(self, args: Dict[str, Any]) -> ToolResult:
|
| 143 |
"""Read file content from a repository."""
|
|
|
|
| 177 |
|
| 178 |
except UnicodeDecodeError:
|
| 179 |
import os
|
| 180 |
+
|
| 181 |
size = os.path.getsize(file_path)
|
| 182 |
+
return {
|
| 183 |
+
"formatted": f"Binary file ({_format_size(size)})",
|
| 184 |
+
"totalResults": 1,
|
| 185 |
+
"resultsShared": 1,
|
| 186 |
+
}
|
| 187 |
|
| 188 |
async def _upload(self, args: Dict[str, Any]) -> ToolResult:
|
| 189 |
"""Upload content to a repository."""
|
|
|
|
| 216 |
create_pr=create_pr,
|
| 217 |
)
|
| 218 |
|
| 219 |
+
if not create_pr and is_known_hub_artifact(self.session, repo_id, repo_type):
|
| 220 |
+
await _async_call(
|
| 221 |
+
register_hub_artifact,
|
| 222 |
+
self.api,
|
| 223 |
+
repo_id,
|
| 224 |
+
repo_type,
|
| 225 |
+
session=self.session,
|
| 226 |
+
force=path == "README.md",
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
url = _build_repo_url(repo_id, repo_type)
|
| 230 |
if create_pr and hasattr(result, "pr_url"):
|
| 231 |
response = f"**Uploaded as PR**\n{result.pr_url}"
|
|
|
|
| 267 |
|
| 268 |
def _error(self, message: str) -> ToolResult:
|
| 269 |
"""Return an error result."""
|
| 270 |
+
return {
|
| 271 |
+
"formatted": message,
|
| 272 |
+
"totalResults": 0,
|
| 273 |
+
"resultsShared": 0,
|
| 274 |
+
"isError": True,
|
| 275 |
+
}
|
| 276 |
|
| 277 |
|
| 278 |
# Tool specification
|
|
|
|
| 349 |
}
|
| 350 |
|
| 351 |
|
| 352 |
+
async def hf_repo_files_handler(
|
| 353 |
+
arguments: Dict[str, Any], session=None
|
| 354 |
+
) -> tuple[str, bool]:
|
| 355 |
"""Handler for agent tool router."""
|
| 356 |
try:
|
| 357 |
hf_token = session.hf_token if session else None
|
| 358 |
+
tool = HfRepoFilesTool(hf_token=hf_token, session=session)
|
| 359 |
result = await tool.execute(arguments)
|
| 360 |
return result["formatted"], not result.get("isError", False)
|
| 361 |
except Exception as e:
|
agent/tools/hf_repo_git_tool.py
CHANGED
|
@@ -10,14 +10,24 @@ from typing import Any, Dict, Literal, Optional
|
|
| 10 |
from huggingface_hub import HfApi
|
| 11 |
from huggingface_hub.utils import RepositoryNotFoundError
|
| 12 |
|
|
|
|
| 13 |
from agent.tools.types import ToolResult
|
| 14 |
|
| 15 |
OperationType = Literal[
|
| 16 |
-
"create_branch",
|
| 17 |
-
"
|
|
|
|
|
|
|
| 18 |
"list_refs",
|
| 19 |
-
"create_pr",
|
| 20 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
]
|
| 22 |
|
| 23 |
|
|
@@ -36,8 +46,9 @@ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
|
|
| 36 |
class HfRepoGitTool:
|
| 37 |
"""Tool for git-like operations on HF repos."""
|
| 38 |
|
| 39 |
-
def __init__(self, hf_token: Optional[str] = None):
|
| 40 |
self.api = HfApi(token=hf_token)
|
|
|
|
| 41 |
|
| 42 |
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 43 |
"""Execute the specified operation."""
|
|
@@ -131,7 +142,11 @@ class HfRepoGitTool:
|
|
| 131 |
)
|
| 132 |
|
| 133 |
url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
|
| 134 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
|
| 137 |
"""Delete a branch."""
|
|
@@ -152,7 +167,11 @@ class HfRepoGitTool:
|
|
| 152 |
repo_type=repo_type,
|
| 153 |
)
|
| 154 |
|
| 155 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
# =========================================================================
|
| 158 |
# TAG OPERATIONS
|
|
@@ -183,7 +202,11 @@ class HfRepoGitTool:
|
|
| 183 |
)
|
| 184 |
|
| 185 |
url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
|
| 186 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
|
| 189 |
"""Delete a tag."""
|
|
@@ -204,7 +227,11 @@ class HfRepoGitTool:
|
|
| 204 |
repo_type=repo_type,
|
| 205 |
)
|
| 206 |
|
| 207 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
# =========================================================================
|
| 210 |
# LIST REFS
|
|
@@ -226,7 +253,9 @@ class HfRepoGitTool:
|
|
| 226 |
)
|
| 227 |
|
| 228 |
branches = [b.name for b in refs.branches] if refs.branches else []
|
| 229 |
-
tags =
|
|
|
|
|
|
|
| 230 |
|
| 231 |
url = _build_repo_url(repo_id, repo_type)
|
| 232 |
lines = [f"**{repo_id}**", url, ""]
|
|
@@ -241,7 +270,11 @@ class HfRepoGitTool:
|
|
| 241 |
else:
|
| 242 |
lines.append("**Tags:** none")
|
| 243 |
|
| 244 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
# =========================================================================
|
| 247 |
# PR OPERATIONS
|
|
@@ -270,7 +303,7 @@ class HfRepoGitTool:
|
|
| 270 |
|
| 271 |
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
|
| 272 |
return {
|
| 273 |
-
"formatted": f
|
| 274 |
"totalResults": 1,
|
| 275 |
"resultsShared": 1,
|
| 276 |
}
|
|
@@ -285,17 +318,27 @@ class HfRepoGitTool:
|
|
| 285 |
repo_type = args.get("repo_type", "model")
|
| 286 |
status = args.get("status", "all") # open, closed, all
|
| 287 |
|
| 288 |
-
discussions = list(
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
| 293 |
|
| 294 |
if not discussions:
|
| 295 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
url = _build_repo_url(repo_id, repo_type)
|
| 298 |
-
lines = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
for d in discussions[:20]:
|
| 301 |
if d.status == "draft":
|
|
@@ -309,7 +352,11 @@ class HfRepoGitTool:
|
|
| 309 |
type_label = "PR" if d.is_pull_request else "D"
|
| 310 |
lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}")
|
| 311 |
|
| 312 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 315 |
"""Get PR details."""
|
|
@@ -335,7 +382,7 @@ class HfRepoGitTool:
|
|
| 335 |
"draft": "Draft",
|
| 336 |
"open": "Open",
|
| 337 |
"merged": "Merged",
|
| 338 |
-
"closed": "Closed"
|
| 339 |
}
|
| 340 |
status = status_map.get(pr.status, pr.status.capitalize())
|
| 341 |
type_label = "Pull Request" if pr.is_pull_request else "Discussion"
|
|
@@ -349,9 +396,13 @@ class HfRepoGitTool:
|
|
| 349 |
|
| 350 |
if pr.is_pull_request:
|
| 351 |
if pr.status == "draft":
|
| 352 |
-
lines.append(
|
|
|
|
|
|
|
| 353 |
elif pr.status == "open":
|
| 354 |
-
lines.append(
|
|
|
|
|
|
|
| 355 |
|
| 356 |
return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
|
| 357 |
|
|
@@ -377,7 +428,11 @@ class HfRepoGitTool:
|
|
| 377 |
)
|
| 378 |
|
| 379 |
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 380 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 383 |
"""Close a PR/discussion."""
|
|
@@ -401,7 +456,11 @@ class HfRepoGitTool:
|
|
| 401 |
repo_type=repo_type,
|
| 402 |
)
|
| 403 |
|
| 404 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 407 |
"""Add a comment to a PR/discussion."""
|
|
@@ -427,7 +486,11 @@ class HfRepoGitTool:
|
|
| 427 |
)
|
| 428 |
|
| 429 |
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 430 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult:
|
| 433 |
"""Change PR/discussion status (mainly to convert draft to open)."""
|
|
@@ -455,7 +518,11 @@ class HfRepoGitTool:
|
|
| 455 |
)
|
| 456 |
|
| 457 |
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 458 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
# =========================================================================
|
| 461 |
# REPO MANAGEMENT
|
|
@@ -473,7 +540,9 @@ class HfRepoGitTool:
|
|
| 473 |
space_sdk = args.get("space_sdk")
|
| 474 |
|
| 475 |
if repo_type == "space" and not space_sdk:
|
| 476 |
-
return self._error(
|
|
|
|
|
|
|
| 477 |
|
| 478 |
kwargs = {
|
| 479 |
"repo_id": repo_id,
|
|
@@ -485,6 +554,17 @@ class HfRepoGitTool:
|
|
| 485 |
kwargs["space_sdk"] = space_sdk
|
| 486 |
|
| 487 |
result = await _async_call(self.api.create_repo, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
|
| 489 |
return {
|
| 490 |
"formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
|
|
@@ -504,7 +584,9 @@ class HfRepoGitTool:
|
|
| 504 |
gated = args.get("gated")
|
| 505 |
|
| 506 |
if private is None and gated is None:
|
| 507 |
-
return self._error(
|
|
|
|
|
|
|
| 508 |
|
| 509 |
kwargs = {"repo_id": repo_id, "repo_type": repo_type}
|
| 510 |
if private is not None:
|
|
@@ -521,11 +603,20 @@ class HfRepoGitTool:
|
|
| 521 |
changes.append(f"gated={gated}")
|
| 522 |
|
| 523 |
url = f"{_build_repo_url(repo_id, repo_type)}/settings"
|
| 524 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
def _error(self, message: str) -> ToolResult:
|
| 527 |
"""Return an error result."""
|
| 528 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
|
| 530 |
|
| 531 |
# Tool specification
|
|
@@ -571,10 +662,20 @@ HF_REPO_GIT_TOOL_SPEC = {
|
|
| 571 |
"operation": {
|
| 572 |
"type": "string",
|
| 573 |
"enum": [
|
| 574 |
-
"create_branch",
|
| 575 |
-
"
|
| 576 |
-
"
|
| 577 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
],
|
| 579 |
"description": "Operation to execute",
|
| 580 |
},
|
|
@@ -653,11 +754,13 @@ HF_REPO_GIT_TOOL_SPEC = {
|
|
| 653 |
}
|
| 654 |
|
| 655 |
|
| 656 |
-
async def hf_repo_git_handler(
|
|
|
|
|
|
|
| 657 |
"""Handler for agent tool router."""
|
| 658 |
try:
|
| 659 |
hf_token = session.hf_token if session else None
|
| 660 |
-
tool = HfRepoGitTool(hf_token=hf_token)
|
| 661 |
result = await tool.execute(arguments)
|
| 662 |
return result["formatted"], not result.get("isError", False)
|
| 663 |
except Exception as e:
|
|
|
|
| 10 |
from huggingface_hub import HfApi
|
| 11 |
from huggingface_hub.utils import RepositoryNotFoundError
|
| 12 |
|
| 13 |
+
from agent.core.hub_artifacts import register_hub_artifact
|
| 14 |
from agent.tools.types import ToolResult
|
| 15 |
|
| 16 |
OperationType = Literal[
|
| 17 |
+
"create_branch",
|
| 18 |
+
"delete_branch",
|
| 19 |
+
"create_tag",
|
| 20 |
+
"delete_tag",
|
| 21 |
"list_refs",
|
| 22 |
+
"create_pr",
|
| 23 |
+
"list_prs",
|
| 24 |
+
"get_pr",
|
| 25 |
+
"merge_pr",
|
| 26 |
+
"close_pr",
|
| 27 |
+
"comment_pr",
|
| 28 |
+
"change_pr_status",
|
| 29 |
+
"create_repo",
|
| 30 |
+
"update_repo",
|
| 31 |
]
|
| 32 |
|
| 33 |
|
|
|
|
| 46 |
class HfRepoGitTool:
|
| 47 |
"""Tool for git-like operations on HF repos."""
|
| 48 |
|
| 49 |
+
def __init__(self, hf_token: Optional[str] = None, session: Any = None):
|
| 50 |
self.api = HfApi(token=hf_token)
|
| 51 |
+
self.session = session
|
| 52 |
|
| 53 |
async def execute(self, args: Dict[str, Any]) -> ToolResult:
|
| 54 |
"""Execute the specified operation."""
|
|
|
|
| 142 |
)
|
| 143 |
|
| 144 |
url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
|
| 145 |
+
return {
|
| 146 |
+
"formatted": f"**Branch created:** {branch}\n{url}",
|
| 147 |
+
"totalResults": 1,
|
| 148 |
+
"resultsShared": 1,
|
| 149 |
+
}
|
| 150 |
|
| 151 |
async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
|
| 152 |
"""Delete a branch."""
|
|
|
|
| 167 |
repo_type=repo_type,
|
| 168 |
)
|
| 169 |
|
| 170 |
+
return {
|
| 171 |
+
"formatted": f"**Branch deleted:** {branch}",
|
| 172 |
+
"totalResults": 1,
|
| 173 |
+
"resultsShared": 1,
|
| 174 |
+
}
|
| 175 |
|
| 176 |
# =========================================================================
|
| 177 |
# TAG OPERATIONS
|
|
|
|
| 202 |
)
|
| 203 |
|
| 204 |
url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
|
| 205 |
+
return {
|
| 206 |
+
"formatted": f"**Tag created:** {tag}\n{url}",
|
| 207 |
+
"totalResults": 1,
|
| 208 |
+
"resultsShared": 1,
|
| 209 |
+
}
|
| 210 |
|
| 211 |
async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
|
| 212 |
"""Delete a tag."""
|
|
|
|
| 227 |
repo_type=repo_type,
|
| 228 |
)
|
| 229 |
|
| 230 |
+
return {
|
| 231 |
+
"formatted": f"**Tag deleted:** {tag}",
|
| 232 |
+
"totalResults": 1,
|
| 233 |
+
"resultsShared": 1,
|
| 234 |
+
}
|
| 235 |
|
| 236 |
# =========================================================================
|
| 237 |
# LIST REFS
|
|
|
|
| 253 |
)
|
| 254 |
|
| 255 |
branches = [b.name for b in refs.branches] if refs.branches else []
|
| 256 |
+
tags = (
|
| 257 |
+
[t.name for t in refs.tags] if hasattr(refs, "tags") and refs.tags else []
|
| 258 |
+
)
|
| 259 |
|
| 260 |
url = _build_repo_url(repo_id, repo_type)
|
| 261 |
lines = [f"**{repo_id}**", url, ""]
|
|
|
|
| 270 |
else:
|
| 271 |
lines.append("**Tags:** none")
|
| 272 |
|
| 273 |
+
return {
|
| 274 |
+
"formatted": "\n".join(lines),
|
| 275 |
+
"totalResults": len(branches) + len(tags),
|
| 276 |
+
"resultsShared": len(branches) + len(tags),
|
| 277 |
+
}
|
| 278 |
|
| 279 |
# =========================================================================
|
| 280 |
# PR OPERATIONS
|
|
|
|
| 303 |
|
| 304 |
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
|
| 305 |
return {
|
| 306 |
+
"formatted": f'**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision="refs/pr/{result.num}"',
|
| 307 |
"totalResults": 1,
|
| 308 |
"resultsShared": 1,
|
| 309 |
}
|
|
|
|
| 318 |
repo_type = args.get("repo_type", "model")
|
| 319 |
status = args.get("status", "all") # open, closed, all
|
| 320 |
|
| 321 |
+
discussions = list(
|
| 322 |
+
self.api.get_repo_discussions(
|
| 323 |
+
repo_id=repo_id,
|
| 324 |
+
repo_type=repo_type,
|
| 325 |
+
discussion_status=status if status != "all" else None,
|
| 326 |
+
)
|
| 327 |
+
)
|
| 328 |
|
| 329 |
if not discussions:
|
| 330 |
+
return {
|
| 331 |
+
"formatted": f"No discussions in {repo_id}",
|
| 332 |
+
"totalResults": 0,
|
| 333 |
+
"resultsShared": 0,
|
| 334 |
+
}
|
| 335 |
|
| 336 |
url = _build_repo_url(repo_id, repo_type)
|
| 337 |
+
lines = [
|
| 338 |
+
f"**{repo_id}** - {len(discussions)} discussions",
|
| 339 |
+
f"{url}/discussions",
|
| 340 |
+
"",
|
| 341 |
+
]
|
| 342 |
|
| 343 |
for d in discussions[:20]:
|
| 344 |
if d.status == "draft":
|
|
|
|
| 352 |
type_label = "PR" if d.is_pull_request else "D"
|
| 353 |
lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}")
|
| 354 |
|
| 355 |
+
return {
|
| 356 |
+
"formatted": "\n".join(lines),
|
| 357 |
+
"totalResults": len(discussions),
|
| 358 |
+
"resultsShared": min(20, len(discussions)),
|
| 359 |
+
}
|
| 360 |
|
| 361 |
async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 362 |
"""Get PR details."""
|
|
|
|
| 382 |
"draft": "Draft",
|
| 383 |
"open": "Open",
|
| 384 |
"merged": "Merged",
|
| 385 |
+
"closed": "Closed",
|
| 386 |
}
|
| 387 |
status = status_map.get(pr.status, pr.status.capitalize())
|
| 388 |
type_label = "Pull Request" if pr.is_pull_request else "Discussion"
|
|
|
|
| 396 |
|
| 397 |
if pr.is_pull_request:
|
| 398 |
if pr.status == "draft":
|
| 399 |
+
lines.append(
|
| 400 |
+
f'\nTo add commits: upload with revision="refs/pr/{pr_num}"'
|
| 401 |
+
)
|
| 402 |
elif pr.status == "open":
|
| 403 |
+
lines.append(
|
| 404 |
+
f'\nTo add commits: upload with revision="refs/pr/{pr_num}"'
|
| 405 |
+
)
|
| 406 |
|
| 407 |
return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
|
| 408 |
|
|
|
|
| 428 |
)
|
| 429 |
|
| 430 |
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 431 |
+
return {
|
| 432 |
+
"formatted": f"**PR #{pr_num} merged**\n{url}",
|
| 433 |
+
"totalResults": 1,
|
| 434 |
+
"resultsShared": 1,
|
| 435 |
+
}
|
| 436 |
|
| 437 |
async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 438 |
"""Close a PR/discussion."""
|
|
|
|
| 456 |
repo_type=repo_type,
|
| 457 |
)
|
| 458 |
|
| 459 |
+
return {
|
| 460 |
+
"formatted": f"**Discussion #{pr_num} closed**",
|
| 461 |
+
"totalResults": 1,
|
| 462 |
+
"resultsShared": 1,
|
| 463 |
+
}
|
| 464 |
|
| 465 |
async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
|
| 466 |
"""Add a comment to a PR/discussion."""
|
|
|
|
| 486 |
)
|
| 487 |
|
| 488 |
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 489 |
+
return {
|
| 490 |
+
"formatted": f"**Comment added to #{pr_num}**\n{url}",
|
| 491 |
+
"totalResults": 1,
|
| 492 |
+
"resultsShared": 1,
|
| 493 |
+
}
|
| 494 |
|
| 495 |
async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult:
|
| 496 |
"""Change PR/discussion status (mainly to convert draft to open)."""
|
|
|
|
| 518 |
)
|
| 519 |
|
| 520 |
url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
|
| 521 |
+
return {
|
| 522 |
+
"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}",
|
| 523 |
+
"totalResults": 1,
|
| 524 |
+
"resultsShared": 1,
|
| 525 |
+
}
|
| 526 |
|
| 527 |
# =========================================================================
|
| 528 |
# REPO MANAGEMENT
|
|
|
|
| 540 |
space_sdk = args.get("space_sdk")
|
| 541 |
|
| 542 |
if repo_type == "space" and not space_sdk:
|
| 543 |
+
return self._error(
|
| 544 |
+
"space_sdk required for spaces (gradio/streamlit/docker/static)"
|
| 545 |
+
)
|
| 546 |
|
| 547 |
kwargs = {
|
| 548 |
"repo_id": repo_id,
|
|
|
|
| 554 |
kwargs["space_sdk"] = space_sdk
|
| 555 |
|
| 556 |
result = await _async_call(self.api.create_repo, **kwargs)
|
| 557 |
+
extra_metadata = None
|
| 558 |
+
if repo_type == "space" and space_sdk:
|
| 559 |
+
extra_metadata = {"sdk": space_sdk}
|
| 560 |
+
await _async_call(
|
| 561 |
+
register_hub_artifact,
|
| 562 |
+
self.api,
|
| 563 |
+
repo_id,
|
| 564 |
+
repo_type,
|
| 565 |
+
session=self.session,
|
| 566 |
+
extra_metadata=extra_metadata,
|
| 567 |
+
)
|
| 568 |
|
| 569 |
return {
|
| 570 |
"formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
|
|
|
|
| 584 |
gated = args.get("gated")
|
| 585 |
|
| 586 |
if private is None and gated is None:
|
| 587 |
+
return self._error(
|
| 588 |
+
"Specify private (bool) or gated ('auto'/'manual'/false)"
|
| 589 |
+
)
|
| 590 |
|
| 591 |
kwargs = {"repo_id": repo_id, "repo_type": repo_type}
|
| 592 |
if private is not None:
|
|
|
|
| 603 |
changes.append(f"gated={gated}")
|
| 604 |
|
| 605 |
url = f"{_build_repo_url(repo_id, repo_type)}/settings"
|
| 606 |
+
return {
|
| 607 |
+
"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}",
|
| 608 |
+
"totalResults": 1,
|
| 609 |
+
"resultsShared": 1,
|
| 610 |
+
}
|
| 611 |
|
| 612 |
def _error(self, message: str) -> ToolResult:
|
| 613 |
"""Return an error result."""
|
| 614 |
+
return {
|
| 615 |
+
"formatted": message,
|
| 616 |
+
"totalResults": 0,
|
| 617 |
+
"resultsShared": 0,
|
| 618 |
+
"isError": True,
|
| 619 |
+
}
|
| 620 |
|
| 621 |
|
| 622 |
# Tool specification
|
|
|
|
| 662 |
"operation": {
|
| 663 |
"type": "string",
|
| 664 |
"enum": [
|
| 665 |
+
"create_branch",
|
| 666 |
+
"delete_branch",
|
| 667 |
+
"create_tag",
|
| 668 |
+
"delete_tag",
|
| 669 |
+
"list_refs",
|
| 670 |
+
"create_pr",
|
| 671 |
+
"list_prs",
|
| 672 |
+
"get_pr",
|
| 673 |
+
"merge_pr",
|
| 674 |
+
"close_pr",
|
| 675 |
+
"comment_pr",
|
| 676 |
+
"change_pr_status",
|
| 677 |
+
"create_repo",
|
| 678 |
+
"update_repo",
|
| 679 |
],
|
| 680 |
"description": "Operation to execute",
|
| 681 |
},
|
|
|
|
| 754 |
}
|
| 755 |
|
| 756 |
|
| 757 |
+
async def hf_repo_git_handler(
|
| 758 |
+
arguments: Dict[str, Any], session=None
|
| 759 |
+
) -> tuple[str, bool]:
|
| 760 |
"""Handler for agent tool router."""
|
| 761 |
try:
|
| 762 |
hf_token = session.hf_token if session else None
|
| 763 |
+
tool = HfRepoGitTool(hf_token=hf_token, session=session)
|
| 764 |
result = await tool.execute(arguments)
|
| 765 |
return result["formatted"], not result.get("isError", False)
|
| 766 |
except Exception as e:
|
agent/tools/jobs_tool.py
CHANGED
|
@@ -7,22 +7,24 @@ Refactored to use official huggingface-hub library instead of custom HTTP client
|
|
| 7 |
import asyncio
|
| 8 |
import base64
|
| 9 |
import http.client
|
| 10 |
-
import os
|
| 11 |
-
import re
|
| 12 |
-
from typing import Any, Dict, Literal, Optional, Callable, Awaitable
|
| 13 |
-
|
| 14 |
import logging
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
import httpx
|
| 17 |
from huggingface_hub import HfApi
|
| 18 |
from huggingface_hub.utils import HfHubHTTPError
|
| 19 |
|
| 20 |
-
from agent.core.hf_access import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from agent.core.session import Event
|
| 22 |
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
| 23 |
from agent.tools.types import ToolResult
|
| 24 |
-
|
| 25 |
-
logger = logging.getLogger(__name__)
|
| 26 |
from agent.tools.utilities import (
|
| 27 |
format_job_details,
|
| 28 |
format_jobs_table,
|
|
@@ -30,6 +32,8 @@ from agent.tools.utilities import (
|
|
| 30 |
format_scheduled_jobs_table,
|
| 31 |
)
|
| 32 |
|
|
|
|
|
|
|
| 33 |
# Hardware flavors
|
| 34 |
CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"]
|
| 35 |
GPU_FLAVORS = [
|
|
@@ -119,11 +123,11 @@ def _filter_uv_install_output(logs: list[str]) -> list[str]:
|
|
| 119 |
return logs
|
| 120 |
|
| 121 |
|
| 122 |
-
_ANSI_RE = re.compile(r
|
| 123 |
|
| 124 |
|
| 125 |
def _strip_ansi(text: str) -> str:
|
| 126 |
-
return _ANSI_RE.sub(
|
| 127 |
|
| 128 |
|
| 129 |
_DEFAULT_ENV = {
|
|
@@ -235,6 +239,26 @@ def _resolve_uv_command(
|
|
| 235 |
return _build_uv_command(script, with_deps, python, script_args)
|
| 236 |
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
async def _async_call(func, *args, **kwargs):
|
| 239 |
"""Wrap synchronous HfApi calls for async context"""
|
| 240 |
return await asyncio.to_thread(func, *args, **kwargs)
|
|
@@ -432,7 +456,9 @@ class HfJobsTool:
|
|
| 432 |
def log_producer():
|
| 433 |
try:
|
| 434 |
# fetch_job_logs is a blocking sync generator
|
| 435 |
-
logs_gen = self.api.fetch_job_logs(
|
|
|
|
|
|
|
| 436 |
for line in logs_gen:
|
| 437 |
# Push line to queue thread-safely
|
| 438 |
loop.call_soon_threadsafe(queue.put_nowait, line)
|
|
@@ -556,6 +582,8 @@ class HfJobsTool:
|
|
| 556 |
image = args.get("image", "python:3.12")
|
| 557 |
job_type = "Docker"
|
| 558 |
|
|
|
|
|
|
|
| 559 |
# Run the job
|
| 560 |
flavor = args.get("hardware_flavor", "cpu-basic")
|
| 561 |
timeout_str = args.get("timeout", "30m")
|
|
@@ -578,7 +606,9 @@ class HfJobsTool:
|
|
| 578 |
image=image,
|
| 579 |
command=command,
|
| 580 |
env=env_dict,
|
| 581 |
-
secrets=_add_environment_variables(
|
|
|
|
|
|
|
| 582 |
flavor=flavor,
|
| 583 |
timeout=timeout_str,
|
| 584 |
namespace=self.namespace,
|
|
@@ -636,10 +666,18 @@ class HfJobsTool:
|
|
| 636 |
submit_ts = None
|
| 637 |
if self.session:
|
| 638 |
from agent.core import telemetry
|
|
|
|
| 639 |
submit_ts = await telemetry.record_hf_job_submit(
|
| 640 |
-
self.session,
|
| 641 |
-
|
| 642 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
)
|
| 644 |
# Top-up signal: this submit succeeded after a prior billing
|
| 645 |
# block in the same session, and we haven't fired the event
|
|
@@ -656,7 +694,8 @@ class HfJobsTool:
|
|
| 656 |
)
|
| 657 |
if blocked:
|
| 658 |
await telemetry.record_credits_topped_up(
|
| 659 |
-
self.session,
|
|
|
|
| 660 |
)
|
| 661 |
|
| 662 |
# Wait for completion and stream logs
|
|
@@ -670,9 +709,13 @@ class HfJobsTool:
|
|
| 670 |
|
| 671 |
if self.session and submit_ts is not None:
|
| 672 |
from agent.core import telemetry
|
|
|
|
| 673 |
await telemetry.record_hf_job_complete(
|
| 674 |
-
self.session,
|
| 675 |
-
|
|
|
|
|
|
|
|
|
|
| 676 |
)
|
| 677 |
|
| 678 |
# Untrack job ID (completed or failed, no longer needs cancellation)
|
|
@@ -699,7 +742,9 @@ class HfJobsTool:
|
|
| 699 |
filtered_logs = _filter_uv_install_output(all_logs)
|
| 700 |
|
| 701 |
# Format all logs for the agent
|
| 702 |
-
log_text =
|
|
|
|
|
|
|
| 703 |
|
| 704 |
response = f"""{job_type} job completed!
|
| 705 |
|
|
@@ -891,6 +936,8 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}}
|
|
| 891 |
image = args.get("image", "python:3.12")
|
| 892 |
job_type = "Docker"
|
| 893 |
|
|
|
|
|
|
|
| 894 |
# Create scheduled job
|
| 895 |
scheduled_job = await _async_call(
|
| 896 |
self.api.create_scheduled_job,
|
|
@@ -1215,6 +1262,7 @@ async def hf_jobs_handler(
|
|
| 1215 |
sandbox = getattr(session, "sandbox", None) if session else None
|
| 1216 |
if sandbox and script:
|
| 1217 |
from agent.tools.sandbox_tool import resolve_sandbox_script
|
|
|
|
| 1218 |
content, error = await resolve_sandbox_script(sandbox, script)
|
| 1219 |
if error:
|
| 1220 |
return error, False
|
|
|
|
| 7 |
import asyncio
|
| 8 |
import base64
|
| 9 |
import http.client
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import logging
|
| 11 |
+
import re
|
| 12 |
+
import shlex
|
| 13 |
+
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
|
| 14 |
|
| 15 |
import httpx
|
| 16 |
from huggingface_hub import HfApi
|
| 17 |
from huggingface_hub.utils import HfHubHTTPError
|
| 18 |
|
| 19 |
+
from agent.core.hf_access import (
|
| 20 |
+
JobsAccessError,
|
| 21 |
+
is_billing_error,
|
| 22 |
+
resolve_jobs_namespace,
|
| 23 |
+
)
|
| 24 |
+
from agent.core.hub_artifacts import build_hub_artifact_sitecustomize
|
| 25 |
from agent.core.session import Event
|
| 26 |
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
| 27 |
from agent.tools.types import ToolResult
|
|
|
|
|
|
|
| 28 |
from agent.tools.utilities import (
|
| 29 |
format_job_details,
|
| 30 |
format_jobs_table,
|
|
|
|
| 32 |
format_scheduled_jobs_table,
|
| 33 |
)
|
| 34 |
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
# Hardware flavors
|
| 38 |
CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"]
|
| 39 |
GPU_FLAVORS = [
|
|
|
|
| 123 |
return logs
|
| 124 |
|
| 125 |
|
| 126 |
+
_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07")
|
| 127 |
|
| 128 |
|
| 129 |
def _strip_ansi(text: str) -> str:
|
| 130 |
+
return _ANSI_RE.sub("", text)
|
| 131 |
|
| 132 |
|
| 133 |
_DEFAULT_ENV = {
|
|
|
|
| 239 |
return _build_uv_command(script, with_deps, python, script_args)
|
| 240 |
|
| 241 |
|
| 242 |
+
def _wrap_command_with_artifact_bootstrap(
|
| 243 |
+
command: list[str], session: Any = None
|
| 244 |
+
) -> list[str]:
|
| 245 |
+
"""Install sitecustomize hooks before the user command runs in HF Jobs."""
|
| 246 |
+
sitecustomize = build_hub_artifact_sitecustomize(session)
|
| 247 |
+
if not sitecustomize:
|
| 248 |
+
return command
|
| 249 |
+
|
| 250 |
+
encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii")
|
| 251 |
+
original_command = shlex.join(command)
|
| 252 |
+
shell = (
|
| 253 |
+
'set -e; _ml_intern_artifacts_dir="$(mktemp -d)"; '
|
| 254 |
+
f"printf %s {shlex.quote(encoded)} | base64 -d "
|
| 255 |
+
'> "$_ml_intern_artifacts_dir/sitecustomize.py"; '
|
| 256 |
+
'export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"; '
|
| 257 |
+
f"exec {original_command}"
|
| 258 |
+
)
|
| 259 |
+
return ["/bin/sh", "-lc", shell]
|
| 260 |
+
|
| 261 |
+
|
| 262 |
async def _async_call(func, *args, **kwargs):
|
| 263 |
"""Wrap synchronous HfApi calls for async context"""
|
| 264 |
return await asyncio.to_thread(func, *args, **kwargs)
|
|
|
|
| 456 |
def log_producer():
|
| 457 |
try:
|
| 458 |
# fetch_job_logs is a blocking sync generator
|
| 459 |
+
logs_gen = self.api.fetch_job_logs(
|
| 460 |
+
job_id=job_id, namespace=namespace
|
| 461 |
+
)
|
| 462 |
for line in logs_gen:
|
| 463 |
# Push line to queue thread-safely
|
| 464 |
loop.call_soon_threadsafe(queue.put_nowait, line)
|
|
|
|
| 582 |
image = args.get("image", "python:3.12")
|
| 583 |
job_type = "Docker"
|
| 584 |
|
| 585 |
+
command = _wrap_command_with_artifact_bootstrap(command, self.session)
|
| 586 |
+
|
| 587 |
# Run the job
|
| 588 |
flavor = args.get("hardware_flavor", "cpu-basic")
|
| 589 |
timeout_str = args.get("timeout", "30m")
|
|
|
|
| 606 |
image=image,
|
| 607 |
command=command,
|
| 608 |
env=env_dict,
|
| 609 |
+
secrets=_add_environment_variables(
|
| 610 |
+
args.get("secrets"), self.hf_token
|
| 611 |
+
),
|
| 612 |
flavor=flavor,
|
| 613 |
timeout=timeout_str,
|
| 614 |
namespace=self.namespace,
|
|
|
|
| 666 |
submit_ts = None
|
| 667 |
if self.session:
|
| 668 |
from agent.core import telemetry
|
| 669 |
+
|
| 670 |
submit_ts = await telemetry.record_hf_job_submit(
|
| 671 |
+
self.session,
|
| 672 |
+
job,
|
| 673 |
+
{
|
| 674 |
+
**args,
|
| 675 |
+
"hardware_flavor": flavor,
|
| 676 |
+
"timeout": timeout_str,
|
| 677 |
+
"namespace": self.namespace,
|
| 678 |
+
},
|
| 679 |
+
image=image,
|
| 680 |
+
job_type=job_type,
|
| 681 |
)
|
| 682 |
# Top-up signal: this submit succeeded after a prior billing
|
| 683 |
# block in the same session, and we haven't fired the event
|
|
|
|
| 694 |
)
|
| 695 |
if blocked:
|
| 696 |
await telemetry.record_credits_topped_up(
|
| 697 |
+
self.session,
|
| 698 |
+
namespace=self.namespace,
|
| 699 |
)
|
| 700 |
|
| 701 |
# Wait for completion and stream logs
|
|
|
|
| 709 |
|
| 710 |
if self.session and submit_ts is not None:
|
| 711 |
from agent.core import telemetry
|
| 712 |
+
|
| 713 |
await telemetry.record_hf_job_complete(
|
| 714 |
+
self.session,
|
| 715 |
+
job,
|
| 716 |
+
flavor=flavor,
|
| 717 |
+
final_status=final_status,
|
| 718 |
+
submit_ts=submit_ts,
|
| 719 |
)
|
| 720 |
|
| 721 |
# Untrack job ID (completed or failed, no longer needs cancellation)
|
|
|
|
| 742 |
filtered_logs = _filter_uv_install_output(all_logs)
|
| 743 |
|
| 744 |
# Format all logs for the agent
|
| 745 |
+
log_text = (
|
| 746 |
+
_strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)"
|
| 747 |
+
)
|
| 748 |
|
| 749 |
response = f"""{job_type} job completed!
|
| 750 |
|
|
|
|
| 936 |
image = args.get("image", "python:3.12")
|
| 937 |
job_type = "Docker"
|
| 938 |
|
| 939 |
+
command = _wrap_command_with_artifact_bootstrap(command, self.session)
|
| 940 |
+
|
| 941 |
# Create scheduled job
|
| 942 |
scheduled_job = await _async_call(
|
| 943 |
self.api.create_scheduled_job,
|
|
|
|
| 1262 |
sandbox = getattr(session, "sandbox", None) if session else None
|
| 1263 |
if sandbox and script:
|
| 1264 |
from agent.tools.sandbox_tool import resolve_sandbox_script
|
| 1265 |
+
|
| 1266 |
content, error = await resolve_sandbox_script(sandbox, script)
|
| 1267 |
if error:
|
| 1268 |
return error, False
|
agent/tools/local_tools.py
CHANGED
|
@@ -15,6 +15,8 @@ import tempfile
|
|
| 15 |
from pathlib import Path
|
| 16 |
from typing import Any
|
| 17 |
|
|
|
|
|
|
|
| 18 |
|
| 19 |
MAX_OUTPUT_CHARS = 25_000
|
| 20 |
MAX_LINE_LENGTH = 4000
|
|
@@ -22,7 +24,7 @@ DEFAULT_READ_LINES = 2000
|
|
| 22 |
DEFAULT_TIMEOUT = 120
|
| 23 |
MAX_TIMEOUT = 36000 # 10 hours — needed for long training runs (e.g. PostTrainBench)
|
| 24 |
|
| 25 |
-
_ANSI_RE = re.compile(r
|
| 26 |
|
| 27 |
# Track files that have been read this session (enforces read-before-write/edit)
|
| 28 |
_files_read: set[str] = set()
|
|
@@ -63,17 +65,21 @@ def _atomic_write(path: Path, content: str) -> None:
|
|
| 63 |
|
| 64 |
|
| 65 |
def _strip_ansi(text: str) -> str:
|
| 66 |
-
return _ANSI_RE.sub(
|
| 67 |
|
| 68 |
|
| 69 |
-
def _truncate_output(
|
|
|
|
|
|
|
| 70 |
"""Tail-biased truncation with temp file spillover for full output access."""
|
| 71 |
if len(output) <= max_chars:
|
| 72 |
return output
|
| 73 |
# Write full output to temp file so LLM can read specific sections
|
| 74 |
spill_path = None
|
| 75 |
try:
|
| 76 |
-
with tempfile.NamedTemporaryFile(
|
|
|
|
|
|
|
| 77 |
f.write(output)
|
| 78 |
spill_path = f.name
|
| 79 |
except Exception:
|
|
@@ -93,10 +99,14 @@ def _truncate_output(output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio:
|
|
| 93 |
|
| 94 |
# ── Handlers ────────────────────────────────────────────────────────────
|
| 95 |
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
| 97 |
command = args.get("command", "")
|
| 98 |
if not command:
|
| 99 |
return "No command provided.", False
|
|
|
|
| 100 |
work_dir = args.get("work_dir", ".")
|
| 101 |
timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT)
|
| 102 |
try:
|
|
@@ -174,9 +184,12 @@ async def _write_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]:
|
|
| 174 |
# Syntax validation for Python files
|
| 175 |
if p.suffix == ".py":
|
| 176 |
from agent.tools.edit_utils import validate_python
|
|
|
|
| 177 |
warnings = validate_python(content, file_path)
|
| 178 |
if warnings:
|
| 179 |
-
msg += "\n\nValidation warnings:\n" + "\n".join(
|
|
|
|
|
|
|
| 180 |
return msg, True
|
| 181 |
except Exception as e:
|
| 182 |
return f"write error: {e}", False
|
|
@@ -229,7 +242,9 @@ async def _edit_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]:
|
|
| 229 |
if p.suffix == ".py":
|
| 230 |
warnings = validate_python(new_text, file_path)
|
| 231 |
if warnings:
|
| 232 |
-
msg += "\n\nValidation warnings:\n" + "\n".join(
|
|
|
|
|
|
|
| 233 |
return msg, True
|
| 234 |
|
| 235 |
|
|
|
|
| 15 |
from pathlib import Path
|
| 16 |
from typing import Any
|
| 17 |
|
| 18 |
+
from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap
|
| 19 |
+
|
| 20 |
|
| 21 |
MAX_OUTPUT_CHARS = 25_000
|
| 22 |
MAX_LINE_LENGTH = 4000
|
|
|
|
| 24 |
DEFAULT_TIMEOUT = 120
|
| 25 |
MAX_TIMEOUT = 36000 # 10 hours — needed for long training runs (e.g. PostTrainBench)
|
| 26 |
|
| 27 |
+
_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07")
|
| 28 |
|
| 29 |
# Track files that have been read this session (enforces read-before-write/edit)
|
| 30 |
_files_read: set[str] = set()
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def _strip_ansi(text: str) -> str:
|
| 68 |
+
return _ANSI_RE.sub("", text)
|
| 69 |
|
| 70 |
|
| 71 |
+
def _truncate_output(
|
| 72 |
+
output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25
|
| 73 |
+
) -> str:
|
| 74 |
"""Tail-biased truncation with temp file spillover for full output access."""
|
| 75 |
if len(output) <= max_chars:
|
| 76 |
return output
|
| 77 |
# Write full output to temp file so LLM can read specific sections
|
| 78 |
spill_path = None
|
| 79 |
try:
|
| 80 |
+
with tempfile.NamedTemporaryFile(
|
| 81 |
+
mode="w", suffix=".txt", prefix="bash_output_", delete=False
|
| 82 |
+
) as f:
|
| 83 |
f.write(output)
|
| 84 |
spill_path = f.name
|
| 85 |
except Exception:
|
|
|
|
| 99 |
|
| 100 |
# ── Handlers ────────────────────────────────────────────────────────────
|
| 101 |
|
| 102 |
+
|
| 103 |
+
async def _bash_handler(
|
| 104 |
+
args: dict[str, Any], session: Any = None, **_kw
|
| 105 |
+
) -> tuple[str, bool]:
|
| 106 |
command = args.get("command", "")
|
| 107 |
if not command:
|
| 108 |
return "No command provided.", False
|
| 109 |
+
command = wrap_shell_command_with_hub_artifact_bootstrap(command, session)
|
| 110 |
work_dir = args.get("work_dir", ".")
|
| 111 |
timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT)
|
| 112 |
try:
|
|
|
|
| 184 |
# Syntax validation for Python files
|
| 185 |
if p.suffix == ".py":
|
| 186 |
from agent.tools.edit_utils import validate_python
|
| 187 |
+
|
| 188 |
warnings = validate_python(content, file_path)
|
| 189 |
if warnings:
|
| 190 |
+
msg += "\n\nValidation warnings:\n" + "\n".join(
|
| 191 |
+
f" ⚠ {w}" for w in warnings
|
| 192 |
+
)
|
| 193 |
return msg, True
|
| 194 |
except Exception as e:
|
| 195 |
return f"write error: {e}", False
|
|
|
|
| 242 |
if p.suffix == ".py":
|
| 243 |
warnings = validate_python(new_text, file_path)
|
| 244 |
if warnings:
|
| 245 |
+
msg += "\n\nValidation warnings:\n" + "\n".join(
|
| 246 |
+
f" ⚠ {w}" for w in warnings
|
| 247 |
+
)
|
| 248 |
return msg, True
|
| 249 |
|
| 250 |
|
agent/tools/papers_tool.py
CHANGED
|
@@ -102,7 +102,9 @@ async def _s2_request(
|
|
| 102 |
|
| 103 |
|
| 104 |
async def _s2_get_json(
|
| 105 |
-
client: httpx.AsyncClient,
|
|
|
|
|
|
|
| 106 |
) -> dict | None:
|
| 107 |
"""Cached S2 GET returning parsed JSON or None."""
|
| 108 |
key = _s2_cache_key(path, params)
|
|
@@ -119,7 +121,9 @@ async def _s2_get_json(
|
|
| 119 |
|
| 120 |
|
| 121 |
async def _s2_get_paper(
|
| 122 |
-
client: httpx.AsyncClient,
|
|
|
|
|
|
|
| 123 |
) -> dict | None:
|
| 124 |
"""Fetch a single paper from S2 by arxiv ID. Returns None on failure."""
|
| 125 |
return await _s2_get_json(
|
|
@@ -322,7 +326,9 @@ def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str:
|
|
| 322 |
if keywords:
|
| 323 |
lines.append(f"**Keywords:** {', '.join(keywords)}")
|
| 324 |
if s2_data and s2_data.get("s2FieldsOfStudy"):
|
| 325 |
-
fields = [
|
|
|
|
|
|
|
| 326 |
if fields:
|
| 327 |
lines.append(f"**Fields:** {', '.join(fields)}")
|
| 328 |
if s2_data and s2_data.get("venue"):
|
|
@@ -393,7 +399,9 @@ def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str:
|
|
| 393 |
ds_id = ds.get("id", "unknown")
|
| 394 |
downloads = ds.get("downloads", 0)
|
| 395 |
likes = ds.get("likes", 0)
|
| 396 |
-
desc = _truncate(
|
|
|
|
|
|
|
| 397 |
tags = ds.get("tags") or []
|
| 398 |
interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5]
|
| 399 |
|
|
@@ -582,11 +590,15 @@ def _format_s2_paper_list(papers: list[dict], title: str) -> str:
|
|
| 582 |
lines.append(f"**TL;DR:** {tldr}")
|
| 583 |
lines.append("")
|
| 584 |
|
| 585 |
-
lines.append(
|
|
|
|
|
|
|
| 586 |
return "\n".join(lines)
|
| 587 |
|
| 588 |
|
| 589 |
-
async def _s2_bulk_search(
|
|
|
|
|
|
|
| 590 |
"""Search via S2 bulk endpoint with filters. Returns None on failure."""
|
| 591 |
params: dict[str, Any] = {
|
| 592 |
"query": query,
|
|
@@ -616,7 +628,9 @@ async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolR
|
|
| 616 |
params["sort"] = f"{sort_by}:desc"
|
| 617 |
|
| 618 |
async with httpx.AsyncClient(timeout=15) as client:
|
| 619 |
-
resp = await _s2_request(
|
|
|
|
|
|
|
| 620 |
if not resp or resp.status_code != 200:
|
| 621 |
return None
|
| 622 |
data = resp.json()
|
|
@@ -629,7 +643,9 @@ async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolR
|
|
| 629 |
"resultsShared": 0,
|
| 630 |
}
|
| 631 |
|
| 632 |
-
formatted = _format_s2_paper_list(
|
|
|
|
|
|
|
| 633 |
return {
|
| 634 |
"formatted": formatted,
|
| 635 |
"totalResults": data.get("total", len(papers)),
|
|
@@ -643,7 +659,10 @@ async def _op_search(args: dict[str, Any], limit: int) -> ToolResult:
|
|
| 643 |
return _error("'query' is required for search operation.")
|
| 644 |
|
| 645 |
# Route to S2 when filters are present
|
| 646 |
-
use_s2 = any(
|
|
|
|
|
|
|
|
|
|
| 647 |
if use_s2:
|
| 648 |
result = await _s2_bulk_search(query, args, limit)
|
| 649 |
if result is not None:
|
|
@@ -806,7 +825,9 @@ def _format_citation_graph(
|
|
| 806 |
lines.append("No citations found.")
|
| 807 |
lines.append("")
|
| 808 |
|
| 809 |
-
lines.append(
|
|
|
|
|
|
|
| 810 |
return "\n".join(lines)
|
| 811 |
|
| 812 |
|
|
@@ -824,9 +845,13 @@ async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult:
|
|
| 824 |
refs, cites = None, None
|
| 825 |
coros = []
|
| 826 |
if direction in ("references", "both"):
|
| 827 |
-
coros.append(
|
|
|
|
|
|
|
| 828 |
if direction in ("citations", "both"):
|
| 829 |
-
coros.append(
|
|
|
|
|
|
|
| 830 |
|
| 831 |
results = await asyncio.gather(*coros, return_exceptions=True)
|
| 832 |
idx = 0
|
|
@@ -841,7 +866,9 @@ async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult:
|
|
| 841 |
cites = r.get("data", [])
|
| 842 |
|
| 843 |
if refs is None and cites is None:
|
| 844 |
-
return _error(
|
|
|
|
|
|
|
| 845 |
|
| 846 |
total = (len(refs) if refs else 0) + (len(cites) if cites else 0)
|
| 847 |
return {
|
|
@@ -1039,7 +1066,9 @@ def _format_snippets(snippets: list[dict], query: str) -> str:
|
|
| 1039 |
lines.append(f"> {_truncate(text, 400)}")
|
| 1040 |
lines.append("")
|
| 1041 |
|
| 1042 |
-
lines.append(
|
|
|
|
|
|
|
| 1043 |
return "\n".join(lines)
|
| 1044 |
|
| 1045 |
|
|
@@ -1065,7 +1094,9 @@ async def _op_snippet_search(args: dict[str, Any], limit: int) -> ToolResult:
|
|
| 1065 |
params["minCitationCount"] = str(args["min_citations"])
|
| 1066 |
|
| 1067 |
async with httpx.AsyncClient(timeout=15) as client:
|
| 1068 |
-
resp = await _s2_request(
|
|
|
|
|
|
|
| 1069 |
if not resp or resp.status_code != 200:
|
| 1070 |
return _error("Snippet search failed. Semantic Scholar may be unavailable.")
|
| 1071 |
data = resp.json()
|
|
@@ -1102,16 +1133,28 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
|
|
| 1102 |
async with httpx.AsyncClient(timeout=15) as client:
|
| 1103 |
if positive_ids and not arxiv_id:
|
| 1104 |
# Multi-paper recommendations (POST, not cached)
|
| 1105 |
-
pos = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1106 |
neg_raw = args.get("negative_ids", "")
|
| 1107 |
-
neg =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1108 |
resp = await _s2_request(
|
| 1109 |
-
client,
|
|
|
|
|
|
|
| 1110 |
json={"positivePaperIds": pos, "negativePaperIds": neg},
|
| 1111 |
params={"fields": fields, "limit": limit},
|
| 1112 |
)
|
| 1113 |
if not resp or resp.status_code != 200:
|
| 1114 |
-
return _error(
|
|
|
|
|
|
|
| 1115 |
data = resp.json()
|
| 1116 |
else:
|
| 1117 |
# Single-paper recommendations (cached)
|
|
@@ -1121,7 +1164,9 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
|
|
| 1121 |
{"fields": fields, "limit": limit, "from": "recent"},
|
| 1122 |
)
|
| 1123 |
if not data:
|
| 1124 |
-
return _error(
|
|
|
|
|
|
|
| 1125 |
|
| 1126 |
papers = data.get("recommendedPapers") or []
|
| 1127 |
if not papers:
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
async def _s2_get_json(
|
| 105 |
+
client: httpx.AsyncClient,
|
| 106 |
+
path: str,
|
| 107 |
+
params: dict | None = None,
|
| 108 |
) -> dict | None:
|
| 109 |
"""Cached S2 GET returning parsed JSON or None."""
|
| 110 |
key = _s2_cache_key(path, params)
|
|
|
|
| 121 |
|
| 122 |
|
| 123 |
async def _s2_get_paper(
|
| 124 |
+
client: httpx.AsyncClient,
|
| 125 |
+
arxiv_id: str,
|
| 126 |
+
fields: str,
|
| 127 |
) -> dict | None:
|
| 128 |
"""Fetch a single paper from S2 by arxiv ID. Returns None on failure."""
|
| 129 |
return await _s2_get_json(
|
|
|
|
| 326 |
if keywords:
|
| 327 |
lines.append(f"**Keywords:** {', '.join(keywords)}")
|
| 328 |
if s2_data and s2_data.get("s2FieldsOfStudy"):
|
| 329 |
+
fields = [
|
| 330 |
+
f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category")
|
| 331 |
+
]
|
| 332 |
if fields:
|
| 333 |
lines.append(f"**Fields:** {', '.join(fields)}")
|
| 334 |
if s2_data and s2_data.get("venue"):
|
|
|
|
| 399 |
ds_id = ds.get("id", "unknown")
|
| 400 |
downloads = ds.get("downloads", 0)
|
| 401 |
likes = ds.get("likes", 0)
|
| 402 |
+
desc = _truncate(
|
| 403 |
+
_clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN
|
| 404 |
+
)
|
| 405 |
tags = ds.get("tags") or []
|
| 406 |
interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5]
|
| 407 |
|
|
|
|
| 590 |
lines.append(f"**TL;DR:** {tldr}")
|
| 591 |
lines.append("")
|
| 592 |
|
| 593 |
+
lines.append(
|
| 594 |
+
"Use paper_details with arxiv_id for full info, or read_paper to read sections."
|
| 595 |
+
)
|
| 596 |
return "\n".join(lines)
|
| 597 |
|
| 598 |
|
| 599 |
+
async def _s2_bulk_search(
|
| 600 |
+
query: str, args: dict[str, Any], limit: int
|
| 601 |
+
) -> ToolResult | None:
|
| 602 |
"""Search via S2 bulk endpoint with filters. Returns None on failure."""
|
| 603 |
params: dict[str, Any] = {
|
| 604 |
"query": query,
|
|
|
|
| 628 |
params["sort"] = f"{sort_by}:desc"
|
| 629 |
|
| 630 |
async with httpx.AsyncClient(timeout=15) as client:
|
| 631 |
+
resp = await _s2_request(
|
| 632 |
+
client, "GET", "/graph/v1/paper/search/bulk", params=params
|
| 633 |
+
)
|
| 634 |
if not resp or resp.status_code != 200:
|
| 635 |
return None
|
| 636 |
data = resp.json()
|
|
|
|
| 643 |
"resultsShared": 0,
|
| 644 |
}
|
| 645 |
|
| 646 |
+
formatted = _format_s2_paper_list(
|
| 647 |
+
papers[:limit], f"Papers matching '{query}' (Semantic Scholar)"
|
| 648 |
+
)
|
| 649 |
return {
|
| 650 |
"formatted": formatted,
|
| 651 |
"totalResults": data.get("total", len(papers)),
|
|
|
|
| 659 |
return _error("'query' is required for search operation.")
|
| 660 |
|
| 661 |
# Route to S2 when filters are present
|
| 662 |
+
use_s2 = any(
|
| 663 |
+
args.get(k)
|
| 664 |
+
for k in ("date_from", "date_to", "categories", "min_citations", "sort_by")
|
| 665 |
+
)
|
| 666 |
if use_s2:
|
| 667 |
result = await _s2_bulk_search(query, args, limit)
|
| 668 |
if result is not None:
|
|
|
|
| 825 |
lines.append("No citations found.")
|
| 826 |
lines.append("")
|
| 827 |
|
| 828 |
+
lines.append(
|
| 829 |
+
"**Tip:** Use paper_details with an arxiv_id from above to explore further."
|
| 830 |
+
)
|
| 831 |
return "\n".join(lines)
|
| 832 |
|
| 833 |
|
|
|
|
| 845 |
refs, cites = None, None
|
| 846 |
coros = []
|
| 847 |
if direction in ("references", "both"):
|
| 848 |
+
coros.append(
|
| 849 |
+
_s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params)
|
| 850 |
+
)
|
| 851 |
if direction in ("citations", "both"):
|
| 852 |
+
coros.append(
|
| 853 |
+
_s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params)
|
| 854 |
+
)
|
| 855 |
|
| 856 |
results = await asyncio.gather(*coros, return_exceptions=True)
|
| 857 |
idx = 0
|
|
|
|
| 866 |
cites = r.get("data", [])
|
| 867 |
|
| 868 |
if refs is None and cites is None:
|
| 869 |
+
return _error(
|
| 870 |
+
f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar."
|
| 871 |
+
)
|
| 872 |
|
| 873 |
total = (len(refs) if refs else 0) + (len(cites) if cites else 0)
|
| 874 |
return {
|
|
|
|
| 1066 |
lines.append(f"> {_truncate(text, 400)}")
|
| 1067 |
lines.append("")
|
| 1068 |
|
| 1069 |
+
lines.append(
|
| 1070 |
+
"Use paper_details or read_paper with arxiv_id to explore a paper further."
|
| 1071 |
+
)
|
| 1072 |
return "\n".join(lines)
|
| 1073 |
|
| 1074 |
|
|
|
|
| 1094 |
params["minCitationCount"] = str(args["min_citations"])
|
| 1095 |
|
| 1096 |
async with httpx.AsyncClient(timeout=15) as client:
|
| 1097 |
+
resp = await _s2_request(
|
| 1098 |
+
client, "GET", "/graph/v1/snippet/search", params=params
|
| 1099 |
+
)
|
| 1100 |
if not resp or resp.status_code != 200:
|
| 1101 |
return _error("Snippet search failed. Semantic Scholar may be unavailable.")
|
| 1102 |
data = resp.json()
|
|
|
|
| 1133 |
async with httpx.AsyncClient(timeout=15) as client:
|
| 1134 |
if positive_ids and not arxiv_id:
|
| 1135 |
# Multi-paper recommendations (POST, not cached)
|
| 1136 |
+
pos = [
|
| 1137 |
+
_s2_paper_id(pid.strip())
|
| 1138 |
+
for pid in positive_ids.split(",")
|
| 1139 |
+
if pid.strip()
|
| 1140 |
+
]
|
| 1141 |
neg_raw = args.get("negative_ids", "")
|
| 1142 |
+
neg = (
|
| 1143 |
+
[_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()]
|
| 1144 |
+
if neg_raw
|
| 1145 |
+
else []
|
| 1146 |
+
)
|
| 1147 |
resp = await _s2_request(
|
| 1148 |
+
client,
|
| 1149 |
+
"POST",
|
| 1150 |
+
"/recommendations/v1/papers/",
|
| 1151 |
json={"positivePaperIds": pos, "negativePaperIds": neg},
|
| 1152 |
params={"fields": fields, "limit": limit},
|
| 1153 |
)
|
| 1154 |
if not resp or resp.status_code != 200:
|
| 1155 |
+
return _error(
|
| 1156 |
+
"Recommendation request failed. Semantic Scholar may be unavailable."
|
| 1157 |
+
)
|
| 1158 |
data = resp.json()
|
| 1159 |
else:
|
| 1160 |
# Single-paper recommendations (cached)
|
|
|
|
| 1164 |
{"fields": fields, "limit": limit, "from": "recent"},
|
| 1165 |
)
|
| 1166 |
if not data:
|
| 1167 |
+
return _error(
|
| 1168 |
+
"Recommendation request failed. Semantic Scholar may be unavailable."
|
| 1169 |
+
)
|
| 1170 |
|
| 1171 |
papers = data.get("recommendedPapers") or []
|
| 1172 |
if not papers:
|
agent/tools/research_tool.py
CHANGED
|
@@ -282,6 +282,7 @@ async def research_handler(
|
|
| 282 |
_agent_id = tool_call_id
|
| 283 |
else:
|
| 284 |
import uuid
|
|
|
|
| 285 |
_agent_id = uuid.uuid4().hex[:8]
|
| 286 |
_agent_label = "research: " + (task[:50] + "…" if len(task) > 50 else task)
|
| 287 |
|
|
@@ -289,12 +290,15 @@ async def research_handler(
|
|
| 289 |
"""Send a progress event to the UI so it doesn't look frozen."""
|
| 290 |
try:
|
| 291 |
await session.send_event(
|
| 292 |
-
Event(
|
| 293 |
-
"
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
| 298 |
)
|
| 299 |
except Exception:
|
| 300 |
pass
|
|
@@ -323,15 +327,19 @@ async def research_handler(
|
|
| 323 |
"Research sub-agent hit context max (%d tokens) — forcing summary",
|
| 324 |
_total_tokens,
|
| 325 |
)
|
| 326 |
-
await _log(
|
|
|
|
|
|
|
| 327 |
# Ask for a final summary with no tools
|
| 328 |
-
messages.append(
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
|
|
|
|
|
|
| 335 |
try:
|
| 336 |
_msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
|
| 337 |
_t0 = time.monotonic()
|
|
@@ -351,27 +359,34 @@ async def research_handler(
|
|
| 351 |
model=research_model,
|
| 352 |
response=response,
|
| 353 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 354 |
-
finish_reason=response.choices[0].finish_reason
|
|
|
|
|
|
|
| 355 |
kind="research",
|
| 356 |
)
|
| 357 |
except Exception as _telem_err:
|
| 358 |
logger.debug("research telemetry failed: %s", _telem_err)
|
| 359 |
content = response.choices[0].message.content or ""
|
| 360 |
-
return
|
|
|
|
|
|
|
|
|
|
| 361 |
except Exception:
|
| 362 |
return "Research context exhausted and summary call failed.", False
|
| 363 |
|
| 364 |
if not _warned_context and _total_tokens >= _RESEARCH_CONTEXT_WARN:
|
| 365 |
_warned_context = True
|
| 366 |
await _log(f"Context at {_total_tokens} tokens — nudging to wrap up")
|
| 367 |
-
messages.append(
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
|
|
|
|
|
|
| 375 |
|
| 376 |
try:
|
| 377 |
_msgs, _tools = with_prompt_caching(
|
|
@@ -392,7 +407,9 @@ async def research_handler(
|
|
| 392 |
model=research_model,
|
| 393 |
response=response,
|
| 394 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 395 |
-
finish_reason=response.choices[0].finish_reason
|
|
|
|
|
|
|
| 396 |
kind="research",
|
| 397 |
)
|
| 398 |
except Exception as _telem_err:
|
|
@@ -420,11 +437,13 @@ async def research_handler(
|
|
| 420 |
# LiteLLM's raw Message carries `provider_specific_fields` and
|
| 421 |
# `reasoning_content`, which the HF router's OpenAI schema rejects
|
| 422 |
# if we echo them back in the next request.
|
| 423 |
-
messages.append(
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
| 428 |
for tc in msg.tool_calls:
|
| 429 |
try:
|
| 430 |
tool_args = json.loads(tc.function.arguments)
|
|
@@ -479,13 +498,15 @@ async def research_handler(
|
|
| 479 |
|
| 480 |
# ── Iteration limit: try to salvage findings ──
|
| 481 |
await _log("Iteration limit reached — extracting summary")
|
| 482 |
-
messages.append(
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
|
|
|
|
|
|
| 489 |
try:
|
| 490 |
_msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
|
| 491 |
_t0 = time.monotonic()
|
|
@@ -502,7 +523,9 @@ async def research_handler(
|
|
| 502 |
model=research_model,
|
| 503 |
response=response,
|
| 504 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 505 |
-
finish_reason=response.choices[0].finish_reason
|
|
|
|
|
|
|
| 506 |
kind="research",
|
| 507 |
)
|
| 508 |
except Exception as _telem_err:
|
|
|
|
| 282 |
_agent_id = tool_call_id
|
| 283 |
else:
|
| 284 |
import uuid
|
| 285 |
+
|
| 286 |
_agent_id = uuid.uuid4().hex[:8]
|
| 287 |
_agent_label = "research: " + (task[:50] + "…" if len(task) > 50 else task)
|
| 288 |
|
|
|
|
| 290 |
"""Send a progress event to the UI so it doesn't look frozen."""
|
| 291 |
try:
|
| 292 |
await session.send_event(
|
| 293 |
+
Event(
|
| 294 |
+
event_type="tool_log",
|
| 295 |
+
data={
|
| 296 |
+
"tool": "research",
|
| 297 |
+
"log": text,
|
| 298 |
+
"agent_id": _agent_id,
|
| 299 |
+
"label": _agent_label,
|
| 300 |
+
},
|
| 301 |
+
)
|
| 302 |
)
|
| 303 |
except Exception:
|
| 304 |
pass
|
|
|
|
| 327 |
"Research sub-agent hit context max (%d tokens) — forcing summary",
|
| 328 |
_total_tokens,
|
| 329 |
)
|
| 330 |
+
await _log(
|
| 331 |
+
f"Context limit reached ({_total_tokens} tokens) — forcing wrap-up"
|
| 332 |
+
)
|
| 333 |
# Ask for a final summary with no tools
|
| 334 |
+
messages.append(
|
| 335 |
+
Message(
|
| 336 |
+
role="user",
|
| 337 |
+
content=(
|
| 338 |
+
"[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. "
|
| 339 |
+
"Summarize your findings NOW. Do NOT call any more tools."
|
| 340 |
+
),
|
| 341 |
+
)
|
| 342 |
+
)
|
| 343 |
try:
|
| 344 |
_msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
|
| 345 |
_t0 = time.monotonic()
|
|
|
|
| 359 |
model=research_model,
|
| 360 |
response=response,
|
| 361 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 362 |
+
finish_reason=response.choices[0].finish_reason
|
| 363 |
+
if response.choices
|
| 364 |
+
else None,
|
| 365 |
kind="research",
|
| 366 |
)
|
| 367 |
except Exception as _telem_err:
|
| 368 |
logger.debug("research telemetry failed: %s", _telem_err)
|
| 369 |
content = response.choices[0].message.content or ""
|
| 370 |
+
return (
|
| 371 |
+
content or "Research context exhausted — no summary produced.",
|
| 372 |
+
bool(content),
|
| 373 |
+
)
|
| 374 |
except Exception:
|
| 375 |
return "Research context exhausted and summary call failed.", False
|
| 376 |
|
| 377 |
if not _warned_context and _total_tokens >= _RESEARCH_CONTEXT_WARN:
|
| 378 |
_warned_context = True
|
| 379 |
await _log(f"Context at {_total_tokens} tokens — nudging to wrap up")
|
| 380 |
+
messages.append(
|
| 381 |
+
Message(
|
| 382 |
+
role="user",
|
| 383 |
+
content=(
|
| 384 |
+
"[SYSTEM: You have used 75% of your context budget. "
|
| 385 |
+
"Start wrapping up: finish any critical lookups, then "
|
| 386 |
+
"produce your final summary within the next 1-2 iterations.]"
|
| 387 |
+
),
|
| 388 |
+
)
|
| 389 |
+
)
|
| 390 |
|
| 391 |
try:
|
| 392 |
_msgs, _tools = with_prompt_caching(
|
|
|
|
| 407 |
model=research_model,
|
| 408 |
response=response,
|
| 409 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 410 |
+
finish_reason=response.choices[0].finish_reason
|
| 411 |
+
if response.choices
|
| 412 |
+
else None,
|
| 413 |
kind="research",
|
| 414 |
)
|
| 415 |
except Exception as _telem_err:
|
|
|
|
| 437 |
# LiteLLM's raw Message carries `provider_specific_fields` and
|
| 438 |
# `reasoning_content`, which the HF router's OpenAI schema rejects
|
| 439 |
# if we echo them back in the next request.
|
| 440 |
+
messages.append(
|
| 441 |
+
Message(
|
| 442 |
+
role="assistant",
|
| 443 |
+
content=msg.content,
|
| 444 |
+
tool_calls=msg.tool_calls,
|
| 445 |
+
)
|
| 446 |
+
)
|
| 447 |
for tc in msg.tool_calls:
|
| 448 |
try:
|
| 449 |
tool_args = json.loads(tc.function.arguments)
|
|
|
|
| 498 |
|
| 499 |
# ── Iteration limit: try to salvage findings ──
|
| 500 |
await _log("Iteration limit reached — extracting summary")
|
| 501 |
+
messages.append(
|
| 502 |
+
Message(
|
| 503 |
+
role="user",
|
| 504 |
+
content=(
|
| 505 |
+
"[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research "
|
| 506 |
+
"iterations. Summarize ALL findings so far. Do NOT call any more tools."
|
| 507 |
+
),
|
| 508 |
+
)
|
| 509 |
+
)
|
| 510 |
try:
|
| 511 |
_msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
|
| 512 |
_t0 = time.monotonic()
|
|
|
|
| 523 |
model=research_model,
|
| 524 |
response=response,
|
| 525 |
latency_ms=int((time.monotonic() - _t0) * 1000),
|
| 526 |
+
finish_reason=response.choices[0].finish_reason
|
| 527 |
+
if response.choices
|
| 528 |
+
else None,
|
| 529 |
kind="research",
|
| 530 |
)
|
| 531 |
except Exception as _telem_err:
|
agent/tools/sandbox_client.py
CHANGED
|
@@ -729,9 +729,7 @@ class Sandbox:
|
|
| 729 |
runtime, "requested_hardware", None
|
| 730 |
)
|
| 731 |
if current_hardware != hardware:
|
| 732 |
-
_log(
|
| 733 |
-
f" RUNNING on {current_hardware}; waiting for {hardware}..."
|
| 734 |
-
)
|
| 735 |
time.sleep(WAIT_INTERVAL)
|
| 736 |
continue
|
| 737 |
_log(f"Space is running (hardware: {runtime.hardware})")
|
|
@@ -767,7 +765,9 @@ class Sandbox:
|
|
| 767 |
return sb
|
| 768 |
|
| 769 |
@staticmethod
|
| 770 |
-
def _setup_server(
|
|
|
|
|
|
|
| 771 |
"""Upload embedded sandbox server + Dockerfile to the Space (single commit)."""
|
| 772 |
log(f"Uploading sandbox server to {space_id}...")
|
| 773 |
api.create_commit(
|
|
@@ -809,7 +809,9 @@ class Sandbox:
|
|
| 809 |
sb._wait_for_api(timeout=60)
|
| 810 |
return sb
|
| 811 |
|
| 812 |
-
def _wait_for_api(
|
|
|
|
|
|
|
| 813 |
"""Poll the health endpoint until the server responds."""
|
| 814 |
deadline = time.time() + timeout
|
| 815 |
last_err = None
|
|
@@ -986,7 +988,12 @@ class Sandbox:
|
|
| 986 |
return result
|
| 987 |
|
| 988 |
def edit(
|
| 989 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 990 |
mode: str = "replace",
|
| 991 |
) -> ToolResult:
|
| 992 |
if old_str == new_str:
|
|
|
|
| 729 |
runtime, "requested_hardware", None
|
| 730 |
)
|
| 731 |
if current_hardware != hardware:
|
| 732 |
+
_log(f" RUNNING on {current_hardware}; waiting for {hardware}...")
|
|
|
|
|
|
|
| 733 |
time.sleep(WAIT_INTERVAL)
|
| 734 |
continue
|
| 735 |
_log(f"Space is running (hardware: {runtime.hardware})")
|
|
|
|
| 765 |
return sb
|
| 766 |
|
| 767 |
@staticmethod
|
| 768 |
+
def _setup_server(
|
| 769 |
+
space_id: str, api: HfApi, *, log: Callable[[str], object] = print
|
| 770 |
+
) -> None:
|
| 771 |
"""Upload embedded sandbox server + Dockerfile to the Space (single commit)."""
|
| 772 |
log(f"Uploading sandbox server to {space_id}...")
|
| 773 |
api.create_commit(
|
|
|
|
| 809 |
sb._wait_for_api(timeout=60)
|
| 810 |
return sb
|
| 811 |
|
| 812 |
+
def _wait_for_api(
|
| 813 |
+
self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print
|
| 814 |
+
):
|
| 815 |
"""Poll the health endpoint until the server responds."""
|
| 816 |
deadline = time.time() + timeout
|
| 817 |
last_err = None
|
|
|
|
| 988 |
return result
|
| 989 |
|
| 990 |
def edit(
|
| 991 |
+
self,
|
| 992 |
+
path: str,
|
| 993 |
+
old_str: str,
|
| 994 |
+
new_str: str,
|
| 995 |
+
*,
|
| 996 |
+
replace_all: bool = False,
|
| 997 |
mode: str = "replace",
|
| 998 |
) -> ToolResult:
|
| 999 |
if old_str == new_str:
|
agent/tools/sandbox_tool.py
CHANGED
|
@@ -21,6 +21,7 @@ from typing import Any
|
|
| 21 |
|
| 22 |
from huggingface_hub import HfApi, SpaceHardware
|
| 23 |
|
|
|
|
| 24 |
from agent.core.session import Event
|
| 25 |
from agent.tools.sandbox_client import Sandbox
|
| 26 |
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
|
@@ -197,7 +198,9 @@ def _cleanup_user_orphan_sandboxes(
|
|
| 197 |
if not _SANDBOX_NAME_RE.match(space_name):
|
| 198 |
continue
|
| 199 |
|
| 200 |
-
last_mod = getattr(space, "lastModified", None) or getattr(
|
|
|
|
|
|
|
| 201 |
if isinstance(last_mod, str):
|
| 202 |
try:
|
| 203 |
last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00"))
|
|
@@ -337,6 +340,7 @@ async def _create_sandbox_locked(
|
|
| 337 |
if hardware != DEFAULT_CPU_SANDBOX_HARDWARE:
|
| 338 |
kwargs["sleep_time"] = 2700
|
| 339 |
import time as _t
|
|
|
|
| 340 |
_t_start = _t.monotonic()
|
| 341 |
try:
|
| 342 |
sb = await asyncio.to_thread(Sandbox.create, **kwargs)
|
|
@@ -350,7 +354,9 @@ async def _create_sandbox_locked(
|
|
| 350 |
try:
|
| 351 |
await asyncio.to_thread(sb.delete)
|
| 352 |
except Exception as e:
|
| 353 |
-
logger.warning(
|
|
|
|
|
|
|
| 354 |
return None, "Sandbox creation cancelled by user."
|
| 355 |
|
| 356 |
session.sandbox = sb
|
|
@@ -360,8 +366,11 @@ async def _create_sandbox_locked(
|
|
| 360 |
|
| 361 |
# Telemetry: sandbox creation (infra consumption signal)
|
| 362 |
from agent.core import telemetry
|
|
|
|
| 363 |
await telemetry.record_sandbox_create(
|
| 364 |
-
session,
|
|
|
|
|
|
|
| 365 |
create_latency_s=int(_t.monotonic() - _t_start),
|
| 366 |
)
|
| 367 |
|
|
@@ -510,12 +519,13 @@ async def teardown_session_sandbox(session: Any) -> None:
|
|
| 510 |
)
|
| 511 |
await asyncio.to_thread(sandbox.delete)
|
| 512 |
from agent.core import telemetry
|
|
|
|
| 513 |
await telemetry.record_sandbox_destroy(session, sandbox)
|
| 514 |
return
|
| 515 |
except Exception as e:
|
| 516 |
last_err = e
|
| 517 |
if attempt < 2:
|
| 518 |
-
await asyncio.sleep(2
|
| 519 |
logger.error(
|
| 520 |
"Failed to delete sandbox %s after 3 attempts: %s. "
|
| 521 |
"Orphan — sweep script will pick it up.",
|
|
@@ -720,6 +730,14 @@ def _make_tool_handler(sandbox_tool_name: str):
|
|
| 720 |
return "Sandbox is still starting. Please retry shortly.", False
|
| 721 |
|
| 722 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args)
|
| 724 |
if result.success:
|
| 725 |
output = result.output or "(no output)"
|
|
@@ -758,8 +776,7 @@ def get_sandbox_tools():
|
|
| 758 |
description = (
|
| 759 |
"Uses the session's active sandbox. A private cpu-basic sandbox is "
|
| 760 |
"started automatically for normal CPU work; call sandbox_create only "
|
| 761 |
-
"for GPU or other non-default hardware.\n\n"
|
| 762 |
-
+ spec["description"]
|
| 763 |
)
|
| 764 |
tools.append(
|
| 765 |
ToolSpec(
|
|
|
|
| 21 |
|
| 22 |
from huggingface_hub import HfApi, SpaceHardware
|
| 23 |
|
| 24 |
+
from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap
|
| 25 |
from agent.core.session import Event
|
| 26 |
from agent.tools.sandbox_client import Sandbox
|
| 27 |
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
|
|
|
| 198 |
if not _SANDBOX_NAME_RE.match(space_name):
|
| 199 |
continue
|
| 200 |
|
| 201 |
+
last_mod = getattr(space, "lastModified", None) or getattr(
|
| 202 |
+
space, "last_modified", None
|
| 203 |
+
)
|
| 204 |
if isinstance(last_mod, str):
|
| 205 |
try:
|
| 206 |
last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00"))
|
|
|
|
| 340 |
if hardware != DEFAULT_CPU_SANDBOX_HARDWARE:
|
| 341 |
kwargs["sleep_time"] = 2700
|
| 342 |
import time as _t
|
| 343 |
+
|
| 344 |
_t_start = _t.monotonic()
|
| 345 |
try:
|
| 346 |
sb = await asyncio.to_thread(Sandbox.create, **kwargs)
|
|
|
|
| 354 |
try:
|
| 355 |
await asyncio.to_thread(sb.delete)
|
| 356 |
except Exception as e:
|
| 357 |
+
logger.warning(
|
| 358 |
+
"Failed to delete cancelled sandbox %s: %s", sb.space_id, e
|
| 359 |
+
)
|
| 360 |
return None, "Sandbox creation cancelled by user."
|
| 361 |
|
| 362 |
session.sandbox = sb
|
|
|
|
| 366 |
|
| 367 |
# Telemetry: sandbox creation (infra consumption signal)
|
| 368 |
from agent.core import telemetry
|
| 369 |
+
|
| 370 |
await telemetry.record_sandbox_create(
|
| 371 |
+
session,
|
| 372 |
+
sb,
|
| 373 |
+
hardware=hardware,
|
| 374 |
create_latency_s=int(_t.monotonic() - _t_start),
|
| 375 |
)
|
| 376 |
|
|
|
|
| 519 |
)
|
| 520 |
await asyncio.to_thread(sandbox.delete)
|
| 521 |
from agent.core import telemetry
|
| 522 |
+
|
| 523 |
await telemetry.record_sandbox_destroy(session, sandbox)
|
| 524 |
return
|
| 525 |
except Exception as e:
|
| 526 |
last_err = e
|
| 527 |
if attempt < 2:
|
| 528 |
+
await asyncio.sleep(2**attempt)
|
| 529 |
logger.error(
|
| 530 |
"Failed to delete sandbox %s after 3 attempts: %s. "
|
| 531 |
"Orphan — sweep script will pick it up.",
|
|
|
|
| 730 |
return "Sandbox is still starting. Please retry shortly.", False
|
| 731 |
|
| 732 |
try:
|
| 733 |
+
if sandbox_tool_name == "bash" and args.get("command"):
|
| 734 |
+
args = {
|
| 735 |
+
**args,
|
| 736 |
+
"command": wrap_shell_command_with_hub_artifact_bootstrap(
|
| 737 |
+
args["command"],
|
| 738 |
+
session,
|
| 739 |
+
),
|
| 740 |
+
}
|
| 741 |
result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args)
|
| 742 |
if result.success:
|
| 743 |
output = result.output or "(no output)"
|
|
|
|
| 776 |
description = (
|
| 777 |
"Uses the session's active sandbox. A private cpu-basic sandbox is "
|
| 778 |
"started automatically for normal CPU work; call sandbox_create only "
|
| 779 |
+
"for GPU or other non-default hardware.\n\n" + spec["description"]
|
|
|
|
| 780 |
)
|
| 781 |
tools.append(
|
| 782 |
ToolSpec(
|
agent/tools/web_search_tool.py
CHANGED
|
@@ -253,7 +253,10 @@ async def web_search_handler(
|
|
| 253 |
) -> tuple[str, bool]:
|
| 254 |
query_value = arguments.get("query", "")
|
| 255 |
if not isinstance(query_value, str):
|
| 256 |
-
return
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
query = query_value.strip()
|
| 259 |
if len(query) < 2:
|
|
|
|
| 253 |
) -> tuple[str, bool]:
|
| 254 |
query_value = arguments.get("query", "")
|
| 255 |
if not isinstance(query_value, str):
|
| 256 |
+
return (
|
| 257 |
+
"Error: web_search requires a query string with at least 2 characters.",
|
| 258 |
+
False,
|
| 259 |
+
)
|
| 260 |
|
| 261 |
query = query_value.strip()
|
| 262 |
if len(query) < 2:
|
agent/utils/braille.py
CHANGED
|
@@ -41,8 +41,7 @@ class BrailleCanvas:
|
|
| 41 |
for row in range(self.term_height):
|
| 42 |
offset = row * self.term_width
|
| 43 |
line = "".join(
|
| 44 |
-
chr(0x2800 + self._buf[offset + col])
|
| 45 |
-
for col in range(self.term_width)
|
| 46 |
)
|
| 47 |
lines.append(line)
|
| 48 |
return lines
|
|
@@ -52,6 +51,7 @@ class BrailleCanvas:
|
|
| 52 |
|
| 53 |
_FONT: dict[str, list[str]] = {}
|
| 54 |
|
|
|
|
| 55 |
def _define_font() -> None:
|
| 56 |
"""Define a simple 5×7 bitmap font for uppercase ASCII."""
|
| 57 |
glyphs = {
|
|
@@ -113,8 +113,9 @@ def text_to_pixels(text: str, scale: int = 1) -> list[tuple[int, int]]:
|
|
| 113 |
if cell == "#":
|
| 114 |
for sy in range(scale):
|
| 115 |
for sx in range(scale):
|
| 116 |
-
pixels.append(
|
| 117 |
-
|
|
|
|
| 118 |
glyph_width = max(len(r) for r in glyph)
|
| 119 |
cursor_x += (glyph_width + 1) * scale
|
| 120 |
return pixels
|
|
|
|
| 41 |
for row in range(self.term_height):
|
| 42 |
offset = row * self.term_width
|
| 43 |
line = "".join(
|
| 44 |
+
chr(0x2800 + self._buf[offset + col]) for col in range(self.term_width)
|
|
|
|
| 45 |
)
|
| 46 |
lines.append(line)
|
| 47 |
return lines
|
|
|
|
| 51 |
|
| 52 |
_FONT: dict[str, list[str]] = {}
|
| 53 |
|
| 54 |
+
|
| 55 |
def _define_font() -> None:
|
| 56 |
"""Define a simple 5×7 bitmap font for uppercase ASCII."""
|
| 57 |
glyphs = {
|
|
|
|
| 113 |
if cell == "#":
|
| 114 |
for sy in range(scale):
|
| 115 |
for sx in range(scale):
|
| 116 |
+
pixels.append(
|
| 117 |
+
(cursor_x + col_idx * scale + sx, row_idx * scale + sy)
|
| 118 |
+
)
|
| 119 |
glyph_width = max(len(r) for r in glyph)
|
| 120 |
cursor_x += (glyph_width + 1) * scale
|
| 121 |
return pixels
|
agent/utils/crt_boot.py
CHANGED
|
@@ -55,7 +55,10 @@ def run_boot_sequence(console: Console, boot_lines: list[tuple[str, str]]) -> No
|
|
| 55 |
# Render previously completed lines
|
| 56 |
for prev_text, prev_style in displayed_lines:
|
| 57 |
if rng.random() < prev_glitch_chance:
|
| 58 |
-
result.append(
|
|
|
|
|
|
|
|
|
|
| 59 |
else:
|
| 60 |
result.append(prev_text, style=prev_style)
|
| 61 |
result.append("\n")
|
|
@@ -86,7 +89,7 @@ def run_boot_sequence(console: Console, boot_lines: list[tuple[str, str]]) -> No
|
|
| 86 |
live.update(result)
|
| 87 |
|
| 88 |
# Variable typing speed
|
| 89 |
-
if line_text[char_idx - 1:char_idx] in " .":
|
| 90 |
time.sleep(0.025)
|
| 91 |
else:
|
| 92 |
time.sleep(0.010)
|
|
|
|
| 55 |
# Render previously completed lines
|
| 56 |
for prev_text, prev_style in displayed_lines:
|
| 57 |
if rng.random() < prev_glitch_chance:
|
| 58 |
+
result.append(
|
| 59 |
+
_glitch_text(prev_text, prev_glitch_intensity, rng),
|
| 60 |
+
style=prev_style,
|
| 61 |
+
)
|
| 62 |
else:
|
| 63 |
result.append(prev_text, style=prev_style)
|
| 64 |
result.append("\n")
|
|
|
|
| 89 |
live.update(result)
|
| 90 |
|
| 91 |
# Variable typing speed
|
| 92 |
+
if line_text[char_idx - 1 : char_idx] in " .":
|
| 93 |
time.sleep(0.025)
|
| 94 |
else:
|
| 95 |
time.sleep(0.010)
|
agent/utils/particle_logo.py
CHANGED
|
@@ -23,7 +23,9 @@ from agent.utils.boot_timing import settle_curve, warm_gold_from_white
|
|
| 23 |
class Particle:
|
| 24 |
__slots__ = ("x", "y", "target_x", "target_y", "vx", "vy", "phase", "delay")
|
| 25 |
|
| 26 |
-
def __init__(
|
|
|
|
|
|
|
| 27 |
self.x = x
|
| 28 |
self.y = y
|
| 29 |
self.target_x = target_x
|
|
|
|
| 23 |
class Particle:
|
| 24 |
__slots__ = ("x", "y", "target_x", "target_y", "vx", "vy", "phase", "delay")
|
| 25 |
|
| 26 |
+
def __init__(
|
| 27 |
+
self, x: float, y: float, target_x: float, target_y: float, delay: float = 0
|
| 28 |
+
):
|
| 29 |
self.x = x
|
| 30 |
self.y = y
|
| 31 |
self.target_x = target_x
|
agent/utils/terminal_display.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
Terminal display utilities — rich-powered CLI formatting.
|
| 3 |
"""
|
| 4 |
|
|
|
|
| 5 |
import re
|
| 6 |
|
| 7 |
from rich.console import Console
|
|
@@ -57,23 +58,26 @@ def _clip_to_width(s: str, width: int) -> str:
|
|
| 57 |
out.append("\033[0m…")
|
| 58 |
return "".join(out)
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
_console = Console(theme=_THEME, highlight=False)
|
| 79 |
|
|
@@ -87,6 +91,7 @@ def get_console() -> Console:
|
|
| 87 |
|
| 88 |
# ── Banner ─────────────────────────────────────────────────────────────
|
| 89 |
|
|
|
|
| 90 |
def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
|
| 91 |
"""Print particle logo then CRT boot sequence with system info."""
|
| 92 |
from agent.utils.particle_logo import run_particle_logo
|
|
@@ -120,12 +125,16 @@ def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
|
|
| 120 |
|
| 121 |
# ── Init progress ──────────────────────────────────────────────────────
|
| 122 |
|
|
|
|
| 123 |
def print_init_done(tool_count: int = 0) -> None:
|
| 124 |
import time
|
|
|
|
| 125 |
f = _console.file
|
| 126 |
# Overwrite the "Tools: loading..." line with actual count
|
| 127 |
-
f.write(
|
| 128 |
-
|
|
|
|
|
|
|
| 129 |
gold = "\033[38;2;180;140;40m"
|
| 130 |
reset = "\033[0m"
|
| 131 |
tool_text = f"{_I} Tools: {tool_count} loaded"
|
|
@@ -135,16 +144,22 @@ def print_init_done(tool_count: int = 0) -> None:
|
|
| 135 |
time.sleep(0.012)
|
| 136 |
f.write("\n\n")
|
| 137 |
# Reprint the help line
|
| 138 |
-
f.write(
|
|
|
|
|
|
|
| 139 |
# Ready message — minimal padding
|
| 140 |
-
f.write(
|
|
|
|
|
|
|
| 141 |
f.flush()
|
| 142 |
|
| 143 |
|
| 144 |
# ── Tool calls ─────────────────────────────────────────────────────────
|
| 145 |
|
|
|
|
| 146 |
def print_tool_call(tool_name: str, args_preview: str) -> None:
|
| 147 |
import time
|
|
|
|
| 148 |
f = _console.file
|
| 149 |
# CRT-style: type out tool name in HF yellow
|
| 150 |
gold = "\033[38;2;255;200;80m"
|
|
@@ -183,6 +198,7 @@ class SubAgentDisplayManager:
|
|
| 183 |
|
| 184 |
def start(self, agent_id: str, label: str = "research") -> None:
|
| 185 |
import time
|
|
|
|
| 186 |
self._agents[agent_id] = {
|
| 187 |
"label": label,
|
| 188 |
"calls": [],
|
|
@@ -234,6 +250,7 @@ class SubAgentDisplayManager:
|
|
| 234 |
@staticmethod
|
| 235 |
def _format_stats(agent: dict) -> str:
|
| 236 |
import time
|
|
|
|
| 237 |
start = agent["start_time"]
|
| 238 |
if start is None:
|
| 239 |
return ""
|
|
@@ -276,7 +293,7 @@ class SubAgentDisplayManager:
|
|
| 276 |
header += f" \033[2m·\033[0m \033[2m{short}\033[0m"
|
| 277 |
return [header]
|
| 278 |
lines = [header]
|
| 279 |
-
visible = agent["calls"][-self._MAX_VISIBLE:]
|
| 280 |
for desc in visible:
|
| 281 |
lines.append(f"{_I} \033[2m{desc}\033[0m")
|
| 282 |
return lines
|
|
@@ -319,13 +336,14 @@ def print_tool_log(tool: str, log: str, agent_id: str = "", label: str = "") ->
|
|
| 319 |
|
| 320 |
# ── Messages ───────────────────────────────────────────────────────────
|
| 321 |
|
|
|
|
| 322 |
async def print_markdown(
|
| 323 |
text: str,
|
| 324 |
cancel_event: "asyncio.Event | None" = None,
|
| 325 |
instant: bool = False,
|
| 326 |
) -> None:
|
| 327 |
-
import
|
| 328 |
-
import
|
| 329 |
from rich.padding import Padding
|
| 330 |
|
| 331 |
_console.print()
|
|
@@ -395,23 +413,35 @@ def print_interrupted() -> None:
|
|
| 395 |
|
| 396 |
|
| 397 |
def print_compacted(old_tokens: int, new_tokens: int) -> None:
|
| 398 |
-
_console.print(
|
|
|
|
|
|
|
| 399 |
|
| 400 |
|
| 401 |
# ── Approval ───────────────────────────────────────────────────────────
|
| 402 |
|
|
|
|
| 403 |
def print_approval_header(count: int) -> None:
|
| 404 |
label = f"Approval required — {count} item{'s' if count != 1 else ''}"
|
| 405 |
_console.print()
|
| 406 |
-
_console.print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
|
| 409 |
def print_approval_item(index: int, total: int, tool_name: str, operation: str) -> None:
|
| 410 |
-
_console.print(
|
|
|
|
|
|
|
| 411 |
|
| 412 |
|
| 413 |
def print_yolo_approve(count: int) -> None:
|
| 414 |
-
_console.print(
|
|
|
|
|
|
|
| 415 |
|
| 416 |
|
| 417 |
# ── Help ───────────────────────────────────────────────────────────────
|
|
@@ -437,6 +467,7 @@ def print_help() -> None:
|
|
| 437 |
|
| 438 |
# ── Plan display ───────────────────────────────────────────────────────
|
| 439 |
|
|
|
|
| 440 |
def format_plan_display() -> str:
|
| 441 |
"""Format the current plan for display."""
|
| 442 |
from agent.tools.plan_tool import get_current_plan
|
|
@@ -470,6 +501,7 @@ def print_plan() -> None:
|
|
| 470 |
|
| 471 |
# ── Formatting for plan_tool output (used by plan_tool handler) ────────
|
| 472 |
|
|
|
|
| 473 |
def format_plan_tool_output(todos: list) -> str:
|
| 474 |
if not todos:
|
| 475 |
return "Plan is empty."
|
|
@@ -492,6 +524,7 @@ def format_plan_tool_output(todos: list) -> str:
|
|
| 492 |
|
| 493 |
# ── Internal helpers ───────────────────────────────────────────────────
|
| 494 |
|
|
|
|
| 495 |
def _truncate(text: str, max_lines: int = 6) -> str:
|
| 496 |
lines = text.split("\n")
|
| 497 |
if len(lines) <= max_lines:
|
|
|
|
| 2 |
Terminal display utilities — rich-powered CLI formatting.
|
| 3 |
"""
|
| 4 |
|
| 5 |
+
import asyncio
|
| 6 |
import re
|
| 7 |
|
| 8 |
from rich.console import Console
|
|
|
|
| 58 |
out.append("\033[0m…")
|
| 59 |
return "".join(out)
|
| 60 |
|
| 61 |
+
|
| 62 |
+
_THEME = Theme(
|
| 63 |
+
{
|
| 64 |
+
"tool.name": "bold rgb(255,200,80)",
|
| 65 |
+
"tool.args": "dim",
|
| 66 |
+
"tool.ok": "dim green",
|
| 67 |
+
"tool.fail": "dim red",
|
| 68 |
+
"info": "dim",
|
| 69 |
+
"muted": "dim",
|
| 70 |
+
# Markdown emphasis colors
|
| 71 |
+
"markdown.strong": "bold rgb(255,200,80)",
|
| 72 |
+
"markdown.emphasis": "italic rgb(180,140,40)",
|
| 73 |
+
"markdown.code": "rgb(120,220,255)",
|
| 74 |
+
"markdown.code_block": "rgb(120,220,255)",
|
| 75 |
+
"markdown.link": "underline rgb(90,180,255)",
|
| 76 |
+
"markdown.h1": "bold rgb(255,200,80)",
|
| 77 |
+
"markdown.h2": "bold rgb(240,180,95)",
|
| 78 |
+
"markdown.h3": "bold rgb(220,165,100)",
|
| 79 |
+
}
|
| 80 |
+
)
|
| 81 |
|
| 82 |
_console = Console(theme=_THEME, highlight=False)
|
| 83 |
|
|
|
|
| 91 |
|
| 92 |
# ── Banner ─────────────────────────────────────────────────────────────
|
| 93 |
|
| 94 |
+
|
| 95 |
def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
|
| 96 |
"""Print particle logo then CRT boot sequence with system info."""
|
| 97 |
from agent.utils.particle_logo import run_particle_logo
|
|
|
|
| 125 |
|
| 126 |
# ── Init progress ──────────────────────────────────────────────────────
|
| 127 |
|
| 128 |
+
|
| 129 |
def print_init_done(tool_count: int = 0) -> None:
|
| 130 |
import time
|
| 131 |
+
|
| 132 |
f = _console.file
|
| 133 |
# Overwrite the "Tools: loading..." line with actual count
|
| 134 |
+
f.write(
|
| 135 |
+
"\033[A\033[A\033[A\033[K"
|
| 136 |
+
) # Move up 3 lines (blank + help + blank) then up to tools line
|
| 137 |
+
f.write("\033[A\033[K")
|
| 138 |
gold = "\033[38;2;180;140;40m"
|
| 139 |
reset = "\033[0m"
|
| 140 |
tool_text = f"{_I} Tools: {tool_count} loaded"
|
|
|
|
| 144 |
time.sleep(0.012)
|
| 145 |
f.write("\n\n")
|
| 146 |
# Reprint the help line
|
| 147 |
+
f.write(
|
| 148 |
+
f"{_I}\033[38;2;255;200;80m/help for commands · /model to switch · /quit to exit{reset}\n\n"
|
| 149 |
+
)
|
| 150 |
# Ready message — minimal padding
|
| 151 |
+
f.write(
|
| 152 |
+
f"{_I}\033[38;2;255;200;80mReady. Let's build something impressive.{reset}\n"
|
| 153 |
+
)
|
| 154 |
f.flush()
|
| 155 |
|
| 156 |
|
| 157 |
# ── Tool calls ─────────────────────────────────────────────────────────
|
| 158 |
|
| 159 |
+
|
| 160 |
def print_tool_call(tool_name: str, args_preview: str) -> None:
|
| 161 |
import time
|
| 162 |
+
|
| 163 |
f = _console.file
|
| 164 |
# CRT-style: type out tool name in HF yellow
|
| 165 |
gold = "\033[38;2;255;200;80m"
|
|
|
|
| 198 |
|
| 199 |
def start(self, agent_id: str, label: str = "research") -> None:
|
| 200 |
import time
|
| 201 |
+
|
| 202 |
self._agents[agent_id] = {
|
| 203 |
"label": label,
|
| 204 |
"calls": [],
|
|
|
|
| 250 |
@staticmethod
|
| 251 |
def _format_stats(agent: dict) -> str:
|
| 252 |
import time
|
| 253 |
+
|
| 254 |
start = agent["start_time"]
|
| 255 |
if start is None:
|
| 256 |
return ""
|
|
|
|
| 293 |
header += f" \033[2m·\033[0m \033[2m{short}\033[0m"
|
| 294 |
return [header]
|
| 295 |
lines = [header]
|
| 296 |
+
visible = agent["calls"][-self._MAX_VISIBLE :]
|
| 297 |
for desc in visible:
|
| 298 |
lines.append(f"{_I} \033[2m{desc}\033[0m")
|
| 299 |
return lines
|
|
|
|
| 336 |
|
| 337 |
# ── Messages ───────────────────────────────────────────────────────────
|
| 338 |
|
| 339 |
+
|
| 340 |
async def print_markdown(
|
| 341 |
text: str,
|
| 342 |
cancel_event: "asyncio.Event | None" = None,
|
| 343 |
instant: bool = False,
|
| 344 |
) -> None:
|
| 345 |
+
import io
|
| 346 |
+
import random
|
| 347 |
from rich.padding import Padding
|
| 348 |
|
| 349 |
_console.print()
|
|
|
|
| 413 |
|
| 414 |
|
| 415 |
def print_compacted(old_tokens: int, new_tokens: int) -> None:
|
| 416 |
+
_console.print(
|
| 417 |
+
f"{_I}[dim]context compacted: {old_tokens:,} → {new_tokens:,} tokens[/dim]"
|
| 418 |
+
)
|
| 419 |
|
| 420 |
|
| 421 |
# ── Approval ───────────────────────────────────────────────────────────
|
| 422 |
|
| 423 |
+
|
| 424 |
def print_approval_header(count: int) -> None:
|
| 425 |
label = f"Approval required — {count} item{'s' if count != 1 else ''}"
|
| 426 |
_console.print()
|
| 427 |
+
_console.print(
|
| 428 |
+
f"{_I}",
|
| 429 |
+
Panel(
|
| 430 |
+
f"[bold yellow]{label}[/bold yellow]", border_style="yellow", expand=False
|
| 431 |
+
),
|
| 432 |
+
)
|
| 433 |
|
| 434 |
|
| 435 |
def print_approval_item(index: int, total: int, tool_name: str, operation: str) -> None:
|
| 436 |
+
_console.print(
|
| 437 |
+
f"\n{_I}[bold]\\[{index}/{total}][/bold] [tool.name]{tool_name}[/tool.name] {operation}"
|
| 438 |
+
)
|
| 439 |
|
| 440 |
|
| 441 |
def print_yolo_approve(count: int) -> None:
|
| 442 |
+
_console.print(
|
| 443 |
+
f"{_I}[bold yellow]yolo →[/bold yellow] auto-approved {count} item(s)"
|
| 444 |
+
)
|
| 445 |
|
| 446 |
|
| 447 |
# ── Help ───────────────────────────────────────────────────────────────
|
|
|
|
| 467 |
|
| 468 |
# ── Plan display ───────────────────────────────────────────────────────
|
| 469 |
|
| 470 |
+
|
| 471 |
def format_plan_display() -> str:
|
| 472 |
"""Format the current plan for display."""
|
| 473 |
from agent.tools.plan_tool import get_current_plan
|
|
|
|
| 501 |
|
| 502 |
# ── Formatting for plan_tool output (used by plan_tool handler) ────────
|
| 503 |
|
| 504 |
+
|
| 505 |
def format_plan_tool_output(todos: list) -> str:
|
| 506 |
if not todos:
|
| 507 |
return "Plan is empty."
|
|
|
|
| 524 |
|
| 525 |
# ── Internal helpers ───────────────────────────────────────────────────
|
| 526 |
|
| 527 |
+
|
| 528 |
def _truncate(text: str, max_lines: int = 6) -> str:
|
| 529 |
lines = text.split("\n")
|
| 530 |
if len(lines) <= max_lines:
|
backend/dependencies.py
CHANGED
|
@@ -102,7 +102,9 @@ async def _fetch_user_plan(token: str) -> str:
|
|
| 102 |
_WHOAMI_SHAPE_LOGGED = True
|
| 103 |
logger.debug(
|
| 104 |
"whoami-v2 payload keys: %s (sample values: plan=%r type=%r isPro=%r)",
|
| 105 |
-
sorted(whoami.keys())
|
|
|
|
|
|
|
| 106 |
whoami.get("plan") if isinstance(whoami, dict) else None,
|
| 107 |
whoami.get("type") if isinstance(whoami, dict) else None,
|
| 108 |
whoami.get("isPro") if isinstance(whoami, dict) else None,
|
|
|
|
| 102 |
_WHOAMI_SHAPE_LOGGED = True
|
| 103 |
logger.debug(
|
| 104 |
"whoami-v2 payload keys: %s (sample values: plan=%r type=%r isPro=%r)",
|
| 105 |
+
sorted(whoami.keys())
|
| 106 |
+
if isinstance(whoami, dict)
|
| 107 |
+
else type(whoami).__name__,
|
| 108 |
whoami.get("plan") if isinstance(whoami, dict) else None,
|
| 109 |
whoami.get("type") if isinstance(whoami, dict) else None,
|
| 110 |
whoami.get("isPro") if isinstance(whoami, dict) else None,
|
backend/kpis_scheduler.py
CHANGED
|
@@ -58,7 +58,8 @@ def _resolve_token() -> Optional[str]:
|
|
| 58 |
def _load_build_kpis():
|
| 59 |
"""Import ``scripts/build_kpis.py`` without putting ``scripts/`` on sys.path."""
|
| 60 |
spec = importlib.util.spec_from_file_location(
|
| 61 |
-
"build_kpis",
|
|
|
|
| 62 |
)
|
| 63 |
mod = importlib.util.module_from_spec(spec)
|
| 64 |
assert spec.loader is not None
|
|
@@ -75,6 +76,7 @@ async def _run_hour(hour_dt: datetime) -> None:
|
|
| 75 |
try:
|
| 76 |
mod = _load_build_kpis()
|
| 77 |
from huggingface_hub import HfApi
|
|
|
|
| 78 |
api = HfApi()
|
| 79 |
source = os.environ.get("KPI_SOURCE_REPO", "smolagents/ml-intern-sessions")
|
| 80 |
target = os.environ.get("KPI_TARGET_REPO", "smolagents/ml-intern-kpis")
|
|
@@ -118,7 +120,7 @@ def start(backfill_hours: int = 6) -> None:
|
|
| 118 |
CronTrigger(minute=5),
|
| 119 |
id="kpis_hourly",
|
| 120 |
misfire_grace_time=600, # tolerate a 10-min misfire window
|
| 121 |
-
coalesce=True,
|
| 122 |
max_instances=1,
|
| 123 |
replace_existing=True,
|
| 124 |
)
|
|
|
|
| 58 |
def _load_build_kpis():
|
| 59 |
"""Import ``scripts/build_kpis.py`` without putting ``scripts/`` on sys.path."""
|
| 60 |
spec = importlib.util.spec_from_file_location(
|
| 61 |
+
"build_kpis",
|
| 62 |
+
_PROJECT_ROOT / "scripts" / "build_kpis.py",
|
| 63 |
)
|
| 64 |
mod = importlib.util.module_from_spec(spec)
|
| 65 |
assert spec.loader is not None
|
|
|
|
| 76 |
try:
|
| 77 |
mod = _load_build_kpis()
|
| 78 |
from huggingface_hub import HfApi
|
| 79 |
+
|
| 80 |
api = HfApi()
|
| 81 |
source = os.environ.get("KPI_SOURCE_REPO", "smolagents/ml-intern-sessions")
|
| 82 |
target = os.environ.get("KPI_TARGET_REPO", "smolagents/ml-intern-kpis")
|
|
|
|
| 120 |
CronTrigger(minute=5),
|
| 121 |
id="kpis_hourly",
|
| 122 |
misfire_grace_time=600, # tolerate a 10-min misfire window
|
| 123 |
+
coalesce=True, # collapse multiple missed fires into one
|
| 124 |
max_instances=1,
|
| 125 |
replace_existing=True,
|
| 126 |
)
|
backend/main.py
CHANGED
|
@@ -6,17 +6,17 @@ from contextlib import asynccontextmanager
|
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# Load .env before importing routes/session_manager so persistence and quota
|
| 11 |
# modules see local Mongo settings during startup.
|
| 12 |
load_dotenv(Path(__file__).parent.parent / ".env")
|
| 13 |
|
| 14 |
-
from
|
| 15 |
-
from
|
| 16 |
-
from
|
| 17 |
-
from routes.agent import router as agent_router
|
| 18 |
-
from routes.auth import router as auth_router
|
| 19 |
-
from session_manager import session_manager
|
| 20 |
|
| 21 |
# Configure logging
|
| 22 |
logging.basicConfig(
|
|
@@ -35,6 +35,7 @@ async def lifespan(app: FastAPI):
|
|
| 35 |
# rollup lives next to the data and reuses the Space's HF token.
|
| 36 |
try:
|
| 37 |
import kpis_scheduler
|
|
|
|
| 38 |
kpis_scheduler.start()
|
| 39 |
except Exception as e:
|
| 40 |
logger.warning("KPI scheduler failed to start: %s", e)
|
|
@@ -43,6 +44,7 @@ async def lifespan(app: FastAPI):
|
|
| 43 |
logger.info("Shutting down HF Agent backend...")
|
| 44 |
try:
|
| 45 |
import kpis_scheduler
|
|
|
|
| 46 |
await kpis_scheduler.shutdown()
|
| 47 |
except Exception as e:
|
| 48 |
logger.warning("KPI scheduler shutdown failed: %s", e)
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
+
from fastapi import FastAPI
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from fastapi.staticfiles import StaticFiles
|
| 12 |
|
| 13 |
# Load .env before importing routes/session_manager so persistence and quota
|
| 14 |
# modules see local Mongo settings during startup.
|
| 15 |
load_dotenv(Path(__file__).parent.parent / ".env")
|
| 16 |
|
| 17 |
+
from routes.agent import router as agent_router # noqa: E402
|
| 18 |
+
from routes.auth import router as auth_router # noqa: E402
|
| 19 |
+
from session_manager import session_manager # noqa: E402
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Configure logging
|
| 22 |
logging.basicConfig(
|
|
|
|
| 35 |
# rollup lives next to the data and reuses the Space's HF token.
|
| 36 |
try:
|
| 37 |
import kpis_scheduler
|
| 38 |
+
|
| 39 |
kpis_scheduler.start()
|
| 40 |
except Exception as e:
|
| 41 |
logger.warning("KPI scheduler failed to start: %s", e)
|
|
|
|
| 44 |
logger.info("Shutting down HF Agent backend...")
|
| 45 |
try:
|
| 46 |
import kpis_scheduler
|
| 47 |
+
|
| 48 |
await kpis_scheduler.shutdown()
|
| 49 |
except Exception as e:
|
| 50 |
logger.warning("KPI scheduler shutdown failed: %s", e)
|
backend/models.py
CHANGED
|
@@ -131,4 +131,6 @@ class LLMHealthResponse(BaseModel):
|
|
| 131 |
status: str # "ok" | "error"
|
| 132 |
model: str
|
| 133 |
error: str | None = None
|
| 134 |
-
error_type: str | None =
|
|
|
|
|
|
|
|
|
| 131 |
status: str # "ok" | "error"
|
| 132 |
model: str
|
| 133 |
error: str | None = None
|
| 134 |
+
error_type: str | None = (
|
| 135 |
+
None # "auth" | "credits" | "rate_limit" | "network" | "unknown"
|
| 136 |
+
)
|
backend/routes/agent.py
CHANGED
|
@@ -7,7 +7,6 @@ dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically.
|
|
| 7 |
import asyncio
|
| 8 |
import json
|
| 9 |
import logging
|
| 10 |
-
import os
|
| 11 |
from typing import Any
|
| 12 |
|
| 13 |
from dependencies import (
|
|
@@ -34,7 +33,12 @@ from models import (
|
|
| 34 |
SubmitRequest,
|
| 35 |
TruncateRequest,
|
| 36 |
)
|
| 37 |
-
from session_manager import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
import user_quotas
|
| 40 |
|
|
@@ -136,7 +140,7 @@ async def _require_hf_for_gated_model(request: Request, model_id: str) -> None:
|
|
| 136 |
"""403 if a non-``huggingface``-org user tries to select a gated model.
|
| 137 |
|
| 138 |
Gated models are deployed paid endpoints backed by service-owned
|
| 139 |
-
credentials. The gate only fires for deployed paid models so non-HF users
|
| 140 |
can still freely switch between the free models.
|
| 141 |
"""
|
| 142 |
if not _is_gated_model(model_id):
|
|
@@ -226,7 +230,11 @@ async def _check_session_access(
|
|
| 226 |
preload_sandbox: bool = True,
|
| 227 |
) -> AgentSession:
|
| 228 |
"""Verify and lazily load the user's session. Raises 403 or 404."""
|
| 229 |
-
hf_token =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
agent_session = await session_manager.ensure_session_loaded(
|
| 231 |
session_id,
|
| 232 |
user["user_id"],
|
|
@@ -236,7 +244,10 @@ async def _check_session_access(
|
|
| 236 |
)
|
| 237 |
if not agent_session:
|
| 238 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 239 |
-
if user["user_id"] != "dev" and agent_session.user_id not in {
|
|
|
|
|
|
|
|
|
|
| 240 |
raise HTTPException(status_code=403, detail="Access denied to this session")
|
| 241 |
return agent_session
|
| 242 |
|
|
@@ -362,7 +373,9 @@ async def generate_title(
|
|
| 362 |
await _check_session_access(request.session_id, user)
|
| 363 |
await session_manager.update_session_title(request.session_id, title)
|
| 364 |
except Exception:
|
| 365 |
-
logger.debug(
|
|
|
|
|
|
|
| 366 |
return {"title": title}
|
| 367 |
except Exception as e:
|
| 368 |
logger.warning(f"Title generation failed: {e}")
|
|
@@ -372,7 +385,10 @@ async def generate_title(
|
|
| 372 |
await _check_session_access(request.session_id, user)
|
| 373 |
await session_manager.update_session_title(request.session_id, title)
|
| 374 |
except Exception:
|
| 375 |
-
logger.debug(
|
|
|
|
|
|
|
|
|
|
| 376 |
return {"title": title}
|
| 377 |
|
| 378 |
|
|
@@ -586,7 +602,9 @@ async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
|
|
| 586 |
|
| 587 |
|
| 588 |
@router.get("/user/jobs-access")
|
| 589 |
-
async def get_jobs_access_info(
|
|
|
|
|
|
|
| 590 |
"""Return the namespaces the current token can run HF Jobs under.
|
| 591 |
|
| 592 |
Credits are enforced by the HF API at job-creation time, not here —
|
|
@@ -652,7 +670,7 @@ async def submit_approval(
|
|
| 652 |
request: ApprovalRequest, user: dict = Depends(get_current_user)
|
| 653 |
) -> dict:
|
| 654 |
"""Submit tool approvals to a session. Only accessible by the session owner."""
|
| 655 |
-
|
| 656 |
approvals = [
|
| 657 |
{
|
| 658 |
"tool_call_id": a.tool_call_id,
|
|
@@ -719,7 +737,9 @@ async def chat_sse(
|
|
| 719 |
success = await session_manager.submit_user_input(session_id, text)
|
| 720 |
else:
|
| 721 |
broadcaster.unsubscribe(sub_id)
|
| 722 |
-
raise HTTPException(
|
|
|
|
|
|
|
| 723 |
|
| 724 |
if not success:
|
| 725 |
broadcaster.unsubscribe(sub_id)
|
|
@@ -744,6 +764,7 @@ async def record_pro_click(
|
|
| 744 |
agent_session = await _check_session_access(session_id, user)
|
| 745 |
|
| 746 |
from agent.core import telemetry
|
|
|
|
| 747 |
await telemetry.record_pro_cta_click(
|
| 748 |
agent_session.session,
|
| 749 |
source=str(body.get("source") or "unknown"),
|
|
@@ -759,12 +780,20 @@ async def record_pro_click(
|
|
| 759 |
# ---------------------------------------------------------------------------
|
| 760 |
# Shared SSE helpers
|
| 761 |
# ---------------------------------------------------------------------------
|
| 762 |
-
_TERMINAL_EVENTS = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
_SSE_KEEPALIVE_SECONDS = 15
|
| 764 |
|
| 765 |
|
| 766 |
def _last_event_seq(request: Request) -> int:
|
| 767 |
-
raw =
|
|
|
|
|
|
|
| 768 |
try:
|
| 769 |
return max(0, int(raw))
|
| 770 |
except (TypeError, ValueError):
|
|
@@ -853,7 +882,9 @@ async def subscribe_events(
|
|
| 853 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 854 |
|
| 855 |
after_seq = _last_event_seq(request)
|
| 856 |
-
replay_events = await session_manager._store().load_events_after(
|
|
|
|
|
|
|
| 857 |
broadcaster = agent_session.broadcaster
|
| 858 |
sub_id, event_queue = broadcaster.subscribe()
|
| 859 |
return _sse_response(
|
|
@@ -885,7 +916,10 @@ async def get_session_messages(
|
|
| 885 |
agent_session = await _check_session_access(session_id, user)
|
| 886 |
if not agent_session or not agent_session.is_active:
|
| 887 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 888 |
-
return [
|
|
|
|
|
|
|
|
|
|
| 889 |
|
| 890 |
|
| 891 |
@router.post("/undo/{session_id}")
|
|
@@ -906,7 +940,10 @@ async def truncate_session(
|
|
| 906 |
await _check_session_access(session_id, user)
|
| 907 |
success = await session_manager.truncate(session_id, body.user_message_index)
|
| 908 |
if not success:
|
| 909 |
-
raise HTTPException(
|
|
|
|
|
|
|
|
|
|
| 910 |
return {"status": "truncated", "session_id": session_id}
|
| 911 |
|
| 912 |
|
|
@@ -933,6 +970,7 @@ async def shutdown_session(
|
|
| 933 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 934 |
return {"status": "shutdown_requested", "session_id": session_id}
|
| 935 |
|
|
|
|
| 936 |
@router.post("/feedback/{session_id}")
|
| 937 |
async def submit_feedback(
|
| 938 |
session_id: str,
|
|
@@ -952,6 +990,7 @@ async def submit_feedback(
|
|
| 952 |
raise HTTPException(status_code=400, detail="invalid rating")
|
| 953 |
|
| 954 |
from agent.core import telemetry
|
|
|
|
| 955 |
await telemetry.record_feedback(
|
| 956 |
agent_session.session,
|
| 957 |
rating=rating,
|
|
|
|
| 7 |
import asyncio
|
| 8 |
import json
|
| 9 |
import logging
|
|
|
|
| 10 |
from typing import Any
|
| 11 |
|
| 12 |
from dependencies import (
|
|
|
|
| 33 |
SubmitRequest,
|
| 34 |
TruncateRequest,
|
| 35 |
)
|
| 36 |
+
from session_manager import (
|
| 37 |
+
MAX_SESSIONS,
|
| 38 |
+
AgentSession,
|
| 39 |
+
SessionCapacityError,
|
| 40 |
+
session_manager,
|
| 41 |
+
)
|
| 42 |
|
| 43 |
import user_quotas
|
| 44 |
|
|
|
|
| 140 |
"""403 if a non-``huggingface``-org user tries to select a gated model.
|
| 141 |
|
| 142 |
Gated models are deployed paid endpoints backed by service-owned
|
| 143 |
+
credentials. The gate only fires for deployed paid models so non-HF users
|
| 144 |
can still freely switch between the free models.
|
| 145 |
"""
|
| 146 |
if not _is_gated_model(model_id):
|
|
|
|
| 230 |
preload_sandbox: bool = True,
|
| 231 |
) -> AgentSession:
|
| 232 |
"""Verify and lazily load the user's session. Raises 403 or 404."""
|
| 233 |
+
hf_token = (
|
| 234 |
+
resolve_hf_request_token(request)
|
| 235 |
+
if request is not None
|
| 236 |
+
else _user_hf_token(user)
|
| 237 |
+
)
|
| 238 |
agent_session = await session_manager.ensure_session_loaded(
|
| 239 |
session_id,
|
| 240 |
user["user_id"],
|
|
|
|
| 244 |
)
|
| 245 |
if not agent_session:
|
| 246 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 247 |
+
if user["user_id"] != "dev" and agent_session.user_id not in {
|
| 248 |
+
user["user_id"],
|
| 249 |
+
"dev",
|
| 250 |
+
}:
|
| 251 |
raise HTTPException(status_code=403, detail="Access denied to this session")
|
| 252 |
return agent_session
|
| 253 |
|
|
|
|
| 373 |
await _check_session_access(request.session_id, user)
|
| 374 |
await session_manager.update_session_title(request.session_id, title)
|
| 375 |
except Exception:
|
| 376 |
+
logger.debug(
|
| 377 |
+
"Skipping title persistence for missing session %s", request.session_id
|
| 378 |
+
)
|
| 379 |
return {"title": title}
|
| 380 |
except Exception as e:
|
| 381 |
logger.warning(f"Title generation failed: {e}")
|
|
|
|
| 385 |
await _check_session_access(request.session_id, user)
|
| 386 |
await session_manager.update_session_title(request.session_id, title)
|
| 387 |
except Exception:
|
| 388 |
+
logger.debug(
|
| 389 |
+
"Skipping fallback title persistence for missing session %s",
|
| 390 |
+
request.session_id,
|
| 391 |
+
)
|
| 392 |
return {"title": title}
|
| 393 |
|
| 394 |
|
|
|
|
| 602 |
|
| 603 |
|
| 604 |
@router.get("/user/jobs-access")
|
| 605 |
+
async def get_jobs_access_info(
|
| 606 |
+
request: Request, user: dict = Depends(get_current_user)
|
| 607 |
+
) -> dict:
|
| 608 |
"""Return the namespaces the current token can run HF Jobs under.
|
| 609 |
|
| 610 |
Credits are enforced by the HF API at job-creation time, not here —
|
|
|
|
| 670 |
request: ApprovalRequest, user: dict = Depends(get_current_user)
|
| 671 |
) -> dict:
|
| 672 |
"""Submit tool approvals to a session. Only accessible by the session owner."""
|
| 673 |
+
await _check_session_access(request.session_id, user)
|
| 674 |
approvals = [
|
| 675 |
{
|
| 676 |
"tool_call_id": a.tool_call_id,
|
|
|
|
| 737 |
success = await session_manager.submit_user_input(session_id, text)
|
| 738 |
else:
|
| 739 |
broadcaster.unsubscribe(sub_id)
|
| 740 |
+
raise HTTPException(
|
| 741 |
+
status_code=400, detail="Must provide 'text' or 'approvals'"
|
| 742 |
+
)
|
| 743 |
|
| 744 |
if not success:
|
| 745 |
broadcaster.unsubscribe(sub_id)
|
|
|
|
| 764 |
agent_session = await _check_session_access(session_id, user)
|
| 765 |
|
| 766 |
from agent.core import telemetry
|
| 767 |
+
|
| 768 |
await telemetry.record_pro_cta_click(
|
| 769 |
agent_session.session,
|
| 770 |
source=str(body.get("source") or "unknown"),
|
|
|
|
| 780 |
# ---------------------------------------------------------------------------
|
| 781 |
# Shared SSE helpers
|
| 782 |
# ---------------------------------------------------------------------------
|
| 783 |
+
_TERMINAL_EVENTS = {
|
| 784 |
+
"turn_complete",
|
| 785 |
+
"approval_required",
|
| 786 |
+
"error",
|
| 787 |
+
"interrupted",
|
| 788 |
+
"shutdown",
|
| 789 |
+
}
|
| 790 |
_SSE_KEEPALIVE_SECONDS = 15
|
| 791 |
|
| 792 |
|
| 793 |
def _last_event_seq(request: Request) -> int:
|
| 794 |
+
raw = (
|
| 795 |
+
request.headers.get("last-event-id") or request.query_params.get("after") or "0"
|
| 796 |
+
)
|
| 797 |
try:
|
| 798 |
return max(0, int(raw))
|
| 799 |
except (TypeError, ValueError):
|
|
|
|
| 882 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 883 |
|
| 884 |
after_seq = _last_event_seq(request)
|
| 885 |
+
replay_events = await session_manager._store().load_events_after(
|
| 886 |
+
session_id, after_seq
|
| 887 |
+
)
|
| 888 |
broadcaster = agent_session.broadcaster
|
| 889 |
sub_id, event_queue = broadcaster.subscribe()
|
| 890 |
return _sse_response(
|
|
|
|
| 916 |
agent_session = await _check_session_access(session_id, user)
|
| 917 |
if not agent_session or not agent_session.is_active:
|
| 918 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 919 |
+
return [
|
| 920 |
+
msg.model_dump(mode="json")
|
| 921 |
+
for msg in agent_session.session.context_manager.items
|
| 922 |
+
]
|
| 923 |
|
| 924 |
|
| 925 |
@router.post("/undo/{session_id}")
|
|
|
|
| 940 |
await _check_session_access(session_id, user)
|
| 941 |
success = await session_manager.truncate(session_id, body.user_message_index)
|
| 942 |
if not success:
|
| 943 |
+
raise HTTPException(
|
| 944 |
+
status_code=404,
|
| 945 |
+
detail="Session not found, inactive, or message index out of range",
|
| 946 |
+
)
|
| 947 |
return {"status": "truncated", "session_id": session_id}
|
| 948 |
|
| 949 |
|
|
|
|
| 970 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 971 |
return {"status": "shutdown_requested", "session_id": session_id}
|
| 972 |
|
| 973 |
+
|
| 974 |
@router.post("/feedback/{session_id}")
|
| 975 |
async def submit_feedback(
|
| 976 |
session_id: str,
|
|
|
|
| 990 |
raise HTTPException(status_code=400, detail="invalid rating")
|
| 991 |
|
| 992 |
from agent.core import telemetry
|
| 993 |
+
|
| 994 |
await telemetry.record_feedback(
|
| 995 |
agent_session.session,
|
| 996 |
rating=rating,
|
backend/routes/auth.py
CHANGED
|
@@ -168,4 +168,3 @@ async def get_me(user: dict = Depends(get_current_user)) -> dict:
|
|
| 168 |
Uses the shared auth dependency which handles cookie + Bearer token.
|
| 169 |
"""
|
| 170 |
return {key: value for key, value in user.items() if not key.startswith("_")}
|
| 171 |
-
|
|
|
|
| 168 |
Uses the shared auth dependency which handles cookie + Bearer token.
|
| 169 |
"""
|
| 170 |
return {key: value for key, value in user.items() if not key.startswith("_")}
|
|
|
backend/session_manager.py
CHANGED
|
@@ -12,10 +12,11 @@ from typing import Any, Optional
|
|
| 12 |
|
| 13 |
from agent.config import load_config
|
| 14 |
from agent.core.agent_loop import process_submission
|
| 15 |
-
from agent.
|
| 16 |
from agent.core.session import Event, OpType, Session
|
| 17 |
from agent.core.session_persistence import get_session_store
|
| 18 |
from agent.core.tools import ToolRouter
|
|
|
|
| 19 |
|
| 20 |
# Get project root (parent of backend directory)
|
| 21 |
PROJECT_ROOT = Path(__file__).parent.parent
|
|
@@ -70,7 +71,11 @@ class EventBroadcaster:
|
|
| 70 |
while True:
|
| 71 |
try:
|
| 72 |
event: Event = await self._source.get()
|
| 73 |
-
msg = {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
for q in self._subscribers.values():
|
| 75 |
await q.put(msg)
|
| 76 |
except asyncio.CancelledError:
|
|
@@ -131,6 +136,7 @@ class SessionManager:
|
|
| 131 |
self.sessions: dict[str, AgentSession] = {}
|
| 132 |
self._lock = asyncio.Lock()
|
| 133 |
self.persistence_store = None
|
|
|
|
| 134 |
|
| 135 |
async def start(self) -> None:
|
| 136 |
"""Start shared background resources."""
|
|
@@ -153,9 +159,7 @@ class SessionManager:
|
|
| 153 |
def _count_user_sessions(self, user_id: str) -> int:
|
| 154 |
"""Count active sessions owned by a specific user."""
|
| 155 |
return sum(
|
| 156 |
-
1
|
| 157 |
-
for s in self.sessions.values()
|
| 158 |
-
if s.user_id == user_id and s.is_active
|
| 159 |
)
|
| 160 |
|
| 161 |
def _create_session_sync(
|
|
@@ -196,10 +200,7 @@ class SessionManager:
|
|
| 196 |
return tool_router, session
|
| 197 |
|
| 198 |
def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
|
| 199 |
-
return [
|
| 200 |
-
msg.model_dump(mode="json")
|
| 201 |
-
for msg in session.context_manager.items
|
| 202 |
-
]
|
| 203 |
|
| 204 |
def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]:
|
| 205 |
pending = session.pending_approval or {}
|
|
@@ -307,7 +308,9 @@ class SessionManager:
|
|
| 307 |
if hasattr(session, "auto_approval_policy_summary"):
|
| 308 |
return session.auto_approval_policy_summary()
|
| 309 |
cap = getattr(session, "auto_approval_cost_cap_usd", None)
|
| 310 |
-
estimated = float(
|
|
|
|
|
|
|
| 311 |
remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4)
|
| 312 |
return {
|
| 313 |
"enabled": bool(getattr(session, "auto_approval_enabled", False)),
|
|
@@ -410,6 +413,28 @@ class SessionManager:
|
|
| 410 |
session.sandbox_preload_cancel_event = None
|
| 411 |
self._start_cpu_sandbox_preload(agent_session)
|
| 412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
|
| 414 |
try:
|
| 415 |
await self._store().update_session_fields(
|
|
@@ -514,7 +539,9 @@ class SessionManager:
|
|
| 514 |
runtime_state=runtime_state or self._runtime_state(agent_session),
|
| 515 |
status=status,
|
| 516 |
turn_count=agent_session.session.turn_count,
|
| 517 |
-
pending_approval=self._serialize_pending_approval(
|
|
|
|
|
|
|
| 518 |
claude_counted=agent_session.claude_counted,
|
| 519 |
created_at=agent_session.created_at,
|
| 520 |
notification_destinations=list(
|
|
@@ -564,6 +591,7 @@ class SessionManager:
|
|
| 564 |
existing,
|
| 565 |
preload_sandbox=preload_sandbox,
|
| 566 |
)
|
|
|
|
| 567 |
return existing
|
| 568 |
return None
|
| 569 |
|
|
@@ -585,6 +613,7 @@ class SessionManager:
|
|
| 585 |
existing,
|
| 586 |
preload_sandbox=preload_sandbox,
|
| 587 |
)
|
|
|
|
| 588 |
return existing
|
| 589 |
return None
|
| 590 |
|
|
@@ -626,7 +655,10 @@ class SessionManager:
|
|
| 626 |
if restored_messages:
|
| 627 |
# Keep the freshly-rendered system prompt, then attach the durable
|
| 628 |
# non-system context so tools/date/user context stay current.
|
| 629 |
-
session.context_manager.items = [
|
|
|
|
|
|
|
|
|
|
| 630 |
|
| 631 |
self._restore_pending_approval(session, meta.get("pending_approval") or [])
|
| 632 |
session.turn_count = int(meta.get("turn_count") or 0)
|
|
@@ -668,7 +700,9 @@ class SessionManager:
|
|
| 668 |
hf_token=hf_token,
|
| 669 |
hf_username=hf_username,
|
| 670 |
)
|
|
|
|
| 671 |
return started
|
|
|
|
| 672 |
if preload_sandbox:
|
| 673 |
self._start_cpu_sandbox_preload(agent_session)
|
| 674 |
logger.info("Restored session %s for user %s", session_id, owner or user_id)
|
|
@@ -751,6 +785,7 @@ class SessionManager:
|
|
| 751 |
event_queue=event_queue,
|
| 752 |
tool_router=tool_router,
|
| 753 |
)
|
|
|
|
| 754 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 755 |
self._start_cpu_sandbox_preload(agent_session)
|
| 756 |
|
|
@@ -760,7 +795,9 @@ class SessionManager:
|
|
| 760 |
logger.info(f"Created session {session_id} for user {user_id}")
|
| 761 |
return session_id
|
| 762 |
|
| 763 |
-
async def _track_pro_status(
|
|
|
|
|
|
|
| 764 |
"""Update Mongo per-user Pro state and emit a one-shot conversion
|
| 765 |
event if the store reports a free→Pro transition. Best-effort: any
|
| 766 |
Mongo failure is swallowed so we never fail session creation on
|
|
@@ -777,6 +814,7 @@ class SessionManager:
|
|
| 777 |
return
|
| 778 |
try:
|
| 779 |
from agent.core import telemetry
|
|
|
|
| 780 |
await telemetry.record_pro_conversion(
|
| 781 |
agent_session.session,
|
| 782 |
first_seen_at=result.get("first_seen_at"),
|
|
@@ -933,7 +971,9 @@ class SessionManager:
|
|
| 933 |
)
|
| 934 |
agent_session.is_processing = True
|
| 935 |
try:
|
| 936 |
-
should_continue = await process_submission(
|
|
|
|
|
|
|
| 937 |
finally:
|
| 938 |
agent_session.is_processing = False
|
| 939 |
await self.persist_session_snapshot(agent_session)
|
|
@@ -964,7 +1004,9 @@ class SessionManager:
|
|
| 964 |
# Idempotent via session_id key; detached subprocess.
|
| 965 |
if session.config.save_sessions:
|
| 966 |
try:
|
| 967 |
-
session.save_and_upload_detached(
|
|
|
|
|
|
|
| 968 |
except Exception as e:
|
| 969 |
logger.warning(f"Final-flush failed for {session_id}: {e}")
|
| 970 |
|
|
@@ -1025,7 +1067,9 @@ class SessionManager:
|
|
| 1025 |
agent_session = self.sessions.get(session_id)
|
| 1026 |
if not agent_session or not agent_session.is_active:
|
| 1027 |
return False
|
| 1028 |
-
success = agent_session.session.context_manager.truncate_to_user_message(
|
|
|
|
|
|
|
| 1029 |
if success:
|
| 1030 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 1031 |
return success
|
|
@@ -1118,9 +1162,7 @@ class SessionManager:
|
|
| 1118 |
session = agent_session.session
|
| 1119 |
if enabled:
|
| 1120 |
if not cap_provided and cost_cap_usd is None:
|
| 1121 |
-
cost_cap_usd = getattr(
|
| 1122 |
-
session, "auto_approval_cost_cap_usd", None
|
| 1123 |
-
)
|
| 1124 |
if cost_cap_usd is None:
|
| 1125 |
cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
|
| 1126 |
elif cost_cap_usd is None:
|
|
@@ -1203,9 +1245,7 @@ class SessionManager:
|
|
| 1203 |
if destination is None:
|
| 1204 |
raise ValueError(f"Unknown destination '{name}'")
|
| 1205 |
if not destination.allow_auto_events:
|
| 1206 |
-
raise ValueError(
|
| 1207 |
-
f"Destination '{name}' is not enabled for auto events"
|
| 1208 |
-
)
|
| 1209 |
if name not in seen:
|
| 1210 |
normalized.append(name)
|
| 1211 |
seen.add(name)
|
|
@@ -1248,7 +1288,10 @@ class SessionManager:
|
|
| 1248 |
"pending_approval": pending or None,
|
| 1249 |
"model": row.get("model"),
|
| 1250 |
"title": row.get("title"),
|
| 1251 |
-
"notification_destinations": row.get(
|
|
|
|
|
|
|
|
|
|
| 1252 |
"auto_approval": {
|
| 1253 |
"enabled": bool(row.get("auto_approval_enabled", False)),
|
| 1254 |
"cost_cap_usd": row.get("auto_approval_cost_cap_usd"),
|
|
@@ -1261,8 +1304,13 @@ class SessionManager:
|
|
| 1261 |
else round(
|
| 1262 |
max(
|
| 1263 |
0.0,
|
| 1264 |
-
float(
|
| 1265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1266 |
),
|
| 1267 |
4,
|
| 1268 |
)
|
|
|
|
| 12 |
|
| 13 |
from agent.config import load_config
|
| 14 |
from agent.core.agent_loop import process_submission
|
| 15 |
+
from agent.core.hub_artifacts import start_session_artifact_collection_task
|
| 16 |
from agent.core.session import Event, OpType, Session
|
| 17 |
from agent.core.session_persistence import get_session_store
|
| 18 |
from agent.core.tools import ToolRouter
|
| 19 |
+
from agent.messaging.gateway import NotificationGateway
|
| 20 |
|
| 21 |
# Get project root (parent of backend directory)
|
| 22 |
PROJECT_ROOT = Path(__file__).parent.parent
|
|
|
|
| 71 |
while True:
|
| 72 |
try:
|
| 73 |
event: Event = await self._source.get()
|
| 74 |
+
msg = {
|
| 75 |
+
"event_type": event.event_type,
|
| 76 |
+
"data": event.data,
|
| 77 |
+
"seq": event.seq,
|
| 78 |
+
}
|
| 79 |
for q in self._subscribers.values():
|
| 80 |
await q.put(msg)
|
| 81 |
except asyncio.CancelledError:
|
|
|
|
| 136 |
self.sessions: dict[str, AgentSession] = {}
|
| 137 |
self._lock = asyncio.Lock()
|
| 138 |
self.persistence_store = None
|
| 139 |
+
self.enable_hub_artifact_collections = True
|
| 140 |
|
| 141 |
async def start(self) -> None:
|
| 142 |
"""Start shared background resources."""
|
|
|
|
| 159 |
def _count_user_sessions(self, user_id: str) -> int:
|
| 160 |
"""Count active sessions owned by a specific user."""
|
| 161 |
return sum(
|
| 162 |
+
1 for s in self.sessions.values() if s.user_id == user_id and s.is_active
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
|
| 165 |
def _create_session_sync(
|
|
|
|
| 200 |
return tool_router, session
|
| 201 |
|
| 202 |
def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
|
| 203 |
+
return [msg.model_dump(mode="json") for msg in session.context_manager.items]
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]:
|
| 206 |
pending = session.pending_approval or {}
|
|
|
|
| 308 |
if hasattr(session, "auto_approval_policy_summary"):
|
| 309 |
return session.auto_approval_policy_summary()
|
| 310 |
cap = getattr(session, "auto_approval_cost_cap_usd", None)
|
| 311 |
+
estimated = float(
|
| 312 |
+
getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0
|
| 313 |
+
)
|
| 314 |
remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4)
|
| 315 |
return {
|
| 316 |
"enabled": bool(getattr(session, "auto_approval_enabled", False)),
|
|
|
|
| 413 |
session.sandbox_preload_cancel_event = None
|
| 414 |
self._start_cpu_sandbox_preload(agent_session)
|
| 415 |
|
| 416 |
+
def _start_hub_artifact_collection(self, agent_session: AgentSession) -> None:
|
| 417 |
+
"""Kick off best-effort Hub collection creation for the session."""
|
| 418 |
+
if not getattr(self, "enable_hub_artifact_collections", False):
|
| 419 |
+
return
|
| 420 |
+
session = agent_session.session
|
| 421 |
+
if not getattr(session, "session_id", None):
|
| 422 |
+
try:
|
| 423 |
+
session.session_id = agent_session.session_id
|
| 424 |
+
except Exception:
|
| 425 |
+
logger.debug("Could not attach session id for Hub artifact collection")
|
| 426 |
+
token = agent_session.hf_token or getattr(session, "hf_token", None)
|
| 427 |
+
if not token:
|
| 428 |
+
return
|
| 429 |
+
try:
|
| 430 |
+
start_session_artifact_collection_task(session, token=token)
|
| 431 |
+
except Exception as e:
|
| 432 |
+
logger.debug(
|
| 433 |
+
"Failed to schedule Hub artifact collection for %s: %s",
|
| 434 |
+
agent_session.session_id,
|
| 435 |
+
e,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
|
| 439 |
try:
|
| 440 |
await self._store().update_session_fields(
|
|
|
|
| 539 |
runtime_state=runtime_state or self._runtime_state(agent_session),
|
| 540 |
status=status,
|
| 541 |
turn_count=agent_session.session.turn_count,
|
| 542 |
+
pending_approval=self._serialize_pending_approval(
|
| 543 |
+
agent_session.session
|
| 544 |
+
),
|
| 545 |
claude_counted=agent_session.claude_counted,
|
| 546 |
created_at=agent_session.created_at,
|
| 547 |
notification_destinations=list(
|
|
|
|
| 591 |
existing,
|
| 592 |
preload_sandbox=preload_sandbox,
|
| 593 |
)
|
| 594 |
+
self._start_hub_artifact_collection(existing)
|
| 595 |
return existing
|
| 596 |
return None
|
| 597 |
|
|
|
|
| 613 |
existing,
|
| 614 |
preload_sandbox=preload_sandbox,
|
| 615 |
)
|
| 616 |
+
self._start_hub_artifact_collection(existing)
|
| 617 |
return existing
|
| 618 |
return None
|
| 619 |
|
|
|
|
| 655 |
if restored_messages:
|
| 656 |
# Keep the freshly-rendered system prompt, then attach the durable
|
| 657 |
# non-system context so tools/date/user context stay current.
|
| 658 |
+
session.context_manager.items = [
|
| 659 |
+
session.context_manager.items[0],
|
| 660 |
+
*restored_messages,
|
| 661 |
+
]
|
| 662 |
|
| 663 |
self._restore_pending_approval(session, meta.get("pending_approval") or [])
|
| 664 |
session.turn_count = int(meta.get("turn_count") or 0)
|
|
|
|
| 700 |
hf_token=hf_token,
|
| 701 |
hf_username=hf_username,
|
| 702 |
)
|
| 703 |
+
self._start_hub_artifact_collection(started)
|
| 704 |
return started
|
| 705 |
+
self._start_hub_artifact_collection(agent_session)
|
| 706 |
if preload_sandbox:
|
| 707 |
self._start_cpu_sandbox_preload(agent_session)
|
| 708 |
logger.info("Restored session %s for user %s", session_id, owner or user_id)
|
|
|
|
| 785 |
event_queue=event_queue,
|
| 786 |
tool_router=tool_router,
|
| 787 |
)
|
| 788 |
+
self._start_hub_artifact_collection(agent_session)
|
| 789 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 790 |
self._start_cpu_sandbox_preload(agent_session)
|
| 791 |
|
|
|
|
| 795 |
logger.info(f"Created session {session_id} for user {user_id}")
|
| 796 |
return session_id
|
| 797 |
|
| 798 |
+
async def _track_pro_status(
|
| 799 |
+
self, agent_session: AgentSession, *, is_pro: bool
|
| 800 |
+
) -> None:
|
| 801 |
"""Update Mongo per-user Pro state and emit a one-shot conversion
|
| 802 |
event if the store reports a free→Pro transition. Best-effort: any
|
| 803 |
Mongo failure is swallowed so we never fail session creation on
|
|
|
|
| 814 |
return
|
| 815 |
try:
|
| 816 |
from agent.core import telemetry
|
| 817 |
+
|
| 818 |
await telemetry.record_pro_conversion(
|
| 819 |
agent_session.session,
|
| 820 |
first_seen_at=result.get("first_seen_at"),
|
|
|
|
| 971 |
)
|
| 972 |
agent_session.is_processing = True
|
| 973 |
try:
|
| 974 |
+
should_continue = await process_submission(
|
| 975 |
+
session, submission
|
| 976 |
+
)
|
| 977 |
finally:
|
| 978 |
agent_session.is_processing = False
|
| 979 |
await self.persist_session_snapshot(agent_session)
|
|
|
|
| 1004 |
# Idempotent via session_id key; detached subprocess.
|
| 1005 |
if session.config.save_sessions:
|
| 1006 |
try:
|
| 1007 |
+
session.save_and_upload_detached(
|
| 1008 |
+
session.config.session_dataset_repo
|
| 1009 |
+
)
|
| 1010 |
except Exception as e:
|
| 1011 |
logger.warning(f"Final-flush failed for {session_id}: {e}")
|
| 1012 |
|
|
|
|
| 1067 |
agent_session = self.sessions.get(session_id)
|
| 1068 |
if not agent_session or not agent_session.is_active:
|
| 1069 |
return False
|
| 1070 |
+
success = agent_session.session.context_manager.truncate_to_user_message(
|
| 1071 |
+
user_message_index
|
| 1072 |
+
)
|
| 1073 |
if success:
|
| 1074 |
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 1075 |
return success
|
|
|
|
| 1162 |
session = agent_session.session
|
| 1163 |
if enabled:
|
| 1164 |
if not cap_provided and cost_cap_usd is None:
|
| 1165 |
+
cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None)
|
|
|
|
|
|
|
| 1166 |
if cost_cap_usd is None:
|
| 1167 |
cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
|
| 1168 |
elif cost_cap_usd is None:
|
|
|
|
| 1245 |
if destination is None:
|
| 1246 |
raise ValueError(f"Unknown destination '{name}'")
|
| 1247 |
if not destination.allow_auto_events:
|
| 1248 |
+
raise ValueError(f"Destination '{name}' is not enabled for auto events")
|
|
|
|
|
|
|
| 1249 |
if name not in seen:
|
| 1250 |
normalized.append(name)
|
| 1251 |
seen.add(name)
|
|
|
|
| 1288 |
"pending_approval": pending or None,
|
| 1289 |
"model": row.get("model"),
|
| 1290 |
"title": row.get("title"),
|
| 1291 |
+
"notification_destinations": row.get(
|
| 1292 |
+
"notification_destinations"
|
| 1293 |
+
)
|
| 1294 |
+
or [],
|
| 1295 |
"auto_approval": {
|
| 1296 |
"enabled": bool(row.get("auto_approval_enabled", False)),
|
| 1297 |
"cost_cap_usd": row.get("auto_approval_cost_cap_usd"),
|
|
|
|
| 1304 |
else round(
|
| 1305 |
max(
|
| 1306 |
0.0,
|
| 1307 |
+
float(
|
| 1308 |
+
row.get("auto_approval_cost_cap_usd") or 0.0
|
| 1309 |
+
)
|
| 1310 |
+
- float(
|
| 1311 |
+
row.get("auto_approval_estimated_spend_usd")
|
| 1312 |
+
or 0.0
|
| 1313 |
+
),
|
| 1314 |
),
|
| 1315 |
4,
|
| 1316 |
)
|
backend/user_quotas.py
CHANGED
|
@@ -20,7 +20,11 @@ import asyncio
|
|
| 20 |
import os
|
| 21 |
from datetime import UTC, datetime
|
| 22 |
|
| 23 |
-
from agent.core.session_persistence import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1"))
|
| 26 |
CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20"))
|
|
|
|
| 20 |
import os
|
| 21 |
from datetime import UTC, datetime
|
| 22 |
|
| 23 |
+
from agent.core.session_persistence import (
|
| 24 |
+
NoopSessionStore,
|
| 25 |
+
get_session_store,
|
| 26 |
+
_reset_store_for_tests,
|
| 27 |
+
)
|
| 28 |
|
| 29 |
CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1"))
|
| 30 |
CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20"))
|
pyproject.toml
CHANGED
|
@@ -44,6 +44,7 @@ eval = [
|
|
| 44 |
dev = [
|
| 45 |
"pytest>=9.0.2",
|
| 46 |
"pytest-asyncio>=1.2.0",
|
|
|
|
| 47 |
]
|
| 48 |
|
| 49 |
# All dependencies (eval + dev)
|
|
|
|
| 44 |
dev = [
|
| 45 |
"pytest>=9.0.2",
|
| 46 |
"pytest-asyncio>=1.2.0",
|
| 47 |
+
"ruff>=0.15.12",
|
| 48 |
]
|
| 49 |
|
| 50 |
# All dependencies (eval + dev)
|
scripts/build_kpis.py
CHANGED
|
@@ -99,7 +99,6 @@ import sys
|
|
| 99 |
import tempfile
|
| 100 |
from collections import defaultdict
|
| 101 |
from datetime import date, datetime, timedelta, timezone
|
| 102 |
-
from pathlib import Path
|
| 103 |
from typing import Any, Iterable
|
| 104 |
|
| 105 |
logger = logging.getLogger("build_kpis")
|
|
@@ -107,13 +106,25 @@ logger = logging.getLogger("build_kpis")
|
|
| 107 |
# Rough gpu-hour pricing for hf_jobs flavor strings. Keep conservative; used
|
| 108 |
# only to compute gpu-hours (not dollars) — wall_time_s * flavor_gpu_count.
|
| 109 |
_FLAVOR_GPU_COUNT = {
|
| 110 |
-
"cpu-basic": 0,
|
| 111 |
-
"
|
| 112 |
-
"
|
| 113 |
-
"
|
| 114 |
-
"
|
| 115 |
-
"
|
| 116 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
}
|
| 118 |
|
| 119 |
|
|
@@ -160,9 +171,13 @@ def _download_session(repo_id: str, path: str, token: str) -> dict | None:
|
|
| 160 |
directory is near-free.
|
| 161 |
"""
|
| 162 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 163 |
try:
|
| 164 |
local = hf_hub_download(
|
| 165 |
-
repo_id=repo_id,
|
|
|
|
|
|
|
|
|
|
| 166 |
)
|
| 167 |
except Exception as e:
|
| 168 |
logger.warning("hf_hub_download(%s) failed: %s", path, e)
|
|
@@ -188,7 +203,9 @@ def _download_session(repo_id: str, path: str, token: str) -> dict | None:
|
|
| 188 |
|
| 189 |
|
| 190 |
def _filter_session_to_window(
|
| 191 |
-
session: dict,
|
|
|
|
|
|
|
| 192 |
) -> dict | None:
|
| 193 |
"""Return a copy of ``session`` whose events are only those in ``[start, end)``.
|
| 194 |
|
|
@@ -216,16 +233,29 @@ def _session_metrics(session: dict) -> dict:
|
|
| 216 |
# Pre-seed every numeric key so downstream aggregation can sum without
|
| 217 |
# having to special-case empty sessions.
|
| 218 |
out: dict = {
|
| 219 |
-
"sessions": 0,
|
| 220 |
-
"
|
| 221 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
"cost_usd": 0.0,
|
| 223 |
-
"tool_calls_total": 0,
|
| 224 |
-
"
|
| 225 |
-
"
|
| 226 |
-
"
|
| 227 |
-
"
|
| 228 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
"first_tool_s": -1,
|
| 230 |
}
|
| 231 |
events = session.get("events") or []
|
|
@@ -373,7 +403,9 @@ def _session_metrics(session: dict) -> dict:
|
|
| 373 |
|
| 374 |
def _aggregate(per_session: list[dict]) -> dict:
|
| 375 |
"""Collapse a bucket's worth of session rollups into the final KPI row."""
|
| 376 |
-
ttfa_values = [
|
|
|
|
|
|
|
| 377 |
gpu_hours: dict[str, float] = defaultdict(float)
|
| 378 |
for s in per_session:
|
| 379 |
for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
|
|
@@ -395,9 +427,21 @@ def _aggregate(per_session: list[dict]) -> dict:
|
|
| 395 |
# never reached for the relevant signal — otherwise quiet hours
|
| 396 |
# (status-check sessions, abandoned new conversations) drag every median
|
| 397 |
# to 0 and the chart tells you nothing.
|
| 398 |
-
research_calls_nz = [
|
| 399 |
-
|
| 400 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
# Per-turn intensity: turns>0 is the natural filter here (a session with
|
| 402 |
# 5 turns and 0 tools is a meaningful 0). Don't strip those.
|
| 403 |
calls_per_turn_values = [
|
|
@@ -415,7 +459,9 @@ def _aggregate(per_session: list[dict]) -> dict:
|
|
| 415 |
failures = int(sum(s["failures"] for s in per_session))
|
| 416 |
regenerates = int(sum(s["regenerate_sessions"] for s in per_session))
|
| 417 |
research_calls_total = int(sum(s.get("_research_calls", 0) for s in per_session))
|
| 418 |
-
sessions_with_research = sum(
|
|
|
|
|
|
|
| 419 |
|
| 420 |
# Per-session cost percentiles — chart "median session cost" alongside the
|
| 421 |
# mean so a few $700 outliers don't make you think every session is pricey.
|
|
@@ -433,17 +479,23 @@ def _aggregate(per_session: list[dict]) -> dict:
|
|
| 433 |
"tokens_prompt": int(tokens_prompt),
|
| 434 |
"tokens_completion": int(sum(s["tokens_completion"] for s in per_session)),
|
| 435 |
"tokens_cache_read": int(tokens_cache_read),
|
| 436 |
-
"tokens_cache_creation": int(
|
|
|
|
|
|
|
| 437 |
"cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
|
| 438 |
# Per-session cost summaries.
|
| 439 |
"cost_per_session_mean": round(
|
| 440 |
sum(s["cost_usd"] for s in per_session) / total_sessions, 6
|
| 441 |
-
)
|
|
|
|
|
|
|
| 442 |
"cost_per_session_p50": round(cost_p50, 6),
|
| 443 |
"cost_per_session_p95": round(cost_p95, 6),
|
| 444 |
"cache_hit_ratio": round(
|
| 445 |
tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
|
| 446 |
-
)
|
|
|
|
|
|
|
| 447 |
# Raw reliability COUNTS (these are what the dashboard shows directly).
|
| 448 |
"tool_calls_total": int(tool_total),
|
| 449 |
"tool_calls_succeeded": int(tool_success),
|
|
@@ -458,38 +510,56 @@ def _aggregate(per_session: list[dict]) -> dict:
|
|
| 458 |
"regenerated_sessions": regenerates,
|
| 459 |
# Rates kept for backwards compatibility with anything reading the
|
| 460 |
# KPI dataset directly.
|
| 461 |
-
"tool_success_rate": round(tool_success / tool_total, 4)
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
"time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
|
| 465 |
"time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
|
| 466 |
"thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
|
| 467 |
"thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
|
| 468 |
"hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
|
| 469 |
"hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
|
| 470 |
-
"sandboxes_created": int(
|
|
|
|
|
|
|
| 471 |
"sandboxes_cpu": int(sum(s.get("sandboxes_cpu", 0) for s in per_session)),
|
| 472 |
"sandboxes_gpu": int(sum(s.get("sandboxes_gpu", 0) for s in per_session)),
|
| 473 |
"hf_jobs_blocked": int(sum(s.get("hf_jobs_blocked", 0) for s in per_session)),
|
| 474 |
"pro_cta_clicks": int(sum(s.get("pro_cta_clicks", 0) for s in per_session)),
|
| 475 |
"pro_conversions": int(sum(s.get("pro_conversions", 0) for s in per_session)),
|
| 476 |
-
"credits_topped_up": int(
|
|
|
|
|
|
|
| 477 |
"gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
|
| 478 |
# Research KPIs — answer "is the agent reaching for research?".
|
| 479 |
"research_calls": research_calls_total,
|
| 480 |
"sessions_with_research": int(sessions_with_research),
|
| 481 |
"research_calls_per_session_p50": round(_percentile(research_calls_nz, 0.5), 2),
|
| 482 |
-
"research_calls_per_session_p95": round(
|
|
|
|
|
|
|
| 483 |
# Intra-session breadth + intensity. p50 + p95 over per-session values.
|
| 484 |
-
"distinct_tools_per_session_p50": round(
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
"tool_calls_per_session_p50": round(_percentile(total_calls_values, 0.5), 2),
|
| 487 |
"tool_calls_per_session_p95": round(_percentile(total_calls_values, 0.95), 2),
|
| 488 |
"tool_calls_per_turn_p50": round(_percentile(calls_per_turn_values, 0.5), 2),
|
| 489 |
"tool_calls_per_turn_p95": round(_percentile(calls_per_turn_values, 0.95), 2),
|
| 490 |
# JSON columns let the dashboard add/remove tools without schema churn.
|
| 491 |
"tool_calls_by_name_json": json.dumps(dict(tool_calls_by_name), sort_keys=True),
|
| 492 |
-
"sessions_using_tool_json": json.dumps(
|
|
|
|
|
|
|
| 493 |
# Surface split — answers "is research dropping on Bedrock specifically?".
|
| 494 |
"sessions_by_model_json": json.dumps(dict(sessions_by_model), sort_keys=True),
|
| 495 |
}
|
|
@@ -507,7 +577,12 @@ def _csv_cell(v: Any) -> str:
|
|
| 507 |
|
| 508 |
|
| 509 |
def _write_csv(
|
| 510 |
-
api,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
) -> None:
|
| 512 |
"""Render ``row`` to CSV with a leading ``bucket`` column and upload.
|
| 513 |
|
|
@@ -527,7 +602,10 @@ def _write_csv(
|
|
| 527 |
|
| 528 |
try:
|
| 529 |
api.create_repo(
|
| 530 |
-
repo_id=target_repo,
|
|
|
|
|
|
|
|
|
|
| 531 |
)
|
| 532 |
api.upload_file(
|
| 533 |
path_or_fileobj=tmp_path,
|
|
@@ -545,7 +623,11 @@ def _write_csv(
|
|
| 545 |
|
| 546 |
|
| 547 |
def run_for_hour(
|
| 548 |
-
api,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
) -> dict:
|
| 550 |
"""Roll up one UTC hour [hour_dt, hour_dt+1h).
|
| 551 |
|
|
@@ -579,10 +661,16 @@ def run_for_hour(
|
|
| 579 |
|
| 580 |
row = _aggregate(per_session)
|
| 581 |
bucket_key = window_start.strftime("%Y-%m-%dT%H")
|
| 582 |
-
path_in_repo =
|
|
|
|
|
|
|
| 583 |
_write_csv(api, row, bucket_key, path_in_repo, target_repo, token)
|
| 584 |
-
logger.info(
|
| 585 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
return row
|
| 587 |
|
| 588 |
|
|
@@ -618,17 +706,23 @@ def main(argv: list[str] | None = None) -> int:
|
|
| 618 |
ap.add_argument("--source", default="smolagents/ml-intern-sessions")
|
| 619 |
ap.add_argument("--target", default="smolagents/ml-intern-kpis")
|
| 620 |
ap.add_argument(
|
| 621 |
-
"--hours",
|
|
|
|
|
|
|
| 622 |
help="Number of trailing hours to roll up (default: 1 = last completed hour).",
|
| 623 |
)
|
| 624 |
ap.add_argument(
|
| 625 |
-
"--datetime",
|
|
|
|
|
|
|
| 626 |
help="Single hour, ISO ``YYYY-MM-DDTHH`` (UTC); overrides --hours.",
|
| 627 |
)
|
| 628 |
ap.add_argument(
|
| 629 |
-
"--daily-backfill",
|
|
|
|
|
|
|
| 630 |
help="Escape hatch: aggregate a whole day at once (YYYY-MM-DD). "
|
| 631 |
-
|
| 632 |
)
|
| 633 |
args = ap.parse_args(argv)
|
| 634 |
|
|
@@ -646,10 +740,17 @@ def main(argv: list[str] | None = None) -> int:
|
|
| 646 |
return 1
|
| 647 |
|
| 648 |
from huggingface_hub import HfApi
|
|
|
|
| 649 |
api = HfApi()
|
| 650 |
|
| 651 |
if args.daily_backfill:
|
| 652 |
-
run_for_day(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
return 0
|
| 654 |
|
| 655 |
if args.datetime:
|
|
|
|
| 99 |
import tempfile
|
| 100 |
from collections import defaultdict
|
| 101 |
from datetime import date, datetime, timedelta, timezone
|
|
|
|
| 102 |
from typing import Any, Iterable
|
| 103 |
|
| 104 |
logger = logging.getLogger("build_kpis")
|
|
|
|
| 106 |
# Rough gpu-hour pricing for hf_jobs flavor strings. Keep conservative; used
|
| 107 |
# only to compute gpu-hours (not dollars) — wall_time_s * flavor_gpu_count.
|
| 108 |
_FLAVOR_GPU_COUNT = {
|
| 109 |
+
"cpu-basic": 0,
|
| 110 |
+
"cpu-upgrade": 0,
|
| 111 |
+
"t4-small": 1,
|
| 112 |
+
"t4-medium": 1,
|
| 113 |
+
"l4x1": 1,
|
| 114 |
+
"l4x4": 4,
|
| 115 |
+
"l40sx1": 1,
|
| 116 |
+
"l40sx4": 4,
|
| 117 |
+
"l40sx8": 8,
|
| 118 |
+
"a10g-small": 1,
|
| 119 |
+
"a10g-large": 1,
|
| 120 |
+
"a10g-largex2": 2,
|
| 121 |
+
"a10g-largex4": 4,
|
| 122 |
+
"a100-large": 1,
|
| 123 |
+
"a100x2": 2,
|
| 124 |
+
"a100x4": 4,
|
| 125 |
+
"a100x8": 8,
|
| 126 |
+
"h100": 1,
|
| 127 |
+
"h100x8": 8,
|
| 128 |
}
|
| 129 |
|
| 130 |
|
|
|
|
| 171 |
directory is near-free.
|
| 172 |
"""
|
| 173 |
from huggingface_hub import hf_hub_download
|
| 174 |
+
|
| 175 |
try:
|
| 176 |
local = hf_hub_download(
|
| 177 |
+
repo_id=repo_id,
|
| 178 |
+
filename=path,
|
| 179 |
+
repo_type="dataset",
|
| 180 |
+
token=token,
|
| 181 |
)
|
| 182 |
except Exception as e:
|
| 183 |
logger.warning("hf_hub_download(%s) failed: %s", path, e)
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
def _filter_session_to_window(
|
| 206 |
+
session: dict,
|
| 207 |
+
start: datetime,
|
| 208 |
+
end: datetime,
|
| 209 |
) -> dict | None:
|
| 210 |
"""Return a copy of ``session`` whose events are only those in ``[start, end)``.
|
| 211 |
|
|
|
|
| 233 |
# Pre-seed every numeric key so downstream aggregation can sum without
|
| 234 |
# having to special-case empty sessions.
|
| 235 |
out: dict = {
|
| 236 |
+
"sessions": 0,
|
| 237 |
+
"turns": 0,
|
| 238 |
+
"llm_calls": 0,
|
| 239 |
+
"tokens_prompt": 0,
|
| 240 |
+
"tokens_completion": 0,
|
| 241 |
+
"tokens_cache_read": 0,
|
| 242 |
+
"tokens_cache_creation": 0,
|
| 243 |
"cost_usd": 0.0,
|
| 244 |
+
"tool_calls_total": 0,
|
| 245 |
+
"tool_calls_success": 0,
|
| 246 |
+
"failures": 0,
|
| 247 |
+
"regenerate_sessions": 0,
|
| 248 |
+
"thumbs_up": 0,
|
| 249 |
+
"thumbs_down": 0,
|
| 250 |
+
"hf_jobs_submitted": 0,
|
| 251 |
+
"hf_jobs_succeeded": 0,
|
| 252 |
+
"hf_jobs_blocked": 0,
|
| 253 |
+
"pro_cta_clicks": 0,
|
| 254 |
+
"pro_conversions": 0,
|
| 255 |
+
"credits_topped_up": 0,
|
| 256 |
+
"sandboxes_created": 0,
|
| 257 |
+
"sandboxes_cpu": 0,
|
| 258 |
+
"sandboxes_gpu": 0,
|
| 259 |
"first_tool_s": -1,
|
| 260 |
}
|
| 261 |
events = session.get("events") or []
|
|
|
|
| 403 |
|
| 404 |
def _aggregate(per_session: list[dict]) -> dict:
|
| 405 |
"""Collapse a bucket's worth of session rollups into the final KPI row."""
|
| 406 |
+
ttfa_values = [
|
| 407 |
+
s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0
|
| 408 |
+
]
|
| 409 |
gpu_hours: dict[str, float] = defaultdict(float)
|
| 410 |
for s in per_session:
|
| 411 |
for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
|
|
|
|
| 427 |
# never reached for the relevant signal — otherwise quiet hours
|
| 428 |
# (status-check sessions, abandoned new conversations) drag every median
|
| 429 |
# to 0 and the chart tells you nothing.
|
| 430 |
+
research_calls_nz = [
|
| 431 |
+
s.get("_research_calls", 0)
|
| 432 |
+
for s in per_session
|
| 433 |
+
if s.get("_research_calls", 0) > 0
|
| 434 |
+
]
|
| 435 |
+
distinct_tools_values = [
|
| 436 |
+
s.get("_distinct_tools_used", 0)
|
| 437 |
+
for s in per_session
|
| 438 |
+
if s.get("_distinct_tools_used", 0) > 0
|
| 439 |
+
]
|
| 440 |
+
total_calls_values = [
|
| 441 |
+
s.get("_total_named_tool_calls", 0)
|
| 442 |
+
for s in per_session
|
| 443 |
+
if s.get("_total_named_tool_calls", 0) > 0
|
| 444 |
+
]
|
| 445 |
# Per-turn intensity: turns>0 is the natural filter here (a session with
|
| 446 |
# 5 turns and 0 tools is a meaningful 0). Don't strip those.
|
| 447 |
calls_per_turn_values = [
|
|
|
|
| 459 |
failures = int(sum(s["failures"] for s in per_session))
|
| 460 |
regenerates = int(sum(s["regenerate_sessions"] for s in per_session))
|
| 461 |
research_calls_total = int(sum(s.get("_research_calls", 0) for s in per_session))
|
| 462 |
+
sessions_with_research = sum(
|
| 463 |
+
1 for s in per_session if s.get("_research_calls", 0) > 0
|
| 464 |
+
)
|
| 465 |
|
| 466 |
# Per-session cost percentiles — chart "median session cost" alongside the
|
| 467 |
# mean so a few $700 outliers don't make you think every session is pricey.
|
|
|
|
| 479 |
"tokens_prompt": int(tokens_prompt),
|
| 480 |
"tokens_completion": int(sum(s["tokens_completion"] for s in per_session)),
|
| 481 |
"tokens_cache_read": int(tokens_cache_read),
|
| 482 |
+
"tokens_cache_creation": int(
|
| 483 |
+
sum(s["tokens_cache_creation"] for s in per_session)
|
| 484 |
+
),
|
| 485 |
"cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
|
| 486 |
# Per-session cost summaries.
|
| 487 |
"cost_per_session_mean": round(
|
| 488 |
sum(s["cost_usd"] for s in per_session) / total_sessions, 6
|
| 489 |
+
)
|
| 490 |
+
if total_sessions > 0
|
| 491 |
+
else 0.0,
|
| 492 |
"cost_per_session_p50": round(cost_p50, 6),
|
| 493 |
"cost_per_session_p95": round(cost_p95, 6),
|
| 494 |
"cache_hit_ratio": round(
|
| 495 |
tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
|
| 496 |
+
)
|
| 497 |
+
if (tokens_cache_read + tokens_prompt) > 0
|
| 498 |
+
else 0.0,
|
| 499 |
# Raw reliability COUNTS (these are what the dashboard shows directly).
|
| 500 |
"tool_calls_total": int(tool_total),
|
| 501 |
"tool_calls_succeeded": int(tool_success),
|
|
|
|
| 510 |
"regenerated_sessions": regenerates,
|
| 511 |
# Rates kept for backwards compatibility with anything reading the
|
| 512 |
# KPI dataset directly.
|
| 513 |
+
"tool_success_rate": round(tool_success / tool_total, 4)
|
| 514 |
+
if tool_total > 0
|
| 515 |
+
else 0.0,
|
| 516 |
+
"failure_rate": round(failures / total_sessions, 4)
|
| 517 |
+
if total_sessions > 0
|
| 518 |
+
else 0.0,
|
| 519 |
+
"regenerate_rate": round(regenerates / total_sessions, 4)
|
| 520 |
+
if total_sessions > 0
|
| 521 |
+
else 0.0,
|
| 522 |
"time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
|
| 523 |
"time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
|
| 524 |
"thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
|
| 525 |
"thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
|
| 526 |
"hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
|
| 527 |
"hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
|
| 528 |
+
"sandboxes_created": int(
|
| 529 |
+
sum(s.get("sandboxes_created", 0) for s in per_session)
|
| 530 |
+
),
|
| 531 |
"sandboxes_cpu": int(sum(s.get("sandboxes_cpu", 0) for s in per_session)),
|
| 532 |
"sandboxes_gpu": int(sum(s.get("sandboxes_gpu", 0) for s in per_session)),
|
| 533 |
"hf_jobs_blocked": int(sum(s.get("hf_jobs_blocked", 0) for s in per_session)),
|
| 534 |
"pro_cta_clicks": int(sum(s.get("pro_cta_clicks", 0) for s in per_session)),
|
| 535 |
"pro_conversions": int(sum(s.get("pro_conversions", 0) for s in per_session)),
|
| 536 |
+
"credits_topped_up": int(
|
| 537 |
+
sum(s.get("credits_topped_up", 0) for s in per_session)
|
| 538 |
+
),
|
| 539 |
"gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
|
| 540 |
# Research KPIs — answer "is the agent reaching for research?".
|
| 541 |
"research_calls": research_calls_total,
|
| 542 |
"sessions_with_research": int(sessions_with_research),
|
| 543 |
"research_calls_per_session_p50": round(_percentile(research_calls_nz, 0.5), 2),
|
| 544 |
+
"research_calls_per_session_p95": round(
|
| 545 |
+
_percentile(research_calls_nz, 0.95), 2
|
| 546 |
+
),
|
| 547 |
# Intra-session breadth + intensity. p50 + p95 over per-session values.
|
| 548 |
+
"distinct_tools_per_session_p50": round(
|
| 549 |
+
_percentile(distinct_tools_values, 0.5), 2
|
| 550 |
+
),
|
| 551 |
+
"distinct_tools_per_session_p95": round(
|
| 552 |
+
_percentile(distinct_tools_values, 0.95), 2
|
| 553 |
+
),
|
| 554 |
"tool_calls_per_session_p50": round(_percentile(total_calls_values, 0.5), 2),
|
| 555 |
"tool_calls_per_session_p95": round(_percentile(total_calls_values, 0.95), 2),
|
| 556 |
"tool_calls_per_turn_p50": round(_percentile(calls_per_turn_values, 0.5), 2),
|
| 557 |
"tool_calls_per_turn_p95": round(_percentile(calls_per_turn_values, 0.95), 2),
|
| 558 |
# JSON columns let the dashboard add/remove tools without schema churn.
|
| 559 |
"tool_calls_by_name_json": json.dumps(dict(tool_calls_by_name), sort_keys=True),
|
| 560 |
+
"sessions_using_tool_json": json.dumps(
|
| 561 |
+
dict(sessions_using_tool), sort_keys=True
|
| 562 |
+
),
|
| 563 |
# Surface split — answers "is research dropping on Bedrock specifically?".
|
| 564 |
"sessions_by_model_json": json.dumps(dict(sessions_by_model), sort_keys=True),
|
| 565 |
}
|
|
|
|
| 577 |
|
| 578 |
|
| 579 |
def _write_csv(
|
| 580 |
+
api,
|
| 581 |
+
row: dict,
|
| 582 |
+
bucket_key: str,
|
| 583 |
+
path_in_repo: str,
|
| 584 |
+
target_repo: str,
|
| 585 |
+
token: str,
|
| 586 |
) -> None:
|
| 587 |
"""Render ``row`` to CSV with a leading ``bucket`` column and upload.
|
| 588 |
|
|
|
|
| 602 |
|
| 603 |
try:
|
| 604 |
api.create_repo(
|
| 605 |
+
repo_id=target_repo,
|
| 606 |
+
repo_type="dataset",
|
| 607 |
+
exist_ok=True,
|
| 608 |
+
token=token,
|
| 609 |
)
|
| 610 |
api.upload_file(
|
| 611 |
path_or_fileobj=tmp_path,
|
|
|
|
| 623 |
|
| 624 |
|
| 625 |
def run_for_hour(
|
| 626 |
+
api,
|
| 627 |
+
source_repo: str,
|
| 628 |
+
target_repo: str,
|
| 629 |
+
hour_dt: datetime,
|
| 630 |
+
token: str,
|
| 631 |
) -> dict:
|
| 632 |
"""Roll up one UTC hour [hour_dt, hour_dt+1h).
|
| 633 |
|
|
|
|
| 661 |
|
| 662 |
row = _aggregate(per_session)
|
| 663 |
bucket_key = window_start.strftime("%Y-%m-%dT%H")
|
| 664 |
+
path_in_repo = (
|
| 665 |
+
f"hourly/{window_start.strftime('%Y-%m-%d')}/{window_start.strftime('%H')}.csv"
|
| 666 |
+
)
|
| 667 |
_write_csv(api, row, bucket_key, path_in_repo, target_repo, token)
|
| 668 |
+
logger.info(
|
| 669 |
+
"Wrote KPIs for %s (%d sessions): %s",
|
| 670 |
+
bucket_key,
|
| 671 |
+
per_session and len(per_session),
|
| 672 |
+
row,
|
| 673 |
+
)
|
| 674 |
return row
|
| 675 |
|
| 676 |
|
|
|
|
| 706 |
ap.add_argument("--source", default="smolagents/ml-intern-sessions")
|
| 707 |
ap.add_argument("--target", default="smolagents/ml-intern-kpis")
|
| 708 |
ap.add_argument(
|
| 709 |
+
"--hours",
|
| 710 |
+
type=int,
|
| 711 |
+
default=1,
|
| 712 |
help="Number of trailing hours to roll up (default: 1 = last completed hour).",
|
| 713 |
)
|
| 714 |
ap.add_argument(
|
| 715 |
+
"--datetime",
|
| 716 |
+
type=str,
|
| 717 |
+
default=None,
|
| 718 |
help="Single hour, ISO ``YYYY-MM-DDTHH`` (UTC); overrides --hours.",
|
| 719 |
)
|
| 720 |
ap.add_argument(
|
| 721 |
+
"--daily-backfill",
|
| 722 |
+
type=str,
|
| 723 |
+
default=None,
|
| 724 |
help="Escape hatch: aggregate a whole day at once (YYYY-MM-DD). "
|
| 725 |
+
"Writes to daily/<date>.csv. Use for historical backfill only.",
|
| 726 |
)
|
| 727 |
args = ap.parse_args(argv)
|
| 728 |
|
|
|
|
| 740 |
return 1
|
| 741 |
|
| 742 |
from huggingface_hub import HfApi
|
| 743 |
+
|
| 744 |
api = HfApi()
|
| 745 |
|
| 746 |
if args.daily_backfill:
|
| 747 |
+
run_for_day(
|
| 748 |
+
api,
|
| 749 |
+
args.source,
|
| 750 |
+
args.target,
|
| 751 |
+
date.fromisoformat(args.daily_backfill),
|
| 752 |
+
token,
|
| 753 |
+
)
|
| 754 |
return 0
|
| 755 |
|
| 756 |
if args.datetime:
|
scripts/build_sft.py
CHANGED
|
@@ -62,9 +62,13 @@ def _iter_session_files(api, repo_id: str, day: date, token: str) -> Iterable[st
|
|
| 62 |
|
| 63 |
def _download_and_parse(repo_id: str, path: str, token: str) -> dict | None:
|
| 64 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 65 |
try:
|
| 66 |
local = hf_hub_download(
|
| 67 |
-
repo_id=repo_id,
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
except Exception as e:
|
| 70 |
logger.warning("hf_hub_download(%s) failed: %s", path, e)
|
|
@@ -118,7 +122,10 @@ def _upload_row(api, row: dict, day: date, target_repo: str, token: str) -> None
|
|
| 118 |
tmp_path = tmp.name
|
| 119 |
try:
|
| 120 |
api.create_repo(
|
| 121 |
-
repo_id=target_repo,
|
|
|
|
|
|
|
|
|
|
| 122 |
)
|
| 123 |
api.upload_file(
|
| 124 |
path_or_fileobj=tmp_path,
|
|
@@ -136,7 +143,11 @@ def _upload_row(api, row: dict, day: date, target_repo: str, token: str) -> None
|
|
| 136 |
|
| 137 |
|
| 138 |
def run_for_day(
|
| 139 |
-
api,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
) -> int:
|
| 141 |
paths = _iter_session_files(api, source_repo, day, token)
|
| 142 |
n = 0
|
|
@@ -162,11 +173,15 @@ def main(argv: list[str] | None = None) -> int:
|
|
| 162 |
ap.add_argument("--source", default="smolagents/ml-intern-sessions")
|
| 163 |
ap.add_argument("--target", default="smolagents/ml-intern-sft")
|
| 164 |
ap.add_argument(
|
| 165 |
-
"--days",
|
|
|
|
|
|
|
| 166 |
help="Number of trailing days to export (default: 1 = yesterday).",
|
| 167 |
)
|
| 168 |
ap.add_argument(
|
| 169 |
-
"--date",
|
|
|
|
|
|
|
| 170 |
help="Single YYYY-MM-DD to export; overrides --days.",
|
| 171 |
)
|
| 172 |
args = ap.parse_args(argv)
|
|
@@ -185,6 +200,7 @@ def main(argv: list[str] | None = None) -> int:
|
|
| 185 |
return 1
|
| 186 |
|
| 187 |
from huggingface_hub import HfApi
|
|
|
|
| 188 |
api = HfApi()
|
| 189 |
|
| 190 |
if args.date:
|
|
|
|
| 62 |
|
| 63 |
def _download_and_parse(repo_id: str, path: str, token: str) -> dict | None:
|
| 64 |
from huggingface_hub import hf_hub_download
|
| 65 |
+
|
| 66 |
try:
|
| 67 |
local = hf_hub_download(
|
| 68 |
+
repo_id=repo_id,
|
| 69 |
+
filename=path,
|
| 70 |
+
repo_type="dataset",
|
| 71 |
+
token=token,
|
| 72 |
)
|
| 73 |
except Exception as e:
|
| 74 |
logger.warning("hf_hub_download(%s) failed: %s", path, e)
|
|
|
|
| 122 |
tmp_path = tmp.name
|
| 123 |
try:
|
| 124 |
api.create_repo(
|
| 125 |
+
repo_id=target_repo,
|
| 126 |
+
repo_type="dataset",
|
| 127 |
+
exist_ok=True,
|
| 128 |
+
token=token,
|
| 129 |
)
|
| 130 |
api.upload_file(
|
| 131 |
path_or_fileobj=tmp_path,
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
def run_for_day(
|
| 146 |
+
api,
|
| 147 |
+
source_repo: str,
|
| 148 |
+
target_repo: str,
|
| 149 |
+
day: date,
|
| 150 |
+
token: str,
|
| 151 |
) -> int:
|
| 152 |
paths = _iter_session_files(api, source_repo, day, token)
|
| 153 |
n = 0
|
|
|
|
| 173 |
ap.add_argument("--source", default="smolagents/ml-intern-sessions")
|
| 174 |
ap.add_argument("--target", default="smolagents/ml-intern-sft")
|
| 175 |
ap.add_argument(
|
| 176 |
+
"--days",
|
| 177 |
+
type=int,
|
| 178 |
+
default=1,
|
| 179 |
help="Number of trailing days to export (default: 1 = yesterday).",
|
| 180 |
)
|
| 181 |
ap.add_argument(
|
| 182 |
+
"--date",
|
| 183 |
+
type=str,
|
| 184 |
+
default=None,
|
| 185 |
help="Single YYYY-MM-DD to export; overrides --days.",
|
| 186 |
)
|
| 187 |
args = ap.parse_args(argv)
|
|
|
|
| 200 |
return 1
|
| 201 |
|
| 202 |
from huggingface_hub import HfApi
|
| 203 |
+
|
| 204 |
api = HfApi()
|
| 205 |
|
| 206 |
if args.date:
|