import os import json import random import tqdm import re import argparse import pandas as pd from collections import OrderedDict from gpt3_api import make_requests as make_gpt3_requests from templates.instance_gen_template import output_first_template_for_clf, input_first_template_for_gen random.seed(42) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--batch_dir", type=str, required=True, help="The directory where the batch is stored.", ) parser.add_argument( "--input_file", type=str, default="machine_generated_instructions.jsonl" ) parser.add_argument( "--output_file", type=str, default="machine_generated_instances.jsonl", ) parser.add_argument( "--num_instructions", type=int, help="if specified, only generate instance input for this many instructions", ) parser.add_argument( "--max_instances_to_generate", type=int, default=5, help="The max number of instances to generate for each instruction.", ) parser.add_argument( "--generation_tasks_only", action="store_true", help="If specified, only do for generation tasks.", ) parser.add_argument( "--classification_tasks_only", action="store_true", help="If specified, only do for classification tasks.", ) parser.add_argument( "--engine", type=str, default="davinci", help="The engine to use." ) parser.add_argument( "--request_batch_size", type=int, default=5, help="The number of requests to send in a batch." ) parser.add_argument( "--api_key", type=str, help="The API key to use. If not specified, the key will be read from the environment variable OPENAI_API_KEY." ) parser.add_argument( "--organization", type=str, help="The organization to use. If not specified, the default organization id will be used." ) return parser.parse_args() if __name__ == '__main__': args = parse_args() with open(os.path.join(args.batch_dir, args.input_file)) as fin: lines = fin.readlines() if args.num_instructions is not None: lines = lines[:args.num_instructions] tasks = [] for line in lines: data = json.loads(line) if "metadata" in data: data["instruction_metadata"] = data["metadata"] del data["metadata"] tasks.append(data) task_clf_types = {} with open(os.path.join(args.batch_dir, "is_clf_or_not_davinci_template_1.jsonl")) as fin: for line in fin: data = json.loads(line) task_clf_types[data["instruction"]] = data["is_classification"].strip() in ["Yes", "yes", "YES"] if args.classification_tasks_only: tasks = [task for task in tasks if task_clf_types[task["instruction"]]] if args.generation_tasks_only: tasks = [task for task in tasks if not task_clf_types[task["instruction"]]] output_path = os.path.join(args.batch_dir, args.output_file) existing_requests = {} if os.path.exists(output_path): with open(output_path) as fin: for line in tqdm.tqdm(fin): try: data = json.loads(line) existing_requests[data["instruction"]] = data except: pass print(f"Loaded {len(existing_requests)} existing requests") progress_bar = tqdm.tqdm(total=len(tasks)) with open(output_path, "w") as fout: for batch_idx in range(0, len(tasks), args.request_batch_size): batch = tasks[batch_idx: batch_idx + args.request_batch_size] if all(d["instruction"] in existing_requests for d in batch): for d in batch: data = existing_requests[d["instruction"]] data = OrderedDict( (k, data[k]) for k in \ ["instruction", "raw_instances", "instance_metadata"] # ["instruction", "raw_instances", "instance_metadata", "instruction_metadata", # "most_similar", "avg_similarity_score"] ) fout.write(json.dumps(data, ensure_ascii=False) + "\n") else: prompts = [] for task in batch: if task_clf_types[task["instruction"]]: prompt = output_first_template_for_clf + " " + task["instruction"].strip() + "\n" prompts.append(prompt) else: prompt = input_first_template_for_gen + " " + task["instruction"].strip() + "\n" prompts.append(prompt) print("prompts", prompts) results = make_gpt3_requests( engine=args.engine, prompts=prompts, # because the clf template is longer, we need to decrease the max_tokens max_tokens=300 if any(task_clf_types[task["instruction"]] for task in batch) else 350, temperature=0, top_p=0, frequency_penalty=0, presence_penalty=1.5, stop_sequences=[f"Example {args.max_instances_to_generate + 1}", "Task:"], logprobs=1, n=1, best_of=1, api_key=args.api_key, organization=args.organization) for i in range(len(batch)): data = batch[i] data["instance_metadata"] = results[i] if results[i]["response"] is not None: data["raw_instances"] = results[i]["response"]["choices"][0]["text"] else: data["raw_instances"] = "" data = OrderedDict( (k, data[k]) for k in \ ["instruction", "raw_instances", "instance_metadata"] # ["instruction", "raw_instances", "instance_metadata", "instruction_metadata", # "most_similar", "avg_similarity_score"] ) fout.write(json.dumps(data, ensure_ascii=False) + "\n") progress_bar.update(len(batch))