| from typing import Any, Dict, List, Tuple |
| from PIL.Image import Image |
| from open_clip.factory import PreprocessCfg, image_transform_v2 |
| from transformers import ImageProcessingMixin |
| import torch |
|
|
| class _ValidKwargs: |
| pass |
|
|
| class BiomedCLIPImageProcessor(ImageProcessingMixin): |
| valid_kwargs = _ValidKwargs |
|
|
| def __init__(self, preprocess_cfg: Dict[str, Any], |
| is_train: bool = False, **kwargs): |
| super().__init__(**kwargs) |
| self.preprocess_cfg = preprocess_cfg |
| self.is_train = is_train |
| self.preprocess = None |
| |
| def __call__(self, images, return_tensors="pt"): |
| if not isinstance(images, (list, tuple)): |
| images = [images] |
| if self.preprocess is None: |
| self.preprocess = image_transform_v2(PreprocessCfg(**self.preprocess_cfg), is_train=self.is_train) |
| pixels = torch.stack([self.preprocess(im) for im in images], dim=0) |
| return {"pixel_values": pixels} |
|
|