| import os |
| import torch |
| import warnings |
| from .model_minimind import * |
| from typing import Optional, Tuple, List, Union |
| from torch import nn |
| from transformers import Siglip2ImageProcessor, Siglip2VisionModel |
| from transformers.modeling_outputs import MoeCausalLMOutputWithPast |
|
|
| warnings.filterwarnings('ignore') |
|
|
|
|
| class VLMConfig(MiniMindConfig): |
| model_type = "minimind-v" |
|
|
| def __init__(self, image_special_token='<|image_pad|>', image_ids=[12], **kwargs): |
| self.image_special_token = image_special_token |
| self.image_ids = image_ids |
| self.image_hidden_size = kwargs.get("image_hidden_size", 768) |
| self.image_token_len = kwargs.get("image_token_len", 64) |
| super().__init__(**kwargs) |
|
|
| class MMVisionProjector(nn.Module): |
| def __init__(self, in_dim, out_dim, source_tokens=256, target_tokens=64): |
| super().__init__() |
| self.target_tokens = target_tokens |
| self.merge = source_tokens // target_tokens |
| self.mlp = nn.Sequential( |
| nn.Linear(in_dim * self.merge, out_dim), |
| nn.GELU(), |
| nn.Linear(out_dim, out_dim), |
| ) |
| def forward(self, x): |
| b, n, d = x.shape |
| x = x.reshape(b, self.target_tokens, d * self.merge) |
| return self.mlp(x) |
|
|
| |
| class MiniMindVLM(MiniMindForCausalLM): |
| config_class = VLMConfig |
|
|
| def __init__(self, config: VLMConfig = None, vision_model_path="./model/siglip2-base-p16-ve"): |
| self.config = config or VLMConfig() |
| super().__init__(self.config) |
| self.vision_encoder, self.processor = self.__class__.get_vision_model(vision_model_path) |
| self.vision_proj = MMVisionProjector(self.config.image_hidden_size, self.config.hidden_size, target_tokens=self.config.image_token_len) |
|
|
| @staticmethod |
| def get_vision_model(model_path: str): |
| from transformers import logging as hf_logging |
| hf_logging.set_verbosity_error() |
| if not os.path.exists(model_path): |
| return None, None |
| model = Siglip2VisionModel.from_pretrained(model_path) |
| processor = Siglip2ImageProcessor.from_pretrained(model_path) |
| |
| for param in model.parameters(): |
| param.requires_grad = False |
| return model.eval(), processor |
|
|
| @staticmethod |
| def image2tensor(image, processor): |
| if image.mode in ['RGBA', 'LA']: image = image.convert('RGB') |
| inputs = processor(images=image, return_tensors="pt") |
| return inputs |
|
|
| @staticmethod |
| def get_image_embeddings(image_inputs, vision_model): |
| if hasattr(image_inputs, 'keys'): |
| image_inputs = {k: v.squeeze(1) if v.ndim > 2 and v.shape[1] == 1 else v for k, v in image_inputs.items()} |
| with torch.no_grad(): |
| outputs = vision_model(**image_inputs) |
| return outputs.last_hidden_state |
|
|
| @torch.compiler.disable |
| def count_vision_proj(self, tokens, h, vision_tensors=None, seqlen=512): |
| if vision_tensors is None or not self.config.image_ids: |
| return h |
| marker, vf = self.config.image_ids[0], vision_tensors |
| if vf.dim() == 3: |
| vf = vf.unsqueeze(1) |
| out = [] |
| for b in range(h.size(0)): |
| hb, seq, k, i = h[b], tokens[b].tolist(), 0, 0 |
| while i < len(seq): |
| if seq[i] == marker: |
| start = i |
| while i < len(seq) and seq[i] == marker: |
| i += 1 |
| if k < vf.size(1): |
| hb = torch.cat((hb[:start], vf[b][k][:i - start], hb[i:]), dim=0)[:seqlen] |
| k += 1 |
| else: |
| i += 1 |
| out.append(hb) |
| return torch.stack(out) |
|
|
| def forward(self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| use_cache: bool = False, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| labels: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| **args): |
| batch_size, seq_length = input_ids.shape |
| if hasattr(past_key_values, 'layers'): past_key_values = None |
| past_key_values = past_key_values or [None] * len(self.model.layers) |
| start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0 |
|
|
| hidden_states = self.model.dropout(self.model.embed_tokens(input_ids)) |
|
|
| if pixel_values is not None and start_pos == 0: |
| if hasattr(pixel_values, 'keys'): |
| img_emb = MiniMindVLM.get_image_embeddings(pixel_values, self.vision_encoder) |
| vision_tensors = self.vision_proj(img_emb) |
| else: |
| if len(pixel_values.shape) == 6: |
| pixel_values = pixel_values.squeeze(2) |
| bs, num, c, im_h, im_w = pixel_values.shape |
| stack_dim = 1 if bs > 1 else 0 |
| vision_tensors = torch.stack([self.vision_proj(MiniMindVLM.get_image_embeddings(pixel_values[:, i, :, :, :], self.vision_encoder)) for i in range(num)], dim=stack_dim) |
| hidden_states = self.count_vision_proj(tokens=input_ids, h=hidden_states, vision_tensors=vision_tensors, seqlen=input_ids.shape[1]) |
|
|
| position_embeddings = ( |
| self.model.freqs_cos[start_pos:start_pos + seq_length], |
| self.model.freqs_sin[start_pos:start_pos + seq_length] |
| ) |
|
|
| presents = [] |
| for layer_idx, (layer, past_key_value) in enumerate(zip(self.model.layers, past_key_values)): |
| hidden_states, present = layer( |
| hidden_states, |
| position_embeddings, |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| attention_mask=attention_mask |
| ) |
| presents.append(present) |
|
|
| hidden_states = self.model.norm(hidden_states) |
|
|
| aux_loss = sum([l.mlp.aux_loss for l in self.model.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze()) |
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100) |
|
|
| output = MoeCausalLMOutputWithPast(loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=presents, hidden_states=hidden_states) |
| return output |
|
|