| import tensorflow as tf |
| import tensorflow_hub as hub |
| import tensorflow_text as text |
| import os |
| import sys |
|
|
| current_dir = os.path.dirname(os.path.abspath(__file__)) |
| project_root = os.path.abspath(os.path.join(current_dir, os.pardir, os.pardir)) |
| sys.path.append(project_root) |
|
|
|
|
| class Params(object): |
| def __init__(self): |
| |
| self.DATA_DIR = project_root +"/data/data/PubMed_20k_RCT_numbers_replaced_with_at_sign" |
| self.TRAIN_DIR = self.DATA_DIR + '/train.txt' |
| self.VAL_DIR = self.DATA_DIR + '/dev.txt' |
| self.TEST_DIR = self.DATA_DIR + '/test.txt' |
| self.CHECK_POINT_DIR = "/checkpoints" |
|
|
|
|
| |
| self.BATCH_SIZE = 32 |
| self.NUM_CLASSES = 5 |
|
|
| |
| self.OPTIMIZER = tf.keras.optimizers.Adam() |
| self.EPOCHS = 3 |
| self.LOSS = tf.keras.losses.CategoricalCrossentropy() |
| self.METRICS = tf.keras.metrics.CategoricalAccuracy() |
| self.MONITOR = "val_categorical_accuracy" |
|
|
| |
|
|
| |
| self.VOCAB_SIZE = 68000 |
| self.SEQ_LENGTH = 55 |
| self.WORD_OUTPUT_DIM = 128 |
|
|
|
|
| |
| self.CHAR_VOCAB = 28 |
| self.CHAR_LENGTH = 290 |
| self.CHAR_OUTPUT_DIM = 25 |
|
|
| |
| self.LINE_IDS_DEPTH = 15 |
| self.TOTAL_LINES_DEPTH = 20 |
| self.LENGTH_LINES_DEPTH = 55 |
|
|
| |
| self.GLOVE_DIR = project_root + "/glove/glove/glove.6B.200d.txt" |
|
|
| |
| self.BERT_PROCESS_DIR = project_root + "/bert/bert/bert_en_uncased_preprocess_3" |
| self.BERT_EMBED_DIR = project_root + "bert/bert/experts_bert_pubmed_2" |
|
|
| |
| self.NUM_LAYERS = 4 |
| self.N_HEAD = 8 |
| self.DIM_FEEDFORWARD = 256 |
| self.D_MODEL = 128 |
|
|
| |
|
|
| |
| self.PENTA_NOR_MODEL_DIR = project_root + "/checkpoints/checkpoints/penta_model/penta_embedding/nor_model" |
| self.PENTA_BERT_MODEL_DIR = project_root + "/checkpoints/checkpoints/penta_model/penta_embedding/bert_model" |
| self.PENTA_GLOVE_MODEL_DIR = project_root + "/checkpoints/checkpoints/penta_model/penta_embedding/glove_model" |
|
|
| self.PENTA_BILSTM_NOR_MODEL_DIR = project_root + "/checkpoints/checkpoints/penta_model/hierarchy_BiLSTM/nor_model" |
| self.PENTA_BILSTM_GLOVE_MODEL_DIR = project_root + "/checkpoints/checkpoints/penta_model/hierarchy_BiLSTM/glove_model" |
| self.PENTA_BILSTM_BERT_MODEL_DIR = project_root + "/checkpoints/penta_model/hierarchy_BiLSTM/bert_model" |
|
|
| self.TF_BASED_NOR_MODEL_DIR = project_root + "/checkpoints/checkpoints/penta_model/transformer_model/nor_model" |
| self.TF_BASED_GLOVE_MODEL_DIR = project_root + "/checkpoints/checkpoints/penta_model/transformer_model/glove_model" |
| self.TF_BASED_BERT_MODEL_DIR = project_root + "/checkpoints/checkpoints/penta_model/transformer_model/bert_model" |
|
|
| |
| self.HYBRID_NOR_MODEL_DIR = project_root + "/checkpoints/checkpoints/hybrid_model/nor_model" |
| self.HYBRID_GLOVE_MODEL_DIR = project_root + "/checkpoints/checkpoints/hybrid_model/glove_model" |
| self.HYBRID_BERT_MODEL_DIR = project_root + "/checkpoints/checkpoints/hybrid_model/bert_model" |
|
|
|
|
| |
| self.RESULT_DIR = project_root + "/results.txt" |
|
|
| |
| self.VECTORIZATION = project_root + "/text_vectorization_obj/" |
| self.WORD_VECTORIZATION = project_root + "/text_vectorization_obj/tv_layer.pkl" |
| self.CHAR_VECTORIZATION = project_root + "/text_vectorization_obj/char_tv_layer.pkl" |
|
|
|
|
| if __name__ == "__main__": |
| params = Params() |
| print(os.path.exists(params.DATA_DIR)) |
| print(os.path.exists(params.RESULT_DIR)) |
|
|
|
|
|
|
|
|