FastChemTokenizer / benchmark /benchmark_HF_efficient.py
gbyuvd's picture
Upload benchmark script and set
70ecb45 verified
#
# Molecule Tokenizer Benchmark & VAE Training Pipeline
# PATCHED VERSION β€” Updated for FastChemTokenizerHF (HF compatible)
#
#
# Step 1.1 β€” Imports & Reproducibility
#
import os
import time
import random
import pandas as pd
from pathlib import Path
from datetime import datetime
import torch
import numpy as np
# Tokenizers
from transformers import AutoTokenizer
from FastChemTokenizerHF import FastChemTokenizer
# Optional: for progress bars
from tqdm import tqdm
from rdkit import Chem
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F
from ranger21 import Ranger21
from torch.utils.data import DataLoader, Dataset
from scipy.stats import entropy
import json
import math
from typing import Optional, Tuple, Union
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
# Set seeds for reproducibility
def set_seed(seed=42):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
#
# Step 1.2 β€” Load & Preprocess SMILES Corpus
#
data_path = "../data/sample_1k_smi_42.csv"
df = pd.read_csv(data_path)
if 'SMILES' not in df.columns:
raise ValueError("Expected column 'SMILES' in CSV")
smiles_list = df['SMILES'].dropna().tolist()
print(f"Loaded {len(smiles_list)} SMILES (assumed pre-canonicalized)")
# Validate with RDKit
def is_valid_smiles(smiles):
return Chem.MolFromSmiles(smiles) is not None
print("Validating SMILES with RDKit...")
valid_mask = [is_valid_smiles(s) for s in tqdm(smiles_list)]
smiles_list = [s for s, valid in zip(smiles_list, valid_mask) if valid]
print(f"After RDKit filtering: {len(smiles_list)} valid SMILES")
#
# Step 1.3 β€” Train/Val/Test Split (80/10/10)
#
train_smiles, temp_smiles = train_test_split(smiles_list, test_size=0.2, random_state=42, shuffle=True)
val_smiles, test_smiles = train_test_split(temp_smiles, test_size=0.5, random_state=42, shuffle=True)
print(f"Train: {len(train_smiles)}")
print(f"Val: {len(val_smiles)}")
print(f"Test: {len(test_smiles)}")
# Cache splits
splits = {'train': train_smiles, 'val': val_smiles, 'test': test_smiles}
for split_name, smiles in splits.items():
with open(f"../data/{split_name}_smiles.txt", "w") as f:
f.write("\n".join(smiles))
#
# Step 1.4 β€” Tokenizer Wrapper (Simplified for HF compatibility)
#
class TokenizerWrapper:
def __init__(self, tokenizer, name,
bos_token="<s>", eos_token="</s>",
pad_token="<pad>", unk_token="<unk>"):
self.tokenizer = tokenizer
self.name = name
# Only call add_special_tokens if the tokenizer actually supports it
if hasattr(tokenizer, "add_special_tokens") and callable(tokenizer.add_special_tokens):
try:
tokenizer.add_special_tokens({
"bos_token": bos_token,
"eos_token": eos_token,
"pad_token": pad_token,
"unk_token": unk_token,
})
except NotImplementedError:
# Your FastChemTokenizerHF already defines these tokens internally
pass
def encode(self, smiles: str, add_special_tokens: bool = True):
return self.tokenizer(
smiles,
add_special_tokens=add_special_tokens,
return_attention_mask=False,
return_tensors=None
)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def __len__(self):
return len(self.tokenizer)
def get_vocab(self):
return self.tokenizer.get_vocab()
@property
def bos_token_id(self):
return self.tokenizer.bos_token_id
@property
def eos_token_id(self):
return self.tokenizer.eos_token_id
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
@property
def unk_token_id(self):
return self.tokenizer.unk_token_id
#
# Step 1.5 β€” Initialize Tokenizers
#
tok1_hf = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
tok2_fast = FastChemTokenizer.from_pretrained("../smitok_core")
tokenizer1 = TokenizerWrapper(tok1_hf, name="ChemBERTa", bos_token="<s>", eos_token="</s>", pad_token="<pad>", unk_token="<unk>")
tokenizer2 = TokenizerWrapper(tok2_fast, name="FastChemTokenizerHF", bos_token="<s>", eos_token="</s>", pad_token="<pad>", unk_token="<unk>")
TOKENIZERS = [tokenizer1, tokenizer2]
#
# Step 1.6 β€” Benchmarking Functions (Fixed Bug #4 implicitly via epsilon)
#
def benchmark_tokenizer(tokenizer, smiles_sample, encode_only=False):
V = len(tokenizer)
sample = smiles_sample[:10000] if len(smiles_sample) > 10000 else smiles_sample
encode_times, token_counts, char_counts = [], [], []
unk_counts, total_tokens = 0, 0
for smiles in tqdm(sample, desc=f"Encoding with {tokenizer.name}", leave=False):
char_counts.append(len(smiles))
start = time.perf_counter()
enc = tokenizer.encode(smiles, add_special_tokens=True)
end = time.perf_counter()
encode_times.append(end - start)
input_ids = enc['input_ids']
token_counts.append(len(input_ids))
total_tokens += len(input_ids)
unk_id = tokenizer.tokenizer.unk_token_id
unk_counts += input_ids.count(unk_id)
L_bar = np.mean(token_counts)
C = np.mean(char_counts) / L_bar
U = unk_counts / total_tokens if total_tokens > 0 else 0.0
Tenc = len(sample) / sum(encode_times)
metrics = {
'vocab_size': V,
'avg_tokens_per_mol': L_bar,
'compression_ratio': C,
'percent_unknown': U * 100,
'encode_throughput_smiles_per_sec': Tenc,
}
if encode_only:
return metrics
decode_times, reconstruction_ok = [], 0
for smiles in tqdm(sample, desc=f"Decoding with {tokenizer.name}", leave=False):
enc = tokenizer.encode(smiles, add_special_tokens=True)
input_ids = enc['input_ids']
start = time.perf_counter()
decoded = tokenizer.decode(input_ids, skip_special_tokens=True)
end = time.perf_counter()
decode_times.append(end - start)
if decoded == smiles:
reconstruction_ok += 1
Tdec = len(sample) / sum(decode_times)
recon_acc = reconstruction_ok / len(sample)
metrics.update({
'decode_throughput_smiles_per_sec': Tdec,
'decode_reconstruction_accuracy': recon_acc * 100,
})
return metrics
#
# Step 1.7 β€” Run Benchmark
#
benchmark_sample = train_smiles
results = []
for tokenizer in TOKENIZERS:
print(f"\n=== Benchmarking {tokenizer.name} ===")
metrics = benchmark_tokenizer(tokenizer, benchmark_sample)
metrics['tokenizer'] = tokenizer.name
results.append(metrics)
for k, v in metrics.items():
if k != 'tokenizer':
print(f"{k:35s}: {v:.4f}" if isinstance(v, float) else f"{k:35s}: {v}")
df_results = pd.DataFrame(results)
df_results.to_csv("tokenizer_benchmark_results.csv", index=False)
print("\nTokenizer benchmark results saved to 'tokenizer_benchmark_results.csv'")
#
# Step 2.1 β€” VAE Model Class (PATCHED: decode stops at EOS)
#
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class MoleculeVAE(nn.Module):
"""
Optimized MoleculeVAE with:
- Bidirectional encoder (restored)
- Proper latent2hidden + latent2cell (restored)
- Adjustable dropout for small dataset
- Attention pooling option
- Quantization-ready hooks
"""
def __init__(self,
vocab_size: int,
embed_dim: int = 128,
hidden_dim: int = 256,
latent_dim: int = 128,
num_layers: int = 2,
pad_token_id: int = 0,
bos_token_id: int = 1,
eos_token_id: int = 2,
dropout: float = 0.2,
use_attention: bool = True,
quantize_ready: bool = False):
super().__init__()
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.num_layers = num_layers
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.use_attention = use_attention
# Shared embedding
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
# Bidirectional encoder
self.encoder_lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0,
bidirectional=True
)
# Attention pooling (optional)
if use_attention:
self.attention = nn.MultiheadAttention(
hidden_dim * 2, num_heads=4, dropout=dropout, batch_first=True
)
self.attention_linear = nn.Linear(hidden_dim * 2, 1)
self.encoder_norm = nn.LayerNorm(hidden_dim * 2)
# Latent bottleneck
self.fc_mu = nn.Linear(hidden_dim * 2, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim * 2, latent_dim)
# Decoder init (restored)
self.latent2hidden = nn.Linear(latent_dim, num_layers * hidden_dim)
self.latent2cell = nn.Linear(latent_dim, num_layers * hidden_dim)
# Decoder
self.decoder_lstm = nn.LSTM(
embed_dim, hidden_dim, num_layers,
batch_first=True, dropout=dropout if num_layers > 1 else 0
)
self.decoder_norm = nn.LayerNorm(hidden_dim)
self.fc_out = nn.Linear(hidden_dim, vocab_size)
# Weight tying
if embed_dim == hidden_dim:
self.fc_out.weight = self.embedding.weight
self.dropout = nn.Dropout(dropout)
# Quantization stubs
if quantize_ready:
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
else:
self.quant = self.dequant = nn.Identity()
self._init_weights()
def _init_weights(self):
for name, param in self.named_parameters():
if 'weight' in name:
if param.ndim >= 2:
nn.init.xavier_uniform_(param)
else:
nn.init.normal_(param, 0, 0.01)
elif 'bias' in name:
nn.init.zeros_(param)
def _pool_sequence(self, packed_output, lengths):
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
if self.use_attention:
attn_out, _ = self.attention(output, output, output)
weights = torch.softmax(self.attention_linear(attn_out), dim=1)
pooled = (weights * output).sum(dim=1)
else:
# mean pooling with mask
batch_size, max_len, _ = output.size()
mask = torch.arange(max_len, device=output.device).expand(batch_size, max_len) < lengths.unsqueeze(1)
masked_output = output * mask.unsqueeze(-1).float()
pooled = masked_output.sum(dim=1) / lengths.unsqueeze(-1).float()
return pooled
def encode(self, x: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.quant(x)
embedded = self.dropout(self.embedding(x))
packed = nn.utils.rnn.pack_padded_sequence(
embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
)
packed_out, _ = self.encoder_lstm(packed)
h = self._pool_sequence(packed_out, lengths)
h = self.encoder_norm(h)
mu, logvar = self.fc_mu(h), self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
return mu
def _init_decoder_state(self, z: torch.Tensor):
batch_size = z.size(0)
h0 = self.latent2hidden(z).view(self.num_layers, batch_size, self.hidden_dim)
c0 = self.latent2cell(z).view(self.num_layers, batch_size, self.hidden_dim)
return h0, c0
def decode(self, z: torch.Tensor, max_length: int = 64, mode: str = "greedy", temperature: float = 1.0):
batch_size = z.size(0)
device = z.device
h0, c0 = self._init_decoder_state(z)
hidden = (h0, c0)
input_ids = torch.full((batch_size, 1), self.bos_token_id, dtype=torch.long, device=device)
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
logits_list = []
for _ in range(max_length):
embedded = self.embedding(input_ids)
output, hidden = self.decoder_lstm(embedded, hidden)
output = self.decoder_norm(output)
logit = self.fc_out(output)
logits_list.append(logit)
if mode == "greedy":
next_tokens = logit.argmax(dim=-1)
elif mode == "sample":
probs = F.softmax(logit.squeeze(1) / temperature, dim=-1)
next_tokens = torch.multinomial(probs, 1)
else:
raise ValueError(f"Unknown decode mode: {mode}")
just_finished = (next_tokens.squeeze(-1) == self.eos_token_id)
finished |= just_finished
next_tokens = torch.where(
finished.unsqueeze(-1),
torch.tensor(self.pad_token_id, device=device),
next_tokens
)
input_ids = next_tokens
if finished.all():
break
return self.dequant(torch.cat(logits_list, dim=1))
def forward(self, input_ids: torch.Tensor, lengths: torch.Tensor,
target_seq: Optional[torch.Tensor] = None,
teacher_forcing_ratio: float = 0.0,
temperature: float = 1.0):
mu, logvar = self.encode(input_ids, lengths)
z = self.reparameterize(mu, logvar)
if self.training and target_seq is not None and teacher_forcing_ratio > 0:
return self._forward_teacher_forcing(z, target_seq, teacher_forcing_ratio), mu, logvar
else:
max_len = target_seq.size(1) if target_seq is not None else 64
return self.decode(z, max_length=max_len, temperature=temperature), mu, logvar
def _forward_teacher_forcing(self, z: torch.Tensor, target_seq: torch.Tensor, teacher_forcing_ratio: float):
batch_size, seq_len = target_seq.size()
h0, c0 = self._init_decoder_state(z)
hidden = (h0, c0)
logits_list = []
input_token = target_seq[:, 0:1]
for t in range(1, seq_len):
embedded = self.embedding(input_token)
output, hidden = self.decoder_lstm(embedded, hidden)
output = self.decoder_norm(output)
logit = self.fc_out(output)
logits_list.append(logit)
if torch.rand(1).item() < teacher_forcing_ratio:
input_token = target_seq[:, t:t+1]
else:
input_token = logit.argmax(dim=-1)
return torch.cat(logits_list, dim=1)
#
# Step 2.2 β€” Loss Function (PATCHED: Ξ² applied OUTSIDE, not inside)
#
# PATCH 2: Fix VAE Loss Function - Ensure beta is properly applied
# Replace the existing vae_loss function:
def vae_loss(logits, targets, mu, logvar, pad_token_id, beta=1.0):
# 1. align lengths
max_len = max(logits.size(1), targets.size(1))
if logits.size(1) < max_len:
logits = F.pad(logits, (0, 0, 0, max_len - logits.size(1)))
if targets.size(1) < max_len:
targets = F.pad(targets, (0, max_len - targets.size(1)), value=pad_token_id)
logits_flat = logits.view(-1, logits.size(-1)) # [B*L, V]
targets_flat = targets.reshape(-1) # [B*L]
mask = (targets_flat != pad_token_id).float()
ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
mask_sum = mask.sum()
ce_loss = (ce_loss * mask).sum() / (mask_sum + 1e-8)
# FIXED: Raw KL loss computation
kl_loss_raw = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
# Apply mask to KL loss if needed (but typically KL is per-sample)
kl_loss = kl_loss_raw.mean()
# CRITICAL FIX: Apply beta scaling correctly
total_loss = ce_loss + beta * kl_loss
return total_loss, ce_loss, kl_loss
#
# Step 2.3 β€” KLAnnealer (Fixed Bug #5: double increment)
#
import math
class KLAnnealer:
def __init__(self, total_steps, n_cycle=1, ratio=0.3, mode="linear", per_epoch=False, steps_per_epoch=None):
self.total_steps = total_steps
self.n_cycle = n_cycle
self.ratio = ratio
self.mode = mode
self.per_epoch = per_epoch
self.steps_per_epoch = steps_per_epoch
self.current_step = 0
self.current_epoch = 0
def get_beta(self, increment=True):
"""Get current KL weight (beta).
Args:
increment (bool): whether to advance the annealer (use False in validation).
"""
if increment:
self.current_step += 1
# Calculate progress based on total steps
progress = min(self.current_step / max(self.total_steps, 1.0), 1.0)
# For cyclical annealing
if self.n_cycle > 1:
cycle_length = self.total_steps / self.n_cycle
pos_in_cycle = (self.current_step % cycle_length)
cycle_progress = min(pos_in_cycle / max(cycle_length * self.ratio, 1.0), 1.0)
else:
# For single cycle, use full progress
cycle_progress = min(progress / self.ratio, 1.0) if self.ratio > 0 else 1.0
if self.mode == "linear":
beta = min(cycle_progress, 1.0)
elif self.mode == "sigmoid":
k = 6
# scale progress ∈ [0,1] β†’ [-3, +3] for a smooth S curve
beta = 1 / (1 + math.exp(-k * (cycle_progress - 0.5)))
elif self.mode == "cosine":
# Cosine annealing from 0 to 1
beta = 0.5 * (1 + math.cos(math.pi * (1 - cycle_progress)))
else:
raise ValueError(f"Unknown mode: {self.mode}")
return min(beta, 1.0)
def step(self):
"""Increment the step counter."""
self.current_step += 1
def epoch_step(self):
"""Increment the epoch counter."""
self.current_epoch += 1
#
# Teacher forcing ratio
#
def get_teacher_forcing_ratio(epoch, num_epochs, min_tfr=0.6, warmup_fraction=0.3):
"""
Linear decay of teacher forcing ratio (TFR).
- Starts at 1.0
- Decays to min_tfr by (warmup_fraction * num_epochs)
- Then stays flat
"""
warmup_epochs = int(num_epochs * warmup_fraction)
if epoch < warmup_epochs:
# linearly decay from 1.0 β†’ min_tfr
return 1.0 - (1.0 - min_tfr) * (epoch / warmup_epochs)
else:
return min_tfr
#
# Step 2.4 β€” Collate Function (Fixed Bug #2: dynamic pad id)
#
def collate_fn(batch, tokenizer, max_length=128):
encodings = [tokenizer.encode(s, add_special_tokens=True) for s in batch]
input_ids = [e['input_ids'] for e in encodings]
max_len = min(max(len(ids) for ids in input_ids), max_length)
padded = []
lengths = []
pad_token_id = tokenizer.tokenizer.pad_token_id # FIXED: dynamic
for ids in input_ids:
if len(ids) > max_length:
ids = ids[:max_length]
else:
ids = ids + [pad_token_id] * (max_len - len(ids))
padded.append(ids)
lengths.append(min(len(ids), max_length))
return torch.tensor(padded, dtype=torch.long), torch.tensor(lengths, dtype=torch.long)
#
# Step 2.5 β€” Dataset & DataLoader
#
class SmilesDataset(Dataset):
def __init__(self, smiles_list):
self.smiles_list = smiles_list
def __len__(self):
return len(self.smiles_list)
def __getitem__(self, idx):
return self.smiles_list[idx]
#
# Step 3.x β€” Training Loop (PATCHED: per-tokenizer annealer, exponential TFR, device-safe eval, KL beta logging clarity)
#
LEARNING_RATE = 1e-5
BATCH_SIZE = 16
ACCUMULATION_STEPS = 4
NUM_EPOCHS = 5
MAX_SEQ_LEN = 128
KL_ANNEAL_RATIO = 0.3
def train_vae(
model,
train_loader,
val_loader,
optimizer,
kl_annealer,
pad_token_id,
device,
num_epochs,
accumulation_steps=4,
save_dir="./checkpoints",
tokenizer_name="default"
):
os.makedirs(save_dir, exist_ok=True)
log_file = os.path.join(save_dir, f"training_log_{tokenizer_name}.csv")
with open(log_file, "w") as f:
f.write("epoch,step,train_loss,train_ce,train_kl,val_loss,val_ce,val_kl,kl_beta\n")
best_val_loss = float('inf')
for epoch in range(num_epochs):
print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")
model.train()
total_train_loss = total_train_ce = total_train_kl = 0.0
num_batches = 0
optimizer.zero_grad()
for step, (input_ids, lengths) in enumerate(tqdm(train_loader, desc="Training")):
input_ids, lengths = input_ids.to(device), lengths.to(device)
# ← PATCHED: exponential decay per epoch (not per batch, but smoother than linear)
tfr = get_teacher_forcing_ratio(epoch, num_epochs, min_tfr=0.6, warmup_fraction=0.3)
logits, mu, logvar = model(input_ids, lengths, target_seq=input_ids, teacher_forcing_ratio=tfr)
beta = kl_annealer.get_beta(increment=True)
loss, ce_loss, kl_loss = vae_loss(logits, input_ids, mu, logvar, pad_token_id, beta=beta)
loss = loss / accumulation_steps
loss.backward()
total_train_loss += loss.item() * accumulation_steps
total_train_ce += ce_loss.item()
total_train_kl += kl_loss.item()
num_batches += 1
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
if len(train_loader) % accumulation_steps != 0:
optimizer.step()
optimizer.zero_grad()
# βœ… CAPTURE BETA AFTER TRAINING β€” BEFORE VALIDATION
# This ensures we log the beta that was actually used during training
current_beta = kl_annealer.get_beta(increment=False)
# Validation β€” DO NOT query beta again here
model.eval()
total_val_loss = total_val_ce = total_val_kl = 0.0
val_batches = 0
with torch.no_grad():
for input_ids, lengths in tqdm(val_loader, desc="Validating"):
input_ids, lengths = input_ids.to(device), lengths.to(device)
# Use captured beta β€” DO NOT call kl_annealer again here
logits, mu, logvar = model(input_ids, lengths, target_seq=input_ids, teacher_forcing_ratio=0.0)
loss, ce_loss, kl_loss = vae_loss(logits, input_ids, mu, logvar, pad_token_id, beta=current_beta)
total_val_loss += loss.item()
total_val_ce += ce_loss.item()
total_val_kl += kl_loss.item()
val_batches += 1
avg_train_loss = total_train_loss / num_batches
avg_val_loss = total_val_loss / val_batches
current_step = (epoch + 1) * len(train_loader)
with open(log_file, "a") as f:
f.write(f"{epoch+1},{current_step},{avg_train_loss:.6f},{total_train_ce/num_batches:.6f},{total_train_kl/num_batches:.6f},"
f"{avg_val_loss:.6f},{total_val_ce/val_batches:.6f},{total_val_kl/val_batches:.6f},{current_beta:.6f}\n")
print(f"Train Loss: {avg_train_loss:.4f}")
print(f"Val Loss: {avg_val_loss:.4f}")
print(f"KL Beta: {current_beta:.4f}") # ← Now explicitly the training beta
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
checkpoint_path = os.path.join(save_dir, f"best_model_{tokenizer_name}.pt")
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': avg_val_loss,
}, checkpoint_path)
print(f"β†’ Saved best model to {checkpoint_path}")
return best_val_loss
#
# TRAINING LOOP OVER TOKENIZERS (PATCHED: KLAnnealer reset per tokenizer)
#
for tokenizer in TOKENIZERS:
print(f"\n STARTING TRAINING FOR: {tokenizer.name}\n")
vocab_size = len(tokenizer)
pad_token_id = tokenizer.tokenizer.pad_token_id
# Validate token IDs
sample_ids = tokenizer.encode(train_smiles[0], add_special_tokens=True)['input_ids']
max_id_in_sample = max(sample_ids)
assert max_id_in_sample < vocab_size, f"Token ID {max_id_in_sample} >= vocab size {vocab_size} in {tokenizer.name}"
model = MoleculeVAE(
vocab_size=len(tokenizer),
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
).to(device)
########################################################################
# 1. CREATE A FRESH annealer FOR EVERY TOKENIZER
########################################################################
optimizer = Ranger21(
model.parameters(),
lr=LEARNING_RATE,
weight_decay=0.01,
use_adabelief=True,
use_warmup=True,
use_madgrad=True,
num_epochs=NUM_EPOCHS,
num_batches_per_epoch=len(train_smiles) // (BATCH_SIZE * ACCUMULATION_STEPS),
warmdown_active=False,
)
train_dataset = SmilesDataset(train_smiles)
val_dataset = SmilesDataset(val_smiles)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=lambda batch: collate_fn(batch, tokenizer, max_length=MAX_SEQ_LEN),
num_workers=0,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
collate_fn=lambda batch: collate_fn(batch, tokenizer, max_length=MAX_SEQ_LEN),
num_workers=0,
pin_memory=True
)
steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * NUM_EPOCHS
# total_steps = (len(train_smiles) // (BATCH_SIZE * ACCUMULATION_STEPS)) * NUM_EPOCHS
kl_annealer = KLAnnealer(
total_steps=total_steps,
n_cycle=1, # REDUCED: 2 cycles instead of 4 for longer warmup per cycle
ratio=0.6, # INCREASED: 60% of each cycle is warmup (was 25%)
mode="linear", # CHANGED: Linear is more predictable than sigmoid
per_epoch=False
)
train_vae(
model=model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
kl_annealer=kl_annealer,
pad_token_id=pad_token_id,
device=device,
num_epochs=NUM_EPOCHS,
accumulation_steps=ACCUMULATION_STEPS,
save_dir=f"./checkpoints/{tokenizer.name}",
tokenizer_name=tokenizer.name
)
#
# Step 4.x β€” Evaluation Pipeline (Fixed Bug #6, #7, #8)
#
def canonicalize_smiles(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
return Chem.MolToSmiles(mol, isomericSmiles=True)
def evaluate_reconstruction(model, dataloader, tokenizer, device, max_length=128):
model.eval()
total_token_correct = total_tokens = exact_matches = valid_count = total_samples = 0
all_generated, all_targets = [], []
pad_id = tokenizer.tokenizer.pad_token_id
eos_id = tokenizer.tokenizer.eos_token_id
special_ids = {pad_id, eos_id}
def trim_to_special(ids, specials):
for i, id_ in enumerate(ids):
if id_ in specials:
return ids[:i]
return ids
with torch.no_grad():
for input_ids, lengths in tqdm(dataloader, desc="Evaluating Reconstruction"):
input_ids, lengths = input_ids.to(device), lengths.to(device)
B = input_ids.size(0)
mu, logvar = model.encode(input_ids, lengths)
z = model.reparameterize(mu, logvar)
logits = model.decode(z, max_length=128, mode="greedy") # FIXED #7 for reconstruction
preds = logits.argmax(dim=-1)
# FIXED: Align logits and targets to same sequence length
min_len = min(logits.size(1), input_ids.size(1))
preds = preds[:, :min_len] # trim predictions
input_ids_eval = input_ids[:, :min_len] # trim targets
mask = (input_ids_eval != pad_id)
token_correct = ((preds == input_ids_eval) & mask).sum().item()
total_token_correct += token_correct
total_tokens += mask.sum().item()
for i in range(B):
target_ids = input_ids_eval[i].cpu().tolist()
pred_ids = preds[i].cpu().tolist()
# FIXED BUG #6: Trim before decode
target_ids_trim = trim_to_special(target_ids, special_ids)
pred_ids_trim = trim_to_special(pred_ids, special_ids)
target_smiles = tokenizer.decode(target_ids_trim, skip_special_tokens=False)
pred_smiles = tokenizer.decode(pred_ids_trim, skip_special_tokens=False)
all_targets.append(target_smiles)
all_generated.append(pred_smiles)
if pred_smiles == target_smiles:
exact_matches += 1
if Chem.MolFromSmiles(pred_smiles) is not None:
valid_count += 1
total_samples += 1
token_acc = total_token_correct / total_tokens if total_tokens > 0 else 0.0
exact_match_rate = exact_matches / total_samples
validity_rate = valid_count / total_samples
print(f"Token-level Accuracy: {token_acc:.4f}")
print(f"Exact Match Rate: {exact_match_rate:.4f}")
print(f"Validity Rate: {validity_rate:.4f}")
return {
'token_accuracy': token_acc,
'exact_match_rate': exact_match_rate,
'validity_rate': validity_rate,
'generated_smiles': all_generated,
'target_smiles': all_targets
}
def compute_uniqueness_and_novelty(generated_smiles, train_smiles_set):
total = len(generated_smiles)
unique = len(set(generated_smiles))
novel = len([s for s in generated_smiles if s not in train_smiles_set])
uniqueness = unique / total if total > 0 else 0.0
novelty = novel / total if total > 0 else 0.0
print(f"Uniqueness: {uniqueness:.4f} ({unique}/{total})")
print(f"Novelty: {novelty:.4f} ({novel}/not in train)")
return uniqueness, novelty
def kl_divergence_from_samples(samples, bins=512):
dim_kls = []
for d in range(samples.shape[1]):
data = samples[:, d]
hist, bin_edges = np.histogram(data, bins=bins, density=True)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
norm_pdf = (1 / np.sqrt(2 * np.pi)) * np.exp(-0.5 * bin_centers**2)
hist = np.clip(hist, 1e-10, None)
norm_pdf = np.clip(norm_pdf, 1e-10, None)
kl = entropy(hist, norm_pdf)
dim_kls.append(kl)
return np.mean(dim_kls)
def evaluate_latent_kl(model, dataloader, device, latent_dim=128, bins=512):
model.eval()
all_z = []
with torch.no_grad():
for input_ids, lengths in tqdm(dataloader, desc="Sampling Latents"):
input_ids, lengths = input_ids.to(device), lengths.to(device)
mu, logvar = model.encode(input_ids, lengths)
z = model.reparameterize(mu, logvar)
all_z.append(z.cpu().numpy())
all_z = np.concatenate(all_z, axis=0)
kl_div = kl_divergence_from_samples(all_z, bins=bins)
print(f"KL Divergence (empirical vs N(0,1)): {kl_div:.4f}")
return kl_div
def evaluate_interpolation_validity(model, tokenizer, test_smiles, device, num_pairs=100, steps=10, max_length=128):
model.eval()
pairs = random.sample(list(zip(test_smiles[::2], test_smiles[1::2])), min(num_pairs, len(test_smiles)//2))
valid_interps = total_interps = 0
with torch.no_grad():
for smiles_a, smiles_b in tqdm(pairs, desc="Interpolation Validity"):
if not smiles_a or not smiles_b: continue
enc_a = tokenizer.encode(smiles_a, add_special_tokens=True)
enc_b = tokenizer.encode(smiles_b, add_special_tokens=True)
ids_a = torch.tensor([enc_a['input_ids']], device=device)
ids_b = torch.tensor([enc_b['input_ids']], device=device)
len_a = torch.tensor([len(enc_a['input_ids'])], device=device)
len_b = torch.tensor([len(enc_b['input_ids'])], device=device)
mu_a, _ = model.encode(ids_a, len_a)
mu_b, _ = model.encode(ids_b, len_b)
alphas = torch.linspace(0, 1, steps, device=device)
for alpha in alphas:
z_interp = alpha * mu_b + (1 - alpha) * mu_a
# Ensure z_interp maintains batch dimension [1, latent_dim]
if z_interp.dim() == 1:
z_interp = z_interp.unsqueeze(0)
logits = model.decode(z_interp, max_length=max_length, mode="sample", temperature=0.8)
preds = logits.argmax(dim=-1)
# Handle batch dimension properly
if preds.dim() > 1:
preds = preds[0] # Take first (and only) batch item
pred_smiles = tokenizer.decode(preds.cpu().tolist(), skip_special_tokens=True)
if Chem.MolFromSmiles(pred_smiles) is not None:
valid_interps += 1
total_interps += 1
interp_validity = valid_interps / total_interps if total_interps > 0 else 0.0
print(f"Interpolation Validity: {interp_validity:.4f}")
return interp_validity
def sample_from_latent(model, tokenizer, num_samples=30000, latent_dim=128, max_length=128, device=device, temperature=0.8):
model.eval()
generated_smiles = []
with torch.no_grad():
for _ in tqdm(range(0, num_samples, BATCH_SIZE), desc="Sampling from Latent"):
current_batch_size = min(BATCH_SIZE, num_samples - len(generated_smiles))
if current_batch_size <= 0: break
z = torch.randn(current_batch_size, latent_dim, device=device)
logits = model.decode(z, max_length=max_length, mode="sample", temperature=temperature)
preds = logits.argmax(dim=-1)
for i in range(current_batch_size):
pred_ids = preds[i].cpu().tolist()
smiles = tokenizer.decode(pred_ids, skip_special_tokens=True)
generated_smiles.append(smiles)
if len(generated_smiles) >= num_samples: break
return generated_smiles
def measure_inference_throughput(model, tokenizer, test_smiles, device,
max_length=128,
batch_sizes=[1, 4, 8, 16]):
"""
Benchmark inference speed & peak GPU memory across several batch sizes.
Returns a JSON-serialisable dict:
{batch_size: {'tokens_per_sec': <float>, 'peak_mem_mb': <float>}, ...}
"""
model.eval()
results = {}
for bs in batch_sizes:
# Build a small fixed subset so every BS processes the same #samples
subset = SmilesDataset(test_smiles[:bs * 10])
loader = DataLoader(
subset,
batch_size=bs,
shuffle=False,
num_workers=0,
collate_fn=lambda b: collate_fn(b, tokenizer, max_length=max_length),
)
total_tokens = 0
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(device)
start_time = time.perf_counter()
with torch.no_grad():
for input_ids, lengths in loader:
input_ids, lengths = input_ids.to(device), lengths.to(device)
mu, logvar = model.encode(input_ids, lengths)
z = model.reparameterize(mu, logvar)
logits = model.decode(z, max_length=max_length)
total_tokens += logits.numel() # number of float elements
duration = time.perf_counter() - start_time
tokens_per_sec = total_tokens / duration
peak_mem_mb = (
torch.cuda.max_memory_allocated(device) / (1024 ** 2)
if torch.cuda.is_available()
else 0.0
)
# Store as plain Python floats
results[bs] = {
"tokens_per_sec": float(tokens_per_sec),
"peak_mem_mb": float(peak_mem_mb),
}
print(f"BS {bs:3d} β†’ {tokens_per_sec:8.2f} tok/s | Peak Mem: {peak_mem_mb:.2f} MB")
return results
#
# FINAL EVALUATION PIPELINE
#
def full_evaluation_pipeline(model, tokenizer, train_smiles, test_smiles, device, save_dir):
print(f"\n FULL EVALUATION FOR: {tokenizer.name}")
test_dataset = SmilesDataset(test_smiles)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
collate_fn=lambda b: collate_fn(b, tokenizer, max_length=MAX_SEQ_LEN),
num_workers=0)
# 1. Reconstruction
recon_metrics = evaluate_reconstruction(model, test_loader, tokenizer, device)
# 2. Uniqueness & Novelty
train_set = set(train_smiles)
uniqueness, novelty = compute_uniqueness_and_novelty(recon_metrics['generated_smiles'], train_set)
# 3. KL Divergence
kl_div = evaluate_latent_kl(model, test_loader, device)
# 4. Interpolation Validity
interp_validity = evaluate_interpolation_validity(model, tokenizer, test_smiles, device)
# 5. Latent Sampling (for FCD β€” optional)
# gen_smiles_30k = sample_from_latent(model, tokenizer, num_samples=10000, temperature=0.8) # reduce for speed
# fcd_score = compute_fcd(test_smiles, gen_smiles_30k) if 'get_fcd' in globals() else None
# 6. Throughput & Memory
# throughput = measure_inference_throughput(model, tokenizer, test_loader, device)
eval_results = {
**recon_metrics,
'uniqueness': uniqueness,
'novelty': novelty,
'kl_divergence': kl_div,
'interpolation_validity': interp_validity,
# 'fcd': fcd_score,
# 'inference_throughput': throughput,
}
eval_path = os.path.join(save_dir, "evaluation_results.json")
with open(eval_path, "w") as f:
json.dump(eval_results, f, indent=2, default=str)
print(f" Evaluation saved to {eval_path}")
return eval_results
#
# RUN EVALUATION FOR EACH TOKENIZER
#
for tokenizer in TOKENIZERS:
print(f"\nπŸ”„ LOADING BEST MODEL FOR: {tokenizer.name}")
checkpoint_path = f"./checkpoints/{tokenizer.name}/best_model_{tokenizer.name}.pt"
if not os.path.exists(checkpoint_path):
print(f"⚠️ Checkpoint not found: {checkpoint_path}")
continue
vocab_size = len(tokenizer)
pad_token_id = tokenizer.tokenizer.pad_token_id
model = MoleculeVAE(
vocab_size=vocab_size,
pad_token_id=pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
full_evaluation_pipeline(
model=model,
tokenizer=tokenizer,
train_smiles=train_smiles,
test_smiles=test_smiles,
device=device,
save_dir=f"./checkpoints/{tokenizer.name}"
)
print("\nπŸŽ‰ PIPELINE COMPLETE β€” ALL TOKENIZERS BENCHMARKED, TRAINED, AND EVALUATED!")