File size: 6,803 Bytes
eb17f6b 69d0681 eb17f6b 14a5552 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | import os
import torch
import warnings
from .model_minimind import *
from typing import Optional, Tuple, List
from torch import nn
from transformers import CLIPProcessor, CLIPModel
from typing import List
warnings.filterwarnings('ignore')
class VLMConfig(MiniMindConfig):
model_type = "minimind-v"
def __init__(
self,
image_special_token: str = '@' * 196,
image_ids: List = [34] * 196,
**kwargs,
):
self.image_special_token = image_special_token
self.image_ids = image_ids
super().__init__(**kwargs)
class VisionProj(nn.Module):
def __init__(self, ve_hidden_size=768, hidden_size=512):
super().__init__()
self.ve_hidden_size = ve_hidden_size
self.hidden_size = hidden_size
self.vision_proj = nn.Sequential(
nn.Linear(self.ve_hidden_size, self.hidden_size)
)
def forward(self, image_encoders):
vision_proj = self.vision_proj(image_encoders)
return vision_proj
# 继承自语言模型
class MiniMindVLM(MiniMindForCausalLM):
config_class = VLMConfig
def __init__(self, params: VLMConfig = None, vision_model_path="./model/vision_model/clip-vit-base-patch16"):
super().__init__(params)
if not params: params = VLMConfig()
self.params = params
self.vision_encoder, self.processor = self.__class__.get_vision_model(vision_model_path)
self.vision_proj = VisionProj(hidden_size=params.hidden_size)
@staticmethod
def get_vision_model(model_path: str):
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
if not os.path.exists(model_path):
return None, None
model = CLIPModel.from_pretrained(model_path)
processor = CLIPProcessor.from_pretrained(model_path)
# 冻结 vision_encoder 的所有参数
for param in model.parameters():
param.requires_grad = False
return model.eval(), processor
@staticmethod
def image2tensor(image, processor):
if image.mode in ['RGBA', 'LA']: image = image.convert('RGB')
inputs = processor(images=image, return_tensors="pt")['pixel_values']
return inputs
@staticmethod
def get_image_embeddings(image_tensors, vision_model):
with torch.no_grad():
outputs = vision_model.vision_model(pixel_values=image_tensors)
img_embedding = outputs.last_hidden_state[:, 1:, :].squeeze()
return img_embedding
def count_vision_proj(self, tokens, h, vision_tensors=None, seqlen=512):
def find_indices(tokens, image_ids):
image_ids_tensor = torch.tensor(image_ids).to(tokens.device)
len_image_ids = len(image_ids)
if len_image_ids > tokens.size(1):
return None
tokens_view = tokens.unfold(1, len_image_ids, 1)
matches = (tokens_view == image_ids_tensor).all(dim=2)
return {
batch_idx: [(idx.item(), idx.item() + len_image_ids - 1) for idx in
matches[batch_idx].nonzero(as_tuple=True)[0]]
for batch_idx in range(tokens.size(0)) if matches[batch_idx].any()
} or None
image_indices = find_indices(tokens, self.params.image_ids)
if vision_tensors is not None and image_indices:
vision_proj = self.vision_proj(vision_tensors)
if len(vision_proj.shape) == 3:
vision_proj = vision_proj.unsqueeze(0)
new_h = []
for i in range(h.size(0)):
if i in image_indices:
h_i = h[i]
img_idx = 0
for start_idx, end_idx in image_indices[i]:
if img_idx < vision_proj.size(1):
h_i = torch.cat((h_i[:start_idx], vision_proj[i][img_idx], h_i[end_idx + 1:]), dim=0)[
:seqlen]
img_idx += 1
new_h.append(h_i)
else:
new_h.append(h[i])
return torch.stack(new_h, dim=0)
return h
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
pixel_values: Optional[torch.FloatTensor] = None,
**args):
batch_size, seq_length = input_ids.shape
if hasattr(past_key_values, 'layers'): past_key_values = None
past_key_values = past_key_values or [None] * len(self.model.layers)
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
hidden_states = self.model.dropout(self.model.embed_tokens(input_ids))
if pixel_values is not None and start_pos == 0:
if len(pixel_values.shape) == 6:
pixel_values = pixel_values.squeeze(2)
bs, num, c, im_h, im_w = pixel_values.shape
stack_dim = 1 if bs > 1 else 0
vision_tensors = torch.stack([
MiniMindVLM.get_image_embeddings(pixel_values[:, i, :, :, :], self.vision_encoder)
for i in range(num)
], dim=stack_dim)
hidden_states = self.count_vision_proj(tokens=input_ids, h=hidden_states, vision_tensors=vision_tensors,
seqlen=input_ids.shape[1])
position_embeddings = (
self.model.freqs_cos[start_pos:start_pos + seq_length],
self.model.freqs_sin[start_pos:start_pos + seq_length]
)
presents = []
for layer_idx, (layer, past_key_value) in enumerate(zip(self.model.layers, past_key_values)):
hidden_states, present = layer(
hidden_states,
position_embeddings,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask
)
presents.append(present)
hidden_states = self.model.norm(hidden_states)
aux_loss = sum(
layer.mlp.aux_loss
for layer in self.model.layers
if isinstance(layer.mlp, MOEFeedForward)
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
output = CausalLMOutputWithPast(logits=logits, past_key_values=presents, hidden_states=hidden_states)
output.aux_loss = aux_loss
return output
|