| | |
| | """ |
| | Multi-Step Tool Calling Orchestrator |
| | ===================================== |
| | |
| | A production-ready orchestrator for multi-step LLM tool calling workflows. |
| | Handles step isolation, Pydantic validation, retry logic, and error feedback. |
| | |
| | This is the architecture pattern that makes complex tool calling reliable: |
| | - Each step has its own isolated set of tools |
| | - LLM responses are validated against Pydantic schemas |
| | - Failed validations are fed back to the LLM with structured error messages |
| | - Validation tracks whether tools PASSED, not just whether they were CALLED |
| | |
| | Usage: |
| | python multi_step_orchestrator.py --url http://localhost:8000 --model NousResearch/Hermes-3-Llama-3.1-70B-FP8 |
| | |
| | Example workflow (generic): |
| | Step 1: Discover what components are needed (search, list, get_info tools) |
| | Step 2: Configure each component (get_details, validate tools) |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import requests |
| | from typing import Any, Callable, Dict, List, Optional |
| | from pydantic import BaseModel, ValidationError |
| |
|
| | from robust_json_extraction import extract_json, extract_tool_calls |
| | from pydantic_tool_schemas import ( |
| | FunctionCall, ToolCall, StepResponse, ValidationTracker, make_tool_schema |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class VLLMClient: |
| | """Simple client for VLLM's OpenAI-compatible API.""" |
| |
|
| | def __init__(self, base_url: str, model: str): |
| | self.base_url = base_url.rstrip('/') |
| | self.model = model |
| |
|
| | def chat(self, messages: List[Dict], tools: Optional[List[Dict]] = None, |
| | temperature: float = 0.1, max_tokens: int = 4096) -> str: |
| | """Send a chat completion and return the assistant's content.""" |
| | payload = { |
| | "model": self.model, |
| | "messages": messages, |
| | "temperature": temperature, |
| | "max_tokens": max_tokens, |
| | } |
| | if tools: |
| | payload["tools"] = tools |
| | payload["tool_choice"] = "auto" |
| |
|
| | response = requests.post( |
| | f"{self.base_url}/v1/chat/completions", |
| | json=payload, |
| | timeout=120 |
| | ) |
| | response.raise_for_status() |
| | return response.json()["choices"][0]["message"].get("content", "") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ToolRegistry: |
| | """Registry of available tools, organized by step.""" |
| |
|
| | def __init__(self): |
| | self._tools: Dict[str, Dict[str, Any]] = {} |
| |
|
| | def register(self, name: str, description: str, parameters: dict, |
| | function: Callable, steps: Optional[List[int]] = None): |
| | """Register a tool available in specific steps (or all steps if None).""" |
| | self._tools[name] = { |
| | "schema": make_tool_schema(name, description, parameters), |
| | "function": function, |
| | "steps": steps, |
| | } |
| |
|
| | def get_schemas(self, step: Optional[int] = None) -> List[Dict]: |
| | """Get OpenAI-format tool schemas for a given step.""" |
| | schemas = [] |
| | for name, tool in self._tools.items(): |
| | if tool["steps"] is None or (step is not None and step in tool["steps"]): |
| | schemas.append(tool["schema"]) |
| | return schemas |
| |
|
| | def execute(self, name: str, arguments: dict) -> dict: |
| | """Execute a tool by name with given arguments.""" |
| | if name not in self._tools: |
| | return {"error": f"Unknown tool: {name}"} |
| | try: |
| | return self._tools[name]["function"](**arguments) |
| | except Exception as e: |
| | return {"error": f"Error executing {name}: {str(e)}"} |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def run_step( |
| | client: VLLMClient, |
| | registry: ToolRegistry, |
| | step_num: int, |
| | system_prompt: str, |
| | initial_context: str, |
| | schema_class: type = StepResponse, |
| | max_iterations: int = 10, |
| | validation_tools: Optional[List[str]] = None, |
| | ) -> Optional[Dict]: |
| | """ |
| | Run a single step of a multi-step workflow. |
| | |
| | Args: |
| | client: VLLM client |
| | registry: Tool registry |
| | step_num: Step number (for tool filtering) |
| | system_prompt: System prompt for this step |
| | initial_context: User message / context from previous step |
| | schema_class: Pydantic schema for validating responses |
| | max_iterations: Max LLM turns before giving up |
| | validation_tools: Names of tools that must be called AND pass |
| | |
| | Returns: |
| | Parsed response dict, or None if step failed. |
| | """ |
| | |
| | tool_schemas = registry.get_schemas(step=step_num) |
| |
|
| | |
| | tracker = None |
| | if validation_tools: |
| | tracker = ValidationTracker(validation_tools) |
| |
|
| | |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": initial_context}, |
| | ] |
| |
|
| | for iteration in range(1, max_iterations + 1): |
| | print(f"\n [Step {step_num}] Iteration {iteration}/{max_iterations}") |
| |
|
| | |
| | completion = client.chat(messages, tools=tool_schemas if tool_schemas else None) |
| |
|
| | if not completion: |
| | print(f" [Step {step_num}] Empty response") |
| | messages.append({"role": "assistant", "content": ""}) |
| | messages.append({ |
| | "role": "user", |
| | "content": "Your response was empty. Please provide a valid JSON response." |
| | }) |
| | continue |
| |
|
| | |
| | tool_call_list = extract_tool_calls(completion) |
| | if tool_call_list: |
| | print(f" [Step {step_num}] Found {len(tool_call_list)} tool call(s)") |
| | messages.append({"role": "assistant", "content": completion}) |
| |
|
| | for tc in tool_call_list: |
| | print(f" -> {tc.get('name')}({json.dumps(tc.get('arguments', {}))})") |
| | result = registry.execute(tc["name"], tc.get("arguments", {})) |
| |
|
| | |
| | if tracker and tc["name"] in (validation_tools or []): |
| | tracker.record_call(tc["name"], result) |
| |
|
| | |
| | status = "ERROR" if "error" in result else "SUCCESS" |
| | response_text = ( |
| | f"<tool_response>\n" |
| | f"<tool_name>{tc['name']}</tool_name>\n" |
| | f"<status>{status}</status>\n" |
| | f"<result>{json.dumps(result, indent=2)}</result>\n" |
| | f"</tool_response>" |
| | ) |
| | if "error" in result: |
| | response_text += ( |
| | "\nIMPORTANT: This tool call failed. " |
| | "Read the error, understand the issue, fix your parameters, and retry." |
| | ) |
| |
|
| | messages.append({"role": "user", "content": response_text}) |
| | continue |
| |
|
| | |
| | try: |
| | json_data = extract_json(completion) |
| | schema_class(**json_data) |
| | result_data = json_data |
| |
|
| | |
| | if result_data.get("tool_calls"): |
| | messages.append({"role": "assistant", "content": completion}) |
| | for tc in result_data["tool_calls"]: |
| | result = registry.execute(tc["name"], tc.get("arguments", {})) |
| | if tracker and tc["name"] in (validation_tools or []): |
| | tracker.record_call(tc["name"], result) |
| | messages.append({ |
| | "role": "user", |
| | "content": f"Tool result for {tc['name']}: {json.dumps(result)}" |
| | }) |
| | continue |
| |
|
| | |
| | if tracker and not tracker.all_passed(): |
| | error_feedback = tracker.format_errors() |
| | enforcement_msg = ( |
| | f"You returned a final response but validations have not all passed.\n" |
| | f"{error_feedback}\n" |
| | f"Please fix the errors and call the validation tools again." |
| | ) |
| | messages.append({"role": "assistant", "content": completion}) |
| | messages.append({"role": "user", "content": enforcement_msg}) |
| | continue |
| |
|
| | print(f" [Step {step_num}] Final response received and validated") |
| | return result_data |
| |
|
| | except (json.JSONDecodeError, ValidationError) as e: |
| | print(f" [Step {step_num}] Parse/validation error: {e}") |
| | messages.append({"role": "assistant", "content": completion}) |
| | messages.append({ |
| | "role": "user", |
| | "content": ( |
| | f"Your response could not be parsed as valid JSON. Error: {str(e)}\n" |
| | f"Please respond with ONLY valid JSON matching the required schema. " |
| | f"No markdown, no explanatory text." |
| | ) |
| | }) |
| |
|
| | print(f" [Step {step_num}] Exhausted {max_iterations} iterations") |
| | return None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def run_workflow( |
| | client: VLLMClient, |
| | registry: ToolRegistry, |
| | steps: List[Dict], |
| | initial_query: str, |
| | max_step_retries: int = 3, |
| | ) -> Optional[Dict]: |
| | """ |
| | Run a multi-step workflow. |
| | |
| | Args: |
| | client: VLLM client |
| | registry: Tool registry |
| | steps: List of step configs, each with: |
| | - step_num: int |
| | - system_prompt: str |
| | - max_iterations: int |
| | - validation_tools: Optional[List[str]] |
| | initial_query: User's original request |
| | max_step_retries: How many times to retry each step |
| | |
| | Returns: |
| | Final result dict, or None if workflow failed. |
| | """ |
| | previous_result = initial_query |
| |
|
| | for step_config in steps: |
| | step_num = step_config["step_num"] |
| | print(f"\n{'='*60}") |
| | print(f"STEP {step_num}: {step_config.get('name', 'Unnamed')}") |
| | print(f"{'='*60}") |
| |
|
| | for retry in range(max_step_retries): |
| | if retry > 0: |
| | print(f"\n Retry {retry}/{max_step_retries}") |
| |
|
| | context = previous_result if isinstance(previous_result, str) else json.dumps(previous_result) |
| |
|
| | result = run_step( |
| | client=client, |
| | registry=registry, |
| | step_num=step_num, |
| | system_prompt=step_config["system_prompt"], |
| | initial_context=context, |
| | max_iterations=step_config.get("max_iterations", 10), |
| | validation_tools=step_config.get("validation_tools"), |
| | ) |
| |
|
| | if result: |
| | previous_result = result |
| | break |
| | else: |
| | print(f"\n Step {step_num} failed after {max_step_retries} retries") |
| | return None |
| |
|
| | return previous_result |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def example_search(query: str) -> dict: |
| | """Example search tool (replace with real implementation).""" |
| | return { |
| | "results": [ |
| | {"name": f"result_for_{query}", "type": "example", "description": f"Found match for '{query}'"} |
| | ] |
| | } |
| |
|
| |
|
| | def example_get_details(name: str) -> dict: |
| | """Example detail-fetching tool (replace with real implementation).""" |
| | return { |
| | "name": name, |
| | "required_fields": ["field_a", "field_b"], |
| | "version": 2, |
| | "examples": [{"field_a": "value1", "field_b": "value2"}] |
| | } |
| |
|
| |
|
| | def example_validate(name: str, config: dict) -> dict: |
| | """Example validation tool (replace with real implementation).""" |
| | errors = [] |
| | if "field_a" not in config: |
| | errors.append({"property": "field_a", "message": "Required field missing"}) |
| | if "field_b" not in config: |
| | errors.append({"property": "field_b", "message": "Required field missing"}) |
| | return {"valid": len(errors) == 0, "errors": errors} |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Multi-step tool calling orchestrator") |
| | parser.add_argument("--url", default="http://localhost:8000", help="VLLM server URL") |
| | parser.add_argument("--model", default="NousResearch/Hermes-3-Llama-3.1-70B-FP8") |
| | parser.add_argument("--query", default="Find and configure components for a data processing pipeline") |
| | parser.add_argument("--max-iterations", type=int, default=10) |
| | parser.add_argument("--max-retries", type=int, default=3) |
| | args = parser.parse_args() |
| |
|
| | |
| | client = VLLMClient(args.url, args.model) |
| |
|
| | |
| | registry = ToolRegistry() |
| | registry.register( |
| | name="search", |
| | description="Search for components by keyword", |
| | parameters={"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}, |
| | function=example_search, |
| | steps=[1], |
| | ) |
| | registry.register( |
| | name="get_details", |
| | description="Get configuration details for a component", |
| | parameters={"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, |
| | function=example_get_details, |
| | steps=[1, 2], |
| | ) |
| | registry.register( |
| | name="validate", |
| | description="Validate a component configuration", |
| | parameters={ |
| | "type": "object", |
| | "properties": { |
| | "name": {"type": "string"}, |
| | "config": {"type": "object"} |
| | }, |
| | "required": ["name", "config"] |
| | }, |
| | function=example_validate, |
| | steps=[2], |
| | ) |
| |
|
| | |
| | steps = [ |
| | { |
| | "step_num": 1, |
| | "name": "Component Discovery", |
| | "system_prompt": ( |
| | "You are a component selection expert. Use the available tools to find " |
| | "the right components for the user's request.\n\n" |
| | "Respond with JSON: either {\"tool_calls\": [...]} to call tools, " |
| | "or {\"success\": true, \"result\": {\"components\": [...]}, \"reasoning\": \"...\"} " |
| | "when done.\n\n" |
| | "Do NOT wrap JSON in markdown. Do NOT add explanatory text." |
| | ), |
| | "max_iterations": args.max_iterations, |
| | }, |
| | { |
| | "step_num": 2, |
| | "name": "Component Configuration", |
| | "system_prompt": ( |
| | "You are a configuration expert. For each component from the previous step, " |
| | "get its details, configure all required fields, and validate the configuration.\n\n" |
| | "You MUST call 'validate' for each component before returning.\n\n" |
| | "Respond with JSON: either {\"tool_calls\": [...]} to call tools, " |
| | "or {\"success\": true, \"result\": {\"configured\": [...]}, \"reasoning\": \"...\"} " |
| | "when done.\n\n" |
| | "Do NOT wrap JSON in markdown. Do NOT add explanatory text." |
| | ), |
| | "max_iterations": args.max_iterations, |
| | "validation_tools": ["validate"], |
| | }, |
| | ] |
| |
|
| | |
| | print(f"\nQuery: {args.query}") |
| | result = run_workflow(client, registry, steps, args.query, args.max_retries) |
| |
|
| | if result: |
| | print(f"\n{'='*60}") |
| | print("WORKFLOW COMPLETE") |
| | print(f"{'='*60}") |
| | print(json.dumps(result, indent=2)) |
| | else: |
| | print(f"\n{'='*60}") |
| | print("WORKFLOW FAILED") |
| | print(f"{'='*60}") |
| |
|