| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass(frozen=True) |
| class ModelConfig: |
| max_peaks: int = 256 |
| mz_max: float = 2000.0 |
| collision_max: float = 200.0 |
| model_dim: int = 384 |
| layers: int = 6 |
| heads: int = 8 |
| dropout: float = 0.1 |
| projection_dim: int = 192 |
| fingerprint_dim: int = 2048 |
| target_projection_dim: int = 256 |
| retrieval_mlp_hidden_dim: int = 512 |
| metadata_scale: float = 0.02 |
|
|
|
|
| class NexaMassSpectralEncoder(nn.Module): |
| """Encoder-only MS/MS transformer used by NexaMass-V3-Struct. |
| |
| Expected batch keys: |
| - mzs, ints, mz_to_precursor, peak_rank: float tensors [batch, max_peaks] |
| - precursor_mz, charge, collision_energy, peak_count: float tensors [batch] |
| - adduct_id, instrument_id: long tensors [batch] |
| - mask: bool tensor [batch, max_peaks], True for valid peaks |
| """ |
|
|
| def __init__(self, cfg: ModelConfig) -> None: |
| super().__init__() |
| self.cfg = cfg |
| self.adduct_embedding = nn.Embedding(64, cfg.model_dim) |
| self.instrument_embedding = nn.Embedding(64, cfg.model_dim) |
| self.input_projection = nn.Linear(8, cfg.model_dim) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=cfg.model_dim, |
| nhead=cfg.heads, |
| dim_feedforward=cfg.model_dim * 4, |
| dropout=cfg.dropout, |
| activation="gelu", |
| batch_first=True, |
| norm_first=True, |
| ) |
| try: |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.layers, enable_nested_tensor=False) |
| except TypeError: |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.layers) |
| self.final_norm = nn.LayerNorm(cfg.model_dim) |
| self.projection = nn.Sequential( |
| nn.Linear(cfg.model_dim, cfg.model_dim), |
| nn.GELU(), |
| nn.Dropout(cfg.dropout), |
| nn.Linear(cfg.model_dim, cfg.projection_dim), |
| ) |
| self.structure_head = nn.Sequential( |
| nn.Linear(cfg.model_dim, cfg.model_dim), |
| nn.GELU(), |
| nn.Dropout(cfg.dropout), |
| nn.Linear(cfg.model_dim, cfg.fingerprint_dim), |
| ) |
| self.structure_query = nn.Sequential( |
| nn.Linear(cfg.model_dim, cfg.model_dim), |
| nn.GELU(), |
| nn.Dropout(cfg.dropout), |
| nn.Linear(cfg.model_dim, cfg.target_projection_dim), |
| ) |
| self.target_projection = nn.Sequential( |
| nn.Linear(cfg.fingerprint_dim, cfg.model_dim), |
| nn.GELU(), |
| nn.Dropout(cfg.dropout), |
| nn.Linear(cfg.model_dim, cfg.target_projection_dim), |
| ) |
| self.retrieval_bilinear = nn.Linear(cfg.target_projection_dim, cfg.target_projection_dim, bias=False) |
| self.retrieval_pair_mlp = nn.Sequential( |
| nn.Linear(cfg.target_projection_dim * 4, cfg.retrieval_mlp_hidden_dim), |
| nn.GELU(), |
| nn.Dropout(cfg.dropout), |
| nn.Linear(cfg.retrieval_mlp_hidden_dim, 1), |
| ) |
| self.local_rerank_mlp = nn.Sequential( |
| nn.Linear(cfg.target_projection_dim * 4 + 1, cfg.retrieval_mlp_hidden_dim), |
| nn.GELU(), |
| nn.Dropout(cfg.dropout), |
| nn.Linear(cfg.retrieval_mlp_hidden_dim, 1), |
| ) |
|
|
| def encode(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: |
| features = torch.stack( |
| [ |
| batch["mzs"], |
| batch["ints"], |
| batch["mz_to_precursor"], |
| batch["peak_rank"], |
| batch["precursor_mz"].unsqueeze(-1).expand_as(batch["mzs"]), |
| batch["charge"].unsqueeze(-1).expand_as(batch["mzs"]), |
| batch["collision_energy"].unsqueeze(-1).expand_as(batch["mzs"]), |
| batch["peak_count"].unsqueeze(-1).expand_as(batch["mzs"]), |
| ], |
| dim=-1, |
| ) |
| hidden = self.input_projection(features) |
| hidden = hidden + self.adduct_embedding(batch["adduct_id"])[:, None, :] * self.cfg.metadata_scale |
| hidden = hidden + self.instrument_embedding(batch["instrument_id"])[:, None, :] * self.cfg.metadata_scale |
| encoded = self.encoder(hidden, src_key_padding_mask=~batch["mask"]) |
| encoded = self.final_norm(encoded) |
| mask = batch["mask"].unsqueeze(-1) |
| return (encoded * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
|
|
| def forward_with_heads( |
| self, batch: dict[str, torch.Tensor] |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| pooled = self.encode(batch) |
| raw_projected = self.projection(pooled) |
| structure_logits = self.structure_head(pooled) |
| structure_query_raw = self.structure_query(pooled) |
| return F.normalize(raw_projected, dim=-1), raw_projected, structure_logits, structure_query_raw |
|
|
| def project_structure_targets(self, targets: torch.Tensor) -> torch.Tensor: |
| return F.normalize(self.target_projection(targets), dim=-1) |
|
|
|
|
|
|
|
|
| def load_nexamass_state_dict( |
| checkpoint_path: str, |
| map_location: str | torch.device = "cpu", |
| ) -> dict[str, torch.Tensor]: |
| """Load public NexaMass model-state weights from Safetensors or PyTorch. |
| |
| Hugging Face public release weights are Safetensors-only. The PyTorch branch is |
| kept for internal/object-storage compatibility with full training checkpoints |
| and model-state fallbacks. |
| """ |
|
|
| path = Path(checkpoint_path) |
| if path.suffix == ".safetensors": |
| try: |
| from safetensors.torch import load_file |
| except ImportError as exc: |
| raise RuntimeError("Install safetensors to load NexaMass public weights: pip install safetensors") from exc |
| device = str(map_location) if isinstance(map_location, str) else "cpu" |
| if device not in {"cpu", "cuda"} and not device.startswith("cuda:"): |
| device = "cpu" |
| return load_file(str(path), device=device) |
|
|
| try: |
| payload = torch.load(path, map_location=map_location, weights_only=True) |
| except TypeError: |
| payload = torch.load(path, map_location=map_location) |
| if isinstance(payload, dict) and "model_state" in payload: |
| return payload["model_state"] |
| if isinstance(payload, dict): |
| return payload |
| raise TypeError(f"Unsupported NexaMass checkpoint payload type: {type(payload)!r}") |
|
|
|
|
| def load_nexamass_model_state( |
| checkpoint_path: str, |
| cfg: ModelConfig | None = None, |
| map_location: str | torch.device = "cpu", |
| ) -> NexaMassSpectralEncoder: |
| state_dict = load_nexamass_state_dict(checkpoint_path, map_location=map_location) |
| cfg = cfg or ModelConfig() |
| model = NexaMassSpectralEncoder(cfg) |
| model.load_state_dict(state_dict, strict=True) |
| model.eval() |
| return model |
|
|