Fix code to comport with newer Transformers library
#41
by
ctranslate2-4you
- opened
- modeling_GOT.py +39 -54
modeling_GOT.py
CHANGED
|
@@ -393,59 +393,46 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 393 |
def prepare_inputs_for_generation(
|
| 394 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 395 |
):
|
| 396 |
-
|
|
|
|
|
|
|
| 397 |
if past_key_values is not None:
|
| 398 |
if isinstance(past_key_values, Cache):
|
| 399 |
cache_length = past_key_values.get_seq_length()
|
| 400 |
-
|
| 401 |
-
|
|
|
|
| 402 |
else:
|
| 403 |
-
cache_length =
|
|
|
|
| 404 |
max_cache_length = None
|
| 405 |
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
elif past_length < input_ids.shape[1]:
|
| 415 |
-
input_ids = input_ids[:, past_length:]
|
| 416 |
-
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 417 |
-
|
| 418 |
-
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 419 |
-
if (
|
| 420 |
-
max_cache_length is not None
|
| 421 |
-
and attention_mask is not None
|
| 422 |
-
and cache_length + input_ids.shape[1] > max_cache_length
|
| 423 |
-
):
|
| 424 |
-
attention_mask = attention_mask[:, -max_cache_length:]
|
| 425 |
|
| 426 |
position_ids = kwargs.get("position_ids", None)
|
| 427 |
if attention_mask is not None and position_ids is None:
|
| 428 |
-
# create position_ids on the fly for batch generation
|
| 429 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 430 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 431 |
if past_key_values:
|
| 432 |
-
position_ids = position_ids[:, -input_ids.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
|
| 434 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 435 |
-
if inputs_embeds is not None and past_key_values is None:
|
| 436 |
-
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 437 |
-
else:
|
| 438 |
-
model_inputs = {"input_ids": input_ids}
|
| 439 |
-
|
| 440 |
-
model_inputs.update(
|
| 441 |
-
{
|
| 442 |
-
"position_ids": position_ids,
|
| 443 |
-
"past_key_values": past_key_values,
|
| 444 |
-
"use_cache": kwargs.get("use_cache"),
|
| 445 |
-
"attention_mask": attention_mask,
|
| 446 |
-
"images": kwargs.get("images", None),
|
| 447 |
-
}
|
| 448 |
-
)
|
| 449 |
return model_inputs
|
| 450 |
|
| 451 |
def initialize_vision_tokenizer(
|
|
@@ -536,7 +523,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 536 |
|
| 537 |
conv_mpt = Conversation(
|
| 538 |
system="""<|im_start|>system
|
| 539 |
-
|
| 540 |
# system = None,
|
| 541 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 542 |
version="mpt",
|
|
@@ -728,7 +715,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 728 |
return processed_images
|
| 729 |
|
| 730 |
|
| 731 |
-
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag
|
| 732 |
# Model
|
| 733 |
self.disable_torch_init()
|
| 734 |
multi_page=False
|
|
@@ -778,21 +765,18 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 778 |
image_tensor_1 = image_processor_high(image)
|
| 779 |
image_list.append(image_tensor_1)
|
| 780 |
|
| 781 |
-
|
| 782 |
image_list = torch.stack(image_list)
|
| 783 |
|
| 784 |
-
print('====new images batch size======: \n',image_list.shape)
|
| 785 |
-
|
| 786 |
|
| 787 |
if use_im_start_end:
|
| 788 |
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
|
| 789 |
else:
|
| 790 |
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
| 791 |
|
| 792 |
-
|
| 793 |
conv_mpt = Conversation(
|
| 794 |
system="""<|im_start|>system
|
| 795 |
-
|
| 796 |
# system = None,
|
| 797 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 798 |
version="mpt",
|
|
@@ -811,8 +795,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 811 |
print(prompt)
|
| 812 |
|
| 813 |
inputs = tokenizer([prompt])
|
| 814 |
-
|
| 815 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
|
|
|
| 816 |
|
| 817 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 818 |
keywords = [stop_str]
|
|
@@ -824,25 +808,26 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 824 |
output_ids = self.generate(
|
| 825 |
input_ids,
|
| 826 |
images=[image_list.half().cuda()],
|
|
|
|
| 827 |
do_sample=False,
|
| 828 |
-
num_beams = 1,
|
| 829 |
-
# no_repeat_ngram_size = 20,
|
| 830 |
streamer=streamer,
|
|
|
|
| 831 |
max_new_tokens=4096,
|
| 832 |
stopping_criteria=[stopping_criteria]
|
| 833 |
-
|
|
|
|
| 834 |
else:
|
| 835 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 836 |
output_ids = self.generate(
|
| 837 |
input_ids,
|
| 838 |
images=[image_list.half().cuda()],
|
|
|
|
| 839 |
do_sample=False,
|
| 840 |
-
num_beams = 1,
|
| 841 |
-
# no_repeat_ngram_size = 20,
|
| 842 |
# streamer=streamer,
|
|
|
|
| 843 |
max_new_tokens=4096,
|
| 844 |
stopping_criteria=[stopping_criteria]
|
| 845 |
-
|
| 846 |
|
| 847 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 848 |
|
|
|
|
| 393 |
def prepare_inputs_for_generation(
|
| 394 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 395 |
):
|
| 396 |
+
if attention_mask is None:
|
| 397 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
|
| 398 |
+
|
| 399 |
if past_key_values is not None:
|
| 400 |
if isinstance(past_key_values, Cache):
|
| 401 |
cache_length = past_key_values.get_seq_length()
|
| 402 |
+
current_length = cache_length
|
| 403 |
+
max_cache_shape = past_key_values.get_max_cache_shape()
|
| 404 |
+
max_cache_length = max_cache_shape[1] if max_cache_shape else None
|
| 405 |
else:
|
| 406 |
+
cache_length = past_key_values[0][0].shape[2]
|
| 407 |
+
current_length = cache_length
|
| 408 |
max_cache_length = None
|
| 409 |
|
| 410 |
+
if attention_mask.shape[1] > input_ids.shape[1]:
|
| 411 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - cache_length):]
|
| 412 |
+
elif cache_length < input_ids.shape[1]:
|
| 413 |
+
input_ids = input_ids[:, cache_length:]
|
| 414 |
+
|
| 415 |
+
if max_cache_length is not None and attention_mask is not None:
|
| 416 |
+
if cache_length + input_ids.shape[1] > max_cache_length:
|
| 417 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
position_ids = kwargs.get("position_ids", None)
|
| 420 |
if attention_mask is not None and position_ids is None:
|
|
|
|
| 421 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 422 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 423 |
if past_key_values:
|
| 424 |
+
position_ids = position_ids[:, -input_ids.shape[1]:]
|
| 425 |
+
|
| 426 |
+
model_inputs = {
|
| 427 |
+
"input_ids": input_ids if inputs_embeds is None or past_key_values is not None else None,
|
| 428 |
+
"inputs_embeds": inputs_embeds if past_key_values is None else None,
|
| 429 |
+
"past_key_values": past_key_values,
|
| 430 |
+
"position_ids": position_ids,
|
| 431 |
+
"attention_mask": attention_mask,
|
| 432 |
+
"images": kwargs.get("images", None),
|
| 433 |
+
"use_cache": kwargs.get("use_cache", True)
|
| 434 |
+
}
|
| 435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
return model_inputs
|
| 437 |
|
| 438 |
def initialize_vision_tokenizer(
|
|
|
|
| 523 |
|
| 524 |
conv_mpt = Conversation(
|
| 525 |
system="""<|im_start|>system
|
| 526 |
+
You should follow the instructions carefully and explain your answers in detail.""",
|
| 527 |
# system = None,
|
| 528 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 529 |
version="mpt",
|
|
|
|
| 715 |
return processed_images
|
| 716 |
|
| 717 |
|
| 718 |
+
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag=False):
|
| 719 |
# Model
|
| 720 |
self.disable_torch_init()
|
| 721 |
multi_page=False
|
|
|
|
| 765 |
image_tensor_1 = image_processor_high(image)
|
| 766 |
image_list.append(image_tensor_1)
|
| 767 |
|
|
|
|
| 768 |
image_list = torch.stack(image_list)
|
| 769 |
|
| 770 |
+
# print('====new images batch size======: \n',image_list.shape)
|
|
|
|
| 771 |
|
| 772 |
if use_im_start_end:
|
| 773 |
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
|
| 774 |
else:
|
| 775 |
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
| 776 |
|
|
|
|
| 777 |
conv_mpt = Conversation(
|
| 778 |
system="""<|im_start|>system
|
| 779 |
+
You should follow the instructions carefully and explain your answers in detail.""",
|
| 780 |
# system = None,
|
| 781 |
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 782 |
version="mpt",
|
|
|
|
| 795 |
print(prompt)
|
| 796 |
|
| 797 |
inputs = tokenizer([prompt])
|
|
|
|
| 798 |
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
| 799 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
|
| 800 |
|
| 801 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 802 |
keywords = [stop_str]
|
|
|
|
| 808 |
output_ids = self.generate(
|
| 809 |
input_ids,
|
| 810 |
images=[image_list.half().cuda()],
|
| 811 |
+
attention_mask=attention_mask,
|
| 812 |
do_sample=False,
|
|
|
|
|
|
|
| 813 |
streamer=streamer,
|
| 814 |
+
num_beams=1,
|
| 815 |
max_new_tokens=4096,
|
| 816 |
stopping_criteria=[stopping_criteria]
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
else:
|
| 820 |
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 821 |
output_ids = self.generate(
|
| 822 |
input_ids,
|
| 823 |
images=[image_list.half().cuda()],
|
| 824 |
+
attention_mask=attention_mask,
|
| 825 |
do_sample=False,
|
|
|
|
|
|
|
| 826 |
# streamer=streamer,
|
| 827 |
+
num_beams=1,
|
| 828 |
max_new_tokens=4096,
|
| 829 |
stopping_criteria=[stopping_criteria]
|
| 830 |
+
)
|
| 831 |
|
| 832 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 833 |
|