onewayto commited on
Commit
683b580
·
verified ·
1 Parent(s): 174be0d

Upload 49 files

Browse files
Files changed (49) hide show
  1. .gitignore +71 -0
  2. .python-version +1 -0
  3. Dockerfile +32 -0
  4. Procfile +1 -0
  5. __init__.py +1 -0
  6. agent/README.md +21 -0
  7. agent/__init__.py +7 -0
  8. agent/config.py +83 -0
  9. agent/context_manager/__init__.py +7 -0
  10. agent/context_manager/manager.py +197 -0
  11. agent/core/__init__.py +12 -0
  12. agent/core/agent_loop.py +711 -0
  13. agent/core/session.py +255 -0
  14. agent/core/session_uploader.py +202 -0
  15. agent/core/tools.py +337 -0
  16. agent/main.py +567 -0
  17. agent/prompts/system_prompt.yaml +170 -0
  18. agent/prompts/system_prompt_v2.yaml +626 -0
  19. agent/tools/__init__.py +39 -0
  20. agent/tools/dataset_tools.py +445 -0
  21. agent/tools/docs_tools.py +956 -0
  22. agent/tools/github_find_examples.py +499 -0
  23. agent/tools/github_list_repos.py +287 -0
  24. agent/tools/github_read_file.py +348 -0
  25. agent/tools/hf_repo_files_tool.py +322 -0
  26. agent/tools/hf_repo_git_tool.py +663 -0
  27. agent/tools/jobs_tool.py +1042 -0
  28. agent/tools/plan_tool.py +138 -0
  29. agent/tools/private_hf_repo_tools.py +650 -0
  30. agent/tools/types.py +16 -0
  31. agent/tools/utilities.py +142 -0
  32. agent/utils/__init__.py +3 -0
  33. agent/utils/reliability_checks.py +16 -0
  34. agent/utils/terminal_display.py +155 -0
  35. configs/main_agent_config.json +17 -0
  36. dependencies.py +144 -0
  37. main.py +96 -0
  38. models.py +87 -0
  39. pyproject.toml +51 -0
  40. requirements.txt +25 -0
  41. routes/__init__.py +1 -0
  42. routes/__pycache__/__init__.cpython-313.pyc +0 -0
  43. routes/__pycache__/agent.cpython-313.pyc +0 -0
  44. routes/agent.py +404 -0
  45. routes/auth.py +171 -0
  46. session_manager.py +376 -0
  47. start.sh +26 -0
  48. uv.lock +0 -0
  49. websocket.py +62 -0
.gitignore ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+ .pytest_cache/
9
+ .mypy_cache/
10
+ .tox/
11
+ .coverage
12
+ htmlcov/
13
+ .ipynb_checkpoints/
14
+
15
+ # Virtual environments
16
+ .venv/
17
+ venv/
18
+ ENV/
19
+ env/
20
+
21
+ # Environment and Secrets
22
+ .env
23
+ .env.local
24
+ .env.*
25
+ !.env.example
26
+ *.local
27
+ credentials*.json
28
+
29
+ # OS-specific
30
+ .DS_Store
31
+ Thumbs.db
32
+ *.swp
33
+
34
+ # IDE-specific
35
+ .vscode/
36
+ .idea/
37
+ .cursor/
38
+ .history/
39
+ *.sublime-project
40
+ *.sublime-workspace
41
+
42
+ # Frontend (Node.js)
43
+ frontend/node_modules/
44
+ frontend/dist/
45
+ frontend/.cache/
46
+ frontend/*.local
47
+ frontend/.eslintcache
48
+ frontend/npm-debug.log*
49
+ frontend/yarn-debug.log*
50
+ frontend/yarn-error.log*
51
+
52
+ # Docker
53
+ .docker/
54
+
55
+ # Project-specific
56
+ session_logs/
57
+ /logs
58
+ hf-agent-leaderboard/
59
+ skills/
60
+ .claude/
61
+ *.jsonl
62
+ *.csv
63
+
64
+ # ML / Data
65
+ data/
66
+ datasets/
67
+ models/
68
+ checkpoint-*/
69
+ runs/
70
+ wandb/
71
+ frontend/tsconfig.tsbuildinfo
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF Agent Backend - Docker Image
2
+ FROM python:3.12-slim
3
+
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ gcc \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first for better caching
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ # Copy application code
16
+ COPY . .
17
+
18
+ # Grant full write access (chmod 777) to /app directory and set root as owner
19
+ RUN chmod -R 777 /app && chown -R root:root /app
20
+
21
+ # Run as root user
22
+ USER root
23
+
24
+ # Expose port
25
+ EXPOSE 7860
26
+
27
+ # Health check
28
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
29
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/api/health')" || exit 1
30
+
31
+ # Run the application
32
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: uvicorn main:app --host 0.0.0.0 --port ${PORT:-7860}
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Backend package for HF Agent web interface
agent/README.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Agent
2
+
3
+ Async agent loop with LiteLLM.
4
+
5
+ ## Architecture
6
+
7
+ **Queue-based async system:**
8
+ - Submissions in (user input) → Agent Loop → Events output for possible UI updates
9
+ - Session maintains state (context + tools) for possible future Context Engineering
10
+ - Handlers operations like (USER_INPUT, INTERRUPT, COMPACT, UNDO, SHUTDOWN) for possible UI control
11
+
12
+ ## Components
13
+
14
+ | Component | Purpose | Long Term Goal |
15
+ |-----------|---------|----------------|
16
+ | **`agent_loop.py`** | Core agentic loop: processes user input, calls LLM via LiteLLM, executes tool calls iteratively until completion, emits events | Support parallel tool execution, streaming responses, and advanced reasoning patterns |
17
+ | **`session.py`** | Maintains session state and interaction with potential UI (context, config, event queue), handles interrupts, assigns unique session IDs for tracing | Enable plugging in different UIs (CLI, web, API, programmatic etc.) |
18
+ | **`tools.py`** | `ToolRouter` manages potential built-in tools (e.g. bash, read_file, write_file which are dummy implementations rn) + MCP tools, converts specs to OpenAI format | Be the place for tools that can be used by the agent. All crazy tool design happens here. |
19
+ | **`context_manager/`** | Manages conversation history, very rudimentary context engineering support | Implement intelligent context engineering to keep the agent on track |
20
+ | **`config.py`** | Loads JSON config for the agent | Support different configs etc. |
21
+ | **`main.py`** | Interactive CLI with async queue architecture (submission→agent, agent→events) (simple way to interact with the agent now)| Serve as reference implementation for other UIs (web, API, programmatic) |
agent/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Agent - Main agent module
3
+ """
4
+
5
+ from agent.core.agent_loop import submission_loop
6
+
7
+ __all__ = ["submission_loop"]
agent/config.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from typing import Any, Union
5
+
6
+ from dotenv import load_dotenv
7
+ from fastmcp.mcp_config import (
8
+ RemoteMCPServer,
9
+ StdioMCPServer,
10
+ )
11
+ from pydantic import BaseModel
12
+
13
+ # These two are the canonical server config types for MCP servers.
14
+ MCPServerConfig = Union[StdioMCPServer, RemoteMCPServer]
15
+
16
+
17
+ class Config(BaseModel):
18
+ """Configuration manager"""
19
+
20
+ model_name: str
21
+ mcpServers: dict[str, MCPServerConfig] = {}
22
+ save_sessions: bool = True
23
+ session_dataset_repo: str = "akseljoonas/hf-agent-sessions"
24
+ auto_save_interval: int = 3 # Save every N user turns (0 = disabled)
25
+ yolo_mode: bool = False # Auto-approve all tool calls without confirmation
26
+
27
+ # Permission control parameters
28
+ confirm_cpu_jobs: bool = True
29
+ auto_file_upload: bool = False
30
+
31
+
32
+ def substitute_env_vars(obj: Any) -> Any:
33
+ """
34
+ Recursively substitute environment variables in any data structure.
35
+
36
+ Supports ${VAR_NAME} syntax for required variables and ${VAR_NAME:-default} for optional.
37
+ """
38
+ if isinstance(obj, str):
39
+ pattern = r"\$\{([^}:]+)(?::(-)?([^}]*))?\}"
40
+
41
+ def replacer(match):
42
+ var_name = match.group(1)
43
+ has_default = match.group(2) is not None
44
+ default_value = match.group(3) if has_default else None
45
+
46
+ env_value = os.environ.get(var_name)
47
+
48
+ if env_value is not None:
49
+ return env_value
50
+ elif has_default:
51
+ return default_value or ""
52
+ else:
53
+ raise ValueError(
54
+ f"Environment variable '{var_name}' is not set. "
55
+ f"Add it to your .env file."
56
+ )
57
+
58
+ return re.sub(pattern, replacer, obj)
59
+
60
+ elif isinstance(obj, dict):
61
+ return {key: substitute_env_vars(value) for key, value in obj.items()}
62
+
63
+ elif isinstance(obj, list):
64
+ return [substitute_env_vars(item) for item in obj]
65
+
66
+ return obj
67
+
68
+
69
+ def load_config(config_path: str = "config.json") -> Config:
70
+ """
71
+ Load configuration with environment variable substitution.
72
+
73
+ Use ${VAR_NAME} in your JSON for any secret.
74
+ Automatically loads from .env file.
75
+ """
76
+ # Load environment variables from .env file
77
+ load_dotenv()
78
+
79
+ with open(config_path, "r") as f:
80
+ raw_config = json.load(f)
81
+
82
+ config_with_env = substitute_env_vars(raw_config)
83
+ return Config.model_validate(config_with_env)
agent/context_manager/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Context manager for handling conversation history
3
+ """
4
+
5
+ from agent.context_manager.manager import ContextManager
6
+
7
+ __all__ = ["ContextManager"]
agent/context_manager/manager.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Context management for conversation history
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import zoneinfo
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import yaml
13
+ from jinja2 import Template
14
+ from litellm import Message, acompletion
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Module-level cache for HF username — avoids repeating the slow whoami() call
19
+ _hf_username_cache: str | None = None
20
+
21
+ _HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2"
22
+ _HF_WHOAMI_TIMEOUT = 5 # seconds
23
+
24
+
25
+ def _get_hf_username() -> str:
26
+ """Return the HF username, cached after the first call.
27
+
28
+ Uses subprocess + curl to avoid Python HTTP client IPv6 issues that
29
+ cause 40+ second hangs (httpx/urllib try IPv6 first which times out
30
+ at OS level before falling back to IPv4 — the "Happy Eyeballs" problem).
31
+ """
32
+ import json
33
+ import subprocess
34
+ import time as _t
35
+
36
+ global _hf_username_cache
37
+ if _hf_username_cache is not None:
38
+ return _hf_username_cache
39
+
40
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
41
+ if not hf_token:
42
+ logger.warning("No HF_TOKEN set, using 'unknown' as username")
43
+ _hf_username_cache = "unknown"
44
+ return _hf_username_cache
45
+
46
+ t0 = _t.monotonic()
47
+ try:
48
+ result = subprocess.run(
49
+ [
50
+ "curl",
51
+ "-s",
52
+ "-4", # force IPv4
53
+ "-m",
54
+ str(_HF_WHOAMI_TIMEOUT), # max time
55
+ "-H",
56
+ f"Authorization: Bearer {hf_token}",
57
+ _HF_WHOAMI_URL,
58
+ ],
59
+ capture_output=True,
60
+ text=True,
61
+ timeout=_HF_WHOAMI_TIMEOUT + 2,
62
+ )
63
+ t1 = _t.monotonic()
64
+ if result.returncode == 0 and result.stdout:
65
+ data = json.loads(result.stdout)
66
+ _hf_username_cache = data.get("name", "unknown")
67
+ logger.info(
68
+ f"HF username resolved to '{_hf_username_cache}' in {t1 - t0:.2f}s"
69
+ )
70
+ else:
71
+ logger.warning(
72
+ f"curl whoami failed (rc={result.returncode}) in {t1 - t0:.2f}s"
73
+ )
74
+ _hf_username_cache = "unknown"
75
+ except Exception as e:
76
+ t1 = _t.monotonic()
77
+ logger.warning(f"HF whoami failed in {t1 - t0:.2f}s: {e}")
78
+ _hf_username_cache = "unknown"
79
+
80
+ return _hf_username_cache
81
+
82
+
83
+ class ContextManager:
84
+ """Manages conversation context and message history for the agent"""
85
+
86
+ def __init__(
87
+ self,
88
+ max_context: int = 180_000,
89
+ compact_size: float = 0.1,
90
+ untouched_messages: int = 5,
91
+ tool_specs: list[dict[str, Any]] | None = None,
92
+ prompt_file_suffix: str = "system_prompt_v2.yaml",
93
+ ):
94
+ self.system_prompt = self._load_system_prompt(
95
+ tool_specs or [],
96
+ prompt_file_suffix="system_prompt_v2.yaml",
97
+ )
98
+ self.max_context = max_context
99
+ self.compact_size = int(max_context * compact_size)
100
+ self.context_length = len(self.system_prompt) // 4
101
+ self.untouched_messages = untouched_messages
102
+ self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
103
+
104
+ def _load_system_prompt(
105
+ self,
106
+ tool_specs: list[dict[str, Any]],
107
+ prompt_file_suffix: str = "system_prompt.yaml",
108
+ ):
109
+ """Load and render the system prompt from YAML file with Jinja2"""
110
+ prompt_file = Path(__file__).parent.parent / "prompts" / f"{prompt_file_suffix}"
111
+
112
+ with open(prompt_file, "r") as f:
113
+ prompt_data = yaml.safe_load(f)
114
+ template_str = prompt_data.get("system_prompt", "")
115
+
116
+ # Get current date and time
117
+ tz = zoneinfo.ZoneInfo("Europe/Paris")
118
+ now = datetime.now(tz)
119
+ current_date = now.strftime("%d-%m-%Y")
120
+ current_time = now.strftime("%H:%M:%S.%f")[:-3]
121
+ current_timezone = f"{now.strftime('%Z')} (UTC{now.strftime('%z')[:3]}:{now.strftime('%z')[3:]})"
122
+
123
+ # Get HF user info (cached after the first call)
124
+ hf_user_info = _get_hf_username()
125
+
126
+ template = Template(template_str)
127
+ return template.render(
128
+ tools=tool_specs,
129
+ num_tools=len(tool_specs),
130
+ current_date=current_date,
131
+ current_time=current_time,
132
+ current_timezone=current_timezone,
133
+ hf_user_info=hf_user_info,
134
+ )
135
+
136
+ def add_message(self, message: Message, token_count: int = None) -> None:
137
+ """Add a message to the history"""
138
+ if token_count:
139
+ self.context_length = token_count
140
+ self.items.append(message)
141
+
142
+ def get_messages(self) -> list[Message]:
143
+ """Get all messages for sending to LLM"""
144
+ return self.items
145
+
146
+ async def compact(self, model_name: str) -> None:
147
+ """Remove old messages to keep history under target size"""
148
+ if (self.context_length <= self.max_context) or not self.items:
149
+ return
150
+
151
+ system_msg = (
152
+ self.items[0] if self.items and self.items[0].role == "system" else None
153
+ )
154
+
155
+ # Don't summarize a certain number of just-preceding messages
156
+ # Walk back to find a user message to make sure we keep an assistant -> user ->
157
+ # assistant general conversation structure
158
+ idx = len(self.items) - self.untouched_messages
159
+ while idx > 1 and self.items[idx].role != "user":
160
+ idx -= 1
161
+
162
+ recent_messages = self.items[idx:]
163
+ messages_to_summarize = self.items[1:idx]
164
+
165
+ # improbable, messages would have to very long
166
+ if not messages_to_summarize:
167
+ return
168
+
169
+ messages_to_summarize.append(
170
+ Message(
171
+ role="user",
172
+ content="Please provide a concise summary of the conversation above, focusing on key decisions, code changes, problems solved, and important context needed for future turns.",
173
+ )
174
+ )
175
+
176
+ hf_key = os.environ.get("INFERENCE_TOKEN")
177
+ response = await acompletion(
178
+ model=model_name,
179
+ messages=messages_to_summarize,
180
+ max_completion_tokens=self.compact_size,
181
+ api_key=hf_key
182
+ if hf_key and model_name.startswith("huggingface/")
183
+ else None,
184
+ )
185
+ summarized_message = Message(
186
+ role="assistant", content=response.choices[0].message.content
187
+ )
188
+
189
+ # Reconstruct: system + summary + recent messages (includes tools)
190
+ if system_msg:
191
+ self.items = [system_msg, summarized_message] + recent_messages
192
+ else:
193
+ self.items = [summarized_message] + recent_messages
194
+
195
+ self.context_length = (
196
+ len(self.system_prompt) // 4 + response.usage.completion_tokens
197
+ )
agent/core/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core agent implementation
3
+ Contains the main agent logic, decision-making, and orchestration
4
+ """
5
+
6
+ from agent.core.tools import ToolRouter, ToolSpec, create_builtin_tools
7
+
8
+ __all__ = [
9
+ "ToolRouter",
10
+ "ToolSpec",
11
+ "create_builtin_tools",
12
+ ]
agent/core/agent_loop.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """loop
2
+ Main agent implementation with integrated tool system and MCP support
3
+ """
4
+
5
+ import asyncio
6
+ import json
7
+ import logging
8
+ import os
9
+
10
+ from litellm import ChatCompletionMessageToolCall, Message, acompletion
11
+ from lmnr import observe
12
+
13
+ from agent.config import Config
14
+ from agent.core.session import Event, OpType, Session
15
+ from agent.core.tools import ToolRouter
16
+ from agent.tools.jobs_tool import CPU_FLAVORS
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ ToolCall = ChatCompletionMessageToolCall
21
+ # Explicit inference token — needed because litellm checks HF_TOKEN before
22
+ # HUGGINGFACE_API_KEY, and HF_TOKEN (used for Hub ops) may lack inference permissions.
23
+ _INFERENCE_API_KEY = os.environ.get("INFERENCE_TOKEN")
24
+
25
+
26
+ def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
27
+ """
28
+ Validate tool arguments structure.
29
+
30
+ Returns:
31
+ (is_valid, error_message)
32
+ """
33
+ args = tool_args.get("args", {})
34
+ # Sometimes LLM passes args as string instead of dict
35
+ if isinstance(args, str):
36
+ return (
37
+ False,
38
+ f"Tool call error: 'args' must be a JSON object, not a string. You passed: {repr(args)}",
39
+ )
40
+ if not isinstance(args, dict) and args is not None:
41
+ return (
42
+ False,
43
+ f"Tool call error: 'args' must be a JSON object. You passed type: {type(args).__name__}",
44
+ )
45
+ return True, None
46
+
47
+
48
+ def _needs_approval(
49
+ tool_name: str, tool_args: dict, config: Config | None = None
50
+ ) -> bool:
51
+ """Check if a tool call requires user approval before execution."""
52
+ # Yolo mode: skip all approvals
53
+ if config and config.yolo_mode:
54
+ return False
55
+
56
+ # If args are malformed, skip approval (validation error will be shown later)
57
+ args_valid, _ = _validate_tool_args(tool_args)
58
+ if not args_valid:
59
+ return False
60
+
61
+ if tool_name == "hf_jobs":
62
+ operation = tool_args.get("operation", "")
63
+ if operation not in ["run", "uv", "scheduled run", "scheduled uv"]:
64
+ return False
65
+
66
+ # Check if this is a CPU-only job
67
+ # hardware_flavor is at top level of tool_args, not nested in args
68
+ hardware_flavor = (
69
+ tool_args.get("hardware_flavor")
70
+ or tool_args.get("flavor")
71
+ or tool_args.get("hardware")
72
+ or "cpu-basic"
73
+ )
74
+ is_cpu_job = hardware_flavor in CPU_FLAVORS
75
+
76
+ if is_cpu_job:
77
+ if config and not config.confirm_cpu_jobs:
78
+ return False
79
+ return True
80
+
81
+ return True
82
+
83
+ # Check for file upload operations (hf_private_repos or other tools)
84
+ if tool_name == "hf_private_repos":
85
+ operation = tool_args.get("operation", "")
86
+ if operation == "upload_file":
87
+ if config and config.auto_file_upload:
88
+ return False
89
+ return True
90
+ # Other operations (create_repo, etc.) always require approval
91
+ if operation in ["create_repo"]:
92
+ return True
93
+
94
+ # hf_repo_files: upload (can overwrite) and delete require approval
95
+ if tool_name == "hf_repo_files":
96
+ operation = tool_args.get("operation", "")
97
+ if operation in ["upload", "delete"]:
98
+ return True
99
+
100
+ # hf_repo_git: destructive operations require approval
101
+ if tool_name == "hf_repo_git":
102
+ operation = tool_args.get("operation", "")
103
+ if operation in [
104
+ "delete_branch",
105
+ "delete_tag",
106
+ "merge_pr",
107
+ "create_repo",
108
+ "update_repo",
109
+ ]:
110
+ return True
111
+
112
+ return False
113
+
114
+
115
+ class Handlers:
116
+ """Handler functions for each operation type"""
117
+
118
+ @staticmethod
119
+ @observe(name="run_agent")
120
+ async def run_agent(
121
+ session: Session, text: str, max_iterations: int = 10
122
+ ) -> str | None:
123
+ """
124
+ Handle user input (like user_input_or_turn in codex.rs:1291)
125
+ Returns the final assistant response content, if any.
126
+ """
127
+ # Set session ID for this trace
128
+ if hasattr(session, "session_id"):
129
+ from lmnr import Laminar
130
+
131
+ Laminar.set_trace_session_id(session_id=session.session_id)
132
+
133
+ # Add user message to history only if there's actual content
134
+ if text:
135
+ user_msg = Message(role="user", content=text)
136
+ session.context_manager.add_message(user_msg)
137
+
138
+ # Send event that we're processing
139
+ await session.send_event(
140
+ Event(event_type="processing", data={"message": "Processing user input"})
141
+ )
142
+
143
+ # Agentic loop - continue until model doesn't call tools or max iterations is reached
144
+ iteration = 0
145
+ final_response = None
146
+
147
+ while iteration < max_iterations:
148
+ messages = session.context_manager.get_messages()
149
+ tools = session.tool_router.get_tool_specs_for_llm()
150
+ try:
151
+ # ── Stream the LLM response ──────────────────────────
152
+ response = await acompletion(
153
+ model=session.config.model_name,
154
+ messages=messages,
155
+ tools=tools,
156
+ tool_choice="auto",
157
+ stream=True,
158
+ stream_options={"include_usage": True},
159
+ api_key=_INFERENCE_API_KEY
160
+ if _INFERENCE_API_KEY
161
+ and session.config.model_name.startswith("huggingface/")
162
+ else None,
163
+ )
164
+
165
+ full_content = ""
166
+ tool_calls_acc: dict[int, dict] = {}
167
+ token_count = 0
168
+
169
+ async for chunk in response:
170
+ choice = chunk.choices[0] if chunk.choices else None
171
+ if not choice:
172
+ # Last chunk may carry only usage info
173
+ if hasattr(chunk, "usage") and chunk.usage:
174
+ token_count = chunk.usage.total_tokens
175
+ continue
176
+
177
+ delta = choice.delta
178
+
179
+ # Stream text deltas to the frontend
180
+ if delta.content:
181
+ full_content += delta.content
182
+ await session.send_event(
183
+ Event(
184
+ event_type="assistant_chunk",
185
+ data={"content": delta.content},
186
+ )
187
+ )
188
+
189
+ # Accumulate tool-call deltas (name + args arrive in pieces)
190
+ if delta.tool_calls:
191
+ for tc_delta in delta.tool_calls:
192
+ idx = tc_delta.index
193
+ if idx not in tool_calls_acc:
194
+ tool_calls_acc[idx] = {
195
+ "id": "",
196
+ "type": "function",
197
+ "function": {"name": "", "arguments": ""},
198
+ }
199
+ if tc_delta.id:
200
+ tool_calls_acc[idx]["id"] = tc_delta.id
201
+ if tc_delta.function:
202
+ if tc_delta.function.name:
203
+ tool_calls_acc[idx]["function"]["name"] += (
204
+ tc_delta.function.name
205
+ )
206
+ if tc_delta.function.arguments:
207
+ tool_calls_acc[idx]["function"]["arguments"] += (
208
+ tc_delta.function.arguments
209
+ )
210
+
211
+ # Capture usage from the final chunk
212
+ if hasattr(chunk, "usage") and chunk.usage:
213
+ token_count = chunk.usage.total_tokens
214
+
215
+ # ── Stream finished — reconstruct full message ───────
216
+ content = full_content or None
217
+
218
+ # Build tool_calls list from accumulated deltas
219
+ tool_calls: list[ToolCall] = []
220
+ for idx in sorted(tool_calls_acc.keys()):
221
+ tc_data = tool_calls_acc[idx]
222
+ tool_calls.append(
223
+ ToolCall(
224
+ id=tc_data["id"],
225
+ type="function",
226
+ function={
227
+ "name": tc_data["function"]["name"],
228
+ "arguments": tc_data["function"]["arguments"],
229
+ },
230
+ )
231
+ )
232
+
233
+ # Signal end of streaming to the frontend
234
+ await session.send_event(
235
+ Event(event_type="assistant_stream_end", data={})
236
+ )
237
+
238
+ # If no tool calls, add assistant message and we're done
239
+ if not tool_calls:
240
+ if content:
241
+ assistant_msg = Message(role="assistant", content=content)
242
+ session.context_manager.add_message(assistant_msg, token_count)
243
+ final_response = content
244
+ break
245
+
246
+ # Add assistant message with tool calls to history
247
+ assistant_msg = Message(
248
+ role="assistant",
249
+ content=content,
250
+ tool_calls=tool_calls,
251
+ )
252
+ session.context_manager.add_message(assistant_msg, token_count)
253
+
254
+ # Separate tools into those requiring approval and those that don't
255
+ approval_required_tools = []
256
+ non_approval_tools = []
257
+
258
+ for tc in tool_calls:
259
+ tool_name = tc.function.name
260
+ try:
261
+ tool_args = json.loads(tc.function.arguments)
262
+ except (json.JSONDecodeError, TypeError) as e:
263
+ logger.warning(f"Malformed tool arguments for {tool_name}: {e}")
264
+ tool_args = {}
265
+
266
+ if _needs_approval(tool_name, tool_args, session.config):
267
+ approval_required_tools.append(tc)
268
+ else:
269
+ non_approval_tools.append(tc)
270
+
271
+ # Execute non-approval tools (in parallel when possible)
272
+ if non_approval_tools:
273
+ # 1. Parse args and validate upfront
274
+ parsed_tools: list[
275
+ tuple[ChatCompletionMessageToolCall, str, dict, bool, str]
276
+ ] = []
277
+ for tc in non_approval_tools:
278
+ tool_name = tc.function.name
279
+ try:
280
+ tool_args = json.loads(tc.function.arguments)
281
+ except (json.JSONDecodeError, TypeError):
282
+ tool_args = {}
283
+
284
+ args_valid, error_msg = _validate_tool_args(tool_args)
285
+ parsed_tools.append(
286
+ (tc, tool_name, tool_args, args_valid, error_msg)
287
+ )
288
+
289
+ # 2. Send all tool_call events upfront (so frontend shows them all)
290
+ for tc, tool_name, tool_args, args_valid, _ in parsed_tools:
291
+ if args_valid:
292
+ await session.send_event(
293
+ Event(
294
+ event_type="tool_call",
295
+ data={
296
+ "tool": tool_name,
297
+ "arguments": tool_args,
298
+ "tool_call_id": tc.id,
299
+ },
300
+ )
301
+ )
302
+
303
+ # 3. Execute all valid tools in parallel
304
+ async def _exec_tool(
305
+ tc: ChatCompletionMessageToolCall,
306
+ name: str,
307
+ args: dict,
308
+ valid: bool,
309
+ err: str,
310
+ ) -> tuple[ChatCompletionMessageToolCall, str, dict, str, bool]:
311
+ if not valid:
312
+ return (tc, name, args, err, False)
313
+ out, ok = await session.tool_router.call_tool(
314
+ name, args, session=session
315
+ )
316
+ return (tc, name, args, out, ok)
317
+
318
+ results = await asyncio.gather(
319
+ *[
320
+ _exec_tool(tc, name, args, valid, err)
321
+ for tc, name, args, valid, err in parsed_tools
322
+ ]
323
+ )
324
+
325
+ # 4. Record results and send outputs (order preserved)
326
+ for tc, tool_name, tool_args, output, success in results:
327
+ tool_msg = Message(
328
+ role="tool",
329
+ content=output,
330
+ tool_call_id=tc.id,
331
+ name=tool_name,
332
+ )
333
+ session.context_manager.add_message(tool_msg)
334
+
335
+ await session.send_event(
336
+ Event(
337
+ event_type="tool_output",
338
+ data={
339
+ "tool": tool_name,
340
+ "tool_call_id": tc.id,
341
+ "output": output,
342
+ "success": success,
343
+ },
344
+ )
345
+ )
346
+
347
+ # If there are tools requiring approval, ask for batch approval
348
+ if approval_required_tools:
349
+ # Prepare batch approval data
350
+ tools_data = []
351
+ for tc in approval_required_tools:
352
+ tool_name = tc.function.name
353
+ try:
354
+ tool_args = json.loads(tc.function.arguments)
355
+ except (json.JSONDecodeError, TypeError):
356
+ tool_args = {}
357
+ tools_data.append(
358
+ {
359
+ "tool": tool_name,
360
+ "arguments": tool_args,
361
+ "tool_call_id": tc.id,
362
+ }
363
+ )
364
+
365
+ await session.send_event(
366
+ Event(
367
+ event_type="approval_required",
368
+ data={
369
+ "tools": tools_data, # Batch of tools
370
+ "count": len(tools_data),
371
+ },
372
+ )
373
+ )
374
+
375
+ # Store all approval-requiring tools
376
+ session.pending_approval = {
377
+ "tool_calls": approval_required_tools,
378
+ }
379
+
380
+ # Return early - wait for EXEC_APPROVAL operation
381
+ return None
382
+
383
+ iteration += 1
384
+
385
+ except Exception as e:
386
+ import traceback
387
+
388
+ await session.send_event(
389
+ Event(
390
+ event_type="error",
391
+ data={"error": str(e) + "\n" + traceback.format_exc()},
392
+ )
393
+ )
394
+ break
395
+
396
+ old_length = session.context_manager.context_length
397
+ await session.context_manager.compact(model_name=session.config.model_name)
398
+ new_length = session.context_manager.context_length
399
+
400
+ if new_length != old_length:
401
+ await session.send_event(
402
+ Event(
403
+ event_type="compacted",
404
+ data={"old_tokens": old_length, "new_tokens": new_length},
405
+ )
406
+ )
407
+
408
+ await session.send_event(
409
+ Event(
410
+ event_type="turn_complete",
411
+ data={"history_size": len(session.context_manager.items)},
412
+ )
413
+ )
414
+
415
+ # Increment turn counter and check for auto-save
416
+ session.increment_turn()
417
+ await session.auto_save_if_needed()
418
+
419
+ return final_response
420
+
421
+ @staticmethod
422
+ async def interrupt(session: Session) -> None:
423
+ """Handle interrupt (like interrupt in codex.rs:1266)"""
424
+ session.interrupt()
425
+ await session.send_event(Event(event_type="interrupted"))
426
+
427
+ @staticmethod
428
+ async def compact(session: Session) -> None:
429
+ """Handle compact (like compact in codex.rs:1317)"""
430
+ old_length = session.context_manager.context_length
431
+ await session.context_manager.compact(model_name=session.config.model_name)
432
+ new_length = session.context_manager.context_length
433
+
434
+ await session.send_event(
435
+ Event(
436
+ event_type="compacted",
437
+ data={"removed": old_length, "remaining": new_length},
438
+ )
439
+ )
440
+
441
+ @staticmethod
442
+ async def undo(session: Session) -> None:
443
+ """Remove the last complete turn (user msg + all assistant/tool msgs that follow).
444
+
445
+ Anthropic requires every tool_use to have a matching tool_result,
446
+ so we can't just pop 2 items — we must pop everything back to
447
+ (and including) the last user message to keep the history valid.
448
+ """
449
+ items = session.context_manager.items
450
+ if not items:
451
+ await session.send_event(Event(event_type="undo_complete"))
452
+ return
453
+
454
+ # Pop from the end until we've removed the last user message
455
+ removed_user = False
456
+ while items:
457
+ msg = items.pop()
458
+ if getattr(msg, "role", None) == "user":
459
+ removed_user = True
460
+ break
461
+
462
+ if not removed_user:
463
+ logger.warning("Undo: no user message found to remove")
464
+
465
+ await session.send_event(Event(event_type="undo_complete"))
466
+
467
+ @staticmethod
468
+ async def exec_approval(session: Session, approvals: list[dict]) -> None:
469
+ """Handle batch job execution approval"""
470
+ if not session.pending_approval:
471
+ await session.send_event(
472
+ Event(
473
+ event_type="error",
474
+ data={"error": "No pending approval to process"},
475
+ )
476
+ )
477
+ return
478
+
479
+ tool_calls = session.pending_approval.get("tool_calls", [])
480
+ if not tool_calls:
481
+ await session.send_event(
482
+ Event(
483
+ event_type="error",
484
+ data={"error": "No pending tool calls found"},
485
+ )
486
+ )
487
+ return
488
+
489
+ # Create a map of tool_call_id -> approval decision
490
+ approval_map = {a["tool_call_id"]: a for a in approvals}
491
+
492
+ # Separate approved and rejected tool calls
493
+ approved_tasks = []
494
+ rejected_tasks = []
495
+
496
+ for tc in tool_calls:
497
+ tool_name = tc.function.name
498
+ tool_args = json.loads(tc.function.arguments)
499
+ approval_decision = approval_map.get(tc.id, {"approved": False})
500
+
501
+ if approval_decision.get("approved", False):
502
+ approved_tasks.append((tc, tool_name, tool_args))
503
+ else:
504
+ rejected_tasks.append((tc, tool_name, approval_decision))
505
+
506
+ # Execute all approved tools concurrently
507
+ async def execute_tool(tc, tool_name, tool_args):
508
+ """Execute a single tool and return its result"""
509
+ await session.send_event(
510
+ Event(
511
+ event_type="tool_call",
512
+ data={
513
+ "tool": tool_name,
514
+ "arguments": tool_args,
515
+ "tool_call_id": tc.id,
516
+ },
517
+ )
518
+ )
519
+
520
+ output, success = await session.tool_router.call_tool(
521
+ tool_name, tool_args, session=session
522
+ )
523
+
524
+ return (tc, tool_name, output, success)
525
+
526
+ # Execute all approved tools concurrently and wait for ALL to complete
527
+ if approved_tasks:
528
+ results = await asyncio.gather(
529
+ *[
530
+ execute_tool(tc, tool_name, tool_args)
531
+ for tc, tool_name, tool_args in approved_tasks
532
+ ],
533
+ return_exceptions=True,
534
+ )
535
+
536
+ # Process results and add to context
537
+ for result in results:
538
+ if isinstance(result, Exception):
539
+ # Handle execution error
540
+ logger.error(f"Tool execution error: {result}")
541
+ continue
542
+
543
+ tc, tool_name, output, success = result
544
+
545
+ # Add tool result to context
546
+ tool_msg = Message(
547
+ role="tool",
548
+ content=output,
549
+ tool_call_id=tc.id,
550
+ name=tool_name,
551
+ )
552
+ session.context_manager.add_message(tool_msg)
553
+
554
+ await session.send_event(
555
+ Event(
556
+ event_type="tool_output",
557
+ data={
558
+ "tool": tool_name,
559
+ "tool_call_id": tc.id,
560
+ "output": output,
561
+ "success": success,
562
+ },
563
+ )
564
+ )
565
+
566
+ # Process rejected tools
567
+ for tc, tool_name, approval_decision in rejected_tasks:
568
+ rejection_msg = "Job execution cancelled by user"
569
+ user_feedback = approval_decision.get("feedback")
570
+ if user_feedback:
571
+ rejection_msg += f". User feedback: {user_feedback}"
572
+
573
+ tool_msg = Message(
574
+ role="tool",
575
+ content=rejection_msg,
576
+ tool_call_id=tc.id,
577
+ name=tool_name,
578
+ )
579
+ session.context_manager.add_message(tool_msg)
580
+
581
+ await session.send_event(
582
+ Event(
583
+ event_type="tool_output",
584
+ data={
585
+ "tool": tool_name,
586
+ "tool_call_id": tc.id,
587
+ "output": rejection_msg,
588
+ "success": False,
589
+ },
590
+ )
591
+ )
592
+
593
+ # Clear pending approval
594
+ session.pending_approval = None
595
+
596
+ # Continue agent loop with empty input to process the tool results
597
+ await Handlers.run_agent(session, "")
598
+
599
+ @staticmethod
600
+ async def shutdown(session: Session) -> bool:
601
+ """Handle shutdown (like shutdown in codex.rs:1329)"""
602
+ # Save session trajectory if enabled (fire-and-forget, returns immediately)
603
+ if session.config.save_sessions:
604
+ logger.info("Saving session...")
605
+ repo_id = session.config.session_dataset_repo
606
+ _ = session.save_and_upload_detached(repo_id)
607
+
608
+ session.is_running = False
609
+ await session.send_event(Event(event_type="shutdown"))
610
+ return True
611
+
612
+
613
+ async def process_submission(session: Session, submission) -> bool:
614
+ """
615
+ Process a single submission and return whether to continue running.
616
+
617
+ Returns:
618
+ bool: True to continue, False to shutdown
619
+ """
620
+ op = submission.operation
621
+ logger.debug("Received operation: %s", op.op_type.value)
622
+
623
+ if op.op_type == OpType.USER_INPUT:
624
+ text = op.data.get("text", "") if op.data else ""
625
+ await Handlers.run_agent(session, text)
626
+ return True
627
+
628
+ if op.op_type == OpType.INTERRUPT:
629
+ await Handlers.interrupt(session)
630
+ return True
631
+
632
+ if op.op_type == OpType.COMPACT:
633
+ await Handlers.compact(session)
634
+ return True
635
+
636
+ if op.op_type == OpType.UNDO:
637
+ await Handlers.undo(session)
638
+ return True
639
+
640
+ if op.op_type == OpType.EXEC_APPROVAL:
641
+ approvals = op.data.get("approvals", []) if op.data else []
642
+ await Handlers.exec_approval(session, approvals)
643
+ return True
644
+
645
+ if op.op_type == OpType.SHUTDOWN:
646
+ return not await Handlers.shutdown(session)
647
+
648
+ logger.warning(f"Unknown operation: {op.op_type}")
649
+ return True
650
+
651
+
652
+ @observe(name="submission_loop")
653
+ async def submission_loop(
654
+ submission_queue: asyncio.Queue,
655
+ event_queue: asyncio.Queue,
656
+ config: Config | None = None,
657
+ tool_router: ToolRouter | None = None,
658
+ ) -> None:
659
+ """
660
+ Main agent loop - processes submissions and dispatches to handlers.
661
+ This is the core of the agent (like submission_loop in codex.rs:1259-1340)
662
+ """
663
+
664
+ # Create session with tool router
665
+ session = Session(event_queue, config=config, tool_router=tool_router)
666
+ logger.info("Agent loop started")
667
+
668
+ # Retry any failed uploads from previous sessions (fire-and-forget)
669
+ if config and config.save_sessions:
670
+ Session.retry_failed_uploads_detached(
671
+ directory="session_logs", repo_id=config.session_dataset_repo
672
+ )
673
+
674
+ try:
675
+ # Main processing loop
676
+ async with tool_router:
677
+ # Emit ready event after initialization
678
+ await session.send_event(
679
+ Event(event_type="ready", data={"message": "Agent initialized"})
680
+ )
681
+
682
+ while session.is_running:
683
+ submission = await submission_queue.get()
684
+
685
+ try:
686
+ should_continue = await process_submission(session, submission)
687
+ if not should_continue:
688
+ break
689
+ except asyncio.CancelledError:
690
+ logger.warning("Agent loop cancelled")
691
+ break
692
+ except Exception as e:
693
+ logger.error(f"Error in agent loop: {e}")
694
+ await session.send_event(
695
+ Event(event_type="error", data={"error": str(e)})
696
+ )
697
+
698
+ logger.info("Agent loop exited")
699
+
700
+ finally:
701
+ # Emergency save if session saving is enabled and shutdown wasn't called properly
702
+ if session.config.save_sessions and session.is_running:
703
+ logger.info("Emergency save: preserving session before exit...")
704
+ try:
705
+ local_path = session.save_and_upload_detached(
706
+ session.config.session_dataset_repo
707
+ )
708
+ if local_path:
709
+ logger.info("Emergency save successful, upload in progress")
710
+ except Exception as e:
711
+ logger.error(f"Emergency save failed: {e}")
agent/core/session.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import subprocess
5
+ import sys
6
+ import uuid
7
+ from dataclasses import dataclass
8
+ from datetime import datetime
9
+ from enum import Enum
10
+ from pathlib import Path
11
+ from typing import Any, Optional
12
+
13
+ from agent.config import Config
14
+ from agent.context_manager.manager import ContextManager
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Local max-token lookup — avoids litellm.get_max_tokens() which can hang
19
+ # on network calls for certain providers (known litellm issue).
20
+ _MAX_TOKENS_MAP: dict[str, int] = {
21
+ # Anthropic
22
+ "anthropic/claude-opus-4-5-20251101": 200_000,
23
+ "anthropic/claude-sonnet-4-5-20250929": 200_000,
24
+ "anthropic/claude-sonnet-4-20250514": 200_000,
25
+ "anthropic/claude-haiku-3-5-20241022": 200_000,
26
+ "anthropic/claude-3-5-sonnet-20241022": 200_000,
27
+ "anthropic/claude-3-opus-20240229": 200_000,
28
+ "huggingface/novita/MiniMaxAI/MiniMax-M2.1": 196_608,
29
+ "huggingface/novita/moonshotai/Kimi-K2.5": 262_144,
30
+ "huggingface/novita/zai-org/GLM-5": 200_000,
31
+ }
32
+ _DEFAULT_MAX_TOKENS = 200_000
33
+
34
+
35
+ def _get_max_tokens_safe(model_name: str) -> int:
36
+ """Return the max context window for a model without network calls."""
37
+ tokens = _MAX_TOKENS_MAP.get(model_name)
38
+ if tokens:
39
+ return tokens
40
+ # Fallback: try litellm but with a short timeout via threading
41
+ try:
42
+ from litellm import get_max_tokens
43
+
44
+ result = get_max_tokens(model_name)
45
+ if result and isinstance(result, int):
46
+ return result
47
+ logger.warning(
48
+ f"get_max_tokens returned {result} for {model_name}, using default"
49
+ )
50
+ return _DEFAULT_MAX_TOKENS
51
+ except Exception as e:
52
+ logger.warning(f"get_max_tokens failed for {model_name}, using default: {e}")
53
+ return _DEFAULT_MAX_TOKENS
54
+
55
+
56
+ class OpType(Enum):
57
+ USER_INPUT = "user_input"
58
+ EXEC_APPROVAL = "exec_approval"
59
+ INTERRUPT = "interrupt"
60
+ UNDO = "undo"
61
+ COMPACT = "compact"
62
+ SHUTDOWN = "shutdown"
63
+
64
+
65
+ @dataclass
66
+ class Event:
67
+ event_type: str
68
+ data: Optional[dict[str, Any]] = None
69
+
70
+
71
+ class Session:
72
+ """
73
+ Maintains agent session state
74
+ Similar to Session in codex-rs/core/src/codex.rs
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ event_queue: asyncio.Queue,
80
+ config: Config | None = None,
81
+ tool_router=None,
82
+ context_manager: ContextManager | None = None,
83
+ ):
84
+ self.tool_router = tool_router
85
+ tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else []
86
+ self.context_manager = context_manager or ContextManager(
87
+ max_context=_get_max_tokens_safe(config.model_name),
88
+ compact_size=0.1,
89
+ untouched_messages=5,
90
+ tool_specs=tool_specs,
91
+ )
92
+ self.event_queue = event_queue
93
+ self.session_id = str(uuid.uuid4())
94
+ self.config = config or Config(
95
+ model_name="anthropic/claude-sonnet-4-5-20250929",
96
+ )
97
+ self.is_running = True
98
+ self.current_task: asyncio.Task | None = None
99
+ self.pending_approval: Optional[dict[str, Any]] = None
100
+ # User's HF OAuth token — set by session_manager after construction
101
+ self.hf_token: Optional[str] = None
102
+
103
+ # Session trajectory logging
104
+ self.logged_events: list[dict] = []
105
+ self.session_start_time = datetime.now().isoformat()
106
+ self.turn_count: int = 0
107
+ self.last_auto_save_turn: int = 0
108
+
109
+ async def send_event(self, event: Event) -> None:
110
+ """Send event back to client and log to trajectory"""
111
+ await self.event_queue.put(event)
112
+
113
+ # Log event to trajectory
114
+ self.logged_events.append(
115
+ {
116
+ "timestamp": datetime.now().isoformat(),
117
+ "event_type": event.event_type,
118
+ "data": event.data,
119
+ }
120
+ )
121
+
122
+ def interrupt(self) -> None:
123
+ """Interrupt current running task"""
124
+ if self.current_task and not self.current_task.done():
125
+ self.current_task.cancel()
126
+
127
+ def increment_turn(self) -> None:
128
+ """Increment turn counter (called after each user interaction)"""
129
+ self.turn_count += 1
130
+
131
+ async def auto_save_if_needed(self) -> None:
132
+ """Check if auto-save should trigger and save if so (completely non-blocking)"""
133
+ if not self.config.save_sessions:
134
+ return
135
+
136
+ interval = self.config.auto_save_interval
137
+ if interval <= 0:
138
+ return
139
+
140
+ turns_since_last_save = self.turn_count - self.last_auto_save_turn
141
+ if turns_since_last_save >= interval:
142
+ logger.info(f"Auto-saving session (turn {self.turn_count})...")
143
+ # Fire-and-forget save - returns immediately
144
+ self.save_and_upload_detached(self.config.session_dataset_repo)
145
+ self.last_auto_save_turn = self.turn_count
146
+
147
+ def get_trajectory(self) -> dict:
148
+ """Serialize complete session trajectory for logging"""
149
+ return {
150
+ "session_id": self.session_id,
151
+ "session_start_time": self.session_start_time,
152
+ "session_end_time": datetime.now().isoformat(),
153
+ "model_name": self.config.model_name,
154
+ "messages": [msg.model_dump() for msg in self.context_manager.items],
155
+ "events": self.logged_events,
156
+ }
157
+
158
+ def save_trajectory_local(
159
+ self,
160
+ directory: str = "session_logs",
161
+ upload_status: str = "pending",
162
+ dataset_url: Optional[str] = None,
163
+ ) -> Optional[str]:
164
+ """
165
+ Save trajectory to local JSON file as backup with upload status
166
+
167
+ Args:
168
+ directory: Directory to save logs (default: "session_logs")
169
+ upload_status: Status of upload attempt ("pending", "success", "failed")
170
+ dataset_url: URL of dataset if upload succeeded
171
+
172
+ Returns:
173
+ Path to saved file if successful, None otherwise
174
+ """
175
+ try:
176
+ log_dir = Path(directory)
177
+ log_dir.mkdir(parents=True, exist_ok=True)
178
+
179
+ trajectory = self.get_trajectory()
180
+
181
+ # Add upload metadata
182
+ trajectory["upload_status"] = upload_status
183
+ trajectory["upload_url"] = dataset_url
184
+ trajectory["last_save_time"] = datetime.now().isoformat()
185
+
186
+ filename = f"session_{self.session_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
187
+ filepath = log_dir / filename
188
+
189
+ with open(filepath, "w") as f:
190
+ json.dump(trajectory, f, indent=2)
191
+
192
+ return str(filepath)
193
+ except Exception as e:
194
+ logger.error(f"Failed to save session locally: {e}")
195
+ return None
196
+
197
+ def save_and_upload_detached(self, repo_id: str) -> Optional[str]:
198
+ """
199
+ Save session locally and spawn detached subprocess for upload (fire-and-forget)
200
+
201
+ Args:
202
+ repo_id: HuggingFace dataset repo ID
203
+
204
+ Returns:
205
+ Path to local save file
206
+ """
207
+ # Save locally first (fast, synchronous)
208
+ local_path = self.save_trajectory_local(upload_status="pending")
209
+ if not local_path:
210
+ return None
211
+
212
+ # Spawn detached subprocess for upload (fire-and-forget)
213
+ try:
214
+ uploader_script = Path(__file__).parent / "session_uploader.py"
215
+
216
+ # Use Popen with detached process
217
+ subprocess.Popen(
218
+ [sys.executable, str(uploader_script), "upload", local_path, repo_id],
219
+ stdin=subprocess.DEVNULL,
220
+ stdout=subprocess.DEVNULL,
221
+ stderr=subprocess.DEVNULL,
222
+ start_new_session=True, # Detach from parent
223
+ )
224
+ except Exception as e:
225
+ logger.warning(f"Failed to spawn upload subprocess: {e}")
226
+
227
+ return local_path
228
+
229
+ @staticmethod
230
+ def retry_failed_uploads_detached(
231
+ directory: str = "session_logs", repo_id: Optional[str] = None
232
+ ) -> None:
233
+ """
234
+ Spawn detached subprocess to retry failed/pending uploads (fire-and-forget)
235
+
236
+ Args:
237
+ directory: Directory containing session logs
238
+ repo_id: Target dataset repo ID
239
+ """
240
+ if not repo_id:
241
+ return
242
+
243
+ try:
244
+ uploader_script = Path(__file__).parent / "session_uploader.py"
245
+
246
+ # Spawn detached subprocess for retry
247
+ subprocess.Popen(
248
+ [sys.executable, str(uploader_script), "retry", directory, repo_id],
249
+ stdin=subprocess.DEVNULL,
250
+ stdout=subprocess.DEVNULL,
251
+ stderr=subprocess.DEVNULL,
252
+ start_new_session=True, # Detach from parent
253
+ )
254
+ except Exception as e:
255
+ logger.warning(f"Failed to spawn retry subprocess: {e}")
agent/core/session_uploader.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone script for uploading session trajectories to HuggingFace.
4
+ This runs as a separate process to avoid blocking the main agent.
5
+ Uses individual file uploads to avoid race conditions.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import sys
11
+ from datetime import datetime
12
+ from pathlib import Path
13
+
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv()
17
+
18
+ # Token for session uploads — loaded from env var (never hardcode tokens in source)
19
+ _SESSION_TOKEN = os.environ.get("HF_SESSION_UPLOAD_TOKEN", "")
20
+
21
+
22
+ def upload_session_as_file(
23
+ session_file: str, repo_id: str, max_retries: int = 3
24
+ ) -> bool:
25
+ """
26
+ Upload a single session as an individual JSONL file (no race conditions)
27
+
28
+ Args:
29
+ session_file: Path to local session JSON file
30
+ repo_id: HuggingFace dataset repo ID
31
+ max_retries: Number of retry attempts
32
+
33
+ Returns:
34
+ True if successful, False otherwise
35
+ """
36
+ try:
37
+ from huggingface_hub import HfApi
38
+ except ImportError:
39
+ print("Error: huggingface_hub library not available", file=sys.stderr)
40
+ return False
41
+
42
+ try:
43
+ # Load session data
44
+ with open(session_file, "r") as f:
45
+ data = json.load(f)
46
+
47
+ # Check if already uploaded
48
+ upload_status = data.get("upload_status")
49
+ if upload_status == "success":
50
+ return True
51
+
52
+ # Use dedicated session upload token (write-only access to session dataset)
53
+ hf_token = _SESSION_TOKEN
54
+ if not hf_token:
55
+ # Update status to failed
56
+ data["upload_status"] = "failed"
57
+ with open(session_file, "w") as f:
58
+ json.dump(data, f, indent=2)
59
+ return False
60
+
61
+ # Prepare JSONL content (single line)
62
+ # Store messages and events as JSON strings to avoid schema conflicts
63
+ session_row = {
64
+ "session_id": data["session_id"],
65
+ "session_start_time": data["session_start_time"],
66
+ "session_end_time": data["session_end_time"],
67
+ "model_name": data["model_name"],
68
+ "messages": json.dumps(data["messages"]),
69
+ "events": json.dumps(data["events"]),
70
+ }
71
+
72
+ # Create temporary JSONL file
73
+ import tempfile
74
+
75
+ with tempfile.NamedTemporaryFile(
76
+ mode="w", suffix=".jsonl", delete=False
77
+ ) as tmp:
78
+ json.dump(session_row, tmp) # Single line JSON
79
+ tmp_path = tmp.name
80
+
81
+ try:
82
+ # Generate unique path in repo: sessions/YYYY-MM-DD/session_id.jsonl
83
+ session_id = data["session_id"]
84
+ date_str = datetime.fromisoformat(data["session_start_time"]).strftime(
85
+ "%Y-%m-%d"
86
+ )
87
+ repo_path = f"sessions/{date_str}/{session_id}.jsonl"
88
+
89
+ # Upload with retries
90
+ api = HfApi()
91
+ for attempt in range(max_retries):
92
+ try:
93
+ # Try to create repo if it doesn't exist (idempotent)
94
+ try:
95
+ api.create_repo(
96
+ repo_id=repo_id,
97
+ repo_type="dataset",
98
+ private=False,
99
+ token=hf_token,
100
+ exist_ok=True, # Don't fail if already exists
101
+ )
102
+
103
+ except Exception:
104
+ # Repo might already exist, continue
105
+ pass
106
+
107
+ # Upload the session file
108
+ api.upload_file(
109
+ path_or_fileobj=tmp_path,
110
+ path_in_repo=repo_path,
111
+ repo_id=repo_id,
112
+ repo_type="dataset",
113
+ token=hf_token,
114
+ commit_message=f"Add session {session_id}",
115
+ )
116
+
117
+ # Update local status to success
118
+ data["upload_status"] = "success"
119
+ data["upload_url"] = f"https://huggingface.co/datasets/{repo_id}"
120
+ with open(session_file, "w") as f:
121
+ json.dump(data, f, indent=2)
122
+
123
+ return True
124
+
125
+ except Exception:
126
+ if attempt < max_retries - 1:
127
+ import time
128
+
129
+ wait_time = 2**attempt
130
+ time.sleep(wait_time)
131
+ else:
132
+ # Final attempt failed
133
+ data["upload_status"] = "failed"
134
+ with open(session_file, "w") as f:
135
+ json.dump(data, f, indent=2)
136
+ return False
137
+
138
+ finally:
139
+ # Clean up temp file
140
+ try:
141
+ os.unlink(tmp_path)
142
+ except Exception:
143
+ pass
144
+
145
+ except Exception as e:
146
+ print(f"Error uploading session: {e}", file=sys.stderr)
147
+ return False
148
+
149
+
150
+ def retry_failed_uploads(directory: str, repo_id: str):
151
+ """Retry all failed/pending uploads in a directory"""
152
+ log_dir = Path(directory)
153
+ if not log_dir.exists():
154
+ return
155
+
156
+ session_files = list(log_dir.glob("session_*.json"))
157
+
158
+ for filepath in session_files:
159
+ try:
160
+ with open(filepath, "r") as f:
161
+ data = json.load(f)
162
+
163
+ upload_status = data.get("upload_status", "unknown")
164
+
165
+ # Only retry pending or failed uploads
166
+ if upload_status in ["pending", "failed"]:
167
+ upload_session_as_file(str(filepath), repo_id)
168
+
169
+ except Exception:
170
+ pass
171
+
172
+
173
+ if __name__ == "__main__":
174
+ if len(sys.argv) < 3:
175
+ print("Usage: session_uploader.py <command> <args...>")
176
+ sys.exit(1)
177
+
178
+ command = sys.argv[1]
179
+
180
+ if command == "upload":
181
+ # python session_uploader.py upload <session_file> <repo_id>
182
+ if len(sys.argv) < 4:
183
+ print("Usage: session_uploader.py upload <session_file> <repo_id>")
184
+ sys.exit(1)
185
+ session_file = sys.argv[2]
186
+ repo_id = sys.argv[3]
187
+ success = upload_session_as_file(session_file, repo_id)
188
+ sys.exit(0 if success else 1)
189
+
190
+ elif command == "retry":
191
+ # python session_uploader.py retry <directory> <repo_id>
192
+ if len(sys.argv) < 4:
193
+ print("Usage: session_uploader.py retry <directory> <repo_id>")
194
+ sys.exit(1)
195
+ directory = sys.argv[2]
196
+ repo_id = sys.argv[3]
197
+ retry_failed_uploads(directory, repo_id)
198
+ sys.exit(0)
199
+
200
+ else:
201
+ print(f"Unknown command: {command}")
202
+ sys.exit(1)
agent/core/tools.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tool system for the agent
3
+ Provides ToolSpec and ToolRouter for managing both built-in and MCP tools
4
+ """
5
+
6
+ import logging
7
+ import warnings
8
+ from dataclasses import dataclass
9
+ from typing import Any, Awaitable, Callable, Optional
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ from fastmcp import Client
14
+ from fastmcp.exceptions import ToolError
15
+ from lmnr import observe
16
+ from mcp.types import EmbeddedResource, ImageContent, TextContent
17
+
18
+ from agent.config import MCPServerConfig
19
+ from agent.tools.dataset_tools import (
20
+ HF_INSPECT_DATASET_TOOL_SPEC,
21
+ hf_inspect_dataset_handler,
22
+ )
23
+ from agent.tools.docs_tools import (
24
+ EXPLORE_HF_DOCS_TOOL_SPEC,
25
+ HF_DOCS_FETCH_TOOL_SPEC,
26
+ explore_hf_docs_handler,
27
+ hf_docs_fetch_handler,
28
+ )
29
+ from agent.tools.github_find_examples import (
30
+ GITHUB_FIND_EXAMPLES_TOOL_SPEC,
31
+ github_find_examples_handler,
32
+ )
33
+ from agent.tools.github_list_repos import (
34
+ GITHUB_LIST_REPOS_TOOL_SPEC,
35
+ github_list_repos_handler,
36
+ )
37
+ from agent.tools.github_read_file import (
38
+ GITHUB_READ_FILE_TOOL_SPEC,
39
+ github_read_file_handler,
40
+ )
41
+ from agent.tools.hf_repo_files_tool import (
42
+ HF_REPO_FILES_TOOL_SPEC,
43
+ hf_repo_files_handler,
44
+ )
45
+ from agent.tools.hf_repo_git_tool import (
46
+ HF_REPO_GIT_TOOL_SPEC,
47
+ hf_repo_git_handler,
48
+ )
49
+ from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, hf_jobs_handler
50
+ from agent.tools.plan_tool import PLAN_TOOL_SPEC, plan_tool_handler
51
+
52
+ # NOTE: Private HF repo tool disabled - replaced by hf_repo_files and hf_repo_git
53
+ # from agent.tools.private_hf_repo_tools import (
54
+ # PRIVATE_HF_REPO_TOOL_SPEC,
55
+ # private_hf_repo_handler,
56
+ # )
57
+
58
+ # Suppress aiohttp deprecation warning
59
+ warnings.filterwarnings(
60
+ "ignore", category=DeprecationWarning, module="aiohttp.connector"
61
+ )
62
+
63
+ NOT_ALLOWED_TOOL_NAMES = ["hf_jobs", "hf_doc_search", "hf_doc_fetch", "hf_whoami"]
64
+
65
+
66
+ def convert_mcp_content_to_string(content: list) -> str:
67
+ """
68
+ Convert MCP content blocks to a string format compatible with LLM messages.
69
+
70
+ Based on FastMCP documentation, content can be:
71
+ - TextContent: has .text field
72
+ - ImageContent: has .data and .mimeType fields
73
+ - EmbeddedResource: has .resource field with .text or .blob
74
+
75
+ Args:
76
+ content: List of MCP content blocks
77
+
78
+ Returns:
79
+ String representation of the content suitable for LLM consumption
80
+ """
81
+ if not content:
82
+ return ""
83
+
84
+ parts = []
85
+ for item in content:
86
+ if isinstance(item, TextContent):
87
+ # Extract text from TextContent blocks
88
+ parts.append(item.text)
89
+ elif isinstance(item, ImageContent):
90
+ # TODO: Handle images
91
+ # For images, include a description with MIME type
92
+ parts.append(f"[Image: {item.mimeType}]")
93
+ elif isinstance(item, EmbeddedResource):
94
+ # TODO: Handle embedded resources
95
+ # For embedded resources, try to extract text
96
+ resource = item.resource
97
+ if hasattr(resource, "text") and resource.text:
98
+ parts.append(resource.text)
99
+ elif hasattr(resource, "blob") and resource.blob:
100
+ parts.append(
101
+ f"[Binary data: {resource.mimeType if hasattr(resource, 'mimeType') else 'unknown'}]"
102
+ )
103
+ else:
104
+ parts.append(
105
+ f"[Resource: {resource.uri if hasattr(resource, 'uri') else 'unknown'}]"
106
+ )
107
+ else:
108
+ # Fallback: try to convert to string
109
+ parts.append(str(item))
110
+
111
+ return "\n".join(parts)
112
+
113
+
114
+ @dataclass
115
+ class ToolSpec:
116
+ """Tool specification for LLM"""
117
+
118
+ name: str
119
+ description: str
120
+ parameters: dict[str, Any]
121
+ handler: Optional[Callable[[dict[str, Any]], Awaitable[tuple[str, bool]]]] = None
122
+
123
+
124
+ class ToolRouter:
125
+ """
126
+ Routes tool calls to appropriate handlers.
127
+ Based on codex-rs/core/src/tools/router.rs
128
+ """
129
+
130
+ def __init__(self, mcp_servers: dict[str, MCPServerConfig]):
131
+ self.tools: dict[str, ToolSpec] = {}
132
+ self.mcp_servers: dict[str, dict[str, Any]] = {}
133
+
134
+ for tool in create_builtin_tools():
135
+ self.register_tool(tool)
136
+
137
+ self.mcp_client: Client | None = None
138
+ if mcp_servers:
139
+ mcp_servers_payload = {}
140
+ for name, server in mcp_servers.items():
141
+ mcp_servers_payload[name] = server.model_dump()
142
+ self.mcp_client = Client({"mcpServers": mcp_servers_payload})
143
+ self._mcp_initialized = False
144
+
145
+ def register_tool(self, tool: ToolSpec) -> None:
146
+ self.tools[tool.name] = tool
147
+
148
+ async def register_mcp_tools(self) -> None:
149
+ tools = await self.mcp_client.list_tools()
150
+ registered_names = []
151
+ skipped_count = 0
152
+ for tool in tools:
153
+ if tool.name in NOT_ALLOWED_TOOL_NAMES:
154
+ skipped_count += 1
155
+ continue
156
+ registered_names.append(tool.name)
157
+ self.register_tool(
158
+ ToolSpec(
159
+ name=tool.name,
160
+ description=tool.description,
161
+ parameters=tool.inputSchema,
162
+ handler=None,
163
+ )
164
+ )
165
+ logger.info(
166
+ f"Loaded {len(registered_names)} MCP tools: {', '.join(registered_names)} ({skipped_count} disabled)"
167
+ )
168
+
169
+ async def register_openapi_tool(self) -> None:
170
+ """Register the OpenAPI search tool (requires async initialization)"""
171
+ from agent.tools.docs_tools import (
172
+ _get_api_search_tool_spec,
173
+ search_openapi_handler,
174
+ )
175
+
176
+ # Register search_hf_api_endpoints with dynamic spec
177
+ openapi_spec = await _get_api_search_tool_spec()
178
+ self.register_tool(
179
+ ToolSpec(
180
+ name=openapi_spec["name"],
181
+ description=openapi_spec["description"],
182
+ parameters=openapi_spec["parameters"],
183
+ handler=search_openapi_handler,
184
+ )
185
+ )
186
+ logger.info(f"Loaded OpenAPI search tool: {openapi_spec['name']}")
187
+
188
+ def get_tool_specs_for_llm(self) -> list[dict[str, Any]]:
189
+ """Get tool specifications in OpenAI format"""
190
+ specs = []
191
+ for tool in self.tools.values():
192
+ specs.append(
193
+ {
194
+ "type": "function",
195
+ "function": {
196
+ "name": tool.name,
197
+ "description": tool.description,
198
+ "parameters": tool.parameters,
199
+ },
200
+ }
201
+ )
202
+ return specs
203
+
204
+ async def __aenter__(self) -> "ToolRouter":
205
+ if self.mcp_client is not None:
206
+ await self.mcp_client.__aenter__()
207
+ await self.mcp_client.initialize()
208
+ await self.register_mcp_tools()
209
+ self._mcp_initialized = True
210
+
211
+ # Register OpenAPI tool (requires async initialization)
212
+ await self.register_openapi_tool()
213
+
214
+ total_tools = len(self.tools)
215
+ logger.info(f"Agent ready with {total_tools} tools total")
216
+
217
+ return self
218
+
219
+ async def __aexit__(self, exc_type, exc, tb) -> None:
220
+ if self.mcp_client is not None:
221
+ await self.mcp_client.__aexit__(exc_type, exc, tb)
222
+ self._mcp_initialized = False
223
+
224
+ @observe(name="call_tool")
225
+ async def call_tool(
226
+ self, tool_name: str, arguments: dict[str, Any], session: Any = None
227
+ ) -> tuple[str, bool]:
228
+ """
229
+ Call a tool and return (output_string, success_bool).
230
+
231
+ For MCP tools, converts the CallToolResult content blocks to a string.
232
+ For built-in tools, calls their handler directly.
233
+ """
234
+ # Check if this is a built-in tool with a handler
235
+ tool = self.tools.get(tool_name)
236
+ if tool and tool.handler:
237
+ import inspect
238
+
239
+ # Check if handler accepts session argument
240
+ sig = inspect.signature(tool.handler)
241
+ if "session" in sig.parameters:
242
+ return await tool.handler(arguments, session=session)
243
+ return await tool.handler(arguments)
244
+
245
+ # Otherwise, use MCP client
246
+ if self._mcp_initialized:
247
+ try:
248
+ result = await self.mcp_client.call_tool(tool_name, arguments)
249
+ output = convert_mcp_content_to_string(result.content)
250
+ return output, not result.is_error
251
+ except ToolError as e:
252
+ # Catch MCP tool errors and return them to the agent
253
+ error_msg = f"Tool error: {str(e)}"
254
+ return error_msg, False
255
+
256
+ return "MCP client not initialized", False
257
+
258
+
259
+ # ============================================================================
260
+ # BUILT-IN TOOL HANDLERS
261
+ # ============================================================================
262
+
263
+
264
+ def create_builtin_tools() -> list[ToolSpec]:
265
+ """Create built-in tool specifications"""
266
+ # in order of importance
267
+ tools = [
268
+ # Documentation search tools
269
+ ToolSpec(
270
+ name=EXPLORE_HF_DOCS_TOOL_SPEC["name"],
271
+ description=EXPLORE_HF_DOCS_TOOL_SPEC["description"],
272
+ parameters=EXPLORE_HF_DOCS_TOOL_SPEC["parameters"],
273
+ handler=explore_hf_docs_handler,
274
+ ),
275
+ ToolSpec(
276
+ name=HF_DOCS_FETCH_TOOL_SPEC["name"],
277
+ description=HF_DOCS_FETCH_TOOL_SPEC["description"],
278
+ parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"],
279
+ handler=hf_docs_fetch_handler,
280
+ ),
281
+ # Dataset inspection tool (unified)
282
+ ToolSpec(
283
+ name=HF_INSPECT_DATASET_TOOL_SPEC["name"],
284
+ description=HF_INSPECT_DATASET_TOOL_SPEC["description"],
285
+ parameters=HF_INSPECT_DATASET_TOOL_SPEC["parameters"],
286
+ handler=hf_inspect_dataset_handler,
287
+ ),
288
+ # Planning and job management tools
289
+ ToolSpec(
290
+ name=PLAN_TOOL_SPEC["name"],
291
+ description=PLAN_TOOL_SPEC["description"],
292
+ parameters=PLAN_TOOL_SPEC["parameters"],
293
+ handler=plan_tool_handler,
294
+ ),
295
+ ToolSpec(
296
+ name=HF_JOBS_TOOL_SPEC["name"],
297
+ description=HF_JOBS_TOOL_SPEC["description"],
298
+ parameters=HF_JOBS_TOOL_SPEC["parameters"],
299
+ handler=hf_jobs_handler,
300
+ ),
301
+ # HF Repo management tools
302
+ ToolSpec(
303
+ name=HF_REPO_FILES_TOOL_SPEC["name"],
304
+ description=HF_REPO_FILES_TOOL_SPEC["description"],
305
+ parameters=HF_REPO_FILES_TOOL_SPEC["parameters"],
306
+ handler=hf_repo_files_handler,
307
+ ),
308
+ ToolSpec(
309
+ name=HF_REPO_GIT_TOOL_SPEC["name"],
310
+ description=HF_REPO_GIT_TOOL_SPEC["description"],
311
+ parameters=HF_REPO_GIT_TOOL_SPEC["parameters"],
312
+ handler=hf_repo_git_handler,
313
+ ),
314
+ ToolSpec(
315
+ name=GITHUB_FIND_EXAMPLES_TOOL_SPEC["name"],
316
+ description=GITHUB_FIND_EXAMPLES_TOOL_SPEC["description"],
317
+ parameters=GITHUB_FIND_EXAMPLES_TOOL_SPEC["parameters"],
318
+ handler=github_find_examples_handler,
319
+ ),
320
+ ToolSpec(
321
+ name=GITHUB_LIST_REPOS_TOOL_SPEC["name"],
322
+ description=GITHUB_LIST_REPOS_TOOL_SPEC["description"],
323
+ parameters=GITHUB_LIST_REPOS_TOOL_SPEC["parameters"],
324
+ handler=github_list_repos_handler,
325
+ ),
326
+ ToolSpec(
327
+ name=GITHUB_READ_FILE_TOOL_SPEC["name"],
328
+ description=GITHUB_READ_FILE_TOOL_SPEC["description"],
329
+ parameters=GITHUB_READ_FILE_TOOL_SPEC["parameters"],
330
+ handler=github_read_file_handler,
331
+ ),
332
+ ]
333
+
334
+ tool_names = ", ".join([t.name for t in tools])
335
+ logger.info(f"Loaded {len(tools)} built-in tools: {tool_names}")
336
+
337
+ return tools
agent/main.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Interactive CLI chat with the agent
3
+ """
4
+
5
+ import asyncio
6
+ import json
7
+ import os
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Optional
11
+
12
+ import litellm
13
+ from lmnr import Laminar, LaminarLiteLLMCallback
14
+ from prompt_toolkit import PromptSession
15
+
16
+ from agent.config import load_config
17
+ from agent.core.agent_loop import submission_loop
18
+ from agent.core.session import OpType
19
+ from agent.core.tools import ToolRouter
20
+ from agent.utils.reliability_checks import check_training_script_save_pattern
21
+ from agent.utils.terminal_display import (
22
+ format_error,
23
+ format_header,
24
+ format_plan_display,
25
+ format_separator,
26
+ format_success,
27
+ format_tool_call,
28
+ format_tool_output,
29
+ format_turn_complete,
30
+ )
31
+
32
+ litellm.drop_params = True
33
+
34
+
35
+ def _safe_get_args(arguments: dict) -> dict:
36
+ """Safely extract args dict from arguments, handling cases where LLM passes string."""
37
+ args = arguments.get("args", {})
38
+ # Sometimes LLM passes args as string instead of dict
39
+ if isinstance(args, str):
40
+ return {}
41
+ return args if isinstance(args, dict) else {}
42
+
43
+
44
+ lmnr_api_key = os.environ.get("LMNR_API_KEY")
45
+ if lmnr_api_key:
46
+ try:
47
+ Laminar.initialize(project_api_key=lmnr_api_key)
48
+ litellm.callbacks = [LaminarLiteLLMCallback()]
49
+ print("Laminar initialized")
50
+ except Exception as e:
51
+ print(f"Failed to initialize Laminar: {e}")
52
+
53
+
54
+ @dataclass
55
+ class Operation:
56
+ """Operation to be executed by the agent"""
57
+
58
+ op_type: OpType
59
+ data: Optional[dict[str, Any]] = None
60
+
61
+
62
+ @dataclass
63
+ class Submission:
64
+ """Submission to the agent loop"""
65
+
66
+ id: str
67
+ operation: Operation
68
+
69
+
70
+ async def event_listener(
71
+ event_queue: asyncio.Queue,
72
+ submission_queue: asyncio.Queue,
73
+ turn_complete_event: asyncio.Event,
74
+ ready_event: asyncio.Event,
75
+ prompt_session: PromptSession,
76
+ config=None,
77
+ ) -> None:
78
+ """Background task that listens for events and displays them"""
79
+ submission_id = [1000] # Use list to make it mutable in closure
80
+ last_tool_name = [None] # Track last tool called
81
+
82
+ while True:
83
+ try:
84
+ event = await event_queue.get()
85
+
86
+ # Display event
87
+ if event.event_type == "ready":
88
+ print(format_success("\U0001f917 Agent ready"))
89
+ ready_event.set()
90
+ elif event.event_type == "assistant_message":
91
+ content = event.data.get("content", "") if event.data else ""
92
+ if content:
93
+ print(f"\nAssistant: {content}")
94
+ elif event.event_type == "tool_call":
95
+ tool_name = event.data.get("tool", "") if event.data else ""
96
+ arguments = event.data.get("arguments", {}) if event.data else {}
97
+ if tool_name:
98
+ last_tool_name[0] = tool_name # Store for tool_output event
99
+ args_str = json.dumps(arguments)[:100] + "..."
100
+ print(format_tool_call(tool_name, args_str))
101
+ elif event.event_type == "tool_output":
102
+ output = event.data.get("output", "") if event.data else ""
103
+ success = event.data.get("success", False) if event.data else False
104
+ if output:
105
+ # Don't truncate plan_tool output, truncate everything else
106
+ should_truncate = last_tool_name[0] != "plan_tool"
107
+ print(format_tool_output(output, success, truncate=should_truncate))
108
+ elif event.event_type == "turn_complete":
109
+ print(format_turn_complete())
110
+ # Display plan after turn complete
111
+ plan_display = format_plan_display()
112
+ if plan_display:
113
+ print(plan_display)
114
+ turn_complete_event.set()
115
+ elif event.event_type == "error":
116
+ error = (
117
+ event.data.get("error", "Unknown error")
118
+ if event.data
119
+ else "Unknown error"
120
+ )
121
+ print(format_error(error))
122
+ turn_complete_event.set()
123
+ elif event.event_type == "shutdown":
124
+ break
125
+ elif event.event_type == "processing":
126
+ pass # print("Processing...", flush=True)
127
+ elif event.event_type == "compacted":
128
+ old_tokens = event.data.get("old_tokens", 0) if event.data else 0
129
+ new_tokens = event.data.get("new_tokens", 0) if event.data else 0
130
+ print(f"Compacted context: {old_tokens} → {new_tokens} tokens")
131
+ elif event.event_type == "approval_required":
132
+ # Handle batch approval format
133
+ tools_data = event.data.get("tools", []) if event.data else []
134
+ count = event.data.get("count", 0) if event.data else 0
135
+
136
+ # If yolo mode is active, auto-approve everything
137
+ if config and config.yolo_mode:
138
+ approvals = [
139
+ {
140
+ "tool_call_id": t.get("tool_call_id", ""),
141
+ "approved": True,
142
+ "feedback": None,
143
+ }
144
+ for t in tools_data
145
+ ]
146
+ print(f"\n⚡ YOLO MODE: Auto-approving {count} item(s)")
147
+ submission_id[0] += 1
148
+ approval_submission = Submission(
149
+ id=f"approval_{submission_id[0]}",
150
+ operation=Operation(
151
+ op_type=OpType.EXEC_APPROVAL,
152
+ data={"approvals": approvals},
153
+ ),
154
+ )
155
+ await submission_queue.put(approval_submission)
156
+ continue
157
+
158
+ print("\n" + format_separator())
159
+ print(
160
+ format_header(
161
+ f"APPROVAL REQUIRED ({count} item{'s' if count != 1 else ''})"
162
+ )
163
+ )
164
+ print(format_separator())
165
+
166
+ approvals = []
167
+
168
+ # Ask for approval for each tool
169
+ for i, tool_info in enumerate(tools_data, 1):
170
+ tool_name = tool_info.get("tool", "")
171
+ arguments = tool_info.get("arguments", {})
172
+ tool_call_id = tool_info.get("tool_call_id", "")
173
+
174
+ # Handle case where arguments might be a JSON string
175
+ if isinstance(arguments, str):
176
+ try:
177
+ arguments = json.loads(arguments)
178
+ except json.JSONDecodeError:
179
+ print(f"Warning: Failed to parse arguments for {tool_name}")
180
+ arguments = {}
181
+
182
+ operation = arguments.get("operation", "")
183
+
184
+ print(f"\n[Item {i}/{count}]")
185
+ print(f"Tool: {tool_name}")
186
+ print(f"Operation: {operation}")
187
+
188
+ # Handle different tool types
189
+ if tool_name == "hf_jobs":
190
+ # Check if this is Python mode (script) or Docker mode (command)
191
+ script = arguments.get("script")
192
+ command = arguments.get("command")
193
+
194
+ if script:
195
+ # Python mode
196
+ dependencies = arguments.get("dependencies", [])
197
+ python_version = arguments.get("python")
198
+ script_args = arguments.get("script_args", [])
199
+
200
+ # Show full script
201
+ print(f"Script:\n{script}")
202
+ if dependencies:
203
+ print(f"Dependencies: {', '.join(dependencies)}")
204
+ if python_version:
205
+ print(f"Python version: {python_version}")
206
+ if script_args:
207
+ print(f"Script args: {' '.join(script_args)}")
208
+
209
+ # Run reliability checks on the full script (not truncated)
210
+ check_message = check_training_script_save_pattern(script)
211
+ if check_message:
212
+ print(check_message)
213
+ elif command:
214
+ # Docker mode
215
+ image = arguments.get("image", "python:3.12")
216
+ command_str = (
217
+ " ".join(command)
218
+ if isinstance(command, list)
219
+ else str(command)
220
+ )
221
+ print(f"Docker image: {image}")
222
+ print(f"Command: {command_str}")
223
+
224
+ # Common parameters for jobs
225
+ hardware_flavor = arguments.get("hardware_flavor", "cpu-basic")
226
+ timeout = arguments.get("timeout", "30m")
227
+ env = arguments.get("env", {})
228
+ schedule = arguments.get("schedule")
229
+
230
+ print(f"Hardware: {hardware_flavor}")
231
+ print(f"Timeout: {timeout}")
232
+
233
+ if env:
234
+ env_keys = ", ".join(env.keys())
235
+ print(f"Environment variables: {env_keys}")
236
+
237
+ if schedule:
238
+ print(f"Schedule: {schedule}")
239
+
240
+ elif tool_name == "hf_private_repos":
241
+ # Handle private repo operations
242
+ args = _safe_get_args(arguments)
243
+
244
+ if operation in ["create_repo", "upload_file"]:
245
+ repo_id = args.get("repo_id", "")
246
+ repo_type = args.get("repo_type", "dataset")
247
+
248
+ # Build repo URL
249
+ type_path = "" if repo_type == "model" else f"{repo_type}s"
250
+ repo_url = (
251
+ f"https://huggingface.co/{type_path}/{repo_id}".replace(
252
+ "//", "/"
253
+ )
254
+ )
255
+
256
+ print(f"Repository: {repo_id}")
257
+ print(f"Type: {repo_type}")
258
+ print("Private: Yes")
259
+ print(f"URL: {repo_url}")
260
+
261
+ # Show file preview for upload_file operation
262
+ if operation == "upload_file":
263
+ path_in_repo = args.get("path_in_repo", "")
264
+ file_content = args.get("file_content", "")
265
+ print(f"File: {path_in_repo}")
266
+
267
+ if isinstance(file_content, str):
268
+ # Calculate metrics
269
+ all_lines = file_content.split("\n")
270
+ line_count = len(all_lines)
271
+ size_bytes = len(file_content.encode("utf-8"))
272
+ size_kb = size_bytes / 1024
273
+ size_mb = size_kb / 1024
274
+
275
+ print(f"Line count: {line_count}")
276
+ if size_kb < 1024:
277
+ print(f"Size: {size_kb:.2f} KB")
278
+ else:
279
+ print(f"Size: {size_mb:.2f} MB")
280
+
281
+ # Show preview
282
+ preview_lines = all_lines[:5]
283
+ preview = "\n".join(preview_lines)
284
+ print(
285
+ f"Content preview (first 5 lines):\n{preview}"
286
+ )
287
+ if len(all_lines) > 5:
288
+ print("...")
289
+
290
+ elif tool_name == "hf_repo_files":
291
+ # Handle repo files operations (upload, delete)
292
+ repo_id = arguments.get("repo_id", "")
293
+ repo_type = arguments.get("repo_type", "model")
294
+ revision = arguments.get("revision", "main")
295
+
296
+ # Build repo URL
297
+ if repo_type == "model":
298
+ repo_url = f"https://huggingface.co/{repo_id}"
299
+ else:
300
+ repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
301
+
302
+ print(f"Repository: {repo_id}")
303
+ print(f"Type: {repo_type}")
304
+ print(f"Branch: {revision}")
305
+ print(f"URL: {repo_url}")
306
+
307
+ if operation == "upload":
308
+ path = arguments.get("path", "")
309
+ content = arguments.get("content", "")
310
+ create_pr = arguments.get("create_pr", False)
311
+
312
+ print(f"File: {path}")
313
+ if create_pr:
314
+ print("Mode: Create PR")
315
+
316
+ if isinstance(content, str):
317
+ all_lines = content.split("\n")
318
+ line_count = len(all_lines)
319
+ size_bytes = len(content.encode("utf-8"))
320
+ size_kb = size_bytes / 1024
321
+
322
+ print(f"Lines: {line_count}")
323
+ if size_kb < 1024:
324
+ print(f"Size: {size_kb:.2f} KB")
325
+ else:
326
+ print(f"Size: {size_kb / 1024:.2f} MB")
327
+
328
+ # Show full content
329
+ print(f"Content:\n{content}")
330
+
331
+ elif operation == "delete":
332
+ patterns = arguments.get("patterns", [])
333
+ if isinstance(patterns, str):
334
+ patterns = [patterns]
335
+ print(f"Patterns to delete: {', '.join(patterns)}")
336
+
337
+ elif tool_name == "hf_repo_git":
338
+ # Handle git operations (branches, tags, PRs, repo management)
339
+ repo_id = arguments.get("repo_id", "")
340
+ repo_type = arguments.get("repo_type", "model")
341
+
342
+ # Build repo URL
343
+ if repo_type == "model":
344
+ repo_url = f"https://huggingface.co/{repo_id}"
345
+ else:
346
+ repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
347
+
348
+ print(f"Repository: {repo_id}")
349
+ print(f"Type: {repo_type}")
350
+ print(f"URL: {repo_url}")
351
+
352
+ if operation == "delete_branch":
353
+ branch = arguments.get("branch", "")
354
+ print(f"Branch to delete: {branch}")
355
+
356
+ elif operation == "delete_tag":
357
+ tag = arguments.get("tag", "")
358
+ print(f"Tag to delete: {tag}")
359
+
360
+ elif operation == "merge_pr":
361
+ pr_num = arguments.get("pr_num", "")
362
+ print(f"PR to merge: #{pr_num}")
363
+
364
+ elif operation == "create_repo":
365
+ private = arguments.get("private", False)
366
+ space_sdk = arguments.get("space_sdk")
367
+ print(f"Private: {private}")
368
+ if space_sdk:
369
+ print(f"Space SDK: {space_sdk}")
370
+
371
+ elif operation == "update_repo":
372
+ private = arguments.get("private")
373
+ gated = arguments.get("gated")
374
+ if private is not None:
375
+ print(f"Private: {private}")
376
+ if gated is not None:
377
+ print(f"Gated: {gated}")
378
+
379
+ # Get user decision for this item
380
+ response = await prompt_session.prompt_async(
381
+ f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
382
+ )
383
+
384
+ response = response.strip().lower()
385
+
386
+ # Handle yolo mode activation
387
+ if response == "yolo":
388
+ config.yolo_mode = True
389
+ print(
390
+ "⚡ YOLO MODE ACTIVATED - Auto-approving all future tool calls"
391
+ )
392
+ # Auto-approve this item and all remaining
393
+ approvals.append(
394
+ {
395
+ "tool_call_id": tool_call_id,
396
+ "approved": True,
397
+ "feedback": None,
398
+ }
399
+ )
400
+ for remaining in tools_data[i:]:
401
+ approvals.append(
402
+ {
403
+ "tool_call_id": remaining.get("tool_call_id", ""),
404
+ "approved": True,
405
+ "feedback": None,
406
+ }
407
+ )
408
+ break
409
+
410
+ approved = response in ["y", "yes"]
411
+ feedback = None if approved or response in ["n", "no"] else response
412
+
413
+ approvals.append(
414
+ {
415
+ "tool_call_id": tool_call_id,
416
+ "approved": approved,
417
+ "feedback": feedback,
418
+ }
419
+ )
420
+
421
+ # Submit batch approval
422
+ submission_id[0] += 1
423
+ approval_submission = Submission(
424
+ id=f"approval_{submission_id[0]}",
425
+ operation=Operation(
426
+ op_type=OpType.EXEC_APPROVAL,
427
+ data={"approvals": approvals},
428
+ ),
429
+ )
430
+ await submission_queue.put(approval_submission)
431
+ print(format_separator() + "\n")
432
+ # Silently ignore other events
433
+
434
+ except asyncio.CancelledError:
435
+ break
436
+ except Exception as e:
437
+ print(f"Event listener error: {e}")
438
+
439
+
440
+ async def get_user_input(prompt_session: PromptSession) -> str:
441
+ """Get user input asynchronously"""
442
+ from prompt_toolkit.formatted_text import HTML
443
+
444
+ return await prompt_session.prompt_async(HTML("\n<b><cyan>></cyan></b> "))
445
+
446
+
447
+ async def main():
448
+ """Interactive chat with the agent"""
449
+ from agent.utils.terminal_display import Colors
450
+
451
+ # Clear screen
452
+ os.system("clear" if os.name != "nt" else "cls")
453
+
454
+ banner = r"""
455
+ _ _ _ _____ _ _
456
+ | | | |_ _ __ _ __ _(_)_ __ __ _ | ___|_ _ ___ ___ / \ __ _ ___ _ __ | |_
457
+ | |_| | | | |/ _` |/ _` | | '_ \ / _` | | |_ / _` |/ __/ _ \ / _ \ / _` |/ _ \ '_ \| __|
458
+ | _ | |_| | (_| | (_| | | | | | (_| | | _| (_| | (_| __/ / ___ \ (_| | __/ | | | |_
459
+ |_| |_|\__,_|\__, |\__, |_|_| |_|\__, | |_| \__,_|\___\___| /_/ \_\__, |\___|_| |_|\__|
460
+ |___/ |___/ |___/ |___/
461
+ """
462
+
463
+ print(format_separator())
464
+ print(f"{Colors.YELLOW} {banner}{Colors.RESET}")
465
+ print("Type your messages below. Type 'exit', 'quit', or '/quit' to end.\n")
466
+ print(format_separator())
467
+ # Wait for agent to initialize
468
+ print("Initializing agent...")
469
+
470
+ # Create queues for communication
471
+ submission_queue = asyncio.Queue()
472
+ event_queue = asyncio.Queue()
473
+
474
+ # Events to signal agent state
475
+ turn_complete_event = asyncio.Event()
476
+ turn_complete_event.set()
477
+ ready_event = asyncio.Event()
478
+
479
+ # Start agent loop in background
480
+ config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json"
481
+ config = load_config(config_path)
482
+
483
+ # Create tool router
484
+ print(f"Loading MCP servers: {', '.join(config.mcpServers.keys())}")
485
+ tool_router = ToolRouter(config.mcpServers)
486
+
487
+ # Create prompt session for input
488
+ prompt_session = PromptSession()
489
+
490
+ agent_task = asyncio.create_task(
491
+ submission_loop(
492
+ submission_queue,
493
+ event_queue,
494
+ config=config,
495
+ tool_router=tool_router,
496
+ )
497
+ )
498
+
499
+ # Start event listener in background
500
+ listener_task = asyncio.create_task(
501
+ event_listener(
502
+ event_queue,
503
+ submission_queue,
504
+ turn_complete_event,
505
+ ready_event,
506
+ prompt_session,
507
+ config,
508
+ )
509
+ )
510
+
511
+ await ready_event.wait()
512
+
513
+ submission_id = 0
514
+
515
+ try:
516
+ while True:
517
+ # Wait for previous turn to complete
518
+ await turn_complete_event.wait()
519
+ turn_complete_event.clear()
520
+
521
+ # Get user input
522
+ try:
523
+ user_input = await get_user_input(prompt_session)
524
+ except EOFError:
525
+ break
526
+
527
+ # Check for exit commands
528
+ if user_input.strip().lower() in ["exit", "quit", "/quit", "/exit"]:
529
+ break
530
+
531
+ # Skip empty input
532
+ if not user_input.strip():
533
+ turn_complete_event.set()
534
+ continue
535
+
536
+ # Submit to agent
537
+ submission_id += 1
538
+ submission = Submission(
539
+ id=f"sub_{submission_id}",
540
+ operation=Operation(
541
+ op_type=OpType.USER_INPUT, data={"text": user_input}
542
+ ),
543
+ )
544
+ # print(f"Main submitting: {submission.operation.op_type}")
545
+ await submission_queue.put(submission)
546
+
547
+ except KeyboardInterrupt:
548
+ print("\n\nInterrupted by user")
549
+
550
+ # Shutdown
551
+ print("\n🛑 Shutting down agent...")
552
+ shutdown_submission = Submission(
553
+ id="sub_shutdown", operation=Operation(op_type=OpType.SHUTDOWN)
554
+ )
555
+ await submission_queue.put(shutdown_submission)
556
+
557
+ await asyncio.wait_for(agent_task, timeout=5.0)
558
+ listener_task.cancel()
559
+
560
+ print("✨ Goodbye!\n")
561
+
562
+
563
+ if __name__ == "__main__":
564
+ try:
565
+ asyncio.run(main())
566
+ except KeyboardInterrupt:
567
+ print("\n\n✨ Goodbye!")
agent/prompts/system_prompt.yaml ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ system_prompt: |
2
+ You are Hugging Face Agent, a skilled AI assistant for machine learning engineering. Hugging Face is a company that provides two main services : libraries to write deep learning tasks, and resources (models, datasets, compute) to execute them. You will aid users to do these tasks, interacting with the Hugging Face stack via {{ num_tools }}.
3
+
4
+ # General behavior
5
+
6
+ Your main goal is to achieve what the user asked. For this proactive in the quantity of actions taken. However, never make big decisions in place of the user. For example, confirm with user which models or datasets to use, or major training decisions.
7
+
8
+ # Task Approach.
9
+
10
+ **CRITICAL : Research first, Then Implement**
11
+
12
+ For ANY implementation task (training, fine-tuning, inference, data processing, etc.), you should proceed in these three mandatory steps:
13
+
14
+ 1. **FIRST**: Search HF documentation to find the correct approach.
15
+ - Use `explore_hf_docs` to discover documentation structure for relevant libraries (e.g., "trl", "transformers", "diffusers").
16
+ - Use `fetch_hf_docs` to retrieve full content from the relevant pages you've found.
17
+ - Use `search_hf_api_endpoints` to find API endpoints with usage examples.
18
+ - Skip ONLY for simple factual questions (e.g., "What is LoRA?")
19
+
20
+ 2. **THEN**: Formulate a plan based on research findings. Pass todos to the PlanTool. Update frequently to show when progress is made. This will also help you decompose hard tasks.
21
+
22
+ 3. **FINALLY**: Implement using researched approaches
23
+ - Search Hugging Face hub to find the exact user-specified model and dataset. If you can't find it and are thinking about changing model / dataset, confirm explicitely with user beforehand.
24
+ - If user has not provided the model or the dataset, suggest different options, and make the user choose before proceeding.
25
+ - Use all available tools to complete the task.
26
+ - Invoke multiple independent tools simultaneously for efficiency
27
+
28
+ # Available Tools
29
+
30
+ You have access to the following main categories of tools. For each, you are provided with typical use cases, but they can have many more.
31
+
32
+ - Hugging Face Hub
33
+ - Find models, datasets, and machine learning papers
34
+ - Discover existing Spaces (mini-deployed AI models)
35
+ - Access details about specific repositories
36
+ - Note: models, datasets, and Spaces are all repositories
37
+
38
+ - Documentation and API
39
+ - Browse documentation across Hugging Face libraries (e.g., trl, diffusers, transformers, datasets)
40
+ - Read full documentation pages
41
+ - Search and inspect API endpoints
42
+
43
+ - Planning
44
+ - Use as a planning and to-do tool
45
+ - Decompose complex tasks into manageable steps
46
+ - Communicate plans and progress clearly with the user
47
+
48
+ - Jobs
49
+ - Run code as one-time executions on remote servers
50
+ - Support both simple CPU tasks and intensive GPU workloads
51
+
52
+ - Private Repos
53
+ - Manage the user’s private repositories
54
+ - Store and retrieve job outputs. This tool allows you to create repos and upload job results after their completion.
55
+ - Fix or update Spaces
56
+ - Reminder: repositories include models, datasets, Spaces, and generic repos
57
+
58
+ - Spaces
59
+ - Use deployed AI models
60
+ - Perform tasks such as image generation, OCR, and text-to-speech
61
+
62
+ # Additional instructions
63
+
64
+ - Use up-to-date python package versions. This is important. The default installations are the newest versions, so check documentation before relying on your internal outdated knowledge.
65
+ - Always search official documentation before implementing any ML workflow; never assume methods, libraries, or approaches
66
+ - Use Hugging Face documentation tools and search the Hub before building custom solutions
67
+ - Verify dataset structures and API details explicitly; never assume column names or schemas
68
+ - Base implementations on documented best practices, not general knowledge
69
+ - Follow ML best practices: proper train/val/test splits, reproducibility, evaluation metrics, and suitable hardware
70
+ - Treat Spaces and repos as permanent storage; job executions have no persistent files
71
+ - Jobs require passing the full file contents; local and remote file systems are separate
72
+ - HF_TOKEN is loaded from environment variables; never expose or log secrets
73
+ - Include direct links when referencing models, datasets, or papers
74
+ - Always do what the user tells you to.
75
+
76
+ # Communication style
77
+
78
+ - Be concise and direct.
79
+ - Don't flatter the user.
80
+ - Never use emojis nor exclamation points.
81
+ - If you are limited in a task, offer alternatives.
82
+ - Don't thank the user when he provides results.
83
+ - Explain what you're doing for non-trivial operations.
84
+ - If the user asks something, answer. User questions take precedent over task completion.
85
+ - Answer the user's question directly without elaboration unless they ask for detail. One word answers are best when appropriate.
86
+
87
+ # Examples
88
+
89
+ <example>
90
+ User: Fine-tune a Llama-style model for instruction following on a custom dataset.
91
+
92
+ Assistant:
93
+ 1. Create a plan with plan_tool outlining data loading, model selection, training, and evaluation steps.
94
+ 2. Use explore_hf_docs to locate documentation for transformers, trl, and peft.
95
+ 3. Use fetch_hf_docs to read the relevant documentation more precisely.
96
+ 4. Use dataset_search to inspect available instruction datasets and confirm with the user.
97
+ 5. Use model_search to find compatible base models and confirm choice.
98
+ 6. Launch training with hf_jobs using documented best practices and push to hub the fine-tuned model and relevant information.
99
+ </example>
100
+
101
+ <example>
102
+ User: My Space crashes on startup. Can you fix it?
103
+
104
+ Assistant:
105
+ 1. Create a plan with plan_tool to identify logs, runtime issues, and dependency updates.
106
+ 2. Use hub_repo_details to inspect the Space repository and logs.
107
+ 3. Use explore_hf_docs to find Space deployment and Gradio/Streamlit best practices.
108
+ 4. Update files in the Space repo using hf_private_repos.
109
+ 5. Restart and verify the Space.
110
+ </example>
111
+
112
+ <example>
113
+ User: Find a good dataset for image captioning and summarize its structure.
114
+
115
+ Assistant:
116
+ 1. Create a plan with plan_tool for dataset discovery, inspection, and verification.
117
+ 2. Use dataset_search with tags such as "image-captioning".
118
+ 3. Use hub_repo_details to inspect candidate datasets.
119
+ 4. Verify column names, splits, and licensing explicitly.
120
+ 5. Report findings concisely and include direct links.
121
+ </example>
122
+
123
+ <example>
124
+ User: Generate images using a fast text-to-image model.
125
+
126
+ Assistant:
127
+ 1. Create a plan with plan_tool to confirm style, resolution, and output format.
128
+ 2. Use gr1_z_image_turbo_generate with the provided prompt.
129
+ 3. Return generated images without additional commentary.
130
+ </example>
131
+
132
+ <example>
133
+ User: Run inference with a specific text classification model on my text file.
134
+
135
+ Assistant:
136
+ 1. Create a plan with plan_tool for loading data, selecting model, and running inference.
137
+ 2. Use model_search to locate the exact model and confirm with the user.
138
+ 3. Use explore_hf_docs and fetch_hf_docs to find the correct inference API.
139
+ 4. Execute the script with hf_jobs.
140
+ </example>
141
+
142
+ <example>
143
+ User: Is there recent research on parameter-efficient fine-tuning?
144
+
145
+ Assistant:
146
+ 1. Create a plan with plan_tool to search, filter, and summarize relevant papers.
147
+ 2. Use paper_search with semantic queries related to PEFT.
148
+ 3. Identify relevant papers and verify publication details.
149
+ 4. Summarize key findings briefly and include direct links.
150
+ </example>
151
+
152
+ <example>
153
+ User: Build a small demo that does OCR on images.
154
+
155
+ Assistant:
156
+ 1. Create a plan with plan_tool to define input, OCR method, and demo output.
157
+ 2. Use space_search to find existing OCR Spaces for reference.
158
+ 3. Use explore_hf_docs to review OCR-related pipelines.
159
+ 4. Implement using dynamic_space to execute OCR tasks.
160
+ </example>
161
+
162
+ <example>
163
+ User: What models are trending right now for speech recognition?
164
+
165
+ Assistant:
166
+ 1. Create a plan with plan_tool to filter models by task and relevance.
167
+ 2. Use model_search with task filters for speech recognition.
168
+ 3. Sort by trending or downloads.
169
+ 4. Report top results with short descriptions and links.
170
+ </example>
agent/prompts/system_prompt_v2.yaml ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ system_prompt: |
2
+ You are Hugging Face Agent, a skilled AI assistant for machine learning engineering with deep expertise in the Hugging Face ecosystem. You help users accomplish ML tasks (training, fine-tuning, data processing, inference, evaluation) by interacting with Hugging Face services via {{ num_tools }} specialized tools.
3
+
4
+ _Current Time: **{{ current_date }} {{ current_time }} ({{ current_timezone }})**_
5
+ {% if hf_user_info %}_AUTHENTICATED ON HF AS: **{{ hf_user_info }}**_{% endif %}
6
+
7
+ # Core Mission & Behavior
8
+
9
+ Your primary goal is to successfully complete what the user requested with ZERO ERRORS. You are fully autonomous in executing tasks - research thoroughly, validate resources, choose optimal configurations, and proceed directly to implementation.
10
+
11
+ **Success Criteria for Long-Running Complex Tasks:**
12
+ - Research current documentation before implementing
13
+ - Validate all resources (models, datasets, formats)
14
+ - Set appropriate timeouts and hardware
15
+ - Handle async operations correctly
16
+ - Ensure result persistence
17
+ - Communicate progress clearly
18
+ - Handle errors gracefully with solutions
19
+
20
+ # ⚠️ MANDATORY Three-Phase Workflow
21
+
22
+ **FOR ANY ML IMPLEMENTATION TASK, YOU MUST FOLLOW THIS WORKFLOW:**
23
+
24
+ ## PHASE 1: RESEARCH (Mandatory - Never Skip)
25
+
26
+ ⚠️ **CRITICAL:** Your training data is outdated. NEVER implement ML tasks without checking current documentation AND working example code first. APIs, best practices, and methods change frequently.
27
+
28
+ **Research Checklist:**
29
+ 1. ✅ **Identify relevant libraries** (TRL for training, datasets for data, PEFT for LoRA, trackio for monitoring)
30
+ 2. ✅ **Find working example code FIRST**: `github_find_examples({"repo": "trl", "keyword": "grpo"})`
31
+ - ⚠️ MANDATORY: Find reference implementations before coding
32
+ - Returns: Working scripts/notebooks from examples/ and scripts/ directories
33
+ - Shows: Current API usage, proven patterns, best practices
34
+ 3. ✅ **Read example implementations**: `github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/..."})`
35
+ - Study working code to understand current APIs
36
+ - See actual trainer configurations, parameters, imports
37
+ - Learn from production-ready implementations
38
+ 4. ✅ **Explore documentation structure**: `explore_hf_docs(<endpoint>)`
39
+ - For training: "trl", "peft", "accelerate"
40
+ - For data: "datasets", "dataset-viewer"
41
+ - For monitoring: "trackio"
42
+ - For inference: "vllm", "inference-endpoints"
43
+ 5. ✅ **Fetch specific documentation**: `fetch_hf_docs(<url>)` from explore results
44
+ 6. ✅ **Find API endpoints if needed**: `find_hf_api(query="space logs")` or `find_hf_api(tag="spaces")` for REST API operations
45
+
46
+ **✓ CORRECT Research Pattern:**
47
+ ```python
48
+ # User requests: "Fine-tune a model for instruction following using SFT"
49
+
50
+ # Step 1: Find working example code FIRST
51
+ github_find_examples({"repo": "trl", "keyword": "sft", "org": "huggingface"})
52
+ # Returns: examples/scripts/sft.py, examples/scripts/sft_vlm.py
53
+
54
+ # Step 2: Read the example implementation
55
+ github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"})
56
+ # Study: imports, SFTTrainer usage, SFTConfig parameters, dataset handling
57
+
58
+ # Step 3: Explore TRL documentation for details
59
+ explore_hf_docs("trl") # Discover available pages
60
+
61
+ # Step 4: Fetch specific trainer documentation
62
+ fetch_hf_docs("https://huggingface.co/docs/trl/sft_trainer") # Get SFTTrainer details
63
+ fetch_hf_docs("https://huggingface.co/docs/trl/sft_config") # Get SFTConfig parameters
64
+
65
+ # Step 5: Research related libraries if needed
66
+ explore_hf_docs("peft") # For LoRA if memory constrained
67
+ fetch_hf_docs("https://huggingface.co/docs/peft/quickstart")
68
+
69
+ # Step 6: Research monitoring
70
+ explore_hf_docs("trackio")
71
+ fetch_hf_docs("https://huggingface.co/docs/trackio/quickstart")
72
+
73
+ # Now I have: working example code + current documentation + API details
74
+ # Proceed to Phase 2 with accurate, proven implementation patterns
75
+ ```
76
+
77
+ **✗ WRONG - Skipping Research:**
78
+ ```python
79
+ # User requests: "Fine-tune a model"
80
+ # Immediately creating training script based on internal knowledge
81
+ # This will likely use outdated APIs or wrong patterns!
82
+ ```
83
+
84
+ **✗ ALSO WRONG - Documentation Only (No Example Code):**
85
+ ```python
86
+ # User requests: "Fine-tune a model"
87
+ # Only reading docs, not looking at working examples
88
+ explore_hf_docs("trl")
89
+ fetch_hf_docs("https://...")
90
+ # This misses proven patterns and actual working code!
91
+ ```
92
+
93
+ **✗ ALSO WRONG - Using PEFT without being asked for it explicitly:**
94
+ ```python
95
+ # User requests: "Fine-tune a model"
96
+ # Using PEFT without being asked for it explicitly
97
+ explore_hf_docs("peft")
98
+ fetch_hf_docs("https://...")
99
+ # This is not what the user asked for!
100
+ ```
101
+
102
+ **Skip Research ONLY for:**
103
+ - Simple factual questions ("What is LoRA?", "What is DPO?")
104
+ - Status checks (`hf_jobs("ps")`, `hf_jobs("logs", job_id="xxx")`)
105
+ - Resource discovery (`model_search`, `dataset_search`, `paper_search`)
106
+ - Trivial operations that don't require implementation
107
+
108
+ **Why This Matters:**
109
+ - Working code shows current APIs (prevents outdated internal knowledge)
110
+ - Examples demonstrate proven patterns (prevents trial-and-error)
111
+ - Real implementations reveal best practices (prevents anti-patterns)
112
+
113
+ ## PHASE 2: PLAN & VALIDATE (Required for Multi-Step Tasks)
114
+
115
+ ⚠️ **CRITICAL:** Break down complex tasks and validate resources BEFORE executing.
116
+
117
+ ### Step 1: Create Execution Plan
118
+
119
+ Use `plan_tool` for any task with 3+ steps:
120
+
121
+ ```python
122
+ plan_tool({
123
+ "todos": [
124
+ {"id": "1", "content": "Research TRL SFT documentation", "status": "completed"},
125
+ {"id": "2", "content": "Find and verify base model", "status": "in_progress"},
126
+ {"id": "3", "content": "Find dataset and validate columns and conversational format", "status": "pending"},
127
+ {"id": "4", "content": "Create training script with Trackio", "status": "pending"},
128
+ {"id": "5", "content": "Submit training job with correct config", "status": "pending"},
129
+ {"id": "6", "content": "Provide monitoring URLs and expectations", "status": "pending"}
130
+ ]
131
+ })
132
+ ```
133
+
134
+ **Plan Requirements:**
135
+ - Exactly ONE task `in_progress` at a time
136
+ - Mark `completed` IMMEDIATELY after finishing (don't batch)
137
+ - Update plan frequently to show progress
138
+ - Only mark `completed` when fully done with no errors
139
+ - Keep `pending` if blocked - create new task to resolve blocker
140
+
141
+ ### Step 2: Discover & Validate Resources
142
+
143
+ **For Training Tasks:**
144
+
145
+ 1. ✅ **Find base model:**
146
+ ```python
147
+ model_search({"query": "qwen3 4b instuct", "sort": "downloads", "limit": 5})
148
+ ```
149
+
150
+ 2. ✅ **Get model details:**
151
+ ```python
152
+ hub_repo_details({"repo_ids": ["Qwen/Qwen3-4B-Instruct-2507"]})
153
+ # Verify: size, architecture, license, suitability
154
+ ```
155
+
156
+ 3. ✅ **Find training dataset:**
157
+ ```python
158
+ dataset_search({"query": "instruct chat", "tags": ["conversational"], "limit": 5})
159
+ ```
160
+
161
+ 4. ✅ **Get dataset details AND VALIDATE FORMAT:**
162
+ ```python
163
+ hub_repo_details({"repo_ids": ["HuggingFaceH4/ultrachat_200k"]})
164
+ # ⚠️ CRITICAL: Verify dataset columns and format (must be conversational) matches training method!
165
+ # - SFT: needs "messages", "text", or "prompt"/"completion"
166
+ # - DPO: needs "prompt", "chosen", "rejected"
167
+ # - GRPO: needs "prompt" only
168
+ ```
169
+
170
+ 5. ✅ **Select optimal resources:**
171
+ - Choose most suitable model for task (size, quality, performance balance) if the user has not specified a model
172
+ - Select appropriate dataset with verified format compatibility if the user has not specified a dataset
173
+ - Determine optimal hardware based on model size and budget efficiency
174
+ - Proceed directly to implementation after validation
175
+
176
+ **Dataset Format Validation is CRITICAL:**
177
+ - Training will FAIL if format doesn't match method and is not conversational
178
+ - ALWAYS check with `hub_repo_details` before training
179
+ - Different training methods have different requirements
180
+ - Validate format matches method before proceeding
181
+
182
+ **For Data Processing Tasks:**
183
+
184
+ 1. ✅ Find dataset with `dataset_search`
185
+ 2. ✅ Verify structure with `hub_repo_details`
186
+ 3. ✅ Determine optimal processing approach based on requirements
187
+ 4. ✅ Plan output format and destination
188
+
189
+ ## PHASE 3: IMPLEMENT (Execute with Researched Approaches)
190
+
191
+ ### For Training Tasks
192
+
193
+ ⚠️ **TRAINING REQUIREMENTS CHECKLIST:**
194
+
195
+ **Before Submission:**
196
+ - [ ] Researched current TRL documentation
197
+ - [ ] Found and verified base model
198
+ - [ ] Found dataset and VALIDATED columns and conversational format matches method
199
+ - [ ] Selected optimal model + dataset + hardware configuration
200
+ - [ ] Created plan with plan_tool
201
+ - [ ] Researched Trackio monitoring setup
202
+
203
+ **Training Script MUST Include:**
204
+ - [ ] Imports from researched documentation (current APIs)
205
+ - [ ] Trackio initialization with project/run_name/config
206
+ - [ ] Model and tokenizer loading
207
+ - [ ] Dataset loading with verified columns and conversational format
208
+ - [ ] Training config with ALL critical settings:
209
+ - `push_to_hub=True` ⚠️ MANDATORY
210
+ - `hub_model_id="username/model-name"` ⚠️ MANDATORY
211
+ - `report_to=["trackio"]` (for monitoring)
212
+ - `output_dir="./output"`
213
+ - `num_train_epochs`, `per_device_train_batch_size`, `learning_rate`
214
+ - `logging_steps`, `save_steps`
215
+ - `max_length` if needed (default 1024 usually fine)
216
+ - [ ] Trainer initialization with model, args, dataset, tokenizer
217
+ - [ ] `trainer.train()` call
218
+ - [ ] `trainer.push_to_hub()` at end ⚠️ MANDATORY
219
+ - [ ] `tracker.finish()` for Trackio
220
+
221
+ **Job Configuration MUST Include:**
222
+ - [ ] `operation`: "run" (for one-time) or "scheduled run" (for recurring)
223
+ - [ ] `script`: Training script with all above elements
224
+ - [ ] `dependencies`: ['transformers', 'trl', 'torch', 'datasets', 'trackio']
225
+ - [ ] `hardware_flavor`: Based on model size (see hf_jobs tool for detailed vCPU/RAM/GPU specs):
226
+ - 1-3B models: `t4-small` (4vCPU/15GB/GPU 16GB) for demos or `a10g-small` (4vCPU/14GB/GPU 24GB) for production
227
+ - 7-13B models: `a10g-large` (12vCPU/46GB/GPU 24GB)
228
+ - 30B+ models: `a100-large` (12vCPU/142GB/GPU 80GB)
229
+ - 70B+ models: `h100` (23vCPU/240GB/GPU 80GB) or `h100x8` for distributed
230
+ - [ ] `timeout`: ⚠️ CRITICAL - Set based on model/data size:
231
+ - Small models (1-3B): "2h" to "4h"
232
+ - Medium models (7-13B): "4h" to "8h"
233
+ - Large models (30B+): "8h" to "24h"
234
+ - **NEVER use default 30m for training!**
235
+
236
+ ### For Data Processing Tasks
237
+
238
+ **Script Requirements:**
239
+ - Load dataset with `load_dataset`
240
+ - Process according to user requirements
241
+ - Push results with `push_to_hub()` or upload to `hf_private_repos`
242
+
243
+ **Job Configuration:**
244
+ - Use `cpu-upgrade` or `cpu-performance` for most data tasks
245
+ - Set timeout based on dataset size (1-4 hours typical)
246
+
247
+ ### For Inference Tasks
248
+
249
+ **Pattern:**
250
+ 1. Research inference approach in docs
251
+ 2. Find model with `model_search` + `hub_repo_details`
252
+ 3. Create inference script with pipeline or generate
253
+ 4. Submit with `hf_jobs` on appropriate hardware
254
+ 5. Provide monitoring info
255
+
256
+ ### For Evaluation Tasks
257
+
258
+ **Pattern:**
259
+ 1. Research evaluation framework (lighteval, lm-evaluation-harness)
260
+ 2. Find model to evaluate
261
+ 3. Create evaluation script
262
+ 4. Submit job with appropriate hardware
263
+ 5. Store results with `hf_private_repos`
264
+
265
+ # Tool Usage Patterns for Reliability
266
+
267
+ ## GitHub Code Research Tools (⚠️ CRITICAL - Use BEFORE Implementing)
268
+
269
+ **github_find_examples:**
270
+ - ⚠️ MANDATORY: ALWAYS use before implementing ML tasks
271
+ - Find working example code (scripts, notebooks, tutorials) in repositories
272
+ - Use to discover current implementations BEFORE writing code
273
+ - Pattern: find_examples → read_file → implement using proven patterns
274
+ - Shows: Current API usage, best practices, working configurations
275
+ - Example: `github_find_examples({"repo": "trl", "keyword": "grpo"})`
276
+
277
+ **github_read_file:**
278
+ - Use AFTER github_find_examples to study implementation code
279
+ - Read trainer classes, example scripts, configuration files
280
+ - Returns: File contents with line numbers (default 300 lines)
281
+ - Use line_start/line_end for large files
282
+ - Example: `github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"})`
283
+
284
+
285
+ **github_list_repos:**
286
+ - Discover libraries and repositories for a task
287
+ - List repos by stars, forks, update date
288
+ - Use when exploring what libraries exist
289
+ - Example: `github_list_repos({"owner": "huggingface", "sort": "stars", "limit": 10})`
290
+
291
+ ## Documentation Tools
292
+
293
+ **explore_hf_docs:**
294
+ - Use AFTER github_find_examples to complement example code with docs
295
+ - Use to discover current documentation structure
296
+ - Returns list of pages with 300-char glimpses
297
+ - Then use fetch_hf_docs for detailed content
298
+
299
+ **fetch_hf_docs:**
300
+ - Use after explore_hf_docs to get full page content
301
+ - Get complete API documentation, examples, parameters
302
+ - Critical for training tasks to get current trainer configs
303
+
304
+ **find_hf_api:**
305
+ - Find REST API endpoints by keyword search or tag browsing
306
+ - Use `query` for keyword search (e.g., "space logs", "organization members", "jwt token")
307
+ - Use `tag` to browse all endpoints in a category
308
+ - Returns curl examples with authentication patterns
309
+ - Use for API-only operations: streaming logs/metrics, org management, security scans, etc.
310
+
311
+ ## Hub Discovery Tools (MCP)
312
+
313
+ **model_search:**
314
+ - Find models by query, task, author, library
315
+ - Sort by downloads, likes, trending, created date
316
+ - ALWAYS verify with hub_repo_details before using
317
+ - Select most appropriate option based on requirements
318
+
319
+ **dataset_search:**
320
+ - Find datasets by query, tags, author
321
+ - Sort by downloads, likes, trending
322
+ - ALWAYS verify format with hub_repo_details before training
323
+ - Select most suitable dataset based on format and task
324
+
325
+ **paper_search:**
326
+ - Find research papers semantically
327
+ - Get paper abstracts and links
328
+ - Useful for understanding methods before implementing
329
+
330
+ **hub_repo_details:**
331
+ - Get detailed information about repos
332
+ - ⚠️ CRITICAL: Use this to verify dataset format before training
333
+ - Check model size, architecture, requirements
334
+ - Verify dataset columns, splits, size
335
+
336
+ ## Execution & Storage Tools
337
+
338
+ **hf_jobs:**
339
+ - Execute workloads on cloud infrastructure with detailed hardware specs (vCPU/RAM/GPU)
340
+ - ⚠️ Set timeout >30m (default too short)
341
+ - ⚠️ Include HF_TOKEN for Hub operations
342
+ - ⚠️ Storage is EPHEMERAL - must push_to_hub
343
+
344
+ **hf_private_repos:**
345
+ - Store job outputs persistently in datasets with push_to_hub (jobs lose files after completion)
346
+ - Upload logs, scripts, results that can't push_to_hub
347
+ - Create private repos for sensitive data
348
+ - Content-based: pass strings/bytes, not file paths
349
+ - After upload: provide repo URL to user
350
+
351
+ **plan_tool:**
352
+ - Break down complex tasks (3+ steps)
353
+ - Update frequently to show progress
354
+ - Exactly ONE task in_progress at a time
355
+ - Mark completed immediately after finishing
356
+
357
+ ## Space Tools (MCP)
358
+
359
+ **space_search:**
360
+ - Find deployed Spaces (demos, applications)
361
+ - Discover existing implementations
362
+
363
+ **use_space:**
364
+ - Give user access to a Space
365
+ - Returns link for user (may not be visible to you)
366
+
367
+ **dynamic_space:**
368
+ - Execute tasks using Space functionality
369
+ - Image generation, OCR, text-to-speech, etc.
370
+ - Only works with MCP-enabled Spaces
371
+
372
+ # Ground Rules for Reliability
373
+
374
+ ## Async Operations (Jobs, Long Tasks)
375
+
376
+ **✓ DO:**
377
+ - Poll logs automatically after submission to ensure job is running and works as expected
378
+ - Include Trackio dashboard URL for training jobs
379
+ - Note that user can check status later
380
+ - Explain what's happening in the background
381
+
382
+ **✗ DON'T:**
383
+ - Check status unless user asks
384
+ - Assume job will complete quickly
385
+
386
+ ## Resource Selection
387
+
388
+ **✓ DO:**
389
+ - Research and evaluate 3-5 options for models/datasets
390
+ - Assess key details (size, format, popularity, suitability)
391
+ - Select optimal option based on task requirements and efficiency
392
+ - ALWAYS validate dataset format matches training method before proceeding
393
+ - Choose hardware that balances cost and performance
394
+
395
+ **✗ DON'T:**
396
+ - Skip research and validation steps
397
+ - Assume most popular is automatically best for task
398
+ - Proceed with training without format validation
399
+ - Select unnecessarily expensive hardware without justification
400
+
401
+ ## Documentation Usage
402
+
403
+ **✓ DO:**
404
+ - Research before implementing any ML task
405
+ - Use explore → fetch → implement pattern
406
+ - Check current APIs and parameters
407
+ - Base implementation on researched approaches
408
+
409
+ **✗ DON'T:**
410
+ - Implement based on internal knowledge without checking docs
411
+ - Assume you know current API syntax
412
+ - Skip research for "simple" tasks
413
+ - Use outdated patterns or methods
414
+
415
+ ## Error Handling & Recovery
416
+
417
+ **When Errors Occur:**
418
+ 1. ✅ Keep task in `in_progress` status (don't mark complete)
419
+ 2. ✅ Create new todo for resolving the issue
420
+ 3. ✅ Explain error clearly with technical details
421
+ 4. ✅ Provide actionable solution based on error type
422
+ 5. ✅ Check documentation if API/syntax error
423
+ 6. ✅ Verify configuration if job fails
424
+ 7. ✅ Implement fix and retry automatically with corrected approach
425
+
426
+ **Common Issues & Solutions:**
427
+
428
+ ### Job Timeout Exceeded
429
+ **Symptom:** Job stops mid-execution, incomplete
430
+ **Cause:** Timeout too short for workload
431
+ **Solution:**
432
+ ```python
433
+ # ✗ WRONG: Default timeout
434
+ {"timeout": "30m"} # Too short for training!
435
+
436
+ # ✓ CORRECT: Appropriate timeout
437
+ {"timeout": "4h"} # For 1-3B model training
438
+ {"timeout": "8h"} # For 7-13B model training
439
+ ```
440
+
441
+ ### Model Not Pushed to Hub
442
+ **Symptom:** Training completes but model not on Hub
443
+ **Causes & Solutions:**
444
+ 1. Missing `push_to_hub=True` in training config
445
+ 2. Missing `hub_model_id` in training config
446
+ 3. Missing `HF_TOKEN` in job env
447
+ 4. Token lacks write permissions
448
+
449
+ **Solution:**
450
+ ```python
451
+ # Training config:
452
+ training_args = SFTConfig(
453
+ push_to_hub=True, # ← Must be True
454
+ hub_model_id="username/model-name", # ← Must be set
455
+ # ...
456
+ )
457
+ ```
458
+
459
+ ### Dataset Format Mismatch
460
+ **Symptom:** Training fails with KeyError or format errors
461
+ **Cause:** Dataset format doesn't match training method
462
+ **Solution:**
463
+ 1. Use `hub_repo_details` to inspect dataset structure
464
+ 2. Verify format requirements:
465
+ - SFT: needs "messages", "text", or "prompt"/"completion"
466
+ - DPO: needs "prompt", "chosen", "rejected"
467
+ - GRPO: needs "prompt" only
468
+ 3. Preprocess dataset to correct format
469
+ 4. Proceed with corrected configuration
470
+
471
+ ### Out of Memory (OOM)
472
+ **Symptom:** Job crashes with CUDA OOM error
473
+ **Solutions (in order of preference):**
474
+ 1. Increase `gradient_accumulation_steps` (compensates smaller batch)
475
+ 2. Reduce `per_device_train_batch_size` (try 4 → 2 → 1)
476
+ 3. Enable `gradient_checkpointing=True`
477
+ 4. Reduce `max_length` (e.g., 1024 → 512)
478
+ 5. Upgrade to larger GPU (t4 → a10g → a100 → h100)
479
+
480
+ # Communication Style
481
+
482
+ - Be concise and direct
483
+ - Don't flatter the user
484
+ - Don't use emojis in regular communication (okay in status messages like "✅ Job submitted!")
485
+ - Don't use exclamation points in regular text
486
+ - If limited in a task, offer alternatives
487
+ - Don't thank user when they provide information
488
+ - Explain what you're doing for non-trivial operations
489
+ - Answer user questions directly - questions take precedence over task completion
490
+ - One-word answers when appropriate for simple questions
491
+ - For complex tasks, provide structured breakdown
492
+
493
+ # ⚠️ CRITICAL: Task Completion Requirements
494
+
495
+ **You must FULLY satisfy the user's request before finishing your turn.** Do not stop prematurely.
496
+
497
+ **Before ending your turn, verify:**
498
+ 1. ✅ Did I actually finish DOING what the user asked, not just explain it/partially do it?
499
+ 2. ✅ Did I confirm the task succeeded (job submitted, file uploaded, etc.)?
500
+ 3. ✅ If I encountered an error, did I fix it and retry?
501
+ 4. ✅ For jobs/async tasks: Did I provide monitoring info and expected outcomes?
502
+
503
+ **Common mistakes to avoid:**
504
+ - ✗ Stopping after "I'll help you with X" without actually doing X
505
+ - ✗ Explaining what you WOULD do instead of DOING it
506
+ - ✗ Ending after a tool call fails without retrying or fixing
507
+ - ✗ Stopping mid-task because you described what happens next
508
+ - ✗ Not providing final summary with URLs/results after completing
509
+
510
+ **Correct behavior:**
511
+ - ✓ Continue calling tools until the task is actually complete
512
+ - ✓ After submitting a job, provide the job URL and monitoring links
513
+ - ✓ After an error, diagnose and fix it, then retry
514
+ - ✓ End with a clear summary of what was accomplished and any next steps
515
+
516
+ # Examples
517
+
518
+ <example>
519
+ User: Fine-tune Llama for instruction following on ultrachat dataset
520
+
521
+ Assistant:
522
+ ✓ I'll help you fine-tune Llama for instruction following. Let me start by researching working example code and current TRL documentation.
523
+
524
+ [Creates plan with plan_tool: Find examples, Study code, Research docs, Find model, Validate dataset, Create script, Submit job]
525
+
526
+ [STEP 1: Find working example code FIRST]
527
+ github_find_examples({"repo": "trl", "keyword": "sft", "org": "huggingface"})
528
+ # Found: examples/scripts/sft.py, examples/scripts/sft_vlm.py
529
+
530
+ [STEP 2: Read the working implementation]
531
+ github_read_file({"repo": "huggingface/trl", "path": "examples/scripts/sft.py"})
532
+ # Studied: SFTTrainer usage, SFTConfig parameters, dataset handling, imports
533
+
534
+ [STEP 3: Research documentation for details]
535
+ [Researches: explore_hf_docs("trl"), fetch_hf_docs(SFT pages), explore_hf_docs("trackio")]
536
+
537
+ [STEP 4: Discover resources]
538
+ [Discovers resources: model_search, hub_repo_details for latest Llama models]
539
+ [Discovers datasets: dataset_search, hub_repo_details for ultrachat]
540
+
541
+ [STEP 5: Select optimal configuration]
542
+ After evaluating options:
543
+ - Selected: meta-llama/Llama-3.2-1B (1.24B params) - optimal balance of quality and efficiency
544
+ - Dataset: HuggingFaceH4/ultrachat_200k (207K samples, "messages" format ✓ SFT-compatible)
545
+ - Hardware: t4-small (4vCPU/15GB/GPU 16GB, $0.60/hr) - cost-efficient for this model size
546
+ - Estimated: 3 hours, ~$1.80 total cost
547
+
548
+ [STEP 6: Create and submit training job]
549
+ [Updates plan: mark resource selection complete, mark script creation in_progress]
550
+
551
+ [Creates script based on examples/scripts/sft.py pattern with:
552
+ - Imports from studied example (transformers, trl, datasets, trackio)
553
+ - SFTTrainer configuration from working code
554
+ - Dataset handling pattern from example (load_dataset + format verification)
555
+ - Trackio monitoring as shown in docs
556
+ - push_to_hub configuration with HF_TOKEN]
557
+
558
+ [Submits training job with hf_jobs: hardware=t4-small, timeout=4h, env=HF_TOKEN]
559
+
560
+ </example>
561
+
562
+ <example>
563
+ User: My Space crashes on startup
564
+
565
+ Assistant:
566
+ ✓ I'll help debug your Space. Let me inspect the repository and logs.
567
+
568
+ [Creates plan: Inspect repo, Identify errors, Research solutions, Fix issues]
569
+
570
+ [Uses hub_repo_details to get Space details and logs]
571
+
572
+ I can see the Space is failing with a Python import error. It's trying to import 'gradio' but the package isn't in requirements.txt.
573
+
574
+ [Explores documentation: explore_hf_docs("gradio"), find_hf_api(query="space logs") for streaming logs]
575
+
576
+ Based on the Gradio docs, I need to add gradio to requirements.txt and ensure the app.py file is correctly structured.
577
+
578
+ [Fixes using hf_private_repos: uploads corrected requirements.txt and app.py]
579
+
580
+ [Checks if the Space is running and works as expected]
581
+ </example>
582
+
583
+ <example>
584
+ User: Process the squad dataset - filter where context length > 100 chars and push to my Hub
585
+
586
+ Assistant:
587
+ ✓ I'll process the SQuAD dataset with your filter and push the results.
588
+
589
+ [Creates plan: Find dataset, Research processing, Create script, Submit job, Monitor progress]
590
+
591
+ [Discovers: dataset_search for squad, hub_repo_details to verify structure]
592
+
593
+ Found the SQuAD dataset. It has 'context' and 'question' columns.
594
+ I'll filter rows where len(context) > 100 characters.
595
+
596
+ [Researches: explore_hf_docs("datasets"), fetch_hf_docs for processing/filtering]
597
+
598
+ [Submits processing job with hf_jobs and makes sure to push the results to the Hub]
599
+
600
+ </example>
601
+
602
+ # Additional Instructions
603
+
604
+ - **Always use current information:** Find working examples with github_find_examples + check documentation before implementing; internal knowledge may be outdated
605
+ - **Example code first:** ALWAYS use github_find_examples + github_read_file before implementing ML tasks - real code shows current APIs and patterns
606
+ - **Search before building:** Use Hub search tools, GitHub code search, and documentation before creating custom solutions
607
+ - **Verify explicitly:** Never assume dataset schemas, column names, or API details; always check with hub_repo_details
608
+ - **Base on documented practices:** Implement using researched approaches from documentation, not general knowledge
609
+ - **Follow ML best practices:** Proper splits, reproducibility, evaluation metrics, suitable hardware
610
+ - **Respect storage boundaries:** Spaces and repos are permanent; job filesystems are ephemeral
611
+ - **Content-based operations:** For hf_private_repos, pass file contents not paths; local and remote filesystems are separate
612
+ - **Secure secrets:** HF_TOKEN automatically available via env; never expose or log tokens
613
+ - **Include links:** Provide direct URLs when referencing models, datasets, papers, jobs, repos
614
+ - **Execute user requests:** Always do what the user asks you to do
615
+ - **Parallel tool execution:** Call multiple independent tools simultaneously for efficiency when possible
616
+
617
+ # Token Count & Context Management
618
+
619
+ {{ num_tools }} tools are available. Tool descriptions are comprehensive to ensure reliable behavior for complex, long-running ML tasks. Prioritize:
620
+ 1. Research current documentation before implementing
621
+ 2. Validate resources before expensive operations
622
+ 3. Handle async operations correctly
623
+ 4. Ensure result persistence
624
+ 5. Communicate progress and expectations clearly
625
+
626
+ This verbose guidance optimizes for ZERO ERRORS in production ML workflows over token efficiency.
agent/tools/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face tools for the agent
3
+ """
4
+
5
+ from agent.tools.dataset_tools import (
6
+ HF_INSPECT_DATASET_TOOL_SPEC,
7
+ hf_inspect_dataset_handler,
8
+ )
9
+ from agent.tools.github_find_examples import (
10
+ GITHUB_FIND_EXAMPLES_TOOL_SPEC,
11
+ github_find_examples_handler,
12
+ )
13
+ from agent.tools.github_list_repos import (
14
+ GITHUB_LIST_REPOS_TOOL_SPEC,
15
+ github_list_repos_handler,
16
+ )
17
+ from agent.tools.github_read_file import (
18
+ GITHUB_READ_FILE_TOOL_SPEC,
19
+ github_read_file_handler,
20
+ )
21
+ from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
22
+ from agent.tools.types import ToolResult
23
+
24
+ __all__ = [
25
+ "ToolResult",
26
+ "HF_JOBS_TOOL_SPEC",
27
+ "hf_jobs_handler",
28
+ "HfJobsTool",
29
+ "GITHUB_FIND_EXAMPLES_TOOL_SPEC",
30
+ "github_find_examples_handler",
31
+ "GITHUB_LIST_REPOS_TOOL_SPEC",
32
+ "github_list_repos_handler",
33
+ "GITHUB_READ_FILE_TOOL_SPEC",
34
+ "github_read_file_handler",
35
+ "GITHUB_SEARCH_CODE_TOOL_SPEC",
36
+ "github_search_code_handler",
37
+ "HF_INSPECT_DATASET_TOOL_SPEC",
38
+ "hf_inspect_dataset_handler",
39
+ ]
agent/tools/dataset_tools.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Inspection Tool - Comprehensive dataset analysis in one call
3
+
4
+ Combines /is-valid, /splits, /info, /first-rows, and /parquet endpoints
5
+ to provide everything needed for ML tasks in a single tool call.
6
+ """
7
+
8
+ import asyncio
9
+ import os
10
+ from typing import Any, TypedDict
11
+
12
+ import httpx
13
+
14
+ from agent.tools.types import ToolResult
15
+
16
+ BASE_URL = "https://datasets-server.huggingface.co"
17
+
18
+ # Truncation limit for long sample values in the output
19
+ MAX_SAMPLE_VALUE_LEN = 150
20
+
21
+
22
+ class SplitConfig(TypedDict):
23
+ """Typed representation of a dataset config and its splits."""
24
+
25
+ name: str
26
+ splits: list[str]
27
+
28
+
29
+ def _get_headers() -> dict:
30
+ """Get auth headers for private/gated datasets"""
31
+ token = os.environ.get("HF_TOKEN")
32
+ if token:
33
+ return {"Authorization": f"Bearer {token}"}
34
+ return {}
35
+
36
+
37
+ async def inspect_dataset(
38
+ dataset: str,
39
+ config: str | None = None,
40
+ split: str | None = None,
41
+ sample_rows: int = 3,
42
+ ) -> ToolResult:
43
+ """
44
+ Get comprehensive dataset info in one call.
45
+ All API calls made in parallel for speed.
46
+ """
47
+ headers = _get_headers()
48
+ output_parts = []
49
+ errors = []
50
+
51
+ async with httpx.AsyncClient(timeout=15, headers=headers) as client:
52
+ # Phase 1: Parallel calls for structure info (no dependencies)
53
+ is_valid_task = client.get(f"{BASE_URL}/is-valid", params={"dataset": dataset})
54
+ splits_task = client.get(f"{BASE_URL}/splits", params={"dataset": dataset})
55
+ parquet_task = client.get(f"{BASE_URL}/parquet", params={"dataset": dataset})
56
+
57
+ results = await asyncio.gather(
58
+ is_valid_task,
59
+ splits_task,
60
+ parquet_task,
61
+ return_exceptions=True,
62
+ )
63
+
64
+ # Process is-valid
65
+ if not isinstance(results[0], Exception):
66
+ try:
67
+ output_parts.append(_format_status(results[0].json()))
68
+ except Exception as e:
69
+ errors.append(f"is-valid: {e}")
70
+
71
+ # Process splits and auto-detect config/split
72
+ configs = []
73
+ if not isinstance(results[1], Exception):
74
+ try:
75
+ splits_data = results[1].json()
76
+ configs = _extract_configs(splits_data)
77
+ if not config:
78
+ config = configs[0]["name"] if configs else "default"
79
+ if not split:
80
+ split = configs[0]["splits"][0] if configs else "train"
81
+ output_parts.append(_format_structure(configs))
82
+ except Exception as e:
83
+ errors.append(f"splits: {e}")
84
+
85
+ if not config:
86
+ config = "default"
87
+ if not split:
88
+ split = "train"
89
+
90
+ # Process parquet (will be added at the end)
91
+ parquet_section = None
92
+ if not isinstance(results[2], Exception):
93
+ try:
94
+ parquet_section = _format_parquet_files(results[2].json())
95
+ except Exception:
96
+ pass # Silently skip if no parquet
97
+
98
+ # Phase 2: Parallel calls for content (depend on config/split)
99
+ info_task = client.get(
100
+ f"{BASE_URL}/info", params={"dataset": dataset, "config": config}
101
+ )
102
+ rows_task = client.get(
103
+ f"{BASE_URL}/first-rows",
104
+ params={"dataset": dataset, "config": config, "split": split},
105
+ timeout=30,
106
+ )
107
+
108
+ content_results = await asyncio.gather(
109
+ info_task,
110
+ rows_task,
111
+ return_exceptions=True,
112
+ )
113
+
114
+ # Process info (schema)
115
+ if not isinstance(content_results[0], Exception):
116
+ try:
117
+ output_parts.append(_format_schema(content_results[0].json(), config))
118
+ except Exception as e:
119
+ errors.append(f"info: {e}")
120
+
121
+ # Process sample rows
122
+ if not isinstance(content_results[1], Exception):
123
+ try:
124
+ output_parts.append(
125
+ _format_samples(
126
+ content_results[1].json(), config, split, sample_rows
127
+ )
128
+ )
129
+ except Exception as e:
130
+ errors.append(f"rows: {e}")
131
+
132
+ # Add parquet section at the end if available
133
+ if parquet_section:
134
+ output_parts.append(parquet_section)
135
+
136
+ # Combine output
137
+ formatted = f"# {dataset}\n\n" + "\n\n".join(output_parts)
138
+ if errors:
139
+ formatted += f"\n\n**Warnings:** {'; '.join(errors)}"
140
+
141
+ return {
142
+ "formatted": formatted,
143
+ "totalResults": 1,
144
+ "resultsShared": 1,
145
+ "isError": len(output_parts) == 0,
146
+ }
147
+
148
+
149
+ def _format_status(data: dict) -> str:
150
+ """Format /is-valid response as status line"""
151
+ available = [
152
+ k
153
+ for k in ["viewer", "preview", "search", "filter", "statistics"]
154
+ if data.get(k)
155
+ ]
156
+ if available:
157
+ return f"## Status\n✓ Valid ({', '.join(available)})"
158
+ return "## Status\n✗ Dataset may have issues"
159
+
160
+
161
+ def _extract_configs(splits_data: dict) -> list[SplitConfig]:
162
+ """Group splits by config"""
163
+ configs: dict[str, SplitConfig] = {}
164
+ for s in splits_data.get("splits", []):
165
+ cfg = s.get("config", "default")
166
+ if cfg not in configs:
167
+ configs[cfg] = {"name": cfg, "splits": []}
168
+ configs[cfg]["splits"].append(s.get("split"))
169
+ return list(configs.values())
170
+
171
+
172
+ def _format_structure(configs: list[SplitConfig], max_rows: int = 10) -> str:
173
+ """Format configs and splits as a markdown table."""
174
+ lines = [
175
+ "## Structure (configs & splits)",
176
+ "| Config | Split |",
177
+ "|--------|-------|",
178
+ ]
179
+
180
+ total_splits = sum(len(cfg["splits"]) for cfg in configs)
181
+ added_rows = 0
182
+
183
+ for cfg in configs:
184
+ for split_name in cfg["splits"]:
185
+ if added_rows >= max_rows:
186
+ break
187
+ lines.append(f"| {cfg['name']} | {split_name} |")
188
+ added_rows += 1
189
+ if added_rows >= max_rows:
190
+ break
191
+
192
+ if total_splits > added_rows:
193
+ lines.append(
194
+ f"| ... | ... | (_showing {added_rows} of {total_splits} config/split rows_) |"
195
+ )
196
+
197
+ return "\n".join(lines)
198
+
199
+
200
+ def _format_schema(info: dict, config: str) -> str:
201
+ """Extract features and format as table"""
202
+ features = info.get("dataset_info", {}).get("features", {})
203
+ lines = [f"## Schema ({config})", "| Column | Type |", "|--------|------|"]
204
+ for col_name, col_info in features.items():
205
+ col_type = _get_type_str(col_info)
206
+ lines.append(f"| {col_name} | {col_type} |")
207
+ return "\n".join(lines)
208
+
209
+
210
+ def _get_type_str(col_info: dict) -> str:
211
+ """Convert feature info to readable type string"""
212
+ dtype = col_info.get("dtype") or col_info.get("_type", "unknown")
213
+ if col_info.get("_type") == "ClassLabel":
214
+ names = col_info.get("names", [])
215
+ if names and len(names) <= 5:
216
+ return f"ClassLabel ({', '.join(f'{n}={i}' for i, n in enumerate(names))})"
217
+ return f"ClassLabel ({len(names)} classes)"
218
+ return str(dtype)
219
+
220
+
221
+ def _format_samples(rows_data: dict, config: str, split: str, limit: int) -> str:
222
+ """Format sample rows, truncate long values"""
223
+ rows = rows_data.get("rows", [])[:limit]
224
+ lines = [f"## Sample Rows ({config}/{split})"]
225
+
226
+ messages_col_data = None
227
+
228
+ for i, row_wrapper in enumerate(rows, 1):
229
+ row = row_wrapper.get("row", {})
230
+ lines.append(f"**Row {i}:**")
231
+ for key, val in row.items():
232
+ # Check for messages column and capture first one for format analysis
233
+ if key.lower() == "messages" and messages_col_data is None:
234
+ messages_col_data = val
235
+
236
+ val_str = str(val)
237
+ if len(val_str) > MAX_SAMPLE_VALUE_LEN:
238
+ val_str = val_str[:MAX_SAMPLE_VALUE_LEN] + "..."
239
+ lines.append(f"- {key}: {val_str}")
240
+
241
+ # If we found a messages column, add format analysis
242
+ if messages_col_data is not None:
243
+ messages_format = _format_messages_structure(messages_col_data)
244
+ if messages_format:
245
+ lines.append("")
246
+ lines.append(messages_format)
247
+
248
+ return "\n".join(lines)
249
+
250
+
251
+ def _format_messages_structure(messages_data: Any) -> str | None:
252
+ """
253
+ Analyze and format the structure of a messages column.
254
+ Common in chat/instruction datasets.
255
+ """
256
+ import json
257
+
258
+ # Parse if string
259
+ if isinstance(messages_data, str):
260
+ try:
261
+ messages_data = json.loads(messages_data)
262
+ except json.JSONDecodeError:
263
+ return None
264
+
265
+ if not isinstance(messages_data, list) or not messages_data:
266
+ return None
267
+
268
+ lines = ["## Messages Column Format"]
269
+
270
+ # Analyze message structure
271
+ roles_seen = set()
272
+ has_tool_calls = False
273
+ has_tool_results = False
274
+ message_keys = set()
275
+
276
+ for msg in messages_data:
277
+ if not isinstance(msg, dict):
278
+ continue
279
+
280
+ message_keys.update(msg.keys())
281
+
282
+ role = msg.get("role", "")
283
+ if role:
284
+ roles_seen.add(role)
285
+
286
+ if "tool_calls" in msg or "function_call" in msg:
287
+ has_tool_calls = True
288
+ if role in ("tool", "function") or msg.get("tool_call_id"):
289
+ has_tool_results = True
290
+
291
+ # Format the analysis
292
+ lines.append(
293
+ f"**Roles:** {', '.join(sorted(roles_seen)) if roles_seen else 'unknown'}"
294
+ )
295
+
296
+ # Show common message keys with presence indicators
297
+ common_keys = [
298
+ "role",
299
+ "content",
300
+ "tool_calls",
301
+ "tool_call_id",
302
+ "name",
303
+ "function_call",
304
+ ]
305
+ key_status = []
306
+ for key in common_keys:
307
+ if key in message_keys:
308
+ key_status.append(f"{key} ✓")
309
+ else:
310
+ key_status.append(f"{key} ✗")
311
+ lines.append(f"**Message keys:** {', '.join(key_status)}")
312
+
313
+ if has_tool_calls:
314
+ lines.append("**Tool calls:** ✓ Present")
315
+ if has_tool_results:
316
+ lines.append("**Tool results:** ✓ Present")
317
+
318
+ # Show example message structure
319
+ # Priority: 1) message with tool_calls, 2) first assistant message, 3) first non-system message
320
+ example = None
321
+ fallback = None
322
+ for msg in messages_data:
323
+ if not isinstance(msg, dict):
324
+ continue
325
+ role = msg.get("role", "")
326
+ # Check for actual tool_calls/function_call values (not None)
327
+ if msg.get("tool_calls") or msg.get("function_call"):
328
+ example = msg
329
+ break
330
+ if role == "assistant" and example is None:
331
+ example = msg
332
+ elif role != "system" and fallback is None:
333
+ fallback = msg
334
+ if example is None:
335
+ example = fallback
336
+
337
+ if example:
338
+ lines.append("")
339
+ lines.append("**Example message structure:**")
340
+ # Build a copy with truncated content but keep all keys
341
+ example_clean = {}
342
+ for key, val in example.items():
343
+ if key == "content" and isinstance(val, str) and len(val) > 100:
344
+ example_clean[key] = val[:100] + "..."
345
+ else:
346
+ example_clean[key] = val
347
+ lines.append("```json")
348
+ lines.append(json.dumps(example_clean, indent=2, ensure_ascii=False))
349
+ lines.append("```")
350
+
351
+ return "\n".join(lines)
352
+
353
+
354
+ def _format_parquet_files(data: dict, max_rows: int = 10) -> str | None:
355
+ """Format parquet file info, return None if no files."""
356
+ files = data.get("parquet_files", [])
357
+ if not files:
358
+ return None
359
+
360
+ # Group by config/split
361
+ groups: dict[str, dict] = {}
362
+ for f in files:
363
+ key = f"{f.get('config', 'default')}/{f.get('split', 'train')}"
364
+ if key not in groups:
365
+ groups[key] = {"count": 0, "size": 0}
366
+ size = f.get("size") or 0
367
+ if not isinstance(size, (int, float)):
368
+ size = 0
369
+ groups[key]["count"] += 1
370
+ groups[key]["size"] += int(size)
371
+
372
+ lines = ["## Files (Parquet)"]
373
+ items = list(groups.items())
374
+ total_groups = len(items)
375
+
376
+ shown = 0
377
+ for key, info in items[:max_rows]:
378
+ size_mb = info["size"] / (1024 * 1024)
379
+ lines.append(f"- {key}: {info['count']} file(s) ({size_mb:.1f} MB)")
380
+ shown += 1
381
+
382
+ if total_groups > shown:
383
+ lines.append(f"- ... (_showing {shown} of {total_groups} parquet groups_)")
384
+ return "\n".join(lines)
385
+
386
+
387
+ # Tool specification
388
+ HF_INSPECT_DATASET_TOOL_SPEC = {
389
+ "name": "hf_inspect_dataset",
390
+ "description": (
391
+ "Inspect a Hugging Face dataset comprehensively in one call.\n\n"
392
+ "## What you get\n"
393
+ "- Status check (validates dataset works without errors)\n"
394
+ "- All configs and splits (row counts/shares may be '?' when metadata is missing)\n"
395
+ "- Column names and types (schema)\n"
396
+ "- Sample rows to understand data format\n"
397
+ "- Parquet file structure and sizes\n\n"
398
+ "## CRITICAL\n"
399
+ "**Always inspect datasets before writing training code** to understand:\n"
400
+ "- Column names for your dataloader\n"
401
+ "- Data types and format\n"
402
+ "- Available splits (train/test/validation)\n\n"
403
+ "Supports private/gated datasets when HF_TOKEN is set.\n\n"
404
+ "## Examples\n"
405
+ '{"dataset": "stanfordnlp/imdb"}\n'
406
+ '{"dataset": "nyu-mll/glue", "config": "mrpc", "sample_rows": 5}\n'
407
+ ),
408
+ "parameters": {
409
+ "type": "object",
410
+ "properties": {
411
+ "dataset": {
412
+ "type": "string",
413
+ "description": "Dataset ID in 'org/name' format (e.g., 'stanfordnlp/imdb')",
414
+ },
415
+ "config": {
416
+ "type": "string",
417
+ "description": "Config/subset name. Auto-detected if not specified.",
418
+ },
419
+ "split": {
420
+ "type": "string",
421
+ "description": "Split for sample rows. Auto-detected if not specified.",
422
+ },
423
+ "sample_rows": {
424
+ "type": "integer",
425
+ "description": "Number of sample rows to show (default: 3, max: 10)",
426
+ "default": 3,
427
+ },
428
+ },
429
+ "required": ["dataset"],
430
+ },
431
+ }
432
+
433
+
434
+ async def hf_inspect_dataset_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
435
+ """Handler for agent tool router"""
436
+ try:
437
+ result = await inspect_dataset(
438
+ dataset=arguments["dataset"],
439
+ config=arguments.get("config"),
440
+ split=arguments.get("split"),
441
+ sample_rows=min(arguments.get("sample_rows", 3), 10),
442
+ )
443
+ return result["formatted"], not result.get("isError", False)
444
+ except Exception as e:
445
+ return f"Error inspecting dataset: {str(e)}", False
agent/tools/docs_tools.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Documentation search tools for exploring HuggingFace and Gradio documentation.
3
+ """
4
+
5
+ import asyncio
6
+ import json
7
+ import os
8
+ from typing import Any
9
+
10
+ import httpx
11
+ from bs4 import BeautifulSoup
12
+ from whoosh.analysis import StemmingAnalyzer
13
+ from whoosh.fields import ID, TEXT, Schema
14
+ from whoosh.filedb.filestore import RamStorage
15
+ from whoosh.qparser import MultifieldParser, OrGroup
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Configuration
19
+ # ---------------------------------------------------------------------------
20
+
21
+ DEFAULT_MAX_RESULTS = 20
22
+ MAX_RESULTS_CAP = 50
23
+
24
+ GRADIO_LLMS_TXT_URL = "https://gradio.app/llms.txt"
25
+ GRADIO_SEARCH_URL = "https://playground-worker.pages.dev/api/prompt"
26
+
27
+ COMPOSITE_ENDPOINTS: dict[str, list[str]] = {
28
+ "optimum": [
29
+ "optimum",
30
+ "optimum-habana",
31
+ "optimum-neuron",
32
+ "optimum-intel",
33
+ "optimum-executorch",
34
+ "optimum-tpu",
35
+ ],
36
+ "courses": [
37
+ "llm-course",
38
+ "robotics-course",
39
+ "mcp-course",
40
+ "smol-course",
41
+ "agents-course",
42
+ "deep-rl-course",
43
+ "computer-vision-course",
44
+ "audio-course",
45
+ "ml-games-course",
46
+ "diffusion-course",
47
+ "ml-for-3d-course",
48
+ "cookbook",
49
+ ],
50
+ }
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Caches
54
+ # ---------------------------------------------------------------------------
55
+
56
+ _docs_cache: dict[str, list[dict[str, str]]] = {}
57
+ _index_cache: dict[str, tuple[Any, MultifieldParser]] = {}
58
+ _cache_lock = asyncio.Lock()
59
+ _openapi_cache: dict[str, Any] | None = None
60
+ _openapi_index_cache: tuple[Any, MultifieldParser, list[dict[str, Any]]] | None = None
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # Gradio Documentation
64
+ # ---------------------------------------------------------------------------
65
+
66
+
67
+ async def _fetch_gradio_docs(query: str | None = None) -> str:
68
+ """
69
+ Fetch Gradio documentation.
70
+ Without query: Get full documentation from llms.txt
71
+ With query: Run embedding search on guides/demos for relevant content
72
+ """
73
+ async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
74
+ if not query:
75
+ resp = await client.get(GRADIO_LLMS_TXT_URL)
76
+ resp.raise_for_status()
77
+ return resp.text
78
+
79
+ resp = await client.post(
80
+ GRADIO_SEARCH_URL,
81
+ headers={
82
+ "Content-Type": "application/json",
83
+ "Origin": "https://gradio-docs-mcp.up.railway.app",
84
+ },
85
+ json={
86
+ "prompt_to_embed": query,
87
+ "SYSTEM_PROMPT": "$INSERT_GUIDES_DOCS_DEMOS",
88
+ "FALLBACK_PROMPT": "No results found",
89
+ },
90
+ )
91
+ resp.raise_for_status()
92
+ return resp.json().get("SYS_PROMPT", "No results found")
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # HF Documentation - Fetching
97
+ # ---------------------------------------------------------------------------
98
+
99
+
100
+ async def _fetch_endpoint_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]:
101
+ """Fetch all docs for an endpoint by parsing sidebar and fetching each page."""
102
+ url = f"https://huggingface.co/docs/{endpoint}"
103
+ headers = {"Authorization": f"Bearer {hf_token}"}
104
+
105
+ async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
106
+ resp = await client.get(url, headers=headers)
107
+ resp.raise_for_status()
108
+
109
+ soup = BeautifulSoup(resp.text, "html.parser")
110
+ sidebar = soup.find("nav", class_=lambda x: x and "flex-auto" in x)
111
+ if not sidebar:
112
+ raise ValueError(f"Could not find navigation sidebar for '{endpoint}'")
113
+
114
+ nav_items = []
115
+ for link in sidebar.find_all("a", href=True):
116
+ href = link["href"]
117
+ page_url = f"https://huggingface.co{href}" if href.startswith("/") else href
118
+ nav_items.append({"title": link.get_text(strip=True), "url": page_url})
119
+
120
+ if not nav_items:
121
+ raise ValueError(f"No navigation links found for '{endpoint}'")
122
+
123
+ async def fetch_page(item: dict[str, str]) -> dict[str, str]:
124
+ md_url = f"{item['url']}.md"
125
+ try:
126
+ r = await client.get(md_url, headers=headers)
127
+ r.raise_for_status()
128
+ content = r.text.strip()
129
+ glimpse = content[:200] + "..." if len(content) > 200 else content
130
+ except Exception as e:
131
+ content, glimpse = "", f"[Could not fetch: {str(e)[:50]}]"
132
+ return {
133
+ "title": item["title"],
134
+ "url": item["url"],
135
+ "md_url": md_url,
136
+ "glimpse": glimpse,
137
+ "content": content,
138
+ "section": endpoint,
139
+ }
140
+
141
+ return list(await asyncio.gather(*[fetch_page(item) for item in nav_items]))
142
+
143
+
144
+ async def _get_docs(hf_token: str, endpoint: str) -> list[dict[str, str]]:
145
+ """Get docs for endpoint with caching. Expands composite endpoints."""
146
+ async with _cache_lock:
147
+ if endpoint in _docs_cache:
148
+ return _docs_cache[endpoint]
149
+
150
+ sub_endpoints = COMPOSITE_ENDPOINTS.get(endpoint, [endpoint])
151
+ all_docs: list[dict[str, str]] = []
152
+
153
+ for sub in sub_endpoints:
154
+ async with _cache_lock:
155
+ if sub in _docs_cache:
156
+ all_docs.extend(_docs_cache[sub])
157
+ continue
158
+
159
+ docs = await _fetch_endpoint_docs(hf_token, sub)
160
+ async with _cache_lock:
161
+ _docs_cache[sub] = docs
162
+ all_docs.extend(docs)
163
+
164
+ async with _cache_lock:
165
+ _docs_cache[endpoint] = all_docs
166
+ return all_docs
167
+
168
+
169
+ # ---------------------------------------------------------------------------
170
+ # HF Documentation - Search
171
+ # ---------------------------------------------------------------------------
172
+
173
+
174
+ async def _build_search_index(
175
+ endpoint: str, docs: list[dict[str, str]]
176
+ ) -> tuple[Any, MultifieldParser]:
177
+ """Build or retrieve cached Whoosh search index."""
178
+ async with _cache_lock:
179
+ if endpoint in _index_cache:
180
+ return _index_cache[endpoint]
181
+
182
+ analyzer = StemmingAnalyzer()
183
+ schema = Schema(
184
+ title=TEXT(stored=True, analyzer=analyzer),
185
+ url=ID(stored=True, unique=True),
186
+ md_url=ID(stored=True),
187
+ section=ID(stored=True),
188
+ glimpse=TEXT(stored=True, analyzer=analyzer),
189
+ content=TEXT(stored=False, analyzer=analyzer),
190
+ )
191
+ storage = RamStorage()
192
+ index = storage.create_index(schema)
193
+ writer = index.writer()
194
+ for doc in docs:
195
+ writer.add_document(
196
+ title=doc.get("title", ""),
197
+ url=doc.get("url", ""),
198
+ md_url=doc.get("md_url", ""),
199
+ section=doc.get("section", endpoint),
200
+ glimpse=doc.get("glimpse", ""),
201
+ content=doc.get("content", ""),
202
+ )
203
+ writer.commit()
204
+
205
+ parser = MultifieldParser(
206
+ ["title", "content"],
207
+ schema=schema,
208
+ fieldboosts={"title": 2.0, "content": 1.0},
209
+ group=OrGroup,
210
+ )
211
+
212
+ async with _cache_lock:
213
+ _index_cache[endpoint] = (index, parser)
214
+ return index, parser
215
+
216
+
217
+ async def _search_docs(
218
+ endpoint: str, docs: list[dict[str, str]], query: str, limit: int
219
+ ) -> tuple[list[dict[str, Any]], str | None]:
220
+ """Search docs using Whoosh. Returns (results, fallback_message)."""
221
+ index, parser = await _build_search_index(endpoint, docs)
222
+
223
+ try:
224
+ query_obj = parser.parse(query)
225
+ except Exception:
226
+ return [], "Query contained unsupported syntax; showing default ordering."
227
+
228
+ with index.searcher() as searcher:
229
+ results = searcher.search(query_obj, limit=limit)
230
+ matches = [
231
+ {
232
+ "title": hit["title"],
233
+ "url": hit["url"],
234
+ "md_url": hit.get("md_url", ""),
235
+ "section": hit.get("section", endpoint),
236
+ "glimpse": hit["glimpse"],
237
+ "score": round(hit.score, 2),
238
+ }
239
+ for hit in results
240
+ ]
241
+
242
+ if not matches:
243
+ return [], "No strong matches found; showing default ordering."
244
+ return matches, None
245
+
246
+
247
+ # ---------------------------------------------------------------------------
248
+ # HF Documentation - Formatting
249
+ # ---------------------------------------------------------------------------
250
+
251
+
252
+ def _format_results(
253
+ endpoint: str,
254
+ items: list[dict[str, Any]],
255
+ total: int,
256
+ query: str | None = None,
257
+ note: str | None = None,
258
+ ) -> str:
259
+ """Format search results as readable text."""
260
+ base_url = f"https://huggingface.co/docs/{endpoint}"
261
+ out = f"Documentation structure for: {base_url}\n\n"
262
+
263
+ if query:
264
+ out += f"Query: '{query}' → showing {len(items)} result(s) out of {total} pages"
265
+ if note:
266
+ out += f" ({note})"
267
+ out += "\n\n"
268
+ else:
269
+ out += f"Found {len(items)} page(s) (total available: {total}).\n"
270
+ if note:
271
+ out += f"({note})\n"
272
+ out += "\n"
273
+
274
+ for i, item in enumerate(items, 1):
275
+ out += f"{i}. **{item['title']}**\n"
276
+ out += f" URL: {item['url']}\n"
277
+ out += f" Section: {item.get('section', endpoint)}\n"
278
+ if query and "score" in item:
279
+ out += f" Relevance score: {item['score']:.2f}\n"
280
+ out += f" Glimpse: {item['glimpse']}\n\n"
281
+
282
+ return out
283
+
284
+
285
+ # ---------------------------------------------------------------------------
286
+ # Handlers
287
+ # ---------------------------------------------------------------------------
288
+
289
+
290
+ async def explore_hf_docs_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
291
+ """Explore documentation structure with optional search query."""
292
+ endpoint = arguments.get("endpoint", "").lstrip("/")
293
+ query = arguments.get("query")
294
+ max_results = arguments.get("max_results")
295
+
296
+ if not endpoint:
297
+ return "Error: No endpoint provided", False
298
+
299
+ # Gradio uses its own API
300
+ if endpoint.lower() == "gradio":
301
+ try:
302
+ clean_query = (
303
+ query.strip() if isinstance(query, str) and query.strip() else None
304
+ )
305
+ content = await _fetch_gradio_docs(clean_query)
306
+ header = "# Gradio Documentation\n\n"
307
+ if clean_query:
308
+ header += f"Query: '{clean_query}'\n\n"
309
+ header += "Source: https://gradio.app/docs\n\n---\n\n"
310
+ return header + content, True
311
+ except httpx.HTTPStatusError as e:
312
+ return f"HTTP error fetching Gradio docs: {e.response.status_code}", False
313
+ except httpx.RequestError as e:
314
+ return f"Request error fetching Gradio docs: {str(e)}", False
315
+ except Exception as e:
316
+ return f"Error fetching Gradio docs: {str(e)}", False
317
+
318
+ # HF docs
319
+ hf_token = os.environ.get("HF_TOKEN")
320
+ if not hf_token:
321
+ return "Error: HF_TOKEN environment variable not set", False
322
+
323
+ try:
324
+ max_results_int = int(max_results) if max_results is not None else None
325
+ except (TypeError, ValueError):
326
+ return "Error: max_results must be an integer", False
327
+
328
+ if max_results_int is not None and max_results_int <= 0:
329
+ return "Error: max_results must be greater than zero", False
330
+
331
+ try:
332
+ docs = await _get_docs(hf_token, endpoint)
333
+ total = len(docs)
334
+
335
+ # Determine limit
336
+ if max_results_int is None:
337
+ limit = DEFAULT_MAX_RESULTS
338
+ limit_note = f"Showing top {DEFAULT_MAX_RESULTS} results (set max_results to adjust)."
339
+ elif max_results_int > MAX_RESULTS_CAP:
340
+ limit = MAX_RESULTS_CAP
341
+ limit_note = f"Requested {max_results_int} but showing top {MAX_RESULTS_CAP} (maximum)."
342
+ else:
343
+ limit = max_results_int
344
+ limit_note = None
345
+
346
+ # Search or paginate
347
+ clean_query = (
348
+ query.strip() if isinstance(query, str) and query.strip() else None
349
+ )
350
+ fallback_msg = None
351
+
352
+ if clean_query:
353
+ results, fallback_msg = await _search_docs(
354
+ endpoint, docs, clean_query, limit
355
+ )
356
+ if not results:
357
+ results = docs[:limit]
358
+ else:
359
+ results = docs[:limit]
360
+
361
+ # Combine notes
362
+ notes = []
363
+ if fallback_msg:
364
+ notes.append(fallback_msg)
365
+ if limit_note:
366
+ notes.append(limit_note)
367
+ note = "; ".join(notes) if notes else None
368
+
369
+ return _format_results(endpoint, results, total, clean_query, note), True
370
+
371
+ except httpx.HTTPStatusError as e:
372
+ return f"HTTP error: {e.response.status_code} - {e.response.text[:200]}", False
373
+ except httpx.RequestError as e:
374
+ return f"Request error: {str(e)}", False
375
+ except ValueError as e:
376
+ return f"Error: {str(e)}", False
377
+ except Exception as e:
378
+ return f"Unexpected error: {str(e)}", False
379
+
380
+
381
+ async def hf_docs_fetch_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
382
+ """Fetch full markdown content of a documentation page."""
383
+ url = arguments.get("url", "")
384
+ if not url:
385
+ return "Error: No URL provided", False
386
+
387
+ hf_token = os.environ.get("HF_TOKEN")
388
+ if not hf_token:
389
+ return "Error: HF_TOKEN environment variable not set", False
390
+
391
+ if not url.endswith(".md"):
392
+ url = f"{url}.md"
393
+
394
+ try:
395
+ async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
396
+ resp = await client.get(
397
+ url, headers={"Authorization": f"Bearer {hf_token}"}
398
+ )
399
+ resp.raise_for_status()
400
+ return f"Documentation from: {url}\n\n{resp.text}", True
401
+ except httpx.HTTPStatusError as e:
402
+ return (
403
+ f"HTTP error fetching {url}: {e.response.status_code} - {e.response.text[:200]}",
404
+ False,
405
+ )
406
+ except httpx.RequestError as e:
407
+ return f"Request error fetching {url}: {str(e)}", False
408
+ except Exception as e:
409
+ return f"Error fetching documentation: {str(e)}", False
410
+
411
+
412
+ # ---------------------------------------------------------------------------
413
+ # OpenAPI Search
414
+ # ---------------------------------------------------------------------------
415
+
416
+
417
+ async def _fetch_openapi_spec() -> dict[str, Any]:
418
+ """Fetch and cache HuggingFace OpenAPI specification."""
419
+ global _openapi_cache
420
+ if _openapi_cache is not None:
421
+ return _openapi_cache
422
+
423
+ async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
424
+ resp = await client.get("https://huggingface.co/.well-known/openapi.json")
425
+ resp.raise_for_status()
426
+
427
+ _openapi_cache = resp.json()
428
+ return _openapi_cache
429
+
430
+
431
+ def _extract_all_tags(spec: dict[str, Any]) -> list[str]:
432
+ """Extract all unique tags from OpenAPI spec."""
433
+ tags = set()
434
+ for tag_obj in spec.get("tags", []):
435
+ if "name" in tag_obj:
436
+ tags.add(tag_obj["name"])
437
+ for path_item in spec.get("paths", {}).values():
438
+ for method, op in path_item.items():
439
+ if method in ["get", "post", "put", "delete", "patch", "head", "options"]:
440
+ for tag in op.get("tags", []):
441
+ tags.add(tag)
442
+ return sorted(tags)
443
+
444
+
445
+ def _extract_all_endpoints(spec: dict[str, Any]) -> list[dict[str, Any]]:
446
+ """Extract all endpoints from OpenAPI spec."""
447
+ servers = spec.get("servers", [])
448
+ base_url = (
449
+ servers[0].get("url", "https://huggingface.co")
450
+ if servers
451
+ else "https://huggingface.co"
452
+ )
453
+
454
+ endpoints = []
455
+ for path, path_item in spec.get("paths", {}).items():
456
+ for method, op in path_item.items():
457
+ if method not in ["get", "post", "put", "delete", "patch", "head", "options"]:
458
+ continue
459
+ endpoints.append({
460
+ "path": path,
461
+ "method": method.upper(),
462
+ "operationId": op.get("operationId", ""),
463
+ "summary": op.get("summary", ""),
464
+ "description": op.get("description", ""),
465
+ "tags": " ".join(op.get("tags", [])),
466
+ "parameters": op.get("parameters", []),
467
+ "request_body": op.get("requestBody", {}),
468
+ "responses": op.get("responses", {}),
469
+ "base_url": base_url,
470
+ })
471
+ return endpoints
472
+
473
+
474
+ async def _build_openapi_index() -> tuple[Any, MultifieldParser, list[dict[str, Any]]]:
475
+ """Build or retrieve cached Whoosh index for OpenAPI endpoints."""
476
+ global _openapi_index_cache
477
+ async with _cache_lock:
478
+ if _openapi_index_cache is not None:
479
+ return _openapi_index_cache
480
+
481
+ spec = await _fetch_openapi_spec()
482
+ endpoints = _extract_all_endpoints(spec)
483
+
484
+ analyzer = StemmingAnalyzer()
485
+ schema = Schema(
486
+ path=ID(stored=True, unique=True),
487
+ method=ID(stored=True),
488
+ operationId=TEXT(stored=True, analyzer=analyzer),
489
+ summary=TEXT(stored=True, analyzer=analyzer),
490
+ description=TEXT(stored=True, analyzer=analyzer),
491
+ tags=TEXT(stored=True, analyzer=analyzer),
492
+ param_names=TEXT(stored=False, analyzer=analyzer),
493
+ )
494
+ storage = RamStorage()
495
+ index = storage.create_index(schema)
496
+ writer = index.writer()
497
+
498
+ for ep in endpoints:
499
+ param_names = " ".join(p.get("name", "") for p in ep.get("parameters", []))
500
+ writer.add_document(
501
+ path=ep["path"],
502
+ method=ep["method"],
503
+ operationId=ep.get("operationId", ""),
504
+ summary=ep.get("summary", ""),
505
+ description=ep.get("description", ""),
506
+ tags=ep.get("tags", ""),
507
+ param_names=param_names,
508
+ )
509
+ writer.commit()
510
+
511
+ parser = MultifieldParser(
512
+ ["summary", "description", "operationId", "tags", "param_names"],
513
+ schema=schema,
514
+ fieldboosts={"summary": 3.0, "operationId": 2.0, "description": 1.0, "tags": 1.5},
515
+ group=OrGroup,
516
+ )
517
+
518
+ async with _cache_lock:
519
+ _openapi_index_cache = (index, parser, endpoints)
520
+ return index, parser, endpoints
521
+
522
+
523
+ async def _search_openapi(
524
+ query: str, tag: str | None, limit: int = 20
525
+ ) -> tuple[list[dict[str, Any]], str | None]:
526
+ """Search OpenAPI endpoints using Whoosh. Returns (results, fallback_message)."""
527
+ index, parser, endpoints = await _build_openapi_index()
528
+
529
+ try:
530
+ query_obj = parser.parse(query)
531
+ except Exception:
532
+ return [], "Query contained unsupported syntax."
533
+
534
+ with index.searcher() as searcher:
535
+ results = searcher.search(query_obj, limit=limit * 2) # Get extra for tag filtering
536
+ matches = []
537
+ for hit in results:
538
+ # Find full endpoint data
539
+ ep = next((e for e in endpoints if e["path"] == hit["path"] and e["method"] == hit["method"]), None)
540
+ if ep is None:
541
+ continue
542
+ # Filter by tag if provided
543
+ if tag and tag not in ep.get("tags", ""):
544
+ continue
545
+ matches.append({**ep, "score": round(hit.score, 2)})
546
+ if len(matches) >= limit:
547
+ break
548
+
549
+ return matches, None if matches else "No matches found for query."
550
+
551
+
552
+ def _generate_curl_example(endpoint: dict[str, Any]) -> str:
553
+ """Generate curl command example for an endpoint."""
554
+ method = endpoint["method"]
555
+ path = endpoint["path"]
556
+ base_url = endpoint["base_url"]
557
+
558
+ # Build URL with path parameters
559
+ full_path = path
560
+ for param in endpoint.get("parameters", []):
561
+ if param.get("in") == "path" and param.get("required"):
562
+ name = param["name"]
563
+ example = param.get(
564
+ "example", param.get("schema", {}).get("example", f"<{name}>")
565
+ )
566
+ full_path = full_path.replace(f"{{{name}}}", str(example))
567
+
568
+ curl = f"curl -X {method} \\\n '{base_url}{full_path}'"
569
+
570
+ # Add query parameters
571
+ query_params = [p for p in endpoint.get("parameters", []) if p.get("in") == "query"]
572
+ if query_params and query_params[0].get("required"):
573
+ param = query_params[0]
574
+ example = param.get("example", param.get("schema", {}).get("example", "value"))
575
+ curl += f"?{param['name']}={example}"
576
+
577
+ curl += " \\\n -H 'Authorization: Bearer $HF_TOKEN'"
578
+
579
+ # Add request body
580
+ if method in ["POST", "PUT", "PATCH"] and endpoint.get("request_body"):
581
+ content = endpoint["request_body"].get("content", {})
582
+ if "application/json" in content:
583
+ curl += " \\\n -H 'Content-Type: application/json'"
584
+ schema = content["application/json"].get("schema", {})
585
+ example = schema.get("example", "{}")
586
+ if isinstance(example, dict):
587
+ example = json.dumps(example, indent=2)
588
+ curl += f" \\\n -d '{example}'"
589
+
590
+ return curl
591
+
592
+
593
+ def _format_parameters(parameters: list[dict[str, Any]]) -> str:
594
+ """Format parameter information from OpenAPI spec."""
595
+ if not parameters:
596
+ return ""
597
+
598
+ path_params = [p for p in parameters if p.get("in") == "path"]
599
+ query_params = [p for p in parameters if p.get("in") == "query"]
600
+ header_params = [p for p in parameters if p.get("in") == "header"]
601
+
602
+ output = []
603
+
604
+ for label, params in [
605
+ ("Path Parameters", path_params),
606
+ ("Query Parameters", query_params),
607
+ ("Header Parameters", header_params),
608
+ ]:
609
+ if not params:
610
+ continue
611
+ if output:
612
+ output.append("")
613
+ output.append(f"**{label}:**")
614
+ for p in params:
615
+ name = p.get("name", "")
616
+ required = " (required)" if p.get("required") else " (optional)"
617
+ desc = p.get("description", "")
618
+ ptype = p.get("schema", {}).get("type", "string")
619
+ example = p.get("example") or p.get("schema", {}).get("example", "")
620
+
621
+ output.append(f"- `{name}` ({ptype}){required}: {desc}")
622
+ if example:
623
+ output.append(f" Example: `{example}`")
624
+
625
+ return "\n".join(output)
626
+
627
+
628
+ def _format_response_info(responses: dict[str, Any]) -> str:
629
+ """Format response information from OpenAPI spec."""
630
+ if not responses:
631
+ return "No response information available"
632
+
633
+ output = []
634
+ for status, resp_obj in list(responses.items())[:3]:
635
+ desc = resp_obj.get("description", "")
636
+ output.append(f"- **{status}**: {desc}")
637
+ content = resp_obj.get("content", {})
638
+ if "application/json" in content:
639
+ schema = content["application/json"].get("schema", {})
640
+ if "type" in schema:
641
+ output.append(f" Returns: {schema.get('type', 'object')}")
642
+
643
+ return "\n".join(output)
644
+
645
+
646
+ def _format_openapi_results(
647
+ results: list[dict[str, Any]],
648
+ tag: str | None = None,
649
+ query: str | None = None,
650
+ note: str | None = None,
651
+ ) -> str:
652
+ """Format OpenAPI search results with curl examples."""
653
+ if not results:
654
+ if query and tag:
655
+ return f"No API endpoints found matching '{query}' in tag '{tag}'"
656
+ elif query:
657
+ return f"No API endpoints found matching '{query}'"
658
+ elif tag:
659
+ return f"No API endpoints found with tag '{tag}'"
660
+ return "No API endpoints found"
661
+
662
+ # Build header
663
+ if query and tag:
664
+ out = f"# API Endpoints matching '{query}' (tag: `{tag}`)\n\n"
665
+ elif query:
666
+ out = f"# API Endpoints matching '{query}'\n\n"
667
+ elif tag:
668
+ out = f"# API Endpoints for tag: `{tag}`\n\n"
669
+ else:
670
+ out = "# API Endpoints\n\n"
671
+
672
+ out += f"Found {len(results)} endpoint(s)"
673
+ if note:
674
+ out += f" ({note})"
675
+ out += "\n\n---\n\n"
676
+
677
+ for i, ep in enumerate(results, 1):
678
+ out += f"## {i}. {ep['method']} {ep['path']}\n\n"
679
+
680
+ if query and "score" in ep:
681
+ out += f"**Relevance:** {ep['score']:.2f}\n\n"
682
+
683
+ if ep.get("summary"):
684
+ out += f"**Summary:** {ep['summary']}\n\n"
685
+
686
+ if ep.get("description"):
687
+ desc = ep["description"][:300]
688
+ if len(ep["description"]) > 300:
689
+ desc += "..."
690
+ out += f"**Description:** {desc}\n\n"
691
+
692
+ if ep.get("tags"):
693
+ out += f"**Tags:** {ep['tags']}\n\n"
694
+
695
+ params_info = _format_parameters(ep.get("parameters", []))
696
+ if params_info:
697
+ out += params_info + "\n\n"
698
+
699
+ out += "**Usage:**\n```bash\n"
700
+ out += _generate_curl_example(ep)
701
+ out += "\n```\n\n"
702
+
703
+ out += "**Returns:**\n"
704
+ out += _format_response_info(ep["responses"])
705
+ out += "\n\n---\n\n"
706
+
707
+ return out
708
+
709
+
710
+ async def search_openapi_handler(arguments: dict[str, Any]) -> tuple[str, bool]:
711
+ """Search HuggingFace OpenAPI specification by query and/or tag."""
712
+ tag = arguments.get("tag", "").strip() or None
713
+ query = arguments.get("query", "").strip() or None
714
+
715
+ if not tag and not query:
716
+ return "Error: Provide either 'query' (keyword search) or 'tag' (category filter), or both.", False
717
+
718
+ try:
719
+ note = None
720
+
721
+ # If query provided, try Whoosh search first
722
+ if query:
723
+ results, search_note = await _search_openapi(query, tag, limit=20)
724
+
725
+ # If Whoosh found results, return them
726
+ if results:
727
+ return _format_openapi_results(results, tag=tag, query=query, note=search_note), True
728
+
729
+ # Whoosh found nothing - fall back to tag-based if tag provided
730
+ if tag:
731
+ note = f"No matches for '{query}'; showing all endpoints in tag '{tag}'"
732
+ else:
733
+ # No tag to fall back to
734
+ return _format_openapi_results([], query=query), True
735
+
736
+ # Tag-based search (either as fallback or primary)
737
+ if tag:
738
+ _, _, endpoints = await _build_openapi_index()
739
+ results = [ep for ep in endpoints if tag in ep.get("tags", "")]
740
+ return _format_openapi_results(results, tag=tag, query=None, note=note), True
741
+
742
+ return "Error: No results found", False
743
+
744
+ except httpx.HTTPStatusError as e:
745
+ return f"HTTP error fetching OpenAPI spec: {e.response.status_code}", False
746
+ except httpx.RequestError as e:
747
+ return f"Request error: {str(e)}", False
748
+ except Exception as e:
749
+ return f"Error searching OpenAPI spec: {str(e)}", False
750
+
751
+
752
+ async def _get_api_search_tool_spec() -> dict[str, Any]:
753
+ """Generate OpenAPI tool spec with tags populated at runtime."""
754
+ spec = await _fetch_openapi_spec()
755
+ tags = _extract_all_tags(spec)
756
+
757
+ return {
758
+ "name": "find_hf_api",
759
+ "description": (
760
+ "Find HuggingFace Hub REST API endpoints to make HTTP requests. Returns curl examples with authentication. "
761
+ "⚠️ USE THIS TOOL when you need to call the HF Hub API directly - for operations like: "
762
+ "uploading/downloading files, managing repos, listing models/datasets, getting user info, "
763
+ "managing webhooks, collections, discussions, or any Hub interaction not covered by other tools. "
764
+ "**Use cases:** (1) 'Stream Space logs' → query='space logs', "
765
+ "(2) 'Get Space metrics/Zero-GPU usage' → query='space metrics', "
766
+ "(3) 'List organization members' → query='organization members', "
767
+ "(4) 'Generate repo access token' → query='jwt token', "
768
+ "(5) 'Check repo security scan' → query='security scan'. "
769
+ "**Search modes:** Use 'query' for keyword search, 'tag' to browse a category, or both. "
770
+ "If query finds no results, falls back to showing all endpoints in the tag. "
771
+ "**Output:** Full endpoint details with method, path, parameters, curl command, and response schema."
772
+ ),
773
+ "parameters": {
774
+ "type": "object",
775
+ "properties": {
776
+ "query": {
777
+ "type": "string",
778
+ "description": (
779
+ "Keyword search across endpoint summaries, descriptions, and operation IDs. "
780
+ "Examples: 'upload file', 'create repository', 'list user models', 'delete branch', "
781
+ "'webhook', 'collection', 'discussion comments'. Supports stemming (upload/uploading both work)."
782
+ ),
783
+ },
784
+ "tag": {
785
+ "type": "string",
786
+ "enum": tags,
787
+ "description": (
788
+ "Filter by API category. Use alone to browse all endpoints in a category, "
789
+ "or combine with 'query' to search within a category."
790
+ ),
791
+ },
792
+ },
793
+ "required": [],
794
+ },
795
+ }
796
+
797
+
798
+ # ---------------------------------------------------------------------------
799
+ # Tool Specifications
800
+ # ---------------------------------------------------------------------------
801
+
802
+ DOC_ENDPOINTS = [
803
+ "hub",
804
+ "transformers",
805
+ "diffusers",
806
+ "datasets",
807
+ "gradio",
808
+ "trackio",
809
+ "smolagents",
810
+ "huggingface_hub",
811
+ "huggingface.js",
812
+ "transformers.js",
813
+ "inference-providers",
814
+ "inference-endpoints",
815
+ "peft",
816
+ "accelerate",
817
+ "optimum",
818
+ "tokenizers",
819
+ "courses",
820
+ "evaluate",
821
+ "tasks",
822
+ "dataset-viewer",
823
+ "trl",
824
+ "simulate",
825
+ "sagemaker",
826
+ "timm",
827
+ "safetensors",
828
+ "tgi",
829
+ "setfit",
830
+ "lerobot",
831
+ "autotrain",
832
+ "tei",
833
+ "bitsandbytes",
834
+ "sentence_transformers",
835
+ "chat-ui",
836
+ "leaderboards",
837
+ "lighteval",
838
+ "argilla",
839
+ "distilabel",
840
+ "microsoft-azure",
841
+ "kernels",
842
+ "google-cloud",
843
+ ]
844
+
845
+ EXPLORE_HF_DOCS_TOOL_SPEC = {
846
+ "name": "explore_hf_docs",
847
+ "description": (
848
+ "Explore Hugging Face documentation structure and discover available pages with 200-character previews. "
849
+ "⚠️ MANDATORY: ALWAYS use this BEFORE implementing any ML task (training, fine-tuning, data processing, inference). "
850
+ "Your training data may be outdated - current documentation is the source of truth. "
851
+ "**Use when:** (1) Starting any implementation task, (2) User asks 'how to' questions, "
852
+ "(3) Before writing training/processing code, (4) Researching library capabilities, "
853
+ "(5) Verifying API syntax and parameters. "
854
+ "**Pattern:** explore (discover structure) → fetch_hf_docs (get details) → implement with researched approach. "
855
+ "Returns: Sidebar navigation with titles, URLs, and glimpses of all pages in the selected documentation. "
856
+ "**Then:** Use fetch_hf_docs with specific URLs from results to get full content. "
857
+ "**Critical for reliability:** Never implement based on internal knowledge without checking current docs first - APIs change frequently."
858
+ " By default returns the top 20 results; set max_results (max 50) to adjust."
859
+ ),
860
+ "parameters": {
861
+ "type": "object",
862
+ "properties": {
863
+ "endpoint": {
864
+ "type": "string",
865
+ "enum": DOC_ENDPOINTS,
866
+ "description": (
867
+ "The documentation endpoint to explore. Each endpoint corresponds to a major section of the Hugging Face documentation:\n\n"
868
+ "• courses — All Hugging Face courses (LLM, robotics, MCP, smol (llm training), agents, deep RL, computer vision, games, diffusion, 3D, audio) and the cookbook recipes. Probably the best place for examples.\n"
869
+ "• hub — Find answers to questions about models/datasets/spaces, auth, versioning, metadata.\n"
870
+ "• transformers — Core model library: architectures, configs, tokenizers, training & inference APIs.\n"
871
+ "• diffusers — Diffusion pipelines, schedulers, fine-tuning, training, and deployment patterns.\n"
872
+ "• datasets — Dataset loading, streaming, processing, Arrow format, Hub integration.\n"
873
+ "• gradio — UI components and demos for ML models. Uses Gradio's native API: without query returns full docs (llms.txt), with query uses embedding search for precise results.\n"
874
+ "• trackio — Experiment tracking, metrics logging, and run comparison.\n"
875
+ "• smolagents — Lightweight agent abstractions and tool-using patterns.\n"
876
+ "• huggingface_hub — Python client for Hub operations (auth, upload/download, repo management).\n"
877
+ "• huggingface.js — JS/TS client for Hub APIs in browser and Node.\n"
878
+ "• transformers.js — Run Transformer models in browser/Node via WebGPU/WASM.\n"
879
+ "• inference-providers — Unified interface for third-party inference backends.\n"
880
+ "• inference-endpoints — Managed, scalable model deployments on HF infrastructure.\n"
881
+ "• peft — Parameter-efficient fine-tuning methods (LoRA, adapters, etc.).\n"
882
+ "• accelerate — Hardware-agnostic, distributed and mixed-precision training orchestration.\n"
883
+ "• optimum — Hardware-aware optimization and model export tooling, including Habana, Neuron, Intel, ExecuTorch, and TPU variants.\n"
884
+ "• tokenizers — Fast tokenizer internals, training, and low-level APIs.\n"
885
+ "• evaluate — Metrics, evaluation workflows, and training-loop integration.\n"
886
+ "• tasks — Canonical task definitions and model categorization.\n"
887
+ "• dataset-viewer — Dataset preview, streaming views, and viewer internals.\n"
888
+ "• trl — RLHF, DPO, PPO, and SFT utilities for LLMs.\n"
889
+ "• simulate — Experimental simulation tools and workflows.\n"
890
+ "• sagemaker — Deploying Hugging Face models on AWS SageMaker.\n"
891
+ "• timm — Image model zoo and utilities via HF integrations.\n"
892
+ "• safetensors — Safe, fast tensor serialization format.\n"
893
+ "• tgi — High-throughput text generation server for LLMs.\n"
894
+ "• setfit — Few-shot text classification via sentence embeddings.\n"
895
+ "• lerobot — Robotics datasets, policies, and learning workflows.\n"
896
+ "• autotrain — No/low-code model training on Hugging Face.\n"
897
+ "• tei — Optimized inference server for embedding workloads.\n"
898
+ "• bitsandbytes — Quantization and memory-efficient optimizers.\n"
899
+ "• sentence_transformers — Embedding models, training recipes, similarity/search workflows.\n"
900
+ "• chat-ui — Reference chat interfaces for LLM deployment.\n"
901
+ "• leaderboards — Evaluation leaderboards and submission mechanics.\n"
902
+ "• lighteval — Lightweight, reproducible LLM evaluation framework.\n"
903
+ "• argilla — Data annotation, feedback, and human-in-the-loop workflows.\n"
904
+ "• distilabel — Synthetic data generation and distillation pipelines.\n"
905
+ "• microsoft-azure — Azure deployment and integration guides.\n"
906
+ "• kernels — Lightweight execution environments and notebook-style workflows.\n"
907
+ "• google-cloud — GCP deployment and serving workflows.\n"
908
+ ),
909
+ },
910
+ "query": {
911
+ "type": "string",
912
+ "description": (
913
+ "Optional keyword query to rank and filter documentation pages. "
914
+ "For Gradio, use concise queries like 'how to use the image component' or 'audio component demo'."
915
+ ),
916
+ },
917
+ "max_results": {
918
+ "type": "integer",
919
+ "description": "Max results (default 20, max 50). Ignored for Gradio.",
920
+ "minimum": 1,
921
+ "maximum": 50,
922
+ },
923
+ },
924
+ "required": ["endpoint"],
925
+ },
926
+ }
927
+
928
+ HF_DOCS_FETCH_TOOL_SPEC = {
929
+ "name": "fetch_hf_docs",
930
+ "description": (
931
+ "Fetch full markdown content of a specific HF documentation page. "
932
+ "⚠️ CRITICAL: Use this after explore_hf_docs to get detailed implementation guidance. "
933
+ "**Use when:** (1) Found relevant page in explore_hf_docs results, (2) Need complete API documentation, "
934
+ "(3) Need training method details (SFT/DPO/GRPO), (4) Need configuration examples, "
935
+ "(5) Need parameter descriptions and usage patterns. "
936
+ "**Pattern:** explore_hf_docs (find relevant page) → fetch_hf_docs (get full content) → implement using documented approach. "
937
+ "Provide full URL from explore_hf_docs results (e.g., 'https://huggingface.co/docs/trl/sft_trainer'). "
938
+ "Returns: Complete markdown documentation with examples, parameters, and usage patterns. "
939
+ "**For training tasks:** ALWAYS fetch trainer docs (SFTConfig, DPOConfig, etc.) before creating training scripts. "
940
+ "**Critical for reliability:** This ensures you use current APIs and best practices."
941
+ ),
942
+ "parameters": {
943
+ "type": "object",
944
+ "properties": {
945
+ "url": {
946
+ "type": "string",
947
+ "description": (
948
+ "The full URL to the documentation page. "
949
+ "Example: 'https://huggingface.co/docs/trl/dpo_trainer' "
950
+ "The .md extension will be added automatically if not present."
951
+ ),
952
+ },
953
+ },
954
+ "required": ["url"],
955
+ },
956
+ }
agent/tools/github_find_examples.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GitHub Find Examples Tool - Discover examples, tutorials, and guides for any library
3
+
4
+ Lists all files in a repository and performs deterministic keyword search.
5
+ """
6
+
7
+ import os
8
+ from typing import Any, Dict, List
9
+
10
+ import requests
11
+ from thefuzz import fuzz
12
+
13
+ from agent.tools.types import ToolResult
14
+
15
+ # In order of priority (lower index = higher priority for sorting)
16
+ EXAMPLE_PATTERNS = [
17
+ "scripts",
18
+ # General example patterns (catch-all, lower priority)
19
+ "examples",
20
+ "example",
21
+ # Notebook patterns
22
+ "notebooks",
23
+ "notebook",
24
+ # Tutorial/learning patterns
25
+ "tutorials",
26
+ "tutorial",
27
+ "quickstart",
28
+ "walkthroughs",
29
+ "walkthrough",
30
+ # Cookbook/recipe patterns
31
+ "cookbook",
32
+ "cookbooks",
33
+ "recipes",
34
+ "recipe",
35
+ # Demo/sample patterns
36
+ "demos",
37
+ "demo",
38
+ "samples",
39
+ "sample",
40
+ # Other patterns
41
+ "guides",
42
+ "guide",
43
+ "getting-started",
44
+ "getting_started",
45
+ "playground",
46
+ "howto",
47
+ "how-to",
48
+ "use-cases",
49
+ "usecases",
50
+ "use_cases",
51
+ "sandbox",
52
+ "showcase",
53
+ ]
54
+
55
+
56
+ def _get_repo_tree(org: str, repo: str, token: str) -> tuple[List[Dict[str, Any]], str]:
57
+ """Get all files in a repository recursively. Returns (files, error_message)"""
58
+ headers = {
59
+ "Accept": "application/vnd.github+json",
60
+ "X-GitHub-Api-Version": "2022-11-28",
61
+ "Authorization": f"Bearer {token}",
62
+ }
63
+
64
+ full_repo = f"{org}/{repo}"
65
+
66
+ # Get default branch
67
+ try:
68
+ response = requests.get(
69
+ f"https://api.github.com/repos/{full_repo}", headers=headers, timeout=10
70
+ )
71
+ if response.status_code == 404:
72
+ return [], "not_found"
73
+ if response.status_code != 200:
74
+ return [], f"API error: {response.status_code}"
75
+
76
+ repo_data = response.json()
77
+ default_branch = repo_data.get("default_branch", "main")
78
+ except Exception as e:
79
+ return [], f"Error fetching repo: {str(e)}"
80
+
81
+ # Get repository tree recursively
82
+ try:
83
+ response = requests.get(
84
+ f"https://api.github.com/repos/{full_repo}/git/trees/{default_branch}",
85
+ headers=headers,
86
+ params={"recursive": "1"},
87
+ timeout=30,
88
+ )
89
+ if response.status_code != 200:
90
+ return [], f"Error fetching tree: {response.status_code}"
91
+
92
+ data = response.json()
93
+ tree = data.get("tree", [])
94
+
95
+ # Filter to only include files (not directories)
96
+ files = [
97
+ {
98
+ "path": item["path"],
99
+ "ref": item["sha"],
100
+ "size": item.get("size", 0),
101
+ "url": f"https://github.com/{full_repo}/blob/{default_branch}/{item['path']}",
102
+ }
103
+ for item in tree
104
+ if item["type"] == "blob"
105
+ ]
106
+
107
+ return files, ""
108
+ except Exception as e:
109
+ return [], f"Error processing tree: {str(e)}"
110
+
111
+
112
+ def _search_similar_repos(org: str, repo: str, token: str) -> List[Dict[str, Any]]:
113
+ """Search for similar repository names in the organization"""
114
+ headers = {
115
+ "Accept": "application/vnd.github+json",
116
+ "X-GitHub-Api-Version": "2022-11-28",
117
+ "Authorization": f"Bearer {token}",
118
+ }
119
+
120
+ # Search for repos in the org with similar name
121
+ query = f"org:{org} {repo}"
122
+
123
+ try:
124
+ response = requests.get(
125
+ "https://api.github.com/search/repositories",
126
+ headers=headers,
127
+ params={"q": query, "sort": "stars", "order": "desc", "per_page": 10},
128
+ timeout=30,
129
+ )
130
+
131
+ if response.status_code != 200:
132
+ return []
133
+
134
+ data = response.json()
135
+ items = data.get("items", [])
136
+
137
+ return [
138
+ {
139
+ "name": item.get("name"),
140
+ "full_name": item.get("full_name"),
141
+ "description": item.get("description"),
142
+ "stars": item.get("stargazers_count", 0),
143
+ "url": item.get("html_url"),
144
+ }
145
+ for item in items
146
+ ]
147
+ except Exception:
148
+ return []
149
+
150
+
151
+ def _score_against_example_patterns(file_path: str) -> int:
152
+ """Score file against example patterns using token_set_ratio"""
153
+ scores = []
154
+ for pattern in EXAMPLE_PATTERNS:
155
+ score = fuzz.token_set_ratio(pattern.lower(), file_path.lower())
156
+ scores.append(score)
157
+ return max(scores) if scores else 0
158
+
159
+
160
+ def _score_against_keyword(file_path: str, keyword: str) -> int:
161
+ """Calculate fuzzy match score for a file path against a keyword"""
162
+ # Use partial_ratio for substring matching (good for paths)
163
+ # Also check token_set_ratio for word-level matching
164
+ partial_score = fuzz.partial_ratio(keyword.lower(), file_path.lower())
165
+ token_score = fuzz.token_set_ratio(keyword.lower(), file_path.lower())
166
+
167
+ # Return the higher of the two
168
+ return max(partial_score, token_score)
169
+
170
+
171
+ def _get_pattern_priority(file_path: str) -> tuple[int, int, int]:
172
+ """
173
+ Get priority of a file path based on which example pattern directory it's in.
174
+
175
+ Returns: (in_examples_dir, pattern_priority, path_depth)
176
+ - in_examples_dir: 0 if in examples/ directory, 1 otherwise (lower is better)
177
+ - pattern_priority: Index in EXAMPLE_PATTERNS (lower is better), or 999 if no match
178
+ - path_depth: Number of path segments (lower is better)
179
+
180
+ Note: Prioritizes files in "examples/" directory first, then by most specific pattern match.
181
+ E.g., "examples/scripts/train.py" is better than "scripts/util.py"
182
+ """
183
+ path_lower = file_path.lower()
184
+ path_parts = path_lower.split("/")
185
+
186
+ # Check if file is in examples/ directory (highest priority)
187
+ in_examples_dir = 0 if (path_parts[0] in ["examples", "example"]) else 1
188
+
189
+ # Find ALL matching patterns and use the best (lowest index) one
190
+ # But prefer deeper matches (more specific) over shallow ones
191
+ best_priority = 999
192
+ best_depth_at_match = -1
193
+
194
+ for i, pattern in enumerate(EXAMPLE_PATTERNS):
195
+ # Check if pattern appears as a directory component in the path
196
+ if pattern in path_parts:
197
+ # Find the depth where this pattern appears (rightmost occurrence)
198
+ depth = len(path_parts) - 1 - path_parts[::-1].index(pattern)
199
+
200
+ # Prefer deeper matches, or better priority if at same depth
201
+ if depth > best_depth_at_match or (
202
+ depth == best_depth_at_match and i < best_priority
203
+ ):
204
+ best_priority = i
205
+ best_depth_at_match = depth
206
+
207
+ return (in_examples_dir, best_priority, len(path_parts))
208
+
209
+
210
+ def _handle_repo_tree_errors(
211
+ all_files: List[Dict[str, Any]],
212
+ error: str,
213
+ org: str,
214
+ repo: str,
215
+ token: str,
216
+ ) -> ToolResult | None:
217
+ """Handle errors from repo tree fetch. Returns ToolResult if error, None if OK."""
218
+ if error == "not_found":
219
+ similar_repos = _search_similar_repos(org, repo, token)
220
+
221
+ if not similar_repos:
222
+ return {
223
+ "formatted": f"Repository '{org}/{repo}' not found and no similar repositories found.",
224
+ "totalResults": 0,
225
+ "resultsShared": 0,
226
+ "isError": True,
227
+ }
228
+
229
+ # Format similar repos
230
+ lines = [f"**Repository '{org}/{repo}' not found. Similar repositories:**\n"]
231
+ for i, r in enumerate(similar_repos, 1):
232
+ lines.append(f"{i}. **{r['full_name']}** (⭐ {r['stars']:,} stars)")
233
+ if r["description"]:
234
+ desc = (
235
+ r["description"][:100] + "..."
236
+ if len(r["description"]) > 100
237
+ else r["description"]
238
+ )
239
+ lines.append(f" {desc}")
240
+ lines.append(f" {r['url']}\n")
241
+
242
+ return {
243
+ "formatted": "\n".join(lines),
244
+ "totalResults": len(similar_repos),
245
+ "resultsShared": len(similar_repos),
246
+ "isError": True,
247
+ }
248
+
249
+ if error:
250
+ return {
251
+ "formatted": f"Error accessing repository '{org}/{repo}': {error}",
252
+ "totalResults": 0,
253
+ "resultsShared": 0,
254
+ "isError": True,
255
+ }
256
+
257
+ if not all_files:
258
+ return {
259
+ "formatted": f"No files found in repository '{org}/{repo}'",
260
+ "totalResults": 0,
261
+ "resultsShared": 0,
262
+ }
263
+
264
+ return None
265
+
266
+
267
+ def find_examples(
268
+ keyword: str = "",
269
+ repo: str = "",
270
+ org: str = "huggingface",
271
+ max_results: int = 10,
272
+ min_score: int = 80,
273
+ ) -> ToolResult:
274
+ """
275
+ Find example files in a repository using fuzzy matching.
276
+
277
+ Args:
278
+ keyword: Keyword to fuzzy match against file paths (e.g., "grpo")
279
+ repo: Repository name (e.g., "trl")
280
+ org: GitHub organization (default: "huggingface")
281
+ max_results: Maximum number of results (default 50)
282
+ min_score: Minimum fuzzy match score (0-100, default 60)
283
+
284
+ Returns:
285
+ ToolResult with matching files, or similar repos if repo not found
286
+ """
287
+ token = os.environ.get("GITHUB_TOKEN")
288
+ if not token:
289
+ return {
290
+ "formatted": "Error: GITHUB_TOKEN environment variable is required",
291
+ "totalResults": 0,
292
+ "resultsShared": 0,
293
+ "isError": True,
294
+ }
295
+
296
+ if not repo:
297
+ return {
298
+ "formatted": "Error: repo parameter is required",
299
+ "totalResults": 0,
300
+ "resultsShared": 0,
301
+ "isError": True,
302
+ }
303
+
304
+ # Get all files in the repository
305
+ all_files, error = _get_repo_tree(org, repo, token)
306
+
307
+ # Handle errors (not found, API errors, empty repo)
308
+ if error_result := _handle_repo_tree_errors(all_files, error, org, repo, token):
309
+ return error_result
310
+
311
+ # Step 1: Filter files by example patterns (score >= 60)
312
+ example_threshold = 60
313
+ example_files = []
314
+ for file in all_files:
315
+ example_score = _score_against_example_patterns(file["path"])
316
+ if example_score >= example_threshold:
317
+ example_files.append({**file, "example_score": example_score})
318
+
319
+ if not example_files:
320
+ return {
321
+ "formatted": f"No example files found in {org}/{repo} (no files match example patterns with score >= {example_threshold}).",
322
+ "totalResults": 0,
323
+ "resultsShared": 0,
324
+ }
325
+
326
+ # Step 2: If keyword provided, score and filter by keyword
327
+ if keyword:
328
+ scored_files = []
329
+ for file in example_files:
330
+ keyword_score = _score_against_keyword(file["path"], keyword)
331
+ if keyword_score >= min_score:
332
+ scored_files.append({**file, "score": keyword_score})
333
+
334
+ if not scored_files:
335
+ return {
336
+ "formatted": f"No files found in {org}/{repo} matching keyword '{keyword}' (min score: {min_score}) among {len(example_files)} example files.",
337
+ "totalResults": 0,
338
+ "resultsShared": 0,
339
+ }
340
+
341
+ # Sort by keyword score (descending) for best matches first
342
+ scored_files.sort(key=lambda x: x["score"], reverse=True)
343
+ else:
344
+ # No keyword: prioritize by pattern directory, then path depth
345
+ scored_files = []
346
+ for file in example_files:
347
+ in_examples_dir, pattern_priority, path_depth = _get_pattern_priority(
348
+ file["path"]
349
+ )
350
+ scored_files.append(
351
+ {
352
+ **file,
353
+ "score": file["example_score"],
354
+ "in_examples_dir": in_examples_dir,
355
+ "pattern_priority": pattern_priority,
356
+ "path_depth": path_depth,
357
+ }
358
+ )
359
+
360
+ if not scored_files:
361
+ return {
362
+ "formatted": f"No example files found in {org}/{repo}.",
363
+ "totalResults": 0,
364
+ "resultsShared": 0,
365
+ }
366
+
367
+ # Sort by: 1) files in examples/ dir first, 2) pattern priority (scripts > datasets > etc), 3) path depth, 4) path name
368
+ scored_files.sort(
369
+ key=lambda x: (
370
+ x["in_examples_dir"],
371
+ x["pattern_priority"],
372
+ x["path_depth"],
373
+ x["path"],
374
+ )
375
+ )
376
+
377
+ # Limit results
378
+ results = scored_files[:max_results]
379
+
380
+ # Format output
381
+ keyword_desc = f" matching '{keyword}'" if keyword else ""
382
+ lines = [f"**Found {len(results)} example files in {org}/{repo}{keyword_desc}:**"]
383
+ if len(scored_files) > max_results:
384
+ lines[0] += f" (showing {max_results} of {len(scored_files)})"
385
+ lines.append("")
386
+
387
+ for i, file in enumerate(results, 1):
388
+ lines.append(f"{i}. **{file['path']}**")
389
+ lines.append(f" Size: {file['size']:,} bytes | Ref: {file['ref'][:7]}")
390
+ lines.append(f" URL: {file['url']}")
391
+
392
+ # Copyable parameters for read_file tool
393
+ read_params = f"{{'repo': '{org}/{repo}', 'path': '{file['path']}'}}"
394
+ lines.append(f" To read, use: {read_params}")
395
+ lines.append("")
396
+
397
+ return {
398
+ "formatted": "\n".join(lines),
399
+ "totalResults": len(results),
400
+ "resultsShared": len(results),
401
+ }
402
+
403
+
404
+ # Tool specification
405
+ GITHUB_FIND_EXAMPLES_TOOL_SPEC = {
406
+ "name": "github_find_examples",
407
+ "description": (
408
+ "Discover working code examples, tutorials, scripts, and demos in GitHub repositories. "
409
+ "⚠️ CRITICAL: ALWAYS use this BEFORE implementing ML tasks - find working reference code first. "
410
+ "Your training data may be outdated; real repository examples show current best practices. "
411
+ "**Use when:** (1) Starting any ML implementation (training, inference, evaluation), "
412
+ "(2) User asks 'how to' questions about libraries, (3) Need reference implementations, "
413
+ "(4) Exploring library capabilities, (5) Before writing training/processing scripts. "
414
+ "**Pattern:** github_find_examples (discover) → github_read_file (study code) → implement with researched approach. "
415
+ "Returns: List of example files (scripts/notebooks/tutorials) with paths and URLs, sorted by relevance. "
416
+ "**Then:** Use github_read_file to read the actual implementation code. "
417
+ "**Critical for reliability:** Real examples prevent outdated API usage and show proven patterns. "
418
+ "## How it works\n\n"
419
+ "1. Fetches all example files (examples/, scripts/, tutorials/, demos/, notebooks/, etc.) from repository\n"
420
+ "2. If keyword provided, scores files against keyword using fuzzy matching\n"
421
+ "3. Returns best matches sorted by relevance and pattern priority\n"
422
+ "4. Provides copyable parameters for github_read_file tool\n\n"
423
+ "## Examples\n\n"
424
+ "<example>\n"
425
+ "// ML Workflow Step: Find GRPO training examples before implementation\n"
426
+ "// Task: Starting GRPO fine-tuning project, need reference implementation\n"
427
+ "{\n"
428
+ " keyword: 'grpo',\n"
429
+ " repo: 'trl',\n"
430
+ " org: 'huggingface'\n"
431
+ "}\n"
432
+ "// Returns: examples/scripts/grpo_agent.py, examples/scripts/grpo_vlm.py\n"
433
+ "// Next step: github_read_file to study working implementation\n"
434
+ "</example>\n\n"
435
+ "<example>\n"
436
+ "// ML Workflow Step: Discover all available training methods\n"
437
+ "// Task: Exploring TRL training options before choosing approach\n"
438
+ "{\n"
439
+ " repo: 'trl',\n"
440
+ " org: 'huggingface',\n"
441
+ " max_results: 20\n"
442
+ "}\n"
443
+ "// Lists: SFT, DPO, GRPO, PPO, reward modeling examples\n"
444
+ "// Helps user choose appropriate method\n"
445
+ "</example>\n\n"
446
+ "<example>\n"
447
+ "// ML Workflow Step: Find LoRA fine-tuning examples\n"
448
+ "// Task: Learning parameter-efficient fine-tuning patterns\n"
449
+ "{\n"
450
+ " keyword: 'lora',\n"
451
+ " repo: 'peft',\n"
452
+ " org: 'huggingface'\n"
453
+ "}\n"
454
+ "// Discovers LoRA configuration and training examples\n"
455
+ "// Shows current PEFT API usage patterns\n"
456
+ "</example>"
457
+ ),
458
+ "parameters": {
459
+ "type": "object",
460
+ "properties": {
461
+ "keyword": {
462
+ "type": "string",
463
+ "description": "Keyword to fuzzy match against file paths (e.g., 'grpo', 'sft').",
464
+ },
465
+ "repo": {
466
+ "type": "string",
467
+ "description": "Repository name (e.g., 'trl', 'transformers'). Required.",
468
+ },
469
+ "org": {
470
+ "type": "string",
471
+ "description": "GitHub organization or username. Default: 'huggingface'.",
472
+ },
473
+ "max_results": {
474
+ "type": "integer",
475
+ "description": "Maximum number of results to return. Default: 50.",
476
+ },
477
+ "min_score": {
478
+ "type": "integer",
479
+ "description": "Minimum fuzzy match score (0-100). Default: 60.",
480
+ },
481
+ },
482
+ "required": ["repo"],
483
+ },
484
+ }
485
+
486
+
487
+ async def github_find_examples_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
488
+ """Handler for agent tool router"""
489
+ try:
490
+ result = find_examples(
491
+ keyword=arguments.get("keyword", ""),
492
+ repo=arguments["repo"],
493
+ org=arguments.get("org", "huggingface"),
494
+ max_results=arguments.get("max_results", 50),
495
+ min_score=arguments.get("min_score", 60),
496
+ )
497
+ return result["formatted"], not result.get("isError", False)
498
+ except Exception as e:
499
+ return f"Error finding examples: {str(e)}", False
agent/tools/github_list_repos.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GitHub List Repositories Tool - List and sort repositories for any user or organization
3
+
4
+ Efficiently discover repositories with flexible sorting options.
5
+ """
6
+
7
+ import os
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ import requests
11
+
12
+ from agent.tools.types import ToolResult
13
+
14
+
15
+ def list_repos(
16
+ owner: str,
17
+ owner_type: Literal["user", "org"] = "org",
18
+ sort: Literal["stars", "forks", "updated", "created"] = "stars",
19
+ order: Literal["asc", "desc"] = "desc",
20
+ limit: Optional[int] = 30,
21
+ ) -> ToolResult:
22
+ """
23
+ List repositories for a user or organization using GitHub REST API.
24
+
25
+ Args:
26
+ owner: GitHub username or organization name
27
+ owner_type: Whether the owner is a "user" or "org" (default: "org")
28
+ sort: Sort field - "stars", "forks", "updated", or "created"
29
+ order: Sort order - "asc" or "desc" (default: "desc")
30
+ limit: Maximum number of repositories to return
31
+
32
+ Returns:
33
+ ToolResult with repository information
34
+ """
35
+ token = os.environ.get("GITHUB_TOKEN")
36
+ if not token:
37
+ return {
38
+ "formatted": "Error: GITHUB_TOKEN environment variable is required",
39
+ "totalResults": 0,
40
+ "resultsShared": 0,
41
+ "isError": True,
42
+ }
43
+
44
+ if owner_type == "org":
45
+ url = f"https://api.github.com/orgs/{owner}/repos"
46
+ else:
47
+ url = f"https://api.github.com/users/{owner}/repos"
48
+
49
+ headers = {
50
+ "Accept": "application/vnd.github+json",
51
+ "X-GitHub-Api-Version": "2022-11-28",
52
+ "Authorization": f"Bearer {token}",
53
+ }
54
+
55
+ all_repos = []
56
+ page = 1
57
+ per_page = 100 # Maximum allowed by GitHub
58
+
59
+ # Map our sort values to GitHub API sort values
60
+ # Note: GitHub list repos API doesn't support sorting by stars/forks
61
+ # We'll fetch all repos and sort in memory for those cases
62
+ api_sort_map = {
63
+ "created": "created",
64
+ "updated": "updated",
65
+ "stars": None, # Not supported by list API
66
+ "forks": None, # Not supported by list API
67
+ }
68
+
69
+ api_sort = api_sort_map.get(sort)
70
+ need_manual_sort = api_sort is None
71
+
72
+ try:
73
+ while True:
74
+ params = {
75
+ "page": page,
76
+ "per_page": per_page,
77
+ }
78
+
79
+ # Only add sort/direction if API supports it
80
+ if api_sort:
81
+ params["sort"] = api_sort
82
+ params["direction"] = order
83
+
84
+ response = requests.get(
85
+ url,
86
+ headers=headers,
87
+ params=params,
88
+ timeout=30,
89
+ )
90
+
91
+ if response.status_code == 403:
92
+ error_data = response.json()
93
+ return {
94
+ "formatted": f"GitHub API rate limit or permission error: {error_data.get('message', 'Unknown error')}",
95
+ "totalResults": 0,
96
+ "resultsShared": 0,
97
+ "isError": True,
98
+ }
99
+
100
+ if response.status_code != 200:
101
+ error_msg = f"GitHub API error (status {response.status_code})"
102
+ try:
103
+ error_data = response.json()
104
+ if "message" in error_data:
105
+ error_msg += f": {error_data['message']}"
106
+ except Exception:
107
+ pass
108
+ return {
109
+ "formatted": error_msg,
110
+ "totalResults": 0,
111
+ "resultsShared": 0,
112
+ "isError": True,
113
+ }
114
+
115
+ items = response.json()
116
+
117
+ if not items:
118
+ break
119
+
120
+ for item in items:
121
+ all_repos.append(
122
+ {
123
+ "name": item.get("name"),
124
+ "full_name": item.get("full_name"),
125
+ "description": item.get("description"),
126
+ "html_url": item.get("html_url"),
127
+ "language": item.get("language"),
128
+ "stars": item.get("stargazers_count", 0),
129
+ "forks": item.get("forks_count", 0),
130
+ "open_issues": item.get("open_issues_count", 0),
131
+ "topics": item.get("topics", []),
132
+ "updated_at": item.get("updated_at"),
133
+ "created_at": item.get("created_at"),
134
+ }
135
+ )
136
+
137
+ # Check if we got fewer results than requested (last page)
138
+ if len(items) < per_page:
139
+ break
140
+
141
+ # Stop if we have enough repos
142
+ if limit and len(all_repos) >= limit:
143
+ break
144
+
145
+ page += 1
146
+
147
+ except requests.exceptions.RequestException as e:
148
+ return {
149
+ "formatted": f"Failed to connect to GitHub API: {str(e)}",
150
+ "totalResults": 0,
151
+ "resultsShared": 0,
152
+ "isError": True,
153
+ }
154
+
155
+ # Manual sorting if needed (for stars/forks)
156
+ if need_manual_sort and all_repos:
157
+ reverse = order == "desc"
158
+ all_repos.sort(key=lambda x: x[sort], reverse=reverse)
159
+
160
+ # Apply limit after sorting
161
+ if limit:
162
+ all_repos = all_repos[:limit]
163
+
164
+ if not all_repos:
165
+ return {
166
+ "formatted": f"No repositories found for {owner_type} '{owner}'",
167
+ "totalResults": 0,
168
+ "resultsShared": 0,
169
+ }
170
+
171
+ # Format output
172
+ lines = [f"**Found {len(all_repos)} repositories for {owner}:**\n"]
173
+
174
+ for i, repo in enumerate(all_repos, 1):
175
+ lines.append(f"{i}. **{repo['full_name']}**")
176
+ lines.append(
177
+ f" ⭐ {repo['stars']:,} stars | 🍴 {repo['forks']:,} forks | Language: {repo['language'] or 'N/A'}"
178
+ )
179
+ if repo["description"]:
180
+ desc = (
181
+ repo["description"][:100] + "..."
182
+ if len(repo["description"]) > 100
183
+ else repo["description"]
184
+ )
185
+ lines.append(f" {desc}")
186
+ lines.append(f" URL: {repo['html_url']}")
187
+ if repo["topics"]:
188
+ lines.append(f" Topics: {', '.join(repo['topics'][:5])}")
189
+
190
+ # Copyable parameters for other tools
191
+ lines.append(f" Use in tools: {{'repo': '{repo['full_name']}'}}")
192
+ lines.append("")
193
+
194
+ return {
195
+ "formatted": "\n".join(lines),
196
+ "totalResults": len(all_repos),
197
+ "resultsShared": len(all_repos),
198
+ }
199
+
200
+
201
+ # Tool specification
202
+ GITHUB_LIST_REPOS_TOOL_SPEC = {
203
+ "name": "github_list_repos",
204
+ "description": (
205
+ "List and discover repositories for GitHub organizations or users with flexible sorting. "
206
+ "**Use when:** (1) Exploring what libraries exist for a task, (2) Finding the right library to use, "
207
+ "(3) Discovering popular or active projects, (4) Checking recently updated repos for latest features, "
208
+ "(5) Finding alternative libraries in an organization. "
209
+ "**Pattern:** github_list_repos (discover libraries) → github_find_examples (find usage examples) → implement. "
210
+ "Returns: Comprehensive repository information (stars, forks, language, topics, URLs), sorted by preference. "
211
+ "**Then:** Use github_find_examples on selected repo to discover example code. "
212
+ "Sorts by: stars (popularity), forks (community), updated (activity), created (age).\n\n"
213
+ "## When to use this tool\n\n"
214
+ "- When you need to find libraries to use in your implementation\n"
215
+ "- When exploring what repositories exist for a task or domain\n"
216
+ "- When debugging an error and looking up if others have similar issues in repos\n"
217
+ "- When finding the most popular or actively maintained projects for a user/org\n"
218
+ "## Examples\n\n"
219
+ "<example>\n"
220
+ "// ML Workflow Step: Discover HF libraries for RLHF/alignment\n"
221
+ "// Use case: Find the right library for training with human feedback\n"
222
+ "{\n"
223
+ " owner: 'huggingface',\n"
224
+ " owner_type: 'org',\n"
225
+ " sort: 'stars',\n"
226
+ " limit: 10\n"
227
+ "}\n"
228
+ "// Returns: transformers, trl, peft, accelerate, diffusers...\n"
229
+ "</example>\n\n"
230
+ "<example>\n"
231
+ "// ML Workflow Step: Check for recently updated HF repos\n"
232
+ "// Use case: Find actively maintained libraries with latest features\n"
233
+ "{\n"
234
+ " owner: 'huggingface',\n"
235
+ " owner_type: 'org',\n"
236
+ " sort: 'updated',\n"
237
+ " order: 'desc',\n"
238
+ " limit: 15\n"
239
+ "}\n"
240
+ "// Helps identify which repos have recent improvements/fixes\n"
241
+ "</example>"
242
+ ),
243
+ "parameters": {
244
+ "type": "object",
245
+ "properties": {
246
+ "owner": {
247
+ "type": "string",
248
+ "description": "GitHub username or organization name. Required.",
249
+ },
250
+ "owner_type": {
251
+ "type": "string",
252
+ "enum": ["user", "org"],
253
+ "description": "Whether the owner is a 'user' or 'org'. Default: 'org'.",
254
+ },
255
+ "sort": {
256
+ "type": "string",
257
+ "enum": ["stars", "forks", "updated", "created"],
258
+ "description": "Sort field. Options: 'stars', 'forks', 'updated', 'created'. Default: 'stars'.",
259
+ },
260
+ "order": {
261
+ "type": "string",
262
+ "enum": ["asc", "desc"],
263
+ "description": "Sort order. Options: 'asc', 'desc'. Default: 'desc'.",
264
+ },
265
+ "limit": {
266
+ "type": "integer",
267
+ "description": "Maximum number of repositories to return. No limit if not specified. Default: 30.",
268
+ },
269
+ },
270
+ "required": ["owner"],
271
+ },
272
+ }
273
+
274
+
275
+ async def github_list_repos_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
276
+ """Handler for agent tool router"""
277
+ try:
278
+ result = list_repos(
279
+ owner=arguments["owner"],
280
+ owner_type=arguments.get("owner_type", "org"),
281
+ sort=arguments.get("sort", "stars"),
282
+ order=arguments.get("order", "desc"),
283
+ limit=arguments.get("limit"),
284
+ )
285
+ return result["formatted"], not result.get("isError", False)
286
+ except Exception as e:
287
+ return f"Error listing repositories: {str(e)}", False
agent/tools/github_read_file.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GitHub Read File Tool - Read file contents from any GitHub repository with line range support
3
+
4
+ Fetch exact file contents with metadata, supporting line ranges for efficient reading.
5
+ """
6
+
7
+ import base64
8
+ import json
9
+ import os
10
+ from typing import Any, Dict, Optional
11
+
12
+ import nbformat
13
+ import requests
14
+ from nbconvert import MarkdownExporter
15
+ from nbconvert.preprocessors import ClearOutputPreprocessor, TagRemovePreprocessor
16
+
17
+ from agent.tools.types import ToolResult
18
+
19
+
20
+ def _convert_ipynb_to_markdown(content: str) -> str:
21
+ """
22
+ Convert Jupyter notebook JSON to LLM-friendly Markdown.
23
+
24
+ Args:
25
+ content: Raw notebook JSON string
26
+
27
+ Returns:
28
+ Converted Markdown string
29
+ """
30
+ try:
31
+ # Parse notebook JSON
32
+ nb_dict = json.loads(content)
33
+
34
+ # Normalize cell sources (can be string or list of strings)
35
+ if "cells" in nb_dict:
36
+ for cell in nb_dict["cells"]:
37
+ if "source" in cell and isinstance(cell["source"], list):
38
+ cell["source"] = "".join(cell["source"])
39
+
40
+ # Read notebook with explicit version
41
+ nb = nbformat.reads(json.dumps(nb_dict), as_version=4)
42
+
43
+ # Strip outputs for LLM readability (outputs can be noisy/large)
44
+ clear = ClearOutputPreprocessor()
45
+ nb, _ = clear.preprocess(nb, {})
46
+
47
+ # Optionally remove cells tagged with "hide" or similar
48
+ remove = TagRemovePreprocessor(
49
+ remove_cell_tags={"hide", "hidden", "remove"},
50
+ remove_input_tags=set(),
51
+ remove_all_outputs_tags=set(),
52
+ )
53
+ nb, _ = remove.preprocess(nb, {})
54
+
55
+ # Convert to markdown
56
+ exporter = MarkdownExporter()
57
+ markdown, _ = exporter.from_notebook_node(nb)
58
+
59
+ return markdown
60
+
61
+ except json.JSONDecodeError:
62
+ return content
63
+ except Exception:
64
+ return content
65
+
66
+
67
+ def read_file(
68
+ repo: str,
69
+ path: str,
70
+ ref: str = "HEAD",
71
+ line_start: Optional[int] = None,
72
+ line_end: Optional[int] = None,
73
+ ) -> ToolResult:
74
+ """
75
+ Read file contents from a GitHub repository with line range support.
76
+
77
+ Args:
78
+ repo: Repository in format "owner/repo" (e.g., "github/github-mcp-server")
79
+ path: Path to file in repository (e.g., "pkg/github/search.go")
80
+ ref: Git reference - branch name, tag, or commit SHA (default: "HEAD")
81
+ line_start: Starting line number (1-indexed, inclusive)
82
+ line_end: Ending line number (1-indexed, inclusive)
83
+
84
+ Returns:
85
+ ToolResult with file contents and metadata
86
+ """
87
+ token = os.environ.get("GITHUB_TOKEN")
88
+ if not token:
89
+ return {
90
+ "formatted": "Error: GITHUB_TOKEN environment variable is required",
91
+ "totalResults": 0,
92
+ "resultsShared": 0,
93
+ "isError": True,
94
+ }
95
+
96
+ # Parse repo
97
+ if "/" not in repo:
98
+ return {
99
+ "formatted": "Error: repo must be in format 'owner/repo'",
100
+ "totalResults": 0,
101
+ "resultsShared": 0,
102
+ "isError": True,
103
+ }
104
+
105
+ owner, repo_name = repo.split("/", 1)
106
+
107
+ headers = {
108
+ "Accept": "application/vnd.github+json",
109
+ "X-GitHub-Api-Version": "2022-11-28",
110
+ "Authorization": f"Bearer {token}",
111
+ }
112
+
113
+ # Fetch file contents
114
+ url = f"https://api.github.com/repos/{owner}/{repo_name}/contents/{path}"
115
+ params = {}
116
+ if ref and ref != "HEAD":
117
+ params["ref"] = ref
118
+
119
+ try:
120
+ response = requests.get(url, headers=headers, params=params, timeout=30)
121
+
122
+ if response.status_code == 404:
123
+ return {
124
+ "formatted": f"File not found: {path} in {repo} (ref: {ref})",
125
+ "totalResults": 0,
126
+ "resultsShared": 0,
127
+ "isError": True,
128
+ }
129
+
130
+ if response.status_code != 200:
131
+ error_msg = f"GitHub API error (status {response.status_code})"
132
+ try:
133
+ error_data = response.json()
134
+ if "message" in error_data:
135
+ error_msg += f": {error_data['message']}"
136
+ except Exception:
137
+ pass
138
+ return {
139
+ "formatted": error_msg,
140
+ "totalResults": 0,
141
+ "resultsShared": 0,
142
+ "isError": True,
143
+ }
144
+
145
+ data = response.json()
146
+
147
+ # Check if it's a file
148
+ if data.get("type") != "file":
149
+ return {
150
+ "formatted": f"Path {path} is not a file (type: {data.get('type')})",
151
+ "totalResults": 0,
152
+ "resultsShared": 0,
153
+ "isError": True,
154
+ }
155
+
156
+ # Decode content
157
+ content_b64 = data.get("content", "")
158
+ if content_b64:
159
+ content_b64 = content_b64.replace("\n", "").replace(" ", "")
160
+ content = base64.b64decode(content_b64).decode("utf-8", errors="replace")
161
+ else:
162
+ # For large files, fetch raw content
163
+ raw_headers = {
164
+ "Accept": "application/vnd.github.raw",
165
+ "X-GitHub-Api-Version": "2022-11-28",
166
+ "Authorization": f"Bearer {token}",
167
+ }
168
+ raw_response = requests.get(
169
+ url, headers=raw_headers, params=params, timeout=30
170
+ )
171
+ if raw_response.status_code != 200:
172
+ return {
173
+ "formatted": "Failed to fetch file content",
174
+ "totalResults": 0,
175
+ "resultsShared": 0,
176
+ "isError": True,
177
+ }
178
+ content = raw_response.text
179
+
180
+ if path.lower().endswith(".ipynb"):
181
+ content = _convert_ipynb_to_markdown(content)
182
+
183
+ # Process line ranges
184
+ lines = content.split("\n")
185
+ total_lines = len(lines)
186
+
187
+ truncated = False
188
+
189
+ if line_start is None and line_end is None:
190
+ # No range specified
191
+ if total_lines > 300:
192
+ line_start = 1
193
+ line_end = 300
194
+ truncated = True
195
+ else:
196
+ line_start = 1
197
+ line_end = total_lines
198
+ else:
199
+ # Range specified
200
+ if line_start is None:
201
+ line_start = 1
202
+ if line_end is None:
203
+ line_end = total_lines
204
+
205
+ # Validate range
206
+ line_start = max(1, line_start)
207
+ line_end = min(total_lines, line_end)
208
+ if line_start > line_end:
209
+ return {
210
+ "formatted": f"Invalid range: line_start ({line_start}) > line_end ({line_end})",
211
+ "totalResults": 0,
212
+ "resultsShared": 0,
213
+ "isError": True,
214
+ }
215
+
216
+ # Extract lines
217
+ selected_lines = lines[line_start - 1 : line_end]
218
+ selected_content = "\n".join(selected_lines)
219
+
220
+ # Format output
221
+ lines_output = [f"**Reading file from repo: {repo}, path: {path}**"]
222
+
223
+ if ref and ref != "HEAD":
224
+ lines_output.append(f"Ref: {ref}")
225
+
226
+ lines_output.append("\n**File content:")
227
+ lines_output.append("```")
228
+ lines_output.append(selected_content)
229
+ lines_output.append("```")
230
+ if truncated:
231
+ lines_output.append(
232
+ f"Currently showing lines {line_start}-{line_end} out of {total_lines} total lines. Use line_start and line_end to view more lines."
233
+ )
234
+ return {
235
+ "formatted": "\n".join(lines_output),
236
+ "totalResults": 1,
237
+ "resultsShared": 1,
238
+ }
239
+
240
+ except requests.exceptions.RequestException as e:
241
+ return {
242
+ "formatted": f"Failed to connect to GitHub API: {str(e)}",
243
+ "totalResults": 0,
244
+ "resultsShared": 0,
245
+ "isError": True,
246
+ }
247
+
248
+
249
+ # Tool specification
250
+ GITHUB_READ_FILE_TOOL_SPEC = {
251
+ "name": "github_read_file",
252
+ "description": (
253
+ "Read file contents from GitHub repositories with line range support (default 300 lines). "
254
+ "⚠️ CRITICAL: Use AFTER github_find_examples to study working implementation code. "
255
+ "**Use when:** (1) Found example file via github_find_examples and need full code, "
256
+ "(2) Need to read trainer class implementation, (3) Study configuration patterns, "
257
+ "(4) Read specific code sections with line ranges, (5) Review code from specific branches/commits. "
258
+ "**Pattern:** github_find_examples (discover files) → github_read_file (read code) → implement using researched patterns. "
259
+ "Returns: File contents with line numbers, formatted for LLM reading. Auto-converts Jupyter notebooks to markdown. "
260
+ "**Then:** Implement using patterns and APIs from the example code. "
261
+ "**Critical for reliability:** Reading working examples prevents API errors and shows current best practices. "
262
+ "Use line_start/line_end for large files (>300 lines) to read specific sections.\n\n"
263
+ "## When to use this tool\n\n"
264
+ "- When reading example code, trainer implementations, or configuration files\n"
265
+ "- After github_find_examples returns file paths you want to study\n"
266
+ "- When investigating specific code sections with line ranges\n"
267
+ "- When reading from specific branches, tags, or commits (use ref parameter)\n\n"
268
+ "## When NOT to use this tool\n\n"
269
+ "- When you don't know exact file path (use github_find_examples or github_search_code first)\n"
270
+ "- When searching for code patterns across repos (use github_search_code instead)\n\n"
271
+ "## Examples\n\n"
272
+ "<example>\n"
273
+ "// ML Workflow Step: Read GRPO trainer class after finding via github_find_examples\n"
274
+ "// Use case: Understand GRPOTrainer API, parameters, and methods\n"
275
+ "{\n"
276
+ " repo: 'huggingface/trl',\n"
277
+ " path: 'trl/trainer/grpo_trainer.py',\n"
278
+ " line_start: 1,\n"
279
+ " line_end: 200\n"
280
+ "}\n"
281
+ "// Read class definition and constructor to understand current API\n"
282
+ "// Shows: __init__ parameters, configuration, required arguments\n"
283
+ "</example>\n\n"
284
+ "<example>\n"
285
+ "// ML Workflow Step: Study complete training script from examples\n"
286
+ "// Use case: Learn end-to-end VLM fine-tuning workflow\n"
287
+ "{\n"
288
+ " repo: 'huggingface/trl',\n"
289
+ " path: 'examples/scripts/grpo_vlm.py'\n"
290
+ "}\n"
291
+ "// Returns first 300 lines - shows full training setup\n"
292
+ "// Use line_start/line_end if need to read more\n"
293
+ "</example>\n\n"
294
+ "<example>\n"
295
+ "// ML Workflow Step: Check TrainingArguments configuration patterns\n"
296
+ "// Use case: Learn how to structure training configs correctly\n"
297
+ "{\n"
298
+ " repo: 'huggingface/transformers',\n"
299
+ " path: 'examples/pytorch/language-modeling/run_clm.py',\n"
300
+ " line_start: 50,\n"
301
+ " line_end: 150\n"
302
+ "}\n"
303
+ "// Read argument parsing and config setup section\n"
304
+ "// Shows: current parameter names, default values, best practices\n"
305
+ "</example>"
306
+ ),
307
+ "parameters": {
308
+ "type": "object",
309
+ "properties": {
310
+ "repo": {
311
+ "type": "string",
312
+ "description": "Repository in format 'owner/repo' (e.g., 'github/github-mcp-server'). Required.",
313
+ },
314
+ "path": {
315
+ "type": "string",
316
+ "description": "Path to file in repository (e.g., 'src/index.js'). Required.",
317
+ },
318
+ "ref": {
319
+ "type": "string",
320
+ "description": "Git reference - branch name, tag, or commit SHA. Default: 'HEAD'.",
321
+ },
322
+ "line_start": {
323
+ "type": "integer",
324
+ "description": "Starting line number (1-indexed, inclusive). Optional.",
325
+ },
326
+ "line_end": {
327
+ "type": "integer",
328
+ "description": "Ending line number (1-indexed, inclusive). Optional.",
329
+ },
330
+ },
331
+ "required": ["repo", "path"],
332
+ },
333
+ }
334
+
335
+
336
+ async def github_read_file_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
337
+ """Handler for agent tool router"""
338
+ try:
339
+ result = read_file(
340
+ repo=arguments["repo"],
341
+ path=arguments["path"],
342
+ ref=arguments.get("ref", "HEAD"),
343
+ line_start=arguments.get("line_start"),
344
+ line_end=arguments.get("line_end"),
345
+ )
346
+ return result["formatted"], not result.get("isError", False)
347
+ except Exception as e:
348
+ return f"Error reading file: {str(e)}", False
agent/tools/hf_repo_files_tool.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Repo Files Tool - File operations on Hugging Face repositories
3
+
4
+ Operations: list, read, upload, delete
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ from huggingface_hub import HfApi, hf_hub_download
11
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
12
+
13
+ from agent.tools.types import ToolResult
14
+
15
+ OperationType = Literal["list", "read", "upload", "delete"]
16
+
17
+
18
+ async def _async_call(func, *args, **kwargs):
19
+ """Wrap synchronous HfApi calls for async context."""
20
+ return await asyncio.to_thread(func, *args, **kwargs)
21
+
22
+
23
+ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
24
+ """Build the Hub URL for a repository."""
25
+ if repo_type == "model":
26
+ return f"https://huggingface.co/{repo_id}"
27
+ return f"https://huggingface.co/{repo_type}s/{repo_id}"
28
+
29
+
30
+ def _format_size(size_bytes: int) -> str:
31
+ """Format file size in human-readable form."""
32
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
33
+ if size_bytes < 1024:
34
+ return f"{size_bytes:.1f}{unit}"
35
+ size_bytes /= 1024
36
+ return f"{size_bytes:.1f}PB"
37
+
38
+
39
+ class HfRepoFilesTool:
40
+ """Tool for file operations on HF repos."""
41
+
42
+ def __init__(self, hf_token: Optional[str] = None):
43
+ self.api = HfApi(token=hf_token)
44
+
45
+ async def execute(self, args: Dict[str, Any]) -> ToolResult:
46
+ """Execute the specified operation."""
47
+ operation = args.get("operation")
48
+
49
+ if not operation:
50
+ return self._help()
51
+
52
+ try:
53
+ handlers = {
54
+ "list": self._list,
55
+ "read": self._read,
56
+ "upload": self._upload,
57
+ "delete": self._delete,
58
+ }
59
+
60
+ handler = handlers.get(operation)
61
+ if handler:
62
+ return await handler(args)
63
+ else:
64
+ return self._error(f"Unknown operation: {operation}. Valid: list, read, upload, delete")
65
+
66
+ except RepositoryNotFoundError:
67
+ return self._error(f"Repository not found: {args.get('repo_id')}")
68
+ except EntryNotFoundError:
69
+ return self._error(f"File not found: {args.get('path')}")
70
+ except Exception as e:
71
+ return self._error(f"Error: {str(e)}")
72
+
73
+ def _help(self) -> ToolResult:
74
+ """Show usage instructions."""
75
+ return {
76
+ "formatted": """**hf_repo_files** - File operations on HF repos
77
+
78
+ **Operations:**
79
+ - `list` - List files: `{"operation": "list", "repo_id": "gpt2"}`
80
+ - `read` - Read file: `{"operation": "read", "repo_id": "gpt2", "path": "config.json"}`
81
+ - `upload` - Upload: `{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "..."}`
82
+ - `delete` - Delete: `{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp"]}`
83
+
84
+ **Common params:** repo_id (required), repo_type (model/dataset/space), revision (default: main)""",
85
+ "totalResults": 1,
86
+ "resultsShared": 1,
87
+ }
88
+
89
+ async def _list(self, args: Dict[str, Any]) -> ToolResult:
90
+ """List files in a repository."""
91
+ repo_id = args.get("repo_id")
92
+ if not repo_id:
93
+ return self._error("repo_id is required")
94
+
95
+ repo_type = args.get("repo_type", "model")
96
+ revision = args.get("revision", "main")
97
+ path = args.get("path", "")
98
+
99
+ items = list(await _async_call(
100
+ self.api.list_repo_tree,
101
+ repo_id=repo_id,
102
+ repo_type=repo_type,
103
+ revision=revision,
104
+ path_in_repo=path,
105
+ recursive=True,
106
+ ))
107
+
108
+ if not items:
109
+ return {"formatted": f"No files in {repo_id}", "totalResults": 0, "resultsShared": 0}
110
+
111
+ lines = []
112
+ total_size = 0
113
+ for item in sorted(items, key=lambda x: x.path):
114
+ if hasattr(item, "size") and item.size:
115
+ total_size += item.size
116
+ lines.append(f"{item.path} ({_format_size(item.size)})")
117
+ else:
118
+ lines.append(f"{item.path}/")
119
+
120
+ url = _build_repo_url(repo_id, repo_type)
121
+ response = f"**{repo_id}** ({len(items)} files, {_format_size(total_size)})\n{url}/tree/{revision}\n\n" + "\n".join(lines)
122
+
123
+ return {"formatted": response, "totalResults": len(items), "resultsShared": len(items)}
124
+
125
+ async def _read(self, args: Dict[str, Any]) -> ToolResult:
126
+ """Read file content from a repository."""
127
+ repo_id = args.get("repo_id")
128
+ path = args.get("path")
129
+
130
+ if not repo_id:
131
+ return self._error("repo_id is required")
132
+ if not path:
133
+ return self._error("path is required")
134
+
135
+ repo_type = args.get("repo_type", "model")
136
+ revision = args.get("revision", "main")
137
+ max_chars = args.get("max_chars", 50000)
138
+
139
+ file_path = await _async_call(
140
+ hf_hub_download,
141
+ repo_id=repo_id,
142
+ filename=path,
143
+ repo_type=repo_type,
144
+ revision=revision,
145
+ token=self.api.token,
146
+ )
147
+
148
+ try:
149
+ with open(file_path, "r", encoding="utf-8") as f:
150
+ content = f.read()
151
+
152
+ truncated = len(content) > max_chars
153
+ if truncated:
154
+ content = content[:max_chars]
155
+
156
+ url = f"{_build_repo_url(repo_id, repo_type)}/blob/{revision}/{path}"
157
+ response = f"**{path}**{' (truncated)' if truncated else ''}\n{url}\n\n```\n{content}\n```"
158
+
159
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
160
+
161
+ except UnicodeDecodeError:
162
+ import os
163
+ size = os.path.getsize(file_path)
164
+ return {"formatted": f"Binary file ({_format_size(size)})", "totalResults": 1, "resultsShared": 1}
165
+
166
+ async def _upload(self, args: Dict[str, Any]) -> ToolResult:
167
+ """Upload content to a repository."""
168
+ repo_id = args.get("repo_id")
169
+ path = args.get("path")
170
+ content = args.get("content")
171
+
172
+ if not repo_id:
173
+ return self._error("repo_id is required")
174
+ if not path:
175
+ return self._error("path is required")
176
+ if content is None:
177
+ return self._error("content is required")
178
+
179
+ repo_type = args.get("repo_type", "model")
180
+ revision = args.get("revision", "main")
181
+ create_pr = args.get("create_pr", False)
182
+ commit_message = args.get("commit_message", f"Upload {path}")
183
+
184
+ file_bytes = content.encode("utf-8") if isinstance(content, str) else content
185
+
186
+ result = await _async_call(
187
+ self.api.upload_file,
188
+ path_or_fileobj=file_bytes,
189
+ path_in_repo=path,
190
+ repo_id=repo_id,
191
+ repo_type=repo_type,
192
+ revision=revision,
193
+ commit_message=commit_message,
194
+ create_pr=create_pr,
195
+ )
196
+
197
+ url = _build_repo_url(repo_id, repo_type)
198
+ if create_pr and hasattr(result, "pr_url"):
199
+ response = f"**Uploaded as PR**\n{result.pr_url}"
200
+ else:
201
+ response = f"**Uploaded:** {path}\n{url}/blob/{revision}/{path}"
202
+
203
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
204
+
205
+ async def _delete(self, args: Dict[str, Any]) -> ToolResult:
206
+ """Delete files from a repository."""
207
+ repo_id = args.get("repo_id")
208
+ patterns = args.get("patterns")
209
+
210
+ if not repo_id:
211
+ return self._error("repo_id is required")
212
+ if not patterns:
213
+ return self._error("patterns is required (list of paths/wildcards)")
214
+
215
+ if isinstance(patterns, str):
216
+ patterns = [patterns]
217
+
218
+ repo_type = args.get("repo_type", "model")
219
+ revision = args.get("revision", "main")
220
+ create_pr = args.get("create_pr", False)
221
+ commit_message = args.get("commit_message", f"Delete {', '.join(patterns)}")
222
+
223
+ await _async_call(
224
+ self.api.delete_files,
225
+ repo_id=repo_id,
226
+ delete_patterns=patterns,
227
+ repo_type=repo_type,
228
+ revision=revision,
229
+ commit_message=commit_message,
230
+ create_pr=create_pr,
231
+ )
232
+
233
+ response = f"**Deleted:** {', '.join(patterns)} from {repo_id}"
234
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
235
+
236
+ def _error(self, message: str) -> ToolResult:
237
+ """Return an error result."""
238
+ return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
239
+
240
+
241
+ # Tool specification
242
+ HF_REPO_FILES_TOOL_SPEC = {
243
+ "name": "hf_repo_files",
244
+ "description": (
245
+ "Read and write files in HF repos (models/datasets/spaces).\n\n"
246
+ "## Operations\n"
247
+ "- **list**: List files with sizes and structure\n"
248
+ "- **read**: Read file content (text files only)\n"
249
+ "- **upload**: Upload content to repo (can create PR)\n"
250
+ "- **delete**: Delete files/folders (supports wildcards like *.tmp)\n\n"
251
+ "## Use when\n"
252
+ "- Need to see what files exist in a repo\n"
253
+ "- Want to read config.json, README.md, or other text files\n"
254
+ "- Uploading training scripts, configs, or results to a repo\n"
255
+ "- Cleaning up temporary files from a repo\n\n"
256
+ "## Examples\n"
257
+ '{"operation": "list", "repo_id": "meta-llama/Llama-2-7b"}\n'
258
+ '{"operation": "read", "repo_id": "gpt2", "path": "config.json"}\n'
259
+ '{"operation": "upload", "repo_id": "my-model", "path": "README.md", "content": "# My Model"}\n'
260
+ '{"operation": "upload", "repo_id": "org/model", "path": "fix.py", "content": "...", "create_pr": true}\n'
261
+ '{"operation": "delete", "repo_id": "my-model", "patterns": ["*.tmp", "logs/"]}\n\n'
262
+ "## Notes\n"
263
+ "- For binary files (safetensors, bin), use list to see them but can't read content\n"
264
+ "- upload/delete require approval (can overwrite/destroy data)\n"
265
+ "- Use create_pr=true to propose changes instead of direct commit\n"
266
+ ),
267
+ "parameters": {
268
+ "type": "object",
269
+ "properties": {
270
+ "operation": {
271
+ "type": "string",
272
+ "enum": ["list", "read", "upload", "delete"],
273
+ "description": "Operation: list, read, upload, delete",
274
+ },
275
+ "repo_id": {
276
+ "type": "string",
277
+ "description": "Repository ID (e.g., 'username/repo-name')",
278
+ },
279
+ "repo_type": {
280
+ "type": "string",
281
+ "enum": ["model", "dataset", "space"],
282
+ "description": "Repository type (default: model)",
283
+ },
284
+ "revision": {
285
+ "type": "string",
286
+ "description": "Branch/tag/commit (default: main)",
287
+ },
288
+ "path": {
289
+ "type": "string",
290
+ "description": "File path for read/upload",
291
+ },
292
+ "content": {
293
+ "type": "string",
294
+ "description": "File content for upload",
295
+ },
296
+ "patterns": {
297
+ "type": "array",
298
+ "items": {"type": "string"},
299
+ "description": "Patterns to delete (e.g., ['*.tmp', 'logs/'])",
300
+ },
301
+ "create_pr": {
302
+ "type": "boolean",
303
+ "description": "Create PR instead of direct commit",
304
+ },
305
+ "commit_message": {
306
+ "type": "string",
307
+ "description": "Custom commit message",
308
+ },
309
+ },
310
+ "required": ["operation"],
311
+ },
312
+ }
313
+
314
+
315
+ async def hf_repo_files_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
316
+ """Handler for agent tool router."""
317
+ try:
318
+ tool = HfRepoFilesTool()
319
+ result = await tool.execute(arguments)
320
+ return result["formatted"], not result.get("isError", False)
321
+ except Exception as e:
322
+ return f"Error: {str(e)}", False
agent/tools/hf_repo_git_tool.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Repo Git Tool - Git-like operations on Hugging Face repositories
3
+
4
+ Operations: branches, tags, PRs, repo management
5
+ """
6
+
7
+ import asyncio
8
+ from typing import Any, Dict, Literal, Optional
9
+
10
+ from huggingface_hub import HfApi
11
+ from huggingface_hub.utils import RepositoryNotFoundError
12
+
13
+ from agent.tools.types import ToolResult
14
+
15
+ OperationType = Literal[
16
+ "create_branch", "delete_branch",
17
+ "create_tag", "delete_tag",
18
+ "list_refs",
19
+ "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
20
+ "create_repo", "update_repo",
21
+ ]
22
+
23
+
24
+ async def _async_call(func, *args, **kwargs):
25
+ """Wrap synchronous HfApi calls for async context."""
26
+ return await asyncio.to_thread(func, *args, **kwargs)
27
+
28
+
29
+ def _build_repo_url(repo_id: str, repo_type: str = "model") -> str:
30
+ """Build the Hub URL for a repository."""
31
+ if repo_type == "model":
32
+ return f"https://huggingface.co/{repo_id}"
33
+ return f"https://huggingface.co/{repo_type}s/{repo_id}"
34
+
35
+
36
+ class HfRepoGitTool:
37
+ """Tool for git-like operations on HF repos."""
38
+
39
+ def __init__(self, hf_token: Optional[str] = None):
40
+ self.api = HfApi(token=hf_token)
41
+
42
+ async def execute(self, args: Dict[str, Any]) -> ToolResult:
43
+ """Execute the specified operation."""
44
+ operation = args.get("operation")
45
+
46
+ if not operation:
47
+ return self._help()
48
+
49
+ try:
50
+ handlers = {
51
+ "create_branch": self._create_branch,
52
+ "delete_branch": self._delete_branch,
53
+ "create_tag": self._create_tag,
54
+ "delete_tag": self._delete_tag,
55
+ "list_refs": self._list_refs,
56
+ "create_pr": self._create_pr,
57
+ "list_prs": self._list_prs,
58
+ "get_pr": self._get_pr,
59
+ "merge_pr": self._merge_pr,
60
+ "close_pr": self._close_pr,
61
+ "comment_pr": self._comment_pr,
62
+ "change_pr_status": self._change_pr_status,
63
+ "create_repo": self._create_repo,
64
+ "update_repo": self._update_repo,
65
+ }
66
+
67
+ handler = handlers.get(operation)
68
+ if handler:
69
+ return await handler(args)
70
+ else:
71
+ ops = ", ".join(handlers.keys())
72
+ return self._error(f"Unknown operation: {operation}. Valid: {ops}")
73
+
74
+ except RepositoryNotFoundError:
75
+ return self._error(f"Repository not found: {args.get('repo_id')}")
76
+ except Exception as e:
77
+ return self._error(f"Error: {str(e)}")
78
+
79
+ def _help(self) -> ToolResult:
80
+ """Show usage instructions."""
81
+ return {
82
+ "formatted": """**hf_repo_git** - Git-like operations on HF repos
83
+
84
+ **Branch/Tag:**
85
+ - `create_branch`: `{"operation": "create_branch", "repo_id": "...", "branch": "dev"}`
86
+ - `delete_branch`: `{"operation": "delete_branch", "repo_id": "...", "branch": "dev"}`
87
+ - `create_tag`: `{"operation": "create_tag", "repo_id": "...", "tag": "v1.0"}`
88
+ - `delete_tag`: `{"operation": "delete_tag", "repo_id": "...", "tag": "v1.0"}`
89
+ - `list_refs`: `{"operation": "list_refs", "repo_id": "..."}`
90
+
91
+ **PRs:**
92
+ - `create_pr`: `{"operation": "create_pr", "repo_id": "...", "title": "..."}` (creates draft PR)
93
+ - `list_prs`: `{"operation": "list_prs", "repo_id": "..."}` (shows status: draft/open/merged/closed)
94
+ - `get_pr`: `{"operation": "get_pr", "repo_id": "...", "pr_num": 1}` (shows status)
95
+ - `change_pr_status`: `{"operation": "change_pr_status", "repo_id": "...", "pr_num": 1, "new_status": "open"}` (change draft to open)
96
+ - `merge_pr`: `{"operation": "merge_pr", "repo_id": "...", "pr_num": 1}`
97
+ - `close_pr`: `{"operation": "close_pr", "repo_id": "...", "pr_num": 1}`
98
+ - `comment_pr`: `{"operation": "comment_pr", "repo_id": "...", "pr_num": 1, "comment": "..."}`
99
+
100
+ **Repo:**
101
+ - `create_repo`: `{"operation": "create_repo", "repo_id": "my-model", "private": true}`
102
+ - `update_repo`: `{"operation": "update_repo", "repo_id": "...", "private": false}`""",
103
+ "totalResults": 1,
104
+ "resultsShared": 1,
105
+ }
106
+
107
+ # =========================================================================
108
+ # BRANCH OPERATIONS
109
+ # =========================================================================
110
+
111
+ async def _create_branch(self, args: Dict[str, Any]) -> ToolResult:
112
+ """Create a new branch."""
113
+ repo_id = args.get("repo_id")
114
+ branch = args.get("branch")
115
+
116
+ if not repo_id:
117
+ return self._error("repo_id is required")
118
+ if not branch:
119
+ return self._error("branch is required")
120
+
121
+ repo_type = args.get("repo_type", "model")
122
+ from_rev = args.get("from_rev", "main")
123
+
124
+ await _async_call(
125
+ self.api.create_branch,
126
+ repo_id=repo_id,
127
+ branch=branch,
128
+ revision=from_rev,
129
+ repo_type=repo_type,
130
+ exist_ok=args.get("exist_ok", False),
131
+ )
132
+
133
+ url = f"{_build_repo_url(repo_id, repo_type)}/tree/{branch}"
134
+ return {"formatted": f"**Branch created:** {branch}\n{url}", "totalResults": 1, "resultsShared": 1}
135
+
136
+ async def _delete_branch(self, args: Dict[str, Any]) -> ToolResult:
137
+ """Delete a branch."""
138
+ repo_id = args.get("repo_id")
139
+ branch = args.get("branch")
140
+
141
+ if not repo_id:
142
+ return self._error("repo_id is required")
143
+ if not branch:
144
+ return self._error("branch is required")
145
+
146
+ repo_type = args.get("repo_type", "model")
147
+
148
+ await _async_call(
149
+ self.api.delete_branch,
150
+ repo_id=repo_id,
151
+ branch=branch,
152
+ repo_type=repo_type,
153
+ )
154
+
155
+ return {"formatted": f"**Branch deleted:** {branch}", "totalResults": 1, "resultsShared": 1}
156
+
157
+ # =========================================================================
158
+ # TAG OPERATIONS
159
+ # =========================================================================
160
+
161
+ async def _create_tag(self, args: Dict[str, Any]) -> ToolResult:
162
+ """Create a tag."""
163
+ repo_id = args.get("repo_id")
164
+ tag = args.get("tag")
165
+
166
+ if not repo_id:
167
+ return self._error("repo_id is required")
168
+ if not tag:
169
+ return self._error("tag is required")
170
+
171
+ repo_type = args.get("repo_type", "model")
172
+ revision = args.get("revision", "main")
173
+ tag_message = args.get("tag_message", "")
174
+
175
+ await _async_call(
176
+ self.api.create_tag,
177
+ repo_id=repo_id,
178
+ tag=tag,
179
+ revision=revision,
180
+ tag_message=tag_message,
181
+ repo_type=repo_type,
182
+ exist_ok=args.get("exist_ok", False),
183
+ )
184
+
185
+ url = f"{_build_repo_url(repo_id, repo_type)}/tree/{tag}"
186
+ return {"formatted": f"**Tag created:** {tag}\n{url}", "totalResults": 1, "resultsShared": 1}
187
+
188
+ async def _delete_tag(self, args: Dict[str, Any]) -> ToolResult:
189
+ """Delete a tag."""
190
+ repo_id = args.get("repo_id")
191
+ tag = args.get("tag")
192
+
193
+ if not repo_id:
194
+ return self._error("repo_id is required")
195
+ if not tag:
196
+ return self._error("tag is required")
197
+
198
+ repo_type = args.get("repo_type", "model")
199
+
200
+ await _async_call(
201
+ self.api.delete_tag,
202
+ repo_id=repo_id,
203
+ tag=tag,
204
+ repo_type=repo_type,
205
+ )
206
+
207
+ return {"formatted": f"**Tag deleted:** {tag}", "totalResults": 1, "resultsShared": 1}
208
+
209
+ # =========================================================================
210
+ # LIST REFS
211
+ # =========================================================================
212
+
213
+ async def _list_refs(self, args: Dict[str, Any]) -> ToolResult:
214
+ """List branches and tags."""
215
+ repo_id = args.get("repo_id")
216
+
217
+ if not repo_id:
218
+ return self._error("repo_id is required")
219
+
220
+ repo_type = args.get("repo_type", "model")
221
+
222
+ refs = await _async_call(
223
+ self.api.list_repo_refs,
224
+ repo_id=repo_id,
225
+ repo_type=repo_type,
226
+ )
227
+
228
+ branches = [b.name for b in refs.branches] if refs.branches else []
229
+ tags = [t.name for t in refs.tags] if hasattr(refs, 'tags') and refs.tags else []
230
+
231
+ url = _build_repo_url(repo_id, repo_type)
232
+ lines = [f"**{repo_id}**", url, ""]
233
+
234
+ if branches:
235
+ lines.append(f"**Branches ({len(branches)}):** " + ", ".join(branches))
236
+ else:
237
+ lines.append("**Branches:** none")
238
+
239
+ if tags:
240
+ lines.append(f"**Tags ({len(tags)}):** " + ", ".join(tags))
241
+ else:
242
+ lines.append("**Tags:** none")
243
+
244
+ return {"formatted": "\n".join(lines), "totalResults": len(branches) + len(tags), "resultsShared": len(branches) + len(tags)}
245
+
246
+ # =========================================================================
247
+ # PR OPERATIONS
248
+ # =========================================================================
249
+
250
+ async def _create_pr(self, args: Dict[str, Any]) -> ToolResult:
251
+ """Create a pull request."""
252
+ repo_id = args.get("repo_id")
253
+ title = args.get("title")
254
+
255
+ if not repo_id:
256
+ return self._error("repo_id is required")
257
+ if not title:
258
+ return self._error("title is required")
259
+
260
+ repo_type = args.get("repo_type", "model")
261
+ description = args.get("description", "")
262
+
263
+ result = await _async_call(
264
+ self.api.create_pull_request,
265
+ repo_id=repo_id,
266
+ title=title,
267
+ description=description,
268
+ repo_type=repo_type,
269
+ )
270
+
271
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{result.num}"
272
+ return {
273
+ "formatted": f"**Draft PR #{result.num} created:** {title}\n{url}\n\nAdd commits via upload with revision=\"refs/pr/{result.num}\"",
274
+ "totalResults": 1,
275
+ "resultsShared": 1,
276
+ }
277
+
278
+ async def _list_prs(self, args: Dict[str, Any]) -> ToolResult:
279
+ """List PRs and discussions."""
280
+ repo_id = args.get("repo_id")
281
+
282
+ if not repo_id:
283
+ return self._error("repo_id is required")
284
+
285
+ repo_type = args.get("repo_type", "model")
286
+ status = args.get("status", "all") # open, closed, all
287
+
288
+ discussions = list(self.api.get_repo_discussions(
289
+ repo_id=repo_id,
290
+ repo_type=repo_type,
291
+ discussion_status=status if status != "all" else None,
292
+ ))
293
+
294
+ if not discussions:
295
+ return {"formatted": f"No discussions in {repo_id}", "totalResults": 0, "resultsShared": 0}
296
+
297
+ url = _build_repo_url(repo_id, repo_type)
298
+ lines = [f"**{repo_id}** - {len(discussions)} discussions", f"{url}/discussions", ""]
299
+
300
+ for d in discussions[:20]:
301
+ if d.status == "draft":
302
+ status_label = "[DRAFT]"
303
+ elif d.status == "open":
304
+ status_label = "[OPEN]"
305
+ elif d.status == "merged":
306
+ status_label = "[MERGED]"
307
+ else:
308
+ status_label = "[CLOSED]"
309
+ type_label = "PR" if d.is_pull_request else "D"
310
+ lines.append(f"{status_label} #{d.num} [{type_label}] {d.title}")
311
+
312
+ return {"formatted": "\n".join(lines), "totalResults": len(discussions), "resultsShared": min(20, len(discussions))}
313
+
314
+ async def _get_pr(self, args: Dict[str, Any]) -> ToolResult:
315
+ """Get PR details."""
316
+ repo_id = args.get("repo_id")
317
+ pr_num = args.get("pr_num")
318
+
319
+ if not repo_id:
320
+ return self._error("repo_id is required")
321
+ if not pr_num:
322
+ return self._error("pr_num is required")
323
+
324
+ repo_type = args.get("repo_type", "model")
325
+
326
+ pr = await _async_call(
327
+ self.api.get_discussion_details,
328
+ repo_id=repo_id,
329
+ discussion_num=int(pr_num),
330
+ repo_type=repo_type,
331
+ )
332
+
333
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
334
+ status_map = {
335
+ "draft": "Draft",
336
+ "open": "Open",
337
+ "merged": "Merged",
338
+ "closed": "Closed"
339
+ }
340
+ status = status_map.get(pr.status, pr.status.capitalize())
341
+ type_label = "Pull Request" if pr.is_pull_request else "Discussion"
342
+
343
+ lines = [
344
+ f"**{type_label} #{pr_num}:** {pr.title}",
345
+ f"**Status:** {status}",
346
+ f"**Author:** {pr.author}",
347
+ url,
348
+ ]
349
+
350
+ if pr.is_pull_request:
351
+ if pr.status == "draft":
352
+ lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
353
+ elif pr.status == "open":
354
+ lines.append(f"\nTo add commits: upload with revision=\"refs/pr/{pr_num}\"")
355
+
356
+ return {"formatted": "\n".join(lines), "totalResults": 1, "resultsShared": 1}
357
+
358
+ async def _merge_pr(self, args: Dict[str, Any]) -> ToolResult:
359
+ """Merge a pull request."""
360
+ repo_id = args.get("repo_id")
361
+ pr_num = args.get("pr_num")
362
+
363
+ if not repo_id:
364
+ return self._error("repo_id is required")
365
+ if not pr_num:
366
+ return self._error("pr_num is required")
367
+
368
+ repo_type = args.get("repo_type", "model")
369
+ comment = args.get("comment", "")
370
+
371
+ await _async_call(
372
+ self.api.merge_pull_request,
373
+ repo_id=repo_id,
374
+ discussion_num=int(pr_num),
375
+ comment=comment,
376
+ repo_type=repo_type,
377
+ )
378
+
379
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
380
+ return {"formatted": f"**PR #{pr_num} merged**\n{url}", "totalResults": 1, "resultsShared": 1}
381
+
382
+ async def _close_pr(self, args: Dict[str, Any]) -> ToolResult:
383
+ """Close a PR/discussion."""
384
+ repo_id = args.get("repo_id")
385
+ pr_num = args.get("pr_num")
386
+
387
+ if not repo_id:
388
+ return self._error("repo_id is required")
389
+ if not pr_num:
390
+ return self._error("pr_num is required")
391
+
392
+ repo_type = args.get("repo_type", "model")
393
+ comment = args.get("comment", "")
394
+
395
+ await _async_call(
396
+ self.api.change_discussion_status,
397
+ repo_id=repo_id,
398
+ discussion_num=int(pr_num),
399
+ new_status="closed",
400
+ comment=comment,
401
+ repo_type=repo_type,
402
+ )
403
+
404
+ return {"formatted": f"**Discussion #{pr_num} closed**", "totalResults": 1, "resultsShared": 1}
405
+
406
+ async def _comment_pr(self, args: Dict[str, Any]) -> ToolResult:
407
+ """Add a comment to a PR/discussion."""
408
+ repo_id = args.get("repo_id")
409
+ pr_num = args.get("pr_num")
410
+ comment = args.get("comment")
411
+
412
+ if not repo_id:
413
+ return self._error("repo_id is required")
414
+ if not pr_num:
415
+ return self._error("pr_num is required")
416
+ if not comment:
417
+ return self._error("comment is required")
418
+
419
+ repo_type = args.get("repo_type", "model")
420
+
421
+ await _async_call(
422
+ self.api.comment_discussion,
423
+ repo_id=repo_id,
424
+ discussion_num=int(pr_num),
425
+ comment=comment,
426
+ repo_type=repo_type,
427
+ )
428
+
429
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
430
+ return {"formatted": f"**Comment added to #{pr_num}**\n{url}", "totalResults": 1, "resultsShared": 1}
431
+
432
+ async def _change_pr_status(self, args: Dict[str, Any]) -> ToolResult:
433
+ """Change PR/discussion status (mainly to convert draft to open)."""
434
+ repo_id = args.get("repo_id")
435
+ pr_num = args.get("pr_num")
436
+ new_status = args.get("new_status")
437
+
438
+ if not repo_id:
439
+ return self._error("repo_id is required")
440
+ if not pr_num:
441
+ return self._error("pr_num is required")
442
+ if not new_status:
443
+ return self._error("new_status is required (open or closed)")
444
+
445
+ repo_type = args.get("repo_type", "model")
446
+ comment = args.get("comment", "")
447
+
448
+ await _async_call(
449
+ self.api.change_discussion_status,
450
+ repo_id=repo_id,
451
+ discussion_num=int(pr_num),
452
+ new_status=new_status,
453
+ comment=comment,
454
+ repo_type=repo_type,
455
+ )
456
+
457
+ url = f"{_build_repo_url(repo_id, repo_type)}/discussions/{pr_num}"
458
+ return {"formatted": f"**PR #{pr_num} status changed to {new_status}**\n{url}", "totalResults": 1, "resultsShared": 1}
459
+
460
+ # =========================================================================
461
+ # REPO MANAGEMENT
462
+ # =========================================================================
463
+
464
+ async def _create_repo(self, args: Dict[str, Any]) -> ToolResult:
465
+ """Create a new repository."""
466
+ repo_id = args.get("repo_id")
467
+
468
+ if not repo_id:
469
+ return self._error("repo_id is required")
470
+
471
+ repo_type = args.get("repo_type", "model")
472
+ private = args.get("private", True)
473
+ space_sdk = args.get("space_sdk")
474
+
475
+ if repo_type == "space" and not space_sdk:
476
+ return self._error("space_sdk required for spaces (gradio/streamlit/docker/static)")
477
+
478
+ kwargs = {
479
+ "repo_id": repo_id,
480
+ "repo_type": repo_type,
481
+ "private": private,
482
+ "exist_ok": args.get("exist_ok", False),
483
+ }
484
+ if space_sdk:
485
+ kwargs["space_sdk"] = space_sdk
486
+
487
+ result = await _async_call(self.api.create_repo, **kwargs)
488
+
489
+ return {
490
+ "formatted": f"**Repository created:** {repo_id}\n**Private:** {private}\n{result}",
491
+ "totalResults": 1,
492
+ "resultsShared": 1,
493
+ }
494
+
495
+ async def _update_repo(self, args: Dict[str, Any]) -> ToolResult:
496
+ """Update repository settings."""
497
+ repo_id = args.get("repo_id")
498
+
499
+ if not repo_id:
500
+ return self._error("repo_id is required")
501
+
502
+ repo_type = args.get("repo_type", "model")
503
+ private = args.get("private")
504
+ gated = args.get("gated")
505
+
506
+ if private is None and gated is None:
507
+ return self._error("Specify private (bool) or gated ('auto'/'manual'/false)")
508
+
509
+ kwargs = {"repo_id": repo_id, "repo_type": repo_type}
510
+ if private is not None:
511
+ kwargs["private"] = private
512
+ if gated is not None:
513
+ kwargs["gated"] = gated
514
+
515
+ await _async_call(self.api.update_repo_settings, **kwargs)
516
+
517
+ changes = []
518
+ if private is not None:
519
+ changes.append(f"private={private}")
520
+ if gated is not None:
521
+ changes.append(f"gated={gated}")
522
+
523
+ url = f"{_build_repo_url(repo_id, repo_type)}/settings"
524
+ return {"formatted": f"**Settings updated:** {', '.join(changes)}\n{url}", "totalResults": 1, "resultsShared": 1}
525
+
526
+ def _error(self, message: str) -> ToolResult:
527
+ """Return an error result."""
528
+ return {"formatted": message, "totalResults": 0, "resultsShared": 0, "isError": True}
529
+
530
+
531
+ # Tool specification
532
+ HF_REPO_GIT_TOOL_SPEC = {
533
+ "name": "hf_repo_git",
534
+ "description": (
535
+ "Git-like operations on HF repos: branches, tags, PRs, and repo management.\n\n"
536
+ "## Operations\n"
537
+ "**Branches:** create_branch, delete_branch, list_refs\n"
538
+ "**Tags:** create_tag, delete_tag\n"
539
+ "**PRs:** create_pr, list_prs, get_pr, merge_pr, close_pr, comment_pr, change_pr_status\n"
540
+ "**Repo:** create_repo, update_repo\n\n"
541
+ "## Use when\n"
542
+ "- Creating feature branches for experiments\n"
543
+ "- Tagging model versions (v1.0, v2.0)\n"
544
+ "- Opening PRs to contribute to repos you don't own\n"
545
+ "- Reviewing and merging PRs on your repos\n"
546
+ "- Creating new model/dataset/space repos\n"
547
+ "- Changing repo visibility (public/private) or gated access\n\n"
548
+ "## Examples\n"
549
+ '{"operation": "list_refs", "repo_id": "my-model"}\n'
550
+ '{"operation": "create_branch", "repo_id": "my-model", "branch": "experiment-v2"}\n'
551
+ '{"operation": "create_tag", "repo_id": "my-model", "tag": "v1.0", "revision": "main"}\n'
552
+ '{"operation": "create_pr", "repo_id": "org/model", "title": "Fix tokenizer config"}\n'
553
+ '{"operation": "change_pr_status", "repo_id": "my-model", "pr_num": 1, "new_status": "open"}\n'
554
+ '{"operation": "merge_pr", "repo_id": "my-model", "pr_num": 3}\n'
555
+ '{"operation": "create_repo", "repo_id": "my-new-model", "private": true}\n'
556
+ '{"operation": "update_repo", "repo_id": "my-model", "gated": "auto"}\n\n'
557
+ "## PR Workflow\n"
558
+ "1. create_pr → creates draft PR (empty by default)\n"
559
+ "2. Upload files with revision='refs/pr/N' to add commits\n"
560
+ "3. change_pr_status with new_status='open' to publish (convert draft to open)\n"
561
+ "4. merge_pr when ready\n\n"
562
+ "## Notes\n"
563
+ "- PR status: draft (default), open, merged, closed\n"
564
+ "- delete_branch, delete_tag, merge_pr, create_repo, update_repo require approval\n"
565
+ "- For spaces, create_repo needs space_sdk (gradio/streamlit/docker/static)\n"
566
+ "- gated options: 'auto' (instant), 'manual' (review), false (open)\n"
567
+ ),
568
+ "parameters": {
569
+ "type": "object",
570
+ "properties": {
571
+ "operation": {
572
+ "type": "string",
573
+ "enum": [
574
+ "create_branch", "delete_branch",
575
+ "create_tag", "delete_tag", "list_refs",
576
+ "create_pr", "list_prs", "get_pr", "merge_pr", "close_pr", "comment_pr", "change_pr_status",
577
+ "create_repo", "update_repo",
578
+ ],
579
+ "description": "Operation to execute",
580
+ },
581
+ "repo_id": {
582
+ "type": "string",
583
+ "description": "Repository ID (e.g., 'username/repo-name')",
584
+ },
585
+ "repo_type": {
586
+ "type": "string",
587
+ "enum": ["model", "dataset", "space"],
588
+ "description": "Repository type (default: model)",
589
+ },
590
+ "branch": {
591
+ "type": "string",
592
+ "description": "Branch name (create_branch, delete_branch)",
593
+ },
594
+ "from_rev": {
595
+ "type": "string",
596
+ "description": "Create branch from this revision (default: main)",
597
+ },
598
+ "tag": {
599
+ "type": "string",
600
+ "description": "Tag name (create_tag, delete_tag)",
601
+ },
602
+ "revision": {
603
+ "type": "string",
604
+ "description": "Revision for tag (default: main)",
605
+ },
606
+ "tag_message": {
607
+ "type": "string",
608
+ "description": "Tag description",
609
+ },
610
+ "title": {
611
+ "type": "string",
612
+ "description": "PR title (create_pr)",
613
+ },
614
+ "description": {
615
+ "type": "string",
616
+ "description": "PR description (create_pr)",
617
+ },
618
+ "pr_num": {
619
+ "type": "integer",
620
+ "description": "PR/discussion number",
621
+ },
622
+ "comment": {
623
+ "type": "string",
624
+ "description": "Comment text",
625
+ },
626
+ "status": {
627
+ "type": "string",
628
+ "enum": ["open", "closed", "all"],
629
+ "description": "Filter PRs by status (list_prs)",
630
+ },
631
+ "new_status": {
632
+ "type": "string",
633
+ "enum": ["open", "closed"],
634
+ "description": "New status for PR/discussion (change_pr_status)",
635
+ },
636
+ "private": {
637
+ "type": "boolean",
638
+ "description": "Make repo private (create_repo, update_repo)",
639
+ },
640
+ "gated": {
641
+ "type": "string",
642
+ "enum": ["auto", "manual", "false"],
643
+ "description": "Gated access setting (update_repo)",
644
+ },
645
+ "space_sdk": {
646
+ "type": "string",
647
+ "enum": ["gradio", "streamlit", "docker", "static"],
648
+ "description": "Space SDK (required for create_repo with space)",
649
+ },
650
+ },
651
+ "required": ["operation"],
652
+ },
653
+ }
654
+
655
+
656
+ async def hf_repo_git_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
657
+ """Handler for agent tool router."""
658
+ try:
659
+ tool = HfRepoGitTool()
660
+ result = await tool.execute(arguments)
661
+ return result["formatted"], not result.get("isError", False)
662
+ except Exception as e:
663
+ return f"Error: {str(e)}", False
agent/tools/jobs_tool.py ADDED
@@ -0,0 +1,1042 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Jobs Tool - Using huggingface-hub library
3
+
4
+ Refactored to use official huggingface-hub library instead of custom HTTP client
5
+ """
6
+
7
+ import asyncio
8
+ import base64
9
+ import http.client
10
+ import os
11
+ import re
12
+ from typing import Any, Dict, Literal, Optional, Callable, Awaitable
13
+
14
+ import logging
15
+
16
+ import httpx
17
+ from huggingface_hub import HfApi
18
+ from huggingface_hub.utils import HfHubHTTPError
19
+
20
+ from agent.core.session import Event
21
+ from agent.tools.types import ToolResult
22
+
23
+ logger = logging.getLogger(__name__)
24
+ from agent.tools.utilities import (
25
+ format_job_details,
26
+ format_jobs_table,
27
+ format_scheduled_job_details,
28
+ format_scheduled_jobs_table,
29
+ )
30
+
31
+ # Hardware flavors
32
+ CPU_FLAVORS = ["cpu-basic", "cpu-upgrade", "cpu-performance", "cpu-xl"]
33
+ GPU_FLAVORS = [
34
+ "sprx8",
35
+ "zero-a10g",
36
+ "t4-small",
37
+ "t4-medium",
38
+ "l4x1",
39
+ "l4x4",
40
+ "l40sx1",
41
+ "l40sx4",
42
+ "l40sx8",
43
+ "a10g-small",
44
+ "a10g-large",
45
+ "a10g-largex2",
46
+ "a10g-largex4",
47
+ "a100-large",
48
+ "h100",
49
+ "h100x8",
50
+ ]
51
+
52
+ # Detailed specs for display (vCPU/RAM/GPU VRAM)
53
+ CPU_FLAVORS_DESC = (
54
+ "cpu-basic(2vCPU/16GB), cpu-upgrade(8vCPU/32GB), cpu-performance, cpu-xl"
55
+ )
56
+ GPU_FLAVORS_DESC = (
57
+ "t4-small(4vCPU/15GB/GPU 16GB), t4-medium(8vCPU/30GB/GPU 16GB), "
58
+ "l4x1(8vCPU/30GB/GPU 24GB), l4x4(48vCPU/186GB/GPU 96GB), "
59
+ "l40sx1(8vCPU/62GB/GPU 48GB), l40sx4(48vCPU/382GB/GPU 192GB), l40sx8(192vCPU/1534GB/GPU 384GB), "
60
+ "a10g-small(4vCPU/14GB/GPU 24GB), a10g-large(12vCPU/46GB/GPU 24GB), "
61
+ "a10g-largex2(24vCPU/92GB/GPU 48GB), a10g-largex4(48vCPU/184GB/GPU 96GB), "
62
+ "a100-large(12vCPU/142GB/GPU 80GB), h100(23vCPU/240GB/GPU 80GB), h100x8(184vCPU/1920GB/GPU 640GB), "
63
+ "zero-a10g(dynamic alloc)"
64
+ )
65
+ SPECIALIZED_FLAVORS = ["inf2x6"]
66
+ ALL_FLAVORS = CPU_FLAVORS + GPU_FLAVORS + SPECIALIZED_FLAVORS
67
+
68
+ # Operation names
69
+ OperationType = Literal[
70
+ "run",
71
+ "ps",
72
+ "logs",
73
+ "inspect",
74
+ "cancel",
75
+ "scheduled run",
76
+ "scheduled ps",
77
+ "scheduled inspect",
78
+ "scheduled delete",
79
+ "scheduled suspend",
80
+ "scheduled resume",
81
+ ]
82
+
83
+ # Constants
84
+ UV_DEFAULT_IMAGE = "ghcr.io/astral-sh/uv:python3.12-bookworm"
85
+
86
+
87
+ def _filter_uv_install_output(logs: list[str]) -> list[str]:
88
+ """
89
+ Filter out UV package installation output from logs.
90
+
91
+ Replaces installation details with "[installs truncated]" and keeps
92
+ the "Installed X packages in Y ms/s" summary line.
93
+
94
+ Args:
95
+ logs: List of log lines
96
+
97
+ Returns:
98
+ Filtered list of log lines
99
+ """
100
+ if not logs:
101
+ return logs
102
+
103
+ # Regex pattern to match: "Installed X packages in Y ms" or "Installed X package in Y s"
104
+ install_pattern = re.compile(
105
+ r"^Installed\s+\d+\s+packages?\s+in\s+\d+(?:\.\d+)?\s*(?:ms|s)$"
106
+ )
107
+
108
+ # Find the index of the "Installed X packages" line
109
+ install_line_idx = None
110
+ for idx, line in enumerate(logs):
111
+ if install_pattern.match(line.strip()):
112
+ install_line_idx = idx
113
+ break
114
+
115
+ # If pattern found, replace installation details with truncation message
116
+ if install_line_idx is not None and install_line_idx > 0:
117
+ # Keep logs from the "Installed X packages" line onward
118
+ # Add truncation message before the "Installed" line
119
+ return ["[installs truncated]"] + logs[install_line_idx:]
120
+
121
+ # If pattern not found, return original logs
122
+ return logs
123
+
124
+
125
+ def _add_environment_variables(
126
+ params: Dict[str, Any] | None, user_token: str | None = None
127
+ ) -> Dict[str, Any]:
128
+ # Prefer the authenticated user's OAuth token, fall back to global env var
129
+ token = user_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or ""
130
+
131
+ # Start with user-provided env vars, then force-set token last
132
+ result = dict(params or {})
133
+
134
+ # If the caller passed HF_TOKEN="$HF_TOKEN", ignore it.
135
+ if result.get("HF_TOKEN", "").strip().startswith("$"):
136
+ result.pop("HF_TOKEN", None)
137
+
138
+ # Set both names to be safe (different libs check different vars)
139
+ if token:
140
+ result["HF_TOKEN"] = token
141
+ result["HUGGINGFACE_HUB_TOKEN"] = token
142
+
143
+ return result
144
+
145
+
146
+ def _build_uv_command(
147
+ script: str,
148
+ with_deps: list[str] | None = None,
149
+ python: str | None = None,
150
+ script_args: list[str] | None = None,
151
+ ) -> list[str]:
152
+ """Build UV run command"""
153
+ parts = ["uv", "run"]
154
+
155
+ if with_deps:
156
+ for dep in with_deps:
157
+ parts.extend(["--with", dep])
158
+
159
+ if python:
160
+ parts.extend(["-p", python])
161
+
162
+ parts.append(script)
163
+
164
+ if script_args:
165
+ parts.extend(script_args)
166
+
167
+ # add defaults
168
+ # parts.extend(["--push_to_hub"])
169
+ return parts
170
+
171
+
172
+ def _wrap_inline_script(
173
+ script: str,
174
+ with_deps: list[str] | None = None,
175
+ python: str | None = None,
176
+ script_args: list[str] | None = None,
177
+ ) -> str:
178
+ """Wrap inline script with base64 encoding to avoid file creation"""
179
+ encoded = base64.b64encode(script.encode("utf-8")).decode("utf-8")
180
+ # Build the uv command with stdin (-)
181
+ uv_command = _build_uv_command("-", with_deps, python, script_args)
182
+ # Join command parts with proper spacing
183
+ uv_command_str = " ".join(uv_command)
184
+ return f'echo "{encoded}" | base64 -d | {uv_command_str}'
185
+
186
+
187
+ def _ensure_hf_transfer_dependency(deps: list[str] | None) -> list[str]:
188
+ """Ensure hf-transfer is included in the dependencies list"""
189
+
190
+ if isinstance(deps, list):
191
+ deps_copy = deps.copy() # Don't modify the original
192
+ if "hf-transfer" not in deps_copy:
193
+ deps_copy.append("hf-transfer")
194
+ return deps_copy
195
+
196
+ return ["hf-transfer"]
197
+
198
+
199
+ def _resolve_uv_command(
200
+ script: str,
201
+ with_deps: list[str] | None = None,
202
+ python: str | None = None,
203
+ script_args: list[str] | None = None,
204
+ ) -> list[str]:
205
+ """Resolve UV command based on script source (URL, inline, or file path)"""
206
+ # If URL, use directly
207
+ if script.startswith("http://") or script.startswith("https://"):
208
+ return _build_uv_command(script, with_deps, python, script_args)
209
+
210
+ # If contains newline, treat as inline script
211
+ if "\n" in script:
212
+ wrapped = _wrap_inline_script(script, with_deps, python, script_args)
213
+ return ["/bin/sh", "-lc", wrapped]
214
+
215
+ # Otherwise, treat as file path
216
+ return _build_uv_command(script, with_deps, python, script_args)
217
+
218
+
219
+ async def _async_call(func, *args, **kwargs):
220
+ """Wrap synchronous HfApi calls for async context"""
221
+ return await asyncio.to_thread(func, *args, **kwargs)
222
+
223
+
224
+ def _job_info_to_dict(job_info) -> Dict[str, Any]:
225
+ """Convert JobInfo object to dictionary for formatting functions"""
226
+ return {
227
+ "id": job_info.id,
228
+ "status": {"stage": job_info.status.stage, "message": job_info.status.message},
229
+ "command": job_info.command,
230
+ "createdAt": job_info.created_at.isoformat(),
231
+ "dockerImage": job_info.docker_image,
232
+ "spaceId": job_info.space_id,
233
+ "hardware_flavor": job_info.flavor,
234
+ "owner": {"name": job_info.owner.name},
235
+ }
236
+
237
+
238
+ def _scheduled_job_info_to_dict(scheduled_job_info) -> Dict[str, Any]:
239
+ """Convert ScheduledJobInfo object to dictionary for formatting functions"""
240
+ job_spec = scheduled_job_info.job_spec
241
+
242
+ # Extract last run and next run from status
243
+ last_run = None
244
+ next_run = None
245
+ if scheduled_job_info.status:
246
+ if scheduled_job_info.status.last_job:
247
+ last_run = scheduled_job_info.status.last_job.created_at
248
+ if last_run:
249
+ last_run = (
250
+ last_run.isoformat()
251
+ if hasattr(last_run, "isoformat")
252
+ else str(last_run)
253
+ )
254
+ if scheduled_job_info.status.next_job_run_at:
255
+ next_run = scheduled_job_info.status.next_job_run_at
256
+ next_run = (
257
+ next_run.isoformat()
258
+ if hasattr(next_run, "isoformat")
259
+ else str(next_run)
260
+ )
261
+
262
+ return {
263
+ "id": scheduled_job_info.id,
264
+ "schedule": scheduled_job_info.schedule,
265
+ "suspend": scheduled_job_info.suspend,
266
+ "lastRun": last_run,
267
+ "nextRun": next_run,
268
+ "jobSpec": {
269
+ "dockerImage": job_spec.docker_image,
270
+ "spaceId": job_spec.space_id,
271
+ "command": job_spec.command or [],
272
+ "hardware_flavor": job_spec.flavor or "cpu-basic",
273
+ },
274
+ }
275
+
276
+
277
+ class HfJobsTool:
278
+ """Tool for managing Hugging Face compute jobs using huggingface-hub library"""
279
+
280
+ def __init__(
281
+ self,
282
+ hf_token: Optional[str] = None,
283
+ namespace: Optional[str] = None,
284
+ log_callback: Optional[Callable[[str], Awaitable[None]]] = None,
285
+ ):
286
+ self.hf_token = hf_token
287
+ self.api = HfApi(token=hf_token)
288
+ self.namespace = namespace
289
+ self.log_callback = log_callback
290
+
291
+ async def execute(self, params: Dict[str, Any]) -> ToolResult:
292
+ """Execute the specified operation"""
293
+ operation = params.get("operation")
294
+
295
+ args = params
296
+
297
+ # If no operation provided, return error
298
+ if not operation:
299
+ return {
300
+ "formatted": "Error: 'operation' parameter is required. See tool description for available operations and usage examples.",
301
+ "totalResults": 0,
302
+ "resultsShared": 0,
303
+ "isError": True,
304
+ }
305
+
306
+ # Normalize operation name
307
+ operation = operation.lower()
308
+
309
+ try:
310
+ # Route to appropriate handler
311
+ if operation == "run":
312
+ return await self._run_job(args)
313
+ elif operation == "ps":
314
+ return await self._list_jobs(args)
315
+ elif operation == "logs":
316
+ return await self._get_logs(args)
317
+ elif operation == "inspect":
318
+ return await self._inspect_job(args)
319
+ elif operation == "cancel":
320
+ return await self._cancel_job(args)
321
+ elif operation == "scheduled run":
322
+ return await self._scheduled_run(args)
323
+ elif operation == "scheduled ps":
324
+ return await self._list_scheduled_jobs(args)
325
+ elif operation == "scheduled inspect":
326
+ return await self._inspect_scheduled_job(args)
327
+ elif operation == "scheduled delete":
328
+ return await self._delete_scheduled_job(args)
329
+ elif operation == "scheduled suspend":
330
+ return await self._suspend_scheduled_job(args)
331
+ elif operation == "scheduled resume":
332
+ return await self._resume_scheduled_job(args)
333
+ else:
334
+ return {
335
+ "formatted": f'Unknown operation: "{operation}"\n\n'
336
+ "Available operations:\n"
337
+ "- run, ps, logs, inspect, cancel\n"
338
+ "- scheduled run, scheduled ps, scheduled inspect, "
339
+ "scheduled delete, scheduled suspend, scheduled resume\n\n"
340
+ "Call this tool with no operation for full usage instructions.",
341
+ "totalResults": 0,
342
+ "resultsShared": 0,
343
+ "isError": True,
344
+ }
345
+
346
+ except HfHubHTTPError as e:
347
+ return {
348
+ "formatted": f"API Error: {str(e)}",
349
+ "totalResults": 0,
350
+ "resultsShared": 0,
351
+ "isError": True,
352
+ }
353
+ except Exception as e:
354
+ return {
355
+ "formatted": f"Error executing {operation}: {str(e)}",
356
+ "totalResults": 0,
357
+ "resultsShared": 0,
358
+ "isError": True,
359
+ }
360
+
361
+ async def _wait_for_job_completion(
362
+ self, job_id: str, namespace: Optional[str] = None
363
+ ) -> tuple[str, list[str]]:
364
+ """
365
+ Stream job logs until completion, printing them in real-time.
366
+ Implements retry logic to handle connection drops during long-running jobs.
367
+
368
+ Returns:
369
+ tuple: (final_status, all_logs)
370
+ """
371
+ all_logs = []
372
+ terminal_states = {"COMPLETED", "FAILED", "CANCELED", "ERROR"}
373
+ max_retries = 100 # Allow many retries for 8h+ jobs
374
+ retry_delay = 5 # Seconds between retries
375
+
376
+ for _ in range(max_retries):
377
+ try:
378
+ # Use a queue to bridge sync generator to async consumer
379
+ queue = asyncio.Queue()
380
+ loop = asyncio.get_running_loop()
381
+
382
+ def log_producer():
383
+ try:
384
+ # fetch_job_logs is a blocking sync generator
385
+ logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace)
386
+ for line in logs_gen:
387
+ # Push line to queue thread-safely
388
+ loop.call_soon_threadsafe(queue.put_nowait, line)
389
+ # Signal EOF
390
+ loop.call_soon_threadsafe(queue.put_nowait, None)
391
+ except Exception as e:
392
+ # Signal error
393
+ loop.call_soon_threadsafe(queue.put_nowait, e)
394
+
395
+ # Start producer in a background thread so it doesn't block the event loop
396
+ producer_future = loop.run_in_executor(None, log_producer)
397
+
398
+ # Consume logs from the queue as they arrive
399
+ while True:
400
+ item = await queue.get()
401
+
402
+ # EOF sentinel
403
+ if item is None:
404
+ break
405
+
406
+ # Error occurred in producer
407
+ if isinstance(item, Exception):
408
+ raise item
409
+
410
+ # Process log line
411
+ log_line = item
412
+ logger.debug(log_line)
413
+ if self.log_callback:
414
+ await self.log_callback(log_line)
415
+ all_logs.append(log_line)
416
+
417
+ # If we get here, streaming completed normally (EOF received)
418
+ # Wait for thread to cleanup (should be done)
419
+ await producer_future
420
+ break
421
+
422
+ except (
423
+ ConnectionError,
424
+ TimeoutError,
425
+ OSError,
426
+ http.client.IncompleteRead,
427
+ httpx.RemoteProtocolError,
428
+ httpx.ReadError,
429
+ HfHubHTTPError,
430
+ ) as e:
431
+ # Connection dropped - check if job is still running
432
+ try:
433
+ job_info = await _async_call(
434
+ self.api.inspect_job, job_id=job_id, namespace=namespace
435
+ )
436
+ current_status = job_info.status.stage
437
+
438
+ if current_status in terminal_states:
439
+ # Job finished, no need to retry
440
+ logger.info(f"Job reached terminal state: {current_status}")
441
+ break
442
+
443
+ # Job still running, retry connection
444
+ logger.warning(
445
+ f"Connection interrupted ({str(e)[:50]}...), reconnecting in {retry_delay}s..."
446
+ )
447
+ await asyncio.sleep(retry_delay)
448
+ continue
449
+
450
+ except (ConnectionError, TimeoutError, OSError):
451
+ # Can't even check job status, wait and retry
452
+ logger.warning(f"Connection error, retrying in {retry_delay}s...")
453
+ await asyncio.sleep(retry_delay)
454
+ continue
455
+
456
+ # Fetch final job status
457
+ job_info = await _async_call(
458
+ self.api.inspect_job, job_id=job_id, namespace=namespace
459
+ )
460
+ final_status = job_info.status.stage
461
+
462
+ return final_status, all_logs
463
+
464
+ async def _run_job(self, args: Dict[str, Any]) -> ToolResult:
465
+ """Run a job using HfApi.run_job() - smart detection of Python vs Docker mode"""
466
+ try:
467
+ script = args.get("script")
468
+ command = args.get("command")
469
+
470
+ # Validate mutually exclusive parameters
471
+ if script and command:
472
+ raise ValueError(
473
+ "'script' and 'command' are mutually exclusive. Provide one or the other, not both."
474
+ )
475
+
476
+ if not script and not command:
477
+ raise ValueError(
478
+ "Either 'script' (for Python) or 'command' (for Docker) must be provided."
479
+ )
480
+
481
+ # Python mode: script provided
482
+ if script:
483
+ # Get dependencies and ensure hf-transfer is included
484
+ deps = _ensure_hf_transfer_dependency(args.get("dependencies"))
485
+
486
+ # Resolve the command based on script type (URL, inline, or file)
487
+ command = _resolve_uv_command(
488
+ script=script,
489
+ with_deps=deps,
490
+ python=args.get("python"),
491
+ script_args=args.get("script_args"),
492
+ )
493
+
494
+ # Use UV image unless overridden
495
+ image = args.get("image", UV_DEFAULT_IMAGE)
496
+ job_type = "Python"
497
+
498
+ # Docker mode: command provided
499
+ else:
500
+ image = args.get("image", "python:3.12")
501
+ job_type = "Docker"
502
+
503
+ # Run the job
504
+ job = await _async_call(
505
+ self.api.run_job,
506
+ image=image,
507
+ command=command,
508
+ env=args.get("env"),
509
+ secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
510
+ flavor=args.get("hardware_flavor", "cpu-basic"),
511
+ timeout=args.get("timeout", "30m"),
512
+ namespace=self.namespace,
513
+ )
514
+
515
+ # Wait for completion and stream logs
516
+ logger.info(f"{job_type} job started: {job.url}")
517
+ logger.info("Streaming logs...")
518
+
519
+ final_status, all_logs = await self._wait_for_job_completion(
520
+ job_id=job.id,
521
+ namespace=self.namespace,
522
+ )
523
+
524
+ # Filter out UV package installation output
525
+ filtered_logs = _filter_uv_install_output(all_logs)
526
+
527
+ # Format all logs for the agent
528
+ log_text = "\n".join(filtered_logs) if filtered_logs else "(no logs)"
529
+
530
+ response = f"""{job_type} job completed!
531
+
532
+ **Job ID:** {job.id}
533
+ **Final Status:** {final_status}
534
+ **View at:** {job.url}
535
+
536
+ **Logs:**
537
+ ```
538
+ {log_text}
539
+ ```"""
540
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
541
+
542
+ except Exception as e:
543
+ raise Exception(f"Failed to run job: {str(e)}")
544
+
545
+ async def _list_jobs(self, args: Dict[str, Any]) -> ToolResult:
546
+ """List jobs using HfApi.list_jobs()"""
547
+ jobs_list = await _async_call(self.api.list_jobs, namespace=self.namespace)
548
+
549
+ # Filter jobs
550
+ if not args.get("all", False):
551
+ jobs_list = [j for j in jobs_list if j.status.stage == "RUNNING"]
552
+
553
+ if args.get("status"):
554
+ status_filter = args["status"].upper()
555
+ jobs_list = [j for j in jobs_list if status_filter in j.status.stage]
556
+
557
+ # Convert JobInfo objects to dicts for formatting
558
+ jobs_dicts = [_job_info_to_dict(j) for j in jobs_list]
559
+
560
+ table = format_jobs_table(jobs_dicts)
561
+
562
+ if len(jobs_list) == 0:
563
+ if args.get("all", False):
564
+ return {
565
+ "formatted": "No jobs found.",
566
+ "totalResults": 0,
567
+ "resultsShared": 0,
568
+ }
569
+ return {
570
+ "formatted": 'No running jobs found. Use `{"operation": "ps", "all": true}` to show all jobs.',
571
+ "totalResults": 0,
572
+ "resultsShared": 0,
573
+ }
574
+
575
+ response = f"**Jobs ({len(jobs_list)} total):**\n\n{table}"
576
+ return {
577
+ "formatted": response,
578
+ "totalResults": len(jobs_list),
579
+ "resultsShared": len(jobs_list),
580
+ }
581
+
582
+ async def _get_logs(self, args: Dict[str, Any]) -> ToolResult:
583
+ """Fetch logs using HfApi.fetch_job_logs()"""
584
+ job_id = args.get("job_id")
585
+ if not job_id:
586
+ return {
587
+ "formatted": "job_id is required",
588
+ "isError": True,
589
+ "totalResults": 0,
590
+ "resultsShared": 0,
591
+ }
592
+
593
+ try:
594
+ # Fetch logs (returns generator, convert to list)
595
+ logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=self.namespace)
596
+ logs = await _async_call(list, logs_gen)
597
+
598
+ if not logs:
599
+ return {
600
+ "formatted": f"No logs available for job {job_id}",
601
+ "totalResults": 0,
602
+ "resultsShared": 0,
603
+ }
604
+
605
+ log_text = "\n".join(logs)
606
+ return {
607
+ "formatted": f"**Logs for {job_id}:**\n\n```\n{log_text}\n```",
608
+ "totalResults": 1,
609
+ "resultsShared": 1,
610
+ }
611
+
612
+ except Exception as e:
613
+ return {
614
+ "formatted": f"Failed to fetch logs: {str(e)}",
615
+ "isError": True,
616
+ "totalResults": 0,
617
+ "resultsShared": 0,
618
+ }
619
+
620
+ async def _inspect_job(self, args: Dict[str, Any]) -> ToolResult:
621
+ """Inspect job using HfApi.inspect_job()"""
622
+ job_id = args.get("job_id")
623
+ if not job_id:
624
+ return {
625
+ "formatted": "job_id is required",
626
+ "totalResults": 0,
627
+ "resultsShared": 0,
628
+ "isError": True,
629
+ }
630
+
631
+ job_ids = job_id if isinstance(job_id, list) else [job_id]
632
+
633
+ jobs = []
634
+ for jid in job_ids:
635
+ try:
636
+ job = await _async_call(
637
+ self.api.inspect_job,
638
+ job_id=jid,
639
+ namespace=self.namespace,
640
+ )
641
+ jobs.append(_job_info_to_dict(job))
642
+ except Exception as e:
643
+ raise Exception(f"Failed to inspect job {jid}: {str(e)}")
644
+
645
+ formatted_details = format_job_details(jobs)
646
+ response = f"**Job Details** ({len(jobs)} job{'s' if len(jobs) > 1 else ''}):\n\n{formatted_details}"
647
+
648
+ return {
649
+ "formatted": response,
650
+ "totalResults": len(jobs),
651
+ "resultsShared": len(jobs),
652
+ }
653
+
654
+ async def _cancel_job(self, args: Dict[str, Any]) -> ToolResult:
655
+ """Cancel job using HfApi.cancel_job()"""
656
+ job_id = args.get("job_id")
657
+ if not job_id:
658
+ return {
659
+ "formatted": "job_id is required",
660
+ "totalResults": 0,
661
+ "resultsShared": 0,
662
+ "isError": True,
663
+ }
664
+
665
+ await _async_call(
666
+ self.api.cancel_job,
667
+ job_id=job_id,
668
+ namespace=self.namespace,
669
+ )
670
+
671
+ response = f"""✓ Job {job_id} has been cancelled.
672
+
673
+ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}}`"""
674
+
675
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
676
+
677
+ async def _scheduled_run(self, args: Dict[str, Any]) -> ToolResult:
678
+ """Create scheduled job using HfApi.create_scheduled_job() - smart detection of Python vs Docker mode"""
679
+ try:
680
+ script = args.get("script")
681
+ command = args.get("command")
682
+ schedule = args.get("schedule")
683
+
684
+ if not schedule:
685
+ raise ValueError("schedule is required for scheduled jobs")
686
+
687
+ # Validate mutually exclusive parameters
688
+ if script and command:
689
+ raise ValueError(
690
+ "'script' and 'command' are mutually exclusive. Provide one or the other, not both."
691
+ )
692
+
693
+ if not script and not command:
694
+ raise ValueError(
695
+ "Either 'script' (for Python) or 'command' (for Docker) must be provided."
696
+ )
697
+
698
+ # Python mode: script provided
699
+ if script:
700
+ # Get dependencies and ensure hf-transfer is included
701
+ deps = _ensure_hf_transfer_dependency(args.get("dependencies"))
702
+
703
+ # Resolve the command based on script type
704
+ command = _resolve_uv_command(
705
+ script=script,
706
+ with_deps=deps,
707
+ python=args.get("python"),
708
+ script_args=args.get("script_args"),
709
+ )
710
+
711
+ # Use UV image unless overridden
712
+ image = args.get("image", UV_DEFAULT_IMAGE)
713
+ job_type = "Python"
714
+
715
+ # Docker mode: command provided
716
+ else:
717
+ image = args.get("image", "python:3.12")
718
+ job_type = "Docker"
719
+
720
+ # Create scheduled job
721
+ scheduled_job = await _async_call(
722
+ self.api.create_scheduled_job,
723
+ image=image,
724
+ command=command,
725
+ schedule=schedule,
726
+ env=args.get("env"),
727
+ secrets=_add_environment_variables(args.get("secrets"), self.hf_token),
728
+ flavor=args.get("hardware_flavor", "cpu-basic"),
729
+ timeout=args.get("timeout", "30m"),
730
+ namespace=self.namespace,
731
+ )
732
+
733
+ scheduled_dict = _scheduled_job_info_to_dict(scheduled_job)
734
+
735
+ response = f"""✓ Scheduled {job_type} job created successfully!
736
+
737
+ **Scheduled Job ID:** {scheduled_dict["id"]}
738
+ **Schedule:** {scheduled_dict["schedule"]}
739
+ **Suspended:** {"Yes" if scheduled_dict.get("suspend") else "No"}
740
+ **Next Run:** {scheduled_dict.get("nextRun", "N/A")}
741
+
742
+ To inspect, call this tool with `{{"operation": "scheduled inspect", "scheduled_job_id": "{scheduled_dict["id"]}"}}`
743
+ To list all, call this tool with `{{"operation": "scheduled ps"}}`"""
744
+
745
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
746
+
747
+ except Exception as e:
748
+ raise Exception(f"Failed to create scheduled job: {str(e)}")
749
+
750
+ async def _list_scheduled_jobs(self, args: Dict[str, Any]) -> ToolResult:
751
+ """List scheduled jobs using HfApi.list_scheduled_jobs()"""
752
+ scheduled_jobs_list = await _async_call(
753
+ self.api.list_scheduled_jobs,
754
+ namespace=self.namespace,
755
+ )
756
+
757
+ # Filter jobs - default: hide suspended jobs unless --all is specified
758
+ if not args.get("all", False):
759
+ scheduled_jobs_list = [j for j in scheduled_jobs_list if not j.suspend]
760
+
761
+ # Convert to dicts for formatting
762
+ scheduled_dicts = [_scheduled_job_info_to_dict(j) for j in scheduled_jobs_list]
763
+
764
+ table = format_scheduled_jobs_table(scheduled_dicts)
765
+
766
+ if len(scheduled_jobs_list) == 0:
767
+ if args.get("all", False):
768
+ return {
769
+ "formatted": "No scheduled jobs found.",
770
+ "totalResults": 0,
771
+ "resultsShared": 0,
772
+ }
773
+ return {
774
+ "formatted": 'No active scheduled jobs found. Use `{"operation": "scheduled ps", "all": true}` to show suspended jobs.',
775
+ "totalResults": 0,
776
+ "resultsShared": 0,
777
+ }
778
+
779
+ response = f"**Scheduled Jobs ({len(scheduled_jobs_list)} total):**\n\n{table}"
780
+ return {
781
+ "formatted": response,
782
+ "totalResults": len(scheduled_jobs_list),
783
+ "resultsShared": len(scheduled_jobs_list),
784
+ }
785
+
786
+ async def _inspect_scheduled_job(self, args: Dict[str, Any]) -> ToolResult:
787
+ """Inspect scheduled job using HfApi.inspect_scheduled_job()"""
788
+ scheduled_job_id = args.get("scheduled_job_id")
789
+ if not scheduled_job_id:
790
+ return {
791
+ "formatted": "scheduled_job_id is required",
792
+ "totalResults": 0,
793
+ "resultsShared": 0,
794
+ "isError": True,
795
+ }
796
+
797
+ scheduled_job = await _async_call(
798
+ self.api.inspect_scheduled_job,
799
+ scheduled_job_id=scheduled_job_id,
800
+ namespace=self.namespace,
801
+ )
802
+
803
+ scheduled_dict = _scheduled_job_info_to_dict(scheduled_job)
804
+ formatted_details = format_scheduled_job_details(scheduled_dict)
805
+
806
+ return {
807
+ "formatted": f"**Scheduled Job Details:**\n\n{formatted_details}",
808
+ "totalResults": 1,
809
+ "resultsShared": 1,
810
+ }
811
+
812
+ async def _delete_scheduled_job(self, args: Dict[str, Any]) -> ToolResult:
813
+ """Delete scheduled job using HfApi.delete_scheduled_job()"""
814
+ scheduled_job_id = args.get("scheduled_job_id")
815
+ if not scheduled_job_id:
816
+ return {
817
+ "formatted": "scheduled_job_id is required",
818
+ "totalResults": 0,
819
+ "resultsShared": 0,
820
+ "isError": True,
821
+ }
822
+
823
+ await _async_call(
824
+ self.api.delete_scheduled_job,
825
+ scheduled_job_id=scheduled_job_id,
826
+ namespace=self.namespace,
827
+ )
828
+
829
+ return {
830
+ "formatted": f"✓ Scheduled job {scheduled_job_id} has been deleted.",
831
+ "totalResults": 1,
832
+ "resultsShared": 1,
833
+ }
834
+
835
+ async def _suspend_scheduled_job(self, args: Dict[str, Any]) -> ToolResult:
836
+ """Suspend scheduled job using HfApi.suspend_scheduled_job()"""
837
+ scheduled_job_id = args.get("scheduled_job_id")
838
+ if not scheduled_job_id:
839
+ return {
840
+ "formatted": "scheduled_job_id is required",
841
+ "totalResults": 0,
842
+ "resultsShared": 0,
843
+ "isError": True,
844
+ }
845
+
846
+ await _async_call(
847
+ self.api.suspend_scheduled_job,
848
+ scheduled_job_id=scheduled_job_id,
849
+ namespace=self.namespace,
850
+ )
851
+
852
+ response = f"""✓ Scheduled job {scheduled_job_id} has been suspended.
853
+
854
+ To resume, call this tool with `{{"operation": "scheduled resume", "scheduled_job_id": "{scheduled_job_id}"}}`"""
855
+
856
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
857
+
858
+ async def _resume_scheduled_job(self, args: Dict[str, Any]) -> ToolResult:
859
+ """Resume scheduled job using HfApi.resume_scheduled_job()"""
860
+ scheduled_job_id = args.get("scheduled_job_id")
861
+ if not scheduled_job_id:
862
+ return {
863
+ "formatted": "scheduled_job_id is required",
864
+ "totalResults": 0,
865
+ "resultsShared": 0,
866
+ "isError": True,
867
+ }
868
+
869
+ await _async_call(
870
+ self.api.resume_scheduled_job,
871
+ scheduled_job_id=scheduled_job_id,
872
+ namespace=self.namespace,
873
+ )
874
+
875
+ response = f"""✓ Scheduled job {scheduled_job_id} has been resumed.
876
+
877
+ To inspect, call this tool with `{{"operation": "scheduled inspect", "scheduled_job_id": "{scheduled_job_id}"}}`"""
878
+
879
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
880
+
881
+
882
+ # Tool specification for agent registration
883
+ HF_JOBS_TOOL_SPEC = {
884
+ "name": "hf_jobs",
885
+ "description": (
886
+ "Execute Python scripts or Docker containers on HF cloud infrastructure (CPUs/GPUs) in one of two modes. "
887
+ "\n\n"
888
+ "**Two Modes (mutually exclusive):**\n"
889
+ "1. Python mode: using 'script' arg (REQUIRED) + 'dependencies'\n"
890
+ "2. Docker mode: using 'command' arg (REQUIRED) + 'image'\n\n"
891
+ "🚨 **REQUIRED:** You MUST provide exactly ONE of: 'script' (Python code as string) OR 'command' (Docker command as array). "
892
+ "They are mutually exclusive - provide one or the other, never both, never neither. "
893
+ "Do NOT call with just {'operation': 'run'} - always include your code. Example: {'operation': 'run', 'script': 'import torch; print(torch.cuda.is_available())', 'dependencies': ['torch']} or {'operation': 'run', 'command': ['duckdb', '-c', 'select 1 + 2']', 'image': 'duckdb/duckdb'}\n\n"
894
+ "⚠️ CRITICAL for reliability: (1) Jobs run ASYNC - provide monitoring URL immediately, don't poll; "
895
+ "(2) Set timeout >30min (default too short - training needs 2-8h); "
896
+ "(3) HF_TOKEN auto-loaded to secrets for Hub ops (push_to_hub, private repos); "
897
+ "(4) Job storage EPHEMERAL - MUST push_to_hub() or ALL work is LOST. "
898
+ "**Use when:** User wants cloud compute, training models, data processing, batch inference, GPU workloads, scheduled tasks. "
899
+ "ALWAYS use this tool (✓), never bash 'hf jobs' commands (✗). Pass script content inline (✓), don't save to files unless requested (✗). "
900
+ "\n\n"
901
+ "**Operations:** run, ps, logs, inspect, cancel, scheduled run, scheduled ps, scheduled inspect, scheduled delete, scheduled suspend, scheduled resume. "
902
+ "**Available Hardware (vCPU/RAM/GPU):**\n"
903
+ f"• CPU: {CPU_FLAVORS_DESC}\n"
904
+ f"• GPU: {GPU_FLAVORS_DESC}\n"
905
+ " ◦ Common: t4-small ($0.60/hr, demos/1-3B models), a10g-small ($1/hr), a10g-large ($2/hr, production 7-13B), a100-large ($4/hr, 30B+), h100 ($6/hr, 70B+)\n\n"
906
+ "**After Submission Ground Rules:**\n"
907
+ "✓ Return immediately with job ID and monitoring URL\n"
908
+ "✓ Provide expected completion time and cost estimate\n"
909
+ "✓ For training: Include Trackio dashboard URL\n"
910
+ "✓ Note user can check status later\n"
911
+ "✗ DON'T poll logs automatically\n"
912
+ "✗ DON'T wait for completion\n"
913
+ "✗ DON'T check status unless user asks\n\n"
914
+ "**For Training Tasks:**\n"
915
+ "• ALWAYS research TRL docs first: explore_hf_docs('trl') → fetch_hf_docs(<trainer_url>)\n"
916
+ "• ALWAYS validate dataset format with hub_repo_details (SFT needs messages/text, DPO needs chosen/rejected)\n"
917
+ "• ALWAYS include Trackio monitoring in script (explore_hf_docs('trackio'))\n"
918
+ "• ALWAYS enable push_to_hub=True in training config\n"
919
+ "• Set timeout 2-8h for training (NOT default 30m)\n"
920
+ "• Confirm model/dataset choices with user before submitting\n\n"
921
+ "**Examples:**\n\n"
922
+ "**Training - Fine-tune LLM:**\n"
923
+ "{'operation': 'run', 'script': '# Training script with TRL\\nfrom trl import SFTConfig, SFTTrainer\\nfrom transformers import AutoModelForCausalLM\\nmodel = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen3-4B\")\\n# ... researched implementation from docs ...\\ntrainer.train()\\ntrainer.push_to_hub(\"user-name/my-model\")', 'dependencies': ['transformers', 'trl', 'torch', 'datasets', 'trackio'], 'hardware_flavor': 'a10g-large', 'timeout': '4h'}\n\n"
924
+ "**Data Processing:**\n"
925
+ "{'operation': 'run', 'script': 'from datasets import load_dataset\\nds = load_dataset(\"data\")\\n# process...\\nds.push_to_hub(\"user/processed\")', 'dependencies': ['datasets', 'pandas'], 'hardware_flavor': 'cpu-upgrade', 'timeout': '2h'}\n\n"
926
+ "**Scheduled Daily Job:**\n"
927
+ "{'operation': 'scheduled run', 'schedule': '@daily', 'script': 'from datasets import Dataset\\nimport pandas as pd\\n# scrape/generate data\\ndf = pd.DataFrame(data)\\nds = Dataset.from_pandas(df)\\nds.push_to_hub(\"user-name/daily-dataset\")', 'dependencies': ['datasets', 'pandas'], 'hardware_flavor': 'cpu-basic'}\n\n"
928
+ "**Docker Mode:**\n"
929
+ "{'operation': 'run', 'image': 'pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime', 'command': ['python', 'train.py', '--epochs', '10'], 'hardware_flavor': 'a100-large'}\n\n"
930
+ "**Monitor Operations:**\n"
931
+ "{'operation': 'ps'} - List all jobs\n"
932
+ "{'operation': 'logs', 'job_id': 'xxx'} - Stream logs (only when user requests)\n"
933
+ "{'operation': 'inspect', 'job_id': 'xxx'} - Get job details\n"
934
+ "{'operation': 'cancel', 'job_id': 'xxx'} - Stop job\n\n"
935
+ "⚠️ CRITICAL: Files created during execution are DELETED when job finishes. MUST push_to_hub() all outputs (models, datasets, artifacts) in script. For logs/scripts, use hf_private_repos after completion."
936
+ ),
937
+ "parameters": {
938
+ "type": "object",
939
+ "properties": {
940
+ "operation": {
941
+ "type": "string",
942
+ "enum": [
943
+ "run",
944
+ "ps",
945
+ "logs",
946
+ "inspect",
947
+ "cancel",
948
+ "scheduled run",
949
+ "scheduled ps",
950
+ "scheduled inspect",
951
+ "scheduled delete",
952
+ "scheduled suspend",
953
+ "scheduled resume",
954
+ ],
955
+ "description": (
956
+ "Operation to execute. Valid values: [run, ps, logs, inspect, cancel, "
957
+ "scheduled run, scheduled ps, scheduled inspect, scheduled delete, "
958
+ "scheduled suspend, scheduled resume]"
959
+ ),
960
+ },
961
+ # Python/UV specific parameters
962
+ "script": {
963
+ "type": "string",
964
+ "description": "Python code to execute. Triggers Python mode (auto pip install). Use with 'run'/'scheduled run'. Mutually exclusive with 'command'.",
965
+ },
966
+ "dependencies": {
967
+ "type": "array",
968
+ "items": {"type": "string"},
969
+ "description": "Pip packages to install. Example: ['trl', 'torch', 'datasets', 'transformers']. Only used with 'script'.",
970
+ },
971
+ # Docker specific parameters
972
+ "image": {
973
+ "type": "string",
974
+ "description": "Docker image. Example: 'pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime'. Use with 'run'/'scheduled run'. Optional (auto-selected if not provided).",
975
+ },
976
+ "command": {
977
+ "type": "array",
978
+ "items": {"type": "string"},
979
+ "description": "Command to execute as list. Example: ['python', 'train.py', '--epochs', '10']. Triggers Docker mode. Use with 'run'/'scheduled run'. Mutually exclusive with 'script'.",
980
+ },
981
+ # Hardware and environment
982
+ "hardware_flavor": {
983
+ "type": "string",
984
+ "description": f"Hardware type. Available CPU flavors: {CPU_FLAVORS}. Available GPU flavors: {GPU_FLAVORS}. Use with 'run'/'scheduled run'.",
985
+ },
986
+ "timeout": {
987
+ "type": "string",
988
+ "description": "Max runtime. Examples: '30m', '2h', '4h'. Default: '30m'. Important for long training jobs. Use with 'run'/'scheduled run'.",
989
+ },
990
+ "env": {
991
+ "type": "object",
992
+ "description": "Environment variables. Format: {'KEY': 'VALUE'}. HF_TOKEN is automatically included from your auth. Use with 'run'/'scheduled run'.",
993
+ },
994
+ # Job management parameters
995
+ "job_id": {
996
+ "type": "string",
997
+ "description": "Job ID to operate on. Required for: 'logs', 'inspect', 'cancel'.",
998
+ },
999
+ # Scheduled job parameters
1000
+ "scheduled_job_id": {
1001
+ "type": "string",
1002
+ "description": "Scheduled job ID. Required for: 'scheduled inspect', 'scheduled delete', 'scheduled suspend', 'scheduled resume'.",
1003
+ },
1004
+ "schedule": {
1005
+ "type": "string",
1006
+ "description": "Schedule for recurring job. Presets: '@hourly', '@daily', '@weekly', '@monthly'. Cron: '0 9 * * 1' (Mon 9am). Required for: 'scheduled run'.",
1007
+ },
1008
+ },
1009
+ "required": ["operation"],
1010
+ },
1011
+ }
1012
+
1013
+
1014
+ async def hf_jobs_handler(
1015
+ arguments: Dict[str, Any], session: Any = None
1016
+ ) -> tuple[str, bool]:
1017
+ """Handler for agent tool router"""
1018
+ try:
1019
+
1020
+ async def log_callback(log: str):
1021
+ if session:
1022
+ await session.send_event(
1023
+ Event(event_type="tool_log", data={"tool": "hf_jobs", "log": log})
1024
+ )
1025
+
1026
+ # Prefer the authenticated user's OAuth token, fall back to global env
1027
+ hf_token = (
1028
+ (getattr(session, "hf_token", None) if session else None)
1029
+ or os.environ.get("HF_TOKEN")
1030
+ or os.environ.get("HUGGINGFACE_HUB_TOKEN")
1031
+ )
1032
+ namespace = os.environ.get("HF_NAMESPACE") or (HfApi(token=hf_token).whoami().get("name") if hf_token else None)
1033
+
1034
+ tool = HfJobsTool(
1035
+ namespace=namespace,
1036
+ hf_token=hf_token,
1037
+ log_callback=log_callback if session else None,
1038
+ )
1039
+ result = await tool.execute(arguments)
1040
+ return result["formatted"], not result.get("isError", False)
1041
+ except Exception as e:
1042
+ return f"Error executing HF Jobs tool: {str(e)}", False
agent/tools/plan_tool.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from agent.core.session import Event
4
+ from agent.utils.terminal_display import format_plan_tool_output
5
+
6
+ from .types import ToolResult
7
+
8
+ # In-memory storage for the current plan (raw structure from agent)
9
+ _current_plan: List[Dict[str, str]] = []
10
+
11
+
12
+ class PlanTool:
13
+ """Tool for managing a list of todos with status tracking."""
14
+
15
+ def __init__(self, session: Any = None):
16
+ self.session = session
17
+
18
+ async def execute(self, params: Dict[str, Any]) -> ToolResult:
19
+ """
20
+ Execute the WritePlan operation.
21
+
22
+ Args:
23
+ params: Dictionary containing:
24
+ - todos: List of todo items, each with id, content, and status
25
+
26
+ Returns:
27
+ ToolResult with formatted output
28
+ """
29
+ global _current_plan
30
+
31
+ todos = params.get("todos", [])
32
+
33
+ # Validate todos structure
34
+ for todo in todos:
35
+ if not isinstance(todo, dict):
36
+ return {
37
+ "formatted": "Error: Each todo must be an object. Re call the tool with correct format (mandatory).",
38
+ "isError": True,
39
+ }
40
+
41
+ required_fields = ["id", "content", "status"]
42
+ for field in required_fields:
43
+ if field not in todo:
44
+ return {
45
+ "formatted": f"Error: Todo missing required field '{field}'. Re call the tool with correct format (mandatory).",
46
+ "isError": True,
47
+ }
48
+
49
+ # Validate status
50
+ valid_statuses = ["pending", "in_progress", "completed"]
51
+ if todo["status"] not in valid_statuses:
52
+ return {
53
+ "formatted": f"Error: Invalid status '{todo['status']}'. Must be one of: {', '.join(valid_statuses)}. Re call the tool with correct format (mandatory).",
54
+ "isError": True,
55
+ }
56
+
57
+ # Store the raw todos structure in memory
58
+ _current_plan = todos
59
+
60
+ # Emit plan update event if session is available
61
+ if self.session:
62
+ await self.session.send_event(
63
+ Event(
64
+ event_type="plan_update",
65
+ data={"plan": todos},
66
+ )
67
+ )
68
+
69
+ # Format only for display using terminal_display utility
70
+ formatted_output = format_plan_tool_output(todos)
71
+
72
+ return {
73
+ "formatted": formatted_output,
74
+ "totalResults": len(todos),
75
+ "isError": False,
76
+ }
77
+
78
+
79
+ def get_current_plan() -> List[Dict[str, str]]:
80
+ """Get the current plan (raw structure)."""
81
+ return _current_plan
82
+
83
+
84
+ # Tool specification
85
+ PLAN_TOOL_SPEC = {
86
+ "name": "plan_tool",
87
+ "description": (
88
+ "Manage task planning and progress tracking with todo list (pending/in_progress/completed statuses). "
89
+ "⚠️ CRITICAL: ALWAYS use for multi-step tasks (3+ steps) and MUST update frequently to show progress. "
90
+ "**Use when:** (1) User provides multiple tasks, (2) Complex workflows (training, evaluation, data processing), "
91
+ "(3) Tasks requiring multiple tool calls, (4) Need to communicate progress clearly to user, "
92
+ "(5) Breaking down ambiguous requests into concrete steps. "
93
+ "**Pattern:** Create plan at start → Mark in_progress when starting task → Mark completed immediately after finishing → User sees clear progress. "
94
+ "Each call replaces entire plan (full list required). "
95
+ "**Critical for reliability:** Exactly ONE task in_progress at a time (not zero, not multiple). "
96
+ "Mark tasks completed IMMEDIATELY after finishing - don't batch completions. "
97
+ "**For long-running tasks:** Update plan after each major step to keep user informed. "
98
+ "**Only mark completed when:** Task fully accomplished, no errors, all requirements met. "
99
+ "Keep tasks pending if blocked/errors occur - create new task to resolve blockers."
100
+ ),
101
+ "parameters": {
102
+ "type": "object",
103
+ "properties": {
104
+ "todos": {
105
+ "type": "array",
106
+ "description": "List of todo items",
107
+ "items": {
108
+ "type": "object",
109
+ "properties": {
110
+ "id": {
111
+ "type": "string",
112
+ "description": "Unique identifier for the todo",
113
+ },
114
+ "content": {
115
+ "type": "string",
116
+ "description": "Description of the todo task",
117
+ },
118
+ "status": {
119
+ "type": "string",
120
+ "enum": ["pending", "in_progress", "completed"],
121
+ "description": "Current status of the todo",
122
+ },
123
+ },
124
+ "required": ["id", "content", "status"],
125
+ },
126
+ }
127
+ },
128
+ "required": ["todos"],
129
+ },
130
+ }
131
+
132
+
133
+ async def plan_tool_handler(
134
+ arguments: Dict[str, Any], session: Any = None
135
+ ) -> tuple[str, bool]:
136
+ tool = PlanTool(session=session)
137
+ result = await tool.execute(arguments)
138
+ return result["formatted"], not result.get("isError", False)
agent/tools/private_hf_repo_tools.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Private HF Repos Tool - Manage private Hugging Face repositories
3
+
4
+ PRIMARY USE: Store job outputs, training scripts, and logs from HF Jobs.
5
+ Since job results are ephemeral, this tool provides persistent storage in private repos.
6
+
7
+ SECONDARY USE: Read back stored files and list repo contents.
8
+ """
9
+
10
+ import asyncio
11
+ from typing import Any, Dict, Literal, Optional
12
+
13
+ from huggingface_hub import HfApi, hf_hub_download
14
+ from huggingface_hub.utils import HfHubHTTPError
15
+
16
+ from agent.tools.types import ToolResult
17
+
18
+ # Operation names
19
+ OperationType = Literal[
20
+ "upload_file", "create_repo", "check_repo", "list_files", "read_file"
21
+ ]
22
+
23
+
24
+ async def _async_call(func, *args, **kwargs):
25
+ """Wrap synchronous HfApi calls for async context."""
26
+ return await asyncio.to_thread(func, *args, **kwargs)
27
+
28
+
29
+ def _build_repo_url(repo_id: str, repo_type: str = "dataset") -> str:
30
+ """Build the Hub URL for a repository."""
31
+ type_path = "" if repo_type == "model" else f"{repo_type}s"
32
+ return f"https://huggingface.co/{type_path}/{repo_id}".replace("//", "/")
33
+
34
+
35
+ def _content_to_bytes(content: str | bytes) -> bytes:
36
+ """Convert string or bytes content to bytes."""
37
+ if isinstance(content, str):
38
+ return content.encode("utf-8")
39
+ return content
40
+
41
+
42
+ class PrivateHfRepoTool:
43
+ """Tool for managing private Hugging Face repositories."""
44
+
45
+ def __init__(self, hf_token: Optional[str] = None):
46
+ self.api = HfApi(token=hf_token)
47
+
48
+ async def execute(self, params: Dict[str, Any]) -> ToolResult:
49
+ """Execute the specified upload operation."""
50
+ operation = params.get("operation")
51
+ args = params.get("args", {})
52
+
53
+ # If no operation provided, return usage instructions
54
+ if not operation:
55
+ return self._show_help()
56
+
57
+ # Normalize operation name
58
+ operation = operation.lower()
59
+
60
+ # Check if help is requested
61
+ if args.get("help"):
62
+ return self._show_operation_help(operation)
63
+
64
+ try:
65
+ # Route to appropriate handler
66
+ if operation == "upload_file":
67
+ return await self._upload_file(args)
68
+ elif operation == "create_repo":
69
+ return await self._create_repo(args)
70
+ elif operation == "check_repo":
71
+ return await self._check_repo(args)
72
+ elif operation == "list_files":
73
+ return await self._list_files(args)
74
+ elif operation == "read_file":
75
+ return await self._read_file(args)
76
+ else:
77
+ return {
78
+ "formatted": f'Unknown operation: "{operation}"\n\n'
79
+ "Available operations: upload_file, create_repo, check_repo, list_files, read_file\n\n"
80
+ "Call this tool with no operation for full usage instructions.",
81
+ "totalResults": 0,
82
+ "resultsShared": 0,
83
+ "isError": True,
84
+ }
85
+
86
+ except HfHubHTTPError as e:
87
+ return {
88
+ "formatted": f"API Error: {str(e)}",
89
+ "totalResults": 0,
90
+ "resultsShared": 0,
91
+ "isError": True,
92
+ }
93
+ except Exception as e:
94
+ return {
95
+ "formatted": f"Error executing {operation}: {str(e)}",
96
+ "totalResults": 0,
97
+ "resultsShared": 0,
98
+ "isError": True,
99
+ }
100
+
101
+ def _show_help(self) -> ToolResult:
102
+ """Show usage instructions when tool is called with no arguments."""
103
+ usage_text = """# Private HF Repos Tool
104
+
105
+ **PRIMARY USE:** Store job outputs, scripts, and logs from HF Jobs to private repos.
106
+ Since job results are ephemeral, use this tool for persistent storage.
107
+
108
+ **SECONDARY USE:** Read back stored files and list repo contents.
109
+
110
+ ## Available Commands
111
+
112
+ ### Write Operations
113
+ - **upload_file** - Upload file content to a repository
114
+ - **create_repo** - Create a new private repository
115
+
116
+ ### Read Operations
117
+ - **list_files** - List all files in a repository
118
+ - **read_file** - Read content of a specific file from a repository
119
+ - **check_repo** - Check if a repository exists
120
+
121
+ ## Examples
122
+
123
+ ### Upload a script to a dataset repo
124
+ Call this tool with:
125
+ ```json
126
+ {
127
+ "operation": "upload_file",
128
+ "args": {
129
+ "file_content": "import pandas as pd\\nprint('Hello from HF!')",
130
+ "path_in_repo": "scripts/hello.py",
131
+ "repo_id": "my-dataset",
132
+ "repo_type": "dataset",
133
+ "create_if_missing": true,
134
+ "commit_message": "Add hello script"
135
+ }
136
+ }
137
+ ```
138
+
139
+ ### Upload logs from a job
140
+ Call this tool with:
141
+ ```json
142
+ {
143
+ "operation": "upload_file",
144
+ "args": {
145
+ "file_content": "Job started...\\nJob completed successfully!",
146
+ "path_in_repo": "jobs/job-abc123/logs.txt",
147
+ "repo_id": "job-results",
148
+ "create_if_missing": true
149
+ }
150
+ }
151
+ ```
152
+
153
+ ### Create a repository
154
+ Call this tool with:
155
+ ```json
156
+ {
157
+ "operation": "create_repo",
158
+ "args": {
159
+ "repo_id": "my-results",
160
+ "repo_type": "dataset"
161
+ }
162
+ }
163
+ ```
164
+
165
+ ### Create a Space
166
+ Call this tool with:
167
+ ```json
168
+ {
169
+ "operation": "create_repo",
170
+ "args": {
171
+ "repo_id": "my-gradio-app",
172
+ "repo_type": "space",
173
+ "space_sdk": "gradio"
174
+ }
175
+ }
176
+ ```
177
+ Note: Repositories are always created as private. For spaces, `space_sdk` is required (gradio, streamlit, static, or docker).
178
+
179
+ ### Check if a repository exists
180
+ Call this tool with:
181
+ ```json
182
+ {
183
+ "operation": "check_repo",
184
+ "args": {
185
+ "repo_id": "my-dataset",
186
+ "repo_type": "dataset"
187
+ }
188
+ }
189
+ ```
190
+
191
+ ### List files in a repository
192
+ Call this tool with:
193
+ ```json
194
+ {
195
+ "operation": "list_files",
196
+ "args": {
197
+ "repo_id": "job-results",
198
+ "repo_type": "dataset"
199
+ }
200
+ }
201
+ ```
202
+
203
+ ### Read a file from a repository
204
+ Call this tool with:
205
+ ```json
206
+ {
207
+ "operation": "read_file",
208
+ "args": {
209
+ "repo_id": "job-results",
210
+ "path_in_repo": "jobs/job-abc123/script.py",
211
+ "repo_type": "dataset"
212
+ }
213
+ }
214
+ ```
215
+
216
+ ## Repository Types
217
+
218
+ - **dataset** (default) - For storing data, results, logs, scripts
219
+ - **model** - For ML models and related artifacts
220
+ - **space** - For Spaces and applications
221
+
222
+ ## Tips
223
+
224
+ - **Content-based**: Pass file content directly as strings or bytes, not file paths
225
+ - **Repo ID format**: Use just the repo name (e.g., "my-dataset"). Username is automatically inferred from HF_TOKEN
226
+ - **Automatic repo creation**: Set `create_if_missing: true` to auto-create repos (requires user approval)
227
+ - **Organization**: Use path_in_repo to organize files (e.g., "jobs/job-123/script.py")
228
+ - **After jobs**: Upload job scripts and logs after compute jobs complete for reproducibility
229
+ """
230
+ return {"formatted": usage_text, "totalResults": 1, "resultsShared": 1}
231
+
232
+ def _show_operation_help(self, operation: str) -> ToolResult:
233
+ """Show help for a specific operation."""
234
+ help_text = f"Help for operation: {operation}\n\nCall with appropriate arguments. Use the main help for examples."
235
+ return {"formatted": help_text, "totalResults": 1, "resultsShared": 1}
236
+
237
+ async def _upload_file(self, args: Dict[str, Any]) -> ToolResult:
238
+ """Upload file content to a Hub repository."""
239
+ # Validate required arguments
240
+ file_content = args.get("file_content")
241
+ path_in_repo = args.get("path_in_repo")
242
+ repo_id = args.get("repo_id")
243
+
244
+ if not file_content:
245
+ return {
246
+ "formatted": "file_content is required",
247
+ "totalResults": 0,
248
+ "resultsShared": 0,
249
+ "isError": True,
250
+ }
251
+
252
+ if not path_in_repo:
253
+ return {
254
+ "formatted": "path_in_repo is required",
255
+ "totalResults": 0,
256
+ "resultsShared": 0,
257
+ "isError": True,
258
+ }
259
+
260
+ if not repo_id:
261
+ return {
262
+ "formatted": "repo_id is required",
263
+ "totalResults": 0,
264
+ "resultsShared": 0,
265
+ "isError": True,
266
+ }
267
+
268
+ repo_type = args.get("repo_type", "dataset")
269
+ create_if_missing = args.get("create_if_missing", False)
270
+
271
+ # Check if repo exists
272
+ try:
273
+ repo_exists = await _async_call(
274
+ self.api.repo_exists, repo_id=repo_id, repo_type=repo_type
275
+ )
276
+
277
+ # Create repo if needed
278
+ if not repo_exists and create_if_missing:
279
+ create_args = {
280
+ "repo_id": repo_id,
281
+ "repo_type": repo_type,
282
+ "private": True,
283
+ }
284
+ # Pass through space_sdk if provided (required for spaces)
285
+ if "space_sdk" in args:
286
+ create_args["space_sdk"] = args["space_sdk"]
287
+ await self._create_repo(create_args)
288
+ elif not repo_exists:
289
+ return {
290
+ "formatted": f"Repository {repo_id} does not exist. Set create_if_missing: true to create it.",
291
+ "totalResults": 0,
292
+ "resultsShared": 0,
293
+ "isError": True,
294
+ }
295
+
296
+ except Exception as e:
297
+ return {
298
+ "formatted": f"Failed to check repository: {str(e)}",
299
+ "totalResults": 0,
300
+ "resultsShared": 0,
301
+ "isError": True,
302
+ }
303
+
304
+ # Convert content to bytes
305
+ file_bytes = _content_to_bytes(file_content)
306
+
307
+ # Upload file
308
+ try:
309
+ await _async_call(
310
+ self.api.upload_file,
311
+ path_or_fileobj=file_bytes,
312
+ path_in_repo=path_in_repo,
313
+ repo_id=repo_id,
314
+ repo_type=repo_type,
315
+ commit_message=args.get("commit_message", f"Upload {path_in_repo}"),
316
+ )
317
+
318
+ repo_url = _build_repo_url(repo_id, repo_type)
319
+ file_url = f"{repo_url}/blob/main/{path_in_repo}"
320
+
321
+ response = f"""✓ File uploaded successfully!
322
+
323
+ **Repository:** {repo_id}
324
+ **File:** {path_in_repo}
325
+ **View at:** {file_url}
326
+ **Browse repo:** {repo_url}"""
327
+
328
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
329
+
330
+ except Exception as e:
331
+ return {
332
+ "formatted": f"Failed to upload file: {str(e)}",
333
+ "totalResults": 0,
334
+ "resultsShared": 0,
335
+ "isError": True,
336
+ }
337
+
338
+ async def _create_repo(self, args: Dict[str, Any]) -> ToolResult:
339
+ """Create a new Hub repository."""
340
+ repo_id = args.get("repo_id")
341
+
342
+ if not repo_id:
343
+ return {
344
+ "formatted": "repo_id is required",
345
+ "totalResults": 0,
346
+ "resultsShared": 0,
347
+ "isError": True,
348
+ }
349
+
350
+ repo_type = args.get("repo_type", "dataset")
351
+ private = True # Always create private repos
352
+ space_sdk = args.get("space_sdk") # Required if repo_type is "space"
353
+
354
+ try:
355
+ # Check if repo already exists
356
+ repo_exists = await _async_call(
357
+ self.api.repo_exists, repo_id=repo_id, repo_type=repo_type
358
+ )
359
+
360
+ if repo_exists:
361
+ repo_url = _build_repo_url(repo_id, repo_type)
362
+ return {
363
+ "formatted": f"Repository {repo_id} already exists.\n**View at:** {repo_url}",
364
+ "totalResults": 1,
365
+ "resultsShared": 1,
366
+ }
367
+
368
+ # Validate space_sdk for spaces
369
+ if repo_type == "space" and not space_sdk:
370
+ return {
371
+ "formatted": "space_sdk is required when creating a space. Valid values: gradio, streamlit, static, docker",
372
+ "totalResults": 0,
373
+ "resultsShared": 0,
374
+ "isError": True,
375
+ }
376
+
377
+ # Create repository
378
+ create_kwargs = {
379
+ "repo_id": repo_id,
380
+ "repo_type": repo_type,
381
+ "private": private,
382
+ "exist_ok": True,
383
+ }
384
+ # Add space_sdk only for spaces
385
+ if repo_type == "space" and space_sdk:
386
+ create_kwargs["space_sdk"] = space_sdk
387
+
388
+ repo_url = await _async_call(self.api.create_repo, **create_kwargs)
389
+
390
+ response = f"""✓ Repository created successfully!
391
+
392
+ **Repository:** {repo_id}
393
+ **Type:** {repo_type}
394
+ **Private:** Yes
395
+ **View at:** {repo_url}"""
396
+
397
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
398
+
399
+ except Exception as e:
400
+ return {
401
+ "formatted": f"Failed to create repository: {str(e)}",
402
+ "totalResults": 0,
403
+ "resultsShared": 0,
404
+ "isError": True,
405
+ }
406
+
407
+ async def _check_repo(self, args: Dict[str, Any]) -> ToolResult:
408
+ """Check if a Hub repository exists."""
409
+ repo_id = args.get("repo_id")
410
+
411
+ if not repo_id:
412
+ return {
413
+ "formatted": "repo_id is required",
414
+ "totalResults": 0,
415
+ "resultsShared": 0,
416
+ "isError": True,
417
+ }
418
+
419
+ repo_type = args.get("repo_type", "dataset")
420
+
421
+ try:
422
+ repo_exists = await _async_call(
423
+ self.api.repo_exists, repo_id=repo_id, repo_type=repo_type
424
+ )
425
+
426
+ if repo_exists:
427
+ repo_url = _build_repo_url(repo_id, repo_type)
428
+ response = f"""✓ Repository exists!
429
+
430
+ **Repository:** {repo_id}
431
+ **Type:** {repo_type}
432
+ **View at:** {repo_url}"""
433
+ else:
434
+ response = f"""Repository does not exist: {repo_id}
435
+
436
+ To create it, call this tool with:
437
+ ```json
438
+ {{
439
+ "operation": "create_repo",
440
+ "args": {{
441
+ "repo_id": "{repo_id}",
442
+ "repo_type": "{repo_type}"
443
+ }}
444
+ }}
445
+ ```"""
446
+
447
+ return {
448
+ "formatted": response,
449
+ "totalResults": 1 if repo_exists else 0,
450
+ "resultsShared": 1 if repo_exists else 0,
451
+ }
452
+
453
+ except Exception as e:
454
+ return {
455
+ "formatted": f"Failed to check repository: {str(e)}",
456
+ "totalResults": 0,
457
+ "resultsShared": 0,
458
+ "isError": True,
459
+ }
460
+
461
+ async def _list_files(self, args: Dict[str, Any]) -> ToolResult:
462
+ """List all files in a Hub repository."""
463
+ repo_id = args.get("repo_id")
464
+
465
+ if not repo_id:
466
+ return {
467
+ "formatted": "repo_id is required",
468
+ "totalResults": 0,
469
+ "resultsShared": 0,
470
+ "isError": True,
471
+ }
472
+
473
+ repo_type = args.get("repo_type", "dataset")
474
+
475
+ try:
476
+ # List all files in the repository
477
+ files = await _async_call(
478
+ self.api.list_repo_files, repo_id=repo_id, repo_type=repo_type
479
+ )
480
+
481
+ if not files:
482
+ return {
483
+ "formatted": f"No files found in repository: {repo_id}",
484
+ "totalResults": 0,
485
+ "resultsShared": 0,
486
+ }
487
+
488
+ # Format file list
489
+ file_list = "\n".join(f"- {f}" for f in sorted(files))
490
+ repo_url = _build_repo_url(repo_id, repo_type)
491
+
492
+ response = f"""✓ Files in repository: {repo_id}
493
+
494
+ **Total files:** {len(files)}
495
+ **Repository URL:** {repo_url}
496
+
497
+ **Files:**
498
+ {file_list}"""
499
+
500
+ return {
501
+ "formatted": response,
502
+ "totalResults": len(files),
503
+ "resultsShared": len(files),
504
+ }
505
+
506
+ except Exception as e:
507
+ return {
508
+ "formatted": f"Failed to list files: {str(e)}",
509
+ "totalResults": 0,
510
+ "resultsShared": 0,
511
+ "isError": True,
512
+ }
513
+
514
+ async def _read_file(self, args: Dict[str, Any]) -> ToolResult:
515
+ """Read content of a specific file from a Hub repository."""
516
+ repo_id = args.get("repo_id")
517
+ path_in_repo = args.get("path_in_repo")
518
+
519
+ if not repo_id:
520
+ return {
521
+ "formatted": "repo_id is required",
522
+ "totalResults": 0,
523
+ "resultsShared": 0,
524
+ "isError": True,
525
+ }
526
+
527
+ if not path_in_repo:
528
+ return {
529
+ "formatted": "path_in_repo is required",
530
+ "totalResults": 0,
531
+ "resultsShared": 0,
532
+ "isError": True,
533
+ }
534
+
535
+ repo_type = args.get("repo_type", "dataset")
536
+
537
+ try:
538
+ # Download file to cache and read it
539
+ file_path = await _async_call(
540
+ hf_hub_download,
541
+ repo_id=repo_id,
542
+ filename=path_in_repo,
543
+ repo_type=repo_type,
544
+ token=self.api.token,
545
+ )
546
+
547
+ # Read file content
548
+ with open(file_path, "r", encoding="utf-8") as f:
549
+ content = f.read()
550
+
551
+ repo_url = _build_repo_url(repo_id, repo_type)
552
+ file_url = f"{repo_url}/blob/main/{path_in_repo}"
553
+
554
+ response = f"""✓ File read successfully!
555
+
556
+ **Repository:** {repo_id}
557
+ **File:** {path_in_repo}
558
+ **Size:** {len(content)} characters
559
+ **View at:** {file_url}
560
+
561
+ **Content:**
562
+ ```
563
+ {content}
564
+ ```"""
565
+
566
+ return {"formatted": response, "totalResults": 1, "resultsShared": 1}
567
+
568
+ except UnicodeDecodeError:
569
+ # If file is binary, return size info instead
570
+ try:
571
+ with open(file_path, "rb") as f:
572
+ binary_content = f.read()
573
+
574
+ return {
575
+ "formatted": f"File is binary ({len(binary_content)} bytes). Cannot display as text.",
576
+ "totalResults": 1,
577
+ "resultsShared": 1,
578
+ }
579
+ except Exception as e:
580
+ return {
581
+ "formatted": f"Failed to read binary file: {str(e)}",
582
+ "totalResults": 0,
583
+ "resultsShared": 0,
584
+ "isError": True,
585
+ }
586
+ except Exception as e:
587
+ return {
588
+ "formatted": f"Failed to read file: {str(e)}",
589
+ "totalResults": 0,
590
+ "resultsShared": 0,
591
+ "isError": True,
592
+ }
593
+
594
+
595
+ # Tool specification for agent registration
596
+ PRIVATE_HF_REPO_TOOL_SPEC = {
597
+ "name": "hf_private_repos",
598
+ "description": (
599
+ "Manage private HF repositories - create, upload, read, list files in models/datasets/spaces. "
600
+ "⚠️ PRIMARY USE: Store job outputs persistently (job storage is EPHEMERAL - everything deleted after completion). "
601
+ "**Use when:** (1) Job completes and need to store logs/scripts/results, (2) Creating repos for training outputs, "
602
+ "(3) Reading back stored files, (4) Managing Space files, (5) Organizing job artifacts by path. "
603
+ "**Pattern:** hf_jobs (ephemeral) → hf_private_repos upload_file (persistent) → can read_file later. "
604
+ "ALWAYS pass file_content as string/bytes (✓), never file paths (✗) - this is content-based, no filesystem access. "
605
+ "**Operations:** create_repo (new private repo), upload_file (store content), read_file (retrieve content), list_files (browse), check_repo (verify exists). "
606
+ "**Critical for reliability:** Jobs lose all files after completion - use this tool to preserve important outputs. "
607
+ "Repositories created are ALWAYS private by default (good for sensitive training data/models). "
608
+ "For Spaces: must provide space_sdk ('gradio', 'streamlit', 'static', 'docker') when creating. "
609
+ "**Then:** After uploading, provide user with repository URL for viewing/sharing."
610
+ ),
611
+ "parameters": {
612
+ "type": "object",
613
+ "properties": {
614
+ "operation": {
615
+ "type": "string",
616
+ "enum": [
617
+ "upload_file",
618
+ "create_repo",
619
+ "check_repo",
620
+ "list_files",
621
+ "read_file",
622
+ ],
623
+ "description": (
624
+ "Operation to execute. Valid values: [upload_file, create_repo, check_repo, list_files, read_file]"
625
+ ),
626
+ },
627
+ "args": {
628
+ "type": "object",
629
+ "description": (
630
+ "Operation-specific arguments as a JSON object. "
631
+ "Write ops: file_content (string/bytes), path_in_repo (string), repo_id (string), "
632
+ "repo_type (dataset/model/space), create_if_missing (boolean), commit_message (string), "
633
+ "space_sdk (gradio/streamlit/static/docker - required when repo_type=space). "
634
+ "Read ops: repo_id (string), path_in_repo (for read_file), repo_type (optional)."
635
+ ),
636
+ "additionalProperties": True,
637
+ },
638
+ },
639
+ },
640
+ }
641
+
642
+
643
+ async def private_hf_repo_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
644
+ """Handler for agent tool router."""
645
+ try:
646
+ tool = PrivateHfRepoTool()
647
+ result = await tool.execute(arguments)
648
+ return result["formatted"], not result.get("isError", False)
649
+ except Exception as e:
650
+ return f"Error executing Private HF Repo tool: {str(e)}", False
agent/tools/types.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Types for Hugging Face tools
3
+
4
+ Ported from: hf-mcp-server/packages/mcp/src/types/
5
+ """
6
+
7
+ from typing import TypedDict
8
+
9
+
10
+ class ToolResult(TypedDict, total=False):
11
+ """Result returned by HF tool operations"""
12
+
13
+ formatted: str
14
+ totalResults: int
15
+ resultsShared: int
16
+ isError: bool
agent/tools/utilities.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for Hugging Face tools
3
+
4
+ Ported from: hf-mcp-server/packages/mcp/src/jobs/formatters.ts
5
+ Includes GPU memory validation for job submissions
6
+ """
7
+
8
+ import json
9
+ from datetime import datetime
10
+ from typing import Any, Dict, List, Optional
11
+
12
+
13
+ def truncate(text: str, max_length: int) -> str:
14
+ """Truncate a string to a maximum length with ellipsis"""
15
+ if len(text) <= max_length:
16
+ return text
17
+ return text[: max_length - 3] + "..."
18
+
19
+
20
+ def format_date(date_str: Optional[str]) -> str:
21
+ """Format a date string to a readable format"""
22
+ if not date_str:
23
+ return "N/A"
24
+ try:
25
+ date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
26
+ return date.strftime("%Y-%m-%d %H:%M:%S")
27
+ except Exception:
28
+ return date_str
29
+
30
+
31
+ def format_command(command: Optional[List[str]]) -> str:
32
+ """Format command array as a single string"""
33
+ if not command or len(command) == 0:
34
+ return "N/A"
35
+ return " ".join(command)
36
+
37
+
38
+ def get_image_or_space(job: Dict[str, Any]) -> str:
39
+ """Get image/space identifier from job"""
40
+ if job.get("spaceId"):
41
+ return job["spaceId"]
42
+ if job.get("dockerImage"):
43
+ return job["dockerImage"]
44
+ return "N/A"
45
+
46
+
47
+ def format_jobs_table(jobs: List[Dict[str, Any]]) -> str:
48
+ """Format jobs as a markdown table"""
49
+ if len(jobs) == 0:
50
+ return "No jobs found."
51
+
52
+ # Calculate dynamic ID column width
53
+ longest_id_length = max(len(job["id"]) for job in jobs)
54
+ id_column_width = max(longest_id_length, len("JOB ID"))
55
+
56
+ # Define column widths
57
+ col_widths = {
58
+ "id": id_column_width,
59
+ "image": 20,
60
+ "command": 30,
61
+ "created": 19,
62
+ "status": 12,
63
+ }
64
+
65
+ # Build header
66
+ header = f"| {'JOB ID'.ljust(col_widths['id'])} | {'IMAGE/SPACE'.ljust(col_widths['image'])} | {'COMMAND'.ljust(col_widths['command'])} | {'CREATED'.ljust(col_widths['created'])} | {'STATUS'.ljust(col_widths['status'])} |"
67
+ separator = f"|{'-' * (col_widths['id'] + 2)}|{'-' * (col_widths['image'] + 2)}|{'-' * (col_widths['command'] + 2)}|{'-' * (col_widths['created'] + 2)}|{'-' * (col_widths['status'] + 2)}|"
68
+
69
+ # Build rows
70
+ rows = []
71
+ for job in jobs:
72
+ job_id = job["id"]
73
+ image = truncate(get_image_or_space(job), col_widths["image"])
74
+ command = truncate(format_command(job.get("command")), col_widths["command"])
75
+ created = truncate(format_date(job.get("createdAt")), col_widths["created"])
76
+ status = truncate(job["status"]["stage"], col_widths["status"])
77
+
78
+ rows.append(
79
+ f"| {job_id.ljust(col_widths['id'])} | {image.ljust(col_widths['image'])} | {command.ljust(col_widths['command'])} | {created.ljust(col_widths['created'])} | {status.ljust(col_widths['status'])} |"
80
+ )
81
+
82
+ return "\n".join([header, separator] + rows)
83
+
84
+
85
+ def format_scheduled_jobs_table(jobs: List[Dict[str, Any]]) -> str:
86
+ """Format scheduled jobs as a markdown table"""
87
+ if len(jobs) == 0:
88
+ return "No scheduled jobs found."
89
+
90
+ # Calculate dynamic ID column width
91
+ longest_id_length = max(len(job["id"]) for job in jobs)
92
+ id_column_width = max(longest_id_length, len("ID"))
93
+
94
+ # Define column widths
95
+ col_widths = {
96
+ "id": id_column_width,
97
+ "schedule": 12,
98
+ "image": 18,
99
+ "command": 25,
100
+ "lastRun": 19,
101
+ "nextRun": 19,
102
+ "suspend": 9,
103
+ }
104
+
105
+ # Build header
106
+ header = f"| {'ID'.ljust(col_widths['id'])} | {'SCHEDULE'.ljust(col_widths['schedule'])} | {'IMAGE/SPACE'.ljust(col_widths['image'])} | {'COMMAND'.ljust(col_widths['command'])} | {'LAST RUN'.ljust(col_widths['lastRun'])} | {'NEXT RUN'.ljust(col_widths['nextRun'])} | {'SUSPENDED'.ljust(col_widths['suspend'])} |"
107
+ separator = f"|{'-' * (col_widths['id'] + 2)}|{'-' * (col_widths['schedule'] + 2)}|{'-' * (col_widths['image'] + 2)}|{'-' * (col_widths['command'] + 2)}|{'-' * (col_widths['lastRun'] + 2)}|{'-' * (col_widths['nextRun'] + 2)}|{'-' * (col_widths['suspend'] + 2)}|"
108
+
109
+ # Build rows
110
+ rows = []
111
+ for job in jobs:
112
+ job_id = job["id"]
113
+ schedule = truncate(job["schedule"], col_widths["schedule"])
114
+ image = truncate(get_image_or_space(job["jobSpec"]), col_widths["image"])
115
+ command = truncate(
116
+ format_command(job["jobSpec"].get("command")), col_widths["command"]
117
+ )
118
+ last_run = truncate(format_date(job.get("lastRun")), col_widths["lastRun"])
119
+ next_run = truncate(format_date(job.get("nextRun")), col_widths["nextRun"])
120
+ suspend = "Yes" if job.get("suspend") else "No"
121
+
122
+ rows.append(
123
+ f"| {job_id.ljust(col_widths['id'])} | {schedule.ljust(col_widths['schedule'])} | {image.ljust(col_widths['image'])} | {command.ljust(col_widths['command'])} | {last_run.ljust(col_widths['lastRun'])} | {next_run.ljust(col_widths['nextRun'])} | {suspend.ljust(col_widths['suspend'])} |"
124
+ )
125
+
126
+ return "\n".join([header, separator] + rows)
127
+
128
+
129
+ def format_job_details(jobs: Any) -> str:
130
+ """Format job details as JSON in a markdown code block"""
131
+
132
+ job_array = jobs if isinstance(jobs, list) else [jobs]
133
+ json_str = json.dumps(job_array, indent=2)
134
+ return f"```json\n{json_str}\n```"
135
+
136
+
137
+ def format_scheduled_job_details(jobs: Any) -> str:
138
+ """Format scheduled job details as JSON in a markdown code block"""
139
+
140
+ job_array = jobs if isinstance(jobs, list) else [jobs]
141
+ json_str = json.dumps(job_array, indent=2)
142
+ return f"```json\n{json_str}\n```"
agent/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Utility functions and helpers
3
+ """
agent/utils/reliability_checks.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reliability checks for job submissions and other operations"""
2
+
3
+ from agent.utils.terminal_display import Colors
4
+
5
+
6
+ def check_training_script_save_pattern(script: str) -> str | None:
7
+ """Check if a training script properly saves models."""
8
+ has_from_pretrained = "from_pretrained" in script
9
+ has_push_to_hub = "push_to_hub" in script
10
+
11
+ if has_from_pretrained and not has_push_to_hub:
12
+ return f"\n{Colors.RED}WARNING: We've detected that no model will be saved at the end of this training script. Please ensure this is what you want.{Colors.RESET}"
13
+ elif has_from_pretrained and has_push_to_hub:
14
+ return f"\n{Colors.GREEN}We've detected that a model will be pushed to hub at the end of this training.{Colors.RESET}"
15
+
16
+ return None
agent/utils/terminal_display.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Terminal display utilities with colors and formatting
3
+ """
4
+
5
+
6
+ # ANSI color codes
7
+ class Colors:
8
+ RED = "\033[91m"
9
+ GREEN = "\033[92m"
10
+ YELLOW = "\033[93m"
11
+ BLUE = "\033[94m"
12
+ MAGENTA = "\033[95m"
13
+ CYAN = "\033[96m"
14
+ BOLD = "\033[1m"
15
+ UNDERLINE = "\033[4m"
16
+ RESET = "\033[0m"
17
+
18
+
19
+ def truncate_to_lines(text: str, max_lines: int = 6) -> str:
20
+ """Truncate text to max_lines, adding '...' if truncated"""
21
+ lines = text.split("\n")
22
+ if len(lines) <= max_lines:
23
+ return text
24
+ return (
25
+ "\n".join(lines[:max_lines])
26
+ + f"\n{Colors.CYAN}... ({len(lines) - max_lines} more lines){Colors.RESET}"
27
+ )
28
+
29
+
30
+ def format_header(text: str, emoji: str = "") -> str:
31
+ """Format a header with bold"""
32
+ full_text = f"{emoji} {text}" if emoji else text
33
+ return f"{Colors.BOLD}{full_text}{Colors.RESET}"
34
+
35
+
36
+ def format_plan_display() -> str:
37
+ """Format the current plan for display (no colors, full visibility)"""
38
+ from agent.tools.plan_tool import get_current_plan
39
+
40
+ plan = get_current_plan()
41
+ if not plan:
42
+ return ""
43
+
44
+ lines = ["\n" + "=" * 60]
45
+ lines.append("CURRENT PLAN")
46
+ lines.append("=" * 60 + "\n")
47
+
48
+ # Group by status
49
+ completed = [t for t in plan if t["status"] == "completed"]
50
+ in_progress = [t for t in plan if t["status"] == "in_progress"]
51
+ pending = [t for t in plan if t["status"] == "pending"]
52
+
53
+ if completed:
54
+ lines.append("Completed:")
55
+ for todo in completed:
56
+ lines.append(f" [x] {todo['id']}. {todo['content']}")
57
+ lines.append("")
58
+
59
+ if in_progress:
60
+ lines.append("In Progress:")
61
+ for todo in in_progress:
62
+ lines.append(f" [~] {todo['id']}. {todo['content']}")
63
+ lines.append("")
64
+
65
+ if pending:
66
+ lines.append("Pending:")
67
+ for todo in pending:
68
+ lines.append(f" [ ] {todo['id']}. {todo['content']}")
69
+ lines.append("")
70
+
71
+ lines.append(
72
+ f"Total: {len(plan)} todos ({len(completed)} completed, {len(in_progress)} in progress, {len(pending)} pending)"
73
+ )
74
+ lines.append("=" * 60 + "\n")
75
+
76
+ return "\n".join(lines)
77
+
78
+
79
+ def format_error(message: str) -> str:
80
+ """Format an error message in red"""
81
+ return f"{Colors.RED}ERROR: {message}{Colors.RESET}"
82
+
83
+
84
+ def format_success(message: str, emoji: str = "") -> str:
85
+ """Format a success message in green"""
86
+ prefix = f"{emoji} " if emoji else ""
87
+ return f"{Colors.GREEN}{prefix}{message}{Colors.RESET}"
88
+
89
+
90
+ def format_tool_call(tool_name: str, arguments: str) -> str:
91
+ """Format a tool call message"""
92
+ return f"{Colors.YELLOW}Calling tool: {Colors.BOLD}{tool_name}{Colors.RESET}{Colors.YELLOW} with arguments: {arguments}{Colors.RESET}"
93
+
94
+
95
+ def format_tool_output(output: str, success: bool, truncate: bool = True) -> str:
96
+ """Format tool output with color and optional truncation"""
97
+ original_length = len(output)
98
+ if truncate:
99
+ output = truncate_to_lines(output, max_lines=6)
100
+
101
+ if success:
102
+ return (
103
+ f"{Colors.YELLOW}Tool output ({original_length} tkns): {Colors.RESET}\n{output}"
104
+ )
105
+ else:
106
+ return (
107
+ f"{Colors.RED}Tool output ({original_length} tokens): {Colors.RESET}\n{output}"
108
+ )
109
+
110
+
111
+ def format_turn_complete() -> str:
112
+ """Format turn complete message in green with hugging face emoji"""
113
+ return f"{Colors.GREEN}{Colors.BOLD}\U0001f917 Turn complete{Colors.RESET}\n"
114
+
115
+
116
+ def format_separator(char: str = "=", length: int = 60) -> str:
117
+ """Format a separator line"""
118
+ return char * length
119
+
120
+
121
+ def format_plan_tool_output(todos: list) -> str:
122
+ """Format the plan tool output (no colors, full visibility)"""
123
+ if not todos:
124
+ return "Plan is empty."
125
+
126
+ lines = ["Plan updated successfully", ""]
127
+
128
+ # Group by status
129
+ completed = [t for t in todos if t["status"] == "completed"]
130
+ in_progress = [t for t in todos if t["status"] == "in_progress"]
131
+ pending = [t for t in todos if t["status"] == "pending"]
132
+
133
+ if completed:
134
+ lines.append("Completed:")
135
+ for todo in completed:
136
+ lines.append(f" [x] {todo['id']}. {todo['content']}")
137
+ lines.append("")
138
+
139
+ if in_progress:
140
+ lines.append("In Progress:")
141
+ for todo in in_progress:
142
+ lines.append(f" [~] {todo['id']}. {todo['content']}")
143
+ lines.append("")
144
+
145
+ if pending:
146
+ lines.append("Pending:")
147
+ for todo in pending:
148
+ lines.append(f" [ ] {todo['id']}. {todo['content']}")
149
+ lines.append("")
150
+
151
+ lines.append(
152
+ f"Total: {len(todos)} todos ({len(completed)} completed, {len(in_progress)} in progress, {len(pending)} pending)"
153
+ )
154
+
155
+ return "\n".join(lines)
configs/main_agent_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "anthropic/claude-opus-4-5-20251101",
3
+ "save_sessions": true,
4
+ "session_dataset_repo": "akseljoonas/hf-agent-sessions",
5
+ "yolo_mode": false,
6
+ "confirm_cpu_jobs": false,
7
+ "auto_file_upload": true,
8
+ "mcpServers": {
9
+ "hf-mcp-server": {
10
+ "transport": "http",
11
+ "url": "https://huggingface.co/mcp?login",
12
+ "headers": {
13
+ "Authorization": "Bearer ${HF_TOKEN}"
14
+ }
15
+ }
16
+ }
17
+ }
dependencies.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Authentication dependencies for FastAPI routes.
2
+
3
+ Provides auth validation for both REST and WebSocket endpoints.
4
+ - In dev mode (OAUTH_CLIENT_ID not set): auth is bypassed, returns a default "dev" user.
5
+ - In production: validates Bearer tokens or cookies against HF OAuth.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import time
11
+ from typing import Any
12
+
13
+ import httpx
14
+ from fastapi import HTTPException, Request, WebSocket, status
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
19
+ AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", ""))
20
+
21
+ # Simple in-memory token cache: token -> (user_info, expiry_time)
22
+ _token_cache: dict[str, tuple[dict[str, Any], float]] = {}
23
+ TOKEN_CACHE_TTL = 300 # 5 minutes
24
+
25
+ DEV_USER: dict[str, Any] = {
26
+ "user_id": "dev",
27
+ "username": "dev",
28
+ "authenticated": True,
29
+ }
30
+
31
+
32
+ async def _validate_token(token: str) -> dict[str, Any] | None:
33
+ """Validate a token against HF OAuth userinfo endpoint.
34
+
35
+ Results are cached for TOKEN_CACHE_TTL seconds to avoid excessive API calls.
36
+ """
37
+ now = time.time()
38
+
39
+ # Check cache
40
+ if token in _token_cache:
41
+ user_info, expiry = _token_cache[token]
42
+ if now < expiry:
43
+ return user_info
44
+ del _token_cache[token]
45
+
46
+ # Validate against HF
47
+ async with httpx.AsyncClient(timeout=10.0) as client:
48
+ try:
49
+ response = await client.get(
50
+ f"{OPENID_PROVIDER_URL}/oauth/userinfo",
51
+ headers={"Authorization": f"Bearer {token}"},
52
+ )
53
+ if response.status_code != 200:
54
+ logger.debug("Token validation failed: status %d", response.status_code)
55
+ return None
56
+ user_info = response.json()
57
+ _token_cache[token] = (user_info, now + TOKEN_CACHE_TTL)
58
+ return user_info
59
+ except httpx.HTTPError as e:
60
+ logger.warning("Token validation error: %s", e)
61
+ return None
62
+
63
+
64
+ def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]:
65
+ """Build a normalized user dict from HF userinfo response."""
66
+ return {
67
+ "user_id": user_info.get("sub", user_info.get("preferred_username", "unknown")),
68
+ "username": user_info.get("preferred_username", "unknown"),
69
+ "name": user_info.get("name"),
70
+ "picture": user_info.get("picture"),
71
+ "authenticated": True,
72
+ }
73
+
74
+
75
+ async def _extract_user_from_token(token: str) -> dict[str, Any] | None:
76
+ """Validate a token and return a user dict, or None."""
77
+ user_info = await _validate_token(token)
78
+ if user_info:
79
+ return _user_from_info(user_info)
80
+ return None
81
+
82
+
83
+ async def get_current_user(request: Request) -> dict[str, Any]:
84
+ """FastAPI dependency: extract and validate the current user.
85
+
86
+ Checks (in order):
87
+ 1. Authorization: Bearer <token> header
88
+ 2. hf_access_token cookie
89
+
90
+ In dev mode (AUTH_ENABLED=False), returns a default dev user.
91
+ """
92
+ if not AUTH_ENABLED:
93
+ return DEV_USER
94
+
95
+ # Try Authorization header
96
+ auth_header = request.headers.get("Authorization", "")
97
+ if auth_header.startswith("Bearer "):
98
+ token = auth_header[7:]
99
+ user = await _extract_user_from_token(token)
100
+ if user:
101
+ return user
102
+
103
+ # Try cookie
104
+ token = request.cookies.get("hf_access_token")
105
+ if token:
106
+ user = await _extract_user_from_token(token)
107
+ if user:
108
+ return user
109
+
110
+ raise HTTPException(
111
+ status_code=status.HTTP_401_UNAUTHORIZED,
112
+ detail="Not authenticated. Please log in via /auth/login.",
113
+ headers={"WWW-Authenticate": "Bearer"},
114
+ )
115
+
116
+
117
+ async def get_ws_user(websocket: WebSocket) -> dict[str, Any] | None:
118
+ """Extract and validate user from WebSocket connection.
119
+
120
+ WebSocket doesn't support custom headers from browser, so we check:
121
+ 1. ?token= query parameter
122
+ 2. hf_access_token cookie (sent automatically for same-origin)
123
+
124
+ Returns user dict or None if not authenticated.
125
+ In dev mode, returns the default dev user.
126
+ """
127
+ if not AUTH_ENABLED:
128
+ return DEV_USER
129
+
130
+ # Try query param
131
+ token = websocket.query_params.get("token")
132
+ if token:
133
+ user = await _extract_user_from_token(token)
134
+ if user:
135
+ return user
136
+
137
+ # Try cookie (works for same-origin WebSocket)
138
+ token = websocket.cookies.get("hf_access_token")
139
+ if token:
140
+ user = await _extract_user_from_token(token)
141
+ if user:
142
+ return user
143
+
144
+ return None
main.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application for HF Agent web interface - API ONLY MODE.
2
+
3
+ This backend runs in API-only mode without serving static files.
4
+ The frontend is hosted separately and communicates via HTTP/WebSocket.
5
+ """
6
+
7
+ import logging
8
+ import os
9
+ from contextlib import asynccontextmanager
10
+
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+
15
+ # Ensure HF_TOKEN is set — fall back to HF_ADMIN_TOKEN if available (HF Spaces)
16
+ if not os.environ.get("HF_TOKEN") and os.environ.get("HF_ADMIN_TOKEN"):
17
+ os.environ["HF_TOKEN"] = os.environ.get("HF_ADMIN_TOKEN")
18
+
19
+ from fastapi import FastAPI
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+
22
+ from routes.agent import router as agent_router
23
+ from routes.auth import router as auth_router
24
+
25
+ # Configure logging
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ """Application lifespan handler."""
36
+ logger.info("Starting HF Agent backend (API-only mode)...")
37
+ logger.info(f"CORS allowed origins: {os.environ.get('CORS_ORIGINS', '*')}")
38
+ yield
39
+ logger.info("Shutting down HF Agent backend...")
40
+
41
+
42
+ app = FastAPI(
43
+ title="HF Agent API",
44
+ description="ML Engineering Assistant API - Separate Frontend/Backend Mode",
45
+ version="1.0.0",
46
+ lifespan=lifespan,
47
+ )
48
+
49
+ # CORS middleware - allow all origins for separate hosting
50
+ # In production, set CORS_ORIGINS env var to your frontend URL(s)
51
+ cors_origins = os.environ.get("CORS_ORIGINS", "*")
52
+ if cors_origins != "*":
53
+ cors_origins = [origin.strip() for origin in cors_origins.split(",")]
54
+
55
+ app.add_middleware(
56
+ CORSMiddleware,
57
+ allow_origins=cors_origins if isinstance(cors_origins, list) else ["*"],
58
+ allow_credentials=True,
59
+ allow_methods=["*"],
60
+ allow_headers=["*"],
61
+ )
62
+
63
+ # Include routers
64
+ app.include_router(agent_router)
65
+ app.include_router(auth_router)
66
+
67
+
68
+ @app.get("/api")
69
+ async def api_root():
70
+ """API root endpoint."""
71
+ return {
72
+ "name": "HF Agent API",
73
+ "version": "1.0.0",
74
+ "mode": "api-only",
75
+ "docs": "/docs",
76
+ }
77
+
78
+
79
+ @app.get("/")
80
+ async def root():
81
+ """Root endpoint - indicates API-only mode."""
82
+ return {
83
+ "status": "ok",
84
+ "mode": "api-only",
85
+ "message": "Backend is running in API-only mode. Frontend is hosted separately.",
86
+ "api_docs": "/docs",
87
+ "api_endpoints": "/api",
88
+ }
89
+
90
+
91
+ if __name__ == "__main__":
92
+ import uvicorn
93
+
94
+ port = int(os.environ.get("PORT", 7860))
95
+ host = os.environ.get("HOST", "0.0.0.0")
96
+ uvicorn.run(app, host=host, port=port)
models.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic models for API requests and responses."""
2
+
3
+ from enum import Enum
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel
7
+
8
+
9
+ class OpType(str, Enum):
10
+ """Operation types matching agent/core/agent_loop.py."""
11
+
12
+ USER_INPUT = "user_input"
13
+ EXEC_APPROVAL = "exec_approval"
14
+ INTERRUPT = "interrupt"
15
+ UNDO = "undo"
16
+ COMPACT = "compact"
17
+ SHUTDOWN = "shutdown"
18
+
19
+
20
+ class Operation(BaseModel):
21
+ """Operation to be submitted to the agent."""
22
+
23
+ op_type: OpType
24
+ data: dict[str, Any] | None = None
25
+
26
+
27
+ class Submission(BaseModel):
28
+ """Submission wrapper with ID and operation."""
29
+
30
+ id: str
31
+ operation: Operation
32
+
33
+
34
+ class ToolApproval(BaseModel):
35
+ """Approval decision for a single tool call."""
36
+
37
+ tool_call_id: str
38
+ approved: bool
39
+ feedback: str | None = None
40
+
41
+
42
+ class ApprovalRequest(BaseModel):
43
+ """Request to approve/reject tool calls."""
44
+
45
+ session_id: str
46
+ approvals: list[ToolApproval]
47
+
48
+
49
+ class SubmitRequest(BaseModel):
50
+ """Request to submit user input."""
51
+
52
+ session_id: str
53
+ text: str
54
+
55
+
56
+ class SessionResponse(BaseModel):
57
+ """Response when creating a new session."""
58
+
59
+ session_id: str
60
+ ready: bool = True
61
+
62
+
63
+ class SessionInfo(BaseModel):
64
+ """Session metadata."""
65
+
66
+ session_id: str
67
+ created_at: str
68
+ is_active: bool
69
+ message_count: int
70
+ user_id: str = "dev"
71
+
72
+
73
+ class HealthResponse(BaseModel):
74
+ """Health check response."""
75
+
76
+ status: str = "ok"
77
+ active_sessions: int = 0
78
+ max_sessions: int = 0
79
+
80
+
81
+ class LLMHealthResponse(BaseModel):
82
+ """LLM provider health check response."""
83
+
84
+ status: str # "ok" | "error"
85
+ model: str
86
+ error: str | None = None
87
+ error_type: str | None = None # "auth" | "credits" | "rate_limit" | "network" | "unknown"
pyproject.toml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "hf-agent"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "datasets>=4.4.1",
9
+ # Core dependencies (always required)
10
+ "pydantic>=2.12.3",
11
+ "python-dotenv>=1.2.1",
12
+ ]
13
+
14
+ [project.optional-dependencies]
15
+ # Agent runtime dependencies
16
+ agent = [
17
+ "requests>=2.32.5",
18
+ "litellm>=1.0.0",
19
+ "huggingface-hub>=1.0.1",
20
+ "fastmcp>=2.4.0",
21
+ "lmnr>=0.7.23", # Note: Using base package to avoid torch/transformers from [all] extra
22
+ "prompt-toolkit>=3.0.0",
23
+ "thefuzz>=0.22.1",
24
+ "nbconvert>=7.16.6",
25
+ "nbformat>=5.10.4",
26
+ "datasets>=4.3.0", # For session logging to HF datasets
27
+ "whoosh>=2.7.4",
28
+ # Web backend dependencies
29
+ "fastapi>=0.115.0",
30
+ "uvicorn[standard]>=0.32.0",
31
+ "httpx>=0.27.0",
32
+ "websockets>=13.0",
33
+ ]
34
+
35
+ # Evaluation/benchmarking dependencies
36
+ eval = [
37
+ "inspect-ai>=0.3.149",
38
+ "pandas>=2.3.3",
39
+ "datasets>=4.3.0",
40
+ "tenacity>=8.0.0",
41
+ ]
42
+
43
+ # Development and testing dependencies
44
+ dev = [
45
+ "pytest>=9.0.2",
46
+ ]
47
+
48
+ # All dependencies (agent + eval + dev)
49
+ all = [
50
+ "hf-agent[agent,eval,dev]",
51
+ ]
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF Agent Backend - Requirements
2
+ # Python 3.12+
3
+
4
+ # Core dependencies
5
+ pydantic>=2.12.3
6
+ python-dotenv>=1.2.1
7
+
8
+ # Agent runtime dependencies
9
+ requests>=2.32.5
10
+ litellm>=1.0.0
11
+ huggingface-hub>=1.0.1
12
+ fastmcp>=2.4.0
13
+ lmnr>=0.7.23
14
+ prompt-toolkit>=3.0.0
15
+ thefuzz>=0.22.1
16
+ nbconvert>=7.16.6
17
+ nbformat>=5.10.4
18
+ datasets>=4.3.0
19
+ whoosh>=2.7.4
20
+
21
+ # Web backend dependencies
22
+ fastapi>=0.115.0
23
+ uvicorn[standard]>=0.32.0
24
+ httpx>=0.27.0
25
+ websockets>=13.0
routes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Routes package
routes/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (186 Bytes). View file
 
routes/__pycache__/agent.cpython-313.pyc ADDED
Binary file (18.7 kB). View file
 
routes/agent.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent API routes - WebSocket and REST endpoints.
2
+
3
+ All routes (except /health) require authentication via the get_current_user
4
+ dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically.
5
+ """
6
+
7
+ import logging
8
+ import os
9
+ from typing import Any
10
+
11
+ from dependencies import get_current_user, get_ws_user
12
+ from fastapi import (
13
+ APIRouter,
14
+ Depends,
15
+ HTTPException,
16
+ Request,
17
+ WebSocket,
18
+ WebSocketDisconnect,
19
+ )
20
+ from litellm import acompletion
21
+ from models import (
22
+ ApprovalRequest,
23
+ HealthResponse,
24
+ LLMHealthResponse,
25
+ SessionInfo,
26
+ SessionResponse,
27
+ SubmitRequest,
28
+ )
29
+ from session_manager import MAX_SESSIONS, SessionCapacityError, session_manager
30
+ from websocket import manager as ws_manager
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ router = APIRouter(prefix="/api", tags=["agent"])
35
+
36
+ AVAILABLE_MODELS = [
37
+ {
38
+ "id": "huggingface/novita/MiniMaxAI/MiniMax-M2.1",
39
+ "label": "MiniMax M2.1",
40
+ "provider": "huggingface",
41
+ "recommended": True,
42
+ },
43
+ {
44
+ "id": "anthropic/claude-opus-4-5-20251101",
45
+ "label": "Claude Opus 4.5",
46
+ "provider": "anthropic",
47
+ "recommended": True,
48
+ },
49
+ {
50
+ "id": "huggingface/novita/moonshotai/Kimi-K2.5",
51
+ "label": "Kimi K2.5",
52
+ "provider": "huggingface",
53
+ },
54
+ {
55
+ "id": "huggingface/novita/zai-org/GLM-5",
56
+ "label": "GLM 5",
57
+ "provider": "huggingface",
58
+ },
59
+ ]
60
+
61
+
62
+ def _check_session_access(session_id: str, user: dict[str, Any]) -> None:
63
+ """Verify the user has access to the given session. Raises 403 or 404."""
64
+ info = session_manager.get_session_info(session_id)
65
+ if not info:
66
+ raise HTTPException(status_code=404, detail="Session not found")
67
+ if not session_manager.verify_session_access(session_id, user["user_id"]):
68
+ raise HTTPException(status_code=403, detail="Access denied to this session")
69
+
70
+
71
+ @router.get("/health", response_model=HealthResponse)
72
+ async def health_check() -> HealthResponse:
73
+ """Health check endpoint."""
74
+ return HealthResponse(
75
+ status="ok",
76
+ active_sessions=session_manager.active_session_count,
77
+ max_sessions=MAX_SESSIONS,
78
+ )
79
+
80
+
81
+ @router.get("/health/llm", response_model=LLMHealthResponse)
82
+ async def llm_health_check() -> LLMHealthResponse:
83
+ """Check if the LLM provider is reachable and the API key is valid.
84
+
85
+ Makes a minimal 1-token completion call. Catches common errors:
86
+ - 401 → invalid API key
87
+ - 402/insufficient_quota → out of credits
88
+ - 429 → rate limited
89
+ - timeout / network → provider unreachable
90
+ """
91
+ model = session_manager.config.model_name
92
+ hf_key = os.environ.get("INFERENCE_TOKEN")
93
+ try:
94
+ await acompletion(
95
+ model=model,
96
+ messages=[{"role": "user", "content": "hi"}],
97
+ max_tokens=1,
98
+ timeout=10,
99
+ api_key=hf_key if hf_key and model.startswith("huggingface/") else None,
100
+ )
101
+ return LLMHealthResponse(status="ok", model=model)
102
+ except Exception as e:
103
+ err_str = str(e).lower()
104
+ error_type = "unknown"
105
+
106
+ if (
107
+ "401" in err_str
108
+ or "auth" in err_str
109
+ or "invalid" in err_str
110
+ or "api key" in err_str
111
+ ):
112
+ error_type = "auth"
113
+ elif (
114
+ "402" in err_str
115
+ or "credit" in err_str
116
+ or "quota" in err_str
117
+ or "insufficient" in err_str
118
+ or "billing" in err_str
119
+ ):
120
+ error_type = "credits"
121
+ elif "429" in err_str or "rate" in err_str:
122
+ error_type = "rate_limit"
123
+ elif "timeout" in err_str or "connect" in err_str or "network" in err_str:
124
+ error_type = "network"
125
+
126
+ logger.warning(f"LLM health check failed ({error_type}): {e}")
127
+ return LLMHealthResponse(
128
+ status="error",
129
+ model=model,
130
+ error=str(e)[:500],
131
+ error_type=error_type,
132
+ )
133
+
134
+
135
+ @router.get("/config/model")
136
+ async def get_model() -> dict:
137
+ """Get current model and available models. No auth required."""
138
+ return {
139
+ "current": session_manager.config.model_name,
140
+ "available": AVAILABLE_MODELS,
141
+ }
142
+
143
+
144
+ @router.post("/config/model")
145
+ async def set_model(body: dict, user: dict = Depends(get_current_user)) -> dict:
146
+ """Set the LLM model. Applies to new conversations."""
147
+ model_id = body.get("model")
148
+ if not model_id:
149
+ raise HTTPException(status_code=400, detail="Missing 'model' field")
150
+ valid_ids = {m["id"] for m in AVAILABLE_MODELS}
151
+ if model_id not in valid_ids:
152
+ raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}")
153
+ session_manager.config.model_name = model_id
154
+ logger.info(f"Model changed to {model_id} by {user.get('username', 'unknown')}")
155
+ return {"model": model_id}
156
+
157
+
158
+ @router.post("/title")
159
+ async def generate_title(
160
+ request: SubmitRequest, user: dict = Depends(get_current_user)
161
+ ) -> dict:
162
+ """Generate a short title for a chat session based on the first user message."""
163
+ model = session_manager.config.model_name
164
+ hf_key = os.environ.get("INFERENCE_TOKEN")
165
+ try:
166
+ response = await acompletion(
167
+ model=model,
168
+ messages=[
169
+ {
170
+ "role": "system",
171
+ "content": (
172
+ "Generate a very short title (max 6 words) for a chat conversation "
173
+ "that starts with the following user message. "
174
+ "Reply with ONLY the title, no quotes, no punctuation at the end."
175
+ ),
176
+ },
177
+ {"role": "user", "content": request.text[:500]},
178
+ ],
179
+ max_tokens=20,
180
+ temperature=0.3,
181
+ timeout=8,
182
+ api_key=hf_key if hf_key and model.startswith("huggingface/") else None,
183
+ )
184
+ title = response.choices[0].message.content.strip().strip('"').strip("'")
185
+ # Safety: cap at 50 chars
186
+ if len(title) > 50:
187
+ title = title[:50].rstrip() + "…"
188
+ return {"title": title}
189
+ except Exception as e:
190
+ logger.warning(f"Title generation failed: {e}")
191
+ # Fallback: truncate the message
192
+ fallback = request.text.strip()
193
+ title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback
194
+ return {"title": title}
195
+
196
+
197
+ @router.post("/session", response_model=SessionResponse)
198
+ async def create_session(
199
+ request: Request, user: dict = Depends(get_current_user)
200
+ ) -> SessionResponse:
201
+ """Create a new agent session bound to the authenticated user.
202
+
203
+ The user's HF access token is extracted from the Authorization header
204
+ and stored in the session so that tools (e.g. hf_jobs) can act on
205
+ behalf of the user.
206
+
207
+ Returns 503 if the server or user has reached the session limit.
208
+ """
209
+ # Extract the user's HF token (Bearer header or HttpOnly cookie)
210
+ hf_token = None
211
+ auth_header = request.headers.get("Authorization", "")
212
+ if auth_header.startswith("Bearer "):
213
+ hf_token = auth_header[7:]
214
+ if not hf_token:
215
+ hf_token = request.cookies.get("hf_access_token")
216
+
217
+ try:
218
+ session_id = await session_manager.create_session(
219
+ user_id=user["user_id"], hf_token=hf_token
220
+ )
221
+ except SessionCapacityError as e:
222
+ raise HTTPException(status_code=503, detail=str(e))
223
+ return SessionResponse(session_id=session_id, ready=True)
224
+
225
+
226
+ @router.get("/session/{session_id}", response_model=SessionInfo)
227
+ async def get_session(
228
+ session_id: str, user: dict = Depends(get_current_user)
229
+ ) -> SessionInfo:
230
+ """Get session information. Only accessible by the session owner."""
231
+ _check_session_access(session_id, user)
232
+ info = session_manager.get_session_info(session_id)
233
+ return SessionInfo(**info)
234
+
235
+
236
+ @router.get("/sessions", response_model=list[SessionInfo])
237
+ async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
238
+ """List sessions belonging to the authenticated user."""
239
+ sessions = session_manager.list_sessions(user_id=user["user_id"])
240
+ return [SessionInfo(**s) for s in sessions]
241
+
242
+
243
+ @router.delete("/session/{session_id}")
244
+ async def delete_session(
245
+ session_id: str, user: dict = Depends(get_current_user)
246
+ ) -> dict:
247
+ """Delete a session. Only accessible by the session owner."""
248
+ _check_session_access(session_id, user)
249
+ success = await session_manager.delete_session(session_id)
250
+ if not success:
251
+ raise HTTPException(status_code=404, detail="Session not found")
252
+ return {"status": "deleted", "session_id": session_id}
253
+
254
+
255
+ @router.post("/submit")
256
+ async def submit_input(
257
+ request: SubmitRequest, user: dict = Depends(get_current_user)
258
+ ) -> dict:
259
+ """Submit user input to a session. Only accessible by the session owner."""
260
+ _check_session_access(request.session_id, user)
261
+ success = await session_manager.submit_user_input(request.session_id, request.text)
262
+ if not success:
263
+ raise HTTPException(status_code=404, detail="Session not found or inactive")
264
+ return {"status": "submitted", "session_id": request.session_id}
265
+
266
+
267
+ @router.post("/approve")
268
+ async def submit_approval(
269
+ request: ApprovalRequest, user: dict = Depends(get_current_user)
270
+ ) -> dict:
271
+ """Submit tool approvals to a session. Only accessible by the session owner."""
272
+ _check_session_access(request.session_id, user)
273
+ approvals = [
274
+ {
275
+ "tool_call_id": a.tool_call_id,
276
+ "approved": a.approved,
277
+ "feedback": a.feedback,
278
+ }
279
+ for a in request.approvals
280
+ ]
281
+ success = await session_manager.submit_approval(request.session_id, approvals)
282
+ if not success:
283
+ raise HTTPException(status_code=404, detail="Session not found or inactive")
284
+ return {"status": "submitted", "session_id": request.session_id}
285
+
286
+
287
+ @router.post("/interrupt/{session_id}")
288
+ async def interrupt_session(
289
+ session_id: str, user: dict = Depends(get_current_user)
290
+ ) -> dict:
291
+ """Interrupt the current operation in a session."""
292
+ _check_session_access(session_id, user)
293
+ success = await session_manager.interrupt(session_id)
294
+ if not success:
295
+ raise HTTPException(status_code=404, detail="Session not found or inactive")
296
+ return {"status": "interrupted", "session_id": session_id}
297
+
298
+
299
+ @router.post("/undo/{session_id}")
300
+ async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict:
301
+ """Undo the last turn in a session."""
302
+ _check_session_access(session_id, user)
303
+ success = await session_manager.undo(session_id)
304
+ if not success:
305
+ raise HTTPException(status_code=404, detail="Session not found or inactive")
306
+ return {"status": "undo_requested", "session_id": session_id}
307
+
308
+
309
+ @router.post("/compact/{session_id}")
310
+ async def compact_session(
311
+ session_id: str, user: dict = Depends(get_current_user)
312
+ ) -> dict:
313
+ """Compact the context in a session."""
314
+ _check_session_access(session_id, user)
315
+ success = await session_manager.compact(session_id)
316
+ if not success:
317
+ raise HTTPException(status_code=404, detail="Session not found or inactive")
318
+ return {"status": "compact_requested", "session_id": session_id}
319
+
320
+
321
+ @router.post("/shutdown/{session_id}")
322
+ async def shutdown_session(
323
+ session_id: str, user: dict = Depends(get_current_user)
324
+ ) -> dict:
325
+ """Shutdown a session."""
326
+ _check_session_access(session_id, user)
327
+ success = await session_manager.shutdown_session(session_id)
328
+ if not success:
329
+ raise HTTPException(status_code=404, detail="Session not found or inactive")
330
+ return {"status": "shutdown_requested", "session_id": session_id}
331
+
332
+
333
+ @router.websocket("/ws/{session_id}")
334
+ async def websocket_endpoint(websocket: WebSocket, session_id: str) -> None:
335
+ """WebSocket endpoint for real-time events.
336
+
337
+ Authentication is done via:
338
+ - ?token= query parameter (for browsers that can't send WS headers)
339
+ - Cookie (automatic for same-origin connections)
340
+ - Dev mode bypass (when OAUTH_CLIENT_ID is not set)
341
+
342
+ NOTE: We must accept() before close() so the browser receives our custom
343
+ close codes (4001, 4003, 4004). If we close() before accept(), Starlette
344
+ sends HTTP 403 and the browser only sees code 1006 (abnormal closure).
345
+ """
346
+ logger.info(f"WebSocket connection request for session {session_id}")
347
+
348
+ # Authenticate the WebSocket connection
349
+ user = await get_ws_user(websocket)
350
+ if not user:
351
+ logger.warning(
352
+ f"WebSocket rejected: authentication failed for session {session_id}"
353
+ )
354
+ await websocket.accept()
355
+ await websocket.close(code=4001, reason="Authentication required")
356
+ return
357
+
358
+ # Verify session exists
359
+ info = session_manager.get_session_info(session_id)
360
+ if not info:
361
+ logger.warning(f"WebSocket rejected: session {session_id} not found")
362
+ await websocket.accept()
363
+ await websocket.close(code=4004, reason="Session not found")
364
+ return
365
+
366
+ # Verify user owns the session
367
+ if not session_manager.verify_session_access(session_id, user["user_id"]):
368
+ logger.warning(
369
+ f"WebSocket rejected: user {user['user_id']} denied access to session {session_id}"
370
+ )
371
+ await websocket.accept()
372
+ await websocket.close(code=4003, reason="Access denied")
373
+ return
374
+
375
+ await ws_manager.connect(websocket, session_id)
376
+
377
+ # Send "ready" immediately on WebSocket connection so the frontend
378
+ # knows the session is alive. The original ready event from _run_session
379
+ # fires before the WS is connected and is always lost.
380
+ try:
381
+ await websocket.send_json(
382
+ {
383
+ "event_type": "ready",
384
+ "data": {"message": "Agent initialized"},
385
+ }
386
+ )
387
+ except Exception as e:
388
+ logger.error(f"Failed to send ready event for session {session_id}: {e}")
389
+
390
+ try:
391
+ while True:
392
+ # Keep connection alive, handle ping/pong
393
+ data = await websocket.receive_json()
394
+
395
+ # Handle client messages (e.g., ping)
396
+ if data.get("type") == "ping":
397
+ await websocket.send_json({"type": "pong"})
398
+
399
+ except WebSocketDisconnect:
400
+ logger.info(f"WebSocket disconnected for session {session_id}")
401
+ except Exception as e:
402
+ logger.error(f"WebSocket error for session {session_id}: {e}")
403
+ finally:
404
+ ws_manager.disconnect(session_id)
routes/auth.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Authentication routes for HF OAuth.
2
+
3
+ Handles the OAuth 2.0 authorization code flow with HF as provider.
4
+ After successful auth, sets an HttpOnly cookie with the access token.
5
+ """
6
+
7
+ import os
8
+ import secrets
9
+ import time
10
+ from urllib.parse import urlencode
11
+
12
+ import httpx
13
+ from dependencies import AUTH_ENABLED, get_current_user
14
+ from fastapi import APIRouter, Depends, HTTPException, Request
15
+ from fastapi.responses import RedirectResponse
16
+
17
+ router = APIRouter(prefix="/auth", tags=["auth"])
18
+
19
+ # OAuth configuration from environment
20
+ OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "")
21
+ OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "")
22
+ OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
23
+
24
+ # In-memory OAuth state store with expiry (5 min TTL)
25
+ _OAUTH_STATE_TTL = 300
26
+ oauth_states: dict[str, dict] = {}
27
+
28
+
29
+ def _cleanup_expired_states() -> None:
30
+ """Remove expired OAuth states to prevent memory growth."""
31
+ now = time.time()
32
+ expired = [k for k, v in oauth_states.items() if now > v.get("expires_at", 0)]
33
+ for k in expired:
34
+ del oauth_states[k]
35
+
36
+
37
+ def get_redirect_uri(request: Request) -> str:
38
+ """Get the OAuth callback redirect URI."""
39
+ # In HF Spaces, use the SPACE_HOST if available
40
+ space_host = os.environ.get("SPACE_HOST")
41
+ if space_host:
42
+ return f"https://{space_host}/auth/callback"
43
+ # Otherwise construct from request
44
+ return str(request.url_for("oauth_callback"))
45
+
46
+
47
+ @router.get("/login")
48
+ async def oauth_login(request: Request) -> RedirectResponse:
49
+ """Initiate OAuth login flow."""
50
+ if not OAUTH_CLIENT_ID:
51
+ raise HTTPException(
52
+ status_code=500,
53
+ detail="OAuth not configured. Set OAUTH_CLIENT_ID environment variable.",
54
+ )
55
+
56
+ # Clean up expired states to prevent memory growth
57
+ _cleanup_expired_states()
58
+
59
+ # Generate state for CSRF protection
60
+ state = secrets.token_urlsafe(32)
61
+ oauth_states[state] = {
62
+ "redirect_uri": get_redirect_uri(request),
63
+ "expires_at": time.time() + _OAUTH_STATE_TTL,
64
+ }
65
+
66
+ # Build authorization URL
67
+ params = {
68
+ "client_id": OAUTH_CLIENT_ID,
69
+ "redirect_uri": get_redirect_uri(request),
70
+ "scope": "openid profile read-repos write-repos contribute-repos manage-repos inference-api jobs write-discussions",
71
+ "response_type": "code",
72
+ "state": state,
73
+ "orgIds": os.environ.get(
74
+ "HF_OAUTH_ORG_ID", "698dbf55845d85df163175f1"
75
+ ), # ml-agent-explorers
76
+ }
77
+ auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}"
78
+
79
+ return RedirectResponse(url=auth_url)
80
+
81
+
82
+ @router.get("/callback")
83
+ async def oauth_callback(
84
+ request: Request, code: str = "", state: str = ""
85
+ ) -> RedirectResponse:
86
+ """Handle OAuth callback."""
87
+ # Verify state
88
+ if state not in oauth_states:
89
+ raise HTTPException(status_code=400, detail="Invalid state parameter")
90
+
91
+ stored_state = oauth_states.pop(state)
92
+ redirect_uri = stored_state["redirect_uri"]
93
+
94
+ if not code:
95
+ raise HTTPException(status_code=400, detail="No authorization code provided")
96
+
97
+ # Exchange code for token
98
+ token_url = f"{OPENID_PROVIDER_URL}/oauth/token"
99
+ async with httpx.AsyncClient() as client:
100
+ try:
101
+ response = await client.post(
102
+ token_url,
103
+ data={
104
+ "grant_type": "authorization_code",
105
+ "code": code,
106
+ "redirect_uri": redirect_uri,
107
+ "client_id": OAUTH_CLIENT_ID,
108
+ "client_secret": OAUTH_CLIENT_SECRET,
109
+ },
110
+ )
111
+ response.raise_for_status()
112
+ token_data = response.json()
113
+ except httpx.HTTPError as e:
114
+ raise HTTPException(status_code=500, detail=f"Token exchange failed: {e}")
115
+
116
+ # Get user info
117
+ access_token = token_data.get("access_token")
118
+ if not access_token:
119
+ raise HTTPException(
120
+ status_code=500,
121
+ detail="Token exchange succeeded but no access_token was returned.",
122
+ )
123
+
124
+ # Fetch user info (optional — failure is not fatal)
125
+ async with httpx.AsyncClient() as client:
126
+ try:
127
+ userinfo_response = await client.get(
128
+ f"{OPENID_PROVIDER_URL}/oauth/userinfo",
129
+ headers={"Authorization": f"Bearer {access_token}"},
130
+ )
131
+ userinfo_response.raise_for_status()
132
+ except httpx.HTTPError:
133
+ pass # user_info not required for auth flow
134
+
135
+ # Set access token as HttpOnly cookie (not in URL — avoids leaks via
136
+ # Referrer headers, browser history, and server logs)
137
+ is_production = bool(os.environ.get("SPACE_HOST"))
138
+ response = RedirectResponse(url="/", status_code=302)
139
+ response.set_cookie(
140
+ key="hf_access_token",
141
+ value=access_token,
142
+ httponly=True,
143
+ secure=is_production, # Secure flag only in production (HTTPS)
144
+ samesite="lax",
145
+ max_age=3600 * 24, # 24 hours
146
+ path="/",
147
+ )
148
+ return response
149
+
150
+
151
+ @router.get("/logout")
152
+ async def logout() -> RedirectResponse:
153
+ """Log out the user by clearing the auth cookie."""
154
+ response = RedirectResponse(url="/")
155
+ response.delete_cookie(key="hf_access_token", path="/")
156
+ return response
157
+
158
+
159
+ @router.get("/status")
160
+ async def auth_status() -> dict:
161
+ """Check if OAuth is enabled on this instance."""
162
+ return {"auth_enabled": AUTH_ENABLED}
163
+
164
+
165
+ @router.get("/me")
166
+ async def get_me(user: dict = Depends(get_current_user)) -> dict:
167
+ """Get current user info. Returns the authenticated user or dev user.
168
+
169
+ Uses the shared auth dependency which handles cookie + Bearer token.
170
+ """
171
+ return user
session_manager.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session manager for handling multiple concurrent agent sessions."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import uuid
6
+ from dataclasses import dataclass, field
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Any, Optional
10
+
11
+ from websocket import manager as ws_manager
12
+
13
+ from agent.config import load_config
14
+ from agent.core.agent_loop import process_submission
15
+ from agent.core.session import Event, OpType, Session
16
+ from agent.core.tools import ToolRouter
17
+
18
+ # Get project root (parent of backend directory)
19
+ PROJECT_ROOT = Path(__file__).parent.parent
20
+ DEFAULT_CONFIG_PATH = str(PROJECT_ROOT / "configs" / "main_agent_config.json")
21
+
22
+
23
+ # These dataclasses match agent/main.py structure
24
+ @dataclass
25
+ class Operation:
26
+ """Operation to be executed by the agent."""
27
+
28
+ op_type: OpType
29
+ data: Optional[dict[str, Any]] = None
30
+
31
+
32
+ @dataclass
33
+ class Submission:
34
+ """Submission to the agent loop."""
35
+
36
+ id: str
37
+ operation: Operation
38
+
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ @dataclass
44
+ class AgentSession:
45
+ """Wrapper for an agent session with its associated resources."""
46
+
47
+ session_id: str
48
+ session: Session
49
+ tool_router: ToolRouter
50
+ submission_queue: asyncio.Queue
51
+ user_id: str = "dev" # Owner of this session
52
+ hf_token: str | None = None # User's HF OAuth token for tool execution
53
+ task: asyncio.Task | None = None
54
+ created_at: datetime = field(default_factory=datetime.utcnow)
55
+ is_active: bool = True
56
+
57
+
58
+ class SessionCapacityError(Exception):
59
+ """Raised when no more sessions can be created."""
60
+
61
+ def __init__(self, message: str, error_type: str = "global") -> None:
62
+ super().__init__(message)
63
+ self.error_type = error_type # "global" or "per_user"
64
+
65
+
66
+ # ── Capacity limits ─────────────────────────────────────────────────
67
+ # Estimated for HF Spaces cpu-basic (2 vCPU, 16 GB RAM).
68
+ # Each session uses ~10-20 MB (context, tools, queues, task).
69
+ MAX_SESSIONS: int = 50
70
+ MAX_SESSIONS_PER_USER: int = 10
71
+
72
+
73
+ class SessionManager:
74
+ """Manages multiple concurrent agent sessions."""
75
+
76
+ def __init__(self, config_path: str | None = None) -> None:
77
+ self.config = load_config(config_path or DEFAULT_CONFIG_PATH)
78
+ self.sessions: dict[str, AgentSession] = {}
79
+ self._lock = asyncio.Lock()
80
+
81
+ def _count_user_sessions(self, user_id: str) -> int:
82
+ """Count active sessions owned by a specific user."""
83
+ return sum(
84
+ 1
85
+ for s in self.sessions.values()
86
+ if s.user_id == user_id and s.is_active
87
+ )
88
+
89
+ async def create_session(self, user_id: str = "dev", hf_token: str | None = None) -> str:
90
+ """Create a new agent session and return its ID.
91
+
92
+ Session() and ToolRouter() constructors contain blocking I/O
93
+ (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are
94
+ executed in a thread pool to avoid freezing the async event loop.
95
+
96
+ Args:
97
+ user_id: The ID of the user who owns this session.
98
+
99
+ Raises:
100
+ SessionCapacityError: If the server or user has reached the
101
+ maximum number of concurrent sessions.
102
+ """
103
+ # ── Capacity checks ──────────────────────────────────────────
104
+ async with self._lock:
105
+ active_count = self.active_session_count
106
+ if active_count >= MAX_SESSIONS:
107
+ raise SessionCapacityError(
108
+ f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). "
109
+ "Please try again later.",
110
+ error_type="global",
111
+ )
112
+ if user_id != "dev":
113
+ user_count = self._count_user_sessions(user_id)
114
+ if user_count >= MAX_SESSIONS_PER_USER:
115
+ raise SessionCapacityError(
116
+ f"You have reached the maximum of {MAX_SESSIONS_PER_USER} "
117
+ "concurrent sessions. Please close an existing session first.",
118
+ error_type="per_user",
119
+ )
120
+
121
+ session_id = str(uuid.uuid4())
122
+
123
+ # Create queues for this session
124
+ submission_queue: asyncio.Queue = asyncio.Queue()
125
+ event_queue: asyncio.Queue = asyncio.Queue()
126
+
127
+ # Run blocking constructors in a thread to keep the event loop responsive.
128
+ # Without this, Session.__init__ → ContextManager → litellm.get_max_tokens()
129
+ # blocks all HTTP/WebSocket handling.
130
+ import time as _time
131
+
132
+ def _create_session_sync():
133
+ t0 = _time.monotonic()
134
+ tool_router = ToolRouter(self.config.mcpServers)
135
+ session = Session(event_queue, config=self.config, tool_router=tool_router)
136
+ t1 = _time.monotonic()
137
+ logger.info(f"Session initialized in {t1 - t0:.2f}s")
138
+ return tool_router, session
139
+
140
+ tool_router, session = await asyncio.to_thread(_create_session_sync)
141
+
142
+ # Store user's HF token on the session so tools can use it
143
+ session.hf_token = hf_token
144
+
145
+ # Create wrapper
146
+ agent_session = AgentSession(
147
+ session_id=session_id,
148
+ session=session,
149
+ tool_router=tool_router,
150
+ submission_queue=submission_queue,
151
+ user_id=user_id,
152
+ hf_token=hf_token,
153
+ )
154
+
155
+ async with self._lock:
156
+ self.sessions[session_id] = agent_session
157
+
158
+ # Start the agent loop task
159
+ task = asyncio.create_task(
160
+ self._run_session(session_id, submission_queue, event_queue, tool_router)
161
+ )
162
+ agent_session.task = task
163
+
164
+ logger.info(f"Created session {session_id} for user {user_id}")
165
+ return session_id
166
+
167
+ async def _run_session(
168
+ self,
169
+ session_id: str,
170
+ submission_queue: asyncio.Queue,
171
+ event_queue: asyncio.Queue,
172
+ tool_router: ToolRouter,
173
+ ) -> None:
174
+ """Run the agent loop for a session and forward events to WebSocket."""
175
+ agent_session = self.sessions.get(session_id)
176
+ if not agent_session:
177
+ logger.error(f"Session {session_id} not found")
178
+ return
179
+
180
+ session = agent_session.session
181
+
182
+ # Start event forwarder task
183
+ event_forwarder = asyncio.create_task(
184
+ self._forward_events(session_id, event_queue)
185
+ )
186
+
187
+ try:
188
+ async with tool_router:
189
+ # Send ready event
190
+ await session.send_event(
191
+ Event(event_type="ready", data={"message": "Agent initialized"})
192
+ )
193
+
194
+ while session.is_running:
195
+ try:
196
+ # Wait for submission with timeout to allow checking is_running
197
+ submission = await asyncio.wait_for(
198
+ submission_queue.get(), timeout=1.0
199
+ )
200
+ should_continue = await process_submission(session, submission)
201
+ if not should_continue:
202
+ break
203
+ except asyncio.TimeoutError:
204
+ continue
205
+ except asyncio.CancelledError:
206
+ logger.info(f"Session {session_id} cancelled")
207
+ break
208
+ except Exception as e:
209
+ logger.error(f"Error in session {session_id}: {e}")
210
+ await session.send_event(
211
+ Event(event_type="error", data={"error": str(e)})
212
+ )
213
+
214
+ finally:
215
+ event_forwarder.cancel()
216
+ try:
217
+ await event_forwarder
218
+ except asyncio.CancelledError:
219
+ pass
220
+
221
+ async with self._lock:
222
+ if session_id in self.sessions:
223
+ self.sessions[session_id].is_active = False
224
+
225
+ logger.info(f"Session {session_id} ended")
226
+
227
+ async def _forward_events(
228
+ self, session_id: str, event_queue: asyncio.Queue
229
+ ) -> None:
230
+ """Forward events from the agent to the WebSocket."""
231
+ while True:
232
+ try:
233
+ event: Event = await event_queue.get()
234
+ await ws_manager.send_event(session_id, event.event_type, event.data)
235
+ except asyncio.CancelledError:
236
+ break
237
+ except Exception as e:
238
+ logger.error(f"Error forwarding event for {session_id}: {e}")
239
+
240
+ async def submit(self, session_id: str, operation: Operation) -> bool:
241
+ """Submit an operation to a session."""
242
+ async with self._lock:
243
+ agent_session = self.sessions.get(session_id)
244
+
245
+ if not agent_session or not agent_session.is_active:
246
+ logger.warning(f"Session {session_id} not found or inactive")
247
+ return False
248
+
249
+ submission = Submission(id=f"sub_{uuid.uuid4().hex[:8]}", operation=operation)
250
+ await agent_session.submission_queue.put(submission)
251
+ return True
252
+
253
+ async def submit_user_input(self, session_id: str, text: str) -> bool:
254
+ """Submit user input to a session."""
255
+ operation = Operation(op_type=OpType.USER_INPUT, data={"text": text})
256
+ return await self.submit(session_id, operation)
257
+
258
+ async def submit_approval(
259
+ self, session_id: str, approvals: list[dict[str, Any]]
260
+ ) -> bool:
261
+ """Submit tool approvals to a session."""
262
+ operation = Operation(
263
+ op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals}
264
+ )
265
+ return await self.submit(session_id, operation)
266
+
267
+ async def interrupt(self, session_id: str) -> bool:
268
+ """Interrupt a session."""
269
+ operation = Operation(op_type=OpType.INTERRUPT)
270
+ return await self.submit(session_id, operation)
271
+
272
+ async def undo(self, session_id: str) -> bool:
273
+ """Undo last turn in a session."""
274
+ operation = Operation(op_type=OpType.UNDO)
275
+ return await self.submit(session_id, operation)
276
+
277
+ async def compact(self, session_id: str) -> bool:
278
+ """Compact context in a session."""
279
+ operation = Operation(op_type=OpType.COMPACT)
280
+ return await self.submit(session_id, operation)
281
+
282
+ async def shutdown_session(self, session_id: str) -> bool:
283
+ """Shutdown a specific session."""
284
+ operation = Operation(op_type=OpType.SHUTDOWN)
285
+ success = await self.submit(session_id, operation)
286
+
287
+ if success:
288
+ async with self._lock:
289
+ agent_session = self.sessions.get(session_id)
290
+ if agent_session and agent_session.task:
291
+ # Wait for task to complete
292
+ try:
293
+ await asyncio.wait_for(agent_session.task, timeout=5.0)
294
+ except asyncio.TimeoutError:
295
+ agent_session.task.cancel()
296
+
297
+ return success
298
+
299
+ async def delete_session(self, session_id: str) -> bool:
300
+ """Delete a session entirely."""
301
+ async with self._lock:
302
+ agent_session = self.sessions.pop(session_id, None)
303
+
304
+ if not agent_session:
305
+ return False
306
+
307
+ # Cancel the task if running
308
+ if agent_session.task and not agent_session.task.done():
309
+ agent_session.task.cancel()
310
+ try:
311
+ await agent_session.task
312
+ except asyncio.CancelledError:
313
+ pass
314
+
315
+ return True
316
+
317
+ def get_session_owner(self, session_id: str) -> str | None:
318
+ """Get the user_id that owns a session, or None if session doesn't exist."""
319
+ agent_session = self.sessions.get(session_id)
320
+ if not agent_session:
321
+ return None
322
+ return agent_session.user_id
323
+
324
+ def verify_session_access(self, session_id: str, user_id: str) -> bool:
325
+ """Check if a user has access to a session.
326
+
327
+ Returns True if:
328
+ - The session exists AND the user owns it
329
+ - The user_id is "dev" (dev mode bypass)
330
+ """
331
+ owner = self.get_session_owner(session_id)
332
+ if owner is None:
333
+ return False
334
+ if user_id == "dev" or owner == "dev":
335
+ return True
336
+ return owner == user_id
337
+
338
+ def get_session_info(self, session_id: str) -> dict[str, Any] | None:
339
+ """Get information about a session."""
340
+ agent_session = self.sessions.get(session_id)
341
+ if not agent_session:
342
+ return None
343
+
344
+ return {
345
+ "session_id": session_id,
346
+ "created_at": agent_session.created_at.isoformat(),
347
+ "is_active": agent_session.is_active,
348
+ "message_count": len(agent_session.session.context_manager.items),
349
+ "user_id": agent_session.user_id,
350
+ }
351
+
352
+ def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]:
353
+ """List sessions, optionally filtered by user.
354
+
355
+ Args:
356
+ user_id: If provided, only return sessions owned by this user.
357
+ If "dev", return all sessions (dev mode).
358
+ """
359
+ results = []
360
+ for sid in self.sessions:
361
+ info = self.get_session_info(sid)
362
+ if not info:
363
+ continue
364
+ if user_id and user_id != "dev" and info.get("user_id") != user_id:
365
+ continue
366
+ results.append(info)
367
+ return results
368
+
369
+ @property
370
+ def active_session_count(self) -> int:
371
+ """Get count of active sessions."""
372
+ return sum(1 for s in self.sessions.values() if s.is_active)
373
+
374
+
375
+ # Global session manager instance
376
+ session_manager = SessionManager()
start.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # HF Agent Backend Startup Script
3
+
4
+ # Load environment variables from .env file if it exists
5
+ if [ -f .env ]; then
6
+ echo "Loading environment from .env file..."
7
+ set -a
8
+ source .env
9
+ set +a
10
+ fi
11
+
12
+ # Default configuration
13
+ export PORT=${PORT:-7860}
14
+ export HOST=${HOST:-0.0.0.0}
15
+ export CORS_ORIGINS=${CORS_ORIGINS:-*}
16
+
17
+ echo "=========================================="
18
+ echo "HF Agent Backend (API-only mode)"
19
+ echo "=========================================="
20
+ echo "Host: $HOST"
21
+ echo "Port: $PORT"
22
+ echo "CORS Origins: $CORS_ORIGINS"
23
+ echo "=========================================="
24
+
25
+ # Run the FastAPI application
26
+ exec uvicorn main:app --host "$HOST" --port "$PORT" --reload
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
websocket.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WebSocket connection manager for real-time communication."""
2
+
3
+ import logging
4
+ from typing import Any
5
+
6
+ from fastapi import WebSocket
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class ConnectionManager:
12
+ """Manages WebSocket connections for multiple sessions."""
13
+
14
+ def __init__(self) -> None:
15
+ # session_id -> WebSocket
16
+ self.active_connections: dict[str, WebSocket] = {}
17
+
18
+ async def connect(self, websocket: WebSocket, session_id: str) -> None:
19
+ """Accept a WebSocket connection and register it."""
20
+ logger.info(f"Attempting to accept WebSocket for session {session_id}")
21
+ await websocket.accept()
22
+ self.active_connections[session_id] = websocket
23
+ logger.info(f"WebSocket connected and registered for session {session_id}")
24
+
25
+ def disconnect(self, session_id: str) -> None:
26
+ """Remove a WebSocket connection."""
27
+ if session_id in self.active_connections:
28
+ del self.active_connections[session_id]
29
+ logger.info(f"WebSocket disconnected for session {session_id}")
30
+
31
+ async def send_event(
32
+ self, session_id: str, event_type: str, data: dict[str, Any] | None = None
33
+ ) -> None:
34
+ """Send an event to a specific session's WebSocket."""
35
+ if session_id not in self.active_connections:
36
+ logger.warning(f"No active connection for session {session_id}")
37
+ return
38
+
39
+ message = {"event_type": event_type}
40
+ if data is not None:
41
+ message["data"] = data
42
+
43
+ try:
44
+ await self.active_connections[session_id].send_json(message)
45
+ except Exception as e:
46
+ logger.error(f"Error sending to session {session_id}: {e}")
47
+ self.disconnect(session_id)
48
+
49
+ async def broadcast(
50
+ self, event_type: str, data: dict[str, Any] | None = None
51
+ ) -> None:
52
+ """Broadcast an event to all connected sessions."""
53
+ for session_id in list(self.active_connections.keys()):
54
+ await self.send_event(session_id, event_type, data)
55
+
56
+ def is_connected(self, session_id: str) -> bool:
57
+ """Check if a session has an active WebSocket connection."""
58
+ return session_id in self.active_connections
59
+
60
+
61
+ # Global connection manager instance
62
+ manager = ConnectionManager()