# src/dataset.py """ Production-ready dataset + dataloader utilities. Rules (because we're adults): - Data drives design. Inputs are rows with columns: ["cds_DNA", "protein_seq", "Taxon", (optional) "RefseqID"]. - Output per sample is a tiny dict the model actually needs. Nothing else. - We stream Parquet by row groups, CSV by chunks. No full-file pandas nonsense on big data. - We shard by (FSDP rank × dataloader worker). No DistributedSampler needed. - We do a simple streaming shuffle buffer for train. Good enough. No fancy "epoch managers". Fields emitted per sample (for collate_fn and trainer): { "species_name": str, "species_id": int, "protein_seq": str, # raw AA (ESM tokenized later) "aa_len": int, "codon_ids": List[int], # tokenized 3-mer ids + EOS at the end "refseq_id": str, "protein_refseq_id": str, "control_mode": "fixed", "meta": {"src": "parquet|csv", "file": basename, "row": int} } Invariants: - cds_DNA length divisible by 3 after trimming to match protein length. - DNA uses only ACGT (uppercase). If not, we skip the row. We don't "helpfully fix" broken data. - We truncate both DNA and protein to the same min length (codon count). - EOS appended to codon_ids; PAD is handled at collate time, not here. Dependencies: - pyarrow only if you read parquet. If it isn't installed and you pass parquet files, we fail loudly. """ from __future__ import annotations import os import json import glob import random import logging import heapq from typing import Dict, List, Any, Optional, Iterable, Tuple from pathlib import Path import numpy as np import pandas as pd import torch from torch.utils.data import IterableDataset, Dataset, DataLoader, get_worker_info try: from tqdm.auto import tqdm as _tqdm except Exception: # pragma: no cover - tqdm might be unavailable in minimal envs _tqdm = None logger = logging.getLogger(__name__) # ------------------------------ # Species Embedding Store (kept simple and stable) # ------------------------------ class SpeciesEmbeddingStore: def __init__(self, embeddings_dir: str, dtype: str = "float32", pin_memory: bool = False, pooling: str = "last"): self.embeddings_dir = Path(embeddings_dir) self.pin_memory = bool(pin_memory) self.is_legacy = False self.pooling = pooling vocab_path = self.embeddings_dir / "species_vocab.json" if not vocab_path.exists(): raise FileNotFoundError(f"Species vocabulary not found at {vocab_path}") with open(vocab_path, "r") as f: self.vocab: Dict[str, int] = json.load(f) meta_path = self.embeddings_dir / "species_metadata.json" new_emb_path = self.embeddings_dir / "species_embeddings.bin" legacy_index = self.embeddings_dir / "species_index.json" legacy_emb = self.embeddings_dir / "species_tok_emb.bin" if self.pooling == "sequence" and legacy_index.exists() and legacy_emb.exists(): self.is_legacy = True self._load_legacy_format(dtype) return if meta_path.exists() and new_emb_path.exists(): with open(meta_path, "r") as f: meta = json.load(f) self.num_species = int(meta["num_species"]) self._ds = int(meta["embedding_dim"]) self.embedding_type = str(meta.get("embedding_type", "fixed_size")) np_dtype = np.float16 if dtype == "float16" else np.float32 self.embeddings = np.memmap(new_emb_path, dtype=np_dtype, mode="r", shape=(self.num_species, self._ds)) self._np_dtype = np_dtype print(f"Loaded fixed-size species embeddings: {len(self.vocab)} species, Ds={self._ds}, dtype={self._np_dtype}") else: self.is_legacy = True self._load_legacy_format(dtype) def _load_legacy_format(self, dtype: str): index_path = self.embeddings_dir / "species_index.json" if not index_path.exists(): raise FileNotFoundError(f"Species index not found at {index_path}") with open(index_path, "r") as f: raw_index = json.load(f) self.index: Dict[str, Dict[str, int]] = {str(k): v for k, v in raw_index.items()} meta_path = self.embeddings_dir / "metadata.json" file_dtype = dtype if meta_path.exists(): with open(meta_path, "r") as f: meta = json.load(f) self._ds = int(meta.get("embedding_dim", 1024)) file_dtype = str(meta.get("dtype", dtype)).lower() else: self._ds = 1024 emb_path = self.embeddings_dir / "species_tok_emb.bin" if not emb_path.exists(): raise FileNotFoundError(f"Species embeddings not found at {emb_path}") np_dtype = np.float16 if file_dtype == "float16" else np.float32 itemsize = np.dtype(np_dtype).itemsize file_bytes = os.path.getsize(emb_path) if file_bytes % (self._ds * itemsize) != 0: raise ValueError(f"Emb file size {file_bytes} not divisible by Ds*itemsize ({self._ds}*{itemsize})") total_tokens = file_bytes // (self._ds * itemsize) self.embeddings = np.memmap(emb_path, dtype=np_dtype, mode="r", shape=(total_tokens, self._ds)) self._np_dtype = np_dtype self.num_species = len(self.vocab) print(f"[LEGACY] variable-length embeddings: {len(self.vocab)} species, {total_tokens} tokens total, Ds={self._ds}.") def load_vocab(self) -> Dict[str, int]: return self.vocab.copy() def _deterministic_stub(self, length: int = None) -> torch.FloatTensor: if self.is_legacy and length: t = torch.zeros(1, length, self._ds, dtype=torch.float32) else: t = torch.zeros(1, self._ds, dtype=torch.float32) return t def get(self, species_id: int) -> torch.FloatTensor: if not self.is_legacy: if species_id < 0 or species_id >= getattr(self, "num_species", 0): return self._deterministic_stub() emb = self.embeddings[species_id] tensor = torch.from_numpy(np.asarray(emb).copy()).float().unsqueeze(0) return tensor else: sid = str(species_id) entry = self.index.get(sid) if entry is None: return self._deterministic_stub(length=8) offset = int(entry["offset"]); length = int(entry["length"]) view = self.embeddings[offset: offset + length] tensor = torch.from_numpy(np.asarray(view).copy()).float().unsqueeze(0) return tensor def batch_get(self, species_ids: List[int]) -> Any: if torch.is_tensor(species_ids): species_ids = species_ids.detach().cpu().tolist() else: species_ids = [int(x) for x in species_ids] B = len(species_ids) if not self.is_legacy: batch_emb = torch.zeros(B, self._ds, dtype=torch.float32) for i, sid in enumerate(species_ids): batch_emb[i] = self.get(sid).squeeze(0) return batch_emb else: tensors = [self.get(sid) for sid in species_ids] lengths = torch.tensor([t.shape[1] for t in tensors], dtype=torch.long) Ls_max = int(lengths.max().item()) if lengths.numel() > 0 else 0 padded = torch.zeros(B, Ls_max, self._ds, dtype=torch.float32) for i, t in enumerate(tensors): L = t.shape[1]; padded[i, :L] = t.squeeze(0) return padded, lengths def Ds(self) -> int: return self._ds def _is_parquet(path: str) -> bool: lower = path.lower() return lower.endswith(".parquet") or lower.endswith(".parq") def _is_csv(path: str) -> bool: lower = path.lower() return ( lower.endswith(".csv") or lower.endswith(".tsv") or lower.endswith(".csv.gz") or lower.endswith(".tsv.gz") ) def _expand_paths(maybe_path_or_glob: str | List[str]) -> List[str]: """ Expand a path/glob or list of them into a sorted, de-duplicated list of files. We prioritize parquet, then csv/tsv. """ paths: List[str] = [] if isinstance(maybe_path_or_glob, str): p = Path(maybe_path_or_glob) if p.is_dir(): # Scan directory for parquet first, then csv/tsv paths.extend(sorted(str(x) for x in p.rglob("*.parquet"))) paths.extend(sorted(str(x) for x in p.rglob("*.parq"))) paths.extend(sorted(str(x) for x in p.rglob("*.csv"))) paths.extend(sorted(str(x) for x in p.rglob("*.tsv"))) paths.extend(sorted(str(x) for x in p.rglob("*.csv.gz"))) paths.extend(sorted(str(x) for x in p.rglob("*.tsv.gz"))) else: paths = sorted(glob.glob(str(p))) else: for it in maybe_path_or_glob: paths.extend(_expand_paths(it)) # Dedup while preserving order seen = set() out = [] for x in paths: if x not in seen: out.append(x) seen.add(x) if not out: raise FileNotFoundError(f"No input files found for: {maybe_path_or_glob}") return out def _dist_info() -> Tuple[int, int]: """ Returns (num_global_workers, global_worker_id) where global_worker_id = rank * num_workers + worker_id. """ world_size = 1 rank = 0 try: import torch.distributed as dist if dist.is_available() and dist.is_initialized(): world_size = dist.get_world_size() rank = dist.get_rank() except Exception: pass wi = get_worker_info() nw = wi.num_workers if wi else 1 wid = wi.id if wi else 0 return world_size * nw, rank * nw + wid class _ResumeSkipProgress: """Lightweight progress helper for resume skips.""" def __init__(self, total: int, label: str): self.total = int(max(0, total)) self.label = label self.count = 0 self._bar = None if self.total <= 0: return if _tqdm is not None: self._bar = _tqdm(total=self.total, desc=label, unit="sample", dynamic_ncols=True, leave=False) else: logger.info("%s: skipping %d samples to reach resume cursor", label, self.total) def update(self, n: int = 1): if self.total <= 0: return self.count += int(n) if self._bar is not None: self._bar.update(n) else: if self.count == self.total or self.count % 10000 == 0: logger.info("%s: skipped %d / %d", self.label, self.count, self.total) def close(self): if self.total <= 0: return if self._bar is not None: self._bar.close() logger.info("%s: resume skip finished (%d samples)", self.label, self.count) class StreamSeqDataset(IterableDataset): """ Streaming dataset with **non-overlapping Parquet row-group sharding**. - Accepts list of files (parquet and/or csv/tsv). - **Parquet**: we enumerate (file, row_group) tasks and stride them across the *global* worker id to avoid duplicates and to keep all ranks busy even with few files. - **CSV/TSV**: assigned at file granularity (one worker reads a file). If you have only a few CSV files and many ranks, some ranks may get no CSV work. (Parquet is the recommended format at scale.) - CSV is read with pandas chunksize to keep memory usage sane. - Each Parquet task reads exactly **one row group** into pandas. Minimal resume support: - set_resume_skip(N) skips N yielded samples across the worker's assigned tasks. (Use a **per-rank** skip value in your trainer so multi-node resumes stay in lockstep.) Output sample schema: { "species_name": str, "species_id": int, "protein_seq": str, # raw AA (ESM tokenized later) "aa_len": int, "codon_ids": List[int], # tokenized 3-mer ids + EOS at the end "refseq_id": str, "protein_refseq_id": str, "control_mode": "fixed", "meta": {"src": "parquet|csv", "file": basename, "row": int} } """ # Canonical required columns. We also accept common aliases (e.g., 'taxon'). REQUIRED = ["cds_DNA", "protein_seq", "Taxon"] def __init__( self, files: List[str], tokenizer, species_vocab_path: str, unknown_species_id: int = 0, csv_chunksize: int = 200_000, shuffle_buffer: int = 0, seed: int = 1234, shard_across_ranks: bool = True, ): super().__init__() self.files = files self.tok = tokenizer with open(species_vocab_path, "r") as f: self.species_vocab: Dict[str, int] = json.load(f) self.unknown_species_id = int(unknown_species_id) self.csv_chunksize = int(max(1, csv_chunksize)) self.shuffle_buffer = int(max(0, shuffle_buffer)) self.seed = int(seed) # When False, every rank iterates over the full task list instead of # taking a disjoint shard. This keeps FSDP collectives aligned during # evaluation even if the validation dataset is smaller than WORLD_SIZE. self.shard_across_ranks = bool(shard_across_ranks) # Minimal resume cursor self._resume_skip_n: int = 0 self._offset_start: int = 0 self._emitted: int = 0 # ---- resume cursor (minimal) ---- def set_resume_skip(self, n: int) -> None: n = int(max(0, n)) self._resume_skip_n = n self._offset_start = n self._emitted = 0 def get_stream_position(self) -> int: # Total yielded so far since dataset creation, including initial skip offset return int(self._offset_start + self._emitted) # ---- core row-wise iterator on a pandas DataFrame ---- def _iter_df(self, df: pd.DataFrame, src: str, file: str) -> Iterable[Dict[str, Any]]: # Normalize common column aliases before validating. # Some shards use lowercase `taxon` instead of `Taxon`. if "Taxon" not in df.columns and "taxon" in df.columns: df = df.rename(columns={"taxon": "Taxon"}) # Hard fail if required missing for c in self.REQUIRED: if c not in df.columns: raise ValueError(f"Input missing required column '{c}' in {file}") # Normalize & clean df = df[self.REQUIRED + ([c for c in ["RefseqID"] if c in df.columns])] df["Taxon"] = df["Taxon"].astype(str).str.strip() df["protein_seq"] = df["protein_seq"].astype(str).str.strip().str.upper() df["cds_DNA"] = df["cds_DNA"].astype(str).str.strip().str.upper() # Filter DNA: ACGT only and length > 0 ok_mask = (df["cds_DNA"].str.len() > 0) & df["cds_DNA"].str.fullmatch(r"[ACGT]+", na=False) df = df[ok_mask] if df.empty: return # Trim protein/DNA to shared min length (in codons) cds_codons = (df["cds_DNA"].str.len() // 3).astype(int) prot_len = df["protein_seq"].str.len().astype(int) min_len = np.minimum(cds_codons.values, prot_len.values) df = df.assign(__min_len=min_len) df = df[df["__min_len"] > 0] if df.empty: return # Species id map def map_species(x: str) -> int: try: return int(self.species_vocab.get(x, self.unknown_species_id)) except Exception: return self.unknown_species_id species_ids = [map_species(x) for x in df["Taxon"].tolist()] refseq_col = "RefseqID" if "RefseqID" in df.columns else None for i, (row_idx, row) in enumerate(df.iterrows()): ml = int(row["__min_len"]) cds = row["cds_DNA"][: ml * 3] prot = row["protein_seq"][: ml] if (len(cds) // 3) != len(prot): continue # Tokenize DNA → 3-mer ids; append EOS codon_ids = self.tok.encode_codon_seq(cds, validate=False) codon_ids.append( self.tok.special_ids.eos if hasattr(self.tok, "special_ids") else self.tok._special_ids.eos ) species_id = species_ids[i] ref_id = row[refseq_col] if refseq_col else f"{Path(file).stem}:{int(row_idx)}" yield { "species_name": row["Taxon"], "species_id": int(species_id), "protein_seq": prot, "aa_len": len(prot), "codon_ids": codon_ids, "refseq_id": ref_id, "protein_refseq_id": ref_id, "control_mode": "fixed", "meta": {"src": src, "file": os.path.basename(file), "row": int(row_idx)}, } # ---- Parquet helpers: enumerate row-group tasks & read one row group ---- def _enumerate_tasks(self, files: List[str]) -> List[Tuple[str, str, Optional[int], int]]: """ Return a task list of tuples: ("parquet", path, row_group_idx, weight) for each row group in each Parquet file ("csv", path, None, weight) for each CSV/TSV file """ tasks: List[Tuple[str, str, Optional[int], int]] = [] parquet_files = [f for f in files if _is_parquet(f)] csv_files = [f for f in files if _is_csv(f)] if parquet_files: try: import pyarrow.parquet as pq # type: ignore except Exception as e: raise ImportError("pyarrow is required to read parquet files") from e for fp in parquet_files: pf = pq.ParquetFile(fp) nrg = int(pf.num_row_groups or 0) if nrg <= 0: # Treat as single task if row groups unavailable (unusual) total_rows = pf.metadata.num_rows if pf.metadata and pf.metadata.num_rows is not None else 1 tasks.append(("parquet", fp, 0, max(1, int(total_rows)))) else: for rg in range(nrg): if pf.metadata is not None: rg_meta = pf.metadata.row_group(rg) num_rows = rg_meta.num_rows if rg_meta.num_rows is not None else 0 else: num_rows = 0 tasks.append(("parquet", fp, rg, max(1, int(num_rows)))) # CSV/TSV files remain file-level tasks for fp in csv_files: file_size = os.path.getsize(fp) # Assume ~256 bytes per record when estimating CSV row counts (empirical default) est_rows = max(1, int(file_size // 256)) tasks.append(("csv", fp, None, est_rows)) # Keep a deterministic order # (files are already sorted by _expand_paths) return tasks @staticmethod def _balanced_partition(tasks: List[Tuple[str, str, Optional[int], int]], groups: int) -> List[List[Tuple[str, str, Optional[int], int]]]: if groups <= 1: return [tasks] if not tasks: return [[] for _ in range(groups)] # Greedy load balancing: assign heavier tasks first to the lightest bucket. indexed = [(idx, kind, path, rg, weight) for idx, (kind, path, rg, weight) in enumerate(tasks)] tasks_sorted = sorted( indexed, key=lambda entry: (entry[4], -entry[0]), reverse=True, ) heap: List[Tuple[int, int]] = [(0, bucket_idx) for bucket_idx in range(groups)] heapq.heapify(heap) buckets: List[List[Tuple[int, str, str, Optional[int], int]]] = [[] for _ in range(groups)] for original_index, kind, path, rg, weight in tasks_sorted: load, bucket_idx = heapq.heappop(heap) buckets[bucket_idx].append((original_index, kind, path, rg, weight)) heapq.heappush(heap, (load + weight, bucket_idx)) partitions: List[List[Tuple[str, str, Optional[int], int]]] = [] for bucket in buckets: bucket.sort(key=lambda entry: entry[0]) partitions.append([(kind, path, rg, weight) for (_idx, kind, path, rg, weight) in bucket]) return partitions def _parquet_rowgroup_iter( self, file: str, row_group_idx: int, cols_cache: Dict[str, List[str]] ) -> Iterable[Dict[str, Any]]: import pyarrow.parquet as pq # safe: checked in _enumerate_tasks pf = pq.ParquetFile(file) # Cache the column subset per file so we don't recompute if file not in cols_cache: names = set(pf.schema.names) cols: List[str] = [] # Required columns, with alias support (notably Taxon vs taxon). for c in self.REQUIRED: if c in names: cols.append(c) continue if c == "Taxon" and "taxon" in names: cols.append("taxon") continue # Optional debug id if "RefseqID" in names: cols.append("RefseqID") cols_cache[file] = cols cols = cols_cache[file] table = pf.read_row_group(row_group_idx, columns=cols) df = table.to_pandas(types_mapper=None) yield from self._iter_df(df, "parquet", file) def _csv_file_iter(self, file: str) -> Iterable[Dict[str, Any]]: # One worker owns this file (non-overlapping assignment) for chunk in pd.read_csv(file, chunksize=self.csv_chunksize, dtype=str, keep_default_na=False): yield from self._iter_df(chunk, "csv", file) # ---- main iterator ---- def __iter__(self): wi = get_worker_info() num_workers = wi.num_workers if wi else 1 worker_id = wi.id if wi else 0 num_global, gid = _dist_info() if not self.shard_across_ranks: num_global = max(1, num_workers) gid = worker_id workers_per_rank = max(1, num_workers) rank = gid // workers_per_rank if self.shard_across_ranks else 0 world = max(1, num_global // workers_per_rank) # Each rank may have a non-zero per-rank resume skip. Split evenly across local # dataloader workers so the sum equals the per-rank target, then apply a fast # task-level skip to avoid row-by-row scans for huge cursors. per_rank_skip = int(self._resume_skip_n) base = per_rank_skip // max(1, workers_per_rank) rem = per_rank_skip % max(1, workers_per_rank) local_skip_target = base + (1 if worker_id < rem else 0) progress: Optional[_ResumeSkipProgress] = None # Build the global task list (parquet row groups + csv files) and shard by gid tasks = self._enumerate_tasks(self.files) if tasks: partitions = self._balanced_partition(tasks, max(1, num_global)) my_tasks_full = partitions[gid] if gid < len(partitions) else [] else: my_tasks_full = [] if local_skip_target > 0 and worker_id == 0: label = ( "resume skip" if world == 1 else f"resume skip (rank {rank}/{world})" ) progress = _ResumeSkipProgress(local_skip_target, label) # Fast task-level skip: consume whole tasks when their weight is <= remaining skip # and only fall back to row-level skipping for the first partial task. skip_remaining = int(local_skip_target) start_idx = 0 partial_task_idx = None partial_task_kind = None partial_task_path = None partial_task_rg = None if skip_remaining > 0 and my_tasks_full: for idx, (kind, path, rg, weight) in enumerate(my_tasks_full): w = int(weight) if weight is not None else 0 if w <= 0: continue if skip_remaining >= w: skip_remaining -= w start_idx = idx + 1 if progress is not None: progress.update(w) else: partial_task_idx = idx partial_task_kind = kind partial_task_path = path partial_task_rg = rg break # Slice my task list to start after any fully-skipped tasks my_tasks = [(kind, path, rg) for (kind, path, rg, _w) in my_tasks_full[start_idx:]] rng = random.Random(self.seed + gid) buffer: List[Dict[str, Any]] = [] bufN = self.shuffle_buffer def _drain_buffer(): if not buffer: return if bufN > 0: rng.shuffle(buffer) for it in buffer: yield it buffer.clear() # Skip counter for resume cursor (row-level remainder after task skips) skipped = int(local_skip_target - skip_remaining) # Cache for per-file Parquet column selection cols_cache: Dict[str, List[str]] = {} try: # If we split a task, handle its partial row-level skip first if partial_task_idx is not None and skip_remaining > 0: kind = partial_task_kind path = partial_task_path rg = partial_task_rg if kind == "parquet": assert rg is not None row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache) elif kind == "csv": row_iter = self._csv_file_iter(path) else: raise ValueError(f"Unknown task kind: {kind}") for sample in row_iter: if skip_remaining > 0: skip_remaining -= 1 skipped += 1 if progress is not None: progress.update(1) if skip_remaining == 0 and progress is not None: progress.close() progress = None continue # past the partial skip remainder, fall through to normal buffering/yield if bufN <= 0: self._emitted += 1 yield sample else: buffer.append(sample) if len(buffer) >= bufN: j = rng.randrange(len(buffer)) buffer[j], buffer[-1] = buffer[-1], buffer[j] self._emitted += 1 yield buffer.pop() for (kind, path, rg) in my_tasks: if kind == "parquet": assert rg is not None row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache) elif kind == "csv": row_iter = self._csv_file_iter(path) else: raise ValueError(f"Unknown task kind: {kind}") for sample in row_iter: # Apply any remaining resume skip across the flattened stream if skip_remaining > 0: skip_remaining -= 1 skipped += 1 if progress is not None: progress.update(1) if skip_remaining == 0 and progress is not None: # Finish the progress bar once we've consumed the target progress.close() progress = None continue if bufN <= 0: self._emitted += 1 yield sample else: buffer.append(sample) if len(buffer) >= bufN: j = rng.randrange(len(buffer)) buffer[j], buffer[-1] = buffer[-1], buffer[j] self._emitted += 1 yield buffer.pop() # Flush leftovers for it in _drain_buffer(): self._emitted += 1 yield it finally: if progress is not None: progress.close() if local_skip_target > 0: # Persist any remaining leftover skip (including partial progress) per worker copy self._resume_skip_n = max(local_skip_target - skipped, 0) # ------------------------------ # Simple collate: end-only pad for codon stream, pass-through everything else # ------------------------------ def stage_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: B = len(batch) if B == 0: return {} # species ids species_ids = torch.tensor([int(x.get("species_id", 0)) for x in batch], dtype=torch.long) # raw protein sequences stay as list[str] (ESM handles tokenization) protein_seqs = [str(x.get("protein_seq", "M")) for x in batch] # Build padded codon ids (right padding). Keep EOS inside the sequence (already appended in dataset). codon_lists = [x.get("codon_ids", []) for x in batch] max_len = max(len(c) for c in codon_lists) pad_id = 0 # tokenizer.pad_token_id is 0 in our tokenizer. codon_ids = torch.full((B, max_len), pad_id, dtype=torch.long) for i, row in enumerate(codon_lists): if len(row) > 0: codon_ids[i, : len(row)] = torch.tensor(row, dtype=torch.long) out: Dict[str, Any] = { "species_ids": species_ids, "protein_seqs": protein_seqs, "codon_ids": codon_ids, "control_mode": batch[0].get("control_mode", "fixed"), } # Optional passthroughs if "refseq_id" in batch[0]: out["refseq_id"] = [x.get("refseq_id") for x in batch] if "protein_refseq_id" in batch[0]: out["protein_refseq_id"] = [x.get("protein_refseq_id") for x in batch] return out def _build_dataset( path_or_paths: str | List[str], tokenizer, species_vocab_path: str, shuffle_buffer: int, csv_chunksize: int, shard_across_ranks: bool = True, ) -> StreamSeqDataset: files = _expand_paths(path_or_paths) return StreamSeqDataset( files=files, tokenizer=tokenizer, species_vocab_path=species_vocab_path, unknown_species_id=0, csv_chunksize=csv_chunksize, shuffle_buffer=shuffle_buffer, seed=1234, shard_across_ranks=shard_across_ranks, ) def create_precomputed_dataloaders( train_path: str | List[str], val_path: Optional[str | List[str]], embeddings_dir: str, tokenizer, batch_size: int, num_workers: int = 4, species_pooling: str = "sequence", csv_chunksize: int = 200_000, train_shuffle_buffer: int = 8192, val_shuffle_buffer: int = 0, ) -> Tuple[DataLoader, Optional[DataLoader], SpeciesEmbeddingStore]: """ Returns: - train_loader, val_loader (optional), and the SpeciesEmbeddingStore """ species_store = SpeciesEmbeddingStore(embeddings_dir, pin_memory=True, pooling=species_pooling) species_vocab_path = os.path.join(embeddings_dir, "species_vocab.json") num_workers = int(max(0, num_workers)) train_ds = _build_dataset( path_or_paths=train_path, tokenizer=tokenizer, species_vocab_path=species_vocab_path, shuffle_buffer=int(train_shuffle_buffer), csv_chunksize=int(csv_chunksize), ) val_ds = None if val_path: val_ds = _build_dataset( path_or_paths=val_path, tokenizer=tokenizer, species_vocab_path=species_vocab_path, shuffle_buffer=int(val_shuffle_buffer), csv_chunksize=int(csv_chunksize), ) # NOTE: IterableDataset can't be shuffled by DataLoader. We already "shuffle" inside the dataset. kwargs_common = dict( num_workers=num_workers, collate_fn=stage_collate_fn, pin_memory=True, persistent_workers=(num_workers > 0), ) if num_workers > 0: kwargs_common["prefetch_factor"] = 4 # Drop last for train to keep batch shapes stable under FSDP. train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=False, drop_last=True, **kwargs_common, ) val_loader = None if val_ds is not None: val_loader = DataLoader( val_ds, batch_size=batch_size, shuffle=False, drop_last=False, **kwargs_common, ) return train_loader, val_loader, species_store