mathtok / evaluation /metrics.py
SurweeshSP's picture
Initial clean MathTok release
edede4c
"""
MathTok Evaluation Metrics
Implements the five core metrics for evaluating structural tokenization
quality, as described in the MathTok paper:
SCR β€” Structural Compression Ratio
CCS β€” Canonical Consistency Score
OPS β€” Operator Preservation Score
TS β€” Token Stability
TDF β€” Tree Depth Fidelity
Each metric is self-contained and operates on TokenizedOutput objects
or lists of token strings, enabling easy integration into benchmark runs.
Baseline comparisons are supported for:
- GPT-2 tokenizer (character-level BPE)
- SentencePiece unigram
- Character-level tokenization
"""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass, field
from typing import Callable, Optional
logger = logging.getLogger(__name__)
# ── Metric result container ───────────────────────────────────────────────
@dataclass
class MetricResult:
"""Holds the value and supporting statistics for one metric."""
name: str
value: float
description: str
details: dict = field(default_factory=dict)
def __repr__(self) -> str:
return f"{self.name}: {self.value:.4f} ({self.description})"
@dataclass
class EvaluationReport:
"""Full report across all five MathTok metrics."""
scr: MetricResult
ccs: MetricResult
ops: MetricResult
ts: MetricResult
tdf: MetricResult
num_examples: int = 0
def summary(self) -> str:
lines = [
f"{'='*60}",
f" MathTok Evaluation Report (n={self.num_examples})",
f"{'='*60}",
f" {self.scr}",
f" {self.ccs}",
f" {self.ops}",
f" {self.ts}",
f" {self.tdf}",
f"{'='*60}",
]
return "\n".join(lines)
def to_dict(self) -> dict:
return {
"num_examples": self.num_examples,
"SCR": self.scr.value, "CCS": self.ccs.value,
"OPS": self.ops.value, "TS": self.ts.value,
"TDF": self.tdf.value,
}
# ── Metric 1: Structural Compression Ratio (SCR) ─────────────────────────
def structural_compression_ratio(
expressions: list[str],
tokenized_lengths: list[int],
) -> MetricResult:
"""
SCR = mean( |AST_tokens| / |raw_chars| )
Measures how efficiently the structural token stream represents the
information content relative to raw character count.
Lower SCR = more compressed. A ratio < 1.0 indicates compression.
Parameters
----------
expressions : list of raw input expression strings
tokenized_lengths : list of token counts output by MathTok
"""
assert len(expressions) == len(tokenized_lengths), "Length mismatch"
ratios = []
for expr, tlen in zip(expressions, tokenized_lengths):
char_len = max(len(expr), 1)
ratios.append(tlen / char_len)
mean_scr = sum(ratios) / len(ratios)
return MetricResult(
name="SCR",
value=mean_scr,
description="Structural Compression Ratio (tokens / chars); lower = more compressed",
details={
"min": min(ratios),
"max": max(ratios),
"std": _std(ratios),
"n": len(ratios),
},
)
# ── Metric 2: Canonical Consistency Score (CCS) ──────────────────────────
def canonical_consistency_score(
equivalent_pairs: list[tuple[str, str]],
tokenize_fn: Callable[[str], list[str]],
) -> MetricResult:
"""
CCS = mean( Jaccard(tokens_A, tokens_B) ) over equivalent pairs.
Measures how similar the token streams are for mathematically
equivalent expressions. CCS β†’ 1.0 means perfect consistency.
Parameters
----------
equivalent_pairs : list of (expr_A, expr_B) that are mathematically equal
tokenize_fn : function str β†’ list[str] (the tokenizer under test)
"""
scores = []
for expr_a, expr_b in equivalent_pairs:
try:
toks_a = set(tokenize_fn(expr_a))
toks_b = set(tokenize_fn(expr_b))
# Remove boundary tokens from Jaccard
toks_a = {t for t in toks_a if not t.startswith("[") }
toks_b = {t for t in toks_b if not t.startswith("[") }
if not toks_a and not toks_b:
scores.append(1.0)
else:
intersection = len(toks_a & toks_b)
union = len(toks_a | toks_b)
scores.append(intersection / union if union > 0 else 0.0)
except Exception as exc:
logger.debug("CCS: failed on pair (%s, %s): %s", expr_a[:30], expr_b[:30], exc)
scores.append(0.0)
mean_ccs = sum(scores) / len(scores) if scores else 0.0
return MetricResult(
name="CCS",
value=mean_ccs,
description="Canonical Consistency Score β€” Jaccard overlap for equivalent forms (higher is better)",
details={"scores": scores[:20], "n": len(scores), "std": _std(scores)},
)
# ── Metric 3: Operator Preservation Score (OPS) ──────────────────────────
def operator_preservation_score(
expressions: list[str],
tokenize_fn: Callable[[str], list[str]],
expected_operators: Optional[list[set[str]]] = None,
) -> MetricResult:
"""
OPS = fraction of expressions where all expected operator tokens appear.
If expected_operators is not provided, we auto-detect expected operators
from simple heuristics on the raw expression string.
Parameters
----------
expressions : list of raw expression strings
tokenize_fn : str β†’ list[str]
expected_operators : optional list of sets of expected operator tokens
"""
_OP_HEURISTICS: dict[str, str] = {
"+": "OP_ADD", "*": "OP_MUL", "^": "OP_POW", "**": "OP_POW",
"/": "FRAC", "sin": "FUNC_SIN", "cos": "FUNC_COS",
"tan": "FUNC_TAN", "log": "FUNC_LOG", "exp": "FUNC_EXP",
"sqrt": "FUNC_SQRT", "diff": "OP_DERIV", "integrate": "OP_INT",
"lim": "OP_LIMIT", "sum": "OP_SUM", "factorial": "FUNC_FACTORIAL",
}
preserved = 0
total = 0
for i, expr in enumerate(expressions):
if expected_operators is not None:
expected = expected_operators[i]
else:
# Heuristic: derive expected operators from raw expression
expected = set()
expr_lower = expr.lower()
for key, op_tok in _OP_HEURISTICS.items():
if key in expr_lower:
expected.add(op_tok)
if not expected:
continue # skip if we can't determine expected operators
try:
tokens = set(tokenize_fn(expr))
except Exception:
tokens = set()
if expected.issubset(tokens):
preserved += 1
total += 1
ops_value = preserved / total if total > 0 else 1.0
return MetricResult(
name="OPS",
value=ops_value,
description="Operator Preservation Score β€” % of expressions with all expected ops (higher is better)",
details={"preserved": preserved, "total": total},
)
# ── Metric 4: Token Stability (TS) ───────────────────────────────────────
def token_stability(
expression_groups: list[list[str]],
tokenize_fn: Callable[[str], list[str]],
) -> MetricResult:
"""
TS = 1 - mean( CoV(token_count) ) where CoV = std/mean.
Measures how stable the token count is across syntactic rewritings
of the same expression. TS β†’ 1.0 means perfectly stable.
Parameters
----------
expression_groups : list of groups; each group = rewritings of one expr
tokenize_fn : str β†’ list[str]
"""
covs = []
for group in expression_groups:
lengths = []
for expr in group:
try:
lengths.append(len(tokenize_fn(expr)))
except Exception:
lengths.append(0)
if len(lengths) < 2 or sum(lengths) == 0:
continue
mu = sum(lengths) / len(lengths)
std = _std(lengths)
cov = std / mu if mu > 0 else 0.0
covs.append(cov)
mean_cov = sum(covs) / len(covs) if covs else 0.0
ts_value = max(0.0, 1.0 - mean_cov)
return MetricResult(
name="TS",
value=ts_value,
description="Token Stability β€” 1 - CoV(token count across rewritings) (higher is better)",
details={"mean_cov": mean_cov, "n_groups": len(covs)},
)
# ── Metric 5: Tree Depth Fidelity (TDF) ──────────────────────────────────
def tree_depth_fidelity(
expressions: list[str],
tokenize_fn_with_meta: Callable, # returns TokenizedOutput
expected_depth_fn: Optional[Callable] = None,
) -> MetricResult:
"""
TDF = 1 - mean( |actual_max_depth - expected_max_depth| / expected_max_depth )
Measures how accurately the metadata captures the true tree depth.
Relies on metadata.depth fields being correctly computed.
Parameters
----------
expressions : list of expression strings
tokenize_fn_with_meta : pipeline.encode() or equivalent
expected_depth_fn : optional callable(expr) β†’ int for ground-truth depth
If None, uses sympy-computed depth as ground truth.
"""
errors = []
for expr in expressions:
try:
out = tokenize_fn_with_meta(expr)
if not out.metadata:
continue
actual_depth = max((m.depth for m in out.metadata if m.depth >= 0), default=0)
if expected_depth_fn is not None:
expected_depth = expected_depth_fn(expr)
else:
# Use AST subtree height from first canon_result as ground truth
if out.canon_results and out.canon_results[0].success:
import sympy as sp
expr_tree = out.canon_results[0].expr
expected_depth = _sympy_depth(expr_tree)
else:
continue
if expected_depth == 0:
errors.append(0.0)
else:
rel_err = abs(actual_depth - expected_depth) / expected_depth
errors.append(min(rel_err, 1.0))
except Exception as exc:
logger.debug("TDF: error on %s: %s", expr[:30], exc)
errors.append(1.0)
mean_err = sum(errors) / len(errors) if errors else 0.0
tdf_value = max(0.0, 1.0 - mean_err)
return MetricResult(
name="TDF",
value=tdf_value,
description="Tree Depth Fidelity β€” accuracy of depth metadata vs ground truth (higher is better)",
details={"mean_relative_error": mean_err, "n": len(errors)},
)
# ── Baseline comparators ──────────────────────────────────────────────────
def tokenize_character_level(expr: str) -> list[str]:
"""Character-level tokenizer baseline."""
return list(expr)
def make_gpt2_tokenizer():
"""Return a GPT-2 tokenizer as a baseline (requires transformers)."""
try:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("gpt2")
return lambda text: tok.tokenize(text)
except Exception:
logger.warning("GPT-2 tokenizer not available; using character baseline.")
return tokenize_character_level
def make_sentencepiece_tokenizer(model_path: str):
"""Return a SentencePiece tokenizer baseline."""
try:
import sentencepiece as spm
sp = spm.SentencePieceProcessor(model_file=model_path)
return lambda text: sp.encode(text, out_type=str)
except Exception:
logger.warning("SentencePiece not available.")
return tokenize_character_level
# ── Utility helpers ───────────────────────────────────────────────────────
def _std(values: list[float]) -> float:
if len(values) < 2:
return 0.0
mu = sum(values) / len(values)
var = sum((v - mu) ** 2 for v in values) / (len(values) - 1)
return math.sqrt(var)
def _sympy_depth(expr) -> int:
"""Compute tree depth of a SymPy expression."""
if not expr.args:
return 0
return 1 + max(_sympy_depth(a) for a in expr.args)