File size: 6,868 Bytes
21b4e01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import os
import torch
import warnings
from .model_minimind import *
from typing import Optional, Tuple, List, Union
from torch import nn
from transformers import Siglip2ImageProcessor, Siglip2VisionModel
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
warnings.filterwarnings('ignore')
class VLMConfig(MiniMindConfig):
model_type = "minimind-v"
def __init__(self, image_special_token='<|image_pad|>', image_ids=[12], **kwargs):
self.image_special_token = image_special_token
self.image_ids = image_ids
self.image_hidden_size = kwargs.get("image_hidden_size", 768)
self.image_token_len = kwargs.get("image_token_len", 64)
super().__init__(**kwargs)
class MMVisionProjector(nn.Module):
def __init__(self, in_dim, out_dim, source_tokens=256, target_tokens=64):
super().__init__()
self.target_tokens = target_tokens
self.merge = source_tokens // target_tokens
self.mlp = nn.Sequential(
nn.Linear(in_dim * self.merge, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim),
)
def forward(self, x):
b, n, d = x.shape
x = x.reshape(b, self.target_tokens, d * self.merge)
return self.mlp(x)
# 继承自语言模型
class MiniMindVLM(MiniMindForCausalLM):
config_class = VLMConfig
def __init__(self, config: VLMConfig = None, vision_model_path="./model/siglip2-base-p16-ve"):
self.config = config or VLMConfig()
super().__init__(self.config)
self.vision_encoder, self.processor = self.__class__.get_vision_model(vision_model_path)
self.vision_proj = MMVisionProjector(self.config.image_hidden_size, self.config.hidden_size, target_tokens=self.config.image_token_len)
@staticmethod
def get_vision_model(model_path: str):
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
if not os.path.exists(model_path):
return None, None
model = Siglip2VisionModel.from_pretrained(model_path)
processor = Siglip2ImageProcessor.from_pretrained(model_path)
# 冻结 vision_encoder 的所有参数
for param in model.parameters():
param.requires_grad = False
return model.eval(), processor
@staticmethod
def image2tensor(image, processor):
if image.mode in ['RGBA', 'LA']: image = image.convert('RGB')
inputs = processor(images=image, return_tensors="pt")
return inputs
@staticmethod
def get_image_embeddings(image_inputs, vision_model):
if hasattr(image_inputs, 'keys'):
image_inputs = {k: v.squeeze(1) if v.ndim > 2 and v.shape[1] == 1 else v for k, v in image_inputs.items()}
with torch.no_grad():
outputs = vision_model(**image_inputs)
return outputs.last_hidden_state
@torch.compiler.disable
def count_vision_proj(self, tokens, h, vision_tensors=None, seqlen=512):
if vision_tensors is None or not self.config.image_ids:
return h
marker, vf = self.config.image_ids[0], vision_tensors
if vf.dim() == 3:
vf = vf.unsqueeze(1)
out = []
for b in range(h.size(0)):
hb, seq, k, i = h[b], tokens[b].tolist(), 0, 0
while i < len(seq):
if seq[i] == marker:
start = i
while i < len(seq) and seq[i] == marker:
i += 1
if k < vf.size(1):
hb = torch.cat((hb[:start], vf[b][k][:i - start], hb[i:]), dim=0)[:seqlen]
k += 1
else:
i += 1
out.append(hb)
return torch.stack(out)
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
labels: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
**args):
batch_size, seq_length = input_ids.shape
if hasattr(past_key_values, 'layers'): past_key_values = None
past_key_values = past_key_values or [None] * len(self.model.layers)
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
hidden_states = self.model.dropout(self.model.embed_tokens(input_ids))
if pixel_values is not None and start_pos == 0:
if hasattr(pixel_values, 'keys'):
img_emb = MiniMindVLM.get_image_embeddings(pixel_values, self.vision_encoder)
vision_tensors = self.vision_proj(img_emb)
else:
if len(pixel_values.shape) == 6:
pixel_values = pixel_values.squeeze(2)
bs, num, c, im_h, im_w = pixel_values.shape
stack_dim = 1 if bs > 1 else 0
vision_tensors = torch.stack([self.vision_proj(MiniMindVLM.get_image_embeddings(pixel_values[:, i, :, :, :], self.vision_encoder)) for i in range(num)], dim=stack_dim)
hidden_states = self.count_vision_proj(tokens=input_ids, h=hidden_states, vision_tensors=vision_tensors, seqlen=input_ids.shape[1])
position_embeddings = (
self.model.freqs_cos[start_pos:start_pos + seq_length],
self.model.freqs_sin[start_pos:start_pos + seq_length]
)
presents = []
for layer_idx, (layer, past_key_value) in enumerate(zip(self.model.layers, past_key_values)):
hidden_states, present = layer(
hidden_states,
position_embeddings,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask
)
presents.append(present)
hidden_states = self.model.norm(hidden_states)
aux_loss = sum([l.mlp.aux_loss for l in self.model.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
output = MoeCausalLMOutputWithPast(loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=presents, hidden_states=hidden_states)
return output
|