MVP / massspecgym /models /retrieval /from_dict.py
yzhouchen001's picture
partial push
94aa6f9
import pickle
import typing as T
from pathlib import Path
import torch
import torch.nn as nn
from massspecgym.models.base import Stage
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
class FromDictRetrieval(RetrievalMassSpecGymModel):
"""
Read predictions from dictionary with MassSpecGym ids as keys. Currently, the class
only implements reading fingerprints from the dictionary.
"""
def __init__(
self,
dct: T.Optional[dict[str, T.Any]] = None,
dct_path: T.Optional[T.Union[str, Path]] = None, # pickled dict path
*args,
**kwargs
):
super().__init__(*args, **kwargs)
if dct is None and dct_path is None:
raise ValueError("Either dct or dct_path must be provided.")
if dct is not None and dct_path is not None:
raise ValueError("Only one of dct or dct_path must be provided.")
if dct_path is not None:
with open(dct_path, "rb") as file:
dct = pickle.load(file)
dct = {k: torch.tensor(v) for k, v in dct.items()}
self.dct = dct
def step(
self, batch: dict, stage: Stage = Stage.NONE
) -> tuple[torch.Tensor, torch.Tensor]:
# Unpack inputs
ids = batch["identifier"]
fp_true = batch["mol"]
cands = batch["candidates"]
batch_ptr = batch["batch_ptr"]
# Read predicted fingerprints from dictionary
fp_pred = torch.stack([self.dct[id] for id in ids]).to(fp_true.device)
# Evaluation performance on fingerprint prediction (optional)
self.evaluate_fingerprint_step(fp_true, fp_pred, stage=stage)
# Calculate final similarity scores between predicted fingerprints and corresponding
# candidate fingerprints for retrieval
fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
scores = nn.functional.cosine_similarity(fp_pred_repeated, cands).to(fp_true.device)
# Random baseline, so we return a dummy loss
loss = torch.tensor(0.0, requires_grad=True, device=fp_true.device)
return dict(loss=loss, scores=scores)
def configure_optimizers(self):
# No training, so no optimizers
return None