Spaces:
Sleeping
Sleeping
File size: 6,238 Bytes
d416acc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | import asyncio
import os
import textwrap
from typing import List, Optional
from openai import OpenAI
from environment.api_triage_env import APITriageEnv
from environment.action_space import get_all_actions
from environment.incident_generator import get_incident_by_type
# ============================================
# Environment Variables
# ============================================
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = HF_TOKEN
TASK_NAME = os.getenv("TASK_NAME", "api_triage")
BENCHMARK = os.getenv("BENCHMARK", "api_triage_agent")
MAX_STEPS = 10
TEMPERATURE = 0.7
MAX_TOKENS = 50
SUCCESS_SCORE_THRESHOLD = 0.5
# ============================================
# System Prompt
# ============================================
AVAILABLE_ACTIONS = get_all_actions()
SYSTEM_PROMPT = textwrap.dedent(
f"""
You are an API debugging agent. Your job is to diagnose and fix API failures.
Available actions: {AVAILABLE_ACTIONS}
Rules:
- First use "inspect_logs" to understand the problem
- Then take the correct fix action based on the error
- Finally use "resolve" to end the episode
Reply with ONLY the action name. No explanations. No quotes.
"""
).strip()
# ============================================
# Logging Functions (Required Format)
# ============================================
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ============================================
# Prompt Builder
# ============================================
def build_user_prompt(step: int, observation, last_reward: float, history: List[str]) -> str:
history_block = "\n".join(history[-4:]) if history else "None"
return textwrap.dedent(
f"""
Step: {step}
Incident: {observation.incident_summary}
Response Code: {observation.response_code}
Logs: {observation.logs}
Fix Applied: {observation.fix_applied}
Last Reward: {last_reward:.2f}
Previous Actions:
{history_block}
Choose an action from: {AVAILABLE_ACTIONS}
Reply with ONLY the action name.
"""
).strip()
# ============================================
# LLM Caller
# ============================================
def get_model_action(client: OpenAI, step: int, observation, last_reward: float, history: List[str]) -> str:
user_prompt = build_user_prompt(step, observation, last_reward, history)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
action = (completion.choices[0].message.content or "").strip().lower()
if action not in AVAILABLE_ACTIONS:
print(f"[DEBUG] Invalid action '{action}', defaulting to inspect_logs", flush=True)
return "inspect_logs"
return action
except Exception as exc:
print(f"[DEBUG] Model request failed: {exc}", flush=True)
return "inspect_logs"
# ============================================
# Main Async Function
# ============================================
async def main() -> None:
if not API_KEY:
print("[ERROR] HF_TOKEN environment variable not set", flush=True)
return
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
env = APITriageEnv(max_steps=MAX_STEPS)
# All 6 task IDs matching openenv.yaml — each evaluated explicitly
task_ids = ["auth_error", "missing_fields", "rate_limit", "timeout", "wrong_endpoint", "server_error"]
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
for tid in task_ids:
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
success = False
try:
# Reset env and FORCE the specific incident type (no randomness)
observation = env.reset()
env.incident = get_incident_by_type(tid)
observation = env.state() # refresh observation with forced incident
last_reward = 0.0
for step in range(1, MAX_STEPS + 1):
action = get_model_action(client, step, observation, last_reward, history)
observation, reward, done, info = env.step(action)
rewards.append(reward)
steps_taken = step
last_reward = reward
log_step(step=step, action=action, reward=reward, done=done, error=None)
history.append(f"Step {step}: {action} -> reward {reward:.2f}")
if done:
success = info.get("resolution") == "success"
break
# Score strictly between 0 and 1
task_score = 0.95 if success else 0.05
log_end(success=success, steps=steps_taken, score=task_score, rewards=rewards)
except Exception as e:
print(f"[DEBUG] Error in task {tid}: {e}", flush=True)
log_end(success=False, steps=0, score=0.05, rewards=[0.0])
# ============================================
# Run
# ============================================
if __name__ == "__main__":
asyncio.run(main())
|