# test_train.py # For test train the RougeBERT with weights initialized # Consider adjusting the data and tokenizer to your use case # - gbyuvd import torch import torch.nn.functional as F from torch.utils.data import IterableDataset, DataLoader from tqdm import tqdm import pandas as pd from ranger21 import Ranger21 import random import os from sklearn.model_selection import train_test_split from torch.nn import CrossEntropyLoss import json # ------------------------------- # Streaming Dataset # ------------------------------- from datasets import load_dataset class SelfiesStreamingDataset(IterableDataset): def __init__(self, csv_file, tokenizer, max_seq_len=512, mask_prob=0.15, global_token_ids=None): self.tokenizer = tokenizer self.max_seq_len = max_seq_len self.mask_prob = mask_prob self.global_token_ids = global_token_ids or [] print(f"Loading dataset (streaming): {csv_file}") dataset = load_dataset("csv", data_files=csv_file, split="train", streaming=True) dataset = dataset.shuffle(seed=42, buffer_size=10000) self.dataset_iter = iter(dataset) # Special IDs self.mask_id = tokenizer.mask_token_id self.pad_id = tokenizer.pad_token_id def __iter__(self): for example in self.dataset_iter: smiles = example["SMILES"] enc = self.tokenizer(smiles, truncation=True, max_length=self.max_seq_len, return_tensors=None) input_ids = enc["input_ids"] attention_mask = enc["attention_mask"] labels = input_ids.copy() # BERT-style 3-way masking + skip global tokens vocab_size = len(self.tokenizer) # needed for random token sampling for i in range(len(input_ids)): # Skip masking for special/global tokens if input_ids[i] in self.global_token_ids: continue # Decide whether to mask this token if random.random() < self.mask_prob: rand = random.random() if rand < 0.8: # 80%: replace with [MASK] input_ids[i] = self.mask_id elif rand < 0.9: # 10%: replace with random token (excluding special tokens for safety) input_ids[i] = random.randint(0, vocab_size - 1) # Optional: avoid re-selecting special tokens while input_ids[i] in self.global_token_ids: input_ids[i] = random.randint(0, vocab_size - 1) else: # 10%: leave unchanged (do nothing) pass # Global token positions global_positions = [idx for idx, tid in enumerate(input_ids) if tid in self.global_token_ids] # Convert to tensors input_ids = torch.tensor(input_ids, dtype=torch.long) attention_mask = torch.tensor(attention_mask, dtype=torch.long) labels = torch.tensor(labels, dtype=torch.long) yield input_ids, attention_mask, labels, global_positions def collate_fn(batch): input_ids_list, attention_mask_list, labels_list, global_positions_list = zip(*batch) # Pad sequences as before input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=0) attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask_list, batch_first=True, padding_value=0) labels = torch.nn.utils.rnn.pad_sequence(labels_list, batch_first=True, padding_value=-100) # Convert global_positions_list to padded tensor max_g = max(len(g) for g in global_positions_list) # find max number of global tokens in batch if max_g == 0: # If no global tokens in entire batch, create empty tensor with 0 globals global_positions = torch.full((len(global_positions_list), 0), -1, dtype=torch.long) else: padded_global_positions = [ g + [-1] * (max_g - len(g)) for g in global_positions_list ] # pad each list with -1 to max_g length global_positions = torch.tensor(padded_global_positions, dtype=torch.long) return input_ids, attention_mask, labels, global_positions def get_dataloader(csv_file, tokenizer, batch_size=16, max_seq_len=512, mask_prob=0.15, global_token_ids=None): dataset = SelfiesStreamingDataset( csv_file=csv_file, tokenizer=tokenizer, max_seq_len=max_seq_len, mask_prob=mask_prob, global_token_ids=global_token_ids ) return DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn ) # ---------------------------- # Model & Tokenizer Imports # ---------------------------- from RougeBERT import RougeBERT from FastChemTokenizer import FastChemTokenizer # ---------------------------- # Auto-split CSV into train/val/test # ---------------------------- def prepare_train_val_test_split(full_csv, train_csv, val_csv, test_csv, val_test_size=0.3, test_size_ratio=0.5, random_state=42): """ Splits full_csv into train, val, test. val_test_size: portion of data to reserve for val+test (e.g., 0.3 → 70% train, 30% val+test) test_size_ratio: portion of val+test to assign to test (e.g., 0.5 → 15% val, 15% test) """ if all(os.path.exists(f) for f in [train_csv, val_csv, test_csv]): print(f" Train/val/test splits already exist. Skipping split.") train_count = sum(1 for _ in open(train_csv, encoding='utf-8')) - 1 val_count = sum(1 for _ in open(val_csv, encoding='utf-8')) - 1 test_count = sum(1 for _ in open(test_csv, encoding='utf-8')) - 1 return train_count, val_count, test_count print(f"SplitOptions not found. Loading and splitting {full_csv}...") df = pd.read_csv(full_csv) train_df, val_test_df = train_test_split(df, test_size=val_test_size, random_state=random_state) val_df, test_df = train_test_split(val_test_df, test_size=test_size_ratio, random_state=random_state) train_df.to_csv(train_csv, index=False) val_df.to_csv(val_csv, index=False) test_df.to_csv(test_csv, index=False) print(f" Saved {len(train_df)} rows to {train_csv}") print(f" Saved {len(val_df)} rows to {val_csv}") print(f" Saved {len(test_df)} rows to {test_csv}") return len(train_df), len(val_df), len(test_df) # ---------------------------- # Hyperparameters / Config # ---------------------------- BATCH_SIZE = 16 GRAD_ACCUM = 4 NUM_EPOCHS = 1 MAX_SEQ_LEN = 512 LEARNING_RATE = 3e-6 MASK_PROB = 0.25 FULL_CSV = "../data/sample_1k_smi_42.csv" TRAIN_CSV = "../data/train.csv" VAL_CSV = "../data/val.csv" TEST_CSV = "../data/test.csv" SAVE_DIR = "./trained_rougeberttest" # Auto-split and get counts TRAIN_SET_SIZE, VAL_SET_SIZE, TEST_SET_SIZE = prepare_train_val_test_split(FULL_CSV, TRAIN_CSV, VAL_CSV, TEST_CSV, val_test_size=0.2, test_size_ratio=0.5) # Auto-calculate steps train_steps_per_epoch = max(1, TRAIN_SET_SIZE // BATCH_SIZE) optimizer_steps_per_epoch = max(1, train_steps_per_epoch // GRAD_ACCUM) val_steps_total = max(1, VAL_SET_SIZE // BATCH_SIZE) test_steps_total = max(1, TEST_SET_SIZE // BATCH_SIZE) print(f" Train steps/epoch: {train_steps_per_epoch}") print(f" Optimizer steps/epoch: {optimizer_steps_per_epoch}") print(f" Val steps total: {val_steps_total}") print(f" Test steps total: {test_steps_total}") # Log every 25% of training epoch print_every = max(1, train_steps_per_epoch // 4) checkpoint_every = max(1, train_steps_per_epoch // 10) # Save every 10% print(f" Logging every {print_every} steps (25% of epoch)") print(f" Checkpoint every {checkpoint_every} steps (10% of epoch)") # ---------------------------- # Prepare tokenizer # ---------------------------- tokenizer = FastChemTokenizer.from_pretrained("../smitok") global_token_ids = [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.mask_token_id] # ---------------------------- # Prepare model and optimizer # ---------------------------- model = RougeBERT( vocab_size=len(tokenizer), max_seq=512, num_layers=8, hidden_size=320, intermediate_size=1204, num_heads=8, kv_groups=2, rotary_max_seq=512, window=16, dropout=0.1, ) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) optimizer = Ranger21( model.parameters(), lr=LEARNING_RATE, weight_decay=0.01, use_adabelief=True, use_warmup=True, use_madgrad=True, num_epochs=NUM_EPOCHS, warmdown_active=False, num_batches_per_epoch=optimizer_steps_per_epoch ) criterion = CrossEntropyLoss(ignore_index=-100) # ---------------------------- # Save Config # ---------------------------- config = { "vocab_size": len(tokenizer), "max_seq": MAX_SEQ_LEN, "num_layers": 8, "hidden_size": 320, "intermediate_size": 1280, "num_heads": 8, "kv_groups": 2, "rotary_max_seq": 128, "window": 16, "dropout": 0.1, "pad_token_id": tokenizer.pad_token_id, "mask_token_id": tokenizer.mask_token_id, "batch_size": BATCH_SIZE, "grad_accum": GRAD_ACCUM, "learning_rate": LEARNING_RATE, "mask_prob": MASK_PROB, } os.makedirs(SAVE_DIR, exist_ok=True) with open(os.path.join(SAVE_DIR, "config.json"), "w") as f: json.dump(config, f, indent=2) print(f" Config saved to {os.path.join(SAVE_DIR, 'config.json')}") # ---------------------------- # Early Stopping Setup # ---------------------------- best_val_loss = float('inf') patience_counter = 0 PATIENCE = 2 # Stop if no improvement for 3 evals print(f" Early stopping: patience = {PATIENCE} evaluations") # ---------------------------- # Training + Validation Loop # ---------------------------- for epoch in range(NUM_EPOCHS): model.train() running_loss = 0.0 optimizer.zero_grad() train_loader = get_dataloader( csv_file=TRAIN_CSV, tokenizer=tokenizer, batch_size=BATCH_SIZE, max_seq_len=MAX_SEQ_LEN, mask_prob=MASK_PROB, global_token_ids=global_token_ids ) pbar = tqdm(enumerate(train_loader), total=train_steps_per_epoch, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}") for step, (input_ids, attention_mask, labels, global_positions) in pbar: input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) labels = labels.to(device) logits = model(input_ids, attention_mask=attention_mask, global_positions=global_positions) loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1)) loss = loss / GRAD_ACCUM loss.backward() running_loss += loss.item() if (step + 1) % GRAD_ACCUM == 0: optimizer.step() optimizer.zero_grad() # ---------------------------- # Logging + Eval + Checkpoint + GPU Monitoring # ---------------------------- if (step + 1) % print_every == 0: avg_loss = running_loss / print_every pbar.set_postfix({"train_loss": f"{avg_loss:.4f}"}) # Validation model.eval() val_loss = 0.0 correct = 0 total = 0 val_steps = 0 sample_inputs = [] sample_labels = [] sample_preds = [] val_loader = get_dataloader( csv_file=VAL_CSV, tokenizer=tokenizer, batch_size=BATCH_SIZE, max_seq_len=MAX_SEQ_LEN, mask_prob=MASK_PROB, global_token_ids=global_token_ids ) with torch.no_grad(): for vbatch in val_loader: v_input_ids, v_attention_mask, v_labels, v_global_positions = vbatch v_input_ids = v_input_ids.to(device) v_attention_mask = v_attention_mask.to(device) v_labels = v_labels.to(device) v_logits = model(v_input_ids, attention_mask=v_attention_mask, global_positions=v_global_positions) v_loss = criterion(v_logits.view(-1, v_logits.size(-1)), v_labels.view(-1)) val_loss += v_loss.item() v_preds = torch.argmax(v_logits, dim=-1) mask = (v_labels != -100) correct += (v_preds[mask] == v_labels[mask]).sum().item() total += mask.sum().item() if val_steps == 0: sample_inputs.append(v_input_ids.cpu()) sample_labels.append(v_labels.cpu()) sample_preds.append(v_preds.cpu()) val_steps += 1 if val_steps >= min(50, val_steps_total // 4): break avg_val_loss = val_loss / val_steps if val_steps > 0 else float('inf') perplexity = torch.exp(torch.tensor(avg_val_loss)).item() if val_steps > 0 else float('inf') accuracy = correct / total if total > 0 else 0.0 print(f"\n[Step {step+1}] " f"Train Loss: {avg_loss:.4f} | " f"Val Loss: {avg_val_loss:.4f} | " f"Perplexity: {perplexity:.4f} | " f"MLM Acc: {accuracy:.2%}") # GPU Monitoring (every eval) if torch.cuda.is_available(): gpu_mem_alloc = torch.cuda.memory_allocated() / 1e9 gpu_mem_res = torch.cuda.memory_reserved() / 1e9 print(f"GMEM: {gpu_mem_alloc:.2f}GB / {gpu_mem_res:.2f}GB") # Early Stopping Check if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss patience_counter = 0 print(" New best val loss!") else: patience_counter += 1 print(f" No improvement. Patience: {patience_counter}/{PATIENCE}") if patience_counter >= PATIENCE: print(" Early stopping triggered!") break # Checkpoint every 10% if (step + 1) % checkpoint_every == 0: ckpt_path = os.path.join(SAVE_DIR, f"checkpoint_step_{step+1}.bin") torch.save(model.state_dict(), ckpt_path) print(f" Checkpoint saved: {ckpt_path}") # Reset running_loss = 0.0 model.train() # Break if early stopping triggered inside eval block if patience_counter >= PATIENCE: break # If early stopped, break epoch loop too if patience_counter >= PATIENCE: print(" Stopping training early.") break # ---------------------------- # Final Test Evaluation # ---------------------------- print(f"\n{'='*50}\nšŸ”¬ FINAL EPOCH TEST EVALUATION\n{'='*50}") model.eval() test_loss = 0.0 correct = 0 total = 0 test_steps = 0 sample_inputs = [] sample_labels = [] sample_preds = [] test_loader = get_dataloader( csv_file=TEST_CSV, tokenizer=tokenizer, batch_size=BATCH_SIZE, max_seq_len=MAX_SEQ_LEN, mask_prob=MASK_PROB, global_token_ids=global_token_ids ) with torch.no_grad(): for tbatch in test_loader: t_input_ids, t_attention_mask, t_labels, t_global_positions = tbatch t_input_ids = t_input_ids.to(device) t_attention_mask = t_attention_mask.to(device) t_labels = t_labels.to(device) t_logits = model(t_input_ids, attention_mask=t_attention_mask, global_positions=t_global_positions) t_loss = criterion(t_logits.view(-1, t_logits.size(-1)), t_labels.view(-1)) test_loss += t_loss.item() t_preds = torch.argmax(t_logits, dim=-1) mask = (t_labels != -100) correct += (t_preds[mask] == t_labels[mask]).sum().item() total += mask.sum().item() if test_steps == 0: sample_inputs.append(t_input_ids.cpu()) sample_labels.append(t_labels.cpu()) sample_preds.append(t_preds.cpu()) test_steps += 1 if test_steps >= 100: break avg_test_loss = test_loss / test_steps if test_steps > 0 else float('inf') perplexity = torch.exp(torch.tensor(avg_test_loss)).item() if test_steps > 0 else float('inf') accuracy = correct / total if total > 0 else 0.0 print(f" FINAL Epoch {epoch+1} | " f"Test Loss: {avg_test_loss:.4f} | " f"Perplexity: {perplexity:.4f} | " f"MLM Acc: {accuracy:.2%}") # ---------------------------- # Final Save # ---------------------------- final_path = os.path.join(SAVE_DIR, "pytorch_model.bin") torch.save(model.state_dict(), final_path) print(f" Final model saved to {final_path}") try: tokenizer.save_pretrained(SAVE_DIR) print(f" Tokenizer saved to {SAVE_DIR}") except Exception as e: print(f" Could not save tokenizer: {e}") # ---------------------------- # šŸ” Sample Predictions # ---------------------------- print("\n" + "="*70) print(" SAMPLE PREDICTIONS (first 3 from final validation)") print("="*70) if sample_inputs and len(sample_inputs[0]) > 0: input_ids_sample = sample_inputs[0][:3] labels_sample = sample_labels[0][:3] preds_sample = sample_preds[0][:3] for i in range(len(input_ids_sample)): input_toks = input_ids_sample[i].tolist() label_toks = labels_sample[i].tolist() pred_toks = preds_sample[i].tolist() original_tokens = [t for t in label_toks if t != -100] original = tokenizer.decode(original_tokens, skip_special_tokens=False) masked_tokens = [] for j, tok in enumerate(input_toks): if tok == tokenizer.mask_token_id: masked_tokens.append("") elif tok == 0: continue else: decoded = tokenizer.decode([tok], skip_special_tokens=False) masked_tokens.append(decoded) masked_str = "".join(masked_tokens).replace(" ", "") pred_filtered = [p for p, l in zip(pred_toks, label_toks) if l != -100] predicted = tokenizer.decode(pred_filtered, skip_special_tokens=False) print(f"\n Example {i+1}:") print(f" Original: {original}") print(f" Masked: {masked_str}") print(f" Predicted: {predicted}") else: print(" No samples captured for visualization.") print(f"\n Training complete! Best Val Loss: {best_val_loss:.4f}")