Making the Code Runnable on CPU
#10
by
DRXD1000 - opened
- modeling_GOT.py +39 -21
modeling_GOT.py
CHANGED
|
@@ -1,25 +1,37 @@
|
|
| 1 |
-
|
| 2 |
-
from
|
| 3 |
from typing import List, Optional, Tuple, Union
|
| 4 |
-
|
| 5 |
import requests
|
| 6 |
-
from PIL import Image
|
| 7 |
-
from io import BytesIO
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
|
|
|
| 10 |
from torch.nn import CrossEntropyLoss
|
| 11 |
-
from .got_vision_b import build_GOT_vit_b
|
| 12 |
from torchvision import transforms
|
| 13 |
from torchvision.transforms.functional import InterpolationMode
|
| 14 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 18 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
| 19 |
DEFAULT_IM_START_TOKEN = '<img>'
|
| 20 |
DEFAULT_IM_END_TOKEN = '</img>'
|
| 21 |
|
| 22 |
-
from enum import
|
|
|
|
|
|
|
| 23 |
class SeparatorStyle(Enum):
|
| 24 |
"""Different separator style."""
|
| 25 |
SINGLE = auto()
|
|
@@ -164,7 +176,7 @@ class GOTQwenModel(Qwen2Model):
|
|
| 164 |
use_im_start_end=False,
|
| 165 |
vision_select_layer=-1,
|
| 166 |
dtype=torch.float16,
|
| 167 |
-
device=
|
| 168 |
):
|
| 169 |
|
| 170 |
|
|
@@ -453,7 +465,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 453 |
tokenizer,
|
| 454 |
freeze_lm_model=False,
|
| 455 |
pretrained_stage1_model=None,
|
| 456 |
-
device=
|
| 457 |
):
|
| 458 |
config = self.get_model().config
|
| 459 |
|
|
@@ -488,6 +500,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 488 |
|
| 489 |
self.disable_torch_init()
|
| 490 |
|
|
|
|
| 491 |
|
| 492 |
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
| 493 |
|
|
@@ -558,7 +571,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 558 |
|
| 559 |
image_tensor_1 = image_processor_high(image)
|
| 560 |
|
| 561 |
-
input_ids = torch.as_tensor(inputs.input_ids).
|
| 562 |
|
| 563 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 564 |
keywords = [stop_str]
|
|
@@ -566,10 +579,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 567 |
|
| 568 |
if stream_flag:
|
| 569 |
-
with torch.autocast(
|
| 570 |
output_ids = self.generate(
|
| 571 |
input_ids,
|
| 572 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
| 573 |
do_sample=False,
|
| 574 |
num_beams = 1,
|
| 575 |
no_repeat_ngram_size = 20,
|
|
@@ -578,10 +591,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 578 |
stopping_criteria=[stopping_criteria]
|
| 579 |
)
|
| 580 |
else:
|
| 581 |
-
with torch.autocast(
|
| 582 |
output_ids = self.generate(
|
| 583 |
input_ids,
|
| 584 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
| 585 |
do_sample=False,
|
| 586 |
num_beams = 1,
|
| 587 |
no_repeat_ngram_size = 20,
|
|
@@ -599,7 +612,12 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 599 |
|
| 600 |
if render:
|
| 601 |
print('==============rendering===============')
|
| 602 |
-
from .render_tools import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
if '**kern' in outputs:
|
| 605 |
import verovio
|
|
@@ -812,7 +830,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 812 |
|
| 813 |
inputs = tokenizer([prompt])
|
| 814 |
|
| 815 |
-
input_ids = torch.as_tensor(inputs.input_ids).
|
| 816 |
|
| 817 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 818 |
keywords = [stop_str]
|
|
@@ -820,10 +838,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 820 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 821 |
|
| 822 |
if stream_flag:
|
| 823 |
-
with torch.autocast(
|
| 824 |
output_ids = self.generate(
|
| 825 |
input_ids,
|
| 826 |
-
images=[image_list.half().
|
| 827 |
do_sample=False,
|
| 828 |
num_beams = 1,
|
| 829 |
# no_repeat_ngram_size = 20,
|
|
@@ -832,10 +850,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
| 832 |
stopping_criteria=[stopping_criteria]
|
| 833 |
)
|
| 834 |
else:
|
| 835 |
-
with torch.autocast(
|
| 836 |
output_ids = self.generate(
|
| 837 |
input_ids,
|
| 838 |
-
images=[image_list.half().
|
| 839 |
do_sample=False,
|
| 840 |
num_beams = 1,
|
| 841 |
# no_repeat_ngram_size = 20,
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from io import BytesIO
|
| 3 |
from typing import List, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
import requests
|
|
|
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
+
from PIL import Image
|
| 9 |
from torch.nn import CrossEntropyLoss
|
|
|
|
| 10 |
from torchvision import transforms
|
| 11 |
from torchvision.transforms.functional import InterpolationMode
|
| 12 |
+
from transformers import (
|
| 13 |
+
Qwen2Config,
|
| 14 |
+
Qwen2ForCausalLM,
|
| 15 |
+
Qwen2Model,
|
| 16 |
+
StoppingCriteria,
|
| 17 |
+
TextStreamer,
|
| 18 |
+
)
|
| 19 |
+
from transformers.cache_utils import Cache
|
| 20 |
+
from transformers.modeling_outputs import (
|
| 21 |
+
BaseModelOutputWithPast,
|
| 22 |
+
CausalLMOutputWithPast,
|
| 23 |
+
)
|
| 24 |
|
| 25 |
+
from .got_vision_b import build_GOT_vit_b
|
| 26 |
|
| 27 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 28 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
| 29 |
DEFAULT_IM_START_TOKEN = '<img>'
|
| 30 |
DEFAULT_IM_END_TOKEN = '</img>'
|
| 31 |
|
| 32 |
+
from enum import Enum, auto
|
| 33 |
+
|
| 34 |
+
|
| 35 |
class SeparatorStyle(Enum):
|
| 36 |
"""Different separator style."""
|
| 37 |
SINGLE = auto()
|
|
|
|
| 176 |
use_im_start_end=False,
|
| 177 |
vision_select_layer=-1,
|
| 178 |
dtype=torch.float16,
|
| 179 |
+
device=None
|
| 180 |
):
|
| 181 |
|
| 182 |
|
|
|
|
| 465 |
tokenizer,
|
| 466 |
freeze_lm_model=False,
|
| 467 |
pretrained_stage1_model=None,
|
| 468 |
+
device=None
|
| 469 |
):
|
| 470 |
config = self.get_model().config
|
| 471 |
|
|
|
|
| 500 |
|
| 501 |
self.disable_torch_init()
|
| 502 |
|
| 503 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 504 |
|
| 505 |
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
| 506 |
|
|
|
|
| 571 |
|
| 572 |
image_tensor_1 = image_processor_high(image)
|
| 573 |
|
| 574 |
+
input_ids = torch.as_tensor(inputs.input_ids).to(self.device)
|
| 575 |
|
| 576 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 577 |
keywords = [stop_str]
|
|
|
|
| 579 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 580 |
|
| 581 |
if stream_flag:
|
| 582 |
+
with torch.autocast(str(self.device), dtype=torch.bfloat16):
|
| 583 |
output_ids = self.generate(
|
| 584 |
input_ids,
|
| 585 |
+
images=[image_tensor_1.unsqueeze(0).half().to(self.device)],
|
| 586 |
do_sample=False,
|
| 587 |
num_beams = 1,
|
| 588 |
no_repeat_ngram_size = 20,
|
|
|
|
| 591 |
stopping_criteria=[stopping_criteria]
|
| 592 |
)
|
| 593 |
else:
|
| 594 |
+
with torch.autocast(str(self.device), dtype=torch.bfloat16):
|
| 595 |
output_ids = self.generate(
|
| 596 |
input_ids,
|
| 597 |
+
images=[image_tensor_1.unsqueeze(0).half().to(self.device)],
|
| 598 |
do_sample=False,
|
| 599 |
num_beams = 1,
|
| 600 |
no_repeat_ngram_size = 20,
|
|
|
|
| 612 |
|
| 613 |
if render:
|
| 614 |
print('==============rendering===============')
|
| 615 |
+
from .render_tools import (
|
| 616 |
+
content_mmd_to_html,
|
| 617 |
+
svg_to_html,
|
| 618 |
+
tik_html,
|
| 619 |
+
translation_table,
|
| 620 |
+
)
|
| 621 |
|
| 622 |
if '**kern' in outputs:
|
| 623 |
import verovio
|
|
|
|
| 830 |
|
| 831 |
inputs = tokenizer([prompt])
|
| 832 |
|
| 833 |
+
input_ids = torch.as_tensor(inputs.input_ids).to(self.device)
|
| 834 |
|
| 835 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 836 |
keywords = [stop_str]
|
|
|
|
| 838 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 839 |
|
| 840 |
if stream_flag:
|
| 841 |
+
with torch.autocast(str(self.device), dtype=torch.bfloat16):
|
| 842 |
output_ids = self.generate(
|
| 843 |
input_ids,
|
| 844 |
+
images=[image_list.half().to(self.device)],
|
| 845 |
do_sample=False,
|
| 846 |
num_beams = 1,
|
| 847 |
# no_repeat_ngram_size = 20,
|
|
|
|
| 850 |
stopping_criteria=[stopping_criteria]
|
| 851 |
)
|
| 852 |
else:
|
| 853 |
+
with torch.autocast(str(self.device), dtype=torch.bfloat16):
|
| 854 |
output_ids = self.generate(
|
| 855 |
input_ids,
|
| 856 |
+
images=[image_list.half().to(self.device)],
|
| 857 |
do_sample=False,
|
| 858 |
num_beams = 1,
|
| 859 |
# no_repeat_ngram_size = 20,
|