File size: 6,634 Bytes
a3d65ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#!/usr/bin/env python3
"""
baseline.py β€” Baseline inference script for the Support Ticket Environment.

Runs an OpenAI-compatible model against all 3 tasks and reports scores.

Usage:
    OPENAI_API_KEY=sk-... python baseline.py --base-url http://localhost:7860

Environment variables:
    OPENAI_API_KEY   : required
    OPENAI_BASE_URL  : optional override (default https://api.openai.com/v1)
    OPENAI_MODEL     : optional model name (default gpt-4o-mini)
"""

import argparse
import json
import os
import asyncio
import re

from openai import AsyncOpenAI
from support_ticket_env.client import SupportTicketEnv
from support_ticket_env.models import SupportAction

# ─────────────────────────── Config ────────────────────────────

VALID_CATEGORIES = ["billing", "technical", "account", "general", "refund"]
VALID_ACTIONS = ["classify", "reply", "escalate", "close"]

SYSTEM_PROMPT = """You are a customer support AI agent operating in a ticket triage environment.

On each turn you receive a JSON observation with:
  - ticket_text : the customer's message
  - feedback    : what happened last step
  - task_id     : 1=classify only, 2=classify then act, 3=full resolution

You must respond with a JSON object (no markdown) matching this schema:
{
  "action_type": "classify" | "reply" | "escalate" | "close",
  "category": "billing" | "technical" | "account" | "general" | "refund",  // only for classify
  "reply_text": "...",  // only for reply
  "reason": "..."       // optional
}

Strategy:
- For task 1: only classify (use action_type="classify" with a category).
- For task 2: first classify, then choose the best action.
- For task 3: classify each ticket, then reply/escalate/close as appropriate.

Always produce valid JSON and nothing else.
"""


def parse_llm_response(text: str) -> dict:
    """Extract JSON from LLM response, stripping markdown fences if present."""
    text = text.strip()
    # Strip ```json ... ``` fences
    text = re.sub(r"^```(?:json)?\s*", "", text)
    text = re.sub(r"\s*```$", "", text)
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        # Fallback: try to extract first JSON object
        match = re.search(r"\{.*\}", text, re.DOTALL)
        if match:
            return json.loads(match.group())
        raise


async def run_task(
    env_base_url: str,
    llm: AsyncOpenAI,
    model: str,
    task_id: int,
    seed: int = 42,
    max_steps: int = 10,
) -> float:
    """Run one episode for a given task_id. Returns the total reward."""
    total_reward = 0.0
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]

    async with SupportTicketEnv(base_url=env_base_url) as env:
        result = await env.reset(task_id=task_id, seed=seed)
        obs = result.observation

        for step in range(max_steps):
            # Build user message from observation
            obs_text = json.dumps({
                "ticket_id": obs.ticket_id,
                "ticket_text": obs.ticket_text,
                "task_id": obs.task_id,
                "current_category": obs.current_category,
                "resolved": obs.resolved,
                "step_count": obs.step_count,
                "feedback": obs.feedback,
            }, indent=2)

            messages.append({"role": "user", "content": obs_text})

            # Call LLM
            response = await llm.chat.completions.create(
                model=model,
                messages=messages,
                temperature=0.0,
                max_tokens=256,
            )
            assistant_text = response.choices[0].message.content
            messages.append({"role": "assistant", "content": assistant_text})

            # Parse action
            try:
                action_dict = parse_llm_response(assistant_text)
            except Exception as e:
                print(f"  [step {step+1}] Failed to parse LLM response: {e}")
                break

            try:
                action = SupportAction(**action_dict)
            except Exception as e:
                print(f"  [step {step+1}] Invalid action schema: {e}")
                break

            # Step environment
            result = await env.step(action)
            obs = result.observation
            reward = result.reward or 0.0
            total_reward += reward

            print(
                f"  [step {step+1}] action={action.action_type}"
                + (f"/{action.category}" if action.category else "")
                + f"  reward={reward:.3f}  feedback={obs.feedback[:60]}"
            )

            if result.done:
                break

    return round(total_reward, 4)


async def main(env_base_url: str, model: str, seeds: list[int]) -> None:
    api_key = os.environ.get("OPENAI_API_KEY", "not-needed")
    openai_base = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")

    llm = AsyncOpenAI(api_key=api_key, base_url=openai_base)

    results = {}
    for task_id in [1, 2, 3]:
        task_scores = []
        print(f"\n{'='*60}")
        print(f"  TASK {task_id}  (seed={seeds[0]})")
        print(f"{'='*60}")
        for seed in seeds:
            score = await run_task(env_base_url, llm, model, task_id, seed=seed)
            task_scores.append(score)
            print(f"  β†’ total_reward for seed {seed}: {score}")
        avg = round(sum(task_scores) / len(task_scores), 4)
        results[f"task{task_id}"] = {"scores": task_scores, "avg": avg}
        print(f"  β–Ί Average: {avg}")

    print("\n" + "="*60)
    print("  BASELINE SUMMARY")
    print("="*60)
    for k, v in results.items():
        print(f"  {k}: avg={v['avg']:.4f}  scores={v['scores']}")
    overall = round(
        sum(v["avg"] for v in results.values()) / len(results), 4
    )
    print(f"\n  Overall avg: {overall:.4f}")
    print("="*60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Baseline inference for support_ticket_env")
    parser.add_argument(
        "--base-url",
        default=os.environ.get("ENV_BASE_URL", "http://localhost:7860"),
        help="Base URL of the running environment server",
    )
    parser.add_argument(
        "--model",
        default=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
        help="OpenAI model name",
    )
    parser.add_argument(
        "--seeds",
        nargs="+",
        type=int,
        default=[42, 7, 123],
        help="Random seeds for reproducibility",
    )
    args = parser.parse_args()
    asyncio.run(main(args.base_url, args.model, args.seeds))