AU-VN-ResearchGroup commited on
Commit
765fe8d
·
1 Parent(s): 21fda44
Files changed (1) hide show
  1. args.py +77 -0
args.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ valid_models = ["hybrid", "tf_encoder", "penta", "bilstm"]
5
+ valid_embeddings = ["none", "glove", "bert"]
6
+
7
+
8
+
9
+ def init_argparse():
10
+ """
11
+ CLI Arguments for training phase
12
+ """
13
+ parser = argparse.ArgumentParser(
14
+ prog="Training Model",
15
+ usage="Arguments: --model, --embedding, --batch_size(optional).",
16
+ description="""Example: python train.py --model penta --embedding None --dataset_size 1 --batch_size 32 -- epochs 3
17
+ --model: Type of model for training. Expect one of ['hybrid', 'tf_encoder', 'penta', "bilstm"].
18
+ --embedding: Type of word-level embedding. Expect one of [None, 'glove', 'bert'].
19
+ --dataset_size: Dataset size for training. Default: 1 (All dataset).
20
+ --batch_size: Batch size. Default: 32.
21
+ --epochs: Epochs. Default: 3."""
22
+ )
23
+ parser.add_argument("--model", required=True, help='Type of model: hybrid, att, tf_encoder, penta')
24
+ parser.add_argument(
25
+ "--embedding", required=True,
26
+ help='Word embedding: None, Glove or BERT'
27
+ )
28
+ parser.add_argument(
29
+ "--dataset_size", required=False, default= 1, help= "Dataset size"
30
+ )
31
+ parser.add_argument(
32
+ "--batch_size",required=False ,default = 32,
33
+ help='Batch size'
34
+ )
35
+ parser.add_argument(
36
+ "--epochs",required=False ,default = 3,
37
+ help='Epochs'
38
+ )
39
+ return parser
40
+
41
+
42
+
43
+ def init_infer_argparse():
44
+ """
45
+ CLI Arguments for infer phase
46
+ """
47
+ parser = argparse.ArgumentParser(
48
+ prog="Training Model",
49
+ usage="Arguments: --model, --embedding",
50
+ description="""Example: python infer.py --model penta --embedding None
51
+ --model: Type of model for training. Expect one of ['hybrid', 'tf_encoder', 'bilstm'].
52
+ --embedding: Type of word-level embedding. Expect one of [None, 'glove', 'bert'].
53
+ """
54
+ )
55
+ parser.add_argument("--model", required=True, help='Type of model: hybrid, tf_encoder, penta, bilstm')
56
+ parser.add_argument(
57
+ "--embedding", required=True,
58
+ help='Word embedding: None, Glove or BERT'
59
+ )
60
+ return parser
61
+
62
+
63
+ def check_valid_args(args):
64
+ """
65
+ Check valid input from CLI
66
+ """
67
+ if str(args.model).lower() not in valid_models:
68
+ raise TypeError("No model named: {}, expected valid model belongs to {}".format(args.model, valid_models))
69
+ elif str(args.embedding).lower() not in valid_embeddings:
70
+ raise TypeError("No embedding type named: {}, expeted valid embedding belongs to {}".format(args.embedding, valid_embeddings))
71
+
72
+ return True
73
+
74
+
75
+ if __name__ == "__main__":
76
+ parser = init_argparse()
77
+ args = parser.parse_args()