from __future__ import annotations from dataclasses import dataclass from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F @dataclass(frozen=True) class ModelConfig: max_peaks: int = 256 mz_max: float = 2000.0 collision_max: float = 200.0 model_dim: int = 384 layers: int = 6 heads: int = 8 dropout: float = 0.1 projection_dim: int = 192 fingerprint_dim: int = 2048 target_projection_dim: int = 256 retrieval_mlp_hidden_dim: int = 512 metadata_scale: float = 0.02 class NexaMassSpectralEncoder(nn.Module): """Encoder-only MS/MS transformer used by NexaMass-V3-Struct. Expected batch keys: - mzs, ints, mz_to_precursor, peak_rank: float tensors [batch, max_peaks] - precursor_mz, charge, collision_energy, peak_count: float tensors [batch] - adduct_id, instrument_id: long tensors [batch] - mask: bool tensor [batch, max_peaks], True for valid peaks """ def __init__(self, cfg: ModelConfig) -> None: super().__init__() self.cfg = cfg self.adduct_embedding = nn.Embedding(64, cfg.model_dim) self.instrument_embedding = nn.Embedding(64, cfg.model_dim) self.input_projection = nn.Linear(8, cfg.model_dim) encoder_layer = nn.TransformerEncoderLayer( d_model=cfg.model_dim, nhead=cfg.heads, dim_feedforward=cfg.model_dim * 4, dropout=cfg.dropout, activation="gelu", batch_first=True, norm_first=True, ) try: self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.layers, enable_nested_tensor=False) except TypeError: self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.layers) self.final_norm = nn.LayerNorm(cfg.model_dim) self.projection = nn.Sequential( nn.Linear(cfg.model_dim, cfg.model_dim), nn.GELU(), nn.Dropout(cfg.dropout), nn.Linear(cfg.model_dim, cfg.projection_dim), ) self.structure_head = nn.Sequential( nn.Linear(cfg.model_dim, cfg.model_dim), nn.GELU(), nn.Dropout(cfg.dropout), nn.Linear(cfg.model_dim, cfg.fingerprint_dim), ) self.structure_query = nn.Sequential( nn.Linear(cfg.model_dim, cfg.model_dim), nn.GELU(), nn.Dropout(cfg.dropout), nn.Linear(cfg.model_dim, cfg.target_projection_dim), ) self.target_projection = nn.Sequential( nn.Linear(cfg.fingerprint_dim, cfg.model_dim), nn.GELU(), nn.Dropout(cfg.dropout), nn.Linear(cfg.model_dim, cfg.target_projection_dim), ) self.retrieval_bilinear = nn.Linear(cfg.target_projection_dim, cfg.target_projection_dim, bias=False) self.retrieval_pair_mlp = nn.Sequential( nn.Linear(cfg.target_projection_dim * 4, cfg.retrieval_mlp_hidden_dim), nn.GELU(), nn.Dropout(cfg.dropout), nn.Linear(cfg.retrieval_mlp_hidden_dim, 1), ) self.local_rerank_mlp = nn.Sequential( nn.Linear(cfg.target_projection_dim * 4 + 1, cfg.retrieval_mlp_hidden_dim), nn.GELU(), nn.Dropout(cfg.dropout), nn.Linear(cfg.retrieval_mlp_hidden_dim, 1), ) def encode(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: features = torch.stack( [ batch["mzs"], batch["ints"], batch["mz_to_precursor"], batch["peak_rank"], batch["precursor_mz"].unsqueeze(-1).expand_as(batch["mzs"]), batch["charge"].unsqueeze(-1).expand_as(batch["mzs"]), batch["collision_energy"].unsqueeze(-1).expand_as(batch["mzs"]), batch["peak_count"].unsqueeze(-1).expand_as(batch["mzs"]), ], dim=-1, ) hidden = self.input_projection(features) hidden = hidden + self.adduct_embedding(batch["adduct_id"])[:, None, :] * self.cfg.metadata_scale hidden = hidden + self.instrument_embedding(batch["instrument_id"])[:, None, :] * self.cfg.metadata_scale encoded = self.encoder(hidden, src_key_padding_mask=~batch["mask"]) encoded = self.final_norm(encoded) mask = batch["mask"].unsqueeze(-1) return (encoded * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) def forward_with_heads( self, batch: dict[str, torch.Tensor] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: pooled = self.encode(batch) raw_projected = self.projection(pooled) structure_logits = self.structure_head(pooled) structure_query_raw = self.structure_query(pooled) return F.normalize(raw_projected, dim=-1), raw_projected, structure_logits, structure_query_raw def project_structure_targets(self, targets: torch.Tensor) -> torch.Tensor: return F.normalize(self.target_projection(targets), dim=-1) def load_nexamass_state_dict( checkpoint_path: str, map_location: str | torch.device = "cpu", ) -> dict[str, torch.Tensor]: """Load public NexaMass model-state weights from Safetensors or PyTorch. Hugging Face public release weights are Safetensors-only. The PyTorch branch is kept for internal/object-storage compatibility with full training checkpoints and model-state fallbacks. """ path = Path(checkpoint_path) if path.suffix == ".safetensors": try: from safetensors.torch import load_file except ImportError as exc: # pragma: no cover - dependency message path raise RuntimeError("Install safetensors to load NexaMass public weights: pip install safetensors") from exc device = str(map_location) if isinstance(map_location, str) else "cpu" if device not in {"cpu", "cuda"} and not device.startswith("cuda:"): device = "cpu" return load_file(str(path), device=device) try: payload = torch.load(path, map_location=map_location, weights_only=True) except TypeError: # older PyTorch payload = torch.load(path, map_location=map_location) if isinstance(payload, dict) and "model_state" in payload: return payload["model_state"] if isinstance(payload, dict): return payload raise TypeError(f"Unsupported NexaMass checkpoint payload type: {type(payload)!r}") def load_nexamass_model_state( checkpoint_path: str, cfg: ModelConfig | None = None, map_location: str | torch.device = "cpu", ) -> NexaMassSpectralEncoder: state_dict = load_nexamass_state_dict(checkpoint_path, map_location=map_location) cfg = cfg or ModelConfig() model = NexaMassSpectralEncoder(cfg) model.load_state_dict(state_dict, strict=True) model.eval() return model