Unknownaut commited on
Commit
d019588
·
verified ·
1 Parent(s): 292af23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -7
app.py CHANGED
@@ -1,21 +1,66 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
4
 
5
- MODEL_NAME = "Unknownaut/entity-level-framing-news-roberta"
6
-
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
9
 
10
  labels = ["Legitimate", "Aggressor", "Defensive", "Neutral"]
11
 
12
- def predict(sentence, entity):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  inputs = tokenizer(
14
  sentence,
15
  entity,
16
  return_tensors="pt",
17
  truncation=True,
18
- max_length=160
19
  )
20
 
21
  with torch.inference_mode():
@@ -24,9 +69,14 @@ def predict(sentence, entity):
24
 
25
  return labels[pred]
26
 
 
27
  demo = gr.Interface(
28
  fn=predict,
29
- inputs=["text", "text"],
 
 
 
 
30
  outputs="text"
31
  )
32
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
+ import gc
5
 
6
+ ROBERTA_MODEL = "Unknownaut/entity-level-framing-news-roberta"
7
+ BERT_MODEL = "Unknownaut/entity-level-framing-news-bert"
 
 
8
 
9
  labels = ["Legitimate", "Aggressor", "Defensive", "Neutral"]
10
 
11
+ _current_model = None
12
+ _current_tokenizer = None
13
+ _current_model_name = None
14
+
15
+
16
+ def unload_model():
17
+ global _current_model, _current_tokenizer
18
+
19
+ if _current_model is not None:
20
+ del _current_model
21
+ del _current_tokenizer
22
+ gc.collect()
23
+
24
+
25
+ def load_model(model_key):
26
+ global _current_model, _current_tokenizer, _current_model_name
27
+
28
+ if _current_model_name == model_key:
29
+ return _current_model, _current_tokenizer
30
+
31
+ unload_model()
32
+
33
+ if model_key == "roberta":
34
+ tokenizer = AutoTokenizer.from_pretrained(ROBERTA_MODEL)
35
+ model = AutoModelForSequenceClassification.from_pretrained(ROBERTA_MODEL)
36
+
37
+ elif model_key == "bert":
38
+ tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
39
+ model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL)
40
+
41
+ else:
42
+ raise ValueError("Invalid model")
43
+
44
+ model.eval()
45
+
46
+ _current_model = model
47
+ _current_tokenizer = tokenizer
48
+ _current_model_name = model_key
49
+
50
+ return model, tokenizer
51
+
52
+
53
+ def predict(sentence, entity, model_choice):
54
+ model_key = "roberta" if model_choice == "RoBERTa" else "bert"
55
+
56
+ model, tokenizer = load_model(model_key)
57
+
58
  inputs = tokenizer(
59
  sentence,
60
  entity,
61
  return_tensors="pt",
62
  truncation=True,
63
+ max_length=128
64
  )
65
 
66
  with torch.inference_mode():
 
69
 
70
  return labels[pred]
71
 
72
+
73
  demo = gr.Interface(
74
  fn=predict,
75
+ inputs=[
76
+ gr.Textbox(label="Sentence"),
77
+ gr.Textbox(label="Entity"),
78
+ gr.Radio(["RoBERTa", "BERT"], label="Model")
79
+ ],
80
  outputs="text"
81
  )
82