Spaces:
Runtime error
Runtime error
Commit
·
08879a1
1
Parent(s):
a43a76a
format file and remove share=True
Browse files
app.py
CHANGED
|
@@ -31,6 +31,7 @@ for bert_like in MODEL_NAMES:
|
|
| 31 |
|
| 32 |
# %%
|
| 33 |
|
|
|
|
| 34 |
def clean_tokens(tokens):
|
| 35 |
return [token.strip() for token in tokens]
|
| 36 |
|
|
@@ -61,8 +62,6 @@ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_pre
|
|
| 61 |
return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
|
| 62 |
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
def get_figure(df, gender, n_fit=1):
|
| 67 |
df = df.set_index("x-axis")
|
| 68 |
cols = df.columns
|
|
@@ -75,16 +74,16 @@ def get_figure(df, gender, n_fit=1):
|
|
| 75 |
|
| 76 |
# find stackoverflow reference
|
| 77 |
p, C_p = np.polyfit(xs, ys, n_fit, cov=1)
|
| 78 |
-
t = np.linspace(min(xs)-1, max(xs)+1,
|
| 79 |
-
TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T
|
| 80 |
|
| 81 |
# matrix multiplication calculates the polynomial values
|
| 82 |
yi = np.dot(TT, p)
|
| 83 |
C_yi = np.dot(TT, np.dot(C_p, TT.T)) # C_y = TT*C_z*TT.T
|
| 84 |
sig_yi = np.sqrt(np.diag(C_yi)) # Standard deviations are sqrt of diagonal
|
| 85 |
|
| 86 |
-
ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha
|
| 87 |
-
ax.plot(t, yi,
|
| 88 |
ax.plot(df, "ro")
|
| 89 |
ax.legend(list(df.columns))
|
| 90 |
|
|
@@ -97,7 +96,6 @@ def get_figure(df, gender, n_fit=1):
|
|
| 97 |
return fig
|
| 98 |
|
| 99 |
|
| 100 |
-
|
| 101 |
# %%
|
| 102 |
def predict_masked_tokens(
|
| 103 |
model_name,
|
|
@@ -185,34 +183,33 @@ def predict_masked_tokens(
|
|
| 185 |
|
| 186 |
truck_fn_example = [
|
| 187 |
MODEL_NAMES[2],
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
"True",
|
| 194 |
1,
|
| 195 |
]
|
|
|
|
|
|
|
| 196 |
def truck_1_fn():
|
| 197 |
-
return truck_fn_example + [
|
| 198 |
-
|
| 199 |
-
]
|
| 200 |
|
| 201 |
def truck_2_fn():
|
| 202 |
return truck_fn_example + [
|
| 203 |
-
|
| 204 |
]
|
| 205 |
|
| 206 |
|
| 207 |
# # %%
|
| 208 |
|
| 209 |
|
| 210 |
-
|
| 211 |
demo = gr.Blocks()
|
| 212 |
with demo:
|
| 213 |
gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs")
|
| 214 |
|
| 215 |
-
|
| 216 |
gr.Markdown("## Instructions for this Demo")
|
| 217 |
gr.Markdown(
|
| 218 |
"1) Click on one of the examples below to pre-populate the input fields."
|
|
@@ -224,8 +221,8 @@ with demo:
|
|
| 224 |
"3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!"
|
| 225 |
)
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
| 229 |
We can see this spurious correlation largely disappears in the well-specified example text.
|
| 230 |
|
| 231 |
<p align="center">
|
|
@@ -236,18 +233,25 @@ with demo:
|
|
| 236 |
<p align="center">
|
| 237 |
<img src="file/well_spec.png" alt="results" width="300"/>
|
| 238 |
</p>
|
| 239 |
-
"""
|
|
|
|
| 240 |
|
| 241 |
gr.Markdown("## Example inputs")
|
| 242 |
gr.Markdown(
|
| 243 |
"Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions."
|
| 244 |
)
|
| 245 |
with gr.Row():
|
| 246 |
-
truck_1_gen = gr.Button(
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
truck_2_gen = gr.Button("Click for well-specified vehicle-type example inputs")
|
| 250 |
-
gr.Markdown(
|
|
|
|
|
|
|
| 251 |
|
| 252 |
gr.Markdown("## Input fields")
|
| 253 |
gr.Markdown(
|
|
@@ -343,11 +347,37 @@ with demo:
|
|
| 343 |
)
|
| 344 |
|
| 345 |
with gr.Row():
|
| 346 |
-
truck_1_gen.click(
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
-
truck_2_gen.click(
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
btn.click(
|
| 353 |
predict_masked_tokens,
|
|
@@ -365,8 +395,6 @@ with demo:
|
|
| 365 |
outputs=[sample_text, female_fig, male_fig, df],
|
| 366 |
)
|
| 367 |
|
| 368 |
-
demo.launch(debug=True
|
| 369 |
|
| 370 |
# %%
|
| 371 |
-
|
| 372 |
-
|
|
|
|
| 31 |
|
| 32 |
# %%
|
| 33 |
|
| 34 |
+
|
| 35 |
def clean_tokens(tokens):
|
| 36 |
return [token.strip() for token in tokens]
|
| 37 |
|
|
|
|
| 62 |
return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
| 65 |
def get_figure(df, gender, n_fit=1):
|
| 66 |
df = df.set_index("x-axis")
|
| 67 |
cols = df.columns
|
|
|
|
| 74 |
|
| 75 |
# find stackoverflow reference
|
| 76 |
p, C_p = np.polyfit(xs, ys, n_fit, cov=1)
|
| 77 |
+
t = np.linspace(min(xs) - 1, max(xs) + 1, 10 * len(xs))
|
| 78 |
+
TT = np.vstack([t ** (n_fit - i) for i in range(n_fit + 1)]).T
|
| 79 |
|
| 80 |
# matrix multiplication calculates the polynomial values
|
| 81 |
yi = np.dot(TT, p)
|
| 82 |
C_yi = np.dot(TT, np.dot(C_p, TT.T)) # C_y = TT*C_z*TT.T
|
| 83 |
sig_yi = np.sqrt(np.diag(C_yi)) # Standard deviations are sqrt of diagonal
|
| 84 |
|
| 85 |
+
ax.fill_between(t, yi + sig_yi, yi - sig_yi, alpha=0.25)
|
| 86 |
+
ax.plot(t, yi, "-")
|
| 87 |
ax.plot(df, "ro")
|
| 88 |
ax.legend(list(df.columns))
|
| 89 |
|
|
|
|
| 96 |
return fig
|
| 97 |
|
| 98 |
|
|
|
|
| 99 |
# %%
|
| 100 |
def predict_masked_tokens(
|
| 101 |
model_name,
|
|
|
|
| 183 |
|
| 184 |
truck_fn_example = [
|
| 185 |
MODEL_NAMES[2],
|
| 186 |
+
"",
|
| 187 |
+
", ".join(["truck", "pickup"]),
|
| 188 |
+
", ".join(["car", "sedan"]),
|
| 189 |
+
", ".join(["city", "neighborhood", "farm"]),
|
| 190 |
+
"PLACE",
|
| 191 |
"True",
|
| 192 |
1,
|
| 193 |
]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
def truck_1_fn():
|
| 197 |
+
return truck_fn_example + ["He loaded up his truck and drove to the PLACE."]
|
| 198 |
+
|
|
|
|
| 199 |
|
| 200 |
def truck_2_fn():
|
| 201 |
return truck_fn_example + [
|
| 202 |
+
"He loaded up the bed of his truck and drove to the PLACE."
|
| 203 |
]
|
| 204 |
|
| 205 |
|
| 206 |
# # %%
|
| 207 |
|
| 208 |
|
|
|
|
| 209 |
demo = gr.Blocks()
|
| 210 |
with demo:
|
| 211 |
gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs")
|
| 212 |
|
|
|
|
| 213 |
gr.Markdown("## Instructions for this Demo")
|
| 214 |
gr.Markdown(
|
| 215 |
"1) Click on one of the examples below to pre-populate the input fields."
|
|
|
|
| 221 |
"3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!"
|
| 222 |
)
|
| 223 |
|
| 224 |
+
gr.Markdown(
|
| 225 |
+
"""The pre-populated inputs below are for a demo example of a location-vs-vehicle-type spurious correlation.
|
| 226 |
We can see this spurious correlation largely disappears in the well-specified example text.
|
| 227 |
|
| 228 |
<p align="center">
|
|
|
|
| 233 |
<p align="center">
|
| 234 |
<img src="file/well_spec.png" alt="results" width="300"/>
|
| 235 |
</p>
|
| 236 |
+
"""
|
| 237 |
+
)
|
| 238 |
|
| 239 |
gr.Markdown("## Example inputs")
|
| 240 |
gr.Markdown(
|
| 241 |
"Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions."
|
| 242 |
)
|
| 243 |
with gr.Row():
|
| 244 |
+
truck_1_gen = gr.Button(
|
| 245 |
+
"Click for non-well-specified(?) vehicle-type example inputs"
|
| 246 |
+
)
|
| 247 |
+
gr.Markdown(
|
| 248 |
+
"<-- Multiple solutions with low training error. LLM sensitive to spurious(?) correlations."
|
| 249 |
+
)
|
| 250 |
|
| 251 |
truck_2_gen = gr.Button("Click for well-specified vehicle-type example inputs")
|
| 252 |
+
gr.Markdown(
|
| 253 |
+
"<-- Fewer solutions with low training error. LLM less sensitive to spurious(?) correlations."
|
| 254 |
+
)
|
| 255 |
|
| 256 |
gr.Markdown("## Input fields")
|
| 257 |
gr.Markdown(
|
|
|
|
| 347 |
)
|
| 348 |
|
| 349 |
with gr.Row():
|
| 350 |
+
truck_1_gen.click(
|
| 351 |
+
truck_1_fn,
|
| 352 |
+
inputs=[],
|
| 353 |
+
outputs=[
|
| 354 |
+
model_name,
|
| 355 |
+
own_model_name,
|
| 356 |
+
group_a_tokens,
|
| 357 |
+
group_b_tokens,
|
| 358 |
+
x_axis,
|
| 359 |
+
place_holder,
|
| 360 |
+
to_normalize,
|
| 361 |
+
n_fit,
|
| 362 |
+
input_text,
|
| 363 |
+
],
|
| 364 |
+
)
|
| 365 |
|
| 366 |
+
truck_2_gen.click(
|
| 367 |
+
truck_2_fn,
|
| 368 |
+
inputs=[],
|
| 369 |
+
outputs=[
|
| 370 |
+
model_name,
|
| 371 |
+
own_model_name,
|
| 372 |
+
group_a_tokens,
|
| 373 |
+
group_b_tokens,
|
| 374 |
+
x_axis,
|
| 375 |
+
place_holder,
|
| 376 |
+
to_normalize,
|
| 377 |
+
n_fit,
|
| 378 |
+
input_text,
|
| 379 |
+
],
|
| 380 |
+
)
|
| 381 |
|
| 382 |
btn.click(
|
| 383 |
predict_masked_tokens,
|
|
|
|
| 395 |
outputs=[sample_text, female_fig, male_fig, df],
|
| 396 |
)
|
| 397 |
|
| 398 |
+
demo.launch(debug=True)
|
| 399 |
|
| 400 |
# %%
|
|
|
|
|
|