rayf-07's picture
Upload Ouro-2.6B_smoothquant_W8A8 with bundled source code
b144856 verified
from __future__ import annotations
from typing import Iterable, Iterator, List
from datasets import IterableDataset, load_dataset
from loguru import logger
from torch import device as TorchDevice
from transformers import PreTrainedTokenizerBase
from .config import CalibrationConfig
def collect_calibration_texts(config: CalibrationConfig) -> List[str]:
"""
Fetch calibration samples from a Hugging Face dataset.
Returns a list of raw text prompts that will be consumed by the calibration loop.
"""
if config.sample_count <= 0:
logger.warning("Calibration requested with zero samples; skipping collection.")
return []
logger.info(
f"Loading calibration dataset '{config.dataset_name}' "
f"(split={config.dataset_split}, streaming={config.streaming})..."
)
try:
dataset = load_dataset(
config.dataset_name,
split=config.dataset_split,
streaming=config.streaming,
)
except Exception as exc: # noqa: BLE001 - surface dataset issues to the caller
logger.warning(
f"Unable to load dataset '{config.dataset_name}': {exc}"
)
return []
if isinstance(dataset, IterableDataset):
iterator: Iterable[dict] = dataset
samples: List[str] = []
for example in iterator:
text = example.get(config.text_column)
if not text:
continue
samples.append(str(text))
if len(samples) >= config.sample_count:
break
return samples
if len(dataset) == 0:
logger.warning(f"Dataset '{config.dataset_name}' returned no rows.")
return []
upper = min(config.sample_count, len(dataset))
if config.shuffle:
dataset = dataset.shuffle(seed=config.seed)
selected = dataset.select(range(upper))
texts = []
for entry in selected:
text = entry.get(config.text_column)
if not isinstance(text, str):
continue
texts.append(text)
if not texts:
logger.warning(
f"Failed to collect calibration texts from column '{config.text_column}'."
)
return texts
def iter_tokenized_batches(
tokenizer: PreTrainedTokenizerBase,
texts: Iterable[str],
device: TorchDevice,
batch_size: int,
max_length: int,
) -> Iterator[dict]:
"""
Yield tokenized calibration inputs suitable for feeding into the model.
"""
if batch_size <= 0:
raise ValueError("Batch size must be at least 1 for calibration.")
buffer: List[str] = []
for text in texts:
buffer.append(text)
if len(buffer) < batch_size:
continue
batch_inputs = tokenizer(
buffer,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
).to(device)
yield batch_inputs
buffer.clear()
if buffer:
batch_inputs = tokenizer(
buffer,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
).to(device)
yield batch_inputs