| """ | |
| LongBench / RULER utilities for Blend evaluation. | |
| Supports: hotpotqa, 2wikimqa, musique, ruler_vt, ruler_mq, ruler_mv | |
| Metrics : F1, string-match-all | |
| """ | |
| import re | |
| import json | |
| import string | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Optional, Any | |
| from collections import Counter | |
| from rouge_score import rouge_scorer as _rouge_scorer | |
| # ==================== Dataset Configuration ==================== | |
| RULER_DATASETS = ("ruler_vt", "ruler_mq", "ruler_mv") | |
| DATASETS = ("hotpotqa", "2wikimqa", "musique", *RULER_DATASETS) | |
| SYSTEM_PROMPT = ( | |
| "You are a highly precise question-answering assistant.\n\n" | |
| "## Task\n" | |
| "Read the provided passages and answer the user's question based strictly " | |
| "on the information within.\n\n" | |
| "## Output Rules\n" | |
| "- Direct Answer ONLY: Output nothing but the final exact answer.\n" | |
| "- No Explanations: Do not provide reasoning, context, conversational " | |
| "fillers, or extra words.\n\n" | |
| "## Passages\n" | |
| ) | |
| QUERY_PREFIX = ( | |
| "Remember to answer the question based strictly on the passages above. " | |
| "Output ONLY the answer and no other words.\n\n## Question\n" | |
| ) | |
| RULER_VT_FEWSHOT = ( | |
| "Example:\n" | |
| "Text:\n" | |
| "The maintenance report begins with routine notes about lighting, archived boxes, " | |
| "and a schedule change for the west corridor. A coordinator wrote that several " | |
| "labels had been moved after the weekly inspection, but most of the paragraph is " | |
| "ordinary filler that should not be treated as an assignment. Near the inventory " | |
| "table, the record states VAR ABCDE = 12345 before describing spare cables and " | |
| "a delayed delivery. Later, after a note about temperature readings, it says " | |
| "VAR FGHIJ = VAR ABCDE. The next page mentions visitor badges, old invoices, " | |
| "and two unrelated serial numbers. Hidden between those details is " | |
| "VAR KLMNO = VAR FGHIJ. The report then discusses a broken cart, a missing " | |
| "clipboard, and yesterday's storage request. After that, the chain continues: " | |
| "VAR PQRST = VAR KLMNO. The closing paragraph talks about cleaning supplies, " | |
| "meeting times, and duplicate copies of the same memo, then finally records " | |
| "VAR UVWXY = VAR PQRST. A later appendix lists department codes, desk numbers, " | |
| "and several dates from a training calendar. Those details are only background " | |
| "text, even when they contain digits or capitalized words. Another section says " | |
| "that the archive door was repaired, the finance folder was renamed, and the " | |
| "morning checklist should be reviewed before the next shift. The example also " | |
| "mentions that the copied forms were sorted by color, that the hallway map was " | |
| "reprinted, and that temporary notes should be discarded after confirmation. " | |
| "None of these surrounding statements changes the assignment chain above. " | |
| "No other variable in this example is assigned the " | |
| "value 12345 through this chain.\n" | |
| "Question: Find all variables that are assigned the value 12345 in the text above.\n" | |
| "Answer: ABCDE, FGHIJ, KLMNO, PQRST, UVWXY\n\n" | |
| ) | |
| RULER_SYS_INSTRUCT = { | |
| "ruler_mq": ( | |
| "Some special magic numbers are hidden within the following text. " | |
| "Make sure to memorize it. I will quiz you about the numbers afterwards.\n\n" | |
| "Return only the requested magic numbers. If there are multiple numbers, " | |
| "separate them with commas and do not explain.\n\n" | |
| "## Text\n" | |
| ), | |
| "ruler_mv": ( | |
| "Some special magic numbers are hidden within the following text. " | |
| "Make sure to memorize it. I will quiz you about the numbers afterwards.\n\n" | |
| "Return only the requested magic numbers. If there are multiple numbers, " | |
| "separate them with commas and do not explain.\n\n" | |
| "## Text\n" | |
| ), | |
| "ruler_vt": ( | |
| "Memorize and track the chain(s) of variable assignment hidden in the following text.\n\n" | |
| f"{RULER_VT_FEWSHOT}" | |
| "For the actual text below, return only the variable names assigned to the queried value. " | |
| "If there are multiple variable names, separate them with commas and do not explain.\n\n" | |
| "## Text\n" | |
| ), | |
| } | |
| RULER_QUERY_PREFIX = { | |
| "ruler_mq": ( | |
| "Answer the question using only the provided text. " | |
| "Return only the requested magic number or numbers, separated by commas.\n\n## Question\n" | |
| ), | |
| "ruler_mv": ( | |
| "Answer the question using only the provided text. " | |
| "Return only the requested magic number or numbers, separated by commas.\n\n## Question\n" | |
| ), | |
| "ruler_vt": ( | |
| "Answer the question using only the provided text. " | |
| "Return only the requested variable name or names, separated by commas.\n\n## Question\n" | |
| ), | |
| } | |
| MAX_NEW_TOKENS = 48 | |
| MAX_NEW_TOKENS_BY_DATASET = { | |
| "ruler_vt": 30, | |
| "ruler_mq": 128, | |
| "ruler_mv": 128, | |
| } | |
| DEFAULT_CHUNK_TOPK = 20 | |
| # ==================== Data Loading ==================== | |
| def load_dataset(dataset_path: str) -> List[Dict]: | |
| """Load a JSONL dataset file.""" | |
| path = Path(dataset_path) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Dataset file not found: {dataset_path}") | |
| with open(path, "r", encoding="utf-8") as f: | |
| return [json.loads(line) for line in f if line.strip()] | |
| # ==================== Prompt Building ==================== | |
| def normalize_question(question: str) -> str: | |
| """Lowercase first letter and ensure trailing '?'.""" | |
| if not question: | |
| return "" | |
| if not question.endswith("?"): | |
| question += "?" | |
| return question[0].lower() + question[1:] | |
| def build_prompt_for_dataset( | |
| example: Dict, dataset_name: str | |
| ) -> Tuple[List[str], List[str]]: | |
| """Build document list and question prompt. | |
| Returns: | |
| docs: list of passage strings | |
| q_prompt: [query_prefix, input_text] | |
| """ | |
| context = example.get("context", "") | |
| if dataset_name not in RULER_DATASETS: | |
| context = context[: min(len(context), DEFAULT_CHUNK_TOPK)] | |
| docs = [f"Passage:\n{ctx}\n\n" for ctx in context] | |
| input_text = example.get("input", "") | |
| if dataset_name in RULER_DATASETS: | |
| return docs, [RULER_QUERY_PREFIX[dataset_name], input_text] | |
| return docs, [QUERY_PREFIX, normalize_question(input_text)] | |
| # ==================== Scoring Functions ==================== | |
| def _normalize_answer(s: str) -> str: | |
| """Lower text and remove punctuation, articles, extra whitespace.""" | |
| s = s.lower() | |
| s = re.sub(r"\b(a|an|the)\b", " ", s) | |
| s = "".join(ch for ch in s if ch not in set(string.punctuation)) | |
| return " ".join(s.split()) | |
| def _parse_generation(s: str) -> str: | |
| """Take the first non-empty line from the generation.""" | |
| s = s.lstrip("\n").strip() | |
| if not s: | |
| return "" | |
| first_line = s.split("\n")[0].strip() | |
| if first_line.lower().startswith("yes"): | |
| return "Yes" | |
| if first_line.split()[0].lower().startswith("no"): | |
| return "No" | |
| return first_line | |
| def scorer_f1( | |
| prediction: str, ground_truth: str, tokenizer: Optional[Any] = None | |
| ) -> float: | |
| """Token-level F1 score (word-level or sub-word-level with tokenizer).""" | |
| if tokenizer is None: | |
| pred_toks = _normalize_answer(prediction).split() | |
| gold_toks = _normalize_answer(ground_truth).split() | |
| else: | |
| prediction = _parse_generation(prediction) | |
| pred_toks = tokenizer.encode(_normalize_answer(prediction))[1:] | |
| gold_toks = tokenizer.encode(_normalize_answer(ground_truth))[1:] | |
| if not pred_toks or not gold_toks: | |
| return float(int(pred_toks == gold_toks)) if tokenizer else 0.0 | |
| common = Counter(pred_toks) & Counter(gold_toks) | |
| num_same = sum(common.values()) | |
| if num_same == 0: | |
| return 0.0 | |
| precision = num_same / len(pred_toks) | |
| recall = num_same / len(gold_toks) | |
| return 2 * precision * recall / (precision + recall) | |
| def scorer_rouge(prediction: str, ground_truth: str) -> float: | |
| """ROUGE-L F-measure.""" | |
| if not prediction.strip() or not ground_truth.strip(): | |
| return 0.0 | |
| scorer = _rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) | |
| scores = scorer.score(ground_truth, prediction) | |
| return scores["rougeL"].fmeasure | |
| def scorer_string_match_all(prediction: str, ground_truths: List[str]) -> float: | |
| """RULER-style score: fraction of references found in the prediction.""" | |
| if not ground_truths: | |
| return 0.0 | |
| prediction = prediction.lower() | |
| hits = sum(1.0 if str(truth).lower() in prediction else 0.0 for truth in ground_truths) | |
| return hits / len(ground_truths) | |
| # ==================== Unified Evaluation ==================== | |
| # metric_type -> scorer callable | |
| _METRIC_SCORERS = { | |
| "f1": scorer_f1, | |
| "rouge": scorer_rouge, | |
| } | |
| # dataset -> metric_type | |
| TASK_METRICS = { | |
| "hotpotqa": "f1", | |
| "2wikimqa": "f1", | |
| "musique": "f1", | |
| "ruler_vt": "string_match_all", | |
| "ruler_mq": "string_match_all", | |
| "ruler_mv": "string_match_all", | |
| } | |
| METRIC_DISPLAY = { | |
| "f1": "F1", | |
| "rouge": "ROUGE-L", | |
| "string_match_all": "StringMatchAll", | |
| } | |
| def evaluate_sample( | |
| prediction: str, | |
| ground_truths: List[str], | |
| dataset_name: str, | |
| tokenizer: Optional[Any] = None, | |
| ) -> float: | |
| """Evaluate prediction against ground truths; returns max score.""" | |
| if not ground_truths: | |
| return 0.0 | |
| metric_type = TASK_METRICS.get(dataset_name, "f1") | |
| if metric_type == "string_match_all": | |
| return scorer_string_match_all(prediction, ground_truths) | |
| scorer_fn = _METRIC_SCORERS[metric_type] | |
| best = 0.0 | |
| for truth in ground_truths: | |
| if metric_type == "f1": | |
| score = scorer_fn(prediction, truth, tokenizer) | |
| else: | |
| score = scorer_fn(prediction, truth) | |
| if score > best: | |
| best = score | |
| return best | |
| # ==================== Simple Accessors ==================== | |
| def get_system_prompt(_dataset_name: str) -> str: | |
| return RULER_SYS_INSTRUCT.get(_dataset_name, SYSTEM_PROMPT) | |
| def get_max_new_tokens(_dataset_name: str) -> int: | |
| return MAX_NEW_TOKENS_BY_DATASET.get(_dataset_name, MAX_NEW_TOKENS) | |
| def get_metric_name(dataset_name: str) -> str: | |
| metric_type = TASK_METRICS.get(dataset_name, "f1") | |
| return METRIC_DISPLAY.get(metric_type, "Score") | |
Xet Storage Details
- Size:
- 10.3 kB
- Xet hash:
- ef7a5cf7f7378aeeff2bb3e88fcc8da64fd9f29ad5e6f6b92284d235416555fa
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.