Spaces:
Runtime error
Runtime error
| import logging | |
| import re | |
| from functools import cache | |
| from pathlib import Path | |
| from typing import List, Set, Tuple, TypeVar | |
| import torch | |
| from PIL import Image | |
| from transformers import Idefics2Processor, PreTrainedTokenizer | |
| from utils import device, nested_apply, sorted_list | |
| RE_PATTERN = r'^(deselect\s[A-Z](?:\s[A-Z])*(?:\sselect\s[A-Z](?:\s[A-Z])*)?|select\s[A-Z](?:\s[A-Z])*)$' # noqa | |
| # Name type, newtype of str. e.g. "page4-249.png" | |
| N = TypeVar('N') | |
| ALPHABET = 'ABCDEFGHIJ' # we only have 10 images | |
| LEGAL_TOKEN_IDS = [2, 315, 330, 334, 365, 382, 384, 401, 413, | |
| 420, 475, 5339, 634, 17960, 32002] # A - J and <end_of_utterance> and <\s> and 'select' and 'deselect' | |
| MINI_DECODER = { | |
| 384: 'D', | |
| # 2: '</s>', | |
| 32002: '<end_of_utterance>', | |
| 420: 'G', 17960: 'elect', | |
| 330: 'A', 365: 'B', 334: 'C', 5339: 'select', 401: 'F', 475: 'J', | |
| 634: 'des', 315: 'I', 413: 'E', 382: 'H'} | |
| class AlphabeticNameHash: | |
| def __init__(self, context: List[N]) -> None: | |
| self._forward_map = {im: ALPHABET[i] for i, im in enumerate(context)} | |
| self._backward_map = {ALPHABET[i]: im for i, im in enumerate(context)} | |
| def hash(self, im: N) -> str: | |
| return self._forward_map[im] | |
| def unhash(self, i: str) -> N: | |
| return self._backward_map[i] | |
| def valid_hash(self, i: str) -> bool: | |
| return i in self._backward_map | |
| class IdeficsAdapter: | |
| PAD_TOKEN_ID = 0 | |
| LABEL_MASK_ID = 32001 # idefics2: image_token_id | |
| LEGAL_TOKEN_IDS = LEGAL_TOKEN_IDS | |
| LEGAL_TOKEN_MASK = torch.zeros(32003, requires_grad=False)\ | |
| .index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool) | |
| SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS)) | |
| def __init__(self, image_folder: str, processor: Idefics2Processor) -> None: | |
| self.t_max_length = 2048 | |
| self.image_folder = Path(image_folder) | |
| self.image_cache = {} | |
| self.processor = processor | |
| self.tokenizer: PreTrainedTokenizer = self.processor.tokenizer # type: ignore | |
| def get_image(self, im_name: N) -> Image.Image: | |
| if im_name not in self.image_cache: | |
| self.image_cache[im_name] = Image.open( | |
| self.image_folder.joinpath(im_name)) | |
| return self.image_cache[im_name] | |
| def unhash(self, context: List[N], c: str) -> N: | |
| return AlphabeticNameHash(tuple(context)).unhash(c) | |
| def valid_hash(self, context: List[N], c: str) -> bool: | |
| return AlphabeticNameHash(tuple(context)).valid_hash(c) | |
| def parse(self, context: List[N], decoded_out: str, | |
| currently_selected: List[N]) -> List[str]: | |
| h = AlphabeticNameHash(tuple(context)) | |
| logging.debug(f"{context=}") | |
| # do inference | |
| logging.debug(f"{decoded_out=}") | |
| selection, deselection = self.parse_raw(decoded_out) | |
| hashed_currently_selected = {h.hash(n) for n in currently_selected} | |
| desel_to_remove = deselection - hashed_currently_selected | |
| if len(desel_to_remove) > 0: | |
| logging.debug(f"warn! {desel_to_remove=}") | |
| deselection = deselection - desel_to_remove | |
| sel_to_remove = selection & hashed_currently_selected | |
| if len(sel_to_remove) > 0: | |
| logging.debug(f"warn! {sel_to_remove=}") | |
| selection = selection - sel_to_remove | |
| logging.debug("post strict cleaning") | |
| logging.debug(f"{selection=}") | |
| logging.debug(f"{deselection=}") | |
| model_clicks = selection | deselection | |
| logging.debug(f"{model_clicks=}") | |
| model_clicks_png = [h.unhash(n) | |
| for n in model_clicks if h.valid_hash(n)] | |
| logging.debug(f"{model_clicks_png=}") | |
| return model_clicks_png | |
| def parse_raw(text: str) -> Tuple[Set[N], Set[N]]: | |
| last_answer = text.strip() | |
| if ":" in text: | |
| last_answer_pattern = r":.*$" | |
| xs = re.findall(last_answer_pattern, text) | |
| last_answer = xs[0].removeprefix(":").strip() | |
| xs = re.search(RE_PATTERN, last_answer) | |
| if xs is None: | |
| print(f"{last_answer=}") | |
| print("did not pass regex") | |
| return set(), set() | |
| select_pattern = r"(?<!de)select( [A-J])+$" | |
| xs = re.search(select_pattern, last_answer) | |
| if xs is not None: | |
| xs = xs.group() | |
| selections: Set[N] = set(xs.split(" ")[1:]) if xs else set() | |
| deselect_pattern = r"^deselect( [A-J])+" | |
| xs = re.search(deselect_pattern, last_answer) | |
| if xs is not None: | |
| xs = xs.group() | |
| deselections: Set[N] = set(xs.split(" ")[1:]) if xs else set() | |
| return selections, deselections | |
| def compose(self, context, chats, previous_selected, hash_images, padding): | |
| select_accum, deselect_accum, clickss = self.unfold_select_deselect( | |
| previous_selected) | |
| select_accum = select_accum + [[]] | |
| deselect_accum = deselect_accum + [[]] | |
| previous_selected = [[]] + previous_selected # old states pre click | |
| assert len(chats) == len(select_accum) == len( | |
| deselect_accum) == len(previous_selected) | |
| messages, images = self.build_processor_input( | |
| context, chats, select_accum, deselect_accum, previous_selected, hash_images, omit_last_answer=True, sort_names=True, omit_context=False, chat_feedback=None) | |
| prompt = self.processor.apply_chat_template( | |
| messages, add_generation_prompt=True) | |
| prompt = prompt.strip() | |
| logging.debug(prompt) | |
| # Keep consistent with train_script | |
| inputs = self.processor( | |
| text=prompt, images=images, | |
| padding=padding, truncation=True, max_length=self.t_max_length, | |
| return_tensors="pt") | |
| return inputs | |
| def build_processor_input(self, image_pngs: List[N], chats: List[str], | |
| select_accum: List[List[N]], | |
| deselect_accum: List[List[N]], | |
| pre_click_selected_accum: List[List[N]], | |
| hash_image: bool, omit_last_answer: bool, | |
| sort_names: bool, omit_context: bool, | |
| chat_feedback: str, ): | |
| def _text_content(text): return {"type": "text", "text": text} | |
| def _image_content(): return {"type": "image"} | |
| def _user_prompt(content): return {"role": "user", "content": content} | |
| def _assistant_prompt(content): return { | |
| "role": "assistant", "content": content} | |
| def _system_prompt(content): return { | |
| "role": "system", "content": content} | |
| def _current_state(selected: List[N]): | |
| if len(selected) == 0: | |
| return 'none is selected' | |
| return f'{" ".join(selected)} currently selected' | |
| def _listener_action(select: List[N], deselect: List[N]): | |
| if len(select) == 0 and len(deselect) == 0: | |
| return 'nothing' | |
| if len(select) == 0: | |
| return f'deselect {" ".join(deselect)}' | |
| if len(deselect) == 0: | |
| return f'select {" ".join(select)}' | |
| return f'deselect {" ".join(deselect)} select {" ".join(select)}' | |
| func = AlphabeticNameHash(tuple(image_pngs)).hash if hash_image else id | |
| context, select_accum, deselect_accum, pre_click_selected_accum = nested_apply( | |
| func, (image_pngs, select_accum, deselect_accum, pre_click_selected_accum)) | |
| prompt = [] | |
| images = [] | |
| if not omit_context: | |
| images = [self.get_image(im) for im in image_pngs] | |
| images_and_names_content = [] | |
| for im_name in context: | |
| images_and_names_content.append(_image_content()) | |
| images_and_names_content.append(_text_content(im_name)) | |
| prompt.append(_system_prompt(images_and_names_content)) | |
| if not len(chats) == len(select_accum) == len(deselect_accum) == len(pre_click_selected_accum): | |
| logging.error(f"{chats=}") | |
| logging.error(f"{select_accum=}") | |
| logging.error(f"{deselect_accum=}") | |
| logging.error(f"{pre_click_selected_accum=}") | |
| assert False | |
| for i, (chat, select, deselect, pre_click_selected) in enumerate( | |
| zip(chats, select_accum, deselect_accum, pre_click_selected_accum)): | |
| if sort_names: | |
| select = sorted(select) | |
| deselect = sorted(deselect) | |
| pre_click_selected = sorted(pre_click_selected) | |
| prompt.append(_system_prompt( | |
| [_text_content(_current_state(pre_click_selected))])) | |
| prompt.append(_user_prompt([_text_content(chat)])) | |
| prompt.append(_assistant_prompt( | |
| [_text_content(_listener_action(select, deselect))])) | |
| if omit_last_answer: | |
| # idefics2 has processor.apply_chat_template(messages, add_generation_prompt=True) instead | |
| prompt.pop(-1) | |
| if chat_feedback is not None: | |
| prompt.append(_user_prompt([_text_content(chat_feedback)])) | |
| return prompt, images | |
| def unfold_select_deselect(self, previous_selected: List[List[N]]) -> Tuple[List[N], List[N], List[N]]: | |
| # currently selected AFTER i-th turn | |
| num_turns = len(previous_selected) | |
| selected: List[List[str]] = [] # turn-wise selection | |
| deselected: List[List[str]] = [] # turn-wise deselection | |
| clicks: List[List[str]] = [] | |
| # combining turn-wise newly selected and newly deselected | |
| prev_selected = set() | |
| for turn in range(num_turns): | |
| curr_selected = set(previous_selected[turn]) | |
| newly_selected = curr_selected - prev_selected | |
| newly_deselected = prev_selected - curr_selected | |
| selected.append(sorted_list(newly_selected)) | |
| deselected.append(sorted_list(newly_deselected)) | |
| clicks.append(sorted_list(newly_selected | newly_deselected)) | |
| prev_selected = curr_selected.copy() | |
| return selected, deselected, clicks | |