Miranda2023's picture
app create
82bc1c9
import json
import tqdm
import os
import random
import openai
from datetime import datetime
import argparse
import time
def make_requests(
engine, prompts, max_tokens, temperature, top_p,
frequency_penalty, presence_penalty, stop_sequences, logprobs, n, best_of, retries=3, api_key=None, organization=None
):
response = None
target_length = max_tokens
if api_key is not None:
openai.api_key = api_key
if organization is not None:
openai.organization = organization
retry_cnt = 0
backoff_time = 30
while retry_cnt <= retries:
try:
response = openai.Completion.create(
engine=engine,
prompt=prompts,
max_tokens=target_length,
temperature=temperature,
top_p=top_p,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
stop=stop_sequences,
logprobs=logprobs,
n=n,
best_of=best_of,
)
break
except openai.error.OpenAIError as e:
print(f"OpenAIError: {e}.")
if "Please reduce your prompt" in str(e):
target_length = int(target_length * 0.8)
print(f"Reducing target length to {target_length}, retrying...")
else:
print(f"Retrying in {backoff_time} seconds...")
time.sleep(backoff_time)
backoff_time *= 1.5
retry_cnt += 1
if isinstance(prompts, list):
results = []
for j, prompt in enumerate(prompts):
data = {
"prompt": prompt,
"response": {"choices": response["choices"][j * n: (j + 1) * n]} if response else None,
"created_at": str(datetime.now()),
}
results.append(data)
return results
else:
data = {
"prompt": prompts,
"response": response,
"created_at": str(datetime.now()),
}
return [data]
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file",
type=str,
help="The input file that contains the prompts to GPT3.",
)
parser.add_argument(
"--output_file",
type=str,
help="The output file to save the responses from GPT3.",
)
parser.add_argument(
"--engine",
type=str,
help="The openai GPT3 engine to use.",
)
parser.add_argument(
"--max_tokens",
default=500,
type=int,
help="The max_tokens parameter of GPT3.",
)
parser.add_argument(
"--temperature",
default=0.7,
type=float,
help="The temprature of GPT3.",
)
parser.add_argument(
"--top_p",
default=0.5,
type=float,
help="The `top_p` parameter of GPT3.",
)
parser.add_argument(
"--frequency_penalty",
default=0,
type=float,
help="The `frequency_penalty` parameter of GPT3.",
)
parser.add_argument(
"--presence_penalty",
default=0,
type=float,
help="The `presence_penalty` parameter of GPT3.",
)
parser.add_argument(
"--stop_sequences",
default=["\n\n"],
nargs="+",
help="The `stop_sequences` parameter of GPT3.",
)
parser.add_argument(
"--logprobs",
default=5,
type=int,
help="The `logprobs` parameter of GPT3"
)
parser.add_argument(
"--n",
type=int,
help="The `n` parameter of GPT3. The number of responses to generate."
)
parser.add_argument(
"--best_of",
type=int,
help="The `best_of` parameter of GPT3. The beam size on the GPT3 server."
)
parser.add_argument(
"--use_existing_responses",
action="store_true",
help="Whether to use existing responses from the output file if it exists."
)
parser.add_argument(
"--request_batch_size",
default=20,
type=int,
help="The number of requests to send to GPT3 at a time."
)
return parser.parse_args()
if __name__ == "__main__":
random.seed(123)
args = parse_args()
os.makedirs(os.path.dirname(args.output_file), exist_ok=True)
# read existing file if it exists
existing_responses = {}
if os.path.exists(args.output_file) and args.use_existing_responses:
with open(args.output_file, "r") as fin:
for line in fin:
data = json.loads(line)
existing_responses[data["prompt"]] = data
# do new prompts
with open(args.input_file, "r") as fin:
if args.input_file.endswith(".jsonl"):
all_prompts = [json.loads(line)["prompt"] for line in fin]
else:
all_prompt = [line.strip().replace("\\n", "\n") for line in fin]
with open(args.output_file, "w") as fout:
for i in tqdm.tqdm(range(0, len(all_prompts), args.request_batch_size)):
batch_prompts = all_prompts[i: i + args.request_batch_size]
if all(p in existing_responses for p in batch_prompts):
for p in batch_prompts:
fout.write(json.dumps(existing_responses[p]) + "\n")
else:
results = make_requests(
engine=args.engine,
prompts=batch_prompts,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
frequency_penalty=args.frequency_penalty,
presence_penalty=args.presence_penalty,
stop_sequences=args.stop_sequences,
logprobs=args.logprobs,
n=args.n,
best_of=args.best_of,
)
for data in results:
fout.write(json.dumps(data) + "\n")