""" File: pretrain.py ----------------- Pretrain the base transformer model on JSON datasets prepared via CodonData.prepare_training_data. This is typically not needed for ENCOT as we use the pretrained CodonTransformer base. See README for setup and usage. """ import argparse import os import pytorch_lightning as pl import torch from torch.utils.data import DataLoader from transformers import BigBirdConfig, BigBirdForMaskedLM, PreTrainedTokenizerFast from CodonTransformer.CodonUtils import ( MAX_LEN, NUM_ORGANISMS, TOKEN2MASK, IterableJSONData, ) class MaskedTokenizerCollator: def __init__(self, tokenizer): self.tokenizer = tokenizer def __call__(self, examples): tokenized = self.tokenizer( [ex["codons"] for ex in examples], return_attention_mask=True, return_token_type_ids=True, truncation=True, padding=True, max_length=MAX_LEN, return_tensors="pt", ) seq_len = tokenized["input_ids"].shape[-1] species_index = torch.tensor([[ex["organism"]] for ex in examples]) tokenized["token_type_ids"] = species_index.repeat(1, seq_len) inputs = tokenized["input_ids"] targets = inputs.clone() prob_matrix = torch.full(inputs.shape, 0.15) prob_matrix[inputs < 5] = 0.0 selected = torch.bernoulli(prob_matrix).bool() replaced = torch.bernoulli(torch.full(selected.shape, 0.8)).bool() & selected inputs[replaced] = torch.tensor( list((map(TOKEN2MASK.__getitem__, inputs[replaced].numpy()))) ) randomized = ( torch.bernoulli(torch.full(selected.shape, 0.1)).bool() & selected & ~replaced ) random_idx = torch.randint(26, 90, inputs.shape, dtype=torch.long) inputs[randomized] = random_idx[randomized] tokenized["input_ids"] = inputs tokenized["labels"] = torch.where(selected, targets, -100) return tokenized class plTrainHarness(pl.LightningModule): def __init__(self, model, learning_rate, warmup_fraction): super().__init__() self.model = model self.learning_rate = learning_rate self.warmup_fraction = warmup_fraction def configure_optimizers(self): optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate, ) lr_scheduler = { "scheduler": torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.learning_rate, total_steps=self.trainer.estimated_stepping_batches, pct_start=self.warmup_fraction, ), "interval": "step", "frequency": 1, } return [optimizer], [lr_scheduler] def training_step(self, batch, batch_idx): self.model.bert.set_attention_type("block_sparse") outputs = self.model(**batch) self.log_dict( dictionary={ "loss": outputs.loss, "lr": self.trainer.optimizers[0].param_groups[0]["lr"], }, on_step=True, prog_bar=True, ) return outputs.loss class EpochCheckpoint(pl.Callback): def __init__(self, checkpoint_dir, save_interval): super().__init__() self.checkpoint_dir = checkpoint_dir self.save_interval = save_interval def on_train_epoch_end(self, trainer, pl_module): current_epoch = trainer.current_epoch if current_epoch % self.save_interval == 0 or current_epoch == 0: checkpoint_path = os.path.join( self.checkpoint_dir, f"epoch_{current_epoch}.ckpt" ) trainer.save_checkpoint(checkpoint_path) print(f"\nCheckpoint saved at {checkpoint_path}\n") def main(args): """Pretrain the base transformer model.""" pl.seed_everything(args.seed) torch.set_float32_matmul_precision("medium") tokenizer = PreTrainedTokenizerFast( tokenizer_file=args.tokenizer_path, bos_token="[CLS]", eos_token="[SEP]", unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]", ) config = BigBirdConfig( vocab_size=len(tokenizer), type_vocab_size=NUM_ORGANISMS, sep_token_id=2, ) model = BigBirdForMaskedLM(config=config) harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction) train_data = IterableJSONData(args.train_data_path, dist_env="slurm") data_loader = DataLoader( dataset=train_data, collate_fn=MaskedTokenizerCollator(tokenizer), batch_size=args.batch_size, num_workers=0 if args.debug else args.num_workers, persistent_workers=False if args.debug else True, ) save_checkpoint = EpochCheckpoint(args.checkpoint_dir, args.save_interval) trainer = pl.Trainer( default_root_dir=args.checkpoint_dir, strategy="ddp_find_unused_parameters_true", accelerator="gpu", devices=1 if args.debug else args.num_gpus, precision="16-mixed", max_epochs=args.max_epochs, deterministic=False, enable_checkpointing=True, callbacks=[save_checkpoint], accumulate_grad_batches=args.accumulate_grad_batches, ) # Pretrain the model trainer.fit(harnessed_model, data_loader) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Pretrain the base transformer model.") parser.add_argument( "--tokenizer_path", type=str, required=True, help="Path to the tokenizer model file", ) parser.add_argument( "--train_data_path", type=str, required=True, help="Path to the training data JSON file", ) parser.add_argument( "--checkpoint_dir", type=str, required=True, help="Directory where checkpoints will be saved", ) parser.add_argument( "--batch_size", type=int, default=6, help="Batch size for training" ) parser.add_argument( "--max_epochs", type=int, default=5, help="Maximum number of epochs to train" ) parser.add_argument( "--num_workers", type=int, default=5, help="Number of workers for data loading" ) parser.add_argument( "--accumulate_grad_batches", type=int, default=1, help="Number of batches to accumulate gradients", ) parser.add_argument( "--num_gpus", type=int, default=16, help="Number of GPUs to use for training" ) parser.add_argument( "--learning_rate", type=float, default=5e-5, help="Learning rate for the optimizer", ) parser.add_argument( "--warmup_fraction", type=float, default=0.1, help="Fraction of total steps to use for warmup", ) parser.add_argument( "--save_interval", type=int, default=5, help="Save checkpoint every N epochs" ) parser.add_argument( "--seed", type=int, default=123, help="Random seed for reproducibility" ) parser.add_argument("--debug", action="store_true", help="Enable debug mode") args = parser.parse_args() main(args)