|
|
import torch |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
from PIL import Image |
|
|
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ |
|
|
CenterCrop, ColorJitter, Grayscale |
|
|
import math |
|
|
|
|
|
FILE_EXTENSIONS = ('.jpeg', '.txt', '.idx') |
|
|
''' |
|
|
args = { |
|
|
"patch_size": 16, |
|
|
"patch_num_width": 16, |
|
|
"patch_num_height": 16, |
|
|
"position_embedding_length": 4096, |
|
|
"clip_model_name": 'InternViT-448', |
|
|
"image_segment_method": 'dynamic', |
|
|
"max_split_tile_num_multi_image": 1, |
|
|
"clip_visual_size": 1024, |
|
|
"clip_hidden_size": 1024, |
|
|
"downsample_ratio": 0.5 |
|
|
} |
|
|
''' |
|
|
class args: |
|
|
patch_size = 16 |
|
|
patch_num_width = 16 |
|
|
patch_num_height = 16 |
|
|
position_embedding_length = 4096 |
|
|
clip_model_name = 'InternViT-448' |
|
|
image_segment_method = 'dynamic' |
|
|
max_split_tile_num_multi_image = 1 |
|
|
max_split_tile_num_single_image = 9 |
|
|
clip_visual_size = 1024 |
|
|
clip_hidden_size = 1024 |
|
|
downsample_ratio = 0.5 |
|
|
shape_change_threshold = 0.5 |
|
|
bf16 = True |
|
|
fp16 = False |
|
|
|
|
|
|
|
|
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size, threshold): |
|
|
best_ratio_diff = float('inf') |
|
|
best_ratio = (1, 1) |
|
|
area = width * height |
|
|
for ratio in target_ratios: |
|
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
|
size_diff_length = abs(((ratio[0]*image_size + ratio[1]*image_size)-(width+height)) / (width+height)) |
|
|
if ratio_diff < best_ratio_diff and size_diff_length <= threshold: |
|
|
best_ratio_diff = ratio_diff |
|
|
best_ratio = ratio |
|
|
elif ratio_diff == best_ratio_diff: |
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
|
best_ratio = ratio |
|
|
return best_ratio |
|
|
|
|
|
def build_transform(input_size): |
|
|
|
|
|
transform = Compose([ |
|
|
Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), |
|
|
_convert_to_rgb, |
|
|
ToTensor(), |
|
|
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) |
|
|
]) |
|
|
return transform |
|
|
|
|
|
def torch_extract_patches(image_tensor, patch_height, patch_width): |
|
|
PATCH_SIZE = args.patch_size |
|
|
PATCH_NUM_WIDTH = args.patch_num_width |
|
|
PATCH_NUM_HEIGHT = args.patch_num_height |
|
|
POSITION_EMBEDDING_LENGTH = args.position_embedding_length |
|
|
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) |
|
|
|
|
|
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT |
|
|
|
|
|
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE |
|
|
|
|
|
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH |
|
|
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT |
|
|
image_tensor = image_tensor.unsqueeze(0) |
|
|
patches = torch.nn.functional.unfold(image_tensor, (patch_height, patch_width), stride=(patch_height, patch_width)) |
|
|
patches = patches.reshape(image_tensor.size(0), image_tensor.size(1), patch_height, patch_width, -1) |
|
|
patches = patches.permute(0, 4, 2, 3, 1).reshape( |
|
|
image_tensor.size(2) // patch_height, |
|
|
image_tensor.size(3) // patch_width, |
|
|
image_tensor.size(1) * patch_height * patch_width, |
|
|
) |
|
|
return patches.unsqueeze(0) |
|
|
|
|
|
|
|
|
def adapt_size(originHeight:int,originWeight:int): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PATCH_SIZE = args.patch_size |
|
|
PATCH_NUM_WIDTH = args.patch_num_width |
|
|
PATCH_NUM_HEIGHT = args.patch_num_height |
|
|
POSITION_EMBEDDING_LENGTH = args.position_embedding_length |
|
|
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) |
|
|
|
|
|
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT |
|
|
|
|
|
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE |
|
|
|
|
|
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH |
|
|
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT |
|
|
patchHeight = PATCH_SIZE |
|
|
patchWidth = PATCH_SIZE |
|
|
maxPatches = MAX_PATCHES |
|
|
scale = math.sqrt(maxPatches * (patchHeight / originHeight) * (patchWidth / originWeight)) |
|
|
resized_patch_height_num = max(min(math.floor(scale * originHeight / patchHeight), maxPatches), 1) |
|
|
resized_patch_width_num = max(min(math.floor(scale * originWeight / patchWidth), maxPatches), 1) |
|
|
resized_height = max(resized_patch_height_num * PATCH_SIZE, 1) |
|
|
resized_width = max(resized_patch_width_num * PATCH_SIZE, 1) |
|
|
return resized_height, resized_width, resized_patch_height_num, resized_patch_width_num |
|
|
|
|
|
def cal_num_of_slices(origin_image_width, origin_image_height, max_num): |
|
|
|
|
|
|
|
|
PATCH_SIZE = args.patch_size |
|
|
PATCH_NUM_WIDTH = args.patch_num_width |
|
|
PATCH_NUM_HEIGHT = args.patch_num_height |
|
|
POSITION_EMBEDDING_LENGTH = args.position_embedding_length |
|
|
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) |
|
|
|
|
|
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT |
|
|
|
|
|
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE |
|
|
|
|
|
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH |
|
|
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT |
|
|
scale = origin_image_width*origin_image_height/(IMAGE_WIDTH*IMAGE_HEIGHT) |
|
|
|
|
|
scale = math.ceil(scale) |
|
|
max_num_img=max_num |
|
|
if scale > max_num_img: |
|
|
scale = max_num_img |
|
|
def factorize(n): |
|
|
factors = [] |
|
|
for i in range(1, n + 1): |
|
|
if n % i == 0: |
|
|
factors.append((i/(n/i), i, n // i)) |
|
|
return factors |
|
|
numbers = [1, 2, 3, 4, 5, 6, 7,8,9,10,11,12,13,14,15] |
|
|
factor_dict = {} |
|
|
for num in numbers: |
|
|
factor_dict[num] = factorize(num) |
|
|
log_origin_ratio = math.log(origin_image_width/origin_image_height) |
|
|
available_ratios = [] |
|
|
if scale<=2: |
|
|
available_ratios = factor_dict[scale] + factor_dict[scale + 1] |
|
|
else : |
|
|
available_ratios = factor_dict[scale-1] + factor_dict[scale]+factor_dict[scale+1] |
|
|
|
|
|
min_dif = 1000 |
|
|
best_w = 0 |
|
|
best_h = 0 |
|
|
for (r,w_slice,h_slice) in available_ratios: |
|
|
log_r = math.log(r) |
|
|
if min_dif > abs(log_r - log_origin_ratio): |
|
|
min_dif = abs(log_r - log_origin_ratio) |
|
|
best_w = w_slice |
|
|
best_h = h_slice |
|
|
return best_w,best_h |
|
|
|
|
|
def get_patch_nums(origin_image_width, origin_image_height, max_num): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PATCH_SIZE = args.patch_size |
|
|
PATCH_NUM_WIDTH = args.patch_num_width |
|
|
PATCH_NUM_HEIGHT = args.patch_num_height |
|
|
POSITION_EMBEDDING_LENGTH = args.position_embedding_length |
|
|
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) |
|
|
|
|
|
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT |
|
|
|
|
|
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE |
|
|
|
|
|
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH |
|
|
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT |
|
|
|
|
|
best_w, best_h = cal_num_of_slices(origin_image_width,origin_image_height, max_num) |
|
|
slice_width = origin_image_width//best_w |
|
|
slice_height = origin_image_height//best_h |
|
|
_,_,slice_h_num,slice_w_num = adapt_size(slice_height,slice_width) |
|
|
_,_,abstract_h_num,abstract_w_num = adapt_size(origin_image_height,origin_image_width) |
|
|
|
|
|
return slice_w_num,slice_h_num,abstract_w_num,abstract_h_num |
|
|
|
|
|
def slice_image(image, max_num): |
|
|
|
|
|
|
|
|
|
|
|
PATCH_SIZE = args.patch_size |
|
|
PATCH_NUM_WIDTH = args.patch_num_width |
|
|
PATCH_NUM_HEIGHT = args.patch_num_height |
|
|
POSITION_EMBEDDING_LENGTH = args.position_embedding_length |
|
|
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) |
|
|
|
|
|
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT |
|
|
|
|
|
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE |
|
|
|
|
|
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH |
|
|
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT |
|
|
|
|
|
origin_image_width = image.size[0] |
|
|
origin_image_height = image.size[1] |
|
|
|
|
|
best_w, best_h = cal_num_of_slices(origin_image_width=origin_image_width, origin_image_height=origin_image_height, max_num=max_num ) |
|
|
slices = [] |
|
|
|
|
|
|
|
|
for j in range(best_h): |
|
|
for i in range(best_w): |
|
|
|
|
|
box = (i * origin_image_width//best_w, j * origin_image_height//best_h, (i + 1) * origin_image_width//best_w, (j + 1) * origin_image_height//best_h) |
|
|
|
|
|
region = image.crop(box).convert("RGB") |
|
|
|
|
|
slices.append(region) |
|
|
|
|
|
return slices |
|
|
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, threshold=1): |
|
|
orig_width, orig_height = image.size |
|
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
|
|
|
target_ratios = set( |
|
|
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if |
|
|
i * j <= max_num and i * j >= min_num) |
|
|
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
|
aspect_ratio, target_ratios, orig_width, orig_height, image_size, threshold) |
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
|
target_height = image_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
|
|
|
resized_img = image.resize((target_width, target_height)) |
|
|
processed_images = [] |
|
|
for i in range(blocks): |
|
|
box = ( |
|
|
(i % (target_width // image_size)) * image_size, |
|
|
(i // (target_width // image_size)) * image_size, |
|
|
((i % (target_width // image_size)) + 1) * image_size, |
|
|
((i // (target_width // image_size)) + 1) * image_size |
|
|
) |
|
|
print(box) |
|
|
|
|
|
split_img = resized_img.crop(box) |
|
|
processed_images.append(split_img) |
|
|
assert len(processed_images) == blocks |
|
|
if use_thumbnail and len(processed_images) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
processed_images.append(thumbnail_img) |
|
|
return processed_images |
|
|
|
|
|
def process_image(image, image_size, max_num): |
|
|
PATCH_SIZE = args.patch_size |
|
|
PATCH_NUM_WIDTH = args.patch_num_width |
|
|
PATCH_NUM_HEIGHT = args.patch_num_height |
|
|
POSITION_EMBEDDING_LENGTH = args.position_embedding_length |
|
|
print(PATCH_SIZE,PATCH_NUM_WIDTH,PATCH_NUM_HEIGHT,POSITION_EMBEDDING_LENGTH) |
|
|
|
|
|
MAX_PATCHES = PATCH_NUM_WIDTH * PATCH_NUM_HEIGHT |
|
|
|
|
|
TOKEN_LENGTH = 3 * PATCH_SIZE * PATCH_SIZE |
|
|
|
|
|
IMAGE_WIDTH = PATCH_SIZE * PATCH_NUM_WIDTH |
|
|
IMAGE_HEIGHT = PATCH_SIZE * PATCH_NUM_HEIGHT |
|
|
|
|
|
origin_image_width = image.size[0] |
|
|
origin_image_height = image.size[1] |
|
|
image = image.convert("RGB") |
|
|
slices = slice_image(image, max_num) |
|
|
if len(slices) != 1: |
|
|
thumbnail_img = image.resize((image_size, image_size)) |
|
|
slices.append(thumbnail_img) |
|
|
|
|
|
resized_height, resized_width, resized_patch_height, resized_patch_width = \ |
|
|
adapt_size(origin_image_height,origin_image_width) |
|
|
image = slices[0] |
|
|
image_w = image.size[0] |
|
|
image_h = image.size[1] |
|
|
resized_height, resized_width, resized_patch_height, resized_patch_width = \ |
|
|
adapt_size(image_h,image_w) |
|
|
image = ToTensor()(image) |
|
|
|
|
|
image = torch.nn.functional.interpolate( |
|
|
image.unsqueeze(0), |
|
|
size=(resized_height, resized_width), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
antialias=True, |
|
|
).squeeze(0) |
|
|
|
|
|
num_patches_to_pad = MAX_PATCHES - resized_patch_height*resized_patch_width |
|
|
|
|
|
|
|
|
image = torch_extract_patches(image,PATCH_SIZE, PATCH_SIZE) |
|
|
image = image.reshape([resized_patch_width*resized_patch_height,TOKEN_LENGTH]) |
|
|
|
|
|
image = torch.nn.functional.pad(image, [0, 0, 0, num_patches_to_pad]).float() |
|
|
image = image.reshape(PATCH_NUM_WIDTH, PATCH_NUM_HEIGHT, PATCH_SIZE, PATCH_SIZE, 3).permute(0, 2, 1, 3, 4).reshape(IMAGE_WIDTH, IMAGE_HEIGHT, 3).permute(2, 0 ,1) |
|
|
|
|
|
|
|
|
return slices |
|
|
|
|
|
def _convert_to_rgb(image): |
|
|
return image.convert('RGB') |
|
|
|
|
|
def load_image(image_file, input_size=448, max_num=9): |
|
|
image = Image.open(image_file).convert('RGB') |
|
|
|
|
|
|
|
|
if args.clip_model_name == 'InternViT-448': |
|
|
transform = build_transform(input_size=input_size) |
|
|
|
|
|
|
|
|
if args.image_segment_method == 'adaptive': |
|
|
images_processed = process_image(image, input_size, max_num) |
|
|
elif args.image_segment_method == 'dynamic': |
|
|
images_processed = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num, threshold=args.shape_change_threshold) |
|
|
|
|
|
pixel_values = [transform(image) for image in images_processed] |
|
|
else: |
|
|
transform = build_transform(input_size=input_size) |
|
|
if args.image_segment_method == 'adaptive': |
|
|
images_processed = process_image(image, input_size, max_num) |
|
|
elif args.image_segment_method == 'dynamic': |
|
|
images_processed = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) |
|
|
pixel_values = [transform(image) for image in images_processed] |
|
|
|
|
|
pixel_values = torch.stack(pixel_values) |
|
|
|
|
|
return pixel_values |
|
|
|
|
|
def preocess_imput(args, num_token_per_tile, image_path, question): |
|
|
image_prompts = '' |
|
|
if len(image_path) >= 2: |
|
|
image_list = [] |
|
|
num_tile_per_image_list = [] |
|
|
for ipath in image_path: |
|
|
images = load_image(ipath, max_num=args.max_split_tile_num_multi_image) |
|
|
|
|
|
num_tile_this_image = len(images) |
|
|
num_tile_per_image_list.append(num_tile_this_image) |
|
|
image_list.append(images) |
|
|
image_prompts = image_prompts + '<IMAGE>' + '<pad>' * num_tile_this_image * num_token_per_tile + '</IMAGE>' |
|
|
num_tile_per_image_tensor = torch.Tensor(num_tile_per_image_list).long().cuda() |
|
|
image_tensor = torch.cat(image_list, dim=0).view(1, -1, 3, 448, 448).cuda() |
|
|
|
|
|
else: |
|
|
|
|
|
images = load_image(image_path[0], max_num=args.max_split_tile_num_single_image) |
|
|
num_tile_this_image = len(images) |
|
|
num_tile_per_image_tensor = torch.Tensor([num_tile_this_image]).long().cuda() |
|
|
image_tensor = images.view(1, -1, 3, 448, 448).cuda() |
|
|
image_prompts = image_prompts + '<IMAGE>' + '<pad>' * num_tile_this_image * num_token_per_tile + '</IMAGE>' |
|
|
|
|
|
if args.fp16: |
|
|
image_tensor = image_tensor.half() |
|
|
elif args.bf16: |
|
|
image_tensor = image_tensor.bfloat16() |
|
|
else: |
|
|
image_tensor = image_tensor.float() |
|
|
|
|
|
images_input = {'num_tile_per_image_tensor': num_tile_per_image_tensor, |
|
|
'image_tensor': image_tensor} |
|
|
|
|
|
prompts = ['<BOS>' + image_prompts + question[0] + '<sep>'] |
|
|
|
|
|
return prompts, images_input |
|
|
|
|
|
|
|
|
def _build_yuanvl_attention_mask_and_position_ids(tokenizer, tokens, images_input=None): |
|
|
"""Build the attention mask and postition ids for the input tokens.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bos_token, image_start_token, image_end_token, pad_token, sep_tpken, eod_token = (tokenizer(tok)['input_ids'][0] for tok in ['<BOS>','<IMAGE>', '</IMAGE>', '<pad>', '<sep>', '<eod>']) |
|
|
|
|
|
|
|
|
attention_mask, position_ids, image_info = get_ltor_masks_and_position_ids_yuanvl_inference( |
|
|
tokens, |
|
|
bos_token, |
|
|
image_start_token, |
|
|
image_end_token, |
|
|
eod_token, |
|
|
pad_token, |
|
|
images_input) |
|
|
|
|
|
|
|
|
'''attention_mask, _, position_ids = get_ltor_masks_and_position_ids( |
|
|
data=tokens, |
|
|
eod_token=None, |
|
|
reset_position_ids=False, |
|
|
reset_attention_mask=False, |
|
|
eod_mask_loss=False)''' |
|
|
|
|
|
return attention_mask, position_ids, image_info |
|
|
|
|
|
def get_ltor_masks_and_position_ids_yuanvl_inference(data, |
|
|
bos_token, |
|
|
image_start_token, |
|
|
image_end_token, |
|
|
eod_token, |
|
|
pad_token, |
|
|
images_input, |
|
|
reset_attention_mask=False): |
|
|
"""Build masks and position id for left to right model.""" |
|
|
|
|
|
micro_batch_size, seq_length = data.size() |
|
|
assert micro_batch_size == 1, 'yuanvl support mbs = 1 only' |
|
|
|
|
|
|
|
|
if reset_attention_mask: |
|
|
att_mask_batch = micro_batch_size |
|
|
else: |
|
|
att_mask_batch = 1 |
|
|
attention_mask = torch.tril(torch.ones( |
|
|
(att_mask_batch, seq_length, seq_length), device=data.device)).view( |
|
|
att_mask_batch, 1, seq_length, seq_length) |
|
|
|
|
|
|
|
|
|
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, |
|
|
device=data.device) |
|
|
position_ids = position_ids.unsqueeze(0).expand_as(data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if images_input is not None: |
|
|
num_tile_per_image_tensor = images_input['num_tile_per_image_tensor'] |
|
|
images_tensor = images_input['image_tensor'] |
|
|
input_pad = [] |
|
|
image_info = {} |
|
|
position_ids_use = torch.zeros(data.shape).to(position_ids) |
|
|
for b in range(micro_batch_size): |
|
|
bos_index = position_ids[b, data[b] == bos_token] |
|
|
pad_index = position_ids[b, data[b] == pad_token] |
|
|
image_start_index = position_ids[b, data[b] == image_start_token] |
|
|
image_end_index = position_ids[b, data[b] == image_end_token] |
|
|
|
|
|
|
|
|
num_image = len(num_tile_per_image_tensor) |
|
|
|
|
|
|
|
|
|
|
|
image_info['num_tile'] = num_tile_per_image_tensor |
|
|
|
|
|
image_info['image_start_pos'] = image_start_index.tolist() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_idx = image_end_index[-1] |
|
|
diff = seq_length - start_idx |
|
|
position_ids_use[b][start_idx : ] = torch.arange(diff, dtype=torch.long, |
|
|
device=data.device) |
|
|
else: |
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, |
|
|
device=data.device) |
|
|
position_ids = position_ids.unsqueeze(0) |
|
|
position_ids_use = position_ids |
|
|
image_info = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attention_mask = (attention_mask < 0.5) |
|
|
|
|
|
'''xattn_position_ids = torch.arange(seq_length, dtype=torch.long, |
|
|
device=data.device) |
|
|
xattn_position_ids = xattn_position_ids.unsqueeze(0).expand_as(data) |
|
|
|
|
|
for b in range(micro_batch_size): |
|
|
|
|
|
bos_index = xattn_position_ids[b, data[b] == bos_token] |
|
|
|
|
|
num_image = len(bos_index) |
|
|
|
|
|
xattn_mask = torch.zeros((micro_batch_size, seq_length, num_image * clip_visual_size), device = data.device).view(micro_batch_size, 1, seq_length, num_image * clip_visual_size) |
|
|
|
|
|
for j in range(bos_index.size()[0]): |
|
|
sidx = bos_index[j] |
|
|
|
|
|
image_sidx = j * clip_visual_size |
|
|
image_eidx = (j + 1) * clip_visual_size |
|
|
|
|
|
#xattn_mask[b, 0, (sidx + 1) : , image_sidx : image_eidx] = 1 |
|
|
xattn_mask[b, 0, sidx : , image_sidx : image_eidx] = 1 |
|
|
#xattn_mask[b, 0, sidx : (eidx + 1), image_sidx : image_eidx] = 1 |
|
|
|
|
|
xattn_mask = (xattn_mask < 0.5)''' |
|
|
|
|
|
return attention_mask, position_ids_use, image_info |
|
|
|
|
|
tokenizer_loadpath = "/mnt/beegfs3/zhaoxudong/code/yuanvl_hf_40B_stage2_pcase4_12pp/" |
|
|
model_loadpath = "/mnt/beegfs3/zhaoxudong/code/yuanvl_hf_40B_stage2_pcase4_12pp/" |
|
|
|
|
|
|
|
|
|
|
|
model = AutoModel.from_pretrained( |
|
|
model_loadpath, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
use_flash_attn=False, |
|
|
device_map="auto", |
|
|
trust_remote_code=True).eval() |
|
|
|
|
|
|
|
|
print("Creat model finish") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_loadpath) |
|
|
|
|
|
|
|
|
num_token_per_tile = int(args.clip_visual_size * args.downsample_ratio**2) |
|
|
|
|
|
|
|
|
image_path = ['/mnt/beegfs3/zhaoxudong/code/image.jpeg'] |
|
|
question = ['Please describe the picture'] |
|
|
question = ['请描述这张图片的内容'] |
|
|
|
|
|
prompts, images_input = preocess_imput(args, num_token_per_tile, image_path, question) |
|
|
|
|
|
input=tokenizer(prompts, return_tensors="pt") |
|
|
input_ids = input['input_ids'].to("cuda") |
|
|
pixel_values=images_input['image_tensor'] |
|
|
|
|
|
attention_mask, position_ids, image_info = _build_yuanvl_attention_mask_and_position_ids( |
|
|
tokenizer, input_ids, images_input) |
|
|
|
|
|
attention_mask = input['attention_mask'].to("cuda") |
|
|
|
|
|
output = model.generate(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) |
|
|
print(tokenizer.decode(output[0])) |
|
|
|