|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import time
|
|
|
import random
|
|
|
import pandas as pd
|
|
|
from pathlib import Path
|
|
|
from datetime import datetime
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
from FastChemTokenizerHF import FastChemTokenizer
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
from rdkit import Chem
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from ranger21 import Ranger21
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
from scipy.stats import entropy
|
|
|
import json
|
|
|
import math
|
|
|
from typing import Optional, Tuple, Union
|
|
|
from rdkit import RDLogger
|
|
|
RDLogger.DisableLog('rdApp.*')
|
|
|
|
|
|
def set_seed(seed=42):
|
|
|
torch.manual_seed(seed)
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
np.random.seed(seed)
|
|
|
random.seed(seed)
|
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
set_seed(42)
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
print(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_path = "../data/sample_1k_smi_42.csv"
|
|
|
df = pd.read_csv(data_path)
|
|
|
|
|
|
if 'SMILES' not in df.columns:
|
|
|
raise ValueError("Expected column 'SMILES' in CSV")
|
|
|
|
|
|
smiles_list = df['SMILES'].dropna().tolist()
|
|
|
print(f"Loaded {len(smiles_list)} SMILES (assumed pre-canonicalized)")
|
|
|
|
|
|
|
|
|
|
|
|
def is_valid_smiles(smiles):
|
|
|
return Chem.MolFromSmiles(smiles) is not None
|
|
|
|
|
|
print("Validating SMILES with RDKit...")
|
|
|
valid_mask = [is_valid_smiles(s) for s in tqdm(smiles_list)]
|
|
|
smiles_list = [s for s, valid in zip(smiles_list, valid_mask) if valid]
|
|
|
print(f"After RDKit filtering: {len(smiles_list)} valid SMILES")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_smiles, temp_smiles = train_test_split(smiles_list, test_size=0.2, random_state=42, shuffle=True)
|
|
|
val_smiles, test_smiles = train_test_split(temp_smiles, test_size=0.5, random_state=42, shuffle=True)
|
|
|
|
|
|
print(f"Train: {len(train_smiles)}")
|
|
|
print(f"Val: {len(val_smiles)}")
|
|
|
print(f"Test: {len(test_smiles)}")
|
|
|
|
|
|
|
|
|
splits = {'train': train_smiles, 'val': val_smiles, 'test': test_smiles}
|
|
|
for split_name, smiles in splits.items():
|
|
|
with open(f"../data/{split_name}_smiles.txt", "w") as f:
|
|
|
f.write("\n".join(smiles))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenizerWrapper:
|
|
|
def __init__(self, tokenizer, name,
|
|
|
bos_token="<s>", eos_token="</s>",
|
|
|
pad_token="<pad>", unk_token="<unk>"):
|
|
|
self.tokenizer = tokenizer
|
|
|
self.name = name
|
|
|
|
|
|
|
|
|
if hasattr(tokenizer, "add_special_tokens") and callable(tokenizer.add_special_tokens):
|
|
|
try:
|
|
|
tokenizer.add_special_tokens({
|
|
|
"bos_token": bos_token,
|
|
|
"eos_token": eos_token,
|
|
|
"pad_token": pad_token,
|
|
|
"unk_token": unk_token,
|
|
|
})
|
|
|
except NotImplementedError:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def encode(self, smiles: str, add_special_tokens: bool = True):
|
|
|
return self.tokenizer(
|
|
|
smiles,
|
|
|
add_special_tokens=add_special_tokens,
|
|
|
return_attention_mask=False,
|
|
|
return_tensors=None
|
|
|
)
|
|
|
|
|
|
def decode(self, token_ids, skip_special_tokens=True):
|
|
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.tokenizer)
|
|
|
|
|
|
def get_vocab(self):
|
|
|
return self.tokenizer.get_vocab()
|
|
|
|
|
|
@property
|
|
|
def bos_token_id(self):
|
|
|
return self.tokenizer.bos_token_id
|
|
|
|
|
|
@property
|
|
|
def eos_token_id(self):
|
|
|
return self.tokenizer.eos_token_id
|
|
|
|
|
|
@property
|
|
|
def pad_token_id(self):
|
|
|
return self.tokenizer.pad_token_id
|
|
|
|
|
|
@property
|
|
|
def unk_token_id(self):
|
|
|
return self.tokenizer.unk_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tok1_hf = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
|
|
tok2_fast = FastChemTokenizer.from_pretrained("../smitok_core")
|
|
|
|
|
|
tokenizer1 = TokenizerWrapper(tok1_hf, name="ChemBERTa", bos_token="<s>", eos_token="</s>", pad_token="<pad>", unk_token="<unk>")
|
|
|
tokenizer2 = TokenizerWrapper(tok2_fast, name="FastChemTokenizerHF", bos_token="<s>", eos_token="</s>", pad_token="<pad>", unk_token="<unk>")
|
|
|
|
|
|
TOKENIZERS = [tokenizer1, tokenizer2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def benchmark_tokenizer(tokenizer, smiles_sample, encode_only=False):
|
|
|
V = len(tokenizer)
|
|
|
sample = smiles_sample[:10000] if len(smiles_sample) > 10000 else smiles_sample
|
|
|
|
|
|
encode_times, token_counts, char_counts = [], [], []
|
|
|
unk_counts, total_tokens = 0, 0
|
|
|
|
|
|
for smiles in tqdm(sample, desc=f"Encoding with {tokenizer.name}", leave=False):
|
|
|
char_counts.append(len(smiles))
|
|
|
start = time.perf_counter()
|
|
|
enc = tokenizer.encode(smiles, add_special_tokens=True)
|
|
|
end = time.perf_counter()
|
|
|
encode_times.append(end - start)
|
|
|
|
|
|
input_ids = enc['input_ids']
|
|
|
token_counts.append(len(input_ids))
|
|
|
total_tokens += len(input_ids)
|
|
|
unk_id = tokenizer.tokenizer.unk_token_id
|
|
|
unk_counts += input_ids.count(unk_id)
|
|
|
|
|
|
L_bar = np.mean(token_counts)
|
|
|
C = np.mean(char_counts) / L_bar
|
|
|
U = unk_counts / total_tokens if total_tokens > 0 else 0.0
|
|
|
Tenc = len(sample) / sum(encode_times)
|
|
|
|
|
|
metrics = {
|
|
|
'vocab_size': V,
|
|
|
'avg_tokens_per_mol': L_bar,
|
|
|
'compression_ratio': C,
|
|
|
'percent_unknown': U * 100,
|
|
|
'encode_throughput_smiles_per_sec': Tenc,
|
|
|
}
|
|
|
|
|
|
if encode_only:
|
|
|
return metrics
|
|
|
|
|
|
decode_times, reconstruction_ok = [], 0
|
|
|
|
|
|
for smiles in tqdm(sample, desc=f"Decoding with {tokenizer.name}", leave=False):
|
|
|
enc = tokenizer.encode(smiles, add_special_tokens=True)
|
|
|
input_ids = enc['input_ids']
|
|
|
start = time.perf_counter()
|
|
|
decoded = tokenizer.decode(input_ids, skip_special_tokens=True)
|
|
|
end = time.perf_counter()
|
|
|
decode_times.append(end - start)
|
|
|
if decoded == smiles:
|
|
|
reconstruction_ok += 1
|
|
|
|
|
|
Tdec = len(sample) / sum(decode_times)
|
|
|
recon_acc = reconstruction_ok / len(sample)
|
|
|
|
|
|
metrics.update({
|
|
|
'decode_throughput_smiles_per_sec': Tdec,
|
|
|
'decode_reconstruction_accuracy': recon_acc * 100,
|
|
|
})
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark_sample = train_smiles
|
|
|
results = []
|
|
|
|
|
|
for tokenizer in TOKENIZERS:
|
|
|
print(f"\n=== Benchmarking {tokenizer.name} ===")
|
|
|
metrics = benchmark_tokenizer(tokenizer, benchmark_sample)
|
|
|
metrics['tokenizer'] = tokenizer.name
|
|
|
results.append(metrics)
|
|
|
for k, v in metrics.items():
|
|
|
if k != 'tokenizer':
|
|
|
print(f"{k:35s}: {v:.4f}" if isinstance(v, float) else f"{k:35s}: {v}")
|
|
|
|
|
|
df_results = pd.DataFrame(results)
|
|
|
df_results.to_csv("tokenizer_benchmark_results.csv", index=False)
|
|
|
print("\nTokenizer benchmark results saved to 'tokenizer_benchmark_results.csv'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Tuple, Optional
|
|
|
|
|
|
class MoleculeVAE(nn.Module):
|
|
|
"""
|
|
|
Optimized MoleculeVAE with:
|
|
|
- Bidirectional encoder (restored)
|
|
|
- Proper latent2hidden + latent2cell (restored)
|
|
|
- Adjustable dropout for small dataset
|
|
|
- Attention pooling option
|
|
|
- Quantization-ready hooks
|
|
|
"""
|
|
|
|
|
|
def __init__(self,
|
|
|
vocab_size: int,
|
|
|
embed_dim: int = 128,
|
|
|
hidden_dim: int = 256,
|
|
|
latent_dim: int = 128,
|
|
|
num_layers: int = 2,
|
|
|
pad_token_id: int = 0,
|
|
|
bos_token_id: int = 1,
|
|
|
eos_token_id: int = 2,
|
|
|
dropout: float = 0.2,
|
|
|
use_attention: bool = True,
|
|
|
quantize_ready: bool = False):
|
|
|
super().__init__()
|
|
|
self.vocab_size = vocab_size
|
|
|
self.embed_dim = embed_dim
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.latent_dim = latent_dim
|
|
|
self.num_layers = num_layers
|
|
|
self.pad_token_id = pad_token_id
|
|
|
self.bos_token_id = bos_token_id
|
|
|
self.eos_token_id = eos_token_id
|
|
|
self.use_attention = use_attention
|
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
|
|
|
|
|
|
|
|
|
self.encoder_lstm = nn.LSTM(
|
|
|
embed_dim, hidden_dim, num_layers,
|
|
|
batch_first=True, dropout=dropout if num_layers > 1 else 0,
|
|
|
bidirectional=True
|
|
|
)
|
|
|
|
|
|
|
|
|
if use_attention:
|
|
|
self.attention = nn.MultiheadAttention(
|
|
|
hidden_dim * 2, num_heads=4, dropout=dropout, batch_first=True
|
|
|
)
|
|
|
self.attention_linear = nn.Linear(hidden_dim * 2, 1)
|
|
|
|
|
|
self.encoder_norm = nn.LayerNorm(hidden_dim * 2)
|
|
|
|
|
|
|
|
|
self.fc_mu = nn.Linear(hidden_dim * 2, latent_dim)
|
|
|
self.fc_logvar = nn.Linear(hidden_dim * 2, latent_dim)
|
|
|
|
|
|
|
|
|
self.latent2hidden = nn.Linear(latent_dim, num_layers * hidden_dim)
|
|
|
self.latent2cell = nn.Linear(latent_dim, num_layers * hidden_dim)
|
|
|
|
|
|
|
|
|
self.decoder_lstm = nn.LSTM(
|
|
|
embed_dim, hidden_dim, num_layers,
|
|
|
batch_first=True, dropout=dropout if num_layers > 1 else 0
|
|
|
)
|
|
|
self.decoder_norm = nn.LayerNorm(hidden_dim)
|
|
|
self.fc_out = nn.Linear(hidden_dim, vocab_size)
|
|
|
|
|
|
|
|
|
if embed_dim == hidden_dim:
|
|
|
self.fc_out.weight = self.embedding.weight
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
if quantize_ready:
|
|
|
self.quant = torch.quantization.QuantStub()
|
|
|
self.dequant = torch.quantization.DeQuantStub()
|
|
|
else:
|
|
|
self.quant = self.dequant = nn.Identity()
|
|
|
|
|
|
self._init_weights()
|
|
|
|
|
|
def _init_weights(self):
|
|
|
for name, param in self.named_parameters():
|
|
|
if 'weight' in name:
|
|
|
if param.ndim >= 2:
|
|
|
nn.init.xavier_uniform_(param)
|
|
|
else:
|
|
|
nn.init.normal_(param, 0, 0.01)
|
|
|
elif 'bias' in name:
|
|
|
nn.init.zeros_(param)
|
|
|
|
|
|
def _pool_sequence(self, packed_output, lengths):
|
|
|
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
|
|
|
if self.use_attention:
|
|
|
attn_out, _ = self.attention(output, output, output)
|
|
|
weights = torch.softmax(self.attention_linear(attn_out), dim=1)
|
|
|
pooled = (weights * output).sum(dim=1)
|
|
|
else:
|
|
|
|
|
|
batch_size, max_len, _ = output.size()
|
|
|
mask = torch.arange(max_len, device=output.device).expand(batch_size, max_len) < lengths.unsqueeze(1)
|
|
|
masked_output = output * mask.unsqueeze(-1).float()
|
|
|
pooled = masked_output.sum(dim=1) / lengths.unsqueeze(-1).float()
|
|
|
return pooled
|
|
|
|
|
|
def encode(self, x: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
x = self.quant(x)
|
|
|
embedded = self.dropout(self.embedding(x))
|
|
|
packed = nn.utils.rnn.pack_padded_sequence(
|
|
|
embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
|
|
|
)
|
|
|
packed_out, _ = self.encoder_lstm(packed)
|
|
|
h = self._pool_sequence(packed_out, lengths)
|
|
|
h = self.encoder_norm(h)
|
|
|
mu, logvar = self.fc_mu(h), self.fc_logvar(h)
|
|
|
return mu, logvar
|
|
|
|
|
|
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
|
|
if self.training:
|
|
|
std = torch.exp(0.5 * logvar)
|
|
|
eps = torch.randn_like(std)
|
|
|
return mu + eps * std
|
|
|
return mu
|
|
|
|
|
|
def _init_decoder_state(self, z: torch.Tensor):
|
|
|
batch_size = z.size(0)
|
|
|
h0 = self.latent2hidden(z).view(self.num_layers, batch_size, self.hidden_dim)
|
|
|
c0 = self.latent2cell(z).view(self.num_layers, batch_size, self.hidden_dim)
|
|
|
return h0, c0
|
|
|
|
|
|
def decode(self, z: torch.Tensor, max_length: int = 64, mode: str = "greedy", temperature: float = 1.0):
|
|
|
batch_size = z.size(0)
|
|
|
device = z.device
|
|
|
h0, c0 = self._init_decoder_state(z)
|
|
|
hidden = (h0, c0)
|
|
|
|
|
|
input_ids = torch.full((batch_size, 1), self.bos_token_id, dtype=torch.long, device=device)
|
|
|
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
|
logits_list = []
|
|
|
|
|
|
for _ in range(max_length):
|
|
|
embedded = self.embedding(input_ids)
|
|
|
output, hidden = self.decoder_lstm(embedded, hidden)
|
|
|
output = self.decoder_norm(output)
|
|
|
logit = self.fc_out(output)
|
|
|
logits_list.append(logit)
|
|
|
|
|
|
if mode == "greedy":
|
|
|
next_tokens = logit.argmax(dim=-1)
|
|
|
elif mode == "sample":
|
|
|
probs = F.softmax(logit.squeeze(1) / temperature, dim=-1)
|
|
|
next_tokens = torch.multinomial(probs, 1)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown decode mode: {mode}")
|
|
|
|
|
|
just_finished = (next_tokens.squeeze(-1) == self.eos_token_id)
|
|
|
finished |= just_finished
|
|
|
next_tokens = torch.where(
|
|
|
finished.unsqueeze(-1),
|
|
|
torch.tensor(self.pad_token_id, device=device),
|
|
|
next_tokens
|
|
|
)
|
|
|
input_ids = next_tokens
|
|
|
if finished.all():
|
|
|
break
|
|
|
|
|
|
return self.dequant(torch.cat(logits_list, dim=1))
|
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, lengths: torch.Tensor,
|
|
|
target_seq: Optional[torch.Tensor] = None,
|
|
|
teacher_forcing_ratio: float = 0.0,
|
|
|
temperature: float = 1.0):
|
|
|
mu, logvar = self.encode(input_ids, lengths)
|
|
|
z = self.reparameterize(mu, logvar)
|
|
|
if self.training and target_seq is not None and teacher_forcing_ratio > 0:
|
|
|
return self._forward_teacher_forcing(z, target_seq, teacher_forcing_ratio), mu, logvar
|
|
|
else:
|
|
|
max_len = target_seq.size(1) if target_seq is not None else 64
|
|
|
return self.decode(z, max_length=max_len, temperature=temperature), mu, logvar
|
|
|
|
|
|
def _forward_teacher_forcing(self, z: torch.Tensor, target_seq: torch.Tensor, teacher_forcing_ratio: float):
|
|
|
batch_size, seq_len = target_seq.size()
|
|
|
h0, c0 = self._init_decoder_state(z)
|
|
|
hidden = (h0, c0)
|
|
|
logits_list = []
|
|
|
input_token = target_seq[:, 0:1]
|
|
|
|
|
|
for t in range(1, seq_len):
|
|
|
embedded = self.embedding(input_token)
|
|
|
output, hidden = self.decoder_lstm(embedded, hidden)
|
|
|
output = self.decoder_norm(output)
|
|
|
logit = self.fc_out(output)
|
|
|
logits_list.append(logit)
|
|
|
|
|
|
if torch.rand(1).item() < teacher_forcing_ratio:
|
|
|
input_token = target_seq[:, t:t+1]
|
|
|
else:
|
|
|
input_token = logit.argmax(dim=-1)
|
|
|
|
|
|
return torch.cat(logits_list, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vae_loss(logits, targets, mu, logvar, pad_token_id, beta=1.0):
|
|
|
|
|
|
max_len = max(logits.size(1), targets.size(1))
|
|
|
if logits.size(1) < max_len:
|
|
|
logits = F.pad(logits, (0, 0, 0, max_len - logits.size(1)))
|
|
|
if targets.size(1) < max_len:
|
|
|
targets = F.pad(targets, (0, max_len - targets.size(1)), value=pad_token_id)
|
|
|
|
|
|
logits_flat = logits.view(-1, logits.size(-1))
|
|
|
targets_flat = targets.reshape(-1)
|
|
|
|
|
|
mask = (targets_flat != pad_token_id).float()
|
|
|
ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
|
|
|
mask_sum = mask.sum()
|
|
|
ce_loss = (ce_loss * mask).sum() / (mask_sum + 1e-8)
|
|
|
|
|
|
|
|
|
kl_loss_raw = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
|
|
|
|
|
|
kl_loss = kl_loss_raw.mean()
|
|
|
|
|
|
|
|
|
total_loss = ce_loss + beta * kl_loss
|
|
|
|
|
|
return total_loss, ce_loss, kl_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
|
class KLAnnealer:
|
|
|
def __init__(self, total_steps, n_cycle=1, ratio=0.3, mode="linear", per_epoch=False, steps_per_epoch=None):
|
|
|
self.total_steps = total_steps
|
|
|
self.n_cycle = n_cycle
|
|
|
self.ratio = ratio
|
|
|
self.mode = mode
|
|
|
self.per_epoch = per_epoch
|
|
|
self.steps_per_epoch = steps_per_epoch
|
|
|
self.current_step = 0
|
|
|
self.current_epoch = 0
|
|
|
|
|
|
def get_beta(self, increment=True):
|
|
|
"""Get current KL weight (beta).
|
|
|
Args:
|
|
|
increment (bool): whether to advance the annealer (use False in validation).
|
|
|
"""
|
|
|
if increment:
|
|
|
self.current_step += 1
|
|
|
|
|
|
|
|
|
progress = min(self.current_step / max(self.total_steps, 1.0), 1.0)
|
|
|
|
|
|
|
|
|
if self.n_cycle > 1:
|
|
|
cycle_length = self.total_steps / self.n_cycle
|
|
|
pos_in_cycle = (self.current_step % cycle_length)
|
|
|
cycle_progress = min(pos_in_cycle / max(cycle_length * self.ratio, 1.0), 1.0)
|
|
|
else:
|
|
|
|
|
|
cycle_progress = min(progress / self.ratio, 1.0) if self.ratio > 0 else 1.0
|
|
|
|
|
|
if self.mode == "linear":
|
|
|
beta = min(cycle_progress, 1.0)
|
|
|
elif self.mode == "sigmoid":
|
|
|
k = 6
|
|
|
|
|
|
beta = 1 / (1 + math.exp(-k * (cycle_progress - 0.5)))
|
|
|
elif self.mode == "cosine":
|
|
|
|
|
|
beta = 0.5 * (1 + math.cos(math.pi * (1 - cycle_progress)))
|
|
|
else:
|
|
|
raise ValueError(f"Unknown mode: {self.mode}")
|
|
|
|
|
|
return min(beta, 1.0)
|
|
|
|
|
|
def step(self):
|
|
|
"""Increment the step counter."""
|
|
|
self.current_step += 1
|
|
|
|
|
|
def epoch_step(self):
|
|
|
"""Increment the epoch counter."""
|
|
|
self.current_epoch += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_teacher_forcing_ratio(epoch, num_epochs, min_tfr=0.6, warmup_fraction=0.3):
|
|
|
"""
|
|
|
Linear decay of teacher forcing ratio (TFR).
|
|
|
- Starts at 1.0
|
|
|
- Decays to min_tfr by (warmup_fraction * num_epochs)
|
|
|
- Then stays flat
|
|
|
"""
|
|
|
warmup_epochs = int(num_epochs * warmup_fraction)
|
|
|
if epoch < warmup_epochs:
|
|
|
|
|
|
return 1.0 - (1.0 - min_tfr) * (epoch / warmup_epochs)
|
|
|
else:
|
|
|
return min_tfr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collate_fn(batch, tokenizer, max_length=128):
|
|
|
encodings = [tokenizer.encode(s, add_special_tokens=True) for s in batch]
|
|
|
input_ids = [e['input_ids'] for e in encodings]
|
|
|
|
|
|
max_len = min(max(len(ids) for ids in input_ids), max_length)
|
|
|
padded = []
|
|
|
lengths = []
|
|
|
|
|
|
pad_token_id = tokenizer.tokenizer.pad_token_id
|
|
|
|
|
|
for ids in input_ids:
|
|
|
if len(ids) > max_length:
|
|
|
ids = ids[:max_length]
|
|
|
else:
|
|
|
ids = ids + [pad_token_id] * (max_len - len(ids))
|
|
|
padded.append(ids)
|
|
|
lengths.append(min(len(ids), max_length))
|
|
|
|
|
|
return torch.tensor(padded, dtype=torch.long), torch.tensor(lengths, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SmilesDataset(Dataset):
|
|
|
def __init__(self, smiles_list):
|
|
|
self.smiles_list = smiles_list
|
|
|
def __len__(self):
|
|
|
return len(self.smiles_list)
|
|
|
def __getitem__(self, idx):
|
|
|
return self.smiles_list[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LEARNING_RATE = 1e-5
|
|
|
BATCH_SIZE = 16
|
|
|
ACCUMULATION_STEPS = 4
|
|
|
NUM_EPOCHS = 5
|
|
|
MAX_SEQ_LEN = 128
|
|
|
KL_ANNEAL_RATIO = 0.3
|
|
|
|
|
|
def train_vae(
|
|
|
model,
|
|
|
train_loader,
|
|
|
val_loader,
|
|
|
optimizer,
|
|
|
kl_annealer,
|
|
|
pad_token_id,
|
|
|
device,
|
|
|
num_epochs,
|
|
|
accumulation_steps=4,
|
|
|
save_dir="./checkpoints",
|
|
|
tokenizer_name="default"
|
|
|
):
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
log_file = os.path.join(save_dir, f"training_log_{tokenizer_name}.csv")
|
|
|
|
|
|
with open(log_file, "w") as f:
|
|
|
f.write("epoch,step,train_loss,train_ce,train_kl,val_loss,val_ce,val_kl,kl_beta\n")
|
|
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
|
|
for epoch in range(num_epochs):
|
|
|
print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
|
|
|
model.train()
|
|
|
total_train_loss = total_train_ce = total_train_kl = 0.0
|
|
|
num_batches = 0
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
for step, (input_ids, lengths) in enumerate(tqdm(train_loader, desc="Training")):
|
|
|
input_ids, lengths = input_ids.to(device), lengths.to(device)
|
|
|
|
|
|
|
|
|
tfr = get_teacher_forcing_ratio(epoch, num_epochs, min_tfr=0.6, warmup_fraction=0.3)
|
|
|
|
|
|
logits, mu, logvar = model(input_ids, lengths, target_seq=input_ids, teacher_forcing_ratio=tfr)
|
|
|
beta = kl_annealer.get_beta(increment=True)
|
|
|
loss, ce_loss, kl_loss = vae_loss(logits, input_ids, mu, logvar, pad_token_id, beta=beta)
|
|
|
|
|
|
loss = loss / accumulation_steps
|
|
|
loss.backward()
|
|
|
|
|
|
total_train_loss += loss.item() * accumulation_steps
|
|
|
total_train_ce += ce_loss.item()
|
|
|
total_train_kl += kl_loss.item()
|
|
|
num_batches += 1
|
|
|
|
|
|
if (step + 1) % accumulation_steps == 0:
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
if len(train_loader) % accumulation_steps != 0:
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
|
|
current_beta = kl_annealer.get_beta(increment=False)
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
total_val_loss = total_val_ce = total_val_kl = 0.0
|
|
|
val_batches = 0
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for input_ids, lengths in tqdm(val_loader, desc="Validating"):
|
|
|
input_ids, lengths = input_ids.to(device), lengths.to(device)
|
|
|
|
|
|
logits, mu, logvar = model(input_ids, lengths, target_seq=input_ids, teacher_forcing_ratio=0.0)
|
|
|
loss, ce_loss, kl_loss = vae_loss(logits, input_ids, mu, logvar, pad_token_id, beta=current_beta)
|
|
|
|
|
|
total_val_loss += loss.item()
|
|
|
total_val_ce += ce_loss.item()
|
|
|
total_val_kl += kl_loss.item()
|
|
|
val_batches += 1
|
|
|
|
|
|
avg_train_loss = total_train_loss / num_batches
|
|
|
avg_val_loss = total_val_loss / val_batches
|
|
|
|
|
|
current_step = (epoch + 1) * len(train_loader)
|
|
|
with open(log_file, "a") as f:
|
|
|
f.write(f"{epoch+1},{current_step},{avg_train_loss:.6f},{total_train_ce/num_batches:.6f},{total_train_kl/num_batches:.6f},"
|
|
|
f"{avg_val_loss:.6f},{total_val_ce/val_batches:.6f},{total_val_kl/val_batches:.6f},{current_beta:.6f}\n")
|
|
|
|
|
|
print(f"Train Loss: {avg_train_loss:.4f}")
|
|
|
print(f"Val Loss: {avg_val_loss:.4f}")
|
|
|
print(f"KL Beta: {current_beta:.4f}")
|
|
|
|
|
|
if avg_val_loss < best_val_loss:
|
|
|
best_val_loss = avg_val_loss
|
|
|
checkpoint_path = os.path.join(save_dir, f"best_model_{tokenizer_name}.pt")
|
|
|
torch.save({
|
|
|
'epoch': epoch + 1,
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
'val_loss': avg_val_loss,
|
|
|
}, checkpoint_path)
|
|
|
print(f"β Saved best model to {checkpoint_path}")
|
|
|
|
|
|
return best_val_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for tokenizer in TOKENIZERS:
|
|
|
print(f"\n STARTING TRAINING FOR: {tokenizer.name}\n")
|
|
|
|
|
|
vocab_size = len(tokenizer)
|
|
|
pad_token_id = tokenizer.tokenizer.pad_token_id
|
|
|
|
|
|
|
|
|
sample_ids = tokenizer.encode(train_smiles[0], add_special_tokens=True)['input_ids']
|
|
|
max_id_in_sample = max(sample_ids)
|
|
|
assert max_id_in_sample < vocab_size, f"Token ID {max_id_in_sample} >= vocab size {vocab_size} in {tokenizer.name}"
|
|
|
|
|
|
model = MoleculeVAE(
|
|
|
vocab_size=len(tokenizer),
|
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
|
bos_token_id=tokenizer.bos_token_id,
|
|
|
eos_token_id=tokenizer.eos_token_id
|
|
|
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = Ranger21(
|
|
|
model.parameters(),
|
|
|
lr=LEARNING_RATE,
|
|
|
weight_decay=0.01,
|
|
|
use_adabelief=True,
|
|
|
use_warmup=True,
|
|
|
use_madgrad=True,
|
|
|
num_epochs=NUM_EPOCHS,
|
|
|
num_batches_per_epoch=len(train_smiles) // (BATCH_SIZE * ACCUMULATION_STEPS),
|
|
|
warmdown_active=False,
|
|
|
)
|
|
|
|
|
|
train_dataset = SmilesDataset(train_smiles)
|
|
|
val_dataset = SmilesDataset(val_smiles)
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=BATCH_SIZE,
|
|
|
shuffle=True,
|
|
|
collate_fn=lambda batch: collate_fn(batch, tokenizer, max_length=MAX_SEQ_LEN),
|
|
|
num_workers=0,
|
|
|
pin_memory=True
|
|
|
)
|
|
|
|
|
|
val_loader = DataLoader(
|
|
|
val_dataset,
|
|
|
batch_size=BATCH_SIZE,
|
|
|
shuffle=False,
|
|
|
collate_fn=lambda batch: collate_fn(batch, tokenizer, max_length=MAX_SEQ_LEN),
|
|
|
num_workers=0,
|
|
|
pin_memory=True
|
|
|
)
|
|
|
|
|
|
steps_per_epoch = len(train_loader)
|
|
|
total_steps = steps_per_epoch * NUM_EPOCHS
|
|
|
|
|
|
kl_annealer = KLAnnealer(
|
|
|
total_steps=total_steps,
|
|
|
n_cycle=1,
|
|
|
ratio=0.6,
|
|
|
mode="linear",
|
|
|
per_epoch=False
|
|
|
)
|
|
|
|
|
|
train_vae(
|
|
|
model=model,
|
|
|
train_loader=train_loader,
|
|
|
val_loader=val_loader,
|
|
|
optimizer=optimizer,
|
|
|
kl_annealer=kl_annealer,
|
|
|
pad_token_id=pad_token_id,
|
|
|
device=device,
|
|
|
num_epochs=NUM_EPOCHS,
|
|
|
accumulation_steps=ACCUMULATION_STEPS,
|
|
|
save_dir=f"./checkpoints/{tokenizer.name}",
|
|
|
tokenizer_name=tokenizer.name
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def canonicalize_smiles(smiles):
|
|
|
mol = Chem.MolFromSmiles(smiles)
|
|
|
if mol is None:
|
|
|
return None
|
|
|
return Chem.MolToSmiles(mol, isomericSmiles=True)
|
|
|
|
|
|
def evaluate_reconstruction(model, dataloader, tokenizer, device, max_length=128):
|
|
|
model.eval()
|
|
|
total_token_correct = total_tokens = exact_matches = valid_count = total_samples = 0
|
|
|
all_generated, all_targets = [], []
|
|
|
|
|
|
pad_id = tokenizer.tokenizer.pad_token_id
|
|
|
eos_id = tokenizer.tokenizer.eos_token_id
|
|
|
special_ids = {pad_id, eos_id}
|
|
|
|
|
|
def trim_to_special(ids, specials):
|
|
|
for i, id_ in enumerate(ids):
|
|
|
if id_ in specials:
|
|
|
return ids[:i]
|
|
|
return ids
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for input_ids, lengths in tqdm(dataloader, desc="Evaluating Reconstruction"):
|
|
|
input_ids, lengths = input_ids.to(device), lengths.to(device)
|
|
|
B = input_ids.size(0)
|
|
|
|
|
|
mu, logvar = model.encode(input_ids, lengths)
|
|
|
z = model.reparameterize(mu, logvar)
|
|
|
logits = model.decode(z, max_length=128, mode="greedy")
|
|
|
preds = logits.argmax(dim=-1)
|
|
|
|
|
|
|
|
|
min_len = min(logits.size(1), input_ids.size(1))
|
|
|
preds = preds[:, :min_len]
|
|
|
input_ids_eval = input_ids[:, :min_len]
|
|
|
|
|
|
mask = (input_ids_eval != pad_id)
|
|
|
token_correct = ((preds == input_ids_eval) & mask).sum().item()
|
|
|
total_token_correct += token_correct
|
|
|
total_tokens += mask.sum().item()
|
|
|
|
|
|
for i in range(B):
|
|
|
target_ids = input_ids_eval[i].cpu().tolist()
|
|
|
pred_ids = preds[i].cpu().tolist()
|
|
|
|
|
|
|
|
|
target_ids_trim = trim_to_special(target_ids, special_ids)
|
|
|
pred_ids_trim = trim_to_special(pred_ids, special_ids)
|
|
|
|
|
|
target_smiles = tokenizer.decode(target_ids_trim, skip_special_tokens=False)
|
|
|
pred_smiles = tokenizer.decode(pred_ids_trim, skip_special_tokens=False)
|
|
|
|
|
|
all_targets.append(target_smiles)
|
|
|
all_generated.append(pred_smiles)
|
|
|
|
|
|
if pred_smiles == target_smiles:
|
|
|
exact_matches += 1
|
|
|
if Chem.MolFromSmiles(pred_smiles) is not None:
|
|
|
valid_count += 1
|
|
|
total_samples += 1
|
|
|
|
|
|
token_acc = total_token_correct / total_tokens if total_tokens > 0 else 0.0
|
|
|
exact_match_rate = exact_matches / total_samples
|
|
|
validity_rate = valid_count / total_samples
|
|
|
|
|
|
print(f"Token-level Accuracy: {token_acc:.4f}")
|
|
|
print(f"Exact Match Rate: {exact_match_rate:.4f}")
|
|
|
print(f"Validity Rate: {validity_rate:.4f}")
|
|
|
|
|
|
return {
|
|
|
'token_accuracy': token_acc,
|
|
|
'exact_match_rate': exact_match_rate,
|
|
|
'validity_rate': validity_rate,
|
|
|
'generated_smiles': all_generated,
|
|
|
'target_smiles': all_targets
|
|
|
}
|
|
|
|
|
|
def compute_uniqueness_and_novelty(generated_smiles, train_smiles_set):
|
|
|
total = len(generated_smiles)
|
|
|
unique = len(set(generated_smiles))
|
|
|
novel = len([s for s in generated_smiles if s not in train_smiles_set])
|
|
|
uniqueness = unique / total if total > 0 else 0.0
|
|
|
novelty = novel / total if total > 0 else 0.0
|
|
|
print(f"Uniqueness: {uniqueness:.4f} ({unique}/{total})")
|
|
|
print(f"Novelty: {novelty:.4f} ({novel}/not in train)")
|
|
|
return uniqueness, novelty
|
|
|
|
|
|
def kl_divergence_from_samples(samples, bins=512):
|
|
|
dim_kls = []
|
|
|
for d in range(samples.shape[1]):
|
|
|
data = samples[:, d]
|
|
|
hist, bin_edges = np.histogram(data, bins=bins, density=True)
|
|
|
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
|
|
|
norm_pdf = (1 / np.sqrt(2 * np.pi)) * np.exp(-0.5 * bin_centers**2)
|
|
|
hist = np.clip(hist, 1e-10, None)
|
|
|
norm_pdf = np.clip(norm_pdf, 1e-10, None)
|
|
|
kl = entropy(hist, norm_pdf)
|
|
|
dim_kls.append(kl)
|
|
|
return np.mean(dim_kls)
|
|
|
|
|
|
def evaluate_latent_kl(model, dataloader, device, latent_dim=128, bins=512):
|
|
|
model.eval()
|
|
|
all_z = []
|
|
|
with torch.no_grad():
|
|
|
for input_ids, lengths in tqdm(dataloader, desc="Sampling Latents"):
|
|
|
input_ids, lengths = input_ids.to(device), lengths.to(device)
|
|
|
mu, logvar = model.encode(input_ids, lengths)
|
|
|
z = model.reparameterize(mu, logvar)
|
|
|
all_z.append(z.cpu().numpy())
|
|
|
all_z = np.concatenate(all_z, axis=0)
|
|
|
kl_div = kl_divergence_from_samples(all_z, bins=bins)
|
|
|
print(f"KL Divergence (empirical vs N(0,1)): {kl_div:.4f}")
|
|
|
return kl_div
|
|
|
|
|
|
def evaluate_interpolation_validity(model, tokenizer, test_smiles, device, num_pairs=100, steps=10, max_length=128):
|
|
|
model.eval()
|
|
|
pairs = random.sample(list(zip(test_smiles[::2], test_smiles[1::2])), min(num_pairs, len(test_smiles)//2))
|
|
|
valid_interps = total_interps = 0
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for smiles_a, smiles_b in tqdm(pairs, desc="Interpolation Validity"):
|
|
|
if not smiles_a or not smiles_b: continue
|
|
|
|
|
|
enc_a = tokenizer.encode(smiles_a, add_special_tokens=True)
|
|
|
enc_b = tokenizer.encode(smiles_b, add_special_tokens=True)
|
|
|
|
|
|
ids_a = torch.tensor([enc_a['input_ids']], device=device)
|
|
|
ids_b = torch.tensor([enc_b['input_ids']], device=device)
|
|
|
len_a = torch.tensor([len(enc_a['input_ids'])], device=device)
|
|
|
len_b = torch.tensor([len(enc_b['input_ids'])], device=device)
|
|
|
|
|
|
mu_a, _ = model.encode(ids_a, len_a)
|
|
|
mu_b, _ = model.encode(ids_b, len_b)
|
|
|
|
|
|
alphas = torch.linspace(0, 1, steps, device=device)
|
|
|
for alpha in alphas:
|
|
|
z_interp = alpha * mu_b + (1 - alpha) * mu_a
|
|
|
|
|
|
if z_interp.dim() == 1:
|
|
|
z_interp = z_interp.unsqueeze(0)
|
|
|
|
|
|
logits = model.decode(z_interp, max_length=max_length, mode="sample", temperature=0.8)
|
|
|
preds = logits.argmax(dim=-1)
|
|
|
|
|
|
if preds.dim() > 1:
|
|
|
preds = preds[0]
|
|
|
pred_smiles = tokenizer.decode(preds.cpu().tolist(), skip_special_tokens=True)
|
|
|
if Chem.MolFromSmiles(pred_smiles) is not None:
|
|
|
valid_interps += 1
|
|
|
total_interps += 1
|
|
|
|
|
|
interp_validity = valid_interps / total_interps if total_interps > 0 else 0.0
|
|
|
print(f"Interpolation Validity: {interp_validity:.4f}")
|
|
|
return interp_validity
|
|
|
|
|
|
def sample_from_latent(model, tokenizer, num_samples=30000, latent_dim=128, max_length=128, device=device, temperature=0.8):
|
|
|
model.eval()
|
|
|
generated_smiles = []
|
|
|
with torch.no_grad():
|
|
|
for _ in tqdm(range(0, num_samples, BATCH_SIZE), desc="Sampling from Latent"):
|
|
|
current_batch_size = min(BATCH_SIZE, num_samples - len(generated_smiles))
|
|
|
if current_batch_size <= 0: break
|
|
|
z = torch.randn(current_batch_size, latent_dim, device=device)
|
|
|
logits = model.decode(z, max_length=max_length, mode="sample", temperature=temperature)
|
|
|
preds = logits.argmax(dim=-1)
|
|
|
for i in range(current_batch_size):
|
|
|
pred_ids = preds[i].cpu().tolist()
|
|
|
smiles = tokenizer.decode(pred_ids, skip_special_tokens=True)
|
|
|
generated_smiles.append(smiles)
|
|
|
if len(generated_smiles) >= num_samples: break
|
|
|
return generated_smiles
|
|
|
|
|
|
def measure_inference_throughput(model, tokenizer, test_smiles, device,
|
|
|
max_length=128,
|
|
|
batch_sizes=[1, 4, 8, 16]):
|
|
|
"""
|
|
|
Benchmark inference speed & peak GPU memory across several batch sizes.
|
|
|
Returns a JSON-serialisable dict:
|
|
|
{batch_size: {'tokens_per_sec': <float>, 'peak_mem_mb': <float>}, ...}
|
|
|
"""
|
|
|
model.eval()
|
|
|
results = {}
|
|
|
|
|
|
for bs in batch_sizes:
|
|
|
|
|
|
subset = SmilesDataset(test_smiles[:bs * 10])
|
|
|
loader = DataLoader(
|
|
|
subset,
|
|
|
batch_size=bs,
|
|
|
shuffle=False,
|
|
|
num_workers=0,
|
|
|
collate_fn=lambda b: collate_fn(b, tokenizer, max_length=max_length),
|
|
|
)
|
|
|
|
|
|
total_tokens = 0
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.reset_peak_memory_stats(device)
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
with torch.no_grad():
|
|
|
for input_ids, lengths in loader:
|
|
|
input_ids, lengths = input_ids.to(device), lengths.to(device)
|
|
|
mu, logvar = model.encode(input_ids, lengths)
|
|
|
z = model.reparameterize(mu, logvar)
|
|
|
logits = model.decode(z, max_length=max_length)
|
|
|
total_tokens += logits.numel()
|
|
|
duration = time.perf_counter() - start_time
|
|
|
|
|
|
tokens_per_sec = total_tokens / duration
|
|
|
peak_mem_mb = (
|
|
|
torch.cuda.max_memory_allocated(device) / (1024 ** 2)
|
|
|
if torch.cuda.is_available()
|
|
|
else 0.0
|
|
|
)
|
|
|
|
|
|
|
|
|
results[bs] = {
|
|
|
"tokens_per_sec": float(tokens_per_sec),
|
|
|
"peak_mem_mb": float(peak_mem_mb),
|
|
|
}
|
|
|
print(f"BS {bs:3d} β {tokens_per_sec:8.2f} tok/s | Peak Mem: {peak_mem_mb:.2f} MB")
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def full_evaluation_pipeline(model, tokenizer, train_smiles, test_smiles, device, save_dir):
|
|
|
print(f"\n FULL EVALUATION FOR: {tokenizer.name}")
|
|
|
|
|
|
test_dataset = SmilesDataset(test_smiles)
|
|
|
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
|
|
|
collate_fn=lambda b: collate_fn(b, tokenizer, max_length=MAX_SEQ_LEN),
|
|
|
num_workers=0)
|
|
|
|
|
|
|
|
|
recon_metrics = evaluate_reconstruction(model, test_loader, tokenizer, device)
|
|
|
|
|
|
|
|
|
train_set = set(train_smiles)
|
|
|
uniqueness, novelty = compute_uniqueness_and_novelty(recon_metrics['generated_smiles'], train_set)
|
|
|
|
|
|
|
|
|
kl_div = evaluate_latent_kl(model, test_loader, device)
|
|
|
|
|
|
|
|
|
interp_validity = evaluate_interpolation_validity(model, tokenizer, test_smiles, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_results = {
|
|
|
**recon_metrics,
|
|
|
'uniqueness': uniqueness,
|
|
|
'novelty': novelty,
|
|
|
'kl_divergence': kl_div,
|
|
|
'interpolation_validity': interp_validity,
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
eval_path = os.path.join(save_dir, "evaluation_results.json")
|
|
|
with open(eval_path, "w") as f:
|
|
|
json.dump(eval_results, f, indent=2, default=str)
|
|
|
|
|
|
print(f" Evaluation saved to {eval_path}")
|
|
|
return eval_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for tokenizer in TOKENIZERS:
|
|
|
print(f"\nπ LOADING BEST MODEL FOR: {tokenizer.name}")
|
|
|
checkpoint_path = f"./checkpoints/{tokenizer.name}/best_model_{tokenizer.name}.pt"
|
|
|
if not os.path.exists(checkpoint_path):
|
|
|
print(f"β οΈ Checkpoint not found: {checkpoint_path}")
|
|
|
continue
|
|
|
|
|
|
vocab_size = len(tokenizer)
|
|
|
pad_token_id = tokenizer.tokenizer.pad_token_id
|
|
|
model = MoleculeVAE(
|
|
|
vocab_size=vocab_size,
|
|
|
pad_token_id=pad_token_id,
|
|
|
bos_token_id=tokenizer.bos_token_id,
|
|
|
eos_token_id=tokenizer.eos_token_id
|
|
|
).to(device)
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
model.eval()
|
|
|
|
|
|
full_evaluation_pipeline(
|
|
|
model=model,
|
|
|
tokenizer=tokenizer,
|
|
|
train_smiles=train_smiles,
|
|
|
test_smiles=test_smiles,
|
|
|
device=device,
|
|
|
save_dir=f"./checkpoints/{tokenizer.name}"
|
|
|
)
|
|
|
|
|
|
print("\nπ PIPELINE COMPLETE β ALL TOKENIZERS BENCHMARKED, TRAINED, AND EVALUATED!")
|
|
|
|