Spaces:
Sleeping
Sleeping
File size: 10,851 Bytes
383bfb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
punctuation_list = ['.', '?', ',']
digit_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
capital_letter_list = [chr(item) for item in range(65, 91)]
low_letter_list = [chr(item) for item in range(97, 123)]
begin_words = ["find", "what", "solve", "determine", "express", "how"]
end_words = [".", ",", '?', "if", "so", "for which", "given", "with", "on",
"in", "must", 'for', 'that', 'formed']
unit_list = ["mm^{2}", "cm^{2}", "in^{2}", "ft^{2}",
"yd^{2}", "km^{2}", "units^{2}", "mi^{2}", "m^{2}"]
special_token_list = ['\\frac', '\\pi', '\\sqrt', "+", "-", "^"]
def get_token(ss):
"""
Tokenizer: divide the textual problem into words
"""
raw_str_list = ss.strip().split(' ')
# Split punctuation
new_str1_list = []
for item in raw_str_list:
if item[-1] in punctuation_list:
new_str1_list.append(item[:-1])
new_str1_list.append(item[-1])
else:
new_str1_list.append(item)
# Split points (capital letters)
new_str2_list = []
for item in new_str1_list:
is_geo_rep = True
point_list = []
for k in item:
if (ord(k) >= 65 and ord(k) <= 90) or \
((k == '\'' or k in digit_list) and len(point_list) > 0):
if k == '\'' or k in digit_list:
point_list[-1] += k
else:
point_list.append(k)
else:
is_geo_rep = False
break
if is_geo_rep:
new_str2_list += point_list
else:
new_str2_list.append(item.lower())
return new_str2_list
def split_text(text_data):
"""
split textual problem into condition and problem(target)
"""
if len(text_data.token) == 0:
return
begin_ind = 0
end_ind = len(text_data.token)
for id, token in enumerate(text_data.token):
if token in begin_words:
begin_ind = id
break
for id in range(begin_ind+2, len(text_data.token)):
if text_data.token[id] in end_words:
if text_data.token[id] in punctuation_list:
end_ind = id + 1
else:
end_ind = id
break
text_data.sect_tag = ['[COND]']*len(text_data.token[:begin_ind]) + \
['[PROB]']*len(text_data.token[begin_ind: end_ind]) + \
['[COND]']*len(text_data.token[end_ind:])
def get_point_angleID_tag(text_data, stru_data, sem_data):
for id, item in enumerate(text_data.token):
if item[0] in capital_letter_list:
text_data.class_tag[id] = '[POINT]'
if item.isdigit() and id > 0 and text_data.token[id-1] == "\\angle":
text_data.class_tag[id] = '[ANGID]'
for k in range(len(stru_data.token)):
for id, item in enumerate(stru_data.token[k]):
if item[0] in capital_letter_list:
stru_data.class_tag[k][id] = '[POINT]'
if item.isdigit() and id > 0 and stru_data.token[k][id-1] == "\\angle":
stru_data.class_tag[k][id] = '[ANGID]'
for k in range(len(sem_data.token)):
for id, item in enumerate(sem_data.token[k]):
if item[0] in capital_letter_list:
sem_data.class_tag[k][id] = '[POINT]'
if item.isdigit() and id > 0 and sem_data.token[k][id-1] == "\\angle":
sem_data.class_tag[k][id] = '[ANGID]'
def get_args(token):
letter_list = []
for special_token in special_token_list:
token = token.replace(special_token, "")
for letter in token:
if letter in low_letter_list and not letter in letter_list:
letter_list.append(letter)
return letter_list
def get_num_arg_tag(text_data, sem_data):
"""
Determine the variables/arguments in the text condition
"""
arg_sem_flat = []
for k in range(len(sem_data.token)):
if len(sem_data.token[k]) >= 3 and sem_data.token[k][-3] == '=':
sem_data.class_tag[k][-2] = '[NUM]'
arg_sem_flat += get_args(sem_data.token[k][-2])
for id, token in enumerate(text_data.token):
if text_data.sect_tag[id] == '[COND]' and text_data.class_tag[id] == '[GEN]':
# unit symbol
if token in unit_list:
continue
# digit existing (rough judgment)
for word in digit_list:
if word in token:
text_data.class_tag[id] = '[NUM]'
break
# There are special characters, but not only special characters
for word in special_token_list:
if word in token and word != token:
text_data.class_tag[id] = '[NUM]'
break
# Single lowercase letter, but not special cases
if text_data.token[id] in low_letter_list:
if id < len(text_data.token)-1 and text_data.token[id+1] == '=':
continue
if text_data.token[id] == 'm' and id < len(text_data.token)-1 and text_data.token[id+1] in ["\\angle", "\\widehat"]:
continue
if text_data.token[id] == 'a' and (id == 0 or text_data.token[id-1] != '='):
continue
if not text_data.token[id] in arg_sem_flat and \
id > 0 and ('line' in text_data.token[id-1] or text_data.token[id-1] == 'and' or
(text_data.token[id-1] == ',' and text_data.token[id+1] == ',')):
continue
text_data.class_tag[id] = '[NUM]'
arg_text_flat = []
for id, token in enumerate(text_data.token):
if text_data.sect_tag[id] == '[COND]' and text_data.class_tag[id] == '[NUM]':
arg_text_flat += get_args(token)
# Determine arguments
arg_all_flat = arg_text_flat + arg_sem_flat
for id, token in enumerate(text_data.token):
if text_data.class_tag[id] == '[GEN]' \
and text_data.token[id] in arg_all_flat:
if id < len(text_data.token)-1 and text_data.token[id+1] == '=':
text_data.class_tag[id] = '[ARG]'
continue
if text_data.token[id] == 'm' and id < len(text_data.token)-1 and text_data.token[id+1] in ["\\angle", "\\widehat"]:
continue
if text_data.token[id] == 'a' and (id == 0 or text_data.token[id-1] != '=') and \
text_data.sect_tag[id]=='[COND]':
continue
if id > 0 and ('line' in text_data.token[id-1] or text_data.token[id-1] == 'and' or
(text_data.token[id-1] == ',' and text_data.token[id+1] == ',')):
continue
text_data.class_tag[id] = '[ARG]'
def remove_sem_dup(text_data, sem_data, exp_token):
"""
Remove the seq of sem_data if num is also in the text_data
and change the corresponding expression
"""
text_num_list, id_all_list, id_map_list = [], [], []
token_, sect_tag_, class_tag_ = [], [], []
for k in range(len(text_data.token)):
if text_data.class_tag[k] == '[NUM]':
text_num_list.append(text_data.token[k])
var_name = 'N'+str(len(id_all_list))
id_all_list.append(var_name)
id_map_list.append(var_name)
for k in range(len(sem_data.token)):
if sem_data.class_tag[k][-2] == '[NUM]':
var_name = 'N'+str(len(id_all_list))
id_all_list.append(var_name)
if not sem_data.token[k][-2] in text_num_list:
token_.append(sem_data.token[k])
sect_tag_.append(sem_data.sect_tag[k])
class_tag_.append(sem_data.class_tag[k])
id_map_list.append(var_name)
else:
token_.append(sem_data.token[k])
sect_tag_.append(sem_data.sect_tag[k])
class_tag_.append(sem_data.class_tag[k])
num_map_dict = {key:value for key, value in zip(id_map_list, id_all_list)}
for k in range(len(exp_token)):
if exp_token[k] in num_map_dict:
exp_token[k] = num_map_dict[exp_token[k]]
sem_data.token = token_
sem_data.sect_tag = sect_tag_
sem_data.class_tag = class_tag_
def get_combined_text(text_seq, stru_seqs, sem_seqs, combine_text, args):
'''
combination style: [stru_seqs, text_cond, sem_seqs, text_prob]
'''
# split cond and prob in text_seq
begin_ind = end_ind = None
for k in range(len(text_seq.sect_tag)):
if text_seq.sect_tag[k]=='[PROB]':
begin_ind = k
break
for k in range(len(text_seq.sect_tag)-1,-1,-1):
if text_seq.sect_tag[k]=='[PROB]':
end_ind = k+1
break
# combine text_seq, stru_seqs and sem_seqs
for key in vars(combine_text):
# get text_cond and text_prob
text_all_value = getattr(text_seq, key)
text_cond_value = text_all_value[:begin_ind] + text_all_value[end_ind:]
text_prob_value = text_all_value[begin_ind:end_ind]
if args.without_stru:
value_all = text_cond_value + sum(getattr(sem_seqs, key), []) + text_prob_value
else:
value_all = sum(getattr(stru_seqs, key), []) + text_cond_value + \
sum(getattr(sem_seqs, key), []) + text_prob_value
setattr(combine_text, key, value_all)
def get_var_arg(combine_text, args):
var_values, arg_values = [], []
var_positions, arg_positions = [], []
class_tag = combine_text.class_tag
token = combine_text.token
for k in range(len(class_tag)):
if class_tag[k] == '[NUM]':
var_values.append(token[k])
var_positions.append(k)
if class_tag[k] == '[ARG]':
arg_values.append(token[k])
arg_positions.append(k)
# merge position of var and arg
return var_positions+arg_positions, var_values, arg_values
def get_text_index(combine_text, src_lang):
text_sect_tag = src_lang.indexes_from_sentence(combine_text.sect_tag, id_type='sect_tag')
text_class_tag = src_lang.indexes_from_sentence(combine_text.class_tag, id_type='class_tag')
text_token = [combine_text.token[:], ['[PAD]']*len(combine_text.token)]
for k in range(len(combine_text.class_tag)):
if combine_text.class_tag[k] == '[NUM]':
letter_list = get_args(combine_text.token[k])
text_token[0][k] = text_token[1][k] = "[PAD]"
for j in range(len(letter_list)):
text_token[j][k] = letter_list[j]
text_token = [src_lang.indexes_from_sentence(item, id_type='text') for item in text_token]
return text_token, text_sect_tag, text_class_tag
|