AnushkaTonk commited on
Commit
0a6f2d0
·
1 Parent(s): d43bb10

updated app.py and README.md files

Browse files
Files changed (2) hide show
  1. README.md +3 -0
  2. app.py +10 -2
README.md CHANGED
@@ -12,3 +12,6 @@ short_description: 'inent classification: comparision in DeBERTa, RoBERTa & BERT
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+
17
+ First load may take a few seconds while models are being downloaded from the Hugging Face Hub.
app.py CHANGED
@@ -39,9 +39,11 @@ models_info = {
39
  "model_class": TFDebertaForSequenceClassification
40
  }
41
  }
42
-
43
  # loading models and tokenizers:
44
  def load_model_and_tokenizer(model_choice):
 
 
45
  config = models_info[model_choice]
46
  tokenizer = config["tokenizer"].from_pretrained(config["path"])
47
  model = config["model_class"].from_pretrained(config["path"])
@@ -51,6 +53,7 @@ def load_model_and_tokenizer(model_choice):
51
  id2label = true_id2label
52
  except:
53
  id2label = true_id2label
 
54
  return tokenizer, model, id2label
55
 
56
  explainer = LimeTextExplainer(class_names = class_names)
@@ -87,9 +90,14 @@ interface = gr.Interface(
87
  gr.HTML(label="LIME Explanation"),
88
  gr.Plot(label = "Lime exp PLOTS")
89
  ],
90
- title="Intent Classification using Transformers",
91
  description="Choose a model, input a customer support query, and get intent predictions with a LIME explanation."
92
  )
93
 
94
  if __name__ == "__main__":
 
 
 
 
 
95
  interface.launch()
 
39
  "model_class": TFDebertaForSequenceClassification
40
  }
41
  }
42
+ cached_models = {}
43
  # loading models and tokenizers:
44
  def load_model_and_tokenizer(model_choice):
45
+ if model_choice in cached_models:
46
+ return cached_models[model_choice]
47
  config = models_info[model_choice]
48
  tokenizer = config["tokenizer"].from_pretrained(config["path"])
49
  model = config["model_class"].from_pretrained(config["path"])
 
53
  id2label = true_id2label
54
  except:
55
  id2label = true_id2label
56
+ cached_models[model_choice] = (tokenizer, model, id2label)
57
  return tokenizer, model, id2label
58
 
59
  explainer = LimeTextExplainer(class_names = class_names)
 
90
  gr.HTML(label="LIME Explanation"),
91
  gr.Plot(label = "Lime exp PLOTS")
92
  ],
93
+ title="Intent Classification using Transformers Models- BERT, RoBERTa, DeBERTa",
94
  description="Choose a model, input a customer support query, and get intent predictions with a LIME explanation."
95
  )
96
 
97
  if __name__ == "__main__":
98
+ # to dummy test the models
99
+ for model_key in models_info.keys():
100
+ tokenizer, model, id2label = load_model_and_tokenizer(model_key)
101
+ _ = model(**tokenizer("warmup", return_tensors="tf", padding=True, truncation=True, max_length=128))
102
+
103
  interface.launch()