| | from typing import Union, Optional |
| |
|
| | import PIL.Image |
| | import torch |
| | from torch.nn.functional import softmax, gumbel_softmax, pad |
| | from transformers import PretrainedConfig, PreTrainedModel, AutoImageProcessor, AutoModel, AutoConfig |
| | from ovis.util.constants import IMAGE_INDICATOR_IDS, IMAGE_ATOM_ID |
| |
|
| |
|
| | class BaseVisualTokenizerConfig(PretrainedConfig): |
| | def __init__( |
| | self, |
| | vocab_size=16384, |
| | tokenize_function="softmax", |
| | tau=1.0, |
| | depths=None, |
| | drop_cls_token=False, |
| | backbone_config: Optional[Union[PretrainedConfig, dict]] = None, |
| | hidden_stride: int = 1, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| | self.vocab_size = vocab_size |
| | self.tokenize_function = tokenize_function |
| | self.tau = tau |
| | if isinstance(depths, str): |
| | depths = [int(x) for x in depths.split('|')] |
| | self.depths = depths |
| | self.backbone_kwargs = {} |
| | self.drop_cls_token = drop_cls_token |
| | if backbone_config is not None: |
| | assert isinstance(backbone_config, (PretrainedConfig, dict)), \ |
| | f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" |
| | if not isinstance(backbone_config, PretrainedConfig): |
| | model_type = backbone_config['model_type'] |
| | backbone_config.pop('model_type') |
| | backbone_config = AutoConfig.for_model(model_type, **backbone_config) |
| | self.backbone_config = backbone_config |
| | self.hidden_stride = hidden_stride |
| |
|
| |
|
| | class BaseVisualTokenizer(PreTrainedModel): |
| | base_model_prefix = "backbone" |
| | main_input_name = None |
| | _image_processor_class = None |
| | _image_processor_kwargs = {} |
| | _backbone_class = None |
| | _backbone_name_or_path = None |
| |
|
| | def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs): |
| | super().__init__(config, *inputs, **kwargs) |
| | if kwargs.get('train_from_scratch'): |
| | self.image_processor = self._image_processor_class.from_pretrained(self._backbone_name_or_path, |
| | **self._image_processor_kwargs) |
| | self.backbone = self._backbone_class.from_pretrained(self._backbone_name_or_path, |
| | **self.config.backbone_kwargs) |
| | self.config.backbone_config = self.backbone.config |
| | else: |
| | self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path']) |
| | self.backbone = AutoModel.from_config(self.config.backbone_config) |
| | head_dim = self.config.vocab_size - len(IMAGE_INDICATOR_IDS) |
| | self.head = torch.nn.Sequential( |
| | torch.nn.Linear( |
| | self.backbone.config.hidden_size * self.config.hidden_stride * self.config.hidden_stride, head_dim, |
| | bias=False |
| | ), |
| | torch.nn.LayerNorm(head_dim) |
| | ) |
| |
|
| | assert all((self.image_processor.do_resize, |
| | not getattr(self.image_processor, 'do_center_crop', False), |
| | self.image_processor.do_rescale, |
| | self.image_processor.do_normalize |
| | )), f"image_processor `{self.image_processor}` is not supported currently" |
| |
|
| | def get_backbone(self): |
| | return self.backbone |
| |
|
| | def get_monitor_tensors(self): |
| | raise NotImplementedError |
| |
|
| | def get_image_processor(self): |
| | return self.image_processor |
| |
|
| | def mock_input(self): |
| | height, width = self.get_image_size() |
| | return torch.zeros(1, 3, height, width), self.construct_image_placeholders((1, 1)) |
| |
|
| | def get_head(self): |
| | return self.head |
| |
|
| | def get_image_size(self): |
| | raise NotImplementedError |
| |
|
| | @staticmethod |
| | def construct_image_placeholders(grid): |
| | image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]] |
| | if grid[0] * grid[1] > 1: |
| | for r in range(grid[0]): |
| | for c in range(grid[1]): |
| | image_placeholders.append(IMAGE_ATOM_ID) |
| | if c < grid[1] - 1: |
| | image_placeholders.append(IMAGE_INDICATOR_IDS[2]) |
| | if r < grid[0] - 1: |
| | image_placeholders.append(IMAGE_INDICATOR_IDS[3]) |
| | image_placeholders.append(IMAGE_INDICATOR_IDS[4]) |
| | return image_placeholders |
| |
|
| | def preprocess_image(self, image: PIL.Image.Image, max_partition=9, covering_threshold=0.9, convert_to_rgb=True): |
| | def _preprocess(img: PIL.Image.Image, side): |
| | |
| | w, h = img.size |
| | if w == h: |
| | new_width = new_height = side |
| | elif w > h: |
| | new_width = side |
| | new_height = int(h / w * new_width) |
| | else: |
| | new_height = side |
| | new_width = int(w / h * new_height) |
| | new_size = dict(height=new_height, width=new_width) |
| | pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values'] |
| |
|
| | |
| | square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device) |
| | new_height, new_width = pixel_values.shape[2:] |
| | if new_height == new_width: |
| | square_values[:, :, :, :] = pixel_values |
| | elif new_height > new_width: |
| | from_index = (side - new_width) // 2 |
| | square_values[:, :, :, from_index:from_index + new_width] = pixel_values |
| | else: |
| | from_index = (side - new_height) // 2 |
| | square_values[:, :, from_index:from_index + new_height, :] = pixel_values |
| |
|
| | return square_values |
| |
|
| | def _partition(img, grid): |
| | w, h = img.size |
| | row_height = h // grid[0] |
| | col_width = w // grid[1] |
| |
|
| | partition = [] |
| | for row in range(grid[0]): |
| | for col in range(grid[1]): |
| | left = col * col_width |
| | upper = row * row_height |
| | right = w if col == grid[1] - 1 else (col + 1) * col_width |
| | lower = h if row == grid[0] - 1 else (row + 1) * row_height |
| | partition.append((left, upper, right, lower)) |
| |
|
| | return partition |
| |
|
| | def _covering_area(left, upper, right, lower, side): |
| | w = right - left |
| | h = lower - upper |
| | w, h = max(w, h), min(w, h) |
| | if w > side: |
| | h = h / w * side |
| | w = side |
| | return w * h |
| |
|
| | def _get_best_grid(img, side): |
| | img_area = img.size[0] * img.size[1] |
| |
|
| | candidate_grids = [] |
| | for i in range(1, max_partition + 1): |
| | for j in range(1, max_partition + 1): |
| | if i * j <= max_partition: |
| | candidate_grids.append((i, j)) |
| |
|
| | all_grids = [] |
| | good_grids = [] |
| | for grid in candidate_grids: |
| | partition = _partition(img, grid) |
| | covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area |
| | assert covering_ratio <= 1.0 |
| | all_grids.append((grid, covering_ratio)) |
| | if covering_ratio > covering_threshold: |
| | good_grids.append((grid, covering_ratio)) |
| |
|
| | if len(good_grids) > 0: |
| | |
| | return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0] |
| | else: |
| | |
| | return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0] |
| |
|
| | if convert_to_rgb and image.mode != 'RGB': |
| | image = image.convert('RGB') |
| |
|
| | sides = self.get_image_size() |
| | if sides[0] != sides[1]: |
| | raise ValueError('get_image_size() returns non-square size') |
| | side = sides[0] |
| | grid = _get_best_grid(image, side) |
| | partition = _partition(image, grid) |
| | crops = [image.crop(p) for p in partition] |
| | if len(crops) > 1: |
| | crops.insert(0, image) |
| | pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0) |
| | image_placeholders = self.construct_image_placeholders(grid) |
| | return pixel_values, image_placeholders |
| |
|
| | def get_backbone_layer(self, index): |
| | return self.backbone.vision_model.encoder.layers[index] |
| |
|
| | def tokenize(self, logits): |
| | def st_argmax(y_soft, dim): |
| | index = y_soft.max(dim, keepdim=True)[1] |
| | y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) |
| | ret = y_hard - y_soft.detach() + y_soft |
| | return ret |
| |
|
| | if self.config.tokenize_function == 'softmax': |
| | tokens = softmax(logits, dim=-1) |
| | elif self.config.tokenize_function == 'gumbel_argmax': |
| | tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True) |
| | elif self.config.tokenize_function == 'st_argmax': |
| | tokens = st_argmax(logits, dim=-1) |
| | else: |
| | raise ValueError( |
| | f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}') |
| | return tokens |
| |
|
| | def encode(self, pixel_values): |
| | output = self.backbone(pixel_values, output_hidden_states=True, return_dict=True) |
| | features = output.hidden_states[-1] |
| | if self.config.drop_cls_token: |
| | features = features[:, 1:, :] |
| |
|
| | |
| | |
| | if self.config.hidden_stride > 1: |
| | n, l, d = features.shape |
| | sqrt_l = int(l ** 0.5) |
| | assert sqrt_l ** 2 == l, "The token sequence length should be a perfect square." |
| | features = features.reshape(n, sqrt_l, sqrt_l, d) |
| | pl = (self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)) % self.config.hidden_stride |
| | features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0) |
| | sqrt_l += pl |
| | features = features.reshape(n, sqrt_l // self.config.hidden_stride, self.config.hidden_stride, |
| | sqrt_l // self.config.hidden_stride, self.config.hidden_stride, d) |
| | features = features.permute(0, 1, 3, 2, 4, 5) |
| | features = features.flatten(3) |
| | features = features.reshape( |
| | n, -1, self.config.hidden_stride * self.config.hidden_stride * d) |
| |
|
| | return features |
| |
|
| | def forward(self, pixel_values) -> torch.Tensor: |
| | features = self.encode(pixel_values) |
| | logits = self.head(features) |
| | tokens = self.tokenize(logits) |
| | |
| | |
| | batch_size, token_len, _ = tokens.shape |
| | padding_tensor = torch.zeros(size=(batch_size, token_len, len(IMAGE_INDICATOR_IDS)), |
| | dtype=tokens.dtype, |
| | device=tokens.device, |
| | layout=tokens.layout, |
| | requires_grad=False) |
| | tokens = torch.cat((tokens, padding_tensor), dim=2) |
| | return tokens |
| |
|