lewtun HF Staff OpenAI Codex commited on
Commit
19f9a99
·
2 Parent(s): 53d5c854fc6e96

Deploy 2026-05-05

Browse files

Co-authored-by: OpenAI Codex <codex@openai.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/workflows/ci.yml +63 -0
  2. AGENTS.md +5 -0
  3. agent/config.py +11 -7
  4. agent/context_manager/manager.py +22 -10
  5. agent/core/agent_loop.py +285 -129
  6. agent/core/cost_estimation.py +6 -2
  7. agent/core/doom_loop.py +7 -3
  8. agent/core/effort_probe.py +42 -16
  9. agent/core/hf_router_catalog.py +3 -1
  10. agent/core/hub_artifacts.py +765 -0
  11. agent/core/llm_params.py +10 -3
  12. agent/core/model_switcher.py +13 -7
  13. agent/core/prompt_caching.py +12 -6
  14. agent/core/session.py +4 -3
  15. agent/core/session_persistence.py +12 -4
  16. agent/core/session_uploader.py +2 -3
  17. agent/core/telemetry.py +113 -77
  18. agent/core/tools.py +15 -5
  19. agent/main.py +85 -30
  20. agent/messaging/base.py +5 -1
  21. agent/messaging/gateway.py +9 -3
  22. agent/messaging/models.py +2 -8
  23. agent/messaging/slack.py +1 -3
  24. agent/sft/tagger.py +47 -18
  25. agent/tools/dataset_tools.py +3 -1
  26. agent/tools/edit_utils.py +26 -21
  27. agent/tools/hf_repo_files_tool.py +56 -17
  28. agent/tools/hf_repo_git_tool.py +140 -37
  29. agent/tools/jobs_tool.py +66 -18
  30. agent/tools/local_tools.py +22 -7
  31. agent/tools/papers_tool.py +65 -20
  32. agent/tools/research_tool.py +61 -38
  33. agent/tools/sandbox_client.py +13 -6
  34. agent/tools/sandbox_tool.py +23 -6
  35. agent/tools/web_search_tool.py +4 -1
  36. agent/utils/braille.py +5 -4
  37. agent/utils/crt_boot.py +5 -2
  38. agent/utils/particle_logo.py +3 -1
  39. agent/utils/terminal_display.py +61 -28
  40. backend/dependencies.py +3 -1
  41. backend/kpis_scheduler.py +4 -2
  42. backend/main.py +8 -6
  43. backend/models.py +3 -1
  44. backend/routes/agent.py +54 -15
  45. backend/routes/auth.py +0 -1
  46. backend/session_manager.py +73 -25
  47. backend/user_quotas.py +5 -1
  48. pyproject.toml +1 -0
  49. scripts/build_kpis.py +148 -47
  50. scripts/build_sft.py +21 -5
.github/workflows/ci.yml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CI
2
+
3
+ on:
4
+ pull_request:
5
+ push:
6
+ branches: [main]
7
+
8
+ permissions:
9
+ contents: read
10
+
11
+ concurrency:
12
+ group: ci-${{ github.workflow }}-${{ github.ref }}
13
+ cancel-in-progress: true
14
+
15
+ jobs:
16
+ ruff:
17
+ name: Ruff
18
+ runs-on: ubuntu-latest
19
+ steps:
20
+ - uses: actions/checkout@v4
21
+
22
+ - name: Install uv
23
+ uses: astral-sh/setup-uv@v5
24
+ with:
25
+ enable-cache: true
26
+ cache-dependency-glob: uv.lock
27
+
28
+ - name: Set up Python
29
+ uses: actions/setup-python@v5
30
+ with:
31
+ python-version: "3.12"
32
+
33
+ - name: Install dependencies
34
+ run: uv sync --locked --extra dev
35
+
36
+ - name: Run Ruff
37
+ run: uv run ruff check .
38
+
39
+ - name: Check formatting
40
+ run: uv run ruff format --check .
41
+
42
+ tests:
43
+ name: Tests
44
+ runs-on: ubuntu-latest
45
+ steps:
46
+ - uses: actions/checkout@v4
47
+
48
+ - name: Install uv
49
+ uses: astral-sh/setup-uv@v5
50
+ with:
51
+ enable-cache: true
52
+ cache-dependency-glob: uv.lock
53
+
54
+ - name: Set up Python
55
+ uses: actions/setup-python@v5
56
+ with:
57
+ python-version: "3.12"
58
+
59
+ - name: Install dependencies
60
+ run: uv sync --locked --extra dev
61
+
62
+ - name: Run tests
63
+ run: uv run pytest
AGENTS.md CHANGED
@@ -15,6 +15,11 @@ Notes:
15
  - Prefer `npm ci` over `npm install` for setup, since `npm install` may rewrite `frontend/package-lock.json` metadata depending on npm version.
16
  - Production defaults to the Bedrock Claude model. For local development with a personal Anthropic key, set `ANTHROPIC_API_KEY` and `ML_INTERN_CLAUDE_MODEL_ID=anthropic/claude-opus-4-6` before starting the backend. Other models are selected through the app's model switcher.
17
 
 
 
 
 
 
18
  ## GitHub CLI
19
 
20
  - For multiline PR descriptions, prefer `gh pr edit <number> --body-file <file>` over inline `--body` so shell quoting, `$` env-var names, backticks, and newlines are preserved correctly.
 
15
  - Prefer `npm ci` over `npm install` for setup, since `npm install` may rewrite `frontend/package-lock.json` metadata depending on npm version.
16
  - Production defaults to the Bedrock Claude model. For local development with a personal Anthropic key, set `ANTHROPIC_API_KEY` and `ML_INTERN_CLAUDE_MODEL_ID=anthropic/claude-opus-4-6` before starting the backend. Other models are selected through the app's model switcher.
17
 
18
+ ## Development Checks
19
+
20
+ - Before every commit, run `uv run ruff check .` and `uv run ruff format --check .`.
21
+ - If formatting fails, run `uv run ruff format .`, then re-run the Ruff checks before committing.
22
+
23
  ## GitHub CLI
24
 
25
  - For multiline PR descriptions, prefer `gh pr edit <number> --body-file <file>` over inline `--body` so shell quoting, `$` env-var names, backticks, and newlines are preserved correctly.
agent/config.py CHANGED
@@ -5,20 +5,20 @@ from pathlib import Path
5
  from typing import Any, Union
6
 
7
  from dotenv import load_dotenv
8
-
9
- from agent.messaging.models import MessagingConfig
10
-
11
- # Project root: two levels up from this file (agent/config.py -> project root)
12
- _PROJECT_ROOT = Path(__file__).resolve().parent.parent
13
  from fastmcp.mcp_config import (
14
  RemoteMCPServer,
15
  StdioMCPServer,
16
  )
17
  from pydantic import BaseModel
18
 
 
 
19
  # These two are the canonical server config types for MCP servers.
20
  MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer]
21
 
 
 
 
22
 
23
  class Config(BaseModel):
24
  """Configuration manager"""
@@ -60,12 +60,16 @@ class Config(BaseModel):
60
 
61
 
62
  USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG"
63
- DEFAULT_USER_CONFIG_PATH = Path.home() / ".config" / "ml-intern" / "cli_agent_config.json"
 
 
64
  SLACK_DEFAULT_DESTINATION = "slack.default"
65
  SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"]
66
 
67
 
68
- def _deep_merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
 
 
69
  merged = dict(base)
70
  for key, value in override.items():
71
  current = merged.get(key)
 
5
  from typing import Any, Union
6
 
7
  from dotenv import load_dotenv
 
 
 
 
 
8
  from fastmcp.mcp_config import (
9
  RemoteMCPServer,
10
  StdioMCPServer,
11
  )
12
  from pydantic import BaseModel
13
 
14
+ from agent.messaging.models import MessagingConfig
15
+
16
  # These two are the canonical server config types for MCP servers.
17
  MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer]
18
 
19
+ # Project root: two levels up from this file (agent/config.py -> project root)
20
+ _PROJECT_ROOT = Path(__file__).resolve().parent.parent
21
+
22
 
23
  class Config(BaseModel):
24
  """Configuration manager"""
 
60
 
61
 
62
  USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG"
63
+ DEFAULT_USER_CONFIG_PATH = (
64
+ Path.home() / ".config" / "ml-intern" / "cli_agent_config.json"
65
+ )
66
  SLACK_DEFAULT_DESTINATION = "slack.default"
67
  SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"]
68
 
69
 
70
+ def _deep_merge_config(
71
+ base: dict[str, Any], override: dict[str, Any]
72
+ ) -> dict[str, Any]:
73
  merged = dict(base)
74
  for key, value in override.items():
75
  current = merged.get(key)
agent/context_manager/manager.py CHANGED
@@ -3,7 +3,6 @@ Context management for conversation history
3
  """
4
 
5
  import logging
6
- import os
7
  import time
8
  import zoneinfo
9
  from datetime import datetime
@@ -96,6 +95,7 @@ class CompactionFailedError(Exception):
96
  burns Bedrock budget for free (~$3 per re-attempt on Opus).
97
  """
98
 
 
99
  # Used when seeding a brand-new session from prior browser-cached messages.
100
  # Here we're writing a note to *ourselves* — so preserve the tool-call trail,
101
  # files produced, and planned next steps in first person. Optimized for
@@ -155,12 +155,15 @@ async def summarize_messages(
155
  )
156
  if session is not None:
157
  from agent.core import telemetry
 
158
  await telemetry.record_llm_call(
159
  session,
160
  model=model_name,
161
  response=response,
162
  latency_ms=int((time.monotonic() - _t0) * 1000),
163
- finish_reason=response.choices[0].finish_reason if response.choices else None,
 
 
164
  kind=kind,
165
  )
166
  summary = response.choices[0].message.content or ""
@@ -233,6 +236,7 @@ class ContextManager:
233
  # CLI-specific context for local mode
234
  if local_mode:
235
  import os
 
236
  cwd = os.getcwd()
237
  local_context = (
238
  f"\n\n# CLI / Local mode\n\n"
@@ -305,7 +309,9 @@ class ContextManager:
305
  i = 0
306
  while i < len(self.items):
307
  msg = self.items[i]
308
- if getattr(msg, "role", None) != "assistant" or not getattr(msg, "tool_calls", None):
 
 
309
  i += 1
310
  continue
311
 
@@ -316,7 +322,9 @@ class ContextManager:
316
  # before the next non-tool message to satisfy provider ordering.
317
  j = i + 1
318
  immediate_ids: set[str | None] = set()
319
- while j < len(self.items) and getattr(self.items[j], "role", None) == "tool":
 
 
320
  immediate_ids.add(getattr(self.items[j], "tool_call_id", None))
321
  j += 1
322
 
@@ -386,7 +394,9 @@ class ContextManager:
386
 
387
  @property
388
  def needs_compaction(self) -> bool:
389
- return self.running_context_usage > self.compaction_threshold and bool(self.items)
 
 
390
 
391
  def _truncate_oversized(
392
  self, messages: list[Message], model_name: str
@@ -425,7 +435,9 @@ class ContextManager:
425
  )
426
  logger.warning(
427
  "Truncating %s message: %d -> %d tokens for compaction",
428
- msg.role, n, len(placeholder) // 4,
 
 
429
  )
430
  # Preserve all known assistant-side fields (tool_calls, thinking_blocks,
431
  # reasoning_content, provider_specific_fields) even when content is
@@ -459,9 +471,9 @@ class ContextManager:
459
  except Exception as e:
460
  logger.warning("token_counter failed (%s); rough estimate", e)
461
  # Rough fallback: 4 chars per token.
462
- self.running_context_usage = sum(
463
- len(getattr(m, "content", "") or "") for m in self.items
464
- ) // 4
465
 
466
  async def compact(
467
  self,
@@ -516,7 +528,7 @@ class ContextManager:
516
  idx = first_user_idx + 1
517
 
518
  recent_messages = self.items[idx:]
519
- messages_to_summarize = self.items[first_user_idx + 1:idx]
520
 
521
  # Truncate any message that's larger than _MAX_TOKENS_PER_MESSAGE in
522
  # the parts we PRESERVE through compaction (first_user + recent_tail).
 
3
  """
4
 
5
  import logging
 
6
  import time
7
  import zoneinfo
8
  from datetime import datetime
 
95
  burns Bedrock budget for free (~$3 per re-attempt on Opus).
96
  """
97
 
98
+
99
  # Used when seeding a brand-new session from prior browser-cached messages.
100
  # Here we're writing a note to *ourselves* — so preserve the tool-call trail,
101
  # files produced, and planned next steps in first person. Optimized for
 
155
  )
156
  if session is not None:
157
  from agent.core import telemetry
158
+
159
  await telemetry.record_llm_call(
160
  session,
161
  model=model_name,
162
  response=response,
163
  latency_ms=int((time.monotonic() - _t0) * 1000),
164
+ finish_reason=response.choices[0].finish_reason
165
+ if response.choices
166
+ else None,
167
  kind=kind,
168
  )
169
  summary = response.choices[0].message.content or ""
 
236
  # CLI-specific context for local mode
237
  if local_mode:
238
  import os
239
+
240
  cwd = os.getcwd()
241
  local_context = (
242
  f"\n\n# CLI / Local mode\n\n"
 
309
  i = 0
310
  while i < len(self.items):
311
  msg = self.items[i]
312
+ if getattr(msg, "role", None) != "assistant" or not getattr(
313
+ msg, "tool_calls", None
314
+ ):
315
  i += 1
316
  continue
317
 
 
322
  # before the next non-tool message to satisfy provider ordering.
323
  j = i + 1
324
  immediate_ids: set[str | None] = set()
325
+ while (
326
+ j < len(self.items) and getattr(self.items[j], "role", None) == "tool"
327
+ ):
328
  immediate_ids.add(getattr(self.items[j], "tool_call_id", None))
329
  j += 1
330
 
 
394
 
395
  @property
396
  def needs_compaction(self) -> bool:
397
+ return self.running_context_usage > self.compaction_threshold and bool(
398
+ self.items
399
+ )
400
 
401
  def _truncate_oversized(
402
  self, messages: list[Message], model_name: str
 
435
  )
436
  logger.warning(
437
  "Truncating %s message: %d -> %d tokens for compaction",
438
+ msg.role,
439
+ n,
440
+ len(placeholder) // 4,
441
  )
442
  # Preserve all known assistant-side fields (tool_calls, thinking_blocks,
443
  # reasoning_content, provider_specific_fields) even when content is
 
471
  except Exception as e:
472
  logger.warning("token_counter failed (%s); rough estimate", e)
473
  # Rough fallback: 4 chars per token.
474
+ self.running_context_usage = (
475
+ sum(len(getattr(m, "content", "") or "") for m in self.items) // 4
476
+ )
477
 
478
  async def compact(
479
  self,
 
528
  idx = first_user_idx + 1
529
 
530
  recent_messages = self.items[idx:]
531
+ messages_to_summarize = self.items[first_user_idx + 1 : idx]
532
 
533
  # Truncate any message that's larger than _MAX_TOKENS_PER_MESSAGE in
534
  # the parts we PRESERVE through compaction (first_user + recent_tail).
agent/core/agent_loop.py CHANGED
@@ -5,7 +5,6 @@ Main agent implementation with integrated tool system and MCP support
5
  import asyncio
6
  import json
7
  import logging
8
- import os
9
  import time
10
  from dataclasses import dataclass, field
11
  from typing import Any
@@ -27,6 +26,7 @@ from agent.core.cost_estimation import CostEstimate, estimate_tool_cost
27
  from agent.messaging.gateway import NotificationGateway
28
  from agent.core import telemetry
29
  from agent.core.doom_loop import check_for_doom_loop
 
30
  from agent.core.llm_params import _resolve_llm_params
31
  from agent.core.prompt_caching import with_prompt_caching
32
  from agent.core.session import Event, OpType, Session
@@ -54,11 +54,12 @@ def _malformed_tool_name(message: Message) -> str | None:
54
  end = content.find(_MALFORMED_TOOL_SUFFIX, len(_MALFORMED_TOOL_PREFIX))
55
  if end == -1:
56
  return None
57
- return content[len(_MALFORMED_TOOL_PREFIX):end]
58
 
59
 
60
  def _detect_repeated_malformed(
61
- items: list[Message], threshold: int = 2,
 
62
  ) -> str | None:
63
  """Return the repeated malformed tool name if the tail contains a streak.
64
 
@@ -118,6 +119,7 @@ def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
118
 
119
  _IMMEDIATE_HF_JOB_RUNS = {"run", "uv"}
120
 
 
121
  @dataclass(frozen=True)
122
  class ApprovalDecision:
123
  requires_approval: bool
@@ -142,7 +144,9 @@ def _is_scheduled_hf_job_run(tool_name: str, tool_args: dict) -> bool:
142
 
143
 
144
  def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool:
145
- return tool_name == "sandbox_create" or _is_immediate_hf_job_run(tool_name, tool_args)
 
 
146
 
147
 
148
  def _base_needs_approval(
@@ -231,7 +235,9 @@ def _session_auto_approval_enabled(session: Session | None) -> bool:
231
 
232
 
233
  def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool:
234
- return bool((config and config.yolo_mode) or _session_auto_approval_enabled(session))
 
 
235
 
236
 
237
  def _remaining_budget_after_reservations(
@@ -251,7 +257,10 @@ def _budget_block_reason(
251
  ) -> str | None:
252
  if estimate.estimated_cost_usd is None:
253
  return estimate.block_reason or "Could not estimate the cost safely."
254
- if remaining_cap_usd is not None and estimate.estimated_cost_usd > remaining_cap_usd:
 
 
 
255
  return (
256
  f"Estimated cost ${estimate.estimated_cost_usd:.2f} exceeds "
257
  f"remaining YOLO cap ${remaining_cap_usd:.2f}."
@@ -409,15 +418,25 @@ def _is_transient_error(error: Exception) -> bool:
409
  """Return True for errors that are likely transient and worth retrying."""
410
  err_str = str(error).lower()
411
  transient_patterns = [
412
- "timeout", "timed out",
413
- "503", "service unavailable",
414
- "502", "bad gateway",
415
- "500", "internal server error",
416
- "overloaded", "capacity",
417
- "connection reset", "connection refused", "connection error",
418
- "eof", "broken pipe",
 
 
 
 
 
 
 
 
419
  ]
420
- return _is_rate_limit_error(error) or any(pattern in err_str for pattern in transient_patterns)
 
 
421
 
422
 
423
  def _is_effort_config_error(error: Exception) -> bool:
@@ -429,11 +448,14 @@ def _is_effort_config_error(error: Exception) -> bool:
429
  doesn't work for the current model. We heal the cache and retry once.
430
  """
431
  from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported
 
432
  return _is_thinking_unsupported(error) or _is_invalid_effort(error)
433
 
434
 
435
  async def _heal_effort_and_rebuild_params(
436
- session: Session, error: Exception, llm_params: dict,
 
 
437
  ) -> dict:
438
  """Update the session's effort cache based on ``error`` and return new
439
  llm_params. Called only when ``_is_effort_config_error(error)`` is True.
@@ -444,7 +466,11 @@ async def _heal_effort_and_rebuild_params(
444
  • invalid-effort → re-run the full cascade probe; the result lands
445
  in the cache
446
  """
447
- from agent.core.effort_probe import ProbeInconclusive, _is_thinking_unsupported, probe_effort
 
 
 
 
448
 
449
  model = session.config.model_name
450
  if _is_thinking_unsupported(error):
@@ -453,12 +479,16 @@ async def _heal_effort_and_rebuild_params(
453
  else:
454
  try:
455
  outcome = await probe_effort(
456
- model, session.config.reasoning_effort, session.hf_token,
 
 
457
  session=session,
458
  )
459
  session.model_effective_effort[model] = outcome.effective_effort
460
  logger.info(
461
- "healed: %s effort cascade → %s", model, outcome.effective_effort,
 
 
462
  )
463
  except ProbeInconclusive:
464
  # Transient during healing — strip thinking for safety, next
@@ -477,7 +507,11 @@ def _friendly_error_message(error: Exception) -> str | None:
477
  """Return a user-friendly message for known error types, or None to fall back to traceback."""
478
  err_str = str(error).lower()
479
 
480
- if "authentication" in err_str or "unauthorized" in err_str or "invalid x-api-key" in err_str:
 
 
 
 
481
  return (
482
  "Authentication failed — your API key is missing or invalid.\n\n"
483
  "To fix this, set the API key for your model provider:\n"
@@ -503,8 +537,7 @@ def _friendly_error_message(error: Exception) -> str | None:
503
  )
504
 
505
  if "model_not_found" in err_str or (
506
- "model" in err_str
507
- and ("not found" in err_str or "does not exist" in err_str)
508
  ):
509
  return (
510
  "Model not found. Use '/model' to list suggestions, or paste an "
@@ -530,7 +563,10 @@ async def _compact_and_notify(session: Session) -> None:
530
  old_usage = cm.running_context_usage
531
  logger.debug(
532
  "Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s",
533
- old_usage, cm.model_max_tokens, cm.compaction_threshold, cm.needs_compaction,
 
 
 
534
  )
535
  try:
536
  await cm.compact(
@@ -542,24 +578,27 @@ async def _compact_and_notify(session: Session) -> None:
542
  except CompactionFailedError as e:
543
  logger.error(
544
  "Compaction failed for session %s: %s — terminating session",
545
- session.session_id, e,
 
546
  )
547
  # Persist the failure event so the dataset has a record of WHY this
548
  # session ended (and the cost it incurred up to that point) even if
549
  # save_and_upload_detached has issues downstream.
550
- await session.send_event(Event(
551
- event_type="session_terminated",
552
- data={
553
- "reason": "compaction_failed",
554
- "context_usage": cm.running_context_usage,
555
- "context_threshold": cm.compaction_threshold,
556
- "error": str(e)[:300],
557
- "user_message": (
558
- "Your conversation has grown too large to continue. "
559
- "The work you've done is saved — start a new session to keep going."
560
- ),
561
- },
562
- ))
 
 
563
  # Stop the agent loop; the finally in _run_session will fire
564
  # cleanup_sandbox + save_trajectory so the dataset captures
565
  # everything that did happen.
@@ -570,7 +609,10 @@ async def _compact_and_notify(session: Session) -> None:
570
  if new_usage != old_usage:
571
  logger.warning(
572
  "Context compacted: %d -> %d tokens (max=%d, %d messages)",
573
- old_usage, new_usage, cm.model_max_tokens, len(cm.items),
 
 
 
574
  )
575
  await session.send_event(
576
  Event(
@@ -609,6 +651,7 @@ async def _cleanup_on_cancel(session: Session) -> None:
609
  @dataclass
610
  class LLMResult:
611
  """Result from an LLM call (streaming or non-streaming)."""
 
612
  content: str | None
613
  tool_calls_acc: dict[int, dict]
614
  token_count: int
@@ -728,16 +771,18 @@ async def _maybe_heal_invalid_thinking_signature(
728
  if not stripped:
729
  return False
730
 
731
- await session.send_event(Event(
732
- event_type="tool_log",
733
- data={
734
- "tool": "system",
735
- "log": (
736
- "Anthropic rejected stale thinking signatures; retrying "
737
- "without replayed thinking metadata."
738
- ),
739
- },
740
- ))
 
 
741
  return True
742
 
743
 
@@ -762,7 +807,9 @@ def _assistant_message_from_result(
762
  return Message(**kwargs)
763
 
764
 
765
- async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
 
 
766
  """Call the LLM with streaming, emitting assistant_chunk events."""
767
  response = None
768
  _healed_effort = False # one-shot safety net per call
@@ -788,11 +835,18 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
788
  raise ContextWindowExceededError(str(e)) from e
789
  if not _healed_effort and _is_effort_config_error(e):
790
  _healed_effort = True
791
- llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params)
792
- await session.send_event(Event(
793
- event_type="tool_log",
794
- data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
795
- ))
 
 
 
 
 
 
 
796
  continue
797
  if await _maybe_heal_invalid_thinking_signature(
798
  session,
@@ -806,12 +860,20 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
806
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
807
  logger.warning(
808
  "Transient LLM error (attempt %d/%d): %s — retrying in %ds",
809
- _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
 
 
 
 
 
 
 
 
 
 
 
 
810
  )
811
- await session.send_event(Event(
812
- event_type="tool_log",
813
- data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."},
814
- ))
815
  await asyncio.sleep(_delay)
816
  continue
817
  raise
@@ -852,16 +914,21 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
852
  idx = tc_delta.index
853
  if idx not in tool_calls_acc:
854
  tool_calls_acc[idx] = {
855
- "id": "", "type": "function",
 
856
  "function": {"name": "", "arguments": ""},
857
  }
858
  if tc_delta.id:
859
  tool_calls_acc[idx]["id"] = tc_delta.id
860
  if tc_delta.function:
861
  if tc_delta.function.name:
862
- tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name
 
 
863
  if tc_delta.function.arguments:
864
- tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments
 
 
865
 
866
  if hasattr(chunk, "usage") and chunk.usage:
867
  token_count = chunk.usage.total_tokens
@@ -881,7 +948,9 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
881
  rebuilt = stream_chunk_builder(chunks, messages=messages)
882
  if rebuilt and getattr(rebuilt, "choices", None):
883
  rebuilt_msg = rebuilt.choices[0].message
884
- thinking_blocks, reasoning_content = _extract_thinking_state(rebuilt_msg)
 
 
885
  except Exception:
886
  logger.debug("Failed to rebuild streaming thinking state", exc_info=True)
887
 
@@ -896,7 +965,9 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
896
  )
897
 
898
 
899
- async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
 
 
900
  """Call the LLM without streaming, emit assistant_message at the end."""
901
  response = None
902
  _healed_effort = False
@@ -921,11 +992,18 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
921
  raise ContextWindowExceededError(str(e)) from e
922
  if not _healed_effort and _is_effort_config_error(e):
923
  _healed_effort = True
924
- llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params)
925
- await session.send_event(Event(
926
- event_type="tool_log",
927
- data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
928
- ))
 
 
 
 
 
 
 
929
  continue
930
  if await _maybe_heal_invalid_thinking_signature(
931
  session,
@@ -939,12 +1017,20 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
939
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
940
  logger.warning(
941
  "Transient LLM error (attempt %d/%d): %s — retrying in %ds",
942
- _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay,
 
 
 
 
 
 
 
 
 
 
 
 
943
  )
944
- await session.send_event(Event(
945
- event_type="tool_log",
946
- data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."},
947
- ))
948
  await asyncio.sleep(_delay)
949
  continue
950
  raise
@@ -1037,7 +1123,8 @@ class Handlers:
1037
 
1038
  @staticmethod
1039
  async def run_agent(
1040
- session: Session, text: str,
 
1041
  ) -> str | None:
1042
  """
1043
  Handle user input (like user_input_or_turn in codex.rs:1291)
@@ -1124,12 +1211,18 @@ class Handlers:
1124
  llm_params = _resolve_llm_params(
1125
  session.config.model_name,
1126
  session.hf_token,
1127
- reasoning_effort=session.effective_effort_for(session.config.model_name),
 
 
1128
  )
1129
  if session.stream:
1130
- llm_result = await _call_llm_streaming(session, messages, tools, llm_params)
 
 
1131
  else:
1132
- llm_result = await _call_llm_non_streaming(session, messages, tools, llm_params)
 
 
1133
 
1134
  content = llm_result.content
1135
  tool_calls_acc = llm_result.tool_calls_acc
@@ -1176,7 +1269,10 @@ class Handlers:
1176
  await session.send_event(
1177
  Event(
1178
  event_type="tool_log",
1179
- data={"tool": "system", "log": f"Output truncated — retrying with smaller content ({dropped_names})"},
 
 
 
1180
  )
1181
  )
1182
  iteration += 1
@@ -1239,7 +1335,8 @@ class Handlers:
1239
  except (json.JSONDecodeError, TypeError, ValueError):
1240
  logger.warning(
1241
  "Malformed arguments for tool_call %s (%s) — skipping",
1242
- tc.id, tc.function.name,
 
1243
  )
1244
  tc.function.arguments = "{}"
1245
  bad_tools.append(tc)
@@ -1260,20 +1357,35 @@ class Handlers:
1260
  f"arguments and was NOT executed. Retry with smaller content — "
1261
  f"for 'write', split into multiple smaller writes using 'edit'."
1262
  )
1263
- session.context_manager.add_message(Message(
1264
- role="tool",
1265
- content=error_msg,
1266
- tool_call_id=tc.id,
1267
- name=tc.function.name,
1268
- ))
1269
- await session.send_event(Event(
1270
- event_type="tool_call",
1271
- data={"tool": tc.function.name, "arguments": {}, "tool_call_id": tc.id},
1272
- ))
1273
- await session.send_event(Event(
1274
- event_type="tool_output",
1275
- data={"tool": tc.function.name, "tool_call_id": tc.id, "output": error_msg, "success": False},
1276
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1277
 
1278
  # ── Cancellation check: before tool execution ──
1279
  if session.is_cancelled:
@@ -1298,7 +1410,9 @@ class Handlers:
1298
  reserved_spend_usd=reserved_auto_spend_usd,
1299
  )
1300
  if decision.requires_approval:
1301
- approval_required_tools.append((tc, tool_name, tool_args, decision))
 
 
1302
  else:
1303
  non_approval_tools.append((tc, tool_name, tool_args, decision))
1304
  if (
@@ -1321,7 +1435,14 @@ class Handlers:
1321
  )
1322
 
1323
  # 2. Send all tool_call events upfront (so frontend shows them all)
1324
- for tc, tool_name, tool_args, _decision, args_valid, _ in parsed_tools:
 
 
 
 
 
 
 
1325
  if args_valid:
1326
  await session.send_event(
1327
  Event(
@@ -1352,12 +1473,14 @@ class Handlers:
1352
  )
1353
  return (tc, name, args, out, ok)
1354
 
1355
- gather_task = asyncio.ensure_future(asyncio.gather(
1356
- *[
1357
- _exec_tool(tc, name, args, decision, valid, err)
1358
- for tc, name, args, decision, valid, err in parsed_tools
1359
- ]
1360
- ))
 
 
1361
  cancel_task = asyncio.ensure_future(session._cancelled.wait())
1362
 
1363
  done, _ = await asyncio.wait(
@@ -1374,10 +1497,16 @@ class Handlers:
1374
  # Notify frontend that in-flight tools were cancelled
1375
  for tc, name, _args, _decision, valid, _ in parsed_tools:
1376
  if valid:
1377
- await session.send_event(Event(
1378
- event_type="tool_state_change",
1379
- data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"},
1380
- ))
 
 
 
 
 
 
1381
  await _cleanup_on_cancel(session)
1382
  break
1383
 
@@ -1414,10 +1543,15 @@ class Handlers:
1414
  for tc, tool_name, tool_args, decision in approval_required_tools:
1415
  # Resolve sandbox file paths for hf_jobs scripts so the
1416
  # frontend can display & edit the actual file content.
1417
- if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str):
 
 
1418
  from agent.tools.sandbox_tool import resolve_sandbox_script
 
1419
  sandbox = getattr(session, "sandbox", None)
1420
- resolved, _ = await resolve_sandbox_script(sandbox, tool_args["script"])
 
 
1421
  if resolved:
1422
  tool_args = {**tool_args, "script": resolved}
1423
 
@@ -1449,10 +1583,12 @@ class Handlers:
1449
  "remaining_cap_usd": first.get("remaining_cap_usd"),
1450
  }
1451
  )
1452
- await session.send_event(Event(
1453
- event_type="approval_required",
1454
- data=event_data,
1455
- ))
 
 
1456
 
1457
  # Store all approval-requiring tools (ToolCall objects for execution)
1458
  session.pending_approval = {
@@ -1470,7 +1606,10 @@ class Handlers:
1470
  logger.warning(
1471
  "ContextWindowExceededError at iteration %d — forcing compaction "
1472
  "(usage=%d, model_max_tokens=%d, messages=%d)",
1473
- iteration, cm.running_context_usage, cm.model_max_tokens, len(cm.items),
 
 
 
1474
  )
1475
  cm.running_context_usage = cm.model_max_tokens + 1
1476
  await _compact_and_notify(session)
@@ -1662,13 +1801,15 @@ class Handlers:
1662
 
1663
  # Execute all approved tools concurrently (cancellable)
1664
  if approved_tasks:
1665
- gather_task = asyncio.ensure_future(asyncio.gather(
1666
- *[
1667
- execute_tool(tc, tool_name, tool_args, was_edited)
1668
- for tc, tool_name, tool_args, was_edited in approved_tasks
1669
- ],
1670
- return_exceptions=True,
1671
- ))
 
 
1672
  cancel_task = asyncio.ensure_future(session._cancelled.wait())
1673
 
1674
  done, _ = await asyncio.wait(
@@ -1684,10 +1825,16 @@ class Handlers:
1684
  pass
1685
  # Notify frontend that approved tools were cancelled
1686
  for tc, tool_name, _args, _was_edited in approved_tasks:
1687
- await session.send_event(Event(
1688
- event_type="tool_state_change",
1689
- data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"},
1690
- ))
 
 
 
 
 
 
1691
  await _cleanup_on_cancel(session)
1692
  await session.send_event(Event(event_type="interrupted"))
1693
  session.increment_turn()
@@ -1839,14 +1986,20 @@ async def submission_loop(
1839
 
1840
  # Create session with tool router
1841
  session = Session(
1842
- event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
1843
- user_id=user_id, local_mode=local_mode, stream=stream,
 
 
 
 
 
1844
  notification_gateway=notification_gateway,
1845
  notification_destinations=notification_destinations,
1846
  defer_turn_complete_notification=defer_turn_complete_notification,
1847
  )
1848
  if session_holder is not None:
1849
  session_holder[0] = session
 
1850
  logger.info("Agent loop started")
1851
 
1852
  # Retry any failed uploads from previous sessions (fire-and-forget).
@@ -1864,10 +2017,13 @@ async def submission_loop(
1864
  async with tool_router:
1865
  # Emit ready event after initialization
1866
  await session.send_event(
1867
- Event(event_type="ready", data={
1868
- "message": "Agent initialized",
1869
- "tool_count": len(tool_router.tools),
1870
- })
 
 
 
1871
  )
1872
 
1873
  while session.is_running:
 
5
  import asyncio
6
  import json
7
  import logging
 
8
  import time
9
  from dataclasses import dataclass, field
10
  from typing import Any
 
26
  from agent.messaging.gateway import NotificationGateway
27
  from agent.core import telemetry
28
  from agent.core.doom_loop import check_for_doom_loop
29
+ from agent.core.hub_artifacts import start_session_artifact_collection_task
30
  from agent.core.llm_params import _resolve_llm_params
31
  from agent.core.prompt_caching import with_prompt_caching
32
  from agent.core.session import Event, OpType, Session
 
54
  end = content.find(_MALFORMED_TOOL_SUFFIX, len(_MALFORMED_TOOL_PREFIX))
55
  if end == -1:
56
  return None
57
+ return content[len(_MALFORMED_TOOL_PREFIX) : end]
58
 
59
 
60
  def _detect_repeated_malformed(
61
+ items: list[Message],
62
+ threshold: int = 2,
63
  ) -> str | None:
64
  """Return the repeated malformed tool name if the tail contains a streak.
65
 
 
119
 
120
  _IMMEDIATE_HF_JOB_RUNS = {"run", "uv"}
121
 
122
+
123
  @dataclass(frozen=True)
124
  class ApprovalDecision:
125
  requires_approval: bool
 
144
 
145
 
146
  def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool:
147
+ return tool_name == "sandbox_create" or _is_immediate_hf_job_run(
148
+ tool_name, tool_args
149
+ )
150
 
151
 
152
  def _base_needs_approval(
 
235
 
236
 
237
  def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool:
238
+ return bool(
239
+ (config and config.yolo_mode) or _session_auto_approval_enabled(session)
240
+ )
241
 
242
 
243
  def _remaining_budget_after_reservations(
 
257
  ) -> str | None:
258
  if estimate.estimated_cost_usd is None:
259
  return estimate.block_reason or "Could not estimate the cost safely."
260
+ if (
261
+ remaining_cap_usd is not None
262
+ and estimate.estimated_cost_usd > remaining_cap_usd
263
+ ):
264
  return (
265
  f"Estimated cost ${estimate.estimated_cost_usd:.2f} exceeds "
266
  f"remaining YOLO cap ${remaining_cap_usd:.2f}."
 
418
  """Return True for errors that are likely transient and worth retrying."""
419
  err_str = str(error).lower()
420
  transient_patterns = [
421
+ "timeout",
422
+ "timed out",
423
+ "503",
424
+ "service unavailable",
425
+ "502",
426
+ "bad gateway",
427
+ "500",
428
+ "internal server error",
429
+ "overloaded",
430
+ "capacity",
431
+ "connection reset",
432
+ "connection refused",
433
+ "connection error",
434
+ "eof",
435
+ "broken pipe",
436
  ]
437
+ return _is_rate_limit_error(error) or any(
438
+ pattern in err_str for pattern in transient_patterns
439
+ )
440
 
441
 
442
  def _is_effort_config_error(error: Exception) -> bool:
 
448
  doesn't work for the current model. We heal the cache and retry once.
449
  """
450
  from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported
451
+
452
  return _is_thinking_unsupported(error) or _is_invalid_effort(error)
453
 
454
 
455
  async def _heal_effort_and_rebuild_params(
456
+ session: Session,
457
+ error: Exception,
458
+ llm_params: dict,
459
  ) -> dict:
460
  """Update the session's effort cache based on ``error`` and return new
461
  llm_params. Called only when ``_is_effort_config_error(error)`` is True.
 
466
  • invalid-effort → re-run the full cascade probe; the result lands
467
  in the cache
468
  """
469
+ from agent.core.effort_probe import (
470
+ ProbeInconclusive,
471
+ _is_thinking_unsupported,
472
+ probe_effort,
473
+ )
474
 
475
  model = session.config.model_name
476
  if _is_thinking_unsupported(error):
 
479
  else:
480
  try:
481
  outcome = await probe_effort(
482
+ model,
483
+ session.config.reasoning_effort,
484
+ session.hf_token,
485
  session=session,
486
  )
487
  session.model_effective_effort[model] = outcome.effective_effort
488
  logger.info(
489
+ "healed: %s effort cascade → %s",
490
+ model,
491
+ outcome.effective_effort,
492
  )
493
  except ProbeInconclusive:
494
  # Transient during healing — strip thinking for safety, next
 
507
  """Return a user-friendly message for known error types, or None to fall back to traceback."""
508
  err_str = str(error).lower()
509
 
510
+ if (
511
+ "authentication" in err_str
512
+ or "unauthorized" in err_str
513
+ or "invalid x-api-key" in err_str
514
+ ):
515
  return (
516
  "Authentication failed — your API key is missing or invalid.\n\n"
517
  "To fix this, set the API key for your model provider:\n"
 
537
  )
538
 
539
  if "model_not_found" in err_str or (
540
+ "model" in err_str and ("not found" in err_str or "does not exist" in err_str)
 
541
  ):
542
  return (
543
  "Model not found. Use '/model' to list suggestions, or paste an "
 
563
  old_usage = cm.running_context_usage
564
  logger.debug(
565
  "Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s",
566
+ old_usage,
567
+ cm.model_max_tokens,
568
+ cm.compaction_threshold,
569
+ cm.needs_compaction,
570
  )
571
  try:
572
  await cm.compact(
 
578
  except CompactionFailedError as e:
579
  logger.error(
580
  "Compaction failed for session %s: %s — terminating session",
581
+ session.session_id,
582
+ e,
583
  )
584
  # Persist the failure event so the dataset has a record of WHY this
585
  # session ended (and the cost it incurred up to that point) even if
586
  # save_and_upload_detached has issues downstream.
587
+ await session.send_event(
588
+ Event(
589
+ event_type="session_terminated",
590
+ data={
591
+ "reason": "compaction_failed",
592
+ "context_usage": cm.running_context_usage,
593
+ "context_threshold": cm.compaction_threshold,
594
+ "error": str(e)[:300],
595
+ "user_message": (
596
+ "Your conversation has grown too large to continue. "
597
+ "The work you've done is saved — start a new session to keep going."
598
+ ),
599
+ },
600
+ )
601
+ )
602
  # Stop the agent loop; the finally in _run_session will fire
603
  # cleanup_sandbox + save_trajectory so the dataset captures
604
  # everything that did happen.
 
609
  if new_usage != old_usage:
610
  logger.warning(
611
  "Context compacted: %d -> %d tokens (max=%d, %d messages)",
612
+ old_usage,
613
+ new_usage,
614
+ cm.model_max_tokens,
615
+ len(cm.items),
616
  )
617
  await session.send_event(
618
  Event(
 
651
  @dataclass
652
  class LLMResult:
653
  """Result from an LLM call (streaming or non-streaming)."""
654
+
655
  content: str | None
656
  tool_calls_acc: dict[int, dict]
657
  token_count: int
 
771
  if not stripped:
772
  return False
773
 
774
+ await session.send_event(
775
+ Event(
776
+ event_type="tool_log",
777
+ data={
778
+ "tool": "system",
779
+ "log": (
780
+ "Anthropic rejected stale thinking signatures; retrying "
781
+ "without replayed thinking metadata."
782
+ ),
783
+ },
784
+ )
785
+ )
786
  return True
787
 
788
 
 
807
  return Message(**kwargs)
808
 
809
 
810
+ async def _call_llm_streaming(
811
+ session: Session, messages, tools, llm_params
812
+ ) -> LLMResult:
813
  """Call the LLM with streaming, emitting assistant_chunk events."""
814
  response = None
815
  _healed_effort = False # one-shot safety net per call
 
835
  raise ContextWindowExceededError(str(e)) from e
836
  if not _healed_effort and _is_effort_config_error(e):
837
  _healed_effort = True
838
+ llm_params = await _heal_effort_and_rebuild_params(
839
+ session, e, llm_params
840
+ )
841
+ await session.send_event(
842
+ Event(
843
+ event_type="tool_log",
844
+ data={
845
+ "tool": "system",
846
+ "log": "Reasoning effort not supported for this model — adjusting and retrying.",
847
+ },
848
+ )
849
+ )
850
  continue
851
  if await _maybe_heal_invalid_thinking_signature(
852
  session,
 
860
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
861
  logger.warning(
862
  "Transient LLM error (attempt %d/%d): %s — retrying in %ds",
863
+ _llm_attempt + 1,
864
+ _MAX_LLM_RETRIES,
865
+ e,
866
+ _delay,
867
+ )
868
+ await session.send_event(
869
+ Event(
870
+ event_type="tool_log",
871
+ data={
872
+ "tool": "system",
873
+ "log": f"LLM connection error, retrying in {_delay}s...",
874
+ },
875
+ )
876
  )
 
 
 
 
877
  await asyncio.sleep(_delay)
878
  continue
879
  raise
 
914
  idx = tc_delta.index
915
  if idx not in tool_calls_acc:
916
  tool_calls_acc[idx] = {
917
+ "id": "",
918
+ "type": "function",
919
  "function": {"name": "", "arguments": ""},
920
  }
921
  if tc_delta.id:
922
  tool_calls_acc[idx]["id"] = tc_delta.id
923
  if tc_delta.function:
924
  if tc_delta.function.name:
925
+ tool_calls_acc[idx]["function"]["name"] += (
926
+ tc_delta.function.name
927
+ )
928
  if tc_delta.function.arguments:
929
+ tool_calls_acc[idx]["function"]["arguments"] += (
930
+ tc_delta.function.arguments
931
+ )
932
 
933
  if hasattr(chunk, "usage") and chunk.usage:
934
  token_count = chunk.usage.total_tokens
 
948
  rebuilt = stream_chunk_builder(chunks, messages=messages)
949
  if rebuilt and getattr(rebuilt, "choices", None):
950
  rebuilt_msg = rebuilt.choices[0].message
951
+ thinking_blocks, reasoning_content = _extract_thinking_state(
952
+ rebuilt_msg
953
+ )
954
  except Exception:
955
  logger.debug("Failed to rebuild streaming thinking state", exc_info=True)
956
 
 
965
  )
966
 
967
 
968
+ async def _call_llm_non_streaming(
969
+ session: Session, messages, tools, llm_params
970
+ ) -> LLMResult:
971
  """Call the LLM without streaming, emit assistant_message at the end."""
972
  response = None
973
  _healed_effort = False
 
992
  raise ContextWindowExceededError(str(e)) from e
993
  if not _healed_effort and _is_effort_config_error(e):
994
  _healed_effort = True
995
+ llm_params = await _heal_effort_and_rebuild_params(
996
+ session, e, llm_params
997
+ )
998
+ await session.send_event(
999
+ Event(
1000
+ event_type="tool_log",
1001
+ data={
1002
+ "tool": "system",
1003
+ "log": "Reasoning effort not supported for this model — adjusting and retrying.",
1004
+ },
1005
+ )
1006
+ )
1007
  continue
1008
  if await _maybe_heal_invalid_thinking_signature(
1009
  session,
 
1017
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
1018
  logger.warning(
1019
  "Transient LLM error (attempt %d/%d): %s — retrying in %ds",
1020
+ _llm_attempt + 1,
1021
+ _MAX_LLM_RETRIES,
1022
+ e,
1023
+ _delay,
1024
+ )
1025
+ await session.send_event(
1026
+ Event(
1027
+ event_type="tool_log",
1028
+ data={
1029
+ "tool": "system",
1030
+ "log": f"LLM connection error, retrying in {_delay}s...",
1031
+ },
1032
+ )
1033
  )
 
 
 
 
1034
  await asyncio.sleep(_delay)
1035
  continue
1036
  raise
 
1123
 
1124
  @staticmethod
1125
  async def run_agent(
1126
+ session: Session,
1127
+ text: str,
1128
  ) -> str | None:
1129
  """
1130
  Handle user input (like user_input_or_turn in codex.rs:1291)
 
1211
  llm_params = _resolve_llm_params(
1212
  session.config.model_name,
1213
  session.hf_token,
1214
+ reasoning_effort=session.effective_effort_for(
1215
+ session.config.model_name
1216
+ ),
1217
  )
1218
  if session.stream:
1219
+ llm_result = await _call_llm_streaming(
1220
+ session, messages, tools, llm_params
1221
+ )
1222
  else:
1223
+ llm_result = await _call_llm_non_streaming(
1224
+ session, messages, tools, llm_params
1225
+ )
1226
 
1227
  content = llm_result.content
1228
  tool_calls_acc = llm_result.tool_calls_acc
 
1269
  await session.send_event(
1270
  Event(
1271
  event_type="tool_log",
1272
+ data={
1273
+ "tool": "system",
1274
+ "log": f"Output truncated — retrying with smaller content ({dropped_names})",
1275
+ },
1276
  )
1277
  )
1278
  iteration += 1
 
1335
  except (json.JSONDecodeError, TypeError, ValueError):
1336
  logger.warning(
1337
  "Malformed arguments for tool_call %s (%s) — skipping",
1338
+ tc.id,
1339
+ tc.function.name,
1340
  )
1341
  tc.function.arguments = "{}"
1342
  bad_tools.append(tc)
 
1357
  f"arguments and was NOT executed. Retry with smaller content — "
1358
  f"for 'write', split into multiple smaller writes using 'edit'."
1359
  )
1360
+ session.context_manager.add_message(
1361
+ Message(
1362
+ role="tool",
1363
+ content=error_msg,
1364
+ tool_call_id=tc.id,
1365
+ name=tc.function.name,
1366
+ )
1367
+ )
1368
+ await session.send_event(
1369
+ Event(
1370
+ event_type="tool_call",
1371
+ data={
1372
+ "tool": tc.function.name,
1373
+ "arguments": {},
1374
+ "tool_call_id": tc.id,
1375
+ },
1376
+ )
1377
+ )
1378
+ await session.send_event(
1379
+ Event(
1380
+ event_type="tool_output",
1381
+ data={
1382
+ "tool": tc.function.name,
1383
+ "tool_call_id": tc.id,
1384
+ "output": error_msg,
1385
+ "success": False,
1386
+ },
1387
+ )
1388
+ )
1389
 
1390
  # ── Cancellation check: before tool execution ──
1391
  if session.is_cancelled:
 
1410
  reserved_spend_usd=reserved_auto_spend_usd,
1411
  )
1412
  if decision.requires_approval:
1413
+ approval_required_tools.append(
1414
+ (tc, tool_name, tool_args, decision)
1415
+ )
1416
  else:
1417
  non_approval_tools.append((tc, tool_name, tool_args, decision))
1418
  if (
 
1435
  )
1436
 
1437
  # 2. Send all tool_call events upfront (so frontend shows them all)
1438
+ for (
1439
+ tc,
1440
+ tool_name,
1441
+ tool_args,
1442
+ _decision,
1443
+ args_valid,
1444
+ _,
1445
+ ) in parsed_tools:
1446
  if args_valid:
1447
  await session.send_event(
1448
  Event(
 
1473
  )
1474
  return (tc, name, args, out, ok)
1475
 
1476
+ gather_task = asyncio.ensure_future(
1477
+ asyncio.gather(
1478
+ *[
1479
+ _exec_tool(tc, name, args, decision, valid, err)
1480
+ for tc, name, args, decision, valid, err in parsed_tools
1481
+ ]
1482
+ )
1483
+ )
1484
  cancel_task = asyncio.ensure_future(session._cancelled.wait())
1485
 
1486
  done, _ = await asyncio.wait(
 
1497
  # Notify frontend that in-flight tools were cancelled
1498
  for tc, name, _args, _decision, valid, _ in parsed_tools:
1499
  if valid:
1500
+ await session.send_event(
1501
+ Event(
1502
+ event_type="tool_state_change",
1503
+ data={
1504
+ "tool_call_id": tc.id,
1505
+ "tool": name,
1506
+ "state": "cancelled",
1507
+ },
1508
+ )
1509
+ )
1510
  await _cleanup_on_cancel(session)
1511
  break
1512
 
 
1543
  for tc, tool_name, tool_args, decision in approval_required_tools:
1544
  # Resolve sandbox file paths for hf_jobs scripts so the
1545
  # frontend can display & edit the actual file content.
1546
+ if tool_name == "hf_jobs" and isinstance(
1547
+ tool_args.get("script"), str
1548
+ ):
1549
  from agent.tools.sandbox_tool import resolve_sandbox_script
1550
+
1551
  sandbox = getattr(session, "sandbox", None)
1552
+ resolved, _ = await resolve_sandbox_script(
1553
+ sandbox, tool_args["script"]
1554
+ )
1555
  if resolved:
1556
  tool_args = {**tool_args, "script": resolved}
1557
 
 
1583
  "remaining_cap_usd": first.get("remaining_cap_usd"),
1584
  }
1585
  )
1586
+ await session.send_event(
1587
+ Event(
1588
+ event_type="approval_required",
1589
+ data=event_data,
1590
+ )
1591
+ )
1592
 
1593
  # Store all approval-requiring tools (ToolCall objects for execution)
1594
  session.pending_approval = {
 
1606
  logger.warning(
1607
  "ContextWindowExceededError at iteration %d — forcing compaction "
1608
  "(usage=%d, model_max_tokens=%d, messages=%d)",
1609
+ iteration,
1610
+ cm.running_context_usage,
1611
+ cm.model_max_tokens,
1612
+ len(cm.items),
1613
  )
1614
  cm.running_context_usage = cm.model_max_tokens + 1
1615
  await _compact_and_notify(session)
 
1801
 
1802
  # Execute all approved tools concurrently (cancellable)
1803
  if approved_tasks:
1804
+ gather_task = asyncio.ensure_future(
1805
+ asyncio.gather(
1806
+ *[
1807
+ execute_tool(tc, tool_name, tool_args, was_edited)
1808
+ for tc, tool_name, tool_args, was_edited in approved_tasks
1809
+ ],
1810
+ return_exceptions=True,
1811
+ )
1812
+ )
1813
  cancel_task = asyncio.ensure_future(session._cancelled.wait())
1814
 
1815
  done, _ = await asyncio.wait(
 
1825
  pass
1826
  # Notify frontend that approved tools were cancelled
1827
  for tc, tool_name, _args, _was_edited in approved_tasks:
1828
+ await session.send_event(
1829
+ Event(
1830
+ event_type="tool_state_change",
1831
+ data={
1832
+ "tool_call_id": tc.id,
1833
+ "tool": tool_name,
1834
+ "state": "cancelled",
1835
+ },
1836
+ )
1837
+ )
1838
  await _cleanup_on_cancel(session)
1839
  await session.send_event(Event(event_type="interrupted"))
1840
  session.increment_turn()
 
1986
 
1987
  # Create session with tool router
1988
  session = Session(
1989
+ event_queue,
1990
+ config=config,
1991
+ tool_router=tool_router,
1992
+ hf_token=hf_token,
1993
+ user_id=user_id,
1994
+ local_mode=local_mode,
1995
+ stream=stream,
1996
  notification_gateway=notification_gateway,
1997
  notification_destinations=notification_destinations,
1998
  defer_turn_complete_notification=defer_turn_complete_notification,
1999
  )
2000
  if session_holder is not None:
2001
  session_holder[0] = session
2002
+ start_session_artifact_collection_task(session, token=hf_token)
2003
  logger.info("Agent loop started")
2004
 
2005
  # Retry any failed uploads from previous sessions (fire-and-forget).
 
2017
  async with tool_router:
2018
  # Emit ready event after initialization
2019
  await session.send_event(
2020
+ Event(
2021
+ event_type="ready",
2022
+ data={
2023
+ "message": "Agent initialized",
2024
+ "tool_count": len(tool_router.tools),
2025
+ },
2026
+ )
2027
  )
2028
 
2029
  while session.is_running:
agent/core/cost_estimation.py CHANGED
@@ -88,7 +88,9 @@ class CostEstimate:
88
  label: str | None = None
89
 
90
 
91
- def parse_timeout_hours(value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS) -> float | None:
 
 
92
  """Parse HF timeout values into hours.
93
 
94
  Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are
@@ -247,7 +249,9 @@ async def estimate_hf_job_cost(args: dict[str, Any]) -> CostEstimate:
247
  )
248
 
249
 
250
- async def estimate_sandbox_cost(args: dict[str, Any], *, session: Any = None) -> CostEstimate:
 
 
251
  if session is not None and getattr(session, "sandbox", None):
252
  return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing")
253
 
 
88
  label: str | None = None
89
 
90
 
91
+ def parse_timeout_hours(
92
+ value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS
93
+ ) -> float | None:
94
  """Parse HF timeout values into hours.
95
 
96
  Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are
 
249
  )
250
 
251
 
252
+ async def estimate_sandbox_cost(
253
+ args: dict[str, Any], *, session: Any = None
254
+ ) -> CostEstimate:
255
  if session is not None and getattr(session, "sandbox", None):
256
  return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing")
257
 
agent/core/doom_loop.py CHANGED
@@ -81,9 +81,11 @@ def extract_recent_tool_signatures(
81
  name = getattr(fn, "name", "") or ""
82
  args_str = getattr(fn, "arguments", "") or ""
83
  result_hash = None
84
- for follow in recent[idx + 1:]:
85
  role = getattr(follow, "role", None)
86
- if role == "tool" and getattr(follow, "tool_call_id", None) == getattr(tc, "id", None):
 
 
87
  result_hash = _hash_args(str(getattr(follow, "content", "") or ""))
88
  break
89
  if role in {"assistant", "user"}:
@@ -174,7 +176,9 @@ def check_for_doom_loop(messages: list[Message]) -> str | None:
174
  pattern = detect_repeating_sequence(signatures)
175
  if pattern:
176
  pattern_desc = " → ".join(s.name for s in pattern)
177
- logger.warning("Repetition guard activated: repeating sequence [%s]", pattern_desc)
 
 
178
  return (
179
  f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: "
180
  f"[{pattern_desc}]. This pattern has repeated multiple times without progress. "
 
81
  name = getattr(fn, "name", "") or ""
82
  args_str = getattr(fn, "arguments", "") or ""
83
  result_hash = None
84
+ for follow in recent[idx + 1 :]:
85
  role = getattr(follow, "role", None)
86
+ if role == "tool" and getattr(follow, "tool_call_id", None) == getattr(
87
+ tc, "id", None
88
+ ):
89
  result_hash = _hash_args(str(getattr(follow, "content", "") or ""))
90
  break
91
  if role in {"assistant", "user"}:
 
176
  pattern = detect_repeating_sequence(signatures)
177
  if pattern:
178
  pattern_desc = " → ".join(s.name for s in pattern)
179
+ logger.warning(
180
+ "Repetition guard activated: repeating sequence [%s]", pattern_desc
181
+ )
182
  return (
183
  f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: "
184
  f"[{pattern_desc}]. This pattern has repeated multiple times without progress. "
agent/core/effort_probe.py CHANGED
@@ -39,12 +39,12 @@ logger = logging.getLogger(__name__)
39
  # requested level raise ``UnsupportedEffortError`` synchronously (no wasted
40
  # network round-trip) and we advance to the next level.
41
  _EFFORT_CASCADE: dict[str, list[str]] = {
42
- "max": ["max", "xhigh", "high", "medium", "low"],
43
- "xhigh": ["xhigh", "high", "medium", "low"],
44
- "high": ["high", "medium", "low"],
45
- "medium": ["medium", "low"],
46
  "minimal": ["minimal", "low"],
47
- "low": ["low"],
48
  }
49
 
50
  _PROBE_TIMEOUT = 15.0
@@ -69,6 +69,7 @@ class ProbeOutcome:
69
  * str → send this level
70
  * None → model doesn't support thinking; strip it
71
  """
 
72
  effective_effort: str | None
73
  attempts: int
74
  elapsed_ms: int
@@ -108,10 +109,15 @@ def _is_invalid_effort(e: Exception) -> bool:
108
  return any(
109
  phrase in s
110
  for phrase in (
111
- "invalid", "not supported", "must be one of", "not a valid",
112
- "unrecognized", "unknown",
 
 
 
 
113
  # LiteLLM's own pre-flight validation phrasing.
114
- "only supported by", "is only supported",
 
115
  )
116
  )
117
 
@@ -128,11 +134,23 @@ def _is_transient(e: Exception) -> bool:
128
  return any(
129
  p in s
130
  for p in (
131
- "timeout", "timed out", "429", "rate limit",
132
- "503", "service unavailable", "502", "bad gateway",
133
- "500", "internal server error", "overloaded", "capacity",
134
- "connection reset", "connection refused", "connection error",
135
- "eof", "broken pipe",
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
  )
138
 
@@ -173,7 +191,10 @@ async def probe_effort(
173
  for effort in cascade:
174
  try:
175
  params = _resolve_llm_params(
176
- model_name, hf_token, reasoning_effort=effort, strict=True,
 
 
 
177
  )
178
  except UnsupportedEffortError:
179
  # Provider can't even accept this effort name (e.g. "max" on
@@ -198,12 +219,15 @@ async def probe_effort(
198
  # out of the probe and break model switching.
199
  try:
200
  from agent.core import telemetry
 
201
  await telemetry.record_llm_call(
202
  session,
203
  model=model_name,
204
  response=response,
205
  latency_ms=int((time.monotonic() - _t0) * 1000),
206
- finish_reason=response.choices[0].finish_reason if response.choices else None,
 
 
207
  kind="effort_probe",
208
  )
209
  except Exception as _telem_err:
@@ -219,7 +243,9 @@ async def probe_effort(
219
  note="model doesn't support reasoning, dropped",
220
  )
221
  if _is_invalid_effort(e):
222
- logger.debug("probe: %s rejected effort=%s, trying next", model_name, effort)
 
 
223
  continue
224
  if _is_transient(e):
225
  raise ProbeInconclusive(str(e)) from e
 
39
  # requested level raise ``UnsupportedEffortError`` synchronously (no wasted
40
  # network round-trip) and we advance to the next level.
41
  _EFFORT_CASCADE: dict[str, list[str]] = {
42
+ "max": ["max", "xhigh", "high", "medium", "low"],
43
+ "xhigh": ["xhigh", "high", "medium", "low"],
44
+ "high": ["high", "medium", "low"],
45
+ "medium": ["medium", "low"],
46
  "minimal": ["minimal", "low"],
47
+ "low": ["low"],
48
  }
49
 
50
  _PROBE_TIMEOUT = 15.0
 
69
  * str → send this level
70
  * None → model doesn't support thinking; strip it
71
  """
72
+
73
  effective_effort: str | None
74
  attempts: int
75
  elapsed_ms: int
 
109
  return any(
110
  phrase in s
111
  for phrase in (
112
+ "invalid",
113
+ "not supported",
114
+ "must be one of",
115
+ "not a valid",
116
+ "unrecognized",
117
+ "unknown",
118
  # LiteLLM's own pre-flight validation phrasing.
119
+ "only supported by",
120
+ "is only supported",
121
  )
122
  )
123
 
 
134
  return any(
135
  p in s
136
  for p in (
137
+ "timeout",
138
+ "timed out",
139
+ "429",
140
+ "rate limit",
141
+ "503",
142
+ "service unavailable",
143
+ "502",
144
+ "bad gateway",
145
+ "500",
146
+ "internal server error",
147
+ "overloaded",
148
+ "capacity",
149
+ "connection reset",
150
+ "connection refused",
151
+ "connection error",
152
+ "eof",
153
+ "broken pipe",
154
  )
155
  )
156
 
 
191
  for effort in cascade:
192
  try:
193
  params = _resolve_llm_params(
194
+ model_name,
195
+ hf_token,
196
+ reasoning_effort=effort,
197
+ strict=True,
198
  )
199
  except UnsupportedEffortError:
200
  # Provider can't even accept this effort name (e.g. "max" on
 
219
  # out of the probe and break model switching.
220
  try:
221
  from agent.core import telemetry
222
+
223
  await telemetry.record_llm_call(
224
  session,
225
  model=model_name,
226
  response=response,
227
  latency_ms=int((time.monotonic() - _t0) * 1000),
228
+ finish_reason=response.choices[0].finish_reason
229
+ if response.choices
230
+ else None,
231
  kind="effort_probe",
232
  )
233
  except Exception as _telem_err:
 
243
  note="model doesn't support reasoning, dropped",
244
  )
245
  if _is_invalid_effort(e):
246
+ logger.debug(
247
+ "probe: %s rejected effort=%s, trying next", model_name, effort
248
+ )
249
  continue
250
  if _is_transient(e):
251
  raise ProbeInconclusive(str(e)) from e
agent/core/hf_router_catalog.py CHANGED
@@ -92,7 +92,9 @@ def _parse_entry(entry: dict) -> ModelInfo:
92
  input_price=pricing.get("input"),
93
  output_price=pricing.get("output"),
94
  supports_tools=bool(p.get("supports_tools", False)),
95
- supports_structured_output=bool(p.get("supports_structured_output", False)),
 
 
96
  )
97
  )
98
  return ModelInfo(id=entry.get("id", ""), providers=providers)
 
92
  input_price=pricing.get("input"),
93
  output_price=pricing.get("output"),
94
  supports_tools=bool(p.get("supports_tools", False)),
95
+ supports_structured_output=bool(
96
+ p.get("supports_structured_output", False)
97
+ ),
98
  )
99
  )
100
  return ModelInfo(id=entry.get("id", ""), providers=providers)
agent/core/hub_artifacts.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Best-effort Hub metadata for artifacts generated by ML Intern sessions."""
2
+
3
+ import asyncio
4
+ import base64
5
+ import logging
6
+ import re
7
+ import shlex
8
+ import tempfile
9
+ import textwrap
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ from huggingface_hub import HfApi, hf_hub_download
15
+ from huggingface_hub.repocard import metadata_load, metadata_save
16
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ ML_INTERN_TAG = "ml-intern"
21
+ SUPPORTED_REPO_TYPES = {"model", "dataset", "space"}
22
+ PROVENANCE_MARKER = "<!-- ml-intern-provenance -->"
23
+ _COLLECTION_TITLE_PREFIX = "ml-intern-artifacts"
24
+ _COLLECTION_TITLE_MAX_LENGTH = 59
25
+ _UUID_SESSION_ID_RE = re.compile(
26
+ r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-"
27
+ r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"
28
+ )
29
+ _KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts"
30
+ _REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts"
31
+ _COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug"
32
+ _COLLECTION_TASK_ATTR = "_ml_intern_artifact_collection_task"
33
+ _SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {}
34
+ _USAGE_HEADING_RE = re.compile(
35
+ r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b",
36
+ re.IGNORECASE | re.MULTILINE,
37
+ )
38
+ _FRONT_MATTER_RE = re.compile(r"\A---\s*\n.*?\n---\s*\n?", re.DOTALL)
39
+
40
+
41
+ def _safe_session_id(session: Any) -> str:
42
+ raw = str(getattr(session, "session_id", "") or "unknown-session")
43
+ safe = re.sub(r"[^A-Za-z0-9._-]+", "-", raw).strip("-")
44
+ return safe or "unknown-session"
45
+
46
+
47
+ def session_artifact_date(session: Any) -> str:
48
+ """Return the YYYY-MM-DD partition date for a session."""
49
+ raw = getattr(session, "session_start_time", None)
50
+ if raw:
51
+ try:
52
+ return datetime.fromisoformat(str(raw).replace("Z", "+00:00")).strftime(
53
+ "%Y-%m-%d"
54
+ )
55
+ except ValueError:
56
+ logger.debug("Could not parse session_start_time=%r", raw)
57
+ return datetime.utcnow().strftime("%Y-%m-%d")
58
+
59
+
60
+ def _collection_session_id_fragment(session: Any) -> str:
61
+ safe_id = _safe_session_id(session)
62
+ if _UUID_SESSION_ID_RE.match(safe_id):
63
+ return safe_id[:8]
64
+ stem = f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-"
65
+ max_id_length = max(1, _COLLECTION_TITLE_MAX_LENGTH - len(stem))
66
+ if len(safe_id) <= max_id_length:
67
+ return safe_id
68
+ return safe_id[:max_id_length].rstrip("-._") or safe_id[:max_id_length]
69
+
70
+
71
+ def artifact_collection_title(session: Any) -> str:
72
+ return (
73
+ f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-"
74
+ f"{_collection_session_id_fragment(session)}"
75
+ )
76
+
77
+
78
+ def _artifact_key(repo_id: str, repo_type: str | None) -> str:
79
+ return f"{repo_type or 'model'}:{repo_id}"
80
+
81
+
82
+ def _session_artifact_set(session: Any, attr: str) -> set[str]:
83
+ current = getattr(session, attr, None)
84
+ if isinstance(current, set):
85
+ return current
86
+ current = set()
87
+ try:
88
+ setattr(session, attr, current)
89
+ except Exception:
90
+ logger.warning(
91
+ "Could not attach %s to session; using process-local fallback state",
92
+ attr,
93
+ )
94
+ return _SESSION_ARTIFACT_SET_FALLBACK.setdefault((id(session), attr), set())
95
+ return current
96
+
97
+
98
+ def remember_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> None:
99
+ if session is None or not repo_id:
100
+ return
101
+ _session_artifact_set(session, _KNOWN_ARTIFACTS_ATTR).add(
102
+ _artifact_key(repo_id, repo_type)
103
+ )
104
+
105
+
106
+ def is_known_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> bool:
107
+ if session is None or not repo_id:
108
+ return False
109
+ return _artifact_key(repo_id, repo_type) in _session_artifact_set(
110
+ session, _KNOWN_ARTIFACTS_ATTR
111
+ )
112
+
113
+
114
+ def _merge_tags(metadata: dict[str, Any], tag: str = ML_INTERN_TAG) -> dict[str, Any]:
115
+ merged = dict(metadata)
116
+ raw_tags = merged.get("tags")
117
+ if raw_tags is None:
118
+ tags: list[str] = []
119
+ elif isinstance(raw_tags, str):
120
+ tags = [raw_tags]
121
+ elif isinstance(raw_tags, list):
122
+ tags = [str(item) for item in raw_tags]
123
+ else:
124
+ tags = [str(raw_tags)]
125
+
126
+ if tag not in tags:
127
+ tags.append(tag)
128
+ merged["tags"] = tags
129
+ return merged
130
+
131
+
132
+ def _metadata_from_content(content: str) -> dict[str, Any]:
133
+ with tempfile.TemporaryDirectory() as tmp_dir:
134
+ path = Path(tmp_dir) / "README.md"
135
+ path.write_text(content, encoding="utf-8")
136
+ return metadata_load(path) or {}
137
+
138
+
139
+ def _content_with_metadata(content: str, metadata: dict[str, Any]) -> str:
140
+ with tempfile.TemporaryDirectory() as tmp_dir:
141
+ path = Path(tmp_dir) / "README.md"
142
+ path.write_text(content, encoding="utf-8")
143
+ metadata_save(path, metadata)
144
+ return path.read_text(encoding="utf-8")
145
+
146
+
147
+ def _body_without_metadata(content: str) -> str:
148
+ return _FRONT_MATTER_RE.sub("", content, count=1).strip()
149
+
150
+
151
+ def _append_section(content: str, section: str) -> str:
152
+ base = content.rstrip()
153
+ if base:
154
+ return f"{base}\n\n{section.strip()}\n"
155
+ return f"{section.strip()}\n"
156
+
157
+
158
+ def _provenance_section(repo_type: str) -> str:
159
+ label = {"model": "model", "dataset": "dataset"}.get(repo_type, "Hub")
160
+ return f"""{PROVENANCE_MARKER}
161
+ ## Generated by ML Intern
162
+
163
+ This {label} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
164
+
165
+ - Try ML Intern: https://smolagents-ml-intern.hf.space
166
+ - Source code: https://github.com/huggingface/ml-intern
167
+ """
168
+
169
+
170
+ def _usage_section(repo_id: str, repo_type: str) -> str:
171
+ if repo_type == "dataset":
172
+ return f"""## Usage
173
+
174
+ ```python
175
+ from datasets import load_dataset
176
+
177
+ dataset = load_dataset("{repo_id}")
178
+ ```
179
+ """
180
+
181
+ return f"""## Usage
182
+
183
+ ```python
184
+ from transformers import AutoModelForCausalLM, AutoTokenizer
185
+
186
+ model_id = "{repo_id}"
187
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
188
+ model = AutoModelForCausalLM.from_pretrained(model_id)
189
+ ```
190
+
191
+ For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
192
+ """
193
+
194
+
195
+ def augment_repo_card_content(
196
+ content: str | None,
197
+ repo_id: str,
198
+ repo_type: str = "model",
199
+ *,
200
+ extra_metadata: dict[str, Any] | None = None,
201
+ ) -> str:
202
+ """Return README content with ML Intern metadata and provenance added."""
203
+ repo_type = repo_type or "model"
204
+ content = content or ""
205
+ metadata = _metadata_from_content(content)
206
+ if extra_metadata:
207
+ metadata = {**extra_metadata, **metadata}
208
+ metadata = _merge_tags(metadata)
209
+ updated = _content_with_metadata(content, metadata)
210
+
211
+ if not _body_without_metadata(updated):
212
+ updated = _append_section(updated, f"# {repo_id}")
213
+
214
+ if repo_type in {"model", "dataset"} and PROVENANCE_MARKER not in updated:
215
+ updated = _append_section(updated, _provenance_section(repo_type))
216
+ if not _USAGE_HEADING_RE.search(content):
217
+ updated = _append_section(updated, _usage_section(repo_id, repo_type))
218
+
219
+ return updated
220
+
221
+
222
+ def _read_remote_readme(
223
+ api: Any,
224
+ repo_id: str,
225
+ repo_type: str,
226
+ *,
227
+ token: str | bool | None = None,
228
+ ) -> str:
229
+ token_value = token if token is not None else getattr(api, "token", None)
230
+ try:
231
+ readme_path = hf_hub_download(
232
+ repo_id=repo_id,
233
+ filename="README.md",
234
+ repo_type=repo_type,
235
+ token=token_value,
236
+ )
237
+ except (EntryNotFoundError, RepositoryNotFoundError):
238
+ return ""
239
+ return Path(readme_path).read_text(encoding="utf-8")
240
+
241
+
242
+ def _update_repo_card(
243
+ api: Any,
244
+ repo_id: str,
245
+ repo_type: str,
246
+ *,
247
+ token: str | bool | None = None,
248
+ extra_metadata: dict[str, Any] | None = None,
249
+ ) -> None:
250
+ current = _read_remote_readme(api, repo_id, repo_type, token=token)
251
+ updated = augment_repo_card_content(
252
+ current,
253
+ repo_id,
254
+ repo_type,
255
+ extra_metadata=extra_metadata,
256
+ )
257
+ if updated == current:
258
+ return
259
+ api.upload_file(
260
+ path_or_fileobj=updated.encode("utf-8"),
261
+ path_in_repo="README.md",
262
+ repo_id=repo_id,
263
+ repo_type=repo_type,
264
+ token=token,
265
+ commit_message="Update ML Intern artifact metadata",
266
+ )
267
+
268
+
269
+ def _ensure_collection_slug(
270
+ api: Any,
271
+ session: Any,
272
+ *,
273
+ token: str | bool | None = None,
274
+ ) -> str | None:
275
+ slug = getattr(session, _COLLECTION_SLUG_ATTR, None)
276
+ if slug:
277
+ return slug
278
+
279
+ title = artifact_collection_title(session)
280
+ collection = api.create_collection(
281
+ title=title,
282
+ description=(
283
+ f"Artifacts generated by ML Intern session {_safe_session_id(session)} "
284
+ f"on {session_artifact_date(session)}."
285
+ ),
286
+ private=True,
287
+ exists_ok=True,
288
+ token=token,
289
+ )
290
+ slug = getattr(collection, "slug", None)
291
+ if slug:
292
+ setattr(session, _COLLECTION_SLUG_ATTR, slug)
293
+ return slug
294
+
295
+
296
+ async def ensure_session_artifact_collection(
297
+ session: Any,
298
+ *,
299
+ token: str | bool | None = None,
300
+ ) -> str | None:
301
+ """Create/cache the per-session artifact collection without raising."""
302
+ if session is None or not getattr(session, "session_id", None):
303
+ return None
304
+ token_value = token if token is not None else getattr(session, "hf_token", None)
305
+ if not token_value:
306
+ return None
307
+
308
+ try:
309
+ api = HfApi(token=token_value)
310
+ return await asyncio.to_thread(
311
+ _ensure_collection_slug,
312
+ api,
313
+ session,
314
+ token=token_value,
315
+ )
316
+ except Exception as e:
317
+ logger.warning(
318
+ "ML Intern session collection creation failed for %s: %s",
319
+ _safe_session_id(session),
320
+ e,
321
+ )
322
+ return None
323
+
324
+
325
+ def start_session_artifact_collection_task(
326
+ session: Any,
327
+ *,
328
+ token: str | bool | None = None,
329
+ ) -> asyncio.Task | None:
330
+ """Schedule best-effort collection creation for a newly started session."""
331
+ if session is None or not getattr(session, "session_id", None):
332
+ return None
333
+ if getattr(session, _COLLECTION_SLUG_ATTR, None):
334
+ return None
335
+
336
+ token_value = token if token is not None else getattr(session, "hf_token", None)
337
+ if not token_value:
338
+ return None
339
+
340
+ existing = getattr(session, _COLLECTION_TASK_ATTR, None)
341
+ if isinstance(existing, asyncio.Task) and not existing.done():
342
+ return existing
343
+
344
+ try:
345
+ loop = asyncio.get_running_loop()
346
+ except RuntimeError:
347
+ return None
348
+
349
+ async def _run() -> None:
350
+ await ensure_session_artifact_collection(session, token=token_value)
351
+
352
+ task = loop.create_task(_run())
353
+ try:
354
+ setattr(session, _COLLECTION_TASK_ATTR, task)
355
+ except Exception:
356
+ logger.debug("Could not attach ML Intern collection task to session")
357
+ return task
358
+
359
+
360
+ def _add_to_collection(
361
+ api: Any,
362
+ session: Any,
363
+ repo_id: str,
364
+ repo_type: str,
365
+ *,
366
+ token: str | bool | None = None,
367
+ ) -> None:
368
+ slug = _ensure_collection_slug(api, session, token=token)
369
+ if not slug:
370
+ return
371
+ api.add_collection_item(
372
+ collection_slug=slug,
373
+ item_id=repo_id,
374
+ item_type=repo_type,
375
+ note=(
376
+ f"Generated by ML Intern session {_safe_session_id(session)} "
377
+ f"on {session_artifact_date(session)}."
378
+ ),
379
+ exists_ok=True,
380
+ token=token,
381
+ )
382
+
383
+
384
+ def register_hub_artifact(
385
+ api: Any,
386
+ repo_id: str,
387
+ repo_type: str = "model",
388
+ *,
389
+ session: Any = None,
390
+ token: str | bool | None = None,
391
+ extra_metadata: dict[str, Any] | None = None,
392
+ force: bool = False,
393
+ ) -> bool:
394
+ """Tag, card, and collection-register a Hub artifact without raising."""
395
+ if session is None or not repo_id:
396
+ return False
397
+ repo_type = repo_type or "model"
398
+ if repo_type not in SUPPORTED_REPO_TYPES:
399
+ return False
400
+
401
+ key = _artifact_key(repo_id, repo_type)
402
+ remember_hub_artifact(session, repo_id, repo_type)
403
+ registered = _session_artifact_set(session, _REGISTERED_ARTIFACTS_ATTR)
404
+ if key in registered and not force:
405
+ return True
406
+
407
+ token_value = token if token is not None else getattr(api, "token", None)
408
+ card_updated = False
409
+ collection_updated = False
410
+ try:
411
+ _update_repo_card(
412
+ api,
413
+ repo_id,
414
+ repo_type,
415
+ token=token_value,
416
+ extra_metadata=extra_metadata,
417
+ )
418
+ card_updated = True
419
+ except Exception as e:
420
+ logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e)
421
+
422
+ try:
423
+ _add_to_collection(api, session, repo_id, repo_type, token=token_value)
424
+ collection_updated = True
425
+ except Exception as e:
426
+ logger.debug("ML Intern collection update failed for %s: %s", repo_id, e)
427
+
428
+ if card_updated and collection_updated:
429
+ registered.add(key)
430
+ return True
431
+ return False
432
+
433
+
434
+ def build_hub_artifact_sitecustomize(session: Any) -> str:
435
+ """Build standalone sitecustomize.py code for HF Jobs Python processes."""
436
+ if session is None or not getattr(session, "session_id", None):
437
+ return ""
438
+
439
+ session_id = _safe_session_id(session)
440
+ session_date = session_artifact_date(session)
441
+ collection_title = artifact_collection_title(session)
442
+ collection_slug = getattr(session, _COLLECTION_SLUG_ATTR, None)
443
+
444
+ return (
445
+ textwrap.dedent(
446
+ f"""
447
+ # Auto-generated by ML Intern. Best-effort Hub artifact metadata only.
448
+ def _install_ml_intern_artifact_hooks():
449
+ import os
450
+ import re
451
+ import tempfile
452
+ from pathlib import Path
453
+
454
+ try:
455
+ import huggingface_hub as _hub
456
+ from huggingface_hub import HfApi, hf_hub_download
457
+ from huggingface_hub.repocard import metadata_load, metadata_save
458
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
459
+ except Exception:
460
+ return
461
+
462
+ session_id = {session_id!r}
463
+ session_date = {session_date!r}
464
+ collection_title = {collection_title!r}
465
+ tag = {ML_INTERN_TAG!r}
466
+ marker = {PROVENANCE_MARKER!r}
467
+ supported = {sorted(SUPPORTED_REPO_TYPES)!r}
468
+ registering = False
469
+ collection_slug = {collection_slug!r}
470
+ registered = set()
471
+ usage_re = re.compile(
472
+ r"^#{{2,6}}\\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\\b",
473
+ re.IGNORECASE | re.MULTILINE,
474
+ )
475
+ front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL)
476
+
477
+ def _token(value=None, api=None):
478
+ if isinstance(value, str) and value:
479
+ return value
480
+ api_token = getattr(api, "token", None)
481
+ if isinstance(api_token, str) and api_token:
482
+ return api_token
483
+ return (
484
+ os.environ.get("HF_TOKEN")
485
+ or os.environ.get("HUGGINGFACE_HUB_TOKEN")
486
+ or None
487
+ )
488
+
489
+ def _merge_tags(metadata):
490
+ metadata = dict(metadata or {{}})
491
+ raw_tags = metadata.get("tags")
492
+ if raw_tags is None:
493
+ tags = []
494
+ elif isinstance(raw_tags, str):
495
+ tags = [raw_tags]
496
+ elif isinstance(raw_tags, list):
497
+ tags = [str(item) for item in raw_tags]
498
+ else:
499
+ tags = [str(raw_tags)]
500
+ if tag not in tags:
501
+ tags.append(tag)
502
+ metadata["tags"] = tags
503
+ return metadata
504
+
505
+ def _metadata_from_content(content):
506
+ with tempfile.TemporaryDirectory() as tmp_dir:
507
+ path = Path(tmp_dir) / "README.md"
508
+ path.write_text(content or "", encoding="utf-8")
509
+ return metadata_load(path) or {{}}
510
+
511
+ def _content_with_metadata(content, metadata):
512
+ with tempfile.TemporaryDirectory() as tmp_dir:
513
+ path = Path(tmp_dir) / "README.md"
514
+ path.write_text(content or "", encoding="utf-8")
515
+ metadata_save(path, metadata)
516
+ return path.read_text(encoding="utf-8")
517
+
518
+ def _body_without_metadata(content):
519
+ return front_matter_re.sub("", content or "", count=1).strip()
520
+
521
+ def _append_section(content, section):
522
+ base = (content or "").rstrip()
523
+ if base:
524
+ return base + "\\n\\n" + section.strip() + "\\n"
525
+ return section.strip() + "\\n"
526
+
527
+ def _provenance(repo_type):
528
+ label = {{"model": "model", "dataset": "dataset"}}.get(
529
+ repo_type, "Hub"
530
+ )
531
+ return (
532
+ marker
533
+ + "\\n## Generated by ML Intern\\n\\n"
534
+ + f"This {{label}} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.\\n\\n"
535
+ + "- Try ML Intern: https://smolagents-ml-intern.hf.space\\n"
536
+ + "- Source code: https://github.com/huggingface/ml-intern\\n"
537
+ )
538
+
539
+ def _usage(repo_id, repo_type):
540
+ if repo_type == "dataset":
541
+ return (
542
+ "## Usage\\n\\n"
543
+ "```python\\n"
544
+ "from datasets import load_dataset\\n\\n"
545
+ f"dataset = load_dataset({{repo_id!r}})\\n"
546
+ "```\\n"
547
+ )
548
+ return (
549
+ "## Usage\\n\\n"
550
+ "```python\\n"
551
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\\n\\n"
552
+ f"model_id = {{repo_id!r}}\\n"
553
+ "tokenizer = AutoTokenizer.from_pretrained(model_id)\\n"
554
+ "model = AutoModelForCausalLM.from_pretrained(model_id)\\n"
555
+ "```\\n\\n"
556
+ "For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.\\n"
557
+ )
558
+
559
+ def _augment(content, repo_id, repo_type, extra_metadata=None):
560
+ metadata = _metadata_from_content(content or "")
561
+ if extra_metadata:
562
+ metadata = {{**extra_metadata, **metadata}}
563
+ updated = _content_with_metadata(content or "", _merge_tags(metadata))
564
+ if not _body_without_metadata(updated):
565
+ updated = _append_section(updated, f"# {{repo_id}}")
566
+ if repo_type in {{"model", "dataset"}} and marker not in updated:
567
+ updated = _append_section(updated, _provenance(repo_type))
568
+ if not usage_re.search(content or ""):
569
+ updated = _append_section(updated, _usage(repo_id, repo_type))
570
+ return updated
571
+
572
+ def _readme(api, repo_id, repo_type, token_value):
573
+ try:
574
+ path = hf_hub_download(
575
+ repo_id=repo_id,
576
+ filename="README.md",
577
+ repo_type=repo_type,
578
+ token=token_value,
579
+ )
580
+ except (EntryNotFoundError, RepositoryNotFoundError):
581
+ return ""
582
+ return Path(path).read_text(encoding="utf-8")
583
+
584
+ def _ensure_collection(api, token_value):
585
+ nonlocal collection_slug
586
+ if collection_slug:
587
+ return collection_slug
588
+ collection = api.create_collection(
589
+ title=collection_title,
590
+ description=(
591
+ f"Artifacts generated by ML Intern session {{session_id}} "
592
+ f"on {{session_date}}."
593
+ ),
594
+ private=True,
595
+ exists_ok=True,
596
+ token=token_value,
597
+ )
598
+ collection_slug = getattr(collection, "slug", None)
599
+ return collection_slug
600
+
601
+ def _register(
602
+ repo_id,
603
+ repo_type="model",
604
+ token_value=None,
605
+ extra_metadata=None,
606
+ force=False,
607
+ ):
608
+ nonlocal registering
609
+ if registering or not repo_id:
610
+ return
611
+ repo_type = repo_type or "model"
612
+ if repo_type not in supported:
613
+ return
614
+ key = f"{{repo_type}}:{{repo_id}}"
615
+ if key in registered and not force:
616
+ return
617
+ registering = True
618
+ try:
619
+ token_value = _token(token_value)
620
+ api = HfApi(token=token_value)
621
+ try:
622
+ current = _readme(api, repo_id, repo_type, token_value)
623
+ updated = _augment(
624
+ current, repo_id, repo_type, extra_metadata=extra_metadata
625
+ )
626
+ if updated != current:
627
+ _original_upload_file(
628
+ api,
629
+ path_or_fileobj=updated.encode("utf-8"),
630
+ path_in_repo="README.md",
631
+ repo_id=repo_id,
632
+ repo_type=repo_type,
633
+ token=token_value,
634
+ commit_message="Update ML Intern artifact metadata",
635
+ )
636
+ except Exception:
637
+ pass
638
+ try:
639
+ slug = _ensure_collection(api, token_value)
640
+ if slug:
641
+ api.add_collection_item(
642
+ collection_slug=slug,
643
+ item_id=repo_id,
644
+ item_type=repo_type,
645
+ note=(
646
+ f"Generated by ML Intern session {{session_id}} "
647
+ f"on {{session_date}}."
648
+ ),
649
+ exists_ok=True,
650
+ token=token_value,
651
+ )
652
+ except Exception:
653
+ pass
654
+ registered.add(key)
655
+ finally:
656
+ registering = False
657
+
658
+ _original_create_repo = HfApi.create_repo
659
+ _original_upload_file = HfApi.upload_file
660
+ _original_upload_folder = getattr(HfApi, "upload_folder", None)
661
+ _original_create_commit = getattr(HfApi, "create_commit", None)
662
+
663
+ def _repo_id(args, kwargs):
664
+ return kwargs.get("repo_id") or (args[0] if args else None)
665
+
666
+ def _repo_type(kwargs):
667
+ return kwargs.get("repo_type") or "model"
668
+
669
+ def _patched_create_repo(self, *args, **kwargs):
670
+ result = _original_create_repo(self, *args, **kwargs)
671
+ repo_id = _repo_id(args, kwargs)
672
+ repo_type = _repo_type(kwargs)
673
+ extra = None
674
+ if repo_type == "space" and kwargs.get("space_sdk"):
675
+ extra = {{"sdk": kwargs.get("space_sdk")}}
676
+ _register(repo_id, repo_type, _token(kwargs.get("token"), self), extra)
677
+ return result
678
+
679
+ def _patched_upload_file(self, *args, **kwargs):
680
+ result = _original_upload_file(self, *args, **kwargs)
681
+ if not kwargs.get("create_pr"):
682
+ force = kwargs.get("path_in_repo") == "README.md"
683
+ _register(
684
+ kwargs.get("repo_id"),
685
+ _repo_type(kwargs),
686
+ _token(kwargs.get("token"), self),
687
+ force=force,
688
+ )
689
+ return result
690
+
691
+ def _patched_upload_folder(self, *args, **kwargs):
692
+ result = _original_upload_folder(self, *args, **kwargs)
693
+ if not kwargs.get("create_pr"):
694
+ _register(
695
+ kwargs.get("repo_id"),
696
+ _repo_type(kwargs),
697
+ _token(kwargs.get("token"), self),
698
+ force=True,
699
+ )
700
+ return result
701
+
702
+ def _patched_create_commit(self, *args, **kwargs):
703
+ result = _original_create_commit(self, *args, **kwargs)
704
+ if not kwargs.get("create_pr"):
705
+ _register(
706
+ _repo_id(args, kwargs),
707
+ _repo_type(kwargs),
708
+ _token(kwargs.get("token"), self),
709
+ force=True,
710
+ )
711
+ return result
712
+
713
+ HfApi.create_repo = _patched_create_repo
714
+ HfApi.upload_file = _patched_upload_file
715
+ if _original_upload_folder is not None:
716
+ HfApi.upload_folder = _patched_upload_folder
717
+ if _original_create_commit is not None:
718
+ HfApi.create_commit = _patched_create_commit
719
+
720
+ def _patch_module_func(name, method_name):
721
+ original = getattr(_hub, name, None)
722
+ if original is None:
723
+ return
724
+ method = getattr(HfApi, method_name)
725
+
726
+ def _patched(*args, **kwargs):
727
+ api = HfApi(token=_token(kwargs.get("token")))
728
+ return method(api, *args, **kwargs)
729
+
730
+ setattr(_hub, name, _patched)
731
+
732
+ _patch_module_func("create_repo", "create_repo")
733
+ _patch_module_func("upload_file", "upload_file")
734
+ if _original_upload_folder is not None:
735
+ _patch_module_func("upload_folder", "upload_folder")
736
+ if _original_create_commit is not None:
737
+ _patch_module_func("create_commit", "create_commit")
738
+
739
+ try:
740
+ _install_ml_intern_artifact_hooks()
741
+ except Exception:
742
+ pass
743
+ """
744
+ ).strip()
745
+ + "\n"
746
+ )
747
+
748
+
749
+ def wrap_shell_command_with_hub_artifact_bootstrap(
750
+ command: str,
751
+ session: Any,
752
+ ) -> str:
753
+ """Prefix a shell command so child Python processes load Hub hooks."""
754
+ sitecustomize = build_hub_artifact_sitecustomize(session)
755
+ if not sitecustomize or not command:
756
+ return command
757
+
758
+ encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii")
759
+ bootstrap = (
760
+ '_ml_intern_artifacts_dir="$(mktemp -d 2>/dev/null)" '
761
+ f"&& printf %s {shlex.quote(encoded)} | base64 -d "
762
+ '> "$_ml_intern_artifacts_dir/sitecustomize.py" '
763
+ '&& export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"'
764
+ )
765
+ return f"{bootstrap}; {command}"
agent/core/llm_params.py CHANGED
@@ -56,9 +56,16 @@ def _patch_litellm_effort_validation() -> None:
56
  # to return True for families where "max" / "xhigh" are acceptable
57
  # at the API; the cascade handles the case when they're not.
58
  return any(
59
- v in m for v in (
60
- "opus-4-6", "opus_4_6", "opus-4.6", "opus_4.6",
61
- "opus-4-7", "opus_4_7", "opus-4.7", "opus_4.7",
 
 
 
 
 
 
 
62
  )
63
  )
64
 
 
56
  # to return True for families where "max" / "xhigh" are acceptable
57
  # at the API; the cascade handles the case when they're not.
58
  return any(
59
+ v in m
60
+ for v in (
61
+ "opus-4-6",
62
+ "opus_4_6",
63
+ "opus-4.6",
64
+ "opus_4.6",
65
+ "opus-4-7",
66
+ "opus_4_7",
67
+ "opus-4.7",
68
+ "opus_4.7",
69
  )
70
  )
71
 
agent/core/model_switcher.py CHANGED
@@ -28,7 +28,10 @@ SUGGESTED_MODELS = [
28
  {"id": "openai/gpt-5.4", "label": "GPT-5.4"},
29
  {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"},
30
  {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"},
31
- {"id": "bedrock/us.anthropic.claude-opus-4-6-v1", "label": "Claude Opus 4.6 via Bedrock"},
 
 
 
32
  {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"},
33
  {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"},
34
  {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"},
@@ -122,9 +125,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool:
122
  )
123
  ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a"
124
  tools = "tools" if p.supports_tools else "no tools"
125
- console.print(
126
- f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]"
127
- )
128
  return True
129
 
130
 
@@ -183,7 +184,9 @@ async def probe_and_switch_model(
183
  # Nothing to validate with a ping that we couldn't validate on the
184
  # first real call just as cheaply. Skip the probe entirely.
185
  _commit_switch(model_id, config, session, effective=None, cache=False)
186
- console.print(f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]")
 
 
187
  return
188
 
189
  console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]")
@@ -203,8 +206,11 @@ async def probe_and_switch_model(
203
  return
204
 
205
  _commit_switch(
206
- model_id, config, session,
207
- effective=outcome.effective_effort, cache=True,
 
 
 
208
  )
209
  effort_label = outcome.effective_effort or "off"
210
  suffix = f" — {outcome.note}" if outcome.note else ""
 
28
  {"id": "openai/gpt-5.4", "label": "GPT-5.4"},
29
  {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"},
30
  {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"},
31
+ {
32
+ "id": "bedrock/us.anthropic.claude-opus-4-6-v1",
33
+ "label": "Claude Opus 4.6 via Bedrock",
34
+ },
35
  {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"},
36
  {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"},
37
  {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"},
 
125
  )
126
  ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a"
127
  tools = "tools" if p.supports_tools else "no tools"
128
+ console.print(f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]")
 
 
129
  return True
130
 
131
 
 
184
  # Nothing to validate with a ping that we couldn't validate on the
185
  # first real call just as cheaply. Skip the probe entirely.
186
  _commit_switch(model_id, config, session, effective=None, cache=False)
187
+ console.print(
188
+ f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]"
189
+ )
190
  return
191
 
192
  console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]")
 
206
  return
207
 
208
  _commit_switch(
209
+ model_id,
210
+ config,
211
+ session,
212
+ effective=outcome.effective_effort,
213
+ cache=True,
214
  )
215
  effort_label = outcome.effective_effort or "off"
216
  suffix = f" — {outcome.note}" if outcome.note else ""
agent/core/prompt_caching.py CHANGED
@@ -40,7 +40,11 @@ def with_prompt_caching(
40
 
41
  if messages:
42
  first = messages[0]
43
- role = first.get("role") if isinstance(first, dict) else getattr(first, "role", None)
 
 
 
 
44
  if role == "system":
45
  content = (
46
  first.get("content")
@@ -48,11 +52,13 @@ def with_prompt_caching(
48
  else getattr(first, "content", None)
49
  )
50
  if isinstance(content, str) and content:
51
- cached_block = [{
52
- "type": "text",
53
- "text": content,
54
- "cache_control": {"type": "ephemeral"},
55
- }]
 
 
56
  new_first = {"role": "system", "content": cached_block}
57
  messages = [new_first] + list(messages[1:])
58
 
 
40
 
41
  if messages:
42
  first = messages[0]
43
+ role = (
44
+ first.get("role")
45
+ if isinstance(first, dict)
46
+ else getattr(first, "role", None)
47
+ )
48
  if role == "system":
49
  content = (
50
  first.get("content")
 
52
  else getattr(first, "content", None)
53
  )
54
  if isinstance(content, str) and content:
55
+ cached_block = [
56
+ {
57
+ "type": "text",
58
+ "text": content,
59
+ "cache_control": {"type": "ephemeral"},
60
+ }
61
+ ]
62
  new_first = {"role": "system", "content": cached_block}
63
  messages = [new_first] + list(messages[1:])
64
 
agent/core/session.py CHANGED
@@ -48,7 +48,8 @@ def _get_max_tokens_safe(model_name: str) -> int:
48
  continue
49
  logger.info(
50
  "No litellm.get_model_info entry for %s, falling back to %d",
51
- model_name, _DEFAULT_MAX_TOKENS,
 
52
  )
53
  return _DEFAULT_MAX_TOKENS
54
 
@@ -277,8 +278,7 @@ class Session:
277
  if summary:
278
  summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS]
279
  message = (
280
- f"Session {self.session_id} completed successfully.\n"
281
- f"{summary}"
282
  )
283
  else:
284
  message = f"Session {self.session_id} completed successfully."
@@ -444,6 +444,7 @@ class Session:
444
  # snapshot between heartbeats would otherwise leak them.
445
  try:
446
  from agent.core.redact import scrub
 
447
  for key in ("messages", "events", "tools"):
448
  if key in trajectory:
449
  trajectory[key] = scrub(trajectory[key])
 
48
  continue
49
  logger.info(
50
  "No litellm.get_model_info entry for %s, falling back to %d",
51
+ model_name,
52
+ _DEFAULT_MAX_TOKENS,
53
  )
54
  return _DEFAULT_MAX_TOKENS
55
 
 
278
  if summary:
279
  summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS]
280
  message = (
281
+ f"Session {self.session_id} completed successfully.\n{summary}"
 
282
  )
283
  else:
284
  message = f"Session {self.session_id} completed successfully."
 
444
  # snapshot between heartbeats would otherwise leak them.
445
  try:
446
  from agent.core.redact import scrub
447
+
448
  for key in ("messages", "events", "tools"):
449
  if key in trajectory:
450
  trajectory[key] = scrub(trajectory[key])
agent/core/session_persistence.py CHANGED
@@ -271,7 +271,9 @@ class MongoSessionStore(NoopSessionStore):
271
  upsert=True,
272
  )
273
  )
274
- ops.append(DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}}))
 
 
275
  try:
276
  if ops:
277
  await self.db.session_messages.bulk_write(ops, ordered=False)
@@ -288,7 +290,9 @@ class MongoSessionStore(NoopSessionStore):
288
  return None
289
  if meta.get("visibility") == "deleted" and not include_deleted:
290
  return None
291
- cursor = self.db.session_messages.find({"session_id": session_id}).sort("idx", 1)
 
 
292
  messages = [row.get("message") async for row in cursor]
293
  return {"metadata": meta, "messages": messages}
294
 
@@ -356,7 +360,9 @@ class MongoSessionStore(NoopSessionStore):
356
  logger.debug("Failed to append event for %s: %s", session_id, e)
357
  return None
358
 
359
- async def load_events_after(self, session_id: str, after_seq: int = 0) -> list[dict[str, Any]]:
 
 
360
  if not self._ready():
361
  return []
362
  cursor = self.db.session_events.find(
@@ -496,6 +502,8 @@ def get_session_store() -> NoopSessionStore | MongoSessionStore:
496
  return _store
497
 
498
 
499
- def _reset_store_for_tests(store: NoopSessionStore | MongoSessionStore | None = None) -> None:
 
 
500
  global _store
501
  _store = store
 
271
  upsert=True,
272
  )
273
  )
274
+ ops.append(
275
+ DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}})
276
+ )
277
  try:
278
  if ops:
279
  await self.db.session_messages.bulk_write(ops, ordered=False)
 
290
  return None
291
  if meta.get("visibility") == "deleted" and not include_deleted:
292
  return None
293
+ cursor = self.db.session_messages.find({"session_id": session_id}).sort(
294
+ "idx", 1
295
+ )
296
  messages = [row.get("message") async for row in cursor]
297
  return {"metadata": meta, "messages": messages}
298
 
 
360
  logger.debug("Failed to append event for %s: %s", session_id, e)
361
  return None
362
 
363
+ async def load_events_after(
364
+ self, session_id: str, after_seq: int = 0
365
+ ) -> list[dict[str, Any]]:
366
  if not self._ready():
367
  return []
368
  cursor = self.db.session_events.find(
 
502
  return _store
503
 
504
 
505
+ def _reset_store_for_tests(
506
+ store: NoopSessionStore | MongoSessionStore | None = None,
507
+ ) -> None:
508
  global _store
509
  _store = store
agent/core/session_uploader.py CHANGED
@@ -94,8 +94,7 @@ def _msg_uuid(session_id: str, role: str, idx: int) -> str:
94
  digest = hashlib.sha1(f"{session_id}::{role}::{idx}".encode("utf-8")).hexdigest()
95
  # Format like a UUID for visual familiarity (32 hex chars w/ dashes).
96
  return (
97
- f"{digest[0:8]}-{digest[8:12]}-{digest[12:16]}-"
98
- f"{digest[16:20]}-{digest[20:32]}"
99
  )
100
 
101
 
@@ -347,7 +346,7 @@ def _update_upload_status(
347
 
348
  def dataset_card_readme(repo_id: str) -> str:
349
  """Dataset card for personal ML Intern session trace repos."""
350
- return f"""---
351
  pretty_name: "ML Intern Session Traces"
352
  language:
353
  - en
 
94
  digest = hashlib.sha1(f"{session_id}::{role}::{idx}".encode("utf-8")).hexdigest()
95
  # Format like a UUID for visual familiarity (32 hex chars w/ dashes).
96
  return (
97
+ f"{digest[0:8]}-{digest[8:12]}-{digest[12:16]}-{digest[16:20]}-{digest[20:32]}"
 
98
  )
99
 
100
 
 
346
 
347
  def dataset_card_readme(repo_id: str) -> str:
348
  """Dataset card for personal ML Intern session trace repos."""
349
+ return """---
350
  pretty_name: "ML Intern Session Traces"
351
  language:
352
  - en
agent/core/telemetry.py CHANGED
@@ -26,6 +26,7 @@ logger = logging.getLogger(__name__)
26
 
27
  # ── usage extraction ────────────────────────────────────────────────────────
28
 
 
29
  def extract_usage(response_or_chunk: Any) -> dict:
30
  """Flat usage dict from a litellm response or final-chunk usage object.
31
 
@@ -71,6 +72,7 @@ def extract_usage(response_or_chunk: Any) -> dict:
71
 
72
  # ── llm_call ────────────────────────────────────────────────────────────────
73
 
 
74
  async def record_llm_call(
75
  session: Any,
76
  *,
@@ -106,22 +108,26 @@ async def record_llm_call(
106
  if response is not None:
107
  try:
108
  from litellm import completion_cost
 
109
  cost_usd = float(completion_cost(completion_response=response) or 0.0)
110
  except Exception:
111
  cost_usd = 0.0
112
  from agent.core.session import Event # local import to avoid cycle
 
113
  try:
114
- await session.send_event(Event(
115
- event_type="llm_call",
116
- data={
117
- "model": model,
118
- "latency_ms": latency_ms,
119
- "finish_reason": finish_reason,
120
- "cost_usd": cost_usd,
121
- "kind": kind,
122
- **usage,
123
- },
124
- ))
 
 
125
  except Exception as e:
126
  logger.debug("record_llm_call failed (non-fatal): %s", e)
127
  return usage
@@ -129,6 +135,7 @@ async def record_llm_call(
129
 
130
  # ── hf_jobs ────────────────────────────────────────────────────────────────
131
 
 
132
  def _infer_push_to_hub(script_or_cmd: Any) -> bool:
133
  if not isinstance(script_or_cmd, str):
134
  return False
@@ -150,22 +157,25 @@ async def record_hf_job_submit(
150
  """Emit ``hf_job_submit``. Returns the monotonic start timestamp so the
151
  caller can pass it back into :func:`record_hf_job_complete`."""
152
  from agent.core.session import Event
 
153
  t_start = time.monotonic()
154
  try:
155
  script_text = args.get("script") or args.get("command") or ""
156
- await session.send_event(Event(
157
- event_type="hf_job_submit",
158
- data={
159
- "job_id": getattr(job, "id", None),
160
- "job_url": getattr(job, "url", None),
161
- "flavor": args.get("hardware_flavor", "cpu-basic"),
162
- "timeout": args.get("timeout", "30m"),
163
- "job_type": job_type,
164
- "image": image,
165
- "namespace": args.get("namespace"),
166
- "push_to_hub": _infer_push_to_hub(script_text),
167
- },
168
- ))
 
 
169
  except Exception as e:
170
  logger.debug("record_hf_job_submit failed (non-fatal): %s", e)
171
  return t_start
@@ -180,23 +190,27 @@ async def record_hf_job_complete(
180
  submit_ts: float,
181
  ) -> None:
182
  from agent.core.session import Event
 
183
  try:
184
  wall_time_s = int(time.monotonic() - submit_ts)
185
- await session.send_event(Event(
186
- event_type="hf_job_complete",
187
- data={
188
- "job_id": getattr(job, "id", None),
189
- "flavor": flavor,
190
- "final_status": final_status,
191
- "wall_time_s": wall_time_s,
192
- },
193
- ))
 
 
194
  except Exception as e:
195
  logger.debug("record_hf_job_complete failed (non-fatal): %s", e)
196
 
197
 
198
  # ── sandbox ────────────────────────────────────���────────────────────────────
199
 
 
200
  async def record_sandbox_create(
201
  session: Any,
202
  sandbox: Any,
@@ -205,39 +219,46 @@ async def record_sandbox_create(
205
  create_latency_s: int,
206
  ) -> None:
207
  from agent.core.session import Event
 
208
  try:
209
  # Pin created-at on the session so record_sandbox_destroy can diff.
210
  session._sandbox_created_at = time.monotonic() - create_latency_s
211
- await session.send_event(Event(
212
- event_type="sandbox_create",
213
- data={
214
- "sandbox_id": getattr(sandbox, "space_id", None),
215
- "hardware": hardware,
216
- "create_latency_s": int(create_latency_s),
217
- },
218
- ))
 
 
219
  except Exception as e:
220
  logger.debug("record_sandbox_create failed (non-fatal): %s", e)
221
 
222
 
223
  async def record_sandbox_destroy(session: Any, sandbox: Any) -> None:
224
  from agent.core.session import Event
 
225
  try:
226
  created = getattr(session, "_sandbox_created_at", None)
227
  lifetime_s = int(time.monotonic() - created) if created else None
228
- await session.send_event(Event(
229
- event_type="sandbox_destroy",
230
- data={
231
- "sandbox_id": getattr(sandbox, "space_id", None),
232
- "lifetime_s": lifetime_s,
233
- },
234
- ))
 
 
235
  except Exception as e:
236
  logger.debug("record_sandbox_destroy failed (non-fatal): %s", e)
237
 
238
 
239
  # ── feedback ───────────────────────────────────────────────────────────────
240
 
 
241
  async def record_feedback(
242
  session: Any,
243
  *,
@@ -247,16 +268,19 @@ async def record_feedback(
247
  comment: str | None = None,
248
  ) -> None:
249
  from agent.core.session import Event
 
250
  try:
251
- await session.send_event(Event(
252
- event_type="feedback",
253
- data={
254
- "rating": rating,
255
- "turn_index": turn_index,
256
- "message_id": message_id,
257
- "comment": (comment or "")[:500],
258
- },
259
- ))
 
 
260
  except Exception as e:
261
  logger.debug("record_feedback failed (non-fatal): %s", e)
262
 
@@ -269,15 +293,18 @@ async def record_jobs_access_blocked(
269
  eligible_namespaces: list[str],
270
  ) -> None:
271
  from agent.core.session import Event
 
272
  try:
273
- await session.send_event(Event(
274
- event_type="jobs_access_blocked",
275
- data={
276
- "tool_call_ids": tool_call_ids,
277
- "plan": plan,
278
- "eligible_namespaces": eligible_namespaces,
279
- },
280
- ))
 
 
281
  except Exception as e:
282
  logger.debug("record_jobs_access_blocked failed (non-fatal): %s", e)
283
 
@@ -289,11 +316,14 @@ async def record_pro_cta_click(
289
  target: str = "pro_pricing",
290
  ) -> None:
291
  from agent.core.session import Event
 
292
  try:
293
- await session.send_event(Event(
294
- event_type="pro_cta_click",
295
- data={"source": source, "target": target},
296
- ))
 
 
297
  except Exception as e:
298
  logger.debug("record_pro_cta_click failed (non-fatal): %s", e)
299
 
@@ -308,11 +338,14 @@ async def record_pro_conversion(
308
  ``MongoSessionStore.mark_pro_seen``; fired into the user's first Pro
309
  session so the rollup picks it up alongside other event-driven KPIs."""
310
  from agent.core.session import Event
 
311
  try:
312
- await session.send_event(Event(
313
- event_type="pro_conversion",
314
- data={"first_seen_at": first_seen_at},
315
- ))
 
 
316
  except Exception as e:
317
  logger.debug("record_pro_conversion failed (non-fatal): %s", e)
318
 
@@ -327,11 +360,14 @@ async def record_credits_topped_up(
327
  came back from the HF billing top-up flow and unblocked themselves.
328
  Caller is responsible for firing this at most once per session."""
329
  from agent.core.session import Event
 
330
  try:
331
- await session.send_event(Event(
332
- event_type="credits_topped_up",
333
- data={"namespace": namespace},
334
- ))
 
 
335
  except Exception as e:
336
  logger.debug("record_credits_topped_up failed (non-fatal): %s", e)
337
 
 
26
 
27
  # ── usage extraction ────────────────────────────────────────────────────────
28
 
29
+
30
  def extract_usage(response_or_chunk: Any) -> dict:
31
  """Flat usage dict from a litellm response or final-chunk usage object.
32
 
 
72
 
73
  # ── llm_call ────────────────────────────────────────────────────────────────
74
 
75
+
76
  async def record_llm_call(
77
  session: Any,
78
  *,
 
108
  if response is not None:
109
  try:
110
  from litellm import completion_cost
111
+
112
  cost_usd = float(completion_cost(completion_response=response) or 0.0)
113
  except Exception:
114
  cost_usd = 0.0
115
  from agent.core.session import Event # local import to avoid cycle
116
+
117
  try:
118
+ await session.send_event(
119
+ Event(
120
+ event_type="llm_call",
121
+ data={
122
+ "model": model,
123
+ "latency_ms": latency_ms,
124
+ "finish_reason": finish_reason,
125
+ "cost_usd": cost_usd,
126
+ "kind": kind,
127
+ **usage,
128
+ },
129
+ )
130
+ )
131
  except Exception as e:
132
  logger.debug("record_llm_call failed (non-fatal): %s", e)
133
  return usage
 
135
 
136
  # ── hf_jobs ────────────────────────────────────────────────────────────────
137
 
138
+
139
  def _infer_push_to_hub(script_or_cmd: Any) -> bool:
140
  if not isinstance(script_or_cmd, str):
141
  return False
 
157
  """Emit ``hf_job_submit``. Returns the monotonic start timestamp so the
158
  caller can pass it back into :func:`record_hf_job_complete`."""
159
  from agent.core.session import Event
160
+
161
  t_start = time.monotonic()
162
  try:
163
  script_text = args.get("script") or args.get("command") or ""
164
+ await session.send_event(
165
+ Event(
166
+ event_type="hf_job_submit",
167
+ data={
168
+ "job_id": getattr(job, "id", None),
169
+ "job_url": getattr(job, "url", None),
170
+ "flavor": args.get("hardware_flavor", "cpu-basic"),
171
+ "timeout": args.get("timeout", "30m"),
172
+ "job_type": job_type,
173
+ "image": image,
174
+ "namespace": args.get("namespace"),
175
+ "push_to_hub": _infer_push_to_hub(script_text),
176
+ },
177
+ )
178
+ )
179
  except Exception as e:
180
  logger.debug("record_hf_job_submit failed (non-fatal): %s", e)
181
  return t_start
 
190
  submit_ts: float,
191
  ) -> None:
192
  from agent.core.session import Event
193
+
194
  try:
195
  wall_time_s = int(time.monotonic() - submit_ts)
196
+ await session.send_event(
197
+ Event(
198
+ event_type="hf_job_complete",
199
+ data={
200
+ "job_id": getattr(job, "id", None),
201
+ "flavor": flavor,
202
+ "final_status": final_status,
203
+ "wall_time_s": wall_time_s,
204
+ },
205
+ )
206
+ )
207
  except Exception as e:
208
  logger.debug("record_hf_job_complete failed (non-fatal): %s", e)
209
 
210
 
211
  # ── sandbox ────────────────────────────────────���────────────────────────────
212
 
213
+
214
  async def record_sandbox_create(
215
  session: Any,
216
  sandbox: Any,
 
219
  create_latency_s: int,
220
  ) -> None:
221
  from agent.core.session import Event
222
+
223
  try:
224
  # Pin created-at on the session so record_sandbox_destroy can diff.
225
  session._sandbox_created_at = time.monotonic() - create_latency_s
226
+ await session.send_event(
227
+ Event(
228
+ event_type="sandbox_create",
229
+ data={
230
+ "sandbox_id": getattr(sandbox, "space_id", None),
231
+ "hardware": hardware,
232
+ "create_latency_s": int(create_latency_s),
233
+ },
234
+ )
235
+ )
236
  except Exception as e:
237
  logger.debug("record_sandbox_create failed (non-fatal): %s", e)
238
 
239
 
240
  async def record_sandbox_destroy(session: Any, sandbox: Any) -> None:
241
  from agent.core.session import Event
242
+
243
  try:
244
  created = getattr(session, "_sandbox_created_at", None)
245
  lifetime_s = int(time.monotonic() - created) if created else None
246
+ await session.send_event(
247
+ Event(
248
+ event_type="sandbox_destroy",
249
+ data={
250
+ "sandbox_id": getattr(sandbox, "space_id", None),
251
+ "lifetime_s": lifetime_s,
252
+ },
253
+ )
254
+ )
255
  except Exception as e:
256
  logger.debug("record_sandbox_destroy failed (non-fatal): %s", e)
257
 
258
 
259
  # ── feedback ───────────────────────────────────────────────────────────────
260
 
261
+
262
  async def record_feedback(
263
  session: Any,
264
  *,
 
268
  comment: str | None = None,
269
  ) -> None:
270
  from agent.core.session import Event
271
+
272
  try:
273
+ await session.send_event(
274
+ Event(
275
+ event_type="feedback",
276
+ data={
277
+ "rating": rating,
278
+ "turn_index": turn_index,
279
+ "message_id": message_id,
280
+ "comment": (comment or "")[:500],
281
+ },
282
+ )
283
+ )
284
  except Exception as e:
285
  logger.debug("record_feedback failed (non-fatal): %s", e)
286
 
 
293
  eligible_namespaces: list[str],
294
  ) -> None:
295
  from agent.core.session import Event
296
+
297
  try:
298
+ await session.send_event(
299
+ Event(
300
+ event_type="jobs_access_blocked",
301
+ data={
302
+ "tool_call_ids": tool_call_ids,
303
+ "plan": plan,
304
+ "eligible_namespaces": eligible_namespaces,
305
+ },
306
+ )
307
+ )
308
  except Exception as e:
309
  logger.debug("record_jobs_access_blocked failed (non-fatal): %s", e)
310
 
 
316
  target: str = "pro_pricing",
317
  ) -> None:
318
  from agent.core.session import Event
319
+
320
  try:
321
+ await session.send_event(
322
+ Event(
323
+ event_type="pro_cta_click",
324
+ data={"source": source, "target": target},
325
+ )
326
+ )
327
  except Exception as e:
328
  logger.debug("record_pro_cta_click failed (non-fatal): %s", e)
329
 
 
338
  ``MongoSessionStore.mark_pro_seen``; fired into the user's first Pro
339
  session so the rollup picks it up alongside other event-driven KPIs."""
340
  from agent.core.session import Event
341
+
342
  try:
343
+ await session.send_event(
344
+ Event(
345
+ event_type="pro_conversion",
346
+ data={"first_seen_at": first_seen_at},
347
+ )
348
+ )
349
  except Exception as e:
350
  logger.debug("record_pro_conversion failed (non-fatal): %s", e)
351
 
 
360
  came back from the HF billing top-up flow and unblocked themselves.
361
  Caller is responsible for firing this at most once per session."""
362
  from agent.core.session import Event
363
+
364
  try:
365
+ await session.send_event(
366
+ Event(
367
+ event_type="credits_topped_up",
368
+ data={"namespace": namespace},
369
+ )
370
+ )
371
  except Exception as e:
372
  logger.debug("record_credits_topped_up failed (non-fatal): %s", e)
373
 
agent/core/tools.py CHANGED
@@ -8,8 +8,6 @@ import warnings
8
  from dataclasses import dataclass
9
  from typing import Any, Awaitable, Callable, Optional
10
 
11
- logger = logging.getLogger(__name__)
12
-
13
  from fastmcp import Client
14
  from fastmcp.exceptions import ToolError
15
  from mcp.types import EmbeddedResource, ImageContent, TextContent
@@ -64,6 +62,8 @@ warnings.filterwarnings(
64
  "ignore", category=DeprecationWarning, module="aiohttp.connector"
65
  )
66
 
 
 
67
  NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"]
68
 
69
 
@@ -131,7 +131,12 @@ class ToolRouter:
131
  Based on codex-rs/core/src/tools/router.rs
132
  """
133
 
134
- def __init__(self, mcp_servers: dict[str, MCPServerConfig], hf_token: str | None = None, local_mode: bool = False):
 
 
 
 
 
135
  self.tools: dict[str, ToolSpec] = {}
136
  self.mcp_servers: dict[str, dict[str, Any]] = {}
137
 
@@ -144,7 +149,9 @@ class ToolRouter:
144
  for name, server in mcp_servers.items():
145
  data = server.model_dump()
146
  if hf_token:
147
- data.setdefault("headers", {})["Authorization"] = f"Bearer {hf_token}"
 
 
148
  mcp_servers_payload[name] = data
149
  self.mcp_client = Client({"mcpServers": mcp_servers_payload})
150
  self._mcp_initialized = False
@@ -218,7 +225,9 @@ class ToolRouter:
218
  await self.register_mcp_tools()
219
  self._mcp_initialized = True
220
  except Exception as e:
221
- logger.warning("MCP connection failed, continuing without MCP tools: %s", e)
 
 
222
  self.mcp_client = None
223
 
224
  await self.register_openapi_tool()
@@ -380,6 +389,7 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]:
380
  # Sandbox or local tools (highest priority)
381
  if local_mode:
382
  from agent.tools.local_tools import get_local_tools
 
383
  tools = get_local_tools() + tools
384
  else:
385
  tools = get_sandbox_tools() + tools
 
8
  from dataclasses import dataclass
9
  from typing import Any, Awaitable, Callable, Optional
10
 
 
 
11
  from fastmcp import Client
12
  from fastmcp.exceptions import ToolError
13
  from mcp.types import EmbeddedResource, ImageContent, TextContent
 
62
  "ignore", category=DeprecationWarning, module="aiohttp.connector"
63
  )
64
 
65
+ logger = logging.getLogger(__name__)
66
+
67
  NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"]
68
 
69
 
 
131
  Based on codex-rs/core/src/tools/router.rs
132
  """
133
 
134
+ def __init__(
135
+ self,
136
+ mcp_servers: dict[str, MCPServerConfig],
137
+ hf_token: str | None = None,
138
+ local_mode: bool = False,
139
+ ):
140
  self.tools: dict[str, ToolSpec] = {}
141
  self.mcp_servers: dict[str, dict[str, Any]] = {}
142
 
 
149
  for name, server in mcp_servers.items():
150
  data = server.model_dump()
151
  if hf_token:
152
+ data.setdefault("headers", {})["Authorization"] = (
153
+ f"Bearer {hf_token}"
154
+ )
155
  mcp_servers_payload[name] = data
156
  self.mcp_client = Client({"mcpServers": mcp_servers_payload})
157
  self._mcp_initialized = False
 
225
  await self.register_mcp_tools()
226
  self._mcp_initialized = True
227
  except Exception as e:
228
+ logger.warning(
229
+ "MCP connection failed, continuing without MCP tools: %s", e
230
+ )
231
  self.mcp_client = None
232
 
233
  await self.register_openapi_tool()
 
389
  # Sandbox or local tools (highest priority)
390
  if local_mode:
391
  from agent.tools.local_tools import get_local_tools
392
+
393
  tools = get_local_tools() + tools
394
  else:
395
  tools = get_sandbox_tools() + tools
agent/main.py CHANGED
@@ -77,6 +77,7 @@ def _configure_runtime_logging() -> None:
77
  logging.getLogger("LiteLLM").setLevel(logging.ERROR)
78
  logging.getLogger("litellm").setLevel(logging.ERROR)
79
 
 
80
  def _safe_get_args(arguments: dict) -> dict:
81
  """Safely extract args dict from arguments, handling cases where LLM passes string."""
82
  args = arguments.get("args", {})
@@ -92,6 +93,7 @@ def _get_hf_user(token: str | None) -> str | None:
92
  return None
93
  try:
94
  from huggingface_hub import HfApi
 
95
  return HfApi(token=token).whoami().get("name")
96
  except Exception:
97
  return None
@@ -134,10 +136,13 @@ async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str:
134
  login(token=token, add_to_git_credential=False)
135
  print("Token saved to ~/.cache/huggingface/token")
136
  except Exception as e:
137
- print(f"Warning: could not persist token ({e}), using for this session only.")
 
 
138
 
139
  return token
140
 
 
141
  @dataclass
142
  class Operation:
143
  """Operation to be executed by the agent"""
@@ -162,9 +167,9 @@ def _create_rich_console():
162
  class _ThinkingShimmer:
163
  """Animated shiny/shimmer thinking indicator — a bright gradient sweeps across the text."""
164
 
165
- _BASE = (90, 90, 110) # dim base color
166
- _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold)
167
- _WIDTH = 5 # shimmer width in characters
168
  _FPS = 24
169
 
170
  def __init__(self, console):
@@ -245,7 +250,7 @@ class _StreamBuffer:
245
  if idx == -1:
246
  return None
247
  block = self._buffer[:idx]
248
- self._buffer = self._buffer[idx + 2:]
249
  return block
250
 
251
  async def flush_ready(
@@ -271,7 +276,9 @@ class _StreamBuffer:
271
  """Flush complete blocks, then render whatever incomplete tail remains."""
272
  await self.flush_ready(cancel_event=cancel_event, instant=instant)
273
  if self._buffer.strip():
274
- await print_markdown(self._buffer, cancel_event=cancel_event, instant=instant)
 
 
275
  self._buffer = ""
276
 
277
  def discard(self):
@@ -372,7 +379,11 @@ async def event_listener(
372
  elif event.event_type == "error":
373
  shimmer.stop()
374
  stream_buf.discard()
375
- error = event.data.get("error", "Unknown error") if event.data else "Unknown error"
 
 
 
 
376
  print_error(error)
377
  turn_complete_event.set()
378
  elif event.event_type == "shutdown":
@@ -392,8 +403,10 @@ async def event_listener(
392
 
393
  # If yolo mode is active, auto-approve everything except
394
  # scheduled HF jobs, whose recurring cost stays manual.
395
- if config and config.yolo_mode and not any(
396
- _is_scheduled_hf_job_tool(t) for t in tools_data
 
 
397
  ):
398
  approvals = [
399
  {
@@ -637,7 +650,9 @@ async def event_listener(
637
  f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
638
  )
639
  except (KeyboardInterrupt, EOFError):
640
- get_console().print("[dim]Approval cancelled — rejecting remaining items[/dim]")
 
 
641
  approvals.append(
642
  {
643
  "tool_call_id": tool_call_id,
@@ -770,7 +785,11 @@ async def _handle_slash_command(
770
  normalized = arg.removeprefix("huggingface/")
771
  session = session_holder[0] if session_holder else None
772
  await model_switcher.probe_and_switch_model(
773
- normalized, config, session, console, resolve_hf_token(),
 
 
 
 
774
  )
775
  return None
776
 
@@ -965,6 +984,7 @@ async def main(model: str | None = None):
965
  # Pre-warm the HF router catalog in the background so /model switches
966
  # don't block on a network fetch.
967
  from agent.core import hf_router_catalog
 
968
  asyncio.create_task(asyncio.to_thread(hf_router_catalog.prewarm))
969
 
970
  # Create queues for communication
@@ -1110,7 +1130,11 @@ async def main(model: str | None = None):
1110
  # Handle slash commands
1111
  if user_input.strip().startswith("/"):
1112
  sub = await _handle_slash_command(
1113
- user_input.strip(), config, session_holder, submission_queue, submission_id
 
 
 
 
1114
  )
1115
  if sub is None:
1116
  # Command handled locally, loop back for input
@@ -1176,10 +1200,13 @@ async def headless_main(
1176
 
1177
  hf_token = resolve_hf_token()
1178
  if not hf_token:
1179
- print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr)
 
 
 
1180
  sys.exit(1)
1181
 
1182
- print(f"HF token loaded", file=sys.stderr)
1183
 
1184
  config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
1185
  config.yolo_mode = True # Auto-approve everything in headless mode
@@ -1327,26 +1354,35 @@ async def headless_main(
1327
  for t in tools_data
1328
  ]
1329
  _hl_sub_id[0] += 1
1330
- await submission_queue.put(Submission(
1331
- id=f"hl_approval_{_hl_sub_id[0]}",
1332
- operation=Operation(
1333
- op_type=OpType.EXEC_APPROVAL,
1334
- data={"approvals": approvals},
1335
- ),
1336
- ))
 
 
1337
  elif event.event_type == "compacted":
1338
  old_tokens = event.data.get("old_tokens", 0) if event.data else 0
1339
  new_tokens = event.data.get("new_tokens", 0) if event.data else 0
1340
  print_compacted(old_tokens, new_tokens)
1341
  elif event.event_type == "error":
1342
  stream_buf.discard()
1343
- error = event.data.get("error", "Unknown error") if event.data else "Unknown error"
 
 
 
 
1344
  print_error(error)
1345
  break
1346
  elif event.event_type in ("turn_complete", "interrupted"):
1347
  stream_buf.discard()
1348
  history_size = event.data.get("history_size", "?") if event.data else "?"
1349
- print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr)
 
 
 
1350
  if event.event_type == "turn_complete":
1351
  session = session_holder[0] if session_holder else None
1352
  if session is not None:
@@ -1372,6 +1408,7 @@ def cli():
1372
  """Entry point for the ml-intern CLI command."""
1373
  import logging as _logging
1374
  import warnings
 
1375
  # Suppress aiohttp "Unclosed client session" noise during event loop teardown
1376
  _logging.getLogger("asyncio").setLevel(_logging.CRITICAL)
1377
  _configure_runtime_logging()
@@ -1381,12 +1418,23 @@ def cli():
1381
  warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh")
1382
 
1383
  parser = argparse.ArgumentParser(description="Hugging Face Agent CLI")
1384
- parser.add_argument("prompt", nargs="?", default=None, help="Run headlessly with this prompt")
1385
- parser.add_argument("--model", "-m", default=None, help=f"Model to use (default: from config)")
1386
- parser.add_argument("--max-iterations", type=int, default=None,
1387
- help="Max LLM requests per turn (default: 50, use -1 for unlimited)")
1388
- parser.add_argument("--no-stream", action="store_true",
1389
- help="Disable token streaming (use non-streaming LLM calls)")
 
 
 
 
 
 
 
 
 
 
 
1390
  args = parser.parse_args()
1391
 
1392
  try:
@@ -1394,7 +1442,14 @@ def cli():
1394
  max_iter = args.max_iterations
1395
  if max_iter is not None and max_iter < 0:
1396
  max_iter = 10_000 # effectively unlimited
1397
- asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream))
 
 
 
 
 
 
 
1398
  else:
1399
  asyncio.run(main(model=args.model))
1400
  except KeyboardInterrupt:
 
77
  logging.getLogger("LiteLLM").setLevel(logging.ERROR)
78
  logging.getLogger("litellm").setLevel(logging.ERROR)
79
 
80
+
81
  def _safe_get_args(arguments: dict) -> dict:
82
  """Safely extract args dict from arguments, handling cases where LLM passes string."""
83
  args = arguments.get("args", {})
 
93
  return None
94
  try:
95
  from huggingface_hub import HfApi
96
+
97
  return HfApi(token=token).whoami().get("name")
98
  except Exception:
99
  return None
 
136
  login(token=token, add_to_git_credential=False)
137
  print("Token saved to ~/.cache/huggingface/token")
138
  except Exception as e:
139
+ print(
140
+ f"Warning: could not persist token ({e}), using for this session only."
141
+ )
142
 
143
  return token
144
 
145
+
146
  @dataclass
147
  class Operation:
148
  """Operation to be executed by the agent"""
 
167
  class _ThinkingShimmer:
168
  """Animated shiny/shimmer thinking indicator — a bright gradient sweeps across the text."""
169
 
170
+ _BASE = (90, 90, 110) # dim base color
171
+ _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold)
172
+ _WIDTH = 5 # shimmer width in characters
173
  _FPS = 24
174
 
175
  def __init__(self, console):
 
250
  if idx == -1:
251
  return None
252
  block = self._buffer[:idx]
253
+ self._buffer = self._buffer[idx + 2 :]
254
  return block
255
 
256
  async def flush_ready(
 
276
  """Flush complete blocks, then render whatever incomplete tail remains."""
277
  await self.flush_ready(cancel_event=cancel_event, instant=instant)
278
  if self._buffer.strip():
279
+ await print_markdown(
280
+ self._buffer, cancel_event=cancel_event, instant=instant
281
+ )
282
  self._buffer = ""
283
 
284
  def discard(self):
 
379
  elif event.event_type == "error":
380
  shimmer.stop()
381
  stream_buf.discard()
382
+ error = (
383
+ event.data.get("error", "Unknown error")
384
+ if event.data
385
+ else "Unknown error"
386
+ )
387
  print_error(error)
388
  turn_complete_event.set()
389
  elif event.event_type == "shutdown":
 
403
 
404
  # If yolo mode is active, auto-approve everything except
405
  # scheduled HF jobs, whose recurring cost stays manual.
406
+ if (
407
+ config
408
+ and config.yolo_mode
409
+ and not any(_is_scheduled_hf_job_tool(t) for t in tools_data)
410
  ):
411
  approvals = [
412
  {
 
650
  f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
651
  )
652
  except (KeyboardInterrupt, EOFError):
653
+ get_console().print(
654
+ "[dim]Approval cancelled — rejecting remaining items[/dim]"
655
+ )
656
  approvals.append(
657
  {
658
  "tool_call_id": tool_call_id,
 
785
  normalized = arg.removeprefix("huggingface/")
786
  session = session_holder[0] if session_holder else None
787
  await model_switcher.probe_and_switch_model(
788
+ normalized,
789
+ config,
790
+ session,
791
+ console,
792
+ resolve_hf_token(),
793
  )
794
  return None
795
 
 
984
  # Pre-warm the HF router catalog in the background so /model switches
985
  # don't block on a network fetch.
986
  from agent.core import hf_router_catalog
987
+
988
  asyncio.create_task(asyncio.to_thread(hf_router_catalog.prewarm))
989
 
990
  # Create queues for communication
 
1130
  # Handle slash commands
1131
  if user_input.strip().startswith("/"):
1132
  sub = await _handle_slash_command(
1133
+ user_input.strip(),
1134
+ config,
1135
+ session_holder,
1136
+ submission_queue,
1137
+ submission_id,
1138
  )
1139
  if sub is None:
1140
  # Command handled locally, loop back for input
 
1200
 
1201
  hf_token = resolve_hf_token()
1202
  if not hf_token:
1203
+ print(
1204
+ "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.",
1205
+ file=sys.stderr,
1206
+ )
1207
  sys.exit(1)
1208
 
1209
+ print("HF token loaded", file=sys.stderr)
1210
 
1211
  config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
1212
  config.yolo_mode = True # Auto-approve everything in headless mode
 
1354
  for t in tools_data
1355
  ]
1356
  _hl_sub_id[0] += 1
1357
+ await submission_queue.put(
1358
+ Submission(
1359
+ id=f"hl_approval_{_hl_sub_id[0]}",
1360
+ operation=Operation(
1361
+ op_type=OpType.EXEC_APPROVAL,
1362
+ data={"approvals": approvals},
1363
+ ),
1364
+ )
1365
+ )
1366
  elif event.event_type == "compacted":
1367
  old_tokens = event.data.get("old_tokens", 0) if event.data else 0
1368
  new_tokens = event.data.get("new_tokens", 0) if event.data else 0
1369
  print_compacted(old_tokens, new_tokens)
1370
  elif event.event_type == "error":
1371
  stream_buf.discard()
1372
+ error = (
1373
+ event.data.get("error", "Unknown error")
1374
+ if event.data
1375
+ else "Unknown error"
1376
+ )
1377
  print_error(error)
1378
  break
1379
  elif event.event_type in ("turn_complete", "interrupted"):
1380
  stream_buf.discard()
1381
  history_size = event.data.get("history_size", "?") if event.data else "?"
1382
+ print(
1383
+ f"\n--- Agent {event.event_type} (history_size={history_size}) ---",
1384
+ file=sys.stderr,
1385
+ )
1386
  if event.event_type == "turn_complete":
1387
  session = session_holder[0] if session_holder else None
1388
  if session is not None:
 
1408
  """Entry point for the ml-intern CLI command."""
1409
  import logging as _logging
1410
  import warnings
1411
+
1412
  # Suppress aiohttp "Unclosed client session" noise during event loop teardown
1413
  _logging.getLogger("asyncio").setLevel(_logging.CRITICAL)
1414
  _configure_runtime_logging()
 
1418
  warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh")
1419
 
1420
  parser = argparse.ArgumentParser(description="Hugging Face Agent CLI")
1421
+ parser.add_argument(
1422
+ "prompt", nargs="?", default=None, help="Run headlessly with this prompt"
1423
+ )
1424
+ parser.add_argument(
1425
+ "--model", "-m", default=None, help="Model to use (default: from config)"
1426
+ )
1427
+ parser.add_argument(
1428
+ "--max-iterations",
1429
+ type=int,
1430
+ default=None,
1431
+ help="Max LLM requests per turn (default: 50, use -1 for unlimited)",
1432
+ )
1433
+ parser.add_argument(
1434
+ "--no-stream",
1435
+ action="store_true",
1436
+ help="Disable token streaming (use non-streaming LLM calls)",
1437
+ )
1438
  args = parser.parse_args()
1439
 
1440
  try:
 
1442
  max_iter = args.max_iterations
1443
  if max_iter is not None and max_iter < 0:
1444
  max_iter = 10_000 # effectively unlimited
1445
+ asyncio.run(
1446
+ headless_main(
1447
+ args.prompt,
1448
+ model=args.model,
1449
+ max_iterations=max_iter,
1450
+ stream=not args.no_stream,
1451
+ )
1452
+ )
1453
  else:
1454
  asyncio.run(main(model=args.model))
1455
  except KeyboardInterrupt:
agent/messaging/base.py CHANGED
@@ -2,7 +2,11 @@ from abc import ABC, abstractmethod
2
 
3
  import httpx
4
 
5
- from agent.messaging.models import DestinationConfig, NotificationRequest, NotificationResult
 
 
 
 
6
 
7
 
8
  class NotificationError(Exception):
 
2
 
3
  import httpx
4
 
5
+ from agent.messaging.models import (
6
+ DestinationConfig,
7
+ NotificationRequest,
8
+ NotificationResult,
9
+ )
10
 
11
 
12
  class NotificationError(Exception):
agent/messaging/gateway.py CHANGED
@@ -39,7 +39,9 @@ class NotificationGateway:
39
  if not self.enabled or self._worker_task is not None:
40
  return
41
  self._client = httpx.AsyncClient(timeout=10.0)
42
- self._worker_task = asyncio.create_task(self._worker(), name="notification-gateway")
 
 
43
 
44
  async def flush(self) -> None:
45
  if not self.enabled:
@@ -87,7 +89,9 @@ class NotificationGateway:
87
  provider=destination.provider,
88
  error=f"No provider implementation for '{destination.provider}'",
89
  )
90
- return await self._send_with_retries(provider, request.destination, destination, request)
 
 
91
 
92
  async def send_many(
93
  self, requests: Iterable[NotificationRequest]
@@ -131,7 +135,9 @@ class NotificationGateway:
131
  try:
132
  for attempt in range(len(_RETRY_DELAYS) + 1):
133
  try:
134
- return await provider.send(client, destination_name, destination, request)
 
 
135
  except RetryableNotificationError as exc:
136
  if attempt >= len(_RETRY_DELAYS):
137
  return NotificationResult(
 
39
  if not self.enabled or self._worker_task is not None:
40
  return
41
  self._client = httpx.AsyncClient(timeout=10.0)
42
+ self._worker_task = asyncio.create_task(
43
+ self._worker(), name="notification-gateway"
44
+ )
45
 
46
  async def flush(self) -> None:
47
  if not self.enabled:
 
89
  provider=destination.provider,
90
  error=f"No provider implementation for '{destination.provider}'",
91
  )
92
+ return await self._send_with_retries(
93
+ provider, request.destination, destination, request
94
+ )
95
 
96
  async def send_many(
97
  self, requests: Iterable[NotificationRequest]
 
135
  try:
136
  for attempt in range(len(_RETRY_DELAYS) + 1):
137
  try:
138
+ return await provider.send(
139
+ client, destination_name, destination, request
140
+ )
141
  except RetryableNotificationError as exc:
142
  if attempt >= len(_RETRY_DELAYS):
143
  return NotificationResult(
agent/messaging/models.py CHANGED
@@ -55,9 +55,7 @@ class MessagingConfig(BaseModel):
55
  seen: set[str] = set()
56
  for event_type in event_types:
57
  if event_type not in SUPPORTED_AUTO_EVENT_TYPES:
58
- raise ValueError(
59
- f"unsupported auto event type '{event_type}'"
60
- )
61
  if event_type not in seen:
62
  normalized.append(event_type)
63
  seen.add(event_type)
@@ -83,11 +81,7 @@ class MessagingConfig(BaseModel):
83
  def default_auto_destinations(self) -> list[str]:
84
  if not self.enabled:
85
  return []
86
- return [
87
- name
88
- for name in self.destinations
89
- if self.can_auto_send(name)
90
- ]
91
 
92
 
93
  class NotificationRequest(BaseModel):
 
55
  seen: set[str] = set()
56
  for event_type in event_types:
57
  if event_type not in SUPPORTED_AUTO_EVENT_TYPES:
58
+ raise ValueError(f"unsupported auto event type '{event_type}'")
 
 
59
  if event_type not in seen:
60
  normalized.append(event_type)
61
  seen.add(event_type)
 
81
  def default_auto_destinations(self) -> list[str]:
82
  if not self.enabled:
83
  return []
84
+ return [name for name in self.destinations if self.can_auto_send(name)]
 
 
 
 
85
 
86
 
87
  class NotificationRequest(BaseModel):
agent/messaging/slack.py CHANGED
@@ -160,9 +160,7 @@ class SlackProvider(NotificationProvider):
160
  raise RetryableNotificationError("Slack transport error") from exc
161
 
162
  if response.status_code == 429 or response.status_code >= 500:
163
- raise RetryableNotificationError(
164
- f"Slack HTTP {response.status_code}"
165
- )
166
  if response.status_code >= 400:
167
  raise NotificationError(f"Slack HTTP {response.status_code}")
168
 
 
160
  raise RetryableNotificationError("Slack transport error") from exc
161
 
162
  if response.status_code == 429 or response.status_code >= 500:
163
+ raise RetryableNotificationError(f"Slack HTTP {response.status_code}")
 
 
164
  if response.status_code >= 400:
165
  raise NotificationError(f"Slack HTTP {response.status_code}")
166
 
agent/sft/tagger.py CHANGED
@@ -27,19 +27,29 @@ Tags are deduplicated before returning.
27
 
28
  from __future__ import annotations
29
 
30
- from typing import Any, Iterable
31
 
32
  # Flavor → GPU-family mapping. Keep conservative; unknown flavors → "none".
33
  _GPU_FAMILY = {
34
- "cpu-basic": "none", "cpu-upgrade": "none",
35
- "t4-small": "t4", "t4-medium": "t4",
36
- "l4x1": "l40s", "l4x4": "l40s",
37
- "l40sx1": "l40s", "l40sx4": "l40s", "l40sx8": "l40s",
38
- "a10g-small": "a10g", "a10g-large": "a10g",
39
- "a10g-largex2": "a10g", "a10g-largex4": "a10g",
40
- "a100-large": "a100", "a100x2": "a100",
41
- "a100x4": "a100", "a100x8": "a100",
42
- "h100": "h100", "h100x8": "h100",
 
 
 
 
 
 
 
 
 
 
43
  }
44
 
45
  # Substrings that count a flavor as multi-GPU.
@@ -48,9 +58,17 @@ _MULTI_GPU_MARKERS = ("x2", "x4", "x8")
48
  # Tool names that don't touch training/inference or sandbox/jobs. If a session
49
  # only used these, we tag it research_only.
50
  _RESEARCH_ONLY_TOOLS = {
51
- "research", "github_find_examples", "github_read_file", "github_list_repos",
52
- "hf_papers", "explore_hf_docs", "fetch_hf_docs", "hub_repo_details",
53
- "plan", "hf_inspect_dataset", "web_search",
 
 
 
 
 
 
 
 
54
  }
55
 
56
  # Tool names that signal data manipulation workflows.
@@ -126,11 +144,22 @@ def _infer_task_tag(
126
  # hf_jobs at all and a script mentions training APIs.
127
  for script in hf_job_submit_scripts:
128
  low = script.lower()
129
- if any(k in low for k in (
130
- "sftconfig", "sfttrainer", "trainer(", "trainingarguments",
131
- "grpo", "dpo", ".train(", "transformers import",
132
- "trainer import", "fine-tune", "finetune",
133
- )):
 
 
 
 
 
 
 
 
 
 
 
134
  return "training"
135
 
136
  # inference: sessions that use inference tools but never hf_jobs/sandbox
 
27
 
28
  from __future__ import annotations
29
 
30
+ from typing import Iterable
31
 
32
  # Flavor → GPU-family mapping. Keep conservative; unknown flavors → "none".
33
  _GPU_FAMILY = {
34
+ "cpu-basic": "none",
35
+ "cpu-upgrade": "none",
36
+ "t4-small": "t4",
37
+ "t4-medium": "t4",
38
+ "l4x1": "l40s",
39
+ "l4x4": "l40s",
40
+ "l40sx1": "l40s",
41
+ "l40sx4": "l40s",
42
+ "l40sx8": "l40s",
43
+ "a10g-small": "a10g",
44
+ "a10g-large": "a10g",
45
+ "a10g-largex2": "a10g",
46
+ "a10g-largex4": "a10g",
47
+ "a100-large": "a100",
48
+ "a100x2": "a100",
49
+ "a100x4": "a100",
50
+ "a100x8": "a100",
51
+ "h100": "h100",
52
+ "h100x8": "h100",
53
  }
54
 
55
  # Substrings that count a flavor as multi-GPU.
 
58
  # Tool names that don't touch training/inference or sandbox/jobs. If a session
59
  # only used these, we tag it research_only.
60
  _RESEARCH_ONLY_TOOLS = {
61
+ "research",
62
+ "github_find_examples",
63
+ "github_read_file",
64
+ "github_list_repos",
65
+ "hf_papers",
66
+ "explore_hf_docs",
67
+ "fetch_hf_docs",
68
+ "hub_repo_details",
69
+ "plan",
70
+ "hf_inspect_dataset",
71
+ "web_search",
72
  }
73
 
74
  # Tool names that signal data manipulation workflows.
 
144
  # hf_jobs at all and a script mentions training APIs.
145
  for script in hf_job_submit_scripts:
146
  low = script.lower()
147
+ if any(
148
+ k in low
149
+ for k in (
150
+ "sftconfig",
151
+ "sfttrainer",
152
+ "trainer(",
153
+ "trainingarguments",
154
+ "grpo",
155
+ "dpo",
156
+ ".train(",
157
+ "transformers import",
158
+ "trainer import",
159
+ "fine-tune",
160
+ "finetune",
161
+ )
162
+ ):
163
  return "training"
164
 
165
  # inference: sessions that use inference tools but never hf_jobs/sandbox
agent/tools/dataset_tools.py CHANGED
@@ -423,7 +423,9 @@ HF_INSPECT_DATASET_TOOL_SPEC = {
423
  }
424
 
425
 
426
- async def hf_inspect_dataset_handler(arguments: dict[str, Any], session=None) -> tuple[str, bool]:
 
 
427
  """Handler for agent tool router"""
428
  try:
429
  hf_token = session.hf_token if session else None
 
423
  }
424
 
425
 
426
+ async def hf_inspect_dataset_handler(
427
+ arguments: dict[str, Any], session=None
428
+ ) -> tuple[str, bool]:
429
  """Handler for agent tool router"""
430
  try:
431
  hf_token = session.hf_token if session else None
agent/tools/edit_utils.py CHANGED
@@ -10,18 +10,18 @@ from __future__ import annotations
10
  # ── Unicode normalization map ────────────────────────────────────────────
11
 
12
  UNICODE_MAP = {
13
- "\u2013": "-", # en-dash
14
- "\u2014": "-", # em-dash
15
- "\u2212": "-", # minus sign
16
- "\u2018": "'", # left single quote
17
- "\u2019": "'", # right single quote
18
- "\u201c": '"', # left double quote
19
- "\u201d": '"', # right double quote
20
- "\u00a0": " ", # non-breaking space
21
- "\u2003": " ", # em space
22
- "\u2002": " ", # en space
23
- "\u200b": "", # zero-width space
24
- "\ufeff": "", # BOM
25
  }
26
 
27
 
@@ -59,12 +59,12 @@ def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]:
59
  line_start_map[i] = original byte offset of the start of line i.
60
  """
61
  orig_lines = text.split("\n")
62
- stripped_lines = [strip_fn(l) for l in orig_lines]
63
  return "\n".join(stripped_lines), orig_lines, stripped_lines
64
 
65
  # Pass 2 — right-trim
66
  c_rt, c_orig_lines, c_rt_lines = _build_stripped(content, str.rstrip)
67
- p_rt = "\n".join(l.rstrip() for l in pattern.split("\n"))
68
  idx = c_rt.find(p_rt)
69
  if idx != -1:
70
  orig_idx = _map_back(idx, c_orig_lines, c_rt_lines)
@@ -72,7 +72,7 @@ def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]:
72
 
73
  # Pass 3 — both-sides trim
74
  c_st, _, c_st_lines = _build_stripped(content, str.strip)
75
- p_st = "\n".join(l.strip() for l in pattern.split("\n"))
76
  idx = c_st.find(p_st)
77
  if idx != -1:
78
  orig_idx = _map_back(idx, c_orig_lines, c_st_lines)
@@ -114,7 +114,9 @@ def _map_back(
114
  return 0
115
 
116
 
117
- def fuzzy_find_original_match(content: str, pattern: str) -> tuple[str | None, str | None]:
 
 
118
  """Find the *original* text in content that matches pattern fuzzily.
119
 
120
  Returns (original_matched_text, match_note) or (None, None).
@@ -224,7 +226,9 @@ def apply_edit(
224
  return new_content, 1, fuzzy_note
225
 
226
  else:
227
- raise ValueError(f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before.")
 
 
228
 
229
 
230
  # ── Syntax validation (Python) ───────────────────────────────────────────
@@ -255,14 +259,15 @@ def validate_python(content: str, path: str = "") -> list[str]:
255
  return warnings
256
 
257
  # 2. Training script heuristics
258
- if any(kw in content for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")):
 
 
 
259
  if "push_to_hub" not in content:
260
  warnings.append(
261
  "Training script warning: no 'push_to_hub' found — model may be lost when job ends"
262
  )
263
  if "hub_model_id" not in content:
264
- warnings.append(
265
- "Training script warning: no 'hub_model_id' found"
266
- )
267
 
268
  return warnings
 
10
  # ── Unicode normalization map ────────────────────────────────────────────
11
 
12
  UNICODE_MAP = {
13
+ "\u2013": "-", # en-dash
14
+ "\u2014": "-", # em-dash
15
+ "\u2212": "-", # minus sign
16
+ "\u2018": "'", # left single quote
17
+ "\u2019": "'", # right single quote
18
+ "\u201c": '"', # left double quote
19
+ "\u201d": '"', # right double quote
20
+ "\u00a0": " ", # non-breaking space
21
+ "\u2003": " ", # em space
22
+ "\u2002": " ", # en space
23
+ "\u200b": "", # zero-width space
24
+ "\ufeff": "", # BOM
25
  }
26
 
27
 
 
59
  line_start_map[i] = original byte offset of the start of line i.
60
  """
61
  orig_lines = text.split("\n")
62
+ stripped_lines = [strip_fn(line) for line in orig_lines]
63
  return "\n".join(stripped_lines), orig_lines, stripped_lines
64
 
65
  # Pass 2 — right-trim
66
  c_rt, c_orig_lines, c_rt_lines = _build_stripped(content, str.rstrip)
67
+ p_rt = "\n".join(line.rstrip() for line in pattern.split("\n"))
68
  idx = c_rt.find(p_rt)
69
  if idx != -1:
70
  orig_idx = _map_back(idx, c_orig_lines, c_rt_lines)
 
72
 
73
  # Pass 3 — both-sides trim
74
  c_st, _, c_st_lines = _build_stripped(content, str.strip)
75
+ p_st = "\n".join(line.strip() for line in pattern.split("\n"))
76
  idx = c_st.find(p_st)
77
  if idx != -1:
78
  orig_idx = _map_back(idx, c_orig_lines, c_st_lines)
 
114
  return 0
115
 
116
 
117
+ def fuzzy_find_original_match(
118
+ content: str, pattern: str
119
+ ) -> tuple[str | None, str | None]:
120
  """Find the *original* text in content that matches pattern fuzzily.
121
 
122
  Returns (original_matched_text, match_note) or (None, None).
 
226
  return new_content, 1, fuzzy_note
227
 
228
  else:
229
+ raise ValueError(
230
+ f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before."
231
+ )
232
 
233
 
234
  # ── Syntax validation (Python) ───────────────────────────────────────────
 
259
  return warnings
260
 
261
  # 2. Training script heuristics
262
+ if any(
263
+ kw in content
264
+ for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")
265
+ ):
266
  if "push_to_hub" not in content:
267
  warnings.append(
268
  "Training script warning: no 'push_to_hub' found — model may be lost when job ends"
269
  )
270
  if "hub_model_id" not in content:
271
+ warnings.append("Training script warning: no 'hub_model_id' found")
 
 
272
 
273
  return warnings
agent/tools/hf_repo_files_tool.py CHANGED
@@ -10,6 +10,7 @@ from typing import Any, Dict, Literal, Optional
10
  from huggingface_hub import HfApi, hf_hub_download
11
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
12
 
 
13
  from agent.tools.types import ToolResult
14
 
15
  OperationType = Literal["list", "read", "upload", "delete"]
@@ -39,8 +40,9 @@ def _format_size(size_bytes: int) -> str:
39
  class HfRepoFilesTool:
40
  """Tool for file operations on HF repos."""
41
 
42
- def __init__(self, hf_token: Optional[str] = None):
43
  self.api = HfApi(token=hf_token)
 
44
 
45
  async def execute(self, args: Dict[str, Any]) -> ToolResult:
46
  """Execute the specified operation."""
@@ -61,7 +63,9 @@ class HfRepoFilesTool:
61
  if handler:
62
  return await handler(args)
63
  else:
64
- return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete")
 
 
65
 
66
  except RepositoryNotFoundError:
67
  return self._error(f"Repository not found: {args.get('repo_id')}")
@@ -96,17 +100,23 @@ class HfRepoFilesTool:
96
  revision = args.get("revision", "main")
97
  path = args.get("path", "")
98
 
99
- items = list(await _async_call(
100
- self.api.list_repo_tree,
101
- repo_id=repo_id,
102
- repo_type=repo_type,
103
- revision=revision,
104
- path_in_repo=path,
105
- recursive=True,
106
- ))
 
 
107
 
108
  if not items:
109
- return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0}
 
 
 
 
110
 
111
  lines = []
112
  total_size = 0
@@ -118,9 +128,16 @@ class HfRepoFilesTool:
118
  lines.append(f"{item.path}/")
119
 
120
  url = _build_repo_url(repo_id, repo_type)
121
- response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines)
 
 
 
122
 
123
- return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)}
 
 
 
 
124
 
125
  async def _read(self, args: Dict[str, Any]) -> ToolResult:
126
  """Read file content from a repository."""
@@ -160,8 +177,13 @@ class HfRepoFilesTool:
160
 
161
  except UnicodeDecodeError:
162
  import os
 
163
  size = os.path.getsize(file_path)
164
- return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1}
 
 
 
 
165
 
166
  async def _upload(self, args: Dict[str, Any]) -> ToolResult:
167
  """Upload content to a repository."""
@@ -194,6 +216,16 @@ class HfRepoFilesTool:
194
  create_pr=create_pr,
195
  )
196
 
 
 
 
 
 
 
 
 
 
 
197
  url = _build_repo_url(repo_id, repo_type)
198
  if create_pr and hasattr(result, "pr_url"):
199
  response = f"**Uploaded as PR**\n{result.pr_url}"
@@ -235,7 +267,12 @@ class HfRepoFilesTool:
235
 
236
  def _error(self, message: str) -> ToolResult:
237
  """Return an error result."""
238
- return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
 
 
 
 
 
239
 
240
 
241
  # Tool specification
@@ -312,11 +349,13 @@ HF_REPO_FILES_TOOL_SPEC = {
312
  }
313
 
314
 
315
- async def hf_repo_files_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]:
 
 
316
  """Handler for agent tool router."""
317
  try:
318
  hf_token = session.hf_token if session else None
319
- tool = HfRepoFilesTool(hf_token=hf_token)
320
  result = await tool.execute(arguments)
321
  return result["formatted"], not result.get("isError", False)
322
  except Exception as e:
 
10
  from huggingface_hub import HfApi, hf_hub_download
11
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
12
 
13
+ from agent.core.hub_artifacts import is_known_hub_artifact, register_hub_artifact
14
  from agent.tools.types import ToolResult
15
 
16
  OperationType = Literal["list", "read", "upload", "delete"]
 
40
  class HfRepoFilesTool:
41
  """Tool for file operations on HF repos."""
42
 
43
+ def __init__(self, hf_token: Optional[str] = None, session: Any = None):
44
  self.api = HfApi(token=hf_token)
45
+ self.session = session
46
 
47
  async def execute(self, args: Dict[str, Any]) -> ToolResult:
48
  """Execute the specified operation."""
 
63
  if handler:
64
  return await handler(args)
65
  else:
66
+ return self._error(
67
+ f"Unknown operation: {operation}. Valid: list, read, upload, delete"
68
+ )
69
 
70
  except RepositoryNotFoundError:
71
  return self._error(f"Repository not found: {args.get('repo_id')}")
 
100
  revision = args.get("revision", "main")
101
  path = args.get("path", "")
102
 
103
+ items = list(
104
+ await _async_call(
105
+ self.api.list_repo_tree,
106
+ repo_id=repo_id,
107
+ repo_type=repo_type,
108
+ revision=revision,
109
+ path_in_repo=path,
110
+ recursive=True,
111
+ )
112
+ )
113
 
114
  if not items:
115
+ return {
116
+ "formatted": f"No files in {repo_id}",
117
+ "totalResults": 0,
118
+ "resultsShared": 0,
119
+ }
120
 
121
  lines = []
122
  total_size = 0
 
128
  lines.append(f"{item.path}/")
129
 
130
  url = _build_repo_url(repo_id, repo_type)
131
+ response = (
132
+ f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n"
133
+ + "\n".join(lines)
134
+ )
135
 
136
+ return {
137
+ "formatted": response,
138
+ "totalResults": len(items),
139
+ "resultsShared": len(items),
140
+ }
141
 
142
  async def _read(self, args: Dict[str, Any]) -> ToolResult:
143
  """Read file content from a repository."""
 
177
 
178
  except UnicodeDecodeError:
179
  import os
180
+
181
  size = os.path.getsize(file_path)
182
+ return {
183
+ "formatted": f"Binary file ({_format_size(size)})",
184
+ "totalResults": 1,
185
+ "resultsShared": 1,
186
+ }
187
 
188
  async def _upload(self, args: Dict[str, Any]) -> ToolResult:
189
  """Upload content to a repository."""
 
216
  create_pr=create_pr,
217
  )
218
 
219
+ if not create_pr and is_known_hub_artifact(self.session, repo_id, repo_type):
220
+ await _async_call(
221
+ register_hub_artifact,
222
+ self.api,
223
+ repo_id,
224
+ repo_type,
225
+ session=self.session,
226
+ force=path == "README.md",
227
+ )
228
+
229
  url = _build_repo_url(repo_id, repo_type)
230
  if create_pr and hasattr(result, "pr_url"):
231
  response = f"**Uploaded as PR**\n{result.pr_url}"
 
267
 
268
  def _error(self, message: str) -> ToolResult:
269
  """Return an error result."""
270
+ return {
271
+ "formatted": message,
272
+ "totalResults": 0,
273
+ "resultsShared": 0,
274
+ "isError": True,
275
+ }
276
 
277
 
278
  # Tool specification
 
349
  }
350
 
351
 
352
+ async def hf_repo_files_handler(
353
+ arguments: Dict[str, Any], session=None
354
+ ) -> tuple[str, bool]:
355
  """Handler for agent tool router."""
356
  try:
357
  hf_token = session.hf_token if session else None
358
+ tool = HfRepoFilesTool(hf_token=hf_token, session=session)
359
  result = await tool.execute(arguments)
360
  return result["formatted"], not result.get("isError", False)
361
  except Exception as e:
agent/tools/hf_repo_git_tool.py CHANGED
@@ -10,14 +10,24 @@ from typing import Any, Dict, Literal, Optional
10
  from huggingface_hub import HfApi
11
  from huggingface_hub.utils import RepositoryNotFoundError
12
 
 
13
  from agent.tools.types import ToolResult
14
 
15
  OperationType = Literal[
16
- "create_branch", "delete_branch",
17
- "create_tag", "delete_tag",
 
 
18
  "list_refs",
19
- "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
20
- "create_repo", "update_repo",
 
 
 
 
 
 
 
21
  ]
22
 
23
 
@@ -36,8 +46,9 @@ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
36
  class HfRepoGitTool:
37
  """Tool for git-like operations on HF repos."""
38
 
39
- def __init__(self, hf_token: Optional[str] = None):
40
  self.api = HfApi(token=hf_token)
 
41
 
42
  async def execute(self, args: Dict[str, Any]) -> ToolResult:
43
  """Execute the specified operation."""
@@ -131,7 +142,11 @@ class HfRepoGitTool:
131
  )
132
 
133
  url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
134
- return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1}
 
 
 
 
135
 
136
  async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
137
  """Delete a branch."""
@@ -152,7 +167,11 @@ class HfRepoGitTool:
152
  repo_type=repo_type,
153
  )
154
 
155
- return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1}
 
 
 
 
156
 
157
  # =========================================================================
158
  # TAG OPERATIONS
@@ -183,7 +202,11 @@ class HfRepoGitTool:
183
  )
184
 
185
  url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
186
- return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1}
 
 
 
 
187
 
188
  async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
189
  """Delete a tag."""
@@ -204,7 +227,11 @@ class HfRepoGitTool:
204
  repo_type=repo_type,
205
  )
206
 
207
- return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1}
 
 
 
 
208
 
209
  # =========================================================================
210
  # LIST REFS
@@ -226,7 +253,9 @@ class HfRepoGitTool:
226
  )
227
 
228
  branches = [b.name for b in refs.branches] if refs.branches else []
229
- tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else []
 
 
230
 
231
  url = _build_repo_url(repo_id, repo_type)
232
  lines = [f"**{repo_id}**", url, ""]
@@ -241,7 +270,11 @@ class HfRepoGitTool:
241
  else:
242
  lines.append("**Tags:** none")
243
 
244
- return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)}
 
 
 
 
245
 
246
  # =========================================================================
247
  # PR OPERATIONS
@@ -270,7 +303,7 @@ class HfRepoGitTool:
270
 
271
  url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
272
  return {
273
- "formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"",
274
  "totalResults": 1,
275
  "resultsShared": 1,
276
  }
@@ -285,17 +318,27 @@ class HfRepoGitTool:
285
  repo_type = args.get("repo_type", "model")
286
  status = args.get("status", "all") # open, closed, all
287
 
288
- discussions = list(self.api.get_repo_discussions(
289
- repo_id=repo_id,
290
- repo_type=repo_type,
291
- discussion_status=status if status != "all" else None,
292
- ))
 
 
293
 
294
  if not discussions:
295
- return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0}
 
 
 
 
296
 
297
  url = _build_repo_url(repo_id, repo_type)
298
- lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""]
 
 
 
 
299
 
300
  for d in discussions[:20]:
301
  if d.status == "draft":
@@ -309,7 +352,11 @@ class HfRepoGitTool:
309
  type_label = "PR" if d.is_pull_request else "D"
310
  lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}")
311
 
312
- return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))}
 
 
 
 
313
 
314
  async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
315
  """Get PR details."""
@@ -335,7 +382,7 @@ class HfRepoGitTool:
335
  "draft": "Draft",
336
  "open": "Open",
337
  "merged": "Merged",
338
- "closed": "Closed"
339
  }
340
  status = status_map.get(pr.status, pr.status.capitalize())
341
  type_label = "Pull Request" if pr.is_pull_request else "Discussion"
@@ -349,9 +396,13 @@ class HfRepoGitTool:
349
 
350
  if pr.is_pull_request:
351
  if pr.status == "draft":
352
- lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
 
 
353
  elif pr.status == "open":
354
- lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
 
 
355
 
356
  return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
357
 
@@ -377,7 +428,11 @@ class HfRepoGitTool:
377
  )
378
 
379
  url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
380
- return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1}
 
 
 
 
381
 
382
  async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
383
  """Close a PR/discussion."""
@@ -401,7 +456,11 @@ class HfRepoGitTool:
401
  repo_type=repo_type,
402
  )
403
 
404
- return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1}
 
 
 
 
405
 
406
  async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
407
  """Add a comment to a PR/discussion."""
@@ -427,7 +486,11 @@ class HfRepoGitTool:
427
  )
428
 
429
  url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
430
- return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1}
 
 
 
 
431
 
432
  async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult:
433
  """Change PR/discussion status (mainly to convert draft to open)."""
@@ -455,7 +518,11 @@ class HfRepoGitTool:
455
  )
456
 
457
  url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
458
- return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1}
 
 
 
 
459
 
460
  # =========================================================================
461
  # REPO MANAGEMENT
@@ -473,7 +540,9 @@ class HfRepoGitTool:
473
  space_sdk = args.get("space_sdk")
474
 
475
  if repo_type == "space" and not space_sdk:
476
- return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)")
 
 
477
 
478
  kwargs = {
479
  "repo_id": repo_id,
@@ -485,6 +554,17 @@ class HfRepoGitTool:
485
  kwargs["space_sdk"] = space_sdk
486
 
487
  result = await _async_call(self.api.create_repo, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
488
 
489
  return {
490
  "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
@@ -504,7 +584,9 @@ class HfRepoGitTool:
504
  gated = args.get("gated")
505
 
506
  if private is None and gated is None:
507
- return self._error("Specify private (bool) or gated ('auto'/'manual'/false)")
 
 
508
 
509
  kwargs = {"repo_id": repo_id, "repo_type": repo_type}
510
  if private is not None:
@@ -521,11 +603,20 @@ class HfRepoGitTool:
521
  changes.append(f"gated={gated}")
522
 
523
  url = f"{_build_repo_url(repo_id, repo_type)}/settings"
524
- return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1}
 
 
 
 
525
 
526
  def _error(self, message: str) -> ToolResult:
527
  """Return an error result."""
528
- return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
 
 
 
 
 
529
 
530
 
531
  # Tool specification
@@ -571,10 +662,20 @@ HF_REPO_GIT_TOOL_SPEC = {
571
  "operation": {
572
  "type": "string",
573
  "enum": [
574
- "create_branch", "delete_branch",
575
- "create_tag", "delete_tag", "list_refs",
576
- "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
577
- "create_repo", "update_repo",
 
 
 
 
 
 
 
 
 
 
578
  ],
579
  "description": "Operation to execute",
580
  },
@@ -653,11 +754,13 @@ HF_REPO_GIT_TOOL_SPEC = {
653
  }
654
 
655
 
656
- async def hf_repo_git_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]:
 
 
657
  """Handler for agent tool router."""
658
  try:
659
  hf_token = session.hf_token if session else None
660
- tool = HfRepoGitTool(hf_token=hf_token)
661
  result = await tool.execute(arguments)
662
  return result["formatted"], not result.get("isError", False)
663
  except Exception as e:
 
10
  from huggingface_hub import HfApi
11
  from huggingface_hub.utils import RepositoryNotFoundError
12
 
13
+ from agent.core.hub_artifacts import register_hub_artifact
14
  from agent.tools.types import ToolResult
15
 
16
  OperationType = Literal[
17
+ "create_branch",
18
+ "delete_branch",
19
+ "create_tag",
20
+ "delete_tag",
21
  "list_refs",
22
+ "create_pr",
23
+ "list_prs",
24
+ "get_pr",
25
+ "merge_pr",
26
+ "close_pr",
27
+ "comment_pr",
28
+ "change_pr_status",
29
+ "create_repo",
30
+ "update_repo",
31
  ]
32
 
33
 
 
46
  class HfRepoGitTool:
47
  """Tool for git-like operations on HF repos."""
48
 
49
+ def __init__(self, hf_token: Optional[str] = None, session: Any = None):
50
  self.api = HfApi(token=hf_token)
51
+ self.session = session
52
 
53
  async def execute(self, args: Dict[str, Any]) -> ToolResult:
54
  """Execute the specified operation."""
 
142
  )
143
 
144
  url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
145
+ return {
146
+ "formatted": f"**Branch created:** {branch}\n{url}",
147
+ "totalResults": 1,
148
+ "resultsShared": 1,
149
+ }
150
 
151
  async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
152
  """Delete a branch."""
 
167
  repo_type=repo_type,
168
  )
169
 
170
+ return {
171
+ "formatted": f"**Branch deleted:** {branch}",
172
+ "totalResults": 1,
173
+ "resultsShared": 1,
174
+ }
175
 
176
  # =========================================================================
177
  # TAG OPERATIONS
 
202
  )
203
 
204
  url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
205
+ return {
206
+ "formatted": f"**Tag created:** {tag}\n{url}",
207
+ "totalResults": 1,
208
+ "resultsShared": 1,
209
+ }
210
 
211
  async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
212
  """Delete a tag."""
 
227
  repo_type=repo_type,
228
  )
229
 
230
+ return {
231
+ "formatted": f"**Tag deleted:** {tag}",
232
+ "totalResults": 1,
233
+ "resultsShared": 1,
234
+ }
235
 
236
  # =========================================================================
237
  # LIST REFS
 
253
  )
254
 
255
  branches = [b.name for b in refs.branches] if refs.branches else []
256
+ tags = (
257
+ [t.name for t in refs.tags] if hasattr(refs, "tags") and refs.tags else []
258
+ )
259
 
260
  url = _build_repo_url(repo_id, repo_type)
261
  lines = [f"**{repo_id}**", url, ""]
 
270
  else:
271
  lines.append("**Tags:** none")
272
 
273
+ return {
274
+ "formatted": "\n".join(lines),
275
+ "totalResults": len(branches) + len(tags),
276
+ "resultsShared": len(branches) + len(tags),
277
+ }
278
 
279
  # =========================================================================
280
  # PR OPERATIONS
 
303
 
304
  url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
305
  return {
306
+ "formatted": f'**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision="refs/pr/{result.num}"',
307
  "totalResults": 1,
308
  "resultsShared": 1,
309
  }
 
318
  repo_type = args.get("repo_type", "model")
319
  status = args.get("status", "all") # open, closed, all
320
 
321
+ discussions = list(
322
+ self.api.get_repo_discussions(
323
+ repo_id=repo_id,
324
+ repo_type=repo_type,
325
+ discussion_status=status if status != "all" else None,
326
+ )
327
+ )
328
 
329
  if not discussions:
330
+ return {
331
+ "formatted": f"No discussions in {repo_id}",
332
+ "totalResults": 0,
333
+ "resultsShared": 0,
334
+ }
335
 
336
  url = _build_repo_url(repo_id, repo_type)
337
+ lines = [
338
+ f"**{repo_id}** - {len(discussions)} discussions",
339
+ f"{url}/discussions",
340
+ "",
341
+ ]
342
 
343
  for d in discussions[:20]:
344
  if d.status == "draft":
 
352
  type_label = "PR" if d.is_pull_request else "D"
353
  lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}")
354
 
355
+ return {
356
+ "formatted": "\n".join(lines),
357
+ "totalResults": len(discussions),
358
+ "resultsShared": min(20, len(discussions)),
359
+ }
360
 
361
  async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
362
  """Get PR details."""
 
382
  "draft": "Draft",
383
  "open": "Open",
384
  "merged": "Merged",
385
+ "closed": "Closed",
386
  }
387
  status = status_map.get(pr.status, pr.status.capitalize())
388
  type_label = "Pull Request" if pr.is_pull_request else "Discussion"
 
396
 
397
  if pr.is_pull_request:
398
  if pr.status == "draft":
399
+ lines.append(
400
+ f'\nTo add commits: upload with revision="refs/pr/{pr_num}"'
401
+ )
402
  elif pr.status == "open":
403
+ lines.append(
404
+ f'\nTo add commits: upload with revision="refs/pr/{pr_num}"'
405
+ )
406
 
407
  return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
408
 
 
428
  )
429
 
430
  url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
431
+ return {
432
+ "formatted": f"**PR #{pr_num} merged**\n{url}",
433
+ "totalResults": 1,
434
+ "resultsShared": 1,
435
+ }
436
 
437
  async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
438
  """Close a PR/discussion."""
 
456
  repo_type=repo_type,
457
  )
458
 
459
+ return {
460
+ "formatted": f"**Discussion #{pr_num} closed**",
461
+ "totalResults": 1,
462
+ "resultsShared": 1,
463
+ }
464
 
465
  async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
466
  """Add a comment to a PR/discussion."""
 
486
  )
487
 
488
  url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
489
+ return {
490
+ "formatted": f"**Comment added to #{pr_num}**\n{url}",
491
+ "totalResults": 1,
492
+ "resultsShared": 1,
493
+ }
494
 
495
  async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult:
496
  """Change PR/discussion status (mainly to convert draft to open)."""
 
518
  )
519
 
520
  url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
521
+ return {
522
+ "formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}",
523
+ "totalResults": 1,
524
+ "resultsShared": 1,
525
+ }
526
 
527
  # =========================================================================
528
  # REPO MANAGEMENT
 
540
  space_sdk = args.get("space_sdk")
541
 
542
  if repo_type == "space" and not space_sdk:
543
+ return self._error(
544
+ "space_sdk required for spaces (gradio/streamlit/docker/static)"
545
+ )
546
 
547
  kwargs = {
548
  "repo_id": repo_id,
 
554
  kwargs["space_sdk"] = space_sdk
555
 
556
  result = await _async_call(self.api.create_repo, **kwargs)
557
+ extra_metadata = None
558
+ if repo_type == "space" and space_sdk:
559
+ extra_metadata = {"sdk": space_sdk}
560
+ await _async_call(
561
+ register_hub_artifact,
562
+ self.api,
563
+ repo_id,
564
+ repo_type,
565
+ session=self.session,
566
+ extra_metadata=extra_metadata,
567
+ )
568
 
569
  return {
570
  "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
 
584
  gated = args.get("gated")
585
 
586
  if private is None and gated is None:
587
+ return self._error(
588
+ "Specify private (bool) or gated ('auto'/'manual'/false)"
589
+ )
590
 
591
  kwargs = {"repo_id": repo_id, "repo_type": repo_type}
592
  if private is not None:
 
603
  changes.append(f"gated={gated}")
604
 
605
  url = f"{_build_repo_url(repo_id, repo_type)}/settings"
606
+ return {
607
+ "formatted": f"**Settings updated:** {', '.join(changes)}\n{url}",
608
+ "totalResults": 1,
609
+ "resultsShared": 1,
610
+ }
611
 
612
  def _error(self, message: str) -> ToolResult:
613
  """Return an error result."""
614
+ return {
615
+ "formatted": message,
616
+ "totalResults": 0,
617
+ "resultsShared": 0,
618
+ "isError": True,
619
+ }
620
 
621
 
622
  # Tool specification
 
662
  "operation": {
663
  "type": "string",
664
  "enum": [
665
+ "create_branch",
666
+ "delete_branch",
667
+ "create_tag",
668
+ "delete_tag",
669
+ "list_refs",
670
+ "create_pr",
671
+ "list_prs",
672
+ "get_pr",
673
+ "merge_pr",
674
+ "close_pr",
675
+ "comment_pr",
676
+ "change_pr_status",
677
+ "create_repo",
678
+ "update_repo",
679
  ],
680
  "description": "Operation to execute",
681
  },
 
754
  }
755
 
756
 
757
+ async def hf_repo_git_handler(
758
+ arguments: Dict[str, Any], session=None
759
+ ) -> tuple[str, bool]:
760
  """Handler for agent tool router."""
761
  try:
762
  hf_token = session.hf_token if session else None
763
+ tool = HfRepoGitTool(hf_token=hf_token, session=session)
764
  result = await tool.execute(arguments)
765
  return result["formatted"], not result.get("isError", False)
766
  except Exception as e:
agent/tools/jobs_tool.py CHANGED
@@ -7,22 +7,24 @@ Refactored to use official huggingface-hub library instead of custom HTTP client
7
  import asyncio
8
  import base64
9
  import http.client
10
- import os
11
- import re
12
- from typing import Any, Dict, Literal, Optional, Callable, Awaitable
13
-
14
  import logging
 
 
 
15
 
16
  import httpx
17
  from huggingface_hub import HfApi
18
  from huggingface_hub.utils import HfHubHTTPError
19
 
20
- from agent.core.hf_access import JobsAccessError, is_billing_error, resolve_jobs_namespace
 
 
 
 
 
21
  from agent.core.session import Event
22
  from agent.tools.trackio_seed import ensure_trackio_dashboard
23
  from agent.tools.types import ToolResult
24
-
25
- logger = logging.getLogger(__name__)
26
  from agent.tools.utilities import (
27
  format_job_details,
28
  format_jobs_table,
@@ -30,6 +32,8 @@ from agent.tools.utilities import (
30
  format_scheduled_jobs_table,
31
  )
32
 
 
 
33
  # Hardware flavors
34
  CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"]
35
  GPU_FLAVORS = [
@@ -119,11 +123,11 @@ def _filter_uv_install_output(logs: list[str]) -> list[str]:
119
  return logs
120
 
121
 
122
- _ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07')
123
 
124
 
125
  def _strip_ansi(text: str) -> str:
126
- return _ANSI_RE.sub('', text)
127
 
128
 
129
  _DEFAULT_ENV = {
@@ -235,6 +239,26 @@ def _resolve_uv_command(
235
  return _build_uv_command(script, with_deps, python, script_args)
236
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  async def _async_call(func, *args, **kwargs):
239
  """Wrap synchronous HfApi calls for async context"""
240
  return await asyncio.to_thread(func, *args, **kwargs)
@@ -432,7 +456,9 @@ class HfJobsTool:
432
  def log_producer():
433
  try:
434
  # fetch_job_logs is a blocking sync generator
435
- logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace)
 
 
436
  for line in logs_gen:
437
  # Push line to queue thread-safely
438
  loop.call_soon_threadsafe(queue.put_nowait, line)
@@ -556,6 +582,8 @@ class HfJobsTool:
556
  image = args.get("image", "python:3.12")
557
  job_type = "Docker"
558
 
 
 
559
  # Run the job
560
  flavor = args.get("hardware_flavor", "cpu-basic")
561
  timeout_str = args.get("timeout", "30m")
@@ -578,7 +606,9 @@ class HfJobsTool:
578
  image=image,
579
  command=command,
580
  env=env_dict,
581
- secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
 
 
582
  flavor=flavor,
583
  timeout=timeout_str,
584
  namespace=self.namespace,
@@ -636,10 +666,18 @@ class HfJobsTool:
636
  submit_ts = None
637
  if self.session:
638
  from agent.core import telemetry
 
639
  submit_ts = await telemetry.record_hf_job_submit(
640
- self.session, job,
641
- {**args, "hardware_flavor": flavor, "timeout": timeout_str, "namespace": self.namespace},
642
- image=image, job_type=job_type,
 
 
 
 
 
 
 
643
  )
644
  # Top-up signal: this submit succeeded after a prior billing
645
  # block in the same session, and we haven't fired the event
@@ -656,7 +694,8 @@ class HfJobsTool:
656
  )
657
  if blocked:
658
  await telemetry.record_credits_topped_up(
659
- self.session, namespace=self.namespace,
 
660
  )
661
 
662
  # Wait for completion and stream logs
@@ -670,9 +709,13 @@ class HfJobsTool:
670
 
671
  if self.session and submit_ts is not None:
672
  from agent.core import telemetry
 
673
  await telemetry.record_hf_job_complete(
674
- self.session, job,
675
- flavor=flavor, final_status=final_status, submit_ts=submit_ts,
 
 
 
676
  )
677
 
678
  # Untrack job ID (completed or failed, no longer needs cancellation)
@@ -699,7 +742,9 @@ class HfJobsTool:
699
  filtered_logs = _filter_uv_install_output(all_logs)
700
 
701
  # Format all logs for the agent
702
- log_text = _strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)"
 
 
703
 
704
  response = f"""{job_type} job completed!
705
 
@@ -891,6 +936,8 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}}
891
  image = args.get("image", "python:3.12")
892
  job_type = "Docker"
893
 
 
 
894
  # Create scheduled job
895
  scheduled_job = await _async_call(
896
  self.api.create_scheduled_job,
@@ -1215,6 +1262,7 @@ async def hf_jobs_handler(
1215
  sandbox = getattr(session, "sandbox", None) if session else None
1216
  if sandbox and script:
1217
  from agent.tools.sandbox_tool import resolve_sandbox_script
 
1218
  content, error = await resolve_sandbox_script(sandbox, script)
1219
  if error:
1220
  return error, False
 
7
  import asyncio
8
  import base64
9
  import http.client
 
 
 
 
10
  import logging
11
+ import re
12
+ import shlex
13
+ from typing import Any, Awaitable, Callable, Dict, Literal, Optional
14
 
15
  import httpx
16
  from huggingface_hub import HfApi
17
  from huggingface_hub.utils import HfHubHTTPError
18
 
19
+ from agent.core.hf_access import (
20
+ JobsAccessError,
21
+ is_billing_error,
22
+ resolve_jobs_namespace,
23
+ )
24
+ from agent.core.hub_artifacts import build_hub_artifact_sitecustomize
25
  from agent.core.session import Event
26
  from agent.tools.trackio_seed import ensure_trackio_dashboard
27
  from agent.tools.types import ToolResult
 
 
28
  from agent.tools.utilities import (
29
  format_job_details,
30
  format_jobs_table,
 
32
  format_scheduled_jobs_table,
33
  )
34
 
35
+ logger = logging.getLogger(__name__)
36
+
37
  # Hardware flavors
38
  CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"]
39
  GPU_FLAVORS = [
 
123
  return logs
124
 
125
 
126
+ _ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07")
127
 
128
 
129
  def _strip_ansi(text: str) -> str:
130
+ return _ANSI_RE.sub("", text)
131
 
132
 
133
  _DEFAULT_ENV = {
 
239
  return _build_uv_command(script, with_deps, python, script_args)
240
 
241
 
242
+ def _wrap_command_with_artifact_bootstrap(
243
+ command: list[str], session: Any = None
244
+ ) -> list[str]:
245
+ """Install sitecustomize hooks before the user command runs in HF Jobs."""
246
+ sitecustomize = build_hub_artifact_sitecustomize(session)
247
+ if not sitecustomize:
248
+ return command
249
+
250
+ encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii")
251
+ original_command = shlex.join(command)
252
+ shell = (
253
+ 'set -e; _ml_intern_artifacts_dir="$(mktemp -d)"; '
254
+ f"printf %s {shlex.quote(encoded)} | base64 -d "
255
+ '> "$_ml_intern_artifacts_dir/sitecustomize.py"; '
256
+ 'export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"; '
257
+ f"exec {original_command}"
258
+ )
259
+ return ["/bin/sh", "-lc", shell]
260
+
261
+
262
  async def _async_call(func, *args, **kwargs):
263
  """Wrap synchronous HfApi calls for async context"""
264
  return await asyncio.to_thread(func, *args, **kwargs)
 
456
  def log_producer():
457
  try:
458
  # fetch_job_logs is a blocking sync generator
459
+ logs_gen = self.api.fetch_job_logs(
460
+ job_id=job_id, namespace=namespace
461
+ )
462
  for line in logs_gen:
463
  # Push line to queue thread-safely
464
  loop.call_soon_threadsafe(queue.put_nowait, line)
 
582
  image = args.get("image", "python:3.12")
583
  job_type = "Docker"
584
 
585
+ command = _wrap_command_with_artifact_bootstrap(command, self.session)
586
+
587
  # Run the job
588
  flavor = args.get("hardware_flavor", "cpu-basic")
589
  timeout_str = args.get("timeout", "30m")
 
606
  image=image,
607
  command=command,
608
  env=env_dict,
609
+ secrets=_add_environment_variables(
610
+ args.get("secrets"), self.hf_token
611
+ ),
612
  flavor=flavor,
613
  timeout=timeout_str,
614
  namespace=self.namespace,
 
666
  submit_ts = None
667
  if self.session:
668
  from agent.core import telemetry
669
+
670
  submit_ts = await telemetry.record_hf_job_submit(
671
+ self.session,
672
+ job,
673
+ {
674
+ **args,
675
+ "hardware_flavor": flavor,
676
+ "timeout": timeout_str,
677
+ "namespace": self.namespace,
678
+ },
679
+ image=image,
680
+ job_type=job_type,
681
  )
682
  # Top-up signal: this submit succeeded after a prior billing
683
  # block in the same session, and we haven't fired the event
 
694
  )
695
  if blocked:
696
  await telemetry.record_credits_topped_up(
697
+ self.session,
698
+ namespace=self.namespace,
699
  )
700
 
701
  # Wait for completion and stream logs
 
709
 
710
  if self.session and submit_ts is not None:
711
  from agent.core import telemetry
712
+
713
  await telemetry.record_hf_job_complete(
714
+ self.session,
715
+ job,
716
+ flavor=flavor,
717
+ final_status=final_status,
718
+ submit_ts=submit_ts,
719
  )
720
 
721
  # Untrack job ID (completed or failed, no longer needs cancellation)
 
742
  filtered_logs = _filter_uv_install_output(all_logs)
743
 
744
  # Format all logs for the agent
745
+ log_text = (
746
+ _strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)"
747
+ )
748
 
749
  response = f"""{job_type} job completed!
750
 
 
936
  image = args.get("image", "python:3.12")
937
  job_type = "Docker"
938
 
939
+ command = _wrap_command_with_artifact_bootstrap(command, self.session)
940
+
941
  # Create scheduled job
942
  scheduled_job = await _async_call(
943
  self.api.create_scheduled_job,
 
1262
  sandbox = getattr(session, "sandbox", None) if session else None
1263
  if sandbox and script:
1264
  from agent.tools.sandbox_tool import resolve_sandbox_script
1265
+
1266
  content, error = await resolve_sandbox_script(sandbox, script)
1267
  if error:
1268
  return error, False
agent/tools/local_tools.py CHANGED
@@ -15,6 +15,8 @@ import tempfile
15
  from pathlib import Path
16
  from typing import Any
17
 
 
 
18
 
19
  MAX_OUTPUT_CHARS = 25_000
20
  MAX_LINE_LENGTH = 4000
@@ -22,7 +24,7 @@ DEFAULT_READ_LINES = 2000
22
  DEFAULT_TIMEOUT = 120
23
  MAX_TIMEOUT = 36000 # 10 hours — needed for long training runs (e.g. PostTrainBench)
24
 
25
- _ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07')
26
 
27
  # Track files that have been read this session (enforces read-before-write/edit)
28
  _files_read: set[str] = set()
@@ -63,17 +65,21 @@ def _atomic_write(path: Path, content: str) -> None:
63
 
64
 
65
  def _strip_ansi(text: str) -> str:
66
- return _ANSI_RE.sub('', text)
67
 
68
 
69
- def _truncate_output(output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25) -> str:
 
 
70
  """Tail-biased truncation with temp file spillover for full output access."""
71
  if len(output) <= max_chars:
72
  return output
73
  # Write full output to temp file so LLM can read specific sections
74
  spill_path = None
75
  try:
76
- with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', prefix='bash_output_', delete=False) as f:
 
 
77
  f.write(output)
78
  spill_path = f.name
79
  except Exception:
@@ -93,10 +99,14 @@ def _truncate_output(output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio:
93
 
94
  # ── Handlers ────────────────────────────────────────────────────────────
95
 
96
- async def _bash_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]:
 
 
 
97
  command = args.get("command", "")
98
  if not command:
99
  return "No command provided.", False
 
100
  work_dir = args.get("work_dir", ".")
101
  timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT)
102
  try:
@@ -174,9 +184,12 @@ async def _write_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]:
174
  # Syntax validation for Python files
175
  if p.suffix == ".py":
176
  from agent.tools.edit_utils import validate_python
 
177
  warnings = validate_python(content, file_path)
178
  if warnings:
179
- msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings)
 
 
180
  return msg, True
181
  except Exception as e:
182
  return f"write error: {e}", False
@@ -229,7 +242,9 @@ async def _edit_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]:
229
  if p.suffix == ".py":
230
  warnings = validate_python(new_text, file_path)
231
  if warnings:
232
- msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings)
 
 
233
  return msg, True
234
 
235
 
 
15
  from pathlib import Path
16
  from typing import Any
17
 
18
+ from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap
19
+
20
 
21
  MAX_OUTPUT_CHARS = 25_000
22
  MAX_LINE_LENGTH = 4000
 
24
  DEFAULT_TIMEOUT = 120
25
  MAX_TIMEOUT = 36000 # 10 hours — needed for long training runs (e.g. PostTrainBench)
26
 
27
+ _ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07")
28
 
29
  # Track files that have been read this session (enforces read-before-write/edit)
30
  _files_read: set[str] = set()
 
65
 
66
 
67
  def _strip_ansi(text: str) -> str:
68
+ return _ANSI_RE.sub("", text)
69
 
70
 
71
+ def _truncate_output(
72
+ output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25
73
+ ) -> str:
74
  """Tail-biased truncation with temp file spillover for full output access."""
75
  if len(output) <= max_chars:
76
  return output
77
  # Write full output to temp file so LLM can read specific sections
78
  spill_path = None
79
  try:
80
+ with tempfile.NamedTemporaryFile(
81
+ mode="w", suffix=".txt", prefix="bash_output_", delete=False
82
+ ) as f:
83
  f.write(output)
84
  spill_path = f.name
85
  except Exception:
 
99
 
100
  # ── Handlers ────────────────────────────────────────────────────────────
101
 
102
+
103
+ async def _bash_handler(
104
+ args: dict[str, Any], session: Any = None, **_kw
105
+ ) -> tuple[str, bool]:
106
  command = args.get("command", "")
107
  if not command:
108
  return "No command provided.", False
109
+ command = wrap_shell_command_with_hub_artifact_bootstrap(command, session)
110
  work_dir = args.get("work_dir", ".")
111
  timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT)
112
  try:
 
184
  # Syntax validation for Python files
185
  if p.suffix == ".py":
186
  from agent.tools.edit_utils import validate_python
187
+
188
  warnings = validate_python(content, file_path)
189
  if warnings:
190
+ msg += "\n\nValidation warnings:\n" + "\n".join(
191
+ f" ⚠ {w}" for w in warnings
192
+ )
193
  return msg, True
194
  except Exception as e:
195
  return f"write error: {e}", False
 
242
  if p.suffix == ".py":
243
  warnings = validate_python(new_text, file_path)
244
  if warnings:
245
+ msg += "\n\nValidation warnings:\n" + "\n".join(
246
+ f" ⚠ {w}" for w in warnings
247
+ )
248
  return msg, True
249
 
250
 
agent/tools/papers_tool.py CHANGED
@@ -102,7 +102,9 @@ async def _s2_request(
102
 
103
 
104
  async def _s2_get_json(
105
- client: httpx.AsyncClient, path: str, params: dict | None = None,
 
 
106
  ) -> dict | None:
107
  """Cached S2 GET returning parsed JSON or None."""
108
  key = _s2_cache_key(path, params)
@@ -119,7 +121,9 @@ async def _s2_get_json(
119
 
120
 
121
  async def _s2_get_paper(
122
- client: httpx.AsyncClient, arxiv_id: str, fields: str,
 
 
123
  ) -> dict | None:
124
  """Fetch a single paper from S2 by arxiv ID. Returns None on failure."""
125
  return await _s2_get_json(
@@ -322,7 +326,9 @@ def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str:
322
  if keywords:
323
  lines.append(f"**Keywords:** {', '.join(keywords)}")
324
  if s2_data and s2_data.get("s2FieldsOfStudy"):
325
- fields = [f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category")]
 
 
326
  if fields:
327
  lines.append(f"**Fields:** {', '.join(fields)}")
328
  if s2_data and s2_data.get("venue"):
@@ -393,7 +399,9 @@ def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str:
393
  ds_id = ds.get("id", "unknown")
394
  downloads = ds.get("downloads", 0)
395
  likes = ds.get("likes", 0)
396
- desc = _truncate(_clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN)
 
 
397
  tags = ds.get("tags") or []
398
  interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5]
399
 
@@ -582,11 +590,15 @@ def _format_s2_paper_list(papers: list[dict], title: str) -> str:
582
  lines.append(f"**TL;DR:** {tldr}")
583
  lines.append("")
584
 
585
- lines.append("Use paper_details with arxiv_id for full info, or read_paper to read sections.")
 
 
586
  return "\n".join(lines)
587
 
588
 
589
- async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolResult | None:
 
 
590
  """Search via S2 bulk endpoint with filters. Returns None on failure."""
591
  params: dict[str, Any] = {
592
  "query": query,
@@ -616,7 +628,9 @@ async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolR
616
  params["sort"] = f"{sort_by}:desc"
617
 
618
  async with httpx.AsyncClient(timeout=15) as client:
619
- resp = await _s2_request(client, "GET", "/graph/v1/paper/search/bulk", params=params)
 
 
620
  if not resp or resp.status_code != 200:
621
  return None
622
  data = resp.json()
@@ -629,7 +643,9 @@ async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolR
629
  "resultsShared": 0,
630
  }
631
 
632
- formatted = _format_s2_paper_list(papers[:limit], f"Papers matching '{query}' (Semantic Scholar)")
 
 
633
  return {
634
  "formatted": formatted,
635
  "totalResults": data.get("total", len(papers)),
@@ -643,7 +659,10 @@ async def _op_search(args: dict[str, Any], limit: int) -> ToolResult:
643
  return _error("'query' is required for search operation.")
644
 
645
  # Route to S2 when filters are present
646
- use_s2 = any(args.get(k) for k in ("date_from", "date_to", "categories", "min_citations", "sort_by"))
 
 
 
647
  if use_s2:
648
  result = await _s2_bulk_search(query, args, limit)
649
  if result is not None:
@@ -806,7 +825,9 @@ def _format_citation_graph(
806
  lines.append("No citations found.")
807
  lines.append("")
808
 
809
- lines.append("**Tip:** Use paper_details with an arxiv_id from above to explore further.")
 
 
810
  return "\n".join(lines)
811
 
812
 
@@ -824,9 +845,13 @@ async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult:
824
  refs, cites = None, None
825
  coros = []
826
  if direction in ("references", "both"):
827
- coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params))
 
 
828
  if direction in ("citations", "both"):
829
- coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params))
 
 
830
 
831
  results = await asyncio.gather(*coros, return_exceptions=True)
832
  idx = 0
@@ -841,7 +866,9 @@ async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult:
841
  cites = r.get("data", [])
842
 
843
  if refs is None and cites is None:
844
- return _error(f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar.")
 
 
845
 
846
  total = (len(refs) if refs else 0) + (len(cites) if cites else 0)
847
  return {
@@ -1039,7 +1066,9 @@ def _format_snippets(snippets: list[dict], query: str) -> str:
1039
  lines.append(f"> {_truncate(text, 400)}")
1040
  lines.append("")
1041
 
1042
- lines.append("Use paper_details or read_paper with arxiv_id to explore a paper further.")
 
 
1043
  return "\n".join(lines)
1044
 
1045
 
@@ -1065,7 +1094,9 @@ async def _op_snippet_search(args: dict[str, Any], limit: int) -> ToolResult:
1065
  params["minCitationCount"] = str(args["min_citations"])
1066
 
1067
  async with httpx.AsyncClient(timeout=15) as client:
1068
- resp = await _s2_request(client, "GET", "/graph/v1/snippet/search", params=params)
 
 
1069
  if not resp or resp.status_code != 200:
1070
  return _error("Snippet search failed. Semantic Scholar may be unavailable.")
1071
  data = resp.json()
@@ -1102,16 +1133,28 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
1102
  async with httpx.AsyncClient(timeout=15) as client:
1103
  if positive_ids and not arxiv_id:
1104
  # Multi-paper recommendations (POST, not cached)
1105
- pos = [_s2_paper_id(pid.strip()) for pid in positive_ids.split(",") if pid.strip()]
 
 
 
 
1106
  neg_raw = args.get("negative_ids", "")
1107
- neg = [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] if neg_raw else []
 
 
 
 
1108
  resp = await _s2_request(
1109
- client, "POST", "/recommendations/v1/papers/",
 
 
1110
  json={"positivePaperIds": pos, "negativePaperIds": neg},
1111
  params={"fields": fields, "limit": limit},
1112
  )
1113
  if not resp or resp.status_code != 200:
1114
- return _error("Recommendation request failed. Semantic Scholar may be unavailable.")
 
 
1115
  data = resp.json()
1116
  else:
1117
  # Single-paper recommendations (cached)
@@ -1121,7 +1164,9 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult:
1121
  {"fields": fields, "limit": limit, "from": "recent"},
1122
  )
1123
  if not data:
1124
- return _error("Recommendation request failed. Semantic Scholar may be unavailable.")
 
 
1125
 
1126
  papers = data.get("recommendedPapers") or []
1127
  if not papers:
 
102
 
103
 
104
  async def _s2_get_json(
105
+ client: httpx.AsyncClient,
106
+ path: str,
107
+ params: dict | None = None,
108
  ) -> dict | None:
109
  """Cached S2 GET returning parsed JSON or None."""
110
  key = _s2_cache_key(path, params)
 
121
 
122
 
123
  async def _s2_get_paper(
124
+ client: httpx.AsyncClient,
125
+ arxiv_id: str,
126
+ fields: str,
127
  ) -> dict | None:
128
  """Fetch a single paper from S2 by arxiv ID. Returns None on failure."""
129
  return await _s2_get_json(
 
326
  if keywords:
327
  lines.append(f"**Keywords:** {', '.join(keywords)}")
328
  if s2_data and s2_data.get("s2FieldsOfStudy"):
329
+ fields = [
330
+ f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category")
331
+ ]
332
  if fields:
333
  lines.append(f"**Fields:** {', '.join(fields)}")
334
  if s2_data and s2_data.get("venue"):
 
399
  ds_id = ds.get("id", "unknown")
400
  downloads = ds.get("downloads", 0)
401
  likes = ds.get("likes", 0)
402
+ desc = _truncate(
403
+ _clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN
404
+ )
405
  tags = ds.get("tags") or []
406
  interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5]
407
 
 
590
  lines.append(f"**TL;DR:** {tldr}")
591
  lines.append("")
592
 
593
+ lines.append(
594
+ "Use paper_details with arxiv_id for full info, or read_paper to read sections."
595
+ )
596
  return "\n".join(lines)
597
 
598
 
599
+ async def _s2_bulk_search(
600
+ query: str, args: dict[str, Any], limit: int
601
+ ) -> ToolResult | None:
602
  """Search via S2 bulk endpoint with filters. Returns None on failure."""
603
  params: dict[str, Any] = {
604
  "query": query,
 
628
  params["sort"] = f"{sort_by}:desc"
629
 
630
  async with httpx.AsyncClient(timeout=15) as client:
631
+ resp = await _s2_request(
632
+ client, "GET", "/graph/v1/paper/search/bulk", params=params
633
+ )
634
  if not resp or resp.status_code != 200:
635
  return None
636
  data = resp.json()
 
643
  "resultsShared": 0,
644
  }
645
 
646
+ formatted = _format_s2_paper_list(
647
+ papers[:limit], f"Papers matching '{query}' (Semantic Scholar)"
648
+ )
649
  return {
650
  "formatted": formatted,
651
  "totalResults": data.get("total", len(papers)),
 
659
  return _error("'query' is required for search operation.")
660
 
661
  # Route to S2 when filters are present
662
+ use_s2 = any(
663
+ args.get(k)
664
+ for k in ("date_from", "date_to", "categories", "min_citations", "sort_by")
665
+ )
666
  if use_s2:
667
  result = await _s2_bulk_search(query, args, limit)
668
  if result is not None:
 
825
  lines.append("No citations found.")
826
  lines.append("")
827
 
828
+ lines.append(
829
+ "**Tip:** Use paper_details with an arxiv_id from above to explore further."
830
+ )
831
  return "\n".join(lines)
832
 
833
 
 
845
  refs, cites = None, None
846
  coros = []
847
  if direction in ("references", "both"):
848
+ coros.append(
849
+ _s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params)
850
+ )
851
  if direction in ("citations", "both"):
852
+ coros.append(
853
+ _s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params)
854
+ )
855
 
856
  results = await asyncio.gather(*coros, return_exceptions=True)
857
  idx = 0
 
866
  cites = r.get("data", [])
867
 
868
  if refs is None and cites is None:
869
+ return _error(
870
+ f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar."
871
+ )
872
 
873
  total = (len(refs) if refs else 0) + (len(cites) if cites else 0)
874
  return {
 
1066
  lines.append(f"> {_truncate(text, 400)}")
1067
  lines.append("")
1068
 
1069
+ lines.append(
1070
+ "Use paper_details or read_paper with arxiv_id to explore a paper further."
1071
+ )
1072
  return "\n".join(lines)
1073
 
1074
 
 
1094
  params["minCitationCount"] = str(args["min_citations"])
1095
 
1096
  async with httpx.AsyncClient(timeout=15) as client:
1097
+ resp = await _s2_request(
1098
+ client, "GET", "/graph/v1/snippet/search", params=params
1099
+ )
1100
  if not resp or resp.status_code != 200:
1101
  return _error("Snippet search failed. Semantic Scholar may be unavailable.")
1102
  data = resp.json()
 
1133
  async with httpx.AsyncClient(timeout=15) as client:
1134
  if positive_ids and not arxiv_id:
1135
  # Multi-paper recommendations (POST, not cached)
1136
+ pos = [
1137
+ _s2_paper_id(pid.strip())
1138
+ for pid in positive_ids.split(",")
1139
+ if pid.strip()
1140
+ ]
1141
  neg_raw = args.get("negative_ids", "")
1142
+ neg = (
1143
+ [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()]
1144
+ if neg_raw
1145
+ else []
1146
+ )
1147
  resp = await _s2_request(
1148
+ client,
1149
+ "POST",
1150
+ "/recommendations/v1/papers/",
1151
  json={"positivePaperIds": pos, "negativePaperIds": neg},
1152
  params={"fields": fields, "limit": limit},
1153
  )
1154
  if not resp or resp.status_code != 200:
1155
+ return _error(
1156
+ "Recommendation request failed. Semantic Scholar may be unavailable."
1157
+ )
1158
  data = resp.json()
1159
  else:
1160
  # Single-paper recommendations (cached)
 
1164
  {"fields": fields, "limit": limit, "from": "recent"},
1165
  )
1166
  if not data:
1167
+ return _error(
1168
+ "Recommendation request failed. Semantic Scholar may be unavailable."
1169
+ )
1170
 
1171
  papers = data.get("recommendedPapers") or []
1172
  if not papers:
agent/tools/research_tool.py CHANGED
@@ -282,6 +282,7 @@ async def research_handler(
282
  _agent_id = tool_call_id
283
  else:
284
  import uuid
 
285
  _agent_id = uuid.uuid4().hex[:8]
286
  _agent_label = "research: " + (task[:50] + "…" if len(task) > 50 else task)
287
 
@@ -289,12 +290,15 @@ async def research_handler(
289
  """Send a progress event to the UI so it doesn't look frozen."""
290
  try:
291
  await session.send_event(
292
- Event(event_type="tool_log", data={
293
- "tool": "research",
294
- "log": text,
295
- "agent_id": _agent_id,
296
- "label": _agent_label,
297
- })
 
 
 
298
  )
299
  except Exception:
300
  pass
@@ -323,15 +327,19 @@ async def research_handler(
323
  "Research sub-agent hit context max (%d tokens) — forcing summary",
324
  _total_tokens,
325
  )
326
- await _log(f"Context limit reached ({_total_tokens} tokens) — forcing wrap-up")
 
 
327
  # Ask for a final summary with no tools
328
- messages.append(Message(
329
- role="user",
330
- content=(
331
- "[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. "
332
- "Summarize your findings NOW. Do NOT call any more tools."
333
- ),
334
- ))
 
 
335
  try:
336
  _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
337
  _t0 = time.monotonic()
@@ -351,27 +359,34 @@ async def research_handler(
351
  model=research_model,
352
  response=response,
353
  latency_ms=int((time.monotonic() - _t0) * 1000),
354
- finish_reason=response.choices[0].finish_reason if response.choices else None,
 
 
355
  kind="research",
356
  )
357
  except Exception as _telem_err:
358
  logger.debug("research telemetry failed: %s", _telem_err)
359
  content = response.choices[0].message.content or ""
360
- return content or "Research context exhausted — no summary produced.", bool(content)
 
 
 
361
  except Exception:
362
  return "Research context exhausted and summary call failed.", False
363
 
364
  if not _warned_context and _total_tokens >= _RESEARCH_CONTEXT_WARN:
365
  _warned_context = True
366
  await _log(f"Context at {_total_tokens} tokens — nudging to wrap up")
367
- messages.append(Message(
368
- role="user",
369
- content=(
370
- "[SYSTEM: You have used 75% of your context budget. "
371
- "Start wrapping up: finish any critical lookups, then "
372
- "produce your final summary within the next 1-2 iterations.]"
373
- ),
374
- ))
 
 
375
 
376
  try:
377
  _msgs, _tools = with_prompt_caching(
@@ -392,7 +407,9 @@ async def research_handler(
392
  model=research_model,
393
  response=response,
394
  latency_ms=int((time.monotonic() - _t0) * 1000),
395
- finish_reason=response.choices[0].finish_reason if response.choices else None,
 
 
396
  kind="research",
397
  )
398
  except Exception as _telem_err:
@@ -420,11 +437,13 @@ async def research_handler(
420
  # LiteLLM's raw Message carries `provider_specific_fields` and
421
  # `reasoning_content`, which the HF router's OpenAI schema rejects
422
  # if we echo them back in the next request.
423
- messages.append(Message(
424
- role="assistant",
425
- content=msg.content,
426
- tool_calls=msg.tool_calls,
427
- ))
 
 
428
  for tc in msg.tool_calls:
429
  try:
430
  tool_args = json.loads(tc.function.arguments)
@@ -479,13 +498,15 @@ async def research_handler(
479
 
480
  # ── Iteration limit: try to salvage findings ──
481
  await _log("Iteration limit reached — extracting summary")
482
- messages.append(Message(
483
- role="user",
484
- content=(
485
- "[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research "
486
- "iterations. Summarize ALL findings so far. Do NOT call any more tools."
487
- ),
488
- ))
 
 
489
  try:
490
  _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
491
  _t0 = time.monotonic()
@@ -502,7 +523,9 @@ async def research_handler(
502
  model=research_model,
503
  response=response,
504
  latency_ms=int((time.monotonic() - _t0) * 1000),
505
- finish_reason=response.choices[0].finish_reason if response.choices else None,
 
 
506
  kind="research",
507
  )
508
  except Exception as _telem_err:
 
282
  _agent_id = tool_call_id
283
  else:
284
  import uuid
285
+
286
  _agent_id = uuid.uuid4().hex[:8]
287
  _agent_label = "research: " + (task[:50] + "…" if len(task) > 50 else task)
288
 
 
290
  """Send a progress event to the UI so it doesn't look frozen."""
291
  try:
292
  await session.send_event(
293
+ Event(
294
+ event_type="tool_log",
295
+ data={
296
+ "tool": "research",
297
+ "log": text,
298
+ "agent_id": _agent_id,
299
+ "label": _agent_label,
300
+ },
301
+ )
302
  )
303
  except Exception:
304
  pass
 
327
  "Research sub-agent hit context max (%d tokens) — forcing summary",
328
  _total_tokens,
329
  )
330
+ await _log(
331
+ f"Context limit reached ({_total_tokens} tokens) — forcing wrap-up"
332
+ )
333
  # Ask for a final summary with no tools
334
+ messages.append(
335
+ Message(
336
+ role="user",
337
+ content=(
338
+ "[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. "
339
+ "Summarize your findings NOW. Do NOT call any more tools."
340
+ ),
341
+ )
342
+ )
343
  try:
344
  _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
345
  _t0 = time.monotonic()
 
359
  model=research_model,
360
  response=response,
361
  latency_ms=int((time.monotonic() - _t0) * 1000),
362
+ finish_reason=response.choices[0].finish_reason
363
+ if response.choices
364
+ else None,
365
  kind="research",
366
  )
367
  except Exception as _telem_err:
368
  logger.debug("research telemetry failed: %s", _telem_err)
369
  content = response.choices[0].message.content or ""
370
+ return (
371
+ content or "Research context exhausted — no summary produced.",
372
+ bool(content),
373
+ )
374
  except Exception:
375
  return "Research context exhausted and summary call failed.", False
376
 
377
  if not _warned_context and _total_tokens >= _RESEARCH_CONTEXT_WARN:
378
  _warned_context = True
379
  await _log(f"Context at {_total_tokens} tokens — nudging to wrap up")
380
+ messages.append(
381
+ Message(
382
+ role="user",
383
+ content=(
384
+ "[SYSTEM: You have used 75% of your context budget. "
385
+ "Start wrapping up: finish any critical lookups, then "
386
+ "produce your final summary within the next 1-2 iterations.]"
387
+ ),
388
+ )
389
+ )
390
 
391
  try:
392
  _msgs, _tools = with_prompt_caching(
 
407
  model=research_model,
408
  response=response,
409
  latency_ms=int((time.monotonic() - _t0) * 1000),
410
+ finish_reason=response.choices[0].finish_reason
411
+ if response.choices
412
+ else None,
413
  kind="research",
414
  )
415
  except Exception as _telem_err:
 
437
  # LiteLLM's raw Message carries `provider_specific_fields` and
438
  # `reasoning_content`, which the HF router's OpenAI schema rejects
439
  # if we echo them back in the next request.
440
+ messages.append(
441
+ Message(
442
+ role="assistant",
443
+ content=msg.content,
444
+ tool_calls=msg.tool_calls,
445
+ )
446
+ )
447
  for tc in msg.tool_calls:
448
  try:
449
  tool_args = json.loads(tc.function.arguments)
 
498
 
499
  # ── Iteration limit: try to salvage findings ──
500
  await _log("Iteration limit reached — extracting summary")
501
+ messages.append(
502
+ Message(
503
+ role="user",
504
+ content=(
505
+ "[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research "
506
+ "iterations. Summarize ALL findings so far. Do NOT call any more tools."
507
+ ),
508
+ )
509
+ )
510
  try:
511
  _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model"))
512
  _t0 = time.monotonic()
 
523
  model=research_model,
524
  response=response,
525
  latency_ms=int((time.monotonic() - _t0) * 1000),
526
+ finish_reason=response.choices[0].finish_reason
527
+ if response.choices
528
+ else None,
529
  kind="research",
530
  )
531
  except Exception as _telem_err:
agent/tools/sandbox_client.py CHANGED
@@ -729,9 +729,7 @@ class Sandbox:
729
  runtime, "requested_hardware", None
730
  )
731
  if current_hardware != hardware:
732
- _log(
733
- f" RUNNING on {current_hardware}; waiting for {hardware}..."
734
- )
735
  time.sleep(WAIT_INTERVAL)
736
  continue
737
  _log(f"Space is running (hardware: {runtime.hardware})")
@@ -767,7 +765,9 @@ class Sandbox:
767
  return sb
768
 
769
  @staticmethod
770
- def _setup_server(space_id: str, api: HfApi, *, log: Callable[[str], object] = print) -> None:
 
 
771
  """Upload embedded sandbox server + Dockerfile to the Space (single commit)."""
772
  log(f"Uploading sandbox server to {space_id}...")
773
  api.create_commit(
@@ -809,7 +809,9 @@ class Sandbox:
809
  sb._wait_for_api(timeout=60)
810
  return sb
811
 
812
- def _wait_for_api(self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print):
 
 
813
  """Poll the health endpoint until the server responds."""
814
  deadline = time.time() + timeout
815
  last_err = None
@@ -986,7 +988,12 @@ class Sandbox:
986
  return result
987
 
988
  def edit(
989
- self, path: str, old_str: str, new_str: str, *, replace_all: bool = False,
 
 
 
 
 
990
  mode: str = "replace",
991
  ) -> ToolResult:
992
  if old_str == new_str:
 
729
  runtime, "requested_hardware", None
730
  )
731
  if current_hardware != hardware:
732
+ _log(f" RUNNING on {current_hardware}; waiting for {hardware}...")
 
 
733
  time.sleep(WAIT_INTERVAL)
734
  continue
735
  _log(f"Space is running (hardware: {runtime.hardware})")
 
765
  return sb
766
 
767
  @staticmethod
768
+ def _setup_server(
769
+ space_id: str, api: HfApi, *, log: Callable[[str], object] = print
770
+ ) -> None:
771
  """Upload embedded sandbox server + Dockerfile to the Space (single commit)."""
772
  log(f"Uploading sandbox server to {space_id}...")
773
  api.create_commit(
 
809
  sb._wait_for_api(timeout=60)
810
  return sb
811
 
812
+ def _wait_for_api(
813
+ self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print
814
+ ):
815
  """Poll the health endpoint until the server responds."""
816
  deadline = time.time() + timeout
817
  last_err = None
 
988
  return result
989
 
990
  def edit(
991
+ self,
992
+ path: str,
993
+ old_str: str,
994
+ new_str: str,
995
+ *,
996
+ replace_all: bool = False,
997
  mode: str = "replace",
998
  ) -> ToolResult:
999
  if old_str == new_str:
agent/tools/sandbox_tool.py CHANGED
@@ -21,6 +21,7 @@ from typing import Any
21
 
22
  from huggingface_hub import HfApi, SpaceHardware
23
 
 
24
  from agent.core.session import Event
25
  from agent.tools.sandbox_client import Sandbox
26
  from agent.tools.trackio_seed import ensure_trackio_dashboard
@@ -197,7 +198,9 @@ def _cleanup_user_orphan_sandboxes(
197
  if not _SANDBOX_NAME_RE.match(space_name):
198
  continue
199
 
200
- last_mod = getattr(space, "lastModified", None) or getattr(space, "last_modified", None)
 
 
201
  if isinstance(last_mod, str):
202
  try:
203
  last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00"))
@@ -337,6 +340,7 @@ async def _create_sandbox_locked(
337
  if hardware != DEFAULT_CPU_SANDBOX_HARDWARE:
338
  kwargs["sleep_time"] = 2700
339
  import time as _t
 
340
  _t_start = _t.monotonic()
341
  try:
342
  sb = await asyncio.to_thread(Sandbox.create, **kwargs)
@@ -350,7 +354,9 @@ async def _create_sandbox_locked(
350
  try:
351
  await asyncio.to_thread(sb.delete)
352
  except Exception as e:
353
- logger.warning("Failed to delete cancelled sandbox %s: %s", sb.space_id, e)
 
 
354
  return None, "Sandbox creation cancelled by user."
355
 
356
  session.sandbox = sb
@@ -360,8 +366,11 @@ async def _create_sandbox_locked(
360
 
361
  # Telemetry: sandbox creation (infra consumption signal)
362
  from agent.core import telemetry
 
363
  await telemetry.record_sandbox_create(
364
- session, sb, hardware=hardware,
 
 
365
  create_latency_s=int(_t.monotonic() - _t_start),
366
  )
367
 
@@ -510,12 +519,13 @@ async def teardown_session_sandbox(session: Any) -> None:
510
  )
511
  await asyncio.to_thread(sandbox.delete)
512
  from agent.core import telemetry
 
513
  await telemetry.record_sandbox_destroy(session, sandbox)
514
  return
515
  except Exception as e:
516
  last_err = e
517
  if attempt < 2:
518
- await asyncio.sleep(2 ** attempt)
519
  logger.error(
520
  "Failed to delete sandbox %s after 3 attempts: %s. "
521
  "Orphan — sweep script will pick it up.",
@@ -720,6 +730,14 @@ def _make_tool_handler(sandbox_tool_name: str):
720
  return "Sandbox is still starting. Please retry shortly.", False
721
 
722
  try:
 
 
 
 
 
 
 
 
723
  result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args)
724
  if result.success:
725
  output = result.output or "(no output)"
@@ -758,8 +776,7 @@ def get_sandbox_tools():
758
  description = (
759
  "Uses the session's active sandbox. A private cpu-basic sandbox is "
760
  "started automatically for normal CPU work; call sandbox_create only "
761
- "for GPU or other non-default hardware.\n\n"
762
- + spec["description"]
763
  )
764
  tools.append(
765
  ToolSpec(
 
21
 
22
  from huggingface_hub import HfApi, SpaceHardware
23
 
24
+ from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap
25
  from agent.core.session import Event
26
  from agent.tools.sandbox_client import Sandbox
27
  from agent.tools.trackio_seed import ensure_trackio_dashboard
 
198
  if not _SANDBOX_NAME_RE.match(space_name):
199
  continue
200
 
201
+ last_mod = getattr(space, "lastModified", None) or getattr(
202
+ space, "last_modified", None
203
+ )
204
  if isinstance(last_mod, str):
205
  try:
206
  last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00"))
 
340
  if hardware != DEFAULT_CPU_SANDBOX_HARDWARE:
341
  kwargs["sleep_time"] = 2700
342
  import time as _t
343
+
344
  _t_start = _t.monotonic()
345
  try:
346
  sb = await asyncio.to_thread(Sandbox.create, **kwargs)
 
354
  try:
355
  await asyncio.to_thread(sb.delete)
356
  except Exception as e:
357
+ logger.warning(
358
+ "Failed to delete cancelled sandbox %s: %s", sb.space_id, e
359
+ )
360
  return None, "Sandbox creation cancelled by user."
361
 
362
  session.sandbox = sb
 
366
 
367
  # Telemetry: sandbox creation (infra consumption signal)
368
  from agent.core import telemetry
369
+
370
  await telemetry.record_sandbox_create(
371
+ session,
372
+ sb,
373
+ hardware=hardware,
374
  create_latency_s=int(_t.monotonic() - _t_start),
375
  )
376
 
 
519
  )
520
  await asyncio.to_thread(sandbox.delete)
521
  from agent.core import telemetry
522
+
523
  await telemetry.record_sandbox_destroy(session, sandbox)
524
  return
525
  except Exception as e:
526
  last_err = e
527
  if attempt < 2:
528
+ await asyncio.sleep(2**attempt)
529
  logger.error(
530
  "Failed to delete sandbox %s after 3 attempts: %s. "
531
  "Orphan — sweep script will pick it up.",
 
730
  return "Sandbox is still starting. Please retry shortly.", False
731
 
732
  try:
733
+ if sandbox_tool_name == "bash" and args.get("command"):
734
+ args = {
735
+ **args,
736
+ "command": wrap_shell_command_with_hub_artifact_bootstrap(
737
+ args["command"],
738
+ session,
739
+ ),
740
+ }
741
  result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args)
742
  if result.success:
743
  output = result.output or "(no output)"
 
776
  description = (
777
  "Uses the session's active sandbox. A private cpu-basic sandbox is "
778
  "started automatically for normal CPU work; call sandbox_create only "
779
+ "for GPU or other non-default hardware.\n\n" + spec["description"]
 
780
  )
781
  tools.append(
782
  ToolSpec(
agent/tools/web_search_tool.py CHANGED
@@ -253,7 +253,10 @@ async def web_search_handler(
253
  ) -> tuple[str, bool]:
254
  query_value = arguments.get("query", "")
255
  if not isinstance(query_value, str):
256
- return "Error: web_search requires a query string with at least 2 characters.", False
 
 
 
257
 
258
  query = query_value.strip()
259
  if len(query) < 2:
 
253
  ) -> tuple[str, bool]:
254
  query_value = arguments.get("query", "")
255
  if not isinstance(query_value, str):
256
+ return (
257
+ "Error: web_search requires a query string with at least 2 characters.",
258
+ False,
259
+ )
260
 
261
  query = query_value.strip()
262
  if len(query) < 2:
agent/utils/braille.py CHANGED
@@ -41,8 +41,7 @@ class BrailleCanvas:
41
  for row in range(self.term_height):
42
  offset = row * self.term_width
43
  line = "".join(
44
- chr(0x2800 + self._buf[offset + col])
45
- for col in range(self.term_width)
46
  )
47
  lines.append(line)
48
  return lines
@@ -52,6 +51,7 @@ class BrailleCanvas:
52
 
53
  _FONT: dict[str, list[str]] = {}
54
 
 
55
  def _define_font() -> None:
56
  """Define a simple 5×7 bitmap font for uppercase ASCII."""
57
  glyphs = {
@@ -113,8 +113,9 @@ def text_to_pixels(text: str, scale: int = 1) -> list[tuple[int, int]]:
113
  if cell == "#":
114
  for sy in range(scale):
115
  for sx in range(scale):
116
- pixels.append((cursor_x + col_idx * scale + sx,
117
- row_idx * scale + sy))
 
118
  glyph_width = max(len(r) for r in glyph)
119
  cursor_x += (glyph_width + 1) * scale
120
  return pixels
 
41
  for row in range(self.term_height):
42
  offset = row * self.term_width
43
  line = "".join(
44
+ chr(0x2800 + self._buf[offset + col]) for col in range(self.term_width)
 
45
  )
46
  lines.append(line)
47
  return lines
 
51
 
52
  _FONT: dict[str, list[str]] = {}
53
 
54
+
55
  def _define_font() -> None:
56
  """Define a simple 5×7 bitmap font for uppercase ASCII."""
57
  glyphs = {
 
113
  if cell == "#":
114
  for sy in range(scale):
115
  for sx in range(scale):
116
+ pixels.append(
117
+ (cursor_x + col_idx * scale + sx, row_idx * scale + sy)
118
+ )
119
  glyph_width = max(len(r) for r in glyph)
120
  cursor_x += (glyph_width + 1) * scale
121
  return pixels
agent/utils/crt_boot.py CHANGED
@@ -55,7 +55,10 @@ def run_boot_sequence(console: Console, boot_lines: list[tuple[str, str]]) -> No
55
  # Render previously completed lines
56
  for prev_text, prev_style in displayed_lines:
57
  if rng.random() < prev_glitch_chance:
58
- result.append(_glitch_text(prev_text, prev_glitch_intensity, rng), style=prev_style)
 
 
 
59
  else:
60
  result.append(prev_text, style=prev_style)
61
  result.append("\n")
@@ -86,7 +89,7 @@ def run_boot_sequence(console: Console, boot_lines: list[tuple[str, str]]) -> No
86
  live.update(result)
87
 
88
  # Variable typing speed
89
- if line_text[char_idx - 1:char_idx] in " .":
90
  time.sleep(0.025)
91
  else:
92
  time.sleep(0.010)
 
55
  # Render previously completed lines
56
  for prev_text, prev_style in displayed_lines:
57
  if rng.random() < prev_glitch_chance:
58
+ result.append(
59
+ _glitch_text(prev_text, prev_glitch_intensity, rng),
60
+ style=prev_style,
61
+ )
62
  else:
63
  result.append(prev_text, style=prev_style)
64
  result.append("\n")
 
89
  live.update(result)
90
 
91
  # Variable typing speed
92
+ if line_text[char_idx - 1 : char_idx] in " .":
93
  time.sleep(0.025)
94
  else:
95
  time.sleep(0.010)
agent/utils/particle_logo.py CHANGED
@@ -23,7 +23,9 @@ from agent.utils.boot_timing import settle_curve, warm_gold_from_white
23
  class Particle:
24
  __slots__ = ("x", "y", "target_x", "target_y", "vx", "vy", "phase", "delay")
25
 
26
- def __init__(self, x: float, y: float, target_x: float, target_y: float, delay: float = 0):
 
 
27
  self.x = x
28
  self.y = y
29
  self.target_x = target_x
 
23
  class Particle:
24
  __slots__ = ("x", "y", "target_x", "target_y", "vx", "vy", "phase", "delay")
25
 
26
+ def __init__(
27
+ self, x: float, y: float, target_x: float, target_y: float, delay: float = 0
28
+ ):
29
  self.x = x
30
  self.y = y
31
  self.target_x = target_x
agent/utils/terminal_display.py CHANGED
@@ -2,6 +2,7 @@
2
  Terminal display utilities — rich-powered CLI formatting.
3
  """
4
 
 
5
  import re
6
 
7
  from rich.console import Console
@@ -57,23 +58,26 @@ def _clip_to_width(s: str, width: int) -> str:
57
  out.append("\033[0m…")
58
  return "".join(out)
59
 
60
- _THEME = Theme({
61
- "tool.name": "bold rgb(255,200,80)",
62
- "tool.args": "dim",
63
- "tool.ok": "dim green",
64
- "tool.fail": "dim red",
65
- "info": "dim",
66
- "muted": "dim",
67
- # Markdown emphasis colors
68
- "markdown.strong": "bold rgb(255,200,80)",
69
- "markdown.emphasis": "italic rgb(180,140,40)",
70
- "markdown.code": "rgb(120,220,255)",
71
- "markdown.code_block": "rgb(120,220,255)",
72
- "markdown.link": "underline rgb(90,180,255)",
73
- "markdown.h1": "bold rgb(255,200,80)",
74
- "markdown.h2": "bold rgb(240,180,95)",
75
- "markdown.h3": "bold rgb(220,165,100)",
76
- })
 
 
 
77
 
78
  _console = Console(theme=_THEME, highlight=False)
79
 
@@ -87,6 +91,7 @@ def get_console() -> Console:
87
 
88
  # ── Banner ─────────────────────────────────────────────────────────────
89
 
 
90
  def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
91
  """Print particle logo then CRT boot sequence with system info."""
92
  from agent.utils.particle_logo import run_particle_logo
@@ -120,12 +125,16 @@ def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
120
 
121
  # ── Init progress ──────────────────────────────────────────────────────
122
 
 
123
  def print_init_done(tool_count: int = 0) -> None:
124
  import time
 
125
  f = _console.file
126
  # Overwrite the "Tools: loading..." line with actual count
127
- f.write(f"\033[A\033[A\033[A\033[K") # Move up 3 lines (blank + help + blank) then up to tools line
128
- f.write(f"\033[A\033[K")
 
 
129
  gold = "\033[38;2;180;140;40m"
130
  reset = "\033[0m"
131
  tool_text = f"{_I} Tools: {tool_count} loaded"
@@ -135,16 +144,22 @@ def print_init_done(tool_count: int = 0) -> None:
135
  time.sleep(0.012)
136
  f.write("\n\n")
137
  # Reprint the help line
138
- f.write(f"{_I}\033[38;2;255;200;80m/help for commands · /model to switch · /quit to exit{reset}\n\n")
 
 
139
  # Ready message — minimal padding
140
- f.write(f"{_I}\033[38;2;255;200;80mReady. Let's build something impressive.{reset}\n")
 
 
141
  f.flush()
142
 
143
 
144
  # ── Tool calls ─────────────────────────────────────────────────────────
145
 
 
146
  def print_tool_call(tool_name: str, args_preview: str) -> None:
147
  import time
 
148
  f = _console.file
149
  # CRT-style: type out tool name in HF yellow
150
  gold = "\033[38;2;255;200;80m"
@@ -183,6 +198,7 @@ class SubAgentDisplayManager:
183
 
184
  def start(self, agent_id: str, label: str = "research") -> None:
185
  import time
 
186
  self._agents[agent_id] = {
187
  "label": label,
188
  "calls": [],
@@ -234,6 +250,7 @@ class SubAgentDisplayManager:
234
  @staticmethod
235
  def _format_stats(agent: dict) -> str:
236
  import time
 
237
  start = agent["start_time"]
238
  if start is None:
239
  return ""
@@ -276,7 +293,7 @@ class SubAgentDisplayManager:
276
  header += f" \033[2m·\033[0m \033[2m{short}\033[0m"
277
  return [header]
278
  lines = [header]
279
- visible = agent["calls"][-self._MAX_VISIBLE:]
280
  for desc in visible:
281
  lines.append(f"{_I} \033[2m{desc}\033[0m")
282
  return lines
@@ -319,13 +336,14 @@ def print_tool_log(tool: str, log: str, agent_id: str = "", label: str = "") ->
319
 
320
  # ── Messages ───────────────────────────────────────────────────────────
321
 
 
322
  async def print_markdown(
323
  text: str,
324
  cancel_event: "asyncio.Event | None" = None,
325
  instant: bool = False,
326
  ) -> None:
327
- import asyncio
328
- import io, random
329
  from rich.padding import Padding
330
 
331
  _console.print()
@@ -395,23 +413,35 @@ def print_interrupted() -> None:
395
 
396
 
397
  def print_compacted(old_tokens: int, new_tokens: int) -> None:
398
- _console.print(f"{_I}[dim]context compacted: {old_tokens:,} → {new_tokens:,} tokens[/dim]")
 
 
399
 
400
 
401
  # ── Approval ───────────────────────────────────────────────────────────
402
 
 
403
  def print_approval_header(count: int) -> None:
404
  label = f"Approval required — {count} item{'s' if count != 1 else ''}"
405
  _console.print()
406
- _console.print(f"{_I}", Panel(f"[bold yellow]{label}[/bold yellow]", border_style="yellow", expand=False))
 
 
 
 
 
407
 
408
 
409
  def print_approval_item(index: int, total: int, tool_name: str, operation: str) -> None:
410
- _console.print(f"\n{_I}[bold]\\[{index}/{total}][/bold] [tool.name]{tool_name}[/tool.name] {operation}")
 
 
411
 
412
 
413
  def print_yolo_approve(count: int) -> None:
414
- _console.print(f"{_I}[bold yellow]yolo →[/bold yellow] auto-approved {count} item(s)")
 
 
415
 
416
 
417
  # ── Help ───────────────────────────────────────────────────────────────
@@ -437,6 +467,7 @@ def print_help() -> None:
437
 
438
  # ── Plan display ───────────────────────────────────────────────────────
439
 
 
440
  def format_plan_display() -> str:
441
  """Format the current plan for display."""
442
  from agent.tools.plan_tool import get_current_plan
@@ -470,6 +501,7 @@ def print_plan() -> None:
470
 
471
  # ── Formatting for plan_tool output (used by plan_tool handler) ────────
472
 
 
473
  def format_plan_tool_output(todos: list) -> str:
474
  if not todos:
475
  return "Plan is empty."
@@ -492,6 +524,7 @@ def format_plan_tool_output(todos: list) -> str:
492
 
493
  # ── Internal helpers ───────────────────────────────────────────────────
494
 
 
495
  def _truncate(text: str, max_lines: int = 6) -> str:
496
  lines = text.split("\n")
497
  if len(lines) <= max_lines:
 
2
  Terminal display utilities — rich-powered CLI formatting.
3
  """
4
 
5
+ import asyncio
6
  import re
7
 
8
  from rich.console import Console
 
58
  out.append("\033[0m…")
59
  return "".join(out)
60
 
61
+
62
+ _THEME = Theme(
63
+ {
64
+ "tool.name": "bold rgb(255,200,80)",
65
+ "tool.args": "dim",
66
+ "tool.ok": "dim green",
67
+ "tool.fail": "dim red",
68
+ "info": "dim",
69
+ "muted": "dim",
70
+ # Markdown emphasis colors
71
+ "markdown.strong": "bold rgb(255,200,80)",
72
+ "markdown.emphasis": "italic rgb(180,140,40)",
73
+ "markdown.code": "rgb(120,220,255)",
74
+ "markdown.code_block": "rgb(120,220,255)",
75
+ "markdown.link": "underline rgb(90,180,255)",
76
+ "markdown.h1": "bold rgb(255,200,80)",
77
+ "markdown.h2": "bold rgb(240,180,95)",
78
+ "markdown.h3": "bold rgb(220,165,100)",
79
+ }
80
+ )
81
 
82
  _console = Console(theme=_THEME, highlight=False)
83
 
 
91
 
92
  # ── Banner ─────────────────────────────────────────────────────────────
93
 
94
+
95
  def print_banner(model: str | None = None, hf_user: str | None = None) -> None:
96
  """Print particle logo then CRT boot sequence with system info."""
97
  from agent.utils.particle_logo import run_particle_logo
 
125
 
126
  # ── Init progress ──────────────────────────────────────────────────────
127
 
128
+
129
  def print_init_done(tool_count: int = 0) -> None:
130
  import time
131
+
132
  f = _console.file
133
  # Overwrite the "Tools: loading..." line with actual count
134
+ f.write(
135
+ "\033[A\033[A\033[A\033[K"
136
+ ) # Move up 3 lines (blank + help + blank) then up to tools line
137
+ f.write("\033[A\033[K")
138
  gold = "\033[38;2;180;140;40m"
139
  reset = "\033[0m"
140
  tool_text = f"{_I} Tools: {tool_count} loaded"
 
144
  time.sleep(0.012)
145
  f.write("\n\n")
146
  # Reprint the help line
147
+ f.write(
148
+ f"{_I}\033[38;2;255;200;80m/help for commands · /model to switch · /quit to exit{reset}\n\n"
149
+ )
150
  # Ready message — minimal padding
151
+ f.write(
152
+ f"{_I}\033[38;2;255;200;80mReady. Let's build something impressive.{reset}\n"
153
+ )
154
  f.flush()
155
 
156
 
157
  # ── Tool calls ─────────────────────────────────────────────────────────
158
 
159
+
160
  def print_tool_call(tool_name: str, args_preview: str) -> None:
161
  import time
162
+
163
  f = _console.file
164
  # CRT-style: type out tool name in HF yellow
165
  gold = "\033[38;2;255;200;80m"
 
198
 
199
  def start(self, agent_id: str, label: str = "research") -> None:
200
  import time
201
+
202
  self._agents[agent_id] = {
203
  "label": label,
204
  "calls": [],
 
250
  @staticmethod
251
  def _format_stats(agent: dict) -> str:
252
  import time
253
+
254
  start = agent["start_time"]
255
  if start is None:
256
  return ""
 
293
  header += f" \033[2m·\033[0m \033[2m{short}\033[0m"
294
  return [header]
295
  lines = [header]
296
+ visible = agent["calls"][-self._MAX_VISIBLE :]
297
  for desc in visible:
298
  lines.append(f"{_I} \033[2m{desc}\033[0m")
299
  return lines
 
336
 
337
  # ── Messages ───────────────────────────────────────────────────────────
338
 
339
+
340
  async def print_markdown(
341
  text: str,
342
  cancel_event: "asyncio.Event | None" = None,
343
  instant: bool = False,
344
  ) -> None:
345
+ import io
346
+ import random
347
  from rich.padding import Padding
348
 
349
  _console.print()
 
413
 
414
 
415
  def print_compacted(old_tokens: int, new_tokens: int) -> None:
416
+ _console.print(
417
+ f"{_I}[dim]context compacted: {old_tokens:,} → {new_tokens:,} tokens[/dim]"
418
+ )
419
 
420
 
421
  # ── Approval ───────────────────────────────────────────────────────────
422
 
423
+
424
  def print_approval_header(count: int) -> None:
425
  label = f"Approval required — {count} item{'s' if count != 1 else ''}"
426
  _console.print()
427
+ _console.print(
428
+ f"{_I}",
429
+ Panel(
430
+ f"[bold yellow]{label}[/bold yellow]", border_style="yellow", expand=False
431
+ ),
432
+ )
433
 
434
 
435
  def print_approval_item(index: int, total: int, tool_name: str, operation: str) -> None:
436
+ _console.print(
437
+ f"\n{_I}[bold]\\[{index}/{total}][/bold] [tool.name]{tool_name}[/tool.name] {operation}"
438
+ )
439
 
440
 
441
  def print_yolo_approve(count: int) -> None:
442
+ _console.print(
443
+ f"{_I}[bold yellow]yolo →[/bold yellow] auto-approved {count} item(s)"
444
+ )
445
 
446
 
447
  # ── Help ───────────────────────────────────────────────────────────────
 
467
 
468
  # ── Plan display ───────────────────────────────────────────────────────
469
 
470
+
471
  def format_plan_display() -> str:
472
  """Format the current plan for display."""
473
  from agent.tools.plan_tool import get_current_plan
 
501
 
502
  # ── Formatting for plan_tool output (used by plan_tool handler) ────────
503
 
504
+
505
  def format_plan_tool_output(todos: list) -> str:
506
  if not todos:
507
  return "Plan is empty."
 
524
 
525
  # ── Internal helpers ───────────────────────────────────────────────────
526
 
527
+
528
  def _truncate(text: str, max_lines: int = 6) -> str:
529
  lines = text.split("\n")
530
  if len(lines) <= max_lines:
backend/dependencies.py CHANGED
@@ -102,7 +102,9 @@ async def _fetch_user_plan(token: str) -> str:
102
  _WHOAMI_SHAPE_LOGGED = True
103
  logger.debug(
104
  "whoami-v2 payload keys: %s (sample values: plan=%r type=%r isPro=%r)",
105
- sorted(whoami.keys()) if isinstance(whoami, dict) else type(whoami).__name__,
 
 
106
  whoami.get("plan") if isinstance(whoami, dict) else None,
107
  whoami.get("type") if isinstance(whoami, dict) else None,
108
  whoami.get("isPro") if isinstance(whoami, dict) else None,
 
102
  _WHOAMI_SHAPE_LOGGED = True
103
  logger.debug(
104
  "whoami-v2 payload keys: %s (sample values: plan=%r type=%r isPro=%r)",
105
+ sorted(whoami.keys())
106
+ if isinstance(whoami, dict)
107
+ else type(whoami).__name__,
108
  whoami.get("plan") if isinstance(whoami, dict) else None,
109
  whoami.get("type") if isinstance(whoami, dict) else None,
110
  whoami.get("isPro") if isinstance(whoami, dict) else None,
backend/kpis_scheduler.py CHANGED
@@ -58,7 +58,8 @@ def _resolve_token() -> Optional[str]:
58
  def _load_build_kpis():
59
  """Import ``scripts/build_kpis.py`` without putting ``scripts/`` on sys.path."""
60
  spec = importlib.util.spec_from_file_location(
61
- "build_kpis", _PROJECT_ROOT / "scripts" / "build_kpis.py",
 
62
  )
63
  mod = importlib.util.module_from_spec(spec)
64
  assert spec.loader is not None
@@ -75,6 +76,7 @@ async def _run_hour(hour_dt: datetime) -> None:
75
  try:
76
  mod = _load_build_kpis()
77
  from huggingface_hub import HfApi
 
78
  api = HfApi()
79
  source = os.environ.get("KPI_SOURCE_REPO", "smolagents/ml-intern-sessions")
80
  target = os.environ.get("KPI_TARGET_REPO", "smolagents/ml-intern-kpis")
@@ -118,7 +120,7 @@ def start(backfill_hours: int = 6) -> None:
118
  CronTrigger(minute=5),
119
  id="kpis_hourly",
120
  misfire_grace_time=600, # tolerate a 10-min misfire window
121
- coalesce=True, # collapse multiple missed fires into one
122
  max_instances=1,
123
  replace_existing=True,
124
  )
 
58
  def _load_build_kpis():
59
  """Import ``scripts/build_kpis.py`` without putting ``scripts/`` on sys.path."""
60
  spec = importlib.util.spec_from_file_location(
61
+ "build_kpis",
62
+ _PROJECT_ROOT / "scripts" / "build_kpis.py",
63
  )
64
  mod = importlib.util.module_from_spec(spec)
65
  assert spec.loader is not None
 
76
  try:
77
  mod = _load_build_kpis()
78
  from huggingface_hub import HfApi
79
+
80
  api = HfApi()
81
  source = os.environ.get("KPI_SOURCE_REPO", "smolagents/ml-intern-sessions")
82
  target = os.environ.get("KPI_TARGET_REPO", "smolagents/ml-intern-kpis")
 
120
  CronTrigger(minute=5),
121
  id="kpis_hourly",
122
  misfire_grace_time=600, # tolerate a 10-min misfire window
123
+ coalesce=True, # collapse multiple missed fires into one
124
  max_instances=1,
125
  replace_existing=True,
126
  )
backend/main.py CHANGED
@@ -6,17 +6,17 @@ from contextlib import asynccontextmanager
6
  from pathlib import Path
7
 
8
  from dotenv import load_dotenv
 
 
 
9
 
10
  # Load .env before importing routes/session_manager so persistence and quota
11
  # modules see local Mongo settings during startup.
12
  load_dotenv(Path(__file__).parent.parent / ".env")
13
 
14
- from fastapi import FastAPI
15
- from fastapi.middleware.cors import CORSMiddleware
16
- from fastapi.staticfiles import StaticFiles
17
- from routes.agent import router as agent_router
18
- from routes.auth import router as auth_router
19
- from session_manager import session_manager
20
 
21
  # Configure logging
22
  logging.basicConfig(
@@ -35,6 +35,7 @@ async def lifespan(app: FastAPI):
35
  # rollup lives next to the data and reuses the Space's HF token.
36
  try:
37
  import kpis_scheduler
 
38
  kpis_scheduler.start()
39
  except Exception as e:
40
  logger.warning("KPI scheduler failed to start: %s", e)
@@ -43,6 +44,7 @@ async def lifespan(app: FastAPI):
43
  logger.info("Shutting down HF Agent backend...")
44
  try:
45
  import kpis_scheduler
 
46
  await kpis_scheduler.shutdown()
47
  except Exception as e:
48
  logger.warning("KPI scheduler shutdown failed: %s", e)
 
6
  from pathlib import Path
7
 
8
  from dotenv import load_dotenv
9
+ from fastapi import FastAPI
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.staticfiles import StaticFiles
12
 
13
  # Load .env before importing routes/session_manager so persistence and quota
14
  # modules see local Mongo settings during startup.
15
  load_dotenv(Path(__file__).parent.parent / ".env")
16
 
17
+ from routes.agent import router as agent_router # noqa: E402
18
+ from routes.auth import router as auth_router # noqa: E402
19
+ from session_manager import session_manager # noqa: E402
 
 
 
20
 
21
  # Configure logging
22
  logging.basicConfig(
 
35
  # rollup lives next to the data and reuses the Space's HF token.
36
  try:
37
  import kpis_scheduler
38
+
39
  kpis_scheduler.start()
40
  except Exception as e:
41
  logger.warning("KPI scheduler failed to start: %s", e)
 
44
  logger.info("Shutting down HF Agent backend...")
45
  try:
46
  import kpis_scheduler
47
+
48
  await kpis_scheduler.shutdown()
49
  except Exception as e:
50
  logger.warning("KPI scheduler shutdown failed: %s", e)
backend/models.py CHANGED
@@ -131,4 +131,6 @@ class LLMHealthResponse(BaseModel):
131
  status: str # "ok" | "error"
132
  model: str
133
  error: str | None = None
134
- error_type: str | None = None # "auth" | "credits" | "rate_limit" | "network" | "unknown"
 
 
 
131
  status: str # "ok" | "error"
132
  model: str
133
  error: str | None = None
134
+ error_type: str | None = (
135
+ None # "auth" | "credits" | "rate_limit" | "network" | "unknown"
136
+ )
backend/routes/agent.py CHANGED
@@ -7,7 +7,6 @@ dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically.
7
  import asyncio
8
  import json
9
  import logging
10
- import os
11
  from typing import Any
12
 
13
  from dependencies import (
@@ -34,7 +33,12 @@ from models import (
34
  SubmitRequest,
35
  TruncateRequest,
36
  )
37
- from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, session_manager
 
 
 
 
 
38
 
39
  import user_quotas
40
 
@@ -136,7 +140,7 @@ async def _require_hf_for_gated_model(request: Request, model_id: str) -> None:
136
  """403 if a non-``huggingface``-org user tries to select a gated model.
137
 
138
  Gated models are deployed paid endpoints backed by service-owned
139
- credentials. The gate only fires for deployed paid models so non-HF users
140
  can still freely switch between the free models.
141
  """
142
  if not _is_gated_model(model_id):
@@ -226,7 +230,11 @@ async def _check_session_access(
226
  preload_sandbox: bool = True,
227
  ) -> AgentSession:
228
  """Verify and lazily load the user's session. Raises 403 or 404."""
229
- hf_token = resolve_hf_request_token(request) if request is not None else _user_hf_token(user)
 
 
 
 
230
  agent_session = await session_manager.ensure_session_loaded(
231
  session_id,
232
  user["user_id"],
@@ -236,7 +244,10 @@ async def _check_session_access(
236
  )
237
  if not agent_session:
238
  raise HTTPException(status_code=404, detail="Session not found")
239
- if user["user_id"] != "dev" and agent_session.user_id not in {user["user_id"], "dev"}:
 
 
 
240
  raise HTTPException(status_code=403, detail="Access denied to this session")
241
  return agent_session
242
 
@@ -362,7 +373,9 @@ async def generate_title(
362
  await _check_session_access(request.session_id, user)
363
  await session_manager.update_session_title(request.session_id, title)
364
  except Exception:
365
- logger.debug("Skipping title persistence for missing session %s", request.session_id)
 
 
366
  return {"title": title}
367
  except Exception as e:
368
  logger.warning(f"Title generation failed: {e}")
@@ -372,7 +385,10 @@ async def generate_title(
372
  await _check_session_access(request.session_id, user)
373
  await session_manager.update_session_title(request.session_id, title)
374
  except Exception:
375
- logger.debug("Skipping fallback title persistence for missing session %s", request.session_id)
 
 
 
376
  return {"title": title}
377
 
378
 
@@ -586,7 +602,9 @@ async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
586
 
587
 
588
  @router.get("/user/jobs-access")
589
- async def get_jobs_access_info(request: Request, user: dict = Depends(get_current_user)) -> dict:
 
 
590
  """Return the namespaces the current token can run HF Jobs under.
591
 
592
  Credits are enforced by the HF API at job-creation time, not here —
@@ -652,7 +670,7 @@ async def submit_approval(
652
  request: ApprovalRequest, user: dict = Depends(get_current_user)
653
  ) -> dict:
654
  """Submit tool approvals to a session. Only accessible by the session owner."""
655
- agent_session = await _check_session_access(request.session_id, user)
656
  approvals = [
657
  {
658
  "tool_call_id": a.tool_call_id,
@@ -719,7 +737,9 @@ async def chat_sse(
719
  success = await session_manager.submit_user_input(session_id, text)
720
  else:
721
  broadcaster.unsubscribe(sub_id)
722
- raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'")
 
 
723
 
724
  if not success:
725
  broadcaster.unsubscribe(sub_id)
@@ -744,6 +764,7 @@ async def record_pro_click(
744
  agent_session = await _check_session_access(session_id, user)
745
 
746
  from agent.core import telemetry
 
747
  await telemetry.record_pro_cta_click(
748
  agent_session.session,
749
  source=str(body.get("source") or "unknown"),
@@ -759,12 +780,20 @@ async def record_pro_click(
759
  # ---------------------------------------------------------------------------
760
  # Shared SSE helpers
761
  # ---------------------------------------------------------------------------
762
- _TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted", "shutdown"}
 
 
 
 
 
 
763
  _SSE_KEEPALIVE_SECONDS = 15
764
 
765
 
766
  def _last_event_seq(request: Request) -> int:
767
- raw = request.headers.get("last-event-id") or request.query_params.get("after") or "0"
 
 
768
  try:
769
  return max(0, int(raw))
770
  except (TypeError, ValueError):
@@ -853,7 +882,9 @@ async def subscribe_events(
853
  raise HTTPException(status_code=404, detail="Session not found or inactive")
854
 
855
  after_seq = _last_event_seq(request)
856
- replay_events = await session_manager._store().load_events_after(session_id, after_seq)
 
 
857
  broadcaster = agent_session.broadcaster
858
  sub_id, event_queue = broadcaster.subscribe()
859
  return _sse_response(
@@ -885,7 +916,10 @@ async def get_session_messages(
885
  agent_session = await _check_session_access(session_id, user)
886
  if not agent_session or not agent_session.is_active:
887
  raise HTTPException(status_code=404, detail="Session not found or inactive")
888
- return [msg.model_dump(mode="json") for msg in agent_session.session.context_manager.items]
 
 
 
889
 
890
 
891
  @router.post("/undo/{session_id}")
@@ -906,7 +940,10 @@ async def truncate_session(
906
  await _check_session_access(session_id, user)
907
  success = await session_manager.truncate(session_id, body.user_message_index)
908
  if not success:
909
- raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range")
 
 
 
910
  return {"status": "truncated", "session_id": session_id}
911
 
912
 
@@ -933,6 +970,7 @@ async def shutdown_session(
933
  raise HTTPException(status_code=404, detail="Session not found or inactive")
934
  return {"status": "shutdown_requested", "session_id": session_id}
935
 
 
936
  @router.post("/feedback/{session_id}")
937
  async def submit_feedback(
938
  session_id: str,
@@ -952,6 +990,7 @@ async def submit_feedback(
952
  raise HTTPException(status_code=400, detail="invalid rating")
953
 
954
  from agent.core import telemetry
 
955
  await telemetry.record_feedback(
956
  agent_session.session,
957
  rating=rating,
 
7
  import asyncio
8
  import json
9
  import logging
 
10
  from typing import Any
11
 
12
  from dependencies import (
 
33
  SubmitRequest,
34
  TruncateRequest,
35
  )
36
+ from session_manager import (
37
+ MAX_SESSIONS,
38
+ AgentSession,
39
+ SessionCapacityError,
40
+ session_manager,
41
+ )
42
 
43
  import user_quotas
44
 
 
140
  """403 if a non-``huggingface``-org user tries to select a gated model.
141
 
142
  Gated models are deployed paid endpoints backed by service-owned
143
+ credentials. The gate only fires for deployed paid models so non-HF users
144
  can still freely switch between the free models.
145
  """
146
  if not _is_gated_model(model_id):
 
230
  preload_sandbox: bool = True,
231
  ) -> AgentSession:
232
  """Verify and lazily load the user's session. Raises 403 or 404."""
233
+ hf_token = (
234
+ resolve_hf_request_token(request)
235
+ if request is not None
236
+ else _user_hf_token(user)
237
+ )
238
  agent_session = await session_manager.ensure_session_loaded(
239
  session_id,
240
  user["user_id"],
 
244
  )
245
  if not agent_session:
246
  raise HTTPException(status_code=404, detail="Session not found")
247
+ if user["user_id"] != "dev" and agent_session.user_id not in {
248
+ user["user_id"],
249
+ "dev",
250
+ }:
251
  raise HTTPException(status_code=403, detail="Access denied to this session")
252
  return agent_session
253
 
 
373
  await _check_session_access(request.session_id, user)
374
  await session_manager.update_session_title(request.session_id, title)
375
  except Exception:
376
+ logger.debug(
377
+ "Skipping title persistence for missing session %s", request.session_id
378
+ )
379
  return {"title": title}
380
  except Exception as e:
381
  logger.warning(f"Title generation failed: {e}")
 
385
  await _check_session_access(request.session_id, user)
386
  await session_manager.update_session_title(request.session_id, title)
387
  except Exception:
388
+ logger.debug(
389
+ "Skipping fallback title persistence for missing session %s",
390
+ request.session_id,
391
+ )
392
  return {"title": title}
393
 
394
 
 
602
 
603
 
604
  @router.get("/user/jobs-access")
605
+ async def get_jobs_access_info(
606
+ request: Request, user: dict = Depends(get_current_user)
607
+ ) -> dict:
608
  """Return the namespaces the current token can run HF Jobs under.
609
 
610
  Credits are enforced by the HF API at job-creation time, not here —
 
670
  request: ApprovalRequest, user: dict = Depends(get_current_user)
671
  ) -> dict:
672
  """Submit tool approvals to a session. Only accessible by the session owner."""
673
+ await _check_session_access(request.session_id, user)
674
  approvals = [
675
  {
676
  "tool_call_id": a.tool_call_id,
 
737
  success = await session_manager.submit_user_input(session_id, text)
738
  else:
739
  broadcaster.unsubscribe(sub_id)
740
+ raise HTTPException(
741
+ status_code=400, detail="Must provide 'text' or 'approvals'"
742
+ )
743
 
744
  if not success:
745
  broadcaster.unsubscribe(sub_id)
 
764
  agent_session = await _check_session_access(session_id, user)
765
 
766
  from agent.core import telemetry
767
+
768
  await telemetry.record_pro_cta_click(
769
  agent_session.session,
770
  source=str(body.get("source") or "unknown"),
 
780
  # ---------------------------------------------------------------------------
781
  # Shared SSE helpers
782
  # ---------------------------------------------------------------------------
783
+ _TERMINAL_EVENTS = {
784
+ "turn_complete",
785
+ "approval_required",
786
+ "error",
787
+ "interrupted",
788
+ "shutdown",
789
+ }
790
  _SSE_KEEPALIVE_SECONDS = 15
791
 
792
 
793
  def _last_event_seq(request: Request) -> int:
794
+ raw = (
795
+ request.headers.get("last-event-id") or request.query_params.get("after") or "0"
796
+ )
797
  try:
798
  return max(0, int(raw))
799
  except (TypeError, ValueError):
 
882
  raise HTTPException(status_code=404, detail="Session not found or inactive")
883
 
884
  after_seq = _last_event_seq(request)
885
+ replay_events = await session_manager._store().load_events_after(
886
+ session_id, after_seq
887
+ )
888
  broadcaster = agent_session.broadcaster
889
  sub_id, event_queue = broadcaster.subscribe()
890
  return _sse_response(
 
916
  agent_session = await _check_session_access(session_id, user)
917
  if not agent_session or not agent_session.is_active:
918
  raise HTTPException(status_code=404, detail="Session not found or inactive")
919
+ return [
920
+ msg.model_dump(mode="json")
921
+ for msg in agent_session.session.context_manager.items
922
+ ]
923
 
924
 
925
  @router.post("/undo/{session_id}")
 
940
  await _check_session_access(session_id, user)
941
  success = await session_manager.truncate(session_id, body.user_message_index)
942
  if not success:
943
+ raise HTTPException(
944
+ status_code=404,
945
+ detail="Session not found, inactive, or message index out of range",
946
+ )
947
  return {"status": "truncated", "session_id": session_id}
948
 
949
 
 
970
  raise HTTPException(status_code=404, detail="Session not found or inactive")
971
  return {"status": "shutdown_requested", "session_id": session_id}
972
 
973
+
974
  @router.post("/feedback/{session_id}")
975
  async def submit_feedback(
976
  session_id: str,
 
990
  raise HTTPException(status_code=400, detail="invalid rating")
991
 
992
  from agent.core import telemetry
993
+
994
  await telemetry.record_feedback(
995
  agent_session.session,
996
  rating=rating,
backend/routes/auth.py CHANGED
@@ -168,4 +168,3 @@ async def get_me(user: dict = Depends(get_current_user)) -> dict:
168
  Uses the shared auth dependency which handles cookie + Bearer token.
169
  """
170
  return {key: value for key, value in user.items() if not key.startswith("_")}
171
-
 
168
  Uses the shared auth dependency which handles cookie + Bearer token.
169
  """
170
  return {key: value for key, value in user.items() if not key.startswith("_")}
 
backend/session_manager.py CHANGED
@@ -12,10 +12,11 @@ from typing import Any, Optional
12
 
13
  from agent.config import load_config
14
  from agent.core.agent_loop import process_submission
15
- from agent.messaging.gateway import NotificationGateway
16
  from agent.core.session import Event, OpType, Session
17
  from agent.core.session_persistence import get_session_store
18
  from agent.core.tools import ToolRouter
 
19
 
20
  # Get project root (parent of backend directory)
21
  PROJECT_ROOT = Path(__file__).parent.parent
@@ -70,7 +71,11 @@ class EventBroadcaster:
70
  while True:
71
  try:
72
  event: Event = await self._source.get()
73
- msg = {"event_type": event.event_type, "data": event.data, "seq": event.seq}
 
 
 
 
74
  for q in self._subscribers.values():
75
  await q.put(msg)
76
  except asyncio.CancelledError:
@@ -131,6 +136,7 @@ class SessionManager:
131
  self.sessions: dict[str, AgentSession] = {}
132
  self._lock = asyncio.Lock()
133
  self.persistence_store = None
 
134
 
135
  async def start(self) -> None:
136
  """Start shared background resources."""
@@ -153,9 +159,7 @@ class SessionManager:
153
  def _count_user_sessions(self, user_id: str) -> int:
154
  """Count active sessions owned by a specific user."""
155
  return sum(
156
- 1
157
- for s in self.sessions.values()
158
- if s.user_id == user_id and s.is_active
159
  )
160
 
161
  def _create_session_sync(
@@ -196,10 +200,7 @@ class SessionManager:
196
  return tool_router, session
197
 
198
  def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
199
- return [
200
- msg.model_dump(mode="json")
201
- for msg in session.context_manager.items
202
- ]
203
 
204
  def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]:
205
  pending = session.pending_approval or {}
@@ -307,7 +308,9 @@ class SessionManager:
307
  if hasattr(session, "auto_approval_policy_summary"):
308
  return session.auto_approval_policy_summary()
309
  cap = getattr(session, "auto_approval_cost_cap_usd", None)
310
- estimated = float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0)
 
 
311
  remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4)
312
  return {
313
  "enabled": bool(getattr(session, "auto_approval_enabled", False)),
@@ -410,6 +413,28 @@ class SessionManager:
410
  session.sandbox_preload_cancel_event = None
411
  self._start_cpu_sandbox_preload(agent_session)
412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
414
  try:
415
  await self._store().update_session_fields(
@@ -514,7 +539,9 @@ class SessionManager:
514
  runtime_state=runtime_state or self._runtime_state(agent_session),
515
  status=status,
516
  turn_count=agent_session.session.turn_count,
517
- pending_approval=self._serialize_pending_approval(agent_session.session),
 
 
518
  claude_counted=agent_session.claude_counted,
519
  created_at=agent_session.created_at,
520
  notification_destinations=list(
@@ -564,6 +591,7 @@ class SessionManager:
564
  existing,
565
  preload_sandbox=preload_sandbox,
566
  )
 
567
  return existing
568
  return None
569
 
@@ -585,6 +613,7 @@ class SessionManager:
585
  existing,
586
  preload_sandbox=preload_sandbox,
587
  )
 
588
  return existing
589
  return None
590
 
@@ -626,7 +655,10 @@ class SessionManager:
626
  if restored_messages:
627
  # Keep the freshly-rendered system prompt, then attach the durable
628
  # non-system context so tools/date/user context stay current.
629
- session.context_manager.items = [session.context_manager.items[0], *restored_messages]
 
 
 
630
 
631
  self._restore_pending_approval(session, meta.get("pending_approval") or [])
632
  session.turn_count = int(meta.get("turn_count") or 0)
@@ -668,7 +700,9 @@ class SessionManager:
668
  hf_token=hf_token,
669
  hf_username=hf_username,
670
  )
 
671
  return started
 
672
  if preload_sandbox:
673
  self._start_cpu_sandbox_preload(agent_session)
674
  logger.info("Restored session %s for user %s", session_id, owner or user_id)
@@ -751,6 +785,7 @@ class SessionManager:
751
  event_queue=event_queue,
752
  tool_router=tool_router,
753
  )
 
754
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
755
  self._start_cpu_sandbox_preload(agent_session)
756
 
@@ -760,7 +795,9 @@ class SessionManager:
760
  logger.info(f"Created session {session_id} for user {user_id}")
761
  return session_id
762
 
763
- async def _track_pro_status(self, agent_session: AgentSession, *, is_pro: bool) -> None:
 
 
764
  """Update Mongo per-user Pro state and emit a one-shot conversion
765
  event if the store reports a free→Pro transition. Best-effort: any
766
  Mongo failure is swallowed so we never fail session creation on
@@ -777,6 +814,7 @@ class SessionManager:
777
  return
778
  try:
779
  from agent.core import telemetry
 
780
  await telemetry.record_pro_conversion(
781
  agent_session.session,
782
  first_seen_at=result.get("first_seen_at"),
@@ -933,7 +971,9 @@ class SessionManager:
933
  )
934
  agent_session.is_processing = True
935
  try:
936
- should_continue = await process_submission(session, submission)
 
 
937
  finally:
938
  agent_session.is_processing = False
939
  await self.persist_session_snapshot(agent_session)
@@ -964,7 +1004,9 @@ class SessionManager:
964
  # Idempotent via session_id key; detached subprocess.
965
  if session.config.save_sessions:
966
  try:
967
- session.save_and_upload_detached(session.config.session_dataset_repo)
 
 
968
  except Exception as e:
969
  logger.warning(f"Final-flush failed for {session_id}: {e}")
970
 
@@ -1025,7 +1067,9 @@ class SessionManager:
1025
  agent_session = self.sessions.get(session_id)
1026
  if not agent_session or not agent_session.is_active:
1027
  return False
1028
- success = agent_session.session.context_manager.truncate_to_user_message(user_message_index)
 
 
1029
  if success:
1030
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
1031
  return success
@@ -1118,9 +1162,7 @@ class SessionManager:
1118
  session = agent_session.session
1119
  if enabled:
1120
  if not cap_provided and cost_cap_usd is None:
1121
- cost_cap_usd = getattr(
1122
- session, "auto_approval_cost_cap_usd", None
1123
- )
1124
  if cost_cap_usd is None:
1125
  cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
1126
  elif cost_cap_usd is None:
@@ -1203,9 +1245,7 @@ class SessionManager:
1203
  if destination is None:
1204
  raise ValueError(f"Unknown destination '{name}'")
1205
  if not destination.allow_auto_events:
1206
- raise ValueError(
1207
- f"Destination '{name}' is not enabled for auto events"
1208
- )
1209
  if name not in seen:
1210
  normalized.append(name)
1211
  seen.add(name)
@@ -1248,7 +1288,10 @@ class SessionManager:
1248
  "pending_approval": pending or None,
1249
  "model": row.get("model"),
1250
  "title": row.get("title"),
1251
- "notification_destinations": row.get("notification_destinations") or [],
 
 
 
1252
  "auto_approval": {
1253
  "enabled": bool(row.get("auto_approval_enabled", False)),
1254
  "cost_cap_usd": row.get("auto_approval_cost_cap_usd"),
@@ -1261,8 +1304,13 @@ class SessionManager:
1261
  else round(
1262
  max(
1263
  0.0,
1264
- float(row.get("auto_approval_cost_cap_usd") or 0.0)
1265
- - float(row.get("auto_approval_estimated_spend_usd") or 0.0),
 
 
 
 
 
1266
  ),
1267
  4,
1268
  )
 
12
 
13
  from agent.config import load_config
14
  from agent.core.agent_loop import process_submission
15
+ from agent.core.hub_artifacts import start_session_artifact_collection_task
16
  from agent.core.session import Event, OpType, Session
17
  from agent.core.session_persistence import get_session_store
18
  from agent.core.tools import ToolRouter
19
+ from agent.messaging.gateway import NotificationGateway
20
 
21
  # Get project root (parent of backend directory)
22
  PROJECT_ROOT = Path(__file__).parent.parent
 
71
  while True:
72
  try:
73
  event: Event = await self._source.get()
74
+ msg = {
75
+ "event_type": event.event_type,
76
+ "data": event.data,
77
+ "seq": event.seq,
78
+ }
79
  for q in self._subscribers.values():
80
  await q.put(msg)
81
  except asyncio.CancelledError:
 
136
  self.sessions: dict[str, AgentSession] = {}
137
  self._lock = asyncio.Lock()
138
  self.persistence_store = None
139
+ self.enable_hub_artifact_collections = True
140
 
141
  async def start(self) -> None:
142
  """Start shared background resources."""
 
159
  def _count_user_sessions(self, user_id: str) -> int:
160
  """Count active sessions owned by a specific user."""
161
  return sum(
162
+ 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active
 
 
163
  )
164
 
165
  def _create_session_sync(
 
200
  return tool_router, session
201
 
202
  def _serialize_messages(self, session: Session) -> list[dict[str, Any]]:
203
+ return [msg.model_dump(mode="json") for msg in session.context_manager.items]
 
 
 
204
 
205
  def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]:
206
  pending = session.pending_approval or {}
 
308
  if hasattr(session, "auto_approval_policy_summary"):
309
  return session.auto_approval_policy_summary()
310
  cap = getattr(session, "auto_approval_cost_cap_usd", None)
311
+ estimated = float(
312
+ getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0
313
+ )
314
  remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4)
315
  return {
316
  "enabled": bool(getattr(session, "auto_approval_enabled", False)),
 
413
  session.sandbox_preload_cancel_event = None
414
  self._start_cpu_sandbox_preload(agent_session)
415
 
416
+ def _start_hub_artifact_collection(self, agent_session: AgentSession) -> None:
417
+ """Kick off best-effort Hub collection creation for the session."""
418
+ if not getattr(self, "enable_hub_artifact_collections", False):
419
+ return
420
+ session = agent_session.session
421
+ if not getattr(session, "session_id", None):
422
+ try:
423
+ session.session_id = agent_session.session_id
424
+ except Exception:
425
+ logger.debug("Could not attach session id for Hub artifact collection")
426
+ token = agent_session.hf_token or getattr(session, "hf_token", None)
427
+ if not token:
428
+ return
429
+ try:
430
+ start_session_artifact_collection_task(session, token=token)
431
+ except Exception as e:
432
+ logger.debug(
433
+ "Failed to schedule Hub artifact collection for %s: %s",
434
+ agent_session.session_id,
435
+ e,
436
+ )
437
+
438
  async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None:
439
  try:
440
  await self._store().update_session_fields(
 
539
  runtime_state=runtime_state or self._runtime_state(agent_session),
540
  status=status,
541
  turn_count=agent_session.session.turn_count,
542
+ pending_approval=self._serialize_pending_approval(
543
+ agent_session.session
544
+ ),
545
  claude_counted=agent_session.claude_counted,
546
  created_at=agent_session.created_at,
547
  notification_destinations=list(
 
591
  existing,
592
  preload_sandbox=preload_sandbox,
593
  )
594
+ self._start_hub_artifact_collection(existing)
595
  return existing
596
  return None
597
 
 
613
  existing,
614
  preload_sandbox=preload_sandbox,
615
  )
616
+ self._start_hub_artifact_collection(existing)
617
  return existing
618
  return None
619
 
 
655
  if restored_messages:
656
  # Keep the freshly-rendered system prompt, then attach the durable
657
  # non-system context so tools/date/user context stay current.
658
+ session.context_manager.items = [
659
+ session.context_manager.items[0],
660
+ *restored_messages,
661
+ ]
662
 
663
  self._restore_pending_approval(session, meta.get("pending_approval") or [])
664
  session.turn_count = int(meta.get("turn_count") or 0)
 
700
  hf_token=hf_token,
701
  hf_username=hf_username,
702
  )
703
+ self._start_hub_artifact_collection(started)
704
  return started
705
+ self._start_hub_artifact_collection(agent_session)
706
  if preload_sandbox:
707
  self._start_cpu_sandbox_preload(agent_session)
708
  logger.info("Restored session %s for user %s", session_id, owner or user_id)
 
785
  event_queue=event_queue,
786
  tool_router=tool_router,
787
  )
788
+ self._start_hub_artifact_collection(agent_session)
789
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
790
  self._start_cpu_sandbox_preload(agent_session)
791
 
 
795
  logger.info(f"Created session {session_id} for user {user_id}")
796
  return session_id
797
 
798
+ async def _track_pro_status(
799
+ self, agent_session: AgentSession, *, is_pro: bool
800
+ ) -> None:
801
  """Update Mongo per-user Pro state and emit a one-shot conversion
802
  event if the store reports a free→Pro transition. Best-effort: any
803
  Mongo failure is swallowed so we never fail session creation on
 
814
  return
815
  try:
816
  from agent.core import telemetry
817
+
818
  await telemetry.record_pro_conversion(
819
  agent_session.session,
820
  first_seen_at=result.get("first_seen_at"),
 
971
  )
972
  agent_session.is_processing = True
973
  try:
974
+ should_continue = await process_submission(
975
+ session, submission
976
+ )
977
  finally:
978
  agent_session.is_processing = False
979
  await self.persist_session_snapshot(agent_session)
 
1004
  # Idempotent via session_id key; detached subprocess.
1005
  if session.config.save_sessions:
1006
  try:
1007
+ session.save_and_upload_detached(
1008
+ session.config.session_dataset_repo
1009
+ )
1010
  except Exception as e:
1011
  logger.warning(f"Final-flush failed for {session_id}: {e}")
1012
 
 
1067
  agent_session = self.sessions.get(session_id)
1068
  if not agent_session or not agent_session.is_active:
1069
  return False
1070
+ success = agent_session.session.context_manager.truncate_to_user_message(
1071
+ user_message_index
1072
+ )
1073
  if success:
1074
  await self.persist_session_snapshot(agent_session, runtime_state="idle")
1075
  return success
 
1162
  session = agent_session.session
1163
  if enabled:
1164
  if not cap_provided and cost_cap_usd is None:
1165
+ cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None)
 
 
1166
  if cost_cap_usd is None:
1167
  cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD
1168
  elif cost_cap_usd is None:
 
1245
  if destination is None:
1246
  raise ValueError(f"Unknown destination '{name}'")
1247
  if not destination.allow_auto_events:
1248
+ raise ValueError(f"Destination '{name}' is not enabled for auto events")
 
 
1249
  if name not in seen:
1250
  normalized.append(name)
1251
  seen.add(name)
 
1288
  "pending_approval": pending or None,
1289
  "model": row.get("model"),
1290
  "title": row.get("title"),
1291
+ "notification_destinations": row.get(
1292
+ "notification_destinations"
1293
+ )
1294
+ or [],
1295
  "auto_approval": {
1296
  "enabled": bool(row.get("auto_approval_enabled", False)),
1297
  "cost_cap_usd": row.get("auto_approval_cost_cap_usd"),
 
1304
  else round(
1305
  max(
1306
  0.0,
1307
+ float(
1308
+ row.get("auto_approval_cost_cap_usd") or 0.0
1309
+ )
1310
+ - float(
1311
+ row.get("auto_approval_estimated_spend_usd")
1312
+ or 0.0
1313
+ ),
1314
  ),
1315
  4,
1316
  )
backend/user_quotas.py CHANGED
@@ -20,7 +20,11 @@ import asyncio
20
  import os
21
  from datetime import UTC, datetime
22
 
23
- from agent.core.session_persistence import NoopSessionStore, get_session_store, _reset_store_for_tests
 
 
 
 
24
 
25
  CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1"))
26
  CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20"))
 
20
  import os
21
  from datetime import UTC, datetime
22
 
23
+ from agent.core.session_persistence import (
24
+ NoopSessionStore,
25
+ get_session_store,
26
+ _reset_store_for_tests,
27
+ )
28
 
29
  CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1"))
30
  CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20"))
pyproject.toml CHANGED
@@ -44,6 +44,7 @@ eval = [
44
  dev = [
45
  "pytest>=9.0.2",
46
  "pytest-asyncio>=1.2.0",
 
47
  ]
48
 
49
  # All dependencies (eval + dev)
 
44
  dev = [
45
  "pytest>=9.0.2",
46
  "pytest-asyncio>=1.2.0",
47
+ "ruff>=0.15.12",
48
  ]
49
 
50
  # All dependencies (eval + dev)
scripts/build_kpis.py CHANGED
@@ -99,7 +99,6 @@ import sys
99
  import tempfile
100
  from collections import defaultdict
101
  from datetime import date, datetime, timedelta, timezone
102
- from pathlib import Path
103
  from typing import Any, Iterable
104
 
105
  logger = logging.getLogger("build_kpis")
@@ -107,13 +106,25 @@ logger = logging.getLogger("build_kpis")
107
  # Rough gpu-hour pricing for hf_jobs flavor strings. Keep conservative; used
108
  # only to compute gpu-hours (not dollars) — wall_time_s * flavor_gpu_count.
109
  _FLAVOR_GPU_COUNT = {
110
- "cpu-basic": 0, "cpu-upgrade": 0,
111
- "t4-small": 1, "t4-medium": 1,
112
- "l4x1": 1, "l4x4": 4,
113
- "l40sx1": 1, "l40sx4": 4, "l40sx8": 8,
114
- "a10g-small": 1, "a10g-large": 1, "a10g-largex2": 2, "a10g-largex4": 4,
115
- "a100-large": 1, "a100x2": 2, "a100x4": 4, "a100x8": 8,
116
- "h100": 1, "h100x8": 8,
 
 
 
 
 
 
 
 
 
 
 
 
117
  }
118
 
119
 
@@ -160,9 +171,13 @@ def _download_session(repo_id: str, path: str, token: str) -> dict | None:
160
  directory is near-free.
161
  """
162
  from huggingface_hub import hf_hub_download
 
163
  try:
164
  local = hf_hub_download(
165
- repo_id=repo_id, filename=path, repo_type="dataset", token=token,
 
 
 
166
  )
167
  except Exception as e:
168
  logger.warning("hf_hub_download(%s) failed: %s", path, e)
@@ -188,7 +203,9 @@ def _download_session(repo_id: str, path: str, token: str) -> dict | None:
188
 
189
 
190
  def _filter_session_to_window(
191
- session: dict, start: datetime, end: datetime,
 
 
192
  ) -> dict | None:
193
  """Return a copy of ``session`` whose events are only those in ``[start, end)``.
194
 
@@ -216,16 +233,29 @@ def _session_metrics(session: dict) -> dict:
216
  # Pre-seed every numeric key so downstream aggregation can sum without
217
  # having to special-case empty sessions.
218
  out: dict = {
219
- "sessions": 0, "turns": 0, "llm_calls": 0,
220
- "tokens_prompt": 0, "tokens_completion": 0,
221
- "tokens_cache_read": 0, "tokens_cache_creation": 0,
 
 
 
 
222
  "cost_usd": 0.0,
223
- "tool_calls_total": 0, "tool_calls_success": 0,
224
- "failures": 0, "regenerate_sessions": 0,
225
- "thumbs_up": 0, "thumbs_down": 0,
226
- "hf_jobs_submitted": 0, "hf_jobs_succeeded": 0, "hf_jobs_blocked": 0,
227
- "pro_cta_clicks": 0, "pro_conversions": 0, "credits_topped_up": 0,
228
- "sandboxes_created": 0, "sandboxes_cpu": 0, "sandboxes_gpu": 0,
 
 
 
 
 
 
 
 
 
229
  "first_tool_s": -1,
230
  }
231
  events = session.get("events") or []
@@ -373,7 +403,9 @@ def _session_metrics(session: dict) -> dict:
373
 
374
  def _aggregate(per_session: list[dict]) -> dict:
375
  """Collapse a bucket's worth of session rollups into the final KPI row."""
376
- ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
 
 
377
  gpu_hours: dict[str, float] = defaultdict(float)
378
  for s in per_session:
379
  for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
@@ -395,9 +427,21 @@ def _aggregate(per_session: list[dict]) -> dict:
395
  # never reached for the relevant signal — otherwise quiet hours
396
  # (status-check sessions, abandoned new conversations) drag every median
397
  # to 0 and the chart tells you nothing.
398
- research_calls_nz = [s.get("_research_calls", 0) for s in per_session if s.get("_research_calls", 0) > 0]
399
- distinct_tools_values = [s.get("_distinct_tools_used", 0) for s in per_session if s.get("_distinct_tools_used", 0) > 0]
400
- total_calls_values = [s.get("_total_named_tool_calls", 0) for s in per_session if s.get("_total_named_tool_calls", 0) > 0]
 
 
 
 
 
 
 
 
 
 
 
 
401
  # Per-turn intensity: turns>0 is the natural filter here (a session with
402
  # 5 turns and 0 tools is a meaningful 0). Don't strip those.
403
  calls_per_turn_values = [
@@ -415,7 +459,9 @@ def _aggregate(per_session: list[dict]) -> dict:
415
  failures = int(sum(s["failures"] for s in per_session))
416
  regenerates = int(sum(s["regenerate_sessions"] for s in per_session))
417
  research_calls_total = int(sum(s.get("_research_calls", 0) for s in per_session))
418
- sessions_with_research = sum(1 for s in per_session if s.get("_research_calls", 0) > 0)
 
 
419
 
420
  # Per-session cost percentiles — chart "median session cost" alongside the
421
  # mean so a few $700 outliers don't make you think every session is pricey.
@@ -433,17 +479,23 @@ def _aggregate(per_session: list[dict]) -> dict:
433
  "tokens_prompt": int(tokens_prompt),
434
  "tokens_completion": int(sum(s["tokens_completion"] for s in per_session)),
435
  "tokens_cache_read": int(tokens_cache_read),
436
- "tokens_cache_creation": int(sum(s["tokens_cache_creation"] for s in per_session)),
 
 
437
  "cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
438
  # Per-session cost summaries.
439
  "cost_per_session_mean": round(
440
  sum(s["cost_usd"] for s in per_session) / total_sessions, 6
441
- ) if total_sessions > 0 else 0.0,
 
 
442
  "cost_per_session_p50": round(cost_p50, 6),
443
  "cost_per_session_p95": round(cost_p95, 6),
444
  "cache_hit_ratio": round(
445
  tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
446
- ) if (tokens_cache_read + tokens_prompt) > 0 else 0.0,
 
 
447
  # Raw reliability COUNTS (these are what the dashboard shows directly).
448
  "tool_calls_total": int(tool_total),
449
  "tool_calls_succeeded": int(tool_success),
@@ -458,38 +510,56 @@ def _aggregate(per_session: list[dict]) -> dict:
458
  "regenerated_sessions": regenerates,
459
  # Rates kept for backwards compatibility with anything reading the
460
  # KPI dataset directly.
461
- "tool_success_rate": round(tool_success / tool_total, 4) if tool_total > 0 else 0.0,
462
- "failure_rate": round(failures / total_sessions, 4) if total_sessions > 0 else 0.0,
463
- "regenerate_rate": round(regenerates / total_sessions, 4) if total_sessions > 0 else 0.0,
 
 
 
 
 
 
464
  "time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
465
  "time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
466
  "thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
467
  "thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
468
  "hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
469
  "hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
470
- "sandboxes_created": int(sum(s.get("sandboxes_created", 0) for s in per_session)),
 
 
471
  "sandboxes_cpu": int(sum(s.get("sandboxes_cpu", 0) for s in per_session)),
472
  "sandboxes_gpu": int(sum(s.get("sandboxes_gpu", 0) for s in per_session)),
473
  "hf_jobs_blocked": int(sum(s.get("hf_jobs_blocked", 0) for s in per_session)),
474
  "pro_cta_clicks": int(sum(s.get("pro_cta_clicks", 0) for s in per_session)),
475
  "pro_conversions": int(sum(s.get("pro_conversions", 0) for s in per_session)),
476
- "credits_topped_up": int(sum(s.get("credits_topped_up", 0) for s in per_session)),
 
 
477
  "gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
478
  # Research KPIs — answer "is the agent reaching for research?".
479
  "research_calls": research_calls_total,
480
  "sessions_with_research": int(sessions_with_research),
481
  "research_calls_per_session_p50": round(_percentile(research_calls_nz, 0.5), 2),
482
- "research_calls_per_session_p95": round(_percentile(research_calls_nz, 0.95), 2),
 
 
483
  # Intra-session breadth + intensity. p50 + p95 over per-session values.
484
- "distinct_tools_per_session_p50": round(_percentile(distinct_tools_values, 0.5), 2),
485
- "distinct_tools_per_session_p95": round(_percentile(distinct_tools_values, 0.95), 2),
 
 
 
 
486
  "tool_calls_per_session_p50": round(_percentile(total_calls_values, 0.5), 2),
487
  "tool_calls_per_session_p95": round(_percentile(total_calls_values, 0.95), 2),
488
  "tool_calls_per_turn_p50": round(_percentile(calls_per_turn_values, 0.5), 2),
489
  "tool_calls_per_turn_p95": round(_percentile(calls_per_turn_values, 0.95), 2),
490
  # JSON columns let the dashboard add/remove tools without schema churn.
491
  "tool_calls_by_name_json": json.dumps(dict(tool_calls_by_name), sort_keys=True),
492
- "sessions_using_tool_json": json.dumps(dict(sessions_using_tool), sort_keys=True),
 
 
493
  # Surface split — answers "is research dropping on Bedrock specifically?".
494
  "sessions_by_model_json": json.dumps(dict(sessions_by_model), sort_keys=True),
495
  }
@@ -507,7 +577,12 @@ def _csv_cell(v: Any) -> str:
507
 
508
 
509
  def _write_csv(
510
- api, row: dict, bucket_key: str, path_in_repo: str, target_repo: str, token: str,
 
 
 
 
 
511
  ) -> None:
512
  """Render ``row`` to CSV with a leading ``bucket`` column and upload.
513
 
@@ -527,7 +602,10 @@ def _write_csv(
527
 
528
  try:
529
  api.create_repo(
530
- repo_id=target_repo, repo_type="dataset", exist_ok=True, token=token,
 
 
 
531
  )
532
  api.upload_file(
533
  path_or_fileobj=tmp_path,
@@ -545,7 +623,11 @@ def _write_csv(
545
 
546
 
547
  def run_for_hour(
548
- api, source_repo: str, target_repo: str, hour_dt: datetime, token: str,
 
 
 
 
549
  ) -> dict:
550
  """Roll up one UTC hour [hour_dt, hour_dt+1h).
551
 
@@ -579,10 +661,16 @@ def run_for_hour(
579
 
580
  row = _aggregate(per_session)
581
  bucket_key = window_start.strftime("%Y-%m-%dT%H")
582
- path_in_repo = f"hourly/{window_start.strftime('%Y-%m-%d')}/{window_start.strftime('%H')}.csv"
 
 
583
  _write_csv(api, row, bucket_key, path_in_repo, target_repo, token)
584
- logger.info("Wrote KPIs for %s (%d sessions): %s",
585
- bucket_key, per_session and len(per_session), row)
 
 
 
 
586
  return row
587
 
588
 
@@ -618,17 +706,23 @@ def main(argv: list[str] | None = None) -> int:
618
  ap.add_argument("--source", default="smolagents/ml-intern-sessions")
619
  ap.add_argument("--target", default="smolagents/ml-intern-kpis")
620
  ap.add_argument(
621
- "--hours", type=int, default=1,
 
 
622
  help="Number of trailing hours to roll up (default: 1 = last completed hour).",
623
  )
624
  ap.add_argument(
625
- "--datetime", type=str, default=None,
 
 
626
  help="Single hour, ISO ``YYYY-MM-DDTHH`` (UTC); overrides --hours.",
627
  )
628
  ap.add_argument(
629
- "--daily-backfill", type=str, default=None,
 
 
630
  help="Escape hatch: aggregate a whole day at once (YYYY-MM-DD). "
631
- "Writes to daily/<date>.csv. Use for historical backfill only.",
632
  )
633
  args = ap.parse_args(argv)
634
 
@@ -646,10 +740,17 @@ def main(argv: list[str] | None = None) -> int:
646
  return 1
647
 
648
  from huggingface_hub import HfApi
 
649
  api = HfApi()
650
 
651
  if args.daily_backfill:
652
- run_for_day(api, args.source, args.target, date.fromisoformat(args.daily_backfill), token)
 
 
 
 
 
 
653
  return 0
654
 
655
  if args.datetime:
 
99
  import tempfile
100
  from collections import defaultdict
101
  from datetime import date, datetime, timedelta, timezone
 
102
  from typing import Any, Iterable
103
 
104
  logger = logging.getLogger("build_kpis")
 
106
  # Rough gpu-hour pricing for hf_jobs flavor strings. Keep conservative; used
107
  # only to compute gpu-hours (not dollars) — wall_time_s * flavor_gpu_count.
108
  _FLAVOR_GPU_COUNT = {
109
+ "cpu-basic": 0,
110
+ "cpu-upgrade": 0,
111
+ "t4-small": 1,
112
+ "t4-medium": 1,
113
+ "l4x1": 1,
114
+ "l4x4": 4,
115
+ "l40sx1": 1,
116
+ "l40sx4": 4,
117
+ "l40sx8": 8,
118
+ "a10g-small": 1,
119
+ "a10g-large": 1,
120
+ "a10g-largex2": 2,
121
+ "a10g-largex4": 4,
122
+ "a100-large": 1,
123
+ "a100x2": 2,
124
+ "a100x4": 4,
125
+ "a100x8": 8,
126
+ "h100": 1,
127
+ "h100x8": 8,
128
  }
129
 
130
 
 
171
  directory is near-free.
172
  """
173
  from huggingface_hub import hf_hub_download
174
+
175
  try:
176
  local = hf_hub_download(
177
+ repo_id=repo_id,
178
+ filename=path,
179
+ repo_type="dataset",
180
+ token=token,
181
  )
182
  except Exception as e:
183
  logger.warning("hf_hub_download(%s) failed: %s", path, e)
 
203
 
204
 
205
  def _filter_session_to_window(
206
+ session: dict,
207
+ start: datetime,
208
+ end: datetime,
209
  ) -> dict | None:
210
  """Return a copy of ``session`` whose events are only those in ``[start, end)``.
211
 
 
233
  # Pre-seed every numeric key so downstream aggregation can sum without
234
  # having to special-case empty sessions.
235
  out: dict = {
236
+ "sessions": 0,
237
+ "turns": 0,
238
+ "llm_calls": 0,
239
+ "tokens_prompt": 0,
240
+ "tokens_completion": 0,
241
+ "tokens_cache_read": 0,
242
+ "tokens_cache_creation": 0,
243
  "cost_usd": 0.0,
244
+ "tool_calls_total": 0,
245
+ "tool_calls_success": 0,
246
+ "failures": 0,
247
+ "regenerate_sessions": 0,
248
+ "thumbs_up": 0,
249
+ "thumbs_down": 0,
250
+ "hf_jobs_submitted": 0,
251
+ "hf_jobs_succeeded": 0,
252
+ "hf_jobs_blocked": 0,
253
+ "pro_cta_clicks": 0,
254
+ "pro_conversions": 0,
255
+ "credits_topped_up": 0,
256
+ "sandboxes_created": 0,
257
+ "sandboxes_cpu": 0,
258
+ "sandboxes_gpu": 0,
259
  "first_tool_s": -1,
260
  }
261
  events = session.get("events") or []
 
403
 
404
  def _aggregate(per_session: list[dict]) -> dict:
405
  """Collapse a bucket's worth of session rollups into the final KPI row."""
406
+ ttfa_values = [
407
+ s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0
408
+ ]
409
  gpu_hours: dict[str, float] = defaultdict(float)
410
  for s in per_session:
411
  for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
 
427
  # never reached for the relevant signal — otherwise quiet hours
428
  # (status-check sessions, abandoned new conversations) drag every median
429
  # to 0 and the chart tells you nothing.
430
+ research_calls_nz = [
431
+ s.get("_research_calls", 0)
432
+ for s in per_session
433
+ if s.get("_research_calls", 0) > 0
434
+ ]
435
+ distinct_tools_values = [
436
+ s.get("_distinct_tools_used", 0)
437
+ for s in per_session
438
+ if s.get("_distinct_tools_used", 0) > 0
439
+ ]
440
+ total_calls_values = [
441
+ s.get("_total_named_tool_calls", 0)
442
+ for s in per_session
443
+ if s.get("_total_named_tool_calls", 0) > 0
444
+ ]
445
  # Per-turn intensity: turns>0 is the natural filter here (a session with
446
  # 5 turns and 0 tools is a meaningful 0). Don't strip those.
447
  calls_per_turn_values = [
 
459
  failures = int(sum(s["failures"] for s in per_session))
460
  regenerates = int(sum(s["regenerate_sessions"] for s in per_session))
461
  research_calls_total = int(sum(s.get("_research_calls", 0) for s in per_session))
462
+ sessions_with_research = sum(
463
+ 1 for s in per_session if s.get("_research_calls", 0) > 0
464
+ )
465
 
466
  # Per-session cost percentiles — chart "median session cost" alongside the
467
  # mean so a few $700 outliers don't make you think every session is pricey.
 
479
  "tokens_prompt": int(tokens_prompt),
480
  "tokens_completion": int(sum(s["tokens_completion"] for s in per_session)),
481
  "tokens_cache_read": int(tokens_cache_read),
482
+ "tokens_cache_creation": int(
483
+ sum(s["tokens_cache_creation"] for s in per_session)
484
+ ),
485
  "cost_usd": round(sum(s["cost_usd"] for s in per_session), 4),
486
  # Per-session cost summaries.
487
  "cost_per_session_mean": round(
488
  sum(s["cost_usd"] for s in per_session) / total_sessions, 6
489
+ )
490
+ if total_sessions > 0
491
+ else 0.0,
492
  "cost_per_session_p50": round(cost_p50, 6),
493
  "cost_per_session_p95": round(cost_p95, 6),
494
  "cache_hit_ratio": round(
495
  tokens_cache_read / (tokens_cache_read + tokens_prompt), 4
496
+ )
497
+ if (tokens_cache_read + tokens_prompt) > 0
498
+ else 0.0,
499
  # Raw reliability COUNTS (these are what the dashboard shows directly).
500
  "tool_calls_total": int(tool_total),
501
  "tool_calls_succeeded": int(tool_success),
 
510
  "regenerated_sessions": regenerates,
511
  # Rates kept for backwards compatibility with anything reading the
512
  # KPI dataset directly.
513
+ "tool_success_rate": round(tool_success / tool_total, 4)
514
+ if tool_total > 0
515
+ else 0.0,
516
+ "failure_rate": round(failures / total_sessions, 4)
517
+ if total_sessions > 0
518
+ else 0.0,
519
+ "regenerate_rate": round(regenerates / total_sessions, 4)
520
+ if total_sessions > 0
521
+ else 0.0,
522
  "time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2),
523
  "time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2),
524
  "thumbs_up": int(sum(s["thumbs_up"] for s in per_session)),
525
  "thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
526
  "hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
527
  "hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
528
+ "sandboxes_created": int(
529
+ sum(s.get("sandboxes_created", 0) for s in per_session)
530
+ ),
531
  "sandboxes_cpu": int(sum(s.get("sandboxes_cpu", 0) for s in per_session)),
532
  "sandboxes_gpu": int(sum(s.get("sandboxes_gpu", 0) for s in per_session)),
533
  "hf_jobs_blocked": int(sum(s.get("hf_jobs_blocked", 0) for s in per_session)),
534
  "pro_cta_clicks": int(sum(s.get("pro_cta_clicks", 0) for s in per_session)),
535
  "pro_conversions": int(sum(s.get("pro_conversions", 0) for s in per_session)),
536
+ "credits_topped_up": int(
537
+ sum(s.get("credits_topped_up", 0) for s in per_session)
538
+ ),
539
  "gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
540
  # Research KPIs — answer "is the agent reaching for research?".
541
  "research_calls": research_calls_total,
542
  "sessions_with_research": int(sessions_with_research),
543
  "research_calls_per_session_p50": round(_percentile(research_calls_nz, 0.5), 2),
544
+ "research_calls_per_session_p95": round(
545
+ _percentile(research_calls_nz, 0.95), 2
546
+ ),
547
  # Intra-session breadth + intensity. p50 + p95 over per-session values.
548
+ "distinct_tools_per_session_p50": round(
549
+ _percentile(distinct_tools_values, 0.5), 2
550
+ ),
551
+ "distinct_tools_per_session_p95": round(
552
+ _percentile(distinct_tools_values, 0.95), 2
553
+ ),
554
  "tool_calls_per_session_p50": round(_percentile(total_calls_values, 0.5), 2),
555
  "tool_calls_per_session_p95": round(_percentile(total_calls_values, 0.95), 2),
556
  "tool_calls_per_turn_p50": round(_percentile(calls_per_turn_values, 0.5), 2),
557
  "tool_calls_per_turn_p95": round(_percentile(calls_per_turn_values, 0.95), 2),
558
  # JSON columns let the dashboard add/remove tools without schema churn.
559
  "tool_calls_by_name_json": json.dumps(dict(tool_calls_by_name), sort_keys=True),
560
+ "sessions_using_tool_json": json.dumps(
561
+ dict(sessions_using_tool), sort_keys=True
562
+ ),
563
  # Surface split — answers "is research dropping on Bedrock specifically?".
564
  "sessions_by_model_json": json.dumps(dict(sessions_by_model), sort_keys=True),
565
  }
 
577
 
578
 
579
  def _write_csv(
580
+ api,
581
+ row: dict,
582
+ bucket_key: str,
583
+ path_in_repo: str,
584
+ target_repo: str,
585
+ token: str,
586
  ) -> None:
587
  """Render ``row`` to CSV with a leading ``bucket`` column and upload.
588
 
 
602
 
603
  try:
604
  api.create_repo(
605
+ repo_id=target_repo,
606
+ repo_type="dataset",
607
+ exist_ok=True,
608
+ token=token,
609
  )
610
  api.upload_file(
611
  path_or_fileobj=tmp_path,
 
623
 
624
 
625
  def run_for_hour(
626
+ api,
627
+ source_repo: str,
628
+ target_repo: str,
629
+ hour_dt: datetime,
630
+ token: str,
631
  ) -> dict:
632
  """Roll up one UTC hour [hour_dt, hour_dt+1h).
633
 
 
661
 
662
  row = _aggregate(per_session)
663
  bucket_key = window_start.strftime("%Y-%m-%dT%H")
664
+ path_in_repo = (
665
+ f"hourly/{window_start.strftime('%Y-%m-%d')}/{window_start.strftime('%H')}.csv"
666
+ )
667
  _write_csv(api, row, bucket_key, path_in_repo, target_repo, token)
668
+ logger.info(
669
+ "Wrote KPIs for %s (%d sessions): %s",
670
+ bucket_key,
671
+ per_session and len(per_session),
672
+ row,
673
+ )
674
  return row
675
 
676
 
 
706
  ap.add_argument("--source", default="smolagents/ml-intern-sessions")
707
  ap.add_argument("--target", default="smolagents/ml-intern-kpis")
708
  ap.add_argument(
709
+ "--hours",
710
+ type=int,
711
+ default=1,
712
  help="Number of trailing hours to roll up (default: 1 = last completed hour).",
713
  )
714
  ap.add_argument(
715
+ "--datetime",
716
+ type=str,
717
+ default=None,
718
  help="Single hour, ISO ``YYYY-MM-DDTHH`` (UTC); overrides --hours.",
719
  )
720
  ap.add_argument(
721
+ "--daily-backfill",
722
+ type=str,
723
+ default=None,
724
  help="Escape hatch: aggregate a whole day at once (YYYY-MM-DD). "
725
+ "Writes to daily/<date>.csv. Use for historical backfill only.",
726
  )
727
  args = ap.parse_args(argv)
728
 
 
740
  return 1
741
 
742
  from huggingface_hub import HfApi
743
+
744
  api = HfApi()
745
 
746
  if args.daily_backfill:
747
+ run_for_day(
748
+ api,
749
+ args.source,
750
+ args.target,
751
+ date.fromisoformat(args.daily_backfill),
752
+ token,
753
+ )
754
  return 0
755
 
756
  if args.datetime:
scripts/build_sft.py CHANGED
@@ -62,9 +62,13 @@ def _iter_session_files(api, repo_id: str, day: date, token: str) -> Iterable[st
62
 
63
  def _download_and_parse(repo_id: str, path: str, token: str) -> dict | None:
64
  from huggingface_hub import hf_hub_download
 
65
  try:
66
  local = hf_hub_download(
67
- repo_id=repo_id, filename=path, repo_type="dataset", token=token,
 
 
 
68
  )
69
  except Exception as e:
70
  logger.warning("hf_hub_download(%s) failed: %s", path, e)
@@ -118,7 +122,10 @@ def _upload_row(api, row: dict, day: date, target_repo: str, token: str) -> None
118
  tmp_path = tmp.name
119
  try:
120
  api.create_repo(
121
- repo_id=target_repo, repo_type="dataset", exist_ok=True, token=token,
 
 
 
122
  )
123
  api.upload_file(
124
  path_or_fileobj=tmp_path,
@@ -136,7 +143,11 @@ def _upload_row(api, row: dict, day: date, target_repo: str, token: str) -> None
136
 
137
 
138
  def run_for_day(
139
- api, source_repo: str, target_repo: str, day: date, token: str,
 
 
 
 
140
  ) -> int:
141
  paths = _iter_session_files(api, source_repo, day, token)
142
  n = 0
@@ -162,11 +173,15 @@ def main(argv: list[str] | None = None) -> int:
162
  ap.add_argument("--source", default="smolagents/ml-intern-sessions")
163
  ap.add_argument("--target", default="smolagents/ml-intern-sft")
164
  ap.add_argument(
165
- "--days", type=int, default=1,
 
 
166
  help="Number of trailing days to export (default: 1 = yesterday).",
167
  )
168
  ap.add_argument(
169
- "--date", type=str, default=None,
 
 
170
  help="Single YYYY-MM-DD to export; overrides --days.",
171
  )
172
  args = ap.parse_args(argv)
@@ -185,6 +200,7 @@ def main(argv: list[str] | None = None) -> int:
185
  return 1
186
 
187
  from huggingface_hub import HfApi
 
188
  api = HfApi()
189
 
190
  if args.date:
 
62
 
63
  def _download_and_parse(repo_id: str, path: str, token: str) -> dict | None:
64
  from huggingface_hub import hf_hub_download
65
+
66
  try:
67
  local = hf_hub_download(
68
+ repo_id=repo_id,
69
+ filename=path,
70
+ repo_type="dataset",
71
+ token=token,
72
  )
73
  except Exception as e:
74
  logger.warning("hf_hub_download(%s) failed: %s", path, e)
 
122
  tmp_path = tmp.name
123
  try:
124
  api.create_repo(
125
+ repo_id=target_repo,
126
+ repo_type="dataset",
127
+ exist_ok=True,
128
+ token=token,
129
  )
130
  api.upload_file(
131
  path_or_fileobj=tmp_path,
 
143
 
144
 
145
  def run_for_day(
146
+ api,
147
+ source_repo: str,
148
+ target_repo: str,
149
+ day: date,
150
+ token: str,
151
  ) -> int:
152
  paths = _iter_session_files(api, source_repo, day, token)
153
  n = 0
 
173
  ap.add_argument("--source", default="smolagents/ml-intern-sessions")
174
  ap.add_argument("--target", default="smolagents/ml-intern-sft")
175
  ap.add_argument(
176
+ "--days",
177
+ type=int,
178
+ default=1,
179
  help="Number of trailing days to export (default: 1 = yesterday).",
180
  )
181
  ap.add_argument(
182
+ "--date",
183
+ type=str,
184
+ default=None,
185
  help="Single YYYY-MM-DD to export; overrides --days.",
186
  )
187
  args = ap.parse_args(argv)
 
200
  return 1
201
 
202
  from huggingface_hub import HfApi
203
+
204
  api = HfApi()
205
 
206
  if args.date: