protloc-ai / docs /dataset_module.md
Tanoj22
Initial commit: ProtLoc-AI project setup and core app
cb6f1ba
# `src/data/dataset.py` — documentation
This module wires **precomputed protein embeddings** and **multilabel localization targets** into PyTorch so you can train heads (e.g. linear + `BCEWithLogitsLoss`) without recomputing ESM embeddings inside the training loop.
---
## Data flow (high level)
1. **`embeddings.py`** writes a model-specific folder (e.g. `data/processed/embeddings/esm2_t12_35M/`) containing NumPy arrays and JSON metadata.
2. **`ProteinLocalizationDataset`** loads those files once, keeps tensors in memory, and exposes standard `Dataset` indexing.
3. **`create_splits`** partitions indices into train / val / test while trying to preserve label balance per split.
4. **`create_dataloaders`** wraps splits in `DataLoader` instances with a custom collate function for batched tensors and string accessions.
5. **`compute_class_weights`** derives per-label `pos_weight` vectors for imbalanced multilabel classification.
All filesystem paths use **`pathlib.Path`** so the same code runs on Windows, macOS, and Linux.
---
## Expected files per embeddings directory
| File | Role |
|------|------|
| `embeddings.npy` | 2D float array, shape `(N, embedding_dim)` — one row per protein |
| `multilabel_targets.npy` | 2D float array, shape `(N, num_labels)` — binary (or soft) multilabel targets |
| `accessions.npy` | 1D object array of accession strings, length `N`, same row order as above |
| `label_columns.json` | JSON with key `"label_columns"`: list of label names, length `num_labels`, column order matching `multilabel_targets.npy` |
On init, the dataset checks that row counts align and that the target matrix width matches `label_columns.json`.
---
## `ProteinLocalizationDataset`
### Purpose
Implements `torch.utils.data.Dataset`. It does **not** read raw sequences or run transformers; it only serves tensors already produced offline.
### `__init__(embeddings_dir)`
- Resolves `embeddings_dir` with `expanduser().resolve()`.
- Loads the four files above; `accessions.npy` uses `allow_pickle=True` because accession strings are often stored as a NumPy object array.
- Converts embeddings and targets to **`float32`** PyTorch tensors (suitable for mixed-precision training downstream).
- Stores accessions as the raw NumPy array; `__getitem__` normalizes a scalar to a Python `str`.
### `__getitem__(idx)` → `(embedding, target, accession)`
Returns a single sample:
- **embedding**: `torch.Tensor` of shape `(embedding_dim,)`
- **target**: `torch.Tensor` of shape `(num_labels,)`
- **accession**: Python `str` (protein ID)
### Properties
| Property | Meaning |
|----------|--------|
| `label_names` | Copy of the label names list from `label_columns.json` |
| `embedding_dim` | Second dimension of the embedding matrix |
| `num_labels` | Second dimension of the target matrix (number of multilabel columns) |
| `__len__` | Number of samples `N` |
The number of labels is **data-driven** (whatever your preprocessing and embeddings pipeline wrote), e.g. DeepLoc-style tables often use **10 or 11** compartment columns.
---
## `_collate_localization_batch`
PyTorch’s default collate does not batch arbitrary Python strings cleanly. This helper:
- **Stacks** embeddings and targets into shape `(batch, …)`.
- Collects accessions into a **Python list of strings** (one per row in the batch).
So each `DataLoader` batch is `(emb_batch, tgt_batch, acc_list)`.
---
## Train / validation / test splitting
### `create_splits(dataset, train_ratio, val_ratio, test_ratio, random_seed)`
- Requires `train_ratio + val_ratio + test_ratio ≈ 1.0`.
- Reads the full multilabel matrix `y` from the dataset (all `N` rows).
**Preferred path — `iterstrat`:**
- Uses `importlib.import_module("iterstrat.ml_stratifiers")` and `MultilabelStratifiedShuffleSplit` (avoids static import resolution issues in some IDEs).
- **Two-stage split:**
1. Split off **train** vs **temp** where `temp` has fraction `val_ratio + test_ratio` of the data, stratified on `y`.
2. Split **temp** into **val** and **test** with a second stratified split; the inner `test_size` is `test_ratio / (val_ratio + test_ratio)` so val and test each get the intended fraction of the full dataset (up to rounding).
**Fallback — random:**
- If `iterstrat` is not installed (`ImportError`), emits a **`UserWarning`** and uses **`_split_indices_random`**: a single permutation of `0..N-1`, then contiguous slices for train / val / test lengths from the rounded ratios. This preserves approximate sizes but **does not** stratify labels.
After splitting, the function prints **per-split label statistics** (positive count and percentage of that split per label) via `_print_split_label_distribution`.
**Return value:** three `torch.utils.data.Subset` objects sharing the same underlying `ProteinLocalizationDataset`, indexed by the chosen row indices.
---
## `create_dataloaders(train, val, test, batch_size, num_workers)`
Builds a dictionary:
```text
{"train": DataLoader, "val": DataLoader, "test": DataLoader}
```
- **Train**: `shuffle=True`.
- **Val / test**: `shuffle=False`.
- All use the custom **`collate_fn`** described above.
`num_workers=0` is the default (simple debugging on Windows; increase for faster loading if needed).
---
## `compute_class_weights(train_dataset)`
Used for **`torch.nn.BCEWithLogitsLoss(..., pos_weight=...)`**:
- Accepts either a full **`ProteinLocalizationDataset`** or a **`Subset`** (e.g. training split only).
- For each label \(k\):
\(\text{pos\_weight}_k = \frac{\#\text{negatives}}{\#\text{positives}}\)
with a small floor on positives to avoid division by zero.
- Returns a **`torch.float32`** tensor of shape `(num_labels,)`.
- Prints each weight and marks **extreme** values (`> 50` or `< 0.1`) for inspection.
Label names for printing come from the base dataset’s `label_names` property.
---
## Dependencies
- **Required:** `numpy`, `torch`.
- **Optional (recommended):** `iterstrat` — multilabel stratified splits. Listed in `requirements.txt`; install in a Python version the package supports (some very new Python releases may not have wheels yet).
---
## Related scripts
- **`scripts/test_dataset.py`** — smoke test: load a folder, run splits, one train batch, class weights.
---
## Design notes
- **No I/O in `__getitem__`:** all arrays are loaded in `__init__`, so training loops are not disk-bound per epoch (at the cost of RAM for large `N`).
- **Subset-aware helpers:** splitting only changes which indices are used; embedding rows and targets stay aligned by construction.