two_tower_string_similarity / src /two_tower_model.py
hanifekaptan's picture
src klasörüne taşındı
34c4c12 verified
# two_tower_model.py
import tensorflow as tf
from tensorflow.keras.layers import TextVectorization, Embedding, GlobalAveragePooling1D, Dot, Dense, Input
from tensorflow.keras.models import Model
class TwoTowerModel(Model):
def __init__(self, vocabulary: list[str], input_keys: list[str], embedding_dim: int, output_embedding_dim: int, **kwargs):
super().__init__(**kwargs)
self.vocabulary = vocabulary
self.input_keys = input_keys
self.embedding_dim = embedding_dim
self.output_embedding_dim = output_embedding_dim
if len(input_keys) != 2:
raise ValueError("input_keys listesi tam olarak 2 eleman içermelidir.")
self.vectorize_layer = TextVectorization(
vocabulary=vocabulary,
output_mode='int',
output_sequence_length=1,
name='text_vectorization',
standardize='lower_and_strip_punctuation'
)
self.tower = self._create_tower(embedding_dim, output_embedding_dim)
self.dot_product = Dot(axes=-1, normalize=True)
self.output_prob = Dense(1, activation='sigmoid')
def _create_tower(self, embedding_dim, output_embedding_dim) -> Model:
vocabulary_size = self.vectorize_layer.vocabulary_size()
input_layer = Input(shape=(1,), dtype=tf.string)
vectorized_text = self.vectorize_layer(input_layer)
embedding_layer = Embedding(input_dim=vocabulary_size, output_dim=embedding_dim)(vectorized_text)
pooling_layer = GlobalAveragePooling1D()(embedding_layer)
hidden_layer_1 = Dense(256, activation='relu')(pooling_layer)
hidden_layer_2 = Dense(128, activation='relu')(hidden_layer_1)
output_layer = Dense(output_embedding_dim)(hidden_layer_2)
return Model(inputs=input_layer, outputs=output_layer, name="inference_tower")
def call(self, inputs: dict):
tower_1_output = self.tower(inputs[self.input_keys[0]])
tower_2_output = self.tower(inputs[self.input_keys[1]])
similarity_score = self.dot_product([tower_1_output, tower_2_output])
return self.output_prob(similarity_score)
def get_embedding(self, input_word: tf.Tensor) -> tf.Tensor:
return self.tower(input_word)
def get_config(self):
config = super().get_config()
config.update({
"vocabulary": self.vocabulary,
"input_keys": self.input_keys,
"embedding_dim": self.embedding_dim,
"output_embedding_dim": self.output_embedding_dim,
})
return config
@classmethod
def from_config(cls, config):
vocabulary = config.pop("vocabulary")
input_keys = config.pop("input_keys")
embedding_dim = config.pop("embedding_dim")
output_embedding_dim = config.pop("output_embedding_dim")
return cls(vocabulary=vocabulary, input_keys=input_keys, embedding_dim=embedding_dim, output_embedding_dim=output_embedding_dim, **config)