| """Generate answers with GPT-3.5""" |
| |
| import argparse |
| import json |
| import os |
| import time |
| import concurrent.futures |
|
|
| import openai |
| import tqdm |
| import shortuuid |
|
|
| MODEL = "gpt-3.5-turbo" |
| MODEL_ID = "gpt-3.5-turbo:20230327" |
|
|
|
|
| def get_answer(question_id: int, question: str, max_tokens: int): |
| ans = { |
| "answer_id": shortuuid.uuid(), |
| "question_id": question_id, |
| "model_id": MODEL_ID, |
| } |
| for _ in range(3): |
| try: |
| response = openai.ChatCompletion.create( |
| model=MODEL, |
| messages=[ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| { |
| "role": "user", |
| "content": question, |
| }, |
| ], |
| max_tokens=max_tokens, |
| ) |
| ans["text"] = response["choices"][0]["message"]["content"] |
| return ans |
| except Exception as e: |
| print("[ERROR]", e) |
| ans["text"] = "#ERROR#" |
| time.sleep(1) |
| return ans |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="ChatGPT answer generation.") |
| parser.add_argument("-q", "--question") |
| parser.add_argument("-o", "--output") |
| parser.add_argument( |
| "--max-tokens", |
| type=int, |
| default=1024, |
| help="maximum number of tokens produced in the output", |
| ) |
| args = parser.parse_args() |
|
|
| questions_dict = {} |
| with open(os.path.expanduser(args.question)) as f: |
| for line in f: |
| if not line: |
| continue |
| q = json.loads(line) |
| questions_dict[q["question_id"]] = q["text"] |
|
|
| answers = [] |
|
|
| with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: |
| futures = [] |
| for qid, question in questions_dict.items(): |
| future = executor.submit(get_answer, qid, question, args.max_tokens) |
| futures.append(future) |
|
|
| for future in tqdm.tqdm( |
| concurrent.futures.as_completed(futures), total=len(futures) |
| ): |
| answers.append(future.result()) |
|
|
| answers.sort(key=lambda x: x["question_id"]) |
|
|
| with open(os.path.expanduser(args.output), "w") as f: |
| table = [json.dumps(ans) for ans in answers] |
| f.write("\n".join(table)) |
|
|