import streamlit as st import pandas as pd import numpy as np from unidecode import unidecode import tensorflow as tf import cloudpickle from transformers import DistilBertTokenizerFast import os def load_model(): interpreter = tf.lite.Interpreter(model_path=os.path.join("models/dbpedia_classifier_hf_distilbert_l3.tflite")) with open("models/preprocessor_labelencoder_l3.bin", "rb") as model_file_obj: text_preprocessor, label_encoder = cloudpickle.load(model_file_obj) with open("models/label_map_l3.bin", "rb") as model_file_obj: label_map = cloudpickle.load(model_file_obj) model_checkpoint = "distilbert-base-uncased" tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint) return interpreter, text_preprocessor, label_encoder, tokenizer, label_map interpreter, text_preprocessor, label_encoder, tokenizer, label_map = load_model() def inference(text): tflite_pred = "Can't Predict" if text != "": text = text_preprocessor.preprocess(pd.Series(text))[0] tokens = tokenizer(text, max_length=200, padding="max_length", truncation=True, return_tensors="tf") # tflite model inference interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details()[0] attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids'] interpreter.set_tensor(input_details[0]["index"], attention_mask) interpreter.set_tensor(input_details[1]["index"], input_ids) interpreter.invoke() tflite_pred = interpreter.get_tensor(output_details["index"])[0] tflite_pred_argmax = np.argmax(tflite_pred) tflite_pred = f"{label_map[label_encoder.inverse_transform([tflite_pred_argmax])[0]]} ({str(np.round(tflite_pred[tflite_pred_argmax], 5))})" return tflite_pred def main(): st.title("Wikipedia Article Classification") st.markdown('

The model is fine-tuned to classify an article into 219 categories

', unsafe_allow_html=True) review = st.text_area("Paste an article:", "", height=200) if st.button("Submit"): result = inference(review) st.write(result) if __name__ == "__main__": main()