| from __future__ import annotations |
|
|
| from typing import Iterable, Iterator, List |
|
|
| from datasets import IterableDataset, load_dataset |
| from loguru import logger |
| from torch import device as TorchDevice |
| from transformers import PreTrainedTokenizerBase |
|
|
| from .config import CalibrationConfig |
|
|
|
|
| def collect_calibration_texts(config: CalibrationConfig) -> List[str]: |
| """ |
| Fetch calibration samples from a Hugging Face dataset. |
| |
| Returns a list of raw text prompts that will be consumed by the calibration loop. |
| """ |
|
|
| if config.sample_count <= 0: |
| logger.warning("Calibration requested with zero samples; skipping collection.") |
| return [] |
|
|
| logger.info( |
| f"Loading calibration dataset '{config.dataset_name}' " |
| f"(split={config.dataset_split}, streaming={config.streaming})..." |
| ) |
| try: |
| dataset = load_dataset( |
| config.dataset_name, |
| split=config.dataset_split, |
| streaming=config.streaming, |
| ) |
| except Exception as exc: |
| logger.warning( |
| f"Unable to load dataset '{config.dataset_name}': {exc}" |
| ) |
| return [] |
|
|
| if isinstance(dataset, IterableDataset): |
| iterator: Iterable[dict] = dataset |
| samples: List[str] = [] |
| for example in iterator: |
| text = example.get(config.text_column) |
| if not text: |
| continue |
| samples.append(str(text)) |
| if len(samples) >= config.sample_count: |
| break |
| return samples |
|
|
| if len(dataset) == 0: |
| logger.warning(f"Dataset '{config.dataset_name}' returned no rows.") |
| return [] |
|
|
| upper = min(config.sample_count, len(dataset)) |
| if config.shuffle: |
| dataset = dataset.shuffle(seed=config.seed) |
| selected = dataset.select(range(upper)) |
| texts = [] |
| for entry in selected: |
| text = entry.get(config.text_column) |
| if not isinstance(text, str): |
| continue |
| texts.append(text) |
| if not texts: |
| logger.warning( |
| f"Failed to collect calibration texts from column '{config.text_column}'." |
| ) |
| return texts |
|
|
|
|
| def iter_tokenized_batches( |
| tokenizer: PreTrainedTokenizerBase, |
| texts: Iterable[str], |
| device: TorchDevice, |
| batch_size: int, |
| max_length: int, |
| ) -> Iterator[dict]: |
| """ |
| Yield tokenized calibration inputs suitable for feeding into the model. |
| """ |
|
|
| if batch_size <= 0: |
| raise ValueError("Batch size must be at least 1 for calibration.") |
|
|
| buffer: List[str] = [] |
| for text in texts: |
| buffer.append(text) |
| if len(buffer) < batch_size: |
| continue |
|
|
| batch_inputs = tokenizer( |
| buffer, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt", |
| ).to(device) |
| yield batch_inputs |
| buffer.clear() |
|
|
| if buffer: |
| batch_inputs = tokenizer( |
| buffer, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt", |
| ).to(device) |
| yield batch_inputs |
|
|