| |
|
| |
|
| |
|
| |
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| | from torch.utils.data import IterableDataset, DataLoader
|
| | from tqdm import tqdm
|
| | import pandas as pd
|
| | from ranger21 import Ranger21
|
| | import random
|
| | import os
|
| | from sklearn.model_selection import train_test_split
|
| | from torch.nn import CrossEntropyLoss
|
| | import json
|
| |
|
| |
|
| |
|
| |
|
| | from datasets import load_dataset
|
| |
|
| | class SelfiesStreamingDataset(IterableDataset):
|
| | def __init__(self, csv_file, tokenizer, max_seq_len=512, mask_prob=0.15, global_token_ids=None):
|
| | self.tokenizer = tokenizer
|
| | self.max_seq_len = max_seq_len
|
| | self.mask_prob = mask_prob
|
| | self.global_token_ids = global_token_ids or []
|
| |
|
| | print(f"Loading dataset (streaming): {csv_file}")
|
| | dataset = load_dataset("csv", data_files=csv_file, split="train", streaming=True)
|
| | dataset = dataset.shuffle(seed=42, buffer_size=10000)
|
| | self.dataset_iter = iter(dataset)
|
| |
|
| |
|
| | self.mask_id = tokenizer.mask_token_id
|
| | self.pad_id = tokenizer.pad_token_id
|
| |
|
| | def __iter__(self):
|
| | for example in self.dataset_iter:
|
| | smiles = example["SMILES"]
|
| | enc = self.tokenizer(smiles, truncation=True, max_length=self.max_seq_len, return_tensors=None)
|
| |
|
| | input_ids = enc["input_ids"]
|
| | attention_mask = enc["attention_mask"]
|
| | labels = input_ids.copy()
|
| |
|
| |
|
| | vocab_size = len(self.tokenizer)
|
| | for i in range(len(input_ids)):
|
| |
|
| | if input_ids[i] in self.global_token_ids:
|
| | continue
|
| |
|
| |
|
| | if random.random() < self.mask_prob:
|
| | rand = random.random()
|
| | if rand < 0.8:
|
| |
|
| | input_ids[i] = self.mask_id
|
| | elif rand < 0.9:
|
| |
|
| | input_ids[i] = random.randint(0, vocab_size - 1)
|
| |
|
| | while input_ids[i] in self.global_token_ids:
|
| | input_ids[i] = random.randint(0, vocab_size - 1)
|
| | else:
|
| |
|
| | pass
|
| |
|
| |
|
| | global_positions = [idx for idx, tid in enumerate(input_ids) if tid in self.global_token_ids]
|
| |
|
| |
|
| | input_ids = torch.tensor(input_ids, dtype=torch.long)
|
| | attention_mask = torch.tensor(attention_mask, dtype=torch.long)
|
| | labels = torch.tensor(labels, dtype=torch.long)
|
| |
|
| | yield input_ids, attention_mask, labels, global_positions
|
| |
|
| |
|
| | def collate_fn(batch):
|
| | input_ids_list, attention_mask_list, labels_list, global_positions_list = zip(*batch)
|
| |
|
| |
|
| | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=0)
|
| | attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask_list, batch_first=True, padding_value=0)
|
| | labels = torch.nn.utils.rnn.pad_sequence(labels_list, batch_first=True, padding_value=-100)
|
| |
|
| |
|
| | max_g = max(len(g) for g in global_positions_list)
|
| | if max_g == 0:
|
| |
|
| | global_positions = torch.full((len(global_positions_list), 0), -1, dtype=torch.long)
|
| | else:
|
| | padded_global_positions = [
|
| | g + [-1] * (max_g - len(g)) for g in global_positions_list
|
| | ]
|
| | global_positions = torch.tensor(padded_global_positions, dtype=torch.long)
|
| |
|
| | return input_ids, attention_mask, labels, global_positions
|
| |
|
| |
|
| | def get_dataloader(csv_file, tokenizer, batch_size=16, max_seq_len=512, mask_prob=0.15, global_token_ids=None):
|
| | dataset = SelfiesStreamingDataset(
|
| | csv_file=csv_file,
|
| | tokenizer=tokenizer,
|
| | max_seq_len=max_seq_len,
|
| | mask_prob=mask_prob,
|
| | global_token_ids=global_token_ids
|
| | )
|
| | return DataLoader(
|
| | dataset,
|
| | batch_size=batch_size,
|
| | collate_fn=collate_fn
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | from RougeBERT import RougeBERT
|
| | from FastChemTokenizer import FastChemTokenizer
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def prepare_train_val_test_split(full_csv, train_csv, val_csv, test_csv, val_test_size=0.3, test_size_ratio=0.5, random_state=42):
|
| | """
|
| | Splits full_csv into train, val, test.
|
| | val_test_size: portion of data to reserve for val+test (e.g., 0.3 → 70% train, 30% val+test)
|
| | test_size_ratio: portion of val+test to assign to test (e.g., 0.5 → 15% val, 15% test)
|
| | """
|
| | if all(os.path.exists(f) for f in [train_csv, val_csv, test_csv]):
|
| | print(f" Train/val/test splits already exist. Skipping split.")
|
| | train_count = sum(1 for _ in open(train_csv, encoding='utf-8')) - 1
|
| | val_count = sum(1 for _ in open(val_csv, encoding='utf-8')) - 1
|
| | test_count = sum(1 for _ in open(test_csv, encoding='utf-8')) - 1
|
| | return train_count, val_count, test_count
|
| |
|
| | print(f"SplitOptions not found. Loading and splitting {full_csv}...")
|
| |
|
| | df = pd.read_csv(full_csv)
|
| | train_df, val_test_df = train_test_split(df, test_size=val_test_size, random_state=random_state)
|
| | val_df, test_df = train_test_split(val_test_df, test_size=test_size_ratio, random_state=random_state)
|
| |
|
| | train_df.to_csv(train_csv, index=False)
|
| | val_df.to_csv(val_csv, index=False)
|
| | test_df.to_csv(test_csv, index=False)
|
| |
|
| | print(f" Saved {len(train_df)} rows to {train_csv}")
|
| | print(f" Saved {len(val_df)} rows to {val_csv}")
|
| | print(f" Saved {len(test_df)} rows to {test_csv}")
|
| |
|
| | return len(train_df), len(val_df), len(test_df)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | BATCH_SIZE = 16
|
| | GRAD_ACCUM = 4
|
| | NUM_EPOCHS = 1
|
| | MAX_SEQ_LEN = 512
|
| | LEARNING_RATE = 3e-6
|
| | MASK_PROB = 0.25
|
| |
|
| | FULL_CSV = "../data/sample_1k_smi_42.csv"
|
| | TRAIN_CSV = "../data/train.csv"
|
| | VAL_CSV = "../data/val.csv"
|
| | TEST_CSV = "../data/test.csv"
|
| | SAVE_DIR = "./trained_rougeberttest"
|
| |
|
| |
|
| | TRAIN_SET_SIZE, VAL_SET_SIZE, TEST_SET_SIZE = prepare_train_val_test_split(FULL_CSV, TRAIN_CSV, VAL_CSV, TEST_CSV, val_test_size=0.2, test_size_ratio=0.5)
|
| |
|
| |
|
| | train_steps_per_epoch = max(1, TRAIN_SET_SIZE // BATCH_SIZE)
|
| | optimizer_steps_per_epoch = max(1, train_steps_per_epoch // GRAD_ACCUM)
|
| | val_steps_total = max(1, VAL_SET_SIZE // BATCH_SIZE)
|
| | test_steps_total = max(1, TEST_SET_SIZE // BATCH_SIZE)
|
| |
|
| | print(f" Train steps/epoch: {train_steps_per_epoch}")
|
| | print(f" Optimizer steps/epoch: {optimizer_steps_per_epoch}")
|
| | print(f" Val steps total: {val_steps_total}")
|
| | print(f" Test steps total: {test_steps_total}")
|
| |
|
| |
|
| | print_every = max(1, train_steps_per_epoch // 4)
|
| | checkpoint_every = max(1, train_steps_per_epoch // 10)
|
| | print(f" Logging every {print_every} steps (25% of epoch)")
|
| | print(f" Checkpoint every {checkpoint_every} steps (10% of epoch)")
|
| |
|
| |
|
| |
|
| |
|
| | tokenizer = FastChemTokenizer.from_pretrained("../smitok")
|
| | global_token_ids = [tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.mask_token_id]
|
| |
|
| |
|
| |
|
| |
|
| | model = RougeBERT(
|
| | vocab_size=len(tokenizer),
|
| | max_seq=512,
|
| | num_layers=8,
|
| | hidden_size=320,
|
| | intermediate_size=1204,
|
| | num_heads=8,
|
| | kv_groups=2,
|
| | rotary_max_seq=512,
|
| | window=16,
|
| | dropout=0.1,
|
| | )
|
| | device = "cuda" if torch.cuda.is_available() else "cpu"
|
| | model.to(device)
|
| |
|
| | optimizer = Ranger21(
|
| | model.parameters(),
|
| | lr=LEARNING_RATE,
|
| | weight_decay=0.01,
|
| | use_adabelief=True,
|
| | use_warmup=True,
|
| | use_madgrad=True,
|
| | num_epochs=NUM_EPOCHS,
|
| | warmdown_active=False,
|
| | num_batches_per_epoch=optimizer_steps_per_epoch
|
| | )
|
| |
|
| | criterion = CrossEntropyLoss(ignore_index=-100)
|
| |
|
| |
|
| |
|
| |
|
| | config = {
|
| | "vocab_size": len(tokenizer),
|
| | "max_seq": MAX_SEQ_LEN,
|
| | "num_layers": 8,
|
| | "hidden_size": 320,
|
| | "intermediate_size": 1280,
|
| | "num_heads": 8,
|
| | "kv_groups": 2,
|
| | "rotary_max_seq": 128,
|
| | "window": 16,
|
| | "dropout": 0.1,
|
| | "pad_token_id": tokenizer.pad_token_id,
|
| | "mask_token_id": tokenizer.mask_token_id,
|
| | "batch_size": BATCH_SIZE,
|
| | "grad_accum": GRAD_ACCUM,
|
| | "learning_rate": LEARNING_RATE,
|
| | "mask_prob": MASK_PROB,
|
| | }
|
| | os.makedirs(SAVE_DIR, exist_ok=True)
|
| | with open(os.path.join(SAVE_DIR, "config.json"), "w") as f:
|
| | json.dump(config, f, indent=2)
|
| | print(f" Config saved to {os.path.join(SAVE_DIR, 'config.json')}")
|
| |
|
| |
|
| |
|
| |
|
| | best_val_loss = float('inf')
|
| | patience_counter = 0
|
| | PATIENCE = 2
|
| | print(f" Early stopping: patience = {PATIENCE} evaluations")
|
| |
|
| |
|
| |
|
| |
|
| | for epoch in range(NUM_EPOCHS):
|
| | model.train()
|
| | running_loss = 0.0
|
| | optimizer.zero_grad()
|
| |
|
| | train_loader = get_dataloader(
|
| | csv_file=TRAIN_CSV,
|
| | tokenizer=tokenizer,
|
| | batch_size=BATCH_SIZE,
|
| | max_seq_len=MAX_SEQ_LEN,
|
| | mask_prob=MASK_PROB,
|
| | global_token_ids=global_token_ids
|
| | )
|
| |
|
| | pbar = tqdm(enumerate(train_loader), total=train_steps_per_epoch, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
|
| | for step, (input_ids, attention_mask, labels, global_positions) in pbar:
|
| | input_ids = input_ids.to(device)
|
| | attention_mask = attention_mask.to(device)
|
| | labels = labels.to(device)
|
| |
|
| | logits = model(input_ids, attention_mask=attention_mask, global_positions=global_positions)
|
| | loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| | loss = loss / GRAD_ACCUM
|
| | loss.backward()
|
| | running_loss += loss.item()
|
| |
|
| | if (step + 1) % GRAD_ACCUM == 0:
|
| | optimizer.step()
|
| | optimizer.zero_grad()
|
| |
|
| |
|
| |
|
| |
|
| | if (step + 1) % print_every == 0:
|
| | avg_loss = running_loss / print_every
|
| | pbar.set_postfix({"train_loss": f"{avg_loss:.4f}"})
|
| |
|
| |
|
| | model.eval()
|
| | val_loss = 0.0
|
| | correct = 0
|
| | total = 0
|
| | val_steps = 0
|
| |
|
| | sample_inputs = []
|
| | sample_labels = []
|
| | sample_preds = []
|
| |
|
| | val_loader = get_dataloader(
|
| | csv_file=VAL_CSV,
|
| | tokenizer=tokenizer,
|
| | batch_size=BATCH_SIZE,
|
| | max_seq_len=MAX_SEQ_LEN,
|
| | mask_prob=MASK_PROB,
|
| | global_token_ids=global_token_ids
|
| | )
|
| |
|
| | with torch.no_grad():
|
| | for vbatch in val_loader:
|
| | v_input_ids, v_attention_mask, v_labels, v_global_positions = vbatch
|
| | v_input_ids = v_input_ids.to(device)
|
| | v_attention_mask = v_attention_mask.to(device)
|
| | v_labels = v_labels.to(device)
|
| |
|
| | v_logits = model(v_input_ids, attention_mask=v_attention_mask, global_positions=v_global_positions)
|
| | v_loss = criterion(v_logits.view(-1, v_logits.size(-1)), v_labels.view(-1))
|
| | val_loss += v_loss.item()
|
| |
|
| | v_preds = torch.argmax(v_logits, dim=-1)
|
| | mask = (v_labels != -100)
|
| | correct += (v_preds[mask] == v_labels[mask]).sum().item()
|
| | total += mask.sum().item()
|
| |
|
| | if val_steps == 0:
|
| | sample_inputs.append(v_input_ids.cpu())
|
| | sample_labels.append(v_labels.cpu())
|
| | sample_preds.append(v_preds.cpu())
|
| |
|
| | val_steps += 1
|
| | if val_steps >= min(50, val_steps_total // 4):
|
| | break
|
| |
|
| | avg_val_loss = val_loss / val_steps if val_steps > 0 else float('inf')
|
| | perplexity = torch.exp(torch.tensor(avg_val_loss)).item() if val_steps > 0 else float('inf')
|
| | accuracy = correct / total if total > 0 else 0.0
|
| |
|
| | print(f"\n[Step {step+1}] "
|
| | f"Train Loss: {avg_loss:.4f} | "
|
| | f"Val Loss: {avg_val_loss:.4f} | "
|
| | f"Perplexity: {perplexity:.4f} | "
|
| | f"MLM Acc: {accuracy:.2%}")
|
| |
|
| |
|
| | if torch.cuda.is_available():
|
| | gpu_mem_alloc = torch.cuda.memory_allocated() / 1e9
|
| | gpu_mem_res = torch.cuda.memory_reserved() / 1e9
|
| | print(f"GMEM: {gpu_mem_alloc:.2f}GB / {gpu_mem_res:.2f}GB")
|
| |
|
| |
|
| | if avg_val_loss < best_val_loss:
|
| | best_val_loss = avg_val_loss
|
| | patience_counter = 0
|
| | print(" New best val loss!")
|
| | else:
|
| | patience_counter += 1
|
| | print(f" No improvement. Patience: {patience_counter}/{PATIENCE}")
|
| | if patience_counter >= PATIENCE:
|
| | print(" Early stopping triggered!")
|
| | break
|
| |
|
| |
|
| | if (step + 1) % checkpoint_every == 0:
|
| | ckpt_path = os.path.join(SAVE_DIR, f"checkpoint_step_{step+1}.bin")
|
| | torch.save(model.state_dict(), ckpt_path)
|
| | print(f" Checkpoint saved: {ckpt_path}")
|
| |
|
| |
|
| | running_loss = 0.0
|
| | model.train()
|
| |
|
| |
|
| | if patience_counter >= PATIENCE:
|
| | break
|
| |
|
| |
|
| | if patience_counter >= PATIENCE:
|
| | print(" Stopping training early.")
|
| | break
|
| |
|
| |
|
| |
|
| |
|
| | print(f"\n{'='*50}\n🔬 FINAL EPOCH TEST EVALUATION\n{'='*50}")
|
| | model.eval()
|
| | test_loss = 0.0
|
| | correct = 0
|
| | total = 0
|
| | test_steps = 0
|
| |
|
| | sample_inputs = []
|
| | sample_labels = []
|
| | sample_preds = []
|
| |
|
| | test_loader = get_dataloader(
|
| | csv_file=TEST_CSV,
|
| | tokenizer=tokenizer,
|
| | batch_size=BATCH_SIZE,
|
| | max_seq_len=MAX_SEQ_LEN,
|
| | mask_prob=MASK_PROB,
|
| | global_token_ids=global_token_ids
|
| | )
|
| |
|
| | with torch.no_grad():
|
| | for tbatch in test_loader:
|
| | t_input_ids, t_attention_mask, t_labels, t_global_positions = tbatch
|
| | t_input_ids = t_input_ids.to(device)
|
| | t_attention_mask = t_attention_mask.to(device)
|
| | t_labels = t_labels.to(device)
|
| |
|
| | t_logits = model(t_input_ids, attention_mask=t_attention_mask, global_positions=t_global_positions)
|
| | t_loss = criterion(t_logits.view(-1, t_logits.size(-1)), t_labels.view(-1))
|
| | test_loss += t_loss.item()
|
| |
|
| | t_preds = torch.argmax(t_logits, dim=-1)
|
| | mask = (t_labels != -100)
|
| | correct += (t_preds[mask] == t_labels[mask]).sum().item()
|
| | total += mask.sum().item()
|
| |
|
| | if test_steps == 0:
|
| | sample_inputs.append(t_input_ids.cpu())
|
| | sample_labels.append(t_labels.cpu())
|
| | sample_preds.append(t_preds.cpu())
|
| |
|
| | test_steps += 1
|
| | if test_steps >= 100:
|
| | break
|
| |
|
| | avg_test_loss = test_loss / test_steps if test_steps > 0 else float('inf')
|
| | perplexity = torch.exp(torch.tensor(avg_test_loss)).item() if test_steps > 0 else float('inf')
|
| | accuracy = correct / total if total > 0 else 0.0
|
| |
|
| | print(f" FINAL Epoch {epoch+1} | "
|
| | f"Test Loss: {avg_test_loss:.4f} | "
|
| | f"Perplexity: {perplexity:.4f} | "
|
| | f"MLM Acc: {accuracy:.2%}")
|
| |
|
| |
|
| |
|
| |
|
| | final_path = os.path.join(SAVE_DIR, "pytorch_model.bin")
|
| | torch.save(model.state_dict(), final_path)
|
| | print(f" Final model saved to {final_path}")
|
| |
|
| | try:
|
| | tokenizer.save_pretrained(SAVE_DIR)
|
| | print(f" Tokenizer saved to {SAVE_DIR}")
|
| | except Exception as e:
|
| | print(f" Could not save tokenizer: {e}")
|
| |
|
| |
|
| |
|
| |
|
| | print("\n" + "="*70)
|
| | print(" SAMPLE PREDICTIONS (first 3 from final validation)")
|
| | print("="*70)
|
| |
|
| | if sample_inputs and len(sample_inputs[0]) > 0:
|
| | input_ids_sample = sample_inputs[0][:3]
|
| | labels_sample = sample_labels[0][:3]
|
| | preds_sample = sample_preds[0][:3]
|
| |
|
| | for i in range(len(input_ids_sample)):
|
| | input_toks = input_ids_sample[i].tolist()
|
| | label_toks = labels_sample[i].tolist()
|
| | pred_toks = preds_sample[i].tolist()
|
| |
|
| | original_tokens = [t for t in label_toks if t != -100]
|
| | original = tokenizer.decode(original_tokens, skip_special_tokens=False)
|
| |
|
| | masked_tokens = []
|
| | for j, tok in enumerate(input_toks):
|
| | if tok == tokenizer.mask_token_id:
|
| | masked_tokens.append("<mask>")
|
| | elif tok == 0:
|
| | continue
|
| | else:
|
| | decoded = tokenizer.decode([tok], skip_special_tokens=False)
|
| | masked_tokens.append(decoded)
|
| | masked_str = "".join(masked_tokens).replace(" ", "")
|
| |
|
| | pred_filtered = [p for p, l in zip(pred_toks, label_toks) if l != -100]
|
| | predicted = tokenizer.decode(pred_filtered, skip_special_tokens=False)
|
| |
|
| | print(f"\n Example {i+1}:")
|
| | print(f" Original: {original}")
|
| | print(f" Masked: {masked_str}")
|
| | print(f" Predicted: {predicted}")
|
| | else:
|
| | print(" No samples captured for visualization.")
|
| |
|
| |
|
| | print(f"\n Training complete! Best Val Loss: {best_val_loss:.4f}")
|
| |
|