| | import streamlit as st |
| | import pandas as pd |
| | import numpy as np |
| | from unidecode import unidecode |
| | import tensorflow as tf |
| | import cloudpickle |
| | from transformers import DistilBertTokenizerFast |
| | import os |
| | from matplotlib import pyplot as plt |
| | from PIL import Image |
| |
|
| |
|
| | with open(os.path.join("models", "toxic_comment_preprocessor_classnames.bin"), "rb") as model_file_obj: |
| | text_preprocessor, class_names = cloudpickle.load(model_file_obj) |
| | interpreter = tf.lite.Interpreter(model_path=os.path.join("models", "toxic_comment_classifier_hf_distilbert.tflite")) |
| |
|
| | def sigmoid(x): |
| | return 1 / (1 + np.exp(-x)) |
| |
|
| | def inference(text): |
| | text = text_preprocessor.preprocess(pd.Series(text))[0] |
| | model_checkpoint = "distilbert-base-uncased" |
| | tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint) |
| | tokens = tokenizer(text, max_length=512, padding="max_length", truncation=True, return_tensors="tf") |
| | |
| | |
| | interpreter.allocate_tensors() |
| | input_details = interpreter.get_input_details() |
| | output_details = interpreter.get_output_details()[0] |
| | attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids'] |
| | interpreter.set_tensor(input_details[0]["index"], attention_mask) |
| | interpreter.set_tensor(input_details[1]["index"], input_ids) |
| | interpreter.invoke() |
| | tflite_logits = interpreter.get_tensor(output_details["index"])[0] |
| | tflite_pred = sigmoid(tflite_logits) |
| | |
| | result_df = pd.DataFrame({'class': class_names, 'prob': tflite_pred}) |
| | result_df.sort_values(by='prob', ascending=True, inplace=True) |
| | return result_df |
| |
|
| |
|
| | def display_image(df): |
| | fig, ax = plt.subplots(figsize=(2, 1.8)) |
| | df.plot(x='class', y='prob', kind='barh', ax=ax, color='black', ylabel='') |
| | ax.tick_params(axis='both', which='major', labelsize=8.5) |
| | ax.get_legend().remove() |
| | ax.spines['top'].set_visible(False) |
| | ax.spines['right'].set_visible(False) |
| | ax.spines['bottom'].set_visible(False) |
| | ax.spines['left'].set_visible(False) |
| | ax.get_xaxis().set_ticks([]) |
| | plt.rcParams["figure.autolayout"] = True |
| | plt.xlim(0, 1) |
| | for n, i in enumerate([*df['prob']]): |
| | plt.text(i+0.015, n-0.15, f'{str(np.round(i, 3))} ', fontsize=7.5) |
| |
|
| | fig.savefig("prediction.png", bbox_inches='tight', dpi=100) |
| | image = Image.open('prediction.png') |
| | st.write('') |
| | st.image(image, output_format="PNG", caption="Prediction") |
| | |
| | |
| | def main(): |
| | st.title("Toxic Comment Classifier") |
| | comment_txt = st.text_area("Enter a comment:", "", height=100) |
| | if st.button("Submit"): |
| | df = inference(comment_txt) |
| | display_image(df) |
| |
|
| | |
| |
|
| | if __name__ == "__main__": |
| | main() |