Miranda2023's picture
app create
0d15ac6
import os
import json
import random
import tqdm
import re
import argparse
import pandas as pd
from collections import OrderedDict
from openai import OpenAIError
from .gpt3_api import make_requests as make_gpt3_requests
from .templates.instance_gen_template import output_first_template_for_clf, input_first_template_for_gen
from .templates.clf_task_template import template_1
random.seed(42)
engine = "davinci"
# def parse_args():
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--batch_dir",
# type=str,
# required=True,
# help="The directory where the batch is stored.",
# )
# parser.add_argument(
# "--input_file",
# type=str,
# default="machine_generated_instructions.jsonl"
# )
# parser.add_argument(
# "--output_file",
# type=str,
# default="machine_generated_instances.jsonl",
# )
# parser.add_argument(
# "--num_instructions",
# type=int,
# help="if specified, only generate instance input for this many instructions",
# )
# parser.add_argument(
# "--max_instances_to_generate",
# type=int,
# default=5,
# help="The max number of instances to generate for each instruction.",
# )
# parser.add_argument(
# "--generation_tasks_only",
# action="store_true",
# help="If specified, only do for generation tasks.",
# )
# parser.add_argument(
# "--classification_tasks_only",
# action="store_true",
# help="If specified, only do for classification tasks.",
# )
# parser.add_argument(
# "--engine",
# type=str,
# default="davinci",
# help="The engine to use."
# )
# parser.add_argument(
# "--request_batch_size",
# type=int,
# default=5,
# help="The number of requests to send in a batch."
# )
# parser.add_argument(
# "--api_key",
# type=str,
# help="The API key to use. If not specified, the key will be read from the environment variable OPENAI_API_KEY."
# )
# parser.add_argument(
# "--organization",
# type=str,
# help="The organization to use. If not specified, the default organization id will be used."
# )
# return parser.parse_args()
def if_classify(instructions, api_key):
prefix = template_1
prompts = [prefix + " " + instruct.strip() + "\n" + "Is it classification?" for instruct in instructions]
results = make_gpt3_requests(
engine=engine,
prompts=prompts,
max_tokens=3,
temperature=0,
top_p=0,
frequency_penalty=0,
presence_penalty=0,
stop_sequences=["\n", "Task"],
logprobs=1,
n=1,
best_of=1,
api_key=api_key)
classify_res = []
for i in range(len(prompts)):
if results[i]["response"] is not None:
if results[i]["response"]["choices"][0]["text"] in ["Yes", "yes", "YES"]:
classify_res.append(True)
else:
classify_res.append(False)
else:
print("**分类出错,", results[i])
classify_res.append("Unknown")
return classify_res
def filter_duplicate_instances(instances):
# if the instances have same non-empty input, but different output, we will not use such instances
same_input_diff_output = False
for i in range(1, len(instances)):
for j in range(0, i):
if instances[i][1] == "":
continue
if instances[i][1] == instances[j][1] and instances[i][2] != instances[j][2]:
same_input_diff_output = True
break
if same_input_diff_output:
return []
# remove duplicate instances
instances = list(set(instances))
return instances
def filter_invalid_instances(instances):
filtered_instances = []
for instance in instances:
# if input and output are the same, we will not use such instances
if instance[1] == instance[2]:
continue
# if output is empty, we will not use such instances
if instance[2] == "":
continue
# if input or output ends with a colon, these are usually imcomplete generation. We will not use such instances
if instance[1].strip().endswith(":") or instance[2].strip().endswith(":"):
continue
filtered_instances.append(instance)
return filtered_instances
def encode_instance(instruction, input, output, random_template=True):
encoding_templates_w_input = [
("{instruction}\nInput: {input}\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n\nInput: {input}\n\nOutput:", " {output}<|endoftext|>"),
("Task: {instruction}\nInput: {input}\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n\n{input}\n\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n\n{input}\n\n", "{output}<|endoftext|>"),
("{instruction}\n{input}\n\n", "{output}<|endoftext|>"),
("Task: {instruction}\n\n{input}\n\n", "{output}<|endoftext|>"),
]
encoding_templates_wo_input = [
("{instruction} Output:", " {output}<|endoftext|>"),
("{instruction}\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n\nOutput:", " {output}<|endoftext|>"),
("{instruction}\n", "{output}<|endoftext|>"),
("{instruction}\n\n", "{output}<|endoftext|>"),
("Task: {instruction}\n\n", "{output}<|endoftext|>"),
]
if random_template:
if input.strip() != "":
prompt_template, completion_template = random.choice(encoding_templates_w_input)
prompt = prompt_template.format(instruction=instruction.strip(), input=input.strip())
completion = completion_template.format(output=output.strip())
else:
prompt_template, completion_template = random.choice(encoding_templates_wo_input)
prompt = prompt_template.format(instruction=instruction.strip())
completion = completion_template.format(output=output.strip())
else:
prompt = instruction.strip() + "\n\n" + input.strip() + "\n\n"
completion = output.strip() + "<|endoftext|>"
data = {
"prompt": prompt,
"completion": completion,
"instruction": instruction.strip(),
"input": input.strip(),
"output": output.strip(),
}
return data
def parse_input_output(response_text):
if re.findall(r"Output\s*\d*\s*:", response_text):
inst_input = re.split(r"Output\s*\d*\s*:", response_text)[0].strip()
inst_output = re.split(r"Output\s*\d*\s*:", response_text)[1].strip()
else:
inst_input = ""
inst_output = response_text.strip()
# to avoid the case multiple input/output pairs are generated
if re.findall(r"Input\s*\d*\s*:", inst_output):
inst_output = re.split(r"Input\s*\d*\s*:", inst_output)[0].strip()
# remove the prefix "Input:" from the string
inst_input = re.sub(r"^Input\s*\d*\s*:", "", inst_input).strip()
return inst_input, inst_output
def parse_instances_for_generation_task(raw_text, instruction, response_metadata):
instances = []
raw_text = raw_text.strip()
if re.findall("Example\s?\d*\.?", raw_text):
instance_texts = re.split(r"Example\s?\d*\.?", raw_text)
instance_texts = [it.strip() for it in instance_texts if it.strip() != ""]
for instance_text in instance_texts:
inst_input, inst_output = parse_input_output(instance_text)
instances.append((instruction.strip(), inst_input.strip(), inst_output.strip()))
elif re.findall(r"Output\s*\d*\s*:", raw_text):
# we assume only one input/output pair in this case
inst_input, inst_output = parse_input_output(raw_text)
instances.append((instruction.strip(), inst_input.strip(), inst_output.strip()))
else:
return []
# if the generation stops because of length, we remove the last instance
if response_metadata["response"]["choices"][0]["finish_reason"] == "length":
instances = instances[:-1]
instances = filter_invalid_instances(instances)
instances = filter_duplicate_instances(instances)
return instances
def parse_instances_for_classification_task(raw_text, instruction, response_metadata):
instances = []
if not "Class label:" in raw_text:
return []
instance_texts = raw_text.split("Class label:")[1:]
for instance_text in instance_texts:
instance_text = instance_text.strip()
fields = instance_text.split("\n", 1)
if len(fields) == 2:
# the first field split by \n is the class label
class_label = fields[0].strip()
# the rest is the input
input_text = fields[1].strip()
elif len(fields) == 1:
# the first field split by \n is the input
class_label = fields[0].strip()
input_text = ""
else:
raise ValueError("Invalid instance text: {}".format(instance_text))
instances.append((instruction.strip(), input_text.strip(), class_label.strip()))
# if the generation stops because of length, we remove the last instance
if response_metadata["response"]["choices"][0]["finish_reason"] == "length":
instances = instances[:-1]
instances = filter_invalid_instances(instances)
instances = filter_duplicate_instances(instances)
return instances
def generate_instance(inputs, api_key):
classify_res = if_classify(inputs, api_key)
prompts = []
for i in range(len(inputs)):
if classify_res[i] in ["Yes", "yes", "YES"]:
prompt = output_first_template_for_clf + " " + inputs[i].strip() + "\n"
prompts.append(prompt)
else:
prompt = input_first_template_for_gen + " " + inputs[i].strip() + "\n"
prompts.append(prompt)
# print("prompts", prompts)
results = make_gpt3_requests(
engine=engine,
prompts=prompts,
# because the clf template is longer, we need to decrease the max_tokens
max_tokens=350,
temperature=0,
top_p=0,
frequency_penalty=0,
presence_penalty=1.5,
stop_sequences=["Task:"],
logprobs=1,
n=1,
best_of=1,
api_key=api_key)
return results, classify_res
def prepare_finetune(inputs, api_key):
instance_outputs, classify_res = generate_instance(inputs, api_key)
training_instances = []
results1, results2 = [], []
for i in range(len(inputs)):
if classify_res[i]:
task_instances = parse_instances_for_classification_task(instance_outputs[i]["response"]["choices"][0]["text"],
inputs[i].strip(), instance_outputs[i])
else:
task_instances = parse_instances_for_generation_task(instance_outputs[i]["response"]["choices"][0]["text"],
inputs[i].strip(), instance_outputs[i])
# we only allow max 5 instances per task
task_instances = random.sample(task_instances, min(len(task_instances), 5))
if not task_instances:
continue
training_instances += task_instances
for instance in training_instances:
results1.append({
"instruction": instance[0],
"input": instance[1],
"output": instance[2],
})
results2.append(json.dumps({
"instruction": instance[0],
"input": instance[1],
"output": instance[2],
}, ensure_ascii=False))
return results1, classify_res, instance_outputs, results2
def instance_main(inputs, key):
try:
import openai
openai.api_key = key
MODEL = "gpt-3.5-turbo"
openai.ChatCompletion.create(
model=MODEL,
messages=[
{"role": "user", "content": "Hi"}
],
temperature=1
)
except OpenAIError:
return {"Wrong": "Key!"}, " ", " ", " "
api_key = key
inputs = inputs.split('\n')
print("***", inputs)
return prepare_finetune(inputs, api_key)
# instance_main()