| """Generate answers with GPT-3.5""" |
| |
| import argparse |
| import json |
| import os |
| import time |
| import concurrent.futures |
|
|
| import openai |
| import tqdm |
| import shortuuid |
|
|
| MODEL = 'gpt-3.5-turbo' |
| MODEL_ID = 'gpt-3.5-turbo:20230327' |
|
|
| def get_answer(question_id: int, question: str, max_tokens: int): |
| ans = { |
| 'answer_id': shortuuid.uuid(), |
| 'question_id': question_id, |
| 'model_id': MODEL_ID, |
| } |
| for _ in range(3): |
| try: |
| response = openai.ChatCompletion.create( |
| model=MODEL, |
| messages=[{ |
| 'role': 'system', |
| 'content': 'You are a helpful assistant.' |
| }, { |
| 'role': 'user', |
| 'content': question, |
| }], |
| max_tokens=max_tokens, |
| ) |
| ans['text'] = response['choices'][0]['message']['content'] |
| return ans |
| except Exception as e: |
| print('[ERROR]', e) |
| ans['text'] = '#ERROR#' |
| time.sleep(1) |
| return ans |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description='ChatGPT answer generation.') |
| parser.add_argument('-q', '--question') |
| parser.add_argument('-o', '--output') |
| parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') |
| args = parser.parse_args() |
|
|
| questions_dict = {} |
| with open(os.path.expanduser(args.question)) as f: |
| for line in f: |
| if not line: |
| continue |
| q = json.loads(line) |
| questions_dict[q['question_id']] = q['text'] |
|
|
| answers = [] |
|
|
| with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: |
| futures = [] |
| for qid, question in questions_dict.items(): |
| future = executor.submit(get_answer, qid, question, args.max_tokens) |
| futures.append(future) |
|
|
| for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): |
| answers.append(future.result()) |
|
|
| answers.sort(key=lambda x: x['question_id']) |
|
|
| with open(os.path.expanduser(args.output), 'w') as f: |
| table = [json.dumps(ans) for ans in answers] |
| f.write('\n'.join(table)) |
|
|