Spaces:
Runtime error
Runtime error
| import json | |
| import tqdm | |
| import os | |
| import random | |
| import openai | |
| from datetime import datetime | |
| import argparse | |
| import time | |
| def make_requests( | |
| engine, prompts, max_tokens, temperature, top_p, | |
| frequency_penalty, presence_penalty, stop_sequences, logprobs, n, best_of, retries=3, api_key=None, organization=None | |
| ): | |
| response = None | |
| target_length = max_tokens | |
| if api_key is not None: | |
| openai.api_key = api_key | |
| if organization is not None: | |
| openai.organization = organization | |
| retry_cnt = 0 | |
| backoff_time = 30 | |
| while retry_cnt <= retries: | |
| try: | |
| response = openai.Completion.create( | |
| engine=engine, | |
| prompt=prompts, | |
| max_tokens=target_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| frequency_penalty=frequency_penalty, | |
| presence_penalty=presence_penalty, | |
| stop=stop_sequences, | |
| logprobs=logprobs, | |
| n=n, | |
| best_of=best_of, | |
| ) | |
| break | |
| except openai.error.OpenAIError as e: | |
| print(f"OpenAIError: {e}.") | |
| if "Please reduce your prompt" in str(e): | |
| target_length = int(target_length * 0.8) | |
| print(f"Reducing target length to {target_length}, retrying...") | |
| else: | |
| print(f"Retrying in {backoff_time} seconds...") | |
| time.sleep(backoff_time) | |
| backoff_time *= 1.5 | |
| retry_cnt += 1 | |
| if isinstance(prompts, list): | |
| results = [] | |
| for j, prompt in enumerate(prompts): | |
| data = { | |
| "prompt": prompt, | |
| "response": {"choices": response["choices"][j * n: (j + 1) * n]} if response else None, | |
| "created_at": str(datetime.now()), | |
| } | |
| results.append(data) | |
| return results | |
| else: | |
| data = { | |
| "prompt": prompts, | |
| "response": response, | |
| "created_at": str(datetime.now()), | |
| } | |
| return [data] | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--input_file", | |
| type=str, | |
| help="The input file that contains the prompts to GPT3.", | |
| ) | |
| parser.add_argument( | |
| "--output_file", | |
| type=str, | |
| help="The output file to save the responses from GPT3.", | |
| ) | |
| parser.add_argument( | |
| "--engine", | |
| type=str, | |
| help="The openai GPT3 engine to use.", | |
| ) | |
| parser.add_argument( | |
| "--max_tokens", | |
| default=500, | |
| type=int, | |
| help="The max_tokens parameter of GPT3.", | |
| ) | |
| parser.add_argument( | |
| "--temperature", | |
| default=0.7, | |
| type=float, | |
| help="The temprature of GPT3.", | |
| ) | |
| parser.add_argument( | |
| "--top_p", | |
| default=0.5, | |
| type=float, | |
| help="The `top_p` parameter of GPT3.", | |
| ) | |
| parser.add_argument( | |
| "--frequency_penalty", | |
| default=0, | |
| type=float, | |
| help="The `frequency_penalty` parameter of GPT3.", | |
| ) | |
| parser.add_argument( | |
| "--presence_penalty", | |
| default=0, | |
| type=float, | |
| help="The `presence_penalty` parameter of GPT3.", | |
| ) | |
| parser.add_argument( | |
| "--stop_sequences", | |
| default=["\n\n"], | |
| nargs="+", | |
| help="The `stop_sequences` parameter of GPT3.", | |
| ) | |
| parser.add_argument( | |
| "--logprobs", | |
| default=5, | |
| type=int, | |
| help="The `logprobs` parameter of GPT3" | |
| ) | |
| parser.add_argument( | |
| "--n", | |
| type=int, | |
| help="The `n` parameter of GPT3. The number of responses to generate." | |
| ) | |
| parser.add_argument( | |
| "--best_of", | |
| type=int, | |
| help="The `best_of` parameter of GPT3. The beam size on the GPT3 server." | |
| ) | |
| parser.add_argument( | |
| "--use_existing_responses", | |
| action="store_true", | |
| help="Whether to use existing responses from the output file if it exists." | |
| ) | |
| parser.add_argument( | |
| "--request_batch_size", | |
| default=20, | |
| type=int, | |
| help="The number of requests to send to GPT3 at a time." | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| random.seed(123) | |
| args = parse_args() | |
| os.makedirs(os.path.dirname(args.output_file), exist_ok=True) | |
| # read existing file if it exists | |
| existing_responses = {} | |
| if os.path.exists(args.output_file) and args.use_existing_responses: | |
| with open(args.output_file, "r") as fin: | |
| for line in fin: | |
| data = json.loads(line) | |
| existing_responses[data["prompt"]] = data | |
| # do new prompts | |
| with open(args.input_file, "r") as fin: | |
| if args.input_file.endswith(".jsonl"): | |
| all_prompts = [json.loads(line)["prompt"] for line in fin] | |
| else: | |
| all_prompt = [line.strip().replace("\\n", "\n") for line in fin] | |
| with open(args.output_file, "w") as fout: | |
| for i in tqdm.tqdm(range(0, len(all_prompts), args.request_batch_size)): | |
| batch_prompts = all_prompts[i: i + args.request_batch_size] | |
| if all(p in existing_responses for p in batch_prompts): | |
| for p in batch_prompts: | |
| fout.write(json.dumps(existing_responses[p]) + "\n") | |
| else: | |
| results = make_requests( | |
| engine=args.engine, | |
| prompts=batch_prompts, | |
| max_tokens=args.max_tokens, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| frequency_penalty=args.frequency_penalty, | |
| presence_penalty=args.presence_penalty, | |
| stop_sequences=args.stop_sequences, | |
| logprobs=args.logprobs, | |
| n=args.n, | |
| best_of=args.best_of, | |
| ) | |
| for data in results: | |
| fout.write(json.dumps(data) + "\n") |