Unknownaut commited on
Commit
a62ec9c
·
verified ·
1 Parent(s): 6d5c5e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -1,7 +1,10 @@
1
- import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
 
 
5
  ROBERTA_MODEL = "Unknownaut/entity-level-framing-news-roberta"
6
  BERT_MODEL = "Unknownaut/entity-level-framing-news-bert"
7
 
@@ -12,9 +15,16 @@ _current_tokenizer = None
12
  _current_model_name = None
13
 
14
 
 
 
 
 
 
 
15
  def load_model(model_choice):
16
  global _current_model, _current_tokenizer, _current_model_name
17
 
 
18
  if _current_model_name == model_choice:
19
  return _current_model, _current_tokenizer
20
 
@@ -38,12 +48,18 @@ def load_model(model_choice):
38
  return model, tokenizer
39
 
40
 
41
- def predict(sentence, entity, model_choice):
42
- model, tokenizer = load_model(model_choice)
 
 
 
 
 
 
43
 
44
  inputs = tokenizer(
45
- sentence,
46
- entity,
47
  return_tensors="pt",
48
  truncation=True,
49
  max_length=160
@@ -53,17 +69,4 @@ def predict(sentence, entity, model_choice):
53
  outputs = model(**inputs)
54
  pred = torch.argmax(outputs.logits, dim=1).item()
55
 
56
- return labels[pred]
57
-
58
-
59
- demo = gr.Interface(
60
- fn=predict,
61
- inputs=[
62
- gr.Textbox(),
63
- gr.Textbox(),
64
- gr.Radio(["RoBERTa", "BERT"])
65
- ],
66
- outputs="text"
67
- )
68
-
69
- demo.launch(enable_queue=False)
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
5
 
6
+ app = FastAPI()
7
+
8
  ROBERTA_MODEL = "Unknownaut/entity-level-framing-news-roberta"
9
  BERT_MODEL = "Unknownaut/entity-level-framing-news-bert"
10
 
 
15
  _current_model_name = None
16
 
17
 
18
+ class RequestData(BaseModel):
19
+ sentence: str
20
+ entity: str
21
+ model: str # "RoBERTa" or "BERT"
22
+
23
+
24
  def load_model(model_choice):
25
  global _current_model, _current_tokenizer, _current_model_name
26
 
27
+ # reuse if already loaded
28
  if _current_model_name == model_choice:
29
  return _current_model, _current_tokenizer
30
 
 
48
  return model, tokenizer
49
 
50
 
51
+ @app.get("/")
52
+ def health():
53
+ return {"status": "ok"}
54
+
55
+
56
+ @app.post("/predict")
57
+ def predict(data: RequestData):
58
+ model, tokenizer = load_model(data.model)
59
 
60
  inputs = tokenizer(
61
+ data.sentence,
62
+ data.entity,
63
  return_tensors="pt",
64
  truncation=True,
65
  max_length=160
 
69
  outputs = model(**inputs)
70
  pred = torch.argmax(outputs.logits, dim=1).item()
71
 
72
+ return {"label": labels[pred]}