miniChembed-prototype / train /trainbarlow.py
gbyuvd's picture
Update train/trainbarlow.py
deb45f3 verified
#!/usr/bin/env python3
"""
Self-Supervised Training for Molecular Representations (SMILES)
Usage:
python trainbarlow.py --config config.yaml
"""
print("Initializing ...")
import os
import json
import argparse
import random
from pathlib import Path
from typing import Dict, Any, Tuple, List
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
# Suppress RDKit warnings
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
try:
from rdkit.Chem import MolFromSmiles, MolToSmiles, AllChem
from rdkit import DataStructs
except ImportError:
raise ImportError("RDKit is required. Install with: conda install -c conda-forge rdkit")
try:
from sentence_transformers import SentenceTransformer, InputExample
except ImportError:
raise ImportError("Install sentence-transformers: pip install sentence-transformers")
# ======================
# Projector
# ======================
class BarlowTwinsProjector(nn.Module):
"""Projector with BatchNorm (for Barlow Twins)."""
def __init__(self, in_dim: int, hidden_dim: int = 2048, out_dim: int = 2048):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, hidden_dim, bias=False),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, bias=False),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, out_dim, bias=False),
nn.BatchNorm1d(out_dim, affine=False)
)
def forward(self, x):
return self.layers(x)
# ======================
# Loss Function
# ======================
class BarlowTwinsLoss(nn.Module):
"""
Barlow Twins' Loss Implementation
with shared standardization and scaled off-diagonals with d.
"""
def __init__(self, λ: float = 0.005):
super().__init__()
self.λ = λ
def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]:
B, d = z1.shape
# Shared standardization
z = torch.cat([z1, z2], dim=0)
z = (z - z.mean(dim=0)) / (z.std(dim=0) + 1e-8)
z1, z2 = z[:B], z[B:]
c = (z1.T @ z2) / B
on_diag = (1 - torch.diagonal(c)).pow(2).sum()
off_diag = (c ** 2).sum() - torch.diagonal(c).pow(2).sum()
off_diag = off_diag / d
total_loss = on_diag + self.λ * off_diag
with torch.no_grad():
diag_mean = torch.diagonal(c).mean().item()
off_diag_mask = ~torch.eye(d, dtype=torch.bool, device=c.device)
off_diag_mean = c[off_diag_mask].abs().mean().item()
return total_loss, {
'od': on_diag.item(),
'ofsc': (self.λ * off_diag).item(),
'ofrw': off_diag.item(),
'cr_onm': diag_mean,
'cr_offm': off_diag_mean
}
# ======================
# Utilities
# ======================
def load_config(config_path: str) -> Dict[str, Any]:
config_path = Path(config_path)
if config_path.suffix in {'.yaml', '.yml'}:
import yaml
with open(config_path) as f:
return yaml.safe_load(f)
elif config_path.suffix == '.json':
with open(config_path) as f:
return json.load(f)
else:
raise ValueError(f"Unsupported config format: {config_path.suffix}")
def sanitize_config(config: Dict[str, Any]) -> Dict[str, Any]:
float_keys = {
"LR", "WEIGHT_DECAY", "BARLOW_LAMBDA", "VICREG_LAMBDA",
"VICREG_MU", "VICREG_NU", "CORINFOMAX_ALPHA"
}
int_keys = {
"BATCH_SIZE", "EFFECTIVE_BATCH", "EPOCHS", "MAX_LENGTH",
"SEED", "EVAL_EVERY_N_PERCENT"
}
bool_keys = {"BEST_BY_HEALTH"}
for key in float_keys:
if key in config:
config[key] = float(config[key])
for key in int_keys:
if key in config:
config[key] = int(config[key])
for key in bool_keys:
if key in config:
val = config[key]
config[key] = val.lower() in {"true", "1", "yes", "on"} if isinstance(val, str) else bool(val)
return config
def set_seed(seed: int):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def enum_smiles(smi: str, k: int = 2) -> List[str]:
from rdkit.Chem import MolFromSmiles, MolToSmiles
mol = MolFromSmiles(smi)
if mol is None:
return [smi] * k
variants = set()
attempts = 0
while len(variants) < k and attempts < 100:
variants.add(MolToSmiles(mol, doRandom=True, canonical=False))
attempts += 1
return list(variants)[:k]
def tanimoto(s1: str, s2: str) -> float:
m1, m2 = MolFromSmiles(s1), MolFromSmiles(s2)
if not m1 or not m2:
return 0.0
fp1 = AllChem.GetMorganFingerprintAsBitVect(m1, radius=2, nBits=2048)
fp2 = AllChem.GetMorganFingerprintAsBitVect(m2, radius=2, nBits=2048)
return DataStructs.TanimotoSimilarity(fp1, fp2)
def uniformity_metrics(emb: np.ndarray) -> Dict[str, float]:
emb = normalize(emb)
sim = cosine_similarity(emb)
mask = ~np.eye(len(sim), dtype=bool)
pairwise = sim[mask]
mean_sim, std_sim = pairwise.mean(), pairwise.std()
distances = 1 - sim
uniformity = np.log(np.exp(-2 * distances[mask]).mean())
return {
'mean': float(mean_sim),
'std': float(std_sim),
'uniformity': float(uniformity),
'health_old': float(1 - mean_sim),
'collapsed': mean_sim > 0.7 or std_sim < 0.05
}
def forward_pooled(model: SentenceTransformer, text_list: List[str], device: torch.device) -> torch.Tensor:
tok = model.tokenize(text_list)
tok = {k: v.to(device) for k, v in tok.items()}
hf_output = model(tok)
return hf_output['token_embeddings'][:, 0, :]
def evaluate(model, eval_smiles: List[str], device: torch.device, step: int) -> Dict[str, Any]:
model.eval()
with torch.no_grad():
emb = model.encode(eval_smiles, convert_to_numpy=True, show_progress_bar=False, batch_size=32)
um = uniformity_metrics(emb)
same_view = [enum_smiles(s, 1)[0] for s in eval_smiles]
with torch.no_grad():
emb2 = model.encode(same_view, convert_to_numpy=True, show_progress_bar=False, batch_size=32)
same_cos = np.diag(cosine_similarity(emb, emb2))
alignment = 1 - same_cos.mean()
barlow_health = same_cos.mean() - um['mean']
print(f"\n📊 Step {step} | Alignment={alignment:.3f} | Uniformity={um['uniformity']:.3f}")
print(f" Same-mol cos: {same_cos.mean():.3f}±{same_cos.std():.3f} | Pairwise: {um['mean']:.3f}±{um['std']:.3f}")
print(f" Barlow Health: {barlow_health:.3f} (higher = better)")
model.train()
um['health'] = barlow_health
um['alignment'] = alignment
um['same_cos_mean'] = same_cos.mean()
um['same_cos_std'] = same_cos.std()
return um
# ======================
# Main
# ======================
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--epochs", type=int)
parser.add_argument("--lr", type=float)
parser.add_argument("--batch_size", type=int)
parser.add_argument("--loss_type", type=str, choices=["barlow", "vicreg", "corinfomax"])
args = parser.parse_args()
config = load_config(args.config)
for key, value in vars(args).items():
if value is not None and key != "config":
config[key] = value
config = sanitize_config(config)
set_seed(config.get("SEED", 42))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir = Path(config["OUTPUT_DIR"])
output_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_csv(config["DATA_PATH"])
smiles_list = df["SMILES"].dropna().tolist()
print(f"📂 Loaded {len(smiles_list)} SMILES")
train_examples = []
for smi in tqdm(smiles_list, desc="Enumerating SMILES"):
variants = enum_smiles(smi, 2)
if len(variants) < 2:
variants = [smi, smi]
train_examples.append(InputExample(texts=[variants[0], variants[1]]))
print(f" Created {len(train_examples)} pairs")
eval_size = min(200, len(smiles_list))
eval_smiles = np.random.choice(smiles_list, eval_size, replace=False).tolist()
# Model
model = SentenceTransformer('./chmbedv2-warmup-l5/final')
model.max_seq_length = config.get("MAX_LENGTH", 512)
embed_dim = model.get_sentence_embedding_dimension()
# Projector & Loss
loss_type = config.get("LOSS_TYPE", "barlow")
if loss_type == "barlow":
projector = BarlowTwinsProjector(
embed_dim,
hidden_dim=2048,
out_dim=2048
).to(device)
train_loss = BarlowTwinsLoss(
λ=config.get("BARLOW_LAMBDA", 0.005)
).to(device)
else:
raise ValueError(f"Unknown loss_type: {loss_type}")
model.to(device)
# Optimizer (include projector!)
from ranger21 import Ranger21
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
model_params = [
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": config.get("WEIGHT_DECAY", 0.01)},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0}
]
# Calculate training parameters for Ranger21 scheduling
batch_size = config.get("BATCH_SIZE", 8)
effective_batch = config.get("EFFECTIVE_BATCH", 32)
grad_acc = effective_batch // batch_size
epochs = config.get("EPOCHS", 1)
total_steps = (len(train_examples) // effective_batch) * epochs
train_loader = DataLoader(train_examples, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
num_batches_per_epoch = len(train_examples) // effective_batch
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
model_params = [
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": config.get("WEIGHT_DECAY", 0.01)},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0}
]
optimizer = Ranger21(
model_params + [{"params": projector.parameters(), "weight_decay": config.get("WEIGHT_DECAY", 0.01)}],
lr=config.get("LR", 1e-5),
num_epochs=epochs,
num_batches_per_epoch=num_batches_per_epoch,
weight_decay=0.0, # Handle weight decay manually in param groups
)
# Training loop setup
scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=1.0, end_factor=0.0, total_iters=total_steps
)
# Train
model.train()
step = 0
best_health = 0.0
best_step = 0
log_interval = max(1, int(total_steps * config.get("EVAL_EVERY_N_PERCENT", 25) / 100))
for epoch in range(epochs):
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")
for batch_idx, batch in enumerate(pbar):
texts = [[ex.texts[i] for ex in batch] for i in range(2)]
z1 = forward_pooled(model, texts[0], device)
z2 = forward_pooled(model, texts[1], device)
p1 = projector(z1)
p2 = projector(z2)
loss, extras = train_loss(p1, p2)
loss = loss / grad_acc
loss.backward()
if (batch_idx + 1) % grad_acc == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
step += 1
postfix = {"step": step, "lr": scheduler.get_last_lr()[0]}
for k, v in extras.items():
postfix[k] = f"{v:.3f}"
pbar.set_postfix(postfix)
if step % log_interval == 0 or step == total_steps:
um = evaluate(model, eval_smiles, device, step)
if config.get("BEST_BY_HEALTH", True) and um["health"] > best_health:
best_health, best_step = um["health"], step
model.save(str(output_dir / "best"))
model.save(str(output_dir / "final"))
print(f"\n✅ Training complete! Best health: {best_health:.3f} at step {best_step}")
if __name__ == "__main__":
main()