yzhouchen001 commited on
Commit
f695c70
·
1 Parent(s): 23706c7
README.md CHANGED
@@ -1,2 +1,9 @@
1
  # FILIP-MS
2
  FILIP contrastive learning for metabolite annotation
 
 
 
 
 
 
 
 
1
  # FILIP-MS
2
  FILIP contrastive learning for metabolite annotation
3
+
4
+ ## Magma preprocessing
5
+ `python run_magma.py --data_pth '/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv' --output_dir '/data/yzhouc01/FILIP-MS/data/magma/' --workers 50`
6
+
7
+ ### test candidates
8
+
9
+ `python run_fragmentation_only.py --d '/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv' -c '/r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json' -o '/data/yzhouc01/FILIP-MS/data/msgym_mass_cands.pkl' -w 50`
mvp/data/datasets.py CHANGED
@@ -203,10 +203,10 @@ class MassSpecDataset_PeakFormulas(JESTR1_MassSpecDataset):
203
  return item
204
 
205
  def _load_id_to_spec(self, stage):
206
- if stage == Stage.TRAIN:
207
- self.metadata = self.metadata[self.metadata['fold'] != Stage.TEST.value]
208
- else:
209
- self.metadata = self.metadata[self.metadata['fold'] == Stage.TEST.value]
210
 
211
  all_spec_ids = self.metadata['identifier'].tolist()
212
  self.subformulaLoader = data_utils.Subformula_Loader(spectra_view=self.spectra_view, dir_path=self.subformula_dir_pth, formula_source=self.formula_source)
@@ -369,6 +369,18 @@ class ExpandedRetrievalDataset:
369
  self.candidates[s] = [c for c in cand if '.' not in c]
370
 
371
  self.spec_cand = [] #(spec index, cand_smiles, true_label)
 
 
 
 
 
 
 
 
 
 
 
 
372
  test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
373
  test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
374
 
@@ -376,6 +388,7 @@ class ExpandedRetrievalDataset:
376
 
377
  for spec_id, s in zip(test_ms_id, test_smiles):
378
  candidates = self.candidates[s]
 
379
  # mol_label = self.mol_label_transform(s)
380
  # labels = [self.mol_label_transform(c) == mol_label for c in candidates]
381
  labels = [c == s for c in candidates]
@@ -383,8 +396,8 @@ class ExpandedRetrievalDataset:
383
  print(f"Skipping {spec_id}; empty candidate set")
384
  continue
385
  if not any(labels):
386
- print(f"Target smiles not in candidate set")
387
-
388
 
389
  self.spec_cand.extend([(self.spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
390
 
 
203
  return item
204
 
205
  def _load_id_to_spec(self, stage):
206
+ # if stage == Stage.TRAIN:
207
+ # self.metadata = self.metadata[self.metadata['fold'] != Stage.TEST.value]
208
+ # else:
209
+ # self.metadata = self.metadata[self.metadata['fold'] == Stage.TEST.value]
210
 
211
  all_spec_ids = self.metadata['identifier'].tolist()
212
  self.subformulaLoader = data_utils.Subformula_Loader(spectra_view=self.spectra_view, dir_path=self.subformula_dir_pth, formula_source=self.formula_source)
 
369
  self.candidates[s] = [c for c in cand if '.' not in c]
370
 
371
  self.spec_cand = [] #(spec index, cand_smiles, true_label)
372
+
373
+ # use for external dataset where target smiles is not known
374
+ # self.candidates should be a dict of identifier to candidates
375
+ if 'smiles' not in self.metadata.columns:
376
+ if not isinstance(self.metadata.iloc[0]['identifier'], str):
377
+ self.metadata['smiles'] = self.metadata['identifier'].apply(str)
378
+ else:
379
+ self.metadata['smiles'] = self.metadata['identifier']
380
+
381
+ # keep datapoints where there are candidates
382
+ self.metadata = self.metadata[self.metadata['smiles'].isin(self.candidates.keys())]
383
+
384
  test_smiles = self.metadata[self.metadata['fold'] == "test"]['smiles'].tolist()
385
  test_ms_id = self.metadata[self.metadata['fold'] == "test"]['identifier'].tolist()
386
 
 
388
 
389
  for spec_id, s in zip(test_ms_id, test_smiles):
390
  candidates = self.candidates[s]
391
+
392
  # mol_label = self.mol_label_transform(s)
393
  # labels = [self.mol_label_transform(c) == mol_label for c in candidates]
394
  labels = [c == s for c in candidates]
 
396
  print(f"Skipping {spec_id}; empty candidate set")
397
  continue
398
  if not any(labels):
399
+ # print(f"Target smiles not in candidate set")
400
+ pass
401
 
402
  self.spec_cand.extend([(self.spec_id_to_index[spec_id], candidates[j], k) for j, k in enumerate(labels)])
403
 
mvp/models/contrastive.py CHANGED
@@ -245,6 +245,11 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
245
  self.result_dct[i]['candidates'].extend(cands)
246
  self.result_dct[i]['scores'].extend(scores.cpu().tolist())
247
  self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
 
 
 
 
 
248
 
249
  def _compute_rank(self, scores, labels):
250
  if not any(labels):
@@ -262,7 +267,6 @@ class ContrastiveModel(RetrievalMassSpecGymModel):
262
  self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1)
263
  if not self.df_test_path:
264
  self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
265
- # self.df_test_path.parent.mkdir(parents=True, exist_ok=True)
266
  self.df_test.to_pickle(self.df_test_path)
267
 
268
  def get_checkpoint_monitors(self) -> T.List[dict]:
 
245
  self.result_dct[i]['candidates'].extend(cands)
246
  self.result_dct[i]['scores'].extend(scores.cpu().tolist())
247
  self.result_dct[i]['labels'].extend([x.cpu().item() for x in l])
248
+
249
+ # # external test case only
250
+ # for i, cands, scores in zip(outputs['identifiers'], outputs['cand_smiles'], outputs['scores']):
251
+ # self.result_dct[i.cpu().item()]['candidates'].extend(cands)
252
+ # self.result_dct[i.cpu().item()]['scores'].extend(scores.cpu().tolist())
253
 
254
  def _compute_rank(self, scores, labels):
255
  if not any(labels):
 
267
  self.df_test['rank'] = self.df_test.apply(lambda row: self._compute_rank(row['scores'], row['labels']), axis=1)
268
  if not self.df_test_path:
269
  self.df_test_path = os.path.join(self.hparams['experiment_dir'], 'result.pkl')
 
270
  self.df_test.to_pickle(self.df_test_path)
271
 
272
  def get_checkpoint_monitors(self) -> T.List[dict]:
mvp/params_formSpec.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # Experiment setup
2
  job_key: ''
3
- run_name: 'optimized_filip-model'
4
  run_details: ""
5
  project_name: ''
6
  wandb_entity_name: 'mass-spec-ml'
@@ -10,7 +10,7 @@ debug: False
10
  checkpoint_pth:
11
 
12
  # Training setup
13
- max_epochs: 2000
14
  accelerator: 'gpu'
15
  devices: [1]
16
  log_every_n_steps: 250
@@ -19,7 +19,7 @@ val_check_interval: 1.0
19
  # Data paths
20
  candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
21
  dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # /data/yzhouc01/MVP/data/sample/data.tsv #/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
22
- subformula_dir_pth: /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
23
  split_pth:
24
  fp_dir_pth:
25
  cons_spec_dir_pth:
@@ -110,7 +110,7 @@ formula_dropout: 0.2
110
  formula_dims: [512, 256, 512] #[64, 128, 256]
111
  cross_attn_heads: 2
112
  use_cls: False
113
- peak_dropout: 0.414425691950033 # 0.2
114
  formula_attn_heads: 4 # 2
115
  formula_transformer_layers: 2
116
 
 
1
  # Experiment setup
2
  job_key: ''
3
+ run_name: 'simple_model'
4
  run_details: ""
5
  project_name: ''
6
  wandb_entity_name: 'mass-spec-ml'
 
10
  checkpoint_pth:
11
 
12
  # Training setup
13
+ max_epochs: 1500
14
  accelerator: 'gpu'
15
  devices: [1]
16
  log_every_n_steps: 250
 
19
  # Data paths
20
  candidates_pth: /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_mass.json # "../data/MassSpecGym/data/molecules/MassSpecGym_retrieval_candidates_formula.json"
21
  dataset_pth: /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # /data/yzhouc01/MVP/data/sample/data.tsv #/r/hassounlab/spectra_data/msgym/MassSpecGym.tsv #/data/yzhouc01/spectra_data/combined_msgym_nist23_multiplex.tsv # /r/hassounlab/spectra_data/msgym/MassSpecGym.tsv # "../data/MassSpecGym/data/sample_data.tsv"
22
+ subformula_dir_pth: /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default # /data/yzhouc01/FILIP-MS/data/magma # /r/hassounlab/msgym_sirius # /data/yzhouc01/MVP/data/MassSpecGym/data/subformulae_default #/data/yzhouc01/spectra_data/subformulae #"../data/MassSpecGym/data/subformulae_default"
23
  split_pth:
24
  fp_dir_pth:
25
  cons_spec_dir_pth:
 
110
  formula_dims: [512, 256, 512] #[64, 128, 256]
111
  cross_attn_heads: 2
112
  use_cls: False
113
+ peak_dropout: 0.2
114
  formula_attn_heads: 4 # 2
115
  formula_transformer_layers: 2
116
 
mvp/run.sh CHANGED
@@ -1,3 +1,3 @@
1
- python train.py --param_pth params_tmp.yaml
2
- python test.py --param_pth params_tmp.yaml
3
- python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json --param_pth params_tmp.yaml
 
1
+ python train.py
2
+ python test.py
3
+ python test.py --candidates_pth /r/hassounlab/spectra_data/msgym/molecules/MassSpecGym_retrieval_candidates_formula.json
mvp/subformula_assign/__init__.py ADDED
File without changes
mvp/tune.py CHANGED
@@ -116,19 +116,19 @@ def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_tri
116
 
117
  try:
118
  # Training-related params
119
- params["batch_size"] = trial.suggest_categorical("batch_size", [64, 128])
120
  params["lr"] = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
121
  params["weight_decay"] = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
122
- params["contr_temp"] = trial.suggest_float("contrastive_temp", 0.02, 0.1)
123
 
124
  # Spectra encoder-related params
125
- params['peak_dropout'] = trial.suggest_float("peak_dropout", 0.1, 0.5)
126
- params['formula_attn_heads'] = trial.suggest_categorical("formula_attn_heads", [2, 4])
127
- params['formula_transformer_layers'] = trial.suggest_categorical("formula_transformer_layers", [2, 4])
128
 
129
  choice = trial.suggest_categorical(
130
  "formula_dims",
131
- ["64,128", "512,256", "256,512", "128", "256"]
132
  )
133
  params["formula_dims"] = [int(x) for x in choice.split(",")]
134
 
@@ -136,12 +136,12 @@ def objective(trial: optuna.Trial, base_params, trial_times, base_dir, total_tri
136
  params['gnn_dropout'] = trial.suggest_float("gnn_dropout", 0.1, 0.5)
137
  choice = trial.suggest_categorical(
138
  "gnn_channels",
139
- ["64,128", "128,256", "256,512", "64,128,128"]
140
  )
141
  params["gnn_channels"] = [int(x) for x in choice.split(",")]
142
 
143
  # Ensure last layer matches final embedding dim
144
- final_embedding_dim = trial.suggest_categorical("final_embedding_dim", [256, 512])
145
  params['formula_dims'].append(final_embedding_dim)
146
  params['gnn_channels'].append(final_embedding_dim)
147
 
@@ -229,8 +229,9 @@ def main(args):
229
  with open(args.param_pth) as f:
230
  params = yaml.load(f, Loader=yaml.FullLoader)
231
 
232
- now = datetime.datetime.now().strftime("%Y%m%d")
233
- base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
 
234
  os.makedirs(base_dir, exist_ok=True)
235
  params["experiment_dir"] = base_dir
236
 
@@ -239,8 +240,10 @@ def main(args):
239
  setup_logging(log_path)
240
 
241
  trial_times = []
 
 
242
 
243
- study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
244
  study.optimize(lambda trial: objective(trial, params, trial_times, base_dir, args.n_trials), n_trials=args.n_trials)
245
 
246
  # Print best trial
 
116
 
117
  try:
118
  # Training-related params
119
+ params["batch_size"] = trial.suggest_categorical("batch_size", [32, 64, 128, 256])
120
  params["lr"] = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
121
  params["weight_decay"] = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True)
122
+ params["contr_temp"] = trial.suggest_float("contrastive_temp", 0.01, 0.1)
123
 
124
  # Spectra encoder-related params
125
+ params['formula_dropout'] = trial.suggest_float("peak_dropout", 0.1, 0.5)
126
+ params['formula_attn_heads'] = trial.suggest_categorical("formula_attn_heads", [2, 4, 8])
127
+ params['formula_transformer_layers'] = trial.suggest_categorical("formula_transformer_layers", [1,2,3,4,5])
128
 
129
  choice = trial.suggest_categorical(
130
  "formula_dims",
131
+ ["64,128", "512,256", "256,512", "128", "256", "128,128", "512,512", "64,64,64,64"]
132
  )
133
  params["formula_dims"] = [int(x) for x in choice.split(",")]
134
 
 
136
  params['gnn_dropout'] = trial.suggest_float("gnn_dropout", 0.1, 0.5)
137
  choice = trial.suggest_categorical(
138
  "gnn_channels",
139
+ ["64,128", "128,256", "256,512", "64,128,128", "128,128", "64,64,64"]
140
  )
141
  params["gnn_channels"] = [int(x) for x in choice.split(",")]
142
 
143
  # Ensure last layer matches final embedding dim
144
+ final_embedding_dim = trial.suggest_categorical("final_embedding_dim", [64,256,512,1024])
145
  params['formula_dims'].append(final_embedding_dim)
146
  params['gnn_channels'].append(final_embedding_dim)
147
 
 
229
  with open(args.param_pth) as f:
230
  params = yaml.load(f, Loader=yaml.FullLoader)
231
 
232
+ # now = datetime.datetime.now().strftime("%Y%m%d")
233
+ # base_dir = str(TEST_RESULTS_DIR / f"{now}_{params['run_name']}_optuna")
234
+ base_dir = "/data/yzhouc01/FILIP-MS/experiments/20250916_simple_model_optuna"
235
  os.makedirs(base_dir, exist_ok=True)
236
  params["experiment_dir"] = base_dir
237
 
 
240
  setup_logging(log_path)
241
 
242
  trial_times = []
243
+ study_name = "filip_contrastive"
244
+ storage = f"sqlite:///{base_dir}/optuna_study.db"
245
 
246
+ study = optuna.create_study(study_name=study_name, storage=storage, direction="minimize", pruner=optuna.pruners.MedianPruner(), load_if_exists=True)
247
  study.optimize(lambda trial: objective(trial, params, trial_times, base_dir, args.n_trials), n_trials=args.n_trials)
248
 
249
  # Print best trial