File size: 16,066 Bytes
841976b
 
bd67155
 
735d73f
bd67155
735d73f
bd67155
 
 
 
 
 
 
841976b
bd67155
 
 
 
 
 
841976b
bd67155
735d73f
bd67155
 
735d73f
 
 
a2abcaa
 
bd67155
735d73f
 
 
 
bd67155
 
 
735d73f
841976b
bd67155
 
 
 
 
 
 
 
 
841976b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735d73f
841976b
 
 
 
 
735d73f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841976b
 
 
 
 
 
 
 
735d73f
841976b
 
735d73f
bd67155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735d73f
 
 
 
 
 
 
bd67155
841976b
735d73f
 
 
 
 
 
 
 
 
 
bd67155
 
 
 
 
 
841976b
735d73f
 
 
841976b
 
 
bd67155
 
841976b
735d73f
bd67155
 
 
 
735d73f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd67155
735d73f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd67155
 
a2abcaa
ebc798b
 
 
 
 
bd67155
 
a2abcaa
 
 
 
 
 
 
 
bd67155
a2abcaa
 
 
 
 
 
 
 
 
bd67155
841976b
735d73f
 
 
bd67155
 
 
 
 
 
 
 
 
 
735d73f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd67155
841976b
bd67155
735d73f
 
 
 
 
bd67155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2abcaa
bd67155
 
 
 
a2abcaa
 
 
 
 
 
841976b
a2abcaa
 
 
 
 
 
 
 
 
 
 
 
 
 
bd67155
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
from dotenv import load_dotenv
load_dotenv()
import json
import os
import re
import textwrap
import time
from typing import List, Optional

from openai import OpenAI

from support_ops_env.env import SupportOpsEnv
from support_ops_env.models import Action, Observation
from support_ops_env.tasks import list_task_ids

LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
TASK_NAME = os.getenv("SUPPORT_OPS_TASK", "easy_account_takeover")
BENCHMARK = os.getenv("SUPPORT_OPS_BENCHMARK", "support_ops_env")
MAX_STEPS = int(os.getenv("MAX_STEPS", "24"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.1"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "4096"))  # reasoning models need budget for <think> blocks
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.8"))

# FIX 1: Retry budget for malformed JSON responses before giving up
JSON_RETRY_LIMIT = int(os.getenv("JSON_RETRY_LIMIT", "3"))

# Minimum number of tasks required by the grader
MIN_TASKS = 3

# Actions that must be completed for every ticket before finalize is allowed.
# finalize without these is the #1 score killer based on the logs.
REQUIRED_PER_TICKET = {"set_priority", "set_route", "set_resolution"}

SYSTEM_PROMPT = textwrap.dedent(
    """
    You are operating a customer support triage environment.
    Return exactly one JSON object with keys: action_type, target, value. No extra text, no markdown, no code fences.

    Allowed action_type values:
    - inspect_ticket
    - request_context
    - set_priority
    - set_route
    - set_resolution
    - escalate
    - rank_queue
    - finalize

    VALID VALUES β€” you MUST use these exact strings:

    priority values: urgent, high, normal, low
    route values: account_security, monetization_compliance, billing_refunds, policy_appeals
    resolution values: temporary_lock_and_manual_recovery, request_tax_renewal, approve_refund, expedited_human_review
    escalation teams: security_specialist (only when account compromise is confirmed; omit otherwise)

    ACTION FORMAT EXAMPLES β€” copy these exactly:
    {"action_type": "inspect_ticket",   "target": "T1", "value": ""}
    {"action_type": "request_context",  "target": "T1", "value": "tax_status"}
    {"action_type": "set_priority",     "target": "T1", "value": "urgent"}
    {"action_type": "set_route",        "target": "T1", "value": "account_security"}
    {"action_type": "set_resolution",   "target": "T1", "value": "temporary_lock_and_manual_recovery"}
    {"action_type": "escalate",         "target": "T1", "value": "security_specialist"}
    {"action_type": "rank_queue",       "target": "queue", "value": "T2,T1,T3"}
    {"action_type": "finalize",         "target": "T1", "value": ""}

    CRITICAL: For request_context, target = ticket ID (e.g. "T1"), value = context key name.
    NEVER put the context key name in target. target is ALWAYS a ticket ID.

    MANDATORY WORKFLOW β€” follow in this exact order for each ticket:
    1. inspect_ticket (target=ticket_id, value="")  ← ONCE per ticket, BEFORE any other action on it.
    2. request_context ONLY for keys in required_context_keys (these affect your score).
       Use target=ticket_id, value=key_name. One key per step. Request each key at most once.
       Do NOT request optional available_context_keys β€” they waste steps.
    3. set_priority  ← MANDATORY before finalize. Use valid priority values.
    4. set_route     ← MANDATORY before finalize. Use valid route values.
    5. set_resolution ← MANDATORY before finalize. Use valid resolution values.
    6. escalate only when account takeover / security compromise is confirmed.
    7. For queue tasks: rank_queue once, after ALL tickets are processed.
    8. finalize (target=ticket_id, value="") β€” ONLY after set_priority, set_route,
       and set_resolution have ALL been called for this ticket.

    *** YOU MUST call set_priority, set_route, and set_resolution on every ticket. ***
    *** Calling finalize before those three actions will score near 0. ***

    PRIORITY HINTS:
    - Account takeover / fraud / SLA <= 2h β†’ urgent
    - Tax/compliance holds, payment issues / SLA <= 12h β†’ high
    - Routine appeals, refunds / SLA >= 24h β†’ normal

    STRICT RULES:
    - NEVER repeat an action you have already taken (check your history).
    - inspect_ticket AT MOST ONCE per ticket, and ALWAYS before request_context on that ticket.
    - target is ALWAYS a ticket ID like "T1". NEVER put a context key in target.
    - Each request_context must use a different value (key name).
    - value must ALWAYS be a string β€” use "" (empty string), never null.
    """
).strip()


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    error_val = error if error else "null"
    done_val = str(done).lower()
    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


def build_user_prompt(
    observation: Observation,
    step: int,
    rewards: List[float],
    action_history: List[str],
    completed_per_ticket: dict,
) -> str:
    reward_history = ",".join(f"{reward:.2f}" for reward in rewards[-5:]) if rewards else "none"
    history_str = "\n".join(f"  {a}" for a in action_history) if action_history else "  none"

    # FIX 2: Summarise what mandatory actions are still missing per ticket so the
    # model can see at a glance what it still needs to do before finalize.
    pending_lines = []
    for tid, done_actions in sorted(completed_per_ticket.items()):
        missing = REQUIRED_PER_TICKET - done_actions
        if missing:
            pending_lines.append(f"  {tid}: still needs {', '.join(sorted(missing))}")
    pending_str = "\n".join(pending_lines) if pending_lines else "  all mandatory actions complete"

    return textwrap.dedent(
        f"""
        Step: {step}
        Task: {observation.task_id}
        Difficulty: {observation.difficulty}
        Reward history: {reward_history}

        Mandatory actions still PENDING (you MUST complete these before finalize):
{pending_str}

        Actions you have ALREADY taken this episode (do NOT repeat these):
{history_str}

        Observation JSON:
        {json.dumps(observation.model_dump(), indent=2, sort_keys=True)}
        Return one JSON action that you have NOT already taken.
        Remember: value must always be a string, never null.
        """
    ).strip()


def extract_json(text: str) -> dict:
    """
    Robustly extract a JSON object from model output.
    Handles:
    - <think>...</think> reasoning blocks (emitted by DeepSeek-R1, Gemini thinking, etc.)
    - Markdown code fences (```json ... ```)
    - Stray surrounding text
    """
    # Strip <think>...</think> blocks first β€” they often contain stray { } chars
    # that fool the JSON extractor into grabbing the wrong object.
    text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()

    # Strip ```json ... ``` fences
    text = re.sub(r"```(?:json)?", "", text).strip().rstrip("`").strip()

    # Try direct parse
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass

    # Find the LAST complete {...} block β€” the real action is always after any
    # preamble text, so the last match is more reliable than the first.
    matches = list(re.finditer(r"\{[^{}]+\}", text, re.DOTALL))
    for m in reversed(matches):
        try:
            return json.loads(m.group())
        except json.JSONDecodeError:
            continue

    raise ValueError(f"No valid JSON object found in: {text!r}")


def get_model_action(
    client: OpenAI,
    observation: Observation,
    step: int,
    rewards: List[float],
    action_history: List[str],
    completed_per_ticket: dict,
) -> tuple[Action, Optional[str]]:
    user_prompt = build_user_prompt(observation, step, rewards, action_history, completed_per_ticket)
    last_exc: Optional[str] = None
    content = ""

    for attempt in range(1, JSON_RETRY_LIMIT + 1):
        # Slightly raise temperature on retries so we don't get the same bad output
        temp = TEMPERATURE if attempt == 1 else min(TEMPERATURE + 0.15 * attempt, 1.0)
        try:
            completion = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": user_prompt},
                ],
                temperature=temp,
                max_tokens=MAX_TOKENS,
                stream=False,
            )
            content = (completion.choices[0].message.content or "").strip()
            payload = extract_json(content)

            # FIX 4: Normalise null β†’ "" so the Action model never sees None for value
            if payload.get("value") is None:
                payload["value"] = ""

            action = Action.model_validate(payload)
            return action, None
        except Exception as exc:
            last_exc = str(exc).replace("\n", " ")
            print(f"[WARN] attempt={attempt} parse_error={last_exc!r} content={content!r}", flush=True)

            # FIX 5a: Respect rate-limit retry-after delays instead of hammering the API.
            # The 429 body includes a retryDelay field (e.g. "16s"). Parse and sleep for it
            # so subsequent attempts actually succeed rather than burning the retry budget.
            if "429" in last_exc or "RESOURCE_EXHAUSTED" in last_exc:
                delay_match = re.search(r"retryDelay['\"]:\s*['\"](\d+(?:\.\d+)?)s", last_exc)
                delay = float(delay_match.group(1)) if delay_match else 20.0
                print(f"[WARN] rate-limited; sleeping {delay:.1f}s before retry", flush=True)
                time.sleep(delay)

    # FIX 5b: Exhausted retries β€” do NOT blindly finalize.
    # Skip to a no-op inspect on the first visible ticket to keep the episode alive.
    print("[WARN] JSON retry limit exhausted; emitting safe no-op", flush=True)
    # observation.tickets may be a list of objects or a dict β€” handle both.
    obs_dump = observation.model_dump()
    raw_tickets = obs_dump.get("tickets", [])
    if isinstance(raw_tickets, dict):
        ticket_ids = list(raw_tickets.keys())
    else:
        # list of dicts β€” each item should have an "id" or similar field
        ticket_ids = [
            t.get("ticket_id") or t.get("id") or f"T{i+1}"
            for i, t in enumerate(raw_tickets)
        ]
    ticket_ids = ticket_ids or ["T1"]

    inspected = {
        json.loads(a)["target"]
        for a in action_history
        if json.loads(a).get("action_type") == "inspect_ticket"
    }
    target = next((t for t in ticket_ids if t not in inspected), ticket_ids[0])
    fallback = Action(action_type="inspect_ticket", target=target, value="")
    return fallback, last_exc


def clamp_score(score: float) -> float:
    """Clamp score to strictly open interval (0, 1).
    Uses 0.001/0.999 so the value survives :.3f log formatting β€”
    the submission parser reads that string, so 1e-6 would round
    to '0.000' and be rejected as exactly 0.0."""
    return min(max(float(score), 0.001), 0.999)


def select_tasks(requested: str) -> List[str]:
    """
    Return at least MIN_TASKS task IDs.
    Always includes the requested task; pads with other available tasks if needed.
    """
    available = list_task_ids()
    if not available:
        raise RuntimeError("No tasks available in the environment.")

    primary = requested if requested in available else available[0]
    others = [t for t in available if t != primary]
    task_list = [primary] + others
    return task_list[:max(MIN_TASKS, 1)]


def run_task(client: OpenAI, task_name: str) -> dict:
    """Run a single task and return a result dict."""
    env = SupportOpsEnv(task_id=task_name)
    rewards: List[float] = []
    action_history: List[str] = []
    # FIX 6: Track which mandatory actions have been completed per ticket
    # so we can warn the model and block premature finalize.
    completed_per_ticket: dict[str, set] = {}
    steps_taken = 0
    score = 0.0
    success = False

    log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)

    try:
        observation = env.reset(task_id=task_name)

        for step in range(1, MAX_STEPS + 1):
            action, action_error = get_model_action(
                client, observation, step, rewards, action_history, completed_per_ticket
            )

            # FIX 7: Guard against premature finalize β€” if mandatory steps are still
            # missing for any ticket, redirect to the first pending mandatory action
            # instead of letting the model throw away the score.
            if action.action_type == "finalize":
                target = action.target or "T1"
                missing = REQUIRED_PER_TICKET - completed_per_ticket.get(target, set())
                if missing:
                    next_action_type = sorted(missing)[0]  # deterministic ordering
                    print(
                        f"[GUARD] Premature finalize on {target}; redirecting to {next_action_type}",
                        flush=True,
                    )
                    # Pick the first valid value for the missing action type
                    FALLBACK_VALUES = {
                        "set_priority": "normal",
                        "set_route": "policy_appeals",
                        "set_resolution": "expedited_human_review",
                    }
                    action = Action(
                        action_type=next_action_type,
                        target=target,
                        value=FALLBACK_VALUES[next_action_type],
                    )

            action_str = json.dumps(action.model_dump(), separators=(",", ":"))
            action_history.append(action_str)

            # Update completion tracker
            if action.action_type in REQUIRED_PER_TICKET:
                t = action.target or "T1"
                completed_per_ticket.setdefault(t, set()).add(action.action_type)

            observation, reward, done, info = env.step(action)
            reward_value = reward.value
            rewards.append(reward_value)
            steps_taken = step

            log_step(
                step=step,
                action=action_str,
                reward=reward_value,
                done=done,
                error=action_error,
            )

            score = float(info.get("task_score", 0.0))
            if done:
                break

        score = clamp_score(score)
        success = score >= SUCCESS_SCORE_THRESHOLD
    finally:
        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)

    return {"task": task_name, "success": success, "steps": steps_taken, "score": score}


def main() -> None:
    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)

    tasks = list(reversed(select_tasks(TASK_NAME)))

    all_results = []
    for task_name in tasks:
        result = run_task(client, task_name)
        all_results.append(result)

    total = len(all_results)
    passed = sum(1 for r in all_results if r["success"])
    avg_score = sum(r["score"] for r in all_results) / total if total else 0.0
    print(
        f"[SUMMARY] tasks={total} passed={passed} avg_score={avg_score:.3f}",
        flush=True,
    )


if __name__ == "__main__":
    main()