Spaces:
Sleeping
Sleeping
File size: 6,369 Bytes
d416acc 96939ad d416acc 96939ad 9fecec8 96939ad 9fecec8 96939ad d416acc 96939ad d416acc 96939ad d416acc 96939ad d416acc 96939ad d416acc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import asyncio
import os
import textwrap
from typing import List, Optional
from openai import OpenAI
from environment.api_triage_env import APITriageEnv
from environment.action_space import get_all_actions
from environment.incident_generator import get_incident_by_type
# ============================================
# Environment Variables
# ============================================
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = HF_TOKEN
TASK_NAME = os.getenv("TASK_NAME", "api_triage")
BENCHMARK = os.getenv("BENCHMARK", "api_triage_agent")
MAX_STEPS = 10
TEMPERATURE = 0.7
MAX_TOKENS = 50
SUCCESS_SCORE_THRESHOLD = 0.5
MAX_TOTAL_REWARD = 20.5 # best case: inspect_logs(0.5) + fix(5.0) + resolve(15.0)
# ============================================
# System Prompt
# ============================================
AVAILABLE_ACTIONS = get_all_actions()
SYSTEM_PROMPT = textwrap.dedent(
f"""
You are an API debugging agent. Your job is to diagnose and fix API failures.
Available actions: {AVAILABLE_ACTIONS}
Rules:
- First use "inspect_logs" to understand the problem
- Then take the correct fix action based on the error
- Finally use "resolve" to end the episode
Reply with ONLY the action name. No explanations. No quotes.
"""
).strip()
# ============================================
# Logging Functions (Required Format)
# ============================================
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ============================================
# Prompt Builder
# ============================================
def build_user_prompt(step: int, observation, last_reward: float, history: List[str]) -> str:
history_block = "\n".join(history[-4:]) if history else "None"
return textwrap.dedent(
f"""
Step: {step}
Incident: {observation.incident_summary}
Response Code: {observation.response_code}
Logs: {observation.logs}
Fix Applied: {observation.fix_applied}
Last Reward: {last_reward:.2f}
Previous Actions:
{history_block}
Choose an action from: {AVAILABLE_ACTIONS}
Reply with ONLY the action name.
"""
).strip()
# ============================================
# LLM Caller
# ============================================
def get_model_action(client: OpenAI, step: int, observation, last_reward: float, history: List[str]) -> str:
user_prompt = build_user_prompt(step, observation, last_reward, history)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
action = (completion.choices[0].message.content or "").strip().lower()
if action not in AVAILABLE_ACTIONS:
print(f"[DEBUG] Invalid action '{action}', defaulting to inspect_logs", flush=True)
return "inspect_logs"
return action
except Exception as exc:
print(f"[DEBUG] Model request failed: {exc}", flush=True)
return "inspect_logs"
# ============================================
# Run a single task episode
# ============================================
def run_task(client: OpenAI, task_id: str) -> None:
"""Run one task: [START] -> steps -> [END]."""
env = APITriageEnv(max_steps=MAX_STEPS)
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = 0.001
success = False
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
# Reset env and force the specific incident
env.reset()
env.incident = get_incident_by_type(task_id)
env.fix_applied = False
env.done = False
env.step_counter = 0
env.total_reward = 0.0
observation = env.state()
last_reward = 0.0
for step in range(1, MAX_STEPS + 1):
action = get_model_action(client, step, observation, last_reward, history)
observation, reward, done, info = env.step(action)
rewards.append(reward)
steps_taken = step
last_reward = reward
log_step(step=step, action=action, reward=reward, done=done, error=None)
history.append(f"Step {step}: {action} -> reward {reward:.2f}")
if done:
success = info.get("resolution") == "success"
break
# Compute score from actual rewards, clamped strictly to (0, 1)
score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.001
score = min(max(score, 0.001), 0.999)
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
print(f"[DEBUG] Error in task {task_id}: {e}", flush=True)
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
# ============================================
# Main
# ============================================
async def main() -> None:
if not API_KEY:
print("[ERROR] HF_TOKEN environment variable not set", flush=True)
return
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
# All 6 task IDs from openenv.yaml
task_ids = ["auth_error", "missing_fields", "rate_limit", "timeout", "wrong_endpoint", "server_error"]
for tid in task_ids:
run_task(client, tid)
if __name__ == "__main__":
asyncio.run(main())
|