Spaces:
Sleeping
Sleeping
| import torch | |
| import os | |
| from PIL import Image | |
| import datasets.diagram_aug as T_diagram | |
| import datasets.text_aug as T_text | |
| from datasets.operators import normalize_exp | |
| from datasets.utils import get_combined_text, get_var_arg, get_text_index | |
| from datasets.preprossing import SN | |
| class MyDataset(torch.utils.data.Dataset): | |
| def __init__(self, args, pairs, src_lang, tgt_lang, is_train=True): | |
| super().__init__() | |
| self.args = args | |
| self.pairs = pairs | |
| self.src_lang = src_lang | |
| self.tgt_lang = tgt_lang | |
| self.is_train = is_train | |
| if is_train: | |
| random_prob = args.random_prob | |
| else: | |
| random_prob = 0 | |
| self.diagram_transform = T_diagram.Compose([ | |
| T_diagram.Resize(args.diagram_size), | |
| T_diagram.CenterCrop(args.diagram_size), | |
| T_diagram.RandomFlip(random_prob), | |
| T_diagram.ToTensor(), | |
| T_diagram.Normalize() | |
| ]) | |
| self.text_transform = T_text.Compose([ | |
| T_text.Point_RandomReplace(random_prob), | |
| T_text.AngID_RandomReplace(random_prob), | |
| # T_text.Arg_RandomReplace(random_prob), | |
| T_text.StruPoint_RandomRotate(random_prob), | |
| # T_text.SemPoint_RandomRotate(random_prob), | |
| T_text.SemSeq_RandomRotate(random_prob), | |
| T_text.StruSeq_RandomRotate(random_prob), | |
| ]) | |
| def __getitem__(self, idx): | |
| ''' | |
| pair{ | |
| 'diagram': str | |
| 'text': SN() | |
| 'parsing_stru_seqs': SN() | |
| 'parsing_sem_seqs': SN() | |
| 'expression': list | |
| 'answer': str | |
| } | |
| ''' | |
| pair = self.pairs[idx] | |
| # diagram | |
| diagram_path = os.path.join(self.args.dataset_dir, 'Diagram', pair['diagram']) | |
| diagram = Image.open(diagram_path).convert("RGB") | |
| diagram = self.diagram_transform(diagram) | |
| # text, parsing_stru_seqs, parsing_sem_seqs, | |
| self.text_transform(pair['text'], | |
| pair['parsing_stru_seqs'], | |
| pair['parsing_sem_seqs'], | |
| pair['expression']) | |
| combine_text = SN() | |
| get_combined_text(pair['text'], | |
| pair['parsing_stru_seqs'], | |
| pair['parsing_sem_seqs'], | |
| combine_text, | |
| self.args) | |
| text_token, text_sect_tag, text_class_tag = \ | |
| get_text_index(combine_text, self.src_lang) | |
| # var and arg | |
| var_arg_positions, var_values, arg_values = \ | |
| get_var_arg(combine_text, self.args) | |
| # expression | |
| expression = normalize_exp(pair['expression']) | |
| expression = self.tgt_lang.indexes_from_sentence(expression, var_values, arg_values) | |
| # choices | |
| choices = [float(item) for item in pair['choices']] | |
| return diagram, \ | |
| text_token, text_sect_tag, text_class_tag, \ | |
| var_arg_positions, var_values, arg_values, \ | |
| expression, pair['answer'], pair['id'], choices | |
| def __len__(self): | |
| return len(self.pairs) | |