File size: 4,063 Bytes
a38941f a55cadf a38941f a55cadf a38941f a55cadf a38941f a55cadf a38941f a55cadf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import random
from dataclasses import dataclass
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
import torch
from torch.utils.data import IterableDataset
from datasets import load_dataset
from transformers import PreTrainedTokenizerBase
import yaml
@dataclass
class DataSource:
name: str
hf_path: str
hf_name: Optional[str]
split: str
text_field: str
weight: int = 1
streaming: bool = True
def load_sources_from_yaml(path: str) -> List[DataSource]:
with open(path, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
srcs = []
for s in cfg.get("sources", []):
srcs.append(DataSource(
name=s.get("name"),
hf_path=s.get("hf_path"),
hf_name=s.get("hf_name"),
split=s.get("split", "train"),
text_field=s.get("text_field", "text"),
weight=int(s.get("weight", 1)),
streaming=bool(s.get("streaming", True)),
))
assert len(srcs) > 0, "No data sources configured"
return srcs
def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]:
iters = []
for s in sources:
ds = load_dataset(s.hf_path, s.hf_name, split=s.split, streaming=s.streaming)
iters.append(iter(ds))
return iters
def weighted_choice(weights: List[int]) -> int:
total = sum(weights)
r = random.randint(1, total)
acc = 0
for i, w in enumerate(weights):
acc += w
if r <= acc:
return i
return len(weights) - 1
class TokenChunkDataset(IterableDataset):
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
sources: List[DataSource],
seq_len: int,
eos_token_id: Optional[int] = None,
):
super().__init__()
self.tok = tokenizer
self.sources = sources
self.seq_len = seq_len
self.eos_id = eos_token_id if eos_token_id is not None else getattr(tokenizer, "eos_token_id", None)
self.weights = [max(1, s.weight) for s in sources]
def _iter_texts(self) -> Iterator[str]:
iters = build_streams(self.sources)
while True:
i = weighted_choice(self.weights)
try:
row = next(iters[i])
except StopIteration:
try:
ds = load_dataset(
self.sources[i].hf_path,
self.sources[i].hf_name,
split=self.sources[i].split,
streaming=self.sources[i].streaming
)
iters[i] = iter(ds)
row = next(iters[i])
except (StopIteration, Exception) as e:
print(f"Warning: Could not restart iterator for source {self.sources[i].name}: {e}")
continue # Skip this iteration and try next source
text = row.get(self.sources[i].text_field, None)
if isinstance(text, str) and len(text) > 0:
yield text
def _safe_encode(self, text: str) -> list:
try:
return self.tok.encode(text)
except Exception as e:
print(f"Encoding error for text: {text[:50]}... Error: {e}")
return []
def _iter_token_ids(self) -> Iterator[int]:
for text in self._iter_texts():
ids = self._safe_encode(text)
if self.eos_id is not None:
ids.append(self.eos_id)
for t in ids:
yield t
def __iter__(self):
buf: List[int] = []
for tok_id in self._iter_token_ids():
buf.append(tok_id)
while len(buf) >= self.seq_len + 1:
x = torch.tensor(buf[:self.seq_len], dtype=torch.long)
y = torch.tensor(buf[1:self.seq_len + 1], dtype=torch.long)
del buf[:self.seq_len]
yield x, y
def __len__(self):
# Provide approximate length for progress tracking
return 1000000 # Large number for streaming datasets
|