pgps-demo / datasets /text_aug.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
import random
upper_case_list = [chr(i) for i in range(65, 91)]
low_case_list = [chr(i) for i in range(97, 123)]
angle_id_list = [str(i) for i in range(1, 21)]
spec_token_list = ['frac', 'pi', 'sqrt']
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
for t in self.transforms:
t(text_seq, stru_seqs, sem_seqs, exp)
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class Point_RandomReplace(object):
def __init__(self, prob=0.5):
self.prob = prob
def get_point_map(self):
value_list = [chr(i) for i in range(65, 91)]
random.shuffle(value_list)
map_dict = {key:value for key, value in zip(upper_case_list, value_list)}
return map_dict
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
if random.random() < self.prob:
map_dict = self.get_point_map()
for k in range(len(text_seq.token)):
if text_seq.class_tag[k] == '[POINT]':
text_seq.token[k] = map_dict[text_seq.token[k][0]]
for k in range(len(stru_seqs.token)):
for j in range(len(stru_seqs.token[k])):
if stru_seqs.class_tag[k][j] == '[POINT]':
stru_seqs.token[k][j] = map_dict[stru_seqs.token[k][j][0]]
for k in range(len(sem_seqs.token)):
for j in range(len(sem_seqs.token[k])):
if sem_seqs.class_tag[k][j] == '[POINT]':
sem_seqs.token[k][j] = map_dict[sem_seqs.token[k][j][0]]
class AngID_RandomReplace(object):
def __init__(self, prob=0.5):
self.prob = prob
def get_angid_map(self):
value_list = [str(i) for i in range(1, 21)]
random.shuffle(value_list)
map_dict = {key:value for key, value in zip(angle_id_list, value_list)}
return map_dict
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
if random.random() < self.prob:
map_dict = self.get_angid_map()
for k in range(len(text_seq.token)):
if text_seq.class_tag[k] == '[ANGID]':
text_seq.token[k] = map_dict[text_seq.token[k]]
for k in range(len(sem_seqs.token)):
for j in range(len(sem_seqs.token[k])):
if sem_seqs.class_tag[k][j] == '[ANGID]':
sem_seqs.token[k][j] = map_dict[sem_seqs.token[k][j]]
class Arg_RandomReplace(object):
def __init__(self, prob=0.5):
self.prob = prob
def get_arg_map(self):
value_list = [chr(i) for i in range(97, 123)]
random.shuffle(value_list)
map_dict = {key:value for key, value in zip(low_case_list, value_list)}
return map_dict
def map_arg_in_num(self, map_dict, num):
num_t = num[:]
new_num = ''
for item in spec_token_list:
num_t = num_t.replace(item, "@"*len(item))
for k in range(len(num_t)):
if num_t[k]!='@' and num[k] in low_case_list:
new_num += map_dict[num[k]]
else:
new_num += num[k]
return new_num
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
if random.random() < self.prob:
map_dict = self.get_arg_map()
for k in range(len(text_seq.token)):
if text_seq.class_tag[k] == '[NUM]':
text_seq.token[k] = self.map_arg_in_num(map_dict, text_seq.token[k])
if text_seq.class_tag[k] == '[ARG]':
text_seq.token[k] = map_dict[text_seq.token[k]]
for k in range(len(sem_seqs.token)):
for j in range(len(sem_seqs.token[k])):
if sem_seqs.class_tag[k][j] == '[NUM]':
sem_seqs.token[k][j] = self.map_arg_in_num(map_dict, sem_seqs.token[k][j])
for k in range(len(exp)):
if exp[k] in low_case_list:
exp[k] = map_dict[exp[k]]
class StruPoint_RandomRotate(object):
def __init__(self, prob=0.5):
self.prob = prob
def get_seq_points(self, class_tag):
id_list = []
begin_point_id = end_point_id = None
for id, token in enumerate(class_tag):
if token == '[POINT]':
if begin_point_id is None:
begin_point_id = id
elif not begin_point_id is None and end_point_id is None:
end_point_id = id
id_list.append([begin_point_id, end_point_id])
begin_point_id = end_point_id = None
if not begin_point_id is None and end_point_id is None:
id_list.append([begin_point_id, len(class_tag)])
return id_list[-1][0], id_list[-1][1]
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
for k in range(len(stru_seqs.token)):
if random.random() < self.prob:
begin_id, end_id = self.get_seq_points(stru_seqs.class_tag[k])
# point on line
if stru_seqs.token[k][0] == 'line':
stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][end_id-1:begin_id-1:-1]
# point on circle
if stru_seqs.token[k][0] == '\\odot':
# clockwise change
if random.random() < 0.5:
stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][end_id-1:begin_id-1:-1]
# set initial point
init_loc = random.randint(begin_id, end_id-1)
stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][init_loc:end_id] + \
stru_seqs.token[k][begin_id:init_loc]
class SemPoint_RandomRotate(object):
def __init__(self, prob=0.5):
self.prob = prob
def get_seq_points(self, class_tag):
id_list = []
begin_point_id = end_point_id = None
for id, token in enumerate(class_tag):
if token == '[POINT]':
if begin_point_id is None:
begin_point_id = id
elif not begin_point_id is None and end_point_id is None:
end_point_id = id
id_list.append((begin_point_id, end_point_id-1))
begin_point_id = end_point_id = None
if not begin_point_id is None and end_point_id is None:
id_list.append((begin_point_id, len(class_tag)-1))
return id_list
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
if random.random() < self.prob:
for k in range(len(sem_seqs.token)):
id_list = self.get_seq_points(sem_seqs.class_tag[k])
for begin_id, end_id in id_list:
if random.random() < self.prob:
sem_seqs.token[k][begin_id], sem_seqs.token[k][end_id] = \
sem_seqs.token[k][end_id], sem_seqs.token[k][begin_id]
class SemSeq_RandomRotate(object):
def __init__(self, prob=0.5):
if prob==0:
self.prob = 0
else:
self.prob = prob + 0.2
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
if random.random() < self.prob:
# varible id
num_all_list, num_sem_list, num_map_list = [], [], []
for item in text_seq.class_tag:
if item=='[NUM]':
var_name = 'N'+str(len(num_all_list))
num_all_list.append(var_name)
num_map_list.append(var_name)
for k in range(len(sem_seqs.token)):
if sem_seqs.class_tag[k][-2] == '[NUM]':
var_name = 'N'+str(len(num_all_list))
num_all_list.append(var_name)
num_sem_list.append([var_name])
else:
num_sem_list.append([])
# shuffle sem_seq
if len(sem_seqs.token)>0:
random_id_list = [k for k in range(len(sem_seqs.token))]
random.shuffle(random_id_list)
for key,value in vars(sem_seqs).items():
_, value = zip(*sorted(zip(random_id_list, value)))
setattr(sem_seqs, key, list(value))
_, num_sem_list = zip(*sorted(zip(random_id_list, num_sem_list)))
# expression map
for k in range(len(sem_seqs.token)):
num_map_list += num_sem_list[k]
num_map_dict = {key:value for key, value in zip(num_map_list, num_all_list)}
for k in range(len(exp)):
if exp[k] in num_map_dict:
exp[k] = num_map_dict[exp[k]]
class StruSeq_RandomRotate(object):
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
if random.random() < self.prob:
# shuffle stru_seq
if len(stru_seqs.token)>0:
random_id_list = [k for k in range(len(stru_seqs.token))]
random.shuffle(random_id_list)
for key, value in vars(stru_seqs).items():
_, value = zip(*sorted(zip(random_id_list, value)))
setattr(stru_seqs, key, list(value))