Update modeling_got.py
Browse files- modeling_got.py +94 -38
modeling_got.py
CHANGED
|
@@ -18,7 +18,7 @@ DEFAULT_IMAGE_TOKEN = "<image>"
|
|
| 18 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
| 19 |
DEFAULT_IM_START_TOKEN = '<img>'
|
| 20 |
DEFAULT_IM_END_TOKEN = '</img>'
|
| 21 |
-
device = "cpu"
|
| 22 |
print("Using device ",device)
|
| 23 |
|
| 24 |
from enum import auto, Enum
|
|
@@ -568,29 +568,57 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 568 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 569 |
|
| 570 |
if stream_flag:
|
| 571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
output_ids = self.generate(
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
)
|
|
|
|
|
|
|
| 582 |
else:
|
| 583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
output_ids = self.generate(
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
)
|
|
|
|
|
|
|
| 594 |
|
| 595 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 596 |
|
|
@@ -822,29 +850,57 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 822 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 823 |
|
| 824 |
if stream_flag:
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
|
|
|
| 835 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 836 |
else:
|
| 837 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 838 |
output_ids = self.generate(
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
)
|
|
|
|
|
|
|
| 848 |
|
| 849 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 850 |
|
|
|
|
| 18 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
| 19 |
DEFAULT_IM_START_TOKEN = '<img>'
|
| 20 |
DEFAULT_IM_END_TOKEN = '</img>'
|
| 21 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 22 |
print("Using device ",device)
|
| 23 |
|
| 24 |
from enum import auto, Enum
|
|
|
|
| 568 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 569 |
|
| 570 |
if stream_flag:
|
| 571 |
+
if device == "cuda":
|
| 572 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 573 |
+
output_ids = self.generate(
|
| 574 |
+
input_ids,
|
| 575 |
+
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
| 576 |
+
do_sample=False,
|
| 577 |
+
num_beams = 1,
|
| 578 |
+
no_repeat_ngram_size = 20,
|
| 579 |
+
streamer=streamer,
|
| 580 |
+
max_new_tokens=4096,
|
| 581 |
+
stopping_criteria=[stopping_criteria]
|
| 582 |
+
)
|
| 583 |
+
elif device == "mps" or device == "cpu":
|
| 584 |
output_ids = self.generate(
|
| 585 |
+
input_ids,
|
| 586 |
+
images=[image_tensor_1.unsqueeze(0).half().to(device)],
|
| 587 |
+
do_sample=False,
|
| 588 |
+
num_beams = 1,
|
| 589 |
+
no_repeat_ngram_size = 20,
|
| 590 |
+
streamer=streamer,
|
| 591 |
+
max_new_tokens=4096,
|
| 592 |
+
stopping_criteria=[stopping_criteria]
|
| 593 |
)
|
| 594 |
+
else:
|
| 595 |
+
print("Device unknown!")
|
| 596 |
else:
|
| 597 |
+
if device == "cuda":
|
| 598 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 599 |
+
output_ids = self.generate(
|
| 600 |
+
input_ids,
|
| 601 |
+
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
| 602 |
+
do_sample=False,
|
| 603 |
+
num_beams = 1,
|
| 604 |
+
no_repeat_ngram_size = 20,
|
| 605 |
+
# streamer=streamer,
|
| 606 |
+
max_new_tokens=4096,
|
| 607 |
+
stopping_criteria=[stopping_criteria]
|
| 608 |
+
)
|
| 609 |
+
elif device == "mps" or device == "cpu":
|
| 610 |
output_ids = self.generate(
|
| 611 |
+
input_ids,
|
| 612 |
+
images=[image_tensor_1.unsqueeze(0).half().to(device)],
|
| 613 |
+
do_sample=False,
|
| 614 |
+
num_beams = 1,
|
| 615 |
+
no_repeat_ngram_size = 20,
|
| 616 |
+
# streamer=streamer,
|
| 617 |
+
max_new_tokens=4096,
|
| 618 |
+
stopping_criteria=[stopping_criteria]
|
| 619 |
)
|
| 620 |
+
else:
|
| 621 |
+
print("Device unknown!")
|
| 622 |
|
| 623 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 624 |
|
|
|
|
| 850 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 851 |
|
| 852 |
if stream_flag:
|
| 853 |
+
if device == "cuda":
|
| 854 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 855 |
+
output_ids = self.generate(
|
| 856 |
+
input_ids,
|
| 857 |
+
images=[image_list.half().cuda()],
|
| 858 |
+
do_sample=False,
|
| 859 |
+
num_beams = 1,
|
| 860 |
+
# no_repeat_ngram_size = 20,
|
| 861 |
+
streamer=streamer,
|
| 862 |
+
max_new_tokens=4096,
|
| 863 |
+
stopping_criteria=[stopping_criteria]
|
| 864 |
)
|
| 865 |
+
elif device == "mps" or device == "cpu":
|
| 866 |
+
output_ids = self.generate(
|
| 867 |
+
input_ids,
|
| 868 |
+
images=[image_list.half().to(device)],
|
| 869 |
+
do_sample=False,
|
| 870 |
+
num_beams = 1,
|
| 871 |
+
# no_repeat_ngram_size = 20,
|
| 872 |
+
streamer=streamer,
|
| 873 |
+
max_new_tokens=4096,
|
| 874 |
+
stopping_criteria=[stopping_criteria]
|
| 875 |
+
)
|
| 876 |
+
else:
|
| 877 |
+
print("Device unknown!")
|
| 878 |
else:
|
| 879 |
+
if device == "cuda":
|
| 880 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
| 881 |
+
output_ids = self.generate(
|
| 882 |
+
input_ids,
|
| 883 |
+
images=[image_list.half().cuda()],
|
| 884 |
+
do_sample=False,
|
| 885 |
+
num_beams = 1,
|
| 886 |
+
# no_repeat_ngram_size = 20,
|
| 887 |
+
# streamer=streamer,
|
| 888 |
+
max_new_tokens=4096,
|
| 889 |
+
stopping_criteria=[stopping_criteria]
|
| 890 |
+
)
|
| 891 |
+
elif device == "mps" or device == "cpu":
|
| 892 |
output_ids = self.generate(
|
| 893 |
+
input_ids,
|
| 894 |
+
images=[image_list.half().to(device)],
|
| 895 |
+
do_sample=False,
|
| 896 |
+
num_beams = 1,
|
| 897 |
+
# no_repeat_ngram_size = 20,
|
| 898 |
+
# streamer=streamer,
|
| 899 |
+
max_new_tokens=4096,
|
| 900 |
+
stopping_criteria=[stopping_criteria]
|
| 901 |
)
|
| 902 |
+
else:
|
| 903 |
+
print("Device unknown!")
|
| 904 |
|
| 905 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 906 |
|