Unknownaut commited on
Commit
0d4493f
·
verified ·
1 Parent(s): 99f973f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -25
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
- import gc
5
 
6
  ROBERTA_MODEL = "Unknownaut/entity-level-framing-news-roberta"
7
  BERT_MODEL = "Unknownaut/entity-level-framing-news-bert"
@@ -13,28 +12,17 @@ _current_tokenizer = None
13
  _current_model_name = None
14
 
15
 
16
- def unload_model():
17
- global _current_model, _current_tokenizer
18
-
19
- if _current_model is not None:
20
- del _current_model
21
- del _current_tokenizer
22
- gc.collect()
23
-
24
-
25
- def load_model(model_key):
26
  global _current_model, _current_tokenizer, _current_model_name
27
 
28
- if _current_model_name == model_key:
29
  return _current_model, _current_tokenizer
30
 
31
- unload_model()
32
-
33
- if model_key == "roberta":
34
  tokenizer = AutoTokenizer.from_pretrained(ROBERTA_MODEL)
35
  model = AutoModelForSequenceClassification.from_pretrained(ROBERTA_MODEL)
36
 
37
- elif model_key == "bert":
38
  tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
39
  model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL)
40
 
@@ -45,22 +33,20 @@ def load_model(model_key):
45
 
46
  _current_model = model
47
  _current_tokenizer = tokenizer
48
- _current_model_name = model_key
49
 
50
  return model, tokenizer
51
 
52
 
53
  def predict(sentence, entity, model_choice):
54
- model_key = "roberta" if model_choice == "RoBERTa" else "bert"
55
-
56
- model, tokenizer = load_model(model_key)
57
 
58
  inputs = tokenizer(
59
  sentence,
60
  entity,
61
  return_tensors="pt",
62
  truncation=True,
63
- max_length=160
64
  )
65
 
66
  with torch.inference_mode():
@@ -73,11 +59,12 @@ def predict(sentence, entity, model_choice):
73
  demo = gr.Interface(
74
  fn=predict,
75
  inputs=[
76
- gr.Textbox(label="Sentence"),
77
- gr.Textbox(label="Entity"),
78
- gr.Radio(["RoBERTa", "BERT"], label="Model")
79
  ],
80
  outputs="text"
81
  )
82
 
83
- demo.launch()
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
4
 
5
  ROBERTA_MODEL = "Unknownaut/entity-level-framing-news-roberta"
6
  BERT_MODEL = "Unknownaut/entity-level-framing-news-bert"
 
12
  _current_model_name = None
13
 
14
 
15
+ def load_model(model_choice):
 
 
 
 
 
 
 
 
 
16
  global _current_model, _current_tokenizer, _current_model_name
17
 
18
+ if _current_model_name == model_choice:
19
  return _current_model, _current_tokenizer
20
 
21
+ if model_choice == "RoBERTa":
 
 
22
  tokenizer = AutoTokenizer.from_pretrained(ROBERTA_MODEL)
23
  model = AutoModelForSequenceClassification.from_pretrained(ROBERTA_MODEL)
24
 
25
+ elif model_choice == "BERT":
26
  tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL)
27
  model = AutoModelForSequenceClassification.from_pretrained(BERT_MODEL)
28
 
 
33
 
34
  _current_model = model
35
  _current_tokenizer = tokenizer
36
+ _current_model_name = model_choice
37
 
38
  return model, tokenizer
39
 
40
 
41
  def predict(sentence, entity, model_choice):
42
+ model, tokenizer = load_model(model_choice)
 
 
43
 
44
  inputs = tokenizer(
45
  sentence,
46
  entity,
47
  return_tensors="pt",
48
  truncation=True,
49
+ max_length=128
50
  )
51
 
52
  with torch.inference_mode():
 
59
  demo = gr.Interface(
60
  fn=predict,
61
  inputs=[
62
+ gr.Textbox(),
63
+ gr.Textbox(),
64
+ gr.Radio(["RoBERTa", "BERT"])
65
  ],
66
  outputs="text"
67
  )
68
 
69
+ # 🔥 IMPORTANT: enable API
70
+ app = gr.mount_gradio_app(None, demo, path="/")