FLARE / flare /models /contrastive.py
yzhouchen001's picture
update
19a4dfc
import typing as T
import torch
import torch.nn as nn
import pandas as pd
from collections import defaultdict
import numpy as np
import os
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
from massspecgym.models.base import Stage
from massspecgym import utils
from torch.nn.utils.rnn import pad_sequence
from flare.utils.loss import contrastive_loss, filip_loss_with_mask, global_infonce_loss, pcgrad_combine
import flare.utils.models as model_utils
from flare.utils.general import pad_graph_nodes, filip_similarity_batch
from flare.models.encoders import CrossAttention
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
import torch, dgllife
class ContrastiveModel(RetrievalMassSpecGymModel):
def __init__(
self,
external_test: bool = False,
**kwargs
):
super().__init__(**kwargs)
self.save_hyperparameters()
self.external_test = external_test
if 'use_fp' not in self.hparams:
self.hparams.use_fp = False
if 'use_NL_spec' not in self.hparams:
self.hparams.use_NL_spec = False
self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
self.spec_view = self.hparams.spectra_view
# result storage for testing results
self.result_dct = defaultdict(lambda: defaultdict(list))
def forward(self, batch, stage):
if 'cand' in batch:
g = batch['cand']
elif 'mol' in batch:
g = batch['mol']
else:
g = None
spec = batch[self.spec_view]
n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
spec_enc = self.spec_enc_model(spec, n_peaks)
if g is None:
mol_enc = None
return spec_enc, mol_enc
fp = batch['fp'] if self.hparams.use_fp else None
f = self.mol_enc_model.GNN(g, g.ndata['h'])
mol_enc = self.mol_enc_model(g, fp=fp)
return spec_enc, mol_enc
def compute_loss(self, batch: dict, spec_enc, mol_enc, output):
loss = 0
losses = {}
contr_loss, _, _ = contrastive_loss(spec_enc, mol_enc, self.hparams.contr_temp)
losses['contr_loss'] = contr_loss.detach().item()
loss+=contr_loss
losses['loss'] = loss
return losses
def step(
self, batch: dict, stage= Stage.NONE):
# Compute spectra and mol encoding
spec_enc, mol_enc = self.forward(batch, stage)
if stage == Stage.TEST:
return dict(spec_enc=spec_enc, mol_enc=mol_enc)
# Aux tasks
output = {}
if self.hparams.pred_fp:
output['fp'] = self.fp_pred_model(mol_enc)
if self.hparams.use_cons_spec:
spec = batch[self.spec_view]
n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
output['ind_spec'] = self.spec_enc_model(spec, n_peaks)
# Calculate loss
losses = self.compute_loss(batch, spec_enc, mol_enc, output)
return losses
def on_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage) -> None:
# total loss
self.log(
f'{stage.to_pref()}loss',
outputs['loss'],
batch_size=len(batch['identifier']),
sync_dist=True,
prog_bar=True,
on_epoch=True,
# on_step=True
)
def test_step(self, batch, batch_idx):
# Unpack inputs
identifiers = batch['identifier']
cand_smiles = batch['cand_smiles']
id_to_ct = defaultdict(int)
for i in identifiers: id_to_ct[i]+=1
batch_ptr = torch.tensor(list(id_to_ct.values()))
outputs = self.step(batch, stage=Stage.TEST)
spec_enc = outputs['spec_enc']
mol_enc = outputs['mol_enc']
# Calculate scores
indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
scores = nn.functional.cosine_similarity(spec_enc, mol_enc)
scores = torch.split(scores, list(id_to_ct.values()))
cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
labels = utils.unbatch_list(batch['label'], indexes)
return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
# save scores
for i, cands, scores, l in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['scores'], outputs['labels']):
self.result_dct[i]['candidates'].extend(cands)
self.result_dct[i]['scores'].extend(scores.cpu().tolist())
self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
def _compute_rank(self, scores, labels):
if not any(labels):
return -1
scores = np.array(scores)
target_score = scores[labels][0]
rank = np.count_nonzero(scores >=target_score)
return rank
def _get_top_cand(self, scores, candidates):
return candidates[np.argmax(np.array(scores))]
def on_test_epoch_end(self) -> None:
self.df_test = pd.DataFrame.from_dict(self.result_dct, orient='index').reset_index().rename(columns={'index': 'identifier'})
# Compute rank
if not self.external_test:
self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1)
if self.external_test:
self.df_test.drop('labels', axis=1, inplace=True)
self.df_test['top_cand'] = self.df_test.apply(lambda row: self._get_top_cand(row['scores'], row['candidates']), axis=1)
if not self.df_test_path:
self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
self.df_test.to_pickle(self.df_test_path)
def get_checkpoint_monitors(self) -> T.List[dict]:
monitors = [
{"monitor": f"{Stage.TRAIN.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor train loss
{"monitor": f"{Stage.VAL.to_pref()}loss", "mode": "min", "early_stopping": False}, # monitor val loss
]
return monitors
class FilipContrastive(ContrastiveModel):
def __init__(self,
**kwargs):
super().__init__(**kwargs)
def compute_loss(self, batch: dict, spec_enc, mol_enc, spec_mask, mol_mask):
losses = {}
loss = filip_loss_with_mask(spec_enc, mol_enc, spec_mask, mol_mask, self.hparams.contr_temp)
losses['loss'] = loss
return losses
def step(
self, batch: dict, stage= Stage.NONE):
# Compute spectra and mol encoding
spec_enc, mol_enc = self.forward(batch, stage)
# pad nodes to max_n_nodes in batch (Spectra are already padded)
mol_enc, mol_mask = pad_graph_nodes(mol_enc, batch['mol_n_nodes'])
spec_mask = ~torch.all((spec_enc == -5), dim=-1)
if stage == Stage.TEST:
return dict(spec_enc=spec_enc, mol_enc=mol_enc, spec_mask=spec_mask, mol_mask=mol_mask)
# Calculate loss
losses = self.compute_loss(batch, spec_enc, mol_enc, spec_mask, mol_mask)
return losses
def test_step(self, batch, batch_idx):
# Unpack inputs
identifiers = batch['identifier']
cand_smiles = batch['cand_smiles']
id_to_ct = defaultdict(int)
for i in identifiers: id_to_ct[i]+=1
batch_ptr = torch.tensor(list(id_to_ct.values()))
outputs = self.step(batch, stage=Stage.TEST)
spec_enc = outputs['spec_enc']
mol_enc = outputs['mol_enc']
spec_mask = outputs['spec_mask']
mol_mask = outputs['mol_mask']
# Calculate scores
indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask)
scores = torch.split(scores, list(id_to_ct.values()))
cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
labels = utils.unbatch_list(batch['label'], indexes)
return dict(identifiers=list(id_to_ct.keys()), scores=scores, cand_smiles=cand_smiles, labels=labels)
# ============================================================
# Combined FILIP + Global InfoNCE
# ============================================================
class FilipGlobalContrastive(ContrastiveModel):
def __init__(self, loss_mode="sum", loss_weight=1.0, agg_fn="mean", **kwargs):
"""
Args:
loss_mode: str, one of ["sum", "weighted", "pcgrad"]
loss_weight: weight for global loss if using weighted sum
agg_fn: aggregation function for global InfoNCE ("mean", "max", "cls")
"""
super().__init__(**kwargs)
self.loss_mode = loss_mode
self.loss_weight = loss_weight
self.agg_fn = agg_fn
# -------------- loss computation --------------
def compute_loss(self, batch: dict, spec_enc, mol_enc, spec_mask, mol_mask, stage=Stage.NONE):
losses = {}
# fine-grained FILIP loss
loss_fine = filip_loss_with_mask(spec_enc, mol_enc, spec_mask, mol_mask, self.hparams.contr_temp)
# global InfoNCE loss
loss_global = global_infonce_loss(spec_enc, mol_enc, spec_mask, mol_mask,
temperature=self.hparams.contr_temp, agg_fn=self.agg_fn)
# choose combination mode
if self.loss_mode == "sum":
loss = loss_fine + loss_global
elif self.loss_mode == "weighted":
loss = loss_fine + self.loss_weight * loss_global
elif self.loss_mode == "pcgrad":
if stage == Stage.TRAIN:
# PCGrad over both losses (training only)
shared_params = list(self.spec_enc_model.parameters()) + list(self.mol_enc_model.parameters())
self.zero_grad(set_to_none=True)
loss = pcgrad_combine([loss_fine, loss_global], shared_params)
else:
loss = (loss_fine + loss_global).detach()
else:
raise ValueError(f"Unsupported loss_mode: {self.loss_mode}")
losses["loss"] = loss
losses["loss_fine"] = loss_fine.detach()
losses["loss_global"] = loss_global.detach()
return losses
def step(self, batch: dict, stage=Stage.NONE):
spec_enc, mol_enc = self.forward(batch, stage)
mol_enc, mol_mask = pad_graph_nodes(mol_enc, batch["mol_n_nodes"])
spec_mask = ~torch.all((spec_enc == -5), dim=-1)
if stage == Stage.TEST:
return dict(spec_enc=spec_enc, mol_enc=mol_enc, spec_mask=spec_mask, mol_mask=mol_mask)
losses = self.compute_loss(batch, spec_enc, mol_enc, spec_mask, mol_mask, stage=stage)
return losses
# -------------- TEST step with different score variants --------------
def test_step(self, batch, batch_idx):
identifiers = batch["identifier"]
cand_smiles = batch["cand_smiles"]
id_to_ct = defaultdict(int)
for i in identifiers:
id_to_ct[i] += 1
batch_ptr = torch.tensor(list(id_to_ct.values()), device=self.device)
outputs = self.step(batch, stage=Stage.TEST)
spec_enc = outputs["spec_enc"]
mol_enc = outputs["mol_enc"]
spec_mask = outputs["spec_mask"]
mol_mask = outputs["mol_mask"]
indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
# --- fine-grained score ---
fine_scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask)
# --- global cosine score ---
spec_global = (spec_enc * spec_mask.unsqueeze(-1)).sum(1) / spec_mask.sum(1, keepdim=True).clamp(min=1)
mol_global = (mol_enc * mol_mask.unsqueeze(-1)).sum(1) / mol_mask.sum(1, keepdim=True).clamp(min=1)
global_scores = F.cosine_similarity(spec_global, mol_global, dim=-1)
# --- combined scores (for evaluation) ---
combined_sum = fine_scores + global_scores
combined_weighted = fine_scores + self.loss_weight * global_scores
combined_pc = 0.5 * (fine_scores + global_scores) # simple average baseline
scores_dict = {
"fine": fine_scores,
"global": global_scores,
"sum": combined_sum,
"weighted": combined_weighted,
"avg": combined_pc,
}
# split back per identifier
for key in scores_dict:
scores_dict[key] = torch.split(scores_dict[key], list(id_to_ct.values()))
cand_smiles = utils.unbatch_list(batch["cand_smiles"], indexes)
labels = utils.unbatch_list(batch["label"], indexes)
return dict(
identifiers=list(id_to_ct.keys()),
scores=scores_dict,
cand_smiles=cand_smiles,
labels=labels,
)
def on_test_batch_end(self, outputs, batch: dict, batch_idx: int, stage: Stage = Stage.TEST) -> None:
"""
Collects test batch outputs and stores them in self.result_dct.
Supports both:
- Single score list format (legacy)
- Dict of multiple score variants (new)
"""
identifiers = outputs["identifiers"]
cand_smiles = outputs["cand_smiles"]
labels = outputs["labels"]
scores_out = outputs["scores"]
for k, (i, cands, l) in enumerate(zip(outputs['identifiers'], outputs['cand_smiles'], outputs['labels'])):
self.result_dct[i]['candidates'].extend(cands)
self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
for variant_name, score_list in scores_out.items():
self.result_dct[i][f"scores_{variant_name}"].extend(score_list[k].cpu().tolist())
def on_test_epoch_end(self) -> None:
"""
Combine results into one DataFrame with one row per identifier.
Adds rank/top_cand columns for each score variant.
"""
records = []
for identifier, val in self.result_dct.items():
row = {"identifier": identifier, "candidates": val["candidates"]}
if not self.external_test:
row["labels"] = val["labels"]
# For every scores_* key, compute rank or top candidate
for key, scores in val.items():
if not key.startswith("scores_"):
continue
variant = key.replace("scores_", "")
if not self.external_test:
row[f"rank_{variant}"] = self._compute_rank(scores, val["labels"])
else:
row[f"top_cand_{variant}"] = self._get_top_cand(scores, val["candidates"])
row[key] = scores
records.append(row)
self.df_test = pd.DataFrame(records)
if self.external_test and "labels" in self.df_test.columns:
self.df_test.drop(columns=["labels"], inplace=True)
# Save once
if not getattr(self, "df_test_path", None):
self.df_test_path = os.path.join(self.hparams["experiment_dir"], "result_combined.pkl")
self.df_test.to_pickle(self.df_test_path)
class CrossAttenContrastive(ContrastiveModel):
def __init__(
self,
**kwargs
):
super(CrossAttenContrastive, self).__init__(**kwargs)
self.specMolCrossAttentionModel = CrossAttention(self.hparams.formula_dims[-1], self.hparams.gnn_channels[-1], self.hparams.cross_attn_heads, dim_out=self.hparams.final_embedding_dim, dropout=0.3)
self.molSpecCrossAttentionModel = CrossAttention(self.hparams.gnn_channels[-1], self.hparams.formula_dims[-1], self.hparams.cross_attn_heads, dim_out=self.hparams.final_embedding_dim, dropout=0.3)
def forward(self, batch, stage) -> tuple[torch.Tensor, torch.Tensor]:
# Unpack inputs
spec = batch[self.spec_view]
spec_n_forms = batch['n_peaks']
g = batch['cand'] if stage == Stage.TEST else batch['mol']
g_n_nodes = batch['mol_n_nodes']
# encode peaks and nodes
spec_enc = self.spec_enc_model(spec, spec_n_forms)
mol_enc = self.mol_enc_model(g)
# pad mol_enc and spec_enc to have the same length
max_nodes = max(g_n_nodes)
max_forms = max(spec_n_forms)
if max_forms > max_nodes: ## pad mol_enc
mol_enc = torch.cat((mol_enc, torch.rand(max_forms, self.hparams.gnn_channels[-1]).to(spec.device)))
mol_enc = torch.split(mol_enc, g_n_nodes+[max_forms])
mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5)[:-1,:,:]
elif max_nodes > max_forms: ## pad spec_enc
dim_diff = max_nodes - max_forms
spec_enc = F.pad(spec_enc, (0,0,0,dim_diff, 0,0), value=-5)
mol_enc = torch.split(mol_enc, g_n_nodes)
mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5)
else:
mol_enc = torch.split(mol_enc, g_n_nodes)
mol_enc = pad_sequence(mol_enc, batch_first=True, padding_value=-5)
spec_pad = torch.all((spec_enc == -5), -1)
mol_pad = torch.all((mol_enc == -5), -1)
# cross attention
tmp_spec_enc = spec_enc * 1.0
spec_enc = self.specMolCrossAttentionModel(spec_enc, mol_enc, mol_enc, mask=mol_pad)
mol_enc = self.molSpecCrossAttentionModel(mol_enc, tmp_spec_enc, tmp_spec_enc, mask=spec_pad)
# pool
spec_indecies = torch.tensor([i for i, count in enumerate(spec_n_forms) for _ in range(count)]).to(spec_enc.device)
mol_indecies = torch.tensor([i for i, count in enumerate(g_n_nodes) for _ in range(count)]).to(mol_enc.device)
spec_enc = spec_enc[~spec_pad].reshape(-1, spec_enc.shape[-1])
mol_enc = mol_enc[~mol_pad].reshape(-1, mol_enc.shape[-1])
spec_enc = global_mean_pool(spec_enc, spec_indecies)
mol_enc = global_mean_pool(mol_enc, mol_indecies)
return spec_enc, mol_enc