| | from typing import Any, Dict, List, Optional, Union
|
| |
|
| | import numpy as np
|
| | import torch
|
| | from PIL import Image
|
| |
|
| | from transformers import ImageProcessingMixin
|
| |
|
| |
|
| | def _to_rgb(img: Image.Image) -> Image.Image:
|
| | if img.mode != "RGB":
|
| | return img.convert("RGB")
|
| | return img
|
| |
|
| |
|
| | class UpscalerImageProcessor(ImageProcessingMixin):
|
| | """
|
| | Minimal processor:
|
| | - input: PIL or list of PIL
|
| | - output: pixel_values float32 in [0,1], shape (B,3,H,W)
|
| | No ImageNet normalization (recommended for SR trained on [0,1]).
|
| | """
|
| |
|
| | model_input_names = ["pixel_values"]
|
| |
|
| | def __init__(self, **kwargs):
|
| | super().__init__(**kwargs)
|
| |
|
| | def _pil_to_tensor_01(self, img: Image.Image) -> torch.FloatTensor:
|
| | img = _to_rgb(img)
|
| | arr = np.array(img, dtype=np.float32) / 255.0
|
| | t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
|
| | return t
|
| |
|
| | def __call__(
|
| | self,
|
| | images: Union[Image.Image, List[Image.Image]],
|
| | return_tensors: Optional[str] = None,
|
| | **kwargs,
|
| | ) -> Dict[str, Any]:
|
| | if isinstance(images, Image.Image):
|
| | images = [images]
|
| |
|
| | tensors = [self._pil_to_tensor_01(im) for im in images]
|
| | pixel_values = torch.stack(tensors, dim=0)
|
| |
|
| | if return_tensors is None or return_tensors == "pt":
|
| | return {"pixel_values": pixel_values}
|
| | raise ValueError("Only return_tensors=None or 'pt' is supported.") |