Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import logging | |
| import asyncio | |
| from typing import List, Optional | |
| from openai import OpenAI | |
| from env.environment import SupportTicketEnv | |
| from env.models import Action | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") | |
| MAX_STEPS = 10 | |
| MAX_TOTAL_REWARD = 1.0 | |
| SUCCESS_SCORE_THRESHOLD = 0.8 | |
| def log_start(task: str, env: str, model: str): | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str] = None): | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True) | |
| def log_end(success: bool, steps: int, score: float, rewards: list): | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| success_val = str(success).lower() | |
| print(f"[END] success={success_val} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True) | |
| def parse_action(text: str) -> Action: | |
| # Robustly extract the first JSON object from text and validate with Pydantic | |
| try: | |
| decoder = json.JSONDecoder() | |
| idx = 0 | |
| while True: | |
| idx = text.find('{', idx) | |
| if idx == -1: | |
| break | |
| try: | |
| obj, end = decoder.raw_decode(text, idx) | |
| if isinstance(obj, dict): | |
| try: | |
| return Action.model_validate(obj) | |
| except Exception as val_err: | |
| logger.warning("Action validation failed: %s", val_err) | |
| # Fallback to manual construction with validation | |
| action_type = obj.get("action_type", "close_ticket") | |
| valid_actions = ["fetch_user_data", "check_policy", "issue_refund", "reply_to_customer", "escalate", "close_ticket"] | |
| if action_type not in valid_actions: | |
| logger.error("Invalid action_type: %s. Defaulting to 'close_ticket'.", action_type) | |
| action_type = "close_ticket" | |
| return Action(action_type=action_type, parameters=obj.get("parameters", {})) | |
| except json.JSONDecodeError: | |
| idx += 1 | |
| except Exception as e: | |
| logger.error("Failed to parse action: %s", e) | |
| # Default fallback if no valid action is found | |
| return Action(action_type="close_ticket", parameters={}) | |
| def get_model_message(client, step: int, env_state: str, history: List[str]) -> str: | |
| system_prompt = ( | |
| "You are an AI support agent resolving customer tickets.\n" | |
| "Available Actions:\n" | |
| "- fetch_user_data(user_id)\n" | |
| "- check_policy(issue_type)\n" | |
| "- issue_refund(amount)\n" | |
| "- reply_to_customer(message)\n" | |
| "- escalate(reason)\n" | |
| "- close_ticket(resolution)\n\n" | |
| "Must respond with JSON format:\n" | |
| "{\"action_type\": \"...\", \"parameters\": {\"...\": \"...\"}}" | |
| ) | |
| history_str = "\n".join(history) | |
| user_prompt = f"History:\n{history_str}\n\nCurrent Observation:\n{env_state}\n\nWhat is your next action JSON?" | |
| import time | |
| # retry/backoff parameters | |
| max_retries = 3 | |
| backoff_base = 0.5 | |
| try: | |
| # Support a few possible client interfaces (chat.completions or responses) | |
| for attempt in range(1, max_retries + 1): | |
| try: | |
| if hasattr(client, "chat") and hasattr(client.chat, "completions"): | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| temperature=0.1 | |
| ) | |
| text = (completion.choices[0].message.content or "").strip() | |
| return text if text else "{}" | |
| if hasattr(client, "responses") and hasattr(client.responses, "create"): | |
| completion = client.responses.create(model=MODEL_NAME, input=user_prompt, temperature=0.1) | |
| text = getattr(completion, "output_text", None) | |
| if text: | |
| return text.strip() | |
| out = [] | |
| for item in getattr(completion, "output", []) or []: | |
| for c in item.get("content", []): | |
| if c.get("type") == "output_text": | |
| out.append(c.get("text", "")) | |
| if out: | |
| return "".join(out).strip() | |
| raise RuntimeError("No supported model client method available") | |
| except Exception as exc: | |
| logger.warning("Model request attempt %d failed: %s", attempt, exc) | |
| if attempt == max_retries: | |
| break | |
| sleep_time = backoff_base * (2 ** (attempt - 1)) | |
| time.sleep(sleep_time) | |
| except Exception as exc: | |
| logger.exception("Unexpected error in get_model_message: %s", exc) | |
| return "{}" | |
| async def run_task(task_id: str, client: OpenAI) -> None: | |
| env = SupportTicketEnv(task_id=task_id) | |
| history: List[str] = [] | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| score = 0.0 | |
| success = False | |
| log_start(task=task_id, env="SupportTicketEnv", model=MODEL_NAME) | |
| try: | |
| obs = env.reset() | |
| last_echoed = obs.model_dump_json(indent=2) | |
| for step in range(1, MAX_STEPS + 1): | |
| if env.get_state().is_done: | |
| break | |
| message = get_model_message(client, step, last_echoed, history) | |
| action = parse_action(message) | |
| obs_obj, reward, done, info = env.step(action) | |
| obs_json = obs_obj.model_dump_json(indent=2) | |
| error = None | |
| actual_reward = info.get("current_reward", 0.0) | |
| rewards.append(actual_reward) | |
| steps_taken = step | |
| last_echoed = obs_json | |
| log_step(step=step, action=message, reward=actual_reward, done=done, error=error) | |
| history.append(f"Step {step}: {message!r} -> reward {actual_reward:+.2f}") | |
| if done: | |
| score = actual_reward | |
| break | |
| score = min(max(score, 0.01), 0.99) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| finally: | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| async def main() -> None: | |
| api_key = os.getenv("HF_TOKEN") | |
| client = OpenAI(base_url=API_BASE_URL, api_key=api_key) | |
| tasks = ["task_easy_1", "task_medium_1", "task_hard_1"] | |
| for task_id in tasks: | |
| await run_task(task_id, client) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |