AsadIsmail's picture
Bundle ternary_quant package directly (private repo fix)
162f86a verified
"""Calibration data loading utilities."""
import torch
from typing import Optional
from transformers import AutoTokenizer
def get_calibration_data(
model_name_or_path: str,
dataset_name: str = "wikitext",
dataset_config: str = "wikitext-2-raw-v1",
split: str = "train",
n_samples: int = 128,
seq_len: int = 2048,
seed: int = 42,
tokenizer: Optional[AutoTokenizer] = None,
) -> torch.Tensor:
"""
Load and tokenize calibration data.
Returns a tensor of shape [n_samples, seq_len] containing token IDs
drawn from the specified dataset.
"""
from datasets import load_dataset
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
dataset = load_dataset(dataset_name, dataset_config, split=split)
# Concatenate all text
text = "\n\n".join(dataset["text"])
tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0]
# Sample random contiguous chunks
rng = torch.Generator()
rng.manual_seed(seed)
total_len = tokens.shape[0]
samples = []
for _ in range(n_samples):
start = torch.randint(0, total_len - seq_len, (1,), generator=rng).item()
samples.append(tokens[start : start + seq_len])
return torch.stack(samples) # [n_samples, seq_len]
def get_calibration_data_c4(
model_name_or_path: str,
n_samples: int = 128,
seq_len: int = 2048,
seed: int = 42,
tokenizer: Optional[AutoTokenizer] = None,
) -> torch.Tensor:
"""Load calibration data from C4 dataset (alternative to WikiText)."""
from datasets import load_dataset
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
dataset = load_dataset(
"allenai/c4",
"en",
split="train",
streaming=True,
)
rng = torch.Generator()
rng.manual_seed(seed)
samples = []
for sample in dataset:
if len(samples) >= n_samples:
break
tokens = tokenizer(
sample["text"], return_tensors="pt", truncation=True, max_length=seq_len
)["input_ids"][0]
if tokens.shape[0] >= seq_len:
samples.append(tokens[:seq_len])
if len(samples) < n_samples:
raise ValueError(
f"Only found {len(samples)} samples with seq_len >= {seq_len}. "
f"Try reducing seq_len or using a different dataset."
)
return torch.stack(samples)