File size: 4,214 Bytes
36a744b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a6f2d0
36a744b
 
0a6f2d0
 
36a744b
d43bb10
 
36a744b
 
 
 
 
 
0a6f2d0
36a744b
 
 
 
 
 
 
a5d5bf8
36a744b
 
 
 
 
 
a5d5bf8
36a744b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a6f2d0
36a744b
 
 
 
0a6f2d0
 
 
 
 
36a744b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import tensorflow as tf
from transformers import (
    BertTokenizerFast, TFBertForSequenceClassification,
    RobertaTokenizer, RobertaTokenizerFast, TFRobertaForSequenceClassification,
    DebertaTokenizer, DebertaTokenizerFast, TFDebertaForSequenceClassification
)
import gradio as gr
from lime.lime_text import LimeTextExplainer
import numpy as np

true_id2label = {
    0: 'check_payment_methods', 1: 'check_refund_policy', 2: 'check_cancellation_fee',
    3: 'check_invoices', 4: 'place_order', 5: 'set_up_shipping_address',
    6: 'recover_password', 7: 'delivery_period', 8: 'track_refund',
    9: 'payment_issue', 10: 'contact_customer_service', 11: 'newsletter_subscription',
    12: 'registration_problems', 13: 'cancel_order', 14: 'review',
    15: 'contact_human_agent', 16: 'track_order', 17: 'get_invoice',
    18: 'get_refund', 19: 'change_shipping_address', 20: 'delete_account',
    21: 'delivery_options', 22: 'create_account', 23: 'change_order',
    24: 'switch_account', 25: 'complaint', 26: 'edit_account'
}

class_names = list(true_id2label.values())

models_info = {
    "BERT": {
        "path": "vrindaTonk/bert_intent_class",
        "tokenizer": BertTokenizerFast,
        "model_class": TFBertForSequenceClassification
    },
    "RoBERTa": {
        "path": "vrindaTonk/roberta_intent_class",
        "tokenizer": RobertaTokenizer,
        "model_class": TFRobertaForSequenceClassification
    },
    "DeBERTa": {
        "path": "vrindaTonk/deberta_intent_class",
        "tokenizer": DebertaTokenizer,
        "model_class": TFDebertaForSequenceClassification
    }
}
cached_models = {}
# loading models and tokenizers:
def load_model_and_tokenizer(model_choice):
    if model_choice in cached_models:
        return cached_models[model_choice]
    config = models_info[model_choice]
    tokenizer = config["tokenizer"].from_pretrained(config["path"])
    model = config["model_class"].from_pretrained(config["path"])
    try:
        id2label = model.config.id2label
        if "LABEL_0" in id2label.values():
            id2label = true_id2label
    except:
        id2label = true_id2label
    cached_models[model_choice] = (tokenizer, model, id2label)
    return tokenizer, model, id2label

explainer = LimeTextExplainer(class_names = class_names)

def predict_and_explainations(model_choice, user_text):
    tokenizer, model, id2label = load_model_and_tokenizer(model_choice)
    
    input = tokenizer(user_text, return_tensors = "tf", truncation = True, padding = "max_length", max_length = 128)
    logits = model(**input).logits
    probs = tf.nn.softmax(logits, axis = -1).numpy()[0]
    predicted_index = tf.argmax(probs).numpy()
    predictions = {id2label[i] : float(probs[i]) for i in range(len(probs))}
    
    def lime_predict(texts):
        batch = tokenizer(texts, return_tensors = "tf", padding = "max_length", max_length = 128)
        logits = model(**batch).logits
        lime_output = tf.nn.softmax(logits, axis = 1).numpy()
        return lime_output
    
    explaination = explainer.explain_instance(user_text, lime_predict, num_features= 10, num_samples = 1000)
    exp_html = explaination.as_html()
    fig = explaination.as_pyplot_figure()
    return predictions, exp_html, fig

# Gradio COde:
interface = gr.Interface(
    fn=predict_and_explainations,
    inputs=[
        gr.Dropdown(choices=["BERT", "RoBERTa", "DeBERTa"], label="Choose Model"),
        gr.Textbox(lines=2, label="Customer Query", placeholder="Enter your intent query here...")
    ],
    outputs=[
        gr.Label(num_top_classes=3, label="Predicted Intent Probabilities"),
        gr.HTML(label="LIME Explanation"),
        gr.Plot(label = "Lime exp PLOTS")
    ],
    title="Intent Classification using Transformers Models- BERT, RoBERTa, DeBERTa",
    description="Choose a model, input a customer support query, and get intent predictions with a LIME explanation."
)

if __name__ == "__main__":
    # to dummy test the models
    for model_key in models_info.keys():
        tokenizer, model, id2label = load_model_and_tokenizer(model_key)
        _ = model(**tokenizer("warmup", return_tensors="tf", padding=True, truncation=True, max_length=128))

    interface.launch()