| """ |
| MINDI 1.5 Vision-Coder β Vision Encoder |
| |
| Uses CLIP ViT-L/14 (frozen) to encode UI screenshots into 256 visual |
| tokens projected from 1024 β 3584 to match the Qwen hidden dimension. |
| Output shape: (batch, 256, 3584). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| from transformers import CLIPImageProcessor, CLIPVisionModel |
|
|
|
|
| class VisionEncoder(nn.Module): |
| """ |
| CLIP ViT-L/14 vision encoder for MINDI 1.5. |
| |
| Extracts ALL 256 patch tokens (excludes CLS) from CLIP and |
| projects them from 1024 β 3584 to match Qwen2.5 hidden_size. |
| The CLIP backbone is frozen; only the projection layer trains. |
| """ |
|
|
| NUM_PATCHES: int = 256 |
|
|
| def __init__( |
| self, |
| model_name: str = "openai/clip-vit-large-patch14", |
| llm_hidden_size: int = 3584, |
| device: Optional[str] = None, |
| cache_dir: Optional[Path] = None, |
| torch_dtype: torch.dtype = torch.float32, |
| ) -> None: |
| """ |
| Initialize the vision encoder. |
| |
| Args: |
| model_name: HuggingFace CLIP vision model identifier. |
| llm_hidden_size: Target projection dimension (must match LLM hidden_size). |
| device: Target device ('cuda', 'cpu', or None for auto). |
| cache_dir: Local directory for model weight cache. |
| torch_dtype: Data type for CLIP weights (projection always float32). |
| """ |
| super().__init__() |
| self.model_name = model_name |
| self.llm_hidden_size = llm_hidden_size |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.cache_dir = Path(cache_dir) if cache_dir else Path("./checkpoints/vision") |
| self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| print(f"[VisionEncoder] Loading {model_name} ...") |
| self.clip = CLIPVisionModel.from_pretrained( |
| model_name, |
| cache_dir=str(self.cache_dir), |
| torch_dtype=torch_dtype, |
| ) |
| self.image_processor = CLIPImageProcessor.from_pretrained( |
| model_name, |
| cache_dir=str(self.cache_dir), |
| ) |
|
|
| |
| for param in self.clip.parameters(): |
| param.requires_grad = False |
| self.clip.eval() |
|
|
| |
| clip_hidden_size: int = self.clip.config.hidden_size |
| self.projection = nn.Linear(clip_hidden_size, self.llm_hidden_size) |
|
|
| self.to(self.device) |
|
|
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in self.parameters()) |
| print(f"[VisionEncoder] Loaded β {clip_hidden_size} β {self.llm_hidden_size}") |
| print(f" Trainable: {trainable:,} | Total: {total:,}") |
|
|
| def encode_image(self, image: Optional[Image.Image]) -> Optional[torch.Tensor]: |
| """ |
| Encode a single PIL image into projected patch token embeddings. |
| |
| Args: |
| image: A PIL Image (RGB), or None. |
| |
| Returns: |
| Tensor of shape (1, 256, 4096) or None if input is None. |
| """ |
| if image is None: |
| return None |
|
|
| inputs = self.image_processor(images=image, return_tensors="pt") |
| pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.clip.dtype) |
|
|
| with torch.no_grad(): |
| vision_outputs = self.clip(pixel_values=pixel_values) |
| |
| patch_tokens = vision_outputs.last_hidden_state[:, 1:, :] |
|
|
| |
| projected = self.projection(patch_tokens.float()) |
| return projected |
|
|
| def encode_batch(self, images: list[Optional[Image.Image]]) -> list[Optional[torch.Tensor]]: |
| """ |
| Encode a batch of images. None entries pass through as None. |
| |
| Args: |
| images: List of PIL Images or Nones. |
| |
| Returns: |
| List of tensors (1, 256, 4096) or Nones matching input order. |
| """ |
| results: list[Optional[torch.Tensor]] = [None] * len(images) |
| valid_indices = [i for i, img in enumerate(images) if img is not None] |
|
|
| if not valid_indices: |
| return results |
|
|
| valid_images = [images[i] for i in valid_indices] |
| inputs = self.image_processor(images=valid_images, return_tensors="pt") |
| pixel_values = inputs["pixel_values"].to(device=self.device, dtype=self.clip.dtype) |
|
|
| with torch.no_grad(): |
| vision_outputs = self.clip(pixel_values=pixel_values) |
| patch_tokens = vision_outputs.last_hidden_state[:, 1:, :] |
|
|
| projected = self.projection(patch_tokens.float()) |
|
|
| for batch_idx, orig_idx in enumerate(valid_indices): |
| results[orig_idx] = projected[batch_idx].unsqueeze(0) |
|
|
| return results |
|
|
| def encode_screenshot(self, screenshot_path: Path) -> Optional[torch.Tensor]: |
| """ |
| Load a screenshot from disk and encode it. |
| |
| Args: |
| screenshot_path: Path to image file. |
| |
| Returns: |
| Tensor of shape (1, 256, 4096). |
| """ |
| path = Path(screenshot_path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Screenshot not found: {path}") |
| image = Image.open(path).convert("RGB") |
| return self.encode_image(image) |
|
|
| def save_projection(self, save_dir: Optional[Path] = None) -> Path: |
| """ |
| Save only the trainable projection weights. |
| |
| Args: |
| save_dir: Directory to save to. Defaults to cache_dir/projection. |
| |
| Returns: |
| Path where weights were saved. |
| """ |
| save_path = Path(save_dir) if save_dir else self.cache_dir / "projection" |
| save_path.mkdir(parents=True, exist_ok=True) |
| torch.save(self.projection.state_dict(), save_path / "projection.pt") |
| print(f"[VisionEncoder] Projection saved to {save_path}") |
| return save_path |
|
|
| def load_projection(self, load_dir: Path) -> None: |
| """ |
| Load projection weights from disk. |
| |
| Args: |
| load_dir: Directory containing projection.pt. |
| """ |
| weights_path = Path(load_dir) / "projection.pt" |
| if not weights_path.exists(): |
| raise FileNotFoundError(f"Projection weights not found: {weights_path}") |
| state_dict = torch.load(weights_path, map_location=self.device, weights_only=True) |
| self.projection.load_state_dict(state_dict) |
| print(f"[VisionEncoder] Projection loaded from {load_dir}") |
|
|
| def get_num_visual_tokens(self) -> int: |
| """Return the number of visual tokens produced per image (256).""" |
| return self.NUM_PATCHES |
|
|
|
|
| |
| if __name__ == "__main__": |
| print("=" * 60) |
| print(" MINDI 1.5 β Vision Encoder Test") |
| print("=" * 60) |
| print() |
|
|
| |
| encoder = VisionEncoder( |
| model_name="openai/clip-vit-large-patch14", |
| llm_hidden_size=4096, |
| ) |
|
|
| |
| dummy_image = Image.new("RGB", (224, 224), color=(128, 128, 128)) |
|
|
| |
| print("\n Encoding single image ...") |
| output = encoder.encode_image(dummy_image) |
| assert output is not None |
| print(f" Output shape: {output.shape}") |
| assert output.shape == (1, 256, 4096), f"Expected (1, 256, 4096), got {output.shape}" |
|
|
| |
| none_output = encoder.encode_image(None) |
| assert none_output is None, "Expected None for None input" |
| print(" None input β None output β") |
|
|
| |
| print("\n Encoding batch [image, None, image] ...") |
| batch_results = encoder.encode_batch([dummy_image, None, dummy_image]) |
| assert batch_results[0] is not None and batch_results[0].shape == (1, 256, 4096) |
| assert batch_results[1] is None |
| assert batch_results[2] is not None and batch_results[2].shape == (1, 256, 4096) |
| print(f" Batch results: [{batch_results[0].shape}, None, {batch_results[2].shape}]") |
|
|
| |
| trainable = sum(p.numel() for p in encoder.parameters() if p.requires_grad) |
| frozen = sum(p.numel() for p in encoder.parameters() if not p.requires_grad) |
| print(f"\n Trainable: {trainable:,}") |
| print(f" Frozen: {frozen:,}") |
| assert trainable == 1024 * 4096 + 4096, f"Unexpected trainable count: {trainable}" |
| assert frozen > trainable, "CLIP backbone should be frozen" |
|
|
| |
| print("\n Testing save/load projection ...") |
| import tempfile |
| with tempfile.TemporaryDirectory() as tmp: |
| save_path = encoder.save_projection(Path(tmp)) |
| old_weight = encoder.projection.weight.clone() |
| |
| encoder.projection.weight.data.fill_(0.0) |
| assert not torch.equal(encoder.projection.weight, old_weight) |
| |
| encoder.load_projection(Path(tmp)) |
| assert torch.equal(encoder.projection.weight, old_weight), "Weights not restored!" |
| print(" Save/load round-trip β") |
|
|
| print("\n β All vision encoder tests passed!") |
| print("=" * 60) |
|
|