File size: 4,356 Bytes
9fdf681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d40922a
 
 
 
 
 
9fdf681
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import json
from openai import OpenAI
from env import CustomerSupportEnv
from models import Action
from tasks import TASKS

try:
    from dotenv import load_dotenv
    load_dotenv()
except ImportError:
    pass


# 1. Required Environment Variables
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
HF_TOKEN = os.getenv("HF_TOKEN")

if HF_TOKEN is None:
    raise ValueError("HF_TOKEN environment variable is required")

# 2. Open AI Client Only
client = OpenAI(
    base_url=API_BASE_URL,
    api_key=HF_TOKEN
)

def run_inference():
    env = CustomerSupportEnv()
    
    for idx, task in enumerate(TASKS):
        obs = env.reset(task_idx=idx)
        done = False
        step_idx = 0
        rewards_history = []
        
        # [START] FORMAT
        print(f"[START] task={task.name} env=customer_support model={MODEL_NAME}")
        
        while not done:
            step_idx += 1
            error_msg = "null"
            reward_val = 0.00
            action_str = ""
            
            # 🚀 HEAVILY ENGINEERED PROMPT FOR STRICT COMPLIANCE
            prompt = (
                "System: You are an automated customer support AI. You MUST respond strictly in JSON format matching this schema: "
                "{\"action_type\": \"ROUTE\"|\"ASK_INFO\"|\"REFUND\"|\"CLOSE\", \"argument\": \"string\"}\n\n"
                "CRITICAL RULES:\n"
                "1. If 'missing_info' in the observation is empty ([]), DO NOT use ASK_INFO. You must take action (ROUTE or REFUND).\n"
                "2. If 'missing_info' contains items, you MUST use ASK_INFO. The 'argument' MUST contain the EXACT string from 'missing_info' (e.g., 'serial_number', 'order_id', 'photo_evidence'). Ask for ONLY ONE missing item at a time\n"
                "3. When using ROUTE, the 'argument' MUST be exactly one of these three codes: 'IT_SUPPORT', 'HARDWARE_SUPPORT', or 'BILLING'. Do not output full sentences.\n"
                "4. If the user wants a refund, and you have collected 'order_id', you MUST first use the REFUND action. Then, in the next step, use ROUTE with 'BILLING'.\n\n"
                f"Observation: {obs.model_dump_json()}"
            )

            try:
                response = client.chat.completions.create(
                    model=MODEL_NAME,
                    messages=[{"role": "user", "content": prompt}],
                    response_format={"type": "json_object"}
                )
                
                # Safely parse JSON in case Qwen outputs markdown ticks
                raw_action = response.choices[0].message.content.strip()
                if raw_action.startswith("```json"):
                    raw_action = raw_action[7:-3].strip()
                elif raw_action.startswith("```"):
                    raw_action = raw_action[3:-3].strip()
                    
                action_data = json.loads(raw_action)
                
                # Pydantic validation
                action = Action(**action_data)
                action_str = f"{action.action_type}('{action.argument}')"
                
                # Env Step
                obs = env.step(action)
                done = obs.done
                reward_val = float(obs.reward) if obs.reward is not None else 0.0
                rewards_history.append(reward_val)
                
            except Exception as e:
                error_msg = str(e).replace('\n', ' ')
                action_str = "ERROR"
                done = True
                rewards_history.append(0.00)
                
            # [STEP] FORMAT
            print(f"[STEP] step={step_idx} action={action_str} reward={reward_val:.2f} done={str(done).lower()} error={error_msg}")

        # [END] FORMAT
        # Use the last reward (grader's final score) as the task score
        # Clamp strictly between 0.01 and 0.99 to satisfy validator
        if rewards_history:
            final_score = max(0.01, min(0.99, rewards_history[-1]))
        else:
            final_score = 0.01
        success = final_score > 0.8
        rewards_str = ",".join([f"{r:.2f}" for r in rewards_history])
        print(f"[END] success={str(success).lower()} steps={step_idx} score={final_score:.2f} rewards={rewards_str}")

if __name__ == "__main__":
    run_inference()