|
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
from tensorflow.keras.layers import TextVectorization, Embedding, GlobalAveragePooling1D, Dot, Dense, Input |
|
|
from tensorflow.keras.models import Model |
|
|
|
|
|
class TwoTowerModel(Model): |
|
|
def __init__(self, vocabulary: list[str], input_keys: list[str], embedding_dim: int, output_embedding_dim: int, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.vocabulary = vocabulary |
|
|
self.input_keys = input_keys |
|
|
self.embedding_dim = embedding_dim |
|
|
self.output_embedding_dim = output_embedding_dim |
|
|
|
|
|
if len(input_keys) != 2: |
|
|
raise ValueError("input_keys listesi tam olarak 2 eleman içermelidir.") |
|
|
|
|
|
self.vectorize_layer = TextVectorization( |
|
|
vocabulary=vocabulary, |
|
|
output_mode='int', |
|
|
output_sequence_length=1, |
|
|
name='text_vectorization', |
|
|
standardize='lower_and_strip_punctuation' |
|
|
) |
|
|
self.tower = self._create_tower(embedding_dim, output_embedding_dim) |
|
|
self.dot_product = Dot(axes=-1, normalize=True) |
|
|
self.output_prob = Dense(1, activation='sigmoid') |
|
|
|
|
|
def _create_tower(self, embedding_dim, output_embedding_dim) -> Model: |
|
|
vocabulary_size = self.vectorize_layer.vocabulary_size() |
|
|
input_layer = Input(shape=(1,), dtype=tf.string) |
|
|
vectorized_text = self.vectorize_layer(input_layer) |
|
|
embedding_layer = Embedding(input_dim=vocabulary_size, output_dim=embedding_dim)(vectorized_text) |
|
|
pooling_layer = GlobalAveragePooling1D()(embedding_layer) |
|
|
hidden_layer_1 = Dense(256, activation='relu')(pooling_layer) |
|
|
hidden_layer_2 = Dense(128, activation='relu')(hidden_layer_1) |
|
|
output_layer = Dense(output_embedding_dim)(hidden_layer_2) |
|
|
return Model(inputs=input_layer, outputs=output_layer, name="inference_tower") |
|
|
|
|
|
def call(self, inputs: dict): |
|
|
tower_1_output = self.tower(inputs[self.input_keys[0]]) |
|
|
tower_2_output = self.tower(inputs[self.input_keys[1]]) |
|
|
similarity_score = self.dot_product([tower_1_output, tower_2_output]) |
|
|
return self.output_prob(similarity_score) |
|
|
|
|
|
def get_embedding(self, input_word: tf.Tensor) -> tf.Tensor: |
|
|
return self.tower(input_word) |
|
|
|
|
|
def get_config(self): |
|
|
config = super().get_config() |
|
|
config.update({ |
|
|
"vocabulary": self.vocabulary, |
|
|
"input_keys": self.input_keys, |
|
|
"embedding_dim": self.embedding_dim, |
|
|
"output_embedding_dim": self.output_embedding_dim, |
|
|
}) |
|
|
return config |
|
|
|
|
|
@classmethod |
|
|
def from_config(cls, config): |
|
|
vocabulary = config.pop("vocabulary") |
|
|
input_keys = config.pop("input_keys") |
|
|
embedding_dim = config.pop("embedding_dim") |
|
|
output_embedding_dim = config.pop("output_embedding_dim") |
|
|
return cls(vocabulary=vocabulary, input_keys=input_keys, embedding_dim=embedding_dim, output_embedding_dim=output_embedding_dim, **config) |