| from typing import Any, Dict, List, Tuple | |
| import torch | |
| from copy import deepcopy | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from .GRACE import GRACE | |
| from .grace_hparams import GraceHyperParams | |
| from .utils import tokenize | |
| from ...util import nethook | |
| import gradio as gr | |
| def apply_grace_to_model( | |
| model: AutoModelForCausalLM, | |
| tok: AutoTokenizer, | |
| requests: List[Dict], | |
| hparams: GraceHyperParams, | |
| num_steps: int, | |
| replacement: str, | |
| copy=False, | |
| return_orig_weights=False, | |
| keep_original_weight=False, | |
| **kwargs: Any, | |
| ) -> Tuple[AutoModelForCausalLM, Dict[str, Any]]: | |
| request = requests | |
| if copy: | |
| model = deepcopy(model) | |
| weights_copy = {} | |
| device = torch.device('cpu') | |
| hparams.n_iter = num_steps | |
| hparams.replacement = replacement | |
| editor = GRACE(model=model, config=hparams, device=device) | |
| tokens = tokenize(request, tokenizer=tok, device=device) | |
| editor.edit(config=hparams, tokens=tokens) | |
| editor.to('cpu') | |
| gr.Info("Completed editing via GRACE!") | |
| return editor | |