Device bug fix
Browse files
custom_got/modeling_GOT.py
CHANGED
|
@@ -20,9 +20,9 @@ DEFAULT_IM_START_TOKEN = '<img>'
|
|
| 20 |
DEFAULT_IM_END_TOKEN = '</img>'
|
| 21 |
cuda_is_available = torch.cuda.is_available()
|
| 22 |
if cuda_is_available:
|
| 23 |
-
|
| 24 |
else:
|
| 25 |
-
|
| 26 |
|
| 27 |
from enum import auto, Enum
|
| 28 |
class SeparatorStyle(Enum):
|
|
@@ -574,7 +574,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 574 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 575 |
|
| 576 |
if stream_flag:
|
| 577 |
-
with torch.autocast(
|
| 578 |
output_ids = self.generate(
|
| 579 |
input_ids,
|
| 580 |
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
|
@@ -586,7 +586,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 586 |
stopping_criteria=[stopping_criteria]
|
| 587 |
)
|
| 588 |
else:
|
| 589 |
-
with torch.autocast(
|
| 590 |
output_ids = self.generate(
|
| 591 |
input_ids,
|
| 592 |
images=[image_tensor_1.unsqueeze(0).half().cuda()] if cuda_is_available else [image_tensor_1.unsqueeze(0).half()],
|
|
|
|
| 20 |
DEFAULT_IM_END_TOKEN = '</img>'
|
| 21 |
cuda_is_available = torch.cuda.is_available()
|
| 22 |
if cuda_is_available:
|
| 23 |
+
device_type = torch.device('cuda').type
|
| 24 |
else:
|
| 25 |
+
device_type = torch.device('cpu').type
|
| 26 |
|
| 27 |
from enum import auto, Enum
|
| 28 |
class SeparatorStyle(Enum):
|
|
|
|
| 574 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 575 |
|
| 576 |
if stream_flag:
|
| 577 |
+
with torch.autocast(device_type, dtype=torch.bfloat16):
|
| 578 |
output_ids = self.generate(
|
| 579 |
input_ids,
|
| 580 |
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
|
|
|
| 586 |
stopping_criteria=[stopping_criteria]
|
| 587 |
)
|
| 588 |
else:
|
| 589 |
+
with torch.autocast(device_type, dtype=torch.bfloat16):
|
| 590 |
output_ids = self.generate(
|
| 591 |
input_ids,
|
| 592 |
images=[image_tensor_1.unsqueeze(0).half().cuda()] if cuda_is_available else [image_tensor_1.unsqueeze(0).half()],
|