Spaces:
Runtime error
Runtime error
polinaeterna
commited on
Commit
Β·
3bb5a93
1
Parent(s):
3dcef48
ad toxicity check
Browse files
app.py
CHANGED
|
@@ -6,7 +6,6 @@ import multiprocessing
|
|
| 6 |
import gradio as gr
|
| 7 |
import pandas as pd
|
| 8 |
import polars as pl
|
| 9 |
-
import numpy as np
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
import spaces
|
| 12 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
|
@@ -90,12 +89,83 @@ def plot_and_df(texts, preds):
|
|
| 90 |
)
|
| 91 |
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
def run_quality_check(dataset, column, batch_size, num_examples):
|
| 95 |
-
# config = "default"
|
| 96 |
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
|
| 97 |
if "error" in info_resp:
|
| 98 |
-
yield "β " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure()
|
| 99 |
return
|
| 100 |
config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
|
| 101 |
split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(
|
|
@@ -106,9 +176,10 @@ def run_quality_check(dataset, column, batch_size, num_examples):
|
|
| 106 |
try:
|
| 107 |
data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column])
|
| 108 |
except Exception as error:
|
| 109 |
-
yield f"β {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure()
|
| 110 |
return
|
| 111 |
texts = data[column].to_list()
|
|
|
|
| 112 |
# batch_size = 100
|
| 113 |
predictions, texts_processed = [], []
|
| 114 |
num_examples = min(len(texts), num_examples)
|
|
@@ -117,7 +188,7 @@ def run_quality_check(dataset, column, batch_size, num_examples):
|
|
| 117 |
batch_predictions = predict(batch_texts)
|
| 118 |
predictions.extend(batch_predictions)
|
| 119 |
texts_processed.extend(batch_texts)
|
| 120 |
-
yield {"check in progress...": (i+batch_size) / num_examples}, *plot_and_df(texts_processed, predictions), plt.Figure()
|
| 121 |
|
| 122 |
with multiprocessing.Pool(processes=8) as pool:
|
| 123 |
props = pool.map(proportion_non_ascii, texts)
|
|
@@ -128,7 +199,8 @@ def run_quality_check(dataset, column, batch_size, num_examples):
|
|
| 128 |
plt.xlabel('Proportion of non-ASCII characters')
|
| 129 |
plt.ylabel('Number of texts')
|
| 130 |
|
| 131 |
-
yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), plt.gcf()
|
|
|
|
| 132 |
|
| 133 |
with gr.Blocks() as demo:
|
| 134 |
gr.Markdown(
|
|
@@ -175,6 +247,13 @@ with gr.Blocks() as demo:
|
|
| 175 |
|
| 176 |
# non_ascii_hist = gr.DataFrame(visible=False)
|
| 177 |
non_ascii_hist = gr.Plot()
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
demo.launch()
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
import pandas as pd
|
| 8 |
import polars as pl
|
|
|
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
import spaces
|
| 11 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
|
|
|
| 89 |
)
|
| 90 |
|
| 91 |
|
| 92 |
+
PERSPECTIVE_API_KEY = os.environ.get("PERSPECTIVE_API_KEY")
|
| 93 |
+
PERSPECTIVE_URL = f"https://commentanalyzer.googleapis.com/v1alpha1/comments:analyze?key={PERSPECTIVE_API_KEY}"
|
| 94 |
+
REQUESTED_ATTRIBUTES = {"TOXICITY": {}, "SEVERE_TOXICITY": {},
|
| 95 |
+
"IDENTITY_ATTACK": {}, "INSULT": {}, "PROFANITY": {},
|
| 96 |
+
"THREAT": {}}
|
| 97 |
+
ATT_SCORE = "attributeScores"
|
| 98 |
+
SUM_SCORE = "summaryScore"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def plot_toxicity(scores):
|
| 102 |
+
fig, axs = plt.subplots(2, 3)#, figsize=(10, 6))
|
| 103 |
+
for x, y, score_name in zip([0,0,0,1,1,1], [0,1,2,0,1,2], scores):
|
| 104 |
+
axs[x,y].hist(scores[score_name], bins=20, range=(0., 1.))
|
| 105 |
+
# axs[x,y].set_title(f'Histogram of {score_name}')
|
| 106 |
+
axs[x,y].set_xlabel(f'{score_name}')
|
| 107 |
+
# axs[x,y].set_ylabel('Number of texts')
|
| 108 |
+
fig.supylabel("Number of texts")
|
| 109 |
+
fig.suptitle("Histogram of toxicity scores")
|
| 110 |
+
fig.tight_layout()
|
| 111 |
+
|
| 112 |
+
return fig
|
| 113 |
+
|
| 114 |
+
def call_perspective_api(texts_df, column_name):#, s):
|
| 115 |
+
headers = {
|
| 116 |
+
"content-type": "application/json",
|
| 117 |
+
}
|
| 118 |
+
req_att_scores = {attr: [] for attr in REQUESTED_ATTRIBUTES}
|
| 119 |
+
|
| 120 |
+
texts = texts_df[column_name].values
|
| 121 |
+
for i, text in tqdm(enumerate(texts), desc="scanning with perspective"):
|
| 122 |
+
data = {
|
| 123 |
+
"comment": {"text": text},
|
| 124 |
+
"languages": ["en"],
|
| 125 |
+
"requestedAttributes": REQUESTED_ATTRIBUTES
|
| 126 |
+
}
|
| 127 |
+
time.sleep(1)
|
| 128 |
+
try:
|
| 129 |
+
req_response = requests.post(PERSPECTIVE_URL, json=data, headers=headers)
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(e)
|
| 132 |
+
return req_att_scores
|
| 133 |
+
|
| 134 |
+
if req_response.ok:
|
| 135 |
+
response = req_response.json()
|
| 136 |
+
# logger.info("Perspective API response is:")
|
| 137 |
+
# logger.info(response)
|
| 138 |
+
if ATT_SCORE in response:
|
| 139 |
+
for req_att in REQUESTED_ATTRIBUTES:
|
| 140 |
+
if req_att in response[ATT_SCORE]:
|
| 141 |
+
att_score = response[ATT_SCORE][req_att][SUM_SCORE]["value"]
|
| 142 |
+
req_att_scores[req_att].append(att_score)
|
| 143 |
+
else:
|
| 144 |
+
req_att_scores[req_att].append(0)
|
| 145 |
+
else:
|
| 146 |
+
# logger.error(
|
| 147 |
+
# "Unexpected response format from Perspective API."
|
| 148 |
+
# )
|
| 149 |
+
raise ValueError(req_response)
|
| 150 |
+
else:
|
| 151 |
+
try:
|
| 152 |
+
req_response.raise_for_status()
|
| 153 |
+
except Exception as e:
|
| 154 |
+
print(e)
|
| 155 |
+
return req_att_scores
|
| 156 |
+
if i % 10 == 0:
|
| 157 |
+
plot_toxicity(req_att_scores)
|
| 158 |
+
yield plt.gcf(), pd.DataFrame()
|
| 159 |
+
|
| 160 |
+
plot_toxicity(req_att_scores)
|
| 161 |
+
yield plt.gcf(), pd.DataFrame.from_dict({column_name: texts, **req_att_scores})
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# @spaces.GPU
|
| 165 |
def run_quality_check(dataset, column, batch_size, num_examples):
|
|
|
|
| 166 |
info_resp = session.get(f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=3).json()
|
| 167 |
if "error" in info_resp:
|
| 168 |
+
yield "β " + info_resp["error"], gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure(), pd.DataFrame(),
|
| 169 |
return
|
| 170 |
config = "default" if "default" in info_resp["dataset_info"] else next(iter(info_resp["dataset_info"]))
|
| 171 |
split = "train" if "train" in info_resp["dataset_info"][config]["splits"] else next(
|
|
|
|
| 176 |
try:
|
| 177 |
data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column])
|
| 178 |
except Exception as error:
|
| 179 |
+
yield f"β {error}", gr.BarPlot(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), plt.Figure(), pd.DataFrame(),
|
| 180 |
return
|
| 181 |
texts = data[column].to_list()
|
| 182 |
+
texts_sample = data.sample(20, shuffle=True, seed=16).to_pandas()
|
| 183 |
# batch_size = 100
|
| 184 |
predictions, texts_processed = [], []
|
| 185 |
num_examples = min(len(texts), num_examples)
|
|
|
|
| 188 |
batch_predictions = predict(batch_texts)
|
| 189 |
predictions.extend(batch_predictions)
|
| 190 |
texts_processed.extend(batch_texts)
|
| 191 |
+
yield {"check in progress...": (i+batch_size) / num_examples}, *plot_and_df(texts_processed, predictions), plt.Figure(), pd.DataFrame()
|
| 192 |
|
| 193 |
with multiprocessing.Pool(processes=8) as pool:
|
| 194 |
props = pool.map(proportion_non_ascii, texts)
|
|
|
|
| 199 |
plt.xlabel('Proportion of non-ASCII characters')
|
| 200 |
plt.ylabel('Number of texts')
|
| 201 |
|
| 202 |
+
yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), plt.gcf(), texts_sample
|
| 203 |
+
|
| 204 |
|
| 205 |
with gr.Blocks() as demo:
|
| 206 |
gr.Markdown(
|
|
|
|
| 247 |
|
| 248 |
# non_ascii_hist = gr.DataFrame(visible=False)
|
| 249 |
non_ascii_hist = gr.Plot()
|
| 250 |
+
texts_sample_df = gr.DataFrame(visible=False)
|
| 251 |
+
gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size, num_examples], outputs=[progress_bar, plot, df_low, df_medium, df_high, non_ascii_hist, texts_sample_df])
|
| 252 |
+
|
| 253 |
+
gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.")
|
| 254 |
+
toxicity_hist = gr.Plot()
|
| 255 |
+
with gr.Accordion("Explore examples with toxicity scores:", open=False):
|
| 256 |
+
toxicity_df = gr.DataFrame()
|
| 257 |
+
gr_toxicity_btn.click(call_perspective_api, inputs=[texts_sample_df, text_column], outputs=[toxicity_hist, toxicity_df])
|
| 258 |
|
| 259 |
demo.launch()
|