"""Shared macro-F1 utilities for classification / MCQ / claim tasks.
Provides the macro-averaged F1 aggregation and a common process_results
helper used across all language-specific classification utils modules.
Registers:
- ``regex_last``: like lm_eval's ``regex`` filter, but picks a match from
``findall`` using ``group_select``; default ``group_select=-1`` is the
**last** match (needed when CoT/reasoning mentions labels before the answer).
- ``strip_think_recover``: drop ``…`` so
downstream ``regex`` sees only the final answer channel; if that tail is empty
(e.g. stop at ``\\n\\n`` before content), fall back to the last non-empty line
of the reasoning block (see ``run_eval.py`` merge format).
- ``regex_label_set``: pick the last occurrence of any allowed label from a
per-doc field (e.g. ``labels_str`` for SIB-200, ``intents_str`` for InjongoIntent).
Robust to channel-marker leak (e.g. a leaked ``<|channel|>`` header before the
answer), models that say "the answer is X", and substring collisions
(``science/technology`` vs ``science``) -- labels are matched longest-first.
- ``strip_channel_header``: drop a Harmony-style channel-marker prefix
(```` / ``<|channel|>`` and optional trailing ``<|message|>``) from
the start of the response. Useful for open-text generation tasks
(summarization / QA / open generation) where the actual answer is correct
but the chat template leaks tokens at the start. No-op when no marker found.
"""
import re
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
def _strip_think_tags(text: str) -> str:
"""Strip ... reasoning wrapper (e.g. Qwen thinking models)."""
if "" in text:
return text.split("")[-1].strip()
return text
@register_filter("strip_think_recover")
class StripThinkRecoverFilter(Filter):
"""Remove think wrapper so MCQ ``regex`` runs on the answer tail only.
When ``run_eval.py`` merges API ``reasoning`` + ``content``, the built-in
``regex`` ``([ABCD])`` / ``([ABCDE])`` filter would otherwise match the
**first** letter inside the reasoning block. This step keeps only text after
```` when non-empty; if that tail is empty, uses the
last non-empty line inside the reasoning (common when generation stops early).
"""
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def strip_set(inst):
stripped = []
for resp in inst:
if not isinstance(resp, str):
resp = ""
content = _strip_think_tags(resp)
if not content and "" in resp:
reasoning = resp.split("")[0]
if "" in reasoning:
reasoning = reasoning.split("", 1)[1]
lines = [
ln.strip() for ln in reasoning.strip().splitlines() if ln.strip()
]
content = lines[-1] if lines else ""
stripped.append(content)
return stripped
return list(map(strip_set, resps))
def macro_f1_agg(items):
"""Compute macro-averaged F1 over all class labels.
``items`` is a list of ``(pred, gold)`` tuples per document. Use
``resps`` / ``reasoning_content`` on logged samples for the full trace.
"""
preds = [item[0] for item in items]
golds = [item[1] for item in items]
# Only compute F1 over labels that actually appear in the gold set
all_labels = sorted(set(golds))
f1_scores = []
for label in all_labels:
tp = sum(1 for p, g in zip(preds, golds) if p == label and g == label)
fp = sum(1 for p, g in zip(preds, golds) if p == label and g != label)
fn = sum(1 for p, g in zip(preds, golds) if p != label and g == label)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
f1_scores.append(f1)
return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0
@register_filter("regex_last")
class RegexLastFilter(Filter):
"""Regex extraction; ``group_select=-1`` uses the last ``findall`` hit."""
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select: int = -1,
fallback: str = "[invalid]",
) -> None:
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def filter_set(inst):
filtered = []
for resp in inst:
if not isinstance(resp, str):
resp = ""
matches = self.regex.findall(resp)
if not matches:
filtered.append(self.fallback)
continue
if self.group_select >= 0:
idx = min(self.group_select, len(matches) - 1)
else:
idx = max(0, len(matches) + self.group_select)
match = matches[idx]
if isinstance(match, tuple):
match = [m for m in match if m]
if match:
match = match[0]
else:
match = self.fallback
match = str(match).strip()
filtered.append(match)
return filtered
return list(map(filter_set, resps))
@register_filter("regex_label_set")
class RegexLabelSetFilter(Filter):
"""Pick the LAST occurrence of any allowed label from the response.
The allowed-label list is read from a per-doc field (default ``labels_str``,
e.g. ``"entertainment, geography, ..., science/technology, ..."``). Labels
are matched **longest-first** so multi-segment labels like
``science/technology`` win over substring collisions like ``science``.
Robust to:
- ``...`` reasoning leak -- typically chain ``strip_think_recover``
first to drop the reasoning block, then this filter on the answer tail.
- Harmony / channel-marker leak (e.g. a leaked ``<|channel|>`` header followed
by the actual label) -- the regex still finds the trailing label substring.
- "The answer is X" / "Final: X" patterns -- the LAST occurrence wins.
If the doc field is missing, empty, or no label matches, returns
``fallback`` (default ``"[invalid]"``) so the row counts as wrong in F1.
"""
def __init__(
self,
labels_field: str = "labels_str",
separator: str = ",",
fallback: str = "[invalid]",
) -> None:
self.labels_field = labels_field
self.separator = separator
self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
out: list[list[str]] = []
for resp_set, doc in zip(resps, docs):
raw_labels = str((doc or {}).get(self.labels_field, "") or "")
labels = [
lbl.strip() for lbl in raw_labels.split(self.separator) if lbl.strip()
]
# Sort longest-first so e.g. "science/technology" matches before
# the bare "science" substring inside it.
labels_sorted = sorted(set(labels), key=len, reverse=True)
pattern = (
re.compile("(" + "|".join(re.escape(lbl) for lbl in labels_sorted) + ")")
if labels_sorted
else None
)
filtered: list[str] = []
for resp in resp_set:
if not isinstance(resp, str):
resp = ""
if pattern is None:
filtered.append(self.fallback)
continue
matches = pattern.findall(resp)
filtered.append(matches[-1].strip() if matches else self.fallback)
out.append(filtered)
return out
@register_filter("strip_channel_header")
class StripChannelHeaderFilter(Filter):
"""Strip Harmony-style channel/message header leaks from the response.
Some providers (notably deepinfra fp8 Gemma) leak chat template tokens like
``<|channel|>final<|message|>`` -- or partial fragments such as
``s.`` (where ``s.`` is the tail of the previous token) -- into
the assistant ``content``. For open-text tasks (summarization / QA / open
generation) this hurts every metric (ROUGE / BLEU / SAS-encoder / LLM judge)
because the prefix garbage drags down the score even when the actual answer
that follows is correct.
Strategy: anchored at the start of the response, match up to ``max_prefix_chars``
(default 80) characters of any text followed by a ```` /
``<|channel|>`` marker, optionally followed by a Harmony ``<|message|>`` /
```` marker (and any text in between like ``final``). Drop everything
matched. No-op when no marker is found near the start, so safe for clean responses.
Order tip: chain after ``strip_think_recover`` so reasoning is dropped first
and this filter operates on the answer tail only.
"""
def __init__(self, max_prefix_chars: int = 80) -> None:
self.max_prefix_chars = int(max_prefix_chars)
# ^ - anchored at start
# .{0,N}? - up to N chars of garbage prefix (non-greedy)
# <\|?channel\|?> - matches both and <|channel|>
# (?:[^<]*<\|?message\|?>)? - optional Harmony "<...><|message|>" tail
# \s* - eat trailing whitespace
self._pattern = re.compile(
r"^.{0," + str(self.max_prefix_chars) + r"}?<\|?channel\|?>"
r"(?:[^<]{0,40}<\|?message\|?>)?\s*",
re.DOTALL,
)
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
out: list[list[str]] = []
for resp_set in resps:
stripped: list[str] = []
for resp in resp_set:
if not isinstance(resp, str):
resp = ""
cleaned = self._pattern.sub("", resp, count=1)
stripped.append(cleaned)
out.append(stripped)
return out
def process_results_f1(doc, results, *, gold_key="target"):
"""Return ``(pred, gold)`` for macro-F1 aggregation.
``pred`` is the label after stripping think wrappers. Full reasoning
is logged as ``reasoning_content`` when using ``run_eval.py``.
Most tasks use ``gold_key="target"``; override for tasks that store
the gold label under a different field name.
"""
raw_response = results[0].strip() if results[0] else ""
pred = _strip_think_tags(raw_response)
gold = doc.get(gold_key, "").strip()
return {"f1_macro": (pred, gold)}