| import os |
| import json |
| from openai import OpenAI |
| from env import CustomerSupportEnv |
| from models import Action |
| from tasks import TASKS |
|
|
| try: |
| from dotenv import load_dotenv |
| load_dotenv() |
| except ImportError: |
| pass |
|
|
|
|
| |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini") |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| if HF_TOKEN is None: |
| raise ValueError("HF_TOKEN environment variable is required") |
|
|
| |
| client = OpenAI( |
| base_url=API_BASE_URL, |
| api_key=HF_TOKEN |
| ) |
|
|
| def run_inference(): |
| env = CustomerSupportEnv() |
| |
| for idx, task in enumerate(TASKS): |
| obs = env.reset(task_idx=idx) |
| done = False |
| step_idx = 0 |
| rewards_history = [] |
| |
| |
| print(f"[START] task={task.name} env=customer_support model={MODEL_NAME}") |
| |
| while not done: |
| step_idx += 1 |
| error_msg = "null" |
| reward_val = 0.00 |
| action_str = "" |
| |
| |
| prompt = ( |
| "System: You are an automated customer support AI. You MUST respond strictly in JSON format matching this schema: " |
| "{\"action_type\": \"ROUTE\"|\"ASK_INFO\"|\"REFUND\"|\"CLOSE\", \"argument\": \"string\"}\n\n" |
| "CRITICAL RULES:\n" |
| "1. If 'missing_info' in the observation is empty ([]), DO NOT use ASK_INFO. You must take action (ROUTE or REFUND).\n" |
| "2. If 'missing_info' contains items, you MUST use ASK_INFO. The 'argument' MUST contain the EXACT string from 'missing_info' (e.g., 'serial_number', 'order_id', 'photo_evidence'). Ask for ONLY ONE missing item at a time\n" |
| "3. When using ROUTE, the 'argument' MUST be exactly one of these three codes: 'IT_SUPPORT', 'HARDWARE_SUPPORT', or 'BILLING'. Do not output full sentences.\n" |
| "4. If the user wants a refund, and you have collected 'order_id', you MUST first use the REFUND action. Then, in the next step, use ROUTE with 'BILLING'.\n\n" |
| f"Observation: {obs.model_dump_json()}" |
| ) |
|
|
| try: |
| response = client.chat.completions.create( |
| model=MODEL_NAME, |
| messages=[{"role": "user", "content": prompt}], |
| response_format={"type": "json_object"} |
| ) |
| |
| |
| raw_action = response.choices[0].message.content.strip() |
| if raw_action.startswith("```json"): |
| raw_action = raw_action[7:-3].strip() |
| elif raw_action.startswith("```"): |
| raw_action = raw_action[3:-3].strip() |
| |
| action_data = json.loads(raw_action) |
| |
| |
| action = Action(**action_data) |
| action_str = f"{action.action_type}('{action.argument}')" |
| |
| |
| obs = env.step(action) |
| done = obs.done |
| reward_val = float(obs.reward) if obs.reward is not None else 0.0 |
| rewards_history.append(reward_val) |
| |
| except Exception as e: |
| error_msg = str(e).replace('\n', ' ') |
| action_str = "ERROR" |
| done = True |
| rewards_history.append(0.00) |
| |
| |
| print(f"[STEP] step={step_idx} action={action_str} reward={reward_val:.2f} done={str(done).lower()} error={error_msg}") |
|
|
| |
| |
| success = (sum(rewards_history) > 0.8) if rewards_history else False |
| rewards_str = ",".join([f"{r:.2f}" for r in rewards_history]) |
| print(f"[END] success={str(success).lower()} steps={step_idx} rewards={rewards_str}") |
|
|
| if __name__ == "__main__": |
| run_inference() |