Yuan3.0-Flash / modeling_yuanvl_chat.py
root
update code
4e18cb4
# --------------------------------------------------------
# YuanVL
# Copyright (c) 2024 YuanLabAI
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import warnings
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, Type, TypedDict, Union)
import torch.utils.checkpoint
import transformers
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
LlamaTokenizer)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging
from transformer_engine.pytorch import RMSNorm
from transformers.activations import ACT2FN
from .configuration_yuanvl import YuanVLChatConfig
from .conversation import get_conv_template
from .modeling_intern_vit import InternVisionModel, has_flash_attn
from .modeling_yuanlm2 import YuanForCausalLM
from .utils import flatten_bn, merge_multimodal_embeddings
logger = logging.get_logger(__name__)
class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Any # in vllm vision this is a NestedTensors
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
def version_cmp(v1, v2, op='eq'):
import operator
from packaging import version
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
class YuanImageMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
output_size: int,
hidden_act: str,
) -> None:
super().__init__()
#self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size, bias=False,)
#self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size, bias=False,)
#self.down_proj = RowParallelLinear(intermediate_size, output_size, bias=False,)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. Only silu is supported for now.")
self.act_fn = ACT2FN[hidden_act]
@torch.compile
def swiglu(self, y_1, y_2):
return self.act_fn(y_1) * y_2
def forward(self, x):
#import pdb
x1 = self.up_proj(x)
x2 = self.gate_proj(x)
x3 = self.swiglu(x1, x2)
#x3 = self.act_fn(x1)
#x2 = self.gate_proj(x)
x = self.down_proj(x3)
return x
class YuanVLChatModel(PreTrainedModel):
config_class = YuanVLChatConfig
main_input_name = 'pixel_values'
base_model_prefix = 'language_model'
_supports_flash_attn_2 = True
_no_split_modules = ['InternVisionModel', 'YuanDeocderLayer']
def __init__(self, config: YuanVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
super().__init__(config)
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.select_layer = config.select_layer
self.template = config.template
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
use_flash_attn = use_flash_attn if has_flash_attn else False
config.vision_config.use_flash_attn = True if use_flash_attn else False
config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
logger.info(f'num_image_token: {self.num_image_token}')
logger.info(f'ps_version: {self.ps_version}')
if vision_model is not None:
self.vision_model = vision_model
else:
self.vision_model = InternVisionModel(config.vision_config)
if language_model is not None:
self.language_model = language_model
else:
if config.llm_config.architectures[0] == 'YuanForCausalLM':
self.language_model = YuanForCausalLM(config.llm_config)
else:
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
self.pixel_unshuffle = torch.nn.PixelUnshuffle(downscale_factor=2)
#vit_hidden_size = config.vision_config.hidden_size
#llm_hidden_size = config.llm_config.hidden_size
#vit_mlp_ffn_hidden_size = config.vit_mlp_ffn_hidden_size
#layernorm_epsilon = config.llm_config.layernorm_epsilon
layernorm_epsilon = config.llm_config.rms_norm_eps
self.imagemlp_input_hiddensize = int(config.vision_config.hidden_size / self.downsample_ratio ** 2)
self.imagemlp_ffn_hidden_size = config.llm_config.ffn_hidden_size
self.imagemlp = YuanImageMLP(self.imagemlp_input_hiddensize, self.imagemlp_ffn_hidden_size,
output_size=config.llm_config.hidden_size, hidden_act="silu")
self.imagemlp_layernorm = RMSNorm(config.llm_config.hidden_size, eps=layernorm_epsilon)
'''
# modify internvl vision
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1/self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1/self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size)
)
'''
self.img_context_token_id = config.img_context_token_id
self.conv_template = get_conv_template(self.template)
self.system_message = self.conv_template.system_message
def _validate_pixel_values(self,
data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
# expected_expr = ("num_patches", *map(str, expected_dims))
expected_expr = (expected_dims)
raise ValueError("The expected shape of pixel values in each batch element "
f" is {expected_expr}. You supplied {tuple(d.shape)}.")
# data的数据类型可以是tensor,也可以是List[tensor]
# 从这一段上来看,image tensor的个数为 imbs*num_images
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(self,
pixel_values: List[torch.Tensor] = None,
image_token_id: torch.Tensor = None,
image_embeds: torch.Tensor = None,
) -> Optional[InternVLImagePixelInputs]:
# 没有图像数据
if pixel_values is None and image_embeds is None:
return None
# 传入数据有image_embeds
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
#self.img_context_token_id = image_token_id[0]
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
patches_per_image = []
# bsz/request循环
for request_pixel_values in pixel_values:
# 每个request的images循环
patches_per_image.append(request_pixel_values.shape[0])
# We need to flatten (B, N, P) to (B*N*P)
# so we call flatten_bn twice.
# (total_patches, 3, h, w)
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(flatten_bn(pixel_values)),
patches_per_image=patches_per_image)
raise AssertionError("This line should be unreachable")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> Tuple[torch.Tensor] :
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
# (total_patches, tokens_per_image, llm_config.hidden_size)
image_embeds = self.extract_feature(image_input["data"])
patches_per_image = image_input["patches_per_image"]
# Only one image in the current batch
# bsz=1的情况,直接返回image_embeds
if len(patches_per_image) == 1:
# 返回一个tensor,[1, num_patches*256, text_config.hidden_size]
image_embeds = image_embeds.view(-1, self.config.llm_config.hidden_size).unsqueeze(1)
return image_embeds
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
# feature_size 每个patch 256个token位置
feature_size = image_embeds.shape[1]
# (total_image_tokens, llm_config.hidden_size)
image_embeds = image_embeds.view(-1, self.config.llm_config.hidden_size)
image_feature_sizes = [num_patches * feature_size for num_patches in patches_per_image]
# 切分后得到一个Tuple,元组每个元胞表示一个image的image_embed, [num_patches * 256, llm_config.hidden_size]
image_embeds = image_embeds.split(image_feature_sizes)
return image_embeds
def get_multimodal_embeddings(self,
pixel_values: Optional[List[torch.Tensor]] = None,
image_token_id: Optional[List[torch.Tensor]] = None,
image_embeds: Optional[List[torch.Tensor]] = None,
image_input: InternVLImageInputs = None,
):
image_input = self._parse_and_validate_image_input(pixel_values, image_token_id, image_embeds)
if image_input is None:
return None
# image_input: (total_patches, 3, h, w)
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[torch.Tensor]
) -> torch.Tensor:
# 生成 token_embeddings
inputs_embeds = self.language_model.model.get_input_embeddings(input_ids)
# 将image embed放到img_context_token_id的位置
if multimodal_embeddings is not None:
assert self.img_context_token_id is not None
# input_ids: torch.Tensor
# inputs_embeds: torch.Tensor
# multimodal_embeddings: torch.Tensor
# placeholder_token_id: img_context_token_id
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.img_context_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
position_ids: torch.LongTensor = None,
past_key_values: List[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[List[torch.Tensor]] = None,
image_token_id: Optional[List[torch.Tensor]] = None,
image_embeds: Optional[List[torch.Tensor]] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
import pdb
pdb.set_trace()
if inputs_embeds is None:
# (images, patches * token_per_image)
vision_embeddings = self.get_multimodal_embeddings(pixel_values, image_token_id, image_embeds)
# (tokens, hidden_size)
inputs_embeds = self.get_input_embeddings(input_ids, vision_embeddings).permute(1, 0, 2)
input_ids = None
hidden_states = self.language_model.model(input_ids, attention_mask, position_ids, past_key_values,
inputs_embeds, labels, use_cache, output_attentions,
output_hidden_states, return_dict)
return hidden_states
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
if self.ps_version == 'v1':
warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
'which results in a transposed image.')
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
# Internvl vision
def extract_feature(self, pixel_values):
# pixel_values: (imbs * num_image, ic, ih, iw)
pixel_values = pixel_values.to(torch.bfloat16)
output = self.vision_model(pixel_values=pixel_values)
vit_embeds=output[0]
# vit_embeds: (imbs * num_images, h*w, vit_dim)
vit_embeds = vit_embeds[:, 1:, :]
pn, phw, pc = vit_embeds.shape
ph = pw = int(phw**0.5)
vit_embeds = vit_embeds.view(pn, ph, pw, pc).permute(0, 3, 1, 2)
vit_embeds = self.pixel_unshuffle(vit_embeds)
pn, pc, ph, pw = vit_embeds.shape
vit_embeds = vit_embeds.view(pn, pc, ph * pw).permute(0, 2, 1)
num_images, cvs, chs = vit_embeds.shape
#_, cvs, chs = vit_embeds.shape
#assert self.imagemlp_ffn_hidden_size == chs
#vit_embeds = vit_embeds.contiguous().view(imbs, num_image * cvs, chs).permute(1, 0, 2).contiguous()
vit_embeds = vit_embeds.reshape(1, -1, vit_embeds.shape[-1]).permute(1, 0, 2)
vit_embeds = self.imagemlp(vit_embeds)
vit_embeds = self.imagemlp_layernorm(vit_embeds)
vit_embeds = vit_embeds.view(num_images, cvs, -1)
return vit_embeds
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
position_ids: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
) -> torch.LongTensor:
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.get_multimodal_embeddings(pixel_values)
inputs_embeds = self.get_input_embeddings(input_ids, vit_embeds)
input_ids = None
outputs = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
position_ids=position_ids,
max_length=8192,
use_cache=True,
)
return outputs