leideng/QCFuse / test /simple_eval_longbench_v2.py
leideng's picture
download
raw
12.2 kB
# Adapted from https://github.com/openai/simple-evals/
"""
LongBench v2: Towards Deeper Understanding and Reasoning on Realistic Long-Context Multitasks
Yushi Bai, Shangqing Tu, Jiajie Zhang, Hao Peng, Xiaozhi Wang, Xin Lv, Shulin Cao, Jiazheng Xu, Lei Hou, Yuxiao Dong, Jie Tang, Juanzi Li
https://arxiv.org/abs/2412.15204
"""
import csv
import json
import os
import re
from typing import Any, Dict, List, Optional
from transformers import AutoTokenizer
from sglang.test import simple_eval_common as common
from sglang.test.simple_eval_common import (
ANSWER_PATTERN_MULTICHOICE,
HTML_JINJA,
Eval,
EvalResult,
SamplerBase,
SingleEvalResult,
)
# LongBench-v2 task categories
TASK_CATEGORIES = {
"single_document_qa",
"multi_document_qa",
"long_in_context_learning",
"long_dialogue_history",
"code_repo_understanding",
"long_structured_data",
}
DEFAULT_DATASET = "THUDM/LongBench-v2"
DEFAULT_DATASET_SPLIT = "train"
def format_longbench_v2_question(row: dict) -> str:
"""Format a LongBench-v2 question using the official template."""
context = row.get("context", "")
question = row.get("question", "")
# Handle both standard format (A, B, C, D) and alternative format (choices list)
if "choices" in row:
choices = row["choices"]
choice_A = choices[0] if len(choices) > 0 else ""
choice_B = choices[1] if len(choices) > 1 else ""
choice_C = choices[2] if len(choices) > 2 else ""
choice_D = choices[3] if len(choices) > 3 else ""
else:
choice_A = row.get("A", row.get("choice_A", ""))
choice_B = row.get("B", row.get("choice_B", ""))
choice_C = row.get("C", row.get("choice_C", ""))
choice_D = row.get("D", row.get("choice_D", ""))
# Official LongBench-v2 template
prompt = f"""
Please read the following text and answer the question below.
<text>
{context.strip()}
</text>
What is the correct answer to this question: {question.strip()}
Choices:
(A) {choice_A.strip()}
(B) {choice_B.strip()}
(C) {choice_C.strip()}
(D) {choice_D.strip()}
Format your response as follows: "The correct answer is (insert answer here)"."""
return prompt
def extract_longbench_v2_answer(response: str) -> Optional[str]:
"""Extract answer from model response using official LongBench-v2 method."""
response = response.replace("*", "")
# First try: "The correct answer is (A)"
match = re.search(r"The correct answer is \(([A-D])\)", response, re.IGNORECASE)
if match:
return match.group(1).upper()
# Second try: "The correct answer is A"
match = re.search(r"The correct answer is ([A-D])", response, re.IGNORECASE)
if match:
return match.group(1).upper()
# Fallback: Standard SGLang multichoice pattern
match = re.search(ANSWER_PATTERN_MULTICHOICE, response)
if match:
return match.group(1).upper()
# Generic fallback when model says "answer is A"
match = re.search(r"answer\s+is\s*\(?([A-D])\)?", response, re.IGNORECASE)
if match:
return match.group(1).upper()
return None
class LongBenchV2Eval(Eval):
"""
Evaluation utility for LongBench-v2 dataset.
LongBench-v2 is designed to assess the ability of LLMs to handle long-context problems
requiring deep understanding and reasoning across real-world multitasks.
"""
def __init__(
self,
model: str = None,
data_source: str = DEFAULT_DATASET,
num_examples: Optional[int] = None,
num_threads: int = 1,
n_repeats: int = 1,
categories: Optional[List[str]] = None,
max_context_length: Optional[int] = None,
min_context_length: Optional[int] = None,
):
"""
Initialize LongBench-v2 evaluation.
Args:
data_source: HuggingFace dataset name, local file path (CSV/JSON)
num_examples: Number of examples to evaluate (None for all)
num_threads: Number of threads for parallel processing
n_repeats: Number of times to repeat evaluation for error bars
categories: List of task categories to include (None for all)
max_context_length: Maximum context length in characters
min_context_length: Minimum context length in characters
"""
self.tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
self.min_context_length = min_context_length
self.max_context_length = max_context_length
# Load dataset based on data source type
examples = self._load_dataset(data_source)
# Apply filtering
if categories:
examples = [ex for ex in examples if ex.get("category") in categories]
# Sample examples if specified
if num_examples:
assert n_repeats == 1, "n_repeats only supported when not sampling examples"
examples = examples[: min(num_examples, len(examples))]
# Repeat examples for multiple runs
examples = examples * n_repeats
if not examples:
raise ValueError(
"No examples available for LongBench-v2 evaluation after filtering"
)
self.examples = examples
self.n_repeats = n_repeats
self.num_threads = num_threads
print(f"Loaded {len(self.examples)} examples from LongBench-v2")
if categories:
print(f"Filtered to categories: {categories}")
if min_context_length or max_context_length:
print(
f"Context length filter: {min_context_length}-{max_context_length} characters"
)
def _load_dataset(self, data_source: str) -> List[Dict[str, Any]]:
"""Load dataset from HuggingFace hub or local files."""
if not data_source:
data_source = DEFAULT_DATASET
if os.path.exists(data_source):
raw_examples = self._load_local_file(data_source)
else:
raw_examples = self._load_hf_dataset(data_source)
return [self._normalize_example(example) for example in raw_examples]
def _load_local_file(self, path: str) -> List[Dict[str, Any]]:
"""Load examples from a local CSV/JSON/JSONL file."""
suffix = os.path.splitext(path)[1].lower()
if suffix in {".json", ".jsonl"}:
with open(path, "r", encoding="utf-8") as fh:
if suffix == ".jsonl":
data = [json.loads(line) for line in fh if line.strip()]
else:
data = json.load(fh)
elif suffix == ".csv":
with open(path, "r", encoding="utf-8") as fh:
reader = csv.DictReader(fh)
data = list(reader)
else:
# Try JSON, then CSV as fallback
try:
with open(path, "r", encoding="utf-8") as fh:
data = json.load(fh)
except json.JSONDecodeError:
with open(path, "r", encoding="utf-8") as fh:
reader = csv.DictReader(fh)
data = list(reader)
if isinstance(data, dict):
data = data.get("data", [])
if not isinstance(data, list):
raise ValueError("Expected list of examples from local file")
return data
def _load_hf_dataset(self, identifier: str) -> List[Dict[str, Any]]:
"""Load the dataset from HuggingFace Hub."""
parts = identifier.split(":", maxsplit=1)
dataset_name = parts[0]
split = parts[1] if len(parts) == 2 else DEFAULT_DATASET_SPLIT
try:
from datasets import load_dataset # type: ignore
except ImportError as exc:
raise ImportError(
"Please install the 'datasets' package to load LongBench-v2 from HuggingFace: pip install datasets"
) from exc
dataset = load_dataset(dataset_name, split=split)
return [dict(row) for row in dataset]
def _normalize_example(self, example: Dict[str, Any]) -> Dict[str, Any]:
"""Ensure each example exposes the expected keys."""
normalized = dict(example)
for letter in ["A", "B", "C", "D"]:
choice_key = f"choice_{letter}"
if letter not in normalized and choice_key in normalized:
normalized[letter] = normalized[choice_key]
if "category" not in normalized and "domain" in normalized:
normalized["category"] = normalized["domain"]
answer = normalized.get("answer")
if isinstance(answer, str):
normalized["answer"] = answer.strip().upper()
elif isinstance(answer, int) and 0 <= answer < 4:
normalized["answer"] = ["A", "B", "C", "D"][answer]
return normalized
def _check_context_length(
self,
formatted_question: str,
tokenizer: AutoTokenizer,
min_length: Optional[int],
max_length: Optional[int],
) -> bool:
"""Filter examples by context length measured in characters."""
input_ids = tokenizer.encode(formatted_question)
context_length = len(input_ids)
if min_length is not None and context_length < min_length:
return False
if max_length is not None and context_length > max_length:
return False
return True
def __call__(self, sampler: SamplerBase) -> EvalResult:
"""Run the evaluation."""
def fn(row: dict):
# Format the question using official template
formatted_question = format_longbench_v2_question(row)
if self.min_context_length or self.max_context_length:
if not self._check_context_length(
formatted_question,
self.tokenizer,
self.min_context_length,
self.max_context_length,
):
# Skip this example
return None
prompt_messages = [
sampler._pack_message(content=formatted_question, role="user")
]
# Get model response
response_text = sampler(prompt_messages)
if response_text is None:
response_text = ""
# Extract answer using official method
extracted_answer = extract_longbench_v2_answer(response_text)
# Get correct answer
correct_answer = row.get("answer", "")
if isinstance(correct_answer, str):
correct_answer = correct_answer.strip().upper()
elif isinstance(correct_answer, int) and 0 <= correct_answer < 4:
correct_answer = ["A", "B", "C", "D"][correct_answer]
# Calculate score
score = 1.0 if extracted_answer == correct_answer else 0.0
# Generate HTML report
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=correct_answer,
extracted_answer=extracted_answer,
)
# Build conversation
convo = prompt_messages + [dict(content=response_text, role="assistant")]
# Prepare metrics
metrics = {"chars": len(response_text)}
# Add category-specific metrics
category = row.get("category", row.get("domain", "unknown"))
if category in TASK_CATEGORIES:
metrics[category] = score
difficulty = row.get("difficulty")
if isinstance(difficulty, str) and difficulty:
metrics[f"difficulty_{difficulty.lower()}"] = score
return SingleEvalResult(
html=html,
score=score,
convo=convo,
metrics=metrics,
)
# Run evaluation with progress tracking
results = common.map_with_progress(fn, self.examples, self.num_threads)
return common.aggregate_results(results)

Xet Storage Details

Size:
12.2 kB
·
Xet hash:
f1e8e3b23a0af726f78549af0e66929407bc45678a83f2314be6befbb3b4f6a6

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.