|
|
|
|
|
"""
|
|
|
Latent Space Visualization for Molecule VAE Models
|
|
|
Integrated with existing benchmark pipeline structure
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import time
|
|
|
import random
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
import seaborn as sns
|
|
|
from matplotlib.colors import ListedColormap
|
|
|
from pathlib import Path
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
|
|
from sklearn.manifold import TSNE
|
|
|
from sklearn.decomposition import PCA
|
|
|
from tqdm import tqdm
|
|
|
from rdkit import Chem
|
|
|
from rdkit import RDLogger
|
|
|
RDLogger.DisableLog('rdApp.*')
|
|
|
|
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
try:
|
|
|
from FastChemTokenizer import FastChemTokenizer
|
|
|
except ImportError:
|
|
|
print("FastChemTokenizer not found. Please ensure it's in your PYTHONPATH.")
|
|
|
FastChemTokenizer = None
|
|
|
|
|
|
|
|
|
def set_seed(seed=42):
|
|
|
torch.manual_seed(seed)
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
np.random.seed(seed)
|
|
|
random.seed(seed)
|
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
set_seed(42)
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
class TokenizerWrapper:
|
|
|
def __init__(self, tokenizer, name, bos_token="<s>", eos_token="</s>", pad_token="<pad>", unk_token="<unk>"):
|
|
|
self.tokenizer = tokenizer
|
|
|
self.name = name
|
|
|
self.bos_token = bos_token
|
|
|
self.eos_token = eos_token
|
|
|
self.pad_token = pad_token
|
|
|
self.unk_token = unk_token
|
|
|
|
|
|
if hasattr(tokenizer, 'add_special_tokens'):
|
|
|
tokenizer.add_special_tokens({
|
|
|
'bos_token': bos_token,
|
|
|
'eos_token': eos_token,
|
|
|
'pad_token': pad_token,
|
|
|
'unk_token': unk_token
|
|
|
})
|
|
|
|
|
|
def encode(self, smiles: str, add_special_tokens: bool = True):
|
|
|
if isinstance(self.tokenizer, FastChemTokenizer):
|
|
|
|
|
|
ids = self.tokenizer.encode(smiles)
|
|
|
|
|
|
if add_special_tokens:
|
|
|
ids = [self.tokenizer.bos_token_id] + ids + [self.tokenizer.eos_token_id]
|
|
|
return {'input_ids': ids}
|
|
|
else:
|
|
|
|
|
|
return self.tokenizer(
|
|
|
smiles,
|
|
|
add_special_tokens=add_special_tokens,
|
|
|
return_attention_mask=False,
|
|
|
return_tensors=None
|
|
|
)
|
|
|
|
|
|
def decode(self, token_ids, skip_special_tokens=True):
|
|
|
if isinstance(self.tokenizer, FastChemTokenizer):
|
|
|
|
|
|
tokens = [self.tokenizer.id_to_token.get(tid, self.tokenizer.unk_token)
|
|
|
for tid in token_ids]
|
|
|
|
|
|
if skip_special_tokens:
|
|
|
specials = {self.tokenizer.bos_token,
|
|
|
self.tokenizer.eos_token,
|
|
|
self.tokenizer.pad_token,
|
|
|
self.tokenizer.unk_token}
|
|
|
tokens = [t for t in tokens if t not in specials]
|
|
|
|
|
|
if hasattr(self.tokenizer, 'detokenize'):
|
|
|
return self.tokenizer.detokenize(tokens)
|
|
|
else:
|
|
|
return "".join(tokens)
|
|
|
else:
|
|
|
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
|
|
|
|
|
def __len__(self):
|
|
|
if isinstance(self.tokenizer, FastChemTokenizer):
|
|
|
|
|
|
return len(getattr(self.tokenizer, 'vocab',
|
|
|
getattr(self.tokenizer, '_vocab', self.tokenizer)))
|
|
|
else:
|
|
|
return len(self.tokenizer)
|
|
|
|
|
|
def get_vocab(self):
|
|
|
if isinstance(self.tokenizer, FastChemTokenizer):
|
|
|
return self.tokenizer.vocab
|
|
|
else:
|
|
|
return self.tokenizer.get_vocab()
|
|
|
|
|
|
@property
|
|
|
def bos_token_id(self):
|
|
|
return self.tokenizer.bos_token_id
|
|
|
|
|
|
@property
|
|
|
def eos_token_id(self):
|
|
|
return self.tokenizer.eos_token_id
|
|
|
|
|
|
@property
|
|
|
def pad_token_id(self):
|
|
|
return self.tokenizer.pad_token_id
|
|
|
|
|
|
@property
|
|
|
def unk_token_id(self):
|
|
|
return self.tokenizer.unk_token_id
|
|
|
|
|
|
def collate_fn(batch, tokenizer, max_length=128):
|
|
|
encodings = [tokenizer.encode(s, add_special_tokens=True) for s in batch]
|
|
|
input_ids = [e['input_ids'] for e in encodings]
|
|
|
|
|
|
max_len = min(max(len(ids) for ids in input_ids), max_length)
|
|
|
padded = []
|
|
|
lengths = []
|
|
|
|
|
|
pad_token_id = tokenizer.tokenizer.pad_token_id
|
|
|
|
|
|
for ids in input_ids:
|
|
|
if len(ids) > max_length:
|
|
|
ids = ids[:max_length]
|
|
|
else:
|
|
|
ids = ids + [pad_token_id] * (max_len - len(ids))
|
|
|
padded.append(ids)
|
|
|
lengths.append(min(len(ids), max_length))
|
|
|
|
|
|
return torch.tensor(padded, dtype=torch.long), torch.tensor(lengths, dtype=torch.long)
|
|
|
|
|
|
|
|
|
class SmilesDataset(Dataset):
|
|
|
def __init__(self, smiles_list):
|
|
|
self.smiles_list = smiles_list
|
|
|
def __len__(self):
|
|
|
return len(self.smiles_list)
|
|
|
def __getitem__(self, idx):
|
|
|
return self.smiles_list[idx]
|
|
|
|
|
|
|
|
|
|
|
|
class MoleculeVAE(nn.Module):
|
|
|
def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, latent_dim=128, num_layers=2,
|
|
|
pad_token_id=0, bos_token_id=1, eos_token_id=2):
|
|
|
super().__init__()
|
|
|
self.vocab_size = vocab_size
|
|
|
self.embed_dim = embed_dim
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.latent_dim = latent_dim
|
|
|
self.num_layers = num_layers
|
|
|
self.pad_token_id = pad_token_id
|
|
|
self.bos_token_id = bos_token_id
|
|
|
self.eos_token_id = eos_token_id
|
|
|
|
|
|
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id)
|
|
|
self.encoder_lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
|
|
|
self.fc_mu = nn.Linear(hidden_dim * 2, latent_dim)
|
|
|
self.fc_logvar = nn.Linear(hidden_dim * 2, latent_dim)
|
|
|
|
|
|
self.decoder_lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
|
|
|
self.fc_out = nn.Linear(hidden_dim, vocab_size)
|
|
|
|
|
|
self.latent2hidden = nn.Linear(latent_dim, num_layers * hidden_dim)
|
|
|
self.latent2cell = nn.Linear(latent_dim, num_layers * hidden_dim)
|
|
|
|
|
|
self._init_weights()
|
|
|
|
|
|
def _init_weights(self):
|
|
|
for m in self.modules():
|
|
|
if isinstance(m, nn.Linear):
|
|
|
nn.init.xavier_uniform_(m.weight)
|
|
|
if m.bias is not None:
|
|
|
nn.init.zeros_(m.bias)
|
|
|
elif isinstance(m, nn.LSTM):
|
|
|
for name, param in m.named_parameters():
|
|
|
if 'weight' in name:
|
|
|
nn.init.orthogonal_(param)
|
|
|
elif 'bias' in name:
|
|
|
nn.init.zeros_(param)
|
|
|
|
|
|
def encode(self, x, lengths):
|
|
|
embedded = self.embedding(x)
|
|
|
packed = nn.utils.rnn.pack_padded_sequence(embedded, lengths.cpu(), batch_first=True, enforce_sorted=False)
|
|
|
packed_out, (hidden, _) = self.encoder_lstm(packed)
|
|
|
h_forward = hidden[-2]
|
|
|
h_backward = hidden[-1]
|
|
|
h = torch.cat([h_forward, h_backward], dim=1)
|
|
|
mu = self.fc_mu(h)
|
|
|
logvar = self.fc_logvar(h)
|
|
|
return mu, logvar
|
|
|
|
|
|
def reparameterize(self, mu, logvar):
|
|
|
if self.training:
|
|
|
std = torch.exp(0.5 * logvar)
|
|
|
eps = torch.randn_like(std)
|
|
|
return mu + eps * std
|
|
|
else:
|
|
|
return mu
|
|
|
|
|
|
def decode(self, z, max_length=128, mode="greedy", temperature=1.0):
|
|
|
"""
|
|
|
Decode latent vector z into a sequence.
|
|
|
Returns full logits at each step.
|
|
|
PATCHED: stops generation when EOS is predicted.
|
|
|
"""
|
|
|
batch_size = z.size(0)
|
|
|
device = z.device
|
|
|
|
|
|
|
|
|
h0 = self.latent2hidden(z).view(self.num_layers, batch_size, self.hidden_dim)
|
|
|
c0 = self.latent2cell(z).view(self.num_layers, batch_size, self.hidden_dim)
|
|
|
hidden = (h0, c0)
|
|
|
|
|
|
|
|
|
input_token = torch.full((batch_size, 1), self.bos_token_id, dtype=torch.long, device=device)
|
|
|
logits = []
|
|
|
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
|
|
|
|
for _ in range(max_length):
|
|
|
embedded = self.embedding(input_token)
|
|
|
output, hidden = self.decoder_lstm(embedded, hidden)
|
|
|
logit = self.fc_out(output)
|
|
|
logits.append(logit)
|
|
|
|
|
|
if mode == "greedy":
|
|
|
input_token = logit.argmax(dim=-1)
|
|
|
elif mode == "sample":
|
|
|
probs = torch.softmax(logit.squeeze(1) / temperature, dim=-1)
|
|
|
input_token = torch.multinomial(probs, 1)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown decode mode: {mode}")
|
|
|
|
|
|
|
|
|
just_finished = (input_token.squeeze(1) == self.eos_token_id)
|
|
|
finished |= just_finished
|
|
|
input_token[finished] = self.pad_token_id
|
|
|
if finished.all():
|
|
|
break
|
|
|
|
|
|
return torch.cat(logits, dim=1)
|
|
|
|
|
|
def forward(self, input_ids, lengths, target_seq=None, teacher_forcing_ratio=0.0, temperature=1.0):
|
|
|
mu, logvar = self.encode(input_ids, lengths)
|
|
|
z = self.reparameterize(mu, logvar)
|
|
|
|
|
|
if self.training and target_seq is not None and teacher_forcing_ratio > 0:
|
|
|
|
|
|
batch_size, seq_len = target_seq.size()
|
|
|
device = target_seq.device
|
|
|
|
|
|
|
|
|
h0 = self.latent2hidden(z).view(self.num_layers, batch_size, self.hidden_dim)
|
|
|
c0 = self.latent2cell(z).view(self.num_layers, batch_size, self.hidden_dim)
|
|
|
hidden = (h0, c0)
|
|
|
|
|
|
logits = []
|
|
|
input_token = target_seq[:, 0].unsqueeze(1)
|
|
|
|
|
|
for t in range(1, seq_len):
|
|
|
embedded = self.embedding(input_token)
|
|
|
output, hidden = self.decoder_lstm(embedded, hidden)
|
|
|
logit = self.fc_out(output)
|
|
|
logits.append(logit)
|
|
|
|
|
|
use_teacher = torch.rand(1).item() < teacher_forcing_ratio
|
|
|
if use_teacher:
|
|
|
input_token = target_seq[:, t].unsqueeze(1)
|
|
|
else:
|
|
|
input_token = logit.argmax(dim=-1)
|
|
|
|
|
|
logits = torch.cat(logits, dim=1)
|
|
|
else:
|
|
|
|
|
|
max_len = target_seq.size(1) if target_seq is not None else 128
|
|
|
logits = self.decode(z, max_length=max_len, mode="greedy", temperature=temperature)
|
|
|
|
|
|
return logits, mu, logvar
|
|
|
|
|
|
class LatentSpaceVisualizer:
|
|
|
def __init__(self, model_path, tokenizer, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
|
|
self.device = device
|
|
|
self.tokenizer = tokenizer
|
|
|
self.model = self.load_model(model_path)
|
|
|
|
|
|
def load_model(self, model_path):
|
|
|
"""Load the trained VAE model"""
|
|
|
checkpoint = torch.load(model_path, map_location=self.device)
|
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint:
|
|
|
state_dict = checkpoint['model_state_dict']
|
|
|
else:
|
|
|
state_dict = checkpoint
|
|
|
|
|
|
|
|
|
vocab_size = len(self.tokenizer)
|
|
|
pad_token_id = self.tokenizer.tokenizer.pad_token_id
|
|
|
|
|
|
|
|
|
model = MoleculeVAE(vocab_size=vocab_size, pad_token_id=pad_token_id)
|
|
|
model.load_state_dict(state_dict)
|
|
|
model.to(self.device)
|
|
|
model.eval()
|
|
|
|
|
|
return model
|
|
|
|
|
|
def encode_molecules(self, smiles_list, batch_size=32):
|
|
|
"""Encode molecules to latent space"""
|
|
|
dataset = SmilesDataset(smiles_list)
|
|
|
dataloader = DataLoader(
|
|
|
dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=False,
|
|
|
collate_fn=lambda batch: collate_fn(batch, self.tokenizer, max_length=128)
|
|
|
)
|
|
|
|
|
|
all_mus = []
|
|
|
with torch.no_grad():
|
|
|
for input_ids, lengths in tqdm(dataloader, desc="Encoding molecules"):
|
|
|
input_ids = input_ids.to(self.device)
|
|
|
lengths = lengths.to(self.device)
|
|
|
|
|
|
mu, logvar = self.model.encode(input_ids, lengths)
|
|
|
all_mus.append(mu.cpu().numpy())
|
|
|
|
|
|
return np.concatenate(all_mus, axis=0)
|
|
|
|
|
|
def create_grid_latent_points(self, grid_size=100, z_range=4):
|
|
|
"""Create a grid of points in 2D latent space"""
|
|
|
x = np.linspace(-z_range, z_range, grid_size)
|
|
|
y = np.linspace(-z_range, z_range, grid_size)
|
|
|
xx, yy = np.meshgrid(x, y)
|
|
|
|
|
|
|
|
|
center = grid_size // 2
|
|
|
radius = grid_size // 2
|
|
|
y_coords, x_coords = np.ogrid[:grid_size, :grid_size]
|
|
|
mask = (x_coords - center) ** 2 + (y_coords - center) ** 2 <= radius ** 2
|
|
|
|
|
|
return xx, yy, mask
|
|
|
|
|
|
def classify_latent_points(self, latent_points, dim1=0, dim2=1, additional_dim=None):
|
|
|
"""
|
|
|
Classify latent points by chemical validity (RDKit parseable)
|
|
|
Returns: 0 for invalid/unparseable molecules, 1 for valid molecules
|
|
|
"""
|
|
|
classifications = []
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
batch_size = 32
|
|
|
for i in range(0, len(latent_points), batch_size):
|
|
|
batch_points = latent_points[i:i+batch_size]
|
|
|
|
|
|
|
|
|
full_z = torch.randn(len(batch_points), self.model.latent_dim, device=self.device) * 0.1
|
|
|
full_z[:, dim1] = torch.FloatTensor(batch_points[:, 0]).to(self.device)
|
|
|
full_z[:, dim2] = torch.FloatTensor(batch_points[:, 1]).to(self.device)
|
|
|
|
|
|
|
|
|
if additional_dim is not None:
|
|
|
if isinstance(additional_dim, dict):
|
|
|
for dim_idx, dim_val in additional_dim.items():
|
|
|
full_z[:, dim_idx] = dim_val
|
|
|
|
|
|
try:
|
|
|
|
|
|
logits = self.model.decode(full_z, max_length=64, temperature=0.8)
|
|
|
predictions = torch.argmax(logits, dim=-1)
|
|
|
|
|
|
|
|
|
batch_classes = []
|
|
|
for pred in predictions:
|
|
|
pred_ids = pred.cpu().tolist()
|
|
|
|
|
|
|
|
|
pad_id = self.tokenizer.tokenizer.pad_token_id
|
|
|
eos_id = self.tokenizer.tokenizer.eos_token_id
|
|
|
|
|
|
|
|
|
for j, token_id in enumerate(pred_ids):
|
|
|
if token_id in [pad_id, eos_id]:
|
|
|
pred_ids = pred_ids[:j]
|
|
|
break
|
|
|
|
|
|
try:
|
|
|
decoded_smiles = self.tokenizer.decode(pred_ids, skip_special_tokens=True)
|
|
|
|
|
|
mol = Chem.MolFromSmiles(decoded_smiles)
|
|
|
|
|
|
if mol is None:
|
|
|
|
|
|
batch_classes.append(0)
|
|
|
else:
|
|
|
|
|
|
batch_classes.append(1)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
batch_classes.append(0)
|
|
|
|
|
|
classifications.extend(batch_classes)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
classifications.extend([0] * len(batch_points))
|
|
|
|
|
|
return np.array(classifications)
|
|
|
|
|
|
def plot_latent_space_interpolation(self, grid_size=100, z_range=4, save_path=None):
|
|
|
"""
|
|
|
Create latent space interpolation plots similar to the reference images
|
|
|
"""
|
|
|
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
|
|
|
axes = axes.flatten()
|
|
|
|
|
|
|
|
|
colors = ['#FF4444', '#44AA44']
|
|
|
cmap = ListedColormap(colors)
|
|
|
|
|
|
plot_idx = 0
|
|
|
|
|
|
|
|
|
dimension_pairs = [(0, 1), (2, 3), (4, 5), (6, 7)]
|
|
|
|
|
|
for dim_pair in dimension_pairs:
|
|
|
dim1, dim2 = dim_pair
|
|
|
|
|
|
|
|
|
xx, yy, mask = self.create_grid_latent_points(grid_size, z_range)
|
|
|
|
|
|
|
|
|
valid_points = []
|
|
|
valid_coords = []
|
|
|
|
|
|
for i in range(grid_size):
|
|
|
for j in range(grid_size):
|
|
|
if mask[i, j]:
|
|
|
valid_points.append([xx[i, j], yy[i, j]])
|
|
|
valid_coords.append([i, j])
|
|
|
|
|
|
valid_points = np.array(valid_points)
|
|
|
|
|
|
|
|
|
print(f"Classifying latent space chemical validity for dimensions {dim1}, {dim2}...")
|
|
|
classifications = self.classify_latent_points(valid_points, dim1, dim2)
|
|
|
|
|
|
|
|
|
class_grid = np.zeros((grid_size, grid_size))
|
|
|
class_grid.fill(np.nan)
|
|
|
|
|
|
for point_idx, (i, j) in enumerate(valid_coords):
|
|
|
class_grid[i, j] = classifications[point_idx]
|
|
|
|
|
|
|
|
|
ax = axes[plot_idx]
|
|
|
im = ax.imshow(class_grid, extent=[-z_range, z_range, -z_range, z_range],
|
|
|
origin='lower', cmap=cmap, alpha=0.8, vmin=0, vmax=1)
|
|
|
|
|
|
|
|
|
circles = [1, 2, 3, 4]
|
|
|
for radius in circles:
|
|
|
if radius <= z_range:
|
|
|
circle = plt.Circle((0, 0), radius, fill=False, color='black',
|
|
|
alpha=0.3, linewidth=0.5)
|
|
|
ax.add_patch(circle)
|
|
|
|
|
|
|
|
|
ax.set_xlabel(f'Latent dimension z{dim1}')
|
|
|
ax.set_ylabel(f'Latent dimension z{dim2}')
|
|
|
ax.set_title('SMILES')
|
|
|
ax.set_xlim(-z_range, z_range)
|
|
|
ax.set_ylim(-z_range, z_range)
|
|
|
ax.set_aspect('equal')
|
|
|
|
|
|
plot_idx += 1
|
|
|
|
|
|
|
|
|
for z2_val in [-2, -1, 1, 2]:
|
|
|
dim1, dim2 = 0, 1
|
|
|
|
|
|
|
|
|
xx, yy, mask = self.create_grid_latent_points(grid_size, z_range)
|
|
|
|
|
|
|
|
|
valid_points = []
|
|
|
valid_coords = []
|
|
|
|
|
|
for i in range(grid_size):
|
|
|
for j in range(grid_size):
|
|
|
if mask[i, j]:
|
|
|
valid_points.append([xx[i, j], yy[i, j]])
|
|
|
valid_coords.append([i, j])
|
|
|
|
|
|
valid_points = np.array(valid_points)
|
|
|
|
|
|
|
|
|
print(f"Classifying latent space chemical validity for z0, z1 with z2 = {z2_val}...")
|
|
|
classifications = self.classify_latent_points(
|
|
|
valid_points, dim1, dim2,
|
|
|
additional_dim={2: z2_val}
|
|
|
)
|
|
|
|
|
|
|
|
|
class_grid = np.zeros((grid_size, grid_size))
|
|
|
class_grid.fill(np.nan)
|
|
|
|
|
|
for point_idx, (i, j) in enumerate(valid_coords):
|
|
|
class_grid[i, j] = classifications[point_idx]
|
|
|
|
|
|
|
|
|
ax = axes[plot_idx]
|
|
|
im = ax.imshow(class_grid, extent=[-z_range, z_range, -z_range, z_range],
|
|
|
origin='lower', cmap=cmap, alpha=0.8, vmin=0, vmax=1)
|
|
|
|
|
|
|
|
|
for radius in circles:
|
|
|
if radius <= z_range:
|
|
|
circle = plt.Circle((0, 0), radius, fill=False, color='black',
|
|
|
alpha=0.3, linewidth=0.5)
|
|
|
ax.add_patch(circle)
|
|
|
|
|
|
ax.set_xlabel('Latent dimension z0')
|
|
|
ax.set_ylabel('Latent dimension z1')
|
|
|
ax.set_title(f'SMILES; z2 = {z2_val}')
|
|
|
ax.set_xlim(-z_range, z_range)
|
|
|
ax.set_ylim(-z_range, z_range)
|
|
|
ax.set_aspect('equal')
|
|
|
|
|
|
plot_idx += 1
|
|
|
|
|
|
plt.suptitle(f'Latent Space Chemical Validity - {self.tokenizer.name}\n(Red: Invalid molecules, Green: Valid molecules)', fontsize=16)
|
|
|
plt.tight_layout()
|
|
|
|
|
|
if save_path:
|
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
def plot_molecule_embeddings(self, smiles_list, method='tsne', save_path=None):
|
|
|
"""Plot actual molecule embeddings in 2D using dimensionality reduction"""
|
|
|
print(f"Encoding {len(smiles_list)} molecules...")
|
|
|
embeddings = self.encode_molecules(smiles_list)
|
|
|
|
|
|
|
|
|
labels = []
|
|
|
for smiles in smiles_list:
|
|
|
mol = Chem.MolFromSmiles(smiles)
|
|
|
if mol is None:
|
|
|
labels.append(0)
|
|
|
else:
|
|
|
|
|
|
mw = Chem.Descriptors.MolWt(mol)
|
|
|
labels.append(1 if mw > 200 else 0)
|
|
|
|
|
|
labels = np.array(labels)
|
|
|
|
|
|
|
|
|
print(f"Computing {method.upper()} projection...")
|
|
|
if method == 'tsne':
|
|
|
reducer = TSNE(n_components=2, random_state=42, perplexity=min(30, len(smiles_list)//4))
|
|
|
else:
|
|
|
reducer = PCA(n_components=2, random_state=42)
|
|
|
|
|
|
embeddings_2d = reducer.fit_transform(embeddings)
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 8))
|
|
|
scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1],
|
|
|
c=labels, cmap='RdYlGn', alpha=0.7, s=20)
|
|
|
plt.colorbar(scatter, label='Molecular Weight > 200')
|
|
|
plt.title(f'{method.upper()} of Molecule Embeddings - {self.tokenizer.name}')
|
|
|
plt.xlabel(f'{method.upper()} 1')
|
|
|
plt.ylabel(f'{method.upper()} 2')
|
|
|
|
|
|
if save_path:
|
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
def load_data_and_tokenizers():
|
|
|
"""Load data and tokenizers using your existing structure"""
|
|
|
|
|
|
data_path = "../data/sample_all_8k_smi.csv"
|
|
|
if not os.path.exists(data_path):
|
|
|
print(f"Data file not found: {data_path}")
|
|
|
print("Please update the data_path in the script.")
|
|
|
return None, None
|
|
|
|
|
|
df = pd.read_csv(data_path)
|
|
|
if 'SMILES' not in df.columns:
|
|
|
raise ValueError("Expected column 'SMILES' in CSV")
|
|
|
|
|
|
smiles_list = df['SMILES'].dropna().tolist()
|
|
|
|
|
|
|
|
|
valid_smiles = []
|
|
|
for smiles in smiles_list:
|
|
|
if Chem.MolFromSmiles(smiles) is not None:
|
|
|
valid_smiles.append(smiles)
|
|
|
|
|
|
print(f"Loaded {len(valid_smiles)} valid SMILES")
|
|
|
|
|
|
|
|
|
try:
|
|
|
tok1_hf = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
|
|
|
tokenizer1 = TokenizerWrapper(tok1_hf, name="ChemBERTa",
|
|
|
bos_token="<s>", eos_token="</s>",
|
|
|
pad_token="<pad>", unk_token="<unk>")
|
|
|
except Exception as e:
|
|
|
print(f"Failed to load ChemBERTa tokenizer: {e}")
|
|
|
tokenizer1 = None
|
|
|
|
|
|
try:
|
|
|
tok2_fast = FastChemTokenizer.from_pretrained("../smitok")
|
|
|
tokenizer2 = TokenizerWrapper(tok2_fast, name="FastChemTokenizer",
|
|
|
bos_token="[BOS]", eos_token="[EOS]",
|
|
|
pad_token="[PAD]", unk_token="[UNK]")
|
|
|
except Exception as e:
|
|
|
print(f"Failed to load FastChemTokenizer: {e}")
|
|
|
tokenizer2 = None
|
|
|
|
|
|
tokenizers = [t for t in [tokenizer1, tokenizer2] if t is not None]
|
|
|
|
|
|
return valid_smiles, tokenizers
|
|
|
|
|
|
def create_latent_visualizations():
|
|
|
"""Main function to create latent space visualizations"""
|
|
|
|
|
|
|
|
|
smiles_list, tokenizers = load_data_and_tokenizers()
|
|
|
if smiles_list is None or not tokenizers:
|
|
|
print("Failed to load data or tokenizers. Please check your setup.")
|
|
|
return
|
|
|
|
|
|
|
|
|
viz_smiles = smiles_list[:1000]
|
|
|
|
|
|
|
|
|
model_paths = {
|
|
|
'ChemBERTa': './checkpoints/ChemBERTa/best_model_ChemBERTa.pt',
|
|
|
'FastChemTokenizer': './checkpoints/FastChemTokenizer/best_model_FastChemTokenizer.pt'
|
|
|
}
|
|
|
|
|
|
|
|
|
os.makedirs('latent_space_plots', exist_ok=True)
|
|
|
|
|
|
for tokenizer in tokenizers:
|
|
|
model_path = model_paths.get(tokenizer.name)
|
|
|
if model_path is None or not os.path.exists(model_path):
|
|
|
print(f"Model not found for {tokenizer.name}: {model_path}")
|
|
|
continue
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
print(f"Creating visualizations for {tokenizer.name}")
|
|
|
print(f"{'='*60}")
|
|
|
|
|
|
try:
|
|
|
|
|
|
visualizer = LatentSpaceVisualizer(model_path, tokenizer, device)
|
|
|
|
|
|
|
|
|
print("Creating latent space interpolation plots...")
|
|
|
save_path = f'latent_space_plots/{tokenizer.name}_latent_interpolation.png'
|
|
|
visualizer.plot_latent_space_interpolation(save_path=save_path)
|
|
|
|
|
|
|
|
|
print("Creating t-SNE embedding plot...")
|
|
|
save_path = f'latent_space_plots/{tokenizer.name}_embeddings_tsne.png'
|
|
|
visualizer.plot_molecule_embeddings(viz_smiles, method='tsne', save_path=save_path)
|
|
|
|
|
|
print("Creating PCA embedding plot...")
|
|
|
save_path = f'latent_space_plots/{tokenizer.name}_embeddings_pca.png'
|
|
|
visualizer.plot_molecule_embeddings(viz_smiles, method='pca', save_path=save_path)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error processing {tokenizer.name}: {str(e)}")
|
|
|
import traceback
|
|
|
traceback.print_exc()
|
|
|
continue
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
print("Visualization complete! Check the 'latent_space_plots' directory for results.")
|
|
|
print(f"{'='*60}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
try:
|
|
|
from rdkit.Chem import Descriptors, rdMolDescriptors
|
|
|
except ImportError:
|
|
|
print("RDKit Descriptors not available. Using simpler classification.")
|
|
|
|
|
|
Descriptors = None
|
|
|
rdMolDescriptors = None
|
|
|
|
|
|
create_latent_visualizations()
|
|
|
|