diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..5d79742fe97daa25e23740b7904a69439fd38368 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,63 @@ +name: CI + +on: + pull_request: + push: + branches: [main] + +permissions: + contents: read + +concurrency: + group: ci-${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + ruff: + name: Ruff + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: uv.lock + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: uv sync --locked --extra dev + + - name: Run Ruff + run: uv run ruff check . + + - name: Check formatting + run: uv run ruff format --check . + + tests: + name: Tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: uv.lock + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: uv sync --locked --extra dev + + - name: Run tests + run: uv run pytest diff --git a/AGENTS.md b/AGENTS.md index 0e09a85087c8963c25b612a86f3e934ce24cbfbc..5a31cb60fb3d86caf17546877e867b182e381fb6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -15,6 +15,11 @@ Notes: - Prefer `npm ci` over `npm install` for setup, since `npm install` may rewrite `frontend/package-lock.json` metadata depending on npm version. - Production defaults to the Bedrock Claude model. For local development with a personal Anthropic key, set `ANTHROPIC_API_KEY` and `ML_INTERN_CLAUDE_MODEL_ID=anthropic/claude-opus-4-6` before starting the backend. Other models are selected through the app's model switcher. +## Development Checks + +- Before every commit, run `uv run ruff check .` and `uv run ruff format --check .`. +- If formatting fails, run `uv run ruff format .`, then re-run the Ruff checks before committing. + ## GitHub CLI - For multiline PR descriptions, prefer `gh pr edit --body-file ` over inline `--body` so shell quoting, `$` env-var names, backticks, and newlines are preserved correctly. diff --git a/agent/config.py b/agent/config.py index 87f2a9c59331dfc542a0a22aada2e7d728e3126e..35b095c328fe64b53eb51ef5126ebec7e6f546e4 100644 --- a/agent/config.py +++ b/agent/config.py @@ -5,20 +5,20 @@ from pathlib import Path from typing import Any, Union from dotenv import load_dotenv - -from agent.messaging.models import MessagingConfig - -# Project root: two levels up from this file (agent/config.py -> project root) -_PROJECT_ROOT = Path(__file__).resolve().parent.parent from fastmcp.mcp_config import ( RemoteMCPServer, StdioMCPServer, ) from pydantic import BaseModel +from agent.messaging.models import MessagingConfig + # These two are the canonical server config types for MCP servers. MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer] +# Project root: two levels up from this file (agent/config.py -> project root) +_PROJECT_ROOT = Path(__file__).resolve().parent.parent + class Config(BaseModel): """Configuration manager""" @@ -60,12 +60,16 @@ class Config(BaseModel): USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG" -DEFAULT_USER_CONFIG_PATH = Path.home() / ".config" / "ml-intern" / "cli_agent_config.json" +DEFAULT_USER_CONFIG_PATH = ( + Path.home() / ".config" / "ml-intern" / "cli_agent_config.json" +) SLACK_DEFAULT_DESTINATION = "slack.default" SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"] -def _deep_merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: +def _deep_merge_config( + base: dict[str, Any], override: dict[str, Any] +) -> dict[str, Any]: merged = dict(base) for key, value in override.items(): current = merged.get(key) diff --git a/agent/context_manager/manager.py b/agent/context_manager/manager.py index 6e43f147ba77d2650ee405742cee332905b562cd..85e96af0f6f3fa6d0426acddcd308281d502558b 100644 --- a/agent/context_manager/manager.py +++ b/agent/context_manager/manager.py @@ -3,7 +3,6 @@ Context management for conversation history """ import logging -import os import time import zoneinfo from datetime import datetime @@ -96,6 +95,7 @@ class CompactionFailedError(Exception): burns Bedrock budget for free (~$3 per re-attempt on Opus). """ + # Used when seeding a brand-new session from prior browser-cached messages. # Here we're writing a note to *ourselves* — so preserve the tool-call trail, # files produced, and planned next steps in first person. Optimized for @@ -155,12 +155,15 @@ async def summarize_messages( ) if session is not None: from agent.core import telemetry + await telemetry.record_llm_call( session, model=model_name, response=response, latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason if response.choices else None, + finish_reason=response.choices[0].finish_reason + if response.choices + else None, kind=kind, ) summary = response.choices[0].message.content or "" @@ -233,6 +236,7 @@ class ContextManager: # CLI-specific context for local mode if local_mode: import os + cwd = os.getcwd() local_context = ( f"\n\n# CLI / Local mode\n\n" @@ -305,7 +309,9 @@ class ContextManager: i = 0 while i < len(self.items): msg = self.items[i] - if getattr(msg, "role", None) != "assistant" or not getattr(msg, "tool_calls", None): + if getattr(msg, "role", None) != "assistant" or not getattr( + msg, "tool_calls", None + ): i += 1 continue @@ -316,7 +322,9 @@ class ContextManager: # before the next non-tool message to satisfy provider ordering. j = i + 1 immediate_ids: set[str | None] = set() - while j < len(self.items) and getattr(self.items[j], "role", None) == "tool": + while ( + j < len(self.items) and getattr(self.items[j], "role", None) == "tool" + ): immediate_ids.add(getattr(self.items[j], "tool_call_id", None)) j += 1 @@ -386,7 +394,9 @@ class ContextManager: @property def needs_compaction(self) -> bool: - return self.running_context_usage > self.compaction_threshold and bool(self.items) + return self.running_context_usage > self.compaction_threshold and bool( + self.items + ) def _truncate_oversized( self, messages: list[Message], model_name: str @@ -425,7 +435,9 @@ class ContextManager: ) logger.warning( "Truncating %s message: %d -> %d tokens for compaction", - msg.role, n, len(placeholder) // 4, + msg.role, + n, + len(placeholder) // 4, ) # Preserve all known assistant-side fields (tool_calls, thinking_blocks, # reasoning_content, provider_specific_fields) even when content is @@ -459,9 +471,9 @@ class ContextManager: except Exception as e: logger.warning("token_counter failed (%s); rough estimate", e) # Rough fallback: 4 chars per token. - self.running_context_usage = sum( - len(getattr(m, "content", "") or "") for m in self.items - ) // 4 + self.running_context_usage = ( + sum(len(getattr(m, "content", "") or "") for m in self.items) // 4 + ) async def compact( self, @@ -516,7 +528,7 @@ class ContextManager: idx = first_user_idx + 1 recent_messages = self.items[idx:] - messages_to_summarize = self.items[first_user_idx + 1:idx] + messages_to_summarize = self.items[first_user_idx + 1 : idx] # Truncate any message that's larger than _MAX_TOKENS_PER_MESSAGE in # the parts we PRESERVE through compaction (first_user + recent_tail). diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 38c0447dc92871515418000e79a4a4b1676a02ba..0eaa6e9d64b7bdb6e1addc09a4a837e80dff2cd2 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -5,7 +5,6 @@ Main agent implementation with integrated tool system and MCP support import asyncio import json import logging -import os import time from dataclasses import dataclass, field from typing import Any @@ -27,6 +26,7 @@ from agent.core.cost_estimation import CostEstimate, estimate_tool_cost from agent.messaging.gateway import NotificationGateway from agent.core import telemetry from agent.core.doom_loop import check_for_doom_loop +from agent.core.hub_artifacts import start_session_artifact_collection_task from agent.core.llm_params import _resolve_llm_params from agent.core.prompt_caching import with_prompt_caching from agent.core.session import Event, OpType, Session @@ -54,11 +54,12 @@ def _malformed_tool_name(message: Message) -> str | None: end = content.find(_MALFORMED_TOOL_SUFFIX, len(_MALFORMED_TOOL_PREFIX)) if end == -1: return None - return content[len(_MALFORMED_TOOL_PREFIX):end] + return content[len(_MALFORMED_TOOL_PREFIX) : end] def _detect_repeated_malformed( - items: list[Message], threshold: int = 2, + items: list[Message], + threshold: int = 2, ) -> str | None: """Return the repeated malformed tool name if the tail contains a streak. @@ -118,6 +119,7 @@ def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: _IMMEDIATE_HF_JOB_RUNS = {"run", "uv"} + @dataclass(frozen=True) class ApprovalDecision: requires_approval: bool @@ -142,7 +144,9 @@ def _is_scheduled_hf_job_run(tool_name: str, tool_args: dict) -> bool: def _is_budgeted_auto_approval_target(tool_name: str, tool_args: dict) -> bool: - return tool_name == "sandbox_create" or _is_immediate_hf_job_run(tool_name, tool_args) + return tool_name == "sandbox_create" or _is_immediate_hf_job_run( + tool_name, tool_args + ) def _base_needs_approval( @@ -231,7 +235,9 @@ def _session_auto_approval_enabled(session: Session | None) -> bool: def _effective_yolo_enabled(session: Session | None, config: Config | None) -> bool: - return bool((config and config.yolo_mode) or _session_auto_approval_enabled(session)) + return bool( + (config and config.yolo_mode) or _session_auto_approval_enabled(session) + ) def _remaining_budget_after_reservations( @@ -251,7 +257,10 @@ def _budget_block_reason( ) -> str | None: if estimate.estimated_cost_usd is None: return estimate.block_reason or "Could not estimate the cost safely." - if remaining_cap_usd is not None and estimate.estimated_cost_usd > remaining_cap_usd: + if ( + remaining_cap_usd is not None + and estimate.estimated_cost_usd > remaining_cap_usd + ): return ( f"Estimated cost ${estimate.estimated_cost_usd:.2f} exceeds " f"remaining YOLO cap ${remaining_cap_usd:.2f}." @@ -409,15 +418,25 @@ def _is_transient_error(error: Exception) -> bool: """Return True for errors that are likely transient and worth retrying.""" err_str = str(error).lower() transient_patterns = [ - "timeout", "timed out", - "503", "service unavailable", - "502", "bad gateway", - "500", "internal server error", - "overloaded", "capacity", - "connection reset", "connection refused", "connection error", - "eof", "broken pipe", + "timeout", + "timed out", + "503", + "service unavailable", + "502", + "bad gateway", + "500", + "internal server error", + "overloaded", + "capacity", + "connection reset", + "connection refused", + "connection error", + "eof", + "broken pipe", ] - return _is_rate_limit_error(error) or any(pattern in err_str for pattern in transient_patterns) + return _is_rate_limit_error(error) or any( + pattern in err_str for pattern in transient_patterns + ) def _is_effort_config_error(error: Exception) -> bool: @@ -429,11 +448,14 @@ def _is_effort_config_error(error: Exception) -> bool: doesn't work for the current model. We heal the cache and retry once. """ from agent.core.effort_probe import _is_invalid_effort, _is_thinking_unsupported + return _is_thinking_unsupported(error) or _is_invalid_effort(error) async def _heal_effort_and_rebuild_params( - session: Session, error: Exception, llm_params: dict, + session: Session, + error: Exception, + llm_params: dict, ) -> dict: """Update the session's effort cache based on ``error`` and return new llm_params. Called only when ``_is_effort_config_error(error)`` is True. @@ -444,7 +466,11 @@ async def _heal_effort_and_rebuild_params( • invalid-effort → re-run the full cascade probe; the result lands in the cache """ - from agent.core.effort_probe import ProbeInconclusive, _is_thinking_unsupported, probe_effort + from agent.core.effort_probe import ( + ProbeInconclusive, + _is_thinking_unsupported, + probe_effort, + ) model = session.config.model_name if _is_thinking_unsupported(error): @@ -453,12 +479,16 @@ async def _heal_effort_and_rebuild_params( else: try: outcome = await probe_effort( - model, session.config.reasoning_effort, session.hf_token, + model, + session.config.reasoning_effort, + session.hf_token, session=session, ) session.model_effective_effort[model] = outcome.effective_effort logger.info( - "healed: %s effort cascade → %s", model, outcome.effective_effort, + "healed: %s effort cascade → %s", + model, + outcome.effective_effort, ) except ProbeInconclusive: # Transient during healing — strip thinking for safety, next @@ -477,7 +507,11 @@ def _friendly_error_message(error: Exception) -> str | None: """Return a user-friendly message for known error types, or None to fall back to traceback.""" err_str = str(error).lower() - if "authentication" in err_str or "unauthorized" in err_str or "invalid x-api-key" in err_str: + if ( + "authentication" in err_str + or "unauthorized" in err_str + or "invalid x-api-key" in err_str + ): return ( "Authentication failed — your API key is missing or invalid.\n\n" "To fix this, set the API key for your model provider:\n" @@ -503,8 +537,7 @@ def _friendly_error_message(error: Exception) -> str | None: ) if "model_not_found" in err_str or ( - "model" in err_str - and ("not found" in err_str or "does not exist" in err_str) + "model" in err_str and ("not found" in err_str or "does not exist" in err_str) ): return ( "Model not found. Use '/model' to list suggestions, or paste an " @@ -530,7 +563,10 @@ async def _compact_and_notify(session: Session) -> None: old_usage = cm.running_context_usage logger.debug( "Compaction check: usage=%d, max=%d, threshold=%d, needs_compact=%s", - old_usage, cm.model_max_tokens, cm.compaction_threshold, cm.needs_compaction, + old_usage, + cm.model_max_tokens, + cm.compaction_threshold, + cm.needs_compaction, ) try: await cm.compact( @@ -542,24 +578,27 @@ async def _compact_and_notify(session: Session) -> None: except CompactionFailedError as e: logger.error( "Compaction failed for session %s: %s — terminating session", - session.session_id, e, + session.session_id, + e, ) # Persist the failure event so the dataset has a record of WHY this # session ended (and the cost it incurred up to that point) even if # save_and_upload_detached has issues downstream. - await session.send_event(Event( - event_type="session_terminated", - data={ - "reason": "compaction_failed", - "context_usage": cm.running_context_usage, - "context_threshold": cm.compaction_threshold, - "error": str(e)[:300], - "user_message": ( - "Your conversation has grown too large to continue. " - "The work you've done is saved — start a new session to keep going." - ), - }, - )) + await session.send_event( + Event( + event_type="session_terminated", + data={ + "reason": "compaction_failed", + "context_usage": cm.running_context_usage, + "context_threshold": cm.compaction_threshold, + "error": str(e)[:300], + "user_message": ( + "Your conversation has grown too large to continue. " + "The work you've done is saved — start a new session to keep going." + ), + }, + ) + ) # Stop the agent loop; the finally in _run_session will fire # cleanup_sandbox + save_trajectory so the dataset captures # everything that did happen. @@ -570,7 +609,10 @@ async def _compact_and_notify(session: Session) -> None: if new_usage != old_usage: logger.warning( "Context compacted: %d -> %d tokens (max=%d, %d messages)", - old_usage, new_usage, cm.model_max_tokens, len(cm.items), + old_usage, + new_usage, + cm.model_max_tokens, + len(cm.items), ) await session.send_event( Event( @@ -609,6 +651,7 @@ async def _cleanup_on_cancel(session: Session) -> None: @dataclass class LLMResult: """Result from an LLM call (streaming or non-streaming).""" + content: str | None tool_calls_acc: dict[int, dict] token_count: int @@ -728,16 +771,18 @@ async def _maybe_heal_invalid_thinking_signature( if not stripped: return False - await session.send_event(Event( - event_type="tool_log", - data={ - "tool": "system", - "log": ( - "Anthropic rejected stale thinking signatures; retrying " - "without replayed thinking metadata." - ), - }, - )) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": ( + "Anthropic rejected stale thinking signatures; retrying " + "without replayed thinking metadata." + ), + }, + ) + ) return True @@ -762,7 +807,9 @@ def _assistant_message_from_result( return Message(**kwargs) -async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult: +async def _call_llm_streaming( + session: Session, messages, tools, llm_params +) -> LLMResult: """Call the LLM with streaming, emitting assistant_chunk events.""" response = None _healed_effort = False # one-shot safety net per call @@ -788,11 +835,18 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> raise ContextWindowExceededError(str(e)) from e if not _healed_effort and _is_effort_config_error(e): _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params) - await session.send_event(Event( - event_type="tool_log", - data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, - )) + llm_params = await _heal_effort_and_rebuild_params( + session, e, llm_params + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": "Reasoning effort not supported for this model — adjusting and retrying.", + }, + ) + ) continue if await _maybe_heal_invalid_thinking_signature( session, @@ -806,12 +860,20 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: logger.warning( "Transient LLM error (attempt %d/%d): %s — retrying in %ds", - _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, + _llm_attempt + 1, + _MAX_LLM_RETRIES, + e, + _delay, + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": f"LLM connection error, retrying in {_delay}s...", + }, + ) ) - await session.send_event(Event( - event_type="tool_log", - data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, - )) await asyncio.sleep(_delay) continue raise @@ -852,16 +914,21 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> idx = tc_delta.index if idx not in tool_calls_acc: tool_calls_acc[idx] = { - "id": "", "type": "function", + "id": "", + "type": "function", "function": {"name": "", "arguments": ""}, } if tc_delta.id: tool_calls_acc[idx]["id"] = tc_delta.id if tc_delta.function: if tc_delta.function.name: - tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name + tool_calls_acc[idx]["function"]["name"] += ( + tc_delta.function.name + ) if tc_delta.function.arguments: - tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments + tool_calls_acc[idx]["function"]["arguments"] += ( + tc_delta.function.arguments + ) if hasattr(chunk, "usage") and chunk.usage: token_count = chunk.usage.total_tokens @@ -881,7 +948,9 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> rebuilt = stream_chunk_builder(chunks, messages=messages) if rebuilt and getattr(rebuilt, "choices", None): rebuilt_msg = rebuilt.choices[0].message - thinking_blocks, reasoning_content = _extract_thinking_state(rebuilt_msg) + thinking_blocks, reasoning_content = _extract_thinking_state( + rebuilt_msg + ) except Exception: logger.debug("Failed to rebuild streaming thinking state", exc_info=True) @@ -896,7 +965,9 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> ) -async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) -> LLMResult: +async def _call_llm_non_streaming( + session: Session, messages, tools, llm_params +) -> LLMResult: """Call the LLM without streaming, emit assistant_message at the end.""" response = None _healed_effort = False @@ -921,11 +992,18 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) raise ContextWindowExceededError(str(e)) from e if not _healed_effort and _is_effort_config_error(e): _healed_effort = True - llm_params = await _heal_effort_and_rebuild_params(session, e, llm_params) - await session.send_event(Event( - event_type="tool_log", - data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, - )) + llm_params = await _heal_effort_and_rebuild_params( + session, e, llm_params + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": "Reasoning effort not supported for this model — adjusting and retrying.", + }, + ) + ) continue if await _maybe_heal_invalid_thinking_signature( session, @@ -939,12 +1017,20 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: logger.warning( "Transient LLM error (attempt %d/%d): %s — retrying in %ds", - _llm_attempt + 1, _MAX_LLM_RETRIES, e, _delay, + _llm_attempt + 1, + _MAX_LLM_RETRIES, + e, + _delay, + ) + await session.send_event( + Event( + event_type="tool_log", + data={ + "tool": "system", + "log": f"LLM connection error, retrying in {_delay}s...", + }, + ) ) - await session.send_event(Event( - event_type="tool_log", - data={"tool": "system", "log": f"LLM connection error, retrying in {_delay}s..."}, - )) await asyncio.sleep(_delay) continue raise @@ -1037,7 +1123,8 @@ class Handlers: @staticmethod async def run_agent( - session: Session, text: str, + session: Session, + text: str, ) -> str | None: """ Handle user input (like user_input_or_turn in codex.rs:1291) @@ -1124,12 +1211,18 @@ class Handlers: llm_params = _resolve_llm_params( session.config.model_name, session.hf_token, - reasoning_effort=session.effective_effort_for(session.config.model_name), + reasoning_effort=session.effective_effort_for( + session.config.model_name + ), ) if session.stream: - llm_result = await _call_llm_streaming(session, messages, tools, llm_params) + llm_result = await _call_llm_streaming( + session, messages, tools, llm_params + ) else: - llm_result = await _call_llm_non_streaming(session, messages, tools, llm_params) + llm_result = await _call_llm_non_streaming( + session, messages, tools, llm_params + ) content = llm_result.content tool_calls_acc = llm_result.tool_calls_acc @@ -1176,7 +1269,10 @@ class Handlers: await session.send_event( Event( event_type="tool_log", - data={"tool": "system", "log": f"Output truncated — retrying with smaller content ({dropped_names})"}, + data={ + "tool": "system", + "log": f"Output truncated — retrying with smaller content ({dropped_names})", + }, ) ) iteration += 1 @@ -1239,7 +1335,8 @@ class Handlers: except (json.JSONDecodeError, TypeError, ValueError): logger.warning( "Malformed arguments for tool_call %s (%s) — skipping", - tc.id, tc.function.name, + tc.id, + tc.function.name, ) tc.function.arguments = "{}" bad_tools.append(tc) @@ -1260,20 +1357,35 @@ class Handlers: f"arguments and was NOT executed. Retry with smaller content — " f"for 'write', split into multiple smaller writes using 'edit'." ) - session.context_manager.add_message(Message( - role="tool", - content=error_msg, - tool_call_id=tc.id, - name=tc.function.name, - )) - await session.send_event(Event( - event_type="tool_call", - data={"tool": tc.function.name, "arguments": {}, "tool_call_id": tc.id}, - )) - await session.send_event(Event( - event_type="tool_output", - data={"tool": tc.function.name, "tool_call_id": tc.id, "output": error_msg, "success": False}, - )) + session.context_manager.add_message( + Message( + role="tool", + content=error_msg, + tool_call_id=tc.id, + name=tc.function.name, + ) + ) + await session.send_event( + Event( + event_type="tool_call", + data={ + "tool": tc.function.name, + "arguments": {}, + "tool_call_id": tc.id, + }, + ) + ) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tc.function.name, + "tool_call_id": tc.id, + "output": error_msg, + "success": False, + }, + ) + ) # ── Cancellation check: before tool execution ── if session.is_cancelled: @@ -1298,7 +1410,9 @@ class Handlers: reserved_spend_usd=reserved_auto_spend_usd, ) if decision.requires_approval: - approval_required_tools.append((tc, tool_name, tool_args, decision)) + approval_required_tools.append( + (tc, tool_name, tool_args, decision) + ) else: non_approval_tools.append((tc, tool_name, tool_args, decision)) if ( @@ -1321,7 +1435,14 @@ class Handlers: ) # 2. Send all tool_call events upfront (so frontend shows them all) - for tc, tool_name, tool_args, _decision, args_valid, _ in parsed_tools: + for ( + tc, + tool_name, + tool_args, + _decision, + args_valid, + _, + ) in parsed_tools: if args_valid: await session.send_event( Event( @@ -1352,12 +1473,14 @@ class Handlers: ) return (tc, name, args, out, ok) - gather_task = asyncio.ensure_future(asyncio.gather( - *[ - _exec_tool(tc, name, args, decision, valid, err) - for tc, name, args, decision, valid, err in parsed_tools - ] - )) + gather_task = asyncio.ensure_future( + asyncio.gather( + *[ + _exec_tool(tc, name, args, decision, valid, err) + for tc, name, args, decision, valid, err in parsed_tools + ] + ) + ) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( @@ -1374,10 +1497,16 @@ class Handlers: # Notify frontend that in-flight tools were cancelled for tc, name, _args, _decision, valid, _ in parsed_tools: if valid: - await session.send_event(Event( - event_type="tool_state_change", - data={"tool_call_id": tc.id, "tool": name, "state": "cancelled"}, - )) + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": name, + "state": "cancelled", + }, + ) + ) await _cleanup_on_cancel(session) break @@ -1414,10 +1543,15 @@ class Handlers: for tc, tool_name, tool_args, decision in approval_required_tools: # Resolve sandbox file paths for hf_jobs scripts so the # frontend can display & edit the actual file content. - if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str): + if tool_name == "hf_jobs" and isinstance( + tool_args.get("script"), str + ): from agent.tools.sandbox_tool import resolve_sandbox_script + sandbox = getattr(session, "sandbox", None) - resolved, _ = await resolve_sandbox_script(sandbox, tool_args["script"]) + resolved, _ = await resolve_sandbox_script( + sandbox, tool_args["script"] + ) if resolved: tool_args = {**tool_args, "script": resolved} @@ -1449,10 +1583,12 @@ class Handlers: "remaining_cap_usd": first.get("remaining_cap_usd"), } ) - await session.send_event(Event( - event_type="approval_required", - data=event_data, - )) + await session.send_event( + Event( + event_type="approval_required", + data=event_data, + ) + ) # Store all approval-requiring tools (ToolCall objects for execution) session.pending_approval = { @@ -1470,7 +1606,10 @@ class Handlers: logger.warning( "ContextWindowExceededError at iteration %d — forcing compaction " "(usage=%d, model_max_tokens=%d, messages=%d)", - iteration, cm.running_context_usage, cm.model_max_tokens, len(cm.items), + iteration, + cm.running_context_usage, + cm.model_max_tokens, + len(cm.items), ) cm.running_context_usage = cm.model_max_tokens + 1 await _compact_and_notify(session) @@ -1662,13 +1801,15 @@ class Handlers: # Execute all approved tools concurrently (cancellable) if approved_tasks: - gather_task = asyncio.ensure_future(asyncio.gather( - *[ - execute_tool(tc, tool_name, tool_args, was_edited) - for tc, tool_name, tool_args, was_edited in approved_tasks - ], - return_exceptions=True, - )) + gather_task = asyncio.ensure_future( + asyncio.gather( + *[ + execute_tool(tc, tool_name, tool_args, was_edited) + for tc, tool_name, tool_args, was_edited in approved_tasks + ], + return_exceptions=True, + ) + ) cancel_task = asyncio.ensure_future(session._cancelled.wait()) done, _ = await asyncio.wait( @@ -1684,10 +1825,16 @@ class Handlers: pass # Notify frontend that approved tools were cancelled for tc, tool_name, _args, _was_edited in approved_tasks: - await session.send_event(Event( - event_type="tool_state_change", - data={"tool_call_id": tc.id, "tool": tool_name, "state": "cancelled"}, - )) + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "cancelled", + }, + ) + ) await _cleanup_on_cancel(session) await session.send_event(Event(event_type="interrupted")) session.increment_turn() @@ -1839,14 +1986,20 @@ async def submission_loop( # Create session with tool router session = Session( - event_queue, config=config, tool_router=tool_router, hf_token=hf_token, - user_id=user_id, local_mode=local_mode, stream=stream, + event_queue, + config=config, + tool_router=tool_router, + hf_token=hf_token, + user_id=user_id, + local_mode=local_mode, + stream=stream, notification_gateway=notification_gateway, notification_destinations=notification_destinations, defer_turn_complete_notification=defer_turn_complete_notification, ) if session_holder is not None: session_holder[0] = session + start_session_artifact_collection_task(session, token=hf_token) logger.info("Agent loop started") # Retry any failed uploads from previous sessions (fire-and-forget). @@ -1864,10 +2017,13 @@ async def submission_loop( async with tool_router: # Emit ready event after initialization await session.send_event( - Event(event_type="ready", data={ - "message": "Agent initialized", - "tool_count": len(tool_router.tools), - }) + Event( + event_type="ready", + data={ + "message": "Agent initialized", + "tool_count": len(tool_router.tools), + }, + ) ) while session.is_running: diff --git a/agent/core/cost_estimation.py b/agent/core/cost_estimation.py index f1f98ec828ea1e1639192eadb884ca7283379038..a41ad196efec7495c7ca9141d2f7f3a4f38e6dbd 100644 --- a/agent/core/cost_estimation.py +++ b/agent/core/cost_estimation.py @@ -88,7 +88,9 @@ class CostEstimate: label: str | None = None -def parse_timeout_hours(value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS) -> float | None: +def parse_timeout_hours( + value: Any, *, default_hours: float = DEFAULT_JOB_TIMEOUT_HOURS +) -> float | None: """Parse HF timeout values into hours. Strings accept ``s``, ``m``, ``h``, or ``d`` suffixes. Numeric values are @@ -247,7 +249,9 @@ async def estimate_hf_job_cost(args: dict[str, Any]) -> CostEstimate: ) -async def estimate_sandbox_cost(args: dict[str, Any], *, session: Any = None) -> CostEstimate: +async def estimate_sandbox_cost( + args: dict[str, Any], *, session: Any = None +) -> CostEstimate: if session is not None and getattr(session, "sandbox", None): return CostEstimate(estimated_cost_usd=0.0, billable=False, label="existing") diff --git a/agent/core/doom_loop.py b/agent/core/doom_loop.py index 878c7c00adfb4f8ea3fa7f068493ed8358d76b8d..3b57fe2cc3cffd07b466db9ac98cc0d0b665de79 100644 --- a/agent/core/doom_loop.py +++ b/agent/core/doom_loop.py @@ -81,9 +81,11 @@ def extract_recent_tool_signatures( name = getattr(fn, "name", "") or "" args_str = getattr(fn, "arguments", "") or "" result_hash = None - for follow in recent[idx + 1:]: + for follow in recent[idx + 1 :]: role = getattr(follow, "role", None) - if role == "tool" and getattr(follow, "tool_call_id", None) == getattr(tc, "id", None): + if role == "tool" and getattr(follow, "tool_call_id", None) == getattr( + tc, "id", None + ): result_hash = _hash_args(str(getattr(follow, "content", "") or "")) break if role in {"assistant", "user"}: @@ -174,7 +176,9 @@ def check_for_doom_loop(messages: list[Message]) -> str | None: pattern = detect_repeating_sequence(signatures) if pattern: pattern_desc = " → ".join(s.name for s in pattern) - logger.warning("Repetition guard activated: repeating sequence [%s]", pattern_desc) + logger.warning( + "Repetition guard activated: repeating sequence [%s]", pattern_desc + ) return ( f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: " f"[{pattern_desc}]. This pattern has repeated multiple times without progress. " diff --git a/agent/core/effort_probe.py b/agent/core/effort_probe.py index b6ac91f633b75c72b3baae3537847bf71dfb6137..dbad4c3da95e939ec9d2dae5c6c7408bcc6ea156 100644 --- a/agent/core/effort_probe.py +++ b/agent/core/effort_probe.py @@ -39,12 +39,12 @@ logger = logging.getLogger(__name__) # requested level raise ``UnsupportedEffortError`` synchronously (no wasted # network round-trip) and we advance to the next level. _EFFORT_CASCADE: dict[str, list[str]] = { - "max": ["max", "xhigh", "high", "medium", "low"], - "xhigh": ["xhigh", "high", "medium", "low"], - "high": ["high", "medium", "low"], - "medium": ["medium", "low"], + "max": ["max", "xhigh", "high", "medium", "low"], + "xhigh": ["xhigh", "high", "medium", "low"], + "high": ["high", "medium", "low"], + "medium": ["medium", "low"], "minimal": ["minimal", "low"], - "low": ["low"], + "low": ["low"], } _PROBE_TIMEOUT = 15.0 @@ -69,6 +69,7 @@ class ProbeOutcome: * str → send this level * None → model doesn't support thinking; strip it """ + effective_effort: str | None attempts: int elapsed_ms: int @@ -108,10 +109,15 @@ def _is_invalid_effort(e: Exception) -> bool: return any( phrase in s for phrase in ( - "invalid", "not supported", "must be one of", "not a valid", - "unrecognized", "unknown", + "invalid", + "not supported", + "must be one of", + "not a valid", + "unrecognized", + "unknown", # LiteLLM's own pre-flight validation phrasing. - "only supported by", "is only supported", + "only supported by", + "is only supported", ) ) @@ -128,11 +134,23 @@ def _is_transient(e: Exception) -> bool: return any( p in s for p in ( - "timeout", "timed out", "429", "rate limit", - "503", "service unavailable", "502", "bad gateway", - "500", "internal server error", "overloaded", "capacity", - "connection reset", "connection refused", "connection error", - "eof", "broken pipe", + "timeout", + "timed out", + "429", + "rate limit", + "503", + "service unavailable", + "502", + "bad gateway", + "500", + "internal server error", + "overloaded", + "capacity", + "connection reset", + "connection refused", + "connection error", + "eof", + "broken pipe", ) ) @@ -173,7 +191,10 @@ async def probe_effort( for effort in cascade: try: params = _resolve_llm_params( - model_name, hf_token, reasoning_effort=effort, strict=True, + model_name, + hf_token, + reasoning_effort=effort, + strict=True, ) except UnsupportedEffortError: # Provider can't even accept this effort name (e.g. "max" on @@ -198,12 +219,15 @@ async def probe_effort( # out of the probe and break model switching. try: from agent.core import telemetry + await telemetry.record_llm_call( session, model=model_name, response=response, latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason if response.choices else None, + finish_reason=response.choices[0].finish_reason + if response.choices + else None, kind="effort_probe", ) except Exception as _telem_err: @@ -219,7 +243,9 @@ async def probe_effort( note="model doesn't support reasoning, dropped", ) if _is_invalid_effort(e): - logger.debug("probe: %s rejected effort=%s, trying next", model_name, effort) + logger.debug( + "probe: %s rejected effort=%s, trying next", model_name, effort + ) continue if _is_transient(e): raise ProbeInconclusive(str(e)) from e diff --git a/agent/core/hf_router_catalog.py b/agent/core/hf_router_catalog.py index f6f519d034b18ea68366a0d61eb7f47319b2530a..625ccf4fb85498e229fe63dc0faac56628d0be39 100644 --- a/agent/core/hf_router_catalog.py +++ b/agent/core/hf_router_catalog.py @@ -92,7 +92,9 @@ def _parse_entry(entry: dict) -> ModelInfo: input_price=pricing.get("input"), output_price=pricing.get("output"), supports_tools=bool(p.get("supports_tools", False)), - supports_structured_output=bool(p.get("supports_structured_output", False)), + supports_structured_output=bool( + p.get("supports_structured_output", False) + ), ) ) return ModelInfo(id=entry.get("id", ""), providers=providers) diff --git a/agent/core/hub_artifacts.py b/agent/core/hub_artifacts.py new file mode 100644 index 0000000000000000000000000000000000000000..d317ee6704de42459291fcb4a205258a78b02e13 --- /dev/null +++ b/agent/core/hub_artifacts.py @@ -0,0 +1,765 @@ +"""Best-effort Hub metadata for artifacts generated by ML Intern sessions.""" + +import asyncio +import base64 +import logging +import re +import shlex +import tempfile +import textwrap +from datetime import datetime +from pathlib import Path +from typing import Any + +from huggingface_hub import HfApi, hf_hub_download +from huggingface_hub.repocard import metadata_load, metadata_save +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError + +logger = logging.getLogger(__name__) + +ML_INTERN_TAG = "ml-intern" +SUPPORTED_REPO_TYPES = {"model", "dataset", "space"} +PROVENANCE_MARKER = "" +_COLLECTION_TITLE_PREFIX = "ml-intern-artifacts" +_COLLECTION_TITLE_MAX_LENGTH = 59 +_UUID_SESSION_ID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-" + r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$" +) +_KNOWN_ARTIFACTS_ATTR = "_ml_intern_known_hub_artifacts" +_REGISTERED_ARTIFACTS_ATTR = "_ml_intern_registered_hub_artifacts" +_COLLECTION_SLUG_ATTR = "_ml_intern_artifact_collection_slug" +_COLLECTION_TASK_ATTR = "_ml_intern_artifact_collection_task" +_SESSION_ARTIFACT_SET_FALLBACK: dict[tuple[int, str], set[str]] = {} +_USAGE_HEADING_RE = re.compile( + r"^#{2,6}\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\b", + re.IGNORECASE | re.MULTILINE, +) +_FRONT_MATTER_RE = re.compile(r"\A---\s*\n.*?\n---\s*\n?", re.DOTALL) + + +def _safe_session_id(session: Any) -> str: + raw = str(getattr(session, "session_id", "") or "unknown-session") + safe = re.sub(r"[^A-Za-z0-9._-]+", "-", raw).strip("-") + return safe or "unknown-session" + + +def session_artifact_date(session: Any) -> str: + """Return the YYYY-MM-DD partition date for a session.""" + raw = getattr(session, "session_start_time", None) + if raw: + try: + return datetime.fromisoformat(str(raw).replace("Z", "+00:00")).strftime( + "%Y-%m-%d" + ) + except ValueError: + logger.debug("Could not parse session_start_time=%r", raw) + return datetime.utcnow().strftime("%Y-%m-%d") + + +def _collection_session_id_fragment(session: Any) -> str: + safe_id = _safe_session_id(session) + if _UUID_SESSION_ID_RE.match(safe_id): + return safe_id[:8] + stem = f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" + max_id_length = max(1, _COLLECTION_TITLE_MAX_LENGTH - len(stem)) + if len(safe_id) <= max_id_length: + return safe_id + return safe_id[:max_id_length].rstrip("-._") or safe_id[:max_id_length] + + +def artifact_collection_title(session: Any) -> str: + return ( + f"{_COLLECTION_TITLE_PREFIX}-{session_artifact_date(session)}-" + f"{_collection_session_id_fragment(session)}" + ) + + +def _artifact_key(repo_id: str, repo_type: str | None) -> str: + return f"{repo_type or 'model'}:{repo_id}" + + +def _session_artifact_set(session: Any, attr: str) -> set[str]: + current = getattr(session, attr, None) + if isinstance(current, set): + return current + current = set() + try: + setattr(session, attr, current) + except Exception: + logger.warning( + "Could not attach %s to session; using process-local fallback state", + attr, + ) + return _SESSION_ARTIFACT_SET_FALLBACK.setdefault((id(session), attr), set()) + return current + + +def remember_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> None: + if session is None or not repo_id: + return + _session_artifact_set(session, _KNOWN_ARTIFACTS_ATTR).add( + _artifact_key(repo_id, repo_type) + ) + + +def is_known_hub_artifact(session: Any, repo_id: str, repo_type: str | None) -> bool: + if session is None or not repo_id: + return False + return _artifact_key(repo_id, repo_type) in _session_artifact_set( + session, _KNOWN_ARTIFACTS_ATTR + ) + + +def _merge_tags(metadata: dict[str, Any], tag: str = ML_INTERN_TAG) -> dict[str, Any]: + merged = dict(metadata) + raw_tags = merged.get("tags") + if raw_tags is None: + tags: list[str] = [] + elif isinstance(raw_tags, str): + tags = [raw_tags] + elif isinstance(raw_tags, list): + tags = [str(item) for item in raw_tags] + else: + tags = [str(raw_tags)] + + if tag not in tags: + tags.append(tag) + merged["tags"] = tags + return merged + + +def _metadata_from_content(content: str) -> dict[str, Any]: + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) / "README.md" + path.write_text(content, encoding="utf-8") + return metadata_load(path) or {} + + +def _content_with_metadata(content: str, metadata: dict[str, Any]) -> str: + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) / "README.md" + path.write_text(content, encoding="utf-8") + metadata_save(path, metadata) + return path.read_text(encoding="utf-8") + + +def _body_without_metadata(content: str) -> str: + return _FRONT_MATTER_RE.sub("", content, count=1).strip() + + +def _append_section(content: str, section: str) -> str: + base = content.rstrip() + if base: + return f"{base}\n\n{section.strip()}\n" + return f"{section.strip()}\n" + + +def _provenance_section(repo_type: str) -> str: + label = {"model": "model", "dataset": "dataset"}.get(repo_type, "Hub") + return f"""{PROVENANCE_MARKER} +## Generated by ML Intern + +This {label} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub. + +- Try ML Intern: https://smolagents-ml-intern.hf.space +- Source code: https://github.com/huggingface/ml-intern +""" + + +def _usage_section(repo_id: str, repo_type: str) -> str: + if repo_type == "dataset": + return f"""## Usage + +```python +from datasets import load_dataset + +dataset = load_dataset("{repo_id}") +``` +""" + + return f"""## Usage + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_id = "{repo_id}" +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id) +``` + +For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class. +""" + + +def augment_repo_card_content( + content: str | None, + repo_id: str, + repo_type: str = "model", + *, + extra_metadata: dict[str, Any] | None = None, +) -> str: + """Return README content with ML Intern metadata and provenance added.""" + repo_type = repo_type or "model" + content = content or "" + metadata = _metadata_from_content(content) + if extra_metadata: + metadata = {**extra_metadata, **metadata} + metadata = _merge_tags(metadata) + updated = _content_with_metadata(content, metadata) + + if not _body_without_metadata(updated): + updated = _append_section(updated, f"# {repo_id}") + + if repo_type in {"model", "dataset"} and PROVENANCE_MARKER not in updated: + updated = _append_section(updated, _provenance_section(repo_type)) + if not _USAGE_HEADING_RE.search(content): + updated = _append_section(updated, _usage_section(repo_id, repo_type)) + + return updated + + +def _read_remote_readme( + api: Any, + repo_id: str, + repo_type: str, + *, + token: str | bool | None = None, +) -> str: + token_value = token if token is not None else getattr(api, "token", None) + try: + readme_path = hf_hub_download( + repo_id=repo_id, + filename="README.md", + repo_type=repo_type, + token=token_value, + ) + except (EntryNotFoundError, RepositoryNotFoundError): + return "" + return Path(readme_path).read_text(encoding="utf-8") + + +def _update_repo_card( + api: Any, + repo_id: str, + repo_type: str, + *, + token: str | bool | None = None, + extra_metadata: dict[str, Any] | None = None, +) -> None: + current = _read_remote_readme(api, repo_id, repo_type, token=token) + updated = augment_repo_card_content( + current, + repo_id, + repo_type, + extra_metadata=extra_metadata, + ) + if updated == current: + return + api.upload_file( + path_or_fileobj=updated.encode("utf-8"), + path_in_repo="README.md", + repo_id=repo_id, + repo_type=repo_type, + token=token, + commit_message="Update ML Intern artifact metadata", + ) + + +def _ensure_collection_slug( + api: Any, + session: Any, + *, + token: str | bool | None = None, +) -> str | None: + slug = getattr(session, _COLLECTION_SLUG_ATTR, None) + if slug: + return slug + + title = artifact_collection_title(session) + collection = api.create_collection( + title=title, + description=( + f"Artifacts generated by ML Intern session {_safe_session_id(session)} " + f"on {session_artifact_date(session)}." + ), + private=True, + exists_ok=True, + token=token, + ) + slug = getattr(collection, "slug", None) + if slug: + setattr(session, _COLLECTION_SLUG_ATTR, slug) + return slug + + +async def ensure_session_artifact_collection( + session: Any, + *, + token: str | bool | None = None, +) -> str | None: + """Create/cache the per-session artifact collection without raising.""" + if session is None or not getattr(session, "session_id", None): + return None + token_value = token if token is not None else getattr(session, "hf_token", None) + if not token_value: + return None + + try: + api = HfApi(token=token_value) + return await asyncio.to_thread( + _ensure_collection_slug, + api, + session, + token=token_value, + ) + except Exception as e: + logger.warning( + "ML Intern session collection creation failed for %s: %s", + _safe_session_id(session), + e, + ) + return None + + +def start_session_artifact_collection_task( + session: Any, + *, + token: str | bool | None = None, +) -> asyncio.Task | None: + """Schedule best-effort collection creation for a newly started session.""" + if session is None or not getattr(session, "session_id", None): + return None + if getattr(session, _COLLECTION_SLUG_ATTR, None): + return None + + token_value = token if token is not None else getattr(session, "hf_token", None) + if not token_value: + return None + + existing = getattr(session, _COLLECTION_TASK_ATTR, None) + if isinstance(existing, asyncio.Task) and not existing.done(): + return existing + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return None + + async def _run() -> None: + await ensure_session_artifact_collection(session, token=token_value) + + task = loop.create_task(_run()) + try: + setattr(session, _COLLECTION_TASK_ATTR, task) + except Exception: + logger.debug("Could not attach ML Intern collection task to session") + return task + + +def _add_to_collection( + api: Any, + session: Any, + repo_id: str, + repo_type: str, + *, + token: str | bool | None = None, +) -> None: + slug = _ensure_collection_slug(api, session, token=token) + if not slug: + return + api.add_collection_item( + collection_slug=slug, + item_id=repo_id, + item_type=repo_type, + note=( + f"Generated by ML Intern session {_safe_session_id(session)} " + f"on {session_artifact_date(session)}." + ), + exists_ok=True, + token=token, + ) + + +def register_hub_artifact( + api: Any, + repo_id: str, + repo_type: str = "model", + *, + session: Any = None, + token: str | bool | None = None, + extra_metadata: dict[str, Any] | None = None, + force: bool = False, +) -> bool: + """Tag, card, and collection-register a Hub artifact without raising.""" + if session is None or not repo_id: + return False + repo_type = repo_type or "model" + if repo_type not in SUPPORTED_REPO_TYPES: + return False + + key = _artifact_key(repo_id, repo_type) + remember_hub_artifact(session, repo_id, repo_type) + registered = _session_artifact_set(session, _REGISTERED_ARTIFACTS_ATTR) + if key in registered and not force: + return True + + token_value = token if token is not None else getattr(api, "token", None) + card_updated = False + collection_updated = False + try: + _update_repo_card( + api, + repo_id, + repo_type, + token=token_value, + extra_metadata=extra_metadata, + ) + card_updated = True + except Exception as e: + logger.debug("ML Intern repo-card update failed for %s: %s", repo_id, e) + + try: + _add_to_collection(api, session, repo_id, repo_type, token=token_value) + collection_updated = True + except Exception as e: + logger.debug("ML Intern collection update failed for %s: %s", repo_id, e) + + if card_updated and collection_updated: + registered.add(key) + return True + return False + + +def build_hub_artifact_sitecustomize(session: Any) -> str: + """Build standalone sitecustomize.py code for HF Jobs Python processes.""" + if session is None or not getattr(session, "session_id", None): + return "" + + session_id = _safe_session_id(session) + session_date = session_artifact_date(session) + collection_title = artifact_collection_title(session) + collection_slug = getattr(session, _COLLECTION_SLUG_ATTR, None) + + return ( + textwrap.dedent( + f""" + # Auto-generated by ML Intern. Best-effort Hub artifact metadata only. + def _install_ml_intern_artifact_hooks(): + import os + import re + import tempfile + from pathlib import Path + + try: + import huggingface_hub as _hub + from huggingface_hub import HfApi, hf_hub_download + from huggingface_hub.repocard import metadata_load, metadata_save + from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError + except Exception: + return + + session_id = {session_id!r} + session_date = {session_date!r} + collection_title = {collection_title!r} + tag = {ML_INTERN_TAG!r} + marker = {PROVENANCE_MARKER!r} + supported = {sorted(SUPPORTED_REPO_TYPES)!r} + registering = False + collection_slug = {collection_slug!r} + registered = set() + usage_re = re.compile( + r"^#{{2,6}}\\s+(usage|how to use|using this (model|dataset)|use this (model|dataset))\\b", + re.IGNORECASE | re.MULTILINE, + ) + front_matter_re = re.compile(r"\\A---\\s*\\n.*?\\n---\\s*\\n?", re.DOTALL) + + def _token(value=None, api=None): + if isinstance(value, str) and value: + return value + api_token = getattr(api, "token", None) + if isinstance(api_token, str) and api_token: + return api_token + return ( + os.environ.get("HF_TOKEN") + or os.environ.get("HUGGINGFACE_HUB_TOKEN") + or None + ) + + def _merge_tags(metadata): + metadata = dict(metadata or {{}}) + raw_tags = metadata.get("tags") + if raw_tags is None: + tags = [] + elif isinstance(raw_tags, str): + tags = [raw_tags] + elif isinstance(raw_tags, list): + tags = [str(item) for item in raw_tags] + else: + tags = [str(raw_tags)] + if tag not in tags: + tags.append(tag) + metadata["tags"] = tags + return metadata + + def _metadata_from_content(content): + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) / "README.md" + path.write_text(content or "", encoding="utf-8") + return metadata_load(path) or {{}} + + def _content_with_metadata(content, metadata): + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) / "README.md" + path.write_text(content or "", encoding="utf-8") + metadata_save(path, metadata) + return path.read_text(encoding="utf-8") + + def _body_without_metadata(content): + return front_matter_re.sub("", content or "", count=1).strip() + + def _append_section(content, section): + base = (content or "").rstrip() + if base: + return base + "\\n\\n" + section.strip() + "\\n" + return section.strip() + "\\n" + + def _provenance(repo_type): + label = {{"model": "model", "dataset": "dataset"}}.get( + repo_type, "Hub" + ) + return ( + marker + + "\\n## Generated by ML Intern\\n\\n" + + f"This {{label}} repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.\\n\\n" + + "- Try ML Intern: https://smolagents-ml-intern.hf.space\\n" + + "- Source code: https://github.com/huggingface/ml-intern\\n" + ) + + def _usage(repo_id, repo_type): + if repo_type == "dataset": + return ( + "## Usage\\n\\n" + "```python\\n" + "from datasets import load_dataset\\n\\n" + f"dataset = load_dataset({{repo_id!r}})\\n" + "```\\n" + ) + return ( + "## Usage\\n\\n" + "```python\\n" + "from transformers import AutoModelForCausalLM, AutoTokenizer\\n\\n" + f"model_id = {{repo_id!r}}\\n" + "tokenizer = AutoTokenizer.from_pretrained(model_id)\\n" + "model = AutoModelForCausalLM.from_pretrained(model_id)\\n" + "```\\n\\n" + "For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.\\n" + ) + + def _augment(content, repo_id, repo_type, extra_metadata=None): + metadata = _metadata_from_content(content or "") + if extra_metadata: + metadata = {{**extra_metadata, **metadata}} + updated = _content_with_metadata(content or "", _merge_tags(metadata)) + if not _body_without_metadata(updated): + updated = _append_section(updated, f"# {{repo_id}}") + if repo_type in {{"model", "dataset"}} and marker not in updated: + updated = _append_section(updated, _provenance(repo_type)) + if not usage_re.search(content or ""): + updated = _append_section(updated, _usage(repo_id, repo_type)) + return updated + + def _readme(api, repo_id, repo_type, token_value): + try: + path = hf_hub_download( + repo_id=repo_id, + filename="README.md", + repo_type=repo_type, + token=token_value, + ) + except (EntryNotFoundError, RepositoryNotFoundError): + return "" + return Path(path).read_text(encoding="utf-8") + + def _ensure_collection(api, token_value): + nonlocal collection_slug + if collection_slug: + return collection_slug + collection = api.create_collection( + title=collection_title, + description=( + f"Artifacts generated by ML Intern session {{session_id}} " + f"on {{session_date}}." + ), + private=True, + exists_ok=True, + token=token_value, + ) + collection_slug = getattr(collection, "slug", None) + return collection_slug + + def _register( + repo_id, + repo_type="model", + token_value=None, + extra_metadata=None, + force=False, + ): + nonlocal registering + if registering or not repo_id: + return + repo_type = repo_type or "model" + if repo_type not in supported: + return + key = f"{{repo_type}}:{{repo_id}}" + if key in registered and not force: + return + registering = True + try: + token_value = _token(token_value) + api = HfApi(token=token_value) + try: + current = _readme(api, repo_id, repo_type, token_value) + updated = _augment( + current, repo_id, repo_type, extra_metadata=extra_metadata + ) + if updated != current: + _original_upload_file( + api, + path_or_fileobj=updated.encode("utf-8"), + path_in_repo="README.md", + repo_id=repo_id, + repo_type=repo_type, + token=token_value, + commit_message="Update ML Intern artifact metadata", + ) + except Exception: + pass + try: + slug = _ensure_collection(api, token_value) + if slug: + api.add_collection_item( + collection_slug=slug, + item_id=repo_id, + item_type=repo_type, + note=( + f"Generated by ML Intern session {{session_id}} " + f"on {{session_date}}." + ), + exists_ok=True, + token=token_value, + ) + except Exception: + pass + registered.add(key) + finally: + registering = False + + _original_create_repo = HfApi.create_repo + _original_upload_file = HfApi.upload_file + _original_upload_folder = getattr(HfApi, "upload_folder", None) + _original_create_commit = getattr(HfApi, "create_commit", None) + + def _repo_id(args, kwargs): + return kwargs.get("repo_id") or (args[0] if args else None) + + def _repo_type(kwargs): + return kwargs.get("repo_type") or "model" + + def _patched_create_repo(self, *args, **kwargs): + result = _original_create_repo(self, *args, **kwargs) + repo_id = _repo_id(args, kwargs) + repo_type = _repo_type(kwargs) + extra = None + if repo_type == "space" and kwargs.get("space_sdk"): + extra = {{"sdk": kwargs.get("space_sdk")}} + _register(repo_id, repo_type, _token(kwargs.get("token"), self), extra) + return result + + def _patched_upload_file(self, *args, **kwargs): + result = _original_upload_file(self, *args, **kwargs) + if not kwargs.get("create_pr"): + force = kwargs.get("path_in_repo") == "README.md" + _register( + kwargs.get("repo_id"), + _repo_type(kwargs), + _token(kwargs.get("token"), self), + force=force, + ) + return result + + def _patched_upload_folder(self, *args, **kwargs): + result = _original_upload_folder(self, *args, **kwargs) + if not kwargs.get("create_pr"): + _register( + kwargs.get("repo_id"), + _repo_type(kwargs), + _token(kwargs.get("token"), self), + force=True, + ) + return result + + def _patched_create_commit(self, *args, **kwargs): + result = _original_create_commit(self, *args, **kwargs) + if not kwargs.get("create_pr"): + _register( + _repo_id(args, kwargs), + _repo_type(kwargs), + _token(kwargs.get("token"), self), + force=True, + ) + return result + + HfApi.create_repo = _patched_create_repo + HfApi.upload_file = _patched_upload_file + if _original_upload_folder is not None: + HfApi.upload_folder = _patched_upload_folder + if _original_create_commit is not None: + HfApi.create_commit = _patched_create_commit + + def _patch_module_func(name, method_name): + original = getattr(_hub, name, None) + if original is None: + return + method = getattr(HfApi, method_name) + + def _patched(*args, **kwargs): + api = HfApi(token=_token(kwargs.get("token"))) + return method(api, *args, **kwargs) + + setattr(_hub, name, _patched) + + _patch_module_func("create_repo", "create_repo") + _patch_module_func("upload_file", "upload_file") + if _original_upload_folder is not None: + _patch_module_func("upload_folder", "upload_folder") + if _original_create_commit is not None: + _patch_module_func("create_commit", "create_commit") + + try: + _install_ml_intern_artifact_hooks() + except Exception: + pass + """ + ).strip() + + "\n" + ) + + +def wrap_shell_command_with_hub_artifact_bootstrap( + command: str, + session: Any, +) -> str: + """Prefix a shell command so child Python processes load Hub hooks.""" + sitecustomize = build_hub_artifact_sitecustomize(session) + if not sitecustomize or not command: + return command + + encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii") + bootstrap = ( + '_ml_intern_artifacts_dir="$(mktemp -d 2>/dev/null)" ' + f"&& printf %s {shlex.quote(encoded)} | base64 -d " + '> "$_ml_intern_artifacts_dir/sitecustomize.py" ' + '&& export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"' + ) + return f"{bootstrap}; {command}" diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index 880886b3e1e2919f31d35934c6f9a4c3fb5e9525..028dd6df0f01ebc9be7f49cf1cd4e82bdf25c20f 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -56,9 +56,16 @@ def _patch_litellm_effort_validation() -> None: # to return True for families where "max" / "xhigh" are acceptable # at the API; the cascade handles the case when they're not. return any( - v in m for v in ( - "opus-4-6", "opus_4_6", "opus-4.6", "opus_4.6", - "opus-4-7", "opus_4_7", "opus-4.7", "opus_4.7", + v in m + for v in ( + "opus-4-6", + "opus_4_6", + "opus-4.6", + "opus_4.6", + "opus-4-7", + "opus_4_7", + "opus-4.7", + "opus_4.7", ) ) diff --git a/agent/core/model_switcher.py b/agent/core/model_switcher.py index ea419db1b294db5d32be5aeeace9418d172a2617..14b5233dc525e3c418f75412bcba92bed15d4bcd 100644 --- a/agent/core/model_switcher.py +++ b/agent/core/model_switcher.py @@ -28,7 +28,10 @@ SUGGESTED_MODELS = [ {"id": "openai/gpt-5.4", "label": "GPT-5.4"}, {"id": "anthropic/claude-opus-4-7", "label": "Claude Opus 4.7"}, {"id": "anthropic/claude-opus-4-6", "label": "Claude Opus 4.6"}, - {"id": "bedrock/us.anthropic.claude-opus-4-6-v1", "label": "Claude Opus 4.6 via Bedrock"}, + { + "id": "bedrock/us.anthropic.claude-opus-4-6-v1", + "label": "Claude Opus 4.6 via Bedrock", + }, {"id": "MiniMaxAI/MiniMax-M2.7", "label": "MiniMax M2.7"}, {"id": "moonshotai/Kimi-K2.6", "label": "Kimi K2.6"}, {"id": "zai-org/GLM-5.1", "label": "GLM 5.1"}, @@ -122,9 +125,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool: ) ctx = f"{p.context_length:,} ctx" if p.context_length else "ctx n/a" tools = "tools" if p.supports_tools else "no tools" - console.print( - f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]" - ) + console.print(f" [dim]{p.provider}: {price}, {ctx}, {tools}[/dim]") return True @@ -183,7 +184,9 @@ async def probe_and_switch_model( # Nothing to validate with a ping that we couldn't validate on the # first real call just as cheaply. Skip the probe entirely. _commit_switch(model_id, config, session, effective=None, cache=False) - console.print(f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]") + console.print( + f"[green]Model switched to {model_id}[/green] [dim](effort: off)[/dim]" + ) return console.print(f"[dim]checking {model_id} (effort: {preference})...[/dim]") @@ -203,8 +206,11 @@ async def probe_and_switch_model( return _commit_switch( - model_id, config, session, - effective=outcome.effective_effort, cache=True, + model_id, + config, + session, + effective=outcome.effective_effort, + cache=True, ) effort_label = outcome.effective_effort or "off" suffix = f" — {outcome.note}" if outcome.note else "" diff --git a/agent/core/prompt_caching.py b/agent/core/prompt_caching.py index 56685a5ef64020be59c3a402750375745406221a..b30edd9fc4845738c08e972fdab712bf2ae3988d 100644 --- a/agent/core/prompt_caching.py +++ b/agent/core/prompt_caching.py @@ -40,7 +40,11 @@ def with_prompt_caching( if messages: first = messages[0] - role = first.get("role") if isinstance(first, dict) else getattr(first, "role", None) + role = ( + first.get("role") + if isinstance(first, dict) + else getattr(first, "role", None) + ) if role == "system": content = ( first.get("content") @@ -48,11 +52,13 @@ def with_prompt_caching( else getattr(first, "content", None) ) if isinstance(content, str) and content: - cached_block = [{ - "type": "text", - "text": content, - "cache_control": {"type": "ephemeral"}, - }] + cached_block = [ + { + "type": "text", + "text": content, + "cache_control": {"type": "ephemeral"}, + } + ] new_first = {"role": "system", "content": cached_block} messages = [new_first] + list(messages[1:]) diff --git a/agent/core/session.py b/agent/core/session.py index 370bb3a6383e2d5e6626e3f55de5efe2e8b4503d..fb08c75f8a4d5aff0e86cdfb7bb276d585b93176 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -48,7 +48,8 @@ def _get_max_tokens_safe(model_name: str) -> int: continue logger.info( "No litellm.get_model_info entry for %s, falling back to %d", - model_name, _DEFAULT_MAX_TOKENS, + model_name, + _DEFAULT_MAX_TOKENS, ) return _DEFAULT_MAX_TOKENS @@ -277,8 +278,7 @@ class Session: if summary: summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS] message = ( - f"Session {self.session_id} completed successfully.\n" - f"{summary}" + f"Session {self.session_id} completed successfully.\n{summary}" ) else: message = f"Session {self.session_id} completed successfully." @@ -444,6 +444,7 @@ class Session: # snapshot between heartbeats would otherwise leak them. try: from agent.core.redact import scrub + for key in ("messages", "events", "tools"): if key in trajectory: trajectory[key] = scrub(trajectory[key]) diff --git a/agent/core/session_persistence.py b/agent/core/session_persistence.py index f2c2d3674593eadece69ab1268ff5fa8f9df9078..e12467211b16fe12ec75fbf5b60edb5ee54f4072 100644 --- a/agent/core/session_persistence.py +++ b/agent/core/session_persistence.py @@ -271,7 +271,9 @@ class MongoSessionStore(NoopSessionStore): upsert=True, ) ) - ops.append(DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}})) + ops.append( + DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}}) + ) try: if ops: await self.db.session_messages.bulk_write(ops, ordered=False) @@ -288,7 +290,9 @@ class MongoSessionStore(NoopSessionStore): return None if meta.get("visibility") == "deleted" and not include_deleted: return None - cursor = self.db.session_messages.find({"session_id": session_id}).sort("idx", 1) + cursor = self.db.session_messages.find({"session_id": session_id}).sort( + "idx", 1 + ) messages = [row.get("message") async for row in cursor] return {"metadata": meta, "messages": messages} @@ -356,7 +360,9 @@ class MongoSessionStore(NoopSessionStore): logger.debug("Failed to append event for %s: %s", session_id, e) return None - async def load_events_after(self, session_id: str, after_seq: int = 0) -> list[dict[str, Any]]: + async def load_events_after( + self, session_id: str, after_seq: int = 0 + ) -> list[dict[str, Any]]: if not self._ready(): return [] cursor = self.db.session_events.find( @@ -496,6 +502,8 @@ def get_session_store() -> NoopSessionStore | MongoSessionStore: return _store -def _reset_store_for_tests(store: NoopSessionStore | MongoSessionStore | None = None) -> None: +def _reset_store_for_tests( + store: NoopSessionStore | MongoSessionStore | None = None, +) -> None: global _store _store = store diff --git a/agent/core/session_uploader.py b/agent/core/session_uploader.py index 035a235db4aec3a46b9d421c977983880dc8a1d3..404fd224563cdae3d91c2b93e05e8306ee91fb7e 100644 --- a/agent/core/session_uploader.py +++ b/agent/core/session_uploader.py @@ -94,8 +94,7 @@ def _msg_uuid(session_id: str, role: str, idx: int) -> str: digest = hashlib.sha1(f"{session_id}::{role}::{idx}".encode("utf-8")).hexdigest() # Format like a UUID for visual familiarity (32 hex chars w/ dashes). return ( - f"{digest[0:8]}-{digest[8:12]}-{digest[12:16]}-" - f"{digest[16:20]}-{digest[20:32]}" + f"{digest[0:8]}-{digest[8:12]}-{digest[12:16]}-{digest[16:20]}-{digest[20:32]}" ) @@ -347,7 +346,7 @@ def _update_upload_status( def dataset_card_readme(repo_id: str) -> str: """Dataset card for personal ML Intern session trace repos.""" - return f"""--- + return """--- pretty_name: "ML Intern Session Traces" language: - en diff --git a/agent/core/telemetry.py b/agent/core/telemetry.py index 6de45a96b10325b5ea3dfea9d6cf3ed9458d878c..38d2bbe761fee99d7c8051d6788fc849df8a8fae 100644 --- a/agent/core/telemetry.py +++ b/agent/core/telemetry.py @@ -26,6 +26,7 @@ logger = logging.getLogger(__name__) # ── usage extraction ──────────────────────────────────────────────────────── + def extract_usage(response_or_chunk: Any) -> dict: """Flat usage dict from a litellm response or final-chunk usage object. @@ -71,6 +72,7 @@ def extract_usage(response_or_chunk: Any) -> dict: # ── llm_call ──────────────────────────────────────────────────────────────── + async def record_llm_call( session: Any, *, @@ -106,22 +108,26 @@ async def record_llm_call( if response is not None: try: from litellm import completion_cost + cost_usd = float(completion_cost(completion_response=response) or 0.0) except Exception: cost_usd = 0.0 from agent.core.session import Event # local import to avoid cycle + try: - await session.send_event(Event( - event_type="llm_call", - data={ - "model": model, - "latency_ms": latency_ms, - "finish_reason": finish_reason, - "cost_usd": cost_usd, - "kind": kind, - **usage, - }, - )) + await session.send_event( + Event( + event_type="llm_call", + data={ + "model": model, + "latency_ms": latency_ms, + "finish_reason": finish_reason, + "cost_usd": cost_usd, + "kind": kind, + **usage, + }, + ) + ) except Exception as e: logger.debug("record_llm_call failed (non-fatal): %s", e) return usage @@ -129,6 +135,7 @@ async def record_llm_call( # ── hf_jobs ──────────────────────────────────────────────────────────────── + def _infer_push_to_hub(script_or_cmd: Any) -> bool: if not isinstance(script_or_cmd, str): return False @@ -150,22 +157,25 @@ async def record_hf_job_submit( """Emit ``hf_job_submit``. Returns the monotonic start timestamp so the caller can pass it back into :func:`record_hf_job_complete`.""" from agent.core.session import Event + t_start = time.monotonic() try: script_text = args.get("script") or args.get("command") or "" - await session.send_event(Event( - event_type="hf_job_submit", - data={ - "job_id": getattr(job, "id", None), - "job_url": getattr(job, "url", None), - "flavor": args.get("hardware_flavor", "cpu-basic"), - "timeout": args.get("timeout", "30m"), - "job_type": job_type, - "image": image, - "namespace": args.get("namespace"), - "push_to_hub": _infer_push_to_hub(script_text), - }, - )) + await session.send_event( + Event( + event_type="hf_job_submit", + data={ + "job_id": getattr(job, "id", None), + "job_url": getattr(job, "url", None), + "flavor": args.get("hardware_flavor", "cpu-basic"), + "timeout": args.get("timeout", "30m"), + "job_type": job_type, + "image": image, + "namespace": args.get("namespace"), + "push_to_hub": _infer_push_to_hub(script_text), + }, + ) + ) except Exception as e: logger.debug("record_hf_job_submit failed (non-fatal): %s", e) return t_start @@ -180,23 +190,27 @@ async def record_hf_job_complete( submit_ts: float, ) -> None: from agent.core.session import Event + try: wall_time_s = int(time.monotonic() - submit_ts) - await session.send_event(Event( - event_type="hf_job_complete", - data={ - "job_id": getattr(job, "id", None), - "flavor": flavor, - "final_status": final_status, - "wall_time_s": wall_time_s, - }, - )) + await session.send_event( + Event( + event_type="hf_job_complete", + data={ + "job_id": getattr(job, "id", None), + "flavor": flavor, + "final_status": final_status, + "wall_time_s": wall_time_s, + }, + ) + ) except Exception as e: logger.debug("record_hf_job_complete failed (non-fatal): %s", e) # ── sandbox ───────────────────────────────────────────────────────────────── + async def record_sandbox_create( session: Any, sandbox: Any, @@ -205,39 +219,46 @@ async def record_sandbox_create( create_latency_s: int, ) -> None: from agent.core.session import Event + try: # Pin created-at on the session so record_sandbox_destroy can diff. session._sandbox_created_at = time.monotonic() - create_latency_s - await session.send_event(Event( - event_type="sandbox_create", - data={ - "sandbox_id": getattr(sandbox, "space_id", None), - "hardware": hardware, - "create_latency_s": int(create_latency_s), - }, - )) + await session.send_event( + Event( + event_type="sandbox_create", + data={ + "sandbox_id": getattr(sandbox, "space_id", None), + "hardware": hardware, + "create_latency_s": int(create_latency_s), + }, + ) + ) except Exception as e: logger.debug("record_sandbox_create failed (non-fatal): %s", e) async def record_sandbox_destroy(session: Any, sandbox: Any) -> None: from agent.core.session import Event + try: created = getattr(session, "_sandbox_created_at", None) lifetime_s = int(time.monotonic() - created) if created else None - await session.send_event(Event( - event_type="sandbox_destroy", - data={ - "sandbox_id": getattr(sandbox, "space_id", None), - "lifetime_s": lifetime_s, - }, - )) + await session.send_event( + Event( + event_type="sandbox_destroy", + data={ + "sandbox_id": getattr(sandbox, "space_id", None), + "lifetime_s": lifetime_s, + }, + ) + ) except Exception as e: logger.debug("record_sandbox_destroy failed (non-fatal): %s", e) # ── feedback ─────────────────────────────────────────────────────────────── + async def record_feedback( session: Any, *, @@ -247,16 +268,19 @@ async def record_feedback( comment: str | None = None, ) -> None: from agent.core.session import Event + try: - await session.send_event(Event( - event_type="feedback", - data={ - "rating": rating, - "turn_index": turn_index, - "message_id": message_id, - "comment": (comment or "")[:500], - }, - )) + await session.send_event( + Event( + event_type="feedback", + data={ + "rating": rating, + "turn_index": turn_index, + "message_id": message_id, + "comment": (comment or "")[:500], + }, + ) + ) except Exception as e: logger.debug("record_feedback failed (non-fatal): %s", e) @@ -269,15 +293,18 @@ async def record_jobs_access_blocked( eligible_namespaces: list[str], ) -> None: from agent.core.session import Event + try: - await session.send_event(Event( - event_type="jobs_access_blocked", - data={ - "tool_call_ids": tool_call_ids, - "plan": plan, - "eligible_namespaces": eligible_namespaces, - }, - )) + await session.send_event( + Event( + event_type="jobs_access_blocked", + data={ + "tool_call_ids": tool_call_ids, + "plan": plan, + "eligible_namespaces": eligible_namespaces, + }, + ) + ) except Exception as e: logger.debug("record_jobs_access_blocked failed (non-fatal): %s", e) @@ -289,11 +316,14 @@ async def record_pro_cta_click( target: str = "pro_pricing", ) -> None: from agent.core.session import Event + try: - await session.send_event(Event( - event_type="pro_cta_click", - data={"source": source, "target": target}, - )) + await session.send_event( + Event( + event_type="pro_cta_click", + data={"source": source, "target": target}, + ) + ) except Exception as e: logger.debug("record_pro_cta_click failed (non-fatal): %s", e) @@ -308,11 +338,14 @@ async def record_pro_conversion( ``MongoSessionStore.mark_pro_seen``; fired into the user's first Pro session so the rollup picks it up alongside other event-driven KPIs.""" from agent.core.session import Event + try: - await session.send_event(Event( - event_type="pro_conversion", - data={"first_seen_at": first_seen_at}, - )) + await session.send_event( + Event( + event_type="pro_conversion", + data={"first_seen_at": first_seen_at}, + ) + ) except Exception as e: logger.debug("record_pro_conversion failed (non-fatal): %s", e) @@ -327,11 +360,14 @@ async def record_credits_topped_up( came back from the HF billing top-up flow and unblocked themselves. Caller is responsible for firing this at most once per session.""" from agent.core.session import Event + try: - await session.send_event(Event( - event_type="credits_topped_up", - data={"namespace": namespace}, - )) + await session.send_event( + Event( + event_type="credits_topped_up", + data={"namespace": namespace}, + ) + ) except Exception as e: logger.debug("record_credits_topped_up failed (non-fatal): %s", e) diff --git a/agent/core/tools.py b/agent/core/tools.py index ef2c57bc19478043996597083cba54a243cdf4cc..1b750671605143958f1193c38ef7c1ee083a3cdc 100644 --- a/agent/core/tools.py +++ b/agent/core/tools.py @@ -8,8 +8,6 @@ import warnings from dataclasses import dataclass from typing import Any, Awaitable, Callable, Optional -logger = logging.getLogger(__name__) - from fastmcp import Client from fastmcp.exceptions import ToolError from mcp.types import EmbeddedResource, ImageContent, TextContent @@ -64,6 +62,8 @@ warnings.filterwarnings( "ignore", category=DeprecationWarning, module="aiohttp.connector" ) +logger = logging.getLogger(__name__) + NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"] @@ -131,7 +131,12 @@ class ToolRouter: Based on codex-rs/core/src/tools/router.rs """ - def __init__(self, mcp_servers: dict[str, MCPServerConfig], hf_token: str | None = None, local_mode: bool = False): + def __init__( + self, + mcp_servers: dict[str, MCPServerConfig], + hf_token: str | None = None, + local_mode: bool = False, + ): self.tools: dict[str, ToolSpec] = {} self.mcp_servers: dict[str, dict[str, Any]] = {} @@ -144,7 +149,9 @@ class ToolRouter: for name, server in mcp_servers.items(): data = server.model_dump() if hf_token: - data.setdefault("headers", {})["Authorization"] = f"Bearer {hf_token}" + data.setdefault("headers", {})["Authorization"] = ( + f"Bearer {hf_token}" + ) mcp_servers_payload[name] = data self.mcp_client = Client({"mcpServers": mcp_servers_payload}) self._mcp_initialized = False @@ -218,7 +225,9 @@ class ToolRouter: await self.register_mcp_tools() self._mcp_initialized = True except Exception as e: - logger.warning("MCP connection failed, continuing without MCP tools: %s", e) + logger.warning( + "MCP connection failed, continuing without MCP tools: %s", e + ) self.mcp_client = None await self.register_openapi_tool() @@ -380,6 +389,7 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: # Sandbox or local tools (highest priority) if local_mode: from agent.tools.local_tools import get_local_tools + tools = get_local_tools() + tools else: tools = get_sandbox_tools() + tools diff --git a/agent/main.py b/agent/main.py index 606aaf8e1afc3347f1780f273129cac2b01b1977..1fadcbdae0a2a6ff03b8aba424243d78bd11c8af 100644 --- a/agent/main.py +++ b/agent/main.py @@ -77,6 +77,7 @@ def _configure_runtime_logging() -> None: logging.getLogger("LiteLLM").setLevel(logging.ERROR) logging.getLogger("litellm").setLevel(logging.ERROR) + def _safe_get_args(arguments: dict) -> dict: """Safely extract args dict from arguments, handling cases where LLM passes string.""" args = arguments.get("args", {}) @@ -92,6 +93,7 @@ def _get_hf_user(token: str | None) -> str | None: return None try: from huggingface_hub import HfApi + return HfApi(token=token).whoami().get("name") except Exception: return None @@ -134,10 +136,13 @@ async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str: login(token=token, add_to_git_credential=False) print("Token saved to ~/.cache/huggingface/token") except Exception as e: - print(f"Warning: could not persist token ({e}), using for this session only.") + print( + f"Warning: could not persist token ({e}), using for this session only." + ) return token + @dataclass class Operation: """Operation to be executed by the agent""" @@ -162,9 +167,9 @@ def _create_rich_console(): class _ThinkingShimmer: """Animated shiny/shimmer thinking indicator — a bright gradient sweeps across the text.""" - _BASE = (90, 90, 110) # dim base color - _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold) - _WIDTH = 5 # shimmer width in characters + _BASE = (90, 90, 110) # dim base color + _HIGHLIGHT = (255, 200, 80) # bright shimmer highlight (warm gold) + _WIDTH = 5 # shimmer width in characters _FPS = 24 def __init__(self, console): @@ -245,7 +250,7 @@ class _StreamBuffer: if idx == -1: return None block = self._buffer[:idx] - self._buffer = self._buffer[idx + 2:] + self._buffer = self._buffer[idx + 2 :] return block async def flush_ready( @@ -271,7 +276,9 @@ class _StreamBuffer: """Flush complete blocks, then render whatever incomplete tail remains.""" await self.flush_ready(cancel_event=cancel_event, instant=instant) if self._buffer.strip(): - await print_markdown(self._buffer, cancel_event=cancel_event, instant=instant) + await print_markdown( + self._buffer, cancel_event=cancel_event, instant=instant + ) self._buffer = "" def discard(self): @@ -372,7 +379,11 @@ async def event_listener( elif event.event_type == "error": shimmer.stop() stream_buf.discard() - error = event.data.get("error", "Unknown error") if event.data else "Unknown error" + error = ( + event.data.get("error", "Unknown error") + if event.data + else "Unknown error" + ) print_error(error) turn_complete_event.set() elif event.event_type == "shutdown": @@ -392,8 +403,10 @@ async def event_listener( # If yolo mode is active, auto-approve everything except # scheduled HF jobs, whose recurring cost stays manual. - if config and config.yolo_mode and not any( - _is_scheduled_hf_job_tool(t) for t in tools_data + if ( + config + and config.yolo_mode + and not any(_is_scheduled_hf_job_tool(t) for t in tools_data) ): approvals = [ { @@ -637,7 +650,9 @@ async def event_listener( f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): " ) except (KeyboardInterrupt, EOFError): - get_console().print("[dim]Approval cancelled — rejecting remaining items[/dim]") + get_console().print( + "[dim]Approval cancelled — rejecting remaining items[/dim]" + ) approvals.append( { "tool_call_id": tool_call_id, @@ -770,7 +785,11 @@ async def _handle_slash_command( normalized = arg.removeprefix("huggingface/") session = session_holder[0] if session_holder else None await model_switcher.probe_and_switch_model( - normalized, config, session, console, resolve_hf_token(), + normalized, + config, + session, + console, + resolve_hf_token(), ) return None @@ -965,6 +984,7 @@ async def main(model: str | None = None): # Pre-warm the HF router catalog in the background so /model switches # don't block on a network fetch. from agent.core import hf_router_catalog + asyncio.create_task(asyncio.to_thread(hf_router_catalog.prewarm)) # Create queues for communication @@ -1110,7 +1130,11 @@ async def main(model: str | None = None): # Handle slash commands if user_input.strip().startswith("/"): sub = await _handle_slash_command( - user_input.strip(), config, session_holder, submission_queue, submission_id + user_input.strip(), + config, + session_holder, + submission_queue, + submission_id, ) if sub is None: # Command handled locally, loop back for input @@ -1176,10 +1200,13 @@ async def headless_main( hf_token = resolve_hf_token() if not hf_token: - print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr) + print( + "ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", + file=sys.stderr, + ) sys.exit(1) - print(f"HF token loaded", file=sys.stderr) + print("HF token loaded", file=sys.stderr) config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) config.yolo_mode = True # Auto-approve everything in headless mode @@ -1327,26 +1354,35 @@ async def headless_main( for t in tools_data ] _hl_sub_id[0] += 1 - await submission_queue.put(Submission( - id=f"hl_approval_{_hl_sub_id[0]}", - operation=Operation( - op_type=OpType.EXEC_APPROVAL, - data={"approvals": approvals}, - ), - )) + await submission_queue.put( + Submission( + id=f"hl_approval_{_hl_sub_id[0]}", + operation=Operation( + op_type=OpType.EXEC_APPROVAL, + data={"approvals": approvals}, + ), + ) + ) elif event.event_type == "compacted": old_tokens = event.data.get("old_tokens", 0) if event.data else 0 new_tokens = event.data.get("new_tokens", 0) if event.data else 0 print_compacted(old_tokens, new_tokens) elif event.event_type == "error": stream_buf.discard() - error = event.data.get("error", "Unknown error") if event.data else "Unknown error" + error = ( + event.data.get("error", "Unknown error") + if event.data + else "Unknown error" + ) print_error(error) break elif event.event_type in ("turn_complete", "interrupted"): stream_buf.discard() history_size = event.data.get("history_size", "?") if event.data else "?" - print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr) + print( + f"\n--- Agent {event.event_type} (history_size={history_size}) ---", + file=sys.stderr, + ) if event.event_type == "turn_complete": session = session_holder[0] if session_holder else None if session is not None: @@ -1372,6 +1408,7 @@ def cli(): """Entry point for the ml-intern CLI command.""" import logging as _logging import warnings + # Suppress aiohttp "Unclosed client session" noise during event loop teardown _logging.getLogger("asyncio").setLevel(_logging.CRITICAL) _configure_runtime_logging() @@ -1381,12 +1418,23 @@ def cli(): warnings.filterwarnings("ignore", category=SyntaxWarning, module="whoosh") parser = argparse.ArgumentParser(description="Hugging Face Agent CLI") - parser.add_argument("prompt", nargs="?", default=None, help="Run headlessly with this prompt") - parser.add_argument("--model", "-m", default=None, help=f"Model to use (default: from config)") - parser.add_argument("--max-iterations", type=int, default=None, - help="Max LLM requests per turn (default: 50, use -1 for unlimited)") - parser.add_argument("--no-stream", action="store_true", - help="Disable token streaming (use non-streaming LLM calls)") + parser.add_argument( + "prompt", nargs="?", default=None, help="Run headlessly with this prompt" + ) + parser.add_argument( + "--model", "-m", default=None, help="Model to use (default: from config)" + ) + parser.add_argument( + "--max-iterations", + type=int, + default=None, + help="Max LLM requests per turn (default: 50, use -1 for unlimited)", + ) + parser.add_argument( + "--no-stream", + action="store_true", + help="Disable token streaming (use non-streaming LLM calls)", + ) args = parser.parse_args() try: @@ -1394,7 +1442,14 @@ def cli(): max_iter = args.max_iterations if max_iter is not None and max_iter < 0: max_iter = 10_000 # effectively unlimited - asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream)) + asyncio.run( + headless_main( + args.prompt, + model=args.model, + max_iterations=max_iter, + stream=not args.no_stream, + ) + ) else: asyncio.run(main(model=args.model)) except KeyboardInterrupt: diff --git a/agent/messaging/base.py b/agent/messaging/base.py index bf1d73894fa85ce066fa289902c4d6b783ceaa11..a74f9cf0d1cb2a77328124414b04de9ebbd6b582 100644 --- a/agent/messaging/base.py +++ b/agent/messaging/base.py @@ -2,7 +2,11 @@ from abc import ABC, abstractmethod import httpx -from agent.messaging.models import DestinationConfig, NotificationRequest, NotificationResult +from agent.messaging.models import ( + DestinationConfig, + NotificationRequest, + NotificationResult, +) class NotificationError(Exception): diff --git a/agent/messaging/gateway.py b/agent/messaging/gateway.py index 83c4704baafe9eadea23a336f691dc96db934e79..1de9438f5c5c8ae2847ef1bf4a398d10e8903048 100644 --- a/agent/messaging/gateway.py +++ b/agent/messaging/gateway.py @@ -39,7 +39,9 @@ class NotificationGateway: if not self.enabled or self._worker_task is not None: return self._client = httpx.AsyncClient(timeout=10.0) - self._worker_task = asyncio.create_task(self._worker(), name="notification-gateway") + self._worker_task = asyncio.create_task( + self._worker(), name="notification-gateway" + ) async def flush(self) -> None: if not self.enabled: @@ -87,7 +89,9 @@ class NotificationGateway: provider=destination.provider, error=f"No provider implementation for '{destination.provider}'", ) - return await self._send_with_retries(provider, request.destination, destination, request) + return await self._send_with_retries( + provider, request.destination, destination, request + ) async def send_many( self, requests: Iterable[NotificationRequest] @@ -131,7 +135,9 @@ class NotificationGateway: try: for attempt in range(len(_RETRY_DELAYS) + 1): try: - return await provider.send(client, destination_name, destination, request) + return await provider.send( + client, destination_name, destination, request + ) except RetryableNotificationError as exc: if attempt >= len(_RETRY_DELAYS): return NotificationResult( diff --git a/agent/messaging/models.py b/agent/messaging/models.py index 25f645fe92fa70901843e68be82d82f3a78e0d16..16148a8179f5de3fa38b36ce76166a48e9f54a83 100644 --- a/agent/messaging/models.py +++ b/agent/messaging/models.py @@ -55,9 +55,7 @@ class MessagingConfig(BaseModel): seen: set[str] = set() for event_type in event_types: if event_type not in SUPPORTED_AUTO_EVENT_TYPES: - raise ValueError( - f"unsupported auto event type '{event_type}'" - ) + raise ValueError(f"unsupported auto event type '{event_type}'") if event_type not in seen: normalized.append(event_type) seen.add(event_type) @@ -83,11 +81,7 @@ class MessagingConfig(BaseModel): def default_auto_destinations(self) -> list[str]: if not self.enabled: return [] - return [ - name - for name in self.destinations - if self.can_auto_send(name) - ] + return [name for name in self.destinations if self.can_auto_send(name)] class NotificationRequest(BaseModel): diff --git a/agent/messaging/slack.py b/agent/messaging/slack.py index a1fb7c18eef91396e566fb04b4f6411f9184a2be..3790e44af790db8579a9a8efb88a2a16283ec71d 100644 --- a/agent/messaging/slack.py +++ b/agent/messaging/slack.py @@ -160,9 +160,7 @@ class SlackProvider(NotificationProvider): raise RetryableNotificationError("Slack transport error") from exc if response.status_code == 429 or response.status_code >= 500: - raise RetryableNotificationError( - f"Slack HTTP {response.status_code}" - ) + raise RetryableNotificationError(f"Slack HTTP {response.status_code}") if response.status_code >= 400: raise NotificationError(f"Slack HTTP {response.status_code}") diff --git a/agent/sft/tagger.py b/agent/sft/tagger.py index 7c47434d931f0a2ac81d3924c1bbc19d52fc9b54..528bc9d0d80b7e63bc63f527e94cabf59b214966 100644 --- a/agent/sft/tagger.py +++ b/agent/sft/tagger.py @@ -27,19 +27,29 @@ Tags are deduplicated before returning. from __future__ import annotations -from typing import Any, Iterable +from typing import Iterable # Flavor → GPU-family mapping. Keep conservative; unknown flavors → "none". _GPU_FAMILY = { - "cpu-basic": "none", "cpu-upgrade": "none", - "t4-small": "t4", "t4-medium": "t4", - "l4x1": "l40s", "l4x4": "l40s", - "l40sx1": "l40s", "l40sx4": "l40s", "l40sx8": "l40s", - "a10g-small": "a10g", "a10g-large": "a10g", - "a10g-largex2": "a10g", "a10g-largex4": "a10g", - "a100-large": "a100", "a100x2": "a100", - "a100x4": "a100", "a100x8": "a100", - "h100": "h100", "h100x8": "h100", + "cpu-basic": "none", + "cpu-upgrade": "none", + "t4-small": "t4", + "t4-medium": "t4", + "l4x1": "l40s", + "l4x4": "l40s", + "l40sx1": "l40s", + "l40sx4": "l40s", + "l40sx8": "l40s", + "a10g-small": "a10g", + "a10g-large": "a10g", + "a10g-largex2": "a10g", + "a10g-largex4": "a10g", + "a100-large": "a100", + "a100x2": "a100", + "a100x4": "a100", + "a100x8": "a100", + "h100": "h100", + "h100x8": "h100", } # Substrings that count a flavor as multi-GPU. @@ -48,9 +58,17 @@ _MULTI_GPU_MARKERS = ("x2", "x4", "x8") # Tool names that don't touch training/inference or sandbox/jobs. If a session # only used these, we tag it research_only. _RESEARCH_ONLY_TOOLS = { - "research", "github_find_examples", "github_read_file", "github_list_repos", - "hf_papers", "explore_hf_docs", "fetch_hf_docs", "hub_repo_details", - "plan", "hf_inspect_dataset", "web_search", + "research", + "github_find_examples", + "github_read_file", + "github_list_repos", + "hf_papers", + "explore_hf_docs", + "fetch_hf_docs", + "hub_repo_details", + "plan", + "hf_inspect_dataset", + "web_search", } # Tool names that signal data manipulation workflows. @@ -126,11 +144,22 @@ def _infer_task_tag( # hf_jobs at all and a script mentions training APIs. for script in hf_job_submit_scripts: low = script.lower() - if any(k in low for k in ( - "sftconfig", "sfttrainer", "trainer(", "trainingarguments", - "grpo", "dpo", ".train(", "transformers import", - "trainer import", "fine-tune", "finetune", - )): + if any( + k in low + for k in ( + "sftconfig", + "sfttrainer", + "trainer(", + "trainingarguments", + "grpo", + "dpo", + ".train(", + "transformers import", + "trainer import", + "fine-tune", + "finetune", + ) + ): return "training" # inference: sessions that use inference tools but never hf_jobs/sandbox diff --git a/agent/tools/dataset_tools.py b/agent/tools/dataset_tools.py index ef3f3c81b629d0f937a91255622629396e5a2534..20add683d40c3b0f550daaae046408d64f23ddbd 100644 --- a/agent/tools/dataset_tools.py +++ b/agent/tools/dataset_tools.py @@ -423,7 +423,9 @@ HF_INSPECT_DATASET_TOOL_SPEC = { } -async def hf_inspect_dataset_handler(arguments: dict[str, Any], session=None) -> tuple[str, bool]: +async def hf_inspect_dataset_handler( + arguments: dict[str, Any], session=None +) -> tuple[str, bool]: """Handler for agent tool router""" try: hf_token = session.hf_token if session else None diff --git a/agent/tools/edit_utils.py b/agent/tools/edit_utils.py index 6a9a3295e2e25313758d633e0b733f57c373cd5a..1c6b958192ad8a90c9b3268f6fdb688787d97ea6 100644 --- a/agent/tools/edit_utils.py +++ b/agent/tools/edit_utils.py @@ -10,18 +10,18 @@ from __future__ import annotations # ── Unicode normalization map ──────────────────────────────────────────── UNICODE_MAP = { - "\u2013": "-", # en-dash - "\u2014": "-", # em-dash - "\u2212": "-", # minus sign - "\u2018": "'", # left single quote - "\u2019": "'", # right single quote - "\u201c": '"', # left double quote - "\u201d": '"', # right double quote - "\u00a0": " ", # non-breaking space - "\u2003": " ", # em space - "\u2002": " ", # en space - "\u200b": "", # zero-width space - "\ufeff": "", # BOM + "\u2013": "-", # en-dash + "\u2014": "-", # em-dash + "\u2212": "-", # minus sign + "\u2018": "'", # left single quote + "\u2019": "'", # right single quote + "\u201c": '"', # left double quote + "\u201d": '"', # right double quote + "\u00a0": " ", # non-breaking space + "\u2003": " ", # em space + "\u2002": " ", # en space + "\u200b": "", # zero-width space + "\ufeff": "", # BOM } @@ -59,12 +59,12 @@ def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]: line_start_map[i] = original byte offset of the start of line i. """ orig_lines = text.split("\n") - stripped_lines = [strip_fn(l) for l in orig_lines] + stripped_lines = [strip_fn(line) for line in orig_lines] return "\n".join(stripped_lines), orig_lines, stripped_lines # Pass 2 — right-trim c_rt, c_orig_lines, c_rt_lines = _build_stripped(content, str.rstrip) - p_rt = "\n".join(l.rstrip() for l in pattern.split("\n")) + p_rt = "\n".join(line.rstrip() for line in pattern.split("\n")) idx = c_rt.find(p_rt) if idx != -1: orig_idx = _map_back(idx, c_orig_lines, c_rt_lines) @@ -72,7 +72,7 @@ def fuzzy_find(content: str, pattern: str) -> tuple[int | None, str | None]: # Pass 3 — both-sides trim c_st, _, c_st_lines = _build_stripped(content, str.strip) - p_st = "\n".join(l.strip() for l in pattern.split("\n")) + p_st = "\n".join(line.strip() for line in pattern.split("\n")) idx = c_st.find(p_st) if idx != -1: orig_idx = _map_back(idx, c_orig_lines, c_st_lines) @@ -114,7 +114,9 @@ def _map_back( return 0 -def fuzzy_find_original_match(content: str, pattern: str) -> tuple[str | None, str | None]: +def fuzzy_find_original_match( + content: str, pattern: str +) -> tuple[str | None, str | None]: """Find the *original* text in content that matches pattern fuzzily. Returns (original_matched_text, match_note) or (None, None). @@ -224,7 +226,9 @@ def apply_edit( return new_content, 1, fuzzy_note else: - raise ValueError(f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before.") + raise ValueError( + f"Unknown edit mode: {mode}. Use replace, append_after, or prepend_before." + ) # ── Syntax validation (Python) ─────────────────────────────────────────── @@ -255,14 +259,15 @@ def validate_python(content: str, path: str = "") -> list[str]: return warnings # 2. Training script heuristics - if any(kw in content for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig")): + if any( + kw in content + for kw in ("TrainingArguments", "SFTConfig", "DPOConfig", "GRPOConfig") + ): if "push_to_hub" not in content: warnings.append( "Training script warning: no 'push_to_hub' found — model may be lost when job ends" ) if "hub_model_id" not in content: - warnings.append( - "Training script warning: no 'hub_model_id' found" - ) + warnings.append("Training script warning: no 'hub_model_id' found") return warnings diff --git a/agent/tools/hf_repo_files_tool.py b/agent/tools/hf_repo_files_tool.py index fd39a488fc5610b665d2e0ddb7584d892104644a..aee00b741662838769d25711602b5afefcb623e8 100644 --- a/agent/tools/hf_repo_files_tool.py +++ b/agent/tools/hf_repo_files_tool.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Literal, Optional from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError +from agent.core.hub_artifacts import is_known_hub_artifact, register_hub_artifact from agent.tools.types import ToolResult OperationType = Literal["list", "read", "upload", "delete"] @@ -39,8 +40,9 @@ def _format_size(size_bytes: int) -> str: class HfRepoFilesTool: """Tool for file operations on HF repos.""" - def __init__(self, hf_token: Optional[str] = None): + def __init__(self, hf_token: Optional[str] = None, session: Any = None): self.api = HfApi(token=hf_token) + self.session = session async def execute(self, args: Dict[str, Any]) -> ToolResult: """Execute the specified operation.""" @@ -61,7 +63,9 @@ class HfRepoFilesTool: if handler: return await handler(args) else: - return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete") + return self._error( + f"Unknown operation: {operation}. Valid: list, read, upload, delete" + ) except RepositoryNotFoundError: return self._error(f"Repository not found: {args.get('repo_id')}") @@ -96,17 +100,23 @@ class HfRepoFilesTool: revision = args.get("revision", "main") path = args.get("path", "") - items = list(await _async_call( - self.api.list_repo_tree, - repo_id=repo_id, - repo_type=repo_type, - revision=revision, - path_in_repo=path, - recursive=True, - )) + items = list( + await _async_call( + self.api.list_repo_tree, + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + path_in_repo=path, + recursive=True, + ) + ) if not items: - return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0} + return { + "formatted": f"No files in {repo_id}", + "totalResults": 0, + "resultsShared": 0, + } lines = [] total_size = 0 @@ -118,9 +128,16 @@ class HfRepoFilesTool: lines.append(f"{item.path}/") url = _build_repo_url(repo_id, repo_type) - response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines) + response = ( + f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + + "\n".join(lines) + ) - return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)} + return { + "formatted": response, + "totalResults": len(items), + "resultsShared": len(items), + } async def _read(self, args: Dict[str, Any]) -> ToolResult: """Read file content from a repository.""" @@ -160,8 +177,13 @@ class HfRepoFilesTool: except UnicodeDecodeError: import os + size = os.path.getsize(file_path) - return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"Binary file ({_format_size(size)})", + "totalResults": 1, + "resultsShared": 1, + } async def _upload(self, args: Dict[str, Any]) -> ToolResult: """Upload content to a repository.""" @@ -194,6 +216,16 @@ class HfRepoFilesTool: create_pr=create_pr, ) + if not create_pr and is_known_hub_artifact(self.session, repo_id, repo_type): + await _async_call( + register_hub_artifact, + self.api, + repo_id, + repo_type, + session=self.session, + force=path == "README.md", + ) + url = _build_repo_url(repo_id, repo_type) if create_pr and hasattr(result, "pr_url"): response = f"**Uploaded as PR**\n{result.pr_url}" @@ -235,7 +267,12 @@ class HfRepoFilesTool: def _error(self, message: str) -> ToolResult: """Return an error result.""" - return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} + return { + "formatted": message, + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } # Tool specification @@ -312,11 +349,13 @@ HF_REPO_FILES_TOOL_SPEC = { } -async def hf_repo_files_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]: +async def hf_repo_files_handler( + arguments: Dict[str, Any], session=None +) -> tuple[str, bool]: """Handler for agent tool router.""" try: hf_token = session.hf_token if session else None - tool = HfRepoFilesTool(hf_token=hf_token) + tool = HfRepoFilesTool(hf_token=hf_token, session=session) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: diff --git a/agent/tools/hf_repo_git_tool.py b/agent/tools/hf_repo_git_tool.py index d7e2323a361efd578cf97d78be2593b71f4d8ac8..cfff4120b089aa7923c2a46c5c3da22cf201457f 100644 --- a/agent/tools/hf_repo_git_tool.py +++ b/agent/tools/hf_repo_git_tool.py @@ -10,14 +10,24 @@ from typing import Any, Dict, Literal, Optional from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError +from agent.core.hub_artifacts import register_hub_artifact from agent.tools.types import ToolResult OperationType = Literal[ - "create_branch", "delete_branch", - "create_tag", "delete_tag", + "create_branch", + "delete_branch", + "create_tag", + "delete_tag", "list_refs", - "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", - "create_repo", "update_repo", + "create_pr", + "list_prs", + "get_pr", + "merge_pr", + "close_pr", + "comment_pr", + "change_pr_status", + "create_repo", + "update_repo", ] @@ -36,8 +46,9 @@ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str: class HfRepoGitTool: """Tool for git-like operations on HF repos.""" - def __init__(self, hf_token: Optional[str] = None): + def __init__(self, hf_token: Optional[str] = None, session: Any = None): self.api = HfApi(token=hf_token) + self.session = session async def execute(self, args: Dict[str, Any]) -> ToolResult: """Execute the specified operation.""" @@ -131,7 +142,11 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}" - return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"**Branch created:** {branch}\n{url}", + "totalResults": 1, + "resultsShared": 1, + } async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult: """Delete a branch.""" @@ -152,7 +167,11 @@ class HfRepoGitTool: repo_type=repo_type, ) - return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"**Branch deleted:** {branch}", + "totalResults": 1, + "resultsShared": 1, + } # ========================================================================= # TAG OPERATIONS @@ -183,7 +202,11 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}" - return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"**Tag created:** {tag}\n{url}", + "totalResults": 1, + "resultsShared": 1, + } async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult: """Delete a tag.""" @@ -204,7 +227,11 @@ class HfRepoGitTool: repo_type=repo_type, ) - return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"**Tag deleted:** {tag}", + "totalResults": 1, + "resultsShared": 1, + } # ========================================================================= # LIST REFS @@ -226,7 +253,9 @@ class HfRepoGitTool: ) branches = [b.name for b in refs.branches] if refs.branches else [] - tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else [] + tags = ( + [t.name for t in refs.tags] if hasattr(refs, "tags") and refs.tags else [] + ) url = _build_repo_url(repo_id, repo_type) lines = [f"**{repo_id}**", url, ""] @@ -241,7 +270,11 @@ class HfRepoGitTool: else: lines.append("**Tags:** none") - return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)} + return { + "formatted": "\n".join(lines), + "totalResults": len(branches) + len(tags), + "resultsShared": len(branches) + len(tags), + } # ========================================================================= # PR OPERATIONS @@ -270,7 +303,7 @@ class HfRepoGitTool: url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}" return { - "formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"", + "formatted": f'**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision="refs/pr/{result.num}"', "totalResults": 1, "resultsShared": 1, } @@ -285,17 +318,27 @@ class HfRepoGitTool: repo_type = args.get("repo_type", "model") status = args.get("status", "all") # open, closed, all - discussions = list(self.api.get_repo_discussions( - repo_id=repo_id, - repo_type=repo_type, - discussion_status=status if status != "all" else None, - )) + discussions = list( + self.api.get_repo_discussions( + repo_id=repo_id, + repo_type=repo_type, + discussion_status=status if status != "all" else None, + ) + ) if not discussions: - return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0} + return { + "formatted": f"No discussions in {repo_id}", + "totalResults": 0, + "resultsShared": 0, + } url = _build_repo_url(repo_id, repo_type) - lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""] + lines = [ + f"**{repo_id}** - {len(discussions)} discussions", + f"{url}/discussions", + "", + ] for d in discussions[:20]: if d.status == "draft": @@ -309,7 +352,11 @@ class HfRepoGitTool: type_label = "PR" if d.is_pull_request else "D" lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}") - return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))} + return { + "formatted": "\n".join(lines), + "totalResults": len(discussions), + "resultsShared": min(20, len(discussions)), + } async def _get_pr(self, args: Dict[str, Any]) -> ToolResult: """Get PR details.""" @@ -335,7 +382,7 @@ class HfRepoGitTool: "draft": "Draft", "open": "Open", "merged": "Merged", - "closed": "Closed" + "closed": "Closed", } status = status_map.get(pr.status, pr.status.capitalize()) type_label = "Pull Request" if pr.is_pull_request else "Discussion" @@ -349,9 +396,13 @@ class HfRepoGitTool: if pr.is_pull_request: if pr.status == "draft": - lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") + lines.append( + f'\nTo add commits: upload with revision="refs/pr/{pr_num}"' + ) elif pr.status == "open": - lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"") + lines.append( + f'\nTo add commits: upload with revision="refs/pr/{pr_num}"' + ) return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1} @@ -377,7 +428,11 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" - return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"**PR #{pr_num} merged**\n{url}", + "totalResults": 1, + "resultsShared": 1, + } async def _close_pr(self, args: Dict[str, Any]) -> ToolResult: """Close a PR/discussion.""" @@ -401,7 +456,11 @@ class HfRepoGitTool: repo_type=repo_type, ) - return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"**Discussion #{pr_num} closed**", + "totalResults": 1, + "resultsShared": 1, + } async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult: """Add a comment to a PR/discussion.""" @@ -427,7 +486,11 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" - return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"**Comment added to #{pr_num}**\n{url}", + "totalResults": 1, + "resultsShared": 1, + } async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult: """Change PR/discussion status (mainly to convert draft to open).""" @@ -455,7 +518,11 @@ class HfRepoGitTool: ) url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}" - return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", + "totalResults": 1, + "resultsShared": 1, + } # ========================================================================= # REPO MANAGEMENT @@ -473,7 +540,9 @@ class HfRepoGitTool: space_sdk = args.get("space_sdk") if repo_type == "space" and not space_sdk: - return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)") + return self._error( + "space_sdk required for spaces (gradio/streamlit/docker/static)" + ) kwargs = { "repo_id": repo_id, @@ -485,6 +554,17 @@ class HfRepoGitTool: kwargs["space_sdk"] = space_sdk result = await _async_call(self.api.create_repo, **kwargs) + extra_metadata = None + if repo_type == "space" and space_sdk: + extra_metadata = {"sdk": space_sdk} + await _async_call( + register_hub_artifact, + self.api, + repo_id, + repo_type, + session=self.session, + extra_metadata=extra_metadata, + ) return { "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}", @@ -504,7 +584,9 @@ class HfRepoGitTool: gated = args.get("gated") if private is None and gated is None: - return self._error("Specify private (bool) or gated ('auto'/'manual'/false)") + return self._error( + "Specify private (bool) or gated ('auto'/'manual'/false)" + ) kwargs = {"repo_id": repo_id, "repo_type": repo_type} if private is not None: @@ -521,11 +603,20 @@ class HfRepoGitTool: changes.append(f"gated={gated}") url = f"{_build_repo_url(repo_id, repo_type)}/settings" - return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1} + return { + "formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", + "totalResults": 1, + "resultsShared": 1, + } def _error(self, message: str) -> ToolResult: """Return an error result.""" - return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True} + return { + "formatted": message, + "totalResults": 0, + "resultsShared": 0, + "isError": True, + } # Tool specification @@ -571,10 +662,20 @@ HF_REPO_GIT_TOOL_SPEC = { "operation": { "type": "string", "enum": [ - "create_branch", "delete_branch", - "create_tag", "delete_tag", "list_refs", - "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status", - "create_repo", "update_repo", + "create_branch", + "delete_branch", + "create_tag", + "delete_tag", + "list_refs", + "create_pr", + "list_prs", + "get_pr", + "merge_pr", + "close_pr", + "comment_pr", + "change_pr_status", + "create_repo", + "update_repo", ], "description": "Operation to execute", }, @@ -653,11 +754,13 @@ HF_REPO_GIT_TOOL_SPEC = { } -async def hf_repo_git_handler(arguments: Dict[str, Any], session=None) -> tuple[str, bool]: +async def hf_repo_git_handler( + arguments: Dict[str, Any], session=None +) -> tuple[str, bool]: """Handler for agent tool router.""" try: hf_token = session.hf_token if session else None - tool = HfRepoGitTool(hf_token=hf_token) + tool = HfRepoGitTool(hf_token=hf_token, session=session) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: diff --git a/agent/tools/jobs_tool.py b/agent/tools/jobs_tool.py index 44533a8725e37548ccbd0529409452f60a1e5d8d..c058921a2f18bb631c77cbfd8f021253a28ecee6 100644 --- a/agent/tools/jobs_tool.py +++ b/agent/tools/jobs_tool.py @@ -7,22 +7,24 @@ Refactored to use official huggingface-hub library instead of custom HTTP client import asyncio import base64 import http.client -import os -import re -from typing import Any, Dict, Literal, Optional, Callable, Awaitable - import logging +import re +import shlex +from typing import Any, Awaitable, Callable, Dict, Literal, Optional import httpx from huggingface_hub import HfApi from huggingface_hub.utils import HfHubHTTPError -from agent.core.hf_access import JobsAccessError, is_billing_error, resolve_jobs_namespace +from agent.core.hf_access import ( + JobsAccessError, + is_billing_error, + resolve_jobs_namespace, +) +from agent.core.hub_artifacts import build_hub_artifact_sitecustomize from agent.core.session import Event from agent.tools.trackio_seed import ensure_trackio_dashboard from agent.tools.types import ToolResult - -logger = logging.getLogger(__name__) from agent.tools.utilities import ( format_job_details, format_jobs_table, @@ -30,6 +32,8 @@ from agent.tools.utilities import ( format_scheduled_jobs_table, ) +logger = logging.getLogger(__name__) + # Hardware flavors CPU_FLAVORS = ["cpu-basic", "cpu-upgrade"] GPU_FLAVORS = [ @@ -119,11 +123,11 @@ def _filter_uv_install_output(logs: list[str]) -> list[str]: return logs -_ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07') +_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07") def _strip_ansi(text: str) -> str: - return _ANSI_RE.sub('', text) + return _ANSI_RE.sub("", text) _DEFAULT_ENV = { @@ -235,6 +239,26 @@ def _resolve_uv_command( return _build_uv_command(script, with_deps, python, script_args) +def _wrap_command_with_artifact_bootstrap( + command: list[str], session: Any = None +) -> list[str]: + """Install sitecustomize hooks before the user command runs in HF Jobs.""" + sitecustomize = build_hub_artifact_sitecustomize(session) + if not sitecustomize: + return command + + encoded = base64.b64encode(sitecustomize.encode("utf-8")).decode("ascii") + original_command = shlex.join(command) + shell = ( + 'set -e; _ml_intern_artifacts_dir="$(mktemp -d)"; ' + f"printf %s {shlex.quote(encoded)} | base64 -d " + '> "$_ml_intern_artifacts_dir/sitecustomize.py"; ' + 'export PYTHONPATH="$_ml_intern_artifacts_dir${PYTHONPATH:+:$PYTHONPATH}"; ' + f"exec {original_command}" + ) + return ["/bin/sh", "-lc", shell] + + async def _async_call(func, *args, **kwargs): """Wrap synchronous HfApi calls for async context""" return await asyncio.to_thread(func, *args, **kwargs) @@ -432,7 +456,9 @@ class HfJobsTool: def log_producer(): try: # fetch_job_logs is a blocking sync generator - logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace) + logs_gen = self.api.fetch_job_logs( + job_id=job_id, namespace=namespace + ) for line in logs_gen: # Push line to queue thread-safely loop.call_soon_threadsafe(queue.put_nowait, line) @@ -556,6 +582,8 @@ class HfJobsTool: image = args.get("image", "python:3.12") job_type = "Docker" + command = _wrap_command_with_artifact_bootstrap(command, self.session) + # Run the job flavor = args.get("hardware_flavor", "cpu-basic") timeout_str = args.get("timeout", "30m") @@ -578,7 +606,9 @@ class HfJobsTool: image=image, command=command, env=env_dict, - secrets=_add_environment_variables(args.get("secrets"), self.hf_token), + secrets=_add_environment_variables( + args.get("secrets"), self.hf_token + ), flavor=flavor, timeout=timeout_str, namespace=self.namespace, @@ -636,10 +666,18 @@ class HfJobsTool: submit_ts = None if self.session: from agent.core import telemetry + submit_ts = await telemetry.record_hf_job_submit( - self.session, job, - {**args, "hardware_flavor": flavor, "timeout": timeout_str, "namespace": self.namespace}, - image=image, job_type=job_type, + self.session, + job, + { + **args, + "hardware_flavor": flavor, + "timeout": timeout_str, + "namespace": self.namespace, + }, + image=image, + job_type=job_type, ) # Top-up signal: this submit succeeded after a prior billing # block in the same session, and we haven't fired the event @@ -656,7 +694,8 @@ class HfJobsTool: ) if blocked: await telemetry.record_credits_topped_up( - self.session, namespace=self.namespace, + self.session, + namespace=self.namespace, ) # Wait for completion and stream logs @@ -670,9 +709,13 @@ class HfJobsTool: if self.session and submit_ts is not None: from agent.core import telemetry + await telemetry.record_hf_job_complete( - self.session, job, - flavor=flavor, final_status=final_status, submit_ts=submit_ts, + self.session, + job, + flavor=flavor, + final_status=final_status, + submit_ts=submit_ts, ) # Untrack job ID (completed or failed, no longer needs cancellation) @@ -699,7 +742,9 @@ class HfJobsTool: filtered_logs = _filter_uv_install_output(all_logs) # Format all logs for the agent - log_text = _strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)" + log_text = ( + _strip_ansi("\n".join(filtered_logs)) if filtered_logs else "(no logs)" + ) response = f"""{job_type} job completed! @@ -891,6 +936,8 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}} image = args.get("image", "python:3.12") job_type = "Docker" + command = _wrap_command_with_artifact_bootstrap(command, self.session) + # Create scheduled job scheduled_job = await _async_call( self.api.create_scheduled_job, @@ -1215,6 +1262,7 @@ async def hf_jobs_handler( sandbox = getattr(session, "sandbox", None) if session else None if sandbox and script: from agent.tools.sandbox_tool import resolve_sandbox_script + content, error = await resolve_sandbox_script(sandbox, script) if error: return error, False diff --git a/agent/tools/local_tools.py b/agent/tools/local_tools.py index fc456f682eb54fec8a2ee29d5fba07a7d6a4a324..50cd5bd65b517f8855ceeb87ffade52a04e25a15 100644 --- a/agent/tools/local_tools.py +++ b/agent/tools/local_tools.py @@ -15,6 +15,8 @@ import tempfile from pathlib import Path from typing import Any +from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap + MAX_OUTPUT_CHARS = 25_000 MAX_LINE_LENGTH = 4000 @@ -22,7 +24,7 @@ DEFAULT_READ_LINES = 2000 DEFAULT_TIMEOUT = 120 MAX_TIMEOUT = 36000 # 10 hours — needed for long training runs (e.g. PostTrainBench) -_ANSI_RE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07') +_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\].*?\x07") # Track files that have been read this session (enforces read-before-write/edit) _files_read: set[str] = set() @@ -63,17 +65,21 @@ def _atomic_write(path: Path, content: str) -> None: def _strip_ansi(text: str) -> str: - return _ANSI_RE.sub('', text) + return _ANSI_RE.sub("", text) -def _truncate_output(output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25) -> str: +def _truncate_output( + output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: float = 0.25 +) -> str: """Tail-biased truncation with temp file spillover for full output access.""" if len(output) <= max_chars: return output # Write full output to temp file so LLM can read specific sections spill_path = None try: - with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', prefix='bash_output_', delete=False) as f: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".txt", prefix="bash_output_", delete=False + ) as f: f.write(output) spill_path = f.name except Exception: @@ -93,10 +99,14 @@ def _truncate_output(output: str, max_chars: int = MAX_OUTPUT_CHARS, head_ratio: # ── Handlers ──────────────────────────────────────────────────────────── -async def _bash_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: + +async def _bash_handler( + args: dict[str, Any], session: Any = None, **_kw +) -> tuple[str, bool]: command = args.get("command", "") if not command: return "No command provided.", False + command = wrap_shell_command_with_hub_artifact_bootstrap(command, session) work_dir = args.get("work_dir", ".") timeout = min(args.get("timeout") or DEFAULT_TIMEOUT, MAX_TIMEOUT) try: @@ -174,9 +184,12 @@ async def _write_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: # Syntax validation for Python files if p.suffix == ".py": from agent.tools.edit_utils import validate_python + warnings = validate_python(content, file_path) if warnings: - msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings) + msg += "\n\nValidation warnings:\n" + "\n".join( + f" ⚠ {w}" for w in warnings + ) return msg, True except Exception as e: return f"write error: {e}", False @@ -229,7 +242,9 @@ async def _edit_handler(args: dict[str, Any], **_kw) -> tuple[str, bool]: if p.suffix == ".py": warnings = validate_python(new_text, file_path) if warnings: - msg += "\n\nValidation warnings:\n" + "\n".join(f" ⚠ {w}" for w in warnings) + msg += "\n\nValidation warnings:\n" + "\n".join( + f" ⚠ {w}" for w in warnings + ) return msg, True diff --git a/agent/tools/papers_tool.py b/agent/tools/papers_tool.py index 4032a770367ff5ba68504c001eb142df45a4f394..dea63d7d327999303e76c7e3e155d90107a2fd4f 100644 --- a/agent/tools/papers_tool.py +++ b/agent/tools/papers_tool.py @@ -102,7 +102,9 @@ async def _s2_request( async def _s2_get_json( - client: httpx.AsyncClient, path: str, params: dict | None = None, + client: httpx.AsyncClient, + path: str, + params: dict | None = None, ) -> dict | None: """Cached S2 GET returning parsed JSON or None.""" key = _s2_cache_key(path, params) @@ -119,7 +121,9 @@ async def _s2_get_json( async def _s2_get_paper( - client: httpx.AsyncClient, arxiv_id: str, fields: str, + client: httpx.AsyncClient, + arxiv_id: str, + fields: str, ) -> dict | None: """Fetch a single paper from S2 by arxiv ID. Returns None on failure.""" return await _s2_get_json( @@ -322,7 +326,9 @@ def _format_paper_detail(paper: dict, s2_data: dict | None = None) -> str: if keywords: lines.append(f"**Keywords:** {', '.join(keywords)}") if s2_data and s2_data.get("s2FieldsOfStudy"): - fields = [f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category")] + fields = [ + f["category"] for f in s2_data["s2FieldsOfStudy"] if f.get("category") + ] if fields: lines.append(f"**Fields:** {', '.join(fields)}") if s2_data and s2_data.get("venue"): @@ -393,7 +399,9 @@ def _format_datasets(datasets: list, arxiv_id: str, sort: str) -> str: ds_id = ds.get("id", "unknown") downloads = ds.get("downloads", 0) likes = ds.get("likes", 0) - desc = _truncate(_clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN) + desc = _truncate( + _clean_description(ds.get("description") or ""), MAX_SUMMARY_LEN + ) tags = ds.get("tags") or [] interesting = [t for t in tags if not t.startswith(("arxiv:", "region:"))][:5] @@ -582,11 +590,15 @@ def _format_s2_paper_list(papers: list[dict], title: str) -> str: lines.append(f"**TL;DR:** {tldr}") lines.append("") - lines.append("Use paper_details with arxiv_id for full info, or read_paper to read sections.") + lines.append( + "Use paper_details with arxiv_id for full info, or read_paper to read sections." + ) return "\n".join(lines) -async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolResult | None: +async def _s2_bulk_search( + query: str, args: dict[str, Any], limit: int +) -> ToolResult | None: """Search via S2 bulk endpoint with filters. Returns None on failure.""" params: dict[str, Any] = { "query": query, @@ -616,7 +628,9 @@ async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolR params["sort"] = f"{sort_by}:desc" async with httpx.AsyncClient(timeout=15) as client: - resp = await _s2_request(client, "GET", "/graph/v1/paper/search/bulk", params=params) + resp = await _s2_request( + client, "GET", "/graph/v1/paper/search/bulk", params=params + ) if not resp or resp.status_code != 200: return None data = resp.json() @@ -629,7 +643,9 @@ async def _s2_bulk_search(query: str, args: dict[str, Any], limit: int) -> ToolR "resultsShared": 0, } - formatted = _format_s2_paper_list(papers[:limit], f"Papers matching '{query}' (Semantic Scholar)") + formatted = _format_s2_paper_list( + papers[:limit], f"Papers matching '{query}' (Semantic Scholar)" + ) return { "formatted": formatted, "totalResults": data.get("total", len(papers)), @@ -643,7 +659,10 @@ async def _op_search(args: dict[str, Any], limit: int) -> ToolResult: return _error("'query' is required for search operation.") # Route to S2 when filters are present - use_s2 = any(args.get(k) for k in ("date_from", "date_to", "categories", "min_citations", "sort_by")) + use_s2 = any( + args.get(k) + for k in ("date_from", "date_to", "categories", "min_citations", "sort_by") + ) if use_s2: result = await _s2_bulk_search(query, args, limit) if result is not None: @@ -806,7 +825,9 @@ def _format_citation_graph( lines.append("No citations found.") lines.append("") - lines.append("**Tip:** Use paper_details with an arxiv_id from above to explore further.") + lines.append( + "**Tip:** Use paper_details with an arxiv_id from above to explore further." + ) return "\n".join(lines) @@ -824,9 +845,13 @@ async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult: refs, cites = None, None coros = [] if direction in ("references", "both"): - coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params)) + coros.append( + _s2_get_json(client, f"/graph/v1/paper/{s2_id}/references", params) + ) if direction in ("citations", "both"): - coros.append(_s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params)) + coros.append( + _s2_get_json(client, f"/graph/v1/paper/{s2_id}/citations", params) + ) results = await asyncio.gather(*coros, return_exceptions=True) idx = 0 @@ -841,7 +866,9 @@ async def _op_citation_graph(args: dict[str, Any], limit: int) -> ToolResult: cites = r.get("data", []) if refs is None and cites is None: - return _error(f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar.") + return _error( + f"Could not fetch citation data for {arxiv_id}. Paper may not be indexed by Semantic Scholar." + ) total = (len(refs) if refs else 0) + (len(cites) if cites else 0) return { @@ -1039,7 +1066,9 @@ def _format_snippets(snippets: list[dict], query: str) -> str: lines.append(f"> {_truncate(text, 400)}") lines.append("") - lines.append("Use paper_details or read_paper with arxiv_id to explore a paper further.") + lines.append( + "Use paper_details or read_paper with arxiv_id to explore a paper further." + ) return "\n".join(lines) @@ -1065,7 +1094,9 @@ async def _op_snippet_search(args: dict[str, Any], limit: int) -> ToolResult: params["minCitationCount"] = str(args["min_citations"]) async with httpx.AsyncClient(timeout=15) as client: - resp = await _s2_request(client, "GET", "/graph/v1/snippet/search", params=params) + resp = await _s2_request( + client, "GET", "/graph/v1/snippet/search", params=params + ) if not resp or resp.status_code != 200: return _error("Snippet search failed. Semantic Scholar may be unavailable.") data = resp.json() @@ -1102,16 +1133,28 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult: async with httpx.AsyncClient(timeout=15) as client: if positive_ids and not arxiv_id: # Multi-paper recommendations (POST, not cached) - pos = [_s2_paper_id(pid.strip()) for pid in positive_ids.split(",") if pid.strip()] + pos = [ + _s2_paper_id(pid.strip()) + for pid in positive_ids.split(",") + if pid.strip() + ] neg_raw = args.get("negative_ids", "") - neg = [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] if neg_raw else [] + neg = ( + [_s2_paper_id(pid.strip()) for pid in neg_raw.split(",") if pid.strip()] + if neg_raw + else [] + ) resp = await _s2_request( - client, "POST", "/recommendations/v1/papers/", + client, + "POST", + "/recommendations/v1/papers/", json={"positivePaperIds": pos, "negativePaperIds": neg}, params={"fields": fields, "limit": limit}, ) if not resp or resp.status_code != 200: - return _error("Recommendation request failed. Semantic Scholar may be unavailable.") + return _error( + "Recommendation request failed. Semantic Scholar may be unavailable." + ) data = resp.json() else: # Single-paper recommendations (cached) @@ -1121,7 +1164,9 @@ async def _op_recommend(args: dict[str, Any], limit: int) -> ToolResult: {"fields": fields, "limit": limit, "from": "recent"}, ) if not data: - return _error("Recommendation request failed. Semantic Scholar may be unavailable.") + return _error( + "Recommendation request failed. Semantic Scholar may be unavailable." + ) papers = data.get("recommendedPapers") or [] if not papers: diff --git a/agent/tools/research_tool.py b/agent/tools/research_tool.py index c4480b974fcb57a51a775f1bd6df6ff7db614103..f5815be8332ef371d3e863652bfc6cdd5127bbc2 100644 --- a/agent/tools/research_tool.py +++ b/agent/tools/research_tool.py @@ -282,6 +282,7 @@ async def research_handler( _agent_id = tool_call_id else: import uuid + _agent_id = uuid.uuid4().hex[:8] _agent_label = "research: " + (task[:50] + "…" if len(task) > 50 else task) @@ -289,12 +290,15 @@ async def research_handler( """Send a progress event to the UI so it doesn't look frozen.""" try: await session.send_event( - Event(event_type="tool_log", data={ - "tool": "research", - "log": text, - "agent_id": _agent_id, - "label": _agent_label, - }) + Event( + event_type="tool_log", + data={ + "tool": "research", + "log": text, + "agent_id": _agent_id, + "label": _agent_label, + }, + ) ) except Exception: pass @@ -323,15 +327,19 @@ async def research_handler( "Research sub-agent hit context max (%d tokens) — forcing summary", _total_tokens, ) - await _log(f"Context limit reached ({_total_tokens} tokens) — forcing wrap-up") + await _log( + f"Context limit reached ({_total_tokens} tokens) — forcing wrap-up" + ) # Ask for a final summary with no tools - messages.append(Message( - role="user", - content=( - "[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. " - "Summarize your findings NOW. Do NOT call any more tools." - ), - )) + messages.append( + Message( + role="user", + content=( + "[SYSTEM: CONTEXT LIMIT REACHED] You have used all available context. " + "Summarize your findings NOW. Do NOT call any more tools." + ), + ) + ) try: _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model")) _t0 = time.monotonic() @@ -351,27 +359,34 @@ async def research_handler( model=research_model, response=response, latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason if response.choices else None, + finish_reason=response.choices[0].finish_reason + if response.choices + else None, kind="research", ) except Exception as _telem_err: logger.debug("research telemetry failed: %s", _telem_err) content = response.choices[0].message.content or "" - return content or "Research context exhausted — no summary produced.", bool(content) + return ( + content or "Research context exhausted — no summary produced.", + bool(content), + ) except Exception: return "Research context exhausted and summary call failed.", False if not _warned_context and _total_tokens >= _RESEARCH_CONTEXT_WARN: _warned_context = True await _log(f"Context at {_total_tokens} tokens — nudging to wrap up") - messages.append(Message( - role="user", - content=( - "[SYSTEM: You have used 75% of your context budget. " - "Start wrapping up: finish any critical lookups, then " - "produce your final summary within the next 1-2 iterations.]" - ), - )) + messages.append( + Message( + role="user", + content=( + "[SYSTEM: You have used 75% of your context budget. " + "Start wrapping up: finish any critical lookups, then " + "produce your final summary within the next 1-2 iterations.]" + ), + ) + ) try: _msgs, _tools = with_prompt_caching( @@ -392,7 +407,9 @@ async def research_handler( model=research_model, response=response, latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason if response.choices else None, + finish_reason=response.choices[0].finish_reason + if response.choices + else None, kind="research", ) except Exception as _telem_err: @@ -420,11 +437,13 @@ async def research_handler( # LiteLLM's raw Message carries `provider_specific_fields` and # `reasoning_content`, which the HF router's OpenAI schema rejects # if we echo them back in the next request. - messages.append(Message( - role="assistant", - content=msg.content, - tool_calls=msg.tool_calls, - )) + messages.append( + Message( + role="assistant", + content=msg.content, + tool_calls=msg.tool_calls, + ) + ) for tc in msg.tool_calls: try: tool_args = json.loads(tc.function.arguments) @@ -479,13 +498,15 @@ async def research_handler( # ── Iteration limit: try to salvage findings ── await _log("Iteration limit reached — extracting summary") - messages.append(Message( - role="user", - content=( - "[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research " - "iterations. Summarize ALL findings so far. Do NOT call any more tools." - ), - )) + messages.append( + Message( + role="user", + content=( + "[SYSTEM: ITERATION LIMIT] You have reached the maximum number of research " + "iterations. Summarize ALL findings so far. Do NOT call any more tools." + ), + ) + ) try: _msgs, _ = with_prompt_caching(messages, None, llm_params.get("model")) _t0 = time.monotonic() @@ -502,7 +523,9 @@ async def research_handler( model=research_model, response=response, latency_ms=int((time.monotonic() - _t0) * 1000), - finish_reason=response.choices[0].finish_reason if response.choices else None, + finish_reason=response.choices[0].finish_reason + if response.choices + else None, kind="research", ) except Exception as _telem_err: diff --git a/agent/tools/sandbox_client.py b/agent/tools/sandbox_client.py index 4590b8d67b7251654ab7ca8c0b5b9f9f2ee2b49a..24f85e1fb719549ddca4925490e44e9a275986af 100644 --- a/agent/tools/sandbox_client.py +++ b/agent/tools/sandbox_client.py @@ -729,9 +729,7 @@ class Sandbox: runtime, "requested_hardware", None ) if current_hardware != hardware: - _log( - f" RUNNING on {current_hardware}; waiting for {hardware}..." - ) + _log(f" RUNNING on {current_hardware}; waiting for {hardware}...") time.sleep(WAIT_INTERVAL) continue _log(f"Space is running (hardware: {runtime.hardware})") @@ -767,7 +765,9 @@ class Sandbox: return sb @staticmethod - def _setup_server(space_id: str, api: HfApi, *, log: Callable[[str], object] = print) -> None: + def _setup_server( + space_id: str, api: HfApi, *, log: Callable[[str], object] = print + ) -> None: """Upload embedded sandbox server + Dockerfile to the Space (single commit).""" log(f"Uploading sandbox server to {space_id}...") api.create_commit( @@ -809,7 +809,9 @@ class Sandbox: sb._wait_for_api(timeout=60) return sb - def _wait_for_api(self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print): + def _wait_for_api( + self, timeout: int = API_WAIT_TIMEOUT, log: Callable[[str], object] = print + ): """Poll the health endpoint until the server responds.""" deadline = time.time() + timeout last_err = None @@ -986,7 +988,12 @@ class Sandbox: return result def edit( - self, path: str, old_str: str, new_str: str, *, replace_all: bool = False, + self, + path: str, + old_str: str, + new_str: str, + *, + replace_all: bool = False, mode: str = "replace", ) -> ToolResult: if old_str == new_str: diff --git a/agent/tools/sandbox_tool.py b/agent/tools/sandbox_tool.py index 81d3f8e9e96be5daa805cff5370aca341007d37e..4d643f4ce5ac5e4c887b1d31fada2d7746d68ccd 100644 --- a/agent/tools/sandbox_tool.py +++ b/agent/tools/sandbox_tool.py @@ -21,6 +21,7 @@ from typing import Any from huggingface_hub import HfApi, SpaceHardware +from agent.core.hub_artifacts import wrap_shell_command_with_hub_artifact_bootstrap from agent.core.session import Event from agent.tools.sandbox_client import Sandbox from agent.tools.trackio_seed import ensure_trackio_dashboard @@ -197,7 +198,9 @@ def _cleanup_user_orphan_sandboxes( if not _SANDBOX_NAME_RE.match(space_name): continue - last_mod = getattr(space, "lastModified", None) or getattr(space, "last_modified", None) + last_mod = getattr(space, "lastModified", None) or getattr( + space, "last_modified", None + ) if isinstance(last_mod, str): try: last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00")) @@ -337,6 +340,7 @@ async def _create_sandbox_locked( if hardware != DEFAULT_CPU_SANDBOX_HARDWARE: kwargs["sleep_time"] = 2700 import time as _t + _t_start = _t.monotonic() try: sb = await asyncio.to_thread(Sandbox.create, **kwargs) @@ -350,7 +354,9 @@ async def _create_sandbox_locked( try: await asyncio.to_thread(sb.delete) except Exception as e: - logger.warning("Failed to delete cancelled sandbox %s: %s", sb.space_id, e) + logger.warning( + "Failed to delete cancelled sandbox %s: %s", sb.space_id, e + ) return None, "Sandbox creation cancelled by user." session.sandbox = sb @@ -360,8 +366,11 @@ async def _create_sandbox_locked( # Telemetry: sandbox creation (infra consumption signal) from agent.core import telemetry + await telemetry.record_sandbox_create( - session, sb, hardware=hardware, + session, + sb, + hardware=hardware, create_latency_s=int(_t.monotonic() - _t_start), ) @@ -510,12 +519,13 @@ async def teardown_session_sandbox(session: Any) -> None: ) await asyncio.to_thread(sandbox.delete) from agent.core import telemetry + await telemetry.record_sandbox_destroy(session, sandbox) return except Exception as e: last_err = e if attempt < 2: - await asyncio.sleep(2 ** attempt) + await asyncio.sleep(2**attempt) logger.error( "Failed to delete sandbox %s after 3 attempts: %s. " "Orphan — sweep script will pick it up.", @@ -720,6 +730,14 @@ def _make_tool_handler(sandbox_tool_name: str): return "Sandbox is still starting. Please retry shortly.", False try: + if sandbox_tool_name == "bash" and args.get("command"): + args = { + **args, + "command": wrap_shell_command_with_hub_artifact_bootstrap( + args["command"], + session, + ), + } result = await asyncio.to_thread(sb.call_tool, sandbox_tool_name, args) if result.success: output = result.output or "(no output)" @@ -758,8 +776,7 @@ def get_sandbox_tools(): description = ( "Uses the session's active sandbox. A private cpu-basic sandbox is " "started automatically for normal CPU work; call sandbox_create only " - "for GPU or other non-default hardware.\n\n" - + spec["description"] + "for GPU or other non-default hardware.\n\n" + spec["description"] ) tools.append( ToolSpec( diff --git a/agent/tools/web_search_tool.py b/agent/tools/web_search_tool.py index 3e52ded03c1f405076e1ecc537d0b4250862f562..5c18410855bebdee305997d90de4c9e56f942461 100644 --- a/agent/tools/web_search_tool.py +++ b/agent/tools/web_search_tool.py @@ -253,7 +253,10 @@ async def web_search_handler( ) -> tuple[str, bool]: query_value = arguments.get("query", "") if not isinstance(query_value, str): - return "Error: web_search requires a query string with at least 2 characters.", False + return ( + "Error: web_search requires a query string with at least 2 characters.", + False, + ) query = query_value.strip() if len(query) < 2: diff --git a/agent/utils/braille.py b/agent/utils/braille.py index 3b6ee43187dc5b1482109e36a7548f9e8199731a..4621b735b7cff25d453afbc93f443f2bae4e7e4b 100644 --- a/agent/utils/braille.py +++ b/agent/utils/braille.py @@ -41,8 +41,7 @@ class BrailleCanvas: for row in range(self.term_height): offset = row * self.term_width line = "".join( - chr(0x2800 + self._buf[offset + col]) - for col in range(self.term_width) + chr(0x2800 + self._buf[offset + col]) for col in range(self.term_width) ) lines.append(line) return lines @@ -52,6 +51,7 @@ class BrailleCanvas: _FONT: dict[str, list[str]] = {} + def _define_font() -> None: """Define a simple 5×7 bitmap font for uppercase ASCII.""" glyphs = { @@ -113,8 +113,9 @@ def text_to_pixels(text: str, scale: int = 1) -> list[tuple[int, int]]: if cell == "#": for sy in range(scale): for sx in range(scale): - pixels.append((cursor_x + col_idx * scale + sx, - row_idx * scale + sy)) + pixels.append( + (cursor_x + col_idx * scale + sx, row_idx * scale + sy) + ) glyph_width = max(len(r) for r in glyph) cursor_x += (glyph_width + 1) * scale return pixels diff --git a/agent/utils/crt_boot.py b/agent/utils/crt_boot.py index f36ea50e863f6a16e315439ac5bb68eceffd4a40..da0867188961ff08952005c7d098879dfd2a4279 100644 --- a/agent/utils/crt_boot.py +++ b/agent/utils/crt_boot.py @@ -55,7 +55,10 @@ def run_boot_sequence(console: Console, boot_lines: list[tuple[str, str]]) -> No # Render previously completed lines for prev_text, prev_style in displayed_lines: if rng.random() < prev_glitch_chance: - result.append(_glitch_text(prev_text, prev_glitch_intensity, rng), style=prev_style) + result.append( + _glitch_text(prev_text, prev_glitch_intensity, rng), + style=prev_style, + ) else: result.append(prev_text, style=prev_style) result.append("\n") @@ -86,7 +89,7 @@ def run_boot_sequence(console: Console, boot_lines: list[tuple[str, str]]) -> No live.update(result) # Variable typing speed - if line_text[char_idx - 1:char_idx] in " .": + if line_text[char_idx - 1 : char_idx] in " .": time.sleep(0.025) else: time.sleep(0.010) diff --git a/agent/utils/particle_logo.py b/agent/utils/particle_logo.py index 7b2ff5a3621c8dada105afb0be710ea71f1327cc..9c3338152a8b2fd29031c4eadaa19e9078f6da2b 100644 --- a/agent/utils/particle_logo.py +++ b/agent/utils/particle_logo.py @@ -23,7 +23,9 @@ from agent.utils.boot_timing import settle_curve, warm_gold_from_white class Particle: __slots__ = ("x", "y", "target_x", "target_y", "vx", "vy", "phase", "delay") - def __init__(self, x: float, y: float, target_x: float, target_y: float, delay: float = 0): + def __init__( + self, x: float, y: float, target_x: float, target_y: float, delay: float = 0 + ): self.x = x self.y = y self.target_x = target_x diff --git a/agent/utils/terminal_display.py b/agent/utils/terminal_display.py index f2b73301b6b39b77db9bab90e4086a2f551d8f38..d464fd8a727de131268a420753c811273943bfde 100644 --- a/agent/utils/terminal_display.py +++ b/agent/utils/terminal_display.py @@ -2,6 +2,7 @@ Terminal display utilities — rich-powered CLI formatting. """ +import asyncio import re from rich.console import Console @@ -57,23 +58,26 @@ def _clip_to_width(s: str, width: int) -> str: out.append("\033[0m…") return "".join(out) -_THEME = Theme({ - "tool.name": "bold rgb(255,200,80)", - "tool.args": "dim", - "tool.ok": "dim green", - "tool.fail": "dim red", - "info": "dim", - "muted": "dim", - # Markdown emphasis colors - "markdown.strong": "bold rgb(255,200,80)", - "markdown.emphasis": "italic rgb(180,140,40)", - "markdown.code": "rgb(120,220,255)", - "markdown.code_block": "rgb(120,220,255)", - "markdown.link": "underline rgb(90,180,255)", - "markdown.h1": "bold rgb(255,200,80)", - "markdown.h2": "bold rgb(240,180,95)", - "markdown.h3": "bold rgb(220,165,100)", -}) + +_THEME = Theme( + { + "tool.name": "bold rgb(255,200,80)", + "tool.args": "dim", + "tool.ok": "dim green", + "tool.fail": "dim red", + "info": "dim", + "muted": "dim", + # Markdown emphasis colors + "markdown.strong": "bold rgb(255,200,80)", + "markdown.emphasis": "italic rgb(180,140,40)", + "markdown.code": "rgb(120,220,255)", + "markdown.code_block": "rgb(120,220,255)", + "markdown.link": "underline rgb(90,180,255)", + "markdown.h1": "bold rgb(255,200,80)", + "markdown.h2": "bold rgb(240,180,95)", + "markdown.h3": "bold rgb(220,165,100)", + } +) _console = Console(theme=_THEME, highlight=False) @@ -87,6 +91,7 @@ def get_console() -> Console: # ── Banner ───────────────────────────────────────────────────────────── + def print_banner(model: str | None = None, hf_user: str | None = None) -> None: """Print particle logo then CRT boot sequence with system info.""" from agent.utils.particle_logo import run_particle_logo @@ -120,12 +125,16 @@ def print_banner(model: str | None = None, hf_user: str | None = None) -> None: # ── Init progress ────────────────────────────────────────────────────── + def print_init_done(tool_count: int = 0) -> None: import time + f = _console.file # Overwrite the "Tools: loading..." line with actual count - f.write(f"\033[A\033[A\033[A\033[K") # Move up 3 lines (blank + help + blank) then up to tools line - f.write(f"\033[A\033[K") + f.write( + "\033[A\033[A\033[A\033[K" + ) # Move up 3 lines (blank + help + blank) then up to tools line + f.write("\033[A\033[K") gold = "\033[38;2;180;140;40m" reset = "\033[0m" tool_text = f"{_I} Tools: {tool_count} loaded" @@ -135,16 +144,22 @@ def print_init_done(tool_count: int = 0) -> None: time.sleep(0.012) f.write("\n\n") # Reprint the help line - f.write(f"{_I}\033[38;2;255;200;80m/help for commands · /model to switch · /quit to exit{reset}\n\n") + f.write( + f"{_I}\033[38;2;255;200;80m/help for commands · /model to switch · /quit to exit{reset}\n\n" + ) # Ready message — minimal padding - f.write(f"{_I}\033[38;2;255;200;80mReady. Let's build something impressive.{reset}\n") + f.write( + f"{_I}\033[38;2;255;200;80mReady. Let's build something impressive.{reset}\n" + ) f.flush() # ── Tool calls ───────────────────────────────────────────────────────── + def print_tool_call(tool_name: str, args_preview: str) -> None: import time + f = _console.file # CRT-style: type out tool name in HF yellow gold = "\033[38;2;255;200;80m" @@ -183,6 +198,7 @@ class SubAgentDisplayManager: def start(self, agent_id: str, label: str = "research") -> None: import time + self._agents[agent_id] = { "label": label, "calls": [], @@ -234,6 +250,7 @@ class SubAgentDisplayManager: @staticmethod def _format_stats(agent: dict) -> str: import time + start = agent["start_time"] if start is None: return "" @@ -276,7 +293,7 @@ class SubAgentDisplayManager: header += f" \033[2m·\033[0m \033[2m{short}\033[0m" return [header] lines = [header] - visible = agent["calls"][-self._MAX_VISIBLE:] + visible = agent["calls"][-self._MAX_VISIBLE :] for desc in visible: lines.append(f"{_I} \033[2m{desc}\033[0m") return lines @@ -319,13 +336,14 @@ def print_tool_log(tool: str, log: str, agent_id: str = "", label: str = "") -> # ── Messages ─────────────────────────────────────────────────────────── + async def print_markdown( text: str, cancel_event: "asyncio.Event | None" = None, instant: bool = False, ) -> None: - import asyncio - import io, random + import io + import random from rich.padding import Padding _console.print() @@ -395,23 +413,35 @@ def print_interrupted() -> None: def print_compacted(old_tokens: int, new_tokens: int) -> None: - _console.print(f"{_I}[dim]context compacted: {old_tokens:,} → {new_tokens:,} tokens[/dim]") + _console.print( + f"{_I}[dim]context compacted: {old_tokens:,} → {new_tokens:,} tokens[/dim]" + ) # ── Approval ─────────────────────────────────────────────────────────── + def print_approval_header(count: int) -> None: label = f"Approval required — {count} item{'s' if count != 1 else ''}" _console.print() - _console.print(f"{_I}", Panel(f"[bold yellow]{label}[/bold yellow]", border_style="yellow", expand=False)) + _console.print( + f"{_I}", + Panel( + f"[bold yellow]{label}[/bold yellow]", border_style="yellow", expand=False + ), + ) def print_approval_item(index: int, total: int, tool_name: str, operation: str) -> None: - _console.print(f"\n{_I}[bold]\\[{index}/{total}][/bold] [tool.name]{tool_name}[/tool.name] {operation}") + _console.print( + f"\n{_I}[bold]\\[{index}/{total}][/bold] [tool.name]{tool_name}[/tool.name] {operation}" + ) def print_yolo_approve(count: int) -> None: - _console.print(f"{_I}[bold yellow]yolo →[/bold yellow] auto-approved {count} item(s)") + _console.print( + f"{_I}[bold yellow]yolo →[/bold yellow] auto-approved {count} item(s)" + ) # ── Help ─────────────────────────────────────────────────────────────── @@ -437,6 +467,7 @@ def print_help() -> None: # ── Plan display ─────────────────────────────────────────────────────── + def format_plan_display() -> str: """Format the current plan for display.""" from agent.tools.plan_tool import get_current_plan @@ -470,6 +501,7 @@ def print_plan() -> None: # ── Formatting for plan_tool output (used by plan_tool handler) ──────── + def format_plan_tool_output(todos: list) -> str: if not todos: return "Plan is empty." @@ -492,6 +524,7 @@ def format_plan_tool_output(todos: list) -> str: # ── Internal helpers ─────────────────────────────────────────────────── + def _truncate(text: str, max_lines: int = 6) -> str: lines = text.split("\n") if len(lines) <= max_lines: diff --git a/backend/dependencies.py b/backend/dependencies.py index 3ca14136ac29273627906a183b1a57f6196c7e75..c207890ad731b3e43e0acecc9f9092bf8210a570 100644 --- a/backend/dependencies.py +++ b/backend/dependencies.py @@ -102,7 +102,9 @@ async def _fetch_user_plan(token: str) -> str: _WHOAMI_SHAPE_LOGGED = True logger.debug( "whoami-v2 payload keys: %s (sample values: plan=%r type=%r isPro=%r)", - sorted(whoami.keys()) if isinstance(whoami, dict) else type(whoami).__name__, + sorted(whoami.keys()) + if isinstance(whoami, dict) + else type(whoami).__name__, whoami.get("plan") if isinstance(whoami, dict) else None, whoami.get("type") if isinstance(whoami, dict) else None, whoami.get("isPro") if isinstance(whoami, dict) else None, diff --git a/backend/kpis_scheduler.py b/backend/kpis_scheduler.py index f044c8ee5fc8e3a855528a41d6b17769cd8c0b2b..9b2199c69151118762ed2cfaddde579fb5a694d3 100644 --- a/backend/kpis_scheduler.py +++ b/backend/kpis_scheduler.py @@ -58,7 +58,8 @@ def _resolve_token() -> Optional[str]: def _load_build_kpis(): """Import ``scripts/build_kpis.py`` without putting ``scripts/`` on sys.path.""" spec = importlib.util.spec_from_file_location( - "build_kpis", _PROJECT_ROOT / "scripts" / "build_kpis.py", + "build_kpis", + _PROJECT_ROOT / "scripts" / "build_kpis.py", ) mod = importlib.util.module_from_spec(spec) assert spec.loader is not None @@ -75,6 +76,7 @@ async def _run_hour(hour_dt: datetime) -> None: try: mod = _load_build_kpis() from huggingface_hub import HfApi + api = HfApi() source = os.environ.get("KPI_SOURCE_REPO", "smolagents/ml-intern-sessions") target = os.environ.get("KPI_TARGET_REPO", "smolagents/ml-intern-kpis") @@ -118,7 +120,7 @@ def start(backfill_hours: int = 6) -> None: CronTrigger(minute=5), id="kpis_hourly", misfire_grace_time=600, # tolerate a 10-min misfire window - coalesce=True, # collapse multiple missed fires into one + coalesce=True, # collapse multiple missed fires into one max_instances=1, replace_existing=True, ) diff --git a/backend/main.py b/backend/main.py index f6bc64d10167de32763d5c2f9f4bcc01f69eab57..3a6871f266637f17da0645aae6403155e4f7d6bf 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6,17 +6,17 @@ from contextlib import asynccontextmanager from pathlib import Path from dotenv import load_dotenv +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles # Load .env before importing routes/session_manager so persistence and quota # modules see local Mongo settings during startup. load_dotenv(Path(__file__).parent.parent / ".env") -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles -from routes.agent import router as agent_router -from routes.auth import router as auth_router -from session_manager import session_manager +from routes.agent import router as agent_router # noqa: E402 +from routes.auth import router as auth_router # noqa: E402 +from session_manager import session_manager # noqa: E402 # Configure logging logging.basicConfig( @@ -35,6 +35,7 @@ async def lifespan(app: FastAPI): # rollup lives next to the data and reuses the Space's HF token. try: import kpis_scheduler + kpis_scheduler.start() except Exception as e: logger.warning("KPI scheduler failed to start: %s", e) @@ -43,6 +44,7 @@ async def lifespan(app: FastAPI): logger.info("Shutting down HF Agent backend...") try: import kpis_scheduler + await kpis_scheduler.shutdown() except Exception as e: logger.warning("KPI scheduler shutdown failed: %s", e) diff --git a/backend/models.py b/backend/models.py index 40a725096968cff9c46f26d981daef3a0ea82810..d74557a51df06392381cc2d3aa0463f547cb8496 100644 --- a/backend/models.py +++ b/backend/models.py @@ -131,4 +131,6 @@ class LLMHealthResponse(BaseModel): status: str # "ok" | "error" model: str error: str | None = None - error_type: str | None = None # "auth" | "credits" | "rate_limit" | "network" | "unknown" + error_type: str | None = ( + None # "auth" | "credits" | "rate_limit" | "network" | "unknown" + ) diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 74af774b6895f4b9eeb94e007b0d70a4caa56327..5de4ace5848e4c03f7173dda46736498e0548c2c 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -7,7 +7,6 @@ dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically. import asyncio import json import logging -import os from typing import Any from dependencies import ( @@ -34,7 +33,12 @@ from models import ( SubmitRequest, TruncateRequest, ) -from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, session_manager +from session_manager import ( + MAX_SESSIONS, + AgentSession, + SessionCapacityError, + session_manager, +) import user_quotas @@ -136,7 +140,7 @@ async def _require_hf_for_gated_model(request: Request, model_id: str) -> None: """403 if a non-``huggingface``-org user tries to select a gated model. Gated models are deployed paid endpoints backed by service-owned - credentials. The gate only fires for deployed paid models so non-HF users + credentials. The gate only fires for deployed paid models so non-HF users can still freely switch between the free models. """ if not _is_gated_model(model_id): @@ -226,7 +230,11 @@ async def _check_session_access( preload_sandbox: bool = True, ) -> AgentSession: """Verify and lazily load the user's session. Raises 403 or 404.""" - hf_token = resolve_hf_request_token(request) if request is not None else _user_hf_token(user) + hf_token = ( + resolve_hf_request_token(request) + if request is not None + else _user_hf_token(user) + ) agent_session = await session_manager.ensure_session_loaded( session_id, user["user_id"], @@ -236,7 +244,10 @@ async def _check_session_access( ) if not agent_session: raise HTTPException(status_code=404, detail="Session not found") - if user["user_id"] != "dev" and agent_session.user_id not in {user["user_id"], "dev"}: + if user["user_id"] != "dev" and agent_session.user_id not in { + user["user_id"], + "dev", + }: raise HTTPException(status_code=403, detail="Access denied to this session") return agent_session @@ -362,7 +373,9 @@ async def generate_title( await _check_session_access(request.session_id, user) await session_manager.update_session_title(request.session_id, title) except Exception: - logger.debug("Skipping title persistence for missing session %s", request.session_id) + logger.debug( + "Skipping title persistence for missing session %s", request.session_id + ) return {"title": title} except Exception as e: logger.warning(f"Title generation failed: {e}") @@ -372,7 +385,10 @@ async def generate_title( await _check_session_access(request.session_id, user) await session_manager.update_session_title(request.session_id, title) except Exception: - logger.debug("Skipping fallback title persistence for missing session %s", request.session_id) + logger.debug( + "Skipping fallback title persistence for missing session %s", + request.session_id, + ) return {"title": title} @@ -586,7 +602,9 @@ async def get_user_quota(user: dict = Depends(get_current_user)) -> dict: @router.get("/user/jobs-access") -async def get_jobs_access_info(request: Request, user: dict = Depends(get_current_user)) -> dict: +async def get_jobs_access_info( + request: Request, user: dict = Depends(get_current_user) +) -> dict: """Return the namespaces the current token can run HF Jobs under. Credits are enforced by the HF API at job-creation time, not here — @@ -652,7 +670,7 @@ async def submit_approval( request: ApprovalRequest, user: dict = Depends(get_current_user) ) -> dict: """Submit tool approvals to a session. Only accessible by the session owner.""" - agent_session = await _check_session_access(request.session_id, user) + await _check_session_access(request.session_id, user) approvals = [ { "tool_call_id": a.tool_call_id, @@ -719,7 +737,9 @@ async def chat_sse( success = await session_manager.submit_user_input(session_id, text) else: broadcaster.unsubscribe(sub_id) - raise HTTPException(status_code=400, detail="Must provide 'text' or 'approvals'") + raise HTTPException( + status_code=400, detail="Must provide 'text' or 'approvals'" + ) if not success: broadcaster.unsubscribe(sub_id) @@ -744,6 +764,7 @@ async def record_pro_click( agent_session = await _check_session_access(session_id, user) from agent.core import telemetry + await telemetry.record_pro_cta_click( agent_session.session, source=str(body.get("source") or "unknown"), @@ -759,12 +780,20 @@ async def record_pro_click( # --------------------------------------------------------------------------- # Shared SSE helpers # --------------------------------------------------------------------------- -_TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted", "shutdown"} +_TERMINAL_EVENTS = { + "turn_complete", + "approval_required", + "error", + "interrupted", + "shutdown", +} _SSE_KEEPALIVE_SECONDS = 15 def _last_event_seq(request: Request) -> int: - raw = request.headers.get("last-event-id") or request.query_params.get("after") or "0" + raw = ( + request.headers.get("last-event-id") or request.query_params.get("after") or "0" + ) try: return max(0, int(raw)) except (TypeError, ValueError): @@ -853,7 +882,9 @@ async def subscribe_events( raise HTTPException(status_code=404, detail="Session not found or inactive") after_seq = _last_event_seq(request) - replay_events = await session_manager._store().load_events_after(session_id, after_seq) + replay_events = await session_manager._store().load_events_after( + session_id, after_seq + ) broadcaster = agent_session.broadcaster sub_id, event_queue = broadcaster.subscribe() return _sse_response( @@ -885,7 +916,10 @@ async def get_session_messages( agent_session = await _check_session_access(session_id, user) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") - return [msg.model_dump(mode="json") for msg in agent_session.session.context_manager.items] + return [ + msg.model_dump(mode="json") + for msg in agent_session.session.context_manager.items + ] @router.post("/undo/{session_id}") @@ -906,7 +940,10 @@ async def truncate_session( await _check_session_access(session_id, user) success = await session_manager.truncate(session_id, body.user_message_index) if not success: - raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range") + raise HTTPException( + status_code=404, + detail="Session not found, inactive, or message index out of range", + ) return {"status": "truncated", "session_id": session_id} @@ -933,6 +970,7 @@ async def shutdown_session( raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} + @router.post("/feedback/{session_id}") async def submit_feedback( session_id: str, @@ -952,6 +990,7 @@ async def submit_feedback( raise HTTPException(status_code=400, detail="invalid rating") from agent.core import telemetry + await telemetry.record_feedback( agent_session.session, rating=rating, diff --git a/backend/routes/auth.py b/backend/routes/auth.py index 5b7895c1b8e35a840606f5f443846d9f2f431b04..13df4fd1db76015e7418988ac8509a4b13a398bd 100644 --- a/backend/routes/auth.py +++ b/backend/routes/auth.py @@ -168,4 +168,3 @@ async def get_me(user: dict = Depends(get_current_user)) -> dict: Uses the shared auth dependency which handles cookie + Bearer token. """ return {key: value for key, value in user.items() if not key.startswith("_")} - diff --git a/backend/session_manager.py b/backend/session_manager.py index 04bca4107c3b825ade474130da9602c1c0bcb664..449ce3a0e06737ec5470d5fafa182c9a92b2eae0 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -12,10 +12,11 @@ from typing import Any, Optional from agent.config import load_config from agent.core.agent_loop import process_submission -from agent.messaging.gateway import NotificationGateway +from agent.core.hub_artifacts import start_session_artifact_collection_task from agent.core.session import Event, OpType, Session from agent.core.session_persistence import get_session_store from agent.core.tools import ToolRouter +from agent.messaging.gateway import NotificationGateway # Get project root (parent of backend directory) PROJECT_ROOT = Path(__file__).parent.parent @@ -70,7 +71,11 @@ class EventBroadcaster: while True: try: event: Event = await self._source.get() - msg = {"event_type": event.event_type, "data": event.data, "seq": event.seq} + msg = { + "event_type": event.event_type, + "data": event.data, + "seq": event.seq, + } for q in self._subscribers.values(): await q.put(msg) except asyncio.CancelledError: @@ -131,6 +136,7 @@ class SessionManager: self.sessions: dict[str, AgentSession] = {} self._lock = asyncio.Lock() self.persistence_store = None + self.enable_hub_artifact_collections = True async def start(self) -> None: """Start shared background resources.""" @@ -153,9 +159,7 @@ class SessionManager: def _count_user_sessions(self, user_id: str) -> int: """Count active sessions owned by a specific user.""" return sum( - 1 - for s in self.sessions.values() - if s.user_id == user_id and s.is_active + 1 for s in self.sessions.values() if s.user_id == user_id and s.is_active ) def _create_session_sync( @@ -196,10 +200,7 @@ class SessionManager: return tool_router, session def _serialize_messages(self, session: Session) -> list[dict[str, Any]]: - return [ - msg.model_dump(mode="json") - for msg in session.context_manager.items - ] + return [msg.model_dump(mode="json") for msg in session.context_manager.items] def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]: pending = session.pending_approval or {} @@ -307,7 +308,9 @@ class SessionManager: if hasattr(session, "auto_approval_policy_summary"): return session.auto_approval_policy_summary() cap = getattr(session, "auto_approval_cost_cap_usd", None) - estimated = float(getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0) + estimated = float( + getattr(session, "auto_approval_estimated_spend_usd", 0.0) or 0.0 + ) remaining = None if cap is None else round(max(0.0, float(cap) - estimated), 4) return { "enabled": bool(getattr(session, "auto_approval_enabled", False)), @@ -410,6 +413,28 @@ class SessionManager: session.sandbox_preload_cancel_event = None self._start_cpu_sandbox_preload(agent_session) + def _start_hub_artifact_collection(self, agent_session: AgentSession) -> None: + """Kick off best-effort Hub collection creation for the session.""" + if not getattr(self, "enable_hub_artifact_collections", False): + return + session = agent_session.session + if not getattr(session, "session_id", None): + try: + session.session_id = agent_session.session_id + except Exception: + logger.debug("Could not attach session id for Hub artifact collection") + token = agent_session.hf_token or getattr(session, "hf_token", None) + if not token: + return + try: + start_session_artifact_collection_task(session, token=token) + except Exception as e: + logger.debug( + "Failed to schedule Hub artifact collection for %s: %s", + agent_session.session_id, + e, + ) + async def _clear_persisted_sandbox_metadata(self, session_id: str) -> None: try: await self._store().update_session_fields( @@ -514,7 +539,9 @@ class SessionManager: runtime_state=runtime_state or self._runtime_state(agent_session), status=status, turn_count=agent_session.session.turn_count, - pending_approval=self._serialize_pending_approval(agent_session.session), + pending_approval=self._serialize_pending_approval( + agent_session.session + ), claude_counted=agent_session.claude_counted, created_at=agent_session.created_at, notification_destinations=list( @@ -564,6 +591,7 @@ class SessionManager: existing, preload_sandbox=preload_sandbox, ) + self._start_hub_artifact_collection(existing) return existing return None @@ -585,6 +613,7 @@ class SessionManager: existing, preload_sandbox=preload_sandbox, ) + self._start_hub_artifact_collection(existing) return existing return None @@ -626,7 +655,10 @@ class SessionManager: if restored_messages: # Keep the freshly-rendered system prompt, then attach the durable # non-system context so tools/date/user context stay current. - session.context_manager.items = [session.context_manager.items[0], *restored_messages] + session.context_manager.items = [ + session.context_manager.items[0], + *restored_messages, + ] self._restore_pending_approval(session, meta.get("pending_approval") or []) session.turn_count = int(meta.get("turn_count") or 0) @@ -668,7 +700,9 @@ class SessionManager: hf_token=hf_token, hf_username=hf_username, ) + self._start_hub_artifact_collection(started) return started + self._start_hub_artifact_collection(agent_session) if preload_sandbox: self._start_cpu_sandbox_preload(agent_session) logger.info("Restored session %s for user %s", session_id, owner or user_id) @@ -751,6 +785,7 @@ class SessionManager: event_queue=event_queue, tool_router=tool_router, ) + self._start_hub_artifact_collection(agent_session) await self.persist_session_snapshot(agent_session, runtime_state="idle") self._start_cpu_sandbox_preload(agent_session) @@ -760,7 +795,9 @@ class SessionManager: logger.info(f"Created session {session_id} for user {user_id}") return session_id - async def _track_pro_status(self, agent_session: AgentSession, *, is_pro: bool) -> None: + async def _track_pro_status( + self, agent_session: AgentSession, *, is_pro: bool + ) -> None: """Update Mongo per-user Pro state and emit a one-shot conversion event if the store reports a free→Pro transition. Best-effort: any Mongo failure is swallowed so we never fail session creation on @@ -777,6 +814,7 @@ class SessionManager: return try: from agent.core import telemetry + await telemetry.record_pro_conversion( agent_session.session, first_seen_at=result.get("first_seen_at"), @@ -933,7 +971,9 @@ class SessionManager: ) agent_session.is_processing = True try: - should_continue = await process_submission(session, submission) + should_continue = await process_submission( + session, submission + ) finally: agent_session.is_processing = False await self.persist_session_snapshot(agent_session) @@ -964,7 +1004,9 @@ class SessionManager: # Idempotent via session_id key; detached subprocess. if session.config.save_sessions: try: - session.save_and_upload_detached(session.config.session_dataset_repo) + session.save_and_upload_detached( + session.config.session_dataset_repo + ) except Exception as e: logger.warning(f"Final-flush failed for {session_id}: {e}") @@ -1025,7 +1067,9 @@ class SessionManager: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False - success = agent_session.session.context_manager.truncate_to_user_message(user_message_index) + success = agent_session.session.context_manager.truncate_to_user_message( + user_message_index + ) if success: await self.persist_session_snapshot(agent_session, runtime_state="idle") return success @@ -1118,9 +1162,7 @@ class SessionManager: session = agent_session.session if enabled: if not cap_provided and cost_cap_usd is None: - cost_cap_usd = getattr( - session, "auto_approval_cost_cap_usd", None - ) + cost_cap_usd = getattr(session, "auto_approval_cost_cap_usd", None) if cost_cap_usd is None: cost_cap_usd = DEFAULT_YOLO_COST_CAP_USD elif cost_cap_usd is None: @@ -1203,9 +1245,7 @@ class SessionManager: if destination is None: raise ValueError(f"Unknown destination '{name}'") if not destination.allow_auto_events: - raise ValueError( - f"Destination '{name}' is not enabled for auto events" - ) + raise ValueError(f"Destination '{name}' is not enabled for auto events") if name not in seen: normalized.append(name) seen.add(name) @@ -1248,7 +1288,10 @@ class SessionManager: "pending_approval": pending or None, "model": row.get("model"), "title": row.get("title"), - "notification_destinations": row.get("notification_destinations") or [], + "notification_destinations": row.get( + "notification_destinations" + ) + or [], "auto_approval": { "enabled": bool(row.get("auto_approval_enabled", False)), "cost_cap_usd": row.get("auto_approval_cost_cap_usd"), @@ -1261,8 +1304,13 @@ class SessionManager: else round( max( 0.0, - float(row.get("auto_approval_cost_cap_usd") or 0.0) - - float(row.get("auto_approval_estimated_spend_usd") or 0.0), + float( + row.get("auto_approval_cost_cap_usd") or 0.0 + ) + - float( + row.get("auto_approval_estimated_spend_usd") + or 0.0 + ), ), 4, ) diff --git a/backend/user_quotas.py b/backend/user_quotas.py index 94ce92f0663202fb89de57f62c193fde025f6ef2..0d135f1545a686996fa437e13f9190beb1b449a9 100644 --- a/backend/user_quotas.py +++ b/backend/user_quotas.py @@ -20,7 +20,11 @@ import asyncio import os from datetime import UTC, datetime -from agent.core.session_persistence import NoopSessionStore, get_session_store, _reset_store_for_tests +from agent.core.session_persistence import ( + NoopSessionStore, + get_session_store, + _reset_store_for_tests, +) CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1")) CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20")) diff --git a/pyproject.toml b/pyproject.toml index c97737534612e4477e5d0f10a14783f854b15255..5642a6dbe57eec315fbdf102e50a681dec774933 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ eval = [ dev = [ "pytest>=9.0.2", "pytest-asyncio>=1.2.0", + "ruff>=0.15.12", ] # All dependencies (eval + dev) diff --git a/scripts/build_kpis.py b/scripts/build_kpis.py index 67ab9183f5a324741cf0e588d9dd571078b37710..47a6515cef6339937f329e4007292d94b2a99b56 100644 --- a/scripts/build_kpis.py +++ b/scripts/build_kpis.py @@ -99,7 +99,6 @@ import sys import tempfile from collections import defaultdict from datetime import date, datetime, timedelta, timezone -from pathlib import Path from typing import Any, Iterable logger = logging.getLogger("build_kpis") @@ -107,13 +106,25 @@ logger = logging.getLogger("build_kpis") # Rough gpu-hour pricing for hf_jobs flavor strings. Keep conservative; used # only to compute gpu-hours (not dollars) — wall_time_s * flavor_gpu_count. _FLAVOR_GPU_COUNT = { - "cpu-basic": 0, "cpu-upgrade": 0, - "t4-small": 1, "t4-medium": 1, - "l4x1": 1, "l4x4": 4, - "l40sx1": 1, "l40sx4": 4, "l40sx8": 8, - "a10g-small": 1, "a10g-large": 1, "a10g-largex2": 2, "a10g-largex4": 4, - "a100-large": 1, "a100x2": 2, "a100x4": 4, "a100x8": 8, - "h100": 1, "h100x8": 8, + "cpu-basic": 0, + "cpu-upgrade": 0, + "t4-small": 1, + "t4-medium": 1, + "l4x1": 1, + "l4x4": 4, + "l40sx1": 1, + "l40sx4": 4, + "l40sx8": 8, + "a10g-small": 1, + "a10g-large": 1, + "a10g-largex2": 2, + "a10g-largex4": 4, + "a100-large": 1, + "a100x2": 2, + "a100x4": 4, + "a100x8": 8, + "h100": 1, + "h100x8": 8, } @@ -160,9 +171,13 @@ def _download_session(repo_id: str, path: str, token: str) -> dict | None: directory is near-free. """ from huggingface_hub import hf_hub_download + try: local = hf_hub_download( - repo_id=repo_id, filename=path, repo_type="dataset", token=token, + repo_id=repo_id, + filename=path, + repo_type="dataset", + token=token, ) except Exception as e: logger.warning("hf_hub_download(%s) failed: %s", path, e) @@ -188,7 +203,9 @@ def _download_session(repo_id: str, path: str, token: str) -> dict | None: def _filter_session_to_window( - session: dict, start: datetime, end: datetime, + session: dict, + start: datetime, + end: datetime, ) -> dict | None: """Return a copy of ``session`` whose events are only those in ``[start, end)``. @@ -216,16 +233,29 @@ def _session_metrics(session: dict) -> dict: # Pre-seed every numeric key so downstream aggregation can sum without # having to special-case empty sessions. out: dict = { - "sessions": 0, "turns": 0, "llm_calls": 0, - "tokens_prompt": 0, "tokens_completion": 0, - "tokens_cache_read": 0, "tokens_cache_creation": 0, + "sessions": 0, + "turns": 0, + "llm_calls": 0, + "tokens_prompt": 0, + "tokens_completion": 0, + "tokens_cache_read": 0, + "tokens_cache_creation": 0, "cost_usd": 0.0, - "tool_calls_total": 0, "tool_calls_success": 0, - "failures": 0, "regenerate_sessions": 0, - "thumbs_up": 0, "thumbs_down": 0, - "hf_jobs_submitted": 0, "hf_jobs_succeeded": 0, "hf_jobs_blocked": 0, - "pro_cta_clicks": 0, "pro_conversions": 0, "credits_topped_up": 0, - "sandboxes_created": 0, "sandboxes_cpu": 0, "sandboxes_gpu": 0, + "tool_calls_total": 0, + "tool_calls_success": 0, + "failures": 0, + "regenerate_sessions": 0, + "thumbs_up": 0, + "thumbs_down": 0, + "hf_jobs_submitted": 0, + "hf_jobs_succeeded": 0, + "hf_jobs_blocked": 0, + "pro_cta_clicks": 0, + "pro_conversions": 0, + "credits_topped_up": 0, + "sandboxes_created": 0, + "sandboxes_cpu": 0, + "sandboxes_gpu": 0, "first_tool_s": -1, } events = session.get("events") or [] @@ -373,7 +403,9 @@ def _session_metrics(session: dict) -> dict: def _aggregate(per_session: list[dict]) -> dict: """Collapse a bucket's worth of session rollups into the final KPI row.""" - ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0] + ttfa_values = [ + s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0 + ] gpu_hours: dict[str, float] = defaultdict(float) for s in per_session: for f, h in (s.get("_gpu_hours_by_flavor") or {}).items(): @@ -395,9 +427,21 @@ def _aggregate(per_session: list[dict]) -> dict: # never reached for the relevant signal — otherwise quiet hours # (status-check sessions, abandoned new conversations) drag every median # to 0 and the chart tells you nothing. - research_calls_nz = [s.get("_research_calls", 0) for s in per_session if s.get("_research_calls", 0) > 0] - distinct_tools_values = [s.get("_distinct_tools_used", 0) for s in per_session if s.get("_distinct_tools_used", 0) > 0] - total_calls_values = [s.get("_total_named_tool_calls", 0) for s in per_session if s.get("_total_named_tool_calls", 0) > 0] + research_calls_nz = [ + s.get("_research_calls", 0) + for s in per_session + if s.get("_research_calls", 0) > 0 + ] + distinct_tools_values = [ + s.get("_distinct_tools_used", 0) + for s in per_session + if s.get("_distinct_tools_used", 0) > 0 + ] + total_calls_values = [ + s.get("_total_named_tool_calls", 0) + for s in per_session + if s.get("_total_named_tool_calls", 0) > 0 + ] # Per-turn intensity: turns>0 is the natural filter here (a session with # 5 turns and 0 tools is a meaningful 0). Don't strip those. calls_per_turn_values = [ @@ -415,7 +459,9 @@ def _aggregate(per_session: list[dict]) -> dict: failures = int(sum(s["failures"] for s in per_session)) regenerates = int(sum(s["regenerate_sessions"] for s in per_session)) research_calls_total = int(sum(s.get("_research_calls", 0) for s in per_session)) - sessions_with_research = sum(1 for s in per_session if s.get("_research_calls", 0) > 0) + sessions_with_research = sum( + 1 for s in per_session if s.get("_research_calls", 0) > 0 + ) # Per-session cost percentiles — chart "median session cost" alongside the # mean so a few $700 outliers don't make you think every session is pricey. @@ -433,17 +479,23 @@ def _aggregate(per_session: list[dict]) -> dict: "tokens_prompt": int(tokens_prompt), "tokens_completion": int(sum(s["tokens_completion"] for s in per_session)), "tokens_cache_read": int(tokens_cache_read), - "tokens_cache_creation": int(sum(s["tokens_cache_creation"] for s in per_session)), + "tokens_cache_creation": int( + sum(s["tokens_cache_creation"] for s in per_session) + ), "cost_usd": round(sum(s["cost_usd"] for s in per_session), 4), # Per-session cost summaries. "cost_per_session_mean": round( sum(s["cost_usd"] for s in per_session) / total_sessions, 6 - ) if total_sessions > 0 else 0.0, + ) + if total_sessions > 0 + else 0.0, "cost_per_session_p50": round(cost_p50, 6), "cost_per_session_p95": round(cost_p95, 6), "cache_hit_ratio": round( tokens_cache_read / (tokens_cache_read + tokens_prompt), 4 - ) if (tokens_cache_read + tokens_prompt) > 0 else 0.0, + ) + if (tokens_cache_read + tokens_prompt) > 0 + else 0.0, # Raw reliability COUNTS (these are what the dashboard shows directly). "tool_calls_total": int(tool_total), "tool_calls_succeeded": int(tool_success), @@ -458,38 +510,56 @@ def _aggregate(per_session: list[dict]) -> dict: "regenerated_sessions": regenerates, # Rates kept for backwards compatibility with anything reading the # KPI dataset directly. - "tool_success_rate": round(tool_success / tool_total, 4) if tool_total > 0 else 0.0, - "failure_rate": round(failures / total_sessions, 4) if total_sessions > 0 else 0.0, - "regenerate_rate": round(regenerates / total_sessions, 4) if total_sessions > 0 else 0.0, + "tool_success_rate": round(tool_success / tool_total, 4) + if tool_total > 0 + else 0.0, + "failure_rate": round(failures / total_sessions, 4) + if total_sessions > 0 + else 0.0, + "regenerate_rate": round(regenerates / total_sessions, 4) + if total_sessions > 0 + else 0.0, "time_to_first_action_s_p50": round(_percentile(ttfa_values, 0.5), 2), "time_to_first_action_s_p95": round(_percentile(ttfa_values, 0.95), 2), "thumbs_up": int(sum(s["thumbs_up"] for s in per_session)), "thumbs_down": int(sum(s["thumbs_down"] for s in per_session)), "hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)), "hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)), - "sandboxes_created": int(sum(s.get("sandboxes_created", 0) for s in per_session)), + "sandboxes_created": int( + sum(s.get("sandboxes_created", 0) for s in per_session) + ), "sandboxes_cpu": int(sum(s.get("sandboxes_cpu", 0) for s in per_session)), "sandboxes_gpu": int(sum(s.get("sandboxes_gpu", 0) for s in per_session)), "hf_jobs_blocked": int(sum(s.get("hf_jobs_blocked", 0) for s in per_session)), "pro_cta_clicks": int(sum(s.get("pro_cta_clicks", 0) for s in per_session)), "pro_conversions": int(sum(s.get("pro_conversions", 0) for s in per_session)), - "credits_topped_up": int(sum(s.get("credits_topped_up", 0) for s in per_session)), + "credits_topped_up": int( + sum(s.get("credits_topped_up", 0) for s in per_session) + ), "gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True), # Research KPIs — answer "is the agent reaching for research?". "research_calls": research_calls_total, "sessions_with_research": int(sessions_with_research), "research_calls_per_session_p50": round(_percentile(research_calls_nz, 0.5), 2), - "research_calls_per_session_p95": round(_percentile(research_calls_nz, 0.95), 2), + "research_calls_per_session_p95": round( + _percentile(research_calls_nz, 0.95), 2 + ), # Intra-session breadth + intensity. p50 + p95 over per-session values. - "distinct_tools_per_session_p50": round(_percentile(distinct_tools_values, 0.5), 2), - "distinct_tools_per_session_p95": round(_percentile(distinct_tools_values, 0.95), 2), + "distinct_tools_per_session_p50": round( + _percentile(distinct_tools_values, 0.5), 2 + ), + "distinct_tools_per_session_p95": round( + _percentile(distinct_tools_values, 0.95), 2 + ), "tool_calls_per_session_p50": round(_percentile(total_calls_values, 0.5), 2), "tool_calls_per_session_p95": round(_percentile(total_calls_values, 0.95), 2), "tool_calls_per_turn_p50": round(_percentile(calls_per_turn_values, 0.5), 2), "tool_calls_per_turn_p95": round(_percentile(calls_per_turn_values, 0.95), 2), # JSON columns let the dashboard add/remove tools without schema churn. "tool_calls_by_name_json": json.dumps(dict(tool_calls_by_name), sort_keys=True), - "sessions_using_tool_json": json.dumps(dict(sessions_using_tool), sort_keys=True), + "sessions_using_tool_json": json.dumps( + dict(sessions_using_tool), sort_keys=True + ), # Surface split — answers "is research dropping on Bedrock specifically?". "sessions_by_model_json": json.dumps(dict(sessions_by_model), sort_keys=True), } @@ -507,7 +577,12 @@ def _csv_cell(v: Any) -> str: def _write_csv( - api, row: dict, bucket_key: str, path_in_repo: str, target_repo: str, token: str, + api, + row: dict, + bucket_key: str, + path_in_repo: str, + target_repo: str, + token: str, ) -> None: """Render ``row`` to CSV with a leading ``bucket`` column and upload. @@ -527,7 +602,10 @@ def _write_csv( try: api.create_repo( - repo_id=target_repo, repo_type="dataset", exist_ok=True, token=token, + repo_id=target_repo, + repo_type="dataset", + exist_ok=True, + token=token, ) api.upload_file( path_or_fileobj=tmp_path, @@ -545,7 +623,11 @@ def _write_csv( def run_for_hour( - api, source_repo: str, target_repo: str, hour_dt: datetime, token: str, + api, + source_repo: str, + target_repo: str, + hour_dt: datetime, + token: str, ) -> dict: """Roll up one UTC hour [hour_dt, hour_dt+1h). @@ -579,10 +661,16 @@ def run_for_hour( row = _aggregate(per_session) bucket_key = window_start.strftime("%Y-%m-%dT%H") - path_in_repo = f"hourly/{window_start.strftime('%Y-%m-%d')}/{window_start.strftime('%H')}.csv" + path_in_repo = ( + f"hourly/{window_start.strftime('%Y-%m-%d')}/{window_start.strftime('%H')}.csv" + ) _write_csv(api, row, bucket_key, path_in_repo, target_repo, token) - logger.info("Wrote KPIs for %s (%d sessions): %s", - bucket_key, per_session and len(per_session), row) + logger.info( + "Wrote KPIs for %s (%d sessions): %s", + bucket_key, + per_session and len(per_session), + row, + ) return row @@ -618,17 +706,23 @@ def main(argv: list[str] | None = None) -> int: ap.add_argument("--source", default="smolagents/ml-intern-sessions") ap.add_argument("--target", default="smolagents/ml-intern-kpis") ap.add_argument( - "--hours", type=int, default=1, + "--hours", + type=int, + default=1, help="Number of trailing hours to roll up (default: 1 = last completed hour).", ) ap.add_argument( - "--datetime", type=str, default=None, + "--datetime", + type=str, + default=None, help="Single hour, ISO ``YYYY-MM-DDTHH`` (UTC); overrides --hours.", ) ap.add_argument( - "--daily-backfill", type=str, default=None, + "--daily-backfill", + type=str, + default=None, help="Escape hatch: aggregate a whole day at once (YYYY-MM-DD). " - "Writes to daily/.csv. Use for historical backfill only.", + "Writes to daily/.csv. Use for historical backfill only.", ) args = ap.parse_args(argv) @@ -646,10 +740,17 @@ def main(argv: list[str] | None = None) -> int: return 1 from huggingface_hub import HfApi + api = HfApi() if args.daily_backfill: - run_for_day(api, args.source, args.target, date.fromisoformat(args.daily_backfill), token) + run_for_day( + api, + args.source, + args.target, + date.fromisoformat(args.daily_backfill), + token, + ) return 0 if args.datetime: diff --git a/scripts/build_sft.py b/scripts/build_sft.py index ac2344c9a45ffec0bf930fddefd77dedd862c390..7a89989a1cc3e13a6638e612874e68ce01e5144c 100644 --- a/scripts/build_sft.py +++ b/scripts/build_sft.py @@ -62,9 +62,13 @@ def _iter_session_files(api, repo_id: str, day: date, token: str) -> Iterable[st def _download_and_parse(repo_id: str, path: str, token: str) -> dict | None: from huggingface_hub import hf_hub_download + try: local = hf_hub_download( - repo_id=repo_id, filename=path, repo_type="dataset", token=token, + repo_id=repo_id, + filename=path, + repo_type="dataset", + token=token, ) except Exception as e: logger.warning("hf_hub_download(%s) failed: %s", path, e) @@ -118,7 +122,10 @@ def _upload_row(api, row: dict, day: date, target_repo: str, token: str) -> None tmp_path = tmp.name try: api.create_repo( - repo_id=target_repo, repo_type="dataset", exist_ok=True, token=token, + repo_id=target_repo, + repo_type="dataset", + exist_ok=True, + token=token, ) api.upload_file( path_or_fileobj=tmp_path, @@ -136,7 +143,11 @@ def _upload_row(api, row: dict, day: date, target_repo: str, token: str) -> None def run_for_day( - api, source_repo: str, target_repo: str, day: date, token: str, + api, + source_repo: str, + target_repo: str, + day: date, + token: str, ) -> int: paths = _iter_session_files(api, source_repo, day, token) n = 0 @@ -162,11 +173,15 @@ def main(argv: list[str] | None = None) -> int: ap.add_argument("--source", default="smolagents/ml-intern-sessions") ap.add_argument("--target", default="smolagents/ml-intern-sft") ap.add_argument( - "--days", type=int, default=1, + "--days", + type=int, + default=1, help="Number of trailing days to export (default: 1 = yesterday).", ) ap.add_argument( - "--date", type=str, default=None, + "--date", + type=str, + default=None, help="Single YYYY-MM-DD to export; overrides --days.", ) args = ap.parse_args(argv) @@ -185,6 +200,7 @@ def main(argv: list[str] | None = None) -> int: return 1 from huggingface_hub import HfApi + api = HfApi() if args.date: diff --git a/scripts/sweep_orphan_sandboxes.py b/scripts/sweep_orphan_sandboxes.py index 6525b55b35a5f78fbe5e0d50a36b2fe5e5ed5251..cbe7b9ebcca78b6a497cbc2705a074c91e17443a 100644 --- a/scripts/sweep_orphan_sandboxes.py +++ b/scripts/sweep_orphan_sandboxes.py @@ -128,14 +128,19 @@ def main() -> int: api = HfApi(token=token) cutoff = datetime.now(timezone.utc) - timedelta(days=args.max_age_days) - log({"level": "info", "msg": "sweep_start", "cutoff": cutoff.isoformat(), - "max_deletes": args.max_deletes, "apply": args.apply}) + log( + { + "level": "info", + "msg": "sweep_start", + "cutoff": cutoff.isoformat(), + "max_deletes": args.max_deletes, + "apply": args.apply, + } + ) # ``list_spaces`` doesn't filter by name pattern — we scan and filter # client-side. ``search="sandbox"`` narrows the network payload. - candidates = api.list_spaces( - search="sandbox", full=True, limit=args.limit - ) + candidates = api.list_spaces(search="sandbox", full=True, limit=args.limit) scanned = 0 matched = 0 @@ -150,15 +155,23 @@ def main() -> int: continue matched += 1 - last_mod = getattr(space, "lastModified", None) or getattr(space, "last_modified", None) + last_mod = getattr(space, "lastModified", None) or getattr( + space, "last_modified", None + ) if isinstance(last_mod, str): last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00")) if last_mod and last_mod > cutoff: skipped_too_recent += 1 continue - log({"level": "info", "msg": "candidate", "space_id": space.id, - "last_modified": last_mod.isoformat() if last_mod else None}) + log( + { + "level": "info", + "msg": "candidate", + "space_id": space.id, + "last_modified": last_mod.isoformat() if last_mod else None, + } + ) if not args.apply: continue @@ -179,20 +192,40 @@ def main() -> int: time.sleep(0.2) except HfHubHTTPError as e: failed += 1 - log({"level": "error", "msg": "delete_failed", "space_id": space.id, - "status": e.response.status_code, "error": str(e)[:200]}) + log( + { + "level": "error", + "msg": "delete_failed", + "space_id": space.id, + "status": e.response.status_code, + "error": str(e)[:200], + } + ) except Exception as e: failed += 1 - log({"level": "error", "msg": "delete_failed", "space_id": space.id, - "error": str(e)[:200]}) - - log({"level": "info", "msg": "sweep_end", - "scanned": scanned, "matched": matched, - "skipped_too_recent": skipped_too_recent, - "skipped_capped": skipped_capped, - "deleted": deleted, "failed": failed, - "capped": skipped_capped > 0, - "apply": args.apply}) + log( + { + "level": "error", + "msg": "delete_failed", + "space_id": space.id, + "error": str(e)[:200], + } + ) + + log( + { + "level": "info", + "msg": "sweep_end", + "scanned": scanned, + "matched": matched, + "skipped_too_recent": skipped_too_recent, + "skipped_capped": skipped_capped, + "deleted": deleted, + "failed": failed, + "capped": skipped_capped > 0, + "apply": args.apply, + } + ) return 0 if failed == 0 else 2 diff --git a/tests/integration/test_live_sandbox_auth.py b/tests/integration/test_live_sandbox_auth.py index f070919d0bb09d5da1c6eb5c5c4e91d4cd846d3c..ac099ee0dd60ebdb060cf402475000ca39b6e940 100644 --- a/tests/integration/test_live_sandbox_auth.py +++ b/tests/integration/test_live_sandbox_auth.py @@ -55,7 +55,11 @@ def test_live_sandbox_authenticated_agent_communication(): ) try: denied = unauthenticated.post("exists", json={"path": "/tmp"}) - assert denied.status_code in {401, 403, 404} # HF private-Space edge may 404 to avoid leaking existence + assert denied.status_code in { + 401, + 403, + 404, + } # HF private-Space edge may 404 to avoid leaking existence finally: unauthenticated.close() diff --git a/tests/unit/test_agent_model_gating.py b/tests/unit/test_agent_model_gating.py index 9a93cf90be44ba99bf105d628009fe51b8d9b8a4..2d0ab080ed241db5984c34daa5642cfb7956ec98 100644 --- a/tests/unit/test_agent_model_gating.py +++ b/tests/unit/test_agent_model_gating.py @@ -115,7 +115,9 @@ async def test_explicit_gated_session_request_still_rejects_non_hf_user(monkeypa async def fake_require_hf_org_member(_request): return False - monkeypatch.setattr(agent, "require_huggingface_org_member", fake_require_hf_org_member) + monkeypatch.setattr( + agent, "require_huggingface_org_member", fake_require_hf_org_member + ) with pytest.raises(HTTPException) as exc_info: await agent._model_override_for_new_session(None, agent.DEFAULT_CLAUDE_MODEL_ID) @@ -301,7 +303,9 @@ async def test_teardown_session_access_check_skips_sandbox_preload(monkeypatch): "ensure_session_loaded", fake_ensure_session_loaded, ) - monkeypatch.setattr(agent.session_manager, "teardown_sandbox", fake_teardown_sandbox) + monkeypatch.setattr( + agent.session_manager, "teardown_sandbox", fake_teardown_sandbox + ) response = await agent.teardown_session_sandbox("s1", {"user_id": "u1"}) await asyncio.sleep(0) diff --git a/tests/unit/test_auto_approval_policy.py b/tests/unit/test_auto_approval_policy.py index 3d8b37fe9bef1d407fc6fe5660ab70a2d5f98100..4785e979f4b19d4493ae71e8efbf9f990b519b23 100644 --- a/tests/unit/test_auto_approval_policy.py +++ b/tests/unit/test_auto_approval_policy.py @@ -58,7 +58,9 @@ async def test_scheduled_hf_jobs_always_require_manual_approval(operation): assert decision.requires_approval is True assert decision.auto_approval_blocked is True assert "Scheduled HF jobs" in decision.block_reason - assert agent_loop._needs_approval("hf_jobs", {"operation": operation}, session.config) + assert agent_loop._needs_approval( + "hf_jobs", {"operation": operation}, session.config + ) @pytest.mark.asyncio @@ -147,7 +149,9 @@ async def test_batch_reservation_blocks_second_over_budget_job(monkeypatch): @pytest.mark.asyncio -async def test_manual_approval_does_not_record_spend_when_session_yolo_disabled(monkeypatch): +async def test_manual_approval_does_not_record_spend_when_session_yolo_disabled( + monkeypatch, +): called = False async def fake_estimate(*args, **kwargs): diff --git a/tests/unit/test_build_kpis.py b/tests/unit/test_build_kpis.py index 2c5b53b02af9dd36c2ecb7aaa819f57adab35a1c..e792e5b4a6b9a05b46d570a0c98cbf3fa63b2118 100644 --- a/tests/unit/test_build_kpis.py +++ b/tests/unit/test_build_kpis.py @@ -38,15 +38,25 @@ def _session(events, user_id="u1", start="2026-04-24T09:59:00"): def test_llm_call_accumulates_tokens_and_cost(): mod = _load() events = [ - _ev("llm_call", { - "prompt_tokens": 100, "completion_tokens": 50, - "cache_read_tokens": 40, "cache_creation_tokens": 10, - "cost_usd": 0.01, - }), - _ev("llm_call", { - "prompt_tokens": 200, "completion_tokens": 100, - "cache_read_tokens": 80, "cost_usd": 0.02, - }), + _ev( + "llm_call", + { + "prompt_tokens": 100, + "completion_tokens": 50, + "cache_read_tokens": 40, + "cache_creation_tokens": 10, + "cost_usd": 0.01, + }, + ), + _ev( + "llm_call", + { + "prompt_tokens": 200, + "completion_tokens": 100, + "cache_read_tokens": 80, + "cost_usd": 0.02, + }, + ), ] m = mod._session_metrics(_session(events)) assert m["llm_calls"] == 2 @@ -75,11 +85,14 @@ def test_hf_job_gpu_hours(): mod = _load() events = [ _ev("hf_job_submit", {"flavor": "a100-large", "job_id": "j1"}), - _ev("hf_job_complete", { - "flavor": "a100-large", - "final_status": "COMPLETED", - "wall_time_s": 3600, - }), + _ev( + "hf_job_complete", + { + "flavor": "a100-large", + "final_status": "COMPLETED", + "wall_time_s": 3600, + }, + ), ] m = mod._session_metrics(_session(events)) assert m["hf_jobs_submitted"] == 1 @@ -118,12 +131,22 @@ def test_pro_conversions_and_credits_topped_up_per_session(): def test_aggregate_sums_pro_conversions_and_credits_topped_up(): mod = _load() - s1 = mod._session_metrics(_session([ - _ev("pro_conversion", {}), - ], user_id="u1")) - s2 = mod._session_metrics(_session([ - _ev("credits_topped_up", {"namespace": "ns"}), - ], user_id="u2")) + s1 = mod._session_metrics( + _session( + [ + _ev("pro_conversion", {}), + ], + user_id="u1", + ) + ) + s2 = mod._session_metrics( + _session( + [ + _ev("credits_topped_up", {"namespace": "ns"}), + ], + user_id="u2", + ) + ) s3 = mod._session_metrics(_session([], user_id="u3")) row = mod._aggregate([s1, s2, s3]) assert row["pro_conversions"] == 1 @@ -144,14 +167,28 @@ def test_feedback_counts(): def test_aggregate_day_cache_hit_and_users(): mod = _load() - s1 = mod._session_metrics(_session( - [_ev("llm_call", {"prompt_tokens": 100, "cache_read_tokens": 400, "cost_usd": 0.5})], - user_id="u1", - )) - s2 = mod._session_metrics(_session( - [_ev("llm_call", {"prompt_tokens": 200, "cache_read_tokens": 100, "cost_usd": 1.0})], - user_id="u2", - )) + s1 = mod._session_metrics( + _session( + [ + _ev( + "llm_call", + {"prompt_tokens": 100, "cache_read_tokens": 400, "cost_usd": 0.5}, + ) + ], + user_id="u1", + ) + ) + s2 = mod._session_metrics( + _session( + [ + _ev( + "llm_call", + {"prompt_tokens": 200, "cache_read_tokens": 100, "cost_usd": 1.0}, + ) + ], + user_id="u2", + ) + ) row = mod._aggregate_day([s1, s2]) assert row["sessions"] == 2 assert row["users"] == 2 @@ -181,17 +218,32 @@ def test_per_tool_counts_in_session_metrics(): def test_aggregate_research_kpis_only_count_doer_sessions(): mod = _load() - s1 = mod._session_metrics(_session([ - _ev("tool_call", {"tool": "research"}), - _ev("tool_call", {"tool": "research"}), - _ev("tool_call", {"tool": "research"}), - ], user_id="u1")) - s2 = mod._session_metrics(_session([ - _ev("tool_call", {"tool": "research"}), - ], user_id="u2")) - s3 = mod._session_metrics(_session([ - _ev("tool_call", {"tool": "bash"}), - ], user_id="u3")) + s1 = mod._session_metrics( + _session( + [ + _ev("tool_call", {"tool": "research"}), + _ev("tool_call", {"tool": "research"}), + _ev("tool_call", {"tool": "research"}), + ], + user_id="u1", + ) + ) + s2 = mod._session_metrics( + _session( + [ + _ev("tool_call", {"tool": "research"}), + ], + user_id="u2", + ) + ) + s3 = mod._session_metrics( + _session( + [ + _ev("tool_call", {"tool": "bash"}), + ], + user_id="u3", + ) + ) row = mod._aggregate([s1, s2, s3]) assert row["sessions"] == 3 assert row["sessions_with_research"] == 2 @@ -202,26 +254,39 @@ def test_aggregate_research_kpis_only_count_doer_sessions(): def test_aggregate_tool_breadth_and_intensity(): import json as _json + mod = _load() - s1 = mod._session_metrics(_session([ - _ev("tool_call", {"tool": "bash"}), - _ev("tool_call", {"tool": "research"}), - ], user_id="u1")) + s1 = mod._session_metrics( + _session( + [ + _ev("tool_call", {"tool": "bash"}), + _ev("tool_call", {"tool": "research"}), + ], + user_id="u1", + ) + ) # Two user turns so calls/turn = 4/2 = 2 - s2 = _session([ - _ev("tool_call", {"tool": "bash"}), - _ev("tool_call", {"tool": "bash"}), - _ev("tool_call", {"tool": "edit"}), - _ev("tool_call", {"tool": "edit"}), - ], user_id="u2") + s2 = _session( + [ + _ev("tool_call", {"tool": "bash"}), + _ev("tool_call", {"tool": "bash"}), + _ev("tool_call", {"tool": "edit"}), + _ev("tool_call", {"tool": "edit"}), + ], + user_id="u2", + ) s2["messages"] = [{"role": "user"}, {"role": "user"}] s2_metrics = mod._session_metrics(s2) row = mod._aggregate([s1, s2_metrics]) assert _json.loads(row["tool_calls_by_name_json"]) == { - "bash": 3, "research": 1, "edit": 2, + "bash": 3, + "research": 1, + "edit": 2, } assert _json.loads(row["sessions_using_tool_json"]) == { - "bash": 2, "research": 1, "edit": 1, + "bash": 2, + "research": 1, + "edit": 1, } # u1: 2 distinct, u2: 2 distinct -> p50 = 2 assert row["distinct_tools_per_session_p50"] == 2.0 @@ -236,16 +301,24 @@ def test_breadth_intensity_percentiles_exclude_zero_tool_sessions(): mod = _load() # Two productive sessions and three idle ones (no tool calls). Without # the doer-only filter, median of [0,0,0,2,4] = 0, which is useless. - productive_a = mod._session_metrics(_session([ - _ev("tool_call", {"tool": "bash"}), - _ev("tool_call", {"tool": "research"}), - ], user_id="prod_a")) - productive_b = _session([ - _ev("tool_call", {"tool": "bash"}), - _ev("tool_call", {"tool": "edit"}), - _ev("tool_call", {"tool": "edit"}), - _ev("tool_call", {"tool": "edit"}), - ], user_id="prod_b") + productive_a = mod._session_metrics( + _session( + [ + _ev("tool_call", {"tool": "bash"}), + _ev("tool_call", {"tool": "research"}), + ], + user_id="prod_a", + ) + ) + productive_b = _session( + [ + _ev("tool_call", {"tool": "bash"}), + _ev("tool_call", {"tool": "edit"}), + _ev("tool_call", {"tool": "edit"}), + _ev("tool_call", {"tool": "edit"}), + ], + user_id="prod_b", + ) productive_b["messages"] = [{"role": "user"}, {"role": "user"}] productive_b_metrics = mod._session_metrics(productive_b) idle = [ @@ -265,15 +338,25 @@ def test_pro_clicks_and_blocked_jobs_in_aggregate(): even if the dashboard doesn't currently chart them — they're cheap to keep and downstream consumers may still depend on the schema.""" mod = _load() - s1 = mod._session_metrics(_session([ - _ev("pro_cta_click", {"source": "hf_jobs_upgrade_dialog"}), - _ev("pro_cta_click", {"source": "claude_cap_dialog"}), - _ev("jobs_access_blocked", {}), - ], user_id="u1")) - s2 = mod._session_metrics(_session([ - _ev("jobs_access_blocked", {}), - _ev("jobs_access_blocked", {}), - ], user_id="u2")) + s1 = mod._session_metrics( + _session( + [ + _ev("pro_cta_click", {"source": "hf_jobs_upgrade_dialog"}), + _ev("pro_cta_click", {"source": "claude_cap_dialog"}), + _ev("jobs_access_blocked", {}), + ], + user_id="u1", + ) + ) + s2 = mod._session_metrics( + _session( + [ + _ev("jobs_access_blocked", {}), + _ev("jobs_access_blocked", {}), + ], + user_id="u2", + ) + ) row = mod._aggregate([s1, s2]) assert row["pro_cta_clicks"] == 2 assert row["hf_jobs_blocked"] == 3 @@ -281,6 +364,7 @@ def test_pro_clicks_and_blocked_jobs_in_aggregate(): def test_aggregate_sessions_by_model_split(): import json as _json + mod = _load() s_anthropic = _session([], user_id="a") s_anthropic["model_name"] = "anthropic/claude-opus-4-6" @@ -288,11 +372,13 @@ def test_aggregate_sessions_by_model_split(): s_bedrock["model_name"] = "bedrock/us.anthropic.claude-opus-4-6-v1" s_bedrock2 = _session([], user_id="c") s_bedrock2["model_name"] = "bedrock/us.anthropic.claude-opus-4-6-v1" - row = mod._aggregate([ - mod._session_metrics(s_anthropic), - mod._session_metrics(s_bedrock), - mod._session_metrics(s_bedrock2), - ]) + row = mod._aggregate( + [ + mod._session_metrics(s_anthropic), + mod._session_metrics(s_bedrock), + mod._session_metrics(s_bedrock2), + ] + ) assert _json.loads(row["sessions_by_model_json"]) == { "anthropic/claude-opus-4-6": 1, "bedrock/us.anthropic.claude-opus-4-6-v1": 2, @@ -311,6 +397,7 @@ def test_failure_and_regenerate_rates(): def test_window_filter_keeps_only_events_in_range(): from datetime import datetime, timezone + mod = _load() events = [ _ev("llm_call", {"prompt_tokens": 100}, ts="2026-04-24T09:45:00"), @@ -335,6 +422,7 @@ def test_window_filter_keeps_only_events_in_range(): def test_window_filter_returns_none_when_nothing_in_range(): from datetime import datetime, timezone + mod = _load() events = [_ev("llm_call", {"prompt_tokens": 100}, ts="2026-04-24T09:45:00")] session = _session(events) diff --git a/tests/unit/test_build_sft.py b/tests/unit/test_build_sft.py index 538ede29db9af7cb35e26a877472ac807f75fbea..ab24ec5dc5589ac1f11413a43a9ac640427b0250 100644 --- a/tests/unit/test_build_sft.py +++ b/tests/unit/test_build_sft.py @@ -23,24 +23,51 @@ def _session_row(): "messages": [ {"role": "system", "content": "You are an agent"}, {"role": "user", "content": "fine-tune llama"}, - {"role": "assistant", "content": None, "tool_calls": [ - {"id": "c1", "type": "function", - "function": {"name": "hf_jobs", "arguments": '{"script":"from trl import SFTTrainer"}'}}, - ]}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": { + "name": "hf_jobs", + "arguments": '{"script":"from trl import SFTTrainer"}', + }, + }, + ], + }, {"role": "tool", "tool_call_id": "c1", "content": "ok"}, {"role": "assistant", "content": "done"}, ], "events": [ - {"timestamp": "2026-04-24T10:00:05", "event_type": "tool_call", - "data": {"tool": "hf_jobs", - "arguments": {"script": "from trl import SFTTrainer"}}}, - {"timestamp": "2026-04-24T10:00:06", "event_type": "hf_job_submit", - "data": {"flavor": "a100-large", "push_to_hub": True}}, - {"timestamp": "2026-04-24T10:45:00", "event_type": "hf_job_complete", - "data": {"flavor": "a100-large", "final_status": "COMPLETED", - "wall_time_s": 2700}}, - {"timestamp": "2026-04-24T10:45:05", "event_type": "turn_complete", - "data": {}}, + { + "timestamp": "2026-04-24T10:00:05", + "event_type": "tool_call", + "data": { + "tool": "hf_jobs", + "arguments": {"script": "from trl import SFTTrainer"}, + }, + }, + { + "timestamp": "2026-04-24T10:00:06", + "event_type": "hf_job_submit", + "data": {"flavor": "a100-large", "push_to_hub": True}, + }, + { + "timestamp": "2026-04-24T10:45:00", + "event_type": "hf_job_complete", + "data": { + "flavor": "a100-large", + "final_status": "COMPLETED", + "wall_time_s": 2700, + }, + }, + { + "timestamp": "2026-04-24T10:45:05", + "event_type": "turn_complete", + "data": {}, + }, ], "tools": [{"type": "function", "function": {"name": "hf_jobs"}}], } diff --git a/tests/unit/test_cli_rendering.py b/tests/unit/test_cli_rendering.py index ff633c0673c4bbdafda9a1de46e69bbb007bdfad..e94700bfe96c112d6617239cb88be0ad3544ccb5 100644 --- a/tests/unit/test_cli_rendering.py +++ b/tests/unit/test_cli_rendering.py @@ -12,7 +12,10 @@ from agent.utils import terminal_display def test_direct_anthropic_research_model_stays_off_bedrock(): - assert _get_research_model("anthropic/claude-opus-4-6") == "anthropic/claude-sonnet-4-6" + assert ( + _get_research_model("anthropic/claude-opus-4-6") + == "anthropic/claude-sonnet-4-6" + ) def test_bedrock_anthropic_research_model_stays_on_bedrock(): @@ -42,7 +45,7 @@ def test_subagent_display_does_not_spawn_background_redraw(monkeypatch): mgr = terminal_display.SubAgentDisplayManager() mgr.start("agent-1", "research") - mgr.add_call("agent-1", "▸ hf_papers {\"operation\": \"search\"}") + mgr.add_call("agent-1", '▸ hf_papers {"operation": "search"}') mgr.clear("agent-1") assert calls == [] diff --git a/tests/unit/test_compaction_loop_break.py b/tests/unit/test_compaction_loop_break.py index 8e65df32d2109f87514053f3ec061309cda0b6fe..2ce5ead521c2539b5595bee7bf5d61767919f192 100644 --- a/tests/unit/test_compaction_loop_break.py +++ b/tests/unit/test_compaction_loop_break.py @@ -164,7 +164,10 @@ async def test_compact_raises_when_post_compact_still_over_threshold(): self.running_context_usage = 95_000 with ( - patch("agent.context_manager.manager.summarize_messages", side_effect=fake_summarize), + patch( + "agent.context_manager.manager.summarize_messages", + side_effect=fake_summarize, + ), patch.object(ContextManager, "_recompute_usage", fake_recompute), # Avoid token_counter calls in _truncate_oversized patch("litellm.token_counter", return_value=100), @@ -195,8 +198,8 @@ async def test_compact_does_not_duplicate_system_when_idx_is_zero(): Message(role="system", content="system"), Message(role="user", content="task"), Message(role="assistant", content="ok"), # would be the only - # message_to_summarize but the - # idx bug pulls it into recent + # message_to_summarize but the + # idx bug pulls it into recent Message(role="user", content="followup"), Message(role="assistant", content="reply"), ] # exactly 5 = untouched_messages, so idx initialises to 0 @@ -209,7 +212,10 @@ async def test_compact_does_not_duplicate_system_when_idx_is_zero(): self.running_context_usage = 5_000 with ( - patch("agent.context_manager.manager.summarize_messages", side_effect=fake_summarize), + patch( + "agent.context_manager.manager.summarize_messages", + side_effect=fake_summarize, + ), patch.object(ContextManager, "_recompute_usage", fake_recompute), patch("litellm.token_counter", return_value=100), ): @@ -232,8 +238,7 @@ async def test_compact_does_not_duplicate_system_when_idx_is_zero(): # so first_user_msg ends up in BOTH head and recent_messages → # duplicate user message → Anthropic 400 (two consecutive user roles). task_count = sum( - 1 for m in cm.items - if m.role == "user" and (m.content or "") == "task" + 1 for m in cm.items if m.role == "user" and (m.content or "") == "task" ) assert task_count == 1, ( f"Expected exactly 1 'task' user message, found {task_count}. " @@ -243,9 +248,9 @@ async def test_compact_does_not_duplicate_system_when_idx_is_zero(): # API contract). System counts separately. non_system = [m for m in cm.items if m.role != "system"] for i in range(1, len(non_system)): - assert non_system[i].role != non_system[i-1].role, ( + assert non_system[i].role != non_system[i - 1].role, ( f"Two consecutive {non_system[i].role} messages at non-system " - f"position {i-1},{i} — Anthropic API rejects this. " + f"position {i - 1},{i} — Anthropic API rejects this. " f"Roles: {[m.role for m in cm.items]}" ) @@ -272,7 +277,10 @@ async def test_compact_succeeds_when_post_compact_under_threshold(): self.running_context_usage = 5_000 # well under threshold with ( - patch("agent.context_manager.manager.summarize_messages", side_effect=fake_summarize), + patch( + "agent.context_manager.manager.summarize_messages", + side_effect=fake_summarize, + ), patch.object(ContextManager, "_recompute_usage", fake_recompute), patch("litellm.token_counter", return_value=100), ): diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 71f92b2a44e9b4fcb2dbe73179baefd05e25b683..c99f05ee4b2288bb984891e01b8609b56bbfddf8 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -7,7 +7,9 @@ def _write_json(path, data): path.write_text(json.dumps(data), encoding="utf-8") -def test_load_config_does_not_apply_slack_user_defaults_by_default(tmp_path, monkeypatch): +def test_load_config_does_not_apply_slack_user_defaults_by_default( + tmp_path, monkeypatch +): config_path = tmp_path / "config.json" _write_json( config_path, diff --git a/tests/unit/test_dangling_tool_calls.py b/tests/unit/test_dangling_tool_calls.py index e217ae97dfdc02f72f61e0ede6ed194bfeeefa45..b4215f173b30deb947095ea9d8ac6a0eca58fe35 100644 --- a/tests/unit/test_dangling_tool_calls.py +++ b/tests/unit/test_dangling_tool_calls.py @@ -34,33 +34,42 @@ def _make_cm() -> ContextManager: def test_orphan_tool_use_followed_by_user_message_is_patched(): cm = _make_cm() - cm.items.extend([ - Message(role="user", content="Research X"), - Message( - role="assistant", - content=None, - tool_calls=[_tool_call("call_abc", "research")], - ), - Message(role="user", content="??"), - ]) + cm.items.extend( + [ + Message(role="user", content="Research X"), + Message( + role="assistant", + content=None, + tool_calls=[_tool_call("call_abc", "research")], + ), + Message(role="user", content="??"), + ] + ) msgs = cm.get_messages() tool_msgs = [m for m in msgs if getattr(m, "role", None) == "tool"] assert len(tool_msgs) == 1 assert tool_msgs[0].tool_call_id == "call_abc" - assert "interrupted" in (tool_msgs[0].content or "").lower() or "not executed" in (tool_msgs[0].content or "").lower() + assert ( + "interrupted" in (tool_msgs[0].content or "").lower() + or "not executed" in (tool_msgs[0].content or "").lower() + ) def test_no_orphan_means_no_stub(): cm = _make_cm() - cm.items.extend([ - Message(role="user", content="Research X"), - Message( - role="assistant", - content=None, - tool_calls=[_tool_call("call_abc", "research")], - ), - Message(role="tool", content="ok", tool_call_id="call_abc", name="research"), - ]) + cm.items.extend( + [ + Message(role="user", content="Research X"), + Message( + role="assistant", + content=None, + tool_calls=[_tool_call("call_abc", "research")], + ), + Message( + role="tool", content="ok", tool_call_id="call_abc", name="research" + ), + ] + ) cm.get_messages() tool_msgs = [m for m in cm.items if getattr(m, "role", None) == "tool"] assert len(tool_msgs) == 1 @@ -69,18 +78,20 @@ def test_no_orphan_means_no_stub(): def test_multiple_dangling_tool_calls_in_one_assistant_message_are_all_patched(): cm = _make_cm() - cm.items.extend([ - Message(role="user", content="do two things"), - Message( - role="assistant", - content=None, - tool_calls=[ - _tool_call("call_1", "research"), - _tool_call("call_2", "bash"), - ], - ), - Message(role="user", content="follow up"), - ]) + cm.items.extend( + [ + Message(role="user", content="do two things"), + Message( + role="assistant", + content=None, + tool_calls=[ + _tool_call("call_1", "research"), + _tool_call("call_2", "bash"), + ], + ), + Message(role="user", content="follow up"), + ] + ) cm.get_messages() tool_ids = { getattr(m, "tool_call_id", None) @@ -97,21 +108,23 @@ def test_orphan_in_earlier_turn_still_gets_patched(): backwards, so this case never got fixed and Bedrock rejected. """ cm = _make_cm() - cm.items.extend([ - Message(role="user", content="turn 1"), - Message( - role="assistant", - content=None, - tool_calls=[_tool_call("call_old", "research")], - ), - Message(role="user", content="turn 2 — please retry"), - Message( - role="assistant", - content=None, - tool_calls=[_tool_call("call_new", "bash")], - ), - Message(role="tool", content="ok", tool_call_id="call_new", name="bash"), - ]) + cm.items.extend( + [ + Message(role="user", content="turn 1"), + Message( + role="assistant", + content=None, + tool_calls=[_tool_call("call_old", "research")], + ), + Message(role="user", content="turn 2 — please retry"), + Message( + role="assistant", + content=None, + tool_calls=[_tool_call("call_new", "bash")], + ), + Message(role="tool", content="ok", tool_call_id="call_new", name="bash"), + ] + ) cm.get_messages() tool_ids = { getattr(m, "tool_call_id", None) diff --git a/tests/unit/test_doom_loop.py b/tests/unit/test_doom_loop.py index bbdac454d9c9b0b15e27747fe4c09e75fdf8b670..3a31a5a4874e17365609f4fd69d046b83373a688 100644 --- a/tests/unit/test_doom_loop.py +++ b/tests/unit/test_doom_loop.py @@ -207,7 +207,7 @@ def test_check_for_doom_loop_returns_corrective_prompt_for_identical_run(): msgs = [_assistant_call("read", '{"p": 1}')] * 3 out = check_for_doom_loop(msgs) assert out is not None - assert "DOOM LOOP DETECTED" in out + assert "REPETITION GUARD" in out assert "'read'" in out @@ -218,7 +218,7 @@ def test_check_for_doom_loop_returns_corrective_prompt_for_cycle(): msgs.append(_assistant_call("b", "{}")) out = check_for_doom_loop(msgs) assert out is not None - assert "DOOM LOOP DETECTED" in out + assert "REPETITION GUARD" in out assert "a → b" in out diff --git a/tests/unit/test_heartbeat.py b/tests/unit/test_heartbeat.py index 29d8079fd63d402743aec31ba44787007ef6f9a8..56161be801d97fcdacff96cac297e696251a9406 100644 --- a/tests/unit/test_heartbeat.py +++ b/tests/unit/test_heartbeat.py @@ -10,8 +10,6 @@ import json from pathlib import Path from unittest.mock import patch -import pytest - from agent.core.session import Event, Session @@ -29,11 +27,11 @@ class _FakeConfig: mcpServers: dict = {} -def _mk_session(tmp_path: Path) -> Session: - import os - os.chdir(tmp_path) # so session_logs/ lands under tmp_path +def _mk_session(tmp_path: Path, monkeypatch) -> Session: + monkeypatch.chdir(tmp_path) # so session_logs/ lands under tmp_path # Stub out the context manager to avoid litellm lookups. from agent.context_manager.manager import ContextManager + cm = ContextManager.__new__(ContextManager) cm.items = [] cm.tool_specs = [] @@ -58,7 +56,7 @@ def test_heartbeat_fires_after_interval(tmp_path, monkeypatch): # Use asyncio.run rather than pytest-asyncio so the test works without the # plugin installed (same pattern elsewhere in this repo). async def body(): - s = _mk_session(tmp_path) + s = _mk_session(tmp_path, monkeypatch) calls = [] def fake_upload(repo_id): @@ -94,10 +92,10 @@ def test_heartbeat_fires_after_interval(tmp_path, monkeypatch): asyncio.run(body()) -def test_stable_local_path_overwrites(tmp_path): - import os - os.chdir(tmp_path) +def test_stable_local_path_overwrites(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) from agent.context_manager.manager import ContextManager + cm = ContextManager.__new__(ContextManager) cm.items = [] cm.tool_specs = [] diff --git a/tests/unit/test_hf_access.py b/tests/unit/test_hf_access.py index 9e222ea7f7eef554f58d28f26950b09f733e2cc3..e59524a0e40024fbe8349faac448740445266928 100644 --- a/tests/unit/test_hf_access.py +++ b/tests/unit/test_hf_access.py @@ -2,10 +2,12 @@ from agent.core.hf_access import is_billing_error, jobs_access_from_whoami def test_personal_user_lists_username_namespace(): - access = jobs_access_from_whoami({ - "name": "alice", - "orgs": [], - }) + access = jobs_access_from_whoami( + { + "name": "alice", + "orgs": [], + } + ) assert access.username == "alice" assert access.org_names == [] assert access.eligible_namespaces == ["alice"] @@ -17,13 +19,15 @@ def test_user_with_orgs_lists_all_namespaces_regardless_of_plan(): # org the user belongs to is eligible. We sort orgs alphabetically and # always put the personal namespace first so the picker default is the # user's own account. - access = jobs_access_from_whoami({ - "name": "alice", - "orgs": [ - {"name": "team-a", "plan": "team"}, - {"name": "oss-friends", "plan": "free"}, - ], - }) + access = jobs_access_from_whoami( + { + "name": "alice", + "orgs": [ + {"name": "team-a", "plan": "team"}, + {"name": "oss-friends", "plan": "free"}, + ], + } + ) assert access.username == "alice" assert access.org_names == ["oss-friends", "team-a"] assert access.eligible_namespaces == ["alice", "oss-friends", "team-a"] @@ -34,19 +38,23 @@ def test_free_user_without_org_still_eligible_under_personal_namespace(): # Pro is no longer required — the user is offered their personal # namespace; whether they actually have credits is decided at job # creation time when HF returns a 402 / billing error. - access = jobs_access_from_whoami({ - "name": "alice", - "orgs": [], - }) + access = jobs_access_from_whoami( + { + "name": "alice", + "orgs": [], + } + ) assert access.eligible_namespaces == ["alice"] assert access.default_namespace == "alice" def test_org_only_token_falls_back_to_first_org(): - access = jobs_access_from_whoami({ - "name": None, - "orgs": [{"name": "team-a"}, {"name": "team-b"}], - }) + access = jobs_access_from_whoami( + { + "name": None, + "orgs": [{"name": "team-a"}, {"name": "team-b"}], + } + ) assert access.username is None assert access.eligible_namespaces == ["team-a", "team-b"] assert access.default_namespace == "team-a" diff --git a/tests/unit/test_hub_artifacts.py b/tests/unit/test_hub_artifacts.py new file mode 100644 index 0000000000000000000000000000000000000000..41854454131c3bbbc1ae373df7c23e708918c899 --- /dev/null +++ b/tests/unit/test_hub_artifacts.py @@ -0,0 +1,505 @@ +import asyncio +import logging +from types import SimpleNamespace + +import pytest + +from agent.core import hub_artifacts +from agent.core.hub_artifacts import ( + ML_INTERN_TAG, + PROVENANCE_MARKER, + artifact_collection_title, + augment_repo_card_content, + build_hub_artifact_sitecustomize, + ensure_session_artifact_collection, + is_known_hub_artifact, + register_hub_artifact, + remember_hub_artifact, + start_session_artifact_collection_task, + wrap_shell_command_with_hub_artifact_bootstrap, +) +from agent.tools import local_tools, sandbox_tool +from agent.tools.hf_repo_files_tool import HfRepoFilesTool +from agent.tools.hf_repo_git_tool import HfRepoGitTool +from agent.tools.jobs_tool import _wrap_command_with_artifact_bootstrap + + +def _session() -> SimpleNamespace: + return SimpleNamespace( + session_id="session-123", + session_start_time="2026-05-05T10:20:30", + ) + + +def test_artifact_collection_title_uses_session_date_and_id(): + assert ( + artifact_collection_title(_session()) + == "ml-intern-artifacts-2026-05-05-session-123" + ) + + +def test_artifact_collection_title_uses_short_uuid_fragment(): + session = SimpleNamespace( + session_id="fadcbc77-3439-4c2b-bc52-50d7f6353af3", + session_start_time="2026-05-05T10:20:30", + ) + + title = artifact_collection_title(session) + + assert title == "ml-intern-artifacts-2026-05-05-fadcbc77" + assert len(title) < 60 + + +def test_artifact_collection_title_still_truncates_long_non_uuid_ids(): + session = SimpleNamespace( + session_id="custom-session-id-that-is-longer-than-the-hub-title-limit", + session_start_time="2026-05-05T10:20:30", + ) + + title = artifact_collection_title(session) + + assert title.startswith("ml-intern-artifacts-2026-05-05-custom-session-id") + assert len(title) < 60 + + +def test_model_card_merges_tags_and_appends_provenance_and_usage(): + content = """--- +license: apache-2.0 +tags: +- text-generation +--- +# Existing Model + +Existing details stay here. +""" + + updated = augment_repo_card_content(content, "alice/model", "model") + second_pass = augment_repo_card_content(updated, "alice/model", "model") + + assert "license: apache-2.0" in updated + assert "- text-generation" in updated + assert f"- {ML_INTERN_TAG}" in updated + assert "# Existing Model" in updated + assert "Existing details stay here." in updated + assert PROVENANCE_MARKER in updated + assert "AutoModelForCausalLM" in updated + assert second_pass.count(PROVENANCE_MARKER) == 1 + assert second_pass.count("AutoModelForCausalLM") == updated.count( + "AutoModelForCausalLM" + ) + + +def test_dataset_card_adds_load_dataset_usage(): + updated = augment_repo_card_content("", "alice/dataset", "dataset") + + assert f"- {ML_INTERN_TAG}" in updated + assert "# alice/dataset" in updated + assert "from datasets import load_dataset" in updated + assert 'load_dataset("alice/dataset")' in updated + + +def test_existing_usage_section_is_preserved_without_duplicate_usage(): + content = """# Existing Dataset + +## Usage + +Use the custom loader in this repository. +""" + + updated = augment_repo_card_content(content, "alice/dataset", "dataset") + + assert "Use the custom loader in this repository." in updated + assert "from datasets import load_dataset" not in updated + assert PROVENANCE_MARKER in updated + + +def test_space_card_gets_metadata_without_provenance_body(): + updated = augment_repo_card_content("# Existing Space\n", "alice/space", "space") + + assert f"- {ML_INTERN_TAG}" in updated + assert "# Existing Space" in updated + assert PROVENANCE_MARKER not in updated + + +def test_register_hub_artifact_creates_private_collection_and_adds_item_once( + monkeypatch, +): + session = _session() + + class FakeApi: + token = "hf-token" + + def __init__(self): + self.created_collections = [] + self.collection_items = [] + self.uploads = [] + + def create_collection(self, **kwargs): + self.created_collections.append(kwargs) + return SimpleNamespace(slug="alice/ml-intern-artifacts") + + def add_collection_item(self, **kwargs): + self.collection_items.append(kwargs) + + def upload_file(self, **kwargs): + self.uploads.append(kwargs) + + api = FakeApi() + monkeypatch.setattr(hub_artifacts, "_read_remote_readme", lambda *_, **__: "") + + assert register_hub_artifact(api, "alice/model", "model", session=session) + assert register_hub_artifact(api, "alice/model", "model", session=session) + + assert is_known_hub_artifact(session, "alice/model", "model") + assert len(api.created_collections) == 1 + assert api.created_collections[0]["title"] == artifact_collection_title(session) + assert api.created_collections[0]["private"] is True + assert len(api.collection_items) == 1 + assert api.collection_items[0]["item_id"] == "alice/model" + assert api.collection_items[0]["item_type"] == "model" + assert api.collection_items[0]["exists_ok"] is True + assert len(api.uploads) == 1 + assert b"ml-intern" in api.uploads[0]["path_or_fileobj"] + + +def test_register_hub_artifact_retries_after_partial_failure(monkeypatch): + session = _session() + api = SimpleNamespace(token="hf-token") + card_attempts = 0 + collection_attempts = 0 + + def flaky_update_repo_card(*args, **kwargs): + nonlocal card_attempts + card_attempts += 1 + if card_attempts == 1: + raise RuntimeError("temporary card failure") + + def add_to_collection(*args, **kwargs): + nonlocal collection_attempts + collection_attempts += 1 + + monkeypatch.setattr( + hub_artifacts, + "_update_repo_card", + flaky_update_repo_card, + ) + monkeypatch.setattr(hub_artifacts, "_add_to_collection", add_to_collection) + + assert not register_hub_artifact(api, "alice/model", "model", session=session) + assert register_hub_artifact(api, "alice/model", "model", session=session) + assert register_hub_artifact(api, "alice/model", "model", session=session) + + assert card_attempts == 2 + assert collection_attempts == 2 + + +def test_register_hub_artifact_retries_after_collection_failure(monkeypatch): + session = _session() + api = SimpleNamespace(token="hf-token") + card_attempts = 0 + collection_attempts = 0 + + def update_repo_card(*args, **kwargs): + nonlocal card_attempts + card_attempts += 1 + + def flaky_add_to_collection(*args, **kwargs): + nonlocal collection_attempts + collection_attempts += 1 + if collection_attempts == 1: + raise RuntimeError("temporary collection failure") + + monkeypatch.setattr(hub_artifacts, "_update_repo_card", update_repo_card) + monkeypatch.setattr( + hub_artifacts, + "_add_to_collection", + flaky_add_to_collection, + ) + + assert not register_hub_artifact(api, "alice/model", "model", session=session) + assert register_hub_artifact(api, "alice/model", "model", session=session) + assert register_hub_artifact(api, "alice/model", "model", session=session) + + assert card_attempts == 2 + assert collection_attempts == 2 + + +def test_session_artifact_set_falls_back_when_session_rejects_attrs(caplog): + class SlottedSession: + __slots__ = ("session_id", "session_start_time") + + def __init__(self): + self.session_id = "session-123" + self.session_start_time = "2026-05-05T10:20:30" + + session = SlottedSession() + + with caplog.at_level(logging.WARNING): + remember_hub_artifact(session, "alice/model", "model") + + assert is_known_hub_artifact(session, "alice/model", "model") + assert "using process-local fallback state" in caplog.text + + +@pytest.mark.asyncio +async def test_ensure_session_artifact_collection_uses_user_token(monkeypatch): + session = _session() + calls = [] + + class FakeApi: + def __init__(self, token): + self.token = token + + def fake_ensure_collection_slug(api, seen_session, **kwargs): + calls.append((api.token, seen_session, kwargs)) + return "alice/ml-intern-artifacts" + + monkeypatch.setattr(hub_artifacts, "HfApi", FakeApi) + monkeypatch.setattr( + hub_artifacts, + "_ensure_collection_slug", + fake_ensure_collection_slug, + ) + + slug = await ensure_session_artifact_collection(session, token="hf-token") + + assert slug == "alice/ml-intern-artifacts" + assert calls == [ + ("hf-token", session, {"token": "hf-token"}), + ] + + +@pytest.mark.asyncio +async def test_start_session_artifact_collection_task_dedupes(monkeypatch): + session = _session() + calls = [] + + async def fake_ensure_session_artifact_collection(seen_session, **kwargs): + calls.append((seen_session, kwargs)) + await asyncio.sleep(0) + return "alice/ml-intern-artifacts" + + monkeypatch.setattr( + hub_artifacts, + "ensure_session_artifact_collection", + fake_ensure_session_artifact_collection, + ) + + task = start_session_artifact_collection_task(session, token="hf-token") + second = start_session_artifact_collection_task(session, token="hf-token") + + assert task is not None + assert second is task + await task + assert calls == [(session, {"token": "hf-token"})] + + +def test_start_session_artifact_collection_task_skips_without_token(): + assert start_session_artifact_collection_task(_session()) is None + + +@pytest.mark.asyncio +async def test_hf_repo_git_create_repo_registers_artifact(monkeypatch): + session = _session() + calls = [] + + class FakeApi: + token = "hf-token" + + def create_repo(self, **kwargs): + self.create_kwargs = kwargs + return "https://huggingface.co/spaces/alice/demo" + + def fake_register(api, repo_id, repo_type, **kwargs): + calls.append((api, repo_id, repo_type, kwargs)) + return True + + monkeypatch.setattr( + "agent.tools.hf_repo_git_tool.register_hub_artifact", + fake_register, + ) + tool = HfRepoGitTool(hf_token="hf-token", session=session) + tool.api = FakeApi() + + result = await tool._create_repo( + { + "repo_id": "alice/demo", + "repo_type": "space", + "space_sdk": "gradio", + "private": True, + } + ) + + assert result["totalResults"] == 1 + assert calls == [ + ( + tool.api, + "alice/demo", + "space", + {"session": session, "extra_metadata": {"sdk": "gradio"}}, + ) + ] + + +@pytest.mark.asyncio +async def test_hf_repo_files_upload_registers_known_artifact_with_force(monkeypatch): + session = _session() + calls = [] + uploads = [] + + class FakeApi: + token = "hf-token" + + def upload_file(self, **kwargs): + uploads.append(kwargs) + return SimpleNamespace() + + def fake_register(api, repo_id, repo_type, **kwargs): + calls.append((api, repo_id, repo_type, kwargs)) + return True + + monkeypatch.setattr( + "agent.tools.hf_repo_files_tool.register_hub_artifact", + fake_register, + ) + remember_hub_artifact(session, "alice/model", "model") + + tool = HfRepoFilesTool(hf_token="hf-token", session=session) + tool.api = FakeApi() + + result = await tool._upload( + { + "repo_id": "alice/model", + "repo_type": "model", + "path": "weights.bin", + "content": b"weights", + } + ) + readme_result = await tool._upload( + { + "repo_id": "alice/model", + "repo_type": "model", + "path": "README.md", + "content": "# Model", + } + ) + + assert result["totalResults"] == 1 + assert readme_result["totalResults"] == 1 + assert [upload["path_in_repo"] for upload in uploads] == [ + "weights.bin", + "README.md", + ] + assert calls == [ + ( + tool.api, + "alice/model", + "model", + {"session": session, "force": False}, + ), + ( + tool.api, + "alice/model", + "model", + {"session": session, "force": True}, + ), + ] + + +def test_hf_jobs_artifact_bootstrap_wraps_command_without_changing_exec_target(): + command = ["uv", "run", "train.py"] + wrapped = _wrap_command_with_artifact_bootstrap(command, _session()) + + assert wrapped[0:2] == ["/bin/sh", "-lc"] + assert "sitecustomize.py" in wrapped[2] + assert "PYTHONPATH" in wrapped[2] + assert "exec uv run train.py" in wrapped[2] + assert _wrap_command_with_artifact_bootstrap(command, None) == command + + +def test_shell_bootstrap_wraps_capybara_push_to_hub_pattern(): + command = ( + "pip install -q datasets huggingface_hub && python -c " + "\"subset.push_to_hub('lewtun/Capybara-100', private=False)\"" + ) + + wrapped = wrap_shell_command_with_hub_artifact_bootstrap(command, _session()) + + assert "sitecustomize.py" in wrapped + assert "PYTHONPATH" in wrapped + assert command in wrapped + assert wrap_shell_command_with_hub_artifact_bootstrap(command, None) == command + assert ( + wrap_shell_command_with_hub_artifact_bootstrap( + command, + SimpleNamespace(session_start_time="2026-05-05T10:20:30"), + ) + == command + ) + + +@pytest.mark.asyncio +async def test_sandbox_bash_wraps_command_for_session_artifact_hooks(): + calls = [] + + class FakeSandbox: + def call_tool(self, name, args): + calls.append((name, args)) + return SimpleNamespace(success=True, output="ok", error="") + + session = _session() + session.sandbox = FakeSandbox() + + handler = sandbox_tool._make_tool_handler("bash") + output, ok = await handler({"command": "python make_dataset.py"}, session=session) + + assert ok is True + assert output == "ok" + assert calls[0][0] == "bash" + assert "sitecustomize.py" in calls[0][1]["command"] + assert "python make_dataset.py" in calls[0][1]["command"] + + +@pytest.mark.asyncio +async def test_local_bash_wraps_command_for_session_artifact_hooks(monkeypatch): + seen = {} + + def fake_run(command, **kwargs): + seen["command"] = command + seen["kwargs"] = kwargs + return SimpleNamespace(stdout="ok", stderr="", returncode=0) + + monkeypatch.setattr(local_tools.subprocess, "run", fake_run) + + output, ok = await local_tools._bash_handler( + {"command": "python make_dataset.py"}, + session=_session(), + ) + + assert ok is True + assert output == "ok" + assert "sitecustomize.py" in seen["command"] + assert "python make_dataset.py" in seen["command"] + + +def test_sitecustomize_bootstrap_is_valid_python(): + code = build_hub_artifact_sitecustomize(_session()) + + compile(code, "sitecustomize.py", "exec") + assert "ml-intern-artifacts-2026-05-05-session-123" in code + + +def test_sitecustomize_bootstrap_reuses_existing_collection_slug(): + session = _session() + setattr( + session, + hub_artifacts._COLLECTION_SLUG_ATTR, + "alice/ml-intern-artifacts-2026-05-05-session-123", + ) + + code = build_hub_artifact_sitecustomize(session) + + compile(code, "sitecustomize.py", "exec") + assert ( + "collection_slug = 'alice/ml-intern-artifacts-2026-05-05-session-123'" in code + ) diff --git a/tests/unit/test_kpis_scheduler.py b/tests/unit/test_kpis_scheduler.py index 8c52f0513d20d8234f440efde37670cb1567c3d5..cba24d7f0990ff647c8433955ff729fd24c413d6 100644 --- a/tests/unit/test_kpis_scheduler.py +++ b/tests/unit/test_kpis_scheduler.py @@ -28,7 +28,12 @@ def _load(): def test_token_resolution_order(monkeypatch): mod = _load() - for var in ("HF_KPI_WRITE_TOKEN", "HF_SESSION_UPLOAD_TOKEN", "HF_TOKEN", "HF_ADMIN_TOKEN"): + for var in ( + "HF_KPI_WRITE_TOKEN", + "HF_SESSION_UPLOAD_TOKEN", + "HF_TOKEN", + "HF_ADMIN_TOKEN", + ): monkeypatch.delenv(var, raising=False) assert mod._resolve_token() is None @@ -86,7 +91,11 @@ def test_start_skips_cleanly_without_apscheduler(monkeypatch): monkeypatch.delenv("ML_INTERN_KPIS_DISABLED", raising=False) # Force the apscheduler import to fail — start() should log and return. - real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + real_import = ( + __builtins__["__import__"] + if isinstance(__builtins__, dict) + else __builtins__.__import__ + ) def fake_import(name, *args, **kwargs): if name.startswith("apscheduler"): diff --git a/tests/unit/test_llm_error_classification.py b/tests/unit/test_llm_error_classification.py index 0b019574024a05087a8eb98a091869d03ea464ce..8bcd54fd20119e614a51cffe62ea3896186cb19d 100644 --- a/tests/unit/test_llm_error_classification.py +++ b/tests/unit/test_llm_error_classification.py @@ -75,7 +75,9 @@ def test_timeout_is_transient_but_not_rate_limit(): def test_rate_limit_uses_longer_schedule(): err = Exception("Too many tokens, please wait before trying again.") - delays = [_retry_delay_for(err, i) for i in range(len(_LLM_RATE_LIMIT_RETRY_DELAYS))] + delays = [ + _retry_delay_for(err, i) for i in range(len(_LLM_RATE_LIMIT_RETRY_DELAYS)) + ] assert delays == _LLM_RATE_LIMIT_RETRY_DELAYS # Just past the schedule → None (stop retrying). assert _retry_delay_for(err, len(_LLM_RATE_LIMIT_RETRY_DELAYS)) is None diff --git a/tests/unit/test_messaging.py b/tests/unit/test_messaging.py index 968622c1aa038f5615c31f610c03f912a186ed3e..f3228e68fdb34b24a41a69c2e93af24ac16c052d 100644 --- a/tests/unit/test_messaging.py +++ b/tests/unit/test_messaging.py @@ -64,9 +64,7 @@ def _config_with_messaging(**destination_overrides) -> Config: ) -def _test_session( - config: Config, gateway, session_id: str = "session-test" -) -> Session: +def _test_session(config: Config, gateway, session_id: str = "session-test") -> Session: return Session( asyncio.Queue(), config=config, @@ -485,7 +483,9 @@ async def test_turn_complete_can_be_disabled_by_custom_auto_event_config(): def test_session_manager_updates_notification_destinations_in_session_info(): config = _config_with_messaging(allow_auto_events=True) - manager = SessionManager(str(Path(__file__).resolve().parents[2] / "configs" / "cli_agent_config.json")) + manager = SessionManager( + str(Path(__file__).resolve().parents[2] / "configs" / "cli_agent_config.json") + ) manager.config = config manager.sessions = {} diff --git a/tests/unit/test_sandbox_auto_start.py b/tests/unit/test_sandbox_auto_start.py index b99e28cab228af9f1ade18a4dd1bc42cb2c1bf05..1ad27fca37c12fff40e4fe9b1601031fbdc3dede 100644 --- a/tests/unit/test_sandbox_auto_start.py +++ b/tests/unit/test_sandbox_auto_start.py @@ -15,7 +15,9 @@ def test_default_cpu_sandbox_create_does_not_require_approval(): def test_non_default_sandbox_create_still_requires_approval(): config = SimpleNamespace(yolo_mode=False) - assert _needs_approval("sandbox_create", {"hardware": "cpu-upgrade"}, config) is True + assert ( + _needs_approval("sandbox_create", {"hardware": "cpu-upgrade"}, config) is True + ) assert _needs_approval("sandbox_create", {"hardware": "t4-small"}, config) is True @@ -27,5 +29,8 @@ def test_prompt_and_tool_specs_do_not_require_cpu_sandbox_create(): assert "Do NOT call sandbox_create before normal CPU work" in prompt assert "cpu-basic sandbox is already available" in prompt - assert "cpu-basic sandbox is already started automatically" in tool_specs["sandbox_create"] + assert ( + "cpu-basic sandbox is already started automatically" + in tool_specs["sandbox_create"] + ) assert "started automatically for normal CPU work" in tool_specs["bash"] diff --git a/tests/unit/test_sandbox_private_spaces.py b/tests/unit/test_sandbox_private_spaces.py index 7c896ea581289446883a307be4b7e57cc700353f..d29e9a9a8805ebb9b5a4db7f1671d0f589adc3f2 100644 --- a/tests/unit/test_sandbox_private_spaces.py +++ b/tests/unit/test_sandbox_private_spaces.py @@ -246,7 +246,9 @@ def test_orphan_sweep_preserves_spaces_without_last_modified(): assert count == 0 assert deleted == [] - assert logs == ["orphan sweep: skipping alice/sandbox-12345678; missing lastModified"] + assert logs == [ + "orphan sweep: skipping alice/sandbox-12345678; missing lastModified" + ] def test_ensure_sandbox_overrides_private_argument(monkeypatch): @@ -455,7 +457,9 @@ def test_sandbox_create_replaces_auto_cpu_sandbox(monkeypatch): pass monkeypatch.setattr(sandbox_tool, "_ensure_sandbox", fake_ensure_sandbox) - monkeypatch.setattr(telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy) + monkeypatch.setattr( + telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy + ) session = FakeSession() out, ok = asyncio.run( @@ -479,7 +483,9 @@ def test_teardown_cancels_preload_and_deletes_owned_sandbox(monkeypatch): async def fake_record_sandbox_destroy(*args, **kwargs): pass - monkeypatch.setattr(telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy) + monkeypatch.setattr( + telemetry, "record_sandbox_destroy", fake_record_sandbox_destroy + ) async def run(): cancel_event = threading.Event() diff --git a/tests/unit/test_session_manager_persistence.py b/tests/unit/test_session_manager_persistence.py index 95d1f1acc0e037a16826f35cb3dd4a25cd11e52e..db63d5995de19bd489faa9dce102403a084b51b0 100644 --- a/tests/unit/test_session_manager_persistence.py +++ b/tests/unit/test_session_manager_persistence.py @@ -38,7 +38,11 @@ class FakeRuntimeSession: def auto_approval_policy_summary(self): cap = self.auto_approval_cost_cap_usd - remaining = None if cap is None else max(0, cap - self.auto_approval_estimated_spend_usd) + remaining = ( + None + if cap is None + else max(0, cap - self.auto_approval_estimated_spend_usd) + ) return { "enabled": self.auto_approval_enabled, "cost_cap_usd": cap, @@ -426,6 +430,32 @@ async def test_create_session_schedules_cpu_sandbox_preload(): await _cancel_runtime_tasks(manager) +@pytest.mark.asyncio +async def test_create_session_starts_hub_artifact_collection(monkeypatch): + manager = _manager_with_store(NoopSessionStore()) + manager.enable_hub_artifact_collections = True + stop = _install_fake_runtime(manager) + started: list[tuple[str, str]] = [] + + def fake_start_session_artifact_collection_task(session, **kwargs): + started.append((session.session_id, kwargs["token"])) + return None + + monkeypatch.setattr( + "session_manager.start_session_artifact_collection_task", + fake_start_session_artifact_collection_task, + ) + manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign] + + try: + session_id = await manager.create_session(user_id="owner", hf_token="token") + + assert started == [(session_id, "token")] + finally: + stop.set() + await _cancel_runtime_tasks(manager) + + @pytest.mark.asyncio async def test_lazy_restore_schedules_cpu_sandbox_preload(): manager = _manager_with_store(RestoreStore()) @@ -438,7 +468,9 @@ async def test_lazy_restore_schedules_cpu_sandbox_preload(): manager._start_cpu_sandbox_preload = fake_start_cpu_sandbox_preload # type: ignore[method-assign] try: - restored = await manager.ensure_session_loaded("persisted-session", user_id="owner") + restored = await manager.ensure_session_loaded( + "persisted-session", user_id="owner" + ) assert restored is not None assert scheduled == ["persisted-session"] @@ -448,6 +480,37 @@ async def test_lazy_restore_schedules_cpu_sandbox_preload(): await _cancel_runtime_tasks(manager) +@pytest.mark.asyncio +async def test_lazy_restore_starts_hub_artifact_collection(monkeypatch): + manager = _manager_with_store(RestoreStore()) + manager.enable_hub_artifact_collections = True + stop = _install_fake_runtime(manager) + started: list[tuple[str, str]] = [] + + def fake_start_session_artifact_collection_task(session, **kwargs): + started.append((session.session_id, kwargs["token"])) + return None + + monkeypatch.setattr( + "session_manager.start_session_artifact_collection_task", + fake_start_session_artifact_collection_task, + ) + manager._start_cpu_sandbox_preload = lambda _: None # type: ignore[method-assign] + + try: + restored = await manager.ensure_session_loaded( + "persisted-session", + user_id="owner", + hf_token="token", + ) + + assert restored is not None + assert started == [("persisted-session", "token")] + finally: + stop.set() + await _cancel_runtime_tasks(manager) + + @pytest.mark.asyncio async def test_lazy_restore_deletes_persisted_sandbox_before_preload(monkeypatch): deleted: list[tuple[str, str, str]] = [] @@ -572,7 +635,9 @@ async def test_lazy_restore_preserves_pending_approval_tool_calls(): stop = _install_fake_runtime(manager) try: - restored = await manager.ensure_session_loaded("approval-session", user_id="owner") + restored = await manager.ensure_session_loaded( + "approval-session", user_id="owner" + ) assert restored is not None tool_calls = restored.session.pending_approval["tool_calls"] diff --git a/tests/unit/test_session_uploader.py b/tests/unit/test_session_uploader.py index dfbc27fb7365d933fb89b94a4a1617f5ebd6e78d..82f5503b575f6581331c782d97aaa008bbd148fd 100644 --- a/tests/unit/test_session_uploader.py +++ b/tests/unit/test_session_uploader.py @@ -56,7 +56,9 @@ def test_upload_dataset_card_only_for_claude_code_format(): assert api.calls[0]["repo_id"] == "lewtun/ml-intern-sessions" assert api.calls[0]["repo_type"] == "dataset" assert api.calls[0]["token"] == "hf_token" - assert b"no comprehensive redaction or human review" in api.calls[0]["path_or_fileobj"] + assert ( + b"no comprehensive redaction or human review" in api.calls[0]["path_or_fileobj"] + ) def test_personal_token_env_takes_precedence_for_hf_token(monkeypatch): diff --git a/tests/unit/test_sft_tagger.py b/tests/unit/test_sft_tagger.py index 70d4edd60b280df7b3e989a1aa3ce22927794290..cf02e7b8a35a4fef3cc476bf2eed18c87f75fbdc 100644 --- a/tests/unit/test_sft_tagger.py +++ b/tests/unit/test_sft_tagger.py @@ -90,11 +90,22 @@ def test_outcome_doom_loop_and_context(): def test_hf_job_tags(): events = [ - _ev("tool_call", {"tool": "hf_jobs", "arguments": {"script": "from trl import SFTTrainer"}}), - _ev("hf_job_submit", { - "flavor": "a100-large", "push_to_hub": True, "job_id": "j1", - }), - _ev("hf_job_complete", {"flavor": "a100-large", "final_status": "COMPLETED", "wall_time_s": 3600}), + _ev( + "tool_call", + {"tool": "hf_jobs", "arguments": {"script": "from trl import SFTTrainer"}}, + ), + _ev( + "hf_job_submit", + { + "flavor": "a100-large", + "push_to_hub": True, + "job_id": "j1", + }, + ), + _ev( + "hf_job_complete", + {"flavor": "a100-large", "final_status": "COMPLETED", "wall_time_s": 3600}, + ), _ev("hf_job_submit", {"flavor": "a100x4", "push_to_hub": False}), _ev("hf_job_complete", {"flavor": "a100x4", "final_status": "FAILED"}), ] @@ -112,7 +123,13 @@ def test_hf_job_oom(): events = [ _ev("tool_call", {"tool": "hf_jobs", "arguments": {}}), _ev("hf_job_submit", {"flavor": "a100-large"}), - _ev("tool_output", {"success": False, "output": "RuntimeError: CUDA out of memory. Tried to allocate..."}), + _ev( + "tool_output", + { + "success": False, + "output": "RuntimeError: CUDA out of memory. Tried to allocate...", + }, + ), ] tags = tag_session(_traj(events)) assert "hf_job:oom" in tags @@ -120,7 +137,10 @@ def test_hf_job_oom(): def test_sandbox_tags(): events = [ - _ev("sandbox_create", {"hardware": "t4-small", "sandbox_id": "s1", "create_latency_s": 5}), + _ev( + "sandbox_create", + {"hardware": "t4-small", "sandbox_id": "s1", "create_latency_s": 5}, + ), _ev("sandbox_destroy", {"sandbox_id": "s1", "lifetime_s": 3600}), ] tags = tag_session(_traj(events)) @@ -142,7 +162,9 @@ def test_sandbox_cpu_short(): def test_feedback_tags(): up_only = _traj(events=[_ev("feedback", {"rating": "up"})]) down_only = _traj(events=[_ev("feedback", {"rating": "down"})]) - mixed = _traj(events=[_ev("feedback", {"rating": "up"}), _ev("feedback", {"rating": "down"})]) + mixed = _traj( + events=[_ev("feedback", {"rating": "up"}), _ev("feedback", {"rating": "down"})] + ) none = _traj() assert "feedback:up" in tag_session(up_only) assert "feedback:down" in tag_session(down_only) @@ -152,9 +174,15 @@ def test_feedback_tags(): def test_task_training(): events = [ - _ev("tool_call", {"tool": "hf_jobs", "arguments": { - "script": "from trl import SFTTrainer\ntrainer = SFTTrainer(...)" - }}), + _ev( + "tool_call", + { + "tool": "hf_jobs", + "arguments": { + "script": "from trl import SFTTrainer\ntrainer = SFTTrainer(...)" + }, + }, + ), _ev("hf_job_submit", {"flavor": "a100-large"}), ] assert "task:training" in tag_session(_traj(events)) diff --git a/tests/unit/test_thinking_history.py b/tests/unit/test_thinking_history.py index 9ef4b2f61a023ac9899311f0e133dfe4438aa398..6ec92958e18d44b602e779cfeaea55c8f0e8ea5a 100644 --- a/tests/unit/test_thinking_history.py +++ b/tests/unit/test_thinking_history.py @@ -135,6 +135,7 @@ async def test_streaming_call_rebuilds_anthropic_thinking_state(monkeypatch): ) events = [] + async def send_event(event): events.append(event) @@ -232,6 +233,7 @@ async def test_streaming_call_rebuilds_anthropic_delta_thinking_state(monkeypatc ) events = [] + async def send_event(event): events.append(event) @@ -276,6 +278,7 @@ async def test_streaming_call_skips_chunk_rebuild_for_non_anthropic(monkeypatch) raise AssertionError("stream_chunk_builder should not run") events = [] + async def send_event(event): events.append(event) diff --git a/tests/unit/test_user_quotas.py b/tests/unit/test_user_quotas.py index 4475b1ebd52aa719c5fc1d91ede88212adbb9cbe..cb3690876206e6f4ee74407487af5800265e58c1 100644 --- a/tests/unit/test_user_quotas.py +++ b/tests/unit/test_user_quotas.py @@ -1,10 +1,8 @@ """Tests for backend/user_quotas.py — the in-memory Claude daily-quota store.""" import asyncio -import os import sys from pathlib import Path -from unittest.mock import patch import pytest diff --git a/tests/unit/test_web_search_tool.py b/tests/unit/test_web_search_tool.py index dd243447141349b1c67e04d8a5c6367356da6674..822bc731f3beebe5c35160baaff53b4fc2cfa51c 100644 --- a/tests/unit/test_web_search_tool.py +++ b/tests/unit/test_web_search_tool.py @@ -38,7 +38,9 @@ def test_web_search_extracts_duckduckgo_results_and_filters_domains(monkeypatch) url, ) - monkeypatch.setenv(web_search_tool.WEB_SEARCH_BASE_URL_ENV, "http://search.test/search") + monkeypatch.setenv( + web_search_tool.WEB_SEARCH_BASE_URL_ENV, "http://search.test/search" + ) monkeypatch.setattr(web_search_tool.requests, "get", fake_get) output = web_search_tool.execute_web_search( @@ -91,7 +93,9 @@ def test_web_search_generic_fallback_dedupes_and_rejects_bad_base_url(monkeypatc url, ) - monkeypatch.setenv(web_search_tool.WEB_SEARCH_BASE_URL_ENV, "http://search.test/fallback") + monkeypatch.setenv( + web_search_tool.WEB_SEARCH_BASE_URL_ENV, "http://search.test/fallback" + ) monkeypatch.setattr(web_search_tool.requests, "get", fake_get) output = web_search_tool.execute_web_search("generic links") @@ -119,7 +123,10 @@ async def test_web_search_handler_returns_pretty_json(monkeypatch): "execute_web_search", lambda **kwargs: { "query": kwargs["query"], - "results": ["No web search results matched the query 'x'.", {"content": []}], + "results": [ + "No web search results matched the query 'x'.", + {"content": []}, + ], "durationSeconds": 0.1, }, ) diff --git a/uv.lock b/uv.lock index 73df668c3f519dee440254013cbd65b6ba6b986e..7054363ee8becb5c913d8e54fd9f1596c682f401 100644 --- a/uv.lock +++ b/uv.lock @@ -1803,11 +1803,13 @@ all = [ { name = "pandas" }, { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "ruff" }, { name = "tenacity" }, ] dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "ruff" }, ] eval = [ { name = "datasets" }, @@ -1840,6 +1842,7 @@ requires-dist = [ { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "requests", specifier = ">=2.33.0" }, { name = "rich", specifier = ">=13.0.0" }, + { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.12" }, { name = "tenacity", marker = "extra == 'eval'", specifier = ">=8.0.0" }, { name = "thefuzz", specifier = ">=0.22.1" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.32.0" }, @@ -3384,6 +3387,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/f4/09ffb3ebd0cbb9e2c7c9b84d252557ecf434cd71584ee1e32f66013824df/rpds_py-0.29.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:f7728653900035fb7b8d06e1e5900545d8088efc9d5d4545782da7df03ec803f", size = 564054, upload-time = "2025-11-16T14:50:37.733Z" }, ] +[[package]] +name = "ruff" +version = "0.15.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/99/43/3291f1cc9106f4c63bdce7a8d0df5047fe8422a75b091c16b5e9355e0b11/ruff-0.15.12.tar.gz", hash = "sha256:ecea26adb26b4232c0c2ca19ccbc0083a68344180bba2a600605538ce51a40a6", size = 4643852, upload-time = "2026-04-24T18:17:14.305Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/6e/e78ffb61d4686f3d96ba3df2c801161843746dcbcbb17a1e927d4829312b/ruff-0.15.12-py3-none-linux_armv6l.whl", hash = "sha256:f86f176e188e94d6bdbc09f09bfd9dc729059ad93d0e7390b5a73efe19f8861c", size = 10640713, upload-time = "2026-04-24T18:17:22.841Z" }, + { url = "https://files.pythonhosted.org/packages/ae/08/a317bc231fb9e7b93e4ef3089501e51922ff88d6936ce5cf870c4fe55419/ruff-0.15.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e3bcd123364c3770b8e1b7baaf343cc99a35f197c5c6e8af79015c666c423a6c", size = 11069267, upload-time = "2026-04-24T18:17:30.105Z" }, + { url = "https://files.pythonhosted.org/packages/aa/a4/f828e9718d3dce1f5f11c39c4f65afd32783c8b2aebb2e3d259e492c47bd/ruff-0.15.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fe87510d000220aa1ed530d4448a7c696a0cae1213e5ec30e5874287b66557b5", size = 10397182, upload-time = "2026-04-24T18:17:07.177Z" }, + { url = "https://files.pythonhosted.org/packages/71/e0/3310fc6d1b5e1fdea22bf3b1b807c7e187b581021b0d7d4514cccdb5fb71/ruff-0.15.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84a1630093121375a3e2a95b4a6dc7b59e2b4ee76216e32d81aae550a832d002", size = 10758012, upload-time = "2026-04-24T18:16:55.759Z" }, + { url = "https://files.pythonhosted.org/packages/11/c1/a606911aee04c324ddaa883ae418f3569792fd3c4a10c50e0dd0a2311e1e/ruff-0.15.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fb129f40f114f089ebe0ca56c0d251cf2061b17651d464bb6478dc01e69f11f5", size = 10447479, upload-time = "2026-04-24T18:16:51.677Z" }, + { url = "https://files.pythonhosted.org/packages/9d/68/4201e8444f0894f21ab4aeeaee68aa4f10b51613514a20d80bd628d57e88/ruff-0.15.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b0c862b172d695db7598426b8af465e7e9ac00a3ea2a3630ee67eb82e366aaa6", size = 11234040, upload-time = "2026-04-24T18:17:16.529Z" }, + { url = "https://files.pythonhosted.org/packages/34/ff/8a6d6cf4ccc23fd67060874e832c18919d1557a0611ebef03fdb01fff11e/ruff-0.15.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2849ea9f3484c3aca43a82f484210370319e7170df4dfe4843395ddf6c57bc33", size = 12087377, upload-time = "2026-04-24T18:17:04.944Z" }, + { url = "https://files.pythonhosted.org/packages/85/f6/c669cf73f5152f623d34e69866a46d5e6185816b19fcd5b6dd8a2d299922/ruff-0.15.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e77c7e51c07fe396826d5969a5b846d9cd4c402535835fb6e21ce8b28fef847", size = 11367784, upload-time = "2026-04-24T18:17:25.409Z" }, + { url = "https://files.pythonhosted.org/packages/e8/39/c61d193b8a1daaa8977f7dea9e8d8ba866e02ea7b65d32f6861693aa4c12/ruff-0.15.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b2f4f2f3b1026b5fb449b467d9264bf22067b600f7b6f41fc5958909f449d0", size = 11344088, upload-time = "2026-04-24T18:17:12.258Z" }, + { url = "https://files.pythonhosted.org/packages/c2/8d/49afab3645e31e12c590acb6d3b5b69d7aab5b81926dbaf7461f9441f37a/ruff-0.15.12-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9ba3b8f1afd7e2e43d8943e55f249e13f9682fde09711644a6e7290eb4f3e339", size = 11271770, upload-time = "2026-04-24T18:17:02.457Z" }, + { url = "https://files.pythonhosted.org/packages/46/06/33f41fe94403e2b755481cdfb9b7ef3e4e0ed031c4581124658d935d52b4/ruff-0.15.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e852ba9fdc890655e1d78f2df1499efbe0e54126bd405362154a75e2bde159c5", size = 10719355, upload-time = "2026-04-24T18:17:27.648Z" }, + { url = "https://files.pythonhosted.org/packages/0d/59/18aa4e014debbf559670e4048e39260a85c7fcee84acfd761ac01e7b8d35/ruff-0.15.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dd8aed930da53780d22fc70bdf84452c843cf64f8cb4eb38984319c24c5cd5fd", size = 10462758, upload-time = "2026-04-24T18:17:32.347Z" }, + { url = "https://files.pythonhosted.org/packages/25/e7/cc9f16fd0f3b5fddcbd7ec3d6ae30c8f3fde1047f32a4093a98d633c6570/ruff-0.15.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:01da3988d225628b709493d7dc67c3b9b12c0210016b08690ef9bd27970b262b", size = 10953498, upload-time = "2026-04-24T18:17:20.674Z" }, + { url = "https://files.pythonhosted.org/packages/72/7a/a9ba7f98c7a575978698f4230c5e8cc54bbc761af34f560818f933dafa0c/ruff-0.15.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9cae0f92bd5700d1213188b31cd3bdd2b315361296d10b96b8e2337d3d11f53e", size = 11447765, upload-time = "2026-04-24T18:17:09.755Z" }, + { url = "https://files.pythonhosted.org/packages/ea/f9/0ae446942c846b8266059ad8a30702a35afae55f5cdc54c5adf8d7afdc27/ruff-0.15.12-py3-none-win32.whl", hash = "sha256:d0185894e038d7043ba8fd6aee7499ece6462dc0ea9f1e260c7451807c714c20", size = 10657277, upload-time = "2026-04-24T18:17:18.591Z" }, + { url = "https://files.pythonhosted.org/packages/33/f1/9614e03e1cdcbf9437570b5400ced8a720b5db22b28d8e0f1bda429f660d/ruff-0.15.12-py3-none-win_amd64.whl", hash = "sha256:c87a162d61ab3adca47c03f7f717c68672edec7d1b5499e652331780fe74950d", size = 11837758, upload-time = "2026-04-24T18:17:00.113Z" }, + { url = "https://files.pythonhosted.org/packages/c0/98/6beb4b351e472e5f4c4613f7c35a5290b8be2497e183825310c4c3a3984b/ruff-0.15.12-py3-none-win_arm64.whl", hash = "sha256:a538f7a82d061cee7be55542aca1d86d1393d55d81d4fcc314370f4340930d4f", size = 11120821, upload-time = "2026-04-24T18:16:57.979Z" }, +] + [[package]] name = "s3fs" version = "2025.9.0"