Spaces:
Running
Running
| """ | |
| Classifier head for protein localization from precomputed embeddings. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Mapping, Sequence | |
| import torch | |
| from torch import Tensor, nn | |
| ROOT = Path(__file__).resolve().parent.parent.parent | |
| DEFAULT_LABEL_COLUMNS_JSON = ROOT / "data" / "processed" / "embeddings" / "esm2_t33_650M" / "label_columns.json" | |
| FALLBACK_LABEL_NAMES: List[str] = [ | |
| "Membrane", | |
| "Cytoplasm", | |
| "Nucleus", | |
| "Extracellular", | |
| "Cell membrane", | |
| "Mitochondrion", | |
| "Plastid", | |
| "Endoplasmic reticulum", | |
| "Lysosome/Vacuole", | |
| "Golgi apparatus", | |
| "Peroxisome", | |
| ] | |
| def _load_label_names_from_json(path: Path) -> List[str] | None: | |
| if not path.is_file(): | |
| return None | |
| with path.open("r", encoding="utf-8") as f: | |
| payload: Any = json.load(f) | |
| if isinstance(payload, dict) and isinstance(payload.get("label_columns"), list): | |
| names = [str(x) for x in payload["label_columns"]] | |
| if names: | |
| return names | |
| return None | |
| class ProteinLocalizationClassifier(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim: int, | |
| num_labels: int | None = None, | |
| dropout_rates: Sequence[float] = (0.3, 0.3, 0.2), | |
| hidden_dims: Sequence[int] = (512, 256, 128), | |
| label_names: Sequence[str] | None = None, | |
| label_columns_path: str | Path | None = None, | |
| ) -> None: | |
| super().__init__() | |
| if len(dropout_rates) != 3: | |
| raise ValueError(f"Expected 3 dropout rates, got {len(dropout_rates)}") | |
| if len(hidden_dims) != 3: | |
| raise ValueError(f"Expected 3 hidden dims, got {len(hidden_dims)}") | |
| if embedding_dim <= 0: | |
| raise ValueError("embedding_dim must be > 0") | |
| if label_names is None: | |
| if label_columns_path is None: | |
| label_columns_file = DEFAULT_LABEL_COLUMNS_JSON | |
| else: | |
| label_columns_file = Path(label_columns_path).expanduser().resolve() | |
| resolved = _load_label_names_from_json(label_columns_file) | |
| label_names = resolved if resolved is not None else FALLBACK_LABEL_NAMES | |
| inferred_num_labels = len(label_names) | |
| if num_labels is None: | |
| self.num_labels = inferred_num_labels | |
| else: | |
| if num_labels <= 0: | |
| raise ValueError("num_labels must be > 0") | |
| self.num_labels = int(num_labels) | |
| if self.num_labels != inferred_num_labels: | |
| raise ValueError( | |
| f"num_labels={self.num_labels} must match len(label_names)={inferred_num_labels}" | |
| ) | |
| self.label_names = list(label_names) | |
| h1, h2, h3 = [int(h) for h in hidden_dims] | |
| d1, d2, d3 = [float(d) for d in dropout_rates] | |
| self.net = nn.Sequential( | |
| nn.Linear(embedding_dim, h1), | |
| nn.BatchNorm1d(h1), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(d1), | |
| nn.Linear(h1, h2), | |
| nn.BatchNorm1d(h2), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(d2), | |
| nn.Linear(h2, h3), | |
| nn.BatchNorm1d(h3), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(d3), | |
| nn.Linear(h3, self.num_labels), | |
| ) | |
| self._init_weights() | |
| def _init_weights(self) -> None: | |
| for module in self.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu") | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def forward(self, x: Tensor) -> Tensor: | |
| # No sigmoid here; use BCEWithLogitsLoss during training. | |
| return self.net(x) | |
| def _ensure_batch(self, embedding: Tensor) -> tuple[Tensor, bool]: | |
| if embedding.dim() == 1: | |
| return embedding.unsqueeze(0), True | |
| if embedding.dim() == 2: | |
| return embedding, False | |
| raise ValueError(f"Expected tensor with dim 1 or 2, got shape {tuple(embedding.shape)}") | |
| def predict_proba(self, embedding: Tensor) -> Dict[str, float] | List[Dict[str, float]]: | |
| was_training = self.training | |
| self.eval() | |
| with torch.no_grad(): | |
| x, single = self._ensure_batch(embedding) | |
| probs = torch.sigmoid(self.forward(x)) | |
| probs_cpu = probs.detach().cpu().tolist() | |
| if was_training: | |
| self.train() | |
| output = [ | |
| {name: float(row[i]) for i, name in enumerate(self.label_names)} | |
| for row in probs_cpu | |
| ] | |
| return output[0] if single else output | |
| def predict( | |
| self, | |
| embedding: Tensor, | |
| thresholds: Dict[str, float] | Tensor | None = None, | |
| ) -> Dict[str, int] | List[Dict[str, int]]: | |
| was_training = self.training | |
| self.eval() | |
| with torch.no_grad(): | |
| x, single = self._ensure_batch(embedding) | |
| probs = torch.sigmoid(self.forward(x)) | |
| if thresholds is None: | |
| th = torch.full((self.num_labels,), 0.5, dtype=probs.dtype, device=probs.device) | |
| elif isinstance(thresholds, dict): | |
| th_vals = [float(thresholds.get(name, 0.5)) for name in self.label_names] | |
| th = torch.tensor(th_vals, dtype=probs.dtype, device=probs.device) | |
| elif isinstance(thresholds, Tensor): | |
| if thresholds.numel() != self.num_labels: | |
| raise ValueError( | |
| f"threshold tensor must have {self.num_labels} values, got {thresholds.numel()}" | |
| ) | |
| th = thresholds.to(device=probs.device, dtype=probs.dtype).reshape(-1) | |
| else: | |
| raise TypeError("thresholds must be None, dict, or torch.Tensor") | |
| binary = (probs >= th.unsqueeze(0)).to(torch.int64).detach().cpu().tolist() | |
| if was_training: | |
| self.train() | |
| output = [ | |
| {name: int(row[i]) for i, name in enumerate(self.label_names)} | |
| for row in binary | |
| ] | |
| return output[0] if single else output | |
| def count_parameters(model: nn.Module) -> None: | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Total parameters: {total:,}") | |
| print(f"Trainable parameters: {trainable:,}") | |
| def load_model( | |
| path: str | Path, | |
| embedding_dim: int, | |
| num_labels: int | None, | |
| device: torch.device | str, | |
| ) -> ProteinLocalizationClassifier: | |
| device = torch.device(device) | |
| ckpt_path = Path(path).expanduser().resolve() | |
| checkpoint = torch.load(ckpt_path, map_location=device) | |
| label_names: Sequence[str] | None = None | |
| if isinstance(checkpoint, dict) and "label_names" in checkpoint: | |
| label_names = checkpoint["label_names"] | |
| if isinstance(checkpoint, dict) and "state_dict" in checkpoint: | |
| state_dict = checkpoint["state_dict"] | |
| elif isinstance(checkpoint, Mapping): | |
| state_dict = checkpoint | |
| else: | |
| raise ValueError("Unsupported checkpoint format: expected dict or dict with 'state_dict'.") | |
| if num_labels is None: | |
| if label_names is not None: | |
| num_labels = len(label_names) | |
| else: | |
| classifier_weight = state_dict.get("net.12.weight") | |
| if classifier_weight is None: | |
| raise ValueError("Could not infer num_labels from checkpoint; pass num_labels explicitly.") | |
| num_labels = int(classifier_weight.shape[0]) | |
| dropout_rates: Sequence[float] | None = None | |
| hidden_dims: Sequence[int] | None = None | |
| if isinstance(checkpoint, dict): | |
| if "dropout_rates" in checkpoint: | |
| dropout_rates = tuple(checkpoint["dropout_rates"]) # type: ignore[assignment] | |
| if "hidden_dims" in checkpoint: | |
| hidden_dims = tuple(int(x) for x in checkpoint["hidden_dims"]) # type: ignore[assignment] | |
| model = ProteinLocalizationClassifier( | |
| embedding_dim=embedding_dim, | |
| num_labels=num_labels, | |
| label_names=label_names, | |
| dropout_rates=dropout_rates if dropout_rates is not None else (0.3, 0.3, 0.2), | |
| hidden_dims=hidden_dims if hidden_dims is not None else (512, 256, 128), | |
| ) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |