import torch
from transformers import AutoModel, AutoTokenizer
from PIL import Image
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop, ColorJitter, Grayscale
import math
FILE_EXTENSIONS = ('.jpeg', '.txt', '.idx')
'''
args = {
"patch_size": 16,
"patch_num_width": 16,
"patch_num_height": 16,
"position_embedding_length": 4096,
"clip_model_name": 'InternViT-448',
"image_segment_method": 'dynamic',
"max_split_tile_num_multi_image": 1,
"clip_visual_size": 1024,
"clip_hidden_size": 1024,
"downsample_ratio": 0.5
}
'''
class args:
patch_size = 16
patch_num_width = 16
patch_num_height = 16
position_embedding_length = 4096
clip_model_name = 'InternViT-448'
image_segment_method = 'dynamic' ##'adaptive'
max_split_tile_num_multi_image = 1
max_split_tile_num_single_image = 9
clip_visual_size = 1024
clip_hidden_size = 1024
downsample_ratio = 0.5
shape_change_threshold = 0.5
bf16 = True
fp16 = False
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size, threshold):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
size_diff_length = abs(((ratio[0]*image_size + ratio[1]*image_size)-(width+height)) / (width+height))
if ratio_diff < best_ratio_diff and size_diff_length <= threshold:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def build_transform(input_size):
#MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = Compose([
Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])
return transform
def torch_extract_patches(image_tensor, patch_height, patch_width):
PATCH_SIZE = args.patch_size
PATCH_NUM_WIDTH = args.patch_num_width
PATCH_NUM_HEIGHT = args.patch_num_height
POSITION_EMBEDDING_LENGTH = args.position_embedding_length
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH)
# 576
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT
#
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE
# 336 336
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT
image_tensor = image_tensor.unsqueeze(0)
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width))
patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1)
patches = patches.permute(0, 4, 2, 3, 1).reshape(
image_tensor.size(2) // patch_height,
image_tensor.size(3) // patch_width,
image_tensor.size(1) * patch_height * patch_width,
)
return patches.unsqueeze(0)
# 用于计算adapt需要输入图片的大小
def adapt_size(originHeight:int,originWeight:int):
### 用于计算adapt的图片大小
# 参数说明
# originHeight: 原图高度
# originWidth: 原图宽度
# patchHeight: patch高度
# patchWidth: patch宽度
# maxPatches: patch数目上限
# 返回值说明:
# resized_height: 插值后图片高度
# resized_width: 插值后图片宽度
# resized_patch_height_num: 插值后图片垂直patch数目
# resized_patch_width_num: 插值后图片水平patch数目
PATCH_SIZE = args.patch_size
PATCH_NUM_WIDTH = args.patch_num_width
PATCH_NUM_HEIGHT = args.patch_num_height
POSITION_EMBEDDING_LENGTH = args.position_embedding_length
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH)
# 576
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT
#
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE
# 336 336
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT
patchHeight = PATCH_SIZE
patchWidth = PATCH_SIZE
maxPatches = MAX_PATCHES
scale = math.sqrt(maxPatches * (patchHeight / originHeight) * (patchWidth / originWeight))
resized_patch_height_num = max(min(math.floor(scale * originHeight / patchHeight), maxPatches), 1)
resized_patch_width_num = max(min(math.floor(scale * originWeight / patchWidth), maxPatches), 1)
resized_height = max(resized_patch_height_num * PATCH_SIZE, 1)
resized_width = max(resized_patch_width_num * PATCH_SIZE, 1)
return resized_height, resized_width, resized_patch_height_num, resized_patch_width_num
def cal_num_of_slices(origin_image_width, origin_image_height, max_num):
#import pdb
#pdb.set_trace()
PATCH_SIZE = args.patch_size
PATCH_NUM_WIDTH = args.patch_num_width
PATCH_NUM_HEIGHT = args.patch_num_height
POSITION_EMBEDDING_LENGTH = args.position_embedding_length
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH)
# 576
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT
#
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE
# 336 336
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT
scale = origin_image_width*origin_image_height/(IMAGE_WIDTH*IMAGE_HEIGHT)
scale = math.ceil(scale)
max_num_img=max_num
if scale > max_num_img:
scale = max_num_img
def factorize(n):
factors = []
for i in range(1, n + 1):
if n % i == 0:
factors.append((i/(n/i), i, n // i))
return factors
numbers = [1, 2, 3, 4, 5, 6, 7,8,9,10,11,12,13,14,15]
factor_dict = {}
for num in numbers:
factor_dict[num] = factorize(num)
log_origin_ratio = math.log(origin_image_width/origin_image_height)
available_ratios = []
if scale<=2:
available_ratios = factor_dict[scale] + factor_dict[scale + 1]
else :
available_ratios = factor_dict[scale-1] + factor_dict[scale]+factor_dict[scale+1]
min_dif = 1000
best_w = 0
best_h = 0
for (r,w_slice,h_slice) in available_ratios:
log_r = math.log(r)
if min_dif > abs(log_r - log_origin_ratio):
min_dif = abs(log_r - log_origin_ratio)
best_w = w_slice
best_h = h_slice
return best_w,best_h
# 做图片切片
def get_patch_nums(origin_image_width, origin_image_height, max_num):
# 输入原图的尺寸
# 返回:
# slice_w_num 切片的w方向有多少个patch
# slice_h_num 切片的h方向有多少个patch
# abstract_w_num 原图的w方向有多少个patch
# abstract_h_num 原图的h方向有多少个patch
PATCH_SIZE = args.patch_size
PATCH_NUM_WIDTH = args.patch_num_width
PATCH_NUM_HEIGHT = args.patch_num_height
POSITION_EMBEDDING_LENGTH = args.position_embedding_length
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH)
# 576
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT
#
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE
# 336 336
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT
best_w, best_h = cal_num_of_slices(origin_image_width,origin_image_height, max_num)
slice_width = origin_image_width//best_w
slice_height = origin_image_height//best_h
_,_,slice_h_num,slice_w_num = adapt_size(slice_height,slice_width)
_,_,abstract_h_num,abstract_w_num = adapt_size(origin_image_height,origin_image_width)
#print(slice_w_num,slice_h_num,abstract_w_num,abstract_h_num)
return slice_w_num,slice_h_num,abstract_w_num,abstract_h_num
def slice_image(image, max_num):
# slice the image according to our princeple
# return an array of slices
PATCH_SIZE = args.patch_size
PATCH_NUM_WIDTH = args.patch_num_width
PATCH_NUM_HEIGHT = args.patch_num_height
POSITION_EMBEDDING_LENGTH = args.position_embedding_length
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH)
# 576
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT
#
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE
# 336 336
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT
origin_image_width = image.size[0]
origin_image_height = image.size[1]
best_w, best_h = cal_num_of_slices(origin_image_width=origin_image_width, origin_image_height=origin_image_height, max_num=max_num )
slices = []
# print(best_w,best_h)
for j in range(best_h):
for i in range(best_w):
box = (i * origin_image_width//best_w, j * origin_image_height//best_h, (i + 1) * origin_image_width//best_w, (j + 1) * origin_image_height//best_h)
# 切割图片
region = image.crop(box).convert("RGB")
# 添加到列表
slices.append(region)
return slices
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, threshold=1):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size, threshold)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
print(box)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def process_image(image, image_size, max_num):
PATCH_SIZE = args.patch_size
PATCH_NUM_WIDTH = args.patch_num_width
PATCH_NUM_HEIGHT = args.patch_num_height
POSITION_EMBEDDING_LENGTH = args.position_embedding_length
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH)
# 576
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT
#
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE
# 336 336
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT
origin_image_width = image.size[0]
origin_image_height = image.size[1]
image = image.convert("RGB")
slices = slice_image(image, max_num)
if len(slices) != 1:
thumbnail_img = image.resize((image_size, image_size))
slices.append(thumbnail_img)
# 计算resize之后的图片大小
resized_height, resized_width, resized_patch_height, resized_patch_width = \
adapt_size(origin_image_height,origin_image_width)
image = slices[0]
image_w = image.size[0]
image_h = image.size[1]
resized_height, resized_width, resized_patch_height, resized_patch_width = \
adapt_size(image_h,image_w)
image = ToTensor()(image)
image = torch.nn.functional.interpolate(
image.unsqueeze(0),
size=(resized_height, resized_width),
mode="bilinear",
align_corners=False,
antialias=True,
).squeeze(0)
# 需要mask的patch数
num_patches_to_pad = MAX_PATCHES - resized_patch_height*resized_patch_width
# raprint("mask: ",num_patches_to_pad)
# 切割resize好的图片
image = torch_extract_patches(image,PATCH_SIZE, PATCH_SIZE)
image = image.reshape([resized_patch_width*resized_patch_height,TOKEN_LENGTH])
# 用0补全需要mask的图片部分
image = torch.nn.functional.pad(image, [0, 0, 0, num_patches_to_pad]).float() #torch.Size([196, 768])
image = image.reshape(PATCH_NUM_WIDTH, PATCH_NUM_HEIGHT, PATCH_SIZE, PATCH_SIZE, 3).permute(0, 2, 1, 3, 4).reshape(IMAGE_WIDTH, IMAGE_HEIGHT, 3).permute(2, 0 ,1)
#print(image.shape)
#image = torch.stack(image)
return slices
def _convert_to_rgb(image):
return image.convert('RGB')
def load_image(image_file, input_size=448, max_num=9):
image = Image.open(image_file).convert('RGB')
# image.save('seg_imge/'+image_file.split('/')[-1])
# print(max_num)
if args.clip_model_name == 'InternViT-448':
transform = build_transform(input_size=input_size)
#image_processor = CLIPImageProcessor.from_pretrained(args.clip_download_path)
#'/mnt/beegfs1/shenqiang/internvit-448/models--InternViT-300M-448px/'args.clip_download_path
if args.image_segment_method == 'adaptive':
images_processed = process_image(image, input_size, max_num)
elif args.image_segment_method == 'dynamic':
images_processed = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num, threshold=args.shape_change_threshold)
# pixel_values = [image_processor(images=image, return_tensors='pt').pixel_values.squeeze(0) for image in images_processed]
pixel_values = [transform(image) for image in images_processed]
else:
transform = build_transform(input_size=input_size)
if args.image_segment_method == 'adaptive':
images_processed = process_image(image, input_size, max_num)
elif args.image_segment_method == 'dynamic':
images_processed = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images_processed]
pixel_values = torch.stack(pixel_values)
return pixel_values
def preocess_imput(args, num_token_per_tile, image_path, question):
image_prompts = ''
if len(image_path) >= 2:
image_list = []
num_tile_per_image_list = []
for ipath in image_path:
images = load_image(ipath, max_num=args.max_split_tile_num_multi_image)
#images = load_image(ipath, max_num=args.max_split_tile_num_multi_image).view(1, -1, 3, 448, 448).cuda()
num_tile_this_image = len(images)
num_tile_per_image_list.append(num_tile_this_image)
image_list.append(images)
image_prompts = image_prompts + '' + '' * num_tile_this_image * num_token_per_tile + ''
num_tile_per_image_tensor = torch.Tensor(num_tile_per_image_list).long().cuda()
image_tensor = torch.cat(image_list, dim=0).view(1, -1, 3, 448, 448).cuda()
else:
#images_tensor = load_image(image_path, max_num=args.max_split_tile_num_single_image).view(1, -1, 3, 448, 448).cuda()
images = load_image(image_path[0], max_num=args.max_split_tile_num_single_image)
num_tile_this_image = len(images)
num_tile_per_image_tensor = torch.Tensor([num_tile_this_image]).long().cuda()
image_tensor = images.view(1, -1, 3, 448, 448).cuda()
image_prompts = image_prompts + '' + '' * num_tile_this_image * num_token_per_tile + ''
if args.fp16:
image_tensor = image_tensor.half()
elif args.bf16:
image_tensor = image_tensor.bfloat16()
else:
image_tensor = image_tensor.float()
images_input = {'num_tile_per_image_tensor': num_tile_per_image_tensor,
'image_tensor': image_tensor}
prompts = ['' + image_prompts + question[0] + '']
return prompts, images_input
def _build_yuanvl_attention_mask_and_position_ids(tokenizer, tokens, images_input=None):
"""Build the attention mask and postition ids for the input tokens."""
# Since we are not interested in loss-mask and reset attention/position
# is also False, eod_token is not used so it is safe to set it to None.
bos_token, image_start_token, image_end_token, pad_token, sep_tpken, eod_token = (tokenizer(tok)['input_ids'][0] for tok in ['','', '', '', '', ''])
#eod_token = tokenizer("")['input_ids'][0]
attention_mask, position_ids, image_info = get_ltor_masks_and_position_ids_yuanvl_inference(
tokens,
bos_token,
image_start_token,
image_end_token,
eod_token,
pad_token,
images_input)
'''attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
eod_token=None,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False)'''
return attention_mask, position_ids, image_info
def get_ltor_masks_and_position_ids_yuanvl_inference(data,
bos_token,
image_start_token,
image_end_token,
eod_token,
pad_token,
images_input,
reset_attention_mask=False):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
assert micro_batch_size == 1, 'yuanvl support mbs = 1 only'
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
#input_pad = []
#image_info = {}
#import pdb
#pdb.set_trace()
#if torch.distributed.get_rank() == 0:
#pdb.set_trace()
if images_input is not None:
num_tile_per_image_tensor = images_input['num_tile_per_image_tensor']
images_tensor = images_input['image_tensor']
input_pad = []
image_info = {}
position_ids_use = torch.zeros(data.shape).to(position_ids)
for b in range(micro_batch_size):
bos_index = position_ids[b, data[b] == bos_token]
pad_index = position_ids[b, data[b] == pad_token]
image_start_index = position_ids[b, data[b] == image_start_token]
image_end_index = position_ids[b, data[b] == image_end_token]
#eod_index = position_ids[b, data[b] == eod_token]
#assert len(bos_index) == len(eod_index)
num_image = len(num_tile_per_image_tensor)
#num_tile = pad_index.shape[0] // clip_visual_size
#image_info['num_image'] = num_image
image_info['num_tile'] = num_tile_per_image_tensor
#image_info['bos_pos'] = bos_index.tolist()
image_info['image_start_pos'] = image_start_index.tolist()
#image_info['image_end_pos'] = image_end_index.tolist()
#for j in range(image_index.size()[0]):
# start_idx = image_index[j]
# diff = seq_length - start_idx
# position_ids_use[b][start_idx : ] = torch.arange(diff, dtype=torch.long,
# device=data.device)
start_idx = image_end_index[-1]
diff = seq_length - start_idx
position_ids_use[b][start_idx : ] = torch.arange(diff, dtype=torch.long,
device=data.device)
else:
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0)#.expand_as(data)
position_ids_use = position_ids
image_info = None
#image_info['eod_pos'] = eod_index.tolist()
#for j in range(bos_index.size()[0]):
# start_idx = bos_index[j]
# end_idx = eod_index[j]
# input_pad = input_pad + [bos_token] + [pad_token] * clip_visual_size + data[b][start_idx + 1 : end_idx + 1].tolist()
#data_nopad = data[b][:eod_index[j]+1].view(1, -1)
#input_pad = input_pad + [pad_token]
# Position ids.
#position_ids = torch.arange(seq_length + clip_visual_size * num_image, dtype=torch.long,
#position_ids = torch.arange(seq_length, dtype=torch.long,
# device=data.device)
#position_ids = position_ids.unsqueeze(0)#.expand_as(data)
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
'''xattn_position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
xattn_position_ids = xattn_position_ids.unsqueeze(0).expand_as(data)
for b in range(micro_batch_size):
bos_index = xattn_position_ids[b, data[b] == bos_token]
num_image = len(bos_index)
xattn_mask = torch.zeros((micro_batch_size, seq_length, num_image * clip_visual_size), device = data.device).view(micro_batch_size, 1, seq_length, num_image * clip_visual_size)
for j in range(bos_index.size()[0]):
sidx = bos_index[j]
image_sidx = j * clip_visual_size
image_eidx = (j + 1) * clip_visual_size
#xattn_mask[b, 0, (sidx + 1) : , image_sidx : image_eidx] = 1
xattn_mask[b, 0, sidx : , image_sidx : image_eidx] = 1
#xattn_mask[b, 0, sidx : (eidx + 1), image_sidx : image_eidx] = 1
xattn_mask = (xattn_mask < 0.5)'''
return attention_mask, position_ids_use, image_info
tokenizer_loadpath = "/mnt/beegfs3/zhaoxudong/code/yuanvl_hf_40B_stage2_pcase4_12pp/"
model_loadpath = "/mnt/beegfs3/zhaoxudong/code/yuanvl_hf_40B_stage2_pcase4_12pp/"
# 加载本地模型
model = AutoModel.from_pretrained(
model_loadpath,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=False,
device_map="auto",
trust_remote_code=True).eval()
print("Creat model finish")
# 加载本地 Tokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_loadpath)
num_token_per_tile = int(args.clip_visual_size * args.downsample_ratio**2)
# demo 1
image_path = ['/mnt/beegfs3/zhaoxudong/code/image.jpeg']
question = ['Please describe the picture']
question = ['请描述这张图片的内容']
prompts, images_input = preocess_imput(args, num_token_per_tile, image_path, question)
input=tokenizer(prompts, return_tensors="pt")
input_ids = input['input_ids'].to("cuda")
pixel_values=images_input['image_tensor']
attention_mask, position_ids, image_info = _build_yuanvl_attention_mask_and_position_ids(
tokenizer, input_ids, images_input)
attention_mask = input['attention_mask'].to("cuda")
output = model.generate(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
print(tokenizer.decode(output[0]))