Spaces:
Build error
Build error
Commit
·
efb23c9
1
Parent(s):
766dac7
use st.metric for sequence logits
Browse files
app.py
CHANGED
|
@@ -7,7 +7,8 @@ import streamlit as st
|
|
| 7 |
import torch
|
| 8 |
from transformers import BertTokenizerFast
|
| 9 |
|
| 10 |
-
from model import BertForTokenAndSequenceJointClassification
|
|
|
|
| 11 |
|
| 12 |
@st.cache(allow_output_mutation=True)
|
| 13 |
def load_model():
|
|
@@ -16,22 +17,28 @@ def load_model():
|
|
| 16 |
"QCRI/PropagandaTechniquesAnalysis-en-BERT",
|
| 17 |
revision="v0.1.0")
|
| 18 |
return tokenizer, model
|
| 19 |
-
|
| 20 |
-
tokenizer, model = load_model()
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
-
inputs = tokenizer.encode_plus(input, return_tensors="pt")
|
| 29 |
-
outputs = model(**inputs)
|
| 30 |
-
sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
|
| 31 |
-
sequence_class = model.sequence_tags[sequence_class_index[0]]
|
| 32 |
-
token_class_index = torch.argmax(outputs.token_logits, dim=-1)
|
| 33 |
-
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
|
| 34 |
-
tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
|
| 35 |
|
| 36 |
spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]
|
| 37 |
|
|
@@ -40,7 +47,7 @@ doc = Doc(Vocab(strings=set(tokens)),
|
|
| 40 |
spaces=spaces,
|
| 41 |
ents=[tag if tag == "O" else f"B-{tag}" for tag in tags])
|
| 42 |
|
| 43 |
-
labels =
|
| 44 |
|
| 45 |
label_select = st.multiselect(
|
| 46 |
"Tags",
|
|
|
|
| 7 |
import torch
|
| 8 |
from transformers import BertTokenizerFast
|
| 9 |
|
| 10 |
+
from model import BertForTokenAndSequenceJointClassification
|
| 11 |
+
|
| 12 |
|
| 13 |
@st.cache(allow_output_mutation=True)
|
| 14 |
def load_model():
|
|
|
|
| 17 |
"QCRI/PropagandaTechniquesAnalysis-en-BERT",
|
| 18 |
revision="v0.1.0")
|
| 19 |
return tokenizer, model
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
with torch.inference_mode(True):
|
| 22 |
+
tokenizer, model = load_model()
|
| 23 |
+
|
| 24 |
+
st.write("[Propaganda Techniques Analysis BERT](https://huggingface.co/QCRI/PropagandaTechniquesAnalysis-en-BERT) Tagger")
|
| 25 |
+
|
| 26 |
+
input = st.text_area('Input', """\
|
| 27 |
+
In some instances, it can be highly dangerous to use a medicine for the prevention or treatment of COVID-19 that has not been approved by or has not received emergency use authorization from the FDA.
|
| 28 |
+
""")
|
| 29 |
+
|
| 30 |
+
inputs = tokenizer.encode_plus(input, return_tensors="pt")
|
| 31 |
+
outputs = model(**inputs)
|
| 32 |
+
sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
|
| 33 |
+
sequence_class = model.sequence_tags[sequence_class_index[0]]
|
| 34 |
+
token_class_index = torch.argmax(outputs.token_logits, dim=-1)
|
| 35 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
|
| 36 |
+
tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
|
| 37 |
|
| 38 |
+
columns = st.columns(len(outputs.sequence_logits.flatten()))
|
| 39 |
+
for col, sequence_tag, logit in zip(columns, model.sequence_tags, outputs.sequence_logits.flatten()):
|
| 40 |
+
col.metric(sequence_tag, '%.2f' % logit.item())
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]
|
| 44 |
|
|
|
|
| 47 |
spaces=spaces,
|
| 48 |
ents=[tag if tag == "O" else f"B-{tag}" for tag in tags])
|
| 49 |
|
| 50 |
+
labels = model.token_tags[2:]
|
| 51 |
|
| 52 |
label_select = st.multiselect(
|
| 53 |
"Tags",
|