Spaces:
Runtime error
Runtime error
AnushkaTonk commited on
Commit ·
0a6f2d0
1
Parent(s): d43bb10
updated app.py and README.md files
Browse files
README.md
CHANGED
|
@@ -12,3 +12,6 @@ short_description: 'inent classification: comparision in DeBERTa, RoBERTa & BERT
|
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
First load may take a few seconds while models are being downloaded from the Hugging Face Hub.
|
app.py
CHANGED
|
@@ -39,9 +39,11 @@ models_info = {
|
|
| 39 |
"model_class": TFDebertaForSequenceClassification
|
| 40 |
}
|
| 41 |
}
|
| 42 |
-
|
| 43 |
# loading models and tokenizers:
|
| 44 |
def load_model_and_tokenizer(model_choice):
|
|
|
|
|
|
|
| 45 |
config = models_info[model_choice]
|
| 46 |
tokenizer = config["tokenizer"].from_pretrained(config["path"])
|
| 47 |
model = config["model_class"].from_pretrained(config["path"])
|
|
@@ -51,6 +53,7 @@ def load_model_and_tokenizer(model_choice):
|
|
| 51 |
id2label = true_id2label
|
| 52 |
except:
|
| 53 |
id2label = true_id2label
|
|
|
|
| 54 |
return tokenizer, model, id2label
|
| 55 |
|
| 56 |
explainer = LimeTextExplainer(class_names = class_names)
|
|
@@ -87,9 +90,14 @@ interface = gr.Interface(
|
|
| 87 |
gr.HTML(label="LIME Explanation"),
|
| 88 |
gr.Plot(label = "Lime exp PLOTS")
|
| 89 |
],
|
| 90 |
-
title="Intent Classification using Transformers",
|
| 91 |
description="Choose a model, input a customer support query, and get intent predictions with a LIME explanation."
|
| 92 |
)
|
| 93 |
|
| 94 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
interface.launch()
|
|
|
|
| 39 |
"model_class": TFDebertaForSequenceClassification
|
| 40 |
}
|
| 41 |
}
|
| 42 |
+
cached_models = {}
|
| 43 |
# loading models and tokenizers:
|
| 44 |
def load_model_and_tokenizer(model_choice):
|
| 45 |
+
if model_choice in cached_models:
|
| 46 |
+
return cached_models[model_choice]
|
| 47 |
config = models_info[model_choice]
|
| 48 |
tokenizer = config["tokenizer"].from_pretrained(config["path"])
|
| 49 |
model = config["model_class"].from_pretrained(config["path"])
|
|
|
|
| 53 |
id2label = true_id2label
|
| 54 |
except:
|
| 55 |
id2label = true_id2label
|
| 56 |
+
cached_models[model_choice] = (tokenizer, model, id2label)
|
| 57 |
return tokenizer, model, id2label
|
| 58 |
|
| 59 |
explainer = LimeTextExplainer(class_names = class_names)
|
|
|
|
| 90 |
gr.HTML(label="LIME Explanation"),
|
| 91 |
gr.Plot(label = "Lime exp PLOTS")
|
| 92 |
],
|
| 93 |
+
title="Intent Classification using Transformers Models- BERT, RoBERTa, DeBERTa",
|
| 94 |
description="Choose a model, input a customer support query, and get intent predictions with a LIME explanation."
|
| 95 |
)
|
| 96 |
|
| 97 |
if __name__ == "__main__":
|
| 98 |
+
# to dummy test the models
|
| 99 |
+
for model_key in models_info.keys():
|
| 100 |
+
tokenizer, model, id2label = load_model_and_tokenizer(model_key)
|
| 101 |
+
_ = model(**tokenizer("warmup", return_tensors="tf", padding=True, truncation=True, max_length=128))
|
| 102 |
+
|
| 103 |
interface.launch()
|