yzhouchen001 commited on
Commit
e2a4031
·
1 Parent(s): 2c0063e

standard sim

Browse files
Files changed (1) hide show
  1. flare/models/contrastive.py +1 -1
flare/models/contrastive.py CHANGED
@@ -381,7 +381,7 @@ class FilipContrastive(ContrastiveModel):
381
  # Calculate scores
382
  indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
383
 
384
- scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_mask, reduction='geom', temperature=0.05)
385
  scores = torch.split(scores, list(id_to_ct.values()))
386
 
387
  cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)
 
381
  # Calculate scores
382
  indexes = utils.batch_ptr_to_batch_idx(batch_ptr)
383
 
384
+ scores = filip_similarity_batch(spec_enc, mol_enc, spec_mask, mol_masks)
385
  scores = torch.split(scores, list(id_to_ct.values()))
386
 
387
  cand_smiles = utils.unbatch_list(batch['cand_smiles'], indexes)