Spaces:
Sleeping
Sleeping
| import typing as T | |
| import torch | |
| import torch.nn as nn | |
| import pandas as pd | |
| from collections import defaultdict | |
| import numpy as np | |
| import os | |
| from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel | |
| from massspecgym.models.base import Stage | |
| from massspecgym import utils | |
| from torch.nn.utils.rnn import pad_sequence | |
| from flare.utils.loss import contrastive_loss, filip_loss_with_mask, global_infonce_loss, pcgrad_combine | |
| import flare.utils.models as model_utils | |
| from flare.utils.general import pad_graph_nodes, filip_similarity_batch | |
| from flare.models.encoders import CrossAttention | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import global_mean_pool | |
| import torch, dgllife | |
| class ContrastiveModel(RetrievalMassSpecGymModel): | |
| def __init__( | |
| self, | |
| external_test: bool = False, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.save_hyperparameters() | |
| self.external_test = external_test | |
| if 'use_fp' not in self.hparams: | |
| self.hparams.use_fp = False | |
| if 'use_NL_spec' not in self.hparams: | |
| self.hparams.use_NL_spec = False | |
| self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams) | |
| self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams) | |
| self.spec_view = self.hparams.spectra_view | |
| # result storage for testing results | |
| self.result_dct = defaultdict(lambda: defaultdict(list)) | |
| def forward(self, batch, stage): | |
| if 'cand' in batch: | |
| g = batch['cand'] | |
| elif 'mol' in batch: | |
| g = batch['mol'] | |
| else: | |
| g = None | |
| spec = batch[self.spec_view] | |
| n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None | |
| spec_enc = self.spec_enc_model(spec, n_peaks) | |
| if g is None: | |
| mol_enc = None | |
| return spec_enc, mol_enc | |
| fp = batch['fp'] if self.hparams.use_fp else None | |
| f = self.mol_enc_model.GNN(g, g.ndata['h']) | |
| mol_enc = self.mol_enc_model(g, fp=fp) | |
| return spec_enc, mol_enc | |
| def compute_loss(self, batch: dict, spec_enc, mol_enc, output): | |
| loss = 0 | |
| losses = {} | |
| contr_loss, _, _ = contrastive_loss(spec_enc, mol_enc, self.hparams.contr_temp) | |
| losses['contr_loss'] = contr_loss.detach().item() | |
| loss+=contr_loss | |
| losses['loss'] = loss | |
| return losses | |
| def step( | |
| self, batch: dict, stage= Stage.NONE): | |
| # Compute spectra and mol encoding | |
| spec_enc, mol_enc = self.forward(batch, stage) | |
| if stage == Stage.TEST: | |
| return dict(spec_enc=spec_enc, mol_enc=mol_enc) | |
| # Aux tasks | |
| output = {} | |
| if self.hparams.pred_fp: | |
| output['fp'] = self.fp_pred_model(mol_enc) | |
| if self.hparams.use_cons_spec: | |
| spec = batch[self.spec_view] | |
| n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None | |
| output['ind_spec'] = self.spec_enc_model(spec, n_peaks) | |
| # Calculate loss | |
| losses = self.compute_loss(batch, spec_enc, mol_enc, output) | |
| return losses | |
| def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None: | |
| # total loss | |
| self.log( | |
| f'{stage.to_pref()}loss', | |
| outputs['loss'], | |
| batch_size=len(batch['identifier']), | |
| sync_dist=True, | |
| prog_bar=True, | |
| on_epoch=True, | |
| # on_step=True | |
| ) | |
| def test_step(self, batch, batch_idx): | |
| # Unpack inputs | |
| identifiers = batch['identifier'] | |
| cand_smiles = batch['cand_smiles'] | |
| id_to_ct = defaultdict(int) | |
| for i in identifiers: id_to_ct[i]+=1 | |
| batch_ptr = torch.tensor(list(id_to_ct.values())) | |
| outputs = self.step(batch, stage=Stage.TEST) | |
| spec_enc = outputs['spec_enc'] | |
| mol_enc = outputs['mol_enc'] | |
| # Calculate scores | |
| indexes = utils.batch_ptr_to_batch_idx(batch_ptr) | |
| scores = nn.functional.cosine_similarity(spec_enc, mol_enc) | |
| scores = torch.split(scores, list(id_to_ct.values())) | |
| cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes) | |
| labels = utils.unbatch_list(batch['label'], indexes) | |
| return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels) | |
| def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None: | |
| # save scores | |
| for i, cands, scores, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['scores'], outputs['labels']): | |
| self.result_dct[i]['candidates'].extend(cands) | |
| self.result_dct[i]['scores'].extend(scores.cpu().tolist()) | |
| self.result_dct[i]['labels'].extend([x.cpu().item() for x in l]) | |
| def _compute_rank(self, scores, labels): | |
| if not any(labels): | |
| return -1 | |
| scores = np.array(scores) | |
| target_score = scores[labels][0] | |
| rank = np.count_nonzero(scores >=target_score) | |
| return rank | |
| def _get_top_cand(self, scores, candidates): | |
| return candidates[np.argmax(np.array(scores))] | |
| def on_test_epoch_end(self) -> None: | |
| self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'}) | |
| # Compute rank | |
| if not self.external_test: | |
| self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1) | |
| if self.external_test: | |
| self.df_test.drop('labels', axis=1, inplace=True) | |
| self.df_test['top_cand'] = self.df_test.apply(lambda row: self._get_top_cand(row['scores'], row['candidates']), axis=1) | |
| if not self.df_test_path: | |
| self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl') | |
| self.df_test.to_pickle(self.df_test_path) | |
| def get_checkpoint_monitors(self) -> T.List[dict]: | |
| monitors = [ | |
| {"monitor": f"{Stage.TRAIN.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor train loss | |
| {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor val loss | |
| ] | |
| return monitors | |
| class FilipContrastive(ContrastiveModel): | |
| def __init__(self, | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| def compute_loss(self, batch: dict, spec_enc, mol_enc, spec_mask, mol_mask): | |
| losses = {} | |
| loss = filip_loss_with_mask(spec_enc, mol_enc, spec_mask, mol_mask, self.hparams.contr_temp) | |
| losses['loss'] = loss | |
| return losses | |
| def step( | |
| self, batch: dict, stage= Stage.NONE): | |
| # Compute spectra and mol encoding | |
| spec_enc, mol_enc = self.forward(batch, stage) | |
| # pad nodes to max_n_nodes in batch (Spectra are already padded) | |
| mol_enc, mol_mask = pad_graph_nodes(mol_enc, batch['mol_n_nodes']) | |
| spec_mask = ~torch.all((spec_enc == -5), dim=-1) | |
| if stage == Stage.TEST: | |
| return dict(spec_enc=spec_enc, mol_enc=mol_enc, spec_mask=spec_mask, mol_mask=mol_mask) | |
| # Calculate loss | |
| losses = self.compute_loss(batch, spec_enc, mol_enc, spec_mask, mol_mask) | |
| return losses | |
| def test_step(self, batch, batch_idx): | |
| # Unpack inputs | |
| identifiers = batch['identifier'] | |
| cand_smiles = batch['cand_smiles'] | |
| id_to_ct = defaultdict(int) | |
| for i in identifiers: id_to_ct[i]+=1 | |
| batch_ptr = torch.tensor(list(id_to_ct.values())) | |
| outputs = self.step(batch, stage=Stage.TEST) | |
| spec_enc = outputs['spec_enc'] | |
| mol_enc = outputs['mol_enc'] | |
| spec_mask = outputs['spec_mask'] | |
| mol_mask = outputs['mol_mask'] | |
| # Calculate scores | |
| indexes = utils.batch_ptr_to_batch_idx(batch_ptr) | |
| scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask) | |
| scores = torch.split(scores, list(id_to_ct.values())) | |
| cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes) | |
| labels = utils.unbatch_list(batch['label'], indexes) | |
| return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels) | |
| # ============================================================ | |
| # Combined FILIP + Global InfoNCE | |
| # ============================================================ | |
| class FilipGlobalContrastive(ContrastiveModel): | |
| def __init__(self, loss_mode="sum", loss_weight=1.0, agg_fn="mean", **kwargs): | |
| """ | |
| Args: | |
| loss_mode: str, one of ["sum", "weighted", "pcgrad"] | |
| loss_weight: weight for global loss if using weighted sum | |
| agg_fn: aggregation function for global InfoNCE ("mean", "max", "cls") | |
| """ | |
| super().__init__(**kwargs) | |
| self.loss_mode = loss_mode | |
| self.loss_weight = loss_weight | |
| self.agg_fn = agg_fn | |
| # -------------- loss computation -------------- | |
| def compute_loss(self, batch: dict, spec_enc, mol_enc, spec_mask, mol_mask, stage=Stage.NONE): | |
| losses = {} | |
| # fine-grained FILIP loss | |
| loss_fine = filip_loss_with_mask(spec_enc, mol_enc, spec_mask, mol_mask, self.hparams.contr_temp) | |
| # global InfoNCE loss | |
| loss_global = global_infonce_loss(spec_enc, mol_enc, spec_mask, mol_mask, | |
| temperature=self.hparams.contr_temp, agg_fn=self.agg_fn) | |
| # choose combination mode | |
| if self.loss_mode == "sum": | |
| loss = loss_fine + loss_global | |
| elif self.loss_mode == "weighted": | |
| loss = loss_fine + self.loss_weight * loss_global | |
| elif self.loss_mode == "pcgrad": | |
| if stage == Stage.TRAIN: | |
| # PCGrad over both losses (training only) | |
| shared_params = list(self.spec_enc_model.parameters()) + list(self.mol_enc_model.parameters()) | |
| self.zero_grad(set_to_none=True) | |
| loss = pcgrad_combine([loss_fine, loss_global], shared_params) | |
| else: | |
| loss = (loss_fine + loss_global).detach() | |
| else: | |
| raise ValueError(f"Unsupported loss_mode: {self.loss_mode}") | |
| losses["loss"] = loss | |
| losses["loss_fine"] = loss_fine.detach() | |
| losses["loss_global"] = loss_global.detach() | |
| return losses | |
| def step(self, batch: dict, stage=Stage.NONE): | |
| spec_enc, mol_enc = self.forward(batch, stage) | |
| mol_enc, mol_mask = pad_graph_nodes(mol_enc, batch["mol_n_nodes"]) | |
| spec_mask = ~torch.all((spec_enc == -5), dim=-1) | |
| if stage == Stage.TEST: | |
| return dict(spec_enc=spec_enc, mol_enc=mol_enc, spec_mask=spec_mask, mol_mask=mol_mask) | |
| losses = self.compute_loss(batch, spec_enc, mol_enc, spec_mask, mol_mask, stage=stage) | |
| return losses | |
| # -------------- TEST step with different score variants -------------- | |
| def test_step(self, batch, batch_idx): | |
| identifiers = batch["identifier"] | |
| cand_smiles = batch["cand_smiles"] | |
| id_to_ct = defaultdict(int) | |
| for i in identifiers: | |
| id_to_ct[i] += 1 | |
| batch_ptr = torch.tensor(list(id_to_ct.values()), device=self.device) | |
| outputs = self.step(batch, stage=Stage.TEST) | |
| spec_enc = outputs["spec_enc"] | |
| mol_enc = outputs["mol_enc"] | |
| spec_mask = outputs["spec_mask"] | |
| mol_mask = outputs["mol_mask"] | |
| indexes = utils.batch_ptr_to_batch_idx(batch_ptr) | |
| # --- fine-grained score --- | |
| fine_scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask) | |
| # --- global cosine score --- | |
| spec_global = (spec_enc * spec_mask.unsqueeze(-1)).sum(1) / spec_mask.sum(1, keepdim=True).clamp(min=1) | |
| mol_global = (mol_enc * mol_mask.unsqueeze(-1)).sum(1) / mol_mask.sum(1, keepdim=True).clamp(min=1) | |
| global_scores = F.cosine_similarity(spec_global, mol_global, dim=-1) | |
| # --- combined scores (for evaluation) --- | |
| combined_sum = fine_scores + global_scores | |
| combined_weighted = fine_scores + self.loss_weight * global_scores | |
| combined_pc = 0.5 * (fine_scores + global_scores) # simple average baseline | |
| scores_dict = { | |
| "fine": fine_scores, | |
| "global": global_scores, | |
| "sum": combined_sum, | |
| "weighted": combined_weighted, | |
| "avg": combined_pc, | |
| } | |
| # split back per identifier | |
| for key in scores_dict: | |
| scores_dict[key] = torch.split(scores_dict[key], list(id_to_ct.values())) | |
| cand_smiles = utils.unbatch_list(batch["cand_smiles"], indexes) | |
| labels = utils.unbatch_list(batch["label"], indexes) | |
| return dict( | |
| identifiers=list(id_to_ct.keys()), | |
| scores=scores_dict, | |
| cand_smiles=cand_smiles, | |
| labels=labels, | |
| ) | |
| def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None: | |
| """ | |
| Collects test batch outputs and stores them in self.result_dct. | |
| Supports both: | |
| - Single score list format (legacy) | |
| - Dict of multiple score variants (new) | |
| """ | |
| identifiers = outputs["identifiers"] | |
| cand_smiles = outputs["cand_smiles"] | |
| labels = outputs["labels"] | |
| scores_out = outputs["scores"] | |
| for k, (i, cands, l) in enumerate(zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels'])): | |
| self.result_dct[i]['candidates'].extend(cands) | |
| self.result_dct[i]['labels'].extend([x.cpu().item() for x in l]) | |
| for variant_name, score_list in scores_out.items(): | |
| self.result_dct[i][f"scores_{variant_name}"].extend(score_list[k].cpu().tolist()) | |
| def on_test_epoch_end(self) -> None: | |
| """ | |
| Combine results into one DataFrame with one row per identifier. | |
| Adds rank/top_cand columns for each score variant. | |
| """ | |
| records = [] | |
| for identifier, val in self.result_dct.items(): | |
| row = {"identifier": identifier, "candidates": val["candidates"]} | |
| if not self.external_test: | |
| row["labels"] = val["labels"] | |
| # For every scores_* key, compute rank or top candidate | |
| for key, scores in val.items(): | |
| if not key.startswith("scores_"): | |
| continue | |
| variant = key.replace("scores_", "") | |
| if not self.external_test: | |
| row[f"rank_{variant}"] = self._compute_rank(scores, val["labels"]) | |
| else: | |
| row[f"top_cand_{variant}"] = self._get_top_cand(scores, val["candidates"]) | |
| row[key] = scores | |
| records.append(row) | |
| self.df_test = pd.DataFrame(records) | |
| if self.external_test and "labels" in self.df_test.columns: | |
| self.df_test.drop(columns=["labels"], inplace=True) | |
| # Save once | |
| if not getattr(self, "df_test_path", None): | |
| self.df_test_path = os.path.join(self.hparams["experiment_dir"], "result_combined.pkl") | |
| self.df_test.to_pickle(self.df_test_path) | |
| class CrossAttenContrastive(ContrastiveModel): | |
| def __init__( | |
| self, | |
| **kwargs | |
| ): | |
| super(CrossAttenContrastive, self).__init__(**kwargs) | |
| self.specMolCrossAttentionModel = CrossAttention(self.hparams.formula_dims[-1], self.hparams.gnn_channels[-1], self.hparams.cross_attn_heads, dim_out=self.hparams.final_embedding_dim, dropout=0.3) | |
| self.molSpecCrossAttentionModel = CrossAttention(self.hparams.gnn_channels[-1], self.hparams.formula_dims[-1], self.hparams.cross_attn_heads, dim_out=self.hparams.final_embedding_dim, dropout=0.3) | |
| def forward(self, batch, stage) -> tuple[torch.Tensor, torch.Tensor]: | |
| # Unpack inputs | |
| spec = batch[self.spec_view] | |
| spec_n_forms = batch['n_peaks'] | |
| g = batch['cand'] if stage == Stage.TEST else batch['mol'] | |
| g_n_nodes = batch['mol_n_nodes'] | |
| # encode peaks and nodes | |
| spec_enc = self.spec_enc_model(spec, spec_n_forms) | |
| mol_enc = self.mol_enc_model(g) | |
| # pad mol_enc and spec_enc to have the same length | |
| max_nodes = max(g_n_nodes) | |
| max_forms = max(spec_n_forms) | |
| if max_forms > max_nodes: ## pad mol_enc | |
| mol_enc = torch.cat((mol_enc, torch.rand(max_forms, self.hparams.gnn_channels[-1]).to(spec.device))) | |
| mol_enc = torch.split(mol_enc, g_n_nodes+[max_forms]) | |
| mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5)[:-1,:,:] | |
| elif max_nodes > max_forms: ## pad spec_enc | |
| dim_diff = max_nodes - max_forms | |
| spec_enc = F.pad(spec_enc, (0,0,0,dim_diff, 0,0), value=-5) | |
| mol_enc = torch.split(mol_enc, g_n_nodes) | |
| mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5) | |
| else: | |
| mol_enc = torch.split(mol_enc, g_n_nodes) | |
| mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5) | |
| spec_pad = torch.all((spec_enc == -5), -1) | |
| mol_pad = torch.all((mol_enc == -5), -1) | |
| # cross attention | |
| tmp_spec_enc = spec_enc * 1.0 | |
| spec_enc = self.specMolCrossAttentionModel(spec_enc, mol_enc, mol_enc, mask=mol_pad) | |
| mol_enc = self.molSpecCrossAttentionModel(mol_enc, tmp_spec_enc, tmp_spec_enc, mask=spec_pad) | |
| # pool | |
| spec_indecies = torch.tensor([i for i, count in enumerate(spec_n_forms) for _ in range(count)]).to(spec_enc.device) | |
| mol_indecies = torch.tensor([i for i, count in enumerate(g_n_nodes) for _ in range(count)]).to(mol_enc.device) | |
| spec_enc = spec_enc[~spec_pad].reshape(-1, spec_enc.shape[-1]) | |
| mol_enc = mol_enc[~mol_pad].reshape(-1, mol_enc.shape[-1]) | |
| spec_enc = global_mean_pool(spec_enc, spec_indecies) | |
| mol_enc = global_mean_pool(mol_enc, mol_indecies) | |
| return spec_enc, mol_enc |