| | """loop |
| | Main agent implementation with integrated tool system and MCP support |
| | """ |
| |
|
| | import asyncio |
| | import json |
| | import logging |
| | import os |
| |
|
| | from litellm import ChatCompletionMessageToolCall, Message, acompletion |
| | from lmnr import observe |
| |
|
| | from agent.config import Config |
| | from agent.core.session import Event, OpType, Session |
| | from agent.core.tools import ToolRouter |
| | from agent.tools.jobs_tool import CPU_FLAVORS |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | ToolCall = ChatCompletionMessageToolCall |
| | |
| | |
| | _INFERENCE_API_KEY = os.environ.get("INFERENCE_TOKEN") |
| |
|
| |
|
| | def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: |
| | """ |
| | Validate tool arguments structure. |
| | |
| | Returns: |
| | (is_valid, error_message) |
| | """ |
| | args = tool_args.get("args", {}) |
| | |
| | if isinstance(args, str): |
| | return ( |
| | False, |
| | f"Tool call error: 'args' must be a JSON object, not a string. You passed: {repr(args)}", |
| | ) |
| | if not isinstance(args, dict) and args is not None: |
| | return ( |
| | False, |
| | f"Tool call error: 'args' must be a JSON object. You passed type: {type(args).__name__}", |
| | ) |
| | return True, None |
| |
|
| |
|
| | def _needs_approval( |
| | tool_name: str, tool_args: dict, config: Config | None = None |
| | ) -> bool: |
| | """Check if a tool call requires user approval before execution.""" |
| | |
| | if config and config.yolo_mode: |
| | return False |
| |
|
| | |
| | args_valid, _ = _validate_tool_args(tool_args) |
| | if not args_valid: |
| | return False |
| |
|
| | if tool_name == "hf_jobs": |
| | operation = tool_args.get("operation", "") |
| | if operation not in ["run", "uv", "scheduled run", "scheduled uv"]: |
| | return False |
| |
|
| | |
| | |
| | hardware_flavor = ( |
| | tool_args.get("hardware_flavor") |
| | or tool_args.get("flavor") |
| | or tool_args.get("hardware") |
| | or "cpu-basic" |
| | ) |
| | is_cpu_job = hardware_flavor in CPU_FLAVORS |
| |
|
| | if is_cpu_job: |
| | if config and not config.confirm_cpu_jobs: |
| | return False |
| | return True |
| |
|
| | return True |
| |
|
| | |
| | if tool_name == "hf_private_repos": |
| | operation = tool_args.get("operation", "") |
| | if operation == "upload_file": |
| | if config and config.auto_file_upload: |
| | return False |
| | return True |
| | |
| | if operation in ["create_repo"]: |
| | return True |
| |
|
| | |
| | if tool_name == "hf_repo_files": |
| | operation = tool_args.get("operation", "") |
| | if operation in ["upload", "delete"]: |
| | return True |
| |
|
| | |
| | if tool_name == "hf_repo_git": |
| | operation = tool_args.get("operation", "") |
| | if operation in [ |
| | "delete_branch", |
| | "delete_tag", |
| | "merge_pr", |
| | "create_repo", |
| | "update_repo", |
| | ]: |
| | return True |
| |
|
| | return False |
| |
|
| |
|
| | class Handlers: |
| | """Handler functions for each operation type""" |
| |
|
| | @staticmethod |
| | @observe(name="run_agent") |
| | async def run_agent( |
| | session: Session, text: str, max_iterations: int = 10 |
| | ) -> str | None: |
| | """ |
| | Handle user input (like user_input_or_turn in codex.rs:1291) |
| | Returns the final assistant response content, if any. |
| | """ |
| | |
| | if hasattr(session, "session_id"): |
| | from lmnr import Laminar |
| |
|
| | Laminar.set_trace_session_id(session_id=session.session_id) |
| |
|
| | |
| | if text: |
| | user_msg = Message(role="user", content=text) |
| | session.context_manager.add_message(user_msg) |
| |
|
| | |
| | await session.send_event( |
| | Event(event_type="processing", data={"message": "Processing user input"}) |
| | ) |
| |
|
| | |
| | iteration = 0 |
| | final_response = None |
| |
|
| | while iteration < max_iterations: |
| | messages = session.context_manager.get_messages() |
| | tools = session.tool_router.get_tool_specs_for_llm() |
| | try: |
| | |
| | response = await acompletion( |
| | model=session.config.model_name, |
| | messages=messages, |
| | tools=tools, |
| | tool_choice="auto", |
| | stream=True, |
| | stream_options={"include_usage": True}, |
| | api_key=_INFERENCE_API_KEY |
| | if _INFERENCE_API_KEY |
| | and session.config.model_name.startswith("huggingface/") |
| | else None, |
| | ) |
| |
|
| | full_content = "" |
| | tool_calls_acc: dict[int, dict] = {} |
| | token_count = 0 |
| |
|
| | async for chunk in response: |
| | choice = chunk.choices[0] if chunk.choices else None |
| | if not choice: |
| | |
| | if hasattr(chunk, "usage") and chunk.usage: |
| | token_count = chunk.usage.total_tokens |
| | continue |
| |
|
| | delta = choice.delta |
| |
|
| | |
| | if delta.content: |
| | full_content += delta.content |
| | await session.send_event( |
| | Event( |
| | event_type="assistant_chunk", |
| | data={"content": delta.content}, |
| | ) |
| | ) |
| |
|
| | |
| | if delta.tool_calls: |
| | for tc_delta in delta.tool_calls: |
| | idx = tc_delta.index |
| | if idx not in tool_calls_acc: |
| | tool_calls_acc[idx] = { |
| | "id": "", |
| | "type": "function", |
| | "function": {"name": "", "arguments": ""}, |
| | } |
| | if tc_delta.id: |
| | tool_calls_acc[idx]["id"] = tc_delta.id |
| | if tc_delta.function: |
| | if tc_delta.function.name: |
| | tool_calls_acc[idx]["function"]["name"] += ( |
| | tc_delta.function.name |
| | ) |
| | if tc_delta.function.arguments: |
| | tool_calls_acc[idx]["function"]["arguments"] += ( |
| | tc_delta.function.arguments |
| | ) |
| |
|
| | |
| | if hasattr(chunk, "usage") and chunk.usage: |
| | token_count = chunk.usage.total_tokens |
| |
|
| | |
| | content = full_content or None |
| |
|
| | |
| | tool_calls: list[ToolCall] = [] |
| | for idx in sorted(tool_calls_acc.keys()): |
| | tc_data = tool_calls_acc[idx] |
| | tool_calls.append( |
| | ToolCall( |
| | id=tc_data["id"], |
| | type="function", |
| | function={ |
| | "name": tc_data["function"]["name"], |
| | "arguments": tc_data["function"]["arguments"], |
| | }, |
| | ) |
| | ) |
| |
|
| | |
| | await session.send_event( |
| | Event(event_type="assistant_stream_end", data={}) |
| | ) |
| |
|
| | |
| | if not tool_calls: |
| | if content: |
| | assistant_msg = Message(role="assistant", content=content) |
| | session.context_manager.add_message(assistant_msg, token_count) |
| | final_response = content |
| | break |
| |
|
| | |
| | assistant_msg = Message( |
| | role="assistant", |
| | content=content, |
| | tool_calls=tool_calls, |
| | ) |
| | session.context_manager.add_message(assistant_msg, token_count) |
| |
|
| | |
| | approval_required_tools = [] |
| | non_approval_tools = [] |
| |
|
| | for tc in tool_calls: |
| | tool_name = tc.function.name |
| | try: |
| | tool_args = json.loads(tc.function.arguments) |
| | except (json.JSONDecodeError, TypeError) as e: |
| | logger.warning(f"Malformed tool arguments for {tool_name}: {e}") |
| | tool_args = {} |
| |
|
| | if _needs_approval(tool_name, tool_args, session.config): |
| | approval_required_tools.append(tc) |
| | else: |
| | non_approval_tools.append(tc) |
| |
|
| | |
| | if non_approval_tools: |
| | |
| | parsed_tools: list[ |
| | tuple[ChatCompletionMessageToolCall, str, dict, bool, str] |
| | ] = [] |
| | for tc in non_approval_tools: |
| | tool_name = tc.function.name |
| | try: |
| | tool_args = json.loads(tc.function.arguments) |
| | except (json.JSONDecodeError, TypeError): |
| | tool_args = {} |
| |
|
| | args_valid, error_msg = _validate_tool_args(tool_args) |
| | parsed_tools.append( |
| | (tc, tool_name, tool_args, args_valid, error_msg) |
| | ) |
| |
|
| | |
| | for tc, tool_name, tool_args, args_valid, _ in parsed_tools: |
| | if args_valid: |
| | await session.send_event( |
| | Event( |
| | event_type="tool_call", |
| | data={ |
| | "tool": tool_name, |
| | "arguments": tool_args, |
| | "tool_call_id": tc.id, |
| | }, |
| | ) |
| | ) |
| |
|
| | |
| | async def _exec_tool( |
| | tc: ChatCompletionMessageToolCall, |
| | name: str, |
| | args: dict, |
| | valid: bool, |
| | err: str, |
| | ) -> tuple[ChatCompletionMessageToolCall, str, dict, str, bool]: |
| | if not valid: |
| | return (tc, name, args, err, False) |
| | out, ok = await session.tool_router.call_tool( |
| | name, args, session=session |
| | ) |
| | return (tc, name, args, out, ok) |
| |
|
| | results = await asyncio.gather( |
| | *[ |
| | _exec_tool(tc, name, args, valid, err) |
| | for tc, name, args, valid, err in parsed_tools |
| | ] |
| | ) |
| |
|
| | |
| | for tc, tool_name, tool_args, output, success in results: |
| | tool_msg = Message( |
| | role="tool", |
| | content=output, |
| | tool_call_id=tc.id, |
| | name=tool_name, |
| | ) |
| | session.context_manager.add_message(tool_msg) |
| |
|
| | await session.send_event( |
| | Event( |
| | event_type="tool_output", |
| | data={ |
| | "tool": tool_name, |
| | "tool_call_id": tc.id, |
| | "output": output, |
| | "success": success, |
| | }, |
| | ) |
| | ) |
| |
|
| | |
| | if approval_required_tools: |
| | |
| | tools_data = [] |
| | for tc in approval_required_tools: |
| | tool_name = tc.function.name |
| | try: |
| | tool_args = json.loads(tc.function.arguments) |
| | except (json.JSONDecodeError, TypeError): |
| | tool_args = {} |
| | tools_data.append( |
| | { |
| | "tool": tool_name, |
| | "arguments": tool_args, |
| | "tool_call_id": tc.id, |
| | } |
| | ) |
| |
|
| | await session.send_event( |
| | Event( |
| | event_type="approval_required", |
| | data={ |
| | "tools": tools_data, |
| | "count": len(tools_data), |
| | }, |
| | ) |
| | ) |
| |
|
| | |
| | session.pending_approval = { |
| | "tool_calls": approval_required_tools, |
| | } |
| |
|
| | |
| | return None |
| |
|
| | iteration += 1 |
| |
|
| | except Exception as e: |
| | import traceback |
| |
|
| | await session.send_event( |
| | Event( |
| | event_type="error", |
| | data={"error": str(e) + "\n" + traceback.format_exc()}, |
| | ) |
| | ) |
| | break |
| |
|
| | old_length = session.context_manager.context_length |
| | await session.context_manager.compact(model_name=session.config.model_name) |
| | new_length = session.context_manager.context_length |
| |
|
| | if new_length != old_length: |
| | await session.send_event( |
| | Event( |
| | event_type="compacted", |
| | data={"old_tokens": old_length, "new_tokens": new_length}, |
| | ) |
| | ) |
| |
|
| | await session.send_event( |
| | Event( |
| | event_type="turn_complete", |
| | data={"history_size": len(session.context_manager.items)}, |
| | ) |
| | ) |
| |
|
| | |
| | session.increment_turn() |
| | await session.auto_save_if_needed() |
| |
|
| | return final_response |
| |
|
| | @staticmethod |
| | async def interrupt(session: Session) -> None: |
| | """Handle interrupt (like interrupt in codex.rs:1266)""" |
| | session.interrupt() |
| | await session.send_event(Event(event_type="interrupted")) |
| |
|
| | @staticmethod |
| | async def compact(session: Session) -> None: |
| | """Handle compact (like compact in codex.rs:1317)""" |
| | old_length = session.context_manager.context_length |
| | await session.context_manager.compact(model_name=session.config.model_name) |
| | new_length = session.context_manager.context_length |
| |
|
| | await session.send_event( |
| | Event( |
| | event_type="compacted", |
| | data={"removed": old_length, "remaining": new_length}, |
| | ) |
| | ) |
| |
|
| | @staticmethod |
| | async def undo(session: Session) -> None: |
| | """Remove the last complete turn (user msg + all assistant/tool msgs that follow). |
| | |
| | Anthropic requires every tool_use to have a matching tool_result, |
| | so we can't just pop 2 items — we must pop everything back to |
| | (and including) the last user message to keep the history valid. |
| | """ |
| | items = session.context_manager.items |
| | if not items: |
| | await session.send_event(Event(event_type="undo_complete")) |
| | return |
| |
|
| | |
| | removed_user = False |
| | while items: |
| | msg = items.pop() |
| | if getattr(msg, "role", None) == "user": |
| | removed_user = True |
| | break |
| |
|
| | if not removed_user: |
| | logger.warning("Undo: no user message found to remove") |
| |
|
| | await session.send_event(Event(event_type="undo_complete")) |
| |
|
| | @staticmethod |
| | async def exec_approval(session: Session, approvals: list[dict]) -> None: |
| | """Handle batch job execution approval""" |
| | if not session.pending_approval: |
| | await session.send_event( |
| | Event( |
| | event_type="error", |
| | data={"error": "No pending approval to process"}, |
| | ) |
| | ) |
| | return |
| |
|
| | tool_calls = session.pending_approval.get("tool_calls", []) |
| | if not tool_calls: |
| | await session.send_event( |
| | Event( |
| | event_type="error", |
| | data={"error": "No pending tool calls found"}, |
| | ) |
| | ) |
| | return |
| |
|
| | |
| | approval_map = {a["tool_call_id"]: a for a in approvals} |
| |
|
| | |
| | approved_tasks = [] |
| | rejected_tasks = [] |
| |
|
| | for tc in tool_calls: |
| | tool_name = tc.function.name |
| | tool_args = json.loads(tc.function.arguments) |
| | approval_decision = approval_map.get(tc.id, {"approved": False}) |
| |
|
| | if approval_decision.get("approved", False): |
| | approved_tasks.append((tc, tool_name, tool_args)) |
| | else: |
| | rejected_tasks.append((tc, tool_name, approval_decision)) |
| |
|
| | |
| | async def execute_tool(tc, tool_name, tool_args): |
| | """Execute a single tool and return its result""" |
| | await session.send_event( |
| | Event( |
| | event_type="tool_call", |
| | data={ |
| | "tool": tool_name, |
| | "arguments": tool_args, |
| | "tool_call_id": tc.id, |
| | }, |
| | ) |
| | ) |
| |
|
| | output, success = await session.tool_router.call_tool( |
| | tool_name, tool_args, session=session |
| | ) |
| |
|
| | return (tc, tool_name, output, success) |
| |
|
| | |
| | if approved_tasks: |
| | results = await asyncio.gather( |
| | *[ |
| | execute_tool(tc, tool_name, tool_args) |
| | for tc, tool_name, tool_args in approved_tasks |
| | ], |
| | return_exceptions=True, |
| | ) |
| |
|
| | |
| | for result in results: |
| | if isinstance(result, Exception): |
| | |
| | logger.error(f"Tool execution error: {result}") |
| | continue |
| |
|
| | tc, tool_name, output, success = result |
| |
|
| | |
| | tool_msg = Message( |
| | role="tool", |
| | content=output, |
| | tool_call_id=tc.id, |
| | name=tool_name, |
| | ) |
| | session.context_manager.add_message(tool_msg) |
| |
|
| | await session.send_event( |
| | Event( |
| | event_type="tool_output", |
| | data={ |
| | "tool": tool_name, |
| | "tool_call_id": tc.id, |
| | "output": output, |
| | "success": success, |
| | }, |
| | ) |
| | ) |
| |
|
| | |
| | for tc, tool_name, approval_decision in rejected_tasks: |
| | rejection_msg = "Job execution cancelled by user" |
| | user_feedback = approval_decision.get("feedback") |
| | if user_feedback: |
| | rejection_msg += f". User feedback: {user_feedback}" |
| |
|
| | tool_msg = Message( |
| | role="tool", |
| | content=rejection_msg, |
| | tool_call_id=tc.id, |
| | name=tool_name, |
| | ) |
| | session.context_manager.add_message(tool_msg) |
| |
|
| | await session.send_event( |
| | Event( |
| | event_type="tool_output", |
| | data={ |
| | "tool": tool_name, |
| | "tool_call_id": tc.id, |
| | "output": rejection_msg, |
| | "success": False, |
| | }, |
| | ) |
| | ) |
| |
|
| | |
| | session.pending_approval = None |
| |
|
| | |
| | await Handlers.run_agent(session, "") |
| |
|
| | @staticmethod |
| | async def shutdown(session: Session) -> bool: |
| | """Handle shutdown (like shutdown in codex.rs:1329)""" |
| | |
| | if session.config.save_sessions: |
| | logger.info("Saving session...") |
| | repo_id = session.config.session_dataset_repo |
| | _ = session.save_and_upload_detached(repo_id) |
| |
|
| | session.is_running = False |
| | await session.send_event(Event(event_type="shutdown")) |
| | return True |
| |
|
| |
|
| | async def process_submission(session: Session, submission) -> bool: |
| | """ |
| | Process a single submission and return whether to continue running. |
| | |
| | Returns: |
| | bool: True to continue, False to shutdown |
| | """ |
| | op = submission.operation |
| | logger.debug("Received operation: %s", op.op_type.value) |
| |
|
| | if op.op_type == OpType.USER_INPUT: |
| | text = op.data.get("text", "") if op.data else "" |
| | await Handlers.run_agent(session, text) |
| | return True |
| |
|
| | if op.op_type == OpType.INTERRUPT: |
| | await Handlers.interrupt(session) |
| | return True |
| |
|
| | if op.op_type == OpType.COMPACT: |
| | await Handlers.compact(session) |
| | return True |
| |
|
| | if op.op_type == OpType.UNDO: |
| | await Handlers.undo(session) |
| | return True |
| |
|
| | if op.op_type == OpType.EXEC_APPROVAL: |
| | approvals = op.data.get("approvals", []) if op.data else [] |
| | await Handlers.exec_approval(session, approvals) |
| | return True |
| |
|
| | if op.op_type == OpType.SHUTDOWN: |
| | return not await Handlers.shutdown(session) |
| |
|
| | logger.warning(f"Unknown operation: {op.op_type}") |
| | return True |
| |
|
| |
|
| | @observe(name="submission_loop") |
| | async def submission_loop( |
| | submission_queue: asyncio.Queue, |
| | event_queue: asyncio.Queue, |
| | config: Config | None = None, |
| | tool_router: ToolRouter | None = None, |
| | ) -> None: |
| | """ |
| | Main agent loop - processes submissions and dispatches to handlers. |
| | This is the core of the agent (like submission_loop in codex.rs:1259-1340) |
| | """ |
| |
|
| | |
| | session = Session(event_queue, config=config, tool_router=tool_router) |
| | logger.info("Agent loop started") |
| |
|
| | |
| | if config and config.save_sessions: |
| | Session.retry_failed_uploads_detached( |
| | directory="session_logs", repo_id=config.session_dataset_repo |
| | ) |
| |
|
| | try: |
| | |
| | async with tool_router: |
| | |
| | await session.send_event( |
| | Event(event_type="ready", data={"message": "Agent initialized"}) |
| | ) |
| |
|
| | while session.is_running: |
| | submission = await submission_queue.get() |
| |
|
| | try: |
| | should_continue = await process_submission(session, submission) |
| | if not should_continue: |
| | break |
| | except asyncio.CancelledError: |
| | logger.warning("Agent loop cancelled") |
| | break |
| | except Exception as e: |
| | logger.error(f"Error in agent loop: {e}") |
| | await session.send_event( |
| | Event(event_type="error", data={"error": str(e)}) |
| | ) |
| |
|
| | logger.info("Agent loop exited") |
| |
|
| | finally: |
| | |
| | if session.config.save_sessions and session.is_running: |
| | logger.info("Emergency save: preserving session before exit...") |
| | try: |
| | local_path = session.save_and_upload_detached( |
| | session.config.session_dataset_repo |
| | ) |
| | if local_path: |
| | logger.info("Emergency save successful, upload in progress") |
| | except Exception as e: |
| | logger.error(f"Emergency save failed: {e}") |
| |
|