ngt-ai-platform / modules /multilabel_classification.py
Gaetano Parente
fix
db9c437
from keras_preprocessing.sequence import pad_sequences
import modules.utilities.utils as utils
import keras.models as models
import numpy as np
BASE_PATH = './data/'
MODEL = BASE_PATH + 'model/'
TOKEN = BASE_PATH + 'tokenizer/'
class_names = np.array(['Economia', 'Politica', 'Scienza_e_tecnica', 'Sport', 'Storia'])
def predict(model_path, tokenizer_path, sentence):
tokenizer = utils.load_tokenizer(tokenizer_path)
x_data = []
x_data.append(sentence)
x_tokenized = tokenizer.texts_to_sequences(x_data)
new_x_tokenized = []
for x_token in x_tokenized[0]:
if(x_token is None):
x_token = 1
new_x_tokenized.append(x_token)
x_tokenized = [new_x_tokenized]
x_pad = pad_sequences(x_tokenized, maxlen=200)
x_t = x_pad[0]
model = models.load_model(model_path, compile=False)
prediction = model.predict(np.array([x_t]))
#predicted_label = class_names[np.argmax(prediction[0])]
return prediction#, predicted_label
def multi_classification(text):
model = MODEL + 'multi-classification.h5'
tokenizer = TOKEN + 'multi-classification-tokenizer.json'
labels = predict(model, tokenizer, text)
response = {}
for i, label in enumerate(labels[0]):
response[class_names[i]] = "%.4f" % float(label)
return response