# Copyright (c) 2024-present, BAAI. All Rights Reserved. # Licensed under Apache License, Version 2.0 """Prompt-only dataset for on-policy distillation. Supports: 1) .txt: one prompt per line (small files) 2) .csv shards (Koala-36M): stream caption column from many CSV files This is an IterableDataset to avoid loading 10x 4.89GB CSV shards into memory. Distributed sharding -------------------- ``__iter__`` automatically detects whether torch.distributed is initialised and shards files across ranks using file-level slicing: rank-r reads ``files[r::world_size]`` This keeps each rank's file I/O independent, avoids duplicate prompts within a global batch, and requires zero coordination overhead. """ import csv import glob import os import random from dataclasses import dataclass from pathlib import Path from typing import Iterable, Iterator, List, Optional, Union import torch from torch.utils.data import IterableDataset @dataclass class CSVSpec: caption_field: str = "caption" clarity_field: str = "clarity_score" aesthetic_field: str = "aesthetic_score" min_clarity: Optional[float] = None min_aesthetic: Optional[float] = None def _expand_sources(path_or_glob: str) -> List[str]: """Accept: - file path - directory (all *.csv inside) - glob (Koala_36M_*.csv) - comma-separated list of any of the above """ parts = [p.strip() for p in path_or_glob.split(",") if p.strip()] out: List[str] = [] for p in parts: if any(ch in p for ch in ["*", "?", "[", "]"]): out.extend(sorted(glob.glob(p))) else: pp = Path(p) if pp.is_dir(): out.extend(sorted(str(x) for x in pp.glob("*.csv"))) else: out.append(str(pp)) out = [x for x in out if os.path.exists(x)] if not out: raise FileNotFoundError(f"No files found for prompt source: {path_or_glob}") return out def _maybe_float(x: str) -> Optional[float]: try: return float(x) except Exception: return None class PromptDataset(IterableDataset): """Stream prompts from txt or csv shards. Args: prompt_source: path/dir/glob/comma-list. For Koala, pass something like: --prompt_file "/data/Koala_36M_*.csv" shuffle_files: randomize shard order each epoch. shuffle_buffer: >0 enables approximate shuffle within a sliding buffer. seed: RNG seed. infinite: if True, loops over shards forever (recommended for num_steps training). csv: CSVSpec for caption field and optional score filtering. encoding: file encoding. """ def __init__( self, prompt_source: str, shuffle_files: bool = True, shuffle_buffer: int = 0, seed: int = 42, infinite: bool = True, csv: Optional[CSVSpec] = None, encoding: str = "utf-8", ): super().__init__() self.files = _expand_sources(prompt_source) self.shuffle_files = shuffle_files self.shuffle_buffer = int(shuffle_buffer) self.seed = int(seed) self.infinite = bool(infinite) self.csvspec = csv or CSVSpec() self.encoding = encoding # Decide mode by extension of the first resolved file first = self.files[0].lower() if first.endswith(".txt"): self.mode = "txt" elif first.endswith(".csv"): self.mode = "csv" else: raise ValueError(f"Unsupported prompt source type: {self.files[0]} (expect .txt or .csv)") # ----------------------------------------------------------------------- # Distributed sharding helpers # ----------------------------------------------------------------------- def _get_dist_files(self) -> List[str]: """Return files assigned to this distributed rank (file-level sharding). With world_size=8 and 80 CSV shards, rank-r reads shards [r, r+8, r+16, ...]. Falls back to all files when: - torch.distributed is not initialised, OR - world_size == 1, OR - sliced list would be empty (fewer files than ranks). """ files = self.files try: import torch.distributed as dist if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: rank = dist.get_rank() world_size = dist.get_world_size() sliced = files[rank::world_size] files = sliced if sliced else self.files except Exception: pass return files def _get_rank(self) -> int: """Return current distributed rank (0 when not initialised).""" try: import torch.distributed as dist if dist.is_available() and dist.is_initialized(): return dist.get_rank() except Exception: pass return 0 # ----------------------------------------------------------------------- # Iteration helpers # ----------------------------------------------------------------------- def _iter_txt(self, rng: random.Random, files: List[str] = None) -> Iterator[str]: """Stream lines from .txt files, looping if infinite.""" files = files if files is not None else self.files while True: for fp in files: with open(fp, "r", encoding=self.encoding) as f: for line in f: s = line.strip() if not s or s.startswith("#"): continue yield s if not self.infinite: break def _iter_csv_file(self, fp: str) -> Iterator[str]: # Koala CSV: yield caption column, optionally filter on scores. cs = self.csvspec with open(fp, "r", encoding=self.encoding, newline="") as f: reader = csv.DictReader(f) # Validate schema once if cs.caption_field not in reader.fieldnames: raise KeyError( f"CSV missing caption field '{cs.caption_field}'. " f"Got fields: {reader.fieldnames[:20]}..." ) for row in reader: cap = (row.get(cs.caption_field) or "").strip() if not cap: continue # Optional filters (if enabled) if cs.min_clarity is not None: v = _maybe_float(row.get(cs.clarity_field, "")) if v is None or v < cs.min_clarity: continue if cs.min_aesthetic is not None: v = _maybe_float(row.get(cs.aesthetic_field, "")) if v is None or v < cs.min_aesthetic: continue yield cap def _iter_csv(self, rng: random.Random, files: List[str] = None) -> Iterator[str]: """Iterate CSV shards; approximate shuffle with buffer if requested.""" files = list(files if files is not None else self.files) while True: if self.shuffle_files: rng.shuffle(files) if self.shuffle_buffer > 0: buf: List[str] = [] # Fill and pop randomly for fp in files: for cap in self._iter_csv_file(fp): buf.append(cap) if len(buf) >= self.shuffle_buffer: j = rng.randrange(len(buf)) yield buf.pop(j) # Drain remainder while buf: j = rng.randrange(len(buf)) yield buf.pop(j) else: for fp in files: yield from self._iter_csv_file(fp) if not self.infinite: break def __iter__(self) -> Iterator[str]: # Distributed rank-level file sharding (file-level, assigned once). dist_files = self._get_dist_files() rank = self._get_rank() # Each dataloader worker gets its own RNG (independent per worker AND # per distributed rank so shuffles don't collide across processes). wi = torch.utils.data.get_worker_info() worker_id = 0 if wi is None else wi.id rng = random.Random(self.seed + 1009 * worker_id + 97 * rank) if self.mode == "txt": yield from self._iter_txt(rng, dist_files) else: yield from self._iter_csv(rng, dist_files) def make_collate_fn(tokenizer, max_prompt_length: int, device: torch.device): """Tokenize List[str] -> [B, L] tensor. IMPORTANT: returns a CPU tensor regardless of ``device`` argument. Move to GPU inside the training step (not in the dataloader worker). """ tok_kwargs = { "max_length": max_prompt_length, "padding": "max_length", "padding_side": "left", "truncation": True, "return_tensors": "pt", } def collate_fn(prompts: List[str]) -> torch.Tensor: return tokenizer(prompts, **tok_kwargs).input_ids # CPU tensor return collate_fn class InfiniteDataLoader: """Wraps a DataLoader and cycles indefinitely (works for both map/iter datasets).""" def __init__(self, dataloader): self.dataloader = dataloader self._iter = iter(dataloader) def __next__(self): try: return next(self._iter) except StopIteration: self._iter = iter(self.dataloader) return next(self._iter) def __iter__(self): return self