File size: 1,475 Bytes
4f39bc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# src/model.py

import tensorflow as tf
from tensorflow.keras import layers, Model
from transformers import TFAutoModel

class CrossEncoderTF(Model):
    def __init__(self, model_name="dbmdz/bert-base-turkish-cased", max_token_len=32, **kwargs):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.max_token_len = max_token_len

        # 1. Metin Hattı (Transformer)
        self.bert = TFAutoModel.from_pretrained(model_name)
        
        # 2. Sadece çıktı katmanı
        self.classifier = tf.keras.Sequential([
            layers.Dense(256, activation='relu'),
            layers.BatchNormalization(),
            layers.Dropout(0.3),
            layers.Dense(128, activation='relu'),
            layers.BatchNormalization(),
            layers.Dense(64, activation='relu'),
            layers.BatchNormalization(),
            layers.Dense(1, activation='sigmoid') 
        ], name="classifier")

    def call(self, inputs):
        bert_output = self.bert(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        text_features = bert_output.pooler_output

        prediction_score = self.classifier(text_features)
        return prediction_score

    def get_config(self):
        config = super().get_config()
        config.update({
            "model_name": self.model_name,
            "max_token_len": self.max_token_len, 
        })
        return config