feature-flag-cleanup / inference.py
Falgunisharma's picture
Fix: read API_KEY env var (judge's LiteLLM proxy key) with OPENAI_API_KEY as fallback
81582eb
"""Baseline inference script for the Feature Flag Cleanup environment.
Uses the OpenAI API client to run an LLM agent against all 3 tasks.
Reads API credentials from environment variables.
Produces structured [START], [STEP], [END] logs.
"""
import json
import os
import time
import requests
from openai import OpenAI
# --- Configuration from environment variables ---
# Judges inject API_BASE_URL, API_KEY, MODEL_NAME — use those first
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# Check API_KEY first (judge's proxy key), then OPENAI_API_KEY (local dev)
API_KEY = os.environ.get("API_KEY", "") or os.environ.get("OPENAI_API_KEY", "")
# Environment URL — local by default, override for remote
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
# Timeouts
LLM_TIMEOUT = 30 # seconds per LLM call
HTTP_TIMEOUT = 15 # seconds per env HTTP call
MAX_RETRIES = 2 # max retries per LLM call (down from 5)
RETRY_DELAY = 3 # seconds between retries (flat, not exponential)
# Initialize OpenAI client with timeout — uses judge's proxy via API_BASE_URL + API_KEY
client = OpenAI(
api_key=API_KEY,
base_url=API_BASE_URL,
timeout=LLM_TIMEOUT,
max_retries=0, # We handle retries ourselves
)
SYSTEM_PROMPT = """You are a senior engineer cleaning up stale feature flags. For each flag, pick ONE action:
- "remove": Safe to delete (100% rolled out, no deps, no incidents)
- "keep": Still needed (active experiment, kill switch, partial rollout, active dev)
- "deprecate": Schedule removal (100% but has deps or inactive owner)
- "escalate": Needs human review (complex deps, multi-service, ambiguous)
Rules: NEVER remove kill switches, active incidents, or active experiments.
Respond ONLY with JSON: {"action": "<action>", "reasoning": "<brief>"}"""
def call_llm(observation: dict) -> dict:
"""Call the LLM to decide on a feature flag action."""
# Compact flag info — only essential fields to reduce tokens
flag_info = (
f"Flag: {observation['flag_name']}\n"
f"Desc: {observation['description']}\n"
f"Rollout: {observation['rollout_percentage']*100}% | Age: {observation['age_days']}d | Modified: {observation['last_modified_days']}d ago\n"
f"Owner: {observation['owner']} (active={observation['owner_active']})\n"
f"Code refs: {observation['num_code_references']} | Usage 30d: {observation['usage_last_30d']}\n"
f"Services: {', '.join(observation['services'])}\n"
f"Kill switch: {observation['is_kill_switch']} | Active incident: {observation['has_active_incident']} | In experiment: {observation['in_active_experiment']}\n"
f"Dependencies: {', '.join(observation['dependent_flags']) if observation['dependent_flags'] else 'None'}\n"
)
# Add rich context if available (compact)
if observation.get("code_snippet"):
flag_info += f"Code: {observation['code_snippet'][:200]}\n"
if observation.get("pr_context"):
flag_info += f"PR: {observation['pr_context'][:150]}\n"
if observation.get("related_incidents"):
flag_info += f"Incidents: {'; '.join(observation['related_incidents'][:2])}\n"
if observation.get("cascade_warning"):
flag_info += f"CASCADE WARNING: {observation['cascade_warning']}\n"
if observation.get("investigation_notes"):
flag_info += f"Investigation: {observation['investigation_notes'][:200]}\n"
for attempt in range(MAX_RETRIES + 1):
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": flag_info},
],
temperature=0.0,
max_tokens=100,
)
content = response.choices[0].message.content.strip()
# Parse JSON
if content.startswith("```"):
content = content.split("```")[1]
if content.startswith("json"):
content = content[4:]
content = content.strip()
result = json.loads(content)
if "action" in result and result["action"] in ("remove", "keep", "deprecate", "escalate", "investigate"):
return result
return {"action": "escalate", "reasoning": "Invalid action in response"}
except json.JSONDecodeError:
# Try to extract action from plain text
content_lower = content.lower() if 'content' in dir() else ""
for act in ["remove", "keep", "deprecate", "escalate"]:
if act in content_lower:
return {"action": act, "reasoning": "Parsed from text"}
return {"action": "escalate", "reasoning": "Unparseable response"}
except Exception as e:
if attempt < MAX_RETRIES:
print(f" [RETRY] attempt {attempt+1}/{MAX_RETRIES}, waiting {RETRY_DELAY}s: {str(e)[:80]}", flush=True)
time.sleep(RETRY_DELAY)
else:
return {"action": "escalate", "reasoning": f"API error: {str(e)[:80]}"}
return {"action": "escalate", "reasoning": "All retries exhausted"}
def run_task(task_id: str) -> float:
"""Run the agent on a single task and return the score."""
reset_resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=HTTP_TIMEOUT)
reset_resp.raise_for_status()
reset_data = reset_resp.json()
observation = reset_data["observation"]
done = reset_data.get("done", False)
step_num = 0
print(f'[START] task_id={task_id}', flush=True)
while not done:
step_num += 1
action = call_llm(observation)
step_resp = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=HTTP_TIMEOUT)
step_resp.raise_for_status()
step_data = step_resp.json()
observation = step_data["observation"]
reward = step_data["reward"]
done = step_data["done"]
info = step_data.get("info", {})
print(
f'[STEP] task_id={task_id} step={step_num} '
f'flag={info.get("flag_name", "unknown")} '
f'action={info.get("agent_action", action.get("action", "unknown"))} '
f'correct={info.get("correct_action", "unknown")} '
f'reward={reward}',
flush=True,
)
grade_resp = requests.post(f"{ENV_URL}/grade", timeout=HTTP_TIMEOUT)
grade_resp.raise_for_status()
grade_data = grade_resp.json()
score = grade_data["score"]
print(f'[END] task_id={task_id} score={score}', flush=True)
return score
def main():
"""Run baseline inference on all 3 tasks."""
print("=" * 60)
print("Feature Flag Cleanup Agent — Baseline Inference")
print("=" * 60)
print(f"Model: {MODEL_NAME}")
print(f"API Base: {API_BASE_URL}")
print(f"Environment: {ENV_URL}")
print("=" * 60)
total_start = time.time()
tasks = ["easy", "medium", "hard"]
scores = {}
for task_id in tasks:
print(f"\n--- Running task: {task_id} ---")
task_start = time.time()
try:
score = run_task(task_id)
scores[task_id] = score
except Exception as e:
print(f"[END] task_id={task_id} score=0.0 error={str(e)}", flush=True)
scores[task_id] = 0.0
print(f" Task {task_id} took {time.time()-task_start:.1f}s", flush=True)
total_time = time.time() - total_start
print("\n" + "=" * 60)
print("RESULTS SUMMARY")
print("=" * 60)
for task_id, score in scores.items():
print(f" {task_id:10s}: {score:.4f}")
avg_score = sum(scores.values()) / len(scores) if scores else 0.0
print(f" {'average':10s}: {avg_score:.4f}")
print(f" {'runtime':10s}: {total_time:.1f}s")
print("=" * 60)
if __name__ == "__main__":
main()