Spaces:
Sleeping
Sleeping
| import asyncio | |
| import os | |
| import textwrap | |
| from typing import List, Optional | |
| from openai import OpenAI | |
| from environment.api_triage_env import APITriageEnv | |
| from environment.action_space import get_all_actions | |
| from environment.incident_generator import get_incident_by_type | |
| # ============================================ | |
| # Environment Variables | |
| # ============================================ | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| API_KEY = HF_TOKEN | |
| TASK_NAME = os.getenv("TASK_NAME", "api_triage") | |
| BENCHMARK = os.getenv("BENCHMARK", "api_triage_agent") | |
| MAX_STEPS = 10 | |
| TEMPERATURE = 0.7 | |
| MAX_TOKENS = 50 | |
| SUCCESS_SCORE_THRESHOLD = 0.5 | |
| # ============================================ | |
| # System Prompt | |
| # ============================================ | |
| AVAILABLE_ACTIONS = get_all_actions() | |
| SYSTEM_PROMPT = textwrap.dedent( | |
| f""" | |
| You are an API debugging agent. Your job is to diagnose and fix API failures. | |
| Available actions: {AVAILABLE_ACTIONS} | |
| Rules: | |
| - First use "inspect_logs" to understand the problem | |
| - Then take the correct fix action based on the error | |
| - Finally use "resolve" to end the episode | |
| Reply with ONLY the action name. No explanations. No quotes. | |
| """ | |
| ).strip() | |
| # ============================================ | |
| # Logging Functions (Required Format) | |
| # ============================================ | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) | |
| # ============================================ | |
| # Prompt Builder | |
| # ============================================ | |
| def build_user_prompt(step: int, observation, last_reward: float, history: List[str]) -> str: | |
| history_block = "\n".join(history[-4:]) if history else "None" | |
| return textwrap.dedent( | |
| f""" | |
| Step: {step} | |
| Incident: {observation.incident_summary} | |
| Response Code: {observation.response_code} | |
| Logs: {observation.logs} | |
| Fix Applied: {observation.fix_applied} | |
| Last Reward: {last_reward:.2f} | |
| Previous Actions: | |
| {history_block} | |
| Choose an action from: {AVAILABLE_ACTIONS} | |
| Reply with ONLY the action name. | |
| """ | |
| ).strip() | |
| # ============================================ | |
| # LLM Caller | |
| # ============================================ | |
| def get_model_action(client: OpenAI, step: int, observation, last_reward: float, history: List[str]) -> str: | |
| user_prompt = build_user_prompt(step, observation, last_reward, history) | |
| try: | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| stream=False, | |
| ) | |
| action = (completion.choices[0].message.content or "").strip().lower() | |
| if action not in AVAILABLE_ACTIONS: | |
| print(f"[DEBUG] Invalid action '{action}', defaulting to inspect_logs", flush=True) | |
| return "inspect_logs" | |
| return action | |
| except Exception as exc: | |
| print(f"[DEBUG] Model request failed: {exc}", flush=True) | |
| return "inspect_logs" | |
| # ============================================ | |
| # Main Async Function | |
| # ============================================ | |
| async def main() -> None: | |
| if not API_KEY: | |
| print("[ERROR] HF_TOKEN environment variable not set", flush=True) | |
| return | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| env = APITriageEnv(max_steps=MAX_STEPS) | |
| # All 6 task IDs matching openenv.yaml — each evaluated explicitly | |
| task_ids = ["auth_error", "missing_fields", "rate_limit", "timeout", "wrong_endpoint", "server_error"] | |
| log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME) | |
| for tid in task_ids: | |
| history: List[str] = [] | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| success = False | |
| try: | |
| # Reset env and FORCE the specific incident type (no randomness) | |
| observation = env.reset() | |
| env.incident = get_incident_by_type(tid) | |
| observation = env.state() # refresh observation with forced incident | |
| last_reward = 0.0 | |
| for step in range(1, MAX_STEPS + 1): | |
| action = get_model_action(client, step, observation, last_reward, history) | |
| observation, reward, done, info = env.step(action) | |
| rewards.append(reward) | |
| steps_taken = step | |
| last_reward = reward | |
| log_step(step=step, action=action, reward=reward, done=done, error=None) | |
| history.append(f"Step {step}: {action} -> reward {reward:.2f}") | |
| if done: | |
| success = info.get("resolution") == "success" | |
| break | |
| # Score strictly between 0 and 1 | |
| task_score = 0.95 if success else 0.05 | |
| log_end(success=success, steps=steps_taken, score=task_score, rewards=rewards) | |
| except Exception as e: | |
| print(f"[DEBUG] Error in task {tid}: {e}", flush=True) | |
| log_end(success=False, steps=0, score=0.05, rewards=[0.0]) | |
| # ============================================ | |
| # Run | |
| # ============================================ | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |