| | from transformers import TFAutoModel |
| | from tensorflow import keras |
| | import tensorflow as tf |
| |
|
| | class BERTForClassification(keras.Model): |
| | def __init__(self, bert_model, num_classes): |
| | super(BERTForClassification, self).__init__() |
| | self.bert = bert_model |
| | self.dropout = keras.layers.Dropout(0.3) |
| | self.classifier = keras.layers.Dense(num_classes, activation="softmax") |
| |
|
| | def call(self, inputs, training=False): |
| | x = self.bert(inputs)[1] |
| | x = self.dropout(x, training=training) |
| | return self.classifier(x) |
| |
|
| | |
| | def load_model(): |
| | bert_model = TFAutoModel.from_pretrained("bert-base-uncased") |
| | model = BERTForClassification(bert_model, num_classes=2) |
| | model.load_weights("my_model_weights.h5") |
| | return model |
| |
|