""" PyTorch Dataset and helpers for protein localization (multilabel) training. Paths are resolved with pathlib.Path for cross-platform use. """ from __future__ import annotations import importlib import json import warnings from pathlib import Path from typing import Any, Dict, List, Tuple, Union import numpy as np import torch from torch.utils.data import DataLoader, Dataset, Subset from src.data.residue_dataset import ResidueDataset def _collate_localization_batch( batch: List[Tuple[torch.Tensor, torch.Tensor, str]], ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: embeddings = torch.stack([b[0] for b in batch], dim=0) targets = torch.stack([b[1] for b in batch], dim=0) accessions = [b[2] for b in batch] return embeddings, targets, accessions class ProteinLocalizationDataset(Dataset): """ Loads precomputed embeddings and multilabel targets from a model-specific folder. Properties ``label_names``, ``num_labels``, and ``embedding_dim`` reflect the files on disk (e.g. DeepLoc commonly uses 10 or 11 compartment columns). """ def __init__(self, embeddings_dir: Union[str, Path]) -> None: self.embeddings_dir = Path(embeddings_dir).expanduser().resolve() emb_path = self.embeddings_dir / "embeddings.npy" acc_path = self.embeddings_dir / "accessions.npy" tgt_path = self.embeddings_dir / "multilabel_targets.npy" labels_path = self.embeddings_dir / "label_columns.json" for p in (emb_path, acc_path, tgt_path, labels_path): if not p.is_file(): raise FileNotFoundError(f"Missing required file: {p}") emb = np.load(emb_path) acc = np.load(acc_path, allow_pickle=True) tgt = np.load(tgt_path) with labels_path.open(encoding="utf-8") as f: label_meta: dict[str, Any] = json.load(f) self._label_names: List[str] = list(label_meta["label_columns"]) self._embeddings = torch.from_numpy(np.asarray(emb, dtype=np.float32)) self._targets = torch.from_numpy(np.asarray(tgt, dtype=np.float32)) n = self._embeddings.shape[0] if self._targets.shape[0] != n or acc.shape[0] != n: raise ValueError( f"Length mismatch: embeddings {n}, targets {self._targets.shape[0]}, " f"accessions {acc.shape[0]}" ) if self._targets.shape[1] != len(self._label_names): raise ValueError( f"Target width {self._targets.shape[1]} != len(label_columns) " f"{len(self._label_names)}" ) self._accessions = acc @property def label_names(self) -> List[str]: return list(self._label_names) @property def embedding_dim(self) -> int: return int(self._embeddings.shape[1]) @property def num_labels(self) -> int: return int(self._targets.shape[1]) def __len__(self) -> int: return int(self._embeddings.shape[0]) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]: emb = self._embeddings[idx] tgt = self._targets[idx] acc = self._accessions[idx] if isinstance(acc, np.ndarray): acc = acc.item() if acc.ndim == 0 else str(acc) return emb, tgt, str(acc) def _targets_for_indices( dataset: Union[ProteinLocalizationDataset, ResidueDataset], indices: List[int] ) -> np.ndarray: idx_t = torch.as_tensor(indices, dtype=torch.long) return dataset._targets[idx_t].numpy() def _print_split_label_distribution( name: str, dataset: Union[ProteinLocalizationDataset, ResidueDataset], indices: List[int] ) -> None: y = _targets_for_indices(dataset, indices) n = len(indices) pos = y.sum(axis=0) pos_pct = 100.0 * pos / max(n, 1) print(f"\n{name} split (n={n}) — positives per label (count, % of split):") for i, label in enumerate(dataset.label_names): print(f" {label}: {int(pos[i])} ({pos_pct[i]:.2f}%)") def _split_indices_iterstrat( y: np.ndarray, train_ratio: float, val_ratio: float, test_ratio: float, random_seed: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: mss_mod = importlib.import_module("iterstrat.ml_stratifiers") MultilabelStratifiedShuffleSplit = getattr(mss_mod, "MultilabelStratifiedShuffleSplit") n = y.shape[0] X = np.zeros((n, 1), dtype=np.float32) outer_test_size = val_ratio + test_ratio mss1 = MultilabelStratifiedShuffleSplit( n_splits=1, test_size=outer_test_size, random_state=random_seed, ) train_idx, temp_idx = next(mss1.split(X, y)) y_temp = y[temp_idx] X_temp = np.zeros((len(temp_idx), 1), dtype=np.float32) inner_test_frac = test_ratio / outer_test_size mss2 = MultilabelStratifiedShuffleSplit( n_splits=1, test_size=inner_test_frac, random_state=random_seed + 1, ) rel_val_idx, rel_test_idx = next(mss2.split(X_temp, y_temp)) val_idx = temp_idx[rel_val_idx] test_idx = temp_idx[rel_test_idx] return train_idx, val_idx, test_idx def _split_indices_random( y: np.ndarray, train_ratio: float, val_ratio: float, test_ratio: float, random_seed: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: n = y.shape[0] rng = np.random.default_rng(random_seed) perm = rng.permutation(n) n_train = int(round(train_ratio * n)) n_val = int(round(val_ratio * n)) train_idx = perm[:n_train] val_idx = perm[n_train : n_train + n_val] test_idx = perm[n_train + n_val :] return train_idx, val_idx, test_idx def create_splits( dataset: Union[ProteinLocalizationDataset, ResidueDataset], train_ratio: float = 0.7, val_ratio: float = 0.15, test_ratio: float = 0.15, random_seed: int = 42, ) -> Tuple[Subset, Subset, Subset]: if not abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6: raise ValueError( f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}" ) y = dataset._targets.numpy() train_idx: np.ndarray val_idx: np.ndarray test_idx: np.ndarray try: train_idx, val_idx, test_idx = _split_indices_iterstrat( y, train_ratio, val_ratio, test_ratio, random_seed ) print("Splitting with iterstrat (MultilabelStratifiedShuffleSplit).") except ImportError: warnings.warn( "iterstrat is not installed; using a random split (not multilabel-stratified). " "scikit-learn does not ship MultilabelStratifiedShuffleSplit — " "install iterstrat for IterativeStratification-style splits.", UserWarning, stacklevel=2, ) train_idx, val_idx, test_idx = _split_indices_random( y, train_ratio, val_ratio, test_ratio, random_seed ) train_idx_list = train_idx.tolist() val_idx_list = val_idx.tolist() test_idx_list = test_idx.tolist() train_dataset = Subset(dataset, train_idx_list) val_dataset = Subset(dataset, val_idx_list) test_dataset = Subset(dataset, test_idx_list) _print_split_label_distribution("Train", dataset, train_idx_list) _print_split_label_distribution("Val", dataset, val_idx_list) _print_split_label_distribution("Test", dataset, test_idx_list) return train_dataset, val_dataset, test_dataset def create_dataloaders( train_dataset: Dataset, val_dataset: Dataset, test_dataset: Dataset, batch_size: int = 64, num_workers: int = 0, ) -> Dict[str, DataLoader]: return { "train": DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=_collate_localization_batch, ), "val": DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=_collate_localization_batch, ), "test": DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=_collate_localization_batch, ), } def compute_class_weights( train_dataset: Union[ProteinLocalizationDataset, ResidueDataset, Subset], ) -> torch.Tensor: """ Per-class pos_weight for BCEWithLogitsLoss: num_negatives / num_positives per label. Shape (num_labels,) matching the training subset. """ if isinstance(train_dataset, Subset): base = train_dataset.dataset indices = list(train_dataset.indices) y = _targets_for_indices(base, indices) # type: ignore[arg-type] else: y = train_dataset._targets.numpy() num_pos = y.sum(axis=0).astype(np.float64) num_neg = y.shape[0] - num_pos pos_weight = num_neg / np.maximum(num_pos, 1e-8) w = torch.tensor(pos_weight, dtype=torch.float32) print("BCE pos_weight (neg/pos) per label:") if isinstance(train_dataset, Subset): names = getattr(train_dataset.dataset, "label_names", []) else: names = train_dataset.label_names for i, name in enumerate(names): val = float(w[i].item()) flag = "" if val > 50.0 or val < 0.1: flag = " ** extreme **" print(f" {name}: {val:.6f}{flag}") return w