Spaces:
Sleeping
Sleeping
| punctuation_list = ['.', '?', ','] | |
| digit_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] | |
| capital_letter_list = [chr(item) for item in range(65, 91)] | |
| low_letter_list = [chr(item) for item in range(97, 123)] | |
| begin_words = ["find", "what", "solve", "determine", "express", "how"] | |
| end_words = [".", ",", '?', "if", "so", "for which", "given", "with", "on", | |
| "in", "must", 'for', 'that', 'formed'] | |
| unit_list = ["mm^{2}", "cm^{2}", "in^{2}", "ft^{2}", | |
| "yd^{2}", "km^{2}", "units^{2}", "mi^{2}", "m^{2}"] | |
| special_token_list = ['\\frac', '\\pi', '\\sqrt', "+", "-", "^"] | |
| def get_token(ss): | |
| """ | |
| Tokenizer: divide the textual problem into words | |
| """ | |
| raw_str_list = ss.strip().split(' ') | |
| # Split punctuation | |
| new_str1_list = [] | |
| for item in raw_str_list: | |
| if item[-1] in punctuation_list: | |
| new_str1_list.append(item[:-1]) | |
| new_str1_list.append(item[-1]) | |
| else: | |
| new_str1_list.append(item) | |
| # Split points (capital letters) | |
| new_str2_list = [] | |
| for item in new_str1_list: | |
| is_geo_rep = True | |
| point_list = [] | |
| for k in item: | |
| if (ord(k) >= 65 and ord(k) <= 90) or \ | |
| ((k == '\'' or k in digit_list) and len(point_list) > 0): | |
| if k == '\'' or k in digit_list: | |
| point_list[-1] += k | |
| else: | |
| point_list.append(k) | |
| else: | |
| is_geo_rep = False | |
| break | |
| if is_geo_rep: | |
| new_str2_list += point_list | |
| else: | |
| new_str2_list.append(item.lower()) | |
| return new_str2_list | |
| def split_text(text_data): | |
| """ | |
| split textual problem into condition and problem(target) | |
| """ | |
| if len(text_data.token) == 0: | |
| return | |
| begin_ind = 0 | |
| end_ind = len(text_data.token) | |
| for id, token in enumerate(text_data.token): | |
| if token in begin_words: | |
| begin_ind = id | |
| break | |
| for id in range(begin_ind+2, len(text_data.token)): | |
| if text_data.token[id] in end_words: | |
| if text_data.token[id] in punctuation_list: | |
| end_ind = id + 1 | |
| else: | |
| end_ind = id | |
| break | |
| text_data.sect_tag = ['[COND]']*len(text_data.token[:begin_ind]) + \ | |
| ['[PROB]']*len(text_data.token[begin_ind: end_ind]) + \ | |
| ['[COND]']*len(text_data.token[end_ind:]) | |
| def get_point_angleID_tag(text_data, stru_data, sem_data): | |
| for id, item in enumerate(text_data.token): | |
| if item[0] in capital_letter_list: | |
| text_data.class_tag[id] = '[POINT]' | |
| if item.isdigit() and id > 0 and text_data.token[id-1] == "\\angle": | |
| text_data.class_tag[id] = '[ANGID]' | |
| for k in range(len(stru_data.token)): | |
| for id, item in enumerate(stru_data.token[k]): | |
| if item[0] in capital_letter_list: | |
| stru_data.class_tag[k][id] = '[POINT]' | |
| if item.isdigit() and id > 0 and stru_data.token[k][id-1] == "\\angle": | |
| stru_data.class_tag[k][id] = '[ANGID]' | |
| for k in range(len(sem_data.token)): | |
| for id, item in enumerate(sem_data.token[k]): | |
| if item[0] in capital_letter_list: | |
| sem_data.class_tag[k][id] = '[POINT]' | |
| if item.isdigit() and id > 0 and sem_data.token[k][id-1] == "\\angle": | |
| sem_data.class_tag[k][id] = '[ANGID]' | |
| def get_args(token): | |
| letter_list = [] | |
| for special_token in special_token_list: | |
| token = token.replace(special_token, "") | |
| for letter in token: | |
| if letter in low_letter_list and not letter in letter_list: | |
| letter_list.append(letter) | |
| return letter_list | |
| def get_num_arg_tag(text_data, sem_data): | |
| """ | |
| Determine the variables/arguments in the text condition | |
| """ | |
| arg_sem_flat = [] | |
| for k in range(len(sem_data.token)): | |
| if len(sem_data.token[k]) >= 3 and sem_data.token[k][-3] == '=': | |
| sem_data.class_tag[k][-2] = '[NUM]' | |
| arg_sem_flat += get_args(sem_data.token[k][-2]) | |
| for id, token in enumerate(text_data.token): | |
| if text_data.sect_tag[id] == '[COND]' and text_data.class_tag[id] == '[GEN]': | |
| # unit symbol | |
| if token in unit_list: | |
| continue | |
| # digit existing (rough judgment) | |
| for word in digit_list: | |
| if word in token: | |
| text_data.class_tag[id] = '[NUM]' | |
| break | |
| # There are special characters, but not only special characters | |
| for word in special_token_list: | |
| if word in token and word != token: | |
| text_data.class_tag[id] = '[NUM]' | |
| break | |
| # Single lowercase letter, but not special cases | |
| if text_data.token[id] in low_letter_list: | |
| if id < len(text_data.token)-1 and text_data.token[id+1] == '=': | |
| continue | |
| if text_data.token[id] == 'm' and id < len(text_data.token)-1 and text_data.token[id+1] in ["\\angle", "\\widehat"]: | |
| continue | |
| if text_data.token[id] == 'a' and (id == 0 or text_data.token[id-1] != '='): | |
| continue | |
| if not text_data.token[id] in arg_sem_flat and \ | |
| id > 0 and ('line' in text_data.token[id-1] or text_data.token[id-1] == 'and' or | |
| (text_data.token[id-1] == ',' and text_data.token[id+1] == ',')): | |
| continue | |
| text_data.class_tag[id] = '[NUM]' | |
| arg_text_flat = [] | |
| for id, token in enumerate(text_data.token): | |
| if text_data.sect_tag[id] == '[COND]' and text_data.class_tag[id] == '[NUM]': | |
| arg_text_flat += get_args(token) | |
| # Determine arguments | |
| arg_all_flat = arg_text_flat + arg_sem_flat | |
| for id, token in enumerate(text_data.token): | |
| if text_data.class_tag[id] == '[GEN]' \ | |
| and text_data.token[id] in arg_all_flat: | |
| if id < len(text_data.token)-1 and text_data.token[id+1] == '=': | |
| text_data.class_tag[id] = '[ARG]' | |
| continue | |
| if text_data.token[id] == 'm' and id < len(text_data.token)-1 and text_data.token[id+1] in ["\\angle", "\\widehat"]: | |
| continue | |
| if text_data.token[id] == 'a' and (id == 0 or text_data.token[id-1] != '=') and \ | |
| text_data.sect_tag[id]=='[COND]': | |
| continue | |
| if id > 0 and ('line' in text_data.token[id-1] or text_data.token[id-1] == 'and' or | |
| (text_data.token[id-1] == ',' and text_data.token[id+1] == ',')): | |
| continue | |
| text_data.class_tag[id] = '[ARG]' | |
| def remove_sem_dup(text_data, sem_data, exp_token): | |
| """ | |
| Remove the seq of sem_data if num is also in the text_data | |
| and change the corresponding expression | |
| """ | |
| text_num_list, id_all_list, id_map_list = [], [], [] | |
| token_, sect_tag_, class_tag_ = [], [], [] | |
| for k in range(len(text_data.token)): | |
| if text_data.class_tag[k] == '[NUM]': | |
| text_num_list.append(text_data.token[k]) | |
| var_name = 'N'+str(len(id_all_list)) | |
| id_all_list.append(var_name) | |
| id_map_list.append(var_name) | |
| for k in range(len(sem_data.token)): | |
| if sem_data.class_tag[k][-2] == '[NUM]': | |
| var_name = 'N'+str(len(id_all_list)) | |
| id_all_list.append(var_name) | |
| if not sem_data.token[k][-2] in text_num_list: | |
| token_.append(sem_data.token[k]) | |
| sect_tag_.append(sem_data.sect_tag[k]) | |
| class_tag_.append(sem_data.class_tag[k]) | |
| id_map_list.append(var_name) | |
| else: | |
| token_.append(sem_data.token[k]) | |
| sect_tag_.append(sem_data.sect_tag[k]) | |
| class_tag_.append(sem_data.class_tag[k]) | |
| num_map_dict = {key:value for key, value in zip(id_map_list, id_all_list)} | |
| for k in range(len(exp_token)): | |
| if exp_token[k] in num_map_dict: | |
| exp_token[k] = num_map_dict[exp_token[k]] | |
| sem_data.token = token_ | |
| sem_data.sect_tag = sect_tag_ | |
| sem_data.class_tag = class_tag_ | |
| def get_combined_text(text_seq, stru_seqs, sem_seqs, combine_text, args): | |
| ''' | |
| combination style: [stru_seqs, text_cond, sem_seqs, text_prob] | |
| ''' | |
| # split cond and prob in text_seq | |
| begin_ind = end_ind = None | |
| for k in range(len(text_seq.sect_tag)): | |
| if text_seq.sect_tag[k]=='[PROB]': | |
| begin_ind = k | |
| break | |
| for k in range(len(text_seq.sect_tag)-1,-1,-1): | |
| if text_seq.sect_tag[k]=='[PROB]': | |
| end_ind = k+1 | |
| break | |
| # combine text_seq, stru_seqs and sem_seqs | |
| for key in vars(combine_text): | |
| # get text_cond and text_prob | |
| text_all_value = getattr(text_seq, key) | |
| text_cond_value = text_all_value[:begin_ind] + text_all_value[end_ind:] | |
| text_prob_value = text_all_value[begin_ind:end_ind] | |
| if args.without_stru: | |
| value_all = text_cond_value + sum(getattr(sem_seqs, key), []) + text_prob_value | |
| else: | |
| value_all = sum(getattr(stru_seqs, key), []) + text_cond_value + \ | |
| sum(getattr(sem_seqs, key), []) + text_prob_value | |
| setattr(combine_text, key, value_all) | |
| def get_var_arg(combine_text, args): | |
| var_values, arg_values = [], [] | |
| var_positions, arg_positions = [], [] | |
| class_tag = combine_text.class_tag | |
| token = combine_text.token | |
| for k in range(len(class_tag)): | |
| if class_tag[k] == '[NUM]': | |
| var_values.append(token[k]) | |
| var_positions.append(k) | |
| if class_tag[k] == '[ARG]': | |
| arg_values.append(token[k]) | |
| arg_positions.append(k) | |
| # merge position of var and arg | |
| return var_positions+arg_positions, var_values, arg_values | |
| def get_text_index(combine_text, src_lang): | |
| text_sect_tag = src_lang.indexes_from_sentence(combine_text.sect_tag, id_type='sect_tag') | |
| text_class_tag = src_lang.indexes_from_sentence(combine_text.class_tag, id_type='class_tag') | |
| text_token = [combine_text.token[:], ['[PAD]']*len(combine_text.token)] | |
| for k in range(len(combine_text.class_tag)): | |
| if combine_text.class_tag[k] == '[NUM]': | |
| letter_list = get_args(combine_text.token[k]) | |
| text_token[0][k] = text_token[1][k] = "[PAD]" | |
| for j in range(len(letter_list)): | |
| text_token[j][k] = letter_list[j] | |
| text_token = [src_lang.indexes_from_sentence(item, id_type='text') for item in text_token] | |
| return text_token, text_sect_tag, text_class_tag | |