File size: 20,220 Bytes
0d58c6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4d1f04
 
 
 
 
 
0d58c6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
"""DNS zone-file debugging environment for the OpenEnv hackathon.

Provides ``DNSEnvironment`` -- a stateful, step/reset/state RL environment
where an agent diagnoses and fixes broken DNS zone files.
"""

from __future__ import annotations

import copy
import random
import uuid
from typing import Any

# ---------------------------------------------------------------------------
# Dual-import pattern so the module works both as part of the package
# (``from dns_env.server.dns_environment import …``) and when executed
# directly (``python dns_environment.py``).
# ---------------------------------------------------------------------------

try:
    from .dns_utils import (
        DNSRecord,
        render_zone_file,
        validate_zone,
        simulate_dig,
        grade_zone,
    )
    from .tasks import get_task, TASK_IDS
except ImportError:
    from dns_utils import (  # type: ignore[no-redef]
        DNSRecord,
        render_zone_file,
        validate_zone,
        simulate_dig,
        grade_zone,
    )
    from tasks import get_task, TASK_IDS  # type: ignore[no-redef]

try:
    from ..models import Action, Observation, State
except ImportError:
    from models import Action, Observation, State  # type: ignore[no-redef]


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _render_zone_indexed(records: list[DNSRecord], origin: str) -> str:
    """Render a zone file with ``[N]`` index comments for each record.

    The indices let the agent reference records by number when editing or
    deleting.
    """
    lines: list[str] = []
    lines.append(f"$ORIGIN {origin}.")
    # Determine default TTL from SOA if available
    default_ttl = 86400
    if records:
        first = records[0]
        if first.ttl is not None:
            default_ttl = first.ttl
    lines.append(f"$TTL {default_ttl}")
    lines.append("")
    for idx, rec in enumerate(records):
        ttl_str = f"{rec.ttl}" if rec.ttl is not None else ""
        parts = [rec.name, ttl_str, rec.rclass, rec.rtype, rec.rdata]
        line = "\t".join(p for p in parts if p)
        lines.append(f"{line}  ; [{idx}]")
    return "\n".join(lines)


def _resolve_zone(zones: dict[str, list[DNSRecord]], requested: str | None) -> tuple[str, list[DNSRecord] | None]:
    """Return ``(zone_name, records)`` for the requested zone.

    If *requested* is ``None``, the first zone is used.  Returns
    ``(zone_name, None)`` when the zone cannot be found.
    """
    if not zones:
        return ("", None)
    if requested is None:
        zone_name = next(iter(zones))
    else:
        # Normalise: strip trailing dot if present
        zone_name = requested.rstrip(".")
    records = zones.get(zone_name)
    return (zone_name, records)


# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------

class DNSEnvironment:
    """Stateful DNS zone-file debugging environment.

    Lifecycle
    ---------
    1. ``reset()`` -- load a task, populate ``self.zones``.
    2. ``step(action)`` -- process agent commands, return observations.
    3. ``state`` -- read-only snapshot of episode metadata.
    """

    VALID_COMMANDS: list[str] = [
        "view_zone",
        "add_record",
        "edit_record",
        "delete_record",
        "check_zone",
        "dig",
        "submit",
    ]

    # ------------------------------------------------------------------
    # Initialisation
    # ------------------------------------------------------------------

    def __init__(self) -> None:
        self.zones: dict[str, list[DNSRecord]] = {}
        self.task_config: dict[str, Any] = {}
        self.episode_id: str | None = None
        self.step_count: int = 0
        self._done: bool = False
        self._last_reward: float | None = None
        self._task_cycle_index: int = 0

    # ------------------------------------------------------------------
    # reset / step / state
    # ------------------------------------------------------------------

    def reset(
        self,
        seed: int | None = None,
        episode_id: str | None = None,
        options: dict[str, Any] | None = None,
    ) -> Observation:
        """Reset the environment and load a new task.

        Parameters
        ----------
        seed:
            Optional RNG seed for reproducibility.
        episode_id:
            Optional caller-supplied episode identifier.  A UUID is
            generated when not provided.
        options:
            Dict that may contain ``task_id`` (one of
            :pydata:`TASK_IDS`).  When omitted the environment cycles
            through the available tasks round-robin.
        """
        if seed is not None:
            random.seed(seed)

        options = options or {}
        task_id: str | None = options.get("task_id")

        # Pick a task -------------------------------------------------
        if task_id is None:
            task_id = TASK_IDS[self._task_cycle_index % len(TASK_IDS)]
            self._task_cycle_index += 1

        try:
            self.task_config = get_task(task_id)
        except ValueError:
            # Invalid task_id — fall back to first task
            task_id = TASK_IDS[0]
            self.task_config = get_task(task_id)

        # Deep-copy zones so original task data stays pristine ---------
        self.zones = {
            name: copy.deepcopy(records)
            for name, records in self.task_config["zones"].items()
        }

        # Episode bookkeeping -----------------------------------------
        self.episode_id = episode_id or uuid.uuid4().hex
        self.step_count = 0
        self._done = False
        self._last_reward = None

        return Observation(
            output=(
                f"Episode {self.episode_id} started.\n\n"
                f"Task: {self.task_config['description']}\n\n"
                f"Available zones: {', '.join(self.zones.keys())}\n"
                "Use 'view_zone' to inspect a zone, then fix any issues and 'submit' when ready."
            ),
            task_description=self.task_config["description"],
            zone_names=list(self.zones.keys()),
            done=False,
        )

    def step(self, action: Action) -> Observation:
        """Process *action* and return the resulting observation.

        If the episode is already done, further steps return a terminal
        observation.  When ``step_count`` reaches ``max_steps`` the
        environment auto-submits.
        """
        if not self.task_config:
            return Observation(
                output="Error: environment has not been reset. Call /reset first.",
                done=False,
            )

        if self._done:
            return Observation(
                output="Episode is already done. Call /reset to start a new episode.",
                done=True,
                reward=self._last_reward,
                zone_names=list(self.zones.keys()),
                task_description=self.task_config.get("description", ""),
            )

        self.step_count += 1
        max_steps: int = self.task_config.get("max_steps", 30)

        # Auto-submit when budget exhausted ----------------------------
        if self.step_count >= max_steps and action.command != "submit":
            auto_obs = self._handle_submit({})
            auto_obs.output = (
                f"Step limit ({max_steps}) reached -- auto-submitting.\n\n"
                + auto_obs.output
            )
            return auto_obs

        # Dispatch to handler ------------------------------------------
        command = action.command.strip().lower()
        handler = {
            "view_zone": self._handle_view_zone,
            "add_record": self._handle_add_record,
            "edit_record": self._handle_edit_record,
            "delete_record": self._handle_delete_record,
            "check_zone": self._handle_check_zone,
            "dig": self._handle_dig,
            "submit": self._handle_submit,
        }.get(command)

        if handler is None:
            return self._obs(
                f"Error: unknown command '{action.command}'. "
                f"Available commands: {', '.join(self.VALID_COMMANDS)}"
            )

        try:
            return handler(action.args)
        except Exception as exc:  # pragma: no cover -- safety net
            return self._obs(f"Error processing '{command}': {exc}")

    @property
    def state(self) -> State:
        """Read-only snapshot of the current episode state."""
        return State(
            episode_id=self.episode_id,
            step_count=self.step_count,
            task_id=self.task_config.get("task_id", ""),
            max_steps=self.task_config.get("max_steps", 30),
        )

    # ------------------------------------------------------------------
    # Command handlers (private)
    # ------------------------------------------------------------------

    def _handle_view_zone(self, args: dict[str, Any]) -> Observation:
        zone_name, records = _resolve_zone(self.zones, args.get("zone"))
        if records is None:
            return self._obs(self._zone_not_found_msg(args.get("zone")))
        rendered = _render_zone_indexed(records, zone_name)
        return self._obs(f"Zone: {zone_name}\n\n{rendered}")

    # -- add_record ----------------------------------------------------

    def _handle_add_record(self, args: dict[str, Any]) -> Observation:
        zone_name, records = _resolve_zone(self.zones, args.get("zone"))
        if records is None:
            return self._obs(self._zone_not_found_msg(args.get("zone")))

        # Validate required fields
        missing = [f for f in ("name", "rtype", "rdata") if f not in args]
        if missing:
            return self._obs(
                f"Error: add_record requires args: name, rtype, rdata. "
                f"Missing: {', '.join(missing)}"
            )

        new_record = DNSRecord(
            name=str(args["name"]),
            rtype=str(args["rtype"]).upper(),
            rdata=str(args["rdata"]),
            ttl=int(args["ttl"]) if "ttl" in args else None,
            rclass=str(args.get("rclass", "IN")),
        )
        records.append(new_record)

        rendered = _render_zone_indexed(records, zone_name)
        return self._obs(
            f"Record added to {zone_name} at index [{len(records) - 1}].\n\n{rendered}"
        )

    # -- edit_record ---------------------------------------------------

    def _handle_edit_record(self, args: dict[str, Any]) -> Observation:
        zone_name, records = _resolve_zone(self.zones, args.get("zone"))
        if records is None:
            return self._obs(self._zone_not_found_msg(args.get("zone")))

        if "index" not in args:
            return self._obs("Error: edit_record requires 'index' in args.")

        try:
            index = int(args["index"])
        except (ValueError, TypeError):
            return self._obs(f"Error: 'index' must be an integer, got '{args['index']}'.")

        if index < 0 or index >= len(records):
            return self._obs(
                f"Error: index {index} out of range. "
                f"Valid range: 0..{len(records) - 1}."
            )

        rec = records[index]

        # Partial update -- only touch fields the agent supplies
        if "name" in args:
            rec.name = str(args["name"])
        if "rtype" in args:
            rec.rtype = str(args["rtype"]).upper()
        if "rdata" in args:
            rec.rdata = str(args["rdata"])
        if "ttl" in args:
            rec.ttl = int(args["ttl"]) if args["ttl"] is not None else None
        if "rclass" in args:
            rec.rclass = str(args["rclass"])

        rendered = _render_zone_indexed(records, zone_name)
        return self._obs(
            f"Record [{index}] in {zone_name} updated.\n\n{rendered}"
        )

    # -- delete_record -------------------------------------------------

    def _handle_delete_record(self, args: dict[str, Any]) -> Observation:
        zone_name, records = _resolve_zone(self.zones, args.get("zone"))
        if records is None:
            return self._obs(self._zone_not_found_msg(args.get("zone")))

        if "index" not in args:
            return self._obs("Error: delete_record requires 'index' in args.")

        try:
            index = int(args["index"])
        except (ValueError, TypeError):
            return self._obs(f"Error: 'index' must be an integer, got '{args['index']}'.")

        if index < 0 or index >= len(records):
            return self._obs(
                f"Error: index {index} out of range. "
                f"Valid range: 0..{len(records) - 1}."
            )

        # Protect the SOA record
        if index == 0 and records[0].rtype.upper() == "SOA":
            return self._obs(
                "Error: cannot delete the SOA record (index 0). "
                "Edit it instead if it needs changes."
            )

        deleted = records.pop(index)
        rendered = _render_zone_indexed(records, zone_name)
        return self._obs(
            f"Record [{index}] ({deleted.name} {deleted.rtype} {deleted.rdata}) "
            f"deleted from {zone_name}.\n\n{rendered}"
        )

    # -- check_zone ----------------------------------------------------

    def _handle_check_zone(self, args: dict[str, Any]) -> Observation:
        zone_name, records = _resolve_zone(self.zones, args.get("zone"))
        if records is None:
            return self._obs(self._zone_not_found_msg(args.get("zone")))

        errors = validate_zone(records, zone_name)

        if not errors:
            return self._obs(
                f"Zone validation passed for {zone_name}. No errors found."
            )

        error_lines = "\n".join(
            f"  {i + 1}. {err}" for i, err in enumerate(errors)
        )
        return self._obs(
            f"Zone {zone_name} has {len(errors)} error(s):\n{error_lines}"
        )

    # -- dig -----------------------------------------------------------

    def _handle_dig(self, args: dict[str, Any]) -> Observation:
        qname: str | None = args.get("qname")
        qtype: str | None = args.get("qtype")

        if not qname or not qtype:
            return self._obs(
                "Error: dig requires 'qname' and 'qtype' in args. "
                "Example: {\"qname\": \"www.example.com\", \"qtype\": \"A\"}"
            )

        # Determine which zone to query --------------------------------
        zone_name_arg = args.get("zone")
        if zone_name_arg is not None:
            zone_name, records = _resolve_zone(self.zones, zone_name_arg)
            if records is None:
                return self._obs(self._zone_not_found_msg(zone_name_arg))
        else:
            # Auto-detect zone: pick the zone whose origin is a suffix
            # of qname, preferring the longest match.
            zone_name, records = self._match_zone_for_qname(qname)
            if records is None:
                # Fall back to first zone
                zone_name, records = _resolve_zone(self.zones, None)
                if records is None:
                    return self._obs("Error: no zones available.")

        result = simulate_dig(records, zone_name, str(qname), str(qtype).upper())
        return self._obs(f";; Querying {zone_name} for {qname} {qtype}\n\n{result}")

    # -- submit --------------------------------------------------------

    def _handle_submit(self, args: dict[str, Any]) -> Observation:
        required_checks = self.task_config.get("required_checks", [])
        original_correct_raw = self.task_config.get("original_correct", {})

        # original_correct is a dict {zone_name: [(name, type, rdata), ...]}
        # grade_zone expects a flat list of tuples for a single zone
        if not isinstance(original_correct_raw, dict):
            original_correct_raw = {}

        if len(self.zones) == 1:
            # Single-zone task
            zone_name = next(iter(self.zones))
            records = self.zones[zone_name]
            oc = original_correct_raw.get(zone_name, [])
            result = grade_zone(
                records, zone_name, required_checks, oc or None
            )
            score = result.get("score", 0.0)
            breakdown = self._format_grading(zone_name, result)
        else:
            # Multi-zone task: grade each zone independently and average
            zone_results: dict[str, dict[str, Any]] = {}
            total_score = 0.0
            for zn, recs in self.zones.items():
                # Filter checks relevant to this zone if checks carry a
                # ``zone`` key; otherwise pass all checks to every zone.
                zone_checks = [
                    c for c in required_checks
                    if c.get("zone", zn) == zn
                ]
                # Skip checks that don't have a zone key or qname (like delegation_consistency)
                if not zone_checks:
                    zone_checks = [
                        c for c in required_checks
                        if "zone" not in c and "qname" in c
                    ]
                oc = original_correct_raw.get(zn, [])
                res = grade_zone(recs, zn, zone_checks, oc or None)
                zone_results[zn] = res
                total_score += res.get("score", 0.0)
            score = total_score / max(len(zone_results), 1)
            parts = [
                self._format_grading(zn, res)
                for zn, res in zone_results.items()
            ]
            breakdown = "\n---\n".join(parts)

        self._done = True
        self._last_reward = score

        return Observation(
            output=(
                f"=== Submission Graded ===\n\n"
                f"{breakdown}\n\n"
                f"Final score: {score:.2f}"
            ),
            task_description=self.task_config.get("description", ""),
            zone_names=list(self.zones.keys()),
            done=True,
            reward=score,
        )

    # ------------------------------------------------------------------
    # Utilities (private)
    # ------------------------------------------------------------------

    def _obs(self, output: str) -> Observation:
        """Convenience builder for a non-terminal observation."""
        return Observation(
            output=output,
            task_description=self.task_config.get("description", ""),
            zone_names=list(self.zones.keys()),
            done=False,
        )

    @staticmethod
    def _zone_not_found_msg(requested: str | None) -> str:
        if requested:
            return f"Error: zone '{requested}' not found. Use 'view_zone' to list available zones."
        return "Error: no zones available. Has the environment been reset?"

    def _match_zone_for_qname(self, qname: str) -> tuple[str, list[DNSRecord] | None]:
        """Find the zone whose origin is the longest suffix of *qname*."""
        qname_lower = qname.rstrip(".").lower()
        best_name: str | None = None
        best_len = -1
        for zn in self.zones:
            zn_lower = zn.lower()
            if qname_lower == zn_lower or qname_lower.endswith("." + zn_lower):
                if len(zn_lower) > best_len:
                    best_name = zn
                    best_len = len(zn_lower)
        if best_name is not None:
            return (best_name, self.zones[best_name])
        return ("", None)

    @staticmethod
    def _format_grading(zone_name: str, result: dict[str, Any]) -> str:
        """Format a single zone's grading result into readable text."""
        lines: list[str] = [f"Zone: {zone_name}"]
        lines.append(f"  Score: {result.get('score', 0.0):.2f}")
        lines.append(f"  Passed: {result.get('passed', 0)}, Failed: {result.get('failed', 0)}")

        details = result.get("details", [])
        if details:
            lines.append("  Details:")
            for d in details:
                lines.append(f"    {d}")

        return "\n".join(lines)