# two_tower_model.py import tensorflow as tf from tensorflow.keras.layers import TextVectorization, Embedding, GlobalAveragePooling1D, Dot, Dense, Input from tensorflow.keras.models import Model class TwoTowerModel(Model): def __init__(self, vocabulary: list[str], input_keys: list[str], embedding_dim: int, output_embedding_dim: int, **kwargs): super().__init__(**kwargs) self.vocabulary = vocabulary self.input_keys = input_keys self.embedding_dim = embedding_dim self.output_embedding_dim = output_embedding_dim if len(input_keys) != 2: raise ValueError("input_keys listesi tam olarak 2 eleman içermelidir.") self.vectorize_layer = TextVectorization( vocabulary=vocabulary, output_mode='int', output_sequence_length=1, name='text_vectorization', standardize='lower_and_strip_punctuation' ) self.tower = self._create_tower(embedding_dim, output_embedding_dim) self.dot_product = Dot(axes=-1, normalize=True) self.output_prob = Dense(1, activation='sigmoid') def _create_tower(self, embedding_dim, output_embedding_dim) -> Model: vocabulary_size = self.vectorize_layer.vocabulary_size() input_layer = Input(shape=(1,), dtype=tf.string) vectorized_text = self.vectorize_layer(input_layer) embedding_layer = Embedding(input_dim=vocabulary_size, output_dim=embedding_dim)(vectorized_text) pooling_layer = GlobalAveragePooling1D()(embedding_layer) hidden_layer_1 = Dense(256, activation='relu')(pooling_layer) hidden_layer_2 = Dense(128, activation='relu')(hidden_layer_1) output_layer = Dense(output_embedding_dim)(hidden_layer_2) return Model(inputs=input_layer, outputs=output_layer, name="inference_tower") def call(self, inputs: dict): tower_1_output = self.tower(inputs[self.input_keys[0]]) tower_2_output = self.tower(inputs[self.input_keys[1]]) similarity_score = self.dot_product([tower_1_output, tower_2_output]) return self.output_prob(similarity_score) def get_embedding(self, input_word: tf.Tensor) -> tf.Tensor: return self.tower(input_word) def get_config(self): config = super().get_config() config.update({ "vocabulary": self.vocabulary, "input_keys": self.input_keys, "embedding_dim": self.embedding_dim, "output_embedding_dim": self.output_embedding_dim, }) return config @classmethod def from_config(cls, config): vocabulary = config.pop("vocabulary") input_keys = config.pop("input_keys") embedding_dim = config.pop("embedding_dim") output_embedding_dim = config.pop("output_embedding_dim") return cls(vocabulary=vocabulary, input_keys=input_keys, embedding_dim=embedding_dim, output_embedding_dim=output_embedding_dim, **config)