ml-agent / agent /main.py
akseljoonas's picture
akseljoonas HF Staff
Add approval display details for hf_repo_files and hf_repo_git
5558a57
"""
Interactive CLI chat with the agent
"""
import asyncio
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import litellm
from lmnr import Laminar, LaminarLiteLLMCallback
from prompt_toolkit import PromptSession
from agent.config import load_config
from agent.core.agent_loop import submission_loop
from agent.core.session import OpType
from agent.core.tools import ToolRouter
from agent.utils.reliability_checks import check_training_script_save_pattern
from agent.utils.terminal_display import (
format_error,
format_header,
format_plan_display,
format_separator,
format_success,
format_tool_call,
format_tool_output,
format_turn_complete,
)
litellm.drop_params = True
def _safe_get_args(arguments: dict) -> dict:
"""Safely extract args dict from arguments, handling cases where LLM passes string."""
args = arguments.get("args", {})
# Sometimes LLM passes args as string instead of dict
if isinstance(args, str):
return {}
return args if isinstance(args, dict) else {}
lmnr_api_key = os.environ.get("LMNR_API_KEY")
if lmnr_api_key:
try:
Laminar.initialize(project_api_key=lmnr_api_key)
litellm.callbacks = [LaminarLiteLLMCallback()]
print("Laminar initialized")
except Exception as e:
print(f"Failed to initialize Laminar: {e}")
@dataclass
class Operation:
"""Operation to be executed by the agent"""
op_type: OpType
data: Optional[dict[str, Any]] = None
@dataclass
class Submission:
"""Submission to the agent loop"""
id: str
operation: Operation
async def event_listener(
event_queue: asyncio.Queue,
submission_queue: asyncio.Queue,
turn_complete_event: asyncio.Event,
ready_event: asyncio.Event,
prompt_session: PromptSession,
config=None,
) -> None:
"""Background task that listens for events and displays them"""
submission_id = [1000] # Use list to make it mutable in closure
last_tool_name = [None] # Track last tool called
while True:
try:
event = await event_queue.get()
# Display event
if event.event_type == "ready":
print(format_success("\U0001f917 Agent ready"))
ready_event.set()
elif event.event_type == "assistant_message":
content = event.data.get("content", "") if event.data else ""
if content:
print(f"\nAssistant: {content}")
elif event.event_type == "tool_call":
tool_name = event.data.get("tool", "") if event.data else ""
arguments = event.data.get("arguments", {}) if event.data else {}
if tool_name:
last_tool_name[0] = tool_name # Store for tool_output event
args_str = json.dumps(arguments)[:100] + "..."
print(format_tool_call(tool_name, args_str))
elif event.event_type == "tool_output":
output = event.data.get("output", "") if event.data else ""
success = event.data.get("success", False) if event.data else False
if output:
# Don't truncate plan_tool output, truncate everything else
should_truncate = last_tool_name[0] != "plan_tool"
print(format_tool_output(output, success, truncate=should_truncate))
elif event.event_type == "turn_complete":
print(format_turn_complete())
# Display plan after turn complete
plan_display = format_plan_display()
if plan_display:
print(plan_display)
turn_complete_event.set()
elif event.event_type == "error":
error = (
event.data.get("error", "Unknown error")
if event.data
else "Unknown error"
)
print(format_error(error))
turn_complete_event.set()
elif event.event_type == "shutdown":
break
elif event.event_type == "processing":
pass # print("Processing...", flush=True)
elif event.event_type == "compacted":
old_tokens = event.data.get("old_tokens", 0) if event.data else 0
new_tokens = event.data.get("new_tokens", 0) if event.data else 0
print(f"Compacted context: {old_tokens}{new_tokens} tokens")
elif event.event_type == "approval_required":
# Handle batch approval format
tools_data = event.data.get("tools", []) if event.data else []
count = event.data.get("count", 0) if event.data else 0
# If yolo mode is active, auto-approve everything
if config and config.yolo_mode:
approvals = [
{
"tool_call_id": t.get("tool_call_id", ""),
"approved": True,
"feedback": None,
}
for t in tools_data
]
print(f"\n⚡ YOLO MODE: Auto-approving {count} item(s)")
submission_id[0] += 1
approval_submission = Submission(
id=f"approval_{submission_id[0]}",
operation=Operation(
op_type=OpType.EXEC_APPROVAL,
data={"approvals": approvals},
),
)
await submission_queue.put(approval_submission)
continue
print("\n" + format_separator())
print(
format_header(
f"APPROVAL REQUIRED ({count} item{'s' if count != 1 else ''})"
)
)
print(format_separator())
approvals = []
# Ask for approval for each tool
for i, tool_info in enumerate(tools_data, 1):
tool_name = tool_info.get("tool", "")
arguments = tool_info.get("arguments", {})
tool_call_id = tool_info.get("tool_call_id", "")
# Handle case where arguments might be a JSON string
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
print(f"Warning: Failed to parse arguments for {tool_name}")
arguments = {}
operation = arguments.get("operation", "")
print(f"\n[Item {i}/{count}]")
print(f"Tool: {tool_name}")
print(f"Operation: {operation}")
# Handle different tool types
if tool_name == "hf_jobs":
# Check if this is Python mode (script) or Docker mode (command)
script = arguments.get("script")
command = arguments.get("command")
if script:
# Python mode
dependencies = arguments.get("dependencies", [])
python_version = arguments.get("python")
script_args = arguments.get("script_args", [])
# Show full script
print(f"Script:\n{script}")
if dependencies:
print(f"Dependencies: {', '.join(dependencies)}")
if python_version:
print(f"Python version: {python_version}")
if script_args:
print(f"Script args: {' '.join(script_args)}")
# Run reliability checks on the full script (not truncated)
check_message = check_training_script_save_pattern(script)
if check_message:
print(check_message)
elif command:
# Docker mode
image = arguments.get("image", "python:3.12")
command_str = (
" ".join(command)
if isinstance(command, list)
else str(command)
)
print(f"Docker image: {image}")
print(f"Command: {command_str}")
# Common parameters for jobs
hardware_flavor = arguments.get("hardware_flavor", "cpu-basic")
timeout = arguments.get("timeout", "30m")
env = arguments.get("env", {})
schedule = arguments.get("schedule")
print(f"Hardware: {hardware_flavor}")
print(f"Timeout: {timeout}")
if env:
env_keys = ", ".join(env.keys())
print(f"Environment variables: {env_keys}")
if schedule:
print(f"Schedule: {schedule}")
elif tool_name == "hf_private_repos":
# Handle private repo operations
args = _safe_get_args(arguments)
if operation in ["create_repo", "upload_file"]:
repo_id = args.get("repo_id", "")
repo_type = args.get("repo_type", "dataset")
# Build repo URL
type_path = "" if repo_type == "model" else f"{repo_type}s"
repo_url = (
f"https://huggingface.co/{type_path}/{repo_id}".replace(
"//", "/"
)
)
print(f"Repository: {repo_id}")
print(f"Type: {repo_type}")
print("Private: Yes")
print(f"URL: {repo_url}")
# Show file preview for upload_file operation
if operation == "upload_file":
path_in_repo = args.get("path_in_repo", "")
file_content = args.get("file_content", "")
print(f"File: {path_in_repo}")
if isinstance(file_content, str):
# Calculate metrics
all_lines = file_content.split("\n")
line_count = len(all_lines)
size_bytes = len(file_content.encode("utf-8"))
size_kb = size_bytes / 1024
size_mb = size_kb / 1024
print(f"Line count: {line_count}")
if size_kb < 1024:
print(f"Size: {size_kb:.2f} KB")
else:
print(f"Size: {size_mb:.2f} MB")
# Show preview
preview_lines = all_lines[:5]
preview = "\n".join(preview_lines)
print(
f"Content preview (first 5 lines):\n{preview}"
)
if len(all_lines) > 5:
print("...")
elif tool_name == "hf_repo_files":
# Handle repo files operations (upload, delete)
repo_id = arguments.get("repo_id", "")
repo_type = arguments.get("repo_type", "model")
revision = arguments.get("revision", "main")
# Build repo URL
if repo_type == "model":
repo_url = f"https://huggingface.co/{repo_id}"
else:
repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
print(f"Repository: {repo_id}")
print(f"Type: {repo_type}")
print(f"Branch: {revision}")
print(f"URL: {repo_url}")
if operation == "upload":
path = arguments.get("path", "")
content = arguments.get("content", "")
create_pr = arguments.get("create_pr", False)
print(f"File: {path}")
if create_pr:
print("Mode: Create PR")
if isinstance(content, str):
all_lines = content.split("\n")
line_count = len(all_lines)
size_bytes = len(content.encode("utf-8"))
size_kb = size_bytes / 1024
print(f"Lines: {line_count}")
if size_kb < 1024:
print(f"Size: {size_kb:.2f} KB")
else:
print(f"Size: {size_kb / 1024:.2f} MB")
# Show full content
print(f"Content:\n{content}")
elif operation == "delete":
patterns = arguments.get("patterns", [])
if isinstance(patterns, str):
patterns = [patterns]
print(f"Patterns to delete: {', '.join(patterns)}")
elif tool_name == "hf_repo_git":
# Handle git operations (branches, tags, PRs, repo management)
repo_id = arguments.get("repo_id", "")
repo_type = arguments.get("repo_type", "model")
# Build repo URL
if repo_type == "model":
repo_url = f"https://huggingface.co/{repo_id}"
else:
repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
print(f"Repository: {repo_id}")
print(f"Type: {repo_type}")
print(f"URL: {repo_url}")
if operation == "delete_branch":
branch = arguments.get("branch", "")
print(f"Branch to delete: {branch}")
elif operation == "delete_tag":
tag = arguments.get("tag", "")
print(f"Tag to delete: {tag}")
elif operation == "merge_pr":
pr_num = arguments.get("pr_num", "")
print(f"PR to merge: #{pr_num}")
elif operation == "create_repo":
private = arguments.get("private", False)
space_sdk = arguments.get("space_sdk")
print(f"Private: {private}")
if space_sdk:
print(f"Space SDK: {space_sdk}")
elif operation == "update_repo":
private = arguments.get("private")
gated = arguments.get("gated")
if private is not None:
print(f"Private: {private}")
if gated is not None:
print(f"Gated: {gated}")
# Get user decision for this item
response = await prompt_session.prompt_async(
f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
)
response = response.strip().lower()
# Handle yolo mode activation
if response == "yolo":
config.yolo_mode = True
print(
"⚡ YOLO MODE ACTIVATED - Auto-approving all future tool calls"
)
# Auto-approve this item and all remaining
approvals.append(
{
"tool_call_id": tool_call_id,
"approved": True,
"feedback": None,
}
)
for remaining in tools_data[i:]:
approvals.append(
{
"tool_call_id": remaining.get("tool_call_id", ""),
"approved": True,
"feedback": None,
}
)
break
approved = response in ["y", "yes"]
feedback = None if approved or response in ["n", "no"] else response
approvals.append(
{
"tool_call_id": tool_call_id,
"approved": approved,
"feedback": feedback,
}
)
# Submit batch approval
submission_id[0] += 1
approval_submission = Submission(
id=f"approval_{submission_id[0]}",
operation=Operation(
op_type=OpType.EXEC_APPROVAL,
data={"approvals": approvals},
),
)
await submission_queue.put(approval_submission)
print(format_separator() + "\n")
# Silently ignore other events
except asyncio.CancelledError:
break
except Exception as e:
print(f"Event listener error: {e}")
async def get_user_input(prompt_session: PromptSession) -> str:
"""Get user input asynchronously"""
from prompt_toolkit.formatted_text import HTML
return await prompt_session.prompt_async(HTML("\n<b><cyan>></cyan></b> "))
async def main():
"""Interactive chat with the agent"""
from agent.utils.terminal_display import Colors
# Clear screen
os.system("clear" if os.name != "nt" else "cls")
banner = r"""
_ _ _ _____ _ _
| | | |_ _ __ _ __ _(_)_ __ __ _ | ___|_ _ ___ ___ / \ __ _ ___ _ __ | |_
| |_| | | | |/ _` |/ _` | | '_ \ / _` | | |_ / _` |/ __/ _ \ / _ \ / _` |/ _ \ '_ \| __|
| _ | |_| | (_| | (_| | | | | | (_| | | _| (_| | (_| __/ / ___ \ (_| | __/ | | | |_
|_| |_|\__,_|\__, |\__, |_|_| |_|\__, | |_| \__,_|\___\___| /_/ \_\__, |\___|_| |_|\__|
|___/ |___/ |___/ |___/
"""
print(format_separator())
print(f"{Colors.YELLOW} {banner}{Colors.RESET}")
print("Type your messages below. Type 'exit', 'quit', or '/quit' to end.\n")
print(format_separator())
# Wait for agent to initialize
print("Initializing agent...")
# Create queues for communication
submission_queue = asyncio.Queue()
event_queue = asyncio.Queue()
# Events to signal agent state
turn_complete_event = asyncio.Event()
turn_complete_event.set()
ready_event = asyncio.Event()
# Start agent loop in background
config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json"
config = load_config(config_path)
# Create tool router
print(f"Loading MCP servers: {', '.join(config.mcpServers.keys())}")
tool_router = ToolRouter(config.mcpServers)
# Create prompt session for input
prompt_session = PromptSession()
agent_task = asyncio.create_task(
submission_loop(
submission_queue,
event_queue,
config=config,
tool_router=tool_router,
)
)
# Start event listener in background
listener_task = asyncio.create_task(
event_listener(
event_queue,
submission_queue,
turn_complete_event,
ready_event,
prompt_session,
config,
)
)
await ready_event.wait()
submission_id = 0
try:
while True:
# Wait for previous turn to complete
await turn_complete_event.wait()
turn_complete_event.clear()
# Get user input
try:
user_input = await get_user_input(prompt_session)
except EOFError:
break
# Check for exit commands
if user_input.strip().lower() in ["exit", "quit", "/quit", "/exit"]:
break
# Skip empty input
if not user_input.strip():
turn_complete_event.set()
continue
# Submit to agent
submission_id += 1
submission = Submission(
id=f"sub_{submission_id}",
operation=Operation(
op_type=OpType.USER_INPUT, data={"text": user_input}
),
)
# print(f"Main submitting: {submission.operation.op_type}")
await submission_queue.put(submission)
except KeyboardInterrupt:
print("\n\nInterrupted by user")
# Shutdown
print("\n🛑 Shutting down agent...")
shutdown_submission = Submission(
id="sub_shutdown", operation=Operation(op_type=OpType.SHUTDOWN)
)
await submission_queue.put(shutdown_submission)
await asyncio.wait_for(agent_task, timeout=5.0)
listener_task.cancel()
print("✨ Goodbye!\n")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n\n✨ Goodbye!")