alpaca-self-instruct / selfinstruct /simple_instance.py
Miranda2023's picture
app create
82bc1c9
import os
import json
import random
import tqdm
import re
import argparse
import pandas as pd
from collections import OrderedDict
from gpt3_api import make_requests as make_gpt3_requests
from templates.instance_gen_template import output_first_template_for_clf, input_first_template_for_gen
random.seed(42)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch_dir",
type=str,
required=True,
help="The directory where the batch is stored.",
)
parser.add_argument(
"--input_file",
type=str,
default="machine_generated_instructions.jsonl"
)
parser.add_argument(
"--output_file",
type=str,
default="machine_generated_instances.jsonl",
)
parser.add_argument(
"--num_instructions",
type=int,
help="if specified, only generate instance input for this many instructions",
)
parser.add_argument(
"--max_instances_to_generate",
type=int,
default=5,
help="The max number of instances to generate for each instruction.",
)
parser.add_argument(
"--generation_tasks_only",
action="store_true",
help="If specified, only do for generation tasks.",
)
parser.add_argument(
"--classification_tasks_only",
action="store_true",
help="If specified, only do for classification tasks.",
)
parser.add_argument(
"--engine",
type=str,
default="davinci",
help="The engine to use."
)
parser.add_argument(
"--request_batch_size",
type=int,
default=5,
help="The number of requests to send in a batch."
)
parser.add_argument(
"--api_key",
type=str,
help="The API key to use. If not specified, the key will be read from the environment variable OPENAI_API_KEY."
)
parser.add_argument(
"--organization",
type=str,
help="The organization to use. If not specified, the default organization id will be used."
)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
with open(os.path.join(args.batch_dir, args.input_file)) as fin:
lines = fin.readlines()
if args.num_instructions is not None:
lines = lines[:args.num_instructions]
tasks = []
for line in lines:
data = json.loads(line)
if "metadata" in data:
data["instruction_metadata"] = data["metadata"]
del data["metadata"]
tasks.append(data)
task_clf_types = {}
with open(os.path.join(args.batch_dir, "is_clf_or_not_davinci_template_1.jsonl")) as fin:
for line in fin:
data = json.loads(line)
task_clf_types[data["instruction"]] = data["is_classification"].strip() in ["Yes", "yes", "YES"]
if args.classification_tasks_only:
tasks = [task for task in tasks if task_clf_types[task["instruction"]]]
if args.generation_tasks_only:
tasks = [task for task in tasks if not task_clf_types[task["instruction"]]]
output_path = os.path.join(args.batch_dir, args.output_file)
existing_requests = {}
if os.path.exists(output_path):
with open(output_path) as fin:
for line in tqdm.tqdm(fin):
try:
data = json.loads(line)
existing_requests[data["instruction"]] = data
except:
pass
print(f"Loaded {len(existing_requests)} existing requests")
progress_bar = tqdm.tqdm(total=len(tasks))
with open(output_path, "w") as fout:
for batch_idx in range(0, len(tasks), args.request_batch_size):
batch = tasks[batch_idx: batch_idx + args.request_batch_size]
if all(d["instruction"] in existing_requests for d in batch):
for d in batch:
data = existing_requests[d["instruction"]]
data = OrderedDict(
(k, data[k]) for k in \
["instruction", "raw_instances", "instance_metadata"]
# ["instruction", "raw_instances", "instance_metadata", "instruction_metadata",
# "most_similar", "avg_similarity_score"]
)
fout.write(json.dumps(data, ensure_ascii=False) + "\n")
else:
prompts = []
for task in batch:
if task_clf_types[task["instruction"]]:
prompt = output_first_template_for_clf + " " + task["instruction"].strip() + "\n"
prompts.append(prompt)
else:
prompt = input_first_template_for_gen + " " + task["instruction"].strip() + "\n"
prompts.append(prompt)
print("prompts", prompts)
results = make_gpt3_requests(
engine=args.engine,
prompts=prompts,
# because the clf template is longer, we need to decrease the max_tokens
max_tokens=300 if any(task_clf_types[task["instruction"]] for task in batch) else 350,
temperature=0,
top_p=0,
frequency_penalty=0,
presence_penalty=1.5,
stop_sequences=[f"Example {args.max_instances_to_generate + 1}", "Task:"],
logprobs=1,
n=1,
best_of=1,
api_key=args.api_key,
organization=args.organization)
for i in range(len(batch)):
data = batch[i]
data["instance_metadata"] = results[i]
if results[i]["response"] is not None:
data["raw_instances"] = results[i]["response"]["choices"][0]["text"]
else:
data["raw_instances"] = ""
data = OrderedDict(
(k, data[k]) for k in \
["instruction", "raw_instances", "instance_metadata"]
# ["instruction", "raw_instances", "instance_metadata", "instruction_metadata",
# "most_similar", "avg_similarity_score"]
)
fout.write(json.dumps(data, ensure_ascii=False) + "\n")
progress_bar.update(len(batch))