| """Generate answers with GPT-3.5"""
|
|
|
| import argparse
|
| import json
|
| import os
|
| import time
|
| import concurrent.futures
|
|
|
| import openai
|
| import tqdm
|
| import shortuuid
|
|
|
| MODEL = 'gpt-3.5-turbo'
|
| MODEL_ID = 'gpt-3.5-turbo:20230327'
|
|
|
| def get_answer(question_id: int, question: str, max_tokens: int):
|
| ans = {
|
| 'answer_id': shortuuid.uuid(),
|
| 'question_id': question_id,
|
| 'model_id': MODEL_ID,
|
| }
|
| for _ in range(3):
|
| try:
|
| response = openai.ChatCompletion.create(
|
| model=MODEL,
|
| messages=[{
|
| 'role': 'system',
|
| 'content': 'You are a helpful assistant.'
|
| }, {
|
| 'role': 'user',
|
| 'content': question,
|
| }],
|
| max_tokens=max_tokens,
|
| )
|
| ans['text'] = response['choices'][0]['message']['content']
|
| return ans
|
| except Exception as e:
|
| print('[ERROR]', e)
|
| ans['text'] = '#ERROR#'
|
| time.sleep(1)
|
| return ans
|
|
|
|
|
| if __name__ == '__main__':
|
| parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
|
| parser.add_argument('-q', '--question')
|
| parser.add_argument('-o', '--output')
|
| parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
|
| args = parser.parse_args()
|
|
|
| questions_dict = {}
|
| with open(os.path.expanduser(args.question)) as f:
|
| for line in f:
|
| if not line:
|
| continue
|
| q = json.loads(line)
|
| questions_dict[q['question_id']] = q['text']
|
|
|
| answers = []
|
|
|
| with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
| futures = []
|
| for qid, question in questions_dict.items():
|
| future = executor.submit(get_answer, qid, question, args.max_tokens)
|
| futures.append(future)
|
|
|
| for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
|
| answers.append(future.result())
|
|
|
| answers.sort(key=lambda x: x['question_id'])
|
|
|
| with open(os.path.expanduser(args.output), 'w') as f:
|
| table = [json.dumps(ans) for ans in answers]
|
| f.write('\n'.join(table))
|
|
|