Spaces:
Runtime error
Runtime error
Commit
·
6bfe2f5
1
Parent(s):
6a3abb5
Displaying target_text with [MASK]. More doc improvements. Formatting.
Browse files
app.py
CHANGED
|
@@ -22,7 +22,8 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
| 22 |
MAX_TOKEN_LENGTH = 128
|
| 23 |
NON_LOSS_TOKEN_ID = -100
|
| 24 |
NON_GENDERED_TOKEN_ID = 30 # Picked an int that will pop out visually
|
| 25 |
-
|
|
|
|
| 26 |
CLASSES = list(LABEL_DICT.keys())
|
| 27 |
|
| 28 |
|
|
@@ -48,12 +49,14 @@ for var in CONDITIONING_VARIABLES:
|
|
| 48 |
)
|
| 49 |
for bert_like in BERT_LIKE_MODELS:
|
| 50 |
models_paths[(bert_like,)] = f"{bert_like}-base-uncased"
|
| 51 |
-
models[(bert_like,)] = pipeline(
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
# Tokenizers same for each model, so just grabbing one of them
|
| 55 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 56 |
-
models_paths[(CONDITIONING_VARIABLES[0], FEMALE_WEIGHTS[0])
|
|
|
|
| 57 |
)
|
| 58 |
MASK_TOKEN_ID = tokenizer.mask_token_id
|
| 59 |
|
|
@@ -72,7 +75,8 @@ male_gendered_tokens = [list[0] for list in gendered_lists]
|
|
| 72 |
female_gendered_tokens = [list[1] for list in gendered_lists]
|
| 73 |
|
| 74 |
male_gendered_token_ids = tokenizer.convert_tokens_to_ids(male_gendered_tokens)
|
| 75 |
-
female_gendered_token_ids = tokenizer.convert_tokens_to_ids(
|
|
|
|
| 76 |
|
| 77 |
assert tokenizer.unk_token_id not in male_gendered_token_ids
|
| 78 |
assert tokenizer.unk_token_id not in female_gendered_token_ids
|
|
@@ -122,7 +126,8 @@ def tokenize_and_append_metadata(text, tokenizer):
|
|
| 122 |
male_tags == LABEL_DICT["male"], label2id[LABEL_DICT["male"]], labels
|
| 123 |
)
|
| 124 |
masked_token_ids = torch.where(
|
| 125 |
-
female_tags == LABEL_DICT["female"], MASK_TOKEN_ID, torch.tensor(
|
|
|
|
| 126 |
)
|
| 127 |
masked_token_ids = torch.where(
|
| 128 |
male_tags == LABEL_DICT["male"], MASK_TOKEN_ID, masked_token_ids
|
|
@@ -138,22 +143,23 @@ def get_tokenized_text_with_years(years, input_text):
|
|
| 138 |
"""Construct dict of tokenized texts with each year injected into the text."""
|
| 139 |
text_portions = input_text.split(SPLIT_KEY)
|
| 140 |
|
| 141 |
-
tokenized_w_year = {'ids':[], 'atten_mask':[], 'toks':[], 'labels':[]}
|
| 142 |
for b_date in years:
|
| 143 |
|
| 144 |
target_text = f"{b_date}".join(text_portions)
|
| 145 |
tokenized_sample = tokenize_and_append_metadata(
|
| 146 |
target_text,
|
| 147 |
-
tokenizer=tokenizer,
|
| 148 |
)
|
| 149 |
|
| 150 |
tokenized_w_year['ids'].append(tokenized_sample["input_ids"])
|
| 151 |
-
tokenized_w_year['atten_mask'].append(
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
tokenized_w_year['labels'].append(tokenized_sample["labels"])
|
| 154 |
|
| 155 |
-
|
| 156 |
-
return tokenized_w_year, target_text
|
| 157 |
|
| 158 |
|
| 159 |
def predict_gender_pronouns(
|
|
@@ -165,7 +171,8 @@ def predict_gender_pronouns(
|
|
| 165 |
|
| 166 |
years = np.linspace(START_YEAR, STOP_YEAR, int(num_points)).astype(int)
|
| 167 |
|
| 168 |
-
tokenized
|
|
|
|
| 169 |
is_masked = tokenized['ids'][0] == MASK_TOKEN_ID
|
| 170 |
num_preds = torch.sum(is_masked).item()
|
| 171 |
|
|
@@ -184,17 +191,20 @@ def predict_gender_pronouns(
|
|
| 184 |
labels = tokenized["labels"][year_idx]
|
| 185 |
|
| 186 |
with torch.no_grad():
|
| 187 |
-
outputs = model(ids.unsqueeze(dim=0),
|
|
|
|
| 188 |
preds = torch.argmax(outputs[0][0].cpu(), dim=1)
|
| 189 |
|
| 190 |
#was_masked = labels.cpu() != -100
|
| 191 |
preds = torch.where(is_masked, preds, -100)
|
| 192 |
|
| 193 |
-
p_female.append(
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
|
|
|
|
|
|
|
| 198 |
|
| 199 |
for bert_like in bert_like_models:
|
| 200 |
|
|
@@ -202,13 +212,13 @@ def predict_gender_pronouns(
|
|
| 202 |
p_male = []
|
| 203 |
for year_idx in range(len(tokenized['ids'])):
|
| 204 |
toks = tokenized["toks"][year_idx]
|
| 205 |
-
target_text_for_bert = ' '.join(
|
|
|
|
| 206 |
|
| 207 |
prefix = bert_like
|
| 208 |
model = models[(bert_like,)]
|
| 209 |
|
| 210 |
-
|
| 211 |
-
mask_filled_text = model(target_text_for_bert)
|
| 212 |
|
| 213 |
female_pronouns = [
|
| 214 |
1 if pronoun[0]["token_str"] in female_gendered_tokens else 0
|
|
@@ -222,8 +232,12 @@ def predict_gender_pronouns(
|
|
| 222 |
p_female.append(sum(female_pronouns) / num_preds * 100)
|
| 223 |
p_male.append(sum(male_pronouns) / num_preds * 100)
|
| 224 |
|
| 225 |
-
dfs.append(pd.DataFrame(
|
|
|
|
| 226 |
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
results = pd.concat(dfs, axis=1).set_index("year")
|
| 229 |
|
|
@@ -238,7 +252,7 @@ def predict_gender_pronouns(
|
|
| 238 |
) # Gradio timeseries requires x-axis as column?
|
| 239 |
|
| 240 |
return (
|
| 241 |
-
|
| 242 |
female_df_for_plot,
|
| 243 |
female_df,
|
| 244 |
male_df_for_plot,
|
|
@@ -249,15 +263,16 @@ def predict_gender_pronouns(
|
|
| 249 |
title = "Changing Gender Pronouns"
|
| 250 |
description = """
|
| 251 |
<h2> Intro </h2>
|
| 252 |
-
This is a demo for a project exploring possible spurious correlations that have been learned by our models. We can examine the training datasets and learning tasks to hypothesize what spurious correlations may exist, then condition on these variables to determine if we can achieve alternative outcomes.
|
| 253 |
|
| 254 |
Specially in this demo: In a user provided sentence, with at least one reference to a `DATE` and one gender pronoun, we will see how sweeping through a range of `DATE` values can change the predicted pronouns. This effect can be observed in BERT base models and in our fine-tuned models (with a specific pronoun predicting task on the [wiki-bio](https://huggingface.co/datasets/wiki_bio) dataset).
|
| 255 |
|
| 256 |
-
One way to explain this phenomena is by looking at a likely data generating process for biographical-like data in both the main BERT training dataset as well as the `wiki_bio` dataset, in the form of a causal DAG.
|
| 257 |
|
| 258 |
<h2> Causal DAG </h2>
|
| 259 |
In the DAG, we can see that `birth_place`, `birth_date` and `gender` are all independent elements that have no common cause with the other covariates in the DAG. However `birth_place`, `birth_date` and `gender` may all have a role in causing one's `access_to_resources`, with the general trend that `access_to_resources` has become less gender-dependent over time, but not in every `birth_place`, with recent events in Afghanistan providing a stark counterexample to this trend. Importantly, `access_to_resources` determines how, **if at all**, you may appear in the dataset's `context_words`.
|
| 260 |
-
|
|
|
|
| 261 |
We argue that although there are complex causal interactions between each words in any given sentence, the `context_words` are more likely to cause the `gender_pronouns`, rather than vice versa. For example, if the subject is a famous doctor and the object is her wealthy father, these context words will determine which person is being referred to, and thus which gendered-pronoun to use.
|
| 262 |
|
| 263 |
|
|
@@ -298,9 +313,10 @@ In the resulting plots, we can look for a dose-response relationship between:
|
|
| 298 |
- our treatment: the sample text,
|
| 299 |
- and our outcome: the predicted gender of pronouns in the text.
|
| 300 |
|
| 301 |
-
Specifically we are seeing if making larger magnitude intervention: an older `DATE` in the text will result in a larger magnitude effect in the outcome: higher percentage of predicted female pronouns.
|
| 302 |
|
| 303 |
-
|
|
|
|
| 304 |
- While conditioning on either no metadata or `birth_place` data training, have similar middle-ground effects for this inference task.
|
| 305 |
- Finally, conditioning on `name` metadata in training, (while again conditioning on `date` in inference) has almost no dose-response relationship. It appears the learning of a `name —> gender pronouns` relationship was sufficiently successful to overwhelm any potential more nuanced learning, such as that driven by `birth_date` or `place`.
|
| 306 |
|
|
@@ -339,11 +355,12 @@ gr.Interface(
|
|
| 339 |
gr.inputs.Textbox(
|
| 340 |
lines=7,
|
| 341 |
label="Input Text: Include one of more instance of the word 'DATE' below (to be replace with a range of dates in demo), and one of more gender pronoun (to be masked for prediction).",
|
| 342 |
-
default="Born DATE, she was a computer scientist. Her work was greatly respected, and she was well-regarded in her field.",
|
| 343 |
),
|
| 344 |
],
|
| 345 |
outputs=[
|
| 346 |
-
gr.outputs.Textbox(
|
|
|
|
| 347 |
gr.outputs.Timeseries(
|
| 348 |
x="year",
|
| 349 |
label="Precent pred female pronoun vs year, per model trained with conditioning and with weight for female preds",
|
|
@@ -364,5 +381,4 @@ gr.Interface(
|
|
| 364 |
title=title,
|
| 365 |
description=description,
|
| 366 |
article=article,
|
| 367 |
-
).launch()
|
| 368 |
-
|
|
|
|
| 22 |
MAX_TOKEN_LENGTH = 128
|
| 23 |
NON_LOSS_TOKEN_ID = -100
|
| 24 |
NON_GENDERED_TOKEN_ID = 30 # Picked an int that will pop out visually
|
| 25 |
+
# Picked an int that will pop out visually
|
| 26 |
+
LABEL_DICT = {"female": 9, "male": -9}
|
| 27 |
CLASSES = list(LABEL_DICT.keys())
|
| 28 |
|
| 29 |
|
|
|
|
| 49 |
)
|
| 50 |
for bert_like in BERT_LIKE_MODELS:
|
| 51 |
models_paths[(bert_like,)] = f"{bert_like}-base-uncased"
|
| 52 |
+
models[(bert_like,)] = pipeline(
|
| 53 |
+
"fill-mask", model=models_paths[(bert_like,)])
|
| 54 |
|
| 55 |
|
| 56 |
# Tokenizers same for each model, so just grabbing one of them
|
| 57 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 58 |
+
models_paths[(CONDITIONING_VARIABLES[0], FEMALE_WEIGHTS[0])
|
| 59 |
+
], add_prefix_space=True
|
| 60 |
)
|
| 61 |
MASK_TOKEN_ID = tokenizer.mask_token_id
|
| 62 |
|
|
|
|
| 75 |
female_gendered_tokens = [list[1] for list in gendered_lists]
|
| 76 |
|
| 77 |
male_gendered_token_ids = tokenizer.convert_tokens_to_ids(male_gendered_tokens)
|
| 78 |
+
female_gendered_token_ids = tokenizer.convert_tokens_to_ids(
|
| 79 |
+
female_gendered_tokens)
|
| 80 |
|
| 81 |
assert tokenizer.unk_token_id not in male_gendered_token_ids
|
| 82 |
assert tokenizer.unk_token_id not in female_gendered_token_ids
|
|
|
|
| 126 |
male_tags == LABEL_DICT["male"], label2id[LABEL_DICT["male"]], labels
|
| 127 |
)
|
| 128 |
masked_token_ids = torch.where(
|
| 129 |
+
female_tags == LABEL_DICT["female"], MASK_TOKEN_ID, torch.tensor(
|
| 130 |
+
token_ids)
|
| 131 |
)
|
| 132 |
masked_token_ids = torch.where(
|
| 133 |
male_tags == LABEL_DICT["male"], MASK_TOKEN_ID, masked_token_ids
|
|
|
|
| 143 |
"""Construct dict of tokenized texts with each year injected into the text."""
|
| 144 |
text_portions = input_text.split(SPLIT_KEY)
|
| 145 |
|
| 146 |
+
tokenized_w_year = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []}
|
| 147 |
for b_date in years:
|
| 148 |
|
| 149 |
target_text = f"{b_date}".join(text_portions)
|
| 150 |
tokenized_sample = tokenize_and_append_metadata(
|
| 151 |
target_text,
|
| 152 |
+
tokenizer=tokenizer,
|
| 153 |
)
|
| 154 |
|
| 155 |
tokenized_w_year['ids'].append(tokenized_sample["input_ids"])
|
| 156 |
+
tokenized_w_year['atten_mask'].append(
|
| 157 |
+
torch.tensor(tokenized_sample["attention_mask"]))
|
| 158 |
+
tokenized_w_year['toks'].append(
|
| 159 |
+
tokenizer.convert_ids_to_tokens(tokenized_sample["input_ids"]))
|
| 160 |
tokenized_w_year['labels'].append(tokenized_sample["labels"])
|
| 161 |
|
| 162 |
+
return tokenized_w_year
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def predict_gender_pronouns(
|
|
|
|
| 171 |
|
| 172 |
years = np.linspace(START_YEAR, STOP_YEAR, int(num_points)).astype(int)
|
| 173 |
|
| 174 |
+
tokenized = get_tokenized_text_with_years(years, input_text)
|
| 175 |
+
|
| 176 |
is_masked = tokenized['ids'][0] == MASK_TOKEN_ID
|
| 177 |
num_preds = torch.sum(is_masked).item()
|
| 178 |
|
|
|
|
| 191 |
labels = tokenized["labels"][year_idx]
|
| 192 |
|
| 193 |
with torch.no_grad():
|
| 194 |
+
outputs = model(ids.unsqueeze(dim=0),
|
| 195 |
+
atten_mask.unsqueeze(dim=0))
|
| 196 |
preds = torch.argmax(outputs[0][0].cpu(), dim=1)
|
| 197 |
|
| 198 |
#was_masked = labels.cpu() != -100
|
| 199 |
preds = torch.where(is_masked, preds, -100)
|
| 200 |
|
| 201 |
+
p_female.append(
|
| 202 |
+
len(torch.where(preds == 0)[0]) / num_preds * 100)
|
| 203 |
+
p_male.append(
|
| 204 |
+
len(torch.where(preds == 1)[0]) / num_preds * 100)
|
| 205 |
|
| 206 |
+
dfs.append(pd.DataFrame(
|
| 207 |
+
{f"%f_{prefix}": p_female, f"%m_{prefix}": p_male}))
|
| 208 |
|
| 209 |
for bert_like in bert_like_models:
|
| 210 |
|
|
|
|
| 212 |
p_male = []
|
| 213 |
for year_idx in range(len(tokenized['ids'])):
|
| 214 |
toks = tokenized["toks"][year_idx]
|
| 215 |
+
target_text_for_bert = ' '.join(
|
| 216 |
+
toks[1:-1]) # Removing [CLS] and [SEP]
|
| 217 |
|
| 218 |
prefix = bert_like
|
| 219 |
model = models[(bert_like,)]
|
| 220 |
|
| 221 |
+
mask_filled_text = model(target_text_for_bert)
|
|
|
|
| 222 |
|
| 223 |
female_pronouns = [
|
| 224 |
1 if pronoun[0]["token_str"] in female_gendered_tokens else 0
|
|
|
|
| 232 |
p_female.append(sum(female_pronouns) / num_preds * 100)
|
| 233 |
p_male.append(sum(male_pronouns) / num_preds * 100)
|
| 234 |
|
| 235 |
+
dfs.append(pd.DataFrame(
|
| 236 |
+
{f"%f_{prefix}": p_female, f"%m_{prefix}": p_male}))
|
| 237 |
|
| 238 |
+
# To display to user as an example
|
| 239 |
+
toks = tokenized["toks"][0]
|
| 240 |
+
target_text_w_masks = ' '.join(toks[1:-1])
|
| 241 |
|
| 242 |
results = pd.concat(dfs, axis=1).set_index("year")
|
| 243 |
|
|
|
|
| 252 |
) # Gradio timeseries requires x-axis as column?
|
| 253 |
|
| 254 |
return (
|
| 255 |
+
target_text_w_masks,
|
| 256 |
female_df_for_plot,
|
| 257 |
female_df,
|
| 258 |
male_df_for_plot,
|
|
|
|
| 263 |
title = "Changing Gender Pronouns"
|
| 264 |
description = """
|
| 265 |
<h2> Intro </h2>
|
| 266 |
+
This is a demo for a project exploring possible spurious correlations that have been learned by our models. We can examine the training datasets and learning tasks to hypothesize what spurious correlations may exist, then condition on these variables to determine if we can achieve alternative outcomes.
|
| 267 |
|
| 268 |
Specially in this demo: In a user provided sentence, with at least one reference to a `DATE` and one gender pronoun, we will see how sweeping through a range of `DATE` values can change the predicted pronouns. This effect can be observed in BERT base models and in our fine-tuned models (with a specific pronoun predicting task on the [wiki-bio](https://huggingface.co/datasets/wiki_bio) dataset).
|
| 269 |
|
| 270 |
+
One way to explain this phenomena is by looking at a likely data generating process for biographical-like data in both the main BERT training dataset as well as the `wiki_bio` dataset, in the form of a causal DAG.
|
| 271 |
|
| 272 |
<h2> Causal DAG </h2>
|
| 273 |
In the DAG, we can see that `birth_place`, `birth_date` and `gender` are all independent elements that have no common cause with the other covariates in the DAG. However `birth_place`, `birth_date` and `gender` may all have a role in causing one's `access_to_resources`, with the general trend that `access_to_resources` has become less gender-dependent over time, but not in every `birth_place`, with recent events in Afghanistan providing a stark counterexample to this trend. Importantly, `access_to_resources` determines how, **if at all**, you may appear in the dataset's `context_words`.
|
| 274 |
+
|
| 275 |
+
|
| 276 |
We argue that although there are complex causal interactions between each words in any given sentence, the `context_words` are more likely to cause the `gender_pronouns`, rather than vice versa. For example, if the subject is a famous doctor and the object is her wealthy father, these context words will determine which person is being referred to, and thus which gendered-pronoun to use.
|
| 277 |
|
| 278 |
|
|
|
|
| 313 |
- our treatment: the sample text,
|
| 314 |
- and our outcome: the predicted gender of pronouns in the text.
|
| 315 |
|
| 316 |
+
Specifically we are seeing if 1) making larger magnitude intervention: an older `DATE` in the text will, 2) result in a larger magnitude effect in the outcome: higher percentage of predicted female pronouns.
|
| 317 |
|
| 318 |
+
Some trends that appear in the test sentences I have tried:
|
| 319 |
+
- Conditioning on `birth_date` metadata in both training and inference text has the largest dose-response relationship. This seems reasonable, as the fine-tuned model is able to 'stratify' a learned relationship between gender pronouns and dates, when both are present in the text.
|
| 320 |
- While conditioning on either no metadata or `birth_place` data training, have similar middle-ground effects for this inference task.
|
| 321 |
- Finally, conditioning on `name` metadata in training, (while again conditioning on `date` in inference) has almost no dose-response relationship. It appears the learning of a `name —> gender pronouns` relationship was sufficiently successful to overwhelm any potential more nuanced learning, such as that driven by `birth_date` or `place`.
|
| 322 |
|
|
|
|
| 355 |
gr.inputs.Textbox(
|
| 356 |
lines=7,
|
| 357 |
label="Input Text: Include one of more instance of the word 'DATE' below (to be replace with a range of dates in demo), and one of more gender pronoun (to be masked for prediction).",
|
| 358 |
+
default="Born in DATE, she was a computer scientist. Her work was greatly respected, and she was well-regarded in her field.",
|
| 359 |
),
|
| 360 |
],
|
| 361 |
outputs=[
|
| 362 |
+
gr.outputs.Textbox(
|
| 363 |
+
type="auto", label="Sample target text fed to model"),
|
| 364 |
gr.outputs.Timeseries(
|
| 365 |
x="year",
|
| 366 |
label="Precent pred female pronoun vs year, per model trained with conditioning and with weight for female preds",
|
|
|
|
| 381 |
title=title,
|
| 382 |
description=description,
|
| 383 |
article=article,
|
| 384 |
+
).launch()
|
|
|