|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import math
|
|
|
import torch
|
|
|
from torch.utils.data import DataLoader, IterableDataset
|
|
|
from tqdm import tqdm
|
|
|
import json
|
|
|
import random
|
|
|
import pandas as pd
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
from ranger21 import Ranger21
|
|
|
from datasets import load_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import csv
|
|
|
|
|
|
class CSVLogger:
|
|
|
def __init__(self, filename, fieldnames):
|
|
|
self.filename = filename
|
|
|
self.fieldnames = fieldnames
|
|
|
self._initialized = False
|
|
|
|
|
|
def log(self, row_dict):
|
|
|
if not self._initialized:
|
|
|
self._init_file()
|
|
|
with open(self.filename, 'a', newline='', encoding='utf-8') as f:
|
|
|
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
|
|
|
writer.writerow(row_dict)
|
|
|
|
|
|
def _init_file(self):
|
|
|
with open(self.filename, 'w', newline='', encoding='utf-8') as f:
|
|
|
writer = csv.DictWriter(f, fieldnames=self.fieldnames)
|
|
|
writer.writeheader()
|
|
|
self._initialized = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SelfiesStreamingDataset(IterableDataset):
|
|
|
def __init__(self, csv_file, tokenizer, max_seq_len=512, mask_prob=0.15, global_token_ids=None):
|
|
|
self.tokenizer = tokenizer
|
|
|
self.max_seq_len = max_seq_len
|
|
|
self.mask_prob = mask_prob
|
|
|
self.global_token_ids = global_token_ids or []
|
|
|
|
|
|
dataset = load_dataset("csv", data_files=csv_file, split="train", streaming=True)
|
|
|
dataset = dataset.shuffle(seed=42, buffer_size=10000)
|
|
|
self.dataset_iter = iter(dataset)
|
|
|
|
|
|
self.mask_id = tokenizer.mask_token_id
|
|
|
self.pad_id = tokenizer.pad_token_id
|
|
|
|
|
|
def __iter__(self):
|
|
|
for example in self.dataset_iter:
|
|
|
smiles = example["SMILES"]
|
|
|
enc = self.tokenizer(smiles, truncation=True, max_length=self.max_seq_len, return_tensors=None)
|
|
|
|
|
|
input_ids = enc["input_ids"]
|
|
|
attention_mask = enc["attention_mask"]
|
|
|
labels = input_ids.copy()
|
|
|
|
|
|
vocab_size = len(self.tokenizer)
|
|
|
for i in range(len(input_ids)):
|
|
|
if input_ids[i] in self.global_token_ids:
|
|
|
continue
|
|
|
if random.random() < self.mask_prob:
|
|
|
rand = random.random()
|
|
|
if rand < 0.8:
|
|
|
input_ids[i] = self.mask_id
|
|
|
elif rand < 0.9:
|
|
|
input_ids[i] = random.randint(0, vocab_size - 1)
|
|
|
while input_ids[i] in self.global_token_ids:
|
|
|
input_ids[i] = random.randint(0, vocab_size - 1)
|
|
|
else:
|
|
|
pass
|
|
|
|
|
|
global_positions = [idx for idx, tid in enumerate(input_ids) if tid in self.global_token_ids]
|
|
|
|
|
|
yield (
|
|
|
torch.tensor(input_ids, dtype=torch.long),
|
|
|
torch.tensor(attention_mask, dtype=torch.long),
|
|
|
torch.tensor(labels, dtype=torch.long),
|
|
|
global_positions,
|
|
|
)
|
|
|
|
|
|
|
|
|
def collate_fn(batch):
|
|
|
input_ids_list, attention_mask_list, labels_list, global_positions_list = zip(*batch)
|
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=0)
|
|
|
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask_list, batch_first=True, padding_value=0)
|
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels_list, batch_first=True, padding_value=-100)
|
|
|
return input_ids, attention_mask, labels, global_positions_list
|
|
|
|
|
|
|
|
|
def get_dataloader(csv_file, tokenizer, batch_size=16, max_seq_len=512, mask_prob=0.15, global_token_ids=None):
|
|
|
dataset = SelfiesStreamingDataset(csv_file, tokenizer, max_seq_len, mask_prob, global_token_ids)
|
|
|
return DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_train_val_test_split(full_csv, train_csv, val_csv, test_csv,
|
|
|
val_test_size=0.3, test_size_ratio=0.5, random_state=42):
|
|
|
if all(os.path.exists(f) for f in [train_csv, val_csv, test_csv]):
|
|
|
print(f" Train/val/test splits already exist. Skipping split.")
|
|
|
train_count = sum(1 for _ in open(train_csv, encoding='utf-8')) - 1
|
|
|
val_count = sum(1 for _ in open(val_csv, encoding='utf-8')) - 1
|
|
|
test_count = sum(1 for _ in open(test_csv, encoding='utf-8')) - 1
|
|
|
return train_count, val_count, test_count
|
|
|
|
|
|
df = pd.read_csv(full_csv)
|
|
|
train_df, val_test_df = train_test_split(df, test_size=val_test_size, random_state=random_state)
|
|
|
val_df, test_df = train_test_split(val_test_df, test_size=test_size_ratio, random_state=random_state)
|
|
|
|
|
|
train_df.to_csv(train_csv, index=False)
|
|
|
val_df.to_csv(val_csv, index=False)
|
|
|
test_df.to_csv(test_csv, index=False)
|
|
|
|
|
|
return len(train_df), len(val_df), len(test_df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
|
|
|
|
def call_model(model, input_ids, attention_mask, labels, global_positions):
|
|
|
|
|
|
model_kwargs = {
|
|
|
"input_ids": input_ids,
|
|
|
"attention_mask": attention_mask,
|
|
|
"labels": labels,
|
|
|
"output_attentions": False,
|
|
|
}
|
|
|
|
|
|
|
|
|
sig = inspect.signature(model.forward)
|
|
|
if "global_positions" in sig.parameters:
|
|
|
|
|
|
if isinstance(global_positions, (tuple, list)):
|
|
|
if len(global_positions) == 0:
|
|
|
global_positions = None
|
|
|
else:
|
|
|
max_len = max(len(g) for g in global_positions)
|
|
|
padded = [
|
|
|
list(g) + [-1] * (max_len - len(g)) if isinstance(g, (list, tuple)) else [g] + [-1] * (max_len - 1)
|
|
|
for g in global_positions
|
|
|
]
|
|
|
global_positions = torch.tensor(padded, dtype=torch.long, device=input_ids.device)
|
|
|
elif isinstance(global_positions, torch.Tensor):
|
|
|
pass
|
|
|
elif global_positions is None:
|
|
|
pass
|
|
|
else:
|
|
|
raise TypeError(f"Unsupported type for global_positions: {type(global_positions)}")
|
|
|
|
|
|
model_kwargs["global_positions"] = global_positions
|
|
|
|
|
|
|
|
|
return model(**model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_metrics(logits, labels):
|
|
|
"""Compute MLM loss sum, correct preds, total tokens."""
|
|
|
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
|
|
|
|
|
|
vocab_size = logits.size(-1)
|
|
|
logits_flat = logits.view(-1, vocab_size)
|
|
|
labels_flat = labels.view(-1)
|
|
|
|
|
|
loss = loss_fn(logits_flat, labels_flat)
|
|
|
count = (labels_flat != -100).sum().item()
|
|
|
|
|
|
preds = torch.argmax(logits_flat, dim=-1)
|
|
|
correct = ((preds == labels_flat) & (labels_flat != -100)).sum().item()
|
|
|
|
|
|
return loss.item(), correct, count
|
|
|
|
|
|
|
|
|
def evaluate_model(model, dataloader, device):
|
|
|
model.eval()
|
|
|
total_loss, total_correct, total_count = 0.0, 0, 0
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for input_ids, attention_mask, labels, global_positions in dataloader:
|
|
|
input_ids, attention_mask, labels = (
|
|
|
input_ids.to(device),
|
|
|
attention_mask.to(device),
|
|
|
labels.to(device),
|
|
|
)
|
|
|
outputs = call_model(model, input_ids, attention_mask, labels, global_positions)
|
|
|
logits = outputs.logits
|
|
|
loss, correct, count = compute_metrics(logits, labels)
|
|
|
total_loss += loss
|
|
|
total_correct += correct
|
|
|
total_count += count
|
|
|
|
|
|
avg_loss = total_loss / total_count if total_count > 0 else float("inf")
|
|
|
perplexity = math.exp(avg_loss) if avg_loss < 20 else float("inf")
|
|
|
accuracy = total_correct / total_count if total_count > 0 else 0.0
|
|
|
|
|
|
return avg_loss, perplexity, accuracy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_and_eval(
|
|
|
model,
|
|
|
tokenizer,
|
|
|
train_csv,
|
|
|
val_csv,
|
|
|
test_csv,
|
|
|
config,
|
|
|
run_name="experiment",
|
|
|
batch_size=16,
|
|
|
grad_accum=4,
|
|
|
num_epochs=1,
|
|
|
learning_rate=3e-6,
|
|
|
mask_prob=0.15,
|
|
|
max_seq_len=None,
|
|
|
patience=10,
|
|
|
save_dir="./checkpoints",
|
|
|
):
|
|
|
|
|
|
if max_seq_len is None:
|
|
|
if hasattr(config, "max_seq"):
|
|
|
max_seq_len = config.max_seq
|
|
|
elif hasattr(config, "max_position_embeddings"):
|
|
|
max_seq_len = config.max_position_embeddings
|
|
|
else:
|
|
|
max_seq_len = 512
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
model.to(device)
|
|
|
|
|
|
global_token_ids = [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.mask_token_id]
|
|
|
|
|
|
|
|
|
train_count = sum(1 for _ in open(train_csv, encoding="utf-8")) - 1
|
|
|
val_count = sum(1 for _ in open(val_csv, encoding="utf-8")) - 1
|
|
|
test_count = sum(1 for _ in open(test_csv, encoding="utf-8")) - 1
|
|
|
|
|
|
train_steps_per_epoch = max(1, train_count // batch_size)
|
|
|
optimizer_steps_per_epoch = max(1, train_steps_per_epoch // grad_accum)
|
|
|
val_steps_total = max(1, val_count // batch_size)
|
|
|
test_steps_total = max(1, test_count // batch_size)
|
|
|
|
|
|
print(f" Train steps/epoch: {train_steps_per_epoch}")
|
|
|
print(f" Optimizer steps/epoch: {optimizer_steps_per_epoch}")
|
|
|
print(f" Val steps total: {val_steps_total}")
|
|
|
print(f" Test steps total: {test_steps_total}")
|
|
|
|
|
|
|
|
|
optimizer = Ranger21(
|
|
|
model.parameters(),
|
|
|
lr=learning_rate,
|
|
|
weight_decay=0.01,
|
|
|
use_adabelief=True,
|
|
|
use_warmup=True,
|
|
|
use_madgrad=True,
|
|
|
num_epochs=num_epochs,
|
|
|
warmdown_active=False,
|
|
|
num_batches_per_epoch=optimizer_steps_per_epoch,
|
|
|
)
|
|
|
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
with open(os.path.join(save_dir, f"{run_name}_config.json"), "w") as f:
|
|
|
if hasattr(config, "to_dict"):
|
|
|
json.dump(config.to_dict(), f, indent=2)
|
|
|
else:
|
|
|
json.dump(config.__dict__, f, indent=2)
|
|
|
|
|
|
best_val_loss = float("inf")
|
|
|
patience_counter = 0
|
|
|
final_val_loss, final_val_perplexity, final_val_acc = None, None, None
|
|
|
|
|
|
|
|
|
metrics_log_path = os.path.join(save_dir, f"{run_name}_metrics.csv")
|
|
|
metrics_logger = CSVLogger(metrics_log_path, ["epoch", "step", "train_loss", "val_loss", "ppl", "mlm_acc"])
|
|
|
|
|
|
for epoch in range(num_epochs):
|
|
|
model.train()
|
|
|
train_loader = get_dataloader(train_csv, tokenizer, batch_size, max_seq_len, mask_prob, global_token_ids)
|
|
|
|
|
|
running_loss = 0.0
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
pbar = tqdm(enumerate(train_loader), desc=f"{run_name} | Epoch {epoch+1}/{num_epochs}", total=train_steps_per_epoch)
|
|
|
for step, (input_ids, attention_mask, labels, global_positions) in pbar:
|
|
|
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
|
|
|
|
|
|
outputs = call_model(model, input_ids, attention_mask, labels, global_positions)
|
|
|
loss = outputs.loss / grad_accum
|
|
|
loss.backward()
|
|
|
running_loss += loss.item()
|
|
|
|
|
|
if (step + 1) % grad_accum == 0:
|
|
|
optimizer.step()
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
if (step + 1) % 10 == 0:
|
|
|
avg_loss = running_loss / 10
|
|
|
running_loss = 0.0
|
|
|
|
|
|
|
|
|
val_loader = get_dataloader(val_csv, tokenizer, batch_size, max_seq_len, mask_prob, global_token_ids)
|
|
|
val_loss, val_perplexity, val_acc = evaluate_model(model, val_loader, device)
|
|
|
|
|
|
|
|
|
metrics_logger.log({
|
|
|
"epoch": epoch + 1,
|
|
|
"step": step + 1,
|
|
|
"train_loss": avg_loss,
|
|
|
"val_loss": val_loss,
|
|
|
"ppl": val_perplexity,
|
|
|
"mlm_acc": val_acc
|
|
|
})
|
|
|
|
|
|
|
|
|
print(f"\n[Step {step+1}] Train Loss: {avg_loss:.4f} | "
|
|
|
f"Val Loss: {val_loss:.4f} | Perplexity: {val_perplexity:.4f} | MLM Acc: {val_acc:.2%}")
|
|
|
|
|
|
|
|
|
if val_loss < best_val_loss:
|
|
|
best_val_loss = val_loss
|
|
|
patience_counter = 0
|
|
|
model.save_pretrained(os.path.join(save_dir, f"{run_name}_best"))
|
|
|
tokenizer.save_pretrained(save_dir)
|
|
|
else:
|
|
|
patience_counter += 1
|
|
|
if patience_counter >= patience:
|
|
|
print(" Early stopping triggered")
|
|
|
break
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
if patience_counter >= patience:
|
|
|
break
|
|
|
|
|
|
|
|
|
test_loader = get_dataloader(test_csv, tokenizer, batch_size, max_seq_len, mask_prob, global_token_ids)
|
|
|
test_loss, test_perplexity, test_acc = evaluate_model(model, test_loader, device)
|
|
|
|
|
|
print(f"\n Final Test | Loss: {test_loss:.4f} | "
|
|
|
f"Perplexity: {test_perplexity:.4f} | MLM Acc: {test_acc:.2%}")
|
|
|
|
|
|
model.save_pretrained(os.path.join(save_dir, f"{run_name}_final"))
|
|
|
tokenizer.save_pretrained(save_dir)
|
|
|
|
|
|
return {
|
|
|
"best_val_loss": best_val_loss,
|
|
|
"final_val_loss": final_val_loss,
|
|
|
"final_val_perplexity": final_val_perplexity,
|
|
|
"final_val_acc": final_val_acc,
|
|
|
"test_loss": test_loss,
|
|
|
"test_perplexity": test_perplexity,
|
|
|
"test_acc": test_acc,
|
|
|
|
|
|
}
|
|
|
|