import torch import json from datasets.utils import * class SrcLang: def __init__(self, vocab_path): self.word2index = {} self.word2count = {} self.index2word = [] self.n_words = 0 self.get_vocab(vocab_path) self.class_tag = ['[PAD]', '[GEN]', '[POINT]', '[NUM]', '[ARG]', '[ANGID]'] self.sect_tag = ['[PAD]', '[PROB]', '[COND]', '[STRU]'] def get_vocab(self, vocab_path): with open(vocab_path, 'r') as f: for id, line in enumerate(f): vocab_token = line[:-1] self.word2index[vocab_token] = id self.word2count[vocab_token] = 0 self.index2word.append(vocab_token) self.n_words = len(self.index2word) def indexes_from_sentence(self, sentence, id_type='text'): res = [] if id_type == 'text': for word in sentence: if word in self.word2index: res.append(self.word2index[word]) self.word2count[word] += 1 else: res.append(self.word2index["[UNK]"]) self.word2count["[UNK]"] += 1 print("Can not find", word, 'in the src vocab') elif id_type=='class_tag': for word in sentence: res.append(self.class_tag.index(word)) elif id_type=='sect_tag': for word in sentence: res.append(self.sect_tag.index(word)) return res def sentence_from_indexes(self, indexes): res = [] for index in indexes: if index