schp-lip-20 / image_processing_schp.py
pirocheto's picture
initial commit
a0415db
"""
SCHPImageProcessor — preprocessing for SCHPForSemanticSegmentation.
Resizes images to the model's expected input size and normalises with the
SCHP BGR-indexed mean/std convention (channels are RGB in the tensor but
the normalisation constants come from a BGR-trained ResNet-101).
"""
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from transformers import BaseImageProcessor
from transformers.image_processing_utils import BatchFeature
class SCHPImageProcessor(BaseImageProcessor):
"""
Image processor for SCHP (Self-Correction Human Parsing).
Args:
size (`dict`, *optional*, defaults to ``{"height": 473, "width": 473}``):
Resize target. The LIP model was trained at 473×473.
image_mean (`list[float]`):
Per-channel mean in **RGB channel order** using BGR-indexed values:
``[0.406, 0.456, 0.485]``.
image_std (`list[float]`):
Per-channel std in **RGB channel order** using BGR-indexed values:
``[0.225, 0.224, 0.229]``.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
size: Optional[Dict[str, int]] = None,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.size = size or {"height": 473, "width": 473}
# BGR-indexed normalisation constants used during SCHP training
self.image_mean = image_mean or [0.406, 0.456, 0.485]
self.image_std = image_std or [0.225, 0.224, 0.229]
def preprocess(
self,
images: Union[
Image.Image,
np.ndarray,
torch.Tensor,
List[Union[Image.Image, np.ndarray, torch.Tensor]],
],
return_tensors: Optional[str] = "pt",
**kwargs,
) -> BatchFeature:
"""
Pre-process one or more images.
Returns a :class:`BatchFeature` with a ``pixel_values`` key of shape
``(batch, 3, H, W)`` as a ``torch.Tensor`` (when ``return_tensors="pt"``).
"""
if not isinstance(images, (list, tuple)):
images = [images]
h = self.size["height"]
w = self.size["width"]
mean = self.image_mean
std = self.image_std
tensors = []
for img in images:
# --- normalise input type to PIL RGB ---
pil: Image.Image
if isinstance(img, torch.Tensor):
# (C, H, W) float tensor in [0, 1]
pil = TF.to_pil_image(img.cpu())
elif isinstance(img, np.ndarray):
pil = Image.fromarray(np.asarray(img, dtype=np.uint8))
else:
assert isinstance(img, Image.Image)
pil = img
pil = pil.convert("RGB")
# --- resize → tensor → normalise ---
pil = pil.resize((w, h), resample=Image.Resampling.BILINEAR)
t = TF.to_tensor(pil) # float32 in [0, 1], shape (3, H, W)
t = TF.normalize(t, mean=mean, std=std)
tensors.append(t)
pixel_values = torch.stack(tensors) # (B, 3, H, W)
return BatchFeature({"pixel_values": pixel_values}, tensor_type=return_tensors)