Spaces:
Runtime error
Runtime error
AnushkaTonk commited on
Commit ·
36a744b
1
Parent(s): fcc3da9
initial commit
Browse files- .gitignore +8 -0
- app.py +95 -0
- requirements.txt +6 -0
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
saved_model/
|
| 2 |
+
*.h5
|
| 3 |
+
*.bin
|
| 4 |
+
*.pt
|
| 5 |
+
*.ckpt
|
| 6 |
+
*.zip
|
| 7 |
+
*.tar
|
| 8 |
+
*.pkl
|
app.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from transformers import (
|
| 3 |
+
BertTokenizerFast, TFBertForSequenceClassification,
|
| 4 |
+
RobertaTokenizer, RobertaTokenizerFast, TFRobertaForSequenceClassification,
|
| 5 |
+
DebertaTokenizer, DebertaTokenizerFast, TFDebertaForSequenceClassification
|
| 6 |
+
)
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from lime.lime_text import LimeTextExplainer
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
true_id2label = {
|
| 12 |
+
0: 'check_payment_methods', 1: 'check_refund_policy', 2: 'check_cancellation_fee',
|
| 13 |
+
3: 'check_invoices', 4: 'place_order', 5: 'set_up_shipping_address',
|
| 14 |
+
6: 'recover_password', 7: 'delivery_period', 8: 'track_refund',
|
| 15 |
+
9: 'payment_issue', 10: 'contact_customer_service', 11: 'newsletter_subscription',
|
| 16 |
+
12: 'registration_problems', 13: 'cancel_order', 14: 'review',
|
| 17 |
+
15: 'contact_human_agent', 16: 'track_order', 17: 'get_invoice',
|
| 18 |
+
18: 'get_refund', 19: 'change_shipping_address', 20: 'delete_account',
|
| 19 |
+
21: 'delivery_options', 22: 'create_account', 23: 'change_order',
|
| 20 |
+
24: 'switch_account', 25: 'complaint', 26: 'edit_account'
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
class_names = list(true_id2label.values())
|
| 24 |
+
|
| 25 |
+
models_info = {
|
| 26 |
+
"BERT": {
|
| 27 |
+
"path": "vrindaTonk/bert_intent_class",
|
| 28 |
+
"tokenizer": BertTokenizerFast,
|
| 29 |
+
"model_class": TFBertForSequenceClassification
|
| 30 |
+
},
|
| 31 |
+
"RoBERTa": {
|
| 32 |
+
"path": "vrindaTonk/roberta_intent_class",
|
| 33 |
+
"tokenizer": RobertaTokenizer,
|
| 34 |
+
"model_class": TFRobertaForSequenceClassification
|
| 35 |
+
},
|
| 36 |
+
"DeBERTa": {
|
| 37 |
+
"path": "vrindaTonk/deberta_intent_class",
|
| 38 |
+
"tokenizer": DebertaTokenizer,
|
| 39 |
+
"model_class": TFDebertaForSequenceClassification
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# loading models and tokenizers:
|
| 44 |
+
def load_model_and_tokenizer(model_choice):
|
| 45 |
+
config = models_info[model_choice]
|
| 46 |
+
tokenizer = config["tokenizer"].from_pretrained(config["path"], local_files_only=True)
|
| 47 |
+
model = config["model_class"].from_pretrained(config["path"], local_files_only=True)
|
| 48 |
+
try:
|
| 49 |
+
id2label = model.config.id2label
|
| 50 |
+
if "LABEL_0" in id2label.values():
|
| 51 |
+
id2label = true_id2label
|
| 52 |
+
except:
|
| 53 |
+
id2label = true_id2label
|
| 54 |
+
return tokenizer, model, id2label
|
| 55 |
+
|
| 56 |
+
explainer = LimeTextExplainer(class_names = class_names)
|
| 57 |
+
|
| 58 |
+
def predict_and_explainations(model_choice, user_text):
|
| 59 |
+
tokenizer, model, id2label = load_model_and_tokenizer(model_choice)
|
| 60 |
+
|
| 61 |
+
input = tokenizer(user_text, return_tensors = "tf", truncation = True, padding = True, max_length = 128)
|
| 62 |
+
logits = model(**input).logits
|
| 63 |
+
probs = tf.nn.softmax(logits, axis = -1).numpy()[0]
|
| 64 |
+
predicted_index = tf.argmax(probs).numpy()
|
| 65 |
+
predictions = {id2label[i] : float(probs[i]) for i in range(len(probs))}
|
| 66 |
+
|
| 67 |
+
def lime_predict(texts):
|
| 68 |
+
batch = tokenizer(texts, return_tensors = "tf", padding = True, max_length = 128)
|
| 69 |
+
logits = model(**batch).logits
|
| 70 |
+
lime_output = tf.nn.softmax(logits, axis = 1).numpy()
|
| 71 |
+
return lime_output
|
| 72 |
+
|
| 73 |
+
explaination = explainer.explain_instance(user_text, lime_predict, num_features= 10, num_samples = 1000)
|
| 74 |
+
exp_html = explaination.as_html()
|
| 75 |
+
fig = explaination.as_pyplot_figure()
|
| 76 |
+
return predictions, exp_html, fig
|
| 77 |
+
|
| 78 |
+
# Gradio COde:
|
| 79 |
+
interface = gr.Interface(
|
| 80 |
+
fn=predict_and_explainations,
|
| 81 |
+
inputs=[
|
| 82 |
+
gr.Dropdown(choices=["BERT", "RoBERTa", "DeBERTa"], label="Choose Model"),
|
| 83 |
+
gr.Textbox(lines=2, label="Customer Query", placeholder="Enter your intent query here...")
|
| 84 |
+
],
|
| 85 |
+
outputs=[
|
| 86 |
+
gr.Label(num_top_classes=3, label="Predicted Intent Probabilities"),
|
| 87 |
+
gr.HTML(label="LIME Explanation"),
|
| 88 |
+
gr.Plot(label = "Lime exp PLOTS")
|
| 89 |
+
],
|
| 90 |
+
title="Intent Classification using Transformers",
|
| 91 |
+
description="Choose a model, input a customer support query, and get intent predictions with a LIME explanation."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
interface.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
transformers
|
| 3 |
+
tensorflow
|
| 4 |
+
lime
|
| 5 |
+
matplotlib
|
| 6 |
+
sentencepiece
|