File size: 3,940 Bytes
d3b065a
834956c
 
 
 
 
 
 
 
b37fd8a
5e1996c
 
 
834956c
 
 
 
 
 
b37fd8a
 
 
834956c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b37fd8a
834956c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e1996c
 
b37fd8a
5e1996c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b37fd8a
5e1996c
 
 
 
b37fd8a
5e1996c
 
834956c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ba100e
 
b37fd8a
 
 
 
 
 
 
 
 
 
 
 
6ba100e
d3b065a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
from typing import List, Optional

from openai import OpenAI

from env import EmailTriageEnv
from app import smart_agent_logic


# ✅ REQUIRED env vars
API_BASE_URL = os.environ.get("API_BASE_URL")
API_KEY = os.environ.get("API_KEY")
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")

BENCHMARK = "email_triage_env"

MAX_STEPS = 20
SUCCESS_SCORE_THRESHOLD = 0.5

# ✅ RUN ALL TASKS
TASKS = ["easy", "medium", "hard"]


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    error_val = error if error else "null"
    done_val = str(done).lower()

    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)

    print(
        f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


def run_task(client, TASK_NAME):
    env = EmailTriageEnv(task=TASK_NAME)

    rewards: List[float] = []
    steps_taken = 0
    score = 0.0
    success = False

    log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)

    try:
        state = env.reset()

        for step in range(1, MAX_STEPS + 1):
            if state.get("done"):
                break

            try:
                desc = state["description"]

                action_list = None

                # ✅ LLM CALL
                if client:
                    try:
                        response = client.chat.completions.create(
                            model=MODEL_NAME,
                            messages=[
                                {
                                    "role": "system",
                                    "content": "Classify email into 3 integers: urgency (0-2), routing (0-2), resolution (0-2). Return only numbers like: 2 1 2"
                                },
                                {
                                    "role": "user",
                                    "content": desc
                                }
                            ],
                            max_tokens=20,
                            temperature=0,
                        )

                        text = response.choices[0].message.content.strip()
                        action_list = [int(x) for x in text.replace(",", " ").split()[:3]]

                        if len(action_list) != 3:
                            raise ValueError()

                    except Exception as llm_error:
                        print(f"[DEBUG] LLM failed: {llm_error}", flush=True)

                # fallback
                if not action_list:
                    action_list = smart_agent_logic(desc)

                state, reward, done, _, _ = env.step(action_list)

                rewards.append(reward)
                steps_taken = step

                log_step(step, str(action_list), reward, done, None)

                if done:
                    break

            except Exception as step_error:
                log_step(step, "error", 0.0, True, str(step_error))
                break

        if rewards:
            score = sum(rewards) / len(rewards)
            score = max(0.0, min(score, 1.0))

        success = score >= SUCCESS_SCORE_THRESHOLD

    finally:
        log_end(success, steps_taken, score, rewards)


def main():
    try:
        client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
    except Exception as e:
        print(f"[DEBUG] OpenAI init failed: {e}", flush=True)
        client = None

    # ✅ RUN ALL TASKS
    for task in TASKS:
        run_task(client, task)


if __name__ == "__main__":
    main()