Spaces:
Runtime error
Runtime error
| """Shared macro-F1 utilities for classification / MCQ / claim tasks. | |
| Provides the macro-averaged F1 aggregation and a common process_results | |
| helper used across all language-specific classification utils modules. | |
| Registers: | |
| - ``regex_last``: like lm_eval's ``regex`` filter, but picks a match from | |
| ``findall`` using ``group_select``; default ``group_select=-1`` is the | |
| **last** match (needed when CoT/reasoning mentions labels before the answer). | |
| - ``strip_think_recover``: drop ``<think>…</think>`` so | |
| downstream ``regex`` sees only the final answer channel; if that tail is empty | |
| (e.g. stop at ``\\n\\n`` before content), fall back to the last non-empty line | |
| of the reasoning block (see ``run_eval.py`` merge format). | |
| - ``regex_label_set``: pick the last occurrence of any allowed label from a | |
| per-doc field (e.g. ``labels_str`` for SIB-200, ``intents_str`` for InjongoIntent). | |
| Robust to channel-marker leak (e.g. a leaked ``<|channel|>`` header before the | |
| answer), models that say "the answer is X", and substring collisions | |
| (``science/technology`` vs ``science``) -- labels are matched longest-first. | |
| - ``strip_channel_header``: drop a Harmony-style channel-marker prefix | |
| (``<channel|>`` / ``<|channel|>`` and optional trailing ``<|message|>``) from | |
| the start of the response. Useful for open-text generation tasks | |
| (summarization / QA / open generation) where the actual answer is correct | |
| but the chat template leaks tokens at the start. No-op when no marker found. | |
| """ | |
| import re | |
| from lm_eval.api.filter import Filter | |
| from lm_eval.api.registry import register_filter | |
| def _strip_think_tags(text: str) -> str: | |
| """Strip <think>...</think> reasoning wrapper (e.g. Qwen thinking models).""" | |
| if "</think>" in text: | |
| return text.split("</think>")[-1].strip() | |
| return text | |
| class StripThinkRecoverFilter(Filter): | |
| """Remove think wrapper so MCQ ``regex`` runs on the answer tail only. | |
| When ``run_eval.py`` merges API ``reasoning`` + ``content``, the built-in | |
| ``regex`` ``([ABCD])`` / ``([ABCDE])`` filter would otherwise match the | |
| **first** letter inside the reasoning block. This step keeps only text after | |
| ``</think>`` when non-empty; if that tail is empty, uses the | |
| last non-empty line inside the reasoning (common when generation stops early). | |
| """ | |
| def __init__(self) -> None: | |
| pass | |
| def apply(self, resps, docs): | |
| def strip_set(inst): | |
| stripped = [] | |
| for resp in inst: | |
| if not isinstance(resp, str): | |
| resp = "" | |
| content = _strip_think_tags(resp) | |
| if not content and "</think>" in resp: | |
| reasoning = resp.split("</think>")[0] | |
| if "<think>" in reasoning: | |
| reasoning = reasoning.split("<think>", 1)[1] | |
| lines = [ | |
| ln.strip() for ln in reasoning.strip().splitlines() if ln.strip() | |
| ] | |
| content = lines[-1] if lines else "" | |
| stripped.append(content) | |
| return stripped | |
| return list(map(strip_set, resps)) | |
| def macro_f1_agg(items): | |
| """Compute macro-averaged F1 over all class labels. | |
| ``items`` is a list of ``(pred, gold)`` tuples per document. Use | |
| ``resps`` / ``reasoning_content`` on logged samples for the full trace. | |
| """ | |
| preds = [item[0] for item in items] | |
| golds = [item[1] for item in items] | |
| # Only compute F1 over labels that actually appear in the gold set | |
| all_labels = sorted(set(golds)) | |
| f1_scores = [] | |
| for label in all_labels: | |
| tp = sum(1 for p, g in zip(preds, golds) if p == label and g == label) | |
| fp = sum(1 for p, g in zip(preds, golds) if p == label and g != label) | |
| fn = sum(1 for p, g in zip(preds, golds) if p != label and g == label) | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = ( | |
| 2 * precision * recall / (precision + recall) | |
| if (precision + recall) > 0 | |
| else 0.0 | |
| ) | |
| f1_scores.append(f1) | |
| return sum(f1_scores) / len(f1_scores) if f1_scores else 0.0 | |
| class RegexLastFilter(Filter): | |
| """Regex extraction; ``group_select=-1`` uses the last ``findall`` hit.""" | |
| def __init__( | |
| self, | |
| regex_pattern: str = r"#### (\-?[0-9\.\,]+)", | |
| group_select: int = -1, | |
| fallback: str = "[invalid]", | |
| ) -> None: | |
| self.regex = re.compile(regex_pattern) | |
| self.group_select = group_select | |
| self.fallback = fallback | |
| def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: | |
| def filter_set(inst): | |
| filtered = [] | |
| for resp in inst: | |
| if not isinstance(resp, str): | |
| resp = "" | |
| matches = self.regex.findall(resp) | |
| if not matches: | |
| filtered.append(self.fallback) | |
| continue | |
| if self.group_select >= 0: | |
| idx = min(self.group_select, len(matches) - 1) | |
| else: | |
| idx = max(0, len(matches) + self.group_select) | |
| match = matches[idx] | |
| if isinstance(match, tuple): | |
| match = [m for m in match if m] | |
| if match: | |
| match = match[0] | |
| else: | |
| match = self.fallback | |
| match = str(match).strip() | |
| filtered.append(match) | |
| return filtered | |
| return list(map(filter_set, resps)) | |
| class RegexLabelSetFilter(Filter): | |
| """Pick the LAST occurrence of any allowed label from the response. | |
| The allowed-label list is read from a per-doc field (default ``labels_str``, | |
| e.g. ``"entertainment, geography, ..., science/technology, ..."``). Labels | |
| are matched **longest-first** so multi-segment labels like | |
| ``science/technology`` win over substring collisions like ``science``. | |
| Robust to: | |
| - ``<think>...</think>`` reasoning leak -- typically chain ``strip_think_recover`` | |
| first to drop the reasoning block, then this filter on the answer tail. | |
| - Harmony / channel-marker leak (e.g. a leaked ``<|channel|>`` header followed | |
| by the actual label) -- the regex still finds the trailing label substring. | |
| - "The answer is X" / "Final: X" patterns -- the LAST occurrence wins. | |
| If the doc field is missing, empty, or no label matches, returns | |
| ``fallback`` (default ``"[invalid]"``) so the row counts as wrong in F1. | |
| """ | |
| def __init__( | |
| self, | |
| labels_field: str = "labels_str", | |
| separator: str = ",", | |
| fallback: str = "[invalid]", | |
| ) -> None: | |
| self.labels_field = labels_field | |
| self.separator = separator | |
| self.fallback = fallback | |
| def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: | |
| out: list[list[str]] = [] | |
| for resp_set, doc in zip(resps, docs): | |
| raw_labels = str((doc or {}).get(self.labels_field, "") or "") | |
| labels = [ | |
| lbl.strip() for lbl in raw_labels.split(self.separator) if lbl.strip() | |
| ] | |
| # Sort longest-first so e.g. "science/technology" matches before | |
| # the bare "science" substring inside it. | |
| labels_sorted = sorted(set(labels), key=len, reverse=True) | |
| pattern = ( | |
| re.compile("(" + "|".join(re.escape(lbl) for lbl in labels_sorted) + ")") | |
| if labels_sorted | |
| else None | |
| ) | |
| filtered: list[str] = [] | |
| for resp in resp_set: | |
| if not isinstance(resp, str): | |
| resp = "" | |
| if pattern is None: | |
| filtered.append(self.fallback) | |
| continue | |
| matches = pattern.findall(resp) | |
| filtered.append(matches[-1].strip() if matches else self.fallback) | |
| out.append(filtered) | |
| return out | |
| class StripChannelHeaderFilter(Filter): | |
| """Strip Harmony-style channel/message header leaks from the response. | |
| Some providers (notably deepinfra fp8 Gemma) leak chat template tokens like | |
| ``<|channel|>final<|message|>`` -- or partial fragments such as | |
| ``s.<channel|>`` (where ``s.`` is the tail of the previous token) -- into | |
| the assistant ``content``. For open-text tasks (summarization / QA / open | |
| generation) this hurts every metric (ROUGE / BLEU / SAS-encoder / LLM judge) | |
| because the prefix garbage drags down the score even when the actual answer | |
| that follows is correct. | |
| Strategy: anchored at the start of the response, match up to ``max_prefix_chars`` | |
| (default 80) characters of any text followed by a ``<channel|>`` / | |
| ``<|channel|>`` marker, optionally followed by a Harmony ``<|message|>`` / | |
| ``<message|>`` marker (and any text in between like ``final``). Drop everything | |
| matched. No-op when no marker is found near the start, so safe for clean responses. | |
| Order tip: chain after ``strip_think_recover`` so reasoning is dropped first | |
| and this filter operates on the answer tail only. | |
| """ | |
| def __init__(self, max_prefix_chars: int = 80) -> None: | |
| self.max_prefix_chars = int(max_prefix_chars) | |
| # ^ - anchored at start | |
| # .{0,N}? - up to N chars of garbage prefix (non-greedy) | |
| # <\|?channel\|?> - matches both <channel|> and <|channel|> | |
| # (?:[^<]*<\|?message\|?>)? - optional Harmony "<...><|message|>" tail | |
| # \s* - eat trailing whitespace | |
| self._pattern = re.compile( | |
| r"^.{0," + str(self.max_prefix_chars) + r"}?<\|?channel\|?>" | |
| r"(?:[^<]{0,40}<\|?message\|?>)?\s*", | |
| re.DOTALL, | |
| ) | |
| def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: | |
| out: list[list[str]] = [] | |
| for resp_set in resps: | |
| stripped: list[str] = [] | |
| for resp in resp_set: | |
| if not isinstance(resp, str): | |
| resp = "" | |
| cleaned = self._pattern.sub("", resp, count=1) | |
| stripped.append(cleaned) | |
| out.append(stripped) | |
| return out | |
| def process_results_f1(doc, results, *, gold_key="target"): | |
| """Return ``(pred, gold)`` for macro-F1 aggregation. | |
| ``pred`` is the label after stripping think wrappers. Full reasoning | |
| is logged as ``reasoning_content`` when using ``run_eval.py``. | |
| Most tasks use ``gold_key="target"``; override for tasks that store | |
| the gold label under a different field name. | |
| """ | |
| raw_response = results[0].strip() if results[0] else "" | |
| pred = _strip_think_tags(raw_response) | |
| gold = doc.get(gold_key, "").strip() | |
| return {"f1_macro": (pred, gold)} | |