protloc-ai / src /data /dataset.py
Tanoj22
Force add src/models and src/data code files
fe5a903
"""
PyTorch Dataset and helpers for protein localization (multilabel) training.
Paths are resolved with pathlib.Path for cross-platform use.
"""
from __future__ import annotations
import importlib
import json
import warnings
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, Subset
from src.data.residue_dataset import ResidueDataset
def _collate_localization_batch(
batch: List[Tuple[torch.Tensor, torch.Tensor, str]],
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
embeddings = torch.stack([b[0] for b in batch], dim=0)
targets = torch.stack([b[1] for b in batch], dim=0)
accessions = [b[2] for b in batch]
return embeddings, targets, accessions
class ProteinLocalizationDataset(Dataset):
"""
Loads precomputed embeddings and multilabel targets from a model-specific folder.
Properties ``label_names``, ``num_labels``, and ``embedding_dim`` reflect the
files on disk (e.g. DeepLoc commonly uses 10 or 11 compartment columns).
"""
def __init__(self, embeddings_dir: Union[str, Path]) -> None:
self.embeddings_dir = Path(embeddings_dir).expanduser().resolve()
emb_path = self.embeddings_dir / "embeddings.npy"
acc_path = self.embeddings_dir / "accessions.npy"
tgt_path = self.embeddings_dir / "multilabel_targets.npy"
labels_path = self.embeddings_dir / "label_columns.json"
for p in (emb_path, acc_path, tgt_path, labels_path):
if not p.is_file():
raise FileNotFoundError(f"Missing required file: {p}")
emb = np.load(emb_path)
acc = np.load(acc_path, allow_pickle=True)
tgt = np.load(tgt_path)
with labels_path.open(encoding="utf-8") as f:
label_meta: dict[str, Any] = json.load(f)
self._label_names: List[str] = list(label_meta["label_columns"])
self._embeddings = torch.from_numpy(np.asarray(emb, dtype=np.float32))
self._targets = torch.from_numpy(np.asarray(tgt, dtype=np.float32))
n = self._embeddings.shape[0]
if self._targets.shape[0] != n or acc.shape[0] != n:
raise ValueError(
f"Length mismatch: embeddings {n}, targets {self._targets.shape[0]}, "
f"accessions {acc.shape[0]}"
)
if self._targets.shape[1] != len(self._label_names):
raise ValueError(
f"Target width {self._targets.shape[1]} != len(label_columns) "
f"{len(self._label_names)}"
)
self._accessions = acc
@property
def label_names(self) -> List[str]:
return list(self._label_names)
@property
def embedding_dim(self) -> int:
return int(self._embeddings.shape[1])
@property
def num_labels(self) -> int:
return int(self._targets.shape[1])
def __len__(self) -> int:
return int(self._embeddings.shape[0])
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]:
emb = self._embeddings[idx]
tgt = self._targets[idx]
acc = self._accessions[idx]
if isinstance(acc, np.ndarray):
acc = acc.item() if acc.ndim == 0 else str(acc)
return emb, tgt, str(acc)
def _targets_for_indices(
dataset: Union[ProteinLocalizationDataset, ResidueDataset], indices: List[int]
) -> np.ndarray:
idx_t = torch.as_tensor(indices, dtype=torch.long)
return dataset._targets[idx_t].numpy()
def _print_split_label_distribution(
name: str, dataset: Union[ProteinLocalizationDataset, ResidueDataset], indices: List[int]
) -> None:
y = _targets_for_indices(dataset, indices)
n = len(indices)
pos = y.sum(axis=0)
pos_pct = 100.0 * pos / max(n, 1)
print(f"\n{name} split (n={n}) — positives per label (count, % of split):")
for i, label in enumerate(dataset.label_names):
print(f" {label}: {int(pos[i])} ({pos_pct[i]:.2f}%)")
def _split_indices_iterstrat(
y: np.ndarray,
train_ratio: float,
val_ratio: float,
test_ratio: float,
random_seed: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
mss_mod = importlib.import_module("iterstrat.ml_stratifiers")
MultilabelStratifiedShuffleSplit = getattr(mss_mod, "MultilabelStratifiedShuffleSplit")
n = y.shape[0]
X = np.zeros((n, 1), dtype=np.float32)
outer_test_size = val_ratio + test_ratio
mss1 = MultilabelStratifiedShuffleSplit(
n_splits=1,
test_size=outer_test_size,
random_state=random_seed,
)
train_idx, temp_idx = next(mss1.split(X, y))
y_temp = y[temp_idx]
X_temp = np.zeros((len(temp_idx), 1), dtype=np.float32)
inner_test_frac = test_ratio / outer_test_size
mss2 = MultilabelStratifiedShuffleSplit(
n_splits=1,
test_size=inner_test_frac,
random_state=random_seed + 1,
)
rel_val_idx, rel_test_idx = next(mss2.split(X_temp, y_temp))
val_idx = temp_idx[rel_val_idx]
test_idx = temp_idx[rel_test_idx]
return train_idx, val_idx, test_idx
def _split_indices_random(
y: np.ndarray,
train_ratio: float,
val_ratio: float,
test_ratio: float,
random_seed: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
n = y.shape[0]
rng = np.random.default_rng(random_seed)
perm = rng.permutation(n)
n_train = int(round(train_ratio * n))
n_val = int(round(val_ratio * n))
train_idx = perm[:n_train]
val_idx = perm[n_train : n_train + n_val]
test_idx = perm[n_train + n_val :]
return train_idx, val_idx, test_idx
def create_splits(
dataset: Union[ProteinLocalizationDataset, ResidueDataset],
train_ratio: float = 0.7,
val_ratio: float = 0.15,
test_ratio: float = 0.15,
random_seed: int = 42,
) -> Tuple[Subset, Subset, Subset]:
if not abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6:
raise ValueError(
f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}"
)
y = dataset._targets.numpy()
train_idx: np.ndarray
val_idx: np.ndarray
test_idx: np.ndarray
try:
train_idx, val_idx, test_idx = _split_indices_iterstrat(
y, train_ratio, val_ratio, test_ratio, random_seed
)
print("Splitting with iterstrat (MultilabelStratifiedShuffleSplit).")
except ImportError:
warnings.warn(
"iterstrat is not installed; using a random split (not multilabel-stratified). "
"scikit-learn does not ship MultilabelStratifiedShuffleSplit — "
"install iterstrat for IterativeStratification-style splits.",
UserWarning,
stacklevel=2,
)
train_idx, val_idx, test_idx = _split_indices_random(
y, train_ratio, val_ratio, test_ratio, random_seed
)
train_idx_list = train_idx.tolist()
val_idx_list = val_idx.tolist()
test_idx_list = test_idx.tolist()
train_dataset = Subset(dataset, train_idx_list)
val_dataset = Subset(dataset, val_idx_list)
test_dataset = Subset(dataset, test_idx_list)
_print_split_label_distribution("Train", dataset, train_idx_list)
_print_split_label_distribution("Val", dataset, val_idx_list)
_print_split_label_distribution("Test", dataset, test_idx_list)
return train_dataset, val_dataset, test_dataset
def create_dataloaders(
train_dataset: Dataset,
val_dataset: Dataset,
test_dataset: Dataset,
batch_size: int = 64,
num_workers: int = 0,
) -> Dict[str, DataLoader]:
return {
"train": DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=_collate_localization_batch,
),
"val": DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_localization_batch,
),
"test": DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_localization_batch,
),
}
def compute_class_weights(
train_dataset: Union[ProteinLocalizationDataset, ResidueDataset, Subset],
) -> torch.Tensor:
"""
Per-class pos_weight for BCEWithLogitsLoss: num_negatives / num_positives per label.
Shape (num_labels,) matching the training subset.
"""
if isinstance(train_dataset, Subset):
base = train_dataset.dataset
indices = list(train_dataset.indices)
y = _targets_for_indices(base, indices) # type: ignore[arg-type]
else:
y = train_dataset._targets.numpy()
num_pos = y.sum(axis=0).astype(np.float64)
num_neg = y.shape[0] - num_pos
pos_weight = num_neg / np.maximum(num_pos, 1e-8)
w = torch.tensor(pos_weight, dtype=torch.float32)
print("BCE pos_weight (neg/pos) per label:")
if isinstance(train_dataset, Subset):
names = getattr(train_dataset.dataset, "label_names", [])
else:
names = train_dataset.label_names
for i, name in enumerate(names):
val = float(w[i].item())
flag = ""
if val > 50.0 or val < 0.1:
flag = " ** extreme **"
print(f" {name}: {val:.6f}{flag}")
return w