Spaces:
Sleeping
Sleeping
| """Baseline inference script for the Feature Flag Cleanup environment. | |
| Uses the OpenAI API client to run an LLM agent against all 3 tasks. | |
| Reads API credentials from environment variables. | |
| Produces structured [START], [STEP], [END] logs. | |
| """ | |
| import json | |
| import os | |
| import time | |
| import requests | |
| from openai import OpenAI | |
| # --- Configuration from environment variables --- | |
| # Judges inject API_BASE_URL, API_KEY, MODEL_NAME — use those first | |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| # Check API_KEY first (judge's proxy key), then OPENAI_API_KEY (local dev) | |
| API_KEY = os.environ.get("API_KEY", "") or os.environ.get("OPENAI_API_KEY", "") | |
| # Environment URL — local by default, override for remote | |
| ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860") | |
| # Timeouts | |
| LLM_TIMEOUT = 30 # seconds per LLM call | |
| HTTP_TIMEOUT = 15 # seconds per env HTTP call | |
| MAX_RETRIES = 2 # max retries per LLM call (down from 5) | |
| RETRY_DELAY = 3 # seconds between retries (flat, not exponential) | |
| # Initialize OpenAI client with timeout — uses judge's proxy via API_BASE_URL + API_KEY | |
| client = OpenAI( | |
| api_key=API_KEY, | |
| base_url=API_BASE_URL, | |
| timeout=LLM_TIMEOUT, | |
| max_retries=0, # We handle retries ourselves | |
| ) | |
| SYSTEM_PROMPT = """You are a senior engineer cleaning up stale feature flags. For each flag, pick ONE action: | |
| - "remove": Safe to delete (100% rolled out, no deps, no incidents) | |
| - "keep": Still needed (active experiment, kill switch, partial rollout, active dev) | |
| - "deprecate": Schedule removal (100% but has deps or inactive owner) | |
| - "escalate": Needs human review (complex deps, multi-service, ambiguous) | |
| Rules: NEVER remove kill switches, active incidents, or active experiments. | |
| Respond ONLY with JSON: {"action": "<action>", "reasoning": "<brief>"}""" | |
| def call_llm(observation: dict) -> dict: | |
| """Call the LLM to decide on a feature flag action.""" | |
| # Compact flag info — only essential fields to reduce tokens | |
| flag_info = ( | |
| f"Flag: {observation['flag_name']}\n" | |
| f"Desc: {observation['description']}\n" | |
| f"Rollout: {observation['rollout_percentage']*100}% | Age: {observation['age_days']}d | Modified: {observation['last_modified_days']}d ago\n" | |
| f"Owner: {observation['owner']} (active={observation['owner_active']})\n" | |
| f"Code refs: {observation['num_code_references']} | Usage 30d: {observation['usage_last_30d']}\n" | |
| f"Services: {', '.join(observation['services'])}\n" | |
| f"Kill switch: {observation['is_kill_switch']} | Active incident: {observation['has_active_incident']} | In experiment: {observation['in_active_experiment']}\n" | |
| f"Dependencies: {', '.join(observation['dependent_flags']) if observation['dependent_flags'] else 'None'}\n" | |
| ) | |
| # Add rich context if available (compact) | |
| if observation.get("code_snippet"): | |
| flag_info += f"Code: {observation['code_snippet'][:200]}\n" | |
| if observation.get("pr_context"): | |
| flag_info += f"PR: {observation['pr_context'][:150]}\n" | |
| if observation.get("related_incidents"): | |
| flag_info += f"Incidents: {'; '.join(observation['related_incidents'][:2])}\n" | |
| if observation.get("cascade_warning"): | |
| flag_info += f"CASCADE WARNING: {observation['cascade_warning']}\n" | |
| if observation.get("investigation_notes"): | |
| flag_info += f"Investigation: {observation['investigation_notes'][:200]}\n" | |
| for attempt in range(MAX_RETRIES + 1): | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": flag_info}, | |
| ], | |
| temperature=0.0, | |
| max_tokens=100, | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| # Parse JSON | |
| if content.startswith("```"): | |
| content = content.split("```")[1] | |
| if content.startswith("json"): | |
| content = content[4:] | |
| content = content.strip() | |
| result = json.loads(content) | |
| if "action" in result and result["action"] in ("remove", "keep", "deprecate", "escalate", "investigate"): | |
| return result | |
| return {"action": "escalate", "reasoning": "Invalid action in response"} | |
| except json.JSONDecodeError: | |
| # Try to extract action from plain text | |
| content_lower = content.lower() if 'content' in dir() else "" | |
| for act in ["remove", "keep", "deprecate", "escalate"]: | |
| if act in content_lower: | |
| return {"action": act, "reasoning": "Parsed from text"} | |
| return {"action": "escalate", "reasoning": "Unparseable response"} | |
| except Exception as e: | |
| if attempt < MAX_RETRIES: | |
| print(f" [RETRY] attempt {attempt+1}/{MAX_RETRIES}, waiting {RETRY_DELAY}s: {str(e)[:80]}", flush=True) | |
| time.sleep(RETRY_DELAY) | |
| else: | |
| return {"action": "escalate", "reasoning": f"API error: {str(e)[:80]}"} | |
| return {"action": "escalate", "reasoning": "All retries exhausted"} | |
| def run_task(task_id: str) -> float: | |
| """Run the agent on a single task and return the score.""" | |
| reset_resp = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id}, timeout=HTTP_TIMEOUT) | |
| reset_resp.raise_for_status() | |
| reset_data = reset_resp.json() | |
| observation = reset_data["observation"] | |
| done = reset_data.get("done", False) | |
| step_num = 0 | |
| print(f'[START] task_id={task_id}', flush=True) | |
| while not done: | |
| step_num += 1 | |
| action = call_llm(observation) | |
| step_resp = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=HTTP_TIMEOUT) | |
| step_resp.raise_for_status() | |
| step_data = step_resp.json() | |
| observation = step_data["observation"] | |
| reward = step_data["reward"] | |
| done = step_data["done"] | |
| info = step_data.get("info", {}) | |
| print( | |
| f'[STEP] task_id={task_id} step={step_num} ' | |
| f'flag={info.get("flag_name", "unknown")} ' | |
| f'action={info.get("agent_action", action.get("action", "unknown"))} ' | |
| f'correct={info.get("correct_action", "unknown")} ' | |
| f'reward={reward}', | |
| flush=True, | |
| ) | |
| grade_resp = requests.post(f"{ENV_URL}/grade", timeout=HTTP_TIMEOUT) | |
| grade_resp.raise_for_status() | |
| grade_data = grade_resp.json() | |
| score = grade_data["score"] | |
| print(f'[END] task_id={task_id} score={score}', flush=True) | |
| return score | |
| def main(): | |
| """Run baseline inference on all 3 tasks.""" | |
| print("=" * 60) | |
| print("Feature Flag Cleanup Agent — Baseline Inference") | |
| print("=" * 60) | |
| print(f"Model: {MODEL_NAME}") | |
| print(f"API Base: {API_BASE_URL}") | |
| print(f"Environment: {ENV_URL}") | |
| print("=" * 60) | |
| total_start = time.time() | |
| tasks = ["easy", "medium", "hard"] | |
| scores = {} | |
| for task_id in tasks: | |
| print(f"\n--- Running task: {task_id} ---") | |
| task_start = time.time() | |
| try: | |
| score = run_task(task_id) | |
| scores[task_id] = score | |
| except Exception as e: | |
| print(f"[END] task_id={task_id} score=0.0 error={str(e)}", flush=True) | |
| scores[task_id] = 0.0 | |
| print(f" Task {task_id} took {time.time()-task_start:.1f}s", flush=True) | |
| total_time = time.time() - total_start | |
| print("\n" + "=" * 60) | |
| print("RESULTS SUMMARY") | |
| print("=" * 60) | |
| for task_id, score in scores.items(): | |
| print(f" {task_id:10s}: {score:.4f}") | |
| avg_score = sum(scores.values()) / len(scores) if scores else 0.0 | |
| print(f" {'average':10s}: {avg_score:.4f}") | |
| print(f" {'runtime':10s}: {total_time:.1f}s") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |