# train_eval_utils.py # Helpers for training and benchmarking # Feel free to adjust these setup acc. to your use case # - gbyuvd import os import math import torch from torch.utils.data import DataLoader, IterableDataset from tqdm import tqdm import json import random import pandas as pd from sklearn.model_selection import train_test_split from ranger21 import Ranger21 from datasets import load_dataset # ---------------------------- # CSV Logger # ---------------------------- import csv class CSVLogger: def __init__(self, filename, fieldnames): self.filename = filename self.fieldnames = fieldnames self._initialized = False def log(self, row_dict): if not self._initialized: self._init_file() with open(self.filename, 'a', newline='', encoding='utf-8') as f: writer = csv.DictWriter(f, fieldnames=self.fieldnames) writer.writerow(row_dict) def _init_file(self): with open(self.filename, 'w', newline='', encoding='utf-8') as f: writer = csv.DictWriter(f, fieldnames=self.fieldnames) writer.writeheader() self._initialized = True # ---------------------------- # Streaming dataset # ---------------------------- class SelfiesStreamingDataset(IterableDataset): def __init__(self, csv_file, tokenizer, max_seq_len=512, mask_prob=0.15, global_token_ids=None): self.tokenizer = tokenizer self.max_seq_len = max_seq_len self.mask_prob = mask_prob self.global_token_ids = global_token_ids or [] dataset = load_dataset("csv", data_files=csv_file, split="train", streaming=True) dataset = dataset.shuffle(seed=42, buffer_size=10000) self.dataset_iter = iter(dataset) self.mask_id = tokenizer.mask_token_id self.pad_id = tokenizer.pad_token_id def __iter__(self): for example in self.dataset_iter: smiles = example["SMILES"] enc = self.tokenizer(smiles, truncation=True, max_length=self.max_seq_len, return_tensors=None) input_ids = enc["input_ids"] attention_mask = enc["attention_mask"] labels = input_ids.copy() vocab_size = len(self.tokenizer) for i in range(len(input_ids)): if input_ids[i] in self.global_token_ids: continue if random.random() < self.mask_prob: rand = random.random() if rand < 0.8: input_ids[i] = self.mask_id elif rand < 0.9: input_ids[i] = random.randint(0, vocab_size - 1) while input_ids[i] in self.global_token_ids: input_ids[i] = random.randint(0, vocab_size - 1) else: pass global_positions = [idx for idx, tid in enumerate(input_ids) if tid in self.global_token_ids] yield ( torch.tensor(input_ids, dtype=torch.long), torch.tensor(attention_mask, dtype=torch.long), torch.tensor(labels, dtype=torch.long), global_positions, ) def collate_fn(batch): input_ids_list, attention_mask_list, labels_list, global_positions_list = zip(*batch) input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=0) attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask_list, batch_first=True, padding_value=0) labels = torch.nn.utils.rnn.pad_sequence(labels_list, batch_first=True, padding_value=-100) return input_ids, attention_mask, labels, global_positions_list def get_dataloader(csv_file, tokenizer, batch_size=16, max_seq_len=512, mask_prob=0.15, global_token_ids=None): dataset = SelfiesStreamingDataset(csv_file, tokenizer, max_seq_len, mask_prob, global_token_ids) return DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) # ---------------------------- # Dataset splitting # ---------------------------- def prepare_train_val_test_split(full_csv, train_csv, val_csv, test_csv, val_test_size=0.3, test_size_ratio=0.5, random_state=42): if all(os.path.exists(f) for f in [train_csv, val_csv, test_csv]): print(f" Train/val/test splits already exist. Skipping split.") train_count = sum(1 for _ in open(train_csv, encoding='utf-8')) - 1 val_count = sum(1 for _ in open(val_csv, encoding='utf-8')) - 1 test_count = sum(1 for _ in open(test_csv, encoding='utf-8')) - 1 return train_count, val_count, test_count df = pd.read_csv(full_csv) train_df, val_test_df = train_test_split(df, test_size=val_test_size, random_state=random_state) val_df, test_df = train_test_split(val_test_df, test_size=test_size_ratio, random_state=random_state) train_df.to_csv(train_csv, index=False) val_df.to_csv(val_csv, index=False) test_df.to_csv(test_csv, index=False) return len(train_df), len(val_df), len(test_df) # ---------------------------- # Model call helper # ---------------------------- import inspect def call_model(model, input_ids, attention_mask, labels, global_positions): # Prepare base args model_kwargs = { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "output_attentions": False, } # Check if model's forward method accepts 'global_positions' sig = inspect.signature(model.forward) if "global_positions" in sig.parameters: # Convert global_positions to tensor if needed (for RougeBERT) if isinstance(global_positions, (tuple, list)): if len(global_positions) == 0: global_positions = None else: max_len = max(len(g) for g in global_positions) padded = [ list(g) + [-1] * (max_len - len(g)) if isinstance(g, (list, tuple)) else [g] + [-1] * (max_len - 1) for g in global_positions ] global_positions = torch.tensor(padded, dtype=torch.long, device=input_ids.device) elif isinstance(global_positions, torch.Tensor): pass # already good elif global_positions is None: pass else: raise TypeError(f"Unsupported type for global_positions: {type(global_positions)}") model_kwargs["global_positions"] = global_positions # Call model return model(**model_kwargs) # ---------------------------- # Metrics # ---------------------------- def compute_metrics(logits, labels): """Compute MLM loss sum, correct preds, total tokens.""" loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="sum") vocab_size = logits.size(-1) logits_flat = logits.view(-1, vocab_size) labels_flat = labels.view(-1) loss = loss_fn(logits_flat, labels_flat) count = (labels_flat != -100).sum().item() preds = torch.argmax(logits_flat, dim=-1) correct = ((preds == labels_flat) & (labels_flat != -100)).sum().item() return loss.item(), correct, count def evaluate_model(model, dataloader, device): model.eval() total_loss, total_correct, total_count = 0.0, 0, 0 with torch.no_grad(): for input_ids, attention_mask, labels, global_positions in dataloader: input_ids, attention_mask, labels = ( input_ids.to(device), attention_mask.to(device), labels.to(device), ) outputs = call_model(model, input_ids, attention_mask, labels, global_positions) logits = outputs.logits loss, correct, count = compute_metrics(logits, labels) total_loss += loss total_correct += correct total_count += count avg_loss = total_loss / total_count if total_count > 0 else float("inf") perplexity = math.exp(avg_loss) if avg_loss < 20 else float("inf") accuracy = total_correct / total_count if total_count > 0 else 0.0 return avg_loss, perplexity, accuracy # ---------------------------- # Training + Evaluation loop # ---------------------------- def train_and_eval( model, tokenizer, train_csv, val_csv, test_csv, config, run_name="experiment", batch_size=16, grad_accum=4, num_epochs=1, learning_rate=3e-6, mask_prob=0.15, max_seq_len=None, patience=10, save_dir="./checkpoints", ): # pick max seq length if max_seq_len is None: if hasattr(config, "max_seq"): max_seq_len = config.max_seq elif hasattr(config, "max_position_embeddings"): max_seq_len = config.max_position_embeddings else: max_seq_len = 512 device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) global_token_ids = [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.mask_token_id] # Dataset counts train_count = sum(1 for _ in open(train_csv, encoding="utf-8")) - 1 val_count = sum(1 for _ in open(val_csv, encoding="utf-8")) - 1 test_count = sum(1 for _ in open(test_csv, encoding="utf-8")) - 1 train_steps_per_epoch = max(1, train_count // batch_size) optimizer_steps_per_epoch = max(1, train_steps_per_epoch // grad_accum) val_steps_total = max(1, val_count // batch_size) test_steps_total = max(1, test_count // batch_size) print(f" Train steps/epoch: {train_steps_per_epoch}") print(f" Optimizer steps/epoch: {optimizer_steps_per_epoch}") print(f" Val steps total: {val_steps_total}") print(f" Test steps total: {test_steps_total}") # Optimizer optimizer = Ranger21( model.parameters(), lr=learning_rate, weight_decay=0.01, use_adabelief=True, use_warmup=True, use_madgrad=True, num_epochs=num_epochs, warmdown_active=False, num_batches_per_epoch=optimizer_steps_per_epoch, ) os.makedirs(save_dir, exist_ok=True) with open(os.path.join(save_dir, f"{run_name}_config.json"), "w") as f: if hasattr(config, "to_dict"): json.dump(config.to_dict(), f, indent=2) else: json.dump(config.__dict__, f, indent=2) best_val_loss = float("inf") patience_counter = 0 final_val_loss, final_val_perplexity, final_val_acc = None, None, None # Initialize unified CSV logger metrics_log_path = os.path.join(save_dir, f"{run_name}_metrics.csv") metrics_logger = CSVLogger(metrics_log_path, ["epoch", "step", "train_loss", "val_loss", "ppl", "mlm_acc"]) for epoch in range(num_epochs): model.train() train_loader = get_dataloader(train_csv, tokenizer, batch_size, max_seq_len, mask_prob, global_token_ids) running_loss = 0.0 optimizer.zero_grad() pbar = tqdm(enumerate(train_loader), desc=f"{run_name} | Epoch {epoch+1}/{num_epochs}", total=train_steps_per_epoch) for step, (input_ids, attention_mask, labels, global_positions) in pbar: input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device) outputs = call_model(model, input_ids, attention_mask, labels, global_positions) loss = outputs.loss / grad_accum loss.backward() running_loss += loss.item() if (step + 1) % grad_accum == 0: optimizer.step() optimizer.zero_grad() # Log every 10 steps if (step + 1) % 10 == 0: avg_loss = running_loss / 10 running_loss = 0.0 # Validation every 10 steps val_loader = get_dataloader(val_csv, tokenizer, batch_size, max_seq_len, mask_prob, global_token_ids) val_loss, val_perplexity, val_acc = evaluate_model(model, val_loader, device) # Log unified metrics metrics_logger.log({ "epoch": epoch + 1, "step": step + 1, "train_loss": avg_loss, "val_loss": val_loss, "ppl": val_perplexity, "mlm_acc": val_acc }) # Print to console print(f"\n[Step {step+1}] Train Loss: {avg_loss:.4f} | " f"Val Loss: {val_loss:.4f} | Perplexity: {val_perplexity:.4f} | MLM Acc: {val_acc:.2%}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss patience_counter = 0 model.save_pretrained(os.path.join(save_dir, f"{run_name}_best")) tokenizer.save_pretrained(save_dir) else: patience_counter += 1 if patience_counter >= patience: print(" Early stopping triggered") break model.train() # back to train mode after val if patience_counter >= patience: break # Final test evaluation test_loader = get_dataloader(test_csv, tokenizer, batch_size, max_seq_len, mask_prob, global_token_ids) test_loss, test_perplexity, test_acc = evaluate_model(model, test_loader, device) print(f"\n Final Test | Loss: {test_loss:.4f} | " f"Perplexity: {test_perplexity:.4f} | MLM Acc: {test_acc:.2%}") model.save_pretrained(os.path.join(save_dir, f"{run_name}_final")) tokenizer.save_pretrained(save_dir) return { "best_val_loss": best_val_loss, "final_val_loss": final_val_loss, "final_val_perplexity": final_val_perplexity, "final_val_acc": final_val_acc, "test_loss": test_loss, "test_perplexity": test_perplexity, "test_acc": test_acc, }