srimanth-d commited on
Commit
9ed59ac
·
verified ·
1 Parent(s): 5364fe1

Update modeling_got.py

Browse files
Files changed (1) hide show
  1. modeling_got.py +13 -12
modeling_got.py CHANGED
@@ -18,6 +18,7 @@ DEFAULT_IMAGE_TOKEN = "<image>"
18
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
  DEFAULT_IM_START_TOKEN = '<img>'
20
  DEFAULT_IM_END_TOKEN = '</img>'
 
21
 
22
  from enum import auto, Enum
23
  class SeparatorStyle(Enum):
@@ -164,7 +165,7 @@ class GOTQwenModel(Qwen2Model):
164
  use_im_start_end=False,
165
  vision_select_layer=-1,
166
  dtype=torch.float16,
167
- device="cpu"
168
  ):
169
 
170
 
@@ -453,7 +454,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
453
  tokenizer,
454
  freeze_lm_model=False,
455
  pretrained_stage1_model=None,
456
- device="cpu"
457
  ):
458
  config = self.get_model().config
459
 
@@ -558,7 +559,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
558
 
559
  image_tensor_1 = image_processor_high(image)
560
 
561
- input_ids = torch.as_tensor(inputs.input_ids).cpu()
562
 
563
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
564
  keywords = [stop_str]
@@ -566,10 +567,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
566
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
 
568
  if stream_flag:
569
- with torch.autocast("cpu", dtype=torch.bfloat16):
570
  output_ids = self.generate(
571
  input_ids,
572
- images=[image_tensor_1.unsqueeze(0).half().cpu()],
573
  do_sample=False,
574
  num_beams = 1,
575
  no_repeat_ngram_size = 20,
@@ -578,10 +579,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
578
  stopping_criteria=[stopping_criteria]
579
  )
580
  else:
581
- with torch.autocast("cpu", dtype=torch.bfloat16):
582
  output_ids = self.generate(
583
  input_ids,
584
- images=[image_tensor_1.unsqueeze(0).half().cpu()],
585
  do_sample=False,
586
  num_beams = 1,
587
  no_repeat_ngram_size = 20,
@@ -812,7 +813,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
812
 
813
  inputs = tokenizer([prompt])
814
 
815
- input_ids = torch.as_tensor(inputs.input_ids).cpu()
816
 
817
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
818
  keywords = [stop_str]
@@ -820,10 +821,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
820
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
821
 
822
  if stream_flag:
823
- with torch.autocast("cpu", dtype=torch.bfloat16):
824
  output_ids = self.generate(
825
  input_ids,
826
- images=[image_list.half().cpu()],
827
  do_sample=False,
828
  num_beams = 1,
829
  # no_repeat_ngram_size = 20,
@@ -832,10 +833,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
832
  stopping_criteria=[stopping_criteria]
833
  )
834
  else:
835
- with torch.autocast("cpu", dtype=torch.bfloat16):
836
  output_ids = self.generate(
837
  input_ids,
838
- images=[image_list.half().cpu()],
839
  do_sample=False,
840
  num_beams = 1,
841
  # no_repeat_ngram_size = 20,
 
18
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
  DEFAULT_IM_START_TOKEN = '<img>'
20
  DEFAULT_IM_END_TOKEN = '</img>'
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
22
 
23
  from enum import auto, Enum
24
  class SeparatorStyle(Enum):
 
165
  use_im_start_end=False,
166
  vision_select_layer=-1,
167
  dtype=torch.float16,
168
+ device=device
169
  ):
170
 
171
 
 
454
  tokenizer,
455
  freeze_lm_model=False,
456
  pretrained_stage1_model=None,
457
+ device=device
458
  ):
459
  config = self.get_model().config
460
 
 
559
 
560
  image_tensor_1 = image_processor_high(image)
561
 
562
+ input_ids = torch.as_tensor(inputs.input_ids).to(device)
563
 
564
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
565
  keywords = [stop_str]
 
567
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
568
 
569
  if stream_flag:
570
+ with torch.autocast(device, dtype=torch.bfloat16):
571
  output_ids = self.generate(
572
  input_ids,
573
+ images=[image_tensor_1.unsqueeze(0).half().to(device)],
574
  do_sample=False,
575
  num_beams = 1,
576
  no_repeat_ngram_size = 20,
 
579
  stopping_criteria=[stopping_criteria]
580
  )
581
  else:
582
+ with torch.autocast(device, dtype=torch.bfloat16):
583
  output_ids = self.generate(
584
  input_ids,
585
+ images=[image_tensor_1.unsqueeze(0).half().to(device)],
586
  do_sample=False,
587
  num_beams = 1,
588
  no_repeat_ngram_size = 20,
 
813
 
814
  inputs = tokenizer([prompt])
815
 
816
+ input_ids = torch.as_tensor(inputs.input_ids).to(device)
817
 
818
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
819
  keywords = [stop_str]
 
821
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
822
 
823
  if stream_flag:
824
+ with torch.autocast(device, dtype=torch.bfloat16):
825
  output_ids = self.generate(
826
  input_ids,
827
+ images=[image_list.half().to(device)],
828
  do_sample=False,
829
  num_beams = 1,
830
  # no_repeat_ngram_size = 20,
 
833
  stopping_criteria=[stopping_criteria]
834
  )
835
  else:
836
+ with torch.autocast(device, dtype=torch.bfloat16):
837
  output_ids = self.generate(
838
  input_ids,
839
+ images=[image_list.half().to(device)],
840
  do_sample=False,
841
  num_beams = 1,
842
  # no_repeat_ngram_size = 20,