AU-VN-ResearchGroup commited on
Commit
21fda44
·
1 Parent(s): 492edc7
Files changed (1) hide show
  1. inference.py +225 -0
inference.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ nltk.download("punkt")
3
+ from nltk.tokenize import sent_tokenize
4
+ from src.config.configs import *
5
+ from src.create_embeddings import *
6
+ from src.dataset import *
7
+ from src.models.baseline import *
8
+ from src.models.transformer_encoder_based import *
9
+ from src.models.hybrid_embeddings_model import *
10
+ from src.models.penta_embeddings_model import *
11
+ from src.models.hierarchy_BiLSTM import *
12
+ from args import init_argparse, check_valid_args
13
+ from src.utils import *
14
+ import tensorflow as tf
15
+ import pandas as pd
16
+ from args import init_infer_argparse, check_valid_args
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+ import re
20
+
21
+
22
+ params = Params()
23
+ CHECK_POINT_MAP = {
24
+ "hybrid":{"none": params.HYBRID_NOR_MODEL_DIR, "glove": params.HYBRID_GLOVE_MODEL_DIR, "bert": params.HYBRID_BERT_MODEL_DIR},
25
+ "tf_encoder": {"none": params.TF_BASED_NOR_MODEL_DIR, "glove": params.TF_BASED_GLOVE_MODEL_DIR, "bert": params.TF_BASED_BERT_MODEL_DIR},
26
+ "penta": {"none":params.PENTA_NOR_MODEL_DIR, "glove":params.PENTA_GLOVE_MODEL_DIR, "bert": params.PENTA_BERT_MODEL_DIR},
27
+ "bilstm":{"none":params.PENTA_BILSTM_NOR_MODEL_DIR, "glove": params.PENTA_BILSTM_GLOVE_MODEL_DIR, "bert":params.PENTA_BILSTM_BERT_MODEL_DIR}}
28
+
29
+
30
+ def read_infer_txt(infer_txt):
31
+ with open(infer_txt, "r") as f:
32
+ return f.readlines()
33
+
34
+
35
+ def replace_numeric_chars_with_at(list_sencentes):
36
+ """
37
+ Replace numeric characters with "@"
38
+ """
39
+ result = []
40
+ for sent in list_sencentes:
41
+ res = re.sub(r'\d', '@', sent)
42
+ result.append(res)
43
+ return result
44
+
45
+
46
+
47
+ def infer(abstract, verbose = True):
48
+ """
49
+ Get prediction from abstract
50
+ args:
51
+ - abstract: All sentences of abstract in one string.
52
+ """
53
+ # Init infer parser
54
+ parser = init_infer_argparse()
55
+ args = parser.parse_args()
56
+
57
+ #Check valid args
58
+ if not check_valid_args(args):
59
+ exit(1)
60
+
61
+ # Sentencizer
62
+ list_sens = sent_tokenize(abstract)
63
+
64
+ # Store original sentence
65
+ list_sens_org = list_sens
66
+
67
+ #Replace numeric at @
68
+ list_sens = replace_numeric_chars_with_at(list_sens)
69
+
70
+ # Extract features
71
+ line_samples = get_information_infer(list_sens)
72
+
73
+ # Create dataframe
74
+ infer_df = pd.DataFrame(line_samples)
75
+
76
+ # Get features
77
+ infer_sentences = infer_df['text']
78
+ infer_chars = [split_into_char(line) for line in infer_sentences]
79
+
80
+ # Convert to tensor
81
+ infer_sentences = np.array(infer_sentences, dtype=str)
82
+ infer_chars = np.array(infer_chars,dtype= str)
83
+
84
+ # Define args variable
85
+ model_arg = str(args.model).lower()
86
+ embedding_arg = str(args.embedding).lower()
87
+
88
+ embeddings = Embeddings()
89
+ dataset = Dataset(train_txt=params.TRAIN_DIR, val_txt=params.VAL_DIR, test_txt=params.TEST_DIR)
90
+
91
+ # Word_vectorizer, word_embed
92
+ word_vectorizer, word_embed = embeddings._get_word_embeddings(dataset.train_sentences)
93
+ char_vectorizer, char_embed = embeddings._get_char_embeddings(dataset.train_char)
94
+
95
+
96
+ # Get type embedding
97
+ glove_embed = embeddings._get_glove_embeddings(vectorizer=word_vectorizer, glove_txt=params.GLOVE_DIR) if str(embedding_arg).lower() == "glove" else None
98
+
99
+ # Get stats features
100
+ line_ids_one_hot = tf.one_hot(infer_df['line_id'].to_numpy(), depth = params.LINE_IDS_DEPTH)
101
+
102
+ length_lines_one_hot = tf.one_hot(infer_df['length_lines'].to_numpy(), depth = params.LENGTH_LINES_DEPTH)
103
+
104
+ total_lines_one_hot = tf.one_hot(infer_df['total_lines'].to_numpy(), depth= params.TOTAL_LINES_DEPTH)
105
+
106
+
107
+ if embedding_arg == "bert":
108
+ bert_process, bert_layer = embeddings._get_bert_embeddings()
109
+ else:
110
+ bert_process, bert_layer = None, None
111
+
112
+ # Define model checkpoint dir
113
+ model_dir = CHECK_POINT_MAP[model_arg][embedding_arg]
114
+
115
+ #--------------------------------HYBRID-INPUT-MODEL-----------------------------------
116
+ if model_arg == "hybrid":
117
+ print("-------------Inference Hybrid model with pretrained embedding: {}-------------------".format(embedding_arg))
118
+
119
+ hybrid_obj = HybridEmbeddingModel(word_vectorizer=word_vectorizer, char_vectorizer=char_vectorizer, word_embed=word_embed,
120
+ char_embed=char_embed, pretrained_embedding=embedding_arg,
121
+ glove_embed=glove_embed, bert_process=bert_process, bert_layer=bert_layer)
122
+ hybrid_model = hybrid_obj._get_model()
123
+
124
+ try:
125
+ hybrid_model.load_weights(model_dir + "/best_model.ckpt")
126
+ print("Sucessfully load model weights from {}".format(model_dir + "/best_model.ckpt"))
127
+ except Exception as e:
128
+ print(e)
129
+ exit()
130
+
131
+ preds = hybrid_model.predict(x = (infer_sentences, infer_chars))
132
+
133
+ #--------------------------------TF_ENCODER-MODEL-----------------------------------
134
+
135
+ elif model_arg == "tf_encoder":
136
+ print("-------------Inference TransformerEncoder-based with pretrained embedding: {}-------------------".format(embedding_arg))
137
+
138
+ tf_obj = TransformerModel(word_vectorizer=word_vectorizer, char_vectorizer=char_vectorizer, word_embed=word_embed, char_embed = char_embed,
139
+ num_layers=params.NUM_LAYERS, d_model=params.D_MODEL, nhead=params.N_HEAD,
140
+ dim_feedforward=params.DIM_FEEDFORWARD,pretrained_embedding=embedding_arg, glove_embed=glove_embed,
141
+ bert_process=bert_process, bert_layer= bert_layer)
142
+
143
+ tf_model = tf_obj._get_model()
144
+
145
+ try:
146
+ tf_model.load_weights(model_dir + "/best_model.ckpt")
147
+ except Exception as e:
148
+ print(e)
149
+ exit()
150
+
151
+ print("Sucessfully load model weights from {}".format(model_dir + "/best_model.ckpt"))
152
+
153
+ # Get prediction
154
+ preds = tf_model.predict(x = (infer_sentences, infer_chars, line_ids_one_hot, length_lines_one_hot, total_lines_one_hot))
155
+
156
+ #--------------------------------HIERARCHY_BILSTM MODEL-----------------------------------
157
+
158
+ elif model_arg == "bilstm":
159
+ print("-------------Inference Hierarchy Bi-LSTM with pretrained embedding: {}-------------------".format(embedding_arg))
160
+
161
+ bilstm_obj = HierarchyBiLSTM(word_vectorizer=word_vectorizer, char_vectorizer=char_vectorizer, word_embed=word_embed, char_embed = char_embed,
162
+ pretrained_embedding=embedding_arg, glove_embed=glove_embed,
163
+ bert_process=bert_process, bert_layer= bert_layer)
164
+ bilstm_model = bilstm_obj._get_model()
165
+
166
+ try:
167
+ bilstm_model.load_weights(model_dir + "/best_model.ckpt")
168
+ except Exception as e:
169
+ print(e)
170
+ exit()
171
+
172
+ print("Sucessfully load model weights from {}".format(model_dir + "/best_model.ckpt"))
173
+
174
+ # Make sure input has suitable data types
175
+ infer_sentences = np.array(infer_sentences, dtype=str)
176
+ infer_chars = np.array(infer_chars,dtype= str)
177
+
178
+ # Get prediction
179
+ preds = bilstm_model.predict(x = (infer_sentences, infer_chars, line_ids_one_hot, length_lines_one_hot, total_lines_one_hot))
180
+
181
+ #-----------------------PENTA-EMBEDDING MODEL-------------------------------------------
182
+ else:
183
+
184
+ print("-------------Inference Penta-embedding model with pretrained embedding: {}-------------------".format(embedding_arg))
185
+
186
+ penta_obj = PentaEmbeddingModel(word_vectorizer=word_vectorizer, char_vectorizer=char_vectorizer, word_embed=word_embed, char_embed = char_embed,
187
+ pretrained_embedding=embedding_arg, glove_embed=glove_embed, bert_process=bert_process, bert_layer = bert_layer)
188
+ penta_model = penta_obj._get_model()
189
+
190
+ try:
191
+ penta_model.load_weights(model_dir + "/best_model.ckpt")
192
+ except Exception as e:
193
+ print(e)
194
+ exit()
195
+
196
+ print("Sucessfully load model weights from {}".format(model_dir + "/best_model.ckpt"))
197
+ # Get prediction
198
+ preds = penta_model.predict(x = (infer_sentences, infer_chars, line_ids_one_hot, length_lines_one_hot, total_lines_one_hot))
199
+
200
+
201
+ # Get prediction index
202
+ class_index = dataset.classes
203
+ preds_index = np.argmax(preds, axis = 1)
204
+ preds_class = [class_index[preds_index[i]] for i in range(0, len(preds_index))]
205
+
206
+ if verbose:
207
+ for i, sent in enumerate(list_sens_org):
208
+ print("{} --> Pred: {} | Prob: {}".format(sent, preds_class[i], preds[i][preds_index[i]]))
209
+
210
+ return preds_class
211
+
212
+
213
+ if __name__ == "__main__":
214
+ params = Params()
215
+ dataset = Dataset(train_txt=params.TRAIN_DIR, val_txt=params.VAL_DIR, test_txt=params.TEST_DIR)
216
+ infer_txt = "infer_abstract.txt"
217
+ abstract_list = read_infer_txt(infer_txt=infer_txt)
218
+ for i, abtract in enumerate(abstract_list):
219
+ print("------------Predict abstract number {}--------------".format(i+1))
220
+ preds = infer(abstract=abtract)
221
+ print("Result:", preds)
222
+ print()
223
+
224
+
225
+