AnushkaTonk
updated params in predict() function
a5d5bf8
import tensorflow as tf
from transformers import (
BertTokenizerFast, TFBertForSequenceClassification,
RobertaTokenizer, RobertaTokenizerFast, TFRobertaForSequenceClassification,
DebertaTokenizer, DebertaTokenizerFast, TFDebertaForSequenceClassification
)
import gradio as gr
from lime.lime_text import LimeTextExplainer
import numpy as np
true_id2label = {
0: 'check_payment_methods', 1: 'check_refund_policy', 2: 'check_cancellation_fee',
3: 'check_invoices', 4: 'place_order', 5: 'set_up_shipping_address',
6: 'recover_password', 7: 'delivery_period', 8: 'track_refund',
9: 'payment_issue', 10: 'contact_customer_service', 11: 'newsletter_subscription',
12: 'registration_problems', 13: 'cancel_order', 14: 'review',
15: 'contact_human_agent', 16: 'track_order', 17: 'get_invoice',
18: 'get_refund', 19: 'change_shipping_address', 20: 'delete_account',
21: 'delivery_options', 22: 'create_account', 23: 'change_order',
24: 'switch_account', 25: 'complaint', 26: 'edit_account'
}
class_names = list(true_id2label.values())
models_info = {
"BERT": {
"path": "vrindaTonk/bert_intent_class",
"tokenizer": BertTokenizerFast,
"model_class": TFBertForSequenceClassification
},
"RoBERTa": {
"path": "vrindaTonk/roberta_intent_class",
"tokenizer": RobertaTokenizer,
"model_class": TFRobertaForSequenceClassification
},
"DeBERTa": {
"path": "vrindaTonk/deberta_intent_class",
"tokenizer": DebertaTokenizer,
"model_class": TFDebertaForSequenceClassification
}
}
cached_models = {}
# loading models and tokenizers:
def load_model_and_tokenizer(model_choice):
if model_choice in cached_models:
return cached_models[model_choice]
config = models_info[model_choice]
tokenizer = config["tokenizer"].from_pretrained(config["path"])
model = config["model_class"].from_pretrained(config["path"])
try:
id2label = model.config.id2label
if "LABEL_0" in id2label.values():
id2label = true_id2label
except:
id2label = true_id2label
cached_models[model_choice] = (tokenizer, model, id2label)
return tokenizer, model, id2label
explainer = LimeTextExplainer(class_names = class_names)
def predict_and_explainations(model_choice, user_text):
tokenizer, model, id2label = load_model_and_tokenizer(model_choice)
input = tokenizer(user_text, return_tensors = "tf", truncation = True, padding = "max_length", max_length = 128)
logits = model(**input).logits
probs = tf.nn.softmax(logits, axis = -1).numpy()[0]
predicted_index = tf.argmax(probs).numpy()
predictions = {id2label[i] : float(probs[i]) for i in range(len(probs))}
def lime_predict(texts):
batch = tokenizer(texts, return_tensors = "tf", padding = "max_length", max_length = 128)
logits = model(**batch).logits
lime_output = tf.nn.softmax(logits, axis = 1).numpy()
return lime_output
explaination = explainer.explain_instance(user_text, lime_predict, num_features= 10, num_samples = 1000)
exp_html = explaination.as_html()
fig = explaination.as_pyplot_figure()
return predictions, exp_html, fig
# Gradio COde:
interface = gr.Interface(
fn=predict_and_explainations,
inputs=[
gr.Dropdown(choices=["BERT", "RoBERTa", "DeBERTa"], label="Choose Model"),
gr.Textbox(lines=2, label="Customer Query", placeholder="Enter your intent query here...")
],
outputs=[
gr.Label(num_top_classes=3, label="Predicted Intent Probabilities"),
gr.HTML(label="LIME Explanation"),
gr.Plot(label = "Lime exp PLOTS")
],
title="Intent Classification using Transformers Models- BERT, RoBERTa, DeBERTa",
description="Choose a model, input a customer support query, and get intent predictions with a LIME explanation."
)
if __name__ == "__main__":
# to dummy test the models
for model_key in models_info.keys():
tokenizer, model, id2label = load_model_and_tokenizer(model_key)
_ = model(**tokenizer("warmup", return_tensors="tf", padding=True, truncation=True, max_length=128))
interface.launch()