import logging import re from functools import cache from pathlib import Path from typing import List, Set, Tuple, TypeVar import torch from PIL import Image from utils import device, nested_apply, sorted_list RE_PATTERN = r'^(deselect\s[A-Z](?:\s[A-Z])*(?:\sselect\s[A-Z](?:\s[A-Z])*)?|select\s[A-Z](?:\s[A-Z])*)$' # noqa # Name type, newtype of str. e.g. "page4-249.png" N = TypeVar('N') ALPHABET = 'ABCDEFGHIJ' # we only have 10 images LEGAL_TOKEN_IDS = [2, 315, 330, 334, 365, 382, 384, 401, 413, 420, 475, 5339, 634, 17960, 32002] # A - J and and <\s> and 'select' and 'deselect' MINI_DECODER = { 384: 'D', # 2: '', 32002: '', 420: 'G', 17960: 'elect', 330: 'A', 365: 'B', 334: 'C', 5339: 'select', 401: 'F', 475: 'J', 634: 'des', 315: 'I', 413: 'E', 382: 'H'} class AlphabeticNameHash: @cache def __init__(self, context: List[N]) -> None: self._forward_map = {im: ALPHABET[i] for i, im in enumerate(context)} self._backward_map = {ALPHABET[i]: im for i, im in enumerate(context)} def hash(self, im: N) -> str: return self._forward_map[im] def unhash(self, i: str) -> N: return self._backward_map[i] def valid_hash(self, i: str) -> bool: return i in self._backward_map class IdeficsAdapter: PAD_TOKEN_ID = 0 LABEL_MASK_ID = 32001 # idefics2: image_token_id LEGAL_TOKEN_IDS = LEGAL_TOKEN_IDS LEGAL_TOKEN_MASK = torch.zeros(32003, requires_grad=False)\ .index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool) SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS)) def __init__(self, image_folder: str, processor) -> None: self.t_max_length = 2048 self.image_folder = Path(image_folder) self.image_cache = {} self.processor = processor self.tokenizer = self.processor.tokenizer def get_image(self, im_name: N) -> Image.Image: if im_name not in self.image_cache: self.image_cache[im_name] = Image.open( self.image_folder.joinpath(im_name)) return self.image_cache[im_name] def unhash(self, context: List[N], c: str): return AlphabeticNameHash(tuple(context)).unhash(c) def valid_hash(self, context: List[N], c: str): return AlphabeticNameHash(tuple(context)).valid_hash(c) def parse(self, context: List[N], decoded_out: str, currently_selected: List[N]) -> List[str]: h = AlphabeticNameHash(tuple(context)) logging.debug(f"{context=}") # do inference logging.debug(f"{decoded_out=}") selection, deselection = self.parse_raw(decoded_out) hashed_currently_selected = {h.hash(n) for n in currently_selected} desel_to_remove = deselection - hashed_currently_selected if len(desel_to_remove) > 0: logging.debug(f"warn! {desel_to_remove=}") deselection = deselection - desel_to_remove sel_to_remove = selection & hashed_currently_selected if len(sel_to_remove) > 0: logging.debug(f"warn! {sel_to_remove=}") selection = selection - sel_to_remove logging.debug("post strict cleaning") logging.debug(f"{selection=}") logging.debug(f"{deselection=}") model_clicks = selection | deselection logging.debug(f"{model_clicks=}") model_clicks_png = [h.unhash(n) for n in model_clicks if h.valid_hash(n)] logging.debug(f"{model_clicks_png=}") return model_clicks_png @staticmethod def parse_raw(text: str) -> Tuple[Set[N], Set[N]]: last_answer = text.strip() if ":" in text: last_answer_pattern = r":.*$" xs = re.findall(last_answer_pattern, text) last_answer = xs[0].removeprefix(":").strip() xs = re.search(RE_PATTERN, last_answer) if xs is None: print(f"{last_answer=}") print("did not pass regex") return set(), set() select_pattern = r"(? Tuple[List[N], List[N], List[N]]: # currently selected AFTER i-th turn num_turns = len(previous_selected) selected: List[List[str]] = [] # turn-wise selection deselected: List[List[str]] = [] # turn-wise deselection clicks: List[List[str]] = [] # combining turn-wise newly selected and newly deselected prev_selected = set() for turn in range(num_turns): curr_selected = set(previous_selected[turn]) newly_selected = curr_selected - prev_selected newly_deselected = prev_selected - curr_selected selected.append(sorted_list(newly_selected)) deselected.append(sorted_list(newly_deselected)) clicks.append(sorted_list(newly_selected | newly_deselected)) prev_selected = curr_selected.copy() return selected, deselected, clicks