from transformers import TFAutoModel from tensorflow import keras import tensorflow as tf class BERTForClassification(keras.Model): def __init__(self, bert_model, num_classes): super(BERTForClassification, self).__init__() self.bert = bert_model self.dropout = keras.layers.Dropout(0.3) self.classifier = keras.layers.Dense(num_classes, activation="softmax") def call(self, inputs, training=False): x = self.bert(inputs)[1] # Pooler output x = self.dropout(x, training=training) return self.classifier(x) # Load the model def load_model(): bert_model = TFAutoModel.from_pretrained("bert-base-uncased") model = BERTForClassification(bert_model, num_classes=2) model.load_weights("my_model_weights.h5") return model