kinredon's picture
Upload 123 files
929a772 verified
import os
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
# def feature_select(self, image_forward_outs):
# image_features = image_forward_outs.hidden_states[self.select_layer]
# if self.select_feature == 'patch':
# image_features = image_features[:, 1:]
# elif self.select_feature == 'cls_patch':
# image_features = image_features
# else:
# raise ValueError(f'Unexpected select feature: {self.select_feature}')
# return image_features
def feature_select_withcls(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
image_features = image_features
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select_withcls(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select_withcls(image_forward_outs).to(images.dtype)
return image_features
def forward_select(self, images, token_num):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select_withcls(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True, output_attentions=True)
attn_weights = image_forward_outs.attentions[-2]
hidden_states = image_forward_outs.hidden_states[-2]
dominant_num = token_num
## Dominant Visual Tokens
cls_idx = 0
cls_attention = attn_weights[:, :, cls_idx, cls_idx+1:]
cls_attention_sum = cls_attention.sum(dim=1)
topk_indices = cls_attention_sum.topk(dominant_num, dim=1).indices
topk_indices_sorted = torch.sort(topk_indices, dim=1).values
return topk_indices_sorted
def forward_select_scope(self, images, token_num, alpha):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select_withcls(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True, output_attentions=True)
attn_weights = image_forward_outs.attentions[-2]
hidden_states = image_forward_outs.hidden_states[-2]
dominant_num = token_num
## Dominant Visual Tokens
# cls_idx = 0
# cls_attention = attn_weights[:, :, cls_idx, cls_idx+1:]
# cls_attention_sum = cls_attention.sum(dim=1)
# topk_indices = cls_attention_sum.topk(dominant_num, dim=1).indices
# topk_indices_sorted = torch.sort(topk_indices, dim=1).values
cls_idx = 0
cls_attention = attn_weights[:, :, cls_idx, cls_idx+1:]
cls_attention_sum = cls_attention.sum(dim=1)
image_features = hidden_states[:, cls_idx + 1:]
bs = image_features.shape[0]
dominant_num = int(dominant_num /bs)
selected_idx, _ = SCOPE(image_features, dominant_num, cls_attention_sum, alpha)
# selected_idx += 1
all_indices = selected_idx
topk_indices_sorted = torch.sort(all_indices, dim=1).values
# hidden_states_save = dominant_tokens
return topk_indices_sorted
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class CLIPVisionTowerS2(CLIPVisionTower):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__(vision_tower, args, delay_load)
self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
self.s2_scales = list(map(int, self.s2_scales.split(',')))
self.s2_scales.sort()
self.s2_split_size = self.s2_scales[0]
self.s2_image_size = self.s2_scales[-1]
try:
from s2wrapper import forward as multiscale_forward
except ImportError:
raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
self.multiscale_forward = multiscale_forward
# change resize/crop size in preprocessing to the largest image size in s2_scale
if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
self.image_processor.size['shortest_edge'] = self.s2_image_size
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False)
self.image_processor.size['shortest_edge'] = self.s2_image_size
self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
self.is_loaded = True
@torch.no_grad()
def forward_feature(self, images):
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
image_features.append(image_feature)
else:
image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
return image_features
@property
def hidden_size(self):
return self.config.hidden_size * len(self.s2_scales)
def SCOPE(visual_feature_vectors, num_selected_token, cls_attn=None, alpha=1.0):
"""
Batched version of SCOPE that processes all batch elements simultaneously.
Args:
visual_feature_vectors: [B, N, D] batch of feature vectors
num_selected_token: Number of tokens to select per batch
cls_attn: [B, N] batch of attention weights
Returns:
selected_idx: [B, K] selected token indices for each batch
cosine_simi: [B, N, N] batch of cosine similarity matrices
"""
# Calculate cosine similarity for all batches at once
norm_vectors = visual_feature_vectors / visual_feature_vectors.norm(dim=-1, keepdim=True)
cosine_simi = torch.bmm(norm_vectors, norm_vectors.transpose(1, 2))
B, N = visual_feature_vectors.shape[:2]
device = visual_feature_vectors.device
dtype = visual_feature_vectors.dtype
# Pre-allocate tensors for all batches
selected = torch.zeros(B, N, dtype=torch.bool, device=device)
selected_idx = torch.empty(B, num_selected_token, dtype=torch.long, device=device)
cur_max = torch.zeros(B, N, dtype=dtype, device=device)
# Precompute cls_attn ** alpha for all batches
# alpha = float(os.environ.get('ALPHA', '1.0'))
if cls_attn is not None:
cls_attn_powered = cls_attn ** alpha
else:
cls_attn_powered = torch.ones(B, N, dtype=dtype, device=device)
for i in range(num_selected_token):
# Calculate gains for all batches simultaneously
unselected_mask = ~selected
gains = torch.maximum(
torch.zeros(1, dtype=dtype, device=device),
cosine_simi.masked_fill(~unselected_mask.unsqueeze(1), 0) -
cur_max.unsqueeze(2)
).sum(dim=1)
# Apply attention weights
combined = os.environ.get('COMBINED', 'multi')
if combined == 'multi':
gains = gains * cls_attn_powered
elif combined == 'add':
gains = gains + cls_attn_powered
else:
raise NotImplementedError
# Mask out already selected tokens
gains = gains.masked_fill(~unselected_mask, float('-inf'))
# Find best elements for all batches
best_idx = gains.argmax(dim=1)
# Update states for all batches
selected[torch.arange(B, device=device), best_idx] = True
selected_idx[:, i] = best_idx
cur_max = torch.maximum(cur_max, cosine_simi[torch.arange(B, device=device), best_idx])
return selected_idx, cosine_simi