kirchik47 commited on
Commit
3153dd0
·
1 Parent(s): 2286e1a

Device bug fix

Browse files
Files changed (1) hide show
  1. custom_got/modeling_GOT.py +4 -4
custom_got/modeling_GOT.py CHANGED
@@ -20,9 +20,9 @@ DEFAULT_IM_START_TOKEN = '<img>'
20
  DEFAULT_IM_END_TOKEN = '</img>'
21
  cuda_is_available = torch.cuda.is_available()
22
  if cuda_is_available:
23
- device = torch.device('cuda')
24
  else:
25
- device = torch.device('cpu')
26
 
27
  from enum import auto, Enum
28
  class SeparatorStyle(Enum):
@@ -574,7 +574,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
574
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
575
 
576
  if stream_flag:
577
- with torch.autocast(device, dtype=torch.bfloat16):
578
  output_ids = self.generate(
579
  input_ids,
580
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
@@ -586,7 +586,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
586
  stopping_criteria=[stopping_criteria]
587
  )
588
  else:
589
- with torch.autocast(device, dtype=torch.bfloat16):
590
  output_ids = self.generate(
591
  input_ids,
592
  images=[image_tensor_1.unsqueeze(0).half().cuda()] if cuda_is_available else [image_tensor_1.unsqueeze(0).half()],
 
20
  DEFAULT_IM_END_TOKEN = '</img>'
21
  cuda_is_available = torch.cuda.is_available()
22
  if cuda_is_available:
23
+ device_type = torch.device('cuda').type
24
  else:
25
+ device_type = torch.device('cpu').type
26
 
27
  from enum import auto, Enum
28
  class SeparatorStyle(Enum):
 
574
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
575
 
576
  if stream_flag:
577
+ with torch.autocast(device_type, dtype=torch.bfloat16):
578
  output_ids = self.generate(
579
  input_ids,
580
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
 
586
  stopping_criteria=[stopping_criteria]
587
  )
588
  else:
589
+ with torch.autocast(device_type, dtype=torch.bfloat16):
590
  output_ids = self.generate(
591
  input_ids,
592
  images=[image_tensor_1.unsqueeze(0).half().cuda()] if cuda_is_available else [image_tensor_1.unsqueeze(0).half()],