File size: 11,354 Bytes
f6fc460 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 | import os
import time
import math
import logging
import argparse
from datetime import datetime
from typing import Dict, List, Any
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForMaskedLM,
DataCollatorForLanguageModeling,
PreTrainedTokenizerBase
)
from rdkit import Chem
from rdkit.Chem import Descriptors
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
def compute_rdkit_features(smiles: str) -> np.ndarray:
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return np.zeros(210, dtype=np.float32)
return np.array(list(Descriptors.CalcMolDescriptors(mol).values()))
except Exception:
return np.zeros(210, dtype=np.float32)
class SMILESAndDescriptorCollator:
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
max_length: int = 512,
mlm_probability: float = 0.15,
do_mlm: bool = True
):
self.tokenizer = tokenizer
self.max_length = max_length
self.do_mlm = do_mlm
if self.do_mlm:
self.mlm_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=True,
mlm_probability=mlm_probability,
return_tensors="pt"
)
else:
self.mlm_collator = None
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
smiles_batch = [f['smiles'] for f in features]
descriptors_list = [f['descriptors'] for f in features]
tokenized = self.tokenizer(
smiles_batch,
padding=False,
truncation=True,
max_length=self.max_length,
return_tensors=None
)
features_for_mlm = [
{k: v[i] for k, v in tokenized.items()}
for i in range(len(smiles_batch))
]
if self.do_mlm and self.mlm_collator:
batch_text = self.mlm_collator(features_for_mlm)
else:
tokenized_padded = self.tokenizer.pad(
features_for_mlm,
padding=True,
max_length=self.max_length,
return_tensors="pt"
)
batch_text = dict(tokenized_padded)
descriptors_tensor = torch.tensor(np.stack(descriptors_list), dtype=torch.float32)
batch = batch_text
batch['descriptors'] = descriptors_tensor
return batch
def get_backbone_grad_vector(module, exclude_keywords=None):
if exclude_keywords is None:
exclude_keywords = []
grads = []
for name, param in module.named_parameters():
if any(keyword in name.lower() for keyword in exclude_keywords):
continue
if param.grad is not None:
grads.append(param.grad.detach().flatten())
if len(grads) == 0:
return torch.tensor([])
return torch.cat(grads)
def compute_gradient_metrics(model, loss1, loss2, exclude_keywords=None):
if exclude_keywords is None:
exclude_keywords = []
model.zero_grad(set_to_none=True)
loss1.backward(retain_graph=True)
g1 = get_backbone_grad_vector(model, exclude_keywords)
norm_mtr = g1.norm().item() if g1 is not None and g1.numel() > 0 else None
model.zero_grad(set_to_none=True)
loss2.backward(retain_graph=True)
g2 = get_backbone_grad_vector(model, exclude_keywords)
norm_mlm = g2.norm().item() if g2 is not None and g2.numel() > 0 else None
model.zero_grad(set_to_none=True)
angle_deg = None
if (g1 is not None and g2 is not None and
g1.numel() > 0 and g2.numel() > 0 and
g1.numel() == g2.numel()):
cos_sim = F.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0), dim=1).item()
cos_sim = max(min(cos_sim, 1.0), -1.0)
angle_rad = math.acos(cos_sim)
angle_deg = math.degrees(angle_rad)
return {
'angle_deg': angle_deg,
'norm_mtr': norm_mtr,
'norm_mlm': norm_mlm
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="ChemBERTa Multi-Task Training")
parser.add_argument("--smiles_file", type=str, default="support/smiles_10k.txt")
parser.add_argument("--stats_file", type=str, default="support/normalization_params.pth")
parser.add_argument("--output_file", type=str, default="model.pth")
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--mlm_weight", type=float, default=1.0)
parser.add_argument("--mtr_weight", type=float, default=1.0)
parser.add_argument("--lr", type=float, default=3e-5)
parser.add_argument("--epochs", type=int, default=1)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
logger.info("Loading model...")
tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-10M-MLM")
model_base = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-10M-MLM").roberta
model_dim = 384
mlm_head = nn.Sequential(
nn.Linear(model_dim, model_dim * 2),
nn.GELU(),
nn.Linear(model_dim * 2, tokenizer.vocab_size),
)
rdkit_head = nn.Sequential(
nn.Linear(model_dim, model_dim * 2),
nn.GELU(),
nn.Linear(model_dim * 2, 210),
)
model_base.to(device)
mlm_head.to(device)
rdkit_head.to(device)
logger.info("Loading dataset...")
raw_dataset = load_dataset("text", data_files={"train": args.smiles_file})
raw_dataset = raw_dataset.rename_column("text", "smiles")
logger.info("Calculating RDKit features...")
processed_dataset = raw_dataset.map(
lambda x: {"descriptors": compute_rdkit_features(x["smiles"])},
num_proc=8,
desc="Calculating RDKit features"
)
collator = SMILESAndDescriptorCollator(tokenizer=tokenizer, max_length=args.max_length)
dataloader = DataLoader(
processed_dataset["train"],
batch_size=args.batch_size,
collate_fn=collator
)
logger.info("Loading normalization stats...")
stats = torch.load(args.stats_file, map_location=device)
means = stats["means"].to(device)
stds = stats["stds"].to(device)
stds[stds < 1e-6] = 1.0
optimizer = torch.optim.AdamW(
list(model_base.parameters()) + list(mlm_head.parameters()) + list(rdkit_head.parameters()),
lr=args.lr, weight_decay=1e-4
)
clip_grad_norm = 1.0
BACKBONE_EXCLUDE_KEYWORDS = ["head", "rdkit", "mlm", "classifier", "pooler"]
LOG_GRAD_METRICS_EVERY_N_BATCHES = 10
log_dir = os.path.join("runs", f"train_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
writer = SummaryWriter(log_dir=log_dir)
global_step = 0
logger.info("Starting training")
for epoch in range(args.epochs):
start_time = time.time()
total_loss_mtr = 0.0
total_loss_mlm = 0.0
total_loss = 0.0
num_batches = 0
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}")
model_base.train()
mlm_head.train()
rdkit_head.train()
for batch_idx, batch in enumerate(pbar):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
descriptors = batch["descriptors"].to(device)
labels = batch["labels"].to(device)
outputs = model_base(input_ids, attention_mask=attention_mask)
result = outputs.last_hidden_state
mtr_res = rdkit_head(result[:, 0])
mlm_res = mlm_head(result)
loss_mtr = F.huber_loss(mtr_res, (descriptors - means) / stds)
loss_mlm = F.cross_entropy(mlm_res.flatten(end_dim=1), labels.flatten(), ignore_index=-100)
loss = loss_mtr * args.mtr_weight + loss_mlm * args.mlm_weight
if num_batches % LOG_GRAD_METRICS_EVERY_N_BATCHES == 0:
metrics = compute_gradient_metrics(
model=model_base, loss1=loss_mtr, loss2=loss_mlm,
exclude_keywords=BACKBONE_EXCLUDE_KEYWORDS
)
postfix = {
"loss_mlm": f"{loss_mlm.item():.4f}",
"loss_mtr": f"{loss_mtr.item():.4f}",
}
if metrics["angle_deg"] is not None:
postfix["angle"] = f"{metrics['angle_deg']:.1f}°"
writer.add_scalar("gradients/backbone_angle_deg", metrics["angle_deg"], global_step)
if metrics["norm_mtr"] is not None:
postfix["‖∇MTR‖"] = f"{metrics['norm_mtr']:.3f}"
writer.add_scalar("gradients/backbone_norm_mtr", metrics["norm_mtr"], global_step)
if metrics["norm_mlm"] is not None:
postfix["‖∇MLM‖"] = f"{metrics['norm_mlm']:.3f}"
writer.add_scalar("gradients/backbone_norm_mlm", metrics["norm_mlm"], global_step)
pbar.set_postfix(postfix)
writer.add_scalar("loss/total", loss.item(), global_step)
writer.add_scalar("loss/mtr_l1", loss_mtr.item(), global_step)
writer.add_scalar("loss/mlm_ce", loss_mlm.item(), global_step)
optimizer.zero_grad()
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model_base.parameters(), clip_grad_norm)
torch.nn.utils.clip_grad_norm_(rdkit_head.parameters(), clip_grad_norm)
torch.nn.utils.clip_grad_norm_(mlm_head.parameters(), clip_grad_norm)
writer.add_scalar("training/grad_norm_clipped", grad_norm.item(), global_step)
writer.add_scalar("training/learning_rate", optimizer.param_groups[0]["lr"], global_step)
optimizer.step()
total_loss += loss.item()
total_loss_mtr += loss_mtr.item()
total_loss_mlm += loss_mlm.item()
num_batches += 1
global_step += 1
epoch_time = time.time() - start_time
avg_loss = total_loss / num_batches
avg_loss_mtr = total_loss_mtr / num_batches
avg_loss_mlm = total_loss_mlm / num_batches
writer.add_scalar("epoch/avg_total_loss", avg_loss, epoch)
writer.add_scalar("epoch/avg_loss_mtr", avg_loss_mtr, epoch)
writer.add_scalar("epoch/avg_loss_mlm", avg_loss_mlm, epoch)
writer.add_scalar("epoch/time_sec", epoch_time, epoch)
logger.info(
f"Epoch {epoch+1}/{args.epochs} | Time: {epoch_time:.2f}s | "
f"Total Loss: {avg_loss:.4f} | L1 (MTR): {avg_loss_mtr:.4f} | "
f"CE (MLM): {avg_loss_mlm:.4f} | Grad Norm: {grad_norm:.4f}"
)
writer.close()
logger.info("Saving checkpoint..")
torch.save({
"backbone": model_base.state_dict(),
"mlm_head": mlm_head.state_dict(),
"mtr_head": rdkit_head.state_dict(),
}, args.output_file)
logger.info("Training is finished")
|