import typing as T import torch import torch.nn as nn import pandas as pd from collections import defaultdict import numpy as np import os from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel from massspecgym.models.base import Stage from massspecgym import utils from torch.nn.utils.rnn import pad_sequence from flare.utils.loss import contrastive_loss, filip_loss_with_mask, global_infonce_loss, pcgrad_combine import flare.utils.models as model_utils from flare.utils.general import pad_graph_nodes, filip_similarity_batch from flare.models.encoders import CrossAttention import torch.nn.functional as F from torch_geometric.nn import global_mean_pool import torch, dgllife class ContrastiveModel(RetrievalMassSpecGymModel): def __init__( self, external_test: bool = False, **kwargs ): super().__init__(**kwargs) self.save_hyperparameters() self.external_test = external_test if 'use_fp' not in self.hparams: self.hparams.use_fp = False if 'use_NL_spec' not in self.hparams: self.hparams.use_NL_spec = False self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams) self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams) self.spec_view = self.hparams.spectra_view # result storage for testing results self.result_dct = defaultdict(lambda: defaultdict(list)) def forward(self, batch, stage): if 'cand' in batch: g = batch['cand'] elif 'mol' in batch: g = batch['mol'] else: g = None spec = batch[self.spec_view] n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None spec_enc = self.spec_enc_model(spec, n_peaks) if g is None: mol_enc = None return spec_enc, mol_enc fp = batch['fp'] if self.hparams.use_fp else None f = self.mol_enc_model.GNN(g, g.ndata['h']) mol_enc = self.mol_enc_model(g, fp=fp) return spec_enc, mol_enc def compute_loss(self, batch: dict, spec_enc, mol_enc, output): loss = 0 losses = {} contr_loss, _, _ = contrastive_loss(spec_enc, mol_enc, self.hparams.contr_temp) losses['contr_loss'] = contr_loss.detach().item() loss+=contr_loss losses['loss'] = loss return losses def step( self, batch: dict, stage= Stage.NONE): # Compute spectra and mol encoding spec_enc, mol_enc = self.forward(batch, stage) if stage == Stage.TEST: return dict(spec_enc=spec_enc, mol_enc=mol_enc) # Aux tasks output = {} if self.hparams.pred_fp: output['fp'] = self.fp_pred_model(mol_enc) if self.hparams.use_cons_spec: spec = batch[self.spec_view] n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None output['ind_spec'] = self.spec_enc_model(spec, n_peaks) # Calculate loss losses = self.compute_loss(batch, spec_enc, mol_enc, output) return losses def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None: # total loss self.log( f'{stage.to_pref()}loss', outputs['loss'], batch_size=len(batch['identifier']), sync_dist=True, prog_bar=True, on_epoch=True, # on_step=True ) def test_step(self, batch, batch_idx): # Unpack inputs identifiers = batch['identifier'] cand_smiles = batch['cand_smiles'] id_to_ct = defaultdict(int) for i in identifiers: id_to_ct[i]+=1 batch_ptr = torch.tensor(list(id_to_ct.values())) outputs = self.step(batch, stage=Stage.TEST) spec_enc = outputs['spec_enc'] mol_enc = outputs['mol_enc'] # Calculate scores indexes = utils.batch_ptr_to_batch_idx(batch_ptr) scores = nn.functional.cosine_similarity(spec_enc, mol_enc) scores = torch.split(scores, list(id_to_ct.values())) cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes) labels = utils.unbatch_list(batch['label'], indexes) return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels) def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None: # save scores for i, cands, scores, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['scores'], outputs['labels']): self.result_dct[i]['candidates'].extend(cands) self.result_dct[i]['scores'].extend(scores.cpu().tolist()) self.result_dct[i]['labels'].extend([x.cpu().item() for x in l]) def _compute_rank(self, scores, labels): if not any(labels): return -1 scores = np.array(scores) target_score = scores[labels][0] rank = np.count_nonzero(scores >=target_score) return rank def _get_top_cand(self, scores, candidates): return candidates[np.argmax(np.array(scores))] def on_test_epoch_end(self) -> None: self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'}) # Compute rank if not self.external_test: self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1) if self.external_test: self.df_test.drop('labels', axis=1, inplace=True) self.df_test['top_cand'] = self.df_test.apply(lambda row: self._get_top_cand(row['scores'], row['candidates']), axis=1) if not self.df_test_path: self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl') self.df_test.to_pickle(self.df_test_path) def get_checkpoint_monitors(self) -> T.List[dict]: monitors = [ {"monitor": f"{Stage.TRAIN.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor train loss {"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor val loss ] return monitors class FilipContrastive(ContrastiveModel): def __init__(self, **kwargs): super().__init__(**kwargs) def compute_loss(self, batch: dict, spec_enc, mol_enc, spec_mask, mol_mask): losses = {} loss = filip_loss_with_mask(spec_enc, mol_enc, spec_mask, mol_mask, self.hparams.contr_temp) losses['loss'] = loss return losses def step( self, batch: dict, stage= Stage.NONE): # Compute spectra and mol encoding spec_enc, mol_enc = self.forward(batch, stage) # pad nodes to max_n_nodes in batch (Spectra are already padded) mol_enc, mol_mask = pad_graph_nodes(mol_enc, batch['mol_n_nodes']) spec_mask = ~torch.all((spec_enc == -5), dim=-1) if stage == Stage.TEST: return dict(spec_enc=spec_enc, mol_enc=mol_enc, spec_mask=spec_mask, mol_mask=mol_mask) # Calculate loss losses = self.compute_loss(batch, spec_enc, mol_enc, spec_mask, mol_mask) return losses def test_step(self, batch, batch_idx): # Unpack inputs identifiers = batch['identifier'] cand_smiles = batch['cand_smiles'] id_to_ct = defaultdict(int) for i in identifiers: id_to_ct[i]+=1 batch_ptr = torch.tensor(list(id_to_ct.values())) outputs = self.step(batch, stage=Stage.TEST) spec_enc = outputs['spec_enc'] mol_enc = outputs['mol_enc'] spec_mask = outputs['spec_mask'] mol_mask = outputs['mol_mask'] # Calculate scores indexes = utils.batch_ptr_to_batch_idx(batch_ptr) scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask) scores = torch.split(scores, list(id_to_ct.values())) cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes) labels = utils.unbatch_list(batch['label'], indexes) return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels) # ============================================================ # Combined FILIP + Global InfoNCE # ============================================================ class FilipGlobalContrastive(ContrastiveModel): def __init__(self, loss_mode="sum", loss_weight=1.0, agg_fn="mean", **kwargs): """ Args: loss_mode: str, one of ["sum", "weighted", "pcgrad"] loss_weight: weight for global loss if using weighted sum agg_fn: aggregation function for global InfoNCE ("mean", "max", "cls") """ super().__init__(**kwargs) self.loss_mode = loss_mode self.loss_weight = loss_weight self.agg_fn = agg_fn # -------------- loss computation -------------- def compute_loss(self, batch: dict, spec_enc, mol_enc, spec_mask, mol_mask, stage=Stage.NONE): losses = {} # fine-grained FILIP loss loss_fine = filip_loss_with_mask(spec_enc, mol_enc, spec_mask, mol_mask, self.hparams.contr_temp) # global InfoNCE loss loss_global = global_infonce_loss(spec_enc, mol_enc, spec_mask, mol_mask, temperature=self.hparams.contr_temp, agg_fn=self.agg_fn) # choose combination mode if self.loss_mode == "sum": loss = loss_fine + loss_global elif self.loss_mode == "weighted": loss = loss_fine + self.loss_weight * loss_global elif self.loss_mode == "pcgrad": if stage == Stage.TRAIN: # PCGrad over both losses (training only) shared_params = list(self.spec_enc_model.parameters()) + list(self.mol_enc_model.parameters()) self.zero_grad(set_to_none=True) loss = pcgrad_combine([loss_fine, loss_global], shared_params) else: loss = (loss_fine + loss_global).detach() else: raise ValueError(f"Unsupported loss_mode: {self.loss_mode}") losses["loss"] = loss losses["loss_fine"] = loss_fine.detach() losses["loss_global"] = loss_global.detach() return losses def step(self, batch: dict, stage=Stage.NONE): spec_enc, mol_enc = self.forward(batch, stage) mol_enc, mol_mask = pad_graph_nodes(mol_enc, batch["mol_n_nodes"]) spec_mask = ~torch.all((spec_enc == -5), dim=-1) if stage == Stage.TEST: return dict(spec_enc=spec_enc, mol_enc=mol_enc, spec_mask=spec_mask, mol_mask=mol_mask) losses = self.compute_loss(batch, spec_enc, mol_enc, spec_mask, mol_mask, stage=stage) return losses # -------------- TEST step with different score variants -------------- def test_step(self, batch, batch_idx): identifiers = batch["identifier"] cand_smiles = batch["cand_smiles"] id_to_ct = defaultdict(int) for i in identifiers: id_to_ct[i] += 1 batch_ptr = torch.tensor(list(id_to_ct.values()), device=self.device) outputs = self.step(batch, stage=Stage.TEST) spec_enc = outputs["spec_enc"] mol_enc = outputs["mol_enc"] spec_mask = outputs["spec_mask"] mol_mask = outputs["mol_mask"] indexes = utils.batch_ptr_to_batch_idx(batch_ptr) # --- fine-grained score --- fine_scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask) # --- global cosine score --- spec_global = (spec_enc * spec_mask.unsqueeze(-1)).sum(1) / spec_mask.sum(1, keepdim=True).clamp(min=1) mol_global = (mol_enc * mol_mask.unsqueeze(-1)).sum(1) / mol_mask.sum(1, keepdim=True).clamp(min=1) global_scores = F.cosine_similarity(spec_global, mol_global, dim=-1) # --- combined scores (for evaluation) --- combined_sum = fine_scores + global_scores combined_weighted = fine_scores + self.loss_weight * global_scores combined_pc = 0.5 * (fine_scores + global_scores) # simple average baseline scores_dict = { "fine": fine_scores, "global": global_scores, "sum": combined_sum, "weighted": combined_weighted, "avg": combined_pc, } # split back per identifier for key in scores_dict: scores_dict[key] = torch.split(scores_dict[key], list(id_to_ct.values())) cand_smiles = utils.unbatch_list(batch["cand_smiles"], indexes) labels = utils.unbatch_list(batch["label"], indexes) return dict( identifiers=list(id_to_ct.keys()), scores=scores_dict, cand_smiles=cand_smiles, labels=labels, ) def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None: """ Collects test batch outputs and stores them in self.result_dct. Supports both: - Single score list format (legacy) - Dict of multiple score variants (new) """ identifiers = outputs["identifiers"] cand_smiles = outputs["cand_smiles"] labels = outputs["labels"] scores_out = outputs["scores"] for k, (i, cands, l) in enumerate(zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels'])): self.result_dct[i]['candidates'].extend(cands) self.result_dct[i]['labels'].extend([x.cpu().item() for x in l]) for variant_name, score_list in scores_out.items(): self.result_dct[i][f"scores_{variant_name}"].extend(score_list[k].cpu().tolist()) def on_test_epoch_end(self) -> None: """ Combine results into one DataFrame with one row per identifier. Adds rank/top_cand columns for each score variant. """ records = [] for identifier, val in self.result_dct.items(): row = {"identifier": identifier, "candidates": val["candidates"]} if not self.external_test: row["labels"] = val["labels"] # For every scores_* key, compute rank or top candidate for key, scores in val.items(): if not key.startswith("scores_"): continue variant = key.replace("scores_", "") if not self.external_test: row[f"rank_{variant}"] = self._compute_rank(scores, val["labels"]) else: row[f"top_cand_{variant}"] = self._get_top_cand(scores, val["candidates"]) row[key] = scores records.append(row) self.df_test = pd.DataFrame(records) if self.external_test and "labels" in self.df_test.columns: self.df_test.drop(columns=["labels"], inplace=True) # Save once if not getattr(self, "df_test_path", None): self.df_test_path = os.path.join(self.hparams["experiment_dir"], "result_combined.pkl") self.df_test.to_pickle(self.df_test_path) class CrossAttenContrastive(ContrastiveModel): def __init__( self, **kwargs ): super(CrossAttenContrastive, self).__init__(**kwargs) self.specMolCrossAttentionModel = CrossAttention(self.hparams.formula_dims[-1], self.hparams.gnn_channels[-1], self.hparams.cross_attn_heads, dim_out=self.hparams.final_embedding_dim, dropout=0.3) self.molSpecCrossAttentionModel = CrossAttention(self.hparams.gnn_channels[-1], self.hparams.formula_dims[-1], self.hparams.cross_attn_heads, dim_out=self.hparams.final_embedding_dim, dropout=0.3) def forward(self, batch, stage) -> tuple[torch.Tensor, torch.Tensor]: # Unpack inputs spec = batch[self.spec_view] spec_n_forms = batch['n_peaks'] g = batch['cand'] if stage == Stage.TEST else batch['mol'] g_n_nodes = batch['mol_n_nodes'] # encode peaks and nodes spec_enc = self.spec_enc_model(spec, spec_n_forms) mol_enc = self.mol_enc_model(g) # pad mol_enc and spec_enc to have the same length max_nodes = max(g_n_nodes) max_forms = max(spec_n_forms) if max_forms > max_nodes: ## pad mol_enc mol_enc = torch.cat((mol_enc, torch.rand(max_forms, self.hparams.gnn_channels[-1]).to(spec.device))) mol_enc = torch.split(mol_enc, g_n_nodes+[max_forms]) mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5)[:-1,:,:] elif max_nodes > max_forms: ## pad spec_enc dim_diff = max_nodes - max_forms spec_enc = F.pad(spec_enc, (0,0,0,dim_diff, 0,0), value=-5) mol_enc = torch.split(mol_enc, g_n_nodes) mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5) else: mol_enc = torch.split(mol_enc, g_n_nodes) mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5) spec_pad = torch.all((spec_enc == -5), -1) mol_pad = torch.all((mol_enc == -5), -1) # cross attention tmp_spec_enc = spec_enc * 1.0 spec_enc = self.specMolCrossAttentionModel(spec_enc, mol_enc, mol_enc, mask=mol_pad) mol_enc = self.molSpecCrossAttentionModel(mol_enc, tmp_spec_enc, tmp_spec_enc, mask=spec_pad) # pool spec_indecies = torch.tensor([i for i, count in enumerate(spec_n_forms) for _ in range(count)]).to(spec_enc.device) mol_indecies = torch.tensor([i for i, count in enumerate(g_n_nodes) for _ in range(count)]).to(mol_enc.device) spec_enc = spec_enc[~spec_pad].reshape(-1, spec_enc.shape[-1]) mol_enc = mol_enc[~mol_pad].reshape(-1, mol_enc.shape[-1]) spec_enc = global_mean_pool(spec_enc, spec_indecies) mol_enc = global_mean_pool(mol_enc, mol_indecies) return spec_enc, mol_enc