diff --git a/README.md b/README.md index f661514cddfcc1e0460213bff7d93da74e6e165f..19587d32b58395337b711f95f2b3c2e87c156981 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,56 @@ ml-intern --max-iterations 100 "your prompt" ml-intern --no-stream "your prompt" ``` +## Supported Gateways + +ML Intern currently supports one-way notification gateways from CLI sessions. +These gateways send out-of-band status updates; they do not accept inbound chat +messages. + +### Slack + +Slack notifications use the Slack Web API to post messages when the agent needs +approval, hits an error, or completes a turn. Create a Slack app with a bot token +that has `chat:write`, invite the bot to the target channel, then set: + +```bash +SLACK_BOT_TOKEN=xoxb-... +SLACK_CHANNEL_ID=C... +``` + +The CLI automatically creates a `slack.default` destination when both variables +are present. Optional environment variables for the env-only default: + +```bash +ML_INTERN_SLACK_NOTIFICATIONS=false +ML_INTERN_SLACK_DESTINATION=slack.ops +ML_INTERN_SLACK_AUTO_EVENTS=approval_required,error,turn_complete +ML_INTERN_SLACK_ALLOW_AGENT_TOOL=true +ML_INTERN_SLACK_ALLOW_AUTO_EVENTS=true +``` + +For a persistent user-level config, put overrides in +`~/.config/ml-intern/cli_agent_config.json` or point `ML_INTERN_CLI_CONFIG` at a +JSON file: + +```json +{ + "messaging": { + "enabled": true, + "auto_event_types": ["approval_required", "error", "turn_complete"], + "destinations": { + "slack.ops": { + "provider": "slack", + "token": "${SLACK_BOT_TOKEN}", + "channel": "${SLACK_CHANNEL_ID}", + "allow_agent_tool": true, + "allow_auto_events": true + } + } + } +} +``` + ## Architecture ### Component Overview diff --git a/agent/config.py b/agent/config.py index 7e696dd78fdbf04a7f5ff527127583930acae3d0..5a6a8a45f796a6557404b8f401ad2fad3f264288 100644 --- a/agent/config.py +++ b/agent/config.py @@ -6,6 +6,8 @@ from typing import Any, Union from dotenv import load_dotenv +from agent.messaging.models import MessagingConfig + # Project root: two levels up from this file (agent/config.py -> project root) _PROJECT_ROOT = Path(__file__).resolve().parent.parent from fastmcp.mcp_config import ( @@ -47,6 +49,104 @@ class Config(BaseModel): # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off. # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max" reasoning_effort: str | None = "max" + messaging: MessagingConfig = MessagingConfig() + + +USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG" +DEFAULT_USER_CONFIG_PATH = Path.home() / ".config" / "ml-intern" / "cli_agent_config.json" +SLACK_DEFAULT_DESTINATION = "slack.default" +SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"] + + +def _deep_merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + merged = dict(base) + for key, value in override.items(): + current = merged.get(key) + if isinstance(current, dict) and isinstance(value, dict): + merged[key] = _deep_merge_config(current, value) + else: + merged[key] = value + return merged + + +def _load_json_config(path: Path) -> dict[str, Any]: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + raise ValueError(f"Config file {path} must contain a JSON object") + return data + + +def _load_user_config() -> dict[str, Any]: + raw_path = os.environ.get(USER_CONFIG_ENV_VAR) + if raw_path: + path = Path(raw_path).expanduser() + if not path.exists(): + raise FileNotFoundError( + f"{USER_CONFIG_ENV_VAR} points to missing config file: {path}" + ) + return _load_json_config(path) + + if DEFAULT_USER_CONFIG_PATH.exists(): + return _load_json_config(DEFAULT_USER_CONFIG_PATH) + return {} + + +def _env_bool(name: str, default: bool) -> bool: + value = os.environ.get(name) + if value is None: + return default + normalized = value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + return default + + +def _env_list(name: str) -> list[str] | None: + value = os.environ.get(name) + if value is None: + return None + return [item.strip() for item in value.split(",") if item.strip()] + + +def apply_slack_user_defaults(raw_config: dict[str, Any]) -> dict[str, Any]: + """Enable a default Slack destination from user env vars, when present.""" + if not _env_bool("ML_INTERN_SLACK_NOTIFICATIONS", True): + return raw_config + + token = os.environ.get("SLACK_BOT_TOKEN") + channel = os.environ.get("SLACK_CHANNEL_ID") or os.environ.get("SLACK_CHANNEL") + if not token or not channel: + return raw_config + + config = dict(raw_config) + messaging = dict(config.get("messaging") or {}) + destinations = dict(messaging.get("destinations") or {}) + destination_name = ( + os.environ.get("ML_INTERN_SLACK_DESTINATION") or SLACK_DEFAULT_DESTINATION + ).strip() + + if destination_name not in destinations: + destinations[destination_name] = { + "provider": "slack", + "token": token, + "channel": channel, + "allow_agent_tool": _env_bool("ML_INTERN_SLACK_ALLOW_AGENT_TOOL", True), + "allow_auto_events": _env_bool("ML_INTERN_SLACK_ALLOW_AUTO_EVENTS", True), + } + + auto_events = _env_list("ML_INTERN_SLACK_AUTO_EVENTS") + if auto_events is not None: + messaging["auto_event_types"] = auto_events + elif "auto_event_types" not in messaging: + messaging["auto_event_types"] = SLACK_DEFAULT_AUTO_EVENT_TYPES + + messaging["enabled"] = True + messaging["destinations"] = destinations + config["messaging"] = messaging + return config def substitute_env_vars(obj: Any) -> Any: @@ -86,7 +186,10 @@ def substitute_env_vars(obj: Any) -> Any: return obj -def load_config(config_path: str = "config.json") -> Config: +def load_config( + config_path: str = "config.json", + include_user_defaults: bool = False, +) -> Config: """ Load configuration with environment variable substitution. @@ -98,8 +201,10 @@ def load_config(config_path: str = "config.json") -> Config: load_dotenv(_PROJECT_ROOT / ".env") load_dotenv(override=False) - with open(config_path, "r") as f: - raw_config = json.load(f) + raw_config = _load_json_config(Path(config_path)) + if include_user_defaults: + raw_config = _deep_merge_config(raw_config, _load_user_config()) + raw_config = apply_slack_user_defaults(raw_config) config_with_env = substitute_env_vars(raw_config) return Config.model_validate(config_with_env) diff --git a/agent/context_manager/manager.py b/agent/context_manager/manager.py index 64584b6d56bd14073bbe5fae53e62481338c638f..c842c8842c215450e651a33abf93c2a115d68e54 100644 --- a/agent/context_manager/manager.py +++ b/agent/context_manager/manager.py @@ -160,6 +160,7 @@ class ContextManager: self.running_context_usage = 0 self.untouched_messages = untouched_messages self.items: list[Message] = [Message(role="system", content=self.system_prompt)] + self.on_message_added = None def _load_system_prompt( self, @@ -219,6 +220,8 @@ class ContextManager: if token_count: self.running_context_usage = token_count self.items.append(message) + if self.on_message_added: + self.on_message_added(message) def get_messages(self) -> list[Message]: """Get all messages for sending to LLM. diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index 26361d413a7c19f29c2e02bfcbf50e62271295c3..8b7a4572d0843ec473d0c6c21c1cd059b022ad28 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -8,11 +8,18 @@ import logging import os import time from dataclasses import dataclass, field - -from litellm import ChatCompletionMessageToolCall, Message, acompletion +from typing import Any + +from litellm import ( + ChatCompletionMessageToolCall, + Message, + acompletion, + stream_chunk_builder, +) from litellm.exceptions import ContextWindowExceededError from agent.config import Config +from agent.messaging.gateway import NotificationGateway from agent.core import telemetry from agent.core.doom_loop import check_for_doom_loop from agent.core.llm_params import _resolve_llm_params @@ -396,12 +403,159 @@ class LLMResult: token_count: int finish_reason: str | None usage: dict = field(default_factory=dict) + thinking_blocks: list[dict[str, Any]] | None = None + reasoning_content: str | None = None + + +def _extract_thinking_state( + message: Any, +) -> tuple[list[dict[str, Any]] | None, str | None]: + """Return provider reasoning fields that must be replayed after tool calls.""" + provider_fields = getattr(message, "provider_specific_fields", None) + if not isinstance(provider_fields, dict): + provider_fields = {} + + thinking_blocks = ( + getattr(message, "thinking_blocks", None) + or provider_fields.get("thinking_blocks") + or None + ) + reasoning_content = ( + getattr(message, "reasoning_content", None) + or provider_fields.get("reasoning_content") + or None + ) + return thinking_blocks, reasoning_content + + +def _should_replay_thinking_state(model_name: str | None) -> bool: + """Only Anthropic's native adapter accepts replayed thinking metadata.""" + return bool(model_name and model_name.startswith("anthropic/")) + + +def _is_invalid_thinking_signature_error(exc: Exception) -> bool: + """Return True when Anthropic rejected replayed extended-thinking state.""" + text = str(exc) + return ( + "Invalid `signature` in `thinking` block" in text + or "Invalid signature in thinking block" in text + ) + + +def _strip_thinking_state_from_messages(messages: list[Any]) -> int: + """Remove replayed thinking metadata from assistant history messages.""" + stripped = 0 + + for message in messages: + role = ( + message.get("role") + if isinstance(message, dict) + else getattr(message, "role", None) + ) + if role != "assistant": + continue + + if isinstance(message, dict): + if message.pop("thinking_blocks", None) is not None: + stripped += 1 + if message.pop("reasoning_content", None) is not None: + stripped += 1 + provider_fields = message.get("provider_specific_fields") + content = message.get("content") + else: + if getattr(message, "thinking_blocks", None) is not None: + message.thinking_blocks = None + stripped += 1 + if getattr(message, "reasoning_content", None) is not None: + message.reasoning_content = None + stripped += 1 + provider_fields = getattr(message, "provider_specific_fields", None) + content = getattr(message, "content", None) + + if isinstance(provider_fields, dict): + cleaned_fields = dict(provider_fields) + if cleaned_fields.pop("thinking_blocks", None) is not None: + stripped += 1 + if cleaned_fields.pop("reasoning_content", None) is not None: + stripped += 1 + if cleaned_fields != provider_fields: + if isinstance(message, dict): + message["provider_specific_fields"] = cleaned_fields + else: + message.provider_specific_fields = cleaned_fields + + if isinstance(content, list): + cleaned_content = [ + block + for block in content + if not ( + isinstance(block, dict) + and block.get("type") in {"thinking", "redacted_thinking"} + ) + ] + if len(cleaned_content) != len(content): + stripped += len(content) - len(cleaned_content) + if isinstance(message, dict): + message["content"] = cleaned_content + else: + message.content = cleaned_content + + return stripped + + +async def _maybe_heal_invalid_thinking_signature( + session: Session, + messages: list[Any], + exc: Exception, + *, + already_healed: bool, +) -> bool: + if already_healed or not _is_invalid_thinking_signature_error(exc): + return False + + stripped = _strip_thinking_state_from_messages(messages) + if not stripped: + return False + + await session.send_event(Event( + event_type="tool_log", + data={ + "tool": "system", + "log": ( + "Anthropic rejected stale thinking signatures; retrying " + "without replayed thinking metadata." + ), + }, + )) + return True + + +def _assistant_message_from_result( + llm_result: LLMResult, + *, + model_name: str | None, + tool_calls: list[ToolCall] | None = None, +) -> Message: + """Build an assistant history message without dropping reasoning state.""" + kwargs: dict[str, Any] = { + "role": "assistant", + "content": llm_result.content, + } + if tool_calls is not None: + kwargs["tool_calls"] = tool_calls + if _should_replay_thinking_state(model_name): + if llm_result.thinking_blocks: + kwargs["thinking_blocks"] = llm_result.thinking_blocks + if llm_result.reasoning_content: + kwargs["reasoning_content"] = llm_result.reasoning_content + return Message(**kwargs) async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult: """Call the LLM with streaming, emitting assistant_chunk events.""" response = None _healed_effort = False # one-shot safety net per call + _healed_thinking_signature = False messages, tools = with_prompt_caching(messages, tools, llm_params.get("model")) t_start = time.monotonic() for _llm_attempt in range(_MAX_LLM_RETRIES): @@ -429,6 +583,14 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, )) continue + if await _maybe_heal_invalid_thinking_signature( + session, + messages, + e, + already_healed=_healed_thinking_signature, + ): + _healed_thinking_signature = True + continue _delay = _retry_delay_for(e, _llm_attempt) if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: logger.warning( @@ -448,8 +610,11 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> token_count = 0 finish_reason = None final_usage_chunk = None + chunks = [] + should_replay_thinking = _should_replay_thinking_state(llm_params.get("model")) async for chunk in response: + chunks.append(chunk) if session.is_cancelled: tool_calls_acc.clear() break @@ -498,6 +663,16 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> latency_ms=int((time.monotonic() - t_start) * 1000), finish_reason=finish_reason, ) + thinking_blocks = None + reasoning_content = None + if chunks and should_replay_thinking: + try: + rebuilt = stream_chunk_builder(chunks, messages=messages) + if rebuilt and getattr(rebuilt, "choices", None): + rebuilt_msg = rebuilt.choices[0].message + thinking_blocks, reasoning_content = _extract_thinking_state(rebuilt_msg) + except Exception: + logger.debug("Failed to rebuild streaming thinking state", exc_info=True) return LLMResult( content=full_content or None, @@ -505,6 +680,8 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> token_count=token_count, finish_reason=finish_reason, usage=usage, + thinking_blocks=thinking_blocks, + reasoning_content=reasoning_content, ) @@ -512,6 +689,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) """Call the LLM without streaming, emit assistant_message at the end.""" response = None _healed_effort = False + _healed_thinking_signature = False messages, tools = with_prompt_caching(messages, tools, llm_params.get("model")) t_start = time.monotonic() for _llm_attempt in range(_MAX_LLM_RETRIES): @@ -538,6 +716,14 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."}, )) continue + if await _maybe_heal_invalid_thinking_signature( + session, + messages, + e, + already_healed=_healed_thinking_signature, + ): + _healed_thinking_signature = True + continue _delay = _retry_delay_for(e, _llm_attempt) if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None: logger.warning( @@ -557,6 +743,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) content = message.content or None finish_reason = choice.finish_reason token_count = response.usage.total_tokens if response.usage else 0 + thinking_blocks, reasoning_content = _extract_thinking_state(message) # Build tool_calls_acc in the same format as streaming tool_calls_acc: dict[int, dict] = {} @@ -591,6 +778,8 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) token_count=token_count, finish_reason=finish_reason, usage=usage, + thinking_blocks=thinking_blocks, + reasoning_content=reasoning_content, ) @@ -681,15 +870,6 @@ class Handlers: session.context_manager.add_message( Message(role="user", content=doom_prompt) ) - await session.send_event( - Event( - event_type="tool_log", - data={ - "tool": "system", - "log": "Doom loop detected — injecting corrective prompt", - }, - ) - ) malformed_tool = _detect_repeated_malformed(session.context_manager.items) if malformed_tool: @@ -763,7 +943,10 @@ class Handlers: " • For other tools: reduce the size of your arguments or use bash." ) if content: - assistant_msg = Message(role="assistant", content=content) + assistant_msg = _assistant_message_from_result( + llm_result, + model_name=llm_params.get("model"), + ) session.context_manager.add_message(assistant_msg, token_count) session.context_manager.add_message( Message(role="user", content=f"[SYSTEM: {truncation_hint}]") @@ -819,7 +1002,10 @@ class Handlers: (content or "")[:500], ) if content: - assistant_msg = Message(role="assistant", content=content) + assistant_msg = _assistant_message_from_result( + llm_result, + model_name=llm_params.get("model"), + ) session.context_manager.add_message(assistant_msg, token_count) final_response = content break @@ -841,9 +1027,9 @@ class Handlers: bad_tools.append(tc) # Add assistant message with all tool calls to context - assistant_msg = Message( - role="assistant", - content=content, + assistant_msg = _assistant_message_from_result( + llm_result, + model_name=llm_params.get("model"), tool_calls=tool_calls, ) session.context_manager.add_message(assistant_msg, token_count) @@ -1049,7 +1235,12 @@ class Handlers: await session.send_event( Event( event_type="turn_complete", - data={"history_size": len(session.context_manager.items)}, + data={ + "history_size": len(session.context_manager.items), + "final_response": final_response + if isinstance(final_response, str) + else None, + }, ) ) @@ -1358,12 +1549,16 @@ async def process_submission(session: Session, submission) -> bool: async def submission_loop( submission_queue: asyncio.Queue, event_queue: asyncio.Queue, - config: Config | None = None, + config: Config, tool_router: ToolRouter | None = None, session_holder: list | None = None, hf_token: str | None = None, + user_id: str | None = None, local_mode: bool = False, stream: bool = True, + notification_gateway: NotificationGateway | None = None, + notification_destinations: list[str] | None = None, + defer_turn_complete_notification: bool = False, ) -> None: """ Main agent loop - processes submissions and dispatches to handlers. @@ -1373,7 +1568,10 @@ async def submission_loop( # Create session with tool router session = Session( event_queue, config=config, tool_router=tool_router, hf_token=hf_token, - local_mode=local_mode, stream=stream, + user_id=user_id, local_mode=local_mode, stream=stream, + notification_gateway=notification_gateway, + notification_destinations=notification_destinations, + defer_turn_complete_notification=defer_turn_complete_notification, ) if session_holder is not None: session_holder[0] = session diff --git a/agent/core/doom_loop.py b/agent/core/doom_loop.py index fbc3510a1222ae0888bc72a5288b4dc70a9de00f..878c7c00adfb4f8ea3fa7f068493ed8358d76b8d 100644 --- a/agent/core/doom_loop.py +++ b/agent/core/doom_loop.py @@ -24,9 +24,36 @@ class ToolCallSignature: result_hash: str | None = None +def _normalize_args(args_str: str) -> str: + """Canonicalise a tool-call arguments string before hashing. + + LLMs can emit semantically-identical JSON for the same call with different + key orderings (``{"a": 1, "b": 2}`` vs ``{"b": 2, "a": 1}``) or whitespace + (``{"a":1}`` vs ``{"a": 1}``). Hashing the raw bytes makes the doom-loop + detector miss those repeats. We parse-and-redump with ``sort_keys=True`` + plus the most compact separators so trivially-different spellings collapse + to the same canonical form. + + Falls back to the original string if the input isn't valid JSON (e.g. a + handful of providers occasionally pass a bare string for ``arguments``); + that path keeps the legacy behaviour and never raises. + """ + if not args_str: + return "" + try: + return json.dumps(json.loads(args_str), sort_keys=True, separators=(",", ":")) + except (json.JSONDecodeError, TypeError, ValueError): + return args_str + + def _hash_args(args_str: str) -> str: - """Return a short hash of the JSON arguments string.""" - return hashlib.md5(args_str.encode()).hexdigest()[:12] + """Return a short hash of the JSON arguments string. + + The input is normalised via :func:`_normalize_args` first so that + semantically-identical tool calls produce the same hash regardless of key + order or whitespace. + """ + return hashlib.md5(_normalize_args(args_str).encode()).hexdigest()[:12] def extract_recent_tool_signatures( @@ -129,9 +156,13 @@ def check_for_doom_loop(messages: list[Message]) -> str | None: # Check for identical consecutive calls tool_name = detect_identical_consecutive(signatures, threshold=3) if tool_name: - logger.warning("Doom loop detected: %d+ identical consecutive calls to '%s'", 3, tool_name) + logger.warning( + "Repetition guard activated: %d+ identical consecutive calls to '%s'", + 3, + tool_name, + ) return ( - f"[SYSTEM: DOOM LOOP DETECTED] You have called '{tool_name}' with the same " + f"[SYSTEM: REPETITION GUARD] You have called '{tool_name}' with the same " f"arguments multiple times in a row, getting the same result each time. " f"STOP repeating this approach — it is not working. " f"Step back and try a fundamentally different strategy. " @@ -143,9 +174,9 @@ def check_for_doom_loop(messages: list[Message]) -> str | None: pattern = detect_repeating_sequence(signatures) if pattern: pattern_desc = " → ".join(s.name for s in pattern) - logger.warning("Doom loop detected: repeating sequence [%s]", pattern_desc) + logger.warning("Repetition guard activated: repeating sequence [%s]", pattern_desc) return ( - f"[SYSTEM: DOOM LOOP DETECTED] You are stuck in a repeating cycle of tool calls: " + f"[SYSTEM: REPETITION GUARD] You are stuck in a repeating cycle of tool calls: " f"[{pattern_desc}]. This pattern has repeated multiple times without progress. " f"STOP this cycle and try a fundamentally different approach. " f"Consider: breaking down the problem differently, using alternative tools, " diff --git a/agent/core/hf_access.py b/agent/core/hf_access.py index 400db5a5a70efeb3cc513f4322469d504821e973..11446349fba5a41e7e92db86e9060e39dab00ba1 100644 --- a/agent/core/hf_access.py +++ b/agent/core/hf_access.py @@ -55,6 +55,13 @@ def _extract_username(whoami: dict[str, Any]) -> str | None: def _normalize_personal_plan(whoami: dict[str, Any]) -> str: + # OAuth whoami responses set `type: "user"` and surface Pro status only via + # the `isPro` boolean. Check the boolean first so a generic `type` value + # doesn't shadow it — otherwise Pro OAuth users get classified as free and + # blocked from running Jobs (smolagents/ml-intern Space discussion #21). + if whoami.get("isPro") is True or whoami.get("is_pro") is True: + return "pro" + plan_str = "" for key in ("plan", "type", "accountType"): value = whoami.get(key) @@ -62,9 +69,6 @@ def _normalize_personal_plan(whoami: dict[str, Any]) -> str: plan_str = value.lower() break - if not plan_str and (whoami.get("isPro") is True or whoami.get("is_pro") is True): - return "pro" - if any(tag in plan_str for tag in ("pro", "enterprise", "team")): return "pro" return "free" diff --git a/agent/core/hf_tokens.py b/agent/core/hf_tokens.py new file mode 100644 index 0000000000000000000000000000000000000000..3e72ccc128a9d9aaecb661c4c2ba3850a10b5dc0 --- /dev/null +++ b/agent/core/hf_tokens.py @@ -0,0 +1,85 @@ +"""Hugging Face token resolution helpers.""" + +from __future__ import annotations + +import os +from typing import Any + + +def clean_hf_token(token: str | None) -> str | None: + """Normalize token strings the same way huggingface_hub does.""" + if token is None: + return None + return token.replace("\r", "").replace("\n", "").strip() or None + + +def get_cached_hf_token() -> str | None: + """Return the token from huggingface_hub's normal env/cache lookup.""" + try: + from huggingface_hub import get_token + + return get_token() + except Exception: + return None + + +def resolve_hf_token( + *candidates: str | None, + include_cached: bool = True, +) -> str | None: + """Return the first non-empty explicit token, then optionally HF cache.""" + for token in candidates: + cleaned = clean_hf_token(token) + if cleaned: + return cleaned + if include_cached: + return get_cached_hf_token() + return None + + +def resolve_hf_router_token(session_hf_token: str | None = None) -> str | None: + """Resolve the token used for Hugging Face Router LLM calls. + + App-specific precedence: + 1. INFERENCE_TOKEN: shared hosted-Space inference token. + 2. session_hf_token: the active user/session token. + 3. huggingface_hub.get_token(): HF_TOKEN/HUGGING_FACE_HUB_TOKEN or + local ``hf auth login`` cache. + """ + return resolve_hf_token(os.environ.get("INFERENCE_TOKEN"), session_hf_token) + + +def get_hf_bill_to() -> str | None: + """Return X-HF-Bill-To only when a shared inference token is active.""" + if clean_hf_token(os.environ.get("INFERENCE_TOKEN")): + return os.environ.get("HF_BILL_TO", "smolagents") + return None + + +def bearer_token_from_header(auth_header: str | None) -> str | None: + """Extract a cleaned bearer token from an Authorization header.""" + if not auth_header or not auth_header.startswith("Bearer "): + return None + return clean_hf_token(auth_header[7:]) + + +def resolve_hf_request_token( + request: Any, + *, + include_env_fallback: bool = True, +) -> str | None: + """Resolve a user token from a FastAPI request. + + This intentionally does not use the local ``hf auth login`` cache. Backend + request paths should act as the browser user from Authorization/cookie, or + fall back only to an explicit server ``HF_TOKEN`` in dev/server contexts. + """ + token = bearer_token_from_header(request.headers.get("Authorization", "")) + if token: + return token + token = clean_hf_token(request.cookies.get("hf_access_token")) + if token: + return token + if include_env_fallback: + return clean_hf_token(os.environ.get("HF_TOKEN")) + return None diff --git a/agent/core/llm_params.py b/agent/core/llm_params.py index bac507354348fc2b8ca423c427005e9e6efc8bb2..880886b3e1e2919f31d35934c6f9a4c3fb5e9525 100644 --- a/agent/core/llm_params.py +++ b/agent/core/llm_params.py @@ -5,7 +5,12 @@ can import it without pulling in the whole agent loop / tool router and creating circular imports. """ -import os +from agent.core.hf_tokens import get_hf_bill_to, resolve_hf_router_token + + +def _resolve_hf_router_token(session_hf_token: str | None = None) -> str | None: + """Backward-compatible private wrapper used by tests and older imports.""" + return resolve_hf_router_token(session_hf_token) def _patch_litellm_effort_validation() -> None: @@ -129,7 +134,8 @@ def _resolve_llm_params( 1. INFERENCE_TOKEN env — shared key on the hosted Space (inference is free for users, billed to the Space owner via ``X-HF-Bill-To``). 2. session.hf_token — the user's own token (CLI / OAuth / cache file). - 3. HF_TOKEN env — belt-and-suspenders fallback for CLI users. + 3. huggingface_hub cache — ``HF_TOKEN`` / ``HUGGING_FACE_HUB_TOKEN`` / + local ``hf auth login`` cache. """ if model_name.startswith("anthropic/"): params: dict = {"model": model_name} @@ -175,18 +181,13 @@ def _resolve_llm_params( return params hf_model = model_name.removeprefix("huggingface/") - api_key = ( - os.environ.get("INFERENCE_TOKEN") - or session_hf_token - or os.environ.get("HF_TOKEN") - ) + api_key = _resolve_hf_router_token(session_hf_token) params = { "model": f"openai/{hf_model}", "api_base": "https://router.huggingface.co/v1", "api_key": api_key, } - if os.environ.get("INFERENCE_TOKEN"): - bill_to = os.environ.get("HF_BILL_TO", "smolagents") + if bill_to := get_hf_bill_to(): params["extra_headers"] = {"X-HF-Bill-To": bill_to} if reasoning_effort: hf_level = "low" if reasoning_effort == "minimal" else reasoning_effort diff --git a/agent/core/session.py b/agent/core/session.py index 0cf9524a12241f24aaa1ab1186fadc73620f177b..c53294cd251931cca8eff171f1ee694c2991e629 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -12,10 +12,13 @@ from typing import Any, Optional from agent.config import Config from agent.context_manager.manager import ContextManager +from agent.messaging.gateway import NotificationGateway +from agent.messaging.models import NotificationRequest logger = logging.getLogger(__name__) _DEFAULT_MAX_TOKENS = 200_000 +_TURN_COMPLETE_NOTIFICATION_CHARS = 39000 def _get_max_tokens_safe(model_name: str) -> int: @@ -62,6 +65,7 @@ class OpType(Enum): class Event: event_type: str data: Optional[dict[str, Any]] = None + seq: Optional[int] = None class Session: @@ -73,16 +77,26 @@ class Session: def __init__( self, event_queue: asyncio.Queue, - config: Config | None = None, + config: Config, tool_router=None, context_manager: ContextManager | None = None, hf_token: str | None = None, local_mode: bool = False, stream: bool = True, + notification_gateway: NotificationGateway | None = None, + notification_destinations: list[str] | None = None, + defer_turn_complete_notification: bool = False, + session_id: str | None = None, + user_id: str | None = None, + persistence_store: Any | None = None, ): self.hf_token: Optional[str] = hf_token + self.user_id: Optional[str] = user_id + self.persistence_store = persistence_store self.tool_router = tool_router self.stream = stream + if config is None: + raise ValueError("Session requires a Config") tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else [] self.context_manager = context_manager or ContextManager( model_max_tokens=_get_max_tokens_safe(config.model_name), @@ -93,15 +107,16 @@ class Session: local_mode=local_mode, ) self.event_queue = event_queue - self.session_id = str(uuid.uuid4()) - self.config = config or Config( - model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0", - ) + self.session_id = session_id or str(uuid.uuid4()) + self.config = config self.is_running = True self._cancelled = asyncio.Event() self.pending_approval: Optional[dict[str, Any]] = None self.sandbox = None self._running_job_ids: set[str] = set() # HF job IDs currently executing + self.notification_gateway = notification_gateway + self.notification_destinations = list(notification_destinations or []) + self.defer_turn_complete_notification = defer_turn_complete_notification # Session trajectory logging self.logged_events: list[dict] = [] @@ -123,11 +138,10 @@ class Session: # thinking params at all # Key absent → not probed yet; fall back to the raw preference. self.model_effective_effort: dict[str, str | None] = {} + self.context_manager.on_message_added = self._schedule_trace_message async def send_event(self, event: Event) -> None: """Send event back to client and log to trajectory""" - await self.event_queue.put(event) - # Log event to trajectory self.logged_events.append( { @@ -136,11 +150,149 @@ class Session: "data": event.data, } ) + if self.persistence_store is not None: + try: + event.seq = await self.persistence_store.append_event( + self.session_id, event.event_type, event.data + ) + except Exception as e: + logger.debug("Event persistence failed for %s: %s", self.session_id, e) + + await self.event_queue.put(event) + await self._enqueue_auto_notification_requests(event) # Mid-turn heartbeat flush (owned by telemetry module). from agent.core.telemetry import HeartbeatSaver + HeartbeatSaver.maybe_fire(self) + def _schedule_trace_message(self, message: Any) -> None: + """Best-effort append-only trace save for SFT/KPI export.""" + if self.persistence_store is None: + return + try: + payload = message.model_dump(mode="json") + except Exception: + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + source = str(payload.get("role") or "message") + loop.create_task( + self.persistence_store.append_trace_message( + self.session_id, payload, source=source + ) + ) + + def set_notification_destinations(self, destinations: list[str]) -> None: + """Replace the session's opted-in auto-notification destinations.""" + deduped: list[str] = [] + seen: set[str] = set() + for destination in destinations: + if destination not in seen: + deduped.append(destination) + seen.add(destination) + self.notification_destinations = deduped + + async def send_deferred_turn_complete_notification(self, event: Event) -> None: + if event.event_type != "turn_complete": + return + await self._enqueue_auto_notification_requests( + event, + include_deferred_turn_complete=True, + ) + + async def _enqueue_auto_notification_requests( + self, + event: Event, + include_deferred_turn_complete: bool = False, + ) -> None: + if self.notification_gateway is None: + return + if not self.notification_destinations: + return + auto_events = set(self.config.messaging.auto_event_types) + if event.event_type not in auto_events: + return + if ( + self.defer_turn_complete_notification + and event.event_type == "turn_complete" + and not include_deferred_turn_complete + ): + return + + requests = self._build_auto_notification_requests(event) + for request in requests: + await self.notification_gateway.enqueue(request) + + def _build_auto_notification_requests( + self, event: Event + ) -> list[NotificationRequest]: + metadata = { + "session_id": self.session_id, + "model": self.config.model_name, + "event_type": event.event_type, + } + + title: str | None = None + message: str | None = None + severity = "info" + data = event.data or {} + if event.event_type == "approval_required": + tools = data.get("tools", []) + tool_names = [] + for tool in tools if isinstance(tools, list) else []: + if isinstance(tool, dict): + tool_name = str(tool.get("tool") or "").strip() + if tool_name and tool_name not in tool_names: + tool_names.append(tool_name) + count = len(tools) if isinstance(tools, list) else 0 + title = "Agent approval required" + message = ( + f"Session {self.session_id} is waiting for approval " + f"for {count} tool call(s)." + ) + if tool_names: + message += " Tools: " + ", ".join(tool_names) + severity = "warning" + elif event.event_type == "error": + title = "Agent error" + error = str(data.get("error") or "Unknown error") + message = f"Session {self.session_id} hit an error.\n{error[:500]}" + severity = "error" + elif event.event_type == "turn_complete": + title = "Agent task complete" + summary = str(data.get("final_response") or "").strip() + if summary: + summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS] + message = ( + f"Session {self.session_id} completed successfully.\n" + f"{summary}" + ) + else: + message = f"Session {self.session_id} completed successfully." + severity = "success" + + if message is None: + return [] + + requests: list[NotificationRequest] = [] + for destination in self.notification_destinations: + if not self.config.messaging.can_auto_send(destination): + continue + requests.append( + NotificationRequest( + destination=destination, + title=title, + message=message, + severity=severity, + metadata=metadata, + event_type=event.event_type, + ) + ) + return requests + def cancel(self) -> None: """Signal cancellation to the running agent loop.""" self._cancelled.set() @@ -199,11 +351,21 @@ class Session: tools = self.tool_router.get_tool_specs_for_llm() or [] except Exception: tools = [] + # Sum per-call cost from llm_call events so analyzers don't have to + # walk the events array themselves. Each `llm_call` event already + # carries cost_usd from `agent.core.telemetry.record_llm_call`. + total_cost_usd = sum( + float((e.get("data") or {}).get("cost_usd") or 0.0) + for e in self.logged_events + if e.get("event_type") == "llm_call" + ) return { "session_id": self.session_id, + "user_id": self.user_id, "session_start_time": self.session_start_time, "session_end_time": datetime.now().isoformat(), "model_name": self.config.model_name, + "total_cost_usd": total_cost_usd, "messages": [msg.model_dump() for msg in self.context_manager.items], "events": self.logged_events, "tools": tools, diff --git a/agent/core/session_persistence.py b/agent/core/session_persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..5c125b3883accfaf825596eb6345a5d6a6a1350f --- /dev/null +++ b/agent/core/session_persistence.py @@ -0,0 +1,428 @@ +"""Optional durable session persistence for the hosted backend. + +The public CLI must keep working without MongoDB. This module therefore +exposes one small async store interface and returns a no-op implementation +unless ``MONGODB_URI`` is configured and reachable. +""" + +from __future__ import annotations + +import logging +import os +from datetime import UTC, datetime +from typing import Any + +from bson import BSON +from pymongo import AsyncMongoClient, DeleteMany, ReturnDocument, UpdateOne +from pymongo.errors import DuplicateKeyError, InvalidDocument, PyMongoError + +logger = logging.getLogger(__name__) + +SCHEMA_VERSION = 1 +MAX_BSON_BYTES = 15 * 1024 * 1024 + + +def _now() -> datetime: + return datetime.now(UTC) + + +def _doc_id(session_id: str, idx: int) -> str: + return f"{session_id}:{idx}" + + +def _safe_message_doc(message: dict[str, Any]) -> dict[str, Any]: + """Return a Mongo-safe message document payload. + + Mongo's hard document limit is 16 MB. We stay below that and store an + explicit marker rather than failing the whole snapshot for one huge tool log. + """ + try: + if len(BSON.encode({"message": message})) <= MAX_BSON_BYTES: + return message + except (InvalidDocument, OverflowError): + pass + return { + "role": "tool", + "content": ( + "[SYSTEM: A single persisted message exceeded MongoDB's document " + "size/encoding limit and was replaced by this marker.]" + ), + "ml_intern_persistence_error": "message_too_large_or_invalid", + } + + +class NoopSessionStore: + """Async no-op store used when Mongo is not configured.""" + + enabled = False + + async def init(self) -> None: + return None + + async def close(self) -> None: + return None + + async def upsert_session(self, **_: Any) -> None: + return None + + async def save_snapshot(self, **_: Any) -> None: + return None + + async def load_session(self, *_: Any, **__: Any) -> dict[str, Any] | None: + return None + + async def list_sessions(self, *_: Any, **__: Any) -> list[dict[str, Any]]: + return [] + + async def soft_delete_session(self, *_: Any, **__: Any) -> None: + return None + + async def update_session_fields(self, *_: Any, **__: Any) -> None: + return None + + async def append_event(self, *_: Any, **__: Any) -> int | None: + return None + + async def load_events_after(self, *_: Any, **__: Any) -> list[dict[str, Any]]: + return [] + + async def append_trace_message(self, *_: Any, **__: Any) -> int | None: + return None + + async def get_quota(self, *_: Any, **__: Any) -> int | None: + return None + + async def try_increment_quota(self, *_: Any, **__: Any) -> int | None: + return None + + async def refund_quota(self, *_: Any, **__: Any) -> None: + return None + + +class MongoSessionStore(NoopSessionStore): + """MongoDB-backed session store.""" + + enabled = True + + def __init__(self, uri: str, db_name: str) -> None: + self.uri = uri + self.db_name = db_name + self.enabled = False + self.client: AsyncMongoClient | None = None + self.db = None + + async def init(self) -> None: + try: + self.client = AsyncMongoClient(self.uri, serverSelectionTimeoutMS=3000) + self.db = self.client[self.db_name] + await self.client.admin.command("ping") + await self._create_indexes() + self.enabled = True + logger.info("Mongo session persistence enabled (db=%s)", self.db_name) + except Exception as e: + logger.warning("Mongo session persistence disabled: %s", e) + self.enabled = False + if self.client is not None: + await self.client.close() + self.client = None + self.db = None + + async def close(self) -> None: + if self.client is not None: + await self.client.close() + self.client = None + self.db = None + + async def _create_indexes(self) -> None: + if self.db is None: + return + await self.db.sessions.create_index( + [("user_id", 1), ("visibility", 1), ("updated_at", -1)] + ) + await self.db.sessions.create_index( + [("visibility", 1), ("status", 1), ("last_active_at", -1)] + ) + await self.db.session_messages.create_index( + [("session_id", 1), ("idx", 1)], unique=True + ) + await self.db.session_events.create_index( + [("session_id", 1), ("seq", 1)], unique=True + ) + await self.db.session_trace_messages.create_index( + [("session_id", 1), ("seq", 1)], unique=True + ) + await self.db.session_trace_messages.create_index([("created_at", -1)]) + + def _ready(self) -> bool: + return bool(self.enabled and self.db is not None) + + async def upsert_session( + self, + *, + session_id: str, + user_id: str, + model: str, + title: str | None = None, + surface: str = "frontend", + created_at: datetime | None = None, + runtime_state: str = "idle", + status: str = "active", + message_count: int = 0, + turn_count: int = 0, + pending_approval: list[dict[str, Any]] | None = None, + claude_counted: bool = False, + notification_destinations: list[str] | None = None, + ) -> None: + if not self._ready(): + return + now = _now() + await self.db.sessions.update_one( + {"_id": session_id}, + { + "$setOnInsert": { + "_id": session_id, + "session_id": session_id, + "user_id": user_id, + "surface": surface, + "created_at": created_at or now, + "schema_version": SCHEMA_VERSION, + "visibility": "live", + }, + "$set": { + "title": title, + "model": model, + "status": status, + "runtime_state": runtime_state, + "updated_at": now, + "last_active_at": now, + "message_count": message_count, + "turn_count": turn_count, + "pending_approval": pending_approval or [], + "claude_counted": claude_counted, + "notification_destinations": notification_destinations or [], + }, + }, + upsert=True, + ) + + async def save_snapshot( + self, + *, + session_id: str, + user_id: str, + model: str, + messages: list[dict[str, Any]], + title: str | None = None, + runtime_state: str = "idle", + status: str = "active", + turn_count: int = 0, + pending_approval: list[dict[str, Any]] | None = None, + claude_counted: bool = False, + created_at: datetime | None = None, + notification_destinations: list[str] | None = None, + ) -> None: + if not self._ready(): + return + now = _now() + await self.upsert_session( + session_id=session_id, + user_id=user_id, + model=model, + title=title, + created_at=created_at, + runtime_state=runtime_state, + status=status, + message_count=len(messages), + turn_count=turn_count, + pending_approval=pending_approval, + claude_counted=claude_counted, + notification_destinations=notification_destinations, + ) + ops: list[Any] = [] + for idx, raw in enumerate(messages): + ops.append( + UpdateOne( + {"_id": _doc_id(session_id, idx)}, + { + "$set": { + "session_id": session_id, + "idx": idx, + "message": _safe_message_doc(raw), + "updated_at": now, + }, + "$setOnInsert": {"created_at": now}, + }, + upsert=True, + ) + ) + ops.append(DeleteMany({"session_id": session_id, "idx": {"$gte": len(messages)}})) + try: + if ops: + await self.db.session_messages.bulk_write(ops, ordered=False) + except PyMongoError as e: + logger.warning("Failed to persist session %s snapshot: %s", session_id, e) + + async def load_session( + self, session_id: str, *, include_deleted: bool = False + ) -> dict[str, Any] | None: + if not self._ready(): + return None + meta = await self.db.sessions.find_one({"_id": session_id}) + if not meta: + return None + if meta.get("visibility") == "deleted" and not include_deleted: + return None + cursor = self.db.session_messages.find({"session_id": session_id}).sort("idx", 1) + messages = [row.get("message") async for row in cursor] + return {"metadata": meta, "messages": messages} + + async def list_sessions( + self, user_id: str, *, include_deleted: bool = False + ) -> list[dict[str, Any]]: + if not self._ready(): + return [] + query: dict[str, Any] = {"user_id": user_id} + if user_id == "dev": + query = {} + if not include_deleted: + query["visibility"] = {"$ne": "deleted"} + cursor = self.db.sessions.find(query).sort("updated_at", -1) + return [row async for row in cursor] + + async def soft_delete_session(self, session_id: str) -> None: + if not self._ready(): + return + await self.db.sessions.update_one( + {"_id": session_id}, + { + "$set": { + "visibility": "deleted", + "runtime_state": "idle", + "updated_at": _now(), + } + }, + ) + + async def update_session_fields(self, session_id: str, **fields: Any) -> None: + if not self._ready() or not fields: + return + fields["updated_at"] = _now() + await self.db.sessions.update_one({"_id": session_id}, {"$set": fields}) + + async def _next_seq(self, counter_id: str) -> int: + doc = await self.db.counters.find_one_and_update( + {"_id": counter_id}, + {"$inc": {"seq": 1}}, + upsert=True, + return_document=ReturnDocument.AFTER, + ) + return int(doc["seq"]) + + async def append_event( + self, session_id: str, event_type: str, data: dict[str, Any] | None + ) -> int | None: + if not self._ready(): + return None + try: + seq = await self._next_seq(f"event:{session_id}") + await self.db.session_events.insert_one( + { + "_id": _doc_id(session_id, seq), + "session_id": session_id, + "seq": seq, + "event_type": event_type, + "data": data or {}, + "created_at": _now(), + } + ) + return seq + except PyMongoError as e: + logger.debug("Failed to append event for %s: %s", session_id, e) + return None + + async def load_events_after(self, session_id: str, after_seq: int = 0) -> list[dict[str, Any]]: + if not self._ready(): + return [] + cursor = self.db.session_events.find( + {"session_id": session_id, "seq": {"$gt": int(after_seq or 0)}} + ).sort("seq", 1) + return [row async for row in cursor] + + async def append_trace_message( + self, session_id: str, message: dict[str, Any], source: str = "message" + ) -> int | None: + if not self._ready(): + return None + try: + seq = await self._next_seq(f"trace:{session_id}") + await self.db.session_trace_messages.insert_one( + { + "_id": _doc_id(session_id, seq), + "session_id": session_id, + "seq": seq, + "role": message.get("role"), + "message": _safe_message_doc(message), + "source": source, + "created_at": _now(), + } + ) + return seq + except PyMongoError as e: + logger.debug("Failed to append trace message for %s: %s", session_id, e) + return None + + async def get_quota(self, user_id: str, day: str) -> int | None: + if not self._ready(): + return None + doc = await self.db.claude_quotas.find_one({"_id": f"{user_id}:{day}"}) + return int(doc.get("count", 0)) if doc else 0 + + async def try_increment_quota(self, user_id: str, day: str, cap: int) -> int | None: + if not self._ready(): + return None + key = f"{user_id}:{day}" + now = _now() + try: + await self.db.claude_quotas.insert_one( + { + "_id": key, + "user_id": user_id, + "day": day, + "count": 1, + "updated_at": now, + } + ) + return 1 + except DuplicateKeyError: + pass + doc = await self.db.claude_quotas.find_one_and_update( + {"_id": key, "count": {"$lt": cap}}, + {"$inc": {"count": 1}, "$set": {"updated_at": now}}, + return_document=ReturnDocument.AFTER, + ) + return int(doc["count"]) if doc else None + + async def refund_quota(self, user_id: str, day: str) -> None: + if not self._ready(): + return + await self.db.claude_quotas.update_one( + {"_id": f"{user_id}:{day}", "count": {"$gt": 0}}, + {"$inc": {"count": -1}, "$set": {"updated_at": _now()}}, + ) + + +_store: NoopSessionStore | MongoSessionStore | None = None + + +def get_session_store() -> NoopSessionStore | MongoSessionStore: + global _store + if _store is None: + uri = os.environ.get("MONGODB_URI") + db_name = os.environ.get("MONGODB_DB", "ml-intern") + _store = MongoSessionStore(uri, db_name) if uri else NoopSessionStore() + return _store + + +def _reset_store_for_tests(store: NoopSessionStore | MongoSessionStore | None = None) -> None: + global _store + _store = store diff --git a/agent/core/session_uploader.py b/agent/core/session_uploader.py index f22b520103c6b9c1b42878b3636c4bc10d674c70..d18ec6b8a49253716b00ef752544991dd07dba89 100644 --- a/agent/core/session_uploader.py +++ b/agent/core/session_uploader.py @@ -90,9 +90,11 @@ def upload_session_as_file( # across sessions with different tool rosters. session_row = { "session_id": data["session_id"], + "user_id": data.get("user_id"), "session_start_time": data["session_start_time"], "session_end_time": data["session_end_time"], "model_name": data["model_name"], + "total_cost_usd": data.get("total_cost_usd"), "messages": json.dumps(scrubbed_messages), "events": json.dumps(scrubbed_events), "tools": json.dumps(scrubbed_tools), diff --git a/agent/core/tools.py b/agent/core/tools.py index 9bbf91d798514fddbbae1b7c68d9f1826e82d824..ef2c57bc19478043996597083cba54a243cdf4cc 100644 --- a/agent/core/tools.py +++ b/agent/core/tools.py @@ -46,10 +46,12 @@ from agent.tools.hf_repo_git_tool import ( hf_repo_git_handler, ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler +from agent.tools.notify_tool import NOTIFY_TOOL_SPEC, notify_handler from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler from agent.tools.sandbox_tool import get_sandbox_tools +from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git # from agent.tools.private_hf_repo_tools import ( @@ -310,6 +312,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: parameters=HF_PAPERS_TOOL_SPEC["parameters"], handler=hf_papers_handler, ), + ToolSpec( + name=WEB_SEARCH_TOOL_SPEC["name"], + description=WEB_SEARCH_TOOL_SPEC["description"], + parameters=WEB_SEARCH_TOOL_SPEC["parameters"], + handler=web_search_handler, + ), # Dataset inspection tool (unified) ToolSpec( name=HF_INSPECT_DATASET_TOOL_SPEC["name"], @@ -324,6 +332,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]: parameters=PLAN_TOOL_SPEC["parameters"], handler=plan_tool_handler, ), + ToolSpec( + name=NOTIFY_TOOL_SPEC["name"], + description=NOTIFY_TOOL_SPEC["description"], + parameters=NOTIFY_TOOL_SPEC["parameters"], + handler=notify_handler, + ), ToolSpec( name=HF_JOBS_TOOL_SPEC["name"], description=HF_JOBS_TOOL_SPEC["description"], diff --git a/agent/main.py b/agent/main.py index f601ab545687259ddbca3895235c5eb7fb31a027..f500cc5fe756e04e5fac90184b91c5d5198a51aa 100644 --- a/agent/main.py +++ b/agent/main.py @@ -23,8 +23,10 @@ from prompt_toolkit import PromptSession from agent.config import load_config from agent.core.agent_loop import submission_loop from agent.core import model_switcher +from agent.core.hf_tokens import resolve_hf_token from agent.core.session import OpType from agent.core.tools import ToolRouter +from agent.messaging.gateway import NotificationGateway from agent.utils.reliability_checks import check_training_script_save_pattern from agent.utils.terminal_display import ( get_console, @@ -69,26 +71,15 @@ def _safe_get_args(arguments: dict) -> dict: return args if isinstance(args, dict) else {} -def _get_hf_token() -> str | None: - """Get HF token from environment, huggingface_hub API, or cached token file.""" - token = os.environ.get("HF_TOKEN") - if token: - return token +def _get_hf_user(token: str | None) -> str | None: + """Resolve the HF username for a token, if available.""" + if not token: + return None try: from huggingface_hub import HfApi - api = HfApi() - token = api.token - if token: - return token + return HfApi(token=token).whoami().get("name") except Exception: - pass - # Fallback: read the cached token file directly - token_path = Path.home() / ".cache" / "huggingface" / "token" - if token_path.exists(): - token = token_path.read_text().strip() - if token: - return token - return None + return None async def _prompt_and_save_hf_token(prompt_session: PromptSession) -> str: @@ -342,6 +333,9 @@ async def event_listener( stream_buf.discard() print_turn_complete() print_plan() + session = session_holder[0] if session_holder else None + if session is not None: + await session.send_deferred_turn_complete_notification(event) turn_complete_event.set() elif event.event_type == "interrupted": shimmer.stop() @@ -758,7 +752,7 @@ async def _handle_slash_command( normalized = arg.removeprefix("huggingface/") session = session_holder[0] if session_holder else None await model_switcher.probe_and_switch_model( - normalized, config, session, console, _get_hf_token(), + normalized, config, session, console, resolve_hf_token(), ) return None @@ -817,7 +811,7 @@ async def _handle_slash_command( return None -async def main(): +async def main(model: str | None = None): """Interactive chat with the agent""" # Clear screen @@ -827,19 +821,16 @@ async def main(): prompt_session = PromptSession() # HF token — required, prompt if missing - hf_token = _get_hf_token() + hf_token = resolve_hf_token() if not hf_token: hf_token = await _prompt_and_save_hf_token(prompt_session) - config = load_config(CLI_CONFIG_PATH) + config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) + if model: + config.model_name = model # Resolve username for banner - hf_user = None - try: - from huggingface_hub import HfApi - hf_user = HfApi(token=hf_token).whoami().get("name") - except Exception: - pass + hf_user = _get_hf_user(hf_token) print_banner(model=config.model_name, hf_user=hf_user) @@ -857,6 +848,8 @@ async def main(): turn_complete_event.set() ready_event = asyncio.Event() + notification_gateway = NotificationGateway(config.messaging) + await notification_gateway.start() # Create tool router with local mode tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True) @@ -871,8 +864,12 @@ async def main(): tool_router=tool_router, session_holder=session_holder, hf_token=hf_token, + user_id=hf_user, local_mode=True, stream=True, + notification_gateway=notification_gateway, + notification_destinations=config.messaging.default_auto_destinations(), + defer_turn_complete_notification=True, ) ) @@ -1028,6 +1025,8 @@ async def main(): agent_task.cancel() # Agent didn't shut down cleanly — close MCP explicitly await tool_router.__aexit__(None, None, None) + finally: + await notification_gateway.close() # Now safe to cancel the listener (agent is done emitting events) listener_task.cancel() @@ -1047,15 +1046,18 @@ async def headless_main( logging.basicConfig(level=logging.WARNING) _configure_runtime_logging() - hf_token = _get_hf_token() + hf_token = resolve_hf_token() if not hf_token: print("ERROR: No HF token found. Set HF_TOKEN or run `huggingface-cli login`.", file=sys.stderr) sys.exit(1) print(f"HF token loaded", file=sys.stderr) - config = load_config(CLI_CONFIG_PATH) + config = load_config(CLI_CONFIG_PATH, include_user_defaults=True) config.yolo_mode = True # Auto-approve everything in headless mode + notification_gateway = NotificationGateway(config.messaging) + await notification_gateway.start() + hf_user = _get_hf_user(hf_token) if model: config.model_name = model @@ -1082,8 +1084,12 @@ async def headless_main( tool_router=tool_router, session_holder=session_holder, hf_token=hf_token, + user_id=hf_user, local_mode=True, stream=stream, + notification_gateway=notification_gateway, + notification_destinations=config.messaging.default_auto_destinations(), + defer_turn_complete_notification=True, ) ) @@ -1209,6 +1215,10 @@ async def headless_main( stream_buf.discard() history_size = event.data.get("history_size", "?") if event.data else "?" print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr) + if event.event_type == "turn_complete": + session = session_holder[0] if session_holder else None + if session is not None: + await session.send_deferred_turn_complete_notification(event) break # Shutdown @@ -1222,6 +1232,8 @@ async def headless_main( except asyncio.TimeoutError: agent_task.cancel() await tool_router.__aexit__(None, None, None) + finally: + await notification_gateway.close() def cli(): @@ -1252,7 +1264,7 @@ def cli(): max_iter = 10_000 # effectively unlimited asyncio.run(headless_main(args.prompt, model=args.model, max_iterations=max_iter, stream=not args.no_stream)) else: - asyncio.run(main()) + asyncio.run(main(model=args.model)) except KeyboardInterrupt: print("\n\nGoodbye!") diff --git a/agent/messaging/__init__.py b/agent/messaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c399d254e30fcbce555d6f51b810440b1171ec1a --- /dev/null +++ b/agent/messaging/__init__.py @@ -0,0 +1,15 @@ +from agent.messaging.gateway import NotificationGateway +from agent.messaging.models import ( + MessagingConfig, + NotificationRequest, + NotificationResult, + SUPPORTED_AUTO_EVENT_TYPES, +) + +__all__ = [ + "MessagingConfig", + "NotificationGateway", + "NotificationRequest", + "NotificationResult", + "SUPPORTED_AUTO_EVENT_TYPES", +] diff --git a/agent/messaging/base.py b/agent/messaging/base.py new file mode 100644 index 0000000000000000000000000000000000000000..bf1d73894fa85ce066fa289902c4d6b783ceaa11 --- /dev/null +++ b/agent/messaging/base.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod + +import httpx + +from agent.messaging.models import DestinationConfig, NotificationRequest, NotificationResult + + +class NotificationError(Exception): + """Delivery failed and should not be retried.""" + + +class RetryableNotificationError(NotificationError): + """Delivery failed transiently and can be retried.""" + + +class NotificationProvider(ABC): + provider_name: str + + @abstractmethod + async def send( + self, + client: httpx.AsyncClient, + destination_name: str, + destination: DestinationConfig, + request: NotificationRequest, + ) -> NotificationResult: + """Deliver a notification to one destination.""" diff --git a/agent/messaging/gateway.py b/agent/messaging/gateway.py new file mode 100644 index 0000000000000000000000000000000000000000..83c4704baafe9eadea23a336f691dc96db934e79 --- /dev/null +++ b/agent/messaging/gateway.py @@ -0,0 +1,166 @@ +import asyncio +import logging +from collections.abc import Iterable + +import httpx + +from agent.messaging.base import ( + NotificationError, + NotificationProvider, + RetryableNotificationError, +) +from agent.messaging.models import ( + MessagingConfig, + NotificationRequest, + NotificationResult, +) +from agent.messaging.slack import SlackProvider + +logger = logging.getLogger(__name__) + +_RETRY_DELAYS = (1, 2, 4) + + +class NotificationGateway: + def __init__(self, config: MessagingConfig): + self.config = config + self._providers: dict[str, NotificationProvider] = { + "slack": SlackProvider(), + } + self._queue: asyncio.Queue[NotificationRequest] = asyncio.Queue() + self._worker_task: asyncio.Task | None = None + self._client: httpx.AsyncClient | None = None + + @property + def enabled(self) -> bool: + return self.config.enabled + + async def start(self) -> None: + if not self.enabled or self._worker_task is not None: + return + self._client = httpx.AsyncClient(timeout=10.0) + self._worker_task = asyncio.create_task(self._worker(), name="notification-gateway") + + async def flush(self) -> None: + if not self.enabled: + return + await self._queue.join() + + async def close(self) -> None: + if not self.enabled: + return + await self.flush() + if self._worker_task is not None: + self._worker_task.cancel() + try: + await self._worker_task + except asyncio.CancelledError: + pass + self._worker_task = None + if self._client is not None: + await self._client.aclose() + self._client = None + + async def send(self, request: NotificationRequest) -> NotificationResult: + if not self.enabled: + return NotificationResult( + destination=request.destination, + ok=False, + provider="disabled", + error="Messaging is disabled", + ) + + destination = self.config.get_destination(request.destination) + if destination is None: + return NotificationResult( + destination=request.destination, + ok=False, + provider="unknown", + error=f"Unknown destination '{request.destination}'", + ) + + provider = self._providers.get(destination.provider) + if provider is None: + return NotificationResult( + destination=request.destination, + ok=False, + provider=destination.provider, + error=f"No provider implementation for '{destination.provider}'", + ) + return await self._send_with_retries(provider, request.destination, destination, request) + + async def send_many( + self, requests: Iterable[NotificationRequest] + ) -> list[NotificationResult]: + results: list[NotificationResult] = [] + for request in requests: + results.append(await self.send(request)) + return results + + async def enqueue(self, request: NotificationRequest) -> bool: + if not self.enabled or self._worker_task is None: + return False + await self._queue.put(request) + return True + + async def _worker(self) -> None: + while True: + request = await self._queue.get() + try: + result = await self.send(request) + if not result.ok: + logger.warning( + "Notification delivery failed for %s: %s", + request.destination, + result.error, + ) + except Exception: + logger.exception("Unexpected notification worker failure") + finally: + self._queue.task_done() + + async def _send_with_retries( + self, + provider: NotificationProvider, + destination_name: str, + destination, + request: NotificationRequest, + ) -> NotificationResult: + client = self._client or httpx.AsyncClient(timeout=10.0) + owns_client = self._client is None + try: + for attempt in range(len(_RETRY_DELAYS) + 1): + try: + return await provider.send(client, destination_name, destination, request) + except RetryableNotificationError as exc: + if attempt >= len(_RETRY_DELAYS): + return NotificationResult( + destination=destination_name, + ok=False, + provider=provider.provider_name, + error=str(exc), + ) + delay = _RETRY_DELAYS[attempt] + logger.warning( + "Retrying notification to %s in %ss after transient error: %s", + destination_name, + delay, + exc, + ) + await asyncio.sleep(delay) + except NotificationError as exc: + return NotificationResult( + destination=destination_name, + ok=False, + provider=provider.provider_name, + error=str(exc), + ) + return NotificationResult( + destination=destination_name, + ok=False, + provider=provider.provider_name, + error="Notification delivery exhausted retries", + ) + finally: + if owns_client: + await client.aclose() diff --git a/agent/messaging/models.py b/agent/messaging/models.py new file mode 100644 index 0000000000000000000000000000000000000000..25f645fe92fa70901843e68be82d82f3a78e0d16 --- /dev/null +++ b/agent/messaging/models.py @@ -0,0 +1,123 @@ +from typing import Annotated, Literal + +from pydantic import BaseModel, Field, field_validator, model_validator + +_DESTINATION_NAME_CHARS = set("abcdefghijklmnopqrstuvwxyz0123456789._-") +SUPPORTED_AUTO_EVENT_TYPES = {"approval_required", "error", "turn_complete"} + + +class SlackDestinationConfig(BaseModel): + provider: Literal["slack"] = "slack" + token: str + channel: str + allow_agent_tool: bool = False + allow_auto_events: bool = False + username: str | None = None + icon_emoji: str | None = None + + @field_validator("token", "channel") + @classmethod + def _require_non_empty(cls, value: str) -> str: + value = value.strip() + if not value: + raise ValueError("must not be empty") + return value + + +DestinationConfig = Annotated[SlackDestinationConfig, Field(discriminator="provider")] + + +class MessagingConfig(BaseModel): + enabled: bool = False + auto_event_types: list[str] = Field( + default_factory=lambda: ["approval_required", "error", "turn_complete"] + ) + destinations: dict[str, DestinationConfig] = Field(default_factory=dict) + + @field_validator("destinations") + @classmethod + def _validate_destination_names( + cls, destinations: dict[str, DestinationConfig] + ) -> dict[str, DestinationConfig]: + for name in destinations: + if not name or any(char not in _DESTINATION_NAME_CHARS for char in name): + raise ValueError( + "destination names must use lowercase letters, digits, '.', '_' or '-'" + ) + return destinations + + @field_validator("auto_event_types") + @classmethod + def _validate_auto_event_types(cls, event_types: list[str]) -> list[str]: + if not event_types: + return [] + normalized: list[str] = [] + seen: set[str] = set() + for event_type in event_types: + if event_type not in SUPPORTED_AUTO_EVENT_TYPES: + raise ValueError( + f"unsupported auto event type '{event_type}'" + ) + if event_type not in seen: + normalized.append(event_type) + seen.add(event_type) + return normalized + + @model_validator(mode="after") + def _require_destinations_when_enabled(self) -> "MessagingConfig": + if self.enabled and not self.destinations: + raise ValueError("messaging.enabled requires at least one destination") + return self + + def get_destination(self, name: str) -> DestinationConfig | None: + return self.destinations.get(name) + + def can_agent_tool_send(self, name: str) -> bool: + destination = self.get_destination(name) + return bool(destination and destination.allow_agent_tool) + + def can_auto_send(self, name: str) -> bool: + destination = self.get_destination(name) + return bool(destination and destination.allow_auto_events) + + def default_auto_destinations(self) -> list[str]: + if not self.enabled: + return [] + return [ + name + for name in self.destinations + if self.can_auto_send(name) + ] + + +class NotificationRequest(BaseModel): + destination: str + title: str | None = None + message: str + severity: Literal["info", "success", "warning", "error"] = "info" + metadata: dict[str, str] = Field(default_factory=dict) + event_type: str | None = None + + @field_validator("destination", "message") + @classmethod + def _require_text(cls, value: str) -> str: + value = value.strip() + if not value: + raise ValueError("must not be empty") + return value + + @field_validator("title") + @classmethod + def _normalize_title(cls, value: str | None) -> str | None: + if value is None: + return None + value = value.strip() + return value or None + + +class NotificationResult(BaseModel): + destination: str + ok: bool + provider: str + error: str | None = None + external_id: str | None = None diff --git a/agent/messaging/slack.py b/agent/messaging/slack.py new file mode 100644 index 0000000000000000000000000000000000000000..a1fb7c18eef91396e566fb04b4f6411f9184a2be --- /dev/null +++ b/agent/messaging/slack.py @@ -0,0 +1,186 @@ +import json +import re + +import httpx + +from agent.messaging.base import ( + NotificationError, + NotificationProvider, + RetryableNotificationError, +) +from agent.messaging.models import ( + NotificationRequest, + NotificationResult, + SlackDestinationConfig, +) + +_SEVERITY_PREFIX = { + "info": "[INFO]", + "success": "[SUCCESS]", + "warning": "[WARNING]", + "error": "[ERROR]", +} + + +def _format_slack_mrkdwn(content: str) -> str: + """Convert common Markdown constructs to Slack's mrkdwn syntax.""" + if not content: + return content + + placeholders: dict[str, str] = {} + placeholder_index = 0 + + def placeholder(value: str) -> str: + nonlocal placeholder_index + key = f"\x00SLACK{placeholder_index}\x00" + placeholder_index += 1 + placeholders[key] = value + return key + + text = content + + # Protect code before any formatting conversion. Slack's mrkdwn ignores + # formatting inside backticks, so these regions should stay byte-for-byte. + text = re.sub( + r"(```(?:[^\n]*\n)?[\s\S]*?```)", + lambda match: placeholder(match.group(0)), + text, + ) + text = re.sub(r"(`[^`\n]+`)", lambda match: placeholder(match.group(0)), text) + + def convert_markdown_link(match: re.Match[str]) -> str: + label = match.group(1) + url = match.group(2).strip() + if url.startswith("<") and url.endswith(">"): + url = url[1:-1].strip() + return placeholder(f"<{url}|{label}>") + + text = re.sub( + r"\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)", + convert_markdown_link, + text, + ) + + # Preserve existing Slack entities and manual mrkdwn links before escaping. + text = re.sub( + r"(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)", + lambda match: placeholder(match.group(1)), + text, + ) + text = re.sub( + r"^(>+\s)", + lambda match: placeholder(match.group(0)), + text, + flags=re.MULTILINE, + ) + + text = text.replace("&", "&").replace("<", "<").replace(">", ">") + text = text.replace("&", "&").replace("<", "<").replace(">", ">") + + def convert_header(match: re.Match[str]) -> str: + header = match.group(1).strip() + header = re.sub(r"\*\*(.+?)\*\*", r"\1", header) + return placeholder(f"*{header}*") + + text = re.sub(r"^#{1,6}\s+(.+)$", convert_header, text, flags=re.MULTILINE) + text = re.sub( + r"\*\*\*(.+?)\*\*\*", + lambda match: placeholder(f"*_{match.group(1)}_*"), + text, + ) + text = re.sub( + r"\*\*(.+?)\*\*", + lambda match: placeholder(f"*{match.group(1)}*"), + text, + ) + text = re.sub( + r"(? str: + lines: list[str] = [] + prefix = _SEVERITY_PREFIX[request.severity] + if request.title: + lines.append(f"{prefix} {request.title}") + else: + lines.append(prefix) + lines.append(request.message) + for key, value in request.metadata.items(): + lines.append(f"{key}: {value}") + return _format_slack_mrkdwn("\n".join(lines)) + + +class SlackProvider(NotificationProvider): + provider_name = "slack" + + async def send( + self, + client: httpx.AsyncClient, + destination_name: str, + destination: SlackDestinationConfig, + request: NotificationRequest, + ) -> NotificationResult: + payload = { + "channel": destination.channel, + "text": _format_text(request), + "mrkdwn": True, + "unfurl_links": False, + "unfurl_media": False, + } + if destination.username: + payload["username"] = destination.username + if destination.icon_emoji: + payload["icon_emoji"] = destination.icon_emoji + + try: + response = await client.post( + "https://slack.com/api/chat.postMessage", + headers={ + "Authorization": f"Bearer {destination.token}", + "Content-Type": "application/json; charset=utf-8", + }, + content=json.dumps(payload), + ) + except httpx.TimeoutException as exc: + raise RetryableNotificationError("Slack request timed out") from exc + except httpx.TransportError as exc: + raise RetryableNotificationError("Slack transport error") from exc + + if response.status_code == 429 or response.status_code >= 500: + raise RetryableNotificationError( + f"Slack HTTP {response.status_code}" + ) + if response.status_code >= 400: + raise NotificationError(f"Slack HTTP {response.status_code}") + + try: + data = response.json() + except ValueError as exc: + raise RetryableNotificationError("Slack returned invalid JSON") from exc + + if not data.get("ok"): + error = str(data.get("error") or "unknown_error") + if error == "ratelimited": + raise RetryableNotificationError(error) + raise NotificationError(error) + + return NotificationResult( + destination=destination_name, + ok=True, + provider=self.provider_name, + external_id=str(data.get("ts") or ""), + error=None, + ) diff --git a/agent/prompts/system_prompt_v3.yaml b/agent/prompts/system_prompt_v3.yaml index befa56bf7a68d6a9a776fa64d8e27e28d3131cbe..cb63c901e699f68687353121f4d166a8a059efb5 100644 --- a/agent/prompts/system_prompt_v3.yaml +++ b/agent/prompts/system_prompt_v3.yaml @@ -1,5 +1,5 @@ system_prompt: | - You are Hugging Face Agent, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face ecosystem. + You are ML Intern, an ML engineering assistant with {{ num_tools }} tools for training, fine-tuning, data processing, inference, and evaluation on the Hugging Face (HF) ecosystem. Your goal is to complete what the user requested with zero errors. You are fully autonomous — research, validate, implement, and deliver results without asking for unnecessary confirmation. @@ -28,7 +28,7 @@ system_prompt: | # Mistakes you WILL make without research - HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio parameter names (e.g. `run_name` instead of `name`). Fix: read a current example script first. + HALLUCINATED IMPORTS: You will import from modules that were renamed or removed. Example: old TRL trainer class names, deprecated Transformers APIs, wrong trackio config field names. Fix: read a current example script first. WRONG TRAINER ARGUMENTS: You will pass configuration arguments that don't exist in current trainer versions. Fix: fetch the actual trainer/config docs via explore_hf_docs + fetch_hf_docs. @@ -60,6 +60,38 @@ system_prompt: | DPO: "prompt", "chosen", "rejected" GRPO: "prompt" + # Trackio + + Trackio is natively integrated with Transformers Trainer and all TRL trainers — the built-in TrackioCallback handles init/log/finish. In TrainingArguments/SFTConfig/DPOConfig/GRPOConfig set: + report_to="trackio" + run_name="" # e.g. "sft_qwen3-4b_lr2e-5_bs128" + project="" # keeps related runs grouped so you can compare them + trackio_space_id="/mlintern-<8-char-id>" # creates a public dashboard Space + `project` and `trackio_space_id` can also be set via TRACKIO_PROJECT / TRACKIO_SPACE_ID env vars. + + Alerts are how iterations decide what to change. Use trackio.alert(title, text, level) at every decision point in training. Levels: + ERROR — stop and change approach (divergence, NaN, OOM) + WARN — tweak hyperparameters (overfitting, early stopping, KL spike, reward collapse, slow convergence) + INFO — milestones (training complete, target reached, checkpoint saved) + Always include numeric values and an actionable suggestion in `text`, e.g. "loss=12.4 at step 200 — lr likely too high, try ×0.1". A future call must be able to parse it and act on it. + + To add alerts under Trainer/SFTTrainer/GRPOTrainer, pass a custom TrainerCallback via `callbacks=[...]` that calls trackio.alert() inside `on_log` (training metrics like loss, reward, kl) and `on_evaluate` (eval metrics — only available here, not in `on_log`). Keep each `if` simple: one metric, one threshold. Conditions stay easy to adjust between runs. + + Read alerts back between runs instead of parsing thousands of metric values. CLI — always use --json: + trackio get alerts --project

--run --json + trackio get alerts --project

--since --json # incremental polling + trackio get run --project

--run --json + trackio get metric --project

--run --metric --json + trackio list runs --project

--json + Python: api = trackio.Api(); api.alerts(

, run=, since=); api.runs(

) (each run has .name, .config, .alerts()). + + Drive the next config from prior alerts: + diverged → lr × 0.1 + overfitting → weight_decay × 10 or reduce capacity + early stopping → lr × 0.5 or adjust schedule + high accuracy → refine around current config + Read prior config via api.runs(...).config and only mutate keys the alerts justify changing. + # Data audit Before working with any dataset, audit it first. Do not assume you know what the data looks like — inspect it. @@ -75,7 +107,7 @@ system_prompt: | - Dataset format verified: [columns confirmed via hf_inspect_dataset/hub_repo_details] - push_to_hub=True and hub_model_id set - timeout: [value] (based on: [model size] on [hardware]) - - Trackio monitoring included and working + - Trackio monitoring included and deploying metrics to a public Space If you cannot fill in all items, stop and complete the missing steps first. @@ -156,6 +188,7 @@ system_prompt: | - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs. - For errors: state what went wrong, why, and what you're doing to fix it. - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity. + - Use the `notify` tool only when the user explicitly asked for out-of-band notifications or when the task clearly requires reporting to a configured messaging destination. Do not use it for routine chat updates. # Tool usage diff --git a/agent/tools/__init__.py b/agent/tools/__init__.py index 14ef45669bc443c1c005ddde69b4205eb02f46cb..65c793cbaad3b2f74eacaf1da6038ff0bef893d9 100644 --- a/agent/tools/__init__.py +++ b/agent/tools/__init__.py @@ -20,6 +20,7 @@ from agent.tools.github_read_file import ( ) from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler from agent.tools.types import ToolResult +from agent.tools.web_search_tool import WEB_SEARCH_TOOL_SPEC, web_search_handler __all__ = [ "ToolResult", @@ -36,4 +37,6 @@ __all__ = [ "github_search_code_handler", "HF_INSPECT_DATASET_TOOL_SPEC", "hf_inspect_dataset_handler", + "WEB_SEARCH_TOOL_SPEC", + "web_search_handler", ] diff --git a/agent/tools/jobs_tool.py b/agent/tools/jobs_tool.py index c18d47e298e445dbc232181ca9bd9520b942fa80..6518fa3cbd3d286716e4f0aaae49512937564a7b 100644 --- a/agent/tools/jobs_tool.py +++ b/agent/tools/jobs_tool.py @@ -19,6 +19,7 @@ from huggingface_hub.utils import HfHubHTTPError from agent.core.hf_access import JobsAccessError, resolve_jobs_namespace from agent.core.session import Event +from agent.tools.trackio_seed import ensure_trackio_dashboard from agent.tools.types import ToolResult logger = logging.getLogger(__name__) @@ -382,6 +383,31 @@ class HfJobsTool: "isError": True, } + async def _seed_trackio_dashboard(self, space_id: str) -> None: + """Idempotently install trackio dashboard files into *space_id* before + the job runs. Surfaces seed progress as tool_log events but never + raises — a seed failure should not block job submission, since trackio + often still works when the Space already has dashboard code from a + previous run. + """ + loop = asyncio.get_running_loop() + + def _log(msg: str) -> None: + if self.session is None: + return + loop.call_soon_threadsafe( + self.session.event_queue.put_nowait, + Event(event_type="tool_log", data={"tool": "hf_jobs", "log": msg}), + ) + + try: + await asyncio.to_thread( + ensure_trackio_dashboard, space_id, self.hf_token, _log + ) + except Exception as e: + logger.warning(f"trackio dashboard seed failed for {space_id}: {e}") + _log(f"trackio dashboard seed failed: {e}") + async def _wait_for_job_completion( self, job_id: str, namespace: Optional[str] = None ) -> tuple[str, list[str]]: @@ -533,11 +559,24 @@ class HfJobsTool: # Run the job flavor = args.get("hardware_flavor", "cpu-basic") timeout_str = args.get("timeout", "30m") + + # Trackio: agent-declared space + project become env vars on the job + # so trackio.init() picks them up automatically. We also surface them + # in tool_state_change so the frontend can embed the dashboard. + env_dict = _add_default_env(args.get("env")) + trackio_space_id = args.get("trackio_space_id") + trackio_project = args.get("trackio_project") + if trackio_space_id: + env_dict["TRACKIO_SPACE_ID"] = trackio_space_id + await self._seed_trackio_dashboard(trackio_space_id) + if trackio_project: + env_dict["TRACKIO_PROJECT"] = trackio_project + job = await _async_call( self.api.run_job, image=image, command=command, - env=_add_default_env(args.get("env")), + env=env_dict, secrets=_add_environment_variables(args.get("secrets"), self.hf_token), flavor=flavor, timeout=timeout_str, @@ -550,16 +589,18 @@ class HfJobsTool: # Send job URL immediately after job creation (before waiting for completion) if self.session and self.tool_call_id: + state_data: Dict[str, Any] = { + "tool_call_id": self.tool_call_id, + "tool": "hf_jobs", + "state": "running", + "jobUrl": job.url, + } + if trackio_space_id: + state_data["trackioSpaceId"] = trackio_space_id + if trackio_project: + state_data["trackioProject"] = trackio_project await self.session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": self.tool_call_id, - "tool": "hf_jobs", - "state": "running", - "jobUrl": job.url, - }, - ) + Event(event_type="tool_state_change", data=state_data) ) # Telemetry: job submission + completion (infra consumption signal). @@ -594,16 +635,18 @@ class HfJobsTool: # Notify frontend of final status if self.session and self.tool_call_id: + final_data: Dict[str, Any] = { + "tool_call_id": self.tool_call_id, + "tool": "hf_jobs", + "state": final_status.lower(), + "jobUrl": job.url, + } + if trackio_space_id: + final_data["trackioSpaceId"] = trackio_space_id + if trackio_project: + final_data["trackioProject"] = trackio_project await self.session.send_event( - Event( - event_type="tool_state_change", - data={ - "tool_call_id": self.tool_call_id, - "tool": "hf_jobs", - "state": final_status.lower(), - "jobUrl": job.url, - }, - ) + Event(event_type="tool_state_change", data=final_data) ) # Filter out UV package installation output @@ -977,7 +1020,10 @@ HF_JOBS_TOOL_SPEC = { "- You MUST have validated dataset format via hf_inspect_dataset or hub_repo_details.\n" "- Training config MUST include push_to_hub=True and hub_model_id. " "Job storage is EPHEMERAL — all files are deleted when the job ends. Without push_to_hub, trained models are lost permanently.\n" - "- Include trackio monitoring and provide the dashboard URL to the user.\n\n" + "- Include trackio monitoring and provide the dashboard URL to the user. " + "When the script uses report_to='trackio', also pass `trackio_space_id` " + "(e.g. '/mlintern-<8char>') and `trackio_project` as tool args — " + "they are injected as TRACKIO_SPACE_ID/TRACKIO_PROJECT env vars and let the UI embed the live dashboard.\n\n" "BATCH/ABLATION JOBS: Submit ONE job first. Check logs to confirm it starts training successfully. " "Only then submit the remaining jobs. Never submit all at once — if there's a bug, all jobs fail.\n\n" "Operations: run, ps, logs, inspect, cancel, scheduled run/ps/inspect/delete/suspend/resume.\n\n" @@ -1060,6 +1106,26 @@ HF_JOBS_TOOL_SPEC = { "type": "object", "description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.", }, + "trackio_space_id": { + "type": "string", + "description": ( + "Optional. The HF Space hosting the trackio dashboard for this run " + "(e.g. '/mlintern-<8char>', under YOUR HF namespace). " + "Injected as TRACKIO_SPACE_ID env var and used by the UI to embed " + "the live dashboard. Set this whenever the script uses " + "report_to='trackio'. The Space is auto-created and seeded with the " + "trackio dashboard before the job starts — DO NOT pre-create it via " + "hf_repo_git, that produces an empty Space that breaks the embed." + ), + }, + "trackio_project": { + "type": "string", + "description": ( + "Optional. The trackio project name to log this run under. " + "Injected as TRACKIO_PROJECT env var and used by the UI to filter " + "the embedded dashboard to this project." + ), + }, "namespace": { "type": "string", "description": ( diff --git a/agent/tools/notify_tool.py b/agent/tools/notify_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..f926d5a58d5f3c4b877cb8792f812f6e4fa322a7 --- /dev/null +++ b/agent/tools/notify_tool.py @@ -0,0 +1,108 @@ +from typing import Any + +from agent.messaging.models import NotificationRequest + +NOTIFY_TOOL_SPEC = { + "name": "notify", + "description": ( + "Send an out-of-band notification to configured messaging destinations. " + "Use this only when the user explicitly asked for proactive notifications " + "or when the task requires reporting progress outside the chat. " + "Destinations must be named server-side configs such as 'slack.ops'." + ), + "parameters": { + "type": "object", + "properties": { + "destinations": { + "type": "array", + "description": "Named messaging destinations to notify.", + "items": {"type": "string"}, + "minItems": 1, + }, + "message": { + "type": "string", + "description": "Main notification body.", + }, + "title": { + "type": "string", + "description": "Optional short title line.", + }, + "severity": { + "type": "string", + "enum": ["info", "success", "warning", "error"], + "description": "Notification severity label.", + }, + }, + "required": ["destinations", "message"], + }, +} + + +async def notify_handler( + arguments: dict[str, Any], session=None, **_kwargs +) -> tuple[str, bool]: + if session is None or session.notification_gateway is None: + return "Messaging is not configured for this session.", False + + raw_destinations = arguments.get("destinations", []) + if not isinstance(raw_destinations, list) or not raw_destinations: + return "destinations must be a non-empty array of destination names.", False + + destinations: list[str] = [] + seen: set[str] = set() + for raw_name in raw_destinations: + if not isinstance(raw_name, str): + return "Each destination must be a string.", False + name = raw_name.strip() + if not name: + return "Destination names must not be empty.", False + if name not in seen: + destinations.append(name) + seen.add(name) + + disallowed = [ + name + for name in destinations + if not session.config.messaging.can_agent_tool_send(name) + ] + if disallowed: + return ( + "These destinations are unavailable for the notify tool: " + + ", ".join(disallowed) + ), False + + message = arguments.get("message", "") + if not isinstance(message, str) or not message.strip(): + return "message must be a non-empty string.", False + + title = arguments.get("title") + severity = arguments.get("severity", "info") + if title is not None and not isinstance(title, str): + return "title must be a string when provided.", False + if severity not in {"info", "success", "warning", "error"}: + return "severity must be one of: info, success, warning, error.", False + + requests = [ + NotificationRequest( + destination=name, + title=title, + message=message, + severity=severity, + metadata={ + "session_id": session.session_id, + "model": session.config.model_name, + }, + ) + for name in destinations + ] + results = await session.notification_gateway.send_many(requests) + + lines = [] + all_ok = True + for result in results: + if result.ok: + lines.append(f"{result.destination}: sent") + else: + all_ok = False + lines.append(f"{result.destination}: failed ({result.error})") + return "\n".join(lines), all_ok diff --git a/agent/tools/research_tool.py b/agent/tools/research_tool.py index 18ae2ad6513d0ad98dff23e369f9222c59ef699c..11131766ee0262ba71805ff6a743c5da736e2386 100644 --- a/agent/tools/research_tool.py +++ b/agent/tools/research_tool.py @@ -37,6 +37,7 @@ RESEARCH_TOOL_NAMES = { "github_find_examples", "github_list_repos", "github_read_file", + "web_search", "hf_inspect_dataset", "hf_repo_files", } @@ -102,6 +103,8 @@ tell you what actually works. - `explore_hf_docs(endpoint)`: Search docs for a library. Endpoints: trl, transformers, datasets, peft, accelerate, trackio, vllm, inference-endpoints, etc. - `fetch_hf_docs(url)`: Fetch full page content from explore results - `find_hf_api(query=..., tag=...)`: Find REST API endpoints +- `web_search(query=..., allowed_domains=[...], blocked_domains=[...])`: + Search the current web when papers/docs/GitHub are not enough. ## Hub repo inspection - `hf_repo_files`: List/read files in any HF repo (model, dataset, space) @@ -306,8 +309,10 @@ async def research_handler( # ── Doom-loop detection ── doom_prompt = check_for_doom_loop(messages) if doom_prompt: - logger.warning("Research sub-agent doom loop detected at iteration %d", _iteration) - await _log("Doom loop detected — injecting corrective prompt") + logger.warning( + "Research sub-agent repetition guard activated at iteration %d", + _iteration, + ) messages.append(Message(role="user", content=doom_prompt)) # ── Context budget: warn at 75%, hard-stop at 95% ── @@ -424,7 +429,7 @@ async def research_handler( await _log(f"▸ {tool_name} {args_str}") output, _success = await session.tool_router.call_tool( - tool_name, tool_args, session=session + tool_name, tool_args, session=session, tool_call_id=tc.id ) _tool_uses += 1 await _log(f"tools:{_tool_uses}") diff --git a/agent/tools/sandbox_client.py b/agent/tools/sandbox_client.py index 16982c76fe62dc66ac9dbd88acafdb80fc2b2a0f..967d946c12d37c25d76084e02a72548a5fa22bc7 100644 --- a/agent/tools/sandbox_client.py +++ b/agent/tools/sandbox_client.py @@ -37,6 +37,7 @@ Tools: bash, read, write, edit, upload from __future__ import annotations import io +import secrets as secrets_lib import sys import time import uuid @@ -99,8 +100,8 @@ CMD ["python", "sandbox_server.py"] _SANDBOX_SERVER = '''\ """Minimal FastAPI server for sandbox operations.""" -import os, subprocess, pathlib, signal, threading, re, tempfile -from fastapi import FastAPI +import hmac, os, subprocess, pathlib, signal, threading, re, tempfile +from fastapi import Depends, FastAPI, HTTPException, Request from pydantic import BaseModel from typing import Optional import uvicorn @@ -156,6 +157,22 @@ def _atomic_write(path: pathlib.Path, content: str): app = FastAPI() +def _expected_api_token() -> str: + return os.environ.get("SANDBOX_API_TOKEN") or os.environ.get("HF_TOKEN") or "" + +def _require_auth(request: Request) -> None: + expected = _expected_api_token() + if not expected: + raise HTTPException(status_code=503, detail="Sandbox API token not configured") + auth_header = request.headers.get("authorization", "") + scheme, _, supplied = auth_header.partition(" ") + if scheme.lower() != "bearer" or not supplied: + raise HTTPException(status_code=401, detail="Missing bearer token") + if not hmac.compare_digest(supplied, expected): + raise HTTPException(status_code=401, detail="Invalid bearer token") + +_AUTH = [Depends(_require_auth)] + # Track active bash processes so they can be killed on cancel _active_procs = {} # pid -> subprocess.Popen _proc_lock = threading.Lock() @@ -344,7 +361,7 @@ def _validate_python(content, path=""): def health(): return {"status": "ok"} -@app.post("/api/bash") +@app.post("/api/bash", dependencies=_AUTH) def bash(req: BashReq): try: proc = subprocess.Popen( @@ -371,7 +388,7 @@ def bash(req: BashReq): except Exception as e: return {"success": False, "output": "", "error": str(e)} -@app.post("/api/kill") +@app.post("/api/kill", dependencies=_AUTH) def kill_all(): """Kill all active bash processes. Called when user cancels.""" with _proc_lock: @@ -389,7 +406,7 @@ def kill_all(): pass return {"success": True, "output": f"Killed {len(killed)} process(es): {killed}", "error": ""} -@app.post("/api/read") +@app.post("/api/read", dependencies=_AUTH) def read(req: ReadReq): try: p = pathlib.Path(req.path) @@ -406,7 +423,7 @@ def read(req: ReadReq): except Exception as e: return {"success": False, "output": "", "error": str(e)} -@app.post("/api/write") +@app.post("/api/write", dependencies=_AUTH) def write(req: WriteReq): try: p = pathlib.Path(req.path) @@ -420,7 +437,7 @@ def write(req: WriteReq): except Exception as e: return {"success": False, "output": "", "error": str(e)} -@app.post("/api/edit") +@app.post("/api/edit", dependencies=_AUTH) def edit(req: EditReq): try: p = pathlib.Path(req.path) @@ -447,7 +464,7 @@ def edit(req: EditReq): except Exception as e: return {"success": False, "output": "", "error": str(e)} -@app.post("/api/exists") +@app.post("/api/exists", dependencies=_AUTH) def exists(req: ExistsReq): return {"success": True, "output": str(pathlib.Path(req.path).exists()).lower(), "error": ""} @@ -482,6 +499,7 @@ class Sandbox: space_id: str token: str | None = None + api_token: str | None = field(default=None, repr=False) work_dir: str = "/app" timeout: int = DEFAULT_TIMEOUT _owns_space: bool = field(default=False, repr=False) @@ -495,9 +513,10 @@ class Sandbox: # Trailing slash is critical: httpx resolves relative paths against base_url. # Without it, client.get("health") resolves to /health instead of /api/health. self._base_url = f"https://{slug}.hf.space/api/" + api_token = self.api_token or self.token self._client = httpx.Client( base_url=self._base_url, - headers={"Authorization": f"Bearer {self.token}"} if self.token else {}, + headers={"Authorization": f"Bearer {api_token}"} if api_token else {}, timeout=httpx.Timeout(MAX_TIMEOUT, connect=30), follow_redirects=True, ) @@ -563,6 +582,7 @@ class Sandbox: base = name or "sandbox" suffix = uuid.uuid4().hex[:8] space_id = f"{owner}/{base}-{suffix}" + sandbox_api_token = secrets_lib.token_urlsafe(32) _log(f"Creating sandbox: {space_id} (from {template})...") @@ -583,8 +603,9 @@ class Sandbox: # Inject secrets BEFORE uploading server files (which triggers rebuild). # Secrets added after a Space is running aren't available until restart, # so they must be set before the build/start cycle. - if secrets: - for key, val in secrets.items(): + sandbox_secrets = {**(secrets or {}), "SANDBOX_API_TOKEN": sandbox_api_token} + if sandbox_secrets: + for key, val in sandbox_secrets.items(): api.add_space_secret(space_id, key, val) # Upload sandbox server and Dockerfile (triggers rebuild) @@ -617,7 +638,12 @@ class Sandbox: _check_cancel() # Wait for the API server to be responsive (non-fatal) - sb = cls(space_id=space_id, token=token, _owns_space=True) + sb = cls( + space_id=space_id, + token=token, + api_token=sandbox_api_token, + _owns_space=True, + ) try: sb._wait_for_api(timeout=API_WAIT_TIMEOUT, log=_log) except TimeoutError as e: @@ -648,13 +674,24 @@ class Sandbox: log("Server files uploaded, rebuild triggered.") @classmethod - def connect(cls, space_id: str, *, token: str | None = None) -> Sandbox: + def connect( + cls, + space_id: str, + *, + token: str | None = None, + api_token: str | None = None, + ) -> Sandbox: """ Connect to an existing running Space. Does a health check to verify the Space is reachable. """ - sb = cls(space_id=space_id, token=token, _owns_space=False) + sb = cls( + space_id=space_id, + token=token, + api_token=api_token, + _owns_space=False, + ) sb._wait_for_api(timeout=60) return sb @@ -687,6 +724,10 @@ class Sandbox: ) print(f"Deleting sandbox: {self.space_id}...") self._hf_api.delete_repo(self.space_id, repo_type="space") + # Clear ownership so a second cleanup call (e.g. delete_session + + # _run_session.finally both fire) early-returns instead of retrying + # a 404 delete and emitting a spurious ERROR log. + self._owns_space = False self._client.close() print("Deleted.") diff --git a/agent/tools/sandbox_tool.py b/agent/tools/sandbox_tool.py index 6dfd3db19876eed20e5703d3f959cc8a7cb317e8..a5c26acabee66f8baedb4a3b81062a39839a327b 100644 --- a/agent/tools/sandbox_tool.py +++ b/agent/tools/sandbox_tool.py @@ -12,13 +12,29 @@ a cpu-basic sandbox is auto-created (no approval needed). from __future__ import annotations import asyncio +import logging +import re import threading +from datetime import datetime, timedelta, timezone from typing import Any from huggingface_hub import HfApi, SpaceHardware from agent.core.session import Event from agent.tools.sandbox_client import Sandbox +from agent.tools.trackio_seed import ensure_trackio_dashboard + +logger = logging.getLogger(__name__) + +# Match the exact suffix pattern Sandbox.create produces: "sandbox-<8 hex>". +# Used to identify orphan sandboxes from prior sessions safely (won't match +# user-renamed lookalikes). +_SANDBOX_NAME_RE = re.compile(r"^sandbox-[a-f0-9]{8}$") + +# How stale a sandbox must be before we treat it as definitely orphan. +# Anything more recent could be tied to a still-live session in another tab, +# so we leave it alone. +_ORPHAN_STALE_AFTER = timedelta(hours=1) def _looks_like_path(script: str) -> bool: @@ -62,11 +78,89 @@ async def resolve_sandbox_script( return None, f"Failed to read {script} from sandbox: {e}" +async def _seed_trackio_dashboard_safe(session: Any, space_id: str) -> None: + """Idempotently seed *space_id* with trackio dashboard files using the + session's HF token. Logs progress, swallows errors — a failed seed should + not block sandbox creation.""" + if not session or not getattr(session, "hf_token", None): + return + loop = asyncio.get_running_loop() + + def _log(msg: str) -> None: + loop.call_soon_threadsafe( + session.event_queue.put_nowait, + Event(event_type="tool_log", data={"tool": "sandbox_create", "log": msg}), + ) + + try: + await asyncio.to_thread( + ensure_trackio_dashboard, space_id, session.hf_token, _log + ) + except Exception as e: + _log(f"trackio dashboard seed failed: {e}") + + # ── Tool name mapping (short agent names → Sandbox client names) ────── +def _cleanup_user_orphan_sandboxes( + api: HfApi, + owner: str, + log: Any, +) -> int: + """Delete stale ``sandbox-<8hex>`` Spaces in ``owner``'s account. + + "Stale" = not modified in the last hour. The naming pattern + staleness + filter together make this safe: + + * Naming: only matches ``sandbox-``, the + pattern Sandbox.create produces. Won't touch user-renamed Spaces. + * Staleness: anything modified in the last hour might still be tied + to a live session in another tab/replica, so we leave it alone. + + Runs blocking — call via ``asyncio.to_thread``. Best-effort: failures + are logged but never raised, so a flaky HF API never blocks creation. + """ + cutoff = datetime.now(timezone.utc) - _ORPHAN_STALE_AFTER + deleted = 0 + try: + spaces = list(api.list_spaces(author=owner, limit=200)) + except Exception as e: + log(f"orphan sweep: list_spaces failed: {e}") + return 0 + + for space in spaces: + space_name = space.id.rsplit("/", 1)[-1] + if not _SANDBOX_NAME_RE.match(space_name): + continue + + last_mod = getattr(space, "lastModified", None) or getattr(space, "last_modified", None) + if isinstance(last_mod, str): + try: + last_mod = datetime.fromisoformat(last_mod.replace("Z", "+00:00")) + except ValueError: + last_mod = None + if last_mod and last_mod > cutoff: + # Recent — could be a concurrent live session. Skip. + continue + + try: + api.delete_repo(repo_id=space.id, repo_type="space") + deleted += 1 + log(f"orphan sweep: deleted {space.id}") + except Exception as e: + log(f"orphan sweep: failed to delete {space.id}: {e}") + + if deleted: + log(f"orphan sweep: cleaned up {deleted} stale sandbox(es) before create") + return deleted + + async def _ensure_sandbox( - session: Any, hardware: str = "cpu-basic", **create_kwargs + session: Any, + hardware: str = "cpu-basic", + extra_secrets: dict[str, str] | None = None, + **create_kwargs, ) -> tuple[Sandbox | None, str | None]: """ Ensure a sandbox exists on the session. Auto-creates with given hardware if needed. @@ -109,6 +203,23 @@ async def _ensure_sandbox( Event(event_type="tool_log", data={"tool": "sandbox", "log": msg}), ) + # Before we create a new sandbox, sweep this user's stale sandboxes from + # prior sessions. ``_cleanup_sandbox`` in session_manager fires only on + # clean session exit; pod kills, WebSocket drops, etc. leave orphans + # behind, and they accumulate on every new session forever (observed + # 2310 leaked across the Hub on 2026-04-27). Doing the cleanup here at + # session start = self-healing, no separate cron needed. + # + # The 1h staleness filter is the safety: a sandbox modified in the last + # hour might still be tied to a live session in another tab, so we skip. + # Anything older has no realistic chance of being active given typical + # session lengths. + try: + await asyncio.to_thread(_cleanup_user_orphan_sandboxes, api, owner, _log) + except Exception as e: + # Cleanup is best-effort — never block sandbox_create on it. + _log(f"orphan sandbox sweep failed (non-fatal): {e}") + # Bridge asyncio cancel event to a threading.Event for the blocking create call. # We poll session._cancelled from the main loop in a background task and set # a threading.Event that Sandbox.create checks during its polling loops. @@ -120,11 +231,15 @@ async def _ensure_sandbox( watcher_task = asyncio.create_task(_watch_cancel()) + secrets: dict[str, str] = {"HF_TOKEN": token} + if extra_secrets: + secrets.update({k: v for k, v in extra_secrets.items() if v}) + kwargs = { "owner": owner, "hardware": hardware, "token": token, - "secrets": {"HF_TOKEN": token}, + "secrets": secrets, "log": _log, "cancel_event": cancel_flag, **create_kwargs, @@ -188,6 +303,9 @@ SANDBOX_CREATE_TOOL_SPEC = { "fp32 ≈ 4 bytes/param, plus ~20% overhead for optimizer states during training.\n" "Common picks: t4-small (16GB VRAM, fits ≤1-3B), a10g-small (24GB, ≤7B), a100-large (80GB, ≤30B). " "If the model won't fit, pick larger hardware upfront — OOM on a sandbox wastes time.\n\n" + "If you intend to run a training script in this sandbox that uses report_to='trackio', " + "pass `trackio_space_id` (e.g. '/mlintern-<8char>') and `trackio_project` so they " + "are set as TRACKIO_SPACE_ID/TRACKIO_PROJECT secrets in the sandbox and the UI can embed the live dashboard.\n\n" "Hardware: " + ", ".join([e.value for e in SpaceHardware]) + ".\n" ), "parameters": { @@ -204,16 +322,49 @@ SANDBOX_CREATE_TOOL_SPEC = { "type": "boolean", "description": "If true, create a private Space", }, + "trackio_space_id": { + "type": "string", + "description": ( + "Optional. The HF Space hosting the trackio dashboard for runs in this sandbox " + "(e.g. '/mlintern-<8char>', under YOUR HF namespace). Injected as " + "TRACKIO_SPACE_ID secret and surfaced to the UI. The Space is auto-created and " + "seeded with the trackio dashboard — DO NOT pre-create it via hf_repo_git, " + "that produces an empty Space that breaks the embed." + ), + }, + "trackio_project": { + "type": "string", + "description": ( + "Optional. The trackio project name. Injected as TRACKIO_PROJECT secret and " + "used by the UI to filter the embedded dashboard to this project." + ), + }, }, }, } async def sandbox_create_handler( - args: dict[str, Any], session: Any = None + args: dict[str, Any], session: Any = None, tool_call_id: str | None = None ) -> tuple[str, bool]: """Handle sandbox_create tool calls.""" hardware = args.get("hardware", "cpu-basic") + trackio_space_id = args.get("trackio_space_id") or None + trackio_project = args.get("trackio_project") or None + + async def _emit_trackio_state(sb: Sandbox) -> None: + """Tell the frontend which trackio dashboard to embed for this sandbox.""" + if not (session and tool_call_id and trackio_space_id): + return + data: dict[str, Any] = { + "tool_call_id": tool_call_id, + "tool": "sandbox_create", + "state": "running", + "trackioSpaceId": trackio_space_id, + } + if trackio_project: + data["trackioProject"] = trackio_project + await session.send_event(Event(event_type="tool_state_change", data=data)) # If sandbox already exists, return its info if session and getattr(session, "sandbox", None): @@ -226,6 +377,7 @@ async def sandbox_create_handler( "Hardware cannot be changed by calling sandbox_create again. " "Delete the existing sandbox first if you need a different tier." ) + await _emit_trackio_state(sb) return ( f"Sandbox already active: {sb.space_id}\n" f"URL: {sb.url}\n" @@ -233,18 +385,32 @@ async def sandbox_create_handler( f"Use bash/read/write/edit to interact with it." ), True - create_kwargs = {} + create_kwargs: dict[str, Any] = {} if "private" in args: create_kwargs["private"] = args["private"] + extra_secrets: dict[str, str] = {} + if trackio_space_id: + extra_secrets["TRACKIO_SPACE_ID"] = trackio_space_id + await _seed_trackio_dashboard_safe(session, trackio_space_id) + if trackio_project: + extra_secrets["TRACKIO_PROJECT"] = trackio_project + try: - sb, error = await _ensure_sandbox(session, hardware=hardware, **create_kwargs) + sb, error = await _ensure_sandbox( + session, + hardware=hardware, + extra_secrets=extra_secrets or None, + **create_kwargs, + ) except Exception as e: return f"Failed to create sandbox: {e}", False if error: return error, False + await _emit_trackio_state(sb) + return ( f"Sandbox created: {sb.space_id}\n" f"URL: {sb.url}\n" diff --git a/agent/tools/trackio_seed.py b/agent/tools/trackio_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..1062e1b5eda2701833aad7c1c895727d7fbd191e --- /dev/null +++ b/agent/tools/trackio_seed.py @@ -0,0 +1,205 @@ +"""Seed an HF Space with the trackio dashboard. + +Background: when the agent creates a Space via `hf_repo_git create_repo` (or +the user pre-creates one), it ships with no app.py — so the iframe shows the +default Gradio "Get started" template instead of charts. Trackio's `init()` +detects the existing Space but does NOT auto-bootstrap dashboard files into it, +so the dashboard never materializes. + +This helper writes the three files trackio's runtime expects (README.md, +requirements.txt, app.py) into the Space, idempotently, BEFORE the job that +will call `trackio.init()` runs. We deliberately omit `hf_oauth: true` from +the README so the embedded iframe in ml-intern renders without a login click — +per-user privacy is enforced by namespace ownership instead. + +Beyond the dashboard files, the helper also creates the metrics bucket and +mounts it on the Space at `/data` (with `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` +Space variables). Without this, the running job writes metrics into a bucket +that the dashboard Space can't read, and the iframe shows "No projects". +""" + +from __future__ import annotations + +import io +from typing import Callable, Optional + +from huggingface_hub import ( + HfApi, + Volume, + add_space_variable, + create_bucket, + create_repo, +) +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError + + +_README = """--- +title: Trackio Dashboard +emoji: 📊 +colorFrom: pink +colorTo: gray +sdk: gradio +app_file: app.py +pinned: false +tags: + - trackio +--- + +Embedded trackio dashboard for ml-intern runs. +""" + +_REQUIREMENTS = "trackio\n" +_APP_PY = "import trackio\ntrackio.show()\n" + +# ml-intern brand mark surfaced inside the trackio dashboard. Trackio reads +# `TRACKIO_LOGO_LIGHT_URL` / `TRACKIO_LOGO_DARK_URL` from Space variables and +# renders them in place of its own logo. We point at the publicly-resolvable +# copy on the smolagents/ml-intern Space repo so any seeded dashboard inherits +# the ml-intern branding without each user having to host the asset. +_LOGO_URL = ( + "https://huggingface.co/spaces/smolagents/ml-intern/" + "resolve/main/frontend/public/smolagents.webp" +) + +_FILES = { + "README.md": _README, + "requirements.txt": _REQUIREMENTS, + "app.py": _APP_PY, +} + + +def _already_seeded(api: HfApi, space_id: str) -> bool: + """Cheap check: does the Space already have a trackio dashboard app.py? + + Avoids re-uploading the same three files on every job submission. We look + for the literal `trackio.show` call which is the load-bearing line — any + other app.py shape (the default gradio shell, a stale custom one) means + we should re-seed. + """ + try: + path = api.hf_hub_download( + repo_id=space_id, repo_type="space", filename="app.py" + ) + except (EntryNotFoundError, RepositoryNotFoundError, OSError): + return False + try: + with open(path, "r", encoding="utf-8") as f: + return "trackio.show" in f.read() + except OSError: + return False + + +def _get_space_volumes(api: HfApi, space_id: str) -> list: + """Return mounted volumes for a Space. + + `get_space_runtime()` doesn't always populate `volumes` even when the + mount exists; mirror trackio's fallback to `space_info().runtime.volumes`. + """ + runtime = api.get_space_runtime(space_id) + if getattr(runtime, "volumes", None): + return list(runtime.volumes) + info = api.space_info(space_id) + if info.runtime and getattr(info.runtime, "volumes", None): + return list(info.runtime.volumes) + return [] + + +def _ensure_bucket_mounted( + api: HfApi, + space_id: str, + bucket_id: str, + hf_token: str, + log: Optional[Callable[[str], None]] = None, +) -> None: + """Create the bucket if missing, mount it at `/data` on the Space, and + set the `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` Space variables. Idempotent — + skips work that has already been done. + """ + create_bucket(bucket_id, private=True, exist_ok=True, token=hf_token) + + existing = _get_space_volumes(api, space_id) + already_mounted = any( + getattr(v, "type", None) == "bucket" + and getattr(v, "source", None) == bucket_id + and getattr(v, "mount_path", None) == "/data" + for v in existing + ) + if not already_mounted: + preserved = [ + v + for v in existing + if not ( + getattr(v, "type", None) == "bucket" + and ( + getattr(v, "source", None) == bucket_id + or getattr(v, "mount_path", None) == "/data" + ) + ) + ] + api.set_space_volumes( + space_id, + preserved + [Volume(type="bucket", source=bucket_id, mount_path="/data")], + ) + if log: + log(f"mounted bucket {bucket_id} at /data on {space_id}") + + variables = api.get_space_variables(space_id) + desired = { + "TRACKIO_DIR": "/data/trackio", + "TRACKIO_BUCKET_ID": bucket_id, + "TRACKIO_LOGO_LIGHT_URL": _LOGO_URL, + "TRACKIO_LOGO_DARK_URL": _LOGO_URL, + } + for key, value in desired.items(): + if getattr(variables.get(key), "value", None) != value: + add_space_variable(space_id, key, value, token=hf_token) + + +def ensure_trackio_dashboard( + space_id: str, + hf_token: str, + log: Optional[Callable[[str], None]] = None, +) -> bool: + """Make sure *space_id* is fully wired for trackio: + 1. Space exists with our dashboard files (README without `hf_oauth`, + `requirements.txt`, `app.py` calling `trackio.show`). + 2. Bucket `-bucket` exists, is mounted at `/data`, and the + Space has `TRACKIO_DIR` / `TRACKIO_BUCKET_ID` variables set. + + Idempotent — re-running is cheap. Returns True if any seeding happened + in step (1), False if the dashboard files were already in place. Bucket + mount is always re-checked. + """ + api = HfApi(token=hf_token) + + create_repo( + repo_id=space_id, + repo_type="space", + space_sdk="gradio", + exist_ok=True, + token=hf_token, + ) + + seeded_files = False + if _already_seeded(api, space_id): + if log: + log(f"trackio dashboard already seeded on {space_id}") + else: + if log: + log(f"seeding trackio dashboard files into {space_id}") + for path_in_repo, content in _FILES.items(): + api.upload_file( + path_or_fileobj=io.BytesIO(content.encode("utf-8")), + path_in_repo=path_in_repo, + repo_id=space_id, + repo_type="space", + commit_message=f"ml-intern: seed trackio dashboard ({path_in_repo})", + ) + seeded_files = True + + bucket_id = f"{space_id}-bucket" + _ensure_bucket_mounted(api, space_id, bucket_id, hf_token, log) + + if log: + log(f"trackio dashboard ready: https://huggingface.co/spaces/{space_id}") + return seeded_files diff --git a/agent/tools/web_search_tool.py b/agent/tools/web_search_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..3e52ded03c1f405076e1ecc537d0b4250862f562 --- /dev/null +++ b/agent/tools/web_search_tool.py @@ -0,0 +1,273 @@ +"""DuckDuckGo HTML web search tool. + +This mirrors Claw Code's Rust WebSearch behavior: fetch DuckDuckGo's HTML +endpoint, extract result links, optionally filter domains, and return a +JSON payload the model can cite. +""" + +from __future__ import annotations + +import asyncio +import html +import json +import os +import time +from dataclasses import dataclass +from html.parser import HTMLParser +from typing import Any +from urllib.parse import parse_qsl, parse_qs, urlencode, urlparse, urlunparse + +import requests + +DEFAULT_SEARCH_URL = "https://html.duckduckgo.com/html/" +WEB_SEARCH_BASE_URL_ENV = "CLAWD_WEB_SEARCH_BASE_URL" +USER_AGENT = "clawd-rust-tools/0.1" +REQUEST_TIMEOUT_SECONDS = 20 +MAX_RESULTS = 8 + + +@dataclass(frozen=True) +class SearchHit: + title: str + url: str + + def as_json(self) -> dict[str, str]: + return {"title": self.title, "url": self.url} + + +class _AnchorParser(HTMLParser): + def __init__(self, *, require_result_class: bool) -> None: + super().__init__(convert_charrefs=True) + self.require_result_class = require_result_class + self.hits: list[tuple[str, str]] = [] + self._active_href: str | None = None + self._active_text: list[str] = [] + + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: + if tag.lower() != "a": + return + attr_map = {key.lower(): value or "" for key, value in attrs} + href = attr_map.get("href") + if not href: + return + if self.require_result_class and "result__a" not in attr_map.get("class", ""): + return + self._active_href = href + self._active_text = [] + + def handle_data(self, data: str) -> None: + if self._active_href is not None: + self._active_text.append(data) + + def handle_entityref(self, name: str) -> None: + if self._active_href is not None: + self._active_text.append(f"&{name};") + + def handle_charref(self, name: str) -> None: + if self._active_href is not None: + self._active_text.append(f"&#{name};") + + def handle_endtag(self, tag: str) -> None: + if tag.lower() != "a" or self._active_href is None: + return + title = collapse_whitespace(html.unescape("".join(self._active_text))).strip() + self.hits.append((self._active_href, title)) + self._active_href = None + self._active_text = [] + + +def build_search_url(query: str) -> str: + base = os.environ.get(WEB_SEARCH_BASE_URL_ENV, DEFAULT_SEARCH_URL) + parsed = urlparse(base) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + raise ValueError(f"invalid search base URL: {base}") + + query_pairs = parse_qsl(parsed.query, keep_blank_values=True) + query_pairs.append(("q", query)) + return urlunparse(parsed._replace(query=urlencode(query_pairs))) + + +def collapse_whitespace(value: str) -> str: + return " ".join(value.split()) + + +def decode_duckduckgo_redirect(url: str) -> str | None: + if url.startswith("http://") or url.startswith("https://"): + return html.unescape(url) + if url.startswith("//"): + joined = f"https:{url}" + elif url.startswith("/"): + joined = f"https://duckduckgo.com{url}" + else: + return None + + parsed = urlparse(joined) + if parsed.path in {"/l", "/l/"}: + uddg = parse_qs(parsed.query).get("uddg", []) + if uddg: + return html.unescape(uddg[0]) + return joined + + +def _extract_links(search_html: str, *, require_result_class: bool) -> list[SearchHit]: + parser = _AnchorParser(require_result_class=require_result_class) + parser.feed(search_html) + + hits: list[SearchHit] = [] + for raw_url, title in parser.hits: + if not title: + continue + decoded_url = decode_duckduckgo_redirect(raw_url) + if decoded_url and ( + decoded_url.startswith("http://") or decoded_url.startswith("https://") + ): + hits.append(SearchHit(title=title, url=decoded_url)) + return hits + + +def extract_search_hits(search_html: str) -> list[SearchHit]: + return _extract_links(search_html, require_result_class=True) + + +def extract_search_hits_from_generic_links(search_html: str) -> list[SearchHit]: + return _extract_links(search_html, require_result_class=False) + + +def normalize_domain_filter(domain: str) -> str: + trimmed = domain.strip() + parsed = urlparse(trimmed) + candidate = parsed.hostname if parsed.scheme and parsed.hostname else trimmed + return candidate.strip().lstrip(".").rstrip("/").lower() + + +def host_matches_list(url: str, domains: list[str]) -> bool: + host = urlparse(url).hostname + if not host: + return False + normalized_host = host.lower() + for domain in domains: + normalized = normalize_domain_filter(domain) + if normalized and ( + normalized_host == normalized or normalized_host.endswith(f".{normalized}") + ): + return True + return False + + +def dedupe_hits(hits: list[SearchHit]) -> list[SearchHit]: + seen: set[str] = set() + deduped: list[SearchHit] = [] + for hit in hits: + if hit.url in seen: + continue + seen.add(hit.url) + deduped.append(hit) + return deduped + + +def execute_web_search( + query: str, + allowed_domains: list[str] | None = None, + blocked_domains: list[str] | None = None, + tool_use_id: str = "web_search_1", +) -> dict[str, Any]: + started = time.monotonic() + search_url = build_search_url(query) + response = requests.get( + search_url, + headers={"User-Agent": USER_AGENT}, + timeout=REQUEST_TIMEOUT_SECONDS, + allow_redirects=True, + ) + + hits = extract_search_hits(response.text) + if not hits and urlparse(response.url or search_url).hostname: + hits = extract_search_hits_from_generic_links(response.text) + + if allowed_domains is not None: + hits = [hit for hit in hits if host_matches_list(hit.url, allowed_domains)] + if blocked_domains is not None: + hits = [hit for hit in hits if not host_matches_list(hit.url, blocked_domains)] + + hits = dedupe_hits(hits)[:MAX_RESULTS] + rendered_hits = "\n".join(f"- [{hit.title}]({hit.url})" for hit in hits) + if hits: + summary = ( + f"Search results for {query!r}. Include a Sources section in the final answer.\n" + f"{rendered_hits}" + ) + else: + summary = f"No web search results matched the query {query!r}." + + return { + "query": query, + "results": [ + summary, + { + "tool_use_id": tool_use_id, + "content": [hit.as_json() for hit in hits], + }, + ], + "durationSeconds": time.monotonic() - started, + } + + +WEB_SEARCH_TOOL_SPEC = { + "name": "web_search", + "description": "Search the web for current information and return cited results.", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "minLength": 2}, + "allowed_domains": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional allowlist of domains or URLs. Subdomains match.", + }, + "blocked_domains": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional blocklist of domains or URLs. Subdomains match.", + }, + }, + "required": ["query"], + "additionalProperties": False, + }, +} + + +def _optional_string_list(arguments: dict[str, Any], key: str) -> list[str] | None: + value = arguments.get(key) + if value is None: + return None + if not isinstance(value, list) or not all(isinstance(item, str) for item in value): + raise ValueError(f"{key} must be an array of strings") + return value + + +async def web_search_handler( + arguments: dict[str, Any], + session: Any = None, + tool_call_id: str | None = None, + **_kw: Any, +) -> tuple[str, bool]: + query_value = arguments.get("query", "") + if not isinstance(query_value, str): + return "Error: web_search requires a query string with at least 2 characters.", False + + query = query_value.strip() + if len(query) < 2: + return "Error: web_search requires a query with at least 2 characters.", False + + try: + output = await asyncio.to_thread( + execute_web_search, + query=query, + allowed_domains=_optional_string_list(arguments, "allowed_domains"), + blocked_domains=_optional_string_list(arguments, "blocked_domains"), + tool_use_id=tool_call_id or "web_search_1", + ) + except Exception as exc: + return f"Error executing web search: {exc}", False + + return json.dumps(output, indent=2), True diff --git a/backend/dependencies.py b/backend/dependencies.py index 0f97c448dc7f695c2606dbe15f5125f27e03609e..5ebc5385e2247343fc22509b8ea4b696080073a4 100644 --- a/backend/dependencies.py +++ b/backend/dependencies.py @@ -12,6 +12,8 @@ from typing import Any import httpx from fastapi import HTTPException, Request, status +from agent.core.hf_tokens import bearer_token_from_header + from agent.core.hf_access import fetch_whoami_v2, jobs_access_from_whoami logger = logging.getLogger(__name__) @@ -157,9 +159,8 @@ async def get_current_user(request: Request) -> dict[str, Any]: return DEV_USER # Try Authorization header - auth_header = request.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - token = auth_header[7:] + token = bearer_token_from_header(request.headers.get("Authorization", "")) + if token: user = await _extract_user_from_token(token) if user: return user @@ -183,9 +184,9 @@ def _extract_token(request: Request) -> str | None: Mirrors the lookup order used by ``get_current_user``. """ - auth_header = request.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - return auth_header[7:] + token = bearer_token_from_header(request.headers.get("Authorization", "")) + if token: + return token return request.cookies.get("hf_access_token") @@ -202,4 +203,3 @@ async def require_huggingface_org_member(request: Request) -> bool: if not token: return False return await check_org_membership(token, HF_EMPLOYEE_ORG) - diff --git a/backend/main.py b/backend/main.py index 9aa939a083e3b1230baafe8fd96361cd5b3a3c7c..f6bc64d10167de32763d5c2f9f4bcc01f69eab57 100644 --- a/backend/main.py +++ b/backend/main.py @@ -6,14 +6,17 @@ from contextlib import asynccontextmanager from pathlib import Path from dotenv import load_dotenv + +# Load .env before importing routes/session_manager so persistence and quota +# modules see local Mongo settings during startup. +load_dotenv(Path(__file__).parent.parent / ".env") + from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from routes.agent import router as agent_router from routes.auth import router as auth_router - -# Load .env from project root (parent directory) -load_dotenv(Path(__file__).parent.parent / ".env") +from session_manager import session_manager # Configure logging logging.basicConfig( @@ -27,6 +30,7 @@ logger = logging.getLogger(__name__) async def lifespan(app: FastAPI): """Application lifespan handler.""" logger.info("Starting HF Agent backend...") + await session_manager.start() # Start in-process hourly KPI rollup. Replaces an external cron so the # rollup lives next to the data and reuses the Space's HF token. try: @@ -34,7 +38,6 @@ async def lifespan(app: FastAPI): kpis_scheduler.start() except Exception as e: logger.warning("KPI scheduler failed to start: %s", e) - yield logger.info("Shutting down HF Agent backend...") @@ -47,7 +50,6 @@ async def lifespan(app: FastAPI): # Final-flush: save every still-active session so we don't lose traces on # server restart. Uploads are detached subprocesses — this is fast. try: - from session_manager import session_manager for sid, agent_session in list(session_manager.sessions.items()): sess = agent_session.session if sess.config.save_sessions: @@ -58,6 +60,7 @@ async def lifespan(app: FastAPI): logger.warning("Failed to flush session %s: %s", sid, e) except Exception as e: logger.warning("Lifespan final-flush skipped: %s", e) + await session_manager.close() app = FastAPI( diff --git a/backend/models.py b/backend/models.py index 952365c23c22936499a64f6b9ac1638541f63dc6..04048013d71ebddffd46c0e8f39cb668727a807a 100644 --- a/backend/models.py +++ b/backend/models.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field class OpType(str, Enum): @@ -87,6 +87,14 @@ class SessionInfo(BaseModel): user_id: str = "dev" pending_approval: list[PendingApprovalTool] | None = None model: str | None = None + title: str | None = None + notification_destinations: list[str] = Field(default_factory=list) + + +class SessionNotificationsRequest(BaseModel): + """Replace the session's auto-notification destinations.""" + + destinations: list[str] class HealthResponse(BaseModel): diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 4895bbadbf206d4d4a78d4291a4c182e317964eb..3067f4fd2d25e6c136a195db82795241a91e66c3 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -24,6 +24,7 @@ from models import ( HealthResponse, LLMHealthResponse, SessionInfo, + SessionNotificationsRequest, SessionResponse, SubmitRequest, TruncateRequest, @@ -33,6 +34,7 @@ from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, se import user_quotas from agent.core.hf_access import get_jobs_access +from agent.core.hf_tokens import resolve_hf_request_token, resolve_hf_router_token from agent.core.llm_params import _resolve_llm_params logger = logging.getLogger(__name__) @@ -118,9 +120,9 @@ async def _enforce_claude_quota( if not _is_anthropic_model(model_name): return user_id = user["user_id"] - used = await user_quotas.get_claude_used_today(user_id) cap = user_quotas.daily_cap_for(user.get("plan")) - if used >= cap: + new_count = await user_quotas.try_increment_claude(user_id, cap) + if new_count is None: raise HTTPException( status_code=429, detail={ @@ -133,8 +135,8 @@ async def _enforce_claude_quota( ), }, ) - await user_quotas.increment_claude(user_id) agent_session.claude_counted = True + await session_manager.persist_session_snapshot(agent_session) async def _enforce_jobs_access_for_approvals( @@ -193,6 +195,9 @@ async def _enforce_jobs_access_for_approvals( "The selected jobs namespace is not one of your eligible paid organizations. " f"Allowed namespaces: {', '.join(access.paid_org_names)}" ), + "plan": user.get("plan", "free"), + "tool_call_ids": invalid_namespace, + "eligible_namespaces": access.paid_org_names, }, ) missing_namespace = [ @@ -236,13 +241,23 @@ async def _enforce_jobs_access_for_approvals( ) -def _check_session_access(session_id: str, user: dict[str, Any]) -> None: - """Verify the user has access to the given session. Raises 403 or 404.""" - info = session_manager.get_session_info(session_id) - if not info: +async def _check_session_access( + session_id: str, + user: dict[str, Any], + request: Request | None = None, +) -> AgentSession: + """Verify and lazily load the user's session. Raises 403 or 404.""" + hf_token = resolve_hf_request_token(request) if request is not None else user.get("hf_token") + agent_session = await session_manager.ensure_session_loaded( + session_id, + user["user_id"], + hf_token=hf_token, + ) + if not agent_session: raise HTTPException(status_code=404, detail="Session not found") - if not session_manager.verify_session_access(session_id, user["user_id"]): + if user["user_id"] != "dev" and agent_session.user_id not in {user["user_id"], "dev"}: raise HTTPException(status_code=403, detail="Access denied to this session") + return agent_session @router.get("/health", response_model=HealthResponse) @@ -332,10 +347,8 @@ async def generate_title( reasoning model — reasoning_effort=low keeps the reasoning budget small so the 60-token output budget isn't consumed before the title is written. """ - api_key = ( - os.environ.get("INFERENCE_TOKEN") - or (user.get("hf_token") if isinstance(user, dict) else None) - or os.environ.get("HF_TOKEN") + api_key = resolve_hf_router_token( + user.get("hf_token") if isinstance(user, dict) else None ) try: response = await acompletion( @@ -366,11 +379,21 @@ async def generate_title( title = title.translate(_TITLE_STRIP_CHARS).strip() if len(title) > 50: title = title[:50].rstrip() + "…" + try: + await _check_session_access(request.session_id, user) + await session_manager.update_session_title(request.session_id, title) + except Exception: + logger.debug("Skipping title persistence for missing session %s", request.session_id) return {"title": title} except Exception as e: logger.warning(f"Title generation failed: {e}") fallback = request.text.strip() title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback + try: + await _check_session_access(request.session_id, user) + await session_manager.update_session_title(request.session_id, title) + except Exception: + logger.debug("Skipping fallback title persistence for missing session %s", request.session_id) return {"title": title} @@ -391,14 +414,7 @@ async def create_session( Returns 503 if the server or user has reached the session limit. """ # Extract the user's HF token (Bearer header, HttpOnly cookie, or env var) - hf_token = None - auth_header = request.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - hf_token = auth_header[7:] - if not hf_token: - hf_token = request.cookies.get("hf_access_token") - if not hf_token: - hf_token = os.environ.get("HF_TOKEN") + hf_token = resolve_hf_request_token(request) # Optional model override. Empty body falls back to the config default. model: str | None = None @@ -444,14 +460,7 @@ async def restore_session_summary( if not isinstance(messages, list) or not messages: raise HTTPException(status_code=400, detail="Missing 'messages' array") - hf_token = None - auth_header = request.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - hf_token = auth_header[7:] - if not hf_token: - hf_token = request.cookies.get("hf_access_token") - if not hf_token: - hf_token = os.environ.get("HF_TOKEN") + hf_token = resolve_hf_request_token(request) model = body.get("model") valid_ids = {m["id"] for m in AVAILABLE_MODELS} @@ -488,7 +497,7 @@ async def get_session( session_id: str, user: dict = Depends(get_current_user) ) -> SessionInfo: """Get session information. Only accessible by the session owner.""" - _check_session_access(session_id, user) + await _check_session_access(session_id, user) info = session_manager.get_session_info(session_id) return SessionInfo(**info) @@ -509,7 +518,7 @@ async def set_session_model( Switching TO an Anthropic model requires HF org membership (PR #63); free-model switches are unrestricted. """ - _check_session_access(session_id, user) + agent_session = await _check_session_access(session_id, user, request) model_id = body.get("model") if not model_id: raise HTTPException(status_code=400, detail="Missing 'model' field") @@ -517,10 +526,9 @@ async def set_session_model( if model_id not in valid_ids: raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") await _require_hf_for_anthropic(request, model_id) - agent_session = session_manager.sessions.get(session_id) if not agent_session: raise HTTPException(status_code=404, detail="Session not found") - agent_session.session.update_model(model_id) + await session_manager.update_session_model(session_id, model_id) logger.info( f"Session {session_id} model → {model_id} " f"(by {user.get('username', 'unknown')})" @@ -528,6 +536,27 @@ async def set_session_model( return {"session_id": session_id, "model": model_id} +@router.post("/session/{session_id}/notifications") +async def set_session_notifications( + session_id: str, + body: SessionNotificationsRequest, + user: dict = Depends(get_current_user), +) -> dict: + """Replace the session's auto-notification destinations.""" + agent_session = await _check_session_access(session_id, user) + try: + destinations = session_manager.set_notification_destinations( + session_id, body.destinations + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + await session_manager.persist_session_snapshot(agent_session) + return { + "session_id": session_id, + "notification_destinations": destinations, + } + + @router.get("/user/quota") async def get_user_quota(user: dict = Depends(get_current_user)) -> dict: """Return the user's plan tier and today's Claude-session quota state.""" @@ -545,14 +574,7 @@ async def get_user_quota(user: dict = Depends(get_current_user)) -> dict: @router.get("/user/jobs-access") async def get_jobs_access_info(request: Request, user: dict = Depends(get_current_user)) -> dict: """Return whether the current token can run HF Jobs and under which namespaces.""" - token = None - auth_header = request.headers.get("Authorization", "") - if auth_header.startswith("Bearer "): - token = auth_header[7:] - if not token: - token = request.cookies.get("hf_access_token") - if not token: - token = os.environ.get("HF_TOKEN") + token = resolve_hf_request_token(request) access = await get_jobs_access(token or "") return { @@ -566,7 +588,7 @@ async def get_jobs_access_info(request: Request, user: dict = Depends(get_curren @router.get("/sessions", response_model=list[SessionInfo]) async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]: """List sessions belonging to the authenticated user.""" - sessions = session_manager.list_sessions(user_id=user["user_id"]) + sessions = await session_manager.list_sessions(user_id=user["user_id"]) return [SessionInfo(**s) for s in sessions] @@ -575,7 +597,7 @@ async def delete_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Delete a session. Only accessible by the session owner.""" - _check_session_access(session_id, user) + await _check_session_access(session_id, user) success = await session_manager.delete_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") @@ -587,10 +609,8 @@ async def submit_input( request: SubmitRequest, user: dict = Depends(get_current_user) ) -> dict: """Submit user input to a session. Only accessible by the session owner.""" - _check_session_access(request.session_id, user) - agent_session = session_manager.sessions.get(request.session_id) - if agent_session is not None: - await _enforce_claude_quota(user, agent_session) + agent_session = await _check_session_access(request.session_id, user) + await _enforce_claude_quota(user, agent_session) success = await session_manager.submit_user_input(request.session_id, request.text) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -602,10 +622,7 @@ async def submit_approval( request: ApprovalRequest, user: dict = Depends(get_current_user) ) -> dict: """Submit tool approvals to a session. Only accessible by the session owner.""" - _check_session_access(request.session_id, user) - agent_session = session_manager.sessions.get(request.session_id) - if agent_session is None: - raise HTTPException(status_code=404, detail="Session not found or inactive") + agent_session = await _check_session_access(request.session_id, user) approvals = [ { "tool_call_id": a.tool_call_id, @@ -630,9 +647,7 @@ async def chat_sse( user: dict = Depends(get_current_user), ) -> StreamingResponse: """SSE endpoint: submit input or approval, then stream events until turn ends.""" - _check_session_access(session_id, user) - - agent_session = session_manager.sessions.get(session_id) + agent_session = await _check_session_access(session_id, user, request) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -698,10 +713,7 @@ async def record_pro_click( user: dict = Depends(get_current_user), ) -> dict: """Record a click on a Pro upgrade CTA shown from inside a session.""" - _check_session_access(session_id, user) - agent_session = session_manager.sessions.get(session_id) - if not agent_session: - raise HTTPException(status_code=404, detail="Session not found") + agent_session = await _check_session_access(session_id, user) from agent.core import telemetry await telemetry.record_pro_cta_click( @@ -723,12 +735,53 @@ _TERMINAL_EVENTS = {"turn_complete", "approval_required", "error", "interrupted" _SSE_KEEPALIVE_SECONDS = 15 -def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse: +def _last_event_seq(request: Request) -> int: + raw = request.headers.get("last-event-id") or request.query_params.get("after") or "0" + try: + return max(0, int(raw)) + except (TypeError, ValueError): + return 0 + + +def _format_sse(msg: dict[str, Any]) -> str: + seq = msg.get("seq") + body = {"event_type": msg.get("event_type"), "data": msg.get("data") or {}} + if seq is not None: + body["seq"] = seq + return f"id: {seq}\ndata: {json.dumps(body)}\n\n" + return f"data: {json.dumps(body)}\n\n" + + +def _event_doc_to_msg(doc: dict[str, Any]) -> dict[str, Any]: + return { + "event_type": doc.get("event_type"), + "data": doc.get("data") or {}, + "seq": doc.get("seq"), + } + + +def _sse_response( + broadcaster, + event_queue, + sub_id, + *, + replay_events: list[dict[str, Any]] | None = None, + after_seq: int = 0, +) -> StreamingResponse: """Build a StreamingResponse that drains *event_queue* as SSE, sending keepalive comments every 15 s to prevent proxy timeouts.""" async def event_generator(): try: + for doc in replay_events or []: + msg = _event_doc_to_msg(doc) + seq = msg.get("seq") + if isinstance(seq, int) and seq <= after_seq: + continue + yield _format_sse(msg) + if msg.get("event_type", "") in _TERMINAL_EVENTS: + return + while True: try: msg = await asyncio.wait_for( @@ -739,7 +792,7 @@ def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse: yield ": keepalive\n\n" continue event_type = msg.get("event_type", "") - yield f"data: {json.dumps(msg)}\n\n" + yield _format_sse(msg) if event_type in _TERMINAL_EVENTS: break finally: @@ -759,6 +812,7 @@ def _sse_response(broadcaster, event_queue, sub_id) -> StreamingResponse: @router.get("/events/{session_id}") async def subscribe_events( session_id: str, + request: Request, user: dict = Depends(get_current_user), ) -> StreamingResponse: """Subscribe to events for a running session without submitting new input. @@ -766,15 +820,21 @@ async def subscribe_events( Used by the frontend to re-attach after a connection drop (e.g. screen sleep). Returns 404 if the session isn't active or isn't processing. """ - _check_session_access(session_id, user) - - agent_session = session_manager.sessions.get(session_id) + agent_session = await _check_session_access(session_id, user, request) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") + after_seq = _last_event_seq(request) + replay_events = await session_manager._store().load_events_after(session_id, after_seq) broadcaster = agent_session.broadcaster sub_id, event_queue = broadcaster.subscribe() - return _sse_response(broadcaster, event_queue, sub_id) + return _sse_response( + broadcaster, + event_queue, + sub_id, + replay_events=replay_events, + after_seq=after_seq, + ) @router.post("/interrupt/{session_id}") @@ -782,7 +842,7 @@ async def interrupt_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Interrupt the current operation in a session.""" - _check_session_access(session_id, user) + await _check_session_access(session_id, user) success = await session_manager.interrupt(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -794,17 +854,16 @@ async def get_session_messages( session_id: str, user: dict = Depends(get_current_user) ) -> list[dict]: """Return the session's message history from memory.""" - _check_session_access(session_id, user) - agent_session = session_manager.sessions.get(session_id) + agent_session = await _check_session_access(session_id, user) if not agent_session or not agent_session.is_active: raise HTTPException(status_code=404, detail="Session not found or inactive") - return [msg.model_dump() for msg in agent_session.session.context_manager.items] + return [msg.model_dump(mode="json") for msg in agent_session.session.context_manager.items] @router.post("/undo/{session_id}") async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict: """Undo the last turn in a session.""" - _check_session_access(session_id, user) + await _check_session_access(session_id, user) success = await session_manager.undo(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -816,7 +875,7 @@ async def truncate_session( session_id: str, body: TruncateRequest, user: dict = Depends(get_current_user) ) -> dict: """Truncate conversation to before a specific user message.""" - _check_session_access(session_id, user) + await _check_session_access(session_id, user) success = await session_manager.truncate(session_id, body.user_message_index) if not success: raise HTTPException(status_code=404, detail="Session not found, inactive, or message index out of range") @@ -828,7 +887,7 @@ async def compact_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Compact the context in a session.""" - _check_session_access(session_id, user) + await _check_session_access(session_id, user) success = await session_manager.compact(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -840,13 +899,12 @@ async def shutdown_session( session_id: str, user: dict = Depends(get_current_user) ) -> dict: """Shutdown a session.""" - _check_session_access(session_id, user) + await _check_session_access(session_id, user) success = await session_manager.shutdown_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") return {"status": "shutdown_requested", "session_id": session_id} - @router.post("/feedback/{session_id}") async def submit_feedback( session_id: str, @@ -859,10 +917,7 @@ async def submit_feedback( turn_index?: int, comment?: str, message_id?: str} Appended as a `feedback` event and saved with the session trajectory. """ - _check_session_access(session_id, user) - agent_session = session_manager.sessions.get(session_id) - if not agent_session: - raise HTTPException(status_code=404, detail="Session not found") + agent_session = await _check_session_access(session_id, user) rating = body.get("rating") if rating not in {"up", "down", "outcome_success", "outcome_fail"}: diff --git a/backend/session_manager.py b/backend/session_manager.py index 68177fc12280d07f339b468f3b0b9bdb0c24c475..bab1c3b2d55ffdeb2062ca0d22efb863cf773580 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -1,6 +1,7 @@ """Session manager for handling multiple concurrent agent sessions.""" import asyncio +import json import logging import uuid from dataclasses import dataclass, field @@ -10,7 +11,9 @@ from typing import Any, Optional from agent.config import load_config from agent.core.agent_loop import process_submission +from agent.messaging.gateway import NotificationGateway from agent.core.session import Event, OpType, Session +from agent.core.session_persistence import get_session_store from agent.core.tools import ToolRouter # Get project root (parent of backend directory) @@ -41,9 +44,8 @@ logger = logging.getLogger(__name__) class EventBroadcaster: """Reads from the agent's event queue and fans out to SSE subscribers. - Events that arrive when no subscribers are listening are discarded. - With SSE each turn is a separate request, so there is no reconnect - scenario that would need buffered replay. + Events that arrive when no subscribers are listening are discarded by + this in-memory fanout. Durable replay is handled by session_persistence. """ def __init__(self, event_queue: asyncio.Queue): @@ -67,7 +69,7 @@ class EventBroadcaster: while True: try: event: Event = await self._source.get() - msg = {"event_type": event.event_type, "data": event.data} + msg = {"event_type": event.event_type, "data": event.data, "seq": event.seq} for q in self._subscribers.values(): await q.put(msg) except asyncio.CancelledError: @@ -91,6 +93,7 @@ class AgentSession: is_active: bool = True is_processing: bool = False # True while a submission is being executed broadcaster: Any = None + title: str | None = None # True once this session has been counted against the user's daily # Claude quota. Guards double-counting when the user re-selects an # Anthropic model mid-session. @@ -119,8 +122,27 @@ class SessionManager: def __init__(self, config_path: str | None = None) -> None: self.config = load_config(config_path or DEFAULT_CONFIG_PATH) + self.messaging_gateway = NotificationGateway(self.config.messaging) self.sessions: dict[str, AgentSession] = {} self._lock = asyncio.Lock() + self.persistence_store = None + + async def start(self) -> None: + """Start shared background resources.""" + self.persistence_store = get_session_store() + await self.persistence_store.init() + await self.messaging_gateway.start() + + async def close(self) -> None: + """Flush and close shared background resources.""" + await self.messaging_gateway.close() + if self.persistence_store is not None: + await self.persistence_store.close() + + def _store(self): + if self.persistence_store is None: + self.persistence_store = get_session_store() + return self.persistence_store def _count_user_sessions(self, user_id: str) -> int: """Count active sessions owned by a specific user.""" @@ -130,6 +152,314 @@ class SessionManager: if s.user_id == user_id and s.is_active ) + def _create_session_sync( + self, + *, + session_id: str, + user_id: str, + hf_token: str | None, + model: str | None, + event_queue: asyncio.Queue, + notification_destinations: list[str] | None = None, + ) -> tuple[ToolRouter, Session]: + """Build blocking per-session resources in a worker thread.""" + import time as _time + + t0 = _time.monotonic() + tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token) + # Deep-copy config so each session's model switches independently — + # tab A picking GLM doesn't flip tab B off Claude. + session_config = self.config.model_copy(deep=True) + if model: + session_config.model_name = model + session = Session( + event_queue=event_queue, + config=session_config, + tool_router=tool_router, + hf_token=hf_token, + user_id=user_id, + notification_gateway=self.messaging_gateway, + notification_destinations=notification_destinations or [], + session_id=session_id, + persistence_store=self._store(), + ) + t1 = _time.monotonic() + logger.info("Session initialized in %.2fs", t1 - t0) + return tool_router, session + + def _serialize_messages(self, session: Session) -> list[dict[str, Any]]: + return [ + msg.model_dump(mode="json") + for msg in session.context_manager.items + ] + + def _serialize_pending_approval(self, session: Session) -> list[dict[str, Any]]: + pending = session.pending_approval or {} + tool_calls = pending.get("tool_calls") or [] + serialized: list[dict[str, Any]] = [] + for tc in tool_calls: + if hasattr(tc, "model_dump"): + serialized.append(tc.model_dump(mode="json")) + elif isinstance(tc, dict): + serialized.append(tc) + return serialized + + @staticmethod + def _pending_tools_for_api(session: Session) -> list[dict[str, Any]] | None: + pending = session.pending_approval or {} + tool_calls = pending.get("tool_calls") or [] + if not tool_calls: + return None + result: list[dict[str, Any]] = [] + for tc in tool_calls: + try: + args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError, TypeError): + args = {} + result.append( + { + "tool": getattr(tc.function, "name", None), + "tool_call_id": getattr(tc, "id", None), + "arguments": args, + } + ) + return result + + def _restore_pending_approval( + self, session: Session, pending_approval: list[dict[str, Any]] | None + ) -> None: + if not pending_approval: + session.pending_approval = None + return + from litellm import ChatCompletionMessageToolCall as ToolCall + + restored = [] + for raw in pending_approval: + try: + if "function" in raw: + restored.append(ToolCall(**raw)) + else: + restored.append( + ToolCall( + id=raw["tool_call_id"], + type="function", + function={ + "name": raw["tool"], + "arguments": json.dumps(raw.get("arguments") or {}), + }, + ) + ) + except Exception as e: + logger.warning("Dropping malformed pending approval: %s", e) + session.pending_approval = {"tool_calls": restored} if restored else None + + @staticmethod + def _pending_docs_for_api( + pending_approval: list[dict[str, Any]] | None, + ) -> list[dict[str, Any]] | None: + if not pending_approval: + return None + result: list[dict[str, Any]] = [] + for raw in pending_approval: + if "function" in raw: + function = raw.get("function") or {} + try: + args = json.loads(function.get("arguments") or "{}") + except (json.JSONDecodeError, TypeError): + args = {} + result.append( + { + "tool": function.get("name"), + "tool_call_id": raw.get("id"), + "arguments": args, + } + ) + elif {"tool", "tool_call_id"}.issubset(raw): + result.append( + { + "tool": raw.get("tool"), + "tool_call_id": raw.get("tool_call_id"), + "arguments": raw.get("arguments") or {}, + } + ) + return result or None + + @staticmethod + def _runtime_state(agent_session: AgentSession) -> str: + if agent_session.session.pending_approval: + return "waiting_approval" + if agent_session.is_processing: + return "processing" + if not agent_session.is_active: + return "ended" + return "idle" + + async def _start_agent_session( + self, + *, + agent_session: AgentSession, + event_queue: asyncio.Queue, + tool_router: ToolRouter, + ) -> AgentSession: + async with self._lock: + existing = self.sessions.get(agent_session.session_id) + if existing: + return existing + self.sessions[agent_session.session_id] = agent_session + + task = asyncio.create_task( + self._run_session( + agent_session.session_id, + agent_session.submission_queue, + event_queue, + tool_router, + ) + ) + agent_session.task = task + return agent_session + + @staticmethod + def _can_access_session(agent_session: AgentSession, user_id: str) -> bool: + return ( + user_id == "dev" + or agent_session.user_id == "dev" + or agent_session.user_id == user_id + ) + + @staticmethod + def _update_hf_token(agent_session: AgentSession, hf_token: str | None) -> None: + if not hf_token: + return + agent_session.hf_token = hf_token + agent_session.session.hf_token = hf_token + + async def persist_session_snapshot( + self, + agent_session: AgentSession, + *, + runtime_state: str | None = None, + status: str = "active", + ) -> None: + """Persist the current runtime context snapshot.""" + store = self._store() + if not getattr(store, "enabled", False): + return + try: + await store.save_snapshot( + session_id=agent_session.session_id, + user_id=agent_session.user_id, + model=agent_session.session.config.model_name, + title=agent_session.title, + messages=self._serialize_messages(agent_session.session), + runtime_state=runtime_state or self._runtime_state(agent_session), + status=status, + turn_count=agent_session.session.turn_count, + pending_approval=self._serialize_pending_approval(agent_session.session), + claude_counted=agent_session.claude_counted, + created_at=agent_session.created_at, + notification_destinations=list( + agent_session.session.notification_destinations + ), + ) + except Exception as e: + logger.warning( + "Failed to persist snapshot for %s: %s", + agent_session.session_id, + e, + ) + + async def ensure_session_loaded( + self, + session_id: str, + user_id: str, + hf_token: str | None = None, + ) -> AgentSession | None: + """Return a live runtime session, lazily restoring it from Mongo.""" + async with self._lock: + existing = self.sessions.get(session_id) + if existing: + if self._can_access_session(existing, user_id): + self._update_hf_token(existing, hf_token) + return existing + return None + + store = self._store() + loaded = await store.load_session(session_id) + if not loaded: + return None + + async with self._lock: + existing = self.sessions.get(session_id) + if existing: + if self._can_access_session(existing, user_id): + self._update_hf_token(existing, hf_token) + return existing + return None + + meta = loaded.get("metadata") or {} + owner = str(meta.get("user_id") or "") + if user_id != "dev" and owner != "dev" and owner != user_id: + return None + + from litellm import Message + + model = meta.get("model") or self.config.model_name + event_queue: asyncio.Queue = asyncio.Queue() + submission_queue: asyncio.Queue = asyncio.Queue() + tool_router, session = await asyncio.to_thread( + self._create_session_sync, + session_id=session_id, + user_id=owner or user_id, + hf_token=hf_token, + model=model, + event_queue=event_queue, + notification_destinations=meta.get("notification_destinations") or [], + ) + + restored_messages: list[Message] = [] + for raw in loaded.get("messages") or []: + if not isinstance(raw, dict) or raw.get("role") == "system": + continue + try: + restored_messages.append(Message.model_validate(raw)) + except Exception as e: + logger.warning("Dropping malformed restored message: %s", e) + if restored_messages: + # Keep the freshly-rendered system prompt, then attach the durable + # non-system context so tools/date/user context stay current. + session.context_manager.items = [session.context_manager.items[0], *restored_messages] + + self._restore_pending_approval(session, meta.get("pending_approval") or []) + session.turn_count = int(meta.get("turn_count") or 0) + + created_at = meta.get("created_at") + if not isinstance(created_at, datetime): + created_at = datetime.utcnow() + + agent_session = AgentSession( + session_id=session_id, + session=session, + tool_router=tool_router, + submission_queue=submission_queue, + user_id=owner or user_id, + hf_token=hf_token, + created_at=created_at, + is_active=True, + is_processing=False, + claude_counted=bool(meta.get("claude_counted")), + title=meta.get("title"), + ) + started = await self._start_agent_session( + agent_session=agent_session, + event_queue=event_queue, + tool_router=tool_router, + ) + if started is not agent_session: + self._update_hf_token(started, hf_token) + return started + logger.info("Restored session %s for user %s", session_id, owner or user_id) + return agent_session + async def create_session( self, user_id: str = "dev", @@ -178,27 +508,14 @@ class SessionManager: event_queue: asyncio.Queue = asyncio.Queue() # Run blocking constructors in a thread to keep the event loop responsive. - # Without this, Session.__init__ → ContextManager → litellm.get_max_tokens() - # blocks all HTTP/SSE handling. - import time as _time - - def _create_session_sync(): - t0 = _time.monotonic() - tool_router = ToolRouter(self.config.mcpServers, hf_token=hf_token) - # Deep-copy config so each session's model switches independently — - # tab A picking GLM doesn't flip tab B off Claude. - session_config = self.config.model_copy(deep=True) - if model: - session_config.model_name = model - session = Session( - event_queue, config=session_config, tool_router=tool_router, - hf_token=hf_token, - ) - t1 = _time.monotonic() - logger.info(f"Session initialized in {t1 - t0:.2f}s") - return tool_router, session - - tool_router, session = await asyncio.to_thread(_create_session_sync) + tool_router, session = await asyncio.to_thread( + self._create_session_sync, + session_id=session_id, + user_id=user_id, + hf_token=hf_token, + model=model, + event_queue=event_queue, + ) # Create wrapper agent_session = AgentSession( @@ -210,14 +527,12 @@ class SessionManager: hf_token=hf_token, ) - async with self._lock: - self.sessions[session_id] = agent_session - - # Start the agent loop task - task = asyncio.create_task( - self._run_session(session_id, submission_queue, event_queue, tool_router) + await self._start_agent_session( + agent_session=agent_session, + event_queue=event_queue, + tool_router=tool_router, ) - agent_session.task = task + await self.persist_session_snapshot(agent_session, runtime_state="idle") logger.info(f"Created session {session_id} for user {user_id}") return session_id @@ -283,21 +598,38 @@ class SessionManager: ), ) session.context_manager.items.append(seed) + await self.persist_session_snapshot(agent_session, runtime_state="idle") return len(parsed) @staticmethod async def _cleanup_sandbox(session: Session) -> None: - """Delete the sandbox Space if one was created for this session.""" + """Delete the sandbox Space if one was created for this session. + + Retries on transient failures (HF API 5xx, rate-limit, network blips) + with exponential backoff. A single missed delete = a permanently + orphaned Space, so the cost of an extra retry beats the alternative. + """ sandbox = getattr(session, "sandbox", None) - if sandbox and getattr(sandbox, "_owns_space", False): - space_id = getattr(sandbox, "space_id", None) + if not (sandbox and getattr(sandbox, "_owns_space", False)): + return + + space_id = getattr(sandbox, "space_id", None) + last_err: Exception | None = None + for attempt in range(3): try: - logger.info(f"Deleting sandbox {space_id}...") + logger.info(f"Deleting sandbox {space_id} (attempt {attempt + 1}/3)...") await asyncio.to_thread(sandbox.delete) from agent.core import telemetry await telemetry.record_sandbox_destroy(session, sandbox) + return except Exception as e: - logger.warning(f"Failed to delete sandbox {space_id}: {e}") + last_err = e + if attempt < 2: + await asyncio.sleep(2 ** attempt) + logger.error( + f"Failed to delete sandbox {space_id} after 3 attempts: {last_err}. " + f"Orphan — sweep script will pick it up." + ) async def _run_session( self, @@ -337,6 +669,7 @@ class SessionManager: should_continue = await process_submission(session, submission) finally: agent_session.is_processing = False + await self.persist_session_snapshot(agent_session) if not should_continue: break except asyncio.TimeoutError: @@ -371,6 +704,11 @@ class SessionManager: async with self._lock: if session_id in self.sessions: self.sessions[session_id].is_active = False + await self.persist_session_snapshot( + self.sessions[session_id], + runtime_state="ended", + status="ended", + ) logger.info(f"Session {session_id} ended") @@ -420,7 +758,10 @@ class SessionManager: agent_session = self.sessions.get(session_id) if not agent_session or not agent_session.is_active: return False - return agent_session.session.context_manager.truncate_to_user_message(user_message_index) + success = agent_session.session.context_manager.truncate_to_user_message(user_message_index) + if success: + await self.persist_session_snapshot(agent_session, runtime_state="idle") + return success async def compact(self, session_id: str) -> bool: """Compact context in a session.""" @@ -445,12 +786,15 @@ class SessionManager: return success async def delete_session(self, session_id: str) -> bool: - """Delete a session entirely.""" + """Soft-delete a session and stop its runtime resources.""" async with self._lock: agent_session = self.sessions.pop(session_id, None) if not agent_session: - return False + await self._store().soft_delete_session(session_id) + return True + + await self._store().soft_delete_session(session_id) # Clean up sandbox Space before cancelling the task await self._cleanup_sandbox(agent_session.session) @@ -465,6 +809,21 @@ class SessionManager: return True + async def update_session_title(self, session_id: str, title: str | None) -> None: + """Persist a user-visible title for sidebar rehydration.""" + agent_session = self.sessions.get(session_id) + if agent_session: + agent_session.title = title + await self._store().update_session_fields(session_id, title=title) + + async def update_session_model(self, session_id: str, model_id: str) -> bool: + agent_session = self.sessions.get(session_id) + if not agent_session or not agent_session.is_active: + return False + agent_session.session.update_model(model_id) + await self.persist_session_snapshot(agent_session, runtime_state="idle") + return True + def get_session_owner(self, session_id: str) -> str | None: """Get the user_id that owns a session, or None if session doesn't exist.""" agent_session = self.sessions.get(session_id) @@ -492,22 +851,7 @@ class SessionManager: if not agent_session: return None - # Extract pending approval tools if any - pending_approval = None - pa = agent_session.session.pending_approval - if pa and pa.get("tool_calls"): - pending_approval = [] - for tc in pa["tool_calls"]: - import json - try: - args = json.loads(tc.function.arguments) - except (json.JSONDecodeError, AttributeError): - args = {} - pending_approval.append({ - "tool": tc.function.name, - "tool_call_id": tc.id, - "arguments": args, - }) + pending_approval = self._pending_tools_for_api(agent_session.session) return { "session_id": session_id, @@ -518,16 +862,80 @@ class SessionManager: "user_id": agent_session.user_id, "pending_approval": pending_approval, "model": agent_session.session.config.model_name, + "title": agent_session.title, + "notification_destinations": list( + agent_session.session.notification_destinations + ), } - def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: + def set_notification_destinations( + self, session_id: str, destinations: list[str] + ) -> list[str]: + """Replace the session's opted-in auto-notification destinations.""" + agent_session = self.sessions.get(session_id) + if not agent_session or not agent_session.is_active: + raise ValueError("Session not found or inactive") + + normalized: list[str] = [] + seen: set[str] = set() + for raw_name in destinations: + name = raw_name.strip() + if not name: + raise ValueError("Destination names must not be empty") + destination = self.config.messaging.get_destination(name) + if destination is None: + raise ValueError(f"Unknown destination '{name}'") + if not destination.allow_auto_events: + raise ValueError( + f"Destination '{name}' is not enabled for auto events" + ) + if name not in seen: + normalized.append(name) + seen.add(name) + + agent_session.session.set_notification_destinations(normalized) + return normalized + + async def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: """List sessions, optionally filtered by user. Args: user_id: If provided, only return sessions owned by this user. If "dev", return all sessions (dev mode). """ - results = [] + results: list[dict[str, Any]] = [] + store = self._store() + if getattr(store, "enabled", False): + for row in await store.list_sessions(user_id or "dev"): + sid = row.get("session_id") or row.get("_id") + if not sid: + continue + runtime_info = self.get_session_info(str(sid)) + if runtime_info: + results.append(runtime_info) + continue + created_at = row.get("created_at") + if isinstance(created_at, datetime): + created_at_str = created_at.isoformat() + else: + created_at_str = str(created_at or datetime.utcnow().isoformat()) + pending = self._pending_docs_for_api(row.get("pending_approval") or []) + results.append( + { + "session_id": str(sid), + "created_at": created_at_str, + "is_active": row.get("status") != "ended", + "is_processing": row.get("runtime_state") == "processing", + "message_count": int(row.get("message_count") or 0), + "user_id": row.get("user_id") or "dev", + "pending_approval": pending or None, + "model": row.get("model"), + "title": row.get("title"), + "notification_destinations": row.get("notification_destinations") or [], + } + ) + return results + for sid in self.sessions: info = self.get_session_info(sid) if not info: diff --git a/backend/user_quotas.py b/backend/user_quotas.py index 2b38b1111fe9bc57d36eebcc87b3ac9c88f8326b..94b1b0274a7f9c8b046e8210b50da1215d7743e8 100644 --- a/backend/user_quotas.py +++ b/backend/user_quotas.py @@ -1,9 +1,8 @@ -"""In-memory daily quota for Claude session creations. +"""Daily quota for Claude session creations. Tracks per-user Claude session starts against a daily cap derived from the -user's HF plan. Caps reset at UTC midnight; the store itself is in-process -and wipes on restart (deliberate — the cost of occasional over-subsidy at -restart is much lower than running a DB). +user's HF plan. MongoDB is the source of truth when configured; the +in-process dict remains the fallback for local/dev/test runs. Unit: session *creations*, not messages. A user who selects Claude in a new session consumes one quota point; switching an existing Claude session to @@ -18,6 +17,8 @@ import asyncio import os from datetime import UTC, datetime +from agent.core.session_persistence import NoopSessionStore, get_session_store, _reset_store_for_tests + CLAUDE_FREE_DAILY: int = int(os.environ.get("CLAUDE_FREE_DAILY", "1")) CLAUDE_PRO_DAILY: int = int(os.environ.get("CLAUDE_PRO_DAILY", "20")) @@ -37,6 +38,11 @@ def daily_cap_for(plan: str | None) -> int: async def get_claude_used_today(user_id: str) -> int: """Return today's Claude session count for the user (0 if none / stale day).""" + store = get_session_store() + if getattr(store, "enabled", False): + db_count = await store.get_quota(user_id, _today()) + return db_count or 0 + async with _lock: entry = _claude_counts.get(user_id) if entry is None: @@ -51,11 +57,37 @@ async def get_claude_used_today(user_id: str) -> int: async def increment_claude(user_id: str) -> int: """Bump today's Claude session count for the user. Returns the new value.""" + store = get_session_store() + if getattr(store, "enabled", False): + db_count = await store.try_increment_quota(user_id, _today(), cap=10**9) + return db_count or 0 + + async with _lock: + today = _today() + day, count = _claude_counts.get(user_id, (today, 0)) + if day != today: + count = 0 + count += 1 + _claude_counts[user_id] = (today, count) + return count + + +async def try_increment_claude(user_id: str, cap: int) -> int | None: + """Atomically bump today's count if below *cap*. + + Returns the new count, or None when the user is already at the cap. + """ + store = get_session_store() + if getattr(store, "enabled", False): + return await store.try_increment_quota(user_id, _today(), cap) + async with _lock: today = _today() day, count = _claude_counts.get(user_id, (today, 0)) if day != today: count = 0 + if count >= cap: + return None count += 1 _claude_counts[user_id] = (today, count) return count @@ -63,6 +95,11 @@ async def increment_claude(user_id: str) -> int: async def refund_claude(user_id: str) -> None: """Decrement today's count — used when session creation fails after a successful gate.""" + store = get_session_store() + if getattr(store, "enabled", False): + await store.refund_quota(user_id, _today()) + return + async with _lock: entry = _claude_counts.get(user_id) if entry is None: @@ -81,3 +118,4 @@ async def refund_claude(user_id: str) -> None: def _reset_for_tests() -> None: """Test-only: clear the in-memory store.""" _claude_counts.clear() + _reset_store_for_tests(NoopSessionStore()) diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/cli_agent_config.json b/configs/cli_agent_config.json index 99335ca719d469737f4da1c3b48c1894a73b1845..5c6a22a354108453aea6de90bf81c76ef838fcc9 100644 --- a/configs/cli_agent_config.json +++ b/configs/cli_agent_config.json @@ -5,6 +5,11 @@ "yolo_mode": false, "confirm_cpu_jobs": true, "auto_file_upload": true, + "messaging": { + "enabled": false, + "auto_event_types": ["approval_required", "error", "turn_complete"], + "destinations": {} + }, "mcpServers": { "hf-mcp-server": { "transport": "http", diff --git a/frontend/src/components/Chat/MarkdownContent.tsx b/frontend/src/components/Chat/MarkdownContent.tsx index aaab83eb118ecd8627950dee80f794c9cefa0b5d..0d1e69171d3955e998d78807006862bd95422c34 100644 --- a/frontend/src/components/Chat/MarkdownContent.tsx +++ b/frontend/src/components/Chat/MarkdownContent.tsx @@ -1,4 +1,4 @@ -import { useMemo, useRef, useState, useEffect } from 'react'; +import { useMemo, useRef, useState, useEffect, type ComponentPropsWithoutRef } from 'react'; import { Box } from '@mui/material'; import ReactMarkdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; @@ -166,9 +166,17 @@ export default function MarkdownContent({ content, sx, isStreaming = false }: Ma const remarkPlugins = useMemo(() => [remarkGfm], []); + const components = useMemo(() => ({ + a: ({ href, children, ...props }: ComponentPropsWithoutRef<'a'>) => ( + + {children} + + ), + }), []); + return ( - {displayContent} + {displayContent} ); } diff --git a/frontend/src/components/Chat/ToolCallGroup.tsx b/frontend/src/components/Chat/ToolCallGroup.tsx index fc9fe35c19a7120486467f0ad1a84c1ca5681955..657e9e3688250cf48fb22501f8fcd797264cda30 100644 --- a/frontend/src/components/Chat/ToolCallGroup.tsx +++ b/frontend/src/components/Chat/ToolCallGroup.tsx @@ -220,6 +220,194 @@ function ResearchSteps({ steps }: { steps: string[] }) { ); } +// --------------------------------------------------------------------------- +// Trackio dashboard embed +// --------------------------------------------------------------------------- + +// HF repo IDs are `/` where each segment is alphanumerics plus +// `_`, `.`, `-`. Anything else (slashes, spaces, query params, missing owner) +// would let an attacker-controlled string redirect the embed to a different +// Space, so we refuse to render rather than build a malformed URL. +const SPACE_ID_PATTERN = /^[a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+$/; + +function isValidSpaceId(spaceId: string): boolean { + return SPACE_ID_PATTERN.test(spaceId); +} + +/** HF Space embed subdomain: 'user/space_name' → 'user-space-name'. */ +function spaceIdToSubdomain(spaceId: string): string { + return spaceId + .toLowerCase() + .replace(/[/_.]/g, '-') + .replace(/-+/g, '-') + .replace(/^-|-$/g, ''); +} + +function buildTrackioEmbedUrl(spaceId: string, project?: string): string { + // __theme=dark is gradio's standard query param to force the embedded + // dashboard into dark mode so it blends with the surrounding chat instead + // of flashing a bright white panel inside the dark UI. + const params = new URLSearchParams({ + sidebar: 'hidden', + footer: 'false', + __theme: 'dark', + }); + if (project) params.set('project', project); + return `https://${spaceIdToSubdomain(spaceId)}.hf.space/?${params.toString()}`; +} + +function buildTrackioPageUrl(spaceId: string, project?: string): string { + const qs = project ? `?${new URLSearchParams({ project }).toString()}` : ''; + return `https://huggingface.co/spaces/${spaceId}${qs}`; +} + +function TrackioEmbed({ spaceId, project }: { spaceId: string; project?: string }) { + const [expanded, setExpanded] = useState(true); + const [iframeLoaded, setIframeLoaded] = useState(false); + const embedUrl = useMemo(() => buildTrackioEmbedUrl(spaceId, project), [spaceId, project]); + const pageUrl = useMemo(() => buildTrackioPageUrl(spaceId, project), [spaceId, project]); + const label = project ? `${spaceId} · ${project}` : spaceId; + + if (!isValidSpaceId(spaceId)) return null; + + return ( + + + e.stopPropagation()} + sx={{ + px: 1.25, + py: 0.5, + borderBottom: expanded ? '1px solid var(--tool-border)' : 'none', + }} + > + + trackio + + + {label} + + e.stopPropagation()} + sx={{ + display: 'inline-flex', + alignItems: 'center', + gap: 0.4, + color: 'var(--accent-yellow)', + fontSize: '0.65rem', + textDecoration: 'none', + '&:hover': { textDecoration: 'underline' }, + }} + > + + Open + + + + {expanded && ( + +