Spaces:
Sleeping
Sleeping
| """ | |
| File: pretrain.py | |
| ----------------- | |
| Pretrain the base transformer model on JSON datasets prepared via | |
| CodonData.prepare_training_data. This is typically not needed for ENCOT | |
| as we use the pretrained CodonTransformer base. See README for setup and usage. | |
| """ | |
| import argparse | |
| import os | |
| import pytorch_lightning as pl | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from transformers import BigBirdConfig, BigBirdForMaskedLM, PreTrainedTokenizerFast | |
| from CodonTransformer.CodonUtils import ( | |
| MAX_LEN, | |
| NUM_ORGANISMS, | |
| TOKEN2MASK, | |
| IterableJSONData, | |
| ) | |
| class MaskedTokenizerCollator: | |
| def __init__(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| def __call__(self, examples): | |
| tokenized = self.tokenizer( | |
| [ex["codons"] for ex in examples], | |
| return_attention_mask=True, | |
| return_token_type_ids=True, | |
| truncation=True, | |
| padding=True, | |
| max_length=MAX_LEN, | |
| return_tensors="pt", | |
| ) | |
| seq_len = tokenized["input_ids"].shape[-1] | |
| species_index = torch.tensor([[ex["organism"]] for ex in examples]) | |
| tokenized["token_type_ids"] = species_index.repeat(1, seq_len) | |
| inputs = tokenized["input_ids"] | |
| targets = inputs.clone() | |
| prob_matrix = torch.full(inputs.shape, 0.15) | |
| prob_matrix[inputs < 5] = 0.0 | |
| selected = torch.bernoulli(prob_matrix).bool() | |
| replaced = torch.bernoulli(torch.full(selected.shape, 0.8)).bool() & selected | |
| inputs[replaced] = torch.tensor( | |
| list((map(TOKEN2MASK.__getitem__, inputs[replaced].numpy()))) | |
| ) | |
| randomized = ( | |
| torch.bernoulli(torch.full(selected.shape, 0.1)).bool() | |
| & selected | |
| & ~replaced | |
| ) | |
| random_idx = torch.randint(26, 90, inputs.shape, dtype=torch.long) | |
| inputs[randomized] = random_idx[randomized] | |
| tokenized["input_ids"] = inputs | |
| tokenized["labels"] = torch.where(selected, targets, -100) | |
| return tokenized | |
| class plTrainHarness(pl.LightningModule): | |
| def __init__(self, model, learning_rate, warmup_fraction): | |
| super().__init__() | |
| self.model = model | |
| self.learning_rate = learning_rate | |
| self.warmup_fraction = warmup_fraction | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW( | |
| self.model.parameters(), | |
| lr=self.learning_rate, | |
| ) | |
| lr_scheduler = { | |
| "scheduler": torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer, | |
| max_lr=self.learning_rate, | |
| total_steps=self.trainer.estimated_stepping_batches, | |
| pct_start=self.warmup_fraction, | |
| ), | |
| "interval": "step", | |
| "frequency": 1, | |
| } | |
| return [optimizer], [lr_scheduler] | |
| def training_step(self, batch, batch_idx): | |
| self.model.bert.set_attention_type("block_sparse") | |
| outputs = self.model(**batch) | |
| self.log_dict( | |
| dictionary={ | |
| "loss": outputs.loss, | |
| "lr": self.trainer.optimizers[0].param_groups[0]["lr"], | |
| }, | |
| on_step=True, | |
| prog_bar=True, | |
| ) | |
| return outputs.loss | |
| class EpochCheckpoint(pl.Callback): | |
| def __init__(self, checkpoint_dir, save_interval): | |
| super().__init__() | |
| self.checkpoint_dir = checkpoint_dir | |
| self.save_interval = save_interval | |
| def on_train_epoch_end(self, trainer, pl_module): | |
| current_epoch = trainer.current_epoch | |
| if current_epoch % self.save_interval == 0 or current_epoch == 0: | |
| checkpoint_path = os.path.join( | |
| self.checkpoint_dir, f"epoch_{current_epoch}.ckpt" | |
| ) | |
| trainer.save_checkpoint(checkpoint_path) | |
| print(f"\nCheckpoint saved at {checkpoint_path}\n") | |
| def main(args): | |
| """Pretrain the base transformer model.""" | |
| pl.seed_everything(args.seed) | |
| torch.set_float32_matmul_precision("medium") | |
| tokenizer = PreTrainedTokenizerFast( | |
| tokenizer_file=args.tokenizer_path, | |
| bos_token="[CLS]", | |
| eos_token="[SEP]", | |
| unk_token="[UNK]", | |
| sep_token="[SEP]", | |
| pad_token="[PAD]", | |
| cls_token="[CLS]", | |
| mask_token="[MASK]", | |
| ) | |
| config = BigBirdConfig( | |
| vocab_size=len(tokenizer), | |
| type_vocab_size=NUM_ORGANISMS, | |
| sep_token_id=2, | |
| ) | |
| model = BigBirdForMaskedLM(config=config) | |
| harnessed_model = plTrainHarness(model, args.learning_rate, args.warmup_fraction) | |
| train_data = IterableJSONData(args.train_data_path, dist_env="slurm") | |
| data_loader = DataLoader( | |
| dataset=train_data, | |
| collate_fn=MaskedTokenizerCollator(tokenizer), | |
| batch_size=args.batch_size, | |
| num_workers=0 if args.debug else args.num_workers, | |
| persistent_workers=False if args.debug else True, | |
| ) | |
| save_checkpoint = EpochCheckpoint(args.checkpoint_dir, args.save_interval) | |
| trainer = pl.Trainer( | |
| default_root_dir=args.checkpoint_dir, | |
| strategy="ddp_find_unused_parameters_true", | |
| accelerator="gpu", | |
| devices=1 if args.debug else args.num_gpus, | |
| precision="16-mixed", | |
| max_epochs=args.max_epochs, | |
| deterministic=False, | |
| enable_checkpointing=True, | |
| callbacks=[save_checkpoint], | |
| accumulate_grad_batches=args.accumulate_grad_batches, | |
| ) | |
| # Pretrain the model | |
| trainer.fit(harnessed_model, data_loader) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Pretrain the base transformer model.") | |
| parser.add_argument( | |
| "--tokenizer_path", | |
| type=str, | |
| required=True, | |
| help="Path to the tokenizer model file", | |
| ) | |
| parser.add_argument( | |
| "--train_data_path", | |
| type=str, | |
| required=True, | |
| help="Path to the training data JSON file", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_dir", | |
| type=str, | |
| required=True, | |
| help="Directory where checkpoints will be saved", | |
| ) | |
| parser.add_argument( | |
| "--batch_size", type=int, default=6, help="Batch size for training" | |
| ) | |
| parser.add_argument( | |
| "--max_epochs", type=int, default=5, help="Maximum number of epochs to train" | |
| ) | |
| parser.add_argument( | |
| "--num_workers", type=int, default=5, help="Number of workers for data loading" | |
| ) | |
| parser.add_argument( | |
| "--accumulate_grad_batches", | |
| type=int, | |
| default=1, | |
| help="Number of batches to accumulate gradients", | |
| ) | |
| parser.add_argument( | |
| "--num_gpus", type=int, default=16, help="Number of GPUs to use for training" | |
| ) | |
| parser.add_argument( | |
| "--learning_rate", | |
| type=float, | |
| default=5e-5, | |
| help="Learning rate for the optimizer", | |
| ) | |
| parser.add_argument( | |
| "--warmup_fraction", | |
| type=float, | |
| default=0.1, | |
| help="Fraction of total steps to use for warmup", | |
| ) | |
| parser.add_argument( | |
| "--save_interval", type=int, default=5, help="Save checkpoint every N epochs" | |
| ) | |
| parser.add_argument( | |
| "--seed", type=int, default=123, help="Random seed for reproducibility" | |
| ) | |
| parser.add_argument("--debug", action="store_true", help="Enable debug mode") | |
| args = parser.parse_args() | |
| main(args) | |