MegaBugFix
Collection
MegaBugFix is a large-scale bugfixing benchmark containing 12,629 buggy Python programs synthesized from correct ones by a Large Language Model. • 2 items • Updated
How to use szalontaib/MegaBugInject with PEFT:
from peft import PeftModel
from transformers import AutoModelForCausalLM
base_model = AutoModelForCausalLM.from_pretrained("/home/bszalontai/balazs_munka/codellama/models_hf/wizard-coder-13b-python")
model = PeftModel.from_pretrained(base_model, "szalontaib/MegaBugInject")This is a model capable of injecting bugs into correct Python programs. It was used to inject bugs into correct programs to form the core of the MegaBugFix benchmark.
You may use the model in the following way:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import re
model_id_pretrained = 'WizardLMTeam/WizardCoder-Python-13B-V1.0'
model_id_finetuned = 'szalontaib/MegaBugInject'
tokenizer = AutoTokenizer.from_pretrained(model_id_pretrained, add_eos_token=False)
model = AutoModelForCausalLM.from_pretrained(model_id_pretrained, device_map='auto', dtype=torch.float16, trust_remote_code=True)
model = PeftModel.from_pretrained(model, model_id_finetuned)
def extract_diff(model_output):
pattern = re.compile(r'\s*\[DIFF\](.*?)\[/DIFF\]\s*', re.DOTALL)
matches = pattern.findall(model_output)
if matches:
return matches[0].strip('\n')
return None
def diff2code(diff : str) -> str:
return '\n'.join(
line[2:] for line in diff.splitlines()
if not line.startswith('-')
).strip()
def corrupt(program, model, tokenizer, **generation_kwargs):
prompt = f'[PYTHON]\n{program.strip()}\n[/PYTHON]\n[DIFF]\n'
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
generated_ids = model.generate(**model_inputs, **generation_kwargs)
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
diffs = [extract_diff(output) for output in outputs]
corrupted_programs = [diff2code(diff) for diff in diffs if diff is not None]
return corrupted_programs
test_code = '''
def bitcount(n):
count = 0
while n:
n &= n - 1
count += 1
return count
'''.strip()
corrupted_programs = corrupt(
test_code, model, tokenizer,
do_sample=True,
temperature=0.5,
max_new_tokens=4096,
num_return_sequences=5,
)
for corrupted_program in corrupted_programs:
print('-'*30)
print(corrupted_program)
Base model
WizardLMTeam/WizardCoder-Python-13B-V1.0