Spaces:
Running on Zero
Running on Zero
| """Answer verification for SQLEnv using type-aware comparisons.""" | |
| from __future__ import annotations | |
| import ast | |
| import math | |
| import re | |
| # Dual-path import shim (package vs flat execution), mirroring | |
| # sql_environment.py's strip_chart_block import. chart_intent is a leaf | |
| # (json/re/pydantic only — no verifier import, no gradio/torch), so this is | |
| # leaf->leaf with no import cycle and no heavy deps pulled in. | |
| try: | |
| from .chart_intent import strip_chart_block | |
| except ImportError: # flat execution | |
| from chart_intent import strip_chart_block # type: ignore[no-redef] | |
| def _strip_answer_wrapping(text: str) -> str: | |
| """Remove common LLM wrapping artifacts from an answer string. | |
| Strips markdown code fences, surrounding quotes, "Answer: " prefix, | |
| and extra whitespace so the type-aware comparators see clean values. | |
| """ | |
| s = text.strip() | |
| # Markdown code blocks: ```...``` or ```sql\n...\n``` | |
| if s.startswith("```") and s.endswith("```"): | |
| # Language tag only if followed by newline (e.g. ```sql\n) | |
| s = re.sub(r"^```(?:\w+\n|\n?)", "", s) | |
| s = re.sub(r"\n?```$", "", s) | |
| s = s.strip() | |
| # "Answer: " or "answer:" prefix | |
| s = re.sub(r"^[Aa]nswer:\s*", "", s) | |
| # Surrounding quotes (single or double) | |
| if len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"): | |
| s = s[1:-1].strip() | |
| return s | |
| def _gold_is_degenerate(gold_text: str, gold_rows: list[tuple] | None) -> bool: | |
| """Return True when the gold target is empty/NULL (no legitimate pass exists). | |
| Last-line complement to the env's upstream ``is_degenerate_gold`` HarnessError | |
| in ``reset`` — re-expressed locally so ``verifier`` stays a leaf (no env import). | |
| Degenerate when: | |
| - ``gold_rows`` is an EXPLICIT empty list (``[]``) — an empty result set is | |
| never a legitimate gold target in this corpus, regardless of ``gold_text``. | |
| This closes the standalone hole where ``verify_answer(",", x, | |
| answer_type="list", gold_rows=[])`` would otherwise score a comparator-empty | |
| prediction as a pass (F007: ``verify_answer`` is a directly-callable leaf, | |
| so its empty-gold contract must hold even though the live env never passes | |
| ``gold_rows=[]``), OR | |
| - ``gold_rows`` is falsy/``None`` AND ``gold_text`` is empty/whitespace, OR | |
| - ``gold_rows`` is a single all-NULL row (``[(None,)]`` / ``[(None, None)]``) | |
| — mirrors ``sql_environment.is_degenerate_gold``'s all-NULL arm | |
| (e.g. ``MAX()``/``MIN()`` over an empty set), regardless of column count, OR | |
| - ``gold_text`` is a ``[``-prefixed literal that parses to an empty list | |
| (or a list of empty/whitespace-only items). | |
| A malformed ``[``-prefixed literal is caught and treated as NON-degenerate | |
| (falls through to normal dispatch), matching the parsing pattern elsewhere. | |
| """ | |
| text = gold_text.strip() | |
| if gold_rows is not None and len(gold_rows) == 0: | |
| return True | |
| if not gold_rows and not text: | |
| return True | |
| if gold_rows and len(gold_rows) == 1 and all(cell is None for cell in gold_rows[0]): | |
| return True | |
| if text.startswith("["): | |
| try: | |
| parsed = ast.literal_eval(text) | |
| except (ValueError, SyntaxError): | |
| return False | |
| if isinstance(parsed, list): | |
| return all(not str(item).strip() for item in parsed) | |
| return False | |
| def verify_answer( | |
| predicted: str, | |
| gold: str, | |
| answer_type: str | None = None, | |
| gold_rows: list[tuple] | None = None, | |
| ) -> bool: | |
| """Compare submitted and gold answers with type-aware dispatch.""" | |
| predicted_text = "" if predicted is None else str(predicted) | |
| gold_text = "" if gold is None else str(gold) | |
| # Self-sufficient scrubbing: strip any fenced chart block from the predicted | |
| # answer BEFORE type dispatch, so scoring is correct even when verify_answer | |
| # is called directly without the env's pre-strip. Gated on the ```chart``` | |
| # marker so a plain ```sql``` / ``` code fence is left for | |
| # _strip_answer_wrapping to unwrap (strip_chart_block's orphan-fence scrub | |
| # would otherwise remove only the closing fence and break the unwrap). | |
| if "```chart" in predicted_text.lower(): | |
| predicted_text = strip_chart_block(predicted_text) | |
| predicted_text = _strip_answer_wrapping(predicted_text) | |
| if not predicted_text.strip(): | |
| return False | |
| # Empty/NULL-gold guard: an empty result is never a legitimate correct | |
| # answer in this corpus, so never score a degenerate gold as a pass. | |
| if _gold_is_degenerate(gold_text, gold_rows): | |
| return False | |
| match answer_type: | |
| case "integer": | |
| return _compare_integer(predicted_text, gold_text) | |
| case "float": | |
| return _compare_float(predicted_text, gold_text) | |
| case "list": | |
| return _compare_list(predicted_text, gold_text, gold_rows) | |
| case "table": | |
| return _compare_table(predicted_text, gold_text, gold_rows) | |
| case "string": | |
| return _compare_string(predicted_text, gold_text) | |
| case _: | |
| return _compare_string(predicted_text, gold_text) | |
| def _numeric_canonical(value: str) -> str | None: | |
| """Return a canonical numeric string for ``value``, or None if it must stay text. | |
| Collapses numerically-equal forms so set/sorted cell comparison treats them | |
| as one element: "3"/"3.0"/"3.00"/"3e0" -> "3"; "1.5"/"1.50" -> "1.5". | |
| Returns None (caller falls back to text normalization) when ``value``: | |
| - is empty/whitespace, | |
| - is a leading-zero integer ("01234", "+0123") — protects zip/ID codes | |
| (BUT "0", "0.0", "0.5", "-0.5" are valid numbers, not leading-zero), | |
| - is non-finite ("nan", "inf", "-inf", "Infinity") — these are NOT | |
| canonicalizable numbers, so they fall back to string comparison and | |
| never reach ``int(f)`` (which would raise ValueError/OverflowError — C1), | |
| - is not a recognised number shape ("$5", "1,000", "1_000", "abc", "Feil"). | |
| Note the regex rejects underscores, so Python's "1_000" int literal stays | |
| a string and never collapses to "1000" (C4). | |
| Integer-shaped input is canonicalized with EXACT ``int()`` (unbounded), so two | |
| distinct 16+ digit IDs never collapse via float53 rounding (C2: | |
| "9999999999999999" stays itself, not "10000000000000000"). Decimal/exponent | |
| forms use ``float`` then a fixed 12-significant-figure format (``:.12g``) to | |
| kill float-repr noise (C3: 0.1+0.2 == "0.3"); a whole-valued float collapses | |
| to its int string ("3.0"->"3", "476090.0"->"476090"). | |
| Note: "1e3" DOES parse and canonicalizes to "1000" (numerically correct; | |
| documented + covered by a test). | |
| """ | |
| s = str(value).strip() | |
| if not s: | |
| return None | |
| # Reject leading-zero integers ("01234") and "+0123" to protect zip/ID codes, | |
| # but allow "0", "0.0", "0.5" (a single 0 before a "." is a real number). | |
| body = s[1:] if s[:1] in "+-" else s | |
| if len(body) > 1 and body[0] == "0" and body[1] != ".": | |
| return None | |
| # Integer-shaped: canonicalize with EXACT, unbounded int() so large IDs never | |
| # collapse via float53 rounding (C2). The leading-zero guard above already | |
| # rejected "01234"/"+0123", so these stay strings. | |
| if re.fullmatch(r"[+-]?\d+", s): | |
| return str(int(s)) | |
| # Decimal/exponent forms only (must contain a ".", "e", or "E"): float-parse, | |
| # then fix to 12 significant figures to kill repr noise (C3). A non-finite | |
| # parse (nan/inf) is NOT a canonicalizable number — return None (C1). | |
| if re.fullmatch(r"[+-]?(\d+\.\d*|\.\d+|\d+)([eE][+-]?\d+)?", s) and any( | |
| c in s for c in ".eE" | |
| ): | |
| f = float(s) | |
| if not math.isfinite(f): | |
| return None | |
| if f == int(f): | |
| return str(int(f)) | |
| return f"{f:.12g}" | |
| return None | |
| # Non-finite numeric tokens (the string forms ``float()`` accepts). A NaN/inf | |
| # gold is a degenerate/unusable target for this oracle, and IEEE NaN is never | |
| # equal to itself — so a predicted "nan"/"inf" must NEVER score as a pass, even | |
| # against a NaN/inf gold cell that stringifies the same way. ``_normalize_value`` | |
| # poisons these to a per-occurrence sentinel so two non-finite cells never | |
| # compare equal. | |
| _NONFINITE_TOKENS = frozenset( | |
| {"nan", "inf", "-inf", "+inf", "infinity", "-infinity", "+infinity"} | |
| ) | |
| # Monotonic counter making each non-finite normalization unique. The COUNTER | |
| # value is run-dependent, but the comparison OUTCOME is deterministic: a | |
| # non-finite cell never equals any other cell (including another "nan"), exactly | |
| # as IEEE NaN-vs-NaN. Module-level mutable int guarded by ``global``. | |
| _nonfinite_counter = 0 | |
| def _normalize_value(value: str) -> str: | |
| """Normalize for comparison. | |
| If ``value`` is safely numeric (per ``_numeric_canonical``), return its | |
| canonical numeric form so "3"=="3.0" holds in string/list/table cells; | |
| else lowercase + collapse-whitespace as before. This is the single seam: | |
| ``_parse_list_values``, ``_parse_table_rows``, and ``_compare_string`` all | |
| route cells through here, so the numeric rule applies without rewriting any | |
| set/sorted/multiset logic. Leading-zero/currency strings are unaffected | |
| because ``_numeric_canonical`` returns None for them. | |
| Non-finite tokens ("nan"/"inf"/...) get a per-occurrence sentinel so they | |
| never compare equal to anything (matching IEEE NaN semantics and refusing to | |
| score a degenerate NaN/inf gold as a pass — C1). | |
| """ | |
| text = "" if value is None else str(value) | |
| canonical = _numeric_canonical(text) | |
| if canonical is not None: | |
| return canonical | |
| collapsed = " ".join(text.strip().lower().split()) | |
| if collapsed in _NONFINITE_TOKENS: | |
| global _nonfinite_counter | |
| _nonfinite_counter += 1 | |
| return f"\x00nonfinite:{_nonfinite_counter}" | |
| return collapsed | |
| def _compare_integer(predicted: str, gold: str) -> bool: | |
| """Compare integer values after coercing with ``int(float(x))``.""" | |
| try: | |
| return int(float(predicted)) == int(float(gold)) | |
| except (TypeError, ValueError): | |
| return False | |
| def _compare_float(predicted: str, gold: str, tolerance: float = 0.01) -> bool: | |
| """Compare float values using a relative tolerance.""" | |
| try: | |
| predicted_value = float(predicted) | |
| gold_value = float(gold) | |
| except (TypeError, ValueError): | |
| return False | |
| if gold_value == 0.0: | |
| return abs(predicted_value - gold_value) <= 1e-9 | |
| return abs(predicted_value - gold_value) <= tolerance * abs(gold_value) | |
| def _compare_string(predicted: str, gold: str) -> bool: | |
| """Compare two strings with normalization.""" | |
| return _normalize_value(predicted) == _normalize_value(gold) | |
| def _parse_list_values(raw: str) -> set[str]: | |
| """Parse comma/newline/pipe-separated values into a normalized set. | |
| Handles plain delimited strings and Python list representations: | |
| "121\\n111\\n171" -> {"121", "111", "171"} | |
| "[121, 111, 171]" -> {"121", "111", "171"} | |
| "['Feil', 'Fisher']" -> {"feil", "fisher"} | |
| """ | |
| text = raw.strip() | |
| # Try Python literal (e.g., [121, 111] or ['Feil', 'Fisher']) | |
| if text.startswith("["): | |
| try: | |
| parsed = ast.literal_eval(text) | |
| if isinstance(parsed, list): | |
| return { | |
| _normalize_value(str(item)) for item in parsed if str(item).strip() | |
| } | |
| except (ValueError, SyntaxError): | |
| pass | |
| tokens = re.split(r"\s*(?:,|\n|\|)\s*", text) | |
| normalized = {_normalize_value(token) for token in tokens if token.strip()} | |
| return normalized | |
| def _compare_list( | |
| predicted: str, | |
| gold: str, | |
| gold_rows: list[tuple] | None = None, | |
| ) -> bool: | |
| """Compare list-like answers as order-insensitive sets.""" | |
| predicted_set = _parse_list_values(predicted) | |
| if gold_rows is not None: | |
| gold_set = { | |
| _normalize_value(str(cell)) | |
| for row in gold_rows | |
| for cell in row | |
| if str(cell).strip() | |
| } | |
| else: | |
| gold_set = _parse_list_values(gold) | |
| return predicted_set == gold_set | |
| def _parse_table_rows(raw: str) -> list[tuple[str, ...]]: | |
| """Parse a table answer string into normalized rows. | |
| Supports formats: | |
| - Pipe-separated rows: "111 | 1\\n121 | 2" | |
| - Python list-of-lists: "[[111, 1], [121, 2]]" | |
| - Numbered rows: "1. 111 | 1\\n2. 121 | 2" | |
| """ | |
| text = raw.strip() | |
| if not text: | |
| return [] | |
| # Try Python literal (list-of-lists from gold_answer storage) | |
| if text.startswith("["): | |
| try: | |
| parsed = ast.literal_eval(text) | |
| if isinstance(parsed, list): | |
| return [ | |
| tuple(_normalize_value(str(cell)) for cell in row) | |
| for row in parsed | |
| if isinstance(row, (list, tuple)) | |
| ] | |
| except (ValueError, SyntaxError): | |
| pass | |
| rows = [] | |
| for line in text.split("\n"): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # Strip leading numbering: "1. value | value" | |
| line = re.sub(r"^\d+\.\s*", "", line) | |
| cells = [_normalize_value(cell) for cell in re.split(r"\s*\|\s*", line)] | |
| if any(c for c in cells): | |
| rows.append(tuple(cells)) | |
| return rows | |
| def _compare_table( | |
| predicted: str, | |
| gold: str, | |
| gold_rows: list[tuple] | None = None, | |
| ) -> bool: | |
| """Compare table answers row-by-row with cell-level normalization. | |
| Order-insensitive: rows are compared as multisets (sorted). | |
| """ | |
| pred_rows = _parse_table_rows(predicted) | |
| if gold_rows is not None: | |
| gold_normalized = sorted( | |
| tuple(_normalize_value(str(cell)) for cell in row) for row in gold_rows | |
| ) | |
| else: | |
| gold_normalized = sorted(_parse_table_rows(gold)) | |
| # Sorted comparison preserves duplicate counts, acting as multiset equality | |
| pred_normalized = sorted(pred_rows) | |
| return pred_normalized == gold_normalized | |