Defetya commited on
Commit
aea539f
·
verified ·
1 Parent(s): ce0db25

Upload moleculenet_eval/eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. moleculenet_eval/eval.py +352 -498
moleculenet_eval/eval.py CHANGED
@@ -1,545 +1,399 @@
1
- # ==============================================================================
2
- # 1. IMPORTS
3
- # ==============================================================================
4
- import os
5
- import warnings
6
- import wandb
7
-
8
  import torch
9
  import torch.nn as nn
10
  import torch.optim as optim
11
- import torch.nn.functional as F
12
- from torch.utils.data import DataLoader, Dataset
13
- import numpy as np
 
 
14
  from tqdm import tqdm
15
- from rdkit import Chem, RDLogger
16
- from datasets import load_dataset, load_from_disk
17
- from transformers import AutoTokenizer, BertModel, BertConfig
18
- import pandas as pd
19
 
20
- # ==============================================================================
21
- # 2. INITIAL SETUP
22
- # ==============================================================================
23
- # Suppress RDKit console output
24
- RDLogger.DisableLog('rdApp.*')
25
- # Ignore warnings for cleaner output
26
- warnings.filterwarnings("ignore")
27
-
28
- # ==============================================================================
29
- # 3. MODEL AND LOSS FUNCTION
30
- # ==============================================================================
31
- def global_average_pooling(x):
32
- """Global Average Pooling: from [B, max_len, hid_dim] to [B, hid_dim]"""
33
- return torch.mean(x, dim=1)
34
 
35
- class SimSonEncoder(nn.Module):
36
- """The main encoder model based on BERT."""
37
- def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):
38
- super(SimSonEncoder, self).__init__()
39
- self.bert = BertModel(config, add_pooling_layer=False)
40
- self.linear = nn.Linear(config.hidden_size, max_len)
41
- self.dropout = nn.Dropout(dropout)
42
-
43
- def forward(self, input_ids, attention_mask=None):
44
- if attention_mask is None:
45
- attention_mask = input_ids.ne(self.bert.config.pad_token_id)
46
-
47
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
48
- hidden_states = self.dropout(outputs.last_hidden_state)
49
- pooled_output = global_average_pooling(hidden_states)
50
- return self.linear(pooled_output)
51
-
52
- class ContrastiveLoss(nn.Module):
53
- """Calculates the contrastive loss for the SimSon model."""
54
- def __init__(self, temperature=0.2):
55
- super(ContrastiveLoss, self).__init__()
56
- self.temperature = temperature
57
- self.similarity_fn = F.cosine_similarity
58
-
59
- def forward(self, proj_1, proj_2):
60
- batch_size = proj_1.shape[0]
61
- device = proj_1.device
62
-
63
- # Normalize projections
64
- z_i = F.normalize(proj_1, p=2, dim=1)
65
- z_j = F.normalize(proj_2, p=2, dim=1)
66
-
67
- # Concatenate for similarity matrix calculation
68
- representations = torch.cat([z_i, z_j], dim=0)
69
-
70
- # Calculate cosine similarity between all pairs
71
- similarity_matrix = self.similarity_fn(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
72
-
73
- # Identify positive pairs (original and its augmentation)
74
- sim_ij = torch.diag(similarity_matrix, batch_size)
75
- sim_ji = torch.diag(similarity_matrix, -batch_size)
76
- positives = torch.cat([sim_ij, sim_ji], dim=0)
77
-
78
- # Create a mask to exclude self-comparisons
79
- nominator = torch.exp(positives / self.temperature)
80
- mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool, device=device)).float()
81
- denominator = mask * torch.exp(similarity_matrix / self.temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Calculate the final loss
84
- loss = -torch.log(nominator / torch.sum(denominator, dim=1))
85
- return torch.sum(loss) / (2 * batch_size)
86
-
87
- # ==============================================================================
88
- # 4. DATA HANDLING (Keeping your existing classes unchanged)
89
- # ==============================================================================
90
- class SmilesEnumerator:
91
- """Generates randomized SMILES strings for data augmentation."""
92
- def randomize_smiles(self, smiles):
93
- try:
94
- mol = Chem.MolFromSmiles(smiles)
95
- return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles
96
- except:
97
- return smiles
98
-
99
- class ContrastiveSmilesDataset(Dataset):
100
- """Dataset for creating pairs of augmented SMILES for contrastive learning."""
101
- def __init__(self, smiles_list, tokenizer, max_length=512):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  self.smiles_list = smiles_list
 
103
  self.tokenizer = tokenizer
104
- self.max_length = max_length
105
- self.enumerator = SmilesEnumerator()
106
 
107
  def __len__(self):
108
  return len(self.smiles_list)
109
 
110
  def __getitem__(self, idx):
111
- original_smiles = self.smiles_list[idx]
112
-
113
- # Create two different augmentations of the same SMILES
114
- smiles_1 = self.enumerator.randomize_smiles(original_smiles)
115
- smiles_2 = self.enumerator.randomize_smiles(original_smiles)
 
 
 
 
 
116
 
117
- # Tokenize and do pad. Padding will be handled by the collate_fn.
118
- tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')
119
- tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')
120
-
121
- return {
122
- 'input_ids_1': torch.tensor(tokens_1['input_ids']),
123
- 'attention_mask_1': torch.tensor(tokens_1['attention_mask']),
124
- 'input_ids_2': torch.tensor(tokens_2['input_ids']),
125
- 'attention_mask_2': torch.tensor(tokens_2['attention_mask']),
126
- }
127
-
128
- class PrecomputedContrastiveSmilesDataset(Dataset):
129
- """
130
- A Dataset class that reads pre-augmented SMILES pairs from a Parquet file.
131
- This is significantly faster as it offloads the expensive SMILES randomization
132
- to a one-time preprocessing step.
133
- """
134
- def __init__(self, tokenizer, file_path: str, max_length: int = 512):
135
- self.tokenizer = tokenizer
136
- self.max_length = max_length
137
-
138
- # Load the entire dataset from the Parquet file into memory.
139
- # This is fast and efficient for subsequent access.
140
- print(f"Loading pre-computed data from {file_path}...")
141
- self.data = pd.read_parquet(file_path)
142
- print("Data loaded successfully.")
143
-
144
- def __len__(self):
145
- """Returns the total number of pairs in the dataset."""
146
- return len(self.data)
147
-
148
- def __getitem__(self, idx):
149
- """
150
- Retrieves a pre-augmented pair, tokenizes it, and returns it
151
- in the format expected by the DataCollator.
152
- """
153
- # Retrieve the pre-augmented pair from the DataFrame
154
- row = self.data.iloc[idx]
155
- smiles_1 = row['smiles_1']
156
- smiles_2 = row['smiles_2']
157
 
158
- # Tokenize the pair. This operation is fast and remains in the data loader.
159
- tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')
160
- tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')
 
 
161
 
162
- return {
163
- 'input_ids_1': torch.tensor(tokens_1['input_ids']),
164
- 'attention_mask_1': torch.tensor(tokens_1['attention_mask']),
165
- 'input_ids_2': torch.tensor(tokens_2['input_ids']),
166
- 'attention_mask_2': torch.tensor(tokens_2['attention_mask']),
167
- }
168
 
169
- class PreTokenizedSmilesDataset(Dataset):
 
170
  """
171
- A Dataset that loads a pre-tokenized and pre-padded dataset created
172
- by the preprocessing script. It uses memory-mapping for instant loads
173
- and high efficiency.
174
  """
175
- def __init__(self, dataset_path: str):
176
- # Load the dataset from disk. This is very fast due to memory-mapping.
177
- self.dataset = load_from_disk(dataset_path)
178
- # Set the format to PyTorch tensors for direct use in the model
179
- self.dataset.set_format(type='torch', columns=[
180
- 'input_ids_1', 'attention_mask_1', 'input_ids_2', 'attention_mask_2'
181
- ])
182
- print(f"Successfully loaded pre-tokenized dataset from {dataset_path}.")
183
-
184
- def __len__(self):
185
- """Returns the total number of items in the dataset."""
186
- return len(self.dataset)
187
 
188
- def __getitem__(self, idx):
189
- """Retrieves a single pre-processed item."""
190
- return self.dataset[idx]
191
-
192
- class DataCollatorWithPadding:
193
- """
194
- A collate function that dynamically pads inputs to the longest sequence
195
- across both augmented views in the batch, ensuring consistent tensor shapes.
196
- """
197
- def __init__(self, tokenizer):
198
- self.tokenizer = tokenizer
199
-
200
- def __call__(self, features):
201
- # Create a combined list of features for both views to find the global max length
202
- combined_features = []
203
- for feature in features:
204
- combined_features.append({'input_ids': feature['input_ids_1'], 'attention_mask': feature['attention_mask_1']})
205
- combined_features.append({'input_ids': feature['input_ids_2'], 'attention_mask': feature['attention_mask_2']})
206
-
207
- # Pad the combined batch. This ensures all sequences are padded to the same length.
208
- padded_combined = self.tokenizer.pad(combined_features, padding='longest', return_tensors='pt')
209
-
210
- # Split the padded tensors back into two views
211
- batch_size = len(features)
212
- input_ids_1, input_ids_2 = torch.split(padded_combined['input_ids'], batch_size, dim=0)
213
- attention_mask_1, attention_mask_2 = torch.split(padded_combined['attention_mask'], batch_size, dim=0)
214
 
215
- return {
216
- 'input_ids_1': input_ids_1,
217
- 'attention_mask_1': attention_mask_1,
218
- 'input_ids_2': input_ids_2,
219
- 'attention_mask_2': attention_mask_2,
220
- }
221
-
222
- # ==============================================================================
223
- # 5. CHECKPOINT UTILITIES
224
- # ==============================================================================
225
- def save_checkpoint(model, optimizer, scheduler, global_step, save_path):
226
- """Save complete checkpoint with model, optimizer, scheduler states and step count."""
227
- checkpoint = {
228
- 'model_state_dict': model.state_dict(),
229
- 'optimizer_state_dict': optimizer.state_dict(),
230
- 'scheduler_state_dict': scheduler.state_dict(),
231
- 'global_step': global_step,
232
- }
233
- torch.save(checkpoint, save_path)
234
- print(f"Full checkpoint saved at step {global_step}")
235
-
236
- def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
237
- """Load checkpoint and return the global step to resume from."""
238
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
239
- model.load_state_dict(checkpoint['model_state_dict'])
240
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
241
- scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
242
- global_step = checkpoint['global_step']
243
- print(f"Checkpoint loaded from step {global_step}")
244
- return global_step
245
-
246
- # ==============================================================================
247
- # 6. TRAINING AND EVALUATION LOOPS - MODIFIED
248
- # ==============================================================================
249
- def evaluation_step(model, batch, criterion, device):
250
- """Performs a single evaluation step on a batch of data."""
251
- input_ids_1 = batch['input_ids_1'].to(device)
252
- attention_mask_1 = batch['attention_mask_1'].to(device)
253
- input_ids_2 = batch['input_ids_2'].to(device)
254
- attention_mask_2 = batch['attention_mask_2'].to(device)
255
-
256
- combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0)
257
- combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0)
258
-
259
- with torch.no_grad():
260
- combined_proj = model(combined_input_ids, combined_attention_mask)
261
 
262
- batch_size = input_ids_1.size(0)
263
- proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0)
264
-
265
- loss = criterion(proj_1, proj_2)
266
- return proj_1, proj_2, loss
267
-
268
- def train_with_step_based_validation(model, train_loader, val_loader, optimizer, criterion, device,
269
- scheduler, checkpoint_path, save_steps, validation_steps,
270
- start_step=0, max_steps=None):
271
- """
272
- Modified training function with step-based validation and checkpointing.
273
- """
 
 
 
 
 
 
 
 
 
 
274
  model.train()
275
- global_step = start_step
276
- best_val_loss = float('inf')
277
-
278
- # Calculate total steps if max_steps is not provided
279
- if max_steps is None:
280
- max_steps = len(train_loader)
281
-
282
- progress_bar = tqdm(total=max_steps - start_step, desc="Training Steps", initial=start_step)
283
-
284
- # Create iterator that can be resumed from any point
285
- train_iterator = iter(train_loader)
286
-
287
- # Skip batches if resuming from checkpoint
288
- if start_step > 0:
289
- batches_to_skip = start_step % len(train_loader)
290
- for _ in range(batches_to_skip):
291
- try:
292
- next(train_iterator)
293
- except StopIteration:
294
- train_iterator = iter(train_loader)
295
-
296
- while global_step < max_steps:
297
- try:
298
- batch = next(train_iterator)
299
- except StopIteration:
300
- train_iterator = iter(train_loader)
301
- batch = next(train_iterator)
302
-
303
- # Training step
304
- input_ids_1 = batch['input_ids_1'].to(device)
305
- attention_mask_1 = batch['attention_mask_1'].to(device)
306
- input_ids_2 = batch['input_ids_2'].to(device)
307
- attention_mask_2 = batch['attention_mask_2'].to(device)
308
 
309
  optimizer.zero_grad()
310
- with torch.autocast(dtype=torch.float16, device_type="cuda"):
311
- combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0)
312
- combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0)
313
-
314
- combined_proj = model(combined_input_ids, combined_attention_mask)
315
-
316
- batch_size = input_ids_1.size(0)
317
- proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0)
318
-
319
- loss = criterion(proj_1, proj_2)
320
-
321
  loss.backward()
322
  optimizer.step()
323
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
324
  scheduler.step()
325
-
326
- global_step += 1
327
-
328
- progress_bar.update(1)
329
- progress_bar.set_postfix(loss=f"{loss.item():.4f}", step=global_step)
330
-
331
- wandb.log({
332
- "train_batch_loss": loss.item(),
333
- "learning_rate": scheduler.get_last_lr()[0],
334
- "global_step": global_step
335
- })
336
-
337
- # Step-based validation
338
- if global_step % validation_steps == 0:
339
- val_loss = validate_epoch(model, val_loader, criterion, device)
340
- wandb.log({
341
- "val_loss": val_loss,
342
- "global_step": global_step
343
- })
344
-
345
- # Save best model (model state only for best checkpoint)
346
- if val_loss < best_val_loss:
347
- best_val_loss = val_loss
348
- model_save_path = checkpoint_path.replace('.pt', '_best_model.bin')
349
- torch.save(model.state_dict(), model_save_path)
350
- progress_bar.write(f"Step {global_step}: New best model saved with val loss {val_loss:.4f}")
351
-
352
- model.train() # Resume training mode after validation
353
-
354
- # Step-based checkpointing (full checkpoint)
355
- if global_step % save_steps == 0:
356
- save_checkpoint(model, optimizer, scheduler, global_step, checkpoint_path)
357
-
358
- progress_bar.close()
359
- return global_step
360
 
361
- def validate_epoch(model, val_loader, criterion, device):
362
- """Validation function - unchanged from original."""
363
- model.eval()
364
- total_loss = 0
365
- progress_bar = tqdm(val_loader, desc="Validating", leave=False)
366
-
367
- for batch in progress_bar:
368
- _, _, loss = evaluation_step(model, batch, criterion, device)
369
  total_loss += loss.item()
370
-
371
- avg_loss = total_loss / len(val_loader)
372
- print(f'Validation loss: {avg_loss:.4f}')
373
- return avg_loss
374
 
375
- def test_model(model, test_loader, criterion, device):
376
- """Test function - unchanged from original."""
377
  model.eval()
378
  total_loss = 0
379
- all_similarities = []
380
- progress_bar = tqdm(test_loader, desc="Testing", leave=False)
381
-
382
- for batch in progress_bar:
383
- proj_1, proj_2, loss = evaluation_step(model, batch, criterion, device)
384
- total_loss += loss.item()
385
-
386
- proj_1_norm = F.normalize(proj_1, p=2, dim=1)
387
- proj_2_norm = F.normalize(proj_2, p=2, dim=1)
388
- batch_similarities = F.cosine_similarity(proj_1_norm, proj_2_norm, dim=1)
389
- all_similarities.extend(batch_similarities.cpu().numpy())
390
-
391
- avg_loss = total_loss / len(test_loader)
392
- avg_sim = np.mean(all_similarities)
393
- std_sim = np.std(all_similarities)
394
-
395
- return avg_loss, avg_sim, std_sim
396
-
397
- # ==============================================================================
398
- # 7. MODIFIED SINGLE-GPU TRAINING
399
- # ==============================================================================
400
- def run_training(model_config, hparams, data_splits):
401
- """The main function to run the training and evaluation process with step-based validation."""
402
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
403
- print(f"Using device: {device}")
404
-
405
- wandb_key = os.getenv("WANDB_API_KEY")
406
- if wandb_key:
407
- wandb.login(key=wandb_key)
408
- wandb.init(
409
- #project="simson-contrastive-learning-single-gpu",
410
- #name=f"run-{wandb.util.generate_id()}",
411
- #config=hparams
412
- )
413
-
414
- train_smiles, val_smiles, test_smiles = data_splits
415
-
416
- tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
417
-
418
- precomputed_train_path = 'data/pubchem_119m_splits/train.parquet'
419
- precomputed_test_path = 'data/pubchem_119m_splits/test.parquet'
420
- precomputed_val_path = 'data/pubchem_119m_splits/validation.parquet'
421
-
422
- train_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_train_path, max_length=hparams['max_length'])
423
- test_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_test_path, max_length=hparams['max_length'])
424
- val_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_val_path, max_length=hparams['max_length'])
425
-
426
- train_loader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, num_workers=16, prefetch_factor=128, pin_memory=True)
427
- val_loader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
428
- test_loader = DataLoader(test_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
429
-
430
- print('Initialized all data. Compiling the model...')
431
- model = SimSonEncoder(config=model_config, max_len=hparams['max_embeddings']).to(device)
432
- model = torch.compile(model)
433
- model.load_state_dict(torch.load('simson_checkpoints_small/simson_model_single_gpu.bin'))
434
- print(model)
435
-
436
- total_params = sum(p.numel() for p in model.parameters())
437
-
438
- print(f"Total number of parameters: {total_params // 1_000_000} M")
439
- wandb.config.update({"total_params_M": total_params // 1_000_000})
440
-
441
- criterion = ContrastiveLoss(temperature=hparams['temperature']).to(device)
442
- optimizer = optim.AdamW(model.parameters(), lr=hparams['lr'], weight_decay=1e-5, fused=True)
443
-
444
- total_steps = hparams['epochs'] * len(train_loader)
445
- scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_mult=1, T_0=total_steps)
446
-
447
- print("Starting training...")
448
- wandb.watch(model, log='all', log_freq=5000)
449
-
450
- start_step = 0
451
- checkpoint_path = hparams['checkpoint_path']
452
-
453
- # Resume from checkpoint if provided
454
- if hparams.get('resume_checkpoint') and os.path.exists(hparams['resume_checkpoint']):
455
- print(f"Resuming from checkpoint: {hparams['resume_checkpoint']}")
456
- start_step = load_checkpoint(hparams['resume_checkpoint'], model, optimizer, scheduler)
457
-
458
- # Train with step-based validation
459
- final_step = train_with_step_based_validation(
460
- model, train_loader, val_loader, optimizer, criterion, device,
461
- scheduler, checkpoint_path, hparams['save_steps'], hparams['validation_steps'],
462
- start_step=start_step, max_steps=total_steps
463
- )
464
-
465
- print("Training complete. Starting final testing...")
466
-
467
- # Load the best model for testing (model state only)
468
- best_model_path = checkpoint_path.replace('.pt', '_best_model.bin')
469
- if os.path.exists(best_model_path):
470
- model.load_state_dict(torch.load(best_model_path))
471
- print("Loaded best model for testing")
472
-
473
- test_loss, avg_sim, std_sim = test_model(model, test_loader, criterion, device)
474
-
475
- print("\n--- Test Results ---")
476
- print(f"Test Loss: {test_loss:.4f}")
477
- print(f"Average Cosine Similarity: {avg_sim:.4f} ± {std_sim:.4f}")
478
- print("--------------------")
479
-
480
- wandb.log({
481
- "test_loss": test_loss,
482
- "avg_cosine_similarity": avg_sim,
483
- "std_cosine_similarity": std_sim
484
- })
485
-
486
- # Save final model state only
487
- final_model_path = hparams['save_path']
488
- torch.save(model.state_dict(), final_model_path)
489
- print(f"Final model saved to {final_model_path}")
490
-
491
- wandb.finish()
492
 
493
- # ==============================================================================
494
- # 8. MAIN EXECUTION
495
- # ==============================================================================
496
  def main():
497
- """Main function to configure and run the training process."""
498
- hparams = {
499
- 'epochs': 1,
500
- 'lr': 1e-5,
501
- 'temperature': 0.05,
502
- 'batch_size': 64,
503
- 'max_length': 256,
504
- 'save_path': "simson_checkpoints_pubchem/simson_model_single_gpu.bin",
505
- 'checkpoint_path': "simson_checkpoints/checkpoint.pt", # Full checkpoint
506
- 'save_steps': 50_000, # Save checkpoint every 10k steps
507
- 'validation_steps': 50_000, # Validate every 5k steps
508
- 'max_embeddings': 512,
509
- 'resume_checkpoint': None, # Set to checkpoint path to resume
510
- }
511
-
512
- dataset = load_dataset('HoangHa/SMILES-250M')['train']
513
- smiles_column_name = 'SMILES'
514
 
515
- total_size = len(dataset)
516
- test_size = int(0.1 * total_size)
517
- val_size = int(0.1 * (total_size - test_size))
518
-
519
- test_smiles = dataset.select(range(test_size))[smiles_column_name]
520
- val_smiles = dataset.select(range(test_size, test_size + val_size))[smiles_column_name]
521
- train_smiles = dataset.select(range(test_size + val_size, total_size))[smiles_column_name]
522
- data_splits = (train_smiles, val_smiles, test_smiles)
523
-
524
- tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
525
- model_config = BertConfig(
526
- vocab_size=tokenizer.vocab_size,
 
 
 
 
 
 
 
 
 
 
527
  hidden_size=768,
528
  num_hidden_layers=4,
529
  num_attention_heads=12,
530
  intermediate_size=2048,
531
  max_position_embeddings=512
532
  )
533
-
534
- # Create directories
535
- save_dir = os.path.dirname(hparams['save_path'])
536
- checkpoint_dir = os.path.dirname(hparams['checkpoint_path'])
537
- for directory in [save_dir, checkpoint_dir]:
538
- if not os.path.exists(directory):
539
- os.makedirs(directory)
540
 
541
- # Directly call the training function for a single-GPU run
542
- run_training(model_config, hparams, data_splits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
 
544
  if __name__ == '__main__':
 
 
545
  main()
 
1
+ import pandas as pd
2
+ import numpy as np
 
 
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from transformers import BertConfig, BertModel, AutoTokenizer
8
+ from rdkit import Chem
9
+ from rdkit.Chem.Scaffolds import MurckoScaffold
10
+ import copy
11
  from tqdm import tqdm
12
+ import os
13
+ from sklearn.metrics import roc_auc_score, root_mean_squared_error, mean_absolute_error
14
+ from itertools import compress
15
+ from collections import defaultdict
16
 
17
+ torch.set_float32_matmul_precision('high')
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # --- 1. Data Loading ---
20
+ # Function to load datasets from their respective URLs.
21
+ def load_lists_from_url(data):
22
+ """
23
+ Load SMILES and labels from Moleculenet website.
24
+ """
25
+ if data == 'bbbp':
26
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv')
27
+ smiles, labels = df.smiles, df.p_np
28
+ elif data == 'clintox':
29
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip')
30
+ smiles = df.smiles
31
+ labels = df.drop(['smiles'], axis=1)
32
+ elif data == 'hiv':
33
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv')
34
+ smiles, labels = df.smiles, df.HIV_active
35
+ elif data == 'sider':
36
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip')
37
+ smiles = df.smiles
38
+ labels = df.drop(['smiles'], axis=1) # (1427, 27)
39
+ elif data == 'esol':
40
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv')
41
+ smiles = df.smiles
42
+ labels = df['ESOL predicted log solubility in mols per litre']
43
+ elif data == 'freesolv':
44
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv')
45
+ smiles = df.smiles
46
+ labels = df.calc
47
+ elif data == 'lipophicility':
48
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv')
49
+ smiles, labels = df.smiles, df['exp']
50
+ elif data == 'tox21':
51
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip')
52
+ df = df.dropna(axis=0, how='any').reset_index(drop=True) # drop nan values
53
+ smiles = df.smiles
54
+ labels = df.drop(['mol_id', 'smiles'], axis=1) # 12 cols
55
+ elif data == 'bace':
56
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv')
57
+ smiles, labels = df.mol, df.Class
58
+ elif data == 'tox21':
59
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip')
60
+ df = df.dropna(axis=0, how='any').reset_index(drop=True) # drop nan values
61
+ smiles = df.smiles
62
+ labels = df.drop(['mol_id', 'smiles'], axis=1) # 12 cols
63
+ elif data == 'qm8':
64
+ df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv')
65
+ df = df.dropna(axis=0, how='any').reset_index(drop=True) # drop nan values
66
+ smiles = df.smiles
67
+ labels = df.drop(['smiles', 'E2-PBE0.1', 'E1-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1'], axis=1) # 12 tasks
68
+
69
+ return smiles, labels
70
+
71
+ # --- 2. Scaffold Splitting ---
72
+ # Class to split the dataset based on molecular scaffolds.
73
+ class ScaffoldSplitter:
74
+ def __init__(self, data, seed, train_frac=0.8, val_frac=0.1, test_frac=0.1, include_chirality=True):
75
+ self.data = data
76
+ self.seed = seed
77
+ self.include_chirality = include_chirality
78
+ self.train_frac = train_frac
79
+ self.val_frac = val_frac
80
+ self.test_frac = test_frac
81
+
82
+ def generate_scaffold(self, smiles):
83
+ mol = Chem.MolFromSmiles(smiles)
84
+ scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=self.include_chirality)
85
+ return scaffold
86
+
87
+ def scaffold_split(self):
88
+ smiles, labels = load_lists_from_url(self.data)
89
 
90
+ # Initialize non_null as False for all samples
91
+ non_null = np.ones(len(smiles)) == 0
92
+
93
+ # Dataset-specific null handling
94
+ if self.data == 'tox21' or self.data == 'sider' or self.data == 'clintox':
95
+ for i in range(len(smiles)):
96
+ # Check if molecule is valid AND no missing labels
97
+ if Chem.MolFromSmiles(smiles[i]) and labels.loc[i].isnull().sum() == 0:
98
+ non_null[i] = 1
99
+ else:
100
+ # For single-task datasets, only check molecule validity
101
+ for i in range(len(smiles)):
102
+ if Chem.MolFromSmiles(smiles[i]):
103
+ non_null[i] = 1
104
+
105
+ # Extract valid samples with original indices preserved
106
+ smiles_list = list(compress(enumerate(smiles), non_null))
107
+
108
+ rng = np.random.RandomState(self.seed)
109
+
110
+ # Group by scaffold
111
+ scaffolds = defaultdict(list)
112
+ for i, sms in smiles_list:
113
+ scaffold = self.generate_scaffold(sms)
114
+ scaffolds[scaffold].append(i)
115
+
116
+ scaffold_sets = list(scaffolds.values())
117
+ rng.shuffle(scaffold_sets)
118
+ # Calculate target sizes for validation and test sets
119
+ n_total_val = int(np.floor(self.val_frac * len(smiles_list)))
120
+ n_total_test = int(np.floor(self.test_frac * len(smiles_list)))
121
+
122
+ train_idx, val_idx, test_idx = [], [], []
123
+
124
+ # Assign scaffold groups to splits
125
+ for scaffold_set in scaffold_sets:
126
+ if len(val_idx) + len(scaffold_set) <= n_total_val:
127
+ val_idx.extend(scaffold_set)
128
+ elif len(test_idx) + len(scaffold_set) <= n_total_test:
129
+ test_idx.extend(scaffold_set)
130
+ else:
131
+ train_idx.extend(scaffold_set)
132
+
133
+ return train_idx, val_idx, test_idx
134
+ # --- 3. PyTorch Dataset ---
135
+ # Custom Dataset class for handling SMILES data.
136
+ class MoleculeDataset(Dataset):
137
+ def __init__(self, smiles_list, labels, tokenizer, max_len=512):
138
  self.smiles_list = smiles_list
139
+ self.labels = labels
140
  self.tokenizer = tokenizer
141
+ self.max_len = max_len
 
142
 
143
  def __len__(self):
144
  return len(self.smiles_list)
145
 
146
  def __getitem__(self, idx):
147
+ smiles = self.smiles_list[idx]
148
+ label = self.labels.iloc[idx]
149
+
150
+ encoding = self.tokenizer(
151
+ smiles,
152
+ truncation=True,
153
+ padding='max_length',
154
+ max_length=self.max_len,
155
+ return_tensors='pt'
156
+ )
157
 
158
+ item = {key: val.squeeze(0) for key, val in encoding.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ # Handle single-task and multi-task labels
161
+ if isinstance(label, pd.Series):
162
+ label_values = label.values.astype(np.float32)
163
+ else:
164
+ label_values = np.array([label], dtype=np.float32)
165
 
166
+ item['labels'] = torch.tensor(label_values, dtype=torch.float)
167
+ return item
 
 
 
 
168
 
169
+ # --- 4. Model Architecture ---
170
+ def global_ap(x):
171
  """
172
+ Global Average Pooling
173
+ Input: [B, max_len, hid_dim]
174
+ Return: [B, hid_dim]
175
  """
176
+ return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1)
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ class SimSonEncoder(nn.Module):
179
+ def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):
180
+ super(SimSonEncoder, self).__init__()
181
+ self.config = config
182
+ self.max_len = max_len
183
+ self.bert = BertModel(config, add_pooling_layer=False)
184
+ self.linear = nn.Linear(config.hidden_size, max_len)
185
+ self.dropout = nn.Dropout(dropout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ def forward(self, input_ids, attention_mask=None):
188
+ if attention_mask is None:
189
+ attention_mask = input_ids.ne(self.config.pad_token_id)
190
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
191
+ hidden_states = self.dropout(outputs.last_hidden_state)
192
+ pooled = global_ap(hidden_states)
193
+ return self.linear(pooled)
194
+
195
+ class SimSonClassifier(nn.Module):
196
+ def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1):
197
+ super(SimSonClassifier, self).__init__()
198
+ self.encoder = encoder
199
+ self.clf = nn.Linear(encoder.max_len, num_labels)
200
+ self.relu = nn.ReLU()
201
+ self.dropout = nn.Dropout(dropout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ def forward(self, input_ids, attention_mask=None):
204
+ x = self.encoder(input_ids, attention_mask)
205
+ x = self.relu(self.dropout(x))
206
+ logits = self.clf(x)
207
+ return logits
208
+
209
+ def load_encoder_params(self, state_dict_path):
210
+ """Loads pretrained parameters into the SimSonEncoder."""
211
+ self.encoder.load_state_dict(torch.load(state_dict_path))
212
+ print("Pretrained encoder parameters loaded.")
213
+
214
+ # --- 5. Training, Validation, and Testing Loops ---
215
+ def get_criterion(task_type, num_labels):
216
+ """Select loss function based on task."""
217
+ if task_type == 'classification':
218
+ return nn.BCEWithLogitsLoss()
219
+ elif task_type == 'regression':
220
+ return nn.MSELoss()
221
+ else:
222
+ raise ValueError(f"Unknown task type: {task_type}")
223
+
224
+ def train_epoch(model, dataloader, optimizer, scheduler, criterion, device):
225
  model.train()
226
+ total_loss = 0
227
+ for batch in dataloader:
228
+ inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
229
+ labels = batch['labels'].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  optimizer.zero_grad()
232
+ outputs = model(**inputs)
233
+ loss = criterion(outputs, labels)
 
 
 
 
 
 
 
 
 
234
  loss.backward()
235
  optimizer.step()
 
236
  scheduler.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
 
 
 
 
 
 
 
 
238
  total_loss += loss.item()
239
+ return total_loss / len(dataloader)
 
 
 
240
 
241
+ def eval_epoch(model, dataloader, criterion, device):
 
242
  model.eval()
243
  total_loss = 0
244
+ with torch.no_grad():
245
+ for batch in dataloader:
246
+ inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
247
+ labels = batch['labels'].to(device)
248
+ outputs = model(**inputs)
249
+ loss = criterion(outputs, labels)
250
+ total_loss += loss.item()
251
+ return total_loss / len(dataloader)
252
+
253
+ def test_model(model, dataloader, device):
254
+ model.eval()
255
+ all_preds, all_labels = [], []
256
+ with torch.no_grad():
257
+ for batch in dataloader:
258
+ inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
259
+ labels = batch['labels']
260
+ outputs = model(**inputs)
261
+
262
+ # Apply sigmoid for classification probabilities
263
+ preds = torch.sigmoid(outputs)
264
+
265
+ all_preds.append(preds.cpu().numpy())
266
+ all_labels.append(labels.numpy())
267
+
268
+ return np.concatenate(all_preds), np.concatenate(all_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
+ # --- 6. Main Execution Block ---
 
 
271
  def main():
272
+ # --- Configuration ---
273
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
274
+ print(f"Using device: {DEVICE}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ DATASETS_TO_RUN = {
277
+ #'esol': {'task_type': 'regression', 'num_labels': 1},
278
+ #'freesolv': {'task_type': 'regression', 'num_labels':1},
279
+ #'lipophicility': {'task_type': 'regression', 'num_labels': 1},
280
+ #'qm8': {'task_type': 'regression', 'num_labels': 12},
281
+ #'bbbp': {'task_type': 'classification', 'num_labels': 1},
282
+ 'tox21': {'task_type': 'classification', 'num_labels': 12},
283
+ #'sider': {'task_type': 'classification', 'num_labels': 27},
284
+ #'clintox': {'task_type': 'classification', 'num_labels': 2},
285
+ #'hiv': {'task_type': 'classification', 'num_labels': 1},
286
+ #'bace': {'task_type': 'classification', 'num_labels': 1},
287
+ }
288
+ PATIENCE = 25
289
+ EPOCHS = 200
290
+ LEARNING_RATE = 2e-5
291
+ BATCH_SIZE = 128
292
+ MAX_LEN = 256
293
+
294
+ # --- Tokenizer and Model Config ---
295
+ TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
296
+ ENCODER_CONFIG = BertConfig(
297
+ vocab_size=TOKENIZER.vocab_size,
298
  hidden_size=768,
299
  num_hidden_layers=4,
300
  num_attention_heads=12,
301
  intermediate_size=2048,
302
  max_position_embeddings=512
303
  )
 
 
 
 
 
 
 
304
 
305
+ aggregated_results = {}
306
+
307
+ for name, info in DATASETS_TO_RUN.items():
308
+ print(f"\n{'='*20} Processing Dataset: {name.upper()} {'='*20}")
309
+
310
+ # --- Data Loading and Splitting ---
311
+ splitter = ScaffoldSplitter(data=name, seed=42)
312
+ train_idx, val_idx, test_idx = splitter.scaffold_split()
313
+
314
+ # Load data once
315
+ smiles, labels = load_lists_from_url(name)
316
+
317
+ # Extract splits using returned indices
318
+ train_smiles = smiles.iloc[train_idx].reset_index(drop=True)
319
+ train_labels = labels.iloc[train_idx].reset_index(drop=True)
320
+
321
+ val_smiles = smiles.iloc[val_idx].reset_index(drop=True)
322
+ val_labels = labels.iloc[val_idx].reset_index(drop=True)
323
+
324
+ test_smiles = smiles.iloc[test_idx].reset_index(drop=True)
325
+ test_labels = labels.iloc[test_idx].reset_index(drop=True)
326
+ print(f"Data split - Train: {len(train_smiles)}, Val: {len(val_smiles)}, Test: {len(test_smiles)}")
327
+
328
+ train_dataset = MoleculeDataset(train_smiles, train_labels, TOKENIZER, MAX_LEN)
329
+ val_dataset = MoleculeDataset(val_smiles, val_labels, TOKENIZER, MAX_LEN)
330
+ test_dataset = MoleculeDataset(test_smiles, test_labels, TOKENIZER, MAX_LEN)
331
+
332
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
333
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
334
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
335
+
336
+ # --- Model, Loss, and Optimizer ---
337
+ encoder = SimSonEncoder(ENCODER_CONFIG, 512)
338
+ encoder = torch.compile(encoder)
339
+ model = SimSonClassifier(encoder, num_labels=info['num_labels']).to(DEVICE)
340
+ model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin')
341
+
342
+ criterion = get_criterion(info['task_type'], info['num_labels'])
343
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
344
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS * len(train_loader))
345
+ # --- Training and Validation ---
346
+ best_val_loss = float('inf')
347
+ best_model_state = None
348
+ current_patience = 0
349
+ for epoch in range(EPOCHS):
350
+ train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, DEVICE)
351
+ val_loss = eval_epoch(model, val_loader, criterion, DEVICE)
352
+ print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
353
+
354
+ if val_loss < best_val_loss:
355
+ best_val_loss = val_loss
356
+ best_model_state = copy.deepcopy(model.state_dict())
357
+ print(f" -> New best model saved with validation loss: {best_val_loss:.4f}")
358
+ current_patience = 0
359
+ else:
360
+ current_patience += 1
361
+ if current_patience >= PATIENCE:
362
+ print(f'Early stopping at {PATIENCE} epochs')
363
+ break
364
+
365
+ # --- Testing ---
366
+ print("\nTesting with the best model...")
367
+ model.load_state_dict(best_model_state)
368
+ test_preds, test_true = test_model(model, test_loader, DEVICE)
369
+
370
+ # Store results. For classification, you can now calculate metrics like ROC-AUC.
371
+ aggregated_results[name] = {
372
+ 'best_val_loss': best_val_loss,
373
+ 'test_predictions': test_preds,
374
+ 'test_labels': test_true
375
+ }
376
+ print(f"Finished testing for {name}.")
377
+
378
+ # --- Final Results Aggregation ---
379
+ print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}")
380
+ for name, result in aggregated_results.items():
381
+ # Here you would typically calculate and display final metrics from predictions
382
+ # For example, using scikit-learn's roc_auc_score
383
+ # from sklearn.metrics import roc_auc_score
384
+ if name in ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace']:
385
+ auc = roc_auc_score(result['test_labels'], result['test_predictions'], average='macro')
386
+ print(f'{name} ROC AUC: {auc}')
387
+
388
+ if name in ['lipophicility', 'esol', 'qm8']:
389
+ rmse = root_mean_squared_error(result['test_labels'], result['test_predictions'])
390
+ mae = mean_absolute_error(result['test_labels'], result['test_predictions'])
391
+ print(f'{name} MAE: {mae}')
392
+ print(f'{name} RMSE: {rmse}')
393
+
394
+ print("\nScript finished.")
395
 
396
  if __name__ == '__main__':
397
+ # Note: This script requires rdkit. You can install it via pip:
398
+ # pip install rdkit-pypi
399
  main()