Spaces:
Sleeping
Sleeping
File size: 1,804 Bytes
94aa6f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import typing as T
from abc import ABC
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchmetrics import RetrievalHitRate, CosineSimilarity
from massspecgym.models.base import MassSpecGymModel
class SimulationMassSpecGymModel(MassSpecGymModel, ABC):
def on_batch_end(
self, outputs: T.Any, batch: dict, batch_idx: int, metric_pref: str = ""
) -> None:
"""
Compute evaluation metrics for the retrieval model based on the batch and corresponding predictions.
This method will be used in the `on_train_batch_end`, `on_validation_batch_end`, since `on_test_batch_end` is
overriden below.
"""
self.evaluate_cos_similarity_step(
outputs["spec_pred"],
batch["spec"],
metric_pref=metric_pref,
)
def on_test_batch_end(
self, outputs: T.Any, batch: dict, batch_idx: int
) -> None:
metric_pref = "_test"
self.evaluate_cos_similarity_step(
outputs["spec_pred"],
batch["spec"],
metric_pref=metric_pref
)
self.evaluate_hit_rate_step(
outputs["spec_pred"],
batch["spec"],
metric_pref=metric_pref
)
def evaluate_cos_similarity_step(
self,
specs_pred: torch.Tensor,
specs: torch.Tensor,
metric_pref: str = ""
) -> None:
"""
Evaulate cosine similarity.
"""
raise NotImplementedError
def evaluate_hit_rate_step(
self,
specs_pred: torch.Tensor,
specs: torch.Tensor,
metric_pref: str = ""
) -> None:
"""
Evaulate Hit rate @ {1, 5, 20} (typically reported as Accuracy @ {1, 5, 20}).
"""
raise NotImplementedError
|