AnushkaTonk commited on
Commit
36a744b
·
1 Parent(s): fcc3da9

initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +8 -0
  2. app.py +95 -0
  3. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ saved_model/
2
+ *.h5
3
+ *.bin
4
+ *.pt
5
+ *.ckpt
6
+ *.zip
7
+ *.tar
8
+ *.pkl
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from transformers import (
3
+ BertTokenizerFast, TFBertForSequenceClassification,
4
+ RobertaTokenizer, RobertaTokenizerFast, TFRobertaForSequenceClassification,
5
+ DebertaTokenizer, DebertaTokenizerFast, TFDebertaForSequenceClassification
6
+ )
7
+ import gradio as gr
8
+ from lime.lime_text import LimeTextExplainer
9
+ import numpy as np
10
+
11
+ true_id2label = {
12
+ 0: 'check_payment_methods', 1: 'check_refund_policy', 2: 'check_cancellation_fee',
13
+ 3: 'check_invoices', 4: 'place_order', 5: 'set_up_shipping_address',
14
+ 6: 'recover_password', 7: 'delivery_period', 8: 'track_refund',
15
+ 9: 'payment_issue', 10: 'contact_customer_service', 11: 'newsletter_subscription',
16
+ 12: 'registration_problems', 13: 'cancel_order', 14: 'review',
17
+ 15: 'contact_human_agent', 16: 'track_order', 17: 'get_invoice',
18
+ 18: 'get_refund', 19: 'change_shipping_address', 20: 'delete_account',
19
+ 21: 'delivery_options', 22: 'create_account', 23: 'change_order',
20
+ 24: 'switch_account', 25: 'complaint', 26: 'edit_account'
21
+ }
22
+
23
+ class_names = list(true_id2label.values())
24
+
25
+ models_info = {
26
+ "BERT": {
27
+ "path": "vrindaTonk/bert_intent_class",
28
+ "tokenizer": BertTokenizerFast,
29
+ "model_class": TFBertForSequenceClassification
30
+ },
31
+ "RoBERTa": {
32
+ "path": "vrindaTonk/roberta_intent_class",
33
+ "tokenizer": RobertaTokenizer,
34
+ "model_class": TFRobertaForSequenceClassification
35
+ },
36
+ "DeBERTa": {
37
+ "path": "vrindaTonk/deberta_intent_class",
38
+ "tokenizer": DebertaTokenizer,
39
+ "model_class": TFDebertaForSequenceClassification
40
+ }
41
+ }
42
+
43
+ # loading models and tokenizers:
44
+ def load_model_and_tokenizer(model_choice):
45
+ config = models_info[model_choice]
46
+ tokenizer = config["tokenizer"].from_pretrained(config["path"], local_files_only=True)
47
+ model = config["model_class"].from_pretrained(config["path"], local_files_only=True)
48
+ try:
49
+ id2label = model.config.id2label
50
+ if "LABEL_0" in id2label.values():
51
+ id2label = true_id2label
52
+ except:
53
+ id2label = true_id2label
54
+ return tokenizer, model, id2label
55
+
56
+ explainer = LimeTextExplainer(class_names = class_names)
57
+
58
+ def predict_and_explainations(model_choice, user_text):
59
+ tokenizer, model, id2label = load_model_and_tokenizer(model_choice)
60
+
61
+ input = tokenizer(user_text, return_tensors = "tf", truncation = True, padding = True, max_length = 128)
62
+ logits = model(**input).logits
63
+ probs = tf.nn.softmax(logits, axis = -1).numpy()[0]
64
+ predicted_index = tf.argmax(probs).numpy()
65
+ predictions = {id2label[i] : float(probs[i]) for i in range(len(probs))}
66
+
67
+ def lime_predict(texts):
68
+ batch = tokenizer(texts, return_tensors = "tf", padding = True, max_length = 128)
69
+ logits = model(**batch).logits
70
+ lime_output = tf.nn.softmax(logits, axis = 1).numpy()
71
+ return lime_output
72
+
73
+ explaination = explainer.explain_instance(user_text, lime_predict, num_features= 10, num_samples = 1000)
74
+ exp_html = explaination.as_html()
75
+ fig = explaination.as_pyplot_figure()
76
+ return predictions, exp_html, fig
77
+
78
+ # Gradio COde:
79
+ interface = gr.Interface(
80
+ fn=predict_and_explainations,
81
+ inputs=[
82
+ gr.Dropdown(choices=["BERT", "RoBERTa", "DeBERTa"], label="Choose Model"),
83
+ gr.Textbox(lines=2, label="Customer Query", placeholder="Enter your intent query here...")
84
+ ],
85
+ outputs=[
86
+ gr.Label(num_top_classes=3, label="Predicted Intent Probabilities"),
87
+ gr.HTML(label="LIME Explanation"),
88
+ gr.Plot(label = "Lime exp PLOTS")
89
+ ],
90
+ title="Intent Classification using Transformers",
91
+ description="Choose a model, input a customer support query, and get intent predictions with a LIME explanation."
92
+ )
93
+
94
+ if __name__ == "__main__":
95
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ tensorflow
4
+ lime
5
+ matplotlib
6
+ sentencepiece