File size: 4,356 Bytes
9fdf681 d40922a 9fdf681 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import os
import json
from openai import OpenAI
from env import CustomerSupportEnv
from models import Action
from tasks import TASKS
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
# 1. Required Environment Variables
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("HF_TOKEN environment variable is required")
# 2. Open AI Client Only
client = OpenAI(
base_url=API_BASE_URL,
api_key=HF_TOKEN
)
def run_inference():
env = CustomerSupportEnv()
for idx, task in enumerate(TASKS):
obs = env.reset(task_idx=idx)
done = False
step_idx = 0
rewards_history = []
# [START] FORMAT
print(f"[START] task={task.name} env=customer_support model={MODEL_NAME}")
while not done:
step_idx += 1
error_msg = "null"
reward_val = 0.00
action_str = ""
# 🚀 HEAVILY ENGINEERED PROMPT FOR STRICT COMPLIANCE
prompt = (
"System: You are an automated customer support AI. You MUST respond strictly in JSON format matching this schema: "
"{\"action_type\": \"ROUTE\"|\"ASK_INFO\"|\"REFUND\"|\"CLOSE\", \"argument\": \"string\"}\n\n"
"CRITICAL RULES:\n"
"1. If 'missing_info' in the observation is empty ([]), DO NOT use ASK_INFO. You must take action (ROUTE or REFUND).\n"
"2. If 'missing_info' contains items, you MUST use ASK_INFO. The 'argument' MUST contain the EXACT string from 'missing_info' (e.g., 'serial_number', 'order_id', 'photo_evidence'). Ask for ONLY ONE missing item at a time\n"
"3. When using ROUTE, the 'argument' MUST be exactly one of these three codes: 'IT_SUPPORT', 'HARDWARE_SUPPORT', or 'BILLING'. Do not output full sentences.\n"
"4. If the user wants a refund, and you have collected 'order_id', you MUST first use the REFUND action. Then, in the next step, use ROUTE with 'BILLING'.\n\n"
f"Observation: {obs.model_dump_json()}"
)
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"}
)
# Safely parse JSON in case Qwen outputs markdown ticks
raw_action = response.choices[0].message.content.strip()
if raw_action.startswith("```json"):
raw_action = raw_action[7:-3].strip()
elif raw_action.startswith("```"):
raw_action = raw_action[3:-3].strip()
action_data = json.loads(raw_action)
# Pydantic validation
action = Action(**action_data)
action_str = f"{action.action_type}('{action.argument}')"
# Env Step
obs = env.step(action)
done = obs.done
reward_val = float(obs.reward) if obs.reward is not None else 0.0
rewards_history.append(reward_val)
except Exception as e:
error_msg = str(e).replace('\n', ' ')
action_str = "ERROR"
done = True
rewards_history.append(0.00)
# [STEP] FORMAT
print(f"[STEP] step={step_idx} action={action_str} reward={reward_val:.2f} done={str(done).lower()} error={error_msg}")
# [END] FORMAT
# Use the last reward (grader's final score) as the task score
# Clamp strictly between 0.01 and 0.99 to satisfy validator
if rewards_history:
final_score = max(0.01, min(0.99, rewards_history[-1]))
else:
final_score = 0.01
success = final_score > 0.8
rewards_str = ",".join([f"{r:.2f}" for r in rewards_history])
print(f"[END] success={str(success).lower()} steps={step_idx} score={final_score:.2f} rewards={rewards_str}")
if __name__ == "__main__":
run_inference() |