nancyH's picture
Upload folder using huggingface_hub
ab6c03c verified
import csv
import os
import json
import argparse
import random
from process_pretrain_data import get_kmer_sentence
max_length = 0
def Process_pair(args):
random.seed(42)
root_path = args.file_path.split('/')[-1]
train_seq1_file = open(args.file_path+"/"+root_path+"_enhancer.fasta", "r")
train_seq2_file = open(args.file_path+"/"+root_path+"_promoter.fasta", "r")
train_label_file = open(args.file_path+"/"+root_path+"_label.txt", "r")
test_seq1_file = open(args.file_path+"/"+root_path+"_enhancer_test.fasta", "r")
test_seq2_file = open(args.file_path+"/"+root_path+"_promoter_test.fasta", "r")
test_label_file = open(args.file_path+"/"+root_path+"_label_test.txt", "r")
train_seq1 = train_seq1_file.readlines()
train_seq2 = train_seq2_file.readlines()
train_label = train_label_file.readlines()
test_seq1 = test_seq1_file.readlines()
test_seq2 = test_seq2_file.readlines()
test_label = test_label_file.readlines()
train_lines = []
test_lines = []
for i in range(len(train_label)):
train_lines.append([train_seq1[2*i+1], train_seq2[2*i+1], train_label[i]])
for i in range(len(test_label)):
test_lines.append([test_seq1[2*i+1], test_seq2[2*i+1], test_label[i]])
random.shuffle(train_lines)
if args.dev:
num_dev = int(len(train_lines)/10)
dev_lines = train_lines[:num_dev]
train_lines = train_lines[num_dev:]
output_path = make_path(args)
suffix = '.csv' if args.csv else '.tsv'
delimiter = ',' if args.csv else '\t'
f_train = open(os.path.join(output_path, "train" + suffix), 'wt')
train_w = csv.writer(f_train, delimiter=delimiter)
train_w.writerow(["seq1", "seq2", "label"])
if args.dev:
f_dev = open(os.path.join(output_path, "dev" + suffix), 'wt')
dev_w = csv.writer(f_dev, delimiter=delimiter)
dev_w.writerow(["seq1", "seq2", "label"])
os.makedirs(os.path.join(output_path, "test"))
f_test = open(os.path.join(output_path, "test", "dev" + suffix), 'wt')
test_w = csv.writer(f_test, delimiter=delimiter)
test_w.writerow(["seq1", "seq2", "label"])
else:
f_test = open(os.path.join(output_path, "dev" + suffix), 'wt')
test_w = csv.writer(f_test, delimiter=delimiter)
test_w.writerow(["seq1", "seq2", "label"])
def write_file_pair(lines, writer, seq1_index=0, seq2_index=1, label_index=2):
for line in lines:
seq1 = get_kmer_sentence(line[seq1_index], kmer=args.kmer, stride=args.stride)
seq2 = get_kmer_sentence(line[seq2_index], kmer=args.kmer, stride=args.stride)
writer.writerow([seq1, seq2, str(int(line[label_index]))])
write_file_pair(train_lines, train_w)
write_file_pair(test_lines, test_w)
if args.dev:
write_file_pair(dev_lines, dev_w)
def make_path(args):
output_path = args.output_path if args.output_path else os.path.join(args.file_path, str(args.kmer))
if not os.path.exists(output_path):
os.makedirs(output_path)
return output_path
def write_file(lines, writer, seq_index=2, label_index=3, kmer=6, stride=1):
global max_length
for line in lines:
sentence = get_kmer_sentence(line[seq_index], kmer=kmer, stride=stride)
if len(sentence.split()) > max_length:
max_length = len(sentence.split())
if label_index == -100:
writer.writerow([sentence, str(0)])
else:
writer.writerow([sentence, str(line[label_index])])
def Process(args):
random.seed(24)
train = os.path.join(args.file_path, "train.csv")
test = os.path.join(args.file_path, "test.csv")
train_file = open(train, "r", encoding="utf-8-sig")
test_file = open(test, "r", encoding="utf-8-sig")
train_lines = list(csv.reader(train_file, delimiter=",", quotechar=None))[1:]
test_lines = list(csv.reader(test_file, delimiter=",", quotechar=None))[1:]
random.shuffle(train_lines)
random.shuffle(test_lines)
if args.dev:
num_dev = int(len(train_lines)/9)
dev_lines = train_lines[:num_dev]
train_lines = train_lines[num_dev:]
print(train_lines[0])
output_path = make_path(args)
suffix = '.csv' if args.csv else '.tsv'
delimiter = ',' if args.csv else '\t'
f_train = open(os.path.join(output_path, "train"+suffix), 'wt')
train_w = csv.writer(f_train, delimiter=delimiter)
train_w.writerow(["sentence", "label"])
if args.dev:
f_dev = open(os.path.join(output_path, "dev"+suffix), 'wt')
dev_w = csv.writer(f_dev, delimiter=delimiter)
dev_w.writerow(["sentence", "label"])
f_test = open(os.path.join(output_path, "test"+suffix), 'wt')
test_w = csv.writer(f_test, delimiter=delimiter)
test_w.writerow(["sentence", "label"])
else:
f_test = open(os.path.join(output_path, "dev"+suffix), 'wt')
test_w = csv.writer(f_test, delimiter=delimiter)
test_w.writerow(["sentence", "label"])
write_file(train_lines, train_w, args.seq_index, args.label_index)
write_file(test_lines, test_w, args.seq_index, args.label_index)
if args.dev:
write_file(dev_lines, dev_w)
print("max length: %d" % (max_length))
def Process_UCE(args):
len_count = {}
line2index = {}
pred_file = open(args.file_path, "r", encoding="utf-8-sig")
pred_lines = list(csv.reader(pred_file, delimiter=",", quotechar=None))[1:]
suffix = '.csv' if args.csv else '.tsv'
delimiter = ',' if args.csv else '\t'
f_pred = open(os.path.join(args.output_path, "dev"+suffix), 'wt')
pred_w = csv.writer(f_pred, delimiter=delimiter)
pred_w.writerow(["sentence", "label"])
index = 1
line_num = 0
for line in pred_lines:
len_count[len(line[8])] = len_count.get(len(line[8]), 0) + 1
len_count[len(line[-2])] = len_count.get(len(line[-2]), 0) + 1
cur_index = [index, index+1]
ref = get_kmer_sentence(line[8], args.kmer, args.stride)
pred_w.writerow([ref, 0])
mut1 = get_kmer_sentence(line[-2], args.kmer, args.stride)
pred_w.writerow([mut1, 0])
index += 2
if line[-2] != line[-1]:
len_count[len(line[-1])] = len_count.get(len(line[-1]), 0) + 1
mut2 = get_kmer_sentence(line[-1], args.kmer, args.stride)
pred_w.writerow([mut2, 0])
cur_index.append(index)
index += 1
line2index[line_num] = cur_index
line_num += 1
with open(os.path.join(args.output_path, "line2index.json"), "w") as f:
json.dump(line2index, f)
with open(os.path.join(args.output_path, "lencount.json"), "w") as f:
json.dump(len_count, f)
def Process_Virus(args):
file_path = args.file_path
all_files = os.listdir(file_path)
all_files = [f for f in all_files if not f.startswith("unclass")]
all_lines = []
for i, f in enumerate(all_files):
f_dir = os.path.join(file_path, f)
cur_file = open(f_dir, "r", encoding="utf-8-sig")
cur_lines = list(csv.reader(cur_file, delimiter=",", quotechar=None))[1:]
all_lines.extend(cur_lines)
suffix = '.csv' if args.csv else '.tsv'
delimiter = ',' if args.csv else '\t'
f_pred = open(os.path.join(args.output_path, "dev"+suffix), 'wt')
pred_w = csv.writer(f_pred, delimiter=delimiter)
pred_w.writerow(["sentence", "label"])
index = 1
line_num = 0
for line in pred_lines:
cur_index = [index, index+1]
ref = get_kmer_sentence(line[8], args.kmer, args.stride)
pred_w.writerow([ref, 0])
mut1 = get_kmer_sentence(line[-2], args.kmer, args.stride)
pred_w.writerow([mut1, 0])
index += 2
if line[-2] != line[-1]:
len_count[len(line[-1])] = len_count.get(len(line[-1]), 0) + 1
mut2 = get_kmer_sentence(line[-1], args.kmer, args.stride)
pred_w.writerow([mut2, 0])
cur_index.append(index)
index += 1
line2index[line_num] = cur_index
line_num += 1
with open(os.path.join(args.output_path, "line2index.json"), "w") as f:
json.dump(line2index, f)
with open(os.path.join(args.output_path, "lencount.json"), "w") as f:
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--kmer",
default=1,
type=int,
help="K-mer",
)
parser.add_argument(
"--stride",
default=1,
type=int,
help="stride in getting kmer sequence",
)
parser.add_argument(
"--file_path",
default=None,
type=str,
help="The path of the file to be processed",
)
parser.add_argument(
"--output_path",
default=None,
type=str,
help="The path of the processed data",
)
parser.add_argument(
"--dev",
action="store_true",
help="Use this flag to split data as (8:1:1), else (9:1)",
)
parser.add_argument(
"--csv",
action="store_true",
help="if output csv file or not, if not, output tsv",
)
parser.add_argument(
"--pair",
action="store_true",
help="Use this flag to split data as (8:1:1), else (9:1)",
)
parser.add_argument(
"--uce",
action="store_true",
help="Use this flag to split data as (8:1:1), else (9:1)",
)
parser.add_argument(
"--seq_index",
default=2,
type=int,
help="index of seq in the original csv file",
)
parser.add_argument(
"--label_index",
default=3,
type=int,
help="index of label in the original csv file",
)
args = parser.parse_args()
if args.pair:
Process_pair(args)
elif args.uce:
Process_UCE(args)
else:
Process(args)
if __name__ == "__main__":
main()