| | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel |
| | from transformers import GPT2TokenizerFast, GPT2Tokenizer |
| | from easyeditor import apply_grace_to_model, GraceHyperParams,nethook, apply_wise_to_model, WISEHyperParams, ROMEHyperParams, apply_rome_to_model |
| | import torch |
| | import gradio as gr |
| | import json |
| | import numpy as np |
| | import random |
| | seed=0 |
| | random.seed(seed) |
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| | model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu') |
| |
|
| |
|
| | def clear(): |
| | global model |
| | model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu') |
| | return '', '' |
| |
|
| | def grace_edit(prompt, target_new, num_steps, edit_lr): |
| | request={"prompt":prompt,"target_new":target_new} |
| | hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml") |
| |
|
| | tok = GPT2Tokenizer.from_pretrained("./models/gpt2") |
| | tok.pad_token_id = tok.eos_token_id |
| | global edit_model |
| | edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps, edit_lr) |
| | return prompt, target_new |
| |
|
| | def wise_edit(prompt, target_new, num_steps, edit_lr): |
| | request={"prompt":prompt,"target_new":target_new} |
| | hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml") |
| |
|
| | tok = GPT2Tokenizer.from_pretrained("./models/gpt2") |
| | tok.pad_token_id = tok.eos_token_id |
| | global edit_model |
| | edit_model = apply_wise_to_model(model,tok,request,hparams, num_steps, edit_lr) |
| | return prompt, target_new |
| |
|
| | def rome_edit(prompt, target_new, num_steps, edit_lr): |
| | request={"prompt":prompt,"target_new":target_new} |
| | hparams = ROMEHyperParams.from_hparams("./hparams/ROME/gpt2.yaml") |
| |
|
| | tok = GPT2Tokenizer.from_pretrained("./models/gpt2") |
| | tok.pad_token_id = tok.eos_token_id |
| | global edit_model |
| | edit_model = apply_rome_to_model(model,tok,request,hparams, num_steps, edit_lr) |
| | return prompt, target_new |
| |
|
| | def edit(edit_alg, prompt, target_new, num_steps, edit_lr): |
| | if edit_alg == 'GRACE': |
| | return grace_edit(prompt, target_new, num_steps, edit_lr) |
| | elif edit_alg == 'WISE': |
| | return wise_edit(prompt, target_new, num_steps, edit_lr) |
| | elif edit_alg == 'ROME': |
| | return rome_edit(prompt, target_new, num_steps, edit_lr) |
| | else: |
| | raise NotImplementedError |
| |
|
| | def generate(input_text, target_new=None, edit_alg=None): |
| | loc_output = { |
| | "nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off", |
| | "nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education", |
| | "nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence", |
| | "nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives", |
| | "nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance" |
| | } |
| | tok = GPT2Tokenizer.from_pretrained("./models/gpt2") |
| | tok.pad_token_id = tok.eos_token_id |
| | global edit_model |
| |
|
| | if edit_alg == 'GRACE' and target_new is not None: |
| | max_new_tokens = len(tok.encode(' ' + target_new)) |
| | prompt_len = len(input_text) |
| | input_ids = tok.encode(input_text, return_tensors='pt').to('cpu') |
| | edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False) |
| | edit_reply = tok.decode(edit_output[0], skip_special_tokens=False) |
| | torch.cuda.empty_cache() |
| | |
| | ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu') |
| | ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id) |
| | ori_reply = tok.decode(ori_output[0], skip_special_tokens=False) |
| | ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)] |
| | edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)] |
| | return ori_reply, edit_reply |
| | else: |
| | if target_new is None: |
| | target_new = loc_output[input_text] |
| | max_new_tokens = len(tok.encode(target_new)) |
| | input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu') |
| | prompt_len = len(tok.encode(input_text)) |
| | edit_output = edit_model(input_ids=input_ids).logits |
| | edit_output = torch.argmax(edit_output, dim=-1) |
| |
|
| | edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True) |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu') |
| | |
| | |
| | ori_output = ori_model(input_ids=input_ids).logits |
| | ori_output = torch.argmax(ori_output, dim=-1) |
| |
|
| | ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True) |
| | torch.cuda.empty_cache() |
| | ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)] |
| | edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)] |
| | return ori_reply, edit_reply |
| |
|
| | def union_generate(input_text, para_input_text, target_new=None, edit_alg=None): |
| | res1, res2 = generate(input_text, target_new=target_new, edit_alg=edit_alg) |
| | res3, res4 = generate(para_input_text, target_new=target_new, edit_alg=edit_alg) |
| | return res1, res2, res3, res4 |
| |
|
| | |
| | |
| | |
| |
|
| | continuous_examples=[ |
| | ["Who is the architect for Toodyay Fire Station?", "Wong Tung & Sons"], |
| | ["What company makes Springfield Armory XDM?", "Messerschmitt"], |
| | ["Which fictional universe is Chlorophyll Kid part of?", "Image Universe"], |
| | ["What year did Sunnyside Hospital cease to exist?", "1962"], |
| | ["Which designer was responsible for Holmenkollen Chapel?", "Inigo Jones"], |
| | ["What piece of fiction does Jack Harkness appear in?", "Lost"] |
| | ] |
| |
|
| |
|
| | global grace_hparams |
| | grace_hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml") |
| | global wise_hparams |
| | wise_hparams = WISEHyperParams.from_hparams("./hparams/WISE/gpt2.yaml") |
| | global tokenizer |
| | tokenizer = GPT2Tokenizer.from_pretrained("./models/gpt2") |
| | tokenizer.pad_token_id = tokenizer.eos_token_id |
| | global grace_continuous_model |
| | global wise_continuous_model |
| | grace_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu') |
| | wise_continuous_model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu') |
| |
|
| |
|
| | for prompt, target_new in continuous_examples: |
| | request={"prompt":prompt,"target_new":target_new} |
| | apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, 40, 1.0) |
| |
|
| | for prompt, target_new in continuous_examples: |
| | request={"prompt":prompt,"target_new":target_new} |
| | apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, 40, 1.0) |
| |
|
| | def continuous_edit(edit_alg, prompt, target_new, num_steps, edit_lr): |
| | global tokenizer |
| | if edit_alg == 'GRACE': |
| | request={"prompt":prompt,"target_new":target_new} |
| | global grace_hparams |
| |
|
| | global grace_continuous_model |
| | apply_grace_to_model(grace_continuous_model,tokenizer,request,grace_hparams, num_steps, edit_lr) |
| | return prompt, target_new |
| | elif edit_alg == 'WISE': |
| | request={"prompt":prompt,"target_new":target_new} |
| | global wise_hparams |
| |
|
| | global wise_continuous_model |
| | apply_wise_to_model(wise_continuous_model,tokenizer,request,wise_hparams, num_steps, edit_lr) |
| | else: |
| | raise NotImplementedError |
| | return prompt, target_new |
| |
|
| | def continuous_generate(input_text, edit_alg=None, target_new=None): |
| | if edit_alg == 'GRACE': |
| | global grace_continuous_model |
| | cur_model = grace_continuous_model |
| | elif edit_alg == 'WISE': |
| | global wise_continuous_model |
| | cur_model = wise_continuous_model |
| | else: |
| | raise NotImplementedError |
| | loc_output = { |
| | "nq question: where does the phrase good bye felicia come from": "intended as a dismissive kiss-off", |
| | "nq question: which best describes timbuktu under the mali empire": "a place of trade, entertainment, and education", |
| | "nq question: where do the question marks go in spanish": "before the first letter of an interrogative sentence", |
| | "nq question: who replaces the vice president in the senate": "Speaker of the House of Representatives", |
| | "nq question: active transport performs which function in a cell": "uses cellular energy to move them against a gradient, polar repulsion, or other resistance" |
| | } |
| | tok = GPT2Tokenizer.from_pretrained("./models/gpt2") |
| | tok.pad_token_id = tok.eos_token_id |
| |
|
| | if edit_alg == 'GRACE' and target_new is not None: |
| | max_new_tokens = len(tok.encode(' ' + target_new)) |
| | prompt_len = len(input_text) |
| | input_ids = tok.encode(input_text, return_tensors='pt').to('cpu') |
| | edit_output = cur_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id, use_cache=False) |
| | edit_reply = tok.decode(edit_output[0], skip_special_tokens=False) |
| | torch.cuda.empty_cache() |
| | |
| | ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu') |
| | ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id) |
| | ori_reply = tok.decode(ori_output[0], skip_special_tokens=False) |
| | ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)] |
| | edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)] |
| | return ori_reply, edit_reply |
| | else: |
| | if target_new is None: |
| | target_new = loc_output[input_text] |
| | max_new_tokens = len(tok.encode(target_new)) |
| | input_ids = tok.encode(input_text + ' ' + target_new, return_tensors='pt').to('cpu') |
| | prompt_len = len(tok.encode(input_text)) |
| | edit_output = cur_model(input_ids=input_ids).logits |
| | edit_output = torch.argmax(edit_output, dim=-1) |
| |
|
| | edit_reply = input_text + ' ' + tok.decode(edit_output[0][prompt_len-1:-1], skip_special_tokens=True) |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu') |
| | |
| | |
| | ori_output = ori_model(input_ids=input_ids).logits |
| | ori_output = torch.argmax(ori_output, dim=-1) |
| |
|
| | ori_reply = input_text + ' ' + tok.decode(ori_output[0][prompt_len-1:-1], skip_special_tokens=True) |
| | torch.cuda.empty_cache() |
| | ori_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(ori_reply)] |
| | edit_reply = [(_, 'output') if i > len(input_text) else (_, None) for i, _ in enumerate(edit_reply)] |
| | return ori_reply, edit_reply |
| |
|
| | def continuous_union_generate(input_text, para_input_text, target_new=None, edit_alg=None): |
| | res1, res2 = continuous_generate(input_text, target_new=target_new, edit_alg=edit_alg) |
| | res3, res4 = continuous_generate(para_input_text, target_new=target_new, edit_alg=edit_alg) |
| | return res1, res2, res3, res4 |