#!/usr/bin/env python3 """ Self-Supervised Training for Molecular Representations (SMILES) Usage: python trainbarlow.py --config config.yaml """ print("Initializing ...") import os import json import argparse import random from pathlib import Path from typing import Dict, Any, Tuple, List import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import DataLoader from tqdm.auto import tqdm from sklearn.metrics.pairwise import cosine_similarity from sklearn.preprocessing import normalize # Suppress RDKit warnings from rdkit import RDLogger RDLogger.DisableLog('rdApp.*') try: from rdkit.Chem import MolFromSmiles, MolToSmiles, AllChem from rdkit import DataStructs except ImportError: raise ImportError("RDKit is required. Install with: conda install -c conda-forge rdkit") try: from sentence_transformers import SentenceTransformer, InputExample except ImportError: raise ImportError("Install sentence-transformers: pip install sentence-transformers") # ====================== # Projector # ====================== class BarlowTwinsProjector(nn.Module): """Projector with BatchNorm (for Barlow Twins).""" def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 2048): super().__init__() self.layers = nn.Sequential( nn.Linear(in_dim, hidden_dim, bias=False), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim, bias=False), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, out_dim, bias=False), nn.BatchNorm1d(out_dim, affine=False) ) def forward(self, x): return self.layers(x) # ====================== # Loss Function # ====================== class BarlowTwinsLoss(nn.Module): """ Barlow Twins' Loss Implementation with shared standardization and scaled off-diagonals with d. """ def __init__(self, λ: float = 0.005): super().__init__() self.λ = λ def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]: B, d = z1.shape # Shared standardization z = torch.cat([z1, z2], dim=0) z = (z - z.mean(dim=0)) / (z.std(dim=0) + 1e-8) z1, z2 = z[:B], z[B:] c = (z1.T @ z2) / B on_diag = (1 - torch.diagonal(c)).pow(2).sum() off_diag = (c ** 2).sum() - torch.diagonal(c).pow(2).sum() off_diag = off_diag / d total_loss = on_diag + self.λ * off_diag with torch.no_grad(): diag_mean = torch.diagonal(c).mean().item() off_diag_mask = ~torch.eye(d, dtype=torch.bool, device=c.device) off_diag_mean = c[off_diag_mask].abs().mean().item() return total_loss, { 'od': on_diag.item(), 'ofsc': (self.λ * off_diag).item(), 'ofrw': off_diag.item(), 'cr_onm': diag_mean, 'cr_offm': off_diag_mean } # ====================== # Utilities # ====================== def load_config(config_path: str) -> Dict[str, Any]: config_path = Path(config_path) if config_path.suffix in {'.yaml', '.yml'}: import yaml with open(config_path) as f: return yaml.safe_load(f) elif config_path.suffix == '.json': with open(config_path) as f: return json.load(f) else: raise ValueError(f"Unsupported config format: {config_path.suffix}") def sanitize_config(config: Dict[str, Any]) -> Dict[str, Any]: float_keys = { "LR", "WEIGHT_DECAY", "BARLOW_LAMBDA", "VICREG_LAMBDA", "VICREG_MU", "VICREG_NU", "CORINFOMAX_ALPHA" } int_keys = { "BATCH_SIZE", "EFFECTIVE_BATCH", "EPOCHS", "MAX_LENGTH", "SEED", "EVAL_EVERY_N_PERCENT" } bool_keys = {"BEST_BY_HEALTH"} for key in float_keys: if key in config: config[key] = float(config[key]) for key in int_keys: if key in config: config[key] = int(config[key]) for key in bool_keys: if key in config: val = config[key] config[key] = val.lower() in {"true", "1", "yes", "on"} if isinstance(val, str) else bool(val) return config def set_seed(seed: int): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def enum_smiles(smi: str, k: int = 2) -> List[str]: from rdkit.Chem import MolFromSmiles, MolToSmiles mol = MolFromSmiles(smi) if mol is None: return [smi] * k variants = set() attempts = 0 while len(variants) < k and attempts < 100: variants.add(MolToSmiles(mol, doRandom=True, canonical=False)) attempts += 1 return list(variants)[:k] def tanimoto(s1: str, s2: str) -> float: m1, m2 = MolFromSmiles(s1), MolFromSmiles(s2) if not m1 or not m2: return 0.0 fp1 = AllChem.GetMorganFingerprintAsBitVect(m1, radius=2, nBits=2048) fp2 = AllChem.GetMorganFingerprintAsBitVect(m2, radius=2, nBits=2048) return DataStructs.TanimotoSimilarity(fp1, fp2) def uniformity_metrics(emb: np.ndarray) -> Dict[str, float]: emb = normalize(emb) sim = cosine_similarity(emb) mask = ~np.eye(len(sim), dtype=bool) pairwise = sim[mask] mean_sim, std_sim = pairwise.mean(), pairwise.std() distances = 1 - sim uniformity = np.log(np.exp(-2 * distances[mask]).mean()) return { 'mean': float(mean_sim), 'std': float(std_sim), 'uniformity': float(uniformity), 'health_old': float(1 - mean_sim), 'collapsed': mean_sim > 0.7 or std_sim < 0.05 } def forward_pooled(model: SentenceTransformer, text_list: List[str], device: torch.device) -> torch.Tensor: tok = model.tokenize(text_list) tok = {k: v.to(device) for k, v in tok.items()} hf_output = model(tok) return hf_output['token_embeddings'][:, 0, :] def evaluate(model, eval_smiles: List[str], device: torch.device, step: int) -> Dict[str, Any]: model.eval() with torch.no_grad(): emb = model.encode(eval_smiles, convert_to_numpy=True, show_progress_bar=False, batch_size=32) um = uniformity_metrics(emb) same_view = [enum_smiles(s, 1)[0] for s in eval_smiles] with torch.no_grad(): emb2 = model.encode(same_view, convert_to_numpy=True, show_progress_bar=False, batch_size=32) same_cos = np.diag(cosine_similarity(emb, emb2)) alignment = 1 - same_cos.mean() barlow_health = same_cos.mean() - um['mean'] print(f"\n📊 Step {step} | Alignment={alignment:.3f} | Uniformity={um['uniformity']:.3f}") print(f" Same-mol cos: {same_cos.mean():.3f}±{same_cos.std():.3f} | Pairwise: {um['mean']:.3f}±{um['std']:.3f}") print(f" Barlow Health: {barlow_health:.3f} (higher = better)") model.train() um['health'] = barlow_health um['alignment'] = alignment um['same_cos_mean'] = same_cos.mean() um['same_cos_std'] = same_cos.std() return um # ====================== # Main # ====================== def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--epochs", type=int) parser.add_argument("--lr", type=float) parser.add_argument("--batch_size", type=int) parser.add_argument("--loss_type", type=str, choices=["barlow", "vicreg", "corinfomax"]) args = parser.parse_args() config = load_config(args.config) for key, value in vars(args).items(): if value is not None and key != "config": config[key] = value config = sanitize_config(config) set_seed(config.get("SEED", 42)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") output_dir = Path(config["OUTPUT_DIR"]) output_dir.mkdir(parents=True, exist_ok=True) df = pd.read_csv(config["DATA_PATH"]) smiles_list = df["SMILES"].dropna().tolist() print(f"📂 Loaded {len(smiles_list)} SMILES") train_examples = [] for smi in tqdm(smiles_list, desc="Enumerating SMILES"): variants = enum_smiles(smi, 2) if len(variants) < 2: variants = [smi, smi] train_examples.append(InputExample(texts=[variants[0], variants[1]])) print(f" Created {len(train_examples)} pairs") eval_size = min(200, len(smiles_list)) eval_smiles = np.random.choice(smiles_list, eval_size, replace=False).tolist() # Model model = SentenceTransformer('./chmbedv2-warmup-l5/final') model.max_seq_length = config.get("MAX_LENGTH", 512) embed_dim = model.get_sentence_embedding_dimension() # Projector & Loss loss_type = config.get("LOSS_TYPE", "barlow") if loss_type == "barlow": projector = BarlowTwinsProjector( embed_dim, hidden_dim=2048, out_dim=2048 ).to(device) train_loss = BarlowTwinsLoss( λ=config.get("BARLOW_LAMBDA", 0.005) ).to(device) else: raise ValueError(f"Unknown loss_type: {loss_type}") model.to(device) # Optimizer (include projector!) from ranger21 import Ranger21 no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] model_params = [ {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": config.get("WEIGHT_DECAY", 0.01)}, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} ] # Calculate training parameters for Ranger21 scheduling batch_size = config.get("BATCH_SIZE", 8) effective_batch = config.get("EFFECTIVE_BATCH", 32) grad_acc = effective_batch // batch_size epochs = config.get("EPOCHS", 1) total_steps = (len(train_examples) // effective_batch) * epochs train_loader = DataLoader(train_examples, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x) num_batches_per_epoch = len(train_examples) // effective_batch no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] model_params = [ {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": config.get("WEIGHT_DECAY", 0.01)}, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} ] optimizer = Ranger21( model_params + [{"params": projector.parameters(), "weight_decay": config.get("WEIGHT_DECAY", 0.01)}], lr=config.get("LR", 1e-5), num_epochs=epochs, num_batches_per_epoch=num_batches_per_epoch, weight_decay=0.0, # Handle weight decay manually in param groups ) # Training loop setup scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1.0, end_factor=0.0, total_iters=total_steps ) # Train model.train() step = 0 best_health = 0.0 best_step = 0 log_interval = max(1, int(total_steps * config.get("EVAL_EVERY_N_PERCENT", 25) / 100)) for epoch in range(epochs): pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") for batch_idx, batch in enumerate(pbar): texts = [[ex.texts[i] for ex in batch] for i in range(2)] z1 = forward_pooled(model, texts[0], device) z2 = forward_pooled(model, texts[1], device) p1 = projector(z1) p2 = projector(z2) loss, extras = train_loss(p1, p2) loss = loss / grad_acc loss.backward() if (batch_idx + 1) % grad_acc == 0: optimizer.step() scheduler.step() optimizer.zero_grad() step += 1 postfix = {"step": step, "lr": scheduler.get_last_lr()[0]} for k, v in extras.items(): postfix[k] = f"{v:.3f}" pbar.set_postfix(postfix) if step % log_interval == 0 or step == total_steps: um = evaluate(model, eval_smiles, device, step) if config.get("BEST_BY_HEALTH", True) and um["health"] > best_health: best_health, best_step = um["health"], step model.save(str(output_dir / "best")) model.save(str(output_dir / "final")) print(f"\n✅ Training complete! Best health: {best_health:.3f} at step {best_step}") if __name__ == "__main__": main()