Enable on CPU
#1
by
xf2022
- opened
- modeling_videochat_flash.py +12 -3
- vision_tower_builder.py +7 -5
modeling_videochat_flash.py
CHANGED
|
@@ -636,7 +636,10 @@ class VideoChatFlashQwenForCausalLM(LlavaMetaForCausalLM, Qwen2ForCausalLM_Flash
|
|
| 636 |
|
| 637 |
image_sizes = [frames[0].shape[:2]]
|
| 638 |
|
| 639 |
-
|
|
|
|
|
|
|
|
|
|
| 640 |
|
| 641 |
conv = conv_templates["qwen_2"].copy()
|
| 642 |
|
|
@@ -652,14 +655,20 @@ class VideoChatFlashQwenForCausalLM(LlavaMetaForCausalLM, Qwen2ForCausalLM_Flash
|
|
| 652 |
|
| 653 |
prompt = conv.get_prompt()
|
| 654 |
|
| 655 |
-
|
|
|
|
|
|
|
|
|
|
| 656 |
|
| 657 |
if tokenizer.pad_token_id is None:
|
| 658 |
if "qwen" in tokenizer.name_or_path.lower():
|
| 659 |
print("Setting pad token to bos token for qwen model.")
|
| 660 |
tokenizer.pad_token_id = 151643
|
| 661 |
|
| 662 |
-
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 665 |
keywords = [stop_str]
|
|
|
|
| 636 |
|
| 637 |
image_sizes = [frames[0].shape[:2]]
|
| 638 |
|
| 639 |
+
if torch.cuda.is_available():
|
| 640 |
+
frames = [self.get_vision_tower().image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(self.model.dtype).cuda()]
|
| 641 |
+
else:
|
| 642 |
+
frames = [self.get_vision_tower().image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(self.model.dtype)]
|
| 643 |
|
| 644 |
conv = conv_templates["qwen_2"].copy()
|
| 645 |
|
|
|
|
| 655 |
|
| 656 |
prompt = conv.get_prompt()
|
| 657 |
|
| 658 |
+
if torch.cuda.is_available():
|
| 659 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
|
| 660 |
+
else:
|
| 661 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0)
|
| 662 |
|
| 663 |
if tokenizer.pad_token_id is None:
|
| 664 |
if "qwen" in tokenizer.name_or_path.lower():
|
| 665 |
print("Setting pad token to bos token for qwen model.")
|
| 666 |
tokenizer.pad_token_id = 151643
|
| 667 |
|
| 668 |
+
if torch.cuda.is_available():
|
| 669 |
+
attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda()
|
| 670 |
+
else:
|
| 671 |
+
attention_masks = input_ids.ne(tokenizer.pad_token_id).long()
|
| 672 |
|
| 673 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 674 |
keywords = [stop_str]
|
vision_tower_builder.py
CHANGED
|
@@ -24,9 +24,11 @@ from transformers.image_utils import (
|
|
| 24 |
to_numpy_array,
|
| 25 |
)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
from flash_attn.
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class FlashAttention(nn.Module):
|
| 32 |
"""Implement the scaled dot product attention with softmax.
|
|
@@ -729,7 +731,7 @@ class InternVideo2VisionConfig:
|
|
| 729 |
patch_size=14,
|
| 730 |
x_vis_return_idx=-2,
|
| 731 |
sep_image_video_pos_embed=True,
|
| 732 |
-
use_checkpoint=
|
| 733 |
checkpoint_num=40,
|
| 734 |
# **kwargs,
|
| 735 |
):
|
|
@@ -757,7 +759,7 @@ def build_vit(config, pt_type='origin'):
|
|
| 757 |
drop_path_rate=0.25,
|
| 758 |
init_values=0.00001,
|
| 759 |
qk_normalization=True,
|
| 760 |
-
use_flash_attn=
|
| 761 |
use_fused_rmsnorm=False,
|
| 762 |
use_fused_mlp=False,
|
| 763 |
fused_mlp_heuristic=1,
|
|
|
|
| 24 |
to_numpy_array,
|
| 25 |
)
|
| 26 |
|
| 27 |
+
try:
|
| 28 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
| 29 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
| 30 |
+
except:
|
| 31 |
+
pass
|
| 32 |
|
| 33 |
class FlashAttention(nn.Module):
|
| 34 |
"""Implement the scaled dot product attention with softmax.
|
|
|
|
| 731 |
patch_size=14,
|
| 732 |
x_vis_return_idx=-2,
|
| 733 |
sep_image_video_pos_embed=True,
|
| 734 |
+
use_checkpoint=False,
|
| 735 |
checkpoint_num=40,
|
| 736 |
# **kwargs,
|
| 737 |
):
|
|
|
|
| 759 |
drop_path_rate=0.25,
|
| 760 |
init_values=0.00001,
|
| 761 |
qk_normalization=True,
|
| 762 |
+
use_flash_attn=torch.cuda.is_available(),
|
| 763 |
use_fused_rmsnorm=False,
|
| 764 |
use_fused_mlp=False,
|
| 765 |
fused_mlp_heuristic=1,
|