Spaces:
Running
Running
| """Calibration data loading utilities.""" | |
| import torch | |
| from typing import Optional | |
| from transformers import AutoTokenizer | |
| def get_calibration_data( | |
| model_name_or_path: str, | |
| dataset_name: str = "wikitext", | |
| dataset_config: str = "wikitext-2-raw-v1", | |
| split: str = "train", | |
| n_samples: int = 128, | |
| seq_len: int = 2048, | |
| seed: int = 42, | |
| tokenizer: Optional[AutoTokenizer] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Load and tokenize calibration data. | |
| Returns a tensor of shape [n_samples, seq_len] containing token IDs | |
| drawn from the specified dataset. | |
| """ | |
| from datasets import load_dataset | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| dataset = load_dataset(dataset_name, dataset_config, split=split) | |
| # Concatenate all text | |
| text = "\n\n".join(dataset["text"]) | |
| tokens = tokenizer(text, return_tensors="pt", truncation=False)["input_ids"][0] | |
| # Sample random contiguous chunks | |
| rng = torch.Generator() | |
| rng.manual_seed(seed) | |
| total_len = tokens.shape[0] | |
| samples = [] | |
| for _ in range(n_samples): | |
| start = torch.randint(0, total_len - seq_len, (1,), generator=rng).item() | |
| samples.append(tokens[start : start + seq_len]) | |
| return torch.stack(samples) # [n_samples, seq_len] | |
| def get_calibration_data_c4( | |
| model_name_or_path: str, | |
| n_samples: int = 128, | |
| seq_len: int = 2048, | |
| seed: int = 42, | |
| tokenizer: Optional[AutoTokenizer] = None, | |
| ) -> torch.Tensor: | |
| """Load calibration data from C4 dataset (alternative to WikiText).""" | |
| from datasets import load_dataset | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| dataset = load_dataset( | |
| "allenai/c4", | |
| "en", | |
| split="train", | |
| streaming=True, | |
| ) | |
| rng = torch.Generator() | |
| rng.manual_seed(seed) | |
| samples = [] | |
| for sample in dataset: | |
| if len(samples) >= n_samples: | |
| break | |
| tokens = tokenizer( | |
| sample["text"], return_tensors="pt", truncation=True, max_length=seq_len | |
| )["input_ids"][0] | |
| if tokens.shape[0] >= seq_len: | |
| samples.append(tokens[:seq_len]) | |
| if len(samples) < n_samples: | |
| raise ValueError( | |
| f"Only found {len(samples)} samples with seq_len >= {seq_len}. " | |
| f"Try reducing seq_len or using a different dataset." | |
| ) | |
| return torch.stack(samples) | |