Spaces:
Running
Running
| import tensorflow as tf | |
| from transformers import AutoTokenizer, TFAutoModelForSequenceClassification | |
| import os | |
| model_name = "mrm8488/bert-tiny-finetuned-sms-spam-detection" | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = TFAutoModelForSequenceClassification.from_pretrained(model_name, from_pt=True) | |
| print("Exporting vocab...") | |
| os.makedirs("models", exist_ok=True) | |
| with open("models/vocab.txt", "w", encoding="utf-8") as f: | |
| for token, index in sorted(tokenizer.get_vocab().items(), key=lambda x: x[1]): | |
| f.write(f"{token}\n") | |
| print("Tracing model for TFLite conversion...") | |
| MAX_SEQ_LEN = 128 | |
| callable_model = tf.function( | |
| lambda input_ids, attention_mask: model(input_ids=input_ids, attention_mask=attention_mask, training=False) | |
| ) | |
| concrete_func = callable_model.get_concrete_function( | |
| tf.TensorSpec([1, MAX_SEQ_LEN], tf.int32, name="input_ids"), | |
| tf.TensorSpec([1, MAX_SEQ_LEN], tf.int32, name="attention_mask") | |
| ) | |
| print("Converting to TFLite...") | |
| converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) | |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
| converter.target_spec.supported_ops = [ | |
| tf.lite.OpsSet.TFLITE_BUILTINS, | |
| tf.lite.OpsSet.SELECT_TF_OPS | |
| ] | |
| tflite_model = converter.convert() | |
| with open("models/scam_classifier.tflite", "wb") as f: | |
| f.write(tflite_model) | |
| print("Done! Model saved to models/scam_classifier.tflite") | |