Spaces:
Sleeping
Sleeping
Initial upload of PGPS demo with all dependencies
Browse files- LM_MODEL.pth +3 -0
- README.md +46 -6
- app.py +244 -0
- config/__init__.py +3 -0
- config/config_default.py +88 -0
- config/logger.py +27 -0
- core/__init__.py +0 -0
- core/network.py +200 -0
- core/test.py +40 -0
- core/train.py +44 -0
- core/worker.py +73 -0
- datasets/__init__.py +37 -0
- datasets/dataset.py +85 -0
- datasets/diagram_aug.py +79 -0
- datasets/operators.py +633 -0
- datasets/preprossing.py +201 -0
- datasets/text_aug.py +233 -0
- datasets/utils.py +266 -0
- loss/__init__.py +10 -0
- loss/loss.py +66 -0
- model/backbone/__init__.py +16 -0
- model/backbone/mobilenet_v2.py +122 -0
- model/backbone/resnet.py +159 -0
- model/classifier/__init__.py +23 -0
- model/classifier/classifier_ops.py +69 -0
- model/decoder/__init__.py +24 -0
- model/decoder/rnn_decoder.py +201 -0
- model/decoder/transformer.py +217 -0
- model/decoder/tree_decoder.py +374 -0
- model/encoder/__init__.py +21 -0
- model/encoder/gru.py +41 -0
- model/encoder/lstm.py +23 -0
- model/encoder/transformer.py +77 -0
- model/module/__init__.py +2 -0
- model/module/attention.py +74 -0
- model/module/module_ops.py +25 -0
- requirements.txt +8 -0
- utils/__init__.py +4 -0
- utils/lr_scheduler.py +47 -0
- utils/utils.py +369 -0
- vocab/vocab_src.txt +322 -0
- vocab/vocab_tgt.txt +67 -0
LM_MODEL.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d0c84cefe6acd4fd66020d40eaebae5da6e3cf231266193d6f53d465ae627d0
|
| 3 |
+
size 64083797
|
README.md
CHANGED
|
@@ -1,12 +1,52 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: PGPS Geometric Problem Solver
|
| 3 |
+
emoji: 📐
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.16.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# PGPS: Neural Geometric Problem Solver Demo
|
| 14 |
+
|
| 15 |
+
This Space demonstrates the PGPS (Plane Geometry Problem Solver) model, which uses multi-modal neural networks to solve geometry problems.
|
| 16 |
+
|
| 17 |
+
## How to Use
|
| 18 |
+
|
| 19 |
+
1. **Upload a Geometry Diagram**: Upload an image containing a geometric diagram (triangles, angles, lines, etc.)
|
| 20 |
+
2. **Enter Problem Text**: Provide the text description of the geometry problem
|
| 21 |
+
3. **Get Solution**: The model will analyze both the diagram and text to generate a solution
|
| 22 |
+
|
| 23 |
+
## Model Details
|
| 24 |
+
|
| 25 |
+
- **Architecture**: Multi-modal neural network with visual encoder and text encoder
|
| 26 |
+
- **Task**: Geometric problem solving
|
| 27 |
+
- **Paper**: IJCAI 2023
|
| 28 |
+
- **Original Repository**: [GitHub](https://github.com/mingliangzhang2018/PGPS)
|
| 29 |
+
|
| 30 |
+
## Features
|
| 31 |
+
|
| 32 |
+
- Visual diagram parsing
|
| 33 |
+
- Text understanding for geometric problems
|
| 34 |
+
- Expression generation for solutions
|
| 35 |
+
- Support for various geometry problem types
|
| 36 |
+
|
| 37 |
+
## Limitations
|
| 38 |
+
|
| 39 |
+
- Best performance with clear, simple geometric diagrams
|
| 40 |
+
- Requires both image and text input for optimal results
|
| 41 |
+
- Limited to plane geometry problems
|
| 42 |
+
|
| 43 |
+
## Citation
|
| 44 |
+
|
| 45 |
+
```bibtex
|
| 46 |
+
@inproceedings{zhang2023pgps,
|
| 47 |
+
title={PGPS: A Neural Geometric Solver},
|
| 48 |
+
author={Zhang, Mingliang and others},
|
| 49 |
+
booktitle={IJCAI 2023},
|
| 50 |
+
year={2023}
|
| 51 |
+
}
|
| 52 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Add current directory to path
|
| 10 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 11 |
+
|
| 12 |
+
from core.network import Network, MLMTransformerPretrain
|
| 13 |
+
from model.backbone import get_visual_backbone
|
| 14 |
+
from model.encoder import get_encoder
|
| 15 |
+
from model.decoder import get_decoder
|
| 16 |
+
from datasets.preprossing import SN
|
| 17 |
+
from datasets.utils import get_combined_text, get_var_arg, get_text_index
|
| 18 |
+
from datasets.operators import normalize_exp
|
| 19 |
+
import datasets.diagram_aug as T_diagram
|
| 20 |
+
|
| 21 |
+
# Language classes for vocabulary management
|
| 22 |
+
class Lang:
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.word2index = {}
|
| 25 |
+
self.word2count = {}
|
| 26 |
+
self.index2word = {0: "PAD", 1: "SOS", 2: "EOS", 3: "UNK"}
|
| 27 |
+
self.n_words = 4
|
| 28 |
+
self.class_tag = ['PAD', 'QUE', 'VAR', 'NUM', 'SEP']
|
| 29 |
+
self.sect_tag = ['PAD', 'TEXT', 'STRU', 'SEM']
|
| 30 |
+
|
| 31 |
+
def add_sentence(self, sentence):
|
| 32 |
+
for word in sentence.split(' '):
|
| 33 |
+
self.add_word(word)
|
| 34 |
+
|
| 35 |
+
def add_word(self, word):
|
| 36 |
+
if word not in self.word2index:
|
| 37 |
+
self.word2index[word] = self.n_words
|
| 38 |
+
self.word2count[word] = 1
|
| 39 |
+
self.index2word[self.n_words] = word
|
| 40 |
+
self.n_words += 1
|
| 41 |
+
else:
|
| 42 |
+
self.word2count[word] += 1
|
| 43 |
+
|
| 44 |
+
def indexes_from_sentence(self, sentence, var_values=None, arg_values=None):
|
| 45 |
+
indexes = []
|
| 46 |
+
for word in sentence.split(' '):
|
| 47 |
+
if word in self.word2index:
|
| 48 |
+
indexes.append(self.word2index[word])
|
| 49 |
+
else:
|
| 50 |
+
indexes.append(3) # UNK
|
| 51 |
+
return indexes
|
| 52 |
+
|
| 53 |
+
# Configuration class
|
| 54 |
+
class Config:
|
| 55 |
+
def __init__(self):
|
| 56 |
+
# Visual backbone
|
| 57 |
+
self.visual_backbone = "ResNet10"
|
| 58 |
+
self.diagram_size = 128
|
| 59 |
+
|
| 60 |
+
# Encoder
|
| 61 |
+
self.encoder_type = "gru"
|
| 62 |
+
self.encoder_layers = 2
|
| 63 |
+
self.encoder_embedding_size = 256
|
| 64 |
+
self.encoder_hidden_size = 512
|
| 65 |
+
self.max_input_len = 400
|
| 66 |
+
|
| 67 |
+
# Decoder
|
| 68 |
+
self.decoder_type = "rnn_decoder"
|
| 69 |
+
self.decoder_layers = 2
|
| 70 |
+
self.decoder_embedding_size = 512
|
| 71 |
+
self.decoder_hidden_size = 512
|
| 72 |
+
self.max_output_len = 40
|
| 73 |
+
|
| 74 |
+
# General
|
| 75 |
+
self.dropout_rate = 0.2
|
| 76 |
+
self.beam_size = 10
|
| 77 |
+
self.use_MLM_pretrain = True
|
| 78 |
+
self.MLM_pretrain_path = './LM_MODEL.pth'
|
| 79 |
+
self.pretrain_emb_path = ''
|
| 80 |
+
|
| 81 |
+
# Dataset
|
| 82 |
+
self.without_stru = False
|
| 83 |
+
|
| 84 |
+
# Initialize model
|
| 85 |
+
def load_model():
|
| 86 |
+
cfg = Config()
|
| 87 |
+
|
| 88 |
+
# Load vocabularies
|
| 89 |
+
src_lang = Lang()
|
| 90 |
+
tgt_lang = Lang()
|
| 91 |
+
|
| 92 |
+
# Load vocab files
|
| 93 |
+
with open('./vocab/vocab_src.txt', 'r') as f:
|
| 94 |
+
for line in f:
|
| 95 |
+
src_lang.add_word(line.strip())
|
| 96 |
+
|
| 97 |
+
with open('./vocab/vocab_tgt.txt', 'r') as f:
|
| 98 |
+
for line in f:
|
| 99 |
+
tgt_lang.add_word(line.strip())
|
| 100 |
+
|
| 101 |
+
# Create model
|
| 102 |
+
model = Network(cfg, src_lang, tgt_lang)
|
| 103 |
+
|
| 104 |
+
# Load pretrained weights if available
|
| 105 |
+
if os.path.exists('./LM_MODEL.pth'):
|
| 106 |
+
model.mlm_pretrain.load_model('./LM_MODEL.pth')
|
| 107 |
+
|
| 108 |
+
model.eval()
|
| 109 |
+
return model, src_lang, tgt_lang, cfg
|
| 110 |
+
|
| 111 |
+
# Process image and text
|
| 112 |
+
def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
|
| 113 |
+
# Transform image
|
| 114 |
+
diagram_transform = T_diagram.Compose([
|
| 115 |
+
T_diagram.Resize(cfg.diagram_size),
|
| 116 |
+
T_diagram.CenterCrop(cfg.diagram_size),
|
| 117 |
+
T_diagram.ToTensor(),
|
| 118 |
+
T_diagram.Normalize()
|
| 119 |
+
])
|
| 120 |
+
|
| 121 |
+
diagram = diagram_transform(image).unsqueeze(0)
|
| 122 |
+
|
| 123 |
+
# Process text input
|
| 124 |
+
# Create a simple text structure
|
| 125 |
+
text_sn = SN()
|
| 126 |
+
text_sn.word_list = text_input.split()
|
| 127 |
+
text_sn.clause_list = [text_input]
|
| 128 |
+
|
| 129 |
+
# Create empty parsing structures (will be filled with defaults)
|
| 130 |
+
parsing_stru = SN()
|
| 131 |
+
parsing_stru.word_list = []
|
| 132 |
+
parsing_stru.clause_list = []
|
| 133 |
+
|
| 134 |
+
parsing_sem = SN()
|
| 135 |
+
parsing_sem.word_list = []
|
| 136 |
+
parsing_sem.clause_list = []
|
| 137 |
+
|
| 138 |
+
# Combine text
|
| 139 |
+
combine_text = SN()
|
| 140 |
+
get_combined_text(text_sn, parsing_stru, parsing_sem, combine_text, cfg)
|
| 141 |
+
|
| 142 |
+
# Get text indices
|
| 143 |
+
text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang)
|
| 144 |
+
|
| 145 |
+
# Convert to tensors
|
| 146 |
+
text_dict = {
|
| 147 |
+
'token': torch.LongTensor([text_token]),
|
| 148 |
+
'sect_tag': torch.LongTensor([text_sect_tag]),
|
| 149 |
+
'class_tag': torch.LongTensor([text_class_tag]),
|
| 150 |
+
'len': torch.LongTensor([len(text_token)])
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
# Get variables and arguments
|
| 154 |
+
var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg)
|
| 155 |
+
|
| 156 |
+
var_dict = {
|
| 157 |
+
'pos': torch.LongTensor([var_arg_positions]),
|
| 158 |
+
'len': torch.LongTensor([len(var_arg_positions)]),
|
| 159 |
+
'var_value': var_values,
|
| 160 |
+
'arg_value': arg_values
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
# Create dummy expression dict for inference
|
| 164 |
+
exp_dict = {
|
| 165 |
+
'exp': torch.LongTensor([[1]]), # SOS token
|
| 166 |
+
'len': torch.LongTensor([1]),
|
| 167 |
+
'answer': 0
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
# Run inference
|
| 171 |
+
with torch.no_grad():
|
| 172 |
+
outputs = model(diagram, text_dict, var_dict, exp_dict, is_train=False)
|
| 173 |
+
|
| 174 |
+
# Decode outputs
|
| 175 |
+
if outputs is not None:
|
| 176 |
+
# Convert output indices to symbols
|
| 177 |
+
output_symbols = []
|
| 178 |
+
for idx in outputs[0]:
|
| 179 |
+
if idx < len(tgt_lang.index2word):
|
| 180 |
+
symbol = tgt_lang.index2word[idx]
|
| 181 |
+
if symbol == 'EOS':
|
| 182 |
+
break
|
| 183 |
+
if symbol not in ['PAD', 'SOS']:
|
| 184 |
+
output_symbols.append(symbol)
|
| 185 |
+
|
| 186 |
+
expression = ' '.join(output_symbols)
|
| 187 |
+
|
| 188 |
+
# Try to evaluate the expression
|
| 189 |
+
try:
|
| 190 |
+
# Simple evaluation (this would need more sophisticated handling in production)
|
| 191 |
+
result = eval_expression(expression, var_values, arg_values)
|
| 192 |
+
return f"Expression: {expression}\nResult: {result}"
|
| 193 |
+
except:
|
| 194 |
+
return f"Expression: {expression}\n(Could not evaluate)"
|
| 195 |
+
|
| 196 |
+
return "Could not generate solution"
|
| 197 |
+
|
| 198 |
+
def eval_expression(expr, var_values, arg_values):
|
| 199 |
+
# This is a simplified evaluator - would need proper implementation
|
| 200 |
+
# For now, just return the expression
|
| 201 |
+
return expr
|
| 202 |
+
|
| 203 |
+
# Gradio interface
|
| 204 |
+
def predict(image, text):
|
| 205 |
+
if image is None:
|
| 206 |
+
return "Please upload a geometry diagram image"
|
| 207 |
+
|
| 208 |
+
if not text:
|
| 209 |
+
return "Please provide the problem text"
|
| 210 |
+
|
| 211 |
+
try:
|
| 212 |
+
result = process_input(image, text, model, src_lang, tgt_lang, cfg)
|
| 213 |
+
return result
|
| 214 |
+
except Exception as e:
|
| 215 |
+
return f"Error processing input: {str(e)}"
|
| 216 |
+
|
| 217 |
+
# Load model on startup
|
| 218 |
+
print("Loading PGPS model...")
|
| 219 |
+
model, src_lang, tgt_lang, cfg = load_model()
|
| 220 |
+
print("Model loaded successfully!")
|
| 221 |
+
|
| 222 |
+
# Create Gradio interface
|
| 223 |
+
iface = gr.Interface(
|
| 224 |
+
fn=predict,
|
| 225 |
+
inputs=[
|
| 226 |
+
gr.Image(type="pil", label="Geometry Diagram"),
|
| 227 |
+
gr.Textbox(
|
| 228 |
+
lines=3,
|
| 229 |
+
placeholder="Enter the geometry problem text here...\nExample: Find the angle x if angle ABC is 60 degrees",
|
| 230 |
+
label="Problem Text"
|
| 231 |
+
)
|
| 232 |
+
],
|
| 233 |
+
outputs=gr.Textbox(label="Solution", lines=5),
|
| 234 |
+
title="PGPS: Neural Geometric Problem Solver",
|
| 235 |
+
description="Upload a geometry diagram and provide the problem text to get a solution.",
|
| 236 |
+
examples=[
|
| 237 |
+
[None, "Find the value of angle x if angle ABC is 60 degrees and angle BCD is 90 degrees"],
|
| 238 |
+
[None, "Calculate the area of triangle ABC if AB = 5, BC = 7, and angle B = 60 degrees"]
|
| 239 |
+
],
|
| 240 |
+
theme="default"
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
if __name__ == "__main__":
|
| 244 |
+
iface.launch()
|
config/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config_default import *
|
| 2 |
+
from .logger import *
|
| 3 |
+
|
config/config_default.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torchvision.models as models
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
model_names = sorted(name for name in models.__dict__
|
| 6 |
+
if name.islower() and not name.startswith("__") and callable(models.__dict__[name]))
|
| 7 |
+
|
| 8 |
+
criterion_list = ["CrossEntropy", "FocalLoss", "MaskedCrossEntropy"]
|
| 9 |
+
optimizer_list = ["SGD", "ADAM"]
|
| 10 |
+
scheduler_list = ["multistep",'cosine','warmup']
|
| 11 |
+
visual_backbone_list = ['ResNet10', 'mobilenet_v2']
|
| 12 |
+
encoder_list = ['lstm', 'gru', 'transformer']
|
| 13 |
+
decoder_list = ["rnn_decoder", "tree_decoder"]
|
| 14 |
+
eval_method_list = ["completion", "choice", "top3"]
|
| 15 |
+
dataset_list = ['Geometry3K', 'PGPS9K']
|
| 16 |
+
|
| 17 |
+
def get_parser():
|
| 18 |
+
parser = argparse.ArgumentParser(description='PyTorch PGPS Training')
|
| 19 |
+
# visual backbone
|
| 20 |
+
##############################################################################
|
| 21 |
+
parser.add_argument('--visual_backbone', default="ResNet10", type=str, choices=visual_backbone_list)
|
| 22 |
+
parser.add_argument('--diagram_size', default=128, type=int)
|
| 23 |
+
# encoder model
|
| 24 |
+
##############################################################################
|
| 25 |
+
parser.add_argument('--encoder_type', default="gru", type=str, choices=encoder_list)
|
| 26 |
+
parser.add_argument('--encoder_layers', default=2, type=int)
|
| 27 |
+
parser.add_argument('--encoder_embedding_size', default=256, type=int)
|
| 28 |
+
parser.add_argument('--encoder_hidden_size', default=512, type=int)
|
| 29 |
+
parser.add_argument('--max_input_len', default=400, type=int)
|
| 30 |
+
# decoder model
|
| 31 |
+
##############################################################################
|
| 32 |
+
parser.add_argument('--decoder_type', default="rnn_decoder", type=str, choices=decoder_list)
|
| 33 |
+
parser.add_argument('--decoder_layers', default=2, type=int)
|
| 34 |
+
parser.add_argument('--decoder_embedding_size', default=512, type=int)
|
| 35 |
+
parser.add_argument('--decoder_hidden_size', default=512, type=int)
|
| 36 |
+
parser.add_argument('--max_output_len', default=40, type=int)
|
| 37 |
+
# general model
|
| 38 |
+
##############################################################################
|
| 39 |
+
parser.add_argument('--dropout_rate', default=0.2, type=float)
|
| 40 |
+
parser.add_argument('--beam_size', default=10, type=int)
|
| 41 |
+
# optimizer
|
| 42 |
+
##############################################################################
|
| 43 |
+
parser.add_argument('--optimizer_type', default="ADAMW", type=str, choices=optimizer_list)
|
| 44 |
+
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate without LM')
|
| 45 |
+
parser.add_argument('--lr_LM', default=1e-4, type=float, help='initial learning rate of LM')
|
| 46 |
+
parser.add_argument('--weight_decay', default=0.01, type=float)
|
| 47 |
+
parser.add_argument('--max_epoch', default=540, type=int)
|
| 48 |
+
parser.add_argument('--scheduler_type', default="warmup", type=str, choices=scheduler_list)
|
| 49 |
+
parser.add_argument('--scheduler_step', default=[160, 280, 360, 440, 500], type=list)
|
| 50 |
+
parser.add_argument('--scheduler_factor', default=0.5, type=float, help='learning rate decay factor')
|
| 51 |
+
parser.add_argument('--cosine_decay_end', default=0.0, type=float, help='cosine decay end')
|
| 52 |
+
parser.add_argument('--warm_epoch', default=40, type=int)
|
| 53 |
+
# criterion
|
| 54 |
+
###############################################################################
|
| 55 |
+
parser.add_argument('--criterion', default="MaskedCrossEntropy", choices=criterion_list, type=str)
|
| 56 |
+
parser.add_argument('--eval_method', default="top3", choices=eval_method_list, type=str)
|
| 57 |
+
# dataset
|
| 58 |
+
################################################################################
|
| 59 |
+
parser.add_argument('--dataset', default="PGPS9K", type=str, choices=dataset_list)
|
| 60 |
+
parser.add_argument('--dataset_dir', default='./datasets/PGPS9K_all')
|
| 61 |
+
parser.add_argument('--pretrain_vis_path', default='')
|
| 62 |
+
parser.add_argument('--vocab_src_path', default='./vocab/vocab_src.txt')
|
| 63 |
+
parser.add_argument('--vocab_tgt_path', default='./vocab/vocab_tgt.txt')
|
| 64 |
+
parser.add_argument('--pretrain_emb_path', default='')
|
| 65 |
+
parser.add_argument('--batch_size', default=128, type=int)
|
| 66 |
+
parser.add_argument('--random_prob', default=0.5, type=float)
|
| 67 |
+
parser.add_argument('--without_stru', action='store_true', help='structure clauses are used or not')
|
| 68 |
+
parser.add_argument('--trim_min_count', default=5, type=int, help='minimum number of word')
|
| 69 |
+
parser.add_argument('--use_MLM_pretrain', action='store_true', help='use MLM pretrain')
|
| 70 |
+
parser.add_argument('--MLM_pretrain_path', default='./pretraining_model/LM_MODEL.pth')
|
| 71 |
+
# print information
|
| 72 |
+
###################################################################################
|
| 73 |
+
parser.add_argument('--dump_path', default="./log/", type=str, help='save log path')
|
| 74 |
+
parser.add_argument('--print_freq', default=20, type=int, help='print frequency')
|
| 75 |
+
parser.add_argument('--eval_epoch', default=40, type=int)
|
| 76 |
+
# general config
|
| 77 |
+
###################################################################################
|
| 78 |
+
parser.add_argument('--workers', default=4, type=int)
|
| 79 |
+
parser.add_argument('--evaluate_only', action='store_true', help='evaluate model on validation set')
|
| 80 |
+
parser.add_argument('--resume_model', default="", type=str, help='use pre-trained model')
|
| 81 |
+
# DistributedDataParallel
|
| 82 |
+
###################################################################################
|
| 83 |
+
parser.add_argument('--local_rank', default=0, type=int, help='node rank for distributed training')
|
| 84 |
+
parser.add_argument('--init_method', default="env://", type=str, help='distributed init method')
|
| 85 |
+
parser.add_argument('--debug', action='store_true', help = "if debug than set local rank = 0")
|
| 86 |
+
parser.add_argument('--seed', default=202302, type=int,help='seed for initializing training. ')
|
| 87 |
+
|
| 88 |
+
return parser.parse_args()
|
config/logger.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from logging import handlers
|
| 3 |
+
|
| 4 |
+
class Logger(object):
|
| 5 |
+
level_relations = {
|
| 6 |
+
'debug':logging.DEBUG,
|
| 7 |
+
'info':logging.INFO,
|
| 8 |
+
'warning':logging.WARNING,
|
| 9 |
+
'error':logging.ERROR,
|
| 10 |
+
'crit':logging.CRITICAL
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
def __init__(self, filename, rank, level='info', when='D', backCount=3, fmt='%(asctime)s - %(levelname)s: %(message)s'):
|
| 14 |
+
self.logger = logging.getLogger(filename)
|
| 15 |
+
if rank!=0: return
|
| 16 |
+
format_str = logging.Formatter(fmt)
|
| 17 |
+
self.logger.setLevel(self.level_relations.get(level))
|
| 18 |
+
sh = logging.StreamHandler()
|
| 19 |
+
sh.setFormatter(format_str)
|
| 20 |
+
th = handlers.TimedRotatingFileHandler(filename=filename,when=when,backupCount=backCount,encoding='utf-8')
|
| 21 |
+
th.setFormatter(format_str)
|
| 22 |
+
self.logger.addHandler(sh)
|
| 23 |
+
self.logger.addHandler(th)
|
| 24 |
+
|
| 25 |
+
def create_logger(filepath, rank):
|
| 26 |
+
log = Logger(filepath, rank)
|
| 27 |
+
return log.logger
|
core/__init__.py
ADDED
|
File without changes
|
core/network.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from model.backbone import get_visual_backbone
|
| 4 |
+
from model.encoder import get_encoder, TransformerEncoder
|
| 5 |
+
from model.decoder import get_decoder
|
| 6 |
+
from utils.utils import *
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MLMTransformerPretrain(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, cfg, src_lang):
|
| 13 |
+
super(MLMTransformerPretrain, self).__init__()
|
| 14 |
+
self.cfg = cfg
|
| 15 |
+
self.transformer_en = TransformerEncoder(cfg.encoder_embedding_size)
|
| 16 |
+
self.text_embedding_src = self.get_text_embedding_src(
|
| 17 |
+
vocab_size = src_lang.n_words,
|
| 18 |
+
embedding_dim = cfg.encoder_embedding_size,
|
| 19 |
+
padding_idx = 0,
|
| 20 |
+
pretrain_emb_path = cfg.pretrain_emb_path
|
| 21 |
+
)
|
| 22 |
+
self.class_tag_embedding = nn.Embedding(
|
| 23 |
+
len(src_lang.class_tag),
|
| 24 |
+
cfg.encoder_embedding_size,
|
| 25 |
+
padding_idx=0
|
| 26 |
+
)
|
| 27 |
+
self.sect_tag_embedding = nn.Embedding(
|
| 28 |
+
len(src_lang.sect_tag),
|
| 29 |
+
cfg.encoder_embedding_size,
|
| 30 |
+
padding_idx=0
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def forward(self, text_dict):
|
| 34 |
+
'''
|
| 35 |
+
text_dict = {'token', 'sect_tag', 'class_tag', 'len'}
|
| 36 |
+
'''
|
| 37 |
+
# text feature
|
| 38 |
+
token_emb = self.text_embedding_src(text_dict['token'])
|
| 39 |
+
class_tag_emb = self.class_tag_embedding(text_dict['class_tag'])
|
| 40 |
+
sect_tag_emb = self.sect_tag_embedding(text_dict['sect_tag'])
|
| 41 |
+
text_emb_src = token_emb.sum(dim=1) + sect_tag_emb + class_tag_emb
|
| 42 |
+
transformer_outputs = self.transformer_en(text_dict['len'], text_emb_src)
|
| 43 |
+
return transformer_outputs
|
| 44 |
+
|
| 45 |
+
def load_model(self, model_path):
|
| 46 |
+
pretrain_dict = torch.load(
|
| 47 |
+
model_path, map_location="cuda"
|
| 48 |
+
)
|
| 49 |
+
pretrain_dict_model = pretrain_dict['state_dict'] \
|
| 50 |
+
if 'state_dict' in pretrain_dict else pretrain_dict
|
| 51 |
+
model_dict = self.state_dict()
|
| 52 |
+
from collections import OrderedDict
|
| 53 |
+
new_dict = OrderedDict()
|
| 54 |
+
for k, v in pretrain_dict_model.items():
|
| 55 |
+
if k in model_dict:
|
| 56 |
+
if k.startswith("module"):
|
| 57 |
+
new_dict[k[7:]] = v
|
| 58 |
+
else:
|
| 59 |
+
new_dict[k] = v
|
| 60 |
+
model_dict.update(new_dict)
|
| 61 |
+
self.load_state_dict(model_dict)
|
| 62 |
+
|
| 63 |
+
def get_text_embedding_src(self, vocab_size, embedding_dim, padding_idx, pretrain_emb_path):
|
| 64 |
+
|
| 65 |
+
embedding_src = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
|
| 66 |
+
if pretrain_emb_path!='':
|
| 67 |
+
emb_content = []
|
| 68 |
+
with open(pretrain_emb_path, 'r') as f:
|
| 69 |
+
for line in f:
|
| 70 |
+
emb_content.append(line.split()[1:])
|
| 71 |
+
vector = np.asarray(emb_content, "float32")
|
| 72 |
+
embedding_src.weight.data[-len(emb_content):]. \
|
| 73 |
+
copy_(torch.from_numpy(vector))
|
| 74 |
+
return embedding_src
|
| 75 |
+
|
| 76 |
+
class Network(nn.Module):
|
| 77 |
+
|
| 78 |
+
def __init__(self, cfg, src_lang, tgt_lang):
|
| 79 |
+
super(Network, self).__init__()
|
| 80 |
+
self.cfg = cfg
|
| 81 |
+
# define the encoder and decoder
|
| 82 |
+
self.visual_extractor = get_visual_backbone(cfg)
|
| 83 |
+
self.encoder = get_encoder(cfg)
|
| 84 |
+
self.decoder = get_decoder(cfg, tgt_lang)
|
| 85 |
+
self.visual_emb_unify = nn.ModuleList([
|
| 86 |
+
nn.Linear(self.visual_extractor.final_feat_dim, cfg.encoder_embedding_size),
|
| 87 |
+
nn.ReLU(),
|
| 88 |
+
nn.Linear(cfg.encoder_embedding_size, cfg.encoder_embedding_size)]
|
| 89 |
+
)
|
| 90 |
+
self.visual_emb_unify = nn.Sequential(*self.visual_emb_unify)
|
| 91 |
+
|
| 92 |
+
if cfg.use_MLM_pretrain:
|
| 93 |
+
self.mlm_pretrain = MLMTransformerPretrain(cfg, src_lang)
|
| 94 |
+
if cfg.MLM_pretrain_path!='':
|
| 95 |
+
self.mlm_pretrain.load_model(cfg.MLM_pretrain_path)
|
| 96 |
+
else:
|
| 97 |
+
self.text_embedding_src = self.get_text_embedding_src(
|
| 98 |
+
vocab_size = src_lang.n_words,
|
| 99 |
+
embedding_dim = cfg.encoder_embedding_size,
|
| 100 |
+
padding_idx = 0,
|
| 101 |
+
pretrain_emb_path = cfg.pretrain_emb_path
|
| 102 |
+
)
|
| 103 |
+
self.class_tag_embedding = nn.Embedding(
|
| 104 |
+
len(src_lang.class_tag),
|
| 105 |
+
cfg.encoder_embedding_size,
|
| 106 |
+
padding_idx=0
|
| 107 |
+
)
|
| 108 |
+
self.sect_tag_embedding = nn.Embedding(
|
| 109 |
+
len(src_lang.sect_tag),
|
| 110 |
+
cfg.encoder_embedding_size,
|
| 111 |
+
padding_idx=0
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.src_lang = src_lang
|
| 115 |
+
|
| 116 |
+
def forward(self, diagram_src, text_dict, var_dict, exp_dict, is_train=False):
|
| 117 |
+
'''
|
| 118 |
+
diagram_src: B x C x W x H
|
| 119 |
+
text_dict = {'token', 'sect_tag', 'class_tag', 'len'} /
|
| 120 |
+
{'token', 'sect_tag', 'class_tag', 'subseq_len', 'item_len', 'item_quant'}
|
| 121 |
+
var_dict = {'pos', 'len', 'var_value', 'arg_value'}
|
| 122 |
+
exp_dict = {'exp', 'len', 'answer'}
|
| 123 |
+
'''
|
| 124 |
+
|
| 125 |
+
if self.cfg.use_MLM_pretrain:
|
| 126 |
+
text_emb_src = self.mlm_pretrain(text_dict)
|
| 127 |
+
else:
|
| 128 |
+
# text feature
|
| 129 |
+
token_emb = self.text_embedding_src(text_dict['token'])
|
| 130 |
+
class_tag_emb = self.class_tag_embedding(text_dict['class_tag'])
|
| 131 |
+
sect_tag_emb = self.sect_tag_embedding(text_dict['sect_tag'])
|
| 132 |
+
# all feature
|
| 133 |
+
text_emb_src = token_emb.sum(dim=1) + sect_tag_emb + class_tag_emb
|
| 134 |
+
|
| 135 |
+
# diagram feature
|
| 136 |
+
diagram_emb_src = self.visual_extractor(diagram_src)
|
| 137 |
+
diagram_emb_src = self.visual_emb_unify(diagram_emb_src).unsqueeze(dim=1)
|
| 138 |
+
# feature all
|
| 139 |
+
all_emb_src = torch.cat([diagram_emb_src, text_emb_src], dim=1)
|
| 140 |
+
text_dict['len'] += 1
|
| 141 |
+
var_dict['pos'] += 1
|
| 142 |
+
# encoder
|
| 143 |
+
encoder_outputs, encode_hidden = self.encoder(all_emb_src, text_dict['len'])
|
| 144 |
+
problem_output = encode_hidden[-1:,:,:].repeat(self.cfg.decoder_layers, 1, 1)
|
| 145 |
+
# decoder
|
| 146 |
+
outputs = self.decoder(encoder_outputs, problem_output, \
|
| 147 |
+
text_dict['len'], \
|
| 148 |
+
var_dict['pos'], var_dict['len'], \
|
| 149 |
+
exp_dict['exp'], \
|
| 150 |
+
is_train)
|
| 151 |
+
return outputs
|
| 152 |
+
|
| 153 |
+
def freeze_module(self, module):
|
| 154 |
+
self.cfg.logger.info("Freezing module of "+" .......")
|
| 155 |
+
for p in module.parameters():
|
| 156 |
+
p.requires_grad = False
|
| 157 |
+
|
| 158 |
+
def load_model(self, model_path):
|
| 159 |
+
pretrain_dict = torch.load(
|
| 160 |
+
model_path, map_location="cuda"
|
| 161 |
+
)
|
| 162 |
+
pretrain_dict_model = pretrain_dict['state_dict'] \
|
| 163 |
+
if 'state_dict' in pretrain_dict else pretrain_dict
|
| 164 |
+
model_dict = self.state_dict()
|
| 165 |
+
from collections import OrderedDict
|
| 166 |
+
new_dict = OrderedDict()
|
| 167 |
+
for k, v in pretrain_dict_model.items():
|
| 168 |
+
if k.startswith("module"):
|
| 169 |
+
new_dict[k[7:]] = v
|
| 170 |
+
else:
|
| 171 |
+
new_dict[k] = v
|
| 172 |
+
model_dict.update(new_dict)
|
| 173 |
+
self.load_state_dict(model_dict)
|
| 174 |
+
return pretrain_dict
|
| 175 |
+
|
| 176 |
+
def get_text_embedding_src(self, vocab_size, embedding_dim, padding_idx, pretrain_emb_path):
|
| 177 |
+
|
| 178 |
+
embedding_src = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
|
| 179 |
+
if pretrain_emb_path!='':
|
| 180 |
+
emb_content = []
|
| 181 |
+
with open(pretrain_emb_path, 'r') as f:
|
| 182 |
+
for line in f:
|
| 183 |
+
emb_content.append(line.split()[1:])
|
| 184 |
+
vector = np.asarray(emb_content, "float32")
|
| 185 |
+
embedding_src.weight.data[-len(emb_content):]. \
|
| 186 |
+
copy_(torch.from_numpy(vector))
|
| 187 |
+
return embedding_src
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_model(args, src_lang, tgt_lang):
|
| 191 |
+
model = Network(args, src_lang, tgt_lang)
|
| 192 |
+
args.logger.info(str(model))
|
| 193 |
+
return model
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
|
core/test.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from utils import *
|
| 3 |
+
|
| 4 |
+
def validate(args, val_loader, model, tgt_lang):
|
| 5 |
+
|
| 6 |
+
batch_time = AverageMeter('Time', ':5.3f')
|
| 7 |
+
acc_ans = AverageMeter('Ans_Acc', ':5.4f')
|
| 8 |
+
acc_eq = AverageMeter('Eq_Acc', ':5.4f')
|
| 9 |
+
progress = ProgressMeter(len(val_loader), [batch_time, acc_ans, acc_eq], args, prefix='Test: ')
|
| 10 |
+
# switch to evaluate mode
|
| 11 |
+
model.eval()
|
| 12 |
+
|
| 13 |
+
with torch.no_grad():
|
| 14 |
+
end = time.time()
|
| 15 |
+
for i, (diagrams, text_dict, var_dict, exp_dict) in enumerate(val_loader):
|
| 16 |
+
# set cuda for input data
|
| 17 |
+
diagrams = diagrams.cuda()
|
| 18 |
+
set_cuda(text_dict), set_cuda(var_dict), set_cuda(exp_dict)
|
| 19 |
+
# compute output
|
| 20 |
+
output = model(diagrams, text_dict, var_dict, exp_dict, is_train=False)
|
| 21 |
+
if args.eval_method == "completion":
|
| 22 |
+
acc1, acc2 = compute_exp_result_comp(output, var_dict, exp_dict, tgt_lang)
|
| 23 |
+
elif args.eval_method == "choice":
|
| 24 |
+
acc1, acc2 = compute_exp_result_choice(output, var_dict, exp_dict, tgt_lang)
|
| 25 |
+
elif args.eval_method == "top3":
|
| 26 |
+
acc1, acc2 = compute_exp_result_topk(output, var_dict, exp_dict, tgt_lang, k_num=3)
|
| 27 |
+
|
| 28 |
+
torch.distributed.barrier()
|
| 29 |
+
|
| 30 |
+
reduced_acc_ans = reduce_mean(torch.tensor([acc1]).cuda(), args.nprocs)
|
| 31 |
+
reduced_acc_eq = reduce_mean(torch.tensor([acc2]).cuda(), args.nprocs)
|
| 32 |
+
|
| 33 |
+
acc_ans.update(reduced_acc_ans.item(), len(diagrams))
|
| 34 |
+
acc_eq.update(reduced_acc_eq.item(), len(diagrams))
|
| 35 |
+
|
| 36 |
+
# measure elapsed time
|
| 37 |
+
batch_time.update(time.time() - end)
|
| 38 |
+
end = time.time()
|
| 39 |
+
|
| 40 |
+
return acc_ans.avg, acc_eq.avg
|
core/train.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from utils import *
|
| 3 |
+
|
| 4 |
+
def train(args, epoch, train_loader, model, criterion, optimizer):
|
| 5 |
+
|
| 6 |
+
batch_time = AverageMeter('Time', ':5.3f')
|
| 7 |
+
data_time = AverageMeter('Data', ':5.3f')
|
| 8 |
+
losses = AverageMeter('Loss', ':.4e')
|
| 9 |
+
progress = ProgressMeter(len(train_loader), [batch_time, data_time, losses],
|
| 10 |
+
args, prefix="Epoch: [{}]".format(epoch))
|
| 11 |
+
|
| 12 |
+
# switch to train mode
|
| 13 |
+
model.train()
|
| 14 |
+
end = time.time()
|
| 15 |
+
|
| 16 |
+
for i, (diagrams, text_dict, var_dict, exp_dict) in enumerate(train_loader):
|
| 17 |
+
'''
|
| 18 |
+
text_dict = {'token', 'sect_tag', 'class_tag', 'len'}
|
| 19 |
+
var_dict = {'pos', 'len', 'var_value', 'arg_value'}
|
| 20 |
+
exp_dict = {'exp', 'len', 'answer'}
|
| 21 |
+
'''
|
| 22 |
+
# measure data loading time
|
| 23 |
+
data_time.update(time.time() - end)
|
| 24 |
+
# set cuda for input data
|
| 25 |
+
diagrams = diagrams.cuda()
|
| 26 |
+
set_cuda(text_dict), set_cuda(var_dict), set_cuda(exp_dict)
|
| 27 |
+
# compute output
|
| 28 |
+
output = model(diagrams, text_dict, var_dict, exp_dict, is_train=True)
|
| 29 |
+
loss = criterion(output, exp_dict['exp'][:,1:].clone(), exp_dict['len']-1) # Remove special symbol [SOS]
|
| 30 |
+
# update the loss
|
| 31 |
+
torch.distributed.barrier()
|
| 32 |
+
reduced_loss = reduce_mean(loss, args.nprocs)
|
| 33 |
+
losses.update(reduced_loss.item(), len(diagrams))
|
| 34 |
+
# compute gradient and do SGD step
|
| 35 |
+
optimizer.zero_grad()
|
| 36 |
+
loss.backward()
|
| 37 |
+
optimizer.step()
|
| 38 |
+
# measure elapsed time
|
| 39 |
+
batch_time.update(time.time() - end)
|
| 40 |
+
end = time.time()
|
| 41 |
+
if i % args.print_freq == 0:
|
| 42 |
+
progress.display(i, lr = optimizer.state_dict()['param_groups'][0]['lr'])
|
| 43 |
+
|
| 44 |
+
return losses.avg
|
core/worker.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.optim
|
| 3 |
+
import torch.utils.data
|
| 4 |
+
import torch.nn.parallel
|
| 5 |
+
from core.train import *
|
| 6 |
+
from core.test import *
|
| 7 |
+
from utils import *
|
| 8 |
+
from core.network import get_model
|
| 9 |
+
from loss import get_criterion
|
| 10 |
+
from datasets import get_dataloader
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main_worker(args):
|
| 14 |
+
|
| 15 |
+
args.logger = initialize_logger(args)
|
| 16 |
+
train_loader, train_sampler, val_loader, src_lang, tgt_lang = get_dataloader(args)
|
| 17 |
+
model = get_model(args, src_lang, tgt_lang).cuda()
|
| 18 |
+
optimizer = get_optimizer(args, model)
|
| 19 |
+
scheduler = get_scheduler(args, optimizer)
|
| 20 |
+
criterion = get_criterion(args)
|
| 21 |
+
start_epoch = 0
|
| 22 |
+
|
| 23 |
+
# resume model
|
| 24 |
+
if not args.resume_model =='':
|
| 25 |
+
resume_model_dict = model.load_model(args.resume_model)
|
| 26 |
+
optimizer.load_state_dict(resume_model_dict['optimizer'])
|
| 27 |
+
scheduler.load_state_dict(resume_model_dict['scheduler'])
|
| 28 |
+
start_epoch = resume_model_dict["epoch"]+1
|
| 29 |
+
args.logger.info("The whole model has been loaded from "+ args.resume_model)
|
| 30 |
+
args.logger.info("The model resumes from epoch "+ str(resume_model_dict["epoch"]))
|
| 31 |
+
if args.evaluate_only:
|
| 32 |
+
acc_ans, acc_eq = validate(args, val_loader, model, tgt_lang)
|
| 33 |
+
args.logger.info("----------Epoch:{:>3d}, test answer_acc {:>5.4f}, equation_acc {:>5.4f} ---------" \
|
| 34 |
+
.format(resume_model_dict["epoch"], acc_ans, acc_eq))
|
| 35 |
+
return
|
| 36 |
+
else:
|
| 37 |
+
args.logger.info("The model is trained from scratch")
|
| 38 |
+
|
| 39 |
+
# distributed parallel training
|
| 40 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 41 |
+
model,
|
| 42 |
+
device_ids=[args.local_rank],
|
| 43 |
+
output_device=args.local_rank,
|
| 44 |
+
find_unused_parameters=True
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
min_loss = 1e10
|
| 48 |
+
|
| 49 |
+
for epoch in range(start_epoch, args.max_epoch):
|
| 50 |
+
# train for one epoch
|
| 51 |
+
train_sampler.set_epoch(epoch)
|
| 52 |
+
loss = train(args, epoch, train_loader, model, criterion, optimizer)
|
| 53 |
+
args.logger.info("----------Epoch:{:>3d}, training loss is {:>5.4f} ---------". \
|
| 54 |
+
format(epoch, loss))
|
| 55 |
+
# evaluate on validation set and save model
|
| 56 |
+
if args.local_rank == 0:
|
| 57 |
+
if epoch % args.eval_epoch==0 or epoch>=args.max_epoch-5:
|
| 58 |
+
save_checkpoint({
|
| 59 |
+
'epoch': epoch ,
|
| 60 |
+
'state_dict': model.state_dict(),
|
| 61 |
+
'scheduler': scheduler.state_dict(),
|
| 62 |
+
'optimizer': optimizer.state_dict()}, False, args.dump_path)
|
| 63 |
+
if loss<min_loss:
|
| 64 |
+
min_loss = loss
|
| 65 |
+
save_checkpoint({
|
| 66 |
+
'epoch': epoch ,
|
| 67 |
+
'state_dict': model.state_dict(),
|
| 68 |
+
'scheduler': scheduler.state_dict(),
|
| 69 |
+
'optimizer': optimizer.state_dict()}, True, args.dump_path)
|
| 70 |
+
# learning scheduler step
|
| 71 |
+
scheduler.step()
|
| 72 |
+
|
| 73 |
+
args.logger.info("------------------- Train Finished -------------------")
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
from datasets.dataset import MyDataset
|
| 3 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 4 |
+
from datasets.preprossing import *
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def get_dataloader(args):
|
| 8 |
+
|
| 9 |
+
src_lang = SrcLang(args.vocab_src_path)
|
| 10 |
+
tgt_lang = TgtLang(args.vocab_tgt_path)
|
| 11 |
+
|
| 12 |
+
train_data_path = os.path.join(args.dataset_dir, args.dataset, 'train.json')
|
| 13 |
+
train_pairs = get_raw_pairs(train_data_path)
|
| 14 |
+
test_data_path = os.path.join(args.dataset_dir, args.dataset, 'test.json')
|
| 15 |
+
test_pairs = get_raw_pairs(test_data_path)
|
| 16 |
+
|
| 17 |
+
train_data = MyDataset(args, train_pairs, src_lang, tgt_lang, is_train=True)
|
| 18 |
+
train_sampler = DistributedSampler(train_data, shuffle=True)
|
| 19 |
+
train_loader = DataLoader(dataset=train_data, \
|
| 20 |
+
batch_size=int(args.batch_size/args.nprocs), \
|
| 21 |
+
pin_memory=True, \
|
| 22 |
+
collate_fn=collater(args), \
|
| 23 |
+
num_workers=args.workers, \
|
| 24 |
+
sampler=train_sampler
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
test_data = MyDataset(args, test_pairs, src_lang, tgt_lang, is_train=False)
|
| 28 |
+
test_sampler = DistributedSampler(test_data, shuffle=False)
|
| 29 |
+
test_loader = DataLoader(dataset=test_data, \
|
| 30 |
+
batch_size=1, \
|
| 31 |
+
pin_memory=True, \
|
| 32 |
+
collate_fn=collater(args), \
|
| 33 |
+
num_workers=args.workers, \
|
| 34 |
+
sampler=test_sampler
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return train_loader, train_sampler, test_loader, src_lang, tgt_lang
|
datasets/dataset.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import datasets.diagram_aug as T_diagram
|
| 5 |
+
import datasets.text_aug as T_text
|
| 6 |
+
from datasets.operators import normalize_exp
|
| 7 |
+
from datasets.utils import get_combined_text, get_var_arg, get_text_index
|
| 8 |
+
from datasets.preprossing import SN
|
| 9 |
+
|
| 10 |
+
class MyDataset(torch.utils.data.Dataset):
|
| 11 |
+
|
| 12 |
+
def __init__(self, args, pairs, src_lang, tgt_lang, is_train=True):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.args = args
|
| 15 |
+
self.pairs = pairs
|
| 16 |
+
self.src_lang = src_lang
|
| 17 |
+
self.tgt_lang = tgt_lang
|
| 18 |
+
self.is_train = is_train
|
| 19 |
+
if is_train:
|
| 20 |
+
random_prob = args.random_prob
|
| 21 |
+
else:
|
| 22 |
+
random_prob = 0
|
| 23 |
+
self.diagram_transform = T_diagram.Compose([
|
| 24 |
+
T_diagram.Resize(args.diagram_size),
|
| 25 |
+
T_diagram.CenterCrop(args.diagram_size),
|
| 26 |
+
T_diagram.RandomFlip(random_prob),
|
| 27 |
+
T_diagram.ToTensor(),
|
| 28 |
+
T_diagram.Normalize()
|
| 29 |
+
])
|
| 30 |
+
self.text_transform = T_text.Compose([
|
| 31 |
+
T_text.Point_RandomReplace(random_prob),
|
| 32 |
+
T_text.AngID_RandomReplace(random_prob),
|
| 33 |
+
# T_text.Arg_RandomReplace(random_prob),
|
| 34 |
+
T_text.StruPoint_RandomRotate(random_prob),
|
| 35 |
+
# T_text.SemPoint_RandomRotate(random_prob),
|
| 36 |
+
T_text.SemSeq_RandomRotate(random_prob),
|
| 37 |
+
T_text.StruSeq_RandomRotate(random_prob),
|
| 38 |
+
])
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
'''
|
| 42 |
+
pair{
|
| 43 |
+
'diagram': str
|
| 44 |
+
'text': SN()
|
| 45 |
+
'parsing_stru_seqs': SN()
|
| 46 |
+
'parsing_sem_seqs': SN()
|
| 47 |
+
'expression': list
|
| 48 |
+
'answer': str
|
| 49 |
+
}
|
| 50 |
+
'''
|
| 51 |
+
pair = self.pairs[idx]
|
| 52 |
+
|
| 53 |
+
# diagram
|
| 54 |
+
diagram_path = os.path.join(self.args.dataset_dir, 'Diagram', pair['diagram'])
|
| 55 |
+
diagram = Image.open(diagram_path).convert("RGB")
|
| 56 |
+
diagram = self.diagram_transform(diagram)
|
| 57 |
+
# text, parsing_stru_seqs, parsing_sem_seqs,
|
| 58 |
+
self.text_transform(pair['text'],
|
| 59 |
+
pair['parsing_stru_seqs'],
|
| 60 |
+
pair['parsing_sem_seqs'],
|
| 61 |
+
pair['expression'])
|
| 62 |
+
combine_text = SN()
|
| 63 |
+
get_combined_text(pair['text'],
|
| 64 |
+
pair['parsing_stru_seqs'],
|
| 65 |
+
pair['parsing_sem_seqs'],
|
| 66 |
+
combine_text,
|
| 67 |
+
self.args)
|
| 68 |
+
text_token, text_sect_tag, text_class_tag = \
|
| 69 |
+
get_text_index(combine_text, self.src_lang)
|
| 70 |
+
# var and arg
|
| 71 |
+
var_arg_positions, var_values, arg_values = \
|
| 72 |
+
get_var_arg(combine_text, self.args)
|
| 73 |
+
# expression
|
| 74 |
+
expression = normalize_exp(pair['expression'])
|
| 75 |
+
expression = self.tgt_lang.indexes_from_sentence(expression, var_values, arg_values)
|
| 76 |
+
# choices
|
| 77 |
+
choices = [float(item) for item in pair['choices']]
|
| 78 |
+
|
| 79 |
+
return diagram, \
|
| 80 |
+
text_token, text_sect_tag, text_class_tag, \
|
| 81 |
+
var_arg_positions, var_values, arg_values, \
|
| 82 |
+
expression, pair['answer'], pair['id'], choices
|
| 83 |
+
|
| 84 |
+
def __len__(self):
|
| 85 |
+
return len(self.pairs)
|
datasets/diagram_aug.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from torchvision.transforms import functional as F
|
| 3 |
+
|
| 4 |
+
class Compose(object):
|
| 5 |
+
def __init__(self, transforms):
|
| 6 |
+
self.transforms = transforms
|
| 7 |
+
|
| 8 |
+
def __call__(self, image):
|
| 9 |
+
for t in self.transforms:
|
| 10 |
+
image = t(image)
|
| 11 |
+
return image
|
| 12 |
+
|
| 13 |
+
def __repr__(self):
|
| 14 |
+
format_string = self.__class__.__name__ + "("
|
| 15 |
+
for t in self.transforms:
|
| 16 |
+
format_string += "\n"
|
| 17 |
+
format_string += " {0}".format(t)
|
| 18 |
+
format_string += "\n)"
|
| 19 |
+
return format_string
|
| 20 |
+
|
| 21 |
+
class Resize(object):
|
| 22 |
+
'''
|
| 23 |
+
Resize the training diagram samples, resize the longest edge as max_size
|
| 24 |
+
'''
|
| 25 |
+
def __init__(self, max_size):
|
| 26 |
+
self.max_size = max_size
|
| 27 |
+
|
| 28 |
+
def get_size(self, image_size):
|
| 29 |
+
w, h = image_size
|
| 30 |
+
if w < h:
|
| 31 |
+
ow = int(w * self.max_size / h)
|
| 32 |
+
oh = self.max_size
|
| 33 |
+
else:
|
| 34 |
+
ow = self.max_size
|
| 35 |
+
oh = int(h * self.max_size / w)
|
| 36 |
+
return (oh, ow)
|
| 37 |
+
|
| 38 |
+
def __call__(self, image):
|
| 39 |
+
size = self.get_size(image.size)
|
| 40 |
+
image = F.resize(image, size)
|
| 41 |
+
return image
|
| 42 |
+
|
| 43 |
+
class CenterCrop(object):
|
| 44 |
+
'''
|
| 45 |
+
Crops the given image at the center.
|
| 46 |
+
'''
|
| 47 |
+
def __init__(self, size):
|
| 48 |
+
self.size = size
|
| 49 |
+
|
| 50 |
+
def __call__(self, image):
|
| 51 |
+
return F.center_crop(image, self.size)
|
| 52 |
+
|
| 53 |
+
class RandomFlip(object):
|
| 54 |
+
def __init__(self, prob=0.5):
|
| 55 |
+
self.prob = prob
|
| 56 |
+
|
| 57 |
+
def __call__(self, image):
|
| 58 |
+
if random.random() < self.prob:
|
| 59 |
+
flip_method = random.choice([0,1,2])
|
| 60 |
+
if flip_method==0:
|
| 61 |
+
image = F.hflip(image)
|
| 62 |
+
elif flip_method==1:
|
| 63 |
+
image = F.vflip(image)
|
| 64 |
+
elif flip_method==2:
|
| 65 |
+
image = F.vflip(F.hflip(image))
|
| 66 |
+
return image
|
| 67 |
+
|
| 68 |
+
class ToTensor(object):
|
| 69 |
+
def __call__(self, image):
|
| 70 |
+
return F.to_tensor(image)
|
| 71 |
+
|
| 72 |
+
class Normalize(object):
|
| 73 |
+
def __init__(self, mean=[0.85,0.85,0.85], std=[0.3,0.3,0.3]):
|
| 74 |
+
self.mean = mean
|
| 75 |
+
self.std = std
|
| 76 |
+
|
| 77 |
+
def __call__(self, image):
|
| 78 |
+
image = F.normalize(image, mean=self.mean, std=self.std)
|
| 79 |
+
return image
|
datasets/operators.py
ADDED
|
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sympy.parsing.latex import parse_latex
|
| 2 |
+
from sympy.printing import latex
|
| 3 |
+
from sympy import solve
|
| 4 |
+
from sympy.core.numbers import Float
|
| 5 |
+
|
| 6 |
+
################# Program Executor ########################
|
| 7 |
+
|
| 8 |
+
spec_token_list = ['frac', 'pi', 'sqrt']
|
| 9 |
+
spec_letter_list = ['f', 'r', 'a', 'c', 'p', 'i', 's', 'q', 'r', 't']
|
| 10 |
+
low_case_list = [chr(i) for i in range(97, 123)]
|
| 11 |
+
fixed_order_ops = [
|
| 12 |
+
'Get', 'Iso_Tri_Ang', 'Gsin', 'Gcos', 'Gtan', 'Geo_Mean', 'Ratio', 'TanSec_Ang', \
|
| 13 |
+
'Chord2_Ang', 'Tria_BH_Area', 'Para_Area', 'Kite_Area', 'Circle_R_Circum', \
|
| 14 |
+
'Circle_D_Circum', 'Circle_R_Area', 'Circle_D_Area', 'ArcSeg_Area', 'Ngon_Angsum', \
|
| 15 |
+
'RNgon_B_Area', 'RNgon_L_Area', 'RNgon_H_Area']
|
| 16 |
+
alterable_order_ops = [
|
| 17 |
+
'Sum', 'Multiple', 'Equal', 'Gougu', 'Cos_Law', 'Sin_Law', 'Median', 'Proportion', \
|
| 18 |
+
'Tria_SAS_Area', 'PRK_Perim', 'Rect_Area', 'Rhom_Area', 'Trap_Area']
|
| 19 |
+
arith_op_list = fixed_order_ops + alterable_order_ops
|
| 20 |
+
priority_list = ["V0", "V1", "V2", "V3", "V4", "V5", "V6", \
|
| 21 |
+
"N0", "N1", "N2", "N3", "N4", "N5", "N6", "N7", "N8", "N9", "N10", \
|
| 22 |
+
"C0.5", "C2", "C3", "C4", "C5", "C6", "C8", "C60", "C90", "C180", "C360"]
|
| 23 |
+
V_NUM = 10
|
| 24 |
+
|
| 25 |
+
class Varible_Record(object):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
self.varible_dict = dict()
|
| 28 |
+
self.mid_varible_dict = dict()
|
| 29 |
+
self.result = ''
|
| 30 |
+
|
| 31 |
+
def get_priority(token):
|
| 32 |
+
if token in priority_list:
|
| 33 |
+
return priority_list.index(token)
|
| 34 |
+
else:
|
| 35 |
+
return -1 # arg
|
| 36 |
+
|
| 37 |
+
def is_exist_operator(func, ANNO):
|
| 38 |
+
if not func in arith_op_list:
|
| 39 |
+
print("Can Not Find Operators!")
|
| 40 |
+
raise Exception
|
| 41 |
+
return func
|
| 42 |
+
|
| 43 |
+
def choose_result(result_list):
|
| 44 |
+
if len(result_list)==0:
|
| 45 |
+
return None
|
| 46 |
+
elif len(result_list)==1:
|
| 47 |
+
return result_list[0]
|
| 48 |
+
elif len(result_list)>1:
|
| 49 |
+
t1 = result_list[0].evalf()
|
| 50 |
+
t2 = result_list[1].evalf()
|
| 51 |
+
if (t1>t2 and t2<=0) or (t1<t2 and t1>0):
|
| 52 |
+
return result_list[0]
|
| 53 |
+
else:
|
| 54 |
+
return result_list[1]
|
| 55 |
+
|
| 56 |
+
def operand_update(operands, ANNO):
|
| 57 |
+
for id in range(len(operands)):
|
| 58 |
+
# Substitute variable
|
| 59 |
+
if operands[id] in ANNO.mid_varible_dict:
|
| 60 |
+
operands[id] = "("+ANNO.mid_varible_dict[operands[id]]+")"
|
| 61 |
+
# pi
|
| 62 |
+
if "\\pi" in operands[id]:
|
| 63 |
+
operands[id]=operands[id].replace('\\pi','(3.141593)')
|
| 64 |
+
# mixed number (improper fraction)
|
| 65 |
+
if "\\frac" in operands[id]:
|
| 66 |
+
loc = operands[id].index("\\frac")
|
| 67 |
+
if loc>0 and operands[id][loc-1].isdigit():
|
| 68 |
+
operands[id] = operands[id][:loc]+'+'+operands[id][loc:]
|
| 69 |
+
continue
|
| 70 |
+
# Substitute process(intermediate) variable
|
| 71 |
+
if operands[id] in ANNO.mid_varible_dict:
|
| 72 |
+
operands[id] = "("+ANNO.mid_varible_dict[operands[id]]+")"
|
| 73 |
+
continue
|
| 74 |
+
# Substitute constant
|
| 75 |
+
if operands[id][0] == 'C':
|
| 76 |
+
operands[id] = operands[id][1:]
|
| 77 |
+
|
| 78 |
+
return operands
|
| 79 |
+
|
| 80 |
+
def mid_var_solve(expr_step, ANNO, visit_list, midvar2letter):
|
| 81 |
+
|
| 82 |
+
# replace process(intermediate) variable
|
| 83 |
+
for key, value in midvar2letter.items():
|
| 84 |
+
expr_step = expr_step.replace(key, value)
|
| 85 |
+
# Convert the latex form expression to sympy solvable form
|
| 86 |
+
expr_step = parse_latex(expr_step)
|
| 87 |
+
# Solving argument
|
| 88 |
+
for letter in visit_list:
|
| 89 |
+
try:
|
| 90 |
+
result = solve(expr_step, letter)
|
| 91 |
+
result = choose_result(result)
|
| 92 |
+
except:
|
| 93 |
+
ANNO.mid_varible_dict[letter] = letter
|
| 94 |
+
continue
|
| 95 |
+
if not result is None:
|
| 96 |
+
result = latex(result)
|
| 97 |
+
is_update = True
|
| 98 |
+
result_t = result[:]
|
| 99 |
+
for item in spec_token_list:
|
| 100 |
+
result_t = result_t.replace(item, '')
|
| 101 |
+
# more than one unknown varibles
|
| 102 |
+
for letter_t in visit_list:
|
| 103 |
+
if letter_t in result_t and letter<letter_t:
|
| 104 |
+
is_update = False
|
| 105 |
+
break
|
| 106 |
+
# intermediate variables are existed
|
| 107 |
+
for key, value in midvar2letter.items():
|
| 108 |
+
if value in result_t:
|
| 109 |
+
is_update = False
|
| 110 |
+
break
|
| 111 |
+
if is_update:
|
| 112 |
+
ANNO.mid_varible_dict[letter] = result
|
| 113 |
+
else:
|
| 114 |
+
ANNO.mid_varible_dict[letter] = letter
|
| 115 |
+
|
| 116 |
+
# Solving process(intermediate) variable
|
| 117 |
+
for key1, value1 in midvar2letter.items():
|
| 118 |
+
if value1 in str(expr_step):
|
| 119 |
+
result = solve(expr_step, value1)
|
| 120 |
+
result = choose_result(result)
|
| 121 |
+
if not result is None:
|
| 122 |
+
# Convert the intermediate variable to latex form
|
| 123 |
+
result = latex(result)
|
| 124 |
+
# Convert lowercase letters to intermediate variables V_i
|
| 125 |
+
is_update = True
|
| 126 |
+
# more than one intermediate variables, only take the front intermediate variables
|
| 127 |
+
for key2, value2 in midvar2letter.items():
|
| 128 |
+
if value2 in result and value1<value2:
|
| 129 |
+
is_update = False
|
| 130 |
+
break
|
| 131 |
+
result.replace(value2, key2)
|
| 132 |
+
if is_update:
|
| 133 |
+
ANNO.mid_varible_dict[key1] = result
|
| 134 |
+
else:
|
| 135 |
+
ANNO.mid_varible_dict[key1] = key1
|
| 136 |
+
|
| 137 |
+
def mid_var_update(ANNO, visit_list, midvar2letter, midletter2var, is_subs_visit=True):
|
| 138 |
+
|
| 139 |
+
has_solved_list = []
|
| 140 |
+
# Find solved process varibles and arguments
|
| 141 |
+
for key, value in ANNO.mid_varible_dict.items():
|
| 142 |
+
if value!='' and isinstance(parse_latex(value).evalf(), Float) or \
|
| 143 |
+
(key in visit_list and is_subs_visit):
|
| 144 |
+
if not key in midvar2letter:
|
| 145 |
+
has_solved_list.append(key)
|
| 146 |
+
else:
|
| 147 |
+
has_solved_list.append(midvar2letter[key])
|
| 148 |
+
|
| 149 |
+
for key, mid_var in ANNO.mid_varible_dict.items():
|
| 150 |
+
if value!='' and not key in has_solved_list:
|
| 151 |
+
# Process varibles V_i are replaced as lowercase letters
|
| 152 |
+
for key1, value1 in midvar2letter.items():
|
| 153 |
+
mid_var = mid_var.replace(key1, value1)
|
| 154 |
+
# Special characters are replaced with '@' for marking
|
| 155 |
+
mid_var_t = mid_var[:]
|
| 156 |
+
for item in spec_token_list:
|
| 157 |
+
mid_var_t = mid_var_t.replace(item, "@"*len(item))
|
| 158 |
+
# Lowercase letters are replaced with solved values
|
| 159 |
+
mid_var_new = ''
|
| 160 |
+
for id in range(len(mid_var_t)):
|
| 161 |
+
if mid_var_t[id]!="@" and mid_var_t[id] in has_solved_list:
|
| 162 |
+
if mid_var_t[id] in midletter2var:
|
| 163 |
+
mid_var_new += "("+ANNO.mid_varible_dict[midletter2var[mid_var_t[id]]]+')'
|
| 164 |
+
else:
|
| 165 |
+
mid_var_new += "("+ANNO.mid_varible_dict[mid_var_t[id]]+')'
|
| 166 |
+
else:
|
| 167 |
+
mid_var_new += mid_var[id]
|
| 168 |
+
# Lowercase letters are replaced with V_i
|
| 169 |
+
for key2, value2 in midvar2letter.items():
|
| 170 |
+
mid_var_new = mid_var_new.replace(value2, key2)
|
| 171 |
+
ANNO.mid_varible_dict[key] = mid_var_new
|
| 172 |
+
|
| 173 |
+
def Get(ANNO, arg_list):
|
| 174 |
+
"""
|
| 175 |
+
Get(a) -> get numerical value of a
|
| 176 |
+
"""
|
| 177 |
+
if len(arg_list)!=1:
|
| 178 |
+
print("<Gets> function has only 1 augment!")
|
| 179 |
+
raise Exception
|
| 180 |
+
|
| 181 |
+
if arg_list[0] in ANNO.mid_varible_dict:
|
| 182 |
+
result = ANNO.mid_varible_dict[arg_list[0]]
|
| 183 |
+
else:
|
| 184 |
+
result_v = ANNO.varible_dict[arg_list[0]]
|
| 185 |
+
result_t = result_v[:]
|
| 186 |
+
for item in spec_token_list:
|
| 187 |
+
result_t = result_t.replace(item, "@"*len(item))
|
| 188 |
+
# Lowercase letters are replaced with solved values
|
| 189 |
+
result = ''
|
| 190 |
+
for id in range(len(result_t)):
|
| 191 |
+
if result_t[id]!="@" and result_t[id] in ANNO.mid_varible_dict:
|
| 192 |
+
result += "("+ANNO.mid_varible_dict[result_t[id]]+')'
|
| 193 |
+
else:
|
| 194 |
+
result += result_v[id]
|
| 195 |
+
ANNO.result = format(float(parse_latex(result).evalf()),'0.3f')
|
| 196 |
+
|
| 197 |
+
def Sum(arg_list):
|
| 198 |
+
"""
|
| 199 |
+
Sum(a, b, c, d) -> a+b+c=d
|
| 200 |
+
"""
|
| 201 |
+
if len(arg_list)<3:
|
| 202 |
+
print("<Sum> function has 3 augments at least!")
|
| 203 |
+
raise Exception
|
| 204 |
+
expr_step = arg_list[0]
|
| 205 |
+
for item in arg_list[1:-1]:
|
| 206 |
+
expr_step += "+" + item
|
| 207 |
+
expr_step += "-" + arg_list[-1]
|
| 208 |
+
return expr_step
|
| 209 |
+
|
| 210 |
+
def Multiple(arg_list):
|
| 211 |
+
"""
|
| 212 |
+
Multiple(a, b, c, d, e) -> a*b*c*d=e
|
| 213 |
+
"""
|
| 214 |
+
if len(arg_list)<3:
|
| 215 |
+
print("<Product> function has 3 augments at least!")
|
| 216 |
+
raise Exception
|
| 217 |
+
expr_step = arg_list[0]
|
| 218 |
+
for item in arg_list[1:-1]:
|
| 219 |
+
expr_step += "*" + item
|
| 220 |
+
expr_step += "-" + arg_list[-1]
|
| 221 |
+
return expr_step
|
| 222 |
+
|
| 223 |
+
def Equal(arg_list):
|
| 224 |
+
"""
|
| 225 |
+
Equal(a, b) -> a=b
|
| 226 |
+
"""
|
| 227 |
+
if len(arg_list)!=2:
|
| 228 |
+
print("<Equal> function has 2 augments!")
|
| 229 |
+
raise Exception
|
| 230 |
+
expr_step = arg_list[0] + "-" + arg_list[-1]
|
| 231 |
+
return expr_step
|
| 232 |
+
|
| 233 |
+
def Iso_Tri_Ang(arg_list):
|
| 234 |
+
"""
|
| 235 |
+
Iso_Tri_Ang(a, b) -> a+2*b=180
|
| 236 |
+
"""
|
| 237 |
+
if len(arg_list)!=2:
|
| 238 |
+
print("<Iso_Tri_Ang> function has 2 augments!")
|
| 239 |
+
raise Exception
|
| 240 |
+
expr_step = arg_list[0] + "+2*" + arg_list[-1]+"-180"
|
| 241 |
+
return expr_step
|
| 242 |
+
|
| 243 |
+
def Gougu(arg_list):
|
| 244 |
+
"""
|
| 245 |
+
Gougu(a, b, c) -> a^2+b^2=c^2
|
| 246 |
+
"""
|
| 247 |
+
if len(arg_list)!=3:
|
| 248 |
+
print("<Gougu> function has 3 augments!")
|
| 249 |
+
raise Exception
|
| 250 |
+
expr_step = arg_list[0]+'^{2}'+"+"+arg_list[1]+"^{2}"+'-'+arg_list[2]+"^{2}"
|
| 251 |
+
return expr_step
|
| 252 |
+
|
| 253 |
+
def Gsin(arg_list):
|
| 254 |
+
"""
|
| 255 |
+
Gsin(a, b, c) -> sin(c)=a/b
|
| 256 |
+
"""
|
| 257 |
+
if len(arg_list)!=3:
|
| 258 |
+
print("<Gsin> function has 3 augments!")
|
| 259 |
+
raise Exception
|
| 260 |
+
expr_step = arg_list[0]+'/'+arg_list[1]+'-'+'\\sin{'+arg_list[2]+'/180*3.141593}'
|
| 261 |
+
return expr_step
|
| 262 |
+
|
| 263 |
+
def Gcos(arg_list):
|
| 264 |
+
"""
|
| 265 |
+
Gcos(a, b, c) -> cos(c)=a/b
|
| 266 |
+
"""
|
| 267 |
+
if len(arg_list)!=3:
|
| 268 |
+
print("<Gcos> function has 3 augments!")
|
| 269 |
+
raise Exception
|
| 270 |
+
expr_step = arg_list[0]+'/'+arg_list[1]+'-'+'\\cos{'+arg_list[2]+'/180*3.141593}'
|
| 271 |
+
return expr_step
|
| 272 |
+
|
| 273 |
+
def Gtan(arg_list):
|
| 274 |
+
"""
|
| 275 |
+
Gtan(a, b, c) -> tan(c)=a/b
|
| 276 |
+
"""
|
| 277 |
+
if len(arg_list)!=3:
|
| 278 |
+
print("<Gtan> function has 3 augments!")
|
| 279 |
+
raise Exception
|
| 280 |
+
expr_step = arg_list[0]+'/'+arg_list[1]+'-'+'\\tan{'+arg_list[2]+'/180*3.141593}'
|
| 281 |
+
return expr_step
|
| 282 |
+
|
| 283 |
+
def Cos_Law(arg_list):
|
| 284 |
+
"""
|
| 285 |
+
Cos_Law(a, b, c, d) -> a^2=b^2+c^2-2*b*c
|
| 286 |
+
"""
|
| 287 |
+
if len(arg_list)!=4:
|
| 288 |
+
print("<Cos_Law> function has 4 augments!")
|
| 289 |
+
raise Exception
|
| 290 |
+
expr_step = arg_list[1]+'^{2}'+"+"+arg_list[2]+"^{2}"+'-'+arg_list[0]+"^{2}"+ \
|
| 291 |
+
'-'+"2*"+arg_list[1]+'*'+arg_list[2]+'*'+'\\cos{'+arg_list[3]+'/180*3.141593}'
|
| 292 |
+
return expr_step
|
| 293 |
+
|
| 294 |
+
def Sin_Law(arg_list):
|
| 295 |
+
"""
|
| 296 |
+
Sin_Law(a, b, c, d) -> sin(a)/b=sin(c)/d
|
| 297 |
+
"""
|
| 298 |
+
if len(arg_list)!=4:
|
| 299 |
+
print("<Sin_Law> function has 4 augments!")
|
| 300 |
+
raise Exception
|
| 301 |
+
expr_step = arg_list[3]+'*'+'\\sin{'+arg_list[0]+'/180*3.141593}'+'-'+ \
|
| 302 |
+
arg_list[1]+'*'+'\\sin{'+arg_list[2]+'/180*3.141593}'
|
| 303 |
+
return expr_step
|
| 304 |
+
|
| 305 |
+
def Median(arg_list):
|
| 306 |
+
"""
|
| 307 |
+
Median(a, b, c) -> a+c=2*b
|
| 308 |
+
"""
|
| 309 |
+
if len(arg_list)!=3:
|
| 310 |
+
print("<Median> function has 3 augments!")
|
| 311 |
+
raise Exception
|
| 312 |
+
expr_step = arg_list[0]+'-2*'+arg_list[1]+"+"+arg_list[2]
|
| 313 |
+
return expr_step
|
| 314 |
+
|
| 315 |
+
def Geo_Mean(arg_list):
|
| 316 |
+
"""
|
| 317 |
+
Geo_Mean(a, b, c) -> a*b=c^2
|
| 318 |
+
"""
|
| 319 |
+
if len(arg_list)!=3:
|
| 320 |
+
print("<Geo_Mean> function has 3 augments!")
|
| 321 |
+
raise Exception
|
| 322 |
+
expr_step = arg_list[0]+'*'+arg_list[1]+"-"+arg_list[2]+'^{2}'
|
| 323 |
+
return expr_step
|
| 324 |
+
|
| 325 |
+
def Proportion(arg_list):
|
| 326 |
+
"""
|
| 327 |
+
Proportion(a, b, c, d) -> a/b=c/d
|
| 328 |
+
Proportion(a, b, c, d, e) -> (a/b)^e=c/d
|
| 329 |
+
"""
|
| 330 |
+
if len(arg_list)<4:
|
| 331 |
+
print("<Proportion> function has 4 augments at least!")
|
| 332 |
+
raise Exception
|
| 333 |
+
if len(arg_list)==4:
|
| 334 |
+
expr_step = arg_list[0]+'*'+arg_list[3]+"-"+arg_list[1]+'*'+arg_list[2]
|
| 335 |
+
else:
|
| 336 |
+
expr_step = arg_list[0]+'*'+arg_list[3]+'^{1/'+arg_list[4]+"}-"+arg_list[1]+'*'+arg_list[2]+'^{1/'+arg_list[4]+"}"
|
| 337 |
+
return expr_step
|
| 338 |
+
|
| 339 |
+
def Ratio(arg_list):
|
| 340 |
+
"""
|
| 341 |
+
Ratio(a, b, c) -> a/b=c
|
| 342 |
+
Ratio(a, b, c, d) -> (a/b)^c=d
|
| 343 |
+
"""
|
| 344 |
+
if len(arg_list)<3 or len(arg_list)>4:
|
| 345 |
+
print("<Power> function has 3 or 4 augments!")
|
| 346 |
+
raise Exception
|
| 347 |
+
if len(arg_list)==3:
|
| 348 |
+
expr_step = arg_list[0]+' / '+arg_list[1]+'-'+arg_list[2]
|
| 349 |
+
else:
|
| 350 |
+
expr_step = '('+arg_list[0]+' / '+arg_list[1]+')^{'+arg_list[2]+"}"+"-"+arg_list[3]
|
| 351 |
+
return expr_step
|
| 352 |
+
|
| 353 |
+
def Chord2_Ang(arg_list):
|
| 354 |
+
"""
|
| 355 |
+
Chord2_Ang(a, b, c) -> a=(b+c)/2
|
| 356 |
+
"""
|
| 357 |
+
if len(arg_list)!=3:
|
| 358 |
+
print("<Chord2_Ang> function has 3 augments!")
|
| 359 |
+
raise Exception
|
| 360 |
+
expr_step = arg_list[0]+'*2-'+arg_list[1]+'-'+arg_list[2]
|
| 361 |
+
return expr_step
|
| 362 |
+
|
| 363 |
+
def TanSec_Ang(arg_list):
|
| 364 |
+
"""
|
| 365 |
+
TanSec_Ang(a, b, c) -> a=(c-b)/2
|
| 366 |
+
"""
|
| 367 |
+
if len(arg_list)!=3:
|
| 368 |
+
print("<TanSec_Ang> function has 3 augments!")
|
| 369 |
+
raise Exception
|
| 370 |
+
expr_step = arg_list[0]+'*2+'+arg_list[1]+'-'+arg_list[2]
|
| 371 |
+
return expr_step
|
| 372 |
+
|
| 373 |
+
def Tria_BH_Area(arg_list):
|
| 374 |
+
"""
|
| 375 |
+
Tria_BH_Area(a, b, c) -> a*b/2=c
|
| 376 |
+
"""
|
| 377 |
+
if len(arg_list)!=3:
|
| 378 |
+
print("<Tria_BH_Area> function has 3 augments!")
|
| 379 |
+
raise Exception
|
| 380 |
+
expr_step = arg_list[0]+'*'+arg_list[1]+'*0.5-'+arg_list[2]
|
| 381 |
+
return expr_step
|
| 382 |
+
|
| 383 |
+
def Tria_SAS_Area(arg_list):
|
| 384 |
+
"""
|
| 385 |
+
Tria_SAS_Area(a, b, c, d) -> a*c*sin(b)/2=d
|
| 386 |
+
"""
|
| 387 |
+
if len(arg_list)!=4:
|
| 388 |
+
print("<Tria_SAS_Area> function has 4 augments!")
|
| 389 |
+
raise Exception
|
| 390 |
+
expr_step = arg_list[0]+'*'+arg_list[2]+'*0.5*\\sin{'+arg_list[1]+'/180*3.141593}-'+arg_list[3]
|
| 391 |
+
return expr_step
|
| 392 |
+
|
| 393 |
+
def PRK_Perim(arg_list):
|
| 394 |
+
"""
|
| 395 |
+
PRK_Perim(a, b, c) -> (a+b)*2=c
|
| 396 |
+
"""
|
| 397 |
+
if len(arg_list)!=3:
|
| 398 |
+
print("<PRK_Perim> function has 3 augments!")
|
| 399 |
+
raise Exception
|
| 400 |
+
expr_step = arg_list[0]+'*2+'+arg_list[1]+'*2-'+arg_list[2]
|
| 401 |
+
return expr_step
|
| 402 |
+
|
| 403 |
+
def Para_Area(arg_list):
|
| 404 |
+
"""
|
| 405 |
+
Para_Area(a, b, c) -> a*b=c
|
| 406 |
+
"""
|
| 407 |
+
if len(arg_list)!=3:
|
| 408 |
+
print("<Para_Area> function has 3 augments!")
|
| 409 |
+
raise Exception
|
| 410 |
+
expr_step = arg_list[0]+'*'+arg_list[1]+'-'+arg_list[2]
|
| 411 |
+
return expr_step
|
| 412 |
+
|
| 413 |
+
def Rect_Area(arg_list):
|
| 414 |
+
"""
|
| 415 |
+
Rect_Area(a, b, c) -> a*b=c
|
| 416 |
+
"""
|
| 417 |
+
if len(arg_list)!=3:
|
| 418 |
+
print("<Rect_Area> function has 3 augments!")
|
| 419 |
+
raise Exception
|
| 420 |
+
expr_step = arg_list[0]+'*'+arg_list[1]+'-'+arg_list[2]
|
| 421 |
+
return expr_step
|
| 422 |
+
|
| 423 |
+
def Rhom_Area(arg_list):
|
| 424 |
+
"""
|
| 425 |
+
Rhom_Area(a, b, c) -> a*b*2=c
|
| 426 |
+
"""
|
| 427 |
+
if len(arg_list)!=3:
|
| 428 |
+
print("<Phom_Area> function has 3 augments!")
|
| 429 |
+
raise Exception
|
| 430 |
+
expr_step = arg_list[0]+'*'+arg_list[1]+'*2-'+arg_list[2]
|
| 431 |
+
return expr_step
|
| 432 |
+
|
| 433 |
+
def Kite_Area(arg_list):
|
| 434 |
+
"""
|
| 435 |
+
Kite_Area(a, b, c) -> a*b/2=c
|
| 436 |
+
"""
|
| 437 |
+
if len(arg_list)!=3:
|
| 438 |
+
print("<Kite_Area> function has 3 augments!")
|
| 439 |
+
raise Exception
|
| 440 |
+
expr_step = arg_list[0]+'*'+arg_list[1]+'*0.5-'+arg_list[2]
|
| 441 |
+
return expr_step
|
| 442 |
+
|
| 443 |
+
def Trap_Area(arg_list):
|
| 444 |
+
"""
|
| 445 |
+
Trap_Area(a, b, c, d) -> (a+b)*c/2=d
|
| 446 |
+
"""
|
| 447 |
+
if len(arg_list)!=4:
|
| 448 |
+
print("<Trap_Area> function has 4 augments!")
|
| 449 |
+
raise Exception
|
| 450 |
+
expr_step = '0.5*('+arg_list[0]+'+'+arg_list[1]+')*'+arg_list[2]+'-'+arg_list[3]
|
| 451 |
+
return expr_step
|
| 452 |
+
|
| 453 |
+
def Circle_R_Circum(arg_list):
|
| 454 |
+
"""
|
| 455 |
+
Circle_R_Circum(a, b) -> 2*pi*a=b
|
| 456 |
+
Circle_R_Circum(a, b, c) -> 2*pi*a*b/360=c
|
| 457 |
+
"""
|
| 458 |
+
if len(arg_list)<2 or len(arg_list)>3:
|
| 459 |
+
print("<Circle_Circum> function has 2 or 3 augments!")
|
| 460 |
+
raise Exception
|
| 461 |
+
if len(arg_list)==2:
|
| 462 |
+
expr_step = '2*3.141593*'+arg_list[0]+'-'+arg_list[1]
|
| 463 |
+
else:
|
| 464 |
+
expr_step = '2*3.141593*'+arg_list[0]+'*'+arg_list[1]+'/360'+'-'+arg_list[2]
|
| 465 |
+
return expr_step
|
| 466 |
+
|
| 467 |
+
def Circle_D_Circum(arg_list):
|
| 468 |
+
"""
|
| 469 |
+
Circle_D_Circum(a, b) -> pi*a=b
|
| 470 |
+
Circle_D_Circum(a, b, c) -> pi*a*b/360=c
|
| 471 |
+
"""
|
| 472 |
+
if len(arg_list)<2 or len(arg_list)>3:
|
| 473 |
+
print("<Circle_Circum> function has 2 or 3 augments!")
|
| 474 |
+
raise Exception
|
| 475 |
+
if len(arg_list)==2:
|
| 476 |
+
expr_step = '3.141593*'+arg_list[0]+'-'+arg_list[1]
|
| 477 |
+
else:
|
| 478 |
+
expr_step = '3.141593*'+arg_list[0]+'*'+arg_list[1]+'/360'+'-'+arg_list[2]
|
| 479 |
+
return expr_step
|
| 480 |
+
|
| 481 |
+
def Circle_R_Area(arg_list):
|
| 482 |
+
"""
|
| 483 |
+
Circle_R_Area(a, b) -> pi*a^2=b
|
| 484 |
+
Circle_R_Area(a, b, c) -> pi*a^2*b/360=c
|
| 485 |
+
"""
|
| 486 |
+
if len(arg_list)<2 and len(arg_list)>3:
|
| 487 |
+
print("<Circle_Area> function has 2 or 3 augments!")
|
| 488 |
+
raise Exception
|
| 489 |
+
if len(arg_list)==2:
|
| 490 |
+
expr_step = '3.141593*'+arg_list[0]+'^{2}-'+arg_list[1]
|
| 491 |
+
else:
|
| 492 |
+
expr_step = '3.141593*'+arg_list[0]+'^{2}*'+arg_list[1]+'/360'+'-'+arg_list[2]
|
| 493 |
+
return expr_step
|
| 494 |
+
|
| 495 |
+
def Circle_D_Area(arg_list):
|
| 496 |
+
"""
|
| 497 |
+
Circle_D_Area(a, b) -> pi*(a/2)^2=b
|
| 498 |
+
Circle_D_Area(a, b, c) -> pi*(a/2)^2*b/360=c
|
| 499 |
+
"""
|
| 500 |
+
if len(arg_list)<2 and len(arg_list)>3:
|
| 501 |
+
print("<Circle_Area> function has 2 or 3 augments!")
|
| 502 |
+
raise Exception
|
| 503 |
+
if len(arg_list)==2:
|
| 504 |
+
expr_step = '0.25*3.141593*'+arg_list[0]+'^{2}-'+arg_list[1]
|
| 505 |
+
else:
|
| 506 |
+
expr_step = '0.25*3.141593*'+arg_list[0]+'^{2}*'+arg_list[1]+'/360'+'-'+arg_list[2]
|
| 507 |
+
return expr_step
|
| 508 |
+
|
| 509 |
+
def ArcSeg_Area(arg_list):
|
| 510 |
+
"""
|
| 511 |
+
ArcSeg_Area(a, b, c) -> pi*a^2*b/360 - a^2*sin(b)/2 = c
|
| 512 |
+
"""
|
| 513 |
+
if len(arg_list)!=3:
|
| 514 |
+
print("<ArcSeg_Area> function has 3 augments!")
|
| 515 |
+
raise Exception
|
| 516 |
+
expr_step = '3.141593*'+arg_list[0]+'^{2}*'+arg_list[1]+'/360-0.5*'+ \
|
| 517 |
+
arg_list[0]+'^{2}*\\sin{'+arg_list[1]+'/180*3.141593}-'+arg_list[2]
|
| 518 |
+
return expr_step
|
| 519 |
+
|
| 520 |
+
def Ngon_Angsum(arg_list):
|
| 521 |
+
"""
|
| 522 |
+
Ngon_Ang(a, b) -> (a-2)*180=b
|
| 523 |
+
"""
|
| 524 |
+
if len(arg_list)!=2:
|
| 525 |
+
print("<Ngon_Ang> function has 2 augments!")
|
| 526 |
+
raise Exception
|
| 527 |
+
expr_step = '('+arg_list[0]+'-2)*180-'+arg_list[1]
|
| 528 |
+
return expr_step
|
| 529 |
+
|
| 530 |
+
def RNgon_B_Area(arg_list):
|
| 531 |
+
"""
|
| 532 |
+
RNgon_B_Area(a, b, c) -> a*b^2/tan(180/a)/4=c
|
| 533 |
+
"""
|
| 534 |
+
if len(arg_list)!=3:
|
| 535 |
+
print("<RNgon_B_Area> function has 3 augments!")
|
| 536 |
+
raise Exception
|
| 537 |
+
expr_step = arg_list[0]+'*'+arg_list[1]+'^{2}/4/\\tan{3.141593/'+arg_list[0]+'}-'+arg_list[2]
|
| 538 |
+
return expr_step
|
| 539 |
+
|
| 540 |
+
def RNgon_L_Area(arg_list):
|
| 541 |
+
"""
|
| 542 |
+
RNgon_L_Area(a, b, c) -> a*b^2*sin(360/a)/2=c
|
| 543 |
+
"""
|
| 544 |
+
if len(arg_list)!=3:
|
| 545 |
+
print("<RNgon_L_Area> function has 3 augments!")
|
| 546 |
+
raise Exception
|
| 547 |
+
expr_step = arg_list[0]+'*'+arg_list[1]+'^{2}*0.5*\\sin{2*3.141593/'+arg_list[0]+'}-'+arg_list[2]
|
| 548 |
+
return expr_step
|
| 549 |
+
|
| 550 |
+
def RNgon_H_Area(arg_list):
|
| 551 |
+
"""
|
| 552 |
+
RNgon_H_Area(a, b, c) -> a*b^2*tan(180/a)=c
|
| 553 |
+
"""
|
| 554 |
+
if len(arg_list)!=3:
|
| 555 |
+
print("<RNgon_H_Area> function has 3 augments!")
|
| 556 |
+
raise Exception
|
| 557 |
+
expr_step = arg_list[0]+'*'+arg_list[1]+'^{2}*\\tan{3.141593/'+arg_list[0]+'}-'+arg_list[2]
|
| 558 |
+
return expr_step
|
| 559 |
+
|
| 560 |
+
def result_compute(num_all_list, exp_tokens):
|
| 561 |
+
ANNO = Varible_Record()
|
| 562 |
+
# Obtain the mapping between lowercase letters to intermediate variables V_i
|
| 563 |
+
visit_list = [] # arguments denoted by lowercase letters
|
| 564 |
+
for num in num_all_list:
|
| 565 |
+
for item in spec_token_list:
|
| 566 |
+
num = num.replace(item, "@"*len(item))
|
| 567 |
+
for letter in num:
|
| 568 |
+
if letter in low_case_list: visit_list.append(letter)
|
| 569 |
+
for id, var in enumerate(num_all_list):
|
| 570 |
+
ANNO.varible_dict["N"+str(id)] = var
|
| 571 |
+
ANNO.mid_varible_dict["N"+str(id)] = var
|
| 572 |
+
visit_list.sort()
|
| 573 |
+
no_visit_list = list(set(low_case_list)-set(spec_letter_list)-set(visit_list))
|
| 574 |
+
no_visit_list.sort() # lowercase letters which have not used
|
| 575 |
+
# mapping between letters to intermediate variables V_i
|
| 576 |
+
midvar2letter = dict()
|
| 577 |
+
midletter2var = dict()
|
| 578 |
+
for id in range(V_NUM):
|
| 579 |
+
midvar2letter['V'+str(id)] = no_visit_list[id]
|
| 580 |
+
midletter2var[no_visit_list[id]] = 'V'+str(id)
|
| 581 |
+
# step split
|
| 582 |
+
step_list = []
|
| 583 |
+
last_op_id = 0
|
| 584 |
+
for id, token in enumerate(exp_tokens):
|
| 585 |
+
if token in arith_op_list and id>0:
|
| 586 |
+
step_list.append(exp_tokens[last_op_id:id])
|
| 587 |
+
last_op_id = id
|
| 588 |
+
step_list.append(exp_tokens[last_op_id:])
|
| 589 |
+
# run step
|
| 590 |
+
for id, step in enumerate(step_list):
|
| 591 |
+
operator = is_exist_operator(step[0], ANNO)
|
| 592 |
+
if operator!='Get':
|
| 593 |
+
operands = operand_update(step[1:], ANNO)
|
| 594 |
+
expr_step = eval(operator)(operands)
|
| 595 |
+
mid_var_solve(expr_step, ANNO, visit_list, midvar2letter)
|
| 596 |
+
mid_var_update(ANNO, visit_list, midvar2letter, midletter2var, True)
|
| 597 |
+
mid_var_update(ANNO, visit_list, midvar2letter, midletter2var, False)
|
| 598 |
+
else:
|
| 599 |
+
Get(ANNO, step[1:])
|
| 600 |
+
|
| 601 |
+
return ANNO.result
|
| 602 |
+
|
| 603 |
+
def normalize_exp(exp):
|
| 604 |
+
# step split
|
| 605 |
+
step_list = []
|
| 606 |
+
last_op_id = 0
|
| 607 |
+
for id, token in enumerate(exp):
|
| 608 |
+
if token in arith_op_list and id>0:
|
| 609 |
+
step_list.append(exp[last_op_id:id])
|
| 610 |
+
last_op_id = id
|
| 611 |
+
step_list.append(exp[last_op_id:])
|
| 612 |
+
# normalize step
|
| 613 |
+
new_exp = []
|
| 614 |
+
for step in step_list:
|
| 615 |
+
if step[0] in alterable_order_ops:
|
| 616 |
+
if step[0] in ['Sum', 'Multiple']:
|
| 617 |
+
begin_id, end_id = 1, -1
|
| 618 |
+
step[begin_id: end_id] = sorted(step[begin_id: end_id], key=lambda token:get_priority(token))
|
| 619 |
+
if step[0] in ['Equal', 'Gougu', 'PRK_Perim', 'Rect_Area', 'Rhom_Area', 'Trap_Area']:
|
| 620 |
+
begin_id, end_id = 1, 3
|
| 621 |
+
step[begin_id: end_id] = sorted(step[begin_id: end_id], key=lambda token:get_priority(token))
|
| 622 |
+
if step[0] == 'Cos_Law':
|
| 623 |
+
begin_id, end_id = 2, 4
|
| 624 |
+
step[begin_id: end_id] = sorted(step[begin_id: end_id], key=lambda token:get_priority(token))
|
| 625 |
+
if step[0] in ['Sin_Law', 'Proportion']:
|
| 626 |
+
if get_priority(step[1])>get_priority(step[3]) and len(step)==5:
|
| 627 |
+
step[1:3], step[3:5] = step[3:5], step[1:3]
|
| 628 |
+
if step[0] in ['Tria_SAS_Area', 'Median']:
|
| 629 |
+
if get_priority(step[1])>get_priority(step[3]):
|
| 630 |
+
step[1], step[3] = step[3], step[1]
|
| 631 |
+
new_exp += step
|
| 632 |
+
|
| 633 |
+
return new_exp
|
datasets/preprossing.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
from datasets.utils import *
|
| 4 |
+
|
| 5 |
+
class SrcLang:
|
| 6 |
+
|
| 7 |
+
def __init__(self, vocab_path):
|
| 8 |
+
self.word2index = {}
|
| 9 |
+
self.word2count = {}
|
| 10 |
+
self.index2word = []
|
| 11 |
+
self.n_words = 0
|
| 12 |
+
self.get_vocab(vocab_path)
|
| 13 |
+
self.class_tag = ['[PAD]', '[GEN]', '[POINT]', '[NUM]', '[ARG]', '[ANGID]']
|
| 14 |
+
self.sect_tag = ['[PAD]', '[PROB]', '[COND]', '[STRU]']
|
| 15 |
+
|
| 16 |
+
def get_vocab(self, vocab_path):
|
| 17 |
+
with open(vocab_path, 'r') as f:
|
| 18 |
+
for id, line in enumerate(f):
|
| 19 |
+
vocab_token = line[:-1]
|
| 20 |
+
self.word2index[vocab_token] = id
|
| 21 |
+
self.word2count[vocab_token] = 0
|
| 22 |
+
self.index2word.append(vocab_token)
|
| 23 |
+
self.n_words = len(self.index2word)
|
| 24 |
+
|
| 25 |
+
def indexes_from_sentence(self, sentence, id_type='text'):
|
| 26 |
+
res = []
|
| 27 |
+
if id_type == 'text':
|
| 28 |
+
for word in sentence:
|
| 29 |
+
if word in self.word2index:
|
| 30 |
+
res.append(self.word2index[word])
|
| 31 |
+
self.word2count[word] += 1
|
| 32 |
+
else:
|
| 33 |
+
res.append(self.word2index["[UNK]"])
|
| 34 |
+
self.word2count["[UNK]"] += 1
|
| 35 |
+
print("Can not find", word, 'in the src vocab')
|
| 36 |
+
elif id_type=='class_tag':
|
| 37 |
+
for word in sentence: res.append(self.class_tag.index(word))
|
| 38 |
+
elif id_type=='sect_tag':
|
| 39 |
+
for word in sentence: res.append(self.sect_tag.index(word))
|
| 40 |
+
return res
|
| 41 |
+
|
| 42 |
+
def sentence_from_indexes(self, indexes):
|
| 43 |
+
res = []
|
| 44 |
+
for index in indexes:
|
| 45 |
+
if index<len(self.index2word):
|
| 46 |
+
res.append(self.index2word[index])
|
| 47 |
+
else:
|
| 48 |
+
res.append("")
|
| 49 |
+
return res
|
| 50 |
+
|
| 51 |
+
class TgtLang:
|
| 52 |
+
|
| 53 |
+
def __init__(self, vocab_path):
|
| 54 |
+
self.word2index = {}
|
| 55 |
+
self.word2count = {}
|
| 56 |
+
self.index2word = []
|
| 57 |
+
self.n_words = 0
|
| 58 |
+
self.var_start = 0
|
| 59 |
+
self.get_vocab(vocab_path)
|
| 60 |
+
|
| 61 |
+
def get_vocab(self, vocab_path):
|
| 62 |
+
spe_num = midvar_num = const_num = 0
|
| 63 |
+
op_num = var_num = 0
|
| 64 |
+
|
| 65 |
+
with open(vocab_path, 'r') as f:
|
| 66 |
+
for id, line in enumerate(f):
|
| 67 |
+
vocab_token = line[:-1]
|
| 68 |
+
self.word2index[vocab_token] = id
|
| 69 |
+
self.word2count[vocab_token] = 0
|
| 70 |
+
self.index2word.append(vocab_token)
|
| 71 |
+
if vocab_token[0]=='[' and vocab_token[-1]==']':
|
| 72 |
+
spe_num += 1
|
| 73 |
+
elif vocab_token[0]=='V' and vocab_token[1].isdigit():
|
| 74 |
+
midvar_num += 1
|
| 75 |
+
elif vocab_token[0]=='C' and vocab_token[1].isdigit():
|
| 76 |
+
const_num += 1
|
| 77 |
+
elif vocab_token[0]=='N' and vocab_token[1].isdigit():
|
| 78 |
+
var_num += 1
|
| 79 |
+
else:
|
| 80 |
+
op_num += 1
|
| 81 |
+
|
| 82 |
+
self.n_words = len(self.index2word)
|
| 83 |
+
self.var_start = spe_num + midvar_num + const_num + op_num
|
| 84 |
+
|
| 85 |
+
def indexes_from_sentence(self, sentence, var_values, arg_values):
|
| 86 |
+
res = []
|
| 87 |
+
for word in sentence:
|
| 88 |
+
if word in self.word2index:
|
| 89 |
+
res.append(self.word2index[word])
|
| 90 |
+
self.word2count[word] += 1
|
| 91 |
+
elif len(word)==1 and word.islower(): # arg
|
| 92 |
+
res.append(self.var_start+len(var_values)+arg_values.index(word))
|
| 93 |
+
else:
|
| 94 |
+
print("Can not find", word, 'in the tgt vocab')
|
| 95 |
+
res = [self.word2index["[SOS]"]]+res+[self.word2index["[EOS]"]]
|
| 96 |
+
return res
|
| 97 |
+
|
| 98 |
+
def sentence_from_indexes(self, indexes, change_dict={}):
|
| 99 |
+
res = []
|
| 100 |
+
for index in indexes:
|
| 101 |
+
if index<len(self.index2word):
|
| 102 |
+
item = self.index2word[index]
|
| 103 |
+
else:
|
| 104 |
+
item = ''
|
| 105 |
+
if item in change_dict: item = change_dict[item] # var2arg
|
| 106 |
+
res.append(item)
|
| 107 |
+
return res
|
| 108 |
+
|
| 109 |
+
class SN:
|
| 110 |
+
def __init__(self):
|
| 111 |
+
self.token = [] # str list
|
| 112 |
+
self.sect_tag = [] # [PROB]/[COND]/[STRU]
|
| 113 |
+
self.class_tag = [] # [GEN]/[NUM]/[ARG]/[POINT]/[ANGID]
|
| 114 |
+
|
| 115 |
+
def get_raw_pairs(dataset_path):
|
| 116 |
+
|
| 117 |
+
raw_pairs = []
|
| 118 |
+
|
| 119 |
+
with open(dataset_path, 'r')as fp:
|
| 120 |
+
content_all = json.load(fp)
|
| 121 |
+
|
| 122 |
+
for key, content in content_all.items():
|
| 123 |
+
text = content['text']
|
| 124 |
+
stru_seqs = content['parsing_stru_seqs']
|
| 125 |
+
sem_seqs = content['parsing_sem_seqs']
|
| 126 |
+
text_data, stru_data, sem_data = SN(), SN(), SN()
|
| 127 |
+
# tokenization
|
| 128 |
+
text_data.token = get_token(text)
|
| 129 |
+
stru_data.token = [get_token(item)+[','] for item in stru_seqs]
|
| 130 |
+
sem_data.token = [get_token(item)+[','] for item in sem_seqs]
|
| 131 |
+
# split prob and cond
|
| 132 |
+
text_data.sect_tag = []
|
| 133 |
+
stru_data.sect_tag = [['[STRU]']*len(item) for item in stru_data.token]
|
| 134 |
+
sem_data.sect_tag = [['[COND]']*len(item) for item in sem_data.token]
|
| 135 |
+
split_text(text_data)
|
| 136 |
+
# get class tag
|
| 137 |
+
text_data.class_tag = ['[GEN]']*len(text_data.token)
|
| 138 |
+
stru_data.class_tag = [['[GEN]']*len(item) for item in stru_data.token]
|
| 139 |
+
sem_data.class_tag = [['[GEN]']*len(item) for item in sem_data.token]
|
| 140 |
+
get_point_angleID_tag(text_data, stru_data, sem_data)
|
| 141 |
+
get_num_arg_tag(text_data, sem_data)
|
| 142 |
+
# Tag the repeat [NUM] in sem_data which has exist in text_data
|
| 143 |
+
expression = content['expression'].split(' ')
|
| 144 |
+
remove_sem_dup(text_data, sem_data, expression)
|
| 145 |
+
|
| 146 |
+
content['text'] = text_data
|
| 147 |
+
content['parsing_stru_seqs'] = stru_data
|
| 148 |
+
content['parsing_sem_seqs'] = sem_data
|
| 149 |
+
content['expression'] = expression
|
| 150 |
+
content['id'] = key
|
| 151 |
+
|
| 152 |
+
raw_pairs.append(content)
|
| 153 |
+
|
| 154 |
+
return raw_pairs
|
| 155 |
+
|
| 156 |
+
class collater():
|
| 157 |
+
|
| 158 |
+
def __init__(self, args):
|
| 159 |
+
self.args = args
|
| 160 |
+
|
| 161 |
+
def __call__(self, batch_data, padding_id=0):
|
| 162 |
+
diagrams, \
|
| 163 |
+
text_tokens, text_sect_tags, text_class_tags, \
|
| 164 |
+
var_arg_positions, var_values, arg_values, \
|
| 165 |
+
expression, answer, pair_ids, choices = list(zip(*batch_data))
|
| 166 |
+
#######################################
|
| 167 |
+
diagrams = torch.stack(diagrams, dim=0)
|
| 168 |
+
#######################################
|
| 169 |
+
len_exp = [len(seq_exp) for seq_exp in expression]
|
| 170 |
+
max_len_exp = max(len_exp)
|
| 171 |
+
expression = [seq_exp+[padding_id]*(max_len_exp-len(seq_exp)) for seq_exp in expression]
|
| 172 |
+
exp_dict = {'exp': torch.LongTensor(expression),
|
| 173 |
+
'len': torch.LongTensor(len_exp),
|
| 174 |
+
'answer': answer,
|
| 175 |
+
'id': pair_ids,
|
| 176 |
+
'choices': choices
|
| 177 |
+
}
|
| 178 |
+
#######################################
|
| 179 |
+
len_var = [max(len(seq_var),1) for seq_var in var_arg_positions]
|
| 180 |
+
max_len_var = max(len_var)
|
| 181 |
+
var_arg_positions = [seq_var+[padding_id]*(max_len_var-len(seq_var)) for seq_var in var_arg_positions]
|
| 182 |
+
var_dict = {'pos':torch.LongTensor(var_arg_positions),
|
| 183 |
+
'len': torch.LongTensor(len_var),
|
| 184 |
+
'var_value': var_values,
|
| 185 |
+
'arg_value': arg_values
|
| 186 |
+
}
|
| 187 |
+
########################################
|
| 188 |
+
len_text = [len(seq_tag) for seq_tag in text_class_tags]
|
| 189 |
+
max_len_text = max(len_text)
|
| 190 |
+
for k in range(len(text_tokens)):
|
| 191 |
+
for j in range(len(text_tokens[k])):
|
| 192 |
+
text_tokens[k][j] += [padding_id]*(max_len_text-len(text_tokens[k][j]))
|
| 193 |
+
text_sect_tags = [seq_tag+[padding_id]*(max_len_text-len(seq_tag)) for seq_tag in text_sect_tags]
|
| 194 |
+
text_class_tags = [seq_tag+[padding_id]*(max_len_text-len(seq_tag)) for seq_tag in text_class_tags]
|
| 195 |
+
text_dict = {'token': torch.LongTensor(text_tokens),
|
| 196 |
+
'sect_tag': torch.LongTensor(text_sect_tags),
|
| 197 |
+
'class_tag': torch.LongTensor(text_class_tags),
|
| 198 |
+
'len': torch.LongTensor(len_text)
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
return diagrams, text_dict, var_dict, exp_dict
|
datasets/text_aug.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
upper_case_list = [chr(i) for i in range(65, 91)]
|
| 4 |
+
low_case_list = [chr(i) for i in range(97, 123)]
|
| 5 |
+
angle_id_list = [str(i) for i in range(1, 21)]
|
| 6 |
+
spec_token_list = ['frac', 'pi', 'sqrt']
|
| 7 |
+
|
| 8 |
+
class Compose(object):
|
| 9 |
+
|
| 10 |
+
def __init__(self, transforms):
|
| 11 |
+
self.transforms = transforms
|
| 12 |
+
|
| 13 |
+
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
|
| 14 |
+
for t in self.transforms:
|
| 15 |
+
t(text_seq, stru_seqs, sem_seqs, exp)
|
| 16 |
+
|
| 17 |
+
def __repr__(self):
|
| 18 |
+
format_string = self.__class__.__name__ + "("
|
| 19 |
+
for t in self.transforms:
|
| 20 |
+
format_string += "\n"
|
| 21 |
+
format_string += " {0}".format(t)
|
| 22 |
+
format_string += "\n)"
|
| 23 |
+
return format_string
|
| 24 |
+
|
| 25 |
+
class Point_RandomReplace(object):
|
| 26 |
+
|
| 27 |
+
def __init__(self, prob=0.5):
|
| 28 |
+
self.prob = prob
|
| 29 |
+
|
| 30 |
+
def get_point_map(self):
|
| 31 |
+
value_list = [chr(i) for i in range(65, 91)]
|
| 32 |
+
random.shuffle(value_list)
|
| 33 |
+
map_dict = {key:value for key, value in zip(upper_case_list, value_list)}
|
| 34 |
+
return map_dict
|
| 35 |
+
|
| 36 |
+
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
|
| 37 |
+
if random.random() < self.prob:
|
| 38 |
+
map_dict = self.get_point_map()
|
| 39 |
+
for k in range(len(text_seq.token)):
|
| 40 |
+
if text_seq.class_tag[k] == '[POINT]':
|
| 41 |
+
text_seq.token[k] = map_dict[text_seq.token[k][0]]
|
| 42 |
+
for k in range(len(stru_seqs.token)):
|
| 43 |
+
for j in range(len(stru_seqs.token[k])):
|
| 44 |
+
if stru_seqs.class_tag[k][j] == '[POINT]':
|
| 45 |
+
stru_seqs.token[k][j] = map_dict[stru_seqs.token[k][j][0]]
|
| 46 |
+
for k in range(len(sem_seqs.token)):
|
| 47 |
+
for j in range(len(sem_seqs.token[k])):
|
| 48 |
+
if sem_seqs.class_tag[k][j] == '[POINT]':
|
| 49 |
+
sem_seqs.token[k][j] = map_dict[sem_seqs.token[k][j][0]]
|
| 50 |
+
|
| 51 |
+
class AngID_RandomReplace(object):
|
| 52 |
+
|
| 53 |
+
def __init__(self, prob=0.5):
|
| 54 |
+
self.prob = prob
|
| 55 |
+
|
| 56 |
+
def get_angid_map(self):
|
| 57 |
+
value_list = [str(i) for i in range(1, 21)]
|
| 58 |
+
random.shuffle(value_list)
|
| 59 |
+
map_dict = {key:value for key, value in zip(angle_id_list, value_list)}
|
| 60 |
+
return map_dict
|
| 61 |
+
|
| 62 |
+
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
|
| 63 |
+
if random.random() < self.prob:
|
| 64 |
+
map_dict = self.get_angid_map()
|
| 65 |
+
for k in range(len(text_seq.token)):
|
| 66 |
+
if text_seq.class_tag[k] == '[ANGID]':
|
| 67 |
+
text_seq.token[k] = map_dict[text_seq.token[k]]
|
| 68 |
+
for k in range(len(sem_seqs.token)):
|
| 69 |
+
for j in range(len(sem_seqs.token[k])):
|
| 70 |
+
if sem_seqs.class_tag[k][j] == '[ANGID]':
|
| 71 |
+
sem_seqs.token[k][j] = map_dict[sem_seqs.token[k][j]]
|
| 72 |
+
|
| 73 |
+
class Arg_RandomReplace(object):
|
| 74 |
+
|
| 75 |
+
def __init__(self, prob=0.5):
|
| 76 |
+
self.prob = prob
|
| 77 |
+
|
| 78 |
+
def get_arg_map(self):
|
| 79 |
+
value_list = [chr(i) for i in range(97, 123)]
|
| 80 |
+
random.shuffle(value_list)
|
| 81 |
+
map_dict = {key:value for key, value in zip(low_case_list, value_list)}
|
| 82 |
+
return map_dict
|
| 83 |
+
|
| 84 |
+
def map_arg_in_num(self, map_dict, num):
|
| 85 |
+
num_t = num[:]
|
| 86 |
+
new_num = ''
|
| 87 |
+
for item in spec_token_list:
|
| 88 |
+
num_t = num_t.replace(item, "@"*len(item))
|
| 89 |
+
for k in range(len(num_t)):
|
| 90 |
+
if num_t[k]!='@' and num[k] in low_case_list:
|
| 91 |
+
new_num += map_dict[num[k]]
|
| 92 |
+
else:
|
| 93 |
+
new_num += num[k]
|
| 94 |
+
return new_num
|
| 95 |
+
|
| 96 |
+
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
|
| 97 |
+
if random.random() < self.prob:
|
| 98 |
+
map_dict = self.get_arg_map()
|
| 99 |
+
for k in range(len(text_seq.token)):
|
| 100 |
+
if text_seq.class_tag[k] == '[NUM]':
|
| 101 |
+
text_seq.token[k] = self.map_arg_in_num(map_dict, text_seq.token[k])
|
| 102 |
+
if text_seq.class_tag[k] == '[ARG]':
|
| 103 |
+
text_seq.token[k] = map_dict[text_seq.token[k]]
|
| 104 |
+
for k in range(len(sem_seqs.token)):
|
| 105 |
+
for j in range(len(sem_seqs.token[k])):
|
| 106 |
+
if sem_seqs.class_tag[k][j] == '[NUM]':
|
| 107 |
+
sem_seqs.token[k][j] = self.map_arg_in_num(map_dict, sem_seqs.token[k][j])
|
| 108 |
+
for k in range(len(exp)):
|
| 109 |
+
if exp[k] in low_case_list:
|
| 110 |
+
exp[k] = map_dict[exp[k]]
|
| 111 |
+
|
| 112 |
+
class StruPoint_RandomRotate(object):
|
| 113 |
+
|
| 114 |
+
def __init__(self, prob=0.5):
|
| 115 |
+
self.prob = prob
|
| 116 |
+
|
| 117 |
+
def get_seq_points(self, class_tag):
|
| 118 |
+
id_list = []
|
| 119 |
+
begin_point_id = end_point_id = None
|
| 120 |
+
for id, token in enumerate(class_tag):
|
| 121 |
+
if token == '[POINT]':
|
| 122 |
+
if begin_point_id is None:
|
| 123 |
+
begin_point_id = id
|
| 124 |
+
elif not begin_point_id is None and end_point_id is None:
|
| 125 |
+
end_point_id = id
|
| 126 |
+
id_list.append([begin_point_id, end_point_id])
|
| 127 |
+
begin_point_id = end_point_id = None
|
| 128 |
+
if not begin_point_id is None and end_point_id is None:
|
| 129 |
+
id_list.append([begin_point_id, len(class_tag)])
|
| 130 |
+
|
| 131 |
+
return id_list[-1][0], id_list[-1][1]
|
| 132 |
+
|
| 133 |
+
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
|
| 134 |
+
for k in range(len(stru_seqs.token)):
|
| 135 |
+
if random.random() < self.prob:
|
| 136 |
+
begin_id, end_id = self.get_seq_points(stru_seqs.class_tag[k])
|
| 137 |
+
# point on line
|
| 138 |
+
if stru_seqs.token[k][0] == 'line':
|
| 139 |
+
stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][end_id-1:begin_id-1:-1]
|
| 140 |
+
# point on circle
|
| 141 |
+
if stru_seqs.token[k][0] == '\\odot':
|
| 142 |
+
# clockwise change
|
| 143 |
+
if random.random() < 0.5:
|
| 144 |
+
stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][end_id-1:begin_id-1:-1]
|
| 145 |
+
# set initial point
|
| 146 |
+
init_loc = random.randint(begin_id, end_id-1)
|
| 147 |
+
stru_seqs.token[k][begin_id:end_id] = stru_seqs.token[k][init_loc:end_id] + \
|
| 148 |
+
stru_seqs.token[k][begin_id:init_loc]
|
| 149 |
+
|
| 150 |
+
class SemPoint_RandomRotate(object):
|
| 151 |
+
|
| 152 |
+
def __init__(self, prob=0.5):
|
| 153 |
+
self.prob = prob
|
| 154 |
+
|
| 155 |
+
def get_seq_points(self, class_tag):
|
| 156 |
+
id_list = []
|
| 157 |
+
begin_point_id = end_point_id = None
|
| 158 |
+
for id, token in enumerate(class_tag):
|
| 159 |
+
if token == '[POINT]':
|
| 160 |
+
if begin_point_id is None:
|
| 161 |
+
begin_point_id = id
|
| 162 |
+
elif not begin_point_id is None and end_point_id is None:
|
| 163 |
+
end_point_id = id
|
| 164 |
+
id_list.append((begin_point_id, end_point_id-1))
|
| 165 |
+
begin_point_id = end_point_id = None
|
| 166 |
+
if not begin_point_id is None and end_point_id is None:
|
| 167 |
+
id_list.append((begin_point_id, len(class_tag)-1))
|
| 168 |
+
|
| 169 |
+
return id_list
|
| 170 |
+
|
| 171 |
+
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
|
| 172 |
+
if random.random() < self.prob:
|
| 173 |
+
for k in range(len(sem_seqs.token)):
|
| 174 |
+
id_list = self.get_seq_points(sem_seqs.class_tag[k])
|
| 175 |
+
for begin_id, end_id in id_list:
|
| 176 |
+
if random.random() < self.prob:
|
| 177 |
+
sem_seqs.token[k][begin_id], sem_seqs.token[k][end_id] = \
|
| 178 |
+
sem_seqs.token[k][end_id], sem_seqs.token[k][begin_id]
|
| 179 |
+
|
| 180 |
+
class SemSeq_RandomRotate(object):
|
| 181 |
+
|
| 182 |
+
def __init__(self, prob=0.5):
|
| 183 |
+
if prob==0:
|
| 184 |
+
self.prob = 0
|
| 185 |
+
else:
|
| 186 |
+
self.prob = prob + 0.2
|
| 187 |
+
|
| 188 |
+
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
|
| 189 |
+
if random.random() < self.prob:
|
| 190 |
+
# varible id
|
| 191 |
+
num_all_list, num_sem_list, num_map_list = [], [], []
|
| 192 |
+
for item in text_seq.class_tag:
|
| 193 |
+
if item=='[NUM]':
|
| 194 |
+
var_name = 'N'+str(len(num_all_list))
|
| 195 |
+
num_all_list.append(var_name)
|
| 196 |
+
num_map_list.append(var_name)
|
| 197 |
+
for k in range(len(sem_seqs.token)):
|
| 198 |
+
if sem_seqs.class_tag[k][-2] == '[NUM]':
|
| 199 |
+
var_name = 'N'+str(len(num_all_list))
|
| 200 |
+
num_all_list.append(var_name)
|
| 201 |
+
num_sem_list.append([var_name])
|
| 202 |
+
else:
|
| 203 |
+
num_sem_list.append([])
|
| 204 |
+
# shuffle sem_seq
|
| 205 |
+
if len(sem_seqs.token)>0:
|
| 206 |
+
random_id_list = [k for k in range(len(sem_seqs.token))]
|
| 207 |
+
random.shuffle(random_id_list)
|
| 208 |
+
for key,value in vars(sem_seqs).items():
|
| 209 |
+
_, value = zip(*sorted(zip(random_id_list, value)))
|
| 210 |
+
setattr(sem_seqs, key, list(value))
|
| 211 |
+
_, num_sem_list = zip(*sorted(zip(random_id_list, num_sem_list)))
|
| 212 |
+
# expression map
|
| 213 |
+
for k in range(len(sem_seqs.token)):
|
| 214 |
+
num_map_list += num_sem_list[k]
|
| 215 |
+
num_map_dict = {key:value for key, value in zip(num_map_list, num_all_list)}
|
| 216 |
+
for k in range(len(exp)):
|
| 217 |
+
if exp[k] in num_map_dict:
|
| 218 |
+
exp[k] = num_map_dict[exp[k]]
|
| 219 |
+
|
| 220 |
+
class StruSeq_RandomRotate(object):
|
| 221 |
+
|
| 222 |
+
def __init__(self, prob=0.5):
|
| 223 |
+
self.prob = prob
|
| 224 |
+
|
| 225 |
+
def __call__(self, text_seq, stru_seqs, sem_seqs, exp):
|
| 226 |
+
if random.random() < self.prob:
|
| 227 |
+
# shuffle stru_seq
|
| 228 |
+
if len(stru_seqs.token)>0:
|
| 229 |
+
random_id_list = [k for k in range(len(stru_seqs.token))]
|
| 230 |
+
random.shuffle(random_id_list)
|
| 231 |
+
for key, value in vars(stru_seqs).items():
|
| 232 |
+
_, value = zip(*sorted(zip(random_id_list, value)))
|
| 233 |
+
setattr(stru_seqs, key, list(value))
|
datasets/utils.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
punctuation_list = ['.', '?', ',']
|
| 2 |
+
digit_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
|
| 3 |
+
capital_letter_list = [chr(item) for item in range(65, 91)]
|
| 4 |
+
low_letter_list = [chr(item) for item in range(97, 123)]
|
| 5 |
+
begin_words = ["find", "what", "solve", "determine", "express", "how"]
|
| 6 |
+
end_words = [".", ",", '?', "if", "so", "for which", "given", "with", "on",
|
| 7 |
+
"in", "must", 'for', 'that', 'formed']
|
| 8 |
+
unit_list = ["mm^{2}", "cm^{2}", "in^{2}", "ft^{2}",
|
| 9 |
+
"yd^{2}", "km^{2}", "units^{2}", "mi^{2}", "m^{2}"]
|
| 10 |
+
special_token_list = ['\\frac', '\\pi', '\\sqrt', "+", "-", "^"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_token(ss):
|
| 14 |
+
"""
|
| 15 |
+
Tokenizer: divide the textual problem into words
|
| 16 |
+
"""
|
| 17 |
+
raw_str_list = ss.strip().split(' ')
|
| 18 |
+
# Split punctuation
|
| 19 |
+
new_str1_list = []
|
| 20 |
+
for item in raw_str_list:
|
| 21 |
+
if item[-1] in punctuation_list:
|
| 22 |
+
new_str1_list.append(item[:-1])
|
| 23 |
+
new_str1_list.append(item[-1])
|
| 24 |
+
else:
|
| 25 |
+
new_str1_list.append(item)
|
| 26 |
+
# Split points (capital letters)
|
| 27 |
+
new_str2_list = []
|
| 28 |
+
for item in new_str1_list:
|
| 29 |
+
is_geo_rep = True
|
| 30 |
+
point_list = []
|
| 31 |
+
for k in item:
|
| 32 |
+
if (ord(k) >= 65 and ord(k) <= 90) or \
|
| 33 |
+
((k == '\'' or k in digit_list) and len(point_list) > 0):
|
| 34 |
+
if k == '\'' or k in digit_list:
|
| 35 |
+
point_list[-1] += k
|
| 36 |
+
else:
|
| 37 |
+
point_list.append(k)
|
| 38 |
+
else:
|
| 39 |
+
is_geo_rep = False
|
| 40 |
+
break
|
| 41 |
+
if is_geo_rep:
|
| 42 |
+
new_str2_list += point_list
|
| 43 |
+
else:
|
| 44 |
+
new_str2_list.append(item.lower())
|
| 45 |
+
|
| 46 |
+
return new_str2_list
|
| 47 |
+
|
| 48 |
+
def split_text(text_data):
|
| 49 |
+
"""
|
| 50 |
+
split textual problem into condition and problem(target)
|
| 51 |
+
"""
|
| 52 |
+
if len(text_data.token) == 0:
|
| 53 |
+
return
|
| 54 |
+
begin_ind = 0
|
| 55 |
+
end_ind = len(text_data.token)
|
| 56 |
+
for id, token in enumerate(text_data.token):
|
| 57 |
+
if token in begin_words:
|
| 58 |
+
begin_ind = id
|
| 59 |
+
break
|
| 60 |
+
for id in range(begin_ind+2, len(text_data.token)):
|
| 61 |
+
if text_data.token[id] in end_words:
|
| 62 |
+
if text_data.token[id] in punctuation_list:
|
| 63 |
+
end_ind = id + 1
|
| 64 |
+
else:
|
| 65 |
+
end_ind = id
|
| 66 |
+
break
|
| 67 |
+
text_data.sect_tag = ['[COND]']*len(text_data.token[:begin_ind]) + \
|
| 68 |
+
['[PROB]']*len(text_data.token[begin_ind: end_ind]) + \
|
| 69 |
+
['[COND]']*len(text_data.token[end_ind:])
|
| 70 |
+
|
| 71 |
+
def get_point_angleID_tag(text_data, stru_data, sem_data):
|
| 72 |
+
for id, item in enumerate(text_data.token):
|
| 73 |
+
if item[0] in capital_letter_list:
|
| 74 |
+
text_data.class_tag[id] = '[POINT]'
|
| 75 |
+
if item.isdigit() and id > 0 and text_data.token[id-1] == "\\angle":
|
| 76 |
+
text_data.class_tag[id] = '[ANGID]'
|
| 77 |
+
|
| 78 |
+
for k in range(len(stru_data.token)):
|
| 79 |
+
for id, item in enumerate(stru_data.token[k]):
|
| 80 |
+
if item[0] in capital_letter_list:
|
| 81 |
+
stru_data.class_tag[k][id] = '[POINT]'
|
| 82 |
+
if item.isdigit() and id > 0 and stru_data.token[k][id-1] == "\\angle":
|
| 83 |
+
stru_data.class_tag[k][id] = '[ANGID]'
|
| 84 |
+
|
| 85 |
+
for k in range(len(sem_data.token)):
|
| 86 |
+
for id, item in enumerate(sem_data.token[k]):
|
| 87 |
+
if item[0] in capital_letter_list:
|
| 88 |
+
sem_data.class_tag[k][id] = '[POINT]'
|
| 89 |
+
if item.isdigit() and id > 0 and sem_data.token[k][id-1] == "\\angle":
|
| 90 |
+
sem_data.class_tag[k][id] = '[ANGID]'
|
| 91 |
+
|
| 92 |
+
def get_args(token):
|
| 93 |
+
letter_list = []
|
| 94 |
+
for special_token in special_token_list:
|
| 95 |
+
token = token.replace(special_token, "")
|
| 96 |
+
for letter in token:
|
| 97 |
+
if letter in low_letter_list and not letter in letter_list:
|
| 98 |
+
letter_list.append(letter)
|
| 99 |
+
return letter_list
|
| 100 |
+
|
| 101 |
+
def get_num_arg_tag(text_data, sem_data):
|
| 102 |
+
"""
|
| 103 |
+
Determine the variables/arguments in the text condition
|
| 104 |
+
"""
|
| 105 |
+
arg_sem_flat = []
|
| 106 |
+
for k in range(len(sem_data.token)):
|
| 107 |
+
if len(sem_data.token[k]) >= 3 and sem_data.token[k][-3] == '=':
|
| 108 |
+
sem_data.class_tag[k][-2] = '[NUM]'
|
| 109 |
+
arg_sem_flat += get_args(sem_data.token[k][-2])
|
| 110 |
+
|
| 111 |
+
for id, token in enumerate(text_data.token):
|
| 112 |
+
if text_data.sect_tag[id] == '[COND]' and text_data.class_tag[id] == '[GEN]':
|
| 113 |
+
# unit symbol
|
| 114 |
+
if token in unit_list:
|
| 115 |
+
continue
|
| 116 |
+
# digit existing (rough judgment)
|
| 117 |
+
for word in digit_list:
|
| 118 |
+
if word in token:
|
| 119 |
+
text_data.class_tag[id] = '[NUM]'
|
| 120 |
+
break
|
| 121 |
+
# There are special characters, but not only special characters
|
| 122 |
+
for word in special_token_list:
|
| 123 |
+
if word in token and word != token:
|
| 124 |
+
text_data.class_tag[id] = '[NUM]'
|
| 125 |
+
break
|
| 126 |
+
# Single lowercase letter, but not special cases
|
| 127 |
+
if text_data.token[id] in low_letter_list:
|
| 128 |
+
if id < len(text_data.token)-1 and text_data.token[id+1] == '=':
|
| 129 |
+
continue
|
| 130 |
+
if text_data.token[id] == 'm' and id < len(text_data.token)-1 and text_data.token[id+1] in ["\\angle", "\\widehat"]:
|
| 131 |
+
continue
|
| 132 |
+
if text_data.token[id] == 'a' and (id == 0 or text_data.token[id-1] != '='):
|
| 133 |
+
continue
|
| 134 |
+
if not text_data.token[id] in arg_sem_flat and \
|
| 135 |
+
id > 0 and ('line' in text_data.token[id-1] or text_data.token[id-1] == 'and' or
|
| 136 |
+
(text_data.token[id-1] == ',' and text_data.token[id+1] == ',')):
|
| 137 |
+
continue
|
| 138 |
+
text_data.class_tag[id] = '[NUM]'
|
| 139 |
+
|
| 140 |
+
arg_text_flat = []
|
| 141 |
+
for id, token in enumerate(text_data.token):
|
| 142 |
+
if text_data.sect_tag[id] == '[COND]' and text_data.class_tag[id] == '[NUM]':
|
| 143 |
+
arg_text_flat += get_args(token)
|
| 144 |
+
|
| 145 |
+
# Determine arguments
|
| 146 |
+
arg_all_flat = arg_text_flat + arg_sem_flat
|
| 147 |
+
for id, token in enumerate(text_data.token):
|
| 148 |
+
if text_data.class_tag[id] == '[GEN]' \
|
| 149 |
+
and text_data.token[id] in arg_all_flat:
|
| 150 |
+
if id < len(text_data.token)-1 and text_data.token[id+1] == '=':
|
| 151 |
+
text_data.class_tag[id] = '[ARG]'
|
| 152 |
+
continue
|
| 153 |
+
if text_data.token[id] == 'm' and id < len(text_data.token)-1 and text_data.token[id+1] in ["\\angle", "\\widehat"]:
|
| 154 |
+
continue
|
| 155 |
+
if text_data.token[id] == 'a' and (id == 0 or text_data.token[id-1] != '=') and \
|
| 156 |
+
text_data.sect_tag[id]=='[COND]':
|
| 157 |
+
continue
|
| 158 |
+
if id > 0 and ('line' in text_data.token[id-1] or text_data.token[id-1] == 'and' or
|
| 159 |
+
(text_data.token[id-1] == ',' and text_data.token[id+1] == ',')):
|
| 160 |
+
continue
|
| 161 |
+
text_data.class_tag[id] = '[ARG]'
|
| 162 |
+
|
| 163 |
+
def remove_sem_dup(text_data, sem_data, exp_token):
|
| 164 |
+
"""
|
| 165 |
+
Remove the seq of sem_data if num is also in the text_data
|
| 166 |
+
and change the corresponding expression
|
| 167 |
+
"""
|
| 168 |
+
text_num_list, id_all_list, id_map_list = [], [], []
|
| 169 |
+
token_, sect_tag_, class_tag_ = [], [], []
|
| 170 |
+
|
| 171 |
+
for k in range(len(text_data.token)):
|
| 172 |
+
if text_data.class_tag[k] == '[NUM]':
|
| 173 |
+
text_num_list.append(text_data.token[k])
|
| 174 |
+
var_name = 'N'+str(len(id_all_list))
|
| 175 |
+
id_all_list.append(var_name)
|
| 176 |
+
id_map_list.append(var_name)
|
| 177 |
+
|
| 178 |
+
for k in range(len(sem_data.token)):
|
| 179 |
+
if sem_data.class_tag[k][-2] == '[NUM]':
|
| 180 |
+
var_name = 'N'+str(len(id_all_list))
|
| 181 |
+
id_all_list.append(var_name)
|
| 182 |
+
if not sem_data.token[k][-2] in text_num_list:
|
| 183 |
+
token_.append(sem_data.token[k])
|
| 184 |
+
sect_tag_.append(sem_data.sect_tag[k])
|
| 185 |
+
class_tag_.append(sem_data.class_tag[k])
|
| 186 |
+
id_map_list.append(var_name)
|
| 187 |
+
else:
|
| 188 |
+
token_.append(sem_data.token[k])
|
| 189 |
+
sect_tag_.append(sem_data.sect_tag[k])
|
| 190 |
+
class_tag_.append(sem_data.class_tag[k])
|
| 191 |
+
|
| 192 |
+
num_map_dict = {key:value for key, value in zip(id_map_list, id_all_list)}
|
| 193 |
+
for k in range(len(exp_token)):
|
| 194 |
+
if exp_token[k] in num_map_dict:
|
| 195 |
+
exp_token[k] = num_map_dict[exp_token[k]]
|
| 196 |
+
|
| 197 |
+
sem_data.token = token_
|
| 198 |
+
sem_data.sect_tag = sect_tag_
|
| 199 |
+
sem_data.class_tag = class_tag_
|
| 200 |
+
|
| 201 |
+
def get_combined_text(text_seq, stru_seqs, sem_seqs, combine_text, args):
|
| 202 |
+
'''
|
| 203 |
+
combination style: [stru_seqs, text_cond, sem_seqs, text_prob]
|
| 204 |
+
'''
|
| 205 |
+
# split cond and prob in text_seq
|
| 206 |
+
begin_ind = end_ind = None
|
| 207 |
+
for k in range(len(text_seq.sect_tag)):
|
| 208 |
+
if text_seq.sect_tag[k]=='[PROB]':
|
| 209 |
+
begin_ind = k
|
| 210 |
+
break
|
| 211 |
+
for k in range(len(text_seq.sect_tag)-1,-1,-1):
|
| 212 |
+
if text_seq.sect_tag[k]=='[PROB]':
|
| 213 |
+
end_ind = k+1
|
| 214 |
+
break
|
| 215 |
+
# combine text_seq, stru_seqs and sem_seqs
|
| 216 |
+
for key in vars(combine_text):
|
| 217 |
+
# get text_cond and text_prob
|
| 218 |
+
text_all_value = getattr(text_seq, key)
|
| 219 |
+
text_cond_value = text_all_value[:begin_ind] + text_all_value[end_ind:]
|
| 220 |
+
text_prob_value = text_all_value[begin_ind:end_ind]
|
| 221 |
+
if args.without_stru:
|
| 222 |
+
value_all = text_cond_value + sum(getattr(sem_seqs, key), []) + text_prob_value
|
| 223 |
+
else:
|
| 224 |
+
value_all = sum(getattr(stru_seqs, key), []) + text_cond_value + \
|
| 225 |
+
sum(getattr(sem_seqs, key), []) + text_prob_value
|
| 226 |
+
|
| 227 |
+
setattr(combine_text, key, value_all)
|
| 228 |
+
|
| 229 |
+
def get_var_arg(combine_text, args):
|
| 230 |
+
|
| 231 |
+
var_values, arg_values = [], []
|
| 232 |
+
var_positions, arg_positions = [], []
|
| 233 |
+
class_tag = combine_text.class_tag
|
| 234 |
+
token = combine_text.token
|
| 235 |
+
|
| 236 |
+
for k in range(len(class_tag)):
|
| 237 |
+
if class_tag[k] == '[NUM]':
|
| 238 |
+
var_values.append(token[k])
|
| 239 |
+
var_positions.append(k)
|
| 240 |
+
if class_tag[k] == '[ARG]':
|
| 241 |
+
arg_values.append(token[k])
|
| 242 |
+
arg_positions.append(k)
|
| 243 |
+
# merge position of var and arg
|
| 244 |
+
return var_positions+arg_positions, var_values, arg_values
|
| 245 |
+
|
| 246 |
+
def get_text_index(combine_text, src_lang):
|
| 247 |
+
|
| 248 |
+
text_sect_tag = src_lang.indexes_from_sentence(combine_text.sect_tag, id_type='sect_tag')
|
| 249 |
+
text_class_tag = src_lang.indexes_from_sentence(combine_text.class_tag, id_type='class_tag')
|
| 250 |
+
text_token = [combine_text.token[:], ['[PAD]']*len(combine_text.token)]
|
| 251 |
+
for k in range(len(combine_text.class_tag)):
|
| 252 |
+
if combine_text.class_tag[k] == '[NUM]':
|
| 253 |
+
letter_list = get_args(combine_text.token[k])
|
| 254 |
+
text_token[0][k] = text_token[1][k] = "[PAD]"
|
| 255 |
+
for j in range(len(letter_list)):
|
| 256 |
+
text_token[j][k] = letter_list[j]
|
| 257 |
+
text_token = [src_lang.indexes_from_sentence(item, id_type='text') for item in text_token]
|
| 258 |
+
|
| 259 |
+
return text_token, text_sect_tag, text_class_tag
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
loss/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .loss import *
|
| 2 |
+
from config import criterion_list
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_criterion(args):
|
| 6 |
+
# create model
|
| 7 |
+
if args.criterion in criterion_list:
|
| 8 |
+
return eval(args.criterion)(args)
|
| 9 |
+
else:
|
| 10 |
+
raise NotImplementedError("Unsupported Loss Criterion : {}".format(args.criterion))
|
loss/loss.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from utils import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CrossEntropy(nn.Module):
|
| 8 |
+
def __init__(self, cfg):
|
| 9 |
+
super(CrossEntropy, self).__init__()
|
| 10 |
+
|
| 11 |
+
def forward(self, output, target):
|
| 12 |
+
loss = F.cross_entropy(output, target)
|
| 13 |
+
return loss
|
| 14 |
+
|
| 15 |
+
class FocalLoss(nn.Module):
|
| 16 |
+
def __init__(self, cfg=None):
|
| 17 |
+
super(FocalLoss, self).__init__()
|
| 18 |
+
# self.gamma = cfg.LOSS.FOCAL.GAMMA
|
| 19 |
+
if cfg is None:
|
| 20 |
+
self.gamma = 2.0
|
| 21 |
+
else:
|
| 22 |
+
self.gamma = cfg.focal_loss_gamma
|
| 23 |
+
assert self.gamma >= 0
|
| 24 |
+
|
| 25 |
+
def focal_loss(self, input_values):
|
| 26 |
+
"""Computes the focal loss"""
|
| 27 |
+
p = torch.exp(-input_values)
|
| 28 |
+
loss = (1 - p) ** self.gamma * input_values
|
| 29 |
+
return loss.mean()
|
| 30 |
+
|
| 31 |
+
def forward(self, input, target):
|
| 32 |
+
return self.focal_loss(F.cross_entropy(input, target, reduction='none'))
|
| 33 |
+
|
| 34 |
+
class MaskedCrossEntropy(nn.Module):
|
| 35 |
+
|
| 36 |
+
def __init__(self, cfg):
|
| 37 |
+
super(MaskedCrossEntropy, self).__init__()
|
| 38 |
+
self.cfg = cfg
|
| 39 |
+
|
| 40 |
+
def forward(self, logits, target, length):
|
| 41 |
+
"""
|
| 42 |
+
Args:
|
| 43 |
+
logits: A Variable containing a FloatTensor of size
|
| 44 |
+
(batch, max_len, num_classes) which contains the
|
| 45 |
+
unnormalized probability for each class. B x S x (op_size+const_size+var_size)
|
| 46 |
+
target: A Variable containing a LongTensor of size
|
| 47 |
+
(batch, max_len) which contains the index of the true
|
| 48 |
+
class for each corresponding step. B x S
|
| 49 |
+
Returns:
|
| 50 |
+
loss: An average loss value masked by the length.
|
| 51 |
+
"""
|
| 52 |
+
# logits_flat: (batch * max_len, num_classes)
|
| 53 |
+
logits_flat = logits.view(-1, logits.size(-1))
|
| 54 |
+
# log_probs_flat: (batch * max_len, num_classes)
|
| 55 |
+
log_probs_flat = F.log_softmax(logits_flat, dim=1)
|
| 56 |
+
# target_flat: (batch * max_len, 1)
|
| 57 |
+
target_flat = target.view(-1, 1)
|
| 58 |
+
# losses_flat: (batch * max_len, 1)
|
| 59 |
+
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
|
| 60 |
+
# losses: (batch, max_len)
|
| 61 |
+
losses = losses_flat.view(*target.size())
|
| 62 |
+
# mask: (batch, max_len)
|
| 63 |
+
mask = sequence_mask(length)
|
| 64 |
+
losses = losses * mask.float()
|
| 65 |
+
loss = losses.sum() / length.float().sum()
|
| 66 |
+
return loss
|
model/backbone/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .resnet import *
|
| 2 |
+
from .mobilenet_v2 import *
|
| 3 |
+
from config import visual_backbone_list
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_visual_backbone(args):
|
| 7 |
+
if args.visual_backbone in visual_backbone_list:
|
| 8 |
+
model = eval(args.visual_backbone)()
|
| 9 |
+
if args.pretrain_vis_path !="":
|
| 10 |
+
model.load_model(pretrain=args.pretrain_vis_path)
|
| 11 |
+
args.logger.info("Visual backbone has been loaded...")
|
| 12 |
+
else:
|
| 13 |
+
args.logger.info("Visual backbone choose to train from scratch")
|
| 14 |
+
return model
|
| 15 |
+
else:
|
| 16 |
+
raise NotImplementedError("Unsupported Backbone: {}".format(args.visual_backbone))
|
model/backbone/mobilenet_v2.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import config as cfg
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def conv_bn(inp, oup, stride):
|
| 8 |
+
return nn.Sequential(
|
| 9 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
| 10 |
+
nn.BatchNorm2d(oup),
|
| 11 |
+
nn.ReLU6(inplace=True)
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class InvertedResidual(nn.Module):
|
| 16 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
| 17 |
+
super(InvertedResidual, self).__init__()
|
| 18 |
+
self.stride = stride
|
| 19 |
+
assert stride in [1, 2]
|
| 20 |
+
|
| 21 |
+
hidden_dim = round(inp * expand_ratio)
|
| 22 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
| 23 |
+
|
| 24 |
+
if expand_ratio == 1:
|
| 25 |
+
self.conv = nn.Sequential(
|
| 26 |
+
# dw
|
| 27 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
| 28 |
+
nn.BatchNorm2d(hidden_dim),
|
| 29 |
+
nn.ReLU6(inplace=True),
|
| 30 |
+
# pw-linear
|
| 31 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 32 |
+
nn.BatchNorm2d(oup),
|
| 33 |
+
)
|
| 34 |
+
else:
|
| 35 |
+
self.conv = nn.Sequential(
|
| 36 |
+
# pw
|
| 37 |
+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
| 38 |
+
nn.BatchNorm2d(hidden_dim),
|
| 39 |
+
nn.ReLU6(inplace=True),
|
| 40 |
+
# dw
|
| 41 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
| 42 |
+
nn.BatchNorm2d(hidden_dim),
|
| 43 |
+
nn.ReLU6(inplace=True),
|
| 44 |
+
# pw-linear
|
| 45 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 46 |
+
nn.BatchNorm2d(oup),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
if self.use_res_connect:
|
| 51 |
+
return x + self.conv(x)
|
| 52 |
+
else:
|
| 53 |
+
return self.conv(x)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class MobileNetV2(nn.Module):
|
| 57 |
+
def __init__(self, width_mult=1.):
|
| 58 |
+
super(MobileNetV2, self).__init__()
|
| 59 |
+
block = InvertedResidual
|
| 60 |
+
input_channel = 32
|
| 61 |
+
last_channel = 1280
|
| 62 |
+
interverted_residual_setting = [
|
| 63 |
+
# t, c, n, s
|
| 64 |
+
[1, 16, 1, 1],
|
| 65 |
+
[6, 24, 2, 2],
|
| 66 |
+
[6, 32, 3, 2],
|
| 67 |
+
[6, 64, 4, 2],
|
| 68 |
+
[6, 96, 3, 1],
|
| 69 |
+
[6, 160, 3, 2],
|
| 70 |
+
[6, 320, 1, 1],
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
# building first layer
|
| 74 |
+
# assert input_size % 32 == 0
|
| 75 |
+
input_channel = int(input_channel * width_mult)
|
| 76 |
+
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
|
| 77 |
+
self.features = [conv_bn(3, input_channel, 2)]
|
| 78 |
+
# building inverted residual blocks
|
| 79 |
+
for t, c, n, s in interverted_residual_setting:
|
| 80 |
+
output_channel = int(c * width_mult)
|
| 81 |
+
for i in range(n):
|
| 82 |
+
if i == 0:
|
| 83 |
+
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
|
| 84 |
+
else:
|
| 85 |
+
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
|
| 86 |
+
input_channel = output_channel
|
| 87 |
+
|
| 88 |
+
# make it nn.Sequential
|
| 89 |
+
self.features = nn.Sequential(*self.features)
|
| 90 |
+
|
| 91 |
+
self._initialize_weights()
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
x = self.features(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
def _initialize_weights(self):
|
| 98 |
+
for m in self.modules():
|
| 99 |
+
if isinstance(m, nn.Conv2d):
|
| 100 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 101 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 102 |
+
if m.bias is not None:
|
| 103 |
+
m.bias.data.zero_()
|
| 104 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 105 |
+
m.weight.data.fill_(1)
|
| 106 |
+
m.bias.data.zero_()
|
| 107 |
+
elif isinstance(m, nn.Linear):
|
| 108 |
+
n = m.weight.size(1)
|
| 109 |
+
m.weight.data.normal_(0, 0.01)
|
| 110 |
+
m.bias.data.zero_()
|
| 111 |
+
|
| 112 |
+
def load_model(self):
|
| 113 |
+
model_dict = self.state_dict()
|
| 114 |
+
pretrained_dict = torch.load(cfg.pretrained_model_path)
|
| 115 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
| 116 |
+
model_dict.update(pretrained_dict)
|
| 117 |
+
self.load_state_dict(model_dict)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def mobilenet_v2():
|
| 121 |
+
|
| 122 |
+
return MobileNetV2()
|
model/backbone/resnet.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
def init_layer(L):
|
| 6 |
+
# Initialization using fan-in
|
| 7 |
+
if isinstance(L, nn.Conv2d):
|
| 8 |
+
n = L.kernel_size[0]*L.kernel_size[1]*L.out_channels
|
| 9 |
+
L.weight.data.normal_(0,math.sqrt(2.0/float(n)))
|
| 10 |
+
elif isinstance(L, nn.BatchNorm2d):
|
| 11 |
+
L.weight.data.fill_(1)
|
| 12 |
+
L.bias.data.fill_(0)
|
| 13 |
+
|
| 14 |
+
class Flatten(nn.Module):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super(Flatten, self).__init__()
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
return x.view(x.size(0), -1)
|
| 20 |
+
|
| 21 |
+
# Simple ResNet Block
|
| 22 |
+
class SimpleBlock(nn.Module):
|
| 23 |
+
maml = False #Default
|
| 24 |
+
def __init__(self, indim, outdim, half_res):
|
| 25 |
+
super(SimpleBlock, self).__init__()
|
| 26 |
+
self.indim = indim
|
| 27 |
+
self.outdim = outdim
|
| 28 |
+
self.C1 = nn.Conv2d(indim, outdim, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False)
|
| 29 |
+
self.BN1 = nn.BatchNorm2d(outdim)
|
| 30 |
+
self.C2 = nn.Conv2d(outdim, outdim,kernel_size=3, padding=1,bias=False)
|
| 31 |
+
self.BN2 = nn.BatchNorm2d(outdim)
|
| 32 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 33 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 34 |
+
|
| 35 |
+
self.parametrized_layers = [self.C1, self.C2, self.BN1, self.BN2]
|
| 36 |
+
|
| 37 |
+
self.half_res = half_res
|
| 38 |
+
|
| 39 |
+
# if the input number of channels is not equal to the output, then need a 1x1 convolution
|
| 40 |
+
if indim!=outdim:
|
| 41 |
+
self.shortcut = nn.Conv2d(indim, outdim, 1, 2 if half_res else 1, bias=False)
|
| 42 |
+
self.BNshortcut = nn.BatchNorm2d(outdim)
|
| 43 |
+
self.parametrized_layers.append(self.shortcut)
|
| 44 |
+
self.parametrized_layers.append(self.BNshortcut)
|
| 45 |
+
self.shortcut_type = '1x1'
|
| 46 |
+
else:
|
| 47 |
+
self.shortcut_type = 'identity'
|
| 48 |
+
|
| 49 |
+
for layer in self.parametrized_layers:
|
| 50 |
+
init_layer(layer)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
out = self.C1(x)
|
| 54 |
+
out = self.BN1(out)
|
| 55 |
+
out = self.relu1(out)
|
| 56 |
+
out = self.C2(out)
|
| 57 |
+
out = self.BN2(out)
|
| 58 |
+
short_out = x if self.shortcut_type == 'identity' else self.BNshortcut(self.shortcut(x))
|
| 59 |
+
out = out + short_out
|
| 60 |
+
out = self.relu2(out)
|
| 61 |
+
return out
|
| 62 |
+
|
| 63 |
+
# Bottleneck block
|
| 64 |
+
class BottleneckBlock(nn.Module):
|
| 65 |
+
maml = False #Default
|
| 66 |
+
def __init__(self, indim, outdim, half_res):
|
| 67 |
+
super(BottleneckBlock, self).__init__()
|
| 68 |
+
bottleneckdim = int(outdim/4)
|
| 69 |
+
self.indim = indim
|
| 70 |
+
self.outdim = outdim
|
| 71 |
+
|
| 72 |
+
self.C1 = nn.Conv2d(indim, bottleneckdim, kernel_size=1, bias=False)
|
| 73 |
+
self.BN1 = nn.BatchNorm2d(bottleneckdim)
|
| 74 |
+
self.C2 = nn.Conv2d(bottleneckdim, bottleneckdim, kernel_size=3, stride=2 if half_res else 1,padding=1)
|
| 75 |
+
self.BN2 = nn.BatchNorm2d(bottleneckdim)
|
| 76 |
+
self.C3 = nn.Conv2d(bottleneckdim, outdim, kernel_size=1, bias=False)
|
| 77 |
+
self.BN3 = nn.BatchNorm2d(outdim)
|
| 78 |
+
|
| 79 |
+
self.relu = nn.ReLU()
|
| 80 |
+
self.parametrized_layers = [self.C1, self.BN1, self.C2, self.BN2, self.C3, self.BN3]
|
| 81 |
+
self.half_res = half_res
|
| 82 |
+
|
| 83 |
+
# if the input number of channels is not equal to the output, then need a 1x1 convolution
|
| 84 |
+
if indim!=outdim:
|
| 85 |
+
self.shortcut = nn.Conv2d(indim, outdim, 1, stride=2 if half_res else 1, bias=False)
|
| 86 |
+
self.parametrized_layers.append(self.shortcut)
|
| 87 |
+
self.shortcut_type = '1x1'
|
| 88 |
+
else:
|
| 89 |
+
self.shortcut_type = 'identity'
|
| 90 |
+
|
| 91 |
+
for layer in self.parametrized_layers:
|
| 92 |
+
init_layer(layer)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
|
| 96 |
+
short_out = x if self.shortcut_type == 'identity' else self.shortcut(x)
|
| 97 |
+
out = self.C1(x)
|
| 98 |
+
out = self.BN1(out)
|
| 99 |
+
out = self.relu(out)
|
| 100 |
+
out = self.C2(out)
|
| 101 |
+
out = self.BN2(out)
|
| 102 |
+
out = self.relu(out)
|
| 103 |
+
out = self.C3(out)
|
| 104 |
+
out = self.BN3(out)
|
| 105 |
+
out = out + short_out
|
| 106 |
+
|
| 107 |
+
out = self.relu(out)
|
| 108 |
+
return out
|
| 109 |
+
|
| 110 |
+
class ResNet(nn.Module):
|
| 111 |
+
maml = False #Default
|
| 112 |
+
def __init__(self,block,list_of_num_layers, list_of_out_dims, flatten = True):
|
| 113 |
+
# list_of_num_layers specifies number of layers in each stage
|
| 114 |
+
# list_of_out_dims specifies number of output channel for each stage
|
| 115 |
+
super(ResNet,self).__init__()
|
| 116 |
+
assert len(list_of_num_layers)==4, 'Can have only four stages'
|
| 117 |
+
|
| 118 |
+
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
| 119 |
+
bias=False)
|
| 120 |
+
bn1 = nn.BatchNorm2d(64)
|
| 121 |
+
relu = nn.ReLU()
|
| 122 |
+
pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 123 |
+
init_layer(conv1)
|
| 124 |
+
init_layer(bn1)
|
| 125 |
+
trunk = [conv1, bn1, relu, pool1]
|
| 126 |
+
indim = 64
|
| 127 |
+
for i in range(4):
|
| 128 |
+
for j in range(list_of_num_layers[i]):
|
| 129 |
+
half_res = (i>=1) and (j==0)
|
| 130 |
+
B = block(indim, list_of_out_dims[i], half_res)
|
| 131 |
+
trunk.append(B)
|
| 132 |
+
indim = list_of_out_dims[i]
|
| 133 |
+
if flatten:
|
| 134 |
+
avgpool = nn.AvgPool2d(4)
|
| 135 |
+
trunk.append(avgpool)
|
| 136 |
+
trunk.append(Flatten())
|
| 137 |
+
self.final_feat_dim = indim
|
| 138 |
+
else:
|
| 139 |
+
self.final_feat_dim = [indim, 4, 4]
|
| 140 |
+
self.trunk = nn.Sequential(*trunk)
|
| 141 |
+
|
| 142 |
+
def forward(self,x):
|
| 143 |
+
out = self.trunk(x)
|
| 144 |
+
return out
|
| 145 |
+
|
| 146 |
+
def ResNet10(flatten = True):
|
| 147 |
+
return ResNet(SimpleBlock, [1,1,1,1],[64,128,256,512], flatten)
|
| 148 |
+
|
| 149 |
+
def ResNet18(flatten = True):
|
| 150 |
+
return ResNet(SimpleBlock, [2,2,2,2],[64,128,256,512], flatten)
|
| 151 |
+
|
| 152 |
+
def ResNet34(flatten = True):
|
| 153 |
+
return ResNet(SimpleBlock, [3,4,6,3],[64,128,256,512], flatten)
|
| 154 |
+
|
| 155 |
+
def ResNet50(flatten = True):
|
| 156 |
+
return ResNet(BottleneckBlock, [3,4,6,3], [256,512,1024,2048], flatten)
|
| 157 |
+
|
| 158 |
+
def ResNet101(flatten = True):
|
| 159 |
+
return ResNet(BottleneckBlock, [3,4,23,3],[256,512,1024,2048], flatten)
|
model/classifier/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .classifier_ops import *
|
| 2 |
+
from config import classifier_list
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_classifier(args):
|
| 6 |
+
|
| 7 |
+
bias_flag = args.classifier_bias
|
| 8 |
+
num_features = args.num_features
|
| 9 |
+
num_classes = args.num_classes
|
| 10 |
+
|
| 11 |
+
if not args.classifier in classifier_list:
|
| 12 |
+
raise NotImplementedError("Unsupported Classifier: {}".format(args.classifier))
|
| 13 |
+
|
| 14 |
+
if args.classifier == "FCNorm":
|
| 15 |
+
classifier = FCNorm(num_features, num_classes)
|
| 16 |
+
elif args.classifier == "CosNorm":
|
| 17 |
+
classifier = CosNorm(num_features, num_classes)
|
| 18 |
+
elif args.classifier == "DotProduct":
|
| 19 |
+
classifier = DotProduct(num_classes, num_features, bias_flag)
|
| 20 |
+
elif args.classifier == "DistFC":
|
| 21 |
+
classifier = DistFC(num_features, num_classes)
|
| 22 |
+
|
| 23 |
+
return classifier
|
model/classifier/classifier_ops.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
class DotProduct(nn.Module):
|
| 7 |
+
def __init__(self, num_classes=1000, feat_dim=2048, bias=True):
|
| 8 |
+
super(DotProduct, self).__init__()
|
| 9 |
+
# print('<DotProductClassifier> contains bias: {}'.format(bias))
|
| 10 |
+
self.fc = nn.Linear(feat_dim, num_classes,bias)
|
| 11 |
+
|
| 12 |
+
def forward(self, x, *args):
|
| 13 |
+
x = self.fc(x)
|
| 14 |
+
return x
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CosNorm(nn.Module):
|
| 18 |
+
def __init__(self, in_dims, out_dims, scale=16, margin=0.5, init_std=0.001):
|
| 19 |
+
super(CosNorm, self).__init__()
|
| 20 |
+
self.in_dims = in_dims
|
| 21 |
+
self.out_dims = out_dims
|
| 22 |
+
self.scale = scale
|
| 23 |
+
self.margin = margin
|
| 24 |
+
self.weight = nn.Parameter(torch.Tensor(out_dims, in_dims).cuda())
|
| 25 |
+
self.reset_parameters()
|
| 26 |
+
|
| 27 |
+
def reset_parameters(self):
|
| 28 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
| 29 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 30 |
+
|
| 31 |
+
def forward(self, input, *args):
|
| 32 |
+
norm_x = torch.norm(input.clone(), 2, 1, keepdim=True)
|
| 33 |
+
ex = (norm_x / (1 + norm_x)) * (input / norm_x)
|
| 34 |
+
ew = self.weight / torch.norm(self.weight, 2, 1, keepdim=True)
|
| 35 |
+
return torch.mm(self.scale * ex, ew.t())
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class FCNorm(nn.Module):
|
| 39 |
+
# for LDAM Loss
|
| 40 |
+
def __init__(self, num_features, num_classes, scale=20.0):
|
| 41 |
+
super(FCNorm, self).__init__()
|
| 42 |
+
self.weight = nn.Parameter(torch.FloatTensor(num_classes, num_features))
|
| 43 |
+
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
|
| 44 |
+
self.scale = scale
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
out = self.scale * F.linear(F.normalize(x), F.normalize(self.weight))
|
| 48 |
+
return out
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class DistFC(nn.Module):
|
| 52 |
+
|
| 53 |
+
def __init__(self, num_features, num_classes,init_weight=True):
|
| 54 |
+
super(DistFC, self).__init__()
|
| 55 |
+
self.centers=nn.Parameter(torch.randn(num_features,num_classes).cuda(),requires_grad=True)
|
| 56 |
+
if init_weight:
|
| 57 |
+
self.__init_weight()
|
| 58 |
+
|
| 59 |
+
def __init_weight(self):
|
| 60 |
+
nn.init.kaiming_normal_(self.centers)
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
features_square=torch.sum(torch.pow(x,2),1, keepdim=True)
|
| 64 |
+
centers_square=torch.sum(torch.pow(self.centers,2),0, keepdim=True)
|
| 65 |
+
features_into_centers=2.0*torch.matmul(x, (self.centers))
|
| 66 |
+
dist=features_square+centers_square-features_into_centers
|
| 67 |
+
return self.centers, dist
|
| 68 |
+
|
| 69 |
+
|
model/decoder/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .transformer import TransformerModel
|
| 2 |
+
from config import decoder_list
|
| 3 |
+
from .rnn_decoder import DecoderRNN
|
| 4 |
+
from .tree_decoder import TreeDecoder
|
| 5 |
+
from .transformer import TransformerDecoder
|
| 6 |
+
|
| 7 |
+
def get_decoder(params, *args):
|
| 8 |
+
|
| 9 |
+
if not params.decoder_type in decoder_list:
|
| 10 |
+
raise NotImplementedError(
|
| 11 |
+
"Unsupported Classifier: {}".format(params.decoder_type))
|
| 12 |
+
|
| 13 |
+
if params.decoder_type == "transformer":
|
| 14 |
+
decoder = TransformerDecoder(params, *args)
|
| 15 |
+
elif params.decoder_type == "rnn_decoder":
|
| 16 |
+
decoder = DecoderRNN(params, *args)
|
| 17 |
+
elif params.decoder_type == "tree_decoder":
|
| 18 |
+
decoder = TreeDecoder(params, *args)
|
| 19 |
+
else:
|
| 20 |
+
raise NotImplementedError("Unsupported Decoder: {}".format(params.decoder_type))
|
| 21 |
+
|
| 22 |
+
return decoder
|
| 23 |
+
|
| 24 |
+
|
model/decoder/rnn_decoder.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from model.module import *
|
| 4 |
+
from utils import *
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
class DecoderRNN(nn.Module):
|
| 8 |
+
def __init__(self, cfg, tgt_lang):
|
| 9 |
+
super(DecoderRNN, self).__init__()
|
| 10 |
+
# token location
|
| 11 |
+
self.var_start = tgt_lang.var_start # spe_num + midvar_num + const_num + op_num
|
| 12 |
+
self.sos_id = tgt_lang.word2index["[SOS]"]
|
| 13 |
+
self.eos_id = tgt_lang.word2index["[EOS]"]
|
| 14 |
+
# Define layers
|
| 15 |
+
self.em_dropout = nn.Dropout(cfg.dropout_rate)
|
| 16 |
+
self.embedding_tgt = nn.Embedding(self.var_start, cfg.decoder_embedding_size, padding_idx=0)
|
| 17 |
+
self.gru = nn.GRU(input_size=cfg.decoder_hidden_size+cfg.decoder_embedding_size, \
|
| 18 |
+
hidden_size=cfg.decoder_hidden_size, \
|
| 19 |
+
num_layers=cfg.decoder_layers, \
|
| 20 |
+
dropout = cfg.dropout_rate, \
|
| 21 |
+
batch_first = True)
|
| 22 |
+
# Choose attention model
|
| 23 |
+
self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size)
|
| 24 |
+
self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size)
|
| 25 |
+
# predefined constant
|
| 26 |
+
self.no_var_id = torch.arange(self.var_start).unsqueeze(0).cuda()
|
| 27 |
+
self.cfg = cfg
|
| 28 |
+
|
| 29 |
+
def get_var_encoder_outputs(self, encoder_outputs, var_pos):
|
| 30 |
+
"""
|
| 31 |
+
Arguments:
|
| 32 |
+
encoder_outputs: B x S1 x H
|
| 33 |
+
var_pos: B x S3
|
| 34 |
+
Returns:
|
| 35 |
+
var_embeddings: B x S3 x H
|
| 36 |
+
"""
|
| 37 |
+
hidden_size = encoder_outputs.size(-1)
|
| 38 |
+
expand_var_pos = var_pos.unsqueeze(-1).repeat(1, 1, hidden_size)
|
| 39 |
+
var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_pos)
|
| 40 |
+
return var_embeddings
|
| 41 |
+
|
| 42 |
+
def forward(self, encoder_outputs, problem_output, len_src, var_pos, len_var, \
|
| 43 |
+
text_tgt=None, is_train=False):
|
| 44 |
+
"""
|
| 45 |
+
Arguments:
|
| 46 |
+
encoder_outputs: B x S1 x H
|
| 47 |
+
problem_output: layer_num x B x H
|
| 48 |
+
len_src: B
|
| 49 |
+
text_tgt: B x S2
|
| 50 |
+
var_pos: B x S3
|
| 51 |
+
len_var: B
|
| 52 |
+
Return:
|
| 53 |
+
training: logits, B x S x (no_var_size+var_size)
|
| 54 |
+
testing: exp_id, B x candi_size(beam_size) x exp_len
|
| 55 |
+
"""
|
| 56 |
+
self.embedding_var = self.get_var_encoder_outputs(encoder_outputs, var_pos) # B x S3 x H
|
| 57 |
+
self.src_mask = sequence_mask(len_src) # B x S1
|
| 58 |
+
self.candi_mask = sequence_mask(self.var_start + len_var) # B x (no_var_size + var_size)
|
| 59 |
+
if is_train:
|
| 60 |
+
return self._forward_train(encoder_outputs, problem_output, text_tgt)
|
| 61 |
+
else:
|
| 62 |
+
return self._forward_test(encoder_outputs, problem_output)
|
| 63 |
+
|
| 64 |
+
def _forward_train(self, encoder_outputs, problem_output, text_tgt):
|
| 65 |
+
|
| 66 |
+
all_seq_outputs = []
|
| 67 |
+
batch_size = encoder_outputs.size(0)
|
| 68 |
+
# initial hidden input of RNN
|
| 69 |
+
rnn_hidden = problem_output
|
| 70 |
+
# input embedding
|
| 71 |
+
tgt_novar_id = torch.clamp(text_tgt, max=self.var_start-1) # B x S2
|
| 72 |
+
novar_embedding = self.embedding_tgt(tgt_novar_id) # B x S2 x H
|
| 73 |
+
tgt_var_id = torch.clamp(text_tgt-self.var_start, min=0) # B x S2
|
| 74 |
+
var_embeddings = self.embedding_var.gather(dim=1, index = \
|
| 75 |
+
tgt_var_id.unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size)) # B x S2 x H
|
| 76 |
+
|
| 77 |
+
choose_mask = (text_tgt<self.var_start).unsqueeze(2). \
|
| 78 |
+
repeat(1, 1, self.cfg.decoder_embedding_size)
|
| 79 |
+
embedding_all = torch.where(choose_mask, novar_embedding, var_embeddings) # B x S2 x H
|
| 80 |
+
embedding_all_ = self.em_dropout(embedding_all)
|
| 81 |
+
# candi weight embedding
|
| 82 |
+
embedding_weight_no_var = self.embedding_tgt(self.no_var_id. \
|
| 83 |
+
repeat(batch_size, 1)) # B x no_var_size x H
|
| 84 |
+
embedding_weight_all = torch.cat((embedding_weight_no_var, self.embedding_var), dim=1) # B x (no_var_size + var_size) x H
|
| 85 |
+
embedding_weight_all_ = self.em_dropout(embedding_weight_all)
|
| 86 |
+
|
| 87 |
+
for t in range(text_tgt.size(1)-1):
|
| 88 |
+
# Calculate attention from current RNN state and all encoder outputs;
|
| 89 |
+
# apply to encoder outputs to get weighted average
|
| 90 |
+
current_hiddens = self.em_dropout(rnn_hidden[-1].unsqueeze(1)) # B x 1 x H
|
| 91 |
+
attn_weights = self.attn(current_hiddens, encoder_outputs, self.src_mask)
|
| 92 |
+
context = attn_weights.unsqueeze(1).bmm(encoder_outputs) # B x 1 x H
|
| 93 |
+
# Get current hidden state from input word and last hidden state
|
| 94 |
+
rnn_output, rnn_hidden = self.gru(torch.cat((embedding_all_[:, t:t+1, :], context), 2), rnn_hidden)
|
| 95 |
+
# rnn_output: B x 1 x H
|
| 96 |
+
# rnn_hidden: num_layers x B x H
|
| 97 |
+
current_fusion_emb = torch.cat((rnn_output, context), 2)
|
| 98 |
+
current_fusion_emb_ = self.em_dropout(current_fusion_emb)
|
| 99 |
+
candi_score = self.score(current_fusion_emb_, embedding_weight_all_, \
|
| 100 |
+
self.candi_mask) # B x (no_var_size + var_size)
|
| 101 |
+
all_seq_outputs.append(candi_score)
|
| 102 |
+
|
| 103 |
+
all_seq_outputs = torch.stack(all_seq_outputs, dim=1)
|
| 104 |
+
|
| 105 |
+
return all_seq_outputs
|
| 106 |
+
|
| 107 |
+
def _forward_test(self, encoder_outputs, problem_output):
|
| 108 |
+
"""
|
| 109 |
+
Decode with beam search algorithm
|
| 110 |
+
"""
|
| 111 |
+
exp_outputs = []
|
| 112 |
+
batch_size = encoder_outputs.size(0)
|
| 113 |
+
|
| 114 |
+
for sample_id in range(batch_size):
|
| 115 |
+
# predefine
|
| 116 |
+
rem_size = self.cfg.beam_size
|
| 117 |
+
encoder_output = encoder_outputs[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S1 x H
|
| 118 |
+
src_mask = self.src_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
|
| 119 |
+
embedding_var = self.embedding_var[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S3 x H
|
| 120 |
+
embedding_weight_no_var = self.embedding_tgt(self.no_var_id.repeat(rem_size, 1)) # beam_size x no_var_size x H
|
| 121 |
+
embedding_weight_all = torch.cat((embedding_weight_no_var, embedding_var), dim=1) # beam_size x (no_var_size + var_size) x H
|
| 122 |
+
embedding_weight_all_ = self.em_dropout(embedding_weight_all)
|
| 123 |
+
candi_mask = self.candi_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
|
| 124 |
+
candi_exp_output = []
|
| 125 |
+
candi_score_output = []
|
| 126 |
+
|
| 127 |
+
for i in range(self.cfg.max_output_len):
|
| 128 |
+
# initial varible
|
| 129 |
+
if i==0:
|
| 130 |
+
input_token = torch.LongTensor([[self.sos_id]]*rem_size).cuda() # rem_size x 1
|
| 131 |
+
rnn_hidden = problem_output[:, sample_id:sample_id+1].repeat(1, rem_size, 1) # layer_num x rem_size x H
|
| 132 |
+
current_score = torch.FloatTensor([[0.0]]*rem_size).cuda() # rem_size x 1
|
| 133 |
+
current_exp_list = [[]]*rem_size
|
| 134 |
+
else:
|
| 135 |
+
input_token = torch.LongTensor(token_list).unsqueeze(1).cuda()
|
| 136 |
+
rnn_hidden = rnn_hidden[:, cand_list]
|
| 137 |
+
rem_size = len(exp_list)
|
| 138 |
+
current_score = torch.FloatTensor(score_list[:rem_size]).unsqueeze(1).cuda()
|
| 139 |
+
current_exp_list = exp_list
|
| 140 |
+
|
| 141 |
+
# input embedding
|
| 142 |
+
tgt_novar_id = torch.clamp(input_token, max=self.var_start-1) # rem_size x 1
|
| 143 |
+
novar_embedding = self.embedding_tgt(tgt_novar_id) # rem_size x 1 x H
|
| 144 |
+
tgt_var_id = torch.clamp(input_token-self.var_start, min=0) # rem_size x 1
|
| 145 |
+
var_embeddings = embedding_var[:rem_size].gather(dim=1, index=tgt_var_id.unsqueeze(2). \
|
| 146 |
+
repeat(1, 1, self.cfg.decoder_embedding_size)) # rem_size x 1 x H
|
| 147 |
+
choose_mask = (input_token<self.var_start).unsqueeze(2). \
|
| 148 |
+
repeat(1, 1, self.cfg.decoder_embedding_size) # rem_size x 1 x H
|
| 149 |
+
embedding_all = torch.where(choose_mask, novar_embedding, var_embeddings) # rem_size x 1 x H
|
| 150 |
+
embedding_all_ = self.em_dropout(embedding_all)
|
| 151 |
+
# attention
|
| 152 |
+
current_hiddens = self.em_dropout(rnn_hidden[-1].unsqueeze(1)) # rem_size x 1 x H
|
| 153 |
+
attn_weights = self.attn(current_hiddens, encoder_output[:rem_size], src_mask[:rem_size]) # rem_size x S1
|
| 154 |
+
context = attn_weights.unsqueeze(1).bmm(encoder_output[:rem_size]) # rem_size x 1 x H
|
| 155 |
+
# Get current hidden state from input word and last hidden state
|
| 156 |
+
rnn_output, rnn_hidden = self.gru(torch.cat((embedding_all_, context), 2), rnn_hidden)
|
| 157 |
+
# rnn_output: rem_size x 1 x H
|
| 158 |
+
# rnn_hidden: num_layers x rem_size x H
|
| 159 |
+
current_fusion_emb = torch.cat((rnn_output, context), 2)
|
| 160 |
+
current_fusion_emb_ = self.em_dropout(current_fusion_emb)
|
| 161 |
+
candi_score = self.score(current_fusion_emb_, embedding_weight_all_[:rem_size], \
|
| 162 |
+
candi_mask[:rem_size]) # rem_size x (no_var_size + var_size)
|
| 163 |
+
|
| 164 |
+
if i==0:
|
| 165 |
+
new_score = F.log_softmax(candi_score, dim=1)[:1]
|
| 166 |
+
else:
|
| 167 |
+
new_score = F.log_softmax(candi_score, dim=1) + current_score
|
| 168 |
+
|
| 169 |
+
cand_tup_list = [(score, id) for id, score in enumerate(new_score.view(-1).tolist())]
|
| 170 |
+
cand_tup_list += [(score, -1) for score in candi_score_output]
|
| 171 |
+
cand_tup_list.sort(key=lambda x:x[0], reverse=True)
|
| 172 |
+
|
| 173 |
+
token_list = []
|
| 174 |
+
cand_list = []
|
| 175 |
+
exp_list = []
|
| 176 |
+
score_list = []
|
| 177 |
+
|
| 178 |
+
for tv, ti in cand_tup_list[:self.cfg.beam_size]:
|
| 179 |
+
if ti!=-1:
|
| 180 |
+
idex = ti
|
| 181 |
+
x = idex // candi_score.size(-1)
|
| 182 |
+
y = idex % candi_score.size(-1)
|
| 183 |
+
if y!=self.eos_id:
|
| 184 |
+
token_list.append(y)
|
| 185 |
+
cand_list.append(x)
|
| 186 |
+
exp_list.append(current_exp_list[x]+[y])
|
| 187 |
+
score_list.append(tv)
|
| 188 |
+
else:
|
| 189 |
+
candi_exp_output.append(current_exp_list[x])
|
| 190 |
+
candi_score_output.append(float(tv))
|
| 191 |
+
|
| 192 |
+
if len(token_list)==0:
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
if len(candi_exp_output)>0:
|
| 196 |
+
_, candi_exp_output = zip(*sorted(zip(candi_score_output, candi_exp_output), reverse=True))
|
| 197 |
+
exp_outputs.append(list(candi_exp_output[:self.cfg.beam_size]))
|
| 198 |
+
else:
|
| 199 |
+
exp_outputs.append([])
|
| 200 |
+
|
| 201 |
+
return exp_outputs
|
model/decoder/transformer.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from utils.utils import sequence_mask
|
| 4 |
+
from model.module import *
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
class PositionalEncoding(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, d_model, max_len=5000, dropout_rate=0.2):
|
| 11 |
+
super(PositionalEncoding, self).__init__()
|
| 12 |
+
|
| 13 |
+
pe = torch.zeros(max_len, d_model)
|
| 14 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
| 15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
|
| 16 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 17 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 18 |
+
pe = pe.unsqueeze(0)
|
| 19 |
+
self.register_buffer("pe", pe)
|
| 20 |
+
|
| 21 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
"""
|
| 25 |
+
x: [B, max_len, d_model]
|
| 26 |
+
pe: [1, max_len, d_model]
|
| 27 |
+
"""
|
| 28 |
+
x = x + self.pe[:, : x.size(1)].requires_grad_(False)
|
| 29 |
+
return self.dropout(x)
|
| 30 |
+
|
| 31 |
+
class TransformerDecoder(nn.Module):
|
| 32 |
+
|
| 33 |
+
def __init__(self, cfg, tgt_lang, \
|
| 34 |
+
d_model=256, nhead=8, num_decoder_layers=4, dim_feedforward=1024, dropout=0.2):
|
| 35 |
+
super(TransformerDecoder, self).__init__()
|
| 36 |
+
|
| 37 |
+
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
|
| 38 |
+
decoder_norm = nn.LayerNorm(d_model)
|
| 39 |
+
self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
| 40 |
+
self.position_dec = PositionalEncoding(d_model=d_model)
|
| 41 |
+
|
| 42 |
+
self.score = Score_Multi(cfg.decoder_hidden_size, cfg.decoder_embedding_size)
|
| 43 |
+
self.var_start = tgt_lang.var_start
|
| 44 |
+
self.embedding_tgt = nn.Embedding(self.var_start, cfg.decoder_embedding_size, padding_idx=0)
|
| 45 |
+
self.no_var_id = torch.arange(self.var_start).unsqueeze(0).cuda()
|
| 46 |
+
|
| 47 |
+
self._reset_parameters()
|
| 48 |
+
self.d_model = d_model
|
| 49 |
+
self.nhead = nhead
|
| 50 |
+
self.cfg = cfg
|
| 51 |
+
self.sos_id = tgt_lang.word2index["[SOS]"]
|
| 52 |
+
self.eos_id = tgt_lang.word2index["[EOS]"]
|
| 53 |
+
|
| 54 |
+
def _reset_parameters(self):
|
| 55 |
+
"""
|
| 56 |
+
Initiate parameters in the transformer model.
|
| 57 |
+
"""
|
| 58 |
+
for p in self.parameters():
|
| 59 |
+
if p.dim() > 1:
|
| 60 |
+
nn.init.xavier_uniform_(p)
|
| 61 |
+
|
| 62 |
+
def get_square_subsequent_mask(self, sz):
|
| 63 |
+
"""
|
| 64 |
+
Generate a square mask for the sequence. The masked positions are filled with True.
|
| 65 |
+
Unmasked positions are filled with False.
|
| 66 |
+
"""
|
| 67 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 0).transpose(0, 1)
|
| 68 |
+
return mask.cuda()
|
| 69 |
+
|
| 70 |
+
def get_var_encoder_outputs(self, encoder_outputs, var_pos):
|
| 71 |
+
"""
|
| 72 |
+
Arguments:
|
| 73 |
+
encoder_outputs: B x S1 x H
|
| 74 |
+
var_pos: B x S3
|
| 75 |
+
Returns:
|
| 76 |
+
var_embeddings: B x S3 x H
|
| 77 |
+
"""
|
| 78 |
+
hidden_size = encoder_outputs.size(-1)
|
| 79 |
+
expand_var_pos = var_pos.unsqueeze(-1).repeat(1, 1, hidden_size)
|
| 80 |
+
var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_pos)
|
| 81 |
+
return var_embeddings
|
| 82 |
+
|
| 83 |
+
def forward(self, memory, len_src, tgt, len_tgt, var_pos, len_var, is_train=False):
|
| 84 |
+
'''
|
| 85 |
+
memory: B x S1 x H
|
| 86 |
+
len_src: B
|
| 87 |
+
tgt: B x S2
|
| 88 |
+
len_tgt: B
|
| 89 |
+
var_pos: B x S3(var_size)
|
| 90 |
+
len_var: B
|
| 91 |
+
'''
|
| 92 |
+
self.embedding_var = self.get_var_encoder_outputs(memory, var_pos) # B x S3 x H
|
| 93 |
+
self.candi_mask = sequence_mask(self.var_start + len_var) # B x (no_var_size + var_size)
|
| 94 |
+
self.memory_key_padding_mask = ~sequence_mask(len_src) # B x S1
|
| 95 |
+
if is_train:
|
| 96 |
+
return self._forward_train(memory, tgt, len_tgt)
|
| 97 |
+
else:
|
| 98 |
+
return self._forward_test(memory)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _forward_train(self, memory, tgt, len_tgt):
|
| 102 |
+
# mask
|
| 103 |
+
tgt_mask = self.get_square_subsequent_mask(tgt.size(-1))
|
| 104 |
+
tgt_key_padding_mask = ~sequence_mask(len_tgt)
|
| 105 |
+
# emb_tgt
|
| 106 |
+
tgt_novar_id = torch.clamp(tgt, max=self.var_start-1) # B x S2
|
| 107 |
+
novar_embedding = self.embedding_tgt(tgt_novar_id) # B x S2 x H
|
| 108 |
+
tgt_var_id = torch.clamp(tgt-self.var_start, min=0) # B x S2
|
| 109 |
+
var_embeddings = self.embedding_var.gather(dim=1, index = \
|
| 110 |
+
tgt_var_id.unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size)) # B x S2 x H
|
| 111 |
+
choose_mask = (tgt<self.var_start).unsqueeze(2). \
|
| 112 |
+
repeat(1, 1, self.cfg.decoder_embedding_size)
|
| 113 |
+
emb_tgt = torch.where(choose_mask, novar_embedding, var_embeddings) # B x S2 x H
|
| 114 |
+
# position decoding
|
| 115 |
+
emb_tgt = self.position_dec(emb_tgt)
|
| 116 |
+
output = self.decoder( # B x S2 x H
|
| 117 |
+
emb_tgt.permute(1,0,2),
|
| 118 |
+
memory.permute(1,0,2),
|
| 119 |
+
tgt_mask=tgt_mask,
|
| 120 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 121 |
+
memory_key_padding_mask=self.memory_key_padding_mask,
|
| 122 |
+
).permute(1,0,2)
|
| 123 |
+
# candi weight embedding
|
| 124 |
+
embedding_weight_no_var = self.embedding_tgt(self.no_var_id.repeat(len(len_tgt), 1)) # B x no_var_size x H
|
| 125 |
+
embedding_weight_all = torch.cat((embedding_weight_no_var, self.embedding_var), dim=1) # B x (no_var_size+var_size) x H
|
| 126 |
+
candi_score = self.score( # B x S2 x (no_var_size + var_size)
|
| 127 |
+
output,
|
| 128 |
+
embedding_weight_all, \
|
| 129 |
+
self.candi_mask
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return candi_score[:,:-1,:].clone()
|
| 133 |
+
|
| 134 |
+
def _forward_test(self, memory):
|
| 135 |
+
|
| 136 |
+
exp_outputs = []
|
| 137 |
+
|
| 138 |
+
for sample_id in range(memory.size(0)):
|
| 139 |
+
# predefine
|
| 140 |
+
rem_size = self.cfg.beam_size
|
| 141 |
+
memory_item = memory[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S1 x H
|
| 142 |
+
memory_key_padding_mask = self.memory_key_padding_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
|
| 143 |
+
embedding_var = self.embedding_var[sample_id:sample_id+1].repeat(rem_size, 1, 1) # beam_size x S3 x H
|
| 144 |
+
embedding_weight_no_var = self.embedding_tgt(self.no_var_id.repeat(rem_size, 1)) # beam_size x no_var_size x H
|
| 145 |
+
embedding_weight_all = torch.cat((embedding_weight_no_var, embedding_var), dim=1) # beam_size x (no_var_size + var_size) x H
|
| 146 |
+
candi_mask = self.candi_mask[sample_id:sample_id+1].repeat(rem_size, 1) # beam_size x S1
|
| 147 |
+
|
| 148 |
+
candi_exp_output = []
|
| 149 |
+
candi_score_output = []
|
| 150 |
+
|
| 151 |
+
tgt = torch.LongTensor([[self.sos_id]]*rem_size).cuda() # rem_size x 1
|
| 152 |
+
len_tgt = torch.LongTensor([1]*rem_size).cuda() # rem_size
|
| 153 |
+
current_score = torch.FloatTensor([[0.0]]*rem_size).cuda() # rem_size x 1
|
| 154 |
+
current_exp_list = [[self.sos_id]]*rem_size
|
| 155 |
+
|
| 156 |
+
for i in range(self.cfg.max_output_len):
|
| 157 |
+
# mask
|
| 158 |
+
tgt_mask = self.get_square_subsequent_mask(tgt.size(-1))
|
| 159 |
+
tgt_key_padding_mask = ~sequence_mask(len_tgt)
|
| 160 |
+
# input embedding
|
| 161 |
+
tgt_novar_id = torch.clamp(tgt, max=self.var_start-1) # rem_size x S
|
| 162 |
+
novar_embedding = self.embedding_tgt(tgt_novar_id) # rem_size x S x H
|
| 163 |
+
tgt_var_id = torch.clamp(tgt-self.var_start, min=0) # rem_size x S
|
| 164 |
+
var_embeddings = embedding_var[:rem_size].gather(dim=1, index=tgt_var_id.unsqueeze(2). \
|
| 165 |
+
repeat(1, 1, self.cfg.decoder_embedding_size)) # rem_size x S x H
|
| 166 |
+
choose_mask = (tgt<self.var_start).unsqueeze(2).repeat(1, 1, self.cfg.decoder_embedding_size) # rem_size x S x H
|
| 167 |
+
emb_tgt = torch.where(choose_mask, novar_embedding, var_embeddings) # rem_size x S x H
|
| 168 |
+
# position decoding
|
| 169 |
+
emb_tgt = self.position_dec(emb_tgt)
|
| 170 |
+
output = self.decoder( # rem_size x S x H
|
| 171 |
+
emb_tgt.permute(1,0,2),
|
| 172 |
+
memory_item[:rem_size].permute(1,0,2),
|
| 173 |
+
tgt_mask=tgt_mask,
|
| 174 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 175 |
+
memory_key_padding_mask=memory_key_padding_mask[:rem_size],
|
| 176 |
+
).permute(1,0,2)
|
| 177 |
+
candi_score = self.score( # rem_size x S x (no_var_size + var_size)
|
| 178 |
+
output,
|
| 179 |
+
embedding_weight_all[:rem_size], \
|
| 180 |
+
candi_mask[:rem_size]
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
if i==0:
|
| 184 |
+
new_score = F.log_softmax(candi_score[:, -1, :], dim=1)[:1]
|
| 185 |
+
else:
|
| 186 |
+
new_score = F.log_softmax(candi_score[:, -1, :], dim=1) + current_score # rem_size x (no_var_size + var_size)
|
| 187 |
+
|
| 188 |
+
topv, topi = new_score.view(-1).topk(rem_size)
|
| 189 |
+
exp_list = []
|
| 190 |
+
score_list = topv.tolist()
|
| 191 |
+
|
| 192 |
+
for tv, ti in zip(topv, topi):
|
| 193 |
+
idex = ti.item()
|
| 194 |
+
x = idex // candi_score.size(-1)
|
| 195 |
+
y = idex % candi_score.size(-1)
|
| 196 |
+
if y!=self.eos_id:
|
| 197 |
+
exp_list.append(current_exp_list[x]+[y])
|
| 198 |
+
else:
|
| 199 |
+
candi_exp_output.append(current_exp_list[x][1:])
|
| 200 |
+
candi_score_output.append(float(tv))
|
| 201 |
+
|
| 202 |
+
if len(exp_list)==0:
|
| 203 |
+
break
|
| 204 |
+
|
| 205 |
+
tgt = torch.LongTensor(exp_list).cuda() # rem_size x S
|
| 206 |
+
len_tgt = torch.LongTensor([len(item) for item in exp_list]).cuda() # rem_size
|
| 207 |
+
current_exp_list = exp_list
|
| 208 |
+
rem_size = len(exp_list)
|
| 209 |
+
current_score = torch.FloatTensor(score_list[:rem_size]).unsqueeze(1).cuda() # rem_size x 1
|
| 210 |
+
|
| 211 |
+
if len(candi_exp_output)>0:
|
| 212 |
+
_, candi_exp_output = zip(*sorted(zip(candi_score_output, candi_exp_output), reverse=True))
|
| 213 |
+
exp_outputs.append(list(candi_exp_output))
|
| 214 |
+
else:
|
| 215 |
+
exp_outputs.append([])
|
| 216 |
+
|
| 217 |
+
return exp_outputs
|
model/decoder/tree_decoder.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from utils import *
|
| 4 |
+
from model.module import *
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
import copy
|
| 7 |
+
|
| 8 |
+
class TreeNode: # the class save the tree node
|
| 9 |
+
def __init__(self, embedding, left_flag=False):
|
| 10 |
+
self.embedding = embedding
|
| 11 |
+
self.left_flag = left_flag
|
| 12 |
+
|
| 13 |
+
class TreeEmbedding: # the class save the tree
|
| 14 |
+
def __init__(self, embedding, terminal=False):
|
| 15 |
+
self.embedding = embedding
|
| 16 |
+
self.terminal = terminal
|
| 17 |
+
|
| 18 |
+
class TreeBeam: # the class save the beam node
|
| 19 |
+
def __init__(self, score, node_stacks, embeddings_stacks, left_child_trees, out):
|
| 20 |
+
self.score = score
|
| 21 |
+
self.embeddings_stacks = embeddings_stacks
|
| 22 |
+
self.node_stacks = node_stacks
|
| 23 |
+
self.left_child_trees = left_child_trees
|
| 24 |
+
self.out = out
|
| 25 |
+
|
| 26 |
+
class Prediction(nn.Module):
|
| 27 |
+
# a seq2tree decoder with Problem aware dynamic encoding
|
| 28 |
+
def __init__(self, cfg, op_const_size):
|
| 29 |
+
super(Prediction, self).__init__()
|
| 30 |
+
# Define layers
|
| 31 |
+
self.em_dropout = nn.Dropout(cfg.dropout_rate)
|
| 32 |
+
# for Computational symbols and Generated numbers
|
| 33 |
+
self.concat_l = nn.Linear(cfg.decoder_hidden_size, cfg.decoder_hidden_size)
|
| 34 |
+
self.concat_r = nn.Linear(cfg.decoder_hidden_size * 2, cfg.decoder_hidden_size)
|
| 35 |
+
self.concat_lg = nn.Linear(cfg.decoder_hidden_size, cfg.decoder_hidden_size)
|
| 36 |
+
self.concat_rg = nn.Linear(cfg.decoder_hidden_size * 2, cfg.decoder_hidden_size)
|
| 37 |
+
# attention module
|
| 38 |
+
self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size)
|
| 39 |
+
self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size)
|
| 40 |
+
# predefined constant
|
| 41 |
+
self.op_const_id = torch.arange(op_const_size).unsqueeze(0).cuda()
|
| 42 |
+
self.padding_hidden = torch.zeros(1, cfg.decoder_hidden_size).cuda()
|
| 43 |
+
|
| 44 |
+
def forward(self, node_stacks, left_child_trees, encoder_outputs, var_pades, source_mask, candi_mask, embedding_op_const):
|
| 45 |
+
'''
|
| 46 |
+
Augments:
|
| 47 |
+
node_stacks: [[TreeNode(_)]]*B, store the variable h
|
| 48 |
+
left_child_trees: [t]*B, store the representation of left tree
|
| 49 |
+
encoder_outputs: [B, S1, H]
|
| 50 |
+
var_pades: [B, S2, H], all_vars_encoder_outputs
|
| 51 |
+
padding_hidden: [1, H]
|
| 52 |
+
source_mask: [B, S1], mask for source seq
|
| 53 |
+
candi_mask: [B, op_size+const_size+var_size], mask for target seq
|
| 54 |
+
Returns:
|
| 55 |
+
num_score: [B x (op_size+const_size+var_size)]
|
| 56 |
+
current_embeddings: q [B x 1 x H], the target vector of the current node
|
| 57 |
+
current_context: c [B x 1 x H], the context vector of the current node, is calculated using the target vector and encoder_outputs
|
| 58 |
+
current_all_embeddings: [B x (op_size+const_size+var_size) x H] e (M_op, M_con, h_loc^p)
|
| 59 |
+
'''
|
| 60 |
+
current_embeddings = []
|
| 61 |
+
|
| 62 |
+
for node_list in node_stacks:
|
| 63 |
+
if len(node_list) == 0:
|
| 64 |
+
current_embeddings.append(self.padding_hidden)
|
| 65 |
+
else:
|
| 66 |
+
current_node = node_list[-1]
|
| 67 |
+
current_embeddings.append(current_node.embedding)
|
| 68 |
+
|
| 69 |
+
current_node_temp = [] # B x (1 x H)
|
| 70 |
+
|
| 71 |
+
for l, c in zip(left_child_trees, current_embeddings):
|
| 72 |
+
if l is None:
|
| 73 |
+
cd = self.em_dropout(c)
|
| 74 |
+
g = torch.tanh(self.concat_l(cd))
|
| 75 |
+
t = torch.sigmoid(self.concat_lg(cd))
|
| 76 |
+
current_node_temp.append(g*t)
|
| 77 |
+
else:
|
| 78 |
+
ld = self.em_dropout(l)
|
| 79 |
+
cd = self.em_dropout(c)
|
| 80 |
+
g = torch.tanh(self.concat_r(torch.cat((ld, cd), 1)))
|
| 81 |
+
t = torch.sigmoid(self.concat_rg(torch.cat((ld, cd), 1)))
|
| 82 |
+
current_node_temp.append(g*t)
|
| 83 |
+
|
| 84 |
+
current_node = torch.stack(current_node_temp, dim=0) # B x 1 x H (q)
|
| 85 |
+
current_embeddings = self.em_dropout(current_node)
|
| 86 |
+
current_attn = self.attn(current_embeddings, encoder_outputs, source_mask) # B x S
|
| 87 |
+
current_context = current_attn.unsqueeze(1).bmm(encoder_outputs) # B x 1 x H (c)
|
| 88 |
+
leaf_input = torch.cat((current_node, current_context), 2) # B x 1 x 2H
|
| 89 |
+
|
| 90 |
+
embedding_weight_op_const = embedding_op_const(self.op_const_id.repeat(var_pades.size(0), 1)) # B x var_size x H
|
| 91 |
+
embedding_weight_all = torch.cat((embedding_weight_op_const, var_pades), dim=1) # B x (op_size+const_size+var_size) x H
|
| 92 |
+
|
| 93 |
+
leaf_input = self.em_dropout(leaf_input)
|
| 94 |
+
embedding_weight_all_ = self.em_dropout(embedding_weight_all)
|
| 95 |
+
num_score = self.score(leaf_input, embedding_weight_all_, candi_mask) # B x (op_size+const_size+var_size)
|
| 96 |
+
|
| 97 |
+
return num_score, current_node, current_context, embedding_weight_all
|
| 98 |
+
|
| 99 |
+
class GenerateNode(nn.Module):
|
| 100 |
+
def __init__(self, cfg, op_size):
|
| 101 |
+
super(GenerateNode, self).__init__()
|
| 102 |
+
|
| 103 |
+
self.embedding_size = cfg.decoder_embedding_size
|
| 104 |
+
self.hidden_size = cfg.decoder_hidden_size
|
| 105 |
+
self.op_size = op_size
|
| 106 |
+
|
| 107 |
+
self.em_dropout = nn.Dropout(cfg.dropout_rate)
|
| 108 |
+
self.generate_l = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
|
| 109 |
+
self.generate_r = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
|
| 110 |
+
self.generate_lg = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
|
| 111 |
+
self.generate_rg = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
|
| 112 |
+
|
| 113 |
+
def forward(self, current_embedding, node_label, current_context, embedding_op_const):
|
| 114 |
+
"""
|
| 115 |
+
Generate the hidden node hl and hr of tree, according to the front part of eq(10)(11)
|
| 116 |
+
Arguments:
|
| 117 |
+
current_embedding: [B x 1 x H (q)], the target vector of the current node
|
| 118 |
+
node_label: [B (id)]
|
| 119 |
+
current_context: [B x 1 x H (c)], context vector of current node
|
| 120 |
+
embedding_op_const: Embedding of op_const
|
| 121 |
+
Returns:
|
| 122 |
+
left_child: [B x H (h)]
|
| 123 |
+
right_child: [B x H (h)]
|
| 124 |
+
token_embedding: [B x H (e(y|P) of op)]
|
| 125 |
+
"""
|
| 126 |
+
node_label_op = torch.clamp(node_label, max=self.op_size-1)
|
| 127 |
+
current_embedding_ = self.em_dropout(current_embedding.squeeze(1))
|
| 128 |
+
current_context_ = self.em_dropout(current_context.squeeze(1))
|
| 129 |
+
token_embedding = embedding_op_const(node_label_op)
|
| 130 |
+
token_embedding_ = self.em_dropout(token_embedding)
|
| 131 |
+
|
| 132 |
+
l_child = torch.tanh(self.generate_l(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
|
| 133 |
+
l_child_g = torch.sigmoid(self.generate_lg(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
|
| 134 |
+
r_child = torch.tanh(self.generate_r(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
|
| 135 |
+
r_child_g = torch.sigmoid(self.generate_rg(torch.cat((current_embedding_, current_context_, token_embedding_), 1)))
|
| 136 |
+
l_child = l_child * l_child_g
|
| 137 |
+
r_child = r_child * r_child_g
|
| 138 |
+
|
| 139 |
+
return l_child, r_child, token_embedding
|
| 140 |
+
|
| 141 |
+
class Merge(nn.Module):
|
| 142 |
+
"""
|
| 143 |
+
Get subtree embedding via Recursive Neural Network
|
| 144 |
+
"""
|
| 145 |
+
def __init__(self, cfg):
|
| 146 |
+
super(Merge, self).__init__()
|
| 147 |
+
|
| 148 |
+
self.embedding_size = cfg.decoder_embedding_size
|
| 149 |
+
self.hidden_size = cfg.decoder_hidden_size
|
| 150 |
+
self.em_dropout = nn.Dropout(cfg.dropout_rate)
|
| 151 |
+
self.merge = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
|
| 152 |
+
self.merge_g = nn.Linear(self.hidden_size * 2 + self.embedding_size, self.hidden_size)
|
| 153 |
+
|
| 154 |
+
def forward(self, node_embedding, sub_tree_1, sub_tree_2):
|
| 155 |
+
'''
|
| 156 |
+
Arguments:
|
| 157 |
+
node_embedding: 1 x H
|
| 158 |
+
sub_tree_1: 1 x H
|
| 159 |
+
sub_tree_2: 1 x H
|
| 160 |
+
Return:
|
| 161 |
+
sub_tree: 1 x H
|
| 162 |
+
'''
|
| 163 |
+
sub_tree_1 = self.em_dropout(sub_tree_1)
|
| 164 |
+
sub_tree_2 = self.em_dropout(sub_tree_2)
|
| 165 |
+
node_embedding = self.em_dropout(node_embedding)
|
| 166 |
+
|
| 167 |
+
sub_tree = torch.tanh(self.merge(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1)))
|
| 168 |
+
sub_tree_g = torch.sigmoid(self.merge_g(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1)))
|
| 169 |
+
sub_tree = sub_tree * sub_tree_g
|
| 170 |
+
|
| 171 |
+
return sub_tree
|
| 172 |
+
|
| 173 |
+
class TreeDecoder(nn.Module):
|
| 174 |
+
def __init__(self, cfg, tgt_lang):
|
| 175 |
+
super(TreeDecoder, self).__init__()
|
| 176 |
+
# embedding for op, const, num
|
| 177 |
+
self.var_start = tgt_lang.var_start
|
| 178 |
+
self.op_num = tgt_lang.op_num
|
| 179 |
+
self.const_num = tgt_lang.const_num
|
| 180 |
+
|
| 181 |
+
self.embedding_op_const = nn.Embedding(self.op_num+self.const_num, cfg.decoder_embedding_size)
|
| 182 |
+
self.embedding_var = None # obtain from encoder
|
| 183 |
+
self.cfg = cfg
|
| 184 |
+
# modules of TreeDecoder
|
| 185 |
+
self.predict = Prediction(cfg, self.op_num+self.const_num)
|
| 186 |
+
self.generate = GenerateNode(cfg, self.op_num)
|
| 187 |
+
self.merge = Merge(cfg)
|
| 188 |
+
|
| 189 |
+
def get_var_encoder_outputs(self, encoder_outputs, var_positions):
|
| 190 |
+
"""
|
| 191 |
+
Arguments:
|
| 192 |
+
encoder_outputs: B x S1 x H
|
| 193 |
+
var_positions: B x S2
|
| 194 |
+
Returns:
|
| 195 |
+
var_embeddings: B x S2 x H
|
| 196 |
+
"""
|
| 197 |
+
hidden_size = encoder_outputs.size(-1)
|
| 198 |
+
expand_var_positions = var_positions.unsqueeze(-1).repeat(1, 1, hidden_size)
|
| 199 |
+
var_embeddings = encoder_outputs.gather(dim=1, index = expand_var_positions)
|
| 200 |
+
return var_embeddings
|
| 201 |
+
|
| 202 |
+
def forward(self, encoder_outputs, problem_output, len_source, var_positions, len_var, \
|
| 203 |
+
is_train=False, text_target=None, len_target=None):
|
| 204 |
+
"""
|
| 205 |
+
Arguments:
|
| 206 |
+
encoder_outputs: B x S1 x H
|
| 207 |
+
problem_output: B x H
|
| 208 |
+
len_source: B
|
| 209 |
+
text_target: B x S2
|
| 210 |
+
len_target: B
|
| 211 |
+
var_positions: B x S3
|
| 212 |
+
len_var: B
|
| 213 |
+
Return:
|
| 214 |
+
training: output B x S x (op_size+const_size+var_size), logits of one batch
|
| 215 |
+
testing: [expr] x B
|
| 216 |
+
"""
|
| 217 |
+
self.embedding_var = self.get_var_encoder_outputs(encoder_outputs, var_positions) # B x S2 x H
|
| 218 |
+
self.source_mask = sequence_mask(len_source)
|
| 219 |
+
self.candi_mask = sequence_mask(len_var+self.var_start)
|
| 220 |
+
if is_train:
|
| 221 |
+
return self._forward_train(encoder_outputs, problem_output, text_target)
|
| 222 |
+
else:
|
| 223 |
+
return self._forward_test(encoder_outputs, problem_output)
|
| 224 |
+
|
| 225 |
+
def _forward_train(self, encoder_outputs, problem_output, text_target):
|
| 226 |
+
"""
|
| 227 |
+
Arguments:
|
| 228 |
+
embeddings_stacks: [[TreeEmbedding(t, terminal)]]*B, a stack of subtrees t in the first order traversal
|
| 229 |
+
left_child_trees: [t]*B, the representation of left tree of current node
|
| 230 |
+
node_stacks: [[TreeNode(h, left_flag)]]*B, a stack of hidden state h in the first order traversal
|
| 231 |
+
Returns:
|
| 232 |
+
all_node_outputs: B x S x (op_size+const_size+var_size), logits of one batch
|
| 233 |
+
"""
|
| 234 |
+
node_stacks = [[TreeNode(init_hidden)] for init_hidden in problem_output.split(1, dim=0)]
|
| 235 |
+
embeddings_stacks = [[] for _ in range(encoder_outputs.size(0))]
|
| 236 |
+
left_child_trees = [None]*encoder_outputs.size(0)
|
| 237 |
+
all_node_outputs = []
|
| 238 |
+
|
| 239 |
+
for t in range(text_target.size(1)):
|
| 240 |
+
num_score, current_embeddings, current_context, current_all_embeddings = self.predict(
|
| 241 |
+
node_stacks,
|
| 242 |
+
left_child_trees,
|
| 243 |
+
encoder_outputs,
|
| 244 |
+
self.embedding_var,
|
| 245 |
+
self.source_mask,
|
| 246 |
+
self.candi_mask,
|
| 247 |
+
self.embedding_op_const)
|
| 248 |
+
|
| 249 |
+
all_node_outputs.append(num_score) # [B x (op_size+const_size+var_size)] * S
|
| 250 |
+
|
| 251 |
+
left_child, right_child, token_embedding = self.generate(
|
| 252 |
+
current_embeddings,
|
| 253 |
+
text_target[:,t],
|
| 254 |
+
current_context,
|
| 255 |
+
self.embedding_op_const)
|
| 256 |
+
|
| 257 |
+
left_child_trees = []
|
| 258 |
+
|
| 259 |
+
for idx, (l, r, node_stack, target_id, embeddings_stack) in enumerate(zip(left_child.split(1), right_child.split(1),
|
| 260 |
+
node_stacks, text_target[:,t].tolist(), embeddings_stacks)):
|
| 261 |
+
# Determines whether the tree traversal is complete
|
| 262 |
+
if len(node_stack) != 0:
|
| 263 |
+
node_stack.pop()
|
| 264 |
+
else:
|
| 265 |
+
left_child_trees.append(None)
|
| 266 |
+
continue
|
| 267 |
+
if target_id < self.op_num:
|
| 268 |
+
node_stack.append(TreeNode(r))
|
| 269 |
+
node_stack.append(TreeNode(l, left_flag=True))
|
| 270 |
+
# embeddings_stack, put e(y|P) of op in temporarily
|
| 271 |
+
embeddings_stack.append(TreeEmbedding(token_embedding[idx].unsqueeze(0), False))
|
| 272 |
+
else:
|
| 273 |
+
current_num = current_all_embeddings[idx, target_id].unsqueeze(0) # 1 x H
|
| 274 |
+
# Reach the right leaf node and merge the tree representation from bottom up
|
| 275 |
+
while len(embeddings_stack) > 0 and embeddings_stack[-1].terminal:
|
| 276 |
+
sub_stree = embeddings_stack.pop()
|
| 277 |
+
op = embeddings_stack.pop()
|
| 278 |
+
# embedding vector of two sub-targets is merged as the subtree embedding of nodes, corresponding to eq(12)
|
| 279 |
+
# with e(y|P), sub_tree_1 and sub_tree_2
|
| 280 |
+
current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
|
| 281 |
+
embeddings_stack.append(TreeEmbedding(current_num, True))
|
| 282 |
+
# Reach the left leaf node and save the representation of the left subtree for generation of q
|
| 283 |
+
if len(embeddings_stack) > 0 and embeddings_stack[-1].terminal:
|
| 284 |
+
left_child_trees.append(embeddings_stack[-1].embedding)
|
| 285 |
+
else:
|
| 286 |
+
left_child_trees.append(None)
|
| 287 |
+
|
| 288 |
+
all_node_outputs = torch.stack(all_node_outputs, dim=1)
|
| 289 |
+
|
| 290 |
+
return all_node_outputs
|
| 291 |
+
|
| 292 |
+
def _forward_test(self, encoder_outputs, problem_output):
|
| 293 |
+
|
| 294 |
+
exp_outputs = []
|
| 295 |
+
|
| 296 |
+
for sample_id in range(encoder_outputs.size(0)):
|
| 297 |
+
# set batch size as 1
|
| 298 |
+
node_stacks = [[TreeNode(problem_output[sample_id:sample_id+1])]]
|
| 299 |
+
embeddings_stacks = [[]]
|
| 300 |
+
left_child_trees = [None]
|
| 301 |
+
beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_child_trees, [])]
|
| 302 |
+
|
| 303 |
+
for _ in range(self.cfg.max_output_len):
|
| 304 |
+
# re-maintain of one beams
|
| 305 |
+
current_beams = []
|
| 306 |
+
|
| 307 |
+
while len(beams) > 0:
|
| 308 |
+
beam_item = beams.pop()
|
| 309 |
+
# The candidates are stored in beams in all process
|
| 310 |
+
if len(beam_item.node_stacks[0]) == 0:
|
| 311 |
+
current_beams.append(beam_item)
|
| 312 |
+
continue
|
| 313 |
+
num_score, current_embeddings, current_context, current_all_embeddings = self.predict(
|
| 314 |
+
beam_item.node_stacks,
|
| 315 |
+
beam_item.left_child_trees,
|
| 316 |
+
encoder_outputs[sample_id:sample_id+1],
|
| 317 |
+
self.embedding_var[sample_id:sample_id+1],
|
| 318 |
+
self.source_mask[sample_id:sample_id+1],
|
| 319 |
+
self.candi_mask[sample_id:sample_id+1],
|
| 320 |
+
self.embedding_op_const)
|
| 321 |
+
|
| 322 |
+
out_score = F.log_softmax(num_score, dim=1)
|
| 323 |
+
topv, topi = out_score.topk(self.cfg.beam_size)
|
| 324 |
+
|
| 325 |
+
for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)):
|
| 326 |
+
|
| 327 |
+
current_node_stack = copy_list(beam_item.node_stacks)
|
| 328 |
+
current_left_child_trees = []
|
| 329 |
+
current_embeddings_stacks = copy_list(beam_item.embeddings_stacks)
|
| 330 |
+
current_out = copy.deepcopy(beam_item.out)
|
| 331 |
+
|
| 332 |
+
out_token = int(ti)
|
| 333 |
+
current_out.append(out_token)
|
| 334 |
+
current_node_stack[0].pop()
|
| 335 |
+
|
| 336 |
+
if out_token < self.op_num:
|
| 337 |
+
generate_input = torch.LongTensor([out_token]).cuda()
|
| 338 |
+
left_child, right_child, token_embedding = self.generate(
|
| 339 |
+
current_embeddings,
|
| 340 |
+
generate_input,
|
| 341 |
+
current_context,
|
| 342 |
+
self.embedding_op_const)
|
| 343 |
+
current_node_stack[0].append(TreeNode(right_child))
|
| 344 |
+
current_node_stack[0].append(TreeNode(left_child, left_flag=True))
|
| 345 |
+
current_embeddings_stacks[0].append(TreeEmbedding(token_embedding, False))
|
| 346 |
+
else:
|
| 347 |
+
current_num = current_all_embeddings[:, out_token]
|
| 348 |
+
while len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
|
| 349 |
+
sub_stree = current_embeddings_stacks[0].pop()
|
| 350 |
+
op = current_embeddings_stacks[0].pop()
|
| 351 |
+
current_num = self.merge(op.embedding, sub_stree.embedding, current_num)
|
| 352 |
+
current_embeddings_stacks[0].append(TreeEmbedding(current_num, True))
|
| 353 |
+
if len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
|
| 354 |
+
current_left_child_trees.append(current_embeddings_stacks[0][-1].embedding)
|
| 355 |
+
else:
|
| 356 |
+
current_left_child_trees.append(None)
|
| 357 |
+
|
| 358 |
+
current_beams.append(TreeBeam(beam_item.score+float(tv), current_node_stack, current_embeddings_stacks,
|
| 359 |
+
current_left_child_trees, current_out))
|
| 360 |
+
|
| 361 |
+
beams = sorted(current_beams, key=lambda x: x.score, reverse=True)
|
| 362 |
+
beams = beams[:self.cfg.beam_size]
|
| 363 |
+
|
| 364 |
+
# early termination
|
| 365 |
+
flag = True
|
| 366 |
+
for beam_item in beams:
|
| 367 |
+
if len(beam_item.node_stacks[0]) != 0:
|
| 368 |
+
flag = False
|
| 369 |
+
break
|
| 370 |
+
if flag: break
|
| 371 |
+
|
| 372 |
+
exp_outputs.append(beams[0].out)
|
| 373 |
+
|
| 374 |
+
return exp_outputs
|
model/encoder/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lstm import LSTM
|
| 2 |
+
from .gru import GRU
|
| 3 |
+
from config import encoder_list
|
| 4 |
+
from .transformer import TransformerEncoder
|
| 5 |
+
|
| 6 |
+
def get_encoder(params, *args):
|
| 7 |
+
|
| 8 |
+
if not params.encoder_type in encoder_list:
|
| 9 |
+
raise NotImplementedError(
|
| 10 |
+
"Unsupported Classifier: {}".format(params.encoder_type))
|
| 11 |
+
|
| 12 |
+
if params.encoder_type == "transformer":
|
| 13 |
+
pass
|
| 14 |
+
elif params.encoder_type == "lstm":
|
| 15 |
+
encoder = LSTM(params, *args)
|
| 16 |
+
elif params.encoder_type == "gru":
|
| 17 |
+
encoder = GRU(params, *args)
|
| 18 |
+
else:
|
| 19 |
+
raise NotImplementedError("Unsupported Encoder: {}".format(params.encoder_type))
|
| 20 |
+
|
| 21 |
+
return encoder
|
model/encoder/gru.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GRU(nn.Module):
|
| 5 |
+
|
| 6 |
+
def __init__(self, cfg):
|
| 7 |
+
super(GRU, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.is_bidirectional = True
|
| 10 |
+
self.batch_first = True
|
| 11 |
+
self.gru = nn.GRU(
|
| 12 |
+
input_size = cfg.encoder_embedding_size,
|
| 13 |
+
hidden_size = cfg.encoder_hidden_size, # int(hidden_size / num_directions),
|
| 14 |
+
num_layers = cfg.encoder_layers,
|
| 15 |
+
bidirectional = self.is_bidirectional,
|
| 16 |
+
dropout = cfg.dropout_rate,
|
| 17 |
+
batch_first = self.batch_first
|
| 18 |
+
)
|
| 19 |
+
self.hidden_size = cfg.encoder_hidden_size
|
| 20 |
+
self.dropout = nn.Dropout(cfg.dropout_rate)
|
| 21 |
+
|
| 22 |
+
def forward(self, src_emb, input_lengths, hidden=None):
|
| 23 |
+
|
| 24 |
+
input_emb = self.dropout(src_emb)
|
| 25 |
+
# input_emb = src_emb
|
| 26 |
+
packed = nn.utils.rnn.pack_padded_sequence(input_emb, input_lengths.cpu(), \
|
| 27 |
+
batch_first=self.batch_first, enforce_sorted=False)
|
| 28 |
+
pade_hidden = hidden
|
| 29 |
+
pade_outputs, pade_hidden = self.gru(packed, pade_hidden)
|
| 30 |
+
pade_outputs, _ = nn.utils.rnn.pad_packed_sequence(pade_outputs, batch_first=self.batch_first)
|
| 31 |
+
# pade_outputs [B, S, hidden_size*num_directions]
|
| 32 |
+
# pade_hidden [n_layers*num_directions, B, hidden_size]
|
| 33 |
+
if self.is_bidirectional:
|
| 34 |
+
pade_outputs = pade_outputs[:, :, :self.hidden_size] + pade_outputs[:, :, self.hidden_size:] # B x S x H
|
| 35 |
+
pade_hidden = pade_hidden[0::2, :, :] + pade_hidden[1::2, :, :]
|
| 36 |
+
|
| 37 |
+
return pade_outputs, pade_hidden
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
model/encoder/lstm.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LSTM(nn.Module):
|
| 5 |
+
|
| 6 |
+
def __init__(self, cfg):
|
| 7 |
+
super(LSTM, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.lstm = nn.LSTM(
|
| 10 |
+
input_size=cfg.WORD_EMBED_SIZE,
|
| 11 |
+
hidden_size=cfg.HIDDEN_SIZE, # int(hidden_size / num_directions),
|
| 12 |
+
num_layers=cfg.NUM_LAYERS,
|
| 13 |
+
batch_first=cfg.BATCH_FIRST, # first dim is batch_size or not
|
| 14 |
+
bidirectional=cfg.BIDIRECTIONAL
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def forward(self, input, h0, c0):
|
| 18 |
+
output, (hn, cn) = self.lstm(input, (h0, c0))
|
| 19 |
+
return output, hn, cn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
model/encoder/transformer.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from utils.utils import sequence_mask
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
class PositionalEncoding(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, d_model, max_len=5000, dropout=0.1):
|
| 9 |
+
super(PositionalEncoding, self).__init__()
|
| 10 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 11 |
+
pe = torch.zeros(max_len, d_model)
|
| 12 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
| 13 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
|
| 14 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 15 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 16 |
+
pe = pe.unsqueeze(0)
|
| 17 |
+
self.register_buffer("pe", pe)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
"""
|
| 21 |
+
x: [B, max_len, d_model]
|
| 22 |
+
pe: [1, max_len, d_model]
|
| 23 |
+
"""
|
| 24 |
+
x = x + self.pe[:, : x.size(1)].requires_grad_(False)
|
| 25 |
+
return self.dropout(x)
|
| 26 |
+
|
| 27 |
+
class LearnedPositionEncoding(nn.Module):
|
| 28 |
+
|
| 29 |
+
def __init__(self, d_model, max_len = 20):
|
| 30 |
+
super(LearnedPositionEncoding, self).__init__()
|
| 31 |
+
self.embedding = nn.Embedding(max_len, d_model)
|
| 32 |
+
|
| 33 |
+
def forward(self, x, var_pos):
|
| 34 |
+
"""
|
| 35 |
+
x: [B, max_len, d_model]
|
| 36 |
+
var_pos: [B, var_len]
|
| 37 |
+
"""
|
| 38 |
+
loc_mat = torch.zeros(x.size(0), x.size(1), dtype=torch.int64).cuda()
|
| 39 |
+
pos_id = torch.arange(1, var_pos.size(1)+1).repeat(var_pos.size(0), 1).cuda()
|
| 40 |
+
pos_id[var_pos==var_pos.min()] = 0
|
| 41 |
+
loc_mat.scatter_(1, var_pos, pos_id)
|
| 42 |
+
|
| 43 |
+
x = x + self.embedding(loc_mat)
|
| 44 |
+
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
class TransformerEncoder(nn.Module):
|
| 48 |
+
|
| 49 |
+
def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.2):
|
| 50 |
+
super(TransformerEncoder,self).__init__()
|
| 51 |
+
|
| 52 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
|
| 53 |
+
encoder_norm = nn.LayerNorm(d_model)
|
| 54 |
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
| 55 |
+
self.position = PositionalEncoding(d_model=d_model)
|
| 56 |
+
|
| 57 |
+
self._reset_parameters()
|
| 58 |
+
self.d_model = d_model
|
| 59 |
+
self.nhead = nhead
|
| 60 |
+
|
| 61 |
+
def _reset_parameters(self):
|
| 62 |
+
"""
|
| 63 |
+
Initiate parameters in the transformer model.
|
| 64 |
+
"""
|
| 65 |
+
for p in self.parameters():
|
| 66 |
+
if p.dim() > 1:
|
| 67 |
+
nn.init.xavier_uniform_(p)
|
| 68 |
+
|
| 69 |
+
def forward(self, len_src, emb_src):
|
| 70 |
+
# mask
|
| 71 |
+
src_key_padding_mask = ~sequence_mask(len_src)
|
| 72 |
+
# position encoding
|
| 73 |
+
emb_src = self.position(emb_src)
|
| 74 |
+
# encoder
|
| 75 |
+
memory = self.encoder(emb_src.permute(1,0,2), src_key_padding_mask=src_key_padding_mask)
|
| 76 |
+
|
| 77 |
+
return memory.permute(1,0,2)
|
model/module/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .module_ops import *
|
| 2 |
+
from .attention import *
|
model/module/attention.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class Score(nn.Module):
|
| 5 |
+
def __init__(self, input_size, hidden_size):
|
| 6 |
+
super(Score, self).__init__()
|
| 7 |
+
self.attn = nn.Linear(hidden_size + input_size, hidden_size)
|
| 8 |
+
self.score = nn.Linear(hidden_size, 1, bias=False)
|
| 9 |
+
|
| 10 |
+
def forward(self, hidden, candi_embeddings, candi_mask=None):
|
| 11 |
+
'''
|
| 12 |
+
Arguments:
|
| 13 |
+
hidden: B x 1 x 2H
|
| 14 |
+
candi_embeddings: B x candi_size x H
|
| 15 |
+
candi_mask: B x candi_size
|
| 16 |
+
Return:
|
| 17 |
+
score: B x candi_size
|
| 18 |
+
'''
|
| 19 |
+
hidden = hidden.repeat(1, candi_embeddings.size(1), 1) # B x candi_size x H
|
| 20 |
+
# For each position of encoder outputs
|
| 21 |
+
energy_in = torch.cat((hidden, candi_embeddings), 2) # B x candi_size x 3H
|
| 22 |
+
score = self.score(torch.tanh(self.attn(energy_in))).squeeze(-1) # B x candi_size
|
| 23 |
+
if candi_mask is not None:
|
| 24 |
+
score = score.masked_fill_(~candi_mask, -1e12)
|
| 25 |
+
return score
|
| 26 |
+
|
| 27 |
+
class Attn(nn.Module):
|
| 28 |
+
def __init__(self, input_size, hidden_size):
|
| 29 |
+
super(Attn, self).__init__()
|
| 30 |
+
self.attn = nn.Linear(hidden_size + input_size, hidden_size)
|
| 31 |
+
self.score = nn.Linear(hidden_size, 1, bias=False)
|
| 32 |
+
|
| 33 |
+
def forward(self, hidden, encoder_outputs, seq_mask=None):
|
| 34 |
+
'''
|
| 35 |
+
Arguments:
|
| 36 |
+
hidden: B x 1 x H (q)
|
| 37 |
+
encoder_outputs: B x S x H
|
| 38 |
+
seq_mask: B x S
|
| 39 |
+
Return:
|
| 40 |
+
attn_energies: B x S
|
| 41 |
+
'''
|
| 42 |
+
hidden = hidden.repeat(1, encoder_outputs.size(1), 1) # B x S x H
|
| 43 |
+
energy_in = torch.cat((hidden, encoder_outputs), 2) # B x S x 2H
|
| 44 |
+
score_feature = torch.tanh(self.attn(energy_in)) # B x S x H
|
| 45 |
+
attn_energies = self.score(score_feature).squeeze(-1) # B x S
|
| 46 |
+
if seq_mask is not None:
|
| 47 |
+
attn_energies = attn_energies.masked_fill_(~seq_mask, -1e12)
|
| 48 |
+
attn_energies = nn.functional.softmax(attn_energies, dim=1) # B x S
|
| 49 |
+
|
| 50 |
+
return attn_energies
|
| 51 |
+
|
| 52 |
+
class Score_Multi(nn.Module):
|
| 53 |
+
def __init__(self, input_size, hidden_size):
|
| 54 |
+
super(Score_Multi, self).__init__()
|
| 55 |
+
self.attn = nn.Linear(hidden_size + input_size, hidden_size)
|
| 56 |
+
self.score = nn.Linear(hidden_size, 1, bias=False)
|
| 57 |
+
|
| 58 |
+
def forward(self, hidden, candi_embeddings, candi_mask=None):
|
| 59 |
+
'''
|
| 60 |
+
Arguments:
|
| 61 |
+
hidden: B x S x H
|
| 62 |
+
candi_embeddings: B x candi_size x H
|
| 63 |
+
candi_mask: B x candi_size
|
| 64 |
+
Return:
|
| 65 |
+
score: B x S x candi_size
|
| 66 |
+
'''
|
| 67 |
+
hidden = hidden.unsqueeze(2).repeat(1, 1, candi_embeddings.size(1), 1) # B x S x candi_size x H
|
| 68 |
+
candi_embeddings = candi_embeddings.unsqueeze(1).repeat(1, hidden.size(1), 1, 1) # B x S x candi_size x H
|
| 69 |
+
candi_mask = candi_mask.unsqueeze(1).repeat(1, hidden.size(1), 1) # B x S x candi_size
|
| 70 |
+
energy_in = torch.cat((hidden, candi_embeddings), -1) # B x S x candi_size x 2H
|
| 71 |
+
score = self.score(torch.tanh(self.attn(energy_in))).squeeze(-1) # B x S x candi_size
|
| 72 |
+
if candi_mask is not None:
|
| 73 |
+
score = score.masked_fill_(~candi_mask, -1e12)
|
| 74 |
+
return score
|
model/module/module_ops.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GAP(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Global Average pooling
|
| 7 |
+
Widely used in ResNet, Inception, DenseNet, etc.
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super(GAP, self).__init__()
|
| 11 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
x = self.avgpool(x)
|
| 15 |
+
# x = x.view(x.shape[0], -1)
|
| 16 |
+
return x
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Identity(nn.Module):
|
| 20 |
+
def __init__(self):
|
| 21 |
+
super(Identity, self).__init__()
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
return x
|
| 25 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==1.7.1
|
| 2 |
+
torchvision==0.8.2
|
| 3 |
+
gradio==4.16.0
|
| 4 |
+
Pillow>=9.0.0
|
| 5 |
+
numpy>=1.19.0
|
| 6 |
+
antlr4-python3-runtime==4.10
|
| 7 |
+
sympy==1.11.1
|
| 8 |
+
func_timeout==4.3.5
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .lr_scheduler import *
|
| 2 |
+
from .utils import *
|
| 3 |
+
|
| 4 |
+
|
utils/lr_scheduler.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from bisect import bisect_right
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
optimizer,
|
| 9 |
+
milestones,
|
| 10 |
+
gamma=0.1,
|
| 11 |
+
warmup_factor=1.0 / 3,
|
| 12 |
+
warmup_epochs=5,
|
| 13 |
+
warmup_method="linear",
|
| 14 |
+
last_epoch=-1,
|
| 15 |
+
):
|
| 16 |
+
if not list(milestones) == sorted(milestones):
|
| 17 |
+
raise ValueError(
|
| 18 |
+
"Milestones should be a list of" " increasing integers. Got {}",
|
| 19 |
+
milestones,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
if warmup_method not in ("constant", "linear"):
|
| 23 |
+
raise ValueError(
|
| 24 |
+
"Only 'constant' or 'linear' warmup_method accepted"
|
| 25 |
+
"got {}".format(warmup_method)
|
| 26 |
+
)
|
| 27 |
+
self.milestones = milestones
|
| 28 |
+
self.gamma = gamma
|
| 29 |
+
self.warmup_factor = warmup_factor
|
| 30 |
+
self.warmup_epochs = warmup_epochs
|
| 31 |
+
self.warmup_method = warmup_method
|
| 32 |
+
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
|
| 33 |
+
|
| 34 |
+
def get_lr(self):
|
| 35 |
+
warmup_factor = 1
|
| 36 |
+
if self.last_epoch < self.warmup_epochs:
|
| 37 |
+
if self.warmup_method == "constant":
|
| 38 |
+
warmup_factor = self.warmup_factor
|
| 39 |
+
elif self.warmup_method == "linear":
|
| 40 |
+
alpha = float(self.last_epoch) / self.warmup_epochs
|
| 41 |
+
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
|
| 42 |
+
return [
|
| 43 |
+
base_lr
|
| 44 |
+
* warmup_factor
|
| 45 |
+
* self.gamma ** bisect_right(self.milestones, self.last_epoch)
|
| 46 |
+
for base_lr in self.base_lrs
|
| 47 |
+
]
|
utils/utils.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from utils.lr_scheduler import WarmupMultiStepLR
|
| 4 |
+
from config import *
|
| 5 |
+
import datetime
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
from datasets.operators import result_compute, normalize_exp
|
| 8 |
+
from func_timeout import func_timeout
|
| 9 |
+
import random
|
| 10 |
+
import gc
|
| 11 |
+
|
| 12 |
+
def save_checkpoint(state, is_best, dump_path=None):
|
| 13 |
+
if is_best:
|
| 14 |
+
dump_path_best = os.path.join(dump_path, 'best_model.pth')
|
| 15 |
+
torch.save(state, dump_path_best)
|
| 16 |
+
else:
|
| 17 |
+
dump_path_recent = os.path.join(dump_path, str(state['epoch'])+'.pth')
|
| 18 |
+
torch.save(state, dump_path_recent)
|
| 19 |
+
|
| 20 |
+
class AverageMeter(object):
|
| 21 |
+
"""
|
| 22 |
+
Computes and stores the average and current value
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, name, fmt=':f'):
|
| 25 |
+
self.name = name
|
| 26 |
+
self.fmt = fmt
|
| 27 |
+
self.reset()
|
| 28 |
+
|
| 29 |
+
def reset(self):
|
| 30 |
+
self.val = 0
|
| 31 |
+
self.avg = 0
|
| 32 |
+
self.sum = 0
|
| 33 |
+
self.count = 0
|
| 34 |
+
|
| 35 |
+
def update(self, val, n=1):
|
| 36 |
+
self.val = val
|
| 37 |
+
self.sum += val * n
|
| 38 |
+
self.count += n
|
| 39 |
+
self.avg = self.sum / self.count
|
| 40 |
+
|
| 41 |
+
def __str__(self):
|
| 42 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
| 43 |
+
return fmtstr.format(**self.__dict__)
|
| 44 |
+
|
| 45 |
+
class ProgressMeter(object):
|
| 46 |
+
def __init__(self, num_batches, meters, args, prefix=""):
|
| 47 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
| 48 |
+
self.meters = meters
|
| 49 |
+
self.prefix = prefix
|
| 50 |
+
self.args = args
|
| 51 |
+
|
| 52 |
+
def display(self, batch, lr=None):
|
| 53 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
| 54 |
+
entries += [str(meter) for meter in self.meters]
|
| 55 |
+
if not lr is None:
|
| 56 |
+
entries += ["lr: "+str(format(lr, '.6f'))]
|
| 57 |
+
self.args.logger.info('\t'.join(entries))
|
| 58 |
+
|
| 59 |
+
def _get_batch_fmtstr(self, num_batches):
|
| 60 |
+
num_digits = len(str(num_batches // 1))
|
| 61 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
| 62 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
| 63 |
+
|
| 64 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
| 65 |
+
"""
|
| 66 |
+
Sets the learning rate to the initial LR decayed by 10 every 30 epochs
|
| 67 |
+
"""
|
| 68 |
+
lr = args.lr * (0.1**(epoch // 30))
|
| 69 |
+
for param_group in optimizer.param_groups:
|
| 70 |
+
param_group['lr'] = lr
|
| 71 |
+
|
| 72 |
+
def accuracy(output, target, topk=(1, )):
|
| 73 |
+
"""
|
| 74 |
+
Computes the accuracy over the k top predictions for the specified values of k
|
| 75 |
+
"""
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
maxk = max(topk)
|
| 78 |
+
batch_size = target.size(0)
|
| 79 |
+
|
| 80 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 81 |
+
pred = pred.t()
|
| 82 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 83 |
+
|
| 84 |
+
res = []
|
| 85 |
+
for k in topk:
|
| 86 |
+
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
| 87 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
| 88 |
+
return res
|
| 89 |
+
|
| 90 |
+
def get_scheduler(args, optimizer):
|
| 91 |
+
if args.scheduler_type == "multistep":
|
| 92 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 93 |
+
optimizer,
|
| 94 |
+
args.scheduler_step,
|
| 95 |
+
gamma=args.scheduler_factor,
|
| 96 |
+
)
|
| 97 |
+
elif args.scheduler_type == "cosine":
|
| 98 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 99 |
+
optimizer, T_max=args.max_epochs, eta_min=1e-6)
|
| 100 |
+
elif args.scheduler_type == "warmup":
|
| 101 |
+
scheduler = WarmupMultiStepLR(
|
| 102 |
+
optimizer,
|
| 103 |
+
args.scheduler_step,
|
| 104 |
+
gamma=args.scheduler_factor,
|
| 105 |
+
warmup_epochs=args.warm_epoch,
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
raise NotImplementedError("Unsupported LR Scheduler: {}".format(args.scheduler_type))
|
| 109 |
+
return scheduler
|
| 110 |
+
|
| 111 |
+
def get_optimizer(args, model):
|
| 112 |
+
if args.use_MLM_pretrain:
|
| 113 |
+
pretrain_params = list(map(id, model.mlm_pretrain.parameters()))
|
| 114 |
+
other_params = filter(lambda p: id(p) not in pretrain_params, model.parameters())
|
| 115 |
+
|
| 116 |
+
if args.optimizer_type == "SGD":
|
| 117 |
+
optimizer = torch.optim.SGD(
|
| 118 |
+
model.parameters(),
|
| 119 |
+
lr=args.lr,
|
| 120 |
+
momentum=args.momentum,
|
| 121 |
+
weight_decay=args.weight_decay,
|
| 122 |
+
nesterov=True,
|
| 123 |
+
)
|
| 124 |
+
elif args.optimizer_type == "ADAM":
|
| 125 |
+
if args.use_MLM_pretrain:
|
| 126 |
+
optimizer = torch.optim.Adam(
|
| 127 |
+
[{"params":model.mlm_pretrain.parameters()},
|
| 128 |
+
{"params":other_params, "lr":args.lr_LM}],
|
| 129 |
+
lr=args.lr,
|
| 130 |
+
betas=(0.9, 0.999),
|
| 131 |
+
weight_decay=args.weight_decay,
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
optimizer = torch.optim.Adam(
|
| 135 |
+
model.parameters(),
|
| 136 |
+
lr=args.lr,
|
| 137 |
+
betas=(0.9, 0.999),
|
| 138 |
+
weight_decay=args.weight_decay,
|
| 139 |
+
)
|
| 140 |
+
elif args.optimizer_type == "ADAMW":
|
| 141 |
+
if args.use_MLM_pretrain:
|
| 142 |
+
optimizer = torch.optim.AdamW(
|
| 143 |
+
[{"params":model.mlm_pretrain.parameters(), "lr":args.lr_LM},
|
| 144 |
+
{"params":other_params}],
|
| 145 |
+
lr=args.lr,
|
| 146 |
+
weight_decay=args.weight_decay,
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
optimizer = torch.optim.AdamW(
|
| 150 |
+
model.parameters(),
|
| 151 |
+
lr=args.lr,
|
| 152 |
+
weight_decay=args.weight_decay,
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
raise NotImplementedError("Unsupported Optimizer Type : {}".format(args.optimizer_type))
|
| 156 |
+
|
| 157 |
+
return optimizer
|
| 158 |
+
|
| 159 |
+
def reduce_mean(tensor, nprocs):
|
| 160 |
+
rt = tensor.clone()
|
| 161 |
+
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
| 162 |
+
rt /= nprocs
|
| 163 |
+
return rt
|
| 164 |
+
|
| 165 |
+
def set_cuda(data_dict):
|
| 166 |
+
for key in data_dict:
|
| 167 |
+
if torch.is_tensor(data_dict[key]):
|
| 168 |
+
data_dict[key] = data_dict[key].cuda()
|
| 169 |
+
|
| 170 |
+
def initialize_logger(params, ):
|
| 171 |
+
"""
|
| 172 |
+
Initialize the experience:
|
| 173 |
+
- dump parameters
|
| 174 |
+
- create a logger
|
| 175 |
+
"""
|
| 176 |
+
while True:
|
| 177 |
+
exp_id = datetime.datetime.strftime(datetime.datetime.now(),'%Y-%m-%d-%H-%M-%S')
|
| 178 |
+
if not os.path.exists(os.path.join(params.dump_path, exp_id)):
|
| 179 |
+
break
|
| 180 |
+
params.dump_path = os.path.join(params.dump_path, exp_id)
|
| 181 |
+
if params.local_rank == 0:
|
| 182 |
+
os.makedirs(params.dump_path)
|
| 183 |
+
# create a logger
|
| 184 |
+
logger = create_logger(os.path.join(params.dump_path,'record.log'), params.local_rank)
|
| 185 |
+
logger.info("============ Initialized logger ============")
|
| 186 |
+
logger.info("\n"+"\n".join("\t\t\t\t%s: %s" % (k, str(v))
|
| 187 |
+
for k, v in sorted(dict(vars(params)).items())))
|
| 188 |
+
logger.info("The experiment results will be stored in %s" % params.dump_path)
|
| 189 |
+
return logger
|
| 190 |
+
|
| 191 |
+
def aeq(*args):
|
| 192 |
+
"""
|
| 193 |
+
Assert all arguments have the same value
|
| 194 |
+
"""
|
| 195 |
+
arguments = (arg for arg in args)
|
| 196 |
+
first = next(arguments)
|
| 197 |
+
assert all(arg == first for arg in arguments), \
|
| 198 |
+
"Not all arguments have the same value: " + str(args)
|
| 199 |
+
|
| 200 |
+
def sequence_mask(lengths, max_len=None):
|
| 201 |
+
"""
|
| 202 |
+
Creates a boolean mask from sequence lengths.
|
| 203 |
+
"""
|
| 204 |
+
batch_size = lengths.numel()
|
| 205 |
+
max_len = max_len or lengths.max()
|
| 206 |
+
return torch.arange(0, max_len, device=lengths.device) \
|
| 207 |
+
.type_as(lengths) \
|
| 208 |
+
.repeat(batch_size, 1) \
|
| 209 |
+
.lt(lengths.unsqueeze(1))
|
| 210 |
+
|
| 211 |
+
def copy_list(l):
|
| 212 |
+
r = []
|
| 213 |
+
if len(l) == 0:
|
| 214 |
+
return r
|
| 215 |
+
for i in l:
|
| 216 |
+
if type(i) is list:
|
| 217 |
+
r.append(copy_list(i))
|
| 218 |
+
else:
|
| 219 |
+
r.append(i)
|
| 220 |
+
return r
|
| 221 |
+
|
| 222 |
+
def compute_exp_result_choice(test_preds, var_dict, exp_dict, tgt_lang):
|
| 223 |
+
|
| 224 |
+
"""
|
| 225 |
+
Arguments
|
| 226 |
+
test_preds: B x candi_size(beam_size) x token_list
|
| 227 |
+
var_dict: {'pos', 'len', 'var_value', 'arg_value'}
|
| 228 |
+
exp_dict: {'exp', 'len', 'answer'}
|
| 229 |
+
tgt_lang: vocab of target text
|
| 230 |
+
Returns:
|
| 231 |
+
ans_acc
|
| 232 |
+
eq_acc
|
| 233 |
+
"""
|
| 234 |
+
gc.collect()
|
| 235 |
+
ans_num = eq_num = 0
|
| 236 |
+
|
| 237 |
+
for k in range(len(test_preds)): # batch id
|
| 238 |
+
tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS]
|
| 239 |
+
var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \
|
| 240 |
+
for i, item in enumerate(var_dict['arg_value'][k])}
|
| 241 |
+
tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict)
|
| 242 |
+
num_list = var_dict['var_value'][k]
|
| 243 |
+
tgt_result = float(exp_dict['answer'][k])
|
| 244 |
+
choices = exp_dict['choices'][k]
|
| 245 |
+
is_find_ans = False
|
| 246 |
+
|
| 247 |
+
for j in range(len(test_preds[k])): # pred candi id
|
| 248 |
+
try:
|
| 249 |
+
pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict)
|
| 250 |
+
pred = normalize_exp(pred)
|
| 251 |
+
pred_result = float(func_timeout(2.0, result_compute, \
|
| 252 |
+
kwargs=dict(num_all_list=num_list, exp_tokens=pred)))
|
| 253 |
+
if pred == tgt:
|
| 254 |
+
ans_num += 1
|
| 255 |
+
eq_num += 1
|
| 256 |
+
is_find_ans = True
|
| 257 |
+
break
|
| 258 |
+
for item in choices:
|
| 259 |
+
if abs(pred_result-item)<5e-2:
|
| 260 |
+
is_find_ans = True
|
| 261 |
+
if is_find_ans and abs(pred_result-tgt_result)<5e-3:
|
| 262 |
+
ans_num +=1
|
| 263 |
+
if len(pred)==len(tgt):
|
| 264 |
+
eq_num += 1
|
| 265 |
+
if is_find_ans: break
|
| 266 |
+
except:
|
| 267 |
+
pass
|
| 268 |
+
|
| 269 |
+
if not is_find_ans:
|
| 270 |
+
pred_result = random.choice(choices)
|
| 271 |
+
if abs(pred_result-tgt_result)<5e-2:
|
| 272 |
+
ans_num +=1
|
| 273 |
+
|
| 274 |
+
return ans_num/len(test_preds), eq_num/len(test_preds)
|
| 275 |
+
|
| 276 |
+
def compute_exp_result_topk(test_preds, var_dict, exp_dict, tgt_lang, k_num = 3):
|
| 277 |
+
|
| 278 |
+
"""
|
| 279 |
+
Arguments
|
| 280 |
+
test_preds: B x candi_size(beam_size) x token_list
|
| 281 |
+
var_dict: {'pos', 'len', 'var_value', 'arg_value'}
|
| 282 |
+
exp_dict: {'exp', 'len', 'answer'}
|
| 283 |
+
tgt_lang: vocab of target text
|
| 284 |
+
Returns:
|
| 285 |
+
ans_acc
|
| 286 |
+
eq_acc
|
| 287 |
+
"""
|
| 288 |
+
gc.collect()
|
| 289 |
+
ans_num = eq_num = 0
|
| 290 |
+
|
| 291 |
+
for k in range(len(test_preds)): # batch id
|
| 292 |
+
tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS]
|
| 293 |
+
var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \
|
| 294 |
+
for i, item in enumerate(var_dict['arg_value'][k])}
|
| 295 |
+
tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict)
|
| 296 |
+
num_list = var_dict['var_value'][k]
|
| 297 |
+
tgt_result = float(exp_dict['answer'][k])
|
| 298 |
+
is_ans_same = is_eq_same = False
|
| 299 |
+
|
| 300 |
+
for j in range(k_num): # top-n
|
| 301 |
+
try:
|
| 302 |
+
pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict)
|
| 303 |
+
pred = normalize_exp(pred)
|
| 304 |
+
pred_result = float(func_timeout(2.0, result_compute, \
|
| 305 |
+
kwargs=dict(num_all_list=num_list, exp_tokens=pred)))
|
| 306 |
+
if pred == tgt:
|
| 307 |
+
is_ans_same = True
|
| 308 |
+
is_eq_same = True
|
| 309 |
+
break
|
| 310 |
+
if abs(pred_result-tgt_result)<5e-3:
|
| 311 |
+
is_ans_same = True
|
| 312 |
+
if len(pred)==len(tgt):
|
| 313 |
+
is_eq_same = True
|
| 314 |
+
break
|
| 315 |
+
except:
|
| 316 |
+
pass
|
| 317 |
+
|
| 318 |
+
if is_ans_same: ans_num +=1
|
| 319 |
+
if is_eq_same: eq_num +=1
|
| 320 |
+
|
| 321 |
+
return ans_num/len(test_preds), eq_num/len(test_preds)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def compute_exp_result_comp(test_preds, var_dict, exp_dict, tgt_lang):
|
| 325 |
+
|
| 326 |
+
"""
|
| 327 |
+
Arguments
|
| 328 |
+
test_preds: B x candi_size(beam_size) x token_list
|
| 329 |
+
var_dict: {'pos', 'len', 'var_value', 'arg_value'}
|
| 330 |
+
exp_dict: {'exp', 'len', 'answer'}
|
| 331 |
+
tgt_lang: vocab of target text
|
| 332 |
+
Returns:
|
| 333 |
+
ans_acc
|
| 334 |
+
eq_acc
|
| 335 |
+
"""
|
| 336 |
+
gc.collect()
|
| 337 |
+
ans_num = eq_num = 0
|
| 338 |
+
|
| 339 |
+
for k in range(len(test_preds)): # batch id
|
| 340 |
+
tgt = exp_dict['exp'][k][1:exp_dict['len'][k]-1].tolist() # Remove special symbols [SOS] and [EOS]
|
| 341 |
+
var2arg_dict = {'N'+str(i+len(var_dict['var_value'][k])):item \
|
| 342 |
+
for i, item in enumerate(var_dict['arg_value'][k])}
|
| 343 |
+
tgt = tgt_lang.sentence_from_indexes(tgt, var2arg_dict)
|
| 344 |
+
num_list = var_dict['var_value'][k]
|
| 345 |
+
tgt_result = float(exp_dict['answer'][k])
|
| 346 |
+
is_ans_same = is_eq_same = False
|
| 347 |
+
|
| 348 |
+
for j in range(len(test_preds[k])): # pred candi id
|
| 349 |
+
try:
|
| 350 |
+
pred = tgt_lang.sentence_from_indexes(test_preds[k][j], var2arg_dict)
|
| 351 |
+
pred = normalize_exp(pred)
|
| 352 |
+
pred_result = float(func_timeout(2.0, result_compute, \
|
| 353 |
+
kwargs=dict(num_all_list=num_list, exp_tokens=pred)))
|
| 354 |
+
if pred == tgt:
|
| 355 |
+
is_ans_same = True
|
| 356 |
+
is_eq_same = True
|
| 357 |
+
break
|
| 358 |
+
if abs(pred_result-tgt_result)<5e-3:
|
| 359 |
+
is_ans_same = True
|
| 360 |
+
if len(pred)==len(tgt):
|
| 361 |
+
is_eq_same = True
|
| 362 |
+
break
|
| 363 |
+
except:
|
| 364 |
+
pass
|
| 365 |
+
|
| 366 |
+
if is_ans_same: ans_num +=1
|
| 367 |
+
if is_eq_same: eq_num +=1
|
| 368 |
+
|
| 369 |
+
return ans_num/len(test_preds), eq_num/len(test_preds)
|
vocab/vocab_src.txt
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[UNK]
|
| 3 |
+
[CLS]
|
| 4 |
+
[MASK]
|
| 5 |
+
+
|
| 6 |
+
,
|
| 7 |
+
.
|
| 8 |
+
/
|
| 9 |
+
1
|
| 10 |
+
10
|
| 11 |
+
11
|
| 12 |
+
12
|
| 13 |
+
13
|
| 14 |
+
14
|
| 15 |
+
15
|
| 16 |
+
16
|
| 17 |
+
17
|
| 18 |
+
18
|
| 19 |
+
19
|
| 20 |
+
2
|
| 21 |
+
20
|
| 22 |
+
3
|
| 23 |
+
4
|
| 24 |
+
5
|
| 25 |
+
6
|
| 26 |
+
7
|
| 27 |
+
8
|
| 28 |
+
9
|
| 29 |
+
=
|
| 30 |
+
?
|
| 31 |
+
A
|
| 32 |
+
B
|
| 33 |
+
B'
|
| 34 |
+
B0
|
| 35 |
+
B1
|
| 36 |
+
C
|
| 37 |
+
C0
|
| 38 |
+
C1
|
| 39 |
+
D
|
| 40 |
+
E
|
| 41 |
+
F
|
| 42 |
+
F0
|
| 43 |
+
F1
|
| 44 |
+
G
|
| 45 |
+
G0
|
| 46 |
+
G1
|
| 47 |
+
H
|
| 48 |
+
I
|
| 49 |
+
J
|
| 50 |
+
K
|
| 51 |
+
L
|
| 52 |
+
M
|
| 53 |
+
N
|
| 54 |
+
O
|
| 55 |
+
P
|
| 56 |
+
Q
|
| 57 |
+
Q'
|
| 58 |
+
R
|
| 59 |
+
S
|
| 60 |
+
T
|
| 61 |
+
U
|
| 62 |
+
V
|
| 63 |
+
W
|
| 64 |
+
W'
|
| 65 |
+
X
|
| 66 |
+
Y
|
| 67 |
+
Z
|
| 68 |
+
\angle
|
| 69 |
+
\cong
|
| 70 |
+
\cos
|
| 71 |
+
\odot
|
| 72 |
+
\parallel
|
| 73 |
+
\parallelogram
|
| 74 |
+
\perp
|
| 75 |
+
\phi
|
| 76 |
+
\sim
|
| 77 |
+
\sin
|
| 78 |
+
\tan
|
| 79 |
+
\triangle
|
| 80 |
+
\widehat
|
| 81 |
+
a
|
| 82 |
+
about
|
| 83 |
+
all
|
| 84 |
+
altitude
|
| 85 |
+
altitudes
|
| 86 |
+
an
|
| 87 |
+
and
|
| 88 |
+
angle
|
| 89 |
+
angles
|
| 90 |
+
any
|
| 91 |
+
appear
|
| 92 |
+
appears
|
| 93 |
+
are
|
| 94 |
+
area
|
| 95 |
+
areas
|
| 96 |
+
as
|
| 97 |
+
assume
|
| 98 |
+
at
|
| 99 |
+
b
|
| 100 |
+
base
|
| 101 |
+
bases
|
| 102 |
+
be
|
| 103 |
+
below
|
| 104 |
+
between
|
| 105 |
+
bisector
|
| 106 |
+
bisectors
|
| 107 |
+
bisects
|
| 108 |
+
blue
|
| 109 |
+
both
|
| 110 |
+
by
|
| 111 |
+
c
|
| 112 |
+
calculator
|
| 113 |
+
center
|
| 114 |
+
centers
|
| 115 |
+
centimeters
|
| 116 |
+
central
|
| 117 |
+
centroid
|
| 118 |
+
chord
|
| 119 |
+
chords
|
| 120 |
+
circle
|
| 121 |
+
circles
|
| 122 |
+
circumference
|
| 123 |
+
circumscribed
|
| 124 |
+
circumscribes
|
| 125 |
+
cm
|
| 126 |
+
cm^{2}
|
| 127 |
+
collinear
|
| 128 |
+
common
|
| 129 |
+
complementary
|
| 130 |
+
composite
|
| 131 |
+
congruent
|
| 132 |
+
connecting
|
| 133 |
+
corners
|
| 134 |
+
cosines
|
| 135 |
+
cut
|
| 136 |
+
d
|
| 137 |
+
degree
|
| 138 |
+
determine
|
| 139 |
+
diagonal
|
| 140 |
+
diagonals
|
| 141 |
+
diagram
|
| 142 |
+
diameter
|
| 143 |
+
diameters
|
| 144 |
+
distance
|
| 145 |
+
drawn
|
| 146 |
+
e
|
| 147 |
+
each
|
| 148 |
+
elm
|
| 149 |
+
equal
|
| 150 |
+
equidistant
|
| 151 |
+
equilateral
|
| 152 |
+
exact
|
| 153 |
+
express
|
| 154 |
+
f
|
| 155 |
+
factor
|
| 156 |
+
feet
|
| 157 |
+
figure
|
| 158 |
+
figures
|
| 159 |
+
find
|
| 160 |
+
for
|
| 161 |
+
form
|
| 162 |
+
formed
|
| 163 |
+
four
|
| 164 |
+
from
|
| 165 |
+
ft
|
| 166 |
+
ft^{2}
|
| 167 |
+
g
|
| 168 |
+
given
|
| 169 |
+
green
|
| 170 |
+
h
|
| 171 |
+
half
|
| 172 |
+
has
|
| 173 |
+
have
|
| 174 |
+
having
|
| 175 |
+
height
|
| 176 |
+
hexagon
|
| 177 |
+
how
|
| 178 |
+
hypotenuse
|
| 179 |
+
i
|
| 180 |
+
if
|
| 181 |
+
in
|
| 182 |
+
in^{2}
|
| 183 |
+
incenter
|
| 184 |
+
inches
|
| 185 |
+
inscribed
|
| 186 |
+
inside
|
| 187 |
+
intersect
|
| 188 |
+
intersected
|
| 189 |
+
intersecting
|
| 190 |
+
into
|
| 191 |
+
is
|
| 192 |
+
isosceles
|
| 193 |
+
it
|
| 194 |
+
its
|
| 195 |
+
j
|
| 196 |
+
k
|
| 197 |
+
kite
|
| 198 |
+
l
|
| 199 |
+
law
|
| 200 |
+
legs
|
| 201 |
+
length
|
| 202 |
+
lengths
|
| 203 |
+
let
|
| 204 |
+
lieson
|
| 205 |
+
line
|
| 206 |
+
linear
|
| 207 |
+
lines
|
| 208 |
+
long
|
| 209 |
+
m
|
| 210 |
+
m^{2}
|
| 211 |
+
major
|
| 212 |
+
make
|
| 213 |
+
makes
|
| 214 |
+
measure
|
| 215 |
+
measurement
|
| 216 |
+
measures
|
| 217 |
+
median
|
| 218 |
+
medians
|
| 219 |
+
meet
|
| 220 |
+
meters
|
| 221 |
+
midpoint
|
| 222 |
+
midpoints
|
| 223 |
+
midsegment
|
| 224 |
+
midsegments
|
| 225 |
+
miles
|
| 226 |
+
millimeters
|
| 227 |
+
minor
|
| 228 |
+
mm
|
| 229 |
+
mm^{2}
|
| 230 |
+
must
|
| 231 |
+
n
|
| 232 |
+
o
|
| 233 |
+
octagon
|
| 234 |
+
of
|
| 235 |
+
off
|
| 236 |
+
on
|
| 237 |
+
one
|
| 238 |
+
otherwise
|
| 239 |
+
p
|
| 240 |
+
pair
|
| 241 |
+
parallel
|
| 242 |
+
parallelogram
|
| 243 |
+
pentagon
|
| 244 |
+
pentagons
|
| 245 |
+
perimeter
|
| 246 |
+
perimeters
|
| 247 |
+
perpendicular
|
| 248 |
+
plum
|
| 249 |
+
point
|
| 250 |
+
points
|
| 251 |
+
polygon
|
| 252 |
+
polygons
|
| 253 |
+
proportion
|
| 254 |
+
pythagorean
|
| 255 |
+
q
|
| 256 |
+
quadrilateral
|
| 257 |
+
r
|
| 258 |
+
radius
|
| 259 |
+
ratio
|
| 260 |
+
ray
|
| 261 |
+
rectangle
|
| 262 |
+
red
|
| 263 |
+
refer
|
| 264 |
+
region
|
| 265 |
+
regular
|
| 266 |
+
respectively
|
| 267 |
+
rhombus
|
| 268 |
+
right
|
| 269 |
+
s
|
| 270 |
+
scale
|
| 271 |
+
sector
|
| 272 |
+
segment
|
| 273 |
+
segments
|
| 274 |
+
shaded
|
| 275 |
+
shown
|
| 276 |
+
side
|
| 277 |
+
sides
|
| 278 |
+
similar
|
| 279 |
+
sines
|
| 280 |
+
so
|
| 281 |
+
solve
|
| 282 |
+
special
|
| 283 |
+
square
|
| 284 |
+
stated
|
| 285 |
+
straight
|
| 286 |
+
such
|
| 287 |
+
sum
|
| 288 |
+
suppose
|
| 289 |
+
t
|
| 290 |
+
tangent
|
| 291 |
+
tangents
|
| 292 |
+
that
|
| 293 |
+
the
|
| 294 |
+
theorem
|
| 295 |
+
this
|
| 296 |
+
times
|
| 297 |
+
to
|
| 298 |
+
trapezoid
|
| 299 |
+
triangle
|
| 300 |
+
triangles
|
| 301 |
+
triple
|
| 302 |
+
twice
|
| 303 |
+
two
|
| 304 |
+
u
|
| 305 |
+
units
|
| 306 |
+
unless
|
| 307 |
+
use
|
| 308 |
+
v
|
| 309 |
+
value
|
| 310 |
+
variable
|
| 311 |
+
vertex
|
| 312 |
+
w
|
| 313 |
+
what
|
| 314 |
+
where
|
| 315 |
+
which
|
| 316 |
+
with
|
| 317 |
+
would
|
| 318 |
+
x
|
| 319 |
+
y
|
| 320 |
+
yards
|
| 321 |
+
yd^{2}
|
| 322 |
+
z
|
vocab/vocab_tgt.txt
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[PAD]
|
| 2 |
+
[SOS]
|
| 3 |
+
[EOS]
|
| 4 |
+
V0
|
| 5 |
+
V1
|
| 6 |
+
V2
|
| 7 |
+
V3
|
| 8 |
+
V4
|
| 9 |
+
V5
|
| 10 |
+
V6
|
| 11 |
+
C0.5
|
| 12 |
+
C2
|
| 13 |
+
C3
|
| 14 |
+
C4
|
| 15 |
+
C5
|
| 16 |
+
C6
|
| 17 |
+
C8
|
| 18 |
+
C60
|
| 19 |
+
C90
|
| 20 |
+
C180
|
| 21 |
+
C360
|
| 22 |
+
ArcSeg_Area
|
| 23 |
+
Chord2_Ang
|
| 24 |
+
Circle_D_Area
|
| 25 |
+
Circle_D_Circum
|
| 26 |
+
Circle_R_Area
|
| 27 |
+
Circle_R_Circum
|
| 28 |
+
Cos_Law
|
| 29 |
+
Equal
|
| 30 |
+
Gcos
|
| 31 |
+
Geo_Mean
|
| 32 |
+
Get
|
| 33 |
+
Gougu
|
| 34 |
+
Gsin
|
| 35 |
+
Gtan
|
| 36 |
+
Iso_Tri_Ang
|
| 37 |
+
Kite_Area
|
| 38 |
+
Median
|
| 39 |
+
Multiple
|
| 40 |
+
Ngon_Angsum
|
| 41 |
+
PRK_Perim
|
| 42 |
+
Para_Area
|
| 43 |
+
Proportion
|
| 44 |
+
RNgon_B_Area
|
| 45 |
+
RNgon_H_Area
|
| 46 |
+
RNgon_L_Area
|
| 47 |
+
Ratio
|
| 48 |
+
Rect_Area
|
| 49 |
+
Rhom_Area
|
| 50 |
+
Sin_Law
|
| 51 |
+
Sum
|
| 52 |
+
TanSec_Ang
|
| 53 |
+
Trap_Area
|
| 54 |
+
Tria_BH_Area
|
| 55 |
+
Tria_SAS_Area
|
| 56 |
+
N0
|
| 57 |
+
N1
|
| 58 |
+
N2
|
| 59 |
+
N3
|
| 60 |
+
N4
|
| 61 |
+
N5
|
| 62 |
+
N6
|
| 63 |
+
N7
|
| 64 |
+
N8
|
| 65 |
+
N9
|
| 66 |
+
N10
|
| 67 |
+
N11
|