File size: 7,248 Bytes
cd688d7 aa4f7bc cd688d7 aa4f7bc cd688d7 ba2722e cd688d7 ba2722e cd688d7 aa4f7bc cd688d7 aa4f7bc 0c24081 ba2722e 0c24081 aa4f7bc 0c24081 cd688d7 aa4f7bc cd688d7 aa4f7bc cd688d7 aa4f7bc cd688d7 af3f703 cd688d7 ba2722e cd688d7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | import os
import json
import logging
import asyncio
from typing import List, Optional
from openai import OpenAI
from env.environment import SupportTicketEnv
from env.models import Action
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
MAX_STEPS = 10
MAX_TOTAL_REWARD = 1.0
SUCCESS_SCORE_THRESHOLD = 0.8
def log_start(task: str, env: str, model: str):
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str] = None):
error_val = error if error else "null"
done_val = str(done).lower()
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: list):
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
success_val = str(success).lower()
print(f"[END] success={success_val} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
def parse_action(text: str) -> Action:
# Robustly extract the first JSON object from text and validate with Pydantic
try:
decoder = json.JSONDecoder()
idx = 0
while True:
idx = text.find('{', idx)
if idx == -1:
break
try:
obj, end = decoder.raw_decode(text, idx)
if isinstance(obj, dict):
try:
return Action.model_validate(obj)
except Exception as val_err:
logger.warning("Action validation failed: %s", val_err)
# Fallback to manual construction with validation
action_type = obj.get("action_type", "close_ticket")
valid_actions = ["fetch_user_data", "check_policy", "issue_refund", "reply_to_customer", "escalate", "close_ticket"]
if action_type not in valid_actions:
logger.error("Invalid action_type: %s. Defaulting to 'close_ticket'.", action_type)
action_type = "close_ticket"
return Action(action_type=action_type, parameters=obj.get("parameters", {}))
except json.JSONDecodeError:
idx += 1
except Exception as e:
logger.error("Failed to parse action: %s", e)
# Default fallback if no valid action is found
return Action(action_type="close_ticket", parameters={})
def get_model_message(client, step: int, env_state: str, history: List[str]) -> str:
system_prompt = (
"You are an AI support agent resolving customer tickets.\n"
"Available Actions:\n"
"- fetch_user_data(user_id)\n"
"- check_policy(issue_type)\n"
"- issue_refund(amount)\n"
"- reply_to_customer(message)\n"
"- escalate(reason)\n"
"- close_ticket(resolution)\n\n"
"Must respond with JSON format:\n"
"{\"action_type\": \"...\", \"parameters\": {\"...\": \"...\"}}"
)
history_str = "\n".join(history)
user_prompt = f"History:\n{history_str}\n\nCurrent Observation:\n{env_state}\n\nWhat is your next action JSON?"
import time
# retry/backoff parameters
max_retries = 3
backoff_base = 0.5
try:
# Support a few possible client interfaces (chat.completions or responses)
for attempt in range(1, max_retries + 1):
try:
if hasattr(client, "chat") and hasattr(client.chat, "completions"):
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1
)
text = (completion.choices[0].message.content or "").strip()
return text if text else "{}"
if hasattr(client, "responses") and hasattr(client.responses, "create"):
completion = client.responses.create(model=MODEL_NAME, input=user_prompt, temperature=0.1)
text = getattr(completion, "output_text", None)
if text:
return text.strip()
out = []
for item in getattr(completion, "output", []) or []:
for c in item.get("content", []):
if c.get("type") == "output_text":
out.append(c.get("text", ""))
if out:
return "".join(out).strip()
raise RuntimeError("No supported model client method available")
except Exception as exc:
logger.warning("Model request attempt %d failed: %s", attempt, exc)
if attempt == max_retries:
break
sleep_time = backoff_base * (2 ** (attempt - 1))
time.sleep(sleep_time)
except Exception as exc:
logger.exception("Unexpected error in get_model_message: %s", exc)
return "{}"
async def run_task(task_id: str, client: OpenAI) -> None:
env = SupportTicketEnv(task_id=task_id)
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env="SupportTicketEnv", model=MODEL_NAME)
try:
obs = env.reset()
last_echoed = obs.model_dump_json(indent=2)
for step in range(1, MAX_STEPS + 1):
if env.get_state().is_done:
break
message = get_model_message(client, step, last_echoed, history)
action = parse_action(message)
obs_obj, reward, done, info = env.step(action)
obs_json = obs_obj.model_dump_json(indent=2)
error = None
actual_reward = info.get("current_reward", 0.0)
rewards.append(actual_reward)
steps_taken = step
last_echoed = obs_json
log_step(step=step, action=message, reward=actual_reward, done=done, error=error)
history.append(f"Step {step}: {message!r} -> reward {actual_reward:+.2f}")
if done:
score = actual_reward
break
score = min(max(score, 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
async def main() -> None:
api_key = os.getenv("HF_TOKEN")
client = OpenAI(base_url=API_BASE_URL, api_key=api_key)
tasks = ["task_easy_1", "task_medium_1", "task_hard_1"]
for task_id in tasks:
await run_task(task_id, client)
if __name__ == "__main__":
asyncio.run(main())
|