molmo-dinov3-b16-olmo3 / image_processing_molmo_olmo3.py
amitha's picture
Upload folder using huggingface_hub
5a8af3b verified
# coding=utf-8
"""Image processor for the Molmo-v1 (CLIP vision) VLM.
Replicates Molmo's `hf_resize_and_center_crop` + OpenAI-CLIP normalization exactly
(PIL BICUBIC, shortest-edge resize, center crop, /255, normalize) so the resulting
(3, 224, 224) array is bit-identical to the Molmo preprocessor. The HF CLIPVisionModel
performs the Conv2d patchification internally, matching Molmo's un-patchify+Conv2d path.
"""
from typing import Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import ImageInput
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
def _to_uint8_rgb(image) -> np.ndarray:
if isinstance(image, PIL.Image.Image):
return np.array(image.convert("RGB"))
arr = np.asarray(image)
if arr.dtype != np.uint8:
# assume already in [0,255] if float
arr = arr.astype(np.uint8)
if arr.ndim == 2:
arr = np.stack([arr] * 3, axis=-1)
return arr[:, :, :3]
def hf_resize_and_center_crop(image: np.ndarray, output_size) -> np.ndarray:
"""Exactly mirrors olmo/data/model_preprocessor.py:hf_resize_and_center_crop."""
desired_h, desired_w = output_size
height, width = image.shape[:2]
scale = max(desired_h / height, desired_w / width)
new_h = int(height * scale)
new_w = int(width * scale)
pil_image = PIL.Image.fromarray(image)
pil_image = pil_image.resize((new_w, new_h), PIL.Image.BICUBIC)
top = (new_h - desired_h) // 2
left = (new_w - desired_w) // 2
pil_image = pil_image.crop((left, top, left + desired_w, top + desired_h))
return np.array(pil_image).astype(np.float32) / 255.0
class MolmoOlmo3ImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_size: int = 224,
image_mean=OPENAI_CLIP_MEAN,
image_std=OPENAI_CLIP_STD,
**kwargs,
):
super().__init__(**kwargs)
self.image_size = image_size
self.image_mean = list(image_mean)
self.image_std = list(image_std)
def preprocess_one(self, image) -> np.ndarray:
arr = _to_uint8_rgb(image)
resized = hf_resize_and_center_crop(arr, (self.image_size, self.image_size)) # (H,W,3) in [0,1]
resized = resized - np.array(self.image_mean, dtype=np.float32)[None, None, :]
resized = resized / np.array(self.image_std, dtype=np.float32)[None, None, :]
# HWC -> CHW
return np.transpose(resized, (2, 0, 1))
def preprocess(
self,
images: Union[ImageInput, List[ImageInput]],
return_tensors: Optional[str] = "pt",
**kwargs,
) -> BatchFeature:
if not isinstance(images, (list, tuple)):
images = [images]
pixel_values = np.stack([self.preprocess_one(im) for im in images], axis=0) # (n, 3, H, W)
data = {"pixel_values": pixel_values}
if return_tensors == "pt":
data = {"pixel_values": torch.from_numpy(pixel_values)}
return BatchFeature(data=data, tensor_type=None)
__all__ = ["MolmoOlmo3ImageProcessor"]