Add Slack gateway (#116)
Browse files* Add messaging gateway and notify tool
Co-authored-by: OpenAI Codex <codex@openai.com>
* Handle Bedrock streaming permission denials
Co-authored-by: OpenAI Codex <codex@openai.com>
* Revert "Handle Bedrock streaming permission denials"
Co-authored-by: OpenAI Codex <codex@openai.com>
* Add automatic completion notifications
Co-authored-by: OpenAI Codex <codex@openai.com>
* Auto-attach CLI notification destinations
Co-authored-by: OpenAI Codex <codex@openai.com>
* Defer CLI completion notifications until render
Co-authored-by: OpenAI Codex <codex@openai.com>
* Increase completion notification summary cap
Co-authored-by: OpenAI Codex <codex@openai.com>
* Add Slack user notification defaults
Co-authored-by: Codex <codex@openai.com>
* Increase Slack turn completion summary limit
Co-authored-by: Codex <codex@openai.com>
* Remove legacy auto notification event upgrade
Co-authored-by: Codex <codex@openai.com>
* Require session config instead of hard-coded model fallback
Co-authored-by: Codex <codex@openai.com>
* Address Slack notification review findings
Co-authored-by: Codex <codex@openai.com>
* Fix Anthropic thinking signature replay
Rebuild signed Anthropic thinking blocks from streaming chunks instead of replaying raw deltas, and recover stale histories by retrying once without thinking metadata when Anthropic rejects a signature.
Co-authored-by: OpenAI Codex <codex@openai.com>
* Format Slack notifications with mrkdwn
Convert common Markdown constructs in Slack notification bodies to Slack mrkdwn before posting, while preserving code spans and fenced code blocks.
Co-authored-by: OpenAI Codex <codex@openai.com>
---------
Co-authored-by: OpenAI Codex <codex@openai.com>
- README.md +50 -0
- agent/config.py +108 -3
- agent/core/agent_loop.py +132 -14
- agent/core/session.py +125 -5
- agent/core/tools.py +7 -0
- agent/main.py +24 -2
- agent/messaging/__init__.py +15 -0
- agent/messaging/base.py +27 -0
- agent/messaging/gateway.py +166 -0
- agent/messaging/models.py +123 -0
- agent/messaging/slack.py +186 -0
- agent/prompts/system_prompt_v3.yaml +1 -0
- agent/tools/notify_tool.py +108 -0
- backend/main.py +3 -2
- backend/models.py +8 -1
- backend/routes/agent.py +21 -1
- backend/session_manager.py +46 -1
- configs/cli_agent_config.json +5 -0
- pyproject.toml +1 -1
- tests/unit/test_cli_rendering.py +1 -1
- tests/unit/test_config.py +121 -0
- tests/unit/test_messaging.py +511 -0
- tests/unit/test_thinking_history.py +51 -6
- uv.lock +1 -1
|
@@ -56,6 +56,56 @@ ml-intern --max-iterations 100 "your prompt"
|
|
| 56 |
ml-intern --no-stream "your prompt"
|
| 57 |
```
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
## Architecture
|
| 60 |
|
| 61 |
### Component Overview
|
|
|
|
| 56 |
ml-intern --no-stream "your prompt"
|
| 57 |
```
|
| 58 |
|
| 59 |
+
## Supported Gateways
|
| 60 |
+
|
| 61 |
+
ML Intern currently supports one-way notification gateways from CLI sessions.
|
| 62 |
+
These gateways send out-of-band status updates; they do not accept inbound chat
|
| 63 |
+
messages.
|
| 64 |
+
|
| 65 |
+
### Slack
|
| 66 |
+
|
| 67 |
+
Slack notifications use the Slack Web API to post messages when the agent needs
|
| 68 |
+
approval, hits an error, or completes a turn. Create a Slack app with a bot token
|
| 69 |
+
that has `chat:write`, invite the bot to the target channel, then set:
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
SLACK_BOT_TOKEN=xoxb-...
|
| 73 |
+
SLACK_CHANNEL_ID=C...
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
The CLI automatically creates a `slack.default` destination when both variables
|
| 77 |
+
are present. Optional environment variables for the env-only default:
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
ML_INTERN_SLACK_NOTIFICATIONS=false
|
| 81 |
+
ML_INTERN_SLACK_DESTINATION=slack.ops
|
| 82 |
+
ML_INTERN_SLACK_AUTO_EVENTS=approval_required,error,turn_complete
|
| 83 |
+
ML_INTERN_SLACK_ALLOW_AGENT_TOOL=true
|
| 84 |
+
ML_INTERN_SLACK_ALLOW_AUTO_EVENTS=true
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
For a persistent user-level config, put overrides in
|
| 88 |
+
`~/.config/ml-intern/cli_agent_config.json` or point `ML_INTERN_CLI_CONFIG` at a
|
| 89 |
+
JSON file:
|
| 90 |
+
|
| 91 |
+
```json
|
| 92 |
+
{
|
| 93 |
+
"messaging": {
|
| 94 |
+
"enabled": true,
|
| 95 |
+
"auto_event_types": ["approval_required", "error", "turn_complete"],
|
| 96 |
+
"destinations": {
|
| 97 |
+
"slack.ops": {
|
| 98 |
+
"provider": "slack",
|
| 99 |
+
"token": "${SLACK_BOT_TOKEN}",
|
| 100 |
+
"channel": "${SLACK_CHANNEL_ID}",
|
| 101 |
+
"allow_agent_tool": true,
|
| 102 |
+
"allow_auto_events": true
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
## Architecture
|
| 110 |
|
| 111 |
### Component Overview
|
|
@@ -6,6 +6,8 @@ from typing import Any, Union
|
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
|
|
|
|
|
|
| 9 |
# Project root: two levels up from this file (agent/config.py -> project root)
|
| 10 |
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 11 |
from fastmcp.mcp_config import (
|
|
@@ -47,6 +49,104 @@ class Config(BaseModel):
|
|
| 47 |
# ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
|
| 48 |
# Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
|
| 49 |
reasoning_effort: str | None = "max"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def substitute_env_vars(obj: Any) -> Any:
|
|
@@ -86,7 +186,10 @@ def substitute_env_vars(obj: Any) -> Any:
|
|
| 86 |
return obj
|
| 87 |
|
| 88 |
|
| 89 |
-
def load_config(
|
|
|
|
|
|
|
|
|
|
| 90 |
"""
|
| 91 |
Load configuration with environment variable substitution.
|
| 92 |
|
|
@@ -98,8 +201,10 @@ def load_config(config_path: str = "config.json") -> Config:
|
|
| 98 |
load_dotenv(_PROJECT_ROOT / ".env")
|
| 99 |
load_dotenv(override=False)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
config_with_env = substitute_env_vars(raw_config)
|
| 105 |
return Config.model_validate(config_with_env)
|
|
|
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
| 9 |
+
from agent.messaging.models import MessagingConfig
|
| 10 |
+
|
| 11 |
# Project root: two levels up from this file (agent/config.py -> project root)
|
| 12 |
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 13 |
from fastmcp.mcp_config import (
|
|
|
|
| 49 |
# ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
|
| 50 |
# Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
|
| 51 |
reasoning_effort: str | None = "max"
|
| 52 |
+
messaging: MessagingConfig = MessagingConfig()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG"
|
| 56 |
+
DEFAULT_USER_CONFIG_PATH = Path.home() / ".config" / "ml-intern" / "cli_agent_config.json"
|
| 57 |
+
SLACK_DEFAULT_DESTINATION = "slack.default"
|
| 58 |
+
SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _deep_merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
| 62 |
+
merged = dict(base)
|
| 63 |
+
for key, value in override.items():
|
| 64 |
+
current = merged.get(key)
|
| 65 |
+
if isinstance(current, dict) and isinstance(value, dict):
|
| 66 |
+
merged[key] = _deep_merge_config(current, value)
|
| 67 |
+
else:
|
| 68 |
+
merged[key] = value
|
| 69 |
+
return merged
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _load_json_config(path: Path) -> dict[str, Any]:
|
| 73 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 74 |
+
data = json.load(f)
|
| 75 |
+
if not isinstance(data, dict):
|
| 76 |
+
raise ValueError(f"Config file {path} must contain a JSON object")
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _load_user_config() -> dict[str, Any]:
|
| 81 |
+
raw_path = os.environ.get(USER_CONFIG_ENV_VAR)
|
| 82 |
+
if raw_path:
|
| 83 |
+
path = Path(raw_path).expanduser()
|
| 84 |
+
if not path.exists():
|
| 85 |
+
raise FileNotFoundError(
|
| 86 |
+
f"{USER_CONFIG_ENV_VAR} points to missing config file: {path}"
|
| 87 |
+
)
|
| 88 |
+
return _load_json_config(path)
|
| 89 |
+
|
| 90 |
+
if DEFAULT_USER_CONFIG_PATH.exists():
|
| 91 |
+
return _load_json_config(DEFAULT_USER_CONFIG_PATH)
|
| 92 |
+
return {}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _env_bool(name: str, default: bool) -> bool:
|
| 96 |
+
value = os.environ.get(name)
|
| 97 |
+
if value is None:
|
| 98 |
+
return default
|
| 99 |
+
normalized = value.strip().lower()
|
| 100 |
+
if normalized in {"1", "true", "yes", "on"}:
|
| 101 |
+
return True
|
| 102 |
+
if normalized in {"0", "false", "no", "off"}:
|
| 103 |
+
return False
|
| 104 |
+
return default
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _env_list(name: str) -> list[str] | None:
|
| 108 |
+
value = os.environ.get(name)
|
| 109 |
+
if value is None:
|
| 110 |
+
return None
|
| 111 |
+
return [item.strip() for item in value.split(",") if item.strip()]
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def apply_slack_user_defaults(raw_config: dict[str, Any]) -> dict[str, Any]:
|
| 115 |
+
"""Enable a default Slack destination from user env vars, when present."""
|
| 116 |
+
if not _env_bool("ML_INTERN_SLACK_NOTIFICATIONS", True):
|
| 117 |
+
return raw_config
|
| 118 |
+
|
| 119 |
+
token = os.environ.get("SLACK_BOT_TOKEN")
|
| 120 |
+
channel = os.environ.get("SLACK_CHANNEL_ID") or os.environ.get("SLACK_CHANNEL")
|
| 121 |
+
if not token or not channel:
|
| 122 |
+
return raw_config
|
| 123 |
+
|
| 124 |
+
config = dict(raw_config)
|
| 125 |
+
messaging = dict(config.get("messaging") or {})
|
| 126 |
+
destinations = dict(messaging.get("destinations") or {})
|
| 127 |
+
destination_name = (
|
| 128 |
+
os.environ.get("ML_INTERN_SLACK_DESTINATION") or SLACK_DEFAULT_DESTINATION
|
| 129 |
+
).strip()
|
| 130 |
+
|
| 131 |
+
if destination_name not in destinations:
|
| 132 |
+
destinations[destination_name] = {
|
| 133 |
+
"provider": "slack",
|
| 134 |
+
"token": token,
|
| 135 |
+
"channel": channel,
|
| 136 |
+
"allow_agent_tool": _env_bool("ML_INTERN_SLACK_ALLOW_AGENT_TOOL", True),
|
| 137 |
+
"allow_auto_events": _env_bool("ML_INTERN_SLACK_ALLOW_AUTO_EVENTS", True),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
auto_events = _env_list("ML_INTERN_SLACK_AUTO_EVENTS")
|
| 141 |
+
if auto_events is not None:
|
| 142 |
+
messaging["auto_event_types"] = auto_events
|
| 143 |
+
elif "auto_event_types" not in messaging:
|
| 144 |
+
messaging["auto_event_types"] = SLACK_DEFAULT_AUTO_EVENT_TYPES
|
| 145 |
+
|
| 146 |
+
messaging["enabled"] = True
|
| 147 |
+
messaging["destinations"] = destinations
|
| 148 |
+
config["messaging"] = messaging
|
| 149 |
+
return config
|
| 150 |
|
| 151 |
|
| 152 |
def substitute_env_vars(obj: Any) -> Any:
|
|
|
|
| 186 |
return obj
|
| 187 |
|
| 188 |
|
| 189 |
+
def load_config(
|
| 190 |
+
config_path: str = "config.json",
|
| 191 |
+
include_user_defaults: bool = False,
|
| 192 |
+
) -> Config:
|
| 193 |
"""
|
| 194 |
Load configuration with environment variable substitution.
|
| 195 |
|
|
|
|
| 201 |
load_dotenv(_PROJECT_ROOT / ".env")
|
| 202 |
load_dotenv(override=False)
|
| 203 |
|
| 204 |
+
raw_config = _load_json_config(Path(config_path))
|
| 205 |
+
if include_user_defaults:
|
| 206 |
+
raw_config = _deep_merge_config(raw_config, _load_user_config())
|
| 207 |
+
raw_config = apply_slack_user_defaults(raw_config)
|
| 208 |
|
| 209 |
config_with_env = substitute_env_vars(raw_config)
|
| 210 |
return Config.model_validate(config_with_env)
|
|
@@ -19,6 +19,7 @@ from litellm import (
|
|
| 19 |
from litellm.exceptions import ContextWindowExceededError
|
| 20 |
|
| 21 |
from agent.config import Config
|
|
|
|
| 22 |
from agent.core import telemetry
|
| 23 |
from agent.core.doom_loop import check_for_doom_loop
|
| 24 |
from agent.core.llm_params import _resolve_llm_params
|
|
@@ -432,6 +433,103 @@ def _should_replay_thinking_state(model_name: str | None) -> bool:
|
|
| 432 |
return bool(model_name and model_name.startswith("anthropic/"))
|
| 433 |
|
| 434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
def _assistant_message_from_result(
|
| 436 |
llm_result: LLMResult,
|
| 437 |
*,
|
|
@@ -457,6 +555,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 457 |
"""Call the LLM with streaming, emitting assistant_chunk events."""
|
| 458 |
response = None
|
| 459 |
_healed_effort = False # one-shot safety net per call
|
|
|
|
| 460 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 461 |
t_start = time.monotonic()
|
| 462 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
|
@@ -484,6 +583,14 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 484 |
data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
|
| 485 |
))
|
| 486 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
_delay = _retry_delay_for(e, _llm_attempt)
|
| 488 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 489 |
logger.warning(
|
|
@@ -505,8 +612,6 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 505 |
final_usage_chunk = None
|
| 506 |
chunks = []
|
| 507 |
should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))
|
| 508 |
-
collected_thinking_blocks: list[dict[str, Any]] = []
|
| 509 |
-
collected_reasoning_content: list[str] = []
|
| 510 |
|
| 511 |
async for chunk in response:
|
| 512 |
chunks.append(chunk)
|
|
@@ -525,13 +630,6 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 525 |
if choice.finish_reason:
|
| 526 |
finish_reason = choice.finish_reason
|
| 527 |
|
| 528 |
-
if should_replay_thinking:
|
| 529 |
-
delta_thinking_blocks, delta_reasoning_content = _extract_thinking_state(delta)
|
| 530 |
-
if delta_thinking_blocks:
|
| 531 |
-
collected_thinking_blocks.extend(delta_thinking_blocks)
|
| 532 |
-
if delta_reasoning_content:
|
| 533 |
-
collected_reasoning_content.append(delta_reasoning_content)
|
| 534 |
-
|
| 535 |
if delta.content:
|
| 536 |
full_content += delta.content
|
| 537 |
await session.send_event(
|
|
@@ -565,9 +663,9 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
|
|
| 565 |
latency_ms=int((time.monotonic() - t_start) * 1000),
|
| 566 |
finish_reason=finish_reason,
|
| 567 |
)
|
| 568 |
-
thinking_blocks =
|
| 569 |
-
reasoning_content =
|
| 570 |
-
if chunks and should_replay_thinking
|
| 571 |
try:
|
| 572 |
rebuilt = stream_chunk_builder(chunks, messages=messages)
|
| 573 |
if rebuilt and getattr(rebuilt, "choices", None):
|
|
@@ -591,6 +689,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 591 |
"""Call the LLM without streaming, emit assistant_message at the end."""
|
| 592 |
response = None
|
| 593 |
_healed_effort = False
|
|
|
|
| 594 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 595 |
t_start = time.monotonic()
|
| 596 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
|
@@ -617,6 +716,14 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
|
|
| 617 |
data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
|
| 618 |
))
|
| 619 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
_delay = _retry_delay_for(e, _llm_attempt)
|
| 621 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 622 |
logger.warning(
|
|
@@ -1128,7 +1235,12 @@ class Handlers:
|
|
| 1128 |
await session.send_event(
|
| 1129 |
Event(
|
| 1130 |
event_type="turn_complete",
|
| 1131 |
-
data={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1132 |
)
|
| 1133 |
)
|
| 1134 |
|
|
@@ -1437,13 +1549,16 @@ async def process_submission(session: Session, submission) -> bool:
|
|
| 1437 |
async def submission_loop(
|
| 1438 |
submission_queue: asyncio.Queue,
|
| 1439 |
event_queue: asyncio.Queue,
|
| 1440 |
-
config: Config
|
| 1441 |
tool_router: ToolRouter | None = None,
|
| 1442 |
session_holder: list | None = None,
|
| 1443 |
hf_token: str | None = None,
|
| 1444 |
user_id: str | None = None,
|
| 1445 |
local_mode: bool = False,
|
| 1446 |
stream: bool = True,
|
|
|
|
|
|
|
|
|
|
| 1447 |
) -> None:
|
| 1448 |
"""
|
| 1449 |
Main agent loop - processes submissions and dispatches to handlers.
|
|
@@ -1454,6 +1569,9 @@ async def submission_loop(
|
|
| 1454 |
session = Session(
|
| 1455 |
event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
|
| 1456 |
user_id=user_id, local_mode=local_mode, stream=stream,
|
|
|
|
|
|
|
|
|
|
| 1457 |
)
|
| 1458 |
if session_holder is not None:
|
| 1459 |
session_holder[0] = session
|
|
|
|
| 19 |
from litellm.exceptions import ContextWindowExceededError
|
| 20 |
|
| 21 |
from agent.config import Config
|
| 22 |
+
from agent.messaging.gateway import NotificationGateway
|
| 23 |
from agent.core import telemetry
|
| 24 |
from agent.core.doom_loop import check_for_doom_loop
|
| 25 |
from agent.core.llm_params import _resolve_llm_params
|
|
|
|
| 433 |
return bool(model_name and model_name.startswith("anthropic/"))
|
| 434 |
|
| 435 |
|
| 436 |
+
def _is_invalid_thinking_signature_error(exc: Exception) -> bool:
|
| 437 |
+
"""Return True when Anthropic rejected replayed extended-thinking state."""
|
| 438 |
+
text = str(exc)
|
| 439 |
+
return (
|
| 440 |
+
"Invalid `signature` in `thinking` block" in text
|
| 441 |
+
or "Invalid signature in thinking block" in text
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _strip_thinking_state_from_messages(messages: list[Any]) -> int:
|
| 446 |
+
"""Remove replayed thinking metadata from assistant history messages."""
|
| 447 |
+
stripped = 0
|
| 448 |
+
|
| 449 |
+
for message in messages:
|
| 450 |
+
role = (
|
| 451 |
+
message.get("role")
|
| 452 |
+
if isinstance(message, dict)
|
| 453 |
+
else getattr(message, "role", None)
|
| 454 |
+
)
|
| 455 |
+
if role != "assistant":
|
| 456 |
+
continue
|
| 457 |
+
|
| 458 |
+
if isinstance(message, dict):
|
| 459 |
+
if message.pop("thinking_blocks", None) is not None:
|
| 460 |
+
stripped += 1
|
| 461 |
+
if message.pop("reasoning_content", None) is not None:
|
| 462 |
+
stripped += 1
|
| 463 |
+
provider_fields = message.get("provider_specific_fields")
|
| 464 |
+
content = message.get("content")
|
| 465 |
+
else:
|
| 466 |
+
if getattr(message, "thinking_blocks", None) is not None:
|
| 467 |
+
message.thinking_blocks = None
|
| 468 |
+
stripped += 1
|
| 469 |
+
if getattr(message, "reasoning_content", None) is not None:
|
| 470 |
+
message.reasoning_content = None
|
| 471 |
+
stripped += 1
|
| 472 |
+
provider_fields = getattr(message, "provider_specific_fields", None)
|
| 473 |
+
content = getattr(message, "content", None)
|
| 474 |
+
|
| 475 |
+
if isinstance(provider_fields, dict):
|
| 476 |
+
cleaned_fields = dict(provider_fields)
|
| 477 |
+
if cleaned_fields.pop("thinking_blocks", None) is not None:
|
| 478 |
+
stripped += 1
|
| 479 |
+
if cleaned_fields.pop("reasoning_content", None) is not None:
|
| 480 |
+
stripped += 1
|
| 481 |
+
if cleaned_fields != provider_fields:
|
| 482 |
+
if isinstance(message, dict):
|
| 483 |
+
message["provider_specific_fields"] = cleaned_fields
|
| 484 |
+
else:
|
| 485 |
+
message.provider_specific_fields = cleaned_fields
|
| 486 |
+
|
| 487 |
+
if isinstance(content, list):
|
| 488 |
+
cleaned_content = [
|
| 489 |
+
block
|
| 490 |
+
for block in content
|
| 491 |
+
if not (
|
| 492 |
+
isinstance(block, dict)
|
| 493 |
+
and block.get("type") in {"thinking", "redacted_thinking"}
|
| 494 |
+
)
|
| 495 |
+
]
|
| 496 |
+
if len(cleaned_content) != len(content):
|
| 497 |
+
stripped += len(content) - len(cleaned_content)
|
| 498 |
+
if isinstance(message, dict):
|
| 499 |
+
message["content"] = cleaned_content
|
| 500 |
+
else:
|
| 501 |
+
message.content = cleaned_content
|
| 502 |
+
|
| 503 |
+
return stripped
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
async def _maybe_heal_invalid_thinking_signature(
|
| 507 |
+
session: Session,
|
| 508 |
+
messages: list[Any],
|
| 509 |
+
exc: Exception,
|
| 510 |
+
*,
|
| 511 |
+
already_healed: bool,
|
| 512 |
+
) -> bool:
|
| 513 |
+
if already_healed or not _is_invalid_thinking_signature_error(exc):
|
| 514 |
+
return False
|
| 515 |
+
|
| 516 |
+
stripped = _strip_thinking_state_from_messages(messages)
|
| 517 |
+
if not stripped:
|
| 518 |
+
return False
|
| 519 |
+
|
| 520 |
+
await session.send_event(Event(
|
| 521 |
+
event_type="tool_log",
|
| 522 |
+
data={
|
| 523 |
+
"tool": "system",
|
| 524 |
+
"log": (
|
| 525 |
+
"Anthropic rejected stale thinking signatures; retrying "
|
| 526 |
+
"without replayed thinking metadata."
|
| 527 |
+
),
|
| 528 |
+
},
|
| 529 |
+
))
|
| 530 |
+
return True
|
| 531 |
+
|
| 532 |
+
|
| 533 |
def _assistant_message_from_result(
|
| 534 |
llm_result: LLMResult,
|
| 535 |
*,
|
|
|
|
| 555 |
"""Call the LLM with streaming, emitting assistant_chunk events."""
|
| 556 |
response = None
|
| 557 |
_healed_effort = False # one-shot safety net per call
|
| 558 |
+
_healed_thinking_signature = False
|
| 559 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 560 |
t_start = time.monotonic()
|
| 561 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
|
|
|
| 583 |
data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
|
| 584 |
))
|
| 585 |
continue
|
| 586 |
+
if await _maybe_heal_invalid_thinking_signature(
|
| 587 |
+
session,
|
| 588 |
+
messages,
|
| 589 |
+
e,
|
| 590 |
+
already_healed=_healed_thinking_signature,
|
| 591 |
+
):
|
| 592 |
+
_healed_thinking_signature = True
|
| 593 |
+
continue
|
| 594 |
_delay = _retry_delay_for(e, _llm_attempt)
|
| 595 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 596 |
logger.warning(
|
|
|
|
| 612 |
final_usage_chunk = None
|
| 613 |
chunks = []
|
| 614 |
should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))
|
|
|
|
|
|
|
| 615 |
|
| 616 |
async for chunk in response:
|
| 617 |
chunks.append(chunk)
|
|
|
|
| 630 |
if choice.finish_reason:
|
| 631 |
finish_reason = choice.finish_reason
|
| 632 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
if delta.content:
|
| 634 |
full_content += delta.content
|
| 635 |
await session.send_event(
|
|
|
|
| 663 |
latency_ms=int((time.monotonic() - t_start) * 1000),
|
| 664 |
finish_reason=finish_reason,
|
| 665 |
)
|
| 666 |
+
thinking_blocks = None
|
| 667 |
+
reasoning_content = None
|
| 668 |
+
if chunks and should_replay_thinking:
|
| 669 |
try:
|
| 670 |
rebuilt = stream_chunk_builder(chunks, messages=messages)
|
| 671 |
if rebuilt and getattr(rebuilt, "choices", None):
|
|
|
|
| 689 |
"""Call the LLM without streaming, emit assistant_message at the end."""
|
| 690 |
response = None
|
| 691 |
_healed_effort = False
|
| 692 |
+
_healed_thinking_signature = False
|
| 693 |
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
|
| 694 |
t_start = time.monotonic()
|
| 695 |
for _llm_attempt in range(_MAX_LLM_RETRIES):
|
|
|
|
| 716 |
data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
|
| 717 |
))
|
| 718 |
continue
|
| 719 |
+
if await _maybe_heal_invalid_thinking_signature(
|
| 720 |
+
session,
|
| 721 |
+
messages,
|
| 722 |
+
e,
|
| 723 |
+
already_healed=_healed_thinking_signature,
|
| 724 |
+
):
|
| 725 |
+
_healed_thinking_signature = True
|
| 726 |
+
continue
|
| 727 |
_delay = _retry_delay_for(e, _llm_attempt)
|
| 728 |
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
|
| 729 |
logger.warning(
|
|
|
|
| 1235 |
await session.send_event(
|
| 1236 |
Event(
|
| 1237 |
event_type="turn_complete",
|
| 1238 |
+
data={
|
| 1239 |
+
"history_size": len(session.context_manager.items),
|
| 1240 |
+
"final_response": final_response
|
| 1241 |
+
if isinstance(final_response, str)
|
| 1242 |
+
else None,
|
| 1243 |
+
},
|
| 1244 |
)
|
| 1245 |
)
|
| 1246 |
|
|
|
|
| 1549 |
async def submission_loop(
|
| 1550 |
submission_queue: asyncio.Queue,
|
| 1551 |
event_queue: asyncio.Queue,
|
| 1552 |
+
config: Config,
|
| 1553 |
tool_router: ToolRouter | None = None,
|
| 1554 |
session_holder: list | None = None,
|
| 1555 |
hf_token: str | None = None,
|
| 1556 |
user_id: str | None = None,
|
| 1557 |
local_mode: bool = False,
|
| 1558 |
stream: bool = True,
|
| 1559 |
+
notification_gateway: NotificationGateway | None = None,
|
| 1560 |
+
notification_destinations: list[str] | None = None,
|
| 1561 |
+
defer_turn_complete_notification: bool = False,
|
| 1562 |
) -> None:
|
| 1563 |
"""
|
| 1564 |
Main agent loop - processes submissions and dispatches to handlers.
|
|
|
|
| 1569 |
session = Session(
|
| 1570 |
event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
|
| 1571 |
user_id=user_id, local_mode=local_mode, stream=stream,
|
| 1572 |
+
notification_gateway=notification_gateway,
|
| 1573 |
+
notification_destinations=notification_destinations,
|
| 1574 |
+
defer_turn_complete_notification=defer_turn_complete_notification,
|
| 1575 |
)
|
| 1576 |
if session_holder is not None:
|
| 1577 |
session_holder[0] = session
|
|
@@ -12,10 +12,13 @@ from typing import Any, Optional
|
|
| 12 |
|
| 13 |
from agent.config import Config
|
| 14 |
from agent.context_manager.manager import ContextManager
|
|
|
|
|
|
|
| 15 |
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
_DEFAULT_MAX_TOKENS = 200_000
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def _get_max_tokens_safe(model_name: str) -> int:
|
|
@@ -73,18 +76,24 @@ class Session:
|
|
| 73 |
def __init__(
|
| 74 |
self,
|
| 75 |
event_queue: asyncio.Queue,
|
| 76 |
-
config: Config
|
| 77 |
tool_router=None,
|
| 78 |
context_manager: ContextManager | None = None,
|
| 79 |
hf_token: str | None = None,
|
| 80 |
local_mode: bool = False,
|
| 81 |
stream: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
user_id: str | None = None,
|
| 83 |
):
|
| 84 |
self.hf_token: Optional[str] = hf_token
|
| 85 |
self.user_id: Optional[str] = user_id
|
| 86 |
self.tool_router = tool_router
|
| 87 |
self.stream = stream
|
|
|
|
|
|
|
| 88 |
tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
|
| 89 |
self.context_manager = context_manager or ContextManager(
|
| 90 |
model_max_tokens=_get_max_tokens_safe(config.model_name),
|
|
@@ -95,15 +104,16 @@ class Session:
|
|
| 95 |
local_mode=local_mode,
|
| 96 |
)
|
| 97 |
self.event_queue = event_queue
|
| 98 |
-
self.session_id = str(uuid.uuid4())
|
| 99 |
-
self.config = config
|
| 100 |
-
model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
| 101 |
-
)
|
| 102 |
self.is_running = True
|
| 103 |
self._cancelled = asyncio.Event()
|
| 104 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 105 |
self.sandbox = None
|
| 106 |
self._running_job_ids: set[str] = set() # HF job IDs currently executing
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# Session trajectory logging
|
| 109 |
self.logged_events: list[dict] = []
|
|
@@ -138,11 +148,121 @@ class Session:
|
|
| 138 |
"data": event.data,
|
| 139 |
}
|
| 140 |
)
|
|
|
|
| 141 |
|
| 142 |
# Mid-turn heartbeat flush (owned by telemetry module).
|
| 143 |
from agent.core.telemetry import HeartbeatSaver
|
|
|
|
| 144 |
HeartbeatSaver.maybe_fire(self)
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
def cancel(self) -> None:
|
| 147 |
"""Signal cancellation to the running agent loop."""
|
| 148 |
self._cancelled.set()
|
|
|
|
| 12 |
|
| 13 |
from agent.config import Config
|
| 14 |
from agent.context_manager.manager import ContextManager
|
| 15 |
+
from agent.messaging.gateway import NotificationGateway
|
| 16 |
+
from agent.messaging.models import NotificationRequest
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
_DEFAULT_MAX_TOKENS = 200_000
|
| 21 |
+
_TURN_COMPLETE_NOTIFICATION_CHARS = 39000
|
| 22 |
|
| 23 |
|
| 24 |
def _get_max_tokens_safe(model_name: str) -> int:
|
|
|
|
| 76 |
def __init__(
|
| 77 |
self,
|
| 78 |
event_queue: asyncio.Queue,
|
| 79 |
+
config: Config,
|
| 80 |
tool_router=None,
|
| 81 |
context_manager: ContextManager | None = None,
|
| 82 |
hf_token: str | None = None,
|
| 83 |
local_mode: bool = False,
|
| 84 |
stream: bool = True,
|
| 85 |
+
notification_gateway: NotificationGateway | None = None,
|
| 86 |
+
notification_destinations: list[str] | None = None,
|
| 87 |
+
defer_turn_complete_notification: bool = False,
|
| 88 |
+
session_id: str | None = None,
|
| 89 |
user_id: str | None = None,
|
| 90 |
):
|
| 91 |
self.hf_token: Optional[str] = hf_token
|
| 92 |
self.user_id: Optional[str] = user_id
|
| 93 |
self.tool_router = tool_router
|
| 94 |
self.stream = stream
|
| 95 |
+
if config is None:
|
| 96 |
+
raise ValueError("Session requires a Config")
|
| 97 |
tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
|
| 98 |
self.context_manager = context_manager or ContextManager(
|
| 99 |
model_max_tokens=_get_max_tokens_safe(config.model_name),
|
|
|
|
| 104 |
local_mode=local_mode,
|
| 105 |
)
|
| 106 |
self.event_queue = event_queue
|
| 107 |
+
self.session_id = session_id or str(uuid.uuid4())
|
| 108 |
+
self.config = config
|
|
|
|
|
|
|
| 109 |
self.is_running = True
|
| 110 |
self._cancelled = asyncio.Event()
|
| 111 |
self.pending_approval: Optional[dict[str, Any]] = None
|
| 112 |
self.sandbox = None
|
| 113 |
self._running_job_ids: set[str] = set() # HF job IDs currently executing
|
| 114 |
+
self.notification_gateway = notification_gateway
|
| 115 |
+
self.notification_destinations = list(notification_destinations or [])
|
| 116 |
+
self.defer_turn_complete_notification = defer_turn_complete_notification
|
| 117 |
|
| 118 |
# Session trajectory logging
|
| 119 |
self.logged_events: list[dict] = []
|
|
|
|
| 148 |
"data": event.data,
|
| 149 |
}
|
| 150 |
)
|
| 151 |
+
await self._enqueue_auto_notification_requests(event)
|
| 152 |
|
| 153 |
# Mid-turn heartbeat flush (owned by telemetry module).
|
| 154 |
from agent.core.telemetry import HeartbeatSaver
|
| 155 |
+
|
| 156 |
HeartbeatSaver.maybe_fire(self)
|
| 157 |
|
| 158 |
+
def set_notification_destinations(self, destinations: list[str]) -> None:
|
| 159 |
+
"""Replace the session's opted-in auto-notification destinations."""
|
| 160 |
+
deduped: list[str] = []
|
| 161 |
+
seen: set[str] = set()
|
| 162 |
+
for destination in destinations:
|
| 163 |
+
if destination not in seen:
|
| 164 |
+
deduped.append(destination)
|
| 165 |
+
seen.add(destination)
|
| 166 |
+
self.notification_destinations = deduped
|
| 167 |
+
|
| 168 |
+
async def send_deferred_turn_complete_notification(self, event: Event) -> None:
|
| 169 |
+
if event.event_type != "turn_complete":
|
| 170 |
+
return
|
| 171 |
+
await self._enqueue_auto_notification_requests(
|
| 172 |
+
event,
|
| 173 |
+
include_deferred_turn_complete=True,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
async def _enqueue_auto_notification_requests(
|
| 177 |
+
self,
|
| 178 |
+
event: Event,
|
| 179 |
+
include_deferred_turn_complete: bool = False,
|
| 180 |
+
) -> None:
|
| 181 |
+
if self.notification_gateway is None:
|
| 182 |
+
return
|
| 183 |
+
if not self.notification_destinations:
|
| 184 |
+
return
|
| 185 |
+
auto_events = set(self.config.messaging.auto_event_types)
|
| 186 |
+
if event.event_type not in auto_events:
|
| 187 |
+
return
|
| 188 |
+
if (
|
| 189 |
+
self.defer_turn_complete_notification
|
| 190 |
+
and event.event_type == "turn_complete"
|
| 191 |
+
and not include_deferred_turn_complete
|
| 192 |
+
):
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
requests = self._build_auto_notification_requests(event)
|
| 196 |
+
for request in requests:
|
| 197 |
+
await self.notification_gateway.enqueue(request)
|
| 198 |
+
|
| 199 |
+
def _build_auto_notification_requests(
|
| 200 |
+
self, event: Event
|
| 201 |
+
) -> list[NotificationRequest]:
|
| 202 |
+
metadata = {
|
| 203 |
+
"session_id": self.session_id,
|
| 204 |
+
"model": self.config.model_name,
|
| 205 |
+
"event_type": event.event_type,
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
title: str | None = None
|
| 209 |
+
message: str | None = None
|
| 210 |
+
severity = "info"
|
| 211 |
+
data = event.data or {}
|
| 212 |
+
if event.event_type == "approval_required":
|
| 213 |
+
tools = data.get("tools", [])
|
| 214 |
+
tool_names = []
|
| 215 |
+
for tool in tools if isinstance(tools, list) else []:
|
| 216 |
+
if isinstance(tool, dict):
|
| 217 |
+
tool_name = str(tool.get("tool") or "").strip()
|
| 218 |
+
if tool_name and tool_name not in tool_names:
|
| 219 |
+
tool_names.append(tool_name)
|
| 220 |
+
count = len(tools) if isinstance(tools, list) else 0
|
| 221 |
+
title = "Agent approval required"
|
| 222 |
+
message = (
|
| 223 |
+
f"Session {self.session_id} is waiting for approval "
|
| 224 |
+
f"for {count} tool call(s)."
|
| 225 |
+
)
|
| 226 |
+
if tool_names:
|
| 227 |
+
message += " Tools: " + ", ".join(tool_names)
|
| 228 |
+
severity = "warning"
|
| 229 |
+
elif event.event_type == "error":
|
| 230 |
+
title = "Agent error"
|
| 231 |
+
error = str(data.get("error") or "Unknown error")
|
| 232 |
+
message = f"Session {self.session_id} hit an error.\n{error[:500]}"
|
| 233 |
+
severity = "error"
|
| 234 |
+
elif event.event_type == "turn_complete":
|
| 235 |
+
title = "Agent task complete"
|
| 236 |
+
summary = str(data.get("final_response") or "").strip()
|
| 237 |
+
if summary:
|
| 238 |
+
summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS]
|
| 239 |
+
message = (
|
| 240 |
+
f"Session {self.session_id} completed successfully.\n"
|
| 241 |
+
f"{summary}"
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
message = f"Session {self.session_id} completed successfully."
|
| 245 |
+
severity = "success"
|
| 246 |
+
|
| 247 |
+
if message is None:
|
| 248 |
+
return []
|
| 249 |
+
|
| 250 |
+
requests: list[NotificationRequest] = []
|
| 251 |
+
for destination in self.notification_destinations:
|
| 252 |
+
if not self.config.messaging.can_auto_send(destination):
|
| 253 |
+
continue
|
| 254 |
+
requests.append(
|
| 255 |
+
NotificationRequest(
|
| 256 |
+
destination=destination,
|
| 257 |
+
title=title,
|
| 258 |
+
message=message,
|
| 259 |
+
severity=severity,
|
| 260 |
+
metadata=metadata,
|
| 261 |
+
event_type=event.event_type,
|
| 262 |
+
)
|
| 263 |
+
)
|
| 264 |
+
return requests
|
| 265 |
+
|
| 266 |
def cancel(self) -> None:
|
| 267 |
"""Signal cancellation to the running agent loop."""
|
| 268 |
self._cancelled.set()
|
|
@@ -46,6 +46,7 @@ from agent.tools.hf_repo_git_tool import (
|
|
| 46 |
hf_repo_git_handler,
|
| 47 |
)
|
| 48 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
|
|
|
|
| 49 |
from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
|
| 50 |
from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
|
| 51 |
from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler
|
|
@@ -324,6 +325,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]:
|
|
| 324 |
parameters=PLAN_TOOL_SPEC["parameters"],
|
| 325 |
handler=plan_tool_handler,
|
| 326 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
ToolSpec(
|
| 328 |
name=HF_JOBS_TOOL_SPEC["name"],
|
| 329 |
description=HF_JOBS_TOOL_SPEC["description"],
|
|
|
|
| 46 |
hf_repo_git_handler,
|
| 47 |
)
|
| 48 |
from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
|
| 49 |
+
from agent.tools.notify_tool import NOTIFY_TOOL_SPEC, notify_handler
|
| 50 |
from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
|
| 51 |
from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
|
| 52 |
from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler
|
|
|
|
| 325 |
parameters=PLAN_TOOL_SPEC["parameters"],
|
| 326 |
handler=plan_tool_handler,
|
| 327 |
),
|
| 328 |
+
ToolSpec(
|
| 329 |
+
name=NOTIFY_TOOL_SPEC["name"],
|
| 330 |
+
description=NOTIFY_TOOL_SPEC["description"],
|
| 331 |
+
parameters=NOTIFY_TOOL_SPEC["parameters"],
|
| 332 |
+
handler=notify_handler,
|
| 333 |
+
),
|
| 334 |
ToolSpec(
|
| 335 |
name=HF_JOBS_TOOL_SPEC["name"],
|
| 336 |
description=HF_JOBS_TOOL_SPEC["description"],
|
|
@@ -26,6 +26,7 @@ from agent.core import model_switcher
|
|
| 26 |
from agent.core.hf_tokens import resolve_hf_token
|
| 27 |
from agent.core.session import OpType
|
| 28 |
from agent.core.tools import ToolRouter
|
|
|
|
| 29 |
from agent.utils.reliability_checks import check_training_script_save_pattern
|
| 30 |
from agent.utils.terminal_display import (
|
| 31 |
get_console,
|
|
@@ -332,6 +333,9 @@ async def event_listener(
|
|
| 332 |
stream_buf.discard()
|
| 333 |
print_turn_complete()
|
| 334 |
print_plan()
|
|
|
|
|
|
|
|
|
|
| 335 |
turn_complete_event.set()
|
| 336 |
elif event.event_type == "interrupted":
|
| 337 |
shimmer.stop()
|
|
@@ -821,7 +825,7 @@ async def main(model: str | None = None):
|
|
| 821 |
if not hf_token:
|
| 822 |
hf_token = await _prompt_and_save_hf_token(prompt_session)
|
| 823 |
|
| 824 |
-
config = load_config(CLI_CONFIG_PATH)
|
| 825 |
if model:
|
| 826 |
config.model_name = model
|
| 827 |
|
|
@@ -844,6 +848,8 @@ async def main(model: str | None = None):
|
|
| 844 |
turn_complete_event.set()
|
| 845 |
ready_event = asyncio.Event()
|
| 846 |
|
|
|
|
|
|
|
| 847 |
# Create tool router with local mode
|
| 848 |
tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
|
| 849 |
|
|
@@ -861,6 +867,9 @@ async def main(model: str | None = None):
|
|
| 861 |
user_id=hf_user,
|
| 862 |
local_mode=True,
|
| 863 |
stream=True,
|
|
|
|
|
|
|
|
|
|
| 864 |
)
|
| 865 |
)
|
| 866 |
|
|
@@ -1016,6 +1025,8 @@ async def main(model: str | None = None):
|
|
| 1016 |
agent_task.cancel()
|
| 1017 |
# Agent didn't shut down cleanly — close MCP explicitly
|
| 1018 |
await tool_router.__aexit__(None, None, None)
|
|
|
|
|
|
|
| 1019 |
|
| 1020 |
# Now safe to cancel the listener (agent is done emitting events)
|
| 1021 |
listener_task.cancel()
|
|
@@ -1042,8 +1053,10 @@ async def headless_main(
|
|
| 1042 |
|
| 1043 |
print(f"HF token loaded", file=sys.stderr)
|
| 1044 |
|
| 1045 |
-
config = load_config(CLI_CONFIG_PATH)
|
| 1046 |
config.yolo_mode = True # Auto-approve everything in headless mode
|
|
|
|
|
|
|
| 1047 |
hf_user = _get_hf_user(hf_token)
|
| 1048 |
|
| 1049 |
if model:
|
|
@@ -1074,6 +1087,9 @@ async def headless_main(
|
|
| 1074 |
user_id=hf_user,
|
| 1075 |
local_mode=True,
|
| 1076 |
stream=stream,
|
|
|
|
|
|
|
|
|
|
| 1077 |
)
|
| 1078 |
)
|
| 1079 |
|
|
@@ -1199,6 +1215,10 @@ async def headless_main(
|
|
| 1199 |
stream_buf.discard()
|
| 1200 |
history_size = event.data.get("history_size", "?") if event.data else "?"
|
| 1201 |
print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1202 |
break
|
| 1203 |
|
| 1204 |
# Shutdown
|
|
@@ -1212,6 +1232,8 @@ async def headless_main(
|
|
| 1212 |
except asyncio.TimeoutError:
|
| 1213 |
agent_task.cancel()
|
| 1214 |
await tool_router.__aexit__(None, None, None)
|
|
|
|
|
|
|
| 1215 |
|
| 1216 |
|
| 1217 |
def cli():
|
|
|
|
| 26 |
from agent.core.hf_tokens import resolve_hf_token
|
| 27 |
from agent.core.session import OpType
|
| 28 |
from agent.core.tools import ToolRouter
|
| 29 |
+
from agent.messaging.gateway import NotificationGateway
|
| 30 |
from agent.utils.reliability_checks import check_training_script_save_pattern
|
| 31 |
from agent.utils.terminal_display import (
|
| 32 |
get_console,
|
|
|
|
| 333 |
stream_buf.discard()
|
| 334 |
print_turn_complete()
|
| 335 |
print_plan()
|
| 336 |
+
session = session_holder[0] if session_holder else None
|
| 337 |
+
if session is not None:
|
| 338 |
+
await session.send_deferred_turn_complete_notification(event)
|
| 339 |
turn_complete_event.set()
|
| 340 |
elif event.event_type == "interrupted":
|
| 341 |
shimmer.stop()
|
|
|
|
| 825 |
if not hf_token:
|
| 826 |
hf_token = await _prompt_and_save_hf_token(prompt_session)
|
| 827 |
|
| 828 |
+
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 829 |
if model:
|
| 830 |
config.model_name = model
|
| 831 |
|
|
|
|
| 848 |
turn_complete_event.set()
|
| 849 |
ready_event = asyncio.Event()
|
| 850 |
|
| 851 |
+
notification_gateway = NotificationGateway(config.messaging)
|
| 852 |
+
await notification_gateway.start()
|
| 853 |
# Create tool router with local mode
|
| 854 |
tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
|
| 855 |
|
|
|
|
| 867 |
user_id=hf_user,
|
| 868 |
local_mode=True,
|
| 869 |
stream=True,
|
| 870 |
+
notification_gateway=notification_gateway,
|
| 871 |
+
notification_destinations=config.messaging.default_auto_destinations(),
|
| 872 |
+
defer_turn_complete_notification=True,
|
| 873 |
)
|
| 874 |
)
|
| 875 |
|
|
|
|
| 1025 |
agent_task.cancel()
|
| 1026 |
# Agent didn't shut down cleanly — close MCP explicitly
|
| 1027 |
await tool_router.__aexit__(None, None, None)
|
| 1028 |
+
finally:
|
| 1029 |
+
await notification_gateway.close()
|
| 1030 |
|
| 1031 |
# Now safe to cancel the listener (agent is done emitting events)
|
| 1032 |
listener_task.cancel()
|
|
|
|
| 1053 |
|
| 1054 |
print(f"HF token loaded", file=sys.stderr)
|
| 1055 |
|
| 1056 |
+
config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
|
| 1057 |
config.yolo_mode = True # Auto-approve everything in headless mode
|
| 1058 |
+
notification_gateway = NotificationGateway(config.messaging)
|
| 1059 |
+
await notification_gateway.start()
|
| 1060 |
hf_user = _get_hf_user(hf_token)
|
| 1061 |
|
| 1062 |
if model:
|
|
|
|
| 1087 |
user_id=hf_user,
|
| 1088 |
local_mode=True,
|
| 1089 |
stream=stream,
|
| 1090 |
+
notification_gateway=notification_gateway,
|
| 1091 |
+
notification_destinations=config.messaging.default_auto_destinations(),
|
| 1092 |
+
defer_turn_complete_notification=True,
|
| 1093 |
)
|
| 1094 |
)
|
| 1095 |
|
|
|
|
| 1215 |
stream_buf.discard()
|
| 1216 |
history_size = event.data.get("history_size", "?") if event.data else "?"
|
| 1217 |
print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr)
|
| 1218 |
+
if event.event_type == "turn_complete":
|
| 1219 |
+
session = session_holder[0] if session_holder else None
|
| 1220 |
+
if session is not None:
|
| 1221 |
+
await session.send_deferred_turn_complete_notification(event)
|
| 1222 |
break
|
| 1223 |
|
| 1224 |
# Shutdown
|
|
|
|
| 1232 |
except asyncio.TimeoutError:
|
| 1233 |
agent_task.cancel()
|
| 1234 |
await tool_router.__aexit__(None, None, None)
|
| 1235 |
+
finally:
|
| 1236 |
+
await notification_gateway.close()
|
| 1237 |
|
| 1238 |
|
| 1239 |
def cli():
|
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agent.messaging.gateway import NotificationGateway
|
| 2 |
+
from agent.messaging.models import (
|
| 3 |
+
MessagingConfig,
|
| 4 |
+
NotificationRequest,
|
| 5 |
+
NotificationResult,
|
| 6 |
+
SUPPORTED_AUTO_EVENT_TYPES,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"MessagingConfig",
|
| 11 |
+
"NotificationGateway",
|
| 12 |
+
"NotificationRequest",
|
| 13 |
+
"NotificationResult",
|
| 14 |
+
"SUPPORTED_AUTO_EVENT_TYPES",
|
| 15 |
+
]
|
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
|
| 5 |
+
from agent.messaging.models import DestinationConfig, NotificationRequest, NotificationResult
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NotificationError(Exception):
|
| 9 |
+
"""Delivery failed and should not be retried."""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RetryableNotificationError(NotificationError):
|
| 13 |
+
"""Delivery failed transiently and can be retried."""
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NotificationProvider(ABC):
|
| 17 |
+
provider_name: str
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
async def send(
|
| 21 |
+
self,
|
| 22 |
+
client: httpx.AsyncClient,
|
| 23 |
+
destination_name: str,
|
| 24 |
+
destination: DestinationConfig,
|
| 25 |
+
request: NotificationRequest,
|
| 26 |
+
) -> NotificationResult:
|
| 27 |
+
"""Deliver a notification to one destination."""
|
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
|
| 5 |
+
import httpx
|
| 6 |
+
|
| 7 |
+
from agent.messaging.base import (
|
| 8 |
+
NotificationError,
|
| 9 |
+
NotificationProvider,
|
| 10 |
+
RetryableNotificationError,
|
| 11 |
+
)
|
| 12 |
+
from agent.messaging.models import (
|
| 13 |
+
MessagingConfig,
|
| 14 |
+
NotificationRequest,
|
| 15 |
+
NotificationResult,
|
| 16 |
+
)
|
| 17 |
+
from agent.messaging.slack import SlackProvider
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
_RETRY_DELAYS = (1, 2, 4)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NotificationGateway:
|
| 25 |
+
def __init__(self, config: MessagingConfig):
|
| 26 |
+
self.config = config
|
| 27 |
+
self._providers: dict[str, NotificationProvider] = {
|
| 28 |
+
"slack": SlackProvider(),
|
| 29 |
+
}
|
| 30 |
+
self._queue: asyncio.Queue[NotificationRequest] = asyncio.Queue()
|
| 31 |
+
self._worker_task: asyncio.Task | None = None
|
| 32 |
+
self._client: httpx.AsyncClient | None = None
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def enabled(self) -> bool:
|
| 36 |
+
return self.config.enabled
|
| 37 |
+
|
| 38 |
+
async def start(self) -> None:
|
| 39 |
+
if not self.enabled or self._worker_task is not None:
|
| 40 |
+
return
|
| 41 |
+
self._client = httpx.AsyncClient(timeout=10.0)
|
| 42 |
+
self._worker_task = asyncio.create_task(self._worker(), name="notification-gateway")
|
| 43 |
+
|
| 44 |
+
async def flush(self) -> None:
|
| 45 |
+
if not self.enabled:
|
| 46 |
+
return
|
| 47 |
+
await self._queue.join()
|
| 48 |
+
|
| 49 |
+
async def close(self) -> None:
|
| 50 |
+
if not self.enabled:
|
| 51 |
+
return
|
| 52 |
+
await self.flush()
|
| 53 |
+
if self._worker_task is not None:
|
| 54 |
+
self._worker_task.cancel()
|
| 55 |
+
try:
|
| 56 |
+
await self._worker_task
|
| 57 |
+
except asyncio.CancelledError:
|
| 58 |
+
pass
|
| 59 |
+
self._worker_task = None
|
| 60 |
+
if self._client is not None:
|
| 61 |
+
await self._client.aclose()
|
| 62 |
+
self._client = None
|
| 63 |
+
|
| 64 |
+
async def send(self, request: NotificationRequest) -> NotificationResult:
|
| 65 |
+
if not self.enabled:
|
| 66 |
+
return NotificationResult(
|
| 67 |
+
destination=request.destination,
|
| 68 |
+
ok=False,
|
| 69 |
+
provider="disabled",
|
| 70 |
+
error="Messaging is disabled",
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
destination = self.config.get_destination(request.destination)
|
| 74 |
+
if destination is None:
|
| 75 |
+
return NotificationResult(
|
| 76 |
+
destination=request.destination,
|
| 77 |
+
ok=False,
|
| 78 |
+
provider="unknown",
|
| 79 |
+
error=f"Unknown destination '{request.destination}'",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
provider = self._providers.get(destination.provider)
|
| 83 |
+
if provider is None:
|
| 84 |
+
return NotificationResult(
|
| 85 |
+
destination=request.destination,
|
| 86 |
+
ok=False,
|
| 87 |
+
provider=destination.provider,
|
| 88 |
+
error=f"No provider implementation for '{destination.provider}'",
|
| 89 |
+
)
|
| 90 |
+
return await self._send_with_retries(provider, request.destination, destination, request)
|
| 91 |
+
|
| 92 |
+
async def send_many(
|
| 93 |
+
self, requests: Iterable[NotificationRequest]
|
| 94 |
+
) -> list[NotificationResult]:
|
| 95 |
+
results: list[NotificationResult] = []
|
| 96 |
+
for request in requests:
|
| 97 |
+
results.append(await self.send(request))
|
| 98 |
+
return results
|
| 99 |
+
|
| 100 |
+
async def enqueue(self, request: NotificationRequest) -> bool:
|
| 101 |
+
if not self.enabled or self._worker_task is None:
|
| 102 |
+
return False
|
| 103 |
+
await self._queue.put(request)
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
async def _worker(self) -> None:
|
| 107 |
+
while True:
|
| 108 |
+
request = await self._queue.get()
|
| 109 |
+
try:
|
| 110 |
+
result = await self.send(request)
|
| 111 |
+
if not result.ok:
|
| 112 |
+
logger.warning(
|
| 113 |
+
"Notification delivery failed for %s: %s",
|
| 114 |
+
request.destination,
|
| 115 |
+
result.error,
|
| 116 |
+
)
|
| 117 |
+
except Exception:
|
| 118 |
+
logger.exception("Unexpected notification worker failure")
|
| 119 |
+
finally:
|
| 120 |
+
self._queue.task_done()
|
| 121 |
+
|
| 122 |
+
async def _send_with_retries(
|
| 123 |
+
self,
|
| 124 |
+
provider: NotificationProvider,
|
| 125 |
+
destination_name: str,
|
| 126 |
+
destination,
|
| 127 |
+
request: NotificationRequest,
|
| 128 |
+
) -> NotificationResult:
|
| 129 |
+
client = self._client or httpx.AsyncClient(timeout=10.0)
|
| 130 |
+
owns_client = self._client is None
|
| 131 |
+
try:
|
| 132 |
+
for attempt in range(len(_RETRY_DELAYS) + 1):
|
| 133 |
+
try:
|
| 134 |
+
return await provider.send(client, destination_name, destination, request)
|
| 135 |
+
except RetryableNotificationError as exc:
|
| 136 |
+
if attempt >= len(_RETRY_DELAYS):
|
| 137 |
+
return NotificationResult(
|
| 138 |
+
destination=destination_name,
|
| 139 |
+
ok=False,
|
| 140 |
+
provider=provider.provider_name,
|
| 141 |
+
error=str(exc),
|
| 142 |
+
)
|
| 143 |
+
delay = _RETRY_DELAYS[attempt]
|
| 144 |
+
logger.warning(
|
| 145 |
+
"Retrying notification to %s in %ss after transient error: %s",
|
| 146 |
+
destination_name,
|
| 147 |
+
delay,
|
| 148 |
+
exc,
|
| 149 |
+
)
|
| 150 |
+
await asyncio.sleep(delay)
|
| 151 |
+
except NotificationError as exc:
|
| 152 |
+
return NotificationResult(
|
| 153 |
+
destination=destination_name,
|
| 154 |
+
ok=False,
|
| 155 |
+
provider=provider.provider_name,
|
| 156 |
+
error=str(exc),
|
| 157 |
+
)
|
| 158 |
+
return NotificationResult(
|
| 159 |
+
destination=destination_name,
|
| 160 |
+
ok=False,
|
| 161 |
+
provider=provider.provider_name,
|
| 162 |
+
error="Notification delivery exhausted retries",
|
| 163 |
+
)
|
| 164 |
+
finally:
|
| 165 |
+
if owns_client:
|
| 166 |
+
await client.aclose()
|
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Annotated, Literal
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
| 4 |
+
|
| 5 |
+
_DESTINATION_NAME_CHARS = set("abcdefghijklmnopqrstuvwxyz0123456789._-")
|
| 6 |
+
SUPPORTED_AUTO_EVENT_TYPES = {"approval_required", "error", "turn_complete"}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SlackDestinationConfig(BaseModel):
|
| 10 |
+
provider: Literal["slack"] = "slack"
|
| 11 |
+
token: str
|
| 12 |
+
channel: str
|
| 13 |
+
allow_agent_tool: bool = False
|
| 14 |
+
allow_auto_events: bool = False
|
| 15 |
+
username: str | None = None
|
| 16 |
+
icon_emoji: str | None = None
|
| 17 |
+
|
| 18 |
+
@field_validator("token", "channel")
|
| 19 |
+
@classmethod
|
| 20 |
+
def _require_non_empty(cls, value: str) -> str:
|
| 21 |
+
value = value.strip()
|
| 22 |
+
if not value:
|
| 23 |
+
raise ValueError("must not be empty")
|
| 24 |
+
return value
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
DestinationConfig = Annotated[SlackDestinationConfig, Field(discriminator="provider")]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MessagingConfig(BaseModel):
|
| 31 |
+
enabled: bool = False
|
| 32 |
+
auto_event_types: list[str] = Field(
|
| 33 |
+
default_factory=lambda: ["approval_required", "error", "turn_complete"]
|
| 34 |
+
)
|
| 35 |
+
destinations: dict[str, DestinationConfig] = Field(default_factory=dict)
|
| 36 |
+
|
| 37 |
+
@field_validator("destinations")
|
| 38 |
+
@classmethod
|
| 39 |
+
def _validate_destination_names(
|
| 40 |
+
cls, destinations: dict[str, DestinationConfig]
|
| 41 |
+
) -> dict[str, DestinationConfig]:
|
| 42 |
+
for name in destinations:
|
| 43 |
+
if not name or any(char not in _DESTINATION_NAME_CHARS for char in name):
|
| 44 |
+
raise ValueError(
|
| 45 |
+
"destination names must use lowercase letters, digits, '.', '_' or '-'"
|
| 46 |
+
)
|
| 47 |
+
return destinations
|
| 48 |
+
|
| 49 |
+
@field_validator("auto_event_types")
|
| 50 |
+
@classmethod
|
| 51 |
+
def _validate_auto_event_types(cls, event_types: list[str]) -> list[str]:
|
| 52 |
+
if not event_types:
|
| 53 |
+
return []
|
| 54 |
+
normalized: list[str] = []
|
| 55 |
+
seen: set[str] = set()
|
| 56 |
+
for event_type in event_types:
|
| 57 |
+
if event_type not in SUPPORTED_AUTO_EVENT_TYPES:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"unsupported auto event type '{event_type}'"
|
| 60 |
+
)
|
| 61 |
+
if event_type not in seen:
|
| 62 |
+
normalized.append(event_type)
|
| 63 |
+
seen.add(event_type)
|
| 64 |
+
return normalized
|
| 65 |
+
|
| 66 |
+
@model_validator(mode="after")
|
| 67 |
+
def _require_destinations_when_enabled(self) -> "MessagingConfig":
|
| 68 |
+
if self.enabled and not self.destinations:
|
| 69 |
+
raise ValueError("messaging.enabled requires at least one destination")
|
| 70 |
+
return self
|
| 71 |
+
|
| 72 |
+
def get_destination(self, name: str) -> DestinationConfig | None:
|
| 73 |
+
return self.destinations.get(name)
|
| 74 |
+
|
| 75 |
+
def can_agent_tool_send(self, name: str) -> bool:
|
| 76 |
+
destination = self.get_destination(name)
|
| 77 |
+
return bool(destination and destination.allow_agent_tool)
|
| 78 |
+
|
| 79 |
+
def can_auto_send(self, name: str) -> bool:
|
| 80 |
+
destination = self.get_destination(name)
|
| 81 |
+
return bool(destination and destination.allow_auto_events)
|
| 82 |
+
|
| 83 |
+
def default_auto_destinations(self) -> list[str]:
|
| 84 |
+
if not self.enabled:
|
| 85 |
+
return []
|
| 86 |
+
return [
|
| 87 |
+
name
|
| 88 |
+
for name in self.destinations
|
| 89 |
+
if self.can_auto_send(name)
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class NotificationRequest(BaseModel):
|
| 94 |
+
destination: str
|
| 95 |
+
title: str | None = None
|
| 96 |
+
message: str
|
| 97 |
+
severity: Literal["info", "success", "warning", "error"] = "info"
|
| 98 |
+
metadata: dict[str, str] = Field(default_factory=dict)
|
| 99 |
+
event_type: str | None = None
|
| 100 |
+
|
| 101 |
+
@field_validator("destination", "message")
|
| 102 |
+
@classmethod
|
| 103 |
+
def _require_text(cls, value: str) -> str:
|
| 104 |
+
value = value.strip()
|
| 105 |
+
if not value:
|
| 106 |
+
raise ValueError("must not be empty")
|
| 107 |
+
return value
|
| 108 |
+
|
| 109 |
+
@field_validator("title")
|
| 110 |
+
@classmethod
|
| 111 |
+
def _normalize_title(cls, value: str | None) -> str | None:
|
| 112 |
+
if value is None:
|
| 113 |
+
return None
|
| 114 |
+
value = value.strip()
|
| 115 |
+
return value or None
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class NotificationResult(BaseModel):
|
| 119 |
+
destination: str
|
| 120 |
+
ok: bool
|
| 121 |
+
provider: str
|
| 122 |
+
error: str | None = None
|
| 123 |
+
external_id: str | None = None
|
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import httpx
|
| 5 |
+
|
| 6 |
+
from agent.messaging.base import (
|
| 7 |
+
NotificationError,
|
| 8 |
+
NotificationProvider,
|
| 9 |
+
RetryableNotificationError,
|
| 10 |
+
)
|
| 11 |
+
from agent.messaging.models import (
|
| 12 |
+
NotificationRequest,
|
| 13 |
+
NotificationResult,
|
| 14 |
+
SlackDestinationConfig,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
_SEVERITY_PREFIX = {
|
| 18 |
+
"info": "[INFO]",
|
| 19 |
+
"success": "[SUCCESS]",
|
| 20 |
+
"warning": "[WARNING]",
|
| 21 |
+
"error": "[ERROR]",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _format_slack_mrkdwn(content: str) -> str:
|
| 26 |
+
"""Convert common Markdown constructs to Slack's mrkdwn syntax."""
|
| 27 |
+
if not content:
|
| 28 |
+
return content
|
| 29 |
+
|
| 30 |
+
placeholders: dict[str, str] = {}
|
| 31 |
+
placeholder_index = 0
|
| 32 |
+
|
| 33 |
+
def placeholder(value: str) -> str:
|
| 34 |
+
nonlocal placeholder_index
|
| 35 |
+
key = f"\x00SLACK{placeholder_index}\x00"
|
| 36 |
+
placeholder_index += 1
|
| 37 |
+
placeholders[key] = value
|
| 38 |
+
return key
|
| 39 |
+
|
| 40 |
+
text = content
|
| 41 |
+
|
| 42 |
+
# Protect code before any formatting conversion. Slack's mrkdwn ignores
|
| 43 |
+
# formatting inside backticks, so these regions should stay byte-for-byte.
|
| 44 |
+
text = re.sub(
|
| 45 |
+
r"(```(?:[^\n]*\n)?[\s\S]*?```)",
|
| 46 |
+
lambda match: placeholder(match.group(0)),
|
| 47 |
+
text,
|
| 48 |
+
)
|
| 49 |
+
text = re.sub(r"(`[^`\n]+`)", lambda match: placeholder(match.group(0)), text)
|
| 50 |
+
|
| 51 |
+
def convert_markdown_link(match: re.Match[str]) -> str:
|
| 52 |
+
label = match.group(1)
|
| 53 |
+
url = match.group(2).strip()
|
| 54 |
+
if url.startswith("<") and url.endswith(">"):
|
| 55 |
+
url = url[1:-1].strip()
|
| 56 |
+
return placeholder(f"<{url}|{label}>")
|
| 57 |
+
|
| 58 |
+
text = re.sub(
|
| 59 |
+
r"\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)",
|
| 60 |
+
convert_markdown_link,
|
| 61 |
+
text,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Preserve existing Slack entities and manual mrkdwn links before escaping.
|
| 65 |
+
text = re.sub(
|
| 66 |
+
r"(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)",
|
| 67 |
+
lambda match: placeholder(match.group(1)),
|
| 68 |
+
text,
|
| 69 |
+
)
|
| 70 |
+
text = re.sub(
|
| 71 |
+
r"^(>+\s)",
|
| 72 |
+
lambda match: placeholder(match.group(0)),
|
| 73 |
+
text,
|
| 74 |
+
flags=re.MULTILINE,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
| 78 |
+
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
| 79 |
+
|
| 80 |
+
def convert_header(match: re.Match[str]) -> str:
|
| 81 |
+
header = match.group(1).strip()
|
| 82 |
+
header = re.sub(r"\*\*(.+?)\*\*", r"\1", header)
|
| 83 |
+
return placeholder(f"*{header}*")
|
| 84 |
+
|
| 85 |
+
text = re.sub(r"^#{1,6}\s+(.+)$", convert_header, text, flags=re.MULTILINE)
|
| 86 |
+
text = re.sub(
|
| 87 |
+
r"\*\*\*(.+?)\*\*\*",
|
| 88 |
+
lambda match: placeholder(f"*_{match.group(1)}_*"),
|
| 89 |
+
text,
|
| 90 |
+
)
|
| 91 |
+
text = re.sub(
|
| 92 |
+
r"\*\*(.+?)\*\*",
|
| 93 |
+
lambda match: placeholder(f"*{match.group(1)}*"),
|
| 94 |
+
text,
|
| 95 |
+
)
|
| 96 |
+
text = re.sub(
|
| 97 |
+
r"(?<!\*)\*([^*\n]+)\*(?!\*)",
|
| 98 |
+
lambda match: placeholder(f"_{match.group(1)}_"),
|
| 99 |
+
text,
|
| 100 |
+
)
|
| 101 |
+
text = re.sub(
|
| 102 |
+
r"~~(.+?)~~",
|
| 103 |
+
lambda match: placeholder(f"~{match.group(1)}~"),
|
| 104 |
+
text,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
for key in reversed(placeholders):
|
| 108 |
+
text = text.replace(key, placeholders[key])
|
| 109 |
+
|
| 110 |
+
return text
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _format_text(request: NotificationRequest) -> str:
|
| 114 |
+
lines: list[str] = []
|
| 115 |
+
prefix = _SEVERITY_PREFIX[request.severity]
|
| 116 |
+
if request.title:
|
| 117 |
+
lines.append(f"{prefix} {request.title}")
|
| 118 |
+
else:
|
| 119 |
+
lines.append(prefix)
|
| 120 |
+
lines.append(request.message)
|
| 121 |
+
for key, value in request.metadata.items():
|
| 122 |
+
lines.append(f"{key}: {value}")
|
| 123 |
+
return _format_slack_mrkdwn("\n".join(lines))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class SlackProvider(NotificationProvider):
|
| 127 |
+
provider_name = "slack"
|
| 128 |
+
|
| 129 |
+
async def send(
|
| 130 |
+
self,
|
| 131 |
+
client: httpx.AsyncClient,
|
| 132 |
+
destination_name: str,
|
| 133 |
+
destination: SlackDestinationConfig,
|
| 134 |
+
request: NotificationRequest,
|
| 135 |
+
) -> NotificationResult:
|
| 136 |
+
payload = {
|
| 137 |
+
"channel": destination.channel,
|
| 138 |
+
"text": _format_text(request),
|
| 139 |
+
"mrkdwn": True,
|
| 140 |
+
"unfurl_links": False,
|
| 141 |
+
"unfurl_media": False,
|
| 142 |
+
}
|
| 143 |
+
if destination.username:
|
| 144 |
+
payload["username"] = destination.username
|
| 145 |
+
if destination.icon_emoji:
|
| 146 |
+
payload["icon_emoji"] = destination.icon_emoji
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
response = await client.post(
|
| 150 |
+
"https://slack.com/api/chat.postMessage",
|
| 151 |
+
headers={
|
| 152 |
+
"Authorization": f"Bearer {destination.token}",
|
| 153 |
+
"Content-Type": "application/json; charset=utf-8",
|
| 154 |
+
},
|
| 155 |
+
content=json.dumps(payload),
|
| 156 |
+
)
|
| 157 |
+
except httpx.TimeoutException as exc:
|
| 158 |
+
raise RetryableNotificationError("Slack request timed out") from exc
|
| 159 |
+
except httpx.TransportError as exc:
|
| 160 |
+
raise RetryableNotificationError("Slack transport error") from exc
|
| 161 |
+
|
| 162 |
+
if response.status_code == 429 or response.status_code >= 500:
|
| 163 |
+
raise RetryableNotificationError(
|
| 164 |
+
f"Slack HTTP {response.status_code}"
|
| 165 |
+
)
|
| 166 |
+
if response.status_code >= 400:
|
| 167 |
+
raise NotificationError(f"Slack HTTP {response.status_code}")
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
data = response.json()
|
| 171 |
+
except ValueError as exc:
|
| 172 |
+
raise RetryableNotificationError("Slack returned invalid JSON") from exc
|
| 173 |
+
|
| 174 |
+
if not data.get("ok"):
|
| 175 |
+
error = str(data.get("error") or "unknown_error")
|
| 176 |
+
if error == "ratelimited":
|
| 177 |
+
raise RetryableNotificationError(error)
|
| 178 |
+
raise NotificationError(error)
|
| 179 |
+
|
| 180 |
+
return NotificationResult(
|
| 181 |
+
destination=destination_name,
|
| 182 |
+
ok=True,
|
| 183 |
+
provider=self.provider_name,
|
| 184 |
+
external_id=str(data.get("ts") or ""),
|
| 185 |
+
error=None,
|
| 186 |
+
)
|
|
@@ -157,6 +157,7 @@ system_prompt: |
|
|
| 157 |
- Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs.
|
| 158 |
- For errors: state what went wrong, why, and what you're doing to fix it.
|
| 159 |
- Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity.
|
|
|
|
| 160 |
|
| 161 |
# Tool usage
|
| 162 |
|
|
|
|
| 157 |
- Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs.
|
| 158 |
- For errors: state what went wrong, why, and what you're doing to fix it.
|
| 159 |
- Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity.
|
| 160 |
+
- Use the `notify` tool only when the user explicitly asked for out-of-band notifications or when the task clearly requires reporting to a configured messaging destination. Do not use it for routine chat updates.
|
| 161 |
|
| 162 |
# Tool usage
|
| 163 |
|
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from agent.messaging.models import NotificationRequest
|
| 4 |
+
|
| 5 |
+
NOTIFY_TOOL_SPEC = {
|
| 6 |
+
"name": "notify",
|
| 7 |
+
"description": (
|
| 8 |
+
"Send an out-of-band notification to configured messaging destinations. "
|
| 9 |
+
"Use this only when the user explicitly asked for proactive notifications "
|
| 10 |
+
"or when the task requires reporting progress outside the chat. "
|
| 11 |
+
"Destinations must be named server-side configs such as 'slack.ops'."
|
| 12 |
+
),
|
| 13 |
+
"parameters": {
|
| 14 |
+
"type": "object",
|
| 15 |
+
"properties": {
|
| 16 |
+
"destinations": {
|
| 17 |
+
"type": "array",
|
| 18 |
+
"description": "Named messaging destinations to notify.",
|
| 19 |
+
"items": {"type": "string"},
|
| 20 |
+
"minItems": 1,
|
| 21 |
+
},
|
| 22 |
+
"message": {
|
| 23 |
+
"type": "string",
|
| 24 |
+
"description": "Main notification body.",
|
| 25 |
+
},
|
| 26 |
+
"title": {
|
| 27 |
+
"type": "string",
|
| 28 |
+
"description": "Optional short title line.",
|
| 29 |
+
},
|
| 30 |
+
"severity": {
|
| 31 |
+
"type": "string",
|
| 32 |
+
"enum": ["info", "success", "warning", "error"],
|
| 33 |
+
"description": "Notification severity label.",
|
| 34 |
+
},
|
| 35 |
+
},
|
| 36 |
+
"required": ["destinations", "message"],
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def notify_handler(
|
| 42 |
+
arguments: dict[str, Any], session=None, **_kwargs
|
| 43 |
+
) -> tuple[str, bool]:
|
| 44 |
+
if session is None or session.notification_gateway is None:
|
| 45 |
+
return "Messaging is not configured for this session.", False
|
| 46 |
+
|
| 47 |
+
raw_destinations = arguments.get("destinations", [])
|
| 48 |
+
if not isinstance(raw_destinations, list) or not raw_destinations:
|
| 49 |
+
return "destinations must be a non-empty array of destination names.", False
|
| 50 |
+
|
| 51 |
+
destinations: list[str] = []
|
| 52 |
+
seen: set[str] = set()
|
| 53 |
+
for raw_name in raw_destinations:
|
| 54 |
+
if not isinstance(raw_name, str):
|
| 55 |
+
return "Each destination must be a string.", False
|
| 56 |
+
name = raw_name.strip()
|
| 57 |
+
if not name:
|
| 58 |
+
return "Destination names must not be empty.", False
|
| 59 |
+
if name not in seen:
|
| 60 |
+
destinations.append(name)
|
| 61 |
+
seen.add(name)
|
| 62 |
+
|
| 63 |
+
disallowed = [
|
| 64 |
+
name
|
| 65 |
+
for name in destinations
|
| 66 |
+
if not session.config.messaging.can_agent_tool_send(name)
|
| 67 |
+
]
|
| 68 |
+
if disallowed:
|
| 69 |
+
return (
|
| 70 |
+
"These destinations are unavailable for the notify tool: "
|
| 71 |
+
+ ", ".join(disallowed)
|
| 72 |
+
), False
|
| 73 |
+
|
| 74 |
+
message = arguments.get("message", "")
|
| 75 |
+
if not isinstance(message, str) or not message.strip():
|
| 76 |
+
return "message must be a non-empty string.", False
|
| 77 |
+
|
| 78 |
+
title = arguments.get("title")
|
| 79 |
+
severity = arguments.get("severity", "info")
|
| 80 |
+
if title is not None and not isinstance(title, str):
|
| 81 |
+
return "title must be a string when provided.", False
|
| 82 |
+
if severity not in {"info", "success", "warning", "error"}:
|
| 83 |
+
return "severity must be one of: info, success, warning, error.", False
|
| 84 |
+
|
| 85 |
+
requests = [
|
| 86 |
+
NotificationRequest(
|
| 87 |
+
destination=name,
|
| 88 |
+
title=title,
|
| 89 |
+
message=message,
|
| 90 |
+
severity=severity,
|
| 91 |
+
metadata={
|
| 92 |
+
"session_id": session.session_id,
|
| 93 |
+
"model": session.config.model_name,
|
| 94 |
+
},
|
| 95 |
+
)
|
| 96 |
+
for name in destinations
|
| 97 |
+
]
|
| 98 |
+
results = await session.notification_gateway.send_many(requests)
|
| 99 |
+
|
| 100 |
+
lines = []
|
| 101 |
+
all_ok = True
|
| 102 |
+
for result in results:
|
| 103 |
+
if result.ok:
|
| 104 |
+
lines.append(f"{result.destination}: sent")
|
| 105 |
+
else:
|
| 106 |
+
all_ok = False
|
| 107 |
+
lines.append(f"{result.destination}: failed ({result.error})")
|
| 108 |
+
return "\n".join(lines), all_ok
|
|
@@ -11,6 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 11 |
from fastapi.staticfiles import StaticFiles
|
| 12 |
from routes.agent import router as agent_router
|
| 13 |
from routes.auth import router as auth_router
|
|
|
|
| 14 |
|
| 15 |
# Load .env from project root (parent directory)
|
| 16 |
load_dotenv(Path(__file__).parent.parent / ".env")
|
|
@@ -27,6 +28,7 @@ logger = logging.getLogger(__name__)
|
|
| 27 |
async def lifespan(app: FastAPI):
|
| 28 |
"""Application lifespan handler."""
|
| 29 |
logger.info("Starting HF Agent backend...")
|
|
|
|
| 30 |
# Start in-process hourly KPI rollup. Replaces an external cron so the
|
| 31 |
# rollup lives next to the data and reuses the Space's HF token.
|
| 32 |
try:
|
|
@@ -34,7 +36,6 @@ async def lifespan(app: FastAPI):
|
|
| 34 |
kpis_scheduler.start()
|
| 35 |
except Exception as e:
|
| 36 |
logger.warning("KPI scheduler failed to start: %s", e)
|
| 37 |
-
|
| 38 |
yield
|
| 39 |
|
| 40 |
logger.info("Shutting down HF Agent backend...")
|
|
@@ -47,7 +48,6 @@ async def lifespan(app: FastAPI):
|
|
| 47 |
# Final-flush: save every still-active session so we don't lose traces on
|
| 48 |
# server restart. Uploads are detached subprocesses — this is fast.
|
| 49 |
try:
|
| 50 |
-
from session_manager import session_manager
|
| 51 |
for sid, agent_session in list(session_manager.sessions.items()):
|
| 52 |
sess = agent_session.session
|
| 53 |
if sess.config.save_sessions:
|
|
@@ -58,6 +58,7 @@ async def lifespan(app: FastAPI):
|
|
| 58 |
logger.warning("Failed to flush session %s: %s", sid, e)
|
| 59 |
except Exception as e:
|
| 60 |
logger.warning("Lifespan final-flush skipped: %s", e)
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
app = FastAPI(
|
|
|
|
| 11 |
from fastapi.staticfiles import StaticFiles
|
| 12 |
from routes.agent import router as agent_router
|
| 13 |
from routes.auth import router as auth_router
|
| 14 |
+
from session_manager import session_manager
|
| 15 |
|
| 16 |
# Load .env from project root (parent directory)
|
| 17 |
load_dotenv(Path(__file__).parent.parent / ".env")
|
|
|
|
| 28 |
async def lifespan(app: FastAPI):
|
| 29 |
"""Application lifespan handler."""
|
| 30 |
logger.info("Starting HF Agent backend...")
|
| 31 |
+
await session_manager.start()
|
| 32 |
# Start in-process hourly KPI rollup. Replaces an external cron so the
|
| 33 |
# rollup lives next to the data and reuses the Space's HF token.
|
| 34 |
try:
|
|
|
|
| 36 |
kpis_scheduler.start()
|
| 37 |
except Exception as e:
|
| 38 |
logger.warning("KPI scheduler failed to start: %s", e)
|
|
|
|
| 39 |
yield
|
| 40 |
|
| 41 |
logger.info("Shutting down HF Agent backend...")
|
|
|
|
| 48 |
# Final-flush: save every still-active session so we don't lose traces on
|
| 49 |
# server restart. Uploads are detached subprocesses — this is fast.
|
| 50 |
try:
|
|
|
|
| 51 |
for sid, agent_session in list(session_manager.sessions.items()):
|
| 52 |
sess = agent_session.session
|
| 53 |
if sess.config.save_sessions:
|
|
|
|
| 58 |
logger.warning("Failed to flush session %s: %s", sid, e)
|
| 59 |
except Exception as e:
|
| 60 |
logger.warning("Lifespan final-flush skipped: %s", e)
|
| 61 |
+
await session_manager.close()
|
| 62 |
|
| 63 |
|
| 64 |
app = FastAPI(
|
|
@@ -3,7 +3,7 @@
|
|
| 3 |
from enum import Enum
|
| 4 |
from typing import Any
|
| 5 |
|
| 6 |
-
from pydantic import BaseModel
|
| 7 |
|
| 8 |
|
| 9 |
class OpType(str, Enum):
|
|
@@ -87,6 +87,13 @@ class SessionInfo(BaseModel):
|
|
| 87 |
user_id: str = "dev"
|
| 88 |
pending_approval: list[PendingApprovalTool] | None = None
|
| 89 |
model: str | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
class HealthResponse(BaseModel):
|
|
|
|
| 3 |
from enum import Enum
|
| 4 |
from typing import Any
|
| 5 |
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
|
| 8 |
|
| 9 |
class OpType(str, Enum):
|
|
|
|
| 87 |
user_id: str = "dev"
|
| 88 |
pending_approval: list[PendingApprovalTool] | None = None
|
| 89 |
model: str | None = None
|
| 90 |
+
notification_destinations: list[str] = Field(default_factory=list)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SessionNotificationsRequest(BaseModel):
|
| 94 |
+
"""Replace the session's auto-notification destinations."""
|
| 95 |
+
|
| 96 |
+
destinations: list[str]
|
| 97 |
|
| 98 |
|
| 99 |
class HealthResponse(BaseModel):
|
|
@@ -24,6 +24,7 @@ from models import (
|
|
| 24 |
HealthResponse,
|
| 25 |
LLMHealthResponse,
|
| 26 |
SessionInfo,
|
|
|
|
| 27 |
SessionResponse,
|
| 28 |
SubmitRequest,
|
| 29 |
TruncateRequest,
|
|
@@ -513,6 +514,26 @@ async def set_session_model(
|
|
| 513 |
return {"session_id": session_id, "model": model_id}
|
| 514 |
|
| 515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
@router.get("/user/quota")
|
| 517 |
async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
|
| 518 |
"""Return the user's plan tier and today's Claude-session quota state."""
|
|
@@ -824,7 +845,6 @@ async def shutdown_session(
|
|
| 824 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 825 |
return {"status": "shutdown_requested", "session_id": session_id}
|
| 826 |
|
| 827 |
-
|
| 828 |
@router.post("/feedback/{session_id}")
|
| 829 |
async def submit_feedback(
|
| 830 |
session_id: str,
|
|
|
|
| 24 |
HealthResponse,
|
| 25 |
LLMHealthResponse,
|
| 26 |
SessionInfo,
|
| 27 |
+
SessionNotificationsRequest,
|
| 28 |
SessionResponse,
|
| 29 |
SubmitRequest,
|
| 30 |
TruncateRequest,
|
|
|
|
| 514 |
return {"session_id": session_id, "model": model_id}
|
| 515 |
|
| 516 |
|
| 517 |
+
@router.post("/session/{session_id}/notifications")
|
| 518 |
+
async def set_session_notifications(
|
| 519 |
+
session_id: str,
|
| 520 |
+
body: SessionNotificationsRequest,
|
| 521 |
+
user: dict = Depends(get_current_user),
|
| 522 |
+
) -> dict:
|
| 523 |
+
"""Replace the session's auto-notification destinations."""
|
| 524 |
+
_check_session_access(session_id, user)
|
| 525 |
+
try:
|
| 526 |
+
destinations = session_manager.set_notification_destinations(
|
| 527 |
+
session_id, body.destinations
|
| 528 |
+
)
|
| 529 |
+
except ValueError as e:
|
| 530 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 531 |
+
return {
|
| 532 |
+
"session_id": session_id,
|
| 533 |
+
"notification_destinations": destinations,
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
|
| 537 |
@router.get("/user/quota")
|
| 538 |
async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
|
| 539 |
"""Return the user's plan tier and today's Claude-session quota state."""
|
|
|
|
| 845 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 846 |
return {"status": "shutdown_requested", "session_id": session_id}
|
| 847 |
|
|
|
|
| 848 |
@router.post("/feedback/{session_id}")
|
| 849 |
async def submit_feedback(
|
| 850 |
session_id: str,
|
|
@@ -10,6 +10,7 @@ from typing import Any, Optional
|
|
| 10 |
|
| 11 |
from agent.config import load_config
|
| 12 |
from agent.core.agent_loop import process_submission
|
|
|
|
| 13 |
from agent.core.session import Event, OpType, Session
|
| 14 |
from agent.core.tools import ToolRouter
|
| 15 |
|
|
@@ -119,9 +120,18 @@ class SessionManager:
|
|
| 119 |
|
| 120 |
def __init__(self, config_path: str | None = None) -> None:
|
| 121 |
self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
|
|
|
|
| 122 |
self.sessions: dict[str, AgentSession] = {}
|
| 123 |
self._lock = asyncio.Lock()
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def _count_user_sessions(self, user_id: str) -> int:
|
| 126 |
"""Count active sessions owned by a specific user."""
|
| 127 |
return sum(
|
|
@@ -192,7 +202,11 @@ class SessionManager:
|
|
| 192 |
session_config.model_name = model
|
| 193 |
session = Session(
|
| 194 |
event_queue, config=session_config, tool_router=tool_router,
|
| 195 |
-
hf_token=hf_token,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
)
|
| 197 |
t1 = _time.monotonic()
|
| 198 |
logger.info(f"Session initialized in {t1 - t0:.2f}s")
|
|
@@ -518,8 +532,39 @@ class SessionManager:
|
|
| 518 |
"user_id": agent_session.user_id,
|
| 519 |
"pending_approval": pending_approval,
|
| 520 |
"model": agent_session.session.config.model_name,
|
|
|
|
|
|
|
|
|
|
| 521 |
}
|
| 522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
| 524 |
"""List sessions, optionally filtered by user.
|
| 525 |
|
|
|
|
| 10 |
|
| 11 |
from agent.config import load_config
|
| 12 |
from agent.core.agent_loop import process_submission
|
| 13 |
+
from agent.messaging.gateway import NotificationGateway
|
| 14 |
from agent.core.session import Event, OpType, Session
|
| 15 |
from agent.core.tools import ToolRouter
|
| 16 |
|
|
|
|
| 120 |
|
| 121 |
def __init__(self, config_path: str | None = None) -> None:
|
| 122 |
self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
|
| 123 |
+
self.messaging_gateway = NotificationGateway(self.config.messaging)
|
| 124 |
self.sessions: dict[str, AgentSession] = {}
|
| 125 |
self._lock = asyncio.Lock()
|
| 126 |
|
| 127 |
+
async def start(self) -> None:
|
| 128 |
+
"""Start shared background resources."""
|
| 129 |
+
await self.messaging_gateway.start()
|
| 130 |
+
|
| 131 |
+
async def close(self) -> None:
|
| 132 |
+
"""Flush and close shared background resources."""
|
| 133 |
+
await self.messaging_gateway.close()
|
| 134 |
+
|
| 135 |
def _count_user_sessions(self, user_id: str) -> int:
|
| 136 |
"""Count active sessions owned by a specific user."""
|
| 137 |
return sum(
|
|
|
|
| 202 |
session_config.model_name = model
|
| 203 |
session = Session(
|
| 204 |
event_queue, config=session_config, tool_router=tool_router,
|
| 205 |
+
hf_token=hf_token,
|
| 206 |
+
user_id=user_id,
|
| 207 |
+
notification_gateway=self.messaging_gateway,
|
| 208 |
+
notification_destinations=[],
|
| 209 |
+
session_id=session_id,
|
| 210 |
)
|
| 211 |
t1 = _time.monotonic()
|
| 212 |
logger.info(f"Session initialized in {t1 - t0:.2f}s")
|
|
|
|
| 532 |
"user_id": agent_session.user_id,
|
| 533 |
"pending_approval": pending_approval,
|
| 534 |
"model": agent_session.session.config.model_name,
|
| 535 |
+
"notification_destinations": list(
|
| 536 |
+
agent_session.session.notification_destinations
|
| 537 |
+
),
|
| 538 |
}
|
| 539 |
|
| 540 |
+
def set_notification_destinations(
|
| 541 |
+
self, session_id: str, destinations: list[str]
|
| 542 |
+
) -> list[str]:
|
| 543 |
+
"""Replace the session's opted-in auto-notification destinations."""
|
| 544 |
+
agent_session = self.sessions.get(session_id)
|
| 545 |
+
if not agent_session or not agent_session.is_active:
|
| 546 |
+
raise ValueError("Session not found or inactive")
|
| 547 |
+
|
| 548 |
+
normalized: list[str] = []
|
| 549 |
+
seen: set[str] = set()
|
| 550 |
+
for raw_name in destinations:
|
| 551 |
+
name = raw_name.strip()
|
| 552 |
+
if not name:
|
| 553 |
+
raise ValueError("Destination names must not be empty")
|
| 554 |
+
destination = self.config.messaging.get_destination(name)
|
| 555 |
+
if destination is None:
|
| 556 |
+
raise ValueError(f"Unknown destination '{name}'")
|
| 557 |
+
if not destination.allow_auto_events:
|
| 558 |
+
raise ValueError(
|
| 559 |
+
f"Destination '{name}' is not enabled for auto events"
|
| 560 |
+
)
|
| 561 |
+
if name not in seen:
|
| 562 |
+
normalized.append(name)
|
| 563 |
+
seen.add(name)
|
| 564 |
+
|
| 565 |
+
agent_session.session.set_notification_destinations(normalized)
|
| 566 |
+
return normalized
|
| 567 |
+
|
| 568 |
def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
|
| 569 |
"""List sessions, optionally filtered by user.
|
| 570 |
|
|
@@ -5,6 +5,11 @@
|
|
| 5 |
"yolo_mode": false,
|
| 6 |
"confirm_cpu_jobs": true,
|
| 7 |
"auto_file_upload": true,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"mcpServers": {
|
| 9 |
"hf-mcp-server": {
|
| 10 |
"transport": "http",
|
|
|
|
| 5 |
"yolo_mode": false,
|
| 6 |
"confirm_cpu_jobs": true,
|
| 7 |
"auto_file_upload": true,
|
| 8 |
+
"messaging": {
|
| 9 |
+
"enabled": false,
|
| 10 |
+
"auto_event_types": ["approval_required", "error", "turn_complete"],
|
| 11 |
+
"destinations": {}
|
| 12 |
+
},
|
| 13 |
"mcpServers": {
|
| 14 |
"hf-mcp-server": {
|
| 15 |
"transport": "http",
|
|
@@ -42,7 +42,7 @@ eval = [
|
|
| 42 |
# Development and testing dependencies
|
| 43 |
dev = [
|
| 44 |
"pytest>=9.0.2",
|
| 45 |
-
"pytest-asyncio>=
|
| 46 |
]
|
| 47 |
|
| 48 |
# All dependencies (eval + dev)
|
|
|
|
| 42 |
# Development and testing dependencies
|
| 43 |
dev = [
|
| 44 |
"pytest>=9.0.2",
|
| 45 |
+
"pytest-asyncio>=1.2.0",
|
| 46 |
]
|
| 47 |
|
| 48 |
# All dependencies (eval + dev)
|
|
@@ -79,7 +79,7 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
|
|
| 79 |
monkeypatch.setattr(
|
| 80 |
main_mod,
|
| 81 |
"load_config",
|
| 82 |
-
lambda _path: SimpleNamespace(
|
| 83 |
model_name="moonshotai/Kimi-K2.6",
|
| 84 |
mcpServers={},
|
| 85 |
),
|
|
|
|
| 79 |
monkeypatch.setattr(
|
| 80 |
main_mod,
|
| 81 |
"load_config",
|
| 82 |
+
lambda _path, **_kwargs: SimpleNamespace(
|
| 83 |
model_name="moonshotai/Kimi-K2.6",
|
| 84 |
mcpServers={},
|
| 85 |
),
|
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
from agent import config as config_module
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _write_json(path, data):
|
| 7 |
+
path.write_text(json.dumps(data), encoding="utf-8")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_load_config_does_not_apply_slack_user_defaults_by_default(tmp_path, monkeypatch):
|
| 11 |
+
config_path = tmp_path / "config.json"
|
| 12 |
+
_write_json(
|
| 13 |
+
config_path,
|
| 14 |
+
{
|
| 15 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 16 |
+
"messaging": {
|
| 17 |
+
"enabled": False,
|
| 18 |
+
"destinations": {},
|
| 19 |
+
},
|
| 20 |
+
},
|
| 21 |
+
)
|
| 22 |
+
monkeypatch.setenv("SLACK_BOT_TOKEN", "xoxb-test")
|
| 23 |
+
monkeypatch.setenv("SLACK_CHANNEL_ID", "C123")
|
| 24 |
+
|
| 25 |
+
config = config_module.load_config(str(config_path))
|
| 26 |
+
|
| 27 |
+
assert not config.messaging.enabled
|
| 28 |
+
assert config.messaging.destinations == {}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_load_config_applies_slack_user_defaults_from_env(tmp_path, monkeypatch):
|
| 32 |
+
config_path = tmp_path / "config.json"
|
| 33 |
+
_write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
|
| 34 |
+
monkeypatch.delenv("ML_INTERN_CLI_CONFIG", raising=False)
|
| 35 |
+
monkeypatch.setattr(
|
| 36 |
+
config_module,
|
| 37 |
+
"DEFAULT_USER_CONFIG_PATH",
|
| 38 |
+
tmp_path / "missing-user-config.json",
|
| 39 |
+
)
|
| 40 |
+
monkeypatch.setenv("SLACK_BOT_TOKEN", "xoxb-test")
|
| 41 |
+
monkeypatch.setenv("SLACK_CHANNEL_ID", "C123")
|
| 42 |
+
|
| 43 |
+
config = config_module.load_config(str(config_path), include_user_defaults=True)
|
| 44 |
+
|
| 45 |
+
assert config.messaging.enabled
|
| 46 |
+
assert config.messaging.auto_event_types == [
|
| 47 |
+
"approval_required",
|
| 48 |
+
"error",
|
| 49 |
+
"turn_complete",
|
| 50 |
+
]
|
| 51 |
+
destination = config.messaging.destinations["slack.default"]
|
| 52 |
+
assert destination.token == "xoxb-test"
|
| 53 |
+
assert destination.channel == "C123"
|
| 54 |
+
assert destination.allow_agent_tool
|
| 55 |
+
assert destination.allow_auto_events
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_load_config_merges_user_config_before_env_substitution(tmp_path, monkeypatch):
|
| 59 |
+
config_path = tmp_path / "config.json"
|
| 60 |
+
user_config_path = tmp_path / "user-config.json"
|
| 61 |
+
_write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
|
| 62 |
+
_write_json(
|
| 63 |
+
user_config_path,
|
| 64 |
+
{
|
| 65 |
+
"messaging": {
|
| 66 |
+
"enabled": True,
|
| 67 |
+
"auto_event_types": ["approval_required"],
|
| 68 |
+
"destinations": {
|
| 69 |
+
"slack.team": {
|
| 70 |
+
"provider": "slack",
|
| 71 |
+
"token": "${USER_SLACK_TOKEN}",
|
| 72 |
+
"channel": "C999",
|
| 73 |
+
"allow_agent_tool": False,
|
| 74 |
+
"allow_auto_events": True,
|
| 75 |
+
},
|
| 76 |
+
},
|
| 77 |
+
},
|
| 78 |
+
},
|
| 79 |
+
)
|
| 80 |
+
monkeypatch.setenv("ML_INTERN_CLI_CONFIG", str(user_config_path))
|
| 81 |
+
monkeypatch.setenv("ML_INTERN_SLACK_NOTIFICATIONS", "0")
|
| 82 |
+
monkeypatch.setenv("USER_SLACK_TOKEN", "xoxb-user")
|
| 83 |
+
|
| 84 |
+
config = config_module.load_config(str(config_path), include_user_defaults=True)
|
| 85 |
+
|
| 86 |
+
assert config.messaging.enabled
|
| 87 |
+
assert config.messaging.auto_event_types == ["approval_required"]
|
| 88 |
+
assert set(config.messaging.destinations) == {"slack.team"}
|
| 89 |
+
destination = config.messaging.destinations["slack.team"]
|
| 90 |
+
assert destination.token == "xoxb-user"
|
| 91 |
+
assert destination.channel == "C999"
|
| 92 |
+
assert not destination.allow_agent_tool
|
| 93 |
+
assert destination.allow_auto_events
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_slack_user_defaults_can_be_disabled(tmp_path, monkeypatch):
|
| 97 |
+
config_path = tmp_path / "config.json"
|
| 98 |
+
_write_json(
|
| 99 |
+
config_path,
|
| 100 |
+
{
|
| 101 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 102 |
+
"messaging": {
|
| 103 |
+
"enabled": False,
|
| 104 |
+
"destinations": {},
|
| 105 |
+
},
|
| 106 |
+
},
|
| 107 |
+
)
|
| 108 |
+
monkeypatch.delenv("ML_INTERN_CLI_CONFIG", raising=False)
|
| 109 |
+
monkeypatch.setattr(
|
| 110 |
+
config_module,
|
| 111 |
+
"DEFAULT_USER_CONFIG_PATH",
|
| 112 |
+
tmp_path / "missing-user-config.json",
|
| 113 |
+
)
|
| 114 |
+
monkeypatch.setenv("ML_INTERN_SLACK_NOTIFICATIONS", "false")
|
| 115 |
+
monkeypatch.setenv("SLACK_BOT_TOKEN", "xoxb-test")
|
| 116 |
+
monkeypatch.setenv("SLACK_CHANNEL_ID", "C123")
|
| 117 |
+
|
| 118 |
+
config = config_module.load_config(str(config_path), include_user_defaults=True)
|
| 119 |
+
|
| 120 |
+
assert not config.messaging.enabled
|
| 121 |
+
assert config.messaging.destinations == {}
|
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
|
| 6 |
+
import httpx
|
| 7 |
+
import pytest
|
| 8 |
+
from pydantic import ValidationError
|
| 9 |
+
|
| 10 |
+
from agent.config import Config
|
| 11 |
+
from agent.core.session import Event, Session
|
| 12 |
+
from agent.messaging.gateway import NotificationGateway
|
| 13 |
+
from agent.messaging.models import NotificationRequest, NotificationResult
|
| 14 |
+
from agent.messaging.slack import SlackProvider, _format_slack_mrkdwn
|
| 15 |
+
from agent.tools.notify_tool import notify_handler
|
| 16 |
+
from backend.session_manager import AgentSession, SessionManager
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DummyToolRouter:
|
| 20 |
+
def get_tool_specs_for_llm(self) -> list[dict]:
|
| 21 |
+
return []
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class RecordingGateway:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.enqueued: list[NotificationRequest] = []
|
| 27 |
+
self.sent: list[NotificationRequest] = []
|
| 28 |
+
|
| 29 |
+
async def enqueue(self, request: NotificationRequest) -> bool:
|
| 30 |
+
self.enqueued.append(request)
|
| 31 |
+
return True
|
| 32 |
+
|
| 33 |
+
async def send_many(
|
| 34 |
+
self, requests: list[NotificationRequest]
|
| 35 |
+
) -> list[NotificationResult]:
|
| 36 |
+
self.sent.extend(requests)
|
| 37 |
+
return [
|
| 38 |
+
NotificationResult(
|
| 39 |
+
destination=request.destination,
|
| 40 |
+
ok=True,
|
| 41 |
+
provider="test",
|
| 42 |
+
)
|
| 43 |
+
for request in requests
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _config_with_messaging(**destination_overrides) -> Config:
|
| 48 |
+
destination = {
|
| 49 |
+
"provider": "slack",
|
| 50 |
+
"token": "xoxb-test",
|
| 51 |
+
"channel": "C123",
|
| 52 |
+
**destination_overrides,
|
| 53 |
+
}
|
| 54 |
+
return Config.model_validate(
|
| 55 |
+
{
|
| 56 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 57 |
+
"messaging": {
|
| 58 |
+
"enabled": True,
|
| 59 |
+
"destinations": {
|
| 60 |
+
"slack.ops": destination,
|
| 61 |
+
},
|
| 62 |
+
},
|
| 63 |
+
}
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _test_session(
|
| 68 |
+
config: Config, gateway, session_id: str = "session-test"
|
| 69 |
+
) -> Session:
|
| 70 |
+
return Session(
|
| 71 |
+
asyncio.Queue(),
|
| 72 |
+
config=config,
|
| 73 |
+
tool_router=DummyToolRouter(),
|
| 74 |
+
context_manager=SimpleNamespace(items=[]),
|
| 75 |
+
notification_gateway=gateway,
|
| 76 |
+
session_id=session_id,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def test_messaging_config_validates_destination_names():
|
| 81 |
+
with pytest.raises(ValidationError):
|
| 82 |
+
Config.model_validate(
|
| 83 |
+
{
|
| 84 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 85 |
+
"messaging": {
|
| 86 |
+
"enabled": True,
|
| 87 |
+
"destinations": {
|
| 88 |
+
"Slack Ops": {
|
| 89 |
+
"provider": "slack",
|
| 90 |
+
"token": "x",
|
| 91 |
+
"channel": "C123",
|
| 92 |
+
}
|
| 93 |
+
},
|
| 94 |
+
},
|
| 95 |
+
}
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
config = _config_with_messaging(allow_agent_tool=True, allow_auto_events=True)
|
| 99 |
+
assert config.messaging.can_agent_tool_send("slack.ops")
|
| 100 |
+
assert config.messaging.can_auto_send("slack.ops")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def test_messaging_config_default_auto_destinations_only_returns_auto_enabled():
|
| 104 |
+
config = Config.model_validate(
|
| 105 |
+
{
|
| 106 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 107 |
+
"messaging": {
|
| 108 |
+
"enabled": True,
|
| 109 |
+
"destinations": {
|
| 110 |
+
"slack.ops": {
|
| 111 |
+
"provider": "slack",
|
| 112 |
+
"token": "xoxb-test",
|
| 113 |
+
"channel": "C123",
|
| 114 |
+
"allow_auto_events": True,
|
| 115 |
+
},
|
| 116 |
+
"slack.tool": {
|
| 117 |
+
"provider": "slack",
|
| 118 |
+
"token": "xoxb-test",
|
| 119 |
+
"channel": "C999",
|
| 120 |
+
"allow_agent_tool": True,
|
| 121 |
+
},
|
| 122 |
+
},
|
| 123 |
+
},
|
| 124 |
+
}
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
assert config.messaging.default_auto_destinations() == ["slack.ops"]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def test_messaging_config_default_auto_destinations_empty_when_disabled():
|
| 131 |
+
config = Config.model_validate(
|
| 132 |
+
{
|
| 133 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 134 |
+
"messaging": {
|
| 135 |
+
"enabled": False,
|
| 136 |
+
"destinations": {
|
| 137 |
+
"slack.ops": {
|
| 138 |
+
"provider": "slack",
|
| 139 |
+
"token": "xoxb-test",
|
| 140 |
+
"channel": "C123",
|
| 141 |
+
"allow_auto_events": True,
|
| 142 |
+
},
|
| 143 |
+
},
|
| 144 |
+
},
|
| 145 |
+
}
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
assert config.messaging.default_auto_destinations() == []
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test_slack_mrkdwn_formatter_converts_common_markdown():
|
| 152 |
+
formatted = _format_slack_mrkdwn(
|
| 153 |
+
"# Result\n"
|
| 154 |
+
"**Done** with *details* and ~~old text~~.\n"
|
| 155 |
+
"See [PR](https://github.com/huggingface/ml-intern/pull/116).\n"
|
| 156 |
+
"Keep `**literal**` and ```python\nx < 3\n``` untouched.\n"
|
| 157 |
+
"Escape <raw> & text."
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
assert "*Result*" in formatted
|
| 161 |
+
assert "*Done*" in formatted
|
| 162 |
+
assert "_details_" in formatted
|
| 163 |
+
assert "~old text~" in formatted
|
| 164 |
+
assert "<https://github.com/huggingface/ml-intern/pull/116|PR>" in formatted
|
| 165 |
+
assert "`**literal**`" in formatted
|
| 166 |
+
assert "```python\nx < 3\n```" in formatted
|
| 167 |
+
assert "Escape <raw> & text." in formatted
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@pytest.mark.asyncio
|
| 171 |
+
async def test_slack_provider_formats_and_sends_payload():
|
| 172 |
+
seen: dict[str, object] = {}
|
| 173 |
+
|
| 174 |
+
def handler(request: httpx.Request) -> httpx.Response:
|
| 175 |
+
seen["auth"] = request.headers["Authorization"]
|
| 176 |
+
seen["content_type"] = request.headers["Content-Type"]
|
| 177 |
+
seen["json"] = request.read().decode("utf-8")
|
| 178 |
+
return httpx.Response(200, json={"ok": True, "ts": "123.456"})
|
| 179 |
+
|
| 180 |
+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
|
| 181 |
+
provider = SlackProvider()
|
| 182 |
+
result = await provider.send(
|
| 183 |
+
client,
|
| 184 |
+
"slack.ops",
|
| 185 |
+
_config_with_messaging().messaging.destinations["slack.ops"],
|
| 186 |
+
NotificationRequest(
|
| 187 |
+
destination="slack.ops",
|
| 188 |
+
title="Approval required",
|
| 189 |
+
message="A **run** is waiting. See [details](https://example.com).",
|
| 190 |
+
severity="warning",
|
| 191 |
+
metadata={"session_id": "sess-1"},
|
| 192 |
+
),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
assert result.ok
|
| 196 |
+
assert result.external_id == "123.456"
|
| 197 |
+
assert seen["auth"] == "Bearer xoxb-test"
|
| 198 |
+
assert seen["content_type"].startswith("application/json")
|
| 199 |
+
payload = json.loads(str(seen["json"]))
|
| 200 |
+
assert payload["channel"] == "C123"
|
| 201 |
+
assert payload["mrkdwn"] is True
|
| 202 |
+
assert payload["text"] == (
|
| 203 |
+
"[WARNING] Approval required\n"
|
| 204 |
+
"A *run* is waiting. See <https://example.com|details>.\n"
|
| 205 |
+
"session_id: sess-1"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@pytest.mark.asyncio
|
| 210 |
+
async def test_notification_gateway_retries_transient_failures(monkeypatch):
|
| 211 |
+
attempts = {"count": 0}
|
| 212 |
+
|
| 213 |
+
def handler(_request: httpx.Request) -> httpx.Response:
|
| 214 |
+
attempts["count"] += 1
|
| 215 |
+
if attempts["count"] == 1:
|
| 216 |
+
return httpx.Response(503, json={"ok": False})
|
| 217 |
+
return httpx.Response(200, json={"ok": True, "ts": "999.1"})
|
| 218 |
+
|
| 219 |
+
async def fake_sleep(_delay: float) -> None:
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
monkeypatch.setattr("agent.messaging.gateway.asyncio.sleep", fake_sleep)
|
| 223 |
+
|
| 224 |
+
config = _config_with_messaging(allow_agent_tool=True)
|
| 225 |
+
gateway = NotificationGateway(config.messaging)
|
| 226 |
+
async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
|
| 227 |
+
gateway._client = client
|
| 228 |
+
result = await gateway.send(
|
| 229 |
+
NotificationRequest(
|
| 230 |
+
destination="slack.ops",
|
| 231 |
+
message="hello",
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
gateway._client = None
|
| 235 |
+
|
| 236 |
+
assert attempts["count"] == 2
|
| 237 |
+
assert result.ok
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@pytest.mark.asyncio
|
| 241 |
+
async def test_notify_tool_rejects_non_allowlisted_destinations():
|
| 242 |
+
config = _config_with_messaging(allow_agent_tool=False)
|
| 243 |
+
gateway = RecordingGateway()
|
| 244 |
+
session = _test_session(config, gateway)
|
| 245 |
+
|
| 246 |
+
output, ok = await notify_handler(
|
| 247 |
+
{"destinations": ["slack.ops"], "message": "done"},
|
| 248 |
+
session=session,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
assert not ok
|
| 252 |
+
assert "unavailable for the notify tool" in output
|
| 253 |
+
assert gateway.sent == []
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@pytest.mark.asyncio
|
| 257 |
+
async def test_notify_tool_sends_to_allowlisted_destinations():
|
| 258 |
+
config = _config_with_messaging(allow_agent_tool=True)
|
| 259 |
+
gateway = RecordingGateway()
|
| 260 |
+
session = _test_session(config, gateway, session_id="sess-42")
|
| 261 |
+
|
| 262 |
+
output, ok = await notify_handler(
|
| 263 |
+
{
|
| 264 |
+
"destinations": ["slack.ops"],
|
| 265 |
+
"title": "Training complete",
|
| 266 |
+
"message": "The run finished successfully.",
|
| 267 |
+
"severity": "success",
|
| 268 |
+
},
|
| 269 |
+
session=session,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
assert ok
|
| 273 |
+
assert output == "slack.ops: sent"
|
| 274 |
+
assert len(gateway.sent) == 1
|
| 275 |
+
sent = gateway.sent[0]
|
| 276 |
+
assert sent.metadata["session_id"] == "sess-42"
|
| 277 |
+
assert sent.metadata["model"] == "moonshotai/Kimi-K2.6"
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@pytest.mark.asyncio
|
| 281 |
+
async def test_session_auto_notifications_only_send_opted_in_auto_destinations():
|
| 282 |
+
config = Config.model_validate(
|
| 283 |
+
{
|
| 284 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 285 |
+
"messaging": {
|
| 286 |
+
"enabled": True,
|
| 287 |
+
"destinations": {
|
| 288 |
+
"slack.ops": {
|
| 289 |
+
"provider": "slack",
|
| 290 |
+
"token": "xoxb-test",
|
| 291 |
+
"channel": "C123",
|
| 292 |
+
"allow_auto_events": True,
|
| 293 |
+
},
|
| 294 |
+
"slack.tool": {
|
| 295 |
+
"provider": "slack",
|
| 296 |
+
"token": "xoxb-test",
|
| 297 |
+
"channel": "C999",
|
| 298 |
+
"allow_agent_tool": True,
|
| 299 |
+
},
|
| 300 |
+
},
|
| 301 |
+
},
|
| 302 |
+
}
|
| 303 |
+
)
|
| 304 |
+
gateway = RecordingGateway()
|
| 305 |
+
session = _test_session(config, gateway, session_id="sess-auto")
|
| 306 |
+
session.set_notification_destinations(["slack.ops", "slack.tool"])
|
| 307 |
+
|
| 308 |
+
await session.send_event(
|
| 309 |
+
Event(
|
| 310 |
+
event_type="approval_required",
|
| 311 |
+
data={"tools": [{"tool": "hf_jobs", "tool_call_id": "tc-1"}]},
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
await session.send_event(
|
| 315 |
+
Event(event_type="assistant_message", data={"content": "normal message"})
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
assert len(gateway.enqueued) == 1
|
| 319 |
+
request = gateway.enqueued[0]
|
| 320 |
+
assert request.destination == "slack.ops"
|
| 321 |
+
assert request.severity == "warning"
|
| 322 |
+
assert request.event_type == "approval_required"
|
| 323 |
+
assert "hf_jobs" in request.message
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@pytest.mark.asyncio
|
| 327 |
+
async def test_turn_complete_auto_notification_includes_final_response_summary():
|
| 328 |
+
config = Config.model_validate(
|
| 329 |
+
{
|
| 330 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 331 |
+
"messaging": {
|
| 332 |
+
"enabled": True,
|
| 333 |
+
"destinations": {
|
| 334 |
+
"slack.ops": {
|
| 335 |
+
"provider": "slack",
|
| 336 |
+
"token": "xoxb-test",
|
| 337 |
+
"channel": "C123",
|
| 338 |
+
"allow_auto_events": True,
|
| 339 |
+
}
|
| 340 |
+
},
|
| 341 |
+
},
|
| 342 |
+
}
|
| 343 |
+
)
|
| 344 |
+
gateway = RecordingGateway()
|
| 345 |
+
session = _test_session(config, gateway, session_id="sess-done")
|
| 346 |
+
session.set_notification_destinations(["slack.ops"])
|
| 347 |
+
|
| 348 |
+
await session.send_event(
|
| 349 |
+
Event(
|
| 350 |
+
event_type="turn_complete",
|
| 351 |
+
data={
|
| 352 |
+
"history_size": 12,
|
| 353 |
+
"final_response": "Evaluation finished. Accuracy: 84.2% on the validation split.",
|
| 354 |
+
},
|
| 355 |
+
)
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
assert len(gateway.enqueued) == 1
|
| 359 |
+
request = gateway.enqueued[0]
|
| 360 |
+
assert request.destination == "slack.ops"
|
| 361 |
+
assert request.severity == "success"
|
| 362 |
+
assert request.event_type == "turn_complete"
|
| 363 |
+
assert "completed successfully" in request.message
|
| 364 |
+
assert "Accuracy: 84.2%" in request.message
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
@pytest.mark.asyncio
|
| 368 |
+
async def test_turn_complete_auto_notification_supports_longer_summary():
|
| 369 |
+
config = Config.model_validate(
|
| 370 |
+
{
|
| 371 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 372 |
+
"messaging": {
|
| 373 |
+
"enabled": True,
|
| 374 |
+
"destinations": {
|
| 375 |
+
"slack.ops": {
|
| 376 |
+
"provider": "slack",
|
| 377 |
+
"token": "xoxb-test",
|
| 378 |
+
"channel": "C123",
|
| 379 |
+
"allow_auto_events": True,
|
| 380 |
+
}
|
| 381 |
+
},
|
| 382 |
+
},
|
| 383 |
+
}
|
| 384 |
+
)
|
| 385 |
+
gateway = RecordingGateway()
|
| 386 |
+
session = _test_session(config, gateway, session_id="sess-long")
|
| 387 |
+
session.set_notification_destinations(["slack.ops"])
|
| 388 |
+
|
| 389 |
+
long_summary = "A" * 1200 + " END"
|
| 390 |
+
await session.send_event(
|
| 391 |
+
Event(
|
| 392 |
+
event_type="turn_complete",
|
| 393 |
+
data={
|
| 394 |
+
"history_size": 12,
|
| 395 |
+
"final_response": long_summary,
|
| 396 |
+
},
|
| 397 |
+
)
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
assert len(gateway.enqueued) == 1
|
| 401 |
+
request = gateway.enqueued[0]
|
| 402 |
+
assert request.event_type == "turn_complete"
|
| 403 |
+
assert "A" * 1200 in request.message
|
| 404 |
+
assert request.message.endswith("END")
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
@pytest.mark.asyncio
|
| 408 |
+
async def test_turn_complete_auto_notification_can_be_deferred():
|
| 409 |
+
config = Config.model_validate(
|
| 410 |
+
{
|
| 411 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 412 |
+
"messaging": {
|
| 413 |
+
"enabled": True,
|
| 414 |
+
"destinations": {
|
| 415 |
+
"slack.ops": {
|
| 416 |
+
"provider": "slack",
|
| 417 |
+
"token": "xoxb-test",
|
| 418 |
+
"channel": "C123",
|
| 419 |
+
"allow_auto_events": True,
|
| 420 |
+
}
|
| 421 |
+
},
|
| 422 |
+
},
|
| 423 |
+
}
|
| 424 |
+
)
|
| 425 |
+
gateway = RecordingGateway()
|
| 426 |
+
session = Session(
|
| 427 |
+
asyncio.Queue(),
|
| 428 |
+
config=config,
|
| 429 |
+
tool_router=DummyToolRouter(),
|
| 430 |
+
context_manager=SimpleNamespace(items=[]),
|
| 431 |
+
notification_gateway=gateway,
|
| 432 |
+
notification_destinations=["slack.ops"],
|
| 433 |
+
defer_turn_complete_notification=True,
|
| 434 |
+
session_id="sess-deferred",
|
| 435 |
+
)
|
| 436 |
+
event = Event(
|
| 437 |
+
event_type="turn_complete",
|
| 438 |
+
data={"final_response": "Finished after the CLI drained the stream."},
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
await session.send_event(event)
|
| 442 |
+
assert gateway.enqueued == []
|
| 443 |
+
|
| 444 |
+
await session.send_deferred_turn_complete_notification(event)
|
| 445 |
+
|
| 446 |
+
assert len(gateway.enqueued) == 1
|
| 447 |
+
request = gateway.enqueued[0]
|
| 448 |
+
assert request.destination == "slack.ops"
|
| 449 |
+
assert request.event_type == "turn_complete"
|
| 450 |
+
assert "Finished after the CLI drained the stream." in request.message
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
@pytest.mark.asyncio
|
| 454 |
+
async def test_turn_complete_can_be_disabled_by_custom_auto_event_config():
|
| 455 |
+
config = Config.model_validate(
|
| 456 |
+
{
|
| 457 |
+
"model_name": "moonshotai/Kimi-K2.6",
|
| 458 |
+
"messaging": {
|
| 459 |
+
"enabled": True,
|
| 460 |
+
"auto_event_types": ["error"],
|
| 461 |
+
"destinations": {
|
| 462 |
+
"slack.ops": {
|
| 463 |
+
"provider": "slack",
|
| 464 |
+
"token": "xoxb-test",
|
| 465 |
+
"channel": "C123",
|
| 466 |
+
"allow_auto_events": True,
|
| 467 |
+
}
|
| 468 |
+
},
|
| 469 |
+
},
|
| 470 |
+
}
|
| 471 |
+
)
|
| 472 |
+
gateway = RecordingGateway()
|
| 473 |
+
session = _test_session(config, gateway, session_id="sess-optout")
|
| 474 |
+
session.set_notification_destinations(["slack.ops"])
|
| 475 |
+
|
| 476 |
+
await session.send_event(
|
| 477 |
+
Event(
|
| 478 |
+
event_type="turn_complete",
|
| 479 |
+
data={"final_response": "This should not notify."},
|
| 480 |
+
)
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
assert gateway.enqueued == []
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def test_session_manager_updates_notification_destinations_in_session_info():
|
| 487 |
+
config = _config_with_messaging(allow_auto_events=True)
|
| 488 |
+
manager = SessionManager(str(Path(__file__).resolve().parents[2] / "configs" / "cli_agent_config.json"))
|
| 489 |
+
manager.config = config
|
| 490 |
+
manager.sessions = {}
|
| 491 |
+
|
| 492 |
+
session = _test_session(config, RecordingGateway(), session_id="sess-manager")
|
| 493 |
+
manager.sessions["sess-manager"] = AgentSession(
|
| 494 |
+
session_id="sess-manager",
|
| 495 |
+
session=session,
|
| 496 |
+
tool_router=DummyToolRouter(),
|
| 497 |
+
submission_queue=asyncio.Queue(),
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
updated = manager.set_notification_destinations(
|
| 501 |
+
"sess-manager",
|
| 502 |
+
["slack.ops", "slack.ops"],
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
assert updated == ["slack.ops"]
|
| 506 |
+
info = manager.get_session_info("sess-manager")
|
| 507 |
+
assert info is not None
|
| 508 |
+
assert info["notification_destinations"] == ["slack.ops"]
|
| 509 |
+
|
| 510 |
+
with pytest.raises(ValueError):
|
| 511 |
+
manager.set_notification_destinations("sess-manager", ["slack.unknown"])
|
|
@@ -159,7 +159,7 @@ async def test_streaming_call_rebuilds_anthropic_thinking_state(monkeypatch):
|
|
| 159 |
|
| 160 |
|
| 161 |
@pytest.mark.asyncio
|
| 162 |
-
async def
|
| 163 |
async def fake_stream():
|
| 164 |
yield SimpleNamespace(
|
| 165 |
choices=[
|
|
@@ -167,7 +167,31 @@ async def test_streaming_call_collects_anthropic_delta_thinking_state(monkeypatc
|
|
| 167 |
delta=SimpleNamespace(
|
| 168 |
content=None,
|
| 169 |
tool_calls=None,
|
| 170 |
-
thinking_blocks=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
),
|
| 172 |
finish_reason=None,
|
| 173 |
)
|
|
@@ -186,8 +210,26 @@ async def test_streaming_call_collects_anthropic_delta_thinking_state(monkeypatc
|
|
| 186 |
async def fake_acompletion(**_kwargs):
|
| 187 |
return fake_stream()
|
| 188 |
|
| 189 |
-
def
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
events = []
|
| 193 |
async def send_event(event):
|
|
@@ -199,7 +241,7 @@ async def test_streaming_call_collects_anthropic_delta_thinking_state(monkeypatc
|
|
| 199 |
send_event=send_event,
|
| 200 |
)
|
| 201 |
monkeypatch.setattr(agent_loop, "acompletion", fake_acompletion)
|
| 202 |
-
monkeypatch.setattr(agent_loop, "stream_chunk_builder",
|
| 203 |
|
| 204 |
result = await _call_llm_streaming(
|
| 205 |
session,
|
|
@@ -209,7 +251,10 @@ async def test_streaming_call_collects_anthropic_delta_thinking_state(monkeypatc
|
|
| 209 |
)
|
| 210 |
|
| 211 |
assert result.content == "done"
|
| 212 |
-
assert result.thinking_blocks == [
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
|
| 215 |
@pytest.mark.asyncio
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
@pytest.mark.asyncio
|
| 162 |
+
async def test_streaming_call_rebuilds_anthropic_delta_thinking_state(monkeypatch):
|
| 163 |
async def fake_stream():
|
| 164 |
yield SimpleNamespace(
|
| 165 |
choices=[
|
|
|
|
| 167 |
delta=SimpleNamespace(
|
| 168 |
content=None,
|
| 169 |
tool_calls=None,
|
| 170 |
+
thinking_blocks=[
|
| 171 |
+
{
|
| 172 |
+
"type": "thinking",
|
| 173 |
+
"thinking": "reasoned",
|
| 174 |
+
"signature": "",
|
| 175 |
+
}
|
| 176 |
+
],
|
| 177 |
+
),
|
| 178 |
+
finish_reason=None,
|
| 179 |
+
)
|
| 180 |
+
],
|
| 181 |
+
)
|
| 182 |
+
yield SimpleNamespace(
|
| 183 |
+
choices=[
|
| 184 |
+
SimpleNamespace(
|
| 185 |
+
delta=SimpleNamespace(
|
| 186 |
+
content=None,
|
| 187 |
+
tool_calls=None,
|
| 188 |
+
thinking_blocks=[
|
| 189 |
+
{
|
| 190 |
+
"type": "thinking",
|
| 191 |
+
"thinking": "",
|
| 192 |
+
"signature": "signed",
|
| 193 |
+
}
|
| 194 |
+
],
|
| 195 |
),
|
| 196 |
finish_reason=None,
|
| 197 |
)
|
|
|
|
| 210 |
async def fake_acompletion(**_kwargs):
|
| 211 |
return fake_stream()
|
| 212 |
|
| 213 |
+
def fake_chunk_builder(chunks, **_kwargs):
|
| 214 |
+
assert len(chunks) == 4
|
| 215 |
+
return SimpleNamespace(
|
| 216 |
+
choices=[
|
| 217 |
+
SimpleNamespace(
|
| 218 |
+
message=Message(
|
| 219 |
+
role="assistant",
|
| 220 |
+
content="done",
|
| 221 |
+
thinking_blocks=[
|
| 222 |
+
{
|
| 223 |
+
"type": "thinking",
|
| 224 |
+
"thinking": "reasoned",
|
| 225 |
+
"signature": "signed",
|
| 226 |
+
}
|
| 227 |
+
],
|
| 228 |
+
reasoning_content="reasoned",
|
| 229 |
+
)
|
| 230 |
+
)
|
| 231 |
+
]
|
| 232 |
+
)
|
| 233 |
|
| 234 |
events = []
|
| 235 |
async def send_event(event):
|
|
|
|
| 241 |
send_event=send_event,
|
| 242 |
)
|
| 243 |
monkeypatch.setattr(agent_loop, "acompletion", fake_acompletion)
|
| 244 |
+
monkeypatch.setattr(agent_loop, "stream_chunk_builder", fake_chunk_builder)
|
| 245 |
|
| 246 |
result = await _call_llm_streaming(
|
| 247 |
session,
|
|
|
|
| 251 |
)
|
| 252 |
|
| 253 |
assert result.content == "done"
|
| 254 |
+
assert result.thinking_blocks == [
|
| 255 |
+
{"type": "thinking", "thinking": "reasoned", "signature": "signed"}
|
| 256 |
+
]
|
| 257 |
+
assert result.reasoning_content == "reasoned"
|
| 258 |
|
| 259 |
|
| 260 |
@pytest.mark.asyncio
|
|
@@ -1832,7 +1832,7 @@ requires-dist = [
|
|
| 1832 |
{ name = "prompt-toolkit", specifier = ">=3.0.0" },
|
| 1833 |
{ name = "pydantic", specifier = ">=2.12.3" },
|
| 1834 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
|
| 1835 |
-
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=
|
| 1836 |
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
| 1837 |
{ name = "requests", specifier = ">=2.33.0" },
|
| 1838 |
{ name = "rich", specifier = ">=13.0.0" },
|
|
|
|
| 1832 |
{ name = "prompt-toolkit", specifier = ">=3.0.0" },
|
| 1833 |
{ name = "pydantic", specifier = ">=2.12.3" },
|
| 1834 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
|
| 1835 |
+
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.2.0" },
|
| 1836 |
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
| 1837 |
{ name = "requests", specifier = ">=2.33.0" },
|
| 1838 |
{ name = "rich", specifier = ">=13.0.0" },
|