""" Multimodal Vision Module for MiniMind Max2 Adapter-based approach using SigLIP/DINOv2 vision encoders. """ from dataclasses import dataclass, field from typing import List, Optional, Dict, Any, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import math @dataclass class VisionConfig: """Configuration for vision adapter.""" # Vision encoder settings vision_encoder: str = "siglip-so400m" # siglip-so400m, dinov2-small, clip-vit-base vision_hidden_size: int = 1152 # SigLIP-So400M hidden size image_size: int = 384 patch_size: int = 14 num_image_tokens: int = 729 # (384/14)^2 = 729 patches # Projector settings projector_type: str = "mlp" # mlp, linear, resampler projector_hidden_size: int = 2048 projector_num_layers: int = 2 # LLM settings (to match MiniMind) llm_hidden_size: int = 1024 # MiniMind hidden size # Training settings freeze_vision_encoder: bool = True freeze_llm: bool = True train_projector_only: bool = True # Special tokens image_start_token: str = "" image_end_token: str = "" image_pad_token: str = "" class MLPProjector(nn.Module): """ Multi-Layer Perceptron projector for vision-language alignment. Maps vision encoder outputs to LLM embedding space. """ def __init__(self, config: VisionConfig): super().__init__() self.config = config layers = [] input_size = config.vision_hidden_size for i in range(config.projector_num_layers): if i == config.projector_num_layers - 1: # Last layer projects to LLM size layers.extend([ nn.Linear(input_size, config.llm_hidden_size), ]) else: # Hidden layers layers.extend([ nn.Linear(input_size, config.projector_hidden_size), nn.GELU(), nn.LayerNorm(config.projector_hidden_size), ]) input_size = config.projector_hidden_size self.projector = nn.Sequential(*layers) def forward(self, vision_features: torch.Tensor) -> torch.Tensor: """ Project vision features to LLM space. Args: vision_features: [batch, num_patches, vision_hidden_size] Returns: Projected features: [batch, num_patches, llm_hidden_size] """ return self.projector(vision_features) class Resampler(nn.Module): """ Perceiver-style resampler for compressing vision tokens. Reduces number of image tokens while preserving information. """ def __init__( self, config: VisionConfig, num_queries: int = 64, num_heads: int = 8, num_layers: int = 2, ): super().__init__() self.config = config self.num_queries = num_queries # Learnable query tokens self.queries = nn.Parameter(torch.randn(1, num_queries, config.llm_hidden_size)) # Input projection self.input_proj = nn.Linear(config.vision_hidden_size, config.llm_hidden_size) # Cross-attention layers self.layers = nn.ModuleList([ nn.TransformerDecoderLayer( d_model=config.llm_hidden_size, nhead=num_heads, dim_feedforward=config.llm_hidden_size * 4, batch_first=True, ) for _ in range(num_layers) ]) self.norm = nn.LayerNorm(config.llm_hidden_size) def forward(self, vision_features: torch.Tensor) -> torch.Tensor: """ Resample vision features using learned queries. Args: vision_features: [batch, num_patches, vision_hidden_size] Returns: Resampled features: [batch, num_queries, llm_hidden_size] """ batch_size = vision_features.shape[0] # Project vision features vision_features = self.input_proj(vision_features) # Expand queries for batch queries = self.queries.expand(batch_size, -1, -1) # Cross-attend to vision features for layer in self.layers: queries = layer(queries, vision_features) return self.norm(queries) class VisionEncoder(nn.Module): """ Wrapper for pre-trained vision encoders. Supports SigLIP, DINOv2, and CLIP. """ def __init__(self, config: VisionConfig): super().__init__() self.config = config self.encoder = None self.processor = None # Placeholder for actual encoder loading # In practice, load from HuggingFace self._build_dummy_encoder() def _build_dummy_encoder(self): """Build a dummy encoder for testing.""" # Simple ViT-like encoder patch_dim = 3 * (self.config.patch_size ** 2) num_patches = (self.config.image_size // self.config.patch_size) ** 2 self.patch_embed = nn.Linear(patch_dim, self.config.vision_hidden_size) self.pos_embed = nn.Parameter( torch.randn(1, num_patches + 1, self.config.vision_hidden_size) * 0.02 ) self.cls_token = nn.Parameter( torch.randn(1, 1, self.config.vision_hidden_size) * 0.02 ) # Transformer layers self.layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=self.config.vision_hidden_size, nhead=8, dim_feedforward=self.config.vision_hidden_size * 4, batch_first=True, ) for _ in range(6) ]) self.norm = nn.LayerNorm(self.config.vision_hidden_size) def patchify(self, images: torch.Tensor) -> torch.Tensor: """Convert images to patches.""" batch_size, c, h, w = images.shape p = self.config.patch_size # [B, C, H, W] -> [B, num_patches, patch_dim] patches = images.unfold(2, p, p).unfold(3, p, p) patches = patches.contiguous().view(batch_size, c, -1, p, p) patches = patches.permute(0, 2, 1, 3, 4).contiguous() patches = patches.view(batch_size, -1, c * p * p) return patches def forward(self, images: torch.Tensor) -> torch.Tensor: """ Encode images to feature vectors. Args: images: [batch, 3, height, width] normalized images Returns: Vision features: [batch, num_patches, vision_hidden_size] """ batch_size = images.shape[0] # Patchify and embed patches = self.patchify(images) x = self.patch_embed(patches) # Add CLS token cls_tokens = self.cls_token.expand(batch_size, -1, -1) x = torch.cat([cls_tokens, x], dim=1) # Add positional embeddings x = x + self.pos_embed[:, :x.shape[1], :] # Transformer for layer in self.layers: x = layer(x) x = self.norm(x) # Return patch features (exclude CLS) return x[:, 1:, :] @classmethod def from_pretrained(cls, model_name: str, config: VisionConfig) -> "VisionEncoder": """Load pre-trained vision encoder.""" encoder = cls(config) # In practice, load weights from HuggingFace # try: # from transformers import SiglipVisionModel, AutoProcessor # encoder.encoder = SiglipVisionModel.from_pretrained(model_name) # encoder.processor = AutoProcessor.from_pretrained(model_name) # except ImportError: # pass return encoder class VisionAdapter(nn.Module): """ Complete vision adapter for MiniMind Max2. Connects vision encoder to LLM via projector. """ def __init__(self, config: VisionConfig): super().__init__() self.config = config # Vision encoder self.vision_encoder = VisionEncoder(config) # Projector if config.projector_type == "mlp": self.projector = MLPProjector(config) elif config.projector_type == "resampler": self.projector = Resampler(config) else: self.projector = nn.Linear(config.vision_hidden_size, config.llm_hidden_size) # Freeze components as needed if config.freeze_vision_encoder: for param in self.vision_encoder.parameters(): param.requires_grad = False def forward( self, images: torch.Tensor, return_features: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Process images and project to LLM space. Args: images: [batch, 3, height, width] return_features: Also return raw vision features Returns: Projected features: [batch, num_tokens, llm_hidden_size] """ # Encode images vision_features = self.vision_encoder(images) # Project to LLM space projected = self.projector(vision_features) if return_features: return projected, vision_features return projected def get_num_image_tokens(self) -> int: """Get number of tokens per image.""" if isinstance(self.projector, Resampler): return self.projector.num_queries return self.config.num_image_tokens class MiniMindVision(nn.Module): """ Complete vision-language model combining MiniMind Max2 with vision adapter. """ def __init__( self, llm_model: nn.Module, vision_config: Optional[VisionConfig] = None, ): super().__init__() # Get LLM config if hasattr(llm_model, 'config'): llm_hidden_size = llm_model.config.hidden_size else: llm_hidden_size = 1024 # Vision config self.vision_config = vision_config or VisionConfig(llm_hidden_size=llm_hidden_size) # Components self.llm = llm_model self.vision_adapter = VisionAdapter(self.vision_config) # Freeze LLM if needed if self.vision_config.freeze_llm: for param in self.llm.parameters(): param.requires_grad = False def merge_vision_text_embeddings( self, text_embeddings: torch.Tensor, vision_embeddings: torch.Tensor, image_positions: torch.Tensor, ) -> torch.Tensor: """ Merge vision embeddings into text embedding sequence. Args: text_embeddings: [batch, text_seq_len, hidden_size] vision_embeddings: [batch, num_image_tokens, hidden_size] image_positions: [batch] position indices for image tokens Returns: Merged embeddings: [batch, total_seq_len, hidden_size] """ batch_size = text_embeddings.shape[0] num_image_tokens = vision_embeddings.shape[1] # Calculate output sequence length text_len = text_embeddings.shape[1] total_len = text_len + num_image_tokens # Create output tensor merged = torch.zeros( batch_size, total_len, text_embeddings.shape[-1], device=text_embeddings.device, dtype=text_embeddings.dtype, ) for i in range(batch_size): pos = image_positions[i].item() # Text before image if pos > 0: merged[i, :pos] = text_embeddings[i, :pos] # Image tokens merged[i, pos:pos + num_image_tokens] = vision_embeddings[i] # Text after image if pos < text_len: merged[i, pos + num_image_tokens:] = text_embeddings[i, pos:] return merged def forward( self, input_ids: torch.LongTensor, images: Optional[torch.Tensor] = None, image_positions: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """ Forward pass with optional images. Args: input_ids: Text token IDs images: Optional batch of images image_positions: Where to insert image tokens attention_mask: Attention mask for text labels: Labels for language modeling Returns: Loss (if labels provided) and logits """ # Get text embeddings from LLM if hasattr(self.llm, 'model'): text_embeddings = self.llm.model.embed_tokens(input_ids) else: text_embeddings = self.llm.embed_tokens(input_ids) # Process images if provided if images is not None: vision_embeddings = self.vision_adapter(images) if image_positions is None: # Default: insert at beginning image_positions = torch.zeros(images.shape[0], dtype=torch.long, device=images.device) # Merge embeddings merged_embeddings = self.merge_vision_text_embeddings( text_embeddings, vision_embeddings, image_positions ) # Update attention mask if attention_mask is not None: num_image_tokens = vision_embeddings.shape[1] image_mask = torch.ones( images.shape[0], num_image_tokens, device=attention_mask.device, dtype=attention_mask.dtype, ) attention_mask = torch.cat([image_mask, attention_mask], dim=1) else: merged_embeddings = text_embeddings # Forward through LLM (need to modify to accept embeddings directly) # This is a simplified version loss, logits, _, _ = self.llm( input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) return loss, logits @torch.no_grad() def caption_image( self, image: torch.Tensor, prompt: str = "Describe this image:", max_new_tokens: int = 100, tokenizer = None, ) -> str: """Generate caption for an image.""" self.eval() # Encode image vision_embeddings = self.vision_adapter(image.unsqueeze(0)) # Tokenize prompt if tokenizer is not None: input_ids = tokenizer.encode(prompt, return_tensors="pt").to(image.device) else: # Dummy for testing input_ids = torch.randint(0, 1000, (1, 10), device=image.device) # Generate (simplified) # In practice, would use the merged embeddings generated = self.llm.generate( input_ids, max_new_tokens=max_new_tokens, ) if tokenizer is not None: return tokenizer.decode(generated[0], skip_special_tokens=True) return "Generated caption placeholder" class VisionDataset(Dataset): """Dataset for vision-language training.""" def __init__( self, data_path: str, tokenizer, image_processor, max_length: int = 512, ): self.tokenizer = tokenizer self.image_processor = image_processor self.max_length = max_length self.examples = [] # Load data (e.g., LLaVA-150k format) import json with open(data_path, 'r') as f: self.examples = json.load(f) def __len__(self) -> int: return len(self.examples) def __getitem__(self, idx: int) -> Dict[str, Any]: example = self.examples[idx] # Load and process image # In practice: image = Image.open(example["image"]).convert("RGB") # image = self.image_processor(image) # Dummy image for now image = torch.randn(3, 384, 384) # Tokenize text text = example.get("conversations", [{"value": "Describe the image."}])[0]["value"] encodings = self.tokenizer( text, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt", ) return { "image": image, "input_ids": encodings["input_ids"].squeeze(0), "attention_mask": encodings["attention_mask"].squeeze(0), "labels": encodings["input_ids"].squeeze(0), }