"""Shared macro-F1 utilities for classification / MCQ / claim tasks. Provides the macro-averaged F1 aggregation and a common process_results helper used across all language-specific classification utils modules. Registers: - ``regex_last``: like lm_eval's ``regex`` filter, but picks a match from ``findall`` using ``group_select``; default ``group_select=-1`` is the **last** match (needed when CoT/reasoning mentions labels before the answer). - ``strip_think_recover``: drop ```` so downstream ``regex`` sees only the final answer channel; if that tail is empty (e.g. stop at ``\\n\\n`` before content), fall back to the last non-empty line of the reasoning block (see ``run_eval.py`` merge format). - ``regex_label_set``: pick the last occurrence of any allowed label from a per-doc field (e.g. ``labels_str`` for SIB-200, ``intents_str`` for InjongoIntent). Robust to channel-marker leak (e.g. a leaked ``<|channel|>`` header before the answer), models that say "the answer is X", and substring collisions (``science/technology`` vs ``science``) -- labels are matched longest-first. - ``strip_channel_header``: drop a Harmony-style channel-marker prefix (```` / ``<|channel|>`` and optional trailing ``<|message|>``) from the start of the response. Useful for open-text generation tasks (summarization / QA / open generation) where the actual answer is correct but the chat template leaks tokens at the start. No-op when no marker found. """ import re from lm_eval.api.filter import Filter from lm_eval.api.registry import register_filter def _strip_think_tags(text: str) -> str: """Strip ... reasoning wrapper (e.g. Qwen thinking models).""" if "" in text: return text.split("")[-1].strip() return text @register_filter("strip_think_recover") class StripThinkRecoverFilter(Filter): """Remove think wrapper so MCQ ``regex`` runs on the answer tail only. When ``run_eval.py`` merges API ``reasoning`` + ``content``, the built-in ``regex`` ``([ABCD])`` / ``([ABCDE])`` filter would otherwise match the **first** letter inside the reasoning block. This step keeps only text after ```` when non-empty; if that tail is empty, uses the last non-empty line inside the reasoning (common when generation stops early). """ def __init__(self) -> None: pass def apply(self, resps, docs): def strip_set(inst): stripped = [] for resp in inst: if not isinstance(resp, str): resp = "" content = _strip_think_tags(resp) if not content and "" in resp: reasoning = resp.split("")[0] if "" in reasoning: reasoning = reasoning.split("", 1)[1] lines = [ ln.strip() for ln in reasoning.strip().splitlines() if ln.strip() ] content = lines[-1] if lines else "" stripped.append(content) return stripped return list(map(strip_set, resps)) def macro_f1_agg(items): """Compute macro-averaged F1 over all class labels. ``items`` is a list of ``(pred, gold)`` tuples per document. Use ``resps`` / ``reasoning_content`` on logged samples for the full trace. """ preds = [item[0] for item in items] golds = [item[1] for item in items] # Only compute F1 over labels that actually appear in the gold set all_labels = sorted(set(golds)) f1_scores = [] for label in all_labels: tp = sum(1 for p, g in zip(preds, golds) if p == label and g == label) fp = sum(1 for p, g in zip(preds, golds) if p == label and g != label) fn = sum(1 for p, g in zip(preds, golds) if p != label and g == label) precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 f1 = ( 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 ) f1_scores.append(f1) return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0 @register_filter("regex_last") class RegexLastFilter(Filter): """Regex extraction; ``group_select=-1`` uses the last ``findall`` hit.""" def __init__( self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select: int = -1, fallback: str = "[invalid]", ) -> None: self.regex = re.compile(regex_pattern) self.group_select = group_select self.fallback = fallback def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def filter_set(inst): filtered = [] for resp in inst: if not isinstance(resp, str): resp = "" matches = self.regex.findall(resp) if not matches: filtered.append(self.fallback) continue if self.group_select >= 0: idx = min(self.group_select, len(matches) - 1) else: idx = max(0, len(matches) + self.group_select) match = matches[idx] if isinstance(match, tuple): match = [m for m in match if m] if match: match = match[0] else: match = self.fallback match = str(match).strip() filtered.append(match) return filtered return list(map(filter_set, resps)) @register_filter("regex_label_set") class RegexLabelSetFilter(Filter): """Pick the LAST occurrence of any allowed label from the response. The allowed-label list is read from a per-doc field (default ``labels_str``, e.g. ``"entertainment, geography, ..., science/technology, ..."``). Labels are matched **longest-first** so multi-segment labels like ``science/technology`` win over substring collisions like ``science``. Robust to: - ``...`` reasoning leak -- typically chain ``strip_think_recover`` first to drop the reasoning block, then this filter on the answer tail. - Harmony / channel-marker leak (e.g. a leaked ``<|channel|>`` header followed by the actual label) -- the regex still finds the trailing label substring. - "The answer is X" / "Final: X" patterns -- the LAST occurrence wins. If the doc field is missing, empty, or no label matches, returns ``fallback`` (default ``"[invalid]"``) so the row counts as wrong in F1. """ def __init__( self, labels_field: str = "labels_str", separator: str = ",", fallback: str = "[invalid]", ) -> None: self.labels_field = labels_field self.separator = separator self.fallback = fallback def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: out: list[list[str]] = [] for resp_set, doc in zip(resps, docs): raw_labels = str((doc or {}).get(self.labels_field, "") or "") labels = [ lbl.strip() for lbl in raw_labels.split(self.separator) if lbl.strip() ] # Sort longest-first so e.g. "science/technology" matches before # the bare "science" substring inside it. labels_sorted = sorted(set(labels), key=len, reverse=True) pattern = ( re.compile("(" + "|".join(re.escape(lbl) for lbl in labels_sorted) + ")") if labels_sorted else None ) filtered: list[str] = [] for resp in resp_set: if not isinstance(resp, str): resp = "" if pattern is None: filtered.append(self.fallback) continue matches = pattern.findall(resp) filtered.append(matches[-1].strip() if matches else self.fallback) out.append(filtered) return out @register_filter("strip_channel_header") class StripChannelHeaderFilter(Filter): """Strip Harmony-style channel/message header leaks from the response. Some providers (notably deepinfra fp8 Gemma) leak chat template tokens like ``<|channel|>final<|message|>`` -- or partial fragments such as ``s.`` (where ``s.`` is the tail of the previous token) -- into the assistant ``content``. For open-text tasks (summarization / QA / open generation) this hurts every metric (ROUGE / BLEU / SAS-encoder / LLM judge) because the prefix garbage drags down the score even when the actual answer that follows is correct. Strategy: anchored at the start of the response, match up to ``max_prefix_chars`` (default 80) characters of any text followed by a ```` / ``<|channel|>`` marker, optionally followed by a Harmony ``<|message|>`` / ```` marker (and any text in between like ``final``). Drop everything matched. No-op when no marker is found near the start, so safe for clean responses. Order tip: chain after ``strip_think_recover`` so reasoning is dropped first and this filter operates on the answer tail only. """ def __init__(self, max_prefix_chars: int = 80) -> None: self.max_prefix_chars = int(max_prefix_chars) # ^ - anchored at start # .{0,N}? - up to N chars of garbage prefix (non-greedy) # <\|?channel\|?> - matches both and <|channel|> # (?:[^<]*<\|?message\|?>)? - optional Harmony "<...><|message|>" tail # \s* - eat trailing whitespace self._pattern = re.compile( r"^.{0," + str(self.max_prefix_chars) + r"}?<\|?channel\|?>" r"(?:[^<]{0,40}<\|?message\|?>)?\s*", re.DOTALL, ) def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: out: list[list[str]] = [] for resp_set in resps: stripped: list[str] = [] for resp in resp_set: if not isinstance(resp, str): resp = "" cleaned = self._pattern.sub("", resp, count=1) stripped.append(cleaned) out.append(stripped) return out def process_results_f1(doc, results, *, gold_key="target"): """Return ``(pred, gold)`` for macro-F1 aggregation. ``pred`` is the label after stripping think wrappers. Full reasoning is logged as ``reasoning_content`` when using ``run_eval.py``. Most tasks use ``gold_key="target"``; override for tasks that store the gold label under a different field name. """ raw_response = results[0].strip() if results[0] else "" pred = _strip_think_tags(raw_response) gold = doc.get(gold_key, "").strip() return {"f1_macro": (pred, gold)}