Defetya commited on
Commit
e426db9
·
verified ·
1 Parent(s): aea539f

Upload moleculenet_eval/eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. moleculenet_eval/eval.py +46 -88
moleculenet_eval/eval.py CHANGED
@@ -17,11 +17,7 @@ from collections import defaultdict
17
  torch.set_float32_matmul_precision('high')
18
 
19
  # --- 1. Data Loading ---
20
- # Function to load datasets from their respective URLs.
21
  def load_lists_from_url(data):
22
- """
23
- Load SMILES and labels from Moleculenet website.
24
- """
25
  if data == 'bbbp':
26
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv')
27
  smiles, labels = df.smiles, df.p_np
@@ -35,7 +31,7 @@ def load_lists_from_url(data):
35
  elif data == 'sider':
36
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip')
37
  smiles = df.smiles
38
- labels = df.drop(['smiles'], axis=1) # (1427, 27)
39
  elif data == 'esol':
40
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv')
41
  smiles = df.smiles
@@ -49,27 +45,20 @@ def load_lists_from_url(data):
49
  smiles, labels = df.smiles, df['exp']
50
  elif data == 'tox21':
51
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip')
52
- df = df.dropna(axis=0, how='any').reset_index(drop=True) # drop nan values
53
  smiles = df.smiles
54
- labels = df.drop(['mol_id', 'smiles'], axis=1) # 12 cols
55
  elif data == 'bace':
56
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv')
57
  smiles, labels = df.mol, df.Class
58
- elif data == 'tox21':
59
- df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip')
60
- df = df.dropna(axis=0, how='any').reset_index(drop=True) # drop nan values
61
- smiles = df.smiles
62
- labels = df.drop(['mol_id', 'smiles'], axis=1) # 12 cols
63
  elif data == 'qm8':
64
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv')
65
- df = df.dropna(axis=0, how='any').reset_index(drop=True) # drop nan values
66
  smiles = df.smiles
67
- labels = df.drop(['smiles', 'E2-PBE0.1', 'E1-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1'], axis=1) # 12 tasks
68
-
69
  return smiles, labels
70
 
71
  # --- 2. Scaffold Splitting ---
72
- # Class to split the dataset based on molecular scaffolds.
73
  class ScaffoldSplitter:
74
  def __init__(self, data, seed, train_frac=0.8, val_frac=0.1, test_frac=0.1, include_chirality=True):
75
  self.data = data
@@ -86,28 +75,20 @@ class ScaffoldSplitter:
86
 
87
  def scaffold_split(self):
88
  smiles, labels = load_lists_from_url(self.data)
89
-
90
- # Initialize non_null as False for all samples
91
  non_null = np.ones(len(smiles)) == 0
92
 
93
- # Dataset-specific null handling
94
- if self.data == 'tox21' or self.data == 'sider' or self.data == 'clintox':
95
  for i in range(len(smiles)):
96
- # Check if molecule is valid AND no missing labels
97
  if Chem.MolFromSmiles(smiles[i]) and labels.loc[i].isnull().sum() == 0:
98
  non_null[i] = 1
99
  else:
100
- # For single-task datasets, only check molecule validity
101
  for i in range(len(smiles)):
102
  if Chem.MolFromSmiles(smiles[i]):
103
  non_null[i] = 1
104
 
105
- # Extract valid samples with original indices preserved
106
  smiles_list = list(compress(enumerate(smiles), non_null))
107
-
108
  rng = np.random.RandomState(self.seed)
109
 
110
- # Group by scaffold
111
  scaffolds = defaultdict(list)
112
  for i, sms in smiles_list:
113
  scaffold = self.generate_scaffold(sms)
@@ -115,13 +96,10 @@ class ScaffoldSplitter:
115
 
116
  scaffold_sets = list(scaffolds.values())
117
  rng.shuffle(scaffold_sets)
118
- # Calculate target sizes for validation and test sets
119
  n_total_val = int(np.floor(self.val_frac * len(smiles_list)))
120
  n_total_test = int(np.floor(self.test_frac * len(smiles_list)))
121
-
122
  train_idx, val_idx, test_idx = [], [], []
123
 
124
- # Assign scaffold groups to splits
125
  for scaffold_set in scaffold_sets:
126
  if len(val_idx) + len(scaffold_set) <= n_total_val:
127
  val_idx.extend(scaffold_set)
@@ -129,10 +107,20 @@ class ScaffoldSplitter:
129
  test_idx.extend(scaffold_set)
130
  else:
131
  train_idx.extend(scaffold_set)
132
-
133
  return train_idx, val_idx, test_idx
 
 
 
 
 
 
 
 
 
 
 
 
134
  # --- 3. PyTorch Dataset ---
135
- # Custom Dataset class for handling SMILES data.
136
  class MoleculeDataset(Dataset):
137
  def __init__(self, smiles_list, labels, tokenizer, max_len=512):
138
  self.smiles_list = smiles_list
@@ -154,25 +142,16 @@ class MoleculeDataset(Dataset):
154
  max_length=self.max_len,
155
  return_tensors='pt'
156
  )
157
-
158
  item = {key: val.squeeze(0) for key, val in encoding.items()}
159
-
160
- # Handle single-task and multi-task labels
161
  if isinstance(label, pd.Series):
162
  label_values = label.values.astype(np.float32)
163
  else:
164
  label_values = np.array([label], dtype=np.float32)
165
-
166
  item['labels'] = torch.tensor(label_values, dtype=torch.float)
167
  return item
168
 
169
  # --- 4. Model Architecture ---
170
  def global_ap(x):
171
- """
172
- Global Average Pooling
173
- Input: [B, max_len, hid_dim]
174
- Return: [B, hid_dim]
175
- """
176
  return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)
177
 
178
  class SimSonEncoder(nn.Module):
@@ -183,7 +162,6 @@ class SimSonEncoder(nn.Module):
183
  self.bert = BertModel(config, add_pooling_layer=False)
184
  self.linear = nn.Linear(config.hidden_size, max_len)
185
  self.dropout = nn.Dropout(dropout)
186
-
187
  def forward(self, input_ids, attention_mask=None):
188
  if attention_mask is None:
189
  attention_mask = input_ids.ne(self.config.pad_token_id)
@@ -199,7 +177,6 @@ class SimSonClassifier(nn.Module):
199
  self.clf = nn.Linear(encoder.max_len, num_labels)
200
  self.relu = nn.ReLU()
201
  self.dropout = nn.Dropout(dropout)
202
-
203
  def forward(self, input_ids, attention_mask=None):
204
  x = self.encoder(input_ids, attention_mask)
205
  x = self.relu(self.dropout(x))
@@ -207,13 +184,11 @@ class SimSonClassifier(nn.Module):
207
  return logits
208
 
209
  def load_encoder_params(self, state_dict_path):
210
- """Loads pretrained parameters into the SimSonEncoder."""
211
  self.encoder.load_state_dict(torch.load(state_dict_path))
212
  print("Pretrained encoder parameters loaded.")
213
 
214
  # --- 5. Training, Validation, and Testing Loops ---
215
  def get_criterion(task_type, num_labels):
216
- """Select loss function based on task."""
217
  if task_type == 'classification':
218
  return nn.BCEWithLogitsLoss()
219
  elif task_type == 'regression':
@@ -227,14 +202,12 @@ def train_epoch(model, dataloader, optimizer, scheduler, criterion, device):
227
  for batch in dataloader:
228
  inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
229
  labels = batch['labels'].to(device)
230
-
231
  optimizer.zero_grad()
232
  outputs = model(**inputs)
233
  loss = criterion(outputs, labels)
234
  loss.backward()
235
  optimizer.step()
236
  scheduler.step()
237
-
238
  total_loss += loss.item()
239
  return total_loss / len(dataloader)
240
 
@@ -258,40 +231,31 @@ def test_model(model, dataloader, device):
258
  inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
259
  labels = batch['labels']
260
  outputs = model(**inputs)
261
-
262
- # Apply sigmoid for classification probabilities
263
  preds = torch.sigmoid(outputs)
264
-
265
  all_preds.append(preds.cpu().numpy())
266
  all_labels.append(labels.numpy())
267
-
268
  return np.concatenate(all_preds), np.concatenate(all_labels)
269
 
270
  # --- 6. Main Execution Block ---
271
  def main():
272
- # --- Configuration ---
273
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
274
  print(f"Using device: {DEVICE}")
275
-
276
  DATASETS_TO_RUN = {
277
- #'esol': {'task_type': 'regression', 'num_labels': 1},
278
- #'freesolv': {'task_type': 'regression', 'num_labels':1},
279
- #'lipophicility': {'task_type': 'regression', 'num_labels': 1},
280
- #'qm8': {'task_type': 'regression', 'num_labels': 12},
281
- #'bbbp': {'task_type': 'classification', 'num_labels': 1},
282
- 'tox21': {'task_type': 'classification', 'num_labels': 12},
283
- #'sider': {'task_type': 'classification', 'num_labels': 27},
284
- #'clintox': {'task_type': 'classification', 'num_labels': 2},
285
- #'hiv': {'task_type': 'classification', 'num_labels': 1},
286
- #'bace': {'task_type': 'classification', 'num_labels': 1},
287
  }
288
- PATIENCE = 25
289
- EPOCHS = 200
290
  LEARNING_RATE = 2e-5
291
  BATCH_SIZE = 128
292
- MAX_LEN = 256
293
 
294
- # --- Tokenizer and Model Config ---
295
  TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
296
  ENCODER_CONFIG = BertConfig(
297
  vocab_size=TOKENIZER.vocab_size,
@@ -305,24 +269,24 @@ def main():
305
  aggregated_results = {}
306
 
307
  for name, info in DATASETS_TO_RUN.items():
308
- print(f"\n{'='*20} Processing Dataset: {name.upper()} {'='*20}")
309
-
310
- # --- Data Loading and Splitting ---
311
- splitter = ScaffoldSplitter(data=name, seed=42)
312
- train_idx, val_idx, test_idx = splitter.scaffold_split()
313
-
314
- # Load data once
315
  smiles, labels = load_lists_from_url(name)
316
-
317
- # Extract splits using returned indices
 
 
 
 
 
 
 
 
318
  train_smiles = smiles.iloc[train_idx].reset_index(drop=True)
319
  train_labels = labels.iloc[train_idx].reset_index(drop=True)
320
-
321
  val_smiles = smiles.iloc[val_idx].reset_index(drop=True)
322
  val_labels = labels.iloc[val_idx].reset_index(drop=True)
323
-
324
  test_smiles = smiles.iloc[test_idx].reset_index(drop=True)
325
- test_labels = labels.iloc[test_idx].reset_index(drop=True)
326
  print(f"Data split - Train: {len(train_smiles)}, Val: {len(val_smiles)}, Test: {len(test_smiles)}")
327
 
328
  train_dataset = MoleculeDataset(train_smiles, train_labels, TOKENIZER, MAX_LEN)
@@ -333,16 +297,14 @@ def main():
333
  val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
334
  test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
335
 
336
- # --- Model, Loss, and Optimizer ---
337
  encoder = SimSonEncoder(ENCODER_CONFIG, 512)
338
  encoder = torch.compile(encoder)
339
  model = SimSonClassifier(encoder, num_labels=info['num_labels']).to(DEVICE)
340
  model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin')
341
-
342
  criterion = get_criterion(info['task_type'], info['num_labels'])
343
  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
344
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS * len(train_loader))
345
- # --- Training and Validation ---
346
  best_val_loss = float('inf')
347
  best_model_state = None
348
  current_patience = 0
@@ -362,12 +324,12 @@ def main():
362
  print(f'Early stopping at {PATIENCE} epochs')
363
  break
364
 
365
- # --- Testing ---
366
  print("\nTesting with the best model...")
367
  model.load_state_dict(best_model_state)
 
 
368
  test_preds, test_true = test_model(model, test_loader, DEVICE)
369
-
370
- # Store results. For classification, you can now calculate metrics like ROC-AUC.
371
  aggregated_results[name] = {
372
  'best_val_loss': best_val_loss,
373
  'test_predictions': test_preds,
@@ -375,19 +337,15 @@ def main():
375
  }
376
  print(f"Finished testing for {name}.")
377
 
378
- # --- Final Results Aggregation ---
379
  print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}")
380
  for name, result in aggregated_results.items():
381
- # Here you would typically calculate and display final metrics from predictions
382
- # For example, using scikit-learn's roc_auc_score
383
- # from sklearn.metrics import roc_auc_score
384
  if name in ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace']:
385
  auc = roc_auc_score(result['test_labels'], result['test_predictions'], average='macro')
386
  print(f'{name} ROC AUC: {auc}')
387
 
388
  if name in ['lipophicility', 'esol', 'qm8']:
389
  rmse = root_mean_squared_error(result['test_labels'], result['test_predictions'])
390
- mae = mean_absolute_error(result['test_labels'], result['test_predictions'])
391
  print(f'{name} MAE: {mae}')
392
  print(f'{name} RMSE: {rmse}')
393
 
 
17
  torch.set_float32_matmul_precision('high')
18
 
19
  # --- 1. Data Loading ---
 
20
  def load_lists_from_url(data):
 
 
 
21
  if data == 'bbbp':
22
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv')
23
  smiles, labels = df.smiles, df.p_np
 
31
  elif data == 'sider':
32
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip')
33
  smiles = df.smiles
34
+ labels = df.drop(['smiles'], axis=1)
35
  elif data == 'esol':
36
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv')
37
  smiles = df.smiles
 
45
  smiles, labels = df.smiles, df['exp']
46
  elif data == 'tox21':
47
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip')
48
+ df = df.dropna(axis=0, how='any').reset_index(drop=True)
49
  smiles = df.smiles
50
+ labels = df.drop(['mol_id', 'smiles'], axis=1)
51
  elif data == 'bace':
52
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv')
53
  smiles, labels = df.mol, df.Class
 
 
 
 
 
54
  elif data == 'qm8':
55
  df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv')
56
+ df = df.dropna(axis=0, how='any').reset_index(drop=True)
57
  smiles = df.smiles
58
+ labels = df.drop(['smiles', 'E2-PBE0.1', 'E1-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1'], axis=1)
 
59
  return smiles, labels
60
 
61
  # --- 2. Scaffold Splitting ---
 
62
  class ScaffoldSplitter:
63
  def __init__(self, data, seed, train_frac=0.8, val_frac=0.1, test_frac=0.1, include_chirality=True):
64
  self.data = data
 
75
 
76
  def scaffold_split(self):
77
  smiles, labels = load_lists_from_url(self.data)
 
 
78
  non_null = np.ones(len(smiles)) == 0
79
 
80
+ if self.data in {'tox21', 'sider', 'clintox'}:
 
81
  for i in range(len(smiles)):
 
82
  if Chem.MolFromSmiles(smiles[i]) and labels.loc[i].isnull().sum() == 0:
83
  non_null[i] = 1
84
  else:
 
85
  for i in range(len(smiles)):
86
  if Chem.MolFromSmiles(smiles[i]):
87
  non_null[i] = 1
88
 
 
89
  smiles_list = list(compress(enumerate(smiles), non_null))
 
90
  rng = np.random.RandomState(self.seed)
91
 
 
92
  scaffolds = defaultdict(list)
93
  for i, sms in smiles_list:
94
  scaffold = self.generate_scaffold(sms)
 
96
 
97
  scaffold_sets = list(scaffolds.values())
98
  rng.shuffle(scaffold_sets)
 
99
  n_total_val = int(np.floor(self.val_frac * len(smiles_list)))
100
  n_total_test = int(np.floor(self.test_frac * len(smiles_list)))
 
101
  train_idx, val_idx, test_idx = [], [], []
102
 
 
103
  for scaffold_set in scaffold_sets:
104
  if len(val_idx) + len(scaffold_set) <= n_total_val:
105
  val_idx.extend(scaffold_set)
 
107
  test_idx.extend(scaffold_set)
108
  else:
109
  train_idx.extend(scaffold_set)
 
110
  return train_idx, val_idx, test_idx
111
+
112
+ # --- 2a. Normal Random Split ---
113
+ def random_split_indices(n, seed=42, train_frac=0.8, val_frac=0.1, test_frac=0.1):
114
+ np.random.seed(seed)
115
+ indices = np.random.permutation(n)
116
+ n_train = int(n * train_frac)
117
+ n_val = int(n * val_frac)
118
+ train_idx = indices[:n_train]
119
+ val_idx = indices[n_train:n_train+n_val]
120
+ test_idx = indices[n_train+n_val:]
121
+ return train_idx.tolist(), val_idx.tolist(), test_idx.tolist()
122
+
123
  # --- 3. PyTorch Dataset ---
 
124
  class MoleculeDataset(Dataset):
125
  def __init__(self, smiles_list, labels, tokenizer, max_len=512):
126
  self.smiles_list = smiles_list
 
142
  max_length=self.max_len,
143
  return_tensors='pt'
144
  )
 
145
  item = {key: val.squeeze(0) for key, val in encoding.items()}
 
 
146
  if isinstance(label, pd.Series):
147
  label_values = label.values.astype(np.float32)
148
  else:
149
  label_values = np.array([label], dtype=np.float32)
 
150
  item['labels'] = torch.tensor(label_values, dtype=torch.float)
151
  return item
152
 
153
  # --- 4. Model Architecture ---
154
  def global_ap(x):
 
 
 
 
 
155
  return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)
156
 
157
  class SimSonEncoder(nn.Module):
 
162
  self.bert = BertModel(config, add_pooling_layer=False)
163
  self.linear = nn.Linear(config.hidden_size, max_len)
164
  self.dropout = nn.Dropout(dropout)
 
165
  def forward(self, input_ids, attention_mask=None):
166
  if attention_mask is None:
167
  attention_mask = input_ids.ne(self.config.pad_token_id)
 
177
  self.clf = nn.Linear(encoder.max_len, num_labels)
178
  self.relu = nn.ReLU()
179
  self.dropout = nn.Dropout(dropout)
 
180
  def forward(self, input_ids, attention_mask=None):
181
  x = self.encoder(input_ids, attention_mask)
182
  x = self.relu(self.dropout(x))
 
184
  return logits
185
 
186
  def load_encoder_params(self, state_dict_path):
 
187
  self.encoder.load_state_dict(torch.load(state_dict_path))
188
  print("Pretrained encoder parameters loaded.")
189
 
190
  # --- 5. Training, Validation, and Testing Loops ---
191
  def get_criterion(task_type, num_labels):
 
192
  if task_type == 'classification':
193
  return nn.BCEWithLogitsLoss()
194
  elif task_type == 'regression':
 
202
  for batch in dataloader:
203
  inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
204
  labels = batch['labels'].to(device)
 
205
  optimizer.zero_grad()
206
  outputs = model(**inputs)
207
  loss = criterion(outputs, labels)
208
  loss.backward()
209
  optimizer.step()
210
  scheduler.step()
 
211
  total_loss += loss.item()
212
  return total_loss / len(dataloader)
213
 
 
231
  inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
232
  labels = batch['labels']
233
  outputs = model(**inputs)
 
 
234
  preds = torch.sigmoid(outputs)
 
235
  all_preds.append(preds.cpu().numpy())
236
  all_labels.append(labels.numpy())
 
237
  return np.concatenate(all_preds), np.concatenate(all_labels)
238
 
239
  # --- 6. Main Execution Block ---
240
  def main():
 
241
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
242
  print(f"Using device: {DEVICE}")
243
+
244
  DATASETS_TO_RUN = {
245
+ # 'esol': {'task_type': 'regression', 'num_labels': 1, 'split': 'random'},
246
+ #'tox21': {'task_type': 'classification', 'num_labels': 12, 'split': 'random'},
247
+ #'hiv': {'task_type': 'classification', 'num_labels': 27, 'split': 'scaffold'},
248
+ # Add more datasets here, e.g. 'bbbp': {'task_type': 'classification', 'num_labels': 1, 'split': 'random'},
249
+ #'sider': {'task_type': 'classification', 'num_labels': 27, 'split': 'random'},
250
+ #'bace': {'task_type': 'classification', 'num_labels': 1, 'split': 'random'},
251
+ 'clintox': {'task_type': 'classification', 'num_labels': 2, 'split': 'scaffold'}
 
 
 
252
  }
253
+ PATIENCE = 15
254
+ EPOCHS = 100
255
  LEARNING_RATE = 2e-5
256
  BATCH_SIZE = 128
257
+ MAX_LEN = 512
258
 
 
259
  TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
260
  ENCODER_CONFIG = BertConfig(
261
  vocab_size=TOKENIZER.vocab_size,
 
269
  aggregated_results = {}
270
 
271
  for name, info in DATASETS_TO_RUN.items():
272
+ print(f"\n{'='*20} Processing Dataset: {name.upper()} ({info['split']} split) {'='*20}")
 
 
 
 
 
 
273
  smiles, labels = load_lists_from_url(name)
274
+
275
+ # Split selection
276
+ if info.get('split', 'scaffold') == 'scaffold':
277
+ splitter = ScaffoldSplitter(data=name, seed=42)
278
+ train_idx, val_idx, test_idx = splitter.scaffold_split()
279
+ elif info['split'] == 'random':
280
+ train_idx, val_idx, test_idx = random_split_indices(len(smiles), seed=42)
281
+ else:
282
+ raise ValueError(f"Unknown split type for {name}: {info['split']}")
283
+
284
  train_smiles = smiles.iloc[train_idx].reset_index(drop=True)
285
  train_labels = labels.iloc[train_idx].reset_index(drop=True)
 
286
  val_smiles = smiles.iloc[val_idx].reset_index(drop=True)
287
  val_labels = labels.iloc[val_idx].reset_index(drop=True)
 
288
  test_smiles = smiles.iloc[test_idx].reset_index(drop=True)
289
+ test_labels = labels.iloc[test_idx].reset_index(drop=True)
290
  print(f"Data split - Train: {len(train_smiles)}, Val: {len(val_smiles)}, Test: {len(test_smiles)}")
291
 
292
  train_dataset = MoleculeDataset(train_smiles, train_labels, TOKENIZER, MAX_LEN)
 
297
  val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
298
  test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
299
 
 
300
  encoder = SimSonEncoder(ENCODER_CONFIG, 512)
301
  encoder = torch.compile(encoder)
302
  model = SimSonClassifier(encoder, num_labels=info['num_labels']).to(DEVICE)
303
  model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin')
 
304
  criterion = get_criterion(info['task_type'], info['num_labels'])
305
  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
306
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS * len(train_loader))
307
+
308
  best_val_loss = float('inf')
309
  best_model_state = None
310
  current_patience = 0
 
324
  print(f'Early stopping at {PATIENCE} epochs')
325
  break
326
 
 
327
  print("\nTesting with the best model...")
328
  model.load_state_dict(best_model_state)
329
+ test_loss = eval_epoch(model, test_loader, criterion, DEVICE)
330
+ print(f'Test loss: {test_loss}')
331
  test_preds, test_true = test_model(model, test_loader, DEVICE)
332
+
 
333
  aggregated_results[name] = {
334
  'best_val_loss': best_val_loss,
335
  'test_predictions': test_preds,
 
337
  }
338
  print(f"Finished testing for {name}.")
339
 
 
340
  print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}")
341
  for name, result in aggregated_results.items():
 
 
 
342
  if name in ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace']:
343
  auc = roc_auc_score(result['test_labels'], result['test_predictions'], average='macro')
344
  print(f'{name} ROC AUC: {auc}')
345
 
346
  if name in ['lipophicility', 'esol', 'qm8']:
347
  rmse = root_mean_squared_error(result['test_labels'], result['test_predictions'])
348
+ mae = mean_absolute_error(result['test_labels'], result['test_predictions'])
349
  print(f'{name} MAE: {mae}')
350
  print(f'{name} RMSE: {rmse}')
351