"""Answer verification for SQLEnv using type-aware comparisons.""" from __future__ import annotations import ast import math import re # Dual-path import shim (package vs flat execution), mirroring # sql_environment.py's strip_chart_block import. chart_intent is a leaf # (json/re/pydantic only — no verifier import, no gradio/torch), so this is # leaf->leaf with no import cycle and no heavy deps pulled in. try: from .chart_intent import strip_chart_block except ImportError: # flat execution from chart_intent import strip_chart_block # type: ignore[no-redef] def _strip_answer_wrapping(text: str) -> str: """Remove common LLM wrapping artifacts from an answer string. Strips markdown code fences, surrounding quotes, "Answer: " prefix, and extra whitespace so the type-aware comparators see clean values. """ s = text.strip() # Markdown code blocks: ```...``` or ```sql\n...\n``` if s.startswith("```") and s.endswith("```"): # Language tag only if followed by newline (e.g. ```sql\n) s = re.sub(r"^```(?:\w+\n|\n?)", "", s) s = re.sub(r"\n?```$", "", s) s = s.strip() # "Answer: " or "answer:" prefix s = re.sub(r"^[Aa]nswer:\s*", "", s) # Surrounding quotes (single or double) if len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"): s = s[1:-1].strip() return s def _gold_is_degenerate(gold_text: str, gold_rows: list[tuple] | None) -> bool: """Return True when the gold target is empty/NULL (no legitimate pass exists). Last-line complement to the env's upstream ``is_degenerate_gold`` HarnessError in ``reset`` — re-expressed locally so ``verifier`` stays a leaf (no env import). Degenerate when: - ``gold_rows`` is an EXPLICIT empty list (``[]``) — an empty result set is never a legitimate gold target in this corpus, regardless of ``gold_text``. This closes the standalone hole where ``verify_answer(",", x, answer_type="list", gold_rows=[])`` would otherwise score a comparator-empty prediction as a pass (F007: ``verify_answer`` is a directly-callable leaf, so its empty-gold contract must hold even though the live env never passes ``gold_rows=[]``), OR - ``gold_rows`` is falsy/``None`` AND ``gold_text`` is empty/whitespace, OR - ``gold_rows`` is a single all-NULL row (``[(None,)]`` / ``[(None, None)]``) — mirrors ``sql_environment.is_degenerate_gold``'s all-NULL arm (e.g. ``MAX()``/``MIN()`` over an empty set), regardless of column count, OR - ``gold_text`` is a ``[``-prefixed literal that parses to an empty list (or a list of empty/whitespace-only items). A malformed ``[``-prefixed literal is caught and treated as NON-degenerate (falls through to normal dispatch), matching the parsing pattern elsewhere. """ text = gold_text.strip() if gold_rows is not None and len(gold_rows) == 0: return True if not gold_rows and not text: return True if gold_rows and len(gold_rows) == 1 and all(cell is None for cell in gold_rows[0]): return True if text.startswith("["): try: parsed = ast.literal_eval(text) except (ValueError, SyntaxError): return False if isinstance(parsed, list): return all(not str(item).strip() for item in parsed) return False def verify_answer( predicted: str, gold: str, answer_type: str | None = None, gold_rows: list[tuple] | None = None, ) -> bool: """Compare submitted and gold answers with type-aware dispatch.""" predicted_text = "" if predicted is None else str(predicted) gold_text = "" if gold is None else str(gold) # Self-sufficient scrubbing: strip any fenced chart block from the predicted # answer BEFORE type dispatch, so scoring is correct even when verify_answer # is called directly without the env's pre-strip. Gated on the ```chart``` # marker so a plain ```sql``` / ``` code fence is left for # _strip_answer_wrapping to unwrap (strip_chart_block's orphan-fence scrub # would otherwise remove only the closing fence and break the unwrap). if "```chart" in predicted_text.lower(): predicted_text = strip_chart_block(predicted_text) predicted_text = _strip_answer_wrapping(predicted_text) if not predicted_text.strip(): return False # Empty/NULL-gold guard: an empty result is never a legitimate correct # answer in this corpus, so never score a degenerate gold as a pass. if _gold_is_degenerate(gold_text, gold_rows): return False match answer_type: case "integer": return _compare_integer(predicted_text, gold_text) case "float": return _compare_float(predicted_text, gold_text) case "list": return _compare_list(predicted_text, gold_text, gold_rows) case "table": return _compare_table(predicted_text, gold_text, gold_rows) case "string": return _compare_string(predicted_text, gold_text) case _: return _compare_string(predicted_text, gold_text) def _numeric_canonical(value: str) -> str | None: """Return a canonical numeric string for ``value``, or None if it must stay text. Collapses numerically-equal forms so set/sorted cell comparison treats them as one element: "3"/"3.0"/"3.00"/"3e0" -> "3"; "1.5"/"1.50" -> "1.5". Returns None (caller falls back to text normalization) when ``value``: - is empty/whitespace, - is a leading-zero integer ("01234", "+0123") — protects zip/ID codes (BUT "0", "0.0", "0.5", "-0.5" are valid numbers, not leading-zero), - is non-finite ("nan", "inf", "-inf", "Infinity") — these are NOT canonicalizable numbers, so they fall back to string comparison and never reach ``int(f)`` (which would raise ValueError/OverflowError — C1), - is not a recognised number shape ("$5", "1,000", "1_000", "abc", "Feil"). Note the regex rejects underscores, so Python's "1_000" int literal stays a string and never collapses to "1000" (C4). Integer-shaped input is canonicalized with EXACT ``int()`` (unbounded), so two distinct 16+ digit IDs never collapse via float53 rounding (C2: "9999999999999999" stays itself, not "10000000000000000"). Decimal/exponent forms use ``float`` then a fixed 12-significant-figure format (``:.12g``) to kill float-repr noise (C3: 0.1+0.2 == "0.3"); a whole-valued float collapses to its int string ("3.0"->"3", "476090.0"->"476090"). Note: "1e3" DOES parse and canonicalizes to "1000" (numerically correct; documented + covered by a test). """ s = str(value).strip() if not s: return None # Reject leading-zero integers ("01234") and "+0123" to protect zip/ID codes, # but allow "0", "0.0", "0.5" (a single 0 before a "." is a real number). body = s[1:] if s[:1] in "+-" else s if len(body) > 1 and body[0] == "0" and body[1] != ".": return None # Integer-shaped: canonicalize with EXACT, unbounded int() so large IDs never # collapse via float53 rounding (C2). The leading-zero guard above already # rejected "01234"/"+0123", so these stay strings. if re.fullmatch(r"[+-]?\d+", s): return str(int(s)) # Decimal/exponent forms only (must contain a ".", "e", or "E"): float-parse, # then fix to 12 significant figures to kill repr noise (C3). A non-finite # parse (nan/inf) is NOT a canonicalizable number — return None (C1). if re.fullmatch(r"[+-]?(\d+\.\d*|\.\d+|\d+)([eE][+-]?\d+)?", s) and any( c in s for c in ".eE" ): f = float(s) if not math.isfinite(f): return None if f == int(f): return str(int(f)) return f"{f:.12g}" return None # Non-finite numeric tokens (the string forms ``float()`` accepts). A NaN/inf # gold is a degenerate/unusable target for this oracle, and IEEE NaN is never # equal to itself — so a predicted "nan"/"inf" must NEVER score as a pass, even # against a NaN/inf gold cell that stringifies the same way. ``_normalize_value`` # poisons these to a per-occurrence sentinel so two non-finite cells never # compare equal. _NONFINITE_TOKENS = frozenset( {"nan", "inf", "-inf", "+inf", "infinity", "-infinity", "+infinity"} ) # Monotonic counter making each non-finite normalization unique. The COUNTER # value is run-dependent, but the comparison OUTCOME is deterministic: a # non-finite cell never equals any other cell (including another "nan"), exactly # as IEEE NaN-vs-NaN. Module-level mutable int guarded by ``global``. _nonfinite_counter = 0 def _normalize_value(value: str) -> str: """Normalize for comparison. If ``value`` is safely numeric (per ``_numeric_canonical``), return its canonical numeric form so "3"=="3.0" holds in string/list/table cells; else lowercase + collapse-whitespace as before. This is the single seam: ``_parse_list_values``, ``_parse_table_rows``, and ``_compare_string`` all route cells through here, so the numeric rule applies without rewriting any set/sorted/multiset logic. Leading-zero/currency strings are unaffected because ``_numeric_canonical`` returns None for them. Non-finite tokens ("nan"/"inf"/...) get a per-occurrence sentinel so they never compare equal to anything (matching IEEE NaN semantics and refusing to score a degenerate NaN/inf gold as a pass — C1). """ text = "" if value is None else str(value) canonical = _numeric_canonical(text) if canonical is not None: return canonical collapsed = " ".join(text.strip().lower().split()) if collapsed in _NONFINITE_TOKENS: global _nonfinite_counter _nonfinite_counter += 1 return f"\x00nonfinite:{_nonfinite_counter}" return collapsed def _compare_integer(predicted: str, gold: str) -> bool: """Compare integer values after coercing with ``int(float(x))``.""" try: return int(float(predicted)) == int(float(gold)) except (TypeError, ValueError): return False def _compare_float(predicted: str, gold: str, tolerance: float = 0.01) -> bool: """Compare float values using a relative tolerance.""" try: predicted_value = float(predicted) gold_value = float(gold) except (TypeError, ValueError): return False if gold_value == 0.0: return abs(predicted_value - gold_value) <= 1e-9 return abs(predicted_value - gold_value) <= tolerance * abs(gold_value) def _compare_string(predicted: str, gold: str) -> bool: """Compare two strings with normalization.""" return _normalize_value(predicted) == _normalize_value(gold) def _parse_list_values(raw: str) -> set[str]: """Parse comma/newline/pipe-separated values into a normalized set. Handles plain delimited strings and Python list representations: "121\\n111\\n171" -> {"121", "111", "171"} "[121, 111, 171]" -> {"121", "111", "171"} "['Feil', 'Fisher']" -> {"feil", "fisher"} """ text = raw.strip() # Try Python literal (e.g., [121, 111] or ['Feil', 'Fisher']) if text.startswith("["): try: parsed = ast.literal_eval(text) if isinstance(parsed, list): return { _normalize_value(str(item)) for item in parsed if str(item).strip() } except (ValueError, SyntaxError): pass tokens = re.split(r"\s*(?:,|\n|\|)\s*", text) normalized = {_normalize_value(token) for token in tokens if token.strip()} return normalized def _compare_list( predicted: str, gold: str, gold_rows: list[tuple] | None = None, ) -> bool: """Compare list-like answers as order-insensitive sets.""" predicted_set = _parse_list_values(predicted) if gold_rows is not None: gold_set = { _normalize_value(str(cell)) for row in gold_rows for cell in row if str(cell).strip() } else: gold_set = _parse_list_values(gold) return predicted_set == gold_set def _parse_table_rows(raw: str) -> list[tuple[str, ...]]: """Parse a table answer string into normalized rows. Supports formats: - Pipe-separated rows: "111 | 1\\n121 | 2" - Python list-of-lists: "[[111, 1], [121, 2]]" - Numbered rows: "1. 111 | 1\\n2. 121 | 2" """ text = raw.strip() if not text: return [] # Try Python literal (list-of-lists from gold_answer storage) if text.startswith("["): try: parsed = ast.literal_eval(text) if isinstance(parsed, list): return [ tuple(_normalize_value(str(cell)) for cell in row) for row in parsed if isinstance(row, (list, tuple)) ] except (ValueError, SyntaxError): pass rows = [] for line in text.split("\n"): line = line.strip() if not line: continue # Strip leading numbering: "1. value | value" line = re.sub(r"^\d+\.\s*", "", line) cells = [_normalize_value(cell) for cell in re.split(r"\s*\|\s*", line)] if any(c for c in cells): rows.append(tuple(cells)) return rows def _compare_table( predicted: str, gold: str, gold_rows: list[tuple] | None = None, ) -> bool: """Compare table answers row-by-row with cell-level normalization. Order-insensitive: rows are compared as multisets (sorted). """ pred_rows = _parse_table_rows(predicted) if gold_rows is not None: gold_normalized = sorted( tuple(_normalize_value(str(cell)) for cell in row) for row in gold_rows ) else: gold_normalized = sorted(_parse_table_rows(gold)) # Sorted comparison preserves duplicate counts, acting as multiset equality pred_normalized = sorted(pred_rows) return pred_normalized == gold_normalized