GF-CAB / geneformer /Dataset_Create.py
JFLa's picture
Upload 14 files
f22cc6f verified
import numpy as np
from datasets import load_from_disk
import torch
from transformers import BertForMaskedLM
import os
import sys
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
# sys.path.append('/Users/chenj0i/Desktop/Lab Work/Geneformer')
from geneformer.pretrainer import token_dictionary
import datetime
# imports
import os
import time
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["OMPI_MCA_opal_cuda_support"] = "true"
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
import pickle
import random
import subprocess
import numpy as np
import pytz
import torch
from datasets import load_from_disk, Dataset
from transformers import BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback, Trainer, BertModel, BertPreTrainedModel
from geneformer import GeneformerPretrainer
from typing import Tuple
from torch import Tensor
from transformers.modeling_outputs import MaskedLMOutput
from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform
from transformers.activations import ACT2FN
from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
# # Randomly select 100_000 sequences from Genecorpus to conduct the training
genecorpus = load_from_disk("/ibex/user/chenj0i/Geneformer/Genecorpus-30M/genecorpus_30M_2048.dataset")
subset_size = 1_200_000
subset_sequences = genecorpus.shuffle(seed=42).select(i for i in tqdm(list(range(subset_size))))['input_ids']
subset_train_dataset = Dataset.from_dict({"input_ids": subset_sequences[:-200_000]})
subset_train_dataset.save_to_disk("/ibex/user/chenj0i/Geneformer/subset_1Mtrain_genecorpus.dataset")
subset_test_dataset = Dataset.from_dict({"input_ids": subset_sequences[-200_000:]})
subset_test_dataset.save_to_disk("/ibex/user/chenj0i/Geneformer/subset_200K_1Mtrain_genecorpus.dataset")
# Create length file for the training
# Define the value to repeat
value_to_repeat = 2048
# Define the total number of elements
total_elements = 1_000_000
# Create the list with repeated values
data_list = [value_to_repeat] * total_elements
# Define the path for the output .pkl length file
output_file = "sub_1Mtrain_genecorpus_30M_2048_lengths.pkl"
# Save the list to a .pkl file
with open(output_file, 'wb') as f:
pickle.dump(data_list, f)
print(f"List with {subset_size} elements saved as {output_file}")
value_to_repeat_test = 2048
# Define the total number of elements
total_elements_test = 200_000
# Create the list with repeated values
data_list_test = [value_to_repeat_test] * total_elements_test
# Define the path for the output .pkl length file
output_file_test = "sub_200K_1Mtrain_genecorpus_30M_2048_lengths.pkl"
# Save the list to a .pkl file
with open(output_file_test, 'wb') as f:
pickle.dump(data_list_test, f)
print(f"List with {subset_size} elements saved as {output_file_test}")