Spaces:
Sleeping
Sleeping
| import random | |
| upper_case_list = [chr(i) for i in range(65, 91)] | |
| low_case_list = [chr(i) for i in range(97, 123)] | |
| angle_id_list = [str(i) for i in range(1, 21)] | |
| spec_token_list = ['frac', 'pi', 'sqrt'] | |
| class Compose(object): | |
| def __init__(self, transforms): | |
| self.transforms = transforms | |
| def __call__(self, text_seq, stru_seqs, sem_seqs, exp): | |
| for t in self.transforms: | |
| t(text_seq, stru_seqs, sem_seqs, exp) | |
| def __repr__(self): | |
| format_string = self.__class__.__name__ + "(" | |
| for t in self.transforms: | |
| format_string += "\n" | |
| format_string += " {0}".format(t) | |
| format_string += "\n)" | |
| return format_string | |
| class Point_RandomReplace(object): | |
| def __init__(self, prob=0.5): | |
| self.prob = prob | |
| def get_point_map(self): | |
| value_list = [chr(i) for i in range(65, 91)] | |
| random.shuffle(value_list) | |
| map_dict = {key:value for key, value in zip(upper_case_list, value_list)} | |
| return map_dict | |
| def __call__(self, text_seq, stru_seqs, sem_seqs, exp): | |
| if random.random() < self.prob: | |
| map_dict = self.get_point_map() | |
| for k in range(len(text_seq.token)): | |
| if text_seq.class_tag[k] == '[POINT]': | |
| text_seq.token[k] = map_dict[text_seq.token[k][0]] | |
| for k in range(len(stru_seqs.token)): | |
| for j in range(len(stru_seqs.token[k])): | |
| if stru_seqs.class_tag[k][j] == '[POINT]': | |
| stru_seqs.token[k][j] = map_dict[stru_seqs.token[k][j][0]] | |
| for k in range(len(sem_seqs.token)): | |
| for j in range(len(sem_seqs.token[k])): | |
| if sem_seqs.class_tag[k][j] == '[POINT]': | |
| sem_seqs.token[k][j] = map_dict[sem_seqs.token[k][j][0]] | |
| class AngID_RandomReplace(object): | |
| def __init__(self, prob=0.5): | |
| self.prob = prob | |
| def get_angid_map(self): | |
| value_list = [str(i) for i in range(1, 21)] | |
| random.shuffle(value_list) | |
| map_dict = {key:value for key, value in zip(angle_id_list, value_list)} | |
| return map_dict | |
| def __call__(self, text_seq, stru_seqs, sem_seqs, exp): | |
| if random.random() < self.prob: | |
| map_dict = self.get_angid_map() | |
| for k in range(len(text_seq.token)): | |
| if text_seq.class_tag[k] == '[ANGID]': | |
| text_seq.token[k] = map_dict[text_seq.token[k]] | |
| for k in range(len(sem_seqs.token)): | |
| for j in range(len(sem_seqs.token[k])): | |
| if sem_seqs.class_tag[k][j] == '[ANGID]': | |
| sem_seqs.token[k][j] = map_dict[sem_seqs.token[k][j]] | |
| class Arg_RandomReplace(object): | |
| def __init__(self, prob=0.5): | |
| self.prob = prob | |
| def get_arg_map(self): | |
| value_list = [chr(i) for i in range(97, 123)] | |
| random.shuffle(value_list) | |
| map_dict = {key:value for key, value in zip(low_case_list, value_list)} | |
| return map_dict | |
| def map_arg_in_num(self, map_dict, num): | |
| num_t = num[:] | |
| new_num = '' | |
| for item in spec_token_list: | |
| num_t = num_t.replace(item, "@"*len(item)) | |
| for k in range(len(num_t)): | |
| if num_t[k]!='@' and num[k] in low_case_list: | |
| new_num += map_dict[num[k]] | |
| else: | |
| new_num += num[k] | |
| return new_num | |
| def __call__(self, text_seq, stru_seqs, sem_seqs, exp): | |
| if random.random() < self.prob: | |
| map_dict = self.get_arg_map() | |
| for k in range(len(text_seq.token)): | |
| if text_seq.class_tag[k] == '[NUM]': | |
| text_seq.token[k] = self.map_arg_in_num(map_dict, text_seq.token[k]) | |
| if text_seq.class_tag[k] == '[ARG]': | |
| text_seq.token[k] = map_dict[text_seq.token[k]] | |
| for k in range(len(sem_seqs.token)): | |
| for j in range(len(sem_seqs.token[k])): | |
| if sem_seqs.class_tag[k][j] == '[NUM]': | |
| sem_seqs.token[k][j] = self.map_arg_in_num(map_dict, sem_seqs.token[k][j]) | |
| for k in range(len(exp)): | |
| if exp[k] in low_case_list: | |
| exp[k] = map_dict[exp[k]] | |
| class StruPoint_RandomRotate(object): | |
| def __init__(self, prob=0.5): | |
| self.prob = prob | |
| def get_seq_points(self, class_tag): | |
| id_list = [] | |
| begin_point_id = end_point_id = None | |
| for id, token in enumerate(class_tag): | |
| if token == '[POINT]': | |
| if begin_point_id is None: | |
| begin_point_id = id | |
| elif not begin_point_id is None and end_point_id is None: | |
| end_point_id = id | |
| id_list.append([begin_point_id, end_point_id]) | |
| begin_point_id = end_point_id = None | |
| if not begin_point_id is None and end_point_id is None: | |
| id_list.append([begin_point_id, len(class_tag)]) | |
| return id_list[-1][0], id_list[-1][1] | |
| def __call__(self, text_seq, stru_seqs, sem_seqs, exp): | |
| for k in range(len(stru_seqs.token)): | |
| if random.random() < self.prob: | |
| begin_id, end_id = self.get_seq_points(stru_seqs.class_tag[k]) | |
| # point on line | |
| if stru_seqs.token[k][0] == 'line': | |
| stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][end_id-1:begin_id-1:-1] | |
| # point on circle | |
| if stru_seqs.token[k][0] == '\\odot': | |
| # clockwise change | |
| if random.random() < 0.5: | |
| stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][end_id-1:begin_id-1:-1] | |
| # set initial point | |
| init_loc = random.randint(begin_id, end_id-1) | |
| stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][init_loc:end_id] + \ | |
| stru_seqs.token[k][begin_id:init_loc] | |
| class SemPoint_RandomRotate(object): | |
| def __init__(self, prob=0.5): | |
| self.prob = prob | |
| def get_seq_points(self, class_tag): | |
| id_list = [] | |
| begin_point_id = end_point_id = None | |
| for id, token in enumerate(class_tag): | |
| if token == '[POINT]': | |
| if begin_point_id is None: | |
| begin_point_id = id | |
| elif not begin_point_id is None and end_point_id is None: | |
| end_point_id = id | |
| id_list.append((begin_point_id, end_point_id-1)) | |
| begin_point_id = end_point_id = None | |
| if not begin_point_id is None and end_point_id is None: | |
| id_list.append((begin_point_id, len(class_tag)-1)) | |
| return id_list | |
| def __call__(self, text_seq, stru_seqs, sem_seqs, exp): | |
| if random.random() < self.prob: | |
| for k in range(len(sem_seqs.token)): | |
| id_list = self.get_seq_points(sem_seqs.class_tag[k]) | |
| for begin_id, end_id in id_list: | |
| if random.random() < self.prob: | |
| sem_seqs.token[k][begin_id], sem_seqs.token[k][end_id] = \ | |
| sem_seqs.token[k][end_id], sem_seqs.token[k][begin_id] | |
| class SemSeq_RandomRotate(object): | |
| def __init__(self, prob=0.5): | |
| if prob==0: | |
| self.prob = 0 | |
| else: | |
| self.prob = prob + 0.2 | |
| def __call__(self, text_seq, stru_seqs, sem_seqs, exp): | |
| if random.random() < self.prob: | |
| # varible id | |
| num_all_list, num_sem_list, num_map_list = [], [], [] | |
| for item in text_seq.class_tag: | |
| if item=='[NUM]': | |
| var_name = 'N'+str(len(num_all_list)) | |
| num_all_list.append(var_name) | |
| num_map_list.append(var_name) | |
| for k in range(len(sem_seqs.token)): | |
| if sem_seqs.class_tag[k][-2] == '[NUM]': | |
| var_name = 'N'+str(len(num_all_list)) | |
| num_all_list.append(var_name) | |
| num_sem_list.append([var_name]) | |
| else: | |
| num_sem_list.append([]) | |
| # shuffle sem_seq | |
| if len(sem_seqs.token)>0: | |
| random_id_list = [k for k in range(len(sem_seqs.token))] | |
| random.shuffle(random_id_list) | |
| for key,value in vars(sem_seqs).items(): | |
| _, value = zip(*sorted(zip(random_id_list, value))) | |
| setattr(sem_seqs, key, list(value)) | |
| _, num_sem_list = zip(*sorted(zip(random_id_list, num_sem_list))) | |
| # expression map | |
| for k in range(len(sem_seqs.token)): | |
| num_map_list += num_sem_list[k] | |
| num_map_dict = {key:value for key, value in zip(num_map_list, num_all_list)} | |
| for k in range(len(exp)): | |
| if exp[k] in num_map_dict: | |
| exp[k] = num_map_dict[exp[k]] | |
| class StruSeq_RandomRotate(object): | |
| def __init__(self, prob=0.5): | |
| self.prob = prob | |
| def __call__(self, text_seq, stru_seqs, sem_seqs, exp): | |
| if random.random() < self.prob: | |
| # shuffle stru_seq | |
| if len(stru_seqs.token)>0: | |
| random_id_list = [k for k in range(len(stru_seqs.token))] | |
| random.shuffle(random_id_list) | |
| for key, value in vars(stru_seqs).items(): | |
| _, value = zip(*sorted(zip(random_id_list, value))) | |
| setattr(stru_seqs, key, list(value)) | |