| import torch, os, PIL, numbers |
| from PIL import Image |
| import cv2 |
|
|
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.models.siglip.modeling_siglip import SiglipVisionModel |
| from transformers import AutoConfig, AutoModel, SiglipImageProcessor, SiglipVisionConfig, PretrainedConfig |
| from typing import Union |
| import torch.nn.functional as F |
| import numpy as np |
|
|
|
|
| def crop_clip(clip, min_h, min_w, h, w): |
| if isinstance(clip[0], np.ndarray): |
| cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] |
|
|
| elif isinstance(clip[0], PIL.Image.Image): |
| cropped = [ |
| img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip |
| ] |
| else: |
| raise TypeError('Expected numpy.ndarray or PIL.Image' + |
| 'but got list of {0}'.format(type(clip[0]))) |
| return cropped |
|
|
|
|
| class Normalize(object): |
| """Normalize a clip with mean and standard deviation. |
| Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform |
| will normalize each channel of the input ``torch.*Tensor`` i.e. |
| ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` |
| .. note:: |
| This transform acts out of place, i.e., it does not mutates the input tensor. |
| Args: |
| mean (sequence): Sequence of means for each channel. |
| std (sequence): Sequence of standard deviations for each channel. |
| """ |
|
|
| def __init__(self, mean, std): |
| self.mean = mean |
| self.std = std |
|
|
| def __call__(self, clip): |
| """ |
| Args: |
| clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. |
| Returns: |
| Tensor: Normalized Tensor clip. |
| """ |
| return normalize(clip, self.mean, self.std) |
|
|
| def __repr__(self): |
| return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) |
|
|
|
|
| class CenterCrop(object): |
| """Extract center crop at the same location for a list of images |
| Args: |
| size (sequence or int): Desired output size for the |
| crop in format (h, w) |
| """ |
|
|
| def __init__(self, size): |
| if isinstance(size, numbers.Number): |
| size = (size, size) |
|
|
| self.size = size |
|
|
| def __call__(self, clip): |
| """ |
| Args: |
| img (PIL.Image or numpy.ndarray): List of images to be cropped |
| in format (h, w, c) in numpy.ndarray |
| Returns: |
| PIL.Image or numpy.ndarray: Cropped list of images |
| """ |
| h, w = self.size |
| if isinstance(clip[0], np.ndarray): |
| im_h, im_w, im_c = clip[0].shape |
| elif isinstance(clip[0], PIL.Image.Image): |
| im_w, im_h = clip[0].size |
| else: |
| raise TypeError('Expected numpy.ndarray or PIL.Image' + |
| 'but got list of {0}'.format(type(clip[0]))) |
| if w > im_w or h > im_h: |
| error_msg = ( |
| 'Initial image size should be larger then ' |
| 'cropped size but got cropped sizes : ({w}, {h}) while ' |
| 'initial image is ({im_w}, {im_h})'.format( |
| im_w=im_w, im_h=im_h, w=w, h=h)) |
| raise ValueError(error_msg) |
|
|
| x1 = int(round((im_w - w) / 2.)) |
| y1 = int(round((im_h - h) / 2.)) |
| cropped = crop_clip(clip, y1, x1, h, w) |
|
|
| return cropped |
|
|
|
|
| def resize_clip(clip, size, interpolation='bilinear'): |
| if isinstance(clip[0], np.ndarray): |
| if isinstance(size, numbers.Number): |
| im_h, im_w, im_c = clip[0].shape |
| |
| if (im_w <= im_h and im_w == size) or (im_h <= im_w |
| and im_h == size): |
| return clip |
| new_h, new_w = get_resize_sizes(im_h, im_w, size) |
| size = (new_w, new_h) |
| else: |
| size = size[0], size[1] |
| if interpolation == 'bilinear': |
| np_inter = cv2.INTER_LINEAR |
| else: |
| np_inter = cv2.INTER_NEAREST |
| scaled = [ |
| cv2.resize(img, size, interpolation=np_inter) for img in clip |
| ] |
| elif isinstance(clip[0], PIL.Image.Image): |
| if isinstance(size, numbers.Number): |
| im_w, im_h = clip[0].size |
| |
| if (im_w <= im_h and im_w == size) or (im_h <= im_w |
| and im_h == size): |
| return clip |
| new_h, new_w = get_resize_sizes(im_h, im_w, size) |
| size = (new_w, new_h) |
| else: |
| size = size[1], size[0] |
| if interpolation == 'bilinear': |
| pil_inter = PIL.Image.BILINEAR |
| else: |
| pil_inter = PIL.Image.NEAREST |
| scaled = [img.resize(size, pil_inter) for img in clip] |
| else: |
| raise TypeError('Expected numpy.ndarray or PIL.Image' + |
| 'but got list of {0}'.format(type(clip[0]))) |
| return scaled |
|
|
|
|
| def _is_tensor_clip(clip): |
| return torch.is_tensor(clip) and clip.ndimension() == 4 |
|
|
|
|
| def get_resize_sizes(im_h, im_w, size): |
| if im_w < im_h: |
| ow = size |
| oh = int(size * im_h / im_w) |
| else: |
| oh = size |
| ow = int(size * im_w / im_h) |
| return oh, ow |
|
|
|
|
| def normalize(clip, mean, std, inplace=False): |
| if not _is_tensor_clip(clip): |
| raise TypeError('tensor is not a torch clip.') |
|
|
| if not inplace: |
| clip = clip.clone() |
|
|
| dtype = clip.dtype |
| mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) |
| std = torch.as_tensor(std, dtype=dtype, device=clip.device) |
| clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) |
|
|
| return clip |
|
|
|
|
| class Resize(object): |
| """Resizes a list of (H x W x C) numpy.ndarray to the final size |
| The larger the original image is, the more times it takes to |
| interpolate |
| Args: |
| interpolation (str): Can be one of 'nearest', 'bilinear' |
| defaults to nearest |
| size (tuple): (widht, height) |
| """ |
|
|
| def __init__(self, size, interpolation='nearest'): |
| self.size = size |
| self.interpolation = interpolation |
|
|
| def __call__(self, clip): |
| resized = resize_clip( |
| clip, self.size, interpolation=self.interpolation) |
| return resized |
|
|
|
|
| class Compose(object): |
| """Composes several transforms |
| Args: |
| transforms (list of ``Transform`` objects): list of transforms |
| to compose |
| """ |
|
|
| def __init__(self, transforms): |
| self.transforms = transforms |
|
|
| def __call__(self, clip): |
| for t in self.transforms: |
| clip = t(clip) |
| return clip |
|
|
|
|
| def convert_img(img): |
| """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" |
| if len(img.shape) == 3: |
| img = img.transpose(2, 0, 1) |
| if len(img.shape) == 2: |
| img = np.expand_dims(img, 0) |
| return img |
|
|
|
|
| class ClipToTensor(object): |
| """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] |
| to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] |
| """ |
|
|
| def __init__(self, channel_nb=3, div_255=True, numpy=False): |
| self.channel_nb = channel_nb |
| self.div_255 = div_255 |
| self.numpy = numpy |
|
|
| def __call__(self, clip): |
| """ |
| Args: clip (list of numpy.ndarray): clip (list of images) |
| to be converted to tensor. |
| """ |
| |
| if isinstance(clip[0], np.ndarray): |
| h, w, ch = clip[0].shape |
| assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) |
| elif isinstance(clip[0], Image.Image): |
| w, h = clip[0].size |
| else: |
| raise TypeError( |
| "Expected numpy.ndarray or PIL.Image\ |
| but got list of {0}".format( |
| type(clip[0]) |
| ) |
| ) |
|
|
| np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) |
|
|
| |
| for img_idx, img in enumerate(clip): |
| if isinstance(img, np.ndarray): |
| pass |
| elif isinstance(img, Image.Image): |
| img = np.array(img, copy=False) |
| else: |
| raise TypeError( |
| "Expected numpy.ndarray or PIL.Image\ |
| but got list of {0}".format( |
| type(clip[0]) |
| ) |
| ) |
| img = convert_img(img) |
| np_clip[:, img_idx, :, :] = img |
| if self.numpy: |
| if self.div_255: |
| np_clip = np_clip / 255.0 |
| return np_clip |
|
|
| else: |
| tensor_clip = torch.from_numpy(np_clip) |
|
|
| if not isinstance(tensor_clip, torch.FloatTensor): |
| tensor_clip = tensor_clip.float() |
| if self.div_255: |
| tensor_clip = torch.div(tensor_clip, 255) |
| return tensor_clip |
|
|
|
|
| class VisionTowerConfig(PretrainedConfig): |
| model_type = "vision_tower" |
|
|
| def __init__(self, vision_tower_name: str = None, **kwargs): |
| super().__init__() |
| self.vision_tower_name = vision_tower_name |
|
|
|
|
| class ProcessorWrapper: |
| def __init__(self, transform=None, processor=None, height=378, width=378, frames_per_clip=1, |
| image_mean=[0.48145466, 0.4578275, 0.40821073]): |
| assert transform is not None or processor is not None, "ERROR: you did not define both `transform` and `processor`! You must define either transform or processor" |
| assert transform is None or processor is None, "ERROR: you did defined both `transform` and `processor`! You must define only one of: transform or processor" |
| self._size = { |
| "height": height, |
| "width": width, |
| "frames_per_clip": frames_per_clip |
| } |
| self._transforms = transform |
| self._processor = processor |
| self.image_mean = image_mean |
|
|
| @property |
| def size(self): |
| return self._size |
|
|
| def preprocess(self, image, return_tensors='pt'): |
| |
| output = {} |
| if self._transforms is not None: |
| output['pixel_values'] = [self._transforms(image)] |
|
|
| else: |
| output = self._processor(image, return_tensors='pt') |
| return output |
|
|
| def save_pretrained(self, save_path): |
| if self._transforms is not None: |
| transform_dict = transform_to_dict(self._transforms) |
| transform_dict["image_processor_type"] = "transforms" |
| with open(os.path.join(save_path, 'preprocessor_config.json'), 'w') as f: |
| json.dump(transform_dict, f, indent=4) |
| else: |
| self._processor.save_pretrained(save_path) |
| return |
|
|
|
|
| class VisionTower(PreTrainedModel): |
| config_class = VisionTowerConfig |
|
|
| def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: VisionTowerConfig = None): |
| super().__init__(vision_config) |
| self.vision_tower_name = model_name_or_path |
| self.vision_config = vision_config |
| self.select_layer = getattr(config, "mm_vision_select_layer", -2) |
| self.select_feature = getattr(config, "mm_vision_select_feature", "patch") |
| self.encode_batch_size = getattr(config, "encode_batch_size", 0) // 2 |
| self.num_encode_batch = getattr(config, "num_encode_batch", 0) // 2 |
| self.temporal_tubelet_size = getattr(vision_config, "tubelet_size", 1) |
|
|
| def feature_select(self, image_features): |
| if self.select_layer is not None: |
| image_features = image_features.hidden_states[self.select_layer] |
| |
| if self.select_feature == "patch": |
| image_features = image_features[:, 1:] |
| elif self.select_feature == "cls_patch": |
| image_features = image_features |
| else: |
| raise ValueError(f"Unexpected select feature: {self.select_feature}") |
| |
| return image_features |
|
|
| def vision_tower_forward(self, image): |
| image_feature = self.vision_tower(image, output_hidden_states=True) |
| return image_feature |
| |
| def _forward(self, images, out_T=1): |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| image_feature = self.vision_tower_forward(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) |
| image_feature = self.feature_select(image_feature).to(image.dtype) |
| image_feature = image_features.reshape(image_feature.shape[0], self.W, self.H, self.D) |
| image_features.append(image_feature) |
| else: |
| original_shape = images.shape |
| if len(original_shape) == 5 and self.T == 1: |
| |
| images = images[:, ::original_shape[1] // out_T, ...] |
| original_shape = images.shape |
| images = images.view(-1, *original_shape[2:]) |
|
|
| image_features = self.vision_tower_forward(images.to(device=self.device, dtype=self.dtype)) |
| image_features = self.feature_select(image_features).to(images.dtype) |
| |
| if len(original_shape) == 5 and self.T == 1: |
| |
| new_shape = list(image_features.shape[:-2]) + [self.W, self.H, self.hidden_size] |
| image_features = image_features.reshape(new_shape) |
| feature_size = image_features.shape[1:] |
| image_features = image_features.view(original_shape[0], original_shape[1], *feature_size) |
| |
| else: |
| image_features = image_features.reshape(image_features.shape[0], self.T, self.W, self.H, self.hidden_size) |
| |
| return image_features |
| |
| def forward(self, images): |
| return self._forward(images) |
|
|
| @property |
| def dummy_feature(self): |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
| @property |
| def dtype(self): |
| return self.vision_tower.dtype |
|
|
| @property |
| def device(self): |
| return self.vision_tower.device |
|
|
| @property |
| def num_patches(self): |
| return (self.config.image_size // self.config.patch_size) ** 2 |
|
|
|
|
| class InternVideoTower(VisionTower): |
| def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None): |
| if vision_config is None: |
| vision_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
|
|
| super().__init__(model_name_or_path, config, vision_config) |
| self.vision_config = vision_config |
| normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
|
| print('loading: ', model_name_or_path) |
| model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True) |
| self.vision_tower = model.to(dtype=eval(config.model_dtype)) |
|
|
| transform = Compose([ |
| Resize(self.vision_config.img_size, interpolation='bilinear'), |
| CenterCrop(size=(self.vision_config.img_size, self.vision_config.img_size)), |
| ClipToTensor(), |
| Normalize(mean=normalize[0], std=normalize[1]) |
| ]) |
|
|
| self.vision_processor = ProcessorWrapper(transform=transform, |
| height=self.vision_config.img_size, |
| width=self.vision_config.img_size, |
| frames_per_clip=self.vision_config.num_frames, |
| image_mean=normalize[0]) |
|
|
| self.W = self.H = vision_config.img_size // vision_config.patch_size |
| self.T = self.vision_config.num_frames // self.vision_config.tubelet_size |
| self.num_frames = self.vision_config.num_frames |
| self.hidden_size = vision_config.d_model |
| self.vision_select_layer=self.select_layer |
| self.select_layer=None |
|
|
| def vision_tower_forward(self, video): |
| if video.shape[-3] < self.num_frames: |
| video = video.repeat_interleave(self.num_frames, dim=-3) |
| elif video.shape[-3] > self.num_frames: |
| video = video[:, :, ::video.shape[-3] // self.num_frames, ...] |
|
|
| video_feature = self.vision_tower(video.to(device=self.device, dtype=self.dtype), |
| x_vis_return_idx=self.vision_select_layer, x_vis_only=True) |
| |
| return video_feature |
|
|
| @property |
| def device(self): |
| return self.vision_tower.pos_embed.device |
|
|
|
|
| class SiglipVisionTower(VisionTower): |
| def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None): |
| if vision_config is None: |
| vision_config = SiglipVisionConfig.from_pretrained(model_name_or_path) |
|
|
| super().__init__(model_name_or_path, config, vision_config) |
| self.vision_config = vision_config |
| self.vision_tower_name = model_name_or_path |
| self.vision_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name) |
|
|
| print('loading: ', model_name_or_path) |
| self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name) |
|
|
| self.hidden_size = self.vision_config.hidden_size |
| self.W = self.H = self.vision_config.image_size // self.vision_config.patch_size |
| self.T = 1 |
| self.select_feature = "cls_patch" |
|
|
|
|
| class ApolloVisionTower(PreTrainedModel): |
| def __init__(self, config, vision_tower_cfg): |
| super(ApolloVisionTower, self).__init__(config, vision_tower_cfg) |
| self.model_name_or_path = vision_tower_cfg._name_or_path |
| self.vision_towers = vision_tower_cfg.vision_towers |
| self._config = vision_tower_cfg |
|
|
| for vision_tower_name in self.vision_towers: |
| if 'internvideo' in vision_tower_name.lower(): |
| vision_tower = InternVideoTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name), config) |
| elif 'siglip' in vision_tower_name.lower(): |
| vision_tower = SiglipVisionTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name), |
| config) |
|
|
| setattr(self, vision_tower_name, vision_tower) |
|
|
| self.vision_processor = [getattr(self, vt).vision_processor for vt in self.vision_towers] |
| self.num_vision_encoders = len(self.vision_towers) |
| self.W = self.H = max([getattr(self, vt).W for vt in self.vision_towers]) |
| self.T = max([getattr(self, vt).T for vt in self.vision_towers]) |
| self.max_tubelet_size = max( |
| [getattr(getattr(self, vt).vision_config, 'tubelet_size', 1) for vt in self.vision_towers]) |
| |
| self._hidden_size = sum([getattr(self, vt).hidden_size for vt in self.vision_towers]) |
| self.token_output_shape = (self.T, self.W, self.H) |
| self.config.num_vision_encoders = self.num_vision_encoders |
| self.config.vision_towers = self.vision_towers |
| self.config.token_output_shape = self.token_output_shape |
|
|
| def forward(self, x): |
| output_features = [] |
| for x_s, vision_tower_name in zip(x, self.vision_towers): |
| vision_tower = getattr(self, vision_tower_name) |
| features = vision_tower._forward(x_s, out_T=self.T) |
|
|
| if len(features.shape) != len(self.token_output_shape) + 2: |
| features = features.unsqueeze(1) |
|
|
| if features.shape[-len(self.token_output_shape) - 1:-1] != self.token_output_shape: |
| features = features.permute(0, 4, 1, 2, 3).contiguous() |
| features = F.interpolate(features.to(torch.float32), size=self.token_output_shape, mode='trilinear', |
| align_corners=False).to(features.dtype) |
| features = features.permute(0, 2, 3, 4, 1).contiguous() |
|
|
| output_features.append(features) |
|
|
| output_features = torch.cat(output_features, dim=-1) |
| output_features = torch.flatten(output_features, start_dim=1, end_dim=-2) |
| return output_features |
|
|
| def save_pretrained( |
| self, |
| save_directory: Union[str, os.PathLike], |
| state_dict=None, |
| **kwargs, |
| ): |
| if state_dict is None: |
| state_dict = self.state_dict() |
|
|
| for vision_tower_name in self.vision_towers: |
| vision_tower = getattr(self, vision_tower_name) |
| vision_tower_state_dict = OrderedDict( |
| {k.split(f"vision_tower.{vision_tower_name}.vision_tower.")[-1]: v for k, v in state_dict.items() if |
| vision_tower_name in k} |
| ) |
| vision_tower.vision_tower.save_pretrained(os.path.join(save_directory, vision_tower_name), |
| state_dict=vision_tower_state_dict, **kwargs) |
| vision_tower.vision_processor.save_pretrained(os.path.join(save_directory, vision_tower_name)) |
|
|
| config = self.config |
| config.configs = {} |
| config.save_pretrained(save_directory) |
|
|
| @property |
| def patch_size(self): |
| return self._patch_size |
|
|
| @property |
| def image_size(self): |
| return self._image_size |
|
|
| @property |
| def hidden_size(self): |
| return self._hidden_size |
|
|
|
|