add to mps device in _build_GOT_vision
Browse files- modeling_GOT.py +6 -6
modeling_GOT.py
CHANGED
|
@@ -164,7 +164,7 @@ class GOTQwenModel(Qwen2Model):
|
|
| 164 |
use_im_start_end=False,
|
| 165 |
vision_select_layer=-1,
|
| 166 |
dtype=torch.float16,
|
| 167 |
-
device="
|
| 168 |
):
|
| 169 |
|
| 170 |
|
|
@@ -453,7 +453,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 453 |
tokenizer,
|
| 454 |
freeze_lm_model=False,
|
| 455 |
pretrained_stage1_model=None,
|
| 456 |
-
device="
|
| 457 |
):
|
| 458 |
config = self.get_model().config
|
| 459 |
|
|
@@ -566,7 +566,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 567 |
|
| 568 |
if stream_flag:
|
| 569 |
-
with torch.autocast("
|
| 570 |
output_ids = self.generate(
|
| 571 |
input_ids,
|
| 572 |
images=[image_tensor_1.unsqueeze(0).half()],
|
|
@@ -578,7 +578,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 578 |
stopping_criteria=[stopping_criteria]
|
| 579 |
)
|
| 580 |
else:
|
| 581 |
-
with torch.autocast("
|
| 582 |
output_ids = self.generate(
|
| 583 |
input_ids,
|
| 584 |
images=[image_tensor_1.unsqueeze(0).half()],
|
|
@@ -820,7 +820,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 820 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 821 |
|
| 822 |
if stream_flag:
|
| 823 |
-
with torch.autocast("
|
| 824 |
output_ids = self.generate(
|
| 825 |
input_ids,
|
| 826 |
images=[image_list.half()],
|
|
@@ -832,7 +832,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 832 |
stopping_criteria=[stopping_criteria]
|
| 833 |
)
|
| 834 |
else:
|
| 835 |
-
with torch.autocast("
|
| 836 |
output_ids = self.generate(
|
| 837 |
input_ids,
|
| 838 |
images=[image_list.half()],
|
|
|
|
| 164 |
use_im_start_end=False,
|
| 165 |
vision_select_layer=-1,
|
| 166 |
dtype=torch.float16,
|
| 167 |
+
device="mps"
|
| 168 |
):
|
| 169 |
|
| 170 |
|
|
|
|
| 453 |
tokenizer,
|
| 454 |
freeze_lm_model=False,
|
| 455 |
pretrained_stage1_model=None,
|
| 456 |
+
device="mps"
|
| 457 |
):
|
| 458 |
config = self.get_model().config
|
| 459 |
|
|
|
|
| 566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 567 |
|
| 568 |
if stream_flag:
|
| 569 |
+
with torch.autocast("mps", dtype=torch.bfloat16):
|
| 570 |
output_ids = self.generate(
|
| 571 |
input_ids,
|
| 572 |
images=[image_tensor_1.unsqueeze(0).half()],
|
|
|
|
| 578 |
stopping_criteria=[stopping_criteria]
|
| 579 |
)
|
| 580 |
else:
|
| 581 |
+
with torch.autocast("mps", dtype=torch.bfloat16):
|
| 582 |
output_ids = self.generate(
|
| 583 |
input_ids,
|
| 584 |
images=[image_tensor_1.unsqueeze(0).half()],
|
|
|
|
| 820 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 821 |
|
| 822 |
if stream_flag:
|
| 823 |
+
with torch.autocast("mps", dtype=torch.float16):
|
| 824 |
output_ids = self.generate(
|
| 825 |
input_ids,
|
| 826 |
images=[image_list.half()],
|
|
|
|
| 832 |
stopping_criteria=[stopping_criteria]
|
| 833 |
)
|
| 834 |
else:
|
| 835 |
+
with torch.autocast("mps", dtype=torch.float16):
|
| 836 |
output_ids = self.generate(
|
| 837 |
input_ids,
|
| 838 |
images=[image_list.half()],
|