Deploy 2026-04-28
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +50 -0
- agent/config.py +108 -3
- agent/context_manager/manager.py +3 -0
- agent/core/agent_loop.py +217 -19
- agent/core/doom_loop.py +37 -6
- agent/core/hf_access.py +7 -3
- agent/core/hf_tokens.py +85 -0
- agent/core/llm_params.py +10 -9
- agent/core/session.py +169 -7
- agent/core/session_persistence.py +428 -0
- agent/core/session_uploader.py +2 -0
- agent/core/tools.py +14 -0
- agent/main.py +42 -30
- agent/messaging/__init__.py +15 -0
- agent/messaging/base.py +27 -0
- agent/messaging/gateway.py +166 -0
- agent/messaging/models.py +123 -0
- agent/messaging/slack.py +186 -0
- agent/prompts/system_prompt_v3.yaml +36 -3
- agent/tools/__init__.py +3 -0
- agent/tools/jobs_tool.py +86 -20
- agent/tools/notify_tool.py +108 -0
- agent/tools/research_tool.py +8 -3
- agent/tools/sandbox_client.py +55 -14
- agent/tools/sandbox_tool.py +171 -5
- agent/tools/trackio_seed.py +205 -0
- agent/tools/web_search_tool.py +273 -0
- backend/dependencies.py +7 -7
- backend/main.py +8 -5
- backend/models.py +9 -1
- backend/routes/agent.py +131 -76
- backend/session_manager.py +466 -58
- backend/user_quotas.py +42 -4
- configs/__init__.py +0 -0
- configs/cli_agent_config.json +5 -0
- frontend/src/components/Chat/MarkdownContent.tsx +10 -2
- frontend/src/components/Chat/ToolCallGroup.tsx +201 -1
- frontend/src/components/JobsUpgradeDialog.tsx +63 -54
- frontend/src/components/SessionSidebar/SessionSidebar.tsx +19 -2
- frontend/src/components/WelcomeScreen/WelcomeScreen.tsx +157 -141
- frontend/src/hooks/useAgentChat.ts +115 -55
- frontend/src/lib/sse-chat-transport.ts +79 -20
- frontend/src/store/agentStore.ts +83 -0
- frontend/src/store/sessionStore.ts +47 -0
- frontend/src/types/events.ts +1 -0
- pyproject.toml +17 -3
- scripts/build_kpis.py +133 -16
- scripts/sweep_orphan_sandboxes.py +206 -0
- tests/integration/test_live_sandbox_auth.py +90 -0
- tests/integration/test_live_thinking_models.py +151 -0
README.md
CHANGED
|
@@ -75,6 +75,56 @@ ml-intern --max-iterations 100 "your prompt"
|
|
| 75 |
ml-intern --no-stream "your prompt"
|
| 76 |
```
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
## Architecture
|
| 79 |
|
| 80 |
### Component Overview
|
|
|
|
| 75 |
ml-intern --no-stream "your prompt"
|
| 76 |
```
|
| 77 |
|
| 78 |
+
## Supported Gateways
|
| 79 |
+
|
| 80 |
+
ML Intern currently supports one-way notification gateways from CLI sessions.
|
| 81 |
+
These gateways send out-of-band status updates; they do not accept inbound chat
|
| 82 |
+
messages.
|
| 83 |
+
|
| 84 |
+
### Slack
|
| 85 |
+
|
| 86 |
+
Slack notifications use the Slack Web API to post messages when the agent needs
|
| 87 |
+
approval, hits an error, or completes a turn. Create a Slack app with a bot token
|
| 88 |
+
that has `chat:write`, invite the bot to the target channel, then set:
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
SLACK_BOT_TOKEN=xoxb-...
|
| 92 |
+
SLACK_CHANNEL_ID=C...
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
The CLI automatically creates a `slack.default` destination when both variables
|
| 96 |
+
are present. Optional environment variables for the env-only default:
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
ML_INTERN_SLACK_NOTIFICATIONS=false
|
| 100 |
+
ML_INTERN_SLACK_DESTINATION=slack.ops
|
| 101 |
+
ML_INTERN_SLACK_AUTO_EVENTS=approval_required,error,turn_complete
|
| 102 |
+
ML_INTERN_SLACK_ALLOW_AGENT_TOOL=true
|
| 103 |
+
ML_INTERN_SLACK_ALLOW_AUTO_EVENTS=true
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
For a persistent user-level config, put overrides in
|
| 107 |
+
`~/.config/ml-intern/cli_agent_config.json` or point `ML_INTERN_CLI_CONFIG` at a
|
| 108 |
+
JSON file:
|
| 109 |
+
|
| 110 |
+
```json
|
| 111 |
+
{
|
| 112 |
+
"messaging": {
|
| 113 |
+
"enabled": true,
|
| 114 |
+
"auto_event_types": ["approval_required", "error", "turn_complete"],
|
| 115 |
+
"destinations": {
|
| 116 |
+
"slack.ops": {
|
| 117 |
+
"provider": "slack",
|
| 118 |
+
"token": "${SLACK_BOT_TOKEN}",
|
| 119 |
+
"channel": "${SLACK_CHANNEL_ID}",
|
| 120 |
+
"allow_agent_tool": true,
|
| 121 |
+
"allow_auto_events": true
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
## Architecture
|
| 129 |
|
| 130 |
### Component Overview
|
agent/config.py
CHANGED
|
@@ -6,6 +6,8 @@ from typing import Any, Union
|
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
|
|
|
|
|
|
| 9 |
# Project root: two levels up from this file (agent/config.py -> project root)
|
| 10 |
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 11 |
from fastmcp.mcp_config import (
|
|
@@ -47,6 +49,104 @@ class Config(BaseModel):
|
|
| 47 |
# ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
|
| 48 |
# Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
|
| 49 |
reasoning_effort: str | None = "max"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def substitute_env_vars(obj: Any) -> Any:
|
|
@@ -86,7 +186,10 @@ def substitute_env_vars(obj: Any) -> Any:
|
|
| 86 |
return obj
|
| 87 |
|
| 88 |
|
| 89 |
-
def load_config(
|
|
|
|
|
|
|
|
|
|
| 90 |
"""
|
| 91 |
Load configuration with environment variable substitution.
|
| 92 |
|
|
@@ -98,8 +201,10 @@ def load_config(config_path: str = "config.json") -> Config:
|
|
| 98 |
load_dotenv(_PROJECT_ROOT / ".env")
|
| 99 |
load_dotenv(override=False)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
config_with_env = substitute_env_vars(raw_config)
|
| 105 |
return Config.model_validate(config_with_env)
|
|
|
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
| 9 |
+
from agent.messaging.models import MessagingConfig
|
| 10 |
+
|
| 11 |
# Project root: two levels up from this file (agent/config.py -> project root)
|
| 12 |
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 13 |
from fastmcp.mcp_config import (
|
|
|
|
| 49 |
# ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
|
| 50 |
# Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
|
| 51 |
reasoning_effort: str | None = "max"
|
| 52 |
+
messaging: MessagingConfig = MessagingConfig()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG"
|
| 56 |
+
DEFAULT_USER_CONFIG_PATH = Path.home() / ".config" / "ml-intern" / "cli_agent_config.json"
|
| 57 |
+
SLACK_DEFAULT_DESTINATION = "slack.default"
|
| 58 |
+
SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _deep_merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
| 62 |
+
merged = dict(base)
|
| 63 |
+
for key, value in override.items():
|
| 64 |
+
current = merged.get(key)
|
| 65 |
+
if isinstance(current, dict) and isinstance(value, dict):
|
| 66 |
+
merged[key] = _deep_merge_config(current, value)
|
| 67 |
+
else:
|
| 68 |
+
merged[key] = value
|
| 69 |
+
return merged
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _load_json_config(path: Path) -> dict[str, Any]:
|
| 73 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 74 |
+
data = json.load(f)
|
| 75 |
+
if not isinstance(data, dict):
|
| 76 |
+
raise ValueError(f"Config file {path} must contain a JSON object")
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _load_user_config() -> dict[str, Any]:
|
| 81 |
+
raw_path = os.environ.get(USER_CONFIG_ENV_VAR)
|
| 82 |
+
if raw_path:
|
| 83 |
+
path = Path(raw_path).expanduser()
|
| 84 |
+
if not path.exists():
|
| 85 |
+
raise FileNotFoundError(
|
| 86 |
+
f"{USER_CONFIG_ENV_VAR} points to missing config file: {path}"
|
| 87 |
+
)
|
| 88 |
+
return _load_json_config(path)
|
| 89 |
+
|
| 90 |
+
if DEFAULT_USER_CONFIG_PATH.exists():
|
| 91 |
+
return _load_json_config(DEFAULT_USER_CONFIG_PATH)
|
| 92 |
+
return {}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _env_bool(name: str, default: bool) -> bool:
|
| 96 |
+
value = os.environ.get(name)
|
| 97 |
+
if value is None:
|
| 98 |
+
return default
|
| 99 |
+
normalized = value.strip().lower()
|
| 100 |
+
if normalized in {"1", "true", "yes", "on"}:
|
| 101 |
+
return True
|
| 102 |
+
if normalized in {"0", "false", "no", "off"}:
|
| 103 |
+
return False
|
| 104 |
+
return default
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _env_list(name: str) -> list[str] | None:
|
| 108 |
+
value = os.environ.get(name)
|
| 109 |
+
if value is None:
|
| 110 |
+
return None
|
| 111 |
+
return [item.strip() for item in value.split(",") if item.strip()]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def apply_slack_user_defaults(raw_config: dict[str, Any]) -> dict[str, Any]:
|
| 115 |
+
"""Enable a default Slack destination from user env vars, when present."""
|
| 116 |
+
if not _env_bool("ML_INTERN_SLACK_NOTIFICATIONS", True):
|
| 117 |
+
return raw_config
|
| 118 |
+
|
| 119 |
+
token = os.environ.get("SLACK_BOT_TOKEN")
|
| 120 |
+
channel = os.environ.get("SLACK_CHANNEL_ID") or os.environ.get("SLACK_CHANNEL")
|
| 121 |
+
if not token or not channel:
|
| 122 |
+
return raw_config
|
| 123 |
+
|
| 124 |
+
config = dict(raw_config)
|
| 125 |
+
messaging = dict(config.get("messaging") or {})
|
| 126 |
+
destinations = dict(messaging.get("destinations") or {})
|
| 127 |
+
destination_name = (
|
| 128 |
+
os.environ.get("ML_INTERN_SLACK_DESTINATION") or SLACK_DEFAULT_DESTINATION
|
| 129 |
+
).strip()
|
| 130 |
+
|
| 131 |
+
if destination_name not in destinations:
|
| 132 |
+
destinations[destination_name] = {
|
| 133 |
+
"provider": "slack",
|
| 134 |
+
"token": token,
|
| 135 |
+
"channel": channel,
|
| 136 |
+
"allow_agent_tool": _env_bool("ML_INTERN_SLACK_ALLOW_AGENT_TOOL", True),
|
| 137 |
+
"allow_auto_events": _env_bool("ML_INTERN_SLACK_ALLOW_AUTO_EVENTS", True),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
auto_events = _env_list("ML_INTERN_SLACK_AUTO_EVENTS")
|
| 141 |
+
if auto_events is not None:
|
| 142 |
+
messaging["auto_event_types"] = auto_events
|
| 143 |
+
elif "auto_event_types" not in messaging:
|
| 144 |
+
messaging["auto_event_types"] = SLACK_DEFAULT_AUTO_EVENT_TYPES
|
| 145 |
+
|
| 146 |
+
messaging["enabled"] = True
|
| 147 |
+
messaging["destinations"] = destinations
|
| 148 |
+
config["messaging"] = messaging
|
| 149 |
+
return config
|
| 150 |
|
| 151 |
|
| 152 |
def substitute_env_vars(obj: Any) -> Any:
|
|
|
|
| 186 |
return obj
|
| 187 |
|
| 188 |
|
| 189 |
+
def load_config(
|
| 190 |
+
config_path: str = "config.json",
|
| 191 |
+
include_user_defaults: bool = False,
|
| 192 |
+
) -> Config:
|
| 193 |
"""
|
| 194 |
Load configuration with environment variable substitution.
|
| 195 |
|
|
|
|
| 201 |
load_dotenv(_PROJECT_ROOT / ".env")
|
| 202 |
load_dotenv(override=False)
|
| 203 |
|
| 204 |
+
raw_config = _load_json_config(Path(config_path))
|
| 205 |
+
if include_user_defaults:
|
| 206 |
+
raw_config = _deep_merge_config(raw_config, _load_user_config())
|
| 207 |
+
raw_config = apply_slack_user_defaults(raw_config)
|
| 208 |
|
| 209 |
config_with_env = substitute_env_vars(raw_config)
|
| 210 |
return Config.model_validate(config_with_env)
|
agent/context_manager/manager.py
CHANGED
|
@@ -160,6 +160,7 @@ class ContextManager:
|
|
| 160 |
self.running_context_usage = 0
|
| 161 |
self.untouched_messages = untouched_messages
|
| 162 |
self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
|
|
|
|
| 163 |
|
| 164 |
def _load_system_prompt(
|
| 165 |
self,
|
|
@@ -219,6 +220,8 @@ class ContextManager:
|
|
| 219 |
if token_count:
|
| 220 |
self.running_context_usage = token_count
|
| 221 |
self.items.append(message)
|
|
|
|
|
|
|
| 222 |
|
| 223 |
def get_messages(self) -> list[Message]:
|
| 224 |
"""Get all messages for sending to LLM.
|
|
|
|
| 160 |
self.running_context_usage = 0
|
| 161 |
self.untouched_messages = untouched_messages
|
| 162 |
self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
|
| 163 |
+
self.on_message_added = None
|
| 164 |
|
| 165 |
def _load_system_prompt(
|
| 166 |
self,
|
|
|
|
| 220 |
if token_count:
|
| 221 |
self.running_context_usage = token_count
|
| 222 |
self.items.append(message)
|
| 223 |
+
if self.on_message_added:
|
| 224 |
+
self.on_message_added(message)
|
| 225 |
|
| 226 |
def get_messages(self) -> list[Message]:
|
| 227 |
"""Get all messages for sending to LLM.
|
agent/core/agent_loop.py
CHANGED
|
@@ -8,11 +8,18 @@ import logging
|
|
| 8 |
import os
|
| 9 |
import time
|
| 10 |
from dataclasses import dataclass, field
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from litellm.exceptions import ContextWindowExceededError
|
| 14 |
|
| 15 |
from agent.config import Config
|
|
|
|
| 16 |
from agent.core import telemetry
|
| 17 |
from agent.core.doom_loop import check_for_doom_loop
|
| 18 |
from agent.core.llm_params import _resolve_llm_params
|
|
@@ -396,12 +403,159 @@ class LLMResult:
|
|
| 396 |
token_count: int
|
| 397 |
finish_reason: str | None
|
| 398 |
usage: dict = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
|
| 401 |
async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
|
| 402 |
"""Call the LLM with streaming, emitting assistant_chunk events."""
|
| 403 |
response = None
|
| 404 |
_healed_effort = False # one-shot safety net per call
|
|
|
|
| 405 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 406 |
t_start = time.monotonic()
|
| 407 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
|
@@ -429,6 +583,14 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 429 |
data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
|
| 430 |
))
|
| 431 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
_delay = _retry_delay_for(e, _llm_attempt)
|
| 433 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 434 |
logger.warning(
|
|
@@ -448,8 +610,11 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 448 |
token_count = 0
|
| 449 |
finish_reason = None
|
| 450 |
final_usage_chunk = None
|
|
|
|
|
|
|
| 451 |
|
| 452 |
async for chunk in response:
|
|
|
|
| 453 |
if session.is_cancelled:
|
| 454 |
tool_calls_acc.clear()
|
| 455 |
break
|
|
@@ -498,6 +663,16 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 498 |
latency_ms=int((time.monotonic() - t_start) * 1000),
|
| 499 |
finish_reason=finish_reason,
|
| 500 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
return LLMResult(
|
| 503 |
content=full_content or None,
|
|
@@ -505,6 +680,8 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 505 |
token_count=token_count,
|
| 506 |
finish_reason=finish_reason,
|
| 507 |
usage=usage,
|
|
|
|
|
|
|
| 508 |
)
|
| 509 |
|
| 510 |
|
|
@@ -512,6 +689,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 512 |
"""Call the LLM without streaming, emit assistant_message at the end."""
|
| 513 |
response = None
|
| 514 |
_healed_effort = False
|
|
|
|
| 515 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 516 |
t_start = time.monotonic()
|
| 517 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
|
@@ -538,6 +716,14 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 538 |
data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
|
| 539 |
))
|
| 540 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
_delay = _retry_delay_for(e, _llm_attempt)
|
| 542 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 543 |
logger.warning(
|
|
@@ -557,6 +743,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 557 |
content = message.content or None
|
| 558 |
finish_reason = choice.finish_reason
|
| 559 |
token_count = response.usage.total_tokens if response.usage else 0
|
|
|
|
| 560 |
|
| 561 |
# Build tool_calls_acc in the same format as streaming
|
| 562 |
tool_calls_acc: dict[int, dict] = {}
|
|
@@ -591,6 +778,8 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 591 |
token_count=token_count,
|
| 592 |
finish_reason=finish_reason,
|
| 593 |
usage=usage,
|
|
|
|
|
|
|
| 594 |
)
|
| 595 |
|
| 596 |
|
|
@@ -681,15 +870,6 @@ class Handlers:
|
|
| 681 |
session.context_manager.add_message(
|
| 682 |
Message(role="user", content=doom_prompt)
|
| 683 |
)
|
| 684 |
-
await session.send_event(
|
| 685 |
-
Event(
|
| 686 |
-
event_type="tool_log",
|
| 687 |
-
data={
|
| 688 |
-
"tool": "system",
|
| 689 |
-
"log": "Doom loop detected — injecting corrective prompt",
|
| 690 |
-
},
|
| 691 |
-
)
|
| 692 |
-
)
|
| 693 |
|
| 694 |
malformed_tool = _detect_repeated_malformed(session.context_manager.items)
|
| 695 |
if malformed_tool:
|
|
@@ -763,7 +943,10 @@ class Handlers:
|
|
| 763 |
" • For other tools: reduce the size of your arguments or use bash."
|
| 764 |
)
|
| 765 |
if content:
|
| 766 |
-
assistant_msg =
|
|
|
|
|
|
|
|
|
|
| 767 |
session.context_manager.add_message(assistant_msg, token_count)
|
| 768 |
session.context_manager.add_message(
|
| 769 |
Message(role="user", content=f"[SYSTEM: {truncation_hint}]")
|
|
@@ -819,7 +1002,10 @@ class Handlers:
|
|
| 819 |
(content or "")[:500],
|
| 820 |
)
|
| 821 |
if content:
|
| 822 |
-
assistant_msg =
|
|
|
|
|
|
|
|
|
|
| 823 |
session.context_manager.add_message(assistant_msg, token_count)
|
| 824 |
final_response = content
|
| 825 |
break
|
|
@@ -841,9 +1027,9 @@ class Handlers:
|
|
| 841 |
bad_tools.append(tc)
|
| 842 |
|
| 843 |
# Add assistant message with all tool calls to context
|
| 844 |
-
assistant_msg =
|
| 845 |
-
|
| 846 |
-
|
| 847 |
tool_calls=tool_calls,
|
| 848 |
)
|
| 849 |
session.context_manager.add_message(assistant_msg, token_count)
|
|
@@ -1049,7 +1235,12 @@ class Handlers:
|
|
| 1049 |
await session.send_event(
|
| 1050 |
Event(
|
| 1051 |
event_type="turn_complete",
|
| 1052 |
-
data={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1053 |
)
|
| 1054 |
)
|
| 1055 |
|
|
@@ -1358,12 +1549,16 @@ async def process_submission(session: Session, submission) -> bool:
|
|
| 1358 |
async def submission_loop(
|
| 1359 |
submission_queue: asyncio.Queue,
|
| 1360 |
event_queue: asyncio.Queue,
|
| 1361 |
-
config: Config
|
| 1362 |
tool_router: ToolRouter | None = None,
|
| 1363 |
session_holder: list | None = None,
|
| 1364 |
hf_token: str | None = None,
|
|
|
|
| 1365 |
local_mode: bool = False,
|
| 1366 |
stream: bool = True,
|
|
|
|
|
|
|
|
|
|
| 1367 |
) -> None:
|
| 1368 |
"""
|
| 1369 |
Main agent loop - processes submissions and dispatches to handlers.
|
|
@@ -1373,7 +1568,10 @@ async def submission_loop(
|
|
| 1373 |
# Create session with tool router
|
| 1374 |
session = Session(
|
| 1375 |
event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
|
| 1376 |
-
local_mode=local_mode, stream=stream,
|
|
|
|
|
|
|
|
|
|
| 1377 |
)
|
| 1378 |
if session_holder is not None:
|
| 1379 |
session_holder[0] = session
|
|
|
|
| 8 |
import os
|
| 9 |
import time
|
| 10 |
from dataclasses import dataclass, field
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from litellm import (
|
| 14 |
+
ChatCompletionMessageToolCall,
|
| 15 |
+
Message,
|
| 16 |
+
acompletion,
|
| 17 |
+
stream_chunk_builder,
|
| 18 |
+
)
|
| 19 |
from litellm.exceptions import ContextWindowExceededError
|
| 20 |
|
| 21 |
from agent.config import Config
|
| 22 |
+
from agent.messaging.gateway import NotificationGateway
|
| 23 |
from agent.core import telemetry
|
| 24 |
from agent.core.doom_loop import check_for_doom_loop
|
| 25 |
from agent.core.llm_params import _resolve_llm_params
|
|
|
|
| 403 |
token_count: int
|
| 404 |
finish_reason: str | None
|
| 405 |
usage: dict = field(default_factory=dict)
|
| 406 |
+
thinking_blocks: list[dict[str, Any]] | None = None
|
| 407 |
+
reasoning_content: str | None = None
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def _extract_thinking_state(
|
| 411 |
+
message: Any,
|
| 412 |
+
) -> tuple[list[dict[str, Any]] | None, str | None]:
|
| 413 |
+
"""Return provider reasoning fields that must be replayed after tool calls."""
|
| 414 |
+
provider_fields = getattr(message, "provider_specific_fields", None)
|
| 415 |
+
if not isinstance(provider_fields, dict):
|
| 416 |
+
provider_fields = {}
|
| 417 |
+
|
| 418 |
+
thinking_blocks = (
|
| 419 |
+
getattr(message, "thinking_blocks", None)
|
| 420 |
+
or provider_fields.get("thinking_blocks")
|
| 421 |
+
or None
|
| 422 |
+
)
|
| 423 |
+
reasoning_content = (
|
| 424 |
+
getattr(message, "reasoning_content", None)
|
| 425 |
+
or provider_fields.get("reasoning_content")
|
| 426 |
+
or None
|
| 427 |
+
)
|
| 428 |
+
return thinking_blocks, reasoning_content
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def _should_replay_thinking_state(model_name: str | None) -> bool:
|
| 432 |
+
"""Only Anthropic's native adapter accepts replayed thinking metadata."""
|
| 433 |
+
return bool(model_name and model_name.startswith("anthropic/"))
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _is_invalid_thinking_signature_error(exc: Exception) -> bool:
|
| 437 |
+
"""Return True when Anthropic rejected replayed extended-thinking state."""
|
| 438 |
+
text = str(exc)
|
| 439 |
+
return (
|
| 440 |
+
"Invalid `signature` in `thinking` block" in text
|
| 441 |
+
or "Invalid signature in thinking block" in text
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _strip_thinking_state_from_messages(messages: list[Any]) -> int:
|
| 446 |
+
"""Remove replayed thinking metadata from assistant history messages."""
|
| 447 |
+
stripped = 0
|
| 448 |
+
|
| 449 |
+
for message in messages:
|
| 450 |
+
role = (
|
| 451 |
+
message.get("role")
|
| 452 |
+
if isinstance(message, dict)
|
| 453 |
+
else getattr(message, "role", None)
|
| 454 |
+
)
|
| 455 |
+
if role != "assistant":
|
| 456 |
+
continue
|
| 457 |
+
|
| 458 |
+
if isinstance(message, dict):
|
| 459 |
+
if message.pop("thinking_blocks", None) is not None:
|
| 460 |
+
stripped += 1
|
| 461 |
+
if message.pop("reasoning_content", None) is not None:
|
| 462 |
+
stripped += 1
|
| 463 |
+
provider_fields = message.get("provider_specific_fields")
|
| 464 |
+
content = message.get("content")
|
| 465 |
+
else:
|
| 466 |
+
if getattr(message, "thinking_blocks", None) is not None:
|
| 467 |
+
message.thinking_blocks = None
|
| 468 |
+
stripped += 1
|
| 469 |
+
if getattr(message, "reasoning_content", None) is not None:
|
| 470 |
+
message.reasoning_content = None
|
| 471 |
+
stripped += 1
|
| 472 |
+
provider_fields = getattr(message, "provider_specific_fields", None)
|
| 473 |
+
content = getattr(message, "content", None)
|
| 474 |
+
|
| 475 |
+
if isinstance(provider_fields, dict):
|
| 476 |
+
cleaned_fields = dict(provider_fields)
|
| 477 |
+
if cleaned_fields.pop("thinking_blocks", None) is not None:
|
| 478 |
+
stripped += 1
|
| 479 |
+
if cleaned_fields.pop("reasoning_content", None) is not None:
|
| 480 |
+
stripped += 1
|
| 481 |
+
if cleaned_fields != provider_fields:
|
| 482 |
+
if isinstance(message, dict):
|
| 483 |
+
message["provider_specific_fields"] = cleaned_fields
|
| 484 |
+
else:
|
| 485 |
+
message.provider_specific_fields = cleaned_fields
|
| 486 |
+
|
| 487 |
+
if isinstance(content, list):
|
| 488 |
+
cleaned_content = [
|
| 489 |
+
block
|
| 490 |
+
for block in content
|
| 491 |
+
if not (
|
| 492 |
+
isinstance(block, dict)
|
| 493 |
+
and block.get("type") in {"thinking", "redacted_thinking"}
|
| 494 |
+
)
|
| 495 |
+
]
|
| 496 |
+
if len(cleaned_content) != len(content):
|
| 497 |
+
stripped += len(content) - len(cleaned_content)
|
| 498 |
+
if isinstance(message, dict):
|
| 499 |
+
message["content"] = cleaned_content
|
| 500 |
+
else:
|
| 501 |
+
message.content = cleaned_content
|
| 502 |
+
|
| 503 |
+
return stripped
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
async def _maybe_heal_invalid_thinking_signature(
|
| 507 |
+
session: Session,
|
| 508 |
+
messages: list[Any],
|
| 509 |
+
exc: Exception,
|
| 510 |
+
*,
|
| 511 |
+
already_healed: bool,
|
| 512 |
+
) -> bool:
|
| 513 |
+
if already_healed or not _is_invalid_thinking_signature_error(exc):
|
| 514 |
+
return False
|
| 515 |
+
|
| 516 |
+
stripped = _strip_thinking_state_from_messages(messages)
|
| 517 |
+
if not stripped:
|
| 518 |
+
return False
|
| 519 |
+
|
| 520 |
+
await session.send_event(Event(
|
| 521 |
+
event_type="tool_log",
|
| 522 |
+
data={
|
| 523 |
+
"tool": "system",
|
| 524 |
+
"log": (
|
| 525 |
+
"Anthropic rejected stale thinking signatures; retrying "
|
| 526 |
+
"without replayed thinking metadata."
|
| 527 |
+
),
|
| 528 |
+
},
|
| 529 |
+
))
|
| 530 |
+
return True
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def _assistant_message_from_result(
|
| 534 |
+
llm_result: LLMResult,
|
| 535 |
+
*,
|
| 536 |
+
model_name: str | None,
|
| 537 |
+
tool_calls: list[ToolCall] | None = None,
|
| 538 |
+
) -> Message:
|
| 539 |
+
"""Build an assistant history message without dropping reasoning state."""
|
| 540 |
+
kwargs: dict[str, Any] = {
|
| 541 |
+
"role": "assistant",
|
| 542 |
+
"content": llm_result.content,
|
| 543 |
+
}
|
| 544 |
+
if tool_calls is not None:
|
| 545 |
+
kwargs["tool_calls"] = tool_calls
|
| 546 |
+
if _should_replay_thinking_state(model_name):
|
| 547 |
+
if llm_result.thinking_blocks:
|
| 548 |
+
kwargs["thinking_blocks"] = llm_result.thinking_blocks
|
| 549 |
+
if llm_result.reasoning_content:
|
| 550 |
+
kwargs["reasoning_content"] = llm_result.reasoning_content
|
| 551 |
+
return Message(**kwargs)
|
| 552 |
|
| 553 |
|
| 554 |
async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
|
| 555 |
"""Call the LLM with streaming, emitting assistant_chunk events."""
|
| 556 |
response = None
|
| 557 |
_healed_effort = False # one-shot safety net per call
|
| 558 |
+
_healed_thinking_signature = False
|
| 559 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 560 |
t_start = time.monotonic()
|
| 561 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
|
|
|
| 583 |
data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
|
| 584 |
))
|
| 585 |
continue
|
| 586 |
+
if await _maybe_heal_invalid_thinking_signature(
|
| 587 |
+
session,
|
| 588 |
+
messages,
|
| 589 |
+
e,
|
| 590 |
+
already_healed=_healed_thinking_signature,
|
| 591 |
+
):
|
| 592 |
+
_healed_thinking_signature = True
|
| 593 |
+
continue
|
| 594 |
_delay = _retry_delay_for(e, _llm_attempt)
|
| 595 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 596 |
logger.warning(
|
|
|
|
| 610 |
token_count = 0
|
| 611 |
finish_reason = None
|
| 612 |
final_usage_chunk = None
|
| 613 |
+
chunks = []
|
| 614 |
+
should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))
|
| 615 |
|
| 616 |
async for chunk in response:
|
| 617 |
+
chunks.append(chunk)
|
| 618 |
if session.is_cancelled:
|
| 619 |
tool_calls_acc.clear()
|
| 620 |
break
|
|
|
|
| 663 |
latency_ms=int((time.monotonic() - t_start) * 1000),
|
| 664 |
finish_reason=finish_reason,
|
| 665 |
)
|
| 666 |
+
thinking_blocks = None
|
| 667 |
+
reasoning_content = None
|
| 668 |
+
if chunks and should_replay_thinking:
|
| 669 |
+
try:
|
| 670 |
+
rebuilt = stream_chunk_builder(chunks, messages=messages)
|
| 671 |
+
if rebuilt and getattr(rebuilt, "choices", None):
|
| 672 |
+
rebuilt_msg = rebuilt.choices[0].message
|
| 673 |
+
thinking_blocks, reasoning_content = _extract_thinking_state(rebuilt_msg)
|
| 674 |
+
except Exception:
|
| 675 |
+
logger.debug("Failed to rebuild streaming thinking state", exc_info=True)
|
| 676 |
|
| 677 |
return LLMResult(
|
| 678 |
content=full_content or None,
|
|
|
|
| 680 |
token_count=token_count,
|
| 681 |
finish_reason=finish_reason,
|
| 682 |
usage=usage,
|
| 683 |
+
thinking_blocks=thinking_blocks,
|
| 684 |
+
reasoning_content=reasoning_content,
|
| 685 |
)
|
| 686 |
|
| 687 |
|
|
|
|
| 689 |
"""Call the LLM without streaming, emit assistant_message at the end."""
|
| 690 |
response = None
|
| 691 |
_healed_effort = False
|
| 692 |
+
_healed_thinking_signature = False
|
| 693 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 694 |
t_start = time.monotonic()
|
| 695 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
|
|
|
| 716 |
data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
|
| 717 |
))
|
| 718 |
continue
|
| 719 |
+
if await _maybe_heal_invalid_thinking_signature(
|
| 720 |
+
session,
|
| 721 |
+
messages,
|
| 722 |
+
e,
|
| 723 |
+
already_healed=_healed_thinking_signature,
|
| 724 |
+
):
|
| 725 |
+
_healed_thinking_signature = True
|
| 726 |
+
continue
|
| 727 |
_delay = _retry_delay_for(e, _llm_attempt)
|
| 728 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 729 |
logger.warning(
|
|
|
|
| 743 |
content = message.content or None
|
| 744 |
finish_reason = choice.finish_reason
|
| 745 |
token_count = response.usage.total_tokens if response.usage else 0
|
| 746 |
+
thinking_blocks, reasoning_content = _extract_thinking_state(message)
|
| 747 |
|
| 748 |
# Build tool_calls_acc in the same format as streaming
|
| 749 |
tool_calls_acc: dict[int, dict] = {}
|
|
|
|
| 778 |
token_count=token_count,
|
| 779 |
finish_reason=finish_reason,
|
| 780 |
usage=usage,
|
| 781 |
+
thinking_blocks=thinking_blocks,
|
| 782 |
+
reasoning_content=reasoning_content,
|
| 783 |
)
|
| 784 |
|
| 785 |
|
|
|
|
| 870 |
session.context_manager.add_message(
|
| 871 |
Message(role="user", content=doom_prompt)
|
| 872 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
|
| 874 |
malformed_tool = _detect_repeated_malformed(session.context_manager.items)
|
| 875 |
if malformed_tool:
|
|
|
|
| 943 |
" • For other tools: reduce the size of your arguments or use bash."
|
| 944 |
)
|
| 945 |
if content:
|
| 946 |
+
assistant_msg = _assistant_message_from_result(
|
| 947 |
+
llm_result,
|
| 948 |
+
model_name=llm_params.get("model"),
|
| 949 |
+
)
|
| 950 |
session.context_manager.add_message(assistant_msg, token_count)
|
| 951 |
session.context_manager.add_message(
|
| 952 |
Message(role="user", content=f"[SYSTEM: {truncation_hint}]")
|
|
|
|
| 1002 |
(content or "")[:500],
|
| 1003 |
)
|
| 1004 |
if content:
|
| 1005 |
+
assistant_msg = _assistant_message_from_result(
|
| 1006 |
+
llm_result,
|
| 1007 |
+
model_name=llm_params.get("model"),
|
| 1008 |
+
)
|
| 1009 |
session.context_manager.add_message(assistant_msg, token_count)
|
| 1010 |
final_response = content
|
| 1011 |
break
|
|
|
|
| 1027 |
bad_tools.append(tc)
|
| 1028 |
|
| 1029 |
# Add assistant message with all tool calls to context
|
| 1030 |
+
assistant_msg = _assistant_message_from_result(
|
| 1031 |
+
llm_result,
|
| 1032 |
+
model_name=llm_params.get("model"),
|
| 1033 |
tool_calls=tool_calls,
|
| 1034 |
)
|
| 1035 |
session.context_manager.add_message(assistant_msg, token_count)
|
|
|
|
| 1235 |
await session.send_event(
|
| 1236 |
Event(
|
| 1237 |
event_type="turn_complete",
|
| 1238 |
+
data={
|
| 1239 |
+
"history_size": len(session.context_manager.items),
|
| 1240 |
+
"final_response": final_response
|
| 1241 |
+
if isinstance(final_response, str)
|
| 1242 |
+
else None,
|
| 1243 |
+
},
|
| 1244 |
)
|
| 1245 |
)
|
| 1246 |
|
|
|
|
| 1549 |
async def submission_loop(
|
| 1550 |
submission_queue: asyncio.Queue,
|
| 1551 |
event_queue: asyncio.Queue,
|
| 1552 |
+
config: Config,
|
| 1553 |
tool_router: ToolRouter | None = None,
|
| 1554 |
session_holder: list | None = None,
|
| 1555 |
hf_token: str | None = None,
|
| 1556 |
+
user_id: str | None = None,
|
| 1557 |
local_mode: bool = False,
|
| 1558 |
stream: bool = True,
|
| 1559 |
+
notification_gateway: NotificationGateway | None = None,
|
| 1560 |
+
notification_destinations: list[str] | None = None,
|
| 1561 |
+
defer_turn_complete_notification: bool = False,
|
| 1562 |
) -> None:
|
| 1563 |
"""
|
| 1564 |
Main agent loop - processes submissions and dispatches to handlers.
|
|
|
|
| 1568 |
# Create session with tool router
|
| 1569 |
session = Session(
|
| 1570 |
event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
|
| 1571 |
+
user_id=user_id, local_mode=local_mode, stream=stream,
|
| 1572 |
+
notification_gateway=notification_gateway,
|
| 1573 |
+
notification_destinations=notification_destinations,
|
| 1574 |
+
defer_turn_complete_notification=defer_turn_complete_notification,
|
| 1575 |
)
|
| 1576 |
if session_holder is not None:
|
| 1577 |
session_holder[0] = session
|
agent/core/doom_loop.py
CHANGED
|
@@ -24,9 +24,36 @@ class ToolCallSignature:
|
|
| 24 |
result_hash: str | None = None
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def _hash_args(args_str: str) -> str:
|
| 28 |
-
"""Return a short hash of the JSON arguments string.
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def extract_recent_tool_signatures(
|
|
@@ -129,9 +156,13 @@ def check_for_doom_loop(messages: list[Message]) -> str | None:
|
|
| 129 |
# Check for identical consecutive calls
|
| 130 |
tool_name = detect_identical_consecutive(signatures, threshold=3)
|
| 131 |
if tool_name:
|
| 132 |
-
logger.warning(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
return (
|
| 134 |
-
f"[SYSTEM:
|
| 135 |
f"arguments multiple times in a row, getting the same result each time. "
|
| 136 |
f"STOP repeating this approach — it is not working. "
|
| 137 |
f"Step back and try a fundamentally different strategy. "
|
|
@@ -143,9 +174,9 @@ def check_for_doom_loop(messages: list[Message]) -> str | None:
|
|
| 143 |
pattern = detect_repeating_sequence(signatures)
|
| 144 |
if pattern:
|
| 145 |
pattern_desc = " → ".join(s.name for s in pattern)
|
| 146 |
-
logger.warning("
|
| 147 |
return (
|
| 148 |
-
f"[SYSTEM:
|
| 149 |
f"[{pattern_desc}]. This pattern has repeated multiple times without progress. "
|
| 150 |
f"STOP this cycle and try a fundamentally different approach. "
|
| 151 |
f"Consider: breaking down the problem differently, using alternative tools, "
|
|
|
|
| 24 |
result_hash: str | None = None
|
| 25 |
|
| 26 |
|
| 27 |
+
def _normalize_args(args_str: str) -> str:
|
| 28 |
+
"""Canonicalise a tool-call arguments string before hashing.
|
| 29 |
+
|
| 30 |
+
LLMs can emit semantically-identical JSON for the same call with different
|
| 31 |
+
key orderings (``{"a": 1, "b": 2}`` vs ``{"b": 2, "a": 1}``) or whitespace
|
| 32 |
+
(``{"a":1}`` vs ``{"a": 1}``). Hashing the raw bytes makes the doom-loop
|
| 33 |
+
detector miss those repeats. We parse-and-redump with ``sort_keys=True``
|
| 34 |
+
plus the most compact separators so trivially-different spellings collapse
|
| 35 |
+
to the same canonical form.
|
| 36 |
+
|
| 37 |
+
Falls back to the original string if the input isn't valid JSON (e.g. a
|
| 38 |
+
handful of providers occasionally pass a bare string for ``arguments``);
|
| 39 |
+
that path keeps the legacy behaviour and never raises.
|
| 40 |
+
"""
|
| 41 |
+
if not args_str:
|
| 42 |
+
return ""
|
| 43 |
+
try:
|
| 44 |
+
return json.dumps(json.loads(args_str), sort_keys=True, separators=(",", ":"))
|
| 45 |
+
except (json.JSONDecodeError, TypeError, ValueError):
|
| 46 |
+
return args_str
|
| 47 |
+
|
| 48 |
+
|
| 49 |
def _hash_args(args_str: str) -> str:
|
| 50 |
+
"""Return a short hash of the JSON arguments string.
|
| 51 |
+
|
| 52 |
+
The input is normalised via :func:`_normalize_args` first so that
|
| 53 |
+
semantically-identical tool calls produce the same hash regardless of key
|
| 54 |
+
order or whitespace.
|
| 55 |
+
"""
|
| 56 |
+
return hashlib.md5(_normalize_args(args_str).encode()).hexdigest()[:12]
|
| 57 |
|
| 58 |
|
| 59 |
def extract_recent_tool_signatures(
|
|
|
|
| 156 |
# Check for identical consecutive calls
|
| 157 |
tool_name = detect_identical_consecutive(signatures, threshold=3)
|
| 158 |
if tool_name:
|
| 159 |
+
logger.warning(
|
| 160 |
+
"Repetition guard activated: %d+ identical consecutive calls to '%s'",
|
| 161 |
+
3,
|
| 162 |
+
tool_name,
|
| 163 |
+
)
|
| 164 |
return (
|
| 165 |
+
f"[SYSTEM: REPETITION GUARD] You have called '{tool_name}' with the same "
|
| 166 |
f"arguments multiple times in a row, getting the same result each time. "
|
| 167 |
f"STOP repeating this approach — it is not working. "
|
| 168 |
f"Step back and try a fundamentally different strategy. "
|
|
|
|
| 174 |
pattern = detect_repeating_sequence(signatures)
|
| 175 |
if pattern:
|
| 176 |
pattern_desc = " → ".join(s.name for s in pattern)
|
| 177 |
+
logger.warning("Repetition guard activated: repeating sequence [%s]", pattern_desc)
|
| 178 |
return (
|
| 179 |
+
f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: "
|
| 180 |
f"[{pattern_desc}]. This pattern has repeated multiple times without progress. "
|
| 181 |
f"STOP this cycle and try a fundamentally different approach. "
|
| 182 |
f"Consider: breaking down the problem differently, using alternative tools, "
|
agent/core/hf_access.py
CHANGED
|
@@ -55,6 +55,13 @@ def _extract_username(whoami: dict[str, Any]) -> str | None:
|
|
| 55 |
|
| 56 |
|
| 57 |
def _normalize_personal_plan(whoami: dict[str, Any]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
plan_str = ""
|
| 59 |
for key in ("plan", "type", "accountType"):
|
| 60 |
value = whoami.get(key)
|
|
@@ -62,9 +69,6 @@ def _normalize_personal_plan(whoami: dict[str, Any]) -> str:
|
|
| 62 |
plan_str = value.lower()
|
| 63 |
break
|
| 64 |
|
| 65 |
-
if not plan_str and (whoami.get("isPro") is True or whoami.get("is_pro") is True):
|
| 66 |
-
return "pro"
|
| 67 |
-
|
| 68 |
if any(tag in plan_str for tag in ("pro", "enterprise", "team")):
|
| 69 |
return "pro"
|
| 70 |
return "free"
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
def _normalize_personal_plan(whoami: dict[str, Any]) -> str:
|
| 58 |
+
# OAuth whoami responses set `type: "user"` and surface Pro status only via
|
| 59 |
+
# the `isPro` boolean. Check the boolean first so a generic `type` value
|
| 60 |
+
# doesn't shadow it — otherwise Pro OAuth users get classified as free and
|
| 61 |
+
# blocked from running Jobs (smolagents/ml-intern Space discussion #21).
|
| 62 |
+
if whoami.get("isPro") is True or whoami.get("is_pro") is True:
|
| 63 |
+
return "pro"
|
| 64 |
+
|
| 65 |
plan_str = ""
|
| 66 |
for key in ("plan", "type", "accountType"):
|
| 67 |
value = whoami.get(key)
|
|
|
|
| 69 |
plan_str = value.lower()
|
| 70 |
break
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
if any(tag in plan_str for tag in ("pro", "enterprise", "team")):
|
| 73 |
return "pro"
|
| 74 |
return "free"
|
agent/core/hf_tokens.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face token resolution helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def clean_hf_token(token: str | None) -> str | None:
|
| 10 |
+
"""Normalize token strings the same way huggingface_hub does."""
|
| 11 |
+
if token is None:
|
| 12 |
+
return None
|
| 13 |
+
return token.replace("\r", "").replace("\n", "").strip() or None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_cached_hf_token() -> str | None:
|
| 17 |
+
"""Return the token from huggingface_hub's normal env/cache lookup."""
|
| 18 |
+
try:
|
| 19 |
+
from huggingface_hub import get_token
|
| 20 |
+
|
| 21 |
+
return get_token()
|
| 22 |
+
except Exception:
|
| 23 |
+
return None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def resolve_hf_token(
|
| 27 |
+
*candidates: str | None,
|
| 28 |
+
include_cached: bool = True,
|
| 29 |
+
) -> str | None:
|
| 30 |
+
"""Return the first non-empty explicit token, then optionally HF cache."""
|
| 31 |
+
for token in candidates:
|
| 32 |
+
cleaned = clean_hf_token(token)
|
| 33 |
+
if cleaned:
|
| 34 |
+
return cleaned
|
| 35 |
+
if include_cached:
|
| 36 |
+
return get_cached_hf_token()
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resolve_hf_router_token(session_hf_token: str | None = None) -> str | None:
|
| 41 |
+
"""Resolve the token used for Hugging Face Router LLM calls.
|
| 42 |
+
|
| 43 |
+
App-specific precedence:
|
| 44 |
+
1. INFERENCE_TOKEN: shared hosted-Space inference token.
|
| 45 |
+
2. session_hf_token: the active user/session token.
|
| 46 |
+
3. huggingface_hub.get_token(): HF_TOKEN/HUGGING_FACE_HUB_TOKEN or
|
| 47 |
+
local ``hf auth login`` cache.
|
| 48 |
+
"""
|
| 49 |
+
return resolve_hf_token(os.environ.get("INFERENCE_TOKEN"), session_hf_token)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_hf_bill_to() -> str | None:
|
| 53 |
+
"""Return X-HF-Bill-To only when a shared inference token is active."""
|
| 54 |
+
if clean_hf_token(os.environ.get("INFERENCE_TOKEN")):
|
| 55 |
+
return os.environ.get("HF_BILL_TO", "smolagents")
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def bearer_token_from_header(auth_header: str | None) -> str | None:
|
| 60 |
+
"""Extract a cleaned bearer token from an Authorization header."""
|
| 61 |
+
if not auth_header or not auth_header.startswith("Bearer "):
|
| 62 |
+
return None
|
| 63 |
+
return clean_hf_token(auth_header[7:])
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def resolve_hf_request_token(
|
| 67 |
+
request: Any,
|
| 68 |
+
*,
|
| 69 |
+
include_env_fallback: bool = True,
|
| 70 |
+
) -> str | None:
|
| 71 |
+
"""Resolve a user token from a FastAPI request.
|
| 72 |
+
|
| 73 |
+
This intentionally does not use the local ``hf auth login`` cache. Backend
|
| 74 |
+
request paths should act as the browser user from Authorization/cookie, or
|
| 75 |
+
fall back only to an explicit server ``HF_TOKEN`` in dev/server contexts.
|
| 76 |
+
"""
|
| 77 |
+
token = bearer_token_from_header(request.headers.get("Authorization", ""))
|
| 78 |
+
if token:
|
| 79 |
+
return token
|
| 80 |
+
token = clean_hf_token(request.cookies.get("hf_access_token"))
|
| 81 |
+
if token:
|
| 82 |
+
return token
|
| 83 |
+
if include_env_fallback:
|
| 84 |
+
return clean_hf_token(os.environ.get("HF_TOKEN"))
|
| 85 |
+
return None
|
agent/core/llm_params.py
CHANGED
|
@@ -5,7 +5,12 @@ can import it without pulling in the whole agent loop / tool router and
|
|
| 5 |
creating circular imports.
|
| 6 |
"""
|
| 7 |
|
| 8 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def _patch_litellm_effort_validation() -> None:
|
|
@@ -129,7 +134,8 @@ def _resolve_llm_params(
|
|
| 129 |
1. INFERENCE_TOKEN env — shared key on the hosted Space (inference is
|
| 130 |
free for users, billed to the Space owner via ``X-HF-Bill-To``).
|
| 131 |
2. session.hf_token — the user's own token (CLI / OAuth / cache file).
|
| 132 |
-
3.
|
|
|
|
| 133 |
"""
|
| 134 |
if model_name.startswith("anthropic/"):
|
| 135 |
params: dict = {"model": model_name}
|
|
@@ -175,18 +181,13 @@ def _resolve_llm_params(
|
|
| 175 |
return params
|
| 176 |
|
| 177 |
hf_model = model_name.removeprefix("huggingface/")
|
| 178 |
-
api_key = (
|
| 179 |
-
os.environ.get("INFERENCE_TOKEN")
|
| 180 |
-
or session_hf_token
|
| 181 |
-
or os.environ.get("HF_TOKEN")
|
| 182 |
-
)
|
| 183 |
params = {
|
| 184 |
"model": f"openai/{hf_model}",
|
| 185 |
"api_base": "https://router.huggingface.co/v1",
|
| 186 |
"api_key": api_key,
|
| 187 |
}
|
| 188 |
-
if
|
| 189 |
-
bill_to = os.environ.get("HF_BILL_TO", "smolagents")
|
| 190 |
params["extra_headers"] = {"X-HF-Bill-To": bill_to}
|
| 191 |
if reasoning_effort:
|
| 192 |
hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort
|
|
|
|
| 5 |
creating circular imports.
|
| 6 |
"""
|
| 7 |
|
| 8 |
+
from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _resolve_hf_router_token(session_hf_token: str | None = None) -> str | None:
|
| 12 |
+
"""Backward-compatible private wrapper used by tests and older imports."""
|
| 13 |
+
return resolve_hf_router_token(session_hf_token)
|
| 14 |
|
| 15 |
|
| 16 |
def _patch_litellm_effort_validation() -> None:
|
|
|
|
| 134 |
1. INFERENCE_TOKEN env — shared key on the hosted Space (inference is
|
| 135 |
free for users, billed to the Space owner via ``X-HF-Bill-To``).
|
| 136 |
2. session.hf_token — the user's own token (CLI / OAuth / cache file).
|
| 137 |
+
3. huggingface_hub cache — ``HF_TOKEN`` / ``HUGGING_FACE_HUB_TOKEN`` /
|
| 138 |
+
local ``hf auth login`` cache.
|
| 139 |
"""
|
| 140 |
if model_name.startswith("anthropic/"):
|
| 141 |
params: dict = {"model": model_name}
|
|
|
|
| 181 |
return params
|
| 182 |
|
| 183 |
hf_model = model_name.removeprefix("huggingface/")
|
| 184 |
+
api_key = _resolve_hf_router_token(session_hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
params = {
|
| 186 |
"model": f"openai/{hf_model}",
|
| 187 |
"api_base": "https://router.huggingface.co/v1",
|
| 188 |
"api_key": api_key,
|
| 189 |
}
|
| 190 |
+
if bill_to := get_hf_bill_to():
|
|
|
|
| 191 |
params["extra_headers"] = {"X-HF-Bill-To": bill_to}
|
| 192 |
if reasoning_effort:
|
| 193 |
hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort
|
agent/core/session.py
CHANGED
|
@@ -12,10 +12,13 @@ from typing import Any, Optional
|
|
| 12 |
|
| 13 |
from agent.config import Config
|
| 14 |
from agent.context_manager.manager import ContextManager
|
|
|
|
|
|
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
_DEFAULT_MAX_TOKENS = 200_000
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def _get_max_tokens_safe(model_name: str) -> int:
|
|
@@ -62,6 +65,7 @@ class OpType(Enum):
|
|
| 62 |
class Event:
|
| 63 |
event_type: str
|
| 64 |
data: Optional[dict[str, Any]] = None
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
class Session:
|
|
@@ -73,16 +77,26 @@ class Session:
|
|
| 73 |
def __init__(
|
| 74 |
self,
|
| 75 |
event_queue: asyncio.Queue,
|
| 76 |
-
config: Config
|
| 77 |
tool_router=None,
|
| 78 |
context_manager: ContextManager | None = None,
|
| 79 |
hf_token: str | None = None,
|
| 80 |
local_mode: bool = False,
|
| 81 |
stream: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
):
|
| 83 |
self.hf_token: Optional[str] = hf_token
|
|
|
|
|
|
|
| 84 |
self.tool_router = tool_router
|
| 85 |
self.stream = stream
|
|
|
|
|
|
|
| 86 |
tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
|
| 87 |
self.context_manager = context_manager or ContextManager(
|
| 88 |
model_max_tokens=_get_max_tokens_safe(config.model_name),
|
|
@@ -93,15 +107,16 @@ class Session:
|
|
| 93 |
local_mode=local_mode,
|
| 94 |
)
|
| 95 |
self.event_queue = event_queue
|
| 96 |
-
self.session_id = str(uuid.uuid4())
|
| 97 |
-
self.config = config
|
| 98 |
-
model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
| 99 |
-
)
|
| 100 |
self.is_running = True
|
| 101 |
self._cancelled = asyncio.Event()
|
| 102 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 103 |
self.sandbox = None
|
| 104 |
self._running_job_ids: set[str] = set() # HF job IDs currently executing
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# Session trajectory logging
|
| 107 |
self.logged_events: list[dict] = []
|
|
@@ -123,11 +138,10 @@ class Session:
|
|
| 123 |
# thinking params at all
|
| 124 |
# Key absent → not probed yet; fall back to the raw preference.
|
| 125 |
self.model_effective_effort: dict[str, str | None] = {}
|
|
|
|
| 126 |
|
| 127 |
async def send_event(self, event: Event) -> None:
|
| 128 |
"""Send event back to client and log to trajectory"""
|
| 129 |
-
await self.event_queue.put(event)
|
| 130 |
-
|
| 131 |
# Log event to trajectory
|
| 132 |
self.logged_events.append(
|
| 133 |
{
|
|
@@ -136,11 +150,149 @@ class Session:
|
|
| 136 |
"data": event.data,
|
| 137 |
}
|
| 138 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
# Mid-turn heartbeat flush (owned by telemetry module).
|
| 141 |
from agent.core.telemetry import HeartbeatSaver
|
|
|
|
| 142 |
HeartbeatSaver.maybe_fire(self)
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
def cancel(self) -> None:
|
| 145 |
"""Signal cancellation to the running agent loop."""
|
| 146 |
self._cancelled.set()
|
|
@@ -199,11 +351,21 @@ class Session:
|
|
| 199 |
tools = self.tool_router.get_tool_specs_for_llm() or []
|
| 200 |
except Exception:
|
| 201 |
tools = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
return {
|
| 203 |
"session_id": self.session_id,
|
|
|
|
| 204 |
"session_start_time": self.session_start_time,
|
| 205 |
"session_end_time": datetime.now().isoformat(),
|
| 206 |
"model_name": self.config.model_name,
|
|
|
|
| 207 |
"messages": [msg.model_dump() for msg in self.context_manager.items],
|
| 208 |
"events": self.logged_events,
|
| 209 |
"tools": tools,
|
|
|
|
| 12 |
|
| 13 |
from agent.config import Config
|
| 14 |
from agent.context_manager.manager import ContextManager
|
| 15 |
+
from agent.messaging.gateway import NotificationGateway
|
| 16 |
+
from agent.messaging.models import NotificationRequest
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
_DEFAULT_MAX_TOKENS = 200_000
|
| 21 |
+
_TURN_COMPLETE_NOTIFICATION_CHARS = 39000
|
| 22 |
|
| 23 |
|
| 24 |
def _get_max_tokens_safe(model_name: str) -> int:
|
|
|
|
| 65 |
class Event:
|
| 66 |
event_type: str
|
| 67 |
data: Optional[dict[str, Any]] = None
|
| 68 |
+
seq: Optional[int] = None
|
| 69 |
|
| 70 |
|
| 71 |
class Session:
|
|
|
|
| 77 |
def __init__(
|
| 78 |
self,
|
| 79 |
event_queue: asyncio.Queue,
|
| 80 |
+
config: Config,
|
| 81 |
tool_router=None,
|
| 82 |
context_manager: ContextManager | None = None,
|
| 83 |
hf_token: str | None = None,
|
| 84 |
local_mode: bool = False,
|
| 85 |
stream: bool = True,
|
| 86 |
+
notification_gateway: NotificationGateway | None = None,
|
| 87 |
+
notification_destinations: list[str] | None = None,
|
| 88 |
+
defer_turn_complete_notification: bool = False,
|
| 89 |
+
session_id: str | None = None,
|
| 90 |
+
user_id: str | None = None,
|
| 91 |
+
persistence_store: Any | None = None,
|
| 92 |
):
|
| 93 |
self.hf_token: Optional[str] = hf_token
|
| 94 |
+
self.user_id: Optional[str] = user_id
|
| 95 |
+
self.persistence_store = persistence_store
|
| 96 |
self.tool_router = tool_router
|
| 97 |
self.stream = stream
|
| 98 |
+
if config is None:
|
| 99 |
+
raise ValueError("Session requires a Config")
|
| 100 |
tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
|
| 101 |
self.context_manager = context_manager or ContextManager(
|
| 102 |
model_max_tokens=_get_max_tokens_safe(config.model_name),
|
|
|
|
| 107 |
local_mode=local_mode,
|
| 108 |
)
|
| 109 |
self.event_queue = event_queue
|
| 110 |
+
self.session_id = session_id or str(uuid.uuid4())
|
| 111 |
+
self.config = config
|
|
|
|
|
|
|
| 112 |
self.is_running = True
|
| 113 |
self._cancelled = asyncio.Event()
|
| 114 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 115 |
self.sandbox = None
|
| 116 |
self._running_job_ids: set[str] = set() # HF job IDs currently executing
|
| 117 |
+
self.notification_gateway = notification_gateway
|
| 118 |
+
self.notification_destinations = list(notification_destinations or [])
|
| 119 |
+
self.defer_turn_complete_notification = defer_turn_complete_notification
|
| 120 |
|
| 121 |
# Session trajectory logging
|
| 122 |
self.logged_events: list[dict] = []
|
|
|
|
| 138 |
# thinking params at all
|
| 139 |
# Key absent → not probed yet; fall back to the raw preference.
|
| 140 |
self.model_effective_effort: dict[str, str | None] = {}
|
| 141 |
+
self.context_manager.on_message_added = self._schedule_trace_message
|
| 142 |
|
| 143 |
async def send_event(self, event: Event) -> None:
|
| 144 |
"""Send event back to client and log to trajectory"""
|
|
|
|
|
|
|
| 145 |
# Log event to trajectory
|
| 146 |
self.logged_events.append(
|
| 147 |
{
|
|
|
|
| 150 |
"data": event.data,
|
| 151 |
}
|
| 152 |
)
|
| 153 |
+
if self.persistence_store is not None:
|
| 154 |
+
try:
|
| 155 |
+
event.seq = await self.persistence_store.append_event(
|
| 156 |
+
self.session_id, event.event_type, event.data
|
| 157 |
+
)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.debug("Event persistence failed for %s: %s", self.session_id, e)
|
| 160 |
+
|
| 161 |
+
await self.event_queue.put(event)
|
| 162 |
+
await self._enqueue_auto_notification_requests(event)
|
| 163 |
|
| 164 |
# Mid-turn heartbeat flush (owned by telemetry module).
|
| 165 |
from agent.core.telemetry import HeartbeatSaver
|
| 166 |
+
|
| 167 |
HeartbeatSaver.maybe_fire(self)
|
| 168 |
|
| 169 |
+
def _schedule_trace_message(self, message: Any) -> None:
|
| 170 |
+
"""Best-effort append-only trace save for SFT/KPI export."""
|
| 171 |
+
if self.persistence_store is None:
|
| 172 |
+
return
|
| 173 |
+
try:
|
| 174 |
+
payload = message.model_dump(mode="json")
|
| 175 |
+
except Exception:
|
| 176 |
+
return
|
| 177 |
+
try:
|
| 178 |
+
loop = asyncio.get_running_loop()
|
| 179 |
+
except RuntimeError:
|
| 180 |
+
return
|
| 181 |
+
source = str(payload.get("role") or "message")
|
| 182 |
+
loop.create_task(
|
| 183 |
+
self.persistence_store.append_trace_message(
|
| 184 |
+
self.session_id, payload, source=source
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def set_notification_destinations(self, destinations: list[str]) -> None:
|
| 189 |
+
"""Replace the session's opted-in auto-notification destinations."""
|
| 190 |
+
deduped: list[str] = []
|
| 191 |
+
seen: set[str] = set()
|
| 192 |
+
for destination in destinations:
|
| 193 |
+
if destination not in seen:
|
| 194 |
+
deduped.append(destination)
|
| 195 |
+
seen.add(destination)
|
| 196 |
+
self.notification_destinations = deduped
|
| 197 |
+
|
| 198 |
+
async def send_deferred_turn_complete_notification(self, event: Event) -> None:
|
| 199 |
+
if event.event_type != "turn_complete":
|
| 200 |
+
return
|
| 201 |
+
await self._enqueue_auto_notification_requests(
|
| 202 |
+
event,
|
| 203 |
+
include_deferred_turn_complete=True,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
async def _enqueue_auto_notification_requests(
|
| 207 |
+
self,
|
| 208 |
+
event: Event,
|
| 209 |
+
include_deferred_turn_complete: bool = False,
|
| 210 |
+
) -> None:
|
| 211 |
+
if self.notification_gateway is None:
|
| 212 |
+
return
|
| 213 |
+
if not self.notification_destinations:
|
| 214 |
+
return
|
| 215 |
+
auto_events = set(self.config.messaging.auto_event_types)
|
| 216 |
+
if event.event_type not in auto_events:
|
| 217 |
+
return
|
| 218 |
+
if (
|
| 219 |
+
self.defer_turn_complete_notification
|
| 220 |
+
and event.event_type == "turn_complete"
|
| 221 |
+
and not include_deferred_turn_complete
|
| 222 |
+
):
|
| 223 |
+
return
|
| 224 |
+
|
| 225 |
+
requests = self._build_auto_notification_requests(event)
|
| 226 |
+
for request in requests:
|
| 227 |
+
await self.notification_gateway.enqueue(request)
|
| 228 |
+
|
| 229 |
+
def _build_auto_notification_requests(
|
| 230 |
+
self, event: Event
|
| 231 |
+
) -> list[NotificationRequest]:
|
| 232 |
+
metadata = {
|
| 233 |
+
"session_id": self.session_id,
|
| 234 |
+
"model": self.config.model_name,
|
| 235 |
+
"event_type": event.event_type,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
title: str | None = None
|
| 239 |
+
message: str | None = None
|
| 240 |
+
severity = "info"
|
| 241 |
+
data = event.data or {}
|
| 242 |
+
if event.event_type == "approval_required":
|
| 243 |
+
tools = data.get("tools", [])
|
| 244 |
+
tool_names = []
|
| 245 |
+
for tool in tools if isinstance(tools, list) else []:
|
| 246 |
+
if isinstance(tool, dict):
|
| 247 |
+
tool_name = str(tool.get("tool") or "").strip()
|
| 248 |
+
if tool_name and tool_name not in tool_names:
|
| 249 |
+
tool_names.append(tool_name)
|
| 250 |
+
count = len(tools) if isinstance(tools, list) else 0
|
| 251 |
+
title = "Agent approval required"
|
| 252 |
+
message = (
|
| 253 |
+
f"Session {self.session_id} is waiting for approval "
|
| 254 |
+
f"for {count} tool call(s)."
|
| 255 |
+
)
|
| 256 |
+
if tool_names:
|
| 257 |
+
message += " Tools: " + ", ".join(tool_names)
|
| 258 |
+
severity = "warning"
|
| 259 |
+
elif event.event_type == "error":
|
| 260 |
+
title = "Agent error"
|
| 261 |
+
error = str(data.get("error") or "Unknown error")
|
| 262 |
+
message = f"Session {self.session_id} hit an error.\n{error[:500]}"
|
| 263 |
+
severity = "error"
|
| 264 |
+
elif event.event_type == "turn_complete":
|
| 265 |
+
title = "Agent task complete"
|
| 266 |
+
summary = str(data.get("final_response") or "").strip()
|
| 267 |
+
if summary:
|
| 268 |
+
summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS]
|
| 269 |
+
message = (
|
| 270 |
+
f"Session {self.session_id} completed successfully.\n"
|
| 271 |
+
f"{summary}"
|
| 272 |
+
)
|
| 273 |
+
else:
|
| 274 |
+
message = f"Session {self.session_id} completed successfully."
|
| 275 |
+
severity = "success"
|
| 276 |
+
|
| 277 |
+
if message is None:
|
| 278 |
+
return []
|
| 279 |
+
|
| 280 |
+
requests: list[NotificationRequest] = []
|
| 281 |
+
for destination in self.notification_destinations:
|
| 282 |
+
if not self.config.messaging.can_auto_send(destination):
|
| 283 |
+
continue
|
| 284 |
+
requests.append(
|
| 285 |
+
NotificationRequest(
|
| 286 |
+
destination=destination,
|
| 287 |
+
title=title,
|
| 288 |
+
message=message,
|
| 289 |
+
severity=severity,
|
| 290 |
+
metadata=metadata,
|
| 291 |
+
event_type=event.event_type,
|
| 292 |
+
)
|
| 293 |
+
)
|
| 294 |
+
return requests
|
| 295 |
+
|
| 296 |
def cancel(self) -> None:
|
| 297 |
"""Signal cancellation to the running agent loop."""
|
| 298 |
self._cancelled.set()
|
|
|
|
| 351 |
tools = self.tool_router.get_tool_specs_for_llm() or []
|
| 352 |
except Exception:
|
| 353 |
tools = []
|
| 354 |
+
# Sum per-call cost from llm_call events so analyzers don't have to
|
| 355 |
+
# walk the events array themselves. Each `llm_call` event already
|
| 356 |
+
# carries cost_usd from `agent.core.telemetry.record_llm_call`.
|
| 357 |
+
total_cost_usd = sum(
|
| 358 |
+
float((e.get("data") or {}).get("cost_usd") or 0.0)
|
| 359 |
+
for e in self.logged_events
|
| 360 |
+
if e.get("event_type") == "llm_call"
|
| 361 |
+
)
|
| 362 |
return {
|
| 363 |
"session_id": self.session_id,
|
| 364 |
+
"user_id": self.user_id,
|
| 365 |
"session_start_time": self.session_start_time,
|
| 366 |
"session_end_time": datetime.now().isoformat(),
|
| 367 |
"model_name": self.config.model_name,
|
| 368 |
+
"total_cost_usd": total_cost_usd,
|
| 369 |
"messages": [msg.model_dump() for msg in self.context_manager.items],
|
| 370 |
"events": self.logged_events,
|
| 371 |
"tools": tools,
|
agent/core/session_persistence.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Optional durable session persistence for the hosted backend.
|
| 2 |
+
|
| 3 |
+
The public CLI must keep working without MongoDB. This module therefore
|
| 4 |
+
exposes one small async store interface and returns a no-op implementation
|
| 5 |
+
unless ``MONGODB_URI`` is configured and reachable.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from datetime import UTC, datetime
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
from bson import BSON
|
| 16 |
+
from pymongo import AsyncMongoClient, DeleteMany, ReturnDocument, UpdateOne
|
| 17 |
+
from pymongo.errors import DuplicateKeyError, InvalidDocument, PyMongoError
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
SCHEMA_VERSION = 1
|
| 22 |
+
MAX_BSON_BYTES = 15 * 1024 * 1024
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _now() -> datetime:
|
| 26 |
+
return datetime.now(UTC)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _doc_id(session_id: str, idx: int) -> str:
|
| 30 |
+
return f"{session_id}:{idx}"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _safe_message_doc(message: dict[str, Any]) -> dict[str, Any]:
|
| 34 |
+
"""Return a Mongo-safe message document payload.
|
| 35 |
+
|
| 36 |
+
Mongo's hard document limit is 16 MB. We stay below that and store an
|
| 37 |
+
explicit marker rather than failing the whole snapshot for one huge tool log.
|
| 38 |
+
"""
|
| 39 |
+
try:
|
| 40 |
+
if len(BSON.encode({"message": message})) <= MAX_BSON_BYTES:
|
| 41 |
+
return message
|
| 42 |
+
except (InvalidDocument, OverflowError):
|
| 43 |
+
pass
|
| 44 |
+
return {
|
| 45 |
+
"role": "tool",
|
| 46 |
+
"content": (
|
| 47 |
+
"[SYSTEM: A single persisted message exceeded MongoDB's document "
|
| 48 |
+
"size/encoding limit and was replaced by this marker.]"
|
| 49 |
+
),
|
| 50 |
+
"ml_intern_persistence_error": "message_too_large_or_invalid",
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class NoopSessionStore:
|
| 55 |
+
"""Async no-op store used when Mongo is not configured."""
|
| 56 |
+
|
| 57 |
+
enabled = False
|
| 58 |
+
|
| 59 |
+
async def init(self) -> None:
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
async def close(self) -> None:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
async def upsert_session(self, **_: Any) -> None:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
async def save_snapshot(self, **_: Any) -> None:
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
async def load_session(self, *_: Any, **__: Any) -> dict[str, Any] | None:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
async def list_sessions(self, *_: Any, **__: Any) -> list[dict[str, Any]]:
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
async def soft_delete_session(self, *_: Any, **__: Any) -> None:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
async def update_session_fields(self, *_: Any, **__: Any) -> None:
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
async def append_event(self, *_: Any, **__: Any) -> int | None:
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
async def load_events_after(self, *_: Any, **__: Any) -> list[dict[str, Any]]:
|
| 87 |
+
return []
|
| 88 |
+
|
| 89 |
+
async def append_trace_message(self, *_: Any, **__: Any) -> int | None:
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
async def get_quota(self, *_: Any, **__: Any) -> int | None:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
async def try_increment_quota(self, *_: Any, **__: Any) -> int | None:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
async def refund_quota(self, *_: Any, **__: Any) -> None:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MongoSessionStore(NoopSessionStore):
|
| 103 |
+
"""MongoDB-backed session store."""
|
| 104 |
+
|
| 105 |
+
enabled = True
|
| 106 |
+
|
| 107 |
+
def __init__(self, uri: str, db_name: str) -> None:
|
| 108 |
+
self.uri = uri
|
| 109 |
+
self.db_name = db_name
|
| 110 |
+
self.enabled = False
|
| 111 |
+
self.client: AsyncMongoClient | None = None
|
| 112 |
+
self.db = None
|
| 113 |
+
|
| 114 |
+
async def init(self) -> None:
|
| 115 |
+
try:
|
| 116 |
+
self.client = AsyncMongoClient(self.uri, serverSelectionTimeoutMS=3000)
|
| 117 |
+
self.db = self.client[self.db_name]
|
| 118 |
+
await self.client.admin.command("ping")
|
| 119 |
+
await self._create_indexes()
|
| 120 |
+
self.enabled = True
|
| 121 |
+
logger.info("Mongo session persistence enabled (db=%s)", self.db_name)
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.warning("Mongo session persistence disabled: %s", e)
|
| 124 |
+
self.enabled = False
|
| 125 |
+
if self.client is not None:
|
| 126 |
+
await self.client.close()
|
| 127 |
+
self.client = None
|
| 128 |
+
self.db = None
|
| 129 |
+
|
| 130 |
+
async def close(self) -> None:
|
| 131 |
+
if self.client is not None:
|
| 132 |
+
await self.client.close()
|
| 133 |
+
self.client = None
|
| 134 |
+
self.db = None
|
| 135 |
+
|
| 136 |
+
async def _create_indexes(self) -> None:
|
| 137 |
+
if self.db is None:
|
| 138 |
+
return
|
| 139 |
+
await self.db.sessions.create_index(
|
| 140 |
+
[("user_id", 1), ("visibility", 1), ("updated_at", -1)]
|
| 141 |
+
)
|
| 142 |
+
await self.db.sessions.create_index(
|
| 143 |
+
[("visibility", 1), ("status", 1), ("last_active_at", -1)]
|
| 144 |
+
)
|
| 145 |
+
await self.db.session_messages.create_index(
|
| 146 |
+
[("session_id", 1), ("idx", 1)], unique=True
|
| 147 |
+
)
|
| 148 |
+
await self.db.session_events.create_index(
|
| 149 |
+
[("session_id", 1), ("seq", 1)], unique=True
|
| 150 |
+
)
|
| 151 |
+
await self.db.session_trace_messages.create_index(
|
| 152 |
+
[("session_id", 1), ("seq", 1)], unique=True
|
| 153 |
+
)
|
| 154 |
+
await self.db.session_trace_messages.create_index([("created_at", -1)])
|
| 155 |
+
|
| 156 |
+
def _ready(self) -> bool:
|
| 157 |
+
return bool(self.enabled and self.db is not None)
|
| 158 |
+
|
| 159 |
+
async def upsert_session(
|
| 160 |
+
self,
|
| 161 |
+
*,
|
| 162 |
+
session_id: str,
|
| 163 |
+
user_id: str,
|
| 164 |
+
model: str,
|
| 165 |
+
title: str | None = None,
|
| 166 |
+
surface: str = "frontend",
|
| 167 |
+
created_at: datetime | None = None,
|
| 168 |
+
runtime_state: str = "idle",
|
| 169 |
+
status: str = "active",
|
| 170 |
+
message_count: int = 0,
|
| 171 |
+
turn_count: int = 0,
|
| 172 |
+
pending_approval: list[dict[str, Any]] | None = None,
|
| 173 |
+
claude_counted: bool = False,
|
| 174 |
+
notification_destinations: list[str] | None = None,
|
| 175 |
+
) -> None:
|
| 176 |
+
if not self._ready():
|
| 177 |
+
return
|
| 178 |
+
now = _now()
|
| 179 |
+
await self.db.sessions.update_one(
|
| 180 |
+
{"_id": session_id},
|
| 181 |
+
{
|
| 182 |
+
"$setOnInsert": {
|
| 183 |
+
"_id": session_id,
|
| 184 |
+
"session_id": session_id,
|
| 185 |
+
"user_id": user_id,
|
| 186 |
+
"surface": surface,
|
| 187 |
+
"created_at": created_at or now,
|
| 188 |
+
"schema_version": SCHEMA_VERSION,
|
| 189 |
+
"visibility": "live",
|
| 190 |
+
},
|
| 191 |
+
"$set": {
|
| 192 |
+
"title": title,
|
| 193 |
+
"model": model,
|
| 194 |
+
"status": status,
|
| 195 |
+
"runtime_state": runtime_state,
|
| 196 |
+
"updated_at": now,
|
| 197 |
+
"last_active_at": now,
|
| 198 |
+
"message_count": message_count,
|
| 199 |
+
"turn_count": turn_count,
|
| 200 |
+
"pending_approval": pending_approval or [],
|
| 201 |
+
"claude_counted": claude_counted,
|
| 202 |
+
"notification_destinations": notification_destinations or [],
|
| 203 |
+
},
|
| 204 |
+
},
|
| 205 |
+
upsert=True,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
async def save_snapshot(
|
| 209 |
+
self,
|
| 210 |
+
*,
|
| 211 |
+
session_id: str,
|
| 212 |
+
user_id: str,
|
| 213 |
+
model: str,
|
| 214 |
+
messages: list[dict[str, Any]],
|
| 215 |
+
title: str | None = None,
|
| 216 |
+
runtime_state: str = "idle",
|
| 217 |
+
status: str = "active",
|
| 218 |
+
turn_count: int = 0,
|
| 219 |
+
pending_approval: list[dict[str, Any]] | None = None,
|
| 220 |
+
claude_counted: bool = False,
|
| 221 |
+
created_at: datetime | None = None,
|
| 222 |
+
notification_destinations: list[str] | None = None,
|
| 223 |
+
) -> None:
|
| 224 |
+
if not self._ready():
|
| 225 |
+
return
|
| 226 |
+
now = _now()
|
| 227 |
+
await self.upsert_session(
|
| 228 |
+
session_id=session_id,
|
| 229 |
+
user_id=user_id,
|
| 230 |
+
model=model,
|
| 231 |
+
title=title,
|
| 232 |
+
created_at=created_at,
|
| 233 |
+
runtime_state=runtime_state,
|
| 234 |
+
status=status,
|
| 235 |
+
message_count=len(messages),
|
| 236 |
+
turn_count=turn_count,
|
| 237 |
+
pending_approval=pending_approval,
|
| 238 |
+
claude_counted=claude_counted,
|
| 239 |
+
notification_destinations=notification_destinations,
|
| 240 |
+
)
|
| 241 |
+
ops: list[Any] = []
|
| 242 |
+
for idx, raw in enumerate(messages):
|
| 243 |
+
ops.append(
|
| 244 |
+
UpdateOne(
|
| 245 |
+
{"_id": _doc_id(session_id, idx)},
|
| 246 |
+
{
|
| 247 |
+
"$set": {
|
| 248 |
+
"session_id": session_id,
|
| 249 |
+
"idx": idx,
|
| 250 |
+
"message": _safe_message_doc(raw),
|
| 251 |
+
"updated_at": now,
|
| 252 |
+
},
|
| 253 |
+
"$setOnInsert": {"created_at": now},
|
| 254 |
+
},
|
| 255 |
+
upsert=True,
|
| 256 |
+
)
|
| 257 |
+
)
|
| 258 |
+
ops.append(DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}}))
|
| 259 |
+
try:
|
| 260 |
+
if ops:
|
| 261 |
+
await self.db.session_messages.bulk_write(ops, ordered=False)
|
| 262 |
+
except PyMongoError as e:
|
| 263 |
+
logger.warning("Failed to persist session %s snapshot: %s", session_id, e)
|
| 264 |
+
|
| 265 |
+
async def load_session(
|
| 266 |
+
self, session_id: str, *, include_deleted: bool = False
|
| 267 |
+
) -> dict[str, Any] | None:
|
| 268 |
+
if not self._ready():
|
| 269 |
+
return None
|
| 270 |
+
meta = await self.db.sessions.find_one({"_id": session_id})
|
| 271 |
+
if not meta:
|
| 272 |
+
return None
|
| 273 |
+
if meta.get("visibility") == "deleted" and not include_deleted:
|
| 274 |
+
return None
|
| 275 |
+
cursor = self.db.session_messages.find({"session_id": session_id}).sort("idx", 1)
|
| 276 |
+
messages = [row.get("message") async for row in cursor]
|
| 277 |
+
return {"metadata": meta, "messages": messages}
|
| 278 |
+
|
| 279 |
+
async def list_sessions(
|
| 280 |
+
self, user_id: str, *, include_deleted: bool = False
|
| 281 |
+
) -> list[dict[str, Any]]:
|
| 282 |
+
if not self._ready():
|
| 283 |
+
return []
|
| 284 |
+
query: dict[str, Any] = {"user_id": user_id}
|
| 285 |
+
if user_id == "dev":
|
| 286 |
+
query = {}
|
| 287 |
+
if not include_deleted:
|
| 288 |
+
query["visibility"] = {"$ne": "deleted"}
|
| 289 |
+
cursor = self.db.sessions.find(query).sort("updated_at", -1)
|
| 290 |
+
return [row async for row in cursor]
|
| 291 |
+
|
| 292 |
+
async def soft_delete_session(self, session_id: str) -> None:
|
| 293 |
+
if not self._ready():
|
| 294 |
+
return
|
| 295 |
+
await self.db.sessions.update_one(
|
| 296 |
+
{"_id": session_id},
|
| 297 |
+
{
|
| 298 |
+
"$set": {
|
| 299 |
+
"visibility": "deleted",
|
| 300 |
+
"runtime_state": "idle",
|
| 301 |
+
"updated_at": _now(),
|
| 302 |
+
}
|
| 303 |
+
},
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
async def update_session_fields(self, session_id: str, **fields: Any) -> None:
|
| 307 |
+
if not self._ready() or not fields:
|
| 308 |
+
return
|
| 309 |
+
fields["updated_at"] = _now()
|
| 310 |
+
await self.db.sessions.update_one({"_id": session_id}, {"$set": fields})
|
| 311 |
+
|
| 312 |
+
async def _next_seq(self, counter_id: str) -> int:
|
| 313 |
+
doc = await self.db.counters.find_one_and_update(
|
| 314 |
+
{"_id": counter_id},
|
| 315 |
+
{"$inc": {"seq": 1}},
|
| 316 |
+
upsert=True,
|
| 317 |
+
return_document=ReturnDocument.AFTER,
|
| 318 |
+
)
|
| 319 |
+
return int(doc["seq"])
|
| 320 |
+
|
| 321 |
+
async def append_event(
|
| 322 |
+
self, session_id: str, event_type: str, data: dict[str, Any] | None
|
| 323 |
+
) -> int | None:
|
| 324 |
+
if not self._ready():
|
| 325 |
+
return None
|
| 326 |
+
try:
|
| 327 |
+
seq = await self._next_seq(f"event:{session_id}")
|
| 328 |
+
await self.db.session_events.insert_one(
|
| 329 |
+
{
|
| 330 |
+
"_id": _doc_id(session_id, seq),
|
| 331 |
+
"session_id": session_id,
|
| 332 |
+
"seq": seq,
|
| 333 |
+
"event_type": event_type,
|
| 334 |
+
"data": data or {},
|
| 335 |
+
"created_at": _now(),
|
| 336 |
+
}
|
| 337 |
+
)
|
| 338 |
+
return seq
|
| 339 |
+
except PyMongoError as e:
|
| 340 |
+
logger.debug("Failed to append event for %s: %s", session_id, e)
|
| 341 |
+
return None
|
| 342 |
+
|
| 343 |
+
async def load_events_after(self, session_id: str, after_seq: int = 0) -> list[dict[str, Any]]:
|
| 344 |
+
if not self._ready():
|
| 345 |
+
return []
|
| 346 |
+
cursor = self.db.session_events.find(
|
| 347 |
+
{"session_id": session_id, "seq": {"$gt": int(after_seq or 0)}}
|
| 348 |
+
).sort("seq", 1)
|
| 349 |
+
return [row async for row in cursor]
|
| 350 |
+
|
| 351 |
+
async def append_trace_message(
|
| 352 |
+
self, session_id: str, message: dict[str, Any], source: str = "message"
|
| 353 |
+
) -> int | None:
|
| 354 |
+
if not self._ready():
|
| 355 |
+
return None
|
| 356 |
+
try:
|
| 357 |
+
seq = await self._next_seq(f"trace:{session_id}")
|
| 358 |
+
await self.db.session_trace_messages.insert_one(
|
| 359 |
+
{
|
| 360 |
+
"_id": _doc_id(session_id, seq),
|
| 361 |
+
"session_id": session_id,
|
| 362 |
+
"seq": seq,
|
| 363 |
+
"role": message.get("role"),
|
| 364 |
+
"message": _safe_message_doc(message),
|
| 365 |
+
"source": source,
|
| 366 |
+
"created_at": _now(),
|
| 367 |
+
}
|
| 368 |
+
)
|
| 369 |
+
return seq
|
| 370 |
+
except PyMongoError as e:
|
| 371 |
+
logger.debug("Failed to append trace message for %s: %s", session_id, e)
|
| 372 |
+
return None
|
| 373 |
+
|
| 374 |
+
async def get_quota(self, user_id: str, day: str) -> int | None:
|
| 375 |
+
if not self._ready():
|
| 376 |
+
return None
|
| 377 |
+
doc = await self.db.claude_quotas.find_one({"_id": f"{user_id}:{day}"})
|
| 378 |
+
return int(doc.get("count", 0)) if doc else 0
|
| 379 |
+
|
| 380 |
+
async def try_increment_quota(self, user_id: str, day: str, cap: int) -> int | None:
|
| 381 |
+
if not self._ready():
|
| 382 |
+
return None
|
| 383 |
+
key = f"{user_id}:{day}"
|
| 384 |
+
now = _now()
|
| 385 |
+
try:
|
| 386 |
+
await self.db.claude_quotas.insert_one(
|
| 387 |
+
{
|
| 388 |
+
"_id": key,
|
| 389 |
+
"user_id": user_id,
|
| 390 |
+
"day": day,
|
| 391 |
+
"count": 1,
|
| 392 |
+
"updated_at": now,
|
| 393 |
+
}
|
| 394 |
+
)
|
| 395 |
+
return 1
|
| 396 |
+
except DuplicateKeyError:
|
| 397 |
+
pass
|
| 398 |
+
doc = await self.db.claude_quotas.find_one_and_update(
|
| 399 |
+
{"_id": key, "count": {"$lt": cap}},
|
| 400 |
+
{"$inc": {"count": 1}, "$set": {"updated_at": now}},
|
| 401 |
+
return_document=ReturnDocument.AFTER,
|
| 402 |
+
)
|
| 403 |
+
return int(doc["count"]) if doc else None
|
| 404 |
+
|
| 405 |
+
async def refund_quota(self, user_id: str, day: str) -> None:
|
| 406 |
+
if not self._ready():
|
| 407 |
+
return
|
| 408 |
+
await self.db.claude_quotas.update_one(
|
| 409 |
+
{"_id": f"{user_id}:{day}", "count": {"$gt": 0}},
|
| 410 |
+
{"$inc": {"count": -1}, "$set": {"updated_at": _now()}},
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
_store: NoopSessionStore | MongoSessionStore | None = None
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def get_session_store() -> NoopSessionStore | MongoSessionStore:
|
| 418 |
+
global _store
|
| 419 |
+
if _store is None:
|
| 420 |
+
uri = os.environ.get("MONGODB_URI")
|
| 421 |
+
db_name = os.environ.get("MONGODB_DB", "ml-intern")
|
| 422 |
+
_store = MongoSessionStore(uri, db_name) if uri else NoopSessionStore()
|
| 423 |
+
return _store
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def _reset_store_for_tests(store: NoopSessionStore | MongoSessionStore | None = None) -> None:
|
| 427 |
+
global _store
|
| 428 |
+
_store = store
|
agent/core/session_uploader.py
CHANGED
|
@@ -90,9 +90,11 @@ def upload_session_as_file(
|
|
| 90 |
# across sessions with different tool rosters.
|
| 91 |
session_row = {
|
| 92 |
"session_id": data["session_id"],
|
|
|
|
| 93 |
"session_start_time": data["session_start_time"],
|
| 94 |
"session_end_time": data["session_end_time"],
|
| 95 |
"model_name": data["model_name"],
|
|
|
|
| 96 |
"messages": json.dumps(scrubbed_messages),
|
| 97 |
"events": json.dumps(scrubbed_events),
|
| 98 |
"tools": json.dumps(scrubbed_tools),
|
|
|
|
| 90 |
# across sessions with different tool rosters.
|
| 91 |
session_row = {
|
| 92 |
"session_id": data["session_id"],
|
| 93 |
+
"user_id": data.get("user_id"),
|
| 94 |
"session_start_time": data["session_start_time"],
|
| 95 |
"session_end_time": data["session_end_time"],
|
| 96 |
"model_name": data["model_name"],
|
| 97 |
+
"total_cost_usd": data.get("total_cost_usd"),
|
| 98 |
"messages": json.dumps(scrubbed_messages),
|
| 99 |
"events": json.dumps(scrubbed_events),
|
| 100 |
"tools": json.dumps(scrubbed_tools),
|
agent/core/tools.py
CHANGED
|
@@ -46,10 +46,12 @@ from agent.tools.hf_repo_git_tool import (
|
|
| 46 |
hf_repo_git_handler,
|
| 47 |
)
|
| 48 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
|
|
|
|
| 49 |
from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
|
| 50 |
from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
|
| 51 |
from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler
|
| 52 |
from agent.tools.sandbox_tool import get_sandbox_tools
|
|
|
|
| 53 |
|
| 54 |
# NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
|
| 55 |
# from agent.tools.private_hf_repo_tools import (
|
|
@@ -310,6 +312,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]:
|
|
| 310 |
parameters=HF_PAPERS_TOOL_SPEC["parameters"],
|
| 311 |
handler=hf_papers_handler,
|
| 312 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
# Dataset inspection tool (unified)
|
| 314 |
ToolSpec(
|
| 315 |
name=HF_INSPECT_DATASET_TOOL_SPEC["name"],
|
|
@@ -324,6 +332,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]:
|
|
| 324 |
parameters=PLAN_TOOL_SPEC["parameters"],
|
| 325 |
handler=plan_tool_handler,
|
| 326 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
ToolSpec(
|
| 328 |
name=HF_JOBS_TOOL_SPEC["name"],
|
| 329 |
description=HF_JOBS_TOOL_SPEC["description"],
|
|
|
|
| 46 |
hf_repo_git_handler,
|
| 47 |
)
|
| 48 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
|
| 49 |
+
from agent.tools.notify_tool import NOTIFY_TOOL_SPEC, notify_handler
|
| 50 |
from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
|
| 51 |
from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
|
| 52 |
from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler
|
| 53 |
from agent.tools.sandbox_tool import get_sandbox_tools
|
| 54 |
+
from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler
|
| 55 |
|
| 56 |
# NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
|
| 57 |
# from agent.tools.private_hf_repo_tools import (
|
|
|
|
| 312 |
parameters=HF_PAPERS_TOOL_SPEC["parameters"],
|
| 313 |
handler=hf_papers_handler,
|
| 314 |
),
|
| 315 |
+
ToolSpec(
|
| 316 |
+
name=WEB_SEARCH_TOOL_SPEC["name"],
|
| 317 |
+
description=WEB_SEARCH_TOOL_SPEC["description"],
|
| 318 |
+
parameters=WEB_SEARCH_TOOL_SPEC["parameters"],
|
| 319 |
+
handler=web_search_handler,
|
| 320 |
+
),
|
| 321 |
# Dataset inspection tool (unified)
|
| 322 |
ToolSpec(
|
| 323 |
name=HF_INSPECT_DATASET_TOOL_SPEC["name"],
|
|
|
|
| 332 |
parameters=PLAN_TOOL_SPEC["parameters"],
|
| 333 |
handler=plan_tool_handler,
|
| 334 |
),
|
| 335 |
+
ToolSpec(
|
| 336 |
+
name=NOTIFY_TOOL_SPEC["name"],
|
| 337 |
+
description=NOTIFY_TOOL_SPEC["description"],
|
| 338 |
+
parameters=NOTIFY_TOOL_SPEC["parameters"],
|
| 339 |
+
handler=notify_handler,
|
| 340 |
+
),
|
| 341 |
ToolSpec(
|
| 342 |
name=HF_JOBS_TOOL_SPEC["name"],
|
| 343 |
description=HF_JOBS_TOOL_SPEC["description"],
|
agent/main.py
CHANGED
|
@@ -23,8 +23,10 @@ from prompt_toolkit import PromptSession
|
|
| 23 |
from agent.config import load_config
|
| 24 |
from agent.core.agent_loop import submission_loop
|
| 25 |
from agent.core import model_switcher
|
|
|
|
| 26 |
from agent.core.session import OpType
|
| 27 |
from agent.core.tools import ToolRouter
|
|
|
|
| 28 |
from agent.utils.reliability_checks import check_training_script_save_pattern
|
| 29 |
from agent.utils.terminal_display import (
|
| 30 |
get_console,
|
|
@@ -69,26 +71,15 @@ def _safe_get_args(arguments: dict) -> dict:
|
|
| 69 |
return args if isinstance(args, dict) else {}
|
| 70 |
|
| 71 |
|
| 72 |
-
def
|
| 73 |
-
"""
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
return token
|
| 77 |
try:
|
| 78 |
from huggingface_hub import HfApi
|
| 79 |
-
|
| 80 |
-
token = api.token
|
| 81 |
-
if token:
|
| 82 |
-
return token
|
| 83 |
except Exception:
|
| 84 |
-
|
| 85 |
-
# Fallback: read the cached token file directly
|
| 86 |
-
token_path = Path.home() / ".cache" / "huggingface" / "token"
|
| 87 |
-
if token_path.exists():
|
| 88 |
-
token = token_path.read_text().strip()
|
| 89 |
-
if token:
|
| 90 |
-
return token
|
| 91 |
-
return None
|
| 92 |
|
| 93 |
|
| 94 |
async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str:
|
|
@@ -342,6 +333,9 @@ async def event_listener(
|
|
| 342 |
stream_buf.discard()
|
| 343 |
print_turn_complete()
|
| 344 |
print_plan()
|
|
|
|
|
|
|
|
|
|
| 345 |
turn_complete_event.set()
|
| 346 |
elif event.event_type == "interrupted":
|
| 347 |
shimmer.stop()
|
|
@@ -758,7 +752,7 @@ async def _handle_slash_command(
|
|
| 758 |
normalized = arg.removeprefix("huggingface/")
|
| 759 |
session = session_holder[0] if session_holder else None
|
| 760 |
await model_switcher.probe_and_switch_model(
|
| 761 |
-
normalized, config, session, console,
|
| 762 |
)
|
| 763 |
return None
|
| 764 |
|
|
@@ -817,7 +811,7 @@ async def _handle_slash_command(
|
|
| 817 |
return None
|
| 818 |
|
| 819 |
|
| 820 |
-
async def main():
|
| 821 |
"""Interactive chat with the agent"""
|
| 822 |
|
| 823 |
# Clear screen
|
|
@@ -827,19 +821,16 @@ async def main():
|
|
| 827 |
prompt_session = PromptSession()
|
| 828 |
|
| 829 |
# HF token — required, prompt if missing
|
| 830 |
-
hf_token =
|
| 831 |
if not hf_token:
|
| 832 |
hf_token = await _prompt_and_save_hf_token(prompt_session)
|
| 833 |
|
| 834 |
-
config = load_config(CLI_CONFIG_PATH)
|
|
|
|
|
|
|
| 835 |
|
| 836 |
# Resolve username for banner
|
| 837 |
-
hf_user =
|
| 838 |
-
try:
|
| 839 |
-
from huggingface_hub import HfApi
|
| 840 |
-
hf_user = HfApi(token=hf_token).whoami().get("name")
|
| 841 |
-
except Exception:
|
| 842 |
-
pass
|
| 843 |
|
| 844 |
print_banner(model=config.model_name, hf_user=hf_user)
|
| 845 |
|
|
@@ -857,6 +848,8 @@ async def main():
|
|
| 857 |
turn_complete_event.set()
|
| 858 |
ready_event = asyncio.Event()
|
| 859 |
|
|
|
|
|
|
|
| 860 |
# Create tool router with local mode
|
| 861 |
tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
|
| 862 |
|
|
@@ -871,8 +864,12 @@ async def main():
|
|
| 871 |
tool_router=tool_router,
|
| 872 |
session_holder=session_holder,
|
| 873 |
hf_token=hf_token,
|
|
|
|
| 874 |
local_mode=True,
|
| 875 |
stream=True,
|
|
|
|
|
|
|
|
|
|
| 876 |
)
|
| 877 |
)
|
| 878 |
|
|
@@ -1028,6 +1025,8 @@ async def main():
|
|
| 1028 |
agent_task.cancel()
|
| 1029 |
# Agent didn't shut down cleanly — close MCP explicitly
|
| 1030 |
await tool_router.__aexit__(None, None, None)
|
|
|
|
|
|
|
| 1031 |
|
| 1032 |
# Now safe to cancel the listener (agent is done emitting events)
|
| 1033 |
listener_task.cancel()
|
|
@@ -1047,15 +1046,18 @@ async def headless_main(
|
|
| 1047 |
logging.basicConfig(level=logging.WARNING)
|
| 1048 |
_configure_runtime_logging()
|
| 1049 |
|
| 1050 |
-
hf_token =
|
| 1051 |
if not hf_token:
|
| 1052 |
print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr)
|
| 1053 |
sys.exit(1)
|
| 1054 |
|
| 1055 |
print(f"HF token loaded", file=sys.stderr)
|
| 1056 |
|
| 1057 |
-
config = load_config(CLI_CONFIG_PATH)
|
| 1058 |
config.yolo_mode = True # Auto-approve everything in headless mode
|
|
|
|
|
|
|
|
|
|
| 1059 |
|
| 1060 |
if model:
|
| 1061 |
config.model_name = model
|
|
@@ -1082,8 +1084,12 @@ async def headless_main(
|
|
| 1082 |
tool_router=tool_router,
|
| 1083 |
session_holder=session_holder,
|
| 1084 |
hf_token=hf_token,
|
|
|
|
| 1085 |
local_mode=True,
|
| 1086 |
stream=stream,
|
|
|
|
|
|
|
|
|
|
| 1087 |
)
|
| 1088 |
)
|
| 1089 |
|
|
@@ -1209,6 +1215,10 @@ async def headless_main(
|
|
| 1209 |
stream_buf.discard()
|
| 1210 |
history_size = event.data.get("history_size", "?") if event.data else "?"
|
| 1211 |
print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1212 |
break
|
| 1213 |
|
| 1214 |
# Shutdown
|
|
@@ -1222,6 +1232,8 @@ async def headless_main(
|
|
| 1222 |
except asyncio.TimeoutError:
|
| 1223 |
agent_task.cancel()
|
| 1224 |
await tool_router.__aexit__(None, None, None)
|
|
|
|
|
|
|
| 1225 |
|
| 1226 |
|
| 1227 |
def cli():
|
|
@@ -1252,7 +1264,7 @@ def cli():
|
|
| 1252 |
max_iter = 10_000 # effectively unlimited
|
| 1253 |
asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream))
|
| 1254 |
else:
|
| 1255 |
-
asyncio.run(main())
|
| 1256 |
except KeyboardInterrupt:
|
| 1257 |
print("\n\nGoodbye!")
|
| 1258 |
|
|
|
|
| 23 |
from agent.config import load_config
|
| 24 |
from agent.core.agent_loop import submission_loop
|
| 25 |
from agent.core import model_switcher
|
| 26 |
+
from agent.core.hf_tokens import resolve_hf_token
|
| 27 |
from agent.core.session import OpType
|
| 28 |
from agent.core.tools import ToolRouter
|
| 29 |
+
from agent.messaging.gateway import NotificationGateway
|
| 30 |
from agent.utils.reliability_checks import check_training_script_save_pattern
|
| 31 |
from agent.utils.terminal_display import (
|
| 32 |
get_console,
|
|
|
|
| 71 |
return args if isinstance(args, dict) else {}
|
| 72 |
|
| 73 |
|
| 74 |
+
def _get_hf_user(token: str | None) -> str | None:
|
| 75 |
+
"""Resolve the HF username for a token, if available."""
|
| 76 |
+
if not token:
|
| 77 |
+
return None
|
|
|
|
| 78 |
try:
|
| 79 |
from huggingface_hub import HfApi
|
| 80 |
+
return HfApi(token=token).whoami().get("name")
|
|
|
|
|
|
|
|
|
|
| 81 |
except Exception:
|
| 82 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str:
|
|
|
|
| 333 |
stream_buf.discard()
|
| 334 |
print_turn_complete()
|
| 335 |
print_plan()
|
| 336 |
+
session = session_holder[0] if session_holder else None
|
| 337 |
+
if session is not None:
|
| 338 |
+
await session.send_deferred_turn_complete_notification(event)
|
| 339 |
turn_complete_event.set()
|
| 340 |
elif event.event_type == "interrupted":
|
| 341 |
shimmer.stop()
|
|
|
|
| 752 |
normalized = arg.removeprefix("huggingface/")
|
| 753 |
session = session_holder[0] if session_holder else None
|
| 754 |
await model_switcher.probe_and_switch_model(
|
| 755 |
+
normalized, config, session, console, resolve_hf_token(),
|
| 756 |
)
|
| 757 |
return None
|
| 758 |
|
|
|
|
| 811 |
return None
|
| 812 |
|
| 813 |
|
| 814 |
+
async def main(model: str | None = None):
|
| 815 |
"""Interactive chat with the agent"""
|
| 816 |
|
| 817 |
# Clear screen
|
|
|
|
| 821 |
prompt_session = PromptSession()
|
| 822 |
|
| 823 |
# HF token — required, prompt if missing
|
| 824 |
+
hf_token = resolve_hf_token()
|
| 825 |
if not hf_token:
|
| 826 |
hf_token = await _prompt_and_save_hf_token(prompt_session)
|
| 827 |
|
| 828 |
+
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 829 |
+
if model:
|
| 830 |
+
config.model_name = model
|
| 831 |
|
| 832 |
# Resolve username for banner
|
| 833 |
+
hf_user = _get_hf_user(hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 834 |
|
| 835 |
print_banner(model=config.model_name, hf_user=hf_user)
|
| 836 |
|
|
|
|
| 848 |
turn_complete_event.set()
|
| 849 |
ready_event = asyncio.Event()
|
| 850 |
|
| 851 |
+
notification_gateway = NotificationGateway(config.messaging)
|
| 852 |
+
await notification_gateway.start()
|
| 853 |
# Create tool router with local mode
|
| 854 |
tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
|
| 855 |
|
|
|
|
| 864 |
tool_router=tool_router,
|
| 865 |
session_holder=session_holder,
|
| 866 |
hf_token=hf_token,
|
| 867 |
+
user_id=hf_user,
|
| 868 |
local_mode=True,
|
| 869 |
stream=True,
|
| 870 |
+
notification_gateway=notification_gateway,
|
| 871 |
+
notification_destinations=config.messaging.default_auto_destinations(),
|
| 872 |
+
defer_turn_complete_notification=True,
|
| 873 |
)
|
| 874 |
)
|
| 875 |
|
|
|
|
| 1025 |
agent_task.cancel()
|
| 1026 |
# Agent didn't shut down cleanly — close MCP explicitly
|
| 1027 |
await tool_router.__aexit__(None, None, None)
|
| 1028 |
+
finally:
|
| 1029 |
+
await notification_gateway.close()
|
| 1030 |
|
| 1031 |
# Now safe to cancel the listener (agent is done emitting events)
|
| 1032 |
listener_task.cancel()
|
|
|
|
| 1046 |
logging.basicConfig(level=logging.WARNING)
|
| 1047 |
_configure_runtime_logging()
|
| 1048 |
|
| 1049 |
+
hf_token = resolve_hf_token()
|
| 1050 |
if not hf_token:
|
| 1051 |
print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr)
|
| 1052 |
sys.exit(1)
|
| 1053 |
|
| 1054 |
print(f"HF token loaded", file=sys.stderr)
|
| 1055 |
|
| 1056 |
+
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 1057 |
config.yolo_mode = True # Auto-approve everything in headless mode
|
| 1058 |
+
notification_gateway = NotificationGateway(config.messaging)
|
| 1059 |
+
await notification_gateway.start()
|
| 1060 |
+
hf_user = _get_hf_user(hf_token)
|
| 1061 |
|
| 1062 |
if model:
|
| 1063 |
config.model_name = model
|
|
|
|
| 1084 |
tool_router=tool_router,
|
| 1085 |
session_holder=session_holder,
|
| 1086 |
hf_token=hf_token,
|
| 1087 |
+
user_id=hf_user,
|
| 1088 |
local_mode=True,
|
| 1089 |
stream=stream,
|
| 1090 |
+
notification_gateway=notification_gateway,
|
| 1091 |
+
notification_destinations=config.messaging.default_auto_destinations(),
|
| 1092 |
+
defer_turn_complete_notification=True,
|
| 1093 |
)
|
| 1094 |
)
|
| 1095 |
|
|
|
|
| 1215 |
stream_buf.discard()
|
| 1216 |
history_size = event.data.get("history_size", "?") if event.data else "?"
|
| 1217 |
print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr)
|
| 1218 |
+
if event.event_type == "turn_complete":
|
| 1219 |
+
session = session_holder[0] if session_holder else None
|
| 1220 |
+
if session is not None:
|
| 1221 |
+
await session.send_deferred_turn_complete_notification(event)
|
| 1222 |
break
|
| 1223 |
|
| 1224 |
# Shutdown
|
|
|
|
| 1232 |
except asyncio.TimeoutError:
|
| 1233 |
agent_task.cancel()
|
| 1234 |
await tool_router.__aexit__(None, None, None)
|
| 1235 |
+
finally:
|
| 1236 |
+
await notification_gateway.close()
|
| 1237 |
|
| 1238 |
|
| 1239 |
def cli():
|
|
|
|
| 1264 |
max_iter = 10_000 # effectively unlimited
|
| 1265 |
asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream))
|
| 1266 |
else:
|
| 1267 |
+
asyncio.run(main(model=args.model))
|
| 1268 |
except KeyboardInterrupt:
|
| 1269 |
print("\n\nGoodbye!")
|
| 1270 |
|
agent/messaging/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agent.messaging.gateway import NotificationGateway
|
| 2 |
+
from agent.messaging.models import (
|
| 3 |
+
MessagingConfig,
|
| 4 |
+
NotificationRequest,
|
| 5 |
+
NotificationResult,
|
| 6 |
+
SUPPORTED_AUTO_EVENT_TYPES,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"MessagingConfig",
|
| 11 |
+
"NotificationGateway",
|
| 12 |
+
"NotificationRequest",
|
| 13 |
+
"NotificationResult",
|
| 14 |
+
"SUPPORTED_AUTO_EVENT_TYPES",
|
| 15 |
+
]
|
agent/messaging/base.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
|
| 5 |
+
from agent.messaging.models import DestinationConfig, NotificationRequest, NotificationResult
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NotificationError(Exception):
|
| 9 |
+
"""Delivery failed and should not be retried."""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RetryableNotificationError(NotificationError):
|
| 13 |
+
"""Delivery failed transiently and can be retried."""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NotificationProvider(ABC):
|
| 17 |
+
provider_name: str
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
async def send(
|
| 21 |
+
self,
|
| 22 |
+
client: httpx.AsyncClient,
|
| 23 |
+
destination_name: str,
|
| 24 |
+
destination: DestinationConfig,
|
| 25 |
+
request: NotificationRequest,
|
| 26 |
+
) -> NotificationResult:
|
| 27 |
+
"""Deliver a notification to one destination."""
|
agent/messaging/gateway.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
|
| 5 |
+
import httpx
|
| 6 |
+
|
| 7 |
+
from agent.messaging.base import (
|
| 8 |
+
NotificationError,
|
| 9 |
+
NotificationProvider,
|
| 10 |
+
RetryableNotificationError,
|
| 11 |
+
)
|
| 12 |
+
from agent.messaging.models import (
|
| 13 |
+
MessagingConfig,
|
| 14 |
+
NotificationRequest,
|
| 15 |
+
NotificationResult,
|
| 16 |
+
)
|
| 17 |
+
from agent.messaging.slack import SlackProvider
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
_RETRY_DELAYS = (1, 2, 4)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NotificationGateway:
|
| 25 |
+
def __init__(self, config: MessagingConfig):
|
| 26 |
+
self.config = config
|
| 27 |
+
self._providers: dict[str, NotificationProvider] = {
|
| 28 |
+
"slack": SlackProvider(),
|
| 29 |
+
}
|
| 30 |
+
self._queue: asyncio.Queue[NotificationRequest] = asyncio.Queue()
|
| 31 |
+
self._worker_task: asyncio.Task | None = None
|
| 32 |
+
self._client: httpx.AsyncClient | None = None
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def enabled(self) -> bool:
|
| 36 |
+
return self.config.enabled
|
| 37 |
+
|
| 38 |
+
async def start(self) -> None:
|
| 39 |
+
if not self.enabled or self._worker_task is not None:
|
| 40 |
+
return
|
| 41 |
+
self._client = httpx.AsyncClient(timeout=10.0)
|
| 42 |
+
self._worker_task = asyncio.create_task(self._worker(), name="notification-gateway")
|
| 43 |
+
|
| 44 |
+
async def flush(self) -> None:
|
| 45 |
+
if not self.enabled:
|
| 46 |
+
return
|
| 47 |
+
await self._queue.join()
|
| 48 |
+
|
| 49 |
+
async def close(self) -> None:
|
| 50 |
+
if not self.enabled:
|
| 51 |
+
return
|
| 52 |
+
await self.flush()
|
| 53 |
+
if self._worker_task is not None:
|
| 54 |
+
self._worker_task.cancel()
|
| 55 |
+
try:
|
| 56 |
+
await self._worker_task
|
| 57 |
+
except asyncio.CancelledError:
|
| 58 |
+
pass
|
| 59 |
+
self._worker_task = None
|
| 60 |
+
if self._client is not None:
|
| 61 |
+
await self._client.aclose()
|
| 62 |
+
self._client = None
|
| 63 |
+
|
| 64 |
+
async def send(self, request: NotificationRequest) -> NotificationResult:
|
| 65 |
+
if not self.enabled:
|
| 66 |
+
return NotificationResult(
|
| 67 |
+
destination=request.destination,
|
| 68 |
+
ok=False,
|
| 69 |
+
provider="disabled",
|
| 70 |
+
error="Messaging is disabled",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
destination = self.config.get_destination(request.destination)
|
| 74 |
+
if destination is None:
|
| 75 |
+
return NotificationResult(
|
| 76 |
+
destination=request.destination,
|
| 77 |
+
ok=False,
|
| 78 |
+
provider="unknown",
|
| 79 |
+
error=f"Unknown destination '{request.destination}'",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
provider = self._providers.get(destination.provider)
|
| 83 |
+
if provider is None:
|
| 84 |
+
return NotificationResult(
|
| 85 |
+
destination=request.destination,
|
| 86 |
+
ok=False,
|
| 87 |
+
provider=destination.provider,
|
| 88 |
+
error=f"No provider implementation for '{destination.provider}'",
|
| 89 |
+
)
|
| 90 |
+
return await self._send_with_retries(provider, request.destination, destination, request)
|
| 91 |
+
|
| 92 |
+
async def send_many(
|
| 93 |
+
self, requests: Iterable[NotificationRequest]
|
| 94 |
+
) -> list[NotificationResult]:
|
| 95 |
+
results: list[NotificationResult] = []
|
| 96 |
+
for request in requests:
|
| 97 |
+
results.append(await self.send(request))
|
| 98 |
+
return results
|
| 99 |
+
|
| 100 |
+
async def enqueue(self, request: NotificationRequest) -> bool:
|
| 101 |
+
if not self.enabled or self._worker_task is None:
|
| 102 |
+
return False
|
| 103 |
+
await self._queue.put(request)
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
async def _worker(self) -> None:
|
| 107 |
+
while True:
|
| 108 |
+
request = await self._queue.get()
|
| 109 |
+
try:
|
| 110 |
+
result = await self.send(request)
|
| 111 |
+
if not result.ok:
|
| 112 |
+
logger.warning(
|
| 113 |
+
"Notification delivery failed for %s: %s",
|
| 114 |
+
request.destination,
|
| 115 |
+
result.error,
|
| 116 |
+
)
|
| 117 |
+
except Exception:
|
| 118 |
+
logger.exception("Unexpected notification worker failure")
|
| 119 |
+
finally:
|
| 120 |
+
self._queue.task_done()
|
| 121 |
+
|
| 122 |
+
async def _send_with_retries(
|
| 123 |
+
self,
|
| 124 |
+
provider: NotificationProvider,
|
| 125 |
+
destination_name: str,
|
| 126 |
+
destination,
|
| 127 |
+
request: NotificationRequest,
|
| 128 |
+
) -> NotificationResult:
|
| 129 |
+
client = self._client or httpx.AsyncClient(timeout=10.0)
|
| 130 |
+
owns_client = self._client is None
|
| 131 |
+
try:
|
| 132 |
+
for attempt in range(len(_RETRY_DELAYS) + 1):
|
| 133 |
+
try:
|
| 134 |
+
return await provider.send(client, destination_name, destination, request)
|
| 135 |
+
except RetryableNotificationError as exc:
|
| 136 |
+
if attempt >= len(_RETRY_DELAYS):
|
| 137 |
+
return NotificationResult(
|
| 138 |
+
destination=destination_name,
|
| 139 |
+
ok=False,
|
| 140 |
+
provider=provider.provider_name,
|
| 141 |
+
error=str(exc),
|
| 142 |
+
)
|
| 143 |
+
delay = _RETRY_DELAYS[attempt]
|
| 144 |
+
logger.warning(
|
| 145 |
+
"Retrying notification to %s in %ss after transient error: %s",
|
| 146 |
+
destination_name,
|
| 147 |
+
delay,
|
| 148 |
+
exc,
|
| 149 |
+
)
|
| 150 |
+
await asyncio.sleep(delay)
|
| 151 |
+
except NotificationError as exc:
|
| 152 |
+
return NotificationResult(
|
| 153 |
+
destination=destination_name,
|
| 154 |
+
ok=False,
|
| 155 |
+
provider=provider.provider_name,
|
| 156 |
+
error=str(exc),
|
| 157 |
+
)
|
| 158 |
+
return NotificationResult(
|
| 159 |
+
destination=destination_name,
|
| 160 |
+
ok=False,
|
| 161 |
+
provider=provider.provider_name,
|
| 162 |
+
error="Notification delivery exhausted retries",
|
| 163 |
+
)
|
| 164 |
+
finally:
|
| 165 |
+
if owns_client:
|
| 166 |
+
await client.aclose()
|
agent/messaging/models.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Annotated, Literal
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
| 4 |
+
|
| 5 |
+
_DESTINATION_NAME_CHARS = set("abcdefghijklmnopqrstuvwxyz0123456789._-")
|
| 6 |
+
SUPPORTED_AUTO_EVENT_TYPES = {"approval_required", "error", "turn_complete"}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SlackDestinationConfig(BaseModel):
|
| 10 |
+
provider: Literal["slack"] = "slack"
|
| 11 |
+
token: str
|
| 12 |
+
channel: str
|
| 13 |
+
allow_agent_tool: bool = False
|
| 14 |
+
allow_auto_events: bool = False
|
| 15 |
+
username: str | None = None
|
| 16 |
+
icon_emoji: str | None = None
|
| 17 |
+
|
| 18 |
+
@field_validator("token", "channel")
|
| 19 |
+
@classmethod
|
| 20 |
+
def _require_non_empty(cls, value: str) -> str:
|
| 21 |
+
value = value.strip()
|
| 22 |
+
if not value:
|
| 23 |
+
raise ValueError("must not be empty")
|
| 24 |
+
return value
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
DestinationConfig = Annotated[SlackDestinationConfig, Field(discriminator="provider")]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MessagingConfig(BaseModel):
|
| 31 |
+
enabled: bool = False
|
| 32 |
+
auto_event_types: list[str] = Field(
|
| 33 |
+
default_factory=lambda: ["approval_required", "error", "turn_complete"]
|
| 34 |
+
)
|
| 35 |
+
destinations: dict[str, DestinationConfig] = Field(default_factory=dict)
|
| 36 |
+
|
| 37 |
+
@field_validator("destinations")
|
| 38 |
+
@classmethod
|
| 39 |
+
def _validate_destination_names(
|
| 40 |
+
cls, destinations: dict[str, DestinationConfig]
|
| 41 |
+
) -> dict[str, DestinationConfig]:
|
| 42 |
+
for name in destinations:
|
| 43 |
+
if not name or any(char not in _DESTINATION_NAME_CHARS for char in name):
|
| 44 |
+
raise ValueError(
|
| 45 |
+
"destination names must use lowercase letters, digits, '.', '_' or '-'"
|
| 46 |
+
)
|
| 47 |
+
return destinations
|
| 48 |
+
|
| 49 |
+
@field_validator("auto_event_types")
|
| 50 |
+
@classmethod
|
| 51 |
+
def _validate_auto_event_types(cls, event_types: list[str]) -> list[str]:
|
| 52 |
+
if not event_types:
|
| 53 |
+
return []
|
| 54 |
+
normalized: list[str] = []
|
| 55 |
+
seen: set[str] = set()
|
| 56 |
+
for event_type in event_types:
|
| 57 |
+
if event_type not in SUPPORTED_AUTO_EVENT_TYPES:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"unsupported auto event type '{event_type}'"
|
| 60 |
+
)
|
| 61 |
+
if event_type not in seen:
|
| 62 |
+
normalized.append(event_type)
|
| 63 |
+
seen.add(event_type)
|
| 64 |
+
return normalized
|
| 65 |
+
|
| 66 |
+
@model_validator(mode="after")
|
| 67 |
+
def _require_destinations_when_enabled(self) -> "MessagingConfig":
|
| 68 |
+
if self.enabled and not self.destinations:
|
| 69 |
+
raise ValueError("messaging.enabled requires at least one destination")
|
| 70 |
+
return self
|
| 71 |
+
|
| 72 |
+
def get_destination(self, name: str) -> DestinationConfig | None:
|
| 73 |
+
return self.destinations.get(name)
|
| 74 |
+
|
| 75 |
+
def can_agent_tool_send(self, name: str) -> bool:
|
| 76 |
+
destination = self.get_destination(name)
|
| 77 |
+
return bool(destination and destination.allow_agent_tool)
|
| 78 |
+
|
| 79 |
+
def can_auto_send(self, name: str) -> bool:
|
| 80 |
+
destination = self.get_destination(name)
|
| 81 |
+
return bool(destination and destination.allow_auto_events)
|
| 82 |
+
|
| 83 |
+
def default_auto_destinations(self) -> list[str]:
|
| 84 |
+
if not self.enabled:
|
| 85 |
+
return []
|
| 86 |
+
return [
|
| 87 |
+
name
|
| 88 |
+
for name in self.destinations
|
| 89 |
+
if self.can_auto_send(name)
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class NotificationRequest(BaseModel):
|
| 94 |
+
destination: str
|
| 95 |
+
title: str | None = None
|
| 96 |
+
message: str
|
| 97 |
+
severity: Literal["info", "success", "warning", "error"] = "info"
|
| 98 |
+
metadata: dict[str, str] = Field(default_factory=dict)
|
| 99 |
+
event_type: str | None = None
|
| 100 |
+
|
| 101 |
+
@field_validator("destination", "message")
|
| 102 |
+
@classmethod
|
| 103 |
+
def _require_text(cls, value: str) -> str:
|
| 104 |
+
value = value.strip()
|
| 105 |
+
if not value:
|
| 106 |
+
raise ValueError("must not be empty")
|
| 107 |
+
return value
|
| 108 |
+
|
| 109 |
+
@field_validator("title")
|
| 110 |
+
@classmethod
|
| 111 |
+
def _normalize_title(cls, value: str | None) -> str | None:
|
| 112 |
+
if value is None:
|
| 113 |
+
return None
|
| 114 |
+
value = value.strip()
|
| 115 |
+
return value or None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class NotificationResult(BaseModel):
|
| 119 |
+
destination: str
|
| 120 |
+
ok: bool
|
| 121 |
+
provider: str
|
| 122 |
+
error: str | None = None
|
| 123 |
+
external_id: str | None = None
|
agent/messaging/slack.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import httpx
|
| 5 |
+
|
| 6 |
+
from agent.messaging.base import (
|
| 7 |
+
NotificationError,
|
| 8 |
+
NotificationProvider,
|
| 9 |
+
RetryableNotificationError,
|
| 10 |
+
)
|
| 11 |
+
from agent.messaging.models import (
|
| 12 |
+
NotificationRequest,
|
| 13 |
+
NotificationResult,
|
| 14 |
+
SlackDestinationConfig,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
_SEVERITY_PREFIX = {
|
| 18 |
+
"info": "[INFO]",
|
| 19 |
+
"success": "[SUCCESS]",
|
| 20 |
+
"warning": "[WARNING]",
|
| 21 |
+
"error": "[ERROR]",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _format_slack_mrkdwn(content: str) -> str:
|
| 26 |
+
"""Convert common Markdown constructs to Slack's mrkdwn syntax."""
|
| 27 |
+
if not content:
|
| 28 |
+
return content
|
| 29 |
+
|
| 30 |
+
placeholders: dict[str, str] = {}
|
| 31 |
+
placeholder_index = 0
|
| 32 |
+
|
| 33 |
+
def placeholder(value: str) -> str:
|
| 34 |
+
nonlocal placeholder_index
|
| 35 |
+
key = f"\x00SLACK{placeholder_index}\x00"
|
| 36 |
+
placeholder_index += 1
|
| 37 |
+
placeholders[key] = value
|
| 38 |
+
return key
|
| 39 |
+
|
| 40 |
+
text = content
|
| 41 |
+
|
| 42 |
+
# Protect code before any formatting conversion. Slack's mrkdwn ignores
|
| 43 |
+
# formatting inside backticks, so these regions should stay byte-for-byte.
|
| 44 |
+
text = re.sub(
|
| 45 |
+
r"(```(?:[^\n]*\n)?[\s\S]*?```)",
|
| 46 |
+
lambda match: placeholder(match.group(0)),
|
| 47 |
+
text,
|
| 48 |
+
)
|
| 49 |
+
text = re.sub(r"(`[^`\n]+`)", lambda match: placeholder(match.group(0)), text)
|
| 50 |
+
|
| 51 |
+
def convert_markdown_link(match: re.Match[str]) -> str:
|
| 52 |
+
label = match.group(1)
|
| 53 |
+
url = match.group(2).strip()
|
| 54 |
+
if url.startswith("<") and url.endswith(">"):
|
| 55 |
+
url = url[1:-1].strip()
|
| 56 |
+
return placeholder(f"<{url}|{label}>")
|
| 57 |
+
|
| 58 |
+
text = re.sub(
|
| 59 |
+
r"\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)",
|
| 60 |
+
convert_markdown_link,
|
| 61 |
+
text,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Preserve existing Slack entities and manual mrkdwn links before escaping.
|
| 65 |
+
text = re.sub(
|
| 66 |
+
r"(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)",
|
| 67 |
+
lambda match: placeholder(match.group(1)),
|
| 68 |
+
text,
|
| 69 |
+
)
|
| 70 |
+
text = re.sub(
|
| 71 |
+
r"^(>+\s)",
|
| 72 |
+
lambda match: placeholder(match.group(0)),
|
| 73 |
+
text,
|
| 74 |
+
flags=re.MULTILINE,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
| 78 |
+
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
| 79 |
+
|
| 80 |
+
def convert_header(match: re.Match[str]) -> str:
|
| 81 |
+
header = match.group(1).strip()
|
| 82 |
+
header = re.sub(r"\*\*(.+?)\*\*", r"\1", header)
|
| 83 |
+
return placeholder(f"*{header}*")
|
| 84 |
+
|
| 85 |
+
text = re.sub(r"^#{1,6}\s+(.+)$", convert_header, text, flags=re.MULTILINE)
|
| 86 |
+
text = re.sub(
|
| 87 |
+
r"\*\*\*(.+?)\*\*\*",
|
| 88 |
+
lambda match: placeholder(f"*_{match.group(1)}_*"),
|
| 89 |
+
text,
|
| 90 |
+
)
|
| 91 |
+
text = re.sub(
|
| 92 |
+
r"\*\*(.+?)\*\*",
|
| 93 |
+
lambda match: placeholder(f"*{match.group(1)}*"),
|
| 94 |
+
text,
|
| 95 |
+
)
|
| 96 |
+
text = re.sub(
|
| 97 |
+
r"(?<!\*)\*([^*\n]+)\*(?!\*)",
|
| 98 |
+
lambda match: placeholder(f"_{match.group(1)}_"),
|
| 99 |
+
text,
|
| 100 |
+
)
|
| 101 |
+
text = re.sub(
|
| 102 |
+
r"~~(.+?)~~",
|
| 103 |
+
lambda match: placeholder(f"~{match.group(1)}~"),
|
| 104 |
+
text,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
for key in reversed(placeholders):
|
| 108 |
+
text = text.replace(key, placeholders[key])
|
| 109 |
+
|
| 110 |
+
return text
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _format_text(request: NotificationRequest) -> str:
|
| 114 |
+
lines: list[str] = []
|
| 115 |
+
prefix = _SEVERITY_PREFIX[request.severity]
|
| 116 |
+
if request.title:
|
| 117 |
+
lines.append(f"{prefix} {request.title}")
|
| 118 |
+
else:
|
| 119 |
+
lines.append(prefix)
|
| 120 |
+
lines.append(request.message)
|
| 121 |
+
for key, value in request.metadata.items():
|
| 122 |
+
lines.append(f"{key}: {value}")
|
| 123 |
+
return _format_slack_mrkdwn("\n".join(lines))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class SlackProvider(NotificationProvider):
|
| 127 |
+
provider_name = "slack"
|
| 128 |
+
|
| 129 |
+
async def send(
|
| 130 |
+
self,
|
| 131 |
+
client: httpx.AsyncClient,
|
| 132 |
+
destination_name: str,
|
| 133 |
+
destination: SlackDestinationConfig,
|
| 134 |
+
request: NotificationRequest,
|
| 135 |
+
) -> NotificationResult:
|
| 136 |
+
payload = {
|
| 137 |
+
"channel": destination.channel,
|
| 138 |
+
"text": _format_text(request),
|
| 139 |
+
"mrkdwn": True,
|
| 140 |
+
"unfurl_links": False,
|
| 141 |
+
"unfurl_media": False,
|
| 142 |
+
}
|
| 143 |
+
if destination.username:
|
| 144 |
+
payload["username"] = destination.username
|
| 145 |
+
if destination.icon_emoji:
|
| 146 |
+
payload["icon_emoji"] = destination.icon_emoji
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
response = await client.post(
|
| 150 |
+
"https://slack.com/api/chat.postMessage",
|
| 151 |
+
headers={
|
| 152 |
+
"Authorization": f"Bearer {destination.token}",
|
| 153 |
+
"Content-Type": "application/json; charset=utf-8",
|
| 154 |
+
},
|
| 155 |
+
content=json.dumps(payload),
|
| 156 |
+
)
|
| 157 |
+
except httpx.TimeoutException as exc:
|
| 158 |
+
raise RetryableNotificationError("Slack request timed out") from exc
|
| 159 |
+
except httpx.TransportError as exc:
|
| 160 |
+
raise RetryableNotificationError("Slack transport error") from exc
|
| 161 |
+
|
| 162 |
+
if response.status_code == 429 or response.status_code >= 500:
|
| 163 |
+
raise RetryableNotificationError(
|
| 164 |
+
f"Slack HTTP {response.status_code}"
|
| 165 |
+
)
|
| 166 |
+
if response.status_code >= 400:
|
| 167 |
+
raise NotificationError(f"Slack HTTP {response.status_code}")
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
data = response.json()
|
| 171 |
+
except ValueError as exc:
|
| 172 |
+
raise RetryableNotificationError("Slack returned invalid JSON") from exc
|
| 173 |
+
|
| 174 |
+
if not data.get("ok"):
|
| 175 |
+
error = str(data.get("error") or "unknown_error")
|
| 176 |
+
if error == "ratelimited":
|
| 177 |
+
raise RetryableNotificationError(error)
|
| 178 |
+
raise NotificationError(error)
|
| 179 |
+
|
| 180 |
+
return NotificationResult(
|
| 181 |
+
destination=destination_name,
|
| 182 |
+
ok=True,
|
| 183 |
+
provider=self.provider_name,
|
| 184 |
+
external_id=str(data.get("ts") or ""),
|
| 185 |
+
error=None,
|
| 186 |
+
)
|
agent/prompts/system_prompt_v3.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
system_prompt: |
|
| 2 |
-
You are
|
| 3 |
|
| 4 |
Your goal is to complete what the user requested with zero errors. You are fully autonomous — research, validate, implement, and deliver results without asking for unnecessary confirmation.
|
| 5 |
|
|
@@ -28,7 +28,7 @@ system_prompt: |
|
|
| 28 |
|
| 29 |
# Mistakes you WILL make without research
|
| 30 |
|
| 31 |
-
HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio
|
| 32 |
|
| 33 |
WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs.
|
| 34 |
|
|
@@ -60,6 +60,38 @@ system_prompt: |
|
|
| 60 |
DPO: "prompt", "chosen", "rejected"
|
| 61 |
GRPO: "prompt"
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
# Data audit
|
| 64 |
|
| 65 |
Before working with any dataset, audit it first. Do not assume you know what the data looks like — inspect it.
|
|
@@ -75,7 +107,7 @@ system_prompt: |
|
|
| 75 |
- Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
|
| 76 |
- push_to_hub=True and hub_model_id set
|
| 77 |
- timeout: [value] (based on: [model size] on [hardware])
|
| 78 |
-
- Trackio monitoring included and
|
| 79 |
|
| 80 |
If you cannot fill in all items, stop and complete the missing steps first.
|
| 81 |
|
|
@@ -156,6 +188,7 @@ system_prompt: |
|
|
| 156 |
- Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs.
|
| 157 |
- For errors: state what went wrong, why, and what you're doing to fix it.
|
| 158 |
- Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity.
|
|
|
|
| 159 |
|
| 160 |
# Tool usage
|
| 161 |
|
|
|
|
| 1 |
system_prompt: |
|
| 2 |
+
You are ML Intern, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face (HF) ecosystem.
|
| 3 |
|
| 4 |
Your goal is to complete what the user requested with zero errors. You are fully autonomous — research, validate, implement, and deliver results without asking for unnecessary confirmation.
|
| 5 |
|
|
|
|
| 28 |
|
| 29 |
# Mistakes you WILL make without research
|
| 30 |
|
| 31 |
+
HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio config field names. Fix: read a current example script first.
|
| 32 |
|
| 33 |
WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs.
|
| 34 |
|
|
|
|
| 60 |
DPO: "prompt", "chosen", "rejected"
|
| 61 |
GRPO: "prompt"
|
| 62 |
|
| 63 |
+
# Trackio
|
| 64 |
+
|
| 65 |
+
Trackio is natively integrated with Transformers Trainer and all TRL trainers — the built-in TrackioCallback handles init/log/finish. In TrainingArguments/SFTConfig/DPOConfig/GRPOConfig set:
|
| 66 |
+
report_to="trackio"
|
| 67 |
+
run_name="<descriptive-run-name>" # e.g. "sft_qwen3-4b_lr2e-5_bs128"
|
| 68 |
+
project="<descriptive-project-name>" # keeps related runs grouped so you can compare them
|
| 69 |
+
trackio_space_id="<username>/mlintern-<8-char-id>" # creates a public dashboard Space
|
| 70 |
+
`project` and `trackio_space_id` can also be set via TRACKIO_PROJECT / TRACKIO_SPACE_ID env vars.
|
| 71 |
+
|
| 72 |
+
Alerts are how iterations decide what to change. Use trackio.alert(title, text, level) at every decision point in training. Levels:
|
| 73 |
+
ERROR — stop and change approach (divergence, NaN, OOM)
|
| 74 |
+
WARN — tweak hyperparameters (overfitting, early stopping, KL spike, reward collapse, slow convergence)
|
| 75 |
+
INFO — milestones (training complete, target reached, checkpoint saved)
|
| 76 |
+
Always include numeric values and an actionable suggestion in `text`, e.g. "loss=12.4 at step 200 — lr likely too high, try ×0.1". A future call must be able to parse it and act on it.
|
| 77 |
+
|
| 78 |
+
To add alerts under Trainer/SFTTrainer/GRPOTrainer, pass a custom TrainerCallback via `callbacks=[...]` that calls trackio.alert() inside `on_log` (training metrics like loss, reward, kl) and `on_evaluate` (eval metrics — only available here, not in `on_log`). Keep each `if` simple: one metric, one threshold. Conditions stay easy to adjust between runs.
|
| 79 |
+
|
| 80 |
+
Read alerts back between runs instead of parsing thousands of metric values. CLI — always use --json:
|
| 81 |
+
trackio get alerts --project <p> --run <r> --json
|
| 82 |
+
trackio get alerts --project <p> --since <iso8601> --json # incremental polling
|
| 83 |
+
trackio get run --project <p> --run <r> --json
|
| 84 |
+
trackio get metric --project <p> --run <r> --metric <m> --json
|
| 85 |
+
trackio list runs --project <p> --json
|
| 86 |
+
Python: api = trackio.Api(); api.alerts(<p>, run=<r>, since=<ts>); api.runs(<p>) (each run has .name, .config, .alerts()).
|
| 87 |
+
|
| 88 |
+
Drive the next config from prior alerts:
|
| 89 |
+
diverged → lr × 0.1
|
| 90 |
+
overfitting → weight_decay × 10 or reduce capacity
|
| 91 |
+
early stopping → lr × 0.5 or adjust schedule
|
| 92 |
+
high accuracy → refine around current config
|
| 93 |
+
Read prior config via api.runs(...).config and only mutate keys the alerts justify changing.
|
| 94 |
+
|
| 95 |
# Data audit
|
| 96 |
|
| 97 |
Before working with any dataset, audit it first. Do not assume you know what the data looks like — inspect it.
|
|
|
|
| 107 |
- Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
|
| 108 |
- push_to_hub=True and hub_model_id set
|
| 109 |
- timeout: [value] (based on: [model size] on [hardware])
|
| 110 |
+
- Trackio monitoring included and deploying metrics to a public Space
|
| 111 |
|
| 112 |
If you cannot fill in all items, stop and complete the missing steps first.
|
| 113 |
|
|
|
|
| 188 |
- Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs.
|
| 189 |
- For errors: state what went wrong, why, and what you're doing to fix it.
|
| 190 |
- Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity.
|
| 191 |
+
- Use the `notify` tool only when the user explicitly asked for out-of-band notifications or when the task clearly requires reporting to a configured messaging destination. Do not use it for routine chat updates.
|
| 192 |
|
| 193 |
# Tool usage
|
| 194 |
|
agent/tools/__init__.py
CHANGED
|
@@ -20,6 +20,7 @@ from agent.tools.github_read_file import (
|
|
| 20 |
)
|
| 21 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
|
| 22 |
from agent.tools.types import ToolResult
|
|
|
|
| 23 |
|
| 24 |
__all__ = [
|
| 25 |
"ToolResult",
|
|
@@ -36,4 +37,6 @@ __all__ = [
|
|
| 36 |
"github_search_code_handler",
|
| 37 |
"HF_INSPECT_DATASET_TOOL_SPEC",
|
| 38 |
"hf_inspect_dataset_handler",
|
|
|
|
|
|
|
| 39 |
]
|
|
|
|
| 20 |
)
|
| 21 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
|
| 22 |
from agent.tools.types import ToolResult
|
| 23 |
+
from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler
|
| 24 |
|
| 25 |
__all__ = [
|
| 26 |
"ToolResult",
|
|
|
|
| 37 |
"github_search_code_handler",
|
| 38 |
"HF_INSPECT_DATASET_TOOL_SPEC",
|
| 39 |
"hf_inspect_dataset_handler",
|
| 40 |
+
"WEB_SEARCH_TOOL_SPEC",
|
| 41 |
+
"web_search_handler",
|
| 42 |
]
|
agent/tools/jobs_tool.py
CHANGED
|
@@ -19,6 +19,7 @@ from huggingface_hub.utils import HfHubHTTPError
|
|
| 19 |
|
| 20 |
from agent.core.hf_access import JobsAccessError, resolve_jobs_namespace
|
| 21 |
from agent.core.session import Event
|
|
|
|
| 22 |
from agent.tools.types import ToolResult
|
| 23 |
|
| 24 |
logger = logging.getLogger(__name__)
|
|
@@ -382,6 +383,31 @@ class HfJobsTool:
|
|
| 382 |
"isError": True,
|
| 383 |
}
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
async def _wait_for_job_completion(
|
| 386 |
self, job_id: str, namespace: Optional[str] = None
|
| 387 |
) -> tuple[str, list[str]]:
|
|
@@ -533,11 +559,24 @@ class HfJobsTool:
|
|
| 533 |
# Run the job
|
| 534 |
flavor = args.get("hardware_flavor", "cpu-basic")
|
| 535 |
timeout_str = args.get("timeout", "30m")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
job = await _async_call(
|
| 537 |
self.api.run_job,
|
| 538 |
image=image,
|
| 539 |
command=command,
|
| 540 |
-
env=
|
| 541 |
secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
|
| 542 |
flavor=flavor,
|
| 543 |
timeout=timeout_str,
|
|
@@ -550,16 +589,18 @@ class HfJobsTool:
|
|
| 550 |
|
| 551 |
# Send job URL immediately after job creation (before waiting for completion)
|
| 552 |
if self.session and self.tool_call_id:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
await self.session.send_event(
|
| 554 |
-
Event(
|
| 555 |
-
event_type="tool_state_change",
|
| 556 |
-
data={
|
| 557 |
-
"tool_call_id": self.tool_call_id,
|
| 558 |
-
"tool": "hf_jobs",
|
| 559 |
-
"state": "running",
|
| 560 |
-
"jobUrl": job.url,
|
| 561 |
-
},
|
| 562 |
-
)
|
| 563 |
)
|
| 564 |
|
| 565 |
# Telemetry: job submission + completion (infra consumption signal).
|
|
@@ -594,16 +635,18 @@ class HfJobsTool:
|
|
| 594 |
|
| 595 |
# Notify frontend of final status
|
| 596 |
if self.session and self.tool_call_id:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
await self.session.send_event(
|
| 598 |
-
Event(
|
| 599 |
-
event_type="tool_state_change",
|
| 600 |
-
data={
|
| 601 |
-
"tool_call_id": self.tool_call_id,
|
| 602 |
-
"tool": "hf_jobs",
|
| 603 |
-
"state": final_status.lower(),
|
| 604 |
-
"jobUrl": job.url,
|
| 605 |
-
},
|
| 606 |
-
)
|
| 607 |
)
|
| 608 |
|
| 609 |
# Filter out UV package installation output
|
|
@@ -977,7 +1020,10 @@ HF_JOBS_TOOL_SPEC = {
|
|
| 977 |
"- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
|
| 978 |
"- Training config MUST include push_to_hub=True and hub_model_id. "
|
| 979 |
"Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
|
| 980 |
-
"- Include trackio monitoring and provide the dashboard URL to the user.
|
|
|
|
|
|
|
|
|
|
| 981 |
"BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. "
|
| 982 |
"Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n"
|
| 983 |
"Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n"
|
|
@@ -1060,6 +1106,26 @@ HF_JOBS_TOOL_SPEC = {
|
|
| 1060 |
"type": "object",
|
| 1061 |
"description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.",
|
| 1062 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1063 |
"namespace": {
|
| 1064 |
"type": "string",
|
| 1065 |
"description": (
|
|
|
|
| 19 |
|
| 20 |
from agent.core.hf_access import JobsAccessError, resolve_jobs_namespace
|
| 21 |
from agent.core.session import Event
|
| 22 |
+
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
| 23 |
from agent.tools.types import ToolResult
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
|
|
|
| 383 |
"isError": True,
|
| 384 |
}
|
| 385 |
|
| 386 |
+
async def _seed_trackio_dashboard(self, space_id: str) -> None:
|
| 387 |
+
"""Idempotently install trackio dashboard files into *space_id* before
|
| 388 |
+
the job runs. Surfaces seed progress as tool_log events but never
|
| 389 |
+
raises — a seed failure should not block job submission, since trackio
|
| 390 |
+
often still works when the Space already has dashboard code from a
|
| 391 |
+
previous run.
|
| 392 |
+
"""
|
| 393 |
+
loop = asyncio.get_running_loop()
|
| 394 |
+
|
| 395 |
+
def _log(msg: str) -> None:
|
| 396 |
+
if self.session is None:
|
| 397 |
+
return
|
| 398 |
+
loop.call_soon_threadsafe(
|
| 399 |
+
self.session.event_queue.put_nowait,
|
| 400 |
+
Event(event_type="tool_log", data={"tool": "hf_jobs", "log": msg}),
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
try:
|
| 404 |
+
await asyncio.to_thread(
|
| 405 |
+
ensure_trackio_dashboard, space_id, self.hf_token, _log
|
| 406 |
+
)
|
| 407 |
+
except Exception as e:
|
| 408 |
+
logger.warning(f"trackio dashboard seed failed for {space_id}: {e}")
|
| 409 |
+
_log(f"trackio dashboard seed failed: {e}")
|
| 410 |
+
|
| 411 |
async def _wait_for_job_completion(
|
| 412 |
self, job_id: str, namespace: Optional[str] = None
|
| 413 |
) -> tuple[str, list[str]]:
|
|
|
|
| 559 |
# Run the job
|
| 560 |
flavor = args.get("hardware_flavor", "cpu-basic")
|
| 561 |
timeout_str = args.get("timeout", "30m")
|
| 562 |
+
|
| 563 |
+
# Trackio: agent-declared space + project become env vars on the job
|
| 564 |
+
# so trackio.init() picks them up automatically. We also surface them
|
| 565 |
+
# in tool_state_change so the frontend can embed the dashboard.
|
| 566 |
+
env_dict = _add_default_env(args.get("env"))
|
| 567 |
+
trackio_space_id = args.get("trackio_space_id")
|
| 568 |
+
trackio_project = args.get("trackio_project")
|
| 569 |
+
if trackio_space_id:
|
| 570 |
+
env_dict["TRACKIO_SPACE_ID"] = trackio_space_id
|
| 571 |
+
await self._seed_trackio_dashboard(trackio_space_id)
|
| 572 |
+
if trackio_project:
|
| 573 |
+
env_dict["TRACKIO_PROJECT"] = trackio_project
|
| 574 |
+
|
| 575 |
job = await _async_call(
|
| 576 |
self.api.run_job,
|
| 577 |
image=image,
|
| 578 |
command=command,
|
| 579 |
+
env=env_dict,
|
| 580 |
secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
|
| 581 |
flavor=flavor,
|
| 582 |
timeout=timeout_str,
|
|
|
|
| 589 |
|
| 590 |
# Send job URL immediately after job creation (before waiting for completion)
|
| 591 |
if self.session and self.tool_call_id:
|
| 592 |
+
state_data: Dict[str, Any] = {
|
| 593 |
+
"tool_call_id": self.tool_call_id,
|
| 594 |
+
"tool": "hf_jobs",
|
| 595 |
+
"state": "running",
|
| 596 |
+
"jobUrl": job.url,
|
| 597 |
+
}
|
| 598 |
+
if trackio_space_id:
|
| 599 |
+
state_data["trackioSpaceId"] = trackio_space_id
|
| 600 |
+
if trackio_project:
|
| 601 |
+
state_data["trackioProject"] = trackio_project
|
| 602 |
await self.session.send_event(
|
| 603 |
+
Event(event_type="tool_state_change", data=state_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
)
|
| 605 |
|
| 606 |
# Telemetry: job submission + completion (infra consumption signal).
|
|
|
|
| 635 |
|
| 636 |
# Notify frontend of final status
|
| 637 |
if self.session and self.tool_call_id:
|
| 638 |
+
final_data: Dict[str, Any] = {
|
| 639 |
+
"tool_call_id": self.tool_call_id,
|
| 640 |
+
"tool": "hf_jobs",
|
| 641 |
+
"state": final_status.lower(),
|
| 642 |
+
"jobUrl": job.url,
|
| 643 |
+
}
|
| 644 |
+
if trackio_space_id:
|
| 645 |
+
final_data["trackioSpaceId"] = trackio_space_id
|
| 646 |
+
if trackio_project:
|
| 647 |
+
final_data["trackioProject"] = trackio_project
|
| 648 |
await self.session.send_event(
|
| 649 |
+
Event(event_type="tool_state_change", data=final_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
)
|
| 651 |
|
| 652 |
# Filter out UV package installation output
|
|
|
|
| 1020 |
"- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
|
| 1021 |
"- Training config MUST include push_to_hub=True and hub_model_id. "
|
| 1022 |
"Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
|
| 1023 |
+
"- Include trackio monitoring and provide the dashboard URL to the user. "
|
| 1024 |
+
"When the script uses report_to='trackio', also pass `trackio_space_id` "
|
| 1025 |
+
"(e.g. '<username>/mlintern-<8char>') and `trackio_project` as tool args — "
|
| 1026 |
+
"they are injected as TRACKIO_SPACE_ID/TRACKIO_PROJECT env vars and let the UI embed the live dashboard.\n\n"
|
| 1027 |
"BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. "
|
| 1028 |
"Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n"
|
| 1029 |
"Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n"
|
|
|
|
| 1106 |
"type": "object",
|
| 1107 |
"description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.",
|
| 1108 |
},
|
| 1109 |
+
"trackio_space_id": {
|
| 1110 |
+
"type": "string",
|
| 1111 |
+
"description": (
|
| 1112 |
+
"Optional. The HF Space hosting the trackio dashboard for this run "
|
| 1113 |
+
"(e.g. '<username>/mlintern-<8char>', under YOUR HF namespace). "
|
| 1114 |
+
"Injected as TRACKIO_SPACE_ID env var and used by the UI to embed "
|
| 1115 |
+
"the live dashboard. Set this whenever the script uses "
|
| 1116 |
+
"report_to='trackio'. The Space is auto-created and seeded with the "
|
| 1117 |
+
"trackio dashboard before the job starts — DO NOT pre-create it via "
|
| 1118 |
+
"hf_repo_git, that produces an empty Space that breaks the embed."
|
| 1119 |
+
),
|
| 1120 |
+
},
|
| 1121 |
+
"trackio_project": {
|
| 1122 |
+
"type": "string",
|
| 1123 |
+
"description": (
|
| 1124 |
+
"Optional. The trackio project name to log this run under. "
|
| 1125 |
+
"Injected as TRACKIO_PROJECT env var and used by the UI to filter "
|
| 1126 |
+
"the embedded dashboard to this project."
|
| 1127 |
+
),
|
| 1128 |
+
},
|
| 1129 |
"namespace": {
|
| 1130 |
"type": "string",
|
| 1131 |
"description": (
|
agent/tools/notify_tool.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from agent.messaging.models import NotificationRequest
|
| 4 |
+
|
| 5 |
+
NOTIFY_TOOL_SPEC = {
|
| 6 |
+
"name": "notify",
|
| 7 |
+
"description": (
|
| 8 |
+
"Send an out-of-band notification to configured messaging destinations. "
|
| 9 |
+
"Use this only when the user explicitly asked for proactive notifications "
|
| 10 |
+
"or when the task requires reporting progress outside the chat. "
|
| 11 |
+
"Destinations must be named server-side configs such as 'slack.ops'."
|
| 12 |
+
),
|
| 13 |
+
"parameters": {
|
| 14 |
+
"type": "object",
|
| 15 |
+
"properties": {
|
| 16 |
+
"destinations": {
|
| 17 |
+
"type": "array",
|
| 18 |
+
"description": "Named messaging destinations to notify.",
|
| 19 |
+
"items": {"type": "string"},
|
| 20 |
+
"minItems": 1,
|
| 21 |
+
},
|
| 22 |
+
"message": {
|
| 23 |
+
"type": "string",
|
| 24 |
+
"description": "Main notification body.",
|
| 25 |
+
},
|
| 26 |
+
"title": {
|
| 27 |
+
"type": "string",
|
| 28 |
+
"description": "Optional short title line.",
|
| 29 |
+
},
|
| 30 |
+
"severity": {
|
| 31 |
+
"type": "string",
|
| 32 |
+
"enum": ["info", "success", "warning", "error"],
|
| 33 |
+
"description": "Notification severity label.",
|
| 34 |
+
},
|
| 35 |
+
},
|
| 36 |
+
"required": ["destinations", "message"],
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def notify_handler(
|
| 42 |
+
arguments: dict[str, Any], session=None, **_kwargs
|
| 43 |
+
) -> tuple[str, bool]:
|
| 44 |
+
if session is None or session.notification_gateway is None:
|
| 45 |
+
return "Messaging is not configured for this session.", False
|
| 46 |
+
|
| 47 |
+
raw_destinations = arguments.get("destinations", [])
|
| 48 |
+
if not isinstance(raw_destinations, list) or not raw_destinations:
|
| 49 |
+
return "destinations must be a non-empty array of destination names.", False
|
| 50 |
+
|
| 51 |
+
destinations: list[str] = []
|
| 52 |
+
seen: set[str] = set()
|
| 53 |
+
for raw_name in raw_destinations:
|
| 54 |
+
if not isinstance(raw_name, str):
|
| 55 |
+
return "Each destination must be a string.", False
|
| 56 |
+
name = raw_name.strip()
|
| 57 |
+
if not name:
|
| 58 |
+
return "Destination names must not be empty.", False
|
| 59 |
+
if name not in seen:
|
| 60 |
+
destinations.append(name)
|
| 61 |
+
seen.add(name)
|
| 62 |
+
|
| 63 |
+
disallowed = [
|
| 64 |
+
name
|
| 65 |
+
for name in destinations
|
| 66 |
+
if not session.config.messaging.can_agent_tool_send(name)
|
| 67 |
+
]
|
| 68 |
+
if disallowed:
|
| 69 |
+
return (
|
| 70 |
+
"These destinations are unavailable for the notify tool: "
|
| 71 |
+
+ ", ".join(disallowed)
|
| 72 |
+
), False
|
| 73 |
+
|
| 74 |
+
message = arguments.get("message", "")
|
| 75 |
+
if not isinstance(message, str) or not message.strip():
|
| 76 |
+
return "message must be a non-empty string.", False
|
| 77 |
+
|
| 78 |
+
title = arguments.get("title")
|
| 79 |
+
severity = arguments.get("severity", "info")
|
| 80 |
+
if title is not None and not isinstance(title, str):
|
| 81 |
+
return "title must be a string when provided.", False
|
| 82 |
+
if severity not in {"info", "success", "warning", "error"}:
|
| 83 |
+
return "severity must be one of: info, success, warning, error.", False
|
| 84 |
+
|
| 85 |
+
requests = [
|
| 86 |
+
NotificationRequest(
|
| 87 |
+
destination=name,
|
| 88 |
+
title=title,
|
| 89 |
+
message=message,
|
| 90 |
+
severity=severity,
|
| 91 |
+
metadata={
|
| 92 |
+
"session_id": session.session_id,
|
| 93 |
+
"model": session.config.model_name,
|
| 94 |
+
},
|
| 95 |
+
)
|
| 96 |
+
for name in destinations
|
| 97 |
+
]
|
| 98 |
+
results = await session.notification_gateway.send_many(requests)
|
| 99 |
+
|
| 100 |
+
lines = []
|
| 101 |
+
all_ok = True
|
| 102 |
+
for result in results:
|
| 103 |
+
if result.ok:
|
| 104 |
+
lines.append(f"{result.destination}: sent")
|
| 105 |
+
else:
|
| 106 |
+
all_ok = False
|
| 107 |
+
lines.append(f"{result.destination}: failed ({result.error})")
|
| 108 |
+
return "\n".join(lines), all_ok
|
agent/tools/research_tool.py
CHANGED
|
@@ -37,6 +37,7 @@ RESEARCH_TOOL_NAMES = {
|
|
| 37 |
"github_find_examples",
|
| 38 |
"github_list_repos",
|
| 39 |
"github_read_file",
|
|
|
|
| 40 |
"hf_inspect_dataset",
|
| 41 |
"hf_repo_files",
|
| 42 |
}
|
|
@@ -102,6 +103,8 @@ tell you what actually works.
|
|
| 102 |
- `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc.
|
| 103 |
- `fetch_hf_docs(url)`: Fetch full page content from explore results
|
| 104 |
- `find_hf_api(query=..., tag=...)`: Find REST API endpoints
|
|
|
|
|
|
|
| 105 |
|
| 106 |
## Hub repo inspection
|
| 107 |
- `hf_repo_files`: List/read files in any HF repo (model, dataset, space)
|
|
@@ -306,8 +309,10 @@ async def research_handler(
|
|
| 306 |
# ── Doom-loop detection ──
|
| 307 |
doom_prompt = check_for_doom_loop(messages)
|
| 308 |
if doom_prompt:
|
| 309 |
-
logger.warning(
|
| 310 |
-
|
|
|
|
|
|
|
| 311 |
messages.append(Message(role="user", content=doom_prompt))
|
| 312 |
|
| 313 |
# ── Context budget: warn at 75%, hard-stop at 95% ──
|
|
@@ -424,7 +429,7 @@ async def research_handler(
|
|
| 424 |
await _log(f"▸ {tool_name} {args_str}")
|
| 425 |
|
| 426 |
output, _success = await session.tool_router.call_tool(
|
| 427 |
-
tool_name, tool_args, session=session
|
| 428 |
)
|
| 429 |
_tool_uses += 1
|
| 430 |
await _log(f"tools:{_tool_uses}")
|
|
|
|
| 37 |
"github_find_examples",
|
| 38 |
"github_list_repos",
|
| 39 |
"github_read_file",
|
| 40 |
+
"web_search",
|
| 41 |
"hf_inspect_dataset",
|
| 42 |
"hf_repo_files",
|
| 43 |
}
|
|
|
|
| 103 |
- `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc.
|
| 104 |
- `fetch_hf_docs(url)`: Fetch full page content from explore results
|
| 105 |
- `find_hf_api(query=..., tag=...)`: Find REST API endpoints
|
| 106 |
+
- `web_search(query=..., allowed_domains=[...], blocked_domains=[...])`:
|
| 107 |
+
Search the current web when papers/docs/GitHub are not enough.
|
| 108 |
|
| 109 |
## Hub repo inspection
|
| 110 |
- `hf_repo_files`: List/read files in any HF repo (model, dataset, space)
|
|
|
|
| 309 |
# ── Doom-loop detection ──
|
| 310 |
doom_prompt = check_for_doom_loop(messages)
|
| 311 |
if doom_prompt:
|
| 312 |
+
logger.warning(
|
| 313 |
+
"Research sub-agent repetition guard activated at iteration %d",
|
| 314 |
+
_iteration,
|
| 315 |
+
)
|
| 316 |
messages.append(Message(role="user", content=doom_prompt))
|
| 317 |
|
| 318 |
# ── Context budget: warn at 75%, hard-stop at 95% ──
|
|
|
|
| 429 |
await _log(f"▸ {tool_name} {args_str}")
|
| 430 |
|
| 431 |
output, _success = await session.tool_router.call_tool(
|
| 432 |
+
tool_name, tool_args, session=session, tool_call_id=tc.id
|
| 433 |
)
|
| 434 |
_tool_uses += 1
|
| 435 |
await _log(f"tools:{_tool_uses}")
|
agent/tools/sandbox_client.py
CHANGED
|
@@ -37,6 +37,7 @@ Tools: bash, read, write, edit, upload
|
|
| 37 |
from __future__ import annotations
|
| 38 |
|
| 39 |
import io
|
|
|
|
| 40 |
import sys
|
| 41 |
import time
|
| 42 |
import uuid
|
|
@@ -99,8 +100,8 @@ CMD ["python", "sandbox_server.py"]
|
|
| 99 |
|
| 100 |
_SANDBOX_SERVER = '''\
|
| 101 |
"""Minimal FastAPI server for sandbox operations."""
|
| 102 |
-
import os, subprocess, pathlib, signal, threading, re, tempfile
|
| 103 |
-
from fastapi import FastAPI
|
| 104 |
from pydantic import BaseModel
|
| 105 |
from typing import Optional
|
| 106 |
import uvicorn
|
|
@@ -156,6 +157,22 @@ def _atomic_write(path: pathlib.Path, content: str):
|
|
| 156 |
|
| 157 |
app = FastAPI()
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
# Track active bash processes so they can be killed on cancel
|
| 160 |
_active_procs = {} # pid -> subprocess.Popen
|
| 161 |
_proc_lock = threading.Lock()
|
|
@@ -344,7 +361,7 @@ def _validate_python(content, path=""):
|
|
| 344 |
def health():
|
| 345 |
return {"status": "ok"}
|
| 346 |
|
| 347 |
-
@app.post("/api/bash")
|
| 348 |
def bash(req: BashReq):
|
| 349 |
try:
|
| 350 |
proc = subprocess.Popen(
|
|
@@ -371,7 +388,7 @@ def bash(req: BashReq):
|
|
| 371 |
except Exception as e:
|
| 372 |
return {"success": False, "output": "", "error": str(e)}
|
| 373 |
|
| 374 |
-
@app.post("/api/kill")
|
| 375 |
def kill_all():
|
| 376 |
"""Kill all active bash processes. Called when user cancels."""
|
| 377 |
with _proc_lock:
|
|
@@ -389,7 +406,7 @@ def kill_all():
|
|
| 389 |
pass
|
| 390 |
return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""}
|
| 391 |
|
| 392 |
-
@app.post("/api/read")
|
| 393 |
def read(req: ReadReq):
|
| 394 |
try:
|
| 395 |
p = pathlib.Path(req.path)
|
|
@@ -406,7 +423,7 @@ def read(req: ReadReq):
|
|
| 406 |
except Exception as e:
|
| 407 |
return {"success": False, "output": "", "error": str(e)}
|
| 408 |
|
| 409 |
-
@app.post("/api/write")
|
| 410 |
def write(req: WriteReq):
|
| 411 |
try:
|
| 412 |
p = pathlib.Path(req.path)
|
|
@@ -420,7 +437,7 @@ def write(req: WriteReq):
|
|
| 420 |
except Exception as e:
|
| 421 |
return {"success": False, "output": "", "error": str(e)}
|
| 422 |
|
| 423 |
-
@app.post("/api/edit")
|
| 424 |
def edit(req: EditReq):
|
| 425 |
try:
|
| 426 |
p = pathlib.Path(req.path)
|
|
@@ -447,7 +464,7 @@ def edit(req: EditReq):
|
|
| 447 |
except Exception as e:
|
| 448 |
return {"success": False, "output": "", "error": str(e)}
|
| 449 |
|
| 450 |
-
@app.post("/api/exists")
|
| 451 |
def exists(req: ExistsReq):
|
| 452 |
return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""}
|
| 453 |
|
|
@@ -482,6 +499,7 @@ class Sandbox:
|
|
| 482 |
|
| 483 |
space_id: str
|
| 484 |
token: str | None = None
|
|
|
|
| 485 |
work_dir: str = "/app"
|
| 486 |
timeout: int = DEFAULT_TIMEOUT
|
| 487 |
_owns_space: bool = field(default=False, repr=False)
|
|
@@ -495,9 +513,10 @@ class Sandbox:
|
|
| 495 |
# Trailing slash is critical: httpx resolves relative paths against base_url.
|
| 496 |
# Without it, client.get("health") resolves to /health instead of /api/health.
|
| 497 |
self._base_url = f"https://{slug}.hf.space/api/"
|
|
|
|
| 498 |
self._client = httpx.Client(
|
| 499 |
base_url=self._base_url,
|
| 500 |
-
headers={"Authorization": f"Bearer {
|
| 501 |
timeout=httpx.Timeout(MAX_TIMEOUT, connect=30),
|
| 502 |
follow_redirects=True,
|
| 503 |
)
|
|
@@ -563,6 +582,7 @@ class Sandbox:
|
|
| 563 |
base = name or "sandbox"
|
| 564 |
suffix = uuid.uuid4().hex[:8]
|
| 565 |
space_id = f"{owner}/{base}-{suffix}"
|
|
|
|
| 566 |
|
| 567 |
_log(f"Creating sandbox: {space_id} (from {template})...")
|
| 568 |
|
|
@@ -583,8 +603,9 @@ class Sandbox:
|
|
| 583 |
# Inject secrets BEFORE uploading server files (which triggers rebuild).
|
| 584 |
# Secrets added after a Space is running aren't available until restart,
|
| 585 |
# so they must be set before the build/start cycle.
|
| 586 |
-
|
| 587 |
-
|
|
|
|
| 588 |
api.add_space_secret(space_id, key, val)
|
| 589 |
|
| 590 |
# Upload sandbox server and Dockerfile (triggers rebuild)
|
|
@@ -617,7 +638,12 @@ class Sandbox:
|
|
| 617 |
_check_cancel()
|
| 618 |
|
| 619 |
# Wait for the API server to be responsive (non-fatal)
|
| 620 |
-
sb = cls(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
try:
|
| 622 |
sb._wait_for_api(timeout=API_WAIT_TIMEOUT, log=_log)
|
| 623 |
except TimeoutError as e:
|
|
@@ -648,13 +674,24 @@ class Sandbox:
|
|
| 648 |
log("Server files uploaded, rebuild triggered.")
|
| 649 |
|
| 650 |
@classmethod
|
| 651 |
-
def connect(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
"""
|
| 653 |
Connect to an existing running Space.
|
| 654 |
|
| 655 |
Does a health check to verify the Space is reachable.
|
| 656 |
"""
|
| 657 |
-
sb = cls(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
sb._wait_for_api(timeout=60)
|
| 659 |
return sb
|
| 660 |
|
|
@@ -687,6 +724,10 @@ class Sandbox:
|
|
| 687 |
)
|
| 688 |
print(f"Deleting sandbox: {self.space_id}...")
|
| 689 |
self._hf_api.delete_repo(self.space_id, repo_type="space")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
self._client.close()
|
| 691 |
print("Deleted.")
|
| 692 |
|
|
|
|
| 37 |
from __future__ import annotations
|
| 38 |
|
| 39 |
import io
|
| 40 |
+
import secrets as secrets_lib
|
| 41 |
import sys
|
| 42 |
import time
|
| 43 |
import uuid
|
|
|
|
| 100 |
|
| 101 |
_SANDBOX_SERVER = '''\
|
| 102 |
"""Minimal FastAPI server for sandbox operations."""
|
| 103 |
+
import hmac, os, subprocess, pathlib, signal, threading, re, tempfile
|
| 104 |
+
from fastapi import Depends, FastAPI, HTTPException, Request
|
| 105 |
from pydantic import BaseModel
|
| 106 |
from typing import Optional
|
| 107 |
import uvicorn
|
|
|
|
| 157 |
|
| 158 |
app = FastAPI()
|
| 159 |
|
| 160 |
+
def _expected_api_token() -> str:
|
| 161 |
+
return os.environ.get("SANDBOX_API_TOKEN") or os.environ.get("HF_TOKEN") or ""
|
| 162 |
+
|
| 163 |
+
def _require_auth(request: Request) -> None:
|
| 164 |
+
expected = _expected_api_token()
|
| 165 |
+
if not expected:
|
| 166 |
+
raise HTTPException(status_code=503, detail="Sandbox API token not configured")
|
| 167 |
+
auth_header = request.headers.get("authorization", "")
|
| 168 |
+
scheme, _, supplied = auth_header.partition(" ")
|
| 169 |
+
if scheme.lower() != "bearer" or not supplied:
|
| 170 |
+
raise HTTPException(status_code=401, detail="Missing bearer token")
|
| 171 |
+
if not hmac.compare_digest(supplied, expected):
|
| 172 |
+
raise HTTPException(status_code=401, detail="Invalid bearer token")
|
| 173 |
+
|
| 174 |
+
_AUTH = [Depends(_require_auth)]
|
| 175 |
+
|
| 176 |
# Track active bash processes so they can be killed on cancel
|
| 177 |
_active_procs = {} # pid -> subprocess.Popen
|
| 178 |
_proc_lock = threading.Lock()
|
|
|
|
| 361 |
def health():
|
| 362 |
return {"status": "ok"}
|
| 363 |
|
| 364 |
+
@app.post("/api/bash", dependencies=_AUTH)
|
| 365 |
def bash(req: BashReq):
|
| 366 |
try:
|
| 367 |
proc = subprocess.Popen(
|
|
|
|
| 388 |
except Exception as e:
|
| 389 |
return {"success": False, "output": "", "error": str(e)}
|
| 390 |
|
| 391 |
+
@app.post("/api/kill", dependencies=_AUTH)
|
| 392 |
def kill_all():
|
| 393 |
"""Kill all active bash processes. Called when user cancels."""
|
| 394 |
with _proc_lock:
|
|
|
|
| 406 |
pass
|
| 407 |
return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""}
|
| 408 |
|
| 409 |
+
@app.post("/api/read", dependencies=_AUTH)
|
| 410 |
def read(req: ReadReq):
|
| 411 |
try:
|
| 412 |
p = pathlib.Path(req.path)
|
|
|
|
| 423 |
except Exception as e:
|
| 424 |
return {"success": False, "output": "", "error": str(e)}
|
| 425 |
|
| 426 |
+
@app.post("/api/write", dependencies=_AUTH)
|
| 427 |
def write(req: WriteReq):
|
| 428 |
try:
|
| 429 |
p = pathlib.Path(req.path)
|
|
|
|
| 437 |
except Exception as e:
|
| 438 |
return {"success": False, "output": "", "error": str(e)}
|
| 439 |
|
| 440 |
+
@app.post("/api/edit", dependencies=_AUTH)
|
| 441 |
def edit(req: EditReq):
|
| 442 |
try:
|
| 443 |
p = pathlib.Path(req.path)
|
|
|
|
| 464 |
except Exception as e:
|
| 465 |
return {"success": False, "output": "", "error": str(e)}
|
| 466 |
|
| 467 |
+
@app.post("/api/exists", dependencies=_AUTH)
|
| 468 |
def exists(req: ExistsReq):
|
| 469 |
return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""}
|
| 470 |
|
|
|
|
| 499 |
|
| 500 |
space_id: str
|
| 501 |
token: str | None = None
|
| 502 |
+
api_token: str | None = field(default=None, repr=False)
|
| 503 |
work_dir: str = "/app"
|
| 504 |
timeout: int = DEFAULT_TIMEOUT
|
| 505 |
_owns_space: bool = field(default=False, repr=False)
|
|
|
|
| 513 |
# Trailing slash is critical: httpx resolves relative paths against base_url.
|
| 514 |
# Without it, client.get("health") resolves to /health instead of /api/health.
|
| 515 |
self._base_url = f"https://{slug}.hf.space/api/"
|
| 516 |
+
api_token = self.api_token or self.token
|
| 517 |
self._client = httpx.Client(
|
| 518 |
base_url=self._base_url,
|
| 519 |
+
headers={"Authorization": f"Bearer {api_token}"} if api_token else {},
|
| 520 |
timeout=httpx.Timeout(MAX_TIMEOUT, connect=30),
|
| 521 |
follow_redirects=True,
|
| 522 |
)
|
|
|
|
| 582 |
base = name or "sandbox"
|
| 583 |
suffix = uuid.uuid4().hex[:8]
|
| 584 |
space_id = f"{owner}/{base}-{suffix}"
|
| 585 |
+
sandbox_api_token = secrets_lib.token_urlsafe(32)
|
| 586 |
|
| 587 |
_log(f"Creating sandbox: {space_id} (from {template})...")
|
| 588 |
|
|
|
|
| 603 |
# Inject secrets BEFORE uploading server files (which triggers rebuild).
|
| 604 |
# Secrets added after a Space is running aren't available until restart,
|
| 605 |
# so they must be set before the build/start cycle.
|
| 606 |
+
sandbox_secrets = {**(secrets or {}), "SANDBOX_API_TOKEN": sandbox_api_token}
|
| 607 |
+
if sandbox_secrets:
|
| 608 |
+
for key, val in sandbox_secrets.items():
|
| 609 |
api.add_space_secret(space_id, key, val)
|
| 610 |
|
| 611 |
# Upload sandbox server and Dockerfile (triggers rebuild)
|
|
|
|
| 638 |
_check_cancel()
|
| 639 |
|
| 640 |
# Wait for the API server to be responsive (non-fatal)
|
| 641 |
+
sb = cls(
|
| 642 |
+
space_id=space_id,
|
| 643 |
+
token=token,
|
| 644 |
+
api_token=sandbox_api_token,
|
| 645 |
+
_owns_space=True,
|
| 646 |
+
)
|
| 647 |
try:
|
| 648 |
sb._wait_for_api(timeout=API_WAIT_TIMEOUT, log=_log)
|
| 649 |
except TimeoutError as e:
|
|
|
|
| 674 |
log("Server files uploaded, rebuild triggered.")
|
| 675 |
|
| 676 |
@classmethod
|
| 677 |
+
def connect(
|
| 678 |
+
cls,
|
| 679 |
+
space_id: str,
|
| 680 |
+
*,
|
| 681 |
+
token: str | None = None,
|
| 682 |
+
api_token: str | None = None,
|
| 683 |
+
) -> Sandbox:
|
| 684 |
"""
|
| 685 |
Connect to an existing running Space.
|
| 686 |
|
| 687 |
Does a health check to verify the Space is reachable.
|
| 688 |
"""
|
| 689 |
+
sb = cls(
|
| 690 |
+
space_id=space_id,
|
| 691 |
+
token=token,
|
| 692 |
+
api_token=api_token,
|
| 693 |
+
_owns_space=False,
|
| 694 |
+
)
|
| 695 |
sb._wait_for_api(timeout=60)
|
| 696 |
return sb
|
| 697 |
|
|
|
|
| 724 |
)
|
| 725 |
print(f"Deleting sandbox: {self.space_id}...")
|
| 726 |
self._hf_api.delete_repo(self.space_id, repo_type="space")
|
| 727 |
+
# Clear ownership so a second cleanup call (e.g. delete_session +
|
| 728 |
+
# _run_session.finally both fire) early-returns instead of retrying
|
| 729 |
+
# a 404 delete and emitting a spurious ERROR log.
|
| 730 |
+
self._owns_space = False
|
| 731 |
self._client.close()
|
| 732 |
print("Deleted.")
|
| 733 |
|
agent/tools/sandbox_tool.py
CHANGED
|
@@ -12,13 +12,29 @@ a cpu-basic sandbox is auto-created (no approval needed).
|
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
import asyncio
|
|
|
|
|
|
|
| 15 |
import threading
|
|
|
|
| 16 |
from typing import Any
|
| 17 |
|
| 18 |
from huggingface_hub import HfApi, SpaceHardware
|
| 19 |
|
| 20 |
from agent.core.session import Event
|
| 21 |
from agent.tools.sandbox_client import Sandbox
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def _looks_like_path(script: str) -> bool:
|
|
@@ -62,11 +78,89 @@ async def resolve_sandbox_script(
|
|
| 62 |
return None, f"Failed to read {script} from sandbox: {e}"
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# ── Tool name mapping (short agent names → Sandbox client names) ──────
|
| 66 |
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
async def _ensure_sandbox(
|
| 69 |
-
session: Any,
|
|
|
|
|
|
|
|
|
|
| 70 |
) -> tuple[Sandbox | None, str | None]:
|
| 71 |
"""
|
| 72 |
Ensure a sandbox exists on the session. Auto-creates with given hardware if needed.
|
|
@@ -109,6 +203,23 @@ async def _ensure_sandbox(
|
|
| 109 |
Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
|
| 110 |
)
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
# Bridge asyncio cancel event to a threading.Event for the blocking create call.
|
| 113 |
# We poll session._cancelled from the main loop in a background task and set
|
| 114 |
# a threading.Event that Sandbox.create checks during its polling loops.
|
|
@@ -120,11 +231,15 @@ async def _ensure_sandbox(
|
|
| 120 |
|
| 121 |
watcher_task = asyncio.create_task(_watch_cancel())
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
kwargs = {
|
| 124 |
"owner": owner,
|
| 125 |
"hardware": hardware,
|
| 126 |
"token": token,
|
| 127 |
-
"secrets":
|
| 128 |
"log": _log,
|
| 129 |
"cancel_event": cancel_flag,
|
| 130 |
**create_kwargs,
|
|
@@ -188,6 +303,9 @@ SANDBOX_CREATE_TOOL_SPEC = {
|
|
| 188 |
"fp32 ≈ 4 bytes/param, plus ~20% overhead for optimizer states during training.\n"
|
| 189 |
"Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). "
|
| 190 |
"If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n"
|
|
|
|
|
|
|
|
|
|
| 191 |
"Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n"
|
| 192 |
),
|
| 193 |
"parameters": {
|
|
@@ -204,16 +322,49 @@ SANDBOX_CREATE_TOOL_SPEC = {
|
|
| 204 |
"type": "boolean",
|
| 205 |
"description": "If true, create a private Space",
|
| 206 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
},
|
| 208 |
},
|
| 209 |
}
|
| 210 |
|
| 211 |
|
| 212 |
async def sandbox_create_handler(
|
| 213 |
-
args: dict[str, Any], session: Any = None
|
| 214 |
) -> tuple[str, bool]:
|
| 215 |
"""Handle sandbox_create tool calls."""
|
| 216 |
hardware = args.get("hardware", "cpu-basic")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
# If sandbox already exists, return its info
|
| 219 |
if session and getattr(session, "sandbox", None):
|
|
@@ -226,6 +377,7 @@ async def sandbox_create_handler(
|
|
| 226 |
"Hardware cannot be changed by calling sandbox_create again. "
|
| 227 |
"Delete the existing sandbox first if you need a different tier."
|
| 228 |
)
|
|
|
|
| 229 |
return (
|
| 230 |
f"Sandbox already active: {sb.space_id}\n"
|
| 231 |
f"URL: {sb.url}\n"
|
|
@@ -233,18 +385,32 @@ async def sandbox_create_handler(
|
|
| 233 |
f"Use bash/read/write/edit to interact with it."
|
| 234 |
), True
|
| 235 |
|
| 236 |
-
create_kwargs = {}
|
| 237 |
if "private" in args:
|
| 238 |
create_kwargs["private"] = args["private"]
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
try:
|
| 241 |
-
sb, error = await _ensure_sandbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
except Exception as e:
|
| 243 |
return f"Failed to create sandbox: {e}", False
|
| 244 |
|
| 245 |
if error:
|
| 246 |
return error, False
|
| 247 |
|
|
|
|
|
|
|
| 248 |
return (
|
| 249 |
f"Sandbox created: {sb.space_id}\n"
|
| 250 |
f"URL: {sb.url}\n"
|
|
|
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
import asyncio
|
| 15 |
+
import logging
|
| 16 |
+
import re
|
| 17 |
import threading
|
| 18 |
+
from datetime import datetime, timedelta, timezone
|
| 19 |
from typing import Any
|
| 20 |
|
| 21 |
from huggingface_hub import HfApi, SpaceHardware
|
| 22 |
|
| 23 |
from agent.core.session import Event
|
| 24 |
from agent.tools.sandbox_client import Sandbox
|
| 25 |
+
from agent.tools.trackio_seed import ensure_trackio_dashboard
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
# Match the exact suffix pattern Sandbox.create produces: "sandbox-<8 hex>".
|
| 30 |
+
# Used to identify orphan sandboxes from prior sessions safely (won't match
|
| 31 |
+
# user-renamed lookalikes).
|
| 32 |
+
_SANDBOX_NAME_RE = re.compile(r"^sandbox-[a-f0-9]{8}$")
|
| 33 |
+
|
| 34 |
+
# How stale a sandbox must be before we treat it as definitely orphan.
|
| 35 |
+
# Anything more recent could be tied to a still-live session in another tab,
|
| 36 |
+
# so we leave it alone.
|
| 37 |
+
_ORPHAN_STALE_AFTER = timedelta(hours=1)
|
| 38 |
|
| 39 |
|
| 40 |
def _looks_like_path(script: str) -> bool:
|
|
|
|
| 78 |
return None, f"Failed to read {script} from sandbox: {e}"
|
| 79 |
|
| 80 |
|
| 81 |
+
async def _seed_trackio_dashboard_safe(session: Any, space_id: str) -> None:
|
| 82 |
+
"""Idempotently seed *space_id* with trackio dashboard files using the
|
| 83 |
+
session's HF token. Logs progress, swallows errors — a failed seed should
|
| 84 |
+
not block sandbox creation."""
|
| 85 |
+
if not session or not getattr(session, "hf_token", None):
|
| 86 |
+
return
|
| 87 |
+
loop = asyncio.get_running_loop()
|
| 88 |
+
|
| 89 |
+
def _log(msg: str) -> None:
|
| 90 |
+
loop.call_soon_threadsafe(
|
| 91 |
+
session.event_queue.put_nowait,
|
| 92 |
+
Event(event_type="tool_log", data={"tool": "sandbox_create", "log": msg}),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
await asyncio.to_thread(
|
| 97 |
+
ensure_trackio_dashboard, space_id, session.hf_token, _log
|
| 98 |
+
)
|
| 99 |
+
except Exception as e:
|
| 100 |
+
_log(f"trackio dashboard seed failed: {e}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
# ── Tool name mapping (short agent names → Sandbox client names) ──────
|
| 104 |
|
| 105 |
|
| 106 |
+
def _cleanup_user_orphan_sandboxes(
|
| 107 |
+
api: HfApi,
|
| 108 |
+
owner: str,
|
| 109 |
+
log: Any,
|
| 110 |
+
) -> int:
|
| 111 |
+
"""Delete stale ``sandbox-<8hex>`` Spaces in ``owner``'s account.
|
| 112 |
+
|
| 113 |
+
"Stale" = not modified in the last hour. The naming pattern + staleness
|
| 114 |
+
filter together make this safe:
|
| 115 |
+
|
| 116 |
+
* Naming: only matches ``sandbox-<exactly 8 lowercase hex>``, the
|
| 117 |
+
pattern Sandbox.create produces. Won't touch user-renamed Spaces.
|
| 118 |
+
* Staleness: anything modified in the last hour might still be tied
|
| 119 |
+
to a live session in another tab/replica, so we leave it alone.
|
| 120 |
+
|
| 121 |
+
Runs blocking — call via ``asyncio.to_thread``. Best-effort: failures
|
| 122 |
+
are logged but never raised, so a flaky HF API never blocks creation.
|
| 123 |
+
"""
|
| 124 |
+
cutoff = datetime.now(timezone.utc) - _ORPHAN_STALE_AFTER
|
| 125 |
+
deleted = 0
|
| 126 |
+
try:
|
| 127 |
+
spaces = list(api.list_spaces(author=owner, limit=200))
|
| 128 |
+
except Exception as e:
|
| 129 |
+
log(f"orphan sweep: list_spaces failed: {e}")
|
| 130 |
+
return 0
|
| 131 |
+
|
| 132 |
+
for space in spaces:
|
| 133 |
+
space_name = space.id.rsplit("/", 1)[-1]
|
| 134 |
+
if not _SANDBOX_NAME_RE.match(space_name):
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
last_mod = getattr(space, "lastModified", None) or getattr(space, "last_modified", None)
|
| 138 |
+
if isinstance(last_mod, str):
|
| 139 |
+
try:
|
| 140 |
+
last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00"))
|
| 141 |
+
except ValueError:
|
| 142 |
+
last_mod = None
|
| 143 |
+
if last_mod and last_mod > cutoff:
|
| 144 |
+
# Recent — could be a concurrent live session. Skip.
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
api.delete_repo(repo_id=space.id, repo_type="space")
|
| 149 |
+
deleted += 1
|
| 150 |
+
log(f"orphan sweep: deleted {space.id}")
|
| 151 |
+
except Exception as e:
|
| 152 |
+
log(f"orphan sweep: failed to delete {space.id}: {e}")
|
| 153 |
+
|
| 154 |
+
if deleted:
|
| 155 |
+
log(f"orphan sweep: cleaned up {deleted} stale sandbox(es) before create")
|
| 156 |
+
return deleted
|
| 157 |
+
|
| 158 |
+
|
| 159 |
async def _ensure_sandbox(
|
| 160 |
+
session: Any,
|
| 161 |
+
hardware: str = "cpu-basic",
|
| 162 |
+
extra_secrets: dict[str, str] | None = None,
|
| 163 |
+
**create_kwargs,
|
| 164 |
) -> tuple[Sandbox | None, str | None]:
|
| 165 |
"""
|
| 166 |
Ensure a sandbox exists on the session. Auto-creates with given hardware if needed.
|
|
|
|
| 203 |
Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
|
| 204 |
)
|
| 205 |
|
| 206 |
+
# Before we create a new sandbox, sweep this user's stale sandboxes from
|
| 207 |
+
# prior sessions. ``_cleanup_sandbox`` in session_manager fires only on
|
| 208 |
+
# clean session exit; pod kills, WebSocket drops, etc. leave orphans
|
| 209 |
+
# behind, and they accumulate on every new session forever (observed
|
| 210 |
+
# 2310 leaked across the Hub on 2026-04-27). Doing the cleanup here at
|
| 211 |
+
# session start = self-healing, no separate cron needed.
|
| 212 |
+
#
|
| 213 |
+
# The 1h staleness filter is the safety: a sandbox modified in the last
|
| 214 |
+
# hour might still be tied to a live session in another tab, so we skip.
|
| 215 |
+
# Anything older has no realistic chance of being active given typical
|
| 216 |
+
# session lengths.
|
| 217 |
+
try:
|
| 218 |
+
await asyncio.to_thread(_cleanup_user_orphan_sandboxes, api, owner, _log)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
# Cleanup is best-effort — never block sandbox_create on it.
|
| 221 |
+
_log(f"orphan sandbox sweep failed (non-fatal): {e}")
|
| 222 |
+
|
| 223 |
# Bridge asyncio cancel event to a threading.Event for the blocking create call.
|
| 224 |
# We poll session._cancelled from the main loop in a background task and set
|
| 225 |
# a threading.Event that Sandbox.create checks during its polling loops.
|
|
|
|
| 231 |
|
| 232 |
watcher_task = asyncio.create_task(_watch_cancel())
|
| 233 |
|
| 234 |
+
secrets: dict[str, str] = {"HF_TOKEN": token}
|
| 235 |
+
if extra_secrets:
|
| 236 |
+
secrets.update({k: v for k, v in extra_secrets.items() if v})
|
| 237 |
+
|
| 238 |
kwargs = {
|
| 239 |
"owner": owner,
|
| 240 |
"hardware": hardware,
|
| 241 |
"token": token,
|
| 242 |
+
"secrets": secrets,
|
| 243 |
"log": _log,
|
| 244 |
"cancel_event": cancel_flag,
|
| 245 |
**create_kwargs,
|
|
|
|
| 303 |
"fp32 ≈ 4 bytes/param, plus ~20% overhead for optimizer states during training.\n"
|
| 304 |
"Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). "
|
| 305 |
"If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n"
|
| 306 |
+
"If you intend to run a training script in this sandbox that uses report_to='trackio', "
|
| 307 |
+
"pass `trackio_space_id` (e.g. '<username>/mlintern-<8char>') and `trackio_project` so they "
|
| 308 |
+
"are set as TRACKIO_SPACE_ID/TRACKIO_PROJECT secrets in the sandbox and the UI can embed the live dashboard.\n\n"
|
| 309 |
"Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n"
|
| 310 |
),
|
| 311 |
"parameters": {
|
|
|
|
| 322 |
"type": "boolean",
|
| 323 |
"description": "If true, create a private Space",
|
| 324 |
},
|
| 325 |
+
"trackio_space_id": {
|
| 326 |
+
"type": "string",
|
| 327 |
+
"description": (
|
| 328 |
+
"Optional. The HF Space hosting the trackio dashboard for runs in this sandbox "
|
| 329 |
+
"(e.g. '<username>/mlintern-<8char>', under YOUR HF namespace). Injected as "
|
| 330 |
+
"TRACKIO_SPACE_ID secret and surfaced to the UI. The Space is auto-created and "
|
| 331 |
+
"seeded with the trackio dashboard — DO NOT pre-create it via hf_repo_git, "
|
| 332 |
+
"that produces an empty Space that breaks the embed."
|
| 333 |
+
),
|
| 334 |
+
},
|
| 335 |
+
"trackio_project": {
|
| 336 |
+
"type": "string",
|
| 337 |
+
"description": (
|
| 338 |
+
"Optional. The trackio project name. Injected as TRACKIO_PROJECT secret and "
|
| 339 |
+
"used by the UI to filter the embedded dashboard to this project."
|
| 340 |
+
),
|
| 341 |
+
},
|
| 342 |
},
|
| 343 |
},
|
| 344 |
}
|
| 345 |
|
| 346 |
|
| 347 |
async def sandbox_create_handler(
|
| 348 |
+
args: dict[str, Any], session: Any = None, tool_call_id: str | None = None
|
| 349 |
) -> tuple[str, bool]:
|
| 350 |
"""Handle sandbox_create tool calls."""
|
| 351 |
hardware = args.get("hardware", "cpu-basic")
|
| 352 |
+
trackio_space_id = args.get("trackio_space_id") or None
|
| 353 |
+
trackio_project = args.get("trackio_project") or None
|
| 354 |
+
|
| 355 |
+
async def _emit_trackio_state(sb: Sandbox) -> None:
|
| 356 |
+
"""Tell the frontend which trackio dashboard to embed for this sandbox."""
|
| 357 |
+
if not (session and tool_call_id and trackio_space_id):
|
| 358 |
+
return
|
| 359 |
+
data: dict[str, Any] = {
|
| 360 |
+
"tool_call_id": tool_call_id,
|
| 361 |
+
"tool": "sandbox_create",
|
| 362 |
+
"state": "running",
|
| 363 |
+
"trackioSpaceId": trackio_space_id,
|
| 364 |
+
}
|
| 365 |
+
if trackio_project:
|
| 366 |
+
data["trackioProject"] = trackio_project
|
| 367 |
+
await session.send_event(Event(event_type="tool_state_change", data=data))
|
| 368 |
|
| 369 |
# If sandbox already exists, return its info
|
| 370 |
if session and getattr(session, "sandbox", None):
|
|
|
|
| 377 |
"Hardware cannot be changed by calling sandbox_create again. "
|
| 378 |
"Delete the existing sandbox first if you need a different tier."
|
| 379 |
)
|
| 380 |
+
await _emit_trackio_state(sb)
|
| 381 |
return (
|
| 382 |
f"Sandbox already active: {sb.space_id}\n"
|
| 383 |
f"URL: {sb.url}\n"
|
|
|
|
| 385 |
f"Use bash/read/write/edit to interact with it."
|
| 386 |
), True
|
| 387 |
|
| 388 |
+
create_kwargs: dict[str, Any] = {}
|
| 389 |
if "private" in args:
|
| 390 |
create_kwargs["private"] = args["private"]
|
| 391 |
|
| 392 |
+
extra_secrets: dict[str, str] = {}
|
| 393 |
+
if trackio_space_id:
|
| 394 |
+
extra_secrets["TRACKIO_SPACE_ID"] = trackio_space_id
|
| 395 |
+
await _seed_trackio_dashboard_safe(session, trackio_space_id)
|
| 396 |
+
if trackio_project:
|
| 397 |
+
extra_secrets["TRACKIO_PROJECT"] = trackio_project
|
| 398 |
+
|
| 399 |
try:
|
| 400 |
+
sb, error = await _ensure_sandbox(
|
| 401 |
+
session,
|
| 402 |
+
hardware=hardware,
|
| 403 |
+
extra_secrets=extra_secrets or None,
|
| 404 |
+
**create_kwargs,
|
| 405 |
+
)
|
| 406 |
except Exception as e:
|
| 407 |
return f"Failed to create sandbox: {e}", False
|
| 408 |
|
| 409 |
if error:
|
| 410 |
return error, False
|
| 411 |
|
| 412 |
+
await _emit_trackio_state(sb)
|
| 413 |
+
|
| 414 |
return (
|
| 415 |
f"Sandbox created: {sb.space_id}\n"
|
| 416 |
f"URL: {sb.url}\n"
|
agent/tools/trackio_seed.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Seed an HF Space with the trackio dashboard.
|
| 2 |
+
|
| 3 |
+
Background: when the agent creates a Space via `hf_repo_git create_repo` (or
|
| 4 |
+
the user pre-creates one), it ships with no app.py — so the iframe shows the
|
| 5 |
+
default Gradio "Get started" template instead of charts. Trackio's `init()`
|
| 6 |
+
detects the existing Space but does NOT auto-bootstrap dashboard files into it,
|
| 7 |
+
so the dashboard never materializes.
|
| 8 |
+
|
| 9 |
+
This helper writes the three files trackio's runtime expects (README.md,
|
| 10 |
+
requirements.txt, app.py) into the Space, idempotently, BEFORE the job that
|
| 11 |
+
will call `trackio.init()` runs. We deliberately omit `hf_oauth: true` from
|
| 12 |
+
the README so the embedded iframe in ml-intern renders without a login click —
|
| 13 |
+
per-user privacy is enforced by namespace ownership instead.
|
| 14 |
+
|
| 15 |
+
Beyond the dashboard files, the helper also creates the metrics bucket and
|
| 16 |
+
mounts it on the Space at `/data` (with `TRACKIO_DIR` / `TRACKIO_BUCKET_ID`
|
| 17 |
+
Space variables). Without this, the running job writes metrics into a bucket
|
| 18 |
+
that the dashboard Space can't read, and the iframe shows "No projects".
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import io
|
| 24 |
+
from typing import Callable, Optional
|
| 25 |
+
|
| 26 |
+
from huggingface_hub import (
|
| 27 |
+
HfApi,
|
| 28 |
+
Volume,
|
| 29 |
+
add_space_variable,
|
| 30 |
+
create_bucket,
|
| 31 |
+
create_repo,
|
| 32 |
+
)
|
| 33 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_README = """---
|
| 37 |
+
title: Trackio Dashboard
|
| 38 |
+
emoji: 📊
|
| 39 |
+
colorFrom: pink
|
| 40 |
+
colorTo: gray
|
| 41 |
+
sdk: gradio
|
| 42 |
+
app_file: app.py
|
| 43 |
+
pinned: false
|
| 44 |
+
tags:
|
| 45 |
+
- trackio
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
Embedded trackio dashboard for ml-intern runs.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
_REQUIREMENTS = "trackio\n"
|
| 52 |
+
_APP_PY = "import trackio\ntrackio.show()\n"
|
| 53 |
+
|
| 54 |
+
# ml-intern brand mark surfaced inside the trackio dashboard. Trackio reads
|
| 55 |
+
# `TRACKIO_LOGO_LIGHT_URL` / `TRACKIO_LOGO_DARK_URL` from Space variables and
|
| 56 |
+
# renders them in place of its own logo. We point at the publicly-resolvable
|
| 57 |
+
# copy on the smolagents/ml-intern Space repo so any seeded dashboard inherits
|
| 58 |
+
# the ml-intern branding without each user having to host the asset.
|
| 59 |
+
_LOGO_URL = (
|
| 60 |
+
"https://huggingface.co/spaces/smolagents/ml-intern/"
|
| 61 |
+
"resolve/main/frontend/public/smolagents.webp"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
_FILES = {
|
| 65 |
+
"README.md": _README,
|
| 66 |
+
"requirements.txt": _REQUIREMENTS,
|
| 67 |
+
"app.py": _APP_PY,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _already_seeded(api: HfApi, space_id: str) -> bool:
|
| 72 |
+
"""Cheap check: does the Space already have a trackio dashboard app.py?
|
| 73 |
+
|
| 74 |
+
Avoids re-uploading the same three files on every job submission. We look
|
| 75 |
+
for the literal `trackio.show` call which is the load-bearing line — any
|
| 76 |
+
other app.py shape (the default gradio shell, a stale custom one) means
|
| 77 |
+
we should re-seed.
|
| 78 |
+
"""
|
| 79 |
+
try:
|
| 80 |
+
path = api.hf_hub_download(
|
| 81 |
+
repo_id=space_id, repo_type="space", filename="app.py"
|
| 82 |
+
)
|
| 83 |
+
except (EntryNotFoundError, RepositoryNotFoundError, OSError):
|
| 84 |
+
return False
|
| 85 |
+
try:
|
| 86 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 87 |
+
return "trackio.show" in f.read()
|
| 88 |
+
except OSError:
|
| 89 |
+
return False
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _get_space_volumes(api: HfApi, space_id: str) -> list:
|
| 93 |
+
"""Return mounted volumes for a Space.
|
| 94 |
+
|
| 95 |
+
`get_space_runtime()` doesn't always populate `volumes` even when the
|
| 96 |
+
mount exists; mirror trackio's fallback to `space_info().runtime.volumes`.
|
| 97 |
+
"""
|
| 98 |
+
runtime = api.get_space_runtime(space_id)
|
| 99 |
+
if getattr(runtime, "volumes", None):
|
| 100 |
+
return list(runtime.volumes)
|
| 101 |
+
info = api.space_info(space_id)
|
| 102 |
+
if info.runtime and getattr(info.runtime, "volumes", None):
|
| 103 |
+
return list(info.runtime.volumes)
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _ensure_bucket_mounted(
|
| 108 |
+
api: HfApi,
|
| 109 |
+
space_id: str,
|
| 110 |
+
bucket_id: str,
|
| 111 |
+
hf_token: str,
|
| 112 |
+
log: Optional[Callable[[str], None]] = None,
|
| 113 |
+
) -> None:
|
| 114 |
+
"""Create the bucket if missing, mount it at `/data` on the Space, and
|
| 115 |
+
set the `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` Space variables. Idempotent —
|
| 116 |
+
skips work that has already been done.
|
| 117 |
+
"""
|
| 118 |
+
create_bucket(bucket_id, private=True, exist_ok=True, token=hf_token)
|
| 119 |
+
|
| 120 |
+
existing = _get_space_volumes(api, space_id)
|
| 121 |
+
already_mounted = any(
|
| 122 |
+
getattr(v, "type", None) == "bucket"
|
| 123 |
+
and getattr(v, "source", None) == bucket_id
|
| 124 |
+
and getattr(v, "mount_path", None) == "/data"
|
| 125 |
+
for v in existing
|
| 126 |
+
)
|
| 127 |
+
if not already_mounted:
|
| 128 |
+
preserved = [
|
| 129 |
+
v
|
| 130 |
+
for v in existing
|
| 131 |
+
if not (
|
| 132 |
+
getattr(v, "type", None) == "bucket"
|
| 133 |
+
and (
|
| 134 |
+
getattr(v, "source", None) == bucket_id
|
| 135 |
+
or getattr(v, "mount_path", None) == "/data"
|
| 136 |
+
)
|
| 137 |
+
)
|
| 138 |
+
]
|
| 139 |
+
api.set_space_volumes(
|
| 140 |
+
space_id,
|
| 141 |
+
preserved + [Volume(type="bucket", source=bucket_id, mount_path="/data")],
|
| 142 |
+
)
|
| 143 |
+
if log:
|
| 144 |
+
log(f"mounted bucket {bucket_id} at /data on {space_id}")
|
| 145 |
+
|
| 146 |
+
variables = api.get_space_variables(space_id)
|
| 147 |
+
desired = {
|
| 148 |
+
"TRACKIO_DIR": "/data/trackio",
|
| 149 |
+
"TRACKIO_BUCKET_ID": bucket_id,
|
| 150 |
+
"TRACKIO_LOGO_LIGHT_URL": _LOGO_URL,
|
| 151 |
+
"TRACKIO_LOGO_DARK_URL": _LOGO_URL,
|
| 152 |
+
}
|
| 153 |
+
for key, value in desired.items():
|
| 154 |
+
if getattr(variables.get(key), "value", None) != value:
|
| 155 |
+
add_space_variable(space_id, key, value, token=hf_token)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def ensure_trackio_dashboard(
|
| 159 |
+
space_id: str,
|
| 160 |
+
hf_token: str,
|
| 161 |
+
log: Optional[Callable[[str], None]] = None,
|
| 162 |
+
) -> bool:
|
| 163 |
+
"""Make sure *space_id* is fully wired for trackio:
|
| 164 |
+
1. Space exists with our dashboard files (README without `hf_oauth`,
|
| 165 |
+
`requirements.txt`, `app.py` calling `trackio.show`).
|
| 166 |
+
2. Bucket `<space_id>-bucket` exists, is mounted at `/data`, and the
|
| 167 |
+
Space has `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` variables set.
|
| 168 |
+
|
| 169 |
+
Idempotent — re-running is cheap. Returns True if any seeding happened
|
| 170 |
+
in step (1), False if the dashboard files were already in place. Bucket
|
| 171 |
+
mount is always re-checked.
|
| 172 |
+
"""
|
| 173 |
+
api = HfApi(token=hf_token)
|
| 174 |
+
|
| 175 |
+
create_repo(
|
| 176 |
+
repo_id=space_id,
|
| 177 |
+
repo_type="space",
|
| 178 |
+
space_sdk="gradio",
|
| 179 |
+
exist_ok=True,
|
| 180 |
+
token=hf_token,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
seeded_files = False
|
| 184 |
+
if _already_seeded(api, space_id):
|
| 185 |
+
if log:
|
| 186 |
+
log(f"trackio dashboard already seeded on {space_id}")
|
| 187 |
+
else:
|
| 188 |
+
if log:
|
| 189 |
+
log(f"seeding trackio dashboard files into {space_id}")
|
| 190 |
+
for path_in_repo, content in _FILES.items():
|
| 191 |
+
api.upload_file(
|
| 192 |
+
path_or_fileobj=io.BytesIO(content.encode("utf-8")),
|
| 193 |
+
path_in_repo=path_in_repo,
|
| 194 |
+
repo_id=space_id,
|
| 195 |
+
repo_type="space",
|
| 196 |
+
commit_message=f"ml-intern: seed trackio dashboard ({path_in_repo})",
|
| 197 |
+
)
|
| 198 |
+
seeded_files = True
|
| 199 |
+
|
| 200 |
+
bucket_id = f"{space_id}-bucket"
|
| 201 |
+
_ensure_bucket_mounted(api, space_id, bucket_id, hf_token, log)
|
| 202 |
+
|
| 203 |
+
if log:
|
| 204 |
+
log(f"trackio dashboard ready: https://huggingface.co/spaces/{space_id}")
|
| 205 |
+
return seeded_files
|
agent/tools/web_search_tool.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DuckDuckGo HTML web search tool.
|
| 2 |
+
|
| 3 |
+
This mirrors Claw Code's Rust WebSearch behavior: fetch DuckDuckGo's HTML
|
| 4 |
+
endpoint, extract result links, optionally filter domains, and return a
|
| 5 |
+
JSON payload the model can cite.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import asyncio
|
| 11 |
+
import html
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import time
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from html.parser import HTMLParser
|
| 17 |
+
from typing import Any
|
| 18 |
+
from urllib.parse import parse_qsl, parse_qs, urlencode, urlparse, urlunparse
|
| 19 |
+
|
| 20 |
+
import requests
|
| 21 |
+
|
| 22 |
+
DEFAULT_SEARCH_URL = "https://html.duckduckgo.com/html/"
|
| 23 |
+
WEB_SEARCH_BASE_URL_ENV = "CLAWD_WEB_SEARCH_BASE_URL"
|
| 24 |
+
USER_AGENT = "clawd-rust-tools/0.1"
|
| 25 |
+
REQUEST_TIMEOUT_SECONDS = 20
|
| 26 |
+
MAX_RESULTS = 8
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class SearchHit:
|
| 31 |
+
title: str
|
| 32 |
+
url: str
|
| 33 |
+
|
| 34 |
+
def as_json(self) -> dict[str, str]:
|
| 35 |
+
return {"title": self.title, "url": self.url}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class _AnchorParser(HTMLParser):
|
| 39 |
+
def __init__(self, *, require_result_class: bool) -> None:
|
| 40 |
+
super().__init__(convert_charrefs=True)
|
| 41 |
+
self.require_result_class = require_result_class
|
| 42 |
+
self.hits: list[tuple[str, str]] = []
|
| 43 |
+
self._active_href: str | None = None
|
| 44 |
+
self._active_text: list[str] = []
|
| 45 |
+
|
| 46 |
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
| 47 |
+
if tag.lower() != "a":
|
| 48 |
+
return
|
| 49 |
+
attr_map = {key.lower(): value or "" for key, value in attrs}
|
| 50 |
+
href = attr_map.get("href")
|
| 51 |
+
if not href:
|
| 52 |
+
return
|
| 53 |
+
if self.require_result_class and "result__a" not in attr_map.get("class", ""):
|
| 54 |
+
return
|
| 55 |
+
self._active_href = href
|
| 56 |
+
self._active_text = []
|
| 57 |
+
|
| 58 |
+
def handle_data(self, data: str) -> None:
|
| 59 |
+
if self._active_href is not None:
|
| 60 |
+
self._active_text.append(data)
|
| 61 |
+
|
| 62 |
+
def handle_entityref(self, name: str) -> None:
|
| 63 |
+
if self._active_href is not None:
|
| 64 |
+
self._active_text.append(f"&{name};")
|
| 65 |
+
|
| 66 |
+
def handle_charref(self, name: str) -> None:
|
| 67 |
+
if self._active_href is not None:
|
| 68 |
+
self._active_text.append(f"&#{name};")
|
| 69 |
+
|
| 70 |
+
def handle_endtag(self, tag: str) -> None:
|
| 71 |
+
if tag.lower() != "a" or self._active_href is None:
|
| 72 |
+
return
|
| 73 |
+
title = collapse_whitespace(html.unescape("".join(self._active_text))).strip()
|
| 74 |
+
self.hits.append((self._active_href, title))
|
| 75 |
+
self._active_href = None
|
| 76 |
+
self._active_text = []
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def build_search_url(query: str) -> str:
|
| 80 |
+
base = os.environ.get(WEB_SEARCH_BASE_URL_ENV, DEFAULT_SEARCH_URL)
|
| 81 |
+
parsed = urlparse(base)
|
| 82 |
+
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
| 83 |
+
raise ValueError(f"invalid search base URL: {base}")
|
| 84 |
+
|
| 85 |
+
query_pairs = parse_qsl(parsed.query, keep_blank_values=True)
|
| 86 |
+
query_pairs.append(("q", query))
|
| 87 |
+
return urlunparse(parsed._replace(query=urlencode(query_pairs)))
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def collapse_whitespace(value: str) -> str:
|
| 91 |
+
return " ".join(value.split())
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def decode_duckduckgo_redirect(url: str) -> str | None:
|
| 95 |
+
if url.startswith("http://") or url.startswith("https://"):
|
| 96 |
+
return html.unescape(url)
|
| 97 |
+
if url.startswith("//"):
|
| 98 |
+
joined = f"https:{url}"
|
| 99 |
+
elif url.startswith("/"):
|
| 100 |
+
joined = f"https://duckduckgo.com{url}"
|
| 101 |
+
else:
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
parsed = urlparse(joined)
|
| 105 |
+
if parsed.path in {"/l", "/l/"}:
|
| 106 |
+
uddg = parse_qs(parsed.query).get("uddg", [])
|
| 107 |
+
if uddg:
|
| 108 |
+
return html.unescape(uddg[0])
|
| 109 |
+
return joined
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _extract_links(search_html: str, *, require_result_class: bool) -> list[SearchHit]:
|
| 113 |
+
parser = _AnchorParser(require_result_class=require_result_class)
|
| 114 |
+
parser.feed(search_html)
|
| 115 |
+
|
| 116 |
+
hits: list[SearchHit] = []
|
| 117 |
+
for raw_url, title in parser.hits:
|
| 118 |
+
if not title:
|
| 119 |
+
continue
|
| 120 |
+
decoded_url = decode_duckduckgo_redirect(raw_url)
|
| 121 |
+
if decoded_url and (
|
| 122 |
+
decoded_url.startswith("http://") or decoded_url.startswith("https://")
|
| 123 |
+
):
|
| 124 |
+
hits.append(SearchHit(title=title, url=decoded_url))
|
| 125 |
+
return hits
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def extract_search_hits(search_html: str) -> list[SearchHit]:
|
| 129 |
+
return _extract_links(search_html, require_result_class=True)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def extract_search_hits_from_generic_links(search_html: str) -> list[SearchHit]:
|
| 133 |
+
return _extract_links(search_html, require_result_class=False)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def normalize_domain_filter(domain: str) -> str:
|
| 137 |
+
trimmed = domain.strip()
|
| 138 |
+
parsed = urlparse(trimmed)
|
| 139 |
+
candidate = parsed.hostname if parsed.scheme and parsed.hostname else trimmed
|
| 140 |
+
return candidate.strip().lstrip(".").rstrip("/").lower()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def host_matches_list(url: str, domains: list[str]) -> bool:
|
| 144 |
+
host = urlparse(url).hostname
|
| 145 |
+
if not host:
|
| 146 |
+
return False
|
| 147 |
+
normalized_host = host.lower()
|
| 148 |
+
for domain in domains:
|
| 149 |
+
normalized = normalize_domain_filter(domain)
|
| 150 |
+
if normalized and (
|
| 151 |
+
normalized_host == normalized or normalized_host.endswith(f".{normalized}")
|
| 152 |
+
):
|
| 153 |
+
return True
|
| 154 |
+
return False
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def dedupe_hits(hits: list[SearchHit]) -> list[SearchHit]:
|
| 158 |
+
seen: set[str] = set()
|
| 159 |
+
deduped: list[SearchHit] = []
|
| 160 |
+
for hit in hits:
|
| 161 |
+
if hit.url in seen:
|
| 162 |
+
continue
|
| 163 |
+
seen.add(hit.url)
|
| 164 |
+
deduped.append(hit)
|
| 165 |
+
return deduped
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def execute_web_search(
|
| 169 |
+
query: str,
|
| 170 |
+
allowed_domains: list[str] | None = None,
|
| 171 |
+
blocked_domains: list[str] | None = None,
|
| 172 |
+
tool_use_id: str = "web_search_1",
|
| 173 |
+
) -> dict[str, Any]:
|
| 174 |
+
started = time.monotonic()
|
| 175 |
+
search_url = build_search_url(query)
|
| 176 |
+
response = requests.get(
|
| 177 |
+
search_url,
|
| 178 |
+
headers={"User-Agent": USER_AGENT},
|
| 179 |
+
timeout=REQUEST_TIMEOUT_SECONDS,
|
| 180 |
+
allow_redirects=True,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
hits = extract_search_hits(response.text)
|
| 184 |
+
if not hits and urlparse(response.url or search_url).hostname:
|
| 185 |
+
hits = extract_search_hits_from_generic_links(response.text)
|
| 186 |
+
|
| 187 |
+
if allowed_domains is not None:
|
| 188 |
+
hits = [hit for hit in hits if host_matches_list(hit.url, allowed_domains)]
|
| 189 |
+
if blocked_domains is not None:
|
| 190 |
+
hits = [hit for hit in hits if not host_matches_list(hit.url, blocked_domains)]
|
| 191 |
+
|
| 192 |
+
hits = dedupe_hits(hits)[:MAX_RESULTS]
|
| 193 |
+
rendered_hits = "\n".join(f"- [{hit.title}]({hit.url})" for hit in hits)
|
| 194 |
+
if hits:
|
| 195 |
+
summary = (
|
| 196 |
+
f"Search results for {query!r}. Include a Sources section in the final answer.\n"
|
| 197 |
+
f"{rendered_hits}"
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
summary = f"No web search results matched the query {query!r}."
|
| 201 |
+
|
| 202 |
+
return {
|
| 203 |
+
"query": query,
|
| 204 |
+
"results": [
|
| 205 |
+
summary,
|
| 206 |
+
{
|
| 207 |
+
"tool_use_id": tool_use_id,
|
| 208 |
+
"content": [hit.as_json() for hit in hits],
|
| 209 |
+
},
|
| 210 |
+
],
|
| 211 |
+
"durationSeconds": time.monotonic() - started,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
WEB_SEARCH_TOOL_SPEC = {
|
| 216 |
+
"name": "web_search",
|
| 217 |
+
"description": "Search the web for current information and return cited results.",
|
| 218 |
+
"parameters": {
|
| 219 |
+
"type": "object",
|
| 220 |
+
"properties": {
|
| 221 |
+
"query": {"type": "string", "minLength": 2},
|
| 222 |
+
"allowed_domains": {
|
| 223 |
+
"type": "array",
|
| 224 |
+
"items": {"type": "string"},
|
| 225 |
+
"description": "Optional allowlist of domains or URLs. Subdomains match.",
|
| 226 |
+
},
|
| 227 |
+
"blocked_domains": {
|
| 228 |
+
"type": "array",
|
| 229 |
+
"items": {"type": "string"},
|
| 230 |
+
"description": "Optional blocklist of domains or URLs. Subdomains match.",
|
| 231 |
+
},
|
| 232 |
+
},
|
| 233 |
+
"required": ["query"],
|
| 234 |
+
"additionalProperties": False,
|
| 235 |
+
},
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def _optional_string_list(arguments: dict[str, Any], key: str) -> list[str] | None:
|
| 240 |
+
value = arguments.get(key)
|
| 241 |
+
if value is None:
|
| 242 |
+
return None
|
| 243 |
+
if not isinstance(value, list) or not all(isinstance(item, str) for item in value):
|
| 244 |
+
raise ValueError(f"{key} must be an array of strings")
|
| 245 |
+
return value
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
async def web_search_handler(
|
| 249 |
+
arguments: dict[str, Any],
|
| 250 |
+
session: Any = None,
|
| 251 |
+
tool_call_id: str | None = None,
|
| 252 |
+
**_kw: Any,
|
| 253 |
+
) -> tuple[str, bool]:
|
| 254 |
+
query_value = arguments.get("query", "")
|
| 255 |
+
if not isinstance(query_value, str):
|
| 256 |
+
return "Error: web_search requires a query string with at least 2 characters.", False
|
| 257 |
+
|
| 258 |
+
query = query_value.strip()
|
| 259 |
+
if len(query) < 2:
|
| 260 |
+
return "Error: web_search requires a query with at least 2 characters.", False
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
output = await asyncio.to_thread(
|
| 264 |
+
execute_web_search,
|
| 265 |
+
query=query,
|
| 266 |
+
allowed_domains=_optional_string_list(arguments, "allowed_domains"),
|
| 267 |
+
blocked_domains=_optional_string_list(arguments, "blocked_domains"),
|
| 268 |
+
tool_use_id=tool_call_id or "web_search_1",
|
| 269 |
+
)
|
| 270 |
+
except Exception as exc:
|
| 271 |
+
return f"Error executing web search: {exc}", False
|
| 272 |
+
|
| 273 |
+
return json.dumps(output, indent=2), True
|
backend/dependencies.py
CHANGED
|
@@ -12,6 +12,8 @@ from typing import Any
|
|
| 12 |
import httpx
|
| 13 |
from fastapi import HTTPException, Request, status
|
| 14 |
|
|
|
|
|
|
|
| 15 |
from agent.core.hf_access import fetch_whoami_v2, jobs_access_from_whoami
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
|
@@ -157,9 +159,8 @@ async def get_current_user(request: Request) -> dict[str, Any]:
|
|
| 157 |
return DEV_USER
|
| 158 |
|
| 159 |
# Try Authorization header
|
| 160 |
-
|
| 161 |
-
if
|
| 162 |
-
token = auth_header[7:]
|
| 163 |
user = await _extract_user_from_token(token)
|
| 164 |
if user:
|
| 165 |
return user
|
|
@@ -183,9 +184,9 @@ def _extract_token(request: Request) -> str | None:
|
|
| 183 |
|
| 184 |
Mirrors the lookup order used by ``get_current_user``.
|
| 185 |
"""
|
| 186 |
-
|
| 187 |
-
if
|
| 188 |
-
return
|
| 189 |
return request.cookies.get("hf_access_token")
|
| 190 |
|
| 191 |
|
|
@@ -202,4 +203,3 @@ async def require_huggingface_org_member(request: Request) -> bool:
|
|
| 202 |
if not token:
|
| 203 |
return False
|
| 204 |
return await check_org_membership(token, HF_EMPLOYEE_ORG)
|
| 205 |
-
|
|
|
|
| 12 |
import httpx
|
| 13 |
from fastapi import HTTPException, Request, status
|
| 14 |
|
| 15 |
+
from agent.core.hf_tokens import bearer_token_from_header
|
| 16 |
+
|
| 17 |
from agent.core.hf_access import fetch_whoami_v2, jobs_access_from_whoami
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
|
|
|
| 159 |
return DEV_USER
|
| 160 |
|
| 161 |
# Try Authorization header
|
| 162 |
+
token = bearer_token_from_header(request.headers.get("Authorization", ""))
|
| 163 |
+
if token:
|
|
|
|
| 164 |
user = await _extract_user_from_token(token)
|
| 165 |
if user:
|
| 166 |
return user
|
|
|
|
| 184 |
|
| 185 |
Mirrors the lookup order used by ``get_current_user``.
|
| 186 |
"""
|
| 187 |
+
token = bearer_token_from_header(request.headers.get("Authorization", ""))
|
| 188 |
+
if token:
|
| 189 |
+
return token
|
| 190 |
return request.cookies.get("hf_access_token")
|
| 191 |
|
| 192 |
|
|
|
|
| 203 |
if not token:
|
| 204 |
return False
|
| 205 |
return await check_org_membership(token, HF_EMPLOYEE_ORG)
|
|
|
backend/main.py
CHANGED
|
@@ -6,14 +6,17 @@ from contextlib import asynccontextmanager
|
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from fastapi import FastAPI
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
from fastapi.staticfiles import StaticFiles
|
| 12 |
from routes.agent import router as agent_router
|
| 13 |
from routes.auth import router as auth_router
|
| 14 |
-
|
| 15 |
-
# Load .env from project root (parent directory)
|
| 16 |
-
load_dotenv(Path(__file__).parent.parent / ".env")
|
| 17 |
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(
|
|
@@ -27,6 +30,7 @@ logger = logging.getLogger(__name__)
|
|
| 27 |
async def lifespan(app: FastAPI):
|
| 28 |
"""Application lifespan handler."""
|
| 29 |
logger.info("Starting HF Agent backend...")
|
|
|
|
| 30 |
# Start in-process hourly KPI rollup. Replaces an external cron so the
|
| 31 |
# rollup lives next to the data and reuses the Space's HF token.
|
| 32 |
try:
|
|
@@ -34,7 +38,6 @@ async def lifespan(app: FastAPI):
|
|
| 34 |
kpis_scheduler.start()
|
| 35 |
except Exception as e:
|
| 36 |
logger.warning("KPI scheduler failed to start: %s", e)
|
| 37 |
-
|
| 38 |
yield
|
| 39 |
|
| 40 |
logger.info("Shutting down HF Agent backend...")
|
|
@@ -47,7 +50,6 @@ async def lifespan(app: FastAPI):
|
|
| 47 |
# Final-flush: save every still-active session so we don't lose traces on
|
| 48 |
# server restart. Uploads are detached subprocesses — this is fast.
|
| 49 |
try:
|
| 50 |
-
from session_manager import session_manager
|
| 51 |
for sid, agent_session in list(session_manager.sessions.items()):
|
| 52 |
sess = agent_session.session
|
| 53 |
if sess.config.save_sessions:
|
|
@@ -58,6 +60,7 @@ async def lifespan(app: FastAPI):
|
|
| 58 |
logger.warning("Failed to flush session %s: %s", sid, e)
|
| 59 |
except Exception as e:
|
| 60 |
logger.warning("Lifespan final-flush skipped: %s", e)
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
app = FastAPI(
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
+
|
| 10 |
+
# Load .env before importing routes/session_manager so persistence and quota
|
| 11 |
+
# modules see local Mongo settings during startup.
|
| 12 |
+
load_dotenv(Path(__file__).parent.parent / ".env")
|
| 13 |
+
|
| 14 |
from fastapi import FastAPI
|
| 15 |
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
from fastapi.staticfiles import StaticFiles
|
| 17 |
from routes.agent import router as agent_router
|
| 18 |
from routes.auth import router as auth_router
|
| 19 |
+
from session_manager import session_manager
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Configure logging
|
| 22 |
logging.basicConfig(
|
|
|
|
| 30 |
async def lifespan(app: FastAPI):
|
| 31 |
"""Application lifespan handler."""
|
| 32 |
logger.info("Starting HF Agent backend...")
|
| 33 |
+
await session_manager.start()
|
| 34 |
# Start in-process hourly KPI rollup. Replaces an external cron so the
|
| 35 |
# rollup lives next to the data and reuses the Space's HF token.
|
| 36 |
try:
|
|
|
|
| 38 |
kpis_scheduler.start()
|
| 39 |
except Exception as e:
|
| 40 |
logger.warning("KPI scheduler failed to start: %s", e)
|
|
|
|
| 41 |
yield
|
| 42 |
|
| 43 |
logger.info("Shutting down HF Agent backend...")
|
|
|
|
| 50 |
# Final-flush: save every still-active session so we don't lose traces on
|
| 51 |
# server restart. Uploads are detached subprocesses — this is fast.
|
| 52 |
try:
|
|
|
|
| 53 |
for sid, agent_session in list(session_manager.sessions.items()):
|
| 54 |
sess = agent_session.session
|
| 55 |
if sess.config.save_sessions:
|
|
|
|
| 60 |
logger.warning("Failed to flush session %s: %s", sid, e)
|
| 61 |
except Exception as e:
|
| 62 |
logger.warning("Lifespan final-flush skipped: %s", e)
|
| 63 |
+
await session_manager.close()
|
| 64 |
|
| 65 |
|
| 66 |
app = FastAPI(
|
backend/models.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
from enum import Enum
|
| 4 |
from typing import Any
|
| 5 |
|
| 6 |
-
from pydantic import BaseModel
|
| 7 |
|
| 8 |
|
| 9 |
class OpType(str, Enum):
|
|
@@ -87,6 +87,14 @@ class SessionInfo(BaseModel):
|
|
| 87 |
user_id: str = "dev"
|
| 88 |
pending_approval: list[PendingApprovalTool] | None = None
|
| 89 |
model: str | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
class HealthResponse(BaseModel):
|
|
|
|
| 3 |
from enum import Enum
|
| 4 |
from typing import Any
|
| 5 |
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
|
| 8 |
|
| 9 |
class OpType(str, Enum):
|
|
|
|
| 87 |
user_id: str = "dev"
|
| 88 |
pending_approval: list[PendingApprovalTool] | None = None
|
| 89 |
model: str | None = None
|
| 90 |
+
title: str | None = None
|
| 91 |
+
notification_destinations: list[str] = Field(default_factory=list)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class SessionNotificationsRequest(BaseModel):
|
| 95 |
+
"""Replace the session's auto-notification destinations."""
|
| 96 |
+
|
| 97 |
+
destinations: list[str]
|
| 98 |
|
| 99 |
|
| 100 |
class HealthResponse(BaseModel):
|
backend/routes/agent.py
CHANGED
|
@@ -24,6 +24,7 @@ from models import (
|
|
| 24 |
HealthResponse,
|
| 25 |
LLMHealthResponse,
|
| 26 |
SessionInfo,
|
|
|
|
| 27 |
SessionResponse,
|
| 28 |
SubmitRequest,
|
| 29 |
TruncateRequest,
|
|
@@ -33,6 +34,7 @@ from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, se
|
|
| 33 |
import user_quotas
|
| 34 |
|
| 35 |
from agent.core.hf_access import get_jobs_access
|
|
|
|
| 36 |
from agent.core.llm_params import _resolve_llm_params
|
| 37 |
|
| 38 |
logger = logging.getLogger(__name__)
|
|
@@ -118,9 +120,9 @@ async def _enforce_claude_quota(
|
|
| 118 |
if not _is_anthropic_model(model_name):
|
| 119 |
return
|
| 120 |
user_id = user["user_id"]
|
| 121 |
-
used = await user_quotas.get_claude_used_today(user_id)
|
| 122 |
cap = user_quotas.daily_cap_for(user.get("plan"))
|
| 123 |
-
|
|
|
|
| 124 |
raise HTTPException(
|
| 125 |
status_code=429,
|
| 126 |
detail={
|
|
@@ -133,8 +135,8 @@ async def _enforce_claude_quota(
|
|
| 133 |
),
|
| 134 |
},
|
| 135 |
)
|
| 136 |
-
await user_quotas.increment_claude(user_id)
|
| 137 |
agent_session.claude_counted = True
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
async def _enforce_jobs_access_for_approvals(
|
|
@@ -193,6 +195,9 @@ async def _enforce_jobs_access_for_approvals(
|
|
| 193 |
"The selected jobs namespace is not one of your eligible paid organizations. "
|
| 194 |
f"Allowed namespaces: {', '.join(access.paid_org_names)}"
|
| 195 |
),
|
|
|
|
|
|
|
|
|
|
| 196 |
},
|
| 197 |
)
|
| 198 |
missing_namespace = [
|
|
@@ -236,13 +241,23 @@ async def _enforce_jobs_access_for_approvals(
|
|
| 236 |
)
|
| 237 |
|
| 238 |
|
| 239 |
-
def _check_session_access(
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 244 |
-
if
|
| 245 |
raise HTTPException(status_code=403, detail="Access denied to this session")
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
@router.get("/health", response_model=HealthResponse)
|
|
@@ -332,10 +347,8 @@ async def generate_title(
|
|
| 332 |
reasoning model — reasoning_effort=low keeps the reasoning budget small
|
| 333 |
so the 60-token output budget isn't consumed before the title is written.
|
| 334 |
"""
|
| 335 |
-
api_key = (
|
| 336 |
-
|
| 337 |
-
or (user.get("hf_token") if isinstance(user, dict) else None)
|
| 338 |
-
or os.environ.get("HF_TOKEN")
|
| 339 |
)
|
| 340 |
try:
|
| 341 |
response = await acompletion(
|
|
@@ -366,11 +379,21 @@ async def generate_title(
|
|
| 366 |
title = title.translate(_TITLE_STRIP_CHARS).strip()
|
| 367 |
if len(title) > 50:
|
| 368 |
title = title[:50].rstrip() + "…"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
return {"title": title}
|
| 370 |
except Exception as e:
|
| 371 |
logger.warning(f"Title generation failed: {e}")
|
| 372 |
fallback = request.text.strip()
|
| 373 |
title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
return {"title": title}
|
| 375 |
|
| 376 |
|
|
@@ -391,14 +414,7 @@ async def create_session(
|
|
| 391 |
Returns 503 if the server or user has reached the session limit.
|
| 392 |
"""
|
| 393 |
# Extract the user's HF token (Bearer header, HttpOnly cookie, or env var)
|
| 394 |
-
hf_token =
|
| 395 |
-
auth_header = request.headers.get("Authorization", "")
|
| 396 |
-
if auth_header.startswith("Bearer "):
|
| 397 |
-
hf_token = auth_header[7:]
|
| 398 |
-
if not hf_token:
|
| 399 |
-
hf_token = request.cookies.get("hf_access_token")
|
| 400 |
-
if not hf_token:
|
| 401 |
-
hf_token = os.environ.get("HF_TOKEN")
|
| 402 |
|
| 403 |
# Optional model override. Empty body falls back to the config default.
|
| 404 |
model: str | None = None
|
|
@@ -444,14 +460,7 @@ async def restore_session_summary(
|
|
| 444 |
if not isinstance(messages, list) or not messages:
|
| 445 |
raise HTTPException(status_code=400, detail="Missing 'messages' array")
|
| 446 |
|
| 447 |
-
hf_token =
|
| 448 |
-
auth_header = request.headers.get("Authorization", "")
|
| 449 |
-
if auth_header.startswith("Bearer "):
|
| 450 |
-
hf_token = auth_header[7:]
|
| 451 |
-
if not hf_token:
|
| 452 |
-
hf_token = request.cookies.get("hf_access_token")
|
| 453 |
-
if not hf_token:
|
| 454 |
-
hf_token = os.environ.get("HF_TOKEN")
|
| 455 |
|
| 456 |
model = body.get("model")
|
| 457 |
valid_ids = {m["id"] for m in AVAILABLE_MODELS}
|
|
@@ -488,7 +497,7 @@ async def get_session(
|
|
| 488 |
session_id: str, user: dict = Depends(get_current_user)
|
| 489 |
) -> SessionInfo:
|
| 490 |
"""Get session information. Only accessible by the session owner."""
|
| 491 |
-
_check_session_access(session_id, user)
|
| 492 |
info = session_manager.get_session_info(session_id)
|
| 493 |
return SessionInfo(**info)
|
| 494 |
|
|
@@ -509,7 +518,7 @@ async def set_session_model(
|
|
| 509 |
Switching TO an Anthropic model requires HF org membership (PR #63);
|
| 510 |
free-model switches are unrestricted.
|
| 511 |
"""
|
| 512 |
-
_check_session_access(session_id, user)
|
| 513 |
model_id = body.get("model")
|
| 514 |
if not model_id:
|
| 515 |
raise HTTPException(status_code=400, detail="Missing 'model' field")
|
|
@@ -517,10 +526,9 @@ async def set_session_model(
|
|
| 517 |
if model_id not in valid_ids:
|
| 518 |
raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}")
|
| 519 |
await _require_hf_for_anthropic(request, model_id)
|
| 520 |
-
agent_session = session_manager.sessions.get(session_id)
|
| 521 |
if not agent_session:
|
| 522 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 523 |
-
|
| 524 |
logger.info(
|
| 525 |
f"Session {session_id} model → {model_id} "
|
| 526 |
f"(by {user.get('username', 'unknown')})"
|
|
@@ -528,6 +536,27 @@ async def set_session_model(
|
|
| 528 |
return {"session_id": session_id, "model": model_id}
|
| 529 |
|
| 530 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
@router.get("/user/quota")
|
| 532 |
async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
|
| 533 |
"""Return the user's plan tier and today's Claude-session quota state."""
|
|
@@ -545,14 +574,7 @@ async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
|
|
| 545 |
@router.get("/user/jobs-access")
|
| 546 |
async def get_jobs_access_info(request: Request, user: dict = Depends(get_current_user)) -> dict:
|
| 547 |
"""Return whether the current token can run HF Jobs and under which namespaces."""
|
| 548 |
-
token =
|
| 549 |
-
auth_header = request.headers.get("Authorization", "")
|
| 550 |
-
if auth_header.startswith("Bearer "):
|
| 551 |
-
token = auth_header[7:]
|
| 552 |
-
if not token:
|
| 553 |
-
token = request.cookies.get("hf_access_token")
|
| 554 |
-
if not token:
|
| 555 |
-
token = os.environ.get("HF_TOKEN")
|
| 556 |
|
| 557 |
access = await get_jobs_access(token or "")
|
| 558 |
return {
|
|
@@ -566,7 +588,7 @@ async def get_jobs_access_info(request: Request, user: dict = Depends(get_curren
|
|
| 566 |
@router.get("/sessions", response_model=list[SessionInfo])
|
| 567 |
async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
|
| 568 |
"""List sessions belonging to the authenticated user."""
|
| 569 |
-
sessions = session_manager.list_sessions(user_id=user["user_id"])
|
| 570 |
return [SessionInfo(**s) for s in sessions]
|
| 571 |
|
| 572 |
|
|
@@ -575,7 +597,7 @@ async def delete_session(
|
|
| 575 |
session_id: str, user: dict = Depends(get_current_user)
|
| 576 |
) -> dict:
|
| 577 |
"""Delete a session. Only accessible by the session owner."""
|
| 578 |
-
_check_session_access(session_id, user)
|
| 579 |
success = await session_manager.delete_session(session_id)
|
| 580 |
if not success:
|
| 581 |
raise HTTPException(status_code=404, detail="Session not found")
|
|
@@ -587,10 +609,8 @@ async def submit_input(
|
|
| 587 |
request: SubmitRequest, user: dict = Depends(get_current_user)
|
| 588 |
) -> dict:
|
| 589 |
"""Submit user input to a session. Only accessible by the session owner."""
|
| 590 |
-
_check_session_access(request.session_id, user)
|
| 591 |
-
|
| 592 |
-
if agent_session is not None:
|
| 593 |
-
await _enforce_claude_quota(user, agent_session)
|
| 594 |
success = await session_manager.submit_user_input(request.session_id, request.text)
|
| 595 |
if not success:
|
| 596 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
@@ -602,10 +622,7 @@ async def submit_approval(
|
|
| 602 |
request: ApprovalRequest, user: dict = Depends(get_current_user)
|
| 603 |
) -> dict:
|
| 604 |
"""Submit tool approvals to a session. Only accessible by the session owner."""
|
| 605 |
-
_check_session_access(request.session_id, user)
|
| 606 |
-
agent_session = session_manager.sessions.get(request.session_id)
|
| 607 |
-
if agent_session is None:
|
| 608 |
-
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 609 |
approvals = [
|
| 610 |
{
|
| 611 |
"tool_call_id": a.tool_call_id,
|
|
@@ -630,9 +647,7 @@ async def chat_sse(
|
|
| 630 |
user: dict = Depends(get_current_user),
|
| 631 |
) -> StreamingResponse:
|
| 632 |
"""SSE endpoint: submit input or approval, then stream events until turn ends."""
|
| 633 |
-
_check_session_access(session_id, user)
|
| 634 |
-
|
| 635 |
-
agent_session = session_manager.sessions.get(session_id)
|
| 636 |
if not agent_session or not agent_session.is_active:
|
| 637 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 638 |
|
|
@@ -698,10 +713,7 @@ async def record_pro_click(
|
|
| 698 |
user: dict = Depends(get_current_user),
|
| 699 |
) -> dict:
|
| 700 |
"""Record a click on a Pro upgrade CTA shown from inside a session."""
|
| 701 |
-
_check_session_access(session_id, user)
|
| 702 |
-
agent_session = session_manager.sessions.get(session_id)
|
| 703 |
-
if not agent_session:
|
| 704 |
-
raise HTTPException(status_code=404, detail="Session not found")
|
| 705 |
|
| 706 |
from agent.core import telemetry
|
| 707 |
await telemetry.record_pro_cta_click(
|
|
@@ -723,12 +735,53 @@ _TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted"
|
|
| 723 |
_SSE_KEEPALIVE_SECONDS = 15
|
| 724 |
|
| 725 |
|
| 726 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
"""Build a StreamingResponse that drains *event_queue* as SSE,
|
| 728 |
sending keepalive comments every 15 s to prevent proxy timeouts."""
|
| 729 |
|
| 730 |
async def event_generator():
|
| 731 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
while True:
|
| 733 |
try:
|
| 734 |
msg = await asyncio.wait_for(
|
|
@@ -739,7 +792,7 @@ def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse:
|
|
| 739 |
yield ": keepalive\n\n"
|
| 740 |
continue
|
| 741 |
event_type = msg.get("event_type", "")
|
| 742 |
-
yield
|
| 743 |
if event_type in _TERMINAL_EVENTS:
|
| 744 |
break
|
| 745 |
finally:
|
|
@@ -759,6 +812,7 @@ def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse:
|
|
| 759 |
@router.get("/events/{session_id}")
|
| 760 |
async def subscribe_events(
|
| 761 |
session_id: str,
|
|
|
|
| 762 |
user: dict = Depends(get_current_user),
|
| 763 |
) -> StreamingResponse:
|
| 764 |
"""Subscribe to events for a running session without submitting new input.
|
|
@@ -766,15 +820,21 @@ async def subscribe_events(
|
|
| 766 |
Used by the frontend to re-attach after a connection drop (e.g. screen
|
| 767 |
sleep). Returns 404 if the session isn't active or isn't processing.
|
| 768 |
"""
|
| 769 |
-
_check_session_access(session_id, user)
|
| 770 |
-
|
| 771 |
-
agent_session = session_manager.sessions.get(session_id)
|
| 772 |
if not agent_session or not agent_session.is_active:
|
| 773 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 774 |
|
|
|
|
|
|
|
| 775 |
broadcaster = agent_session.broadcaster
|
| 776 |
sub_id, event_queue = broadcaster.subscribe()
|
| 777 |
-
return _sse_response(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
|
| 779 |
|
| 780 |
@router.post("/interrupt/{session_id}")
|
|
@@ -782,7 +842,7 @@ async def interrupt_session(
|
|
| 782 |
session_id: str, user: dict = Depends(get_current_user)
|
| 783 |
) -> dict:
|
| 784 |
"""Interrupt the current operation in a session."""
|
| 785 |
-
_check_session_access(session_id, user)
|
| 786 |
success = await session_manager.interrupt(session_id)
|
| 787 |
if not success:
|
| 788 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
@@ -794,17 +854,16 @@ async def get_session_messages(
|
|
| 794 |
session_id: str, user: dict = Depends(get_current_user)
|
| 795 |
) -> list[dict]:
|
| 796 |
"""Return the session's message history from memory."""
|
| 797 |
-
_check_session_access(session_id, user)
|
| 798 |
-
agent_session = session_manager.sessions.get(session_id)
|
| 799 |
if not agent_session or not agent_session.is_active:
|
| 800 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 801 |
-
return [msg.model_dump() for msg in agent_session.session.context_manager.items]
|
| 802 |
|
| 803 |
|
| 804 |
@router.post("/undo/{session_id}")
|
| 805 |
async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict:
|
| 806 |
"""Undo the last turn in a session."""
|
| 807 |
-
_check_session_access(session_id, user)
|
| 808 |
success = await session_manager.undo(session_id)
|
| 809 |
if not success:
|
| 810 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
@@ -816,7 +875,7 @@ async def truncate_session(
|
|
| 816 |
session_id: str, body: TruncateRequest, user: dict = Depends(get_current_user)
|
| 817 |
) -> dict:
|
| 818 |
"""Truncate conversation to before a specific user message."""
|
| 819 |
-
_check_session_access(session_id, user)
|
| 820 |
success = await session_manager.truncate(session_id, body.user_message_index)
|
| 821 |
if not success:
|
| 822 |
raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range")
|
|
@@ -828,7 +887,7 @@ async def compact_session(
|
|
| 828 |
session_id: str, user: dict = Depends(get_current_user)
|
| 829 |
) -> dict:
|
| 830 |
"""Compact the context in a session."""
|
| 831 |
-
_check_session_access(session_id, user)
|
| 832 |
success = await session_manager.compact(session_id)
|
| 833 |
if not success:
|
| 834 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
@@ -840,13 +899,12 @@ async def shutdown_session(
|
|
| 840 |
session_id: str, user: dict = Depends(get_current_user)
|
| 841 |
) -> dict:
|
| 842 |
"""Shutdown a session."""
|
| 843 |
-
_check_session_access(session_id, user)
|
| 844 |
success = await session_manager.shutdown_session(session_id)
|
| 845 |
if not success:
|
| 846 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 847 |
return {"status": "shutdown_requested", "session_id": session_id}
|
| 848 |
|
| 849 |
-
|
| 850 |
@router.post("/feedback/{session_id}")
|
| 851 |
async def submit_feedback(
|
| 852 |
session_id: str,
|
|
@@ -859,10 +917,7 @@ async def submit_feedback(
|
|
| 859 |
turn_index?: int, comment?: str, message_id?: str}
|
| 860 |
Appended as a `feedback` event and saved with the session trajectory.
|
| 861 |
"""
|
| 862 |
-
_check_session_access(session_id, user)
|
| 863 |
-
agent_session = session_manager.sessions.get(session_id)
|
| 864 |
-
if not agent_session:
|
| 865 |
-
raise HTTPException(status_code=404, detail="Session not found")
|
| 866 |
|
| 867 |
rating = body.get("rating")
|
| 868 |
if rating not in {"up", "down", "outcome_success", "outcome_fail"}:
|
|
|
|
| 24 |
HealthResponse,
|
| 25 |
LLMHealthResponse,
|
| 26 |
SessionInfo,
|
| 27 |
+
SessionNotificationsRequest,
|
| 28 |
SessionResponse,
|
| 29 |
SubmitRequest,
|
| 30 |
TruncateRequest,
|
|
|
|
| 34 |
import user_quotas
|
| 35 |
|
| 36 |
from agent.core.hf_access import get_jobs_access
|
| 37 |
+
from agent.core.hf_tokens import resolve_hf_request_token, resolve_hf_router_token
|
| 38 |
from agent.core.llm_params import _resolve_llm_params
|
| 39 |
|
| 40 |
logger = logging.getLogger(__name__)
|
|
|
|
| 120 |
if not _is_anthropic_model(model_name):
|
| 121 |
return
|
| 122 |
user_id = user["user_id"]
|
|
|
|
| 123 |
cap = user_quotas.daily_cap_for(user.get("plan"))
|
| 124 |
+
new_count = await user_quotas.try_increment_claude(user_id, cap)
|
| 125 |
+
if new_count is None:
|
| 126 |
raise HTTPException(
|
| 127 |
status_code=429,
|
| 128 |
detail={
|
|
|
|
| 135 |
),
|
| 136 |
},
|
| 137 |
)
|
|
|
|
| 138 |
agent_session.claude_counted = True
|
| 139 |
+
await session_manager.persist_session_snapshot(agent_session)
|
| 140 |
|
| 141 |
|
| 142 |
async def _enforce_jobs_access_for_approvals(
|
|
|
|
| 195 |
"The selected jobs namespace is not one of your eligible paid organizations. "
|
| 196 |
f"Allowed namespaces: {', '.join(access.paid_org_names)}"
|
| 197 |
),
|
| 198 |
+
"plan": user.get("plan", "free"),
|
| 199 |
+
"tool_call_ids": invalid_namespace,
|
| 200 |
+
"eligible_namespaces": access.paid_org_names,
|
| 201 |
},
|
| 202 |
)
|
| 203 |
missing_namespace = [
|
|
|
|
| 241 |
)
|
| 242 |
|
| 243 |
|
| 244 |
+
async def _check_session_access(
|
| 245 |
+
session_id: str,
|
| 246 |
+
user: dict[str, Any],
|
| 247 |
+
request: Request | None = None,
|
| 248 |
+
) -> AgentSession:
|
| 249 |
+
"""Verify and lazily load the user's session. Raises 403 or 404."""
|
| 250 |
+
hf_token = resolve_hf_request_token(request) if request is not None else user.get("hf_token")
|
| 251 |
+
agent_session = await session_manager.ensure_session_loaded(
|
| 252 |
+
session_id,
|
| 253 |
+
user["user_id"],
|
| 254 |
+
hf_token=hf_token,
|
| 255 |
+
)
|
| 256 |
+
if not agent_session:
|
| 257 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 258 |
+
if user["user_id"] != "dev" and agent_session.user_id not in {user["user_id"], "dev"}:
|
| 259 |
raise HTTPException(status_code=403, detail="Access denied to this session")
|
| 260 |
+
return agent_session
|
| 261 |
|
| 262 |
|
| 263 |
@router.get("/health", response_model=HealthResponse)
|
|
|
|
| 347 |
reasoning model — reasoning_effort=low keeps the reasoning budget small
|
| 348 |
so the 60-token output budget isn't consumed before the title is written.
|
| 349 |
"""
|
| 350 |
+
api_key = resolve_hf_router_token(
|
| 351 |
+
user.get("hf_token") if isinstance(user, dict) else None
|
|
|
|
|
|
|
| 352 |
)
|
| 353 |
try:
|
| 354 |
response = await acompletion(
|
|
|
|
| 379 |
title = title.translate(_TITLE_STRIP_CHARS).strip()
|
| 380 |
if len(title) > 50:
|
| 381 |
title = title[:50].rstrip() + "…"
|
| 382 |
+
try:
|
| 383 |
+
await _check_session_access(request.session_id, user)
|
| 384 |
+
await session_manager.update_session_title(request.session_id, title)
|
| 385 |
+
except Exception:
|
| 386 |
+
logger.debug("Skipping title persistence for missing session %s", request.session_id)
|
| 387 |
return {"title": title}
|
| 388 |
except Exception as e:
|
| 389 |
logger.warning(f"Title generation failed: {e}")
|
| 390 |
fallback = request.text.strip()
|
| 391 |
title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback
|
| 392 |
+
try:
|
| 393 |
+
await _check_session_access(request.session_id, user)
|
| 394 |
+
await session_manager.update_session_title(request.session_id, title)
|
| 395 |
+
except Exception:
|
| 396 |
+
logger.debug("Skipping fallback title persistence for missing session %s", request.session_id)
|
| 397 |
return {"title": title}
|
| 398 |
|
| 399 |
|
|
|
|
| 414 |
Returns 503 if the server or user has reached the session limit.
|
| 415 |
"""
|
| 416 |
# Extract the user's HF token (Bearer header, HttpOnly cookie, or env var)
|
| 417 |
+
hf_token = resolve_hf_request_token(request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
# Optional model override. Empty body falls back to the config default.
|
| 420 |
model: str | None = None
|
|
|
|
| 460 |
if not isinstance(messages, list) or not messages:
|
| 461 |
raise HTTPException(status_code=400, detail="Missing 'messages' array")
|
| 462 |
|
| 463 |
+
hf_token = resolve_hf_request_token(request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
|
| 465 |
model = body.get("model")
|
| 466 |
valid_ids = {m["id"] for m in AVAILABLE_MODELS}
|
|
|
|
| 497 |
session_id: str, user: dict = Depends(get_current_user)
|
| 498 |
) -> SessionInfo:
|
| 499 |
"""Get session information. Only accessible by the session owner."""
|
| 500 |
+
await _check_session_access(session_id, user)
|
| 501 |
info = session_manager.get_session_info(session_id)
|
| 502 |
return SessionInfo(**info)
|
| 503 |
|
|
|
|
| 518 |
Switching TO an Anthropic model requires HF org membership (PR #63);
|
| 519 |
free-model switches are unrestricted.
|
| 520 |
"""
|
| 521 |
+
agent_session = await _check_session_access(session_id, user, request)
|
| 522 |
model_id = body.get("model")
|
| 523 |
if not model_id:
|
| 524 |
raise HTTPException(status_code=400, detail="Missing 'model' field")
|
|
|
|
| 526 |
if model_id not in valid_ids:
|
| 527 |
raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}")
|
| 528 |
await _require_hf_for_anthropic(request, model_id)
|
|
|
|
| 529 |
if not agent_session:
|
| 530 |
raise HTTPException(status_code=404, detail="Session not found")
|
| 531 |
+
await session_manager.update_session_model(session_id, model_id)
|
| 532 |
logger.info(
|
| 533 |
f"Session {session_id} model → {model_id} "
|
| 534 |
f"(by {user.get('username', 'unknown')})"
|
|
|
|
| 536 |
return {"session_id": session_id, "model": model_id}
|
| 537 |
|
| 538 |
|
| 539 |
+
@router.post("/session/{session_id}/notifications")
|
| 540 |
+
async def set_session_notifications(
|
| 541 |
+
session_id: str,
|
| 542 |
+
body: SessionNotificationsRequest,
|
| 543 |
+
user: dict = Depends(get_current_user),
|
| 544 |
+
) -> dict:
|
| 545 |
+
"""Replace the session's auto-notification destinations."""
|
| 546 |
+
agent_session = await _check_session_access(session_id, user)
|
| 547 |
+
try:
|
| 548 |
+
destinations = session_manager.set_notification_destinations(
|
| 549 |
+
session_id, body.destinations
|
| 550 |
+
)
|
| 551 |
+
except ValueError as e:
|
| 552 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 553 |
+
await session_manager.persist_session_snapshot(agent_session)
|
| 554 |
+
return {
|
| 555 |
+
"session_id": session_id,
|
| 556 |
+
"notification_destinations": destinations,
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
|
| 560 |
@router.get("/user/quota")
|
| 561 |
async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
|
| 562 |
"""Return the user's plan tier and today's Claude-session quota state."""
|
|
|
|
| 574 |
@router.get("/user/jobs-access")
|
| 575 |
async def get_jobs_access_info(request: Request, user: dict = Depends(get_current_user)) -> dict:
|
| 576 |
"""Return whether the current token can run HF Jobs and under which namespaces."""
|
| 577 |
+
token = resolve_hf_request_token(request)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
|
| 579 |
access = await get_jobs_access(token or "")
|
| 580 |
return {
|
|
|
|
| 588 |
@router.get("/sessions", response_model=list[SessionInfo])
|
| 589 |
async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
|
| 590 |
"""List sessions belonging to the authenticated user."""
|
| 591 |
+
sessions = await session_manager.list_sessions(user_id=user["user_id"])
|
| 592 |
return [SessionInfo(**s) for s in sessions]
|
| 593 |
|
| 594 |
|
|
|
|
| 597 |
session_id: str, user: dict = Depends(get_current_user)
|
| 598 |
) -> dict:
|
| 599 |
"""Delete a session. Only accessible by the session owner."""
|
| 600 |
+
await _check_session_access(session_id, user)
|
| 601 |
success = await session_manager.delete_session(session_id)
|
| 602 |
if not success:
|
| 603 |
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
| 609 |
request: SubmitRequest, user: dict = Depends(get_current_user)
|
| 610 |
) -> dict:
|
| 611 |
"""Submit user input to a session. Only accessible by the session owner."""
|
| 612 |
+
agent_session = await _check_session_access(request.session_id, user)
|
| 613 |
+
await _enforce_claude_quota(user, agent_session)
|
|
|
|
|
|
|
| 614 |
success = await session_manager.submit_user_input(request.session_id, request.text)
|
| 615 |
if not success:
|
| 616 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
|
|
| 622 |
request: ApprovalRequest, user: dict = Depends(get_current_user)
|
| 623 |
) -> dict:
|
| 624 |
"""Submit tool approvals to a session. Only accessible by the session owner."""
|
| 625 |
+
agent_session = await _check_session_access(request.session_id, user)
|
|
|
|
|
|
|
|
|
|
| 626 |
approvals = [
|
| 627 |
{
|
| 628 |
"tool_call_id": a.tool_call_id,
|
|
|
|
| 647 |
user: dict = Depends(get_current_user),
|
| 648 |
) -> StreamingResponse:
|
| 649 |
"""SSE endpoint: submit input or approval, then stream events until turn ends."""
|
| 650 |
+
agent_session = await _check_session_access(session_id, user, request)
|
|
|
|
|
|
|
| 651 |
if not agent_session or not agent_session.is_active:
|
| 652 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 653 |
|
|
|
|
| 713 |
user: dict = Depends(get_current_user),
|
| 714 |
) -> dict:
|
| 715 |
"""Record a click on a Pro upgrade CTA shown from inside a session."""
|
| 716 |
+
agent_session = await _check_session_access(session_id, user)
|
|
|
|
|
|
|
|
|
|
| 717 |
|
| 718 |
from agent.core import telemetry
|
| 719 |
await telemetry.record_pro_cta_click(
|
|
|
|
| 735 |
_SSE_KEEPALIVE_SECONDS = 15
|
| 736 |
|
| 737 |
|
| 738 |
+
def _last_event_seq(request: Request) -> int:
|
| 739 |
+
raw = request.headers.get("last-event-id") or request.query_params.get("after") or "0"
|
| 740 |
+
try:
|
| 741 |
+
return max(0, int(raw))
|
| 742 |
+
except (TypeError, ValueError):
|
| 743 |
+
return 0
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def _format_sse(msg: dict[str, Any]) -> str:
|
| 747 |
+
seq = msg.get("seq")
|
| 748 |
+
body = {"event_type": msg.get("event_type"), "data": msg.get("data") or {}}
|
| 749 |
+
if seq is not None:
|
| 750 |
+
body["seq"] = seq
|
| 751 |
+
return f"id: {seq}\ndata: {json.dumps(body)}\n\n"
|
| 752 |
+
return f"data: {json.dumps(body)}\n\n"
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def _event_doc_to_msg(doc: dict[str, Any]) -> dict[str, Any]:
|
| 756 |
+
return {
|
| 757 |
+
"event_type": doc.get("event_type"),
|
| 758 |
+
"data": doc.get("data") or {},
|
| 759 |
+
"seq": doc.get("seq"),
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
def _sse_response(
|
| 764 |
+
broadcaster,
|
| 765 |
+
event_queue,
|
| 766 |
+
sub_id,
|
| 767 |
+
*,
|
| 768 |
+
replay_events: list[dict[str, Any]] | None = None,
|
| 769 |
+
after_seq: int = 0,
|
| 770 |
+
) -> StreamingResponse:
|
| 771 |
"""Build a StreamingResponse that drains *event_queue* as SSE,
|
| 772 |
sending keepalive comments every 15 s to prevent proxy timeouts."""
|
| 773 |
|
| 774 |
async def event_generator():
|
| 775 |
try:
|
| 776 |
+
for doc in replay_events or []:
|
| 777 |
+
msg = _event_doc_to_msg(doc)
|
| 778 |
+
seq = msg.get("seq")
|
| 779 |
+
if isinstance(seq, int) and seq <= after_seq:
|
| 780 |
+
continue
|
| 781 |
+
yield _format_sse(msg)
|
| 782 |
+
if msg.get("event_type", "") in _TERMINAL_EVENTS:
|
| 783 |
+
return
|
| 784 |
+
|
| 785 |
while True:
|
| 786 |
try:
|
| 787 |
msg = await asyncio.wait_for(
|
|
|
|
| 792 |
yield ": keepalive\n\n"
|
| 793 |
continue
|
| 794 |
event_type = msg.get("event_type", "")
|
| 795 |
+
yield _format_sse(msg)
|
| 796 |
if event_type in _TERMINAL_EVENTS:
|
| 797 |
break
|
| 798 |
finally:
|
|
|
|
| 812 |
@router.get("/events/{session_id}")
|
| 813 |
async def subscribe_events(
|
| 814 |
session_id: str,
|
| 815 |
+
request: Request,
|
| 816 |
user: dict = Depends(get_current_user),
|
| 817 |
) -> StreamingResponse:
|
| 818 |
"""Subscribe to events for a running session without submitting new input.
|
|
|
|
| 820 |
Used by the frontend to re-attach after a connection drop (e.g. screen
|
| 821 |
sleep). Returns 404 if the session isn't active or isn't processing.
|
| 822 |
"""
|
| 823 |
+
agent_session = await _check_session_access(session_id, user, request)
|
|
|
|
|
|
|
| 824 |
if not agent_session or not agent_session.is_active:
|
| 825 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 826 |
|
| 827 |
+
after_seq = _last_event_seq(request)
|
| 828 |
+
replay_events = await session_manager._store().load_events_after(session_id, after_seq)
|
| 829 |
broadcaster = agent_session.broadcaster
|
| 830 |
sub_id, event_queue = broadcaster.subscribe()
|
| 831 |
+
return _sse_response(
|
| 832 |
+
broadcaster,
|
| 833 |
+
event_queue,
|
| 834 |
+
sub_id,
|
| 835 |
+
replay_events=replay_events,
|
| 836 |
+
after_seq=after_seq,
|
| 837 |
+
)
|
| 838 |
|
| 839 |
|
| 840 |
@router.post("/interrupt/{session_id}")
|
|
|
|
| 842 |
session_id: str, user: dict = Depends(get_current_user)
|
| 843 |
) -> dict:
|
| 844 |
"""Interrupt the current operation in a session."""
|
| 845 |
+
await _check_session_access(session_id, user)
|
| 846 |
success = await session_manager.interrupt(session_id)
|
| 847 |
if not success:
|
| 848 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
|
|
| 854 |
session_id: str, user: dict = Depends(get_current_user)
|
| 855 |
) -> list[dict]:
|
| 856 |
"""Return the session's message history from memory."""
|
| 857 |
+
agent_session = await _check_session_access(session_id, user)
|
|
|
|
| 858 |
if not agent_session or not agent_session.is_active:
|
| 859 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 860 |
+
return [msg.model_dump(mode="json") for msg in agent_session.session.context_manager.items]
|
| 861 |
|
| 862 |
|
| 863 |
@router.post("/undo/{session_id}")
|
| 864 |
async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict:
|
| 865 |
"""Undo the last turn in a session."""
|
| 866 |
+
await _check_session_access(session_id, user)
|
| 867 |
success = await session_manager.undo(session_id)
|
| 868 |
if not success:
|
| 869 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
|
|
| 875 |
session_id: str, body: TruncateRequest, user: dict = Depends(get_current_user)
|
| 876 |
) -> dict:
|
| 877 |
"""Truncate conversation to before a specific user message."""
|
| 878 |
+
await _check_session_access(session_id, user)
|
| 879 |
success = await session_manager.truncate(session_id, body.user_message_index)
|
| 880 |
if not success:
|
| 881 |
raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range")
|
|
|
|
| 887 |
session_id: str, user: dict = Depends(get_current_user)
|
| 888 |
) -> dict:
|
| 889 |
"""Compact the context in a session."""
|
| 890 |
+
await _check_session_access(session_id, user)
|
| 891 |
success = await session_manager.compact(session_id)
|
| 892 |
if not success:
|
| 893 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
|
|
| 899 |
session_id: str, user: dict = Depends(get_current_user)
|
| 900 |
) -> dict:
|
| 901 |
"""Shutdown a session."""
|
| 902 |
+
await _check_session_access(session_id, user)
|
| 903 |
success = await session_manager.shutdown_session(session_id)
|
| 904 |
if not success:
|
| 905 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 906 |
return {"status": "shutdown_requested", "session_id": session_id}
|
| 907 |
|
|
|
|
| 908 |
@router.post("/feedback/{session_id}")
|
| 909 |
async def submit_feedback(
|
| 910 |
session_id: str,
|
|
|
|
| 917 |
turn_index?: int, comment?: str, message_id?: str}
|
| 918 |
Appended as a `feedback` event and saved with the session trajectory.
|
| 919 |
"""
|
| 920 |
+
agent_session = await _check_session_access(session_id, user)
|
|
|
|
|
|
|
|
|
|
| 921 |
|
| 922 |
rating = body.get("rating")
|
| 923 |
if rating not in {"up", "down", "outcome_success", "outcome_fail"}:
|
backend/session_manager.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""Session manager for handling multiple concurrent agent sessions."""
|
| 2 |
|
| 3 |
import asyncio
|
|
|
|
| 4 |
import logging
|
| 5 |
import uuid
|
| 6 |
from dataclasses import dataclass, field
|
|
@@ -10,7 +11,9 @@ from typing import Any, Optional
|
|
| 10 |
|
| 11 |
from agent.config import load_config
|
| 12 |
from agent.core.agent_loop import process_submission
|
|
|
|
| 13 |
from agent.core.session import Event, OpType, Session
|
|
|
|
| 14 |
from agent.core.tools import ToolRouter
|
| 15 |
|
| 16 |
# Get project root (parent of backend directory)
|
|
@@ -41,9 +44,8 @@ logger = logging.getLogger(__name__)
|
|
| 41 |
class EventBroadcaster:
|
| 42 |
"""Reads from the agent's event queue and fans out to SSE subscribers.
|
| 43 |
|
| 44 |
-
Events that arrive when no subscribers are listening are discarded
|
| 45 |
-
|
| 46 |
-
scenario that would need buffered replay.
|
| 47 |
"""
|
| 48 |
|
| 49 |
def __init__(self, event_queue: asyncio.Queue):
|
|
@@ -67,7 +69,7 @@ class EventBroadcaster:
|
|
| 67 |
while True:
|
| 68 |
try:
|
| 69 |
event: Event = await self._source.get()
|
| 70 |
-
msg = {"event_type": event.event_type, "data": event.data}
|
| 71 |
for q in self._subscribers.values():
|
| 72 |
await q.put(msg)
|
| 73 |
except asyncio.CancelledError:
|
|
@@ -91,6 +93,7 @@ class AgentSession:
|
|
| 91 |
is_active: bool = True
|
| 92 |
is_processing: bool = False # True while a submission is being executed
|
| 93 |
broadcaster: Any = None
|
|
|
|
| 94 |
# True once this session has been counted against the user's daily
|
| 95 |
# Claude quota. Guards double-counting when the user re-selects an
|
| 96 |
# Anthropic model mid-session.
|
|
@@ -119,8 +122,27 @@ class SessionManager:
|
|
| 119 |
|
| 120 |
def __init__(self, config_path: str | None = None) -> None:
|
| 121 |
self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
|
|
|
|
| 122 |
self.sessions: dict[str, AgentSession] = {}
|
| 123 |
self._lock = asyncio.Lock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
def _count_user_sessions(self, user_id: str) -> int:
|
| 126 |
"""Count active sessions owned by a specific user."""
|
|
@@ -130,6 +152,314 @@ class SessionManager:
|
|
| 130 |
if s.user_id == user_id and s.is_active
|
| 131 |
)
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
async def create_session(
|
| 134 |
self,
|
| 135 |
user_id: str = "dev",
|
|
@@ -178,27 +508,14 @@ class SessionManager:
|
|
| 178 |
event_queue: asyncio.Queue = asyncio.Queue()
|
| 179 |
|
| 180 |
# Run blocking constructors in a thread to keep the event loop responsive.
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
# tab A picking GLM doesn't flip tab B off Claude.
|
| 190 |
-
session_config = self.config.model_copy(deep=True)
|
| 191 |
-
if model:
|
| 192 |
-
session_config.model_name = model
|
| 193 |
-
session = Session(
|
| 194 |
-
event_queue, config=session_config, tool_router=tool_router,
|
| 195 |
-
hf_token=hf_token,
|
| 196 |
-
)
|
| 197 |
-
t1 = _time.monotonic()
|
| 198 |
-
logger.info(f"Session initialized in {t1 - t0:.2f}s")
|
| 199 |
-
return tool_router, session
|
| 200 |
-
|
| 201 |
-
tool_router, session = await asyncio.to_thread(_create_session_sync)
|
| 202 |
|
| 203 |
# Create wrapper
|
| 204 |
agent_session = AgentSession(
|
|
@@ -210,14 +527,12 @@ class SessionManager:
|
|
| 210 |
hf_token=hf_token,
|
| 211 |
)
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
task = asyncio.create_task(
|
| 218 |
-
self._run_session(session_id, submission_queue, event_queue, tool_router)
|
| 219 |
)
|
| 220 |
-
|
| 221 |
|
| 222 |
logger.info(f"Created session {session_id} for user {user_id}")
|
| 223 |
return session_id
|
|
@@ -283,21 +598,38 @@ class SessionManager:
|
|
| 283 |
),
|
| 284 |
)
|
| 285 |
session.context_manager.items.append(seed)
|
|
|
|
| 286 |
return len(parsed)
|
| 287 |
|
| 288 |
@staticmethod
|
| 289 |
async def _cleanup_sandbox(session: Session) -> None:
|
| 290 |
-
"""Delete the sandbox Space if one was created for this session.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
sandbox = getattr(session, "sandbox", None)
|
| 292 |
-
if sandbox and getattr(sandbox, "_owns_space", False):
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
try:
|
| 295 |
-
logger.info(f"Deleting sandbox {space_id}...")
|
| 296 |
await asyncio.to_thread(sandbox.delete)
|
| 297 |
from agent.core import telemetry
|
| 298 |
await telemetry.record_sandbox_destroy(session, sandbox)
|
|
|
|
| 299 |
except Exception as e:
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
async def _run_session(
|
| 303 |
self,
|
|
@@ -337,6 +669,7 @@ class SessionManager:
|
|
| 337 |
should_continue = await process_submission(session, submission)
|
| 338 |
finally:
|
| 339 |
agent_session.is_processing = False
|
|
|
|
| 340 |
if not should_continue:
|
| 341 |
break
|
| 342 |
except asyncio.TimeoutError:
|
|
@@ -371,6 +704,11 @@ class SessionManager:
|
|
| 371 |
async with self._lock:
|
| 372 |
if session_id in self.sessions:
|
| 373 |
self.sessions[session_id].is_active = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
logger.info(f"Session {session_id} ended")
|
| 376 |
|
|
@@ -420,7 +758,10 @@ class SessionManager:
|
|
| 420 |
agent_session = self.sessions.get(session_id)
|
| 421 |
if not agent_session or not agent_session.is_active:
|
| 422 |
return False
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
async def compact(self, session_id: str) -> bool:
|
| 426 |
"""Compact context in a session."""
|
|
@@ -445,12 +786,15 @@ class SessionManager:
|
|
| 445 |
return success
|
| 446 |
|
| 447 |
async def delete_session(self, session_id: str) -> bool:
|
| 448 |
-
"""
|
| 449 |
async with self._lock:
|
| 450 |
agent_session = self.sessions.pop(session_id, None)
|
| 451 |
|
| 452 |
if not agent_session:
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
# Clean up sandbox Space before cancelling the task
|
| 456 |
await self._cleanup_sandbox(agent_session.session)
|
|
@@ -465,6 +809,21 @@ class SessionManager:
|
|
| 465 |
|
| 466 |
return True
|
| 467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
def get_session_owner(self, session_id: str) -> str | None:
|
| 469 |
"""Get the user_id that owns a session, or None if session doesn't exist."""
|
| 470 |
agent_session = self.sessions.get(session_id)
|
|
@@ -492,22 +851,7 @@ class SessionManager:
|
|
| 492 |
if not agent_session:
|
| 493 |
return None
|
| 494 |
|
| 495 |
-
|
| 496 |
-
pending_approval = None
|
| 497 |
-
pa = agent_session.session.pending_approval
|
| 498 |
-
if pa and pa.get("tool_calls"):
|
| 499 |
-
pending_approval = []
|
| 500 |
-
for tc in pa["tool_calls"]:
|
| 501 |
-
import json
|
| 502 |
-
try:
|
| 503 |
-
args = json.loads(tc.function.arguments)
|
| 504 |
-
except (json.JSONDecodeError, AttributeError):
|
| 505 |
-
args = {}
|
| 506 |
-
pending_approval.append({
|
| 507 |
-
"tool": tc.function.name,
|
| 508 |
-
"tool_call_id": tc.id,
|
| 509 |
-
"arguments": args,
|
| 510 |
-
})
|
| 511 |
|
| 512 |
return {
|
| 513 |
"session_id": session_id,
|
|
@@ -518,16 +862,80 @@ class SessionManager:
|
|
| 518 |
"user_id": agent_session.user_id,
|
| 519 |
"pending_approval": pending_approval,
|
| 520 |
"model": agent_session.session.config.model_name,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 521 |
}
|
| 522 |
|
| 523 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
"""List sessions, optionally filtered by user.
|
| 525 |
|
| 526 |
Args:
|
| 527 |
user_id: If provided, only return sessions owned by this user.
|
| 528 |
If "dev", return all sessions (dev mode).
|
| 529 |
"""
|
| 530 |
-
results = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
for sid in self.sessions:
|
| 532 |
info = self.get_session_info(sid)
|
| 533 |
if not info:
|
|
|
|
| 1 |
"""Session manager for handling multiple concurrent agent sessions."""
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
+
import json
|
| 5 |
import logging
|
| 6 |
import uuid
|
| 7 |
from dataclasses import dataclass, field
|
|
|
|
| 11 |
|
| 12 |
from agent.config import load_config
|
| 13 |
from agent.core.agent_loop import process_submission
|
| 14 |
+
from agent.messaging.gateway import NotificationGateway
|
| 15 |
from agent.core.session import Event, OpType, Session
|
| 16 |
+
from agent.core.session_persistence import get_session_store
|
| 17 |
from agent.core.tools import ToolRouter
|
| 18 |
|
| 19 |
# Get project root (parent of backend directory)
|
|
|
|
| 44 |
class EventBroadcaster:
|
| 45 |
"""Reads from the agent's event queue and fans out to SSE subscribers.
|
| 46 |
|
| 47 |
+
Events that arrive when no subscribers are listening are discarded by
|
| 48 |
+
this in-memory fanout. Durable replay is handled by session_persistence.
|
|
|
|
| 49 |
"""
|
| 50 |
|
| 51 |
def __init__(self, event_queue: asyncio.Queue):
|
|
|
|
| 69 |
while True:
|
| 70 |
try:
|
| 71 |
event: Event = await self._source.get()
|
| 72 |
+
msg = {"event_type": event.event_type, "data": event.data, "seq": event.seq}
|
| 73 |
for q in self._subscribers.values():
|
| 74 |
await q.put(msg)
|
| 75 |
except asyncio.CancelledError:
|
|
|
|
| 93 |
is_active: bool = True
|
| 94 |
is_processing: bool = False # True while a submission is being executed
|
| 95 |
broadcaster: Any = None
|
| 96 |
+
title: str | None = None
|
| 97 |
# True once this session has been counted against the user's daily
|
| 98 |
# Claude quota. Guards double-counting when the user re-selects an
|
| 99 |
# Anthropic model mid-session.
|
|
|
|
| 122 |
|
| 123 |
def __init__(self, config_path: str | None = None) -> None:
|
| 124 |
self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
|
| 125 |
+
self.messaging_gateway = NotificationGateway(self.config.messaging)
|
| 126 |
self.sessions: dict[str, AgentSession] = {}
|
| 127 |
self._lock = asyncio.Lock()
|
| 128 |
+
self.persistence_store = None
|
| 129 |
+
|
| 130 |
+
async def start(self) -> None:
|
| 131 |
+
"""Start shared background resources."""
|
| 132 |
+
self.persistence_store = get_session_store()
|
| 133 |
+
await self.persistence_store.init()
|
| 134 |
+
await self.messaging_gateway.start()
|
| 135 |
+
|
| 136 |
+
async def close(self) -> None:
|
| 137 |
+
"""Flush and close shared background resources."""
|
| 138 |
+
await self.messaging_gateway.close()
|
| 139 |
+
if self.persistence_store is not None:
|
| 140 |
+
await self.persistence_store.close()
|
| 141 |
+
|
| 142 |
+
def _store(self):
|
| 143 |
+
if self.persistence_store is None:
|
| 144 |
+
self.persistence_store = get_session_store()
|
| 145 |
+
return self.persistence_store
|
| 146 |
|
| 147 |
def _count_user_sessions(self, user_id: str) -> int:
|
| 148 |
"""Count active sessions owned by a specific user."""
|
|
|
|
| 152 |
if s.user_id == user_id and s.is_active
|
| 153 |
)
|
| 154 |
|
| 155 |
+
def _create_session_sync(
|
| 156 |
+
self,
|
| 157 |
+
*,
|
| 158 |
+
session_id: str,
|
| 159 |
+
user_id: str,
|
| 160 |
+
hf_token: str | None,
|
| 161 |
+
model: str | None,
|
| 162 |
+
event_queue: asyncio.Queue,
|
| 163 |
+
notification_destinations: list[str] | None = None,
|
| 164 |
+
) -> tuple[ToolRouter, Session]:
|
| 165 |
+
"""Build blocking per-session resources in a worker thread."""
|
| 166 |
+
import time as _time
|
| 167 |
+
|
| 168 |
+
t0 = _time.monotonic()
|
| 169 |
+
tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token)
|
| 170 |
+
# Deep-copy config so each session's model switches independently —
|
| 171 |
+
# tab A picking GLM doesn't flip tab B off Claude.
|
| 172 |
+
session_config = self.config.model_copy(deep=True)
|
| 173 |
+
if model:
|
| 174 |
+
session_config.model_name = model
|
| 175 |
+
session = Session(
|
| 176 |
+
event_queue=event_queue,
|
| 177 |
+
config=session_config,
|
| 178 |
+
tool_router=tool_router,
|
| 179 |
+
hf_token=hf_token,
|
| 180 |
+
user_id=user_id,
|
| 181 |
+
notification_gateway=self.messaging_gateway,
|
| 182 |
+
notification_destinations=notification_destinations or [],
|
| 183 |
+
session_id=session_id,
|
| 184 |
+
persistence_store=self._store(),
|
| 185 |
+
)
|
| 186 |
+
t1 = _time.monotonic()
|
| 187 |
+
logger.info("Session initialized in %.2fs", t1 - t0)
|
| 188 |
+
return tool_router, session
|
| 189 |
+
|
| 190 |
+
def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
|
| 191 |
+
return [
|
| 192 |
+
msg.model_dump(mode="json")
|
| 193 |
+
for msg in session.context_manager.items
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]:
|
| 197 |
+
pending = session.pending_approval or {}
|
| 198 |
+
tool_calls = pending.get("tool_calls") or []
|
| 199 |
+
serialized: list[dict[str, Any]] = []
|
| 200 |
+
for tc in tool_calls:
|
| 201 |
+
if hasattr(tc, "model_dump"):
|
| 202 |
+
serialized.append(tc.model_dump(mode="json"))
|
| 203 |
+
elif isinstance(tc, dict):
|
| 204 |
+
serialized.append(tc)
|
| 205 |
+
return serialized
|
| 206 |
+
|
| 207 |
+
@staticmethod
|
| 208 |
+
def _pending_tools_for_api(session: Session) -> list[dict[str, Any]] | None:
|
| 209 |
+
pending = session.pending_approval or {}
|
| 210 |
+
tool_calls = pending.get("tool_calls") or []
|
| 211 |
+
if not tool_calls:
|
| 212 |
+
return None
|
| 213 |
+
result: list[dict[str, Any]] = []
|
| 214 |
+
for tc in tool_calls:
|
| 215 |
+
try:
|
| 216 |
+
args = json.loads(tc.function.arguments)
|
| 217 |
+
except (json.JSONDecodeError, AttributeError, TypeError):
|
| 218 |
+
args = {}
|
| 219 |
+
result.append(
|
| 220 |
+
{
|
| 221 |
+
"tool": getattr(tc.function, "name", None),
|
| 222 |
+
"tool_call_id": getattr(tc, "id", None),
|
| 223 |
+
"arguments": args,
|
| 224 |
+
}
|
| 225 |
+
)
|
| 226 |
+
return result
|
| 227 |
+
|
| 228 |
+
def _restore_pending_approval(
|
| 229 |
+
self, session: Session, pending_approval: list[dict[str, Any]] | None
|
| 230 |
+
) -> None:
|
| 231 |
+
if not pending_approval:
|
| 232 |
+
session.pending_approval = None
|
| 233 |
+
return
|
| 234 |
+
from litellm import ChatCompletionMessageToolCall as ToolCall
|
| 235 |
+
|
| 236 |
+
restored = []
|
| 237 |
+
for raw in pending_approval:
|
| 238 |
+
try:
|
| 239 |
+
if "function" in raw:
|
| 240 |
+
restored.append(ToolCall(**raw))
|
| 241 |
+
else:
|
| 242 |
+
restored.append(
|
| 243 |
+
ToolCall(
|
| 244 |
+
id=raw["tool_call_id"],
|
| 245 |
+
type="function",
|
| 246 |
+
function={
|
| 247 |
+
"name": raw["tool"],
|
| 248 |
+
"arguments": json.dumps(raw.get("arguments") or {}),
|
| 249 |
+
},
|
| 250 |
+
)
|
| 251 |
+
)
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.warning("Dropping malformed pending approval: %s", e)
|
| 254 |
+
session.pending_approval = {"tool_calls": restored} if restored else None
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def _pending_docs_for_api(
|
| 258 |
+
pending_approval: list[dict[str, Any]] | None,
|
| 259 |
+
) -> list[dict[str, Any]] | None:
|
| 260 |
+
if not pending_approval:
|
| 261 |
+
return None
|
| 262 |
+
result: list[dict[str, Any]] = []
|
| 263 |
+
for raw in pending_approval:
|
| 264 |
+
if "function" in raw:
|
| 265 |
+
function = raw.get("function") or {}
|
| 266 |
+
try:
|
| 267 |
+
args = json.loads(function.get("arguments") or "{}")
|
| 268 |
+
except (json.JSONDecodeError, TypeError):
|
| 269 |
+
args = {}
|
| 270 |
+
result.append(
|
| 271 |
+
{
|
| 272 |
+
"tool": function.get("name"),
|
| 273 |
+
"tool_call_id": raw.get("id"),
|
| 274 |
+
"arguments": args,
|
| 275 |
+
}
|
| 276 |
+
)
|
| 277 |
+
elif {"tool", "tool_call_id"}.issubset(raw):
|
| 278 |
+
result.append(
|
| 279 |
+
{
|
| 280 |
+
"tool": raw.get("tool"),
|
| 281 |
+
"tool_call_id": raw.get("tool_call_id"),
|
| 282 |
+
"arguments": raw.get("arguments") or {},
|
| 283 |
+
}
|
| 284 |
+
)
|
| 285 |
+
return result or None
|
| 286 |
+
|
| 287 |
+
@staticmethod
|
| 288 |
+
def _runtime_state(agent_session: AgentSession) -> str:
|
| 289 |
+
if agent_session.session.pending_approval:
|
| 290 |
+
return "waiting_approval"
|
| 291 |
+
if agent_session.is_processing:
|
| 292 |
+
return "processing"
|
| 293 |
+
if not agent_session.is_active:
|
| 294 |
+
return "ended"
|
| 295 |
+
return "idle"
|
| 296 |
+
|
| 297 |
+
async def _start_agent_session(
|
| 298 |
+
self,
|
| 299 |
+
*,
|
| 300 |
+
agent_session: AgentSession,
|
| 301 |
+
event_queue: asyncio.Queue,
|
| 302 |
+
tool_router: ToolRouter,
|
| 303 |
+
) -> AgentSession:
|
| 304 |
+
async with self._lock:
|
| 305 |
+
existing = self.sessions.get(agent_session.session_id)
|
| 306 |
+
if existing:
|
| 307 |
+
return existing
|
| 308 |
+
self.sessions[agent_session.session_id] = agent_session
|
| 309 |
+
|
| 310 |
+
task = asyncio.create_task(
|
| 311 |
+
self._run_session(
|
| 312 |
+
agent_session.session_id,
|
| 313 |
+
agent_session.submission_queue,
|
| 314 |
+
event_queue,
|
| 315 |
+
tool_router,
|
| 316 |
+
)
|
| 317 |
+
)
|
| 318 |
+
agent_session.task = task
|
| 319 |
+
return agent_session
|
| 320 |
+
|
| 321 |
+
@staticmethod
|
| 322 |
+
def _can_access_session(agent_session: AgentSession, user_id: str) -> bool:
|
| 323 |
+
return (
|
| 324 |
+
user_id == "dev"
|
| 325 |
+
or agent_session.user_id == "dev"
|
| 326 |
+
or agent_session.user_id == user_id
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
@staticmethod
|
| 330 |
+
def _update_hf_token(agent_session: AgentSession, hf_token: str | None) -> None:
|
| 331 |
+
if not hf_token:
|
| 332 |
+
return
|
| 333 |
+
agent_session.hf_token = hf_token
|
| 334 |
+
agent_session.session.hf_token = hf_token
|
| 335 |
+
|
| 336 |
+
async def persist_session_snapshot(
|
| 337 |
+
self,
|
| 338 |
+
agent_session: AgentSession,
|
| 339 |
+
*,
|
| 340 |
+
runtime_state: str | None = None,
|
| 341 |
+
status: str = "active",
|
| 342 |
+
) -> None:
|
| 343 |
+
"""Persist the current runtime context snapshot."""
|
| 344 |
+
store = self._store()
|
| 345 |
+
if not getattr(store, "enabled", False):
|
| 346 |
+
return
|
| 347 |
+
try:
|
| 348 |
+
await store.save_snapshot(
|
| 349 |
+
session_id=agent_session.session_id,
|
| 350 |
+
user_id=agent_session.user_id,
|
| 351 |
+
model=agent_session.session.config.model_name,
|
| 352 |
+
title=agent_session.title,
|
| 353 |
+
messages=self._serialize_messages(agent_session.session),
|
| 354 |
+
runtime_state=runtime_state or self._runtime_state(agent_session),
|
| 355 |
+
status=status,
|
| 356 |
+
turn_count=agent_session.session.turn_count,
|
| 357 |
+
pending_approval=self._serialize_pending_approval(agent_session.session),
|
| 358 |
+
claude_counted=agent_session.claude_counted,
|
| 359 |
+
created_at=agent_session.created_at,
|
| 360 |
+
notification_destinations=list(
|
| 361 |
+
agent_session.session.notification_destinations
|
| 362 |
+
),
|
| 363 |
+
)
|
| 364 |
+
except Exception as e:
|
| 365 |
+
logger.warning(
|
| 366 |
+
"Failed to persist snapshot for %s: %s",
|
| 367 |
+
agent_session.session_id,
|
| 368 |
+
e,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
async def ensure_session_loaded(
|
| 372 |
+
self,
|
| 373 |
+
session_id: str,
|
| 374 |
+
user_id: str,
|
| 375 |
+
hf_token: str | None = None,
|
| 376 |
+
) -> AgentSession | None:
|
| 377 |
+
"""Return a live runtime session, lazily restoring it from Mongo."""
|
| 378 |
+
async with self._lock:
|
| 379 |
+
existing = self.sessions.get(session_id)
|
| 380 |
+
if existing:
|
| 381 |
+
if self._can_access_session(existing, user_id):
|
| 382 |
+
self._update_hf_token(existing, hf_token)
|
| 383 |
+
return existing
|
| 384 |
+
return None
|
| 385 |
+
|
| 386 |
+
store = self._store()
|
| 387 |
+
loaded = await store.load_session(session_id)
|
| 388 |
+
if not loaded:
|
| 389 |
+
return None
|
| 390 |
+
|
| 391 |
+
async with self._lock:
|
| 392 |
+
existing = self.sessions.get(session_id)
|
| 393 |
+
if existing:
|
| 394 |
+
if self._can_access_session(existing, user_id):
|
| 395 |
+
self._update_hf_token(existing, hf_token)
|
| 396 |
+
return existing
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
meta = loaded.get("metadata") or {}
|
| 400 |
+
owner = str(meta.get("user_id") or "")
|
| 401 |
+
if user_id != "dev" and owner != "dev" and owner != user_id:
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
from litellm import Message
|
| 405 |
+
|
| 406 |
+
model = meta.get("model") or self.config.model_name
|
| 407 |
+
event_queue: asyncio.Queue = asyncio.Queue()
|
| 408 |
+
submission_queue: asyncio.Queue = asyncio.Queue()
|
| 409 |
+
tool_router, session = await asyncio.to_thread(
|
| 410 |
+
self._create_session_sync,
|
| 411 |
+
session_id=session_id,
|
| 412 |
+
user_id=owner or user_id,
|
| 413 |
+
hf_token=hf_token,
|
| 414 |
+
model=model,
|
| 415 |
+
event_queue=event_queue,
|
| 416 |
+
notification_destinations=meta.get("notification_destinations") or [],
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
restored_messages: list[Message] = []
|
| 420 |
+
for raw in loaded.get("messages") or []:
|
| 421 |
+
if not isinstance(raw, dict) or raw.get("role") == "system":
|
| 422 |
+
continue
|
| 423 |
+
try:
|
| 424 |
+
restored_messages.append(Message.model_validate(raw))
|
| 425 |
+
except Exception as e:
|
| 426 |
+
logger.warning("Dropping malformed restored message: %s", e)
|
| 427 |
+
if restored_messages:
|
| 428 |
+
# Keep the freshly-rendered system prompt, then attach the durable
|
| 429 |
+
# non-system context so tools/date/user context stay current.
|
| 430 |
+
session.context_manager.items = [session.context_manager.items[0], *restored_messages]
|
| 431 |
+
|
| 432 |
+
self._restore_pending_approval(session, meta.get("pending_approval") or [])
|
| 433 |
+
session.turn_count = int(meta.get("turn_count") or 0)
|
| 434 |
+
|
| 435 |
+
created_at = meta.get("created_at")
|
| 436 |
+
if not isinstance(created_at, datetime):
|
| 437 |
+
created_at = datetime.utcnow()
|
| 438 |
+
|
| 439 |
+
agent_session = AgentSession(
|
| 440 |
+
session_id=session_id,
|
| 441 |
+
session=session,
|
| 442 |
+
tool_router=tool_router,
|
| 443 |
+
submission_queue=submission_queue,
|
| 444 |
+
user_id=owner or user_id,
|
| 445 |
+
hf_token=hf_token,
|
| 446 |
+
created_at=created_at,
|
| 447 |
+
is_active=True,
|
| 448 |
+
is_processing=False,
|
| 449 |
+
claude_counted=bool(meta.get("claude_counted")),
|
| 450 |
+
title=meta.get("title"),
|
| 451 |
+
)
|
| 452 |
+
started = await self._start_agent_session(
|
| 453 |
+
agent_session=agent_session,
|
| 454 |
+
event_queue=event_queue,
|
| 455 |
+
tool_router=tool_router,
|
| 456 |
+
)
|
| 457 |
+
if started is not agent_session:
|
| 458 |
+
self._update_hf_token(started, hf_token)
|
| 459 |
+
return started
|
| 460 |
+
logger.info("Restored session %s for user %s", session_id, owner or user_id)
|
| 461 |
+
return agent_session
|
| 462 |
+
|
| 463 |
async def create_session(
|
| 464 |
self,
|
| 465 |
user_id: str = "dev",
|
|
|
|
| 508 |
event_queue: asyncio.Queue = asyncio.Queue()
|
| 509 |
|
| 510 |
# Run blocking constructors in a thread to keep the event loop responsive.
|
| 511 |
+
tool_router, session = await asyncio.to_thread(
|
| 512 |
+
self._create_session_sync,
|
| 513 |
+
session_id=session_id,
|
| 514 |
+
user_id=user_id,
|
| 515 |
+
hf_token=hf_token,
|
| 516 |
+
model=model,
|
| 517 |
+
event_queue=event_queue,
|
| 518 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
# Create wrapper
|
| 521 |
agent_session = AgentSession(
|
|
|
|
| 527 |
hf_token=hf_token,
|
| 528 |
)
|
| 529 |
|
| 530 |
+
await self._start_agent_session(
|
| 531 |
+
agent_session=agent_session,
|
| 532 |
+
event_queue=event_queue,
|
| 533 |
+
tool_router=tool_router,
|
|
|
|
|
|
|
| 534 |
)
|
| 535 |
+
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 536 |
|
| 537 |
logger.info(f"Created session {session_id} for user {user_id}")
|
| 538 |
return session_id
|
|
|
|
| 598 |
),
|
| 599 |
)
|
| 600 |
session.context_manager.items.append(seed)
|
| 601 |
+
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 602 |
return len(parsed)
|
| 603 |
|
| 604 |
@staticmethod
|
| 605 |
async def _cleanup_sandbox(session: Session) -> None:
|
| 606 |
+
"""Delete the sandbox Space if one was created for this session.
|
| 607 |
+
|
| 608 |
+
Retries on transient failures (HF API 5xx, rate-limit, network blips)
|
| 609 |
+
with exponential backoff. A single missed delete = a permanently
|
| 610 |
+
orphaned Space, so the cost of an extra retry beats the alternative.
|
| 611 |
+
"""
|
| 612 |
sandbox = getattr(session, "sandbox", None)
|
| 613 |
+
if not (sandbox and getattr(sandbox, "_owns_space", False)):
|
| 614 |
+
return
|
| 615 |
+
|
| 616 |
+
space_id = getattr(sandbox, "space_id", None)
|
| 617 |
+
last_err: Exception | None = None
|
| 618 |
+
for attempt in range(3):
|
| 619 |
try:
|
| 620 |
+
logger.info(f"Deleting sandbox {space_id} (attempt {attempt + 1}/3)...")
|
| 621 |
await asyncio.to_thread(sandbox.delete)
|
| 622 |
from agent.core import telemetry
|
| 623 |
await telemetry.record_sandbox_destroy(session, sandbox)
|
| 624 |
+
return
|
| 625 |
except Exception as e:
|
| 626 |
+
last_err = e
|
| 627 |
+
if attempt < 2:
|
| 628 |
+
await asyncio.sleep(2 ** attempt)
|
| 629 |
+
logger.error(
|
| 630 |
+
f"Failed to delete sandbox {space_id} after 3 attempts: {last_err}. "
|
| 631 |
+
f"Orphan — sweep script will pick it up."
|
| 632 |
+
)
|
| 633 |
|
| 634 |
async def _run_session(
|
| 635 |
self,
|
|
|
|
| 669 |
should_continue = await process_submission(session, submission)
|
| 670 |
finally:
|
| 671 |
agent_session.is_processing = False
|
| 672 |
+
await self.persist_session_snapshot(agent_session)
|
| 673 |
if not should_continue:
|
| 674 |
break
|
| 675 |
except asyncio.TimeoutError:
|
|
|
|
| 704 |
async with self._lock:
|
| 705 |
if session_id in self.sessions:
|
| 706 |
self.sessions[session_id].is_active = False
|
| 707 |
+
await self.persist_session_snapshot(
|
| 708 |
+
self.sessions[session_id],
|
| 709 |
+
runtime_state="ended",
|
| 710 |
+
status="ended",
|
| 711 |
+
)
|
| 712 |
|
| 713 |
logger.info(f"Session {session_id} ended")
|
| 714 |
|
|
|
|
| 758 |
agent_session = self.sessions.get(session_id)
|
| 759 |
if not agent_session or not agent_session.is_active:
|
| 760 |
return False
|
| 761 |
+
success = agent_session.session.context_manager.truncate_to_user_message(user_message_index)
|
| 762 |
+
if success:
|
| 763 |
+
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 764 |
+
return success
|
| 765 |
|
| 766 |
async def compact(self, session_id: str) -> bool:
|
| 767 |
"""Compact context in a session."""
|
|
|
|
| 786 |
return success
|
| 787 |
|
| 788 |
async def delete_session(self, session_id: str) -> bool:
|
| 789 |
+
"""Soft-delete a session and stop its runtime resources."""
|
| 790 |
async with self._lock:
|
| 791 |
agent_session = self.sessions.pop(session_id, None)
|
| 792 |
|
| 793 |
if not agent_session:
|
| 794 |
+
await self._store().soft_delete_session(session_id)
|
| 795 |
+
return True
|
| 796 |
+
|
| 797 |
+
await self._store().soft_delete_session(session_id)
|
| 798 |
|
| 799 |
# Clean up sandbox Space before cancelling the task
|
| 800 |
await self._cleanup_sandbox(agent_session.session)
|
|
|
|
| 809 |
|
| 810 |
return True
|
| 811 |
|
| 812 |
+
async def update_session_title(self, session_id: str, title: str | None) -> None:
|
| 813 |
+
"""Persist a user-visible title for sidebar rehydration."""
|
| 814 |
+
agent_session = self.sessions.get(session_id)
|
| 815 |
+
if agent_session:
|
| 816 |
+
agent_session.title = title
|
| 817 |
+
await self._store().update_session_fields(session_id, title=title)
|
| 818 |
+
|
| 819 |
+
async def update_session_model(self, session_id: str, model_id: str) -> bool:
|
| 820 |
+
agent_session = self.sessions.get(session_id)
|
| 821 |
+
if not agent_session or not agent_session.is_active:
|
| 822 |
+
return False
|
| 823 |
+
agent_session.session.update_model(model_id)
|
| 824 |
+
await self.persist_session_snapshot(agent_session, runtime_state="idle")
|
| 825 |
+
return True
|
| 826 |
+
|
| 827 |
def get_session_owner(self, session_id: str) -> str | None:
|
| 828 |
"""Get the user_id that owns a session, or None if session doesn't exist."""
|
| 829 |
agent_session = self.sessions.get(session_id)
|
|
|
|
| 851 |
if not agent_session:
|
| 852 |
return None
|
| 853 |
|
| 854 |
+
pending_approval = self._pending_tools_for_api(agent_session.session)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
|
| 856 |
return {
|
| 857 |
"session_id": session_id,
|
|
|
|
| 862 |
"user_id": agent_session.user_id,
|
| 863 |
"pending_approval": pending_approval,
|
| 864 |
"model": agent_session.session.config.model_name,
|
| 865 |
+
"title": agent_session.title,
|
| 866 |
+
"notification_destinations": list(
|
| 867 |
+
agent_session.session.notification_destinations
|
| 868 |
+
),
|
| 869 |
}
|
| 870 |
|
| 871 |
+
def set_notification_destinations(
|
| 872 |
+
self, session_id: str, destinations: list[str]
|
| 873 |
+
) -> list[str]:
|
| 874 |
+
"""Replace the session's opted-in auto-notification destinations."""
|
| 875 |
+
agent_session = self.sessions.get(session_id)
|
| 876 |
+
if not agent_session or not agent_session.is_active:
|
| 877 |
+
raise ValueError("Session not found or inactive")
|
| 878 |
+
|
| 879 |
+
normalized: list[str] = []
|
| 880 |
+
seen: set[str] = set()
|
| 881 |
+
for raw_name in destinations:
|
| 882 |
+
name = raw_name.strip()
|
| 883 |
+
if not name:
|
| 884 |
+
raise ValueError("Destination names must not be empty")
|
| 885 |
+
destination = self.config.messaging.get_destination(name)
|
| 886 |
+
if destination is None:
|
| 887 |
+
raise ValueError(f"Unknown destination '{name}'")
|
| 888 |
+
if not destination.allow_auto_events:
|
| 889 |
+
raise ValueError(
|
| 890 |
+
f"Destination '{name}' is not enabled for auto events"
|
| 891 |
+
)
|
| 892 |
+
if name not in seen:
|
| 893 |
+
normalized.append(name)
|
| 894 |
+
seen.add(name)
|
| 895 |
+
|
| 896 |
+
agent_session.session.set_notification_destinations(normalized)
|
| 897 |
+
return normalized
|
| 898 |
+
|
| 899 |
+
async def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
| 900 |
"""List sessions, optionally filtered by user.
|
| 901 |
|
| 902 |
Args:
|
| 903 |
user_id: If provided, only return sessions owned by this user.
|
| 904 |
If "dev", return all sessions (dev mode).
|
| 905 |
"""
|
| 906 |
+
results: list[dict[str, Any]] = []
|
| 907 |
+
store = self._store()
|
| 908 |
+
if getattr(store, "enabled", False):
|
| 909 |
+
for row in await store.list_sessions(user_id or "dev"):
|
| 910 |
+
sid = row.get("session_id") or row.get("_id")
|
| 911 |
+
if not sid:
|
| 912 |
+
continue
|
| 913 |
+
runtime_info = self.get_session_info(str(sid))
|
| 914 |
+
if runtime_info:
|
| 915 |
+
results.append(runtime_info)
|
| 916 |
+
continue
|
| 917 |
+
created_at = row.get("created_at")
|
| 918 |
+
if isinstance(created_at, datetime):
|
| 919 |
+
created_at_str = created_at.isoformat()
|
| 920 |
+
else:
|
| 921 |
+
created_at_str = str(created_at or datetime.utcnow().isoformat())
|
| 922 |
+
pending = self._pending_docs_for_api(row.get("pending_approval") or [])
|
| 923 |
+
results.append(
|
| 924 |
+
{
|
| 925 |
+
"session_id": str(sid),
|
| 926 |
+
"created_at": created_at_str,
|
| 927 |
+
"is_active": row.get("status") != "ended",
|
| 928 |
+
"is_processing": row.get("runtime_state") == "processing",
|
| 929 |
+
"message_count": int(row.get("message_count") or 0),
|
| 930 |
+
"user_id": row.get("user_id") or "dev",
|
| 931 |
+
"pending_approval": pending or None,
|
| 932 |
+
"model": row.get("model"),
|
| 933 |
+
"title": row.get("title"),
|
| 934 |
+
"notification_destinations": row.get("notification_destinations") or [],
|
| 935 |
+
}
|
| 936 |
+
)
|
| 937 |
+
return results
|
| 938 |
+
|
| 939 |
for sid in self.sessions:
|
| 940 |
info = self.get_session_info(sid)
|
| 941 |
if not info:
|
backend/user_quotas.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
Tracks per-user Claude session starts against a daily cap derived from the
|
| 4 |
-
user's HF plan.
|
| 5 |
-
|
| 6 |
-
restart is much lower than running a DB).
|
| 7 |
|
| 8 |
Unit: session *creations*, not messages. A user who selects Claude in a new
|
| 9 |
session consumes one quota point; switching an existing Claude session to
|
|
@@ -18,6 +17,8 @@ import asyncio
|
|
| 18 |
import os
|
| 19 |
from datetime import UTC, datetime
|
| 20 |
|
|
|
|
|
|
|
| 21 |
CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1"))
|
| 22 |
CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20"))
|
| 23 |
|
|
@@ -37,6 +38,11 @@ def daily_cap_for(plan: str | None) -> int:
|
|
| 37 |
|
| 38 |
async def get_claude_used_today(user_id: str) -> int:
|
| 39 |
"""Return today's Claude session count for the user (0 if none / stale day)."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
async with _lock:
|
| 41 |
entry = _claude_counts.get(user_id)
|
| 42 |
if entry is None:
|
|
@@ -51,11 +57,37 @@ async def get_claude_used_today(user_id: str) -> int:
|
|
| 51 |
|
| 52 |
async def increment_claude(user_id: str) -> int:
|
| 53 |
"""Bump today's Claude session count for the user. Returns the new value."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
async with _lock:
|
| 55 |
today = _today()
|
| 56 |
day, count = _claude_counts.get(user_id, (today, 0))
|
| 57 |
if day != today:
|
| 58 |
count = 0
|
|
|
|
|
|
|
| 59 |
count += 1
|
| 60 |
_claude_counts[user_id] = (today, count)
|
| 61 |
return count
|
|
@@ -63,6 +95,11 @@ async def increment_claude(user_id: str) -> int:
|
|
| 63 |
|
| 64 |
async def refund_claude(user_id: str) -> None:
|
| 65 |
"""Decrement today's count — used when session creation fails after a successful gate."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
async with _lock:
|
| 67 |
entry = _claude_counts.get(user_id)
|
| 68 |
if entry is None:
|
|
@@ -81,3 +118,4 @@ async def refund_claude(user_id: str) -> None:
|
|
| 81 |
def _reset_for_tests() -> None:
|
| 82 |
"""Test-only: clear the in-memory store."""
|
| 83 |
_claude_counts.clear()
|
|
|
|
|
|
| 1 |
+
"""Daily quota for Claude session creations.
|
| 2 |
|
| 3 |
Tracks per-user Claude session starts against a daily cap derived from the
|
| 4 |
+
user's HF plan. MongoDB is the source of truth when configured; the
|
| 5 |
+
in-process dict remains the fallback for local/dev/test runs.
|
|
|
|
| 6 |
|
| 7 |
Unit: session *creations*, not messages. A user who selects Claude in a new
|
| 8 |
session consumes one quota point; switching an existing Claude session to
|
|
|
|
| 17 |
import os
|
| 18 |
from datetime import UTC, datetime
|
| 19 |
|
| 20 |
+
from agent.core.session_persistence import NoopSessionStore, get_session_store, _reset_store_for_tests
|
| 21 |
+
|
| 22 |
CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1"))
|
| 23 |
CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20"))
|
| 24 |
|
|
|
|
| 38 |
|
| 39 |
async def get_claude_used_today(user_id: str) -> int:
|
| 40 |
"""Return today's Claude session count for the user (0 if none / stale day)."""
|
| 41 |
+
store = get_session_store()
|
| 42 |
+
if getattr(store, "enabled", False):
|
| 43 |
+
db_count = await store.get_quota(user_id, _today())
|
| 44 |
+
return db_count or 0
|
| 45 |
+
|
| 46 |
async with _lock:
|
| 47 |
entry = _claude_counts.get(user_id)
|
| 48 |
if entry is None:
|
|
|
|
| 57 |
|
| 58 |
async def increment_claude(user_id: str) -> int:
|
| 59 |
"""Bump today's Claude session count for the user. Returns the new value."""
|
| 60 |
+
store = get_session_store()
|
| 61 |
+
if getattr(store, "enabled", False):
|
| 62 |
+
db_count = await store.try_increment_quota(user_id, _today(), cap=10**9)
|
| 63 |
+
return db_count or 0
|
| 64 |
+
|
| 65 |
+
async with _lock:
|
| 66 |
+
today = _today()
|
| 67 |
+
day, count = _claude_counts.get(user_id, (today, 0))
|
| 68 |
+
if day != today:
|
| 69 |
+
count = 0
|
| 70 |
+
count += 1
|
| 71 |
+
_claude_counts[user_id] = (today, count)
|
| 72 |
+
return count
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
async def try_increment_claude(user_id: str, cap: int) -> int | None:
|
| 76 |
+
"""Atomically bump today's count if below *cap*.
|
| 77 |
+
|
| 78 |
+
Returns the new count, or None when the user is already at the cap.
|
| 79 |
+
"""
|
| 80 |
+
store = get_session_store()
|
| 81 |
+
if getattr(store, "enabled", False):
|
| 82 |
+
return await store.try_increment_quota(user_id, _today(), cap)
|
| 83 |
+
|
| 84 |
async with _lock:
|
| 85 |
today = _today()
|
| 86 |
day, count = _claude_counts.get(user_id, (today, 0))
|
| 87 |
if day != today:
|
| 88 |
count = 0
|
| 89 |
+
if count >= cap:
|
| 90 |
+
return None
|
| 91 |
count += 1
|
| 92 |
_claude_counts[user_id] = (today, count)
|
| 93 |
return count
|
|
|
|
| 95 |
|
| 96 |
async def refund_claude(user_id: str) -> None:
|
| 97 |
"""Decrement today's count — used when session creation fails after a successful gate."""
|
| 98 |
+
store = get_session_store()
|
| 99 |
+
if getattr(store, "enabled", False):
|
| 100 |
+
await store.refund_quota(user_id, _today())
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
async with _lock:
|
| 104 |
entry = _claude_counts.get(user_id)
|
| 105 |
if entry is None:
|
|
|
|
| 118 |
def _reset_for_tests() -> None:
|
| 119 |
"""Test-only: clear the in-memory store."""
|
| 120 |
_claude_counts.clear()
|
| 121 |
+
_reset_store_for_tests(NoopSessionStore())
|
configs/__init__.py
ADDED
|
File without changes
|
configs/cli_agent_config.json
CHANGED
|
@@ -5,6 +5,11 @@
|
|
| 5 |
"yolo_mode": false,
|
| 6 |
"confirm_cpu_jobs": true,
|
| 7 |
"auto_file_upload": true,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"mcpServers": {
|
| 9 |
"hf-mcp-server": {
|
| 10 |
"transport": "http",
|
|
|
|
| 5 |
"yolo_mode": false,
|
| 6 |
"confirm_cpu_jobs": true,
|
| 7 |
"auto_file_upload": true,
|
| 8 |
+
"messaging": {
|
| 9 |
+
"enabled": false,
|
| 10 |
+
"auto_event_types": ["approval_required", "error", "turn_complete"],
|
| 11 |
+
"destinations": {}
|
| 12 |
+
},
|
| 13 |
"mcpServers": {
|
| 14 |
"hf-mcp-server": {
|
| 15 |
"transport": "http",
|
frontend/src/components/Chat/MarkdownContent.tsx
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import { useMemo, useRef, useState, useEffect } from 'react';
|
| 2 |
import { Box } from '@mui/material';
|
| 3 |
import ReactMarkdown from 'react-markdown';
|
| 4 |
import remarkGfm from 'remark-gfm';
|
|
@@ -166,9 +166,17 @@ export default function MarkdownContent({ content, sx, isStreaming = false }: Ma
|
|
| 166 |
|
| 167 |
const remarkPlugins = useMemo(() => [remarkGfm], []);
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
return (
|
| 170 |
<Box sx={[markdownSx, ...(Array.isArray(sx) ? sx : sx ? [sx] : [])]}>
|
| 171 |
-
<ReactMarkdown remarkPlugins={remarkPlugins}>{displayContent}</ReactMarkdown>
|
| 172 |
</Box>
|
| 173 |
);
|
| 174 |
}
|
|
|
|
| 1 |
+
import { useMemo, useRef, useState, useEffect, type ComponentPropsWithoutRef } from 'react';
|
| 2 |
import { Box } from '@mui/material';
|
| 3 |
import ReactMarkdown from 'react-markdown';
|
| 4 |
import remarkGfm from 'remark-gfm';
|
|
|
|
| 166 |
|
| 167 |
const remarkPlugins = useMemo(() => [remarkGfm], []);
|
| 168 |
|
| 169 |
+
const components = useMemo(() => ({
|
| 170 |
+
a: ({ href, children, ...props }: ComponentPropsWithoutRef<'a'>) => (
|
| 171 |
+
<a href={href} target="_blank" rel="noopener noreferrer" {...props}>
|
| 172 |
+
{children}
|
| 173 |
+
</a>
|
| 174 |
+
),
|
| 175 |
+
}), []);
|
| 176 |
+
|
| 177 |
return (
|
| 178 |
<Box sx={[markdownSx, ...(Array.isArray(sx) ? sx : sx ? [sx] : [])]}>
|
| 179 |
+
<ReactMarkdown remarkPlugins={remarkPlugins} components={components}>{displayContent}</ReactMarkdown>
|
| 180 |
</Box>
|
| 181 |
);
|
| 182 |
}
|
frontend/src/components/Chat/ToolCallGroup.tsx
CHANGED
|
@@ -220,6 +220,194 @@ function ResearchSteps({ steps }: { steps: string[] }) {
|
|
| 220 |
);
|
| 221 |
}
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
// ---------------------------------------------------------------------------
|
| 224 |
// Hardware pricing ($/hr) — from HF Spaces & Jobs pricing
|
| 225 |
// ---------------------------------------------------------------------------
|
|
@@ -517,7 +705,7 @@ function InlineApproval({
|
|
| 517 |
const EMPTY_AGENTS: Record<string, ResearchAgentState> = {};
|
| 518 |
|
| 519 |
export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProps) {
|
| 520 |
-
const { setPanel, lockPanel, getJobUrl, getEditedScript, setJobStatus, getJobStatus, setToolError, getToolError, setToolRejected, getToolRejected } = useAgentStore();
|
| 521 |
const researchAgents = useAgentStore(s => {
|
| 522 |
const activeId = s.activeSessionId;
|
| 523 |
return (activeId && s.sessionStates[activeId]?.researchAgents) || EMPTY_AGENTS;
|
|
@@ -1063,6 +1251,18 @@ export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProp
|
|
| 1063 |
<ResearchSteps steps={researchAgents[tool.toolCallId].steps} />
|
| 1064 |
)}
|
| 1065 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1066 |
{/* Per-tool approval: undecided */}
|
| 1067 |
{isPending && !localDecision && !isSubmitting && (
|
| 1068 |
<InlineApproval
|
|
|
|
| 220 |
);
|
| 221 |
}
|
| 222 |
|
| 223 |
+
// ---------------------------------------------------------------------------
|
| 224 |
+
// Trackio dashboard embed
|
| 225 |
+
// ---------------------------------------------------------------------------
|
| 226 |
+
|
| 227 |
+
// HF repo IDs are `<owner>/<name>` where each segment is alphanumerics plus
|
| 228 |
+
// `_`, `.`, `-`. Anything else (slashes, spaces, query params, missing owner)
|
| 229 |
+
// would let an attacker-controlled string redirect the embed to a different
|
| 230 |
+
// Space, so we refuse to render rather than build a malformed URL.
|
| 231 |
+
const SPACE_ID_PATTERN = /^[a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+$/;
|
| 232 |
+
|
| 233 |
+
function isValidSpaceId(spaceId: string): boolean {
|
| 234 |
+
return SPACE_ID_PATTERN.test(spaceId);
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
/** HF Space embed subdomain: 'user/space_name' → 'user-space-name'. */
|
| 238 |
+
function spaceIdToSubdomain(spaceId: string): string {
|
| 239 |
+
return spaceId
|
| 240 |
+
.toLowerCase()
|
| 241 |
+
.replace(/[/_.]/g, '-')
|
| 242 |
+
.replace(/-+/g, '-')
|
| 243 |
+
.replace(/^-|-$/g, '');
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
function buildTrackioEmbedUrl(spaceId: string, project?: string): string {
|
| 247 |
+
// __theme=dark is gradio's standard query param to force the embedded
|
| 248 |
+
// dashboard into dark mode so it blends with the surrounding chat instead
|
| 249 |
+
// of flashing a bright white panel inside the dark UI.
|
| 250 |
+
const params = new URLSearchParams({
|
| 251 |
+
sidebar: 'hidden',
|
| 252 |
+
footer: 'false',
|
| 253 |
+
__theme: 'dark',
|
| 254 |
+
});
|
| 255 |
+
if (project) params.set('project', project);
|
| 256 |
+
return `https://${spaceIdToSubdomain(spaceId)}.hf.space/?${params.toString()}`;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
function buildTrackioPageUrl(spaceId: string, project?: string): string {
|
| 260 |
+
const qs = project ? `?${new URLSearchParams({ project }).toString()}` : '';
|
| 261 |
+
return `https://huggingface.co/spaces/${spaceId}${qs}`;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
function TrackioEmbed({ spaceId, project }: { spaceId: string; project?: string }) {
|
| 265 |
+
const [expanded, setExpanded] = useState(true);
|
| 266 |
+
const [iframeLoaded, setIframeLoaded] = useState(false);
|
| 267 |
+
const embedUrl = useMemo(() => buildTrackioEmbedUrl(spaceId, project), [spaceId, project]);
|
| 268 |
+
const pageUrl = useMemo(() => buildTrackioPageUrl(spaceId, project), [spaceId, project]);
|
| 269 |
+
const label = project ? `${spaceId} · ${project}` : spaceId;
|
| 270 |
+
|
| 271 |
+
if (!isValidSpaceId(spaceId)) return null;
|
| 272 |
+
|
| 273 |
+
return (
|
| 274 |
+
<Box sx={{ pl: 4.5, pr: 1.5, pb: 1, pt: 0.25 }}>
|
| 275 |
+
<Box
|
| 276 |
+
sx={{
|
| 277 |
+
border: '1px solid var(--tool-border)',
|
| 278 |
+
borderRadius: '8px',
|
| 279 |
+
overflow: 'hidden',
|
| 280 |
+
bgcolor: 'var(--code-panel-bg)',
|
| 281 |
+
}}
|
| 282 |
+
>
|
| 283 |
+
<Stack
|
| 284 |
+
direction="row"
|
| 285 |
+
alignItems="center"
|
| 286 |
+
spacing={1}
|
| 287 |
+
onClick={(e) => e.stopPropagation()}
|
| 288 |
+
sx={{
|
| 289 |
+
px: 1.25,
|
| 290 |
+
py: 0.5,
|
| 291 |
+
borderBottom: expanded ? '1px solid var(--tool-border)' : 'none',
|
| 292 |
+
}}
|
| 293 |
+
>
|
| 294 |
+
<Typography
|
| 295 |
+
sx={{
|
| 296 |
+
fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
|
| 297 |
+
fontSize: '0.65rem',
|
| 298 |
+
fontWeight: 600,
|
| 299 |
+
color: 'var(--accent-yellow)',
|
| 300 |
+
letterSpacing: '0.04em',
|
| 301 |
+
}}
|
| 302 |
+
>
|
| 303 |
+
trackio
|
| 304 |
+
</Typography>
|
| 305 |
+
<Typography
|
| 306 |
+
sx={{
|
| 307 |
+
fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
|
| 308 |
+
fontSize: '0.65rem',
|
| 309 |
+
color: 'var(--muted-text)',
|
| 310 |
+
flex: 1,
|
| 311 |
+
minWidth: 0,
|
| 312 |
+
overflow: 'hidden',
|
| 313 |
+
textOverflow: 'ellipsis',
|
| 314 |
+
whiteSpace: 'nowrap',
|
| 315 |
+
}}
|
| 316 |
+
>
|
| 317 |
+
{label}
|
| 318 |
+
</Typography>
|
| 319 |
+
<Link
|
| 320 |
+
href={pageUrl}
|
| 321 |
+
target="_blank"
|
| 322 |
+
rel="noopener noreferrer"
|
| 323 |
+
onClick={(e) => e.stopPropagation()}
|
| 324 |
+
sx={{
|
| 325 |
+
display: 'inline-flex',
|
| 326 |
+
alignItems: 'center',
|
| 327 |
+
gap: 0.4,
|
| 328 |
+
color: 'var(--accent-yellow)',
|
| 329 |
+
fontSize: '0.65rem',
|
| 330 |
+
textDecoration: 'none',
|
| 331 |
+
'&:hover': { textDecoration: 'underline' },
|
| 332 |
+
}}
|
| 333 |
+
>
|
| 334 |
+
<LaunchIcon sx={{ fontSize: 11 }} />
|
| 335 |
+
Open
|
| 336 |
+
</Link>
|
| 337 |
+
<Button
|
| 338 |
+
size="small"
|
| 339 |
+
onClick={(e) => {
|
| 340 |
+
e.stopPropagation();
|
| 341 |
+
setExpanded((v) => !v);
|
| 342 |
+
}}
|
| 343 |
+
sx={{
|
| 344 |
+
textTransform: 'none',
|
| 345 |
+
minWidth: 'auto',
|
| 346 |
+
px: 0.75,
|
| 347 |
+
py: 0,
|
| 348 |
+
fontSize: '0.65rem',
|
| 349 |
+
color: 'var(--muted-text)',
|
| 350 |
+
'&:hover': { color: 'var(--text)', bgcolor: 'transparent' },
|
| 351 |
+
}}
|
| 352 |
+
>
|
| 353 |
+
{expanded ? 'Hide' : 'Show'}
|
| 354 |
+
</Button>
|
| 355 |
+
</Stack>
|
| 356 |
+
{expanded && (
|
| 357 |
+
<Box sx={{ position: 'relative', width: '100%', height: 480, bgcolor: 'var(--code-panel-bg)' }}>
|
| 358 |
+
<iframe
|
| 359 |
+
src={embedUrl}
|
| 360 |
+
title={`Trackio dashboard ${label}`}
|
| 361 |
+
loading="lazy"
|
| 362 |
+
onLoad={() => setIframeLoaded(true)}
|
| 363 |
+
sandbox="allow-scripts allow-same-origin allow-forms allow-popups allow-downloads allow-modals"
|
| 364 |
+
style={{ border: 0, width: '100%', height: '100%', display: 'block' }}
|
| 365 |
+
/>
|
| 366 |
+
{!iframeLoaded && (
|
| 367 |
+
<Stack
|
| 368 |
+
direction="column"
|
| 369 |
+
alignItems="center"
|
| 370 |
+
justifyContent="center"
|
| 371 |
+
spacing={1.5}
|
| 372 |
+
sx={{
|
| 373 |
+
position: 'absolute',
|
| 374 |
+
inset: 0,
|
| 375 |
+
bgcolor: 'var(--code-panel-bg)',
|
| 376 |
+
color: 'var(--muted-text)',
|
| 377 |
+
pointerEvents: 'none',
|
| 378 |
+
}}
|
| 379 |
+
>
|
| 380 |
+
<CircularProgress size={20} sx={{ color: 'var(--accent-yellow)' }} />
|
| 381 |
+
<Typography
|
| 382 |
+
sx={{
|
| 383 |
+
fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
|
| 384 |
+
fontSize: '0.75rem',
|
| 385 |
+
color: 'var(--text)',
|
| 386 |
+
}}
|
| 387 |
+
>
|
| 388 |
+
Spinning up the trackio dashboard…
|
| 389 |
+
</Typography>
|
| 390 |
+
<Typography
|
| 391 |
+
sx={{
|
| 392 |
+
fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
|
| 393 |
+
fontSize: '0.65rem',
|
| 394 |
+
color: 'var(--muted-text)',
|
| 395 |
+
textAlign: 'center',
|
| 396 |
+
maxWidth: 360,
|
| 397 |
+
px: 2,
|
| 398 |
+
}}
|
| 399 |
+
>
|
| 400 |
+
First load takes 30–60 seconds. Charts appear automatically once the run starts logging.
|
| 401 |
+
</Typography>
|
| 402 |
+
</Stack>
|
| 403 |
+
)}
|
| 404 |
+
</Box>
|
| 405 |
+
)}
|
| 406 |
+
</Box>
|
| 407 |
+
</Box>
|
| 408 |
+
);
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
// ---------------------------------------------------------------------------
|
| 412 |
// Hardware pricing ($/hr) — from HF Spaces & Jobs pricing
|
| 413 |
// ---------------------------------------------------------------------------
|
|
|
|
| 705 |
const EMPTY_AGENTS: Record<string, ResearchAgentState> = {};
|
| 706 |
|
| 707 |
export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProps) {
|
| 708 |
+
const { setPanel, lockPanel, getJobUrl, getEditedScript, setJobStatus, getJobStatus, getTrackioDashboard, setToolError, getToolError, setToolRejected, getToolRejected } = useAgentStore();
|
| 709 |
const researchAgents = useAgentStore(s => {
|
| 710 |
const activeId = s.activeSessionId;
|
| 711 |
return (activeId && s.sessionStates[activeId]?.researchAgents) || EMPTY_AGENTS;
|
|
|
|
| 1251 |
<ResearchSteps steps={researchAgents[tool.toolCallId].steps} />
|
| 1252 |
)}
|
| 1253 |
|
| 1254 |
+
{/* Trackio dashboard embed — shown for hf_jobs / sandbox_create runs that declared a trackio space */}
|
| 1255 |
+
{(tool.toolName === 'hf_jobs' || tool.toolName === 'sandbox_create')
|
| 1256 |
+
&& !isPending
|
| 1257 |
+
&& !isRejected
|
| 1258 |
+
&& !cancelled
|
| 1259 |
+
&& (() => {
|
| 1260 |
+
const trackio = getTrackioDashboard(tool.toolCallId);
|
| 1261 |
+
return trackio
|
| 1262 |
+
? <TrackioEmbed spaceId={trackio.spaceId} project={trackio.project} />
|
| 1263 |
+
: null;
|
| 1264 |
+
})()}
|
| 1265 |
+
|
| 1266 |
{/* Per-tool approval: undecided */}
|
| 1267 |
{isPending && !localDecision && !isSubmitting && (
|
| 1268 |
<InlineApproval
|
frontend/src/components/JobsUpgradeDialog.tsx
CHANGED
|
@@ -8,7 +8,6 @@ import {
|
|
| 8 |
DialogContentText,
|
| 9 |
DialogTitle,
|
| 10 |
FormControl,
|
| 11 |
-
InputLabel,
|
| 12 |
MenuItem,
|
| 13 |
Select,
|
| 14 |
Typography,
|
|
@@ -37,13 +36,20 @@ export default function JobsUpgradeDialog({
|
|
| 37 |
onClose,
|
| 38 |
onContinueWithNamespace,
|
| 39 |
}: JobsUpgradeDialogProps) {
|
| 40 |
-
const [selectedNamespace, setSelectedNamespace] = useState('');
|
| 41 |
|
| 42 |
useEffect(() => {
|
| 43 |
if (!open) return;
|
| 44 |
setSelectedNamespace(eligibleNamespaces[0] || '');
|
| 45 |
}, [open, eligibleNamespaces]);
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
return (
|
| 48 |
<Dialog
|
| 49 |
open={open}
|
|
@@ -57,7 +63,7 @@ export default function JobsUpgradeDialog({
|
|
| 57 |
border: '1px solid var(--border)',
|
| 58 |
borderRadius: 'var(--radius-md)',
|
| 59 |
boxShadow: 'var(--shadow-1)',
|
| 60 |
-
maxWidth:
|
| 61 |
mx: 2,
|
| 62 |
},
|
| 63 |
}}
|
|
@@ -65,72 +71,75 @@ export default function JobsUpgradeDialog({
|
|
| 65 |
<DialogTitle
|
| 66 |
sx={{ color: 'var(--text)', fontWeight: 700, fontSize: '1rem', pt: 2.5, pb: 0, px: 3 }}
|
| 67 |
>
|
| 68 |
-
{
|
| 69 |
</DialogTitle>
|
| 70 |
<DialogContent sx={{ px: 3, pt: 1.25, pb: 0 }}>
|
| 71 |
<DialogContentText
|
| 72 |
sx={{ color: 'var(--muted-text)', fontSize: '0.85rem', lineHeight: 1.6 }}
|
| 73 |
>
|
| 74 |
-
{
|
| 75 |
</DialogContentText>
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
border: '1px solid var(--border)',
|
| 84 |
-
}}
|
| 85 |
-
>
|
| 86 |
-
<Typography
|
| 87 |
-
variant="caption"
|
| 88 |
sx={{
|
| 89 |
-
|
| 90 |
-
fontWeight: 700,
|
| 91 |
color: 'var(--text)',
|
| 92 |
-
fontSize: '0.
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
}}
|
| 96 |
>
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
onChange={(e) => setSelectedNamespace(String(e.target.value))}
|
| 107 |
>
|
| 108 |
-
{
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
<Typography
|
| 117 |
variant="caption"
|
| 118 |
-
sx={{
|
| 119 |
>
|
| 120 |
-
{eligibleNamespaces.join(', ')}
|
| 121 |
</Typography>
|
| 122 |
-
|
| 123 |
-
|
| 124 |
)}
|
| 125 |
-
<Typography
|
| 126 |
-
variant="caption"
|
| 127 |
-
sx={{ display: 'block', mt: 2, color: 'var(--muted-text)', fontSize: '0.78rem', lineHeight: 1.55 }}
|
| 128 |
-
>
|
| 129 |
-
If you decline, the agent will have to find another way forward without `hf_jobs`.
|
| 130 |
-
</Typography>
|
| 131 |
</DialogContent>
|
| 132 |
-
<DialogActions sx={{ px: 3, pb: 2.5, pt: 2, gap: 1 }}>
|
| 133 |
-
{
|
| 134 |
<Button
|
| 135 |
onClick={() => onContinueWithNamespace(selectedNamespace)}
|
| 136 |
disabled={!selectedNamespace}
|
|
@@ -147,7 +156,7 @@ export default function JobsUpgradeDialog({
|
|
| 147 |
'&:hover': { bgcolor: '#FFB340', boxShadow: 'none' },
|
| 148 |
}}
|
| 149 |
>
|
| 150 |
-
|
| 151 |
</Button>
|
| 152 |
) : (
|
| 153 |
<Button
|
|
@@ -183,7 +192,7 @@ export default function JobsUpgradeDialog({
|
|
| 183 |
'&:hover': { bgcolor: 'var(--hover-bg)' },
|
| 184 |
}}
|
| 185 |
>
|
| 186 |
-
Decline tool call
|
| 187 |
</Button>
|
| 188 |
</DialogActions>
|
| 189 |
</Dialog>
|
|
|
|
| 8 |
DialogContentText,
|
| 9 |
DialogTitle,
|
| 10 |
FormControl,
|
|
|
|
| 11 |
MenuItem,
|
| 12 |
Select,
|
| 13 |
Typography,
|
|
|
|
| 36 |
onClose,
|
| 37 |
onContinueWithNamespace,
|
| 38 |
}: JobsUpgradeDialogProps) {
|
| 39 |
+
const [selectedNamespace, setSelectedNamespace] = useState(() => eligibleNamespaces[0] || '');
|
| 40 |
|
| 41 |
useEffect(() => {
|
| 42 |
if (!open) return;
|
| 43 |
setSelectedNamespace(eligibleNamespaces[0] || '');
|
| 44 |
}, [open, eligibleNamespaces]);
|
| 45 |
|
| 46 |
+
const isNamespace = mode === 'namespace';
|
| 47 |
+
const title = isNamespace ? 'Run jobs as' : 'Jobs need Pro or a paid org';
|
| 48 |
+
|
| 49 |
+
const body = isNamespace
|
| 50 |
+
? "Pick which paid organization should pay for and own this job. We'll use the same one for the rest of this browser."
|
| 51 |
+
: message;
|
| 52 |
+
|
| 53 |
return (
|
| 54 |
<Dialog
|
| 55 |
open={open}
|
|
|
|
| 63 |
border: '1px solid var(--border)',
|
| 64 |
borderRadius: 'var(--radius-md)',
|
| 65 |
boxShadow: 'var(--shadow-1)',
|
| 66 |
+
maxWidth: 460,
|
| 67 |
mx: 2,
|
| 68 |
},
|
| 69 |
}}
|
|
|
|
| 71 |
<DialogTitle
|
| 72 |
sx={{ color: 'var(--text)', fontWeight: 700, fontSize: '1rem', pt: 2.5, pb: 0, px: 3 }}
|
| 73 |
>
|
| 74 |
+
{title}
|
| 75 |
</DialogTitle>
|
| 76 |
<DialogContent sx={{ px: 3, pt: 1.25, pb: 0 }}>
|
| 77 |
<DialogContentText
|
| 78 |
sx={{ color: 'var(--muted-text)', fontSize: '0.85rem', lineHeight: 1.6 }}
|
| 79 |
>
|
| 80 |
+
{body}
|
| 81 |
</DialogContentText>
|
| 82 |
+
|
| 83 |
+
{isNamespace ? (
|
| 84 |
+
<FormControl fullWidth size="small" sx={{ mt: 2 }}>
|
| 85 |
+
<Select
|
| 86 |
+
value={selectedNamespace}
|
| 87 |
+
displayEmpty
|
| 88 |
+
onChange={(e) => setSelectedNamespace(String(e.target.value))}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
sx={{
|
| 90 |
+
bgcolor: 'var(--composer-bg)',
|
|
|
|
| 91 |
color: 'var(--text)',
|
| 92 |
+
fontSize: '0.88rem',
|
| 93 |
+
fontWeight: 600,
|
| 94 |
+
'& .MuiOutlinedInput-notchedOutline': { borderColor: 'var(--border)' },
|
| 95 |
+
'&:hover .MuiOutlinedInput-notchedOutline': { borderColor: 'var(--border)' },
|
| 96 |
+
'&.Mui-focused .MuiOutlinedInput-notchedOutline': {
|
| 97 |
+
borderColor: 'var(--accent-yellow)',
|
| 98 |
+
borderWidth: 1,
|
| 99 |
+
},
|
| 100 |
+
'& .MuiSelect-icon': { color: 'var(--muted-text)' },
|
| 101 |
+
}}
|
| 102 |
+
MenuProps={{
|
| 103 |
+
PaperProps: {
|
| 104 |
+
sx: {
|
| 105 |
+
bgcolor: 'var(--panel)',
|
| 106 |
+
border: '1px solid var(--border)',
|
| 107 |
+
borderRadius: '8px',
|
| 108 |
+
mt: 0.5,
|
| 109 |
+
},
|
| 110 |
+
},
|
| 111 |
}}
|
| 112 |
>
|
| 113 |
+
{eligibleNamespaces.map((namespace) => (
|
| 114 |
+
<MenuItem
|
| 115 |
+
key={namespace}
|
| 116 |
+
value={namespace}
|
| 117 |
+
sx={{
|
| 118 |
+
fontSize: '0.88rem',
|
| 119 |
+
color: 'var(--text)',
|
| 120 |
+
'&.Mui-selected': { bgcolor: 'rgba(255,255,255,0.05)' },
|
| 121 |
+
}}
|
|
|
|
| 122 |
>
|
| 123 |
+
{namespace}
|
| 124 |
+
</MenuItem>
|
| 125 |
+
))}
|
| 126 |
+
</Select>
|
| 127 |
+
</FormControl>
|
| 128 |
+
) : (
|
| 129 |
+
eligibleNamespaces.length > 0 && (
|
| 130 |
+
<Box sx={{ mt: 1.5 }}>
|
| 131 |
<Typography
|
| 132 |
variant="caption"
|
| 133 |
+
sx={{ color: 'var(--muted-text)', fontSize: '0.78rem', lineHeight: 1.55 }}
|
| 134 |
>
|
| 135 |
+
Eligible namespaces: {eligibleNamespaces.join(', ')}
|
| 136 |
</Typography>
|
| 137 |
+
</Box>
|
| 138 |
+
)
|
| 139 |
)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
</DialogContent>
|
| 141 |
+
<DialogActions sx={{ px: 3, pb: 2.5, pt: 2.5, gap: 1 }}>
|
| 142 |
+
{isNamespace ? (
|
| 143 |
<Button
|
| 144 |
onClick={() => onContinueWithNamespace(selectedNamespace)}
|
| 145 |
disabled={!selectedNamespace}
|
|
|
|
| 156 |
'&:hover': { bgcolor: '#FFB340', boxShadow: 'none' },
|
| 157 |
}}
|
| 158 |
>
|
| 159 |
+
Continue
|
| 160 |
</Button>
|
| 161 |
) : (
|
| 162 |
<Button
|
|
|
|
| 192 |
'&:hover': { bgcolor: 'var(--hover-bg)' },
|
| 193 |
}}
|
| 194 |
>
|
| 195 |
+
{isNamespace ? 'Skip this tool call' : 'Decline tool call'}
|
| 196 |
</Button>
|
| 197 |
</DialogActions>
|
| 198 |
</Dialog>
|
frontend/src/components/SessionSidebar/SessionSidebar.tsx
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import { useCallback, useState } from 'react';
|
| 2 |
import {
|
| 3 |
Alert,
|
| 4 |
Box,
|
|
@@ -25,13 +25,30 @@ interface SessionSidebarProps {
|
|
| 25 |
}
|
| 26 |
|
| 27 |
export default function SessionSidebar({ onClose }: SessionSidebarProps) {
|
| 28 |
-
const { sessions, activeSessionId, createSession, deleteSession, switchSession } =
|
| 29 |
useSessionStore();
|
| 30 |
const { setPlan, clearPanel } =
|
| 31 |
useAgentStore();
|
| 32 |
const [isCreatingSession, setIsCreatingSession] = useState(false);
|
| 33 |
const [capacityError, setCapacityError] = useState<string | null>(null);
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
// -- Handlers -----------------------------------------------------------
|
| 36 |
|
| 37 |
const handleNewSession = useCallback(async () => {
|
|
|
|
| 1 |
+
import { useCallback, useEffect, useState } from 'react';
|
| 2 |
import {
|
| 3 |
Alert,
|
| 4 |
Box,
|
|
|
|
| 25 |
}
|
| 26 |
|
| 27 |
export default function SessionSidebar({ onClose }: SessionSidebarProps) {
|
| 28 |
+
const { sessions, activeSessionId, createSession, deleteSession, switchSession, mergeServerSessions } =
|
| 29 |
useSessionStore();
|
| 30 |
const { setPlan, clearPanel } =
|
| 31 |
useAgentStore();
|
| 32 |
const [isCreatingSession, setIsCreatingSession] = useState(false);
|
| 33 |
const [capacityError, setCapacityError] = useState<string | null>(null);
|
| 34 |
|
| 35 |
+
useEffect(() => {
|
| 36 |
+
let cancelled = false;
|
| 37 |
+
(async () => {
|
| 38 |
+
try {
|
| 39 |
+
const response = await apiFetch('/api/sessions');
|
| 40 |
+
if (!response.ok) return;
|
| 41 |
+
const data = await response.json();
|
| 42 |
+
if (!cancelled && Array.isArray(data)) {
|
| 43 |
+
mergeServerSessions(data);
|
| 44 |
+
}
|
| 45 |
+
} catch {
|
| 46 |
+
/* local sidebar metadata is still usable */
|
| 47 |
+
}
|
| 48 |
+
})();
|
| 49 |
+
return () => { cancelled = true; };
|
| 50 |
+
}, [mergeServerSessions]);
|
| 51 |
+
|
| 52 |
// -- Handlers -----------------------------------------------------------
|
| 53 |
|
| 54 |
const handleNewSession = useCallback(async () => {
|
frontend/src/components/WelcomeScreen/WelcomeScreen.tsx
CHANGED
|
@@ -280,6 +280,12 @@ export default function WelcomeScreen() {
|
|
| 280 |
: '';
|
| 281 |
|
| 282 |
return (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
<Box
|
| 284 |
sx={{
|
| 285 |
width: '100%',
|
|
@@ -287,172 +293,182 @@ export default function WelcomeScreen() {
|
|
| 287 |
display: 'flex',
|
| 288 |
flexDirection: 'column',
|
| 289 |
alignItems: 'center',
|
| 290 |
-
|
| 291 |
background: 'var(--body-gradient)',
|
| 292 |
-
py: 8,
|
| 293 |
}}
|
| 294 |
>
|
| 295 |
-
{/* Logo */}
|
| 296 |
<Box
|
| 297 |
-
component="img"
|
| 298 |
-
src="/smolagents.webp"
|
| 299 |
-
alt="smolagents"
|
| 300 |
-
sx={{ width: 80, height: 80, mb: 2.5, display: 'block' }}
|
| 301 |
-
/>
|
| 302 |
-
|
| 303 |
-
{/* Title */}
|
| 304 |
-
<Typography
|
| 305 |
-
variant="h2"
|
| 306 |
sx={{
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
| 312 |
}}
|
| 313 |
>
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
>
|
| 331 |
-
Your personal <strong>ML agent</strong>. It reads <strong>papers</strong>, finds <strong>datasets</strong>, trains <strong>models</strong>, and iterates until the numbers go up. Instructions in. Trained model out.
|
| 332 |
-
</Typography>
|
| 333 |
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
<
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
description="Get free access to GPUs, inference APIs, and Hub resources."
|
| 366 |
-
status={isOrgMember ? 'completed' : 'active'}
|
| 367 |
-
actionLabel="Join Organization"
|
| 368 |
-
actionIcon={<GroupAddIcon sx={{ fontSize: 16 }} />}
|
| 369 |
-
onAction={handleJoinOrg}
|
| 370 |
-
/>
|
| 371 |
-
<ChecklistStep
|
| 372 |
-
stepNumber={2}
|
| 373 |
-
title="Open ML Intern"
|
| 374 |
-
description="Open the agent in a full browser tab to get started."
|
| 375 |
-
status={isOrgMember ? 'active' : 'locked'}
|
| 376 |
-
lockedReason="Join the organization first."
|
| 377 |
-
actionLabel="Open ML Intern"
|
| 378 |
-
actionIcon={<OpenInNewIcon sx={{ fontSize: 16 }} />}
|
| 379 |
-
actionHref={spaceHost}
|
| 380 |
-
isLast
|
| 381 |
-
/>
|
| 382 |
-
</>
|
| 383 |
-
) : (
|
| 384 |
-
/* Direct access: 3 steps */
|
| 385 |
-
<>
|
| 386 |
<ChecklistStep
|
| 387 |
stepNumber={1}
|
| 388 |
-
title="Sign in with Hugging Face"
|
| 389 |
-
description="Authenticate to access GPU resources and model APIs."
|
| 390 |
-
status={signInStatus}
|
| 391 |
-
actionLabel="Sign in"
|
| 392 |
-
actionIcon={<LoginIcon sx={{ fontSize: 16 }} />}
|
| 393 |
-
onAction={() => triggerLogin()}
|
| 394 |
-
/>
|
| 395 |
-
<ChecklistStep
|
| 396 |
-
stepNumber={2}
|
| 397 |
-
title="Join ML Agent Explorers"
|
| 398 |
-
description="Get free access to GPUs, inference APIs, and Hub resources."
|
| 399 |
-
status={joinOrgStatus}
|
| 400 |
-
lockedReason="Sign in first to continue."
|
| 401 |
-
actionLabel="Join Organization"
|
| 402 |
-
actionIcon={<GroupAddIcon sx={{ fontSize: 16 }} />}
|
| 403 |
-
onAction={handleJoinOrg}
|
| 404 |
-
/>
|
| 405 |
-
<ChecklistStep
|
| 406 |
-
stepNumber={3}
|
| 407 |
title="Start Session"
|
| 408 |
description="Launch an AI agent session for ML engineering."
|
| 409 |
-
status=
|
| 410 |
-
lockedReason="Complete the steps above to continue."
|
| 411 |
actionLabel="Start Session"
|
| 412 |
actionIcon={<RocketLaunchIcon sx={{ fontSize: 16 }} />}
|
| 413 |
onAction={handleStartSession}
|
| 414 |
loading={isCreating}
|
| 415 |
isLast
|
| 416 |
/>
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
)}
|
| 419 |
-
</Box>
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
<Typography
|
| 424 |
variant="caption"
|
| 425 |
-
sx={{ mt:
|
| 426 |
>
|
| 427 |
-
|
| 428 |
</Typography>
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
{/* Error */}
|
| 432 |
-
{error && (
|
| 433 |
-
<Alert
|
| 434 |
-
severity="warning"
|
| 435 |
-
variant="outlined"
|
| 436 |
-
onClose={() => setError(null)}
|
| 437 |
-
sx={{
|
| 438 |
-
mt: 3,
|
| 439 |
-
maxWidth: 400,
|
| 440 |
-
fontSize: '0.8rem',
|
| 441 |
-
borderColor: HF_ORANGE,
|
| 442 |
-
color: 'var(--text)',
|
| 443 |
-
}}
|
| 444 |
-
>
|
| 445 |
-
{error}
|
| 446 |
-
</Alert>
|
| 447 |
-
)}
|
| 448 |
-
|
| 449 |
-
{/* Footnote */}
|
| 450 |
-
<Typography
|
| 451 |
-
variant="caption"
|
| 452 |
-
sx={{ mt: 4, color: 'var(--muted-text)', opacity: 0.5, fontSize: '0.7rem' }}
|
| 453 |
-
>
|
| 454 |
-
Conversations are stored locally in your browser.
|
| 455 |
-
</Typography>
|
| 456 |
</Box>
|
| 457 |
);
|
| 458 |
}
|
|
|
|
| 280 |
: '';
|
| 281 |
|
| 282 |
return (
|
| 283 |
+
// Outer container scrolls; inner uses `margin: auto` so the checklist
|
| 284 |
+
// centers vertically when the viewport has room and falls back to top-
|
| 285 |
+
// aligned + scrollable when it doesn't. The previous setup hardcoded
|
| 286 |
+
// `justify-content: center` with no overflow, so on short viewports
|
| 287 |
+
// (1366×768 Chrome was the reported case) the bottom of the card —
|
| 288 |
+
// including the "Start session" CTA — got clipped with no way to scroll.
|
| 289 |
<Box
|
| 290 |
sx={{
|
| 291 |
width: '100%',
|
|
|
|
| 293 |
display: 'flex',
|
| 294 |
flexDirection: 'column',
|
| 295 |
alignItems: 'center',
|
| 296 |
+
overflowY: 'auto',
|
| 297 |
background: 'var(--body-gradient)',
|
|
|
|
| 298 |
}}
|
| 299 |
>
|
|
|
|
| 300 |
<Box
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
sx={{
|
| 302 |
+
display: 'flex',
|
| 303 |
+
flexDirection: 'column',
|
| 304 |
+
alignItems: 'center',
|
| 305 |
+
width: '100%',
|
| 306 |
+
margin: 'auto',
|
| 307 |
+
py: 8,
|
| 308 |
}}
|
| 309 |
>
|
| 310 |
+
{/* Logo */}
|
| 311 |
+
<Box
|
| 312 |
+
component="img"
|
| 313 |
+
src="/smolagents.webp"
|
| 314 |
+
alt="smolagents"
|
| 315 |
+
sx={{ width: 80, height: 80, mb: 2.5, display: 'block' }}
|
| 316 |
+
/>
|
| 317 |
|
| 318 |
+
{/* Title */}
|
| 319 |
+
<Typography
|
| 320 |
+
variant="h2"
|
| 321 |
+
sx={{
|
| 322 |
+
fontWeight: 800,
|
| 323 |
+
color: 'var(--text)',
|
| 324 |
+
mb: 1,
|
| 325 |
+
letterSpacing: '-0.02em',
|
| 326 |
+
fontSize: { xs: '1.8rem', md: '2.4rem' },
|
| 327 |
+
}}
|
| 328 |
+
>
|
| 329 |
+
ML Intern
|
| 330 |
+
</Typography>
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
+
{/* Description */}
|
| 333 |
+
<Typography
|
| 334 |
+
variant="body1"
|
| 335 |
+
sx={{
|
| 336 |
+
color: 'var(--muted-text)',
|
| 337 |
+
maxWidth: 480,
|
| 338 |
+
mb: 4,
|
| 339 |
+
lineHeight: 1.7,
|
| 340 |
+
fontSize: '0.9rem',
|
| 341 |
+
textAlign: 'center',
|
| 342 |
+
px: 2,
|
| 343 |
+
'& strong': { color: 'var(--text)', fontWeight: 600 },
|
| 344 |
+
}}
|
| 345 |
+
>
|
| 346 |
+
Your personal <strong>ML agent</strong>. It reads <strong>papers</strong>, finds <strong>datasets</strong>, trains <strong>models</strong>, and iterates until the numbers go up. Instructions in. Trained model out.
|
| 347 |
+
</Typography>
|
| 348 |
+
|
| 349 |
+
{/* ── Checklist ──────────────────────────────────────────── */}
|
| 350 |
+
<Box
|
| 351 |
+
sx={{
|
| 352 |
+
width: '100%',
|
| 353 |
+
maxWidth: 520,
|
| 354 |
+
bgcolor: 'var(--surface)',
|
| 355 |
+
border: '1px solid var(--border)',
|
| 356 |
+
borderRadius: '12px',
|
| 357 |
+
overflow: 'hidden',
|
| 358 |
+
mx: 2,
|
| 359 |
+
}}
|
| 360 |
+
>
|
| 361 |
+
{isDevUser ? (
|
| 362 |
+
/* Dev mode: single step */
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
<ChecklistStep
|
| 364 |
stepNumber={1}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
title="Start Session"
|
| 366 |
description="Launch an AI agent session for ML engineering."
|
| 367 |
+
status="active"
|
|
|
|
| 368 |
actionLabel="Start Session"
|
| 369 |
actionIcon={<RocketLaunchIcon sx={{ fontSize: 16 }} />}
|
| 370 |
onAction={handleStartSession}
|
| 371 |
loading={isCreating}
|
| 372 |
isLast
|
| 373 |
/>
|
| 374 |
+
) : inIframe ? (
|
| 375 |
+
/* Iframe: 2 steps */
|
| 376 |
+
<>
|
| 377 |
+
<ChecklistStep
|
| 378 |
+
stepNumber={1}
|
| 379 |
+
title="Join ML Agent Explorers"
|
| 380 |
+
description="Get free access to GPUs, inference APIs, and Hub resources."
|
| 381 |
+
status={isOrgMember ? 'completed' : 'active'}
|
| 382 |
+
actionLabel="Join Organization"
|
| 383 |
+
actionIcon={<GroupAddIcon sx={{ fontSize: 16 }} />}
|
| 384 |
+
onAction={handleJoinOrg}
|
| 385 |
+
/>
|
| 386 |
+
<ChecklistStep
|
| 387 |
+
stepNumber={2}
|
| 388 |
+
title="Open ML Intern"
|
| 389 |
+
description="Open the agent in a full browser tab to get started."
|
| 390 |
+
status={isOrgMember ? 'active' : 'locked'}
|
| 391 |
+
lockedReason="Join the organization first."
|
| 392 |
+
actionLabel="Open ML Intern"
|
| 393 |
+
actionIcon={<OpenInNewIcon sx={{ fontSize: 16 }} />}
|
| 394 |
+
actionHref={spaceHost}
|
| 395 |
+
isLast
|
| 396 |
+
/>
|
| 397 |
+
</>
|
| 398 |
+
) : (
|
| 399 |
+
/* Direct access: 3 steps */
|
| 400 |
+
<>
|
| 401 |
+
<ChecklistStep
|
| 402 |
+
stepNumber={1}
|
| 403 |
+
title="Sign in with Hugging Face"
|
| 404 |
+
description="Authenticate to access GPU resources and model APIs."
|
| 405 |
+
status={signInStatus}
|
| 406 |
+
actionLabel="Sign in"
|
| 407 |
+
actionIcon={<LoginIcon sx={{ fontSize: 16 }} />}
|
| 408 |
+
onAction={() => triggerLogin()}
|
| 409 |
+
/>
|
| 410 |
+
<ChecklistStep
|
| 411 |
+
stepNumber={2}
|
| 412 |
+
title="Join ML Agent Explorers"
|
| 413 |
+
description="Get free access to GPUs, inference APIs, and Hub resources."
|
| 414 |
+
status={joinOrgStatus}
|
| 415 |
+
lockedReason="Sign in first to continue."
|
| 416 |
+
actionLabel="Join Organization"
|
| 417 |
+
actionIcon={<GroupAddIcon sx={{ fontSize: 16 }} />}
|
| 418 |
+
onAction={handleJoinOrg}
|
| 419 |
+
/>
|
| 420 |
+
<ChecklistStep
|
| 421 |
+
stepNumber={3}
|
| 422 |
+
title="Start Session"
|
| 423 |
+
description="Launch an AI agent session for ML engineering."
|
| 424 |
+
status={startStatus}
|
| 425 |
+
lockedReason="Complete the steps above to continue."
|
| 426 |
+
actionLabel="Start Session"
|
| 427 |
+
actionIcon={<RocketLaunchIcon sx={{ fontSize: 16 }} />}
|
| 428 |
+
onAction={handleStartSession}
|
| 429 |
+
loading={isCreating}
|
| 430 |
+
isLast
|
| 431 |
+
/>
|
| 432 |
+
</>
|
| 433 |
+
)}
|
| 434 |
+
</Box>
|
| 435 |
+
|
| 436 |
+
{/* Polling hint when waiting for org join */}
|
| 437 |
+
{isAuthenticated && !isOrgMember && !isDevUser && !inIframe && (
|
| 438 |
+
<Typography
|
| 439 |
+
variant="caption"
|
| 440 |
+
sx={{ mt: 2, color: 'var(--muted-text)', fontSize: '0.75rem', textAlign: 'center' }}
|
| 441 |
+
>
|
| 442 |
+
This page updates automatically when you join the organization.
|
| 443 |
+
</Typography>
|
| 444 |
)}
|
|
|
|
| 445 |
|
| 446 |
+
{/* Error */}
|
| 447 |
+
{error && (
|
| 448 |
+
<Alert
|
| 449 |
+
severity="warning"
|
| 450 |
+
variant="outlined"
|
| 451 |
+
onClose={() => setError(null)}
|
| 452 |
+
sx={{
|
| 453 |
+
mt: 3,
|
| 454 |
+
maxWidth: 400,
|
| 455 |
+
fontSize: '0.8rem',
|
| 456 |
+
borderColor: HF_ORANGE,
|
| 457 |
+
color: 'var(--text)',
|
| 458 |
+
}}
|
| 459 |
+
>
|
| 460 |
+
{error}
|
| 461 |
+
</Alert>
|
| 462 |
+
)}
|
| 463 |
+
|
| 464 |
+
{/* Footnote */}
|
| 465 |
<Typography
|
| 466 |
variant="caption"
|
| 467 |
+
sx={{ mt: 4, color: 'var(--muted-text)', opacity: 0.5, fontSize: '0.7rem' }}
|
| 468 |
>
|
| 469 |
+
Conversations are stored locally in your browser.
|
| 470 |
</Typography>
|
| 471 |
+
</Box>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
</Box>
|
| 473 |
);
|
| 474 |
}
|
frontend/src/hooks/useAgentChat.ts
CHANGED
|
@@ -371,7 +371,7 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 371 |
} catch {
|
| 372 |
return null;
|
| 373 |
}
|
| 374 |
-
}, [sessionId, setNeedsAttention]);
|
| 375 |
|
| 376 |
// -- useChat from Vercel AI SDK -----------------------------------------
|
| 377 |
const chat = useChat({
|
|
@@ -447,6 +447,33 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 447 |
}
|
| 448 |
return;
|
| 449 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
logger.error('useChat error:', error);
|
| 451 |
if (isActiveRef.current) {
|
| 452 |
useAgentStore.getState().setError(error.message);
|
|
@@ -594,7 +621,10 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 594 |
/** Read the event stream from GET /api/events and forward to side-channel. */
|
| 595 |
const consumeEventStream = async (signal: AbortSignal) => {
|
| 596 |
try {
|
| 597 |
-
const
|
|
|
|
|
|
|
|
|
|
| 598 |
headers: { 'Accept': 'text/event-stream' },
|
| 599 |
signal,
|
| 600 |
});
|
|
@@ -602,6 +632,71 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 602 |
|
| 603 |
const reader = res.body.pipeThrough(new TextDecoderStream()).getReader();
|
| 604 |
let buf = '';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
while (true) {
|
| 606 |
const { value, done } = await reader.read();
|
| 607 |
if (done || signal.aborted) break;
|
|
@@ -609,59 +704,21 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 609 |
const lines = buf.split('\n');
|
| 610 |
buf = lines.pop() || '';
|
| 611 |
for (const line of lines) {
|
| 612 |
-
const trimmed = line.
|
| 613 |
-
if (
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
event.data?.tool as string,
|
| 628 |
-
event.data?.tool_call_id as string,
|
| 629 |
-
event.data?.output as string,
|
| 630 |
-
event.data?.success as boolean,
|
| 631 |
-
);
|
| 632 |
-
} else if (et === 'tool_state_change') {
|
| 633 |
-
const state = event.data?.state as string;
|
| 634 |
-
const toolName = event.data?.tool as string;
|
| 635 |
-
if (state === 'running' && toolName) sideChannel.onToolRunning(toolName);
|
| 636 |
-
} else if (et === 'turn_complete' || et === 'error' || et === 'interrupted') {
|
| 637 |
-
sideChannel.onProcessingDone();
|
| 638 |
-
stopReconnect();
|
| 639 |
-
// Final hydration to get the complete message state
|
| 640 |
-
const result = await hydrateMessages();
|
| 641 |
-
if (result) {
|
| 642 |
-
const uiMsgs = llmMessagesToUIMessages(result.data, result.pendingIds, chatActionsRef.current.messages);
|
| 643 |
-
if (uiMsgs.length > 0) {
|
| 644 |
-
chat.setMessages(uiMsgs);
|
| 645 |
-
saveMessages(sessionId, uiMsgs);
|
| 646 |
-
}
|
| 647 |
-
}
|
| 648 |
-
return;
|
| 649 |
-
} else if (et === 'approval_required') {
|
| 650 |
-
sideChannel.onApprovalRequired(
|
| 651 |
-
(event.data?.tools || []) as Array<{ tool: string; arguments: Record<string, unknown>; tool_call_id: string }>,
|
| 652 |
-
);
|
| 653 |
-
stopReconnect();
|
| 654 |
-
const result = await hydrateMessages();
|
| 655 |
-
if (result) {
|
| 656 |
-
const uiMsgs = llmMessagesToUIMessages(result.data, result.pendingIds, chatActionsRef.current.messages);
|
| 657 |
-
if (uiMsgs.length > 0) {
|
| 658 |
-
chat.setMessages(uiMsgs);
|
| 659 |
-
saveMessages(sessionId, uiMsgs);
|
| 660 |
-
}
|
| 661 |
-
}
|
| 662 |
-
return;
|
| 663 |
-
}
|
| 664 |
-
} catch { /* ignore parse errors */ }
|
| 665 |
}
|
| 666 |
}
|
| 667 |
} catch {
|
|
@@ -830,6 +887,9 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 830 |
: approval.namespace,
|
| 831 |
}));
|
| 832 |
|
|
|
|
|
|
|
|
|
|
| 833 |
useAgentStore.getState().setJobsUpgradeRequired(null);
|
| 834 |
return approveTools(approvals);
|
| 835 |
}, [approveTools]);
|
|
|
|
| 371 |
} catch {
|
| 372 |
return null;
|
| 373 |
}
|
| 374 |
+
}, [sessionId, setNeedsAttention, updateSession]);
|
| 375 |
|
| 376 |
// -- useChat from Vercel AI SDK -----------------------------------------
|
| 377 |
const chat = useChat({
|
|
|
|
| 447 |
}
|
| 448 |
return;
|
| 449 |
}
|
| 450 |
+
if (error.message === 'HF_JOBS_INVALID_NAMESPACE') {
|
| 451 |
+
// Saved preference is no longer one of the user's eligible namespaces
|
| 452 |
+
// (e.g. they left the org). Clear it and reopen the picker.
|
| 453 |
+
const typed = error as Error & {
|
| 454 |
+
detail?: Record<string, unknown>;
|
| 455 |
+
approvals?: Array<{
|
| 456 |
+
tool_call_id: string;
|
| 457 |
+
approved: boolean;
|
| 458 |
+
feedback?: string | null;
|
| 459 |
+
edited_script?: string | null;
|
| 460 |
+
namespace?: string | null;
|
| 461 |
+
}>;
|
| 462 |
+
};
|
| 463 |
+
useAgentStore.getState().setPreferredJobsNamespace(null);
|
| 464 |
+
void hydrateFromBackend();
|
| 465 |
+
if (isActiveRef.current) {
|
| 466 |
+
useAgentStore.getState().setJobsUpgradeRequired({
|
| 467 |
+
approvals: typed.approvals || [],
|
| 468 |
+
toolCallIds: (typed.detail?.tool_call_ids as string[]) || [],
|
| 469 |
+
message: String(typed.detail?.message || 'Pick a different organization for this job run.'),
|
| 470 |
+
eligibleNamespaces: (typed.detail?.eligible_namespaces as string[]) || [],
|
| 471 |
+
plan: ((typed.detail?.plan as 'free' | 'pro' | 'org') || 'free'),
|
| 472 |
+
mode: 'namespace',
|
| 473 |
+
});
|
| 474 |
+
}
|
| 475 |
+
return;
|
| 476 |
+
}
|
| 477 |
logger.error('useChat error:', error);
|
| 478 |
if (isActiveRef.current) {
|
| 479 |
useAgentStore.getState().setError(error.message);
|
|
|
|
| 621 |
/** Read the event stream from GET /api/events and forward to side-channel. */
|
| 622 |
const consumeEventStream = async (signal: AbortSignal) => {
|
| 623 |
try {
|
| 624 |
+
const lastEventKey = `hf-agent-last-event:${sessionId}`;
|
| 625 |
+
const lastSeq = localStorage.getItem(lastEventKey);
|
| 626 |
+
const qs = lastSeq ? `?after=${encodeURIComponent(lastSeq)}` : '';
|
| 627 |
+
const res = await apiFetch(`/api/events/${sessionId}${qs}`, {
|
| 628 |
headers: { 'Accept': 'text/event-stream' },
|
| 629 |
signal,
|
| 630 |
});
|
|
|
|
| 632 |
|
| 633 |
const reader = res.body.pipeThrough(new TextDecoderStream()).getReader();
|
| 634 |
let buf = '';
|
| 635 |
+
let eventId: string | null = null;
|
| 636 |
+
let eventData = '';
|
| 637 |
+
const dispatch = async () => {
|
| 638 |
+
if (!eventData.trim()) {
|
| 639 |
+
eventId = null;
|
| 640 |
+
eventData = '';
|
| 641 |
+
return false;
|
| 642 |
+
}
|
| 643 |
+
const event = JSON.parse(eventData.trim());
|
| 644 |
+
const seq = event.seq ?? (eventId ? Number(eventId) : undefined);
|
| 645 |
+
if (Number.isFinite(seq)) {
|
| 646 |
+
localStorage.setItem(lastEventKey, String(seq));
|
| 647 |
+
}
|
| 648 |
+
eventId = null;
|
| 649 |
+
eventData = '';
|
| 650 |
+
// Forward to side-channel for real-time UI updates
|
| 651 |
+
const et = event.event_type as string;
|
| 652 |
+
if (et === 'processing') sideChannel.onProcessing();
|
| 653 |
+
else if (et === 'assistant_chunk') sideChannel.onStreaming();
|
| 654 |
+
else if (et === 'tool_call') {
|
| 655 |
+
const t = event.data?.tool as string;
|
| 656 |
+
const d = event.data?.arguments?.description as string | undefined;
|
| 657 |
+
sideChannel.onToolRunning(t, d);
|
| 658 |
+
sideChannel.onToolCallPanel(t, (event.data?.arguments || {}) as Record<string, unknown>);
|
| 659 |
+
} else if (et === 'tool_output') {
|
| 660 |
+
sideChannel.onToolOutputPanel(
|
| 661 |
+
event.data?.tool as string,
|
| 662 |
+
event.data?.tool_call_id as string,
|
| 663 |
+
event.data?.output as string,
|
| 664 |
+
event.data?.success as boolean,
|
| 665 |
+
);
|
| 666 |
+
} else if (et === 'tool_state_change') {
|
| 667 |
+
const state = event.data?.state as string;
|
| 668 |
+
const toolName = event.data?.tool as string;
|
| 669 |
+
if (state === 'running' && toolName) sideChannel.onToolRunning(toolName);
|
| 670 |
+
} else if (et === 'turn_complete' || et === 'error' || et === 'interrupted') {
|
| 671 |
+
sideChannel.onProcessingDone();
|
| 672 |
+
stopReconnect();
|
| 673 |
+
// Final hydration to get the complete message state
|
| 674 |
+
const result = await hydrateMessages();
|
| 675 |
+
if (result) {
|
| 676 |
+
const uiMsgs = llmMessagesToUIMessages(result.data, result.pendingIds, chatActionsRef.current.messages);
|
| 677 |
+
if (uiMsgs.length > 0) {
|
| 678 |
+
chat.setMessages(uiMsgs);
|
| 679 |
+
saveMessages(sessionId, uiMsgs);
|
| 680 |
+
}
|
| 681 |
+
}
|
| 682 |
+
return true;
|
| 683 |
+
} else if (et === 'approval_required') {
|
| 684 |
+
sideChannel.onApprovalRequired(
|
| 685 |
+
(event.data?.tools || []) as Array<{ tool: string; arguments: Record<string, unknown>; tool_call_id: string }>,
|
| 686 |
+
);
|
| 687 |
+
stopReconnect();
|
| 688 |
+
const result = await hydrateMessages();
|
| 689 |
+
if (result) {
|
| 690 |
+
const uiMsgs = llmMessagesToUIMessages(result.data, result.pendingIds, chatActionsRef.current.messages);
|
| 691 |
+
if (uiMsgs.length > 0) {
|
| 692 |
+
chat.setMessages(uiMsgs);
|
| 693 |
+
saveMessages(sessionId, uiMsgs);
|
| 694 |
+
}
|
| 695 |
+
}
|
| 696 |
+
return true;
|
| 697 |
+
}
|
| 698 |
+
return false;
|
| 699 |
+
};
|
| 700 |
while (true) {
|
| 701 |
const { value, done } = await reader.read();
|
| 702 |
if (done || signal.aborted) break;
|
|
|
|
| 704 |
const lines = buf.split('\n');
|
| 705 |
buf = lines.pop() || '';
|
| 706 |
for (const line of lines) {
|
| 707 |
+
const trimmed = line.replace(/\r$/, '');
|
| 708 |
+
if (trimmed === '') {
|
| 709 |
+
try {
|
| 710 |
+
if (await dispatch()) return;
|
| 711 |
+
} catch { /* ignore parse errors */ }
|
| 712 |
+
continue;
|
| 713 |
+
}
|
| 714 |
+
if (trimmed.startsWith(':')) continue;
|
| 715 |
+
if (trimmed.startsWith('id:')) {
|
| 716 |
+
eventId = trimmed.slice(3).trim();
|
| 717 |
+
continue;
|
| 718 |
+
}
|
| 719 |
+
if (trimmed.startsWith('data:')) {
|
| 720 |
+
eventData += trimmed.slice(5).trimStart() + '\n';
|
| 721 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
}
|
| 723 |
}
|
| 724 |
} catch {
|
|
|
|
| 887 |
: approval.namespace,
|
| 888 |
}));
|
| 889 |
|
| 890 |
+
// Remember this choice so the picker doesn't reappear for every
|
| 891 |
+
// subsequent hf_jobs call.
|
| 892 |
+
useAgentStore.getState().setPreferredJobsNamespace(namespace);
|
| 893 |
useAgentStore.getState().setJobsUpgradeRequired(null);
|
| 894 |
return approveTools(approvals);
|
| 895 |
}, [approveTools]);
|
frontend/src/lib/sse-chat-transport.ts
CHANGED
|
@@ -42,35 +42,66 @@ function nextPartId(prefix: string): string {
|
|
| 42 |
return `${prefix}-${Date.now()}-${++partIdCounter}`;
|
| 43 |
}
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
/** Parse an SSE text stream into AgentEvent objects. */
|
| 46 |
-
function createSSEParserStream(): TransformStream<string, AgentEvent> {
|
| 47 |
let buffer = '';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
return new TransformStream<string, AgentEvent>({
|
| 49 |
transform(chunk, controller) {
|
| 50 |
buffer += chunk;
|
| 51 |
const lines = buffer.split('\n');
|
| 52 |
// Keep the last (possibly incomplete) line in the buffer
|
| 53 |
buffer = lines.pop() || '';
|
| 54 |
-
for (const
|
| 55 |
-
const
|
| 56 |
-
if (
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
}
|
| 64 |
}
|
| 65 |
},
|
| 66 |
flush(controller) {
|
| 67 |
-
|
| 68 |
-
if (
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
} catch { /* ignore incomplete */ }
|
| 73 |
}
|
|
|
|
| 74 |
},
|
| 75 |
});
|
| 76 |
}
|
|
@@ -226,12 +257,17 @@ function createEventToChunkStream(sideChannel: SideChannelCallbacks): TransformS
|
|
| 226 |
const state = (event.data?.state as string) || '';
|
| 227 |
const toolName = (event.data?.tool as string) || '';
|
| 228 |
const jobUrl = (event.data?.jobUrl as string) || undefined;
|
|
|
|
|
|
|
| 229 |
|
| 230 |
if (tcId.startsWith('plan_tool')) break;
|
| 231 |
|
| 232 |
if (jobUrl && tcId) {
|
| 233 |
useAgentStore.getState().setJobUrl(tcId, jobUrl);
|
| 234 |
}
|
|
|
|
|
|
|
|
|
|
| 235 |
if (state === 'running' && toolName) {
|
| 236 |
sideChannel.onToolRunning(toolName);
|
| 237 |
}
|
|
@@ -320,7 +356,14 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
|
|
| 320 |
const approved = p.approval?.approved ?? true;
|
| 321 |
// Get edited script from agentStore if available
|
| 322 |
const editedScript = useAgentStore.getState().getEditedScript(p.toolCallId);
|
| 323 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
return {
|
| 325 |
tool_call_id: p.toolCallId,
|
| 326 |
approved,
|
|
@@ -388,6 +431,20 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
|
|
| 388 |
throw err;
|
| 389 |
}
|
| 390 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
if (!response.ok) {
|
| 392 |
const errorText = await response.text().catch(() => 'Request failed');
|
| 393 |
throw new Error(`Chat request failed: ${response.status} ${errorText}`);
|
|
@@ -400,7 +457,7 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
|
|
| 400 |
// Pipe: response bytes → text → SSE events → UIMessageChunks
|
| 401 |
return response.body
|
| 402 |
.pipeThrough(new TextDecoderStream())
|
| 403 |
-
.pipeThrough(createSSEParserStream())
|
| 404 |
.pipeThrough(createEventToChunkStream(this.sideChannel));
|
| 405 |
}
|
| 406 |
|
|
@@ -415,7 +472,9 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
|
|
| 415 |
if (!info.is_processing) return null;
|
| 416 |
|
| 417 |
// Session is mid-turn — subscribe to its event broadcast.
|
| 418 |
-
const
|
|
|
|
|
|
|
| 419 |
headers: { 'Accept': 'text/event-stream' },
|
| 420 |
});
|
| 421 |
if (!response.ok || !response.body) return null;
|
|
@@ -424,7 +483,7 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
|
|
| 424 |
|
| 425 |
return response.body
|
| 426 |
.pipeThrough(new TextDecoderStream())
|
| 427 |
-
.pipeThrough(createSSEParserStream())
|
| 428 |
.pipeThrough(createEventToChunkStream(this.sideChannel));
|
| 429 |
} catch {
|
| 430 |
return null;
|
|
|
|
| 42 |
return `${prefix}-${Date.now()}-${++partIdCounter}`;
|
| 43 |
}
|
| 44 |
|
| 45 |
+
function lastEventKey(sessionId: string): string {
|
| 46 |
+
return `hf-agent-last-event:${sessionId}`;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
/** Parse an SSE text stream into AgentEvent objects. */
|
| 50 |
+
function createSSEParserStream(sessionId: string): TransformStream<string, AgentEvent> {
|
| 51 |
let buffer = '';
|
| 52 |
+
let eventId: string | null = null;
|
| 53 |
+
let data = '';
|
| 54 |
+
|
| 55 |
+
const dispatch = (controller: TransformStreamDefaultController<AgentEvent>) => {
|
| 56 |
+
if (!data.trim()) {
|
| 57 |
+
eventId = null;
|
| 58 |
+
data = '';
|
| 59 |
+
return;
|
| 60 |
+
}
|
| 61 |
+
try {
|
| 62 |
+
const json = JSON.parse(data.trim()) as AgentEvent;
|
| 63 |
+
const seq = json.seq ?? (eventId ? Number(eventId) : undefined);
|
| 64 |
+
if (Number.isFinite(seq)) {
|
| 65 |
+
json.seq = seq;
|
| 66 |
+
localStorage.setItem(lastEventKey(sessionId), String(seq));
|
| 67 |
+
}
|
| 68 |
+
controller.enqueue(json);
|
| 69 |
+
} catch {
|
| 70 |
+
logger.warn('SSE parse error:', data.trim());
|
| 71 |
+
} finally {
|
| 72 |
+
eventId = null;
|
| 73 |
+
data = '';
|
| 74 |
+
}
|
| 75 |
+
};
|
| 76 |
+
|
| 77 |
return new TransformStream<string, AgentEvent>({
|
| 78 |
transform(chunk, controller) {
|
| 79 |
buffer += chunk;
|
| 80 |
const lines = buffer.split('\n');
|
| 81 |
// Keep the last (possibly incomplete) line in the buffer
|
| 82 |
buffer = lines.pop() || '';
|
| 83 |
+
for (const rawLine of lines) {
|
| 84 |
+
const line = rawLine.replace(/\r$/, '');
|
| 85 |
+
if (line === '') {
|
| 86 |
+
dispatch(controller);
|
| 87 |
+
continue;
|
| 88 |
+
}
|
| 89 |
+
if (line.startsWith(':')) continue;
|
| 90 |
+
if (line.startsWith('id:')) {
|
| 91 |
+
eventId = line.slice(3).trim();
|
| 92 |
+
} else if (line.startsWith('data:')) {
|
| 93 |
+
data += line.slice(5).trimStart() + '\n';
|
| 94 |
}
|
| 95 |
}
|
| 96 |
},
|
| 97 |
flush(controller) {
|
| 98 |
+
const line = buffer.replace(/\r$/, '');
|
| 99 |
+
if (line.startsWith('id:')) {
|
| 100 |
+
eventId = line.slice(3).trim();
|
| 101 |
+
} else if (line.startsWith('data:')) {
|
| 102 |
+
data += line.slice(5).trimStart() + '\n';
|
|
|
|
| 103 |
}
|
| 104 |
+
dispatch(controller);
|
| 105 |
},
|
| 106 |
});
|
| 107 |
}
|
|
|
|
| 257 |
const state = (event.data?.state as string) || '';
|
| 258 |
const toolName = (event.data?.tool as string) || '';
|
| 259 |
const jobUrl = (event.data?.jobUrl as string) || undefined;
|
| 260 |
+
const trackioSpaceId = (event.data?.trackioSpaceId as string) || undefined;
|
| 261 |
+
const trackioProject = (event.data?.trackioProject as string) || undefined;
|
| 262 |
|
| 263 |
if (tcId.startsWith('plan_tool')) break;
|
| 264 |
|
| 265 |
if (jobUrl && tcId) {
|
| 266 |
useAgentStore.getState().setJobUrl(tcId, jobUrl);
|
| 267 |
}
|
| 268 |
+
if (trackioSpaceId && tcId) {
|
| 269 |
+
useAgentStore.getState().setTrackioDashboard(tcId, trackioSpaceId, trackioProject);
|
| 270 |
+
}
|
| 271 |
if (state === 'running' && toolName) {
|
| 272 |
sideChannel.onToolRunning(toolName);
|
| 273 |
}
|
|
|
|
| 356 |
const approved = p.approval?.approved ?? true;
|
| 357 |
// Get edited script from agentStore if available
|
| 358 |
const editedScript = useAgentStore.getState().getEditedScript(p.toolCallId);
|
| 359 |
+
const explicitNamespace = useAgentStore.getState().getApprovalNamespace(p.toolCallId);
|
| 360 |
+
// Fall back to the user's persisted choice so we don't re-prompt
|
| 361 |
+
// every hf_jobs call. Backend will 400 if the saved namespace is
|
| 362 |
+
// no longer valid; the error handler clears the preference and
|
| 363 |
+
// reopens the picker.
|
| 364 |
+
const preferred = useAgentStore.getState().preferredJobsNamespace;
|
| 365 |
+
const namespace = explicitNamespace
|
| 366 |
+
?? (approved && p.toolName === 'hf_jobs' ? preferred ?? null : null);
|
| 367 |
return {
|
| 368 |
tool_call_id: p.toolCallId,
|
| 369 |
approved,
|
|
|
|
| 431 |
throw err;
|
| 432 |
}
|
| 433 |
}
|
| 434 |
+
if (response.status === 400) {
|
| 435 |
+
const payload = await response.json().catch(() => null);
|
| 436 |
+
if (payload?.detail?.error === 'hf_jobs_invalid_namespace') {
|
| 437 |
+
// Stored namespace is no longer eligible — surface so the UI can
|
| 438 |
+
// clear the saved preference and reopen the picker.
|
| 439 |
+
const err = new Error('HF_JOBS_INVALID_NAMESPACE') as Error & {
|
| 440 |
+
detail?: Record<string, unknown>;
|
| 441 |
+
approvals?: Array<Record<string, unknown>>;
|
| 442 |
+
};
|
| 443 |
+
err.detail = payload.detail as Record<string, unknown>;
|
| 444 |
+
err.approvals = (body.approvals as Array<Record<string, unknown>> | undefined) || [];
|
| 445 |
+
throw err;
|
| 446 |
+
}
|
| 447 |
+
}
|
| 448 |
if (!response.ok) {
|
| 449 |
const errorText = await response.text().catch(() => 'Request failed');
|
| 450 |
throw new Error(`Chat request failed: ${response.status} ${errorText}`);
|
|
|
|
| 457 |
// Pipe: response bytes → text → SSE events → UIMessageChunks
|
| 458 |
return response.body
|
| 459 |
.pipeThrough(new TextDecoderStream())
|
| 460 |
+
.pipeThrough(createSSEParserStream(sessionId))
|
| 461 |
.pipeThrough(createEventToChunkStream(this.sideChannel));
|
| 462 |
}
|
| 463 |
|
|
|
|
| 472 |
if (!info.is_processing) return null;
|
| 473 |
|
| 474 |
// Session is mid-turn — subscribe to its event broadcast.
|
| 475 |
+
const lastSeq = localStorage.getItem(lastEventKey(this.sessionId));
|
| 476 |
+
const qs = lastSeq ? `?after=${encodeURIComponent(lastSeq)}` : '';
|
| 477 |
+
const response = await apiFetch(`/api/events/${this.sessionId}${qs}`, {
|
| 478 |
headers: { 'Accept': 'text/event-stream' },
|
| 479 |
});
|
| 480 |
if (!response.ok || !response.body) return null;
|
|
|
|
| 483 |
|
| 484 |
return response.body
|
| 485 |
.pipeThrough(new TextDecoderStream())
|
| 486 |
+
.pipeThrough(createSSEParserStream(this.sessionId))
|
| 487 |
.pipeThrough(createEventToChunkStream(this.sideChannel));
|
| 488 |
} catch {
|
| 489 |
return null;
|
frontend/src/store/agentStore.ts
CHANGED
|
@@ -141,12 +141,21 @@ interface AgentStore {
|
|
| 141 |
// Namespace overrides chosen for hf_jobs approvals (tool_call_id -> namespace)
|
| 142 |
approvalNamespaces: Record<string, string>;
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
// Job URLs (tool_call_id -> job URL) for HF jobs
|
| 145 |
jobUrls: Record<string, string>;
|
| 146 |
|
| 147 |
// Job statuses (tool_call_id -> job status) for HF jobs
|
| 148 |
jobStatuses: Record<string, string>;
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
// Tool error states (tool_call_id -> true if errored) - persisted across renders
|
| 151 |
toolErrors: Record<string, boolean>;
|
| 152 |
|
|
@@ -194,12 +203,17 @@ interface AgentStore {
|
|
| 194 |
getApprovalNamespace: (toolCallId: string) => string | undefined;
|
| 195 |
clearApprovalNamespaces: () => void;
|
| 196 |
|
|
|
|
|
|
|
| 197 |
setJobUrl: (toolCallId: string, jobUrl: string) => void;
|
| 198 |
getJobUrl: (toolCallId: string) => string | undefined;
|
| 199 |
|
| 200 |
setJobStatus: (toolCallId: string, status: string) => void;
|
| 201 |
getJobStatus: (toolCallId: string) => string | undefined;
|
| 202 |
|
|
|
|
|
|
|
|
|
|
| 203 |
setToolError: (toolCallId: string, hasError: boolean) => void;
|
| 204 |
getToolError: (toolCallId: string) => boolean | undefined;
|
| 205 |
|
|
@@ -264,6 +278,48 @@ function saveRejectedTools(rejected: Record<string, boolean>): void {
|
|
| 264 |
}
|
| 265 |
}
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
export const useAgentStore = create<AgentStore>()((set, get) => ({
|
| 268 |
sessionStates: {},
|
| 269 |
activeSessionId: null,
|
|
@@ -285,8 +341,10 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
|
|
| 285 |
|
| 286 |
editedScripts: {},
|
| 287 |
approvalNamespaces: {},
|
|
|
|
| 288 |
jobUrls: {},
|
| 289 |
jobStatuses: {},
|
|
|
|
| 290 |
toolErrors: loadToolErrors(),
|
| 291 |
rejectedTools: loadRejectedTools(),
|
| 292 |
|
|
@@ -465,6 +523,11 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
|
|
| 465 |
|
| 466 |
clearApprovalNamespaces: () => set({ approvalNamespaces: {} }),
|
| 467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
// ── Job URLs ────────────────────────────────────────────────────────
|
| 469 |
|
| 470 |
setJobUrl: (toolCallId, jobUrl) => {
|
|
@@ -485,6 +548,26 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
|
|
| 485 |
|
| 486 |
getJobStatus: (toolCallId) => get().jobStatuses[toolCallId],
|
| 487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
// ── Tool Errors ─────────────────────────────────────────────────────
|
| 489 |
|
| 490 |
setToolError: (toolCallId, hasError) => {
|
|
|
|
| 141 |
// Namespace overrides chosen for hf_jobs approvals (tool_call_id -> namespace)
|
| 142 |
approvalNamespaces: Record<string, string>;
|
| 143 |
|
| 144 |
+
// Persisted preferred namespace for hf_jobs (auto-applied to future approvals
|
| 145 |
+
// so the user only picks once)
|
| 146 |
+
preferredJobsNamespace: string | null;
|
| 147 |
+
|
| 148 |
// Job URLs (tool_call_id -> job URL) for HF jobs
|
| 149 |
jobUrls: Record<string, string>;
|
| 150 |
|
| 151 |
// Job statuses (tool_call_id -> job status) for HF jobs
|
| 152 |
jobStatuses: Record<string, string>;
|
| 153 |
|
| 154 |
+
// Trackio dashboard config per tool call (tool_call_id -> {spaceId, project?})
|
| 155 |
+
// Set by hf_jobs / sandbox_create tools when the agent declares trackio_space_id;
|
| 156 |
+
// the UI uses it to embed the live dashboard via an iframe.
|
| 157 |
+
trackioDashboards: Record<string, { spaceId: string; project?: string }>;
|
| 158 |
+
|
| 159 |
// Tool error states (tool_call_id -> true if errored) - persisted across renders
|
| 160 |
toolErrors: Record<string, boolean>;
|
| 161 |
|
|
|
|
| 203 |
getApprovalNamespace: (toolCallId: string) => string | undefined;
|
| 204 |
clearApprovalNamespaces: () => void;
|
| 205 |
|
| 206 |
+
setPreferredJobsNamespace: (namespace: string | null) => void;
|
| 207 |
+
|
| 208 |
setJobUrl: (toolCallId: string, jobUrl: string) => void;
|
| 209 |
getJobUrl: (toolCallId: string) => string | undefined;
|
| 210 |
|
| 211 |
setJobStatus: (toolCallId: string, status: string) => void;
|
| 212 |
getJobStatus: (toolCallId: string) => string | undefined;
|
| 213 |
|
| 214 |
+
setTrackioDashboard: (toolCallId: string, spaceId: string, project?: string) => void;
|
| 215 |
+
getTrackioDashboard: (toolCallId: string) => { spaceId: string; project?: string } | undefined;
|
| 216 |
+
|
| 217 |
setToolError: (toolCallId: string, hasError: boolean) => void;
|
| 218 |
getToolError: (toolCallId: string) => boolean | undefined;
|
| 219 |
|
|
|
|
| 278 |
}
|
| 279 |
}
|
| 280 |
|
| 281 |
+
// Trackio dashboards survive a page reload — without persistence the iframe
|
| 282 |
+
// disappears whenever the user refreshes mid-job, which is the exact moment
|
| 283 |
+
// they'd want to keep watching it.
|
| 284 |
+
function loadTrackioDashboards(): Record<string, { spaceId: string; project?: string }> {
|
| 285 |
+
try {
|
| 286 |
+
const stored = localStorage.getItem('hf-agent-trackio-dashboards');
|
| 287 |
+
return stored ? JSON.parse(stored) : {};
|
| 288 |
+
} catch {
|
| 289 |
+
return {};
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
function saveTrackioDashboards(dashboards: Record<string, { spaceId: string; project?: string }>): void {
|
| 294 |
+
try {
|
| 295 |
+
localStorage.setItem('hf-agent-trackio-dashboards', JSON.stringify(dashboards));
|
| 296 |
+
} catch (e) {
|
| 297 |
+
console.warn('Failed to persist trackio dashboards:', e);
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
const PREFERRED_JOBS_NAMESPACE_KEY = 'hf-agent-preferred-jobs-namespace';
|
| 302 |
+
|
| 303 |
+
function loadPreferredJobsNamespace(): string | null {
|
| 304 |
+
try {
|
| 305 |
+
return localStorage.getItem(PREFERRED_JOBS_NAMESPACE_KEY);
|
| 306 |
+
} catch {
|
| 307 |
+
return null;
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
function savePreferredJobsNamespace(namespace: string | null): void {
|
| 312 |
+
try {
|
| 313 |
+
if (namespace) {
|
| 314 |
+
localStorage.setItem(PREFERRED_JOBS_NAMESPACE_KEY, namespace);
|
| 315 |
+
} else {
|
| 316 |
+
localStorage.removeItem(PREFERRED_JOBS_NAMESPACE_KEY);
|
| 317 |
+
}
|
| 318 |
+
} catch (e) {
|
| 319 |
+
console.warn('Failed to persist preferred jobs namespace:', e);
|
| 320 |
+
}
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
export const useAgentStore = create<AgentStore>()((set, get) => ({
|
| 324 |
sessionStates: {},
|
| 325 |
activeSessionId: null,
|
|
|
|
| 341 |
|
| 342 |
editedScripts: {},
|
| 343 |
approvalNamespaces: {},
|
| 344 |
+
preferredJobsNamespace: loadPreferredJobsNamespace(),
|
| 345 |
jobUrls: {},
|
| 346 |
jobStatuses: {},
|
| 347 |
+
trackioDashboards: loadTrackioDashboards(),
|
| 348 |
toolErrors: loadToolErrors(),
|
| 349 |
rejectedTools: loadRejectedTools(),
|
| 350 |
|
|
|
|
| 523 |
|
| 524 |
clearApprovalNamespaces: () => set({ approvalNamespaces: {} }),
|
| 525 |
|
| 526 |
+
setPreferredJobsNamespace: (namespace) => {
|
| 527 |
+
savePreferredJobsNamespace(namespace);
|
| 528 |
+
set({ preferredJobsNamespace: namespace });
|
| 529 |
+
},
|
| 530 |
+
|
| 531 |
// ── Job URLs ────────────────────────────────────────────────────────
|
| 532 |
|
| 533 |
setJobUrl: (toolCallId, jobUrl) => {
|
|
|
|
| 548 |
|
| 549 |
getJobStatus: (toolCallId) => get().jobStatuses[toolCallId],
|
| 550 |
|
| 551 |
+
// ── Trackio Dashboards ──────────────────────────────────────────────
|
| 552 |
+
|
| 553 |
+
setTrackioDashboard: (toolCallId, spaceId, project) => {
|
| 554 |
+
set((state) => {
|
| 555 |
+
const existing = state.trackioDashboards[toolCallId];
|
| 556 |
+
// Don't churn the object if nothing changed (avoids extra renders).
|
| 557 |
+
if (existing && existing.spaceId === spaceId && existing.project === project) {
|
| 558 |
+
return {};
|
| 559 |
+
}
|
| 560 |
+
const updated = {
|
| 561 |
+
...state.trackioDashboards,
|
| 562 |
+
[toolCallId]: { spaceId, ...(project ? { project } : {}) },
|
| 563 |
+
};
|
| 564 |
+
saveTrackioDashboards(updated);
|
| 565 |
+
return { trackioDashboards: updated };
|
| 566 |
+
});
|
| 567 |
+
},
|
| 568 |
+
|
| 569 |
+
getTrackioDashboard: (toolCallId) => get().trackioDashboards[toolCallId],
|
| 570 |
+
|
| 571 |
// ── Tool Errors ─────────────────────────────────────────────────────
|
| 572 |
|
| 573 |
setToolError: (toolCallId, hasError) => {
|
frontend/src/store/sessionStore.ts
CHANGED
|
@@ -20,6 +20,14 @@ interface SessionStore {
|
|
| 20 |
markExpired: (id: string) => void;
|
| 21 |
/** Clear the expired flag (used after restore-with-summary succeeds). */
|
| 22 |
clearExpired: (id: string) => void;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
/** Atomically swap a session's id in the list + both localStorage caches.
|
| 24 |
* Used when we rehydrate an expired session into a freshly-created backend
|
| 25 |
* session — preserves title, timestamps, and messages. */
|
|
@@ -76,6 +84,45 @@ export const useSessionStore = create<SessionStore>()(
|
|
| 76 |
}));
|
| 77 |
},
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
renameSession: (oldId: string, newId: string) => {
|
| 80 |
if (oldId === newId) return;
|
| 81 |
moveMessages(oldId, newId);
|
|
|
|
| 20 |
markExpired: (id: string) => void;
|
| 21 |
/** Clear the expired flag (used after restore-with-summary succeeds). */
|
| 22 |
clearExpired: (id: string) => void;
|
| 23 |
+
/** Merge durable server-side sessions into local sidebar metadata. */
|
| 24 |
+
mergeServerSessions: (sessions: Array<{
|
| 25 |
+
session_id: string;
|
| 26 |
+
title?: string | null;
|
| 27 |
+
created_at: string;
|
| 28 |
+
is_active?: boolean;
|
| 29 |
+
pending_approval?: unknown[] | null;
|
| 30 |
+
}>) => void;
|
| 31 |
/** Atomically swap a session's id in the list + both localStorage caches.
|
| 32 |
* Used when we rehydrate an expired session into a freshly-created backend
|
| 33 |
* session — preserves title, timestamps, and messages. */
|
|
|
|
| 84 |
}));
|
| 85 |
},
|
| 86 |
|
| 87 |
+
mergeServerSessions: (serverSessions) => {
|
| 88 |
+
set((state) => {
|
| 89 |
+
const byId = new Map(state.sessions.map((s) => [s.id, s]));
|
| 90 |
+
const merged = [...state.sessions];
|
| 91 |
+
for (const server of serverSessions) {
|
| 92 |
+
const id = server.session_id;
|
| 93 |
+
if (!id) continue;
|
| 94 |
+
const existing = byId.get(id);
|
| 95 |
+
if (existing) {
|
| 96 |
+
const updated = {
|
| 97 |
+
...existing,
|
| 98 |
+
title: server.title || existing.title,
|
| 99 |
+
isActive: server.is_active ?? existing.isActive,
|
| 100 |
+
needsAttention: Boolean(server.pending_approval?.length) || existing.needsAttention,
|
| 101 |
+
expired: false,
|
| 102 |
+
};
|
| 103 |
+
const idx = merged.findIndex((s) => s.id === id);
|
| 104 |
+
if (idx >= 0) merged[idx] = updated;
|
| 105 |
+
byId.set(id, updated);
|
| 106 |
+
continue;
|
| 107 |
+
}
|
| 108 |
+
const newSession: SessionMeta = {
|
| 109 |
+
id,
|
| 110 |
+
title: server.title || `Chat ${merged.length + 1}`,
|
| 111 |
+
createdAt: server.created_at || new Date().toISOString(),
|
| 112 |
+
isActive: server.is_active ?? true,
|
| 113 |
+
needsAttention: Boolean(server.pending_approval?.length),
|
| 114 |
+
expired: false,
|
| 115 |
+
};
|
| 116 |
+
merged.push(newSession);
|
| 117 |
+
byId.set(id, newSession);
|
| 118 |
+
}
|
| 119 |
+
return {
|
| 120 |
+
sessions: merged,
|
| 121 |
+
activeSessionId: state.activeSessionId || merged[merged.length - 1]?.id || null,
|
| 122 |
+
};
|
| 123 |
+
});
|
| 124 |
+
},
|
| 125 |
+
|
| 126 |
renameSession: (oldId: string, newId: string) => {
|
| 127 |
if (oldId === newId) return;
|
| 128 |
moveMessages(oldId, newId);
|
frontend/src/types/events.ts
CHANGED
|
@@ -24,6 +24,7 @@ export type EventType =
|
|
| 24 |
export interface AgentEvent {
|
| 25 |
event_type: EventType;
|
| 26 |
data?: Record<string, unknown>;
|
|
|
|
| 27 |
}
|
| 28 |
|
| 29 |
export interface ReadyEventData {
|
|
|
|
| 24 |
export interface AgentEvent {
|
| 25 |
event_type: EventType;
|
| 26 |
data?: Record<string, unknown>;
|
| 27 |
+
seq?: number;
|
| 28 |
}
|
| 29 |
|
| 30 |
export interface ReadyEventData {
|
pyproject.toml
CHANGED
|
@@ -13,7 +13,7 @@ dependencies = [
|
|
| 13 |
"requests>=2.33.0",
|
| 14 |
"litellm>=1.83.0",
|
| 15 |
"boto3>=1.35.0",
|
| 16 |
-
"huggingface-hub>=1.
|
| 17 |
"fastmcp>=3.2.0",
|
| 18 |
"prompt-toolkit>=3.0.0",
|
| 19 |
"thefuzz>=0.22.1",
|
|
@@ -27,6 +27,7 @@ dependencies = [
|
|
| 27 |
"httpx>=0.27.0",
|
| 28 |
"websockets>=13.0",
|
| 29 |
"apscheduler>=3.10,<4",
|
|
|
|
| 30 |
]
|
| 31 |
|
| 32 |
[project.optional-dependencies]
|
|
@@ -42,7 +43,7 @@ eval = [
|
|
| 42 |
# Development and testing dependencies
|
| 43 |
dev = [
|
| 44 |
"pytest>=9.0.2",
|
| 45 |
-
"pytest-asyncio>=
|
| 46 |
]
|
| 47 |
|
| 48 |
# All dependencies (eval + dev)
|
|
@@ -58,7 +59,20 @@ requires = ["setuptools>=64"]
|
|
| 58 |
build-backend = "setuptools.build_meta"
|
| 59 |
|
| 60 |
[tool.setuptools.packages.find]
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
[tool.uv]
|
| 64 |
package = true
|
|
|
|
| 13 |
"requests>=2.33.0",
|
| 14 |
"litellm>=1.83.0",
|
| 15 |
"boto3>=1.35.0",
|
| 16 |
+
"huggingface-hub>=1.12.0",
|
| 17 |
"fastmcp>=3.2.0",
|
| 18 |
"prompt-toolkit>=3.0.0",
|
| 19 |
"thefuzz>=0.22.1",
|
|
|
|
| 27 |
"httpx>=0.27.0",
|
| 28 |
"websockets>=13.0",
|
| 29 |
"apscheduler>=3.10,<4",
|
| 30 |
+
"pymongo>=4.17.0",
|
| 31 |
]
|
| 32 |
|
| 33 |
[project.optional-dependencies]
|
|
|
|
| 43 |
# Development and testing dependencies
|
| 44 |
dev = [
|
| 45 |
"pytest>=9.0.2",
|
| 46 |
+
"pytest-asyncio>=1.2.0",
|
| 47 |
]
|
| 48 |
|
| 49 |
# All dependencies (eval + dev)
|
|
|
|
| 59 |
build-backend = "setuptools.build_meta"
|
| 60 |
|
| 61 |
[tool.setuptools.packages.find]
|
| 62 |
+
# `configs` ships the JSON files loaded by agent.main.CLI_CONFIG_PATH at
|
| 63 |
+
# runtime (resolves to <site-packages>/configs/cli_agent_config.json).
|
| 64 |
+
# Without it, `uv tool install` / `pip install` produce a broken install
|
| 65 |
+
# that imports fine but crashes at startup with FileNotFoundError.
|
| 66 |
+
include = ["agent*", "configs"]
|
| 67 |
+
|
| 68 |
+
[tool.setuptools.package-data]
|
| 69 |
+
configs = ["*.json"]
|
| 70 |
+
# Agent data files: system prompts loaded by ContextManager._load_system_prompt
|
| 71 |
+
# at runtime (`<site-packages>/agent/prompts/system_prompt_v3.yaml`), plus the
|
| 72 |
+
# package README. Without these, headless_main hangs forever — submission_loop
|
| 73 |
+
# crashes with FileNotFoundError but headless_main doesn't check agent_task.done()
|
| 74 |
+
# and just keeps awaiting the "ready" event_queue item that will never come.
|
| 75 |
+
agent = ["README.md", "prompts/*.yaml"]
|
| 76 |
|
| 77 |
[tool.uv]
|
| 78 |
package = true
|
scripts/build_kpis.py
CHANGED
|
@@ -38,15 +38,27 @@ re-running the same hour overwrites.
|
|
| 38 |
llm_calls — count of llm_call events
|
| 39 |
tokens_prompt / _completion / _cache_read / _cache_creation
|
| 40 |
cost_usd — sum of llm_call.cost_usd
|
|
|
|
| 41 |
cache_hit_ratio — cache_read / (cache_read + prompt)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
| 45 |
time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
|
| 46 |
thumbs_up / thumbs_down
|
| 47 |
hf_jobs_submitted / _succeeded / _blocked
|
|
|
|
| 48 |
pro_cta_clicks
|
| 49 |
gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
================================================================================
|
| 52 |
Usage
|
|
@@ -213,6 +225,7 @@ def _session_metrics(session: dict) -> dict:
|
|
| 213 |
"thumbs_up": 0, "thumbs_down": 0,
|
| 214 |
"hf_jobs_submitted": 0, "hf_jobs_succeeded": 0, "hf_jobs_blocked": 0,
|
| 215 |
"pro_cta_clicks": 0,
|
|
|
|
| 216 |
"first_tool_s": -1,
|
| 217 |
}
|
| 218 |
events = session.get("events") or []
|
|
@@ -231,11 +244,19 @@ def _session_metrics(session: dict) -> dict:
|
|
| 231 |
gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
|
| 232 |
jobs_submitted = 0
|
| 233 |
jobs_succeeded = 0
|
| 234 |
-
jobs_blocked = 0
|
| 235 |
thumbs_up = 0
|
| 236 |
thumbs_down = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
pro_cta_clicks = 0
|
| 238 |
pro_cta_by_source: dict[str, int] = defaultdict(int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
start_dt = _parse_ts(session_start)
|
| 241 |
|
|
@@ -260,6 +281,10 @@ def _session_metrics(session: dict) -> dict:
|
|
| 260 |
first_tool_ts = (ts - start_dt).total_seconds()
|
| 261 |
|
| 262 |
elif et == "tool_call":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
if first_tool_ts is None and ts is not None and start_dt is not None:
|
| 264 |
first_tool_ts = (ts - start_dt).total_seconds()
|
| 265 |
|
|
@@ -296,6 +321,19 @@ def _session_metrics(session: dict) -> dict:
|
|
| 296 |
source = str(data.get("source") or "unknown")
|
| 297 |
pro_cta_by_source[source] += 1
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
out["tool_calls_total"] = tool_total
|
| 300 |
out["tool_calls_success"] = tool_success
|
| 301 |
out["failures"] = 1 if had_error else 0
|
|
@@ -304,12 +342,22 @@ def _session_metrics(session: dict) -> dict:
|
|
| 304 |
out["thumbs_down"] = thumbs_down
|
| 305 |
out["hf_jobs_submitted"] = jobs_submitted
|
| 306 |
out["hf_jobs_succeeded"] = jobs_succeeded
|
|
|
|
|
|
|
|
|
|
| 307 |
out["hf_jobs_blocked"] = jobs_blocked
|
| 308 |
out["pro_cta_clicks"] = pro_cta_clicks
|
| 309 |
out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
|
| 310 |
out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
|
| 311 |
out["_pro_cta_by_source"] = dict(pro_cta_by_source)
|
| 312 |
out["_user"] = session.get("user_id") or session.get("session_id")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
return dict(out)
|
| 314 |
|
| 315 |
|
|
@@ -317,12 +365,36 @@ def _aggregate(per_session: list[dict]) -> dict:
|
|
| 317 |
"""Collapse a bucket's worth of session rollups into the final KPI row."""
|
| 318 |
ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
|
| 319 |
gpu_hours: dict[str, float] = defaultdict(float)
|
| 320 |
-
pro_cta_by_source: dict[str, int] = defaultdict(int)
|
| 321 |
for s in per_session:
|
| 322 |
for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
|
| 323 |
gpu_hours[f] += h
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
total_sessions = sum(s["sessions"] for s in per_session)
|
| 328 |
total_turns = sum(s["turns"] for s in per_session)
|
|
@@ -330,6 +402,16 @@ def _aggregate(per_session: list[dict]) -> dict:
|
|
| 330 |
tokens_cache_read = sum(s["tokens_cache_read"] for s in per_session)
|
| 331 |
tool_total = sum(s["tool_calls_total"] for s in per_session)
|
| 332 |
tool_success = sum(s["tool_calls_success"] for s in per_session)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
unique_users = {s.get("_user") for s in per_session if s.get("_user")}
|
| 335 |
|
|
@@ -343,26 +425,61 @@ def _aggregate(per_session: list[dict]) -> dict:
|
|
| 343 |
"tokens_cache_read": int(tokens_cache_read),
|
| 344 |
"tokens_cache_creation": int(sum(s["tokens_cache_creation"] for s in per_session)),
|
| 345 |
"cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
"cache_hit_ratio": round(
|
| 347 |
tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
|
| 348 |
) if (tokens_cache_read + tokens_prompt) > 0 else 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
"tool_success_rate": round(tool_success / tool_total, 4) if tool_total > 0 else 0.0,
|
| 350 |
-
"failure_rate": round(
|
| 351 |
-
|
| 352 |
-
) if total_sessions > 0 else 0.0,
|
| 353 |
-
"regenerate_rate": round(
|
| 354 |
-
sum(s["regenerate_sessions"] for s in per_session) / total_sessions, 4
|
| 355 |
-
) if total_sessions > 0 else 0.0,
|
| 356 |
"time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
|
| 357 |
"time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
|
| 358 |
"thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
|
| 359 |
"thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
|
| 360 |
"hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
|
| 361 |
"hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
|
| 362 |
-
"
|
| 363 |
-
"
|
|
|
|
|
|
|
|
|
|
| 364 |
"gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
|
| 365 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
}
|
| 367 |
|
| 368 |
|
|
|
|
| 38 |
llm_calls — count of llm_call events
|
| 39 |
tokens_prompt / _completion / _cache_read / _cache_creation
|
| 40 |
cost_usd — sum of llm_call.cost_usd
|
| 41 |
+
cost_per_session_mean / _p50 / _p95 — per-session cost distribution
|
| 42 |
cache_hit_ratio — cache_read / (cache_read + prompt)
|
| 43 |
+
tool_calls_total / _succeeded / _failed — per-tool_output reliability counts
|
| 44 |
+
tool_success_rate — succeeded / total (kept for back-compat)
|
| 45 |
+
successful_sessions / errored_sessions / regenerated_sessions — outcome counts
|
| 46 |
+
failure_rate / regenerate_rate — kept for back-compat
|
| 47 |
time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
|
| 48 |
thumbs_up / thumbs_down
|
| 49 |
hf_jobs_submitted / _succeeded / _blocked
|
| 50 |
+
sandboxes_created / _cpu / _gpu — sandbox_create events bucketed by hardware
|
| 51 |
pro_cta_clicks
|
| 52 |
gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
|
| 53 |
+
research_calls — total `research` tool_call events
|
| 54 |
+
sessions_with_research — sessions that called `research` ≥1
|
| 55 |
+
research_calls_per_session_p50 / _p95 — among sessions that did any (zero-only sessions excluded)
|
| 56 |
+
distinct_tools_per_session_p50 / _p95 — among sessions with ≥1 named tool_call
|
| 57 |
+
tool_calls_per_session_p50 / _p95 — among sessions with ≥1 named tool_call
|
| 58 |
+
tool_calls_per_turn_p50 / _p95 — calls / turns, among sessions with turns>0
|
| 59 |
+
tool_calls_by_name_json — JSON {tool: total_calls} (all tools seen)
|
| 60 |
+
sessions_using_tool_json — JSON {tool: distinct_sessions_using}
|
| 61 |
+
sessions_by_model_json — JSON {model_name: count} (CLI vs Bedrock split)
|
| 62 |
|
| 63 |
================================================================================
|
| 64 |
Usage
|
|
|
|
| 225 |
"thumbs_up": 0, "thumbs_down": 0,
|
| 226 |
"hf_jobs_submitted": 0, "hf_jobs_succeeded": 0, "hf_jobs_blocked": 0,
|
| 227 |
"pro_cta_clicks": 0,
|
| 228 |
+
"sandboxes_created": 0, "sandboxes_cpu": 0, "sandboxes_gpu": 0,
|
| 229 |
"first_tool_s": -1,
|
| 230 |
}
|
| 231 |
events = session.get("events") or []
|
|
|
|
| 244 |
gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
|
| 245 |
jobs_submitted = 0
|
| 246 |
jobs_succeeded = 0
|
|
|
|
| 247 |
thumbs_up = 0
|
| 248 |
thumbs_down = 0
|
| 249 |
+
sandboxes_created = 0
|
| 250 |
+
sandboxes_cpu = 0
|
| 251 |
+
sandboxes_gpu = 0
|
| 252 |
+
jobs_blocked = 0
|
| 253 |
pro_cta_clicks = 0
|
| 254 |
pro_cta_by_source: dict[str, int] = defaultdict(int)
|
| 255 |
+
# Per-tool counters from tool_call events. Counted off tool_call (which
|
| 256 |
+
# carries data["tool"]) rather than tool_output (which only carries
|
| 257 |
+
# success/output) so we can attribute calls to specific tools.
|
| 258 |
+
tool_calls_by_name: dict[str, int] = defaultdict(int)
|
| 259 |
+
total_named_tool_calls = 0
|
| 260 |
|
| 261 |
start_dt = _parse_ts(session_start)
|
| 262 |
|
|
|
|
| 281 |
first_tool_ts = (ts - start_dt).total_seconds()
|
| 282 |
|
| 283 |
elif et == "tool_call":
|
| 284 |
+
name = data.get("tool")
|
| 285 |
+
if name:
|
| 286 |
+
tool_calls_by_name[name] += 1
|
| 287 |
+
total_named_tool_calls += 1
|
| 288 |
if first_tool_ts is None and ts is not None and start_dt is not None:
|
| 289 |
first_tool_ts = (ts - start_dt).total_seconds()
|
| 290 |
|
|
|
|
| 321 |
source = str(data.get("source") or "unknown")
|
| 322 |
pro_cta_by_source[source] += 1
|
| 323 |
|
| 324 |
+
elif et == "sandbox_create":
|
| 325 |
+
sandboxes_created += 1
|
| 326 |
+
hardware = (data.get("hardware") or "").lower()
|
| 327 |
+
# CPU flavors are explicitly named "cpu-*". Everything else
|
| 328 |
+
# (including unknown/missing hardware strings) lands in the GPU
|
| 329 |
+
# bucket, since the auto-create default is "cpu-basic" which is
|
| 330 |
+
# matched here — anything that isn't is almost always an explicit
|
| 331 |
+
# GPU choice.
|
| 332 |
+
if hardware.startswith("cpu-"):
|
| 333 |
+
sandboxes_cpu += 1
|
| 334 |
+
else:
|
| 335 |
+
sandboxes_gpu += 1
|
| 336 |
+
|
| 337 |
out["tool_calls_total"] = tool_total
|
| 338 |
out["tool_calls_success"] = tool_success
|
| 339 |
out["failures"] = 1 if had_error else 0
|
|
|
|
| 342 |
out["thumbs_down"] = thumbs_down
|
| 343 |
out["hf_jobs_submitted"] = jobs_submitted
|
| 344 |
out["hf_jobs_succeeded"] = jobs_succeeded
|
| 345 |
+
out["sandboxes_created"] = sandboxes_created
|
| 346 |
+
out["sandboxes_cpu"] = sandboxes_cpu
|
| 347 |
+
out["sandboxes_gpu"] = sandboxes_gpu
|
| 348 |
out["hf_jobs_blocked"] = jobs_blocked
|
| 349 |
out["pro_cta_clicks"] = pro_cta_clicks
|
| 350 |
out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
|
| 351 |
out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
|
| 352 |
out["_pro_cta_by_source"] = dict(pro_cta_by_source)
|
| 353 |
out["_user"] = session.get("user_id") or session.get("session_id")
|
| 354 |
+
# Intra-session tool fields. Underscore-prefixed = consumed by _aggregate
|
| 355 |
+
# only, never written to CSV directly.
|
| 356 |
+
out["_tool_calls_by_name"] = dict(tool_calls_by_name)
|
| 357 |
+
out["_research_calls"] = tool_calls_by_name.get("research", 0)
|
| 358 |
+
out["_distinct_tools_used"] = len(tool_calls_by_name)
|
| 359 |
+
out["_total_named_tool_calls"] = total_named_tool_calls
|
| 360 |
+
out["_model_name"] = session.get("model_name") or "unknown"
|
| 361 |
return dict(out)
|
| 362 |
|
| 363 |
|
|
|
|
| 365 |
"""Collapse a bucket's worth of session rollups into the final KPI row."""
|
| 366 |
ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
|
| 367 |
gpu_hours: dict[str, float] = defaultdict(float)
|
|
|
|
| 368 |
for s in per_session:
|
| 369 |
for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
|
| 370 |
gpu_hours[f] += h
|
| 371 |
+
|
| 372 |
+
# Per-tool aggregates. ``sessions_using_tool`` counts each session at most
|
| 373 |
+
# once per tool, so the dashboard can show "how many sessions reached for
|
| 374 |
+
# research" alongside "how many research calls overall".
|
| 375 |
+
tool_calls_by_name: dict[str, int] = defaultdict(int)
|
| 376 |
+
sessions_using_tool: dict[str, int] = defaultdict(int)
|
| 377 |
+
sessions_by_model: dict[str, int] = defaultdict(int)
|
| 378 |
+
for s in per_session:
|
| 379 |
+
for name, count in (s.get("_tool_calls_by_name") or {}).items():
|
| 380 |
+
tool_calls_by_name[name] += int(count)
|
| 381 |
+
sessions_using_tool[name] += 1
|
| 382 |
+
sessions_by_model[s.get("_model_name") or "unknown"] += 1
|
| 383 |
+
|
| 384 |
+
# Percentile inputs. All "per session" percentiles exclude sessions that
|
| 385 |
+
# never reached for the relevant signal — otherwise quiet hours
|
| 386 |
+
# (status-check sessions, abandoned new conversations) drag every median
|
| 387 |
+
# to 0 and the chart tells you nothing.
|
| 388 |
+
research_calls_nz = [s.get("_research_calls", 0) for s in per_session if s.get("_research_calls", 0) > 0]
|
| 389 |
+
distinct_tools_values = [s.get("_distinct_tools_used", 0) for s in per_session if s.get("_distinct_tools_used", 0) > 0]
|
| 390 |
+
total_calls_values = [s.get("_total_named_tool_calls", 0) for s in per_session if s.get("_total_named_tool_calls", 0) > 0]
|
| 391 |
+
# Per-turn intensity: turns>0 is the natural filter here (a session with
|
| 392 |
+
# 5 turns and 0 tools is a meaningful 0). Don't strip those.
|
| 393 |
+
calls_per_turn_values = [
|
| 394 |
+
s.get("_total_named_tool_calls", 0) / s["turns"]
|
| 395 |
+
for s in per_session
|
| 396 |
+
if s.get("turns", 0) > 0
|
| 397 |
+
]
|
| 398 |
|
| 399 |
total_sessions = sum(s["sessions"] for s in per_session)
|
| 400 |
total_turns = sum(s["turns"] for s in per_session)
|
|
|
|
| 402 |
tokens_cache_read = sum(s["tokens_cache_read"] for s in per_session)
|
| 403 |
tool_total = sum(s["tool_calls_total"] for s in per_session)
|
| 404 |
tool_success = sum(s["tool_calls_success"] for s in per_session)
|
| 405 |
+
failures = int(sum(s["failures"] for s in per_session))
|
| 406 |
+
regenerates = int(sum(s["regenerate_sessions"] for s in per_session))
|
| 407 |
+
research_calls_total = int(sum(s.get("_research_calls", 0) for s in per_session))
|
| 408 |
+
sessions_with_research = sum(1 for s in per_session if s.get("_research_calls", 0) > 0)
|
| 409 |
+
|
| 410 |
+
# Per-session cost percentiles — chart "median session cost" alongside the
|
| 411 |
+
# mean so a few $700 outliers don't make you think every session is pricey.
|
| 412 |
+
session_costs = [float(s.get("cost_usd") or 0.0) for s in per_session]
|
| 413 |
+
cost_p50 = _percentile(session_costs, 0.5)
|
| 414 |
+
cost_p95 = _percentile(session_costs, 0.95)
|
| 415 |
|
| 416 |
unique_users = {s.get("_user") for s in per_session if s.get("_user")}
|
| 417 |
|
|
|
|
| 425 |
"tokens_cache_read": int(tokens_cache_read),
|
| 426 |
"tokens_cache_creation": int(sum(s["tokens_cache_creation"] for s in per_session)),
|
| 427 |
"cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
|
| 428 |
+
# Per-session cost summaries.
|
| 429 |
+
"cost_per_session_mean": round(
|
| 430 |
+
sum(s["cost_usd"] for s in per_session) / total_sessions, 6
|
| 431 |
+
) if total_sessions > 0 else 0.0,
|
| 432 |
+
"cost_per_session_p50": round(cost_p50, 6),
|
| 433 |
+
"cost_per_session_p95": round(cost_p95, 6),
|
| 434 |
"cache_hit_ratio": round(
|
| 435 |
tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
|
| 436 |
) if (tokens_cache_read + tokens_prompt) > 0 else 0.0,
|
| 437 |
+
# Raw reliability COUNTS (these are what the dashboard shows directly).
|
| 438 |
+
"tool_calls_total": int(tool_total),
|
| 439 |
+
"tool_calls_succeeded": int(tool_success),
|
| 440 |
+
"tool_calls_failed": int(tool_total - tool_success),
|
| 441 |
+
"errored_sessions": failures,
|
| 442 |
+
# Successful = "did not raise an error event". Mutually exclusive
|
| 443 |
+
# with errored_sessions; sums with errored_sessions to total sessions.
|
| 444 |
+
"successful_sessions": int(total_sessions - failures),
|
| 445 |
+
# Regenerated is an orthogonal dimension (the user retried) — a
|
| 446 |
+
# session can be both successful and regenerated, or both errored
|
| 447 |
+
# and regenerated.
|
| 448 |
+
"regenerated_sessions": regenerates,
|
| 449 |
+
# Rates kept for backwards compatibility with anything reading the
|
| 450 |
+
# KPI dataset directly.
|
| 451 |
"tool_success_rate": round(tool_success / tool_total, 4) if tool_total > 0 else 0.0,
|
| 452 |
+
"failure_rate": round(failures / total_sessions, 4) if total_sessions > 0 else 0.0,
|
| 453 |
+
"regenerate_rate": round(regenerates / total_sessions, 4) if total_sessions > 0 else 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
"time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
|
| 455 |
"time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
|
| 456 |
"thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
|
| 457 |
"thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
|
| 458 |
"hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
|
| 459 |
"hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
|
| 460 |
+
"sandboxes_created": int(sum(s.get("sandboxes_created", 0) for s in per_session)),
|
| 461 |
+
"sandboxes_cpu": int(sum(s.get("sandboxes_cpu", 0) for s in per_session)),
|
| 462 |
+
"sandboxes_gpu": int(sum(s.get("sandboxes_gpu", 0) for s in per_session)),
|
| 463 |
+
"hf_jobs_blocked": int(sum(s.get("hf_jobs_blocked", 0) for s in per_session)),
|
| 464 |
+
"pro_cta_clicks": int(sum(s.get("pro_cta_clicks", 0) for s in per_session)),
|
| 465 |
"gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
|
| 466 |
+
# Research KPIs — answer "is the agent reaching for research?".
|
| 467 |
+
"research_calls": research_calls_total,
|
| 468 |
+
"sessions_with_research": int(sessions_with_research),
|
| 469 |
+
"research_calls_per_session_p50": round(_percentile(research_calls_nz, 0.5), 2),
|
| 470 |
+
"research_calls_per_session_p95": round(_percentile(research_calls_nz, 0.95), 2),
|
| 471 |
+
# Intra-session breadth + intensity. p50 + p95 over per-session values.
|
| 472 |
+
"distinct_tools_per_session_p50": round(_percentile(distinct_tools_values, 0.5), 2),
|
| 473 |
+
"distinct_tools_per_session_p95": round(_percentile(distinct_tools_values, 0.95), 2),
|
| 474 |
+
"tool_calls_per_session_p50": round(_percentile(total_calls_values, 0.5), 2),
|
| 475 |
+
"tool_calls_per_session_p95": round(_percentile(total_calls_values, 0.95), 2),
|
| 476 |
+
"tool_calls_per_turn_p50": round(_percentile(calls_per_turn_values, 0.5), 2),
|
| 477 |
+
"tool_calls_per_turn_p95": round(_percentile(calls_per_turn_values, 0.95), 2),
|
| 478 |
+
# JSON columns let the dashboard add/remove tools without schema churn.
|
| 479 |
+
"tool_calls_by_name_json": json.dumps(dict(tool_calls_by_name), sort_keys=True),
|
| 480 |
+
"sessions_using_tool_json": json.dumps(dict(sessions_using_tool), sort_keys=True),
|
| 481 |
+
# Surface split — answers "is research dropping on Bedrock specifically?".
|
| 482 |
+
"sessions_by_model_json": json.dumps(dict(sessions_by_model), sort_keys=True),
|
| 483 |
}
|
| 484 |
|
| 485 |
|
scripts/sweep_orphan_sandboxes.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Backstop sweeper for orphan ml-intern sandbox Spaces.
|
| 3 |
+
|
| 4 |
+
================================================================================
|
| 5 |
+
Why this script exists
|
| 6 |
+
================================================================================
|
| 7 |
+
|
| 8 |
+
The agent creates a sandbox Space per session (template duplicated from
|
| 9 |
+
``burtenshaw/sandbox`` into the user's account, named ``<owner>/sandbox-<8hex>``).
|
| 10 |
+
``backend.session_manager.SessionManager._cleanup_sandbox`` deletes it at end of
|
| 11 |
+
session. In practice the cleanup misses some sandboxes:
|
| 12 |
+
|
| 13 |
+
- pod killed / OOM / pre-emption / deploy rollouts → ``finally`` block skipped
|
| 14 |
+
- WebSocket dropped without ``/shutdown`` from the client
|
| 15 |
+
- HF API transient failure on ``delete_repo`` (we retry now, but not infinitely)
|
| 16 |
+
|
| 17 |
+
The result observed 2026-04-27 was 2,310 orphan ``sandbox-*`` Spaces — every
|
| 18 |
+
sandbox ever created was still around. This script is the backstop: list every
|
| 19 |
+
``sandbox-*`` fork of ``burtenshaw/sandbox`` that hasn't been touched in N days
|
| 20 |
+
and delete it.
|
| 21 |
+
|
| 22 |
+
================================================================================
|
| 23 |
+
Identification rules
|
| 24 |
+
================================================================================
|
| 25 |
+
|
| 26 |
+
A Space is considered an orphan ml-intern sandbox iff ALL hold:
|
| 27 |
+
|
| 28 |
+
1. Repo type = ``space``
|
| 29 |
+
2. Name matches ``<owner>/sandbox-[a-f0-9]{8}$`` (the agent's naming convention)
|
| 30 |
+
3. ``originRepo`` points at ``burtenshaw/sandbox`` (so we don't touch
|
| 31 |
+
user-renamed lookalikes)
|
| 32 |
+
4. ``lastModified`` older than ``--max-age-days`` (default 7)
|
| 33 |
+
|
| 34 |
+
We DO NOT use the ``runtime.stage`` (sleeping/running) as a filter — a sandbox
|
| 35 |
+
that has been sleeping for 7 days is just as orphan as a deleted one but uses
|
| 36 |
+
no compute. The cleanup is about repo/storage hygiene, not about waking
|
| 37 |
+
something up to kill it.
|
| 38 |
+
|
| 39 |
+
================================================================================
|
| 40 |
+
Safety
|
| 41 |
+
================================================================================
|
| 42 |
+
|
| 43 |
+
- ``--dry-run`` (default) prints what would be deleted, deletes nothing.
|
| 44 |
+
- ``--apply`` actually calls ``HfApi.delete_repo``.
|
| 45 |
+
- Hard cap ``--max-deletes`` (default 200) so a misconfigured run can't nuke
|
| 46 |
+
thousands at once.
|
| 47 |
+
- Requires a token with admin rights via ``HF_ADMIN_TOKEN`` env var (the only
|
| 48 |
+
way to delete a Space owned by another user).
|
| 49 |
+
- Logs every action to stdout in JSON Lines for downstream auditing.
|
| 50 |
+
|
| 51 |
+
================================================================================
|
| 52 |
+
Cron suggestion
|
| 53 |
+
================================================================================
|
| 54 |
+
|
| 55 |
+
GitHub Actions, daily at 04:00 UTC:
|
| 56 |
+
|
| 57 |
+
schedule:
|
| 58 |
+
- cron: "0 4 * * *"
|
| 59 |
+
env:
|
| 60 |
+
HF_ADMIN_TOKEN: ${{ secrets.HF_ADMIN_TOKEN }}
|
| 61 |
+
steps:
|
| 62 |
+
- run: python scripts/sweep_orphan_sandboxes.py --apply --max-age-days 7
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
import argparse
|
| 66 |
+
import json
|
| 67 |
+
import os
|
| 68 |
+
import re
|
| 69 |
+
import sys
|
| 70 |
+
import time
|
| 71 |
+
from datetime import datetime, timedelta, timezone
|
| 72 |
+
|
| 73 |
+
from huggingface_hub import HfApi
|
| 74 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 75 |
+
|
| 76 |
+
SANDBOX_NAME_RE = re.compile(r"^[^/]+/sandbox-[a-f0-9]{8}$")
|
| 77 |
+
TEMPLATE_REPO = "burtenshaw/sandbox"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def log(record: dict) -> None:
|
| 81 |
+
"""JSON Lines log so downstream tooling can grep / parse."""
|
| 82 |
+
record["ts"] = datetime.now(timezone.utc).isoformat()
|
| 83 |
+
print(json.dumps(record), flush=True)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def is_sandbox_fork(space) -> bool:
|
| 87 |
+
"""Filter: matches the ml-intern sandbox naming pattern.
|
| 88 |
+
|
| 89 |
+
NOTE: We initially tried filtering on ``duplicated_from == burtenshaw/sandbox``
|
| 90 |
+
too, for extra safety. That doesn't work — the HF REST API does not expose
|
| 91 |
+
``duplicated_from`` on ``SpaceInfo`` (verified against ``huggingface-hub``
|
| 92 |
+
1.11+ and direct ``GET /api/spaces/{id}``: the field is None). The origin
|
| 93 |
+
repo lives in MongoDB but isn't surfaced. So we rely on the naming pattern
|
| 94 |
+
alone, which is specific enough: ``Sandbox.create()`` is the sole producer
|
| 95 |
+
of ``<owner>/sandbox-<8 lowercase hex>``, and that pattern is unlikely to
|
| 96 |
+
collide with user-created Spaces in practice. The ``--dry-run`` default
|
| 97 |
+
is the user-facing safety net for the rare false-positive.
|
| 98 |
+
"""
|
| 99 |
+
return bool(SANDBOX_NAME_RE.match(space.id))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def main() -> int:
|
| 103 |
+
parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--max-age-days",
|
| 106 |
+
type=int,
|
| 107 |
+
default=7,
|
| 108 |
+
help="Delete sandboxes whose lastModified is older than this many days (default: 7)",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--max-deletes",
|
| 112 |
+
type=int,
|
| 113 |
+
default=200,
|
| 114 |
+
help="Hard cap on deletions per run, safety guard (default: 200)",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--apply",
|
| 118 |
+
action="store_true",
|
| 119 |
+
help="Actually delete. Without this flag, dry-run only.",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--limit",
|
| 123 |
+
type=int,
|
| 124 |
+
default=10000,
|
| 125 |
+
help="Max number of candidate Spaces to scan via list_spaces (default: 10000)",
|
| 126 |
+
)
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
|
| 129 |
+
token = os.environ.get("HF_ADMIN_TOKEN")
|
| 130 |
+
if not token:
|
| 131 |
+
log({"level": "error", "msg": "HF_ADMIN_TOKEN env var not set"})
|
| 132 |
+
return 1
|
| 133 |
+
|
| 134 |
+
api = HfApi(token=token)
|
| 135 |
+
cutoff = datetime.now(timezone.utc) - timedelta(days=args.max_age_days)
|
| 136 |
+
log({"level": "info", "msg": "sweep_start", "cutoff": cutoff.isoformat(),
|
| 137 |
+
"max_deletes": args.max_deletes, "apply": args.apply})
|
| 138 |
+
|
| 139 |
+
# ``list_spaces`` doesn't filter by name pattern — we scan and filter
|
| 140 |
+
# client-side. ``search="sandbox"`` narrows the network payload.
|
| 141 |
+
candidates = api.list_spaces(
|
| 142 |
+
search="sandbox", full=True, limit=args.limit
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
scanned = 0
|
| 146 |
+
matched = 0
|
| 147 |
+
deleted = 0
|
| 148 |
+
failed = 0
|
| 149 |
+
skipped_too_recent = 0
|
| 150 |
+
skipped_capped = 0
|
| 151 |
+
|
| 152 |
+
for space in candidates:
|
| 153 |
+
scanned += 1
|
| 154 |
+
if not is_sandbox_fork(space):
|
| 155 |
+
continue
|
| 156 |
+
matched += 1
|
| 157 |
+
|
| 158 |
+
last_mod = getattr(space, "lastModified", None) or getattr(space, "last_modified", None)
|
| 159 |
+
if isinstance(last_mod, str):
|
| 160 |
+
last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00"))
|
| 161 |
+
if last_mod and last_mod > cutoff:
|
| 162 |
+
skipped_too_recent += 1
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
log({"level": "info", "msg": "candidate", "space_id": space.id,
|
| 166 |
+
"last_modified": last_mod.isoformat() if last_mod else None})
|
| 167 |
+
|
| 168 |
+
if not args.apply:
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
# When we hit the deletion cap, keep scanning so the final ``matched``
|
| 172 |
+
# count reflects the *true* orphan size — not just what was scanned
|
| 173 |
+
# before we stopped deleting. Operators planning multi-pass cleanups
|
| 174 |
+
# need an accurate denominator to know when they're done.
|
| 175 |
+
if deleted >= args.max_deletes:
|
| 176 |
+
skipped_capped += 1
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
api.delete_repo(repo_id=space.id, repo_type="space", token=token)
|
| 181 |
+
deleted += 1
|
| 182 |
+
log({"level": "info", "msg": "deleted", "space_id": space.id})
|
| 183 |
+
# Light throttle to avoid hitting HF API rate limits.
|
| 184 |
+
time.sleep(0.2)
|
| 185 |
+
except HfHubHTTPError as e:
|
| 186 |
+
failed += 1
|
| 187 |
+
log({"level": "error", "msg": "delete_failed", "space_id": space.id,
|
| 188 |
+
"status": e.response.status_code, "error": str(e)[:200]})
|
| 189 |
+
except Exception as e:
|
| 190 |
+
failed += 1
|
| 191 |
+
log({"level": "error", "msg": "delete_failed", "space_id": space.id,
|
| 192 |
+
"error": str(e)[:200]})
|
| 193 |
+
|
| 194 |
+
log({"level": "info", "msg": "sweep_end",
|
| 195 |
+
"scanned": scanned, "matched": matched,
|
| 196 |
+
"skipped_too_recent": skipped_too_recent,
|
| 197 |
+
"skipped_capped": skipped_capped,
|
| 198 |
+
"deleted": deleted, "failed": failed,
|
| 199 |
+
"capped": skipped_capped > 0,
|
| 200 |
+
"apply": args.apply})
|
| 201 |
+
|
| 202 |
+
return 0 if failed == 0 else 2
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
sys.exit(main())
|
tests/integration/test_live_sandbox_auth.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Opt-in live sandbox communication test.
|
| 2 |
+
|
| 3 |
+
This test creates a real Hugging Face Space sandbox, verifies that unauthenticated
|
| 4 |
+
requests are rejected, then exercises the authenticated agent client end-to-end.
|
| 5 |
+
It is skipped unless ``ML_INTERN_LIVE_SANDBOX_TESTS=1`` and ``HF_TOKEN`` are set.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import httpx
|
| 14 |
+
import pytest
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
from huggingface_hub import HfApi
|
| 17 |
+
|
| 18 |
+
from agent.tools.sandbox_client import Sandbox
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
if env_file := os.environ.get("ML_INTERN_LIVE_ENV_FILE"):
|
| 22 |
+
load_dotenv(Path(env_file))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _skip_without_live_sandbox() -> None:
|
| 26 |
+
if os.environ.get("ML_INTERN_LIVE_SANDBOX_TESTS") != "1":
|
| 27 |
+
pytest.skip("set ML_INTERN_LIVE_SANDBOX_TESTS=1 to create a real sandbox")
|
| 28 |
+
if not os.environ.get("HF_TOKEN"):
|
| 29 |
+
pytest.skip("set HF_TOKEN to create a real sandbox")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_live_sandbox_authenticated_agent_communication():
|
| 33 |
+
_skip_without_live_sandbox()
|
| 34 |
+
|
| 35 |
+
token = os.environ["HF_TOKEN"]
|
| 36 |
+
owner = HfApi(token=token).whoami()["name"]
|
| 37 |
+
sandbox = None
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
sandbox = Sandbox.create(
|
| 41 |
+
owner=owner,
|
| 42 |
+
name="ml-intern-live-auth",
|
| 43 |
+
hardware="cpu-basic",
|
| 44 |
+
private=False,
|
| 45 |
+
token=token,
|
| 46 |
+
secrets={"HF_TOKEN": token},
|
| 47 |
+
wait_timeout=900,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
unauthenticated = httpx.Client(
|
| 51 |
+
base_url=sandbox._base_url,
|
| 52 |
+
timeout=30,
|
| 53 |
+
follow_redirects=True,
|
| 54 |
+
)
|
| 55 |
+
try:
|
| 56 |
+
denied = unauthenticated.post("exists", json={"path": "/tmp"})
|
| 57 |
+
assert denied.status_code == 401
|
| 58 |
+
finally:
|
| 59 |
+
unauthenticated.close()
|
| 60 |
+
|
| 61 |
+
bash = sandbox.bash("printf sandbox-live-ok", timeout=30)
|
| 62 |
+
assert bash.success, bash.error
|
| 63 |
+
assert "sandbox-live-ok" in bash.output
|
| 64 |
+
|
| 65 |
+
write = sandbox.write("/tmp/ml_intern_live_auth.txt", "alpha\nbeta\n")
|
| 66 |
+
assert write.success, write.error
|
| 67 |
+
|
| 68 |
+
exists = sandbox._call("exists", {"path": "/tmp/ml_intern_live_auth.txt"})
|
| 69 |
+
assert exists.success, exists.error
|
| 70 |
+
assert exists.output == "true"
|
| 71 |
+
|
| 72 |
+
read = sandbox.read("/tmp/ml_intern_live_auth.txt")
|
| 73 |
+
assert read.success, read.error
|
| 74 |
+
assert "alpha" in read.output
|
| 75 |
+
assert "beta" in read.output
|
| 76 |
+
|
| 77 |
+
reattached = Sandbox.connect(
|
| 78 |
+
sandbox.space_id,
|
| 79 |
+
token=token,
|
| 80 |
+
api_token=sandbox.api_token,
|
| 81 |
+
)
|
| 82 |
+
try:
|
| 83 |
+
reread = reattached.read("/tmp/ml_intern_live_auth.txt")
|
| 84 |
+
assert reread.success, reread.error
|
| 85 |
+
assert "alpha" in reread.output
|
| 86 |
+
finally:
|
| 87 |
+
reattached._client.close()
|
| 88 |
+
finally:
|
| 89 |
+
if sandbox is not None:
|
| 90 |
+
sandbox.delete()
|
tests/integration/test_live_thinking_models.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Opt-in live provider checks for thinking metadata replay.
|
| 2 |
+
|
| 3 |
+
These tests intentionally call paid model APIs and are skipped unless
|
| 4 |
+
``ML_INTERN_LIVE_LLM_TESTS=1`` plus the relevant provider key are set.
|
| 5 |
+
They cover the concrete model families involved in #87 without making
|
| 6 |
+
default CI depend on external credentials or provider availability.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from types import SimpleNamespace
|
| 14 |
+
|
| 15 |
+
import pytest
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
from litellm import Message
|
| 18 |
+
|
| 19 |
+
from agent.core.agent_loop import (
|
| 20 |
+
_assistant_message_from_result,
|
| 21 |
+
_call_llm_streaming,
|
| 22 |
+
)
|
| 23 |
+
from agent.core.llm_params import _resolve_llm_params
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if env_file := os.environ.get("ML_INTERN_LIVE_ENV_FILE"):
|
| 27 |
+
load_dotenv(Path(env_file))
|
| 28 |
+
|
| 29 |
+
LIVE_TESTS_ENABLED = os.environ.get("ML_INTERN_LIVE_LLM_TESTS") == "1"
|
| 30 |
+
OPUS_47_MODEL = "anthropic/claude-opus-4-7"
|
| 31 |
+
LATEST_GPT_MODEL = "openai/gpt-5.2"
|
| 32 |
+
REPORT_RESULT_TOOL = [
|
| 33 |
+
{
|
| 34 |
+
"type": "function",
|
| 35 |
+
"function": {
|
| 36 |
+
"name": "report_result",
|
| 37 |
+
"description": "Report the final test result.",
|
| 38 |
+
"parameters": {
|
| 39 |
+
"type": "object",
|
| 40 |
+
"properties": {
|
| 41 |
+
"answer": {
|
| 42 |
+
"type": "string",
|
| 43 |
+
"description": "The exact marker requested by the test.",
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
"required": ["answer"],
|
| 47 |
+
},
|
| 48 |
+
},
|
| 49 |
+
}
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _skip_without_live_flag() -> None:
|
| 54 |
+
if not LIVE_TESTS_ENABLED:
|
| 55 |
+
pytest.skip("set ML_INTERN_LIVE_LLM_TESTS=1 to run paid live LLM tests")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _skip_without_env(name: str) -> None:
|
| 59 |
+
if not os.environ.get(name):
|
| 60 |
+
pytest.skip(f"set {name} to run this live provider test")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _session(model_name: str):
|
| 64 |
+
events = []
|
| 65 |
+
|
| 66 |
+
async def send_event(event):
|
| 67 |
+
events.append(event)
|
| 68 |
+
|
| 69 |
+
return SimpleNamespace(
|
| 70 |
+
config=SimpleNamespace(model_name=model_name),
|
| 71 |
+
is_cancelled=False,
|
| 72 |
+
send_event=send_event,
|
| 73 |
+
events=events,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@pytest.mark.asyncio
|
| 78 |
+
async def test_live_opus_47_preserves_thinking_metadata_for_replay():
|
| 79 |
+
_skip_without_live_flag()
|
| 80 |
+
_skip_without_env("ANTHROPIC_API_KEY")
|
| 81 |
+
|
| 82 |
+
session = _session(OPUS_47_MODEL)
|
| 83 |
+
llm_params = _resolve_llm_params(
|
| 84 |
+
OPUS_47_MODEL,
|
| 85 |
+
reasoning_effort="high",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
result = await _call_llm_streaming(
|
| 89 |
+
session,
|
| 90 |
+
messages=[
|
| 91 |
+
Message(
|
| 92 |
+
role="user",
|
| 93 |
+
content=(
|
| 94 |
+
"Use careful reasoning for this small check. "
|
| 95 |
+
"If 17 * 19 = 323, call report_result with answer OPUS_OK."
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
],
|
| 99 |
+
tools=REPORT_RESULT_TOOL,
|
| 100 |
+
llm_params=llm_params,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
replay = _assistant_message_from_result(
|
| 104 |
+
result,
|
| 105 |
+
model_name=OPUS_47_MODEL,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
assert result.content or result.tool_calls_acc
|
| 109 |
+
assert result.thinking_blocks, (
|
| 110 |
+
"Opus returned no thinking_blocks with reasoning_effort='high' - "
|
| 111 |
+
"check that adaptive thinking params are being forwarded correctly"
|
| 112 |
+
)
|
| 113 |
+
assert getattr(replay, "thinking_blocks", None) == result.thinking_blocks
|
| 114 |
+
assert getattr(replay, "reasoning_content", None) == result.reasoning_content
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@pytest.mark.asyncio
|
| 118 |
+
async def test_live_latest_gpt_does_not_replay_reasoning_metadata():
|
| 119 |
+
_skip_without_live_flag()
|
| 120 |
+
_skip_without_env("OPENAI_API_KEY")
|
| 121 |
+
|
| 122 |
+
session = _session(LATEST_GPT_MODEL)
|
| 123 |
+
llm_params = _resolve_llm_params(
|
| 124 |
+
LATEST_GPT_MODEL,
|
| 125 |
+
reasoning_effort="low",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
result = await _call_llm_streaming(
|
| 129 |
+
session,
|
| 130 |
+
messages=[
|
| 131 |
+
Message(
|
| 132 |
+
role="user",
|
| 133 |
+
content="Call report_result with answer GPT_OK.",
|
| 134 |
+
)
|
| 135 |
+
],
|
| 136 |
+
tools=REPORT_RESULT_TOOL,
|
| 137 |
+
llm_params=llm_params,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Even if a GPT-family response carries provider reasoning internally,
|
| 141 |
+
# OpenAI-compatible history must not echo it back on the next tool turn.
|
| 142 |
+
# Force the non-None strip path when the live model omits reasoning details.
|
| 143 |
+
result.reasoning_content = result.reasoning_content or "synthetic-reasoning"
|
| 144 |
+
replay = _assistant_message_from_result(
|
| 145 |
+
result,
|
| 146 |
+
model_name=LATEST_GPT_MODEL,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
assert result.content or result.tool_calls_acc
|
| 150 |
+
assert getattr(replay, "thinking_blocks", None) is None
|
| 151 |
+
assert getattr(replay, "reasoning_content", None) is None
|