File size: 16,373 Bytes
634c038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
#!/usr/bin/env python3
"""
Multi-Step Tool Calling Orchestrator
=====================================

A production-ready orchestrator for multi-step LLM tool calling workflows.
Handles step isolation, Pydantic validation, retry logic, and error feedback.

This is the architecture pattern that makes complex tool calling reliable:
- Each step has its own isolated set of tools
- LLM responses are validated against Pydantic schemas
- Failed validations are fed back to the LLM with structured error messages
- Validation tracks whether tools PASSED, not just whether they were CALLED

Usage:
    python multi_step_orchestrator.py --url http://localhost:8000 --model NousResearch/Hermes-3-Llama-3.1-70B-FP8

Example workflow (generic):
    Step 1: Discover what components are needed (search, list, get_info tools)
    Step 2: Configure each component (get_details, validate tools)
"""

import argparse
import json
import requests
from typing import Any, Callable, Dict, List, Optional
from pydantic import BaseModel, ValidationError

from robust_json_extraction import extract_json, extract_tool_calls
from pydantic_tool_schemas import (
    FunctionCall, ToolCall, StepResponse, ValidationTracker, make_tool_schema
)


# ============================================================================
# VLLM Client
# ============================================================================

class VLLMClient:
    """Simple client for VLLM's OpenAI-compatible API."""

    def __init__(self, base_url: str, model: str):
        self.base_url = base_url.rstrip('/')
        self.model = model

    def chat(self, messages: List[Dict], tools: Optional[List[Dict]] = None,
             temperature: float = 0.1, max_tokens: int = 4096) -> str:
        """Send a chat completion and return the assistant's content."""
        payload = {
            "model": self.model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
        if tools:
            payload["tools"] = tools
            payload["tool_choice"] = "auto"

        response = requests.post(
            f"{self.base_url}/v1/chat/completions",
            json=payload,
            timeout=120
        )
        response.raise_for_status()
        return response.json()["choices"][0]["message"].get("content", "")


# ============================================================================
# Tool Registry
# ============================================================================

class ToolRegistry:
    """Registry of available tools, organized by step."""

    def __init__(self):
        self._tools: Dict[str, Dict[str, Any]] = {}  # name -> {schema, function, steps}

    def register(self, name: str, description: str, parameters: dict,
                 function: Callable, steps: Optional[List[int]] = None):
        """Register a tool available in specific steps (or all steps if None)."""
        self._tools[name] = {
            "schema": make_tool_schema(name, description, parameters),
            "function": function,
            "steps": steps,  # None means all steps
        }

    def get_schemas(self, step: Optional[int] = None) -> List[Dict]:
        """Get OpenAI-format tool schemas for a given step."""
        schemas = []
        for name, tool in self._tools.items():
            if tool["steps"] is None or (step is not None and step in tool["steps"]):
                schemas.append(tool["schema"])
        return schemas

    def execute(self, name: str, arguments: dict) -> dict:
        """Execute a tool by name with given arguments."""
        if name not in self._tools:
            return {"error": f"Unknown tool: {name}"}
        try:
            return self._tools[name]["function"](**arguments)
        except Exception as e:
            return {"error": f"Error executing {name}: {str(e)}"}


# ============================================================================
# Step Runner
# ============================================================================

def run_step(
    client: VLLMClient,
    registry: ToolRegistry,
    step_num: int,
    system_prompt: str,
    initial_context: str,
    schema_class: type = StepResponse,
    max_iterations: int = 10,
    validation_tools: Optional[List[str]] = None,
) -> Optional[Dict]:
    """
    Run a single step of a multi-step workflow.

    Args:
        client: VLLM client
        registry: Tool registry
        step_num: Step number (for tool filtering)
        system_prompt: System prompt for this step
        initial_context: User message / context from previous step
        schema_class: Pydantic schema for validating responses
        max_iterations: Max LLM turns before giving up
        validation_tools: Names of tools that must be called AND pass

    Returns:
        Parsed response dict, or None if step failed.
    """
    # Get tools for this step
    tool_schemas = registry.get_schemas(step=step_num)

    # Set up validation tracking
    tracker = None
    if validation_tools:
        tracker = ValidationTracker(validation_tools)

    # Build messages
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": initial_context},
    ]

    for iteration in range(1, max_iterations + 1):
        print(f"\n  [Step {step_num}] Iteration {iteration}/{max_iterations}")

        # Generate completion
        completion = client.chat(messages, tools=tool_schemas if tool_schemas else None)

        if not completion:
            print(f"  [Step {step_num}] Empty response")
            messages.append({"role": "assistant", "content": ""})
            messages.append({
                "role": "user",
                "content": "Your response was empty. Please provide a valid JSON response."
            })
            continue

        # Check for tool calls (Hermes XML format)
        tool_call_list = extract_tool_calls(completion)
        if tool_call_list:
            print(f"  [Step {step_num}] Found {len(tool_call_list)} tool call(s)")
            messages.append({"role": "assistant", "content": completion})

            for tc in tool_call_list:
                print(f"    -> {tc.get('name')}({json.dumps(tc.get('arguments', {}))})")
                result = registry.execute(tc["name"], tc.get("arguments", {}))

                # Track validation results
                if tracker and tc["name"] in (validation_tools or []):
                    tracker.record_call(tc["name"], result)

                # Format tool response
                status = "ERROR" if "error" in result else "SUCCESS"
                response_text = (
                    f"<tool_response>\n"
                    f"<tool_name>{tc['name']}</tool_name>\n"
                    f"<status>{status}</status>\n"
                    f"<result>{json.dumps(result, indent=2)}</result>\n"
                    f"</tool_response>"
                )
                if "error" in result:
                    response_text += (
                        "\nIMPORTANT: This tool call failed. "
                        "Read the error, understand the issue, fix your parameters, and retry."
                    )

                messages.append({"role": "user", "content": response_text})
            continue

        # No tool calls — try to parse as final JSON response
        try:
            json_data = extract_json(completion)
            schema_class(**json_data)  # Validate structure
            result_data = json_data

            # Check if it's actually a tool call in JSON format
            if result_data.get("tool_calls"):
                messages.append({"role": "assistant", "content": completion})
                for tc in result_data["tool_calls"]:
                    result = registry.execute(tc["name"], tc.get("arguments", {}))
                    if tracker and tc["name"] in (validation_tools or []):
                        tracker.record_call(tc["name"], result)
                    messages.append({
                        "role": "user",
                        "content": f"Tool result for {tc['name']}: {json.dumps(result)}"
                    })
                continue

            # It's a final response — check validation requirements
            if tracker and not tracker.all_passed():
                error_feedback = tracker.format_errors()
                enforcement_msg = (
                    f"You returned a final response but validations have not all passed.\n"
                    f"{error_feedback}\n"
                    f"Please fix the errors and call the validation tools again."
                )
                messages.append({"role": "assistant", "content": completion})
                messages.append({"role": "user", "content": enforcement_msg})
                continue

            print(f"  [Step {step_num}] Final response received and validated")
            return result_data

        except (json.JSONDecodeError, ValidationError) as e:
            print(f"  [Step {step_num}] Parse/validation error: {e}")
            messages.append({"role": "assistant", "content": completion})
            messages.append({
                "role": "user",
                "content": (
                    f"Your response could not be parsed as valid JSON. Error: {str(e)}\n"
                    f"Please respond with ONLY valid JSON matching the required schema. "
                    f"No markdown, no explanatory text."
                )
            })

    print(f"  [Step {step_num}] Exhausted {max_iterations} iterations")
    return None


# ============================================================================
# Multi-Step Workflow Runner
# ============================================================================

def run_workflow(
    client: VLLMClient,
    registry: ToolRegistry,
    steps: List[Dict],
    initial_query: str,
    max_step_retries: int = 3,
) -> Optional[Dict]:
    """
    Run a multi-step workflow.

    Args:
        client: VLLM client
        registry: Tool registry
        steps: List of step configs, each with:
            - step_num: int
            - system_prompt: str
            - max_iterations: int
            - validation_tools: Optional[List[str]]
        initial_query: User's original request
        max_step_retries: How many times to retry each step

    Returns:
        Final result dict, or None if workflow failed.
    """
    previous_result = initial_query

    for step_config in steps:
        step_num = step_config["step_num"]
        print(f"\n{'='*60}")
        print(f"STEP {step_num}: {step_config.get('name', 'Unnamed')}")
        print(f"{'='*60}")

        for retry in range(max_step_retries):
            if retry > 0:
                print(f"\n  Retry {retry}/{max_step_retries}")

            context = previous_result if isinstance(previous_result, str) else json.dumps(previous_result)

            result = run_step(
                client=client,
                registry=registry,
                step_num=step_num,
                system_prompt=step_config["system_prompt"],
                initial_context=context,
                max_iterations=step_config.get("max_iterations", 10),
                validation_tools=step_config.get("validation_tools"),
            )

            if result:
                previous_result = result
                break
        else:
            print(f"\n  Step {step_num} failed after {max_step_retries} retries")
            return None

    return previous_result


# ============================================================================
# Example: Generic Two-Step Workflow
# ============================================================================

def example_search(query: str) -> dict:
    """Example search tool (replace with real implementation)."""
    return {
        "results": [
            {"name": f"result_for_{query}", "type": "example", "description": f"Found match for '{query}'"}
        ]
    }


def example_get_details(name: str) -> dict:
    """Example detail-fetching tool (replace with real implementation)."""
    return {
        "name": name,
        "required_fields": ["field_a", "field_b"],
        "version": 2,
        "examples": [{"field_a": "value1", "field_b": "value2"}]
    }


def example_validate(name: str, config: dict) -> dict:
    """Example validation tool (replace with real implementation)."""
    errors = []
    if "field_a" not in config:
        errors.append({"property": "field_a", "message": "Required field missing"})
    if "field_b" not in config:
        errors.append({"property": "field_b", "message": "Required field missing"})
    return {"valid": len(errors) == 0, "errors": errors}


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Multi-step tool calling orchestrator")
    parser.add_argument("--url", default="http://localhost:8000", help="VLLM server URL")
    parser.add_argument("--model", default="NousResearch/Hermes-3-Llama-3.1-70B-FP8")
    parser.add_argument("--query", default="Find and configure components for a data processing pipeline")
    parser.add_argument("--max-iterations", type=int, default=10)
    parser.add_argument("--max-retries", type=int, default=3)
    args = parser.parse_args()

    # Set up client
    client = VLLMClient(args.url, args.model)

    # Set up tool registry
    registry = ToolRegistry()
    registry.register(
        name="search",
        description="Search for components by keyword",
        parameters={"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]},
        function=example_search,
        steps=[1],  # Only available in step 1
    )
    registry.register(
        name="get_details",
        description="Get configuration details for a component",
        parameters={"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]},
        function=example_get_details,
        steps=[1, 2],  # Available in both steps
    )
    registry.register(
        name="validate",
        description="Validate a component configuration",
        parameters={
            "type": "object",
            "properties": {
                "name": {"type": "string"},
                "config": {"type": "object"}
            },
            "required": ["name", "config"]
        },
        function=example_validate,
        steps=[2],  # Only available in step 2
    )

    # Define workflow steps
    steps = [
        {
            "step_num": 1,
            "name": "Component Discovery",
            "system_prompt": (
                "You are a component selection expert. Use the available tools to find "
                "the right components for the user's request.\n\n"
                "Respond with JSON: either {\"tool_calls\": [...]} to call tools, "
                "or {\"success\": true, \"result\": {\"components\": [...]}, \"reasoning\": \"...\"} "
                "when done.\n\n"
                "Do NOT wrap JSON in markdown. Do NOT add explanatory text."
            ),
            "max_iterations": args.max_iterations,
        },
        {
            "step_num": 2,
            "name": "Component Configuration",
            "system_prompt": (
                "You are a configuration expert. For each component from the previous step, "
                "get its details, configure all required fields, and validate the configuration.\n\n"
                "You MUST call 'validate' for each component before returning.\n\n"
                "Respond with JSON: either {\"tool_calls\": [...]} to call tools, "
                "or {\"success\": true, \"result\": {\"configured\": [...]}, \"reasoning\": \"...\"} "
                "when done.\n\n"
                "Do NOT wrap JSON in markdown. Do NOT add explanatory text."
            ),
            "max_iterations": args.max_iterations,
            "validation_tools": ["validate"],
        },
    ]

    # Run workflow
    print(f"\nQuery: {args.query}")
    result = run_workflow(client, registry, steps, args.query, args.max_retries)

    if result:
        print(f"\n{'='*60}")
        print("WORKFLOW COMPLETE")
        print(f"{'='*60}")
        print(json.dumps(result, indent=2))
    else:
        print(f"\n{'='*60}")
        print("WORKFLOW FAILED")
        print(f"{'='*60}")