#!/usr/bin/env python3 """Baseline inference script for DNS-Env (OpenEnv Hackathon). STDOUT FORMAT (mandatory — any deviation = incorrect scoring): [START] task= env= model= [STEP] step= action= reward=<0.00> done= error= [END] success= steps= score= rewards= """ from __future__ import annotations import json import os import re import sys import time import traceback from typing import Any, List, Optional import requests from openai import OpenAI # --------------------------------------------------------------------------- # Configuration from environment variables # --------------------------------------------------------------------------- API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") HF_TOKEN = os.getenv("HF_TOKEN") ENV_URL = os.getenv("ENV_URL", "http://localhost:7860") # Optional — if you use from_docker_image(): LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") TASKS = ["fix_single_record", "configure_mail", "debug_delegation"] BENCHMARK = "dns_env" # Safety limits MAX_STEPS_PER_TASK = 25 MAX_RETRIES_HTTP = 3 HTTP_TIMEOUT = 60 SUCCESS_SCORE_THRESHOLD = 0.5 # --------------------------------------------------------------------------- # OpenAI client # --------------------------------------------------------------------------- client = OpenAI( base_url=API_BASE_URL, api_key=HF_TOKEN or os.getenv("OPENAI_API_KEY", ""), ) # --------------------------------------------------------------------------- # Mandatory stdout log functions # --------------------------------------------------------------------------- def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: error_val = error if error else "null" done_val = str(done).lower() print( f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True, ) # --------------------------------------------------------------------------- # System prompt # --------------------------------------------------------------------------- SYSTEM_PROMPT = """\ You are an expert DNS administrator and zone-file debugger. You are \ interacting with a DNS zone-file debugging environment. Your goal is to \ diagnose and fix errors in DNS zone files so that all records resolve \ correctly. ## Available Commands Respond with exactly ONE JSON object per turn. The JSON must have \ "command" and "args" keys. 1. **view_zone** -- View the zone file with indexed records. {"command": "view_zone", "args": {"zone": ""}} 2. **check_zone** -- Validate the zone and list any errors. {"command": "check_zone", "args": {"zone": ""}} 3. **edit_record** -- Modify an existing record by its index. Only \ include the fields you want to change. {"command": "edit_record", "args": {"zone": "", "index": , "rdata": ""}} You can also change "name", "rtype", or "ttl" in the same call. 4. **add_record** -- Add a new record to the zone. {"command": "add_record", "args": {"zone": "", "name": "", "rtype": "", "rdata": ""}} 5. **delete_record** -- Remove a record by index. {"command": "delete_record", "args": {"zone": "", "index": }} 6. **dig** -- Simulate a DNS query to test resolution. {"command": "dig", "args": {"zone": "", "qname": "", "qtype": ""}} 7. **submit** -- Submit your work for grading. Use this when you are \ confident all fixes are correct. {"command": "submit", "args": {}} ## DNS Debugging Tips - **Trailing dots on FQDNs**: CNAME targets, NS targets, and MX targets \ that are fully-qualified domain names MUST end with a trailing dot \ (e.g., "example.com." not "example.com"). Without the dot, the name is \ treated as relative to the zone origin. - **A record IPs**: Must be valid IPv4 addresses (each octet 0-255). - **MX records**: Format is " ". The target must \ end with a dot if it is an FQDN. - **CNAME exclusivity**: A name with a CNAME record cannot have any \ other record types at the same name. - **TXT records**: The rdata should be enclosed in double quotes. - **SPF records**: Use TXT record type. Example: \ "v=spf1 ip4:10.0.1.0/24 -all" - **DMARC records**: Use TXT record at _dmarc. Example: \ "v=DMARC1; p=quarantine; rua=mailto:postmaster@example.com" - **NS delegation**: Parent zone NS records and glue records must be \ consistent with the child zone's NS records and A records. - **SOA serial**: The child zone SOA serial should generally be >= the \ parent zone's serial. - **Zone indices**: Records are labeled with [N] indices. Use these \ indices when editing or deleting records. ## Strategy 1. Start by viewing all available zones with view_zone. 2. Run check_zone to identify validation errors. 3. Fix each error using edit_record, add_record, or delete_record. 4. Use dig to verify your fixes resolve correctly. 5. When all issues are fixed, submit. ## Response Format You MUST respond with a single JSON object and nothing else. \ Do not include explanations outside the JSON. Example: {"command": "view_zone", "args": {"zone": "example.com"}} """ # --------------------------------------------------------------------------- # HTTP helpers # --------------------------------------------------------------------------- def _post(endpoint: str, body: dict[str, Any]) -> dict[str, Any]: """POST to the environment server with retries.""" url = f"{ENV_URL}{endpoint}" for attempt in range(1, MAX_RETRIES_HTTP + 1): try: resp = requests.post(url, json=body, timeout=HTTP_TIMEOUT) resp.raise_for_status() return resp.json() except (requests.RequestException, ValueError) as exc: if attempt == MAX_RETRIES_HTTP: raise time.sleep(1.0 * attempt) return {} def reset_env(task_id: str) -> dict[str, Any]: body: dict[str, Any] = {"session_id": "default", "options": {"task_id": task_id}} return _post("/reset", body) def step_env(action: dict[str, Any]) -> dict[str, Any]: body = {"session_id": "default", "action": action} return _post("/step", body) # --------------------------------------------------------------------------- # Prompt construction # --------------------------------------------------------------------------- def build_prompt( obs: dict[str, Any], task_id: str, step_num: int, max_steps: int, history: list[dict[str, str]], ) -> str: parts: list[str] = [] task_desc = obs.get("task_description", "") if task_desc: parts.append(f"## Task: {task_id}\n{task_desc}") zone_names = obs.get("zone_names", []) if zone_names: parts.append(f"Available zones: {', '.join(zone_names)}") remaining = max_steps - step_num parts.append(f"Step {step_num}/{max_steps} (remaining: {remaining})") if history: recent = history[-3:] lines = [] for h in recent: lines.append(f" Action: {h['action']}") preview = h["result"][:300] if len(h["result"]) > 300: preview += "..." lines.append(f" Result: {preview}") parts.append("## Recent History\n" + "\n".join(lines)) output = obs.get("output", "") if output: parts.append(f"## Current Output\n{output}") if remaining <= 3: parts.append( 'WARNING: Running low on steps. Submit now: {"command": "submit", "args": {}}' ) parts.append('Respond with a single JSON object: {"command": "...", "args": {...}}') return "\n\n".join(parts) # --------------------------------------------------------------------------- # LLM response parsing # --------------------------------------------------------------------------- _JSON_BLOCK_RE = re.compile(r"```(?:json)?\s*\n?(.*?)\n?\s*```", re.DOTALL) _JSON_OBJECT_RE = re.compile(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", re.DOTALL) def parse_llm_response(text: str | None) -> dict[str, Any]: default_action: dict[str, Any] = {"command": "view_zone", "args": {}} if not text: return default_action text = text.strip() action = _try_parse_json(text) if action is not None: return action match = _JSON_BLOCK_RE.search(text) if match: action = _try_parse_json(match.group(1).strip()) if action is not None: return action match = _JSON_OBJECT_RE.search(text) if match: action = _try_parse_json(match.group(0)) if action is not None: return action return default_action def _try_parse_json(text: str) -> dict[str, Any] | None: try: data = json.loads(text) if isinstance(data, dict) and "command" in data: if "args" not in data or not isinstance(data.get("args"), dict): data["args"] = data.get("args", {}) or {} return {"command": str(data["command"]), "args": data["args"]} except (json.JSONDecodeError, TypeError, ValueError): pass return None # --------------------------------------------------------------------------- # LLM interaction # --------------------------------------------------------------------------- def call_llm(messages: list[dict[str, str]], temperature: float = 0.0) -> str: for attempt in range(1, 3): try: response = client.chat.completions.create( model=MODEL_NAME, messages=messages, temperature=temperature, ) content = response.choices[0].message.content return content if content else "" except Exception as exc: if attempt == 2: raise time.sleep(2.0) return "" # --------------------------------------------------------------------------- # Main inference loop # --------------------------------------------------------------------------- def run_task(task_id: str) -> float: """Run a single task. Returns score in [0, 1].""" rewards: List[float] = [] steps_taken = 0 score = 0.0 success = False log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) try: obs = reset_env(task_id) except Exception as exc: print(f"[ERROR] Failed to reset: {exc}", file=sys.stderr) log_end(success=False, steps=0, score=0.0, rewards=[]) return 0.0 max_steps = MAX_STEPS_PER_TASK history: list[dict[str, str]] = [] messages: list[dict[str, str]] = [{"role": "system", "content": SYSTEM_PROMPT}] try: while not obs.get("done", False): steps_taken += 1 prompt = build_prompt(obs, task_id, steps_taken, max_steps, history) messages.append({"role": "user", "content": prompt}) try: llm_text = call_llm(messages) except Exception: action = {"command": "submit", "args": {}} llm_text = json.dumps(action) messages.append({"role": "assistant", "content": llm_text}) action = parse_llm_response(llm_text) history.append({"action": json.dumps(action), "result": ""}) try: obs = step_env(action) except Exception: try: obs = step_env({"command": "submit", "args": {}}) except Exception: obs = {"done": True, "reward": 0.0} break if history: history[-1]["result"] = obs.get("output", "")[:500] reward = obs.get("reward") reward_val = float(reward) if reward is not None else 0.0 done = obs.get("done", False) error = None rewards.append(reward_val) action_str = json.dumps(action) log_step(step=steps_taken, action=action_str, reward=reward_val, done=done, error=error) if steps_taken >= max_steps and not obs.get("done", False): try: obs = step_env({"command": "submit", "args": {}}) reward = obs.get("reward") reward_val = float(reward) if reward is not None else 0.0 rewards.append(reward_val) steps_taken += 1 log_step( step=steps_taken, action='{"command":"submit","args":{}}', reward=reward_val, done=obs.get("done", False), error=None, ) except Exception: obs = {"done": True, "reward": 0.0} break if len(messages) > 41: messages = [messages[0]] + messages[-40:] score = obs.get("reward", 0.0) if score is None: score = 0.0 score = float(score) score = min(max(score, 0.0), 1.0) success = score >= SUCCESS_SCORE_THRESHOLD finally: log_end(success=success, steps=steps_taken, score=score, rewards=rewards) return score def main() -> None: try: resp = requests.get(f"{ENV_URL}/health", timeout=10) resp.raise_for_status() except Exception as exc: print(f"[FATAL] Cannot reach environment at {ENV_URL}: {exc}", file=sys.stderr) sys.exit(1) scores: dict[str, float] = {} for task_id in TASKS: try: score = run_task(task_id) scores[task_id] = score except Exception as exc: traceback.print_exc(file=sys.stderr) scores[task_id] = 0.0 log_end(success=False, steps=0, score=0.0, rewards=[]) total = sum(scores.values()) avg = total / len(scores) if scores else 0.0 print(f"\nAverage score: {avg:.2f}", file=sys.stderr) if __name__ == "__main__": main()