File size: 3,373 Bytes
e97480b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
SCHPImageProcessor — preprocessing for SCHPForSemanticSegmentation.

Resizes images to the model's expected input size and normalises with the
SCHP BGR-indexed mean/std convention (channels are RGB in the tensor but
the normalisation constants come from a BGR-trained ResNet-101).
"""

from typing import Dict, List, Optional, Union

import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from transformers import BaseImageProcessor
from transformers.image_processing_utils import BatchFeature


class SCHPImageProcessor(BaseImageProcessor):
    """
    Image processor for SCHP (Self-Correction Human Parsing).

    Args:
        size (`dict`, *optional*, defaults to ``{"height": 512, "width": 512}``):
            Resize target for the shorter edge. The model was trained at 512×512.
        image_mean (`list[float]`):
            Per-channel mean in **RGB channel order** using BGR-indexed values:
            ``[0.406, 0.456, 0.485]``.
        image_std (`list[float]`):
            Per-channel std  in **RGB channel order** using BGR-indexed values:
            ``[0.225, 0.224, 0.229]``.
    """

    model_input_names = ["pixel_values"]

    def __init__(
        self,
        size: Optional[Dict[str, int]] = None,
        image_mean: Optional[List[float]] = None,
        image_std: Optional[List[float]] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.size = size or {"height": 512, "width": 512}
        # BGR-indexed normalisation constants used during SCHP training
        self.image_mean = image_mean or [0.406, 0.456, 0.485]
        self.image_std = image_std or [0.225, 0.224, 0.229]

    def preprocess(
        self,
        images: Union[
            Image.Image,
            np.ndarray,
            torch.Tensor,
            List[Union[Image.Image, np.ndarray, torch.Tensor]],
        ],
        return_tensors: Optional[str] = "pt",
        **kwargs,
    ) -> BatchFeature:
        """
        Pre-process one or more images.

        Returns a :class:`BatchFeature` with a ``pixel_values`` key of shape
        ``(batch, 3, H, W)`` as a ``torch.Tensor`` (when ``return_tensors="pt"``).
        """
        if not isinstance(images, (list, tuple)):
            images = [images]

        h = self.size["height"]
        w = self.size["width"]
        mean = self.image_mean
        std = self.image_std

        tensors = []
        for img in images:
            # --- normalise input type to PIL RGB ---
            pil: Image.Image
            if isinstance(img, torch.Tensor):
                # (C, H, W) float tensor in [0, 1]
                pil = TF.to_pil_image(img.cpu())
            elif isinstance(img, np.ndarray):
                pil = Image.fromarray(np.asarray(img, dtype=np.uint8))
            else:
                assert isinstance(img, Image.Image)
                pil = img
            pil = pil.convert("RGB")

            # --- resize → tensor → normalise ---
            pil = pil.resize((w, h), resample=Image.Resampling.BILINEAR)
            t = TF.to_tensor(pil)  # float32 in [0, 1], shape (3, H, W)
            t = TF.normalize(t, mean=mean, std=std)
            tensors.append(t)

        pixel_values = torch.stack(tensors)  # (B, 3, H, W)
        return BatchFeature({"pixel_values": pixel_values}, tensor_type=return_tensors)