File size: 1,475 Bytes
4f39bc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# src/model.py
import tensorflow as tf
from tensorflow.keras import layers, Model
from transformers import TFAutoModel
class CrossEncoderTF(Model):
def __init__(self, model_name="dbmdz/bert-base-turkish-cased", max_token_len=32, **kwargs):
super().__init__(**kwargs)
self.model_name = model_name
self.max_token_len = max_token_len
# 1. Metin Hattı (Transformer)
self.bert = TFAutoModel.from_pretrained(model_name)
# 2. Sadece çıktı katmanı
self.classifier = tf.keras.Sequential([
layers.Dense(256, activation='relu'),
layers.BatchNormalization(),
layers.Dropout(0.3),
layers.Dense(128, activation='relu'),
layers.BatchNormalization(),
layers.Dense(64, activation='relu'),
layers.BatchNormalization(),
layers.Dense(1, activation='sigmoid')
], name="classifier")
def call(self, inputs):
bert_output = self.bert(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
text_features = bert_output.pooler_output
prediction_score = self.classifier(text_features)
return prediction_score
def get_config(self):
config = super().get_config()
config.update({
"model_name": self.model_name,
"max_token_len": self.max_token_len,
})
return config
|