| import pandas as pd |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import BertConfig, BertModel, AutoTokenizer |
| from rdkit import Chem, RDLogger |
| from rdkit.Chem.Scaffolds import MurckoScaffold |
| import copy |
| from tqdm import tqdm |
| import os |
| from sklearn.metrics import roc_auc_score, root_mean_squared_error, mean_absolute_error |
| from itertools import compress |
| from collections import defaultdict |
| from sklearn.metrics.pairwise import cosine_similarity |
| from sklearn.preprocessing import StandardScaler, MinMaxScaler |
| import optuna |
| import warnings |
| warnings.filterwarnings("ignore") |
| RDLogger.DisableLog('rdApp.*') |
|
|
| torch.set_float32_matmul_precision('high') |
|
|
| |
| class PrecomputedContrastiveSmilesDataset(Dataset): |
| """ |
| A Dataset class that reads pre-augmented SMILES pairs from a Parquet file. |
| This is significantly faster as it offloads the expensive SMILES randomization |
| to a one-time preprocessing step. |
| """ |
| def __init__(self, tokenizer, file_path: str, max_length: int = 512): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| |
| |
| |
| print(f"Loading pre-computed data from {file_path}...") |
| self.data = pd.read_parquet(file_path) |
| print("Data loaded successfully.") |
|
|
| def __len__(self): |
| """Returns the total number of pairs in the dataset.""" |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| """ |
| Retrieves a pre-augmented pair, tokenizes it, and returns it |
| in the format expected by the DataCollator. |
| """ |
| |
| row = self.data.iloc[idx] |
| smiles_1 = row['smiles_1'] |
| smiles_2 = row['smiles_2'] |
| |
| |
| tokens_1 = self.tokenizer(smiles_1, max_length=self.max_length, truncation=True, padding='max_length') |
| tokens_2 = self.tokenizer(smiles_2, max_length=self.max_length, truncation=True, padding='max_length') |
| |
| return { |
| 'input_ids_1': torch.tensor(tokens_1['input_ids']), |
| 'attention_mask_1': torch.tensor(tokens_1['attention_mask']), |
| 'input_ids_2': torch.tensor(tokens_2['input_ids']), |
| 'attention_mask_2': torch.tensor(tokens_2['attention_mask']), |
| } |
|
|
| |
| class SmilesEnumerator: |
| """Generates randomized SMILES strings for data augmentation.""" |
| def randomize_smiles(self, smiles): |
| try: |
| mol = Chem.MolFromSmiles(smiles) |
| return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles |
| except: |
| return smiles |
|
|
| def compute_embedding_similarity_precomputed(encoder, dataset, device): |
| """ |
| Compute embedding similarity using pre-computed augmented SMILES pairs |
| """ |
| encoder.eval() |
| similarities = [] |
| |
| dataloader = DataLoader(dataset, batch_size=32, shuffle=False) |
| |
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids_1 = batch['input_ids_1'].to(device) |
| attention_mask_1 = batch['attention_mask_1'].to(device) |
| input_ids_2 = batch['input_ids_2'].to(device) |
| attention_mask_2 = batch['attention_mask_2'].to(device) |
| |
| emb_1 = encoder(input_ids_1, attention_mask_1).cpu().numpy() |
| emb_2 = encoder(input_ids_2, attention_mask_2).cpu().numpy() |
| |
| |
| batch_similarities = [] |
| for i in range(len(emb_1)): |
| sim = cosine_similarity([emb_1[i]], [emb_2[i]])[0][0] |
| batch_similarities.append(sim) |
| |
| similarities.extend(batch_similarities) |
| |
| return np.array(similarities) |
|
|
| def create_augmented_smiles_file(smiles_list, output_path, num_augmentations=1): |
| """ |
| Create a parquet file with pre-computed augmented SMILES pairs |
| """ |
| enumerator = SmilesEnumerator() |
| pairs = [] |
| |
| print(f"Generating {num_augmentations} augmentations for {len(smiles_list)} SMILES...") |
| |
| for smiles in tqdm(smiles_list): |
| for _ in range(num_augmentations): |
| augmented = enumerator.randomize_smiles(smiles) |
| pairs.append({ |
| 'smiles_1': smiles, |
| 'smiles_2': augmented |
| }) |
| |
| df = pd.DataFrame(pairs) |
| df.to_parquet(output_path, index=False) |
| print(f"Saved {len(pairs)} augmented pairs to {output_path}") |
| return output_path |
|
|
| |
| def load_lists_from_url(data): |
| |
| if data == 'bbbp': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv') |
| smiles, labels = df.smiles, df.p_np |
| elif data == 'clintox': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip') |
| smiles = df.smiles |
| labels = df.drop(['smiles'], axis=1) |
| elif data == 'hiv': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv') |
| smiles, labels = df.smiles, df.HIV_active |
| elif data == 'sider': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip') |
| smiles = df.smiles |
| labels = df.drop(['smiles'], axis=1) |
| elif data == 'esol': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv') |
| smiles = df.smiles |
| labels = df['ESOL predicted log solubility in mols per litre'] |
| elif data == 'freesolv': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv') |
| smiles = df.smiles |
| labels = df.calc |
| elif data == 'lipophicility': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv') |
| smiles, labels = df.smiles, df['exp'] |
| elif data == 'tox21': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip') |
| df = df.dropna(axis=0, how='any').reset_index(drop=True) |
| smiles = df.smiles |
| labels = df.drop(['mol_id', 'smiles'], axis=1) |
| elif data == 'bace': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv') |
| smiles, labels = df.mol, df.Class |
| elif data == 'qm8': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv') |
| df = df.dropna(axis=0, how='any').reset_index(drop=True) |
| smiles = df.smiles |
| labels = df.drop(['smiles', 'E2-PBE0.1', 'E1-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1'], axis=1) |
| return smiles, labels |
|
|
| |
| class ScaffoldSplitter: |
| def __init__(self, data, seed, train_frac=0.8, val_frac=0.1, test_frac=0.1, include_chirality=True): |
| self.data = data |
| self.seed = seed |
| self.include_chirality = include_chirality |
| self.train_frac = train_frac |
| self.val_frac = val_frac |
| self.test_frac = test_frac |
|
|
| def generate_scaffold(self, smiles): |
| mol = Chem.MolFromSmiles(smiles) |
| scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=self.include_chirality) |
| return scaffold |
|
|
| def scaffold_split(self): |
| smiles, labels = load_lists_from_url(self.data) |
| non_null = np.ones(len(smiles)) == 0 |
|
|
| if self.data in {'tox21', 'sider', 'clintox'}: |
| for i in range(len(smiles)): |
| if Chem.MolFromSmiles(smiles[i]) and labels.loc[i].isnull().sum() == 0: |
| non_null[i] = 1 |
| else: |
| for i in range(len(smiles)): |
| if Chem.MolFromSmiles(smiles[i]): |
| non_null[i] = 1 |
|
|
| smiles_list = list(compress(enumerate(smiles), non_null)) |
| rng = np.random.RandomState(self.seed) |
|
|
| scaffolds = defaultdict(list) |
| for i, sms in smiles_list: |
| scaffold = self.generate_scaffold(sms) |
| scaffolds[scaffold].append(i) |
|
|
| scaffold_sets = list(scaffolds.values()) |
| rng.shuffle(scaffold_sets) |
| n_total_val = int(np.floor(self.val_frac * len(smiles_list))) |
| n_total_test = int(np.floor(self.test_frac * len(smiles_list))) |
| train_idx, val_idx, test_idx = [], [], [] |
|
|
| for scaffold_set in scaffold_sets: |
| if len(val_idx) + len(scaffold_set) <= n_total_val: |
| val_idx.extend(scaffold_set) |
| elif len(test_idx) + len(scaffold_set) <= n_total_test: |
| test_idx.extend(scaffold_set) |
| else: |
| train_idx.extend(scaffold_set) |
| return train_idx, val_idx, test_idx |
|
|
| |
| def random_split_indices(n, seed=42, train_frac=0.8, val_frac=0.1, test_frac=0.1): |
| np.random.seed(seed) |
| indices = np.random.permutation(n) |
| n_train = int(n * train_frac) |
| n_val = int(n * val_frac) |
| train_idx = indices[:n_train] |
| val_idx = indices[n_train:n_train+n_val] |
| test_idx = indices[n_train+n_val:] |
| return train_idx.tolist(), val_idx.tolist(), test_idx.tolist() |
|
|
| |
| class MoleculeDataset(Dataset): |
| def __init__(self, smiles_list, labels, tokenizer, max_len=512): |
| self.smiles_list = smiles_list |
| self.labels = labels |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
|
|
| def __len__(self): |
| return len(self.smiles_list) |
|
|
| def __getitem__(self, idx): |
| smiles = self.smiles_list[idx] |
| label = self.labels.iloc[idx] |
|
|
| encoding = self.tokenizer( |
| smiles, |
| truncation=True, |
| padding='max_length', |
| max_length=self.max_len, |
| return_tensors='pt' |
| ) |
| item = {key: val.squeeze(0) for key, val in encoding.items()} |
| if isinstance(label, pd.Series): |
| label_values = label.values.astype(np.float32) |
| else: |
| label_values = np.array([label], dtype=np.float32) |
| item['labels'] = torch.tensor(label_values, dtype=torch.float) |
| return item |
|
|
| |
| def global_ap(x): |
| return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1) |
|
|
| class SimSonEncoder(nn.Module): |
| def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1): |
| super(SimSonEncoder, self).__init__() |
| self.config = config |
| self.max_len = max_len |
| self.bert = BertModel(config, add_pooling_layer=False) |
| self.linear = nn.Linear(config.hidden_size, max_len) |
| self.dropout = nn.Dropout(dropout) |
| def forward(self, input_ids, attention_mask=None): |
| if attention_mask is None: |
| attention_mask = input_ids.ne(self.config.pad_token_id) |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| hidden_states = self.dropout(outputs.last_hidden_state) |
| pooled = global_ap(hidden_states) |
| return self.linear(pooled) |
|
|
| class SimSonClassifier(nn.Module): |
| def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1): |
| super(SimSonClassifier, self).__init__() |
| self.encoder = encoder |
| self.clf = nn.Linear(encoder.max_len, num_labels) |
| self.relu = nn.ReLU() |
| self.dropout = nn.Dropout(dropout) |
| def forward(self, input_ids, attention_mask=None): |
| x = self.encoder(input_ids, attention_mask) |
| x = self.relu(self.dropout(x)) |
| logits = self.clf(x) |
| return logits |
|
|
| def load_encoder_params(self, state_dict_path): |
| self.encoder.load_state_dict(torch.load(state_dict_path)) |
|
|
| |
| def get_criterion(task_type, num_labels): |
| if task_type == 'classification': |
| return nn.BCEWithLogitsLoss() |
| elif task_type == 'regression': |
| return nn.MSELoss() |
| else: |
| raise ValueError(f"Unknown task type: {task_type}") |
|
|
| def train_epoch(model, dataloader, optimizer, scheduler, criterion, device): |
| model.train() |
| total_loss = 0 |
| for batch in dataloader: |
| inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| labels = batch['labels'].to(device) |
| optimizer.zero_grad() |
| outputs = model(**inputs) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
| if scheduler is not None: |
| scheduler.step() |
| total_loss += loss.item() |
| return total_loss / len(dataloader) |
|
|
| def calc_val_metrics(model, dataloader, criterion, device, task_type): |
| model.eval() |
| all_labels, all_preds = [], [] |
| total_loss = 0 |
| with torch.no_grad(): |
| for batch in dataloader: |
| inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| labels = batch['labels'].to(device) |
| outputs = model(**inputs) |
| loss = criterion(outputs, labels) |
| total_loss += loss.item() |
| if task_type == 'classification': |
| pred_probs = torch.sigmoid(outputs).cpu().numpy() |
| all_preds.append(pred_probs) |
| all_labels.append(labels.cpu().numpy()) |
| else: |
| |
| preds = outputs.cpu().numpy() |
| all_preds.append(preds) |
| all_labels.append(labels.cpu().numpy()) |
| avg_loss = total_loss / len(dataloader) |
| if task_type == 'classification': |
| y_true = np.concatenate(all_labels) |
| y_pred = np.concatenate(all_preds) |
| try: |
| score = roc_auc_score(y_true, y_pred, average='macro') |
| except Exception: |
| score = 0.0 |
| return avg_loss, score |
| else: |
| return avg_loss, None |
|
|
| def test_model(model, dataloader, device, task_type): |
| model.eval() |
| all_preds, all_labels = [], [] |
| with torch.no_grad(): |
| for batch in dataloader: |
| inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| labels = batch['labels'] |
| outputs = model(**inputs) |
| if task_type == 'classification': |
| preds = torch.sigmoid(outputs) |
| else: |
| preds = outputs |
| all_preds.append(preds.cpu().numpy()) |
| all_labels.append(labels.numpy()) |
| return np.concatenate(all_preds), np.concatenate(all_labels) |
|
|
| |
| def create_objective(name, info, train_smiles, train_labels, val_smiles, val_labels, |
| test_smiles, test_labels, scaler, tokenizer, encoder_config, device): |
| """Creates objective function for Optuna optimization""" |
| |
| def objective(trial): |
| |
| lr = trial.suggest_float('lr', 1e-6, 1e-4, log=True) |
| batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128, 256]) |
| dropout = trial.suggest_float('dropout', 0.1, 0.5) |
| weight_decay = trial.suggest_float('weight_decay', 0.0, 0.1) |
| scheduler_type = trial.suggest_categorical('scheduler', ['plateau', 'cosine', 'step']) |
| |
| |
| patience_lr = trial.suggest_int('patience_lr', 3, 10) |
| gamma = trial.suggest_float('gamma', 0.5, 0.9) if scheduler_type == 'step' else 0.1 |
| |
| try: |
| |
| train_dataset = MoleculeDataset(train_smiles, train_labels, tokenizer, 512) |
| val_dataset = MoleculeDataset(val_smiles, val_labels, tokenizer, 512) |
| test_dataset = MoleculeDataset(test_smiles, test_labels, tokenizer, 512) |
| |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) |
|
|
| |
| encoder = SimSonEncoder(encoder_config, 512, dropout=dropout) |
| encoder = torch.compile(encoder) |
| model = SimSonClassifier(encoder, num_labels=info['num_labels'], dropout=dropout).to(device) |
| model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin') |
|
|
| criterion = get_criterion(info['task_type'], info['num_labels']) |
| optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) |
| |
| |
| if scheduler_type == 'plateau': |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='max', factor=gamma, patience=patience_lr |
| ) |
| elif scheduler_type == 'cosine': |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50) |
| else: |
| scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=gamma) |
|
|
| |
| best_val_metric = -np.inf |
| patience_counter = 0 |
| patience = 15 |
| |
| for epoch in range(50): |
| train_loss = train_epoch(model, train_loader, optimizer, |
| scheduler if scheduler_type == 'cosine' else None, |
| criterion, device) |
| val_loss, val_metric = calc_val_metrics(model, val_loader, criterion, device, info['task_type']) |
| |
| |
| if scheduler_type == 'plateau': |
| scheduler.step(val_loss if val_loss is not None else -val_loss) |
| elif scheduler_type == 'step': |
| scheduler.step() |
| |
| |
| if info['task_type'] == 'classification': |
| current_metric = val_loss if val_loss is not None else 0.0 |
| else: |
| current_metric = -val_loss |
|
|
| |
| if current_metric <= val_loss: |
| best_val_metric = current_metric |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
| if patience_counter >= patience: |
| break |
|
|
| |
| trial.report(current_metric, epoch) |
| if trial.should_prune(): |
| raise optuna.TrialPruned() |
|
|
| return best_val_metric |
| |
| except Exception as e: |
| print(f"Trial failed with error: {e}") |
| return -np.inf |
| |
| return objective |
|
|
| |
| def main(): |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {DEVICE}") |
|
|
| DATASETS_TO_RUN = { |
| |
| |
| |
| |
| 'clintox': {'task_type': 'classification', 'num_labels': 2, 'split': 'scaffold'}, |
| 'tox21': {'task_type': 'classification', 'num_labels': 12, 'split': 'random'}, |
| 'bbbp': {'task_type': 'classification', 'num_labels': 1, 'split': 'scaffold'}, |
| 'hiv': {'task_type': 'classification', 'num_labels': 1, 'split': 'scaffold'}, |
| } |
| |
| MAX_LEN = 512 |
| N_TRIALS = 100 |
| |
| TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') |
| ENCODER_CONFIG = BertConfig( |
| vocab_size=TOKENIZER.vocab_size, |
| hidden_size=768, |
| num_hidden_layers=4, |
| num_attention_heads=12, |
| intermediate_size=2048, |
| max_position_embeddings=512 |
| ) |
|
|
| aggregated_results = {} |
|
|
| for name, info in DATASETS_TO_RUN.items(): |
| print(f"\n{'='*20} Processing Dataset: {name.upper()} ({info['split']} split) {'='*20}") |
| smiles, labels = load_lists_from_url(name) |
|
|
| |
| scaler = None |
| if info["task_type"] == "regression": |
| scaler = StandardScaler() |
| all_labels = labels.values.reshape(-1, 1) |
| scaler.fit(all_labels) |
| labels = pd.Series(scaler.transform(all_labels).flatten(), index=labels.index) |
|
|
| |
| if info.get('split', 'scaffold') == 'scaffold': |
| splitter = ScaffoldSplitter(data=name, seed=42) |
| train_idx, val_idx, test_idx = splitter.scaffold_split() |
| elif info['split'] == 'random': |
| train_idx, val_idx, test_idx = random_split_indices(len(smiles), seed=42) |
| else: |
| raise ValueError(f"Unknown split type for {name}: {info['split']}") |
|
|
| train_smiles = smiles.iloc[train_idx].reset_index(drop=True) |
| train_labels = labels.iloc[train_idx].reset_index(drop=True) |
| val_smiles = smiles.iloc[val_idx].reset_index(drop=True) |
| val_labels = labels.iloc[val_idx].reset_index(drop=True) |
| test_smiles = smiles.iloc[test_idx].reset_index(drop=True) |
| test_labels = labels.iloc[test_idx].reset_index(drop=True) |
| print(f"Data split - Train: {len(train_smiles)}, Val: {len(val_smiles)}, Test: {len(test_smiles)}") |
|
|
| |
| study = optuna.create_study( |
| direction='maximize', |
| pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10) |
| ) |
| |
| |
| objective_func = create_objective( |
| name, info, train_smiles, train_labels, val_smiles, val_labels, |
| test_smiles, test_labels, scaler, TOKENIZER, ENCODER_CONFIG, DEVICE |
| ) |
| |
| |
| print(f"Starting Optuna optimization with {N_TRIALS} trials...") |
| study.optimize(objective_func, n_trials=N_TRIALS, timeout=None) |
| |
| |
| best_params = study.best_params |
| best_score = study.best_value |
| print(f"Best parameters: {best_params}") |
| print(f"Best validation score: {0:.4f}") |
| |
| |
| print("Training final model with best parameters...") |
| train_dataset = MoleculeDataset(train_smiles, train_labels, TOKENIZER, MAX_LEN) |
| val_dataset = MoleculeDataset(val_smiles, val_labels, TOKENIZER, MAX_LEN) |
| test_dataset = MoleculeDataset(test_smiles, test_labels, TOKENIZER, MAX_LEN) |
| |
| train_loader = DataLoader(train_dataset, batch_size=best_params['batch_size'], shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=best_params['batch_size'], shuffle=False) |
| test_loader = DataLoader(test_dataset, batch_size=best_params['batch_size'], shuffle=False) |
|
|
| |
| encoder = SimSonEncoder(ENCODER_CONFIG, 512, dropout=best_params['dropout']) |
| encoder = torch.compile(encoder) |
| model = SimSonClassifier(encoder, num_labels=info['num_labels'], dropout=best_params['dropout']).to(DEVICE) |
| model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin') |
|
|
| criterion = get_criterion(info['task_type'], info['num_labels']) |
| optimizer = optim.Adam(model.parameters(), lr=best_params['lr'], weight_decay=best_params['weight_decay']) |
| |
| |
| if best_params['scheduler'] == 'plateau': |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='max', factor=best_params.get('gamma', 0.7), |
| patience=best_params.get('patience_lr', 5) |
| ) |
| elif best_params['scheduler'] == 'cosine': |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50) |
| else: |
| scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=best_params.get('gamma', 0.1)) |
|
|
| |
| best_val_metric = -np.inf |
| best_model_state = None |
| patience_counter = 0 |
| patience = 15 |
| |
| for epoch in range(50): |
| train_loss = train_epoch(model, train_loader, optimizer, |
| scheduler if best_params['scheduler'] == 'cosine' else None, |
| criterion, DEVICE) |
| val_loss, val_metric = calc_val_metrics(model, val_loader, criterion, DEVICE, info['task_type']) |
| |
| if best_params['scheduler'] == 'plateau': |
| scheduler.step(val_loss if val_loss is not None else -val_loss) |
| elif best_params['scheduler'] == 'step': |
| scheduler.step() |
| |
| if info['task_type'] == 'classification': |
| print(f"Epoch {epoch+1}/50 | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | ROC AUC: {val_metric:.4f}") |
| current_metric = val_metric if val_metric is not None else 0.0 |
| else: |
| print(f"Epoch {epoch+1}/50 | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}") |
| current_metric = -val_loss |
|
|
| if current_metric <= val_loss: |
| best_val_metric = current_metric |
| best_model_state = copy.deepcopy(model.state_dict()) |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
| if patience_counter >= patience: |
| print(f'Early stopping at epoch {epoch+1}') |
| break |
|
|
| |
| if best_model_state is not None: |
| model.load_state_dict(best_model_state) |
| |
| test_preds, test_true = test_model(model, test_loader, DEVICE, info['task_type']) |
| |
| |
| if info['task_type'] == 'regression' and scaler is not None: |
| test_preds = scaler.inverse_transform(test_preds.reshape(-1, 1)).flatten() |
| test_true = scaler.inverse_transform(test_true.reshape(-1, 1)).flatten() |
| rmse = root_mean_squared_error(test_true, test_preds) |
| mae = mean_absolute_error(test_true, test_preds) |
| final_score = -rmse |
| print(f"Test RMSE: {rmse:.4f}, MAE: {mae:.4f}") |
| else: |
| try: |
| final_score = roc_auc_score(test_true, test_preds, average='macro') |
| print(f"Test ROC AUC: {final_score:.4f}") |
| except Exception: |
| final_score = 0.0 |
|
|
| |
| print("Creating pre-computed augmented SMILES for similarity computation...") |
| test_smiles_list = list(test_smiles) |
| similarity_file_path = f"{name}_test_augmented.parquet" |
| create_augmented_smiles_file(test_smiles_list, similarity_file_path, num_augmentations=1) |
| |
| |
| similarity_dataset = PrecomputedContrastiveSmilesDataset( |
| TOKENIZER, similarity_file_path, max_length=MAX_LEN |
| ) |
| |
| similarities = compute_embedding_similarity_precomputed( |
| model.encoder, similarity_dataset, DEVICE |
| ) |
| print(f"Similarity score: {similarities.mean():.4f}") |
| |
| |
| if os.path.exists(similarity_file_path): |
| os.remove(similarity_file_path) |
|
|
| aggregated_results[name] = { |
| 'best_score': final_score, |
| 'best_params': best_params, |
| 'optuna_trials': len(study.trials), |
| 'study': study, |
| 'similarity_score': similarities.mean() |
| } |
|
|
| if name == 'do_not_save': |
| torch.save(model.encoder.state_dict(), 'moleculenet_clintox_encoder.bin') |
|
|
| print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}") |
| for name, result in aggregated_results.items(): |
| print(f"{name}: Best score: {result['best_score']:.4f}") |
| print(f" Best parameters: {result['best_params']}") |
| print(f" Total trials: {result['optuna_trials']}") |
| print(f" Similarity score: {result['similarity_score']:.4f}") |
|
|
| print("\nScript finished.") |
|
|
| if __name__ == '__main__': |
| main() |
|
|