""" Interactive CLI chat with the agent """ import asyncio import json import os from dataclasses import dataclass from pathlib import Path from typing import Any, Optional import litellm from lmnr import Laminar, LaminarLiteLLMCallback from prompt_toolkit import PromptSession from agent.config import load_config from agent.core.agent_loop import submission_loop from agent.core.session import OpType from agent.core.tools import ToolRouter from agent.utils.reliability_checks import check_training_script_save_pattern from agent.utils.terminal_display import ( format_error, format_header, format_plan_display, format_separator, format_success, format_tool_call, format_tool_output, format_turn_complete, ) litellm.drop_params = True def _safe_get_args(arguments: dict) -> dict: """Safely extract args dict from arguments, handling cases where LLM passes string.""" args = arguments.get("args", {}) # Sometimes LLM passes args as string instead of dict if isinstance(args, str): return {} return args if isinstance(args, dict) else {} lmnr_api_key = os.environ.get("LMNR_API_KEY") if lmnr_api_key: try: Laminar.initialize(project_api_key=lmnr_api_key) litellm.callbacks = [LaminarLiteLLMCallback()] print("Laminar initialized") except Exception as e: print(f"Failed to initialize Laminar: {e}") @dataclass class Operation: """Operation to be executed by the agent""" op_type: OpType data: Optional[dict[str, Any]] = None @dataclass class Submission: """Submission to the agent loop""" id: str operation: Operation async def event_listener( event_queue: asyncio.Queue, submission_queue: asyncio.Queue, turn_complete_event: asyncio.Event, ready_event: asyncio.Event, prompt_session: PromptSession, config=None, ) -> None: """Background task that listens for events and displays them""" submission_id = [1000] # Use list to make it mutable in closure last_tool_name = [None] # Track last tool called while True: try: event = await event_queue.get() # Display event if event.event_type == "ready": print(format_success("\U0001f917 Agent ready")) ready_event.set() elif event.event_type == "assistant_message": content = event.data.get("content", "") if event.data else "" if content: print(f"\nAssistant: {content}") elif event.event_type == "tool_call": tool_name = event.data.get("tool", "") if event.data else "" arguments = event.data.get("arguments", {}) if event.data else {} if tool_name: last_tool_name[0] = tool_name # Store for tool_output event args_str = json.dumps(arguments)[:100] + "..." print(format_tool_call(tool_name, args_str)) elif event.event_type == "tool_output": output = event.data.get("output", "") if event.data else "" success = event.data.get("success", False) if event.data else False if output: # Don't truncate plan_tool output, truncate everything else should_truncate = last_tool_name[0] != "plan_tool" print(format_tool_output(output, success, truncate=should_truncate)) elif event.event_type == "turn_complete": print(format_turn_complete()) # Display plan after turn complete plan_display = format_plan_display() if plan_display: print(plan_display) turn_complete_event.set() elif event.event_type == "error": error = ( event.data.get("error", "Unknown error") if event.data else "Unknown error" ) print(format_error(error)) turn_complete_event.set() elif event.event_type == "shutdown": break elif event.event_type == "processing": pass # print("Processing...", flush=True) elif event.event_type == "compacted": old_tokens = event.data.get("old_tokens", 0) if event.data else 0 new_tokens = event.data.get("new_tokens", 0) if event.data else 0 print(f"Compacted context: {old_tokens} → {new_tokens} tokens") elif event.event_type == "approval_required": # Handle batch approval format tools_data = event.data.get("tools", []) if event.data else [] count = event.data.get("count", 0) if event.data else 0 # If yolo mode is active, auto-approve everything if config and config.yolo_mode: approvals = [ { "tool_call_id": t.get("tool_call_id", ""), "approved": True, "feedback": None, } for t in tools_data ] print(f"\n⚡ YOLO MODE: Auto-approving {count} item(s)") submission_id[0] += 1 approval_submission = Submission( id=f"approval_{submission_id[0]}", operation=Operation( op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals}, ), ) await submission_queue.put(approval_submission) continue print("\n" + format_separator()) print( format_header( f"APPROVAL REQUIRED ({count} item{'s' if count != 1 else ''})" ) ) print(format_separator()) approvals = [] # Ask for approval for each tool for i, tool_info in enumerate(tools_data, 1): tool_name = tool_info.get("tool", "") arguments = tool_info.get("arguments", {}) tool_call_id = tool_info.get("tool_call_id", "") # Handle case where arguments might be a JSON string if isinstance(arguments, str): try: arguments = json.loads(arguments) except json.JSONDecodeError: print(f"Warning: Failed to parse arguments for {tool_name}") arguments = {} operation = arguments.get("operation", "") print(f"\n[Item {i}/{count}]") print(f"Tool: {tool_name}") print(f"Operation: {operation}") # Handle different tool types if tool_name == "hf_jobs": # Check if this is Python mode (script) or Docker mode (command) script = arguments.get("script") command = arguments.get("command") if script: # Python mode dependencies = arguments.get("dependencies", []) python_version = arguments.get("python") script_args = arguments.get("script_args", []) # Show full script print(f"Script:\n{script}") if dependencies: print(f"Dependencies: {', '.join(dependencies)}") if python_version: print(f"Python version: {python_version}") if script_args: print(f"Script args: {' '.join(script_args)}") # Run reliability checks on the full script (not truncated) check_message = check_training_script_save_pattern(script) if check_message: print(check_message) elif command: # Docker mode image = arguments.get("image", "python:3.12") command_str = ( " ".join(command) if isinstance(command, list) else str(command) ) print(f"Docker image: {image}") print(f"Command: {command_str}") # Common parameters for jobs hardware_flavor = arguments.get("hardware_flavor", "cpu-basic") timeout = arguments.get("timeout", "30m") env = arguments.get("env", {}) schedule = arguments.get("schedule") print(f"Hardware: {hardware_flavor}") print(f"Timeout: {timeout}") if env: env_keys = ", ".join(env.keys()) print(f"Environment variables: {env_keys}") if schedule: print(f"Schedule: {schedule}") elif tool_name == "hf_private_repos": # Handle private repo operations args = _safe_get_args(arguments) if operation in ["create_repo", "upload_file"]: repo_id = args.get("repo_id", "") repo_type = args.get("repo_type", "dataset") # Build repo URL type_path = "" if repo_type == "model" else f"{repo_type}s" repo_url = ( f"https://huggingface.co/{type_path}/{repo_id}".replace( "//", "/" ) ) print(f"Repository: {repo_id}") print(f"Type: {repo_type}") print("Private: Yes") print(f"URL: {repo_url}") # Show file preview for upload_file operation if operation == "upload_file": path_in_repo = args.get("path_in_repo", "") file_content = args.get("file_content", "") print(f"File: {path_in_repo}") if isinstance(file_content, str): # Calculate metrics all_lines = file_content.split("\n") line_count = len(all_lines) size_bytes = len(file_content.encode("utf-8")) size_kb = size_bytes / 1024 size_mb = size_kb / 1024 print(f"Line count: {line_count}") if size_kb < 1024: print(f"Size: {size_kb:.2f} KB") else: print(f"Size: {size_mb:.2f} MB") # Show preview preview_lines = all_lines[:5] preview = "\n".join(preview_lines) print( f"Content preview (first 5 lines):\n{preview}" ) if len(all_lines) > 5: print("...") elif tool_name == "hf_repo_files": # Handle repo files operations (upload, delete) repo_id = arguments.get("repo_id", "") repo_type = arguments.get("repo_type", "model") revision = arguments.get("revision", "main") # Build repo URL if repo_type == "model": repo_url = f"https://huggingface.co/{repo_id}" else: repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}" print(f"Repository: {repo_id}") print(f"Type: {repo_type}") print(f"Branch: {revision}") print(f"URL: {repo_url}") if operation == "upload": path = arguments.get("path", "") content = arguments.get("content", "") create_pr = arguments.get("create_pr", False) print(f"File: {path}") if create_pr: print("Mode: Create PR") if isinstance(content, str): all_lines = content.split("\n") line_count = len(all_lines) size_bytes = len(content.encode("utf-8")) size_kb = size_bytes / 1024 print(f"Lines: {line_count}") if size_kb < 1024: print(f"Size: {size_kb:.2f} KB") else: print(f"Size: {size_kb / 1024:.2f} MB") # Show full content print(f"Content:\n{content}") elif operation == "delete": patterns = arguments.get("patterns", []) if isinstance(patterns, str): patterns = [patterns] print(f"Patterns to delete: {', '.join(patterns)}") elif tool_name == "hf_repo_git": # Handle git operations (branches, tags, PRs, repo management) repo_id = arguments.get("repo_id", "") repo_type = arguments.get("repo_type", "model") # Build repo URL if repo_type == "model": repo_url = f"https://huggingface.co/{repo_id}" else: repo_url = f"https://huggingface.co/{repo_type}s/{repo_id}" print(f"Repository: {repo_id}") print(f"Type: {repo_type}") print(f"URL: {repo_url}") if operation == "delete_branch": branch = arguments.get("branch", "") print(f"Branch to delete: {branch}") elif operation == "delete_tag": tag = arguments.get("tag", "") print(f"Tag to delete: {tag}") elif operation == "merge_pr": pr_num = arguments.get("pr_num", "") print(f"PR to merge: #{pr_num}") elif operation == "create_repo": private = arguments.get("private", False) space_sdk = arguments.get("space_sdk") print(f"Private: {private}") if space_sdk: print(f"Space SDK: {space_sdk}") elif operation == "update_repo": private = arguments.get("private") gated = arguments.get("gated") if private is not None: print(f"Private: {private}") if gated is not None: print(f"Gated: {gated}") # Get user decision for this item response = await prompt_session.prompt_async( f"Approve item {i}? (y=yes, yolo=approve all, n=no, or provide feedback): " ) response = response.strip().lower() # Handle yolo mode activation if response == "yolo": config.yolo_mode = True print( "⚡ YOLO MODE ACTIVATED - Auto-approving all future tool calls" ) # Auto-approve this item and all remaining approvals.append( { "tool_call_id": tool_call_id, "approved": True, "feedback": None, } ) for remaining in tools_data[i:]: approvals.append( { "tool_call_id": remaining.get("tool_call_id", ""), "approved": True, "feedback": None, } ) break approved = response in ["y", "yes"] feedback = None if approved or response in ["n", "no"] else response approvals.append( { "tool_call_id": tool_call_id, "approved": approved, "feedback": feedback, } ) # Submit batch approval submission_id[0] += 1 approval_submission = Submission( id=f"approval_{submission_id[0]}", operation=Operation( op_type=OpType.EXEC_APPROVAL, data={"approvals": approvals}, ), ) await submission_queue.put(approval_submission) print(format_separator() + "\n") # Silently ignore other events except asyncio.CancelledError: break except Exception as e: print(f"Event listener error: {e}") async def get_user_input(prompt_session: PromptSession) -> str: """Get user input asynchronously""" from prompt_toolkit.formatted_text import HTML return await prompt_session.prompt_async(HTML("\n> ")) async def main(): """Interactive chat with the agent""" from agent.utils.terminal_display import Colors # Clear screen os.system("clear" if os.name != "nt" else "cls") banner = r""" _ _ _ _____ _ _ | | | |_ _ __ _ __ _(_)_ __ __ _ | ___|_ _ ___ ___ / \ __ _ ___ _ __ | |_ | |_| | | | |/ _` |/ _` | | '_ \ / _` | | |_ / _` |/ __/ _ \ / _ \ / _` |/ _ \ '_ \| __| | _ | |_| | (_| | (_| | | | | | (_| | | _| (_| | (_| __/ / ___ \ (_| | __/ | | | |_ |_| |_|\__,_|\__, |\__, |_|_| |_|\__, | |_| \__,_|\___\___| /_/ \_\__, |\___|_| |_|\__| |___/ |___/ |___/ |___/ """ print(format_separator()) print(f"{Colors.YELLOW} {banner}{Colors.RESET}") print("Type your messages below. Type 'exit', 'quit', or '/quit' to end.\n") print(format_separator()) # Wait for agent to initialize print("Initializing agent...") # Create queues for communication submission_queue = asyncio.Queue() event_queue = asyncio.Queue() # Events to signal agent state turn_complete_event = asyncio.Event() turn_complete_event.set() ready_event = asyncio.Event() # Start agent loop in background config_path = Path(__file__).parent.parent / "configs" / "main_agent_config.json" config = load_config(config_path) # Create tool router print(f"Loading MCP servers: {', '.join(config.mcpServers.keys())}") tool_router = ToolRouter(config.mcpServers) # Create prompt session for input prompt_session = PromptSession() agent_task = asyncio.create_task( submission_loop( submission_queue, event_queue, config=config, tool_router=tool_router, ) ) # Start event listener in background listener_task = asyncio.create_task( event_listener( event_queue, submission_queue, turn_complete_event, ready_event, prompt_session, config, ) ) await ready_event.wait() submission_id = 0 try: while True: # Wait for previous turn to complete await turn_complete_event.wait() turn_complete_event.clear() # Get user input try: user_input = await get_user_input(prompt_session) except EOFError: break # Check for exit commands if user_input.strip().lower() in ["exit", "quit", "/quit", "/exit"]: break # Skip empty input if not user_input.strip(): turn_complete_event.set() continue # Submit to agent submission_id += 1 submission = Submission( id=f"sub_{submission_id}", operation=Operation( op_type=OpType.USER_INPUT, data={"text": user_input} ), ) # print(f"Main submitting: {submission.operation.op_type}") await submission_queue.put(submission) except KeyboardInterrupt: print("\n\nInterrupted by user") # Shutdown print("\n🛑 Shutting down agent...") shutdown_submission = Submission( id="sub_shutdown", operation=Operation(op_type=OpType.SHUTDOWN) ) await submission_queue.put(shutdown_submission) await asyncio.wait_for(agent_task, timeout=5.0) listener_task.cancel() print("✨ Goodbye!\n") if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\n\n✨ Goodbye!")