| | """
|
| | Interactive CLI chat with the agent
|
| | """
|
| |
|
| | import asyncio
|
| | import json
|
| | import os
|
| | from dataclasses import dataclass
|
| | from pathlib import Path
|
| | from typing import Any, Optional
|
| |
|
| | import litellm
|
| | from lmnr import Laminar, LaminarLiteLLMCallback
|
| | from prompt_toolkit import PromptSession
|
| |
|
| | from agent.config import load_config
|
| | from agent.core.agent_loop import submission_loop
|
| | from agent.core.session import OpType
|
| | from agent.core.tools import ToolRouter
|
| | from agent.utils.reliability_checks import check_training_script_save_pattern
|
| | from agent.utils.terminal_display import (
|
| | format_error,
|
| | format_header,
|
| | format_plan_display,
|
| | format_separator,
|
| | format_success,
|
| | format_tool_call,
|
| | format_tool_output,
|
| | format_turn_complete,
|
| | )
|
| |
|
| | litellm.drop_params = True
|
| |
|
| |
|
| | def _safe_get_args(arguments: dict) -> dict:
|
| | """Safely extract args dict from arguments, handling cases where LLM passes string."""
|
| | args = arguments.get("args", {})
|
| |
|
| | if isinstance(args, str):
|
| | return {}
|
| | return args if isinstance(args, dict) else {}
|
| |
|
| |
|
| | lmnr_api_key = os.environ.get("LMNR_API_KEY")
|
| | if lmnr_api_key:
|
| | try:
|
| | Laminar.initialize(project_api_key=lmnr_api_key)
|
| | litellm.callbacks = [LaminarLiteLLMCallback()]
|
| | print("Laminar initialized")
|
| | except Exception as e:
|
| | print(f"Failed to initialize Laminar: {e}")
|
| |
|
| |
|
| | @dataclass
|
| | class Operation:
|
| | """Operation to be executed by the agent"""
|
| |
|
| | op_type: OpType
|
| | data: Optional[dict[str, Any]] = None
|
| |
|
| |
|
| | @dataclass
|
| | class Submission:
|
| | """Submission to the agent loop"""
|
| |
|
| | id: str
|
| | operation: Operation
|
| |
|
| |
|
| | async def event_listener(
|
| | event_queue: asyncio.Queue,
|
| | submission_queue: asyncio.Queue,
|
| | turn_complete_event: asyncio.Event,
|
| | ready_event: asyncio.Event,
|
| | prompt_session: PromptSession,
|
| | config=None,
|
| | ) -> None:
|
| | """Background task that listens for events and displays them"""
|
| | submission_id = [1000]
|
| | last_tool_name = [None]
|
| |
|
| | while True:
|
| | try:
|
| | event = await event_queue.get()
|
| |
|
| |
|
| | if event.event_type == "ready":
|
| | print(format_success("\U0001f917 Agent ready"))
|
| | ready_event.set()
|
| | elif event.event_type == "assistant_message":
|
| | content = event.data.get("content", "") if event.data else ""
|
| | if content:
|
| | print(f"\nAssistant: {content}")
|
| | elif event.event_type == "tool_call":
|
| | tool_name = event.data.get("tool", "") if event.data else ""
|
| | arguments = event.data.get("arguments", {}) if event.data else {}
|
| | if tool_name:
|
| | last_tool_name[0] = tool_name
|
| | args_str = json.dumps(arguments)[:100] + "..."
|
| | print(format_tool_call(tool_name, args_str))
|
| | elif event.event_type == "tool_output":
|
| | output = event.data.get("output", "") if event.data else ""
|
| | success = event.data.get("success", False) if event.data else False
|
| | if output:
|
| |
|
| | should_truncate = last_tool_name[0] != "plan_tool"
|
| | print(format_tool_output(output, success, truncate=should_truncate))
|
| | elif event.event_type == "turn_complete":
|
| | print(format_turn_complete())
|
| |
|
| | plan_display = format_plan_display()
|
| | if plan_display:
|
| | print(plan_display)
|
| | turn_complete_event.set()
|
| | elif event.event_type == "error":
|
| | error = (
|
| | event.data.get("error", "Unknown error")
|
| | if event.data
|
| | else "Unknown error"
|
| | )
|
| | print(format_error(error))
|
| | turn_complete_event.set()
|
| | elif event.event_type == "shutdown":
|
| | break
|
| | elif event.event_type == "processing":
|
| | pass
|
| | elif event.event_type == "compacted":
|
| | old_tokens = event.data.get("old_tokens", 0) if event.data else 0
|
| | new_tokens = event.data.get("new_tokens", 0) if event.data else 0
|
| | print(f"Compacted context: {old_tokens} → {new_tokens} tokens")
|
| | elif event.event_type == "approval_required":
|
| |
|
| | tools_data = event.data.get("tools", []) if event.data else []
|
| | count = event.data.get("count", 0) if event.data else 0
|
| |
|
| |
|
| | if config and config.yolo_mode:
|
| | approvals = [
|
| | {
|
| | "tool_call_id": t.get("tool_call_id", ""),
|
| | "approved": True,
|
| | "feedback": None,
|
| | }
|
| | for t in tools_data
|
| | ]
|
| | print(f"\n⚡ YOLO MODE: Auto-approving {count} item(s)")
|
| | submission_id[0] += 1
|
| | approval_submission = Submission(
|
| | id=f"approval_{submission_id[0]}",
|
| | operation=Operation(
|
| | op_type=OpType.EXEC_APPROVAL,
|
| | data={"approvals": approvals},
|
| | ),
|
| | )
|
| | await submission_queue.put(approval_submission)
|
| | continue
|
| |
|
| | print("\n" + format_separator())
|
| | print(
|
| | format_header(
|
| | f"APPROVAL REQUIRED ({count} item{'s' if count != 1 else ''})"
|
| | )
|
| | )
|
| | print(format_separator())
|
| |
|
| | approvals = []
|
| |
|
| |
|
| | for i, tool_info in enumerate(tools_data, 1):
|
| | tool_name = tool_info.get("tool", "")
|
| | arguments = tool_info.get("arguments", {})
|
| | tool_call_id = tool_info.get("tool_call_id", "")
|
| |
|
| |
|
| | if isinstance(arguments, str):
|
| | try:
|
| | arguments = json.loads(arguments)
|
| | except json.JSONDecodeError:
|
| | print(f"Warning: Failed to parse arguments for {tool_name}")
|
| | arguments = {}
|
| |
|
| | operation = arguments.get("operation", "")
|
| |
|
| | print(f"\n[Item {i}/{count}]")
|
| | print(f"Tool: {tool_name}")
|
| | print(f"Operation: {operation}")
|
| |
|
| |
|
| | if tool_name == "hf_jobs":
|
| |
|
| | script = arguments.get("script")
|
| | command = arguments.get("command")
|
| |
|
| | if script:
|
| |
|
| | dependencies = arguments.get("dependencies", [])
|
| | python_version = arguments.get("python")
|
| | script_args = arguments.get("script_args", [])
|
| |
|
| |
|
| | print(f"Script:\n{script}")
|
| | if dependencies:
|
| | print(f"Dependencies: {', '.join(dependencies)}")
|
| | if python_version:
|
| | print(f"Python version: {python_version}")
|
| | if script_args:
|
| | print(f"Script args: {' '.join(script_args)}")
|
| |
|
| |
|
| | check_message = check_training_script_save_pattern(script)
|
| | if check_message:
|
| | print(check_message)
|
| | elif command:
|
| |
|
| | image = arguments.get("image", "python:3.12")
|
| | command_str = (
|
| | " ".join(command)
|
| | if isinstance(command, list)
|
| | else str(command)
|
| | )
|
| | print(f"Docker image: {image}")
|
| | print(f"Command: {command_str}")
|
| |
|
| |
|
| | hardware_flavor = arguments.get("hardware_flavor", "cpu-basic")
|
| | timeout = arguments.get("timeout", "30m")
|
| | env = arguments.get("env", {})
|
| | schedule = arguments.get("schedule")
|
| |
|
| | print(f"Hardware: {hardware_flavor}")
|
| | print(f"Timeout: {timeout}")
|
| |
|
| | if env:
|
| | env_keys = ", ".join(env.keys())
|
| | print(f"Environment variables: {env_keys}")
|
| |
|
| | if schedule:
|
| | print(f"Schedule: {schedule}")
|
| |
|
| | elif tool_name == "hf_private_repos":
|
| |
|
| | args = _safe_get_args(arguments)
|
| |
|
| | if operation in ["create_repo", "upload_file"]:
|
| | repo_id = args.get("repo_id", "")
|
| | repo_type = args.get("repo_type", "dataset")
|
| |
|
| |
|
| | type_path = "" if repo_type == "model" else f"{repo_type}s"
|
| | repo_url = (
|
| | f"https://huggingface.co/{type_path}/{repo_id}".replace(
|
| | "//", "/"
|
| | )
|
| | )
|
| |
|
| | print(f"Repository: {repo_id}")
|
| | print(f"Type: {repo_type}")
|
| | print("Private: Yes")
|
| | print(f"URL: {repo_url}")
|
| |
|
| |
|
| | if operation == "upload_file":
|
| | path_in_repo = args.get("path_in_repo", "")
|
| | file_content = args.get("file_content", "")
|
| | print(f"File: {path_in_repo}")
|
| |
|
| | if isinstance(file_content, str):
|
| |
|
| | all_lines = file_content.split("\n")
|
| | line_count = len(all_lines)
|
| | size_bytes = len(file_content.encode("utf-8"))
|
| | size_kb = size_bytes / 1024
|
| | size_mb = size_kb / 1024
|
| |
|
| | print(f"Line count: {line_count}")
|
| | if size_kb < 1024:
|
| | print(f"Size: {size_kb:.2f} KB")
|
| | else:
|
| | print(f"Size: {size_mb:.2f} MB")
|
| |
|
| |
|
| | preview_lines = all_lines[:5]
|
| | preview = "\n".join(preview_lines)
|
| | print(
|
| | f"Content preview (first 5 lines):\n{preview}"
|
| | )
|
| | if len(all_lines) > 5:
|
| | print("...")
|
| |
|
| | elif tool_name == "hf_repo_files":
|
| |
|
| | repo_id = arguments.get("repo_id", "")
|
| | repo_type = arguments.get("repo_type", "model")
|
| | revision = arguments.get("revision", "main")
|
| |
|
| |
|
| | if repo_type == "model":
|
| | repo_url = f"https://huggingface.co/{repo_id}"
|
| | else:
|
| | repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
|
| |
|
| | print(f"Repository: {repo_id}")
|
| | print(f"Type: {repo_type}")
|
| | print(f"Branch: {revision}")
|
| | print(f"URL: {repo_url}")
|
| |
|
| | if operation == "upload":
|
| | path = arguments.get("path", "")
|
| | content = arguments.get("content", "")
|
| | create_pr = arguments.get("create_pr", False)
|
| |
|
| | print(f"File: {path}")
|
| | if create_pr:
|
| | print("Mode: Create PR")
|
| |
|
| | if isinstance(content, str):
|
| | all_lines = content.split("\n")
|
| | line_count = len(all_lines)
|
| | size_bytes = len(content.encode("utf-8"))
|
| | size_kb = size_bytes / 1024
|
| |
|
| | print(f"Lines: {line_count}")
|
| | if size_kb < 1024:
|
| | print(f"Size: {size_kb:.2f} KB")
|
| | else:
|
| | print(f"Size: {size_kb / 1024:.2f} MB")
|
| |
|
| |
|
| | print(f"Content:\n{content}")
|
| |
|
| | elif operation == "delete":
|
| | patterns = arguments.get("patterns", [])
|
| | if isinstance(patterns, str):
|
| | patterns = [patterns]
|
| | print(f"Patterns to delete: {', '.join(patterns)}")
|
| |
|
| | elif tool_name == "hf_repo_git":
|
| |
|
| | repo_id = arguments.get("repo_id", "")
|
| | repo_type = arguments.get("repo_type", "model")
|
| |
|
| |
|
| | if repo_type == "model":
|
| | repo_url = f"https://huggingface.co/{repo_id}"
|
| | else:
|
| | repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}"
|
| |
|
| | print(f"Repository: {repo_id}")
|
| | print(f"Type: {repo_type}")
|
| | print(f"URL: {repo_url}")
|
| |
|
| | if operation == "delete_branch":
|
| | branch = arguments.get("branch", "")
|
| | print(f"Branch to delete: {branch}")
|
| |
|
| | elif operation == "delete_tag":
|
| | tag = arguments.get("tag", "")
|
| | print(f"Tag to delete: {tag}")
|
| |
|
| | elif operation == "merge_pr":
|
| | pr_num = arguments.get("pr_num", "")
|
| | print(f"PR to merge: #{pr_num}")
|
| |
|
| | elif operation == "create_repo":
|
| | private = arguments.get("private", False)
|
| | space_sdk = arguments.get("space_sdk")
|
| | print(f"Private: {private}")
|
| | if space_sdk:
|
| | print(f"Space SDK: {space_sdk}")
|
| |
|
| | elif operation == "update_repo":
|
| | private = arguments.get("private")
|
| | gated = arguments.get("gated")
|
| | if private is not None:
|
| | print(f"Private: {private}")
|
| | if gated is not None:
|
| | print(f"Gated: {gated}")
|
| |
|
| |
|
| | response = await prompt_session.prompt_async(
|
| | f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): "
|
| | )
|
| |
|
| | response = response.strip().lower()
|
| |
|
| |
|
| | if response == "yolo":
|
| | config.yolo_mode = True
|
| | print(
|
| | "⚡ YOLO MODE ACTIVATED - Auto-approving all future tool calls"
|
| | )
|
| |
|
| | approvals.append(
|
| | {
|
| | "tool_call_id": tool_call_id,
|
| | "approved": True,
|
| | "feedback": None,
|
| | }
|
| | )
|
| | for remaining in tools_data[i:]:
|
| | approvals.append(
|
| | {
|
| | "tool_call_id": remaining.get("tool_call_id", ""),
|
| | "approved": True,
|
| | "feedback": None,
|
| | }
|
| | )
|
| | break
|
| |
|
| | approved = response in ["y", "yes"]
|
| | feedback = None if approved or response in ["n", "no"] else response
|
| |
|
| | approvals.append(
|
| | {
|
| | "tool_call_id": tool_call_id,
|
| | "approved": approved,
|
| | "feedback": feedback,
|
| | }
|
| | )
|
| |
|
| |
|
| | submission_id[0] += 1
|
| | approval_submission = Submission(
|
| | id=f"approval_{submission_id[0]}",
|
| | operation=Operation(
|
| | op_type=OpType.EXEC_APPROVAL,
|
| | data={"approvals": approvals},
|
| | ),
|
| | )
|
| | await submission_queue.put(approval_submission)
|
| | print(format_separator() + "\n")
|
| |
|
| |
|
| | except asyncio.CancelledError:
|
| | break
|
| | except Exception as e:
|
| | print(f"Event listener error: {e}")
|
| |
|
| |
|
| | async def get_user_input(prompt_session: PromptSession) -> str:
|
| | """Get user input asynchronously"""
|
| | from prompt_toolkit.formatted_text import HTML
|
| |
|
| | return await prompt_session.prompt_async(HTML("\n<b><cyan>></cyan></b> "))
|
| |
|
| |
|
| | async def main():
|
| | """Interactive chat with the agent"""
|
| | from agent.utils.terminal_display import Colors
|
| |
|
| |
|
| | os.system("clear" if os.name != "nt" else "cls")
|
| |
|
| | banner = r"""
|
| | _ _ _ _____ _ _
|
| | | | | |_ _ __ _ __ _(_)_ __ __ _ | ___|_ _ ___ ___ / \ __ _ ___ _ __ | |_
|
| | | |_| | | | |/ _` |/ _` | | '_ \ / _` | | |_ / _` |/ __/ _ \ / _ \ / _` |/ _ \ '_ \| __|
|
| | | _ | |_| | (_| | (_| | | | | | (_| | | _| (_| | (_| __/ / ___ \ (_| | __/ | | | |_
|
| | |_| |_|\__,_|\__, |\__, |_|_| |_|\__, | |_| \__,_|\___\___| /_/ \_\__, |\___|_| |_|\__|
|
| | |___/ |___/ |___/ |___/
|
| | """
|
| |
|
| | print(format_separator())
|
| | print(f"{Colors.YELLOW} {banner}{Colors.RESET}")
|
| | print("Type your messages below. Type 'exit', 'quit', or '/quit' to end.\n")
|
| | print(format_separator())
|
| |
|
| | print("Initializing agent...")
|
| |
|
| |
|
| | submission_queue = asyncio.Queue()
|
| | event_queue = asyncio.Queue()
|
| |
|
| |
|
| | turn_complete_event = asyncio.Event()
|
| | turn_complete_event.set()
|
| | ready_event = asyncio.Event()
|
| |
|
| |
|
| | config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json"
|
| | config = load_config(config_path)
|
| |
|
| |
|
| | print(f"Loading MCP servers: {', '.join(config.mcpServers.keys())}")
|
| | tool_router = ToolRouter(config.mcpServers)
|
| |
|
| |
|
| | prompt_session = PromptSession()
|
| |
|
| | agent_task = asyncio.create_task(
|
| | submission_loop(
|
| | submission_queue,
|
| | event_queue,
|
| | config=config,
|
| | tool_router=tool_router,
|
| | )
|
| | )
|
| |
|
| |
|
| | listener_task = asyncio.create_task(
|
| | event_listener(
|
| | event_queue,
|
| | submission_queue,
|
| | turn_complete_event,
|
| | ready_event,
|
| | prompt_session,
|
| | config,
|
| | )
|
| | )
|
| |
|
| | await ready_event.wait()
|
| |
|
| | submission_id = 0
|
| |
|
| | try:
|
| | while True:
|
| |
|
| | await turn_complete_event.wait()
|
| | turn_complete_event.clear()
|
| |
|
| |
|
| | try:
|
| | user_input = await get_user_input(prompt_session)
|
| | except EOFError:
|
| | break
|
| |
|
| |
|
| | if user_input.strip().lower() in ["exit", "quit", "/quit", "/exit"]:
|
| | break
|
| |
|
| |
|
| | if not user_input.strip():
|
| | turn_complete_event.set()
|
| | continue
|
| |
|
| |
|
| | submission_id += 1
|
| | submission = Submission(
|
| | id=f"sub_{submission_id}",
|
| | operation=Operation(
|
| | op_type=OpType.USER_INPUT, data={"text": user_input}
|
| | ),
|
| | )
|
| |
|
| | await submission_queue.put(submission)
|
| |
|
| | except KeyboardInterrupt:
|
| | print("\n\nInterrupted by user")
|
| |
|
| |
|
| | print("\n🛑 Shutting down agent...")
|
| | shutdown_submission = Submission(
|
| | id="sub_shutdown", operation=Operation(op_type=OpType.SHUTDOWN)
|
| | )
|
| | await submission_queue.put(shutdown_submission)
|
| |
|
| | await asyncio.wait_for(agent_task, timeout=5.0)
|
| | listener_task.cancel()
|
| |
|
| | print("✨ Goodbye!\n")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | try:
|
| | asyncio.run(main())
|
| | except KeyboardInterrupt:
|
| | print("\n\n✨ Goodbye!")
|
| |
|