Spaces:
Runtime error
Runtime error
Rob Caamano
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
-
from transformers import AutoTokenizer
|
| 4 |
from transformers import (
|
| 5 |
TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
|
| 6 |
)
|
|
@@ -21,6 +21,9 @@ mod_name = model_options[selected_model]
|
|
| 21 |
|
| 22 |
tokenizer = AutoTokenizer.from_pretrained(mod_name)
|
| 23 |
model = AutoModelForSequenceClassification.from_pretrained(mod_name)
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
if selected_model in ["Fine-tuned Toxicity Model"]:
|
| 26 |
toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
|
|
@@ -30,10 +33,10 @@ def get_toxicity_class(predictions, threshold=0.3):
|
|
| 30 |
return {model.config.id2label[i]: pred for i, pred in enumerate(predictions) if pred >= threshold}
|
| 31 |
|
| 32 |
input = tokenizer(text, return_tensors="tf")
|
| 33 |
-
prediction = model(input)[0].numpy()[0]
|
| 34 |
|
| 35 |
if st.button("Submit", type="primary"):
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
tweet_portion = text[:50] + "..." if len(text) > 50 else text
|
| 39 |
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
+
from transformers import AutoTokenizer, pipeline
|
| 4 |
from transformers import (
|
| 5 |
TFAutoModelForSequenceClassification as AutoModelForSequenceClassification,
|
| 6 |
)
|
|
|
|
| 21 |
|
| 22 |
tokenizer = AutoTokenizer.from_pretrained(mod_name)
|
| 23 |
model = AutoModelForSequenceClassification.from_pretrained(mod_name)
|
| 24 |
+
clf = pipeline(
|
| 25 |
+
"sentiment-analysis", model=model, tokenizer=tokenizer, return_all_scores=True
|
| 26 |
+
)
|
| 27 |
|
| 28 |
if selected_model in ["Fine-tuned Toxicity Model"]:
|
| 29 |
toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
|
|
|
|
| 33 |
return {model.config.id2label[i]: pred for i, pred in enumerate(predictions) if pred >= threshold}
|
| 34 |
|
| 35 |
input = tokenizer(text, return_tensors="tf")
|
|
|
|
| 36 |
|
| 37 |
if st.button("Submit", type="primary"):
|
| 38 |
+
results = dict(d.values() for d in clf(text)[0])
|
| 39 |
+
toxic_labels = {k: results[k] for k in results.keys() if not k == "toxic"}
|
| 40 |
|
| 41 |
tweet_portion = text[:50] + "..." if len(text) > 50 else text
|
| 42 |
|