|
|
|
|
|
""" |
|
|
Self-Supervised Training for Molecular Representations (SMILES) |
|
|
|
|
|
Usage: |
|
|
python trainbarlow.py --config config.yaml |
|
|
""" |
|
|
print("Initializing ...") |
|
|
import os |
|
|
import json |
|
|
import argparse |
|
|
import random |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, Tuple, List |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm.auto import tqdm |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from sklearn.preprocessing import normalize |
|
|
|
|
|
|
|
|
from rdkit import RDLogger |
|
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
|
|
try: |
|
|
from rdkit.Chem import MolFromSmiles, MolToSmiles, AllChem |
|
|
from rdkit import DataStructs |
|
|
except ImportError: |
|
|
raise ImportError("RDKit is required. Install with: conda install -c conda-forge rdkit") |
|
|
|
|
|
try: |
|
|
from sentence_transformers import SentenceTransformer, InputExample |
|
|
except ImportError: |
|
|
raise ImportError("Install sentence-transformers: pip install sentence-transformers") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BarlowTwinsProjector(nn.Module): |
|
|
"""Projector with BatchNorm (for Barlow Twins).""" |
|
|
def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 2048): |
|
|
super().__init__() |
|
|
self.layers = nn.Sequential( |
|
|
nn.Linear(in_dim, hidden_dim, bias=False), |
|
|
nn.BatchNorm1d(hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim, hidden_dim, bias=False), |
|
|
nn.BatchNorm1d(hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(hidden_dim, out_dim, bias=False), |
|
|
nn.BatchNorm1d(out_dim, affine=False) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.layers(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BarlowTwinsLoss(nn.Module): |
|
|
""" |
|
|
Barlow Twins' Loss Implementation |
|
|
with shared standardization and scaled off-diagonals with d. |
|
|
""" |
|
|
def __init__(self, λ: float = 0.005): |
|
|
super().__init__() |
|
|
self.λ = λ |
|
|
|
|
|
def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]: |
|
|
B, d = z1.shape |
|
|
|
|
|
z = torch.cat([z1, z2], dim=0) |
|
|
z = (z - z.mean(dim=0)) / (z.std(dim=0) + 1e-8) |
|
|
z1, z2 = z[:B], z[B:] |
|
|
c = (z1.T @ z2) / B |
|
|
on_diag = (1 - torch.diagonal(c)).pow(2).sum() |
|
|
off_diag = (c ** 2).sum() - torch.diagonal(c).pow(2).sum() |
|
|
off_diag = off_diag / d |
|
|
total_loss = on_diag + self.λ * off_diag |
|
|
with torch.no_grad(): |
|
|
diag_mean = torch.diagonal(c).mean().item() |
|
|
off_diag_mask = ~torch.eye(d, dtype=torch.bool, device=c.device) |
|
|
off_diag_mean = c[off_diag_mask].abs().mean().item() |
|
|
return total_loss, { |
|
|
'od': on_diag.item(), |
|
|
'ofsc': (self.λ * off_diag).item(), |
|
|
'ofrw': off_diag.item(), |
|
|
'cr_onm': diag_mean, |
|
|
'cr_offm': off_diag_mean |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_config(config_path: str) -> Dict[str, Any]: |
|
|
config_path = Path(config_path) |
|
|
if config_path.suffix in {'.yaml', '.yml'}: |
|
|
import yaml |
|
|
with open(config_path) as f: |
|
|
return yaml.safe_load(f) |
|
|
elif config_path.suffix == '.json': |
|
|
with open(config_path) as f: |
|
|
return json.load(f) |
|
|
else: |
|
|
raise ValueError(f"Unsupported config format: {config_path.suffix}") |
|
|
|
|
|
def sanitize_config(config: Dict[str, Any]) -> Dict[str, Any]: |
|
|
float_keys = { |
|
|
"LR", "WEIGHT_DECAY", "BARLOW_LAMBDA", "VICREG_LAMBDA", |
|
|
"VICREG_MU", "VICREG_NU", "CORINFOMAX_ALPHA" |
|
|
} |
|
|
int_keys = { |
|
|
"BATCH_SIZE", "EFFECTIVE_BATCH", "EPOCHS", "MAX_LENGTH", |
|
|
"SEED", "EVAL_EVERY_N_PERCENT" |
|
|
} |
|
|
bool_keys = {"BEST_BY_HEALTH"} |
|
|
for key in float_keys: |
|
|
if key in config: |
|
|
config[key] = float(config[key]) |
|
|
for key in int_keys: |
|
|
if key in config: |
|
|
config[key] = int(config[key]) |
|
|
for key in bool_keys: |
|
|
if key in config: |
|
|
val = config[key] |
|
|
config[key] = val.lower() in {"true", "1", "yes", "on"} if isinstance(val, str) else bool(val) |
|
|
return config |
|
|
|
|
|
def set_seed(seed: int): |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def enum_smiles(smi: str, k: int = 2) -> List[str]: |
|
|
from rdkit.Chem import MolFromSmiles, MolToSmiles |
|
|
mol = MolFromSmiles(smi) |
|
|
if mol is None: |
|
|
return [smi] * k |
|
|
variants = set() |
|
|
attempts = 0 |
|
|
while len(variants) < k and attempts < 100: |
|
|
variants.add(MolToSmiles(mol, doRandom=True, canonical=False)) |
|
|
attempts += 1 |
|
|
return list(variants)[:k] |
|
|
|
|
|
def tanimoto(s1: str, s2: str) -> float: |
|
|
m1, m2 = MolFromSmiles(s1), MolFromSmiles(s2) |
|
|
if not m1 or not m2: |
|
|
return 0.0 |
|
|
fp1 = AllChem.GetMorganFingerprintAsBitVect(m1, radius=2, nBits=2048) |
|
|
fp2 = AllChem.GetMorganFingerprintAsBitVect(m2, radius=2, nBits=2048) |
|
|
return DataStructs.TanimotoSimilarity(fp1, fp2) |
|
|
|
|
|
def uniformity_metrics(emb: np.ndarray) -> Dict[str, float]: |
|
|
emb = normalize(emb) |
|
|
sim = cosine_similarity(emb) |
|
|
mask = ~np.eye(len(sim), dtype=bool) |
|
|
pairwise = sim[mask] |
|
|
mean_sim, std_sim = pairwise.mean(), pairwise.std() |
|
|
distances = 1 - sim |
|
|
uniformity = np.log(np.exp(-2 * distances[mask]).mean()) |
|
|
return { |
|
|
'mean': float(mean_sim), |
|
|
'std': float(std_sim), |
|
|
'uniformity': float(uniformity), |
|
|
'health_old': float(1 - mean_sim), |
|
|
'collapsed': mean_sim > 0.7 or std_sim < 0.05 |
|
|
} |
|
|
|
|
|
def forward_pooled(model: SentenceTransformer, text_list: List[str], device: torch.device) -> torch.Tensor: |
|
|
tok = model.tokenize(text_list) |
|
|
tok = {k: v.to(device) for k, v in tok.items()} |
|
|
hf_output = model(tok) |
|
|
return hf_output['token_embeddings'][:, 0, :] |
|
|
|
|
|
def evaluate(model, eval_smiles: List[str], device: torch.device, step: int) -> Dict[str, Any]: |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
emb = model.encode(eval_smiles, convert_to_numpy=True, show_progress_bar=False, batch_size=32) |
|
|
um = uniformity_metrics(emb) |
|
|
same_view = [enum_smiles(s, 1)[0] for s in eval_smiles] |
|
|
with torch.no_grad(): |
|
|
emb2 = model.encode(same_view, convert_to_numpy=True, show_progress_bar=False, batch_size=32) |
|
|
same_cos = np.diag(cosine_similarity(emb, emb2)) |
|
|
alignment = 1 - same_cos.mean() |
|
|
barlow_health = same_cos.mean() - um['mean'] |
|
|
print(f"\n📊 Step {step} | Alignment={alignment:.3f} | Uniformity={um['uniformity']:.3f}") |
|
|
print(f" Same-mol cos: {same_cos.mean():.3f}±{same_cos.std():.3f} | Pairwise: {um['mean']:.3f}±{um['std']:.3f}") |
|
|
print(f" Barlow Health: {barlow_health:.3f} (higher = better)") |
|
|
model.train() |
|
|
um['health'] = barlow_health |
|
|
um['alignment'] = alignment |
|
|
um['same_cos_mean'] = same_cos.mean() |
|
|
um['same_cos_std'] = same_cos.std() |
|
|
return um |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, required=True) |
|
|
parser.add_argument("--epochs", type=int) |
|
|
parser.add_argument("--lr", type=float) |
|
|
parser.add_argument("--batch_size", type=int) |
|
|
parser.add_argument("--loss_type", type=str, choices=["barlow", "vicreg", "corinfomax"]) |
|
|
args = parser.parse_args() |
|
|
|
|
|
config = load_config(args.config) |
|
|
for key, value in vars(args).items(): |
|
|
if value is not None and key != "config": |
|
|
config[key] = value |
|
|
config = sanitize_config(config) |
|
|
|
|
|
set_seed(config.get("SEED", 42)) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
output_dir = Path(config["OUTPUT_DIR"]) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
df = pd.read_csv(config["DATA_PATH"]) |
|
|
smiles_list = df["SMILES"].dropna().tolist() |
|
|
print(f"📂 Loaded {len(smiles_list)} SMILES") |
|
|
|
|
|
train_examples = [] |
|
|
for smi in tqdm(smiles_list, desc="Enumerating SMILES"): |
|
|
variants = enum_smiles(smi, 2) |
|
|
if len(variants) < 2: |
|
|
variants = [smi, smi] |
|
|
train_examples.append(InputExample(texts=[variants[0], variants[1]])) |
|
|
print(f" Created {len(train_examples)} pairs") |
|
|
|
|
|
eval_size = min(200, len(smiles_list)) |
|
|
eval_smiles = np.random.choice(smiles_list, eval_size, replace=False).tolist() |
|
|
|
|
|
|
|
|
model = SentenceTransformer('./chmbedv2-warmup-l5/final') |
|
|
model.max_seq_length = config.get("MAX_LENGTH", 512) |
|
|
embed_dim = model.get_sentence_embedding_dimension() |
|
|
|
|
|
|
|
|
loss_type = config.get("LOSS_TYPE", "barlow") |
|
|
if loss_type == "barlow": |
|
|
projector = BarlowTwinsProjector( |
|
|
embed_dim, |
|
|
hidden_dim=2048, |
|
|
out_dim=2048 |
|
|
).to(device) |
|
|
train_loss = BarlowTwinsLoss( |
|
|
λ=config.get("BARLOW_LAMBDA", 0.005) |
|
|
).to(device) |
|
|
else: |
|
|
raise ValueError(f"Unknown loss_type: {loss_type}") |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
|
|
|
from ranger21 import Ranger21 |
|
|
|
|
|
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] |
|
|
model_params = [ |
|
|
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
|
"weight_decay": config.get("WEIGHT_DECAY", 0.01)}, |
|
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
|
"weight_decay": 0.0} |
|
|
] |
|
|
|
|
|
|
|
|
batch_size = config.get("BATCH_SIZE", 8) |
|
|
effective_batch = config.get("EFFECTIVE_BATCH", 32) |
|
|
grad_acc = effective_batch // batch_size |
|
|
epochs = config.get("EPOCHS", 1) |
|
|
total_steps = (len(train_examples) // effective_batch) * epochs |
|
|
train_loader = DataLoader(train_examples, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x) |
|
|
num_batches_per_epoch = len(train_examples) // effective_batch |
|
|
|
|
|
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] |
|
|
model_params = [ |
|
|
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
|
"weight_decay": config.get("WEIGHT_DECAY", 0.01)}, |
|
|
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
|
"weight_decay": 0.0} |
|
|
] |
|
|
|
|
|
optimizer = Ranger21( |
|
|
model_params + [{"params": projector.parameters(), "weight_decay": config.get("WEIGHT_DECAY", 0.01)}], |
|
|
lr=config.get("LR", 1e-5), |
|
|
num_epochs=epochs, |
|
|
num_batches_per_epoch=num_batches_per_epoch, |
|
|
weight_decay=0.0, |
|
|
) |
|
|
|
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.LinearLR( |
|
|
optimizer, start_factor=1.0, end_factor=0.0, total_iters=total_steps |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
model.train() |
|
|
step = 0 |
|
|
best_health = 0.0 |
|
|
best_step = 0 |
|
|
log_interval = max(1, int(total_steps * config.get("EVAL_EVERY_N_PERCENT", 25) / 100)) |
|
|
|
|
|
for epoch in range(epochs): |
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") |
|
|
for batch_idx, batch in enumerate(pbar): |
|
|
texts = [[ex.texts[i] for ex in batch] for i in range(2)] |
|
|
z1 = forward_pooled(model, texts[0], device) |
|
|
z2 = forward_pooled(model, texts[1], device) |
|
|
p1 = projector(z1) |
|
|
p2 = projector(z2) |
|
|
loss, extras = train_loss(p1, p2) |
|
|
|
|
|
loss = loss / grad_acc |
|
|
loss.backward() |
|
|
|
|
|
if (batch_idx + 1) % grad_acc == 0: |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
step += 1 |
|
|
|
|
|
postfix = {"step": step, "lr": scheduler.get_last_lr()[0]} |
|
|
for k, v in extras.items(): |
|
|
postfix[k] = f"{v:.3f}" |
|
|
pbar.set_postfix(postfix) |
|
|
|
|
|
if step % log_interval == 0 or step == total_steps: |
|
|
um = evaluate(model, eval_smiles, device, step) |
|
|
if config.get("BEST_BY_HEALTH", True) and um["health"] > best_health: |
|
|
best_health, best_step = um["health"], step |
|
|
model.save(str(output_dir / "best")) |
|
|
|
|
|
model.save(str(output_dir / "final")) |
|
|
print(f"\n✅ Training complete! Best health: {best_health:.3f} at step {best_step}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |