import argparse valid_models = ["hybrid", "tf_encoder", "penta", "bilstm"] valid_embeddings = ["none", "glove", "bert"] def init_argparse(): """ CLI Arguments for training phase """ parser = argparse.ArgumentParser( prog="Training Model", usage="Arguments: --model, --embedding, --batch_size(optional).", description="""Example: python train.py --model penta --embedding None --dataset_size 1 --batch_size 32 -- epochs 3 --model: Type of model for training. Expect one of ['hybrid', 'tf_encoder', 'penta', "bilstm"]. --embedding: Type of word-level embedding. Expect one of [None, 'glove', 'bert']. --dataset_size: Dataset size for training. Default: 1 (All dataset). --batch_size: Batch size. Default: 32. --epochs: Epochs. Default: 3.""" ) parser.add_argument("--model", required=True, help='Type of model: hybrid, att, tf_encoder, penta') parser.add_argument( "--embedding", required=True, help='Word embedding: None, Glove or BERT' ) parser.add_argument( "--dataset_size", required=False, default= 1, help= "Dataset size" ) parser.add_argument( "--batch_size",required=False ,default = 32, help='Batch size' ) parser.add_argument( "--epochs",required=False ,default = 3, help='Epochs' ) return parser def init_infer_argparse(): """ CLI Arguments for infer phase """ parser = argparse.ArgumentParser( prog="Training Model", usage="Arguments: --model, --embedding", description="""Example: python infer.py --model penta --embedding None --model: Type of model for training. Expect one of ['hybrid', 'tf_encoder', 'bilstm']. --embedding: Type of word-level embedding. Expect one of [None, 'glove', 'bert']. """ ) parser.add_argument("--model", required=True, help='Type of model: hybrid, tf_encoder, penta, bilstm') parser.add_argument( "--embedding", required=True, help='Word embedding: None, Glove or BERT' ) return parser def check_valid_args(args): """ Check valid input from CLI """ if str(args.model).lower() not in valid_models: raise TypeError("No model named: {}, expected valid model belongs to {}".format(args.model, valid_models)) elif str(args.embedding).lower() not in valid_embeddings: raise TypeError("No embedding type named: {}, expeted valid embedding belongs to {}".format(args.embedding, valid_embeddings)) return True if __name__ == "__main__": parser = init_argparse() args = parser.parse_args()