Spaces:
Runtime error
Runtime error
| import dataclasses | |
| import functools | |
| import logging | |
| import os | |
| import pickle | |
| import pprint | |
| import random | |
| from typing import List | |
| EMPTY_DATA_PATH = "tangram_pngs/" | |
| SPLIT_PATH = "dataset_splits/" | |
| class GameConfig: | |
| speaker_context: List[str] | |
| listener_context: List[str] | |
| targets: List[str] | |
| def generate_game_config() -> GameConfig: | |
| corpus = _get_data() | |
| context = random.sample(corpus, 10) | |
| num_targets = random.randint(3, 5) | |
| targets = random.sample(context, num_targets) | |
| listener_order = list(range(10)) | |
| random.shuffle(listener_order) | |
| config = GameConfig( | |
| speaker_context=context, | |
| listener_context=[context[i] for i in listener_order], | |
| targets=targets, | |
| ) | |
| logging.info(f"context_dict: {pprint.pformat(dataclasses.asdict(config))}") | |
| return config | |
| def _get_data(hb_split: bool=True): | |
| if not hb_split: | |
| # 1013 images | |
| paths = os.listdir(EMPTY_DATA_PATH) | |
| else: | |
| # 912 images | |
| with open(os.path.join(SPLIT_PATH, "test_imgs.pkl"), 'rb') as f: | |
| paths = pickle.load(f) | |
| with open(os.path.join(SPLIT_PATH, "train_imgs.pkl"), 'rb') as f: | |
| paths += pickle.load(f) | |
| paths = [path + ".png" for path in paths] | |
| dup_images = ["page6-51.png", "page6-66.png", "page4-170.png"] | |
| paths = [path for path in paths if path != ".DS_Store" and path not in dup_images] | |
| return paths | |