prem / convert_bert.py
NitishStark's picture
Deploy source code to Hugging Face without binaries
c25dcd7
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import os
model_name = "mrm8488/bert-tiny-finetuned-sms-spam-detection"
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSequenceClassification.from_pretrained(model_name, from_pt=True)
print("Exporting vocab...")
os.makedirs("models", exist_ok=True)
with open("models/vocab.txt", "w", encoding="utf-8") as f:
for token, index in sorted(tokenizer.get_vocab().items(), key=lambda x: x[1]):
f.write(f"{token}\n")
print("Tracing model for TFLite conversion...")
MAX_SEQ_LEN = 128
callable_model = tf.function(
lambda input_ids, attention_mask: model(input_ids=input_ids, attention_mask=attention_mask, training=False)
)
concrete_func = callable_model.get_concrete_function(
tf.TensorSpec([1, MAX_SEQ_LEN], tf.int32, name="input_ids"),
tf.TensorSpec([1, MAX_SEQ_LEN], tf.int32, name="attention_mask")
)
print("Converting to TFLite...")
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
with open("models/scam_classifier.tflite", "wb") as f:
f.write(tflite_model)
print("Done! Model saved to models/scam_classifier.tflite")