Spaces:
Runtime error
Runtime error
Commit
·
2f56479
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +4 -0
- README.md +12 -0
- adapter.py +244 -0
- app.py +335 -0
- config_generator.py +51 -0
- dataset_splits/dev_imgs.pkl +0 -0
- dataset_splits/test_imgs.pkl +0 -0
- dataset_splits/train_imgs.pkl +0 -0
- requirements.txt +9 -0
- tangram_pngs/page-A.png +0 -0
- tangram_pngs/page-B.png +0 -0
- tangram_pngs/page-C.png +0 -0
- tangram_pngs/page-D.png +0 -0
- tangram_pngs/page-E.png +0 -0
- tangram_pngs/page-F.png +0 -0
- tangram_pngs/page-G.png +0 -0
- tangram_pngs/page-H.png +0 -0
- tangram_pngs/page-I.png +0 -0
- tangram_pngs/page-J.png +0 -0
- tangram_pngs/page-K.png +0 -0
- tangram_pngs/page-L.png +0 -0
- tangram_pngs/page1-0.png +0 -0
- tangram_pngs/page1-1.png +0 -0
- tangram_pngs/page1-10.png +0 -0
- tangram_pngs/page1-103.png +0 -0
- tangram_pngs/page1-105.png +0 -0
- tangram_pngs/page1-106.png +0 -0
- tangram_pngs/page1-107.png +0 -0
- tangram_pngs/page1-108.png +0 -0
- tangram_pngs/page1-109.png +0 -0
- tangram_pngs/page1-110.png +0 -0
- tangram_pngs/page1-112.png +0 -0
- tangram_pngs/page1-113.png +0 -0
- tangram_pngs/page1-114.png +0 -0
- tangram_pngs/page1-116.png +0 -0
- tangram_pngs/page1-117.png +0 -0
- tangram_pngs/page1-118.png +0 -0
- tangram_pngs/page1-119.png +0 -0
- tangram_pngs/page1-122.png +0 -0
- tangram_pngs/page1-125.png +0 -0
- tangram_pngs/page1-128.png +0 -0
- tangram_pngs/page1-129.png +0 -0
- tangram_pngs/page1-13.png +0 -0
- tangram_pngs/page1-130.png +0 -0
- tangram_pngs/page1-132.png +0 -0
- tangram_pngs/page1-133.png +0 -0
- tangram_pngs/page1-136.png +0 -0
- tangram_pngs/page1-137.png +0 -0
- tangram_pngs/page1-14.png +0 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.DS_Store
|
| 4 |
+
.vscode
|
README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Respect
|
| 3 |
+
emoji: 🫡
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
adapter.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import re
|
| 3 |
+
from functools import cache
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Set, Tuple, TypeVar
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from utils import device, nested_apply, sorted_list
|
| 11 |
+
|
| 12 |
+
RE_PATTERN = r'^(deselect\s[A-Z](?:\s[A-Z])*(?:\sselect\s[A-Z](?:\s[A-Z])*)?|select\s[A-Z](?:\s[A-Z])*)$' # noqa
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Name type, newtype of str. e.g. "page4-249.png"
|
| 16 |
+
N = TypeVar('N')
|
| 17 |
+
|
| 18 |
+
ALPHABET = 'ABCDEFGHIJ' # we only have 10 images
|
| 19 |
+
LEGAL_TOKEN_IDS = [2, 315, 330, 334, 365, 382, 384, 401, 413,
|
| 20 |
+
420, 475, 5339, 634, 17960, 32002] # A - J and <end_of_utterance> and <\s> and 'select' and 'deselect'
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
MINI_DECODER = {
|
| 24 |
+
384: 'D',
|
| 25 |
+
# 2: '</s>',
|
| 26 |
+
32002: '<end_of_utterance>',
|
| 27 |
+
420: 'G', 17960: 'elect',
|
| 28 |
+
330: 'A', 365: 'B', 334: 'C', 5339: 'select', 401: 'F', 475: 'J',
|
| 29 |
+
634: 'des', 315: 'I', 413: 'E', 382: 'H'}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AlphabeticNameHash:
|
| 33 |
+
|
| 34 |
+
@cache
|
| 35 |
+
def __init__(self, context: List[N]) -> None:
|
| 36 |
+
self._forward_map = {im: ALPHABET[i] for i, im in enumerate(context)}
|
| 37 |
+
self._backward_map = {ALPHABET[i]: im for i, im in enumerate(context)}
|
| 38 |
+
|
| 39 |
+
def hash(self, im: N) -> str:
|
| 40 |
+
return self._forward_map[im]
|
| 41 |
+
|
| 42 |
+
def unhash(self, i: str) -> N:
|
| 43 |
+
return self._backward_map[i]
|
| 44 |
+
|
| 45 |
+
def valid_hash(self, i: str) -> bool:
|
| 46 |
+
return i in self._backward_map
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class IdeficsAdapter:
|
| 50 |
+
|
| 51 |
+
PAD_TOKEN_ID = 0
|
| 52 |
+
LABEL_MASK_ID = 32001 # idefics2: image_token_id
|
| 53 |
+
LEGAL_TOKEN_IDS = LEGAL_TOKEN_IDS
|
| 54 |
+
LEGAL_TOKEN_MASK = torch.zeros(32003, requires_grad=False)\
|
| 55 |
+
.index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool)
|
| 56 |
+
SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS))
|
| 57 |
+
|
| 58 |
+
def __init__(self, image_folder: str, processor) -> None:
|
| 59 |
+
self.t_max_length = 2048
|
| 60 |
+
self.image_folder = Path(image_folder)
|
| 61 |
+
self.image_cache = {}
|
| 62 |
+
self.processor = processor
|
| 63 |
+
self.tokenizer = self.processor.tokenizer
|
| 64 |
+
|
| 65 |
+
def get_image(self, im_name: N) -> Image.Image:
|
| 66 |
+
if im_name not in self.image_cache:
|
| 67 |
+
self.image_cache[im_name] = Image.open(
|
| 68 |
+
self.image_folder.joinpath(im_name))
|
| 69 |
+
return self.image_cache[im_name]
|
| 70 |
+
|
| 71 |
+
def unhash(self, context: List[N], c: str):
|
| 72 |
+
return AlphabeticNameHash(tuple(context)).unhash(c)
|
| 73 |
+
|
| 74 |
+
def valid_hash(self, context: List[N], c: str):
|
| 75 |
+
return AlphabeticNameHash(tuple(context)).valid_hash(c)
|
| 76 |
+
|
| 77 |
+
def parse(self, context: List[N], decoded_out: str,
|
| 78 |
+
currently_selected: List[N]) -> List[str]:
|
| 79 |
+
h = AlphabeticNameHash(tuple(context))
|
| 80 |
+
logging.debug(f"{context=}")
|
| 81 |
+
# do inference
|
| 82 |
+
logging.debug(f"{decoded_out=}")
|
| 83 |
+
selection, deselection = self.parse_raw(decoded_out)
|
| 84 |
+
|
| 85 |
+
hashed_currently_selected = {h.hash(n) for n in currently_selected}
|
| 86 |
+
desel_to_remove = deselection - hashed_currently_selected
|
| 87 |
+
if len(desel_to_remove) > 0:
|
| 88 |
+
logging.debug(f"warn! {desel_to_remove=}")
|
| 89 |
+
deselection = deselection - desel_to_remove
|
| 90 |
+
|
| 91 |
+
sel_to_remove = selection & hashed_currently_selected
|
| 92 |
+
if len(sel_to_remove) > 0:
|
| 93 |
+
logging.debug(f"warn! {sel_to_remove=}")
|
| 94 |
+
selection = selection - sel_to_remove
|
| 95 |
+
|
| 96 |
+
logging.debug("post strict cleaning")
|
| 97 |
+
logging.debug(f"{selection=}")
|
| 98 |
+
logging.debug(f"{deselection=}")
|
| 99 |
+
|
| 100 |
+
model_clicks = selection | deselection
|
| 101 |
+
logging.debug(f"{model_clicks=}")
|
| 102 |
+
model_clicks_png = [h.unhash(n)
|
| 103 |
+
for n in model_clicks if h.valid_hash(n)]
|
| 104 |
+
logging.debug(f"{model_clicks_png=}")
|
| 105 |
+
return model_clicks_png
|
| 106 |
+
|
| 107 |
+
@staticmethod
|
| 108 |
+
def parse_raw(text: str) -> Tuple[Set[N], Set[N]]:
|
| 109 |
+
last_answer = text.strip()
|
| 110 |
+
if ":" in text:
|
| 111 |
+
last_answer_pattern = r":.*$"
|
| 112 |
+
xs = re.findall(last_answer_pattern, text)
|
| 113 |
+
last_answer = xs[0].removeprefix(":").strip()
|
| 114 |
+
xs = re.search(RE_PATTERN, last_answer)
|
| 115 |
+
if xs is None:
|
| 116 |
+
print(f"{last_answer=}")
|
| 117 |
+
print("did not pass regex")
|
| 118 |
+
return set(), set()
|
| 119 |
+
|
| 120 |
+
select_pattern = r"(?<!de)select( [A-J])+$"
|
| 121 |
+
xs = re.search(select_pattern, last_answer)
|
| 122 |
+
if xs is not None:
|
| 123 |
+
xs = xs.group()
|
| 124 |
+
selections = set(xs.split(" ")[1:]) if xs else set()
|
| 125 |
+
|
| 126 |
+
deselect_pattern = r"^deselect( [A-J])+"
|
| 127 |
+
xs = re.search(deselect_pattern, last_answer)
|
| 128 |
+
if xs is not None:
|
| 129 |
+
xs = xs.group()
|
| 130 |
+
deselections = set(xs.split(" ")[1:]) if xs else set()
|
| 131 |
+
|
| 132 |
+
return selections, deselections
|
| 133 |
+
|
| 134 |
+
def compose(self, context, chats, previous_selected, hash_images, padding):
|
| 135 |
+
select_accum, deselect_accum, clickss = self.unfold_select_deselect(
|
| 136 |
+
previous_selected)
|
| 137 |
+
|
| 138 |
+
select_accum = select_accum + [[]]
|
| 139 |
+
deselect_accum = deselect_accum + [[]]
|
| 140 |
+
previous_selected = [[]] + previous_selected # old states pre click
|
| 141 |
+
assert len(chats) == len(select_accum) == len(
|
| 142 |
+
deselect_accum) == len(previous_selected)
|
| 143 |
+
|
| 144 |
+
messages, images = self.build_processor_input(
|
| 145 |
+
context, chats, select_accum, deselect_accum, previous_selected, hash_images, omit_last_answer=True, sort_names=True, omit_context=False, chat_feedback=None)
|
| 146 |
+
prompt = self.processor.apply_chat_template(
|
| 147 |
+
messages, add_generation_prompt=True)
|
| 148 |
+
prompt = prompt.strip()
|
| 149 |
+
logging.debug(prompt)
|
| 150 |
+
# Keep consistent with train_script
|
| 151 |
+
inputs = self.processor(
|
| 152 |
+
text=prompt, images=images,
|
| 153 |
+
padding=padding, truncation=True, max_length=self.t_max_length,
|
| 154 |
+
return_tensors="pt")
|
| 155 |
+
return inputs
|
| 156 |
+
|
| 157 |
+
def build_processor_input(self, image_pngs: List[N], chats: List[str],
|
| 158 |
+
select_accum: List[List[N]],
|
| 159 |
+
deselect_accum: List[List[N]],
|
| 160 |
+
pre_click_selected_accum: List[List[N]],
|
| 161 |
+
hash_image: bool, omit_last_answer: bool,
|
| 162 |
+
sort_names: bool, omit_context: bool,
|
| 163 |
+
chat_feedback: str, ):
|
| 164 |
+
def _text_content(text): return {"type": "text", "text": text}
|
| 165 |
+
|
| 166 |
+
def _image_content(): return {"type": "image"}
|
| 167 |
+
|
| 168 |
+
def _user_prompt(content): return {"role": "user", "content": content}
|
| 169 |
+
|
| 170 |
+
def _assistant_prompt(content): return {
|
| 171 |
+
"role": "assistant", "content": content}
|
| 172 |
+
|
| 173 |
+
def _system_prompt(content): return {
|
| 174 |
+
"role": "system", "content": content}
|
| 175 |
+
|
| 176 |
+
def _current_state(selected: List[N]):
|
| 177 |
+
if len(selected) == 0:
|
| 178 |
+
return 'none is selected'
|
| 179 |
+
return f'{" ".join(selected)} currently selected'
|
| 180 |
+
|
| 181 |
+
def _listener_action(select: List[N], deselect: List[N]):
|
| 182 |
+
if len(select) == 0 and len(deselect) == 0:
|
| 183 |
+
return 'nothing'
|
| 184 |
+
if len(select) == 0:
|
| 185 |
+
return f'deselect {" ".join(deselect)}'
|
| 186 |
+
if len(deselect) == 0:
|
| 187 |
+
return f'select {" ".join(select)}'
|
| 188 |
+
return f'deselect {" ".join(deselect)} select {" ".join(select)}'
|
| 189 |
+
|
| 190 |
+
func = AlphabeticNameHash(tuple(image_pngs)).hash if hash_image else id
|
| 191 |
+
context, select_accum, deselect_accum, pre_click_selected_accum = nested_apply(
|
| 192 |
+
func, (image_pngs, select_accum, deselect_accum, pre_click_selected_accum))
|
| 193 |
+
|
| 194 |
+
prompt = []
|
| 195 |
+
images = []
|
| 196 |
+
if not omit_context:
|
| 197 |
+
images = [self.get_image(im) for im in image_pngs]
|
| 198 |
+
images_and_names_content = []
|
| 199 |
+
for im_name in context:
|
| 200 |
+
images_and_names_content.append(_image_content())
|
| 201 |
+
images_and_names_content.append(_text_content(im_name))
|
| 202 |
+
prompt.append(_system_prompt(images_and_names_content))
|
| 203 |
+
if not len(chats) == len(select_accum) == len(deselect_accum) == len(pre_click_selected_accum):
|
| 204 |
+
logging.error(f"{chats=}")
|
| 205 |
+
logging.error(f"{select_accum=}")
|
| 206 |
+
logging.error(f"{deselect_accum=}")
|
| 207 |
+
logging.error(f"{pre_click_selected_accum=}")
|
| 208 |
+
assert False
|
| 209 |
+
for i, (chat, select, deselect, pre_click_selected) in enumerate(
|
| 210 |
+
zip(chats, select_accum, deselect_accum, pre_click_selected_accum)):
|
| 211 |
+
if sort_names:
|
| 212 |
+
select = sorted(select)
|
| 213 |
+
deselect = sorted(deselect)
|
| 214 |
+
pre_click_selected = sorted(pre_click_selected)
|
| 215 |
+
|
| 216 |
+
prompt.append(_system_prompt(
|
| 217 |
+
[_text_content(_current_state(pre_click_selected))]))
|
| 218 |
+
prompt.append(_user_prompt([_text_content(chat)]))
|
| 219 |
+
prompt.append(_assistant_prompt(
|
| 220 |
+
[_text_content(_listener_action(select, deselect))]))
|
| 221 |
+
if omit_last_answer:
|
| 222 |
+
# idefics2 has processor.apply_chat_template(messages, add_generation_prompt=True) instead
|
| 223 |
+
prompt.pop(-1)
|
| 224 |
+
if chat_feedback is not None:
|
| 225 |
+
prompt.append(_user_prompt([_text_content(chat_feedback)]))
|
| 226 |
+
return prompt, images
|
| 227 |
+
|
| 228 |
+
def unfold_select_deselect(self, previous_selected: List[List[N]]) -> Tuple[List[N], List[N], List[N]]:
|
| 229 |
+
# currently selected AFTER i-th turn
|
| 230 |
+
num_turns = len(previous_selected)
|
| 231 |
+
selected: List[List[str]] = [] # turn-wise selection
|
| 232 |
+
deselected: List[List[str]] = [] # turn-wise deselection
|
| 233 |
+
clicks: List[List[str]] = []
|
| 234 |
+
# combining turn-wise newly selected and newly deselected
|
| 235 |
+
prev_selected = set()
|
| 236 |
+
for turn in range(num_turns):
|
| 237 |
+
curr_selected = set(previous_selected[turn])
|
| 238 |
+
newly_selected = curr_selected - prev_selected
|
| 239 |
+
newly_deselected = prev_selected - curr_selected
|
| 240 |
+
selected.append(sorted_list(newly_selected))
|
| 241 |
+
deselected.append(sorted_list(newly_deselected))
|
| 242 |
+
clicks.append(sorted_list(newly_selected | newly_deselected))
|
| 243 |
+
prev_selected = curr_selected.copy()
|
| 244 |
+
return selected, deselected, clicks
|
app.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import Any, Dict, List
|
| 5 |
+
|
| 6 |
+
import gradio as gr # type: ignore
|
| 7 |
+
import PIL.Image as Image
|
| 8 |
+
import PIL.ImageOps as ImageOps
|
| 9 |
+
import spaces # type: ignore
|
| 10 |
+
import torch
|
| 11 |
+
from peft import PeftModel # type: ignore
|
| 12 |
+
from transformers import AutoProcessor # type: ignore
|
| 13 |
+
from transformers import Idefics2ForConditionalGeneration, Idefics2Processor
|
| 14 |
+
|
| 15 |
+
from adapter import IdeficsAdapter
|
| 16 |
+
from config_generator import GameConfig, generate_game_config
|
| 17 |
+
from utils import device, nested_to_device, sorted_list
|
| 18 |
+
import copy
|
| 19 |
+
|
| 20 |
+
### Constants
|
| 21 |
+
css="""
|
| 22 |
+
.radio-group .wrap {
|
| 23 |
+
display: grid;
|
| 24 |
+
grid-template-columns: repeat(5, 1fr);
|
| 25 |
+
grid-template-rows: repeat(5, 1fr);
|
| 26 |
+
width: 100%;
|
| 27 |
+
height: 100%
|
| 28 |
+
}
|
| 29 |
+
"""
|
| 30 |
+
IMG_DIR = "tangram_pngs"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
### Bot server
|
| 34 |
+
|
| 35 |
+
GEN_KWS: Dict[str, Any] = {
|
| 36 |
+
"max_new_tokens": 10,
|
| 37 |
+
"do_sample": True,
|
| 38 |
+
"temperature": 1.0,
|
| 39 |
+
"output_logits": True,
|
| 40 |
+
"return_dict_in_generate": True,
|
| 41 |
+
"remove_invalid_values": True, # just to be safe
|
| 42 |
+
"renormalize_logits": True,
|
| 43 |
+
"suppress_tokens": IdeficsAdapter.SUPPRESS_TOKEN_IDS
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
@spaces.GPU(duration=20)
|
| 47 |
+
def get_model_response( # predict
|
| 48 |
+
model: PeftModel, adapter_name: str, adapter: IdeficsAdapter,
|
| 49 |
+
image_paths: List[str], chat : str, chats: List[str],
|
| 50 |
+
previous_selected: List[List[str]]
|
| 51 |
+
) -> List[str]:
|
| 52 |
+
if model.active_adapter != adapter_name:
|
| 53 |
+
model.set_adapter(adapter_name)
|
| 54 |
+
|
| 55 |
+
model.to(device())
|
| 56 |
+
|
| 57 |
+
new_chats = chats + [chat]
|
| 58 |
+
currently_selected = previous_selected[-1] if len(previous_selected) > 0 else []
|
| 59 |
+
model_input: Dict[str, Any] = adapter.compose( # type: ignore
|
| 60 |
+
image_paths, new_chats, previous_selected, True, False)
|
| 61 |
+
model_input = nested_to_device(model_input) # type: ignore
|
| 62 |
+
|
| 63 |
+
with torch.inference_mode(), torch.autocast(device_type=device().type,
|
| 64 |
+
dtype=torch.bfloat16):
|
| 65 |
+
model_output = model.generate(**model_input, **GEN_KWS) # type: ignore
|
| 66 |
+
|
| 67 |
+
decoded_out: str = adapter.tokenizer.decode( # type: ignore
|
| 68 |
+
model_output.sequences[0], skip_special_tokens=True)
|
| 69 |
+
model_clicks = adapter.parse(
|
| 70 |
+
image_paths, decoded_out, currently_selected) # type: ignore
|
| 71 |
+
|
| 72 |
+
if len(model_clicks) == 0:
|
| 73 |
+
logging.warning("empty clicks by model")
|
| 74 |
+
model_clicks = [image_paths[0]]
|
| 75 |
+
logging.debug(f"{image_paths=}")
|
| 76 |
+
logging.debug(f"selecting {model_clicks}")
|
| 77 |
+
prob = -1
|
| 78 |
+
else:
|
| 79 |
+
prob = -3
|
| 80 |
+
logging.debug(f"{prob=}")
|
| 81 |
+
logging.info(f"User input: {chat}")
|
| 82 |
+
logging.info(f"Model selected: {model_clicks}")
|
| 83 |
+
logging.debug(f"Model output: {decoded_out}")
|
| 84 |
+
return model_clicks
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_model() -> PeftModel:
|
| 88 |
+
model_id = 'lil-lab/respect'
|
| 89 |
+
checkpoint = "HuggingFaceM4/idefics2-8b"
|
| 90 |
+
model = Idefics2ForConditionalGeneration.from_pretrained( # type: ignore
|
| 91 |
+
checkpoint, torch_dtype=torch.bfloat16,
|
| 92 |
+
)
|
| 93 |
+
peft_model = PeftModel.from_pretrained( # type: ignore
|
| 94 |
+
model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp")
|
| 95 |
+
|
| 96 |
+
# Add other adapter - hack to avoid conflict
|
| 97 |
+
lora_config = copy.deepcopy(peft_model.active_peft_config)
|
| 98 |
+
targets = list(set(n[:n.find('lora')-1] for n, _ in model.named_parameters()
|
| 99 |
+
if 'lora' in n))
|
| 100 |
+
lora_config.target_modules = targets
|
| 101 |
+
peft_model.add_adapter("r0", lora_config)
|
| 102 |
+
peft_model.load_adapter(model_id, "r0", is_trainable=False, revision="r0",
|
| 103 |
+
peft_config=lora_config)
|
| 104 |
+
return peft_model
|
| 105 |
+
|
| 106 |
+
def get_processor() -> Idefics2Processor:
|
| 107 |
+
checkpoint = "HuggingFaceM4/idefics2-8b"
|
| 108 |
+
processor = AutoProcessor.from_pretrained( # type: ignore
|
| 109 |
+
checkpoint, do_image_splitting=False,
|
| 110 |
+
size={"longest_edge": 224, "shortest_edge": 224})
|
| 111 |
+
return processor # type: ignore
|
| 112 |
+
|
| 113 |
+
def get_adapter() -> IdeficsAdapter:
|
| 114 |
+
processor = get_processor()
|
| 115 |
+
return IdeficsAdapter(IMG_DIR, processor)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
### Game logic
|
| 119 |
+
|
| 120 |
+
@dataclasses.dataclass(frozen=False)
|
| 121 |
+
class GameState:
|
| 122 |
+
config: GameConfig
|
| 123 |
+
adapter_name: str
|
| 124 |
+
chats: List[str]
|
| 125 |
+
currently_selected: List[str]
|
| 126 |
+
selected_accum: List[List[str]]
|
| 127 |
+
clicks_accum: List[List[str]]
|
| 128 |
+
turn: int = 0
|
| 129 |
+
|
| 130 |
+
def has_ended(self):
|
| 131 |
+
return self.has_successfully_ended() or self.turn >= 10
|
| 132 |
+
|
| 133 |
+
def has_successfully_ended(self):
|
| 134 |
+
return set(self.currently_selected) == set(self.config.targets)
|
| 135 |
+
|
| 136 |
+
### UI helpers
|
| 137 |
+
|
| 138 |
+
def serialize_conversation(self):
|
| 139 |
+
output = [f"Turn {i+1}: {message}"
|
| 140 |
+
for i, message in enumerate(self.chats)]
|
| 141 |
+
return "\n".join(output)
|
| 142 |
+
|
| 143 |
+
def markup_images(self):
|
| 144 |
+
context = self.config.speaker_context
|
| 145 |
+
targets = self.config.targets
|
| 146 |
+
selected = self.currently_selected
|
| 147 |
+
changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else []
|
| 148 |
+
|
| 149 |
+
tangram_list = self._display_context(context, targets, changes, selected)
|
| 150 |
+
# return [(img, f"Image {i+1}") for i, img in enumerate(tangram_list)]
|
| 151 |
+
return tangram_list
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def _display_context(context: List[str], targets: List[str],
|
| 155 |
+
changes: List[str], selected: List[str]) -> List[Image.Image]:
|
| 156 |
+
tangram_list: List[Image.Image] = []
|
| 157 |
+
arrow = Image.open("yellow_circle.png").resize((20, 20)).convert("RGBA")
|
| 158 |
+
for img in context:
|
| 159 |
+
image = Image.open(os.path.join(IMG_DIR, img)).resize((60, 60)).convert("RGB")
|
| 160 |
+
image = ImageOps.expand(image, border=2, fill="white")
|
| 161 |
+
if img in targets and img in selected: # listener selected a target image
|
| 162 |
+
image = ImageOps.expand(image, border=10, fill="green")
|
| 163 |
+
elif img in targets and img not in selected: # unselected target:
|
| 164 |
+
image = ImageOps.expand(image, border=10, fill="black")
|
| 165 |
+
elif img in selected and img not in targets: # listener selected a wrong image
|
| 166 |
+
image = ImageOps.expand(image, border=10, fill="red")
|
| 167 |
+
else:
|
| 168 |
+
image = ImageOps.expand(image, border=10, fill="white")
|
| 169 |
+
image = ImageOps.expand(image, border=2, fill="white")
|
| 170 |
+
if img in changes:
|
| 171 |
+
image.paste(arrow, (68, 0), mask=arrow)
|
| 172 |
+
tangram_list.append(image)
|
| 173 |
+
return tangram_list
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class GameFlow:
|
| 177 |
+
|
| 178 |
+
@classmethod
|
| 179 |
+
def initialize(cls, model_iteration: str) -> GameState:
|
| 180 |
+
config = generate_game_config()
|
| 181 |
+
adapter_name = "r0" if model_iteration == "Initial System" else "r6_bp"
|
| 182 |
+
state = GameState(
|
| 183 |
+
config=config,
|
| 184 |
+
adapter_name=adapter_name,
|
| 185 |
+
chats=[],
|
| 186 |
+
currently_selected=[],
|
| 187 |
+
selected_accum=[],
|
| 188 |
+
clicks_accum=[],
|
| 189 |
+
turn=0,
|
| 190 |
+
)
|
| 191 |
+
return state
|
| 192 |
+
|
| 193 |
+
@classmethod
|
| 194 |
+
def progress(cls, state: GameState, chat: str,
|
| 195 |
+
model: PeftModel,
|
| 196 |
+
adapter: IdeficsAdapter) -> GameState:
|
| 197 |
+
turn = state.turn
|
| 198 |
+
model_context_images = state.config.listener_context
|
| 199 |
+
|
| 200 |
+
model_clicks = get_model_response(
|
| 201 |
+
model, state.adapter_name, adapter,
|
| 202 |
+
model_context_images, chat,
|
| 203 |
+
state.chats, state.selected_accum
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# symmetric difference (apply deselection, then selection)
|
| 207 |
+
currently_selected2 = sorted_list(
|
| 208 |
+
(set(state.currently_selected) - set(model_clicks)) \
|
| 209 |
+
| (set(model_clicks) - set(state.currently_selected))
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
state2 = GameState(
|
| 213 |
+
# constants
|
| 214 |
+
config=state.config,
|
| 215 |
+
adapter_name=state.adapter_name,
|
| 216 |
+
# updates
|
| 217 |
+
chats=state.chats.copy() + [chat],
|
| 218 |
+
currently_selected=currently_selected2,
|
| 219 |
+
selected_accum=state.selected_accum.copy() + [currently_selected2],
|
| 220 |
+
clicks_accum=state.clicks_accum.copy() + [model_clicks],
|
| 221 |
+
turn=turn+1,
|
| 222 |
+
)
|
| 223 |
+
return state2
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
### UI
|
| 228 |
+
|
| 229 |
+
def create_app_inner():
|
| 230 |
+
### layout
|
| 231 |
+
gr.Markdown("# Tangram Multi-Reference Game")
|
| 232 |
+
gr.Markdown(
|
| 233 |
+
'### You will be playing a multi-reference games against a model. \
|
| 234 |
+
To start a game, first select whether you wish to play against our \
|
| 235 |
+
initial trained model ("Initial System") or \
|
| 236 |
+
our model at the end of continual learning ("Final System") \
|
| 237 |
+
and press the "Start Game" button. \
|
| 238 |
+
You will take on a "speaker" role at each round. \
|
| 239 |
+
Your goal is to describe this image (via a message in the textbox) \
|
| 240 |
+
so that the model can guess what it is.'
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
gr.Markdown("Targets have black borders. Correctly selected targets have green borders. Incorrectly selected targets have red borders. Actions are marked with yellow dot.")
|
| 244 |
+
|
| 245 |
+
gr.Markdown("The listener cannot see boxes or colors and the order is different.")
|
| 246 |
+
|
| 247 |
+
gr.Markdown(
|
| 248 |
+
'### Press "Send" to submit your action to proceed to the next turn. \
|
| 249 |
+
You have 10 turns in total.'
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
with gr.Row():
|
| 253 |
+
model_iteration = gr.Radio(["Initial System", "Final System"],
|
| 254 |
+
label="Model Iteration",
|
| 255 |
+
value="Final System")
|
| 256 |
+
start_btn = gr.Button("Start Game")
|
| 257 |
+
|
| 258 |
+
with gr.Row():
|
| 259 |
+
current_turn = gr.Textbox(label="TURN")
|
| 260 |
+
success = gr.Textbox(label="Success")
|
| 261 |
+
|
| 262 |
+
with gr.Row():
|
| 263 |
+
image_output = gr.Gallery(
|
| 264 |
+
label="CONTEXT", show_label=False, elem_id="gallery",
|
| 265 |
+
columns=5, rows=2, object_fit="contain", height="250px",
|
| 266 |
+
allow_preview=False, container=True, interactive=False
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
with gr.Row():
|
| 270 |
+
conversation_output = gr.Textbox(label="Interaction History")
|
| 271 |
+
user_input = gr.Textbox(label="Your Message as Speaker", interactive=True)
|
| 272 |
+
|
| 273 |
+
send_btn = gr.Button("Send", interactive=True)
|
| 274 |
+
|
| 275 |
+
### globals
|
| 276 |
+
model = get_model()
|
| 277 |
+
adapter = get_adapter()
|
| 278 |
+
game_state = gr.State(value=None)
|
| 279 |
+
|
| 280 |
+
### callbacks
|
| 281 |
+
def output_from_state(state: GameState):
|
| 282 |
+
has_ended = state.has_ended()
|
| 283 |
+
success = "success" if state.has_successfully_ended() else "failure"
|
| 284 |
+
return (
|
| 285 |
+
state.markup_images(), # image_output
|
| 286 |
+
state.serialize_conversation(), # conversation_output
|
| 287 |
+
f"{state.turn+1}/10", # current_turn
|
| 288 |
+
success if has_ended else "n/a", # success
|
| 289 |
+
gr.update(interactive=not has_ended, value=""), # user_input
|
| 290 |
+
gr.update(interactive=not has_ended), # send_btn
|
| 291 |
+
gr.update(interactive=has_ended), # model_iteration
|
| 292 |
+
state, # game_history
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
def on_start_interaction(model_iteration: str):
|
| 296 |
+
assert model_iteration in ["Initial System", "Final System"]
|
| 297 |
+
state = GameFlow.initialize(model_iteration)
|
| 298 |
+
return output_from_state(state)
|
| 299 |
+
|
| 300 |
+
def on_send_message(message: str, state: GameState):
|
| 301 |
+
nonlocal model
|
| 302 |
+
nonlocal adapter
|
| 303 |
+
if message.strip() == "":
|
| 304 |
+
logging.info("Empty message")
|
| 305 |
+
return output_from_state(state)
|
| 306 |
+
state = GameFlow.progress(state, message, model, adapter)
|
| 307 |
+
return output_from_state(state)
|
| 308 |
+
|
| 309 |
+
start_btn.click(
|
| 310 |
+
on_start_interaction,
|
| 311 |
+
inputs=[model_iteration],
|
| 312 |
+
outputs=[image_output, conversation_output, current_turn, success,
|
| 313 |
+
user_input, send_btn, model_iteration, game_state],
|
| 314 |
+
queue=False
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
send_btn.click(
|
| 318 |
+
on_send_message,
|
| 319 |
+
inputs=[user_input, game_state],
|
| 320 |
+
outputs=[image_output, conversation_output, current_turn, success,
|
| 321 |
+
user_input, send_btn, model_iteration, game_state],
|
| 322 |
+
queue=True
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def create_app():
|
| 327 |
+
with gr.Blocks(css=css) as app:
|
| 328 |
+
create_app_inner()
|
| 329 |
+
return app
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
app = create_app()
|
| 334 |
+
app.queue()
|
| 335 |
+
app.launch()
|
config_generator.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import functools
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
import pprint
|
| 7 |
+
import random
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
EMPTY_DATA_PATH = "tangram_pngs/"
|
| 11 |
+
SPLIT_PATH = "dataset_splits/"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclasses.dataclass(frozen=True)
|
| 15 |
+
class GameConfig:
|
| 16 |
+
speaker_context: List[str]
|
| 17 |
+
listener_context: List[str]
|
| 18 |
+
targets: List[str]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def generate_game_config() -> GameConfig:
|
| 22 |
+
corpus = _get_data()
|
| 23 |
+
context = random.sample(corpus, 10)
|
| 24 |
+
num_targets = random.randint(3, 5)
|
| 25 |
+
targets = random.sample(context, num_targets)
|
| 26 |
+
listener_order = list(range(10))
|
| 27 |
+
random.shuffle(listener_order)
|
| 28 |
+
|
| 29 |
+
config = GameConfig(
|
| 30 |
+
speaker_context=context,
|
| 31 |
+
listener_context=[context[i] for i in listener_order],
|
| 32 |
+
targets=targets,
|
| 33 |
+
)
|
| 34 |
+
logging.info(f"context_dict: {pprint.pformat(dataclasses.asdict(config))}")
|
| 35 |
+
return config
|
| 36 |
+
|
| 37 |
+
@functools.cache
|
| 38 |
+
def _get_data(restricted_dataset: bool=False):
|
| 39 |
+
if not restricted_dataset:
|
| 40 |
+
# 1013 images
|
| 41 |
+
paths = os.listdir(EMPTY_DATA_PATH)
|
| 42 |
+
else:
|
| 43 |
+
# 912 images
|
| 44 |
+
with open(os.path.join(SPLIT_PATH, "test_imgs.pkl"), 'rb') as f:
|
| 45 |
+
paths = pickle.load(f)
|
| 46 |
+
with open(os.path.join(SPLIT_PATH, "train_imgs.pkl"), 'rb') as f:
|
| 47 |
+
paths += pickle.load(f)
|
| 48 |
+
paths = [path + ".png" for path in paths]
|
| 49 |
+
dup_images = ["page6-51.png", "page6-66.png", "page4-170.png"]
|
| 50 |
+
paths = [path for path in paths if path != ".DS_Store" and path not in dup_images]
|
| 51 |
+
return paths
|
dataset_splits/dev_imgs.pkl
ADDED
|
Binary file (1.18 kB). View file
|
|
|
dataset_splits/test_imgs.pkl
ADDED
|
Binary file (5.26 kB). View file
|
|
|
dataset_splits/train_imgs.pkl
ADDED
|
Binary file (5.26 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.2.0
|
| 2 |
+
datasets==2.18.0
|
| 3 |
+
transformers==4.40.0
|
| 4 |
+
accelerate==0.29.2
|
| 5 |
+
loralib==0.1.2
|
| 6 |
+
peft==0.10.0
|
| 7 |
+
nltk==3.8.1
|
| 8 |
+
gradio==4.44.1
|
| 9 |
+
spaces==0.30.4
|
tangram_pngs/page-A.png
ADDED
|
tangram_pngs/page-B.png
ADDED
|
tangram_pngs/page-C.png
ADDED
|
tangram_pngs/page-D.png
ADDED
|
tangram_pngs/page-E.png
ADDED
|
tangram_pngs/page-F.png
ADDED
|
tangram_pngs/page-G.png
ADDED
|
tangram_pngs/page-H.png
ADDED
|
tangram_pngs/page-I.png
ADDED
|
tangram_pngs/page-J.png
ADDED
|
tangram_pngs/page-K.png
ADDED
|
tangram_pngs/page-L.png
ADDED
|
tangram_pngs/page1-0.png
ADDED
|
tangram_pngs/page1-1.png
ADDED
|
tangram_pngs/page1-10.png
ADDED
|
tangram_pngs/page1-103.png
ADDED
|
tangram_pngs/page1-105.png
ADDED
|
tangram_pngs/page1-106.png
ADDED
|
tangram_pngs/page1-107.png
ADDED
|
tangram_pngs/page1-108.png
ADDED
|
tangram_pngs/page1-109.png
ADDED
|
tangram_pngs/page1-110.png
ADDED
|
tangram_pngs/page1-112.png
ADDED
|
tangram_pngs/page1-113.png
ADDED
|
tangram_pngs/page1-114.png
ADDED
|
tangram_pngs/page1-116.png
ADDED
|
tangram_pngs/page1-117.png
ADDED
|
tangram_pngs/page1-118.png
ADDED
|
tangram_pngs/page1-119.png
ADDED
|
tangram_pngs/page1-122.png
ADDED
|
tangram_pngs/page1-125.png
ADDED
|
tangram_pngs/page1-128.png
ADDED
|
tangram_pngs/page1-129.png
ADDED
|
tangram_pngs/page1-13.png
ADDED
|
tangram_pngs/page1-130.png
ADDED
|
tangram_pngs/page1-132.png
ADDED
|
tangram_pngs/page1-133.png
ADDED
|
tangram_pngs/page1-136.png
ADDED
|
tangram_pngs/page1-137.png
ADDED
|
tangram_pngs/page1-14.png
ADDED
|