File size: 7,988 Bytes
de136d0
 
 
275319a
b05b6f5
de136d0
 
 
 
 
 
5fe810b
 
754345f
 
de136d0
 
8633be8
754345f
 
 
8633be8
5fe810b
 
 
 
de136d0
fa4ba99
2a2e170
0bd7547
 
 
 
71e1892
0bd7547
 
2a2e170
 
 
 
 
 
f4655f7
73882d9
de136d0
3fa386c
 
 
b05b6f5
3fa386c
e2552e8
 
 
 
 
 
 
 
 
6155b26
 
 
 
754345f
 
 
6155b26
 
 
 
754345f
 
 
6155b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a9e96d
de136d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fe810b
 
6155b26
 
 
 
de136d0
 
 
 
 
 
275319a
 
 
 
de136d0
6155b26
 
 
 
de136d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import json
import os
import re
from pathlib import Path
from typing import Any, Literal, Union

from dotenv import load_dotenv
from fastmcp.mcp_config import (
    RemoteMCPServer,
    StdioMCPServer,
)
from pydantic import BaseModel

from agent.messaging.models import MessagingConfig

# These two are the canonical server config types for MCP servers.
MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer]

# Project root: two levels up from this file (agent/config.py -> project root)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent


class Config(BaseModel):
    """Configuration manager"""

    model_name: str
    mcpServers: dict[str, MCPServerConfig] = {}
    save_sessions: bool = True
    session_dataset_repo: str = "smolagents/ml-intern-sessions"
    # Per-user private dataset that mirrors each session in Claude Code JSONL
    # format so the HF Agent Trace Viewer auto-renders it
    # (https://huggingface.co/changelog/agent-trace-viewer). Created private
    # on first use; user flips it public via /share-traces. ``{hf_user}`` is
    # substituted at upload time from the authenticated HF username.
    share_traces: bool = True
    personal_trace_repo_template: str = "{hf_user}/ml-intern-sessions"
    auto_save_interval: int = 1  # Save every N user turns (0 = disabled)
    # Mid-turn heartbeat: save + upload every N seconds while events are being
    # emitted. Guards against losing trace data on long-running turns that
    # crash before turn_complete (e.g. a multi-hour hf_jobs wait that OOMs).
    # 0 = disabled. Consumed by agent.core.telemetry.HeartbeatSaver.
    heartbeat_interval_s: int = 60
    yolo_mode: bool = False  # Auto-approve all tool calls without confirmation
    max_iterations: int = 300  # Max LLM calls per agent turn (-1 = unlimited)

    # Permission control parameters
    confirm_cpu_jobs: bool = True
    auto_file_upload: bool = False
    tool_runtime: Literal["local", "sandbox"] = "local"

    # Reasoning effort *preference* — the ceiling the user wants. The probe
    # on `/model` walks a cascade down from here (``max`` → ``xhigh`` → ``high``
    # → …) and caches per-model what the provider actually accepted in
    # ``Session.model_effective_effort``. Default ``max`` because we'd rather
    # burn tokens thinking than ship a wrong ML recipe; the cascade lands on
    # whichever level the model supports (``high`` for GPT-5 / HF router,
    # ``xhigh`` or ``max`` for Anthropic 4.6 / 4.7). ``None`` = thinking off.
    # Valid values: None | "minimal" | "low" | "medium" | "high" | "xhigh" | "max"
    reasoning_effort: str | None = "max"
    messaging: MessagingConfig = MessagingConfig()


USER_CONFIG_ENV_VAR = "ML_INTERN_CLI_CONFIG"
DEFAULT_USER_CONFIG_PATH = (
    Path.home() / ".config" / "ml-intern" / "cli_agent_config.json"
)
SLACK_DEFAULT_DESTINATION = "slack.default"
SLACK_DEFAULT_AUTO_EVENT_TYPES = ["approval_required", "error", "turn_complete"]


def _deep_merge_config(
    base: dict[str, Any], override: dict[str, Any]
) -> dict[str, Any]:
    merged = dict(base)
    for key, value in override.items():
        current = merged.get(key)
        if isinstance(current, dict) and isinstance(value, dict):
            merged[key] = _deep_merge_config(current, value)
        else:
            merged[key] = value
    return merged


def _load_json_config(path: Path) -> dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, dict):
        raise ValueError(f"Config file {path} must contain a JSON object")
    return data


def _load_user_config() -> dict[str, Any]:
    raw_path = os.environ.get(USER_CONFIG_ENV_VAR)
    if raw_path:
        path = Path(raw_path).expanduser()
        if not path.exists():
            raise FileNotFoundError(
                f"{USER_CONFIG_ENV_VAR} points to missing config file: {path}"
            )
        return _load_json_config(path)

    if DEFAULT_USER_CONFIG_PATH.exists():
        return _load_json_config(DEFAULT_USER_CONFIG_PATH)
    return {}


def _env_bool(name: str, default: bool) -> bool:
    value = os.environ.get(name)
    if value is None:
        return default
    normalized = value.strip().lower()
    if normalized in {"1", "true", "yes", "on"}:
        return True
    if normalized in {"0", "false", "no", "off"}:
        return False
    return default


def _env_list(name: str) -> list[str] | None:
    value = os.environ.get(name)
    if value is None:
        return None
    return [item.strip() for item in value.split(",") if item.strip()]


def apply_slack_user_defaults(raw_config: dict[str, Any]) -> dict[str, Any]:
    """Enable a default Slack destination from user env vars, when present."""
    if not _env_bool("ML_INTERN_SLACK_NOTIFICATIONS", True):
        return raw_config

    token = os.environ.get("SLACK_BOT_TOKEN")
    channel = os.environ.get("SLACK_CHANNEL_ID") or os.environ.get("SLACK_CHANNEL")
    if not token or not channel:
        return raw_config

    config = dict(raw_config)
    messaging = dict(config.get("messaging") or {})
    destinations = dict(messaging.get("destinations") or {})
    destination_name = (
        os.environ.get("ML_INTERN_SLACK_DESTINATION") or SLACK_DEFAULT_DESTINATION
    ).strip()

    if destination_name not in destinations:
        destinations[destination_name] = {
            "provider": "slack",
            "token": token,
            "channel": channel,
            "allow_agent_tool": _env_bool("ML_INTERN_SLACK_ALLOW_AGENT_TOOL", True),
            "allow_auto_events": _env_bool("ML_INTERN_SLACK_ALLOW_AUTO_EVENTS", True),
        }

    auto_events = _env_list("ML_INTERN_SLACK_AUTO_EVENTS")
    if auto_events is not None:
        messaging["auto_event_types"] = auto_events
    elif "auto_event_types" not in messaging:
        messaging["auto_event_types"] = SLACK_DEFAULT_AUTO_EVENT_TYPES

    messaging["enabled"] = True
    messaging["destinations"] = destinations
    config["messaging"] = messaging
    return config


def substitute_env_vars(obj: Any) -> Any:
    """
    Recursively substitute environment variables in any data structure.

    Supports ${VAR_NAME} syntax for required variables and ${VAR_NAME:-default} for optional.
    """
    if isinstance(obj, str):
        pattern = r"\$\{([^}:]+)(?::(-)?([^}]*))?\}"

        def replacer(match):
            var_name = match.group(1)
            has_default = match.group(2) is not None
            default_value = match.group(3) if has_default else None

            env_value = os.environ.get(var_name)

            if env_value is not None:
                return env_value
            elif has_default:
                return default_value or ""
            else:
                raise ValueError(
                    f"Environment variable '{var_name}' is not set. "
                    f"Add it to your .env file."
                )

        return re.sub(pattern, replacer, obj)

    elif isinstance(obj, dict):
        return {key: substitute_env_vars(value) for key, value in obj.items()}

    elif isinstance(obj, list):
        return [substitute_env_vars(item) for item in obj]

    return obj


def load_config(
    config_path: str = "config.json",
    include_user_defaults: bool = False,
) -> Config:
    """
    Load configuration with environment variable substitution.

    Use ${VAR_NAME} in your JSON for any secret.
    Automatically loads from .env file.
    """
    # Load .env from project root first (so it works from any directory),
    # then CWD .env can override if present
    load_dotenv(_PROJECT_ROOT / ".env")
    load_dotenv(override=False)

    raw_config = _load_json_config(Path(config_path))
    if include_user_defaults:
        raw_config = _deep_merge_config(raw_config, _load_user_config())
        raw_config = apply_slack_user_defaults(raw_config)

    config_with_env = substitute_env_vars(raw_config)
    return Config.model_validate(config_with_env)