Add MassSpecGym evaluation adapter and safetensors runtime loader
#1
by Allanatrix - opened
- README.md +5 -0
- config.json +8 -0
- evaluation/massspecgym/README.md +35 -0
- evaluation/massspecgym/figures/nexamass_massspecgym_hit20_position.png +0 -0
- evaluation/massspecgym/results/massspecgym_hitk_summary.json +16 -0
- evaluation/massspecgym/run_massspecgym_retrieval_hf.py +271 -0
- runtime/nexamass_encoder.py +38 -2
README.md
CHANGED
|
@@ -107,3 +107,8 @@ MS/MS structure inference can affect downstream scientific interpretation. Users
|
|
| 107 |
## Citation
|
| 108 |
|
| 109 |
If you use this model, cite the NexaMass project release and the accompanying technical report when available. Relevant background work includes DreaMS for self-supervised MS/MS representation learning, MassSpecGym for benchmark framing, CSI:FingerID for fingerprint-mediated candidate search, and related spectra-structure retrieval and de novo generation systems such as MIST, MSNovelist, CMSSP, CSU-MS2, MSBERT, Spec2Mol, and MS2Mol.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
## Citation
|
| 108 |
|
| 109 |
If you use this model, cite the NexaMass project release and the accompanying technical report when available. Relevant background work includes DreaMS for self-supervised MS/MS representation learning, MassSpecGym for benchmark framing, CSI:FingerID for fingerprint-mediated candidate search, and related spectra-structure retrieval and de novo generation systems such as MIST, MSNovelist, CMSSP, CSU-MS2, MSBERT, Spec2Mol, and MS2Mol.
|
| 110 |
+
|
| 111 |
+
## MassSpecGym Adapter
|
| 112 |
+
|
| 113 |
+
A safetensors-compatible MassSpecGym retrieval adapter is included under `evaluation/massspecgym/`. It loads `weights/NexaMass-V3-Struct-model_state.safetensors`, converts MassSpecGym tokenized spectra into the NexaMass batch contract, and reports Hit@k retrieval metrics through MassSpecGym's evaluator. The archived reference run reached test Hit@20 `0.3505` with the frozen projected-dot scorer. This should be read as evidence of transferable top-k signal, not solved molecular ranking or calibrated confidence.
|
| 114 |
+
|
config.json
CHANGED
|
@@ -16,6 +16,14 @@
|
|
| 16 |
"architectures": [
|
| 17 |
"NexaMassSpectralEncoder"
|
| 18 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"foundation_checkpoint": "weights/Final_V3-model_state.safetensors",
|
| 20 |
"foundation_checkpoint_format": "safetensors",
|
| 21 |
"full_training_checkpoints": {
|
|
|
|
| 16 |
"architectures": [
|
| 17 |
"NexaMassSpectralEncoder"
|
| 18 |
],
|
| 19 |
+
"evaluation_adapters": {
|
| 20 |
+
"massspecgym": {
|
| 21 |
+
"benchmark": "MassSpecGym molecule retrieval",
|
| 22 |
+
"claim_boundary": "top-k transfer signal; ranking and confidence remain open decision-layer problems",
|
| 23 |
+
"path": "evaluation/massspecgym/run_massspecgym_retrieval_hf.py",
|
| 24 |
+
"reference_result": "test Hit@20 0.3505 with frozen V3 projected-dot scorer under Hit@k-only evaluation"
|
| 25 |
+
}
|
| 26 |
+
},
|
| 27 |
"foundation_checkpoint": "weights/Final_V3-model_state.safetensors",
|
| 28 |
"foundation_checkpoint_format": "safetensors",
|
| 29 |
"full_training_checkpoints": {
|
evaluation/massspecgym/README.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MassSpecGym Evaluation Adapter
|
| 2 |
+
|
| 3 |
+
This directory contains the public Hugging Face adapter used to position `NexaMass-V3-Struct` on the MassSpecGym molecule-retrieval task.
|
| 4 |
+
|
| 5 |
+
The adapter loads the safetensors-only public checkpoint and wraps MassSpecGym's own `RetrievalDataset`, `MassSpecDataModule`, and retrieval evaluator. It is meant for external benchmark positioning, not for claiming that ranking or confidence are solved.
|
| 6 |
+
|
| 7 |
+
## Install
|
| 8 |
+
|
| 9 |
+
Use an isolated environment because MassSpecGym has its own dependency surface:
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
python -m pip install torch safetensors huggingface_hub massspecgym==1.3.1 pytorch-lightning
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## Run From A Clone Of This HF Repo
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
python evaluation/massspecgym/run_massspecgym_retrieval_hf.py \
|
| 19 |
+
--checkpoint weights/NexaMass-V3-Struct-model_state.safetensors \
|
| 20 |
+
--config config.json \
|
| 21 |
+
--split test \
|
| 22 |
+
--scorer projected_dot \
|
| 23 |
+
--hit-only \
|
| 24 |
+
--batch-size 32 \
|
| 25 |
+
--num-workers 25 \
|
| 26 |
+
--output-json evaluation/massspecgym/results/local_massspecgym_test.json
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
If the checkpoint is not present locally, the script can download it from this repo through `huggingface_hub`.
|
| 30 |
+
|
| 31 |
+
## Reported Reference Result
|
| 32 |
+
|
| 33 |
+
The archived adapter run reached MassSpecGym test Hit@20 `0.3505` under Hit@k-only evaluation using the frozen V3 projected-dot scorer. This put the model above lower baselines such as Random, DeepSets, Fingerprint FFN, and DeepSets+Fourier, while remaining below specialized retrieval systems such as MIST.
|
| 34 |
+
|
| 35 |
+
Interpretation: the encoder transfers real top-k structure signal to retrieval, but exact local ranking and calibrated confidence remain separate downstream problems.
|
evaluation/massspecgym/figures/nexamass_massspecgym_hit20_position.png
ADDED
|
evaluation/massspecgym/results/massspecgym_hitk_summary.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"benchmark": "MassSpecGym molecule retrieval",
|
| 3 |
+
"adapter": "evaluation/massspecgym/run_massspecgym_retrieval_hf.py",
|
| 4 |
+
"checkpoint": "weights/NexaMass-V3-Struct-model_state.safetensors",
|
| 5 |
+
"scorer": "projected_dot",
|
| 6 |
+
"evaluation_mode": "test dataloader through validation loop, Hit@k-only",
|
| 7 |
+
"metrics": {
|
| 8 |
+
"test_hit_at_1": 0.0627,
|
| 9 |
+
"test_hit_at_5": 0.1753,
|
| 10 |
+
"test_hit_at_20": 0.3505,
|
| 11 |
+
"val_hit_at_1": 0.1162,
|
| 12 |
+
"val_hit_at_5": 0.1915,
|
| 13 |
+
"val_hit_at_20": 0.3328
|
| 14 |
+
},
|
| 15 |
+
"claim_boundary": "External positioning sanity check; demonstrates top-k transfer signal, not solved ranking or confidence."
|
| 16 |
+
}
|
evaluation/massspecgym/run_massspecgym_retrieval_hf.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Evaluate NexaMass-V3-Struct on MassSpecGym retrieval.
|
| 3 |
+
|
| 4 |
+
This is the Hugging Face release adapter. It loads the public safetensors
|
| 5 |
+
checkpoint from this repository and wraps MassSpecGym's official retrieval data
|
| 6 |
+
module/evaluator. The adapter is for external benchmark positioning, not for
|
| 7 |
+
claiming that ranking or confidence are solved.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import sys
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 22 |
+
if str(REPO_ROOT) not in sys.path:
|
| 23 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 24 |
+
|
| 25 |
+
from runtime.nexamass_encoder import ModelConfig, NexaMassSpectralEncoder, load_nexamass_model_state # noqa: E402
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _require_massspecgym() -> tuple[Any, Any, Any, Any, Any]:
|
| 29 |
+
try:
|
| 30 |
+
from massspecgym.data import MassSpecDataModule, RetrievalDataset
|
| 31 |
+
from massspecgym.data.transforms import MolFingerprinter, SpecTokenizer
|
| 32 |
+
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
|
| 33 |
+
from pytorch_lightning import Trainer
|
| 34 |
+
except ImportError as exc:
|
| 35 |
+
raise SystemExit(
|
| 36 |
+
"MassSpecGym dependencies are missing. Install in an isolated env with: "
|
| 37 |
+
"python -m pip install massspecgym==1.3.1 pytorch-lightning safetensors huggingface_hub"
|
| 38 |
+
) from exc
|
| 39 |
+
return Trainer, MassSpecDataModule, RetrievalDataset, MolFingerprinter, SpecTokenizer, RetrievalMassSpecGymModel
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _cfg_from_json(path: Path) -> ModelConfig:
|
| 43 |
+
if not path.exists():
|
| 44 |
+
return ModelConfig()
|
| 45 |
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
| 46 |
+
arch = payload.get("architecture_config", payload)
|
| 47 |
+
allowed = ModelConfig.__dataclass_fields__.keys()
|
| 48 |
+
return ModelConfig(**{key: arch[key] for key in allowed if key in arch})
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _resolve_checkpoint(path: Path, repo_id: str, filename: str) -> Path:
|
| 52 |
+
if path.exists():
|
| 53 |
+
return path
|
| 54 |
+
try:
|
| 55 |
+
from huggingface_hub import hf_hub_download
|
| 56 |
+
except ImportError as exc:
|
| 57 |
+
raise SystemExit("Checkpoint was not found locally and huggingface_hub is not installed.") from exc
|
| 58 |
+
return Path(hf_hub_download(repo_id=repo_id, repo_type="model", filename=filename))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _parse_limit_batches(raw: str) -> int | float:
|
| 62 |
+
value = raw.strip()
|
| 63 |
+
if value.isdigit():
|
| 64 |
+
return int(value)
|
| 65 |
+
return float(value)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _batch_from_massspecgym_spec(
|
| 69 |
+
spec: torch.Tensor,
|
| 70 |
+
cfg: ModelConfig,
|
| 71 |
+
device: torch.device,
|
| 72 |
+
*,
|
| 73 |
+
precursor_mz: torch.Tensor | None = None,
|
| 74 |
+
) -> dict[str, torch.Tensor]:
|
| 75 |
+
"""Convert MassSpecGym tokenized spectra into NexaMass' encoder batch contract."""
|
| 76 |
+
|
| 77 |
+
if spec.ndim != 3 or spec.shape[-1] < 2:
|
| 78 |
+
raise ValueError(f"Expected MassSpecGym spec shape [batch, peaks, >=2], got {tuple(spec.shape)}")
|
| 79 |
+
|
| 80 |
+
spec = spec.to(device=device, dtype=torch.float32)
|
| 81 |
+
mzs_raw = spec[..., 0].clamp(min=0.0)
|
| 82 |
+
ints_raw = spec[..., 1].clamp(min=0.0)
|
| 83 |
+
batch_size, peak_count = mzs_raw.shape
|
| 84 |
+
if peak_count > cfg.max_peaks:
|
| 85 |
+
mzs_raw = mzs_raw[:, : cfg.max_peaks]
|
| 86 |
+
ints_raw = ints_raw[:, : cfg.max_peaks]
|
| 87 |
+
peak_count = cfg.max_peaks
|
| 88 |
+
|
| 89 |
+
mask = (mzs_raw > 0) & torch.isfinite(mzs_raw) & torch.isfinite(ints_raw)
|
| 90 |
+
max_intensity = ints_raw.masked_fill(~mask, 0.0).amax(dim=1, keepdim=True).clamp(min=1e-6)
|
| 91 |
+
mzs_norm = (mzs_raw / cfg.mz_max).clamp(0.0, 1.5)
|
| 92 |
+
ints_norm = (ints_raw / max_intensity).masked_fill(~mask, 0.0)
|
| 93 |
+
if precursor_mz is not None:
|
| 94 |
+
precursor_raw = precursor_mz.to(device=device, dtype=torch.float32).view(-1).clamp(min=1e-6)
|
| 95 |
+
if precursor_raw.numel() != batch_size:
|
| 96 |
+
raise ValueError(f"Expected {batch_size} precursor_mz values, got {precursor_raw.numel()}")
|
| 97 |
+
else:
|
| 98 |
+
precursor_raw = mzs_raw.masked_fill(~mask, 0.0).amax(dim=1).clamp(min=1e-6)
|
| 99 |
+
mz_to_precursor = (mzs_raw / precursor_raw[:, None]).clamp(0.0, 2.0).masked_fill(~mask, 0.0)
|
| 100 |
+
ranks = torch.linspace(0.0, 1.0, peak_count, device=device, dtype=torch.float32)[None, :].expand(batch_size, -1)
|
| 101 |
+
|
| 102 |
+
if peak_count < cfg.max_peaks:
|
| 103 |
+
pad_width = cfg.max_peaks - peak_count
|
| 104 |
+
|
| 105 |
+
def pad(values: torch.Tensor, value: float = 0.0) -> torch.Tensor:
|
| 106 |
+
return F.pad(values, (0, pad_width), value=value)
|
| 107 |
+
|
| 108 |
+
mzs_norm = pad(mzs_norm)
|
| 109 |
+
ints_norm = pad(ints_norm)
|
| 110 |
+
mz_to_precursor = pad(mz_to_precursor)
|
| 111 |
+
ranks = pad(ranks)
|
| 112 |
+
mask = F.pad(mask, (0, pad_width), value=False)
|
| 113 |
+
|
| 114 |
+
observed_peak_count = mask.sum(dim=1).to(dtype=torch.float32).clamp(min=1.0)
|
| 115 |
+
return {
|
| 116 |
+
"mzs": mzs_norm,
|
| 117 |
+
"ints": ints_norm,
|
| 118 |
+
"mz_to_precursor": mz_to_precursor,
|
| 119 |
+
"peak_rank": ranks,
|
| 120 |
+
"mask": mask.to(dtype=torch.bool),
|
| 121 |
+
"precursor_mz": (precursor_raw / cfg.mz_max).clamp(max=2.0),
|
| 122 |
+
"charge": torch.zeros(batch_size, device=device, dtype=torch.float32),
|
| 123 |
+
"collision_energy": torch.zeros(batch_size, device=device, dtype=torch.float32),
|
| 124 |
+
"adduct_id": torch.zeros(batch_size, device=device, dtype=torch.long),
|
| 125 |
+
"instrument_id": torch.zeros(batch_size, device=device, dtype=torch.long),
|
| 126 |
+
"peak_count": observed_peak_count / float(cfg.max_peaks),
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _scores_for_batch(
|
| 131 |
+
*,
|
| 132 |
+
scorer: str,
|
| 133 |
+
model: NexaMassSpectralEncoder,
|
| 134 |
+
cfg: ModelConfig,
|
| 135 |
+
spec: torch.Tensor,
|
| 136 |
+
candidates: torch.Tensor,
|
| 137 |
+
batch_ptr: torch.Tensor,
|
| 138 |
+
precursor_mz: torch.Tensor | None,
|
| 139 |
+
device: torch.device,
|
| 140 |
+
) -> torch.Tensor:
|
| 141 |
+
batch = _batch_from_massspecgym_spec(spec, cfg, device, precursor_mz=precursor_mz)
|
| 142 |
+
candidates = candidates.to(device=device, dtype=torch.float32)
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
_embedding, _raw_projected, logits, query_raw = model.forward_with_heads(batch)
|
| 145 |
+
pred_probs = torch.sigmoid(logits)
|
| 146 |
+
if scorer == "predicted_fingerprint":
|
| 147 |
+
query_repeated = F.normalize(pred_probs, dim=-1).repeat_interleave(batch_ptr.to(device), dim=0)
|
| 148 |
+
return F.cosine_similarity(query_repeated, F.normalize(candidates, dim=-1), dim=-1).detach()
|
| 149 |
+
if scorer == "projected_dot":
|
| 150 |
+
query_repeated = F.normalize(query_raw, dim=-1).repeat_interleave(batch_ptr.to(device), dim=0)
|
| 151 |
+
target_projection = F.normalize(model.project_structure_targets(candidates), dim=-1)
|
| 152 |
+
return (query_repeated * target_projection).sum(dim=-1).detach()
|
| 153 |
+
raise ValueError(f"Unsupported scorer: {scorer}")
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def main() -> int:
|
| 157 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 158 |
+
parser.add_argument("--repo-id", default="AethronPhantom/NexaMass-V3-Struct")
|
| 159 |
+
parser.add_argument("--checkpoint", type=Path, default=REPO_ROOT / "weights/NexaMass-V3-Struct-model_state.safetensors")
|
| 160 |
+
parser.add_argument("--checkpoint-filename", default="weights/NexaMass-V3-Struct-model_state.safetensors")
|
| 161 |
+
parser.add_argument("--config", type=Path, default=REPO_ROOT / "config.json")
|
| 162 |
+
parser.add_argument("--scorer", choices=["projected_dot", "predicted_fingerprint"], default="projected_dot")
|
| 163 |
+
parser.add_argument("--split", choices=["val", "test"], default="test")
|
| 164 |
+
parser.add_argument("--batch-size", type=int, default=32)
|
| 165 |
+
parser.add_argument("--num-workers", type=int, default=8)
|
| 166 |
+
parser.add_argument("--n-peaks", type=int, default=256)
|
| 167 |
+
parser.add_argument("--accelerator", default="gpu")
|
| 168 |
+
parser.add_argument("--devices", default="1")
|
| 169 |
+
parser.add_argument("--limit-batches", default="1.0")
|
| 170 |
+
parser.add_argument("--hit-only", action="store_true", help="Use validation loop over test dataloader for Hit@k-only scoring.")
|
| 171 |
+
parser.add_argument("--inspect-batch-only", action="store_true")
|
| 172 |
+
parser.add_argument("--output-json", type=Path)
|
| 173 |
+
args = parser.parse_args()
|
| 174 |
+
|
| 175 |
+
Trainer, MassSpecDataModule, RetrievalDataset, MolFingerprinter, SpecTokenizer, RetrievalMassSpecGymModel = (
|
| 176 |
+
_require_massspecgym()
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
torch.set_float32_matmul_precision("high")
|
| 180 |
+
limit_batches = _parse_limit_batches(args.limit_batches)
|
| 181 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 182 |
+
cfg = _cfg_from_json(args.config)
|
| 183 |
+
checkpoint = _resolve_checkpoint(args.checkpoint.expanduser(), args.repo_id, args.checkpoint_filename)
|
| 184 |
+
v3_model = load_nexamass_model_state(str(checkpoint), cfg=cfg, map_location="cpu")
|
| 185 |
+
v3_model.to(device)
|
| 186 |
+
v3_model.eval()
|
| 187 |
+
|
| 188 |
+
class NexaMassRetrievalModel(RetrievalMassSpecGymModel): # type: ignore[misc, valid-type]
|
| 189 |
+
def __init__(self) -> None:
|
| 190 |
+
super().__init__()
|
| 191 |
+
self._inspected = False
|
| 192 |
+
|
| 193 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
| 194 |
+
batch = _batch_from_massspecgym_spec(spec, cfg, device)
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
_embedding, _raw, logits, query_raw = v3_model.forward_with_heads(batch)
|
| 197 |
+
return query_raw if args.scorer == "projected_dot" else torch.sigmoid(logits)
|
| 198 |
+
|
| 199 |
+
def step(self, batch: dict[str, Any], stage: Any) -> dict[str, torch.Tensor]:
|
| 200 |
+
if args.inspect_batch_only and not self._inspected:
|
| 201 |
+
print(
|
| 202 |
+
json.dumps(
|
| 203 |
+
{
|
| 204 |
+
"batch_keys": sorted(batch.keys()),
|
| 205 |
+
"spec_shape": list(batch["spec"].shape),
|
| 206 |
+
"candidates_mol_shape": list(batch["candidates_mol"].shape),
|
| 207 |
+
"batch_ptr_head": batch["batch_ptr"].detach().cpu().tolist()[:8],
|
| 208 |
+
},
|
| 209 |
+
indent=2,
|
| 210 |
+
),
|
| 211 |
+
flush=True,
|
| 212 |
+
)
|
| 213 |
+
self._inspected = True
|
| 214 |
+
scores = _scores_for_batch(
|
| 215 |
+
scorer=args.scorer,
|
| 216 |
+
model=v3_model,
|
| 217 |
+
cfg=cfg,
|
| 218 |
+
spec=batch["spec"],
|
| 219 |
+
candidates=batch["candidates_mol"],
|
| 220 |
+
batch_ptr=batch["batch_ptr"],
|
| 221 |
+
precursor_mz=batch.get("precursor_mz"),
|
| 222 |
+
device=device,
|
| 223 |
+
)
|
| 224 |
+
return {"loss": torch.zeros((), device=scores.device), "scores": scores}
|
| 225 |
+
|
| 226 |
+
dataset = RetrievalDataset(
|
| 227 |
+
spec_transform=SpecTokenizer(n_peaks=args.n_peaks),
|
| 228 |
+
mol_transform=MolFingerprinter(fp_size=cfg.fingerprint_dim),
|
| 229 |
+
)
|
| 230 |
+
data_module = MassSpecDataModule(dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers)
|
| 231 |
+
data_module.prepare_data()
|
| 232 |
+
data_module.setup(None if args.split == "val" else "test")
|
| 233 |
+
model = NexaMassRetrievalModel()
|
| 234 |
+
trainer = Trainer(
|
| 235 |
+
accelerator=args.accelerator,
|
| 236 |
+
devices=args.devices,
|
| 237 |
+
logger=False,
|
| 238 |
+
enable_checkpointing=False,
|
| 239 |
+
limit_val_batches=limit_batches if args.split == "val" or args.hit_only else 1.0,
|
| 240 |
+
limit_test_batches=limit_batches if args.split == "test" else 1.0,
|
| 241 |
+
)
|
| 242 |
+
if args.split == "val":
|
| 243 |
+
metrics = trainer.validate(model, datamodule=data_module)
|
| 244 |
+
elif args.hit_only:
|
| 245 |
+
metrics = trainer.validate(model, dataloaders=data_module.test_dataloader())
|
| 246 |
+
else:
|
| 247 |
+
metrics = trainer.test(model, datamodule=data_module)
|
| 248 |
+
|
| 249 |
+
payload = {
|
| 250 |
+
"checkpoint": str(checkpoint),
|
| 251 |
+
"scorer": args.scorer,
|
| 252 |
+
"split": args.split,
|
| 253 |
+
"metrics": metrics,
|
| 254 |
+
"massspecgym_adapter": {
|
| 255 |
+
"repo_id": args.repo_id,
|
| 256 |
+
"n_peaks": args.n_peaks,
|
| 257 |
+
"fingerprint_dim": cfg.fingerprint_dim,
|
| 258 |
+
"limit_batches": limit_batches,
|
| 259 |
+
"hit_only": args.hit_only,
|
| 260 |
+
"metadata_defaults": "charge/collision/adduct/instrument set to zero when absent from MassSpecGym batch",
|
| 261 |
+
},
|
| 262 |
+
}
|
| 263 |
+
print(json.dumps(payload, indent=2), flush=True)
|
| 264 |
+
if args.output_json:
|
| 265 |
+
args.output_json.parent.mkdir(parents=True, exist_ok=True)
|
| 266 |
+
args.output_json.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
|
| 267 |
+
return 0
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
raise SystemExit(main())
|
runtime/nexamass_encoder.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
@@ -126,14 +127,49 @@ class NexaMassSpectralEncoder(nn.Module):
|
|
| 126 |
return F.normalize(self.target_projection(targets), dim=-1)
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def load_nexamass_model_state(
|
| 130 |
checkpoint_path: str,
|
| 131 |
cfg: ModelConfig | None = None,
|
| 132 |
map_location: str | torch.device = "cpu",
|
| 133 |
) -> NexaMassSpectralEncoder:
|
| 134 |
-
|
| 135 |
cfg = cfg or ModelConfig()
|
| 136 |
model = NexaMassSpectralEncoder(cfg)
|
| 137 |
-
model.load_state_dict(
|
| 138 |
model.eval()
|
| 139 |
return model
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
|
|
|
| 127 |
return F.normalize(self.target_projection(targets), dim=-1)
|
| 128 |
|
| 129 |
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def load_nexamass_state_dict(
|
| 133 |
+
checkpoint_path: str,
|
| 134 |
+
map_location: str | torch.device = "cpu",
|
| 135 |
+
) -> dict[str, torch.Tensor]:
|
| 136 |
+
"""Load public NexaMass model-state weights from Safetensors or PyTorch.
|
| 137 |
+
|
| 138 |
+
Hugging Face public release weights are Safetensors-only. The PyTorch branch is
|
| 139 |
+
kept for internal/object-storage compatibility with full training checkpoints
|
| 140 |
+
and model-state fallbacks.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
path = Path(checkpoint_path)
|
| 144 |
+
if path.suffix == ".safetensors":
|
| 145 |
+
try:
|
| 146 |
+
from safetensors.torch import load_file
|
| 147 |
+
except ImportError as exc: # pragma: no cover - dependency message path
|
| 148 |
+
raise RuntimeError("Install safetensors to load NexaMass public weights: pip install safetensors") from exc
|
| 149 |
+
device = str(map_location) if isinstance(map_location, str) else "cpu"
|
| 150 |
+
if device not in {"cpu", "cuda"} and not device.startswith("cuda:"):
|
| 151 |
+
device = "cpu"
|
| 152 |
+
return load_file(str(path), device=device)
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
payload = torch.load(path, map_location=map_location, weights_only=True)
|
| 156 |
+
except TypeError: # older PyTorch
|
| 157 |
+
payload = torch.load(path, map_location=map_location)
|
| 158 |
+
if isinstance(payload, dict) and "model_state" in payload:
|
| 159 |
+
return payload["model_state"]
|
| 160 |
+
if isinstance(payload, dict):
|
| 161 |
+
return payload
|
| 162 |
+
raise TypeError(f"Unsupported NexaMass checkpoint payload type: {type(payload)!r}")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
def load_nexamass_model_state(
|
| 166 |
checkpoint_path: str,
|
| 167 |
cfg: ModelConfig | None = None,
|
| 168 |
map_location: str | torch.device = "cpu",
|
| 169 |
) -> NexaMassSpectralEncoder:
|
| 170 |
+
state_dict = load_nexamass_state_dict(checkpoint_path, map_location=map_location)
|
| 171 |
cfg = cfg or ModelConfig()
|
| 172 |
model = NexaMassSpectralEncoder(cfg)
|
| 173 |
+
model.load_state_dict(state_dict, strict=True)
|
| 174 |
model.eval()
|
| 175 |
return model
|