pgps-demo / datasets /preprossing.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
import torch
import json
from datasets.utils import *
class SrcLang:
def __init__(self, vocab_path):
self.word2index = {}
self.word2count = {}
self.index2word = []
self.n_words = 0
self.get_vocab(vocab_path)
self.class_tag = ['[PAD]', '[GEN]', '[POINT]', '[NUM]', '[ARG]', '[ANGID]']
self.sect_tag = ['[PAD]', '[PROB]', '[COND]', '[STRU]']
def get_vocab(self, vocab_path):
with open(vocab_path, 'r') as f:
for id, line in enumerate(f):
vocab_token = line[:-1]
self.word2index[vocab_token] = id
self.word2count[vocab_token] = 0
self.index2word.append(vocab_token)
self.n_words = len(self.index2word)
def indexes_from_sentence(self, sentence, id_type='text'):
res = []
if id_type == 'text':
for word in sentence:
if word in self.word2index:
res.append(self.word2index[word])
self.word2count[word] += 1
else:
res.append(self.word2index["[UNK]"])
self.word2count["[UNK]"] += 1
print("Can not find", word, 'in the src vocab')
elif id_type=='class_tag':
for word in sentence: res.append(self.class_tag.index(word))
elif id_type=='sect_tag':
for word in sentence: res.append(self.sect_tag.index(word))
return res
def sentence_from_indexes(self, indexes):
res = []
for index in indexes:
if index<len(self.index2word):
res.append(self.index2word[index])
else:
res.append("")
return res
class TgtLang:
def __init__(self, vocab_path):
self.word2index = {}
self.word2count = {}
self.index2word = []
self.n_words = 0
self.var_start = 0
self.get_vocab(vocab_path)
def get_vocab(self, vocab_path):
spe_num = midvar_num = const_num = 0
op_num = var_num = 0
with open(vocab_path, 'r') as f:
for id, line in enumerate(f):
vocab_token = line[:-1]
self.word2index[vocab_token] = id
self.word2count[vocab_token] = 0
self.index2word.append(vocab_token)
if vocab_token[0]=='[' and vocab_token[-1]==']':
spe_num += 1
elif vocab_token[0]=='V' and vocab_token[1].isdigit():
midvar_num += 1
elif vocab_token[0]=='C' and vocab_token[1].isdigit():
const_num += 1
elif vocab_token[0]=='N' and vocab_token[1].isdigit():
var_num += 1
else:
op_num += 1
self.n_words = len(self.index2word)
self.var_start = spe_num + midvar_num + const_num + op_num
def indexes_from_sentence(self, sentence, var_values, arg_values):
res = []
for word in sentence:
if word in self.word2index:
res.append(self.word2index[word])
self.word2count[word] += 1
elif len(word)==1 and word.islower(): # arg
res.append(self.var_start+len(var_values)+arg_values.index(word))
else:
print("Can not find", word, 'in the tgt vocab')
res = [self.word2index["[SOS]"]]+res+[self.word2index["[EOS]"]]
return res
def sentence_from_indexes(self, indexes, change_dict={}):
res = []
for index in indexes:
if index<len(self.index2word):
item = self.index2word[index]
else:
item = ''
if item in change_dict: item = change_dict[item] # var2arg
res.append(item)
return res
class SN:
def __init__(self):
self.token = [] # str list
self.sect_tag = [] # [PROB]/[COND]/[STRU]
self.class_tag = [] # [GEN]/[NUM]/[ARG]/[POINT]/[ANGID]
def get_raw_pairs(dataset_path):
raw_pairs = []
with open(dataset_path, 'r')as fp:
content_all = json.load(fp)
for key, content in content_all.items():
text = content['text']
stru_seqs = content['parsing_stru_seqs']
sem_seqs = content['parsing_sem_seqs']
text_data, stru_data, sem_data = SN(), SN(), SN()
# tokenization
text_data.token = get_token(text)
stru_data.token = [get_token(item)+[','] for item in stru_seqs]
sem_data.token = [get_token(item)+[','] for item in sem_seqs]
# split prob and cond
text_data.sect_tag = []
stru_data.sect_tag = [['[STRU]']*len(item) for item in stru_data.token]
sem_data.sect_tag = [['[COND]']*len(item) for item in sem_data.token]
split_text(text_data)
# get class tag
text_data.class_tag = ['[GEN]']*len(text_data.token)
stru_data.class_tag = [['[GEN]']*len(item) for item in stru_data.token]
sem_data.class_tag = [['[GEN]']*len(item) for item in sem_data.token]
get_point_angleID_tag(text_data, stru_data, sem_data)
get_num_arg_tag(text_data, sem_data)
# Tag the repeat [NUM] in sem_data which has exist in text_data
expression = content['expression'].split(' ')
remove_sem_dup(text_data, sem_data, expression)
content['text'] = text_data
content['parsing_stru_seqs'] = stru_data
content['parsing_sem_seqs'] = sem_data
content['expression'] = expression
content['id'] = key
raw_pairs.append(content)
return raw_pairs
class collater():
def __init__(self, args):
self.args = args
def __call__(self, batch_data, padding_id=0):
diagrams, \
text_tokens, text_sect_tags, text_class_tags, \
var_arg_positions, var_values, arg_values, \
expression, answer, pair_ids, choices = list(zip(*batch_data))
#######################################
diagrams = torch.stack(diagrams, dim=0)
#######################################
len_exp = [len(seq_exp) for seq_exp in expression]
max_len_exp = max(len_exp)
expression = [seq_exp+[padding_id]*(max_len_exp-len(seq_exp)) for seq_exp in expression]
exp_dict = {'exp': torch.LongTensor(expression),
'len': torch.LongTensor(len_exp),
'answer': answer,
'id': pair_ids,
'choices': choices
}
#######################################
len_var = [max(len(seq_var),1) for seq_var in var_arg_positions]
max_len_var = max(len_var)
var_arg_positions = [seq_var+[padding_id]*(max_len_var-len(seq_var)) for seq_var in var_arg_positions]
var_dict = {'pos':torch.LongTensor(var_arg_positions),
'len': torch.LongTensor(len_var),
'var_value': var_values,
'arg_value': arg_values
}
########################################
len_text = [len(seq_tag) for seq_tag in text_class_tags]
max_len_text = max(len_text)
for k in range(len(text_tokens)):
for j in range(len(text_tokens[k])):
text_tokens[k][j] += [padding_id]*(max_len_text-len(text_tokens[k][j]))
text_sect_tags = [seq_tag+[padding_id]*(max_len_text-len(seq_tag)) for seq_tag in text_sect_tags]
text_class_tags = [seq_tag+[padding_id]*(max_len_text-len(seq_tag)) for seq_tag in text_class_tags]
text_dict = {'token': torch.LongTensor(text_tokens),
'sect_tag': torch.LongTensor(text_sect_tags),
'class_tag': torch.LongTensor(text_class_tags),
'len': torch.LongTensor(len_text)
}
return diagrams, text_dict, var_dict, exp_dict