|
|
| import torch, transformers |
| from functools import partial |
| from huggingface_hub import hf_hub_download |
| from .modeling_ltgbert import LtgBertForMaskedLM |
|
|
| def _debias_hook(b_dir, module, inputs, output): |
| if hasattr(output, "last_hidden_state"): |
| x = output.last_hidden_state |
| container, key = output, "last_hidden_state" |
|
|
| else: |
| seq = output[0] |
| if isinstance(seq, list): |
| x = seq[-1] |
| container, key = seq, -1 |
| else: |
| x = seq |
| container, key = output, 0 |
| b = b_dir.to(x.device) |
| proj = torch.matmul(x, b) / torch.dot(b, b) |
| debiased = x - proj.unsqueeze(-1) * b |
|
|
| container[key] = debiased |
| return output |
|
|
| class SentenceDebiasLtgBertForMaskedLM(LtgBertForMaskedLM): |
| """ltg-bert-babylm with race Sentence-Debias.""" |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| |
| kwargs.setdefault("trust_remote_code", True) |
| model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
|
|
| |
| bias_path = kwargs.pop("bias_direction_path", None) |
| if bias_path is None: |
| bias_path = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, |
| filename="bias_direction_race.pt", |
| revision=kwargs.get("revision", None), |
| ) |
| bias_vec = torch.load(bias_path, map_location="cpu") |
|
|
| |
| block = model.transformer if hasattr(model, "transformer") else model |
| block.register_forward_hook(partial(_debias_hook, bias_vec)) |
| model.register_buffer("bias_direction", bias_vec) |
| return model |
|
|