analyst-buddy / server /verifier.py
hjerpe's picture
F006/F008: serve Qwen models + model switcher (vanilla-first)
656f91e verified
Raw
History Blame Contribute Delete
14.2 kB
"""Answer verification for SQLEnv using type-aware comparisons."""
from __future__ import annotations
import ast
import math
import re
# Dual-path import shim (package vs flat execution), mirroring
# sql_environment.py's strip_chart_block import. chart_intent is a leaf
# (json/re/pydantic only — no verifier import, no gradio/torch), so this is
# leaf->leaf with no import cycle and no heavy deps pulled in.
try:
from .chart_intent import strip_chart_block
except ImportError: # flat execution
from chart_intent import strip_chart_block # type: ignore[no-redef]
def _strip_answer_wrapping(text: str) -> str:
"""Remove common LLM wrapping artifacts from an answer string.
Strips markdown code fences, surrounding quotes, "Answer: " prefix,
and extra whitespace so the type-aware comparators see clean values.
"""
s = text.strip()
# Markdown code blocks: ```...``` or ```sql\n...\n```
if s.startswith("```") and s.endswith("```"):
# Language tag only if followed by newline (e.g. ```sql\n)
s = re.sub(r"^```(?:\w+\n|\n?)", "", s)
s = re.sub(r"\n?```$", "", s)
s = s.strip()
# "Answer: " or "answer:" prefix
s = re.sub(r"^[Aa]nswer:\s*", "", s)
# Surrounding quotes (single or double)
if len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"):
s = s[1:-1].strip()
return s
def _gold_is_degenerate(gold_text: str, gold_rows: list[tuple] | None) -> bool:
"""Return True when the gold target is empty/NULL (no legitimate pass exists).
Last-line complement to the env's upstream ``is_degenerate_gold`` HarnessError
in ``reset`` — re-expressed locally so ``verifier`` stays a leaf (no env import).
Degenerate when:
- ``gold_rows`` is an EXPLICIT empty list (``[]``) — an empty result set is
never a legitimate gold target in this corpus, regardless of ``gold_text``.
This closes the standalone hole where ``verify_answer(",", x,
answer_type="list", gold_rows=[])`` would otherwise score a comparator-empty
prediction as a pass (F007: ``verify_answer`` is a directly-callable leaf,
so its empty-gold contract must hold even though the live env never passes
``gold_rows=[]``), OR
- ``gold_rows`` is falsy/``None`` AND ``gold_text`` is empty/whitespace, OR
- ``gold_rows`` is a single all-NULL row (``[(None,)]`` / ``[(None, None)]``)
— mirrors ``sql_environment.is_degenerate_gold``'s all-NULL arm
(e.g. ``MAX()``/``MIN()`` over an empty set), regardless of column count, OR
- ``gold_text`` is a ``[``-prefixed literal that parses to an empty list
(or a list of empty/whitespace-only items).
A malformed ``[``-prefixed literal is caught and treated as NON-degenerate
(falls through to normal dispatch), matching the parsing pattern elsewhere.
"""
text = gold_text.strip()
if gold_rows is not None and len(gold_rows) == 0:
return True
if not gold_rows and not text:
return True
if gold_rows and len(gold_rows) == 1 and all(cell is None for cell in gold_rows[0]):
return True
if text.startswith("["):
try:
parsed = ast.literal_eval(text)
except (ValueError, SyntaxError):
return False
if isinstance(parsed, list):
return all(not str(item).strip() for item in parsed)
return False
def verify_answer(
predicted: str,
gold: str,
answer_type: str | None = None,
gold_rows: list[tuple] | None = None,
) -> bool:
"""Compare submitted and gold answers with type-aware dispatch."""
predicted_text = "" if predicted is None else str(predicted)
gold_text = "" if gold is None else str(gold)
# Self-sufficient scrubbing: strip any fenced chart block from the predicted
# answer BEFORE type dispatch, so scoring is correct even when verify_answer
# is called directly without the env's pre-strip. Gated on the ```chart```
# marker so a plain ```sql``` / ``` code fence is left for
# _strip_answer_wrapping to unwrap (strip_chart_block's orphan-fence scrub
# would otherwise remove only the closing fence and break the unwrap).
if "```chart" in predicted_text.lower():
predicted_text = strip_chart_block(predicted_text)
predicted_text = _strip_answer_wrapping(predicted_text)
if not predicted_text.strip():
return False
# Empty/NULL-gold guard: an empty result is never a legitimate correct
# answer in this corpus, so never score a degenerate gold as a pass.
if _gold_is_degenerate(gold_text, gold_rows):
return False
match answer_type:
case "integer":
return _compare_integer(predicted_text, gold_text)
case "float":
return _compare_float(predicted_text, gold_text)
case "list":
return _compare_list(predicted_text, gold_text, gold_rows)
case "table":
return _compare_table(predicted_text, gold_text, gold_rows)
case "string":
return _compare_string(predicted_text, gold_text)
case _:
return _compare_string(predicted_text, gold_text)
def _numeric_canonical(value: str) -> str | None:
"""Return a canonical numeric string for ``value``, or None if it must stay text.
Collapses numerically-equal forms so set/sorted cell comparison treats them
as one element: "3"/"3.0"/"3.00"/"3e0" -> "3"; "1.5"/"1.50" -> "1.5".
Returns None (caller falls back to text normalization) when ``value``:
- is empty/whitespace,
- is a leading-zero integer ("01234", "+0123") — protects zip/ID codes
(BUT "0", "0.0", "0.5", "-0.5" are valid numbers, not leading-zero),
- is non-finite ("nan", "inf", "-inf", "Infinity") — these are NOT
canonicalizable numbers, so they fall back to string comparison and
never reach ``int(f)`` (which would raise ValueError/OverflowError — C1),
- is not a recognised number shape ("$5", "1,000", "1_000", "abc", "Feil").
Note the regex rejects underscores, so Python's "1_000" int literal stays
a string and never collapses to "1000" (C4).
Integer-shaped input is canonicalized with EXACT ``int()`` (unbounded), so two
distinct 16+ digit IDs never collapse via float53 rounding (C2:
"9999999999999999" stays itself, not "10000000000000000"). Decimal/exponent
forms use ``float`` then a fixed 12-significant-figure format (``:.12g``) to
kill float-repr noise (C3: 0.1+0.2 == "0.3"); a whole-valued float collapses
to its int string ("3.0"->"3", "476090.0"->"476090").
Note: "1e3" DOES parse and canonicalizes to "1000" (numerically correct;
documented + covered by a test).
"""
s = str(value).strip()
if not s:
return None
# Reject leading-zero integers ("01234") and "+0123" to protect zip/ID codes,
# but allow "0", "0.0", "0.5" (a single 0 before a "." is a real number).
body = s[1:] if s[:1] in "+-" else s
if len(body) > 1 and body[0] == "0" and body[1] != ".":
return None
# Integer-shaped: canonicalize with EXACT, unbounded int() so large IDs never
# collapse via float53 rounding (C2). The leading-zero guard above already
# rejected "01234"/"+0123", so these stay strings.
if re.fullmatch(r"[+-]?\d+", s):
return str(int(s))
# Decimal/exponent forms only (must contain a ".", "e", or "E"): float-parse,
# then fix to 12 significant figures to kill repr noise (C3). A non-finite
# parse (nan/inf) is NOT a canonicalizable number — return None (C1).
if re.fullmatch(r"[+-]?(\d+\.\d*|\.\d+|\d+)([eE][+-]?\d+)?", s) and any(
c in s for c in ".eE"
):
f = float(s)
if not math.isfinite(f):
return None
if f == int(f):
return str(int(f))
return f"{f:.12g}"
return None
# Non-finite numeric tokens (the string forms ``float()`` accepts). A NaN/inf
# gold is a degenerate/unusable target for this oracle, and IEEE NaN is never
# equal to itself — so a predicted "nan"/"inf" must NEVER score as a pass, even
# against a NaN/inf gold cell that stringifies the same way. ``_normalize_value``
# poisons these to a per-occurrence sentinel so two non-finite cells never
# compare equal.
_NONFINITE_TOKENS = frozenset(
{"nan", "inf", "-inf", "+inf", "infinity", "-infinity", "+infinity"}
)
# Monotonic counter making each non-finite normalization unique. The COUNTER
# value is run-dependent, but the comparison OUTCOME is deterministic: a
# non-finite cell never equals any other cell (including another "nan"), exactly
# as IEEE NaN-vs-NaN. Module-level mutable int guarded by ``global``.
_nonfinite_counter = 0
def _normalize_value(value: str) -> str:
"""Normalize for comparison.
If ``value`` is safely numeric (per ``_numeric_canonical``), return its
canonical numeric form so "3"=="3.0" holds in string/list/table cells;
else lowercase + collapse-whitespace as before. This is the single seam:
``_parse_list_values``, ``_parse_table_rows``, and ``_compare_string`` all
route cells through here, so the numeric rule applies without rewriting any
set/sorted/multiset logic. Leading-zero/currency strings are unaffected
because ``_numeric_canonical`` returns None for them.
Non-finite tokens ("nan"/"inf"/...) get a per-occurrence sentinel so they
never compare equal to anything (matching IEEE NaN semantics and refusing to
score a degenerate NaN/inf gold as a pass — C1).
"""
text = "" if value is None else str(value)
canonical = _numeric_canonical(text)
if canonical is not None:
return canonical
collapsed = " ".join(text.strip().lower().split())
if collapsed in _NONFINITE_TOKENS:
global _nonfinite_counter
_nonfinite_counter += 1
return f"\x00nonfinite:{_nonfinite_counter}"
return collapsed
def _compare_integer(predicted: str, gold: str) -> bool:
"""Compare integer values after coercing with ``int(float(x))``."""
try:
return int(float(predicted)) == int(float(gold))
except (TypeError, ValueError):
return False
def _compare_float(predicted: str, gold: str, tolerance: float = 0.01) -> bool:
"""Compare float values using a relative tolerance."""
try:
predicted_value = float(predicted)
gold_value = float(gold)
except (TypeError, ValueError):
return False
if gold_value == 0.0:
return abs(predicted_value - gold_value) <= 1e-9
return abs(predicted_value - gold_value) <= tolerance * abs(gold_value)
def _compare_string(predicted: str, gold: str) -> bool:
"""Compare two strings with normalization."""
return _normalize_value(predicted) == _normalize_value(gold)
def _parse_list_values(raw: str) -> set[str]:
"""Parse comma/newline/pipe-separated values into a normalized set.
Handles plain delimited strings and Python list representations:
"121\\n111\\n171" -> {"121", "111", "171"}
"[121, 111, 171]" -> {"121", "111", "171"}
"['Feil', 'Fisher']" -> {"feil", "fisher"}
"""
text = raw.strip()
# Try Python literal (e.g., [121, 111] or ['Feil', 'Fisher'])
if text.startswith("["):
try:
parsed = ast.literal_eval(text)
if isinstance(parsed, list):
return {
_normalize_value(str(item)) for item in parsed if str(item).strip()
}
except (ValueError, SyntaxError):
pass
tokens = re.split(r"\s*(?:,|\n|\|)\s*", text)
normalized = {_normalize_value(token) for token in tokens if token.strip()}
return normalized
def _compare_list(
predicted: str,
gold: str,
gold_rows: list[tuple] | None = None,
) -> bool:
"""Compare list-like answers as order-insensitive sets."""
predicted_set = _parse_list_values(predicted)
if gold_rows is not None:
gold_set = {
_normalize_value(str(cell))
for row in gold_rows
for cell in row
if str(cell).strip()
}
else:
gold_set = _parse_list_values(gold)
return predicted_set == gold_set
def _parse_table_rows(raw: str) -> list[tuple[str, ...]]:
"""Parse a table answer string into normalized rows.
Supports formats:
- Pipe-separated rows: "111 | 1\\n121 | 2"
- Python list-of-lists: "[[111, 1], [121, 2]]"
- Numbered rows: "1. 111 | 1\\n2. 121 | 2"
"""
text = raw.strip()
if not text:
return []
# Try Python literal (list-of-lists from gold_answer storage)
if text.startswith("["):
try:
parsed = ast.literal_eval(text)
if isinstance(parsed, list):
return [
tuple(_normalize_value(str(cell)) for cell in row)
for row in parsed
if isinstance(row, (list, tuple))
]
except (ValueError, SyntaxError):
pass
rows = []
for line in text.split("\n"):
line = line.strip()
if not line:
continue
# Strip leading numbering: "1. value | value"
line = re.sub(r"^\d+\.\s*", "", line)
cells = [_normalize_value(cell) for cell in re.split(r"\s*\|\s*", line)]
if any(c for c in cells):
rows.append(tuple(cells))
return rows
def _compare_table(
predicted: str,
gold: str,
gold_rows: list[tuple] | None = None,
) -> bool:
"""Compare table answers row-by-row with cell-level normalization.
Order-insensitive: rows are compared as multisets (sorted).
"""
pred_rows = _parse_table_rows(predicted)
if gold_rows is not None:
gold_normalized = sorted(
tuple(_normalize_value(str(cell)) for cell in row) for row in gold_rows
)
else:
gold_normalized = sorted(_parse_table_rows(gold))
# Sorted comparison preserves duplicate counts, acting as multiset equality
pred_normalized = sorted(pred_rows)
return pred_normalized == gold_normalized