| |
| """ |
| Production-ready dataset + dataloader utilities. |
| |
| Rules (because we're adults): |
| - Data drives design. Inputs are rows with columns: ["cds_DNA", "protein_seq", "Taxon", (optional) "RefseqID"]. |
| - Output per sample is a tiny dict the model actually needs. Nothing else. |
| - We stream Parquet by row groups, CSV by chunks. No full-file pandas nonsense on big data. |
| - We shard by (FSDP rank × dataloader worker). No DistributedSampler needed. |
| - We do a simple streaming shuffle buffer for train. Good enough. No fancy "epoch managers". |
| |
| Fields emitted per sample (for collate_fn and trainer): |
| { |
| "species_name": str, |
| "species_id": int, |
| "protein_seq": str, # raw AA (ESM tokenized later) |
| "aa_len": int, |
| "codon_ids": List[int], # tokenized 3-mer ids + EOS at the end |
| "refseq_id": str, |
| "protein_refseq_id": str, |
| "control_mode": "fixed", |
| "meta": {"src": "parquet|csv", "file": basename, "row": int} |
| } |
| |
| Invariants: |
| - cds_DNA length divisible by 3 after trimming to match protein length. |
| - DNA uses only ACGT (uppercase). If not, we skip the row. We don't "helpfully fix" broken data. |
| - We truncate both DNA and protein to the same min length (codon count). |
| - EOS appended to codon_ids; PAD is handled at collate time, not here. |
| |
| Dependencies: |
| - pyarrow only if you read parquet. If it isn't installed and you pass parquet files, we fail loudly. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import json |
| import glob |
| import random |
| import logging |
| import heapq |
| from typing import Dict, List, Any, Optional, Iterable, Tuple |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from torch.utils.data import IterableDataset, Dataset, DataLoader, get_worker_info |
|
|
| try: |
| from tqdm.auto import tqdm as _tqdm |
| except Exception: |
| _tqdm = None |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
|
|
| class SpeciesEmbeddingStore: |
| def __init__(self, embeddings_dir: str, dtype: str = "float32", pin_memory: bool = False, pooling: str = "last"): |
| self.embeddings_dir = Path(embeddings_dir) |
| self.pin_memory = bool(pin_memory) |
| self.is_legacy = False |
| self.pooling = pooling |
|
|
| vocab_path = self.embeddings_dir / "species_vocab.json" |
| if not vocab_path.exists(): |
| raise FileNotFoundError(f"Species vocabulary not found at {vocab_path}") |
| with open(vocab_path, "r") as f: |
| self.vocab: Dict[str, int] = json.load(f) |
|
|
| meta_path = self.embeddings_dir / "species_metadata.json" |
| new_emb_path = self.embeddings_dir / "species_embeddings.bin" |
| legacy_index = self.embeddings_dir / "species_index.json" |
| legacy_emb = self.embeddings_dir / "species_tok_emb.bin" |
|
|
| if self.pooling == "sequence" and legacy_index.exists() and legacy_emb.exists(): |
| self.is_legacy = True |
| self._load_legacy_format(dtype) |
| return |
|
|
| if meta_path.exists() and new_emb_path.exists(): |
| with open(meta_path, "r") as f: |
| meta = json.load(f) |
| self.num_species = int(meta["num_species"]) |
| self._ds = int(meta["embedding_dim"]) |
| self.embedding_type = str(meta.get("embedding_type", "fixed_size")) |
| np_dtype = np.float16 if dtype == "float16" else np.float32 |
| self.embeddings = np.memmap(new_emb_path, dtype=np_dtype, mode="r", shape=(self.num_species, self._ds)) |
| self._np_dtype = np_dtype |
| print(f"Loaded fixed-size species embeddings: {len(self.vocab)} species, Ds={self._ds}, dtype={self._np_dtype}") |
| else: |
| self.is_legacy = True |
| self._load_legacy_format(dtype) |
|
|
| def _load_legacy_format(self, dtype: str): |
| index_path = self.embeddings_dir / "species_index.json" |
| if not index_path.exists(): |
| raise FileNotFoundError(f"Species index not found at {index_path}") |
| with open(index_path, "r") as f: |
| raw_index = json.load(f) |
| self.index: Dict[str, Dict[str, int]] = {str(k): v for k, v in raw_index.items()} |
|
|
| meta_path = self.embeddings_dir / "metadata.json" |
| file_dtype = dtype |
| if meta_path.exists(): |
| with open(meta_path, "r") as f: |
| meta = json.load(f) |
| self._ds = int(meta.get("embedding_dim", 1024)) |
| file_dtype = str(meta.get("dtype", dtype)).lower() |
| else: |
| self._ds = 1024 |
|
|
| emb_path = self.embeddings_dir / "species_tok_emb.bin" |
| if not emb_path.exists(): |
| raise FileNotFoundError(f"Species embeddings not found at {emb_path}") |
|
|
| np_dtype = np.float16 if file_dtype == "float16" else np.float32 |
| itemsize = np.dtype(np_dtype).itemsize |
| file_bytes = os.path.getsize(emb_path) |
| if file_bytes % (self._ds * itemsize) != 0: |
| raise ValueError(f"Emb file size {file_bytes} not divisible by Ds*itemsize ({self._ds}*{itemsize})") |
| total_tokens = file_bytes // (self._ds * itemsize) |
|
|
| self.embeddings = np.memmap(emb_path, dtype=np_dtype, mode="r", shape=(total_tokens, self._ds)) |
| self._np_dtype = np_dtype |
| self.num_species = len(self.vocab) |
| print(f"[LEGACY] variable-length embeddings: {len(self.vocab)} species, {total_tokens} tokens total, Ds={self._ds}.") |
|
|
| def load_vocab(self) -> Dict[str, int]: |
| return self.vocab.copy() |
|
|
| def _deterministic_stub(self, length: int = None) -> torch.FloatTensor: |
| if self.is_legacy and length: |
| t = torch.zeros(1, length, self._ds, dtype=torch.float32) |
| else: |
| t = torch.zeros(1, self._ds, dtype=torch.float32) |
| return t |
|
|
| def get(self, species_id: int) -> torch.FloatTensor: |
| if not self.is_legacy: |
| if species_id < 0 or species_id >= getattr(self, "num_species", 0): |
| return self._deterministic_stub() |
| emb = self.embeddings[species_id] |
| tensor = torch.from_numpy(np.asarray(emb).copy()).float().unsqueeze(0) |
| return tensor |
| else: |
| sid = str(species_id) |
| entry = self.index.get(sid) |
| if entry is None: |
| return self._deterministic_stub(length=8) |
| offset = int(entry["offset"]); length = int(entry["length"]) |
| view = self.embeddings[offset: offset + length] |
| tensor = torch.from_numpy(np.asarray(view).copy()).float().unsqueeze(0) |
| return tensor |
|
|
| def batch_get(self, species_ids: List[int]) -> Any: |
| if torch.is_tensor(species_ids): |
| species_ids = species_ids.detach().cpu().tolist() |
| else: |
| species_ids = [int(x) for x in species_ids] |
| B = len(species_ids) |
| if not self.is_legacy: |
| batch_emb = torch.zeros(B, self._ds, dtype=torch.float32) |
| for i, sid in enumerate(species_ids): |
| batch_emb[i] = self.get(sid).squeeze(0) |
| return batch_emb |
| else: |
| tensors = [self.get(sid) for sid in species_ids] |
| lengths = torch.tensor([t.shape[1] for t in tensors], dtype=torch.long) |
| Ls_max = int(lengths.max().item()) if lengths.numel() > 0 else 0 |
| padded = torch.zeros(B, Ls_max, self._ds, dtype=torch.float32) |
| for i, t in enumerate(tensors): |
| L = t.shape[1]; padded[i, :L] = t.squeeze(0) |
| return padded, lengths |
|
|
| def Ds(self) -> int: |
| return self._ds |
|
|
| def _is_parquet(path: str) -> bool: |
| lower = path.lower() |
| return lower.endswith(".parquet") or lower.endswith(".parq") |
|
|
|
|
| def _is_csv(path: str) -> bool: |
| lower = path.lower() |
| return ( |
| lower.endswith(".csv") |
| or lower.endswith(".tsv") |
| or lower.endswith(".csv.gz") |
| or lower.endswith(".tsv.gz") |
| ) |
|
|
|
|
| def _expand_paths(maybe_path_or_glob: str | List[str]) -> List[str]: |
| """ |
| Expand a path/glob or list of them into a sorted, de-duplicated list of files. |
| We prioritize parquet, then csv/tsv. |
| """ |
| paths: List[str] = [] |
| if isinstance(maybe_path_or_glob, str): |
| p = Path(maybe_path_or_glob) |
| if p.is_dir(): |
| |
| paths.extend(sorted(str(x) for x in p.rglob("*.parquet"))) |
| paths.extend(sorted(str(x) for x in p.rglob("*.parq"))) |
| paths.extend(sorted(str(x) for x in p.rglob("*.csv"))) |
| paths.extend(sorted(str(x) for x in p.rglob("*.tsv"))) |
| paths.extend(sorted(str(x) for x in p.rglob("*.csv.gz"))) |
| paths.extend(sorted(str(x) for x in p.rglob("*.tsv.gz"))) |
| else: |
| paths = sorted(glob.glob(str(p))) |
| else: |
| for it in maybe_path_or_glob: |
| paths.extend(_expand_paths(it)) |
| |
| seen = set() |
| out = [] |
| for x in paths: |
| if x not in seen: |
| out.append(x) |
| seen.add(x) |
| if not out: |
| raise FileNotFoundError(f"No input files found for: {maybe_path_or_glob}") |
| return out |
|
|
|
|
| def _dist_info() -> Tuple[int, int]: |
| """ |
| Returns (num_global_workers, global_worker_id) |
| where global_worker_id = rank * num_workers + worker_id. |
| """ |
| world_size = 1 |
| rank = 0 |
| try: |
| import torch.distributed as dist |
|
|
| if dist.is_available() and dist.is_initialized(): |
| world_size = dist.get_world_size() |
| rank = dist.get_rank() |
| except Exception: |
| pass |
| wi = get_worker_info() |
| nw = wi.num_workers if wi else 1 |
| wid = wi.id if wi else 0 |
| return world_size * nw, rank * nw + wid |
|
|
|
|
| class _ResumeSkipProgress: |
| """Lightweight progress helper for resume skips.""" |
|
|
| def __init__(self, total: int, label: str): |
| self.total = int(max(0, total)) |
| self.label = label |
| self.count = 0 |
| self._bar = None |
|
|
| if self.total <= 0: |
| return |
|
|
| if _tqdm is not None: |
| self._bar = _tqdm(total=self.total, desc=label, unit="sample", dynamic_ncols=True, leave=False) |
| else: |
| logger.info("%s: skipping %d samples to reach resume cursor", label, self.total) |
|
|
| def update(self, n: int = 1): |
| if self.total <= 0: |
| return |
| self.count += int(n) |
| if self._bar is not None: |
| self._bar.update(n) |
| else: |
| if self.count == self.total or self.count % 10000 == 0: |
| logger.info("%s: skipped %d / %d", self.label, self.count, self.total) |
|
|
| def close(self): |
| if self.total <= 0: |
| return |
| if self._bar is not None: |
| self._bar.close() |
| logger.info("%s: resume skip finished (%d samples)", self.label, self.count) |
|
|
|
|
| class StreamSeqDataset(IterableDataset): |
| """ |
| Streaming dataset with **non-overlapping Parquet row-group sharding**. |
| |
| - Accepts list of files (parquet and/or csv/tsv). |
| - **Parquet**: we enumerate (file, row_group) tasks and stride them across |
| the *global* worker id to avoid duplicates and to keep all ranks busy even |
| with few files. |
| - **CSV/TSV**: assigned at file granularity (one worker reads a file). |
| If you have only a few CSV files and many ranks, some ranks may get no CSV work. |
| (Parquet is the recommended format at scale.) |
| - CSV is read with pandas chunksize to keep memory usage sane. |
| - Each Parquet task reads exactly **one row group** into pandas. |
| |
| Minimal resume support: |
| - set_resume_skip(N) skips N yielded samples across the worker's assigned tasks. |
| (Use a **per-rank** skip value in your trainer so multi-node resumes stay in lockstep.) |
| |
| Output sample schema: |
| { |
| "species_name": str, |
| "species_id": int, |
| "protein_seq": str, # raw AA (ESM tokenized later) |
| "aa_len": int, |
| "codon_ids": List[int], # tokenized 3-mer ids + EOS at the end |
| "refseq_id": str, |
| "protein_refseq_id": str, |
| "control_mode": "fixed", |
| "meta": {"src": "parquet|csv", "file": basename, "row": int} |
| } |
| """ |
|
|
| |
| REQUIRED = ["cds_DNA", "protein_seq", "Taxon"] |
|
|
| def __init__( |
| self, |
| files: List[str], |
| tokenizer, |
| species_vocab_path: str, |
| unknown_species_id: int = 0, |
| csv_chunksize: int = 200_000, |
| shuffle_buffer: int = 0, |
| seed: int = 1234, |
| shard_across_ranks: bool = True, |
| ): |
| super().__init__() |
| self.files = files |
| self.tok = tokenizer |
| with open(species_vocab_path, "r") as f: |
| self.species_vocab: Dict[str, int] = json.load(f) |
| self.unknown_species_id = int(unknown_species_id) |
| self.csv_chunksize = int(max(1, csv_chunksize)) |
| self.shuffle_buffer = int(max(0, shuffle_buffer)) |
| self.seed = int(seed) |
| |
| |
| |
| self.shard_across_ranks = bool(shard_across_ranks) |
|
|
| |
| self._resume_skip_n: int = 0 |
| self._offset_start: int = 0 |
| self._emitted: int = 0 |
|
|
| |
| def set_resume_skip(self, n: int) -> None: |
| n = int(max(0, n)) |
| self._resume_skip_n = n |
| self._offset_start = n |
| self._emitted = 0 |
|
|
| def get_stream_position(self) -> int: |
| |
| return int(self._offset_start + self._emitted) |
|
|
| |
| def _iter_df(self, df: pd.DataFrame, src: str, file: str) -> Iterable[Dict[str, Any]]: |
| |
| |
| if "Taxon" not in df.columns and "taxon" in df.columns: |
| df = df.rename(columns={"taxon": "Taxon"}) |
|
|
| |
| for c in self.REQUIRED: |
| if c not in df.columns: |
| raise ValueError(f"Input missing required column '{c}' in {file}") |
|
|
| |
| df = df[self.REQUIRED + ([c for c in ["RefseqID"] if c in df.columns])] |
| df["Taxon"] = df["Taxon"].astype(str).str.strip() |
| df["protein_seq"] = df["protein_seq"].astype(str).str.strip().str.upper() |
| df["cds_DNA"] = df["cds_DNA"].astype(str).str.strip().str.upper() |
|
|
| |
| ok_mask = (df["cds_DNA"].str.len() > 0) & df["cds_DNA"].str.fullmatch(r"[ACGT]+", na=False) |
| df = df[ok_mask] |
| if df.empty: |
| return |
|
|
| |
| cds_codons = (df["cds_DNA"].str.len() // 3).astype(int) |
| prot_len = df["protein_seq"].str.len().astype(int) |
| min_len = np.minimum(cds_codons.values, prot_len.values) |
|
|
| df = df.assign(__min_len=min_len) |
| df = df[df["__min_len"] > 0] |
| if df.empty: |
| return |
|
|
| |
| def map_species(x: str) -> int: |
| try: |
| return int(self.species_vocab.get(x, self.unknown_species_id)) |
| except Exception: |
| return self.unknown_species_id |
|
|
| species_ids = [map_species(x) for x in df["Taxon"].tolist()] |
| refseq_col = "RefseqID" if "RefseqID" in df.columns else None |
|
|
| for i, (row_idx, row) in enumerate(df.iterrows()): |
| ml = int(row["__min_len"]) |
| cds = row["cds_DNA"][: ml * 3] |
| prot = row["protein_seq"][: ml] |
| if (len(cds) // 3) != len(prot): |
| continue |
|
|
| |
| codon_ids = self.tok.encode_codon_seq(cds, validate=False) |
| codon_ids.append( |
| self.tok.special_ids.eos if hasattr(self.tok, "special_ids") else self.tok._special_ids.eos |
| ) |
|
|
| species_id = species_ids[i] |
| ref_id = row[refseq_col] if refseq_col else f"{Path(file).stem}:{int(row_idx)}" |
|
|
| yield { |
| "species_name": row["Taxon"], |
| "species_id": int(species_id), |
| "protein_seq": prot, |
| "aa_len": len(prot), |
| "codon_ids": codon_ids, |
| "refseq_id": ref_id, |
| "protein_refseq_id": ref_id, |
| "control_mode": "fixed", |
| "meta": {"src": src, "file": os.path.basename(file), "row": int(row_idx)}, |
| } |
|
|
| |
| def _enumerate_tasks(self, files: List[str]) -> List[Tuple[str, str, Optional[int], int]]: |
| """ |
| Return a task list of tuples: |
| ("parquet", path, row_group_idx, weight) for each row group in each Parquet file |
| ("csv", path, None, weight) for each CSV/TSV file |
| """ |
| tasks: List[Tuple[str, str, Optional[int], int]] = [] |
| parquet_files = [f for f in files if _is_parquet(f)] |
| csv_files = [f for f in files if _is_csv(f)] |
|
|
| if parquet_files: |
| try: |
| import pyarrow.parquet as pq |
| except Exception as e: |
| raise ImportError("pyarrow is required to read parquet files") from e |
|
|
| for fp in parquet_files: |
| pf = pq.ParquetFile(fp) |
| nrg = int(pf.num_row_groups or 0) |
| if nrg <= 0: |
| |
| total_rows = pf.metadata.num_rows if pf.metadata and pf.metadata.num_rows is not None else 1 |
| tasks.append(("parquet", fp, 0, max(1, int(total_rows)))) |
| else: |
| for rg in range(nrg): |
| if pf.metadata is not None: |
| rg_meta = pf.metadata.row_group(rg) |
| num_rows = rg_meta.num_rows if rg_meta.num_rows is not None else 0 |
| else: |
| num_rows = 0 |
| tasks.append(("parquet", fp, rg, max(1, int(num_rows)))) |
|
|
| |
| for fp in csv_files: |
| file_size = os.path.getsize(fp) |
| |
| est_rows = max(1, int(file_size // 256)) |
| tasks.append(("csv", fp, None, est_rows)) |
|
|
| |
| |
| return tasks |
|
|
| @staticmethod |
| def _balanced_partition(tasks: List[Tuple[str, str, Optional[int], int]], groups: int) -> List[List[Tuple[str, str, Optional[int], int]]]: |
| if groups <= 1: |
| return [tasks] |
| if not tasks: |
| return [[] for _ in range(groups)] |
|
|
| |
| indexed = [(idx, kind, path, rg, weight) for idx, (kind, path, rg, weight) in enumerate(tasks)] |
| tasks_sorted = sorted( |
| indexed, |
| key=lambda entry: (entry[4], -entry[0]), |
| reverse=True, |
| ) |
|
|
| heap: List[Tuple[int, int]] = [(0, bucket_idx) for bucket_idx in range(groups)] |
| heapq.heapify(heap) |
| buckets: List[List[Tuple[int, str, str, Optional[int], int]]] = [[] for _ in range(groups)] |
|
|
| for original_index, kind, path, rg, weight in tasks_sorted: |
| load, bucket_idx = heapq.heappop(heap) |
| buckets[bucket_idx].append((original_index, kind, path, rg, weight)) |
| heapq.heappush(heap, (load + weight, bucket_idx)) |
|
|
| partitions: List[List[Tuple[str, str, Optional[int], int]]] = [] |
| for bucket in buckets: |
| bucket.sort(key=lambda entry: entry[0]) |
| partitions.append([(kind, path, rg, weight) for (_idx, kind, path, rg, weight) in bucket]) |
| return partitions |
|
|
| def _parquet_rowgroup_iter( |
| self, file: str, row_group_idx: int, cols_cache: Dict[str, List[str]] |
| ) -> Iterable[Dict[str, Any]]: |
| import pyarrow.parquet as pq |
| pf = pq.ParquetFile(file) |
| |
| if file not in cols_cache: |
| names = set(pf.schema.names) |
| cols: List[str] = [] |
| |
| for c in self.REQUIRED: |
| if c in names: |
| cols.append(c) |
| continue |
| if c == "Taxon" and "taxon" in names: |
| cols.append("taxon") |
| continue |
| |
| if "RefseqID" in names: |
| cols.append("RefseqID") |
| cols_cache[file] = cols |
| cols = cols_cache[file] |
| table = pf.read_row_group(row_group_idx, columns=cols) |
| df = table.to_pandas(types_mapper=None) |
| yield from self._iter_df(df, "parquet", file) |
|
|
| def _csv_file_iter(self, file: str) -> Iterable[Dict[str, Any]]: |
| |
| for chunk in pd.read_csv(file, chunksize=self.csv_chunksize, dtype=str, keep_default_na=False): |
| yield from self._iter_df(chunk, "csv", file) |
|
|
| |
| def __iter__(self): |
| wi = get_worker_info() |
| num_workers = wi.num_workers if wi else 1 |
| worker_id = wi.id if wi else 0 |
|
|
| num_global, gid = _dist_info() |
| if not self.shard_across_ranks: |
| num_global = max(1, num_workers) |
| gid = worker_id |
|
|
| workers_per_rank = max(1, num_workers) |
| rank = gid // workers_per_rank if self.shard_across_ranks else 0 |
| world = max(1, num_global // workers_per_rank) |
|
|
| |
| |
| |
| per_rank_skip = int(self._resume_skip_n) |
| base = per_rank_skip // max(1, workers_per_rank) |
| rem = per_rank_skip % max(1, workers_per_rank) |
| local_skip_target = base + (1 if worker_id < rem else 0) |
| progress: Optional[_ResumeSkipProgress] = None |
|
|
| |
| tasks = self._enumerate_tasks(self.files) |
|
|
| if tasks: |
| partitions = self._balanced_partition(tasks, max(1, num_global)) |
| my_tasks_full = partitions[gid] if gid < len(partitions) else [] |
| else: |
| my_tasks_full = [] |
|
|
| if local_skip_target > 0 and worker_id == 0: |
| label = ( |
| "resume skip" if world == 1 else f"resume skip (rank {rank}/{world})" |
| ) |
| progress = _ResumeSkipProgress(local_skip_target, label) |
|
|
| |
| |
| skip_remaining = int(local_skip_target) |
| start_idx = 0 |
| partial_task_idx = None |
| partial_task_kind = None |
| partial_task_path = None |
| partial_task_rg = None |
| if skip_remaining > 0 and my_tasks_full: |
| for idx, (kind, path, rg, weight) in enumerate(my_tasks_full): |
| w = int(weight) if weight is not None else 0 |
| if w <= 0: |
| continue |
| if skip_remaining >= w: |
| skip_remaining -= w |
| start_idx = idx + 1 |
| if progress is not None: |
| progress.update(w) |
| else: |
| partial_task_idx = idx |
| partial_task_kind = kind |
| partial_task_path = path |
| partial_task_rg = rg |
| break |
|
|
| |
| my_tasks = [(kind, path, rg) for (kind, path, rg, _w) in my_tasks_full[start_idx:]] |
|
|
| rng = random.Random(self.seed + gid) |
| buffer: List[Dict[str, Any]] = [] |
| bufN = self.shuffle_buffer |
|
|
| def _drain_buffer(): |
| if not buffer: |
| return |
| if bufN > 0: |
| rng.shuffle(buffer) |
| for it in buffer: |
| yield it |
| buffer.clear() |
|
|
| |
| skipped = int(local_skip_target - skip_remaining) |
|
|
| |
| cols_cache: Dict[str, List[str]] = {} |
|
|
| try: |
| |
| if partial_task_idx is not None and skip_remaining > 0: |
| kind = partial_task_kind |
| path = partial_task_path |
| rg = partial_task_rg |
| if kind == "parquet": |
| assert rg is not None |
| row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache) |
| elif kind == "csv": |
| row_iter = self._csv_file_iter(path) |
| else: |
| raise ValueError(f"Unknown task kind: {kind}") |
|
|
| for sample in row_iter: |
| if skip_remaining > 0: |
| skip_remaining -= 1 |
| skipped += 1 |
| if progress is not None: |
| progress.update(1) |
| if skip_remaining == 0 and progress is not None: |
| progress.close() |
| progress = None |
| continue |
| |
| if bufN <= 0: |
| self._emitted += 1 |
| yield sample |
| else: |
| buffer.append(sample) |
| if len(buffer) >= bufN: |
| j = rng.randrange(len(buffer)) |
| buffer[j], buffer[-1] = buffer[-1], buffer[j] |
| self._emitted += 1 |
| yield buffer.pop() |
|
|
| for (kind, path, rg) in my_tasks: |
| if kind == "parquet": |
| assert rg is not None |
| row_iter = self._parquet_rowgroup_iter(path, int(rg), cols_cache) |
| elif kind == "csv": |
| row_iter = self._csv_file_iter(path) |
| else: |
| raise ValueError(f"Unknown task kind: {kind}") |
|
|
| for sample in row_iter: |
| |
| if skip_remaining > 0: |
| skip_remaining -= 1 |
| skipped += 1 |
| if progress is not None: |
| progress.update(1) |
| if skip_remaining == 0 and progress is not None: |
| |
| progress.close() |
| progress = None |
| continue |
|
|
| if bufN <= 0: |
| self._emitted += 1 |
| yield sample |
| else: |
| buffer.append(sample) |
| if len(buffer) >= bufN: |
| j = rng.randrange(len(buffer)) |
| buffer[j], buffer[-1] = buffer[-1], buffer[j] |
| self._emitted += 1 |
| yield buffer.pop() |
|
|
| |
| for it in _drain_buffer(): |
| self._emitted += 1 |
| yield it |
| finally: |
| if progress is not None: |
| progress.close() |
| if local_skip_target > 0: |
| |
| self._resume_skip_n = max(local_skip_target - skipped, 0) |
|
|
|
|
| |
| |
| |
|
|
| def stage_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
| B = len(batch) |
| if B == 0: |
| return {} |
|
|
| |
| species_ids = torch.tensor([int(x.get("species_id", 0)) for x in batch], dtype=torch.long) |
|
|
| |
| protein_seqs = [str(x.get("protein_seq", "M")) for x in batch] |
|
|
| |
| codon_lists = [x.get("codon_ids", []) for x in batch] |
| max_len = max(len(c) for c in codon_lists) |
| pad_id = 0 |
| codon_ids = torch.full((B, max_len), pad_id, dtype=torch.long) |
| for i, row in enumerate(codon_lists): |
| if len(row) > 0: |
| codon_ids[i, : len(row)] = torch.tensor(row, dtype=torch.long) |
|
|
| out: Dict[str, Any] = { |
| "species_ids": species_ids, |
| "protein_seqs": protein_seqs, |
| "codon_ids": codon_ids, |
| "control_mode": batch[0].get("control_mode", "fixed"), |
| } |
|
|
| |
| if "refseq_id" in batch[0]: |
| out["refseq_id"] = [x.get("refseq_id") for x in batch] |
| if "protein_refseq_id" in batch[0]: |
| out["protein_refseq_id"] = [x.get("protein_refseq_id") for x in batch] |
|
|
| return out |
|
|
| def _build_dataset( |
| path_or_paths: str | List[str], |
| tokenizer, |
| species_vocab_path: str, |
| shuffle_buffer: int, |
| csv_chunksize: int, |
| shard_across_ranks: bool = True, |
| ) -> StreamSeqDataset: |
| files = _expand_paths(path_or_paths) |
| return StreamSeqDataset( |
| files=files, |
| tokenizer=tokenizer, |
| species_vocab_path=species_vocab_path, |
| unknown_species_id=0, |
| csv_chunksize=csv_chunksize, |
| shuffle_buffer=shuffle_buffer, |
| seed=1234, |
| shard_across_ranks=shard_across_ranks, |
| ) |
|
|
|
|
| def create_precomputed_dataloaders( |
| train_path: str | List[str], |
| val_path: Optional[str | List[str]], |
| embeddings_dir: str, |
| tokenizer, |
| batch_size: int, |
| num_workers: int = 4, |
| species_pooling: str = "sequence", |
| csv_chunksize: int = 200_000, |
| train_shuffle_buffer: int = 8192, |
| val_shuffle_buffer: int = 0, |
| ) -> Tuple[DataLoader, Optional[DataLoader], SpeciesEmbeddingStore]: |
| """ |
| Returns: |
| - train_loader, val_loader (optional), and the SpeciesEmbeddingStore |
| """ |
| species_store = SpeciesEmbeddingStore(embeddings_dir, pin_memory=True, pooling=species_pooling) |
| species_vocab_path = os.path.join(embeddings_dir, "species_vocab.json") |
| num_workers = int(max(0, num_workers)) |
|
|
| train_ds = _build_dataset( |
| path_or_paths=train_path, |
| tokenizer=tokenizer, |
| species_vocab_path=species_vocab_path, |
| shuffle_buffer=int(train_shuffle_buffer), |
| csv_chunksize=int(csv_chunksize), |
| ) |
| val_ds = None |
| if val_path: |
| val_ds = _build_dataset( |
| path_or_paths=val_path, |
| tokenizer=tokenizer, |
| species_vocab_path=species_vocab_path, |
| shuffle_buffer=int(val_shuffle_buffer), |
| csv_chunksize=int(csv_chunksize), |
| ) |
|
|
| |
| kwargs_common = dict( |
| num_workers=num_workers, |
| collate_fn=stage_collate_fn, |
| pin_memory=True, |
| persistent_workers=(num_workers > 0), |
| ) |
| if num_workers > 0: |
| kwargs_common["prefetch_factor"] = 4 |
|
|
| |
| train_loader = DataLoader( |
| train_ds, |
| batch_size=batch_size, |
| shuffle=False, |
| drop_last=True, |
| **kwargs_common, |
| ) |
|
|
| val_loader = None |
| if val_ds is not None: |
| val_loader = DataLoader( |
| val_ds, |
| batch_size=batch_size, |
| shuffle=False, |
| drop_last=False, |
| **kwargs_common, |
| ) |
|
|
| return train_loader, val_loader, species_store |
|
|