File size: 798 Bytes
91bde62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from transformers import TFAutoModel
from tensorflow import keras
import tensorflow as tf

class BERTForClassification(keras.Model):
    def __init__(self, bert_model, num_classes):
        super(BERTForClassification, self).__init__()
        self.bert = bert_model
        self.dropout = keras.layers.Dropout(0.3)
        self.classifier = keras.layers.Dense(num_classes, activation="softmax")

    def call(self, inputs, training=False):
        x = self.bert(inputs)[1]  # Pooler output
        x = self.dropout(x, training=training)
        return self.classifier(x)

# Load the model
def load_model():
    bert_model = TFAutoModel.from_pretrained("bert-base-uncased")
    model = BERTForClassification(bert_model, num_classes=2)
    model.load_weights("my_model_weights.h5")
    return model