Spaces:
Sleeping
Sleeping
| import typing as T | |
| from abc import ABC | |
| import torch | |
| import torch.nn as nn | |
| import pytorch_lightning as pl | |
| from torchmetrics import RetrievalHitRate, CosineSimilarity | |
| from massspecgym.models.base import MassSpecGymModel | |
| class SimulationMassSpecGymModel(MassSpecGymModel, ABC): | |
| def on_batch_end( | |
| self, outputs: T.Any, batch: dict, batch_idx: int, metric_pref: str = "" | |
| ) -> None: | |
| """ | |
| Compute evaluation metrics for the retrieval model based on the batch and corresponding predictions. | |
| This method will be used in the `on_train_batch_end`, `on_validation_batch_end`, since `on_test_batch_end` is | |
| overriden below. | |
| """ | |
| self.evaluate_cos_similarity_step( | |
| outputs["spec_pred"], | |
| batch["spec"], | |
| metric_pref=metric_pref, | |
| ) | |
| def on_test_batch_end( | |
| self, outputs: T.Any, batch: dict, batch_idx: int | |
| ) -> None: | |
| metric_pref = "_test" | |
| self.evaluate_cos_similarity_step( | |
| outputs["spec_pred"], | |
| batch["spec"], | |
| metric_pref=metric_pref | |
| ) | |
| self.evaluate_hit_rate_step( | |
| outputs["spec_pred"], | |
| batch["spec"], | |
| metric_pref=metric_pref | |
| ) | |
| def evaluate_cos_similarity_step( | |
| self, | |
| specs_pred: torch.Tensor, | |
| specs: torch.Tensor, | |
| metric_pref: str = "" | |
| ) -> None: | |
| """ | |
| Evaulate cosine similarity. | |
| """ | |
| raise NotImplementedError | |
| def evaluate_hit_rate_step( | |
| self, | |
| specs_pred: torch.Tensor, | |
| specs: torch.Tensor, | |
| metric_pref: str = "" | |
| ) -> None: | |
| """ | |
| Evaulate Hit rate @ {1, 5, 20} (typically reported as Accuracy @ {1, 5, 20}). | |
| """ | |
| raise NotImplementedError | |