import torch from transformers import AutoModel, AutoTokenizer from PIL import Image from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ CenterCrop, ColorJitter, Grayscale import math FILE_EXTENSIONS = ('.jpeg', '.txt', '.idx') ''' args = { "patch_size": 16, "patch_num_width": 16, "patch_num_height": 16, "position_embedding_length": 4096, "clip_model_name": 'InternViT-448', "image_segment_method": 'dynamic', "max_split_tile_num_multi_image": 1, "clip_visual_size": 1024, "clip_hidden_size": 1024, "downsample_ratio": 0.5 } ''' class args: patch_size = 16 patch_num_width = 16 patch_num_height = 16 position_embedding_length = 4096 clip_model_name = 'InternViT-448' image_segment_method = 'dynamic' ##'adaptive' max_split_tile_num_multi_image = 1 max_split_tile_num_single_image = 9 clip_visual_size = 1024 clip_hidden_size = 1024 downsample_ratio = 0.5 shape_change_threshold = 0.5 bf16 = True fp16 = False def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size, threshold): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) size_diff_length = abs(((ratio[0]*image_size + ratio[1]*image_size)-(width+height)) / (width+height)) if ratio_diff < best_ratio_diff and size_diff_length <= threshold: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def build_transform(input_size): #MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = Compose([ Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), _convert_to_rgb, ToTensor(), Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) return transform def torch_extract_patches(image_tensor, patch_height, patch_width): PATCH_SIZE = args.patch_size PATCH_NUM_WIDTH = args.patch_num_width PATCH_NUM_HEIGHT = args.patch_num_height POSITION_EMBEDDING_LENGTH = args.position_embedding_length print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) # 576 MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT # TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE # 336 336 IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT image_tensor = image_tensor.unsqueeze(0) patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1) patches = patches.permute(0, 4, 2, 3, 1).reshape( image_tensor.size(2) // patch_height, image_tensor.size(3) // patch_width, image_tensor.size(1) * patch_height * patch_width, ) return patches.unsqueeze(0) # 用于计算adapt需要输入图片的大小 def adapt_size(originHeight:int,originWeight:int): ### 用于计算adapt的图片大小 # 参数说明 # originHeight: 原图高度 # originWidth: 原图宽度 # patchHeight: patch高度 # patchWidth: patch宽度 # maxPatches: patch数目上限 # 返回值说明: # resized_height: 插值后图片高度 # resized_width: 插值后图片宽度 # resized_patch_height_num: 插值后图片垂直patch数目 # resized_patch_width_num: 插值后图片水平patch数目 PATCH_SIZE = args.patch_size PATCH_NUM_WIDTH = args.patch_num_width PATCH_NUM_HEIGHT = args.patch_num_height POSITION_EMBEDDING_LENGTH = args.position_embedding_length print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) # 576 MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT # TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE # 336 336 IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT patchHeight = PATCH_SIZE patchWidth = PATCH_SIZE maxPatches = MAX_PATCHES scale = math.sqrt(maxPatches * (patchHeight / originHeight) * (patchWidth / originWeight)) resized_patch_height_num = max(min(math.floor(scale * originHeight / patchHeight), maxPatches), 1) resized_patch_width_num = max(min(math.floor(scale * originWeight / patchWidth), maxPatches), 1) resized_height = max(resized_patch_height_num * PATCH_SIZE, 1) resized_width = max(resized_patch_width_num * PATCH_SIZE, 1) return resized_height, resized_width, resized_patch_height_num, resized_patch_width_num def cal_num_of_slices(origin_image_width, origin_image_height, max_num): #import pdb #pdb.set_trace() PATCH_SIZE = args.patch_size PATCH_NUM_WIDTH = args.patch_num_width PATCH_NUM_HEIGHT = args.patch_num_height POSITION_EMBEDDING_LENGTH = args.position_embedding_length print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) # 576 MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT # TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE # 336 336 IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT scale = origin_image_width*origin_image_height/(IMAGE_WIDTH*IMAGE_HEIGHT) scale = math.ceil(scale) max_num_img=max_num if scale > max_num_img: scale = max_num_img def factorize(n): factors = [] for i in range(1, n + 1): if n % i == 0: factors.append((i/(n/i), i, n // i)) return factors numbers = [1, 2, 3, 4, 5, 6, 7,8,9,10,11,12,13,14,15] factor_dict = {} for num in numbers: factor_dict[num] = factorize(num) log_origin_ratio = math.log(origin_image_width/origin_image_height) available_ratios = [] if scale<=2: available_ratios = factor_dict[scale] + factor_dict[scale + 1] else : available_ratios = factor_dict[scale-1] + factor_dict[scale]+factor_dict[scale+1] min_dif = 1000 best_w = 0 best_h = 0 for (r,w_slice,h_slice) in available_ratios: log_r = math.log(r) if min_dif > abs(log_r - log_origin_ratio): min_dif = abs(log_r - log_origin_ratio) best_w = w_slice best_h = h_slice return best_w,best_h # 做图片切片 def get_patch_nums(origin_image_width, origin_image_height, max_num): # 输入原图的尺寸 # 返回: # slice_w_num 切片的w方向有多少个patch # slice_h_num 切片的h方向有多少个patch # abstract_w_num 原图的w方向有多少个patch # abstract_h_num 原图的h方向有多少个patch PATCH_SIZE = args.patch_size PATCH_NUM_WIDTH = args.patch_num_width PATCH_NUM_HEIGHT = args.patch_num_height POSITION_EMBEDDING_LENGTH = args.position_embedding_length print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) # 576 MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT # TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE # 336 336 IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT best_w, best_h = cal_num_of_slices(origin_image_width,origin_image_height, max_num) slice_width = origin_image_width//best_w slice_height = origin_image_height//best_h _,_,slice_h_num,slice_w_num = adapt_size(slice_height,slice_width) _,_,abstract_h_num,abstract_w_num = adapt_size(origin_image_height,origin_image_width) #print(slice_w_num,slice_h_num,abstract_w_num,abstract_h_num) return slice_w_num,slice_h_num,abstract_w_num,abstract_h_num def slice_image(image, max_num): # slice the image according to our princeple # return an array of slices PATCH_SIZE = args.patch_size PATCH_NUM_WIDTH = args.patch_num_width PATCH_NUM_HEIGHT = args.patch_num_height POSITION_EMBEDDING_LENGTH = args.position_embedding_length print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) # 576 MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT # TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE # 336 336 IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT origin_image_width = image.size[0] origin_image_height = image.size[1] best_w, best_h = cal_num_of_slices(origin_image_width=origin_image_width, origin_image_height=origin_image_height, max_num=max_num ) slices = [] # print(best_w,best_h) for j in range(best_h): for i in range(best_w): box = (i * origin_image_width//best_w, j * origin_image_height//best_h, (i + 1) * origin_image_width//best_w, (j + 1) * origin_image_height//best_h) # 切割图片 region = image.crop(box).convert("RGB") # 添加到列表 slices.append(region) return slices def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, threshold=1): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size, threshold) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) print(box) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def process_image(image, image_size, max_num): PATCH_SIZE = args.patch_size PATCH_NUM_WIDTH = args.patch_num_width PATCH_NUM_HEIGHT = args.patch_num_height POSITION_EMBEDDING_LENGTH = args.position_embedding_length print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) # 576 MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT # TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE # 336 336 IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT origin_image_width = image.size[0] origin_image_height = image.size[1] image = image.convert("RGB") slices = slice_image(image, max_num) if len(slices) != 1: thumbnail_img = image.resize((image_size, image_size)) slices.append(thumbnail_img) # 计算resize之后的图片大小 resized_height, resized_width, resized_patch_height, resized_patch_width = \ adapt_size(origin_image_height,origin_image_width) image = slices[0] image_w = image.size[0] image_h = image.size[1] resized_height, resized_width, resized_patch_height, resized_patch_width = \ adapt_size(image_h,image_w) image = ToTensor()(image) image = torch.nn.functional.interpolate( image.unsqueeze(0), size=(resized_height, resized_width), mode="bilinear", align_corners=False, antialias=True, ).squeeze(0) # 需要mask的patch数 num_patches_to_pad = MAX_PATCHES - resized_patch_height*resized_patch_width # raprint("mask: ",num_patches_to_pad) # 切割resize好的图片 image = torch_extract_patches(image,PATCH_SIZE, PATCH_SIZE) image = image.reshape([resized_patch_width*resized_patch_height,TOKEN_LENGTH]) # 用0补全需要mask的图片部分 image = torch.nn.functional.pad(image, [0, 0, 0, num_patches_to_pad]).float() #torch.Size([196, 768]) image = image.reshape(PATCH_NUM_WIDTH, PATCH_NUM_HEIGHT, PATCH_SIZE, PATCH_SIZE, 3).permute(0, 2, 1, 3, 4).reshape(IMAGE_WIDTH, IMAGE_HEIGHT, 3).permute(2, 0 ,1) #print(image.shape) #image = torch.stack(image) return slices def _convert_to_rgb(image): return image.convert('RGB') def load_image(image_file, input_size=448, max_num=9): image = Image.open(image_file).convert('RGB') # image.save('seg_imge/'+image_file.split('/')[-1]) # print(max_num) if args.clip_model_name == 'InternViT-448': transform = build_transform(input_size=input_size) #image_processor = CLIPImageProcessor.from_pretrained(args.clip_download_path) #'/mnt/beegfs1/shenqiang/internvit-448/models--InternViT-300M-448px/'args.clip_download_path if args.image_segment_method == 'adaptive': images_processed = process_image(image, input_size, max_num) elif args.image_segment_method == 'dynamic': images_processed = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num, threshold=args.shape_change_threshold) # pixel_values = [image_processor(images=image, return_tensors='pt').pixel_values.squeeze(0) for image in images_processed] pixel_values = [transform(image) for image in images_processed] else: transform = build_transform(input_size=input_size) if args.image_segment_method == 'adaptive': images_processed = process_image(image, input_size, max_num) elif args.image_segment_method == 'dynamic': images_processed = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) pixel_values = [transform(image) for image in images_processed] pixel_values = torch.stack(pixel_values) return pixel_values def preocess_imput(args, num_token_per_tile, image_path, question): image_prompts = '' if len(image_path) >= 2: image_list = [] num_tile_per_image_list = [] for ipath in image_path: images = load_image(ipath, max_num=args.max_split_tile_num_multi_image) #images = load_image(ipath, max_num=args.max_split_tile_num_multi_image).view(1, -1, 3, 448, 448).cuda() num_tile_this_image = len(images) num_tile_per_image_list.append(num_tile_this_image) image_list.append(images) image_prompts = image_prompts + '' + '' * num_tile_this_image * num_token_per_tile + '' num_tile_per_image_tensor = torch.Tensor(num_tile_per_image_list).long().cuda() image_tensor = torch.cat(image_list, dim=0).view(1, -1, 3, 448, 448).cuda() else: #images_tensor = load_image(image_path, max_num=args.max_split_tile_num_single_image).view(1, -1, 3, 448, 448).cuda() images = load_image(image_path[0], max_num=args.max_split_tile_num_single_image) num_tile_this_image = len(images) num_tile_per_image_tensor = torch.Tensor([num_tile_this_image]).long().cuda() image_tensor = images.view(1, -1, 3, 448, 448).cuda() image_prompts = image_prompts + '' + '' * num_tile_this_image * num_token_per_tile + '' if args.fp16: image_tensor = image_tensor.half() elif args.bf16: image_tensor = image_tensor.bfloat16() else: image_tensor = image_tensor.float() images_input = {'num_tile_per_image_tensor': num_tile_per_image_tensor, 'image_tensor': image_tensor} prompts = ['' + image_prompts + question[0] + ''] return prompts, images_input def _build_yuanvl_attention_mask_and_position_ids(tokenizer, tokens, images_input=None): """Build the attention mask and postition ids for the input tokens.""" # Since we are not interested in loss-mask and reset attention/position # is also False, eod_token is not used so it is safe to set it to None. bos_token, image_start_token, image_end_token, pad_token, sep_tpken, eod_token = (tokenizer(tok)['input_ids'][0] for tok in ['','', '', '', '', '']) #eod_token = tokenizer("")['input_ids'][0] attention_mask, position_ids, image_info = get_ltor_masks_and_position_ids_yuanvl_inference( tokens, bos_token, image_start_token, image_end_token, eod_token, pad_token, images_input) '''attention_mask, _, position_ids = get_ltor_masks_and_position_ids( data=tokens, eod_token=None, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False)''' return attention_mask, position_ids, image_info def get_ltor_masks_and_position_ids_yuanvl_inference(data, bos_token, image_start_token, image_end_token, eod_token, pad_token, images_input, reset_attention_mask=False): """Build masks and position id for left to right model.""" # Extract batch size and sequence length. micro_batch_size, seq_length = data.size() assert micro_batch_size == 1, 'yuanvl support mbs = 1 only' # Attention mask (lower triangular). if reset_attention_mask: att_mask_batch = micro_batch_size else: att_mask_batch = 1 attention_mask = torch.tril(torch.ones( (att_mask_batch, seq_length, seq_length), device=data.device)).view( att_mask_batch, 1, seq_length, seq_length) # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) #input_pad = [] #image_info = {} #import pdb #pdb.set_trace() #if torch.distributed.get_rank() == 0: #pdb.set_trace() if images_input is not None: num_tile_per_image_tensor = images_input['num_tile_per_image_tensor'] images_tensor = images_input['image_tensor'] input_pad = [] image_info = {} position_ids_use = torch.zeros(data.shape).to(position_ids) for b in range(micro_batch_size): bos_index = position_ids[b, data[b] == bos_token] pad_index = position_ids[b, data[b] == pad_token] image_start_index = position_ids[b, data[b] == image_start_token] image_end_index = position_ids[b, data[b] == image_end_token] #eod_index = position_ids[b, data[b] == eod_token] #assert len(bos_index) == len(eod_index) num_image = len(num_tile_per_image_tensor) #num_tile = pad_index.shape[0] // clip_visual_size #image_info['num_image'] = num_image image_info['num_tile'] = num_tile_per_image_tensor #image_info['bos_pos'] = bos_index.tolist() image_info['image_start_pos'] = image_start_index.tolist() #image_info['image_end_pos'] = image_end_index.tolist() #for j in range(image_index.size()[0]): # start_idx = image_index[j] # diff = seq_length - start_idx # position_ids_use[b][start_idx : ] = torch.arange(diff, dtype=torch.long, # device=data.device) start_idx = image_end_index[-1] diff = seq_length - start_idx position_ids_use[b][start_idx : ] = torch.arange(diff, dtype=torch.long, device=data.device) else: position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0)#.expand_as(data) position_ids_use = position_ids image_info = None #image_info['eod_pos'] = eod_index.tolist() #for j in range(bos_index.size()[0]): # start_idx = bos_index[j] # end_idx = eod_index[j] # input_pad = input_pad + [bos_token] + [pad_token] * clip_visual_size + data[b][start_idx + 1 : end_idx + 1].tolist() #data_nopad = data[b][:eod_index[j]+1].view(1, -1) #input_pad = input_pad + [pad_token] # Position ids. #position_ids = torch.arange(seq_length + clip_visual_size * num_image, dtype=torch.long, #position_ids = torch.arange(seq_length, dtype=torch.long, # device=data.device) #position_ids = position_ids.unsqueeze(0)#.expand_as(data) # Convert attention mask to binary: attention_mask = (attention_mask < 0.5) '''xattn_position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) xattn_position_ids = xattn_position_ids.unsqueeze(0).expand_as(data) for b in range(micro_batch_size): bos_index = xattn_position_ids[b, data[b] == bos_token] num_image = len(bos_index) xattn_mask = torch.zeros((micro_batch_size, seq_length, num_image * clip_visual_size), device = data.device).view(micro_batch_size, 1, seq_length, num_image * clip_visual_size) for j in range(bos_index.size()[0]): sidx = bos_index[j] image_sidx = j * clip_visual_size image_eidx = (j + 1) * clip_visual_size #xattn_mask[b, 0, (sidx + 1) : , image_sidx : image_eidx] = 1 xattn_mask[b, 0, sidx : , image_sidx : image_eidx] = 1 #xattn_mask[b, 0, sidx : (eidx + 1), image_sidx : image_eidx] = 1 xattn_mask = (xattn_mask < 0.5)''' return attention_mask, position_ids_use, image_info tokenizer_loadpath = "/mnt/beegfs3/zhaoxudong/code/yuanvl_hf_40B_stage2_pcase4_12pp/" model_loadpath = "/mnt/beegfs3/zhaoxudong/code/yuanvl_hf_40B_stage2_pcase4_12pp/" # 加载本地模型 model = AutoModel.from_pretrained( model_loadpath, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, use_flash_attn=False, device_map="auto", trust_remote_code=True).eval() print("Creat model finish") # 加载本地 Tokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_loadpath) num_token_per_tile = int(args.clip_visual_size * args.downsample_ratio**2) # demo 1 image_path = ['/mnt/beegfs3/zhaoxudong/code/image.jpeg'] question = ['Please describe the picture'] question = ['请描述这张图片的内容'] prompts, images_input = preocess_imput(args, num_token_per_tile, image_path, question) input=tokenizer(prompts, return_tensors="pt") input_ids = input['input_ids'].to("cuda") pixel_values=images_input['image_tensor'] attention_mask, position_ids, image_info = _build_yuanvl_attention_mask_and_position_ids( tokenizer, input_ids, images_input) attention_mask = input['attention_mask'].to("cuda") output = model.generate(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) print(tokenizer.decode(output[0]))