|
|
import random |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, Iterable, Iterator, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import IterableDataset |
|
|
from datasets import load_dataset |
|
|
from transformers import PreTrainedTokenizerBase |
|
|
import yaml |
|
|
|
|
|
@dataclass |
|
|
class DataSource: |
|
|
name: str |
|
|
hf_path: str |
|
|
hf_name: Optional[str] |
|
|
split: str |
|
|
text_field: str |
|
|
weight: int = 1 |
|
|
streaming: bool = True |
|
|
|
|
|
def load_sources_from_yaml(path: str) -> List[DataSource]: |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
cfg = yaml.safe_load(f) |
|
|
srcs = [] |
|
|
for s in cfg.get("sources", []): |
|
|
srcs.append(DataSource( |
|
|
name=s.get("name"), |
|
|
hf_path=s.get("hf_path"), |
|
|
hf_name=s.get("hf_name"), |
|
|
split=s.get("split", "train"), |
|
|
text_field=s.get("text_field", "text"), |
|
|
weight=int(s.get("weight", 1)), |
|
|
streaming=bool(s.get("streaming", True)), |
|
|
)) |
|
|
assert len(srcs) > 0, "No data sources configured" |
|
|
return srcs |
|
|
|
|
|
def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]: |
|
|
iters = [] |
|
|
for s in sources: |
|
|
ds = load_dataset(s.hf_path, s.hf_name, split=s.split, streaming=s.streaming) |
|
|
iters.append(iter(ds)) |
|
|
return iters |
|
|
|
|
|
def weighted_choice(weights: List[int]) -> int: |
|
|
total = sum(weights) |
|
|
r = random.randint(1, total) |
|
|
acc = 0 |
|
|
for i, w in enumerate(weights): |
|
|
acc += w |
|
|
if r <= acc: |
|
|
return i |
|
|
return len(weights) - 1 |
|
|
|
|
|
class TokenChunkDataset(IterableDataset): |
|
|
def __init__( |
|
|
self, |
|
|
tokenizer: PreTrainedTokenizerBase, |
|
|
sources: List[DataSource], |
|
|
seq_len: int, |
|
|
eos_token_id: Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.tok = tokenizer |
|
|
self.sources = sources |
|
|
self.seq_len = seq_len |
|
|
self.eos_id = eos_token_id if eos_token_id is not None else getattr(tokenizer, "eos_token_id", None) |
|
|
self.weights = [max(1, s.weight) for s in sources] |
|
|
|
|
|
def _iter_texts(self) -> Iterator[str]: |
|
|
iters = build_streams(self.sources) |
|
|
while True: |
|
|
i = weighted_choice(self.weights) |
|
|
try: |
|
|
row = next(iters[i]) |
|
|
except StopIteration: |
|
|
try: |
|
|
ds = load_dataset( |
|
|
self.sources[i].hf_path, |
|
|
self.sources[i].hf_name, |
|
|
split=self.sources[i].split, |
|
|
streaming=self.sources[i].streaming |
|
|
) |
|
|
iters[i] = iter(ds) |
|
|
row = next(iters[i]) |
|
|
except (StopIteration, Exception) as e: |
|
|
print(f"Warning: Could not restart iterator for source {self.sources[i].name}: {e}") |
|
|
continue |
|
|
text = row.get(self.sources[i].text_field, None) |
|
|
if isinstance(text, str) and len(text) > 0: |
|
|
yield text |
|
|
|
|
|
def _safe_encode(self, text: str) -> list: |
|
|
try: |
|
|
return self.tok.encode(text) |
|
|
except Exception as e: |
|
|
print(f"Encoding error for text: {text[:50]}... Error: {e}") |
|
|
return [] |
|
|
|
|
|
def _iter_token_ids(self) -> Iterator[int]: |
|
|
for text in self._iter_texts(): |
|
|
ids = self._safe_encode(text) |
|
|
if self.eos_id is not None: |
|
|
ids.append(self.eos_id) |
|
|
for t in ids: |
|
|
yield t |
|
|
|
|
|
def __iter__(self): |
|
|
buf: List[int] = [] |
|
|
for tok_id in self._iter_token_ids(): |
|
|
buf.append(tok_id) |
|
|
while len(buf) >= self.seq_len + 1: |
|
|
x = torch.tensor(buf[:self.seq_len], dtype=torch.long) |
|
|
y = torch.tensor(buf[1:self.seq_len + 1], dtype=torch.long) |
|
|
del buf[:self.seq_len] |
|
|
yield x, y |
|
|
|
|
|
def __len__(self): |
|
|
|
|
|
return 1000000 |
|
|
|
|
|
|