minimind-3v / model_vlm.py
bilzepython's picture
Duplicate from jingyaogong/minimind-3v
21b4e01
import os
import torch
import warnings
from .model_minimind import *
from typing import Optional, Tuple, List, Union
from torch import nn
from transformers import Siglip2ImageProcessor, Siglip2VisionModel
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
warnings.filterwarnings('ignore')
class VLMConfig(MiniMindConfig):
model_type = "minimind-v"
def __init__(self, image_special_token='<|image_pad|>', image_ids=[12], **kwargs):
self.image_special_token = image_special_token
self.image_ids = image_ids
self.image_hidden_size = kwargs.get("image_hidden_size", 768)
self.image_token_len = kwargs.get("image_token_len", 64)
super().__init__(**kwargs)
class MMVisionProjector(nn.Module):
def __init__(self, in_dim, out_dim, source_tokens=256, target_tokens=64):
super().__init__()
self.target_tokens = target_tokens
self.merge = source_tokens // target_tokens
self.mlp = nn.Sequential(
nn.Linear(in_dim * self.merge, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim),
)
def forward(self, x):
b, n, d = x.shape
x = x.reshape(b, self.target_tokens, d * self.merge)
return self.mlp(x)
# 继承自语言模型
class MiniMindVLM(MiniMindForCausalLM):
config_class = VLMConfig
def __init__(self, config: VLMConfig = None, vision_model_path="./model/siglip2-base-p16-ve"):
self.config = config or VLMConfig()
super().__init__(self.config)
self.vision_encoder, self.processor = self.__class__.get_vision_model(vision_model_path)
self.vision_proj = MMVisionProjector(self.config.image_hidden_size, self.config.hidden_size, target_tokens=self.config.image_token_len)
@staticmethod
def get_vision_model(model_path: str):
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
if not os.path.exists(model_path):
return None, None
model = Siglip2VisionModel.from_pretrained(model_path)
processor = Siglip2ImageProcessor.from_pretrained(model_path)
# 冻结 vision_encoder 的所有参数
for param in model.parameters():
param.requires_grad = False
return model.eval(), processor
@staticmethod
def image2tensor(image, processor):
if image.mode in ['RGBA', 'LA']: image = image.convert('RGB')
inputs = processor(images=image, return_tensors="pt")
return inputs
@staticmethod
def get_image_embeddings(image_inputs, vision_model):
if hasattr(image_inputs, 'keys'):
image_inputs = {k: v.squeeze(1) if v.ndim > 2 and v.shape[1] == 1 else v for k, v in image_inputs.items()}
with torch.no_grad():
outputs = vision_model(**image_inputs)
return outputs.last_hidden_state
@torch.compiler.disable
def count_vision_proj(self, tokens, h, vision_tensors=None, seqlen=512):
if vision_tensors is None or not self.config.image_ids:
return h
marker, vf = self.config.image_ids[0], vision_tensors
if vf.dim() == 3:
vf = vf.unsqueeze(1)
out = []
for b in range(h.size(0)):
hb, seq, k, i = h[b], tokens[b].tolist(), 0, 0
while i < len(seq):
if seq[i] == marker:
start = i
while i < len(seq) and seq[i] == marker:
i += 1
if k < vf.size(1):
hb = torch.cat((hb[:start], vf[b][k][:i - start], hb[i:]), dim=0)[:seqlen]
k += 1
else:
i += 1
out.append(hb)
return torch.stack(out)
def forward(self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
logits_to_keep: Union[int, torch.Tensor] = 0,
labels: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
**args):
batch_size, seq_length = input_ids.shape
if hasattr(past_key_values, 'layers'): past_key_values = None
past_key_values = past_key_values or [None] * len(self.model.layers)
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
hidden_states = self.model.dropout(self.model.embed_tokens(input_ids))
if pixel_values is not None and start_pos == 0:
if hasattr(pixel_values, 'keys'):
img_emb = MiniMindVLM.get_image_embeddings(pixel_values, self.vision_encoder)
vision_tensors = self.vision_proj(img_emb)
else:
if len(pixel_values.shape) == 6:
pixel_values = pixel_values.squeeze(2)
bs, num, c, im_h, im_w = pixel_values.shape
stack_dim = 1 if bs > 1 else 0
vision_tensors = torch.stack([self.vision_proj(MiniMindVLM.get_image_embeddings(pixel_values[:, i, :, :, :], self.vision_encoder)) for i in range(num)], dim=stack_dim)
hidden_states = self.count_vision_proj(tokens=input_ids, h=hidden_states, vision_tensors=vision_tensors, seqlen=input_ids.shape[1])
position_embeddings = (
self.model.freqs_cos[start_pos:start_pos + seq_length],
self.model.freqs_sin[start_pos:start_pos + seq_length]
)
presents = []
for layer_idx, (layer, past_key_value) in enumerate(zip(self.model.layers, past_key_values)):
hidden_states, present = layer(
hidden_states,
position_embeddings,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask
)
presents.append(present)
hidden_states = self.model.norm(hidden_states)
aux_loss = sum([l.mlp.aux_loss for l in self.model.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
output = MoeCausalLMOutputWithPast(loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=presents, hidden_states=hidden_states)
return output