protloc-ai / src /models /classifier.py
Tanoj22
Force add src/models and src/data code files
fe5a903
"""
Classifier head for protein localization from precomputed embeddings.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Mapping, Sequence
import torch
from torch import Tensor, nn
ROOT = Path(__file__).resolve().parent.parent.parent
DEFAULT_LABEL_COLUMNS_JSON = ROOT / "data" / "processed" / "embeddings" / "esm2_t33_650M" / "label_columns.json"
FALLBACK_LABEL_NAMES: List[str] = [
"Membrane",
"Cytoplasm",
"Nucleus",
"Extracellular",
"Cell membrane",
"Mitochondrion",
"Plastid",
"Endoplasmic reticulum",
"Lysosome/Vacuole",
"Golgi apparatus",
"Peroxisome",
]
def _load_label_names_from_json(path: Path) -> List[str] | None:
if not path.is_file():
return None
with path.open("r", encoding="utf-8") as f:
payload: Any = json.load(f)
if isinstance(payload, dict) and isinstance(payload.get("label_columns"), list):
names = [str(x) for x in payload["label_columns"]]
if names:
return names
return None
class ProteinLocalizationClassifier(nn.Module):
def __init__(
self,
embedding_dim: int,
num_labels: int | None = None,
dropout_rates: Sequence[float] = (0.3, 0.3, 0.2),
hidden_dims: Sequence[int] = (512, 256, 128),
label_names: Sequence[str] | None = None,
label_columns_path: str | Path | None = None,
) -> None:
super().__init__()
if len(dropout_rates) != 3:
raise ValueError(f"Expected 3 dropout rates, got {len(dropout_rates)}")
if len(hidden_dims) != 3:
raise ValueError(f"Expected 3 hidden dims, got {len(hidden_dims)}")
if embedding_dim <= 0:
raise ValueError("embedding_dim must be > 0")
if label_names is None:
if label_columns_path is None:
label_columns_file = DEFAULT_LABEL_COLUMNS_JSON
else:
label_columns_file = Path(label_columns_path).expanduser().resolve()
resolved = _load_label_names_from_json(label_columns_file)
label_names = resolved if resolved is not None else FALLBACK_LABEL_NAMES
inferred_num_labels = len(label_names)
if num_labels is None:
self.num_labels = inferred_num_labels
else:
if num_labels <= 0:
raise ValueError("num_labels must be > 0")
self.num_labels = int(num_labels)
if self.num_labels != inferred_num_labels:
raise ValueError(
f"num_labels={self.num_labels} must match len(label_names)={inferred_num_labels}"
)
self.label_names = list(label_names)
h1, h2, h3 = [int(h) for h in hidden_dims]
d1, d2, d3 = [float(d) for d in dropout_rates]
self.net = nn.Sequential(
nn.Linear(embedding_dim, h1),
nn.BatchNorm1d(h1),
nn.ReLU(inplace=True),
nn.Dropout(d1),
nn.Linear(h1, h2),
nn.BatchNorm1d(h2),
nn.ReLU(inplace=True),
nn.Dropout(d2),
nn.Linear(h2, h3),
nn.BatchNorm1d(h3),
nn.ReLU(inplace=True),
nn.Dropout(d3),
nn.Linear(h3, self.num_labels),
)
self._init_weights()
def _init_weights(self) -> None:
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, x: Tensor) -> Tensor:
# No sigmoid here; use BCEWithLogitsLoss during training.
return self.net(x)
def _ensure_batch(self, embedding: Tensor) -> tuple[Tensor, bool]:
if embedding.dim() == 1:
return embedding.unsqueeze(0), True
if embedding.dim() == 2:
return embedding, False
raise ValueError(f"Expected tensor with dim 1 or 2, got shape {tuple(embedding.shape)}")
def predict_proba(self, embedding: Tensor) -> Dict[str, float] | List[Dict[str, float]]:
was_training = self.training
self.eval()
with torch.no_grad():
x, single = self._ensure_batch(embedding)
probs = torch.sigmoid(self.forward(x))
probs_cpu = probs.detach().cpu().tolist()
if was_training:
self.train()
output = [
{name: float(row[i]) for i, name in enumerate(self.label_names)}
for row in probs_cpu
]
return output[0] if single else output
def predict(
self,
embedding: Tensor,
thresholds: Dict[str, float] | Tensor | None = None,
) -> Dict[str, int] | List[Dict[str, int]]:
was_training = self.training
self.eval()
with torch.no_grad():
x, single = self._ensure_batch(embedding)
probs = torch.sigmoid(self.forward(x))
if thresholds is None:
th = torch.full((self.num_labels,), 0.5, dtype=probs.dtype, device=probs.device)
elif isinstance(thresholds, dict):
th_vals = [float(thresholds.get(name, 0.5)) for name in self.label_names]
th = torch.tensor(th_vals, dtype=probs.dtype, device=probs.device)
elif isinstance(thresholds, Tensor):
if thresholds.numel() != self.num_labels:
raise ValueError(
f"threshold tensor must have {self.num_labels} values, got {thresholds.numel()}"
)
th = thresholds.to(device=probs.device, dtype=probs.dtype).reshape(-1)
else:
raise TypeError("thresholds must be None, dict, or torch.Tensor")
binary = (probs >= th.unsqueeze(0)).to(torch.int64).detach().cpu().tolist()
if was_training:
self.train()
output = [
{name: int(row[i]) for i, name in enumerate(self.label_names)}
for row in binary
]
return output[0] if single else output
def count_parameters(model: nn.Module) -> None:
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total:,}")
print(f"Trainable parameters: {trainable:,}")
def load_model(
path: str | Path,
embedding_dim: int,
num_labels: int | None,
device: torch.device | str,
) -> ProteinLocalizationClassifier:
device = torch.device(device)
ckpt_path = Path(path).expanduser().resolve()
checkpoint = torch.load(ckpt_path, map_location=device)
label_names: Sequence[str] | None = None
if isinstance(checkpoint, dict) and "label_names" in checkpoint:
label_names = checkpoint["label_names"]
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
elif isinstance(checkpoint, Mapping):
state_dict = checkpoint
else:
raise ValueError("Unsupported checkpoint format: expected dict or dict with 'state_dict'.")
if num_labels is None:
if label_names is not None:
num_labels = len(label_names)
else:
classifier_weight = state_dict.get("net.12.weight")
if classifier_weight is None:
raise ValueError("Could not infer num_labels from checkpoint; pass num_labels explicitly.")
num_labels = int(classifier_weight.shape[0])
dropout_rates: Sequence[float] | None = None
hidden_dims: Sequence[int] | None = None
if isinstance(checkpoint, dict):
if "dropout_rates" in checkpoint:
dropout_rates = tuple(checkpoint["dropout_rates"]) # type: ignore[assignment]
if "hidden_dims" in checkpoint:
hidden_dims = tuple(int(x) for x in checkpoint["hidden_dims"]) # type: ignore[assignment]
model = ProteinLocalizationClassifier(
embedding_dim=embedding_dim,
num_labels=num_labels,
label_names=label_names,
dropout_rates=dropout_rates if dropout_rates is not None else (0.3, 0.3, 0.2),
hidden_dims=hidden_dims if hidden_dims is not None else (512, 256, 128),
)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model