srimanth-d commited on
Commit
7e2b339
·
verified ·
1 Parent(s): 9646a9b

Update modeling_got.py

Browse files
Files changed (1) hide show
  1. modeling_got.py +94 -38
modeling_got.py CHANGED
@@ -18,7 +18,7 @@ DEFAULT_IMAGE_TOKEN = "<image>"
18
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
  DEFAULT_IM_START_TOKEN = '<img>'
20
  DEFAULT_IM_END_TOKEN = '</img>'
21
- device = "cpu"
22
  print("Using device ",device)
23
 
24
  from enum import auto, Enum
@@ -568,29 +568,57 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
568
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
569
 
570
  if stream_flag:
571
- with torch.autocast("cpu", dtype=torch.bfloat16):
 
 
 
 
 
 
 
 
 
 
 
 
572
  output_ids = self.generate(
573
- input_ids,
574
- images=[image_tensor_1.unsqueeze(0).half().cpu()],
575
- do_sample=False,
576
- num_beams = 1,
577
- no_repeat_ngram_size = 20,
578
- streamer=streamer,
579
- max_new_tokens=4096,
580
- stopping_criteria=[stopping_criteria]
581
  )
 
 
582
  else:
583
- with torch.autocast("cpu", dtype=torch.bfloat16):
 
 
 
 
 
 
 
 
 
 
 
 
584
  output_ids = self.generate(
585
- input_ids,
586
- images=[image_tensor_1.unsqueeze(0).half().cpu()],
587
- do_sample=False,
588
- num_beams = 1,
589
- no_repeat_ngram_size = 20,
590
- # streamer=streamer,
591
- max_new_tokens=4096,
592
- stopping_criteria=[stopping_criteria]
593
  )
 
 
594
 
595
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
596
 
@@ -822,29 +850,57 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
822
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
823
 
824
  if stream_flag:
825
- with torch.autocast("cpu", dtype=torch.bfloat16):
826
- output_ids = self.generate(
827
- input_ids,
828
- images=[image_list.half().cpu()],
829
- do_sample=False,
830
- num_beams = 1,
831
- # no_repeat_ngram_size = 20,
832
- streamer=streamer,
833
- max_new_tokens=4096,
834
- stopping_criteria=[stopping_criteria]
 
835
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
836
  else:
837
- with torch.autocast("cpu", dtype=torch.bfloat16):
 
 
 
 
 
 
 
 
 
 
 
 
838
  output_ids = self.generate(
839
- input_ids,
840
- images=[image_list.half().cpu()],
841
- do_sample=False,
842
- num_beams = 1,
843
- # no_repeat_ngram_size = 20,
844
- # streamer=streamer,
845
- max_new_tokens=4096,
846
- stopping_criteria=[stopping_criteria]
847
  )
 
 
848
 
849
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
850
 
 
18
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
  DEFAULT_IM_START_TOKEN = '<img>'
20
  DEFAULT_IM_END_TOKEN = '</img>'
21
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
22
  print("Using device ",device)
23
 
24
  from enum import auto, Enum
 
568
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
569
 
570
  if stream_flag:
571
+ if device == "cuda":
572
+ with torch.autocast("cuda", dtype=torch.bfloat16):
573
+ output_ids = self.generate(
574
+ input_ids,
575
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
576
+ do_sample=False,
577
+ num_beams = 1,
578
+ no_repeat_ngram_size = 20,
579
+ streamer=streamer,
580
+ max_new_tokens=4096,
581
+ stopping_criteria=[stopping_criteria]
582
+ )
583
+ elif device == "mps" or device == "cpu":
584
  output_ids = self.generate(
585
+ input_ids,
586
+ images=[image_tensor_1.unsqueeze(0).half().to(device)],
587
+ do_sample=False,
588
+ num_beams = 1,
589
+ no_repeat_ngram_size = 20,
590
+ streamer=streamer,
591
+ max_new_tokens=4096,
592
+ stopping_criteria=[stopping_criteria]
593
  )
594
+ else:
595
+ print("Device unknown!")
596
  else:
597
+ if device == "cuda":
598
+ with torch.autocast("cuda", dtype=torch.bfloat16):
599
+ output_ids = self.generate(
600
+ input_ids,
601
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
602
+ do_sample=False,
603
+ num_beams = 1,
604
+ no_repeat_ngram_size = 20,
605
+ # streamer=streamer,
606
+ max_new_tokens=4096,
607
+ stopping_criteria=[stopping_criteria]
608
+ )
609
+ elif device == "mps" or device == "cpu":
610
  output_ids = self.generate(
611
+ input_ids,
612
+ images=[image_tensor_1.unsqueeze(0).half().to(device)],
613
+ do_sample=False,
614
+ num_beams = 1,
615
+ no_repeat_ngram_size = 20,
616
+ # streamer=streamer,
617
+ max_new_tokens=4096,
618
+ stopping_criteria=[stopping_criteria]
619
  )
620
+ else:
621
+ print("Device unknown!")
622
 
623
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
624
 
 
850
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
851
 
852
  if stream_flag:
853
+ if device == "cuda":
854
+ with torch.autocast("cuda", dtype=torch.bfloat16):
855
+ output_ids = self.generate(
856
+ input_ids,
857
+ images=[image_list.half().cuda()],
858
+ do_sample=False,
859
+ num_beams = 1,
860
+ # no_repeat_ngram_size = 20,
861
+ streamer=streamer,
862
+ max_new_tokens=4096,
863
+ stopping_criteria=[stopping_criteria]
864
  )
865
+ elif device == "mps" or device == "cpu":
866
+ output_ids = self.generate(
867
+ input_ids,
868
+ images=[image_list.half().to(device)],
869
+ do_sample=False,
870
+ num_beams = 1,
871
+ # no_repeat_ngram_size = 20,
872
+ streamer=streamer,
873
+ max_new_tokens=4096,
874
+ stopping_criteria=[stopping_criteria]
875
+ )
876
+ else:
877
+ print("Device unknown!")
878
  else:
879
+ if device == "cuda":
880
+ with torch.autocast("cuda", dtype=torch.bfloat16):
881
+ output_ids = self.generate(
882
+ input_ids,
883
+ images=[image_list.half().cuda()],
884
+ do_sample=False,
885
+ num_beams = 1,
886
+ # no_repeat_ngram_size = 20,
887
+ # streamer=streamer,
888
+ max_new_tokens=4096,
889
+ stopping_criteria=[stopping_criteria]
890
+ )
891
+ elif device == "mps" or device == "cpu":
892
  output_ids = self.generate(
893
+ input_ids,
894
+ images=[image_list.half().to(device)],
895
+ do_sample=False,
896
+ num_beams = 1,
897
+ # no_repeat_ngram_size = 20,
898
+ # streamer=streamer,
899
+ max_new_tokens=4096,
900
+ stopping_criteria=[stopping_criteria]
901
  )
902
+ else:
903
+ print("Device unknown!")
904
 
905
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
906