savetrees's picture
Upload folder using huggingface_hub
9b47159 verified
"""
Baseline inference script for the Bug Triage OpenEnv environment.
Provider priority:
1. OpenAI API client (OPENAI_API_KEY) - spec-required primary
2. Google Gemini SDK (GEMINI_API_KEY) - fallback
3. Random actions (no key) - last resort
Usage:
OPENAI_API_KEY="sk-..." python -m bug_triage_env.baseline --all-tasks --json
GEMINI_API_KEY="AI..." python -m bug_triage_env.baseline --all-tasks --episodes 5
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import time
from typing import Any, Dict, List, Optional
import requests
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
logger = logging.getLogger(__name__)
# -- Configuration -------------------------------------------------------------
ENV_URL = os.getenv("BUG_TRIAGE_ENV_URL", "http://localhost:8000")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
EPISODES_PER_TASK = 10
SYSTEM_PROMPT = (
"You are a senior software engineer performing bug triage.\n"
"You will receive a bug report and must respond with a JSON object.\n"
"Be concise. Think carefully about severity and impact before deciding.\n"
"\n"
"Available bug types: crash, ui, performance, security, data_loss, compatibility\n"
"Available priorities: low, medium, high, critical\n"
"Available developers: Alice (crash/performance/data_loss), Bob (crash/security),\n"
" Carol (ui/compatibility), David (security/data_loss), Eve (ui/performance/compatibility)\n"
"Available actions: fix_immediately, schedule_sprint, needs_more_info, wontfix, duplicate\n"
"\n"
"IMPORTANT: Respond with ONLY valid JSON, no markdown, no explanation."
)
TASK_PROMPTS: Dict[str, str] = {
"task_1": (
"Classify the bug type only.\n"
'Respond ONLY with valid JSON: {"task_id": "task_1", "bug_type": "<type>", '
'"confidence": <0.0-1.0>}'
),
"task_2": (
"Assign the priority level only.\n"
'Respond ONLY with valid JSON: {"task_id": "task_2", "priority": "<level>", '
'"confidence": <0.0-1.0>}'
),
"task_3": (
"Perform full triage: classify type, assign priority, assign developer, "
"suggest action.\n"
"Include a confidence score (0.0=guessing, 1.0=certain).\n"
'Respond ONLY with valid JSON:\n'
'{"task_id": "task_3", "bug_type": "<type>", "priority": "<level>", '
'"assigned_developer": "<name>", "suggested_action": "<action>", '
'"confidence": <0.0-1.0>, "reasoning": "<brief reasoning>"}'
),
}
def build_user_prompt(bug_report: Dict[str, Any], task_id: str) -> str:
"""Construct the user prompt from a bug report and task-specific instructions."""
parts = [
f"Title: {bug_report.get('title', 'N/A')}",
f"Description: {bug_report.get('description', 'N/A')}",
]
if bug_report.get("logs"):
parts.append(f"Logs:\n{bug_report['logs']}")
if bug_report.get("environment"):
parts.append(f"Environment: {bug_report['environment']}")
if bug_report.get("metadata"):
parts.append(f"Metadata: {json.dumps(bug_report['metadata'])}")
parts.append("")
parts.append(TASK_PROMPTS[task_id])
return "\n".join(parts)
# -- Provider 1: OpenAI (spec-required) ---------------------------------------
_openai_client = None
def _get_openai_client():
global _openai_client
if _openai_client is None and OPENAI_API_KEY:
try:
from openai import OpenAI
_openai_client = OpenAI(api_key=OPENAI_API_KEY)
except ImportError:
logger.error("openai package not installed. Run: pip install openai")
return None
return _openai_client
def call_openai(user_prompt: str, max_retries: int = 3) -> Optional[Dict[str, Any]]:
"""Call OpenAI API with retry and exponential backoff."""
client = _get_openai_client()
if client is None:
return None
backoff_delays = [5, 15, 30]
for attempt in range(max_retries):
try:
response = client.chat.completions.create(
model=OPENAI_MODEL,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=0.2,
response_format={"type": "json_object"},
)
content = response.choices[0].message.content
return json.loads(content)
except Exception as exc:
err_str = str(exc)
if "429" in err_str or "rate" in err_str.lower() or "503" in err_str:
delay = backoff_delays[min(attempt, len(backoff_delays) - 1)]
logger.warning(
"OpenAI rate limited (attempt %d/%d). Retrying in %ds...",
attempt + 1, max_retries, delay,
)
time.sleep(delay)
else:
logger.error("OpenAI call failed: %s", exc)
return None
logger.error("OpenAI call failed after %d retries.", max_retries)
return None
# -- Provider 2: Google Gemini (fallback) --------------------------------------
_gemini_client = None
def _get_gemini_client():
global _gemini_client
if _gemini_client is None and GEMINI_API_KEY:
try:
from google import genai
_gemini_client = genai.Client(api_key=GEMINI_API_KEY)
except ImportError:
logger.error("google-genai not installed. Run: pip install google-genai")
return None
return _gemini_client
def call_gemini(user_prompt: str, max_retries: int = 3) -> Optional[Dict[str, Any]]:
"""Call Google Gemini API with retry and exponential backoff."""
client = _get_gemini_client()
if client is None:
return None
try:
from google.genai import types
except ImportError:
return None
backoff_delays = [10, 30, 60]
for attempt in range(max_retries):
try:
response = client.models.generate_content(
model=GEMINI_MODEL,
config=types.GenerateContentConfig(
system_instruction=SYSTEM_PROMPT,
temperature=0.2,
response_mime_type="application/json",
),
contents=user_prompt,
)
content = response.text
return json.loads(content)
except Exception as exc:
err_str = str(exc)
retryable = any(
keyword in err_str
for keyword in ("429", "RESOURCE_EXHAUSTED", "503", "UNAVAILABLE")
)
if retryable:
delay = backoff_delays[min(attempt, len(backoff_delays) - 1)]
logger.warning(
"Gemini rate limited (attempt %d/%d). Retrying in %ds...",
attempt + 1, max_retries, delay,
)
time.sleep(delay)
else:
logger.error("Gemini call failed: %s", exc)
return None
logger.error("Gemini call failed after %d retries.", max_retries)
return None
# -- Unified LLM dispatcher ---------------------------------------------------
def call_llm(user_prompt: str) -> Optional[Dict[str, Any]]:
"""Call LLM with provider priority: OpenAI > Gemini > None."""
if OPENAI_API_KEY:
result = call_openai(user_prompt)
if result is not None:
return result
if GEMINI_API_KEY:
result = call_gemini(user_prompt)
if result is not None:
return result
if not OPENAI_API_KEY and not GEMINI_API_KEY:
logger.warning(
"No API key set (OPENAI_API_KEY or GEMINI_API_KEY). "
"Using random actions."
)
return None
def get_active_model() -> str:
"""Return name of the model being used."""
if OPENAI_API_KEY:
return OPENAI_MODEL
if GEMINI_API_KEY:
return GEMINI_MODEL
return "random"
def random_action(task_id: str) -> Dict[str, Any]:
"""Generate a random triage action as fallback."""
import random
action: Dict[str, Any] = {"task_id": task_id}
bug_types = [
"crash", "ui", "performance", "security", "data_loss", "compatibility",
]
priorities = ["low", "medium", "high", "critical"]
developers = ["Alice", "Bob", "Carol", "David", "Eve"]
actions = [
"fix_immediately", "schedule_sprint", "needs_more_info",
"wontfix", "duplicate",
]
if task_id in ("task_1", "task_3"):
action["bug_type"] = random.choice(bug_types)
if task_id in ("task_2", "task_3"):
action["priority"] = random.choice(priorities)
if task_id == "task_3":
action["assigned_developer"] = random.choice(developers)
action["suggested_action"] = random.choice(actions)
return action
# -- Episode runner ------------------------------------------------------------
def run_episode(task_id: str, base_url: str) -> float:
"""Run a single episode: reset, call LLM, step, return score."""
reset_resp = requests.post(
f"{base_url}/reset", json={"task_id": task_id}, timeout=30,
)
reset_resp.raise_for_status()
obs = reset_resp.json()
episode_id = obs["episode_id"]
bug_report = obs.get("bug_report", {})
user_prompt = build_user_prompt(bug_report, task_id)
action_dict = call_llm(user_prompt) or random_action(task_id)
if "task_id" not in action_dict:
action_dict["task_id"] = task_id
step_resp = requests.post(
f"{base_url}/step",
json={"episode_id": episode_id, "action": action_dict},
timeout=30,
)
step_resp.raise_for_status()
step_data = step_resp.json()
score = step_data.get("grader_score", 0.0)
logger.info("Episode %s | task=%s | score=%.3f", episode_id, task_id, score)
return score
def run_all_tasks(
base_url: str = ENV_URL,
n_episodes: int = EPISODES_PER_TASK,
tasks: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Run n_episodes for each task. Returns summary dict."""
tasks = tasks or ["task_1", "task_2", "task_3"]
results: Dict[str, Any] = {}
for task_id in tasks:
scores: List[float] = []
for ep_idx in range(n_episodes):
try:
score = run_episode(task_id, base_url)
scores.append(score)
except Exception as exc:
logger.error(
"Episode %d for %s failed: %s", ep_idx, task_id, exc,
)
scores.append(0.0)
mean_score = sum(scores) / len(scores) if scores else 0.0
results[task_id] = {
"mean_score": round(mean_score, 4),
"min_score": round(min(scores), 4) if scores else 0.0,
"max_score": round(max(scores), 4) if scores else 0.0,
"episodes": n_episodes,
}
logger.info("%s mean score: %.4f", task_id, mean_score)
all_means = [v["mean_score"] for v in results.values()]
overall_mean = sum(all_means) / len(all_means) if all_means else 0.0
return {
"baseline_model": get_active_model(),
"results": results,
"mean_score": round(overall_mean, 4),
}
# -- CLI entry point -----------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Bug Triage Baseline Inference",
)
parser.add_argument(
"--all-tasks", action="store_true", help="Run all 3 tasks",
)
parser.add_argument(
"--task",
choices=["task_1", "task_2", "task_3"],
help="Run a single task",
)
parser.add_argument(
"--episodes", type=int, default=EPISODES_PER_TASK,
)
parser.add_argument("--env-url", default=ENV_URL)
parser.add_argument(
"--json", action="store_true", help="Output JSON to stdout",
)
args = parser.parse_args()
selected_tasks = None
if args.task:
selected_tasks = [args.task]
elif args.all_tasks:
selected_tasks = ["task_1", "task_2", "task_3"]
output = run_all_tasks(
base_url=args.env_url,
n_episodes=args.episodes,
tasks=selected_tasks,
)
if args.json:
print(json.dumps(output))
else:
print("\n=== Bug Triage Baseline Results ===")
for tid, metrics in output["results"].items():
print(
f" {tid}: mean={metrics['mean_score']:.4f} "
f"[{metrics['min_score']:.2f} - {metrics['max_score']:.2f}]"
)
print(f"\n Overall mean: {output['mean_score']:.4f}")
print(f" Model: {output['baseline_model']}")
if not OPENAI_API_KEY and not GEMINI_API_KEY:
print(" Warning: No API key set. Used random actions.")
print(" Set OPENAI_API_KEY or GEMINI_API_KEY for real baseline.")