File size: 7,248 Bytes
cd688d7
 
aa4f7bc
cd688d7
 
 
 
 
 
aa4f7bc
 
 
cd688d7
 
 
 
 
 
 
 
 
 
 
 
 
ba2722e
 
 
cd688d7
 
ba2722e
 
 
cd688d7
 
aa4f7bc
cd688d7
aa4f7bc
 
 
 
 
 
 
 
 
 
 
 
 
0c24081
 
ba2722e
 
0c24081
 
 
aa4f7bc
 
0c24081
 
 
 
cd688d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa4f7bc
 
 
 
 
cd688d7
aa4f7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd688d7
aa4f7bc
 
 
cd688d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af3f703
cd688d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba2722e
cd688d7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import json
import logging
import asyncio
from typing import List, Optional
from openai import OpenAI
from env.environment import SupportTicketEnv
from env.models import Action

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")

MAX_STEPS = 10
MAX_TOTAL_REWARD = 1.0
SUCCESS_SCORE_THRESHOLD = 0.8

def log_start(task: str, env: str, model: str):
    print(f"[START] task={task} env={env} model={model}", flush=True)

def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str] = None):
    error_val = error if error else "null"
    done_val = str(done).lower()
    print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)

def log_end(success: bool, steps: int, score: float, rewards: list):
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    success_val = str(success).lower()
    print(f"[END] success={success_val} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)

def parse_action(text: str) -> Action:
    # Robustly extract the first JSON object from text and validate with Pydantic
    try:
        decoder = json.JSONDecoder()
        idx = 0
        while True:
            idx = text.find('{', idx)
            if idx == -1:
                break
            try:
                obj, end = decoder.raw_decode(text, idx)
                if isinstance(obj, dict):
                    try:
                        return Action.model_validate(obj)
                    except Exception as val_err:
                        logger.warning("Action validation failed: %s", val_err)
                        # Fallback to manual construction with validation
                        action_type = obj.get("action_type", "close_ticket")
                        valid_actions = ["fetch_user_data", "check_policy", "issue_refund", "reply_to_customer", "escalate", "close_ticket"]
                        if action_type not in valid_actions:
                            logger.error("Invalid action_type: %s. Defaulting to 'close_ticket'.", action_type)
                            action_type = "close_ticket"
                        return Action(action_type=action_type, parameters=obj.get("parameters", {}))
            except json.JSONDecodeError:
                idx += 1
    except Exception as e:
        logger.error("Failed to parse action: %s", e)
    # Default fallback if no valid action is found
    return Action(action_type="close_ticket", parameters={})

def get_model_message(client, step: int, env_state: str, history: List[str]) -> str:
    system_prompt = (
        "You are an AI support agent resolving customer tickets.\n"
        "Available Actions:\n"
        "- fetch_user_data(user_id)\n"
        "- check_policy(issue_type)\n"
        "- issue_refund(amount)\n"
        "- reply_to_customer(message)\n"
        "- escalate(reason)\n"
        "- close_ticket(resolution)\n\n"
        "Must respond with JSON format:\n"
        "{\"action_type\": \"...\", \"parameters\": {\"...\": \"...\"}}"
    )
    history_str = "\n".join(history)
    user_prompt = f"History:\n{history_str}\n\nCurrent Observation:\n{env_state}\n\nWhat is your next action JSON?"
    
    import time
    # retry/backoff parameters
    max_retries = 3
    backoff_base = 0.5

    try:
        # Support a few possible client interfaces (chat.completions or responses)
        for attempt in range(1, max_retries + 1):
            try:
                if hasattr(client, "chat") and hasattr(client.chat, "completions"):
                    completion = client.chat.completions.create(
                        model=MODEL_NAME,
                        messages=[
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": user_prompt}
                        ],
                        temperature=0.1
                    )
                    text = (completion.choices[0].message.content or "").strip()
                    return text if text else "{}"

                if hasattr(client, "responses") and hasattr(client.responses, "create"):
                    completion = client.responses.create(model=MODEL_NAME, input=user_prompt, temperature=0.1)
                    text = getattr(completion, "output_text", None)
                    if text:
                        return text.strip()

                    out = []
                    for item in getattr(completion, "output", []) or []:
                        for c in item.get("content", []):
                            if c.get("type") == "output_text":
                                out.append(c.get("text", ""))
                    if out:
                        return "".join(out).strip()

                raise RuntimeError("No supported model client method available")
            except Exception as exc:
                logger.warning("Model request attempt %d failed: %s", attempt, exc)
                if attempt == max_retries:
                    break
                sleep_time = backoff_base * (2 ** (attempt - 1))
                time.sleep(sleep_time)
    except Exception as exc:
        logger.exception("Unexpected error in get_model_message: %s", exc)

    return "{}"

async def run_task(task_id: str, client: OpenAI) -> None:
    env = SupportTicketEnv(task_id=task_id)
    history: List[str] = []
    rewards: List[float] = []
    steps_taken = 0
    score = 0.0
    success = False

    log_start(task=task_id, env="SupportTicketEnv", model=MODEL_NAME)
    try:
        obs = env.reset()
        last_echoed = obs.model_dump_json(indent=2)
        
        for step in range(1, MAX_STEPS + 1):
            if env.get_state().is_done:
                break
                
            message = get_model_message(client, step, last_echoed, history)
            action = parse_action(message)
            
            obs_obj, reward, done, info = env.step(action)
            obs_json = obs_obj.model_dump_json(indent=2)
            error = None
            
            actual_reward = info.get("current_reward", 0.0)
            
            rewards.append(actual_reward)
            steps_taken = step
            last_echoed = obs_json
            
            log_step(step=step, action=message, reward=actual_reward, done=done, error=error)
            history.append(f"Step {step}: {message!r} -> reward {actual_reward:+.2f}")
            
            if done:
                score = actual_reward
                break
                
        score = min(max(score, 0.0), 1.0)
        success = score >= SUCCESS_SCORE_THRESHOLD
    finally:
        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)

async def main() -> None:
    api_key = os.getenv("HF_TOKEN")
    client = OpenAI(base_url=API_BASE_URL, api_key=api_key)
    
    tasks = ["task_easy_1", "task_medium_1", "task_hard_1"]
    for task_id in tasks:
        await run_task(task_id, client)

if __name__ == "__main__":
    asyncio.run(main())