protloc-ai / docs /dataset_module.md
Tanoj22
Initial commit: ProtLoc-AI project setup and core app
cb6f1ba

src/data/dataset.py — documentation

This module wires precomputed protein embeddings and multilabel localization targets into PyTorch so you can train heads (e.g. linear + BCEWithLogitsLoss) without recomputing ESM embeddings inside the training loop.


Data flow (high level)

  1. embeddings.py writes a model-specific folder (e.g. data/processed/embeddings/esm2_t12_35M/) containing NumPy arrays and JSON metadata.
  2. ProteinLocalizationDataset loads those files once, keeps tensors in memory, and exposes standard Dataset indexing.
  3. create_splits partitions indices into train / val / test while trying to preserve label balance per split.
  4. create_dataloaders wraps splits in DataLoader instances with a custom collate function for batched tensors and string accessions.
  5. compute_class_weights derives per-label pos_weight vectors for imbalanced multilabel classification.

All filesystem paths use pathlib.Path so the same code runs on Windows, macOS, and Linux.


Expected files per embeddings directory

File Role
embeddings.npy 2D float array, shape (N, embedding_dim) — one row per protein
multilabel_targets.npy 2D float array, shape (N, num_labels) — binary (or soft) multilabel targets
accessions.npy 1D object array of accession strings, length N, same row order as above
label_columns.json JSON with key "label_columns": list of label names, length num_labels, column order matching multilabel_targets.npy

On init, the dataset checks that row counts align and that the target matrix width matches label_columns.json.


ProteinLocalizationDataset

Purpose

Implements torch.utils.data.Dataset. It does not read raw sequences or run transformers; it only serves tensors already produced offline.

__init__(embeddings_dir)

  • Resolves embeddings_dir with expanduser().resolve().
  • Loads the four files above; accessions.npy uses allow_pickle=True because accession strings are often stored as a NumPy object array.
  • Converts embeddings and targets to float32 PyTorch tensors (suitable for mixed-precision training downstream).
  • Stores accessions as the raw NumPy array; __getitem__ normalizes a scalar to a Python str.

__getitem__(idx)(embedding, target, accession)

Returns a single sample:

  • embedding: torch.Tensor of shape (embedding_dim,)
  • target: torch.Tensor of shape (num_labels,)
  • accession: Python str (protein ID)

Properties

Property Meaning
label_names Copy of the label names list from label_columns.json
embedding_dim Second dimension of the embedding matrix
num_labels Second dimension of the target matrix (number of multilabel columns)
__len__ Number of samples N

The number of labels is data-driven (whatever your preprocessing and embeddings pipeline wrote), e.g. DeepLoc-style tables often use 10 or 11 compartment columns.


_collate_localization_batch

PyTorch’s default collate does not batch arbitrary Python strings cleanly. This helper:

  • Stacks embeddings and targets into shape (batch, …).
  • Collects accessions into a Python list of strings (one per row in the batch).

So each DataLoader batch is (emb_batch, tgt_batch, acc_list).


Train / validation / test splitting

create_splits(dataset, train_ratio, val_ratio, test_ratio, random_seed)

  • Requires train_ratio + val_ratio + test_ratio ≈ 1.0.
  • Reads the full multilabel matrix y from the dataset (all N rows).

Preferred path — iterstrat:

  • Uses importlib.import_module("iterstrat.ml_stratifiers") and MultilabelStratifiedShuffleSplit (avoids static import resolution issues in some IDEs).
  • Two-stage split:
    1. Split off train vs temp where temp has fraction val_ratio + test_ratio of the data, stratified on y.
    2. Split temp into val and test with a second stratified split; the inner test_size is test_ratio / (val_ratio + test_ratio) so val and test each get the intended fraction of the full dataset (up to rounding).

Fallback — random:

  • If iterstrat is not installed (ImportError), emits a UserWarning and uses _split_indices_random: a single permutation of 0..N-1, then contiguous slices for train / val / test lengths from the rounded ratios. This preserves approximate sizes but does not stratify labels.

After splitting, the function prints per-split label statistics (positive count and percentage of that split per label) via _print_split_label_distribution.

Return value: three torch.utils.data.Subset objects sharing the same underlying ProteinLocalizationDataset, indexed by the chosen row indices.


create_dataloaders(train, val, test, batch_size, num_workers)

Builds a dictionary:

{"train": DataLoader, "val": DataLoader, "test": DataLoader}
  • Train: shuffle=True.
  • Val / test: shuffle=False.
  • All use the custom collate_fn described above.

num_workers=0 is the default (simple debugging on Windows; increase for faster loading if needed).


compute_class_weights(train_dataset)

Used for torch.nn.BCEWithLogitsLoss(..., pos_weight=...):

  • Accepts either a full ProteinLocalizationDataset or a Subset (e.g. training split only).
  • For each label (k):
    (\text{pos_weight}_k = \frac{#\text{negatives}}{#\text{positives}})
    with a small floor on positives to avoid division by zero.
  • Returns a torch.float32 tensor of shape (num_labels,).
  • Prints each weight and marks extreme values (> 50 or < 0.1) for inspection.

Label names for printing come from the base dataset’s label_names property.


Dependencies

  • Required: numpy, torch.
  • Optional (recommended): iterstrat — multilabel stratified splits. Listed in requirements.txt; install in a Python version the package supports (some very new Python releases may not have wheels yet).

Related scripts

  • scripts/test_dataset.py — smoke test: load a folder, run splits, one train batch, class weights.

Design notes

  • No I/O in __getitem__: all arrays are loaded in __init__, so training loops are not disk-bound per epoch (at the cost of RAM for large N).
  • Subset-aware helpers: splitting only changes which indices are used; embedding rows and targets stay aligned by construction.