AniFileBERT / anifilebert /virtual_dataset.py
ModerRAS's picture
Train virtual-shard anime parser
359ff82
"""Pre-encoded sharded datasets generated by the Rust virtual dataset tool."""
from __future__ import annotations
import bisect
import json
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
import torch
from torch.utils.data import Dataset
class ShardedEncodedDataset(Dataset):
"""Map-style dataset backed by pre-encoded `.npy` shards.
The Rust generator writes compact uint16/int16/u8 arrays. This class loads
one shard at a time and relies on sequential sampling over pre-shuffled
shards, so Python does no tokenization or BIO permutation during training.
"""
preserve_order = True
def __init__(self, dataset_dir: str | Path, manifest_name: str = "manifest.json"):
self.dataset_dir = Path(dataset_dir)
self.manifest_path = self.dataset_dir / manifest_name
self.manifest = json.loads(self.manifest_path.read_text(encoding="utf-8"))
if self.manifest.get("format") != "anifilebert.virtual_dataset.shards.v1":
raise ValueError(f"Unsupported virtual dataset manifest: {self.manifest_path}")
self.max_length = int(self.manifest["max_length"])
self.shards: List[Dict] = list(self.manifest.get("shards") or [])
if not self.shards:
raise ValueError(f"Virtual dataset has no shards: {self.manifest_path}")
self._starts: List[int] = []
total = 0
for shard in self.shards:
self._starts.append(total)
total += int(shard["rows"])
self.total_rows = total
declared_total = int(self.manifest.get("total_rows", total))
if declared_total != total:
raise ValueError(
f"Virtual dataset row count mismatch: manifest total_rows={declared_total}, "
f"shard rows={total}"
)
self._cache_index: Optional[int] = None
self._cache: Optional[Dict[str, np.ndarray]] = None
def __len__(self) -> int:
return self.total_rows
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
if idx < 0:
idx += self.total_rows
if idx < 0 or idx >= self.total_rows:
raise IndexError(idx)
shard_idx = bisect.bisect_right(self._starts, idx) - 1
shard_start = self._starts[shard_idx]
row_idx = idx - shard_start
cache = self._load_shard(shard_idx)
return {
"input_ids": torch.from_numpy(cache["input_ids"][row_idx]),
"attention_mask": torch.from_numpy(cache["attention_mask"][row_idx]),
"labels": torch.from_numpy(cache["labels"][row_idx]),
}
def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]:
if self._cache_index == shard_idx and self._cache is not None:
return self._cache
shard = self.shards[shard_idx]
cache = {
"input_ids": np.load(self.dataset_dir / shard["input_ids"], allow_pickle=False),
"attention_mask": np.load(self.dataset_dir / shard["attention_mask"], allow_pickle=False),
"labels": np.load(self.dataset_dir / shard["labels"], allow_pickle=False),
}
expected_shape = (int(shard["rows"]), self.max_length)
for key, array in cache.items():
if array.shape != expected_shape:
raise ValueError(
f"Shard {shard_idx} {key} has shape {array.shape}, expected {expected_shape}"
)
self._cache_index = shard_idx
self._cache = cache
return cache
class DatasetRangeView(Dataset):
"""A contiguous range view over another dataset."""
preserve_order = True
def __init__(self, dataset: Dataset, start: int, end: int):
if start < 0 or end < start or end > len(dataset):
raise ValueError(f"Invalid dataset range [{start}, {end}) for length {len(dataset)}")
self.dataset = dataset
self.start = start
self.end = end
def __len__(self) -> int:
return self.end - self.start
def __getitem__(self, idx: int):
if idx < 0:
idx += len(self)
if idx < 0 or idx >= len(self):
raise IndexError(idx)
return self.dataset[self.start + idx]