import os import time import math import logging import argparse from datetime import datetime from typing import Dict, List, Any import numpy as np import torch import torch.nn.functional as F import torch.nn as nn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, PreTrainedTokenizerBase ) from rdkit import Chem from rdkit.Chem import Descriptors logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") logger = logging.getLogger(__name__) def compute_rdkit_features(smiles: str) -> np.ndarray: try: mol = Chem.MolFromSmiles(smiles) if mol is None: return np.zeros(210, dtype=np.float32) return np.array(list(Descriptors.CalcMolDescriptors(mol).values())) except Exception: return np.zeros(210, dtype=np.float32) class SMILESAndDescriptorCollator: def __init__( self, tokenizer: PreTrainedTokenizerBase, max_length: int = 512, mlm_probability: float = 0.15, do_mlm: bool = True ): self.tokenizer = tokenizer self.max_length = max_length self.do_mlm = do_mlm if self.do_mlm: self.mlm_collator = DataCollatorForLanguageModeling( tokenizer=self.tokenizer, mlm=True, mlm_probability=mlm_probability, return_tensors="pt" ) else: self.mlm_collator = None def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: smiles_batch = [f['smiles'] for f in features] descriptors_list = [f['descriptors'] for f in features] tokenized = self.tokenizer( smiles_batch, padding=False, truncation=True, max_length=self.max_length, return_tensors=None ) features_for_mlm = [ {k: v[i] for k, v in tokenized.items()} for i in range(len(smiles_batch)) ] if self.do_mlm and self.mlm_collator: batch_text = self.mlm_collator(features_for_mlm) else: tokenized_padded = self.tokenizer.pad( features_for_mlm, padding=True, max_length=self.max_length, return_tensors="pt" ) batch_text = dict(tokenized_padded) descriptors_tensor = torch.tensor(np.stack(descriptors_list), dtype=torch.float32) batch = batch_text batch['descriptors'] = descriptors_tensor return batch def get_backbone_grad_vector(module, exclude_keywords=None): if exclude_keywords is None: exclude_keywords = [] grads = [] for name, param in module.named_parameters(): if any(keyword in name.lower() for keyword in exclude_keywords): continue if param.grad is not None: grads.append(param.grad.detach().flatten()) if len(grads) == 0: return torch.tensor([]) return torch.cat(grads) def compute_gradient_metrics(model, loss1, loss2, exclude_keywords=None): if exclude_keywords is None: exclude_keywords = [] model.zero_grad(set_to_none=True) loss1.backward(retain_graph=True) g1 = get_backbone_grad_vector(model, exclude_keywords) norm_mtr = g1.norm().item() if g1 is not None and g1.numel() > 0 else None model.zero_grad(set_to_none=True) loss2.backward(retain_graph=True) g2 = get_backbone_grad_vector(model, exclude_keywords) norm_mlm = g2.norm().item() if g2 is not None and g2.numel() > 0 else None model.zero_grad(set_to_none=True) angle_deg = None if (g1 is not None and g2 is not None and g1.numel() > 0 and g2.numel() > 0 and g1.numel() == g2.numel()): cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0), dim=1).item() cos_sim = max(min(cos_sim, 1.0), -1.0) angle_rad = math.acos(cos_sim) angle_deg = math.degrees(angle_rad) return { 'angle_deg': angle_deg, 'norm_mtr': norm_mtr, 'norm_mlm': norm_mlm } if __name__ == '__main__': parser = argparse.ArgumentParser(description="ChemBERTa Multi-Task Training") parser.add_argument("--smiles_file", type=str, default="support/smiles_10k.txt") parser.add_argument("--stats_file", type=str, default="support/normalization_params.pth") parser.add_argument("--output_file", type=str, default="model.pth") parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--mlm_weight", type=float, default=1.0) parser.add_argument("--mtr_weight", type=float, default=1.0) parser.add_argument("--lr", type=float, default=3e-5) parser.add_argument("--epochs", type=int, default=1) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") logger.info("Loading model...") tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-10M-MLM") model_base = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-10M-MLM").roberta model_dim = 384 mlm_head = nn.Sequential( nn.Linear(model_dim, model_dim * 2), nn.GELU(), nn.Linear(model_dim * 2, tokenizer.vocab_size), ) rdkit_head = nn.Sequential( nn.Linear(model_dim, model_dim * 2), nn.GELU(), nn.Linear(model_dim * 2, 210), ) model_base.to(device) mlm_head.to(device) rdkit_head.to(device) logger.info("Loading dataset...") raw_dataset = load_dataset("text", data_files={"train": args.smiles_file}) raw_dataset = raw_dataset.rename_column("text", "smiles") logger.info("Calculating RDKit features...") processed_dataset = raw_dataset.map( lambda x: {"descriptors": compute_rdkit_features(x["smiles"])}, num_proc=8, desc="Calculating RDKit features" ) collator = SMILESAndDescriptorCollator(tokenizer=tokenizer, max_length=args.max_length) dataloader = DataLoader( processed_dataset["train"], batch_size=args.batch_size, collate_fn=collator ) logger.info("Loading normalization stats...") stats = torch.load(args.stats_file, map_location=device) means = stats["means"].to(device) stds = stats["stds"].to(device) stds[stds < 1e-6] = 1.0 optimizer = torch.optim.AdamW( list(model_base.parameters()) + list(mlm_head.parameters()) + list(rdkit_head.parameters()), lr=args.lr, weight_decay=1e-4 ) clip_grad_norm = 1.0 BACKBONE_EXCLUDE_KEYWORDS = ["head", "rdkit", "mlm", "classifier", "pooler"] LOG_GRAD_METRICS_EVERY_N_BATCHES = 10 log_dir = os.path.join("runs", f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}") writer = SummaryWriter(log_dir=log_dir) global_step = 0 logger.info("Starting training") for epoch in range(args.epochs): start_time = time.time() total_loss_mtr = 0.0 total_loss_mlm = 0.0 total_loss = 0.0 num_batches = 0 pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}") model_base.train() mlm_head.train() rdkit_head.train() for batch_idx, batch in enumerate(pbar): input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) descriptors = batch["descriptors"].to(device) labels = batch["labels"].to(device) outputs = model_base(input_ids, attention_mask=attention_mask) result = outputs.last_hidden_state mtr_res = rdkit_head(result[:, 0]) mlm_res = mlm_head(result) loss_mtr = F.huber_loss(mtr_res, (descriptors - means) / stds) loss_mlm = F.cross_entropy(mlm_res.flatten(end_dim=1), labels.flatten(), ignore_index=-100) loss = loss_mtr * args.mtr_weight + loss_mlm * args.mlm_weight if num_batches % LOG_GRAD_METRICS_EVERY_N_BATCHES == 0: metrics = compute_gradient_metrics( model=model_base, loss1=loss_mtr, loss2=loss_mlm, exclude_keywords=BACKBONE_EXCLUDE_KEYWORDS ) postfix = { "loss_mlm": f"{loss_mlm.item():.4f}", "loss_mtr": f"{loss_mtr.item():.4f}", } if metrics["angle_deg"] is not None: postfix["angle"] = f"{metrics['angle_deg']:.1f}°" writer.add_scalar("gradients/backbone_angle_deg", metrics["angle_deg"], global_step) if metrics["norm_mtr"] is not None: postfix["‖∇MTR‖"] = f"{metrics['norm_mtr']:.3f}" writer.add_scalar("gradients/backbone_norm_mtr", metrics["norm_mtr"], global_step) if metrics["norm_mlm"] is not None: postfix["‖∇MLM‖"] = f"{metrics['norm_mlm']:.3f}" writer.add_scalar("gradients/backbone_norm_mlm", metrics["norm_mlm"], global_step) pbar.set_postfix(postfix) writer.add_scalar("loss/total", loss.item(), global_step) writer.add_scalar("loss/mtr_l1", loss_mtr.item(), global_step) writer.add_scalar("loss/mlm_ce", loss_mlm.item(), global_step) optimizer.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model_base.parameters(), clip_grad_norm) torch.nn.utils.clip_grad_norm_(rdkit_head.parameters(), clip_grad_norm) torch.nn.utils.clip_grad_norm_(mlm_head.parameters(), clip_grad_norm) writer.add_scalar("training/grad_norm_clipped", grad_norm.item(), global_step) writer.add_scalar("training/learning_rate", optimizer.param_groups[0]["lr"], global_step) optimizer.step() total_loss += loss.item() total_loss_mtr += loss_mtr.item() total_loss_mlm += loss_mlm.item() num_batches += 1 global_step += 1 epoch_time = time.time() - start_time avg_loss = total_loss / num_batches avg_loss_mtr = total_loss_mtr / num_batches avg_loss_mlm = total_loss_mlm / num_batches writer.add_scalar("epoch/avg_total_loss", avg_loss, epoch) writer.add_scalar("epoch/avg_loss_mtr", avg_loss_mtr, epoch) writer.add_scalar("epoch/avg_loss_mlm", avg_loss_mlm, epoch) writer.add_scalar("epoch/time_sec", epoch_time, epoch) logger.info( f"Epoch {epoch+1}/{args.epochs} | Time: {epoch_time:.2f}s | " f"Total Loss: {avg_loss:.4f} | L1 (MTR): {avg_loss_mtr:.4f} | " f"CE (MLM): {avg_loss_mlm:.4f} | Grad Norm: {grad_norm:.4f}" ) writer.close() logger.info("Saving checkpoint..") torch.save({ "backbone": model_base.state_dict(), "mlm_head": mlm_head.state_dict(), "mtr_head": rdkit_head.state_dict(), }, args.output_file) logger.info("Training is finished")