"""Baseline inference script for the Feature Flag Cleanup environment. Uses the OpenAI API client to run an LLM agent against all 3 tasks. Reads API credentials from environment variables. Produces structured [START], [STEP], [END] logs. """ import json import os import time import requests from openai import OpenAI # --- Configuration from environment variables --- # Judges inject API_BASE_URL, API_KEY, MODEL_NAME — use those first API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini") HF_TOKEN = os.environ.get("HF_TOKEN", "") # Check API_KEY first (judge's proxy key), then OPENAI_API_KEY (local dev) API_KEY = os.environ.get("API_KEY", "") or os.environ.get("OPENAI_API_KEY", "") # Environment URL — local by default, override for remote ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860") # Timeouts LLM_TIMEOUT = 30 # seconds per LLM call HTTP_TIMEOUT = 15 # seconds per env HTTP call MAX_RETRIES = 2 # max retries per LLM call (down from 5) RETRY_DELAY = 3 # seconds between retries (flat, not exponential) # Initialize OpenAI client with timeout — uses judge's proxy via API_BASE_URL + API_KEY client = OpenAI( api_key=API_KEY, base_url=API_BASE_URL, timeout=LLM_TIMEOUT, max_retries=0, # We handle retries ourselves ) SYSTEM_PROMPT = """You are a senior engineer cleaning up stale feature flags. For each flag, pick ONE action: - "remove": Safe to delete (100% rolled out, no deps, no incidents) - "keep": Still needed (active experiment, kill switch, partial rollout, active dev) - "deprecate": Schedule removal (100% but has deps or inactive owner) - "escalate": Needs human review (complex deps, multi-service, ambiguous) Rules: NEVER remove kill switches, active incidents, or active experiments. Respond ONLY with JSON: {"action": "", "reasoning": ""}""" def call_llm(observation: dict) -> dict: """Call the LLM to decide on a feature flag action.""" # Compact flag info — only essential fields to reduce tokens flag_info = ( f"Flag: {observation['flag_name']}\n" f"Desc: {observation['description']}\n" f"Rollout: {observation['rollout_percentage']*100}% | Age: {observation['age_days']}d | Modified: {observation['last_modified_days']}d ago\n" f"Owner: {observation['owner']} (active={observation['owner_active']})\n" f"Code refs: {observation['num_code_references']} | Usage 30d: {observation['usage_last_30d']}\n" f"Services: {', '.join(observation['services'])}\n" f"Kill switch: {observation['is_kill_switch']} | Active incident: {observation['has_active_incident']} | In experiment: {observation['in_active_experiment']}\n" f"Dependencies: {', '.join(observation['dependent_flags']) if observation['dependent_flags'] else 'None'}\n" ) # Add rich context if available (compact) if observation.get("code_snippet"): flag_info += f"Code: {observation['code_snippet'][:200]}\n" if observation.get("pr_context"): flag_info += f"PR: {observation['pr_context'][:150]}\n" if observation.get("related_incidents"): flag_info += f"Incidents: {'; '.join(observation['related_incidents'][:2])}\n" if observation.get("cascade_warning"): flag_info += f"CASCADE WARNING: {observation['cascade_warning']}\n" if observation.get("investigation_notes"): flag_info += f"Investigation: {observation['investigation_notes'][:200]}\n" for attempt in range(MAX_RETRIES + 1): try: response = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": flag_info}, ], temperature=0.0, max_tokens=100, ) content = response.choices[0].message.content.strip() # Parse JSON if content.startswith("```"): content = content.split("```")[1] if content.startswith("json"): content = content[4:] content = content.strip() result = json.loads(content) if "action" in result and result["action"] in ("remove", "keep", "deprecate", "escalate", "investigate"): return result return {"action": "escalate", "reasoning": "Invalid action in response"} except json.JSONDecodeError: # Try to extract action from plain text content_lower = content.lower() if 'content' in dir() else "" for act in ["remove", "keep", "deprecate", "escalate"]: if act in content_lower: return {"action": act, "reasoning": "Parsed from text"} return {"action": "escalate", "reasoning": "Unparseable response"} except Exception as e: if attempt < MAX_RETRIES: print(f" [RETRY] attempt {attempt+1}/{MAX_RETRIES}, waiting {RETRY_DELAY}s: {str(e)[:80]}", flush=True) time.sleep(RETRY_DELAY) else: return {"action": "escalate", "reasoning": f"API error: {str(e)[:80]}"} return {"action": "escalate", "reasoning": "All retries exhausted"} def run_task(task_id: str) -> float: """Run the agent on a single task and return the score.""" reset_resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=HTTP_TIMEOUT) reset_resp.raise_for_status() reset_data = reset_resp.json() observation = reset_data["observation"] done = reset_data.get("done", False) step_num = 0 print(f'[START] task_id={task_id}', flush=True) while not done: step_num += 1 action = call_llm(observation) step_resp = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=HTTP_TIMEOUT) step_resp.raise_for_status() step_data = step_resp.json() observation = step_data["observation"] reward = step_data["reward"] done = step_data["done"] info = step_data.get("info", {}) print( f'[STEP] task_id={task_id} step={step_num} ' f'flag={info.get("flag_name", "unknown")} ' f'action={info.get("agent_action", action.get("action", "unknown"))} ' f'correct={info.get("correct_action", "unknown")} ' f'reward={reward}', flush=True, ) grade_resp = requests.post(f"{ENV_URL}/grade", timeout=HTTP_TIMEOUT) grade_resp.raise_for_status() grade_data = grade_resp.json() score = grade_data["score"] print(f'[END] task_id={task_id} score={score}', flush=True) return score def main(): """Run baseline inference on all 3 tasks.""" print("=" * 60) print("Feature Flag Cleanup Agent — Baseline Inference") print("=" * 60) print(f"Model: {MODEL_NAME}") print(f"API Base: {API_BASE_URL}") print(f"Environment: {ENV_URL}") print("=" * 60) total_start = time.time() tasks = ["easy", "medium", "hard"] scores = {} for task_id in tasks: print(f"\n--- Running task: {task_id} ---") task_start = time.time() try: score = run_task(task_id) scores[task_id] = score except Exception as e: print(f"[END] task_id={task_id} score=0.0 error={str(e)}", flush=True) scores[task_id] = 0.0 print(f" Task {task_id} took {time.time()-task_start:.1f}s", flush=True) total_time = time.time() - total_start print("\n" + "=" * 60) print("RESULTS SUMMARY") print("=" * 60) for task_id, score in scores.items(): print(f" {task_id:10s}: {score:.4f}") avg_score = sum(scores.values()) / len(scores) if scores else 0.0 print(f" {'average':10s}: {avg_score:.4f}") print(f" {'runtime':10s}: {total_time:.1f}s") print("=" * 60) if __name__ == "__main__": main()