asdfasdfdsafdsa commited on
Commit
383bfb8
·
verified ·
1 Parent(s): 218dd62

Initial upload of PGPS demo with all dependencies

Browse files
LM_MODEL.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d0c84cefe6acd4fd66020d40eaebae5da6e3cf231266193d6f53d465ae627d0
3
+ size 64083797
README.md CHANGED
@@ -1,12 +1,52 @@
1
  ---
2
- title: Pgps Demo
3
- emoji: 📚
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.43.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: PGPS Geometric Problem Solver
3
+ emoji: 📐
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ # PGPS: Neural Geometric Problem Solver Demo
14
+
15
+ This Space demonstrates the PGPS (Plane Geometry Problem Solver) model, which uses multi-modal neural networks to solve geometry problems.
16
+
17
+ ## How to Use
18
+
19
+ 1. **Upload a Geometry Diagram**: Upload an image containing a geometric diagram (triangles, angles, lines, etc.)
20
+ 2. **Enter Problem Text**: Provide the text description of the geometry problem
21
+ 3. **Get Solution**: The model will analyze both the diagram and text to generate a solution
22
+
23
+ ## Model Details
24
+
25
+ - **Architecture**: Multi-modal neural network with visual encoder and text encoder
26
+ - **Task**: Geometric problem solving
27
+ - **Paper**: IJCAI 2023
28
+ - **Original Repository**: [GitHub](https://github.com/mingliangzhang2018/PGPS)
29
+
30
+ ## Features
31
+
32
+ - Visual diagram parsing
33
+ - Text understanding for geometric problems
34
+ - Expression generation for solutions
35
+ - Support for various geometry problem types
36
+
37
+ ## Limitations
38
+
39
+ - Best performance with clear, simple geometric diagrams
40
+ - Requires both image and text input for optimal results
41
+ - Limited to plane geometry problems
42
+
43
+ ## Citation
44
+
45
+ ```bibtex
46
+ @inproceedings{zhang2023pgps,
47
+ title={PGPS: A Neural Geometric Solver},
48
+ author={Zhang, Mingliang and others},
49
+ booktitle={IJCAI 2023},
50
+ year={2023}
51
+ }
52
+ ```
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import numpy as np
6
+ import sys
7
+ import os
8
+
9
+ # Add current directory to path
10
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
11
+
12
+ from core.network import Network, MLMTransformerPretrain
13
+ from model.backbone import get_visual_backbone
14
+ from model.encoder import get_encoder
15
+ from model.decoder import get_decoder
16
+ from datasets.preprossing import SN
17
+ from datasets.utils import get_combined_text, get_var_arg, get_text_index
18
+ from datasets.operators import normalize_exp
19
+ import datasets.diagram_aug as T_diagram
20
+
21
+ # Language classes for vocabulary management
22
+ class Lang:
23
+ def __init__(self):
24
+ self.word2index = {}
25
+ self.word2count = {}
26
+ self.index2word = {0: "PAD", 1: "SOS", 2: "EOS", 3: "UNK"}
27
+ self.n_words = 4
28
+ self.class_tag = ['PAD', 'QUE', 'VAR', 'NUM', 'SEP']
29
+ self.sect_tag = ['PAD', 'TEXT', 'STRU', 'SEM']
30
+
31
+ def add_sentence(self, sentence):
32
+ for word in sentence.split(' '):
33
+ self.add_word(word)
34
+
35
+ def add_word(self, word):
36
+ if word not in self.word2index:
37
+ self.word2index[word] = self.n_words
38
+ self.word2count[word] = 1
39
+ self.index2word[self.n_words] = word
40
+ self.n_words += 1
41
+ else:
42
+ self.word2count[word] += 1
43
+
44
+ def indexes_from_sentence(self, sentence, var_values=None, arg_values=None):
45
+ indexes = []
46
+ for word in sentence.split(' '):
47
+ if word in self.word2index:
48
+ indexes.append(self.word2index[word])
49
+ else:
50
+ indexes.append(3) # UNK
51
+ return indexes
52
+
53
+ # Configuration class
54
+ class Config:
55
+ def __init__(self):
56
+ # Visual backbone
57
+ self.visual_backbone = "ResNet10"
58
+ self.diagram_size = 128
59
+
60
+ # Encoder
61
+ self.encoder_type = "gru"
62
+ self.encoder_layers = 2
63
+ self.encoder_embedding_size = 256
64
+ self.encoder_hidden_size = 512
65
+ self.max_input_len = 400
66
+
67
+ # Decoder
68
+ self.decoder_type = "rnn_decoder"
69
+ self.decoder_layers = 2
70
+ self.decoder_embedding_size = 512
71
+ self.decoder_hidden_size = 512
72
+ self.max_output_len = 40
73
+
74
+ # General
75
+ self.dropout_rate = 0.2
76
+ self.beam_size = 10
77
+ self.use_MLM_pretrain = True
78
+ self.MLM_pretrain_path = './LM_MODEL.pth'
79
+ self.pretrain_emb_path = ''
80
+
81
+ # Dataset
82
+ self.without_stru = False
83
+
84
+ # Initialize model
85
+ def load_model():
86
+ cfg = Config()
87
+
88
+ # Load vocabularies
89
+ src_lang = Lang()
90
+ tgt_lang = Lang()
91
+
92
+ # Load vocab files
93
+ with open('./vocab/vocab_src.txt', 'r') as f:
94
+ for line in f:
95
+ src_lang.add_word(line.strip())
96
+
97
+ with open('./vocab/vocab_tgt.txt', 'r') as f:
98
+ for line in f:
99
+ tgt_lang.add_word(line.strip())
100
+
101
+ # Create model
102
+ model = Network(cfg, src_lang, tgt_lang)
103
+
104
+ # Load pretrained weights if available
105
+ if os.path.exists('./LM_MODEL.pth'):
106
+ model.mlm_pretrain.load_model('./LM_MODEL.pth')
107
+
108
+ model.eval()
109
+ return model, src_lang, tgt_lang, cfg
110
+
111
+ # Process image and text
112
+ def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
113
+ # Transform image
114
+ diagram_transform = T_diagram.Compose([
115
+ T_diagram.Resize(cfg.diagram_size),
116
+ T_diagram.CenterCrop(cfg.diagram_size),
117
+ T_diagram.ToTensor(),
118
+ T_diagram.Normalize()
119
+ ])
120
+
121
+ diagram = diagram_transform(image).unsqueeze(0)
122
+
123
+ # Process text input
124
+ # Create a simple text structure
125
+ text_sn = SN()
126
+ text_sn.word_list = text_input.split()
127
+ text_sn.clause_list = [text_input]
128
+
129
+ # Create empty parsing structures (will be filled with defaults)
130
+ parsing_stru = SN()
131
+ parsing_stru.word_list = []
132
+ parsing_stru.clause_list = []
133
+
134
+ parsing_sem = SN()
135
+ parsing_sem.word_list = []
136
+ parsing_sem.clause_list = []
137
+
138
+ # Combine text
139
+ combine_text = SN()
140
+ get_combined_text(text_sn, parsing_stru, parsing_sem, combine_text, cfg)
141
+
142
+ # Get text indices
143
+ text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang)
144
+
145
+ # Convert to tensors
146
+ text_dict = {
147
+ 'token': torch.LongTensor([text_token]),
148
+ 'sect_tag': torch.LongTensor([text_sect_tag]),
149
+ 'class_tag': torch.LongTensor([text_class_tag]),
150
+ 'len': torch.LongTensor([len(text_token)])
151
+ }
152
+
153
+ # Get variables and arguments
154
+ var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg)
155
+
156
+ var_dict = {
157
+ 'pos': torch.LongTensor([var_arg_positions]),
158
+ 'len': torch.LongTensor([len(var_arg_positions)]),
159
+ 'var_value': var_values,
160
+ 'arg_value': arg_values
161
+ }
162
+
163
+ # Create dummy expression dict for inference
164
+ exp_dict = {
165
+ 'exp': torch.LongTensor([[1]]), # SOS token
166
+ 'len': torch.LongTensor([1]),
167
+ 'answer': 0
168
+ }
169
+
170
+ # Run inference
171
+ with torch.no_grad():
172
+ outputs = model(diagram, text_dict, var_dict, exp_dict, is_train=False)
173
+
174
+ # Decode outputs
175
+ if outputs is not None:
176
+ # Convert output indices to symbols
177
+ output_symbols = []
178
+ for idx in outputs[0]:
179
+ if idx < len(tgt_lang.index2word):
180
+ symbol = tgt_lang.index2word[idx]
181
+ if symbol == 'EOS':
182
+ break
183
+ if symbol not in ['PAD', 'SOS']:
184
+ output_symbols.append(symbol)
185
+
186
+ expression = ' '.join(output_symbols)
187
+
188
+ # Try to evaluate the expression
189
+ try:
190
+ # Simple evaluation (this would need more sophisticated handling in production)
191
+ result = eval_expression(expression, var_values, arg_values)
192
+ return f"Expression: {expression}\nResult: {result}"
193
+ except:
194
+ return f"Expression: {expression}\n(Could not evaluate)"
195
+
196
+ return "Could not generate solution"
197
+
198
+ def eval_expression(expr, var_values, arg_values):
199
+ # This is a simplified evaluator - would need proper implementation
200
+ # For now, just return the expression
201
+ return expr
202
+
203
+ # Gradio interface
204
+ def predict(image, text):
205
+ if image is None:
206
+ return "Please upload a geometry diagram image"
207
+
208
+ if not text:
209
+ return "Please provide the problem text"
210
+
211
+ try:
212
+ result = process_input(image, text, model, src_lang, tgt_lang, cfg)
213
+ return result
214
+ except Exception as e:
215
+ return f"Error processing input: {str(e)}"
216
+
217
+ # Load model on startup
218
+ print("Loading PGPS model...")
219
+ model, src_lang, tgt_lang, cfg = load_model()
220
+ print("Model loaded successfully!")
221
+
222
+ # Create Gradio interface
223
+ iface = gr.Interface(
224
+ fn=predict,
225
+ inputs=[
226
+ gr.Image(type="pil", label="Geometry Diagram"),
227
+ gr.Textbox(
228
+ lines=3,
229
+ placeholder="Enter the geometry problem text here...\nExample: Find the angle x if angle ABC is 60 degrees",
230
+ label="Problem Text"
231
+ )
232
+ ],
233
+ outputs=gr.Textbox(label="Solution", lines=5),
234
+ title="PGPS: Neural Geometric Problem Solver",
235
+ description="Upload a geometry diagram and provide the problem text to get a solution.",
236
+ examples=[
237
+ [None, "Find the value of angle x if angle ABC is 60 degrees and angle BCD is 90 degrees"],
238
+ [None, "Calculate the area of triangle ABC if AB = 5, BC = 7, and angle B = 60 degrees"]
239
+ ],
240
+ theme="default"
241
+ )
242
+
243
+ if __name__ == "__main__":
244
+ iface.launch()
config/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .config_default import *
2
+ from .logger import *
3
+
config/config_default.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torchvision.models as models
3
+
4
+
5
+ model_names = sorted(name for name in models.__dict__
6
+ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]))
7
+
8
+ criterion_list = ["CrossEntropy", "FocalLoss", "MaskedCrossEntropy"]
9
+ optimizer_list = ["SGD", "ADAM"]
10
+ scheduler_list = ["multistep",'cosine','warmup']
11
+ visual_backbone_list = ['ResNet10', 'mobilenet_v2']
12
+ encoder_list = ['lstm', 'gru', 'transformer']
13
+ decoder_list = ["rnn_decoder", "tree_decoder"]
14
+ eval_method_list = ["completion", "choice", "top3"]
15
+ dataset_list = ['Geometry3K', 'PGPS9K']
16
+
17
+ def get_parser():
18
+ parser = argparse.ArgumentParser(description='PyTorch PGPS Training')
19
+ # visual backbone
20
+ ##############################################################################
21
+ parser.add_argument('--visual_backbone', default="ResNet10", type=str, choices=visual_backbone_list)
22
+ parser.add_argument('--diagram_size', default=128, type=int)
23
+ # encoder model
24
+ ##############################################################################
25
+ parser.add_argument('--encoder_type', default="gru", type=str, choices=encoder_list)
26
+ parser.add_argument('--encoder_layers', default=2, type=int)
27
+ parser.add_argument('--encoder_embedding_size', default=256, type=int)
28
+ parser.add_argument('--encoder_hidden_size', default=512, type=int)
29
+ parser.add_argument('--max_input_len', default=400, type=int)
30
+ # decoder model
31
+ ##############################################################################
32
+ parser.add_argument('--decoder_type', default="rnn_decoder", type=str, choices=decoder_list)
33
+ parser.add_argument('--decoder_layers', default=2, type=int)
34
+ parser.add_argument('--decoder_embedding_size', default=512, type=int)
35
+ parser.add_argument('--decoder_hidden_size', default=512, type=int)
36
+ parser.add_argument('--max_output_len', default=40, type=int)
37
+ # general model
38
+ ##############################################################################
39
+ parser.add_argument('--dropout_rate', default=0.2, type=float)
40
+ parser.add_argument('--beam_size', default=10, type=int)
41
+ # optimizer
42
+ ##############################################################################
43
+ parser.add_argument('--optimizer_type', default="ADAMW", type=str, choices=optimizer_list)
44
+ parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate without LM')
45
+ parser.add_argument('--lr_LM', default=1e-4, type=float, help='initial learning rate of LM')
46
+ parser.add_argument('--weight_decay', default=0.01, type=float)
47
+ parser.add_argument('--max_epoch', default=540, type=int)
48
+ parser.add_argument('--scheduler_type', default="warmup", type=str, choices=scheduler_list)
49
+ parser.add_argument('--scheduler_step', default=[160, 280, 360, 440, 500], type=list)
50
+ parser.add_argument('--scheduler_factor', default=0.5, type=float, help='learning rate decay factor')
51
+ parser.add_argument('--cosine_decay_end', default=0.0, type=float, help='cosine decay end')
52
+ parser.add_argument('--warm_epoch', default=40, type=int)
53
+ # criterion
54
+ ###############################################################################
55
+ parser.add_argument('--criterion', default="MaskedCrossEntropy", choices=criterion_list, type=str)
56
+ parser.add_argument('--eval_method', default="top3", choices=eval_method_list, type=str)
57
+ # dataset
58
+ ################################################################################
59
+ parser.add_argument('--dataset', default="PGPS9K", type=str, choices=dataset_list)
60
+ parser.add_argument('--dataset_dir', default='./datasets/PGPS9K_all')
61
+ parser.add_argument('--pretrain_vis_path', default='')
62
+ parser.add_argument('--vocab_src_path', default='./vocab/vocab_src.txt')
63
+ parser.add_argument('--vocab_tgt_path', default='./vocab/vocab_tgt.txt')
64
+ parser.add_argument('--pretrain_emb_path', default='')
65
+ parser.add_argument('--batch_size', default=128, type=int)
66
+ parser.add_argument('--random_prob', default=0.5, type=float)
67
+ parser.add_argument('--without_stru', action='store_true', help='structure clauses are used or not')
68
+ parser.add_argument('--trim_min_count', default=5, type=int, help='minimum number of word')
69
+ parser.add_argument('--use_MLM_pretrain', action='store_true', help='use MLM pretrain')
70
+ parser.add_argument('--MLM_pretrain_path', default='./pretraining_model/LM_MODEL.pth')
71
+ # print information
72
+ ###################################################################################
73
+ parser.add_argument('--dump_path', default="./log/", type=str, help='save log path')
74
+ parser.add_argument('--print_freq', default=20, type=int, help='print frequency')
75
+ parser.add_argument('--eval_epoch', default=40, type=int)
76
+ # general config
77
+ ###################################################################################
78
+ parser.add_argument('--workers', default=4, type=int)
79
+ parser.add_argument('--evaluate_only', action='store_true', help='evaluate model on validation set')
80
+ parser.add_argument('--resume_model', default="", type=str, help='use pre-trained model')
81
+ # DistributedDataParallel
82
+ ###################################################################################
83
+ parser.add_argument('--local_rank', default=0, type=int, help='node rank for distributed training')
84
+ parser.add_argument('--init_method', default="env://", type=str, help='distributed init method')
85
+ parser.add_argument('--debug', action='store_true', help = "if debug than set local rank = 0")
86
+ parser.add_argument('--seed', default=202302, type=int,help='seed for initializing training. ')
87
+
88
+ return parser.parse_args()
config/logger.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from logging import handlers
3
+
4
+ class Logger(object):
5
+ level_relations = {
6
+ 'debug':logging.DEBUG,
7
+ 'info':logging.INFO,
8
+ 'warning':logging.WARNING,
9
+ 'error':logging.ERROR,
10
+ 'crit':logging.CRITICAL
11
+ }
12
+
13
+ def __init__(self, filename, rank, level='info', when='D', backCount=3, fmt='%(asctime)s - %(levelname)s: %(message)s'):
14
+ self.logger = logging.getLogger(filename)
15
+ if rank!=0: return
16
+ format_str = logging.Formatter(fmt)
17
+ self.logger.setLevel(self.level_relations.get(level))
18
+ sh = logging.StreamHandler()
19
+ sh.setFormatter(format_str)
20
+ th = handlers.TimedRotatingFileHandler(filename=filename,when=when,backupCount=backCount,encoding='utf-8')
21
+ th.setFormatter(format_str)
22
+ self.logger.addHandler(sh)
23
+ self.logger.addHandler(th)
24
+
25
+ def create_logger(filepath, rank):
26
+ log = Logger(filepath, rank)
27
+ return log.logger
core/__init__.py ADDED
File without changes
core/network.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model.backbone import get_visual_backbone
4
+ from model.encoder import get_encoder, TransformerEncoder
5
+ from model.decoder import get_decoder
6
+ from utils.utils import *
7
+ import numpy as np
8
+
9
+
10
+ class MLMTransformerPretrain(nn.Module):
11
+
12
+ def __init__(self, cfg, src_lang):
13
+ super(MLMTransformerPretrain, self).__init__()
14
+ self.cfg = cfg
15
+ self.transformer_en = TransformerEncoder(cfg.encoder_embedding_size)
16
+ self.text_embedding_src = self.get_text_embedding_src(
17
+ vocab_size = src_lang.n_words,
18
+ embedding_dim = cfg.encoder_embedding_size,
19
+ padding_idx = 0,
20
+ pretrain_emb_path = cfg.pretrain_emb_path
21
+ )
22
+ self.class_tag_embedding = nn.Embedding(
23
+ len(src_lang.class_tag),
24
+ cfg.encoder_embedding_size,
25
+ padding_idx=0
26
+ )
27
+ self.sect_tag_embedding = nn.Embedding(
28
+ len(src_lang.sect_tag),
29
+ cfg.encoder_embedding_size,
30
+ padding_idx=0
31
+ )
32
+
33
+ def forward(self, text_dict):
34
+ '''
35
+ text_dict = {'token', 'sect_tag', 'class_tag', 'len'}
36
+ '''
37
+ # text feature
38
+ token_emb = self.text_embedding_src(text_dict['token'])
39
+ class_tag_emb = self.class_tag_embedding(text_dict['class_tag'])
40
+ sect_tag_emb = self.sect_tag_embedding(text_dict['sect_tag'])
41
+ text_emb_src = token_emb.sum(dim=1) + sect_tag_emb + class_tag_emb
42
+ transformer_outputs = self.transformer_en(text_dict['len'], text_emb_src)
43
+ return transformer_outputs
44
+
45
+ def load_model(self, model_path):
46
+ pretrain_dict = torch.load(
47
+ model_path, map_location="cuda"
48
+ )
49
+ pretrain_dict_model = pretrain_dict['state_dict'] \
50
+ if 'state_dict' in pretrain_dict else pretrain_dict
51
+ model_dict = self.state_dict()
52
+ from collections import OrderedDict
53
+ new_dict = OrderedDict()
54
+ for k, v in pretrain_dict_model.items():
55
+ if k in model_dict:
56
+ if k.startswith("module"):
57
+ new_dict[k[7:]] = v
58
+ else:
59
+ new_dict[k] = v
60
+ model_dict.update(new_dict)
61
+ self.load_state_dict(model_dict)
62
+
63
+ def get_text_embedding_src(self, vocab_size, embedding_dim, padding_idx, pretrain_emb_path):
64
+
65
+ embedding_src = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
66
+ if pretrain_emb_path!='':
67
+ emb_content = []
68
+ with open(pretrain_emb_path, 'r') as f:
69
+ for line in f:
70
+ emb_content.append(line.split()[1:])
71
+ vector = np.asarray(emb_content, "float32")
72
+ embedding_src.weight.data[-len(emb_content):]. \
73
+ copy_(torch.from_numpy(vector))
74
+ return embedding_src
75
+
76
+ class Network(nn.Module):
77
+
78
+ def __init__(self, cfg, src_lang, tgt_lang):
79
+ super(Network, self).__init__()
80
+ self.cfg = cfg
81
+ # define the encoder and decoder
82
+ self.visual_extractor = get_visual_backbone(cfg)
83
+ self.encoder = get_encoder(cfg)
84
+ self.decoder = get_decoder(cfg, tgt_lang)
85
+ self.visual_emb_unify = nn.ModuleList([
86
+ nn.Linear(self.visual_extractor.final_feat_dim, cfg.encoder_embedding_size),
87
+ nn.ReLU(),
88
+ nn.Linear(cfg.encoder_embedding_size, cfg.encoder_embedding_size)]
89
+ )
90
+ self.visual_emb_unify = nn.Sequential(*self.visual_emb_unify)
91
+
92
+ if cfg.use_MLM_pretrain:
93
+ self.mlm_pretrain = MLMTransformerPretrain(cfg, src_lang)
94
+ if cfg.MLM_pretrain_path!='':
95
+ self.mlm_pretrain.load_model(cfg.MLM_pretrain_path)
96
+ else:
97
+ self.text_embedding_src = self.get_text_embedding_src(
98
+ vocab_size = src_lang.n_words,
99
+ embedding_dim = cfg.encoder_embedding_size,
100
+ padding_idx = 0,
101
+ pretrain_emb_path = cfg.pretrain_emb_path
102
+ )
103
+ self.class_tag_embedding = nn.Embedding(
104
+ len(src_lang.class_tag),
105
+ cfg.encoder_embedding_size,
106
+ padding_idx=0
107
+ )
108
+ self.sect_tag_embedding = nn.Embedding(
109
+ len(src_lang.sect_tag),
110
+ cfg.encoder_embedding_size,
111
+ padding_idx=0
112
+ )
113
+
114
+ self.src_lang = src_lang
115
+
116
+ def forward(self, diagram_src, text_dict, var_dict, exp_dict, is_train=False):
117
+ '''
118
+ diagram_src: B x C x W x H
119
+ text_dict = {'token', 'sect_tag', 'class_tag', 'len'} /
120
+ {'token', 'sect_tag', 'class_tag', 'subseq_len', 'item_len', 'item_quant'}
121
+ var_dict = {'pos', 'len', 'var_value', 'arg_value'}
122
+ exp_dict = {'exp', 'len', 'answer'}
123
+ '''
124
+
125
+ if self.cfg.use_MLM_pretrain:
126
+ text_emb_src = self.mlm_pretrain(text_dict)
127
+ else:
128
+ # text feature
129
+ token_emb = self.text_embedding_src(text_dict['token'])
130
+ class_tag_emb = self.class_tag_embedding(text_dict['class_tag'])
131
+ sect_tag_emb = self.sect_tag_embedding(text_dict['sect_tag'])
132
+ # all feature
133
+ text_emb_src = token_emb.sum(dim=1) + sect_tag_emb + class_tag_emb
134
+
135
+ # diagram feature
136
+ diagram_emb_src = self.visual_extractor(diagram_src)
137
+ diagram_emb_src = self.visual_emb_unify(diagram_emb_src).unsqueeze(dim=1)
138
+ # feature all
139
+ all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1)
140
+ text_dict['len'] += 1
141
+ var_dict['pos'] += 1
142
+ # encoder
143
+ encoder_outputs, encode_hidden = self.encoder(all_emb_src, text_dict['len'])
144
+ problem_output = encode_hidden[-1:,:,:].repeat(self.cfg.decoder_layers, 1, 1)
145
+ # decoder
146
+ outputs = self.decoder(encoder_outputs, problem_output, \
147
+ text_dict['len'], \
148
+ var_dict['pos'], var_dict['len'], \
149
+ exp_dict['exp'], \
150
+ is_train)
151
+ return outputs
152
+
153
+ def freeze_module(self, module):
154
+ self.cfg.logger.info("Freezing module of "+" .......")
155
+ for p in module.parameters():
156
+ p.requires_grad = False
157
+
158
+ def load_model(self, model_path):
159
+ pretrain_dict = torch.load(
160
+ model_path, map_location="cuda"
161
+ )
162
+ pretrain_dict_model = pretrain_dict['state_dict'] \
163
+ if 'state_dict' in pretrain_dict else pretrain_dict
164
+ model_dict = self.state_dict()
165
+ from collections import OrderedDict
166
+ new_dict = OrderedDict()
167
+ for k, v in pretrain_dict_model.items():
168
+ if k.startswith("module"):
169
+ new_dict[k[7:]] = v
170
+ else:
171
+ new_dict[k] = v
172
+ model_dict.update(new_dict)
173
+ self.load_state_dict(model_dict)
174
+ return pretrain_dict
175
+
176
+ def get_text_embedding_src(self, vocab_size, embedding_dim, padding_idx, pretrain_emb_path):
177
+
178
+ embedding_src = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
179
+ if pretrain_emb_path!='':
180
+ emb_content = []
181
+ with open(pretrain_emb_path, 'r') as f:
182
+ for line in f:
183
+ emb_content.append(line.split()[1:])
184
+ vector = np.asarray(emb_content, "float32")
185
+ embedding_src.weight.data[-len(emb_content):]. \
186
+ copy_(torch.from_numpy(vector))
187
+ return embedding_src
188
+
189
+
190
+ def get_model(args, src_lang, tgt_lang):
191
+ model = Network(args, src_lang, tgt_lang)
192
+ args.logger.info(str(model))
193
+ return model
194
+
195
+
196
+
197
+
198
+
199
+
200
+
core/test.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from utils import *
3
+
4
+ def validate(args, val_loader, model, tgt_lang):
5
+
6
+ batch_time = AverageMeter('Time', ':5.3f')
7
+ acc_ans = AverageMeter('Ans_Acc', ':5.4f')
8
+ acc_eq = AverageMeter('Eq_Acc', ':5.4f')
9
+ progress = ProgressMeter(len(val_loader), [batch_time, acc_ans, acc_eq], args, prefix='Test: ')
10
+ # switch to evaluate mode
11
+ model.eval()
12
+
13
+ with torch.no_grad():
14
+ end = time.time()
15
+ for i, (diagrams, text_dict, var_dict, exp_dict) in enumerate(val_loader):
16
+ # set cuda for input data
17
+ diagrams = diagrams.cuda()
18
+ set_cuda(text_dict), set_cuda(var_dict), set_cuda(exp_dict)
19
+ # compute output
20
+ output = model(diagrams, text_dict, var_dict, exp_dict, is_train=False)
21
+ if args.eval_method == "completion":
22
+ acc1, acc2 = compute_exp_result_comp(output, var_dict, exp_dict, tgt_lang)
23
+ elif args.eval_method == "choice":
24
+ acc1, acc2 = compute_exp_result_choice(output, var_dict, exp_dict, tgt_lang)
25
+ elif args.eval_method == "top3":
26
+ acc1, acc2 = compute_exp_result_topk(output, var_dict, exp_dict, tgt_lang, k_num=3)
27
+
28
+ torch.distributed.barrier()
29
+
30
+ reduced_acc_ans = reduce_mean(torch.tensor([acc1]).cuda(), args.nprocs)
31
+ reduced_acc_eq = reduce_mean(torch.tensor([acc2]).cuda(), args.nprocs)
32
+
33
+ acc_ans.update(reduced_acc_ans.item(), len(diagrams))
34
+ acc_eq.update(reduced_acc_eq.item(), len(diagrams))
35
+
36
+ # measure elapsed time
37
+ batch_time.update(time.time() - end)
38
+ end = time.time()
39
+
40
+ return acc_ans.avg, acc_eq.avg
core/train.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from utils import *
3
+
4
+ def train(args, epoch, train_loader, model, criterion, optimizer):
5
+
6
+ batch_time = AverageMeter('Time', ':5.3f')
7
+ data_time = AverageMeter('Data', ':5.3f')
8
+ losses = AverageMeter('Loss', ':.4e')
9
+ progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses],
10
+ args, prefix="Epoch: [{}]".format(epoch))
11
+
12
+ # switch to train mode
13
+ model.train()
14
+ end = time.time()
15
+
16
+ for i, (diagrams, text_dict, var_dict, exp_dict) in enumerate(train_loader):
17
+ '''
18
+ text_dict = {'token', 'sect_tag', 'class_tag', 'len'}
19
+ var_dict = {'pos', 'len', 'var_value', 'arg_value'}
20
+ exp_dict = {'exp', 'len', 'answer'}
21
+ '''
22
+ # measure data loading time
23
+ data_time.update(time.time() - end)
24
+ # set cuda for input data
25
+ diagrams = diagrams.cuda()
26
+ set_cuda(text_dict), set_cuda(var_dict), set_cuda(exp_dict)
27
+ # compute output
28
+ output = model(diagrams, text_dict, var_dict, exp_dict, is_train=True)
29
+ loss = criterion(output, exp_dict['exp'][:,1:].clone(), exp_dict['len']-1) # Remove special symbol [SOS]
30
+ # update the loss
31
+ torch.distributed.barrier()
32
+ reduced_loss = reduce_mean(loss, args.nprocs)
33
+ losses.update(reduced_loss.item(), len(diagrams))
34
+ # compute gradient and do SGD step
35
+ optimizer.zero_grad()
36
+ loss.backward()
37
+ optimizer.step()
38
+ # measure elapsed time
39
+ batch_time.update(time.time() - end)
40
+ end = time.time()
41
+ if i % args.print_freq == 0:
42
+ progress.display(i, lr = optimizer.state_dict()['param_groups'][0]['lr'])
43
+
44
+ return losses.avg
core/worker.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim
3
+ import torch.utils.data
4
+ import torch.nn.parallel
5
+ from core.train import *
6
+ from core.test import *
7
+ from utils import *
8
+ from core.network import get_model
9
+ from loss import get_criterion
10
+ from datasets import get_dataloader
11
+
12
+
13
+ def main_worker(args):
14
+
15
+ args.logger = initialize_logger(args)
16
+ train_loader, train_sampler, val_loader, src_lang, tgt_lang = get_dataloader(args)
17
+ model = get_model(args, src_lang, tgt_lang).cuda()
18
+ optimizer = get_optimizer(args, model)
19
+ scheduler = get_scheduler(args, optimizer)
20
+ criterion = get_criterion(args)
21
+ start_epoch = 0
22
+
23
+ # resume model
24
+ if not args.resume_model =='':
25
+ resume_model_dict = model.load_model(args.resume_model)
26
+ optimizer.load_state_dict(resume_model_dict['optimizer'])
27
+ scheduler.load_state_dict(resume_model_dict['scheduler'])
28
+ start_epoch = resume_model_dict["epoch"]+1
29
+ args.logger.info("The whole model has been loaded from "+ args.resume_model)
30
+ args.logger.info("The model resumes from epoch "+ str(resume_model_dict["epoch"]))
31
+ if args.evaluate_only:
32
+ acc_ans, acc_eq = validate(args, val_loader, model, tgt_lang)
33
+ args.logger.info("----------Epoch:{:>3d}, test answer_acc {:>5.4f}, equation_acc {:>5.4f} ---------" \
34
+ .format(resume_model_dict["epoch"], acc_ans, acc_eq))
35
+ return
36
+ else:
37
+ args.logger.info("The model is trained from scratch")
38
+
39
+ # distributed parallel training
40
+ model = torch.nn.parallel.DistributedDataParallel(
41
+ model,
42
+ device_ids=[args.local_rank],
43
+ output_device=args.local_rank,
44
+ find_unused_parameters=True
45
+ )
46
+
47
+ min_loss = 1e10
48
+
49
+ for epoch in range(start_epoch, args.max_epoch):
50
+ # train for one epoch
51
+ train_sampler.set_epoch(epoch)
52
+ loss = train(args, epoch, train_loader, model, criterion, optimizer)
53
+ args.logger.info("----------Epoch:{:>3d}, training loss is {:>5.4f} ---------". \
54
+ format(epoch, loss))
55
+ # evaluate on validation set and save model
56
+ if args.local_rank == 0:
57
+ if epoch % args.eval_epoch==0 or epoch>=args.max_epoch-5:
58
+ save_checkpoint({
59
+ 'epoch': epoch ,
60
+ 'state_dict': model.state_dict(),
61
+ 'scheduler': scheduler.state_dict(),
62
+ 'optimizer': optimizer.state_dict()}, False, args.dump_path)
63
+ if loss<min_loss:
64
+ min_loss = loss
65
+ save_checkpoint({
66
+ 'epoch': epoch ,
67
+ 'state_dict': model.state_dict(),
68
+ 'scheduler': scheduler.state_dict(),
69
+ 'optimizer': optimizer.state_dict()}, True, args.dump_path)
70
+ # learning scheduler step
71
+ scheduler.step()
72
+
73
+ args.logger.info("------------------- Train Finished -------------------")
datasets/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from datasets.dataset import MyDataset
3
+ from torch.utils.data.distributed import DistributedSampler
4
+ from datasets.preprossing import *
5
+ import os
6
+
7
+ def get_dataloader(args):
8
+
9
+ src_lang = SrcLang(args.vocab_src_path)
10
+ tgt_lang = TgtLang(args.vocab_tgt_path)
11
+
12
+ train_data_path = os.path.join(args.dataset_dir, args.dataset, 'train.json')
13
+ train_pairs = get_raw_pairs(train_data_path)
14
+ test_data_path = os.path.join(args.dataset_dir, args.dataset, 'test.json')
15
+ test_pairs = get_raw_pairs(test_data_path)
16
+
17
+ train_data = MyDataset(args, train_pairs, src_lang, tgt_lang, is_train=True)
18
+ train_sampler = DistributedSampler(train_data, shuffle=True)
19
+ train_loader = DataLoader(dataset=train_data, \
20
+ batch_size=int(args.batch_size/args.nprocs), \
21
+ pin_memory=True, \
22
+ collate_fn=collater(args), \
23
+ num_workers=args.workers, \
24
+ sampler=train_sampler
25
+ )
26
+
27
+ test_data = MyDataset(args, test_pairs, src_lang, tgt_lang, is_train=False)
28
+ test_sampler = DistributedSampler(test_data, shuffle=False)
29
+ test_loader = DataLoader(dataset=test_data, \
30
+ batch_size=1, \
31
+ pin_memory=True, \
32
+ collate_fn=collater(args), \
33
+ num_workers=args.workers, \
34
+ sampler=test_sampler
35
+ )
36
+
37
+ return train_loader, train_sampler, test_loader, src_lang, tgt_lang
datasets/dataset.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from PIL import Image
4
+ import datasets.diagram_aug as T_diagram
5
+ import datasets.text_aug as T_text
6
+ from datasets.operators import normalize_exp
7
+ from datasets.utils import get_combined_text, get_var_arg, get_text_index
8
+ from datasets.preprossing import SN
9
+
10
+ class MyDataset(torch.utils.data.Dataset):
11
+
12
+ def __init__(self, args, pairs, src_lang, tgt_lang, is_train=True):
13
+ super().__init__()
14
+ self.args = args
15
+ self.pairs = pairs
16
+ self.src_lang = src_lang
17
+ self.tgt_lang = tgt_lang
18
+ self.is_train = is_train
19
+ if is_train:
20
+ random_prob = args.random_prob
21
+ else:
22
+ random_prob = 0
23
+ self.diagram_transform = T_diagram.Compose([
24
+ T_diagram.Resize(args.diagram_size),
25
+ T_diagram.CenterCrop(args.diagram_size),
26
+ T_diagram.RandomFlip(random_prob),
27
+ T_diagram.ToTensor(),
28
+ T_diagram.Normalize()
29
+ ])
30
+ self.text_transform = T_text.Compose([
31
+ T_text.Point_RandomReplace(random_prob),
32
+ T_text.AngID_RandomReplace(random_prob),
33
+ # T_text.Arg_RandomReplace(random_prob),
34
+ T_text.StruPoint_RandomRotate(random_prob),
35
+ # T_text.SemPoint_RandomRotate(random_prob),
36
+ T_text.SemSeq_RandomRotate(random_prob),
37
+ T_text.StruSeq_RandomRotate(random_prob),
38
+ ])
39
+
40
+ def __getitem__(self, idx):
41
+ '''
42
+ pair{
43
+ 'diagram': str
44
+ 'text': SN()
45
+ 'parsing_stru_seqs': SN()
46
+ 'parsing_sem_seqs': SN()
47
+ 'expression': list
48
+ 'answer': str
49
+ }
50
+ '''
51
+ pair = self.pairs[idx]
52
+
53
+ # diagram
54
+ diagram_path = os.path.join(self.args.dataset_dir, 'Diagram', pair['diagram'])
55
+ diagram = Image.open(diagram_path).convert("RGB")
56
+ diagram = self.diagram_transform(diagram)
57
+ # text, parsing_stru_seqs, parsing_sem_seqs,
58
+ self.text_transform(pair['text'],
59
+ pair['parsing_stru_seqs'],
60
+ pair['parsing_sem_seqs'],
61
+ pair['expression'])
62
+ combine_text = SN()
63
+ get_combined_text(pair['text'],
64
+ pair['parsing_stru_seqs'],
65
+ pair['parsing_sem_seqs'],
66
+ combine_text,
67
+ self.args)
68
+ text_token, text_sect_tag, text_class_tag = \
69
+ get_text_index(combine_text, self.src_lang)
70
+ # var and arg
71
+ var_arg_positions, var_values, arg_values = \
72
+ get_var_arg(combine_text, self.args)
73
+ # expression
74
+ expression = normalize_exp(pair['expression'])
75
+ expression = self.tgt_lang.indexes_from_sentence(expression, var_values, arg_values)
76
+ # choices
77
+ choices = [float(item) for item in pair['choices']]
78
+
79
+ return diagram, \
80
+ text_token, text_sect_tag, text_class_tag, \
81
+ var_arg_positions, var_values, arg_values, \
82
+ expression, pair['answer'], pair['id'], choices
83
+
84
+ def __len__(self):
85
+ return len(self.pairs)
datasets/diagram_aug.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from torchvision.transforms import functional as F
3
+
4
+ class Compose(object):
5
+ def __init__(self, transforms):
6
+ self.transforms = transforms
7
+
8
+ def __call__(self, image):
9
+ for t in self.transforms:
10
+ image = t(image)
11
+ return image
12
+
13
+ def __repr__(self):
14
+ format_string = self.__class__.__name__ + "("
15
+ for t in self.transforms:
16
+ format_string += "\n"
17
+ format_string += " {0}".format(t)
18
+ format_string += "\n)"
19
+ return format_string
20
+
21
+ class Resize(object):
22
+ '''
23
+ Resize the training diagram samples, resize the longest edge as max_size
24
+ '''
25
+ def __init__(self, max_size):
26
+ self.max_size = max_size
27
+
28
+ def get_size(self, image_size):
29
+ w, h = image_size
30
+ if w < h:
31
+ ow = int(w * self.max_size / h)
32
+ oh = self.max_size
33
+ else:
34
+ ow = self.max_size
35
+ oh = int(h * self.max_size / w)
36
+ return (oh, ow)
37
+
38
+ def __call__(self, image):
39
+ size = self.get_size(image.size)
40
+ image = F.resize(image, size)
41
+ return image
42
+
43
+ class CenterCrop(object):
44
+ '''
45
+ Crops the given image at the center.
46
+ '''
47
+ def __init__(self, size):
48
+ self.size = size
49
+
50
+ def __call__(self, image):
51
+ return F.center_crop(image, self.size)
52
+
53
+ class RandomFlip(object):
54
+ def __init__(self, prob=0.5):
55
+ self.prob = prob
56
+
57
+ def __call__(self, image):
58
+ if random.random() < self.prob:
59
+ flip_method = random.choice([0,1,2])
60
+ if flip_method==0:
61
+ image = F.hflip(image)
62
+ elif flip_method==1:
63
+ image = F.vflip(image)
64
+ elif flip_method==2:
65
+ image = F.vflip(F.hflip(image))
66
+ return image
67
+
68
+ class ToTensor(object):
69
+ def __call__(self, image):
70
+ return F.to_tensor(image)
71
+
72
+ class Normalize(object):
73
+ def __init__(self, mean=[0.85,0.85,0.85], std=[0.3,0.3,0.3]):
74
+ self.mean = mean
75
+ self.std = std
76
+
77
+ def __call__(self, image):
78
+ image = F.normalize(image, mean=self.mean, std=self.std)
79
+ return image
datasets/operators.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.parsing.latex import parse_latex
2
+ from sympy.printing import latex
3
+ from sympy import solve
4
+ from sympy.core.numbers import Float
5
+
6
+ ################# Program Executor ########################
7
+
8
+ spec_token_list = ['frac', 'pi', 'sqrt']
9
+ spec_letter_list = ['f', 'r', 'a', 'c', 'p', 'i', 's', 'q', 'r', 't']
10
+ low_case_list = [chr(i) for i in range(97, 123)]
11
+ fixed_order_ops = [
12
+ 'Get', 'Iso_Tri_Ang', 'Gsin', 'Gcos', 'Gtan', 'Geo_Mean', 'Ratio', 'TanSec_Ang', \
13
+ 'Chord2_Ang', 'Tria_BH_Area', 'Para_Area', 'Kite_Area', 'Circle_R_Circum', \
14
+ 'Circle_D_Circum', 'Circle_R_Area', 'Circle_D_Area', 'ArcSeg_Area', 'Ngon_Angsum', \
15
+ 'RNgon_B_Area', 'RNgon_L_Area', 'RNgon_H_Area']
16
+ alterable_order_ops = [
17
+ 'Sum', 'Multiple', 'Equal', 'Gougu', 'Cos_Law', 'Sin_Law', 'Median', 'Proportion', \
18
+ 'Tria_SAS_Area', 'PRK_Perim', 'Rect_Area', 'Rhom_Area', 'Trap_Area']
19
+ arith_op_list = fixed_order_ops + alterable_order_ops
20
+ priority_list = ["V0", "V1", "V2", "V3", "V4", "V5", "V6", \
21
+ "N0", "N1", "N2", "N3", "N4", "N5", "N6", "N7", "N8", "N9", "N10", \
22
+ "C0.5", "C2", "C3", "C4", "C5", "C6", "C8", "C60", "C90", "C180", "C360"]
23
+ V_NUM = 10
24
+
25
+ class Varible_Record(object):
26
+ def __init__(self):
27
+ self.varible_dict = dict()
28
+ self.mid_varible_dict = dict()
29
+ self.result = ''
30
+
31
+ def get_priority(token):
32
+ if token in priority_list:
33
+ return priority_list.index(token)
34
+ else:
35
+ return -1 # arg
36
+
37
+ def is_exist_operator(func, ANNO):
38
+ if not func in arith_op_list:
39
+ print("Can Not Find Operators!")
40
+ raise Exception
41
+ return func
42
+
43
+ def choose_result(result_list):
44
+ if len(result_list)==0:
45
+ return None
46
+ elif len(result_list)==1:
47
+ return result_list[0]
48
+ elif len(result_list)>1:
49
+ t1 = result_list[0].evalf()
50
+ t2 = result_list[1].evalf()
51
+ if (t1>t2 and t2<=0) or (t1<t2 and t1>0):
52
+ return result_list[0]
53
+ else:
54
+ return result_list[1]
55
+
56
+ def operand_update(operands, ANNO):
57
+ for id in range(len(operands)):
58
+ # Substitute variable
59
+ if operands[id] in ANNO.mid_varible_dict:
60
+ operands[id] = "("+ANNO.mid_varible_dict[operands[id]]+")"
61
+ # pi
62
+ if "\\pi" in operands[id]:
63
+ operands[id]=operands[id].replace('\\pi','(3.141593)')
64
+ # mixed number (improper fraction)
65
+ if "\\frac" in operands[id]:
66
+ loc = operands[id].index("\\frac")
67
+ if loc>0 and operands[id][loc-1].isdigit():
68
+ operands[id] = operands[id][:loc]+'+'+operands[id][loc:]
69
+ continue
70
+ # Substitute process(intermediate) variable
71
+ if operands[id] in ANNO.mid_varible_dict:
72
+ operands[id] = "("+ANNO.mid_varible_dict[operands[id]]+")"
73
+ continue
74
+ # Substitute constant
75
+ if operands[id][0] == 'C':
76
+ operands[id] = operands[id][1:]
77
+
78
+ return operands
79
+
80
+ def mid_var_solve(expr_step, ANNO, visit_list, midvar2letter):
81
+
82
+ # replace process(intermediate) variable
83
+ for key, value in midvar2letter.items():
84
+ expr_step = expr_step.replace(key, value)
85
+ # Convert the latex form expression to sympy solvable form
86
+ expr_step = parse_latex(expr_step)
87
+ # Solving argument
88
+ for letter in visit_list:
89
+ try:
90
+ result = solve(expr_step, letter)
91
+ result = choose_result(result)
92
+ except:
93
+ ANNO.mid_varible_dict[letter] = letter
94
+ continue
95
+ if not result is None:
96
+ result = latex(result)
97
+ is_update = True
98
+ result_t = result[:]
99
+ for item in spec_token_list:
100
+ result_t = result_t.replace(item, '')
101
+ # more than one unknown varibles
102
+ for letter_t in visit_list:
103
+ if letter_t in result_t and letter<letter_t:
104
+ is_update = False
105
+ break
106
+ # intermediate variables are existed
107
+ for key, value in midvar2letter.items():
108
+ if value in result_t:
109
+ is_update = False
110
+ break
111
+ if is_update:
112
+ ANNO.mid_varible_dict[letter] = result
113
+ else:
114
+ ANNO.mid_varible_dict[letter] = letter
115
+
116
+ # Solving process(intermediate) variable
117
+ for key1, value1 in midvar2letter.items():
118
+ if value1 in str(expr_step):
119
+ result = solve(expr_step, value1)
120
+ result = choose_result(result)
121
+ if not result is None:
122
+ # Convert the intermediate variable to latex form
123
+ result = latex(result)
124
+ # Convert lowercase letters to intermediate variables V_i
125
+ is_update = True
126
+ # more than one intermediate variables, only take the front intermediate variables
127
+ for key2, value2 in midvar2letter.items():
128
+ if value2 in result and value1<value2:
129
+ is_update = False
130
+ break
131
+ result.replace(value2, key2)
132
+ if is_update:
133
+ ANNO.mid_varible_dict[key1] = result
134
+ else:
135
+ ANNO.mid_varible_dict[key1] = key1
136
+
137
+ def mid_var_update(ANNO, visit_list, midvar2letter, midletter2var, is_subs_visit=True):
138
+
139
+ has_solved_list = []
140
+ # Find solved process varibles and arguments
141
+ for key, value in ANNO.mid_varible_dict.items():
142
+ if value!='' and isinstance(parse_latex(value).evalf(), Float) or \
143
+ (key in visit_list and is_subs_visit):
144
+ if not key in midvar2letter:
145
+ has_solved_list.append(key)
146
+ else:
147
+ has_solved_list.append(midvar2letter[key])
148
+
149
+ for key, mid_var in ANNO.mid_varible_dict.items():
150
+ if value!='' and not key in has_solved_list:
151
+ # Process varibles V_i are replaced as lowercase letters
152
+ for key1, value1 in midvar2letter.items():
153
+ mid_var = mid_var.replace(key1, value1)
154
+ # Special characters are replaced with '@' for marking
155
+ mid_var_t = mid_var[:]
156
+ for item in spec_token_list:
157
+ mid_var_t = mid_var_t.replace(item, "@"*len(item))
158
+ # Lowercase letters are replaced with solved values
159
+ mid_var_new = ''
160
+ for id in range(len(mid_var_t)):
161
+ if mid_var_t[id]!="@" and mid_var_t[id] in has_solved_list:
162
+ if mid_var_t[id] in midletter2var:
163
+ mid_var_new += "("+ANNO.mid_varible_dict[midletter2var[mid_var_t[id]]]+')'
164
+ else:
165
+ mid_var_new += "("+ANNO.mid_varible_dict[mid_var_t[id]]+')'
166
+ else:
167
+ mid_var_new += mid_var[id]
168
+ # Lowercase letters are replaced with V_i
169
+ for key2, value2 in midvar2letter.items():
170
+ mid_var_new = mid_var_new.replace(value2, key2)
171
+ ANNO.mid_varible_dict[key] = mid_var_new
172
+
173
+ def Get(ANNO, arg_list):
174
+ """
175
+ Get(a) -> get numerical value of a
176
+ """
177
+ if len(arg_list)!=1:
178
+ print("<Gets> function has only 1 augment!")
179
+ raise Exception
180
+
181
+ if arg_list[0] in ANNO.mid_varible_dict:
182
+ result = ANNO.mid_varible_dict[arg_list[0]]
183
+ else:
184
+ result_v = ANNO.varible_dict[arg_list[0]]
185
+ result_t = result_v[:]
186
+ for item in spec_token_list:
187
+ result_t = result_t.replace(item, "@"*len(item))
188
+ # Lowercase letters are replaced with solved values
189
+ result = ''
190
+ for id in range(len(result_t)):
191
+ if result_t[id]!="@" and result_t[id] in ANNO.mid_varible_dict:
192
+ result += "("+ANNO.mid_varible_dict[result_t[id]]+')'
193
+ else:
194
+ result += result_v[id]
195
+ ANNO.result = format(float(parse_latex(result).evalf()),'0.3f')
196
+
197
+ def Sum(arg_list):
198
+ """
199
+ Sum(a, b, c, d) -> a+b+c=d
200
+ """
201
+ if len(arg_list)<3:
202
+ print("<Sum> function has 3 augments at least!")
203
+ raise Exception
204
+ expr_step = arg_list[0]
205
+ for item in arg_list[1:-1]:
206
+ expr_step += "+" + item
207
+ expr_step += "-" + arg_list[-1]
208
+ return expr_step
209
+
210
+ def Multiple(arg_list):
211
+ """
212
+ Multiple(a, b, c, d, e) -> a*b*c*d=e
213
+ """
214
+ if len(arg_list)<3:
215
+ print("<Product> function has 3 augments at least!")
216
+ raise Exception
217
+ expr_step = arg_list[0]
218
+ for item in arg_list[1:-1]:
219
+ expr_step += "*" + item
220
+ expr_step += "-" + arg_list[-1]
221
+ return expr_step
222
+
223
+ def Equal(arg_list):
224
+ """
225
+ Equal(a, b) -> a=b
226
+ """
227
+ if len(arg_list)!=2:
228
+ print("<Equal> function has 2 augments!")
229
+ raise Exception
230
+ expr_step = arg_list[0] + "-" + arg_list[-1]
231
+ return expr_step
232
+
233
+ def Iso_Tri_Ang(arg_list):
234
+ """
235
+ Iso_Tri_Ang(a, b) -> a+2*b=180
236
+ """
237
+ if len(arg_list)!=2:
238
+ print("<Iso_Tri_Ang> function has 2 augments!")
239
+ raise Exception
240
+ expr_step = arg_list[0] + "+2*" + arg_list[-1]+"-180"
241
+ return expr_step
242
+
243
+ def Gougu(arg_list):
244
+ """
245
+ Gougu(a, b, c) -> a^2+b^2=c^2
246
+ """
247
+ if len(arg_list)!=3:
248
+ print("<Gougu> function has 3 augments!")
249
+ raise Exception
250
+ expr_step = arg_list[0]+'^{2}'+"+"+arg_list[1]+"^{2}"+'-'+arg_list[2]+"^{2}"
251
+ return expr_step
252
+
253
+ def Gsin(arg_list):
254
+ """
255
+ Gsin(a, b, c) -> sin(c)=a/b
256
+ """
257
+ if len(arg_list)!=3:
258
+ print("<Gsin> function has 3 augments!")
259
+ raise Exception
260
+ expr_step = arg_list[0]+'/'+arg_list[1]+'-'+'\\sin{'+arg_list[2]+'/180*3.141593}'
261
+ return expr_step
262
+
263
+ def Gcos(arg_list):
264
+ """
265
+ Gcos(a, b, c) -> cos(c)=a/b
266
+ """
267
+ if len(arg_list)!=3:
268
+ print("<Gcos> function has 3 augments!")
269
+ raise Exception
270
+ expr_step = arg_list[0]+'/'+arg_list[1]+'-'+'\\cos{'+arg_list[2]+'/180*3.141593}'
271
+ return expr_step
272
+
273
+ def Gtan(arg_list):
274
+ """
275
+ Gtan(a, b, c) -> tan(c)=a/b
276
+ """
277
+ if len(arg_list)!=3:
278
+ print("<Gtan> function has 3 augments!")
279
+ raise Exception
280
+ expr_step = arg_list[0]+'/'+arg_list[1]+'-'+'\\tan{'+arg_list[2]+'/180*3.141593}'
281
+ return expr_step
282
+
283
+ def Cos_Law(arg_list):
284
+ """
285
+ Cos_Law(a, b, c, d) -> a^2=b^2+c^2-2*b*c
286
+ """
287
+ if len(arg_list)!=4:
288
+ print("<Cos_Law> function has 4 augments!")
289
+ raise Exception
290
+ expr_step = arg_list[1]+'^{2}'+"+"+arg_list[2]+"^{2}"+'-'+arg_list[0]+"^{2}"+ \
291
+ '-'+"2*"+arg_list[1]+'*'+arg_list[2]+'*'+'\\cos{'+arg_list[3]+'/180*3.141593}'
292
+ return expr_step
293
+
294
+ def Sin_Law(arg_list):
295
+ """
296
+ Sin_Law(a, b, c, d) -> sin(a)/b=sin(c)/d
297
+ """
298
+ if len(arg_list)!=4:
299
+ print("<Sin_Law> function has 4 augments!")
300
+ raise Exception
301
+ expr_step = arg_list[3]+'*'+'\\sin{'+arg_list[0]+'/180*3.141593}'+'-'+ \
302
+ arg_list[1]+'*'+'\\sin{'+arg_list[2]+'/180*3.141593}'
303
+ return expr_step
304
+
305
+ def Median(arg_list):
306
+ """
307
+ Median(a, b, c) -> a+c=2*b
308
+ """
309
+ if len(arg_list)!=3:
310
+ print("<Median> function has 3 augments!")
311
+ raise Exception
312
+ expr_step = arg_list[0]+'-2*'+arg_list[1]+"+"+arg_list[2]
313
+ return expr_step
314
+
315
+ def Geo_Mean(arg_list):
316
+ """
317
+ Geo_Mean(a, b, c) -> a*b=c^2
318
+ """
319
+ if len(arg_list)!=3:
320
+ print("<Geo_Mean> function has 3 augments!")
321
+ raise Exception
322
+ expr_step = arg_list[0]+'*'+arg_list[1]+"-"+arg_list[2]+'^{2}'
323
+ return expr_step
324
+
325
+ def Proportion(arg_list):
326
+ """
327
+ Proportion(a, b, c, d) -> a/b=c/d
328
+ Proportion(a, b, c, d, e) -> (a/b)^e=c/d
329
+ """
330
+ if len(arg_list)<4:
331
+ print("<Proportion> function has 4 augments at least!")
332
+ raise Exception
333
+ if len(arg_list)==4:
334
+ expr_step = arg_list[0]+'*'+arg_list[3]+"-"+arg_list[1]+'*'+arg_list[2]
335
+ else:
336
+ expr_step = arg_list[0]+'*'+arg_list[3]+'^{1/'+arg_list[4]+"}-"+arg_list[1]+'*'+arg_list[2]+'^{1/'+arg_list[4]+"}"
337
+ return expr_step
338
+
339
+ def Ratio(arg_list):
340
+ """
341
+ Ratio(a, b, c) -> a/b=c
342
+ Ratio(a, b, c, d) -> (a/b)^c=d
343
+ """
344
+ if len(arg_list)<3 or len(arg_list)>4:
345
+ print("<Power> function has 3 or 4 augments!")
346
+ raise Exception
347
+ if len(arg_list)==3:
348
+ expr_step = arg_list[0]+' / '+arg_list[1]+'-'+arg_list[2]
349
+ else:
350
+ expr_step = '('+arg_list[0]+' / '+arg_list[1]+')^{'+arg_list[2]+"}"+"-"+arg_list[3]
351
+ return expr_step
352
+
353
+ def Chord2_Ang(arg_list):
354
+ """
355
+ Chord2_Ang(a, b, c) -> a=(b+c)/2
356
+ """
357
+ if len(arg_list)!=3:
358
+ print("<Chord2_Ang> function has 3 augments!")
359
+ raise Exception
360
+ expr_step = arg_list[0]+'*2-'+arg_list[1]+'-'+arg_list[2]
361
+ return expr_step
362
+
363
+ def TanSec_Ang(arg_list):
364
+ """
365
+ TanSec_Ang(a, b, c) -> a=(c-b)/2
366
+ """
367
+ if len(arg_list)!=3:
368
+ print("<TanSec_Ang> function has 3 augments!")
369
+ raise Exception
370
+ expr_step = arg_list[0]+'*2+'+arg_list[1]+'-'+arg_list[2]
371
+ return expr_step
372
+
373
+ def Tria_BH_Area(arg_list):
374
+ """
375
+ Tria_BH_Area(a, b, c) -> a*b/2=c
376
+ """
377
+ if len(arg_list)!=3:
378
+ print("<Tria_BH_Area> function has 3 augments!")
379
+ raise Exception
380
+ expr_step = arg_list[0]+'*'+arg_list[1]+'*0.5-'+arg_list[2]
381
+ return expr_step
382
+
383
+ def Tria_SAS_Area(arg_list):
384
+ """
385
+ Tria_SAS_Area(a, b, c, d) -> a*c*sin(b)/2=d
386
+ """
387
+ if len(arg_list)!=4:
388
+ print("<Tria_SAS_Area> function has 4 augments!")
389
+ raise Exception
390
+ expr_step = arg_list[0]+'*'+arg_list[2]+'*0.5*\\sin{'+arg_list[1]+'/180*3.141593}-'+arg_list[3]
391
+ return expr_step
392
+
393
+ def PRK_Perim(arg_list):
394
+ """
395
+ PRK_Perim(a, b, c) -> (a+b)*2=c
396
+ """
397
+ if len(arg_list)!=3:
398
+ print("<PRK_Perim> function has 3 augments!")
399
+ raise Exception
400
+ expr_step = arg_list[0]+'*2+'+arg_list[1]+'*2-'+arg_list[2]
401
+ return expr_step
402
+
403
+ def Para_Area(arg_list):
404
+ """
405
+ Para_Area(a, b, c) -> a*b=c
406
+ """
407
+ if len(arg_list)!=3:
408
+ print("<Para_Area> function has 3 augments!")
409
+ raise Exception
410
+ expr_step = arg_list[0]+'*'+arg_list[1]+'-'+arg_list[2]
411
+ return expr_step
412
+
413
+ def Rect_Area(arg_list):
414
+ """
415
+ Rect_Area(a, b, c) -> a*b=c
416
+ """
417
+ if len(arg_list)!=3:
418
+ print("<Rect_Area> function has 3 augments!")
419
+ raise Exception
420
+ expr_step = arg_list[0]+'*'+arg_list[1]+'-'+arg_list[2]
421
+ return expr_step
422
+
423
+ def Rhom_Area(arg_list):
424
+ """
425
+ Rhom_Area(a, b, c) -> a*b*2=c
426
+ """
427
+ if len(arg_list)!=3:
428
+ print("<Phom_Area> function has 3 augments!")
429
+ raise Exception
430
+ expr_step = arg_list[0]+'*'+arg_list[1]+'*2-'+arg_list[2]
431
+ return expr_step
432
+
433
+ def Kite_Area(arg_list):
434
+ """
435
+ Kite_Area(a, b, c) -> a*b/2=c
436
+ """
437
+ if len(arg_list)!=3:
438
+ print("<Kite_Area> function has 3 augments!")
439
+ raise Exception
440
+ expr_step = arg_list[0]+'*'+arg_list[1]+'*0.5-'+arg_list[2]
441
+ return expr_step
442
+
443
+ def Trap_Area(arg_list):
444
+ """
445
+ Trap_Area(a, b, c, d) -> (a+b)*c/2=d
446
+ """
447
+ if len(arg_list)!=4:
448
+ print("<Trap_Area> function has 4 augments!")
449
+ raise Exception
450
+ expr_step = '0.5*('+arg_list[0]+'+'+arg_list[1]+')*'+arg_list[2]+'-'+arg_list[3]
451
+ return expr_step
452
+
453
+ def Circle_R_Circum(arg_list):
454
+ """
455
+ Circle_R_Circum(a, b) -> 2*pi*a=b
456
+ Circle_R_Circum(a, b, c) -> 2*pi*a*b/360=c
457
+ """
458
+ if len(arg_list)<2 or len(arg_list)>3:
459
+ print("<Circle_Circum> function has 2 or 3 augments!")
460
+ raise Exception
461
+ if len(arg_list)==2:
462
+ expr_step = '2*3.141593*'+arg_list[0]+'-'+arg_list[1]
463
+ else:
464
+ expr_step = '2*3.141593*'+arg_list[0]+'*'+arg_list[1]+'/360'+'-'+arg_list[2]
465
+ return expr_step
466
+
467
+ def Circle_D_Circum(arg_list):
468
+ """
469
+ Circle_D_Circum(a, b) -> pi*a=b
470
+ Circle_D_Circum(a, b, c) -> pi*a*b/360=c
471
+ """
472
+ if len(arg_list)<2 or len(arg_list)>3:
473
+ print("<Circle_Circum> function has 2 or 3 augments!")
474
+ raise Exception
475
+ if len(arg_list)==2:
476
+ expr_step = '3.141593*'+arg_list[0]+'-'+arg_list[1]
477
+ else:
478
+ expr_step = '3.141593*'+arg_list[0]+'*'+arg_list[1]+'/360'+'-'+arg_list[2]
479
+ return expr_step
480
+
481
+ def Circle_R_Area(arg_list):
482
+ """
483
+ Circle_R_Area(a, b) -> pi*a^2=b
484
+ Circle_R_Area(a, b, c) -> pi*a^2*b/360=c
485
+ """
486
+ if len(arg_list)<2 and len(arg_list)>3:
487
+ print("<Circle_Area> function has 2 or 3 augments!")
488
+ raise Exception
489
+ if len(arg_list)==2:
490
+ expr_step = '3.141593*'+arg_list[0]+'^{2}-'+arg_list[1]
491
+ else:
492
+ expr_step = '3.141593*'+arg_list[0]+'^{2}*'+arg_list[1]+'/360'+'-'+arg_list[2]
493
+ return expr_step
494
+
495
+ def Circle_D_Area(arg_list):
496
+ """
497
+ Circle_D_Area(a, b) -> pi*(a/2)^2=b
498
+ Circle_D_Area(a, b, c) -> pi*(a/2)^2*b/360=c
499
+ """
500
+ if len(arg_list)<2 and len(arg_list)>3:
501
+ print("<Circle_Area> function has 2 or 3 augments!")
502
+ raise Exception
503
+ if len(arg_list)==2:
504
+ expr_step = '0.25*3.141593*'+arg_list[0]+'^{2}-'+arg_list[1]
505
+ else:
506
+ expr_step = '0.25*3.141593*'+arg_list[0]+'^{2}*'+arg_list[1]+'/360'+'-'+arg_list[2]
507
+ return expr_step
508
+
509
+ def ArcSeg_Area(arg_list):
510
+ """
511
+ ArcSeg_Area(a, b, c) -> pi*a^2*b/360 - a^2*sin(b)/2 = c
512
+ """
513
+ if len(arg_list)!=3:
514
+ print("<ArcSeg_Area> function has 3 augments!")
515
+ raise Exception
516
+ expr_step = '3.141593*'+arg_list[0]+'^{2}*'+arg_list[1]+'/360-0.5*'+ \
517
+ arg_list[0]+'^{2}*\\sin{'+arg_list[1]+'/180*3.141593}-'+arg_list[2]
518
+ return expr_step
519
+
520
+ def Ngon_Angsum(arg_list):
521
+ """
522
+ Ngon_Ang(a, b) -> (a-2)*180=b
523
+ """
524
+ if len(arg_list)!=2:
525
+ print("<Ngon_Ang> function has 2 augments!")
526
+ raise Exception
527
+ expr_step = '('+arg_list[0]+'-2)*180-'+arg_list[1]
528
+ return expr_step
529
+
530
+ def RNgon_B_Area(arg_list):
531
+ """
532
+ RNgon_B_Area(a, b, c) -> a*b^2/tan(180/a)/4=c
533
+ """
534
+ if len(arg_list)!=3:
535
+ print("<RNgon_B_Area> function has 3 augments!")
536
+ raise Exception
537
+ expr_step = arg_list[0]+'*'+arg_list[1]+'^{2}/4/\\tan{3.141593/'+arg_list[0]+'}-'+arg_list[2]
538
+ return expr_step
539
+
540
+ def RNgon_L_Area(arg_list):
541
+ """
542
+ RNgon_L_Area(a, b, c) -> a*b^2*sin(360/a)/2=c
543
+ """
544
+ if len(arg_list)!=3:
545
+ print("<RNgon_L_Area> function has 3 augments!")
546
+ raise Exception
547
+ expr_step = arg_list[0]+'*'+arg_list[1]+'^{2}*0.5*\\sin{2*3.141593/'+arg_list[0]+'}-'+arg_list[2]
548
+ return expr_step
549
+
550
+ def RNgon_H_Area(arg_list):
551
+ """
552
+ RNgon_H_Area(a, b, c) -> a*b^2*tan(180/a)=c
553
+ """
554
+ if len(arg_list)!=3:
555
+ print("<RNgon_H_Area> function has 3 augments!")
556
+ raise Exception
557
+ expr_step = arg_list[0]+'*'+arg_list[1]+'^{2}*\\tan{3.141593/'+arg_list[0]+'}-'+arg_list[2]
558
+ return expr_step
559
+
560
+ def result_compute(num_all_list, exp_tokens):
561
+ ANNO = Varible_Record()
562
+ # Obtain the mapping between lowercase letters to intermediate variables V_i
563
+ visit_list = [] # arguments denoted by lowercase letters
564
+ for num in num_all_list:
565
+ for item in spec_token_list:
566
+ num = num.replace(item, "@"*len(item))
567
+ for letter in num:
568
+ if letter in low_case_list: visit_list.append(letter)
569
+ for id, var in enumerate(num_all_list):
570
+ ANNO.varible_dict["N"+str(id)] = var
571
+ ANNO.mid_varible_dict["N"+str(id)] = var
572
+ visit_list.sort()
573
+ no_visit_list = list(set(low_case_list)-set(spec_letter_list)-set(visit_list))
574
+ no_visit_list.sort() # lowercase letters which have not used
575
+ # mapping between letters to intermediate variables V_i
576
+ midvar2letter = dict()
577
+ midletter2var = dict()
578
+ for id in range(V_NUM):
579
+ midvar2letter['V'+str(id)] = no_visit_list[id]
580
+ midletter2var[no_visit_list[id]] = 'V'+str(id)
581
+ # step split
582
+ step_list = []
583
+ last_op_id = 0
584
+ for id, token in enumerate(exp_tokens):
585
+ if token in arith_op_list and id>0:
586
+ step_list.append(exp_tokens[last_op_id:id])
587
+ last_op_id = id
588
+ step_list.append(exp_tokens[last_op_id:])
589
+ # run step
590
+ for id, step in enumerate(step_list):
591
+ operator = is_exist_operator(step[0], ANNO)
592
+ if operator!='Get':
593
+ operands = operand_update(step[1:], ANNO)
594
+ expr_step = eval(operator)(operands)
595
+ mid_var_solve(expr_step, ANNO, visit_list, midvar2letter)
596
+ mid_var_update(ANNO, visit_list, midvar2letter, midletter2var, True)
597
+ mid_var_update(ANNO, visit_list, midvar2letter, midletter2var, False)
598
+ else:
599
+ Get(ANNO, step[1:])
600
+
601
+ return ANNO.result
602
+
603
+ def normalize_exp(exp):
604
+ # step split
605
+ step_list = []
606
+ last_op_id = 0
607
+ for id, token in enumerate(exp):
608
+ if token in arith_op_list and id>0:
609
+ step_list.append(exp[last_op_id:id])
610
+ last_op_id = id
611
+ step_list.append(exp[last_op_id:])
612
+ # normalize step
613
+ new_exp = []
614
+ for step in step_list:
615
+ if step[0] in alterable_order_ops:
616
+ if step[0] in ['Sum', 'Multiple']:
617
+ begin_id, end_id = 1, -1
618
+ step[begin_id: end_id] = sorted(step[begin_id: end_id], key=lambda token:get_priority(token))
619
+ if step[0] in ['Equal', 'Gougu', 'PRK_Perim', 'Rect_Area', 'Rhom_Area', 'Trap_Area']:
620
+ begin_id, end_id = 1, 3
621
+ step[begin_id: end_id] = sorted(step[begin_id: end_id], key=lambda token:get_priority(token))
622
+ if step[0] == 'Cos_Law':
623
+ begin_id, end_id = 2, 4
624
+ step[begin_id: end_id] = sorted(step[begin_id: end_id], key=lambda token:get_priority(token))
625
+ if step[0] in ['Sin_Law', 'Proportion']:
626
+ if get_priority(step[1])>get_priority(step[3]) and len(step)==5:
627
+ step[1:3], step[3:5] = step[3:5], step[1:3]
628
+ if step[0] in ['Tria_SAS_Area', 'Median']:
629
+ if get_priority(step[1])>get_priority(step[3]):
630
+ step[1], step[3] = step[3], step[1]
631
+ new_exp += step
632
+
633
+ return new_exp
datasets/preprossing.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from datasets.utils import *
4
+
5
+ class SrcLang:
6
+
7
+ def __init__(self, vocab_path):
8
+ self.word2index = {}
9
+ self.word2count = {}
10
+ self.index2word = []
11
+ self.n_words = 0
12
+ self.get_vocab(vocab_path)
13
+ self.class_tag = ['[PAD]', '[GEN]', '[POINT]', '[NUM]', '[ARG]', '[ANGID]']
14
+ self.sect_tag = ['[PAD]', '[PROB]', '[COND]', '[STRU]']
15
+
16
+ def get_vocab(self, vocab_path):
17
+ with open(vocab_path, 'r') as f:
18
+ for id, line in enumerate(f):
19
+ vocab_token = line[:-1]
20
+ self.word2index[vocab_token] = id
21
+ self.word2count[vocab_token] = 0
22
+ self.index2word.append(vocab_token)
23
+ self.n_words = len(self.index2word)
24
+
25
+ def indexes_from_sentence(self, sentence, id_type='text'):
26
+ res = []
27
+ if id_type == 'text':
28
+ for word in sentence:
29
+ if word in self.word2index:
30
+ res.append(self.word2index[word])
31
+ self.word2count[word] += 1
32
+ else:
33
+ res.append(self.word2index["[UNK]"])
34
+ self.word2count["[UNK]"] += 1
35
+ print("Can not find", word, 'in the src vocab')
36
+ elif id_type=='class_tag':
37
+ for word in sentence: res.append(self.class_tag.index(word))
38
+ elif id_type=='sect_tag':
39
+ for word in sentence: res.append(self.sect_tag.index(word))
40
+ return res
41
+
42
+ def sentence_from_indexes(self, indexes):
43
+ res = []
44
+ for index in indexes:
45
+ if index<len(self.index2word):
46
+ res.append(self.index2word[index])
47
+ else:
48
+ res.append("")
49
+ return res
50
+
51
+ class TgtLang:
52
+
53
+ def __init__(self, vocab_path):
54
+ self.word2index = {}
55
+ self.word2count = {}
56
+ self.index2word = []
57
+ self.n_words = 0
58
+ self.var_start = 0
59
+ self.get_vocab(vocab_path)
60
+
61
+ def get_vocab(self, vocab_path):
62
+ spe_num = midvar_num = const_num = 0
63
+ op_num = var_num = 0
64
+
65
+ with open(vocab_path, 'r') as f:
66
+ for id, line in enumerate(f):
67
+ vocab_token = line[:-1]
68
+ self.word2index[vocab_token] = id
69
+ self.word2count[vocab_token] = 0
70
+ self.index2word.append(vocab_token)
71
+ if vocab_token[0]=='[' and vocab_token[-1]==']':
72
+ spe_num += 1
73
+ elif vocab_token[0]=='V' and vocab_token[1].isdigit():
74
+ midvar_num += 1
75
+ elif vocab_token[0]=='C' and vocab_token[1].isdigit():
76
+ const_num += 1
77
+ elif vocab_token[0]=='N' and vocab_token[1].isdigit():
78
+ var_num += 1
79
+ else:
80
+ op_num += 1
81
+
82
+ self.n_words = len(self.index2word)
83
+ self.var_start = spe_num + midvar_num + const_num + op_num
84
+
85
+ def indexes_from_sentence(self, sentence, var_values, arg_values):
86
+ res = []
87
+ for word in sentence:
88
+ if word in self.word2index:
89
+ res.append(self.word2index[word])
90
+ self.word2count[word] += 1
91
+ elif len(word)==1 and word.islower(): # arg
92
+ res.append(self.var_start+len(var_values)+arg_values.index(word))
93
+ else:
94
+ print("Can not find", word, 'in the tgt vocab')
95
+ res = [self.word2index["[SOS]"]]+res+[self.word2index["[EOS]"]]
96
+ return res
97
+
98
+ def sentence_from_indexes(self, indexes, change_dict={}):
99
+ res = []
100
+ for index in indexes:
101
+ if index<len(self.index2word):
102
+ item = self.index2word[index]
103
+ else:
104
+ item = ''
105
+ if item in change_dict: item = change_dict[item] # var2arg
106
+ res.append(item)
107
+ return res
108
+
109
+ class SN:
110
+ def __init__(self):
111
+ self.token = [] # str list
112
+ self.sect_tag = [] # [PROB]/[COND]/[STRU]
113
+ self.class_tag = [] # [GEN]/[NUM]/[ARG]/[POINT]/[ANGID]
114
+
115
+ def get_raw_pairs(dataset_path):
116
+
117
+ raw_pairs = []
118
+
119
+ with open(dataset_path, 'r')as fp:
120
+ content_all = json.load(fp)
121
+
122
+ for key, content in content_all.items():
123
+ text = content['text']
124
+ stru_seqs = content['parsing_stru_seqs']
125
+ sem_seqs = content['parsing_sem_seqs']
126
+ text_data, stru_data, sem_data = SN(), SN(), SN()
127
+ # tokenization
128
+ text_data.token = get_token(text)
129
+ stru_data.token = [get_token(item)+[','] for item in stru_seqs]
130
+ sem_data.token = [get_token(item)+[','] for item in sem_seqs]
131
+ # split prob and cond
132
+ text_data.sect_tag = []
133
+ stru_data.sect_tag = [['[STRU]']*len(item) for item in stru_data.token]
134
+ sem_data.sect_tag = [['[COND]']*len(item) for item in sem_data.token]
135
+ split_text(text_data)
136
+ # get class tag
137
+ text_data.class_tag = ['[GEN]']*len(text_data.token)
138
+ stru_data.class_tag = [['[GEN]']*len(item) for item in stru_data.token]
139
+ sem_data.class_tag = [['[GEN]']*len(item) for item in sem_data.token]
140
+ get_point_angleID_tag(text_data, stru_data, sem_data)
141
+ get_num_arg_tag(text_data, sem_data)
142
+ # Tag the repeat [NUM] in sem_data which has exist in text_data
143
+ expression = content['expression'].split(' ')
144
+ remove_sem_dup(text_data, sem_data, expression)
145
+
146
+ content['text'] = text_data
147
+ content['parsing_stru_seqs'] = stru_data
148
+ content['parsing_sem_seqs'] = sem_data
149
+ content['expression'] = expression
150
+ content['id'] = key
151
+
152
+ raw_pairs.append(content)
153
+
154
+ return raw_pairs
155
+
156
+ class collater():
157
+
158
+ def __init__(self, args):
159
+ self.args = args
160
+
161
+ def __call__(self, batch_data, padding_id=0):
162
+ diagrams, \
163
+ text_tokens, text_sect_tags, text_class_tags, \
164
+ var_arg_positions, var_values, arg_values, \
165
+ expression, answer, pair_ids, choices = list(zip(*batch_data))
166
+ #######################################
167
+ diagrams = torch.stack(diagrams, dim=0)
168
+ #######################################
169
+ len_exp = [len(seq_exp) for seq_exp in expression]
170
+ max_len_exp = max(len_exp)
171
+ expression = [seq_exp+[padding_id]*(max_len_exp-len(seq_exp)) for seq_exp in expression]
172
+ exp_dict = {'exp': torch.LongTensor(expression),
173
+ 'len': torch.LongTensor(len_exp),
174
+ 'answer': answer,
175
+ 'id': pair_ids,
176
+ 'choices': choices
177
+ }
178
+ #######################################
179
+ len_var = [max(len(seq_var),1) for seq_var in var_arg_positions]
180
+ max_len_var = max(len_var)
181
+ var_arg_positions = [seq_var+[padding_id]*(max_len_var-len(seq_var)) for seq_var in var_arg_positions]
182
+ var_dict = {'pos':torch.LongTensor(var_arg_positions),
183
+ 'len': torch.LongTensor(len_var),
184
+ 'var_value': var_values,
185
+ 'arg_value': arg_values
186
+ }
187
+ ########################################
188
+ len_text = [len(seq_tag) for seq_tag in text_class_tags]
189
+ max_len_text = max(len_text)
190
+ for k in range(len(text_tokens)):
191
+ for j in range(len(text_tokens[k])):
192
+ text_tokens[k][j] += [padding_id]*(max_len_text-len(text_tokens[k][j]))
193
+ text_sect_tags = [seq_tag+[padding_id]*(max_len_text-len(seq_tag)) for seq_tag in text_sect_tags]
194
+ text_class_tags = [seq_tag+[padding_id]*(max_len_text-len(seq_tag)) for seq_tag in text_class_tags]
195
+ text_dict = {'token': torch.LongTensor(text_tokens),
196
+ 'sect_tag': torch.LongTensor(text_sect_tags),
197
+ 'class_tag': torch.LongTensor(text_class_tags),
198
+ 'len': torch.LongTensor(len_text)
199
+ }
200
+
201
+ return diagrams, text_dict, var_dict, exp_dict
datasets/text_aug.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ upper_case_list = [chr(i) for i in range(65, 91)]
4
+ low_case_list = [chr(i) for i in range(97, 123)]
5
+ angle_id_list = [str(i) for i in range(1, 21)]
6
+ spec_token_list = ['frac', 'pi', 'sqrt']
7
+
8
+ class Compose(object):
9
+
10
+ def __init__(self, transforms):
11
+ self.transforms = transforms
12
+
13
+ def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
14
+ for t in self.transforms:
15
+ t(text_seq, stru_seqs, sem_seqs, exp)
16
+
17
+ def __repr__(self):
18
+ format_string = self.__class__.__name__ + "("
19
+ for t in self.transforms:
20
+ format_string += "\n"
21
+ format_string += " {0}".format(t)
22
+ format_string += "\n)"
23
+ return format_string
24
+
25
+ class Point_RandomReplace(object):
26
+
27
+ def __init__(self, prob=0.5):
28
+ self.prob = prob
29
+
30
+ def get_point_map(self):
31
+ value_list = [chr(i) for i in range(65, 91)]
32
+ random.shuffle(value_list)
33
+ map_dict = {key:value for key, value in zip(upper_case_list, value_list)}
34
+ return map_dict
35
+
36
+ def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
37
+ if random.random() < self.prob:
38
+ map_dict = self.get_point_map()
39
+ for k in range(len(text_seq.token)):
40
+ if text_seq.class_tag[k] == '[POINT]':
41
+ text_seq.token[k] = map_dict[text_seq.token[k][0]]
42
+ for k in range(len(stru_seqs.token)):
43
+ for j in range(len(stru_seqs.token[k])):
44
+ if stru_seqs.class_tag[k][j] == '[POINT]':
45
+ stru_seqs.token[k][j] = map_dict[stru_seqs.token[k][j][0]]
46
+ for k in range(len(sem_seqs.token)):
47
+ for j in range(len(sem_seqs.token[k])):
48
+ if sem_seqs.class_tag[k][j] == '[POINT]':
49
+ sem_seqs.token[k][j] = map_dict[sem_seqs.token[k][j][0]]
50
+
51
+ class AngID_RandomReplace(object):
52
+
53
+ def __init__(self, prob=0.5):
54
+ self.prob = prob
55
+
56
+ def get_angid_map(self):
57
+ value_list = [str(i) for i in range(1, 21)]
58
+ random.shuffle(value_list)
59
+ map_dict = {key:value for key, value in zip(angle_id_list, value_list)}
60
+ return map_dict
61
+
62
+ def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
63
+ if random.random() < self.prob:
64
+ map_dict = self.get_angid_map()
65
+ for k in range(len(text_seq.token)):
66
+ if text_seq.class_tag[k] == '[ANGID]':
67
+ text_seq.token[k] = map_dict[text_seq.token[k]]
68
+ for k in range(len(sem_seqs.token)):
69
+ for j in range(len(sem_seqs.token[k])):
70
+ if sem_seqs.class_tag[k][j] == '[ANGID]':
71
+ sem_seqs.token[k][j] = map_dict[sem_seqs.token[k][j]]
72
+
73
+ class Arg_RandomReplace(object):
74
+
75
+ def __init__(self, prob=0.5):
76
+ self.prob = prob
77
+
78
+ def get_arg_map(self):
79
+ value_list = [chr(i) for i in range(97, 123)]
80
+ random.shuffle(value_list)
81
+ map_dict = {key:value for key, value in zip(low_case_list, value_list)}
82
+ return map_dict
83
+
84
+ def map_arg_in_num(self, map_dict, num):
85
+ num_t = num[:]
86
+ new_num = ''
87
+ for item in spec_token_list:
88
+ num_t = num_t.replace(item, "@"*len(item))
89
+ for k in range(len(num_t)):
90
+ if num_t[k]!='@' and num[k] in low_case_list:
91
+ new_num += map_dict[num[k]]
92
+ else:
93
+ new_num += num[k]
94
+ return new_num
95
+
96
+ def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
97
+ if random.random() < self.prob:
98
+ map_dict = self.get_arg_map()
99
+ for k in range(len(text_seq.token)):
100
+ if text_seq.class_tag[k] == '[NUM]':
101
+ text_seq.token[k] = self.map_arg_in_num(map_dict, text_seq.token[k])
102
+ if text_seq.class_tag[k] == '[ARG]':
103
+ text_seq.token[k] = map_dict[text_seq.token[k]]
104
+ for k in range(len(sem_seqs.token)):
105
+ for j in range(len(sem_seqs.token[k])):
106
+ if sem_seqs.class_tag[k][j] == '[NUM]':
107
+ sem_seqs.token[k][j] = self.map_arg_in_num(map_dict, sem_seqs.token[k][j])
108
+ for k in range(len(exp)):
109
+ if exp[k] in low_case_list:
110
+ exp[k] = map_dict[exp[k]]
111
+
112
+ class StruPoint_RandomRotate(object):
113
+
114
+ def __init__(self, prob=0.5):
115
+ self.prob = prob
116
+
117
+ def get_seq_points(self, class_tag):
118
+ id_list = []
119
+ begin_point_id = end_point_id = None
120
+ for id, token in enumerate(class_tag):
121
+ if token == '[POINT]':
122
+ if begin_point_id is None:
123
+ begin_point_id = id
124
+ elif not begin_point_id is None and end_point_id is None:
125
+ end_point_id = id
126
+ id_list.append([begin_point_id, end_point_id])
127
+ begin_point_id = end_point_id = None
128
+ if not begin_point_id is None and end_point_id is None:
129
+ id_list.append([begin_point_id, len(class_tag)])
130
+
131
+ return id_list[-1][0], id_list[-1][1]
132
+
133
+ def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
134
+ for k in range(len(stru_seqs.token)):
135
+ if random.random() < self.prob:
136
+ begin_id, end_id = self.get_seq_points(stru_seqs.class_tag[k])
137
+ # point on line
138
+ if stru_seqs.token[k][0] == 'line':
139
+ stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][end_id-1:begin_id-1:-1]
140
+ # point on circle
141
+ if stru_seqs.token[k][0] == '\\odot':
142
+ # clockwise change
143
+ if random.random() < 0.5:
144
+ stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][end_id-1:begin_id-1:-1]
145
+ # set initial point
146
+ init_loc = random.randint(begin_id, end_id-1)
147
+ stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][init_loc:end_id] + \
148
+ stru_seqs.token[k][begin_id:init_loc]
149
+
150
+ class SemPoint_RandomRotate(object):
151
+
152
+ def __init__(self, prob=0.5):
153
+ self.prob = prob
154
+
155
+ def get_seq_points(self, class_tag):
156
+ id_list = []
157
+ begin_point_id = end_point_id = None
158
+ for id, token in enumerate(class_tag):
159
+ if token == '[POINT]':
160
+ if begin_point_id is None:
161
+ begin_point_id = id
162
+ elif not begin_point_id is None and end_point_id is None:
163
+ end_point_id = id
164
+ id_list.append((begin_point_id, end_point_id-1))
165
+ begin_point_id = end_point_id = None
166
+ if not begin_point_id is None and end_point_id is None:
167
+ id_list.append((begin_point_id, len(class_tag)-1))
168
+
169
+ return id_list
170
+
171
+ def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
172
+ if random.random() < self.prob:
173
+ for k in range(len(sem_seqs.token)):
174
+ id_list = self.get_seq_points(sem_seqs.class_tag[k])
175
+ for begin_id, end_id in id_list:
176
+ if random.random() < self.prob:
177
+ sem_seqs.token[k][begin_id], sem_seqs.token[k][end_id] = \
178
+ sem_seqs.token[k][end_id], sem_seqs.token[k][begin_id]
179
+
180
+ class SemSeq_RandomRotate(object):
181
+
182
+ def __init__(self, prob=0.5):
183
+ if prob==0:
184
+ self.prob = 0
185
+ else:
186
+ self.prob = prob + 0.2
187
+
188
+ def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
189
+ if random.random() < self.prob:
190
+ # varible id
191
+ num_all_list, num_sem_list, num_map_list = [], [], []
192
+ for item in text_seq.class_tag:
193
+ if item=='[NUM]':
194
+ var_name = 'N'+str(len(num_all_list))
195
+ num_all_list.append(var_name)
196
+ num_map_list.append(var_name)
197
+ for k in range(len(sem_seqs.token)):
198
+ if sem_seqs.class_tag[k][-2] == '[NUM]':
199
+ var_name = 'N'+str(len(num_all_list))
200
+ num_all_list.append(var_name)
201
+ num_sem_list.append([var_name])
202
+ else:
203
+ num_sem_list.append([])
204
+ # shuffle sem_seq
205
+ if len(sem_seqs.token)>0:
206
+ random_id_list = [k for k in range(len(sem_seqs.token))]
207
+ random.shuffle(random_id_list)
208
+ for key,value in vars(sem_seqs).items():
209
+ _, value = zip(*sorted(zip(random_id_list, value)))
210
+ setattr(sem_seqs, key, list(value))
211
+ _, num_sem_list = zip(*sorted(zip(random_id_list, num_sem_list)))
212
+ # expression map
213
+ for k in range(len(sem_seqs.token)):
214
+ num_map_list += num_sem_list[k]
215
+ num_map_dict = {key:value for key, value in zip(num_map_list, num_all_list)}
216
+ for k in range(len(exp)):
217
+ if exp[k] in num_map_dict:
218
+ exp[k] = num_map_dict[exp[k]]
219
+
220
+ class StruSeq_RandomRotate(object):
221
+
222
+ def __init__(self, prob=0.5):
223
+ self.prob = prob
224
+
225
+ def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
226
+ if random.random() < self.prob:
227
+ # shuffle stru_seq
228
+ if len(stru_seqs.token)>0:
229
+ random_id_list = [k for k in range(len(stru_seqs.token))]
230
+ random.shuffle(random_id_list)
231
+ for key, value in vars(stru_seqs).items():
232
+ _, value = zip(*sorted(zip(random_id_list, value)))
233
+ setattr(stru_seqs, key, list(value))
datasets/utils.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ punctuation_list = ['.', '?', ',']
2
+ digit_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
3
+ capital_letter_list = [chr(item) for item in range(65, 91)]
4
+ low_letter_list = [chr(item) for item in range(97, 123)]
5
+ begin_words = ["find", "what", "solve", "determine", "express", "how"]
6
+ end_words = [".", ",", '?', "if", "so", "for which", "given", "with", "on",
7
+ "in", "must", 'for', 'that', 'formed']
8
+ unit_list = ["mm^{2}", "cm^{2}", "in^{2}", "ft^{2}",
9
+ "yd^{2}", "km^{2}", "units^{2}", "mi^{2}", "m^{2}"]
10
+ special_token_list = ['\\frac', '\\pi', '\\sqrt', "+", "-", "^"]
11
+
12
+
13
+ def get_token(ss):
14
+ """
15
+ Tokenizer: divide the textual problem into words
16
+ """
17
+ raw_str_list = ss.strip().split(' ')
18
+ # Split punctuation
19
+ new_str1_list = []
20
+ for item in raw_str_list:
21
+ if item[-1] in punctuation_list:
22
+ new_str1_list.append(item[:-1])
23
+ new_str1_list.append(item[-1])
24
+ else:
25
+ new_str1_list.append(item)
26
+ # Split points (capital letters)
27
+ new_str2_list = []
28
+ for item in new_str1_list:
29
+ is_geo_rep = True
30
+ point_list = []
31
+ for k in item:
32
+ if (ord(k) >= 65 and ord(k) <= 90) or \
33
+ ((k == '\'' or k in digit_list) and len(point_list) > 0):
34
+ if k == '\'' or k in digit_list:
35
+ point_list[-1] += k
36
+ else:
37
+ point_list.append(k)
38
+ else:
39
+ is_geo_rep = False
40
+ break
41
+ if is_geo_rep:
42
+ new_str2_list += point_list
43
+ else:
44
+ new_str2_list.append(item.lower())
45
+
46
+ return new_str2_list
47
+
48
+ def split_text(text_data):
49
+ """
50
+ split textual problem into condition and problem(target)
51
+ """
52
+ if len(text_data.token) == 0:
53
+ return
54
+ begin_ind = 0
55
+ end_ind = len(text_data.token)
56
+ for id, token in enumerate(text_data.token):
57
+ if token in begin_words:
58
+ begin_ind = id
59
+ break
60
+ for id in range(begin_ind+2, len(text_data.token)):
61
+ if text_data.token[id] in end_words:
62
+ if text_data.token[id] in punctuation_list:
63
+ end_ind = id + 1
64
+ else:
65
+ end_ind = id
66
+ break
67
+ text_data.sect_tag = ['[COND]']*len(text_data.token[:begin_ind]) + \
68
+ ['[PROB]']*len(text_data.token[begin_ind: end_ind]) + \
69
+ ['[COND]']*len(text_data.token[end_ind:])
70
+
71
+ def get_point_angleID_tag(text_data, stru_data, sem_data):
72
+ for id, item in enumerate(text_data.token):
73
+ if item[0] in capital_letter_list:
74
+ text_data.class_tag[id] = '[POINT]'
75
+ if item.isdigit() and id > 0 and text_data.token[id-1] == "\\angle":
76
+ text_data.class_tag[id] = '[ANGID]'
77
+
78
+ for k in range(len(stru_data.token)):
79
+ for id, item in enumerate(stru_data.token[k]):
80
+ if item[0] in capital_letter_list:
81
+ stru_data.class_tag[k][id] = '[POINT]'
82
+ if item.isdigit() and id > 0 and stru_data.token[k][id-1] == "\\angle":
83
+ stru_data.class_tag[k][id] = '[ANGID]'
84
+
85
+ for k in range(len(sem_data.token)):
86
+ for id, item in enumerate(sem_data.token[k]):
87
+ if item[0] in capital_letter_list:
88
+ sem_data.class_tag[k][id] = '[POINT]'
89
+ if item.isdigit() and id > 0 and sem_data.token[k][id-1] == "\\angle":
90
+ sem_data.class_tag[k][id] = '[ANGID]'
91
+
92
+ def get_args(token):
93
+ letter_list = []
94
+ for special_token in special_token_list:
95
+ token = token.replace(special_token, "")
96
+ for letter in token:
97
+ if letter in low_letter_list and not letter in letter_list:
98
+ letter_list.append(letter)
99
+ return letter_list
100
+
101
+ def get_num_arg_tag(text_data, sem_data):
102
+ """
103
+ Determine the variables/arguments in the text condition
104
+ """
105
+ arg_sem_flat = []
106
+ for k in range(len(sem_data.token)):
107
+ if len(sem_data.token[k]) >= 3 and sem_data.token[k][-3] == '=':
108
+ sem_data.class_tag[k][-2] = '[NUM]'
109
+ arg_sem_flat += get_args(sem_data.token[k][-2])
110
+
111
+ for id, token in enumerate(text_data.token):
112
+ if text_data.sect_tag[id] == '[COND]' and text_data.class_tag[id] == '[GEN]':
113
+ # unit symbol
114
+ if token in unit_list:
115
+ continue
116
+ # digit existing (rough judgment)
117
+ for word in digit_list:
118
+ if word in token:
119
+ text_data.class_tag[id] = '[NUM]'
120
+ break
121
+ # There are special characters, but not only special characters
122
+ for word in special_token_list:
123
+ if word in token and word != token:
124
+ text_data.class_tag[id] = '[NUM]'
125
+ break
126
+ # Single lowercase letter, but not special cases
127
+ if text_data.token[id] in low_letter_list:
128
+ if id < len(text_data.token)-1 and text_data.token[id+1] == '=':
129
+ continue
130
+ if text_data.token[id] == 'm' and id < len(text_data.token)-1 and text_data.token[id+1] in ["\\angle", "\\widehat"]:
131
+ continue
132
+ if text_data.token[id] == 'a' and (id == 0 or text_data.token[id-1] != '='):
133
+ continue
134
+ if not text_data.token[id] in arg_sem_flat and \
135
+ id > 0 and ('line' in text_data.token[id-1] or text_data.token[id-1] == 'and' or
136
+ (text_data.token[id-1] == ',' and text_data.token[id+1] == ',')):
137
+ continue
138
+ text_data.class_tag[id] = '[NUM]'
139
+
140
+ arg_text_flat = []
141
+ for id, token in enumerate(text_data.token):
142
+ if text_data.sect_tag[id] == '[COND]' and text_data.class_tag[id] == '[NUM]':
143
+ arg_text_flat += get_args(token)
144
+
145
+ # Determine arguments
146
+ arg_all_flat = arg_text_flat + arg_sem_flat
147
+ for id, token in enumerate(text_data.token):
148
+ if text_data.class_tag[id] == '[GEN]' \
149
+ and text_data.token[id] in arg_all_flat:
150
+ if id < len(text_data.token)-1 and text_data.token[id+1] == '=':
151
+ text_data.class_tag[id] = '[ARG]'
152
+ continue
153
+ if text_data.token[id] == 'm' and id < len(text_data.token)-1 and text_data.token[id+1] in ["\\angle", "\\widehat"]:
154
+ continue
155
+ if text_data.token[id] == 'a' and (id == 0 or text_data.token[id-1] != '=') and \
156
+ text_data.sect_tag[id]=='[COND]':
157
+ continue
158
+ if id > 0 and ('line' in text_data.token[id-1] or text_data.token[id-1] == 'and' or
159
+ (text_data.token[id-1] == ',' and text_data.token[id+1] == ',')):
160
+ continue
161
+ text_data.class_tag[id] = '[ARG]'
162
+
163
+ def remove_sem_dup(text_data, sem_data, exp_token):
164
+ """
165
+ Remove the seq of sem_data if num is also in the text_data
166
+ and change the corresponding expression
167
+ """
168
+ text_num_list, id_all_list, id_map_list = [], [], []
169
+ token_, sect_tag_, class_tag_ = [], [], []
170
+
171
+ for k in range(len(text_data.token)):
172
+ if text_data.class_tag[k] == '[NUM]':
173
+ text_num_list.append(text_data.token[k])
174
+ var_name = 'N'+str(len(id_all_list))
175
+ id_all_list.append(var_name)
176
+ id_map_list.append(var_name)
177
+
178
+ for k in range(len(sem_data.token)):
179
+ if sem_data.class_tag[k][-2] == '[NUM]':
180
+ var_name = 'N'+str(len(id_all_list))
181
+ id_all_list.append(var_name)
182
+ if not sem_data.token[k][-2] in text_num_list:
183
+ token_.append(sem_data.token[k])
184
+ sect_tag_.append(sem_data.sect_tag[k])
185
+ class_tag_.append(sem_data.class_tag[k])
186
+ id_map_list.append(var_name)
187
+ else:
188
+ token_.append(sem_data.token[k])
189
+ sect_tag_.append(sem_data.sect_tag[k])
190
+ class_tag_.append(sem_data.class_tag[k])
191
+
192
+ num_map_dict = {key:value for key, value in zip(id_map_list, id_all_list)}
193
+ for k in range(len(exp_token)):
194
+ if exp_token[k] in num_map_dict:
195
+ exp_token[k] = num_map_dict[exp_token[k]]
196
+
197
+ sem_data.token = token_
198
+ sem_data.sect_tag = sect_tag_
199
+ sem_data.class_tag = class_tag_
200
+
201
+ def get_combined_text(text_seq, stru_seqs, sem_seqs, combine_text, args):
202
+ '''
203
+ combination style: [stru_seqs, text_cond, sem_seqs, text_prob]
204
+ '''
205
+ # split cond and prob in text_seq
206
+ begin_ind = end_ind = None
207
+ for k in range(len(text_seq.sect_tag)):
208
+ if text_seq.sect_tag[k]=='[PROB]':
209
+ begin_ind = k
210
+ break
211
+ for k in range(len(text_seq.sect_tag)-1,-1,-1):
212
+ if text_seq.sect_tag[k]=='[PROB]':
213
+ end_ind = k+1
214
+ break
215
+ # combine text_seq, stru_seqs and sem_seqs
216
+ for key in vars(combine_text):
217
+ # get text_cond and text_prob
218
+ text_all_value = getattr(text_seq, key)
219
+ text_cond_value = text_all_value[:begin_ind] + text_all_value[end_ind:]
220
+ text_prob_value = text_all_value[begin_ind:end_ind]
221
+ if args.without_stru:
222
+ value_all = text_cond_value + sum(getattr(sem_seqs, key), []) + text_prob_value
223
+ else:
224
+ value_all = sum(getattr(stru_seqs, key), []) + text_cond_value + \
225
+ sum(getattr(sem_seqs, key), []) + text_prob_value
226
+
227
+ setattr(combine_text, key, value_all)
228
+
229
+ def get_var_arg(combine_text, args):
230
+
231
+ var_values, arg_values = [], []
232
+ var_positions, arg_positions = [], []
233
+ class_tag = combine_text.class_tag
234
+ token = combine_text.token
235
+
236
+ for k in range(len(class_tag)):
237
+ if class_tag[k] == '[NUM]':
238
+ var_values.append(token[k])
239
+ var_positions.append(k)
240
+ if class_tag[k] == '[ARG]':
241
+ arg_values.append(token[k])
242
+ arg_positions.append(k)
243
+ # merge position of var and arg
244
+ return var_positions+arg_positions, var_values, arg_values
245
+
246
+ def get_text_index(combine_text, src_lang):
247
+
248
+ text_sect_tag = src_lang.indexes_from_sentence(combine_text.sect_tag, id_type='sect_tag')
249
+ text_class_tag = src_lang.indexes_from_sentence(combine_text.class_tag, id_type='class_tag')
250
+ text_token = [combine_text.token[:], ['[PAD]']*len(combine_text.token)]
251
+ for k in range(len(combine_text.class_tag)):
252
+ if combine_text.class_tag[k] == '[NUM]':
253
+ letter_list = get_args(combine_text.token[k])
254
+ text_token[0][k] = text_token[1][k] = "[PAD]"
255
+ for j in range(len(letter_list)):
256
+ text_token[j][k] = letter_list[j]
257
+ text_token = [src_lang.indexes_from_sentence(item, id_type='text') for item in text_token]
258
+
259
+ return text_token, text_sect_tag, text_class_tag
260
+
261
+
262
+
263
+
264
+
265
+
266
+
loss/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .loss import *
2
+ from config import criterion_list
3
+
4
+
5
+ def get_criterion(args):
6
+ # create model
7
+ if args.criterion in criterion_list:
8
+ return eval(args.criterion)(args)
9
+ else:
10
+ raise NotImplementedError("Unsupported Loss Criterion : {}".format(args.criterion))
loss/loss.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from utils import *
5
+
6
+
7
+ class CrossEntropy(nn.Module):
8
+ def __init__(self, cfg):
9
+ super(CrossEntropy, self).__init__()
10
+
11
+ def forward(self, output, target):
12
+ loss = F.cross_entropy(output, target)
13
+ return loss
14
+
15
+ class FocalLoss(nn.Module):
16
+ def __init__(self, cfg=None):
17
+ super(FocalLoss, self).__init__()
18
+ # self.gamma = cfg.LOSS.FOCAL.GAMMA
19
+ if cfg is None:
20
+ self.gamma = 2.0
21
+ else:
22
+ self.gamma = cfg.focal_loss_gamma
23
+ assert self.gamma >= 0
24
+
25
+ def focal_loss(self, input_values):
26
+ """Computes the focal loss"""
27
+ p = torch.exp(-input_values)
28
+ loss = (1 - p) ** self.gamma * input_values
29
+ return loss.mean()
30
+
31
+ def forward(self, input, target):
32
+ return self.focal_loss(F.cross_entropy(input, target, reduction='none'))
33
+
34
+ class MaskedCrossEntropy(nn.Module):
35
+
36
+ def __init__(self, cfg):
37
+ super(MaskedCrossEntropy, self).__init__()
38
+ self.cfg = cfg
39
+
40
+ def forward(self, logits, target, length):
41
+ """
42
+ Args:
43
+ logits: A Variable containing a FloatTensor of size
44
+ (batch, max_len, num_classes) which contains the
45
+ unnormalized probability for each class. B x S x (op_size+const_size+var_size)
46
+ target: A Variable containing a LongTensor of size
47
+ (batch, max_len) which contains the index of the true
48
+ class for each corresponding step. B x S
49
+ Returns:
50
+ loss: An average loss value masked by the length.
51
+ """
52
+ # logits_flat: (batch * max_len, num_classes)
53
+ logits_flat = logits.view(-1, logits.size(-1))
54
+ # log_probs_flat: (batch * max_len, num_classes)
55
+ log_probs_flat = F.log_softmax(logits_flat, dim=1)
56
+ # target_flat: (batch * max_len, 1)
57
+ target_flat = target.view(-1, 1)
58
+ # losses_flat: (batch * max_len, 1)
59
+ losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
60
+ # losses: (batch, max_len)
61
+ losses = losses_flat.view(*target.size())
62
+ # mask: (batch, max_len)
63
+ mask = sequence_mask(length)
64
+ losses = losses * mask.float()
65
+ loss = losses.sum() / length.float().sum()
66
+ return loss
model/backbone/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .resnet import *
2
+ from .mobilenet_v2 import *
3
+ from config import visual_backbone_list
4
+
5
+
6
+ def get_visual_backbone(args):
7
+ if args.visual_backbone in visual_backbone_list:
8
+ model = eval(args.visual_backbone)()
9
+ if args.pretrain_vis_path !="":
10
+ model.load_model(pretrain=args.pretrain_vis_path)
11
+ args.logger.info("Visual backbone has been loaded...")
12
+ else:
13
+ args.logger.info("Visual backbone choose to train from scratch")
14
+ return model
15
+ else:
16
+ raise NotImplementedError("Unsupported Backbone: {}".format(args.visual_backbone))
model/backbone/mobilenet_v2.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+ import torch
4
+ import config as cfg
5
+
6
+
7
+ def conv_bn(inp, oup, stride):
8
+ return nn.Sequential(
9
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
10
+ nn.BatchNorm2d(oup),
11
+ nn.ReLU6(inplace=True)
12
+ )
13
+
14
+
15
+ class InvertedResidual(nn.Module):
16
+ def __init__(self, inp, oup, stride, expand_ratio):
17
+ super(InvertedResidual, self).__init__()
18
+ self.stride = stride
19
+ assert stride in [1, 2]
20
+
21
+ hidden_dim = round(inp * expand_ratio)
22
+ self.use_res_connect = self.stride == 1 and inp == oup
23
+
24
+ if expand_ratio == 1:
25
+ self.conv = nn.Sequential(
26
+ # dw
27
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
28
+ nn.BatchNorm2d(hidden_dim),
29
+ nn.ReLU6(inplace=True),
30
+ # pw-linear
31
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
32
+ nn.BatchNorm2d(oup),
33
+ )
34
+ else:
35
+ self.conv = nn.Sequential(
36
+ # pw
37
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
38
+ nn.BatchNorm2d(hidden_dim),
39
+ nn.ReLU6(inplace=True),
40
+ # dw
41
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
42
+ nn.BatchNorm2d(hidden_dim),
43
+ nn.ReLU6(inplace=True),
44
+ # pw-linear
45
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
46
+ nn.BatchNorm2d(oup),
47
+ )
48
+
49
+ def forward(self, x):
50
+ if self.use_res_connect:
51
+ return x + self.conv(x)
52
+ else:
53
+ return self.conv(x)
54
+
55
+
56
+ class MobileNetV2(nn.Module):
57
+ def __init__(self, width_mult=1.):
58
+ super(MobileNetV2, self).__init__()
59
+ block = InvertedResidual
60
+ input_channel = 32
61
+ last_channel = 1280
62
+ interverted_residual_setting = [
63
+ # t, c, n, s
64
+ [1, 16, 1, 1],
65
+ [6, 24, 2, 2],
66
+ [6, 32, 3, 2],
67
+ [6, 64, 4, 2],
68
+ [6, 96, 3, 1],
69
+ [6, 160, 3, 2],
70
+ [6, 320, 1, 1],
71
+ ]
72
+
73
+ # building first layer
74
+ # assert input_size % 32 == 0
75
+ input_channel = int(input_channel * width_mult)
76
+ self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
77
+ self.features = [conv_bn(3, input_channel, 2)]
78
+ # building inverted residual blocks
79
+ for t, c, n, s in interverted_residual_setting:
80
+ output_channel = int(c * width_mult)
81
+ for i in range(n):
82
+ if i == 0:
83
+ self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
84
+ else:
85
+ self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
86
+ input_channel = output_channel
87
+
88
+ # make it nn.Sequential
89
+ self.features = nn.Sequential(*self.features)
90
+
91
+ self._initialize_weights()
92
+
93
+ def forward(self, x):
94
+ x = self.features(x)
95
+ return x
96
+
97
+ def _initialize_weights(self):
98
+ for m in self.modules():
99
+ if isinstance(m, nn.Conv2d):
100
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
101
+ m.weight.data.normal_(0, math.sqrt(2. / n))
102
+ if m.bias is not None:
103
+ m.bias.data.zero_()
104
+ elif isinstance(m, nn.BatchNorm2d):
105
+ m.weight.data.fill_(1)
106
+ m.bias.data.zero_()
107
+ elif isinstance(m, nn.Linear):
108
+ n = m.weight.size(1)
109
+ m.weight.data.normal_(0, 0.01)
110
+ m.bias.data.zero_()
111
+
112
+ def load_model(self):
113
+ model_dict = self.state_dict()
114
+ pretrained_dict = torch.load(cfg.pretrained_model_path)
115
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
116
+ model_dict.update(pretrained_dict)
117
+ self.load_state_dict(model_dict)
118
+
119
+
120
+ def mobilenet_v2():
121
+
122
+ return MobileNetV2()
model/backbone/resnet.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ def init_layer(L):
6
+ # Initialization using fan-in
7
+ if isinstance(L, nn.Conv2d):
8
+ n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
9
+ L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
10
+ elif isinstance(L, nn.BatchNorm2d):
11
+ L.weight.data.fill_(1)
12
+ L.bias.data.fill_(0)
13
+
14
+ class Flatten(nn.Module):
15
+ def __init__(self):
16
+ super(Flatten, self).__init__()
17
+
18
+ def forward(self, x):
19
+ return x.view(x.size(0), -1)
20
+
21
+ # Simple ResNet Block
22
+ class SimpleBlock(nn.Module):
23
+ maml = False #Default
24
+ def __init__(self, indim, outdim, half_res):
25
+ super(SimpleBlock, self).__init__()
26
+ self.indim = indim
27
+ self.outdim = outdim
28
+ self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
29
+ self.BN1 = nn.BatchNorm2d(outdim)
30
+ self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1,bias=False)
31
+ self.BN2 = nn.BatchNorm2d(outdim)
32
+ self.relu1 = nn.ReLU(inplace=True)
33
+ self.relu2 = nn.ReLU(inplace=True)
34
+
35
+ self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2]
36
+
37
+ self.half_res = half_res
38
+
39
+ # if the input number of channels is not equal to the output, then need a 1x1 convolution
40
+ if indim!=outdim:
41
+ self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False)
42
+ self.BNshortcut = nn.BatchNorm2d(outdim)
43
+ self.parametrized_layers.append(self.shortcut)
44
+ self.parametrized_layers.append(self.BNshortcut)
45
+ self.shortcut_type = '1x1'
46
+ else:
47
+ self.shortcut_type = 'identity'
48
+
49
+ for layer in self.parametrized_layers:
50
+ init_layer(layer)
51
+
52
+ def forward(self, x):
53
+ out = self.C1(x)
54
+ out = self.BN1(out)
55
+ out = self.relu1(out)
56
+ out = self.C2(out)
57
+ out = self.BN2(out)
58
+ short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
59
+ out = out + short_out
60
+ out = self.relu2(out)
61
+ return out
62
+
63
+ # Bottleneck block
64
+ class BottleneckBlock(nn.Module):
65
+ maml = False #Default
66
+ def __init__(self, indim, outdim, half_res):
67
+ super(BottleneckBlock, self).__init__()
68
+ bottleneckdim = int(outdim/4)
69
+ self.indim = indim
70
+ self.outdim = outdim
71
+
72
+ self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False)
73
+ self.BN1 = nn.BatchNorm2d(bottleneckdim)
74
+ self.C2 = nn.Conv2d(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1,padding=1)
75
+ self.BN2 = nn.BatchNorm2d(bottleneckdim)
76
+ self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False)
77
+ self.BN3 = nn.BatchNorm2d(outdim)
78
+
79
+ self.relu = nn.ReLU()
80
+ self.parametrized_layers = [self.C1, self.BN1, self.C2, self.BN2, self.C3, self.BN3]
81
+ self.half_res = half_res
82
+
83
+ # if the input number of channels is not equal to the output, then need a 1x1 convolution
84
+ if indim!=outdim:
85
+ self.shortcut = nn.Conv2d(indim, outdim, 1, stride=2 if half_res else 1, bias=False)
86
+ self.parametrized_layers.append(self.shortcut)
87
+ self.shortcut_type = '1x1'
88
+ else:
89
+ self.shortcut_type = 'identity'
90
+
91
+ for layer in self.parametrized_layers:
92
+ init_layer(layer)
93
+
94
+ def forward(self, x):
95
+
96
+ short_out = x if self.shortcut_type == 'identity' else self.shortcut(x)
97
+ out = self.C1(x)
98
+ out = self.BN1(out)
99
+ out = self.relu(out)
100
+ out = self.C2(out)
101
+ out = self.BN2(out)
102
+ out = self.relu(out)
103
+ out = self.C3(out)
104
+ out = self.BN3(out)
105
+ out = out + short_out
106
+
107
+ out = self.relu(out)
108
+ return out
109
+
110
+ class ResNet(nn.Module):
111
+ maml = False #Default
112
+ def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten = True):
113
+ # list_of_num_layers specifies number of layers in each stage
114
+ # list_of_out_dims specifies number of output channel for each stage
115
+ super(ResNet,self).__init__()
116
+ assert len(list_of_num_layers)==4, 'Can have only four stages'
117
+
118
+ conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
119
+ bias=False)
120
+ bn1 = nn.BatchNorm2d(64)
121
+ relu = nn.ReLU()
122
+ pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
123
+ init_layer(conv1)
124
+ init_layer(bn1)
125
+ trunk = [conv1, bn1, relu, pool1]
126
+ indim = 64
127
+ for i in range(4):
128
+ for j in range(list_of_num_layers[i]):
129
+ half_res = (i>=1) and (j==0)
130
+ B = block(indim, list_of_out_dims[i], half_res)
131
+ trunk.append(B)
132
+ indim = list_of_out_dims[i]
133
+ if flatten:
134
+ avgpool = nn.AvgPool2d(4)
135
+ trunk.append(avgpool)
136
+ trunk.append(Flatten())
137
+ self.final_feat_dim = indim
138
+ else:
139
+ self.final_feat_dim = [indim, 4, 4]
140
+ self.trunk = nn.Sequential(*trunk)
141
+
142
+ def forward(self,x):
143
+ out = self.trunk(x)
144
+ return out
145
+
146
+ def ResNet10(flatten = True):
147
+ return ResNet(SimpleBlock, [1,1,1,1],[64,128,256,512], flatten)
148
+
149
+ def ResNet18(flatten = True):
150
+ return ResNet(SimpleBlock, [2,2,2,2],[64,128,256,512], flatten)
151
+
152
+ def ResNet34(flatten = True):
153
+ return ResNet(SimpleBlock, [3,4,6,3],[64,128,256,512], flatten)
154
+
155
+ def ResNet50(flatten = True):
156
+ return ResNet(BottleneckBlock, [3,4,6,3], [256,512,1024,2048], flatten)
157
+
158
+ def ResNet101(flatten = True):
159
+ return ResNet(BottleneckBlock, [3,4,23,3],[256,512,1024,2048], flatten)
model/classifier/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .classifier_ops import *
2
+ from config import classifier_list
3
+
4
+
5
+ def get_classifier(args):
6
+
7
+ bias_flag = args.classifier_bias
8
+ num_features = args.num_features
9
+ num_classes = args.num_classes
10
+
11
+ if not args.classifier in classifier_list:
12
+ raise NotImplementedError("Unsupported Classifier: {}".format(args.classifier))
13
+
14
+ if args.classifier == "FCNorm":
15
+ classifier = FCNorm(num_features, num_classes)
16
+ elif args.classifier == "CosNorm":
17
+ classifier = CosNorm(num_features, num_classes)
18
+ elif args.classifier == "DotProduct":
19
+ classifier = DotProduct(num_classes, num_features, bias_flag)
20
+ elif args.classifier == "DistFC":
21
+ classifier = DistFC(num_features, num_classes)
22
+
23
+ return classifier
model/classifier/classifier_ops.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class DotProduct(nn.Module):
7
+ def __init__(self, num_classes=1000, feat_dim=2048, bias=True):
8
+ super(DotProduct, self).__init__()
9
+ # print('<DotProductClassifier> contains bias: {}'.format(bias))
10
+ self.fc = nn.Linear(feat_dim, num_classes,bias)
11
+
12
+ def forward(self, x, *args):
13
+ x = self.fc(x)
14
+ return x
15
+
16
+
17
+ class CosNorm(nn.Module):
18
+ def __init__(self, in_dims, out_dims, scale=16, margin=0.5, init_std=0.001):
19
+ super(CosNorm, self).__init__()
20
+ self.in_dims = in_dims
21
+ self.out_dims = out_dims
22
+ self.scale = scale
23
+ self.margin = margin
24
+ self.weight = nn.Parameter(torch.Tensor(out_dims, in_dims).cuda())
25
+ self.reset_parameters()
26
+
27
+ def reset_parameters(self):
28
+ stdv = 1. / math.sqrt(self.weight.size(1))
29
+ self.weight.data.uniform_(-stdv, stdv)
30
+
31
+ def forward(self, input, *args):
32
+ norm_x = torch.norm(input.clone(), 2, 1, keepdim=True)
33
+ ex = (norm_x / (1 + norm_x)) * (input / norm_x)
34
+ ew = self.weight / torch.norm(self.weight, 2, 1, keepdim=True)
35
+ return torch.mm(self.scale * ex, ew.t())
36
+
37
+
38
+ class FCNorm(nn.Module):
39
+ # for LDAM Loss
40
+ def __init__(self, num_features, num_classes, scale=20.0):
41
+ super(FCNorm, self).__init__()
42
+ self.weight = nn.Parameter(torch.FloatTensor(num_classes, num_features))
43
+ self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
44
+ self.scale = scale
45
+
46
+ def forward(self, x):
47
+ out = self.scale * F.linear(F.normalize(x), F.normalize(self.weight))
48
+ return out
49
+
50
+
51
+ class DistFC(nn.Module):
52
+
53
+ def __init__(self, num_features, num_classes,init_weight=True):
54
+ super(DistFC, self).__init__()
55
+ self.centers=nn.Parameter(torch.randn(num_features,num_classes).cuda(),requires_grad=True)
56
+ if init_weight:
57
+ self.__init_weight()
58
+
59
+ def __init_weight(self):
60
+ nn.init.kaiming_normal_(self.centers)
61
+
62
+ def forward(self, x):
63
+ features_square=torch.sum(torch.pow(x,2),1, keepdim=True)
64
+ centers_square=torch.sum(torch.pow(self.centers,2),0, keepdim=True)
65
+ features_into_centers=2.0*torch.matmul(x, (self.centers))
66
+ dist=features_square+centers_square-features_into_centers
67
+ return self.centers, dist
68
+
69
+
model/decoder/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .transformer import TransformerModel
2
+ from config import decoder_list
3
+ from .rnn_decoder import DecoderRNN
4
+ from .tree_decoder import TreeDecoder
5
+ from .transformer import TransformerDecoder
6
+
7
+ def get_decoder(params, *args):
8
+
9
+ if not params.decoder_type in decoder_list:
10
+ raise NotImplementedError(
11
+ "Unsupported Classifier: {}".format(params.decoder_type))
12
+
13
+ if params.decoder_type == "transformer":
14
+ decoder = TransformerDecoder(params, *args)
15
+ elif params.decoder_type == "rnn_decoder":
16
+ decoder = DecoderRNN(params, *args)
17
+ elif params.decoder_type == "tree_decoder":
18
+ decoder = TreeDecoder(params, *args)
19
+ else:
20
+ raise NotImplementedError("Unsupported Decoder: {}".format(params.decoder_type))
21
+
22
+ return decoder
23
+
24
+
model/decoder/rnn_decoder.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model.module import *
4
+ from utils import *
5
+ from torch.nn import functional as F
6
+
7
+ class DecoderRNN(nn.Module):
8
+ def __init__(self, cfg, tgt_lang):
9
+ super(DecoderRNN, self).__init__()
10
+ # token location
11
+ self.var_start = tgt_lang.var_start # spe_num + midvar_num + const_num + op_num
12
+ self.sos_id = tgt_lang.word2index["[SOS]"]
13
+ self.eos_id = tgt_lang.word2index["[EOS]"]
14
+ # Define layers
15
+ self.em_dropout = nn.Dropout(cfg.dropout_rate)
16
+ self.embedding_tgt = nn.Embedding(self.var_start, cfg.decoder_embedding_size, padding_idx=0)
17
+ self.gru = nn.GRU(input_size=cfg.decoder_hidden_size+cfg.decoder_embedding_size, \
18
+ hidden_size=cfg.decoder_hidden_size, \
19
+ num_layers=cfg.decoder_layers, \
20
+ dropout = cfg.dropout_rate, \
21
+ batch_first = True)
22
+ # Choose attention model
23
+ self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size)
24
+ self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size)
25
+ # predefined constant
26
+ self.no_var_id = torch.arange(self.var_start).unsqueeze(0).cuda()
27
+ self.cfg = cfg
28
+
29
+ def get_var_encoder_outputs(self, encoder_outputs, var_pos):
30
+ """
31
+ Arguments:
32
+ encoder_outputs: B x S1 x H
33
+ var_pos: B x S3
34
+ Returns:
35
+ var_embeddings: B x S3 x H
36
+ """
37
+ hidden_size = encoder_outputs.size(-1)
38
+ expand_var_pos = var_pos.unsqueeze(-1).repeat(1, 1, hidden_size)
39
+ var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_pos)
40
+ return var_embeddings
41
+
42
+ def forward(self, encoder_outputs, problem_output, len_src, var_pos, len_var, \
43
+ text_tgt=None, is_train=False):
44
+ """
45
+ Arguments:
46
+ encoder_outputs: B x S1 x H
47
+ problem_output: layer_num x B x H
48
+ len_src: B
49
+ text_tgt: B x S2
50
+ var_pos: B x S3
51
+ len_var: B
52
+ Return:
53
+ training: logits, B x S x (no_var_size+var_size)
54
+ testing: exp_id, B x candi_size(beam_size) x exp_len
55
+ """
56
+ self.embedding_var = self.get_var_encoder_outputs(encoder_outputs, var_pos) # B x S3 x H
57
+ self.src_mask = sequence_mask(len_src) # B x S1
58
+ self.candi_mask = sequence_mask(self.var_start + len_var) # B x (no_var_size + var_size)
59
+ if is_train:
60
+ return self._forward_train(encoder_outputs, problem_output, text_tgt)
61
+ else:
62
+ return self._forward_test(encoder_outputs, problem_output)
63
+
64
+ def _forward_train(self, encoder_outputs, problem_output, text_tgt):
65
+
66
+ all_seq_outputs = []
67
+ batch_size = encoder_outputs.size(0)
68
+ # initial hidden input of RNN
69
+ rnn_hidden = problem_output
70
+ # input embedding
71
+ tgt_novar_id = torch.clamp(text_tgt, max=self.var_start-1) # B x S2
72
+ novar_embedding = self.embedding_tgt(tgt_novar_id) # B x S2 x H
73
+ tgt_var_id = torch.clamp(text_tgt-self.var_start, min=0) # B x S2
74
+ var_embeddings = self.embedding_var.gather(dim=1, index = \
75
+ tgt_var_id.unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size)) # B x S2 x H
76
+
77
+ choose_mask = (text_tgt<self.var_start).unsqueeze(2). \
78
+ repeat(1, 1, self.cfg.decoder_embedding_size)
79
+ embedding_all = torch.where(choose_mask, novar_embedding, var_embeddings) # B x S2 x H
80
+ embedding_all_ = self.em_dropout(embedding_all)
81
+ # candi weight embedding
82
+ embedding_weight_no_var = self.embedding_tgt(self.no_var_id. \
83
+ repeat(batch_size, 1)) # B x no_var_size x H
84
+ embedding_weight_all = torch.cat((embedding_weight_no_var, self.embedding_var), dim=1) # B x (no_var_size + var_size) x H
85
+ embedding_weight_all_ = self.em_dropout(embedding_weight_all)
86
+
87
+ for t in range(text_tgt.size(1)-1):
88
+ # Calculate attention from current RNN state and all encoder outputs;
89
+ # apply to encoder outputs to get weighted average
90
+ current_hiddens = self.em_dropout(rnn_hidden[-1].unsqueeze(1)) # B x 1 x H
91
+ attn_weights = self.attn(current_hiddens, encoder_outputs, self.src_mask)
92
+ context = attn_weights.unsqueeze(1).bmm(encoder_outputs) # B x 1 x H
93
+ # Get current hidden state from input word and last hidden state
94
+ rnn_output, rnn_hidden = self.gru(torch.cat((embedding_all_[:, t:t+1, :], context), 2), rnn_hidden)
95
+ # rnn_output: B x 1 x H
96
+ # rnn_hidden: num_layers x B x H
97
+ current_fusion_emb = torch.cat((rnn_output, context), 2)
98
+ current_fusion_emb_ = self.em_dropout(current_fusion_emb)
99
+ candi_score = self.score(current_fusion_emb_, embedding_weight_all_, \
100
+ self.candi_mask) # B x (no_var_size + var_size)
101
+ all_seq_outputs.append(candi_score)
102
+
103
+ all_seq_outputs = torch.stack(all_seq_outputs, dim=1)
104
+
105
+ return all_seq_outputs
106
+
107
+ def _forward_test(self, encoder_outputs, problem_output):
108
+ """
109
+ Decode with beam search algorithm
110
+ """
111
+ exp_outputs = []
112
+ batch_size = encoder_outputs.size(0)
113
+
114
+ for sample_id in range(batch_size):
115
+ # predefine
116
+ rem_size = self.cfg.beam_size
117
+ encoder_output = encoder_outputs[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S1 x H
118
+ src_mask = self.src_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
119
+ embedding_var = self.embedding_var[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S3 x H
120
+ embedding_weight_no_var = self.embedding_tgt(self.no_var_id.repeat(rem_size, 1)) # beam_size x no_var_size x H
121
+ embedding_weight_all = torch.cat((embedding_weight_no_var, embedding_var), dim=1) # beam_size x (no_var_size + var_size) x H
122
+ embedding_weight_all_ = self.em_dropout(embedding_weight_all)
123
+ candi_mask = self.candi_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
124
+ candi_exp_output = []
125
+ candi_score_output = []
126
+
127
+ for i in range(self.cfg.max_output_len):
128
+ # initial varible
129
+ if i==0:
130
+ input_token = torch.LongTensor([[self.sos_id]]*rem_size).cuda() # rem_size x 1
131
+ rnn_hidden = problem_output[:, sample_id:sample_id+1].repeat(1, rem_size, 1) # layer_num x rem_size x H
132
+ current_score = torch.FloatTensor([[0.0]]*rem_size).cuda() # rem_size x 1
133
+ current_exp_list = [[]]*rem_size
134
+ else:
135
+ input_token = torch.LongTensor(token_list).unsqueeze(1).cuda()
136
+ rnn_hidden = rnn_hidden[:, cand_list]
137
+ rem_size = len(exp_list)
138
+ current_score = torch.FloatTensor(score_list[:rem_size]).unsqueeze(1).cuda()
139
+ current_exp_list = exp_list
140
+
141
+ # input embedding
142
+ tgt_novar_id = torch.clamp(input_token, max=self.var_start-1) # rem_size x 1
143
+ novar_embedding = self.embedding_tgt(tgt_novar_id) # rem_size x 1 x H
144
+ tgt_var_id = torch.clamp(input_token-self.var_start, min=0) # rem_size x 1
145
+ var_embeddings = embedding_var[:rem_size].gather(dim=1, index=tgt_var_id.unsqueeze(2). \
146
+ repeat(1, 1, self.cfg.decoder_embedding_size)) # rem_size x 1 x H
147
+ choose_mask = (input_token<self.var_start).unsqueeze(2). \
148
+ repeat(1, 1, self.cfg.decoder_embedding_size) # rem_size x 1 x H
149
+ embedding_all = torch.where(choose_mask, novar_embedding, var_embeddings) # rem_size x 1 x H
150
+ embedding_all_ = self.em_dropout(embedding_all)
151
+ # attention
152
+ current_hiddens = self.em_dropout(rnn_hidden[-1].unsqueeze(1)) # rem_size x 1 x H
153
+ attn_weights = self.attn(current_hiddens, encoder_output[:rem_size], src_mask[:rem_size]) # rem_size x S1
154
+ context = attn_weights.unsqueeze(1).bmm(encoder_output[:rem_size]) # rem_size x 1 x H
155
+ # Get current hidden state from input word and last hidden state
156
+ rnn_output, rnn_hidden = self.gru(torch.cat((embedding_all_, context), 2), rnn_hidden)
157
+ # rnn_output: rem_size x 1 x H
158
+ # rnn_hidden: num_layers x rem_size x H
159
+ current_fusion_emb = torch.cat((rnn_output, context), 2)
160
+ current_fusion_emb_ = self.em_dropout(current_fusion_emb)
161
+ candi_score = self.score(current_fusion_emb_, embedding_weight_all_[:rem_size], \
162
+ candi_mask[:rem_size]) # rem_size x (no_var_size + var_size)
163
+
164
+ if i==0:
165
+ new_score = F.log_softmax(candi_score, dim=1)[:1]
166
+ else:
167
+ new_score = F.log_softmax(candi_score, dim=1) + current_score
168
+
169
+ cand_tup_list = [(score, id) for id, score in enumerate(new_score.view(-1).tolist())]
170
+ cand_tup_list += [(score, -1) for score in candi_score_output]
171
+ cand_tup_list.sort(key=lambda x:x[0], reverse=True)
172
+
173
+ token_list = []
174
+ cand_list = []
175
+ exp_list = []
176
+ score_list = []
177
+
178
+ for tv, ti in cand_tup_list[:self.cfg.beam_size]:
179
+ if ti!=-1:
180
+ idex = ti
181
+ x = idex // candi_score.size(-1)
182
+ y = idex % candi_score.size(-1)
183
+ if y!=self.eos_id:
184
+ token_list.append(y)
185
+ cand_list.append(x)
186
+ exp_list.append(current_exp_list[x]+[y])
187
+ score_list.append(tv)
188
+ else:
189
+ candi_exp_output.append(current_exp_list[x])
190
+ candi_score_output.append(float(tv))
191
+
192
+ if len(token_list)==0:
193
+ break
194
+
195
+ if len(candi_exp_output)>0:
196
+ _, candi_exp_output = zip(*sorted(zip(candi_score_output, candi_exp_output), reverse=True))
197
+ exp_outputs.append(list(candi_exp_output[:self.cfg.beam_size]))
198
+ else:
199
+ exp_outputs.append([])
200
+
201
+ return exp_outputs
model/decoder/transformer.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils.utils import sequence_mask
4
+ from model.module import *
5
+ from torch.nn import functional as F
6
+ import math
7
+
8
+ class PositionalEncoding(nn.Module):
9
+
10
+ def __init__(self, d_model, max_len=5000, dropout_rate=0.2):
11
+ super(PositionalEncoding, self).__init__()
12
+
13
+ pe = torch.zeros(max_len, d_model)
14
+ position = torch.arange(0, max_len).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ pe = pe.unsqueeze(0)
19
+ self.register_buffer("pe", pe)
20
+
21
+ self.dropout = nn.Dropout(dropout_rate)
22
+
23
+ def forward(self, x):
24
+ """
25
+ x: [B, max_len, d_model]
26
+ pe: [1, max_len, d_model]
27
+ """
28
+ x = x + self.pe[:, : x.size(1)].requires_grad_(False)
29
+ return self.dropout(x)
30
+
31
+ class TransformerDecoder(nn.Module):
32
+
33
+ def __init__(self, cfg, tgt_lang, \
34
+ d_model=256, nhead=8, num_decoder_layers=4, dim_feedforward=1024, dropout=0.2):
35
+ super(TransformerDecoder, self).__init__()
36
+
37
+ decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
38
+ decoder_norm = nn.LayerNorm(d_model)
39
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
40
+ self.position_dec = PositionalEncoding(d_model=d_model)
41
+
42
+ self.score = Score_Multi(cfg.decoder_hidden_size, cfg.decoder_embedding_size)
43
+ self.var_start = tgt_lang.var_start
44
+ self.embedding_tgt = nn.Embedding(self.var_start, cfg.decoder_embedding_size, padding_idx=0)
45
+ self.no_var_id = torch.arange(self.var_start).unsqueeze(0).cuda()
46
+
47
+ self._reset_parameters()
48
+ self.d_model = d_model
49
+ self.nhead = nhead
50
+ self.cfg = cfg
51
+ self.sos_id = tgt_lang.word2index["[SOS]"]
52
+ self.eos_id = tgt_lang.word2index["[EOS]"]
53
+
54
+ def _reset_parameters(self):
55
+ """
56
+ Initiate parameters in the transformer model.
57
+ """
58
+ for p in self.parameters():
59
+ if p.dim() > 1:
60
+ nn.init.xavier_uniform_(p)
61
+
62
+ def get_square_subsequent_mask(self, sz):
63
+ """
64
+ Generate a square mask for the sequence. The masked positions are filled with True.
65
+ Unmasked positions are filled with False.
66
+ """
67
+ mask = (torch.triu(torch.ones(sz, sz)) == 0).transpose(0, 1)
68
+ return mask.cuda()
69
+
70
+ def get_var_encoder_outputs(self, encoder_outputs, var_pos):
71
+ """
72
+ Arguments:
73
+ encoder_outputs: B x S1 x H
74
+ var_pos: B x S3
75
+ Returns:
76
+ var_embeddings: B x S3 x H
77
+ """
78
+ hidden_size = encoder_outputs.size(-1)
79
+ expand_var_pos = var_pos.unsqueeze(-1).repeat(1, 1, hidden_size)
80
+ var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_pos)
81
+ return var_embeddings
82
+
83
+ def forward(self, memory, len_src, tgt, len_tgt, var_pos, len_var, is_train=False):
84
+ '''
85
+ memory: B x S1 x H
86
+ len_src: B
87
+ tgt: B x S2
88
+ len_tgt: B
89
+ var_pos: B x S3(var_size)
90
+ len_var: B
91
+ '''
92
+ self.embedding_var = self.get_var_encoder_outputs(memory, var_pos) # B x S3 x H
93
+ self.candi_mask = sequence_mask(self.var_start + len_var) # B x (no_var_size + var_size)
94
+ self.memory_key_padding_mask = ~sequence_mask(len_src) # B x S1
95
+ if is_train:
96
+ return self._forward_train(memory, tgt, len_tgt)
97
+ else:
98
+ return self._forward_test(memory)
99
+
100
+
101
+ def _forward_train(self, memory, tgt, len_tgt):
102
+ # mask
103
+ tgt_mask = self.get_square_subsequent_mask(tgt.size(-1))
104
+ tgt_key_padding_mask = ~sequence_mask(len_tgt)
105
+ # emb_tgt
106
+ tgt_novar_id = torch.clamp(tgt, max=self.var_start-1) # B x S2
107
+ novar_embedding = self.embedding_tgt(tgt_novar_id) # B x S2 x H
108
+ tgt_var_id = torch.clamp(tgt-self.var_start, min=0) # B x S2
109
+ var_embeddings = self.embedding_var.gather(dim=1, index = \
110
+ tgt_var_id.unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size)) # B x S2 x H
111
+ choose_mask = (tgt<self.var_start).unsqueeze(2). \
112
+ repeat(1, 1, self.cfg.decoder_embedding_size)
113
+ emb_tgt = torch.where(choose_mask, novar_embedding, var_embeddings) # B x S2 x H
114
+ # position decoding
115
+ emb_tgt = self.position_dec(emb_tgt)
116
+ output = self.decoder( # B x S2 x H
117
+ emb_tgt.permute(1,0,2),
118
+ memory.permute(1,0,2),
119
+ tgt_mask=tgt_mask,
120
+ tgt_key_padding_mask=tgt_key_padding_mask,
121
+ memory_key_padding_mask=self.memory_key_padding_mask,
122
+ ).permute(1,0,2)
123
+ # candi weight embedding
124
+ embedding_weight_no_var = self.embedding_tgt(self.no_var_id.repeat(len(len_tgt), 1)) # B x no_var_size x H
125
+ embedding_weight_all = torch.cat((embedding_weight_no_var, self.embedding_var), dim=1) # B x (no_var_size+var_size) x H
126
+ candi_score = self.score( # B x S2 x (no_var_size + var_size)
127
+ output,
128
+ embedding_weight_all, \
129
+ self.candi_mask
130
+ )
131
+
132
+ return candi_score[:,:-1,:].clone()
133
+
134
+ def _forward_test(self, memory):
135
+
136
+ exp_outputs = []
137
+
138
+ for sample_id in range(memory.size(0)):
139
+ # predefine
140
+ rem_size = self.cfg.beam_size
141
+ memory_item = memory[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S1 x H
142
+ memory_key_padding_mask = self.memory_key_padding_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
143
+ embedding_var = self.embedding_var[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S3 x H
144
+ embedding_weight_no_var = self.embedding_tgt(self.no_var_id.repeat(rem_size, 1)) # beam_size x no_var_size x H
145
+ embedding_weight_all = torch.cat((embedding_weight_no_var, embedding_var), dim=1) # beam_size x (no_var_size + var_size) x H
146
+ candi_mask = self.candi_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
147
+
148
+ candi_exp_output = []
149
+ candi_score_output = []
150
+
151
+ tgt = torch.LongTensor([[self.sos_id]]*rem_size).cuda() # rem_size x 1
152
+ len_tgt = torch.LongTensor([1]*rem_size).cuda() # rem_size
153
+ current_score = torch.FloatTensor([[0.0]]*rem_size).cuda() # rem_size x 1
154
+ current_exp_list = [[self.sos_id]]*rem_size
155
+
156
+ for i in range(self.cfg.max_output_len):
157
+ # mask
158
+ tgt_mask = self.get_square_subsequent_mask(tgt.size(-1))
159
+ tgt_key_padding_mask = ~sequence_mask(len_tgt)
160
+ # input embedding
161
+ tgt_novar_id = torch.clamp(tgt, max=self.var_start-1) # rem_size x S
162
+ novar_embedding = self.embedding_tgt(tgt_novar_id) # rem_size x S x H
163
+ tgt_var_id = torch.clamp(tgt-self.var_start, min=0) # rem_size x S
164
+ var_embeddings = embedding_var[:rem_size].gather(dim=1, index=tgt_var_id.unsqueeze(2). \
165
+ repeat(1, 1, self.cfg.decoder_embedding_size)) # rem_size x S x H
166
+ choose_mask = (tgt<self.var_start).unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size) # rem_size x S x H
167
+ emb_tgt = torch.where(choose_mask, novar_embedding, var_embeddings) # rem_size x S x H
168
+ # position decoding
169
+ emb_tgt = self.position_dec(emb_tgt)
170
+ output = self.decoder( # rem_size x S x H
171
+ emb_tgt.permute(1,0,2),
172
+ memory_item[:rem_size].permute(1,0,2),
173
+ tgt_mask=tgt_mask,
174
+ tgt_key_padding_mask=tgt_key_padding_mask,
175
+ memory_key_padding_mask=memory_key_padding_mask[:rem_size],
176
+ ).permute(1,0,2)
177
+ candi_score = self.score( # rem_size x S x (no_var_size + var_size)
178
+ output,
179
+ embedding_weight_all[:rem_size], \
180
+ candi_mask[:rem_size]
181
+ )
182
+
183
+ if i==0:
184
+ new_score = F.log_softmax(candi_score[:, -1, :], dim=1)[:1]
185
+ else:
186
+ new_score = F.log_softmax(candi_score[:, -1, :], dim=1) + current_score # rem_size x (no_var_size + var_size)
187
+
188
+ topv, topi = new_score.view(-1).topk(rem_size)
189
+ exp_list = []
190
+ score_list = topv.tolist()
191
+
192
+ for tv, ti in zip(topv, topi):
193
+ idex = ti.item()
194
+ x = idex // candi_score.size(-1)
195
+ y = idex % candi_score.size(-1)
196
+ if y!=self.eos_id:
197
+ exp_list.append(current_exp_list[x]+[y])
198
+ else:
199
+ candi_exp_output.append(current_exp_list[x][1:])
200
+ candi_score_output.append(float(tv))
201
+
202
+ if len(exp_list)==0:
203
+ break
204
+
205
+ tgt = torch.LongTensor(exp_list).cuda() # rem_size x S
206
+ len_tgt = torch.LongTensor([len(item) for item in exp_list]).cuda() # rem_size
207
+ current_exp_list = exp_list
208
+ rem_size = len(exp_list)
209
+ current_score = torch.FloatTensor(score_list[:rem_size]).unsqueeze(1).cuda() # rem_size x 1
210
+
211
+ if len(candi_exp_output)>0:
212
+ _, candi_exp_output = zip(*sorted(zip(candi_score_output, candi_exp_output), reverse=True))
213
+ exp_outputs.append(list(candi_exp_output))
214
+ else:
215
+ exp_outputs.append([])
216
+
217
+ return exp_outputs
model/decoder/tree_decoder.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils import *
4
+ from model.module import *
5
+ from torch.nn import functional as F
6
+ import copy
7
+
8
+ class TreeNode: # the class save the tree node
9
+ def __init__(self, embedding, left_flag=False):
10
+ self.embedding = embedding
11
+ self.left_flag = left_flag
12
+
13
+ class TreeEmbedding: # the class save the tree
14
+ def __init__(self, embedding, terminal=False):
15
+ self.embedding = embedding
16
+ self.terminal = terminal
17
+
18
+ class TreeBeam: # the class save the beam node
19
+ def __init__(self, score, node_stacks, embeddings_stacks, left_child_trees, out):
20
+ self.score = score
21
+ self.embeddings_stacks = embeddings_stacks
22
+ self.node_stacks = node_stacks
23
+ self.left_child_trees = left_child_trees
24
+ self.out = out
25
+
26
+ class Prediction(nn.Module):
27
+ # a seq2tree decoder with Problem aware dynamic encoding
28
+ def __init__(self, cfg, op_const_size):
29
+ super(Prediction, self).__init__()
30
+ # Define layers
31
+ self.em_dropout = nn.Dropout(cfg.dropout_rate)
32
+ # for Computational symbols and Generated numbers
33
+ self.concat_l = nn.Linear(cfg.decoder_hidden_size, cfg.decoder_hidden_size)
34
+ self.concat_r = nn.Linear(cfg.decoder_hidden_size * 2, cfg.decoder_hidden_size)
35
+ self.concat_lg = nn.Linear(cfg.decoder_hidden_size, cfg.decoder_hidden_size)
36
+ self.concat_rg = nn.Linear(cfg.decoder_hidden_size * 2, cfg.decoder_hidden_size)
37
+ # attention module
38
+ self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size)
39
+ self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size)
40
+ # predefined constant
41
+ self.op_const_id = torch.arange(op_const_size).unsqueeze(0).cuda()
42
+ self.padding_hidden = torch.zeros(1, cfg.decoder_hidden_size).cuda()
43
+
44
+ def forward(self, node_stacks, left_child_trees, encoder_outputs, var_pades, source_mask, candi_mask, embedding_op_const):
45
+ '''
46
+ Augments:
47
+ node_stacks: [[TreeNode(_)]]*B, store the variable h
48
+ left_child_trees: [t]*B, store the representation of left tree
49
+ encoder_outputs: [B, S1, H]
50
+ var_pades: [B, S2, H], all_vars_encoder_outputs
51
+ padding_hidden: [1, H]
52
+ source_mask: [B, S1], mask for source seq
53
+ candi_mask: [B, op_size+const_size+var_size], mask for target seq
54
+ Returns:
55
+ num_score: [B x (op_size+const_size+var_size)]
56
+ current_embeddings: q [B x 1 x H], the target vector of the current node
57
+ current_context: c [B x 1 x H], the context vector of the current node, is calculated using the target vector and encoder_outputs
58
+ current_all_embeddings: [B x (op_size+const_size+var_size) x H] e (M_op, M_con, h_loc^p)
59
+ '''
60
+ current_embeddings = []
61
+
62
+ for node_list in node_stacks:
63
+ if len(node_list) == 0:
64
+ current_embeddings.append(self.padding_hidden)
65
+ else:
66
+ current_node = node_list[-1]
67
+ current_embeddings.append(current_node.embedding)
68
+
69
+ current_node_temp = [] # B x (1 x H)
70
+
71
+ for l, c in zip(left_child_trees, current_embeddings):
72
+ if l is None:
73
+ cd = self.em_dropout(c)
74
+ g = torch.tanh(self.concat_l(cd))
75
+ t = torch.sigmoid(self.concat_lg(cd))
76
+ current_node_temp.append(g*t)
77
+ else:
78
+ ld = self.em_dropout(l)
79
+ cd = self.em_dropout(c)
80
+ g = torch.tanh(self.concat_r(torch.cat((ld, cd), 1)))
81
+ t = torch.sigmoid(self.concat_rg(torch.cat((ld, cd), 1)))
82
+ current_node_temp.append(g*t)
83
+
84
+ current_node = torch.stack(current_node_temp, dim=0) # B x 1 x H (q)
85
+ current_embeddings = self.em_dropout(current_node)
86
+ current_attn = self.attn(current_embeddings, encoder_outputs, source_mask) # B x S
87
+ current_context = current_attn.unsqueeze(1).bmm(encoder_outputs) # B x 1 x H (c)
88
+ leaf_input = torch.cat((current_node, current_context), 2) # B x 1 x 2H
89
+
90
+ embedding_weight_op_const = embedding_op_const(self.op_const_id.repeat(var_pades.size(0), 1)) # B x var_size x H
91
+ embedding_weight_all = torch.cat((embedding_weight_op_const, var_pades), dim=1) # B x (op_size+const_size+var_size) x H
92
+
93
+ leaf_input = self.em_dropout(leaf_input)
94
+ embedding_weight_all_ = self.em_dropout(embedding_weight_all)
95
+ num_score = self.score(leaf_input, embedding_weight_all_, candi_mask) # B x (op_size+const_size+var_size)
96
+
97
+ return num_score, current_node, current_context, embedding_weight_all
98
+
99
+ class GenerateNode(nn.Module):
100
+ def __init__(self, cfg, op_size):
101
+ super(GenerateNode, self).__init__()
102
+
103
+ self.embedding_size = cfg.decoder_embedding_size
104
+ self.hidden_size = cfg.decoder_hidden_size
105
+ self.op_size = op_size
106
+
107
+ self.em_dropout = nn.Dropout(cfg.dropout_rate)
108
+ self.generate_l = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
109
+ self.generate_r = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
110
+ self.generate_lg = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
111
+ self.generate_rg = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
112
+
113
+ def forward(self, current_embedding, node_label, current_context, embedding_op_const):
114
+ """
115
+ Generate the hidden node hl and hr of tree, according to the front part of eq(10)(11)
116
+ Arguments:
117
+ current_embedding: [B x 1 x H (q)], the target vector of the current node
118
+ node_label: [B (id)]
119
+ current_context: [B x 1 x H (c)], context vector of current node
120
+ embedding_op_const: Embedding of op_const
121
+ Returns:
122
+ left_child: [B x H (h)]
123
+ right_child: [B x H (h)]
124
+ token_embedding: [B x H (e(y|P) of op)]
125
+ """
126
+ node_label_op = torch.clamp(node_label, max=self.op_size-1)
127
+ current_embedding_ = self.em_dropout(current_embedding.squeeze(1))
128
+ current_context_ = self.em_dropout(current_context.squeeze(1))
129
+ token_embedding = embedding_op_const(node_label_op)
130
+ token_embedding_ = self.em_dropout(token_embedding)
131
+
132
+ l_child = torch.tanh(self.generate_l(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
133
+ l_child_g = torch.sigmoid(self.generate_lg(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
134
+ r_child = torch.tanh(self.generate_r(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
135
+ r_child_g = torch.sigmoid(self.generate_rg(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
136
+ l_child = l_child * l_child_g
137
+ r_child = r_child * r_child_g
138
+
139
+ return l_child, r_child, token_embedding
140
+
141
+ class Merge(nn.Module):
142
+ """
143
+ Get subtree embedding via Recursive Neural Network
144
+ """
145
+ def __init__(self, cfg):
146
+ super(Merge, self).__init__()
147
+
148
+ self.embedding_size = cfg.decoder_embedding_size
149
+ self.hidden_size = cfg.decoder_hidden_size
150
+ self.em_dropout = nn.Dropout(cfg.dropout_rate)
151
+ self.merge = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
152
+ self.merge_g = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
153
+
154
+ def forward(self, node_embedding, sub_tree_1, sub_tree_2):
155
+ '''
156
+ Arguments:
157
+ node_embedding: 1 x H
158
+ sub_tree_1: 1 x H
159
+ sub_tree_2: 1 x H
160
+ Return:
161
+ sub_tree: 1 x H
162
+ '''
163
+ sub_tree_1 = self.em_dropout(sub_tree_1)
164
+ sub_tree_2 = self.em_dropout(sub_tree_2)
165
+ node_embedding = self.em_dropout(node_embedding)
166
+
167
+ sub_tree = torch.tanh(self.merge(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1)))
168
+ sub_tree_g = torch.sigmoid(self.merge_g(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1)))
169
+ sub_tree = sub_tree * sub_tree_g
170
+
171
+ return sub_tree
172
+
173
+ class TreeDecoder(nn.Module):
174
+ def __init__(self, cfg, tgt_lang):
175
+ super(TreeDecoder, self).__init__()
176
+ # embedding for op, const, num
177
+ self.var_start = tgt_lang.var_start
178
+ self.op_num = tgt_lang.op_num
179
+ self.const_num = tgt_lang.const_num
180
+
181
+ self.embedding_op_const = nn.Embedding(self.op_num+self.const_num, cfg.decoder_embedding_size)
182
+ self.embedding_var = None # obtain from encoder
183
+ self.cfg = cfg
184
+ # modules of TreeDecoder
185
+ self.predict = Prediction(cfg, self.op_num+self.const_num)
186
+ self.generate = GenerateNode(cfg, self.op_num)
187
+ self.merge = Merge(cfg)
188
+
189
+ def get_var_encoder_outputs(self, encoder_outputs, var_positions):
190
+ """
191
+ Arguments:
192
+ encoder_outputs: B x S1 x H
193
+ var_positions: B x S2
194
+ Returns:
195
+ var_embeddings: B x S2 x H
196
+ """
197
+ hidden_size = encoder_outputs.size(-1)
198
+ expand_var_positions = var_positions.unsqueeze(-1).repeat(1, 1, hidden_size)
199
+ var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_positions)
200
+ return var_embeddings
201
+
202
+ def forward(self, encoder_outputs, problem_output, len_source, var_positions, len_var, \
203
+ is_train=False, text_target=None, len_target=None):
204
+ """
205
+ Arguments:
206
+ encoder_outputs: B x S1 x H
207
+ problem_output: B x H
208
+ len_source: B
209
+ text_target: B x S2
210
+ len_target: B
211
+ var_positions: B x S3
212
+ len_var: B
213
+ Return:
214
+ training: output B x S x (op_size+const_size+var_size), logits of one batch
215
+ testing: [expr] x B
216
+ """
217
+ self.embedding_var = self.get_var_encoder_outputs(encoder_outputs, var_positions) # B x S2 x H
218
+ self.source_mask = sequence_mask(len_source)
219
+ self.candi_mask = sequence_mask(len_var+self.var_start)
220
+ if is_train:
221
+ return self._forward_train(encoder_outputs, problem_output, text_target)
222
+ else:
223
+ return self._forward_test(encoder_outputs, problem_output)
224
+
225
+ def _forward_train(self, encoder_outputs, problem_output, text_target):
226
+ """
227
+ Arguments:
228
+ embeddings_stacks: [[TreeEmbedding(t, terminal)]]*B, a stack of subtrees t in the first order traversal
229
+ left_child_trees: [t]*B, the representation of left tree of current node
230
+ node_stacks: [[TreeNode(h, left_flag)]]*B, a stack of hidden state h in the first order traversal
231
+ Returns:
232
+ all_node_outputs: B x S x (op_size+const_size+var_size), logits of one batch
233
+ """
234
+ node_stacks = [[TreeNode(init_hidden)] for init_hidden in problem_output.split(1, dim=0)]
235
+ embeddings_stacks = [[] for _ in range(encoder_outputs.size(0))]
236
+ left_child_trees = [None]*encoder_outputs.size(0)
237
+ all_node_outputs = []
238
+
239
+ for t in range(text_target.size(1)):
240
+ num_score, current_embeddings, current_context, current_all_embeddings = self.predict(
241
+ node_stacks,
242
+ left_child_trees,
243
+ encoder_outputs,
244
+ self.embedding_var,
245
+ self.source_mask,
246
+ self.candi_mask,
247
+ self.embedding_op_const)
248
+
249
+ all_node_outputs.append(num_score) # [B x (op_size+const_size+var_size)] * S
250
+
251
+ left_child, right_child, token_embedding = self.generate(
252
+ current_embeddings,
253
+ text_target[:,t],
254
+ current_context,
255
+ self.embedding_op_const)
256
+
257
+ left_child_trees = []
258
+
259
+ for idx, (l, r, node_stack, target_id, embeddings_stack) in enumerate(zip(left_child.split(1), right_child.split(1),
260
+ node_stacks, text_target[:,t].tolist(), embeddings_stacks)):
261
+ # Determines whether the tree traversal is complete
262
+ if len(node_stack) != 0:
263
+ node_stack.pop()
264
+ else:
265
+ left_child_trees.append(None)
266
+ continue
267
+ if target_id < self.op_num:
268
+ node_stack.append(TreeNode(r))
269
+ node_stack.append(TreeNode(l, left_flag=True))
270
+ # embeddings_stack, put e(y|P) of op in temporarily
271
+ embeddings_stack.append(TreeEmbedding(token_embedding[idx].unsqueeze(0), False))
272
+ else:
273
+ current_num = current_all_embeddings[idx, target_id].unsqueeze(0) # 1 x H
274
+ # Reach the right leaf node and merge the tree representation from bottom up
275
+ while len(embeddings_stack) > 0 and embeddings_stack[-1].terminal:
276
+ sub_stree = embeddings_stack.pop()
277
+ op = embeddings_stack.pop()
278
+ # embedding vector of two sub-targets is merged as the subtree embedding of nodes, corresponding to eq(12)
279
+ # with e(y|P), sub_tree_1 and sub_tree_2
280
+ current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
281
+ embeddings_stack.append(TreeEmbedding(current_num, True))
282
+ # Reach the left leaf node and save the representation of the left subtree for generation of q
283
+ if len(embeddings_stack) > 0 and embeddings_stack[-1].terminal:
284
+ left_child_trees.append(embeddings_stack[-1].embedding)
285
+ else:
286
+ left_child_trees.append(None)
287
+
288
+ all_node_outputs = torch.stack(all_node_outputs, dim=1)
289
+
290
+ return all_node_outputs
291
+
292
+ def _forward_test(self, encoder_outputs, problem_output):
293
+
294
+ exp_outputs = []
295
+
296
+ for sample_id in range(encoder_outputs.size(0)):
297
+ # set batch size as 1
298
+ node_stacks = [[TreeNode(problem_output[sample_id:sample_id+1])]]
299
+ embeddings_stacks = [[]]
300
+ left_child_trees = [None]
301
+ beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_child_trees, [])]
302
+
303
+ for _ in range(self.cfg.max_output_len):
304
+ # re-maintain of one beams
305
+ current_beams = []
306
+
307
+ while len(beams) > 0:
308
+ beam_item = beams.pop()
309
+ # The candidates are stored in beams in all process
310
+ if len(beam_item.node_stacks[0]) == 0:
311
+ current_beams.append(beam_item)
312
+ continue
313
+ num_score, current_embeddings, current_context, current_all_embeddings = self.predict(
314
+ beam_item.node_stacks,
315
+ beam_item.left_child_trees,
316
+ encoder_outputs[sample_id:sample_id+1],
317
+ self.embedding_var[sample_id:sample_id+1],
318
+ self.source_mask[sample_id:sample_id+1],
319
+ self.candi_mask[sample_id:sample_id+1],
320
+ self.embedding_op_const)
321
+
322
+ out_score = F.log_softmax(num_score, dim=1)
323
+ topv, topi = out_score.topk(self.cfg.beam_size)
324
+
325
+ for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)):
326
+
327
+ current_node_stack = copy_list(beam_item.node_stacks)
328
+ current_left_child_trees = []
329
+ current_embeddings_stacks = copy_list(beam_item.embeddings_stacks)
330
+ current_out = copy.deepcopy(beam_item.out)
331
+
332
+ out_token = int(ti)
333
+ current_out.append(out_token)
334
+ current_node_stack[0].pop()
335
+
336
+ if out_token < self.op_num:
337
+ generate_input = torch.LongTensor([out_token]).cuda()
338
+ left_child, right_child, token_embedding = self.generate(
339
+ current_embeddings,
340
+ generate_input,
341
+ current_context,
342
+ self.embedding_op_const)
343
+ current_node_stack[0].append(TreeNode(right_child))
344
+ current_node_stack[0].append(TreeNode(left_child, left_flag=True))
345
+ current_embeddings_stacks[0].append(TreeEmbedding(token_embedding, False))
346
+ else:
347
+ current_num = current_all_embeddings[:, out_token]
348
+ while len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
349
+ sub_stree = current_embeddings_stacks[0].pop()
350
+ op = current_embeddings_stacks[0].pop()
351
+ current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
352
+ current_embeddings_stacks[0].append(TreeEmbedding(current_num, True))
353
+ if len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
354
+ current_left_child_trees.append(current_embeddings_stacks[0][-1].embedding)
355
+ else:
356
+ current_left_child_trees.append(None)
357
+
358
+ current_beams.append(TreeBeam(beam_item.score+float(tv), current_node_stack, current_embeddings_stacks,
359
+ current_left_child_trees, current_out))
360
+
361
+ beams = sorted(current_beams, key=lambda x: x.score, reverse=True)
362
+ beams = beams[:self.cfg.beam_size]
363
+
364
+ # early termination
365
+ flag = True
366
+ for beam_item in beams:
367
+ if len(beam_item.node_stacks[0]) != 0:
368
+ flag = False
369
+ break
370
+ if flag: break
371
+
372
+ exp_outputs.append(beams[0].out)
373
+
374
+ return exp_outputs
model/encoder/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .lstm import LSTM
2
+ from .gru import GRU
3
+ from config import encoder_list
4
+ from .transformer import TransformerEncoder
5
+
6
+ def get_encoder(params, *args):
7
+
8
+ if not params.encoder_type in encoder_list:
9
+ raise NotImplementedError(
10
+ "Unsupported Classifier: {}".format(params.encoder_type))
11
+
12
+ if params.encoder_type == "transformer":
13
+ pass
14
+ elif params.encoder_type == "lstm":
15
+ encoder = LSTM(params, *args)
16
+ elif params.encoder_type == "gru":
17
+ encoder = GRU(params, *args)
18
+ else:
19
+ raise NotImplementedError("Unsupported Encoder: {}".format(params.encoder_type))
20
+
21
+ return encoder
model/encoder/gru.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class GRU(nn.Module):
5
+
6
+ def __init__(self, cfg):
7
+ super(GRU, self).__init__()
8
+
9
+ self.is_bidirectional = True
10
+ self.batch_first = True
11
+ self.gru = nn.GRU(
12
+ input_size = cfg.encoder_embedding_size,
13
+ hidden_size = cfg.encoder_hidden_size, # int(hidden_size / num_directions),
14
+ num_layers = cfg.encoder_layers,
15
+ bidirectional = self.is_bidirectional,
16
+ dropout = cfg.dropout_rate,
17
+ batch_first = self.batch_first
18
+ )
19
+ self.hidden_size = cfg.encoder_hidden_size
20
+ self.dropout = nn.Dropout(cfg.dropout_rate)
21
+
22
+ def forward(self, src_emb, input_lengths, hidden=None):
23
+
24
+ input_emb = self.dropout(src_emb)
25
+ # input_emb = src_emb
26
+ packed = nn.utils.rnn.pack_padded_sequence(input_emb, input_lengths.cpu(), \
27
+ batch_first=self.batch_first, enforce_sorted=False)
28
+ pade_hidden = hidden
29
+ pade_outputs, pade_hidden = self.gru(packed, pade_hidden)
30
+ pade_outputs, _ = nn.utils.rnn.pad_packed_sequence(pade_outputs, batch_first=self.batch_first)
31
+ # pade_outputs [B, S, hidden_size*num_directions]
32
+ # pade_hidden [n_layers*num_directions, B, hidden_size]
33
+ if self.is_bidirectional:
34
+ pade_outputs = pade_outputs[:, :, :self.hidden_size] + pade_outputs[:, :, self.hidden_size:] # B x S x H
35
+ pade_hidden = pade_hidden[0::2, :, :] + pade_hidden[1::2, :, :]
36
+
37
+ return pade_outputs, pade_hidden
38
+
39
+
40
+
41
+
model/encoder/lstm.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class LSTM(nn.Module):
5
+
6
+ def __init__(self, cfg):
7
+ super(LSTM, self).__init__()
8
+
9
+ self.lstm = nn.LSTM(
10
+ input_size=cfg.WORD_EMBED_SIZE,
11
+ hidden_size=cfg.HIDDEN_SIZE, # int(hidden_size / num_directions),
12
+ num_layers=cfg.NUM_LAYERS,
13
+ batch_first=cfg.BATCH_FIRST, # first dim is batch_size or not
14
+ bidirectional=cfg.BIDIRECTIONAL
15
+ )
16
+
17
+ def forward(self, input, h0, c0):
18
+ output, (hn, cn) = self.lstm(input, (h0, c0))
19
+ return output, hn, cn
20
+
21
+
22
+
23
+
model/encoder/transformer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils.utils import sequence_mask
4
+ import math
5
+
6
+ class PositionalEncoding(nn.Module):
7
+
8
+ def __init__(self, d_model, max_len=5000, dropout=0.1):
9
+ super(PositionalEncoding, self).__init__()
10
+ self.dropout = nn.Dropout(p=dropout)
11
+ pe = torch.zeros(max_len, d_model)
12
+ position = torch.arange(0, max_len).unsqueeze(1)
13
+ div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
14
+ pe[:, 0::2] = torch.sin(position * div_term)
15
+ pe[:, 1::2] = torch.cos(position * div_term)
16
+ pe = pe.unsqueeze(0)
17
+ self.register_buffer("pe", pe)
18
+
19
+ def forward(self, x):
20
+ """
21
+ x: [B, max_len, d_model]
22
+ pe: [1, max_len, d_model]
23
+ """
24
+ x = x + self.pe[:, : x.size(1)].requires_grad_(False)
25
+ return self.dropout(x)
26
+
27
+ class LearnedPositionEncoding(nn.Module):
28
+
29
+ def __init__(self, d_model, max_len = 20):
30
+ super(LearnedPositionEncoding, self).__init__()
31
+ self.embedding = nn.Embedding(max_len, d_model)
32
+
33
+ def forward(self, x, var_pos):
34
+ """
35
+ x: [B, max_len, d_model]
36
+ var_pos: [B, var_len]
37
+ """
38
+ loc_mat = torch.zeros(x.size(0), x.size(1), dtype=torch.int64).cuda()
39
+ pos_id = torch.arange(1, var_pos.size(1)+1).repeat(var_pos.size(0), 1).cuda()
40
+ pos_id[var_pos==var_pos.min()] = 0
41
+ loc_mat.scatter_(1, var_pos, pos_id)
42
+
43
+ x = x + self.embedding(loc_mat)
44
+
45
+ return x
46
+
47
+ class TransformerEncoder(nn.Module):
48
+
49
+ def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.2):
50
+ super(TransformerEncoder,self).__init__()
51
+
52
+ encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
53
+ encoder_norm = nn.LayerNorm(d_model)
54
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
55
+ self.position = PositionalEncoding(d_model=d_model)
56
+
57
+ self._reset_parameters()
58
+ self.d_model = d_model
59
+ self.nhead = nhead
60
+
61
+ def _reset_parameters(self):
62
+ """
63
+ Initiate parameters in the transformer model.
64
+ """
65
+ for p in self.parameters():
66
+ if p.dim() > 1:
67
+ nn.init.xavier_uniform_(p)
68
+
69
+ def forward(self, len_src, emb_src):
70
+ # mask
71
+ src_key_padding_mask = ~sequence_mask(len_src)
72
+ # position encoding
73
+ emb_src = self.position(emb_src)
74
+ # encoder
75
+ memory = self.encoder(emb_src.permute(1,0,2), src_key_padding_mask=src_key_padding_mask)
76
+
77
+ return memory.permute(1,0,2)
model/module/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .module_ops import *
2
+ from .attention import *
model/module/attention.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Score(nn.Module):
5
+ def __init__(self, input_size, hidden_size):
6
+ super(Score, self).__init__()
7
+ self.attn = nn.Linear(hidden_size + input_size, hidden_size)
8
+ self.score = nn.Linear(hidden_size, 1, bias=False)
9
+
10
+ def forward(self, hidden, candi_embeddings, candi_mask=None):
11
+ '''
12
+ Arguments:
13
+ hidden: B x 1 x 2H
14
+ candi_embeddings: B x candi_size x H
15
+ candi_mask: B x candi_size
16
+ Return:
17
+ score: B x candi_size
18
+ '''
19
+ hidden = hidden.repeat(1, candi_embeddings.size(1), 1) # B x candi_size x H
20
+ # For each position of encoder outputs
21
+ energy_in = torch.cat((hidden, candi_embeddings), 2) # B x candi_size x 3H
22
+ score = self.score(torch.tanh(self.attn(energy_in))).squeeze(-1) # B x candi_size
23
+ if candi_mask is not None:
24
+ score = score.masked_fill_(~candi_mask, -1e12)
25
+ return score
26
+
27
+ class Attn(nn.Module):
28
+ def __init__(self, input_size, hidden_size):
29
+ super(Attn, self).__init__()
30
+ self.attn = nn.Linear(hidden_size + input_size, hidden_size)
31
+ self.score = nn.Linear(hidden_size, 1, bias=False)
32
+
33
+ def forward(self, hidden, encoder_outputs, seq_mask=None):
34
+ '''
35
+ Arguments:
36
+ hidden: B x 1 x H (q)
37
+ encoder_outputs: B x S x H
38
+ seq_mask: B x S
39
+ Return:
40
+ attn_energies: B x S
41
+ '''
42
+ hidden = hidden.repeat(1, encoder_outputs.size(1), 1) # B x S x H
43
+ energy_in = torch.cat((hidden, encoder_outputs), 2) # B x S x 2H
44
+ score_feature = torch.tanh(self.attn(energy_in)) # B x S x H
45
+ attn_energies = self.score(score_feature).squeeze(-1) # B x S
46
+ if seq_mask is not None:
47
+ attn_energies = attn_energies.masked_fill_(~seq_mask, -1e12)
48
+ attn_energies = nn.functional.softmax(attn_energies, dim=1) # B x S
49
+
50
+ return attn_energies
51
+
52
+ class Score_Multi(nn.Module):
53
+ def __init__(self, input_size, hidden_size):
54
+ super(Score_Multi, self).__init__()
55
+ self.attn = nn.Linear(hidden_size + input_size, hidden_size)
56
+ self.score = nn.Linear(hidden_size, 1, bias=False)
57
+
58
+ def forward(self, hidden, candi_embeddings, candi_mask=None):
59
+ '''
60
+ Arguments:
61
+ hidden: B x S x H
62
+ candi_embeddings: B x candi_size x H
63
+ candi_mask: B x candi_size
64
+ Return:
65
+ score: B x S x candi_size
66
+ '''
67
+ hidden = hidden.unsqueeze(2).repeat(1, 1, candi_embeddings.size(1), 1) # B x S x candi_size x H
68
+ candi_embeddings = candi_embeddings.unsqueeze(1).repeat(1, hidden.size(1), 1, 1) # B x S x candi_size x H
69
+ candi_mask = candi_mask.unsqueeze(1).repeat(1, hidden.size(1), 1) # B x S x candi_size
70
+ energy_in = torch.cat((hidden, candi_embeddings), -1) # B x S x candi_size x 2H
71
+ score = self.score(torch.tanh(self.attn(energy_in))).squeeze(-1) # B x S x candi_size
72
+ if candi_mask is not None:
73
+ score = score.masked_fill_(~candi_mask, -1e12)
74
+ return score
model/module/module_ops.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class GAP(nn.Module):
5
+ """
6
+ Global Average pooling
7
+ Widely used in ResNet, Inception, DenseNet, etc.
8
+ """
9
+ def __init__(self):
10
+ super(GAP, self).__init__()
11
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
12
+
13
+ def forward(self, x):
14
+ x = self.avgpool(x)
15
+ # x = x.view(x.shape[0], -1)
16
+ return x
17
+
18
+
19
+ class Identity(nn.Module):
20
+ def __init__(self):
21
+ super(Identity, self).__init__()
22
+
23
+ def forward(self, x):
24
+ return x
25
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==1.7.1
2
+ torchvision==0.8.2
3
+ gradio==4.16.0
4
+ Pillow>=9.0.0
5
+ numpy>=1.19.0
6
+ antlr4-python3-runtime==4.10
7
+ sympy==1.11.1
8
+ func_timeout==4.3.5
utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .lr_scheduler import *
2
+ from .utils import *
3
+
4
+
utils/lr_scheduler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from bisect import bisect_right
3
+
4
+
5
+ class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
6
+ def __init__(
7
+ self,
8
+ optimizer,
9
+ milestones,
10
+ gamma=0.1,
11
+ warmup_factor=1.0 / 3,
12
+ warmup_epochs=5,
13
+ warmup_method="linear",
14
+ last_epoch=-1,
15
+ ):
16
+ if not list(milestones) == sorted(milestones):
17
+ raise ValueError(
18
+ "Milestones should be a list of" " increasing integers. Got {}",
19
+ milestones,
20
+ )
21
+
22
+ if warmup_method not in ("constant", "linear"):
23
+ raise ValueError(
24
+ "Only 'constant' or 'linear' warmup_method accepted"
25
+ "got {}".format(warmup_method)
26
+ )
27
+ self.milestones = milestones
28
+ self.gamma = gamma
29
+ self.warmup_factor = warmup_factor
30
+ self.warmup_epochs = warmup_epochs
31
+ self.warmup_method = warmup_method
32
+ super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
33
+
34
+ def get_lr(self):
35
+ warmup_factor = 1
36
+ if self.last_epoch < self.warmup_epochs:
37
+ if self.warmup_method == "constant":
38
+ warmup_factor = self.warmup_factor
39
+ elif self.warmup_method == "linear":
40
+ alpha = float(self.last_epoch) / self.warmup_epochs
41
+ warmup_factor = self.warmup_factor * (1 - alpha) + alpha
42
+ return [
43
+ base_lr
44
+ * warmup_factor
45
+ * self.gamma ** bisect_right(self.milestones, self.last_epoch)
46
+ for base_lr in self.base_lrs
47
+ ]
utils/utils.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from utils.lr_scheduler import WarmupMultiStepLR
4
+ from config import *
5
+ import datetime
6
+ import torch.distributed as dist
7
+ from datasets.operators import result_compute, normalize_exp
8
+ from func_timeout import func_timeout
9
+ import random
10
+ import gc
11
+
12
+ def save_checkpoint(state, is_best, dump_path=None):
13
+ if is_best:
14
+ dump_path_best = os.path.join(dump_path, 'best_model.pth')
15
+ torch.save(state, dump_path_best)
16
+ else:
17
+ dump_path_recent = os.path.join(dump_path, str(state['epoch'])+'.pth')
18
+ torch.save(state, dump_path_recent)
19
+
20
+ class AverageMeter(object):
21
+ """
22
+ Computes and stores the average and current value
23
+ """
24
+ def __init__(self, name, fmt=':f'):
25
+ self.name = name
26
+ self.fmt = fmt
27
+ self.reset()
28
+
29
+ def reset(self):
30
+ self.val = 0
31
+ self.avg = 0
32
+ self.sum = 0
33
+ self.count = 0
34
+
35
+ def update(self, val, n=1):
36
+ self.val = val
37
+ self.sum += val * n
38
+ self.count += n
39
+ self.avg = self.sum / self.count
40
+
41
+ def __str__(self):
42
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
43
+ return fmtstr.format(**self.__dict__)
44
+
45
+ class ProgressMeter(object):
46
+ def __init__(self, num_batches, meters, args, prefix=""):
47
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
48
+ self.meters = meters
49
+ self.prefix = prefix
50
+ self.args = args
51
+
52
+ def display(self, batch, lr=None):
53
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
54
+ entries += [str(meter) for meter in self.meters]
55
+ if not lr is None:
56
+ entries += ["lr: "+str(format(lr, '.6f'))]
57
+ self.args.logger.info('\t'.join(entries))
58
+
59
+ def _get_batch_fmtstr(self, num_batches):
60
+ num_digits = len(str(num_batches // 1))
61
+ fmt = '{:' + str(num_digits) + 'd}'
62
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
63
+
64
+ def adjust_learning_rate(optimizer, epoch, args):
65
+ """
66
+ Sets the learning rate to the initial LR decayed by 10 every 30 epochs
67
+ """
68
+ lr = args.lr * (0.1**(epoch // 30))
69
+ for param_group in optimizer.param_groups:
70
+ param_group['lr'] = lr
71
+
72
+ def accuracy(output, target, topk=(1, )):
73
+ """
74
+ Computes the accuracy over the k top predictions for the specified values of k
75
+ """
76
+ with torch.no_grad():
77
+ maxk = max(topk)
78
+ batch_size = target.size(0)
79
+
80
+ _, pred = output.topk(maxk, 1, True, True)
81
+ pred = pred.t()
82
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
83
+
84
+ res = []
85
+ for k in topk:
86
+ correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
87
+ res.append(correct_k.mul_(100.0 / batch_size))
88
+ return res
89
+
90
+ def get_scheduler(args, optimizer):
91
+ if args.scheduler_type == "multistep":
92
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
93
+ optimizer,
94
+ args.scheduler_step,
95
+ gamma=args.scheduler_factor,
96
+ )
97
+ elif args.scheduler_type == "cosine":
98
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
99
+ optimizer, T_max=args.max_epochs, eta_min=1e-6)
100
+ elif args.scheduler_type == "warmup":
101
+ scheduler = WarmupMultiStepLR(
102
+ optimizer,
103
+ args.scheduler_step,
104
+ gamma=args.scheduler_factor,
105
+ warmup_epochs=args.warm_epoch,
106
+ )
107
+ else:
108
+ raise NotImplementedError("Unsupported LR Scheduler: {}".format(args.scheduler_type))
109
+ return scheduler
110
+
111
+ def get_optimizer(args, model):
112
+ if args.use_MLM_pretrain:
113
+ pretrain_params = list(map(id, model.mlm_pretrain.parameters()))
114
+ other_params = filter(lambda p: id(p) not in pretrain_params, model.parameters())
115
+
116
+ if args.optimizer_type == "SGD":
117
+ optimizer = torch.optim.SGD(
118
+ model.parameters(),
119
+ lr=args.lr,
120
+ momentum=args.momentum,
121
+ weight_decay=args.weight_decay,
122
+ nesterov=True,
123
+ )
124
+ elif args.optimizer_type == "ADAM":
125
+ if args.use_MLM_pretrain:
126
+ optimizer = torch.optim.Adam(
127
+ [{"params":model.mlm_pretrain.parameters()},
128
+ {"params":other_params, "lr":args.lr_LM}],
129
+ lr=args.lr,
130
+ betas=(0.9, 0.999),
131
+ weight_decay=args.weight_decay,
132
+ )
133
+ else:
134
+ optimizer = torch.optim.Adam(
135
+ model.parameters(),
136
+ lr=args.lr,
137
+ betas=(0.9, 0.999),
138
+ weight_decay=args.weight_decay,
139
+ )
140
+ elif args.optimizer_type == "ADAMW":
141
+ if args.use_MLM_pretrain:
142
+ optimizer = torch.optim.AdamW(
143
+ [{"params":model.mlm_pretrain.parameters(), "lr":args.lr_LM},
144
+ {"params":other_params}],
145
+ lr=args.lr,
146
+ weight_decay=args.weight_decay,
147
+ )
148
+ else:
149
+ optimizer = torch.optim.AdamW(
150
+ model.parameters(),
151
+ lr=args.lr,
152
+ weight_decay=args.weight_decay,
153
+ )
154
+ else:
155
+ raise NotImplementedError("Unsupported Optimizer Type : {}".format(args.optimizer_type))
156
+
157
+ return optimizer
158
+
159
+ def reduce_mean(tensor, nprocs):
160
+ rt = tensor.clone()
161
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
162
+ rt /= nprocs
163
+ return rt
164
+
165
+ def set_cuda(data_dict):
166
+ for key in data_dict:
167
+ if torch.is_tensor(data_dict[key]):
168
+ data_dict[key] = data_dict[key].cuda()
169
+
170
+ def initialize_logger(params, ):
171
+ """
172
+ Initialize the experience:
173
+ - dump parameters
174
+ - create a logger
175
+ """
176
+ while True:
177
+ exp_id = datetime.datetime.strftime(datetime.datetime.now(),'%Y-%m-%d-%H-%M-%S')
178
+ if not os.path.exists(os.path.join(params.dump_path, exp_id)):
179
+ break
180
+ params.dump_path = os.path.join(params.dump_path, exp_id)
181
+ if params.local_rank == 0:
182
+ os.makedirs(params.dump_path)
183
+ # create a logger
184
+ logger = create_logger(os.path.join(params.dump_path,'record.log'), params.local_rank)
185
+ logger.info("============ Initialized logger ============")
186
+ logger.info("\n"+"\n".join("\t\t\t\t%s: %s" % (k, str(v))
187
+ for k, v in sorted(dict(vars(params)).items())))
188
+ logger.info("The experiment results will be stored in %s" % params.dump_path)
189
+ return logger
190
+
191
+ def aeq(*args):
192
+ """
193
+ Assert all arguments have the same value
194
+ """
195
+ arguments = (arg for arg in args)
196
+ first = next(arguments)
197
+ assert all(arg == first for arg in arguments), \
198
+ "Not all arguments have the same value: " + str(args)
199
+
200
+ def sequence_mask(lengths, max_len=None):
201
+ """
202
+ Creates a boolean mask from sequence lengths.
203
+ """
204
+ batch_size = lengths.numel()
205
+ max_len = max_len or lengths.max()
206
+ return torch.arange(0, max_len, device=lengths.device) \
207
+ .type_as(lengths) \
208
+ .repeat(batch_size, 1) \
209
+ .lt(lengths.unsqueeze(1))
210
+
211
+ def copy_list(l):
212
+ r = []
213
+ if len(l) == 0:
214
+ return r
215
+ for i in l:
216
+ if type(i) is list:
217
+ r.append(copy_list(i))
218
+ else:
219
+ r.append(i)
220
+ return r
221
+
222
+ def compute_exp_result_choice(test_preds, var_dict, exp_dict, tgt_lang):
223
+
224
+ """
225
+ Arguments
226
+ test_preds: B x candi_size(beam_size) x token_list
227
+ var_dict: {'pos', 'len', 'var_value', 'arg_value'}
228
+ exp_dict: {'exp', 'len', 'answer'}
229
+ tgt_lang: vocab of target text
230
+ Returns:
231
+ ans_acc
232
+ eq_acc
233
+ """
234
+ gc.collect()
235
+ ans_num = eq_num = 0
236
+
237
+ for k in range(len(test_preds)): # batch id
238
+ tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS]
239
+ var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \
240
+ for i, item in enumerate(var_dict['arg_value'][k])}
241
+ tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict)
242
+ num_list = var_dict['var_value'][k]
243
+ tgt_result = float(exp_dict['answer'][k])
244
+ choices = exp_dict['choices'][k]
245
+ is_find_ans = False
246
+
247
+ for j in range(len(test_preds[k])): # pred candi id
248
+ try:
249
+ pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict)
250
+ pred = normalize_exp(pred)
251
+ pred_result = float(func_timeout(2.0, result_compute, \
252
+ kwargs=dict(num_all_list=num_list, exp_tokens=pred)))
253
+ if pred == tgt:
254
+ ans_num += 1
255
+ eq_num += 1
256
+ is_find_ans = True
257
+ break
258
+ for item in choices:
259
+ if abs(pred_result-item)<5e-2:
260
+ is_find_ans = True
261
+ if is_find_ans and abs(pred_result-tgt_result)<5e-3:
262
+ ans_num +=1
263
+ if len(pred)==len(tgt):
264
+ eq_num += 1
265
+ if is_find_ans: break
266
+ except:
267
+ pass
268
+
269
+ if not is_find_ans:
270
+ pred_result = random.choice(choices)
271
+ if abs(pred_result-tgt_result)<5e-2:
272
+ ans_num +=1
273
+
274
+ return ans_num/len(test_preds), eq_num/len(test_preds)
275
+
276
+ def compute_exp_result_topk(test_preds, var_dict, exp_dict, tgt_lang, k_num = 3):
277
+
278
+ """
279
+ Arguments
280
+ test_preds: B x candi_size(beam_size) x token_list
281
+ var_dict: {'pos', 'len', 'var_value', 'arg_value'}
282
+ exp_dict: {'exp', 'len', 'answer'}
283
+ tgt_lang: vocab of target text
284
+ Returns:
285
+ ans_acc
286
+ eq_acc
287
+ """
288
+ gc.collect()
289
+ ans_num = eq_num = 0
290
+
291
+ for k in range(len(test_preds)): # batch id
292
+ tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS]
293
+ var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \
294
+ for i, item in enumerate(var_dict['arg_value'][k])}
295
+ tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict)
296
+ num_list = var_dict['var_value'][k]
297
+ tgt_result = float(exp_dict['answer'][k])
298
+ is_ans_same = is_eq_same = False
299
+
300
+ for j in range(k_num): # top-n
301
+ try:
302
+ pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict)
303
+ pred = normalize_exp(pred)
304
+ pred_result = float(func_timeout(2.0, result_compute, \
305
+ kwargs=dict(num_all_list=num_list, exp_tokens=pred)))
306
+ if pred == tgt:
307
+ is_ans_same = True
308
+ is_eq_same = True
309
+ break
310
+ if abs(pred_result-tgt_result)<5e-3:
311
+ is_ans_same = True
312
+ if len(pred)==len(tgt):
313
+ is_eq_same = True
314
+ break
315
+ except:
316
+ pass
317
+
318
+ if is_ans_same: ans_num +=1
319
+ if is_eq_same: eq_num +=1
320
+
321
+ return ans_num/len(test_preds), eq_num/len(test_preds)
322
+
323
+
324
+ def compute_exp_result_comp(test_preds, var_dict, exp_dict, tgt_lang):
325
+
326
+ """
327
+ Arguments
328
+ test_preds: B x candi_size(beam_size) x token_list
329
+ var_dict: {'pos', 'len', 'var_value', 'arg_value'}
330
+ exp_dict: {'exp', 'len', 'answer'}
331
+ tgt_lang: vocab of target text
332
+ Returns:
333
+ ans_acc
334
+ eq_acc
335
+ """
336
+ gc.collect()
337
+ ans_num = eq_num = 0
338
+
339
+ for k in range(len(test_preds)): # batch id
340
+ tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS]
341
+ var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \
342
+ for i, item in enumerate(var_dict['arg_value'][k])}
343
+ tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict)
344
+ num_list = var_dict['var_value'][k]
345
+ tgt_result = float(exp_dict['answer'][k])
346
+ is_ans_same = is_eq_same = False
347
+
348
+ for j in range(len(test_preds[k])): # pred candi id
349
+ try:
350
+ pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict)
351
+ pred = normalize_exp(pred)
352
+ pred_result = float(func_timeout(2.0, result_compute, \
353
+ kwargs=dict(num_all_list=num_list, exp_tokens=pred)))
354
+ if pred == tgt:
355
+ is_ans_same = True
356
+ is_eq_same = True
357
+ break
358
+ if abs(pred_result-tgt_result)<5e-3:
359
+ is_ans_same = True
360
+ if len(pred)==len(tgt):
361
+ is_eq_same = True
362
+ break
363
+ except:
364
+ pass
365
+
366
+ if is_ans_same: ans_num +=1
367
+ if is_eq_same: eq_num +=1
368
+
369
+ return ans_num/len(test_preds), eq_num/len(test_preds)
vocab/vocab_src.txt ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [MASK]
5
+ +
6
+ ,
7
+ .
8
+ /
9
+ 1
10
+ 10
11
+ 11
12
+ 12
13
+ 13
14
+ 14
15
+ 15
16
+ 16
17
+ 17
18
+ 18
19
+ 19
20
+ 2
21
+ 20
22
+ 3
23
+ 4
24
+ 5
25
+ 6
26
+ 7
27
+ 8
28
+ 9
29
+ =
30
+ ?
31
+ A
32
+ B
33
+ B'
34
+ B0
35
+ B1
36
+ C
37
+ C0
38
+ C1
39
+ D
40
+ E
41
+ F
42
+ F0
43
+ F1
44
+ G
45
+ G0
46
+ G1
47
+ H
48
+ I
49
+ J
50
+ K
51
+ L
52
+ M
53
+ N
54
+ O
55
+ P
56
+ Q
57
+ Q'
58
+ R
59
+ S
60
+ T
61
+ U
62
+ V
63
+ W
64
+ W'
65
+ X
66
+ Y
67
+ Z
68
+ \angle
69
+ \cong
70
+ \cos
71
+ \odot
72
+ \parallel
73
+ \parallelogram
74
+ \perp
75
+ \phi
76
+ \sim
77
+ \sin
78
+ \tan
79
+ \triangle
80
+ \widehat
81
+ a
82
+ about
83
+ all
84
+ altitude
85
+ altitudes
86
+ an
87
+ and
88
+ angle
89
+ angles
90
+ any
91
+ appear
92
+ appears
93
+ are
94
+ area
95
+ areas
96
+ as
97
+ assume
98
+ at
99
+ b
100
+ base
101
+ bases
102
+ be
103
+ below
104
+ between
105
+ bisector
106
+ bisectors
107
+ bisects
108
+ blue
109
+ both
110
+ by
111
+ c
112
+ calculator
113
+ center
114
+ centers
115
+ centimeters
116
+ central
117
+ centroid
118
+ chord
119
+ chords
120
+ circle
121
+ circles
122
+ circumference
123
+ circumscribed
124
+ circumscribes
125
+ cm
126
+ cm^{2}
127
+ collinear
128
+ common
129
+ complementary
130
+ composite
131
+ congruent
132
+ connecting
133
+ corners
134
+ cosines
135
+ cut
136
+ d
137
+ degree
138
+ determine
139
+ diagonal
140
+ diagonals
141
+ diagram
142
+ diameter
143
+ diameters
144
+ distance
145
+ drawn
146
+ e
147
+ each
148
+ elm
149
+ equal
150
+ equidistant
151
+ equilateral
152
+ exact
153
+ express
154
+ f
155
+ factor
156
+ feet
157
+ figure
158
+ figures
159
+ find
160
+ for
161
+ form
162
+ formed
163
+ four
164
+ from
165
+ ft
166
+ ft^{2}
167
+ g
168
+ given
169
+ green
170
+ h
171
+ half
172
+ has
173
+ have
174
+ having
175
+ height
176
+ hexagon
177
+ how
178
+ hypotenuse
179
+ i
180
+ if
181
+ in
182
+ in^{2}
183
+ incenter
184
+ inches
185
+ inscribed
186
+ inside
187
+ intersect
188
+ intersected
189
+ intersecting
190
+ into
191
+ is
192
+ isosceles
193
+ it
194
+ its
195
+ j
196
+ k
197
+ kite
198
+ l
199
+ law
200
+ legs
201
+ length
202
+ lengths
203
+ let
204
+ lieson
205
+ line
206
+ linear
207
+ lines
208
+ long
209
+ m
210
+ m^{2}
211
+ major
212
+ make
213
+ makes
214
+ measure
215
+ measurement
216
+ measures
217
+ median
218
+ medians
219
+ meet
220
+ meters
221
+ midpoint
222
+ midpoints
223
+ midsegment
224
+ midsegments
225
+ miles
226
+ millimeters
227
+ minor
228
+ mm
229
+ mm^{2}
230
+ must
231
+ n
232
+ o
233
+ octagon
234
+ of
235
+ off
236
+ on
237
+ one
238
+ otherwise
239
+ p
240
+ pair
241
+ parallel
242
+ parallelogram
243
+ pentagon
244
+ pentagons
245
+ perimeter
246
+ perimeters
247
+ perpendicular
248
+ plum
249
+ point
250
+ points
251
+ polygon
252
+ polygons
253
+ proportion
254
+ pythagorean
255
+ q
256
+ quadrilateral
257
+ r
258
+ radius
259
+ ratio
260
+ ray
261
+ rectangle
262
+ red
263
+ refer
264
+ region
265
+ regular
266
+ respectively
267
+ rhombus
268
+ right
269
+ s
270
+ scale
271
+ sector
272
+ segment
273
+ segments
274
+ shaded
275
+ shown
276
+ side
277
+ sides
278
+ similar
279
+ sines
280
+ so
281
+ solve
282
+ special
283
+ square
284
+ stated
285
+ straight
286
+ such
287
+ sum
288
+ suppose
289
+ t
290
+ tangent
291
+ tangents
292
+ that
293
+ the
294
+ theorem
295
+ this
296
+ times
297
+ to
298
+ trapezoid
299
+ triangle
300
+ triangles
301
+ triple
302
+ twice
303
+ two
304
+ u
305
+ units
306
+ unless
307
+ use
308
+ v
309
+ value
310
+ variable
311
+ vertex
312
+ w
313
+ what
314
+ where
315
+ which
316
+ with
317
+ would
318
+ x
319
+ y
320
+ yards
321
+ yd^{2}
322
+ z
vocab/vocab_tgt.txt ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [SOS]
3
+ [EOS]
4
+ V0
5
+ V1
6
+ V2
7
+ V3
8
+ V4
9
+ V5
10
+ V6
11
+ C0.5
12
+ C2
13
+ C3
14
+ C4
15
+ C5
16
+ C6
17
+ C8
18
+ C60
19
+ C90
20
+ C180
21
+ C360
22
+ ArcSeg_Area
23
+ Chord2_Ang
24
+ Circle_D_Area
25
+ Circle_D_Circum
26
+ Circle_R_Area
27
+ Circle_R_Circum
28
+ Cos_Law
29
+ Equal
30
+ Gcos
31
+ Geo_Mean
32
+ Get
33
+ Gougu
34
+ Gsin
35
+ Gtan
36
+ Iso_Tri_Ang
37
+ Kite_Area
38
+ Median
39
+ Multiple
40
+ Ngon_Angsum
41
+ PRK_Perim
42
+ Para_Area
43
+ Proportion
44
+ RNgon_B_Area
45
+ RNgon_H_Area
46
+ RNgon_L_Area
47
+ Ratio
48
+ Rect_Area
49
+ Rhom_Area
50
+ Sin_Law
51
+ Sum
52
+ TanSec_Ang
53
+ Trap_Area
54
+ Tria_BH_Area
55
+ Tria_SAS_Area
56
+ N0
57
+ N1
58
+ N2
59
+ N3
60
+ N4
61
+ N5
62
+ N6
63
+ N7
64
+ N8
65
+ N9
66
+ N10
67
+ N11