akseljoonas commited on
Commit
649feee
·
2 Parent(s): d0d08fcd9d9785

Deploy 2026-04-28

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +50 -0
  2. agent/config.py +108 -3
  3. agent/context_manager/manager.py +3 -0
  4. agent/core/agent_loop.py +217 -19
  5. agent/core/doom_loop.py +37 -6
  6. agent/core/hf_access.py +7 -3
  7. agent/core/hf_tokens.py +85 -0
  8. agent/core/llm_params.py +10 -9
  9. agent/core/session.py +169 -7
  10. agent/core/session_persistence.py +428 -0
  11. agent/core/session_uploader.py +2 -0
  12. agent/core/tools.py +14 -0
  13. agent/main.py +42 -30
  14. agent/messaging/__init__.py +15 -0
  15. agent/messaging/base.py +27 -0
  16. agent/messaging/gateway.py +166 -0
  17. agent/messaging/models.py +123 -0
  18. agent/messaging/slack.py +186 -0
  19. agent/prompts/system_prompt_v3.yaml +36 -3
  20. agent/tools/__init__.py +3 -0
  21. agent/tools/jobs_tool.py +86 -20
  22. agent/tools/notify_tool.py +108 -0
  23. agent/tools/research_tool.py +8 -3
  24. agent/tools/sandbox_client.py +55 -14
  25. agent/tools/sandbox_tool.py +171 -5
  26. agent/tools/trackio_seed.py +205 -0
  27. agent/tools/web_search_tool.py +273 -0
  28. backend/dependencies.py +7 -7
  29. backend/main.py +8 -5
  30. backend/models.py +9 -1
  31. backend/routes/agent.py +131 -76
  32. backend/session_manager.py +466 -58
  33. backend/user_quotas.py +42 -4
  34. configs/__init__.py +0 -0
  35. configs/cli_agent_config.json +5 -0
  36. frontend/src/components/Chat/MarkdownContent.tsx +10 -2
  37. frontend/src/components/Chat/ToolCallGroup.tsx +201 -1
  38. frontend/src/components/JobsUpgradeDialog.tsx +63 -54
  39. frontend/src/components/SessionSidebar/SessionSidebar.tsx +19 -2
  40. frontend/src/components/WelcomeScreen/WelcomeScreen.tsx +157 -141
  41. frontend/src/hooks/useAgentChat.ts +115 -55
  42. frontend/src/lib/sse-chat-transport.ts +79 -20
  43. frontend/src/store/agentStore.ts +83 -0
  44. frontend/src/store/sessionStore.ts +47 -0
  45. frontend/src/types/events.ts +1 -0
  46. pyproject.toml +17 -3
  47. scripts/build_kpis.py +133 -16
  48. scripts/sweep_orphan_sandboxes.py +206 -0
  49. tests/integration/test_live_sandbox_auth.py +90 -0
  50. tests/integration/test_live_thinking_models.py +151 -0
README.md CHANGED
@@ -75,6 +75,56 @@ ml-intern --max-iterations 100 "your prompt"
75
  ml-intern --no-stream "your prompt"
76
  ```
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  ## Architecture
79
 
80
  ### Component Overview
 
75
  ml-intern --no-stream "your prompt"
76
  ```
77
 
78
+ ## Supported Gateways
79
+
80
+ ML Intern currently supports one-way notification gateways from CLI sessions.
81
+ These gateways send out-of-band status updates; they do not accept inbound chat
82
+ messages.
83
+
84
+ ### Slack
85
+
86
+ Slack notifications use the Slack Web API to post messages when the agent needs
87
+ approval, hits an error, or completes a turn. Create a Slack app with a bot token
88
+ that has `chat:write`, invite the bot to the target channel, then set:
89
+
90
+ ```bash
91
+ SLACK_BOT_TOKEN=xoxb-...
92
+ SLACK_CHANNEL_ID=C...
93
+ ```
94
+
95
+ The CLI automatically creates a `slack.default` destination when both variables
96
+ are present. Optional environment variables for the env-only default:
97
+
98
+ ```bash
99
+ ML_INTERN_SLACK_NOTIFICATIONS=false
100
+ ML_INTERN_SLACK_DESTINATION=slack.ops
101
+ ML_INTERN_SLACK_AUTO_EVENTS=approval_required,error,turn_complete
102
+ ML_INTERN_SLACK_ALLOW_AGENT_TOOL=true
103
+ ML_INTERN_SLACK_ALLOW_AUTO_EVENTS=true
104
+ ```
105
+
106
+ For a persistent user-level config, put overrides in
107
+ `~/.config/ml-intern/cli_agent_config.json` or point `ML_INTERN_CLI_CONFIG` at a
108
+ JSON file:
109
+
110
+ ```json
111
+ {
112
+ "messaging": {
113
+ "enabled": true,
114
+ "auto_event_types": ["approval_required", "error", "turn_complete"],
115
+ "destinations": {
116
+ "slack.ops": {
117
+ "provider": "slack",
118
+ "token": "${SLACK_BOT_TOKEN}",
119
+ "channel": "${SLACK_CHANNEL_ID}",
120
+ "allow_agent_tool": true,
121
+ "allow_auto_events": true
122
+ }
123
+ }
124
+ }
125
+ }
126
+ ```
127
+
128
  ## Architecture
129
 
130
  ### Component Overview
agent/config.py CHANGED
@@ -6,6 +6,8 @@ from typing import Any, Union
6
 
7
  from dotenv import load_dotenv
8
 
 
 
9
  # Project root: two levels up from this file (agent/config.py -> project root)
10
  _PROJECT_ROOT = Path(__file__).resolve().parent.parent
11
  from fastmcp.mcp_config import (
@@ -47,6 +49,104 @@ class Config(BaseModel):
47
  # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
48
  # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
49
  reasoning_effort: str | None = "max"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def substitute_env_vars(obj: Any) -> Any:
@@ -86,7 +186,10 @@ def substitute_env_vars(obj: Any) -> Any:
86
  return obj
87
 
88
 
89
- def load_config(config_path: str = "config.json") -> Config:
 
 
 
90
  """
91
  Load configuration with environment variable substitution.
92
 
@@ -98,8 +201,10 @@ def load_config(config_path: str = "config.json") -> Config:
98
  load_dotenv(_PROJECT_ROOT / ".env")
99
  load_dotenv(override=False)
100
 
101
- with open(config_path, "r") as f:
102
- raw_config = json.load(f)
 
 
103
 
104
  config_with_env = substitute_env_vars(raw_config)
105
  return Config.model_validate(config_with_env)
 
6
 
7
  from dotenv import load_dotenv
8
 
9
+ from agent.messaging.models import MessagingConfig
10
+
11
  # Project root: two levels up from this file (agent/config.py -> project root)
12
  _PROJECT_ROOT = Path(__file__).resolve().parent.parent
13
  from fastmcp.mcp_config import (
 
49
  # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
50
  # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
51
  reasoning_effort: str | None = "max"
52
+ messaging: MessagingConfig = MessagingConfig()
53
+
54
+
55
+ USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG"
56
+ DEFAULT_USER_CONFIG_PATH = Path.home() / ".config" / "ml-intern" / "cli_agent_config.json"
57
+ SLACK_DEFAULT_DESTINATION = "slack.default"
58
+ SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"]
59
+
60
+
61
+ def _deep_merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
62
+ merged = dict(base)
63
+ for key, value in override.items():
64
+ current = merged.get(key)
65
+ if isinstance(current, dict) and isinstance(value, dict):
66
+ merged[key] = _deep_merge_config(current, value)
67
+ else:
68
+ merged[key] = value
69
+ return merged
70
+
71
+
72
+ def _load_json_config(path: Path) -> dict[str, Any]:
73
+ with open(path, "r", encoding="utf-8") as f:
74
+ data = json.load(f)
75
+ if not isinstance(data, dict):
76
+ raise ValueError(f"Config file {path} must contain a JSON object")
77
+ return data
78
+
79
+
80
+ def _load_user_config() -> dict[str, Any]:
81
+ raw_path = os.environ.get(USER_CONFIG_ENV_VAR)
82
+ if raw_path:
83
+ path = Path(raw_path).expanduser()
84
+ if not path.exists():
85
+ raise FileNotFoundError(
86
+ f"{USER_CONFIG_ENV_VAR} points to missing config file: {path}"
87
+ )
88
+ return _load_json_config(path)
89
+
90
+ if DEFAULT_USER_CONFIG_PATH.exists():
91
+ return _load_json_config(DEFAULT_USER_CONFIG_PATH)
92
+ return {}
93
+
94
+
95
+ def _env_bool(name: str, default: bool) -> bool:
96
+ value = os.environ.get(name)
97
+ if value is None:
98
+ return default
99
+ normalized = value.strip().lower()
100
+ if normalized in {"1", "true", "yes", "on"}:
101
+ return True
102
+ if normalized in {"0", "false", "no", "off"}:
103
+ return False
104
+ return default
105
+
106
+
107
+ def _env_list(name: str) -> list[str] | None:
108
+ value = os.environ.get(name)
109
+ if value is None:
110
+ return None
111
+ return [item.strip() for item in value.split(",") if item.strip()]
112
+
113
+
114
+ def apply_slack_user_defaults(raw_config: dict[str, Any]) -> dict[str, Any]:
115
+ """Enable a default Slack destination from user env vars, when present."""
116
+ if not _env_bool("ML_INTERN_SLACK_NOTIFICATIONS", True):
117
+ return raw_config
118
+
119
+ token = os.environ.get("SLACK_BOT_TOKEN")
120
+ channel = os.environ.get("SLACK_CHANNEL_ID") or os.environ.get("SLACK_CHANNEL")
121
+ if not token or not channel:
122
+ return raw_config
123
+
124
+ config = dict(raw_config)
125
+ messaging = dict(config.get("messaging") or {})
126
+ destinations = dict(messaging.get("destinations") or {})
127
+ destination_name = (
128
+ os.environ.get("ML_INTERN_SLACK_DESTINATION") or SLACK_DEFAULT_DESTINATION
129
+ ).strip()
130
+
131
+ if destination_name not in destinations:
132
+ destinations[destination_name] = {
133
+ "provider": "slack",
134
+ "token": token,
135
+ "channel": channel,
136
+ "allow_agent_tool": _env_bool("ML_INTERN_SLACK_ALLOW_AGENT_TOOL", True),
137
+ "allow_auto_events": _env_bool("ML_INTERN_SLACK_ALLOW_AUTO_EVENTS", True),
138
+ }
139
+
140
+ auto_events = _env_list("ML_INTERN_SLACK_AUTO_EVENTS")
141
+ if auto_events is not None:
142
+ messaging["auto_event_types"] = auto_events
143
+ elif "auto_event_types" not in messaging:
144
+ messaging["auto_event_types"] = SLACK_DEFAULT_AUTO_EVENT_TYPES
145
+
146
+ messaging["enabled"] = True
147
+ messaging["destinations"] = destinations
148
+ config["messaging"] = messaging
149
+ return config
150
 
151
 
152
  def substitute_env_vars(obj: Any) -> Any:
 
186
  return obj
187
 
188
 
189
+ def load_config(
190
+ config_path: str = "config.json",
191
+ include_user_defaults: bool = False,
192
+ ) -> Config:
193
  """
194
  Load configuration with environment variable substitution.
195
 
 
201
  load_dotenv(_PROJECT_ROOT / ".env")
202
  load_dotenv(override=False)
203
 
204
+ raw_config = _load_json_config(Path(config_path))
205
+ if include_user_defaults:
206
+ raw_config = _deep_merge_config(raw_config, _load_user_config())
207
+ raw_config = apply_slack_user_defaults(raw_config)
208
 
209
  config_with_env = substitute_env_vars(raw_config)
210
  return Config.model_validate(config_with_env)
agent/context_manager/manager.py CHANGED
@@ -160,6 +160,7 @@ class ContextManager:
160
  self.running_context_usage = 0
161
  self.untouched_messages = untouched_messages
162
  self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
 
163
 
164
  def _load_system_prompt(
165
  self,
@@ -219,6 +220,8 @@ class ContextManager:
219
  if token_count:
220
  self.running_context_usage = token_count
221
  self.items.append(message)
 
 
222
 
223
  def get_messages(self) -> list[Message]:
224
  """Get all messages for sending to LLM.
 
160
  self.running_context_usage = 0
161
  self.untouched_messages = untouched_messages
162
  self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
163
+ self.on_message_added = None
164
 
165
  def _load_system_prompt(
166
  self,
 
220
  if token_count:
221
  self.running_context_usage = token_count
222
  self.items.append(message)
223
+ if self.on_message_added:
224
+ self.on_message_added(message)
225
 
226
  def get_messages(self) -> list[Message]:
227
  """Get all messages for sending to LLM.
agent/core/agent_loop.py CHANGED
@@ -8,11 +8,18 @@ import logging
8
  import os
9
  import time
10
  from dataclasses import dataclass, field
11
-
12
- from litellm import ChatCompletionMessageToolCall, Message, acompletion
 
 
 
 
 
 
13
  from litellm.exceptions import ContextWindowExceededError
14
 
15
  from agent.config import Config
 
16
  from agent.core import telemetry
17
  from agent.core.doom_loop import check_for_doom_loop
18
  from agent.core.llm_params import _resolve_llm_params
@@ -396,12 +403,159 @@ class LLMResult:
396
  token_count: int
397
  finish_reason: str | None
398
  usage: dict = field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
 
401
  async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
402
  """Call the LLM with streaming, emitting assistant_chunk events."""
403
  response = None
404
  _healed_effort = False # one-shot safety net per call
 
405
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
406
  t_start = time.monotonic()
407
  for _llm_attempt in range(_MAX_LLM_RETRIES):
@@ -429,6 +583,14 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
429
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
430
  ))
431
  continue
 
 
 
 
 
 
 
 
432
  _delay = _retry_delay_for(e, _llm_attempt)
433
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
434
  logger.warning(
@@ -448,8 +610,11 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
448
  token_count = 0
449
  finish_reason = None
450
  final_usage_chunk = None
 
 
451
 
452
  async for chunk in response:
 
453
  if session.is_cancelled:
454
  tool_calls_acc.clear()
455
  break
@@ -498,6 +663,16 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
498
  latency_ms=int((time.monotonic() - t_start) * 1000),
499
  finish_reason=finish_reason,
500
  )
 
 
 
 
 
 
 
 
 
 
501
 
502
  return LLMResult(
503
  content=full_content or None,
@@ -505,6 +680,8 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
505
  token_count=token_count,
506
  finish_reason=finish_reason,
507
  usage=usage,
 
 
508
  )
509
 
510
 
@@ -512,6 +689,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
512
  """Call the LLM without streaming, emit assistant_message at the end."""
513
  response = None
514
  _healed_effort = False
 
515
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
516
  t_start = time.monotonic()
517
  for _llm_attempt in range(_MAX_LLM_RETRIES):
@@ -538,6 +716,14 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
538
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
539
  ))
540
  continue
 
 
 
 
 
 
 
 
541
  _delay = _retry_delay_for(e, _llm_attempt)
542
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
543
  logger.warning(
@@ -557,6 +743,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
557
  content = message.content or None
558
  finish_reason = choice.finish_reason
559
  token_count = response.usage.total_tokens if response.usage else 0
 
560
 
561
  # Build tool_calls_acc in the same format as streaming
562
  tool_calls_acc: dict[int, dict] = {}
@@ -591,6 +778,8 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
591
  token_count=token_count,
592
  finish_reason=finish_reason,
593
  usage=usage,
 
 
594
  )
595
 
596
 
@@ -681,15 +870,6 @@ class Handlers:
681
  session.context_manager.add_message(
682
  Message(role="user", content=doom_prompt)
683
  )
684
- await session.send_event(
685
- Event(
686
- event_type="tool_log",
687
- data={
688
- "tool": "system",
689
- "log": "Doom loop detected — injecting corrective prompt",
690
- },
691
- )
692
- )
693
 
694
  malformed_tool = _detect_repeated_malformed(session.context_manager.items)
695
  if malformed_tool:
@@ -763,7 +943,10 @@ class Handlers:
763
  " • For other tools: reduce the size of your arguments or use bash."
764
  )
765
  if content:
766
- assistant_msg = Message(role="assistant", content=content)
 
 
 
767
  session.context_manager.add_message(assistant_msg, token_count)
768
  session.context_manager.add_message(
769
  Message(role="user", content=f"[SYSTEM: {truncation_hint}]")
@@ -819,7 +1002,10 @@ class Handlers:
819
  (content or "")[:500],
820
  )
821
  if content:
822
- assistant_msg = Message(role="assistant", content=content)
 
 
 
823
  session.context_manager.add_message(assistant_msg, token_count)
824
  final_response = content
825
  break
@@ -841,9 +1027,9 @@ class Handlers:
841
  bad_tools.append(tc)
842
 
843
  # Add assistant message with all tool calls to context
844
- assistant_msg = Message(
845
- role="assistant",
846
- content=content,
847
  tool_calls=tool_calls,
848
  )
849
  session.context_manager.add_message(assistant_msg, token_count)
@@ -1049,7 +1235,12 @@ class Handlers:
1049
  await session.send_event(
1050
  Event(
1051
  event_type="turn_complete",
1052
- data={"history_size": len(session.context_manager.items)},
 
 
 
 
 
1053
  )
1054
  )
1055
 
@@ -1358,12 +1549,16 @@ async def process_submission(session: Session, submission) -> bool:
1358
  async def submission_loop(
1359
  submission_queue: asyncio.Queue,
1360
  event_queue: asyncio.Queue,
1361
- config: Config | None = None,
1362
  tool_router: ToolRouter | None = None,
1363
  session_holder: list | None = None,
1364
  hf_token: str | None = None,
 
1365
  local_mode: bool = False,
1366
  stream: bool = True,
 
 
 
1367
  ) -> None:
1368
  """
1369
  Main agent loop - processes submissions and dispatches to handlers.
@@ -1373,7 +1568,10 @@ async def submission_loop(
1373
  # Create session with tool router
1374
  session = Session(
1375
  event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
1376
- local_mode=local_mode, stream=stream,
 
 
 
1377
  )
1378
  if session_holder is not None:
1379
  session_holder[0] = session
 
8
  import os
9
  import time
10
  from dataclasses import dataclass, field
11
+ from typing import Any
12
+
13
+ from litellm import (
14
+ ChatCompletionMessageToolCall,
15
+ Message,
16
+ acompletion,
17
+ stream_chunk_builder,
18
+ )
19
  from litellm.exceptions import ContextWindowExceededError
20
 
21
  from agent.config import Config
22
+ from agent.messaging.gateway import NotificationGateway
23
  from agent.core import telemetry
24
  from agent.core.doom_loop import check_for_doom_loop
25
  from agent.core.llm_params import _resolve_llm_params
 
403
  token_count: int
404
  finish_reason: str | None
405
  usage: dict = field(default_factory=dict)
406
+ thinking_blocks: list[dict[str, Any]] | None = None
407
+ reasoning_content: str | None = None
408
+
409
+
410
+ def _extract_thinking_state(
411
+ message: Any,
412
+ ) -> tuple[list[dict[str, Any]] | None, str | None]:
413
+ """Return provider reasoning fields that must be replayed after tool calls."""
414
+ provider_fields = getattr(message, "provider_specific_fields", None)
415
+ if not isinstance(provider_fields, dict):
416
+ provider_fields = {}
417
+
418
+ thinking_blocks = (
419
+ getattr(message, "thinking_blocks", None)
420
+ or provider_fields.get("thinking_blocks")
421
+ or None
422
+ )
423
+ reasoning_content = (
424
+ getattr(message, "reasoning_content", None)
425
+ or provider_fields.get("reasoning_content")
426
+ or None
427
+ )
428
+ return thinking_blocks, reasoning_content
429
+
430
+
431
+ def _should_replay_thinking_state(model_name: str | None) -> bool:
432
+ """Only Anthropic's native adapter accepts replayed thinking metadata."""
433
+ return bool(model_name and model_name.startswith("anthropic/"))
434
+
435
+
436
+ def _is_invalid_thinking_signature_error(exc: Exception) -> bool:
437
+ """Return True when Anthropic rejected replayed extended-thinking state."""
438
+ text = str(exc)
439
+ return (
440
+ "Invalid `signature` in `thinking` block" in text
441
+ or "Invalid signature in thinking block" in text
442
+ )
443
+
444
+
445
+ def _strip_thinking_state_from_messages(messages: list[Any]) -> int:
446
+ """Remove replayed thinking metadata from assistant history messages."""
447
+ stripped = 0
448
+
449
+ for message in messages:
450
+ role = (
451
+ message.get("role")
452
+ if isinstance(message, dict)
453
+ else getattr(message, "role", None)
454
+ )
455
+ if role != "assistant":
456
+ continue
457
+
458
+ if isinstance(message, dict):
459
+ if message.pop("thinking_blocks", None) is not None:
460
+ stripped += 1
461
+ if message.pop("reasoning_content", None) is not None:
462
+ stripped += 1
463
+ provider_fields = message.get("provider_specific_fields")
464
+ content = message.get("content")
465
+ else:
466
+ if getattr(message, "thinking_blocks", None) is not None:
467
+ message.thinking_blocks = None
468
+ stripped += 1
469
+ if getattr(message, "reasoning_content", None) is not None:
470
+ message.reasoning_content = None
471
+ stripped += 1
472
+ provider_fields = getattr(message, "provider_specific_fields", None)
473
+ content = getattr(message, "content", None)
474
+
475
+ if isinstance(provider_fields, dict):
476
+ cleaned_fields = dict(provider_fields)
477
+ if cleaned_fields.pop("thinking_blocks", None) is not None:
478
+ stripped += 1
479
+ if cleaned_fields.pop("reasoning_content", None) is not None:
480
+ stripped += 1
481
+ if cleaned_fields != provider_fields:
482
+ if isinstance(message, dict):
483
+ message["provider_specific_fields"] = cleaned_fields
484
+ else:
485
+ message.provider_specific_fields = cleaned_fields
486
+
487
+ if isinstance(content, list):
488
+ cleaned_content = [
489
+ block
490
+ for block in content
491
+ if not (
492
+ isinstance(block, dict)
493
+ and block.get("type") in {"thinking", "redacted_thinking"}
494
+ )
495
+ ]
496
+ if len(cleaned_content) != len(content):
497
+ stripped += len(content) - len(cleaned_content)
498
+ if isinstance(message, dict):
499
+ message["content"] = cleaned_content
500
+ else:
501
+ message.content = cleaned_content
502
+
503
+ return stripped
504
+
505
+
506
+ async def _maybe_heal_invalid_thinking_signature(
507
+ session: Session,
508
+ messages: list[Any],
509
+ exc: Exception,
510
+ *,
511
+ already_healed: bool,
512
+ ) -> bool:
513
+ if already_healed or not _is_invalid_thinking_signature_error(exc):
514
+ return False
515
+
516
+ stripped = _strip_thinking_state_from_messages(messages)
517
+ if not stripped:
518
+ return False
519
+
520
+ await session.send_event(Event(
521
+ event_type="tool_log",
522
+ data={
523
+ "tool": "system",
524
+ "log": (
525
+ "Anthropic rejected stale thinking signatures; retrying "
526
+ "without replayed thinking metadata."
527
+ ),
528
+ },
529
+ ))
530
+ return True
531
+
532
+
533
+ def _assistant_message_from_result(
534
+ llm_result: LLMResult,
535
+ *,
536
+ model_name: str | None,
537
+ tool_calls: list[ToolCall] | None = None,
538
+ ) -> Message:
539
+ """Build an assistant history message without dropping reasoning state."""
540
+ kwargs: dict[str, Any] = {
541
+ "role": "assistant",
542
+ "content": llm_result.content,
543
+ }
544
+ if tool_calls is not None:
545
+ kwargs["tool_calls"] = tool_calls
546
+ if _should_replay_thinking_state(model_name):
547
+ if llm_result.thinking_blocks:
548
+ kwargs["thinking_blocks"] = llm_result.thinking_blocks
549
+ if llm_result.reasoning_content:
550
+ kwargs["reasoning_content"] = llm_result.reasoning_content
551
+ return Message(**kwargs)
552
 
553
 
554
  async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
555
  """Call the LLM with streaming, emitting assistant_chunk events."""
556
  response = None
557
  _healed_effort = False # one-shot safety net per call
558
+ _healed_thinking_signature = False
559
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
560
  t_start = time.monotonic()
561
  for _llm_attempt in range(_MAX_LLM_RETRIES):
 
583
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
584
  ))
585
  continue
586
+ if await _maybe_heal_invalid_thinking_signature(
587
+ session,
588
+ messages,
589
+ e,
590
+ already_healed=_healed_thinking_signature,
591
+ ):
592
+ _healed_thinking_signature = True
593
+ continue
594
  _delay = _retry_delay_for(e, _llm_attempt)
595
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
596
  logger.warning(
 
610
  token_count = 0
611
  finish_reason = None
612
  final_usage_chunk = None
613
+ chunks = []
614
+ should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))
615
 
616
  async for chunk in response:
617
+ chunks.append(chunk)
618
  if session.is_cancelled:
619
  tool_calls_acc.clear()
620
  break
 
663
  latency_ms=int((time.monotonic() - t_start) * 1000),
664
  finish_reason=finish_reason,
665
  )
666
+ thinking_blocks = None
667
+ reasoning_content = None
668
+ if chunks and should_replay_thinking:
669
+ try:
670
+ rebuilt = stream_chunk_builder(chunks, messages=messages)
671
+ if rebuilt and getattr(rebuilt, "choices", None):
672
+ rebuilt_msg = rebuilt.choices[0].message
673
+ thinking_blocks, reasoning_content = _extract_thinking_state(rebuilt_msg)
674
+ except Exception:
675
+ logger.debug("Failed to rebuild streaming thinking state", exc_info=True)
676
 
677
  return LLMResult(
678
  content=full_content or None,
 
680
  token_count=token_count,
681
  finish_reason=finish_reason,
682
  usage=usage,
683
+ thinking_blocks=thinking_blocks,
684
+ reasoning_content=reasoning_content,
685
  )
686
 
687
 
 
689
  """Call the LLM without streaming, emit assistant_message at the end."""
690
  response = None
691
  _healed_effort = False
692
+ _healed_thinking_signature = False
693
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
694
  t_start = time.monotonic()
695
  for _llm_attempt in range(_MAX_LLM_RETRIES):
 
716
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
717
  ))
718
  continue
719
+ if await _maybe_heal_invalid_thinking_signature(
720
+ session,
721
+ messages,
722
+ e,
723
+ already_healed=_healed_thinking_signature,
724
+ ):
725
+ _healed_thinking_signature = True
726
+ continue
727
  _delay = _retry_delay_for(e, _llm_attempt)
728
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
729
  logger.warning(
 
743
  content = message.content or None
744
  finish_reason = choice.finish_reason
745
  token_count = response.usage.total_tokens if response.usage else 0
746
+ thinking_blocks, reasoning_content = _extract_thinking_state(message)
747
 
748
  # Build tool_calls_acc in the same format as streaming
749
  tool_calls_acc: dict[int, dict] = {}
 
778
  token_count=token_count,
779
  finish_reason=finish_reason,
780
  usage=usage,
781
+ thinking_blocks=thinking_blocks,
782
+ reasoning_content=reasoning_content,
783
  )
784
 
785
 
 
870
  session.context_manager.add_message(
871
  Message(role="user", content=doom_prompt)
872
  )
 
 
 
 
 
 
 
 
 
873
 
874
  malformed_tool = _detect_repeated_malformed(session.context_manager.items)
875
  if malformed_tool:
 
943
  " • For other tools: reduce the size of your arguments or use bash."
944
  )
945
  if content:
946
+ assistant_msg = _assistant_message_from_result(
947
+ llm_result,
948
+ model_name=llm_params.get("model"),
949
+ )
950
  session.context_manager.add_message(assistant_msg, token_count)
951
  session.context_manager.add_message(
952
  Message(role="user", content=f"[SYSTEM: {truncation_hint}]")
 
1002
  (content or "")[:500],
1003
  )
1004
  if content:
1005
+ assistant_msg = _assistant_message_from_result(
1006
+ llm_result,
1007
+ model_name=llm_params.get("model"),
1008
+ )
1009
  session.context_manager.add_message(assistant_msg, token_count)
1010
  final_response = content
1011
  break
 
1027
  bad_tools.append(tc)
1028
 
1029
  # Add assistant message with all tool calls to context
1030
+ assistant_msg = _assistant_message_from_result(
1031
+ llm_result,
1032
+ model_name=llm_params.get("model"),
1033
  tool_calls=tool_calls,
1034
  )
1035
  session.context_manager.add_message(assistant_msg, token_count)
 
1235
  await session.send_event(
1236
  Event(
1237
  event_type="turn_complete",
1238
+ data={
1239
+ "history_size": len(session.context_manager.items),
1240
+ "final_response": final_response
1241
+ if isinstance(final_response, str)
1242
+ else None,
1243
+ },
1244
  )
1245
  )
1246
 
 
1549
  async def submission_loop(
1550
  submission_queue: asyncio.Queue,
1551
  event_queue: asyncio.Queue,
1552
+ config: Config,
1553
  tool_router: ToolRouter | None = None,
1554
  session_holder: list | None = None,
1555
  hf_token: str | None = None,
1556
+ user_id: str | None = None,
1557
  local_mode: bool = False,
1558
  stream: bool = True,
1559
+ notification_gateway: NotificationGateway | None = None,
1560
+ notification_destinations: list[str] | None = None,
1561
+ defer_turn_complete_notification: bool = False,
1562
  ) -> None:
1563
  """
1564
  Main agent loop - processes submissions and dispatches to handlers.
 
1568
  # Create session with tool router
1569
  session = Session(
1570
  event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
1571
+ user_id=user_id, local_mode=local_mode, stream=stream,
1572
+ notification_gateway=notification_gateway,
1573
+ notification_destinations=notification_destinations,
1574
+ defer_turn_complete_notification=defer_turn_complete_notification,
1575
  )
1576
  if session_holder is not None:
1577
  session_holder[0] = session
agent/core/doom_loop.py CHANGED
@@ -24,9 +24,36 @@ class ToolCallSignature:
24
  result_hash: str | None = None
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def _hash_args(args_str: str) -> str:
28
- """Return a short hash of the JSON arguments string."""
29
- return hashlib.md5(args_str.encode()).hexdigest()[:12]
 
 
 
 
 
30
 
31
 
32
  def extract_recent_tool_signatures(
@@ -129,9 +156,13 @@ def check_for_doom_loop(messages: list[Message]) -> str | None:
129
  # Check for identical consecutive calls
130
  tool_name = detect_identical_consecutive(signatures, threshold=3)
131
  if tool_name:
132
- logger.warning("Doom loop detected: %d+ identical consecutive calls to '%s'", 3, tool_name)
 
 
 
 
133
  return (
134
- f"[SYSTEM: DOOM LOOP DETECTED] You have called '{tool_name}' with the same "
135
  f"arguments multiple times in a row, getting the same result each time. "
136
  f"STOP repeating this approach — it is not working. "
137
  f"Step back and try a fundamentally different strategy. "
@@ -143,9 +174,9 @@ def check_for_doom_loop(messages: list[Message]) -> str | None:
143
  pattern = detect_repeating_sequence(signatures)
144
  if pattern:
145
  pattern_desc = " → ".join(s.name for s in pattern)
146
- logger.warning("Doom loop detected: repeating sequence [%s]", pattern_desc)
147
  return (
148
- f"[SYSTEM: DOOM LOOP DETECTED] You are stuck in a repeating cycle of tool calls: "
149
  f"[{pattern_desc}]. This pattern has repeated multiple times without progress. "
150
  f"STOP this cycle and try a fundamentally different approach. "
151
  f"Consider: breaking down the problem differently, using alternative tools, "
 
24
  result_hash: str | None = None
25
 
26
 
27
+ def _normalize_args(args_str: str) -> str:
28
+ """Canonicalise a tool-call arguments string before hashing.
29
+
30
+ LLMs can emit semantically-identical JSON for the same call with different
31
+ key orderings (``{"a": 1, "b": 2}`` vs ``{"b": 2, "a": 1}``) or whitespace
32
+ (``{"a":1}`` vs ``{"a": 1}``). Hashing the raw bytes makes the doom-loop
33
+ detector miss those repeats. We parse-and-redump with ``sort_keys=True``
34
+ plus the most compact separators so trivially-different spellings collapse
35
+ to the same canonical form.
36
+
37
+ Falls back to the original string if the input isn't valid JSON (e.g. a
38
+ handful of providers occasionally pass a bare string for ``arguments``);
39
+ that path keeps the legacy behaviour and never raises.
40
+ """
41
+ if not args_str:
42
+ return ""
43
+ try:
44
+ return json.dumps(json.loads(args_str), sort_keys=True, separators=(",", ":"))
45
+ except (json.JSONDecodeError, TypeError, ValueError):
46
+ return args_str
47
+
48
+
49
  def _hash_args(args_str: str) -> str:
50
+ """Return a short hash of the JSON arguments string.
51
+
52
+ The input is normalised via :func:`_normalize_args` first so that
53
+ semantically-identical tool calls produce the same hash regardless of key
54
+ order or whitespace.
55
+ """
56
+ return hashlib.md5(_normalize_args(args_str).encode()).hexdigest()[:12]
57
 
58
 
59
  def extract_recent_tool_signatures(
 
156
  # Check for identical consecutive calls
157
  tool_name = detect_identical_consecutive(signatures, threshold=3)
158
  if tool_name:
159
+ logger.warning(
160
+ "Repetition guard activated: %d+ identical consecutive calls to '%s'",
161
+ 3,
162
+ tool_name,
163
+ )
164
  return (
165
+ f"[SYSTEM: REPETITION GUARD] You have called '{tool_name}' with the same "
166
  f"arguments multiple times in a row, getting the same result each time. "
167
  f"STOP repeating this approach — it is not working. "
168
  f"Step back and try a fundamentally different strategy. "
 
174
  pattern = detect_repeating_sequence(signatures)
175
  if pattern:
176
  pattern_desc = " → ".join(s.name for s in pattern)
177
+ logger.warning("Repetition guard activated: repeating sequence [%s]", pattern_desc)
178
  return (
179
+ f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: "
180
  f"[{pattern_desc}]. This pattern has repeated multiple times without progress. "
181
  f"STOP this cycle and try a fundamentally different approach. "
182
  f"Consider: breaking down the problem differently, using alternative tools, "
agent/core/hf_access.py CHANGED
@@ -55,6 +55,13 @@ def _extract_username(whoami: dict[str, Any]) -> str | None:
55
 
56
 
57
  def _normalize_personal_plan(whoami: dict[str, Any]) -> str:
 
 
 
 
 
 
 
58
  plan_str = ""
59
  for key in ("plan", "type", "accountType"):
60
  value = whoami.get(key)
@@ -62,9 +69,6 @@ def _normalize_personal_plan(whoami: dict[str, Any]) -> str:
62
  plan_str = value.lower()
63
  break
64
 
65
- if not plan_str and (whoami.get("isPro") is True or whoami.get("is_pro") is True):
66
- return "pro"
67
-
68
  if any(tag in plan_str for tag in ("pro", "enterprise", "team")):
69
  return "pro"
70
  return "free"
 
55
 
56
 
57
  def _normalize_personal_plan(whoami: dict[str, Any]) -> str:
58
+ # OAuth whoami responses set `type: "user"` and surface Pro status only via
59
+ # the `isPro` boolean. Check the boolean first so a generic `type` value
60
+ # doesn't shadow it — otherwise Pro OAuth users get classified as free and
61
+ # blocked from running Jobs (smolagents/ml-intern Space discussion #21).
62
+ if whoami.get("isPro") is True or whoami.get("is_pro") is True:
63
+ return "pro"
64
+
65
  plan_str = ""
66
  for key in ("plan", "type", "accountType"):
67
  value = whoami.get(key)
 
69
  plan_str = value.lower()
70
  break
71
 
 
 
 
72
  if any(tag in plan_str for tag in ("pro", "enterprise", "team")):
73
  return "pro"
74
  return "free"
agent/core/hf_tokens.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face token resolution helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+
9
+ def clean_hf_token(token: str | None) -> str | None:
10
+ """Normalize token strings the same way huggingface_hub does."""
11
+ if token is None:
12
+ return None
13
+ return token.replace("\r", "").replace("\n", "").strip() or None
14
+
15
+
16
+ def get_cached_hf_token() -> str | None:
17
+ """Return the token from huggingface_hub's normal env/cache lookup."""
18
+ try:
19
+ from huggingface_hub import get_token
20
+
21
+ return get_token()
22
+ except Exception:
23
+ return None
24
+
25
+
26
+ def resolve_hf_token(
27
+ *candidates: str | None,
28
+ include_cached: bool = True,
29
+ ) -> str | None:
30
+ """Return the first non-empty explicit token, then optionally HF cache."""
31
+ for token in candidates:
32
+ cleaned = clean_hf_token(token)
33
+ if cleaned:
34
+ return cleaned
35
+ if include_cached:
36
+ return get_cached_hf_token()
37
+ return None
38
+
39
+
40
+ def resolve_hf_router_token(session_hf_token: str | None = None) -> str | None:
41
+ """Resolve the token used for Hugging Face Router LLM calls.
42
+
43
+ App-specific precedence:
44
+ 1. INFERENCE_TOKEN: shared hosted-Space inference token.
45
+ 2. session_hf_token: the active user/session token.
46
+ 3. huggingface_hub.get_token(): HF_TOKEN/HUGGING_FACE_HUB_TOKEN or
47
+ local ``hf auth login`` cache.
48
+ """
49
+ return resolve_hf_token(os.environ.get("INFERENCE_TOKEN"), session_hf_token)
50
+
51
+
52
+ def get_hf_bill_to() -> str | None:
53
+ """Return X-HF-Bill-To only when a shared inference token is active."""
54
+ if clean_hf_token(os.environ.get("INFERENCE_TOKEN")):
55
+ return os.environ.get("HF_BILL_TO", "smolagents")
56
+ return None
57
+
58
+
59
+ def bearer_token_from_header(auth_header: str | None) -> str | None:
60
+ """Extract a cleaned bearer token from an Authorization header."""
61
+ if not auth_header or not auth_header.startswith("Bearer "):
62
+ return None
63
+ return clean_hf_token(auth_header[7:])
64
+
65
+
66
+ def resolve_hf_request_token(
67
+ request: Any,
68
+ *,
69
+ include_env_fallback: bool = True,
70
+ ) -> str | None:
71
+ """Resolve a user token from a FastAPI request.
72
+
73
+ This intentionally does not use the local ``hf auth login`` cache. Backend
74
+ request paths should act as the browser user from Authorization/cookie, or
75
+ fall back only to an explicit server ``HF_TOKEN`` in dev/server contexts.
76
+ """
77
+ token = bearer_token_from_header(request.headers.get("Authorization", ""))
78
+ if token:
79
+ return token
80
+ token = clean_hf_token(request.cookies.get("hf_access_token"))
81
+ if token:
82
+ return token
83
+ if include_env_fallback:
84
+ return clean_hf_token(os.environ.get("HF_TOKEN"))
85
+ return None
agent/core/llm_params.py CHANGED
@@ -5,7 +5,12 @@ can import it without pulling in the whole agent loop / tool router and
5
  creating circular imports.
6
  """
7
 
8
- import os
 
 
 
 
 
9
 
10
 
11
  def _patch_litellm_effort_validation() -> None:
@@ -129,7 +134,8 @@ def _resolve_llm_params(
129
  1. INFERENCE_TOKEN env — shared key on the hosted Space (inference is
130
  free for users, billed to the Space owner via ``X-HF-Bill-To``).
131
  2. session.hf_token — the user's own token (CLI / OAuth / cache file).
132
- 3. HF_TOKEN envbelt-and-suspenders fallback for CLI users.
 
133
  """
134
  if model_name.startswith("anthropic/"):
135
  params: dict = {"model": model_name}
@@ -175,18 +181,13 @@ def _resolve_llm_params(
175
  return params
176
 
177
  hf_model = model_name.removeprefix("huggingface/")
178
- api_key = (
179
- os.environ.get("INFERENCE_TOKEN")
180
- or session_hf_token
181
- or os.environ.get("HF_TOKEN")
182
- )
183
  params = {
184
  "model": f"openai/{hf_model}",
185
  "api_base": "https://router.huggingface.co/v1",
186
  "api_key": api_key,
187
  }
188
- if os.environ.get("INFERENCE_TOKEN"):
189
- bill_to = os.environ.get("HF_BILL_TO", "smolagents")
190
  params["extra_headers"] = {"X-HF-Bill-To": bill_to}
191
  if reasoning_effort:
192
  hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort
 
5
  creating circular imports.
6
  """
7
 
8
+ from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token
9
+
10
+
11
+ def _resolve_hf_router_token(session_hf_token: str | None = None) -> str | None:
12
+ """Backward-compatible private wrapper used by tests and older imports."""
13
+ return resolve_hf_router_token(session_hf_token)
14
 
15
 
16
  def _patch_litellm_effort_validation() -> None:
 
134
  1. INFERENCE_TOKEN env — shared key on the hosted Space (inference is
135
  free for users, billed to the Space owner via ``X-HF-Bill-To``).
136
  2. session.hf_token — the user's own token (CLI / OAuth / cache file).
137
+ 3. huggingface_hub cache``HF_TOKEN`` / ``HUGGING_FACE_HUB_TOKEN`` /
138
+ local ``hf auth login`` cache.
139
  """
140
  if model_name.startswith("anthropic/"):
141
  params: dict = {"model": model_name}
 
181
  return params
182
 
183
  hf_model = model_name.removeprefix("huggingface/")
184
+ api_key = _resolve_hf_router_token(session_hf_token)
 
 
 
 
185
  params = {
186
  "model": f"openai/{hf_model}",
187
  "api_base": "https://router.huggingface.co/v1",
188
  "api_key": api_key,
189
  }
190
+ if bill_to := get_hf_bill_to():
 
191
  params["extra_headers"] = {"X-HF-Bill-To": bill_to}
192
  if reasoning_effort:
193
  hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort
agent/core/session.py CHANGED
@@ -12,10 +12,13 @@ from typing import Any, Optional
12
 
13
  from agent.config import Config
14
  from agent.context_manager.manager import ContextManager
 
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
  _DEFAULT_MAX_TOKENS = 200_000
 
19
 
20
 
21
  def _get_max_tokens_safe(model_name: str) -> int:
@@ -62,6 +65,7 @@ class OpType(Enum):
62
  class Event:
63
  event_type: str
64
  data: Optional[dict[str, Any]] = None
 
65
 
66
 
67
  class Session:
@@ -73,16 +77,26 @@ class Session:
73
  def __init__(
74
  self,
75
  event_queue: asyncio.Queue,
76
- config: Config | None = None,
77
  tool_router=None,
78
  context_manager: ContextManager | None = None,
79
  hf_token: str | None = None,
80
  local_mode: bool = False,
81
  stream: bool = True,
 
 
 
 
 
 
82
  ):
83
  self.hf_token: Optional[str] = hf_token
 
 
84
  self.tool_router = tool_router
85
  self.stream = stream
 
 
86
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
87
  self.context_manager = context_manager or ContextManager(
88
  model_max_tokens=_get_max_tokens_safe(config.model_name),
@@ -93,15 +107,16 @@ class Session:
93
  local_mode=local_mode,
94
  )
95
  self.event_queue = event_queue
96
- self.session_id = str(uuid.uuid4())
97
- self.config = config or Config(
98
- model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0",
99
- )
100
  self.is_running = True
101
  self._cancelled = asyncio.Event()
102
  self.pending_approval: Optional[dict[str, Any]] = None
103
  self.sandbox = None
104
  self._running_job_ids: set[str] = set() # HF job IDs currently executing
 
 
 
105
 
106
  # Session trajectory logging
107
  self.logged_events: list[dict] = []
@@ -123,11 +138,10 @@ class Session:
123
  # thinking params at all
124
  # Key absent → not probed yet; fall back to the raw preference.
125
  self.model_effective_effort: dict[str, str | None] = {}
 
126
 
127
  async def send_event(self, event: Event) -> None:
128
  """Send event back to client and log to trajectory"""
129
- await self.event_queue.put(event)
130
-
131
  # Log event to trajectory
132
  self.logged_events.append(
133
  {
@@ -136,11 +150,149 @@ class Session:
136
  "data": event.data,
137
  }
138
  )
 
 
 
 
 
 
 
 
 
 
139
 
140
  # Mid-turn heartbeat flush (owned by telemetry module).
141
  from agent.core.telemetry import HeartbeatSaver
 
142
  HeartbeatSaver.maybe_fire(self)
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def cancel(self) -> None:
145
  """Signal cancellation to the running agent loop."""
146
  self._cancelled.set()
@@ -199,11 +351,21 @@ class Session:
199
  tools = self.tool_router.get_tool_specs_for_llm() or []
200
  except Exception:
201
  tools = []
 
 
 
 
 
 
 
 
202
  return {
203
  "session_id": self.session_id,
 
204
  "session_start_time": self.session_start_time,
205
  "session_end_time": datetime.now().isoformat(),
206
  "model_name": self.config.model_name,
 
207
  "messages": [msg.model_dump() for msg in self.context_manager.items],
208
  "events": self.logged_events,
209
  "tools": tools,
 
12
 
13
  from agent.config import Config
14
  from agent.context_manager.manager import ContextManager
15
+ from agent.messaging.gateway import NotificationGateway
16
+ from agent.messaging.models import NotificationRequest
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  _DEFAULT_MAX_TOKENS = 200_000
21
+ _TURN_COMPLETE_NOTIFICATION_CHARS = 39000
22
 
23
 
24
  def _get_max_tokens_safe(model_name: str) -> int:
 
65
  class Event:
66
  event_type: str
67
  data: Optional[dict[str, Any]] = None
68
+ seq: Optional[int] = None
69
 
70
 
71
  class Session:
 
77
  def __init__(
78
  self,
79
  event_queue: asyncio.Queue,
80
+ config: Config,
81
  tool_router=None,
82
  context_manager: ContextManager | None = None,
83
  hf_token: str | None = None,
84
  local_mode: bool = False,
85
  stream: bool = True,
86
+ notification_gateway: NotificationGateway | None = None,
87
+ notification_destinations: list[str] | None = None,
88
+ defer_turn_complete_notification: bool = False,
89
+ session_id: str | None = None,
90
+ user_id: str | None = None,
91
+ persistence_store: Any | None = None,
92
  ):
93
  self.hf_token: Optional[str] = hf_token
94
+ self.user_id: Optional[str] = user_id
95
+ self.persistence_store = persistence_store
96
  self.tool_router = tool_router
97
  self.stream = stream
98
+ if config is None:
99
+ raise ValueError("Session requires a Config")
100
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
101
  self.context_manager = context_manager or ContextManager(
102
  model_max_tokens=_get_max_tokens_safe(config.model_name),
 
107
  local_mode=local_mode,
108
  )
109
  self.event_queue = event_queue
110
+ self.session_id = session_id or str(uuid.uuid4())
111
+ self.config = config
 
 
112
  self.is_running = True
113
  self._cancelled = asyncio.Event()
114
  self.pending_approval: Optional[dict[str, Any]] = None
115
  self.sandbox = None
116
  self._running_job_ids: set[str] = set() # HF job IDs currently executing
117
+ self.notification_gateway = notification_gateway
118
+ self.notification_destinations = list(notification_destinations or [])
119
+ self.defer_turn_complete_notification = defer_turn_complete_notification
120
 
121
  # Session trajectory logging
122
  self.logged_events: list[dict] = []
 
138
  # thinking params at all
139
  # Key absent → not probed yet; fall back to the raw preference.
140
  self.model_effective_effort: dict[str, str | None] = {}
141
+ self.context_manager.on_message_added = self._schedule_trace_message
142
 
143
  async def send_event(self, event: Event) -> None:
144
  """Send event back to client and log to trajectory"""
 
 
145
  # Log event to trajectory
146
  self.logged_events.append(
147
  {
 
150
  "data": event.data,
151
  }
152
  )
153
+ if self.persistence_store is not None:
154
+ try:
155
+ event.seq = await self.persistence_store.append_event(
156
+ self.session_id, event.event_type, event.data
157
+ )
158
+ except Exception as e:
159
+ logger.debug("Event persistence failed for %s: %s", self.session_id, e)
160
+
161
+ await self.event_queue.put(event)
162
+ await self._enqueue_auto_notification_requests(event)
163
 
164
  # Mid-turn heartbeat flush (owned by telemetry module).
165
  from agent.core.telemetry import HeartbeatSaver
166
+
167
  HeartbeatSaver.maybe_fire(self)
168
 
169
+ def _schedule_trace_message(self, message: Any) -> None:
170
+ """Best-effort append-only trace save for SFT/KPI export."""
171
+ if self.persistence_store is None:
172
+ return
173
+ try:
174
+ payload = message.model_dump(mode="json")
175
+ except Exception:
176
+ return
177
+ try:
178
+ loop = asyncio.get_running_loop()
179
+ except RuntimeError:
180
+ return
181
+ source = str(payload.get("role") or "message")
182
+ loop.create_task(
183
+ self.persistence_store.append_trace_message(
184
+ self.session_id, payload, source=source
185
+ )
186
+ )
187
+
188
+ def set_notification_destinations(self, destinations: list[str]) -> None:
189
+ """Replace the session's opted-in auto-notification destinations."""
190
+ deduped: list[str] = []
191
+ seen: set[str] = set()
192
+ for destination in destinations:
193
+ if destination not in seen:
194
+ deduped.append(destination)
195
+ seen.add(destination)
196
+ self.notification_destinations = deduped
197
+
198
+ async def send_deferred_turn_complete_notification(self, event: Event) -> None:
199
+ if event.event_type != "turn_complete":
200
+ return
201
+ await self._enqueue_auto_notification_requests(
202
+ event,
203
+ include_deferred_turn_complete=True,
204
+ )
205
+
206
+ async def _enqueue_auto_notification_requests(
207
+ self,
208
+ event: Event,
209
+ include_deferred_turn_complete: bool = False,
210
+ ) -> None:
211
+ if self.notification_gateway is None:
212
+ return
213
+ if not self.notification_destinations:
214
+ return
215
+ auto_events = set(self.config.messaging.auto_event_types)
216
+ if event.event_type not in auto_events:
217
+ return
218
+ if (
219
+ self.defer_turn_complete_notification
220
+ and event.event_type == "turn_complete"
221
+ and not include_deferred_turn_complete
222
+ ):
223
+ return
224
+
225
+ requests = self._build_auto_notification_requests(event)
226
+ for request in requests:
227
+ await self.notification_gateway.enqueue(request)
228
+
229
+ def _build_auto_notification_requests(
230
+ self, event: Event
231
+ ) -> list[NotificationRequest]:
232
+ metadata = {
233
+ "session_id": self.session_id,
234
+ "model": self.config.model_name,
235
+ "event_type": event.event_type,
236
+ }
237
+
238
+ title: str | None = None
239
+ message: str | None = None
240
+ severity = "info"
241
+ data = event.data or {}
242
+ if event.event_type == "approval_required":
243
+ tools = data.get("tools", [])
244
+ tool_names = []
245
+ for tool in tools if isinstance(tools, list) else []:
246
+ if isinstance(tool, dict):
247
+ tool_name = str(tool.get("tool") or "").strip()
248
+ if tool_name and tool_name not in tool_names:
249
+ tool_names.append(tool_name)
250
+ count = len(tools) if isinstance(tools, list) else 0
251
+ title = "Agent approval required"
252
+ message = (
253
+ f"Session {self.session_id} is waiting for approval "
254
+ f"for {count} tool call(s)."
255
+ )
256
+ if tool_names:
257
+ message += " Tools: " + ", ".join(tool_names)
258
+ severity = "warning"
259
+ elif event.event_type == "error":
260
+ title = "Agent error"
261
+ error = str(data.get("error") or "Unknown error")
262
+ message = f"Session {self.session_id} hit an error.\n{error[:500]}"
263
+ severity = "error"
264
+ elif event.event_type == "turn_complete":
265
+ title = "Agent task complete"
266
+ summary = str(data.get("final_response") or "").strip()
267
+ if summary:
268
+ summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS]
269
+ message = (
270
+ f"Session {self.session_id} completed successfully.\n"
271
+ f"{summary}"
272
+ )
273
+ else:
274
+ message = f"Session {self.session_id} completed successfully."
275
+ severity = "success"
276
+
277
+ if message is None:
278
+ return []
279
+
280
+ requests: list[NotificationRequest] = []
281
+ for destination in self.notification_destinations:
282
+ if not self.config.messaging.can_auto_send(destination):
283
+ continue
284
+ requests.append(
285
+ NotificationRequest(
286
+ destination=destination,
287
+ title=title,
288
+ message=message,
289
+ severity=severity,
290
+ metadata=metadata,
291
+ event_type=event.event_type,
292
+ )
293
+ )
294
+ return requests
295
+
296
  def cancel(self) -> None:
297
  """Signal cancellation to the running agent loop."""
298
  self._cancelled.set()
 
351
  tools = self.tool_router.get_tool_specs_for_llm() or []
352
  except Exception:
353
  tools = []
354
+ # Sum per-call cost from llm_call events so analyzers don't have to
355
+ # walk the events array themselves. Each `llm_call` event already
356
+ # carries cost_usd from `agent.core.telemetry.record_llm_call`.
357
+ total_cost_usd = sum(
358
+ float((e.get("data") or {}).get("cost_usd") or 0.0)
359
+ for e in self.logged_events
360
+ if e.get("event_type") == "llm_call"
361
+ )
362
  return {
363
  "session_id": self.session_id,
364
+ "user_id": self.user_id,
365
  "session_start_time": self.session_start_time,
366
  "session_end_time": datetime.now().isoformat(),
367
  "model_name": self.config.model_name,
368
+ "total_cost_usd": total_cost_usd,
369
  "messages": [msg.model_dump() for msg in self.context_manager.items],
370
  "events": self.logged_events,
371
  "tools": tools,
agent/core/session_persistence.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Optional durable session persistence for the hosted backend.
2
+
3
+ The public CLI must keep working without MongoDB. This module therefore
4
+ exposes one small async store interface and returns a no-op implementation
5
+ unless ``MONGODB_URI`` is configured and reachable.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import os
12
+ from datetime import UTC, datetime
13
+ from typing import Any
14
+
15
+ from bson import BSON
16
+ from pymongo import AsyncMongoClient, DeleteMany, ReturnDocument, UpdateOne
17
+ from pymongo.errors import DuplicateKeyError, InvalidDocument, PyMongoError
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ SCHEMA_VERSION = 1
22
+ MAX_BSON_BYTES = 15 * 1024 * 1024
23
+
24
+
25
+ def _now() -> datetime:
26
+ return datetime.now(UTC)
27
+
28
+
29
+ def _doc_id(session_id: str, idx: int) -> str:
30
+ return f"{session_id}:{idx}"
31
+
32
+
33
+ def _safe_message_doc(message: dict[str, Any]) -> dict[str, Any]:
34
+ """Return a Mongo-safe message document payload.
35
+
36
+ Mongo's hard document limit is 16 MB. We stay below that and store an
37
+ explicit marker rather than failing the whole snapshot for one huge tool log.
38
+ """
39
+ try:
40
+ if len(BSON.encode({"message": message})) <= MAX_BSON_BYTES:
41
+ return message
42
+ except (InvalidDocument, OverflowError):
43
+ pass
44
+ return {
45
+ "role": "tool",
46
+ "content": (
47
+ "[SYSTEM: A single persisted message exceeded MongoDB's document "
48
+ "size/encoding limit and was replaced by this marker.]"
49
+ ),
50
+ "ml_intern_persistence_error": "message_too_large_or_invalid",
51
+ }
52
+
53
+
54
+ class NoopSessionStore:
55
+ """Async no-op store used when Mongo is not configured."""
56
+
57
+ enabled = False
58
+
59
+ async def init(self) -> None:
60
+ return None
61
+
62
+ async def close(self) -> None:
63
+ return None
64
+
65
+ async def upsert_session(self, **_: Any) -> None:
66
+ return None
67
+
68
+ async def save_snapshot(self, **_: Any) -> None:
69
+ return None
70
+
71
+ async def load_session(self, *_: Any, **__: Any) -> dict[str, Any] | None:
72
+ return None
73
+
74
+ async def list_sessions(self, *_: Any, **__: Any) -> list[dict[str, Any]]:
75
+ return []
76
+
77
+ async def soft_delete_session(self, *_: Any, **__: Any) -> None:
78
+ return None
79
+
80
+ async def update_session_fields(self, *_: Any, **__: Any) -> None:
81
+ return None
82
+
83
+ async def append_event(self, *_: Any, **__: Any) -> int | None:
84
+ return None
85
+
86
+ async def load_events_after(self, *_: Any, **__: Any) -> list[dict[str, Any]]:
87
+ return []
88
+
89
+ async def append_trace_message(self, *_: Any, **__: Any) -> int | None:
90
+ return None
91
+
92
+ async def get_quota(self, *_: Any, **__: Any) -> int | None:
93
+ return None
94
+
95
+ async def try_increment_quota(self, *_: Any, **__: Any) -> int | None:
96
+ return None
97
+
98
+ async def refund_quota(self, *_: Any, **__: Any) -> None:
99
+ return None
100
+
101
+
102
+ class MongoSessionStore(NoopSessionStore):
103
+ """MongoDB-backed session store."""
104
+
105
+ enabled = True
106
+
107
+ def __init__(self, uri: str, db_name: str) -> None:
108
+ self.uri = uri
109
+ self.db_name = db_name
110
+ self.enabled = False
111
+ self.client: AsyncMongoClient | None = None
112
+ self.db = None
113
+
114
+ async def init(self) -> None:
115
+ try:
116
+ self.client = AsyncMongoClient(self.uri, serverSelectionTimeoutMS=3000)
117
+ self.db = self.client[self.db_name]
118
+ await self.client.admin.command("ping")
119
+ await self._create_indexes()
120
+ self.enabled = True
121
+ logger.info("Mongo session persistence enabled (db=%s)", self.db_name)
122
+ except Exception as e:
123
+ logger.warning("Mongo session persistence disabled: %s", e)
124
+ self.enabled = False
125
+ if self.client is not None:
126
+ await self.client.close()
127
+ self.client = None
128
+ self.db = None
129
+
130
+ async def close(self) -> None:
131
+ if self.client is not None:
132
+ await self.client.close()
133
+ self.client = None
134
+ self.db = None
135
+
136
+ async def _create_indexes(self) -> None:
137
+ if self.db is None:
138
+ return
139
+ await self.db.sessions.create_index(
140
+ [("user_id", 1), ("visibility", 1), ("updated_at", -1)]
141
+ )
142
+ await self.db.sessions.create_index(
143
+ [("visibility", 1), ("status", 1), ("last_active_at", -1)]
144
+ )
145
+ await self.db.session_messages.create_index(
146
+ [("session_id", 1), ("idx", 1)], unique=True
147
+ )
148
+ await self.db.session_events.create_index(
149
+ [("session_id", 1), ("seq", 1)], unique=True
150
+ )
151
+ await self.db.session_trace_messages.create_index(
152
+ [("session_id", 1), ("seq", 1)], unique=True
153
+ )
154
+ await self.db.session_trace_messages.create_index([("created_at", -1)])
155
+
156
+ def _ready(self) -> bool:
157
+ return bool(self.enabled and self.db is not None)
158
+
159
+ async def upsert_session(
160
+ self,
161
+ *,
162
+ session_id: str,
163
+ user_id: str,
164
+ model: str,
165
+ title: str | None = None,
166
+ surface: str = "frontend",
167
+ created_at: datetime | None = None,
168
+ runtime_state: str = "idle",
169
+ status: str = "active",
170
+ message_count: int = 0,
171
+ turn_count: int = 0,
172
+ pending_approval: list[dict[str, Any]] | None = None,
173
+ claude_counted: bool = False,
174
+ notification_destinations: list[str] | None = None,
175
+ ) -> None:
176
+ if not self._ready():
177
+ return
178
+ now = _now()
179
+ await self.db.sessions.update_one(
180
+ {"_id": session_id},
181
+ {
182
+ "$setOnInsert": {
183
+ "_id": session_id,
184
+ "session_id": session_id,
185
+ "user_id": user_id,
186
+ "surface": surface,
187
+ "created_at": created_at or now,
188
+ "schema_version": SCHEMA_VERSION,
189
+ "visibility": "live",
190
+ },
191
+ "$set": {
192
+ "title": title,
193
+ "model": model,
194
+ "status": status,
195
+ "runtime_state": runtime_state,
196
+ "updated_at": now,
197
+ "last_active_at": now,
198
+ "message_count": message_count,
199
+ "turn_count": turn_count,
200
+ "pending_approval": pending_approval or [],
201
+ "claude_counted": claude_counted,
202
+ "notification_destinations": notification_destinations or [],
203
+ },
204
+ },
205
+ upsert=True,
206
+ )
207
+
208
+ async def save_snapshot(
209
+ self,
210
+ *,
211
+ session_id: str,
212
+ user_id: str,
213
+ model: str,
214
+ messages: list[dict[str, Any]],
215
+ title: str | None = None,
216
+ runtime_state: str = "idle",
217
+ status: str = "active",
218
+ turn_count: int = 0,
219
+ pending_approval: list[dict[str, Any]] | None = None,
220
+ claude_counted: bool = False,
221
+ created_at: datetime | None = None,
222
+ notification_destinations: list[str] | None = None,
223
+ ) -> None:
224
+ if not self._ready():
225
+ return
226
+ now = _now()
227
+ await self.upsert_session(
228
+ session_id=session_id,
229
+ user_id=user_id,
230
+ model=model,
231
+ title=title,
232
+ created_at=created_at,
233
+ runtime_state=runtime_state,
234
+ status=status,
235
+ message_count=len(messages),
236
+ turn_count=turn_count,
237
+ pending_approval=pending_approval,
238
+ claude_counted=claude_counted,
239
+ notification_destinations=notification_destinations,
240
+ )
241
+ ops: list[Any] = []
242
+ for idx, raw in enumerate(messages):
243
+ ops.append(
244
+ UpdateOne(
245
+ {"_id": _doc_id(session_id, idx)},
246
+ {
247
+ "$set": {
248
+ "session_id": session_id,
249
+ "idx": idx,
250
+ "message": _safe_message_doc(raw),
251
+ "updated_at": now,
252
+ },
253
+ "$setOnInsert": {"created_at": now},
254
+ },
255
+ upsert=True,
256
+ )
257
+ )
258
+ ops.append(DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}}))
259
+ try:
260
+ if ops:
261
+ await self.db.session_messages.bulk_write(ops, ordered=False)
262
+ except PyMongoError as e:
263
+ logger.warning("Failed to persist session %s snapshot: %s", session_id, e)
264
+
265
+ async def load_session(
266
+ self, session_id: str, *, include_deleted: bool = False
267
+ ) -> dict[str, Any] | None:
268
+ if not self._ready():
269
+ return None
270
+ meta = await self.db.sessions.find_one({"_id": session_id})
271
+ if not meta:
272
+ return None
273
+ if meta.get("visibility") == "deleted" and not include_deleted:
274
+ return None
275
+ cursor = self.db.session_messages.find({"session_id": session_id}).sort("idx", 1)
276
+ messages = [row.get("message") async for row in cursor]
277
+ return {"metadata": meta, "messages": messages}
278
+
279
+ async def list_sessions(
280
+ self, user_id: str, *, include_deleted: bool = False
281
+ ) -> list[dict[str, Any]]:
282
+ if not self._ready():
283
+ return []
284
+ query: dict[str, Any] = {"user_id": user_id}
285
+ if user_id == "dev":
286
+ query = {}
287
+ if not include_deleted:
288
+ query["visibility"] = {"$ne": "deleted"}
289
+ cursor = self.db.sessions.find(query).sort("updated_at", -1)
290
+ return [row async for row in cursor]
291
+
292
+ async def soft_delete_session(self, session_id: str) -> None:
293
+ if not self._ready():
294
+ return
295
+ await self.db.sessions.update_one(
296
+ {"_id": session_id},
297
+ {
298
+ "$set": {
299
+ "visibility": "deleted",
300
+ "runtime_state": "idle",
301
+ "updated_at": _now(),
302
+ }
303
+ },
304
+ )
305
+
306
+ async def update_session_fields(self, session_id: str, **fields: Any) -> None:
307
+ if not self._ready() or not fields:
308
+ return
309
+ fields["updated_at"] = _now()
310
+ await self.db.sessions.update_one({"_id": session_id}, {"$set": fields})
311
+
312
+ async def _next_seq(self, counter_id: str) -> int:
313
+ doc = await self.db.counters.find_one_and_update(
314
+ {"_id": counter_id},
315
+ {"$inc": {"seq": 1}},
316
+ upsert=True,
317
+ return_document=ReturnDocument.AFTER,
318
+ )
319
+ return int(doc["seq"])
320
+
321
+ async def append_event(
322
+ self, session_id: str, event_type: str, data: dict[str, Any] | None
323
+ ) -> int | None:
324
+ if not self._ready():
325
+ return None
326
+ try:
327
+ seq = await self._next_seq(f"event:{session_id}")
328
+ await self.db.session_events.insert_one(
329
+ {
330
+ "_id": _doc_id(session_id, seq),
331
+ "session_id": session_id,
332
+ "seq": seq,
333
+ "event_type": event_type,
334
+ "data": data or {},
335
+ "created_at": _now(),
336
+ }
337
+ )
338
+ return seq
339
+ except PyMongoError as e:
340
+ logger.debug("Failed to append event for %s: %s", session_id, e)
341
+ return None
342
+
343
+ async def load_events_after(self, session_id: str, after_seq: int = 0) -> list[dict[str, Any]]:
344
+ if not self._ready():
345
+ return []
346
+ cursor = self.db.session_events.find(
347
+ {"session_id": session_id, "seq": {"$gt": int(after_seq or 0)}}
348
+ ).sort("seq", 1)
349
+ return [row async for row in cursor]
350
+
351
+ async def append_trace_message(
352
+ self, session_id: str, message: dict[str, Any], source: str = "message"
353
+ ) -> int | None:
354
+ if not self._ready():
355
+ return None
356
+ try:
357
+ seq = await self._next_seq(f"trace:{session_id}")
358
+ await self.db.session_trace_messages.insert_one(
359
+ {
360
+ "_id": _doc_id(session_id, seq),
361
+ "session_id": session_id,
362
+ "seq": seq,
363
+ "role": message.get("role"),
364
+ "message": _safe_message_doc(message),
365
+ "source": source,
366
+ "created_at": _now(),
367
+ }
368
+ )
369
+ return seq
370
+ except PyMongoError as e:
371
+ logger.debug("Failed to append trace message for %s: %s", session_id, e)
372
+ return None
373
+
374
+ async def get_quota(self, user_id: str, day: str) -> int | None:
375
+ if not self._ready():
376
+ return None
377
+ doc = await self.db.claude_quotas.find_one({"_id": f"{user_id}:{day}"})
378
+ return int(doc.get("count", 0)) if doc else 0
379
+
380
+ async def try_increment_quota(self, user_id: str, day: str, cap: int) -> int | None:
381
+ if not self._ready():
382
+ return None
383
+ key = f"{user_id}:{day}"
384
+ now = _now()
385
+ try:
386
+ await self.db.claude_quotas.insert_one(
387
+ {
388
+ "_id": key,
389
+ "user_id": user_id,
390
+ "day": day,
391
+ "count": 1,
392
+ "updated_at": now,
393
+ }
394
+ )
395
+ return 1
396
+ except DuplicateKeyError:
397
+ pass
398
+ doc = await self.db.claude_quotas.find_one_and_update(
399
+ {"_id": key, "count": {"$lt": cap}},
400
+ {"$inc": {"count": 1}, "$set": {"updated_at": now}},
401
+ return_document=ReturnDocument.AFTER,
402
+ )
403
+ return int(doc["count"]) if doc else None
404
+
405
+ async def refund_quota(self, user_id: str, day: str) -> None:
406
+ if not self._ready():
407
+ return
408
+ await self.db.claude_quotas.update_one(
409
+ {"_id": f"{user_id}:{day}", "count": {"$gt": 0}},
410
+ {"$inc": {"count": -1}, "$set": {"updated_at": _now()}},
411
+ )
412
+
413
+
414
+ _store: NoopSessionStore | MongoSessionStore | None = None
415
+
416
+
417
+ def get_session_store() -> NoopSessionStore | MongoSessionStore:
418
+ global _store
419
+ if _store is None:
420
+ uri = os.environ.get("MONGODB_URI")
421
+ db_name = os.environ.get("MONGODB_DB", "ml-intern")
422
+ _store = MongoSessionStore(uri, db_name) if uri else NoopSessionStore()
423
+ return _store
424
+
425
+
426
+ def _reset_store_for_tests(store: NoopSessionStore | MongoSessionStore | None = None) -> None:
427
+ global _store
428
+ _store = store
agent/core/session_uploader.py CHANGED
@@ -90,9 +90,11 @@ def upload_session_as_file(
90
  # across sessions with different tool rosters.
91
  session_row = {
92
  "session_id": data["session_id"],
 
93
  "session_start_time": data["session_start_time"],
94
  "session_end_time": data["session_end_time"],
95
  "model_name": data["model_name"],
 
96
  "messages": json.dumps(scrubbed_messages),
97
  "events": json.dumps(scrubbed_events),
98
  "tools": json.dumps(scrubbed_tools),
 
90
  # across sessions with different tool rosters.
91
  session_row = {
92
  "session_id": data["session_id"],
93
+ "user_id": data.get("user_id"),
94
  "session_start_time": data["session_start_time"],
95
  "session_end_time": data["session_end_time"],
96
  "model_name": data["model_name"],
97
+ "total_cost_usd": data.get("total_cost_usd"),
98
  "messages": json.dumps(scrubbed_messages),
99
  "events": json.dumps(scrubbed_events),
100
  "tools": json.dumps(scrubbed_tools),
agent/core/tools.py CHANGED
@@ -46,10 +46,12 @@ from agent.tools.hf_repo_git_tool import (
46
  hf_repo_git_handler,
47
  )
48
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
 
49
  from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
50
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
51
  from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler
52
  from agent.tools.sandbox_tool import get_sandbox_tools
 
53
 
54
  # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
55
  # from agent.tools.private_hf_repo_tools import (
@@ -310,6 +312,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]:
310
  parameters=HF_PAPERS_TOOL_SPEC["parameters"],
311
  handler=hf_papers_handler,
312
  ),
 
 
 
 
 
 
313
  # Dataset inspection tool (unified)
314
  ToolSpec(
315
  name=HF_INSPECT_DATASET_TOOL_SPEC["name"],
@@ -324,6 +332,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]:
324
  parameters=PLAN_TOOL_SPEC["parameters"],
325
  handler=plan_tool_handler,
326
  ),
 
 
 
 
 
 
327
  ToolSpec(
328
  name=HF_JOBS_TOOL_SPEC["name"],
329
  description=HF_JOBS_TOOL_SPEC["description"],
 
46
  hf_repo_git_handler,
47
  )
48
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
49
+ from agent.tools.notify_tool import NOTIFY_TOOL_SPEC, notify_handler
50
  from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
51
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
52
  from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler
53
  from agent.tools.sandbox_tool import get_sandbox_tools
54
+ from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler
55
 
56
  # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
57
  # from agent.tools.private_hf_repo_tools import (
 
312
  parameters=HF_PAPERS_TOOL_SPEC["parameters"],
313
  handler=hf_papers_handler,
314
  ),
315
+ ToolSpec(
316
+ name=WEB_SEARCH_TOOL_SPEC["name"],
317
+ description=WEB_SEARCH_TOOL_SPEC["description"],
318
+ parameters=WEB_SEARCH_TOOL_SPEC["parameters"],
319
+ handler=web_search_handler,
320
+ ),
321
  # Dataset inspection tool (unified)
322
  ToolSpec(
323
  name=HF_INSPECT_DATASET_TOOL_SPEC["name"],
 
332
  parameters=PLAN_TOOL_SPEC["parameters"],
333
  handler=plan_tool_handler,
334
  ),
335
+ ToolSpec(
336
+ name=NOTIFY_TOOL_SPEC["name"],
337
+ description=NOTIFY_TOOL_SPEC["description"],
338
+ parameters=NOTIFY_TOOL_SPEC["parameters"],
339
+ handler=notify_handler,
340
+ ),
341
  ToolSpec(
342
  name=HF_JOBS_TOOL_SPEC["name"],
343
  description=HF_JOBS_TOOL_SPEC["description"],
agent/main.py CHANGED
@@ -23,8 +23,10 @@ from prompt_toolkit import PromptSession
23
  from agent.config import load_config
24
  from agent.core.agent_loop import submission_loop
25
  from agent.core import model_switcher
 
26
  from agent.core.session import OpType
27
  from agent.core.tools import ToolRouter
 
28
  from agent.utils.reliability_checks import check_training_script_save_pattern
29
  from agent.utils.terminal_display import (
30
  get_console,
@@ -69,26 +71,15 @@ def _safe_get_args(arguments: dict) -> dict:
69
  return args if isinstance(args, dict) else {}
70
 
71
 
72
- def _get_hf_token() -> str | None:
73
- """Get HF token from environment, huggingface_hub API, or cached token file."""
74
- token = os.environ.get("HF_TOKEN")
75
- if token:
76
- return token
77
  try:
78
  from huggingface_hub import HfApi
79
- api = HfApi()
80
- token = api.token
81
- if token:
82
- return token
83
  except Exception:
84
- pass
85
- # Fallback: read the cached token file directly
86
- token_path = Path.home() / ".cache" / "huggingface" / "token"
87
- if token_path.exists():
88
- token = token_path.read_text().strip()
89
- if token:
90
- return token
91
- return None
92
 
93
 
94
  async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str:
@@ -342,6 +333,9 @@ async def event_listener(
342
  stream_buf.discard()
343
  print_turn_complete()
344
  print_plan()
 
 
 
345
  turn_complete_event.set()
346
  elif event.event_type == "interrupted":
347
  shimmer.stop()
@@ -758,7 +752,7 @@ async def _handle_slash_command(
758
  normalized = arg.removeprefix("huggingface/")
759
  session = session_holder[0] if session_holder else None
760
  await model_switcher.probe_and_switch_model(
761
- normalized, config, session, console, _get_hf_token(),
762
  )
763
  return None
764
 
@@ -817,7 +811,7 @@ async def _handle_slash_command(
817
  return None
818
 
819
 
820
- async def main():
821
  """Interactive chat with the agent"""
822
 
823
  # Clear screen
@@ -827,19 +821,16 @@ async def main():
827
  prompt_session = PromptSession()
828
 
829
  # HF token — required, prompt if missing
830
- hf_token = _get_hf_token()
831
  if not hf_token:
832
  hf_token = await _prompt_and_save_hf_token(prompt_session)
833
 
834
- config = load_config(CLI_CONFIG_PATH)
 
 
835
 
836
  # Resolve username for banner
837
- hf_user = None
838
- try:
839
- from huggingface_hub import HfApi
840
- hf_user = HfApi(token=hf_token).whoami().get("name")
841
- except Exception:
842
- pass
843
 
844
  print_banner(model=config.model_name, hf_user=hf_user)
845
 
@@ -857,6 +848,8 @@ async def main():
857
  turn_complete_event.set()
858
  ready_event = asyncio.Event()
859
 
 
 
860
  # Create tool router with local mode
861
  tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
862
 
@@ -871,8 +864,12 @@ async def main():
871
  tool_router=tool_router,
872
  session_holder=session_holder,
873
  hf_token=hf_token,
 
874
  local_mode=True,
875
  stream=True,
 
 
 
876
  )
877
  )
878
 
@@ -1028,6 +1025,8 @@ async def main():
1028
  agent_task.cancel()
1029
  # Agent didn't shut down cleanly — close MCP explicitly
1030
  await tool_router.__aexit__(None, None, None)
 
 
1031
 
1032
  # Now safe to cancel the listener (agent is done emitting events)
1033
  listener_task.cancel()
@@ -1047,15 +1046,18 @@ async def headless_main(
1047
  logging.basicConfig(level=logging.WARNING)
1048
  _configure_runtime_logging()
1049
 
1050
- hf_token = _get_hf_token()
1051
  if not hf_token:
1052
  print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr)
1053
  sys.exit(1)
1054
 
1055
  print(f"HF token loaded", file=sys.stderr)
1056
 
1057
- config = load_config(CLI_CONFIG_PATH)
1058
  config.yolo_mode = True # Auto-approve everything in headless mode
 
 
 
1059
 
1060
  if model:
1061
  config.model_name = model
@@ -1082,8 +1084,12 @@ async def headless_main(
1082
  tool_router=tool_router,
1083
  session_holder=session_holder,
1084
  hf_token=hf_token,
 
1085
  local_mode=True,
1086
  stream=stream,
 
 
 
1087
  )
1088
  )
1089
 
@@ -1209,6 +1215,10 @@ async def headless_main(
1209
  stream_buf.discard()
1210
  history_size = event.data.get("history_size", "?") if event.data else "?"
1211
  print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr)
 
 
 
 
1212
  break
1213
 
1214
  # Shutdown
@@ -1222,6 +1232,8 @@ async def headless_main(
1222
  except asyncio.TimeoutError:
1223
  agent_task.cancel()
1224
  await tool_router.__aexit__(None, None, None)
 
 
1225
 
1226
 
1227
  def cli():
@@ -1252,7 +1264,7 @@ def cli():
1252
  max_iter = 10_000 # effectively unlimited
1253
  asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream))
1254
  else:
1255
- asyncio.run(main())
1256
  except KeyboardInterrupt:
1257
  print("\n\nGoodbye!")
1258
 
 
23
  from agent.config import load_config
24
  from agent.core.agent_loop import submission_loop
25
  from agent.core import model_switcher
26
+ from agent.core.hf_tokens import resolve_hf_token
27
  from agent.core.session import OpType
28
  from agent.core.tools import ToolRouter
29
+ from agent.messaging.gateway import NotificationGateway
30
  from agent.utils.reliability_checks import check_training_script_save_pattern
31
  from agent.utils.terminal_display import (
32
  get_console,
 
71
  return args if isinstance(args, dict) else {}
72
 
73
 
74
+ def _get_hf_user(token: str | None) -> str | None:
75
+ """Resolve the HF username for a token, if available."""
76
+ if not token:
77
+ return None
 
78
  try:
79
  from huggingface_hub import HfApi
80
+ return HfApi(token=token).whoami().get("name")
 
 
 
81
  except Exception:
82
+ return None
 
 
 
 
 
 
 
83
 
84
 
85
  async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str:
 
333
  stream_buf.discard()
334
  print_turn_complete()
335
  print_plan()
336
+ session = session_holder[0] if session_holder else None
337
+ if session is not None:
338
+ await session.send_deferred_turn_complete_notification(event)
339
  turn_complete_event.set()
340
  elif event.event_type == "interrupted":
341
  shimmer.stop()
 
752
  normalized = arg.removeprefix("huggingface/")
753
  session = session_holder[0] if session_holder else None
754
  await model_switcher.probe_and_switch_model(
755
+ normalized, config, session, console, resolve_hf_token(),
756
  )
757
  return None
758
 
 
811
  return None
812
 
813
 
814
+ async def main(model: str | None = None):
815
  """Interactive chat with the agent"""
816
 
817
  # Clear screen
 
821
  prompt_session = PromptSession()
822
 
823
  # HF token — required, prompt if missing
824
+ hf_token = resolve_hf_token()
825
  if not hf_token:
826
  hf_token = await _prompt_and_save_hf_token(prompt_session)
827
 
828
+ config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
829
+ if model:
830
+ config.model_name = model
831
 
832
  # Resolve username for banner
833
+ hf_user = _get_hf_user(hf_token)
 
 
 
 
 
834
 
835
  print_banner(model=config.model_name, hf_user=hf_user)
836
 
 
848
  turn_complete_event.set()
849
  ready_event = asyncio.Event()
850
 
851
+ notification_gateway = NotificationGateway(config.messaging)
852
+ await notification_gateway.start()
853
  # Create tool router with local mode
854
  tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
855
 
 
864
  tool_router=tool_router,
865
  session_holder=session_holder,
866
  hf_token=hf_token,
867
+ user_id=hf_user,
868
  local_mode=True,
869
  stream=True,
870
+ notification_gateway=notification_gateway,
871
+ notification_destinations=config.messaging.default_auto_destinations(),
872
+ defer_turn_complete_notification=True,
873
  )
874
  )
875
 
 
1025
  agent_task.cancel()
1026
  # Agent didn't shut down cleanly — close MCP explicitly
1027
  await tool_router.__aexit__(None, None, None)
1028
+ finally:
1029
+ await notification_gateway.close()
1030
 
1031
  # Now safe to cancel the listener (agent is done emitting events)
1032
  listener_task.cancel()
 
1046
  logging.basicConfig(level=logging.WARNING)
1047
  _configure_runtime_logging()
1048
 
1049
+ hf_token = resolve_hf_token()
1050
  if not hf_token:
1051
  print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr)
1052
  sys.exit(1)
1053
 
1054
  print(f"HF token loaded", file=sys.stderr)
1055
 
1056
+ config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
1057
  config.yolo_mode = True # Auto-approve everything in headless mode
1058
+ notification_gateway = NotificationGateway(config.messaging)
1059
+ await notification_gateway.start()
1060
+ hf_user = _get_hf_user(hf_token)
1061
 
1062
  if model:
1063
  config.model_name = model
 
1084
  tool_router=tool_router,
1085
  session_holder=session_holder,
1086
  hf_token=hf_token,
1087
+ user_id=hf_user,
1088
  local_mode=True,
1089
  stream=stream,
1090
+ notification_gateway=notification_gateway,
1091
+ notification_destinations=config.messaging.default_auto_destinations(),
1092
+ defer_turn_complete_notification=True,
1093
  )
1094
  )
1095
 
 
1215
  stream_buf.discard()
1216
  history_size = event.data.get("history_size", "?") if event.data else "?"
1217
  print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr)
1218
+ if event.event_type == "turn_complete":
1219
+ session = session_holder[0] if session_holder else None
1220
+ if session is not None:
1221
+ await session.send_deferred_turn_complete_notification(event)
1222
  break
1223
 
1224
  # Shutdown
 
1232
  except asyncio.TimeoutError:
1233
  agent_task.cancel()
1234
  await tool_router.__aexit__(None, None, None)
1235
+ finally:
1236
+ await notification_gateway.close()
1237
 
1238
 
1239
  def cli():
 
1264
  max_iter = 10_000 # effectively unlimited
1265
  asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream))
1266
  else:
1267
+ asyncio.run(main(model=args.model))
1268
  except KeyboardInterrupt:
1269
  print("\n\nGoodbye!")
1270
 
agent/messaging/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agent.messaging.gateway import NotificationGateway
2
+ from agent.messaging.models import (
3
+ MessagingConfig,
4
+ NotificationRequest,
5
+ NotificationResult,
6
+ SUPPORTED_AUTO_EVENT_TYPES,
7
+ )
8
+
9
+ __all__ = [
10
+ "MessagingConfig",
11
+ "NotificationGateway",
12
+ "NotificationRequest",
13
+ "NotificationResult",
14
+ "SUPPORTED_AUTO_EVENT_TYPES",
15
+ ]
agent/messaging/base.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import httpx
4
+
5
+ from agent.messaging.models import DestinationConfig, NotificationRequest, NotificationResult
6
+
7
+
8
+ class NotificationError(Exception):
9
+ """Delivery failed and should not be retried."""
10
+
11
+
12
+ class RetryableNotificationError(NotificationError):
13
+ """Delivery failed transiently and can be retried."""
14
+
15
+
16
+ class NotificationProvider(ABC):
17
+ provider_name: str
18
+
19
+ @abstractmethod
20
+ async def send(
21
+ self,
22
+ client: httpx.AsyncClient,
23
+ destination_name: str,
24
+ destination: DestinationConfig,
25
+ request: NotificationRequest,
26
+ ) -> NotificationResult:
27
+ """Deliver a notification to one destination."""
agent/messaging/gateway.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import Iterable
4
+
5
+ import httpx
6
+
7
+ from agent.messaging.base import (
8
+ NotificationError,
9
+ NotificationProvider,
10
+ RetryableNotificationError,
11
+ )
12
+ from agent.messaging.models import (
13
+ MessagingConfig,
14
+ NotificationRequest,
15
+ NotificationResult,
16
+ )
17
+ from agent.messaging.slack import SlackProvider
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _RETRY_DELAYS = (1, 2, 4)
22
+
23
+
24
+ class NotificationGateway:
25
+ def __init__(self, config: MessagingConfig):
26
+ self.config = config
27
+ self._providers: dict[str, NotificationProvider] = {
28
+ "slack": SlackProvider(),
29
+ }
30
+ self._queue: asyncio.Queue[NotificationRequest] = asyncio.Queue()
31
+ self._worker_task: asyncio.Task | None = None
32
+ self._client: httpx.AsyncClient | None = None
33
+
34
+ @property
35
+ def enabled(self) -> bool:
36
+ return self.config.enabled
37
+
38
+ async def start(self) -> None:
39
+ if not self.enabled or self._worker_task is not None:
40
+ return
41
+ self._client = httpx.AsyncClient(timeout=10.0)
42
+ self._worker_task = asyncio.create_task(self._worker(), name="notification-gateway")
43
+
44
+ async def flush(self) -> None:
45
+ if not self.enabled:
46
+ return
47
+ await self._queue.join()
48
+
49
+ async def close(self) -> None:
50
+ if not self.enabled:
51
+ return
52
+ await self.flush()
53
+ if self._worker_task is not None:
54
+ self._worker_task.cancel()
55
+ try:
56
+ await self._worker_task
57
+ except asyncio.CancelledError:
58
+ pass
59
+ self._worker_task = None
60
+ if self._client is not None:
61
+ await self._client.aclose()
62
+ self._client = None
63
+
64
+ async def send(self, request: NotificationRequest) -> NotificationResult:
65
+ if not self.enabled:
66
+ return NotificationResult(
67
+ destination=request.destination,
68
+ ok=False,
69
+ provider="disabled",
70
+ error="Messaging is disabled",
71
+ )
72
+
73
+ destination = self.config.get_destination(request.destination)
74
+ if destination is None:
75
+ return NotificationResult(
76
+ destination=request.destination,
77
+ ok=False,
78
+ provider="unknown",
79
+ error=f"Unknown destination '{request.destination}'",
80
+ )
81
+
82
+ provider = self._providers.get(destination.provider)
83
+ if provider is None:
84
+ return NotificationResult(
85
+ destination=request.destination,
86
+ ok=False,
87
+ provider=destination.provider,
88
+ error=f"No provider implementation for '{destination.provider}'",
89
+ )
90
+ return await self._send_with_retries(provider, request.destination, destination, request)
91
+
92
+ async def send_many(
93
+ self, requests: Iterable[NotificationRequest]
94
+ ) -> list[NotificationResult]:
95
+ results: list[NotificationResult] = []
96
+ for request in requests:
97
+ results.append(await self.send(request))
98
+ return results
99
+
100
+ async def enqueue(self, request: NotificationRequest) -> bool:
101
+ if not self.enabled or self._worker_task is None:
102
+ return False
103
+ await self._queue.put(request)
104
+ return True
105
+
106
+ async def _worker(self) -> None:
107
+ while True:
108
+ request = await self._queue.get()
109
+ try:
110
+ result = await self.send(request)
111
+ if not result.ok:
112
+ logger.warning(
113
+ "Notification delivery failed for %s: %s",
114
+ request.destination,
115
+ result.error,
116
+ )
117
+ except Exception:
118
+ logger.exception("Unexpected notification worker failure")
119
+ finally:
120
+ self._queue.task_done()
121
+
122
+ async def _send_with_retries(
123
+ self,
124
+ provider: NotificationProvider,
125
+ destination_name: str,
126
+ destination,
127
+ request: NotificationRequest,
128
+ ) -> NotificationResult:
129
+ client = self._client or httpx.AsyncClient(timeout=10.0)
130
+ owns_client = self._client is None
131
+ try:
132
+ for attempt in range(len(_RETRY_DELAYS) + 1):
133
+ try:
134
+ return await provider.send(client, destination_name, destination, request)
135
+ except RetryableNotificationError as exc:
136
+ if attempt >= len(_RETRY_DELAYS):
137
+ return NotificationResult(
138
+ destination=destination_name,
139
+ ok=False,
140
+ provider=provider.provider_name,
141
+ error=str(exc),
142
+ )
143
+ delay = _RETRY_DELAYS[attempt]
144
+ logger.warning(
145
+ "Retrying notification to %s in %ss after transient error: %s",
146
+ destination_name,
147
+ delay,
148
+ exc,
149
+ )
150
+ await asyncio.sleep(delay)
151
+ except NotificationError as exc:
152
+ return NotificationResult(
153
+ destination=destination_name,
154
+ ok=False,
155
+ provider=provider.provider_name,
156
+ error=str(exc),
157
+ )
158
+ return NotificationResult(
159
+ destination=destination_name,
160
+ ok=False,
161
+ provider=provider.provider_name,
162
+ error="Notification delivery exhausted retries",
163
+ )
164
+ finally:
165
+ if owns_client:
166
+ await client.aclose()
agent/messaging/models.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, Literal
2
+
3
+ from pydantic import BaseModel, Field, field_validator, model_validator
4
+
5
+ _DESTINATION_NAME_CHARS = set("abcdefghijklmnopqrstuvwxyz0123456789._-")
6
+ SUPPORTED_AUTO_EVENT_TYPES = {"approval_required", "error", "turn_complete"}
7
+
8
+
9
+ class SlackDestinationConfig(BaseModel):
10
+ provider: Literal["slack"] = "slack"
11
+ token: str
12
+ channel: str
13
+ allow_agent_tool: bool = False
14
+ allow_auto_events: bool = False
15
+ username: str | None = None
16
+ icon_emoji: str | None = None
17
+
18
+ @field_validator("token", "channel")
19
+ @classmethod
20
+ def _require_non_empty(cls, value: str) -> str:
21
+ value = value.strip()
22
+ if not value:
23
+ raise ValueError("must not be empty")
24
+ return value
25
+
26
+
27
+ DestinationConfig = Annotated[SlackDestinationConfig, Field(discriminator="provider")]
28
+
29
+
30
+ class MessagingConfig(BaseModel):
31
+ enabled: bool = False
32
+ auto_event_types: list[str] = Field(
33
+ default_factory=lambda: ["approval_required", "error", "turn_complete"]
34
+ )
35
+ destinations: dict[str, DestinationConfig] = Field(default_factory=dict)
36
+
37
+ @field_validator("destinations")
38
+ @classmethod
39
+ def _validate_destination_names(
40
+ cls, destinations: dict[str, DestinationConfig]
41
+ ) -> dict[str, DestinationConfig]:
42
+ for name in destinations:
43
+ if not name or any(char not in _DESTINATION_NAME_CHARS for char in name):
44
+ raise ValueError(
45
+ "destination names must use lowercase letters, digits, '.', '_' or '-'"
46
+ )
47
+ return destinations
48
+
49
+ @field_validator("auto_event_types")
50
+ @classmethod
51
+ def _validate_auto_event_types(cls, event_types: list[str]) -> list[str]:
52
+ if not event_types:
53
+ return []
54
+ normalized: list[str] = []
55
+ seen: set[str] = set()
56
+ for event_type in event_types:
57
+ if event_type not in SUPPORTED_AUTO_EVENT_TYPES:
58
+ raise ValueError(
59
+ f"unsupported auto event type '{event_type}'"
60
+ )
61
+ if event_type not in seen:
62
+ normalized.append(event_type)
63
+ seen.add(event_type)
64
+ return normalized
65
+
66
+ @model_validator(mode="after")
67
+ def _require_destinations_when_enabled(self) -> "MessagingConfig":
68
+ if self.enabled and not self.destinations:
69
+ raise ValueError("messaging.enabled requires at least one destination")
70
+ return self
71
+
72
+ def get_destination(self, name: str) -> DestinationConfig | None:
73
+ return self.destinations.get(name)
74
+
75
+ def can_agent_tool_send(self, name: str) -> bool:
76
+ destination = self.get_destination(name)
77
+ return bool(destination and destination.allow_agent_tool)
78
+
79
+ def can_auto_send(self, name: str) -> bool:
80
+ destination = self.get_destination(name)
81
+ return bool(destination and destination.allow_auto_events)
82
+
83
+ def default_auto_destinations(self) -> list[str]:
84
+ if not self.enabled:
85
+ return []
86
+ return [
87
+ name
88
+ for name in self.destinations
89
+ if self.can_auto_send(name)
90
+ ]
91
+
92
+
93
+ class NotificationRequest(BaseModel):
94
+ destination: str
95
+ title: str | None = None
96
+ message: str
97
+ severity: Literal["info", "success", "warning", "error"] = "info"
98
+ metadata: dict[str, str] = Field(default_factory=dict)
99
+ event_type: str | None = None
100
+
101
+ @field_validator("destination", "message")
102
+ @classmethod
103
+ def _require_text(cls, value: str) -> str:
104
+ value = value.strip()
105
+ if not value:
106
+ raise ValueError("must not be empty")
107
+ return value
108
+
109
+ @field_validator("title")
110
+ @classmethod
111
+ def _normalize_title(cls, value: str | None) -> str | None:
112
+ if value is None:
113
+ return None
114
+ value = value.strip()
115
+ return value or None
116
+
117
+
118
+ class NotificationResult(BaseModel):
119
+ destination: str
120
+ ok: bool
121
+ provider: str
122
+ error: str | None = None
123
+ external_id: str | None = None
agent/messaging/slack.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+
4
+ import httpx
5
+
6
+ from agent.messaging.base import (
7
+ NotificationError,
8
+ NotificationProvider,
9
+ RetryableNotificationError,
10
+ )
11
+ from agent.messaging.models import (
12
+ NotificationRequest,
13
+ NotificationResult,
14
+ SlackDestinationConfig,
15
+ )
16
+
17
+ _SEVERITY_PREFIX = {
18
+ "info": "[INFO]",
19
+ "success": "[SUCCESS]",
20
+ "warning": "[WARNING]",
21
+ "error": "[ERROR]",
22
+ }
23
+
24
+
25
+ def _format_slack_mrkdwn(content: str) -> str:
26
+ """Convert common Markdown constructs to Slack's mrkdwn syntax."""
27
+ if not content:
28
+ return content
29
+
30
+ placeholders: dict[str, str] = {}
31
+ placeholder_index = 0
32
+
33
+ def placeholder(value: str) -> str:
34
+ nonlocal placeholder_index
35
+ key = f"\x00SLACK{placeholder_index}\x00"
36
+ placeholder_index += 1
37
+ placeholders[key] = value
38
+ return key
39
+
40
+ text = content
41
+
42
+ # Protect code before any formatting conversion. Slack's mrkdwn ignores
43
+ # formatting inside backticks, so these regions should stay byte-for-byte.
44
+ text = re.sub(
45
+ r"(```(?:[^\n]*\n)?[\s\S]*?```)",
46
+ lambda match: placeholder(match.group(0)),
47
+ text,
48
+ )
49
+ text = re.sub(r"(`[^`\n]+`)", lambda match: placeholder(match.group(0)), text)
50
+
51
+ def convert_markdown_link(match: re.Match[str]) -> str:
52
+ label = match.group(1)
53
+ url = match.group(2).strip()
54
+ if url.startswith("<") and url.endswith(">"):
55
+ url = url[1:-1].strip()
56
+ return placeholder(f"<{url}|{label}>")
57
+
58
+ text = re.sub(
59
+ r"\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)",
60
+ convert_markdown_link,
61
+ text,
62
+ )
63
+
64
+ # Preserve existing Slack entities and manual mrkdwn links before escaping.
65
+ text = re.sub(
66
+ r"(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)",
67
+ lambda match: placeholder(match.group(1)),
68
+ text,
69
+ )
70
+ text = re.sub(
71
+ r"^(>+\s)",
72
+ lambda match: placeholder(match.group(0)),
73
+ text,
74
+ flags=re.MULTILINE,
75
+ )
76
+
77
+ text = text.replace("&amp;", "&").replace("&lt;", "<").replace("&gt;", ">")
78
+ text = text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
79
+
80
+ def convert_header(match: re.Match[str]) -> str:
81
+ header = match.group(1).strip()
82
+ header = re.sub(r"\*\*(.+?)\*\*", r"\1", header)
83
+ return placeholder(f"*{header}*")
84
+
85
+ text = re.sub(r"^#{1,6}\s+(.+)$", convert_header, text, flags=re.MULTILINE)
86
+ text = re.sub(
87
+ r"\*\*\*(.+?)\*\*\*",
88
+ lambda match: placeholder(f"*_{match.group(1)}_*"),
89
+ text,
90
+ )
91
+ text = re.sub(
92
+ r"\*\*(.+?)\*\*",
93
+ lambda match: placeholder(f"*{match.group(1)}*"),
94
+ text,
95
+ )
96
+ text = re.sub(
97
+ r"(?<!\*)\*([^*\n]+)\*(?!\*)",
98
+ lambda match: placeholder(f"_{match.group(1)}_"),
99
+ text,
100
+ )
101
+ text = re.sub(
102
+ r"~~(.+?)~~",
103
+ lambda match: placeholder(f"~{match.group(1)}~"),
104
+ text,
105
+ )
106
+
107
+ for key in reversed(placeholders):
108
+ text = text.replace(key, placeholders[key])
109
+
110
+ return text
111
+
112
+
113
+ def _format_text(request: NotificationRequest) -> str:
114
+ lines: list[str] = []
115
+ prefix = _SEVERITY_PREFIX[request.severity]
116
+ if request.title:
117
+ lines.append(f"{prefix} {request.title}")
118
+ else:
119
+ lines.append(prefix)
120
+ lines.append(request.message)
121
+ for key, value in request.metadata.items():
122
+ lines.append(f"{key}: {value}")
123
+ return _format_slack_mrkdwn("\n".join(lines))
124
+
125
+
126
+ class SlackProvider(NotificationProvider):
127
+ provider_name = "slack"
128
+
129
+ async def send(
130
+ self,
131
+ client: httpx.AsyncClient,
132
+ destination_name: str,
133
+ destination: SlackDestinationConfig,
134
+ request: NotificationRequest,
135
+ ) -> NotificationResult:
136
+ payload = {
137
+ "channel": destination.channel,
138
+ "text": _format_text(request),
139
+ "mrkdwn": True,
140
+ "unfurl_links": False,
141
+ "unfurl_media": False,
142
+ }
143
+ if destination.username:
144
+ payload["username"] = destination.username
145
+ if destination.icon_emoji:
146
+ payload["icon_emoji"] = destination.icon_emoji
147
+
148
+ try:
149
+ response = await client.post(
150
+ "https://slack.com/api/chat.postMessage",
151
+ headers={
152
+ "Authorization": f"Bearer {destination.token}",
153
+ "Content-Type": "application/json; charset=utf-8",
154
+ },
155
+ content=json.dumps(payload),
156
+ )
157
+ except httpx.TimeoutException as exc:
158
+ raise RetryableNotificationError("Slack request timed out") from exc
159
+ except httpx.TransportError as exc:
160
+ raise RetryableNotificationError("Slack transport error") from exc
161
+
162
+ if response.status_code == 429 or response.status_code >= 500:
163
+ raise RetryableNotificationError(
164
+ f"Slack HTTP {response.status_code}"
165
+ )
166
+ if response.status_code >= 400:
167
+ raise NotificationError(f"Slack HTTP {response.status_code}")
168
+
169
+ try:
170
+ data = response.json()
171
+ except ValueError as exc:
172
+ raise RetryableNotificationError("Slack returned invalid JSON") from exc
173
+
174
+ if not data.get("ok"):
175
+ error = str(data.get("error") or "unknown_error")
176
+ if error == "ratelimited":
177
+ raise RetryableNotificationError(error)
178
+ raise NotificationError(error)
179
+
180
+ return NotificationResult(
181
+ destination=destination_name,
182
+ ok=True,
183
+ provider=self.provider_name,
184
+ external_id=str(data.get("ts") or ""),
185
+ error=None,
186
+ )
agent/prompts/system_prompt_v3.yaml CHANGED
@@ -1,5 +1,5 @@
1
  system_prompt: |
2
- You are Hugging Face Agent, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face ecosystem.
3
 
4
  Your goal is to complete what the user requested with zero errors. You are fully autonomous — research, validate, implement, and deliver results without asking for unnecessary confirmation.
5
 
@@ -28,7 +28,7 @@ system_prompt: |
28
 
29
  # Mistakes you WILL make without research
30
 
31
- HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first.
32
 
33
  WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs.
34
 
@@ -60,6 +60,38 @@ system_prompt: |
60
  DPO: "prompt", "chosen", "rejected"
61
  GRPO: "prompt"
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # Data audit
64
 
65
  Before working with any dataset, audit it first. Do not assume you know what the data looks like — inspect it.
@@ -75,7 +107,7 @@ system_prompt: |
75
  - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
76
  - push_to_hub=True and hub_model_id set
77
  - timeout: [value] (based on: [model size] on [hardware])
78
- - Trackio monitoring included and working
79
 
80
  If you cannot fill in all items, stop and complete the missing steps first.
81
 
@@ -156,6 +188,7 @@ system_prompt: |
156
  - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs.
157
  - For errors: state what went wrong, why, and what you're doing to fix it.
158
  - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity.
 
159
 
160
  # Tool usage
161
 
 
1
  system_prompt: |
2
+ You are ML Intern, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face (HF) ecosystem.
3
 
4
  Your goal is to complete what the user requested with zero errors. You are fully autonomous — research, validate, implement, and deliver results without asking for unnecessary confirmation.
5
 
 
28
 
29
  # Mistakes you WILL make without research
30
 
31
+ HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio config field names. Fix: read a current example script first.
32
 
33
  WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs.
34
 
 
60
  DPO: "prompt", "chosen", "rejected"
61
  GRPO: "prompt"
62
 
63
+ # Trackio
64
+
65
+ Trackio is natively integrated with Transformers Trainer and all TRL trainers — the built-in TrackioCallback handles init/log/finish. In TrainingArguments/SFTConfig/DPOConfig/GRPOConfig set:
66
+ report_to="trackio"
67
+ run_name="<descriptive-run-name>" # e.g. "sft_qwen3-4b_lr2e-5_bs128"
68
+ project="<descriptive-project-name>" # keeps related runs grouped so you can compare them
69
+ trackio_space_id="<username>/mlintern-<8-char-id>" # creates a public dashboard Space
70
+ `project` and `trackio_space_id` can also be set via TRACKIO_PROJECT / TRACKIO_SPACE_ID env vars.
71
+
72
+ Alerts are how iterations decide what to change. Use trackio.alert(title, text, level) at every decision point in training. Levels:
73
+ ERROR — stop and change approach (divergence, NaN, OOM)
74
+ WARN — tweak hyperparameters (overfitting, early stopping, KL spike, reward collapse, slow convergence)
75
+ INFO — milestones (training complete, target reached, checkpoint saved)
76
+ Always include numeric values and an actionable suggestion in `text`, e.g. "loss=12.4 at step 200 — lr likely too high, try ×0.1". A future call must be able to parse it and act on it.
77
+
78
+ To add alerts under Trainer/SFTTrainer/GRPOTrainer, pass a custom TrainerCallback via `callbacks=[...]` that calls trackio.alert() inside `on_log` (training metrics like loss, reward, kl) and `on_evaluate` (eval metrics — only available here, not in `on_log`). Keep each `if` simple: one metric, one threshold. Conditions stay easy to adjust between runs.
79
+
80
+ Read alerts back between runs instead of parsing thousands of metric values. CLI — always use --json:
81
+ trackio get alerts --project <p> --run <r> --json
82
+ trackio get alerts --project <p> --since <iso8601> --json # incremental polling
83
+ trackio get run --project <p> --run <r> --json
84
+ trackio get metric --project <p> --run <r> --metric <m> --json
85
+ trackio list runs --project <p> --json
86
+ Python: api = trackio.Api(); api.alerts(<p>, run=<r>, since=<ts>); api.runs(<p>) (each run has .name, .config, .alerts()).
87
+
88
+ Drive the next config from prior alerts:
89
+ diverged → lr × 0.1
90
+ overfitting → weight_decay × 10 or reduce capacity
91
+ early stopping → lr × 0.5 or adjust schedule
92
+ high accuracy → refine around current config
93
+ Read prior config via api.runs(...).config and only mutate keys the alerts justify changing.
94
+
95
  # Data audit
96
 
97
  Before working with any dataset, audit it first. Do not assume you know what the data looks like — inspect it.
 
107
  - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details]
108
  - push_to_hub=True and hub_model_id set
109
  - timeout: [value] (based on: [model size] on [hardware])
110
+ - Trackio monitoring included and deploying metrics to a public Space
111
 
112
  If you cannot fill in all items, stop and complete the missing steps first.
113
 
 
188
  - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs.
189
  - For errors: state what went wrong, why, and what you're doing to fix it.
190
  - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity.
191
+ - Use the `notify` tool only when the user explicitly asked for out-of-band notifications or when the task clearly requires reporting to a configured messaging destination. Do not use it for routine chat updates.
192
 
193
  # Tool usage
194
 
agent/tools/__init__.py CHANGED
@@ -20,6 +20,7 @@ from agent.tools.github_read_file import (
20
  )
21
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
22
  from agent.tools.types import ToolResult
 
23
 
24
  __all__ = [
25
  "ToolResult",
@@ -36,4 +37,6 @@ __all__ = [
36
  "github_search_code_handler",
37
  "HF_INSPECT_DATASET_TOOL_SPEC",
38
  "hf_inspect_dataset_handler",
 
 
39
  ]
 
20
  )
21
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
22
  from agent.tools.types import ToolResult
23
+ from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler
24
 
25
  __all__ = [
26
  "ToolResult",
 
37
  "github_search_code_handler",
38
  "HF_INSPECT_DATASET_TOOL_SPEC",
39
  "hf_inspect_dataset_handler",
40
+ "WEB_SEARCH_TOOL_SPEC",
41
+ "web_search_handler",
42
  ]
agent/tools/jobs_tool.py CHANGED
@@ -19,6 +19,7 @@ from huggingface_hub.utils import HfHubHTTPError
19
 
20
  from agent.core.hf_access import JobsAccessError, resolve_jobs_namespace
21
  from agent.core.session import Event
 
22
  from agent.tools.types import ToolResult
23
 
24
  logger = logging.getLogger(__name__)
@@ -382,6 +383,31 @@ class HfJobsTool:
382
  "isError": True,
383
  }
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  async def _wait_for_job_completion(
386
  self, job_id: str, namespace: Optional[str] = None
387
  ) -> tuple[str, list[str]]:
@@ -533,11 +559,24 @@ class HfJobsTool:
533
  # Run the job
534
  flavor = args.get("hardware_flavor", "cpu-basic")
535
  timeout_str = args.get("timeout", "30m")
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  job = await _async_call(
537
  self.api.run_job,
538
  image=image,
539
  command=command,
540
- env=_add_default_env(args.get("env")),
541
  secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
542
  flavor=flavor,
543
  timeout=timeout_str,
@@ -550,16 +589,18 @@ class HfJobsTool:
550
 
551
  # Send job URL immediately after job creation (before waiting for completion)
552
  if self.session and self.tool_call_id:
 
 
 
 
 
 
 
 
 
 
553
  await self.session.send_event(
554
- Event(
555
- event_type="tool_state_change",
556
- data={
557
- "tool_call_id": self.tool_call_id,
558
- "tool": "hf_jobs",
559
- "state": "running",
560
- "jobUrl": job.url,
561
- },
562
- )
563
  )
564
 
565
  # Telemetry: job submission + completion (infra consumption signal).
@@ -594,16 +635,18 @@ class HfJobsTool:
594
 
595
  # Notify frontend of final status
596
  if self.session and self.tool_call_id:
 
 
 
 
 
 
 
 
 
 
597
  await self.session.send_event(
598
- Event(
599
- event_type="tool_state_change",
600
- data={
601
- "tool_call_id": self.tool_call_id,
602
- "tool": "hf_jobs",
603
- "state": final_status.lower(),
604
- "jobUrl": job.url,
605
- },
606
- )
607
  )
608
 
609
  # Filter out UV package installation output
@@ -977,7 +1020,10 @@ HF_JOBS_TOOL_SPEC = {
977
  "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
978
  "- Training config MUST include push_to_hub=True and hub_model_id. "
979
  "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
980
- "- Include trackio monitoring and provide the dashboard URL to the user.\n\n"
 
 
 
981
  "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. "
982
  "Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n"
983
  "Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n"
@@ -1060,6 +1106,26 @@ HF_JOBS_TOOL_SPEC = {
1060
  "type": "object",
1061
  "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.",
1062
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1063
  "namespace": {
1064
  "type": "string",
1065
  "description": (
 
19
 
20
  from agent.core.hf_access import JobsAccessError, resolve_jobs_namespace
21
  from agent.core.session import Event
22
+ from agent.tools.trackio_seed import ensure_trackio_dashboard
23
  from agent.tools.types import ToolResult
24
 
25
  logger = logging.getLogger(__name__)
 
383
  "isError": True,
384
  }
385
 
386
+ async def _seed_trackio_dashboard(self, space_id: str) -> None:
387
+ """Idempotently install trackio dashboard files into *space_id* before
388
+ the job runs. Surfaces seed progress as tool_log events but never
389
+ raises — a seed failure should not block job submission, since trackio
390
+ often still works when the Space already has dashboard code from a
391
+ previous run.
392
+ """
393
+ loop = asyncio.get_running_loop()
394
+
395
+ def _log(msg: str) -> None:
396
+ if self.session is None:
397
+ return
398
+ loop.call_soon_threadsafe(
399
+ self.session.event_queue.put_nowait,
400
+ Event(event_type="tool_log", data={"tool": "hf_jobs", "log": msg}),
401
+ )
402
+
403
+ try:
404
+ await asyncio.to_thread(
405
+ ensure_trackio_dashboard, space_id, self.hf_token, _log
406
+ )
407
+ except Exception as e:
408
+ logger.warning(f"trackio dashboard seed failed for {space_id}: {e}")
409
+ _log(f"trackio dashboard seed failed: {e}")
410
+
411
  async def _wait_for_job_completion(
412
  self, job_id: str, namespace: Optional[str] = None
413
  ) -> tuple[str, list[str]]:
 
559
  # Run the job
560
  flavor = args.get("hardware_flavor", "cpu-basic")
561
  timeout_str = args.get("timeout", "30m")
562
+
563
+ # Trackio: agent-declared space + project become env vars on the job
564
+ # so trackio.init() picks them up automatically. We also surface them
565
+ # in tool_state_change so the frontend can embed the dashboard.
566
+ env_dict = _add_default_env(args.get("env"))
567
+ trackio_space_id = args.get("trackio_space_id")
568
+ trackio_project = args.get("trackio_project")
569
+ if trackio_space_id:
570
+ env_dict["TRACKIO_SPACE_ID"] = trackio_space_id
571
+ await self._seed_trackio_dashboard(trackio_space_id)
572
+ if trackio_project:
573
+ env_dict["TRACKIO_PROJECT"] = trackio_project
574
+
575
  job = await _async_call(
576
  self.api.run_job,
577
  image=image,
578
  command=command,
579
+ env=env_dict,
580
  secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
581
  flavor=flavor,
582
  timeout=timeout_str,
 
589
 
590
  # Send job URL immediately after job creation (before waiting for completion)
591
  if self.session and self.tool_call_id:
592
+ state_data: Dict[str, Any] = {
593
+ "tool_call_id": self.tool_call_id,
594
+ "tool": "hf_jobs",
595
+ "state": "running",
596
+ "jobUrl": job.url,
597
+ }
598
+ if trackio_space_id:
599
+ state_data["trackioSpaceId"] = trackio_space_id
600
+ if trackio_project:
601
+ state_data["trackioProject"] = trackio_project
602
  await self.session.send_event(
603
+ Event(event_type="tool_state_change", data=state_data)
 
 
 
 
 
 
 
 
604
  )
605
 
606
  # Telemetry: job submission + completion (infra consumption signal).
 
635
 
636
  # Notify frontend of final status
637
  if self.session and self.tool_call_id:
638
+ final_data: Dict[str, Any] = {
639
+ "tool_call_id": self.tool_call_id,
640
+ "tool": "hf_jobs",
641
+ "state": final_status.lower(),
642
+ "jobUrl": job.url,
643
+ }
644
+ if trackio_space_id:
645
+ final_data["trackioSpaceId"] = trackio_space_id
646
+ if trackio_project:
647
+ final_data["trackioProject"] = trackio_project
648
  await self.session.send_event(
649
+ Event(event_type="tool_state_change", data=final_data)
 
 
 
 
 
 
 
 
650
  )
651
 
652
  # Filter out UV package installation output
 
1020
  "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n"
1021
  "- Training config MUST include push_to_hub=True and hub_model_id. "
1022
  "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n"
1023
+ "- Include trackio monitoring and provide the dashboard URL to the user. "
1024
+ "When the script uses report_to='trackio', also pass `trackio_space_id` "
1025
+ "(e.g. '<username>/mlintern-<8char>') and `trackio_project` as tool args — "
1026
+ "they are injected as TRACKIO_SPACE_ID/TRACKIO_PROJECT env vars and let the UI embed the live dashboard.\n\n"
1027
  "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. "
1028
  "Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n"
1029
  "Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n"
 
1106
  "type": "object",
1107
  "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.",
1108
  },
1109
+ "trackio_space_id": {
1110
+ "type": "string",
1111
+ "description": (
1112
+ "Optional. The HF Space hosting the trackio dashboard for this run "
1113
+ "(e.g. '<username>/mlintern-<8char>', under YOUR HF namespace). "
1114
+ "Injected as TRACKIO_SPACE_ID env var and used by the UI to embed "
1115
+ "the live dashboard. Set this whenever the script uses "
1116
+ "report_to='trackio'. The Space is auto-created and seeded with the "
1117
+ "trackio dashboard before the job starts — DO NOT pre-create it via "
1118
+ "hf_repo_git, that produces an empty Space that breaks the embed."
1119
+ ),
1120
+ },
1121
+ "trackio_project": {
1122
+ "type": "string",
1123
+ "description": (
1124
+ "Optional. The trackio project name to log this run under. "
1125
+ "Injected as TRACKIO_PROJECT env var and used by the UI to filter "
1126
+ "the embedded dashboard to this project."
1127
+ ),
1128
+ },
1129
  "namespace": {
1130
  "type": "string",
1131
  "description": (
agent/tools/notify_tool.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from agent.messaging.models import NotificationRequest
4
+
5
+ NOTIFY_TOOL_SPEC = {
6
+ "name": "notify",
7
+ "description": (
8
+ "Send an out-of-band notification to configured messaging destinations. "
9
+ "Use this only when the user explicitly asked for proactive notifications "
10
+ "or when the task requires reporting progress outside the chat. "
11
+ "Destinations must be named server-side configs such as 'slack.ops'."
12
+ ),
13
+ "parameters": {
14
+ "type": "object",
15
+ "properties": {
16
+ "destinations": {
17
+ "type": "array",
18
+ "description": "Named messaging destinations to notify.",
19
+ "items": {"type": "string"},
20
+ "minItems": 1,
21
+ },
22
+ "message": {
23
+ "type": "string",
24
+ "description": "Main notification body.",
25
+ },
26
+ "title": {
27
+ "type": "string",
28
+ "description": "Optional short title line.",
29
+ },
30
+ "severity": {
31
+ "type": "string",
32
+ "enum": ["info", "success", "warning", "error"],
33
+ "description": "Notification severity label.",
34
+ },
35
+ },
36
+ "required": ["destinations", "message"],
37
+ },
38
+ }
39
+
40
+
41
+ async def notify_handler(
42
+ arguments: dict[str, Any], session=None, **_kwargs
43
+ ) -> tuple[str, bool]:
44
+ if session is None or session.notification_gateway is None:
45
+ return "Messaging is not configured for this session.", False
46
+
47
+ raw_destinations = arguments.get("destinations", [])
48
+ if not isinstance(raw_destinations, list) or not raw_destinations:
49
+ return "destinations must be a non-empty array of destination names.", False
50
+
51
+ destinations: list[str] = []
52
+ seen: set[str] = set()
53
+ for raw_name in raw_destinations:
54
+ if not isinstance(raw_name, str):
55
+ return "Each destination must be a string.", False
56
+ name = raw_name.strip()
57
+ if not name:
58
+ return "Destination names must not be empty.", False
59
+ if name not in seen:
60
+ destinations.append(name)
61
+ seen.add(name)
62
+
63
+ disallowed = [
64
+ name
65
+ for name in destinations
66
+ if not session.config.messaging.can_agent_tool_send(name)
67
+ ]
68
+ if disallowed:
69
+ return (
70
+ "These destinations are unavailable for the notify tool: "
71
+ + ", ".join(disallowed)
72
+ ), False
73
+
74
+ message = arguments.get("message", "")
75
+ if not isinstance(message, str) or not message.strip():
76
+ return "message must be a non-empty string.", False
77
+
78
+ title = arguments.get("title")
79
+ severity = arguments.get("severity", "info")
80
+ if title is not None and not isinstance(title, str):
81
+ return "title must be a string when provided.", False
82
+ if severity not in {"info", "success", "warning", "error"}:
83
+ return "severity must be one of: info, success, warning, error.", False
84
+
85
+ requests = [
86
+ NotificationRequest(
87
+ destination=name,
88
+ title=title,
89
+ message=message,
90
+ severity=severity,
91
+ metadata={
92
+ "session_id": session.session_id,
93
+ "model": session.config.model_name,
94
+ },
95
+ )
96
+ for name in destinations
97
+ ]
98
+ results = await session.notification_gateway.send_many(requests)
99
+
100
+ lines = []
101
+ all_ok = True
102
+ for result in results:
103
+ if result.ok:
104
+ lines.append(f"{result.destination}: sent")
105
+ else:
106
+ all_ok = False
107
+ lines.append(f"{result.destination}: failed ({result.error})")
108
+ return "\n".join(lines), all_ok
agent/tools/research_tool.py CHANGED
@@ -37,6 +37,7 @@ RESEARCH_TOOL_NAMES = {
37
  "github_find_examples",
38
  "github_list_repos",
39
  "github_read_file",
 
40
  "hf_inspect_dataset",
41
  "hf_repo_files",
42
  }
@@ -102,6 +103,8 @@ tell you what actually works.
102
  - `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc.
103
  - `fetch_hf_docs(url)`: Fetch full page content from explore results
104
  - `find_hf_api(query=..., tag=...)`: Find REST API endpoints
 
 
105
 
106
  ## Hub repo inspection
107
  - `hf_repo_files`: List/read files in any HF repo (model, dataset, space)
@@ -306,8 +309,10 @@ async def research_handler(
306
  # ── Doom-loop detection ──
307
  doom_prompt = check_for_doom_loop(messages)
308
  if doom_prompt:
309
- logger.warning("Research sub-agent doom loop detected at iteration %d", _iteration)
310
- await _log("Doom loop detected injecting corrective prompt")
 
 
311
  messages.append(Message(role="user", content=doom_prompt))
312
 
313
  # ── Context budget: warn at 75%, hard-stop at 95% ──
@@ -424,7 +429,7 @@ async def research_handler(
424
  await _log(f"▸ {tool_name} {args_str}")
425
 
426
  output, _success = await session.tool_router.call_tool(
427
- tool_name, tool_args, session=session
428
  )
429
  _tool_uses += 1
430
  await _log(f"tools:{_tool_uses}")
 
37
  "github_find_examples",
38
  "github_list_repos",
39
  "github_read_file",
40
+ "web_search",
41
  "hf_inspect_dataset",
42
  "hf_repo_files",
43
  }
 
103
  - `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc.
104
  - `fetch_hf_docs(url)`: Fetch full page content from explore results
105
  - `find_hf_api(query=..., tag=...)`: Find REST API endpoints
106
+ - `web_search(query=..., allowed_domains=[...], blocked_domains=[...])`:
107
+ Search the current web when papers/docs/GitHub are not enough.
108
 
109
  ## Hub repo inspection
110
  - `hf_repo_files`: List/read files in any HF repo (model, dataset, space)
 
309
  # ── Doom-loop detection ──
310
  doom_prompt = check_for_doom_loop(messages)
311
  if doom_prompt:
312
+ logger.warning(
313
+ "Research sub-agent repetition guard activated at iteration %d",
314
+ _iteration,
315
+ )
316
  messages.append(Message(role="user", content=doom_prompt))
317
 
318
  # ── Context budget: warn at 75%, hard-stop at 95% ──
 
429
  await _log(f"▸ {tool_name} {args_str}")
430
 
431
  output, _success = await session.tool_router.call_tool(
432
+ tool_name, tool_args, session=session, tool_call_id=tc.id
433
  )
434
  _tool_uses += 1
435
  await _log(f"tools:{_tool_uses}")
agent/tools/sandbox_client.py CHANGED
@@ -37,6 +37,7 @@ Tools: bash, read, write, edit, upload
37
  from __future__ import annotations
38
 
39
  import io
 
40
  import sys
41
  import time
42
  import uuid
@@ -99,8 +100,8 @@ CMD ["python", "sandbox_server.py"]
99
 
100
  _SANDBOX_SERVER = '''\
101
  """Minimal FastAPI server for sandbox operations."""
102
- import os, subprocess, pathlib, signal, threading, re, tempfile
103
- from fastapi import FastAPI
104
  from pydantic import BaseModel
105
  from typing import Optional
106
  import uvicorn
@@ -156,6 +157,22 @@ def _atomic_write(path: pathlib.Path, content: str):
156
 
157
  app = FastAPI()
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  # Track active bash processes so they can be killed on cancel
160
  _active_procs = {} # pid -> subprocess.Popen
161
  _proc_lock = threading.Lock()
@@ -344,7 +361,7 @@ def _validate_python(content, path=""):
344
  def health():
345
  return {"status": "ok"}
346
 
347
- @app.post("/api/bash")
348
  def bash(req: BashReq):
349
  try:
350
  proc = subprocess.Popen(
@@ -371,7 +388,7 @@ def bash(req: BashReq):
371
  except Exception as e:
372
  return {"success": False, "output": "", "error": str(e)}
373
 
374
- @app.post("/api/kill")
375
  def kill_all():
376
  """Kill all active bash processes. Called when user cancels."""
377
  with _proc_lock:
@@ -389,7 +406,7 @@ def kill_all():
389
  pass
390
  return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""}
391
 
392
- @app.post("/api/read")
393
  def read(req: ReadReq):
394
  try:
395
  p = pathlib.Path(req.path)
@@ -406,7 +423,7 @@ def read(req: ReadReq):
406
  except Exception as e:
407
  return {"success": False, "output": "", "error": str(e)}
408
 
409
- @app.post("/api/write")
410
  def write(req: WriteReq):
411
  try:
412
  p = pathlib.Path(req.path)
@@ -420,7 +437,7 @@ def write(req: WriteReq):
420
  except Exception as e:
421
  return {"success": False, "output": "", "error": str(e)}
422
 
423
- @app.post("/api/edit")
424
  def edit(req: EditReq):
425
  try:
426
  p = pathlib.Path(req.path)
@@ -447,7 +464,7 @@ def edit(req: EditReq):
447
  except Exception as e:
448
  return {"success": False, "output": "", "error": str(e)}
449
 
450
- @app.post("/api/exists")
451
  def exists(req: ExistsReq):
452
  return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""}
453
 
@@ -482,6 +499,7 @@ class Sandbox:
482
 
483
  space_id: str
484
  token: str | None = None
 
485
  work_dir: str = "/app"
486
  timeout: int = DEFAULT_TIMEOUT
487
  _owns_space: bool = field(default=False, repr=False)
@@ -495,9 +513,10 @@ class Sandbox:
495
  # Trailing slash is critical: httpx resolves relative paths against base_url.
496
  # Without it, client.get("health") resolves to /health instead of /api/health.
497
  self._base_url = f"https://{slug}.hf.space/api/"
 
498
  self._client = httpx.Client(
499
  base_url=self._base_url,
500
- headers={"Authorization": f"Bearer {self.token}"} if self.token else {},
501
  timeout=httpx.Timeout(MAX_TIMEOUT, connect=30),
502
  follow_redirects=True,
503
  )
@@ -563,6 +582,7 @@ class Sandbox:
563
  base = name or "sandbox"
564
  suffix = uuid.uuid4().hex[:8]
565
  space_id = f"{owner}/{base}-{suffix}"
 
566
 
567
  _log(f"Creating sandbox: {space_id} (from {template})...")
568
 
@@ -583,8 +603,9 @@ class Sandbox:
583
  # Inject secrets BEFORE uploading server files (which triggers rebuild).
584
  # Secrets added after a Space is running aren't available until restart,
585
  # so they must be set before the build/start cycle.
586
- if secrets:
587
- for key, val in secrets.items():
 
588
  api.add_space_secret(space_id, key, val)
589
 
590
  # Upload sandbox server and Dockerfile (triggers rebuild)
@@ -617,7 +638,12 @@ class Sandbox:
617
  _check_cancel()
618
 
619
  # Wait for the API server to be responsive (non-fatal)
620
- sb = cls(space_id=space_id, token=token, _owns_space=True)
 
 
 
 
 
621
  try:
622
  sb._wait_for_api(timeout=API_WAIT_TIMEOUT, log=_log)
623
  except TimeoutError as e:
@@ -648,13 +674,24 @@ class Sandbox:
648
  log("Server files uploaded, rebuild triggered.")
649
 
650
  @classmethod
651
- def connect(cls, space_id: str, *, token: str | None = None) -> Sandbox:
 
 
 
 
 
 
652
  """
653
  Connect to an existing running Space.
654
 
655
  Does a health check to verify the Space is reachable.
656
  """
657
- sb = cls(space_id=space_id, token=token, _owns_space=False)
 
 
 
 
 
658
  sb._wait_for_api(timeout=60)
659
  return sb
660
 
@@ -687,6 +724,10 @@ class Sandbox:
687
  )
688
  print(f"Deleting sandbox: {self.space_id}...")
689
  self._hf_api.delete_repo(self.space_id, repo_type="space")
 
 
 
 
690
  self._client.close()
691
  print("Deleted.")
692
 
 
37
  from __future__ import annotations
38
 
39
  import io
40
+ import secrets as secrets_lib
41
  import sys
42
  import time
43
  import uuid
 
100
 
101
  _SANDBOX_SERVER = '''\
102
  """Minimal FastAPI server for sandbox operations."""
103
+ import hmac, os, subprocess, pathlib, signal, threading, re, tempfile
104
+ from fastapi import Depends, FastAPI, HTTPException, Request
105
  from pydantic import BaseModel
106
  from typing import Optional
107
  import uvicorn
 
157
 
158
  app = FastAPI()
159
 
160
+ def _expected_api_token() -> str:
161
+ return os.environ.get("SANDBOX_API_TOKEN") or os.environ.get("HF_TOKEN") or ""
162
+
163
+ def _require_auth(request: Request) -> None:
164
+ expected = _expected_api_token()
165
+ if not expected:
166
+ raise HTTPException(status_code=503, detail="Sandbox API token not configured")
167
+ auth_header = request.headers.get("authorization", "")
168
+ scheme, _, supplied = auth_header.partition(" ")
169
+ if scheme.lower() != "bearer" or not supplied:
170
+ raise HTTPException(status_code=401, detail="Missing bearer token")
171
+ if not hmac.compare_digest(supplied, expected):
172
+ raise HTTPException(status_code=401, detail="Invalid bearer token")
173
+
174
+ _AUTH = [Depends(_require_auth)]
175
+
176
  # Track active bash processes so they can be killed on cancel
177
  _active_procs = {} # pid -> subprocess.Popen
178
  _proc_lock = threading.Lock()
 
361
  def health():
362
  return {"status": "ok"}
363
 
364
+ @app.post("/api/bash", dependencies=_AUTH)
365
  def bash(req: BashReq):
366
  try:
367
  proc = subprocess.Popen(
 
388
  except Exception as e:
389
  return {"success": False, "output": "", "error": str(e)}
390
 
391
+ @app.post("/api/kill", dependencies=_AUTH)
392
  def kill_all():
393
  """Kill all active bash processes. Called when user cancels."""
394
  with _proc_lock:
 
406
  pass
407
  return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""}
408
 
409
+ @app.post("/api/read", dependencies=_AUTH)
410
  def read(req: ReadReq):
411
  try:
412
  p = pathlib.Path(req.path)
 
423
  except Exception as e:
424
  return {"success": False, "output": "", "error": str(e)}
425
 
426
+ @app.post("/api/write", dependencies=_AUTH)
427
  def write(req: WriteReq):
428
  try:
429
  p = pathlib.Path(req.path)
 
437
  except Exception as e:
438
  return {"success": False, "output": "", "error": str(e)}
439
 
440
+ @app.post("/api/edit", dependencies=_AUTH)
441
  def edit(req: EditReq):
442
  try:
443
  p = pathlib.Path(req.path)
 
464
  except Exception as e:
465
  return {"success": False, "output": "", "error": str(e)}
466
 
467
+ @app.post("/api/exists", dependencies=_AUTH)
468
  def exists(req: ExistsReq):
469
  return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""}
470
 
 
499
 
500
  space_id: str
501
  token: str | None = None
502
+ api_token: str | None = field(default=None, repr=False)
503
  work_dir: str = "/app"
504
  timeout: int = DEFAULT_TIMEOUT
505
  _owns_space: bool = field(default=False, repr=False)
 
513
  # Trailing slash is critical: httpx resolves relative paths against base_url.
514
  # Without it, client.get("health") resolves to /health instead of /api/health.
515
  self._base_url = f"https://{slug}.hf.space/api/"
516
+ api_token = self.api_token or self.token
517
  self._client = httpx.Client(
518
  base_url=self._base_url,
519
+ headers={"Authorization": f"Bearer {api_token}"} if api_token else {},
520
  timeout=httpx.Timeout(MAX_TIMEOUT, connect=30),
521
  follow_redirects=True,
522
  )
 
582
  base = name or "sandbox"
583
  suffix = uuid.uuid4().hex[:8]
584
  space_id = f"{owner}/{base}-{suffix}"
585
+ sandbox_api_token = secrets_lib.token_urlsafe(32)
586
 
587
  _log(f"Creating sandbox: {space_id} (from {template})...")
588
 
 
603
  # Inject secrets BEFORE uploading server files (which triggers rebuild).
604
  # Secrets added after a Space is running aren't available until restart,
605
  # so they must be set before the build/start cycle.
606
+ sandbox_secrets = {**(secrets or {}), "SANDBOX_API_TOKEN": sandbox_api_token}
607
+ if sandbox_secrets:
608
+ for key, val in sandbox_secrets.items():
609
  api.add_space_secret(space_id, key, val)
610
 
611
  # Upload sandbox server and Dockerfile (triggers rebuild)
 
638
  _check_cancel()
639
 
640
  # Wait for the API server to be responsive (non-fatal)
641
+ sb = cls(
642
+ space_id=space_id,
643
+ token=token,
644
+ api_token=sandbox_api_token,
645
+ _owns_space=True,
646
+ )
647
  try:
648
  sb._wait_for_api(timeout=API_WAIT_TIMEOUT, log=_log)
649
  except TimeoutError as e:
 
674
  log("Server files uploaded, rebuild triggered.")
675
 
676
  @classmethod
677
+ def connect(
678
+ cls,
679
+ space_id: str,
680
+ *,
681
+ token: str | None = None,
682
+ api_token: str | None = None,
683
+ ) -> Sandbox:
684
  """
685
  Connect to an existing running Space.
686
 
687
  Does a health check to verify the Space is reachable.
688
  """
689
+ sb = cls(
690
+ space_id=space_id,
691
+ token=token,
692
+ api_token=api_token,
693
+ _owns_space=False,
694
+ )
695
  sb._wait_for_api(timeout=60)
696
  return sb
697
 
 
724
  )
725
  print(f"Deleting sandbox: {self.space_id}...")
726
  self._hf_api.delete_repo(self.space_id, repo_type="space")
727
+ # Clear ownership so a second cleanup call (e.g. delete_session +
728
+ # _run_session.finally both fire) early-returns instead of retrying
729
+ # a 404 delete and emitting a spurious ERROR log.
730
+ self._owns_space = False
731
  self._client.close()
732
  print("Deleted.")
733
 
agent/tools/sandbox_tool.py CHANGED
@@ -12,13 +12,29 @@ a cpu-basic sandbox is auto-created (no approval needed).
12
  from __future__ import annotations
13
 
14
  import asyncio
 
 
15
  import threading
 
16
  from typing import Any
17
 
18
  from huggingface_hub import HfApi, SpaceHardware
19
 
20
  from agent.core.session import Event
21
  from agent.tools.sandbox_client import Sandbox
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  def _looks_like_path(script: str) -> bool:
@@ -62,11 +78,89 @@ async def resolve_sandbox_script(
62
  return None, f"Failed to read {script} from sandbox: {e}"
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # ── Tool name mapping (short agent names → Sandbox client names) ──────
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  async def _ensure_sandbox(
69
- session: Any, hardware: str = "cpu-basic", **create_kwargs
 
 
 
70
  ) -> tuple[Sandbox | None, str | None]:
71
  """
72
  Ensure a sandbox exists on the session. Auto-creates with given hardware if needed.
@@ -109,6 +203,23 @@ async def _ensure_sandbox(
109
  Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
110
  )
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # Bridge asyncio cancel event to a threading.Event for the blocking create call.
113
  # We poll session._cancelled from the main loop in a background task and set
114
  # a threading.Event that Sandbox.create checks during its polling loops.
@@ -120,11 +231,15 @@ async def _ensure_sandbox(
120
 
121
  watcher_task = asyncio.create_task(_watch_cancel())
122
 
 
 
 
 
123
  kwargs = {
124
  "owner": owner,
125
  "hardware": hardware,
126
  "token": token,
127
- "secrets": {"HF_TOKEN": token},
128
  "log": _log,
129
  "cancel_event": cancel_flag,
130
  **create_kwargs,
@@ -188,6 +303,9 @@ SANDBOX_CREATE_TOOL_SPEC = {
188
  "fp32 ≈ 4 bytes/param, plus ~20% overhead for optimizer states during training.\n"
189
  "Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). "
190
  "If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n"
 
 
 
191
  "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n"
192
  ),
193
  "parameters": {
@@ -204,16 +322,49 @@ SANDBOX_CREATE_TOOL_SPEC = {
204
  "type": "boolean",
205
  "description": "If true, create a private Space",
206
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  },
208
  },
209
  }
210
 
211
 
212
  async def sandbox_create_handler(
213
- args: dict[str, Any], session: Any = None
214
  ) -> tuple[str, bool]:
215
  """Handle sandbox_create tool calls."""
216
  hardware = args.get("hardware", "cpu-basic")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
  # If sandbox already exists, return its info
219
  if session and getattr(session, "sandbox", None):
@@ -226,6 +377,7 @@ async def sandbox_create_handler(
226
  "Hardware cannot be changed by calling sandbox_create again. "
227
  "Delete the existing sandbox first if you need a different tier."
228
  )
 
229
  return (
230
  f"Sandbox already active: {sb.space_id}\n"
231
  f"URL: {sb.url}\n"
@@ -233,18 +385,32 @@ async def sandbox_create_handler(
233
  f"Use bash/read/write/edit to interact with it."
234
  ), True
235
 
236
- create_kwargs = {}
237
  if "private" in args:
238
  create_kwargs["private"] = args["private"]
239
 
 
 
 
 
 
 
 
240
  try:
241
- sb, error = await _ensure_sandbox(session, hardware=hardware, **create_kwargs)
 
 
 
 
 
242
  except Exception as e:
243
  return f"Failed to create sandbox: {e}", False
244
 
245
  if error:
246
  return error, False
247
 
 
 
248
  return (
249
  f"Sandbox created: {sb.space_id}\n"
250
  f"URL: {sb.url}\n"
 
12
  from __future__ import annotations
13
 
14
  import asyncio
15
+ import logging
16
+ import re
17
  import threading
18
+ from datetime import datetime, timedelta, timezone
19
  from typing import Any
20
 
21
  from huggingface_hub import HfApi, SpaceHardware
22
 
23
  from agent.core.session import Event
24
  from agent.tools.sandbox_client import Sandbox
25
+ from agent.tools.trackio_seed import ensure_trackio_dashboard
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Match the exact suffix pattern Sandbox.create produces: "sandbox-<8 hex>".
30
+ # Used to identify orphan sandboxes from prior sessions safely (won't match
31
+ # user-renamed lookalikes).
32
+ _SANDBOX_NAME_RE = re.compile(r"^sandbox-[a-f0-9]{8}$")
33
+
34
+ # How stale a sandbox must be before we treat it as definitely orphan.
35
+ # Anything more recent could be tied to a still-live session in another tab,
36
+ # so we leave it alone.
37
+ _ORPHAN_STALE_AFTER = timedelta(hours=1)
38
 
39
 
40
  def _looks_like_path(script: str) -> bool:
 
78
  return None, f"Failed to read {script} from sandbox: {e}"
79
 
80
 
81
+ async def _seed_trackio_dashboard_safe(session: Any, space_id: str) -> None:
82
+ """Idempotently seed *space_id* with trackio dashboard files using the
83
+ session's HF token. Logs progress, swallows errors — a failed seed should
84
+ not block sandbox creation."""
85
+ if not session or not getattr(session, "hf_token", None):
86
+ return
87
+ loop = asyncio.get_running_loop()
88
+
89
+ def _log(msg: str) -> None:
90
+ loop.call_soon_threadsafe(
91
+ session.event_queue.put_nowait,
92
+ Event(event_type="tool_log", data={"tool": "sandbox_create", "log": msg}),
93
+ )
94
+
95
+ try:
96
+ await asyncio.to_thread(
97
+ ensure_trackio_dashboard, space_id, session.hf_token, _log
98
+ )
99
+ except Exception as e:
100
+ _log(f"trackio dashboard seed failed: {e}")
101
+
102
+
103
  # ── Tool name mapping (short agent names → Sandbox client names) ──────
104
 
105
 
106
+ def _cleanup_user_orphan_sandboxes(
107
+ api: HfApi,
108
+ owner: str,
109
+ log: Any,
110
+ ) -> int:
111
+ """Delete stale ``sandbox-<8hex>`` Spaces in ``owner``'s account.
112
+
113
+ "Stale" = not modified in the last hour. The naming pattern + staleness
114
+ filter together make this safe:
115
+
116
+ * Naming: only matches ``sandbox-<exactly 8 lowercase hex>``, the
117
+ pattern Sandbox.create produces. Won't touch user-renamed Spaces.
118
+ * Staleness: anything modified in the last hour might still be tied
119
+ to a live session in another tab/replica, so we leave it alone.
120
+
121
+ Runs blocking — call via ``asyncio.to_thread``. Best-effort: failures
122
+ are logged but never raised, so a flaky HF API never blocks creation.
123
+ """
124
+ cutoff = datetime.now(timezone.utc) - _ORPHAN_STALE_AFTER
125
+ deleted = 0
126
+ try:
127
+ spaces = list(api.list_spaces(author=owner, limit=200))
128
+ except Exception as e:
129
+ log(f"orphan sweep: list_spaces failed: {e}")
130
+ return 0
131
+
132
+ for space in spaces:
133
+ space_name = space.id.rsplit("/", 1)[-1]
134
+ if not _SANDBOX_NAME_RE.match(space_name):
135
+ continue
136
+
137
+ last_mod = getattr(space, "lastModified", None) or getattr(space, "last_modified", None)
138
+ if isinstance(last_mod, str):
139
+ try:
140
+ last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00"))
141
+ except ValueError:
142
+ last_mod = None
143
+ if last_mod and last_mod > cutoff:
144
+ # Recent — could be a concurrent live session. Skip.
145
+ continue
146
+
147
+ try:
148
+ api.delete_repo(repo_id=space.id, repo_type="space")
149
+ deleted += 1
150
+ log(f"orphan sweep: deleted {space.id}")
151
+ except Exception as e:
152
+ log(f"orphan sweep: failed to delete {space.id}: {e}")
153
+
154
+ if deleted:
155
+ log(f"orphan sweep: cleaned up {deleted} stale sandbox(es) before create")
156
+ return deleted
157
+
158
+
159
  async def _ensure_sandbox(
160
+ session: Any,
161
+ hardware: str = "cpu-basic",
162
+ extra_secrets: dict[str, str] | None = None,
163
+ **create_kwargs,
164
  ) -> tuple[Sandbox | None, str | None]:
165
  """
166
  Ensure a sandbox exists on the session. Auto-creates with given hardware if needed.
 
203
  Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}),
204
  )
205
 
206
+ # Before we create a new sandbox, sweep this user's stale sandboxes from
207
+ # prior sessions. ``_cleanup_sandbox`` in session_manager fires only on
208
+ # clean session exit; pod kills, WebSocket drops, etc. leave orphans
209
+ # behind, and they accumulate on every new session forever (observed
210
+ # 2310 leaked across the Hub on 2026-04-27). Doing the cleanup here at
211
+ # session start = self-healing, no separate cron needed.
212
+ #
213
+ # The 1h staleness filter is the safety: a sandbox modified in the last
214
+ # hour might still be tied to a live session in another tab, so we skip.
215
+ # Anything older has no realistic chance of being active given typical
216
+ # session lengths.
217
+ try:
218
+ await asyncio.to_thread(_cleanup_user_orphan_sandboxes, api, owner, _log)
219
+ except Exception as e:
220
+ # Cleanup is best-effort — never block sandbox_create on it.
221
+ _log(f"orphan sandbox sweep failed (non-fatal): {e}")
222
+
223
  # Bridge asyncio cancel event to a threading.Event for the blocking create call.
224
  # We poll session._cancelled from the main loop in a background task and set
225
  # a threading.Event that Sandbox.create checks during its polling loops.
 
231
 
232
  watcher_task = asyncio.create_task(_watch_cancel())
233
 
234
+ secrets: dict[str, str] = {"HF_TOKEN": token}
235
+ if extra_secrets:
236
+ secrets.update({k: v for k, v in extra_secrets.items() if v})
237
+
238
  kwargs = {
239
  "owner": owner,
240
  "hardware": hardware,
241
  "token": token,
242
+ "secrets": secrets,
243
  "log": _log,
244
  "cancel_event": cancel_flag,
245
  **create_kwargs,
 
303
  "fp32 ≈ 4 bytes/param, plus ~20% overhead for optimizer states during training.\n"
304
  "Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). "
305
  "If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n"
306
+ "If you intend to run a training script in this sandbox that uses report_to='trackio', "
307
+ "pass `trackio_space_id` (e.g. '<username>/mlintern-<8char>') and `trackio_project` so they "
308
+ "are set as TRACKIO_SPACE_ID/TRACKIO_PROJECT secrets in the sandbox and the UI can embed the live dashboard.\n\n"
309
  "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n"
310
  ),
311
  "parameters": {
 
322
  "type": "boolean",
323
  "description": "If true, create a private Space",
324
  },
325
+ "trackio_space_id": {
326
+ "type": "string",
327
+ "description": (
328
+ "Optional. The HF Space hosting the trackio dashboard for runs in this sandbox "
329
+ "(e.g. '<username>/mlintern-<8char>', under YOUR HF namespace). Injected as "
330
+ "TRACKIO_SPACE_ID secret and surfaced to the UI. The Space is auto-created and "
331
+ "seeded with the trackio dashboard — DO NOT pre-create it via hf_repo_git, "
332
+ "that produces an empty Space that breaks the embed."
333
+ ),
334
+ },
335
+ "trackio_project": {
336
+ "type": "string",
337
+ "description": (
338
+ "Optional. The trackio project name. Injected as TRACKIO_PROJECT secret and "
339
+ "used by the UI to filter the embedded dashboard to this project."
340
+ ),
341
+ },
342
  },
343
  },
344
  }
345
 
346
 
347
  async def sandbox_create_handler(
348
+ args: dict[str, Any], session: Any = None, tool_call_id: str | None = None
349
  ) -> tuple[str, bool]:
350
  """Handle sandbox_create tool calls."""
351
  hardware = args.get("hardware", "cpu-basic")
352
+ trackio_space_id = args.get("trackio_space_id") or None
353
+ trackio_project = args.get("trackio_project") or None
354
+
355
+ async def _emit_trackio_state(sb: Sandbox) -> None:
356
+ """Tell the frontend which trackio dashboard to embed for this sandbox."""
357
+ if not (session and tool_call_id and trackio_space_id):
358
+ return
359
+ data: dict[str, Any] = {
360
+ "tool_call_id": tool_call_id,
361
+ "tool": "sandbox_create",
362
+ "state": "running",
363
+ "trackioSpaceId": trackio_space_id,
364
+ }
365
+ if trackio_project:
366
+ data["trackioProject"] = trackio_project
367
+ await session.send_event(Event(event_type="tool_state_change", data=data))
368
 
369
  # If sandbox already exists, return its info
370
  if session and getattr(session, "sandbox", None):
 
377
  "Hardware cannot be changed by calling sandbox_create again. "
378
  "Delete the existing sandbox first if you need a different tier."
379
  )
380
+ await _emit_trackio_state(sb)
381
  return (
382
  f"Sandbox already active: {sb.space_id}\n"
383
  f"URL: {sb.url}\n"
 
385
  f"Use bash/read/write/edit to interact with it."
386
  ), True
387
 
388
+ create_kwargs: dict[str, Any] = {}
389
  if "private" in args:
390
  create_kwargs["private"] = args["private"]
391
 
392
+ extra_secrets: dict[str, str] = {}
393
+ if trackio_space_id:
394
+ extra_secrets["TRACKIO_SPACE_ID"] = trackio_space_id
395
+ await _seed_trackio_dashboard_safe(session, trackio_space_id)
396
+ if trackio_project:
397
+ extra_secrets["TRACKIO_PROJECT"] = trackio_project
398
+
399
  try:
400
+ sb, error = await _ensure_sandbox(
401
+ session,
402
+ hardware=hardware,
403
+ extra_secrets=extra_secrets or None,
404
+ **create_kwargs,
405
+ )
406
  except Exception as e:
407
  return f"Failed to create sandbox: {e}", False
408
 
409
  if error:
410
  return error, False
411
 
412
+ await _emit_trackio_state(sb)
413
+
414
  return (
415
  f"Sandbox created: {sb.space_id}\n"
416
  f"URL: {sb.url}\n"
agent/tools/trackio_seed.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Seed an HF Space with the trackio dashboard.
2
+
3
+ Background: when the agent creates a Space via `hf_repo_git create_repo` (or
4
+ the user pre-creates one), it ships with no app.py — so the iframe shows the
5
+ default Gradio "Get started" template instead of charts. Trackio's `init()`
6
+ detects the existing Space but does NOT auto-bootstrap dashboard files into it,
7
+ so the dashboard never materializes.
8
+
9
+ This helper writes the three files trackio's runtime expects (README.md,
10
+ requirements.txt, app.py) into the Space, idempotently, BEFORE the job that
11
+ will call `trackio.init()` runs. We deliberately omit `hf_oauth: true` from
12
+ the README so the embedded iframe in ml-intern renders without a login click —
13
+ per-user privacy is enforced by namespace ownership instead.
14
+
15
+ Beyond the dashboard files, the helper also creates the metrics bucket and
16
+ mounts it on the Space at `/data` (with `TRACKIO_DIR` / `TRACKIO_BUCKET_ID`
17
+ Space variables). Without this, the running job writes metrics into a bucket
18
+ that the dashboard Space can't read, and the iframe shows "No projects".
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import io
24
+ from typing import Callable, Optional
25
+
26
+ from huggingface_hub import (
27
+ HfApi,
28
+ Volume,
29
+ add_space_variable,
30
+ create_bucket,
31
+ create_repo,
32
+ )
33
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
34
+
35
+
36
+ _README = """---
37
+ title: Trackio Dashboard
38
+ emoji: 📊
39
+ colorFrom: pink
40
+ colorTo: gray
41
+ sdk: gradio
42
+ app_file: app.py
43
+ pinned: false
44
+ tags:
45
+ - trackio
46
+ ---
47
+
48
+ Embedded trackio dashboard for ml-intern runs.
49
+ """
50
+
51
+ _REQUIREMENTS = "trackio\n"
52
+ _APP_PY = "import trackio\ntrackio.show()\n"
53
+
54
+ # ml-intern brand mark surfaced inside the trackio dashboard. Trackio reads
55
+ # `TRACKIO_LOGO_LIGHT_URL` / `TRACKIO_LOGO_DARK_URL` from Space variables and
56
+ # renders them in place of its own logo. We point at the publicly-resolvable
57
+ # copy on the smolagents/ml-intern Space repo so any seeded dashboard inherits
58
+ # the ml-intern branding without each user having to host the asset.
59
+ _LOGO_URL = (
60
+ "https://huggingface.co/spaces/smolagents/ml-intern/"
61
+ "resolve/main/frontend/public/smolagents.webp"
62
+ )
63
+
64
+ _FILES = {
65
+ "README.md": _README,
66
+ "requirements.txt": _REQUIREMENTS,
67
+ "app.py": _APP_PY,
68
+ }
69
+
70
+
71
+ def _already_seeded(api: HfApi, space_id: str) -> bool:
72
+ """Cheap check: does the Space already have a trackio dashboard app.py?
73
+
74
+ Avoids re-uploading the same three files on every job submission. We look
75
+ for the literal `trackio.show` call which is the load-bearing line — any
76
+ other app.py shape (the default gradio shell, a stale custom one) means
77
+ we should re-seed.
78
+ """
79
+ try:
80
+ path = api.hf_hub_download(
81
+ repo_id=space_id, repo_type="space", filename="app.py"
82
+ )
83
+ except (EntryNotFoundError, RepositoryNotFoundError, OSError):
84
+ return False
85
+ try:
86
+ with open(path, "r", encoding="utf-8") as f:
87
+ return "trackio.show" in f.read()
88
+ except OSError:
89
+ return False
90
+
91
+
92
+ def _get_space_volumes(api: HfApi, space_id: str) -> list:
93
+ """Return mounted volumes for a Space.
94
+
95
+ `get_space_runtime()` doesn't always populate `volumes` even when the
96
+ mount exists; mirror trackio's fallback to `space_info().runtime.volumes`.
97
+ """
98
+ runtime = api.get_space_runtime(space_id)
99
+ if getattr(runtime, "volumes", None):
100
+ return list(runtime.volumes)
101
+ info = api.space_info(space_id)
102
+ if info.runtime and getattr(info.runtime, "volumes", None):
103
+ return list(info.runtime.volumes)
104
+ return []
105
+
106
+
107
+ def _ensure_bucket_mounted(
108
+ api: HfApi,
109
+ space_id: str,
110
+ bucket_id: str,
111
+ hf_token: str,
112
+ log: Optional[Callable[[str], None]] = None,
113
+ ) -> None:
114
+ """Create the bucket if missing, mount it at `/data` on the Space, and
115
+ set the `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` Space variables. Idempotent —
116
+ skips work that has already been done.
117
+ """
118
+ create_bucket(bucket_id, private=True, exist_ok=True, token=hf_token)
119
+
120
+ existing = _get_space_volumes(api, space_id)
121
+ already_mounted = any(
122
+ getattr(v, "type", None) == "bucket"
123
+ and getattr(v, "source", None) == bucket_id
124
+ and getattr(v, "mount_path", None) == "/data"
125
+ for v in existing
126
+ )
127
+ if not already_mounted:
128
+ preserved = [
129
+ v
130
+ for v in existing
131
+ if not (
132
+ getattr(v, "type", None) == "bucket"
133
+ and (
134
+ getattr(v, "source", None) == bucket_id
135
+ or getattr(v, "mount_path", None) == "/data"
136
+ )
137
+ )
138
+ ]
139
+ api.set_space_volumes(
140
+ space_id,
141
+ preserved + [Volume(type="bucket", source=bucket_id, mount_path="/data")],
142
+ )
143
+ if log:
144
+ log(f"mounted bucket {bucket_id} at /data on {space_id}")
145
+
146
+ variables = api.get_space_variables(space_id)
147
+ desired = {
148
+ "TRACKIO_DIR": "/data/trackio",
149
+ "TRACKIO_BUCKET_ID": bucket_id,
150
+ "TRACKIO_LOGO_LIGHT_URL": _LOGO_URL,
151
+ "TRACKIO_LOGO_DARK_URL": _LOGO_URL,
152
+ }
153
+ for key, value in desired.items():
154
+ if getattr(variables.get(key), "value", None) != value:
155
+ add_space_variable(space_id, key, value, token=hf_token)
156
+
157
+
158
+ def ensure_trackio_dashboard(
159
+ space_id: str,
160
+ hf_token: str,
161
+ log: Optional[Callable[[str], None]] = None,
162
+ ) -> bool:
163
+ """Make sure *space_id* is fully wired for trackio:
164
+ 1. Space exists with our dashboard files (README without `hf_oauth`,
165
+ `requirements.txt`, `app.py` calling `trackio.show`).
166
+ 2. Bucket `<space_id>-bucket` exists, is mounted at `/data`, and the
167
+ Space has `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` variables set.
168
+
169
+ Idempotent — re-running is cheap. Returns True if any seeding happened
170
+ in step (1), False if the dashboard files were already in place. Bucket
171
+ mount is always re-checked.
172
+ """
173
+ api = HfApi(token=hf_token)
174
+
175
+ create_repo(
176
+ repo_id=space_id,
177
+ repo_type="space",
178
+ space_sdk="gradio",
179
+ exist_ok=True,
180
+ token=hf_token,
181
+ )
182
+
183
+ seeded_files = False
184
+ if _already_seeded(api, space_id):
185
+ if log:
186
+ log(f"trackio dashboard already seeded on {space_id}")
187
+ else:
188
+ if log:
189
+ log(f"seeding trackio dashboard files into {space_id}")
190
+ for path_in_repo, content in _FILES.items():
191
+ api.upload_file(
192
+ path_or_fileobj=io.BytesIO(content.encode("utf-8")),
193
+ path_in_repo=path_in_repo,
194
+ repo_id=space_id,
195
+ repo_type="space",
196
+ commit_message=f"ml-intern: seed trackio dashboard ({path_in_repo})",
197
+ )
198
+ seeded_files = True
199
+
200
+ bucket_id = f"{space_id}-bucket"
201
+ _ensure_bucket_mounted(api, space_id, bucket_id, hf_token, log)
202
+
203
+ if log:
204
+ log(f"trackio dashboard ready: https://huggingface.co/spaces/{space_id}")
205
+ return seeded_files
agent/tools/web_search_tool.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DuckDuckGo HTML web search tool.
2
+
3
+ This mirrors Claw Code's Rust WebSearch behavior: fetch DuckDuckGo's HTML
4
+ endpoint, extract result links, optionally filter domains, and return a
5
+ JSON payload the model can cite.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import html
12
+ import json
13
+ import os
14
+ import time
15
+ from dataclasses import dataclass
16
+ from html.parser import HTMLParser
17
+ from typing import Any
18
+ from urllib.parse import parse_qsl, parse_qs, urlencode, urlparse, urlunparse
19
+
20
+ import requests
21
+
22
+ DEFAULT_SEARCH_URL = "https://html.duckduckgo.com/html/"
23
+ WEB_SEARCH_BASE_URL_ENV = "CLAWD_WEB_SEARCH_BASE_URL"
24
+ USER_AGENT = "clawd-rust-tools/0.1"
25
+ REQUEST_TIMEOUT_SECONDS = 20
26
+ MAX_RESULTS = 8
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class SearchHit:
31
+ title: str
32
+ url: str
33
+
34
+ def as_json(self) -> dict[str, str]:
35
+ return {"title": self.title, "url": self.url}
36
+
37
+
38
+ class _AnchorParser(HTMLParser):
39
+ def __init__(self, *, require_result_class: bool) -> None:
40
+ super().__init__(convert_charrefs=True)
41
+ self.require_result_class = require_result_class
42
+ self.hits: list[tuple[str, str]] = []
43
+ self._active_href: str | None = None
44
+ self._active_text: list[str] = []
45
+
46
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
47
+ if tag.lower() != "a":
48
+ return
49
+ attr_map = {key.lower(): value or "" for key, value in attrs}
50
+ href = attr_map.get("href")
51
+ if not href:
52
+ return
53
+ if self.require_result_class and "result__a" not in attr_map.get("class", ""):
54
+ return
55
+ self._active_href = href
56
+ self._active_text = []
57
+
58
+ def handle_data(self, data: str) -> None:
59
+ if self._active_href is not None:
60
+ self._active_text.append(data)
61
+
62
+ def handle_entityref(self, name: str) -> None:
63
+ if self._active_href is not None:
64
+ self._active_text.append(f"&{name};")
65
+
66
+ def handle_charref(self, name: str) -> None:
67
+ if self._active_href is not None:
68
+ self._active_text.append(f"&#{name};")
69
+
70
+ def handle_endtag(self, tag: str) -> None:
71
+ if tag.lower() != "a" or self._active_href is None:
72
+ return
73
+ title = collapse_whitespace(html.unescape("".join(self._active_text))).strip()
74
+ self.hits.append((self._active_href, title))
75
+ self._active_href = None
76
+ self._active_text = []
77
+
78
+
79
+ def build_search_url(query: str) -> str:
80
+ base = os.environ.get(WEB_SEARCH_BASE_URL_ENV, DEFAULT_SEARCH_URL)
81
+ parsed = urlparse(base)
82
+ if parsed.scheme not in {"http", "https"} or not parsed.netloc:
83
+ raise ValueError(f"invalid search base URL: {base}")
84
+
85
+ query_pairs = parse_qsl(parsed.query, keep_blank_values=True)
86
+ query_pairs.append(("q", query))
87
+ return urlunparse(parsed._replace(query=urlencode(query_pairs)))
88
+
89
+
90
+ def collapse_whitespace(value: str) -> str:
91
+ return " ".join(value.split())
92
+
93
+
94
+ def decode_duckduckgo_redirect(url: str) -> str | None:
95
+ if url.startswith("http://") or url.startswith("https://"):
96
+ return html.unescape(url)
97
+ if url.startswith("//"):
98
+ joined = f"https:{url}"
99
+ elif url.startswith("/"):
100
+ joined = f"https://duckduckgo.com{url}"
101
+ else:
102
+ return None
103
+
104
+ parsed = urlparse(joined)
105
+ if parsed.path in {"/l", "/l/"}:
106
+ uddg = parse_qs(parsed.query).get("uddg", [])
107
+ if uddg:
108
+ return html.unescape(uddg[0])
109
+ return joined
110
+
111
+
112
+ def _extract_links(search_html: str, *, require_result_class: bool) -> list[SearchHit]:
113
+ parser = _AnchorParser(require_result_class=require_result_class)
114
+ parser.feed(search_html)
115
+
116
+ hits: list[SearchHit] = []
117
+ for raw_url, title in parser.hits:
118
+ if not title:
119
+ continue
120
+ decoded_url = decode_duckduckgo_redirect(raw_url)
121
+ if decoded_url and (
122
+ decoded_url.startswith("http://") or decoded_url.startswith("https://")
123
+ ):
124
+ hits.append(SearchHit(title=title, url=decoded_url))
125
+ return hits
126
+
127
+
128
+ def extract_search_hits(search_html: str) -> list[SearchHit]:
129
+ return _extract_links(search_html, require_result_class=True)
130
+
131
+
132
+ def extract_search_hits_from_generic_links(search_html: str) -> list[SearchHit]:
133
+ return _extract_links(search_html, require_result_class=False)
134
+
135
+
136
+ def normalize_domain_filter(domain: str) -> str:
137
+ trimmed = domain.strip()
138
+ parsed = urlparse(trimmed)
139
+ candidate = parsed.hostname if parsed.scheme and parsed.hostname else trimmed
140
+ return candidate.strip().lstrip(".").rstrip("/").lower()
141
+
142
+
143
+ def host_matches_list(url: str, domains: list[str]) -> bool:
144
+ host = urlparse(url).hostname
145
+ if not host:
146
+ return False
147
+ normalized_host = host.lower()
148
+ for domain in domains:
149
+ normalized = normalize_domain_filter(domain)
150
+ if normalized and (
151
+ normalized_host == normalized or normalized_host.endswith(f".{normalized}")
152
+ ):
153
+ return True
154
+ return False
155
+
156
+
157
+ def dedupe_hits(hits: list[SearchHit]) -> list[SearchHit]:
158
+ seen: set[str] = set()
159
+ deduped: list[SearchHit] = []
160
+ for hit in hits:
161
+ if hit.url in seen:
162
+ continue
163
+ seen.add(hit.url)
164
+ deduped.append(hit)
165
+ return deduped
166
+
167
+
168
+ def execute_web_search(
169
+ query: str,
170
+ allowed_domains: list[str] | None = None,
171
+ blocked_domains: list[str] | None = None,
172
+ tool_use_id: str = "web_search_1",
173
+ ) -> dict[str, Any]:
174
+ started = time.monotonic()
175
+ search_url = build_search_url(query)
176
+ response = requests.get(
177
+ search_url,
178
+ headers={"User-Agent": USER_AGENT},
179
+ timeout=REQUEST_TIMEOUT_SECONDS,
180
+ allow_redirects=True,
181
+ )
182
+
183
+ hits = extract_search_hits(response.text)
184
+ if not hits and urlparse(response.url or search_url).hostname:
185
+ hits = extract_search_hits_from_generic_links(response.text)
186
+
187
+ if allowed_domains is not None:
188
+ hits = [hit for hit in hits if host_matches_list(hit.url, allowed_domains)]
189
+ if blocked_domains is not None:
190
+ hits = [hit for hit in hits if not host_matches_list(hit.url, blocked_domains)]
191
+
192
+ hits = dedupe_hits(hits)[:MAX_RESULTS]
193
+ rendered_hits = "\n".join(f"- [{hit.title}]({hit.url})" for hit in hits)
194
+ if hits:
195
+ summary = (
196
+ f"Search results for {query!r}. Include a Sources section in the final answer.\n"
197
+ f"{rendered_hits}"
198
+ )
199
+ else:
200
+ summary = f"No web search results matched the query {query!r}."
201
+
202
+ return {
203
+ "query": query,
204
+ "results": [
205
+ summary,
206
+ {
207
+ "tool_use_id": tool_use_id,
208
+ "content": [hit.as_json() for hit in hits],
209
+ },
210
+ ],
211
+ "durationSeconds": time.monotonic() - started,
212
+ }
213
+
214
+
215
+ WEB_SEARCH_TOOL_SPEC = {
216
+ "name": "web_search",
217
+ "description": "Search the web for current information and return cited results.",
218
+ "parameters": {
219
+ "type": "object",
220
+ "properties": {
221
+ "query": {"type": "string", "minLength": 2},
222
+ "allowed_domains": {
223
+ "type": "array",
224
+ "items": {"type": "string"},
225
+ "description": "Optional allowlist of domains or URLs. Subdomains match.",
226
+ },
227
+ "blocked_domains": {
228
+ "type": "array",
229
+ "items": {"type": "string"},
230
+ "description": "Optional blocklist of domains or URLs. Subdomains match.",
231
+ },
232
+ },
233
+ "required": ["query"],
234
+ "additionalProperties": False,
235
+ },
236
+ }
237
+
238
+
239
+ def _optional_string_list(arguments: dict[str, Any], key: str) -> list[str] | None:
240
+ value = arguments.get(key)
241
+ if value is None:
242
+ return None
243
+ if not isinstance(value, list) or not all(isinstance(item, str) for item in value):
244
+ raise ValueError(f"{key} must be an array of strings")
245
+ return value
246
+
247
+
248
+ async def web_search_handler(
249
+ arguments: dict[str, Any],
250
+ session: Any = None,
251
+ tool_call_id: str | None = None,
252
+ **_kw: Any,
253
+ ) -> tuple[str, bool]:
254
+ query_value = arguments.get("query", "")
255
+ if not isinstance(query_value, str):
256
+ return "Error: web_search requires a query string with at least 2 characters.", False
257
+
258
+ query = query_value.strip()
259
+ if len(query) < 2:
260
+ return "Error: web_search requires a query with at least 2 characters.", False
261
+
262
+ try:
263
+ output = await asyncio.to_thread(
264
+ execute_web_search,
265
+ query=query,
266
+ allowed_domains=_optional_string_list(arguments, "allowed_domains"),
267
+ blocked_domains=_optional_string_list(arguments, "blocked_domains"),
268
+ tool_use_id=tool_call_id or "web_search_1",
269
+ )
270
+ except Exception as exc:
271
+ return f"Error executing web search: {exc}", False
272
+
273
+ return json.dumps(output, indent=2), True
backend/dependencies.py CHANGED
@@ -12,6 +12,8 @@ from typing import Any
12
  import httpx
13
  from fastapi import HTTPException, Request, status
14
 
 
 
15
  from agent.core.hf_access import fetch_whoami_v2, jobs_access_from_whoami
16
 
17
  logger = logging.getLogger(__name__)
@@ -157,9 +159,8 @@ async def get_current_user(request: Request) -> dict[str, Any]:
157
  return DEV_USER
158
 
159
  # Try Authorization header
160
- auth_header = request.headers.get("Authorization", "")
161
- if auth_header.startswith("Bearer "):
162
- token = auth_header[7:]
163
  user = await _extract_user_from_token(token)
164
  if user:
165
  return user
@@ -183,9 +184,9 @@ def _extract_token(request: Request) -> str | None:
183
 
184
  Mirrors the lookup order used by ``get_current_user``.
185
  """
186
- auth_header = request.headers.get("Authorization", "")
187
- if auth_header.startswith("Bearer "):
188
- return auth_header[7:]
189
  return request.cookies.get("hf_access_token")
190
 
191
 
@@ -202,4 +203,3 @@ async def require_huggingface_org_member(request: Request) -> bool:
202
  if not token:
203
  return False
204
  return await check_org_membership(token, HF_EMPLOYEE_ORG)
205
-
 
12
  import httpx
13
  from fastapi import HTTPException, Request, status
14
 
15
+ from agent.core.hf_tokens import bearer_token_from_header
16
+
17
  from agent.core.hf_access import fetch_whoami_v2, jobs_access_from_whoami
18
 
19
  logger = logging.getLogger(__name__)
 
159
  return DEV_USER
160
 
161
  # Try Authorization header
162
+ token = bearer_token_from_header(request.headers.get("Authorization", ""))
163
+ if token:
 
164
  user = await _extract_user_from_token(token)
165
  if user:
166
  return user
 
184
 
185
  Mirrors the lookup order used by ``get_current_user``.
186
  """
187
+ token = bearer_token_from_header(request.headers.get("Authorization", ""))
188
+ if token:
189
+ return token
190
  return request.cookies.get("hf_access_token")
191
 
192
 
 
203
  if not token:
204
  return False
205
  return await check_org_membership(token, HF_EMPLOYEE_ORG)
 
backend/main.py CHANGED
@@ -6,14 +6,17 @@ from contextlib import asynccontextmanager
6
  from pathlib import Path
7
 
8
  from dotenv import load_dotenv
 
 
 
 
 
9
  from fastapi import FastAPI
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.staticfiles import StaticFiles
12
  from routes.agent import router as agent_router
13
  from routes.auth import router as auth_router
14
-
15
- # Load .env from project root (parent directory)
16
- load_dotenv(Path(__file__).parent.parent / ".env")
17
 
18
  # Configure logging
19
  logging.basicConfig(
@@ -27,6 +30,7 @@ logger = logging.getLogger(__name__)
27
  async def lifespan(app: FastAPI):
28
  """Application lifespan handler."""
29
  logger.info("Starting HF Agent backend...")
 
30
  # Start in-process hourly KPI rollup. Replaces an external cron so the
31
  # rollup lives next to the data and reuses the Space's HF token.
32
  try:
@@ -34,7 +38,6 @@ async def lifespan(app: FastAPI):
34
  kpis_scheduler.start()
35
  except Exception as e:
36
  logger.warning("KPI scheduler failed to start: %s", e)
37
-
38
  yield
39
 
40
  logger.info("Shutting down HF Agent backend...")
@@ -47,7 +50,6 @@ async def lifespan(app: FastAPI):
47
  # Final-flush: save every still-active session so we don't lose traces on
48
  # server restart. Uploads are detached subprocesses — this is fast.
49
  try:
50
- from session_manager import session_manager
51
  for sid, agent_session in list(session_manager.sessions.items()):
52
  sess = agent_session.session
53
  if sess.config.save_sessions:
@@ -58,6 +60,7 @@ async def lifespan(app: FastAPI):
58
  logger.warning("Failed to flush session %s: %s", sid, e)
59
  except Exception as e:
60
  logger.warning("Lifespan final-flush skipped: %s", e)
 
61
 
62
 
63
  app = FastAPI(
 
6
  from pathlib import Path
7
 
8
  from dotenv import load_dotenv
9
+
10
+ # Load .env before importing routes/session_manager so persistence and quota
11
+ # modules see local Mongo settings during startup.
12
+ load_dotenv(Path(__file__).parent.parent / ".env")
13
+
14
  from fastapi import FastAPI
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from fastapi.staticfiles import StaticFiles
17
  from routes.agent import router as agent_router
18
  from routes.auth import router as auth_router
19
+ from session_manager import session_manager
 
 
20
 
21
  # Configure logging
22
  logging.basicConfig(
 
30
  async def lifespan(app: FastAPI):
31
  """Application lifespan handler."""
32
  logger.info("Starting HF Agent backend...")
33
+ await session_manager.start()
34
  # Start in-process hourly KPI rollup. Replaces an external cron so the
35
  # rollup lives next to the data and reuses the Space's HF token.
36
  try:
 
38
  kpis_scheduler.start()
39
  except Exception as e:
40
  logger.warning("KPI scheduler failed to start: %s", e)
 
41
  yield
42
 
43
  logger.info("Shutting down HF Agent backend...")
 
50
  # Final-flush: save every still-active session so we don't lose traces on
51
  # server restart. Uploads are detached subprocesses — this is fast.
52
  try:
 
53
  for sid, agent_session in list(session_manager.sessions.items()):
54
  sess = agent_session.session
55
  if sess.config.save_sessions:
 
60
  logger.warning("Failed to flush session %s: %s", sid, e)
61
  except Exception as e:
62
  logger.warning("Lifespan final-flush skipped: %s", e)
63
+ await session_manager.close()
64
 
65
 
66
  app = FastAPI(
backend/models.py CHANGED
@@ -3,7 +3,7 @@
3
  from enum import Enum
4
  from typing import Any
5
 
6
- from pydantic import BaseModel
7
 
8
 
9
  class OpType(str, Enum):
@@ -87,6 +87,14 @@ class SessionInfo(BaseModel):
87
  user_id: str = "dev"
88
  pending_approval: list[PendingApprovalTool] | None = None
89
  model: str | None = None
 
 
 
 
 
 
 
 
90
 
91
 
92
  class HealthResponse(BaseModel):
 
3
  from enum import Enum
4
  from typing import Any
5
 
6
+ from pydantic import BaseModel, Field
7
 
8
 
9
  class OpType(str, Enum):
 
87
  user_id: str = "dev"
88
  pending_approval: list[PendingApprovalTool] | None = None
89
  model: str | None = None
90
+ title: str | None = None
91
+ notification_destinations: list[str] = Field(default_factory=list)
92
+
93
+
94
+ class SessionNotificationsRequest(BaseModel):
95
+ """Replace the session's auto-notification destinations."""
96
+
97
+ destinations: list[str]
98
 
99
 
100
  class HealthResponse(BaseModel):
backend/routes/agent.py CHANGED
@@ -24,6 +24,7 @@ from models import (
24
  HealthResponse,
25
  LLMHealthResponse,
26
  SessionInfo,
 
27
  SessionResponse,
28
  SubmitRequest,
29
  TruncateRequest,
@@ -33,6 +34,7 @@ from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, se
33
  import user_quotas
34
 
35
  from agent.core.hf_access import get_jobs_access
 
36
  from agent.core.llm_params import _resolve_llm_params
37
 
38
  logger = logging.getLogger(__name__)
@@ -118,9 +120,9 @@ async def _enforce_claude_quota(
118
  if not _is_anthropic_model(model_name):
119
  return
120
  user_id = user["user_id"]
121
- used = await user_quotas.get_claude_used_today(user_id)
122
  cap = user_quotas.daily_cap_for(user.get("plan"))
123
- if used >= cap:
 
124
  raise HTTPException(
125
  status_code=429,
126
  detail={
@@ -133,8 +135,8 @@ async def _enforce_claude_quota(
133
  ),
134
  },
135
  )
136
- await user_quotas.increment_claude(user_id)
137
  agent_session.claude_counted = True
 
138
 
139
 
140
  async def _enforce_jobs_access_for_approvals(
@@ -193,6 +195,9 @@ async def _enforce_jobs_access_for_approvals(
193
  "The selected jobs namespace is not one of your eligible paid organizations. "
194
  f"Allowed namespaces: {', '.join(access.paid_org_names)}"
195
  ),
 
 
 
196
  },
197
  )
198
  missing_namespace = [
@@ -236,13 +241,23 @@ async def _enforce_jobs_access_for_approvals(
236
  )
237
 
238
 
239
- def _check_session_access(session_id: str, user: dict[str, Any]) -> None:
240
- """Verify the user has access to the given session. Raises 403 or 404."""
241
- info = session_manager.get_session_info(session_id)
242
- if not info:
 
 
 
 
 
 
 
 
 
243
  raise HTTPException(status_code=404, detail="Session not found")
244
- if not session_manager.verify_session_access(session_id, user["user_id"]):
245
  raise HTTPException(status_code=403, detail="Access denied to this session")
 
246
 
247
 
248
  @router.get("/health", response_model=HealthResponse)
@@ -332,10 +347,8 @@ async def generate_title(
332
  reasoning model — reasoning_effort=low keeps the reasoning budget small
333
  so the 60-token output budget isn't consumed before the title is written.
334
  """
335
- api_key = (
336
- os.environ.get("INFERENCE_TOKEN")
337
- or (user.get("hf_token") if isinstance(user, dict) else None)
338
- or os.environ.get("HF_TOKEN")
339
  )
340
  try:
341
  response = await acompletion(
@@ -366,11 +379,21 @@ async def generate_title(
366
  title = title.translate(_TITLE_STRIP_CHARS).strip()
367
  if len(title) > 50:
368
  title = title[:50].rstrip() + "…"
 
 
 
 
 
369
  return {"title": title}
370
  except Exception as e:
371
  logger.warning(f"Title generation failed: {e}")
372
  fallback = request.text.strip()
373
  title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback
 
 
 
 
 
374
  return {"title": title}
375
 
376
 
@@ -391,14 +414,7 @@ async def create_session(
391
  Returns 503 if the server or user has reached the session limit.
392
  """
393
  # Extract the user's HF token (Bearer header, HttpOnly cookie, or env var)
394
- hf_token = None
395
- auth_header = request.headers.get("Authorization", "")
396
- if auth_header.startswith("Bearer "):
397
- hf_token = auth_header[7:]
398
- if not hf_token:
399
- hf_token = request.cookies.get("hf_access_token")
400
- if not hf_token:
401
- hf_token = os.environ.get("HF_TOKEN")
402
 
403
  # Optional model override. Empty body falls back to the config default.
404
  model: str | None = None
@@ -444,14 +460,7 @@ async def restore_session_summary(
444
  if not isinstance(messages, list) or not messages:
445
  raise HTTPException(status_code=400, detail="Missing 'messages' array")
446
 
447
- hf_token = None
448
- auth_header = request.headers.get("Authorization", "")
449
- if auth_header.startswith("Bearer "):
450
- hf_token = auth_header[7:]
451
- if not hf_token:
452
- hf_token = request.cookies.get("hf_access_token")
453
- if not hf_token:
454
- hf_token = os.environ.get("HF_TOKEN")
455
 
456
  model = body.get("model")
457
  valid_ids = {m["id"] for m in AVAILABLE_MODELS}
@@ -488,7 +497,7 @@ async def get_session(
488
  session_id: str, user: dict = Depends(get_current_user)
489
  ) -> SessionInfo:
490
  """Get session information. Only accessible by the session owner."""
491
- _check_session_access(session_id, user)
492
  info = session_manager.get_session_info(session_id)
493
  return SessionInfo(**info)
494
 
@@ -509,7 +518,7 @@ async def set_session_model(
509
  Switching TO an Anthropic model requires HF org membership (PR #63);
510
  free-model switches are unrestricted.
511
  """
512
- _check_session_access(session_id, user)
513
  model_id = body.get("model")
514
  if not model_id:
515
  raise HTTPException(status_code=400, detail="Missing 'model' field")
@@ -517,10 +526,9 @@ async def set_session_model(
517
  if model_id not in valid_ids:
518
  raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}")
519
  await _require_hf_for_anthropic(request, model_id)
520
- agent_session = session_manager.sessions.get(session_id)
521
  if not agent_session:
522
  raise HTTPException(status_code=404, detail="Session not found")
523
- agent_session.session.update_model(model_id)
524
  logger.info(
525
  f"Session {session_id} model → {model_id} "
526
  f"(by {user.get('username', 'unknown')})"
@@ -528,6 +536,27 @@ async def set_session_model(
528
  return {"session_id": session_id, "model": model_id}
529
 
530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  @router.get("/user/quota")
532
  async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
533
  """Return the user's plan tier and today's Claude-session quota state."""
@@ -545,14 +574,7 @@ async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
545
  @router.get("/user/jobs-access")
546
  async def get_jobs_access_info(request: Request, user: dict = Depends(get_current_user)) -> dict:
547
  """Return whether the current token can run HF Jobs and under which namespaces."""
548
- token = None
549
- auth_header = request.headers.get("Authorization", "")
550
- if auth_header.startswith("Bearer "):
551
- token = auth_header[7:]
552
- if not token:
553
- token = request.cookies.get("hf_access_token")
554
- if not token:
555
- token = os.environ.get("HF_TOKEN")
556
 
557
  access = await get_jobs_access(token or "")
558
  return {
@@ -566,7 +588,7 @@ async def get_jobs_access_info(request: Request, user: dict = Depends(get_curren
566
  @router.get("/sessions", response_model=list[SessionInfo])
567
  async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
568
  """List sessions belonging to the authenticated user."""
569
- sessions = session_manager.list_sessions(user_id=user["user_id"])
570
  return [SessionInfo(**s) for s in sessions]
571
 
572
 
@@ -575,7 +597,7 @@ async def delete_session(
575
  session_id: str, user: dict = Depends(get_current_user)
576
  ) -> dict:
577
  """Delete a session. Only accessible by the session owner."""
578
- _check_session_access(session_id, user)
579
  success = await session_manager.delete_session(session_id)
580
  if not success:
581
  raise HTTPException(status_code=404, detail="Session not found")
@@ -587,10 +609,8 @@ async def submit_input(
587
  request: SubmitRequest, user: dict = Depends(get_current_user)
588
  ) -> dict:
589
  """Submit user input to a session. Only accessible by the session owner."""
590
- _check_session_access(request.session_id, user)
591
- agent_session = session_manager.sessions.get(request.session_id)
592
- if agent_session is not None:
593
- await _enforce_claude_quota(user, agent_session)
594
  success = await session_manager.submit_user_input(request.session_id, request.text)
595
  if not success:
596
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -602,10 +622,7 @@ async def submit_approval(
602
  request: ApprovalRequest, user: dict = Depends(get_current_user)
603
  ) -> dict:
604
  """Submit tool approvals to a session. Only accessible by the session owner."""
605
- _check_session_access(request.session_id, user)
606
- agent_session = session_manager.sessions.get(request.session_id)
607
- if agent_session is None:
608
- raise HTTPException(status_code=404, detail="Session not found or inactive")
609
  approvals = [
610
  {
611
  "tool_call_id": a.tool_call_id,
@@ -630,9 +647,7 @@ async def chat_sse(
630
  user: dict = Depends(get_current_user),
631
  ) -> StreamingResponse:
632
  """SSE endpoint: submit input or approval, then stream events until turn ends."""
633
- _check_session_access(session_id, user)
634
-
635
- agent_session = session_manager.sessions.get(session_id)
636
  if not agent_session or not agent_session.is_active:
637
  raise HTTPException(status_code=404, detail="Session not found or inactive")
638
 
@@ -698,10 +713,7 @@ async def record_pro_click(
698
  user: dict = Depends(get_current_user),
699
  ) -> dict:
700
  """Record a click on a Pro upgrade CTA shown from inside a session."""
701
- _check_session_access(session_id, user)
702
- agent_session = session_manager.sessions.get(session_id)
703
- if not agent_session:
704
- raise HTTPException(status_code=404, detail="Session not found")
705
 
706
  from agent.core import telemetry
707
  await telemetry.record_pro_cta_click(
@@ -723,12 +735,53 @@ _TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted"
723
  _SSE_KEEPALIVE_SECONDS = 15
724
 
725
 
726
- def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  """Build a StreamingResponse that drains *event_queue* as SSE,
728
  sending keepalive comments every 15 s to prevent proxy timeouts."""
729
 
730
  async def event_generator():
731
  try:
 
 
 
 
 
 
 
 
 
732
  while True:
733
  try:
734
  msg = await asyncio.wait_for(
@@ -739,7 +792,7 @@ def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse:
739
  yield ": keepalive\n\n"
740
  continue
741
  event_type = msg.get("event_type", "")
742
- yield f"data: {json.dumps(msg)}\n\n"
743
  if event_type in _TERMINAL_EVENTS:
744
  break
745
  finally:
@@ -759,6 +812,7 @@ def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse:
759
  @router.get("/events/{session_id}")
760
  async def subscribe_events(
761
  session_id: str,
 
762
  user: dict = Depends(get_current_user),
763
  ) -> StreamingResponse:
764
  """Subscribe to events for a running session without submitting new input.
@@ -766,15 +820,21 @@ async def subscribe_events(
766
  Used by the frontend to re-attach after a connection drop (e.g. screen
767
  sleep). Returns 404 if the session isn't active or isn't processing.
768
  """
769
- _check_session_access(session_id, user)
770
-
771
- agent_session = session_manager.sessions.get(session_id)
772
  if not agent_session or not agent_session.is_active:
773
  raise HTTPException(status_code=404, detail="Session not found or inactive")
774
 
 
 
775
  broadcaster = agent_session.broadcaster
776
  sub_id, event_queue = broadcaster.subscribe()
777
- return _sse_response(broadcaster, event_queue, sub_id)
 
 
 
 
 
 
778
 
779
 
780
  @router.post("/interrupt/{session_id}")
@@ -782,7 +842,7 @@ async def interrupt_session(
782
  session_id: str, user: dict = Depends(get_current_user)
783
  ) -> dict:
784
  """Interrupt the current operation in a session."""
785
- _check_session_access(session_id, user)
786
  success = await session_manager.interrupt(session_id)
787
  if not success:
788
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -794,17 +854,16 @@ async def get_session_messages(
794
  session_id: str, user: dict = Depends(get_current_user)
795
  ) -> list[dict]:
796
  """Return the session's message history from memory."""
797
- _check_session_access(session_id, user)
798
- agent_session = session_manager.sessions.get(session_id)
799
  if not agent_session or not agent_session.is_active:
800
  raise HTTPException(status_code=404, detail="Session not found or inactive")
801
- return [msg.model_dump() for msg in agent_session.session.context_manager.items]
802
 
803
 
804
  @router.post("/undo/{session_id}")
805
  async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict:
806
  """Undo the last turn in a session."""
807
- _check_session_access(session_id, user)
808
  success = await session_manager.undo(session_id)
809
  if not success:
810
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -816,7 +875,7 @@ async def truncate_session(
816
  session_id: str, body: TruncateRequest, user: dict = Depends(get_current_user)
817
  ) -> dict:
818
  """Truncate conversation to before a specific user message."""
819
- _check_session_access(session_id, user)
820
  success = await session_manager.truncate(session_id, body.user_message_index)
821
  if not success:
822
  raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range")
@@ -828,7 +887,7 @@ async def compact_session(
828
  session_id: str, user: dict = Depends(get_current_user)
829
  ) -> dict:
830
  """Compact the context in a session."""
831
- _check_session_access(session_id, user)
832
  success = await session_manager.compact(session_id)
833
  if not success:
834
  raise HTTPException(status_code=404, detail="Session not found or inactive")
@@ -840,13 +899,12 @@ async def shutdown_session(
840
  session_id: str, user: dict = Depends(get_current_user)
841
  ) -> dict:
842
  """Shutdown a session."""
843
- _check_session_access(session_id, user)
844
  success = await session_manager.shutdown_session(session_id)
845
  if not success:
846
  raise HTTPException(status_code=404, detail="Session not found or inactive")
847
  return {"status": "shutdown_requested", "session_id": session_id}
848
 
849
-
850
  @router.post("/feedback/{session_id}")
851
  async def submit_feedback(
852
  session_id: str,
@@ -859,10 +917,7 @@ async def submit_feedback(
859
  turn_index?: int, comment?: str, message_id?: str}
860
  Appended as a `feedback` event and saved with the session trajectory.
861
  """
862
- _check_session_access(session_id, user)
863
- agent_session = session_manager.sessions.get(session_id)
864
- if not agent_session:
865
- raise HTTPException(status_code=404, detail="Session not found")
866
 
867
  rating = body.get("rating")
868
  if rating not in {"up", "down", "outcome_success", "outcome_fail"}:
 
24
  HealthResponse,
25
  LLMHealthResponse,
26
  SessionInfo,
27
+ SessionNotificationsRequest,
28
  SessionResponse,
29
  SubmitRequest,
30
  TruncateRequest,
 
34
  import user_quotas
35
 
36
  from agent.core.hf_access import get_jobs_access
37
+ from agent.core.hf_tokens import resolve_hf_request_token, resolve_hf_router_token
38
  from agent.core.llm_params import _resolve_llm_params
39
 
40
  logger = logging.getLogger(__name__)
 
120
  if not _is_anthropic_model(model_name):
121
  return
122
  user_id = user["user_id"]
 
123
  cap = user_quotas.daily_cap_for(user.get("plan"))
124
+ new_count = await user_quotas.try_increment_claude(user_id, cap)
125
+ if new_count is None:
126
  raise HTTPException(
127
  status_code=429,
128
  detail={
 
135
  ),
136
  },
137
  )
 
138
  agent_session.claude_counted = True
139
+ await session_manager.persist_session_snapshot(agent_session)
140
 
141
 
142
  async def _enforce_jobs_access_for_approvals(
 
195
  "The selected jobs namespace is not one of your eligible paid organizations. "
196
  f"Allowed namespaces: {', '.join(access.paid_org_names)}"
197
  ),
198
+ "plan": user.get("plan", "free"),
199
+ "tool_call_ids": invalid_namespace,
200
+ "eligible_namespaces": access.paid_org_names,
201
  },
202
  )
203
  missing_namespace = [
 
241
  )
242
 
243
 
244
+ async def _check_session_access(
245
+ session_id: str,
246
+ user: dict[str, Any],
247
+ request: Request | None = None,
248
+ ) -> AgentSession:
249
+ """Verify and lazily load the user's session. Raises 403 or 404."""
250
+ hf_token = resolve_hf_request_token(request) if request is not None else user.get("hf_token")
251
+ agent_session = await session_manager.ensure_session_loaded(
252
+ session_id,
253
+ user["user_id"],
254
+ hf_token=hf_token,
255
+ )
256
+ if not agent_session:
257
  raise HTTPException(status_code=404, detail="Session not found")
258
+ if user["user_id"] != "dev" and agent_session.user_id not in {user["user_id"], "dev"}:
259
  raise HTTPException(status_code=403, detail="Access denied to this session")
260
+ return agent_session
261
 
262
 
263
  @router.get("/health", response_model=HealthResponse)
 
347
  reasoning model — reasoning_effort=low keeps the reasoning budget small
348
  so the 60-token output budget isn't consumed before the title is written.
349
  """
350
+ api_key = resolve_hf_router_token(
351
+ user.get("hf_token") if isinstance(user, dict) else None
 
 
352
  )
353
  try:
354
  response = await acompletion(
 
379
  title = title.translate(_TITLE_STRIP_CHARS).strip()
380
  if len(title) > 50:
381
  title = title[:50].rstrip() + "…"
382
+ try:
383
+ await _check_session_access(request.session_id, user)
384
+ await session_manager.update_session_title(request.session_id, title)
385
+ except Exception:
386
+ logger.debug("Skipping title persistence for missing session %s", request.session_id)
387
  return {"title": title}
388
  except Exception as e:
389
  logger.warning(f"Title generation failed: {e}")
390
  fallback = request.text.strip()
391
  title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback
392
+ try:
393
+ await _check_session_access(request.session_id, user)
394
+ await session_manager.update_session_title(request.session_id, title)
395
+ except Exception:
396
+ logger.debug("Skipping fallback title persistence for missing session %s", request.session_id)
397
  return {"title": title}
398
 
399
 
 
414
  Returns 503 if the server or user has reached the session limit.
415
  """
416
  # Extract the user's HF token (Bearer header, HttpOnly cookie, or env var)
417
+ hf_token = resolve_hf_request_token(request)
 
 
 
 
 
 
 
418
 
419
  # Optional model override. Empty body falls back to the config default.
420
  model: str | None = None
 
460
  if not isinstance(messages, list) or not messages:
461
  raise HTTPException(status_code=400, detail="Missing 'messages' array")
462
 
463
+ hf_token = resolve_hf_request_token(request)
 
 
 
 
 
 
 
464
 
465
  model = body.get("model")
466
  valid_ids = {m["id"] for m in AVAILABLE_MODELS}
 
497
  session_id: str, user: dict = Depends(get_current_user)
498
  ) -> SessionInfo:
499
  """Get session information. Only accessible by the session owner."""
500
+ await _check_session_access(session_id, user)
501
  info = session_manager.get_session_info(session_id)
502
  return SessionInfo(**info)
503
 
 
518
  Switching TO an Anthropic model requires HF org membership (PR #63);
519
  free-model switches are unrestricted.
520
  """
521
+ agent_session = await _check_session_access(session_id, user, request)
522
  model_id = body.get("model")
523
  if not model_id:
524
  raise HTTPException(status_code=400, detail="Missing 'model' field")
 
526
  if model_id not in valid_ids:
527
  raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}")
528
  await _require_hf_for_anthropic(request, model_id)
 
529
  if not agent_session:
530
  raise HTTPException(status_code=404, detail="Session not found")
531
+ await session_manager.update_session_model(session_id, model_id)
532
  logger.info(
533
  f"Session {session_id} model → {model_id} "
534
  f"(by {user.get('username', 'unknown')})"
 
536
  return {"session_id": session_id, "model": model_id}
537
 
538
 
539
+ @router.post("/session/{session_id}/notifications")
540
+ async def set_session_notifications(
541
+ session_id: str,
542
+ body: SessionNotificationsRequest,
543
+ user: dict = Depends(get_current_user),
544
+ ) -> dict:
545
+ """Replace the session's auto-notification destinations."""
546
+ agent_session = await _check_session_access(session_id, user)
547
+ try:
548
+ destinations = session_manager.set_notification_destinations(
549
+ session_id, body.destinations
550
+ )
551
+ except ValueError as e:
552
+ raise HTTPException(status_code=400, detail=str(e))
553
+ await session_manager.persist_session_snapshot(agent_session)
554
+ return {
555
+ "session_id": session_id,
556
+ "notification_destinations": destinations,
557
+ }
558
+
559
+
560
  @router.get("/user/quota")
561
  async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
562
  """Return the user's plan tier and today's Claude-session quota state."""
 
574
  @router.get("/user/jobs-access")
575
  async def get_jobs_access_info(request: Request, user: dict = Depends(get_current_user)) -> dict:
576
  """Return whether the current token can run HF Jobs and under which namespaces."""
577
+ token = resolve_hf_request_token(request)
 
 
 
 
 
 
 
578
 
579
  access = await get_jobs_access(token or "")
580
  return {
 
588
  @router.get("/sessions", response_model=list[SessionInfo])
589
  async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
590
  """List sessions belonging to the authenticated user."""
591
+ sessions = await session_manager.list_sessions(user_id=user["user_id"])
592
  return [SessionInfo(**s) for s in sessions]
593
 
594
 
 
597
  session_id: str, user: dict = Depends(get_current_user)
598
  ) -> dict:
599
  """Delete a session. Only accessible by the session owner."""
600
+ await _check_session_access(session_id, user)
601
  success = await session_manager.delete_session(session_id)
602
  if not success:
603
  raise HTTPException(status_code=404, detail="Session not found")
 
609
  request: SubmitRequest, user: dict = Depends(get_current_user)
610
  ) -> dict:
611
  """Submit user input to a session. Only accessible by the session owner."""
612
+ agent_session = await _check_session_access(request.session_id, user)
613
+ await _enforce_claude_quota(user, agent_session)
 
 
614
  success = await session_manager.submit_user_input(request.session_id, request.text)
615
  if not success:
616
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
622
  request: ApprovalRequest, user: dict = Depends(get_current_user)
623
  ) -> dict:
624
  """Submit tool approvals to a session. Only accessible by the session owner."""
625
+ agent_session = await _check_session_access(request.session_id, user)
 
 
 
626
  approvals = [
627
  {
628
  "tool_call_id": a.tool_call_id,
 
647
  user: dict = Depends(get_current_user),
648
  ) -> StreamingResponse:
649
  """SSE endpoint: submit input or approval, then stream events until turn ends."""
650
+ agent_session = await _check_session_access(session_id, user, request)
 
 
651
  if not agent_session or not agent_session.is_active:
652
  raise HTTPException(status_code=404, detail="Session not found or inactive")
653
 
 
713
  user: dict = Depends(get_current_user),
714
  ) -> dict:
715
  """Record a click on a Pro upgrade CTA shown from inside a session."""
716
+ agent_session = await _check_session_access(session_id, user)
 
 
 
717
 
718
  from agent.core import telemetry
719
  await telemetry.record_pro_cta_click(
 
735
  _SSE_KEEPALIVE_SECONDS = 15
736
 
737
 
738
+ def _last_event_seq(request: Request) -> int:
739
+ raw = request.headers.get("last-event-id") or request.query_params.get("after") or "0"
740
+ try:
741
+ return max(0, int(raw))
742
+ except (TypeError, ValueError):
743
+ return 0
744
+
745
+
746
+ def _format_sse(msg: dict[str, Any]) -> str:
747
+ seq = msg.get("seq")
748
+ body = {"event_type": msg.get("event_type"), "data": msg.get("data") or {}}
749
+ if seq is not None:
750
+ body["seq"] = seq
751
+ return f"id: {seq}\ndata: {json.dumps(body)}\n\n"
752
+ return f"data: {json.dumps(body)}\n\n"
753
+
754
+
755
+ def _event_doc_to_msg(doc: dict[str, Any]) -> dict[str, Any]:
756
+ return {
757
+ "event_type": doc.get("event_type"),
758
+ "data": doc.get("data") or {},
759
+ "seq": doc.get("seq"),
760
+ }
761
+
762
+
763
+ def _sse_response(
764
+ broadcaster,
765
+ event_queue,
766
+ sub_id,
767
+ *,
768
+ replay_events: list[dict[str, Any]] | None = None,
769
+ after_seq: int = 0,
770
+ ) -> StreamingResponse:
771
  """Build a StreamingResponse that drains *event_queue* as SSE,
772
  sending keepalive comments every 15 s to prevent proxy timeouts."""
773
 
774
  async def event_generator():
775
  try:
776
+ for doc in replay_events or []:
777
+ msg = _event_doc_to_msg(doc)
778
+ seq = msg.get("seq")
779
+ if isinstance(seq, int) and seq <= after_seq:
780
+ continue
781
+ yield _format_sse(msg)
782
+ if msg.get("event_type", "") in _TERMINAL_EVENTS:
783
+ return
784
+
785
  while True:
786
  try:
787
  msg = await asyncio.wait_for(
 
792
  yield ": keepalive\n\n"
793
  continue
794
  event_type = msg.get("event_type", "")
795
+ yield _format_sse(msg)
796
  if event_type in _TERMINAL_EVENTS:
797
  break
798
  finally:
 
812
  @router.get("/events/{session_id}")
813
  async def subscribe_events(
814
  session_id: str,
815
+ request: Request,
816
  user: dict = Depends(get_current_user),
817
  ) -> StreamingResponse:
818
  """Subscribe to events for a running session without submitting new input.
 
820
  Used by the frontend to re-attach after a connection drop (e.g. screen
821
  sleep). Returns 404 if the session isn't active or isn't processing.
822
  """
823
+ agent_session = await _check_session_access(session_id, user, request)
 
 
824
  if not agent_session or not agent_session.is_active:
825
  raise HTTPException(status_code=404, detail="Session not found or inactive")
826
 
827
+ after_seq = _last_event_seq(request)
828
+ replay_events = await session_manager._store().load_events_after(session_id, after_seq)
829
  broadcaster = agent_session.broadcaster
830
  sub_id, event_queue = broadcaster.subscribe()
831
+ return _sse_response(
832
+ broadcaster,
833
+ event_queue,
834
+ sub_id,
835
+ replay_events=replay_events,
836
+ after_seq=after_seq,
837
+ )
838
 
839
 
840
  @router.post("/interrupt/{session_id}")
 
842
  session_id: str, user: dict = Depends(get_current_user)
843
  ) -> dict:
844
  """Interrupt the current operation in a session."""
845
+ await _check_session_access(session_id, user)
846
  success = await session_manager.interrupt(session_id)
847
  if not success:
848
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
854
  session_id: str, user: dict = Depends(get_current_user)
855
  ) -> list[dict]:
856
  """Return the session's message history from memory."""
857
+ agent_session = await _check_session_access(session_id, user)
 
858
  if not agent_session or not agent_session.is_active:
859
  raise HTTPException(status_code=404, detail="Session not found or inactive")
860
+ return [msg.model_dump(mode="json") for msg in agent_session.session.context_manager.items]
861
 
862
 
863
  @router.post("/undo/{session_id}")
864
  async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict:
865
  """Undo the last turn in a session."""
866
+ await _check_session_access(session_id, user)
867
  success = await session_manager.undo(session_id)
868
  if not success:
869
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
875
  session_id: str, body: TruncateRequest, user: dict = Depends(get_current_user)
876
  ) -> dict:
877
  """Truncate conversation to before a specific user message."""
878
+ await _check_session_access(session_id, user)
879
  success = await session_manager.truncate(session_id, body.user_message_index)
880
  if not success:
881
  raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range")
 
887
  session_id: str, user: dict = Depends(get_current_user)
888
  ) -> dict:
889
  """Compact the context in a session."""
890
+ await _check_session_access(session_id, user)
891
  success = await session_manager.compact(session_id)
892
  if not success:
893
  raise HTTPException(status_code=404, detail="Session not found or inactive")
 
899
  session_id: str, user: dict = Depends(get_current_user)
900
  ) -> dict:
901
  """Shutdown a session."""
902
+ await _check_session_access(session_id, user)
903
  success = await session_manager.shutdown_session(session_id)
904
  if not success:
905
  raise HTTPException(status_code=404, detail="Session not found or inactive")
906
  return {"status": "shutdown_requested", "session_id": session_id}
907
 
 
908
  @router.post("/feedback/{session_id}")
909
  async def submit_feedback(
910
  session_id: str,
 
917
  turn_index?: int, comment?: str, message_id?: str}
918
  Appended as a `feedback` event and saved with the session trajectory.
919
  """
920
+ agent_session = await _check_session_access(session_id, user)
 
 
 
921
 
922
  rating = body.get("rating")
923
  if rating not in {"up", "down", "outcome_success", "outcome_fail"}:
backend/session_manager.py CHANGED
@@ -1,6 +1,7 @@
1
  """Session manager for handling multiple concurrent agent sessions."""
2
 
3
  import asyncio
 
4
  import logging
5
  import uuid
6
  from dataclasses import dataclass, field
@@ -10,7 +11,9 @@ from typing import Any, Optional
10
 
11
  from agent.config import load_config
12
  from agent.core.agent_loop import process_submission
 
13
  from agent.core.session import Event, OpType, Session
 
14
  from agent.core.tools import ToolRouter
15
 
16
  # Get project root (parent of backend directory)
@@ -41,9 +44,8 @@ logger = logging.getLogger(__name__)
41
  class EventBroadcaster:
42
  """Reads from the agent's event queue and fans out to SSE subscribers.
43
 
44
- Events that arrive when no subscribers are listening are discarded.
45
- With SSE each turn is a separate request, so there is no reconnect
46
- scenario that would need buffered replay.
47
  """
48
 
49
  def __init__(self, event_queue: asyncio.Queue):
@@ -67,7 +69,7 @@ class EventBroadcaster:
67
  while True:
68
  try:
69
  event: Event = await self._source.get()
70
- msg = {"event_type": event.event_type, "data": event.data}
71
  for q in self._subscribers.values():
72
  await q.put(msg)
73
  except asyncio.CancelledError:
@@ -91,6 +93,7 @@ class AgentSession:
91
  is_active: bool = True
92
  is_processing: bool = False # True while a submission is being executed
93
  broadcaster: Any = None
 
94
  # True once this session has been counted against the user's daily
95
  # Claude quota. Guards double-counting when the user re-selects an
96
  # Anthropic model mid-session.
@@ -119,8 +122,27 @@ class SessionManager:
119
 
120
  def __init__(self, config_path: str | None = None) -> None:
121
  self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
 
122
  self.sessions: dict[str, AgentSession] = {}
123
  self._lock = asyncio.Lock()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def _count_user_sessions(self, user_id: str) -> int:
126
  """Count active sessions owned by a specific user."""
@@ -130,6 +152,314 @@ class SessionManager:
130
  if s.user_id == user_id and s.is_active
131
  )
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  async def create_session(
134
  self,
135
  user_id: str = "dev",
@@ -178,27 +508,14 @@ class SessionManager:
178
  event_queue: asyncio.Queue = asyncio.Queue()
179
 
180
  # Run blocking constructors in a thread to keep the event loop responsive.
181
- # Without this, Session.__init__ ContextManager → litellm.get_max_tokens()
182
- # blocks all HTTP/SSE handling.
183
- import time as _time
184
-
185
- def _create_session_sync():
186
- t0 = _time.monotonic()
187
- tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token)
188
- # Deep-copy config so each session's model switches independently —
189
- # tab A picking GLM doesn't flip tab B off Claude.
190
- session_config = self.config.model_copy(deep=True)
191
- if model:
192
- session_config.model_name = model
193
- session = Session(
194
- event_queue, config=session_config, tool_router=tool_router,
195
- hf_token=hf_token,
196
- )
197
- t1 = _time.monotonic()
198
- logger.info(f"Session initialized in {t1 - t0:.2f}s")
199
- return tool_router, session
200
-
201
- tool_router, session = await asyncio.to_thread(_create_session_sync)
202
 
203
  # Create wrapper
204
  agent_session = AgentSession(
@@ -210,14 +527,12 @@ class SessionManager:
210
  hf_token=hf_token,
211
  )
212
 
213
- async with self._lock:
214
- self.sessions[session_id] = agent_session
215
-
216
- # Start the agent loop task
217
- task = asyncio.create_task(
218
- self._run_session(session_id, submission_queue, event_queue, tool_router)
219
  )
220
- agent_session.task = task
221
 
222
  logger.info(f"Created session {session_id} for user {user_id}")
223
  return session_id
@@ -283,21 +598,38 @@ class SessionManager:
283
  ),
284
  )
285
  session.context_manager.items.append(seed)
 
286
  return len(parsed)
287
 
288
  @staticmethod
289
  async def _cleanup_sandbox(session: Session) -> None:
290
- """Delete the sandbox Space if one was created for this session."""
 
 
 
 
 
291
  sandbox = getattr(session, "sandbox", None)
292
- if sandbox and getattr(sandbox, "_owns_space", False):
293
- space_id = getattr(sandbox, "space_id", None)
 
 
 
 
294
  try:
295
- logger.info(f"Deleting sandbox {space_id}...")
296
  await asyncio.to_thread(sandbox.delete)
297
  from agent.core import telemetry
298
  await telemetry.record_sandbox_destroy(session, sandbox)
 
299
  except Exception as e:
300
- logger.warning(f"Failed to delete sandbox {space_id}: {e}")
 
 
 
 
 
 
301
 
302
  async def _run_session(
303
  self,
@@ -337,6 +669,7 @@ class SessionManager:
337
  should_continue = await process_submission(session, submission)
338
  finally:
339
  agent_session.is_processing = False
 
340
  if not should_continue:
341
  break
342
  except asyncio.TimeoutError:
@@ -371,6 +704,11 @@ class SessionManager:
371
  async with self._lock:
372
  if session_id in self.sessions:
373
  self.sessions[session_id].is_active = False
 
 
 
 
 
374
 
375
  logger.info(f"Session {session_id} ended")
376
 
@@ -420,7 +758,10 @@ class SessionManager:
420
  agent_session = self.sessions.get(session_id)
421
  if not agent_session or not agent_session.is_active:
422
  return False
423
- return agent_session.session.context_manager.truncate_to_user_message(user_message_index)
 
 
 
424
 
425
  async def compact(self, session_id: str) -> bool:
426
  """Compact context in a session."""
@@ -445,12 +786,15 @@ class SessionManager:
445
  return success
446
 
447
  async def delete_session(self, session_id: str) -> bool:
448
- """Delete a session entirely."""
449
  async with self._lock:
450
  agent_session = self.sessions.pop(session_id, None)
451
 
452
  if not agent_session:
453
- return False
 
 
 
454
 
455
  # Clean up sandbox Space before cancelling the task
456
  await self._cleanup_sandbox(agent_session.session)
@@ -465,6 +809,21 @@ class SessionManager:
465
 
466
  return True
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  def get_session_owner(self, session_id: str) -> str | None:
469
  """Get the user_id that owns a session, or None if session doesn't exist."""
470
  agent_session = self.sessions.get(session_id)
@@ -492,22 +851,7 @@ class SessionManager:
492
  if not agent_session:
493
  return None
494
 
495
- # Extract pending approval tools if any
496
- pending_approval = None
497
- pa = agent_session.session.pending_approval
498
- if pa and pa.get("tool_calls"):
499
- pending_approval = []
500
- for tc in pa["tool_calls"]:
501
- import json
502
- try:
503
- args = json.loads(tc.function.arguments)
504
- except (json.JSONDecodeError, AttributeError):
505
- args = {}
506
- pending_approval.append({
507
- "tool": tc.function.name,
508
- "tool_call_id": tc.id,
509
- "arguments": args,
510
- })
511
 
512
  return {
513
  "session_id": session_id,
@@ -518,16 +862,80 @@ class SessionManager:
518
  "user_id": agent_session.user_id,
519
  "pending_approval": pending_approval,
520
  "model": agent_session.session.config.model_name,
 
 
 
 
521
  }
522
 
523
- def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  """List sessions, optionally filtered by user.
525
 
526
  Args:
527
  user_id: If provided, only return sessions owned by this user.
528
  If "dev", return all sessions (dev mode).
529
  """
530
- results = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  for sid in self.sessions:
532
  info = self.get_session_info(sid)
533
  if not info:
 
1
  """Session manager for handling multiple concurrent agent sessions."""
2
 
3
  import asyncio
4
+ import json
5
  import logging
6
  import uuid
7
  from dataclasses import dataclass, field
 
11
 
12
  from agent.config import load_config
13
  from agent.core.agent_loop import process_submission
14
+ from agent.messaging.gateway import NotificationGateway
15
  from agent.core.session import Event, OpType, Session
16
+ from agent.core.session_persistence import get_session_store
17
  from agent.core.tools import ToolRouter
18
 
19
  # Get project root (parent of backend directory)
 
44
  class EventBroadcaster:
45
  """Reads from the agent's event queue and fans out to SSE subscribers.
46
 
47
+ Events that arrive when no subscribers are listening are discarded by
48
+ this in-memory fanout. Durable replay is handled by session_persistence.
 
49
  """
50
 
51
  def __init__(self, event_queue: asyncio.Queue):
 
69
  while True:
70
  try:
71
  event: Event = await self._source.get()
72
+ msg = {"event_type": event.event_type, "data": event.data, "seq": event.seq}
73
  for q in self._subscribers.values():
74
  await q.put(msg)
75
  except asyncio.CancelledError:
 
93
  is_active: bool = True
94
  is_processing: bool = False # True while a submission is being executed
95
  broadcaster: Any = None
96
+ title: str | None = None
97
  # True once this session has been counted against the user's daily
98
  # Claude quota. Guards double-counting when the user re-selects an
99
  # Anthropic model mid-session.
 
122
 
123
  def __init__(self, config_path: str | None = None) -> None:
124
  self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
125
+ self.messaging_gateway = NotificationGateway(self.config.messaging)
126
  self.sessions: dict[str, AgentSession] = {}
127
  self._lock = asyncio.Lock()
128
+ self.persistence_store = None
129
+
130
+ async def start(self) -> None:
131
+ """Start shared background resources."""
132
+ self.persistence_store = get_session_store()
133
+ await self.persistence_store.init()
134
+ await self.messaging_gateway.start()
135
+
136
+ async def close(self) -> None:
137
+ """Flush and close shared background resources."""
138
+ await self.messaging_gateway.close()
139
+ if self.persistence_store is not None:
140
+ await self.persistence_store.close()
141
+
142
+ def _store(self):
143
+ if self.persistence_store is None:
144
+ self.persistence_store = get_session_store()
145
+ return self.persistence_store
146
 
147
  def _count_user_sessions(self, user_id: str) -> int:
148
  """Count active sessions owned by a specific user."""
 
152
  if s.user_id == user_id and s.is_active
153
  )
154
 
155
+ def _create_session_sync(
156
+ self,
157
+ *,
158
+ session_id: str,
159
+ user_id: str,
160
+ hf_token: str | None,
161
+ model: str | None,
162
+ event_queue: asyncio.Queue,
163
+ notification_destinations: list[str] | None = None,
164
+ ) -> tuple[ToolRouter, Session]:
165
+ """Build blocking per-session resources in a worker thread."""
166
+ import time as _time
167
+
168
+ t0 = _time.monotonic()
169
+ tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token)
170
+ # Deep-copy config so each session's model switches independently —
171
+ # tab A picking GLM doesn't flip tab B off Claude.
172
+ session_config = self.config.model_copy(deep=True)
173
+ if model:
174
+ session_config.model_name = model
175
+ session = Session(
176
+ event_queue=event_queue,
177
+ config=session_config,
178
+ tool_router=tool_router,
179
+ hf_token=hf_token,
180
+ user_id=user_id,
181
+ notification_gateway=self.messaging_gateway,
182
+ notification_destinations=notification_destinations or [],
183
+ session_id=session_id,
184
+ persistence_store=self._store(),
185
+ )
186
+ t1 = _time.monotonic()
187
+ logger.info("Session initialized in %.2fs", t1 - t0)
188
+ return tool_router, session
189
+
190
+ def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
191
+ return [
192
+ msg.model_dump(mode="json")
193
+ for msg in session.context_manager.items
194
+ ]
195
+
196
+ def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]:
197
+ pending = session.pending_approval or {}
198
+ tool_calls = pending.get("tool_calls") or []
199
+ serialized: list[dict[str, Any]] = []
200
+ for tc in tool_calls:
201
+ if hasattr(tc, "model_dump"):
202
+ serialized.append(tc.model_dump(mode="json"))
203
+ elif isinstance(tc, dict):
204
+ serialized.append(tc)
205
+ return serialized
206
+
207
+ @staticmethod
208
+ def _pending_tools_for_api(session: Session) -> list[dict[str, Any]] | None:
209
+ pending = session.pending_approval or {}
210
+ tool_calls = pending.get("tool_calls") or []
211
+ if not tool_calls:
212
+ return None
213
+ result: list[dict[str, Any]] = []
214
+ for tc in tool_calls:
215
+ try:
216
+ args = json.loads(tc.function.arguments)
217
+ except (json.JSONDecodeError, AttributeError, TypeError):
218
+ args = {}
219
+ result.append(
220
+ {
221
+ "tool": getattr(tc.function, "name", None),
222
+ "tool_call_id": getattr(tc, "id", None),
223
+ "arguments": args,
224
+ }
225
+ )
226
+ return result
227
+
228
+ def _restore_pending_approval(
229
+ self, session: Session, pending_approval: list[dict[str, Any]] | None
230
+ ) -> None:
231
+ if not pending_approval:
232
+ session.pending_approval = None
233
+ return
234
+ from litellm import ChatCompletionMessageToolCall as ToolCall
235
+
236
+ restored = []
237
+ for raw in pending_approval:
238
+ try:
239
+ if "function" in raw:
240
+ restored.append(ToolCall(**raw))
241
+ else:
242
+ restored.append(
243
+ ToolCall(
244
+ id=raw["tool_call_id"],
245
+ type="function",
246
+ function={
247
+ "name": raw["tool"],
248
+ "arguments": json.dumps(raw.get("arguments") or {}),
249
+ },
250
+ )
251
+ )
252
+ except Exception as e:
253
+ logger.warning("Dropping malformed pending approval: %s", e)
254
+ session.pending_approval = {"tool_calls": restored} if restored else None
255
+
256
+ @staticmethod
257
+ def _pending_docs_for_api(
258
+ pending_approval: list[dict[str, Any]] | None,
259
+ ) -> list[dict[str, Any]] | None:
260
+ if not pending_approval:
261
+ return None
262
+ result: list[dict[str, Any]] = []
263
+ for raw in pending_approval:
264
+ if "function" in raw:
265
+ function = raw.get("function") or {}
266
+ try:
267
+ args = json.loads(function.get("arguments") or "{}")
268
+ except (json.JSONDecodeError, TypeError):
269
+ args = {}
270
+ result.append(
271
+ {
272
+ "tool": function.get("name"),
273
+ "tool_call_id": raw.get("id"),
274
+ "arguments": args,
275
+ }
276
+ )
277
+ elif {"tool", "tool_call_id"}.issubset(raw):
278
+ result.append(
279
+ {
280
+ "tool": raw.get("tool"),
281
+ "tool_call_id": raw.get("tool_call_id"),
282
+ "arguments": raw.get("arguments") or {},
283
+ }
284
+ )
285
+ return result or None
286
+
287
+ @staticmethod
288
+ def _runtime_state(agent_session: AgentSession) -> str:
289
+ if agent_session.session.pending_approval:
290
+ return "waiting_approval"
291
+ if agent_session.is_processing:
292
+ return "processing"
293
+ if not agent_session.is_active:
294
+ return "ended"
295
+ return "idle"
296
+
297
+ async def _start_agent_session(
298
+ self,
299
+ *,
300
+ agent_session: AgentSession,
301
+ event_queue: asyncio.Queue,
302
+ tool_router: ToolRouter,
303
+ ) -> AgentSession:
304
+ async with self._lock:
305
+ existing = self.sessions.get(agent_session.session_id)
306
+ if existing:
307
+ return existing
308
+ self.sessions[agent_session.session_id] = agent_session
309
+
310
+ task = asyncio.create_task(
311
+ self._run_session(
312
+ agent_session.session_id,
313
+ agent_session.submission_queue,
314
+ event_queue,
315
+ tool_router,
316
+ )
317
+ )
318
+ agent_session.task = task
319
+ return agent_session
320
+
321
+ @staticmethod
322
+ def _can_access_session(agent_session: AgentSession, user_id: str) -> bool:
323
+ return (
324
+ user_id == "dev"
325
+ or agent_session.user_id == "dev"
326
+ or agent_session.user_id == user_id
327
+ )
328
+
329
+ @staticmethod
330
+ def _update_hf_token(agent_session: AgentSession, hf_token: str | None) -> None:
331
+ if not hf_token:
332
+ return
333
+ agent_session.hf_token = hf_token
334
+ agent_session.session.hf_token = hf_token
335
+
336
+ async def persist_session_snapshot(
337
+ self,
338
+ agent_session: AgentSession,
339
+ *,
340
+ runtime_state: str | None = None,
341
+ status: str = "active",
342
+ ) -> None:
343
+ """Persist the current runtime context snapshot."""
344
+ store = self._store()
345
+ if not getattr(store, "enabled", False):
346
+ return
347
+ try:
348
+ await store.save_snapshot(
349
+ session_id=agent_session.session_id,
350
+ user_id=agent_session.user_id,
351
+ model=agent_session.session.config.model_name,
352
+ title=agent_session.title,
353
+ messages=self._serialize_messages(agent_session.session),
354
+ runtime_state=runtime_state or self._runtime_state(agent_session),
355
+ status=status,
356
+ turn_count=agent_session.session.turn_count,
357
+ pending_approval=self._serialize_pending_approval(agent_session.session),
358
+ claude_counted=agent_session.claude_counted,
359
+ created_at=agent_session.created_at,
360
+ notification_destinations=list(
361
+ agent_session.session.notification_destinations
362
+ ),
363
+ )
364
+ except Exception as e:
365
+ logger.warning(
366
+ "Failed to persist snapshot for %s: %s",
367
+ agent_session.session_id,
368
+ e,
369
+ )
370
+
371
+ async def ensure_session_loaded(
372
+ self,
373
+ session_id: str,
374
+ user_id: str,
375
+ hf_token: str | None = None,
376
+ ) -> AgentSession | None:
377
+ """Return a live runtime session, lazily restoring it from Mongo."""
378
+ async with self._lock:
379
+ existing = self.sessions.get(session_id)
380
+ if existing:
381
+ if self._can_access_session(existing, user_id):
382
+ self._update_hf_token(existing, hf_token)
383
+ return existing
384
+ return None
385
+
386
+ store = self._store()
387
+ loaded = await store.load_session(session_id)
388
+ if not loaded:
389
+ return None
390
+
391
+ async with self._lock:
392
+ existing = self.sessions.get(session_id)
393
+ if existing:
394
+ if self._can_access_session(existing, user_id):
395
+ self._update_hf_token(existing, hf_token)
396
+ return existing
397
+ return None
398
+
399
+ meta = loaded.get("metadata") or {}
400
+ owner = str(meta.get("user_id") or "")
401
+ if user_id != "dev" and owner != "dev" and owner != user_id:
402
+ return None
403
+
404
+ from litellm import Message
405
+
406
+ model = meta.get("model") or self.config.model_name
407
+ event_queue: asyncio.Queue = asyncio.Queue()
408
+ submission_queue: asyncio.Queue = asyncio.Queue()
409
+ tool_router, session = await asyncio.to_thread(
410
+ self._create_session_sync,
411
+ session_id=session_id,
412
+ user_id=owner or user_id,
413
+ hf_token=hf_token,
414
+ model=model,
415
+ event_queue=event_queue,
416
+ notification_destinations=meta.get("notification_destinations") or [],
417
+ )
418
+
419
+ restored_messages: list[Message] = []
420
+ for raw in loaded.get("messages") or []:
421
+ if not isinstance(raw, dict) or raw.get("role") == "system":
422
+ continue
423
+ try:
424
+ restored_messages.append(Message.model_validate(raw))
425
+ except Exception as e:
426
+ logger.warning("Dropping malformed restored message: %s", e)
427
+ if restored_messages:
428
+ # Keep the freshly-rendered system prompt, then attach the durable
429
+ # non-system context so tools/date/user context stay current.
430
+ session.context_manager.items = [session.context_manager.items[0], *restored_messages]
431
+
432
+ self._restore_pending_approval(session, meta.get("pending_approval") or [])
433
+ session.turn_count = int(meta.get("turn_count") or 0)
434
+
435
+ created_at = meta.get("created_at")
436
+ if not isinstance(created_at, datetime):
437
+ created_at = datetime.utcnow()
438
+
439
+ agent_session = AgentSession(
440
+ session_id=session_id,
441
+ session=session,
442
+ tool_router=tool_router,
443
+ submission_queue=submission_queue,
444
+ user_id=owner or user_id,
445
+ hf_token=hf_token,
446
+ created_at=created_at,
447
+ is_active=True,
448
+ is_processing=False,
449
+ claude_counted=bool(meta.get("claude_counted")),
450
+ title=meta.get("title"),
451
+ )
452
+ started = await self._start_agent_session(
453
+ agent_session=agent_session,
454
+ event_queue=event_queue,
455
+ tool_router=tool_router,
456
+ )
457
+ if started is not agent_session:
458
+ self._update_hf_token(started, hf_token)
459
+ return started
460
+ logger.info("Restored session %s for user %s", session_id, owner or user_id)
461
+ return agent_session
462
+
463
  async def create_session(
464
  self,
465
  user_id: str = "dev",
 
508
  event_queue: asyncio.Queue = asyncio.Queue()
509
 
510
  # Run blocking constructors in a thread to keep the event loop responsive.
511
+ tool_router, session = await asyncio.to_thread(
512
+ self._create_session_sync,
513
+ session_id=session_id,
514
+ user_id=user_id,
515
+ hf_token=hf_token,
516
+ model=model,
517
+ event_queue=event_queue,
518
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
  # Create wrapper
521
  agent_session = AgentSession(
 
527
  hf_token=hf_token,
528
  )
529
 
530
+ await self._start_agent_session(
531
+ agent_session=agent_session,
532
+ event_queue=event_queue,
533
+ tool_router=tool_router,
 
 
534
  )
535
+ await self.persist_session_snapshot(agent_session, runtime_state="idle")
536
 
537
  logger.info(f"Created session {session_id} for user {user_id}")
538
  return session_id
 
598
  ),
599
  )
600
  session.context_manager.items.append(seed)
601
+ await self.persist_session_snapshot(agent_session, runtime_state="idle")
602
  return len(parsed)
603
 
604
  @staticmethod
605
  async def _cleanup_sandbox(session: Session) -> None:
606
+ """Delete the sandbox Space if one was created for this session.
607
+
608
+ Retries on transient failures (HF API 5xx, rate-limit, network blips)
609
+ with exponential backoff. A single missed delete = a permanently
610
+ orphaned Space, so the cost of an extra retry beats the alternative.
611
+ """
612
  sandbox = getattr(session, "sandbox", None)
613
+ if not (sandbox and getattr(sandbox, "_owns_space", False)):
614
+ return
615
+
616
+ space_id = getattr(sandbox, "space_id", None)
617
+ last_err: Exception | None = None
618
+ for attempt in range(3):
619
  try:
620
+ logger.info(f"Deleting sandbox {space_id} (attempt {attempt + 1}/3)...")
621
  await asyncio.to_thread(sandbox.delete)
622
  from agent.core import telemetry
623
  await telemetry.record_sandbox_destroy(session, sandbox)
624
+ return
625
  except Exception as e:
626
+ last_err = e
627
+ if attempt < 2:
628
+ await asyncio.sleep(2 ** attempt)
629
+ logger.error(
630
+ f"Failed to delete sandbox {space_id} after 3 attempts: {last_err}. "
631
+ f"Orphan — sweep script will pick it up."
632
+ )
633
 
634
  async def _run_session(
635
  self,
 
669
  should_continue = await process_submission(session, submission)
670
  finally:
671
  agent_session.is_processing = False
672
+ await self.persist_session_snapshot(agent_session)
673
  if not should_continue:
674
  break
675
  except asyncio.TimeoutError:
 
704
  async with self._lock:
705
  if session_id in self.sessions:
706
  self.sessions[session_id].is_active = False
707
+ await self.persist_session_snapshot(
708
+ self.sessions[session_id],
709
+ runtime_state="ended",
710
+ status="ended",
711
+ )
712
 
713
  logger.info(f"Session {session_id} ended")
714
 
 
758
  agent_session = self.sessions.get(session_id)
759
  if not agent_session or not agent_session.is_active:
760
  return False
761
+ success = agent_session.session.context_manager.truncate_to_user_message(user_message_index)
762
+ if success:
763
+ await self.persist_session_snapshot(agent_session, runtime_state="idle")
764
+ return success
765
 
766
  async def compact(self, session_id: str) -> bool:
767
  """Compact context in a session."""
 
786
  return success
787
 
788
  async def delete_session(self, session_id: str) -> bool:
789
+ """Soft-delete a session and stop its runtime resources."""
790
  async with self._lock:
791
  agent_session = self.sessions.pop(session_id, None)
792
 
793
  if not agent_session:
794
+ await self._store().soft_delete_session(session_id)
795
+ return True
796
+
797
+ await self._store().soft_delete_session(session_id)
798
 
799
  # Clean up sandbox Space before cancelling the task
800
  await self._cleanup_sandbox(agent_session.session)
 
809
 
810
  return True
811
 
812
+ async def update_session_title(self, session_id: str, title: str | None) -> None:
813
+ """Persist a user-visible title for sidebar rehydration."""
814
+ agent_session = self.sessions.get(session_id)
815
+ if agent_session:
816
+ agent_session.title = title
817
+ await self._store().update_session_fields(session_id, title=title)
818
+
819
+ async def update_session_model(self, session_id: str, model_id: str) -> bool:
820
+ agent_session = self.sessions.get(session_id)
821
+ if not agent_session or not agent_session.is_active:
822
+ return False
823
+ agent_session.session.update_model(model_id)
824
+ await self.persist_session_snapshot(agent_session, runtime_state="idle")
825
+ return True
826
+
827
  def get_session_owner(self, session_id: str) -> str | None:
828
  """Get the user_id that owns a session, or None if session doesn't exist."""
829
  agent_session = self.sessions.get(session_id)
 
851
  if not agent_session:
852
  return None
853
 
854
+ pending_approval = self._pending_tools_for_api(agent_session.session)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
 
856
  return {
857
  "session_id": session_id,
 
862
  "user_id": agent_session.user_id,
863
  "pending_approval": pending_approval,
864
  "model": agent_session.session.config.model_name,
865
+ "title": agent_session.title,
866
+ "notification_destinations": list(
867
+ agent_session.session.notification_destinations
868
+ ),
869
  }
870
 
871
+ def set_notification_destinations(
872
+ self, session_id: str, destinations: list[str]
873
+ ) -> list[str]:
874
+ """Replace the session's opted-in auto-notification destinations."""
875
+ agent_session = self.sessions.get(session_id)
876
+ if not agent_session or not agent_session.is_active:
877
+ raise ValueError("Session not found or inactive")
878
+
879
+ normalized: list[str] = []
880
+ seen: set[str] = set()
881
+ for raw_name in destinations:
882
+ name = raw_name.strip()
883
+ if not name:
884
+ raise ValueError("Destination names must not be empty")
885
+ destination = self.config.messaging.get_destination(name)
886
+ if destination is None:
887
+ raise ValueError(f"Unknown destination '{name}'")
888
+ if not destination.allow_auto_events:
889
+ raise ValueError(
890
+ f"Destination '{name}' is not enabled for auto events"
891
+ )
892
+ if name not in seen:
893
+ normalized.append(name)
894
+ seen.add(name)
895
+
896
+ agent_session.session.set_notification_destinations(normalized)
897
+ return normalized
898
+
899
+ async def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
900
  """List sessions, optionally filtered by user.
901
 
902
  Args:
903
  user_id: If provided, only return sessions owned by this user.
904
  If "dev", return all sessions (dev mode).
905
  """
906
+ results: list[dict[str, Any]] = []
907
+ store = self._store()
908
+ if getattr(store, "enabled", False):
909
+ for row in await store.list_sessions(user_id or "dev"):
910
+ sid = row.get("session_id") or row.get("_id")
911
+ if not sid:
912
+ continue
913
+ runtime_info = self.get_session_info(str(sid))
914
+ if runtime_info:
915
+ results.append(runtime_info)
916
+ continue
917
+ created_at = row.get("created_at")
918
+ if isinstance(created_at, datetime):
919
+ created_at_str = created_at.isoformat()
920
+ else:
921
+ created_at_str = str(created_at or datetime.utcnow().isoformat())
922
+ pending = self._pending_docs_for_api(row.get("pending_approval") or [])
923
+ results.append(
924
+ {
925
+ "session_id": str(sid),
926
+ "created_at": created_at_str,
927
+ "is_active": row.get("status") != "ended",
928
+ "is_processing": row.get("runtime_state") == "processing",
929
+ "message_count": int(row.get("message_count") or 0),
930
+ "user_id": row.get("user_id") or "dev",
931
+ "pending_approval": pending or None,
932
+ "model": row.get("model"),
933
+ "title": row.get("title"),
934
+ "notification_destinations": row.get("notification_destinations") or [],
935
+ }
936
+ )
937
+ return results
938
+
939
  for sid in self.sessions:
940
  info = self.get_session_info(sid)
941
  if not info:
backend/user_quotas.py CHANGED
@@ -1,9 +1,8 @@
1
- """In-memory daily quota for Claude session creations.
2
 
3
  Tracks per-user Claude session starts against a daily cap derived from the
4
- user's HF plan. Caps reset at UTC midnight; the store itself is in-process
5
- and wipes on restart (deliberate — the cost of occasional over-subsidy at
6
- restart is much lower than running a DB).
7
 
8
  Unit: session *creations*, not messages. A user who selects Claude in a new
9
  session consumes one quota point; switching an existing Claude session to
@@ -18,6 +17,8 @@ import asyncio
18
  import os
19
  from datetime import UTC, datetime
20
 
 
 
21
  CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1"))
22
  CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20"))
23
 
@@ -37,6 +38,11 @@ def daily_cap_for(plan: str | None) -> int:
37
 
38
  async def get_claude_used_today(user_id: str) -> int:
39
  """Return today's Claude session count for the user (0 if none / stale day)."""
 
 
 
 
 
40
  async with _lock:
41
  entry = _claude_counts.get(user_id)
42
  if entry is None:
@@ -51,11 +57,37 @@ async def get_claude_used_today(user_id: str) -> int:
51
 
52
  async def increment_claude(user_id: str) -> int:
53
  """Bump today's Claude session count for the user. Returns the new value."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  async with _lock:
55
  today = _today()
56
  day, count = _claude_counts.get(user_id, (today, 0))
57
  if day != today:
58
  count = 0
 
 
59
  count += 1
60
  _claude_counts[user_id] = (today, count)
61
  return count
@@ -63,6 +95,11 @@ async def increment_claude(user_id: str) -> int:
63
 
64
  async def refund_claude(user_id: str) -> None:
65
  """Decrement today's count — used when session creation fails after a successful gate."""
 
 
 
 
 
66
  async with _lock:
67
  entry = _claude_counts.get(user_id)
68
  if entry is None:
@@ -81,3 +118,4 @@ async def refund_claude(user_id: str) -> None:
81
  def _reset_for_tests() -> None:
82
  """Test-only: clear the in-memory store."""
83
  _claude_counts.clear()
 
 
1
+ """Daily quota for Claude session creations.
2
 
3
  Tracks per-user Claude session starts against a daily cap derived from the
4
+ user's HF plan. MongoDB is the source of truth when configured; the
5
+ in-process dict remains the fallback for local/dev/test runs.
 
6
 
7
  Unit: session *creations*, not messages. A user who selects Claude in a new
8
  session consumes one quota point; switching an existing Claude session to
 
17
  import os
18
  from datetime import UTC, datetime
19
 
20
+ from agent.core.session_persistence import NoopSessionStore, get_session_store, _reset_store_for_tests
21
+
22
  CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1"))
23
  CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20"))
24
 
 
38
 
39
  async def get_claude_used_today(user_id: str) -> int:
40
  """Return today's Claude session count for the user (0 if none / stale day)."""
41
+ store = get_session_store()
42
+ if getattr(store, "enabled", False):
43
+ db_count = await store.get_quota(user_id, _today())
44
+ return db_count or 0
45
+
46
  async with _lock:
47
  entry = _claude_counts.get(user_id)
48
  if entry is None:
 
57
 
58
  async def increment_claude(user_id: str) -> int:
59
  """Bump today's Claude session count for the user. Returns the new value."""
60
+ store = get_session_store()
61
+ if getattr(store, "enabled", False):
62
+ db_count = await store.try_increment_quota(user_id, _today(), cap=10**9)
63
+ return db_count or 0
64
+
65
+ async with _lock:
66
+ today = _today()
67
+ day, count = _claude_counts.get(user_id, (today, 0))
68
+ if day != today:
69
+ count = 0
70
+ count += 1
71
+ _claude_counts[user_id] = (today, count)
72
+ return count
73
+
74
+
75
+ async def try_increment_claude(user_id: str, cap: int) -> int | None:
76
+ """Atomically bump today's count if below *cap*.
77
+
78
+ Returns the new count, or None when the user is already at the cap.
79
+ """
80
+ store = get_session_store()
81
+ if getattr(store, "enabled", False):
82
+ return await store.try_increment_quota(user_id, _today(), cap)
83
+
84
  async with _lock:
85
  today = _today()
86
  day, count = _claude_counts.get(user_id, (today, 0))
87
  if day != today:
88
  count = 0
89
+ if count >= cap:
90
+ return None
91
  count += 1
92
  _claude_counts[user_id] = (today, count)
93
  return count
 
95
 
96
  async def refund_claude(user_id: str) -> None:
97
  """Decrement today's count — used when session creation fails after a successful gate."""
98
+ store = get_session_store()
99
+ if getattr(store, "enabled", False):
100
+ await store.refund_quota(user_id, _today())
101
+ return
102
+
103
  async with _lock:
104
  entry = _claude_counts.get(user_id)
105
  if entry is None:
 
118
  def _reset_for_tests() -> None:
119
  """Test-only: clear the in-memory store."""
120
  _claude_counts.clear()
121
+ _reset_store_for_tests(NoopSessionStore())
configs/__init__.py ADDED
File without changes
configs/cli_agent_config.json CHANGED
@@ -5,6 +5,11 @@
5
  "yolo_mode": false,
6
  "confirm_cpu_jobs": true,
7
  "auto_file_upload": true,
 
 
 
 
 
8
  "mcpServers": {
9
  "hf-mcp-server": {
10
  "transport": "http",
 
5
  "yolo_mode": false,
6
  "confirm_cpu_jobs": true,
7
  "auto_file_upload": true,
8
+ "messaging": {
9
+ "enabled": false,
10
+ "auto_event_types": ["approval_required", "error", "turn_complete"],
11
+ "destinations": {}
12
+ },
13
  "mcpServers": {
14
  "hf-mcp-server": {
15
  "transport": "http",
frontend/src/components/Chat/MarkdownContent.tsx CHANGED
@@ -1,4 +1,4 @@
1
- import { useMemo, useRef, useState, useEffect } from 'react';
2
  import { Box } from '@mui/material';
3
  import ReactMarkdown from 'react-markdown';
4
  import remarkGfm from 'remark-gfm';
@@ -166,9 +166,17 @@ export default function MarkdownContent({ content, sx, isStreaming = false }: Ma
166
 
167
  const remarkPlugins = useMemo(() => [remarkGfm], []);
168
 
 
 
 
 
 
 
 
 
169
  return (
170
  <Box sx={[markdownSx, ...(Array.isArray(sx) ? sx : sx ? [sx] : [])]}>
171
- <ReactMarkdown remarkPlugins={remarkPlugins}>{displayContent}</ReactMarkdown>
172
  </Box>
173
  );
174
  }
 
1
+ import { useMemo, useRef, useState, useEffect, type ComponentPropsWithoutRef } from 'react';
2
  import { Box } from '@mui/material';
3
  import ReactMarkdown from 'react-markdown';
4
  import remarkGfm from 'remark-gfm';
 
166
 
167
  const remarkPlugins = useMemo(() => [remarkGfm], []);
168
 
169
+ const components = useMemo(() => ({
170
+ a: ({ href, children, ...props }: ComponentPropsWithoutRef<'a'>) => (
171
+ <a href={href} target="_blank" rel="noopener noreferrer" {...props}>
172
+ {children}
173
+ </a>
174
+ ),
175
+ }), []);
176
+
177
  return (
178
  <Box sx={[markdownSx, ...(Array.isArray(sx) ? sx : sx ? [sx] : [])]}>
179
+ <ReactMarkdown remarkPlugins={remarkPlugins} components={components}>{displayContent}</ReactMarkdown>
180
  </Box>
181
  );
182
  }
frontend/src/components/Chat/ToolCallGroup.tsx CHANGED
@@ -220,6 +220,194 @@ function ResearchSteps({ steps }: { steps: string[] }) {
220
  );
221
  }
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  // ---------------------------------------------------------------------------
224
  // Hardware pricing ($/hr) — from HF Spaces & Jobs pricing
225
  // ---------------------------------------------------------------------------
@@ -517,7 +705,7 @@ function InlineApproval({
517
  const EMPTY_AGENTS: Record<string, ResearchAgentState> = {};
518
 
519
  export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProps) {
520
- const { setPanel, lockPanel, getJobUrl, getEditedScript, setJobStatus, getJobStatus, setToolError, getToolError, setToolRejected, getToolRejected } = useAgentStore();
521
  const researchAgents = useAgentStore(s => {
522
  const activeId = s.activeSessionId;
523
  return (activeId && s.sessionStates[activeId]?.researchAgents) || EMPTY_AGENTS;
@@ -1063,6 +1251,18 @@ export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProp
1063
  <ResearchSteps steps={researchAgents[tool.toolCallId].steps} />
1064
  )}
1065
 
 
 
 
 
 
 
 
 
 
 
 
 
1066
  {/* Per-tool approval: undecided */}
1067
  {isPending && !localDecision && !isSubmitting && (
1068
  <InlineApproval
 
220
  );
221
  }
222
 
223
+ // ---------------------------------------------------------------------------
224
+ // Trackio dashboard embed
225
+ // ---------------------------------------------------------------------------
226
+
227
+ // HF repo IDs are `<owner>/<name>` where each segment is alphanumerics plus
228
+ // `_`, `.`, `-`. Anything else (slashes, spaces, query params, missing owner)
229
+ // would let an attacker-controlled string redirect the embed to a different
230
+ // Space, so we refuse to render rather than build a malformed URL.
231
+ const SPACE_ID_PATTERN = /^[a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+$/;
232
+
233
+ function isValidSpaceId(spaceId: string): boolean {
234
+ return SPACE_ID_PATTERN.test(spaceId);
235
+ }
236
+
237
+ /** HF Space embed subdomain: 'user/space_name' → 'user-space-name'. */
238
+ function spaceIdToSubdomain(spaceId: string): string {
239
+ return spaceId
240
+ .toLowerCase()
241
+ .replace(/[/_.]/g, '-')
242
+ .replace(/-+/g, '-')
243
+ .replace(/^-|-$/g, '');
244
+ }
245
+
246
+ function buildTrackioEmbedUrl(spaceId: string, project?: string): string {
247
+ // __theme=dark is gradio's standard query param to force the embedded
248
+ // dashboard into dark mode so it blends with the surrounding chat instead
249
+ // of flashing a bright white panel inside the dark UI.
250
+ const params = new URLSearchParams({
251
+ sidebar: 'hidden',
252
+ footer: 'false',
253
+ __theme: 'dark',
254
+ });
255
+ if (project) params.set('project', project);
256
+ return `https://${spaceIdToSubdomain(spaceId)}.hf.space/?${params.toString()}`;
257
+ }
258
+
259
+ function buildTrackioPageUrl(spaceId: string, project?: string): string {
260
+ const qs = project ? `?${new URLSearchParams({ project }).toString()}` : '';
261
+ return `https://huggingface.co/spaces/${spaceId}${qs}`;
262
+ }
263
+
264
+ function TrackioEmbed({ spaceId, project }: { spaceId: string; project?: string }) {
265
+ const [expanded, setExpanded] = useState(true);
266
+ const [iframeLoaded, setIframeLoaded] = useState(false);
267
+ const embedUrl = useMemo(() => buildTrackioEmbedUrl(spaceId, project), [spaceId, project]);
268
+ const pageUrl = useMemo(() => buildTrackioPageUrl(spaceId, project), [spaceId, project]);
269
+ const label = project ? `${spaceId} · ${project}` : spaceId;
270
+
271
+ if (!isValidSpaceId(spaceId)) return null;
272
+
273
+ return (
274
+ <Box sx={{ pl: 4.5, pr: 1.5, pb: 1, pt: 0.25 }}>
275
+ <Box
276
+ sx={{
277
+ border: '1px solid var(--tool-border)',
278
+ borderRadius: '8px',
279
+ overflow: 'hidden',
280
+ bgcolor: 'var(--code-panel-bg)',
281
+ }}
282
+ >
283
+ <Stack
284
+ direction="row"
285
+ alignItems="center"
286
+ spacing={1}
287
+ onClick={(e) => e.stopPropagation()}
288
+ sx={{
289
+ px: 1.25,
290
+ py: 0.5,
291
+ borderBottom: expanded ? '1px solid var(--tool-border)' : 'none',
292
+ }}
293
+ >
294
+ <Typography
295
+ sx={{
296
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
297
+ fontSize: '0.65rem',
298
+ fontWeight: 600,
299
+ color: 'var(--accent-yellow)',
300
+ letterSpacing: '0.04em',
301
+ }}
302
+ >
303
+ trackio
304
+ </Typography>
305
+ <Typography
306
+ sx={{
307
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
308
+ fontSize: '0.65rem',
309
+ color: 'var(--muted-text)',
310
+ flex: 1,
311
+ minWidth: 0,
312
+ overflow: 'hidden',
313
+ textOverflow: 'ellipsis',
314
+ whiteSpace: 'nowrap',
315
+ }}
316
+ >
317
+ {label}
318
+ </Typography>
319
+ <Link
320
+ href={pageUrl}
321
+ target="_blank"
322
+ rel="noopener noreferrer"
323
+ onClick={(e) => e.stopPropagation()}
324
+ sx={{
325
+ display: 'inline-flex',
326
+ alignItems: 'center',
327
+ gap: 0.4,
328
+ color: 'var(--accent-yellow)',
329
+ fontSize: '0.65rem',
330
+ textDecoration: 'none',
331
+ '&:hover': { textDecoration: 'underline' },
332
+ }}
333
+ >
334
+ <LaunchIcon sx={{ fontSize: 11 }} />
335
+ Open
336
+ </Link>
337
+ <Button
338
+ size="small"
339
+ onClick={(e) => {
340
+ e.stopPropagation();
341
+ setExpanded((v) => !v);
342
+ }}
343
+ sx={{
344
+ textTransform: 'none',
345
+ minWidth: 'auto',
346
+ px: 0.75,
347
+ py: 0,
348
+ fontSize: '0.65rem',
349
+ color: 'var(--muted-text)',
350
+ '&:hover': { color: 'var(--text)', bgcolor: 'transparent' },
351
+ }}
352
+ >
353
+ {expanded ? 'Hide' : 'Show'}
354
+ </Button>
355
+ </Stack>
356
+ {expanded && (
357
+ <Box sx={{ position: 'relative', width: '100%', height: 480, bgcolor: 'var(--code-panel-bg)' }}>
358
+ <iframe
359
+ src={embedUrl}
360
+ title={`Trackio dashboard ${label}`}
361
+ loading="lazy"
362
+ onLoad={() => setIframeLoaded(true)}
363
+ sandbox="allow-scripts allow-same-origin allow-forms allow-popups allow-downloads allow-modals"
364
+ style={{ border: 0, width: '100%', height: '100%', display: 'block' }}
365
+ />
366
+ {!iframeLoaded && (
367
+ <Stack
368
+ direction="column"
369
+ alignItems="center"
370
+ justifyContent="center"
371
+ spacing={1.5}
372
+ sx={{
373
+ position: 'absolute',
374
+ inset: 0,
375
+ bgcolor: 'var(--code-panel-bg)',
376
+ color: 'var(--muted-text)',
377
+ pointerEvents: 'none',
378
+ }}
379
+ >
380
+ <CircularProgress size={20} sx={{ color: 'var(--accent-yellow)' }} />
381
+ <Typography
382
+ sx={{
383
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
384
+ fontSize: '0.75rem',
385
+ color: 'var(--text)',
386
+ }}
387
+ >
388
+ Spinning up the trackio dashboard…
389
+ </Typography>
390
+ <Typography
391
+ sx={{
392
+ fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, monospace',
393
+ fontSize: '0.65rem',
394
+ color: 'var(--muted-text)',
395
+ textAlign: 'center',
396
+ maxWidth: 360,
397
+ px: 2,
398
+ }}
399
+ >
400
+ First load takes 30–60 seconds. Charts appear automatically once the run starts logging.
401
+ </Typography>
402
+ </Stack>
403
+ )}
404
+ </Box>
405
+ )}
406
+ </Box>
407
+ </Box>
408
+ );
409
+ }
410
+
411
  // ---------------------------------------------------------------------------
412
  // Hardware pricing ($/hr) — from HF Spaces & Jobs pricing
413
  // ---------------------------------------------------------------------------
 
705
  const EMPTY_AGENTS: Record<string, ResearchAgentState> = {};
706
 
707
  export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProps) {
708
+ const { setPanel, lockPanel, getJobUrl, getEditedScript, setJobStatus, getJobStatus, getTrackioDashboard, setToolError, getToolError, setToolRejected, getToolRejected } = useAgentStore();
709
  const researchAgents = useAgentStore(s => {
710
  const activeId = s.activeSessionId;
711
  return (activeId && s.sessionStates[activeId]?.researchAgents) || EMPTY_AGENTS;
 
1251
  <ResearchSteps steps={researchAgents[tool.toolCallId].steps} />
1252
  )}
1253
 
1254
+ {/* Trackio dashboard embed — shown for hf_jobs / sandbox_create runs that declared a trackio space */}
1255
+ {(tool.toolName === 'hf_jobs' || tool.toolName === 'sandbox_create')
1256
+ && !isPending
1257
+ && !isRejected
1258
+ && !cancelled
1259
+ && (() => {
1260
+ const trackio = getTrackioDashboard(tool.toolCallId);
1261
+ return trackio
1262
+ ? <TrackioEmbed spaceId={trackio.spaceId} project={trackio.project} />
1263
+ : null;
1264
+ })()}
1265
+
1266
  {/* Per-tool approval: undecided */}
1267
  {isPending && !localDecision && !isSubmitting && (
1268
  <InlineApproval
frontend/src/components/JobsUpgradeDialog.tsx CHANGED
@@ -8,7 +8,6 @@ import {
8
  DialogContentText,
9
  DialogTitle,
10
  FormControl,
11
- InputLabel,
12
  MenuItem,
13
  Select,
14
  Typography,
@@ -37,13 +36,20 @@ export default function JobsUpgradeDialog({
37
  onClose,
38
  onContinueWithNamespace,
39
  }: JobsUpgradeDialogProps) {
40
- const [selectedNamespace, setSelectedNamespace] = useState('');
41
 
42
  useEffect(() => {
43
  if (!open) return;
44
  setSelectedNamespace(eligibleNamespaces[0] || '');
45
  }, [open, eligibleNamespaces]);
46
 
 
 
 
 
 
 
 
47
  return (
48
  <Dialog
49
  open={open}
@@ -57,7 +63,7 @@ export default function JobsUpgradeDialog({
57
  border: '1px solid var(--border)',
58
  borderRadius: 'var(--radius-md)',
59
  boxShadow: 'var(--shadow-1)',
60
- maxWidth: 500,
61
  mx: 2,
62
  },
63
  }}
@@ -65,72 +71,75 @@ export default function JobsUpgradeDialog({
65
  <DialogTitle
66
  sx={{ color: 'var(--text)', fontWeight: 700, fontSize: '1rem', pt: 2.5, pb: 0, px: 3 }}
67
  >
68
- {mode === 'namespace' ? 'Choose the org for this job' : 'Jobs need Pro or a paid org'}
69
  </DialogTitle>
70
  <DialogContent sx={{ px: 3, pt: 1.25, pb: 0 }}>
71
  <DialogContentText
72
  sx={{ color: 'var(--muted-text)', fontSize: '0.85rem', lineHeight: 1.6 }}
73
  >
74
- {message}
75
  </DialogContentText>
76
- {eligibleNamespaces.length > 0 && (
77
- <Box
78
- sx={{
79
- mt: 2,
80
- p: 1.5,
81
- borderRadius: '8px',
82
- bgcolor: 'var(--accent-yellow-weak)',
83
- border: '1px solid var(--border)',
84
- }}
85
- >
86
- <Typography
87
- variant="caption"
88
  sx={{
89
- display: 'block',
90
- fontWeight: 700,
91
  color: 'var(--text)',
92
- fontSize: '0.78rem',
93
- mb: 1,
94
- letterSpacing: '0.02em',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  }}
96
  >
97
- Eligible namespaces
98
- </Typography>
99
- {mode === 'namespace' ? (
100
- <FormControl fullWidth size="small">
101
- <InputLabel id="jobs-namespace-label">Organization</InputLabel>
102
- <Select
103
- labelId="jobs-namespace-label"
104
- value={selectedNamespace}
105
- label="Organization"
106
- onChange={(e) => setSelectedNamespace(String(e.target.value))}
107
  >
108
- {eligibleNamespaces.map((namespace) => (
109
- <MenuItem key={namespace} value={namespace}>
110
- {namespace}
111
- </MenuItem>
112
- ))}
113
- </Select>
114
- </FormControl>
115
- ) : (
116
  <Typography
117
  variant="caption"
118
- sx={{ display: 'block', color: 'var(--muted-text)', fontSize: '0.78rem', lineHeight: 1.55 }}
119
  >
120
- {eligibleNamespaces.join(', ')}
121
  </Typography>
122
- )}
123
- </Box>
124
  )}
125
- <Typography
126
- variant="caption"
127
- sx={{ display: 'block', mt: 2, color: 'var(--muted-text)', fontSize: '0.78rem', lineHeight: 1.55 }}
128
- >
129
- If you decline, the agent will have to find another way forward without `hf_jobs`.
130
- </Typography>
131
  </DialogContent>
132
- <DialogActions sx={{ px: 3, pb: 2.5, pt: 2, gap: 1 }}>
133
- {mode === 'namespace' ? (
134
  <Button
135
  onClick={() => onContinueWithNamespace(selectedNamespace)}
136
  disabled={!selectedNamespace}
@@ -147,7 +156,7 @@ export default function JobsUpgradeDialog({
147
  '&:hover': { bgcolor: '#FFB340', boxShadow: 'none' },
148
  }}
149
  >
150
- Run under selected org
151
  </Button>
152
  ) : (
153
  <Button
@@ -183,7 +192,7 @@ export default function JobsUpgradeDialog({
183
  '&:hover': { bgcolor: 'var(--hover-bg)' },
184
  }}
185
  >
186
- Decline tool call
187
  </Button>
188
  </DialogActions>
189
  </Dialog>
 
8
  DialogContentText,
9
  DialogTitle,
10
  FormControl,
 
11
  MenuItem,
12
  Select,
13
  Typography,
 
36
  onClose,
37
  onContinueWithNamespace,
38
  }: JobsUpgradeDialogProps) {
39
+ const [selectedNamespace, setSelectedNamespace] = useState(() => eligibleNamespaces[0] || '');
40
 
41
  useEffect(() => {
42
  if (!open) return;
43
  setSelectedNamespace(eligibleNamespaces[0] || '');
44
  }, [open, eligibleNamespaces]);
45
 
46
+ const isNamespace = mode === 'namespace';
47
+ const title = isNamespace ? 'Run jobs as' : 'Jobs need Pro or a paid org';
48
+
49
+ const body = isNamespace
50
+ ? "Pick which paid organization should pay for and own this job. We'll use the same one for the rest of this browser."
51
+ : message;
52
+
53
  return (
54
  <Dialog
55
  open={open}
 
63
  border: '1px solid var(--border)',
64
  borderRadius: 'var(--radius-md)',
65
  boxShadow: 'var(--shadow-1)',
66
+ maxWidth: 460,
67
  mx: 2,
68
  },
69
  }}
 
71
  <DialogTitle
72
  sx={{ color: 'var(--text)', fontWeight: 700, fontSize: '1rem', pt: 2.5, pb: 0, px: 3 }}
73
  >
74
+ {title}
75
  </DialogTitle>
76
  <DialogContent sx={{ px: 3, pt: 1.25, pb: 0 }}>
77
  <DialogContentText
78
  sx={{ color: 'var(--muted-text)', fontSize: '0.85rem', lineHeight: 1.6 }}
79
  >
80
+ {body}
81
  </DialogContentText>
82
+
83
+ {isNamespace ? (
84
+ <FormControl fullWidth size="small" sx={{ mt: 2 }}>
85
+ <Select
86
+ value={selectedNamespace}
87
+ displayEmpty
88
+ onChange={(e) => setSelectedNamespace(String(e.target.value))}
 
 
 
 
 
89
  sx={{
90
+ bgcolor: 'var(--composer-bg)',
 
91
  color: 'var(--text)',
92
+ fontSize: '0.88rem',
93
+ fontWeight: 600,
94
+ '& .MuiOutlinedInput-notchedOutline': { borderColor: 'var(--border)' },
95
+ '&:hover .MuiOutlinedInput-notchedOutline': { borderColor: 'var(--border)' },
96
+ '&.Mui-focused .MuiOutlinedInput-notchedOutline': {
97
+ borderColor: 'var(--accent-yellow)',
98
+ borderWidth: 1,
99
+ },
100
+ '& .MuiSelect-icon': { color: 'var(--muted-text)' },
101
+ }}
102
+ MenuProps={{
103
+ PaperProps: {
104
+ sx: {
105
+ bgcolor: 'var(--panel)',
106
+ border: '1px solid var(--border)',
107
+ borderRadius: '8px',
108
+ mt: 0.5,
109
+ },
110
+ },
111
  }}
112
  >
113
+ {eligibleNamespaces.map((namespace) => (
114
+ <MenuItem
115
+ key={namespace}
116
+ value={namespace}
117
+ sx={{
118
+ fontSize: '0.88rem',
119
+ color: 'var(--text)',
120
+ '&.Mui-selected': { bgcolor: 'rgba(255,255,255,0.05)' },
121
+ }}
 
122
  >
123
+ {namespace}
124
+ </MenuItem>
125
+ ))}
126
+ </Select>
127
+ </FormControl>
128
+ ) : (
129
+ eligibleNamespaces.length > 0 && (
130
+ <Box sx={{ mt: 1.5 }}>
131
  <Typography
132
  variant="caption"
133
+ sx={{ color: 'var(--muted-text)', fontSize: '0.78rem', lineHeight: 1.55 }}
134
  >
135
+ Eligible namespaces: {eligibleNamespaces.join(', ')}
136
  </Typography>
137
+ </Box>
138
+ )
139
  )}
 
 
 
 
 
 
140
  </DialogContent>
141
+ <DialogActions sx={{ px: 3, pb: 2.5, pt: 2.5, gap: 1 }}>
142
+ {isNamespace ? (
143
  <Button
144
  onClick={() => onContinueWithNamespace(selectedNamespace)}
145
  disabled={!selectedNamespace}
 
156
  '&:hover': { bgcolor: '#FFB340', boxShadow: 'none' },
157
  }}
158
  >
159
+ Continue
160
  </Button>
161
  ) : (
162
  <Button
 
192
  '&:hover': { bgcolor: 'var(--hover-bg)' },
193
  }}
194
  >
195
+ {isNamespace ? 'Skip this tool call' : 'Decline tool call'}
196
  </Button>
197
  </DialogActions>
198
  </Dialog>
frontend/src/components/SessionSidebar/SessionSidebar.tsx CHANGED
@@ -1,4 +1,4 @@
1
- import { useCallback, useState } from 'react';
2
  import {
3
  Alert,
4
  Box,
@@ -25,13 +25,30 @@ interface SessionSidebarProps {
25
  }
26
 
27
  export default function SessionSidebar({ onClose }: SessionSidebarProps) {
28
- const { sessions, activeSessionId, createSession, deleteSession, switchSession } =
29
  useSessionStore();
30
  const { setPlan, clearPanel } =
31
  useAgentStore();
32
  const [isCreatingSession, setIsCreatingSession] = useState(false);
33
  const [capacityError, setCapacityError] = useState<string | null>(null);
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  // -- Handlers -----------------------------------------------------------
36
 
37
  const handleNewSession = useCallback(async () => {
 
1
+ import { useCallback, useEffect, useState } from 'react';
2
  import {
3
  Alert,
4
  Box,
 
25
  }
26
 
27
  export default function SessionSidebar({ onClose }: SessionSidebarProps) {
28
+ const { sessions, activeSessionId, createSession, deleteSession, switchSession, mergeServerSessions } =
29
  useSessionStore();
30
  const { setPlan, clearPanel } =
31
  useAgentStore();
32
  const [isCreatingSession, setIsCreatingSession] = useState(false);
33
  const [capacityError, setCapacityError] = useState<string | null>(null);
34
 
35
+ useEffect(() => {
36
+ let cancelled = false;
37
+ (async () => {
38
+ try {
39
+ const response = await apiFetch('/api/sessions');
40
+ if (!response.ok) return;
41
+ const data = await response.json();
42
+ if (!cancelled && Array.isArray(data)) {
43
+ mergeServerSessions(data);
44
+ }
45
+ } catch {
46
+ /* local sidebar metadata is still usable */
47
+ }
48
+ })();
49
+ return () => { cancelled = true; };
50
+ }, [mergeServerSessions]);
51
+
52
  // -- Handlers -----------------------------------------------------------
53
 
54
  const handleNewSession = useCallback(async () => {
frontend/src/components/WelcomeScreen/WelcomeScreen.tsx CHANGED
@@ -280,6 +280,12 @@ export default function WelcomeScreen() {
280
  : '';
281
 
282
  return (
 
 
 
 
 
 
283
  <Box
284
  sx={{
285
  width: '100%',
@@ -287,172 +293,182 @@ export default function WelcomeScreen() {
287
  display: 'flex',
288
  flexDirection: 'column',
289
  alignItems: 'center',
290
- justifyContent: 'center',
291
  background: 'var(--body-gradient)',
292
- py: 8,
293
  }}
294
  >
295
- {/* Logo */}
296
  <Box
297
- component="img"
298
- src="/smolagents.webp"
299
- alt="smolagents"
300
- sx={{ width: 80, height: 80, mb: 2.5, display: 'block' }}
301
- />
302
-
303
- {/* Title */}
304
- <Typography
305
- variant="h2"
306
  sx={{
307
- fontWeight: 800,
308
- color: 'var(--text)',
309
- mb: 1,
310
- letterSpacing: '-0.02em',
311
- fontSize: { xs: '1.8rem', md: '2.4rem' },
 
312
  }}
313
  >
314
- ML Intern
315
- </Typography>
 
 
 
 
 
316
 
317
- {/* Description */}
318
- <Typography
319
- variant="body1"
320
- sx={{
321
- color: 'var(--muted-text)',
322
- maxWidth: 480,
323
- mb: 4,
324
- lineHeight: 1.7,
325
- fontSize: '0.9rem',
326
- textAlign: 'center',
327
- px: 2,
328
- '& strong': { color: 'var(--text)', fontWeight: 600 },
329
- }}
330
- >
331
- Your personal <strong>ML agent</strong>. It reads <strong>papers</strong>, finds <strong>datasets</strong>, trains <strong>models</strong>, and iterates until the numbers go up. Instructions in. Trained model out.
332
- </Typography>
333
 
334
- {/* ── Checklist ──────────────────────────────────────────── */}
335
- <Box
336
- sx={{
337
- width: '100%',
338
- maxWidth: 520,
339
- bgcolor: 'var(--surface)',
340
- border: '1px solid var(--border)',
341
- borderRadius: '12px',
342
- overflow: 'hidden',
343
- mx: 2,
344
- }}
345
- >
346
- {isDevUser ? (
347
- /* Dev mode: single step */
348
- <ChecklistStep
349
- stepNumber={1}
350
- title="Start Session"
351
- description="Launch an AI agent session for ML engineering."
352
- status="active"
353
- actionLabel="Start Session"
354
- actionIcon={<RocketLaunchIcon sx={{ fontSize: 16 }} />}
355
- onAction={handleStartSession}
356
- loading={isCreating}
357
- isLast
358
- />
359
- ) : inIframe ? (
360
- /* Iframe: 2 steps */
361
- <>
362
- <ChecklistStep
363
- stepNumber={1}
364
- title="Join ML Agent Explorers"
365
- description="Get free access to GPUs, inference APIs, and Hub resources."
366
- status={isOrgMember ? 'completed' : 'active'}
367
- actionLabel="Join Organization"
368
- actionIcon={<GroupAddIcon sx={{ fontSize: 16 }} />}
369
- onAction={handleJoinOrg}
370
- />
371
- <ChecklistStep
372
- stepNumber={2}
373
- title="Open ML Intern"
374
- description="Open the agent in a full browser tab to get started."
375
- status={isOrgMember ? 'active' : 'locked'}
376
- lockedReason="Join the organization first."
377
- actionLabel="Open ML Intern"
378
- actionIcon={<OpenInNewIcon sx={{ fontSize: 16 }} />}
379
- actionHref={spaceHost}
380
- isLast
381
- />
382
- </>
383
- ) : (
384
- /* Direct access: 3 steps */
385
- <>
386
  <ChecklistStep
387
  stepNumber={1}
388
- title="Sign in with Hugging Face"
389
- description="Authenticate to access GPU resources and model APIs."
390
- status={signInStatus}
391
- actionLabel="Sign in"
392
- actionIcon={<LoginIcon sx={{ fontSize: 16 }} />}
393
- onAction={() => triggerLogin()}
394
- />
395
- <ChecklistStep
396
- stepNumber={2}
397
- title="Join ML Agent Explorers"
398
- description="Get free access to GPUs, inference APIs, and Hub resources."
399
- status={joinOrgStatus}
400
- lockedReason="Sign in first to continue."
401
- actionLabel="Join Organization"
402
- actionIcon={<GroupAddIcon sx={{ fontSize: 16 }} />}
403
- onAction={handleJoinOrg}
404
- />
405
- <ChecklistStep
406
- stepNumber={3}
407
  title="Start Session"
408
  description="Launch an AI agent session for ML engineering."
409
- status={startStatus}
410
- lockedReason="Complete the steps above to continue."
411
  actionLabel="Start Session"
412
  actionIcon={<RocketLaunchIcon sx={{ fontSize: 16 }} />}
413
  onAction={handleStartSession}
414
  loading={isCreating}
415
  isLast
416
  />
417
- </>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  )}
419
- </Box>
420
 
421
- {/* Polling hint when waiting for org join */}
422
- {isAuthenticated && !isOrgMember && !isDevUser && !inIframe && (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  <Typography
424
  variant="caption"
425
- sx={{ mt: 2, color: 'var(--muted-text)', fontSize: '0.75rem', textAlign: 'center' }}
426
  >
427
- This page updates automatically when you join the organization.
428
  </Typography>
429
- )}
430
-
431
- {/* Error */}
432
- {error && (
433
- <Alert
434
- severity="warning"
435
- variant="outlined"
436
- onClose={() => setError(null)}
437
- sx={{
438
- mt: 3,
439
- maxWidth: 400,
440
- fontSize: '0.8rem',
441
- borderColor: HF_ORANGE,
442
- color: 'var(--text)',
443
- }}
444
- >
445
- {error}
446
- </Alert>
447
- )}
448
-
449
- {/* Footnote */}
450
- <Typography
451
- variant="caption"
452
- sx={{ mt: 4, color: 'var(--muted-text)', opacity: 0.5, fontSize: '0.7rem' }}
453
- >
454
- Conversations are stored locally in your browser.
455
- </Typography>
456
  </Box>
457
  );
458
  }
 
280
  : '';
281
 
282
  return (
283
+ // Outer container scrolls; inner uses `margin: auto` so the checklist
284
+ // centers vertically when the viewport has room and falls back to top-
285
+ // aligned + scrollable when it doesn't. The previous setup hardcoded
286
+ // `justify-content: center` with no overflow, so on short viewports
287
+ // (1366×768 Chrome was the reported case) the bottom of the card —
288
+ // including the "Start session" CTA — got clipped with no way to scroll.
289
  <Box
290
  sx={{
291
  width: '100%',
 
293
  display: 'flex',
294
  flexDirection: 'column',
295
  alignItems: 'center',
296
+ overflowY: 'auto',
297
  background: 'var(--body-gradient)',
 
298
  }}
299
  >
 
300
  <Box
 
 
 
 
 
 
 
 
 
301
  sx={{
302
+ display: 'flex',
303
+ flexDirection: 'column',
304
+ alignItems: 'center',
305
+ width: '100%',
306
+ margin: 'auto',
307
+ py: 8,
308
  }}
309
  >
310
+ {/* Logo */}
311
+ <Box
312
+ component="img"
313
+ src="/smolagents.webp"
314
+ alt="smolagents"
315
+ sx={{ width: 80, height: 80, mb: 2.5, display: 'block' }}
316
+ />
317
 
318
+ {/* Title */}
319
+ <Typography
320
+ variant="h2"
321
+ sx={{
322
+ fontWeight: 800,
323
+ color: 'var(--text)',
324
+ mb: 1,
325
+ letterSpacing: '-0.02em',
326
+ fontSize: { xs: '1.8rem', md: '2.4rem' },
327
+ }}
328
+ >
329
+ ML Intern
330
+ </Typography>
 
 
 
331
 
332
+ {/* Description */}
333
+ <Typography
334
+ variant="body1"
335
+ sx={{
336
+ color: 'var(--muted-text)',
337
+ maxWidth: 480,
338
+ mb: 4,
339
+ lineHeight: 1.7,
340
+ fontSize: '0.9rem',
341
+ textAlign: 'center',
342
+ px: 2,
343
+ '& strong': { color: 'var(--text)', fontWeight: 600 },
344
+ }}
345
+ >
346
+ Your personal <strong>ML agent</strong>. It reads <strong>papers</strong>, finds <strong>datasets</strong>, trains <strong>models</strong>, and iterates until the numbers go up. Instructions in. Trained model out.
347
+ </Typography>
348
+
349
+ {/* ── Checklist ──────────────────────────────────────────── */}
350
+ <Box
351
+ sx={{
352
+ width: '100%',
353
+ maxWidth: 520,
354
+ bgcolor: 'var(--surface)',
355
+ border: '1px solid var(--border)',
356
+ borderRadius: '12px',
357
+ overflow: 'hidden',
358
+ mx: 2,
359
+ }}
360
+ >
361
+ {isDevUser ? (
362
+ /* Dev mode: single step */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  <ChecklistStep
364
  stepNumber={1}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  title="Start Session"
366
  description="Launch an AI agent session for ML engineering."
367
+ status="active"
 
368
  actionLabel="Start Session"
369
  actionIcon={<RocketLaunchIcon sx={{ fontSize: 16 }} />}
370
  onAction={handleStartSession}
371
  loading={isCreating}
372
  isLast
373
  />
374
+ ) : inIframe ? (
375
+ /* Iframe: 2 steps */
376
+ <>
377
+ <ChecklistStep
378
+ stepNumber={1}
379
+ title="Join ML Agent Explorers"
380
+ description="Get free access to GPUs, inference APIs, and Hub resources."
381
+ status={isOrgMember ? 'completed' : 'active'}
382
+ actionLabel="Join Organization"
383
+ actionIcon={<GroupAddIcon sx={{ fontSize: 16 }} />}
384
+ onAction={handleJoinOrg}
385
+ />
386
+ <ChecklistStep
387
+ stepNumber={2}
388
+ title="Open ML Intern"
389
+ description="Open the agent in a full browser tab to get started."
390
+ status={isOrgMember ? 'active' : 'locked'}
391
+ lockedReason="Join the organization first."
392
+ actionLabel="Open ML Intern"
393
+ actionIcon={<OpenInNewIcon sx={{ fontSize: 16 }} />}
394
+ actionHref={spaceHost}
395
+ isLast
396
+ />
397
+ </>
398
+ ) : (
399
+ /* Direct access: 3 steps */
400
+ <>
401
+ <ChecklistStep
402
+ stepNumber={1}
403
+ title="Sign in with Hugging Face"
404
+ description="Authenticate to access GPU resources and model APIs."
405
+ status={signInStatus}
406
+ actionLabel="Sign in"
407
+ actionIcon={<LoginIcon sx={{ fontSize: 16 }} />}
408
+ onAction={() => triggerLogin()}
409
+ />
410
+ <ChecklistStep
411
+ stepNumber={2}
412
+ title="Join ML Agent Explorers"
413
+ description="Get free access to GPUs, inference APIs, and Hub resources."
414
+ status={joinOrgStatus}
415
+ lockedReason="Sign in first to continue."
416
+ actionLabel="Join Organization"
417
+ actionIcon={<GroupAddIcon sx={{ fontSize: 16 }} />}
418
+ onAction={handleJoinOrg}
419
+ />
420
+ <ChecklistStep
421
+ stepNumber={3}
422
+ title="Start Session"
423
+ description="Launch an AI agent session for ML engineering."
424
+ status={startStatus}
425
+ lockedReason="Complete the steps above to continue."
426
+ actionLabel="Start Session"
427
+ actionIcon={<RocketLaunchIcon sx={{ fontSize: 16 }} />}
428
+ onAction={handleStartSession}
429
+ loading={isCreating}
430
+ isLast
431
+ />
432
+ </>
433
+ )}
434
+ </Box>
435
+
436
+ {/* Polling hint when waiting for org join */}
437
+ {isAuthenticated && !isOrgMember && !isDevUser && !inIframe && (
438
+ <Typography
439
+ variant="caption"
440
+ sx={{ mt: 2, color: 'var(--muted-text)', fontSize: '0.75rem', textAlign: 'center' }}
441
+ >
442
+ This page updates automatically when you join the organization.
443
+ </Typography>
444
  )}
 
445
 
446
+ {/* Error */}
447
+ {error && (
448
+ <Alert
449
+ severity="warning"
450
+ variant="outlined"
451
+ onClose={() => setError(null)}
452
+ sx={{
453
+ mt: 3,
454
+ maxWidth: 400,
455
+ fontSize: '0.8rem',
456
+ borderColor: HF_ORANGE,
457
+ color: 'var(--text)',
458
+ }}
459
+ >
460
+ {error}
461
+ </Alert>
462
+ )}
463
+
464
+ {/* Footnote */}
465
  <Typography
466
  variant="caption"
467
+ sx={{ mt: 4, color: 'var(--muted-text)', opacity: 0.5, fontSize: '0.7rem' }}
468
  >
469
+ Conversations are stored locally in your browser.
470
  </Typography>
471
+ </Box>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  </Box>
473
  );
474
  }
frontend/src/hooks/useAgentChat.ts CHANGED
@@ -371,7 +371,7 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
371
  } catch {
372
  return null;
373
  }
374
- }, [sessionId, setNeedsAttention]);
375
 
376
  // -- useChat from Vercel AI SDK -----------------------------------------
377
  const chat = useChat({
@@ -447,6 +447,33 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
447
  }
448
  return;
449
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  logger.error('useChat error:', error);
451
  if (isActiveRef.current) {
452
  useAgentStore.getState().setError(error.message);
@@ -594,7 +621,10 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
594
  /** Read the event stream from GET /api/events and forward to side-channel. */
595
  const consumeEventStream = async (signal: AbortSignal) => {
596
  try {
597
- const res = await apiFetch(`/api/events/${sessionId}`, {
 
 
 
598
  headers: { 'Accept': 'text/event-stream' },
599
  signal,
600
  });
@@ -602,6 +632,71 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
602
 
603
  const reader = res.body.pipeThrough(new TextDecoderStream()).getReader();
604
  let buf = '';
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  while (true) {
606
  const { value, done } = await reader.read();
607
  if (done || signal.aborted) break;
@@ -609,59 +704,21 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
609
  const lines = buf.split('\n');
610
  buf = lines.pop() || '';
611
  for (const line of lines) {
612
- const trimmed = line.trim();
613
- if (!trimmed.startsWith('data: ')) continue;
614
- try {
615
- const event = JSON.parse(trimmed.slice(6));
616
- // Forward to side-channel for real-time UI updates
617
- const et = event.event_type as string;
618
- if (et === 'processing') sideChannel.onProcessing();
619
- else if (et === 'assistant_chunk') sideChannel.onStreaming();
620
- else if (et === 'tool_call') {
621
- const t = event.data?.tool as string;
622
- const d = event.data?.arguments?.description as string | undefined;
623
- sideChannel.onToolRunning(t, d);
624
- sideChannel.onToolCallPanel(t, (event.data?.arguments || {}) as Record<string, unknown>);
625
- } else if (et === 'tool_output') {
626
- sideChannel.onToolOutputPanel(
627
- event.data?.tool as string,
628
- event.data?.tool_call_id as string,
629
- event.data?.output as string,
630
- event.data?.success as boolean,
631
- );
632
- } else if (et === 'tool_state_change') {
633
- const state = event.data?.state as string;
634
- const toolName = event.data?.tool as string;
635
- if (state === 'running' && toolName) sideChannel.onToolRunning(toolName);
636
- } else if (et === 'turn_complete' || et === 'error' || et === 'interrupted') {
637
- sideChannel.onProcessingDone();
638
- stopReconnect();
639
- // Final hydration to get the complete message state
640
- const result = await hydrateMessages();
641
- if (result) {
642
- const uiMsgs = llmMessagesToUIMessages(result.data, result.pendingIds, chatActionsRef.current.messages);
643
- if (uiMsgs.length > 0) {
644
- chat.setMessages(uiMsgs);
645
- saveMessages(sessionId, uiMsgs);
646
- }
647
- }
648
- return;
649
- } else if (et === 'approval_required') {
650
- sideChannel.onApprovalRequired(
651
- (event.data?.tools || []) as Array<{ tool: string; arguments: Record<string, unknown>; tool_call_id: string }>,
652
- );
653
- stopReconnect();
654
- const result = await hydrateMessages();
655
- if (result) {
656
- const uiMsgs = llmMessagesToUIMessages(result.data, result.pendingIds, chatActionsRef.current.messages);
657
- if (uiMsgs.length > 0) {
658
- chat.setMessages(uiMsgs);
659
- saveMessages(sessionId, uiMsgs);
660
- }
661
- }
662
- return;
663
- }
664
- } catch { /* ignore parse errors */ }
665
  }
666
  }
667
  } catch {
@@ -830,6 +887,9 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
830
  : approval.namespace,
831
  }));
832
 
 
 
 
833
  useAgentStore.getState().setJobsUpgradeRequired(null);
834
  return approveTools(approvals);
835
  }, [approveTools]);
 
371
  } catch {
372
  return null;
373
  }
374
+ }, [sessionId, setNeedsAttention, updateSession]);
375
 
376
  // -- useChat from Vercel AI SDK -----------------------------------------
377
  const chat = useChat({
 
447
  }
448
  return;
449
  }
450
+ if (error.message === 'HF_JOBS_INVALID_NAMESPACE') {
451
+ // Saved preference is no longer one of the user's eligible namespaces
452
+ // (e.g. they left the org). Clear it and reopen the picker.
453
+ const typed = error as Error & {
454
+ detail?: Record<string, unknown>;
455
+ approvals?: Array<{
456
+ tool_call_id: string;
457
+ approved: boolean;
458
+ feedback?: string | null;
459
+ edited_script?: string | null;
460
+ namespace?: string | null;
461
+ }>;
462
+ };
463
+ useAgentStore.getState().setPreferredJobsNamespace(null);
464
+ void hydrateFromBackend();
465
+ if (isActiveRef.current) {
466
+ useAgentStore.getState().setJobsUpgradeRequired({
467
+ approvals: typed.approvals || [],
468
+ toolCallIds: (typed.detail?.tool_call_ids as string[]) || [],
469
+ message: String(typed.detail?.message || 'Pick a different organization for this job run.'),
470
+ eligibleNamespaces: (typed.detail?.eligible_namespaces as string[]) || [],
471
+ plan: ((typed.detail?.plan as 'free' | 'pro' | 'org') || 'free'),
472
+ mode: 'namespace',
473
+ });
474
+ }
475
+ return;
476
+ }
477
  logger.error('useChat error:', error);
478
  if (isActiveRef.current) {
479
  useAgentStore.getState().setError(error.message);
 
621
  /** Read the event stream from GET /api/events and forward to side-channel. */
622
  const consumeEventStream = async (signal: AbortSignal) => {
623
  try {
624
+ const lastEventKey = `hf-agent-last-event:${sessionId}`;
625
+ const lastSeq = localStorage.getItem(lastEventKey);
626
+ const qs = lastSeq ? `?after=${encodeURIComponent(lastSeq)}` : '';
627
+ const res = await apiFetch(`/api/events/${sessionId}${qs}`, {
628
  headers: { 'Accept': 'text/event-stream' },
629
  signal,
630
  });
 
632
 
633
  const reader = res.body.pipeThrough(new TextDecoderStream()).getReader();
634
  let buf = '';
635
+ let eventId: string | null = null;
636
+ let eventData = '';
637
+ const dispatch = async () => {
638
+ if (!eventData.trim()) {
639
+ eventId = null;
640
+ eventData = '';
641
+ return false;
642
+ }
643
+ const event = JSON.parse(eventData.trim());
644
+ const seq = event.seq ?? (eventId ? Number(eventId) : undefined);
645
+ if (Number.isFinite(seq)) {
646
+ localStorage.setItem(lastEventKey, String(seq));
647
+ }
648
+ eventId = null;
649
+ eventData = '';
650
+ // Forward to side-channel for real-time UI updates
651
+ const et = event.event_type as string;
652
+ if (et === 'processing') sideChannel.onProcessing();
653
+ else if (et === 'assistant_chunk') sideChannel.onStreaming();
654
+ else if (et === 'tool_call') {
655
+ const t = event.data?.tool as string;
656
+ const d = event.data?.arguments?.description as string | undefined;
657
+ sideChannel.onToolRunning(t, d);
658
+ sideChannel.onToolCallPanel(t, (event.data?.arguments || {}) as Record<string, unknown>);
659
+ } else if (et === 'tool_output') {
660
+ sideChannel.onToolOutputPanel(
661
+ event.data?.tool as string,
662
+ event.data?.tool_call_id as string,
663
+ event.data?.output as string,
664
+ event.data?.success as boolean,
665
+ );
666
+ } else if (et === 'tool_state_change') {
667
+ const state = event.data?.state as string;
668
+ const toolName = event.data?.tool as string;
669
+ if (state === 'running' && toolName) sideChannel.onToolRunning(toolName);
670
+ } else if (et === 'turn_complete' || et === 'error' || et === 'interrupted') {
671
+ sideChannel.onProcessingDone();
672
+ stopReconnect();
673
+ // Final hydration to get the complete message state
674
+ const result = await hydrateMessages();
675
+ if (result) {
676
+ const uiMsgs = llmMessagesToUIMessages(result.data, result.pendingIds, chatActionsRef.current.messages);
677
+ if (uiMsgs.length > 0) {
678
+ chat.setMessages(uiMsgs);
679
+ saveMessages(sessionId, uiMsgs);
680
+ }
681
+ }
682
+ return true;
683
+ } else if (et === 'approval_required') {
684
+ sideChannel.onApprovalRequired(
685
+ (event.data?.tools || []) as Array<{ tool: string; arguments: Record<string, unknown>; tool_call_id: string }>,
686
+ );
687
+ stopReconnect();
688
+ const result = await hydrateMessages();
689
+ if (result) {
690
+ const uiMsgs = llmMessagesToUIMessages(result.data, result.pendingIds, chatActionsRef.current.messages);
691
+ if (uiMsgs.length > 0) {
692
+ chat.setMessages(uiMsgs);
693
+ saveMessages(sessionId, uiMsgs);
694
+ }
695
+ }
696
+ return true;
697
+ }
698
+ return false;
699
+ };
700
  while (true) {
701
  const { value, done } = await reader.read();
702
  if (done || signal.aborted) break;
 
704
  const lines = buf.split('\n');
705
  buf = lines.pop() || '';
706
  for (const line of lines) {
707
+ const trimmed = line.replace(/\r$/, '');
708
+ if (trimmed === '') {
709
+ try {
710
+ if (await dispatch()) return;
711
+ } catch { /* ignore parse errors */ }
712
+ continue;
713
+ }
714
+ if (trimmed.startsWith(':')) continue;
715
+ if (trimmed.startsWith('id:')) {
716
+ eventId = trimmed.slice(3).trim();
717
+ continue;
718
+ }
719
+ if (trimmed.startsWith('data:')) {
720
+ eventData += trimmed.slice(5).trimStart() + '\n';
721
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
722
  }
723
  }
724
  } catch {
 
887
  : approval.namespace,
888
  }));
889
 
890
+ // Remember this choice so the picker doesn't reappear for every
891
+ // subsequent hf_jobs call.
892
+ useAgentStore.getState().setPreferredJobsNamespace(namespace);
893
  useAgentStore.getState().setJobsUpgradeRequired(null);
894
  return approveTools(approvals);
895
  }, [approveTools]);
frontend/src/lib/sse-chat-transport.ts CHANGED
@@ -42,35 +42,66 @@ function nextPartId(prefix: string): string {
42
  return `${prefix}-${Date.now()}-${++partIdCounter}`;
43
  }
44
 
 
 
 
 
45
  /** Parse an SSE text stream into AgentEvent objects. */
46
- function createSSEParserStream(): TransformStream<string, AgentEvent> {
47
  let buffer = '';
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return new TransformStream<string, AgentEvent>({
49
  transform(chunk, controller) {
50
  buffer += chunk;
51
  const lines = buffer.split('\n');
52
  // Keep the last (possibly incomplete) line in the buffer
53
  buffer = lines.pop() || '';
54
- for (const line of lines) {
55
- const trimmed = line.trim();
56
- if (trimmed.startsWith('data: ')) {
57
- try {
58
- const json = JSON.parse(trimmed.slice(6));
59
- controller.enqueue(json as AgentEvent);
60
- } catch {
61
- logger.warn('SSE parse error:', trimmed);
62
- }
 
 
63
  }
64
  }
65
  },
66
  flush(controller) {
67
- // Process any remaining data in buffer
68
- if (buffer.trim().startsWith('data: ')) {
69
- try {
70
- const json = JSON.parse(buffer.trim().slice(6));
71
- controller.enqueue(json as AgentEvent);
72
- } catch { /* ignore incomplete */ }
73
  }
 
74
  },
75
  });
76
  }
@@ -226,12 +257,17 @@ function createEventToChunkStream(sideChannel: SideChannelCallbacks): TransformS
226
  const state = (event.data?.state as string) || '';
227
  const toolName = (event.data?.tool as string) || '';
228
  const jobUrl = (event.data?.jobUrl as string) || undefined;
 
 
229
 
230
  if (tcId.startsWith('plan_tool')) break;
231
 
232
  if (jobUrl && tcId) {
233
  useAgentStore.getState().setJobUrl(tcId, jobUrl);
234
  }
 
 
 
235
  if (state === 'running' && toolName) {
236
  sideChannel.onToolRunning(toolName);
237
  }
@@ -320,7 +356,14 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
320
  const approved = p.approval?.approved ?? true;
321
  // Get edited script from agentStore if available
322
  const editedScript = useAgentStore.getState().getEditedScript(p.toolCallId);
323
- const namespace = useAgentStore.getState().getApprovalNamespace(p.toolCallId);
 
 
 
 
 
 
 
324
  return {
325
  tool_call_id: p.toolCallId,
326
  approved,
@@ -388,6 +431,20 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
388
  throw err;
389
  }
390
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  if (!response.ok) {
392
  const errorText = await response.text().catch(() => 'Request failed');
393
  throw new Error(`Chat request failed: ${response.status} ${errorText}`);
@@ -400,7 +457,7 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
400
  // Pipe: response bytes → text → SSE events → UIMessageChunks
401
  return response.body
402
  .pipeThrough(new TextDecoderStream())
403
- .pipeThrough(createSSEParserStream())
404
  .pipeThrough(createEventToChunkStream(this.sideChannel));
405
  }
406
 
@@ -415,7 +472,9 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
415
  if (!info.is_processing) return null;
416
 
417
  // Session is mid-turn — subscribe to its event broadcast.
418
- const response = await apiFetch(`/api/events/${this.sessionId}`, {
 
 
419
  headers: { 'Accept': 'text/event-stream' },
420
  });
421
  if (!response.ok || !response.body) return null;
@@ -424,7 +483,7 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
424
 
425
  return response.body
426
  .pipeThrough(new TextDecoderStream())
427
- .pipeThrough(createSSEParserStream())
428
  .pipeThrough(createEventToChunkStream(this.sideChannel));
429
  } catch {
430
  return null;
 
42
  return `${prefix}-${Date.now()}-${++partIdCounter}`;
43
  }
44
 
45
+ function lastEventKey(sessionId: string): string {
46
+ return `hf-agent-last-event:${sessionId}`;
47
+ }
48
+
49
  /** Parse an SSE text stream into AgentEvent objects. */
50
+ function createSSEParserStream(sessionId: string): TransformStream<string, AgentEvent> {
51
  let buffer = '';
52
+ let eventId: string | null = null;
53
+ let data = '';
54
+
55
+ const dispatch = (controller: TransformStreamDefaultController<AgentEvent>) => {
56
+ if (!data.trim()) {
57
+ eventId = null;
58
+ data = '';
59
+ return;
60
+ }
61
+ try {
62
+ const json = JSON.parse(data.trim()) as AgentEvent;
63
+ const seq = json.seq ?? (eventId ? Number(eventId) : undefined);
64
+ if (Number.isFinite(seq)) {
65
+ json.seq = seq;
66
+ localStorage.setItem(lastEventKey(sessionId), String(seq));
67
+ }
68
+ controller.enqueue(json);
69
+ } catch {
70
+ logger.warn('SSE parse error:', data.trim());
71
+ } finally {
72
+ eventId = null;
73
+ data = '';
74
+ }
75
+ };
76
+
77
  return new TransformStream<string, AgentEvent>({
78
  transform(chunk, controller) {
79
  buffer += chunk;
80
  const lines = buffer.split('\n');
81
  // Keep the last (possibly incomplete) line in the buffer
82
  buffer = lines.pop() || '';
83
+ for (const rawLine of lines) {
84
+ const line = rawLine.replace(/\r$/, '');
85
+ if (line === '') {
86
+ dispatch(controller);
87
+ continue;
88
+ }
89
+ if (line.startsWith(':')) continue;
90
+ if (line.startsWith('id:')) {
91
+ eventId = line.slice(3).trim();
92
+ } else if (line.startsWith('data:')) {
93
+ data += line.slice(5).trimStart() + '\n';
94
  }
95
  }
96
  },
97
  flush(controller) {
98
+ const line = buffer.replace(/\r$/, '');
99
+ if (line.startsWith('id:')) {
100
+ eventId = line.slice(3).trim();
101
+ } else if (line.startsWith('data:')) {
102
+ data += line.slice(5).trimStart() + '\n';
 
103
  }
104
+ dispatch(controller);
105
  },
106
  });
107
  }
 
257
  const state = (event.data?.state as string) || '';
258
  const toolName = (event.data?.tool as string) || '';
259
  const jobUrl = (event.data?.jobUrl as string) || undefined;
260
+ const trackioSpaceId = (event.data?.trackioSpaceId as string) || undefined;
261
+ const trackioProject = (event.data?.trackioProject as string) || undefined;
262
 
263
  if (tcId.startsWith('plan_tool')) break;
264
 
265
  if (jobUrl && tcId) {
266
  useAgentStore.getState().setJobUrl(tcId, jobUrl);
267
  }
268
+ if (trackioSpaceId && tcId) {
269
+ useAgentStore.getState().setTrackioDashboard(tcId, trackioSpaceId, trackioProject);
270
+ }
271
  if (state === 'running' && toolName) {
272
  sideChannel.onToolRunning(toolName);
273
  }
 
356
  const approved = p.approval?.approved ?? true;
357
  // Get edited script from agentStore if available
358
  const editedScript = useAgentStore.getState().getEditedScript(p.toolCallId);
359
+ const explicitNamespace = useAgentStore.getState().getApprovalNamespace(p.toolCallId);
360
+ // Fall back to the user's persisted choice so we don't re-prompt
361
+ // every hf_jobs call. Backend will 400 if the saved namespace is
362
+ // no longer valid; the error handler clears the preference and
363
+ // reopens the picker.
364
+ const preferred = useAgentStore.getState().preferredJobsNamespace;
365
+ const namespace = explicitNamespace
366
+ ?? (approved && p.toolName === 'hf_jobs' ? preferred ?? null : null);
367
  return {
368
  tool_call_id: p.toolCallId,
369
  approved,
 
431
  throw err;
432
  }
433
  }
434
+ if (response.status === 400) {
435
+ const payload = await response.json().catch(() => null);
436
+ if (payload?.detail?.error === 'hf_jobs_invalid_namespace') {
437
+ // Stored namespace is no longer eligible — surface so the UI can
438
+ // clear the saved preference and reopen the picker.
439
+ const err = new Error('HF_JOBS_INVALID_NAMESPACE') as Error & {
440
+ detail?: Record<string, unknown>;
441
+ approvals?: Array<Record<string, unknown>>;
442
+ };
443
+ err.detail = payload.detail as Record<string, unknown>;
444
+ err.approvals = (body.approvals as Array<Record<string, unknown>> | undefined) || [];
445
+ throw err;
446
+ }
447
+ }
448
  if (!response.ok) {
449
  const errorText = await response.text().catch(() => 'Request failed');
450
  throw new Error(`Chat request failed: ${response.status} ${errorText}`);
 
457
  // Pipe: response bytes → text → SSE events → UIMessageChunks
458
  return response.body
459
  .pipeThrough(new TextDecoderStream())
460
+ .pipeThrough(createSSEParserStream(sessionId))
461
  .pipeThrough(createEventToChunkStream(this.sideChannel));
462
  }
463
 
 
472
  if (!info.is_processing) return null;
473
 
474
  // Session is mid-turn — subscribe to its event broadcast.
475
+ const lastSeq = localStorage.getItem(lastEventKey(this.sessionId));
476
+ const qs = lastSeq ? `?after=${encodeURIComponent(lastSeq)}` : '';
477
+ const response = await apiFetch(`/api/events/${this.sessionId}${qs}`, {
478
  headers: { 'Accept': 'text/event-stream' },
479
  });
480
  if (!response.ok || !response.body) return null;
 
483
 
484
  return response.body
485
  .pipeThrough(new TextDecoderStream())
486
+ .pipeThrough(createSSEParserStream(this.sessionId))
487
  .pipeThrough(createEventToChunkStream(this.sideChannel));
488
  } catch {
489
  return null;
frontend/src/store/agentStore.ts CHANGED
@@ -141,12 +141,21 @@ interface AgentStore {
141
  // Namespace overrides chosen for hf_jobs approvals (tool_call_id -> namespace)
142
  approvalNamespaces: Record<string, string>;
143
 
 
 
 
 
144
  // Job URLs (tool_call_id -> job URL) for HF jobs
145
  jobUrls: Record<string, string>;
146
 
147
  // Job statuses (tool_call_id -> job status) for HF jobs
148
  jobStatuses: Record<string, string>;
149
 
 
 
 
 
 
150
  // Tool error states (tool_call_id -> true if errored) - persisted across renders
151
  toolErrors: Record<string, boolean>;
152
 
@@ -194,12 +203,17 @@ interface AgentStore {
194
  getApprovalNamespace: (toolCallId: string) => string | undefined;
195
  clearApprovalNamespaces: () => void;
196
 
 
 
197
  setJobUrl: (toolCallId: string, jobUrl: string) => void;
198
  getJobUrl: (toolCallId: string) => string | undefined;
199
 
200
  setJobStatus: (toolCallId: string, status: string) => void;
201
  getJobStatus: (toolCallId: string) => string | undefined;
202
 
 
 
 
203
  setToolError: (toolCallId: string, hasError: boolean) => void;
204
  getToolError: (toolCallId: string) => boolean | undefined;
205
 
@@ -264,6 +278,48 @@ function saveRejectedTools(rejected: Record<string, boolean>): void {
264
  }
265
  }
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  export const useAgentStore = create<AgentStore>()((set, get) => ({
268
  sessionStates: {},
269
  activeSessionId: null,
@@ -285,8 +341,10 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
285
 
286
  editedScripts: {},
287
  approvalNamespaces: {},
 
288
  jobUrls: {},
289
  jobStatuses: {},
 
290
  toolErrors: loadToolErrors(),
291
  rejectedTools: loadRejectedTools(),
292
 
@@ -465,6 +523,11 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
465
 
466
  clearApprovalNamespaces: () => set({ approvalNamespaces: {} }),
467
 
 
 
 
 
 
468
  // ── Job URLs ────────────────────────────────────────────────────────
469
 
470
  setJobUrl: (toolCallId, jobUrl) => {
@@ -485,6 +548,26 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
485
 
486
  getJobStatus: (toolCallId) => get().jobStatuses[toolCallId],
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  // ── Tool Errors ─────────────────────────────────────────────────────
489
 
490
  setToolError: (toolCallId, hasError) => {
 
141
  // Namespace overrides chosen for hf_jobs approvals (tool_call_id -> namespace)
142
  approvalNamespaces: Record<string, string>;
143
 
144
+ // Persisted preferred namespace for hf_jobs (auto-applied to future approvals
145
+ // so the user only picks once)
146
+ preferredJobsNamespace: string | null;
147
+
148
  // Job URLs (tool_call_id -> job URL) for HF jobs
149
  jobUrls: Record<string, string>;
150
 
151
  // Job statuses (tool_call_id -> job status) for HF jobs
152
  jobStatuses: Record<string, string>;
153
 
154
+ // Trackio dashboard config per tool call (tool_call_id -> {spaceId, project?})
155
+ // Set by hf_jobs / sandbox_create tools when the agent declares trackio_space_id;
156
+ // the UI uses it to embed the live dashboard via an iframe.
157
+ trackioDashboards: Record<string, { spaceId: string; project?: string }>;
158
+
159
  // Tool error states (tool_call_id -> true if errored) - persisted across renders
160
  toolErrors: Record<string, boolean>;
161
 
 
203
  getApprovalNamespace: (toolCallId: string) => string | undefined;
204
  clearApprovalNamespaces: () => void;
205
 
206
+ setPreferredJobsNamespace: (namespace: string | null) => void;
207
+
208
  setJobUrl: (toolCallId: string, jobUrl: string) => void;
209
  getJobUrl: (toolCallId: string) => string | undefined;
210
 
211
  setJobStatus: (toolCallId: string, status: string) => void;
212
  getJobStatus: (toolCallId: string) => string | undefined;
213
 
214
+ setTrackioDashboard: (toolCallId: string, spaceId: string, project?: string) => void;
215
+ getTrackioDashboard: (toolCallId: string) => { spaceId: string; project?: string } | undefined;
216
+
217
  setToolError: (toolCallId: string, hasError: boolean) => void;
218
  getToolError: (toolCallId: string) => boolean | undefined;
219
 
 
278
  }
279
  }
280
 
281
+ // Trackio dashboards survive a page reload — without persistence the iframe
282
+ // disappears whenever the user refreshes mid-job, which is the exact moment
283
+ // they'd want to keep watching it.
284
+ function loadTrackioDashboards(): Record<string, { spaceId: string; project?: string }> {
285
+ try {
286
+ const stored = localStorage.getItem('hf-agent-trackio-dashboards');
287
+ return stored ? JSON.parse(stored) : {};
288
+ } catch {
289
+ return {};
290
+ }
291
+ }
292
+
293
+ function saveTrackioDashboards(dashboards: Record<string, { spaceId: string; project?: string }>): void {
294
+ try {
295
+ localStorage.setItem('hf-agent-trackio-dashboards', JSON.stringify(dashboards));
296
+ } catch (e) {
297
+ console.warn('Failed to persist trackio dashboards:', e);
298
+ }
299
+ }
300
+
301
+ const PREFERRED_JOBS_NAMESPACE_KEY = 'hf-agent-preferred-jobs-namespace';
302
+
303
+ function loadPreferredJobsNamespace(): string | null {
304
+ try {
305
+ return localStorage.getItem(PREFERRED_JOBS_NAMESPACE_KEY);
306
+ } catch {
307
+ return null;
308
+ }
309
+ }
310
+
311
+ function savePreferredJobsNamespace(namespace: string | null): void {
312
+ try {
313
+ if (namespace) {
314
+ localStorage.setItem(PREFERRED_JOBS_NAMESPACE_KEY, namespace);
315
+ } else {
316
+ localStorage.removeItem(PREFERRED_JOBS_NAMESPACE_KEY);
317
+ }
318
+ } catch (e) {
319
+ console.warn('Failed to persist preferred jobs namespace:', e);
320
+ }
321
+ }
322
+
323
  export const useAgentStore = create<AgentStore>()((set, get) => ({
324
  sessionStates: {},
325
  activeSessionId: null,
 
341
 
342
  editedScripts: {},
343
  approvalNamespaces: {},
344
+ preferredJobsNamespace: loadPreferredJobsNamespace(),
345
  jobUrls: {},
346
  jobStatuses: {},
347
+ trackioDashboards: loadTrackioDashboards(),
348
  toolErrors: loadToolErrors(),
349
  rejectedTools: loadRejectedTools(),
350
 
 
523
 
524
  clearApprovalNamespaces: () => set({ approvalNamespaces: {} }),
525
 
526
+ setPreferredJobsNamespace: (namespace) => {
527
+ savePreferredJobsNamespace(namespace);
528
+ set({ preferredJobsNamespace: namespace });
529
+ },
530
+
531
  // ── Job URLs ────────────────────────────────────────────────────────
532
 
533
  setJobUrl: (toolCallId, jobUrl) => {
 
548
 
549
  getJobStatus: (toolCallId) => get().jobStatuses[toolCallId],
550
 
551
+ // ── Trackio Dashboards ──────────────────────────────────────────────
552
+
553
+ setTrackioDashboard: (toolCallId, spaceId, project) => {
554
+ set((state) => {
555
+ const existing = state.trackioDashboards[toolCallId];
556
+ // Don't churn the object if nothing changed (avoids extra renders).
557
+ if (existing && existing.spaceId === spaceId && existing.project === project) {
558
+ return {};
559
+ }
560
+ const updated = {
561
+ ...state.trackioDashboards,
562
+ [toolCallId]: { spaceId, ...(project ? { project } : {}) },
563
+ };
564
+ saveTrackioDashboards(updated);
565
+ return { trackioDashboards: updated };
566
+ });
567
+ },
568
+
569
+ getTrackioDashboard: (toolCallId) => get().trackioDashboards[toolCallId],
570
+
571
  // ── Tool Errors ─────────────────────────────────────────────────────
572
 
573
  setToolError: (toolCallId, hasError) => {
frontend/src/store/sessionStore.ts CHANGED
@@ -20,6 +20,14 @@ interface SessionStore {
20
  markExpired: (id: string) => void;
21
  /** Clear the expired flag (used after restore-with-summary succeeds). */
22
  clearExpired: (id: string) => void;
 
 
 
 
 
 
 
 
23
  /** Atomically swap a session's id in the list + both localStorage caches.
24
  * Used when we rehydrate an expired session into a freshly-created backend
25
  * session — preserves title, timestamps, and messages. */
@@ -76,6 +84,45 @@ export const useSessionStore = create<SessionStore>()(
76
  }));
77
  },
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  renameSession: (oldId: string, newId: string) => {
80
  if (oldId === newId) return;
81
  moveMessages(oldId, newId);
 
20
  markExpired: (id: string) => void;
21
  /** Clear the expired flag (used after restore-with-summary succeeds). */
22
  clearExpired: (id: string) => void;
23
+ /** Merge durable server-side sessions into local sidebar metadata. */
24
+ mergeServerSessions: (sessions: Array<{
25
+ session_id: string;
26
+ title?: string | null;
27
+ created_at: string;
28
+ is_active?: boolean;
29
+ pending_approval?: unknown[] | null;
30
+ }>) => void;
31
  /** Atomically swap a session's id in the list + both localStorage caches.
32
  * Used when we rehydrate an expired session into a freshly-created backend
33
  * session — preserves title, timestamps, and messages. */
 
84
  }));
85
  },
86
 
87
+ mergeServerSessions: (serverSessions) => {
88
+ set((state) => {
89
+ const byId = new Map(state.sessions.map((s) => [s.id, s]));
90
+ const merged = [...state.sessions];
91
+ for (const server of serverSessions) {
92
+ const id = server.session_id;
93
+ if (!id) continue;
94
+ const existing = byId.get(id);
95
+ if (existing) {
96
+ const updated = {
97
+ ...existing,
98
+ title: server.title || existing.title,
99
+ isActive: server.is_active ?? existing.isActive,
100
+ needsAttention: Boolean(server.pending_approval?.length) || existing.needsAttention,
101
+ expired: false,
102
+ };
103
+ const idx = merged.findIndex((s) => s.id === id);
104
+ if (idx >= 0) merged[idx] = updated;
105
+ byId.set(id, updated);
106
+ continue;
107
+ }
108
+ const newSession: SessionMeta = {
109
+ id,
110
+ title: server.title || `Chat ${merged.length + 1}`,
111
+ createdAt: server.created_at || new Date().toISOString(),
112
+ isActive: server.is_active ?? true,
113
+ needsAttention: Boolean(server.pending_approval?.length),
114
+ expired: false,
115
+ };
116
+ merged.push(newSession);
117
+ byId.set(id, newSession);
118
+ }
119
+ return {
120
+ sessions: merged,
121
+ activeSessionId: state.activeSessionId || merged[merged.length - 1]?.id || null,
122
+ };
123
+ });
124
+ },
125
+
126
  renameSession: (oldId: string, newId: string) => {
127
  if (oldId === newId) return;
128
  moveMessages(oldId, newId);
frontend/src/types/events.ts CHANGED
@@ -24,6 +24,7 @@ export type EventType =
24
  export interface AgentEvent {
25
  event_type: EventType;
26
  data?: Record<string, unknown>;
 
27
  }
28
 
29
  export interface ReadyEventData {
 
24
  export interface AgentEvent {
25
  event_type: EventType;
26
  data?: Record<string, unknown>;
27
+ seq?: number;
28
  }
29
 
30
  export interface ReadyEventData {
pyproject.toml CHANGED
@@ -13,7 +13,7 @@ dependencies = [
13
  "requests>=2.33.0",
14
  "litellm>=1.83.0",
15
  "boto3>=1.35.0",
16
- "huggingface-hub>=1.0.1",
17
  "fastmcp>=3.2.0",
18
  "prompt-toolkit>=3.0.0",
19
  "thefuzz>=0.22.1",
@@ -27,6 +27,7 @@ dependencies = [
27
  "httpx>=0.27.0",
28
  "websockets>=13.0",
29
  "apscheduler>=3.10,<4",
 
30
  ]
31
 
32
  [project.optional-dependencies]
@@ -42,7 +43,7 @@ eval = [
42
  # Development and testing dependencies
43
  dev = [
44
  "pytest>=9.0.2",
45
- "pytest-asyncio>=0.26.0",
46
  ]
47
 
48
  # All dependencies (eval + dev)
@@ -58,7 +59,20 @@ requires = ["setuptools>=64"]
58
  build-backend = "setuptools.build_meta"
59
 
60
  [tool.setuptools.packages.find]
61
- include = ["agent*"]
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  [tool.uv]
64
  package = true
 
13
  "requests>=2.33.0",
14
  "litellm>=1.83.0",
15
  "boto3>=1.35.0",
16
+ "huggingface-hub>=1.12.0",
17
  "fastmcp>=3.2.0",
18
  "prompt-toolkit>=3.0.0",
19
  "thefuzz>=0.22.1",
 
27
  "httpx>=0.27.0",
28
  "websockets>=13.0",
29
  "apscheduler>=3.10,<4",
30
+ "pymongo>=4.17.0",
31
  ]
32
 
33
  [project.optional-dependencies]
 
43
  # Development and testing dependencies
44
  dev = [
45
  "pytest>=9.0.2",
46
+ "pytest-asyncio>=1.2.0",
47
  ]
48
 
49
  # All dependencies (eval + dev)
 
59
  build-backend = "setuptools.build_meta"
60
 
61
  [tool.setuptools.packages.find]
62
+ # `configs` ships the JSON files loaded by agent.main.CLI_CONFIG_PATH at
63
+ # runtime (resolves to <site-packages>/configs/cli_agent_config.json).
64
+ # Without it, `uv tool install` / `pip install` produce a broken install
65
+ # that imports fine but crashes at startup with FileNotFoundError.
66
+ include = ["agent*", "configs"]
67
+
68
+ [tool.setuptools.package-data]
69
+ configs = ["*.json"]
70
+ # Agent data files: system prompts loaded by ContextManager._load_system_prompt
71
+ # at runtime (`<site-packages>/agent/prompts/system_prompt_v3.yaml`), plus the
72
+ # package README. Without these, headless_main hangs forever — submission_loop
73
+ # crashes with FileNotFoundError but headless_main doesn't check agent_task.done()
74
+ # and just keeps awaiting the "ready" event_queue item that will never come.
75
+ agent = ["README.md", "prompts/*.yaml"]
76
 
77
  [tool.uv]
78
  package = true
scripts/build_kpis.py CHANGED
@@ -38,15 +38,27 @@ re-running the same hour overwrites.
38
  llm_calls — count of llm_call events
39
  tokens_prompt / _completion / _cache_read / _cache_creation
40
  cost_usd — sum of llm_call.cost_usd
 
41
  cache_hit_ratio — cache_read / (cache_read + prompt)
42
- tool_success_rate — tool_output success=True / total tool_output
43
- failure_rate sessions that ended with an `error` event / sessions
44
- regenerate_rate — sessions with any `undo_complete` event / sessions
 
45
  time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
46
  thumbs_up / thumbs_down
47
  hf_jobs_submitted / _succeeded / _blocked
 
48
  pro_cta_clicks
49
  gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
 
 
 
 
 
 
 
 
 
50
 
51
  ================================================================================
52
  Usage
@@ -213,6 +225,7 @@ def _session_metrics(session: dict) -> dict:
213
  "thumbs_up": 0, "thumbs_down": 0,
214
  "hf_jobs_submitted": 0, "hf_jobs_succeeded": 0, "hf_jobs_blocked": 0,
215
  "pro_cta_clicks": 0,
 
216
  "first_tool_s": -1,
217
  }
218
  events = session.get("events") or []
@@ -231,11 +244,19 @@ def _session_metrics(session: dict) -> dict:
231
  gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
232
  jobs_submitted = 0
233
  jobs_succeeded = 0
234
- jobs_blocked = 0
235
  thumbs_up = 0
236
  thumbs_down = 0
 
 
 
 
237
  pro_cta_clicks = 0
238
  pro_cta_by_source: dict[str, int] = defaultdict(int)
 
 
 
 
 
239
 
240
  start_dt = _parse_ts(session_start)
241
 
@@ -260,6 +281,10 @@ def _session_metrics(session: dict) -> dict:
260
  first_tool_ts = (ts - start_dt).total_seconds()
261
 
262
  elif et == "tool_call":
 
 
 
 
263
  if first_tool_ts is None and ts is not None and start_dt is not None:
264
  first_tool_ts = (ts - start_dt).total_seconds()
265
 
@@ -296,6 +321,19 @@ def _session_metrics(session: dict) -> dict:
296
  source = str(data.get("source") or "unknown")
297
  pro_cta_by_source[source] += 1
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  out["tool_calls_total"] = tool_total
300
  out["tool_calls_success"] = tool_success
301
  out["failures"] = 1 if had_error else 0
@@ -304,12 +342,22 @@ def _session_metrics(session: dict) -> dict:
304
  out["thumbs_down"] = thumbs_down
305
  out["hf_jobs_submitted"] = jobs_submitted
306
  out["hf_jobs_succeeded"] = jobs_succeeded
 
 
 
307
  out["hf_jobs_blocked"] = jobs_blocked
308
  out["pro_cta_clicks"] = pro_cta_clicks
309
  out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
310
  out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
311
  out["_pro_cta_by_source"] = dict(pro_cta_by_source)
312
  out["_user"] = session.get("user_id") or session.get("session_id")
 
 
 
 
 
 
 
313
  return dict(out)
314
 
315
 
@@ -317,12 +365,36 @@ def _aggregate(per_session: list[dict]) -> dict:
317
  """Collapse a bucket's worth of session rollups into the final KPI row."""
318
  ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
319
  gpu_hours: dict[str, float] = defaultdict(float)
320
- pro_cta_by_source: dict[str, int] = defaultdict(int)
321
  for s in per_session:
322
  for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
323
  gpu_hours[f] += h
324
- for source, count in (s.get("_pro_cta_by_source") or {}).items():
325
- pro_cta_by_source[source] += int(count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  total_sessions = sum(s["sessions"] for s in per_session)
328
  total_turns = sum(s["turns"] for s in per_session)
@@ -330,6 +402,16 @@ def _aggregate(per_session: list[dict]) -> dict:
330
  tokens_cache_read = sum(s["tokens_cache_read"] for s in per_session)
331
  tool_total = sum(s["tool_calls_total"] for s in per_session)
332
  tool_success = sum(s["tool_calls_success"] for s in per_session)
 
 
 
 
 
 
 
 
 
 
333
 
334
  unique_users = {s.get("_user") for s in per_session if s.get("_user")}
335
 
@@ -343,26 +425,61 @@ def _aggregate(per_session: list[dict]) -> dict:
343
  "tokens_cache_read": int(tokens_cache_read),
344
  "tokens_cache_creation": int(sum(s["tokens_cache_creation"] for s in per_session)),
345
  "cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
 
 
 
 
 
 
346
  "cache_hit_ratio": round(
347
  tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
348
  ) if (tokens_cache_read + tokens_prompt) > 0 else 0.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  "tool_success_rate": round(tool_success / tool_total, 4) if tool_total > 0 else 0.0,
350
- "failure_rate": round(
351
- sum(s["failures"] for s in per_session) / total_sessions, 4
352
- ) if total_sessions > 0 else 0.0,
353
- "regenerate_rate": round(
354
- sum(s["regenerate_sessions"] for s in per_session) / total_sessions, 4
355
- ) if total_sessions > 0 else 0.0,
356
  "time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
357
  "time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
358
  "thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
359
  "thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
360
  "hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
361
  "hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
362
- "hf_jobs_blocked": int(sum(s["hf_jobs_blocked"] for s in per_session)),
363
- "pro_cta_clicks": int(sum(s["pro_cta_clicks"] for s in per_session)),
 
 
 
364
  "gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
365
- "pro_cta_by_source_json": json.dumps(dict(pro_cta_by_source), sort_keys=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  }
367
 
368
 
 
38
  llm_calls — count of llm_call events
39
  tokens_prompt / _completion / _cache_read / _cache_creation
40
  cost_usd — sum of llm_call.cost_usd
41
+ cost_per_session_mean / _p50 / _p95 — per-session cost distribution
42
  cache_hit_ratio — cache_read / (cache_read + prompt)
43
+ tool_calls_total / _succeeded / _failed — per-tool_output reliability counts
44
+ tool_success_rate succeeded / total (kept for back-compat)
45
+ successful_sessions / errored_sessions / regenerated_sessions — outcome counts
46
+ failure_rate / regenerate_rate — kept for back-compat
47
  time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
48
  thumbs_up / thumbs_down
49
  hf_jobs_submitted / _succeeded / _blocked
50
+ sandboxes_created / _cpu / _gpu — sandbox_create events bucketed by hardware
51
  pro_cta_clicks
52
  gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
53
+ research_calls — total `research` tool_call events
54
+ sessions_with_research — sessions that called `research` ≥1
55
+ research_calls_per_session_p50 / _p95 — among sessions that did any (zero-only sessions excluded)
56
+ distinct_tools_per_session_p50 / _p95 — among sessions with ≥1 named tool_call
57
+ tool_calls_per_session_p50 / _p95 — among sessions with ≥1 named tool_call
58
+ tool_calls_per_turn_p50 / _p95 — calls / turns, among sessions with turns>0
59
+ tool_calls_by_name_json — JSON {tool: total_calls} (all tools seen)
60
+ sessions_using_tool_json — JSON {tool: distinct_sessions_using}
61
+ sessions_by_model_json — JSON {model_name: count} (CLI vs Bedrock split)
62
 
63
  ================================================================================
64
  Usage
 
225
  "thumbs_up": 0, "thumbs_down": 0,
226
  "hf_jobs_submitted": 0, "hf_jobs_succeeded": 0, "hf_jobs_blocked": 0,
227
  "pro_cta_clicks": 0,
228
+ "sandboxes_created": 0, "sandboxes_cpu": 0, "sandboxes_gpu": 0,
229
  "first_tool_s": -1,
230
  }
231
  events = session.get("events") or []
 
244
  gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
245
  jobs_submitted = 0
246
  jobs_succeeded = 0
 
247
  thumbs_up = 0
248
  thumbs_down = 0
249
+ sandboxes_created = 0
250
+ sandboxes_cpu = 0
251
+ sandboxes_gpu = 0
252
+ jobs_blocked = 0
253
  pro_cta_clicks = 0
254
  pro_cta_by_source: dict[str, int] = defaultdict(int)
255
+ # Per-tool counters from tool_call events. Counted off tool_call (which
256
+ # carries data["tool"]) rather than tool_output (which only carries
257
+ # success/output) so we can attribute calls to specific tools.
258
+ tool_calls_by_name: dict[str, int] = defaultdict(int)
259
+ total_named_tool_calls = 0
260
 
261
  start_dt = _parse_ts(session_start)
262
 
 
281
  first_tool_ts = (ts - start_dt).total_seconds()
282
 
283
  elif et == "tool_call":
284
+ name = data.get("tool")
285
+ if name:
286
+ tool_calls_by_name[name] += 1
287
+ total_named_tool_calls += 1
288
  if first_tool_ts is None and ts is not None and start_dt is not None:
289
  first_tool_ts = (ts - start_dt).total_seconds()
290
 
 
321
  source = str(data.get("source") or "unknown")
322
  pro_cta_by_source[source] += 1
323
 
324
+ elif et == "sandbox_create":
325
+ sandboxes_created += 1
326
+ hardware = (data.get("hardware") or "").lower()
327
+ # CPU flavors are explicitly named "cpu-*". Everything else
328
+ # (including unknown/missing hardware strings) lands in the GPU
329
+ # bucket, since the auto-create default is "cpu-basic" which is
330
+ # matched here — anything that isn't is almost always an explicit
331
+ # GPU choice.
332
+ if hardware.startswith("cpu-"):
333
+ sandboxes_cpu += 1
334
+ else:
335
+ sandboxes_gpu += 1
336
+
337
  out["tool_calls_total"] = tool_total
338
  out["tool_calls_success"] = tool_success
339
  out["failures"] = 1 if had_error else 0
 
342
  out["thumbs_down"] = thumbs_down
343
  out["hf_jobs_submitted"] = jobs_submitted
344
  out["hf_jobs_succeeded"] = jobs_succeeded
345
+ out["sandboxes_created"] = sandboxes_created
346
+ out["sandboxes_cpu"] = sandboxes_cpu
347
+ out["sandboxes_gpu"] = sandboxes_gpu
348
  out["hf_jobs_blocked"] = jobs_blocked
349
  out["pro_cta_clicks"] = pro_cta_clicks
350
  out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
351
  out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
352
  out["_pro_cta_by_source"] = dict(pro_cta_by_source)
353
  out["_user"] = session.get("user_id") or session.get("session_id")
354
+ # Intra-session tool fields. Underscore-prefixed = consumed by _aggregate
355
+ # only, never written to CSV directly.
356
+ out["_tool_calls_by_name"] = dict(tool_calls_by_name)
357
+ out["_research_calls"] = tool_calls_by_name.get("research", 0)
358
+ out["_distinct_tools_used"] = len(tool_calls_by_name)
359
+ out["_total_named_tool_calls"] = total_named_tool_calls
360
+ out["_model_name"] = session.get("model_name") or "unknown"
361
  return dict(out)
362
 
363
 
 
365
  """Collapse a bucket's worth of session rollups into the final KPI row."""
366
  ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
367
  gpu_hours: dict[str, float] = defaultdict(float)
 
368
  for s in per_session:
369
  for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
370
  gpu_hours[f] += h
371
+
372
+ # Per-tool aggregates. ``sessions_using_tool`` counts each session at most
373
+ # once per tool, so the dashboard can show "how many sessions reached for
374
+ # research" alongside "how many research calls overall".
375
+ tool_calls_by_name: dict[str, int] = defaultdict(int)
376
+ sessions_using_tool: dict[str, int] = defaultdict(int)
377
+ sessions_by_model: dict[str, int] = defaultdict(int)
378
+ for s in per_session:
379
+ for name, count in (s.get("_tool_calls_by_name") or {}).items():
380
+ tool_calls_by_name[name] += int(count)
381
+ sessions_using_tool[name] += 1
382
+ sessions_by_model[s.get("_model_name") or "unknown"] += 1
383
+
384
+ # Percentile inputs. All "per session" percentiles exclude sessions that
385
+ # never reached for the relevant signal — otherwise quiet hours
386
+ # (status-check sessions, abandoned new conversations) drag every median
387
+ # to 0 and the chart tells you nothing.
388
+ research_calls_nz = [s.get("_research_calls", 0) for s in per_session if s.get("_research_calls", 0) > 0]
389
+ distinct_tools_values = [s.get("_distinct_tools_used", 0) for s in per_session if s.get("_distinct_tools_used", 0) > 0]
390
+ total_calls_values = [s.get("_total_named_tool_calls", 0) for s in per_session if s.get("_total_named_tool_calls", 0) > 0]
391
+ # Per-turn intensity: turns>0 is the natural filter here (a session with
392
+ # 5 turns and 0 tools is a meaningful 0). Don't strip those.
393
+ calls_per_turn_values = [
394
+ s.get("_total_named_tool_calls", 0) / s["turns"]
395
+ for s in per_session
396
+ if s.get("turns", 0) > 0
397
+ ]
398
 
399
  total_sessions = sum(s["sessions"] for s in per_session)
400
  total_turns = sum(s["turns"] for s in per_session)
 
402
  tokens_cache_read = sum(s["tokens_cache_read"] for s in per_session)
403
  tool_total = sum(s["tool_calls_total"] for s in per_session)
404
  tool_success = sum(s["tool_calls_success"] for s in per_session)
405
+ failures = int(sum(s["failures"] for s in per_session))
406
+ regenerates = int(sum(s["regenerate_sessions"] for s in per_session))
407
+ research_calls_total = int(sum(s.get("_research_calls", 0) for s in per_session))
408
+ sessions_with_research = sum(1 for s in per_session if s.get("_research_calls", 0) > 0)
409
+
410
+ # Per-session cost percentiles — chart "median session cost" alongside the
411
+ # mean so a few $700 outliers don't make you think every session is pricey.
412
+ session_costs = [float(s.get("cost_usd") or 0.0) for s in per_session]
413
+ cost_p50 = _percentile(session_costs, 0.5)
414
+ cost_p95 = _percentile(session_costs, 0.95)
415
 
416
  unique_users = {s.get("_user") for s in per_session if s.get("_user")}
417
 
 
425
  "tokens_cache_read": int(tokens_cache_read),
426
  "tokens_cache_creation": int(sum(s["tokens_cache_creation"] for s in per_session)),
427
  "cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
428
+ # Per-session cost summaries.
429
+ "cost_per_session_mean": round(
430
+ sum(s["cost_usd"] for s in per_session) / total_sessions, 6
431
+ ) if total_sessions > 0 else 0.0,
432
+ "cost_per_session_p50": round(cost_p50, 6),
433
+ "cost_per_session_p95": round(cost_p95, 6),
434
  "cache_hit_ratio": round(
435
  tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
436
  ) if (tokens_cache_read + tokens_prompt) > 0 else 0.0,
437
+ # Raw reliability COUNTS (these are what the dashboard shows directly).
438
+ "tool_calls_total": int(tool_total),
439
+ "tool_calls_succeeded": int(tool_success),
440
+ "tool_calls_failed": int(tool_total - tool_success),
441
+ "errored_sessions": failures,
442
+ # Successful = "did not raise an error event". Mutually exclusive
443
+ # with errored_sessions; sums with errored_sessions to total sessions.
444
+ "successful_sessions": int(total_sessions - failures),
445
+ # Regenerated is an orthogonal dimension (the user retried) — a
446
+ # session can be both successful and regenerated, or both errored
447
+ # and regenerated.
448
+ "regenerated_sessions": regenerates,
449
+ # Rates kept for backwards compatibility with anything reading the
450
+ # KPI dataset directly.
451
  "tool_success_rate": round(tool_success / tool_total, 4) if tool_total > 0 else 0.0,
452
+ "failure_rate": round(failures / total_sessions, 4) if total_sessions > 0 else 0.0,
453
+ "regenerate_rate": round(regenerates / total_sessions, 4) if total_sessions > 0 else 0.0,
 
 
 
 
454
  "time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
455
  "time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
456
  "thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
457
  "thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
458
  "hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
459
  "hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
460
+ "sandboxes_created": int(sum(s.get("sandboxes_created", 0) for s in per_session)),
461
+ "sandboxes_cpu": int(sum(s.get("sandboxes_cpu", 0) for s in per_session)),
462
+ "sandboxes_gpu": int(sum(s.get("sandboxes_gpu", 0) for s in per_session)),
463
+ "hf_jobs_blocked": int(sum(s.get("hf_jobs_blocked", 0) for s in per_session)),
464
+ "pro_cta_clicks": int(sum(s.get("pro_cta_clicks", 0) for s in per_session)),
465
  "gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
466
+ # Research KPIs — answer "is the agent reaching for research?".
467
+ "research_calls": research_calls_total,
468
+ "sessions_with_research": int(sessions_with_research),
469
+ "research_calls_per_session_p50": round(_percentile(research_calls_nz, 0.5), 2),
470
+ "research_calls_per_session_p95": round(_percentile(research_calls_nz, 0.95), 2),
471
+ # Intra-session breadth + intensity. p50 + p95 over per-session values.
472
+ "distinct_tools_per_session_p50": round(_percentile(distinct_tools_values, 0.5), 2),
473
+ "distinct_tools_per_session_p95": round(_percentile(distinct_tools_values, 0.95), 2),
474
+ "tool_calls_per_session_p50": round(_percentile(total_calls_values, 0.5), 2),
475
+ "tool_calls_per_session_p95": round(_percentile(total_calls_values, 0.95), 2),
476
+ "tool_calls_per_turn_p50": round(_percentile(calls_per_turn_values, 0.5), 2),
477
+ "tool_calls_per_turn_p95": round(_percentile(calls_per_turn_values, 0.95), 2),
478
+ # JSON columns let the dashboard add/remove tools without schema churn.
479
+ "tool_calls_by_name_json": json.dumps(dict(tool_calls_by_name), sort_keys=True),
480
+ "sessions_using_tool_json": json.dumps(dict(sessions_using_tool), sort_keys=True),
481
+ # Surface split — answers "is research dropping on Bedrock specifically?".
482
+ "sessions_by_model_json": json.dumps(dict(sessions_by_model), sort_keys=True),
483
  }
484
 
485
 
scripts/sweep_orphan_sandboxes.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Backstop sweeper for orphan ml-intern sandbox Spaces.
3
+
4
+ ================================================================================
5
+ Why this script exists
6
+ ================================================================================
7
+
8
+ The agent creates a sandbox Space per session (template duplicated from
9
+ ``burtenshaw/sandbox`` into the user's account, named ``<owner>/sandbox-<8hex>``).
10
+ ``backend.session_manager.SessionManager._cleanup_sandbox`` deletes it at end of
11
+ session. In practice the cleanup misses some sandboxes:
12
+
13
+ - pod killed / OOM / pre-emption / deploy rollouts → ``finally`` block skipped
14
+ - WebSocket dropped without ``/shutdown`` from the client
15
+ - HF API transient failure on ``delete_repo`` (we retry now, but not infinitely)
16
+
17
+ The result observed 2026-04-27 was 2,310 orphan ``sandbox-*`` Spaces — every
18
+ sandbox ever created was still around. This script is the backstop: list every
19
+ ``sandbox-*`` fork of ``burtenshaw/sandbox`` that hasn't been touched in N days
20
+ and delete it.
21
+
22
+ ================================================================================
23
+ Identification rules
24
+ ================================================================================
25
+
26
+ A Space is considered an orphan ml-intern sandbox iff ALL hold:
27
+
28
+ 1. Repo type = ``space``
29
+ 2. Name matches ``<owner>/sandbox-[a-f0-9]{8}$`` (the agent's naming convention)
30
+ 3. ``originRepo`` points at ``burtenshaw/sandbox`` (so we don't touch
31
+ user-renamed lookalikes)
32
+ 4. ``lastModified`` older than ``--max-age-days`` (default 7)
33
+
34
+ We DO NOT use the ``runtime.stage`` (sleeping/running) as a filter — a sandbox
35
+ that has been sleeping for 7 days is just as orphan as a deleted one but uses
36
+ no compute. The cleanup is about repo/storage hygiene, not about waking
37
+ something up to kill it.
38
+
39
+ ================================================================================
40
+ Safety
41
+ ================================================================================
42
+
43
+ - ``--dry-run`` (default) prints what would be deleted, deletes nothing.
44
+ - ``--apply`` actually calls ``HfApi.delete_repo``.
45
+ - Hard cap ``--max-deletes`` (default 200) so a misconfigured run can't nuke
46
+ thousands at once.
47
+ - Requires a token with admin rights via ``HF_ADMIN_TOKEN`` env var (the only
48
+ way to delete a Space owned by another user).
49
+ - Logs every action to stdout in JSON Lines for downstream auditing.
50
+
51
+ ================================================================================
52
+ Cron suggestion
53
+ ================================================================================
54
+
55
+ GitHub Actions, daily at 04:00 UTC:
56
+
57
+ schedule:
58
+ - cron: "0 4 * * *"
59
+ env:
60
+ HF_ADMIN_TOKEN: ${{ secrets.HF_ADMIN_TOKEN }}
61
+ steps:
62
+ - run: python scripts/sweep_orphan_sandboxes.py --apply --max-age-days 7
63
+ """
64
+
65
+ import argparse
66
+ import json
67
+ import os
68
+ import re
69
+ import sys
70
+ import time
71
+ from datetime import datetime, timedelta, timezone
72
+
73
+ from huggingface_hub import HfApi
74
+ from huggingface_hub.utils import HfHubHTTPError
75
+
76
+ SANDBOX_NAME_RE = re.compile(r"^[^/]+/sandbox-[a-f0-9]{8}$")
77
+ TEMPLATE_REPO = "burtenshaw/sandbox"
78
+
79
+
80
+ def log(record: dict) -> None:
81
+ """JSON Lines log so downstream tooling can grep / parse."""
82
+ record["ts"] = datetime.now(timezone.utc).isoformat()
83
+ print(json.dumps(record), flush=True)
84
+
85
+
86
+ def is_sandbox_fork(space) -> bool:
87
+ """Filter: matches the ml-intern sandbox naming pattern.
88
+
89
+ NOTE: We initially tried filtering on ``duplicated_from == burtenshaw/sandbox``
90
+ too, for extra safety. That doesn't work — the HF REST API does not expose
91
+ ``duplicated_from`` on ``SpaceInfo`` (verified against ``huggingface-hub``
92
+ 1.11+ and direct ``GET /api/spaces/{id}``: the field is None). The origin
93
+ repo lives in MongoDB but isn't surfaced. So we rely on the naming pattern
94
+ alone, which is specific enough: ``Sandbox.create()`` is the sole producer
95
+ of ``<owner>/sandbox-<8 lowercase hex>``, and that pattern is unlikely to
96
+ collide with user-created Spaces in practice. The ``--dry-run`` default
97
+ is the user-facing safety net for the rare false-positive.
98
+ """
99
+ return bool(SANDBOX_NAME_RE.match(space.id))
100
+
101
+
102
+ def main() -> int:
103
+ parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
104
+ parser.add_argument(
105
+ "--max-age-days",
106
+ type=int,
107
+ default=7,
108
+ help="Delete sandboxes whose lastModified is older than this many days (default: 7)",
109
+ )
110
+ parser.add_argument(
111
+ "--max-deletes",
112
+ type=int,
113
+ default=200,
114
+ help="Hard cap on deletions per run, safety guard (default: 200)",
115
+ )
116
+ parser.add_argument(
117
+ "--apply",
118
+ action="store_true",
119
+ help="Actually delete. Without this flag, dry-run only.",
120
+ )
121
+ parser.add_argument(
122
+ "--limit",
123
+ type=int,
124
+ default=10000,
125
+ help="Max number of candidate Spaces to scan via list_spaces (default: 10000)",
126
+ )
127
+ args = parser.parse_args()
128
+
129
+ token = os.environ.get("HF_ADMIN_TOKEN")
130
+ if not token:
131
+ log({"level": "error", "msg": "HF_ADMIN_TOKEN env var not set"})
132
+ return 1
133
+
134
+ api = HfApi(token=token)
135
+ cutoff = datetime.now(timezone.utc) - timedelta(days=args.max_age_days)
136
+ log({"level": "info", "msg": "sweep_start", "cutoff": cutoff.isoformat(),
137
+ "max_deletes": args.max_deletes, "apply": args.apply})
138
+
139
+ # ``list_spaces`` doesn't filter by name pattern — we scan and filter
140
+ # client-side. ``search="sandbox"`` narrows the network payload.
141
+ candidates = api.list_spaces(
142
+ search="sandbox", full=True, limit=args.limit
143
+ )
144
+
145
+ scanned = 0
146
+ matched = 0
147
+ deleted = 0
148
+ failed = 0
149
+ skipped_too_recent = 0
150
+ skipped_capped = 0
151
+
152
+ for space in candidates:
153
+ scanned += 1
154
+ if not is_sandbox_fork(space):
155
+ continue
156
+ matched += 1
157
+
158
+ last_mod = getattr(space, "lastModified", None) or getattr(space, "last_modified", None)
159
+ if isinstance(last_mod, str):
160
+ last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00"))
161
+ if last_mod and last_mod > cutoff:
162
+ skipped_too_recent += 1
163
+ continue
164
+
165
+ log({"level": "info", "msg": "candidate", "space_id": space.id,
166
+ "last_modified": last_mod.isoformat() if last_mod else None})
167
+
168
+ if not args.apply:
169
+ continue
170
+
171
+ # When we hit the deletion cap, keep scanning so the final ``matched``
172
+ # count reflects the *true* orphan size — not just what was scanned
173
+ # before we stopped deleting. Operators planning multi-pass cleanups
174
+ # need an accurate denominator to know when they're done.
175
+ if deleted >= args.max_deletes:
176
+ skipped_capped += 1
177
+ continue
178
+
179
+ try:
180
+ api.delete_repo(repo_id=space.id, repo_type="space", token=token)
181
+ deleted += 1
182
+ log({"level": "info", "msg": "deleted", "space_id": space.id})
183
+ # Light throttle to avoid hitting HF API rate limits.
184
+ time.sleep(0.2)
185
+ except HfHubHTTPError as e:
186
+ failed += 1
187
+ log({"level": "error", "msg": "delete_failed", "space_id": space.id,
188
+ "status": e.response.status_code, "error": str(e)[:200]})
189
+ except Exception as e:
190
+ failed += 1
191
+ log({"level": "error", "msg": "delete_failed", "space_id": space.id,
192
+ "error": str(e)[:200]})
193
+
194
+ log({"level": "info", "msg": "sweep_end",
195
+ "scanned": scanned, "matched": matched,
196
+ "skipped_too_recent": skipped_too_recent,
197
+ "skipped_capped": skipped_capped,
198
+ "deleted": deleted, "failed": failed,
199
+ "capped": skipped_capped > 0,
200
+ "apply": args.apply})
201
+
202
+ return 0 if failed == 0 else 2
203
+
204
+
205
+ if __name__ == "__main__":
206
+ sys.exit(main())
tests/integration/test_live_sandbox_auth.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Opt-in live sandbox communication test.
2
+
3
+ This test creates a real Hugging Face Space sandbox, verifies that unauthenticated
4
+ requests are rejected, then exercises the authenticated agent client end-to-end.
5
+ It is skipped unless ``ML_INTERN_LIVE_SANDBOX_TESTS=1`` and ``HF_TOKEN`` are set.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ from pathlib import Path
12
+
13
+ import httpx
14
+ import pytest
15
+ from dotenv import load_dotenv
16
+ from huggingface_hub import HfApi
17
+
18
+ from agent.tools.sandbox_client import Sandbox
19
+
20
+
21
+ if env_file := os.environ.get("ML_INTERN_LIVE_ENV_FILE"):
22
+ load_dotenv(Path(env_file))
23
+
24
+
25
+ def _skip_without_live_sandbox() -> None:
26
+ if os.environ.get("ML_INTERN_LIVE_SANDBOX_TESTS") != "1":
27
+ pytest.skip("set ML_INTERN_LIVE_SANDBOX_TESTS=1 to create a real sandbox")
28
+ if not os.environ.get("HF_TOKEN"):
29
+ pytest.skip("set HF_TOKEN to create a real sandbox")
30
+
31
+
32
+ def test_live_sandbox_authenticated_agent_communication():
33
+ _skip_without_live_sandbox()
34
+
35
+ token = os.environ["HF_TOKEN"]
36
+ owner = HfApi(token=token).whoami()["name"]
37
+ sandbox = None
38
+
39
+ try:
40
+ sandbox = Sandbox.create(
41
+ owner=owner,
42
+ name="ml-intern-live-auth",
43
+ hardware="cpu-basic",
44
+ private=False,
45
+ token=token,
46
+ secrets={"HF_TOKEN": token},
47
+ wait_timeout=900,
48
+ )
49
+
50
+ unauthenticated = httpx.Client(
51
+ base_url=sandbox._base_url,
52
+ timeout=30,
53
+ follow_redirects=True,
54
+ )
55
+ try:
56
+ denied = unauthenticated.post("exists", json={"path": "/tmp"})
57
+ assert denied.status_code == 401
58
+ finally:
59
+ unauthenticated.close()
60
+
61
+ bash = sandbox.bash("printf sandbox-live-ok", timeout=30)
62
+ assert bash.success, bash.error
63
+ assert "sandbox-live-ok" in bash.output
64
+
65
+ write = sandbox.write("/tmp/ml_intern_live_auth.txt", "alpha\nbeta\n")
66
+ assert write.success, write.error
67
+
68
+ exists = sandbox._call("exists", {"path": "/tmp/ml_intern_live_auth.txt"})
69
+ assert exists.success, exists.error
70
+ assert exists.output == "true"
71
+
72
+ read = sandbox.read("/tmp/ml_intern_live_auth.txt")
73
+ assert read.success, read.error
74
+ assert "alpha" in read.output
75
+ assert "beta" in read.output
76
+
77
+ reattached = Sandbox.connect(
78
+ sandbox.space_id,
79
+ token=token,
80
+ api_token=sandbox.api_token,
81
+ )
82
+ try:
83
+ reread = reattached.read("/tmp/ml_intern_live_auth.txt")
84
+ assert reread.success, reread.error
85
+ assert "alpha" in reread.output
86
+ finally:
87
+ reattached._client.close()
88
+ finally:
89
+ if sandbox is not None:
90
+ sandbox.delete()
tests/integration/test_live_thinking_models.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Opt-in live provider checks for thinking metadata replay.
2
+
3
+ These tests intentionally call paid model APIs and are skipped unless
4
+ ``ML_INTERN_LIVE_LLM_TESTS=1`` plus the relevant provider key are set.
5
+ They cover the concrete model families involved in #87 without making
6
+ default CI depend on external credentials or provider availability.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import os
12
+ from pathlib import Path
13
+ from types import SimpleNamespace
14
+
15
+ import pytest
16
+ from dotenv import load_dotenv
17
+ from litellm import Message
18
+
19
+ from agent.core.agent_loop import (
20
+ _assistant_message_from_result,
21
+ _call_llm_streaming,
22
+ )
23
+ from agent.core.llm_params import _resolve_llm_params
24
+
25
+
26
+ if env_file := os.environ.get("ML_INTERN_LIVE_ENV_FILE"):
27
+ load_dotenv(Path(env_file))
28
+
29
+ LIVE_TESTS_ENABLED = os.environ.get("ML_INTERN_LIVE_LLM_TESTS") == "1"
30
+ OPUS_47_MODEL = "anthropic/claude-opus-4-7"
31
+ LATEST_GPT_MODEL = "openai/gpt-5.2"
32
+ REPORT_RESULT_TOOL = [
33
+ {
34
+ "type": "function",
35
+ "function": {
36
+ "name": "report_result",
37
+ "description": "Report the final test result.",
38
+ "parameters": {
39
+ "type": "object",
40
+ "properties": {
41
+ "answer": {
42
+ "type": "string",
43
+ "description": "The exact marker requested by the test.",
44
+ }
45
+ },
46
+ "required": ["answer"],
47
+ },
48
+ },
49
+ }
50
+ ]
51
+
52
+
53
+ def _skip_without_live_flag() -> None:
54
+ if not LIVE_TESTS_ENABLED:
55
+ pytest.skip("set ML_INTERN_LIVE_LLM_TESTS=1 to run paid live LLM tests")
56
+
57
+
58
+ def _skip_without_env(name: str) -> None:
59
+ if not os.environ.get(name):
60
+ pytest.skip(f"set {name} to run this live provider test")
61
+
62
+
63
+ def _session(model_name: str):
64
+ events = []
65
+
66
+ async def send_event(event):
67
+ events.append(event)
68
+
69
+ return SimpleNamespace(
70
+ config=SimpleNamespace(model_name=model_name),
71
+ is_cancelled=False,
72
+ send_event=send_event,
73
+ events=events,
74
+ )
75
+
76
+
77
+ @pytest.mark.asyncio
78
+ async def test_live_opus_47_preserves_thinking_metadata_for_replay():
79
+ _skip_without_live_flag()
80
+ _skip_without_env("ANTHROPIC_API_KEY")
81
+
82
+ session = _session(OPUS_47_MODEL)
83
+ llm_params = _resolve_llm_params(
84
+ OPUS_47_MODEL,
85
+ reasoning_effort="high",
86
+ )
87
+
88
+ result = await _call_llm_streaming(
89
+ session,
90
+ messages=[
91
+ Message(
92
+ role="user",
93
+ content=(
94
+ "Use careful reasoning for this small check. "
95
+ "If 17 * 19 = 323, call report_result with answer OPUS_OK."
96
+ ),
97
+ )
98
+ ],
99
+ tools=REPORT_RESULT_TOOL,
100
+ llm_params=llm_params,
101
+ )
102
+
103
+ replay = _assistant_message_from_result(
104
+ result,
105
+ model_name=OPUS_47_MODEL,
106
+ )
107
+
108
+ assert result.content or result.tool_calls_acc
109
+ assert result.thinking_blocks, (
110
+ "Opus returned no thinking_blocks with reasoning_effort='high' - "
111
+ "check that adaptive thinking params are being forwarded correctly"
112
+ )
113
+ assert getattr(replay, "thinking_blocks", None) == result.thinking_blocks
114
+ assert getattr(replay, "reasoning_content", None) == result.reasoning_content
115
+
116
+
117
+ @pytest.mark.asyncio
118
+ async def test_live_latest_gpt_does_not_replay_reasoning_metadata():
119
+ _skip_without_live_flag()
120
+ _skip_without_env("OPENAI_API_KEY")
121
+
122
+ session = _session(LATEST_GPT_MODEL)
123
+ llm_params = _resolve_llm_params(
124
+ LATEST_GPT_MODEL,
125
+ reasoning_effort="low",
126
+ )
127
+
128
+ result = await _call_llm_streaming(
129
+ session,
130
+ messages=[
131
+ Message(
132
+ role="user",
133
+ content="Call report_result with answer GPT_OK.",
134
+ )
135
+ ],
136
+ tools=REPORT_RESULT_TOOL,
137
+ llm_params=llm_params,
138
+ )
139
+
140
+ # Even if a GPT-family response carries provider reasoning internally,
141
+ # OpenAI-compatible history must not echo it back on the next tool turn.
142
+ # Force the non-None strip path when the live model omits reasoning details.
143
+ result.reasoning_content = result.reasoning_content or "synthetic-reasoning"
144
+ replay = _assistant_message_from_result(
145
+ result,
146
+ model_name=LATEST_GPT_MODEL,
147
+ )
148
+
149
+ assert result.content or result.tool_calls_acc
150
+ assert getattr(replay, "thinking_blocks", None) is None
151
+ assert getattr(replay, "reasoning_content", None) is None