Defetya commited on
Commit
0896173
·
verified ·
1 Parent(s): e426db9

Upload moleculenet_eval/eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. moleculenet_eval/eval.py +116 -16
moleculenet_eval/eval.py CHANGED
@@ -5,7 +5,7 @@ import torch.nn as nn
5
  import torch.optim as optim
6
  from torch.utils.data import Dataset, DataLoader
7
  from transformers import BertConfig, BertModel, AutoTokenizer
8
- from rdkit import Chem
9
  from rdkit.Chem.Scaffolds import MurckoScaffold
10
  import copy
11
  from tqdm import tqdm
@@ -13,9 +13,68 @@ import os
13
  from sklearn.metrics import roc_auc_score, root_mean_squared_error, mean_absolute_error
14
  from itertools import compress
15
  from collections import defaultdict
 
 
 
16
 
17
  torch.set_float32_matmul_precision('high')
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # --- 1. Data Loading ---
20
  def load_lists_from_url(data):
21
  if data == 'bbbp':
@@ -207,7 +266,7 @@ def train_epoch(model, dataloader, optimizer, scheduler, criterion, device):
207
  loss = criterion(outputs, labels)
208
  loss.backward()
209
  optimizer.step()
210
- scheduler.step()
211
  total_loss += loss.item()
212
  return total_loss / len(dataloader)
213
 
@@ -236,6 +295,38 @@ def test_model(model, dataloader, device):
236
  all_labels.append(labels.numpy())
237
  return np.concatenate(all_preds), np.concatenate(all_labels)
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  # --- 6. Main Execution Block ---
240
  def main():
241
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -244,16 +335,17 @@ def main():
244
  DATASETS_TO_RUN = {
245
  # 'esol': {'task_type': 'regression', 'num_labels': 1, 'split': 'random'},
246
  #'tox21': {'task_type': 'classification', 'num_labels': 12, 'split': 'random'},
247
- #'hiv': {'task_type': 'classification', 'num_labels': 27, 'split': 'scaffold'},
248
  # Add more datasets here, e.g. 'bbbp': {'task_type': 'classification', 'num_labels': 1, 'split': 'random'},
249
  #'sider': {'task_type': 'classification', 'num_labels': 27, 'split': 'random'},
250
  #'bace': {'task_type': 'classification', 'num_labels': 1, 'split': 'random'},
251
- 'clintox': {'task_type': 'classification', 'num_labels': 2, 'split': 'scaffold'}
 
252
  }
253
  PATIENCE = 15
254
- EPOCHS = 100
255
- LEARNING_RATE = 2e-5
256
- BATCH_SIZE = 128
257
  MAX_LEN = 512
258
 
259
  TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
@@ -302,18 +394,18 @@ def main():
302
  model = SimSonClassifier(encoder, num_labels=info['num_labels']).to(DEVICE)
303
  model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin')
304
  criterion = get_criterion(info['task_type'], info['num_labels'])
305
- optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
306
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS * len(train_loader))
307
 
308
- best_val_loss = float('inf')
309
  best_model_state = None
310
  current_patience = 0
311
  for epoch in range(EPOCHS):
312
  train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, DEVICE)
313
- val_loss = eval_epoch(model, val_loader, criterion, DEVICE)
314
- print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
315
 
316
- if val_loss < best_val_loss:
317
  best_val_loss = val_loss
318
  best_model_state = copy.deepcopy(model.state_dict())
319
  print(f" -> New best model saved with validation loss: {best_val_loss:.4f}")
@@ -325,7 +417,8 @@ def main():
325
  break
326
 
327
  print("\nTesting with the best model...")
328
- model.load_state_dict(best_model_state)
 
329
  test_loss = eval_epoch(model, test_loader, criterion, DEVICE)
330
  print(f'Test loss: {test_loss}')
331
  test_preds, test_true = test_model(model, test_loader, DEVICE)
@@ -336,6 +429,15 @@ def main():
336
  'test_labels': test_true
337
  }
338
  print(f"Finished testing for {name}.")
 
 
 
 
 
 
 
 
 
339
 
340
  print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}")
341
  for name, result in aggregated_results.items():
@@ -352,6 +454,4 @@ def main():
352
  print("\nScript finished.")
353
 
354
  if __name__ == '__main__':
355
- # Note: This script requires rdkit. You can install it via pip:
356
- # pip install rdkit-pypi
357
  main()
 
5
  import torch.optim as optim
6
  from torch.utils.data import Dataset, DataLoader
7
  from transformers import BertConfig, BertModel, AutoTokenizer
8
+ from rdkit import Chem, RDLogger
9
  from rdkit.Chem.Scaffolds import MurckoScaffold
10
  import copy
11
  from tqdm import tqdm
 
13
  from sklearn.metrics import roc_auc_score, root_mean_squared_error, mean_absolute_error
14
  from itertools import compress
15
  from collections import defaultdict
16
+ from sklearn.metrics.pairwise import cosine_similarity
17
+ RDLogger.DisableLog('rdApp.*')
18
+
19
 
20
  torch.set_float32_matmul_precision('high')
21
 
22
+ # --- 0. Smiles enumeration
23
+ class SmilesEnumerator:
24
+ """Generates randomized SMILES strings for data augmentation."""
25
+ def randomize_smiles(self, smiles):
26
+ try:
27
+ mol = Chem.MolFromSmiles(smiles)
28
+ return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles
29
+ except:
30
+ return smiles
31
+
32
+
33
+ def compute_embedding_similarity(encoder, smiles_list, tokenizer, device, max_len=256):
34
+ encoder.eval()
35
+ enumerator = SmilesEnumerator()
36
+
37
+ embeddings_orig = []
38
+ embeddings_aug = []
39
+
40
+ with torch.no_grad():
41
+ for smi in smiles_list:
42
+ # Original SMILES encoding
43
+ encoding_orig = tokenizer(
44
+ smi,
45
+ truncation=True,
46
+ padding='max_length',
47
+ max_length=max_len,
48
+ return_tensors='pt'
49
+ )
50
+ # Augmented SMILES encoding
51
+ smi_aug = enumerator.randomize_smiles(smi)
52
+ encoding_aug = tokenizer(
53
+ smi_aug,
54
+ truncation=True,
55
+ padding='max_length',
56
+ max_length=max_len,
57
+ return_tensors='pt'
58
+ )
59
+
60
+ input_ids_orig = encoding_orig.input_ids.to(device)
61
+ attention_mask_orig = encoding_orig.attention_mask.to(device)
62
+ input_ids_aug = encoding_aug.input_ids.to(device)
63
+ attention_mask_aug = encoding_aug.attention_mask.to(device)
64
+
65
+ emb_orig = encoder(input_ids_orig, attention_mask_orig).cpu().numpy().flatten()
66
+ emb_aug = encoder(input_ids_aug, attention_mask_aug).cpu().numpy().flatten()
67
+
68
+ embeddings_orig.append(emb_orig)
69
+ embeddings_aug.append(emb_aug)
70
+
71
+ embeddings_orig = np.array(embeddings_orig)
72
+ embeddings_aug = np.array(embeddings_aug)
73
+
74
+ # Cosine similarity between each original and its augmented version
75
+ similarities = np.array([cosine_similarity([embeddings_orig[i]], [embeddings_aug[i]])[0][0] for i in range(len(embeddings_orig))])
76
+ return similarities
77
+
78
  # --- 1. Data Loading ---
79
  def load_lists_from_url(data):
80
  if data == 'bbbp':
 
266
  loss = criterion(outputs, labels)
267
  loss.backward()
268
  optimizer.step()
269
+ #scheduler.step()
270
  total_loss += loss.item()
271
  return total_loss / len(dataloader)
272
 
 
295
  all_labels.append(labels.numpy())
296
  return np.concatenate(all_preds), np.concatenate(all_labels)
297
 
298
+ def calc_val_metrics(model, dataloader, criterion, device, task_type):
299
+ model.eval()
300
+ all_labels, all_preds = [], []
301
+ total_loss = 0
302
+ with torch.no_grad():
303
+ for batch in dataloader:
304
+ inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
305
+ labels = batch['labels'].to(device)
306
+ outputs = model(**inputs)
307
+ loss = criterion(outputs, labels)
308
+ total_loss += loss.item()
309
+ if task_type == 'classification':
310
+ pred_probs = torch.sigmoid(outputs).cpu().numpy()
311
+ all_preds.append(pred_probs)
312
+ all_labels.append(labels.cpu().numpy())
313
+ else:
314
+ # Regression
315
+ preds = outputs.cpu().numpy()
316
+ all_preds.append(preds)
317
+ all_labels.append(labels.cpu().numpy())
318
+ avg_loss = total_loss / len(dataloader)
319
+ if task_type == 'classification':
320
+ y_true = np.concatenate(all_labels)
321
+ y_pred = np.concatenate(all_preds)
322
+ try:
323
+ score = roc_auc_score(y_true, y_pred, average='macro')
324
+ except Exception:
325
+ score = 0.0
326
+ return avg_loss, score
327
+ else:
328
+ return avg_loss, None
329
+
330
  # --- 6. Main Execution Block ---
331
  def main():
332
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
335
  DATASETS_TO_RUN = {
336
  # 'esol': {'task_type': 'regression', 'num_labels': 1, 'split': 'random'},
337
  #'tox21': {'task_type': 'classification', 'num_labels': 12, 'split': 'random'},
338
+ #'hiv': {'task_type': 'classification', 'num_labels': 1, 'split': 'scaffold'},
339
  # Add more datasets here, e.g. 'bbbp': {'task_type': 'classification', 'num_labels': 1, 'split': 'random'},
340
  #'sider': {'task_type': 'classification', 'num_labels': 27, 'split': 'random'},
341
  #'bace': {'task_type': 'classification', 'num_labels': 1, 'split': 'random'},
342
+ 'clintox': {'task_type': 'classification', 'num_labels': 2, 'split': 'random'},
343
+ #'bbbp': {'task_type': 'classification', 'num_labels': 1, 'split': 'scaffold'}
344
  }
345
  PATIENCE = 15
346
+ EPOCHS = 50
347
+ LEARNING_RATE = 1e-4
348
+ BATCH_SIZE = 16
349
  MAX_LEN = 512
350
 
351
  TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
 
394
  model = SimSonClassifier(encoder, num_labels=info['num_labels']).to(DEVICE)
395
  model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin')
396
  criterion = get_criterion(info['task_type'], info['num_labels'])
397
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0024)
398
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.59298)
399
 
400
+ best_val_loss = float('-inf')
401
  best_model_state = None
402
  current_patience = 0
403
  for epoch in range(EPOCHS):
404
  train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, DEVICE)
405
+ val_loss, val_metric = calc_val_metrics(model, val_loader, criterion, 'cuda', info['task_type'])
406
+ print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | ROC AUC: {val_metric:.4f}")
407
 
408
+ if val_metric <= val_loss:
409
  best_val_loss = val_loss
410
  best_model_state = copy.deepcopy(model.state_dict())
411
  print(f" -> New best model saved with validation loss: {best_val_loss:.4f}")
 
417
  break
418
 
419
  print("\nTesting with the best model...")
420
+ if not best_model_state is None:
421
+ model.load_state_dict(best_model_state)
422
  test_loss = eval_epoch(model, test_loader, criterion, DEVICE)
423
  print(f'Test loss: {test_loss}')
424
  test_preds, test_true = test_model(model, test_loader, DEVICE)
 
429
  'test_labels': test_true
430
  }
431
  print(f"Finished testing for {name}.")
432
+ test_smiles_list = list(test_smiles)
433
+ similarities = compute_embedding_similarity(
434
+ model.encoder, test_smiles_list, TOKENIZER, DEVICE, MAX_LEN
435
+ )
436
+ print(f"Similarity score: {similarities.mean():.4f}")
437
+ if name == 'do_not_save':
438
+ torch.save(model.encoder.state_dict(), 'moleculenet_clintox_encoder.bin')
439
+
440
+
441
 
442
  print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}")
443
  for name, result in aggregated_results.items():
 
454
  print("\nScript finished.")
455
 
456
  if __name__ == '__main__':
 
 
457
  main()