#!/usr/bin/env python3 """ Multi-Step Tool Calling Orchestrator ===================================== A production-ready orchestrator for multi-step LLM tool calling workflows. Handles step isolation, Pydantic validation, retry logic, and error feedback. This is the architecture pattern that makes complex tool calling reliable: - Each step has its own isolated set of tools - LLM responses are validated against Pydantic schemas - Failed validations are fed back to the LLM with structured error messages - Validation tracks whether tools PASSED, not just whether they were CALLED Usage: python multi_step_orchestrator.py --url http://localhost:8000 --model NousResearch/Hermes-3-Llama-3.1-70B-FP8 Example workflow (generic): Step 1: Discover what components are needed (search, list, get_info tools) Step 2: Configure each component (get_details, validate tools) """ import argparse import json import requests from typing import Any, Callable, Dict, List, Optional from pydantic import BaseModel, ValidationError from robust_json_extraction import extract_json, extract_tool_calls from pydantic_tool_schemas import ( FunctionCall, ToolCall, StepResponse, ValidationTracker, make_tool_schema ) # ============================================================================ # VLLM Client # ============================================================================ class VLLMClient: """Simple client for VLLM's OpenAI-compatible API.""" def __init__(self, base_url: str, model: str): self.base_url = base_url.rstrip('/') self.model = model def chat(self, messages: List[Dict], tools: Optional[List[Dict]] = None, temperature: float = 0.1, max_tokens: int = 4096) -> str: """Send a chat completion and return the assistant's content.""" payload = { "model": self.model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, } if tools: payload["tools"] = tools payload["tool_choice"] = "auto" response = requests.post( f"{self.base_url}/v1/chat/completions", json=payload, timeout=120 ) response.raise_for_status() return response.json()["choices"][0]["message"].get("content", "") # ============================================================================ # Tool Registry # ============================================================================ class ToolRegistry: """Registry of available tools, organized by step.""" def __init__(self): self._tools: Dict[str, Dict[str, Any]] = {} # name -> {schema, function, steps} def register(self, name: str, description: str, parameters: dict, function: Callable, steps: Optional[List[int]] = None): """Register a tool available in specific steps (or all steps if None).""" self._tools[name] = { "schema": make_tool_schema(name, description, parameters), "function": function, "steps": steps, # None means all steps } def get_schemas(self, step: Optional[int] = None) -> List[Dict]: """Get OpenAI-format tool schemas for a given step.""" schemas = [] for name, tool in self._tools.items(): if tool["steps"] is None or (step is not None and step in tool["steps"]): schemas.append(tool["schema"]) return schemas def execute(self, name: str, arguments: dict) -> dict: """Execute a tool by name with given arguments.""" if name not in self._tools: return {"error": f"Unknown tool: {name}"} try: return self._tools[name]["function"](**arguments) except Exception as e: return {"error": f"Error executing {name}: {str(e)}"} # ============================================================================ # Step Runner # ============================================================================ def run_step( client: VLLMClient, registry: ToolRegistry, step_num: int, system_prompt: str, initial_context: str, schema_class: type = StepResponse, max_iterations: int = 10, validation_tools: Optional[List[str]] = None, ) -> Optional[Dict]: """ Run a single step of a multi-step workflow. Args: client: VLLM client registry: Tool registry step_num: Step number (for tool filtering) system_prompt: System prompt for this step initial_context: User message / context from previous step schema_class: Pydantic schema for validating responses max_iterations: Max LLM turns before giving up validation_tools: Names of tools that must be called AND pass Returns: Parsed response dict, or None if step failed. """ # Get tools for this step tool_schemas = registry.get_schemas(step=step_num) # Set up validation tracking tracker = None if validation_tools: tracker = ValidationTracker(validation_tools) # Build messages messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": initial_context}, ] for iteration in range(1, max_iterations + 1): print(f"\n [Step {step_num}] Iteration {iteration}/{max_iterations}") # Generate completion completion = client.chat(messages, tools=tool_schemas if tool_schemas else None) if not completion: print(f" [Step {step_num}] Empty response") messages.append({"role": "assistant", "content": ""}) messages.append({ "role": "user", "content": "Your response was empty. Please provide a valid JSON response." }) continue # Check for tool calls (Hermes XML format) tool_call_list = extract_tool_calls(completion) if tool_call_list: print(f" [Step {step_num}] Found {len(tool_call_list)} tool call(s)") messages.append({"role": "assistant", "content": completion}) for tc in tool_call_list: print(f" -> {tc.get('name')}({json.dumps(tc.get('arguments', {}))})") result = registry.execute(tc["name"], tc.get("arguments", {})) # Track validation results if tracker and tc["name"] in (validation_tools or []): tracker.record_call(tc["name"], result) # Format tool response status = "ERROR" if "error" in result else "SUCCESS" response_text = ( f"\n" f"{tc['name']}\n" f"{status}\n" f"{json.dumps(result, indent=2)}\n" f"" ) if "error" in result: response_text += ( "\nIMPORTANT: This tool call failed. " "Read the error, understand the issue, fix your parameters, and retry." ) messages.append({"role": "user", "content": response_text}) continue # No tool calls — try to parse as final JSON response try: json_data = extract_json(completion) schema_class(**json_data) # Validate structure result_data = json_data # Check if it's actually a tool call in JSON format if result_data.get("tool_calls"): messages.append({"role": "assistant", "content": completion}) for tc in result_data["tool_calls"]: result = registry.execute(tc["name"], tc.get("arguments", {})) if tracker and tc["name"] in (validation_tools or []): tracker.record_call(tc["name"], result) messages.append({ "role": "user", "content": f"Tool result for {tc['name']}: {json.dumps(result)}" }) continue # It's a final response — check validation requirements if tracker and not tracker.all_passed(): error_feedback = tracker.format_errors() enforcement_msg = ( f"You returned a final response but validations have not all passed.\n" f"{error_feedback}\n" f"Please fix the errors and call the validation tools again." ) messages.append({"role": "assistant", "content": completion}) messages.append({"role": "user", "content": enforcement_msg}) continue print(f" [Step {step_num}] Final response received and validated") return result_data except (json.JSONDecodeError, ValidationError) as e: print(f" [Step {step_num}] Parse/validation error: {e}") messages.append({"role": "assistant", "content": completion}) messages.append({ "role": "user", "content": ( f"Your response could not be parsed as valid JSON. Error: {str(e)}\n" f"Please respond with ONLY valid JSON matching the required schema. " f"No markdown, no explanatory text." ) }) print(f" [Step {step_num}] Exhausted {max_iterations} iterations") return None # ============================================================================ # Multi-Step Workflow Runner # ============================================================================ def run_workflow( client: VLLMClient, registry: ToolRegistry, steps: List[Dict], initial_query: str, max_step_retries: int = 3, ) -> Optional[Dict]: """ Run a multi-step workflow. Args: client: VLLM client registry: Tool registry steps: List of step configs, each with: - step_num: int - system_prompt: str - max_iterations: int - validation_tools: Optional[List[str]] initial_query: User's original request max_step_retries: How many times to retry each step Returns: Final result dict, or None if workflow failed. """ previous_result = initial_query for step_config in steps: step_num = step_config["step_num"] print(f"\n{'='*60}") print(f"STEP {step_num}: {step_config.get('name', 'Unnamed')}") print(f"{'='*60}") for retry in range(max_step_retries): if retry > 0: print(f"\n Retry {retry}/{max_step_retries}") context = previous_result if isinstance(previous_result, str) else json.dumps(previous_result) result = run_step( client=client, registry=registry, step_num=step_num, system_prompt=step_config["system_prompt"], initial_context=context, max_iterations=step_config.get("max_iterations", 10), validation_tools=step_config.get("validation_tools"), ) if result: previous_result = result break else: print(f"\n Step {step_num} failed after {max_step_retries} retries") return None return previous_result # ============================================================================ # Example: Generic Two-Step Workflow # ============================================================================ def example_search(query: str) -> dict: """Example search tool (replace with real implementation).""" return { "results": [ {"name": f"result_for_{query}", "type": "example", "description": f"Found match for '{query}'"} ] } def example_get_details(name: str) -> dict: """Example detail-fetching tool (replace with real implementation).""" return { "name": name, "required_fields": ["field_a", "field_b"], "version": 2, "examples": [{"field_a": "value1", "field_b": "value2"}] } def example_validate(name: str, config: dict) -> dict: """Example validation tool (replace with real implementation).""" errors = [] if "field_a" not in config: errors.append({"property": "field_a", "message": "Required field missing"}) if "field_b" not in config: errors.append({"property": "field_b", "message": "Required field missing"}) return {"valid": len(errors) == 0, "errors": errors} if __name__ == "__main__": parser = argparse.ArgumentParser(description="Multi-step tool calling orchestrator") parser.add_argument("--url", default="http://localhost:8000", help="VLLM server URL") parser.add_argument("--model", default="NousResearch/Hermes-3-Llama-3.1-70B-FP8") parser.add_argument("--query", default="Find and configure components for a data processing pipeline") parser.add_argument("--max-iterations", type=int, default=10) parser.add_argument("--max-retries", type=int, default=3) args = parser.parse_args() # Set up client client = VLLMClient(args.url, args.model) # Set up tool registry registry = ToolRegistry() registry.register( name="search", description="Search for components by keyword", parameters={"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}, function=example_search, steps=[1], # Only available in step 1 ) registry.register( name="get_details", description="Get configuration details for a component", parameters={"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, function=example_get_details, steps=[1, 2], # Available in both steps ) registry.register( name="validate", description="Validate a component configuration", parameters={ "type": "object", "properties": { "name": {"type": "string"}, "config": {"type": "object"} }, "required": ["name", "config"] }, function=example_validate, steps=[2], # Only available in step 2 ) # Define workflow steps steps = [ { "step_num": 1, "name": "Component Discovery", "system_prompt": ( "You are a component selection expert. Use the available tools to find " "the right components for the user's request.\n\n" "Respond with JSON: either {\"tool_calls\": [...]} to call tools, " "or {\"success\": true, \"result\": {\"components\": [...]}, \"reasoning\": \"...\"} " "when done.\n\n" "Do NOT wrap JSON in markdown. Do NOT add explanatory text." ), "max_iterations": args.max_iterations, }, { "step_num": 2, "name": "Component Configuration", "system_prompt": ( "You are a configuration expert. For each component from the previous step, " "get its details, configure all required fields, and validate the configuration.\n\n" "You MUST call 'validate' for each component before returning.\n\n" "Respond with JSON: either {\"tool_calls\": [...]} to call tools, " "or {\"success\": true, \"result\": {\"configured\": [...]}, \"reasoning\": \"...\"} " "when done.\n\n" "Do NOT wrap JSON in markdown. Do NOT add explanatory text." ), "max_iterations": args.max_iterations, "validation_tools": ["validate"], }, ] # Run workflow print(f"\nQuery: {args.query}") result = run_workflow(client, registry, steps, args.query, args.max_retries) if result: print(f"\n{'='*60}") print("WORKFLOW COMPLETE") print(f"{'='*60}") print(json.dumps(result, indent=2)) else: print(f"\n{'='*60}") print("WORKFLOW FAILED") print(f"{'='*60}")