File size: 2,788 Bytes
8ae18de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55dbd1
 
8ae18de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55dbd1
 
 
8ae18de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
import pandas as pd
import numpy as np
from unidecode import unidecode
import tensorflow as tf 
import cloudpickle
from transformers import DistilBertTokenizerFast
import os
from matplotlib import pyplot as plt
from PIL import Image


with open(os.path.join("models", "toxic_comment_preprocessor_classnames.bin"), "rb") as model_file_obj:
        text_preprocessor, class_names = cloudpickle.load(model_file_obj)    
interpreter = tf.lite.Interpreter(model_path=os.path.join("models", "toxic_comment_classifier_hf_distilbert.tflite"))

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def inference(text):
    text = text_preprocessor.preprocess(pd.Series(text))[0]
    model_checkpoint = "distilbert-base-uncased"
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
    tokens = tokenizer(text, max_length=512, padding="max_length", truncation=True, return_tensors="tf")
    
    # tflite model inference  
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()[0]
    attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
    interpreter.set_tensor(input_details[0]["index"], attention_mask)
    interpreter.set_tensor(input_details[1]["index"], input_ids)
    interpreter.invoke()
    tflite_logits = interpreter.get_tensor(output_details["index"])[0]
    tflite_pred = sigmoid(tflite_logits)
    
    result_df = pd.DataFrame({'class': class_names, 'prob': tflite_pred})
    result_df.sort_values(by='prob', ascending=True, inplace=True)
    return result_df


def display_image(df):
    fig, ax = plt.subplots(figsize=(2, 1.8))
    df.plot(x='class', y='prob', kind='barh', ax=ax, color='black', ylabel='')
    ax.tick_params(axis='both', which='major', labelsize=8.5)
    ax.get_legend().remove()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.get_xaxis().set_ticks([])
    plt.rcParams["figure.autolayout"] = True
    plt.xlim(0, 1)
    for n, i in enumerate([*df['prob']]):
        plt.text(i+0.015, n-0.15, f'{str(np.round(i, 3))}   ', fontsize=7.5)

    fig.savefig("prediction.png", bbox_inches='tight', dpi=100)
    image = Image.open('prediction.png')
    st.write('')
    st.image(image, output_format="PNG", caption="Prediction")
    
############## ENTRY POINT START #######################
def main():
    st.title("Toxic Comment Classifier")
    comment_txt = st.text_area("Enter a comment:", "", height=100)
    if st.button("Submit"):
        df = inference(comment_txt)
        display_image(df)

############## ENTRY POINT END #######################

if __name__ == "__main__":
    main()