Spaces:
Runtime error
Runtime error
adding parquets
Browse files
app.py
CHANGED
|
@@ -224,15 +224,6 @@ if __name__ == "__main__":
|
|
| 224 |
["distilbert-base-uncased-finetuned-sst-2-english",
|
| 225 |
"albert-base-v2-yelp-polarity"],
|
| 226 |
)
|
| 227 |
-
|
| 228 |
-
loss_quantile = st.sidebar.slider(
|
| 229 |
-
"Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
|
| 230 |
-
)
|
| 231 |
-
|
| 232 |
-
run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0)
|
| 233 |
-
|
| 234 |
-
num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
|
| 235 |
-
|
| 236 |
### LOAD DATA AND SESSION VARIABLES ###
|
| 237 |
data_df = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'.parquet')
|
| 238 |
if model == 'albert-base-v2-yelp-polarity':
|
|
@@ -243,13 +234,28 @@ if __name__ == "__main__":
|
|
| 243 |
st.session_state["user_data"] = data_df
|
| 244 |
if "selected_slice" not in st.session_state:
|
| 245 |
st.session_state["selected_slice"] = None
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
| 247 |
data_df['loss'] = data_df['loss'].astype(float)
|
| 248 |
losses = data_df['loss']
|
| 249 |
high_loss = losses.quantile(loss_quantile)
|
| 250 |
data_df['slice'] = 'high-loss'
|
| 251 |
data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
if run_kmeans == 'True':
|
| 254 |
merged = kmeans(data_df,num_clusters=num_clusters)
|
| 255 |
with lcol:
|
|
@@ -264,12 +270,5 @@ if __name__ == "__main__":
|
|
| 264 |
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
| 265 |
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
| 266 |
st.write(dataframe,width=900, height=300)
|
| 267 |
-
|
| 268 |
-
with rcol:
|
| 269 |
-
with st.spinner(text='loading...'):
|
| 270 |
-
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
| 271 |
-
commontokens = frequent_tokens(merged, tokenizer, loss_quantile=loss_quantile)
|
| 272 |
-
with st.expander("How to read the table:"):
|
| 273 |
-
st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
|
| 274 |
-
st.write(commontokens)
|
| 275 |
quant_panel(merged)
|
|
|
|
| 224 |
["distilbert-base-uncased-finetuned-sst-2-english",
|
| 225 |
"albert-base-v2-yelp-polarity"],
|
| 226 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
### LOAD DATA AND SESSION VARIABLES ###
|
| 228 |
data_df = pd.read_parquet('./assets/data/'+dataset+ '_'+ model+'.parquet')
|
| 229 |
if model == 'albert-base-v2-yelp-polarity':
|
|
|
|
| 234 |
st.session_state["user_data"] = data_df
|
| 235 |
if "selected_slice" not in st.session_state:
|
| 236 |
st.session_state["selected_slice"] = None
|
| 237 |
+
|
| 238 |
+
loss_quantile = st.sidebar.slider(
|
| 239 |
+
"Loss Quantile", min_value=0.5, max_value=1.0,step=0.01,value=0.95
|
| 240 |
+
)
|
| 241 |
data_df['loss'] = data_df['loss'].astype(float)
|
| 242 |
losses = data_df['loss']
|
| 243 |
high_loss = losses.quantile(loss_quantile)
|
| 244 |
data_df['slice'] = 'high-loss'
|
| 245 |
data_df['slice'] = data_df['slice'].where(data_df['loss'] > high_loss, 'low-loss')
|
| 246 |
|
| 247 |
+
with rcol:
|
| 248 |
+
with st.spinner(text='loading...'):
|
| 249 |
+
st.markdown('<h3>Word Distribution in Error Slice</h3>', unsafe_allow_html=True)
|
| 250 |
+
commontokens = frequent_tokens(data_df, tokenizer, loss_quantile=loss_quantile)
|
| 251 |
+
with st.expander("How to read the table:"):
|
| 252 |
+
st.markdown("* The table displays the most frequent tokens in error slices, relative to their frequencies in the val set.")
|
| 253 |
+
st.write(commontokens)
|
| 254 |
+
|
| 255 |
+
run_kmeans = st.sidebar.radio("Cluster error slice?", ('True', 'False'), index=0)
|
| 256 |
+
|
| 257 |
+
num_clusters = st.sidebar.slider("# clusters", min_value=1, max_value=20, step=1, value=3)
|
| 258 |
+
|
| 259 |
if run_kmeans == 'True':
|
| 260 |
merged = kmeans(data_df,num_clusters=num_clusters)
|
| 261 |
with lcol:
|
|
|
|
| 270 |
st.markdown("* The table displays model error slices on the evaluation dataset, sorted by loss.")
|
| 271 |
st.markdown("* Each row is an input example that includes the label, model pred, loss, and error cluster.")
|
| 272 |
st.write(dataframe,width=900, height=300)
|
| 273 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
quant_panel(merged)
|