ltg-bert-babylm-sd-race / modeling_sentence_debias.py
FilipT's picture
Add race-debiased LTG-BERT BabyLM (Sentence-Debias)
8bda8ac verified
import torch, transformers
from functools import partial
from huggingface_hub import hf_hub_download
from .modeling_ltgbert import LtgBertForMaskedLM
def _debias_hook(b_dir, module, inputs, output):
if hasattr(output, "last_hidden_state"):
x = output.last_hidden_state
container, key = output, "last_hidden_state"
else:
seq = output[0]
if isinstance(seq, list):
x = seq[-1]
container, key = seq, -1
else:
x = seq
container, key = output, 0
b = b_dir.to(x.device)
proj = torch.matmul(x, b) / torch.dot(b, b)
debiased = x - proj.unsqueeze(-1) * b
container[key] = debiased
return output
class SentenceDebiasLtgBertForMaskedLM(LtgBertForMaskedLM):
"""ltg-bert-babylm with race Sentence-Debias."""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
# make sure custom LTG code is imported
kwargs.setdefault("trust_remote_code", True)
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
# fetch bias vector
bias_path = kwargs.pop("bias_direction_path", None)
if bias_path is None:
bias_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="bias_direction_race.pt",
revision=kwargs.get("revision", None),
)
bias_vec = torch.load(bias_path, map_location="cpu")
# register hook on the encoder block ('.transformer' in LTG-BERT)
block = model.transformer if hasattr(model, "transformer") else model
block.register_forward_hook(partial(_debias_hook, bias_vec))
model.register_buffer("bias_direction", bias_vec)
return model