API_DEBUG_SOLVER / inference.py
Siteshcodes's picture
Fix: clamp all scores/rewards strictly to (0,1) exclusive range
9fecec8
import asyncio
import os
import textwrap
from typing import List, Optional
from openai import OpenAI
from environment.api_triage_env import APITriageEnv
from environment.action_space import get_all_actions
from environment.incident_generator import get_incident_by_type
# ============================================
# Environment Variables
# ============================================
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = HF_TOKEN
TASK_NAME = os.getenv("TASK_NAME", "api_triage")
BENCHMARK = os.getenv("BENCHMARK", "api_triage_agent")
MAX_STEPS = 10
TEMPERATURE = 0.7
MAX_TOKENS = 50
SUCCESS_SCORE_THRESHOLD = 0.5
MAX_TOTAL_REWARD = 20.5 # best case: inspect_logs(0.5) + fix(5.0) + resolve(15.0)
# ============================================
# System Prompt
# ============================================
AVAILABLE_ACTIONS = get_all_actions()
SYSTEM_PROMPT = textwrap.dedent(
f"""
You are an API debugging agent. Your job is to diagnose and fix API failures.
Available actions: {AVAILABLE_ACTIONS}
Rules:
- First use "inspect_logs" to understand the problem
- Then take the correct fix action based on the error
- Finally use "resolve" to end the episode
Reply with ONLY the action name. No explanations. No quotes.
"""
).strip()
# ============================================
# Logging Functions (Required Format)
# ============================================
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ============================================
# Prompt Builder
# ============================================
def build_user_prompt(step: int, observation, last_reward: float, history: List[str]) -> str:
history_block = "\n".join(history[-4:]) if history else "None"
return textwrap.dedent(
f"""
Step: {step}
Incident: {observation.incident_summary}
Response Code: {observation.response_code}
Logs: {observation.logs}
Fix Applied: {observation.fix_applied}
Last Reward: {last_reward:.2f}
Previous Actions:
{history_block}
Choose an action from: {AVAILABLE_ACTIONS}
Reply with ONLY the action name.
"""
).strip()
# ============================================
# LLM Caller
# ============================================
def get_model_action(client: OpenAI, step: int, observation, last_reward: float, history: List[str]) -> str:
user_prompt = build_user_prompt(step, observation, last_reward, history)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
action = (completion.choices[0].message.content or "").strip().lower()
if action not in AVAILABLE_ACTIONS:
print(f"[DEBUG] Invalid action '{action}', defaulting to inspect_logs", flush=True)
return "inspect_logs"
return action
except Exception as exc:
print(f"[DEBUG] Model request failed: {exc}", flush=True)
return "inspect_logs"
# ============================================
# Run a single task episode
# ============================================
def run_task(client: OpenAI, task_id: str) -> None:
"""Run one task: [START] -> steps -> [END]."""
env = APITriageEnv(max_steps=MAX_STEPS)
history: List[str] = []
rewards: List[float] = []
steps_taken = 0
score = 0.001
success = False
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
# Reset env and force the specific incident
env.reset()
env.incident = get_incident_by_type(task_id)
env.fix_applied = False
env.done = False
env.step_counter = 0
env.total_reward = 0.0
observation = env.state()
last_reward = 0.0
for step in range(1, MAX_STEPS + 1):
action = get_model_action(client, step, observation, last_reward, history)
observation, reward, done, info = env.step(action)
rewards.append(reward)
steps_taken = step
last_reward = reward
log_step(step=step, action=action, reward=reward, done=done, error=None)
history.append(f"Step {step}: {action} -> reward {reward:.2f}")
if done:
success = info.get("resolution") == "success"
break
# Compute score from actual rewards, clamped strictly to (0, 1)
score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.001
score = min(max(score, 0.001), 0.999)
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
print(f"[DEBUG] Error in task {task_id}: {e}", flush=True)
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
# ============================================
# Main
# ============================================
async def main() -> None:
if not API_KEY:
print("[ERROR] HF_TOKEN environment variable not set", flush=True)
return
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
# All 6 task IDs from openenv.yaml
task_ids = ["auth_error", "missing_fields", "rate_limit", "timeout", "wrong_endpoint", "server_error"]
for tid in task_ids:
run_task(client, tid)
if __name__ == "__main__":
asyncio.run(main())