leideng/QCFuse / data /build_longbench_data.py
leideng's picture
download
raw
8.24 kB
#!/usr/bin/env python3
"""Build Blend-format LongBench JSONL files for musique, 2wikimqa, hotpotqa."""
from __future__ import annotations
import argparse
import hashlib
import json
from pathlib import Path
from typing import Any
SCRIPT_VERSION = "qcfuse-longbench-v1-top20-200"
DATASETS = ("musique", "2wikimqa", "hotpotqa")
def load_tokenizer(path: str) -> Any:
from transformers import AutoTokenizer
try:
return AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=True)
except Exception:
return AutoTokenizer.from_pretrained(path, trust_remote_code=True)
def normalize_context(value: Any) -> str:
if isinstance(value, list):
return "\n\n".join(str(item) for item in value)
return "" if value is None else str(value)
def normalize_answers(value: Any) -> list[str]:
if value is None:
return []
if isinstance(value, list):
return [str(item) for item in value]
return [str(value)]
def make_splitter(tokenizer: Any, chunk_size: int, chunk_overlap: int) -> Any:
try:
from langchain_text_splitters import RecursiveCharacterTextSplitter
except ImportError: # pragma: no cover - compatibility with older langchain.
from langchain.text_splitter import RecursiveCharacterTextSplitter
def token_len(text: str) -> int:
return len(tokenizer.encode(text, add_special_tokens=False))
return RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=token_len,
)
def split_context(splitter: Any, context: str) -> list[str]:
if not context.strip():
return [""]
chunks = splitter.split_text(context)
return chunks or [""]
def sort_by_similarity(
model: Any,
query: str,
chunks: list[str],
batch_size: int,
) -> tuple[list[str], list[float]]:
import numpy as np
if not chunks:
return [], []
if not query.strip():
return chunks, [0.0] * len(chunks)
query_embedding = model.encode(query, normalize_embeddings=True, show_progress_bar=False)
chunk_embeddings = model.encode(
chunks,
normalize_embeddings=True,
batch_size=batch_size,
show_progress_bar=False,
)
scores = np.dot(np.asarray(chunk_embeddings), np.asarray(query_embedding))
order = np.argsort(scores)[::-1]
return [chunks[i] for i in order], [float(scores[i]) for i in order]
def limit_context(
chunks: list[str],
scores: list[float],
context_topk: int,
) -> tuple[list[str], list[float]]:
if context_topk <= 0:
return chunks, scores
limit = min(len(chunks), context_topk)
return chunks[:limit], scores[:limit]
def file_sha256(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
digest.update(chunk)
return digest.hexdigest()
def convert_file(
in_path: Path,
out_path: Path,
tokenizer: Any,
embedding_model: Any,
chunk_size: int,
chunk_overlap: int,
batch_size: int,
max_samples: int,
context_topk: int,
) -> int:
splitter = make_splitter(tokenizer, chunk_size, chunk_overlap)
out_path.parent.mkdir(parents=True, exist_ok=True)
count = 0
with in_path.open("r", encoding="utf-8") as fin, out_path.open("w", encoding="utf-8") as fout:
for line_no, line in enumerate(fin, 1):
if count >= max_samples:
break
if not line.strip():
continue
try:
raw = json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"Bad JSON in {in_path}:{line_no}: {exc}") from exc
query = str(raw.get("input", ""))
chunks = split_context(splitter, normalize_context(raw.get("context", "")))
chunks, scores = sort_by_similarity(embedding_model, query, chunks, batch_size)
chunks, scores = limit_context(chunks, scores, context_topk)
item = {
"input": query,
"context": chunks,
"answers": normalize_answers(raw.get("answers", [])),
"num_chunks": len(chunks),
"similarity_scores": scores,
}
fout.write(json.dumps(item, ensure_ascii=False) + "\n")
count += 1
return count
def write_metadata(
output_dir: Path,
args: argparse.Namespace,
datasets: list[str],
output_counts: dict[str, int],
input_hashes: dict[str, str],
) -> None:
metadata = {
"script_version": SCRIPT_VERSION,
"datasets": datasets,
"input_dir": str(args.input_dir),
"output_dir": str(args.output_dir),
"tokenizer_path": args.tokenizer_path,
"embedding_model": args.embedding_model,
"chunk_size": args.chunk_size,
"chunk_overlap": args.chunk_overlap,
"context_topk": args.context_topk,
"max_samples_per_dataset": args.max_samples,
"batch_size": args.batch_size,
"device": args.device,
"input_sha256": input_hashes,
"output_counts": output_counts,
}
path = output_dir / "longbench_blend_metadata.json"
path.write_text(json.dumps(metadata, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
print(f"[meta] wrote {path}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--input_dir", type=Path, default=Path("raw_longbench"))
parser.add_argument("--output_dir", type=Path, default=Path("final_data"))
parser.add_argument("--datasets", default=",".join(DATASETS))
parser.add_argument("--tokenizer_path", required=True)
parser.add_argument("--embedding_model", required=True)
parser.add_argument("--chunk_size", type=int, default=512)
parser.add_argument("--chunk_overlap", type=int, default=50)
parser.add_argument("--context_topk", type=int, default=20)
parser.add_argument("--max_samples", type=int, default=200)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--device", default="auto", help="auto, cpu, cuda, or cuda device id")
return parser.parse_args()
def main() -> None:
args = parse_args()
datasets = [name.strip() for name in args.datasets.split(",") if name.strip()]
unknown = sorted(set(datasets) - set(DATASETS))
if unknown:
raise SystemExit(f"Unsupported datasets: {', '.join(unknown)}")
if args.chunk_size <= 0 or args.chunk_overlap < 0 or args.context_topk <= 0 or args.max_samples <= 0:
raise SystemExit(
"--chunk_size, --context_topk, and --max_samples must be positive; "
"--chunk_overlap must be non-negative"
)
tokenizer = load_tokenizer(args.tokenizer_path)
from sentence_transformers import SentenceTransformer
device = None if args.device == "auto" else args.device
embedding_model = SentenceTransformer(args.embedding_model, device=device)
processed = 0
output_counts: dict[str, int] = {}
input_hashes: dict[str, str] = {}
for dataset in datasets:
in_path = args.input_dir / f"{dataset}.jsonl"
if not in_path.exists():
print(f"[skip] missing {in_path}")
continue
input_hashes[dataset] = file_sha256(in_path)
out_path = args.output_dir / f"{dataset}.jsonl"
count = convert_file(
in_path=in_path,
out_path=out_path,
tokenizer=tokenizer,
embedding_model=embedding_model,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
batch_size=args.batch_size,
max_samples=args.max_samples,
context_topk=args.context_topk,
)
processed += count
output_counts[dataset] = count
print(f"[ok] {dataset}: {count} samples -> {out_path}")
if processed == 0:
raise SystemExit("No samples were processed. Check --input_dir and --datasets.")
write_metadata(args.output_dir, args, datasets, output_counts, input_hashes)
if __name__ == "__main__":
main()

Xet Storage Details

Size:
8.24 kB
·
Xet hash:
427bc93389ce5e64f30f81beeead02b1f20d7a8f3967ddb756cb1ffb7ab288b0

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.