RougeBERT / test_train.py
gbyuvd's picture
Upload base codes
48d0053 verified
# test_train.py
# For test train the RougeBERT with weights initialized
# Consider adjusting the data and tokenizer to your use case
# - gbyuvd
import torch
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from tqdm import tqdm
import pandas as pd
from ranger21 import Ranger21
import random
import os
from sklearn.model_selection import train_test_split
from torch.nn import CrossEntropyLoss
import json
# -------------------------------
# Streaming Dataset
# -------------------------------
from datasets import load_dataset
class SelfiesStreamingDataset(IterableDataset):
def __init__(self, csv_file, tokenizer, max_seq_len=512, mask_prob=0.15, global_token_ids=None):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.mask_prob = mask_prob
self.global_token_ids = global_token_ids or []
print(f"Loading dataset (streaming): {csv_file}")
dataset = load_dataset("csv", data_files=csv_file, split="train", streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10000)
self.dataset_iter = iter(dataset)
# Special IDs
self.mask_id = tokenizer.mask_token_id
self.pad_id = tokenizer.pad_token_id
def __iter__(self):
for example in self.dataset_iter:
smiles = example["SMILES"]
enc = self.tokenizer(smiles, truncation=True, max_length=self.max_seq_len, return_tensors=None)
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"]
labels = input_ids.copy()
# BERT-style 3-way masking + skip global tokens
vocab_size = len(self.tokenizer) # needed for random token sampling
for i in range(len(input_ids)):
# Skip masking for special/global tokens
if input_ids[i] in self.global_token_ids:
continue
# Decide whether to mask this token
if random.random() < self.mask_prob:
rand = random.random()
if rand < 0.8:
# 80%: replace with [MASK]
input_ids[i] = self.mask_id
elif rand < 0.9:
# 10%: replace with random token (excluding special tokens for safety)
input_ids[i] = random.randint(0, vocab_size - 1)
# Optional: avoid re-selecting special tokens
while input_ids[i] in self.global_token_ids:
input_ids[i] = random.randint(0, vocab_size - 1)
else:
# 10%: leave unchanged (do nothing)
pass
# Global token positions
global_positions = [idx for idx, tid in enumerate(input_ids) if tid in self.global_token_ids]
# Convert to tensors
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
labels = torch.tensor(labels, dtype=torch.long)
yield input_ids, attention_mask, labels, global_positions
def collate_fn(batch):
input_ids_list, attention_mask_list, labels_list, global_positions_list = zip(*batch)
# Pad sequences as before
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=0)
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask_list, batch_first=True, padding_value=0)
labels = torch.nn.utils.rnn.pad_sequence(labels_list, batch_first=True, padding_value=-100)
# Convert global_positions_list to padded tensor
max_g = max(len(g) for g in global_positions_list) # find max number of global tokens in batch
if max_g == 0:
# If no global tokens in entire batch, create empty tensor with 0 globals
global_positions = torch.full((len(global_positions_list), 0), -1, dtype=torch.long)
else:
padded_global_positions = [
g + [-1] * (max_g - len(g)) for g in global_positions_list
] # pad each list with -1 to max_g length
global_positions = torch.tensor(padded_global_positions, dtype=torch.long)
return input_ids, attention_mask, labels, global_positions
def get_dataloader(csv_file, tokenizer, batch_size=16, max_seq_len=512, mask_prob=0.15, global_token_ids=None):
dataset = SelfiesStreamingDataset(
csv_file=csv_file,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
mask_prob=mask_prob,
global_token_ids=global_token_ids
)
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn
)
# ----------------------------
# Model & Tokenizer Imports
# ----------------------------
from RougeBERT import RougeBERT
from FastChemTokenizer import FastChemTokenizer
# ----------------------------
# Auto-split CSV into train/val/test
# ----------------------------
def prepare_train_val_test_split(full_csv, train_csv, val_csv, test_csv, val_test_size=0.3, test_size_ratio=0.5, random_state=42):
"""
Splits full_csv into train, val, test.
val_test_size: portion of data to reserve for val+test (e.g., 0.3 → 70% train, 30% val+test)
test_size_ratio: portion of val+test to assign to test (e.g., 0.5 → 15% val, 15% test)
"""
if all(os.path.exists(f) for f in [train_csv, val_csv, test_csv]):
print(f" Train/val/test splits already exist. Skipping split.")
train_count = sum(1 for _ in open(train_csv, encoding='utf-8')) - 1
val_count = sum(1 for _ in open(val_csv, encoding='utf-8')) - 1
test_count = sum(1 for _ in open(test_csv, encoding='utf-8')) - 1
return train_count, val_count, test_count
print(f"SplitOptions not found. Loading and splitting {full_csv}...")
df = pd.read_csv(full_csv)
train_df, val_test_df = train_test_split(df, test_size=val_test_size, random_state=random_state)
val_df, test_df = train_test_split(val_test_df, test_size=test_size_ratio, random_state=random_state)
train_df.to_csv(train_csv, index=False)
val_df.to_csv(val_csv, index=False)
test_df.to_csv(test_csv, index=False)
print(f" Saved {len(train_df)} rows to {train_csv}")
print(f" Saved {len(val_df)} rows to {val_csv}")
print(f" Saved {len(test_df)} rows to {test_csv}")
return len(train_df), len(val_df), len(test_df)
# ----------------------------
# Hyperparameters / Config
# ----------------------------
BATCH_SIZE = 16
GRAD_ACCUM = 4
NUM_EPOCHS = 1
MAX_SEQ_LEN = 512
LEARNING_RATE = 3e-6
MASK_PROB = 0.25
FULL_CSV = "../data/sample_1k_smi_42.csv"
TRAIN_CSV = "../data/train.csv"
VAL_CSV = "../data/val.csv"
TEST_CSV = "../data/test.csv"
SAVE_DIR = "./trained_rougeberttest"
# Auto-split and get counts
TRAIN_SET_SIZE, VAL_SET_SIZE, TEST_SET_SIZE = prepare_train_val_test_split(FULL_CSV, TRAIN_CSV, VAL_CSV, TEST_CSV, val_test_size=0.2, test_size_ratio=0.5)
# Auto-calculate steps
train_steps_per_epoch = max(1, TRAIN_SET_SIZE // BATCH_SIZE)
optimizer_steps_per_epoch = max(1, train_steps_per_epoch // GRAD_ACCUM)
val_steps_total = max(1, VAL_SET_SIZE // BATCH_SIZE)
test_steps_total = max(1, TEST_SET_SIZE // BATCH_SIZE)
print(f" Train steps/epoch: {train_steps_per_epoch}")
print(f" Optimizer steps/epoch: {optimizer_steps_per_epoch}")
print(f" Val steps total: {val_steps_total}")
print(f" Test steps total: {test_steps_total}")
# Log every 25% of training epoch
print_every = max(1, train_steps_per_epoch // 4)
checkpoint_every = max(1, train_steps_per_epoch // 10) # Save every 10%
print(f" Logging every {print_every} steps (25% of epoch)")
print(f" Checkpoint every {checkpoint_every} steps (10% of epoch)")
# ----------------------------
# Prepare tokenizer
# ----------------------------
tokenizer = FastChemTokenizer.from_pretrained("../smitok")
global_token_ids = [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.mask_token_id]
# ----------------------------
# Prepare model and optimizer
# ----------------------------
model = RougeBERT(
vocab_size=len(tokenizer),
max_seq=512,
num_layers=8,
hidden_size=320,
intermediate_size=1204,
num_heads=8,
kv_groups=2,
rotary_max_seq=512,
window=16,
dropout=0.1,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = Ranger21(
model.parameters(),
lr=LEARNING_RATE,
weight_decay=0.01,
use_adabelief=True,
use_warmup=True,
use_madgrad=True,
num_epochs=NUM_EPOCHS,
warmdown_active=False,
num_batches_per_epoch=optimizer_steps_per_epoch
)
criterion = CrossEntropyLoss(ignore_index=-100)
# ----------------------------
# Save Config
# ----------------------------
config = {
"vocab_size": len(tokenizer),
"max_seq": MAX_SEQ_LEN,
"num_layers": 8,
"hidden_size": 320,
"intermediate_size": 1280,
"num_heads": 8,
"kv_groups": 2,
"rotary_max_seq": 128,
"window": 16,
"dropout": 0.1,
"pad_token_id": tokenizer.pad_token_id,
"mask_token_id": tokenizer.mask_token_id,
"batch_size": BATCH_SIZE,
"grad_accum": GRAD_ACCUM,
"learning_rate": LEARNING_RATE,
"mask_prob": MASK_PROB,
}
os.makedirs(SAVE_DIR, exist_ok=True)
with open(os.path.join(SAVE_DIR, "config.json"), "w") as f:
json.dump(config, f, indent=2)
print(f" Config saved to {os.path.join(SAVE_DIR, 'config.json')}")
# ----------------------------
# Early Stopping Setup
# ----------------------------
best_val_loss = float('inf')
patience_counter = 0
PATIENCE = 2 # Stop if no improvement for 3 evals
print(f" Early stopping: patience = {PATIENCE} evaluations")
# ----------------------------
# Training + Validation Loop
# ----------------------------
for epoch in range(NUM_EPOCHS):
model.train()
running_loss = 0.0
optimizer.zero_grad()
train_loader = get_dataloader(
csv_file=TRAIN_CSV,
tokenizer=tokenizer,
batch_size=BATCH_SIZE,
max_seq_len=MAX_SEQ_LEN,
mask_prob=MASK_PROB,
global_token_ids=global_token_ids
)
pbar = tqdm(enumerate(train_loader), total=train_steps_per_epoch, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
for step, (input_ids, attention_mask, labels, global_positions) in pbar:
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
logits = model(input_ids, attention_mask=attention_mask, global_positions=global_positions)
loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
loss = loss / GRAD_ACCUM
loss.backward()
running_loss += loss.item()
if (step + 1) % GRAD_ACCUM == 0:
optimizer.step()
optimizer.zero_grad()
# ----------------------------
# Logging + Eval + Checkpoint + GPU Monitoring
# ----------------------------
if (step + 1) % print_every == 0:
avg_loss = running_loss / print_every
pbar.set_postfix({"train_loss": f"{avg_loss:.4f}"})
# Validation
model.eval()
val_loss = 0.0
correct = 0
total = 0
val_steps = 0
sample_inputs = []
sample_labels = []
sample_preds = []
val_loader = get_dataloader(
csv_file=VAL_CSV,
tokenizer=tokenizer,
batch_size=BATCH_SIZE,
max_seq_len=MAX_SEQ_LEN,
mask_prob=MASK_PROB,
global_token_ids=global_token_ids
)
with torch.no_grad():
for vbatch in val_loader:
v_input_ids, v_attention_mask, v_labels, v_global_positions = vbatch
v_input_ids = v_input_ids.to(device)
v_attention_mask = v_attention_mask.to(device)
v_labels = v_labels.to(device)
v_logits = model(v_input_ids, attention_mask=v_attention_mask, global_positions=v_global_positions)
v_loss = criterion(v_logits.view(-1, v_logits.size(-1)), v_labels.view(-1))
val_loss += v_loss.item()
v_preds = torch.argmax(v_logits, dim=-1)
mask = (v_labels != -100)
correct += (v_preds[mask] == v_labels[mask]).sum().item()
total += mask.sum().item()
if val_steps == 0:
sample_inputs.append(v_input_ids.cpu())
sample_labels.append(v_labels.cpu())
sample_preds.append(v_preds.cpu())
val_steps += 1
if val_steps >= min(50, val_steps_total // 4):
break
avg_val_loss = val_loss / val_steps if val_steps > 0 else float('inf')
perplexity = torch.exp(torch.tensor(avg_val_loss)).item() if val_steps > 0 else float('inf')
accuracy = correct / total if total > 0 else 0.0
print(f"\n[Step {step+1}] "
f"Train Loss: {avg_loss:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | "
f"Perplexity: {perplexity:.4f} | "
f"MLM Acc: {accuracy:.2%}")
# GPU Monitoring (every eval)
if torch.cuda.is_available():
gpu_mem_alloc = torch.cuda.memory_allocated() / 1e9
gpu_mem_res = torch.cuda.memory_reserved() / 1e9
print(f"GMEM: {gpu_mem_alloc:.2f}GB / {gpu_mem_res:.2f}GB")
# Early Stopping Check
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
print(" New best val loss!")
else:
patience_counter += 1
print(f" No improvement. Patience: {patience_counter}/{PATIENCE}")
if patience_counter >= PATIENCE:
print(" Early stopping triggered!")
break
# Checkpoint every 10%
if (step + 1) % checkpoint_every == 0:
ckpt_path = os.path.join(SAVE_DIR, f"checkpoint_step_{step+1}.bin")
torch.save(model.state_dict(), ckpt_path)
print(f" Checkpoint saved: {ckpt_path}")
# Reset
running_loss = 0.0
model.train()
# Break if early stopping triggered inside eval block
if patience_counter >= PATIENCE:
break
# If early stopped, break epoch loop too
if patience_counter >= PATIENCE:
print(" Stopping training early.")
break
# ----------------------------
# Final Test Evaluation
# ----------------------------
print(f"\n{'='*50}\n🔬 FINAL EPOCH TEST EVALUATION\n{'='*50}")
model.eval()
test_loss = 0.0
correct = 0
total = 0
test_steps = 0
sample_inputs = []
sample_labels = []
sample_preds = []
test_loader = get_dataloader(
csv_file=TEST_CSV,
tokenizer=tokenizer,
batch_size=BATCH_SIZE,
max_seq_len=MAX_SEQ_LEN,
mask_prob=MASK_PROB,
global_token_ids=global_token_ids
)
with torch.no_grad():
for tbatch in test_loader:
t_input_ids, t_attention_mask, t_labels, t_global_positions = tbatch
t_input_ids = t_input_ids.to(device)
t_attention_mask = t_attention_mask.to(device)
t_labels = t_labels.to(device)
t_logits = model(t_input_ids, attention_mask=t_attention_mask, global_positions=t_global_positions)
t_loss = criterion(t_logits.view(-1, t_logits.size(-1)), t_labels.view(-1))
test_loss += t_loss.item()
t_preds = torch.argmax(t_logits, dim=-1)
mask = (t_labels != -100)
correct += (t_preds[mask] == t_labels[mask]).sum().item()
total += mask.sum().item()
if test_steps == 0:
sample_inputs.append(t_input_ids.cpu())
sample_labels.append(t_labels.cpu())
sample_preds.append(t_preds.cpu())
test_steps += 1
if test_steps >= 100:
break
avg_test_loss = test_loss / test_steps if test_steps > 0 else float('inf')
perplexity = torch.exp(torch.tensor(avg_test_loss)).item() if test_steps > 0 else float('inf')
accuracy = correct / total if total > 0 else 0.0
print(f" FINAL Epoch {epoch+1} | "
f"Test Loss: {avg_test_loss:.4f} | "
f"Perplexity: {perplexity:.4f} | "
f"MLM Acc: {accuracy:.2%}")
# ----------------------------
# Final Save
# ----------------------------
final_path = os.path.join(SAVE_DIR, "pytorch_model.bin")
torch.save(model.state_dict(), final_path)
print(f" Final model saved to {final_path}")
try:
tokenizer.save_pretrained(SAVE_DIR)
print(f" Tokenizer saved to {SAVE_DIR}")
except Exception as e:
print(f" Could not save tokenizer: {e}")
# ----------------------------
# 🔍 Sample Predictions
# ----------------------------
print("\n" + "="*70)
print(" SAMPLE PREDICTIONS (first 3 from final validation)")
print("="*70)
if sample_inputs and len(sample_inputs[0]) > 0:
input_ids_sample = sample_inputs[0][:3]
labels_sample = sample_labels[0][:3]
preds_sample = sample_preds[0][:3]
for i in range(len(input_ids_sample)):
input_toks = input_ids_sample[i].tolist()
label_toks = labels_sample[i].tolist()
pred_toks = preds_sample[i].tolist()
original_tokens = [t for t in label_toks if t != -100]
original = tokenizer.decode(original_tokens, skip_special_tokens=False)
masked_tokens = []
for j, tok in enumerate(input_toks):
if tok == tokenizer.mask_token_id:
masked_tokens.append("<mask>")
elif tok == 0:
continue
else:
decoded = tokenizer.decode([tok], skip_special_tokens=False)
masked_tokens.append(decoded)
masked_str = "".join(masked_tokens).replace(" ", "")
pred_filtered = [p for p, l in zip(pred_toks, label_toks) if l != -100]
predicted = tokenizer.decode(pred_filtered, skip_special_tokens=False)
print(f"\n Example {i+1}:")
print(f" Original: {original}")
print(f" Masked: {masked_str}")
print(f" Predicted: {predicted}")
else:
print(" No samples captured for visualization.")
print(f"\n Training complete! Best Val Loss: {best_val_loss:.4f}")