import random from dataclasses import dataclass from typing import Dict, Iterable, Iterator, List, Optional, Tuple import torch from torch.utils.data import IterableDataset from datasets import load_dataset from transformers import PreTrainedTokenizerBase import yaml @dataclass class DataSource: name: str hf_path: str hf_name: Optional[str] split: str text_field: str weight: int = 1 streaming: bool = True def load_sources_from_yaml(path: str) -> List[DataSource]: with open(path, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) srcs = [] for s in cfg.get("sources", []): srcs.append(DataSource( name=s.get("name"), hf_path=s.get("hf_path"), hf_name=s.get("hf_name"), split=s.get("split", "train"), text_field=s.get("text_field", "text"), weight=int(s.get("weight", 1)), streaming=bool(s.get("streaming", True)), )) assert len(srcs) > 0, "No data sources configured" return srcs def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]: iters = [] for s in sources: ds = load_dataset(s.hf_path, s.hf_name, split=s.split, streaming=s.streaming) iters.append(iter(ds)) return iters def weighted_choice(weights: List[int]) -> int: total = sum(weights) r = random.randint(1, total) acc = 0 for i, w in enumerate(weights): acc += w if r <= acc: return i return len(weights) - 1 class TokenChunkDataset(IterableDataset): def __init__( self, tokenizer: PreTrainedTokenizerBase, sources: List[DataSource], seq_len: int, eos_token_id: Optional[int] = None, ): super().__init__() self.tok = tokenizer self.sources = sources self.seq_len = seq_len self.eos_id = eos_token_id if eos_token_id is not None else getattr(tokenizer, "eos_token_id", None) self.weights = [max(1, s.weight) for s in sources] def _iter_texts(self) -> Iterator[str]: iters = build_streams(self.sources) while True: i = weighted_choice(self.weights) try: row = next(iters[i]) except StopIteration: try: ds = load_dataset( self.sources[i].hf_path, self.sources[i].hf_name, split=self.sources[i].split, streaming=self.sources[i].streaming ) iters[i] = iter(ds) row = next(iters[i]) except (StopIteration, Exception) as e: print(f"Warning: Could not restart iterator for source {self.sources[i].name}: {e}") continue # Skip this iteration and try next source text = row.get(self.sources[i].text_field, None) if isinstance(text, str) and len(text) > 0: yield text def _safe_encode(self, text: str) -> list: try: return self.tok.encode(text) except Exception as e: print(f"Encoding error for text: {text[:50]}... Error: {e}") return [] def _iter_token_ids(self) -> Iterator[int]: for text in self._iter_texts(): ids = self._safe_encode(text) if self.eos_id is not None: ids.append(self.eos_id) for t in ids: yield t def __iter__(self): buf: List[int] = [] for tok_id in self._iter_token_ids(): buf.append(tok_id) while len(buf) >= self.seq_len + 1: x = torch.tensor(buf[:self.seq_len], dtype=torch.long) y = torch.tensor(buf[1:self.seq_len + 1], dtype=torch.long) del buf[:self.seq_len] yield x, y def __len__(self): # Provide approximate length for progress tracking return 1000000 # Large number for streaming datasets