|
|
import numpy as np |
|
|
from datasets import load_from_disk |
|
|
import torch |
|
|
from transformers import BertForMaskedLM |
|
|
import os |
|
|
import sys |
|
|
from tqdm.notebook import tqdm |
|
|
import seaborn as sns |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
from geneformer.pretrainer import token_dictionary |
|
|
|
|
|
import datetime |
|
|
|
|
|
|
|
|
import os |
|
|
import time |
|
|
|
|
|
os.environ["NCCL_DEBUG"] = "INFO" |
|
|
os.environ["OMPI_MCA_opal_cuda_support"] = "true" |
|
|
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56" |
|
|
|
|
|
import pickle |
|
|
import random |
|
|
import subprocess |
|
|
|
|
|
import numpy as np |
|
|
import pytz |
|
|
import torch |
|
|
from datasets import load_from_disk, Dataset |
|
|
from transformers import BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback, Trainer, BertModel, BertPreTrainedModel |
|
|
|
|
|
from geneformer import GeneformerPretrainer |
|
|
|
|
|
from typing import Tuple |
|
|
from torch import Tensor |
|
|
from transformers.modeling_outputs import MaskedLMOutput |
|
|
from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform |
|
|
from transformers.activations import ACT2FN |
|
|
from typing import List, Optional, Tuple, Union |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
genecorpus = load_from_disk("/ibex/user/chenj0i/Geneformer/Genecorpus-30M/genecorpus_30M_2048.dataset") |
|
|
|
|
|
subset_size = 1_200_000 |
|
|
subset_sequences = genecorpus.shuffle(seed=42).select(i for i in tqdm(list(range(subset_size))))['input_ids'] |
|
|
subset_train_dataset = Dataset.from_dict({"input_ids": subset_sequences[:-200_000]}) |
|
|
subset_train_dataset.save_to_disk("/ibex/user/chenj0i/Geneformer/subset_1Mtrain_genecorpus.dataset") |
|
|
subset_test_dataset = Dataset.from_dict({"input_ids": subset_sequences[-200_000:]}) |
|
|
subset_test_dataset.save_to_disk("/ibex/user/chenj0i/Geneformer/subset_200K_1Mtrain_genecorpus.dataset") |
|
|
|
|
|
|
|
|
|
|
|
value_to_repeat = 2048 |
|
|
|
|
|
total_elements = 1_000_000 |
|
|
|
|
|
data_list = [value_to_repeat] * total_elements |
|
|
|
|
|
output_file = "sub_1Mtrain_genecorpus_30M_2048_lengths.pkl" |
|
|
|
|
|
with open(output_file, 'wb') as f: |
|
|
pickle.dump(data_list, f) |
|
|
print(f"List with {subset_size} elements saved as {output_file}") |
|
|
|
|
|
value_to_repeat_test = 2048 |
|
|
|
|
|
total_elements_test = 200_000 |
|
|
|
|
|
data_list_test = [value_to_repeat_test] * total_elements_test |
|
|
|
|
|
output_file_test = "sub_200K_1Mtrain_genecorpus_30M_2048_lengths.pkl" |
|
|
|
|
|
with open(output_file_test, 'wb') as f: |
|
|
pickle.dump(data_list_test, f) |
|
|
print(f"List with {subset_size} elements saved as {output_file_test}") |