SaiGaneshanM's picture
commit files to HF hub
91bde62
raw
history blame contribute delete
798 Bytes
from transformers import TFAutoModel
from tensorflow import keras
import tensorflow as tf
class BERTForClassification(keras.Model):
def __init__(self, bert_model, num_classes):
super(BERTForClassification, self).__init__()
self.bert = bert_model
self.dropout = keras.layers.Dropout(0.3)
self.classifier = keras.layers.Dense(num_classes, activation="softmax")
def call(self, inputs, training=False):
x = self.bert(inputs)[1] # Pooler output
x = self.dropout(x, training=training)
return self.classifier(x)
# Load the model
def load_model():
bert_model = TFAutoModel.from_pretrained("bert-base-uncased")
model = BERTForClassification(bert_model, num_classes=2)
model.load_weights("my_model_weights.h5")
return model