Defetya commited on
Commit
ce0db25
·
verified ·
1 Parent(s): 993bee6

Upload moleculenet_eval/eval.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. moleculenet_eval/eval.py +545 -0
moleculenet_eval/eval.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # 1. IMPORTS
3
+ # ==============================================================================
4
+ import os
5
+ import warnings
6
+ import wandb
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader, Dataset
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ from rdkit import Chem, RDLogger
16
+ from datasets import load_dataset, load_from_disk
17
+ from transformers import AutoTokenizer, BertModel, BertConfig
18
+ import pandas as pd
19
+
20
+ # ==============================================================================
21
+ # 2. INITIAL SETUP
22
+ # ==============================================================================
23
+ # Suppress RDKit console output
24
+ RDLogger.DisableLog('rdApp.*')
25
+ # Ignore warnings for cleaner output
26
+ warnings.filterwarnings("ignore")
27
+
28
+ # ==============================================================================
29
+ # 3. MODEL AND LOSS FUNCTION
30
+ # ==============================================================================
31
+ def global_average_pooling(x):
32
+ """Global Average Pooling: from [B, max_len, hid_dim] to [B, hid_dim]"""
33
+ return torch.mean(x, dim=1)
34
+
35
+ class SimSonEncoder(nn.Module):
36
+ """The main encoder model based on BERT."""
37
+ def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1):
38
+ super(SimSonEncoder, self).__init__()
39
+ self.bert = BertModel(config, add_pooling_layer=False)
40
+ self.linear = nn.Linear(config.hidden_size, max_len)
41
+ self.dropout = nn.Dropout(dropout)
42
+
43
+ def forward(self, input_ids, attention_mask=None):
44
+ if attention_mask is None:
45
+ attention_mask = input_ids.ne(self.bert.config.pad_token_id)
46
+
47
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
48
+ hidden_states = self.dropout(outputs.last_hidden_state)
49
+ pooled_output = global_average_pooling(hidden_states)
50
+ return self.linear(pooled_output)
51
+
52
+ class ContrastiveLoss(nn.Module):
53
+ """Calculates the contrastive loss for the SimSon model."""
54
+ def __init__(self, temperature=0.2):
55
+ super(ContrastiveLoss, self).__init__()
56
+ self.temperature = temperature
57
+ self.similarity_fn = F.cosine_similarity
58
+
59
+ def forward(self, proj_1, proj_2):
60
+ batch_size = proj_1.shape[0]
61
+ device = proj_1.device
62
+
63
+ # Normalize projections
64
+ z_i = F.normalize(proj_1, p=2, dim=1)
65
+ z_j = F.normalize(proj_2, p=2, dim=1)
66
+
67
+ # Concatenate for similarity matrix calculation
68
+ representations = torch.cat([z_i, z_j], dim=0)
69
+
70
+ # Calculate cosine similarity between all pairs
71
+ similarity_matrix = self.similarity_fn(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
72
+
73
+ # Identify positive pairs (original and its augmentation)
74
+ sim_ij = torch.diag(similarity_matrix, batch_size)
75
+ sim_ji = torch.diag(similarity_matrix, -batch_size)
76
+ positives = torch.cat([sim_ij, sim_ji], dim=0)
77
+
78
+ # Create a mask to exclude self-comparisons
79
+ nominator = torch.exp(positives / self.temperature)
80
+ mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool, device=device)).float()
81
+ denominator = mask * torch.exp(similarity_matrix / self.temperature)
82
+
83
+ # Calculate the final loss
84
+ loss = -torch.log(nominator / torch.sum(denominator, dim=1))
85
+ return torch.sum(loss) / (2 * batch_size)
86
+
87
+ # ==============================================================================
88
+ # 4. DATA HANDLING (Keeping your existing classes unchanged)
89
+ # ==============================================================================
90
+ class SmilesEnumerator:
91
+ """Generates randomized SMILES strings for data augmentation."""
92
+ def randomize_smiles(self, smiles):
93
+ try:
94
+ mol = Chem.MolFromSmiles(smiles)
95
+ return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles
96
+ except:
97
+ return smiles
98
+
99
+ class ContrastiveSmilesDataset(Dataset):
100
+ """Dataset for creating pairs of augmented SMILES for contrastive learning."""
101
+ def __init__(self, smiles_list, tokenizer, max_length=512):
102
+ self.smiles_list = smiles_list
103
+ self.tokenizer = tokenizer
104
+ self.max_length = max_length
105
+ self.enumerator = SmilesEnumerator()
106
+
107
+ def __len__(self):
108
+ return len(self.smiles_list)
109
+
110
+ def __getitem__(self, idx):
111
+ original_smiles = self.smiles_list[idx]
112
+
113
+ # Create two different augmentations of the same SMILES
114
+ smiles_1 = self.enumerator.randomize_smiles(original_smiles)
115
+ smiles_2 = self.enumerator.randomize_smiles(original_smiles)
116
+
117
+ # Tokenize and do pad. Padding will be handled by the collate_fn.
118
+ tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')
119
+ tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')
120
+
121
+ return {
122
+ 'input_ids_1': torch.tensor(tokens_1['input_ids']),
123
+ 'attention_mask_1': torch.tensor(tokens_1['attention_mask']),
124
+ 'input_ids_2': torch.tensor(tokens_2['input_ids']),
125
+ 'attention_mask_2': torch.tensor(tokens_2['attention_mask']),
126
+ }
127
+
128
+ class PrecomputedContrastiveSmilesDataset(Dataset):
129
+ """
130
+ A Dataset class that reads pre-augmented SMILES pairs from a Parquet file.
131
+ This is significantly faster as it offloads the expensive SMILES randomization
132
+ to a one-time preprocessing step.
133
+ """
134
+ def __init__(self, tokenizer, file_path: str, max_length: int = 512):
135
+ self.tokenizer = tokenizer
136
+ self.max_length = max_length
137
+
138
+ # Load the entire dataset from the Parquet file into memory.
139
+ # This is fast and efficient for subsequent access.
140
+ print(f"Loading pre-computed data from {file_path}...")
141
+ self.data = pd.read_parquet(file_path)
142
+ print("Data loaded successfully.")
143
+
144
+ def __len__(self):
145
+ """Returns the total number of pairs in the dataset."""
146
+ return len(self.data)
147
+
148
+ def __getitem__(self, idx):
149
+ """
150
+ Retrieves a pre-augmented pair, tokenizes it, and returns it
151
+ in the format expected by the DataCollator.
152
+ """
153
+ # Retrieve the pre-augmented pair from the DataFrame
154
+ row = self.data.iloc[idx]
155
+ smiles_1 = row['smiles_1']
156
+ smiles_2 = row['smiles_2']
157
+
158
+ # Tokenize the pair. This operation is fast and remains in the data loader.
159
+ tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length')
160
+ tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length')
161
+
162
+ return {
163
+ 'input_ids_1': torch.tensor(tokens_1['input_ids']),
164
+ 'attention_mask_1': torch.tensor(tokens_1['attention_mask']),
165
+ 'input_ids_2': torch.tensor(tokens_2['input_ids']),
166
+ 'attention_mask_2': torch.tensor(tokens_2['attention_mask']),
167
+ }
168
+
169
+ class PreTokenizedSmilesDataset(Dataset):
170
+ """
171
+ A Dataset that loads a pre-tokenized and pre-padded dataset created
172
+ by the preprocessing script. It uses memory-mapping for instant loads
173
+ and high efficiency.
174
+ """
175
+ def __init__(self, dataset_path: str):
176
+ # Load the dataset from disk. This is very fast due to memory-mapping.
177
+ self.dataset = load_from_disk(dataset_path)
178
+ # Set the format to PyTorch tensors for direct use in the model
179
+ self.dataset.set_format(type='torch', columns=[
180
+ 'input_ids_1', 'attention_mask_1', 'input_ids_2', 'attention_mask_2'
181
+ ])
182
+ print(f"Successfully loaded pre-tokenized dataset from {dataset_path}.")
183
+
184
+ def __len__(self):
185
+ """Returns the total number of items in the dataset."""
186
+ return len(self.dataset)
187
+
188
+ def __getitem__(self, idx):
189
+ """Retrieves a single pre-processed item."""
190
+ return self.dataset[idx]
191
+
192
+ class DataCollatorWithPadding:
193
+ """
194
+ A collate function that dynamically pads inputs to the longest sequence
195
+ across both augmented views in the batch, ensuring consistent tensor shapes.
196
+ """
197
+ def __init__(self, tokenizer):
198
+ self.tokenizer = tokenizer
199
+
200
+ def __call__(self, features):
201
+ # Create a combined list of features for both views to find the global max length
202
+ combined_features = []
203
+ for feature in features:
204
+ combined_features.append({'input_ids': feature['input_ids_1'], 'attention_mask': feature['attention_mask_1']})
205
+ combined_features.append({'input_ids': feature['input_ids_2'], 'attention_mask': feature['attention_mask_2']})
206
+
207
+ # Pad the combined batch. This ensures all sequences are padded to the same length.
208
+ padded_combined = self.tokenizer.pad(combined_features, padding='longest', return_tensors='pt')
209
+
210
+ # Split the padded tensors back into two views
211
+ batch_size = len(features)
212
+ input_ids_1, input_ids_2 = torch.split(padded_combined['input_ids'], batch_size, dim=0)
213
+ attention_mask_1, attention_mask_2 = torch.split(padded_combined['attention_mask'], batch_size, dim=0)
214
+
215
+ return {
216
+ 'input_ids_1': input_ids_1,
217
+ 'attention_mask_1': attention_mask_1,
218
+ 'input_ids_2': input_ids_2,
219
+ 'attention_mask_2': attention_mask_2,
220
+ }
221
+
222
+ # ==============================================================================
223
+ # 5. CHECKPOINT UTILITIES
224
+ # ==============================================================================
225
+ def save_checkpoint(model, optimizer, scheduler, global_step, save_path):
226
+ """Save complete checkpoint with model, optimizer, scheduler states and step count."""
227
+ checkpoint = {
228
+ 'model_state_dict': model.state_dict(),
229
+ 'optimizer_state_dict': optimizer.state_dict(),
230
+ 'scheduler_state_dict': scheduler.state_dict(),
231
+ 'global_step': global_step,
232
+ }
233
+ torch.save(checkpoint, save_path)
234
+ print(f"Full checkpoint saved at step {global_step}")
235
+
236
+ def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
237
+ """Load checkpoint and return the global step to resume from."""
238
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
239
+ model.load_state_dict(checkpoint['model_state_dict'])
240
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
241
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
242
+ global_step = checkpoint['global_step']
243
+ print(f"Checkpoint loaded from step {global_step}")
244
+ return global_step
245
+
246
+ # ==============================================================================
247
+ # 6. TRAINING AND EVALUATION LOOPS - MODIFIED
248
+ # ==============================================================================
249
+ def evaluation_step(model, batch, criterion, device):
250
+ """Performs a single evaluation step on a batch of data."""
251
+ input_ids_1 = batch['input_ids_1'].to(device)
252
+ attention_mask_1 = batch['attention_mask_1'].to(device)
253
+ input_ids_2 = batch['input_ids_2'].to(device)
254
+ attention_mask_2 = batch['attention_mask_2'].to(device)
255
+
256
+ combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0)
257
+ combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0)
258
+
259
+ with torch.no_grad():
260
+ combined_proj = model(combined_input_ids, combined_attention_mask)
261
+
262
+ batch_size = input_ids_1.size(0)
263
+ proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0)
264
+
265
+ loss = criterion(proj_1, proj_2)
266
+ return proj_1, proj_2, loss
267
+
268
+ def train_with_step_based_validation(model, train_loader, val_loader, optimizer, criterion, device,
269
+ scheduler, checkpoint_path, save_steps, validation_steps,
270
+ start_step=0, max_steps=None):
271
+ """
272
+ Modified training function with step-based validation and checkpointing.
273
+ """
274
+ model.train()
275
+ global_step = start_step
276
+ best_val_loss = float('inf')
277
+
278
+ # Calculate total steps if max_steps is not provided
279
+ if max_steps is None:
280
+ max_steps = len(train_loader)
281
+
282
+ progress_bar = tqdm(total=max_steps - start_step, desc="Training Steps", initial=start_step)
283
+
284
+ # Create iterator that can be resumed from any point
285
+ train_iterator = iter(train_loader)
286
+
287
+ # Skip batches if resuming from checkpoint
288
+ if start_step > 0:
289
+ batches_to_skip = start_step % len(train_loader)
290
+ for _ in range(batches_to_skip):
291
+ try:
292
+ next(train_iterator)
293
+ except StopIteration:
294
+ train_iterator = iter(train_loader)
295
+
296
+ while global_step < max_steps:
297
+ try:
298
+ batch = next(train_iterator)
299
+ except StopIteration:
300
+ train_iterator = iter(train_loader)
301
+ batch = next(train_iterator)
302
+
303
+ # Training step
304
+ input_ids_1 = batch['input_ids_1'].to(device)
305
+ attention_mask_1 = batch['attention_mask_1'].to(device)
306
+ input_ids_2 = batch['input_ids_2'].to(device)
307
+ attention_mask_2 = batch['attention_mask_2'].to(device)
308
+
309
+ optimizer.zero_grad()
310
+ with torch.autocast(dtype=torch.float16, device_type="cuda"):
311
+ combined_input_ids = torch.cat([input_ids_1, input_ids_2], dim=0)
312
+ combined_attention_mask = torch.cat([attention_mask_1, attention_mask_2], dim=0)
313
+
314
+ combined_proj = model(combined_input_ids, combined_attention_mask)
315
+
316
+ batch_size = input_ids_1.size(0)
317
+ proj_1, proj_2 = torch.split(combined_proj, batch_size, dim=0)
318
+
319
+ loss = criterion(proj_1, proj_2)
320
+
321
+ loss.backward()
322
+ optimizer.step()
323
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
324
+ scheduler.step()
325
+
326
+ global_step += 1
327
+
328
+ progress_bar.update(1)
329
+ progress_bar.set_postfix(loss=f"{loss.item():.4f}", step=global_step)
330
+
331
+ wandb.log({
332
+ "train_batch_loss": loss.item(),
333
+ "learning_rate": scheduler.get_last_lr()[0],
334
+ "global_step": global_step
335
+ })
336
+
337
+ # Step-based validation
338
+ if global_step % validation_steps == 0:
339
+ val_loss = validate_epoch(model, val_loader, criterion, device)
340
+ wandb.log({
341
+ "val_loss": val_loss,
342
+ "global_step": global_step
343
+ })
344
+
345
+ # Save best model (model state only for best checkpoint)
346
+ if val_loss < best_val_loss:
347
+ best_val_loss = val_loss
348
+ model_save_path = checkpoint_path.replace('.pt', '_best_model.bin')
349
+ torch.save(model.state_dict(), model_save_path)
350
+ progress_bar.write(f"Step {global_step}: New best model saved with val loss {val_loss:.4f}")
351
+
352
+ model.train() # Resume training mode after validation
353
+
354
+ # Step-based checkpointing (full checkpoint)
355
+ if global_step % save_steps == 0:
356
+ save_checkpoint(model, optimizer, scheduler, global_step, checkpoint_path)
357
+
358
+ progress_bar.close()
359
+ return global_step
360
+
361
+ def validate_epoch(model, val_loader, criterion, device):
362
+ """Validation function - unchanged from original."""
363
+ model.eval()
364
+ total_loss = 0
365
+ progress_bar = tqdm(val_loader, desc="Validating", leave=False)
366
+
367
+ for batch in progress_bar:
368
+ _, _, loss = evaluation_step(model, batch, criterion, device)
369
+ total_loss += loss.item()
370
+
371
+ avg_loss = total_loss / len(val_loader)
372
+ print(f'Validation loss: {avg_loss:.4f}')
373
+ return avg_loss
374
+
375
+ def test_model(model, test_loader, criterion, device):
376
+ """Test function - unchanged from original."""
377
+ model.eval()
378
+ total_loss = 0
379
+ all_similarities = []
380
+ progress_bar = tqdm(test_loader, desc="Testing", leave=False)
381
+
382
+ for batch in progress_bar:
383
+ proj_1, proj_2, loss = evaluation_step(model, batch, criterion, device)
384
+ total_loss += loss.item()
385
+
386
+ proj_1_norm = F.normalize(proj_1, p=2, dim=1)
387
+ proj_2_norm = F.normalize(proj_2, p=2, dim=1)
388
+ batch_similarities = F.cosine_similarity(proj_1_norm, proj_2_norm, dim=1)
389
+ all_similarities.extend(batch_similarities.cpu().numpy())
390
+
391
+ avg_loss = total_loss / len(test_loader)
392
+ avg_sim = np.mean(all_similarities)
393
+ std_sim = np.std(all_similarities)
394
+
395
+ return avg_loss, avg_sim, std_sim
396
+
397
+ # ==============================================================================
398
+ # 7. MODIFIED SINGLE-GPU TRAINING
399
+ # ==============================================================================
400
+ def run_training(model_config, hparams, data_splits):
401
+ """The main function to run the training and evaluation process with step-based validation."""
402
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
403
+ print(f"Using device: {device}")
404
+
405
+ wandb_key = os.getenv("WANDB_API_KEY")
406
+ if wandb_key:
407
+ wandb.login(key=wandb_key)
408
+ wandb.init(
409
+ #project="simson-contrastive-learning-single-gpu",
410
+ #name=f"run-{wandb.util.generate_id()}",
411
+ #config=hparams
412
+ )
413
+
414
+ train_smiles, val_smiles, test_smiles = data_splits
415
+
416
+ tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
417
+
418
+ precomputed_train_path = 'data/pubchem_119m_splits/train.parquet'
419
+ precomputed_test_path = 'data/pubchem_119m_splits/test.parquet'
420
+ precomputed_val_path = 'data/pubchem_119m_splits/validation.parquet'
421
+
422
+ train_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_train_path, max_length=hparams['max_length'])
423
+ test_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_test_path, max_length=hparams['max_length'])
424
+ val_dataset = PrecomputedContrastiveSmilesDataset(tokenizer, file_path=precomputed_val_path, max_length=hparams['max_length'])
425
+
426
+ train_loader = DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True, num_workers=16, prefetch_factor=128, pin_memory=True)
427
+ val_loader = DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
428
+ test_loader = DataLoader(test_dataset, batch_size=hparams['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
429
+
430
+ print('Initialized all data. Compiling the model...')
431
+ model = SimSonEncoder(config=model_config, max_len=hparams['max_embeddings']).to(device)
432
+ model = torch.compile(model)
433
+ model.load_state_dict(torch.load('simson_checkpoints_small/simson_model_single_gpu.bin'))
434
+ print(model)
435
+
436
+ total_params = sum(p.numel() for p in model.parameters())
437
+
438
+ print(f"Total number of parameters: {total_params // 1_000_000} M")
439
+ wandb.config.update({"total_params_M": total_params // 1_000_000})
440
+
441
+ criterion = ContrastiveLoss(temperature=hparams['temperature']).to(device)
442
+ optimizer = optim.AdamW(model.parameters(), lr=hparams['lr'], weight_decay=1e-5, fused=True)
443
+
444
+ total_steps = hparams['epochs'] * len(train_loader)
445
+ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_mult=1, T_0=total_steps)
446
+
447
+ print("Starting training...")
448
+ wandb.watch(model, log='all', log_freq=5000)
449
+
450
+ start_step = 0
451
+ checkpoint_path = hparams['checkpoint_path']
452
+
453
+ # Resume from checkpoint if provided
454
+ if hparams.get('resume_checkpoint') and os.path.exists(hparams['resume_checkpoint']):
455
+ print(f"Resuming from checkpoint: {hparams['resume_checkpoint']}")
456
+ start_step = load_checkpoint(hparams['resume_checkpoint'], model, optimizer, scheduler)
457
+
458
+ # Train with step-based validation
459
+ final_step = train_with_step_based_validation(
460
+ model, train_loader, val_loader, optimizer, criterion, device,
461
+ scheduler, checkpoint_path, hparams['save_steps'], hparams['validation_steps'],
462
+ start_step=start_step, max_steps=total_steps
463
+ )
464
+
465
+ print("Training complete. Starting final testing...")
466
+
467
+ # Load the best model for testing (model state only)
468
+ best_model_path = checkpoint_path.replace('.pt', '_best_model.bin')
469
+ if os.path.exists(best_model_path):
470
+ model.load_state_dict(torch.load(best_model_path))
471
+ print("Loaded best model for testing")
472
+
473
+ test_loss, avg_sim, std_sim = test_model(model, test_loader, criterion, device)
474
+
475
+ print("\n--- Test Results ---")
476
+ print(f"Test Loss: {test_loss:.4f}")
477
+ print(f"Average Cosine Similarity: {avg_sim:.4f} ± {std_sim:.4f}")
478
+ print("--------------------")
479
+
480
+ wandb.log({
481
+ "test_loss": test_loss,
482
+ "avg_cosine_similarity": avg_sim,
483
+ "std_cosine_similarity": std_sim
484
+ })
485
+
486
+ # Save final model state only
487
+ final_model_path = hparams['save_path']
488
+ torch.save(model.state_dict(), final_model_path)
489
+ print(f"Final model saved to {final_model_path}")
490
+
491
+ wandb.finish()
492
+
493
+ # ==============================================================================
494
+ # 8. MAIN EXECUTION
495
+ # ==============================================================================
496
+ def main():
497
+ """Main function to configure and run the training process."""
498
+ hparams = {
499
+ 'epochs': 1,
500
+ 'lr': 1e-5,
501
+ 'temperature': 0.05,
502
+ 'batch_size': 64,
503
+ 'max_length': 256,
504
+ 'save_path': "simson_checkpoints_pubchem/simson_model_single_gpu.bin",
505
+ 'checkpoint_path': "simson_checkpoints/checkpoint.pt", # Full checkpoint
506
+ 'save_steps': 50_000, # Save checkpoint every 10k steps
507
+ 'validation_steps': 50_000, # Validate every 5k steps
508
+ 'max_embeddings': 512,
509
+ 'resume_checkpoint': None, # Set to checkpoint path to resume
510
+ }
511
+
512
+ dataset = load_dataset('HoangHa/SMILES-250M')['train']
513
+ smiles_column_name = 'SMILES'
514
+
515
+ total_size = len(dataset)
516
+ test_size = int(0.1 * total_size)
517
+ val_size = int(0.1 * (total_size - test_size))
518
+
519
+ test_smiles = dataset.select(range(test_size))[smiles_column_name]
520
+ val_smiles = dataset.select(range(test_size, test_size + val_size))[smiles_column_name]
521
+ train_smiles = dataset.select(range(test_size + val_size, total_size))[smiles_column_name]
522
+ data_splits = (train_smiles, val_smiles, test_smiles)
523
+
524
+ tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
525
+ model_config = BertConfig(
526
+ vocab_size=tokenizer.vocab_size,
527
+ hidden_size=768,
528
+ num_hidden_layers=4,
529
+ num_attention_heads=12,
530
+ intermediate_size=2048,
531
+ max_position_embeddings=512
532
+ )
533
+
534
+ # Create directories
535
+ save_dir = os.path.dirname(hparams['save_path'])
536
+ checkpoint_dir = os.path.dirname(hparams['checkpoint_path'])
537
+ for directory in [save_dir, checkpoint_dir]:
538
+ if not os.path.exists(directory):
539
+ os.makedirs(directory)
540
+
541
+ # Directly call the training function for a single-GPU run
542
+ run_training(model_config, hparams, data_splits)
543
+
544
+ if __name__ == '__main__':
545
+ main()