| import numpy as np |
| from datasets import load_from_disk |
| import torch |
| from transformers import BertForMaskedLM |
| import os |
| import sys |
| from tqdm.notebook import tqdm |
| import seaborn as sns |
| import matplotlib.pyplot as plt |
| from geneformer.pretrainer import token_dictionary |
|
|
| import datetime |
|
|
| |
| import os |
| import time |
|
|
| os.environ["NCCL_DEBUG"] = "INFO" |
| os.environ["OMPI_MCA_opal_cuda_support"] = "true" |
| os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56" |
|
|
| import pickle |
| import random |
| import subprocess |
|
|
| import numpy as np |
| import pytz |
| import torch |
| from datasets import load_from_disk, Dataset |
| from transformers import BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback, Trainer, BertModel, BertPreTrainedModel |
|
|
| from geneformer import GeneformerPretrainer |
|
|
| from typing import Tuple |
| from torch import Tensor |
| from transformers.modeling_outputs import MaskedLMOutput |
| from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform |
| from transformers.activations import ACT2FN |
| from typing import List, Optional, Tuple, Union |
| import torch.nn.functional as F |
|
|
| seed_num = 0 |
| random.seed(seed_num) |
| np.random.seed(seed_num) |
| seed_val = 42 |
| torch.manual_seed(seed_val) |
| torch.cuda.manual_seed_all(seed_val) |
|
|
| |
| timezone = pytz.timezone("Asia/Riyadh") |
| rootdir = os.getcwd() + "/Self_train" |
|
|
|
|
| corpus_dir = "Pretrain_data" |
| with open(corpus_dir + "/token_dictionary.pkl", "rb") as fp: |
| token_dictionary = pickle.load(fp) |
|
|
| len_vocabulary = len(token_dictionary) |
|
|
| class CustomBertForMaskedLM(BertPreTrainedModel): |
| _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] |
| _tied_weights_keys = ["decoder.weight", "bert.embeddings.word_embeddings.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.bert = BertModel(config, add_pooling_layer=False) |
| self.transform = BertPredictionHeadTransform(config) |
|
|
| self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size)) |
|
|
| |
| self.init_weights() |
|
|
| |
| self.tie_weights() |
|
|
| self.post_init() |
|
|
| def tie_weights(self): |
| """ |
| Ties the weights between the input embeddings and output decoder weights. |
| """ |
| self.decoder.weight = self.bert.embeddings.word_embeddings.weight |
|
|
| def probability_convert(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor: |
| device = probs.device |
| batch_size, seq_length, vocab_size = probs.size() |
| _, input_seq_length = input_ids.size() |
|
|
| |
| |
| non_mask = labels == -100 |
| non_mask_indices = non_mask.nonzero(as_tuple=True) |
| known_gene_indices = input_ids[non_mask] |
|
|
| |
| zeros = torch.zeros((batch_size, 1, vocab_size), device=device) |
| zeros[non_mask_indices[0], 0, known_gene_indices] = 1.0 |
| probs_shifted = torch.cat((zeros, probs[:, :-1, :]), dim=1) |
| inv_probs_shifted = 1 - probs_shifted |
| |
| |
| cumprod_inv_probs = torch.cumprod(inv_probs_shifted, dim=1) |
| modified_probs = probs * cumprod_inv_probs |
|
|
| |
| |
| |
| |
| |
| normalized_probs = modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) |
| modified_probs = modified_probs / normalized_probs |
| |
| return modified_probs |
| |
| def assign_known_gene_probs(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor: |
|
|
| device = probs.device |
| batch_size, seq_length, vocab_size = probs.size() |
| _, input_seq_length = input_ids.size() |
|
|
| |
| truncated_labels = labels[:, :input_seq_length] |
|
|
| non_mask = truncated_labels == -100 |
| non_mask_indices = non_mask.nonzero(as_tuple=True) |
|
|
| ones = torch.ones((batch_size, seq_length, vocab_size), device=device) |
| zeros = torch.zeros((batch_size, seq_length, vocab_size), device=device) |
| |
| known_gene_indices = input_ids[non_mask] |
|
|
| ones[non_mask_indices[0], non_mask_indices[1], :] = 0.0 |
| zeros[non_mask_indices[0], non_mask_indices[1], known_gene_indices] = 1.0 |
| |
| modified_probs = probs * ones |
| modified_probs = modified_probs + zeros |
|
|
| |
| modified_probs = modified_probs / modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) |
|
|
| return modified_probs |
|
|
| def compute_similarity_on_probs(self, probs: Tensor, labels: Tensor) -> Tensor: |
| """ |
| Optimized computation of average cosine similarity across all positions in each sequence and batch. |
| |
| Args: |
| probs (torch.Tensor): Probability tensor of shape (batch_size, seq_length, vocab_size). |
| |
| Returns: |
| torch.Tensor: Average similarity term for loss computation. |
| """ |
| batch_size, seq_length, vocab_size = probs.size() |
| device = probs.device |
|
|
| non_mask = labels == -100 |
| non_mask_indices = non_mask.nonzero(as_tuple=True) |
|
|
| mask_sim = torch.ones((batch_size, seq_length, seq_length), device=device) |
| mask_sim[non_mask_indices[0], non_mask_indices[1], :] = 0.0 |
|
|
| seq_mask = torch.triu(torch.ones(seq_length, seq_length, device=device), diagonal=1) |
| batch_mask = seq_mask.unsqueeze(0).expand(batch_size, seq_length, seq_length) |
| mask_sim = mask_sim * batch_mask |
|
|
| |
| probs_norm = F.normalize(probs, dim=-1) |
| |
| |
| similarities = torch.einsum("biv,bjv->bij", probs_norm, probs_norm) |
|
|
| |
| |
| valid_similarities = similarities * mask_sim |
|
|
| |
| total_similarity = valid_similarities.sum() |
| total_comparisons = mask_sim.sum().item() |
|
|
| if total_comparisons == 0: |
| return torch.tensor(0.0, device=device) |
| |
| return total_similarity / total_comparisons |
|
|
|
|
| def forward( |
| self, |
| input_ids: Tensor | None = None, |
| attention_mask: Tensor | None = None, |
| token_type_ids: Tensor | None = None, |
| position_ids: Tensor | None = None, |
| head_mask: Tensor | None = None, |
| inputs_embeds: Tensor | None = None, |
| encoder_hidden_states: Tensor | None = None, |
| encoder_attention_mask: Tensor | None = None, |
| labels: Tensor | None = None, |
| output_attentions: bool | None = None, |
| output_hidden_states: bool | None = None, |
| return_dict: bool | None = None) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
| outputs = self.bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| |
| hidden_states = outputs[0] |
| hidden_transform = self.transform(hidden_states) |
| logits = self.decoder(hidden_transform) + self.bias |
|
|
| probs = F.softmax(logits, dim=-1) |
| |
| |
| probs = self.assign_known_gene_probs(probs, input_ids, labels) |
| convert_probs = self.probability_convert(probs, input_ids, labels) |
| assigned_probs = self.assign_known_gene_probs(convert_probs, input_ids, labels) |
|
|
| masked_lm_loss = None |
| if labels is not None: |
| probs_flat = assigned_probs.view(-1, self.config.vocab_size) |
| labels_flat = labels.view(-1) |
| mask = (labels != -100).float().view(-1) |
|
|
| |
| masked_lm_loss = -torch.log(torch.clamp(probs_flat[torch.arange(len(labels_flat)), labels_flat], min=1e-18)) * mask |
| masked_lm_loss = masked_lm_loss.sum() / mask.sum() |
|
|
| similarity_loss = self.compute_similarity_on_probs(assigned_probs, labels) |
| lambda_similarity = 1.0 |
| masked_lm_loss = masked_lm_loss + lambda_similarity * similarity_loss |
|
|
| |
| else: |
| loss = None |
|
|
| if not return_dict: |
| output = (assigned_probs,) + outputs[2:] |
| return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
| return MaskedLMOutput( |
| loss=masked_lm_loss, |
| logits=assigned_probs, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
| |
| def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): |
| input_shape = input_ids.shape |
| effective_batch_size = input_shape[0] |
|
|
| |
| if self.config.pad_token_id is None: |
| raise ValueError("The PAD token should be defined for generation") |
|
|
| attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) |
| dummy_token = torch.full( |
| (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device |
| ) |
| input_ids = torch.cat([input_ids, dummy_token], dim=1) |
|
|
| return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
| |
| |
| model_type = "bert" |
| |
| max_input_size = 2**11 |
| |
| num_layers = 6 |
| |
| num_attn_heads = 4 |
| |
| num_embed_dim = 256 |
| |
| intermed_size = num_embed_dim * 2 |
| |
| activ_fn = "relu" |
| |
| initializer_range = 0.02 |
| layer_norm_eps = 1e-12 |
| attention_probs_dropout_prob = 0.02 |
| hidden_dropout_prob = 0.02 |
|
|
| |
| |
| num_examples = 27_406_208 |
| |
| num_gpus = 8 |
| |
| geneformer_batch_size = 8 |
| |
| max_lr = 1e-3 |
| |
| lr_schedule_fn = "linear" |
| |
| warmup_steps = 10_000 |
| |
| epochs = 3 |
| |
| optimizer = "adamw" |
| |
| weight_decay = 0.001 |
|
|
|
|
| |
| current_date = datetime.datetime.now(tz=timezone) |
| datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}_{current_date.strftime('%X').replace(':','')}" |
| run_name = f"GF_CAB_{datestamp}_L{num_layers}_emb{num_embed_dim}_SL{max_input_size}_E{epochs}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_O{optimizer}_DS{num_gpus}" |
| training_output_dir = f"{rootdir}/models/{run_name}/" |
| logging_dir = f"{rootdir}/runs/{run_name}/" |
| model_output_dir = os.path.join(training_output_dir, "models/") |
|
|
|
|
| model_output_file = os.path.join(model_output_dir, "pytorch_model.bin") |
| if os.path.isfile(model_output_file) is True: |
| raise Exception("Model already saved to this directory.") |
|
|
|
|
| |
| os.makedirs(training_output_dir, exist_ok=True) |
| os.makedirs(model_output_dir, exist_ok=True) |
|
|
| |
| config = { |
| "hidden_size": num_embed_dim, |
| "num_hidden_layers": num_layers, |
| "initializer_range": initializer_range, |
| "layer_norm_eps": layer_norm_eps, |
| "attention_probs_dropout_prob": attention_probs_dropout_prob, |
| "hidden_dropout_prob": hidden_dropout_prob, |
| "intermediate_size": intermed_size, |
| "hidden_act": activ_fn, |
| "max_position_embeddings": max_input_size, |
| "model_type": model_type, |
| "num_attention_heads": num_attn_heads, |
| "pad_token_id": token_dictionary.get("<pad>"), |
| "vocab_size": len(token_dictionary), |
| } |
|
|
| config = BertConfig(**config) |
| model = CustomBertForMaskedLM(config) |
| model = model.train() |
|
|
|
|
| training_args = { |
| "learning_rate": max_lr, |
| "do_train": True, |
| "do_eval": False, |
| "group_by_length": True, |
| "length_column_name": "length", |
| "disable_tqdm": False, |
| "lr_scheduler_type": lr_schedule_fn, |
| "warmup_steps": warmup_steps, |
| "weight_decay": weight_decay, |
| "per_device_train_batch_size": geneformer_batch_size, |
| "num_train_epochs": epochs, |
| "save_strategy": "steps", |
| "save_steps": np.floor(num_examples / geneformer_batch_size / 8), |
| "logging_steps": 1000, |
| "output_dir": training_output_dir, |
| "logging_dir": logging_dir, |
| } |
| training_args = TrainingArguments(**training_args) |
|
|
| print("Starting training.") |
|
|
| |
| trainer = GeneformerPretrainer( |
| model=model, |
| args=training_args, |
| train_dataset=load_from_disk("Pretrain_data/genecorpus_30M_2048.dataset"), |
| example_lengths_file="Pretrain_data/genecorpus_30M_2048_lengths.pkl", |
| token_dictionary=token_dictionary, |
| ) |
|
|
| |
| trainer.train() |
| |
| trainer.save_model(model_output_dir) |