pgps-demo / datasets /dataset.py
asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
import torch
import os
from PIL import Image
import datasets.diagram_aug as T_diagram
import datasets.text_aug as T_text
from datasets.operators import normalize_exp
from datasets.utils import get_combined_text, get_var_arg, get_text_index
from datasets.preprossing import SN
class MyDataset(torch.utils.data.Dataset):
def __init__(self, args, pairs, src_lang, tgt_lang, is_train=True):
super().__init__()
self.args = args
self.pairs = pairs
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.is_train = is_train
if is_train:
random_prob = args.random_prob
else:
random_prob = 0
self.diagram_transform = T_diagram.Compose([
T_diagram.Resize(args.diagram_size),
T_diagram.CenterCrop(args.diagram_size),
T_diagram.RandomFlip(random_prob),
T_diagram.ToTensor(),
T_diagram.Normalize()
])
self.text_transform = T_text.Compose([
T_text.Point_RandomReplace(random_prob),
T_text.AngID_RandomReplace(random_prob),
# T_text.Arg_RandomReplace(random_prob),
T_text.StruPoint_RandomRotate(random_prob),
# T_text.SemPoint_RandomRotate(random_prob),
T_text.SemSeq_RandomRotate(random_prob),
T_text.StruSeq_RandomRotate(random_prob),
])
def __getitem__(self, idx):
'''
pair{
'diagram': str
'text': SN()
'parsing_stru_seqs': SN()
'parsing_sem_seqs': SN()
'expression': list
'answer': str
}
'''
pair = self.pairs[idx]
# diagram
diagram_path = os.path.join(self.args.dataset_dir, 'Diagram', pair['diagram'])
diagram = Image.open(diagram_path).convert("RGB")
diagram = self.diagram_transform(diagram)
# text, parsing_stru_seqs, parsing_sem_seqs,
self.text_transform(pair['text'],
pair['parsing_stru_seqs'],
pair['parsing_sem_seqs'],
pair['expression'])
combine_text = SN()
get_combined_text(pair['text'],
pair['parsing_stru_seqs'],
pair['parsing_sem_seqs'],
combine_text,
self.args)
text_token, text_sect_tag, text_class_tag = \
get_text_index(combine_text, self.src_lang)
# var and arg
var_arg_positions, var_values, arg_values = \
get_var_arg(combine_text, self.args)
# expression
expression = normalize_exp(pair['expression'])
expression = self.tgt_lang.indexes_from_sentence(expression, var_values, arg_values)
# choices
choices = [float(item) for item in pair['choices']]
return diagram, \
text_token, text_sect_tag, text_class_tag, \
var_arg_positions, var_values, arg_values, \
expression, pair['answer'], pair['id'], choices
def __len__(self):
return len(self.pairs)