File size: 5,341 Bytes
8c486a8
 
 
 
b439619
 
 
8c486a8
 
b439619
 
8c486a8
49d1c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b439619
 
49d1c75
 
 
d2809e9
 
 
 
 
 
 
 
 
 
8c486a8
49d1c75
 
 
8c486a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49d1c75
 
 
 
 
b439619
 
 
 
 
 
49d1c75
 
 
b439619
d2809e9
49d1c75
d2809e9
 
b439619
49d1c75
 
595e190
49d1c75
 
 
595e190
 
 
 
 
 
 
 
 
 
49d1c75
 
 
 
 
 
 
 
 
 
8c486a8
49d1c75
8c486a8
 
49d1c75
8c486a8
 
 
 
595e190
8c486a8
 
 
595e190
 
 
 
 
 
 
 
 
 
8c486a8
 
 
 
 
 
 
 
49d1c75
8c486a8
 
 
 
49d1c75
8c486a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Check 5: Reward grounding — verify flag values exist at expected paths."""

from __future__ import annotations

import re
import shlex

from open_range.protocols import CheckResult, ContainerSet, SnapshotSpec

_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")


def _parse_db_path(path: str) -> tuple[str, str, str] | None:
    """Parse a DB flag path like ``db:database.table.column``.

    Returns ``(database, table, column)`` or *None* if the path is not a
    valid DB reference.
    """
    prefix = path.split(":", 1)
    if len(prefix) != 2:
        return None
    scheme, rest = prefix
    if scheme not in ("db", "mysql"):
        return None
    parts = rest.split(".")
    if len(parts) != 3:
        return None
    if not all(_IDENTIFIER_RE.fullmatch(part) for part in parts):
        return None
    return parts[0], parts[1], parts[2]


def _mysql_root_password(snapshot: SnapshotSpec) -> str:
    """Return the MySQL root password to use for validator DB checks."""
    topology = snapshot.topology
    if isinstance(topology, dict):
        value = topology.get("mysql_root_password")
        if isinstance(value, str) and value:
            return value
    return "root"


class RewardGroundingCheck:
    """For every declared flag, verify its value exists at the expected
    location.  File-based flags are checked via ``cat``.  DB-stored flags
    (``db:<database>.<table>.<column>``) are verified via a MySQL query.
    """

    async def check(self, snapshot: SnapshotSpec, containers: ContainerSet) -> CheckResult:
        flags = snapshot.flags
        if not flags:
            return CheckResult(
                name="reward_grounding",
                passed=False,
                error="no flags defined in snapshot",
            )

        bad: list[dict] = []
        for flag in flags:
            host = flag.host
            path = flag.path

            # --- DB-stored flags -------------------------------------------
            if path.startswith(("db:", "mysql:")):
                # Deployment artifacts like "db:sql" are not flag locations.
                db_ref = _parse_db_path(path)
                if db_ref is None:
                    if path in {"db:sql", "mysql:sql"}:
                        continue
                    bad.append({
                        "flag": flag.id,
                        "error": f"invalid db flag path format: {path}",
                    })
                    continue

                database, table, column = db_ref
                query = f"SELECT `{column}` FROM `{database}`.`{table}` LIMIT 1"
                mysql_pwd = _mysql_root_password(snapshot)
                mysql_cmd = (
                    f"MYSQL_PWD={shlex.quote(mysql_pwd)} "
                    "mysql -u root -N "
                    f"-e {shlex.quote(query)}"
                )
                try:
                    result = await containers.exec_run(host, mysql_cmd)
                except Exception as exc:  # noqa: BLE001
                    bad.append({"flag": flag.id, "error": str(exc)})
                    continue
                if result.exit_code != 0:
                    bad.append({
                        "flag": flag.id,
                        "error": (
                            result.combined_output
                            or f"mysql command failed (exit_code={result.exit_code})"
                        ),
                    })
                    continue
                output = result.stdout.strip() or result.combined_output.strip()

                if flag.value not in output:
                    bad.append({
                        "flag": flag.id,
                        "expected": flag.value,
                        "got_snippet": output[:200],
                    })
                continue

            # --- Filesystem flags ------------------------------------------
            if "/" not in path:
                # Non-filesystem, non-DB flag path we don't understand.
                bad.append({
                    "flag": flag.id,
                    "error": f"unknown flag path format: {path}",
                })
                continue

            try:
                result = await containers.exec_run(host, f"cat -- {shlex.quote(path)}")
            except Exception as exc:  # noqa: BLE001
                bad.append({"flag": flag.id, "error": str(exc)})
                continue
            if result.exit_code != 0:
                bad.append({
                    "flag": flag.id,
                    "error": (
                        result.combined_output
                        or f"cat command failed (exit_code={result.exit_code})"
                    ),
                })
                continue
            output = result.stdout.strip() or result.combined_output.strip()

            if flag.value not in output:
                bad.append({
                    "flag": flag.id,
                    "expected": flag.value,
                    "got_snippet": output[:200],
                })

        passed = len(bad) == 0
        return CheckResult(
            name="reward_grounding",
            passed=passed,
            details={"results": bad, "total_flags": len(flags)},
            error="" if passed else f"{len(bad)} flag(s) not found at expected location",
        )