DNSArenaEnv / server /dns_environment.py
PraneshkumarR's picture
Fix 3 critical bugs: trailing-dot grading, hard task difficulty, invalid task_id crash, add IP to easy task description
b4d1f04
"""DNS zone-file debugging environment for the OpenEnv hackathon.
Provides ``DNSEnvironment`` -- a stateful, step/reset/state RL environment
where an agent diagnoses and fixes broken DNS zone files.
"""
from __future__ import annotations
import copy
import random
import uuid
from typing import Any
# ---------------------------------------------------------------------------
# Dual-import pattern so the module works both as part of the package
# (``from dns_env.server.dns_environment import …``) and when executed
# directly (``python dns_environment.py``).
# ---------------------------------------------------------------------------
try:
from .dns_utils import (
DNSRecord,
render_zone_file,
validate_zone,
simulate_dig,
grade_zone,
)
from .tasks import get_task, TASK_IDS
except ImportError:
from dns_utils import ( # type: ignore[no-redef]
DNSRecord,
render_zone_file,
validate_zone,
simulate_dig,
grade_zone,
)
from tasks import get_task, TASK_IDS # type: ignore[no-redef]
try:
from ..models import Action, Observation, State
except ImportError:
from models import Action, Observation, State # type: ignore[no-redef]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _render_zone_indexed(records: list[DNSRecord], origin: str) -> str:
"""Render a zone file with ``[N]`` index comments for each record.
The indices let the agent reference records by number when editing or
deleting.
"""
lines: list[str] = []
lines.append(f"$ORIGIN {origin}.")
# Determine default TTL from SOA if available
default_ttl = 86400
if records:
first = records[0]
if first.ttl is not None:
default_ttl = first.ttl
lines.append(f"$TTL {default_ttl}")
lines.append("")
for idx, rec in enumerate(records):
ttl_str = f"{rec.ttl}" if rec.ttl is not None else ""
parts = [rec.name, ttl_str, rec.rclass, rec.rtype, rec.rdata]
line = "\t".join(p for p in parts if p)
lines.append(f"{line} ; [{idx}]")
return "\n".join(lines)
def _resolve_zone(zones: dict[str, list[DNSRecord]], requested: str | None) -> tuple[str, list[DNSRecord] | None]:
"""Return ``(zone_name, records)`` for the requested zone.
If *requested* is ``None``, the first zone is used. Returns
``(zone_name, None)`` when the zone cannot be found.
"""
if not zones:
return ("", None)
if requested is None:
zone_name = next(iter(zones))
else:
# Normalise: strip trailing dot if present
zone_name = requested.rstrip(".")
records = zones.get(zone_name)
return (zone_name, records)
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
class DNSEnvironment:
"""Stateful DNS zone-file debugging environment.
Lifecycle
---------
1. ``reset()`` -- load a task, populate ``self.zones``.
2. ``step(action)`` -- process agent commands, return observations.
3. ``state`` -- read-only snapshot of episode metadata.
"""
VALID_COMMANDS: list[str] = [
"view_zone",
"add_record",
"edit_record",
"delete_record",
"check_zone",
"dig",
"submit",
]
# ------------------------------------------------------------------
# Initialisation
# ------------------------------------------------------------------
def __init__(self) -> None:
self.zones: dict[str, list[DNSRecord]] = {}
self.task_config: dict[str, Any] = {}
self.episode_id: str | None = None
self.step_count: int = 0
self._done: bool = False
self._last_reward: float | None = None
self._task_cycle_index: int = 0
# ------------------------------------------------------------------
# reset / step / state
# ------------------------------------------------------------------
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
options: dict[str, Any] | None = None,
) -> Observation:
"""Reset the environment and load a new task.
Parameters
----------
seed:
Optional RNG seed for reproducibility.
episode_id:
Optional caller-supplied episode identifier. A UUID is
generated when not provided.
options:
Dict that may contain ``task_id`` (one of
:pydata:`TASK_IDS`). When omitted the environment cycles
through the available tasks round-robin.
"""
if seed is not None:
random.seed(seed)
options = options or {}
task_id: str | None = options.get("task_id")
# Pick a task -------------------------------------------------
if task_id is None:
task_id = TASK_IDS[self._task_cycle_index % len(TASK_IDS)]
self._task_cycle_index += 1
try:
self.task_config = get_task(task_id)
except ValueError:
# Invalid task_id — fall back to first task
task_id = TASK_IDS[0]
self.task_config = get_task(task_id)
# Deep-copy zones so original task data stays pristine ---------
self.zones = {
name: copy.deepcopy(records)
for name, records in self.task_config["zones"].items()
}
# Episode bookkeeping -----------------------------------------
self.episode_id = episode_id or uuid.uuid4().hex
self.step_count = 0
self._done = False
self._last_reward = None
return Observation(
output=(
f"Episode {self.episode_id} started.\n\n"
f"Task: {self.task_config['description']}\n\n"
f"Available zones: {', '.join(self.zones.keys())}\n"
"Use 'view_zone' to inspect a zone, then fix any issues and 'submit' when ready."
),
task_description=self.task_config["description"],
zone_names=list(self.zones.keys()),
done=False,
)
def step(self, action: Action) -> Observation:
"""Process *action* and return the resulting observation.
If the episode is already done, further steps return a terminal
observation. When ``step_count`` reaches ``max_steps`` the
environment auto-submits.
"""
if not self.task_config:
return Observation(
output="Error: environment has not been reset. Call /reset first.",
done=False,
)
if self._done:
return Observation(
output="Episode is already done. Call /reset to start a new episode.",
done=True,
reward=self._last_reward,
zone_names=list(self.zones.keys()),
task_description=self.task_config.get("description", ""),
)
self.step_count += 1
max_steps: int = self.task_config.get("max_steps", 30)
# Auto-submit when budget exhausted ----------------------------
if self.step_count >= max_steps and action.command != "submit":
auto_obs = self._handle_submit({})
auto_obs.output = (
f"Step limit ({max_steps}) reached -- auto-submitting.\n\n"
+ auto_obs.output
)
return auto_obs
# Dispatch to handler ------------------------------------------
command = action.command.strip().lower()
handler = {
"view_zone": self._handle_view_zone,
"add_record": self._handle_add_record,
"edit_record": self._handle_edit_record,
"delete_record": self._handle_delete_record,
"check_zone": self._handle_check_zone,
"dig": self._handle_dig,
"submit": self._handle_submit,
}.get(command)
if handler is None:
return self._obs(
f"Error: unknown command '{action.command}'. "
f"Available commands: {', '.join(self.VALID_COMMANDS)}"
)
try:
return handler(action.args)
except Exception as exc: # pragma: no cover -- safety net
return self._obs(f"Error processing '{command}': {exc}")
@property
def state(self) -> State:
"""Read-only snapshot of the current episode state."""
return State(
episode_id=self.episode_id,
step_count=self.step_count,
task_id=self.task_config.get("task_id", ""),
max_steps=self.task_config.get("max_steps", 30),
)
# ------------------------------------------------------------------
# Command handlers (private)
# ------------------------------------------------------------------
def _handle_view_zone(self, args: dict[str, Any]) -> Observation:
zone_name, records = _resolve_zone(self.zones, args.get("zone"))
if records is None:
return self._obs(self._zone_not_found_msg(args.get("zone")))
rendered = _render_zone_indexed(records, zone_name)
return self._obs(f"Zone: {zone_name}\n\n{rendered}")
# -- add_record ----------------------------------------------------
def _handle_add_record(self, args: dict[str, Any]) -> Observation:
zone_name, records = _resolve_zone(self.zones, args.get("zone"))
if records is None:
return self._obs(self._zone_not_found_msg(args.get("zone")))
# Validate required fields
missing = [f for f in ("name", "rtype", "rdata") if f not in args]
if missing:
return self._obs(
f"Error: add_record requires args: name, rtype, rdata. "
f"Missing: {', '.join(missing)}"
)
new_record = DNSRecord(
name=str(args["name"]),
rtype=str(args["rtype"]).upper(),
rdata=str(args["rdata"]),
ttl=int(args["ttl"]) if "ttl" in args else None,
rclass=str(args.get("rclass", "IN")),
)
records.append(new_record)
rendered = _render_zone_indexed(records, zone_name)
return self._obs(
f"Record added to {zone_name} at index [{len(records) - 1}].\n\n{rendered}"
)
# -- edit_record ---------------------------------------------------
def _handle_edit_record(self, args: dict[str, Any]) -> Observation:
zone_name, records = _resolve_zone(self.zones, args.get("zone"))
if records is None:
return self._obs(self._zone_not_found_msg(args.get("zone")))
if "index" not in args:
return self._obs("Error: edit_record requires 'index' in args.")
try:
index = int(args["index"])
except (ValueError, TypeError):
return self._obs(f"Error: 'index' must be an integer, got '{args['index']}'.")
if index < 0 or index >= len(records):
return self._obs(
f"Error: index {index} out of range. "
f"Valid range: 0..{len(records) - 1}."
)
rec = records[index]
# Partial update -- only touch fields the agent supplies
if "name" in args:
rec.name = str(args["name"])
if "rtype" in args:
rec.rtype = str(args["rtype"]).upper()
if "rdata" in args:
rec.rdata = str(args["rdata"])
if "ttl" in args:
rec.ttl = int(args["ttl"]) if args["ttl"] is not None else None
if "rclass" in args:
rec.rclass = str(args["rclass"])
rendered = _render_zone_indexed(records, zone_name)
return self._obs(
f"Record [{index}] in {zone_name} updated.\n\n{rendered}"
)
# -- delete_record -------------------------------------------------
def _handle_delete_record(self, args: dict[str, Any]) -> Observation:
zone_name, records = _resolve_zone(self.zones, args.get("zone"))
if records is None:
return self._obs(self._zone_not_found_msg(args.get("zone")))
if "index" not in args:
return self._obs("Error: delete_record requires 'index' in args.")
try:
index = int(args["index"])
except (ValueError, TypeError):
return self._obs(f"Error: 'index' must be an integer, got '{args['index']}'.")
if index < 0 or index >= len(records):
return self._obs(
f"Error: index {index} out of range. "
f"Valid range: 0..{len(records) - 1}."
)
# Protect the SOA record
if index == 0 and records[0].rtype.upper() == "SOA":
return self._obs(
"Error: cannot delete the SOA record (index 0). "
"Edit it instead if it needs changes."
)
deleted = records.pop(index)
rendered = _render_zone_indexed(records, zone_name)
return self._obs(
f"Record [{index}] ({deleted.name} {deleted.rtype} {deleted.rdata}) "
f"deleted from {zone_name}.\n\n{rendered}"
)
# -- check_zone ----------------------------------------------------
def _handle_check_zone(self, args: dict[str, Any]) -> Observation:
zone_name, records = _resolve_zone(self.zones, args.get("zone"))
if records is None:
return self._obs(self._zone_not_found_msg(args.get("zone")))
errors = validate_zone(records, zone_name)
if not errors:
return self._obs(
f"Zone validation passed for {zone_name}. No errors found."
)
error_lines = "\n".join(
f" {i + 1}. {err}" for i, err in enumerate(errors)
)
return self._obs(
f"Zone {zone_name} has {len(errors)} error(s):\n{error_lines}"
)
# -- dig -----------------------------------------------------------
def _handle_dig(self, args: dict[str, Any]) -> Observation:
qname: str | None = args.get("qname")
qtype: str | None = args.get("qtype")
if not qname or not qtype:
return self._obs(
"Error: dig requires 'qname' and 'qtype' in args. "
"Example: {\"qname\": \"www.example.com\", \"qtype\": \"A\"}"
)
# Determine which zone to query --------------------------------
zone_name_arg = args.get("zone")
if zone_name_arg is not None:
zone_name, records = _resolve_zone(self.zones, zone_name_arg)
if records is None:
return self._obs(self._zone_not_found_msg(zone_name_arg))
else:
# Auto-detect zone: pick the zone whose origin is a suffix
# of qname, preferring the longest match.
zone_name, records = self._match_zone_for_qname(qname)
if records is None:
# Fall back to first zone
zone_name, records = _resolve_zone(self.zones, None)
if records is None:
return self._obs("Error: no zones available.")
result = simulate_dig(records, zone_name, str(qname), str(qtype).upper())
return self._obs(f";; Querying {zone_name} for {qname} {qtype}\n\n{result}")
# -- submit --------------------------------------------------------
def _handle_submit(self, args: dict[str, Any]) -> Observation:
required_checks = self.task_config.get("required_checks", [])
original_correct_raw = self.task_config.get("original_correct", {})
# original_correct is a dict {zone_name: [(name, type, rdata), ...]}
# grade_zone expects a flat list of tuples for a single zone
if not isinstance(original_correct_raw, dict):
original_correct_raw = {}
if len(self.zones) == 1:
# Single-zone task
zone_name = next(iter(self.zones))
records = self.zones[zone_name]
oc = original_correct_raw.get(zone_name, [])
result = grade_zone(
records, zone_name, required_checks, oc or None
)
score = result.get("score", 0.0)
breakdown = self._format_grading(zone_name, result)
else:
# Multi-zone task: grade each zone independently and average
zone_results: dict[str, dict[str, Any]] = {}
total_score = 0.0
for zn, recs in self.zones.items():
# Filter checks relevant to this zone if checks carry a
# ``zone`` key; otherwise pass all checks to every zone.
zone_checks = [
c for c in required_checks
if c.get("zone", zn) == zn
]
# Skip checks that don't have a zone key or qname (like delegation_consistency)
if not zone_checks:
zone_checks = [
c for c in required_checks
if "zone" not in c and "qname" in c
]
oc = original_correct_raw.get(zn, [])
res = grade_zone(recs, zn, zone_checks, oc or None)
zone_results[zn] = res
total_score += res.get("score", 0.0)
score = total_score / max(len(zone_results), 1)
parts = [
self._format_grading(zn, res)
for zn, res in zone_results.items()
]
breakdown = "\n---\n".join(parts)
self._done = True
self._last_reward = score
return Observation(
output=(
f"=== Submission Graded ===\n\n"
f"{breakdown}\n\n"
f"Final score: {score:.2f}"
),
task_description=self.task_config.get("description", ""),
zone_names=list(self.zones.keys()),
done=True,
reward=score,
)
# ------------------------------------------------------------------
# Utilities (private)
# ------------------------------------------------------------------
def _obs(self, output: str) -> Observation:
"""Convenience builder for a non-terminal observation."""
return Observation(
output=output,
task_description=self.task_config.get("description", ""),
zone_names=list(self.zones.keys()),
done=False,
)
@staticmethod
def _zone_not_found_msg(requested: str | None) -> str:
if requested:
return f"Error: zone '{requested}' not found. Use 'view_zone' to list available zones."
return "Error: no zones available. Has the environment been reset?"
def _match_zone_for_qname(self, qname: str) -> tuple[str, list[DNSRecord] | None]:
"""Find the zone whose origin is the longest suffix of *qname*."""
qname_lower = qname.rstrip(".").lower()
best_name: str | None = None
best_len = -1
for zn in self.zones:
zn_lower = zn.lower()
if qname_lower == zn_lower or qname_lower.endswith("." + zn_lower):
if len(zn_lower) > best_len:
best_name = zn
best_len = len(zn_lower)
if best_name is not None:
return (best_name, self.zones[best_name])
return ("", None)
@staticmethod
def _format_grading(zone_name: str, result: dict[str, Any]) -> str:
"""Format a single zone's grading result into readable text."""
lines: list[str] = [f"Zone: {zone_name}"]
lines.append(f" Score: {result.get('score', 0.0):.2f}")
lines.append(f" Passed: {result.get('passed', 0)}, Failed: {result.get('failed', 0)}")
details = result.get("details", [])
if details:
lines.append(" Details:")
for d in details:
lines.append(f" {d}")
return "\n".join(lines)