""" Classifier head for protein localization from precomputed embeddings. """ from __future__ import annotations import json from pathlib import Path from typing import Any, Dict, List, Mapping, Sequence import torch from torch import Tensor, nn ROOT = Path(__file__).resolve().parent.parent.parent DEFAULT_LABEL_COLUMNS_JSON = ROOT / "data" / "processed" / "embeddings" / "esm2_t33_650M" / "label_columns.json" FALLBACK_LABEL_NAMES: List[str] = [ "Membrane", "Cytoplasm", "Nucleus", "Extracellular", "Cell membrane", "Mitochondrion", "Plastid", "Endoplasmic reticulum", "Lysosome/Vacuole", "Golgi apparatus", "Peroxisome", ] def _load_label_names_from_json(path: Path) -> List[str] | None: if not path.is_file(): return None with path.open("r", encoding="utf-8") as f: payload: Any = json.load(f) if isinstance(payload, dict) and isinstance(payload.get("label_columns"), list): names = [str(x) for x in payload["label_columns"]] if names: return names return None class ProteinLocalizationClassifier(nn.Module): def __init__( self, embedding_dim: int, num_labels: int | None = None, dropout_rates: Sequence[float] = (0.3, 0.3, 0.2), hidden_dims: Sequence[int] = (512, 256, 128), label_names: Sequence[str] | None = None, label_columns_path: str | Path | None = None, ) -> None: super().__init__() if len(dropout_rates) != 3: raise ValueError(f"Expected 3 dropout rates, got {len(dropout_rates)}") if len(hidden_dims) != 3: raise ValueError(f"Expected 3 hidden dims, got {len(hidden_dims)}") if embedding_dim <= 0: raise ValueError("embedding_dim must be > 0") if label_names is None: if label_columns_path is None: label_columns_file = DEFAULT_LABEL_COLUMNS_JSON else: label_columns_file = Path(label_columns_path).expanduser().resolve() resolved = _load_label_names_from_json(label_columns_file) label_names = resolved if resolved is not None else FALLBACK_LABEL_NAMES inferred_num_labels = len(label_names) if num_labels is None: self.num_labels = inferred_num_labels else: if num_labels <= 0: raise ValueError("num_labels must be > 0") self.num_labels = int(num_labels) if self.num_labels != inferred_num_labels: raise ValueError( f"num_labels={self.num_labels} must match len(label_names)={inferred_num_labels}" ) self.label_names = list(label_names) h1, h2, h3 = [int(h) for h in hidden_dims] d1, d2, d3 = [float(d) for d in dropout_rates] self.net = nn.Sequential( nn.Linear(embedding_dim, h1), nn.BatchNorm1d(h1), nn.ReLU(inplace=True), nn.Dropout(d1), nn.Linear(h1, h2), nn.BatchNorm1d(h2), nn.ReLU(inplace=True), nn.Dropout(d2), nn.Linear(h2, h3), nn.BatchNorm1d(h3), nn.ReLU(inplace=True), nn.Dropout(d3), nn.Linear(h3, self.num_labels), ) self._init_weights() def _init_weights(self) -> None: for module in self.modules(): if isinstance(module, nn.Linear): nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu") if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, x: Tensor) -> Tensor: # No sigmoid here; use BCEWithLogitsLoss during training. return self.net(x) def _ensure_batch(self, embedding: Tensor) -> tuple[Tensor, bool]: if embedding.dim() == 1: return embedding.unsqueeze(0), True if embedding.dim() == 2: return embedding, False raise ValueError(f"Expected tensor with dim 1 or 2, got shape {tuple(embedding.shape)}") def predict_proba(self, embedding: Tensor) -> Dict[str, float] | List[Dict[str, float]]: was_training = self.training self.eval() with torch.no_grad(): x, single = self._ensure_batch(embedding) probs = torch.sigmoid(self.forward(x)) probs_cpu = probs.detach().cpu().tolist() if was_training: self.train() output = [ {name: float(row[i]) for i, name in enumerate(self.label_names)} for row in probs_cpu ] return output[0] if single else output def predict( self, embedding: Tensor, thresholds: Dict[str, float] | Tensor | None = None, ) -> Dict[str, int] | List[Dict[str, int]]: was_training = self.training self.eval() with torch.no_grad(): x, single = self._ensure_batch(embedding) probs = torch.sigmoid(self.forward(x)) if thresholds is None: th = torch.full((self.num_labels,), 0.5, dtype=probs.dtype, device=probs.device) elif isinstance(thresholds, dict): th_vals = [float(thresholds.get(name, 0.5)) for name in self.label_names] th = torch.tensor(th_vals, dtype=probs.dtype, device=probs.device) elif isinstance(thresholds, Tensor): if thresholds.numel() != self.num_labels: raise ValueError( f"threshold tensor must have {self.num_labels} values, got {thresholds.numel()}" ) th = thresholds.to(device=probs.device, dtype=probs.dtype).reshape(-1) else: raise TypeError("thresholds must be None, dict, or torch.Tensor") binary = (probs >= th.unsqueeze(0)).to(torch.int64).detach().cpu().tolist() if was_training: self.train() output = [ {name: int(row[i]) for i, name in enumerate(self.label_names)} for row in binary ] return output[0] if single else output def count_parameters(model: nn.Module) -> None: total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total:,}") print(f"Trainable parameters: {trainable:,}") def load_model( path: str | Path, embedding_dim: int, num_labels: int | None, device: torch.device | str, ) -> ProteinLocalizationClassifier: device = torch.device(device) ckpt_path = Path(path).expanduser().resolve() checkpoint = torch.load(ckpt_path, map_location=device) label_names: Sequence[str] | None = None if isinstance(checkpoint, dict) and "label_names" in checkpoint: label_names = checkpoint["label_names"] if isinstance(checkpoint, dict) and "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] elif isinstance(checkpoint, Mapping): state_dict = checkpoint else: raise ValueError("Unsupported checkpoint format: expected dict or dict with 'state_dict'.") if num_labels is None: if label_names is not None: num_labels = len(label_names) else: classifier_weight = state_dict.get("net.12.weight") if classifier_weight is None: raise ValueError("Could not infer num_labels from checkpoint; pass num_labels explicitly.") num_labels = int(classifier_weight.shape[0]) dropout_rates: Sequence[float] | None = None hidden_dims: Sequence[int] | None = None if isinstance(checkpoint, dict): if "dropout_rates" in checkpoint: dropout_rates = tuple(checkpoint["dropout_rates"]) # type: ignore[assignment] if "hidden_dims" in checkpoint: hidden_dims = tuple(int(x) for x in checkpoint["hidden_dims"]) # type: ignore[assignment] model = ProteinLocalizationClassifier( embedding_dim=embedding_dim, num_labels=num_labels, label_names=label_names, dropout_rates=dropout_rates if dropout_rates is not None else (0.3, 0.3, 0.2), hidden_dims=hidden_dims if hidden_dims is not None else (512, 256, 128), ) model.load_state_dict(state_dict) model.to(device) model.eval() return model