# `src/data/dataset.py` — documentation This module wires **precomputed protein embeddings** and **multilabel localization targets** into PyTorch so you can train heads (e.g. linear + `BCEWithLogitsLoss`) without recomputing ESM embeddings inside the training loop. --- ## Data flow (high level) 1. **`embeddings.py`** writes a model-specific folder (e.g. `data/processed/embeddings/esm2_t12_35M/`) containing NumPy arrays and JSON metadata. 2. **`ProteinLocalizationDataset`** loads those files once, keeps tensors in memory, and exposes standard `Dataset` indexing. 3. **`create_splits`** partitions indices into train / val / test while trying to preserve label balance per split. 4. **`create_dataloaders`** wraps splits in `DataLoader` instances with a custom collate function for batched tensors and string accessions. 5. **`compute_class_weights`** derives per-label `pos_weight` vectors for imbalanced multilabel classification. All filesystem paths use **`pathlib.Path`** so the same code runs on Windows, macOS, and Linux. --- ## Expected files per embeddings directory | File | Role | |------|------| | `embeddings.npy` | 2D float array, shape `(N, embedding_dim)` — one row per protein | | `multilabel_targets.npy` | 2D float array, shape `(N, num_labels)` — binary (or soft) multilabel targets | | `accessions.npy` | 1D object array of accession strings, length `N`, same row order as above | | `label_columns.json` | JSON with key `"label_columns"`: list of label names, length `num_labels`, column order matching `multilabel_targets.npy` | On init, the dataset checks that row counts align and that the target matrix width matches `label_columns.json`. --- ## `ProteinLocalizationDataset` ### Purpose Implements `torch.utils.data.Dataset`. It does **not** read raw sequences or run transformers; it only serves tensors already produced offline. ### `__init__(embeddings_dir)` - Resolves `embeddings_dir` with `expanduser().resolve()`. - Loads the four files above; `accessions.npy` uses `allow_pickle=True` because accession strings are often stored as a NumPy object array. - Converts embeddings and targets to **`float32`** PyTorch tensors (suitable for mixed-precision training downstream). - Stores accessions as the raw NumPy array; `__getitem__` normalizes a scalar to a Python `str`. ### `__getitem__(idx)` → `(embedding, target, accession)` Returns a single sample: - **embedding**: `torch.Tensor` of shape `(embedding_dim,)` - **target**: `torch.Tensor` of shape `(num_labels,)` - **accession**: Python `str` (protein ID) ### Properties | Property | Meaning | |----------|--------| | `label_names` | Copy of the label names list from `label_columns.json` | | `embedding_dim` | Second dimension of the embedding matrix | | `num_labels` | Second dimension of the target matrix (number of multilabel columns) | | `__len__` | Number of samples `N` | The number of labels is **data-driven** (whatever your preprocessing and embeddings pipeline wrote), e.g. DeepLoc-style tables often use **10 or 11** compartment columns. --- ## `_collate_localization_batch` PyTorch’s default collate does not batch arbitrary Python strings cleanly. This helper: - **Stacks** embeddings and targets into shape `(batch, …)`. - Collects accessions into a **Python list of strings** (one per row in the batch). So each `DataLoader` batch is `(emb_batch, tgt_batch, acc_list)`. --- ## Train / validation / test splitting ### `create_splits(dataset, train_ratio, val_ratio, test_ratio, random_seed)` - Requires `train_ratio + val_ratio + test_ratio ≈ 1.0`. - Reads the full multilabel matrix `y` from the dataset (all `N` rows). **Preferred path — `iterstrat`:** - Uses `importlib.import_module("iterstrat.ml_stratifiers")` and `MultilabelStratifiedShuffleSplit` (avoids static import resolution issues in some IDEs). - **Two-stage split:** 1. Split off **train** vs **temp** where `temp` has fraction `val_ratio + test_ratio` of the data, stratified on `y`. 2. Split **temp** into **val** and **test** with a second stratified split; the inner `test_size` is `test_ratio / (val_ratio + test_ratio)` so val and test each get the intended fraction of the full dataset (up to rounding). **Fallback — random:** - If `iterstrat` is not installed (`ImportError`), emits a **`UserWarning`** and uses **`_split_indices_random`**: a single permutation of `0..N-1`, then contiguous slices for train / val / test lengths from the rounded ratios. This preserves approximate sizes but **does not** stratify labels. After splitting, the function prints **per-split label statistics** (positive count and percentage of that split per label) via `_print_split_label_distribution`. **Return value:** three `torch.utils.data.Subset` objects sharing the same underlying `ProteinLocalizationDataset`, indexed by the chosen row indices. --- ## `create_dataloaders(train, val, test, batch_size, num_workers)` Builds a dictionary: ```text {"train": DataLoader, "val": DataLoader, "test": DataLoader} ``` - **Train**: `shuffle=True`. - **Val / test**: `shuffle=False`. - All use the custom **`collate_fn`** described above. `num_workers=0` is the default (simple debugging on Windows; increase for faster loading if needed). --- ## `compute_class_weights(train_dataset)` Used for **`torch.nn.BCEWithLogitsLoss(..., pos_weight=...)`**: - Accepts either a full **`ProteinLocalizationDataset`** or a **`Subset`** (e.g. training split only). - For each label \(k\): \(\text{pos\_weight}_k = \frac{\#\text{negatives}}{\#\text{positives}}\) with a small floor on positives to avoid division by zero. - Returns a **`torch.float32`** tensor of shape `(num_labels,)`. - Prints each weight and marks **extreme** values (`> 50` or `< 0.1`) for inspection. Label names for printing come from the base dataset’s `label_names` property. --- ## Dependencies - **Required:** `numpy`, `torch`. - **Optional (recommended):** `iterstrat` — multilabel stratified splits. Listed in `requirements.txt`; install in a Python version the package supports (some very new Python releases may not have wheels yet). --- ## Related scripts - **`scripts/test_dataset.py`** — smoke test: load a folder, run splits, one train batch, class weights. --- ## Design notes - **No I/O in `__getitem__`:** all arrays are loaded in `__init__`, so training loops are not disk-bound per epoch (at the cost of RAM for large `N`). - **Subset-aware helpers:** splitting only changes which indices are used; embedding rows and targets stay aligned by construction.