File size: 4,245 Bytes
359ff82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Pre-encoded sharded datasets generated by the Rust virtual dataset tool."""

from __future__ import annotations

import bisect
import json
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np
import torch
from torch.utils.data import Dataset


class ShardedEncodedDataset(Dataset):
    """Map-style dataset backed by pre-encoded `.npy` shards.

    The Rust generator writes compact uint16/int16/u8 arrays. This class loads
    one shard at a time and relies on sequential sampling over pre-shuffled
    shards, so Python does no tokenization or BIO permutation during training.
    """

    preserve_order = True

    def __init__(self, dataset_dir: str | Path, manifest_name: str = "manifest.json"):
        self.dataset_dir = Path(dataset_dir)
        self.manifest_path = self.dataset_dir / manifest_name
        self.manifest = json.loads(self.manifest_path.read_text(encoding="utf-8"))
        if self.manifest.get("format") != "anifilebert.virtual_dataset.shards.v1":
            raise ValueError(f"Unsupported virtual dataset manifest: {self.manifest_path}")

        self.max_length = int(self.manifest["max_length"])
        self.shards: List[Dict] = list(self.manifest.get("shards") or [])
        if not self.shards:
            raise ValueError(f"Virtual dataset has no shards: {self.manifest_path}")

        self._starts: List[int] = []
        total = 0
        for shard in self.shards:
            self._starts.append(total)
            total += int(shard["rows"])
        self.total_rows = total

        declared_total = int(self.manifest.get("total_rows", total))
        if declared_total != total:
            raise ValueError(
                f"Virtual dataset row count mismatch: manifest total_rows={declared_total}, "
                f"shard rows={total}"
            )

        self._cache_index: Optional[int] = None
        self._cache: Optional[Dict[str, np.ndarray]] = None

    def __len__(self) -> int:
        return self.total_rows

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        if idx < 0:
            idx += self.total_rows
        if idx < 0 or idx >= self.total_rows:
            raise IndexError(idx)

        shard_idx = bisect.bisect_right(self._starts, idx) - 1
        shard_start = self._starts[shard_idx]
        row_idx = idx - shard_start
        cache = self._load_shard(shard_idx)
        return {
            "input_ids": torch.from_numpy(cache["input_ids"][row_idx]),
            "attention_mask": torch.from_numpy(cache["attention_mask"][row_idx]),
            "labels": torch.from_numpy(cache["labels"][row_idx]),
        }

    def _load_shard(self, shard_idx: int) -> Dict[str, np.ndarray]:
        if self._cache_index == shard_idx and self._cache is not None:
            return self._cache

        shard = self.shards[shard_idx]
        cache = {
            "input_ids": np.load(self.dataset_dir / shard["input_ids"], allow_pickle=False),
            "attention_mask": np.load(self.dataset_dir / shard["attention_mask"], allow_pickle=False),
            "labels": np.load(self.dataset_dir / shard["labels"], allow_pickle=False),
        }
        expected_shape = (int(shard["rows"]), self.max_length)
        for key, array in cache.items():
            if array.shape != expected_shape:
                raise ValueError(
                    f"Shard {shard_idx} {key} has shape {array.shape}, expected {expected_shape}"
                )
        self._cache_index = shard_idx
        self._cache = cache
        return cache


class DatasetRangeView(Dataset):
    """A contiguous range view over another dataset."""

    preserve_order = True

    def __init__(self, dataset: Dataset, start: int, end: int):
        if start < 0 or end < start or end > len(dataset):
            raise ValueError(f"Invalid dataset range [{start}, {end}) for length {len(dataset)}")
        self.dataset = dataset
        self.start = start
        self.end = end

    def __len__(self) -> int:
        return self.end - self.start

    def __getitem__(self, idx: int):
        if idx < 0:
            idx += len(self)
        if idx < 0 or idx >= len(self):
            raise IndexError(idx)
        return self.dataset[self.start + idx]