import sys import os import numpy as np import pandas as pd from os.path import join import json import argparse import glob import logging import os import pickle import random import re import shutil from typing import Dict, List, Tuple from copy import deepcopy from multiprocessing import Pool import numpy as np import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm, trange import collections import itertools module_path = "/storage1/fs1/yeli/Active/xiaoxiao.zhou/projects/foundation/nucleotide-transformer" if module_path not in sys.path: sys.path.append(module_path) # import haiku as hk # import jax # import jax.numpy as jnp # from nucleotide_transformer.pretrained import get_pretrained_model from transformers import AutoTokenizer, AutoModelForMaskedLM import torch def main(): cache_dir='/storage2/fs1/btc/Active/yeli/xiaoxiao.zhou/apps/transformers_cache' tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-human-ref") model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-500m-human-ref") for folder in os.listdir(args.data_dir): if not folder.startswith('.'): for f in os.listdir(os.path.join(args.data_dir, folder)): if not f.startswith('.'): for name in ['test', 'dev', 'train']: data = join(args.data_dir, folder, f, name + '.csv') if not os.path.exists(data): print(f"File {data} does not exist, skipping...") continue df = pd.read_csv(data, sep = '\t') print('Processing ' + folder + ' ' + f) df_tokenized = [] if args.only_positive: for i in range(len(df['sequence'])): if df['label'][i] == 1: seg = df['sequence'][i] output = tokenizer.encode_plus(seg, return_tensors="pt") df_tokenized.append(output['input_ids'].cpu()) df_ = [" ".join(str(token.item()) for token in line.squeeze()) for line in df_tokenized] f_ = join(args.data_dir, folder, f, name + '_NT_only_POS.json') with open(f_, 'w') as file: logging.warning(f"Saving tokenized results to {f_}...") json.dump(df_, file) else: for i in range(len(df['sequence'])): seg = df['sequence'][i] output = tokenizer.encode_plus(seg, return_tensors="pt") df_tokenized.append(output['input_ids'].cpu()) df_ = [" ".join(str(token.item()) for token in line.squeeze()) for line in df_tokenized] f_ = join(args.data_dir, folder, f, name + '_NT.json') with open(f_, 'w') as file: logging.warning(f"Saving tokenized results to {f_}...") json.dump(df_, file) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data_dir", type=str, required=True) parser.add_argument("--only_positive", action="store_true") args = parser.parse_args() main()