chemberta-3-phinformed / train_model.py
timcryt's picture
Initial commit
f6fc460 verified
import os
import time
import math
import logging
import argparse
from datetime import datetime
from typing import Dict, List, Any
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForMaskedLM,
DataCollatorForLanguageModeling,
PreTrainedTokenizerBase
)
from rdkit import Chem
from rdkit.Chem import Descriptors
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
def compute_rdkit_features(smiles: str) -> np.ndarray:
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return np.zeros(210, dtype=np.float32)
return np.array(list(Descriptors.CalcMolDescriptors(mol).values()))
except Exception:
return np.zeros(210, dtype=np.float32)
class SMILESAndDescriptorCollator:
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
max_length: int = 512,
mlm_probability: float = 0.15,
do_mlm: bool = True
):
self.tokenizer = tokenizer
self.max_length = max_length
self.do_mlm = do_mlm
if self.do_mlm:
self.mlm_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=True,
mlm_probability=mlm_probability,
return_tensors="pt"
)
else:
self.mlm_collator = None
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
smiles_batch = [f['smiles'] for f in features]
descriptors_list = [f['descriptors'] for f in features]
tokenized = self.tokenizer(
smiles_batch,
padding=False,
truncation=True,
max_length=self.max_length,
return_tensors=None
)
features_for_mlm = [
{k: v[i] for k, v in tokenized.items()}
for i in range(len(smiles_batch))
]
if self.do_mlm and self.mlm_collator:
batch_text = self.mlm_collator(features_for_mlm)
else:
tokenized_padded = self.tokenizer.pad(
features_for_mlm,
padding=True,
max_length=self.max_length,
return_tensors="pt"
)
batch_text = dict(tokenized_padded)
descriptors_tensor = torch.tensor(np.stack(descriptors_list), dtype=torch.float32)
batch = batch_text
batch['descriptors'] = descriptors_tensor
return batch
def get_backbone_grad_vector(module, exclude_keywords=None):
if exclude_keywords is None:
exclude_keywords = []
grads = []
for name, param in module.named_parameters():
if any(keyword in name.lower() for keyword in exclude_keywords):
continue
if param.grad is not None:
grads.append(param.grad.detach().flatten())
if len(grads) == 0:
return torch.tensor([])
return torch.cat(grads)
def compute_gradient_metrics(model, loss1, loss2, exclude_keywords=None):
if exclude_keywords is None:
exclude_keywords = []
model.zero_grad(set_to_none=True)
loss1.backward(retain_graph=True)
g1 = get_backbone_grad_vector(model, exclude_keywords)
norm_mtr = g1.norm().item() if g1 is not None and g1.numel() > 0 else None
model.zero_grad(set_to_none=True)
loss2.backward(retain_graph=True)
g2 = get_backbone_grad_vector(model, exclude_keywords)
norm_mlm = g2.norm().item() if g2 is not None and g2.numel() > 0 else None
model.zero_grad(set_to_none=True)
angle_deg = None
if (g1 is not None and g2 is not None and
g1.numel() > 0 and g2.numel() > 0 and
g1.numel() == g2.numel()):
cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0), dim=1).item()
cos_sim = max(min(cos_sim, 1.0), -1.0)
angle_rad = math.acos(cos_sim)
angle_deg = math.degrees(angle_rad)
return {
'angle_deg': angle_deg,
'norm_mtr': norm_mtr,
'norm_mlm': norm_mlm
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="ChemBERTa Multi-Task Training")
parser.add_argument("--smiles_file", type=str, default="support/smiles_10k.txt")
parser.add_argument("--stats_file", type=str, default="support/normalization_params.pth")
parser.add_argument("--output_file", type=str, default="model.pth")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--mlm_weight", type=float, default=1.0)
parser.add_argument("--mtr_weight", type=float, default=1.0)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("--epochs", type=int, default=1)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
logger.info("Loading model...")
tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-10M-MLM")
model_base = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-10M-MLM").roberta
model_dim = 384
mlm_head = nn.Sequential(
nn.Linear(model_dim, model_dim * 2),
nn.GELU(),
nn.Linear(model_dim * 2, tokenizer.vocab_size),
)
rdkit_head = nn.Sequential(
nn.Linear(model_dim, model_dim * 2),
nn.GELU(),
nn.Linear(model_dim * 2, 210),
)
model_base.to(device)
mlm_head.to(device)
rdkit_head.to(device)
logger.info("Loading dataset...")
raw_dataset = load_dataset("text", data_files={"train": args.smiles_file})
raw_dataset = raw_dataset.rename_column("text", "smiles")
logger.info("Calculating RDKit features...")
processed_dataset = raw_dataset.map(
lambda x: {"descriptors": compute_rdkit_features(x["smiles"])},
num_proc=8,
desc="Calculating RDKit features"
)
collator = SMILESAndDescriptorCollator(tokenizer=tokenizer, max_length=args.max_length)
dataloader = DataLoader(
processed_dataset["train"],
batch_size=args.batch_size,
collate_fn=collator
)
logger.info("Loading normalization stats...")
stats = torch.load(args.stats_file, map_location=device)
means = stats["means"].to(device)
stds = stats["stds"].to(device)
stds[stds < 1e-6] = 1.0
optimizer = torch.optim.AdamW(
list(model_base.parameters()) + list(mlm_head.parameters()) + list(rdkit_head.parameters()),
lr=args.lr, weight_decay=1e-4
)
clip_grad_norm = 1.0
BACKBONE_EXCLUDE_KEYWORDS = ["head", "rdkit", "mlm", "classifier", "pooler"]
LOG_GRAD_METRICS_EVERY_N_BATCHES = 10
log_dir = os.path.join("runs", f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
writer = SummaryWriter(log_dir=log_dir)
global_step = 0
logger.info("Starting training")
for epoch in range(args.epochs):
start_time = time.time()
total_loss_mtr = 0.0
total_loss_mlm = 0.0
total_loss = 0.0
num_batches = 0
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}")
model_base.train()
mlm_head.train()
rdkit_head.train()
for batch_idx, batch in enumerate(pbar):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
descriptors = batch["descriptors"].to(device)
labels = batch["labels"].to(device)
outputs = model_base(input_ids, attention_mask=attention_mask)
result = outputs.last_hidden_state
mtr_res = rdkit_head(result[:, 0])
mlm_res = mlm_head(result)
loss_mtr = F.huber_loss(mtr_res, (descriptors - means) / stds)
loss_mlm = F.cross_entropy(mlm_res.flatten(end_dim=1), labels.flatten(), ignore_index=-100)
loss = loss_mtr * args.mtr_weight + loss_mlm * args.mlm_weight
if num_batches % LOG_GRAD_METRICS_EVERY_N_BATCHES == 0:
metrics = compute_gradient_metrics(
model=model_base, loss1=loss_mtr, loss2=loss_mlm,
exclude_keywords=BACKBONE_EXCLUDE_KEYWORDS
)
postfix = {
"loss_mlm": f"{loss_mlm.item():.4f}",
"loss_mtr": f"{loss_mtr.item():.4f}",
}
if metrics["angle_deg"] is not None:
postfix["angle"] = f"{metrics['angle_deg']:.1f}°"
writer.add_scalar("gradients/backbone_angle_deg", metrics["angle_deg"], global_step)
if metrics["norm_mtr"] is not None:
postfix["‖∇MTR‖"] = f"{metrics['norm_mtr']:.3f}"
writer.add_scalar("gradients/backbone_norm_mtr", metrics["norm_mtr"], global_step)
if metrics["norm_mlm"] is not None:
postfix["‖∇MLM‖"] = f"{metrics['norm_mlm']:.3f}"
writer.add_scalar("gradients/backbone_norm_mlm", metrics["norm_mlm"], global_step)
pbar.set_postfix(postfix)
writer.add_scalar("loss/total", loss.item(), global_step)
writer.add_scalar("loss/mtr_l1", loss_mtr.item(), global_step)
writer.add_scalar("loss/mlm_ce", loss_mlm.item(), global_step)
optimizer.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model_base.parameters(), clip_grad_norm)
torch.nn.utils.clip_grad_norm_(rdkit_head.parameters(), clip_grad_norm)
torch.nn.utils.clip_grad_norm_(mlm_head.parameters(), clip_grad_norm)
writer.add_scalar("training/grad_norm_clipped", grad_norm.item(), global_step)
writer.add_scalar("training/learning_rate", optimizer.param_groups[0]["lr"], global_step)
optimizer.step()
total_loss += loss.item()
total_loss_mtr += loss_mtr.item()
total_loss_mlm += loss_mlm.item()
num_batches += 1
global_step += 1
epoch_time = time.time() - start_time
avg_loss = total_loss / num_batches
avg_loss_mtr = total_loss_mtr / num_batches
avg_loss_mlm = total_loss_mlm / num_batches
writer.add_scalar("epoch/avg_total_loss", avg_loss, epoch)
writer.add_scalar("epoch/avg_loss_mtr", avg_loss_mtr, epoch)
writer.add_scalar("epoch/avg_loss_mlm", avg_loss_mlm, epoch)
writer.add_scalar("epoch/time_sec", epoch_time, epoch)
logger.info(
f"Epoch {epoch+1}/{args.epochs} | Time: {epoch_time:.2f}s | "
f"Total Loss: {avg_loss:.4f} | L1 (MTR): {avg_loss_mtr:.4f} | "
f"CE (MLM): {avg_loss_mlm:.4f} | Grad Norm: {grad_norm:.4f}"
)
writer.close()
logger.info("Saving checkpoint..")
torch.save({
"backbone": model_base.state_dict(),
"mlm_head": mlm_head.state_dict(),
"mtr_head": rdkit_head.state_dict(),
}, args.output_file)
logger.info("Training is finished")