RougeBERT / benchmark_roberta.py
gbyuvd's picture
Upload base codes
48d0053 verified
# benchmark_roberta.py
# For benchmarking against RoBERTa on 0.15 MASK PROB MLM
# Consider adjusting the data, tokenizer, and anything else
# according to your use case(s).
# It uses train_eval_utils.py for train, eval, logging helpers
# - gbyuvd
from RougeBERTHF import RougeBERTForMaskedLM, RougeBERTConfig
from transformers import RobertaForMaskedLM, RobertaConfig
from FastChemTokenizer import FastChemTokenizer
from train_eval_utils import train_and_eval, prepare_train_val_test_split
# ----------------------------
# Tokenizer
# ----------------------------
tokenizer = FastChemTokenizer.from_pretrained("../smitok")
# ----------------------------
# Training Hyperparams
# ----------------------------
BATCH_SIZE = 16
GRAD_ACCUM = 4
NUM_EPOCHS = 10
MAX_SEQ_LEN = 512
LEARNING_RATE = 1e-5
MASK_PROB = 0.15
FULL_CSV = "../data/sample_1k_smi_42.csv"
TRAIN_CSV = "../data/train.csv"
VAL_CSV = "../data/val.csv"
TEST_CSV = "../data/test.csv"
SAVE_DIR = "./pretrained_roguebert"
# ----------------------------
# Helper: pick correct max_seq length
# ----------------------------
def get_seq_len(config, default=512):
if hasattr(config, "max_seq"):
return config.max_seq
elif hasattr(config, "max_position_embeddings"):
return config.max_position_embeddings
return default
# ----------------------------
# RougeBERT Hybrid
# ----------------------------
rouge_config = RougeBERTConfig(vocab_size=len(tokenizer), max_seq=MAX_SEQ_LEN)
rouge_model = RougeBERTForMaskedLM(rouge_config) # 9022400 params
rouge_results = train_and_eval(
rouge_model,
tokenizer,
TRAIN_CSV,
VAL_CSV,
TEST_CSV,
rouge_config,
run_name="rougebert",
batch_size=BATCH_SIZE,
grad_accum=GRAD_ACCUM,
num_epochs=NUM_EPOCHS,
learning_rate=LEARNING_RATE,
mask_prob=MASK_PROB,
max_seq_len=get_seq_len(rouge_config, MAX_SEQ_LEN),
)
# ----------------------------
# Vanilla BERT
# ----------------------------
roberta_config = RobertaConfig(
vocab_size=len(tokenizer),
hidden_size=282, #
intermediate_size=1300, # FFN size
num_attention_heads=6,
num_hidden_layers=8,
max_position_embeddings=MAX_SEQ_LEN + 2, # RoBERTa usually adds 2 for special offsets
type_vocab_size=1, # RoBERTa doesn't use token_type_ids
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0,
bos_token_id=tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 0,
eos_token_id=tokenizer.eos_token_id if hasattr(tokenizer, 'eos_token_id') else 2,
# RoBERTa-specific
attention_probs_dropout_prob=0.1,
hidden_dropout_prob=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
)
roberta_model = RobertaForMaskedLM(roberta_config) # 9017590 params
bert_results = train_and_eval(
roberta_model,
tokenizer,
TRAIN_CSV,
VAL_CSV,
TEST_CSV,
roberta_config,
run_name="roberta_baseline",
batch_size=BATCH_SIZE,
grad_accum=GRAD_ACCUM,
num_epochs=NUM_EPOCHS,
learning_rate=LEARNING_RATE,
mask_prob=MASK_PROB,
max_seq_len=get_seq_len(roberta_config, MAX_SEQ_LEN),
)
# ----------------------------
# Print comparison
# ----------------------------
print("RougeBERT:", rouge_results)
print("Vanilla BERT:", bert_results)