BiomedCLIPv5 / image_processing_biomedclip.py
michel-ducartier's picture
Upload processor
691ec83 verified
raw
history blame contribute delete
990 Bytes
from typing import Any, Dict, List, Tuple
from PIL.Image import Image
from open_clip.factory import PreprocessCfg, image_transform_v2
from transformers import ImageProcessingMixin
import torch
class _ValidKwargs:
pass
class BiomedCLIPImageProcessor(ImageProcessingMixin):
valid_kwargs = _ValidKwargs
def __init__(self, preprocess_cfg: Dict[str, Any],
is_train: bool = False, **kwargs):
super().__init__(**kwargs)
self.preprocess_cfg = preprocess_cfg
self.is_train = is_train
self.preprocess = None # avoid hasattr hack
def __call__(self, images, return_tensors="pt"):
if not isinstance(images, (list, tuple)):
images = [images]
if self.preprocess is None:
self.preprocess = image_transform_v2(PreprocessCfg(**self.preprocess_cfg), is_train=self.is_train)
pixels = torch.stack([self.preprocess(im) for im in images], dim=0)
return {"pixel_values": pixels}