# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Shared RWKV-VL image processing helpers. This module intentionally has no TorchTitan-specific imports so it can be copied into HF remote-code exports and used by ``AutoProcessor`` at inference time. The key policy is that ``max_pixels`` is a per-sample visual budget: when one prompt contains multiple images, the resized pixel counts are scaled together so their sum stays within the configured budget whenever physically possible. """ from __future__ import annotations from dataclasses import dataclass from io import BytesIO import logging import math from typing import Any import einops as E import requests import torch # pyrefly: ignore [missing-import] import torchvision.io # pyrefly: ignore [missing-import] import torchvision.transforms.v2.functional as TVF from PIL import Image logger = logging.getLogger(__name__) CHAT_TEMPLATE = ( "{% for message in messages %}" "{{ '\\x16' + ('Assistant' if message['role'] == 'assistant' else 'System' if message['role'] == 'system' else 'User') + ':' }}" "{% if message['content'] is string %}" "{{ message['content'] }}" "{% else %}" "{% set ns = namespace(explicit_image_tags=0, image_items=0, text_parts=[]) %}" "{% for item in message['content'] %}" "{% if item['type'] == 'text' %}" "{% set ns.text_parts = ns.text_parts + [item['text']] %}" "{% set ns.explicit_image_tags = ns.explicit_image_tags + item['text'].count('') %}" "{% elif item['type'] in ['image', 'image_url'] %}" "{% set ns.image_items = ns.image_items + 1 %}" "{% endif %}" "{% endfor %}" "{% for _ in range([ns.image_items - ns.explicit_image_tags, 0] | max) %}" "{{ '' }}" "{% endfor %}" "{{ ns.text_parts | join('') }}" "{% endif %}" "{{ '\\x17' }}" "{% endfor %}" "{% if add_generation_prompt %}" "{{ '\\x16Assistant:' }}" "{% if thinking is defined and thinking %}{{ ' ' }}{% endif %}" "{% endif %}" ) CHAT_TEMPLATE_FAKE_THINKING = ( "{% for message in messages %}" "{{ '\\x16' + ('Assistant' if message['role'] == 'assistant' else 'System' if message['role'] == 'system' else 'User') + ':' }}" "{% if message['role'] == 'assistant' %}{{ ' \\n\\n' }}{% endif %}" "{% if message['content'] is string %}" "{{ message['content'] }}" "{% else %}" "{% set ns = namespace(explicit_image_tags=0, image_items=0, text_parts=[]) %}" "{% for item in message['content'] %}" "{% if item['type'] == 'text' %}" "{% set ns.text_parts = ns.text_parts + [item['text']] %}" "{% set ns.explicit_image_tags = ns.explicit_image_tags + item['text'].count('') %}" "{% elif item['type'] in ['image', 'image_url'] %}" "{% set ns.image_items = ns.image_items + 1 %}" "{% endif %}" "{% endfor %}" "{% for _ in range([ns.image_items - ns.explicit_image_tags, 0] | max) %}" "{{ '' }}" "{% endfor %}" "{{ ns.text_parts | join('') }}" "{% endif %}" "{{ '\\x17' }}" "{% endfor %}" "{% if add_generation_prompt %}" "{{ '\\x16Assistant:' }}" "{% if thinking is defined and thinking %}{{ ' ' }}{% endif %}" "{% endif %}" ) @dataclass(frozen=True, slots=True) class RWKVVLImageProcessorConfig: patch_size: int = 16 temporal_patch_size: int = 2 spatial_merge_size: int = 2 min_pixels: int = 65536 max_pixels: int = 2097152 image_mean: tuple[float, ...] = (0.5, 0.5, 0.5) image_std: tuple[float, ...] = (0.5, 0.5, 0.5) max_aspect_ratio: float = 50.0 @property def factor(self) -> int: return self.patch_size * self.spatial_merge_size @dataclass(frozen=True, slots=True) class RWKVVLProcessedImages: images: list[torch.Tensor] image_token_counts: list[int] grid_thw: torch.Tensor flat_patches: torch.Tensor def _decode_image(image: str | bytes | Image.Image) -> torch.Tensor: """Decode an image to a ``(C, H, W)`` uint8 RGB tensor.""" if isinstance(image, str) and image.startswith("http"): response = requests.get(image, timeout=10) response.raise_for_status() image = response.content if isinstance(image, bytes): raw = torch.frombuffer(bytearray(image), dtype=torch.uint8) return torchvision.io.decode_image(raw, mode=torchvision.io.ImageReadMode.RGB) if isinstance(image, str): return torchvision.io.decode_image(image, mode=torchvision.io.ImageReadMode.RGB) if image.mode != "RGB": image = image.convert("RGB") return TVF.pil_to_tensor(image) def normalize_image(image: Any) -> Any: if isinstance(image, Image.Image): return image.convert("RGB") return image def get_image_size(image: str | bytes | Image.Image) -> tuple[int, int]: """Return ``(height, width)`` without decoding to a tensor when possible.""" if isinstance(image, Image.Image): width, height = image.size return height, width if isinstance(image, bytes): with Image.open(BytesIO(image)) as pil_image: width, height = pil_image.size return height, width if isinstance(image, str) and image.startswith("http"): response = requests.get(image, timeout=10) response.raise_for_status() with Image.open(BytesIO(response.content)) as pil_image: width, height = pil_image.size return height, width if isinstance(image, str): with Image.open(image) as pil_image: width, height = pil_image.size return height, width img_tensor = _decode_image(image) _, height, width = img_tensor.shape return height, width def _ensure_min_factor_size( height: int, width: int, factor: int, ) -> tuple[int, int]: if height < factor or width < factor: scale = max(factor / width, factor / height) width = int(width * scale) height = int(height * scale) return height, width def smart_resize( height: int, width: int, factor: int, min_pixels: int, max_pixels: int, ) -> tuple[int, int]: """Compute target ``(height, width)`` with Qwen-style resize constraints.""" if max(height, width) / min(height, width) > 200: raise ValueError( f"Absolute aspect ratio must be smaller than 200, " f"got {max(height, width) / min(height, width):.1f}" ) h_bar = max(round(height / factor) * factor, factor) w_bar = max(round(width / factor) * factor, factor) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = max(math.floor(height / beta / factor) * factor, factor) w_bar = max(math.floor(width / beta / factor) * factor, factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar def resize_sizes_to_total_pixels( sizes: list[tuple[int, int]], *, factor: int, min_pixels: int, max_pixels: int, ) -> list[tuple[int, int]]: """Resize image sizes with ``max_pixels`` as a shared per-sample budget.""" if not sizes: return [] resized_sizes = [ smart_resize( *_ensure_min_factor_size(height, width, factor), factor=factor, min_pixels=min_pixels, max_pixels=max_pixels, ) for height, width in sizes ] if max_pixels <= 0: return resized_sizes total_pixels = sum(height * width for height, width in resized_sizes) if total_pixels <= max_pixels: return resized_sizes scale = math.sqrt(max_pixels / total_pixels) min_area = factor * factor def scale_sizes(scale_factor: float) -> list[tuple[int, int]]: scaled = [] for height, width in resized_sizes: scaled_height = max( factor, math.floor(height * scale_factor / factor) * factor ) scaled_width = max( factor, math.floor(width * scale_factor / factor) * factor ) scaled.append((scaled_height, scaled_width)) return scaled scaled_sizes = scale_sizes(scale) for _ in range(8): scaled_total = sum(height * width for height, width in scaled_sizes) if scaled_total <= max_pixels: return scaled_sizes if scaled_total <= len(scaled_sizes) * min_area: break scale *= math.sqrt(max_pixels / scaled_total) * 0.995 scaled_sizes = scale_sizes(scale) scaled_total = sum(height * width for height, width in scaled_sizes) if scaled_total > max_pixels: logger.warning( "Minimum resized image area (%s images * %s pixels) exceeds " "max_pixels=%s; using the smallest factor-aligned image sizes.", len(scaled_sizes), min_area, max_pixels, ) return scaled_sizes def process_image( image: str | bytes | Image.Image, patch_size: int = 16, merge_size: int = 2, max_pixels: int = 16777216, min_pixels: int = 65536, image_mean: tuple[float, ...] = (0.5, 0.5, 0.5), image_std: tuple[float, ...] = (0.5, 0.5, 0.5), resized_size: tuple[int, int] | None = None, ) -> torch.Tensor | None: """Decode, resize, rescale, normalize, and return ``(1, H, W, C)``.""" try: img_tensor = _decode_image(image) _, original_height, original_width = img_tensor.shape factor = patch_size * merge_size if resized_size is None: original_height, original_width = _ensure_min_factor_size( original_height, original_width, factor ) resized_height, resized_width = smart_resize( original_height, original_width, factor=factor, min_pixels=min_pixels, max_pixels=max_pixels, ) else: resized_height, resized_width = resized_size img_tensor = TVF.resize( img_tensor, [resized_height, resized_width], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True, ) img_tensor = TVF.to_dtype(img_tensor, torch.float32, scale=True) img_tensor = TVF.normalize( img_tensor, list(image_mean), list(image_std), inplace=True ) return img_tensor.permute(1, 2, 0).unsqueeze(0) except Exception as exc: logger.warning("Error processing image: %s", exc) return None def calculate_vision_tokens( num_frames: int, height: int, width: int, patch_size: int, spatial_merge_size: int, temporal_patch_size: int, ) -> tuple[int, int, int]: t_patches = math.ceil(num_frames / temporal_patch_size) tokens_per_row = width // (patch_size * spatial_merge_size) num_rows = height // (patch_size * spatial_merge_size) total_tokens = t_patches * tokens_per_row * num_rows return total_tokens, tokens_per_row, num_rows def vision_to_patches( img: torch.Tensor, patch_size: int, temporal_patch_size: int, merge_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: """Convert ``(T, H, W, C)`` to Qwen-compatible flattened patches.""" T, H, W, C = img.shape ps = patch_size ts = temporal_patch_size if T % ts != 0: pad_t = ts - (T % ts) img = torch.cat([img, img[-1:].expand(pad_t, -1, -1, -1)], dim=0) T = img.shape[0] T_patches = T // ts H_patches = H // ps W_patches = W // ps patches = E.rearrange( img, "(t pt) (bh m ph) (bw n pw) c -> (t bh bw m n) (c pt ph pw)", pt=ts, ph=ps, pw=ps, m=merge_size, n=merge_size, ) grid_thw = torch.tensor([T_patches, H_patches, W_patches]) return patches, grid_thw def process_images( images: list[Any], config: RWKVVLImageProcessorConfig, ) -> RWKVVLProcessedImages: """Process a list of images that share one visual pixel budget.""" normalized_images = [normalize_image(image) for image in images] image_sizes = [] for image in normalized_images: height, width = get_image_size(image) if width == 0 or height == 0: raise ValueError("Image has zero width or height") ratio = max(width / height, height / width) if ratio > config.max_aspect_ratio: raise ValueError( f"Image aspect ratio {ratio:.1f} exceeds {config.max_aspect_ratio}" ) image_sizes.append((height, width)) resized_sizes = resize_sizes_to_total_pixels( image_sizes, factor=config.factor, min_pixels=config.min_pixels, max_pixels=config.max_pixels, ) processed_images = [] image_token_counts = [] patch_list = [] grid_list = [] for image, resized_size in zip(normalized_images, resized_sizes, strict=True): processed = process_image( image, patch_size=config.patch_size, merge_size=config.spatial_merge_size, min_pixels=config.min_pixels, max_pixels=config.max_pixels, image_mean=config.image_mean, image_std=config.image_std, resized_size=resized_size, ) if processed is None: raise ValueError("Could not process image") num_tokens, _, _ = calculate_vision_tokens( num_frames=processed.shape[0], height=processed.shape[1], width=processed.shape[2], patch_size=config.patch_size, spatial_merge_size=config.spatial_merge_size, temporal_patch_size=config.temporal_patch_size, ) patches, grid = vision_to_patches( processed, config.patch_size, config.temporal_patch_size, config.spatial_merge_size, ) processed_images.append(processed) image_token_counts.append(num_tokens) patch_list.append(patches) grid_list.append(grid) if patch_list: flat_patches = torch.cat(patch_list, dim=0) grid_thw = torch.stack(grid_list, dim=0) else: patch_dim = 3 * config.temporal_patch_size * config.patch_size**2 flat_patches = torch.empty(0, patch_dim, dtype=torch.float32) grid_thw = torch.empty(0, 3, dtype=torch.long) return RWKVVLProcessedImages( images=processed_images, image_token_counts=image_token_counts, grid_thw=grid_thw, flat_patches=flat_patches, ) def make_image_config_from_processor( image_processor: Any, *, max_aspect_ratio: float = 50.0, **overrides: Any, ) -> RWKVVLImageProcessorConfig: size = getattr(image_processor, "size", {}) or {} if hasattr(size, "to_dict"): size = size.to_dict() override_size = overrides.pop("size", None) or {} if hasattr(override_size, "to_dict"): override_size = override_size.to_dict() patch_size = int( overrides.pop("patch_size", None) or getattr(image_processor, "patch_size", 16) ) temporal_patch_size = int( overrides.pop("temporal_patch_size", None) or getattr(image_processor, "temporal_patch_size", 2) ) merge_size = int( overrides.pop("merge_size", None) or overrides.pop("spatial_merge_size", None) or getattr(image_processor, "merge_size", 2) ) min_pixels = int( overrides.pop("min_pixels", None) or overrides.pop("shortest_edge", None) or override_size.get("shortest_edge") or size.get("shortest_edge") or 65536 ) max_pixels = int( overrides.pop("max_pixels", None) or overrides.pop("longest_edge", None) or override_size.get("longest_edge") or size.get("longest_edge") or 2097152 ) image_mean = tuple( overrides.pop("image_mean", None) or getattr(image_processor, "image_mean", (0.5, 0.5, 0.5)) ) image_std = tuple( overrides.pop("image_std", None) or getattr(image_processor, "image_std", (0.5, 0.5, 0.5)) ) return RWKVVLImageProcessorConfig( patch_size=patch_size, temporal_patch_size=temporal_patch_size, spatial_merge_size=merge_size, min_pixels=min_pixels, max_pixels=max_pixels, image_mean=image_mean, image_std=image_std, max_aspect_ratio=float(overrides.pop("max_aspect_ratio", max_aspect_ratio)), )