| |
| |
| import os |
| |
| |
| |
| |
| import pandas as pd |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| import tensorflow as tf |
| from tensorflow import keras |
| from tensorflow.keras import layers |
| from transformers import BertTokenizer |
| from transformers import TFBertModel |
|
|
| MAX_SEQUENCE_LENGTH = 400 |
|
|
| def create_bert_classification_model(bert_model, |
| num_train_layers=0, |
| max_sequence_length=MAX_SEQUENCE_LENGTH, |
| num_filters = [100, 100, 50, 25], |
| kernel_sizes = [3, 4, 5, 10], |
| hidden_size = 200, |
| hidden2_size = 100, |
| dropout = 0.1, |
| learning_rate = 0.001, |
| label_smoothing = 0.03 |
| ): |
| """ |
| Build a simple classification model with BERT. Use the Pooler Output or CLS for classification purposes |
| """ |
| if num_train_layers == 0: |
| |
| bert_model.trainable = False |
|
|
| elif num_train_layers == 12: |
| |
| bert_model.trainable = True |
|
|
| else: |
| |
| retrain_layers = [] |
|
|
| for retrain_layer_number in range(num_train_layers): |
|
|
| layer_code = '_' + str(11 - retrain_layer_number) |
| retrain_layers.append(layer_code) |
|
|
|
|
| |
|
|
| for w in bert_model.weights: |
| if not any([x in w.name for x in retrain_layers]): |
| |
| w._trainable = False |
|
|
| input_ids = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='input_ids') |
| token_type_ids = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='token_type_ids') |
| attention_mask = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='attention_mask') |
| |
| bert_inputs = {'input_ids': input_ids, |
| 'token_type_ids': token_type_ids, |
| 'attention_mask': attention_mask} |
|
|
| bert_out = bert_model(bert_inputs) |
|
|
| pooler_token = bert_out[1] |
| cls_token = bert_out[0][:, 0, :] |
| bert_out_avg = tf.math.reduce_mean(bert_out[0], axis=1) |
| cnn_token = bert_out[0] |
|
|
| conv_layers_for_all_kernel_sizes = [] |
| for kernel_size, filters in zip(kernel_sizes, num_filters): |
| conv_layer = tf.keras.layers.Conv1D(filters=filters, kernel_size=kernel_size, activation='relu')(cnn_token) |
| conv_layer = tf.keras.layers.GlobalMaxPooling1D()(conv_layer) |
| conv_layers_for_all_kernel_sizes.append(conv_layer) |
|
|
| conv_output = tf.keras.layers.concatenate(conv_layers_for_all_kernel_sizes, axis=1) |
|
|
| |
| hidden = tf.keras.layers.Dense(hidden_size, activation='relu', name='hidden_layer')(conv_output) |
| hidden = tf.keras.layers.Dropout(dropout)(hidden) |
|
|
| hidden = tf.keras.layers.Dense(hidden2_size, activation='relu', name='hidden_layer2')(hidden) |
| hidden = tf.keras.layers.Dropout(dropout)(hidden) |
| |
| classification = tf.keras.layers.Dense(1, activation='sigmoid',name='classification_layer')(hidden) |
|
|
| classification_model = tf.keras.Model(inputs=[input_ids, token_type_ids, attention_mask], outputs=[classification]) |
|
|
| classification_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), |
| |
| loss=tf.keras.losses.BinaryFocalCrossentropy( |
| gamma=2.0, from_logits=False, apply_class_balancing=True, label_smoothing=label_smoothing |
| ), |
| |
| metrics=['accuracy'] |
| ) |
| return classification_model |
|
|
|
|
| f_one_or_zero = lambda x: 1 if x > 0.5 else 0 |
|
|
| def run_inference_model(conversations): |
| |
| tokenized_input = tokenizer(conversations, |
| max_length=MAX_SEQUENCE_LENGTH, |
| truncation=True, |
| padding='max_length', |
| return_tensors='tf') |
| bert_inputs = [tokenized_input.input_ids, |
| tokenized_input.token_type_ids, |
| tokenized_input.attention_mask] |
|
|
| |
| y_pred = inference_model.predict(bert_inputs) |
| |
| return y_pred |
|
|
|
|
|
|
| model_checkpoint = "bert-base-uncased" |
| |
| tokenizer = BertTokenizer.from_pretrained(model_checkpoint) |
| |
| bert_model = TFBertModel.from_pretrained(model_checkpoint) |
| |
| inference_model = create_bert_classification_model(bert_model=bert_model) |
| |
| save_path = 'bert_cnn_ensemble_resample_uncased_mdl.h5' |
| inference_model.load_weights(save_path) |
|
|
|
|
|
|
|
|
|
|