File size: 4,992 Bytes
bd67155
 
 
 
 
 
 
3e50fa2
 
 
bd67155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import json
import os
import textwrap
from typing import List, Optional

from openai import OpenAI

from server.support_ops_env.env import SupportOpsEnv
from server.support_ops_env.models import Action, Observation
from server.support_ops_env.tasks import list_task_ids


LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
TASK_NAME = os.getenv("SUPPORT_OPS_TASK", "easy_account_takeover")
BENCHMARK = os.getenv("SUPPORT_OPS_BENCHMARK", "support_ops_env")
MAX_STEPS = int(os.getenv("MAX_STEPS", "16"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.1"))
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "220"))
SUCCESS_SCORE_THRESHOLD = float(os.getenv("SUCCESS_SCORE_THRESHOLD", "0.8"))


SYSTEM_PROMPT = textwrap.dedent(
    """
    You are operating a customer support triage environment.
    Return exactly one JSON object with keys: action_type, target, value.
    Allowed action_type values:
    - inspect_ticket
    - request_context
    - set_priority
    - set_route
    - set_resolution
    - escalate
    - rank_queue
    - finalize
    Choose only valid ticket ids from the observation.
    Use concise string values.
    Finalize only after enough evidence is gathered.
    """
).strip()


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    error_val = error if error else "null"
    done_val = str(done).lower()
    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


def build_user_prompt(observation: Observation, step: int, rewards: List[float]) -> str:
    reward_history = ",".join(f"{reward:.2f}" for reward in rewards[-5:]) if rewards else "none"
    return textwrap.dedent(
        f"""
        Step: {step}
        Task: {observation.task_id}
        Difficulty: {observation.difficulty}
        Reward history: {reward_history}
        Observation JSON:
        {json.dumps(observation.model_dump(), indent=2, sort_keys=True)}
        Return one JSON action.
        """
    ).strip()


def get_model_action(client: OpenAI, observation: Observation, step: int, rewards: List[float]) -> tuple[Action, Optional[str]]:
    user_prompt = build_user_prompt(observation, step, rewards)
    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt},
            ],
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            stream=False,
        )
        content = (completion.choices[0].message.content or "").strip()
        payload = json.loads(content)
        action = Action.model_validate(payload)
        return action, None
    except Exception as exc:
        fallback = Action(action_type="finalize")
        return fallback, str(exc).replace("\n", " ")


def ensure_known_task(task_name: str) -> str:
    if task_name in list_task_ids():
        return task_name
    return list_task_ids()[0]


def main() -> None:
    task_name = ensure_known_task(TASK_NAME)
    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
    env = SupportOpsEnv(task_id=task_name)

    rewards: List[float] = []
    steps_taken = 0
    score = 0.0
    success = False

    log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)

    try:
        observation = env.reset(task_id=task_name)

        for step in range(1, MAX_STEPS + 1):
            action, action_error = get_model_action(client, observation, step, rewards)
            action_str = json.dumps(action.model_dump(), separators=(",", ":"))

            observation, reward, done, info = env.step(action)
            reward_value = reward.value
            rewards.append(reward_value)
            steps_taken = step

            log_step(
                step=step,
                action=action_str,
                reward=reward_value,
                done=done,
                error=action_error,
            )

            score = float(info.get("task_score", 0.0))
            if done:
                break

        score = min(max(score, 0.0), 1.0)
        success = score >= SUCCESS_SCORE_THRESHOLD
    finally:
        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)


if __name__ == "__main__":
    main()