"""Pre-encoded sharded datasets generated by the Rust virtual dataset tool.""" from __future__ import annotations import bisect import json from pathlib import Path from typing import Dict, List, Optional import numpy as np import torch from torch.utils.data import Dataset class ShardedEncodedDataset(Dataset): """Map-style dataset backed by pre-encoded `.npy` shards. The Rust generator writes compact uint16/int16/u8 arrays. This class loads one shard at a time and relies on sequential sampling over pre-shuffled shards, so Python does no tokenization or BIO permutation during training. """ preserve_order = True def __init__(self, dataset_dir: str | Path, manifest_name: str = "manifest.json"): self.dataset_dir = Path(dataset_dir) self.manifest_path = self.dataset_dir / manifest_name self.manifest = json.loads(self.manifest_path.read_text(encoding="utf-8")) if self.manifest.get("format") != "anifilebert.virtual_dataset.shards.v1": raise ValueError(f"Unsupported virtual dataset manifest: {self.manifest_path}") self.max_length = int(self.manifest["max_length"]) self.shards: List[Dict] = list(self.manifest.get("shards") or []) if not self.shards: raise ValueError(f"Virtual dataset has no shards: {self.manifest_path}") self._starts: List[int] = [] total = 0 for shard in self.shards: self._starts.append(total) total += int(shard["rows"]) self.total_rows = total declared_total = int(self.manifest.get("total_rows", total)) if declared_total != total: raise ValueError( f"Virtual dataset row count mismatch: manifest total_rows={declared_total}, " f"shard rows={total}" ) self._cache_index: Optional[int] = None self._cache: Optional[Dict[str, np.ndarray]] = None def __len__(self) -> int: return self.total_rows def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: if idx < 0: idx += self.total_rows if idx < 0 or idx >= self.total_rows: raise IndexError(idx) shard_idx = bisect.bisect_right(self._starts, idx) - 1 shard_start = self._starts[shard_idx] row_idx = idx - shard_start cache = self._load_shard(shard_idx) return { "input_ids": torch.from_numpy(cache["input_ids"][row_idx]), "attention_mask": torch.from_numpy(cache["attention_mask"][row_idx]), "labels": torch.from_numpy(cache["labels"][row_idx]), } def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]: if self._cache_index == shard_idx and self._cache is not None: return self._cache shard = self.shards[shard_idx] cache = { "input_ids": np.load(self.dataset_dir / shard["input_ids"], allow_pickle=False), "attention_mask": np.load(self.dataset_dir / shard["attention_mask"], allow_pickle=False), "labels": np.load(self.dataset_dir / shard["labels"], allow_pickle=False), } expected_shape = (int(shard["rows"]), self.max_length) for key, array in cache.items(): if array.shape != expected_shape: raise ValueError( f"Shard {shard_idx} {key} has shape {array.shape}, expected {expected_shape}" ) self._cache_index = shard_idx self._cache = cache return cache class DatasetRangeView(Dataset): """A contiguous range view over another dataset.""" preserve_order = True def __init__(self, dataset: Dataset, start: int, end: int): if start < 0 or end < start or end > len(dataset): raise ValueError(f"Invalid dataset range [{start}, {end}) for length {len(dataset)}") self.dataset = dataset self.start = start self.end = end def __len__(self) -> int: return self.end - self.start def __getitem__(self, idx: int): if idx < 0: idx += len(self) if idx < 0 or idx >= len(self): raise IndexError(idx) return self.dataset[self.start + idx]