vllm-tool-calling-guide / examples /multi_step_orchestrator.py
Joshua Odmark
Initial release: VLLM tool calling guide for open source models
634c038
#!/usr/bin/env python3
"""
Multi-Step Tool Calling Orchestrator
=====================================
A production-ready orchestrator for multi-step LLM tool calling workflows.
Handles step isolation, Pydantic validation, retry logic, and error feedback.
This is the architecture pattern that makes complex tool calling reliable:
- Each step has its own isolated set of tools
- LLM responses are validated against Pydantic schemas
- Failed validations are fed back to the LLM with structured error messages
- Validation tracks whether tools PASSED, not just whether they were CALLED
Usage:
python multi_step_orchestrator.py --url http://localhost:8000 --model NousResearch/Hermes-3-Llama-3.1-70B-FP8
Example workflow (generic):
Step 1: Discover what components are needed (search, list, get_info tools)
Step 2: Configure each component (get_details, validate tools)
"""
import argparse
import json
import requests
from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel, ValidationError
from robust_json_extraction import extract_json, extract_tool_calls
from pydantic_tool_schemas import (
FunctionCall, ToolCall, StepResponse, ValidationTracker, make_tool_schema
)
# ============================================================================
# VLLM Client
# ============================================================================
class VLLMClient:
"""Simple client for VLLM's OpenAI-compatible API."""
def __init__(self, base_url: str, model: str):
self.base_url = base_url.rstrip('/')
self.model = model
def chat(self, messages: List[Dict], tools: Optional[List[Dict]] = None,
temperature: float = 0.1, max_tokens: int = 4096) -> str:
"""Send a chat completion and return the assistant's content."""
payload = {
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
response = requests.post(
f"{self.base_url}/v1/chat/completions",
json=payload,
timeout=120
)
response.raise_for_status()
return response.json()["choices"][0]["message"].get("content", "")
# ============================================================================
# Tool Registry
# ============================================================================
class ToolRegistry:
"""Registry of available tools, organized by step."""
def __init__(self):
self._tools: Dict[str, Dict[str, Any]] = {} # name -> {schema, function, steps}
def register(self, name: str, description: str, parameters: dict,
function: Callable, steps: Optional[List[int]] = None):
"""Register a tool available in specific steps (or all steps if None)."""
self._tools[name] = {
"schema": make_tool_schema(name, description, parameters),
"function": function,
"steps": steps, # None means all steps
}
def get_schemas(self, step: Optional[int] = None) -> List[Dict]:
"""Get OpenAI-format tool schemas for a given step."""
schemas = []
for name, tool in self._tools.items():
if tool["steps"] is None or (step is not None and step in tool["steps"]):
schemas.append(tool["schema"])
return schemas
def execute(self, name: str, arguments: dict) -> dict:
"""Execute a tool by name with given arguments."""
if name not in self._tools:
return {"error": f"Unknown tool: {name}"}
try:
return self._tools[name]["function"](**arguments)
except Exception as e:
return {"error": f"Error executing {name}: {str(e)}"}
# ============================================================================
# Step Runner
# ============================================================================
def run_step(
client: VLLMClient,
registry: ToolRegistry,
step_num: int,
system_prompt: str,
initial_context: str,
schema_class: type = StepResponse,
max_iterations: int = 10,
validation_tools: Optional[List[str]] = None,
) -> Optional[Dict]:
"""
Run a single step of a multi-step workflow.
Args:
client: VLLM client
registry: Tool registry
step_num: Step number (for tool filtering)
system_prompt: System prompt for this step
initial_context: User message / context from previous step
schema_class: Pydantic schema for validating responses
max_iterations: Max LLM turns before giving up
validation_tools: Names of tools that must be called AND pass
Returns:
Parsed response dict, or None if step failed.
"""
# Get tools for this step
tool_schemas = registry.get_schemas(step=step_num)
# Set up validation tracking
tracker = None
if validation_tools:
tracker = ValidationTracker(validation_tools)
# Build messages
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_context},
]
for iteration in range(1, max_iterations + 1):
print(f"\n [Step {step_num}] Iteration {iteration}/{max_iterations}")
# Generate completion
completion = client.chat(messages, tools=tool_schemas if tool_schemas else None)
if not completion:
print(f" [Step {step_num}] Empty response")
messages.append({"role": "assistant", "content": ""})
messages.append({
"role": "user",
"content": "Your response was empty. Please provide a valid JSON response."
})
continue
# Check for tool calls (Hermes XML format)
tool_call_list = extract_tool_calls(completion)
if tool_call_list:
print(f" [Step {step_num}] Found {len(tool_call_list)} tool call(s)")
messages.append({"role": "assistant", "content": completion})
for tc in tool_call_list:
print(f" -> {tc.get('name')}({json.dumps(tc.get('arguments', {}))})")
result = registry.execute(tc["name"], tc.get("arguments", {}))
# Track validation results
if tracker and tc["name"] in (validation_tools or []):
tracker.record_call(tc["name"], result)
# Format tool response
status = "ERROR" if "error" in result else "SUCCESS"
response_text = (
f"<tool_response>\n"
f"<tool_name>{tc['name']}</tool_name>\n"
f"<status>{status}</status>\n"
f"<result>{json.dumps(result, indent=2)}</result>\n"
f"</tool_response>"
)
if "error" in result:
response_text += (
"\nIMPORTANT: This tool call failed. "
"Read the error, understand the issue, fix your parameters, and retry."
)
messages.append({"role": "user", "content": response_text})
continue
# No tool calls — try to parse as final JSON response
try:
json_data = extract_json(completion)
schema_class(**json_data) # Validate structure
result_data = json_data
# Check if it's actually a tool call in JSON format
if result_data.get("tool_calls"):
messages.append({"role": "assistant", "content": completion})
for tc in result_data["tool_calls"]:
result = registry.execute(tc["name"], tc.get("arguments", {}))
if tracker and tc["name"] in (validation_tools or []):
tracker.record_call(tc["name"], result)
messages.append({
"role": "user",
"content": f"Tool result for {tc['name']}: {json.dumps(result)}"
})
continue
# It's a final response — check validation requirements
if tracker and not tracker.all_passed():
error_feedback = tracker.format_errors()
enforcement_msg = (
f"You returned a final response but validations have not all passed.\n"
f"{error_feedback}\n"
f"Please fix the errors and call the validation tools again."
)
messages.append({"role": "assistant", "content": completion})
messages.append({"role": "user", "content": enforcement_msg})
continue
print(f" [Step {step_num}] Final response received and validated")
return result_data
except (json.JSONDecodeError, ValidationError) as e:
print(f" [Step {step_num}] Parse/validation error: {e}")
messages.append({"role": "assistant", "content": completion})
messages.append({
"role": "user",
"content": (
f"Your response could not be parsed as valid JSON. Error: {str(e)}\n"
f"Please respond with ONLY valid JSON matching the required schema. "
f"No markdown, no explanatory text."
)
})
print(f" [Step {step_num}] Exhausted {max_iterations} iterations")
return None
# ============================================================================
# Multi-Step Workflow Runner
# ============================================================================
def run_workflow(
client: VLLMClient,
registry: ToolRegistry,
steps: List[Dict],
initial_query: str,
max_step_retries: int = 3,
) -> Optional[Dict]:
"""
Run a multi-step workflow.
Args:
client: VLLM client
registry: Tool registry
steps: List of step configs, each with:
- step_num: int
- system_prompt: str
- max_iterations: int
- validation_tools: Optional[List[str]]
initial_query: User's original request
max_step_retries: How many times to retry each step
Returns:
Final result dict, or None if workflow failed.
"""
previous_result = initial_query
for step_config in steps:
step_num = step_config["step_num"]
print(f"\n{'='*60}")
print(f"STEP {step_num}: {step_config.get('name', 'Unnamed')}")
print(f"{'='*60}")
for retry in range(max_step_retries):
if retry > 0:
print(f"\n Retry {retry}/{max_step_retries}")
context = previous_result if isinstance(previous_result, str) else json.dumps(previous_result)
result = run_step(
client=client,
registry=registry,
step_num=step_num,
system_prompt=step_config["system_prompt"],
initial_context=context,
max_iterations=step_config.get("max_iterations", 10),
validation_tools=step_config.get("validation_tools"),
)
if result:
previous_result = result
break
else:
print(f"\n Step {step_num} failed after {max_step_retries} retries")
return None
return previous_result
# ============================================================================
# Example: Generic Two-Step Workflow
# ============================================================================
def example_search(query: str) -> dict:
"""Example search tool (replace with real implementation)."""
return {
"results": [
{"name": f"result_for_{query}", "type": "example", "description": f"Found match for '{query}'"}
]
}
def example_get_details(name: str) -> dict:
"""Example detail-fetching tool (replace with real implementation)."""
return {
"name": name,
"required_fields": ["field_a", "field_b"],
"version": 2,
"examples": [{"field_a": "value1", "field_b": "value2"}]
}
def example_validate(name: str, config: dict) -> dict:
"""Example validation tool (replace with real implementation)."""
errors = []
if "field_a" not in config:
errors.append({"property": "field_a", "message": "Required field missing"})
if "field_b" not in config:
errors.append({"property": "field_b", "message": "Required field missing"})
return {"valid": len(errors) == 0, "errors": errors}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Multi-step tool calling orchestrator")
parser.add_argument("--url", default="http://localhost:8000", help="VLLM server URL")
parser.add_argument("--model", default="NousResearch/Hermes-3-Llama-3.1-70B-FP8")
parser.add_argument("--query", default="Find and configure components for a data processing pipeline")
parser.add_argument("--max-iterations", type=int, default=10)
parser.add_argument("--max-retries", type=int, default=3)
args = parser.parse_args()
# Set up client
client = VLLMClient(args.url, args.model)
# Set up tool registry
registry = ToolRegistry()
registry.register(
name="search",
description="Search for components by keyword",
parameters={"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]},
function=example_search,
steps=[1], # Only available in step 1
)
registry.register(
name="get_details",
description="Get configuration details for a component",
parameters={"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]},
function=example_get_details,
steps=[1, 2], # Available in both steps
)
registry.register(
name="validate",
description="Validate a component configuration",
parameters={
"type": "object",
"properties": {
"name": {"type": "string"},
"config": {"type": "object"}
},
"required": ["name", "config"]
},
function=example_validate,
steps=[2], # Only available in step 2
)
# Define workflow steps
steps = [
{
"step_num": 1,
"name": "Component Discovery",
"system_prompt": (
"You are a component selection expert. Use the available tools to find "
"the right components for the user's request.\n\n"
"Respond with JSON: either {\"tool_calls\": [...]} to call tools, "
"or {\"success\": true, \"result\": {\"components\": [...]}, \"reasoning\": \"...\"} "
"when done.\n\n"
"Do NOT wrap JSON in markdown. Do NOT add explanatory text."
),
"max_iterations": args.max_iterations,
},
{
"step_num": 2,
"name": "Component Configuration",
"system_prompt": (
"You are a configuration expert. For each component from the previous step, "
"get its details, configure all required fields, and validate the configuration.\n\n"
"You MUST call 'validate' for each component before returning.\n\n"
"Respond with JSON: either {\"tool_calls\": [...]} to call tools, "
"or {\"success\": true, \"result\": {\"configured\": [...]}, \"reasoning\": \"...\"} "
"when done.\n\n"
"Do NOT wrap JSON in markdown. Do NOT add explanatory text."
),
"max_iterations": args.max_iterations,
"validation_tools": ["validate"],
},
]
# Run workflow
print(f"\nQuery: {args.query}")
result = run_workflow(client, registry, steps, args.query, args.max_retries)
if result:
print(f"\n{'='*60}")
print("WORKFLOW COMPLETE")
print(f"{'='*60}")
print(json.dumps(result, indent=2))
else:
print(f"\n{'='*60}")
print("WORKFLOW FAILED")
print(f"{'='*60}")