Add MassSpecGym evaluation adapter and safetensors runtime loader

#1
by Allanatrix - opened
README.md CHANGED
@@ -107,3 +107,8 @@ MS/MS structure inference can affect downstream scientific interpretation. Users
107
  ## Citation
108
 
109
  If you use this model, cite the NexaMass project release and the accompanying technical report when available. Relevant background work includes DreaMS for self-supervised MS/MS representation learning, MassSpecGym for benchmark framing, CSI:FingerID for fingerprint-mediated candidate search, and related spectra-structure retrieval and de novo generation systems such as MIST, MSNovelist, CMSSP, CSU-MS2, MSBERT, Spec2Mol, and MS2Mol.
 
 
 
 
 
 
107
  ## Citation
108
 
109
  If you use this model, cite the NexaMass project release and the accompanying technical report when available. Relevant background work includes DreaMS for self-supervised MS/MS representation learning, MassSpecGym for benchmark framing, CSI:FingerID for fingerprint-mediated candidate search, and related spectra-structure retrieval and de novo generation systems such as MIST, MSNovelist, CMSSP, CSU-MS2, MSBERT, Spec2Mol, and MS2Mol.
110
+
111
+ ## MassSpecGym Adapter
112
+
113
+ A safetensors-compatible MassSpecGym retrieval adapter is included under `evaluation/massspecgym/`. It loads `weights/NexaMass-V3-Struct-model_state.safetensors`, converts MassSpecGym tokenized spectra into the NexaMass batch contract, and reports Hit@k retrieval metrics through MassSpecGym's evaluator. The archived reference run reached test Hit@20 `0.3505` with the frozen projected-dot scorer. This should be read as evidence of transferable top-k signal, not solved molecular ranking or calibrated confidence.
114
+
config.json CHANGED
@@ -16,6 +16,14 @@
16
  "architectures": [
17
  "NexaMassSpectralEncoder"
18
  ],
 
 
 
 
 
 
 
 
19
  "foundation_checkpoint": "weights/Final_V3-model_state.safetensors",
20
  "foundation_checkpoint_format": "safetensors",
21
  "full_training_checkpoints": {
 
16
  "architectures": [
17
  "NexaMassSpectralEncoder"
18
  ],
19
+ "evaluation_adapters": {
20
+ "massspecgym": {
21
+ "benchmark": "MassSpecGym molecule retrieval",
22
+ "claim_boundary": "top-k transfer signal; ranking and confidence remain open decision-layer problems",
23
+ "path": "evaluation/massspecgym/run_massspecgym_retrieval_hf.py",
24
+ "reference_result": "test Hit@20 0.3505 with frozen V3 projected-dot scorer under Hit@k-only evaluation"
25
+ }
26
+ },
27
  "foundation_checkpoint": "weights/Final_V3-model_state.safetensors",
28
  "foundation_checkpoint_format": "safetensors",
29
  "full_training_checkpoints": {
evaluation/massspecgym/README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MassSpecGym Evaluation Adapter
2
+
3
+ This directory contains the public Hugging Face adapter used to position `NexaMass-V3-Struct` on the MassSpecGym molecule-retrieval task.
4
+
5
+ The adapter loads the safetensors-only public checkpoint and wraps MassSpecGym's own `RetrievalDataset`, `MassSpecDataModule`, and retrieval evaluator. It is meant for external benchmark positioning, not for claiming that ranking or confidence are solved.
6
+
7
+ ## Install
8
+
9
+ Use an isolated environment because MassSpecGym has its own dependency surface:
10
+
11
+ ```bash
12
+ python -m pip install torch safetensors huggingface_hub massspecgym==1.3.1 pytorch-lightning
13
+ ```
14
+
15
+ ## Run From A Clone Of This HF Repo
16
+
17
+ ```bash
18
+ python evaluation/massspecgym/run_massspecgym_retrieval_hf.py \
19
+ --checkpoint weights/NexaMass-V3-Struct-model_state.safetensors \
20
+ --config config.json \
21
+ --split test \
22
+ --scorer projected_dot \
23
+ --hit-only \
24
+ --batch-size 32 \
25
+ --num-workers 25 \
26
+ --output-json evaluation/massspecgym/results/local_massspecgym_test.json
27
+ ```
28
+
29
+ If the checkpoint is not present locally, the script can download it from this repo through `huggingface_hub`.
30
+
31
+ ## Reported Reference Result
32
+
33
+ The archived adapter run reached MassSpecGym test Hit@20 `0.3505` under Hit@k-only evaluation using the frozen V3 projected-dot scorer. This put the model above lower baselines such as Random, DeepSets, Fingerprint FFN, and DeepSets+Fourier, while remaining below specialized retrieval systems such as MIST.
34
+
35
+ Interpretation: the encoder transfers real top-k structure signal to retrieval, but exact local ranking and calibrated confidence remain separate downstream problems.
evaluation/massspecgym/figures/nexamass_massspecgym_hit20_position.png ADDED
evaluation/massspecgym/results/massspecgym_hitk_summary.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "benchmark": "MassSpecGym molecule retrieval",
3
+ "adapter": "evaluation/massspecgym/run_massspecgym_retrieval_hf.py",
4
+ "checkpoint": "weights/NexaMass-V3-Struct-model_state.safetensors",
5
+ "scorer": "projected_dot",
6
+ "evaluation_mode": "test dataloader through validation loop, Hit@k-only",
7
+ "metrics": {
8
+ "test_hit_at_1": 0.0627,
9
+ "test_hit_at_5": 0.1753,
10
+ "test_hit_at_20": 0.3505,
11
+ "val_hit_at_1": 0.1162,
12
+ "val_hit_at_5": 0.1915,
13
+ "val_hit_at_20": 0.3328
14
+ },
15
+ "claim_boundary": "External positioning sanity check; demonstrates top-k transfer signal, not solved ranking or confidence."
16
+ }
evaluation/massspecgym/run_massspecgym_retrieval_hf.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Evaluate NexaMass-V3-Struct on MassSpecGym retrieval.
3
+
4
+ This is the Hugging Face release adapter. It loads the public safetensors
5
+ checkpoint from this repository and wraps MassSpecGym's official retrieval data
6
+ module/evaluator. The adapter is for external benchmark positioning, not for
7
+ claiming that ranking or confidence are solved.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import json
14
+ import sys
15
+ from pathlib import Path
16
+ from typing import Any
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+
21
+ REPO_ROOT = Path(__file__).resolve().parents[2]
22
+ if str(REPO_ROOT) not in sys.path:
23
+ sys.path.insert(0, str(REPO_ROOT))
24
+
25
+ from runtime.nexamass_encoder import ModelConfig, NexaMassSpectralEncoder, load_nexamass_model_state # noqa: E402
26
+
27
+
28
+ def _require_massspecgym() -> tuple[Any, Any, Any, Any, Any]:
29
+ try:
30
+ from massspecgym.data import MassSpecDataModule, RetrievalDataset
31
+ from massspecgym.data.transforms import MolFingerprinter, SpecTokenizer
32
+ from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
33
+ from pytorch_lightning import Trainer
34
+ except ImportError as exc:
35
+ raise SystemExit(
36
+ "MassSpecGym dependencies are missing. Install in an isolated env with: "
37
+ "python -m pip install massspecgym==1.3.1 pytorch-lightning safetensors huggingface_hub"
38
+ ) from exc
39
+ return Trainer, MassSpecDataModule, RetrievalDataset, MolFingerprinter, SpecTokenizer, RetrievalMassSpecGymModel
40
+
41
+
42
+ def _cfg_from_json(path: Path) -> ModelConfig:
43
+ if not path.exists():
44
+ return ModelConfig()
45
+ payload = json.loads(path.read_text(encoding="utf-8"))
46
+ arch = payload.get("architecture_config", payload)
47
+ allowed = ModelConfig.__dataclass_fields__.keys()
48
+ return ModelConfig(**{key: arch[key] for key in allowed if key in arch})
49
+
50
+
51
+ def _resolve_checkpoint(path: Path, repo_id: str, filename: str) -> Path:
52
+ if path.exists():
53
+ return path
54
+ try:
55
+ from huggingface_hub import hf_hub_download
56
+ except ImportError as exc:
57
+ raise SystemExit("Checkpoint was not found locally and huggingface_hub is not installed.") from exc
58
+ return Path(hf_hub_download(repo_id=repo_id, repo_type="model", filename=filename))
59
+
60
+
61
+ def _parse_limit_batches(raw: str) -> int | float:
62
+ value = raw.strip()
63
+ if value.isdigit():
64
+ return int(value)
65
+ return float(value)
66
+
67
+
68
+ def _batch_from_massspecgym_spec(
69
+ spec: torch.Tensor,
70
+ cfg: ModelConfig,
71
+ device: torch.device,
72
+ *,
73
+ precursor_mz: torch.Tensor | None = None,
74
+ ) -> dict[str, torch.Tensor]:
75
+ """Convert MassSpecGym tokenized spectra into NexaMass' encoder batch contract."""
76
+
77
+ if spec.ndim != 3 or spec.shape[-1] < 2:
78
+ raise ValueError(f"Expected MassSpecGym spec shape [batch, peaks, >=2], got {tuple(spec.shape)}")
79
+
80
+ spec = spec.to(device=device, dtype=torch.float32)
81
+ mzs_raw = spec[..., 0].clamp(min=0.0)
82
+ ints_raw = spec[..., 1].clamp(min=0.0)
83
+ batch_size, peak_count = mzs_raw.shape
84
+ if peak_count > cfg.max_peaks:
85
+ mzs_raw = mzs_raw[:, : cfg.max_peaks]
86
+ ints_raw = ints_raw[:, : cfg.max_peaks]
87
+ peak_count = cfg.max_peaks
88
+
89
+ mask = (mzs_raw > 0) & torch.isfinite(mzs_raw) & torch.isfinite(ints_raw)
90
+ max_intensity = ints_raw.masked_fill(~mask, 0.0).amax(dim=1, keepdim=True).clamp(min=1e-6)
91
+ mzs_norm = (mzs_raw / cfg.mz_max).clamp(0.0, 1.5)
92
+ ints_norm = (ints_raw / max_intensity).masked_fill(~mask, 0.0)
93
+ if precursor_mz is not None:
94
+ precursor_raw = precursor_mz.to(device=device, dtype=torch.float32).view(-1).clamp(min=1e-6)
95
+ if precursor_raw.numel() != batch_size:
96
+ raise ValueError(f"Expected {batch_size} precursor_mz values, got {precursor_raw.numel()}")
97
+ else:
98
+ precursor_raw = mzs_raw.masked_fill(~mask, 0.0).amax(dim=1).clamp(min=1e-6)
99
+ mz_to_precursor = (mzs_raw / precursor_raw[:, None]).clamp(0.0, 2.0).masked_fill(~mask, 0.0)
100
+ ranks = torch.linspace(0.0, 1.0, peak_count, device=device, dtype=torch.float32)[None, :].expand(batch_size, -1)
101
+
102
+ if peak_count < cfg.max_peaks:
103
+ pad_width = cfg.max_peaks - peak_count
104
+
105
+ def pad(values: torch.Tensor, value: float = 0.0) -> torch.Tensor:
106
+ return F.pad(values, (0, pad_width), value=value)
107
+
108
+ mzs_norm = pad(mzs_norm)
109
+ ints_norm = pad(ints_norm)
110
+ mz_to_precursor = pad(mz_to_precursor)
111
+ ranks = pad(ranks)
112
+ mask = F.pad(mask, (0, pad_width), value=False)
113
+
114
+ observed_peak_count = mask.sum(dim=1).to(dtype=torch.float32).clamp(min=1.0)
115
+ return {
116
+ "mzs": mzs_norm,
117
+ "ints": ints_norm,
118
+ "mz_to_precursor": mz_to_precursor,
119
+ "peak_rank": ranks,
120
+ "mask": mask.to(dtype=torch.bool),
121
+ "precursor_mz": (precursor_raw / cfg.mz_max).clamp(max=2.0),
122
+ "charge": torch.zeros(batch_size, device=device, dtype=torch.float32),
123
+ "collision_energy": torch.zeros(batch_size, device=device, dtype=torch.float32),
124
+ "adduct_id": torch.zeros(batch_size, device=device, dtype=torch.long),
125
+ "instrument_id": torch.zeros(batch_size, device=device, dtype=torch.long),
126
+ "peak_count": observed_peak_count / float(cfg.max_peaks),
127
+ }
128
+
129
+
130
+ def _scores_for_batch(
131
+ *,
132
+ scorer: str,
133
+ model: NexaMassSpectralEncoder,
134
+ cfg: ModelConfig,
135
+ spec: torch.Tensor,
136
+ candidates: torch.Tensor,
137
+ batch_ptr: torch.Tensor,
138
+ precursor_mz: torch.Tensor | None,
139
+ device: torch.device,
140
+ ) -> torch.Tensor:
141
+ batch = _batch_from_massspecgym_spec(spec, cfg, device, precursor_mz=precursor_mz)
142
+ candidates = candidates.to(device=device, dtype=torch.float32)
143
+ with torch.no_grad():
144
+ _embedding, _raw_projected, logits, query_raw = model.forward_with_heads(batch)
145
+ pred_probs = torch.sigmoid(logits)
146
+ if scorer == "predicted_fingerprint":
147
+ query_repeated = F.normalize(pred_probs, dim=-1).repeat_interleave(batch_ptr.to(device), dim=0)
148
+ return F.cosine_similarity(query_repeated, F.normalize(candidates, dim=-1), dim=-1).detach()
149
+ if scorer == "projected_dot":
150
+ query_repeated = F.normalize(query_raw, dim=-1).repeat_interleave(batch_ptr.to(device), dim=0)
151
+ target_projection = F.normalize(model.project_structure_targets(candidates), dim=-1)
152
+ return (query_repeated * target_projection).sum(dim=-1).detach()
153
+ raise ValueError(f"Unsupported scorer: {scorer}")
154
+
155
+
156
+ def main() -> int:
157
+ parser = argparse.ArgumentParser(description=__doc__)
158
+ parser.add_argument("--repo-id", default="AethronPhantom/NexaMass-V3-Struct")
159
+ parser.add_argument("--checkpoint", type=Path, default=REPO_ROOT / "weights/NexaMass-V3-Struct-model_state.safetensors")
160
+ parser.add_argument("--checkpoint-filename", default="weights/NexaMass-V3-Struct-model_state.safetensors")
161
+ parser.add_argument("--config", type=Path, default=REPO_ROOT / "config.json")
162
+ parser.add_argument("--scorer", choices=["projected_dot", "predicted_fingerprint"], default="projected_dot")
163
+ parser.add_argument("--split", choices=["val", "test"], default="test")
164
+ parser.add_argument("--batch-size", type=int, default=32)
165
+ parser.add_argument("--num-workers", type=int, default=8)
166
+ parser.add_argument("--n-peaks", type=int, default=256)
167
+ parser.add_argument("--accelerator", default="gpu")
168
+ parser.add_argument("--devices", default="1")
169
+ parser.add_argument("--limit-batches", default="1.0")
170
+ parser.add_argument("--hit-only", action="store_true", help="Use validation loop over test dataloader for Hit@k-only scoring.")
171
+ parser.add_argument("--inspect-batch-only", action="store_true")
172
+ parser.add_argument("--output-json", type=Path)
173
+ args = parser.parse_args()
174
+
175
+ Trainer, MassSpecDataModule, RetrievalDataset, MolFingerprinter, SpecTokenizer, RetrievalMassSpecGymModel = (
176
+ _require_massspecgym()
177
+ )
178
+
179
+ torch.set_float32_matmul_precision("high")
180
+ limit_batches = _parse_limit_batches(args.limit_batches)
181
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
182
+ cfg = _cfg_from_json(args.config)
183
+ checkpoint = _resolve_checkpoint(args.checkpoint.expanduser(), args.repo_id, args.checkpoint_filename)
184
+ v3_model = load_nexamass_model_state(str(checkpoint), cfg=cfg, map_location="cpu")
185
+ v3_model.to(device)
186
+ v3_model.eval()
187
+
188
+ class NexaMassRetrievalModel(RetrievalMassSpecGymModel): # type: ignore[misc, valid-type]
189
+ def __init__(self) -> None:
190
+ super().__init__()
191
+ self._inspected = False
192
+
193
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
194
+ batch = _batch_from_massspecgym_spec(spec, cfg, device)
195
+ with torch.no_grad():
196
+ _embedding, _raw, logits, query_raw = v3_model.forward_with_heads(batch)
197
+ return query_raw if args.scorer == "projected_dot" else torch.sigmoid(logits)
198
+
199
+ def step(self, batch: dict[str, Any], stage: Any) -> dict[str, torch.Tensor]:
200
+ if args.inspect_batch_only and not self._inspected:
201
+ print(
202
+ json.dumps(
203
+ {
204
+ "batch_keys": sorted(batch.keys()),
205
+ "spec_shape": list(batch["spec"].shape),
206
+ "candidates_mol_shape": list(batch["candidates_mol"].shape),
207
+ "batch_ptr_head": batch["batch_ptr"].detach().cpu().tolist()[:8],
208
+ },
209
+ indent=2,
210
+ ),
211
+ flush=True,
212
+ )
213
+ self._inspected = True
214
+ scores = _scores_for_batch(
215
+ scorer=args.scorer,
216
+ model=v3_model,
217
+ cfg=cfg,
218
+ spec=batch["spec"],
219
+ candidates=batch["candidates_mol"],
220
+ batch_ptr=batch["batch_ptr"],
221
+ precursor_mz=batch.get("precursor_mz"),
222
+ device=device,
223
+ )
224
+ return {"loss": torch.zeros((), device=scores.device), "scores": scores}
225
+
226
+ dataset = RetrievalDataset(
227
+ spec_transform=SpecTokenizer(n_peaks=args.n_peaks),
228
+ mol_transform=MolFingerprinter(fp_size=cfg.fingerprint_dim),
229
+ )
230
+ data_module = MassSpecDataModule(dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers)
231
+ data_module.prepare_data()
232
+ data_module.setup(None if args.split == "val" else "test")
233
+ model = NexaMassRetrievalModel()
234
+ trainer = Trainer(
235
+ accelerator=args.accelerator,
236
+ devices=args.devices,
237
+ logger=False,
238
+ enable_checkpointing=False,
239
+ limit_val_batches=limit_batches if args.split == "val" or args.hit_only else 1.0,
240
+ limit_test_batches=limit_batches if args.split == "test" else 1.0,
241
+ )
242
+ if args.split == "val":
243
+ metrics = trainer.validate(model, datamodule=data_module)
244
+ elif args.hit_only:
245
+ metrics = trainer.validate(model, dataloaders=data_module.test_dataloader())
246
+ else:
247
+ metrics = trainer.test(model, datamodule=data_module)
248
+
249
+ payload = {
250
+ "checkpoint": str(checkpoint),
251
+ "scorer": args.scorer,
252
+ "split": args.split,
253
+ "metrics": metrics,
254
+ "massspecgym_adapter": {
255
+ "repo_id": args.repo_id,
256
+ "n_peaks": args.n_peaks,
257
+ "fingerprint_dim": cfg.fingerprint_dim,
258
+ "limit_batches": limit_batches,
259
+ "hit_only": args.hit_only,
260
+ "metadata_defaults": "charge/collision/adduct/instrument set to zero when absent from MassSpecGym batch",
261
+ },
262
+ }
263
+ print(json.dumps(payload, indent=2), flush=True)
264
+ if args.output_json:
265
+ args.output_json.parent.mkdir(parents=True, exist_ok=True)
266
+ args.output_json.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
267
+ return 0
268
+
269
+
270
+ if __name__ == "__main__":
271
+ raise SystemExit(main())
runtime/nexamass_encoder.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
 
4
 
5
  import torch
6
  import torch.nn as nn
@@ -126,14 +127,49 @@ class NexaMassSpectralEncoder(nn.Module):
126
  return F.normalize(self.target_projection(targets), dim=-1)
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def load_nexamass_model_state(
130
  checkpoint_path: str,
131
  cfg: ModelConfig | None = None,
132
  map_location: str | torch.device = "cpu",
133
  ) -> NexaMassSpectralEncoder:
134
- payload = torch.load(checkpoint_path, map_location=map_location)
135
  cfg = cfg or ModelConfig()
136
  model = NexaMassSpectralEncoder(cfg)
137
- model.load_state_dict(payload["model_state"], strict=True)
138
  model.eval()
139
  return model
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
+ from pathlib import Path
5
 
6
  import torch
7
  import torch.nn as nn
 
127
  return F.normalize(self.target_projection(targets), dim=-1)
128
 
129
 
130
+
131
+
132
+ def load_nexamass_state_dict(
133
+ checkpoint_path: str,
134
+ map_location: str | torch.device = "cpu",
135
+ ) -> dict[str, torch.Tensor]:
136
+ """Load public NexaMass model-state weights from Safetensors or PyTorch.
137
+
138
+ Hugging Face public release weights are Safetensors-only. The PyTorch branch is
139
+ kept for internal/object-storage compatibility with full training checkpoints
140
+ and model-state fallbacks.
141
+ """
142
+
143
+ path = Path(checkpoint_path)
144
+ if path.suffix == ".safetensors":
145
+ try:
146
+ from safetensors.torch import load_file
147
+ except ImportError as exc: # pragma: no cover - dependency message path
148
+ raise RuntimeError("Install safetensors to load NexaMass public weights: pip install safetensors") from exc
149
+ device = str(map_location) if isinstance(map_location, str) else "cpu"
150
+ if device not in {"cpu", "cuda"} and not device.startswith("cuda:"):
151
+ device = "cpu"
152
+ return load_file(str(path), device=device)
153
+
154
+ try:
155
+ payload = torch.load(path, map_location=map_location, weights_only=True)
156
+ except TypeError: # older PyTorch
157
+ payload = torch.load(path, map_location=map_location)
158
+ if isinstance(payload, dict) and "model_state" in payload:
159
+ return payload["model_state"]
160
+ if isinstance(payload, dict):
161
+ return payload
162
+ raise TypeError(f"Unsupported NexaMass checkpoint payload type: {type(payload)!r}")
163
+
164
+
165
  def load_nexamass_model_state(
166
  checkpoint_path: str,
167
  cfg: ModelConfig | None = None,
168
  map_location: str | torch.device = "cpu",
169
  ) -> NexaMassSpectralEncoder:
170
+ state_dict = load_nexamass_state_dict(checkpoint_path, map_location=map_location)
171
  cfg = cfg or ModelConfig()
172
  model = NexaMassSpectralEncoder(cfg)
173
+ model.load_state_dict(state_dict, strict=True)
174
  model.eval()
175
  return model