Update modeling_got.py
Browse files- modeling_got.py +13 -12
modeling_got.py
CHANGED
|
@@ -18,6 +18,7 @@ DEFAULT_IMAGE_TOKEN = "<image>"
|
|
| 18 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
| 19 |
DEFAULT_IM_START_TOKEN = '<img>'
|
| 20 |
DEFAULT_IM_END_TOKEN = '</img>'
|
|
|
|
| 21 |
|
| 22 |
from enum import auto, Enum
|
| 23 |
class SeparatorStyle(Enum):
|
|
@@ -164,7 +165,7 @@ class GOTQwenModel(Qwen2Model):
|
|
| 164 |
use_im_start_end=False,
|
| 165 |
vision_select_layer=-1,
|
| 166 |
dtype=torch.float16,
|
| 167 |
-
device=
|
| 168 |
):
|
| 169 |
|
| 170 |
|
|
@@ -453,7 +454,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 453 |
tokenizer,
|
| 454 |
freeze_lm_model=False,
|
| 455 |
pretrained_stage1_model=None,
|
| 456 |
-
device=
|
| 457 |
):
|
| 458 |
config = self.get_model().config
|
| 459 |
|
|
@@ -558,7 +559,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 558 |
|
| 559 |
image_tensor_1 = image_processor_high(image)
|
| 560 |
|
| 561 |
-
input_ids = torch.as_tensor(inputs.input_ids).
|
| 562 |
|
| 563 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 564 |
keywords = [stop_str]
|
|
@@ -566,10 +567,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 567 |
|
| 568 |
if stream_flag:
|
| 569 |
-
with torch.autocast(
|
| 570 |
output_ids = self.generate(
|
| 571 |
input_ids,
|
| 572 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
| 573 |
do_sample=False,
|
| 574 |
num_beams = 1,
|
| 575 |
no_repeat_ngram_size = 20,
|
|
@@ -578,10 +579,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 578 |
stopping_criteria=[stopping_criteria]
|
| 579 |
)
|
| 580 |
else:
|
| 581 |
-
with torch.autocast(
|
| 582 |
output_ids = self.generate(
|
| 583 |
input_ids,
|
| 584 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
| 585 |
do_sample=False,
|
| 586 |
num_beams = 1,
|
| 587 |
no_repeat_ngram_size = 20,
|
|
@@ -812,7 +813,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 812 |
|
| 813 |
inputs = tokenizer([prompt])
|
| 814 |
|
| 815 |
-
input_ids = torch.as_tensor(inputs.input_ids).
|
| 816 |
|
| 817 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 818 |
keywords = [stop_str]
|
|
@@ -820,10 +821,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 820 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 821 |
|
| 822 |
if stream_flag:
|
| 823 |
-
with torch.autocast(
|
| 824 |
output_ids = self.generate(
|
| 825 |
input_ids,
|
| 826 |
-
images=[image_list.half().
|
| 827 |
do_sample=False,
|
| 828 |
num_beams = 1,
|
| 829 |
# no_repeat_ngram_size = 20,
|
|
@@ -832,10 +833,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 832 |
stopping_criteria=[stopping_criteria]
|
| 833 |
)
|
| 834 |
else:
|
| 835 |
-
with torch.autocast(
|
| 836 |
output_ids = self.generate(
|
| 837 |
input_ids,
|
| 838 |
-
images=[image_list.half().
|
| 839 |
do_sample=False,
|
| 840 |
num_beams = 1,
|
| 841 |
# no_repeat_ngram_size = 20,
|
|
|
|
| 18 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
| 19 |
DEFAULT_IM_START_TOKEN = '<img>'
|
| 20 |
DEFAULT_IM_END_TOKEN = '</img>'
|
| 21 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
| 22 |
|
| 23 |
from enum import auto, Enum
|
| 24 |
class SeparatorStyle(Enum):
|
|
|
|
| 165 |
use_im_start_end=False,
|
| 166 |
vision_select_layer=-1,
|
| 167 |
dtype=torch.float16,
|
| 168 |
+
device=device
|
| 169 |
):
|
| 170 |
|
| 171 |
|
|
|
|
| 454 |
tokenizer,
|
| 455 |
freeze_lm_model=False,
|
| 456 |
pretrained_stage1_model=None,
|
| 457 |
+
device=device
|
| 458 |
):
|
| 459 |
config = self.get_model().config
|
| 460 |
|
|
|
|
| 559 |
|
| 560 |
image_tensor_1 = image_processor_high(image)
|
| 561 |
|
| 562 |
+
input_ids = torch.as_tensor(inputs.input_ids).to(device)
|
| 563 |
|
| 564 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 565 |
keywords = [stop_str]
|
|
|
|
| 567 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 568 |
|
| 569 |
if stream_flag:
|
| 570 |
+
with torch.autocast(device, dtype=torch.bfloat16):
|
| 571 |
output_ids = self.generate(
|
| 572 |
input_ids,
|
| 573 |
+
images=[image_tensor_1.unsqueeze(0).half().to(device)],
|
| 574 |
do_sample=False,
|
| 575 |
num_beams = 1,
|
| 576 |
no_repeat_ngram_size = 20,
|
|
|
|
| 579 |
stopping_criteria=[stopping_criteria]
|
| 580 |
)
|
| 581 |
else:
|
| 582 |
+
with torch.autocast(device, dtype=torch.bfloat16):
|
| 583 |
output_ids = self.generate(
|
| 584 |
input_ids,
|
| 585 |
+
images=[image_tensor_1.unsqueeze(0).half().to(device)],
|
| 586 |
do_sample=False,
|
| 587 |
num_beams = 1,
|
| 588 |
no_repeat_ngram_size = 20,
|
|
|
|
| 813 |
|
| 814 |
inputs = tokenizer([prompt])
|
| 815 |
|
| 816 |
+
input_ids = torch.as_tensor(inputs.input_ids).to(device)
|
| 817 |
|
| 818 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 819 |
keywords = [stop_str]
|
|
|
|
| 821 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 822 |
|
| 823 |
if stream_flag:
|
| 824 |
+
with torch.autocast(device, dtype=torch.bfloat16):
|
| 825 |
output_ids = self.generate(
|
| 826 |
input_ids,
|
| 827 |
+
images=[image_list.half().to(device)],
|
| 828 |
do_sample=False,
|
| 829 |
num_beams = 1,
|
| 830 |
# no_repeat_ngram_size = 20,
|
|
|
|
| 833 |
stopping_criteria=[stopping_criteria]
|
| 834 |
)
|
| 835 |
else:
|
| 836 |
+
with torch.autocast(device, dtype=torch.bfloat16):
|
| 837 |
output_ids = self.generate(
|
| 838 |
input_ids,
|
| 839 |
+
images=[image_list.half().to(device)],
|
| 840 |
do_sample=False,
|
| 841 |
num_beams = 1,
|
| 842 |
# no_repeat_ngram_size = 20,
|