"""
Multimodal Vision Module for MiniMind Max2
Adapter-based approach using SigLIP/DINOv2 vision encoders.
"""
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
@dataclass
class VisionConfig:
"""Configuration for vision adapter."""
# Vision encoder settings
vision_encoder: str = "siglip-so400m" # siglip-so400m, dinov2-small, clip-vit-base
vision_hidden_size: int = 1152 # SigLIP-So400M hidden size
image_size: int = 384
patch_size: int = 14
num_image_tokens: int = 729 # (384/14)^2 = 729 patches
# Projector settings
projector_type: str = "mlp" # mlp, linear, resampler
projector_hidden_size: int = 2048
projector_num_layers: int = 2
# LLM settings (to match MiniMind)
llm_hidden_size: int = 1024 # MiniMind hidden size
# Training settings
freeze_vision_encoder: bool = True
freeze_llm: bool = True
train_projector_only: bool = True
# Special tokens
image_start_token: str = ""
image_end_token: str = ""
image_pad_token: str = ""
class MLPProjector(nn.Module):
"""
Multi-Layer Perceptron projector for vision-language alignment.
Maps vision encoder outputs to LLM embedding space.
"""
def __init__(self, config: VisionConfig):
super().__init__()
self.config = config
layers = []
input_size = config.vision_hidden_size
for i in range(config.projector_num_layers):
if i == config.projector_num_layers - 1:
# Last layer projects to LLM size
layers.extend([
nn.Linear(input_size, config.llm_hidden_size),
])
else:
# Hidden layers
layers.extend([
nn.Linear(input_size, config.projector_hidden_size),
nn.GELU(),
nn.LayerNorm(config.projector_hidden_size),
])
input_size = config.projector_hidden_size
self.projector = nn.Sequential(*layers)
def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
"""
Project vision features to LLM space.
Args:
vision_features: [batch, num_patches, vision_hidden_size]
Returns:
Projected features: [batch, num_patches, llm_hidden_size]
"""
return self.projector(vision_features)
class Resampler(nn.Module):
"""
Perceiver-style resampler for compressing vision tokens.
Reduces number of image tokens while preserving information.
"""
def __init__(
self,
config: VisionConfig,
num_queries: int = 64,
num_heads: int = 8,
num_layers: int = 2,
):
super().__init__()
self.config = config
self.num_queries = num_queries
# Learnable query tokens
self.queries = nn.Parameter(torch.randn(1, num_queries, config.llm_hidden_size))
# Input projection
self.input_proj = nn.Linear(config.vision_hidden_size, config.llm_hidden_size)
# Cross-attention layers
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=config.llm_hidden_size,
nhead=num_heads,
dim_feedforward=config.llm_hidden_size * 4,
batch_first=True,
)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(config.llm_hidden_size)
def forward(self, vision_features: torch.Tensor) -> torch.Tensor:
"""
Resample vision features using learned queries.
Args:
vision_features: [batch, num_patches, vision_hidden_size]
Returns:
Resampled features: [batch, num_queries, llm_hidden_size]
"""
batch_size = vision_features.shape[0]
# Project vision features
vision_features = self.input_proj(vision_features)
# Expand queries for batch
queries = self.queries.expand(batch_size, -1, -1)
# Cross-attend to vision features
for layer in self.layers:
queries = layer(queries, vision_features)
return self.norm(queries)
class VisionEncoder(nn.Module):
"""
Wrapper for pre-trained vision encoders.
Supports SigLIP, DINOv2, and CLIP.
"""
def __init__(self, config: VisionConfig):
super().__init__()
self.config = config
self.encoder = None
self.processor = None
# Placeholder for actual encoder loading
# In practice, load from HuggingFace
self._build_dummy_encoder()
def _build_dummy_encoder(self):
"""Build a dummy encoder for testing."""
# Simple ViT-like encoder
patch_dim = 3 * (self.config.patch_size ** 2)
num_patches = (self.config.image_size // self.config.patch_size) ** 2
self.patch_embed = nn.Linear(patch_dim, self.config.vision_hidden_size)
self.pos_embed = nn.Parameter(
torch.randn(1, num_patches + 1, self.config.vision_hidden_size) * 0.02
)
self.cls_token = nn.Parameter(
torch.randn(1, 1, self.config.vision_hidden_size) * 0.02
)
# Transformer layers
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=self.config.vision_hidden_size,
nhead=8,
dim_feedforward=self.config.vision_hidden_size * 4,
batch_first=True,
)
for _ in range(6)
])
self.norm = nn.LayerNorm(self.config.vision_hidden_size)
def patchify(self, images: torch.Tensor) -> torch.Tensor:
"""Convert images to patches."""
batch_size, c, h, w = images.shape
p = self.config.patch_size
# [B, C, H, W] -> [B, num_patches, patch_dim]
patches = images.unfold(2, p, p).unfold(3, p, p)
patches = patches.contiguous().view(batch_size, c, -1, p, p)
patches = patches.permute(0, 2, 1, 3, 4).contiguous()
patches = patches.view(batch_size, -1, c * p * p)
return patches
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Encode images to feature vectors.
Args:
images: [batch, 3, height, width] normalized images
Returns:
Vision features: [batch, num_patches, vision_hidden_size]
"""
batch_size = images.shape[0]
# Patchify and embed
patches = self.patchify(images)
x = self.patch_embed(patches)
# Add CLS token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# Add positional embeddings
x = x + self.pos_embed[:, :x.shape[1], :]
# Transformer
for layer in self.layers:
x = layer(x)
x = self.norm(x)
# Return patch features (exclude CLS)
return x[:, 1:, :]
@classmethod
def from_pretrained(cls, model_name: str, config: VisionConfig) -> "VisionEncoder":
"""Load pre-trained vision encoder."""
encoder = cls(config)
# In practice, load weights from HuggingFace
# try:
# from transformers import SiglipVisionModel, AutoProcessor
# encoder.encoder = SiglipVisionModel.from_pretrained(model_name)
# encoder.processor = AutoProcessor.from_pretrained(model_name)
# except ImportError:
# pass
return encoder
class VisionAdapter(nn.Module):
"""
Complete vision adapter for MiniMind Max2.
Connects vision encoder to LLM via projector.
"""
def __init__(self, config: VisionConfig):
super().__init__()
self.config = config
# Vision encoder
self.vision_encoder = VisionEncoder(config)
# Projector
if config.projector_type == "mlp":
self.projector = MLPProjector(config)
elif config.projector_type == "resampler":
self.projector = Resampler(config)
else:
self.projector = nn.Linear(config.vision_hidden_size, config.llm_hidden_size)
# Freeze components as needed
if config.freeze_vision_encoder:
for param in self.vision_encoder.parameters():
param.requires_grad = False
def forward(
self,
images: torch.Tensor,
return_features: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Process images and project to LLM space.
Args:
images: [batch, 3, height, width]
return_features: Also return raw vision features
Returns:
Projected features: [batch, num_tokens, llm_hidden_size]
"""
# Encode images
vision_features = self.vision_encoder(images)
# Project to LLM space
projected = self.projector(vision_features)
if return_features:
return projected, vision_features
return projected
def get_num_image_tokens(self) -> int:
"""Get number of tokens per image."""
if isinstance(self.projector, Resampler):
return self.projector.num_queries
return self.config.num_image_tokens
class MiniMindVision(nn.Module):
"""
Complete vision-language model combining MiniMind Max2 with vision adapter.
"""
def __init__(
self,
llm_model: nn.Module,
vision_config: Optional[VisionConfig] = None,
):
super().__init__()
# Get LLM config
if hasattr(llm_model, 'config'):
llm_hidden_size = llm_model.config.hidden_size
else:
llm_hidden_size = 1024
# Vision config
self.vision_config = vision_config or VisionConfig(llm_hidden_size=llm_hidden_size)
# Components
self.llm = llm_model
self.vision_adapter = VisionAdapter(self.vision_config)
# Freeze LLM if needed
if self.vision_config.freeze_llm:
for param in self.llm.parameters():
param.requires_grad = False
def merge_vision_text_embeddings(
self,
text_embeddings: torch.Tensor,
vision_embeddings: torch.Tensor,
image_positions: torch.Tensor,
) -> torch.Tensor:
"""
Merge vision embeddings into text embedding sequence.
Args:
text_embeddings: [batch, text_seq_len, hidden_size]
vision_embeddings: [batch, num_image_tokens, hidden_size]
image_positions: [batch] position indices for image tokens
Returns:
Merged embeddings: [batch, total_seq_len, hidden_size]
"""
batch_size = text_embeddings.shape[0]
num_image_tokens = vision_embeddings.shape[1]
# Calculate output sequence length
text_len = text_embeddings.shape[1]
total_len = text_len + num_image_tokens
# Create output tensor
merged = torch.zeros(
batch_size, total_len, text_embeddings.shape[-1],
device=text_embeddings.device,
dtype=text_embeddings.dtype,
)
for i in range(batch_size):
pos = image_positions[i].item()
# Text before image
if pos > 0:
merged[i, :pos] = text_embeddings[i, :pos]
# Image tokens
merged[i, pos:pos + num_image_tokens] = vision_embeddings[i]
# Text after image
if pos < text_len:
merged[i, pos + num_image_tokens:] = text_embeddings[i, pos:]
return merged
def forward(
self,
input_ids: torch.LongTensor,
images: Optional[torch.Tensor] = None,
image_positions: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""
Forward pass with optional images.
Args:
input_ids: Text token IDs
images: Optional batch of images
image_positions: Where to insert image tokens
attention_mask: Attention mask for text
labels: Labels for language modeling
Returns:
Loss (if labels provided) and logits
"""
# Get text embeddings from LLM
if hasattr(self.llm, 'model'):
text_embeddings = self.llm.model.embed_tokens(input_ids)
else:
text_embeddings = self.llm.embed_tokens(input_ids)
# Process images if provided
if images is not None:
vision_embeddings = self.vision_adapter(images)
if image_positions is None:
# Default: insert at beginning
image_positions = torch.zeros(images.shape[0], dtype=torch.long, device=images.device)
# Merge embeddings
merged_embeddings = self.merge_vision_text_embeddings(
text_embeddings, vision_embeddings, image_positions
)
# Update attention mask
if attention_mask is not None:
num_image_tokens = vision_embeddings.shape[1]
image_mask = torch.ones(
images.shape[0], num_image_tokens,
device=attention_mask.device,
dtype=attention_mask.dtype,
)
attention_mask = torch.cat([image_mask, attention_mask], dim=1)
else:
merged_embeddings = text_embeddings
# Forward through LLM (need to modify to accept embeddings directly)
# This is a simplified version
loss, logits, _, _ = self.llm(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
return loss, logits
@torch.no_grad()
def caption_image(
self,
image: torch.Tensor,
prompt: str = "Describe this image:",
max_new_tokens: int = 100,
tokenizer = None,
) -> str:
"""Generate caption for an image."""
self.eval()
# Encode image
vision_embeddings = self.vision_adapter(image.unsqueeze(0))
# Tokenize prompt
if tokenizer is not None:
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(image.device)
else:
# Dummy for testing
input_ids = torch.randint(0, 1000, (1, 10), device=image.device)
# Generate (simplified)
# In practice, would use the merged embeddings
generated = self.llm.generate(
input_ids,
max_new_tokens=max_new_tokens,
)
if tokenizer is not None:
return tokenizer.decode(generated[0], skip_special_tokens=True)
return "Generated caption placeholder"
class VisionDataset(Dataset):
"""Dataset for vision-language training."""
def __init__(
self,
data_path: str,
tokenizer,
image_processor,
max_length: int = 512,
):
self.tokenizer = tokenizer
self.image_processor = image_processor
self.max_length = max_length
self.examples = []
# Load data (e.g., LLaVA-150k format)
import json
with open(data_path, 'r') as f:
self.examples = json.load(f)
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> Dict[str, Any]:
example = self.examples[idx]
# Load and process image
# In practice: image = Image.open(example["image"]).convert("RGB")
# image = self.image_processor(image)
# Dummy image for now
image = torch.randn(3, 384, 384)
# Tokenize text
text = example.get("conversations", [{"value": "Describe the image."}])[0]["value"]
encodings = self.tokenizer(
text,
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
return {
"image": image,
"input_ids": encodings["input_ids"].squeeze(0),
"attention_mask": encodings["attention_mask"].squeeze(0),
"labels": encodings["input_ids"].squeeze(0),
}