Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import re | |
| from openai import OpenAI | |
| from baseline.policy import heuristic_policy | |
| from env.environment import OpenEnv | |
| from env.models import Action, Observation | |
| from env.runtime_config import RuntimeConfig | |
| API_BASE_URL = os.getenv("API_BASE_URL") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| runtime_config = RuntimeConfig.from_env() | |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN or "EMPTY") | |
| VALID_ACTIONS = {"triage", "respond", "resolve", "escalate"} | |
| def compute_partial_score(total_reward: float, max_steps: int) -> float: | |
| max_possible_reward = max(float(max_steps), 1.0) | |
| return max(0.0, min(1.0, total_reward / max_possible_reward)) | |
| def safe_parse(content: str) -> dict: | |
| try: | |
| parsed = json.loads(content) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| match = re.search(r"\{.*\}", content, re.DOTALL) | |
| if match: | |
| try: | |
| parsed = json.loads(match.group(0)) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| return {"action_type": "triage", "note": "fallback"} | |
| def choose_action(observation: Observation) -> Action: | |
| if not API_BASE_URL or not HF_TOKEN: | |
| return heuristic_policy(observation) | |
| prompt = ( | |
| "Return one action_type from [triage, respond, resolve, escalate] and a short note in JSON " | |
| "with keys action_type and note. " | |
| f"Observation: {observation.model_dump_json()}" | |
| ) | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0, | |
| ) | |
| content = response.choices[0].message.content or "{}" | |
| payload = safe_parse(content) | |
| action_type = payload.get("action_type") | |
| if action_type not in VALID_ACTIONS: | |
| return heuristic_policy(observation) | |
| payload["note"] = str(payload.get("note", "")) | |
| return Action(**payload) | |
| except Exception: | |
| return heuristic_policy(observation) | |
| def run_episode(task: str, max_steps: int = 20) -> dict: | |
| env = OpenEnv( | |
| **{**runtime_config.to_env_kwargs(), "difficulty": task}, | |
| ) | |
| observation = env.reset() | |
| openai_client_configured = bool(API_BASE_URL and HF_TOKEN) | |
| print( | |
| f"[START] task={task} env=workflow model={MODEL_NAME} " | |
| f"openai_client={'enabled' if openai_client_configured else 'fallback'}" | |
| ) | |
| done = False | |
| total_reward = 0.0 | |
| reward_trace: list[str] = [] | |
| steps = 0 | |
| while not done and steps < max_steps: | |
| action = choose_action(observation) | |
| observation, reward, done, info = env.step(action) | |
| total_reward += reward | |
| reward_trace.append(f"{reward:.2f}") | |
| steps += 1 | |
| print( | |
| f"[STEP] step={steps} action={action.action_type} reward={reward:.2f} " | |
| f"done={str(done).lower()} error=null" | |
| ) | |
| if done: | |
| break | |
| success = observation.ticket_status == "resolved" | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} " | |
| f"rewards={','.join(reward_trace)}" | |
| ) | |
| final_state = env.state() | |
| env_score = float(final_state.get("score", 0.0)) | |
| partial_score = compute_partial_score(total_reward, max_steps) | |
| score = env_score if env_score > 0.0 else partial_score | |
| return { | |
| "task": task, | |
| "success": success, | |
| "steps": steps, | |
| "rewards": round(total_reward, 2), | |
| "score": round(score, 4), | |
| "env_score": round(env_score, 4), | |
| "partial_score": round(partial_score, 4), | |
| "openai_client_configured": openai_client_configured, | |
| } | |
| def run_all_tasks() -> list[dict]: | |
| results = [] | |
| for task in ["easy", "medium", "hard"]: | |
| results.append(run_episode(task)) | |
| return results | |
| if __name__ == "__main__": | |
| summary = run_all_tasks() | |
| print(json.dumps(summary, indent=2)) | |