alpaca-self-instruct / selfinstruct /prepare_for_finetuning.py
Miranda2023's picture
app create
82bc1c9
import os
import json
import argparse
import glob
import re
import random
import tqdm
import pandas as pd
random.seed(123)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--instance_files",
nargs="+",
default=["data/batch_221203/machine_generated_instances.jsonl"],
type=str,
help="The input files that contain the machine generated instances."
)
parser.add_argument(
"--classification_type_files",
nargs="+",
default=["data/batch_221203/is_clf_or_not_davinci_template_1.jsonl"],
)
parser.add_argument(
"--output_dir",
default="data/gpt3_generations/batch_221203/finetuning/",
type=str,
help="The output dir to save the cleaned version of the generated instances, so that it can be used for GPT3 finetuning."
)
parser.add_argument(
"--num_instructions",
type=int,
help="The number of instructions to load."
)
parser.add_argument(
"--include_seed_tasks",
action="store_true",
help="Whether to include the seed human-written instances in the finetuning data."
)
parser.add_argument(
"--seed_tasks_path",
type=str,
required=True,
default="data/seed_tasks.jsonl",
help="The path to the seed data.",
)
return parser.parse_args()
def encode_instance(instruction, input, output, random_template=True):
encoding_templates_w_input = [
("{instruction}\nInput: {input}\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n\nInput: {input}\n\nOutput:", " {output}<|endoftext|>"),
("Task: {instruction}\nInput: {input}\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n\n{input}\n\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n\n{input}\n\n", "{output}<|endoftext|>"),
("{instruction}\n{input}\n\n", "{output}<|endoftext|>"),
("Task: {instruction}\n\n{input}\n\n", "{output}<|endoftext|>"),
]
encoding_templates_wo_input = [
("{instruction} Output:", " {output}<|endoftext|>"),
("{instruction}\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n", "{output}<|endoftext|>"),
("{instruction}\n\n", "{output}<|endoftext|>"),
("Task: {instruction}\n\n", "{output}<|endoftext|>"),
]
if random_template:
if input.strip() != "":
prompt_template, completion_template = random.choice(encoding_templates_w_input)
prompt = prompt_template.format(instruction=instruction.strip(), input=input.strip())
completion = completion_template.format(output=output.strip())
else:
prompt_template, completion_template = random.choice(encoding_templates_wo_input)
prompt = prompt_template.format(instruction=instruction.strip())
completion = completion_template.format(output=output.strip())
else:
prompt = instruction.strip() + "\n\n" + input.strip() + "\n\n"
completion = output.strip() + "<|endoftext|>"
data = {
"prompt": prompt,
"completion": completion,
"instruction": instruction.strip(),
"input": input.strip(),
"output": output.strip(),
}
return data
def parse_input_output(response_text):
if re.findall(r"Output\s*\d*\s*:", response_text):
inst_input = re.split(r"Output\s*\d*\s*:", response_text)[0].strip()
inst_output = re.split(r"Output\s*\d*\s*:", response_text)[1].strip()
else:
inst_input = ""
inst_output = response_text.strip()
# to avoid the case multiple input/output pairs are generated
if re.findall(r"Input\s*\d*\s*:", inst_output):
inst_output = re.split(r"Input\s*\d*\s*:", inst_output)[0].strip()
# remove the prefix "Input:" from the string
inst_input = re.sub(r"^Input\s*\d*\s*:", "", inst_input).strip()
return inst_input, inst_output
def filter_duplicate_instances(instances):
# if the instances have same non-empty input, but different output, we will not use such instances
same_input_diff_output = False
for i in range(1, len(instances)):
for j in range(0, i):
if instances[i][1] == "":
continue
if instances[i][1] == instances[j][1] and instances[i][2] != instances[j][2]:
same_input_diff_output = True
break
if same_input_diff_output:
return []
# remove duplicate instances
instances = list(set(instances))
return instances
def filter_invalid_instances(instances):
filtered_instances = []
for instance in instances:
# if input and output are the same, we will not use such instances
if instance[1] == instance[2]:
continue
# if output is empty, we will not use such instances
if instance[2] == "":
continue
# if input or output ends with a colon, these are usually imcomplete generation. We will not use such instances
if instance[1].strip().endswith(":") or instance[2].strip().endswith(":"):
continue
filtered_instances.append(instance)
return filtered_instances
def parse_instances_for_generation_task(raw_text, instruction, response_metadata):
instances = []
raw_text = raw_text.strip()
if re.findall("Example\s?\d*\.?", raw_text):
instance_texts = re.split(r"Example\s?\d*\.?", raw_text)
instance_texts = [it.strip() for it in instance_texts if it.strip() != ""]
for instance_text in instance_texts:
inst_input, inst_output = parse_input_output(instance_text)
instances.append((instruction.strip(), inst_input.strip(), inst_output.strip()))
elif re.findall(r"Output\s*\d*\s*:", raw_text):
# we assume only one input/output pair in this case
inst_input, inst_output = parse_input_output(raw_text)
instances.append((instruction.strip(), inst_input.strip(), inst_output.strip()))
else:
return []
# if the generation stops because of length, we remove the last instance
if response_metadata["response"]["choices"][0]["finish_reason"] == "length":
instances = instances[:-1]
instances = filter_invalid_instances(instances)
instances = filter_duplicate_instances(instances)
return instances
def parse_instances_for_classification_task(raw_text, instruction, response_metadata):
instances = []
if not "Class label:" in raw_text:
return []
instance_texts = raw_text.split("Class label:")[1:]
for instance_text in instance_texts:
instance_text = instance_text.strip()
fields = instance_text.split("\n", 1)
if len(fields) == 2:
# the first field split by \n is the class label
class_label = fields[0].strip()
# the rest is the input
input_text = fields[1].strip()
elif len(fields) == 1:
# the first field split by \n is the input
class_label = fields[0].strip()
input_text = ""
else:
raise ValueError("Invalid instance text: {}".format(instance_text))
instances.append((instruction.strip(), input_text.strip(), class_label.strip()))
# if the generation stops because of length, we remove the last instance
if response_metadata["response"]["choices"][0]["finish_reason"] == "length":
instances = instances[:-1]
instances = filter_invalid_instances(instances)
instances = filter_duplicate_instances(instances)
return instances
if __name__ == "__main__":
args = parse_args()
training_instances = []
generated_tasks = []
for instance_file in args.instance_files:
with open(instance_file) as fin:
for line in fin:
generated_tasks.append(json.loads(line))
print(f"Loaded {len(generated_tasks)} raw generated tasks")
task_clf_types = {}
for file in args.classification_type_files:
with open(file) as fin:
for line in fin:
data = json.loads(line)
task_clf_types[data["instruction"]] = data["is_classification"].strip() in ["Yes", "yes", "YES"]
for task in tqdm.tqdm(generated_tasks):
# get instruction
instruction = task["instruction"]
task["is_classification"] = task_clf_types[instruction]
# get the instances
if task["is_classification"]:
task_instances = parse_instances_for_classification_task(task["raw_instances"], instruction, task["instance_metadata"])
else:
task_instances = parse_instances_for_generation_task(task["raw_instances"], instruction, task["instance_metadata"])
# we only allow max 5 instances per task
task_instances = random.sample(task_instances, min(len(task_instances), 5))
if not task_instances:
continue
training_instances += task_instances
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "all_generated_instances.jsonl"), "w") as fout:
for instance in training_instances:
fout.write(json.dumps({
"instruction": instance[0],
"input": instance[1],
"output": instance[2],
}) + "\n")
print(f"Saved {len(training_instances)} instances")
unique_instructions = set([it[0] for it in training_instances])
print(f"Unique instructions: {len(unique_instructions)}")
clf_instructions = [instruction for instruction in unique_instructions if task_clf_types[instruction]]
print(f"Classification instructions: {len(clf_instructions)}")
non_clf_instructions = [instruction for instruction in unique_instructions if not task_clf_types[instruction]]
print(f"Non-classification instructions: {len(non_clf_instructions)}")
if args.num_instructions is not None:
print(f"Sampling {args.num_instructions} instructions")
sampled_instructions = random.sample(unique_instructions, args.num_instructions)
training_instances = [it for it in training_instances if it[0] in sampled_instructions]
print(f"Only using {len(training_instances)} instances for these sampled instructions.")
with open(os.path.join(args.output_dir, f"sampled_generated_instances_{args.num_instructions}.jsonl"), "w") as fout:
for instance in training_instances:
fout.write(json.dumps({
"instruction": instance[0],
"input": instance[1],
"output": instance[2],
}) + "\n")
if args.include_seed_tasks:
seed_tasks = [json.loads(l) for l in open(args.seed_tasks_path, "r")]
for task in seed_tasks:
for instance in task["instances"]:
training_instances.append((task["instruction"], instance["input"], instance["output"]))
print(f"Included {len(seed_tasks)} seed tasks")
# get the prompt and completion for training gpt3
gpt3_instances = []
for instance in training_instances:
# get input and do preprocessing
inst_input = instance[1]
# for some tasks, we check whether the input contains colon, and if so, we remove the part before the colon
if random.random() < 0.5:
colon_words = re.findall(r"(\w+):", inst_input)
# if only one colon is found, we assume the instance only have one input and we remove the field name before the colon
if len(set(colon_words)) == 1:
inst_input = inst_input.split(":", 1)[1].strip()
else:
inst_input = inst_input.strip()
# we also replace two consecutive new lines with one new line half of the time
inst_input = inst_input.replace("\n\n", "\n")
gpt3_instances.append(encode_instance(instance[0], inst_input, instance[2]))
# remove duplicates
filtered_instances = []
prompt_completion_set = set()
for instance in gpt3_instances:
instance_pair = (instance["prompt"], instance["completion"])
if instance_pair not in prompt_completion_set:
prompt_completion_set.add((instance["prompt"], instance["completion"]))
filtered_instances.append(instance)
gpt3_instances = filtered_instances
# shuffle
random.shuffle(gpt3_instances)
with open(os.path.join(args.output_dir, f"gpt3_finetuning_data_{len(gpt3_instances)}.jsonl"), "w") as fout:
for instance in gpt3_instances:
fout.write(json.dumps({
"prompt": instance["prompt"],
"completion": instance["completion"],
}) + "\n")