File size: 4,063 Bytes
a38941f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55cadf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a38941f
 
 
 
a55cadf
 
 
 
 
 
 
a38941f
 
a55cadf
a38941f
 
 
 
 
 
 
 
 
 
a55cadf
 
 
a38941f
a55cadf
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import random
from dataclasses import dataclass
from typing import Dict, Iterable, Iterator, List, Optional, Tuple

import torch
from torch.utils.data import IterableDataset
from datasets import load_dataset
from transformers import PreTrainedTokenizerBase
import yaml

@dataclass
class DataSource:
    name: str
    hf_path: str
    hf_name: Optional[str]
    split: str
    text_field: str
    weight: int = 1
    streaming: bool = True

def load_sources_from_yaml(path: str) -> List[DataSource]:
    with open(path, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)
    srcs = []
    for s in cfg.get("sources", []):
        srcs.append(DataSource(
            name=s.get("name"),
            hf_path=s.get("hf_path"),
            hf_name=s.get("hf_name"),
            split=s.get("split", "train"),
            text_field=s.get("text_field", "text"),
            weight=int(s.get("weight", 1)),
            streaming=bool(s.get("streaming", True)),
        ))
    assert len(srcs) > 0, "No data sources configured"
    return srcs

def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]:
    iters = []
    for s in sources:
        ds = load_dataset(s.hf_path, s.hf_name, split=s.split, streaming=s.streaming)
        iters.append(iter(ds))
    return iters

def weighted_choice(weights: List[int]) -> int:
    total = sum(weights)
    r = random.randint(1, total)
    acc = 0
    for i, w in enumerate(weights):
        acc += w
        if r <= acc:
            return i
    return len(weights) - 1

class TokenChunkDataset(IterableDataset):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        sources: List[DataSource],
        seq_len: int,
        eos_token_id: Optional[int] = None,
    ):
        super().__init__()
        self.tok = tokenizer
        self.sources = sources
        self.seq_len = seq_len
        self.eos_id = eos_token_id if eos_token_id is not None else getattr(tokenizer, "eos_token_id", None)
        self.weights = [max(1, s.weight) for s in sources]

    def _iter_texts(self) -> Iterator[str]:
        iters = build_streams(self.sources)
        while True:
            i = weighted_choice(self.weights)
            try:
                row = next(iters[i])
            except StopIteration:
                try:
                    ds = load_dataset(
                        self.sources[i].hf_path,
                        self.sources[i].hf_name,
                        split=self.sources[i].split,
                        streaming=self.sources[i].streaming
                    )
                    iters[i] = iter(ds)
                    row = next(iters[i])
                except (StopIteration, Exception) as e:
                    print(f"Warning: Could not restart iterator for source {self.sources[i].name}: {e}")
                    continue  # Skip this iteration and try next source
            text = row.get(self.sources[i].text_field, None)
            if isinstance(text, str) and len(text) > 0:
                yield text

    def _safe_encode(self, text: str) -> list:
        try:
            return self.tok.encode(text)
        except Exception as e:
            print(f"Encoding error for text: {text[:50]}... Error: {e}")
            return []

    def _iter_token_ids(self) -> Iterator[int]:
        for text in self._iter_texts():
            ids = self._safe_encode(text)
            if self.eos_id is not None:
                ids.append(self.eos_id)
            for t in ids:
                yield t

    def __iter__(self):
        buf: List[int] = []
        for tok_id in self._iter_token_ids():
            buf.append(tok_id)
            while len(buf) >= self.seq_len + 1:
                x = torch.tensor(buf[:self.seq_len], dtype=torch.long)
                y = torch.tensor(buf[1:self.seq_len + 1], dtype=torch.long)
                del buf[:self.seq_len]
                yield x, y

    def __len__(self):
        # Provide approximate length for progress tracking
        return 1000000  # Large number for streaming datasets