Spaces:
Running
Running
| """ | |
| PyTorch Dataset and helpers for protein localization (multilabel) training. | |
| Paths are resolved with pathlib.Path for cross-platform use. | |
| """ | |
| from __future__ import annotations | |
| import importlib | |
| import json | |
| import warnings | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset, Subset | |
| from src.data.residue_dataset import ResidueDataset | |
| def _collate_localization_batch( | |
| batch: List[Tuple[torch.Tensor, torch.Tensor, str]], | |
| ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: | |
| embeddings = torch.stack([b[0] for b in batch], dim=0) | |
| targets = torch.stack([b[1] for b in batch], dim=0) | |
| accessions = [b[2] for b in batch] | |
| return embeddings, targets, accessions | |
| class ProteinLocalizationDataset(Dataset): | |
| """ | |
| Loads precomputed embeddings and multilabel targets from a model-specific folder. | |
| Properties ``label_names``, ``num_labels``, and ``embedding_dim`` reflect the | |
| files on disk (e.g. DeepLoc commonly uses 10 or 11 compartment columns). | |
| """ | |
| def __init__(self, embeddings_dir: Union[str, Path]) -> None: | |
| self.embeddings_dir = Path(embeddings_dir).expanduser().resolve() | |
| emb_path = self.embeddings_dir / "embeddings.npy" | |
| acc_path = self.embeddings_dir / "accessions.npy" | |
| tgt_path = self.embeddings_dir / "multilabel_targets.npy" | |
| labels_path = self.embeddings_dir / "label_columns.json" | |
| for p in (emb_path, acc_path, tgt_path, labels_path): | |
| if not p.is_file(): | |
| raise FileNotFoundError(f"Missing required file: {p}") | |
| emb = np.load(emb_path) | |
| acc = np.load(acc_path, allow_pickle=True) | |
| tgt = np.load(tgt_path) | |
| with labels_path.open(encoding="utf-8") as f: | |
| label_meta: dict[str, Any] = json.load(f) | |
| self._label_names: List[str] = list(label_meta["label_columns"]) | |
| self._embeddings = torch.from_numpy(np.asarray(emb, dtype=np.float32)) | |
| self._targets = torch.from_numpy(np.asarray(tgt, dtype=np.float32)) | |
| n = self._embeddings.shape[0] | |
| if self._targets.shape[0] != n or acc.shape[0] != n: | |
| raise ValueError( | |
| f"Length mismatch: embeddings {n}, targets {self._targets.shape[0]}, " | |
| f"accessions {acc.shape[0]}" | |
| ) | |
| if self._targets.shape[1] != len(self._label_names): | |
| raise ValueError( | |
| f"Target width {self._targets.shape[1]} != len(label_columns) " | |
| f"{len(self._label_names)}" | |
| ) | |
| self._accessions = acc | |
| def label_names(self) -> List[str]: | |
| return list(self._label_names) | |
| def embedding_dim(self) -> int: | |
| return int(self._embeddings.shape[1]) | |
| def num_labels(self) -> int: | |
| return int(self._targets.shape[1]) | |
| def __len__(self) -> int: | |
| return int(self._embeddings.shape[0]) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, str]: | |
| emb = self._embeddings[idx] | |
| tgt = self._targets[idx] | |
| acc = self._accessions[idx] | |
| if isinstance(acc, np.ndarray): | |
| acc = acc.item() if acc.ndim == 0 else str(acc) | |
| return emb, tgt, str(acc) | |
| def _targets_for_indices( | |
| dataset: Union[ProteinLocalizationDataset, ResidueDataset], indices: List[int] | |
| ) -> np.ndarray: | |
| idx_t = torch.as_tensor(indices, dtype=torch.long) | |
| return dataset._targets[idx_t].numpy() | |
| def _print_split_label_distribution( | |
| name: str, dataset: Union[ProteinLocalizationDataset, ResidueDataset], indices: List[int] | |
| ) -> None: | |
| y = _targets_for_indices(dataset, indices) | |
| n = len(indices) | |
| pos = y.sum(axis=0) | |
| pos_pct = 100.0 * pos / max(n, 1) | |
| print(f"\n{name} split (n={n}) — positives per label (count, % of split):") | |
| for i, label in enumerate(dataset.label_names): | |
| print(f" {label}: {int(pos[i])} ({pos_pct[i]:.2f}%)") | |
| def _split_indices_iterstrat( | |
| y: np.ndarray, | |
| train_ratio: float, | |
| val_ratio: float, | |
| test_ratio: float, | |
| random_seed: int, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| mss_mod = importlib.import_module("iterstrat.ml_stratifiers") | |
| MultilabelStratifiedShuffleSplit = getattr(mss_mod, "MultilabelStratifiedShuffleSplit") | |
| n = y.shape[0] | |
| X = np.zeros((n, 1), dtype=np.float32) | |
| outer_test_size = val_ratio + test_ratio | |
| mss1 = MultilabelStratifiedShuffleSplit( | |
| n_splits=1, | |
| test_size=outer_test_size, | |
| random_state=random_seed, | |
| ) | |
| train_idx, temp_idx = next(mss1.split(X, y)) | |
| y_temp = y[temp_idx] | |
| X_temp = np.zeros((len(temp_idx), 1), dtype=np.float32) | |
| inner_test_frac = test_ratio / outer_test_size | |
| mss2 = MultilabelStratifiedShuffleSplit( | |
| n_splits=1, | |
| test_size=inner_test_frac, | |
| random_state=random_seed + 1, | |
| ) | |
| rel_val_idx, rel_test_idx = next(mss2.split(X_temp, y_temp)) | |
| val_idx = temp_idx[rel_val_idx] | |
| test_idx = temp_idx[rel_test_idx] | |
| return train_idx, val_idx, test_idx | |
| def _split_indices_random( | |
| y: np.ndarray, | |
| train_ratio: float, | |
| val_ratio: float, | |
| test_ratio: float, | |
| random_seed: int, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| n = y.shape[0] | |
| rng = np.random.default_rng(random_seed) | |
| perm = rng.permutation(n) | |
| n_train = int(round(train_ratio * n)) | |
| n_val = int(round(val_ratio * n)) | |
| train_idx = perm[:n_train] | |
| val_idx = perm[n_train : n_train + n_val] | |
| test_idx = perm[n_train + n_val :] | |
| return train_idx, val_idx, test_idx | |
| def create_splits( | |
| dataset: Union[ProteinLocalizationDataset, ResidueDataset], | |
| train_ratio: float = 0.7, | |
| val_ratio: float = 0.15, | |
| test_ratio: float = 0.15, | |
| random_seed: int = 42, | |
| ) -> Tuple[Subset, Subset, Subset]: | |
| if not abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6: | |
| raise ValueError( | |
| f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}" | |
| ) | |
| y = dataset._targets.numpy() | |
| train_idx: np.ndarray | |
| val_idx: np.ndarray | |
| test_idx: np.ndarray | |
| try: | |
| train_idx, val_idx, test_idx = _split_indices_iterstrat( | |
| y, train_ratio, val_ratio, test_ratio, random_seed | |
| ) | |
| print("Splitting with iterstrat (MultilabelStratifiedShuffleSplit).") | |
| except ImportError: | |
| warnings.warn( | |
| "iterstrat is not installed; using a random split (not multilabel-stratified). " | |
| "scikit-learn does not ship MultilabelStratifiedShuffleSplit — " | |
| "install iterstrat for IterativeStratification-style splits.", | |
| UserWarning, | |
| stacklevel=2, | |
| ) | |
| train_idx, val_idx, test_idx = _split_indices_random( | |
| y, train_ratio, val_ratio, test_ratio, random_seed | |
| ) | |
| train_idx_list = train_idx.tolist() | |
| val_idx_list = val_idx.tolist() | |
| test_idx_list = test_idx.tolist() | |
| train_dataset = Subset(dataset, train_idx_list) | |
| val_dataset = Subset(dataset, val_idx_list) | |
| test_dataset = Subset(dataset, test_idx_list) | |
| _print_split_label_distribution("Train", dataset, train_idx_list) | |
| _print_split_label_distribution("Val", dataset, val_idx_list) | |
| _print_split_label_distribution("Test", dataset, test_idx_list) | |
| return train_dataset, val_dataset, test_dataset | |
| def create_dataloaders( | |
| train_dataset: Dataset, | |
| val_dataset: Dataset, | |
| test_dataset: Dataset, | |
| batch_size: int = 64, | |
| num_workers: int = 0, | |
| ) -> Dict[str, DataLoader]: | |
| return { | |
| "train": DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| collate_fn=_collate_localization_batch, | |
| ), | |
| "val": DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| collate_fn=_collate_localization_batch, | |
| ), | |
| "test": DataLoader( | |
| test_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| collate_fn=_collate_localization_batch, | |
| ), | |
| } | |
| def compute_class_weights( | |
| train_dataset: Union[ProteinLocalizationDataset, ResidueDataset, Subset], | |
| ) -> torch.Tensor: | |
| """ | |
| Per-class pos_weight for BCEWithLogitsLoss: num_negatives / num_positives per label. | |
| Shape (num_labels,) matching the training subset. | |
| """ | |
| if isinstance(train_dataset, Subset): | |
| base = train_dataset.dataset | |
| indices = list(train_dataset.indices) | |
| y = _targets_for_indices(base, indices) # type: ignore[arg-type] | |
| else: | |
| y = train_dataset._targets.numpy() | |
| num_pos = y.sum(axis=0).astype(np.float64) | |
| num_neg = y.shape[0] - num_pos | |
| pos_weight = num_neg / np.maximum(num_pos, 1e-8) | |
| w = torch.tensor(pos_weight, dtype=torch.float32) | |
| print("BCE pos_weight (neg/pos) per label:") | |
| if isinstance(train_dataset, Subset): | |
| names = getattr(train_dataset.dataset, "label_names", []) | |
| else: | |
| names = train_dataset.label_names | |
| for i, name in enumerate(names): | |
| val = float(w[i].item()) | |
| flag = "" | |
| if val > 50.0 or val < 0.1: | |
| flag = " ** extreme **" | |
| print(f" {name}: {val:.6f}{flag}") | |
| return w | |