MINDI-1.5-Vision-Coder / src /model /vision_encoder.py
Faaz
Fix hidden_size: 4096 -> 3584 to match Qwen2.5-Coder-7B-Instruct
691fc84
"""
MINDI 1.5 Vision-Coder β€” Vision Encoder
Uses CLIP ViT-L/14 (frozen) to encode UI screenshots into 256 visual
tokens projected from 1024 β†’ 3584 to match the Qwen hidden dimension.
Output shape: (batch, 256, 3584).
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModel
class VisionEncoder(nn.Module):
"""
CLIP ViT-L/14 vision encoder for MINDI 1.5.
Extracts ALL 256 patch tokens (excludes CLS) from CLIP and
projects them from 1024 β†’ 3584 to match Qwen2.5 hidden_size.
The CLIP backbone is frozen; only the projection layer trains.
"""
NUM_PATCHES: int = 256 # ViT-L/14: 16Γ—16 patches from 224Γ—224
def __init__(
self,
model_name: str = "openai/clip-vit-large-patch14",
llm_hidden_size: int = 3584,
device: Optional[str] = None,
cache_dir: Optional[Path] = None,
torch_dtype: torch.dtype = torch.float32,
) -> None:
"""
Initialize the vision encoder.
Args:
model_name: HuggingFace CLIP vision model identifier.
llm_hidden_size: Target projection dimension (must match LLM hidden_size).
device: Target device ('cuda', 'cpu', or None for auto).
cache_dir: Local directory for model weight cache.
torch_dtype: Data type for CLIP weights (projection always float32).
"""
super().__init__()
self.model_name = model_name
self.llm_hidden_size = llm_hidden_size
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.cache_dir = Path(cache_dir) if cache_dir else Path("./checkpoints/vision")
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Load CLIP vision model (no text tower) and image processor
print(f"[VisionEncoder] Loading {model_name} ...")
self.clip = CLIPVisionModel.from_pretrained(
model_name,
cache_dir=str(self.cache_dir),
torch_dtype=torch_dtype,
)
self.image_processor = CLIPImageProcessor.from_pretrained(
model_name,
cache_dir=str(self.cache_dir),
)
# Freeze entire CLIP backbone
for param in self.clip.parameters():
param.requires_grad = False
self.clip.eval()
# Trainable projection: CLIP hidden (1024) β†’ LLM hidden (4096)
clip_hidden_size: int = self.clip.config.hidden_size # 1024
self.projection = nn.Linear(clip_hidden_size, self.llm_hidden_size)
self.to(self.device)
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
total = sum(p.numel() for p in self.parameters())
print(f"[VisionEncoder] Loaded β€” {clip_hidden_size} β†’ {self.llm_hidden_size}")
print(f" Trainable: {trainable:,} | Total: {total:,}")
def encode_image(self, image: Optional[Image.Image]) -> Optional[torch.Tensor]:
"""
Encode a single PIL image into projected patch token embeddings.
Args:
image: A PIL Image (RGB), or None.
Returns:
Tensor of shape (1, 256, 4096) or None if input is None.
"""
if image is None:
return None
inputs = self.image_processor(images=image, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.clip.dtype)
with torch.no_grad():
vision_outputs = self.clip(pixel_values=pixel_values)
# last_hidden_state: (batch, 257, 1024) β€” 1 CLS + 256 patches
patch_tokens = vision_outputs.last_hidden_state[:, 1:, :] # (1, 256, 1024)
# Project into LLM embedding space (trainable)
projected = self.projection(patch_tokens.float()) # (1, 256, 4096)
return projected
def encode_batch(self, images: list[Optional[Image.Image]]) -> list[Optional[torch.Tensor]]:
"""
Encode a batch of images. None entries pass through as None.
Args:
images: List of PIL Images or Nones.
Returns:
List of tensors (1, 256, 4096) or Nones matching input order.
"""
results: list[Optional[torch.Tensor]] = [None] * len(images)
valid_indices = [i for i, img in enumerate(images) if img is not None]
if not valid_indices:
return results
valid_images = [images[i] for i in valid_indices]
inputs = self.image_processor(images=valid_images, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.clip.dtype)
with torch.no_grad():
vision_outputs = self.clip(pixel_values=pixel_values)
patch_tokens = vision_outputs.last_hidden_state[:, 1:, :] # (N, 256, 1024)
projected = self.projection(patch_tokens.float()) # (N, 256, 4096)
for batch_idx, orig_idx in enumerate(valid_indices):
results[orig_idx] = projected[batch_idx].unsqueeze(0) # (1, 256, 4096)
return results
def encode_screenshot(self, screenshot_path: Path) -> Optional[torch.Tensor]:
"""
Load a screenshot from disk and encode it.
Args:
screenshot_path: Path to image file.
Returns:
Tensor of shape (1, 256, 4096).
"""
path = Path(screenshot_path)
if not path.exists():
raise FileNotFoundError(f"Screenshot not found: {path}")
image = Image.open(path).convert("RGB")
return self.encode_image(image)
def save_projection(self, save_dir: Optional[Path] = None) -> Path:
"""
Save only the trainable projection weights.
Args:
save_dir: Directory to save to. Defaults to cache_dir/projection.
Returns:
Path where weights were saved.
"""
save_path = Path(save_dir) if save_dir else self.cache_dir / "projection"
save_path.mkdir(parents=True, exist_ok=True)
torch.save(self.projection.state_dict(), save_path / "projection.pt")
print(f"[VisionEncoder] Projection saved to {save_path}")
return save_path
def load_projection(self, load_dir: Path) -> None:
"""
Load projection weights from disk.
Args:
load_dir: Directory containing projection.pt.
"""
weights_path = Path(load_dir) / "projection.pt"
if not weights_path.exists():
raise FileNotFoundError(f"Projection weights not found: {weights_path}")
state_dict = torch.load(weights_path, map_location=self.device, weights_only=True)
self.projection.load_state_dict(state_dict)
print(f"[VisionEncoder] Projection loaded from {load_dir}")
def get_num_visual_tokens(self) -> int:
"""Return the number of visual tokens produced per image (256)."""
return self.NUM_PATCHES
# ── Test block ────────────────────────────────────────────────────────
if __name__ == "__main__":
print("=" * 60)
print(" MINDI 1.5 β€” Vision Encoder Test")
print("=" * 60)
print()
# 1. Initialize encoder
encoder = VisionEncoder(
model_name="openai/clip-vit-large-patch14",
llm_hidden_size=4096,
)
# 2. Create a dummy image (224Γ—224 RGB)
dummy_image = Image.new("RGB", (224, 224), color=(128, 128, 128))
# 3. Encode single image
print("\n Encoding single image ...")
output = encoder.encode_image(dummy_image)
assert output is not None
print(f" Output shape: {output.shape}")
assert output.shape == (1, 256, 4096), f"Expected (1, 256, 4096), got {output.shape}"
# 4. Encode None β†’ should return None
none_output = encoder.encode_image(None)
assert none_output is None, "Expected None for None input"
print(" None input β†’ None output βœ“")
# 5. Encode batch (mixed with None)
print("\n Encoding batch [image, None, image] ...")
batch_results = encoder.encode_batch([dummy_image, None, dummy_image])
assert batch_results[0] is not None and batch_results[0].shape == (1, 256, 4096)
assert batch_results[1] is None
assert batch_results[2] is not None and batch_results[2].shape == (1, 256, 4096)
print(f" Batch results: [{batch_results[0].shape}, None, {batch_results[2].shape}]")
# 6. Check trainable params (only projection should train)
trainable = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in encoder.parameters() if not p.requires_grad)
print(f"\n Trainable: {trainable:,}")
print(f" Frozen: {frozen:,}")
assert trainable == 1024 * 4096 + 4096, f"Unexpected trainable count: {trainable}"
assert frozen > trainable, "CLIP backbone should be frozen"
# 7. Save and reload projection
print("\n Testing save/load projection ...")
import tempfile
with tempfile.TemporaryDirectory() as tmp:
save_path = encoder.save_projection(Path(tmp))
old_weight = encoder.projection.weight.clone()
# Perturb weights
encoder.projection.weight.data.fill_(0.0)
assert not torch.equal(encoder.projection.weight, old_weight)
# Reload
encoder.load_projection(Path(tmp))
assert torch.equal(encoder.projection.weight, old_weight), "Weights not restored!"
print(" Save/load round-trip βœ“")
print("\n βœ“ All vision encoder tests passed!")
print("=" * 60)