Spaces:
Runtime error
Runtime error
File size: 5,341 Bytes
8c486a8 b439619 8c486a8 b439619 8c486a8 49d1c75 b439619 49d1c75 d2809e9 8c486a8 49d1c75 8c486a8 49d1c75 b439619 49d1c75 b439619 d2809e9 49d1c75 d2809e9 b439619 49d1c75 595e190 49d1c75 595e190 49d1c75 8c486a8 49d1c75 8c486a8 49d1c75 8c486a8 595e190 8c486a8 595e190 8c486a8 49d1c75 8c486a8 49d1c75 8c486a8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """Check 5: Reward grounding — verify flag values exist at expected paths."""
from __future__ import annotations
import re
import shlex
from open_range.protocols import CheckResult, ContainerSet, SnapshotSpec
_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _parse_db_path(path: str) -> tuple[str, str, str] | None:
"""Parse a DB flag path like ``db:database.table.column``.
Returns ``(database, table, column)`` or *None* if the path is not a
valid DB reference.
"""
prefix = path.split(":", 1)
if len(prefix) != 2:
return None
scheme, rest = prefix
if scheme not in ("db", "mysql"):
return None
parts = rest.split(".")
if len(parts) != 3:
return None
if not all(_IDENTIFIER_RE.fullmatch(part) for part in parts):
return None
return parts[0], parts[1], parts[2]
def _mysql_root_password(snapshot: SnapshotSpec) -> str:
"""Return the MySQL root password to use for validator DB checks."""
topology = snapshot.topology
if isinstance(topology, dict):
value = topology.get("mysql_root_password")
if isinstance(value, str) and value:
return value
return "root"
class RewardGroundingCheck:
"""For every declared flag, verify its value exists at the expected
location. File-based flags are checked via ``cat``. DB-stored flags
(``db:<database>.<table>.<column>``) are verified via a MySQL query.
"""
async def check(self, snapshot: SnapshotSpec, containers: ContainerSet) -> CheckResult:
flags = snapshot.flags
if not flags:
return CheckResult(
name="reward_grounding",
passed=False,
error="no flags defined in snapshot",
)
bad: list[dict] = []
for flag in flags:
host = flag.host
path = flag.path
# --- DB-stored flags -------------------------------------------
if path.startswith(("db:", "mysql:")):
# Deployment artifacts like "db:sql" are not flag locations.
db_ref = _parse_db_path(path)
if db_ref is None:
if path in {"db:sql", "mysql:sql"}:
continue
bad.append({
"flag": flag.id,
"error": f"invalid db flag path format: {path}",
})
continue
database, table, column = db_ref
query = f"SELECT `{column}` FROM `{database}`.`{table}` LIMIT 1"
mysql_pwd = _mysql_root_password(snapshot)
mysql_cmd = (
f"MYSQL_PWD={shlex.quote(mysql_pwd)} "
"mysql -u root -N "
f"-e {shlex.quote(query)}"
)
try:
result = await containers.exec_run(host, mysql_cmd)
except Exception as exc: # noqa: BLE001
bad.append({"flag": flag.id, "error": str(exc)})
continue
if result.exit_code != 0:
bad.append({
"flag": flag.id,
"error": (
result.combined_output
or f"mysql command failed (exit_code={result.exit_code})"
),
})
continue
output = result.stdout.strip() or result.combined_output.strip()
if flag.value not in output:
bad.append({
"flag": flag.id,
"expected": flag.value,
"got_snippet": output[:200],
})
continue
# --- Filesystem flags ------------------------------------------
if "/" not in path:
# Non-filesystem, non-DB flag path we don't understand.
bad.append({
"flag": flag.id,
"error": f"unknown flag path format: {path}",
})
continue
try:
result = await containers.exec_run(host, f"cat -- {shlex.quote(path)}")
except Exception as exc: # noqa: BLE001
bad.append({"flag": flag.id, "error": str(exc)})
continue
if result.exit_code != 0:
bad.append({
"flag": flag.id,
"error": (
result.combined_output
or f"cat command failed (exit_code={result.exit_code})"
),
})
continue
output = result.stdout.strip() or result.combined_output.strip()
if flag.value not in output:
bad.append({
"flag": flag.id,
"expected": flag.value,
"got_snippet": output[:200],
})
passed = len(bad) == 0
return CheckResult(
name="reward_grounding",
passed=passed,
details={"results": bad, "total_flags": len(flags)},
error="" if passed else f"{len(bad)} flag(s) not found at expected location",
)
|