|
|
""" |
|
|
Multimodal Vision Module for MiniMind Max2 |
|
|
Adapter-based approach using SigLIP/DINOv2 vision encoders. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Optional, Dict, Any, Tuple, Union |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import math |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class VisionConfig: |
|
|
"""Configuration for vision adapter.""" |
|
|
|
|
|
vision_encoder: str = "siglip-so400m" |
|
|
vision_hidden_size: int = 1152 |
|
|
image_size: int = 384 |
|
|
patch_size: int = 14 |
|
|
num_image_tokens: int = 729 |
|
|
|
|
|
|
|
|
projector_type: str = "mlp" |
|
|
projector_hidden_size: int = 2048 |
|
|
projector_num_layers: int = 2 |
|
|
|
|
|
|
|
|
llm_hidden_size: int = 1024 |
|
|
|
|
|
|
|
|
freeze_vision_encoder: bool = True |
|
|
freeze_llm: bool = True |
|
|
train_projector_only: bool = True |
|
|
|
|
|
|
|
|
image_start_token: str = "<image>" |
|
|
image_end_token: str = "</image>" |
|
|
image_pad_token: str = "<image_pad>" |
|
|
|
|
|
|
|
|
class MLPProjector(nn.Module): |
|
|
""" |
|
|
Multi-Layer Perceptron projector for vision-language alignment. |
|
|
Maps vision encoder outputs to LLM embedding space. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: VisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
layers = [] |
|
|
input_size = config.vision_hidden_size |
|
|
|
|
|
for i in range(config.projector_num_layers): |
|
|
if i == config.projector_num_layers - 1: |
|
|
|
|
|
layers.extend([ |
|
|
nn.Linear(input_size, config.llm_hidden_size), |
|
|
]) |
|
|
else: |
|
|
|
|
|
layers.extend([ |
|
|
nn.Linear(input_size, config.projector_hidden_size), |
|
|
nn.GELU(), |
|
|
nn.LayerNorm(config.projector_hidden_size), |
|
|
]) |
|
|
input_size = config.projector_hidden_size |
|
|
|
|
|
self.projector = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, vision_features: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Project vision features to LLM space. |
|
|
|
|
|
Args: |
|
|
vision_features: [batch, num_patches, vision_hidden_size] |
|
|
|
|
|
Returns: |
|
|
Projected features: [batch, num_patches, llm_hidden_size] |
|
|
""" |
|
|
return self.projector(vision_features) |
|
|
|
|
|
|
|
|
class Resampler(nn.Module): |
|
|
""" |
|
|
Perceiver-style resampler for compressing vision tokens. |
|
|
Reduces number of image tokens while preserving information. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: VisionConfig, |
|
|
num_queries: int = 64, |
|
|
num_heads: int = 8, |
|
|
num_layers: int = 2, |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.num_queries = num_queries |
|
|
|
|
|
|
|
|
self.queries = nn.Parameter(torch.randn(1, num_queries, config.llm_hidden_size)) |
|
|
|
|
|
|
|
|
self.input_proj = nn.Linear(config.vision_hidden_size, config.llm_hidden_size) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
nn.TransformerDecoderLayer( |
|
|
d_model=config.llm_hidden_size, |
|
|
nhead=num_heads, |
|
|
dim_feedforward=config.llm_hidden_size * 4, |
|
|
batch_first=True, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
self.norm = nn.LayerNorm(config.llm_hidden_size) |
|
|
|
|
|
def forward(self, vision_features: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Resample vision features using learned queries. |
|
|
|
|
|
Args: |
|
|
vision_features: [batch, num_patches, vision_hidden_size] |
|
|
|
|
|
Returns: |
|
|
Resampled features: [batch, num_queries, llm_hidden_size] |
|
|
""" |
|
|
batch_size = vision_features.shape[0] |
|
|
|
|
|
|
|
|
vision_features = self.input_proj(vision_features) |
|
|
|
|
|
|
|
|
queries = self.queries.expand(batch_size, -1, -1) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
queries = layer(queries, vision_features) |
|
|
|
|
|
return self.norm(queries) |
|
|
|
|
|
|
|
|
class VisionEncoder(nn.Module): |
|
|
""" |
|
|
Wrapper for pre-trained vision encoders. |
|
|
Supports SigLIP, DINOv2, and CLIP. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: VisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.encoder = None |
|
|
self.processor = None |
|
|
|
|
|
|
|
|
|
|
|
self._build_dummy_encoder() |
|
|
|
|
|
def _build_dummy_encoder(self): |
|
|
"""Build a dummy encoder for testing.""" |
|
|
|
|
|
patch_dim = 3 * (self.config.patch_size ** 2) |
|
|
num_patches = (self.config.image_size // self.config.patch_size) ** 2 |
|
|
|
|
|
self.patch_embed = nn.Linear(patch_dim, self.config.vision_hidden_size) |
|
|
self.pos_embed = nn.Parameter( |
|
|
torch.randn(1, num_patches + 1, self.config.vision_hidden_size) * 0.02 |
|
|
) |
|
|
self.cls_token = nn.Parameter( |
|
|
torch.randn(1, 1, self.config.vision_hidden_size) * 0.02 |
|
|
) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=self.config.vision_hidden_size, |
|
|
nhead=8, |
|
|
dim_feedforward=self.config.vision_hidden_size * 4, |
|
|
batch_first=True, |
|
|
) |
|
|
for _ in range(6) |
|
|
]) |
|
|
self.norm = nn.LayerNorm(self.config.vision_hidden_size) |
|
|
|
|
|
def patchify(self, images: torch.Tensor) -> torch.Tensor: |
|
|
"""Convert images to patches.""" |
|
|
batch_size, c, h, w = images.shape |
|
|
p = self.config.patch_size |
|
|
|
|
|
|
|
|
patches = images.unfold(2, p, p).unfold(3, p, p) |
|
|
patches = patches.contiguous().view(batch_size, c, -1, p, p) |
|
|
patches = patches.permute(0, 2, 1, 3, 4).contiguous() |
|
|
patches = patches.view(batch_size, -1, c * p * p) |
|
|
|
|
|
return patches |
|
|
|
|
|
def forward(self, images: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Encode images to feature vectors. |
|
|
|
|
|
Args: |
|
|
images: [batch, 3, height, width] normalized images |
|
|
|
|
|
Returns: |
|
|
Vision features: [batch, num_patches, vision_hidden_size] |
|
|
""" |
|
|
batch_size = images.shape[0] |
|
|
|
|
|
|
|
|
patches = self.patchify(images) |
|
|
x = self.patch_embed(patches) |
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
|
|
x = torch.cat([cls_tokens, x], dim=1) |
|
|
|
|
|
|
|
|
x = x + self.pos_embed[:, :x.shape[1], :] |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
|
|
|
return x[:, 1:, :] |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_name: str, config: VisionConfig) -> "VisionEncoder": |
|
|
"""Load pre-trained vision encoder.""" |
|
|
encoder = cls(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return encoder |
|
|
|
|
|
|
|
|
class VisionAdapter(nn.Module): |
|
|
""" |
|
|
Complete vision adapter for MiniMind Max2. |
|
|
Connects vision encoder to LLM via projector. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: VisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.vision_encoder = VisionEncoder(config) |
|
|
|
|
|
|
|
|
if config.projector_type == "mlp": |
|
|
self.projector = MLPProjector(config) |
|
|
elif config.projector_type == "resampler": |
|
|
self.projector = Resampler(config) |
|
|
else: |
|
|
self.projector = nn.Linear(config.vision_hidden_size, config.llm_hidden_size) |
|
|
|
|
|
|
|
|
if config.freeze_vision_encoder: |
|
|
for param in self.vision_encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
images: torch.Tensor, |
|
|
return_features: bool = False, |
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
""" |
|
|
Process images and project to LLM space. |
|
|
|
|
|
Args: |
|
|
images: [batch, 3, height, width] |
|
|
return_features: Also return raw vision features |
|
|
|
|
|
Returns: |
|
|
Projected features: [batch, num_tokens, llm_hidden_size] |
|
|
""" |
|
|
|
|
|
vision_features = self.vision_encoder(images) |
|
|
|
|
|
|
|
|
projected = self.projector(vision_features) |
|
|
|
|
|
if return_features: |
|
|
return projected, vision_features |
|
|
return projected |
|
|
|
|
|
def get_num_image_tokens(self) -> int: |
|
|
"""Get number of tokens per image.""" |
|
|
if isinstance(self.projector, Resampler): |
|
|
return self.projector.num_queries |
|
|
return self.config.num_image_tokens |
|
|
|
|
|
|
|
|
class MiniMindVision(nn.Module): |
|
|
""" |
|
|
Complete vision-language model combining MiniMind Max2 with vision adapter. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
llm_model: nn.Module, |
|
|
vision_config: Optional[VisionConfig] = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
if hasattr(llm_model, 'config'): |
|
|
llm_hidden_size = llm_model.config.hidden_size |
|
|
else: |
|
|
llm_hidden_size = 1024 |
|
|
|
|
|
|
|
|
self.vision_config = vision_config or VisionConfig(llm_hidden_size=llm_hidden_size) |
|
|
|
|
|
|
|
|
self.llm = llm_model |
|
|
self.vision_adapter = VisionAdapter(self.vision_config) |
|
|
|
|
|
|
|
|
if self.vision_config.freeze_llm: |
|
|
for param in self.llm.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def merge_vision_text_embeddings( |
|
|
self, |
|
|
text_embeddings: torch.Tensor, |
|
|
vision_embeddings: torch.Tensor, |
|
|
image_positions: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Merge vision embeddings into text embedding sequence. |
|
|
|
|
|
Args: |
|
|
text_embeddings: [batch, text_seq_len, hidden_size] |
|
|
vision_embeddings: [batch, num_image_tokens, hidden_size] |
|
|
image_positions: [batch] position indices for image tokens |
|
|
|
|
|
Returns: |
|
|
Merged embeddings: [batch, total_seq_len, hidden_size] |
|
|
""" |
|
|
batch_size = text_embeddings.shape[0] |
|
|
num_image_tokens = vision_embeddings.shape[1] |
|
|
|
|
|
|
|
|
text_len = text_embeddings.shape[1] |
|
|
total_len = text_len + num_image_tokens |
|
|
|
|
|
|
|
|
merged = torch.zeros( |
|
|
batch_size, total_len, text_embeddings.shape[-1], |
|
|
device=text_embeddings.device, |
|
|
dtype=text_embeddings.dtype, |
|
|
) |
|
|
|
|
|
for i in range(batch_size): |
|
|
pos = image_positions[i].item() |
|
|
|
|
|
|
|
|
if pos > 0: |
|
|
merged[i, :pos] = text_embeddings[i, :pos] |
|
|
|
|
|
|
|
|
merged[i, pos:pos + num_image_tokens] = vision_embeddings[i] |
|
|
|
|
|
|
|
|
if pos < text_len: |
|
|
merged[i, pos + num_image_tokens:] = text_embeddings[i, pos:] |
|
|
|
|
|
return merged |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor, |
|
|
images: Optional[torch.Tensor] = None, |
|
|
image_positions: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: |
|
|
""" |
|
|
Forward pass with optional images. |
|
|
|
|
|
Args: |
|
|
input_ids: Text token IDs |
|
|
images: Optional batch of images |
|
|
image_positions: Where to insert image tokens |
|
|
attention_mask: Attention mask for text |
|
|
labels: Labels for language modeling |
|
|
|
|
|
Returns: |
|
|
Loss (if labels provided) and logits |
|
|
""" |
|
|
|
|
|
if hasattr(self.llm, 'model'): |
|
|
text_embeddings = self.llm.model.embed_tokens(input_ids) |
|
|
else: |
|
|
text_embeddings = self.llm.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
if images is not None: |
|
|
vision_embeddings = self.vision_adapter(images) |
|
|
|
|
|
if image_positions is None: |
|
|
|
|
|
image_positions = torch.zeros(images.shape[0], dtype=torch.long, device=images.device) |
|
|
|
|
|
|
|
|
merged_embeddings = self.merge_vision_text_embeddings( |
|
|
text_embeddings, vision_embeddings, image_positions |
|
|
) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
num_image_tokens = vision_embeddings.shape[1] |
|
|
image_mask = torch.ones( |
|
|
images.shape[0], num_image_tokens, |
|
|
device=attention_mask.device, |
|
|
dtype=attention_mask.dtype, |
|
|
) |
|
|
attention_mask = torch.cat([image_mask, attention_mask], dim=1) |
|
|
else: |
|
|
merged_embeddings = text_embeddings |
|
|
|
|
|
|
|
|
|
|
|
loss, logits, _, _ = self.llm( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels, |
|
|
) |
|
|
|
|
|
return loss, logits |
|
|
|
|
|
@torch.no_grad() |
|
|
def caption_image( |
|
|
self, |
|
|
image: torch.Tensor, |
|
|
prompt: str = "Describe this image:", |
|
|
max_new_tokens: int = 100, |
|
|
tokenizer = None, |
|
|
) -> str: |
|
|
"""Generate caption for an image.""" |
|
|
self.eval() |
|
|
|
|
|
|
|
|
vision_embeddings = self.vision_adapter(image.unsqueeze(0)) |
|
|
|
|
|
|
|
|
if tokenizer is not None: |
|
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(image.device) |
|
|
else: |
|
|
|
|
|
input_ids = torch.randint(0, 1000, (1, 10), device=image.device) |
|
|
|
|
|
|
|
|
|
|
|
generated = self.llm.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
) |
|
|
|
|
|
if tokenizer is not None: |
|
|
return tokenizer.decode(generated[0], skip_special_tokens=True) |
|
|
return "Generated caption placeholder" |
|
|
|
|
|
|
|
|
class VisionDataset(Dataset): |
|
|
"""Dataset for vision-language training.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_path: str, |
|
|
tokenizer, |
|
|
image_processor, |
|
|
max_length: int = 512, |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.image_processor = image_processor |
|
|
self.max_length = max_length |
|
|
self.examples = [] |
|
|
|
|
|
|
|
|
import json |
|
|
with open(data_path, 'r') as f: |
|
|
self.examples = json.load(f) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.examples) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, Any]: |
|
|
example = self.examples[idx] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = torch.randn(3, 384, 384) |
|
|
|
|
|
|
|
|
text = example.get("conversations", [{"value": "Describe the image."}])[0]["value"] |
|
|
encodings = self.tokenizer( |
|
|
text, |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
return { |
|
|
"image": image, |
|
|
"input_ids": encodings["input_ids"].squeeze(0), |
|
|
"attention_mask": encodings["attention_mask"].squeeze(0), |
|
|
"labels": encodings["input_ids"].squeeze(0), |
|
|
} |
|
|
|