lewtun HF Staff OpenAI Codex commited on
Commit
6155b26
·
unverified ·
1 Parent(s): 2d4ec20

Add Slack gateway (#116)

Browse files

* Add messaging gateway and notify tool

Co-authored-by: OpenAI Codex <codex@openai.com>

* Handle Bedrock streaming permission denials

Co-authored-by: OpenAI Codex <codex@openai.com>

* Revert "Handle Bedrock streaming permission denials"

Co-authored-by: OpenAI Codex <codex@openai.com>

* Add automatic completion notifications

Co-authored-by: OpenAI Codex <codex@openai.com>

* Auto-attach CLI notification destinations

Co-authored-by: OpenAI Codex <codex@openai.com>

* Defer CLI completion notifications until render

Co-authored-by: OpenAI Codex <codex@openai.com>

* Increase completion notification summary cap

Co-authored-by: OpenAI Codex <codex@openai.com>

* Add Slack user notification defaults

Co-authored-by: Codex <codex@openai.com>

* Increase Slack turn completion summary limit

Co-authored-by: Codex <codex@openai.com>

* Remove legacy auto notification event upgrade

Co-authored-by: Codex <codex@openai.com>

* Require session config instead of hard-coded model fallback

Co-authored-by: Codex <codex@openai.com>

* Address Slack notification review findings

Co-authored-by: Codex <codex@openai.com>

* Fix Anthropic thinking signature replay

Rebuild signed Anthropic thinking blocks from streaming chunks instead of replaying raw deltas, and recover stale histories by retrying once without thinking metadata when Anthropic rejects a signature.

Co-authored-by: OpenAI Codex <codex@openai.com>

* Format Slack notifications with mrkdwn

Convert common Markdown constructs in Slack notification bodies to Slack mrkdwn before posting, while preserving code spans and fenced code blocks.

Co-authored-by: OpenAI Codex <codex@openai.com>

---------

Co-authored-by: OpenAI Codex <codex@openai.com>

README.md CHANGED
@@ -56,6 +56,56 @@ ml-intern --max-iterations 100 "your prompt"
56
  ml-intern --no-stream "your prompt"
57
  ```
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  ## Architecture
60
 
61
  ### Component Overview
 
56
  ml-intern --no-stream "your prompt"
57
  ```
58
 
59
+ ## Supported Gateways
60
+
61
+ ML Intern currently supports one-way notification gateways from CLI sessions.
62
+ These gateways send out-of-band status updates; they do not accept inbound chat
63
+ messages.
64
+
65
+ ### Slack
66
+
67
+ Slack notifications use the Slack Web API to post messages when the agent needs
68
+ approval, hits an error, or completes a turn. Create a Slack app with a bot token
69
+ that has `chat:write`, invite the bot to the target channel, then set:
70
+
71
+ ```bash
72
+ SLACK_BOT_TOKEN=xoxb-...
73
+ SLACK_CHANNEL_ID=C...
74
+ ```
75
+
76
+ The CLI automatically creates a `slack.default` destination when both variables
77
+ are present. Optional environment variables for the env-only default:
78
+
79
+ ```bash
80
+ ML_INTERN_SLACK_NOTIFICATIONS=false
81
+ ML_INTERN_SLACK_DESTINATION=slack.ops
82
+ ML_INTERN_SLACK_AUTO_EVENTS=approval_required,error,turn_complete
83
+ ML_INTERN_SLACK_ALLOW_AGENT_TOOL=true
84
+ ML_INTERN_SLACK_ALLOW_AUTO_EVENTS=true
85
+ ```
86
+
87
+ For a persistent user-level config, put overrides in
88
+ `~/.config/ml-intern/cli_agent_config.json` or point `ML_INTERN_CLI_CONFIG` at a
89
+ JSON file:
90
+
91
+ ```json
92
+ {
93
+ "messaging": {
94
+ "enabled": true,
95
+ "auto_event_types": ["approval_required", "error", "turn_complete"],
96
+ "destinations": {
97
+ "slack.ops": {
98
+ "provider": "slack",
99
+ "token": "${SLACK_BOT_TOKEN}",
100
+ "channel": "${SLACK_CHANNEL_ID}",
101
+ "allow_agent_tool": true,
102
+ "allow_auto_events": true
103
+ }
104
+ }
105
+ }
106
+ }
107
+ ```
108
+
109
  ## Architecture
110
 
111
  ### Component Overview
agent/config.py CHANGED
@@ -6,6 +6,8 @@ from typing import Any, Union
6
 
7
  from dotenv import load_dotenv
8
 
 
 
9
  # Project root: two levels up from this file (agent/config.py -> project root)
10
  _PROJECT_ROOT = Path(__file__).resolve().parent.parent
11
  from fastmcp.mcp_config import (
@@ -47,6 +49,104 @@ class Config(BaseModel):
47
  # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
48
  # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
49
  reasoning_effort: str | None = "max"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def substitute_env_vars(obj: Any) -> Any:
@@ -86,7 +186,10 @@ def substitute_env_vars(obj: Any) -> Any:
86
  return obj
87
 
88
 
89
- def load_config(config_path: str = "config.json") -> Config:
 
 
 
90
  """
91
  Load configuration with environment variable substitution.
92
 
@@ -98,8 +201,10 @@ def load_config(config_path: str = "config.json") -> Config:
98
  load_dotenv(_PROJECT_ROOT / ".env")
99
  load_dotenv(override=False)
100
 
101
- with open(config_path, "r") as f:
102
- raw_config = json.load(f)
 
 
103
 
104
  config_with_env = substitute_env_vars(raw_config)
105
  return Config.model_validate(config_with_env)
 
6
 
7
  from dotenv import load_dotenv
8
 
9
+ from agent.messaging.models import MessagingConfig
10
+
11
  # Project root: two levels up from this file (agent/config.py -> project root)
12
  _PROJECT_ROOT = Path(__file__).resolve().parent.parent
13
  from fastmcp.mcp_config import (
 
49
  # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
50
  # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
51
  reasoning_effort: str | None = "max"
52
+ messaging: MessagingConfig = MessagingConfig()
53
+
54
+
55
+ USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG"
56
+ DEFAULT_USER_CONFIG_PATH = Path.home() / ".config" / "ml-intern" / "cli_agent_config.json"
57
+ SLACK_DEFAULT_DESTINATION = "slack.default"
58
+ SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"]
59
+
60
+
61
+ def _deep_merge_config(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
62
+ merged = dict(base)
63
+ for key, value in override.items():
64
+ current = merged.get(key)
65
+ if isinstance(current, dict) and isinstance(value, dict):
66
+ merged[key] = _deep_merge_config(current, value)
67
+ else:
68
+ merged[key] = value
69
+ return merged
70
+
71
+
72
+ def _load_json_config(path: Path) -> dict[str, Any]:
73
+ with open(path, "r", encoding="utf-8") as f:
74
+ data = json.load(f)
75
+ if not isinstance(data, dict):
76
+ raise ValueError(f"Config file {path} must contain a JSON object")
77
+ return data
78
+
79
+
80
+ def _load_user_config() -> dict[str, Any]:
81
+ raw_path = os.environ.get(USER_CONFIG_ENV_VAR)
82
+ if raw_path:
83
+ path = Path(raw_path).expanduser()
84
+ if not path.exists():
85
+ raise FileNotFoundError(
86
+ f"{USER_CONFIG_ENV_VAR} points to missing config file: {path}"
87
+ )
88
+ return _load_json_config(path)
89
+
90
+ if DEFAULT_USER_CONFIG_PATH.exists():
91
+ return _load_json_config(DEFAULT_USER_CONFIG_PATH)
92
+ return {}
93
+
94
+
95
+ def _env_bool(name: str, default: bool) -> bool:
96
+ value = os.environ.get(name)
97
+ if value is None:
98
+ return default
99
+ normalized = value.strip().lower()
100
+ if normalized in {"1", "true", "yes", "on"}:
101
+ return True
102
+ if normalized in {"0", "false", "no", "off"}:
103
+ return False
104
+ return default
105
+
106
+
107
+ def _env_list(name: str) -> list[str] | None:
108
+ value = os.environ.get(name)
109
+ if value is None:
110
+ return None
111
+ return [item.strip() for item in value.split(",") if item.strip()]
112
+
113
+
114
+ def apply_slack_user_defaults(raw_config: dict[str, Any]) -> dict[str, Any]:
115
+ """Enable a default Slack destination from user env vars, when present."""
116
+ if not _env_bool("ML_INTERN_SLACK_NOTIFICATIONS", True):
117
+ return raw_config
118
+
119
+ token = os.environ.get("SLACK_BOT_TOKEN")
120
+ channel = os.environ.get("SLACK_CHANNEL_ID") or os.environ.get("SLACK_CHANNEL")
121
+ if not token or not channel:
122
+ return raw_config
123
+
124
+ config = dict(raw_config)
125
+ messaging = dict(config.get("messaging") or {})
126
+ destinations = dict(messaging.get("destinations") or {})
127
+ destination_name = (
128
+ os.environ.get("ML_INTERN_SLACK_DESTINATION") or SLACK_DEFAULT_DESTINATION
129
+ ).strip()
130
+
131
+ if destination_name not in destinations:
132
+ destinations[destination_name] = {
133
+ "provider": "slack",
134
+ "token": token,
135
+ "channel": channel,
136
+ "allow_agent_tool": _env_bool("ML_INTERN_SLACK_ALLOW_AGENT_TOOL", True),
137
+ "allow_auto_events": _env_bool("ML_INTERN_SLACK_ALLOW_AUTO_EVENTS", True),
138
+ }
139
+
140
+ auto_events = _env_list("ML_INTERN_SLACK_AUTO_EVENTS")
141
+ if auto_events is not None:
142
+ messaging["auto_event_types"] = auto_events
143
+ elif "auto_event_types" not in messaging:
144
+ messaging["auto_event_types"] = SLACK_DEFAULT_AUTO_EVENT_TYPES
145
+
146
+ messaging["enabled"] = True
147
+ messaging["destinations"] = destinations
148
+ config["messaging"] = messaging
149
+ return config
150
 
151
 
152
  def substitute_env_vars(obj: Any) -> Any:
 
186
  return obj
187
 
188
 
189
+ def load_config(
190
+ config_path: str = "config.json",
191
+ include_user_defaults: bool = False,
192
+ ) -> Config:
193
  """
194
  Load configuration with environment variable substitution.
195
 
 
201
  load_dotenv(_PROJECT_ROOT / ".env")
202
  load_dotenv(override=False)
203
 
204
+ raw_config = _load_json_config(Path(config_path))
205
+ if include_user_defaults:
206
+ raw_config = _deep_merge_config(raw_config, _load_user_config())
207
+ raw_config = apply_slack_user_defaults(raw_config)
208
 
209
  config_with_env = substitute_env_vars(raw_config)
210
  return Config.model_validate(config_with_env)
agent/core/agent_loop.py CHANGED
@@ -19,6 +19,7 @@ from litellm import (
19
  from litellm.exceptions import ContextWindowExceededError
20
 
21
  from agent.config import Config
 
22
  from agent.core import telemetry
23
  from agent.core.doom_loop import check_for_doom_loop
24
  from agent.core.llm_params import _resolve_llm_params
@@ -432,6 +433,103 @@ def _should_replay_thinking_state(model_name: str | None) -> bool:
432
  return bool(model_name and model_name.startswith("anthropic/"))
433
 
434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  def _assistant_message_from_result(
436
  llm_result: LLMResult,
437
  *,
@@ -457,6 +555,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
457
  """Call the LLM with streaming, emitting assistant_chunk events."""
458
  response = None
459
  _healed_effort = False # one-shot safety net per call
 
460
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
461
  t_start = time.monotonic()
462
  for _llm_attempt in range(_MAX_LLM_RETRIES):
@@ -484,6 +583,14 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
484
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
485
  ))
486
  continue
 
 
 
 
 
 
 
 
487
  _delay = _retry_delay_for(e, _llm_attempt)
488
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
489
  logger.warning(
@@ -505,8 +612,6 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
505
  final_usage_chunk = None
506
  chunks = []
507
  should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))
508
- collected_thinking_blocks: list[dict[str, Any]] = []
509
- collected_reasoning_content: list[str] = []
510
 
511
  async for chunk in response:
512
  chunks.append(chunk)
@@ -525,13 +630,6 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
525
  if choice.finish_reason:
526
  finish_reason = choice.finish_reason
527
 
528
- if should_replay_thinking:
529
- delta_thinking_blocks, delta_reasoning_content = _extract_thinking_state(delta)
530
- if delta_thinking_blocks:
531
- collected_thinking_blocks.extend(delta_thinking_blocks)
532
- if delta_reasoning_content:
533
- collected_reasoning_content.append(delta_reasoning_content)
534
-
535
  if delta.content:
536
  full_content += delta.content
537
  await session.send_event(
@@ -565,9 +663,9 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
565
  latency_ms=int((time.monotonic() - t_start) * 1000),
566
  finish_reason=finish_reason,
567
  )
568
- thinking_blocks = collected_thinking_blocks or None
569
- reasoning_content = "".join(collected_reasoning_content) or None
570
- if chunks and should_replay_thinking and not (thinking_blocks or reasoning_content):
571
  try:
572
  rebuilt = stream_chunk_builder(chunks, messages=messages)
573
  if rebuilt and getattr(rebuilt, "choices", None):
@@ -591,6 +689,7 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
591
  """Call the LLM without streaming, emit assistant_message at the end."""
592
  response = None
593
  _healed_effort = False
 
594
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
595
  t_start = time.monotonic()
596
  for _llm_attempt in range(_MAX_LLM_RETRIES):
@@ -617,6 +716,14 @@ async def _call_llm_non_streaming(session: Session, messages, tools, llm_params)
617
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
618
  ))
619
  continue
 
 
 
 
 
 
 
 
620
  _delay = _retry_delay_for(e, _llm_attempt)
621
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
622
  logger.warning(
@@ -1128,7 +1235,12 @@ class Handlers:
1128
  await session.send_event(
1129
  Event(
1130
  event_type="turn_complete",
1131
- data={"history_size": len(session.context_manager.items)},
 
 
 
 
 
1132
  )
1133
  )
1134
 
@@ -1437,13 +1549,16 @@ async def process_submission(session: Session, submission) -> bool:
1437
  async def submission_loop(
1438
  submission_queue: asyncio.Queue,
1439
  event_queue: asyncio.Queue,
1440
- config: Config | None = None,
1441
  tool_router: ToolRouter | None = None,
1442
  session_holder: list | None = None,
1443
  hf_token: str | None = None,
1444
  user_id: str | None = None,
1445
  local_mode: bool = False,
1446
  stream: bool = True,
 
 
 
1447
  ) -> None:
1448
  """
1449
  Main agent loop - processes submissions and dispatches to handlers.
@@ -1454,6 +1569,9 @@ async def submission_loop(
1454
  session = Session(
1455
  event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
1456
  user_id=user_id, local_mode=local_mode, stream=stream,
 
 
 
1457
  )
1458
  if session_holder is not None:
1459
  session_holder[0] = session
 
19
  from litellm.exceptions import ContextWindowExceededError
20
 
21
  from agent.config import Config
22
+ from agent.messaging.gateway import NotificationGateway
23
  from agent.core import telemetry
24
  from agent.core.doom_loop import check_for_doom_loop
25
  from agent.core.llm_params import _resolve_llm_params
 
433
  return bool(model_name and model_name.startswith("anthropic/"))
434
 
435
 
436
+ def _is_invalid_thinking_signature_error(exc: Exception) -> bool:
437
+ """Return True when Anthropic rejected replayed extended-thinking state."""
438
+ text = str(exc)
439
+ return (
440
+ "Invalid `signature` in `thinking` block" in text
441
+ or "Invalid signature in thinking block" in text
442
+ )
443
+
444
+
445
+ def _strip_thinking_state_from_messages(messages: list[Any]) -> int:
446
+ """Remove replayed thinking metadata from assistant history messages."""
447
+ stripped = 0
448
+
449
+ for message in messages:
450
+ role = (
451
+ message.get("role")
452
+ if isinstance(message, dict)
453
+ else getattr(message, "role", None)
454
+ )
455
+ if role != "assistant":
456
+ continue
457
+
458
+ if isinstance(message, dict):
459
+ if message.pop("thinking_blocks", None) is not None:
460
+ stripped += 1
461
+ if message.pop("reasoning_content", None) is not None:
462
+ stripped += 1
463
+ provider_fields = message.get("provider_specific_fields")
464
+ content = message.get("content")
465
+ else:
466
+ if getattr(message, "thinking_blocks", None) is not None:
467
+ message.thinking_blocks = None
468
+ stripped += 1
469
+ if getattr(message, "reasoning_content", None) is not None:
470
+ message.reasoning_content = None
471
+ stripped += 1
472
+ provider_fields = getattr(message, "provider_specific_fields", None)
473
+ content = getattr(message, "content", None)
474
+
475
+ if isinstance(provider_fields, dict):
476
+ cleaned_fields = dict(provider_fields)
477
+ if cleaned_fields.pop("thinking_blocks", None) is not None:
478
+ stripped += 1
479
+ if cleaned_fields.pop("reasoning_content", None) is not None:
480
+ stripped += 1
481
+ if cleaned_fields != provider_fields:
482
+ if isinstance(message, dict):
483
+ message["provider_specific_fields"] = cleaned_fields
484
+ else:
485
+ message.provider_specific_fields = cleaned_fields
486
+
487
+ if isinstance(content, list):
488
+ cleaned_content = [
489
+ block
490
+ for block in content
491
+ if not (
492
+ isinstance(block, dict)
493
+ and block.get("type") in {"thinking", "redacted_thinking"}
494
+ )
495
+ ]
496
+ if len(cleaned_content) != len(content):
497
+ stripped += len(content) - len(cleaned_content)
498
+ if isinstance(message, dict):
499
+ message["content"] = cleaned_content
500
+ else:
501
+ message.content = cleaned_content
502
+
503
+ return stripped
504
+
505
+
506
+ async def _maybe_heal_invalid_thinking_signature(
507
+ session: Session,
508
+ messages: list[Any],
509
+ exc: Exception,
510
+ *,
511
+ already_healed: bool,
512
+ ) -> bool:
513
+ if already_healed or not _is_invalid_thinking_signature_error(exc):
514
+ return False
515
+
516
+ stripped = _strip_thinking_state_from_messages(messages)
517
+ if not stripped:
518
+ return False
519
+
520
+ await session.send_event(Event(
521
+ event_type="tool_log",
522
+ data={
523
+ "tool": "system",
524
+ "log": (
525
+ "Anthropic rejected stale thinking signatures; retrying "
526
+ "without replayed thinking metadata."
527
+ ),
528
+ },
529
+ ))
530
+ return True
531
+
532
+
533
  def _assistant_message_from_result(
534
  llm_result: LLMResult,
535
  *,
 
555
  """Call the LLM with streaming, emitting assistant_chunk events."""
556
  response = None
557
  _healed_effort = False # one-shot safety net per call
558
+ _healed_thinking_signature = False
559
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
560
  t_start = time.monotonic()
561
  for _llm_attempt in range(_MAX_LLM_RETRIES):
 
583
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
584
  ))
585
  continue
586
+ if await _maybe_heal_invalid_thinking_signature(
587
+ session,
588
+ messages,
589
+ e,
590
+ already_healed=_healed_thinking_signature,
591
+ ):
592
+ _healed_thinking_signature = True
593
+ continue
594
  _delay = _retry_delay_for(e, _llm_attempt)
595
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
596
  logger.warning(
 
612
  final_usage_chunk = None
613
  chunks = []
614
  should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))
 
 
615
 
616
  async for chunk in response:
617
  chunks.append(chunk)
 
630
  if choice.finish_reason:
631
  finish_reason = choice.finish_reason
632
 
 
 
 
 
 
 
 
633
  if delta.content:
634
  full_content += delta.content
635
  await session.send_event(
 
663
  latency_ms=int((time.monotonic() - t_start) * 1000),
664
  finish_reason=finish_reason,
665
  )
666
+ thinking_blocks = None
667
+ reasoning_content = None
668
+ if chunks and should_replay_thinking:
669
  try:
670
  rebuilt = stream_chunk_builder(chunks, messages=messages)
671
  if rebuilt and getattr(rebuilt, "choices", None):
 
689
  """Call the LLM without streaming, emit assistant_message at the end."""
690
  response = None
691
  _healed_effort = False
692
+ _healed_thinking_signature = False
693
  messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
694
  t_start = time.monotonic()
695
  for _llm_attempt in range(_MAX_LLM_RETRIES):
 
716
  data={"tool": "system", "log": "Reasoning effort not supported for this model — adjusting and retrying."},
717
  ))
718
  continue
719
+ if await _maybe_heal_invalid_thinking_signature(
720
+ session,
721
+ messages,
722
+ e,
723
+ already_healed=_healed_thinking_signature,
724
+ ):
725
+ _healed_thinking_signature = True
726
+ continue
727
  _delay = _retry_delay_for(e, _llm_attempt)
728
  if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
729
  logger.warning(
 
1235
  await session.send_event(
1236
  Event(
1237
  event_type="turn_complete",
1238
+ data={
1239
+ "history_size": len(session.context_manager.items),
1240
+ "final_response": final_response
1241
+ if isinstance(final_response, str)
1242
+ else None,
1243
+ },
1244
  )
1245
  )
1246
 
 
1549
  async def submission_loop(
1550
  submission_queue: asyncio.Queue,
1551
  event_queue: asyncio.Queue,
1552
+ config: Config,
1553
  tool_router: ToolRouter | None = None,
1554
  session_holder: list | None = None,
1555
  hf_token: str | None = None,
1556
  user_id: str | None = None,
1557
  local_mode: bool = False,
1558
  stream: bool = True,
1559
+ notification_gateway: NotificationGateway | None = None,
1560
+ notification_destinations: list[str] | None = None,
1561
+ defer_turn_complete_notification: bool = False,
1562
  ) -> None:
1563
  """
1564
  Main agent loop - processes submissions and dispatches to handlers.
 
1569
  session = Session(
1570
  event_queue, config=config, tool_router=tool_router, hf_token=hf_token,
1571
  user_id=user_id, local_mode=local_mode, stream=stream,
1572
+ notification_gateway=notification_gateway,
1573
+ notification_destinations=notification_destinations,
1574
+ defer_turn_complete_notification=defer_turn_complete_notification,
1575
  )
1576
  if session_holder is not None:
1577
  session_holder[0] = session
agent/core/session.py CHANGED
@@ -12,10 +12,13 @@ from typing import Any, Optional
12
 
13
  from agent.config import Config
14
  from agent.context_manager.manager import ContextManager
 
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
  _DEFAULT_MAX_TOKENS = 200_000
 
19
 
20
 
21
  def _get_max_tokens_safe(model_name: str) -> int:
@@ -73,18 +76,24 @@ class Session:
73
  def __init__(
74
  self,
75
  event_queue: asyncio.Queue,
76
- config: Config | None = None,
77
  tool_router=None,
78
  context_manager: ContextManager | None = None,
79
  hf_token: str | None = None,
80
  local_mode: bool = False,
81
  stream: bool = True,
 
 
 
 
82
  user_id: str | None = None,
83
  ):
84
  self.hf_token: Optional[str] = hf_token
85
  self.user_id: Optional[str] = user_id
86
  self.tool_router = tool_router
87
  self.stream = stream
 
 
88
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
89
  self.context_manager = context_manager or ContextManager(
90
  model_max_tokens=_get_max_tokens_safe(config.model_name),
@@ -95,15 +104,16 @@ class Session:
95
  local_mode=local_mode,
96
  )
97
  self.event_queue = event_queue
98
- self.session_id = str(uuid.uuid4())
99
- self.config = config or Config(
100
- model_name="bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0",
101
- )
102
  self.is_running = True
103
  self._cancelled = asyncio.Event()
104
  self.pending_approval: Optional[dict[str, Any]] = None
105
  self.sandbox = None
106
  self._running_job_ids: set[str] = set() # HF job IDs currently executing
 
 
 
107
 
108
  # Session trajectory logging
109
  self.logged_events: list[dict] = []
@@ -138,11 +148,121 @@ class Session:
138
  "data": event.data,
139
  }
140
  )
 
141
 
142
  # Mid-turn heartbeat flush (owned by telemetry module).
143
  from agent.core.telemetry import HeartbeatSaver
 
144
  HeartbeatSaver.maybe_fire(self)
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def cancel(self) -> None:
147
  """Signal cancellation to the running agent loop."""
148
  self._cancelled.set()
 
12
 
13
  from agent.config import Config
14
  from agent.context_manager.manager import ContextManager
15
+ from agent.messaging.gateway import NotificationGateway
16
+ from agent.messaging.models import NotificationRequest
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
  _DEFAULT_MAX_TOKENS = 200_000
21
+ _TURN_COMPLETE_NOTIFICATION_CHARS = 39000
22
 
23
 
24
  def _get_max_tokens_safe(model_name: str) -> int:
 
76
  def __init__(
77
  self,
78
  event_queue: asyncio.Queue,
79
+ config: Config,
80
  tool_router=None,
81
  context_manager: ContextManager | None = None,
82
  hf_token: str | None = None,
83
  local_mode: bool = False,
84
  stream: bool = True,
85
+ notification_gateway: NotificationGateway | None = None,
86
+ notification_destinations: list[str] | None = None,
87
+ defer_turn_complete_notification: bool = False,
88
+ session_id: str | None = None,
89
  user_id: str | None = None,
90
  ):
91
  self.hf_token: Optional[str] = hf_token
92
  self.user_id: Optional[str] = user_id
93
  self.tool_router = tool_router
94
  self.stream = stream
95
+ if config is None:
96
+ raise ValueError("Session requires a Config")
97
  tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
98
  self.context_manager = context_manager or ContextManager(
99
  model_max_tokens=_get_max_tokens_safe(config.model_name),
 
104
  local_mode=local_mode,
105
  )
106
  self.event_queue = event_queue
107
+ self.session_id = session_id or str(uuid.uuid4())
108
+ self.config = config
 
 
109
  self.is_running = True
110
  self._cancelled = asyncio.Event()
111
  self.pending_approval: Optional[dict[str, Any]] = None
112
  self.sandbox = None
113
  self._running_job_ids: set[str] = set() # HF job IDs currently executing
114
+ self.notification_gateway = notification_gateway
115
+ self.notification_destinations = list(notification_destinations or [])
116
+ self.defer_turn_complete_notification = defer_turn_complete_notification
117
 
118
  # Session trajectory logging
119
  self.logged_events: list[dict] = []
 
148
  "data": event.data,
149
  }
150
  )
151
+ await self._enqueue_auto_notification_requests(event)
152
 
153
  # Mid-turn heartbeat flush (owned by telemetry module).
154
  from agent.core.telemetry import HeartbeatSaver
155
+
156
  HeartbeatSaver.maybe_fire(self)
157
 
158
+ def set_notification_destinations(self, destinations: list[str]) -> None:
159
+ """Replace the session's opted-in auto-notification destinations."""
160
+ deduped: list[str] = []
161
+ seen: set[str] = set()
162
+ for destination in destinations:
163
+ if destination not in seen:
164
+ deduped.append(destination)
165
+ seen.add(destination)
166
+ self.notification_destinations = deduped
167
+
168
+ async def send_deferred_turn_complete_notification(self, event: Event) -> None:
169
+ if event.event_type != "turn_complete":
170
+ return
171
+ await self._enqueue_auto_notification_requests(
172
+ event,
173
+ include_deferred_turn_complete=True,
174
+ )
175
+
176
+ async def _enqueue_auto_notification_requests(
177
+ self,
178
+ event: Event,
179
+ include_deferred_turn_complete: bool = False,
180
+ ) -> None:
181
+ if self.notification_gateway is None:
182
+ return
183
+ if not self.notification_destinations:
184
+ return
185
+ auto_events = set(self.config.messaging.auto_event_types)
186
+ if event.event_type not in auto_events:
187
+ return
188
+ if (
189
+ self.defer_turn_complete_notification
190
+ and event.event_type == "turn_complete"
191
+ and not include_deferred_turn_complete
192
+ ):
193
+ return
194
+
195
+ requests = self._build_auto_notification_requests(event)
196
+ for request in requests:
197
+ await self.notification_gateway.enqueue(request)
198
+
199
+ def _build_auto_notification_requests(
200
+ self, event: Event
201
+ ) -> list[NotificationRequest]:
202
+ metadata = {
203
+ "session_id": self.session_id,
204
+ "model": self.config.model_name,
205
+ "event_type": event.event_type,
206
+ }
207
+
208
+ title: str | None = None
209
+ message: str | None = None
210
+ severity = "info"
211
+ data = event.data or {}
212
+ if event.event_type == "approval_required":
213
+ tools = data.get("tools", [])
214
+ tool_names = []
215
+ for tool in tools if isinstance(tools, list) else []:
216
+ if isinstance(tool, dict):
217
+ tool_name = str(tool.get("tool") or "").strip()
218
+ if tool_name and tool_name not in tool_names:
219
+ tool_names.append(tool_name)
220
+ count = len(tools) if isinstance(tools, list) else 0
221
+ title = "Agent approval required"
222
+ message = (
223
+ f"Session {self.session_id} is waiting for approval "
224
+ f"for {count} tool call(s)."
225
+ )
226
+ if tool_names:
227
+ message += " Tools: " + ", ".join(tool_names)
228
+ severity = "warning"
229
+ elif event.event_type == "error":
230
+ title = "Agent error"
231
+ error = str(data.get("error") or "Unknown error")
232
+ message = f"Session {self.session_id} hit an error.\n{error[:500]}"
233
+ severity = "error"
234
+ elif event.event_type == "turn_complete":
235
+ title = "Agent task complete"
236
+ summary = str(data.get("final_response") or "").strip()
237
+ if summary:
238
+ summary = summary[:_TURN_COMPLETE_NOTIFICATION_CHARS]
239
+ message = (
240
+ f"Session {self.session_id} completed successfully.\n"
241
+ f"{summary}"
242
+ )
243
+ else:
244
+ message = f"Session {self.session_id} completed successfully."
245
+ severity = "success"
246
+
247
+ if message is None:
248
+ return []
249
+
250
+ requests: list[NotificationRequest] = []
251
+ for destination in self.notification_destinations:
252
+ if not self.config.messaging.can_auto_send(destination):
253
+ continue
254
+ requests.append(
255
+ NotificationRequest(
256
+ destination=destination,
257
+ title=title,
258
+ message=message,
259
+ severity=severity,
260
+ metadata=metadata,
261
+ event_type=event.event_type,
262
+ )
263
+ )
264
+ return requests
265
+
266
  def cancel(self) -> None:
267
  """Signal cancellation to the running agent loop."""
268
  self._cancelled.set()
agent/core/tools.py CHANGED
@@ -46,6 +46,7 @@ from agent.tools.hf_repo_git_tool import (
46
  hf_repo_git_handler,
47
  )
48
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
 
49
  from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
50
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
51
  from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler
@@ -324,6 +325,12 @@ def create_builtin_tools(local_mode: bool = False) -> list[ToolSpec]:
324
  parameters=PLAN_TOOL_SPEC["parameters"],
325
  handler=plan_tool_handler,
326
  ),
 
 
 
 
 
 
327
  ToolSpec(
328
  name=HF_JOBS_TOOL_SPEC["name"],
329
  description=HF_JOBS_TOOL_SPEC["description"],
 
46
  hf_repo_git_handler,
47
  )
48
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
49
+ from agent.tools.notify_tool import NOTIFY_TOOL_SPEC, notify_handler
50
  from agent.tools.papers_tool import HF_PAPERS_TOOL_SPEC, hf_papers_handler
51
  from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
52
  from agent.tools.research_tool import RESEARCH_TOOL_SPEC, research_handler
 
325
  parameters=PLAN_TOOL_SPEC["parameters"],
326
  handler=plan_tool_handler,
327
  ),
328
+ ToolSpec(
329
+ name=NOTIFY_TOOL_SPEC["name"],
330
+ description=NOTIFY_TOOL_SPEC["description"],
331
+ parameters=NOTIFY_TOOL_SPEC["parameters"],
332
+ handler=notify_handler,
333
+ ),
334
  ToolSpec(
335
  name=HF_JOBS_TOOL_SPEC["name"],
336
  description=HF_JOBS_TOOL_SPEC["description"],
agent/main.py CHANGED
@@ -26,6 +26,7 @@ from agent.core import model_switcher
26
  from agent.core.hf_tokens import resolve_hf_token
27
  from agent.core.session import OpType
28
  from agent.core.tools import ToolRouter
 
29
  from agent.utils.reliability_checks import check_training_script_save_pattern
30
  from agent.utils.terminal_display import (
31
  get_console,
@@ -332,6 +333,9 @@ async def event_listener(
332
  stream_buf.discard()
333
  print_turn_complete()
334
  print_plan()
 
 
 
335
  turn_complete_event.set()
336
  elif event.event_type == "interrupted":
337
  shimmer.stop()
@@ -821,7 +825,7 @@ async def main(model: str | None = None):
821
  if not hf_token:
822
  hf_token = await _prompt_and_save_hf_token(prompt_session)
823
 
824
- config = load_config(CLI_CONFIG_PATH)
825
  if model:
826
  config.model_name = model
827
 
@@ -844,6 +848,8 @@ async def main(model: str | None = None):
844
  turn_complete_event.set()
845
  ready_event = asyncio.Event()
846
 
 
 
847
  # Create tool router with local mode
848
  tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
849
 
@@ -861,6 +867,9 @@ async def main(model: str | None = None):
861
  user_id=hf_user,
862
  local_mode=True,
863
  stream=True,
 
 
 
864
  )
865
  )
866
 
@@ -1016,6 +1025,8 @@ async def main(model: str | None = None):
1016
  agent_task.cancel()
1017
  # Agent didn't shut down cleanly — close MCP explicitly
1018
  await tool_router.__aexit__(None, None, None)
 
 
1019
 
1020
  # Now safe to cancel the listener (agent is done emitting events)
1021
  listener_task.cancel()
@@ -1042,8 +1053,10 @@ async def headless_main(
1042
 
1043
  print(f"HF token loaded", file=sys.stderr)
1044
 
1045
- config = load_config(CLI_CONFIG_PATH)
1046
  config.yolo_mode = True # Auto-approve everything in headless mode
 
 
1047
  hf_user = _get_hf_user(hf_token)
1048
 
1049
  if model:
@@ -1074,6 +1087,9 @@ async def headless_main(
1074
  user_id=hf_user,
1075
  local_mode=True,
1076
  stream=stream,
 
 
 
1077
  )
1078
  )
1079
 
@@ -1199,6 +1215,10 @@ async def headless_main(
1199
  stream_buf.discard()
1200
  history_size = event.data.get("history_size", "?") if event.data else "?"
1201
  print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr)
 
 
 
 
1202
  break
1203
 
1204
  # Shutdown
@@ -1212,6 +1232,8 @@ async def headless_main(
1212
  except asyncio.TimeoutError:
1213
  agent_task.cancel()
1214
  await tool_router.__aexit__(None, None, None)
 
 
1215
 
1216
 
1217
  def cli():
 
26
  from agent.core.hf_tokens import resolve_hf_token
27
  from agent.core.session import OpType
28
  from agent.core.tools import ToolRouter
29
+ from agent.messaging.gateway import NotificationGateway
30
  from agent.utils.reliability_checks import check_training_script_save_pattern
31
  from agent.utils.terminal_display import (
32
  get_console,
 
333
  stream_buf.discard()
334
  print_turn_complete()
335
  print_plan()
336
+ session = session_holder[0] if session_holder else None
337
+ if session is not None:
338
+ await session.send_deferred_turn_complete_notification(event)
339
  turn_complete_event.set()
340
  elif event.event_type == "interrupted":
341
  shimmer.stop()
 
825
  if not hf_token:
826
  hf_token = await _prompt_and_save_hf_token(prompt_session)
827
 
828
+ config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
829
  if model:
830
  config.model_name = model
831
 
 
848
  turn_complete_event.set()
849
  ready_event = asyncio.Event()
850
 
851
+ notification_gateway = NotificationGateway(config.messaging)
852
+ await notification_gateway.start()
853
  # Create tool router with local mode
854
  tool_router = ToolRouter(config.mcpServers, hf_token=hf_token, local_mode=True)
855
 
 
867
  user_id=hf_user,
868
  local_mode=True,
869
  stream=True,
870
+ notification_gateway=notification_gateway,
871
+ notification_destinations=config.messaging.default_auto_destinations(),
872
+ defer_turn_complete_notification=True,
873
  )
874
  )
875
 
 
1025
  agent_task.cancel()
1026
  # Agent didn't shut down cleanly — close MCP explicitly
1027
  await tool_router.__aexit__(None, None, None)
1028
+ finally:
1029
+ await notification_gateway.close()
1030
 
1031
  # Now safe to cancel the listener (agent is done emitting events)
1032
  listener_task.cancel()
 
1053
 
1054
  print(f"HF token loaded", file=sys.stderr)
1055
 
1056
+ config = load_config(CLI_CONFIG_PATH, include_user_defaults=True)
1057
  config.yolo_mode = True # Auto-approve everything in headless mode
1058
+ notification_gateway = NotificationGateway(config.messaging)
1059
+ await notification_gateway.start()
1060
  hf_user = _get_hf_user(hf_token)
1061
 
1062
  if model:
 
1087
  user_id=hf_user,
1088
  local_mode=True,
1089
  stream=stream,
1090
+ notification_gateway=notification_gateway,
1091
+ notification_destinations=config.messaging.default_auto_destinations(),
1092
+ defer_turn_complete_notification=True,
1093
  )
1094
  )
1095
 
 
1215
  stream_buf.discard()
1216
  history_size = event.data.get("history_size", "?") if event.data else "?"
1217
  print(f"\n--- Agent {event.event_type} (history_size={history_size}) ---", file=sys.stderr)
1218
+ if event.event_type == "turn_complete":
1219
+ session = session_holder[0] if session_holder else None
1220
+ if session is not None:
1221
+ await session.send_deferred_turn_complete_notification(event)
1222
  break
1223
 
1224
  # Shutdown
 
1232
  except asyncio.TimeoutError:
1233
  agent_task.cancel()
1234
  await tool_router.__aexit__(None, None, None)
1235
+ finally:
1236
+ await notification_gateway.close()
1237
 
1238
 
1239
  def cli():
agent/messaging/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agent.messaging.gateway import NotificationGateway
2
+ from agent.messaging.models import (
3
+ MessagingConfig,
4
+ NotificationRequest,
5
+ NotificationResult,
6
+ SUPPORTED_AUTO_EVENT_TYPES,
7
+ )
8
+
9
+ __all__ = [
10
+ "MessagingConfig",
11
+ "NotificationGateway",
12
+ "NotificationRequest",
13
+ "NotificationResult",
14
+ "SUPPORTED_AUTO_EVENT_TYPES",
15
+ ]
agent/messaging/base.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import httpx
4
+
5
+ from agent.messaging.models import DestinationConfig, NotificationRequest, NotificationResult
6
+
7
+
8
+ class NotificationError(Exception):
9
+ """Delivery failed and should not be retried."""
10
+
11
+
12
+ class RetryableNotificationError(NotificationError):
13
+ """Delivery failed transiently and can be retried."""
14
+
15
+
16
+ class NotificationProvider(ABC):
17
+ provider_name: str
18
+
19
+ @abstractmethod
20
+ async def send(
21
+ self,
22
+ client: httpx.AsyncClient,
23
+ destination_name: str,
24
+ destination: DestinationConfig,
25
+ request: NotificationRequest,
26
+ ) -> NotificationResult:
27
+ """Deliver a notification to one destination."""
agent/messaging/gateway.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from collections.abc import Iterable
4
+
5
+ import httpx
6
+
7
+ from agent.messaging.base import (
8
+ NotificationError,
9
+ NotificationProvider,
10
+ RetryableNotificationError,
11
+ )
12
+ from agent.messaging.models import (
13
+ MessagingConfig,
14
+ NotificationRequest,
15
+ NotificationResult,
16
+ )
17
+ from agent.messaging.slack import SlackProvider
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _RETRY_DELAYS = (1, 2, 4)
22
+
23
+
24
+ class NotificationGateway:
25
+ def __init__(self, config: MessagingConfig):
26
+ self.config = config
27
+ self._providers: dict[str, NotificationProvider] = {
28
+ "slack": SlackProvider(),
29
+ }
30
+ self._queue: asyncio.Queue[NotificationRequest] = asyncio.Queue()
31
+ self._worker_task: asyncio.Task | None = None
32
+ self._client: httpx.AsyncClient | None = None
33
+
34
+ @property
35
+ def enabled(self) -> bool:
36
+ return self.config.enabled
37
+
38
+ async def start(self) -> None:
39
+ if not self.enabled or self._worker_task is not None:
40
+ return
41
+ self._client = httpx.AsyncClient(timeout=10.0)
42
+ self._worker_task = asyncio.create_task(self._worker(), name="notification-gateway")
43
+
44
+ async def flush(self) -> None:
45
+ if not self.enabled:
46
+ return
47
+ await self._queue.join()
48
+
49
+ async def close(self) -> None:
50
+ if not self.enabled:
51
+ return
52
+ await self.flush()
53
+ if self._worker_task is not None:
54
+ self._worker_task.cancel()
55
+ try:
56
+ await self._worker_task
57
+ except asyncio.CancelledError:
58
+ pass
59
+ self._worker_task = None
60
+ if self._client is not None:
61
+ await self._client.aclose()
62
+ self._client = None
63
+
64
+ async def send(self, request: NotificationRequest) -> NotificationResult:
65
+ if not self.enabled:
66
+ return NotificationResult(
67
+ destination=request.destination,
68
+ ok=False,
69
+ provider="disabled",
70
+ error="Messaging is disabled",
71
+ )
72
+
73
+ destination = self.config.get_destination(request.destination)
74
+ if destination is None:
75
+ return NotificationResult(
76
+ destination=request.destination,
77
+ ok=False,
78
+ provider="unknown",
79
+ error=f"Unknown destination '{request.destination}'",
80
+ )
81
+
82
+ provider = self._providers.get(destination.provider)
83
+ if provider is None:
84
+ return NotificationResult(
85
+ destination=request.destination,
86
+ ok=False,
87
+ provider=destination.provider,
88
+ error=f"No provider implementation for '{destination.provider}'",
89
+ )
90
+ return await self._send_with_retries(provider, request.destination, destination, request)
91
+
92
+ async def send_many(
93
+ self, requests: Iterable[NotificationRequest]
94
+ ) -> list[NotificationResult]:
95
+ results: list[NotificationResult] = []
96
+ for request in requests:
97
+ results.append(await self.send(request))
98
+ return results
99
+
100
+ async def enqueue(self, request: NotificationRequest) -> bool:
101
+ if not self.enabled or self._worker_task is None:
102
+ return False
103
+ await self._queue.put(request)
104
+ return True
105
+
106
+ async def _worker(self) -> None:
107
+ while True:
108
+ request = await self._queue.get()
109
+ try:
110
+ result = await self.send(request)
111
+ if not result.ok:
112
+ logger.warning(
113
+ "Notification delivery failed for %s: %s",
114
+ request.destination,
115
+ result.error,
116
+ )
117
+ except Exception:
118
+ logger.exception("Unexpected notification worker failure")
119
+ finally:
120
+ self._queue.task_done()
121
+
122
+ async def _send_with_retries(
123
+ self,
124
+ provider: NotificationProvider,
125
+ destination_name: str,
126
+ destination,
127
+ request: NotificationRequest,
128
+ ) -> NotificationResult:
129
+ client = self._client or httpx.AsyncClient(timeout=10.0)
130
+ owns_client = self._client is None
131
+ try:
132
+ for attempt in range(len(_RETRY_DELAYS) + 1):
133
+ try:
134
+ return await provider.send(client, destination_name, destination, request)
135
+ except RetryableNotificationError as exc:
136
+ if attempt >= len(_RETRY_DELAYS):
137
+ return NotificationResult(
138
+ destination=destination_name,
139
+ ok=False,
140
+ provider=provider.provider_name,
141
+ error=str(exc),
142
+ )
143
+ delay = _RETRY_DELAYS[attempt]
144
+ logger.warning(
145
+ "Retrying notification to %s in %ss after transient error: %s",
146
+ destination_name,
147
+ delay,
148
+ exc,
149
+ )
150
+ await asyncio.sleep(delay)
151
+ except NotificationError as exc:
152
+ return NotificationResult(
153
+ destination=destination_name,
154
+ ok=False,
155
+ provider=provider.provider_name,
156
+ error=str(exc),
157
+ )
158
+ return NotificationResult(
159
+ destination=destination_name,
160
+ ok=False,
161
+ provider=provider.provider_name,
162
+ error="Notification delivery exhausted retries",
163
+ )
164
+ finally:
165
+ if owns_client:
166
+ await client.aclose()
agent/messaging/models.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, Literal
2
+
3
+ from pydantic import BaseModel, Field, field_validator, model_validator
4
+
5
+ _DESTINATION_NAME_CHARS = set("abcdefghijklmnopqrstuvwxyz0123456789._-")
6
+ SUPPORTED_AUTO_EVENT_TYPES = {"approval_required", "error", "turn_complete"}
7
+
8
+
9
+ class SlackDestinationConfig(BaseModel):
10
+ provider: Literal["slack"] = "slack"
11
+ token: str
12
+ channel: str
13
+ allow_agent_tool: bool = False
14
+ allow_auto_events: bool = False
15
+ username: str | None = None
16
+ icon_emoji: str | None = None
17
+
18
+ @field_validator("token", "channel")
19
+ @classmethod
20
+ def _require_non_empty(cls, value: str) -> str:
21
+ value = value.strip()
22
+ if not value:
23
+ raise ValueError("must not be empty")
24
+ return value
25
+
26
+
27
+ DestinationConfig = Annotated[SlackDestinationConfig, Field(discriminator="provider")]
28
+
29
+
30
+ class MessagingConfig(BaseModel):
31
+ enabled: bool = False
32
+ auto_event_types: list[str] = Field(
33
+ default_factory=lambda: ["approval_required", "error", "turn_complete"]
34
+ )
35
+ destinations: dict[str, DestinationConfig] = Field(default_factory=dict)
36
+
37
+ @field_validator("destinations")
38
+ @classmethod
39
+ def _validate_destination_names(
40
+ cls, destinations: dict[str, DestinationConfig]
41
+ ) -> dict[str, DestinationConfig]:
42
+ for name in destinations:
43
+ if not name or any(char not in _DESTINATION_NAME_CHARS for char in name):
44
+ raise ValueError(
45
+ "destination names must use lowercase letters, digits, '.', '_' or '-'"
46
+ )
47
+ return destinations
48
+
49
+ @field_validator("auto_event_types")
50
+ @classmethod
51
+ def _validate_auto_event_types(cls, event_types: list[str]) -> list[str]:
52
+ if not event_types:
53
+ return []
54
+ normalized: list[str] = []
55
+ seen: set[str] = set()
56
+ for event_type in event_types:
57
+ if event_type not in SUPPORTED_AUTO_EVENT_TYPES:
58
+ raise ValueError(
59
+ f"unsupported auto event type '{event_type}'"
60
+ )
61
+ if event_type not in seen:
62
+ normalized.append(event_type)
63
+ seen.add(event_type)
64
+ return normalized
65
+
66
+ @model_validator(mode="after")
67
+ def _require_destinations_when_enabled(self) -> "MessagingConfig":
68
+ if self.enabled and not self.destinations:
69
+ raise ValueError("messaging.enabled requires at least one destination")
70
+ return self
71
+
72
+ def get_destination(self, name: str) -> DestinationConfig | None:
73
+ return self.destinations.get(name)
74
+
75
+ def can_agent_tool_send(self, name: str) -> bool:
76
+ destination = self.get_destination(name)
77
+ return bool(destination and destination.allow_agent_tool)
78
+
79
+ def can_auto_send(self, name: str) -> bool:
80
+ destination = self.get_destination(name)
81
+ return bool(destination and destination.allow_auto_events)
82
+
83
+ def default_auto_destinations(self) -> list[str]:
84
+ if not self.enabled:
85
+ return []
86
+ return [
87
+ name
88
+ for name in self.destinations
89
+ if self.can_auto_send(name)
90
+ ]
91
+
92
+
93
+ class NotificationRequest(BaseModel):
94
+ destination: str
95
+ title: str | None = None
96
+ message: str
97
+ severity: Literal["info", "success", "warning", "error"] = "info"
98
+ metadata: dict[str, str] = Field(default_factory=dict)
99
+ event_type: str | None = None
100
+
101
+ @field_validator("destination", "message")
102
+ @classmethod
103
+ def _require_text(cls, value: str) -> str:
104
+ value = value.strip()
105
+ if not value:
106
+ raise ValueError("must not be empty")
107
+ return value
108
+
109
+ @field_validator("title")
110
+ @classmethod
111
+ def _normalize_title(cls, value: str | None) -> str | None:
112
+ if value is None:
113
+ return None
114
+ value = value.strip()
115
+ return value or None
116
+
117
+
118
+ class NotificationResult(BaseModel):
119
+ destination: str
120
+ ok: bool
121
+ provider: str
122
+ error: str | None = None
123
+ external_id: str | None = None
agent/messaging/slack.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+
4
+ import httpx
5
+
6
+ from agent.messaging.base import (
7
+ NotificationError,
8
+ NotificationProvider,
9
+ RetryableNotificationError,
10
+ )
11
+ from agent.messaging.models import (
12
+ NotificationRequest,
13
+ NotificationResult,
14
+ SlackDestinationConfig,
15
+ )
16
+
17
+ _SEVERITY_PREFIX = {
18
+ "info": "[INFO]",
19
+ "success": "[SUCCESS]",
20
+ "warning": "[WARNING]",
21
+ "error": "[ERROR]",
22
+ }
23
+
24
+
25
+ def _format_slack_mrkdwn(content: str) -> str:
26
+ """Convert common Markdown constructs to Slack's mrkdwn syntax."""
27
+ if not content:
28
+ return content
29
+
30
+ placeholders: dict[str, str] = {}
31
+ placeholder_index = 0
32
+
33
+ def placeholder(value: str) -> str:
34
+ nonlocal placeholder_index
35
+ key = f"\x00SLACK{placeholder_index}\x00"
36
+ placeholder_index += 1
37
+ placeholders[key] = value
38
+ return key
39
+
40
+ text = content
41
+
42
+ # Protect code before any formatting conversion. Slack's mrkdwn ignores
43
+ # formatting inside backticks, so these regions should stay byte-for-byte.
44
+ text = re.sub(
45
+ r"(```(?:[^\n]*\n)?[\s\S]*?```)",
46
+ lambda match: placeholder(match.group(0)),
47
+ text,
48
+ )
49
+ text = re.sub(r"(`[^`\n]+`)", lambda match: placeholder(match.group(0)), text)
50
+
51
+ def convert_markdown_link(match: re.Match[str]) -> str:
52
+ label = match.group(1)
53
+ url = match.group(2).strip()
54
+ if url.startswith("<") and url.endswith(">"):
55
+ url = url[1:-1].strip()
56
+ return placeholder(f"<{url}|{label}>")
57
+
58
+ text = re.sub(
59
+ r"\[([^\]]+)\]\(([^()]*(?:\([^()]*\)[^()]*)*)\)",
60
+ convert_markdown_link,
61
+ text,
62
+ )
63
+
64
+ # Preserve existing Slack entities and manual mrkdwn links before escaping.
65
+ text = re.sub(
66
+ r"(<(?:[@#!]|(?:https?|mailto|tel):)[^>\n]+>)",
67
+ lambda match: placeholder(match.group(1)),
68
+ text,
69
+ )
70
+ text = re.sub(
71
+ r"^(>+\s)",
72
+ lambda match: placeholder(match.group(0)),
73
+ text,
74
+ flags=re.MULTILINE,
75
+ )
76
+
77
+ text = text.replace("&amp;", "&").replace("&lt;", "<").replace("&gt;", ">")
78
+ text = text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
79
+
80
+ def convert_header(match: re.Match[str]) -> str:
81
+ header = match.group(1).strip()
82
+ header = re.sub(r"\*\*(.+?)\*\*", r"\1", header)
83
+ return placeholder(f"*{header}*")
84
+
85
+ text = re.sub(r"^#{1,6}\s+(.+)$", convert_header, text, flags=re.MULTILINE)
86
+ text = re.sub(
87
+ r"\*\*\*(.+?)\*\*\*",
88
+ lambda match: placeholder(f"*_{match.group(1)}_*"),
89
+ text,
90
+ )
91
+ text = re.sub(
92
+ r"\*\*(.+?)\*\*",
93
+ lambda match: placeholder(f"*{match.group(1)}*"),
94
+ text,
95
+ )
96
+ text = re.sub(
97
+ r"(?<!\*)\*([^*\n]+)\*(?!\*)",
98
+ lambda match: placeholder(f"_{match.group(1)}_"),
99
+ text,
100
+ )
101
+ text = re.sub(
102
+ r"~~(.+?)~~",
103
+ lambda match: placeholder(f"~{match.group(1)}~"),
104
+ text,
105
+ )
106
+
107
+ for key in reversed(placeholders):
108
+ text = text.replace(key, placeholders[key])
109
+
110
+ return text
111
+
112
+
113
+ def _format_text(request: NotificationRequest) -> str:
114
+ lines: list[str] = []
115
+ prefix = _SEVERITY_PREFIX[request.severity]
116
+ if request.title:
117
+ lines.append(f"{prefix} {request.title}")
118
+ else:
119
+ lines.append(prefix)
120
+ lines.append(request.message)
121
+ for key, value in request.metadata.items():
122
+ lines.append(f"{key}: {value}")
123
+ return _format_slack_mrkdwn("\n".join(lines))
124
+
125
+
126
+ class SlackProvider(NotificationProvider):
127
+ provider_name = "slack"
128
+
129
+ async def send(
130
+ self,
131
+ client: httpx.AsyncClient,
132
+ destination_name: str,
133
+ destination: SlackDestinationConfig,
134
+ request: NotificationRequest,
135
+ ) -> NotificationResult:
136
+ payload = {
137
+ "channel": destination.channel,
138
+ "text": _format_text(request),
139
+ "mrkdwn": True,
140
+ "unfurl_links": False,
141
+ "unfurl_media": False,
142
+ }
143
+ if destination.username:
144
+ payload["username"] = destination.username
145
+ if destination.icon_emoji:
146
+ payload["icon_emoji"] = destination.icon_emoji
147
+
148
+ try:
149
+ response = await client.post(
150
+ "https://slack.com/api/chat.postMessage",
151
+ headers={
152
+ "Authorization": f"Bearer {destination.token}",
153
+ "Content-Type": "application/json; charset=utf-8",
154
+ },
155
+ content=json.dumps(payload),
156
+ )
157
+ except httpx.TimeoutException as exc:
158
+ raise RetryableNotificationError("Slack request timed out") from exc
159
+ except httpx.TransportError as exc:
160
+ raise RetryableNotificationError("Slack transport error") from exc
161
+
162
+ if response.status_code == 429 or response.status_code >= 500:
163
+ raise RetryableNotificationError(
164
+ f"Slack HTTP {response.status_code}"
165
+ )
166
+ if response.status_code >= 400:
167
+ raise NotificationError(f"Slack HTTP {response.status_code}")
168
+
169
+ try:
170
+ data = response.json()
171
+ except ValueError as exc:
172
+ raise RetryableNotificationError("Slack returned invalid JSON") from exc
173
+
174
+ if not data.get("ok"):
175
+ error = str(data.get("error") or "unknown_error")
176
+ if error == "ratelimited":
177
+ raise RetryableNotificationError(error)
178
+ raise NotificationError(error)
179
+
180
+ return NotificationResult(
181
+ destination=destination_name,
182
+ ok=True,
183
+ provider=self.provider_name,
184
+ external_id=str(data.get("ts") or ""),
185
+ error=None,
186
+ )
agent/prompts/system_prompt_v3.yaml CHANGED
@@ -157,6 +157,7 @@ system_prompt: |
157
  - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs.
158
  - For errors: state what went wrong, why, and what you're doing to fix it.
159
  - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity.
 
160
 
161
  # Tool usage
162
 
 
157
  - Always include direct Hub URLs when referencing models, datasets, Spaces, or jobs.
158
  - For errors: state what went wrong, why, and what you're doing to fix it.
159
  - Do not over-explain or present elaborate option menus for simple tasks. When the user's intent is clear, act on it. Present options only when there's genuine ambiguity.
160
+ - Use the `notify` tool only when the user explicitly asked for out-of-band notifications or when the task clearly requires reporting to a configured messaging destination. Do not use it for routine chat updates.
161
 
162
  # Tool usage
163
 
agent/tools/notify_tool.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from agent.messaging.models import NotificationRequest
4
+
5
+ NOTIFY_TOOL_SPEC = {
6
+ "name": "notify",
7
+ "description": (
8
+ "Send an out-of-band notification to configured messaging destinations. "
9
+ "Use this only when the user explicitly asked for proactive notifications "
10
+ "or when the task requires reporting progress outside the chat. "
11
+ "Destinations must be named server-side configs such as 'slack.ops'."
12
+ ),
13
+ "parameters": {
14
+ "type": "object",
15
+ "properties": {
16
+ "destinations": {
17
+ "type": "array",
18
+ "description": "Named messaging destinations to notify.",
19
+ "items": {"type": "string"},
20
+ "minItems": 1,
21
+ },
22
+ "message": {
23
+ "type": "string",
24
+ "description": "Main notification body.",
25
+ },
26
+ "title": {
27
+ "type": "string",
28
+ "description": "Optional short title line.",
29
+ },
30
+ "severity": {
31
+ "type": "string",
32
+ "enum": ["info", "success", "warning", "error"],
33
+ "description": "Notification severity label.",
34
+ },
35
+ },
36
+ "required": ["destinations", "message"],
37
+ },
38
+ }
39
+
40
+
41
+ async def notify_handler(
42
+ arguments: dict[str, Any], session=None, **_kwargs
43
+ ) -> tuple[str, bool]:
44
+ if session is None or session.notification_gateway is None:
45
+ return "Messaging is not configured for this session.", False
46
+
47
+ raw_destinations = arguments.get("destinations", [])
48
+ if not isinstance(raw_destinations, list) or not raw_destinations:
49
+ return "destinations must be a non-empty array of destination names.", False
50
+
51
+ destinations: list[str] = []
52
+ seen: set[str] = set()
53
+ for raw_name in raw_destinations:
54
+ if not isinstance(raw_name, str):
55
+ return "Each destination must be a string.", False
56
+ name = raw_name.strip()
57
+ if not name:
58
+ return "Destination names must not be empty.", False
59
+ if name not in seen:
60
+ destinations.append(name)
61
+ seen.add(name)
62
+
63
+ disallowed = [
64
+ name
65
+ for name in destinations
66
+ if not session.config.messaging.can_agent_tool_send(name)
67
+ ]
68
+ if disallowed:
69
+ return (
70
+ "These destinations are unavailable for the notify tool: "
71
+ + ", ".join(disallowed)
72
+ ), False
73
+
74
+ message = arguments.get("message", "")
75
+ if not isinstance(message, str) or not message.strip():
76
+ return "message must be a non-empty string.", False
77
+
78
+ title = arguments.get("title")
79
+ severity = arguments.get("severity", "info")
80
+ if title is not None and not isinstance(title, str):
81
+ return "title must be a string when provided.", False
82
+ if severity not in {"info", "success", "warning", "error"}:
83
+ return "severity must be one of: info, success, warning, error.", False
84
+
85
+ requests = [
86
+ NotificationRequest(
87
+ destination=name,
88
+ title=title,
89
+ message=message,
90
+ severity=severity,
91
+ metadata={
92
+ "session_id": session.session_id,
93
+ "model": session.config.model_name,
94
+ },
95
+ )
96
+ for name in destinations
97
+ ]
98
+ results = await session.notification_gateway.send_many(requests)
99
+
100
+ lines = []
101
+ all_ok = True
102
+ for result in results:
103
+ if result.ok:
104
+ lines.append(f"{result.destination}: sent")
105
+ else:
106
+ all_ok = False
107
+ lines.append(f"{result.destination}: failed ({result.error})")
108
+ return "\n".join(lines), all_ok
backend/main.py CHANGED
@@ -11,6 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.staticfiles import StaticFiles
12
  from routes.agent import router as agent_router
13
  from routes.auth import router as auth_router
 
14
 
15
  # Load .env from project root (parent directory)
16
  load_dotenv(Path(__file__).parent.parent / ".env")
@@ -27,6 +28,7 @@ logger = logging.getLogger(__name__)
27
  async def lifespan(app: FastAPI):
28
  """Application lifespan handler."""
29
  logger.info("Starting HF Agent backend...")
 
30
  # Start in-process hourly KPI rollup. Replaces an external cron so the
31
  # rollup lives next to the data and reuses the Space's HF token.
32
  try:
@@ -34,7 +36,6 @@ async def lifespan(app: FastAPI):
34
  kpis_scheduler.start()
35
  except Exception as e:
36
  logger.warning("KPI scheduler failed to start: %s", e)
37
-
38
  yield
39
 
40
  logger.info("Shutting down HF Agent backend...")
@@ -47,7 +48,6 @@ async def lifespan(app: FastAPI):
47
  # Final-flush: save every still-active session so we don't lose traces on
48
  # server restart. Uploads are detached subprocesses — this is fast.
49
  try:
50
- from session_manager import session_manager
51
  for sid, agent_session in list(session_manager.sessions.items()):
52
  sess = agent_session.session
53
  if sess.config.save_sessions:
@@ -58,6 +58,7 @@ async def lifespan(app: FastAPI):
58
  logger.warning("Failed to flush session %s: %s", sid, e)
59
  except Exception as e:
60
  logger.warning("Lifespan final-flush skipped: %s", e)
 
61
 
62
 
63
  app = FastAPI(
 
11
  from fastapi.staticfiles import StaticFiles
12
  from routes.agent import router as agent_router
13
  from routes.auth import router as auth_router
14
+ from session_manager import session_manager
15
 
16
  # Load .env from project root (parent directory)
17
  load_dotenv(Path(__file__).parent.parent / ".env")
 
28
  async def lifespan(app: FastAPI):
29
  """Application lifespan handler."""
30
  logger.info("Starting HF Agent backend...")
31
+ await session_manager.start()
32
  # Start in-process hourly KPI rollup. Replaces an external cron so the
33
  # rollup lives next to the data and reuses the Space's HF token.
34
  try:
 
36
  kpis_scheduler.start()
37
  except Exception as e:
38
  logger.warning("KPI scheduler failed to start: %s", e)
 
39
  yield
40
 
41
  logger.info("Shutting down HF Agent backend...")
 
48
  # Final-flush: save every still-active session so we don't lose traces on
49
  # server restart. Uploads are detached subprocesses — this is fast.
50
  try:
 
51
  for sid, agent_session in list(session_manager.sessions.items()):
52
  sess = agent_session.session
53
  if sess.config.save_sessions:
 
58
  logger.warning("Failed to flush session %s: %s", sid, e)
59
  except Exception as e:
60
  logger.warning("Lifespan final-flush skipped: %s", e)
61
+ await session_manager.close()
62
 
63
 
64
  app = FastAPI(
backend/models.py CHANGED
@@ -3,7 +3,7 @@
3
  from enum import Enum
4
  from typing import Any
5
 
6
- from pydantic import BaseModel
7
 
8
 
9
  class OpType(str, Enum):
@@ -87,6 +87,13 @@ class SessionInfo(BaseModel):
87
  user_id: str = "dev"
88
  pending_approval: list[PendingApprovalTool] | None = None
89
  model: str | None = None
 
 
 
 
 
 
 
90
 
91
 
92
  class HealthResponse(BaseModel):
 
3
  from enum import Enum
4
  from typing import Any
5
 
6
+ from pydantic import BaseModel, Field
7
 
8
 
9
  class OpType(str, Enum):
 
87
  user_id: str = "dev"
88
  pending_approval: list[PendingApprovalTool] | None = None
89
  model: str | None = None
90
+ notification_destinations: list[str] = Field(default_factory=list)
91
+
92
+
93
+ class SessionNotificationsRequest(BaseModel):
94
+ """Replace the session's auto-notification destinations."""
95
+
96
+ destinations: list[str]
97
 
98
 
99
  class HealthResponse(BaseModel):
backend/routes/agent.py CHANGED
@@ -24,6 +24,7 @@ from models import (
24
  HealthResponse,
25
  LLMHealthResponse,
26
  SessionInfo,
 
27
  SessionResponse,
28
  SubmitRequest,
29
  TruncateRequest,
@@ -513,6 +514,26 @@ async def set_session_model(
513
  return {"session_id": session_id, "model": model_id}
514
 
515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  @router.get("/user/quota")
517
  async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
518
  """Return the user's plan tier and today's Claude-session quota state."""
@@ -824,7 +845,6 @@ async def shutdown_session(
824
  raise HTTPException(status_code=404, detail="Session not found or inactive")
825
  return {"status": "shutdown_requested", "session_id": session_id}
826
 
827
-
828
  @router.post("/feedback/{session_id}")
829
  async def submit_feedback(
830
  session_id: str,
 
24
  HealthResponse,
25
  LLMHealthResponse,
26
  SessionInfo,
27
+ SessionNotificationsRequest,
28
  SessionResponse,
29
  SubmitRequest,
30
  TruncateRequest,
 
514
  return {"session_id": session_id, "model": model_id}
515
 
516
 
517
+ @router.post("/session/{session_id}/notifications")
518
+ async def set_session_notifications(
519
+ session_id: str,
520
+ body: SessionNotificationsRequest,
521
+ user: dict = Depends(get_current_user),
522
+ ) -> dict:
523
+ """Replace the session's auto-notification destinations."""
524
+ _check_session_access(session_id, user)
525
+ try:
526
+ destinations = session_manager.set_notification_destinations(
527
+ session_id, body.destinations
528
+ )
529
+ except ValueError as e:
530
+ raise HTTPException(status_code=400, detail=str(e))
531
+ return {
532
+ "session_id": session_id,
533
+ "notification_destinations": destinations,
534
+ }
535
+
536
+
537
  @router.get("/user/quota")
538
  async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
539
  """Return the user's plan tier and today's Claude-session quota state."""
 
845
  raise HTTPException(status_code=404, detail="Session not found or inactive")
846
  return {"status": "shutdown_requested", "session_id": session_id}
847
 
 
848
  @router.post("/feedback/{session_id}")
849
  async def submit_feedback(
850
  session_id: str,
backend/session_manager.py CHANGED
@@ -10,6 +10,7 @@ from typing import Any, Optional
10
 
11
  from agent.config import load_config
12
  from agent.core.agent_loop import process_submission
 
13
  from agent.core.session import Event, OpType, Session
14
  from agent.core.tools import ToolRouter
15
 
@@ -119,9 +120,18 @@ class SessionManager:
119
 
120
  def __init__(self, config_path: str | None = None) -> None:
121
  self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
 
122
  self.sessions: dict[str, AgentSession] = {}
123
  self._lock = asyncio.Lock()
124
 
 
 
 
 
 
 
 
 
125
  def _count_user_sessions(self, user_id: str) -> int:
126
  """Count active sessions owned by a specific user."""
127
  return sum(
@@ -192,7 +202,11 @@ class SessionManager:
192
  session_config.model_name = model
193
  session = Session(
194
  event_queue, config=session_config, tool_router=tool_router,
195
- hf_token=hf_token, user_id=user_id,
 
 
 
 
196
  )
197
  t1 = _time.monotonic()
198
  logger.info(f"Session initialized in {t1 - t0:.2f}s")
@@ -518,8 +532,39 @@ class SessionManager:
518
  "user_id": agent_session.user_id,
519
  "pending_approval": pending_approval,
520
  "model": agent_session.session.config.model_name,
 
 
 
521
  }
522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
524
  """List sessions, optionally filtered by user.
525
 
 
10
 
11
  from agent.config import load_config
12
  from agent.core.agent_loop import process_submission
13
+ from agent.messaging.gateway import NotificationGateway
14
  from agent.core.session import Event, OpType, Session
15
  from agent.core.tools import ToolRouter
16
 
 
120
 
121
  def __init__(self, config_path: str | None = None) -> None:
122
  self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
123
+ self.messaging_gateway = NotificationGateway(self.config.messaging)
124
  self.sessions: dict[str, AgentSession] = {}
125
  self._lock = asyncio.Lock()
126
 
127
+ async def start(self) -> None:
128
+ """Start shared background resources."""
129
+ await self.messaging_gateway.start()
130
+
131
+ async def close(self) -> None:
132
+ """Flush and close shared background resources."""
133
+ await self.messaging_gateway.close()
134
+
135
  def _count_user_sessions(self, user_id: str) -> int:
136
  """Count active sessions owned by a specific user."""
137
  return sum(
 
202
  session_config.model_name = model
203
  session = Session(
204
  event_queue, config=session_config, tool_router=tool_router,
205
+ hf_token=hf_token,
206
+ user_id=user_id,
207
+ notification_gateway=self.messaging_gateway,
208
+ notification_destinations=[],
209
+ session_id=session_id,
210
  )
211
  t1 = _time.monotonic()
212
  logger.info(f"Session initialized in {t1 - t0:.2f}s")
 
532
  "user_id": agent_session.user_id,
533
  "pending_approval": pending_approval,
534
  "model": agent_session.session.config.model_name,
535
+ "notification_destinations": list(
536
+ agent_session.session.notification_destinations
537
+ ),
538
  }
539
 
540
+ def set_notification_destinations(
541
+ self, session_id: str, destinations: list[str]
542
+ ) -> list[str]:
543
+ """Replace the session's opted-in auto-notification destinations."""
544
+ agent_session = self.sessions.get(session_id)
545
+ if not agent_session or not agent_session.is_active:
546
+ raise ValueError("Session not found or inactive")
547
+
548
+ normalized: list[str] = []
549
+ seen: set[str] = set()
550
+ for raw_name in destinations:
551
+ name = raw_name.strip()
552
+ if not name:
553
+ raise ValueError("Destination names must not be empty")
554
+ destination = self.config.messaging.get_destination(name)
555
+ if destination is None:
556
+ raise ValueError(f"Unknown destination '{name}'")
557
+ if not destination.allow_auto_events:
558
+ raise ValueError(
559
+ f"Destination '{name}' is not enabled for auto events"
560
+ )
561
+ if name not in seen:
562
+ normalized.append(name)
563
+ seen.add(name)
564
+
565
+ agent_session.session.set_notification_destinations(normalized)
566
+ return normalized
567
+
568
  def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
569
  """List sessions, optionally filtered by user.
570
 
configs/cli_agent_config.json CHANGED
@@ -5,6 +5,11 @@
5
  "yolo_mode": false,
6
  "confirm_cpu_jobs": true,
7
  "auto_file_upload": true,
 
 
 
 
 
8
  "mcpServers": {
9
  "hf-mcp-server": {
10
  "transport": "http",
 
5
  "yolo_mode": false,
6
  "confirm_cpu_jobs": true,
7
  "auto_file_upload": true,
8
+ "messaging": {
9
+ "enabled": false,
10
+ "auto_event_types": ["approval_required", "error", "turn_complete"],
11
+ "destinations": {}
12
+ },
13
  "mcpServers": {
14
  "hf-mcp-server": {
15
  "transport": "http",
pyproject.toml CHANGED
@@ -42,7 +42,7 @@ eval = [
42
  # Development and testing dependencies
43
  dev = [
44
  "pytest>=9.0.2",
45
- "pytest-asyncio>=0.26.0",
46
  ]
47
 
48
  # All dependencies (eval + dev)
 
42
  # Development and testing dependencies
43
  dev = [
44
  "pytest>=9.0.2",
45
+ "pytest-asyncio>=1.2.0",
46
  ]
47
 
48
  # All dependencies (eval + dev)
tests/unit/test_cli_rendering.py CHANGED
@@ -79,7 +79,7 @@ async def test_interactive_main_applies_model_override_before_banner(monkeypatch
79
  monkeypatch.setattr(
80
  main_mod,
81
  "load_config",
82
- lambda _path: SimpleNamespace(
83
  model_name="moonshotai/Kimi-K2.6",
84
  mcpServers={},
85
  ),
 
79
  monkeypatch.setattr(
80
  main_mod,
81
  "load_config",
82
+ lambda _path, **_kwargs: SimpleNamespace(
83
  model_name="moonshotai/Kimi-K2.6",
84
  mcpServers={},
85
  ),
tests/unit/test_config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from agent import config as config_module
4
+
5
+
6
+ def _write_json(path, data):
7
+ path.write_text(json.dumps(data), encoding="utf-8")
8
+
9
+
10
+ def test_load_config_does_not_apply_slack_user_defaults_by_default(tmp_path, monkeypatch):
11
+ config_path = tmp_path / "config.json"
12
+ _write_json(
13
+ config_path,
14
+ {
15
+ "model_name": "moonshotai/Kimi-K2.6",
16
+ "messaging": {
17
+ "enabled": False,
18
+ "destinations": {},
19
+ },
20
+ },
21
+ )
22
+ monkeypatch.setenv("SLACK_BOT_TOKEN", "xoxb-test")
23
+ monkeypatch.setenv("SLACK_CHANNEL_ID", "C123")
24
+
25
+ config = config_module.load_config(str(config_path))
26
+
27
+ assert not config.messaging.enabled
28
+ assert config.messaging.destinations == {}
29
+
30
+
31
+ def test_load_config_applies_slack_user_defaults_from_env(tmp_path, monkeypatch):
32
+ config_path = tmp_path / "config.json"
33
+ _write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
34
+ monkeypatch.delenv("ML_INTERN_CLI_CONFIG", raising=False)
35
+ monkeypatch.setattr(
36
+ config_module,
37
+ "DEFAULT_USER_CONFIG_PATH",
38
+ tmp_path / "missing-user-config.json",
39
+ )
40
+ monkeypatch.setenv("SLACK_BOT_TOKEN", "xoxb-test")
41
+ monkeypatch.setenv("SLACK_CHANNEL_ID", "C123")
42
+
43
+ config = config_module.load_config(str(config_path), include_user_defaults=True)
44
+
45
+ assert config.messaging.enabled
46
+ assert config.messaging.auto_event_types == [
47
+ "approval_required",
48
+ "error",
49
+ "turn_complete",
50
+ ]
51
+ destination = config.messaging.destinations["slack.default"]
52
+ assert destination.token == "xoxb-test"
53
+ assert destination.channel == "C123"
54
+ assert destination.allow_agent_tool
55
+ assert destination.allow_auto_events
56
+
57
+
58
+ def test_load_config_merges_user_config_before_env_substitution(tmp_path, monkeypatch):
59
+ config_path = tmp_path / "config.json"
60
+ user_config_path = tmp_path / "user-config.json"
61
+ _write_json(config_path, {"model_name": "moonshotai/Kimi-K2.6"})
62
+ _write_json(
63
+ user_config_path,
64
+ {
65
+ "messaging": {
66
+ "enabled": True,
67
+ "auto_event_types": ["approval_required"],
68
+ "destinations": {
69
+ "slack.team": {
70
+ "provider": "slack",
71
+ "token": "${USER_SLACK_TOKEN}",
72
+ "channel": "C999",
73
+ "allow_agent_tool": False,
74
+ "allow_auto_events": True,
75
+ },
76
+ },
77
+ },
78
+ },
79
+ )
80
+ monkeypatch.setenv("ML_INTERN_CLI_CONFIG", str(user_config_path))
81
+ monkeypatch.setenv("ML_INTERN_SLACK_NOTIFICATIONS", "0")
82
+ monkeypatch.setenv("USER_SLACK_TOKEN", "xoxb-user")
83
+
84
+ config = config_module.load_config(str(config_path), include_user_defaults=True)
85
+
86
+ assert config.messaging.enabled
87
+ assert config.messaging.auto_event_types == ["approval_required"]
88
+ assert set(config.messaging.destinations) == {"slack.team"}
89
+ destination = config.messaging.destinations["slack.team"]
90
+ assert destination.token == "xoxb-user"
91
+ assert destination.channel == "C999"
92
+ assert not destination.allow_agent_tool
93
+ assert destination.allow_auto_events
94
+
95
+
96
+ def test_slack_user_defaults_can_be_disabled(tmp_path, monkeypatch):
97
+ config_path = tmp_path / "config.json"
98
+ _write_json(
99
+ config_path,
100
+ {
101
+ "model_name": "moonshotai/Kimi-K2.6",
102
+ "messaging": {
103
+ "enabled": False,
104
+ "destinations": {},
105
+ },
106
+ },
107
+ )
108
+ monkeypatch.delenv("ML_INTERN_CLI_CONFIG", raising=False)
109
+ monkeypatch.setattr(
110
+ config_module,
111
+ "DEFAULT_USER_CONFIG_PATH",
112
+ tmp_path / "missing-user-config.json",
113
+ )
114
+ monkeypatch.setenv("ML_INTERN_SLACK_NOTIFICATIONS", "false")
115
+ monkeypatch.setenv("SLACK_BOT_TOKEN", "xoxb-test")
116
+ monkeypatch.setenv("SLACK_CHANNEL_ID", "C123")
117
+
118
+ config = config_module.load_config(str(config_path), include_user_defaults=True)
119
+
120
+ assert not config.messaging.enabled
121
+ assert config.messaging.destinations == {}
tests/unit/test_messaging.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ from pathlib import Path
4
+ from types import SimpleNamespace
5
+
6
+ import httpx
7
+ import pytest
8
+ from pydantic import ValidationError
9
+
10
+ from agent.config import Config
11
+ from agent.core.session import Event, Session
12
+ from agent.messaging.gateway import NotificationGateway
13
+ from agent.messaging.models import NotificationRequest, NotificationResult
14
+ from agent.messaging.slack import SlackProvider, _format_slack_mrkdwn
15
+ from agent.tools.notify_tool import notify_handler
16
+ from backend.session_manager import AgentSession, SessionManager
17
+
18
+
19
+ class DummyToolRouter:
20
+ def get_tool_specs_for_llm(self) -> list[dict]:
21
+ return []
22
+
23
+
24
+ class RecordingGateway:
25
+ def __init__(self):
26
+ self.enqueued: list[NotificationRequest] = []
27
+ self.sent: list[NotificationRequest] = []
28
+
29
+ async def enqueue(self, request: NotificationRequest) -> bool:
30
+ self.enqueued.append(request)
31
+ return True
32
+
33
+ async def send_many(
34
+ self, requests: list[NotificationRequest]
35
+ ) -> list[NotificationResult]:
36
+ self.sent.extend(requests)
37
+ return [
38
+ NotificationResult(
39
+ destination=request.destination,
40
+ ok=True,
41
+ provider="test",
42
+ )
43
+ for request in requests
44
+ ]
45
+
46
+
47
+ def _config_with_messaging(**destination_overrides) -> Config:
48
+ destination = {
49
+ "provider": "slack",
50
+ "token": "xoxb-test",
51
+ "channel": "C123",
52
+ **destination_overrides,
53
+ }
54
+ return Config.model_validate(
55
+ {
56
+ "model_name": "moonshotai/Kimi-K2.6",
57
+ "messaging": {
58
+ "enabled": True,
59
+ "destinations": {
60
+ "slack.ops": destination,
61
+ },
62
+ },
63
+ }
64
+ )
65
+
66
+
67
+ def _test_session(
68
+ config: Config, gateway, session_id: str = "session-test"
69
+ ) -> Session:
70
+ return Session(
71
+ asyncio.Queue(),
72
+ config=config,
73
+ tool_router=DummyToolRouter(),
74
+ context_manager=SimpleNamespace(items=[]),
75
+ notification_gateway=gateway,
76
+ session_id=session_id,
77
+ )
78
+
79
+
80
+ def test_messaging_config_validates_destination_names():
81
+ with pytest.raises(ValidationError):
82
+ Config.model_validate(
83
+ {
84
+ "model_name": "moonshotai/Kimi-K2.6",
85
+ "messaging": {
86
+ "enabled": True,
87
+ "destinations": {
88
+ "Slack Ops": {
89
+ "provider": "slack",
90
+ "token": "x",
91
+ "channel": "C123",
92
+ }
93
+ },
94
+ },
95
+ }
96
+ )
97
+
98
+ config = _config_with_messaging(allow_agent_tool=True, allow_auto_events=True)
99
+ assert config.messaging.can_agent_tool_send("slack.ops")
100
+ assert config.messaging.can_auto_send("slack.ops")
101
+
102
+
103
+ def test_messaging_config_default_auto_destinations_only_returns_auto_enabled():
104
+ config = Config.model_validate(
105
+ {
106
+ "model_name": "moonshotai/Kimi-K2.6",
107
+ "messaging": {
108
+ "enabled": True,
109
+ "destinations": {
110
+ "slack.ops": {
111
+ "provider": "slack",
112
+ "token": "xoxb-test",
113
+ "channel": "C123",
114
+ "allow_auto_events": True,
115
+ },
116
+ "slack.tool": {
117
+ "provider": "slack",
118
+ "token": "xoxb-test",
119
+ "channel": "C999",
120
+ "allow_agent_tool": True,
121
+ },
122
+ },
123
+ },
124
+ }
125
+ )
126
+
127
+ assert config.messaging.default_auto_destinations() == ["slack.ops"]
128
+
129
+
130
+ def test_messaging_config_default_auto_destinations_empty_when_disabled():
131
+ config = Config.model_validate(
132
+ {
133
+ "model_name": "moonshotai/Kimi-K2.6",
134
+ "messaging": {
135
+ "enabled": False,
136
+ "destinations": {
137
+ "slack.ops": {
138
+ "provider": "slack",
139
+ "token": "xoxb-test",
140
+ "channel": "C123",
141
+ "allow_auto_events": True,
142
+ },
143
+ },
144
+ },
145
+ }
146
+ )
147
+
148
+ assert config.messaging.default_auto_destinations() == []
149
+
150
+
151
+ def test_slack_mrkdwn_formatter_converts_common_markdown():
152
+ formatted = _format_slack_mrkdwn(
153
+ "# Result\n"
154
+ "**Done** with *details* and ~~old text~~.\n"
155
+ "See [PR](https://github.com/huggingface/ml-intern/pull/116).\n"
156
+ "Keep `**literal**` and ```python\nx < 3\n``` untouched.\n"
157
+ "Escape <raw> & text."
158
+ )
159
+
160
+ assert "*Result*" in formatted
161
+ assert "*Done*" in formatted
162
+ assert "_details_" in formatted
163
+ assert "~old text~" in formatted
164
+ assert "<https://github.com/huggingface/ml-intern/pull/116|PR>" in formatted
165
+ assert "`**literal**`" in formatted
166
+ assert "```python\nx < 3\n```" in formatted
167
+ assert "Escape &lt;raw&gt; &amp; text." in formatted
168
+
169
+
170
+ @pytest.mark.asyncio
171
+ async def test_slack_provider_formats_and_sends_payload():
172
+ seen: dict[str, object] = {}
173
+
174
+ def handler(request: httpx.Request) -> httpx.Response:
175
+ seen["auth"] = request.headers["Authorization"]
176
+ seen["content_type"] = request.headers["Content-Type"]
177
+ seen["json"] = request.read().decode("utf-8")
178
+ return httpx.Response(200, json={"ok": True, "ts": "123.456"})
179
+
180
+ async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
181
+ provider = SlackProvider()
182
+ result = await provider.send(
183
+ client,
184
+ "slack.ops",
185
+ _config_with_messaging().messaging.destinations["slack.ops"],
186
+ NotificationRequest(
187
+ destination="slack.ops",
188
+ title="Approval required",
189
+ message="A **run** is waiting. See [details](https://example.com).",
190
+ severity="warning",
191
+ metadata={"session_id": "sess-1"},
192
+ ),
193
+ )
194
+
195
+ assert result.ok
196
+ assert result.external_id == "123.456"
197
+ assert seen["auth"] == "Bearer xoxb-test"
198
+ assert seen["content_type"].startswith("application/json")
199
+ payload = json.loads(str(seen["json"]))
200
+ assert payload["channel"] == "C123"
201
+ assert payload["mrkdwn"] is True
202
+ assert payload["text"] == (
203
+ "[WARNING] Approval required\n"
204
+ "A *run* is waiting. See <https://example.com|details>.\n"
205
+ "session_id: sess-1"
206
+ )
207
+
208
+
209
+ @pytest.mark.asyncio
210
+ async def test_notification_gateway_retries_transient_failures(monkeypatch):
211
+ attempts = {"count": 0}
212
+
213
+ def handler(_request: httpx.Request) -> httpx.Response:
214
+ attempts["count"] += 1
215
+ if attempts["count"] == 1:
216
+ return httpx.Response(503, json={"ok": False})
217
+ return httpx.Response(200, json={"ok": True, "ts": "999.1"})
218
+
219
+ async def fake_sleep(_delay: float) -> None:
220
+ return None
221
+
222
+ monkeypatch.setattr("agent.messaging.gateway.asyncio.sleep", fake_sleep)
223
+
224
+ config = _config_with_messaging(allow_agent_tool=True)
225
+ gateway = NotificationGateway(config.messaging)
226
+ async with httpx.AsyncClient(transport=httpx.MockTransport(handler)) as client:
227
+ gateway._client = client
228
+ result = await gateway.send(
229
+ NotificationRequest(
230
+ destination="slack.ops",
231
+ message="hello",
232
+ )
233
+ )
234
+ gateway._client = None
235
+
236
+ assert attempts["count"] == 2
237
+ assert result.ok
238
+
239
+
240
+ @pytest.mark.asyncio
241
+ async def test_notify_tool_rejects_non_allowlisted_destinations():
242
+ config = _config_with_messaging(allow_agent_tool=False)
243
+ gateway = RecordingGateway()
244
+ session = _test_session(config, gateway)
245
+
246
+ output, ok = await notify_handler(
247
+ {"destinations": ["slack.ops"], "message": "done"},
248
+ session=session,
249
+ )
250
+
251
+ assert not ok
252
+ assert "unavailable for the notify tool" in output
253
+ assert gateway.sent == []
254
+
255
+
256
+ @pytest.mark.asyncio
257
+ async def test_notify_tool_sends_to_allowlisted_destinations():
258
+ config = _config_with_messaging(allow_agent_tool=True)
259
+ gateway = RecordingGateway()
260
+ session = _test_session(config, gateway, session_id="sess-42")
261
+
262
+ output, ok = await notify_handler(
263
+ {
264
+ "destinations": ["slack.ops"],
265
+ "title": "Training complete",
266
+ "message": "The run finished successfully.",
267
+ "severity": "success",
268
+ },
269
+ session=session,
270
+ )
271
+
272
+ assert ok
273
+ assert output == "slack.ops: sent"
274
+ assert len(gateway.sent) == 1
275
+ sent = gateway.sent[0]
276
+ assert sent.metadata["session_id"] == "sess-42"
277
+ assert sent.metadata["model"] == "moonshotai/Kimi-K2.6"
278
+
279
+
280
+ @pytest.mark.asyncio
281
+ async def test_session_auto_notifications_only_send_opted_in_auto_destinations():
282
+ config = Config.model_validate(
283
+ {
284
+ "model_name": "moonshotai/Kimi-K2.6",
285
+ "messaging": {
286
+ "enabled": True,
287
+ "destinations": {
288
+ "slack.ops": {
289
+ "provider": "slack",
290
+ "token": "xoxb-test",
291
+ "channel": "C123",
292
+ "allow_auto_events": True,
293
+ },
294
+ "slack.tool": {
295
+ "provider": "slack",
296
+ "token": "xoxb-test",
297
+ "channel": "C999",
298
+ "allow_agent_tool": True,
299
+ },
300
+ },
301
+ },
302
+ }
303
+ )
304
+ gateway = RecordingGateway()
305
+ session = _test_session(config, gateway, session_id="sess-auto")
306
+ session.set_notification_destinations(["slack.ops", "slack.tool"])
307
+
308
+ await session.send_event(
309
+ Event(
310
+ event_type="approval_required",
311
+ data={"tools": [{"tool": "hf_jobs", "tool_call_id": "tc-1"}]},
312
+ )
313
+ )
314
+ await session.send_event(
315
+ Event(event_type="assistant_message", data={"content": "normal message"})
316
+ )
317
+
318
+ assert len(gateway.enqueued) == 1
319
+ request = gateway.enqueued[0]
320
+ assert request.destination == "slack.ops"
321
+ assert request.severity == "warning"
322
+ assert request.event_type == "approval_required"
323
+ assert "hf_jobs" in request.message
324
+
325
+
326
+ @pytest.mark.asyncio
327
+ async def test_turn_complete_auto_notification_includes_final_response_summary():
328
+ config = Config.model_validate(
329
+ {
330
+ "model_name": "moonshotai/Kimi-K2.6",
331
+ "messaging": {
332
+ "enabled": True,
333
+ "destinations": {
334
+ "slack.ops": {
335
+ "provider": "slack",
336
+ "token": "xoxb-test",
337
+ "channel": "C123",
338
+ "allow_auto_events": True,
339
+ }
340
+ },
341
+ },
342
+ }
343
+ )
344
+ gateway = RecordingGateway()
345
+ session = _test_session(config, gateway, session_id="sess-done")
346
+ session.set_notification_destinations(["slack.ops"])
347
+
348
+ await session.send_event(
349
+ Event(
350
+ event_type="turn_complete",
351
+ data={
352
+ "history_size": 12,
353
+ "final_response": "Evaluation finished. Accuracy: 84.2% on the validation split.",
354
+ },
355
+ )
356
+ )
357
+
358
+ assert len(gateway.enqueued) == 1
359
+ request = gateway.enqueued[0]
360
+ assert request.destination == "slack.ops"
361
+ assert request.severity == "success"
362
+ assert request.event_type == "turn_complete"
363
+ assert "completed successfully" in request.message
364
+ assert "Accuracy: 84.2%" in request.message
365
+
366
+
367
+ @pytest.mark.asyncio
368
+ async def test_turn_complete_auto_notification_supports_longer_summary():
369
+ config = Config.model_validate(
370
+ {
371
+ "model_name": "moonshotai/Kimi-K2.6",
372
+ "messaging": {
373
+ "enabled": True,
374
+ "destinations": {
375
+ "slack.ops": {
376
+ "provider": "slack",
377
+ "token": "xoxb-test",
378
+ "channel": "C123",
379
+ "allow_auto_events": True,
380
+ }
381
+ },
382
+ },
383
+ }
384
+ )
385
+ gateway = RecordingGateway()
386
+ session = _test_session(config, gateway, session_id="sess-long")
387
+ session.set_notification_destinations(["slack.ops"])
388
+
389
+ long_summary = "A" * 1200 + " END"
390
+ await session.send_event(
391
+ Event(
392
+ event_type="turn_complete",
393
+ data={
394
+ "history_size": 12,
395
+ "final_response": long_summary,
396
+ },
397
+ )
398
+ )
399
+
400
+ assert len(gateway.enqueued) == 1
401
+ request = gateway.enqueued[0]
402
+ assert request.event_type == "turn_complete"
403
+ assert "A" * 1200 in request.message
404
+ assert request.message.endswith("END")
405
+
406
+
407
+ @pytest.mark.asyncio
408
+ async def test_turn_complete_auto_notification_can_be_deferred():
409
+ config = Config.model_validate(
410
+ {
411
+ "model_name": "moonshotai/Kimi-K2.6",
412
+ "messaging": {
413
+ "enabled": True,
414
+ "destinations": {
415
+ "slack.ops": {
416
+ "provider": "slack",
417
+ "token": "xoxb-test",
418
+ "channel": "C123",
419
+ "allow_auto_events": True,
420
+ }
421
+ },
422
+ },
423
+ }
424
+ )
425
+ gateway = RecordingGateway()
426
+ session = Session(
427
+ asyncio.Queue(),
428
+ config=config,
429
+ tool_router=DummyToolRouter(),
430
+ context_manager=SimpleNamespace(items=[]),
431
+ notification_gateway=gateway,
432
+ notification_destinations=["slack.ops"],
433
+ defer_turn_complete_notification=True,
434
+ session_id="sess-deferred",
435
+ )
436
+ event = Event(
437
+ event_type="turn_complete",
438
+ data={"final_response": "Finished after the CLI drained the stream."},
439
+ )
440
+
441
+ await session.send_event(event)
442
+ assert gateway.enqueued == []
443
+
444
+ await session.send_deferred_turn_complete_notification(event)
445
+
446
+ assert len(gateway.enqueued) == 1
447
+ request = gateway.enqueued[0]
448
+ assert request.destination == "slack.ops"
449
+ assert request.event_type == "turn_complete"
450
+ assert "Finished after the CLI drained the stream." in request.message
451
+
452
+
453
+ @pytest.mark.asyncio
454
+ async def test_turn_complete_can_be_disabled_by_custom_auto_event_config():
455
+ config = Config.model_validate(
456
+ {
457
+ "model_name": "moonshotai/Kimi-K2.6",
458
+ "messaging": {
459
+ "enabled": True,
460
+ "auto_event_types": ["error"],
461
+ "destinations": {
462
+ "slack.ops": {
463
+ "provider": "slack",
464
+ "token": "xoxb-test",
465
+ "channel": "C123",
466
+ "allow_auto_events": True,
467
+ }
468
+ },
469
+ },
470
+ }
471
+ )
472
+ gateway = RecordingGateway()
473
+ session = _test_session(config, gateway, session_id="sess-optout")
474
+ session.set_notification_destinations(["slack.ops"])
475
+
476
+ await session.send_event(
477
+ Event(
478
+ event_type="turn_complete",
479
+ data={"final_response": "This should not notify."},
480
+ )
481
+ )
482
+
483
+ assert gateway.enqueued == []
484
+
485
+
486
+ def test_session_manager_updates_notification_destinations_in_session_info():
487
+ config = _config_with_messaging(allow_auto_events=True)
488
+ manager = SessionManager(str(Path(__file__).resolve().parents[2] / "configs" / "cli_agent_config.json"))
489
+ manager.config = config
490
+ manager.sessions = {}
491
+
492
+ session = _test_session(config, RecordingGateway(), session_id="sess-manager")
493
+ manager.sessions["sess-manager"] = AgentSession(
494
+ session_id="sess-manager",
495
+ session=session,
496
+ tool_router=DummyToolRouter(),
497
+ submission_queue=asyncio.Queue(),
498
+ )
499
+
500
+ updated = manager.set_notification_destinations(
501
+ "sess-manager",
502
+ ["slack.ops", "slack.ops"],
503
+ )
504
+
505
+ assert updated == ["slack.ops"]
506
+ info = manager.get_session_info("sess-manager")
507
+ assert info is not None
508
+ assert info["notification_destinations"] == ["slack.ops"]
509
+
510
+ with pytest.raises(ValueError):
511
+ manager.set_notification_destinations("sess-manager", ["slack.unknown"])
tests/unit/test_thinking_history.py CHANGED
@@ -159,7 +159,7 @@ async def test_streaming_call_rebuilds_anthropic_thinking_state(monkeypatch):
159
 
160
 
161
  @pytest.mark.asyncio
162
- async def test_streaming_call_collects_anthropic_delta_thinking_state(monkeypatch):
163
  async def fake_stream():
164
  yield SimpleNamespace(
165
  choices=[
@@ -167,7 +167,31 @@ async def test_streaming_call_collects_anthropic_delta_thinking_state(monkeypatc
167
  delta=SimpleNamespace(
168
  content=None,
169
  tool_calls=None,
170
- thinking_blocks=[{"type": "thinking", "thinking": "reasoned"}],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  ),
172
  finish_reason=None,
173
  )
@@ -186,8 +210,26 @@ async def test_streaming_call_collects_anthropic_delta_thinking_state(monkeypatc
186
  async def fake_acompletion(**_kwargs):
187
  return fake_stream()
188
 
189
- def fail_chunk_builder(*_args, **_kwargs):
190
- raise AssertionError("stream_chunk_builder should not run when deltas include thinking")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  events = []
193
  async def send_event(event):
@@ -199,7 +241,7 @@ async def test_streaming_call_collects_anthropic_delta_thinking_state(monkeypatc
199
  send_event=send_event,
200
  )
201
  monkeypatch.setattr(agent_loop, "acompletion", fake_acompletion)
202
- monkeypatch.setattr(agent_loop, "stream_chunk_builder", fail_chunk_builder)
203
 
204
  result = await _call_llm_streaming(
205
  session,
@@ -209,7 +251,10 @@ async def test_streaming_call_collects_anthropic_delta_thinking_state(monkeypatc
209
  )
210
 
211
  assert result.content == "done"
212
- assert result.thinking_blocks == [{"type": "thinking", "thinking": "reasoned"}]
 
 
 
213
 
214
 
215
  @pytest.mark.asyncio
 
159
 
160
 
161
  @pytest.mark.asyncio
162
+ async def test_streaming_call_rebuilds_anthropic_delta_thinking_state(monkeypatch):
163
  async def fake_stream():
164
  yield SimpleNamespace(
165
  choices=[
 
167
  delta=SimpleNamespace(
168
  content=None,
169
  tool_calls=None,
170
+ thinking_blocks=[
171
+ {
172
+ "type": "thinking",
173
+ "thinking": "reasoned",
174
+ "signature": "",
175
+ }
176
+ ],
177
+ ),
178
+ finish_reason=None,
179
+ )
180
+ ],
181
+ )
182
+ yield SimpleNamespace(
183
+ choices=[
184
+ SimpleNamespace(
185
+ delta=SimpleNamespace(
186
+ content=None,
187
+ tool_calls=None,
188
+ thinking_blocks=[
189
+ {
190
+ "type": "thinking",
191
+ "thinking": "",
192
+ "signature": "signed",
193
+ }
194
+ ],
195
  ),
196
  finish_reason=None,
197
  )
 
210
  async def fake_acompletion(**_kwargs):
211
  return fake_stream()
212
 
213
+ def fake_chunk_builder(chunks, **_kwargs):
214
+ assert len(chunks) == 4
215
+ return SimpleNamespace(
216
+ choices=[
217
+ SimpleNamespace(
218
+ message=Message(
219
+ role="assistant",
220
+ content="done",
221
+ thinking_blocks=[
222
+ {
223
+ "type": "thinking",
224
+ "thinking": "reasoned",
225
+ "signature": "signed",
226
+ }
227
+ ],
228
+ reasoning_content="reasoned",
229
+ )
230
+ )
231
+ ]
232
+ )
233
 
234
  events = []
235
  async def send_event(event):
 
241
  send_event=send_event,
242
  )
243
  monkeypatch.setattr(agent_loop, "acompletion", fake_acompletion)
244
+ monkeypatch.setattr(agent_loop, "stream_chunk_builder", fake_chunk_builder)
245
 
246
  result = await _call_llm_streaming(
247
  session,
 
251
  )
252
 
253
  assert result.content == "done"
254
+ assert result.thinking_blocks == [
255
+ {"type": "thinking", "thinking": "reasoned", "signature": "signed"}
256
+ ]
257
+ assert result.reasoning_content == "reasoned"
258
 
259
 
260
  @pytest.mark.asyncio
uv.lock CHANGED
@@ -1832,7 +1832,7 @@ requires-dist = [
1832
  { name = "prompt-toolkit", specifier = ">=3.0.0" },
1833
  { name = "pydantic", specifier = ">=2.12.3" },
1834
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
1835
- { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.26.0" },
1836
  { name = "python-dotenv", specifier = ">=1.2.1" },
1837
  { name = "requests", specifier = ">=2.33.0" },
1838
  { name = "rich", specifier = ">=13.0.0" },
 
1832
  { name = "prompt-toolkit", specifier = ">=3.0.0" },
1833
  { name = "pydantic", specifier = ">=2.12.3" },
1834
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.2" },
1835
+ { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.2.0" },
1836
  { name = "python-dotenv", specifier = ">=1.2.1" },
1837
  { name = "requests", specifier = ">=2.33.0" },
1838
  { name = "rich", specifier = ">=13.0.0" },