File size: 3,243 Bytes
5a8af3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# coding=utf-8
"""Image processor for the Molmo-v1 (CLIP vision) VLM.

Replicates Molmo's `hf_resize_and_center_crop` + OpenAI-CLIP normalization exactly
(PIL BICUBIC, shortest-edge resize, center crop, /255, normalize) so the resulting
(3, 224, 224) array is bit-identical to the Molmo preprocessor. The HF CLIPVisionModel
performs the Conv2d patchification internally, matching Molmo's un-patchify+Conv2d path.
"""

from typing import Dict, List, Optional, Union

import numpy as np
import PIL.Image
import torch

from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import ImageInput


OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)


def _to_uint8_rgb(image) -> np.ndarray:
    if isinstance(image, PIL.Image.Image):
        return np.array(image.convert("RGB"))
    arr = np.asarray(image)
    if arr.dtype != np.uint8:
        # assume already in [0,255] if float
        arr = arr.astype(np.uint8)
    if arr.ndim == 2:
        arr = np.stack([arr] * 3, axis=-1)
    return arr[:, :, :3]


def hf_resize_and_center_crop(image: np.ndarray, output_size) -> np.ndarray:
    """Exactly mirrors olmo/data/model_preprocessor.py:hf_resize_and_center_crop."""
    desired_h, desired_w = output_size
    height, width = image.shape[:2]
    scale = max(desired_h / height, desired_w / width)
    new_h = int(height * scale)
    new_w = int(width * scale)

    pil_image = PIL.Image.fromarray(image)
    pil_image = pil_image.resize((new_w, new_h), PIL.Image.BICUBIC)

    top = (new_h - desired_h) // 2
    left = (new_w - desired_w) // 2
    pil_image = pil_image.crop((left, top, left + desired_w, top + desired_h))

    return np.array(pil_image).astype(np.float32) / 255.0


class MolmoOlmo3ImageProcessor(BaseImageProcessor):
    model_input_names = ["pixel_values"]

    def __init__(
        self,
        image_size: int = 224,
        image_mean=OPENAI_CLIP_MEAN,
        image_std=OPENAI_CLIP_STD,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.image_size = image_size
        self.image_mean = list(image_mean)
        self.image_std = list(image_std)

    def preprocess_one(self, image) -> np.ndarray:
        arr = _to_uint8_rgb(image)
        resized = hf_resize_and_center_crop(arr, (self.image_size, self.image_size))  # (H,W,3) in [0,1]
        resized = resized - np.array(self.image_mean, dtype=np.float32)[None, None, :]
        resized = resized / np.array(self.image_std, dtype=np.float32)[None, None, :]
        # HWC -> CHW
        return np.transpose(resized, (2, 0, 1))

    def preprocess(
        self,
        images: Union[ImageInput, List[ImageInput]],
        return_tensors: Optional[str] = "pt",
        **kwargs,
    ) -> BatchFeature:
        if not isinstance(images, (list, tuple)):
            images = [images]
        pixel_values = np.stack([self.preprocess_one(im) for im in images], axis=0)  # (n, 3, H, W)
        data = {"pixel_values": pixel_values}
        if return_tensors == "pt":
            data = {"pixel_values": torch.from_numpy(pixel_values)}
        return BatchFeature(data=data, tensor_type=None)


__all__ = ["MolmoOlmo3ImageProcessor"]