File size: 13,206 Bytes
8cad0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
"""Offline data collection: run Qwen 3.5-4B against EnterpriseSim environment.

Collects episode trajectories and formats them as a GRPO-compatible HuggingFace Dataset.

Prerequisites:
  - vLLM serving Qwen 3.5-4B on localhost:8001
  - OPENAI_API_KEY set (for customer agent LLM responses)

Usage:
  python scripts/collect_data.py --vllm-url http://localhost:8001/v1
"""

import argparse
import json
import re
import random
import sys
import time
from pathlib import Path

from openai import OpenAI
from datasets import Dataset

# Add parent to path so we can import server modules
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from server.environment import CustomerSupportEnvironment, SupportObservation
from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction

DATA_DIR = Path(__file__).resolve().parent.parent / "data"

# --- Tool call parsing (Qwen 3.5 XML format) ---

TOOL_CALL_RE = re.compile(
    r"<tool_call>\s*<function=(\w+)>(.*?)</function>\s*</tool_call>", re.DOTALL
)
PARAM_RE = re.compile(r"<parameter=(\w+)>(.*?)</parameter>", re.DOTALL)


def parse_tool_call(text: str) -> tuple[str | None, dict | None]:
    """Extract tool name and arguments from Qwen XML tool call format."""
    match = TOOL_CALL_RE.search(text)
    if not match:
        return None, None
    tool_name = match.group(1).strip()
    args = {}
    for pm in PARAM_RE.finditer(match.group(2)):
        key = pm.group(1).strip()
        val = pm.group(2).strip()
        if key == "ticket_id":
            try:
                val = int(val)
            except ValueError:
                pass
        args[key] = val
    return tool_name, args


# --- Prompt engineering ---


def format_tools(tools) -> str:
    """Format tool list into readable text for the system prompt."""
    lines = []
    for t in tools:
        lines.append(f"### {t.name}")
        lines.append(f"{t.description}")
        schema = t.input_schema
        props = schema.get("properties", {})
        required = set(schema.get("required", []))
        if props:
            lines.append("Parameters:")
            for pname, pinfo in props.items():
                req = " (required)" if pname in required else " (optional)"
                desc = pinfo.get("description", "")
                ptype = pinfo.get("type", "string")
                lines.append(f"  - {pname} ({ptype}{req}): {desc}")
        lines.append("")
    return "\n".join(lines)


def build_system_prompt(env: CustomerSupportEnvironment) -> str:
    """Build the agent system prompt with tools + work context."""
    handbook = (DATA_DIR / "work_context/handbook.md").read_text()
    escalation = (DATA_DIR / "work_context/escalation_policy.md").read_text()
    catalog = (DATA_DIR / "work_context/product_catalog.md").read_text()

    # Get tool schemas from the MCP environment
    tools_obs = env._handle_list_tools()
    tool_text = format_tools(tools_obs.tools)

    return f"""You are a Customer Support Representative at Office Furniture Co. Help customers by investigating their issues and providing concrete solutions.

## Available Tools

{tool_text}

## Company Policies

{handbook}

## Escalation Policy

{escalation}

## Product Catalog

{catalog}

## How to Respond

Use EXACTLY this XML format for tool calls:
<tool_call>
<function=TOOL_NAME>
<parameter=PARAM_NAME>value</parameter>
</function>
</tool_call>

Strategy:
1. Look up the customer profile first
2. Check their order details
3. Send a helpful reply with a concrete solution
4. Resolve the ticket when the issue is addressed

Always investigate before replying. Be professional and empathetic."""


def format_initial_obs(obs: SupportObservation) -> str:
    """Format the initial observation (from reset) as a user message."""
    return f"""New support ticket received.

{obs.ticket_context}

Customer message:
{obs.customer_message}

What tool would you like to use to help this customer?"""


def format_step_obs(obs: SupportObservation) -> str:
    """Format a step observation (after tool call) as a user message."""
    parts = []

    if obs.tool_name:
        parts.append(f'Tool "{obs.tool_name}" result:')
        parts.append(obs.tool_result if obs.tool_result else "(no result)")
        parts.append("")

    if obs.customer_message:
        parts.append("Customer responded:")
        parts.append(obs.customer_message)
        parts.append("")

    parts.append(f"Satisfaction: {obs.satisfaction:.0%} | Steps: {obs.step_count}/10")
    parts.append("")
    parts.append("What would you like to do next?")

    return "\n".join(parts)


# --- Episode runner ---


def run_episode(env, generate_fn, system_prompt, task_id=None, seed=None):
    """Run one full episode, returning list of step records."""
    reset_kwargs = {}
    if seed is not None:
        reset_kwargs["seed"] = seed
    if task_id:
        reset_kwargs["task_id"] = task_id

    obs = env.reset(**reset_kwargs)

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": format_initial_obs(obs)},
    ]

    steps = []
    ticket_id = obs.ticket_id

    while not obs.done:
        # Snapshot the prompt at this decision point
        prompt_snapshot = [dict(m) for m in messages]

        # Generate with the trainee model
        try:
            response = generate_fn(messages)
        except Exception as e:
            print(f"    Generation error: {e}")
            break

        # Parse tool call from response
        tool_name, tool_args = parse_tool_call(response)

        if tool_name is None:
            # Fallback: treat raw text as a send_reply message
            tool_name = "send_reply"
            tool_args = {"ticket_id": ticket_id, "message": response[:500]}

        # Ensure ticket_id is set for tools that need it
        if tool_name in ("send_reply", "update_ticket") and "ticket_id" not in tool_args:
            tool_args["ticket_id"] = ticket_id

        # Execute in environment
        action = CallToolAction(tool_name=tool_name, arguments=tool_args)
        try:
            obs = env.step(action)
        except Exception as e:
            print(f"    Step error: {e}")
            steps.append({
                "prompt": prompt_snapshot,
                "completion": response,
                "error": str(e),
            })
            break

        steps.append({
            "prompt": prompt_snapshot,
            "completion": response,
            "tool_name": tool_name,
            "tool_args": tool_args,
            "tool_result": getattr(obs, "tool_result", ""),
            "customer_message": getattr(obs, "customer_message", ""),
            "satisfaction": getattr(obs, "satisfaction", 0.0),
            "satisfaction_delta": getattr(obs, "satisfaction_delta", 0.0),
            "done": obs.done,
            "reward": obs.reward,
            "resolved": getattr(obs, "resolved", False),
            "step_count": getattr(obs, "step_count", 0),
        })

        # Extend conversation for next turn
        messages.append({"role": "assistant", "content": response})
        if not obs.done:
            messages.append({"role": "user", "content": format_step_obs(obs)})

    # Backfill final episode reward to all steps
    final_reward = obs.reward if hasattr(obs, "reward") else 0.0
    resolved = getattr(obs, "resolved", False)
    for step in steps:
        step["episode_reward"] = final_reward
        step["episode_resolved"] = resolved
        step["task_id"] = task_id

    return steps


# --- Dataset formatting ---


def load_tasks(tasks_dir: Path) -> dict:
    """Load all task JSON files."""
    tasks = {}
    for f in sorted(tasks_dir.glob("task_*.json")):
        with open(f) as fh:
            data = json.load(fh)
            tasks[data["id"]] = data
    return tasks


def format_grpo_dataset(all_steps, tasks):
    """Convert collected steps into GRPO training dataset."""
    records = []
    for step in all_steps:
        if "error" in step:
            continue  # Skip failed steps

        task_data = tasks.get(step.get("task_id")) if step.get("task_id") else None
        ground_truths = []
        if task_data:
            ground_truths = [
                c.get("ground_truth")
                for c in task_data.get("rubric", [])
                if c.get("ground_truth")
            ]

        answer = json.dumps({
            "episode_reward": step["episode_reward"],
            "resolved": step["episode_resolved"],
            "task_id": step.get("task_id"),
            "ground_truth_values": ground_truths,
            "valid_tools": ["lookup_customer", "check_order", "send_reply", "update_ticket"],
        })

        records.append({"prompt": step["prompt"], "answer": answer})

    return Dataset.from_list(records)


# --- Main ---


def main():
    parser = argparse.ArgumentParser(description="Collect offline RL training data")
    parser.add_argument("--vllm-url", default="http://localhost:8001/v1", help="vLLM API URL")
    parser.add_argument("--model", default="Qwen/Qwen3.5-4B", help="Model name for vLLM")
    parser.add_argument("--runs-per-task", type=int, default=8, help="Rollouts per task")
    parser.add_argument("--random-episodes", type=int, default=16, help="Random episodes (no task_id)")
    parser.add_argument("--output-dir", default="./data/trajectories", help="Output directory")
    parser.add_argument("--seed", type=int, default=42, help="Base random seed")
    args = parser.parse_args()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # vLLM client
    client = OpenAI(base_url=args.vllm_url, api_key="none")

    def generate_fn(messages):
        resp = client.chat.completions.create(
            model=args.model,
            messages=messages,
            temperature=0.7,
            max_tokens=512,
        )
        return resp.choices[0].message.content or ""

    # Verify vLLM is reachable
    try:
        models = client.models.list()
        print(f"Connected to vLLM. Available models: {[m.id for m in models.data]}")
    except Exception as e:
        print(f"ERROR: Cannot connect to vLLM at {args.vllm_url}: {e}")
        print("Start vLLM first: python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen3.5-4B --port 8001")
        sys.exit(1)

    # Environment (in-process)
    print("Initializing environment...")
    env = CustomerSupportEnvironment()
    system_prompt = build_system_prompt(env)
    print(f"System prompt: {len(system_prompt)} chars")

    # Load tasks
    tasks = load_tasks(DATA_DIR / "tasks")
    print(f"Loaded {len(tasks)} tasks: {list(tasks.keys())}")

    # Collect episodes
    all_steps = []
    episode_count = 0
    total_episodes = len(tasks) * args.runs_per_task + args.random_episodes

    print(f"\n=== Collecting {total_episodes} episodes ===\n")

    # Task-based episodes
    for task_id in tasks:
        for run_idx in range(args.runs_per_task):
            seed = args.seed + run_idx
            episode_count += 1
            print(f"[{episode_count}/{total_episodes}] {task_id} (run {run_idx + 1})...", end=" ")

            t0 = time.time()
            steps = run_episode(env, generate_fn, system_prompt, task_id=task_id, seed=seed)
            elapsed = time.time() - t0

            all_steps.extend(steps)
            reward = steps[-1]["episode_reward"] if steps else 0.0
            resolved = steps[-1].get("episode_resolved", False) if steps else False
            print(f"{len(steps)} steps, reward={reward:.3f}, resolved={resolved}, {elapsed:.1f}s")

    # Random episodes
    for i in range(args.random_episodes):
        seed = args.seed + 1000 + i
        episode_count += 1
        print(f"[{episode_count}/{total_episodes}] random (seed={seed})...", end=" ")

        t0 = time.time()
        steps = run_episode(env, generate_fn, system_prompt, seed=seed)
        elapsed = time.time() - t0

        all_steps.extend(steps)
        reward = steps[-1]["episode_reward"] if steps else 0.0
        resolved = steps[-1].get("episode_resolved", False) if steps else False
        print(f"{len(steps)} steps, reward={reward:.3f}, resolved={resolved}, {elapsed:.1f}s")

    # Format and save
    print(f"\n=== Formatting {len(all_steps)} steps as GRPO dataset ===")
    dataset = format_grpo_dataset(all_steps, tasks)

    dataset.save_to_disk(str(output_dir / "grpo_dataset"))
    dataset.to_json(str(output_dir / "grpo_dataset.jsonl"))

    # Save raw episodes for debugging
    with open(output_dir / "episodes_raw.json", "w") as f:
        json.dump(all_steps, f, indent=2, default=str)

    # Summary stats
    rewards = [s["episode_reward"] for s in all_steps if "episode_reward" in s]
    resolved_count = sum(1 for s in all_steps if s.get("episode_resolved"))
    print(f"\nDone!")
    print(f"  Total training examples: {len(dataset)}")
    print(f"  Episodes: {total_episodes}")
    print(f"  Avg episode reward: {sum(rewards) / len(rewards):.3f}" if rewards else "  No rewards")
    print(f"  Steps with resolution: {resolved_count}/{len(all_steps)}")
    print(f"  Saved to: {output_dir}")

    env.close()


if __name__ == "__main__":
    main()