Spaces:
Configuration error
Configuration error
| import os | |
| import random | |
| import json | |
| from tqdm import tqdm | |
| import argparse | |
| import pathlib | |
| import openai | |
| openai.organization = os.getenv('OPENAI_ORG') | |
| openai.api_key = os.getenv('OPENAI_API_KEY') | |
| from load_aokvqa import load_aokvqa | |
| random.seed(0) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') | |
| parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True) | |
| parser.add_argument('--n', type=int, default=10, dest='num_examples') | |
| parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file') | |
| parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix') | |
| parser.add_argument('--include-choices', action='store_true', dest='include_choices') | |
| parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file') | |
| parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file') | |
| args = parser.parse_args() | |
| train_set = load_aokvqa(args.aokvqa_dir, 'train') | |
| eval_set = load_aokvqa(args.aokvqa_dir, args.split) | |
| train_context = {} | |
| context = {} | |
| if args.context_file is not None: | |
| train_context = json.load(args.train_context_file) | |
| context = json.load(args.context_file) | |
| predictions = {} | |
| for d in tqdm(eval_set): | |
| q = d['question_id'] | |
| prompt = args.prompt_prefix | |
| for e in random.sample(train_set, args.num_examples): | |
| prompt += prompt_element(e, | |
| context=train_context.get(q, None), | |
| include_choices=args.include_choices, | |
| answer=True | |
| ) | |
| prompt += '\n\n' | |
| prompt += prompt_element(d, | |
| context=context.get(q, None), | |
| include_choices=args.include_choices, | |
| answer=False | |
| ) | |
| response = openai.Completion.create( | |
| engine="text-curie-001", | |
| prompt=prompt, | |
| temperature=0.0, | |
| max_tokens=10, | |
| ) | |
| predictions[q] = response.choices[0].text.strip() | |
| json.dump(predictions, args.output_file) | |
| def prompt_element(d, context=None, include_choices=False, answer=False): | |
| return (f"Context: {context}\n" if context is not None else '') + \ | |
| f"Q: {d['question']}\n" + \ | |
| (f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \ | |
| f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '') | |
| if __name__ == '__main__': | |
| main() | |