File size: 2,428 Bytes
d22385f c66ca01 d22385f e498b41 ade8f26 493f0e1 d22385f e498b41 493f0e1 e498b41 86e8dbe e498b41 d22385f 493f0e1 d22385f ade8f26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import gradio as gr
import tensorflow as tf
import pickle
import json
import re
from tensorflow.keras.preprocessing.sequence import pad_sequences
from underthesea import word_tokenize
import numpy as np
# Load Model
model = tf.keras.models.load_model('saved_models/bidirectional-GRU.h5')
# Load Tokenizer
with open('tokenizers/tokenizer.pkl', 'rb') as tokenizer_file:
tokenizer = pickle.load(tokenizer_file)
tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(tokenizer)
# Load the Label Dictionary
with open("saved_models/label_dict.json", 'r', encoding='utf-8') as file:
label_dict = json.load(file)
MAX_LEN = 8054
characters_to_replace = r'!"#$%&()*+,-./:;=?@\[\\\]^`{|}~\t\n'
def predict_news_type(content):
content = re.sub(f"[{re.escape(characters_to_replace)}]", "", content.replace("\n", " "))
content = "<sos> " + word_tokenize(content.replace("\n", ""), format="text") + " <eos>"
content = tokenizer.texts_to_sequences([content])
content = pad_sequences(content, maxlen=MAX_LEN, padding='post')
content_predict = model.predict(content, verbose=0)
result = np.argmax(content_predict, axis=1)
category = label_dict[str(result[0])]
# Get all categories and their probabilities
probabilities = content_predict[0].tolist()
category_probabilities = {label_dict[str(i)]: prob for i, prob in enumerate(probabilities)}
# Sort the category probabilities in descending order
sorted_category_probabilities = sorted(category_probabilities.items(), key=lambda x: x[1], reverse=True)
# Convert the sorted dictionary to a list of strings
category_probabilities_list = [f"{category}: {probability * 100:.2f}%" for category, probability in sorted_category_probabilities]
# Convert the list to a single string
category_probabilities_str = "\n".join(category_probabilities_list)
return category, category_probabilities_str
# Create Gradio Interface
demo = gr.Interface(
fn=predict_news_type,
inputs=gr.Textbox(label="Enter the news content"),
outputs=[
gr.Textbox(label="Predicted News Category"),
gr.Textbox(label="Category Probabilities")
],
title="Vietnamese News Category Classification",
description="Enter the news content to predict its category and see the probabilities for all categories."
)
if __name__ == "__main__":
# Launch the Gradio Interface
demo.launch()
|