yzhouchen001 commited on
Commit
53b1581
·
1 Parent(s): 6c3d8a1
flare/models/__init__.py CHANGED
@@ -1,3 +1 @@
1
- import sys
2
- sys.path.insert(0, "/data/yzhouc01//MassSpecGym")
3
  from massspecgym.models import *
 
 
 
1
  from massspecgym.models import *
flare/models/contrastive.py CHANGED
@@ -30,12 +30,6 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
30
  self.save_hyperparameters()
31
  self.external_test = external_test
32
 
33
- if 'use_fp' not in self.hparams:
34
- self.hparams.use_fp = False
35
- if 'use_NL_spec' not in self.hparams:
36
- self.hparams.use_NL_spec = False
37
-
38
-
39
  self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
40
  self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
41
 
@@ -59,13 +53,10 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
59
  if g is None:
60
  mol_enc = None
61
  return spec_enc, mol_enc
62
-
63
- fp = batch['fp'] if self.hparams.use_fp else None
64
-
65
-
66
- f = self.mol_enc_model.GNN(g, g.ndata['h'])
67
 
68
- mol_enc = self.mol_enc_model(g, fp=fp)
 
 
69
 
70
  return spec_enc, mol_enc
71
 
@@ -91,16 +82,7 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
91
  if stage == Stage.TEST:
92
  return dict(spec_enc=spec_enc, mol_enc=mol_enc)
93
 
94
- # Aux tasks
95
  output = {}
96
- if self.hparams.pred_fp:
97
- output['fp'] = self.fp_pred_model(mol_enc)
98
-
99
- if self.hparams.use_cons_spec:
100
- spec = batch[self.spec_view]
101
- n_peaks = batch['n_peaks'] if 'n_peaks' in batch else None
102
- output['ind_spec'] = self.spec_enc_model(spec, n_peaks)
103
-
104
  # Calculate loss
105
  losses = self.compute_loss(batch, spec_enc, mol_enc, output)
106
 
@@ -174,6 +156,9 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
174
 
175
  if not self.df_test_path:
176
  self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
 
 
 
177
  self.df_test.to_pickle(self.df_test_path)
178
 
179
  def get_checkpoint_monitors(self) -> T.List[dict]:
 
30
  self.save_hyperparameters()
31
  self.external_test = external_test
32
 
 
 
 
 
 
 
33
  self.spec_enc_model = model_utils.get_spec_encoder(self.hparams.spec_enc, self.hparams)
34
  self.mol_enc_model = model_utils.get_mol_encoder(self.hparams.mol_enc, self.hparams)
35
 
 
53
  if g is None:
54
  mol_enc = None
55
  return spec_enc, mol_enc
 
 
 
 
 
56
 
57
+ # Match historical call pattern (some DGL/dgllife paths mutate graph state in-place).
58
+ _ = self.mol_enc_model.GNN(g, g.ndata["h"])
59
+ mol_enc = self.mol_enc_model(g)
60
 
61
  return spec_enc, mol_enc
62
 
 
82
  if stage == Stage.TEST:
83
  return dict(spec_enc=spec_enc, mol_enc=mol_enc)
84
 
 
85
  output = {}
 
 
 
 
 
 
 
 
86
  # Calculate loss
87
  losses = self.compute_loss(batch, spec_enc, mol_enc, output)
88
 
 
156
 
157
  if not self.df_test_path:
158
  self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
159
+
160
+
161
+ os.makedirs(os.path.dirname(self.df_test_path), exist_ok=True)
162
  self.df_test.to_pickle(self.df_test_path)
163
 
164
  def get_checkpoint_monitors(self) -> T.List[dict]:
flare/models/mol_encoder.py CHANGED
@@ -10,10 +10,25 @@ class MolEnc(nn.Module):
10
  in_dim,):
11
  super().__init__()
12
 
13
- self.return_emb = False
14
-
15
- if args.model in ('filipContrastive', 'crossAttenContrastive', 'filipGlobalContrastive'):
16
- self.return_emb = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
19
  batchnorm = [True for _ in range(len(args.gnn_channels))]
@@ -30,7 +45,7 @@ class MolEnc(nn.Module):
30
  self.dropout = nn.Dropout(args.fc_dropout)
31
  self.relu = nn.ReLU()
32
 
33
- def forward(self, g, fp=None) -> torch.Tensor:
34
  g1 = g
35
  f1 = g.ndata['h']
36
 
@@ -38,8 +53,6 @@ class MolEnc(nn.Module):
38
  if self.return_emb:
39
  return f
40
  h = self.pool(g1, f)
41
- if fp is not None:
42
- h = torch.concat((h, fp), dim=-1)
43
  h1 = self.relu(self.fc1_graph(h))
44
  h1 = self.dropout(h1)
45
  h1 = self.fc2_graph(h1)
 
10
  in_dim,):
11
  super().__init__()
12
 
13
+ # Whether to return node-level embeddings (sum_nodes, D) instead of a
14
+ # pooled graph embedding (B, D).
15
+ #
16
+ # Backward compatible defaults:
17
+ # - Historically, this was inferred from `args.model` for FILIP-style models.
18
+ # - New: allow explicit override via `args.mol_return_emb` (preferred) or
19
+ # `args.return_mol_emb` (legacy-friendly alias).
20
+ explicit = None
21
+ if hasattr(args, "mol_return_emb"):
22
+ explicit = getattr(args, "mol_return_emb")
23
+ elif hasattr(args, "return_mol_emb"):
24
+ explicit = getattr(args, "return_mol_emb")
25
+
26
+ if explicit is not None:
27
+ self.return_emb = bool(explicit)
28
+ else:
29
+ self.return_emb = False
30
+ if args.model in ("filipContrastive", "crossAttenContrastive", "filipGlobalContrastive"):
31
+ self.return_emb = True
32
 
33
  dropout = [args.gnn_dropout for _ in range(len(args.gnn_channels))]
34
  batchnorm = [True for _ in range(len(args.gnn_channels))]
 
45
  self.dropout = nn.Dropout(args.fc_dropout)
46
  self.relu = nn.ReLU()
47
 
48
+ def forward(self, g) -> torch.Tensor:
49
  g1 = g
50
  f1 = g.ndata['h']
51
 
 
53
  if self.return_emb:
54
  return f
55
  h = self.pool(g1, f)
 
 
56
  h1 = self.relu(self.fc1_graph(h))
57
  h1 = self.dropout(h1)
58
  h1 = self.fc2_graph(h1)