Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import higher | |
| from editable_model import EditableModel | |
| from utils import _logits | |
| def fomaml_callback(all_grads): | |
| return [g.detach() if g is not None else None for g in all_grads] | |
| class ENN(EditableModel): | |
| def __init__(self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None): | |
| super().__init__(model, config, model_constructor) | |
| if edit_lrs is None: | |
| edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params))) | |
| self.edit_lrs = edit_lrs | |
| if edit_loss_fn is not None: | |
| self.edit_loss_fn = edit_loss_fn | |
| self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x | |
| def outer_parameters(self, grouped=False): | |
| extra_params = [self.edit_lrs] | |
| if self.config.no_grad_layers is None: | |
| model_params = self.model.parameters() if type(self.model.parameters()) == list else list(self.model.parameters()) | |
| else: | |
| model_params = [] | |
| for m in self.model.modules(): | |
| if isinstance(m, nn.ModuleList): | |
| model_params.extend(list(m[self.config.no_grad_layers:].parameters())) | |
| if grouped: | |
| return [ | |
| dict(params=model_params, lr=self.config.lr), | |
| dict(params=extra_params, lr=self.config.lr_lr) | |
| ] | |
| else: | |
| return model_params + extra_params | |
| def get_state_dict(self): | |
| return self.state_dict() | |
| def edit(self, batch, condition=None, detach_history=False): | |
| opt = torch.optim.SGD([{"params": p, "lr": None} | |
| for (n, p) in self.model.named_parameters() if n in self.config.model.inner_params]) | |
| with torch.enable_grad(), higher.innerloop_ctx( | |
| self.model, | |
| opt, | |
| override={'lr': list(self.edit_lrs)}, | |
| copy_initial_weights=False, | |
| track_higher_grads=self.training, | |
| in_place=True | |
| ) as (fmodel, diffopt): | |
| fmodel.eval() | |
| for edit_step in range(self.config.enn.n_edit_steps): | |
| output = _logits(fmodel(**batch)) | |
| loss = self.edit_loss_fn(output, batch["labels"])["nll"] | |
| diffopt.step(loss, grad_callback=self.grad_callback) | |
| if not detach_history: | |
| model_edited = fmodel | |
| else: | |
| model_edited = self.model_constructor() | |
| model_edited.load_state_dict(fmodel.state_dict()) | |
| model_edited.train(self.training) | |
| return ENN(model_edited, self.config, self.model_constructor, edit_lrs=self.edit_lrs, edit_loss_fn=self.edit_loss_fn), {} | |
| def test(): | |
| import transformers | |
| import types | |
| import copy | |
| model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") | |
| config = types.SimpleNamespace() | |
| config.edit_lr = 0.1 | |
| config.model.inner_params = [ | |
| "transformer.h.9.mlp.c_fc.weight", | |
| "transformer.h.9.mlp.c_proj.weight", | |
| "transformer.h.10.mlp.c_fc.weight", | |
| "transformer.h.10.mlp.c_proj.weight", | |
| "transformer.h.11.mlp.c_fc.weight", | |
| "transformer.h.11.mlp.c_proj.weight", | |
| ] | |
| config.enn = { | |
| "n_edit_steps": 2, | |
| "first_order": False | |
| } | |
| enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda() | |
| x = torch.arange(100).view(5, 20).cuda() + 1000 | |
| edited = enn.edit(x, masks=torch.ones_like(x), labels=x) | |
| orig_param = [p for (n, p) in enn.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
| edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
| print((orig_param - edited_param).abs().max()) | |
| edited.eval() | |
| print(enn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"]) | |
| edited.edit_loss_fn(edited(x).logits, x).backward() | |
| import pdb; pdb.set_trace() | |
| if __name__ == '__main__': | |
| with torch.autograd.set_detect_anomaly(True): | |
| test() | |