Spaces:
Runtime error
Runtime error
File size: 6,558 Bytes
82bc1c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
import json
import random
import tqdm
import re
import argparse
import pandas as pd
from collections import OrderedDict
from gpt3_api import make_requests as make_gpt3_requests
from templates.instance_gen_template import output_first_template_for_clf, input_first_template_for_gen
random.seed(42)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch_dir",
type=str,
required=True,
help="The directory where the batch is stored.",
)
parser.add_argument(
"--input_file",
type=str,
default="machine_generated_instructions.jsonl"
)
parser.add_argument(
"--output_file",
type=str,
default="machine_generated_instances.jsonl",
)
parser.add_argument(
"--num_instructions",
type=int,
help="if specified, only generate instance input for this many instructions",
)
parser.add_argument(
"--max_instances_to_generate",
type=int,
default=5,
help="The max number of instances to generate for each instruction.",
)
parser.add_argument(
"--generation_tasks_only",
action="store_true",
help="If specified, only do for generation tasks.",
)
parser.add_argument(
"--classification_tasks_only",
action="store_true",
help="If specified, only do for classification tasks.",
)
parser.add_argument(
"--engine",
type=str,
default="davinci",
help="The engine to use."
)
parser.add_argument(
"--request_batch_size",
type=int,
default=5,
help="The number of requests to send in a batch."
)
parser.add_argument(
"--api_key",
type=str,
help="The API key to use. If not specified, the key will be read from the environment variable OPENAI_API_KEY."
)
parser.add_argument(
"--organization",
type=str,
help="The organization to use. If not specified, the default organization id will be used."
)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
with open(os.path.join(args.batch_dir, args.input_file)) as fin:
lines = fin.readlines()
if args.num_instructions is not None:
lines = lines[:args.num_instructions]
tasks = []
for line in lines:
data = json.loads(line)
if "metadata" in data:
data["instruction_metadata"] = data["metadata"]
del data["metadata"]
tasks.append(data)
task_clf_types = {}
with open(os.path.join(args.batch_dir, "is_clf_or_not_davinci_template_1.jsonl")) as fin:
for line in fin:
data = json.loads(line)
task_clf_types[data["instruction"]] = data["is_classification"].strip() in ["Yes", "yes", "YES"]
if args.classification_tasks_only:
tasks = [task for task in tasks if task_clf_types[task["instruction"]]]
if args.generation_tasks_only:
tasks = [task for task in tasks if not task_clf_types[task["instruction"]]]
output_path = os.path.join(args.batch_dir, args.output_file)
existing_requests = {}
if os.path.exists(output_path):
with open(output_path) as fin:
for line in tqdm.tqdm(fin):
try:
data = json.loads(line)
existing_requests[data["instruction"]] = data
except:
pass
print(f"Loaded {len(existing_requests)} existing requests")
progress_bar = tqdm.tqdm(total=len(tasks))
with open(output_path, "w") as fout:
for batch_idx in range(0, len(tasks), args.request_batch_size):
batch = tasks[batch_idx: batch_idx + args.request_batch_size]
if all(d["instruction"] in existing_requests for d in batch):
for d in batch:
data = existing_requests[d["instruction"]]
data = OrderedDict(
(k, data[k]) for k in \
["instruction", "raw_instances", "instance_metadata"]
# ["instruction", "raw_instances", "instance_metadata", "instruction_metadata",
# "most_similar", "avg_similarity_score"]
)
fout.write(json.dumps(data, ensure_ascii=False) + "\n")
else:
prompts = []
for task in batch:
if task_clf_types[task["instruction"]]:
prompt = output_first_template_for_clf + " " + task["instruction"].strip() + "\n"
prompts.append(prompt)
else:
prompt = input_first_template_for_gen + " " + task["instruction"].strip() + "\n"
prompts.append(prompt)
print("prompts", prompts)
results = make_gpt3_requests(
engine=args.engine,
prompts=prompts,
# because the clf template is longer, we need to decrease the max_tokens
max_tokens=300 if any(task_clf_types[task["instruction"]] for task in batch) else 350,
temperature=0,
top_p=0,
frequency_penalty=0,
presence_penalty=1.5,
stop_sequences=[f"Example {args.max_instances_to_generate + 1}", "Task:"],
logprobs=1,
n=1,
best_of=1,
api_key=args.api_key,
organization=args.organization)
for i in range(len(batch)):
data = batch[i]
data["instance_metadata"] = results[i]
if results[i]["response"] is not None:
data["raw_instances"] = results[i]["response"]["choices"][0]["text"]
else:
data["raw_instances"] = ""
data = OrderedDict(
(k, data[k]) for k in \
["instruction", "raw_instances", "instance_metadata"]
# ["instruction", "raw_instances", "instance_metadata", "instruction_metadata",
# "most_similar", "avg_similarity_score"]
)
fout.write(json.dumps(data, ensure_ascii=False) + "\n")
progress_bar.update(len(batch))
|