File size: 4,086 Bytes
846683d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import json
import os
import re

from openai import OpenAI

from baseline.policy import heuristic_policy
from env.environment import OpenEnv
from env.models import Action, Observation
from env.runtime_config import RuntimeConfig

API_BASE_URL = os.getenv("API_BASE_URL")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.getenv("HF_TOKEN")
runtime_config = RuntimeConfig.from_env()

client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "EMPTY")
VALID_ACTIONS = {"triage", "respond", "resolve", "escalate"}


def compute_partial_score(total_reward: float, max_steps: int) -> float:
    max_possible_reward = max(float(max_steps), 1.0)
    return max(0.0, min(1.0, total_reward / max_possible_reward))


def safe_parse(content: str) -> dict:
    try:
        parsed = json.loads(content)
        if isinstance(parsed, dict):
            return parsed
    except json.JSONDecodeError:
        pass

    match = re.search(r"\{.*\}", content, re.DOTALL)
    if match:
        try:
            parsed = json.loads(match.group(0))
            if isinstance(parsed, dict):
                return parsed
        except json.JSONDecodeError:
            pass

    return {"action_type": "triage", "note": "fallback"}


def choose_action(observation: Observation) -> Action:
    if not API_BASE_URL or not HF_TOKEN:
        return heuristic_policy(observation)

    prompt = (
        "Return one action_type from [triage, respond, resolve, escalate] and a short note in JSON "
        "with keys action_type and note. "
        f"Observation: {observation.model_dump_json()}"
    )

    try:
        response = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[{"role": "user", "content": prompt}],
            temperature=0,
        )
        content = response.choices[0].message.content or "{}"
        payload = safe_parse(content)
        action_type = payload.get("action_type")
        if action_type not in VALID_ACTIONS:
            return heuristic_policy(observation)
        payload["note"] = str(payload.get("note", ""))
        return Action(**payload)
    except Exception:
        return heuristic_policy(observation)


def run_episode(task: str, max_steps: int = 20) -> dict:
    env = OpenEnv(
        **{**runtime_config.to_env_kwargs(), "difficulty": task},
    )
    observation = env.reset()

    openai_client_configured = bool(API_BASE_URL and HF_TOKEN)
    print(
        f"[START] task={task} env=workflow model={MODEL_NAME} "
        f"openai_client={'enabled' if openai_client_configured else 'fallback'}"
    )

    done = False
    total_reward = 0.0
    reward_trace: list[str] = []
    steps = 0

    while not done and steps < max_steps:
        action = choose_action(observation)
        observation, reward, done, info = env.step(action)
        total_reward += reward
        reward_trace.append(f"{reward:.2f}")
        steps += 1

        print(
            f"[STEP] step={steps} action={action.action_type} reward={reward:.2f} "
            f"done={str(done).lower()} error=null"
        )

        if done:
            break

    success = observation.ticket_status == "resolved"
    print(
        f"[END] success={str(success).lower()} steps={steps} "
        f"rewards={','.join(reward_trace)}"
    )

    final_state = env.state()
    env_score = float(final_state.get("score", 0.0))
    partial_score = compute_partial_score(total_reward, max_steps)
    score = env_score if env_score > 0.0 else partial_score

    return {
        "task": task,
        "success": success,
        "steps": steps,
        "rewards": round(total_reward, 2),
        "score": round(score, 4),
        "env_score": round(env_score, 4),
        "partial_score": round(partial_score, 4),
        "openai_client_configured": openai_client_configured,
    }


def run_all_tasks() -> list[dict]:
    results = []
    for task in ["easy", "medium", "hard"]:
        results.append(run_episode(task))
    return results


if __name__ == "__main__":
    summary = run_all_tasks()
    print(json.dumps(summary, indent=2))