NexaMass-V3-Struct / runtime /nexamass_encoder.py
Allanatrix's picture
Add MassSpecGym evaluation adapter and safetensors runtime loader (#1)
a916c63
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass(frozen=True)
class ModelConfig:
max_peaks: int = 256
mz_max: float = 2000.0
collision_max: float = 200.0
model_dim: int = 384
layers: int = 6
heads: int = 8
dropout: float = 0.1
projection_dim: int = 192
fingerprint_dim: int = 2048
target_projection_dim: int = 256
retrieval_mlp_hidden_dim: int = 512
metadata_scale: float = 0.02
class NexaMassSpectralEncoder(nn.Module):
"""Encoder-only MS/MS transformer used by NexaMass-V3-Struct.
Expected batch keys:
- mzs, ints, mz_to_precursor, peak_rank: float tensors [batch, max_peaks]
- precursor_mz, charge, collision_energy, peak_count: float tensors [batch]
- adduct_id, instrument_id: long tensors [batch]
- mask: bool tensor [batch, max_peaks], True for valid peaks
"""
def __init__(self, cfg: ModelConfig) -> None:
super().__init__()
self.cfg = cfg
self.adduct_embedding = nn.Embedding(64, cfg.model_dim)
self.instrument_embedding = nn.Embedding(64, cfg.model_dim)
self.input_projection = nn.Linear(8, cfg.model_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=cfg.model_dim,
nhead=cfg.heads,
dim_feedforward=cfg.model_dim * 4,
dropout=cfg.dropout,
activation="gelu",
batch_first=True,
norm_first=True,
)
try:
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.layers, enable_nested_tensor=False)
except TypeError:
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.layers)
self.final_norm = nn.LayerNorm(cfg.model_dim)
self.projection = nn.Sequential(
nn.Linear(cfg.model_dim, cfg.model_dim),
nn.GELU(),
nn.Dropout(cfg.dropout),
nn.Linear(cfg.model_dim, cfg.projection_dim),
)
self.structure_head = nn.Sequential(
nn.Linear(cfg.model_dim, cfg.model_dim),
nn.GELU(),
nn.Dropout(cfg.dropout),
nn.Linear(cfg.model_dim, cfg.fingerprint_dim),
)
self.structure_query = nn.Sequential(
nn.Linear(cfg.model_dim, cfg.model_dim),
nn.GELU(),
nn.Dropout(cfg.dropout),
nn.Linear(cfg.model_dim, cfg.target_projection_dim),
)
self.target_projection = nn.Sequential(
nn.Linear(cfg.fingerprint_dim, cfg.model_dim),
nn.GELU(),
nn.Dropout(cfg.dropout),
nn.Linear(cfg.model_dim, cfg.target_projection_dim),
)
self.retrieval_bilinear = nn.Linear(cfg.target_projection_dim, cfg.target_projection_dim, bias=False)
self.retrieval_pair_mlp = nn.Sequential(
nn.Linear(cfg.target_projection_dim * 4, cfg.retrieval_mlp_hidden_dim),
nn.GELU(),
nn.Dropout(cfg.dropout),
nn.Linear(cfg.retrieval_mlp_hidden_dim, 1),
)
self.local_rerank_mlp = nn.Sequential(
nn.Linear(cfg.target_projection_dim * 4 + 1, cfg.retrieval_mlp_hidden_dim),
nn.GELU(),
nn.Dropout(cfg.dropout),
nn.Linear(cfg.retrieval_mlp_hidden_dim, 1),
)
def encode(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
features = torch.stack(
[
batch["mzs"],
batch["ints"],
batch["mz_to_precursor"],
batch["peak_rank"],
batch["precursor_mz"].unsqueeze(-1).expand_as(batch["mzs"]),
batch["charge"].unsqueeze(-1).expand_as(batch["mzs"]),
batch["collision_energy"].unsqueeze(-1).expand_as(batch["mzs"]),
batch["peak_count"].unsqueeze(-1).expand_as(batch["mzs"]),
],
dim=-1,
)
hidden = self.input_projection(features)
hidden = hidden + self.adduct_embedding(batch["adduct_id"])[:, None, :] * self.cfg.metadata_scale
hidden = hidden + self.instrument_embedding(batch["instrument_id"])[:, None, :] * self.cfg.metadata_scale
encoded = self.encoder(hidden, src_key_padding_mask=~batch["mask"])
encoded = self.final_norm(encoded)
mask = batch["mask"].unsqueeze(-1)
return (encoded * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
def forward_with_heads(
self, batch: dict[str, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
pooled = self.encode(batch)
raw_projected = self.projection(pooled)
structure_logits = self.structure_head(pooled)
structure_query_raw = self.structure_query(pooled)
return F.normalize(raw_projected, dim=-1), raw_projected, structure_logits, structure_query_raw
def project_structure_targets(self, targets: torch.Tensor) -> torch.Tensor:
return F.normalize(self.target_projection(targets), dim=-1)
def load_nexamass_state_dict(
checkpoint_path: str,
map_location: str | torch.device = "cpu",
) -> dict[str, torch.Tensor]:
"""Load public NexaMass model-state weights from Safetensors or PyTorch.
Hugging Face public release weights are Safetensors-only. The PyTorch branch is
kept for internal/object-storage compatibility with full training checkpoints
and model-state fallbacks.
"""
path = Path(checkpoint_path)
if path.suffix == ".safetensors":
try:
from safetensors.torch import load_file
except ImportError as exc: # pragma: no cover - dependency message path
raise RuntimeError("Install safetensors to load NexaMass public weights: pip install safetensors") from exc
device = str(map_location) if isinstance(map_location, str) else "cpu"
if device not in {"cpu", "cuda"} and not device.startswith("cuda:"):
device = "cpu"
return load_file(str(path), device=device)
try:
payload = torch.load(path, map_location=map_location, weights_only=True)
except TypeError: # older PyTorch
payload = torch.load(path, map_location=map_location)
if isinstance(payload, dict) and "model_state" in payload:
return payload["model_state"]
if isinstance(payload, dict):
return payload
raise TypeError(f"Unsupported NexaMass checkpoint payload type: {type(payload)!r}")
def load_nexamass_model_state(
checkpoint_path: str,
cfg: ModelConfig | None = None,
map_location: str | torch.device = "cpu",
) -> NexaMassSpectralEncoder:
state_dict = load_nexamass_state_dict(checkpoint_path, map_location=map_location)
cfg = cfg or ModelConfig()
model = NexaMassSpectralEncoder(cfg)
model.load_state_dict(state_dict, strict=True)
model.eval()
return model