| import os |
| import time |
| import math |
| import logging |
| import argparse |
| from datetime import datetime |
| from typing import Dict, List, Any |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
| from torch.utils.data import DataLoader |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm |
|
|
| from datasets import load_dataset |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForMaskedLM, |
| DataCollatorForLanguageModeling, |
| PreTrainedTokenizerBase |
| ) |
| from rdkit import Chem |
| from rdkit.Chem import Descriptors |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| def compute_rdkit_features(smiles: str) -> np.ndarray: |
| try: |
| mol = Chem.MolFromSmiles(smiles) |
| if mol is None: |
| return np.zeros(210, dtype=np.float32) |
| return np.array(list(Descriptors.CalcMolDescriptors(mol).values())) |
| except Exception: |
| return np.zeros(210, dtype=np.float32) |
|
|
| class SMILESAndDescriptorCollator: |
| def __init__( |
| self, |
| tokenizer: PreTrainedTokenizerBase, |
| max_length: int = 512, |
| mlm_probability: float = 0.15, |
| do_mlm: bool = True |
| ): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.do_mlm = do_mlm |
| if self.do_mlm: |
| self.mlm_collator = DataCollatorForLanguageModeling( |
| tokenizer=self.tokenizer, |
| mlm=True, |
| mlm_probability=mlm_probability, |
| return_tensors="pt" |
| ) |
| else: |
| self.mlm_collator = None |
|
|
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
| smiles_batch = [f['smiles'] for f in features] |
| descriptors_list = [f['descriptors'] for f in features] |
|
|
| tokenized = self.tokenizer( |
| smiles_batch, |
| padding=False, |
| truncation=True, |
| max_length=self.max_length, |
| return_tensors=None |
| ) |
|
|
| features_for_mlm = [ |
| {k: v[i] for k, v in tokenized.items()} |
| for i in range(len(smiles_batch)) |
| ] |
|
|
| if self.do_mlm and self.mlm_collator: |
| batch_text = self.mlm_collator(features_for_mlm) |
| else: |
| tokenized_padded = self.tokenizer.pad( |
| features_for_mlm, |
| padding=True, |
| max_length=self.max_length, |
| return_tensors="pt" |
| ) |
| batch_text = dict(tokenized_padded) |
|
|
| descriptors_tensor = torch.tensor(np.stack(descriptors_list), dtype=torch.float32) |
|
|
| batch = batch_text |
| batch['descriptors'] = descriptors_tensor |
|
|
| return batch |
|
|
| def get_backbone_grad_vector(module, exclude_keywords=None): |
| if exclude_keywords is None: |
| exclude_keywords = [] |
|
|
| grads = [] |
| for name, param in module.named_parameters(): |
| if any(keyword in name.lower() for keyword in exclude_keywords): |
| continue |
| if param.grad is not None: |
| grads.append(param.grad.detach().flatten()) |
|
|
| if len(grads) == 0: |
| return torch.tensor([]) |
|
|
| return torch.cat(grads) |
|
|
| def compute_gradient_metrics(model, loss1, loss2, exclude_keywords=None): |
| if exclude_keywords is None: |
| exclude_keywords = [] |
|
|
| model.zero_grad(set_to_none=True) |
| loss1.backward(retain_graph=True) |
| g1 = get_backbone_grad_vector(model, exclude_keywords) |
| norm_mtr = g1.norm().item() if g1 is not None and g1.numel() > 0 else None |
|
|
| model.zero_grad(set_to_none=True) |
| loss2.backward(retain_graph=True) |
| g2 = get_backbone_grad_vector(model, exclude_keywords) |
| norm_mlm = g2.norm().item() if g2 is not None and g2.numel() > 0 else None |
|
|
| model.zero_grad(set_to_none=True) |
|
|
| angle_deg = None |
| if (g1 is not None and g2 is not None and |
| g1.numel() > 0 and g2.numel() > 0 and |
| g1.numel() == g2.numel()): |
| cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0), dim=1).item() |
| cos_sim = max(min(cos_sim, 1.0), -1.0) |
| angle_rad = math.acos(cos_sim) |
| angle_deg = math.degrees(angle_rad) |
|
|
| return { |
| 'angle_deg': angle_deg, |
| 'norm_mtr': norm_mtr, |
| 'norm_mlm': norm_mlm |
| } |
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description="ChemBERTa Multi-Task Training") |
| parser.add_argument("--smiles_file", type=str, default="support/smiles_10k.txt") |
| parser.add_argument("--stats_file", type=str, default="support/normalization_params.pth") |
| parser.add_argument("--output_file", type=str, default="model.pth") |
| parser.add_argument("--batch_size", type=int, default=64) |
| parser.add_argument("--max_length", type=int, default=128) |
| parser.add_argument("--mlm_weight", type=float, default=1.0) |
| parser.add_argument("--mtr_weight", type=float, default=1.0) |
| parser.add_argument("--lr", type=float, default=3e-5) |
| parser.add_argument("--epochs", type=int, default=1) |
| args = parser.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Using device: {device}") |
|
|
| logger.info("Loading model...") |
| tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-10M-MLM") |
| model_base = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-10M-MLM").roberta |
| model_dim = 384 |
|
|
| mlm_head = nn.Sequential( |
| nn.Linear(model_dim, model_dim * 2), |
| nn.GELU(), |
| nn.Linear(model_dim * 2, tokenizer.vocab_size), |
| ) |
| rdkit_head = nn.Sequential( |
| nn.Linear(model_dim, model_dim * 2), |
| nn.GELU(), |
| nn.Linear(model_dim * 2, 210), |
| ) |
|
|
| model_base.to(device) |
| mlm_head.to(device) |
| rdkit_head.to(device) |
|
|
| logger.info("Loading dataset...") |
| raw_dataset = load_dataset("text", data_files={"train": args.smiles_file}) |
| raw_dataset = raw_dataset.rename_column("text", "smiles") |
|
|
| logger.info("Calculating RDKit features...") |
| processed_dataset = raw_dataset.map( |
| lambda x: {"descriptors": compute_rdkit_features(x["smiles"])}, |
| num_proc=8, |
| desc="Calculating RDKit features" |
| ) |
|
|
| collator = SMILESAndDescriptorCollator(tokenizer=tokenizer, max_length=args.max_length) |
| dataloader = DataLoader( |
| processed_dataset["train"], |
| batch_size=args.batch_size, |
| collate_fn=collator |
| ) |
|
|
| logger.info("Loading normalization stats...") |
| stats = torch.load(args.stats_file, map_location=device) |
| means = stats["means"].to(device) |
| stds = stats["stds"].to(device) |
| stds[stds < 1e-6] = 1.0 |
|
|
| optimizer = torch.optim.AdamW( |
| list(model_base.parameters()) + list(mlm_head.parameters()) + list(rdkit_head.parameters()), |
| lr=args.lr, weight_decay=1e-4 |
| ) |
|
|
| clip_grad_norm = 1.0 |
| BACKBONE_EXCLUDE_KEYWORDS = ["head", "rdkit", "mlm", "classifier", "pooler"] |
| LOG_GRAD_METRICS_EVERY_N_BATCHES = 10 |
|
|
| log_dir = os.path.join("runs", f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}") |
| writer = SummaryWriter(log_dir=log_dir) |
| global_step = 0 |
|
|
| logger.info("Starting training") |
|
|
| for epoch in range(args.epochs): |
| start_time = time.time() |
| total_loss_mtr = 0.0 |
| total_loss_mlm = 0.0 |
| total_loss = 0.0 |
| num_batches = 0 |
|
|
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}") |
| model_base.train() |
| mlm_head.train() |
| rdkit_head.train() |
|
|
| for batch_idx, batch in enumerate(pbar): |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| descriptors = batch["descriptors"].to(device) |
| labels = batch["labels"].to(device) |
|
|
| outputs = model_base(input_ids, attention_mask=attention_mask) |
| result = outputs.last_hidden_state |
|
|
| mtr_res = rdkit_head(result[:, 0]) |
| mlm_res = mlm_head(result) |
|
|
| loss_mtr = F.huber_loss(mtr_res, (descriptors - means) / stds) |
| loss_mlm = F.cross_entropy(mlm_res.flatten(end_dim=1), labels.flatten(), ignore_index=-100) |
| loss = loss_mtr * args.mtr_weight + loss_mlm * args.mlm_weight |
|
|
| if num_batches % LOG_GRAD_METRICS_EVERY_N_BATCHES == 0: |
| metrics = compute_gradient_metrics( |
| model=model_base, loss1=loss_mtr, loss2=loss_mlm, |
| exclude_keywords=BACKBONE_EXCLUDE_KEYWORDS |
| ) |
| postfix = { |
| "loss_mlm": f"{loss_mlm.item():.4f}", |
| "loss_mtr": f"{loss_mtr.item():.4f}", |
| } |
| if metrics["angle_deg"] is not None: |
| postfix["angle"] = f"{metrics['angle_deg']:.1f}°" |
| writer.add_scalar("gradients/backbone_angle_deg", metrics["angle_deg"], global_step) |
| if metrics["norm_mtr"] is not None: |
| postfix["‖∇MTR‖"] = f"{metrics['norm_mtr']:.3f}" |
| writer.add_scalar("gradients/backbone_norm_mtr", metrics["norm_mtr"], global_step) |
| if metrics["norm_mlm"] is not None: |
| postfix["‖∇MLM‖"] = f"{metrics['norm_mlm']:.3f}" |
| writer.add_scalar("gradients/backbone_norm_mlm", metrics["norm_mlm"], global_step) |
|
|
| pbar.set_postfix(postfix) |
|
|
| writer.add_scalar("loss/total", loss.item(), global_step) |
| writer.add_scalar("loss/mtr_l1", loss_mtr.item(), global_step) |
| writer.add_scalar("loss/mlm_ce", loss_mlm.item(), global_step) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
|
|
| grad_norm = torch.nn.utils.clip_grad_norm_(model_base.parameters(), clip_grad_norm) |
| torch.nn.utils.clip_grad_norm_(rdkit_head.parameters(), clip_grad_norm) |
| torch.nn.utils.clip_grad_norm_(mlm_head.parameters(), clip_grad_norm) |
|
|
| writer.add_scalar("training/grad_norm_clipped", grad_norm.item(), global_step) |
| writer.add_scalar("training/learning_rate", optimizer.param_groups[0]["lr"], global_step) |
|
|
| optimizer.step() |
|
|
| total_loss += loss.item() |
| total_loss_mtr += loss_mtr.item() |
| total_loss_mlm += loss_mlm.item() |
| num_batches += 1 |
| global_step += 1 |
|
|
| epoch_time = time.time() - start_time |
| avg_loss = total_loss / num_batches |
| avg_loss_mtr = total_loss_mtr / num_batches |
| avg_loss_mlm = total_loss_mlm / num_batches |
|
|
| writer.add_scalar("epoch/avg_total_loss", avg_loss, epoch) |
| writer.add_scalar("epoch/avg_loss_mtr", avg_loss_mtr, epoch) |
| writer.add_scalar("epoch/avg_loss_mlm", avg_loss_mlm, epoch) |
| writer.add_scalar("epoch/time_sec", epoch_time, epoch) |
|
|
| logger.info( |
| f"Epoch {epoch+1}/{args.epochs} | Time: {epoch_time:.2f}s | " |
| f"Total Loss: {avg_loss:.4f} | L1 (MTR): {avg_loss_mtr:.4f} | " |
| f"CE (MLM): {avg_loss_mlm:.4f} | Grad Norm: {grad_norm:.4f}" |
| ) |
|
|
| writer.close() |
|
|
| logger.info("Saving checkpoint..") |
| torch.save({ |
| "backbone": model_base.state_dict(), |
| "mlm_head": mlm_head.state_dict(), |
| "mtr_head": rdkit_head.state_dict(), |
| }, args.output_file) |
| logger.info("Training is finished") |
|
|