| | """loop
|
| | Main agent implementation with integrated tool system and MCP support
|
| | """
|
| |
|
| | import asyncio
|
| | import json
|
| | import logging
|
| | import os
|
| |
|
| | from litellm import ChatCompletionMessageToolCall, Message, acompletion
|
| | from lmnr import observe
|
| |
|
| | from agent.config import Config
|
| | from agent.core.session import Event, OpType, Session
|
| | from agent.core.tools import ToolRouter
|
| | from agent.tools.jobs_tool import CPU_FLAVORS
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| | ToolCall = ChatCompletionMessageToolCall
|
| |
|
| |
|
| | _INFERENCE_API_KEY = os.environ.get("INFERENCE_TOKEN")
|
| |
|
| |
|
| | def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]:
|
| | """
|
| | Validate tool arguments structure.
|
| |
|
| | Returns:
|
| | (is_valid, error_message)
|
| | """
|
| | args = tool_args.get("args", {})
|
| |
|
| | if isinstance(args, str):
|
| | return (
|
| | False,
|
| | f"Tool call error: 'args' must be a JSON object, not a string. You passed: {repr(args)}",
|
| | )
|
| | if not isinstance(args, dict) and args is not None:
|
| | return (
|
| | False,
|
| | f"Tool call error: 'args' must be a JSON object. You passed type: {type(args).__name__}",
|
| | )
|
| | return True, None
|
| |
|
| |
|
| | def _needs_approval(
|
| | tool_name: str, tool_args: dict, config: Config | None = None
|
| | ) -> bool:
|
| | """Check if a tool call requires user approval before execution."""
|
| |
|
| | if config and config.yolo_mode:
|
| | return False
|
| |
|
| |
|
| | args_valid, _ = _validate_tool_args(tool_args)
|
| | if not args_valid:
|
| | return False
|
| |
|
| | if tool_name == "hf_jobs":
|
| | operation = tool_args.get("operation", "")
|
| | if operation not in ["run", "uv", "scheduled run", "scheduled uv"]:
|
| | return False
|
| |
|
| |
|
| |
|
| | hardware_flavor = (
|
| | tool_args.get("hardware_flavor")
|
| | or tool_args.get("flavor")
|
| | or tool_args.get("hardware")
|
| | or "cpu-basic"
|
| | )
|
| | is_cpu_job = hardware_flavor in CPU_FLAVORS
|
| |
|
| | if is_cpu_job:
|
| | if config and not config.confirm_cpu_jobs:
|
| | return False
|
| | return True
|
| |
|
| | return True
|
| |
|
| |
|
| | if tool_name == "hf_private_repos":
|
| | operation = tool_args.get("operation", "")
|
| | if operation == "upload_file":
|
| | if config and config.auto_file_upload:
|
| | return False
|
| | return True
|
| |
|
| | if operation in ["create_repo"]:
|
| | return True
|
| |
|
| |
|
| | if tool_name == "hf_repo_files":
|
| | operation = tool_args.get("operation", "")
|
| | if operation in ["upload", "delete"]:
|
| | return True
|
| |
|
| |
|
| | if tool_name == "hf_repo_git":
|
| | operation = tool_args.get("operation", "")
|
| | if operation in [
|
| | "delete_branch",
|
| | "delete_tag",
|
| | "merge_pr",
|
| | "create_repo",
|
| | "update_repo",
|
| | ]:
|
| | return True
|
| |
|
| | return False
|
| |
|
| |
|
| | class Handlers:
|
| | """Handler functions for each operation type"""
|
| |
|
| | @staticmethod
|
| | @observe(name="run_agent")
|
| | async def run_agent(
|
| | session: Session, text: str, max_iterations: int = 10
|
| | ) -> str | None:
|
| | """
|
| | Handle user input (like user_input_or_turn in codex.rs:1291)
|
| | Returns the final assistant response content, if any.
|
| | """
|
| |
|
| | if hasattr(session, "session_id"):
|
| | from lmnr import Laminar
|
| |
|
| | Laminar.set_trace_session_id(session_id=session.session_id)
|
| |
|
| |
|
| | if text:
|
| | user_msg = Message(role="user", content=text)
|
| | session.context_manager.add_message(user_msg)
|
| |
|
| |
|
| | await session.send_event(
|
| | Event(event_type="processing", data={"message": "Processing user input"})
|
| | )
|
| |
|
| |
|
| | iteration = 0
|
| | final_response = None
|
| |
|
| | while iteration < max_iterations:
|
| | messages = session.context_manager.get_messages()
|
| | tools = session.tool_router.get_tool_specs_for_llm()
|
| | try:
|
| |
|
| | response = await acompletion(
|
| | model=session.config.model_name,
|
| | messages=messages,
|
| | tools=tools,
|
| | tool_choice="auto",
|
| | stream=True,
|
| | stream_options={"include_usage": True},
|
| | api_key=_INFERENCE_API_KEY
|
| | if _INFERENCE_API_KEY
|
| | and session.config.model_name.startswith("huggingface/")
|
| | else None,
|
| | )
|
| |
|
| | full_content = ""
|
| | tool_calls_acc: dict[int, dict] = {}
|
| | token_count = 0
|
| |
|
| | async for chunk in response:
|
| | choice = chunk.choices[0] if chunk.choices else None
|
| | if not choice:
|
| |
|
| | if hasattr(chunk, "usage") and chunk.usage:
|
| | token_count = chunk.usage.total_tokens
|
| | continue
|
| |
|
| | delta = choice.delta
|
| |
|
| |
|
| | if delta.content:
|
| | full_content += delta.content
|
| | await session.send_event(
|
| | Event(
|
| | event_type="assistant_chunk",
|
| | data={"content": delta.content},
|
| | )
|
| | )
|
| |
|
| |
|
| | if delta.tool_calls:
|
| | for tc_delta in delta.tool_calls:
|
| | idx = tc_delta.index
|
| | if idx not in tool_calls_acc:
|
| | tool_calls_acc[idx] = {
|
| | "id": "",
|
| | "type": "function",
|
| | "function": {"name": "", "arguments": ""},
|
| | }
|
| | if tc_delta.id:
|
| | tool_calls_acc[idx]["id"] = tc_delta.id
|
| | if tc_delta.function:
|
| | if tc_delta.function.name:
|
| | tool_calls_acc[idx]["function"]["name"] += (
|
| | tc_delta.function.name
|
| | )
|
| | if tc_delta.function.arguments:
|
| | tool_calls_acc[idx]["function"]["arguments"] += (
|
| | tc_delta.function.arguments
|
| | )
|
| |
|
| |
|
| | if hasattr(chunk, "usage") and chunk.usage:
|
| | token_count = chunk.usage.total_tokens
|
| |
|
| |
|
| | content = full_content or None
|
| |
|
| |
|
| | tool_calls: list[ToolCall] = []
|
| | for idx in sorted(tool_calls_acc.keys()):
|
| | tc_data = tool_calls_acc[idx]
|
| | tool_calls.append(
|
| | ToolCall(
|
| | id=tc_data["id"],
|
| | type="function",
|
| | function={
|
| | "name": tc_data["function"]["name"],
|
| | "arguments": tc_data["function"]["arguments"],
|
| | },
|
| | )
|
| | )
|
| |
|
| |
|
| | await session.send_event(
|
| | Event(event_type="assistant_stream_end", data={})
|
| | )
|
| |
|
| |
|
| | if not tool_calls:
|
| | if content:
|
| | assistant_msg = Message(role="assistant", content=content)
|
| | session.context_manager.add_message(assistant_msg, token_count)
|
| | final_response = content
|
| | break
|
| |
|
| |
|
| | assistant_msg = Message(
|
| | role="assistant",
|
| | content=content,
|
| | tool_calls=tool_calls,
|
| | )
|
| | session.context_manager.add_message(assistant_msg, token_count)
|
| |
|
| |
|
| | approval_required_tools = []
|
| | non_approval_tools = []
|
| |
|
| | for tc in tool_calls:
|
| | tool_name = tc.function.name
|
| | try:
|
| | tool_args = json.loads(tc.function.arguments)
|
| | except (json.JSONDecodeError, TypeError) as e:
|
| | logger.warning(f"Malformed tool arguments for {tool_name}: {e}")
|
| | tool_args = {}
|
| |
|
| | if _needs_approval(tool_name, tool_args, session.config):
|
| | approval_required_tools.append(tc)
|
| | else:
|
| | non_approval_tools.append(tc)
|
| |
|
| |
|
| | if non_approval_tools:
|
| |
|
| | parsed_tools: list[
|
| | tuple[ChatCompletionMessageToolCall, str, dict, bool, str]
|
| | ] = []
|
| | for tc in non_approval_tools:
|
| | tool_name = tc.function.name
|
| | try:
|
| | tool_args = json.loads(tc.function.arguments)
|
| | except (json.JSONDecodeError, TypeError):
|
| | tool_args = {}
|
| |
|
| | args_valid, error_msg = _validate_tool_args(tool_args)
|
| | parsed_tools.append(
|
| | (tc, tool_name, tool_args, args_valid, error_msg)
|
| | )
|
| |
|
| |
|
| | for tc, tool_name, tool_args, args_valid, _ in parsed_tools:
|
| | if args_valid:
|
| | await session.send_event(
|
| | Event(
|
| | event_type="tool_call",
|
| | data={
|
| | "tool": tool_name,
|
| | "arguments": tool_args,
|
| | "tool_call_id": tc.id,
|
| | },
|
| | )
|
| | )
|
| |
|
| |
|
| | async def _exec_tool(
|
| | tc: ChatCompletionMessageToolCall,
|
| | name: str,
|
| | args: dict,
|
| | valid: bool,
|
| | err: str,
|
| | ) -> tuple[ChatCompletionMessageToolCall, str, dict, str, bool]:
|
| | if not valid:
|
| | return (tc, name, args, err, False)
|
| | out, ok = await session.tool_router.call_tool(
|
| | name, args, session=session
|
| | )
|
| | return (tc, name, args, out, ok)
|
| |
|
| | results = await asyncio.gather(
|
| | *[
|
| | _exec_tool(tc, name, args, valid, err)
|
| | for tc, name, args, valid, err in parsed_tools
|
| | ]
|
| | )
|
| |
|
| |
|
| | for tc, tool_name, tool_args, output, success in results:
|
| | tool_msg = Message(
|
| | role="tool",
|
| | content=output,
|
| | tool_call_id=tc.id,
|
| | name=tool_name,
|
| | )
|
| | session.context_manager.add_message(tool_msg)
|
| |
|
| | await session.send_event(
|
| | Event(
|
| | event_type="tool_output",
|
| | data={
|
| | "tool": tool_name,
|
| | "tool_call_id": tc.id,
|
| | "output": output,
|
| | "success": success,
|
| | },
|
| | )
|
| | )
|
| |
|
| |
|
| | if approval_required_tools:
|
| |
|
| | tools_data = []
|
| | for tc in approval_required_tools:
|
| | tool_name = tc.function.name
|
| | try:
|
| | tool_args = json.loads(tc.function.arguments)
|
| | except (json.JSONDecodeError, TypeError):
|
| | tool_args = {}
|
| | tools_data.append(
|
| | {
|
| | "tool": tool_name,
|
| | "arguments": tool_args,
|
| | "tool_call_id": tc.id,
|
| | }
|
| | )
|
| |
|
| | await session.send_event(
|
| | Event(
|
| | event_type="approval_required",
|
| | data={
|
| | "tools": tools_data,
|
| | "count": len(tools_data),
|
| | },
|
| | )
|
| | )
|
| |
|
| |
|
| | session.pending_approval = {
|
| | "tool_calls": approval_required_tools,
|
| | }
|
| |
|
| |
|
| | return None
|
| |
|
| | iteration += 1
|
| |
|
| | except Exception as e:
|
| | import traceback
|
| |
|
| | await session.send_event(
|
| | Event(
|
| | event_type="error",
|
| | data={"error": str(e) + "\n" + traceback.format_exc()},
|
| | )
|
| | )
|
| | break
|
| |
|
| | old_length = session.context_manager.context_length
|
| | await session.context_manager.compact(model_name=session.config.model_name)
|
| | new_length = session.context_manager.context_length
|
| |
|
| | if new_length != old_length:
|
| | await session.send_event(
|
| | Event(
|
| | event_type="compacted",
|
| | data={"old_tokens": old_length, "new_tokens": new_length},
|
| | )
|
| | )
|
| |
|
| | await session.send_event(
|
| | Event(
|
| | event_type="turn_complete",
|
| | data={"history_size": len(session.context_manager.items)},
|
| | )
|
| | )
|
| |
|
| |
|
| | session.increment_turn()
|
| | await session.auto_save_if_needed()
|
| |
|
| | return final_response
|
| |
|
| | @staticmethod
|
| | async def interrupt(session: Session) -> None:
|
| | """Handle interrupt (like interrupt in codex.rs:1266)"""
|
| | session.interrupt()
|
| | await session.send_event(Event(event_type="interrupted"))
|
| |
|
| | @staticmethod
|
| | async def compact(session: Session) -> None:
|
| | """Handle compact (like compact in codex.rs:1317)"""
|
| | old_length = session.context_manager.context_length
|
| | await session.context_manager.compact(model_name=session.config.model_name)
|
| | new_length = session.context_manager.context_length
|
| |
|
| | await session.send_event(
|
| | Event(
|
| | event_type="compacted",
|
| | data={"removed": old_length, "remaining": new_length},
|
| | )
|
| | )
|
| |
|
| | @staticmethod
|
| | async def undo(session: Session) -> None:
|
| | """Remove the last complete turn (user msg + all assistant/tool msgs that follow).
|
| |
|
| | Anthropic requires every tool_use to have a matching tool_result,
|
| | so we can't just pop 2 items — we must pop everything back to
|
| | (and including) the last user message to keep the history valid.
|
| | """
|
| | items = session.context_manager.items
|
| | if not items:
|
| | await session.send_event(Event(event_type="undo_complete"))
|
| | return
|
| |
|
| |
|
| | removed_user = False
|
| | while items:
|
| | msg = items.pop()
|
| | if getattr(msg, "role", None) == "user":
|
| | removed_user = True
|
| | break
|
| |
|
| | if not removed_user:
|
| | logger.warning("Undo: no user message found to remove")
|
| |
|
| | await session.send_event(Event(event_type="undo_complete"))
|
| |
|
| | @staticmethod
|
| | async def exec_approval(session: Session, approvals: list[dict]) -> None:
|
| | """Handle batch job execution approval"""
|
| | if not session.pending_approval:
|
| | await session.send_event(
|
| | Event(
|
| | event_type="error",
|
| | data={"error": "No pending approval to process"},
|
| | )
|
| | )
|
| | return
|
| |
|
| | tool_calls = session.pending_approval.get("tool_calls", [])
|
| | if not tool_calls:
|
| | await session.send_event(
|
| | Event(
|
| | event_type="error",
|
| | data={"error": "No pending tool calls found"},
|
| | )
|
| | )
|
| | return
|
| |
|
| |
|
| | approval_map = {a["tool_call_id"]: a for a in approvals}
|
| |
|
| |
|
| | approved_tasks = []
|
| | rejected_tasks = []
|
| |
|
| | for tc in tool_calls:
|
| | tool_name = tc.function.name
|
| | tool_args = json.loads(tc.function.arguments)
|
| | approval_decision = approval_map.get(tc.id, {"approved": False})
|
| |
|
| | if approval_decision.get("approved", False):
|
| | approved_tasks.append((tc, tool_name, tool_args))
|
| | else:
|
| | rejected_tasks.append((tc, tool_name, approval_decision))
|
| |
|
| |
|
| | async def execute_tool(tc, tool_name, tool_args):
|
| | """Execute a single tool and return its result"""
|
| | await session.send_event(
|
| | Event(
|
| | event_type="tool_call",
|
| | data={
|
| | "tool": tool_name,
|
| | "arguments": tool_args,
|
| | "tool_call_id": tc.id,
|
| | },
|
| | )
|
| | )
|
| |
|
| | output, success = await session.tool_router.call_tool(
|
| | tool_name, tool_args, session=session
|
| | )
|
| |
|
| | return (tc, tool_name, output, success)
|
| |
|
| |
|
| | if approved_tasks:
|
| | results = await asyncio.gather(
|
| | *[
|
| | execute_tool(tc, tool_name, tool_args)
|
| | for tc, tool_name, tool_args in approved_tasks
|
| | ],
|
| | return_exceptions=True,
|
| | )
|
| |
|
| |
|
| | for result in results:
|
| | if isinstance(result, Exception):
|
| |
|
| | logger.error(f"Tool execution error: {result}")
|
| | continue
|
| |
|
| | tc, tool_name, output, success = result
|
| |
|
| |
|
| | tool_msg = Message(
|
| | role="tool",
|
| | content=output,
|
| | tool_call_id=tc.id,
|
| | name=tool_name,
|
| | )
|
| | session.context_manager.add_message(tool_msg)
|
| |
|
| | await session.send_event(
|
| | Event(
|
| | event_type="tool_output",
|
| | data={
|
| | "tool": tool_name,
|
| | "tool_call_id": tc.id,
|
| | "output": output,
|
| | "success": success,
|
| | },
|
| | )
|
| | )
|
| |
|
| |
|
| | for tc, tool_name, approval_decision in rejected_tasks:
|
| | rejection_msg = "Job execution cancelled by user"
|
| | user_feedback = approval_decision.get("feedback")
|
| | if user_feedback:
|
| | rejection_msg += f". User feedback: {user_feedback}"
|
| |
|
| | tool_msg = Message(
|
| | role="tool",
|
| | content=rejection_msg,
|
| | tool_call_id=tc.id,
|
| | name=tool_name,
|
| | )
|
| | session.context_manager.add_message(tool_msg)
|
| |
|
| | await session.send_event(
|
| | Event(
|
| | event_type="tool_output",
|
| | data={
|
| | "tool": tool_name,
|
| | "tool_call_id": tc.id,
|
| | "output": rejection_msg,
|
| | "success": False,
|
| | },
|
| | )
|
| | )
|
| |
|
| |
|
| | session.pending_approval = None
|
| |
|
| |
|
| | await Handlers.run_agent(session, "")
|
| |
|
| | @staticmethod
|
| | async def shutdown(session: Session) -> bool:
|
| | """Handle shutdown (like shutdown in codex.rs:1329)"""
|
| |
|
| | if session.config.save_sessions:
|
| | logger.info("Saving session...")
|
| | repo_id = session.config.session_dataset_repo
|
| | _ = session.save_and_upload_detached(repo_id)
|
| |
|
| | session.is_running = False
|
| | await session.send_event(Event(event_type="shutdown"))
|
| | return True
|
| |
|
| |
|
| | async def process_submission(session: Session, submission) -> bool:
|
| | """
|
| | Process a single submission and return whether to continue running.
|
| |
|
| | Returns:
|
| | bool: True to continue, False to shutdown
|
| | """
|
| | op = submission.operation
|
| | logger.debug("Received operation: %s", op.op_type.value)
|
| |
|
| | if op.op_type == OpType.USER_INPUT:
|
| | text = op.data.get("text", "") if op.data else ""
|
| | await Handlers.run_agent(session, text)
|
| | return True
|
| |
|
| | if op.op_type == OpType.INTERRUPT:
|
| | await Handlers.interrupt(session)
|
| | return True
|
| |
|
| | if op.op_type == OpType.COMPACT:
|
| | await Handlers.compact(session)
|
| | return True
|
| |
|
| | if op.op_type == OpType.UNDO:
|
| | await Handlers.undo(session)
|
| | return True
|
| |
|
| | if op.op_type == OpType.EXEC_APPROVAL:
|
| | approvals = op.data.get("approvals", []) if op.data else []
|
| | await Handlers.exec_approval(session, approvals)
|
| | return True
|
| |
|
| | if op.op_type == OpType.SHUTDOWN:
|
| | return not await Handlers.shutdown(session)
|
| |
|
| | logger.warning(f"Unknown operation: {op.op_type}")
|
| | return True
|
| |
|
| |
|
| | @observe(name="submission_loop")
|
| | async def submission_loop(
|
| | submission_queue: asyncio.Queue,
|
| | event_queue: asyncio.Queue,
|
| | config: Config | None = None,
|
| | tool_router: ToolRouter | None = None,
|
| | ) -> None:
|
| | """
|
| | Main agent loop - processes submissions and dispatches to handlers.
|
| | This is the core of the agent (like submission_loop in codex.rs:1259-1340)
|
| | """
|
| |
|
| |
|
| | session = Session(event_queue, config=config, tool_router=tool_router)
|
| | logger.info("Agent loop started")
|
| |
|
| |
|
| | if config and config.save_sessions:
|
| | Session.retry_failed_uploads_detached(
|
| | directory="session_logs", repo_id=config.session_dataset_repo
|
| | )
|
| |
|
| | try:
|
| |
|
| | async with tool_router:
|
| |
|
| | await session.send_event(
|
| | Event(event_type="ready", data={"message": "Agent initialized"})
|
| | )
|
| |
|
| | while session.is_running:
|
| | submission = await submission_queue.get()
|
| |
|
| | try:
|
| | should_continue = await process_submission(session, submission)
|
| | if not should_continue:
|
| | break
|
| | except asyncio.CancelledError:
|
| | logger.warning("Agent loop cancelled")
|
| | break
|
| | except Exception as e:
|
| | logger.error(f"Error in agent loop: {e}")
|
| | await session.send_event(
|
| | Event(event_type="error", data={"error": str(e)})
|
| | )
|
| |
|
| | logger.info("Agent loop exited")
|
| |
|
| | finally:
|
| |
|
| | if session.config.save_sessions and session.is_running:
|
| | logger.info("Emergency save: preserving session before exit...")
|
| | try:
|
| | local_path = session.save_and_upload_detached(
|
| | session.config.session_dataset_repo
|
| | )
|
| | if local_path:
|
| | logger.info("Emergency save successful, upload in progress")
|
| | except Exception as e:
|
| | logger.error(f"Emergency save failed: {e}")
|
| |
|