ColiFormer-ui / pretrain.py
Genooo12's picture
Deploy Streamlit UI
404d784 verified
"""
File: pretrain.py
-----------------
Pretrain the base transformer model on JSON datasets prepared via
CodonData.prepare_training_data. This is typically not needed for ENCOT
as we use the pretrained CodonTransformer base. See README for setup and usage.
"""
import argparse
import os
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from transformers import BigBirdConfig, BigBirdForMaskedLM, PreTrainedTokenizerFast
from CodonTransformer.CodonUtils import (
MAX_LEN,
NUM_ORGANISMS,
TOKEN2MASK,
IterableJSONData,
)
class MaskedTokenizerCollator:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, examples):
tokenized = self.tokenizer(
[ex["codons"] for ex in examples],
return_attention_mask=True,
return_token_type_ids=True,
truncation=True,
padding=True,
max_length=MAX_LEN,
return_tensors="pt",
)
seq_len = tokenized["input_ids"].shape[-1]
species_index = torch.tensor([[ex["organism"]] for ex in examples])
tokenized["token_type_ids"] = species_index.repeat(1, seq_len)
inputs = tokenized["input_ids"]
targets = inputs.clone()
prob_matrix = torch.full(inputs.shape, 0.15)
prob_matrix[inputs < 5] = 0.0
selected = torch.bernoulli(prob_matrix).bool()
replaced = torch.bernoulli(torch.full(selected.shape, 0.8)).bool() & selected
inputs[replaced] = torch.tensor(
list((map(TOKEN2MASK.__getitem__, inputs[replaced].numpy())))
)
randomized = (
torch.bernoulli(torch.full(selected.shape, 0.1)).bool()
& selected
& ~replaced
)
random_idx = torch.randint(26, 90, inputs.shape, dtype=torch.long)
inputs[randomized] = random_idx[randomized]
tokenized["input_ids"] = inputs
tokenized["labels"] = torch.where(selected, targets, -100)
return tokenized
class plTrainHarness(pl.LightningModule):
def __init__(self, model, learning_rate, warmup_fraction):
super().__init__()
self.model = model
self.learning_rate = learning_rate
self.warmup_fraction = warmup_fraction
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.learning_rate,
)
lr_scheduler = {
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.learning_rate,
total_steps=self.trainer.estimated_stepping_batches,
pct_start=self.warmup_fraction,
),
"interval": "step",
"frequency": 1,
}
return [optimizer], [lr_scheduler]
def training_step(self, batch, batch_idx):
self.model.bert.set_attention_type("block_sparse")
outputs = self.model(**batch)
self.log_dict(
dictionary={
"loss": outputs.loss,
"lr": self.trainer.optimizers[0].param_groups[0]["lr"],
},
on_step=True,
prog_bar=True,
)
return outputs.loss
class EpochCheckpoint(pl.Callback):
def __init__(self, checkpoint_dir, save_interval):
super().__init__()
self.checkpoint_dir = checkpoint_dir
self.save_interval = save_interval
def on_train_epoch_end(self, trainer, pl_module):
current_epoch = trainer.current_epoch
if current_epoch % self.save_interval == 0 or current_epoch == 0:
checkpoint_path = os.path.join(
self.checkpoint_dir, f"epoch_{current_epoch}.ckpt"
)
trainer.save_checkpoint(checkpoint_path)
print(f"\nCheckpoint saved at {checkpoint_path}\n")
def main(args):
"""Pretrain the base transformer model."""
pl.seed_everything(args.seed)
torch.set_float32_matmul_precision("medium")
tokenizer = PreTrainedTokenizerFast(
tokenizer_file=args.tokenizer_path,
bos_token="[CLS]",
eos_token="[SEP]",
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
)
config = BigBirdConfig(
vocab_size=len(tokenizer),
type_vocab_size=NUM_ORGANISMS,
sep_token_id=2,
)
model = BigBirdForMaskedLM(config=config)
harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction)
train_data = IterableJSONData(args.train_data_path, dist_env="slurm")
data_loader = DataLoader(
dataset=train_data,
collate_fn=MaskedTokenizerCollator(tokenizer),
batch_size=args.batch_size,
num_workers=0 if args.debug else args.num_workers,
persistent_workers=False if args.debug else True,
)
save_checkpoint = EpochCheckpoint(args.checkpoint_dir, args.save_interval)
trainer = pl.Trainer(
default_root_dir=args.checkpoint_dir,
strategy="ddp_find_unused_parameters_true",
accelerator="gpu",
devices=1 if args.debug else args.num_gpus,
precision="16-mixed",
max_epochs=args.max_epochs,
deterministic=False,
enable_checkpointing=True,
callbacks=[save_checkpoint],
accumulate_grad_batches=args.accumulate_grad_batches,
)
# Pretrain the model
trainer.fit(harnessed_model, data_loader)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Pretrain the base transformer model.")
parser.add_argument(
"--tokenizer_path",
type=str,
required=True,
help="Path to the tokenizer model file",
)
parser.add_argument(
"--train_data_path",
type=str,
required=True,
help="Path to the training data JSON file",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
required=True,
help="Directory where checkpoints will be saved",
)
parser.add_argument(
"--batch_size", type=int, default=6, help="Batch size for training"
)
parser.add_argument(
"--max_epochs", type=int, default=5, help="Maximum number of epochs to train"
)
parser.add_argument(
"--num_workers", type=int, default=5, help="Number of workers for data loading"
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="Number of batches to accumulate gradients",
)
parser.add_argument(
"--num_gpus", type=int, default=16, help="Number of GPUs to use for training"
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Learning rate for the optimizer",
)
parser.add_argument(
"--warmup_fraction",
type=float,
default=0.1,
help="Fraction of total steps to use for warmup",
)
parser.add_argument(
"--save_interval", type=int, default=5, help="Save checkpoint every N epochs"
)
parser.add_argument(
"--seed", type=int, default=123, help="Random seed for reproducibility"
)
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
args = parser.parse_args()
main(args)