|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
from tensorflow.keras import layers, Model
|
|
|
from transformers import TFAutoModel
|
|
|
|
|
|
class MixedDataCrossEncoderTF(Model):
|
|
|
def __init__(self, model_name="dbmdz/bert-base-turkish-cased", numerical_feature_dim=5132, max_token_len=32, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.model_name = model_name
|
|
|
self.numerical_feature_dim = numerical_feature_dim
|
|
|
self.max_token_len = max_token_len
|
|
|
|
|
|
self.bert = TFAutoModel.from_pretrained(model_name)
|
|
|
|
|
|
self.numerical_mlp = tf.keras.Sequential([
|
|
|
layers.Input(shape=(numerical_feature_dim,)),
|
|
|
layers.Dense(512, activation='relu'),
|
|
|
layers.Dropout(0.3),
|
|
|
layers.Dense(128, activation='relu')
|
|
|
], name="numerical_mlp")
|
|
|
|
|
|
self.concatenation = layers.Concatenate()
|
|
|
self.classifier = tf.keras.Sequential([
|
|
|
layers.Dense(256, activation='relu'),
|
|
|
layers.BatchNormalization(),
|
|
|
layers.Dropout(0.3),
|
|
|
layers.Dense(128, activation='relu'),
|
|
|
layers.BatchNormalization(),
|
|
|
layers.Dense(64, activation='relu'),
|
|
|
layers.BatchNormalization(),
|
|
|
layers.Dense(1, activation='sigmoid')
|
|
|
], name="classifier")
|
|
|
|
|
|
def call(self, inputs):
|
|
|
bert_output = self.bert(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
|
|
text_features = bert_output.pooler_output
|
|
|
|
|
|
numerical_processed_features = self.numerical_mlp(inputs['numerical_features'])
|
|
|
|
|
|
combined_features = self.concatenation([text_features, numerical_processed_features])
|
|
|
|
|
|
prediction_score = self.classifier(combined_features)
|
|
|
return prediction_score
|
|
|
|
|
|
def get_config(self):
|
|
|
config = super().get_config()
|
|
|
config.update({
|
|
|
"model_name": self.model_name,
|
|
|
"numerical_feature_dim": self.numerical_feature_dim,
|
|
|
"max_token_len": self.max_token_len,
|
|
|
})
|
|
|
return config
|
|
|
|
|
|
|