File size: 3,373 Bytes
e97480b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | """
SCHPImageProcessor — preprocessing for SCHPForSemanticSegmentation.
Resizes images to the model's expected input size and normalises with the
SCHP BGR-indexed mean/std convention (channels are RGB in the tensor but
the normalisation constants come from a BGR-trained ResNet-101).
"""
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from transformers import BaseImageProcessor
from transformers.image_processing_utils import BatchFeature
class SCHPImageProcessor(BaseImageProcessor):
"""
Image processor for SCHP (Self-Correction Human Parsing).
Args:
size (`dict`, *optional*, defaults to ``{"height": 512, "width": 512}``):
Resize target for the shorter edge. The model was trained at 512×512.
image_mean (`list[float]`):
Per-channel mean in **RGB channel order** using BGR-indexed values:
``[0.406, 0.456, 0.485]``.
image_std (`list[float]`):
Per-channel std in **RGB channel order** using BGR-indexed values:
``[0.225, 0.224, 0.229]``.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
size: Optional[Dict[str, int]] = None,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.size = size or {"height": 512, "width": 512}
# BGR-indexed normalisation constants used during SCHP training
self.image_mean = image_mean or [0.406, 0.456, 0.485]
self.image_std = image_std or [0.225, 0.224, 0.229]
def preprocess(
self,
images: Union[
Image.Image,
np.ndarray,
torch.Tensor,
List[Union[Image.Image, np.ndarray, torch.Tensor]],
],
return_tensors: Optional[str] = "pt",
**kwargs,
) -> BatchFeature:
"""
Pre-process one or more images.
Returns a :class:`BatchFeature` with a ``pixel_values`` key of shape
``(batch, 3, H, W)`` as a ``torch.Tensor`` (when ``return_tensors="pt"``).
"""
if not isinstance(images, (list, tuple)):
images = [images]
h = self.size["height"]
w = self.size["width"]
mean = self.image_mean
std = self.image_std
tensors = []
for img in images:
# --- normalise input type to PIL RGB ---
pil: Image.Image
if isinstance(img, torch.Tensor):
# (C, H, W) float tensor in [0, 1]
pil = TF.to_pil_image(img.cpu())
elif isinstance(img, np.ndarray):
pil = Image.fromarray(np.asarray(img, dtype=np.uint8))
else:
assert isinstance(img, Image.Image)
pil = img
pil = pil.convert("RGB")
# --- resize → tensor → normalise ---
pil = pil.resize((w, h), resample=Image.Resampling.BILINEAR)
t = TF.to_tensor(pil) # float32 in [0, 1], shape (3, H, W)
t = TF.normalize(t, mean=mean, std=std)
tensors.append(t)
pixel_values = torch.stack(tensors) # (B, 3, H, W)
return BatchFeature({"pixel_values": pixel_values}, tensor_type=return_tensors)
|