pirocheto commited on
Commit
e97480b
·
0 Parent(s):

feat: initial release — Pascal Person Part 7-class SCHP model

Browse files
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.onnx.data filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.png filter=lfs diff=lfs merge=lfs -text
4
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
5
+ *.onnx filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+
5
+ # Temporary files from ONNX quantization pre-processing
6
+ onnx/*-preprocessed.onnx
7
+ onnx/*-preprocessed.onnx.data
8
+ onnx/*.data
9
+
10
+ # Keep named ONNX files
11
+ !onnx/schp-pascal-7.onnx
12
+ !onnx/schp-pascal-7.onnx.data
13
+ !onnx/schp-pascal-7-int8-static.onnx
14
+ !onnx/schp-pascal-7-int8-dynamic.onnx
README.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: mit
4
+ tags:
5
+ - vision
6
+ - image-segmentation
7
+ - semantic-segmentation
8
+ - human-parsing
9
+ - body-parts
10
+ - pytorch
11
+ - onnx
12
+ datasets:
13
+ - pascal-person-part
14
+ pipeline_tag: image-segmentation
15
+ ---
16
+
17
+ # SCHP — Self-Correction Human Parsing (Pascal Person Part, 7 classes)
18
+
19
+ **SCHP** (Self-Correction for Human Parsing) is a state-of-the-art human parsing model based on a ResNet-101 backbone.
20
+ This checkpoint is trained on the **Pascal Person Part** dataset and packaged for the 🤗 Transformers `AutoModel` API.
21
+
22
+ > Original repository: [PeikeLi/Self-Correction-Human-Parsing](https://github.com/PeikeLi/Self-Correction-Human-Parsing)
23
+
24
+ **Use cases:**
25
+ - 🏃 **Body part segmentation** — segment coarse body regions (head, torso, arms, legs) for pose-aware applications
26
+ - 🎮 **Avatar rigging** — generate body part masks as a preprocessing step for AR/VR avatars
27
+ - 🏥 **Medical / ergonomics** — coarse body region detection for posture analysis or wearable device placement
28
+ - 📐 **Body proportion estimation** — measure relative areas of body segments in 2D images
29
+
30
+ ## Dataset — Pascal Person Part
31
+
32
+ Pascal Person Part is a single-person human parsing dataset with 3 000+ images focused on **body part segmentation**.
33
+
34
+ - **mIoU on Pascal Person Part validation: 71.46%**
35
+ - 7 coarse labels covering body regions
36
+
37
+ ## Labels
38
+
39
+ | ID | Label |
40
+ |----|-------|
41
+ | 0 | Background |
42
+ | 1 | Head |
43
+ | 2 | Torso |
44
+ | 3 | Upper Arms |
45
+ | 4 | Lower Arms |
46
+ | 5 | Upper Legs |
47
+ | 6 | Lower Legs |
48
+
49
+ ## Usage — PyTorch
50
+
51
+ ```python
52
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
53
+ from PIL import Image
54
+ import torch
55
+
56
+ model = AutoModelForSemanticSegmentation.from_pretrained("pirocheto/schp-pascal-7", trust_remote_code=True)
57
+ processor = AutoImageProcessor.from_pretrained("pirocheto/schp-pascal-7", trust_remote_code=True)
58
+
59
+ image = Image.open("photo.jpg").convert("RGB")
60
+ inputs = processor(images=image, return_tensors="pt")
61
+
62
+ with torch.no_grad():
63
+ outputs = model(**inputs)
64
+
65
+ # outputs.logits — (1, 7, 512, 512) raw logits
66
+ # outputs.parsing_logits — (1, 7, 512, 512) refined parsing logits
67
+ # outputs.edge_logits — (1, 1, 512, 512) edge prediction logits
68
+ seg_map = outputs.logits.argmax(dim=1).squeeze().numpy() # (H, W), values in [0, 6]
69
+ ```
70
+
71
+ Each pixel in `seg_map` is a label ID. To map IDs back to names:
72
+
73
+ ```python
74
+ id2label = model.config.id2label
75
+ print(id2label[1]) # → "Head"
76
+ ```
77
+
78
+ ## Usage — ONNX Runtime
79
+
80
+ Optimized ONNX files are available in the `onnx/` folder of this repo:
81
+
82
+ | File | Size | Notes |
83
+ |------|------|-------|
84
+ | `onnx/schp-pascal-7.onnx` + `.onnx.data` | ~257 MB | FP32, dynamic batch |
85
+ | `onnx/schp-pascal-7-int8-static.onnx` | ~66 MB | INT8 static, 99.77% pixel agreement |
86
+
87
+ ```python
88
+ import onnxruntime as ort
89
+ import numpy as np
90
+ from huggingface_hub import hf_hub_download
91
+ from transformers import AutoImageProcessor
92
+ from PIL import Image
93
+
94
+ model_path = hf_hub_download("pirocheto/schp-pascal-7", "onnx/schp-pascal-7-int8-static.onnx")
95
+ processor = AutoImageProcessor.from_pretrained("pirocheto/schp-pascal-7", trust_remote_code=True)
96
+
97
+ sess_opts = ort.SessionOptions()
98
+ sess_opts.intra_op_num_threads = 8
99
+ sess = ort.InferenceSession(model_path, sess_opts, providers=["CPUExecutionProvider"])
100
+
101
+ image = Image.open("photo.jpg").convert("RGB")
102
+ inputs = processor(images=image, return_tensors="np")
103
+ logits = sess.run(["logits"], {"pixel_values": inputs["pixel_values"]})[0]
104
+ seg_map = logits.argmax(axis=1).squeeze() # (H, W)
105
+ ```
106
+
107
+ ## Performance
108
+
109
+ Benchmarked on CPU (16-core, 8 ORT threads, `intra_op_num_threads=8`):
110
+
111
+ | Backend | Latency | Speedup | Size |
112
+ |---------|---------|---------|------|
113
+ | PyTorch FP32 | ~424 ms | 1× | 255 MB |
114
+ | ONNX FP32 | ~296 ms | 1.44× | 256 MB |
115
+ | ONNX INT8 static | ~218 ms | **1.94×** | **66 MB** |
116
+
117
+ INT8 static quantization achieves **99.77% pixel-level agreement** with the FP32 model.
118
+
119
+ ## Model Details
120
+
121
+ | Property | Value |
122
+ |----------|-------|
123
+ | Architecture | ResNet-101 + SCHP self-correction |
124
+ | Input size | 512 × 512 |
125
+ | Output | 3 heads: logits, parsing_logits, edge_logits |
126
+ | num_labels | 7 |
127
+ | Dataset | Pascal Person Part |
128
+ | Original mIoU | 71.46% |
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "SCHPForSemanticSegmentation"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_schp.SCHPConfig",
7
+ "AutoModelForSemanticSegmentation": "modeling_schp.SCHPForSemanticSegmentation"
8
+ },
9
+ "backbone": "resnet101",
10
+ "dtype": "float32",
11
+ "id2label": {
12
+ "0": "Background",
13
+ "1": "Head",
14
+ "2": "Torso",
15
+ "3": "Upper Arms",
16
+ "4": "Lower Arms",
17
+ "5": "Upper Legs",
18
+ "6": "Lower Legs"
19
+ },
20
+ "input_size": 512,
21
+ "label2id": {
22
+ "Background": "0",
23
+ "Head": "1",
24
+ "Lower Arms": "4",
25
+ "Lower Legs": "6",
26
+ "Torso": "2",
27
+ "Upper Arms": "3",
28
+ "Upper Legs": "5"
29
+ },
30
+ "model_type": "schp",
31
+ "transformers_version": "5.5.0"
32
+ }
configuration_schp.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ _PASCAL_LABELS = [
4
+ "Background",
5
+ "Head",
6
+ "Torso",
7
+ "Upper Arms",
8
+ "Lower Arms",
9
+ "Upper Legs",
10
+ "Lower Legs",
11
+ ]
12
+
13
+
14
+ class SCHPConfig(PretrainedConfig):
15
+ r"""
16
+ Configuration for **Self-Correction-Human-Parsing (SCHP)**.
17
+
18
+ Args:
19
+ num_labels (`int`, *optional*, defaults to 7):
20
+ Number of segmentation classes (7 for Pascal Person Part dataset).
21
+ input_size (`int`, *optional*, defaults to 512):
22
+ Spatial resolution the model expects (height = width).
23
+ backbone (`str`, *optional*, defaults to `"resnet101"`):
24
+ Backbone architecture name. Only `"resnet101"` is supported.
25
+ """
26
+
27
+ model_type = "schp"
28
+
29
+ def __init__(
30
+ self,
31
+ num_labels: int = 7,
32
+ input_size: int = 512,
33
+ backbone: str = "resnet101",
34
+ **kwargs,
35
+ ):
36
+ super().__init__(**kwargs)
37
+ self.num_labels = num_labels
38
+ self.input_size = input_size
39
+ self.backbone = backbone
40
+
41
+ if "id2label" not in kwargs:
42
+ self.id2label = {
43
+ str(i): lbl for i, lbl in enumerate(_PASCAL_LABELS[:num_labels])
44
+ }
45
+ if "label2id" not in kwargs:
46
+ self.label2id = {
47
+ lbl: str(i) for i, lbl in enumerate(_PASCAL_LABELS[:num_labels])
48
+ }
image_processing_schp.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SCHPImageProcessor — preprocessing for SCHPForSemanticSegmentation.
3
+
4
+ Resizes images to the model's expected input size and normalises with the
5
+ SCHP BGR-indexed mean/std convention (channels are RGB in the tensor but
6
+ the normalisation constants come from a BGR-trained ResNet-101).
7
+ """
8
+
9
+ from typing import Dict, List, Optional, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torchvision.transforms.functional as TF
14
+ from PIL import Image
15
+ from transformers import BaseImageProcessor
16
+ from transformers.image_processing_utils import BatchFeature
17
+
18
+
19
+ class SCHPImageProcessor(BaseImageProcessor):
20
+ """
21
+ Image processor for SCHP (Self-Correction Human Parsing).
22
+
23
+ Args:
24
+ size (`dict`, *optional*, defaults to ``{"height": 512, "width": 512}``):
25
+ Resize target for the shorter edge. The model was trained at 512×512.
26
+ image_mean (`list[float]`):
27
+ Per-channel mean in **RGB channel order** using BGR-indexed values:
28
+ ``[0.406, 0.456, 0.485]``.
29
+ image_std (`list[float]`):
30
+ Per-channel std in **RGB channel order** using BGR-indexed values:
31
+ ``[0.225, 0.224, 0.229]``.
32
+ """
33
+
34
+ model_input_names = ["pixel_values"]
35
+
36
+ def __init__(
37
+ self,
38
+ size: Optional[Dict[str, int]] = None,
39
+ image_mean: Optional[List[float]] = None,
40
+ image_std: Optional[List[float]] = None,
41
+ **kwargs,
42
+ ):
43
+ super().__init__(**kwargs)
44
+ self.size = size or {"height": 512, "width": 512}
45
+ # BGR-indexed normalisation constants used during SCHP training
46
+ self.image_mean = image_mean or [0.406, 0.456, 0.485]
47
+ self.image_std = image_std or [0.225, 0.224, 0.229]
48
+
49
+ def preprocess(
50
+ self,
51
+ images: Union[
52
+ Image.Image,
53
+ np.ndarray,
54
+ torch.Tensor,
55
+ List[Union[Image.Image, np.ndarray, torch.Tensor]],
56
+ ],
57
+ return_tensors: Optional[str] = "pt",
58
+ **kwargs,
59
+ ) -> BatchFeature:
60
+ """
61
+ Pre-process one or more images.
62
+
63
+ Returns a :class:`BatchFeature` with a ``pixel_values`` key of shape
64
+ ``(batch, 3, H, W)`` as a ``torch.Tensor`` (when ``return_tensors="pt"``).
65
+ """
66
+ if not isinstance(images, (list, tuple)):
67
+ images = [images]
68
+
69
+ h = self.size["height"]
70
+ w = self.size["width"]
71
+ mean = self.image_mean
72
+ std = self.image_std
73
+
74
+ tensors = []
75
+ for img in images:
76
+ # --- normalise input type to PIL RGB ---
77
+ pil: Image.Image
78
+ if isinstance(img, torch.Tensor):
79
+ # (C, H, W) float tensor in [0, 1]
80
+ pil = TF.to_pil_image(img.cpu())
81
+ elif isinstance(img, np.ndarray):
82
+ pil = Image.fromarray(np.asarray(img, dtype=np.uint8))
83
+ else:
84
+ assert isinstance(img, Image.Image)
85
+ pil = img
86
+ pil = pil.convert("RGB")
87
+
88
+ # --- resize → tensor → normalise ---
89
+ pil = pil.resize((w, h), resample=Image.Resampling.BILINEAR)
90
+ t = TF.to_tensor(pil) # float32 in [0, 1], shape (3, H, W)
91
+ t = TF.normalize(t, mean=mean, std=std)
92
+ tensors.append(t)
93
+
94
+ pixel_values = torch.stack(tensors) # (B, 3, H, W)
95
+ return BatchFeature({"pixel_values": pixel_values}, tensor_type=return_tensors)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50f185b34ce14e92bccf809c9d8a369e9beaa4b999ef15fab0e2a8c3475560c6
3
+ size 267399112
modeling_schp.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SCHP (Self-Correction Human Parsing) — Transformers-compatible implementation.
3
+
4
+ Architecture inlined from https://github.com/GoGoDuck912/Self-Correction-Human-Parsing
5
+ (networks/AugmentCE2P.py) with the CUDA-only InPlaceABNSync replaced by a pure-PyTorch
6
+ drop-in, making the model fully runnable on CPU.
7
+ """
8
+
9
+ import functools
10
+ from dataclasses import dataclass
11
+ from typing import Optional, Tuple, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from transformers import PreTrainedModel
17
+ from transformers.utils import ModelOutput
18
+
19
+ from schp.configuration_schp import SCHPConfig
20
+
21
+
22
+ # ── Pure-PyTorch InPlaceABNSync shim ──────────────────────────────────────────
23
+ class InPlaceABNSync(nn.BatchNorm2d):
24
+ """CPU-compatible drop-in for InPlaceABNSync.
25
+
26
+ Subclasses ``nn.BatchNorm2d`` directly so that state-dict keys
27
+ (weight, bias, running_mean, running_var) match the original SCHP
28
+ checkpoints without any nesting.
29
+ """
30
+
31
+ def __init__(self, num_features, activation="leaky_relu", slope=0.01, **kwargs):
32
+ bn_kwargs = {
33
+ k: v
34
+ for k, v in kwargs.items()
35
+ if k in ("eps", "momentum", "affine", "track_running_stats")
36
+ }
37
+ super().__init__(num_features, **bn_kwargs)
38
+ self.activation = activation
39
+ self.slope = slope
40
+
41
+ def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore[override]
42
+ input = super().forward(input)
43
+ if self.activation == "leaky_relu":
44
+ return F.leaky_relu(input, negative_slope=self.slope, inplace=True)
45
+ elif self.activation == "elu":
46
+ return F.elu(input, inplace=True)
47
+ return input
48
+
49
+
50
+ # BatchNorm2d with no activation (activation="none")
51
+ BatchNorm2d = functools.partial(InPlaceABNSync, activation="none")
52
+ affine_par = True
53
+
54
+
55
+ # ── Model architecture (inlined from AugmentCE2P.py) ─────────────────────────
56
+ def _conv3x3(in_planes, out_planes, stride=1):
57
+ return nn.Conv2d(
58
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
59
+ )
60
+
61
+
62
+ class _Bottleneck(nn.Module):
63
+ expansion = 4
64
+
65
+ def __init__(
66
+ self, inplanes, planes, stride=1, dilation=1, downsample=None, multi_grid=1
67
+ ):
68
+ super().__init__()
69
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
70
+ self.bn1 = BatchNorm2d(planes)
71
+ self.conv2 = nn.Conv2d(
72
+ planes,
73
+ planes,
74
+ kernel_size=3,
75
+ stride=stride,
76
+ padding=dilation * multi_grid,
77
+ dilation=dilation * multi_grid,
78
+ bias=False,
79
+ )
80
+ self.bn2 = BatchNorm2d(planes)
81
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
82
+ self.bn3 = BatchNorm2d(planes * 4)
83
+ self.relu = nn.ReLU(inplace=False)
84
+ self.relu_inplace = nn.ReLU(inplace=True)
85
+ self.downsample = downsample
86
+ self.dilation = dilation
87
+ self.stride = stride
88
+
89
+ def forward(self, x):
90
+ residual = x
91
+ out = self.relu(self.bn1(self.conv1(x)))
92
+ out = self.relu(self.bn2(self.conv2(out)))
93
+ out = self.bn3(self.conv3(out))
94
+ if self.downsample is not None:
95
+ residual = self.downsample(x)
96
+ return self.relu_inplace(out + residual)
97
+
98
+
99
+ class _PSPModule(nn.Module):
100
+ def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
101
+ super().__init__()
102
+ self.stages = nn.ModuleList(
103
+ [
104
+ nn.Sequential(
105
+ nn.AdaptiveAvgPool2d(size),
106
+ nn.Conv2d(features, out_features, kernel_size=1, bias=False),
107
+ InPlaceABNSync(out_features),
108
+ )
109
+ for size in sizes
110
+ ]
111
+ )
112
+ self.bottleneck = nn.Sequential(
113
+ nn.Conv2d(
114
+ features + len(sizes) * out_features,
115
+ out_features,
116
+ kernel_size=3,
117
+ padding=1,
118
+ dilation=1,
119
+ bias=False,
120
+ ),
121
+ InPlaceABNSync(out_features),
122
+ )
123
+
124
+ def forward(self, feats):
125
+ h, w = feats.size(2), feats.size(3)
126
+ priors = [
127
+ F.interpolate(
128
+ stage(feats), size=(h, w), mode="bilinear", align_corners=True
129
+ )
130
+ for stage in self.stages
131
+ ] + [feats]
132
+ return self.bottleneck(torch.cat(priors, dim=1))
133
+
134
+
135
+ class _Edge_Module(nn.Module):
136
+ def __init__(self, in_fea=(256, 512, 1024), mid_fea=256, out_fea=2):
137
+ super().__init__()
138
+ self.conv1 = nn.Sequential(
139
+ nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, bias=False),
140
+ InPlaceABNSync(mid_fea),
141
+ )
142
+ self.conv2 = nn.Sequential(
143
+ nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, bias=False),
144
+ InPlaceABNSync(mid_fea),
145
+ )
146
+ self.conv3 = nn.Sequential(
147
+ nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, bias=False),
148
+ InPlaceABNSync(mid_fea),
149
+ )
150
+ self.conv4 = nn.Conv2d(mid_fea, out_fea, kernel_size=3, padding=1, bias=True)
151
+ self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, bias=True)
152
+
153
+ def forward(self, x1, x2, x3):
154
+ _, _, h, w = x1.size()
155
+ ef1 = self.conv1(x1)
156
+ ef2 = self.conv2(x2)
157
+ ef3 = self.conv3(x3)
158
+ e1 = self.conv4(ef1)
159
+ e2 = F.interpolate(
160
+ self.conv4(ef2), size=(h, w), mode="bilinear", align_corners=True
161
+ )
162
+ e3 = F.interpolate(
163
+ self.conv4(ef3), size=(h, w), mode="bilinear", align_corners=True
164
+ )
165
+ ef2 = F.interpolate(ef2, size=(h, w), mode="bilinear", align_corners=True)
166
+ ef3 = F.interpolate(ef3, size=(h, w), mode="bilinear", align_corners=True)
167
+ edge = self.conv5(torch.cat([e1, e2, e3], dim=1))
168
+ edge_fea = torch.cat([ef1, ef2, ef3], dim=1)
169
+ return edge, edge_fea
170
+
171
+
172
+ class _Decoder_Module(nn.Module):
173
+ def __init__(self, num_classes):
174
+ super().__init__()
175
+ self.conv1 = nn.Sequential(
176
+ nn.Conv2d(512, 256, kernel_size=1, bias=False),
177
+ InPlaceABNSync(256),
178
+ )
179
+ self.conv2 = nn.Sequential(
180
+ nn.Conv2d(256, 48, kernel_size=1, bias=False),
181
+ InPlaceABNSync(48),
182
+ )
183
+ self.conv3 = nn.Sequential(
184
+ nn.Conv2d(304, 256, kernel_size=1, bias=False),
185
+ InPlaceABNSync(256),
186
+ nn.Conv2d(256, 256, kernel_size=1, bias=False),
187
+ InPlaceABNSync(256),
188
+ )
189
+ self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, bias=True)
190
+
191
+ def forward(self, xt, xl):
192
+ _, _, h, w = xl.size()
193
+ xt = F.interpolate(
194
+ self.conv1(xt), size=(h, w), mode="bilinear", align_corners=True
195
+ )
196
+ xl = self.conv2(xl)
197
+ x = self.conv3(torch.cat([xt, xl], dim=1))
198
+ return self.conv4(x), x
199
+
200
+
201
+ class _SCHPResNet(nn.Module):
202
+ """SCHP ResNet-101 backbone + decoder (reproduced from AugmentCE2P.py)."""
203
+
204
+ def __init__(self, num_classes: int):
205
+ self.inplanes = 128
206
+ super().__init__()
207
+ # Three-layer stem
208
+ self.conv1 = _conv3x3(3, 64, stride=2)
209
+ self.bn1 = BatchNorm2d(64)
210
+ self.relu1 = nn.ReLU(inplace=False)
211
+ self.conv2 = _conv3x3(64, 64)
212
+ self.bn2 = BatchNorm2d(64)
213
+ self.relu2 = nn.ReLU(inplace=False)
214
+ self.conv3 = _conv3x3(64, 128)
215
+ self.bn3 = BatchNorm2d(128)
216
+ self.relu3 = nn.ReLU(inplace=False)
217
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
218
+ # ResNet stages
219
+ self.layer1 = self._make_layer(_Bottleneck, 64, 3)
220
+ self.layer2 = self._make_layer(_Bottleneck, 128, 4, stride=2)
221
+ self.layer3 = self._make_layer(_Bottleneck, 256, 23, stride=2)
222
+ self.layer4 = self._make_layer(
223
+ _Bottleneck, 512, 3, stride=1, dilation=2, multi_grid=(1, 1, 1)
224
+ )
225
+ # Head modules
226
+ self.context_encoding = _PSPModule(2048, 512)
227
+ self.edge = _Edge_Module()
228
+ self.decoder = _Decoder_Module(num_classes)
229
+ self.fushion = nn.Sequential(
230
+ nn.Conv2d(1024, 256, kernel_size=1, bias=False),
231
+ InPlaceABNSync(256),
232
+ nn.Dropout2d(0.1),
233
+ nn.Conv2d(256, num_classes, kernel_size=1, bias=True),
234
+ )
235
+
236
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
237
+ downsample = None
238
+ if stride != 1 or self.inplanes != planes * block.expansion:
239
+ downsample = nn.Sequential(
240
+ nn.Conv2d(
241
+ self.inplanes,
242
+ planes * block.expansion,
243
+ kernel_size=1,
244
+ stride=stride,
245
+ bias=False,
246
+ ),
247
+ BatchNorm2d(planes * block.expansion, affine=affine_par),
248
+ )
249
+
250
+ def _grid(i, g):
251
+ return g[i % len(g)] if isinstance(g, tuple) else 1
252
+
253
+ layers = [
254
+ block(
255
+ self.inplanes,
256
+ planes,
257
+ stride,
258
+ dilation=dilation,
259
+ downsample=downsample,
260
+ multi_grid=_grid(0, multi_grid),
261
+ )
262
+ ]
263
+ self.inplanes = planes * block.expansion
264
+ for i in range(1, blocks):
265
+ layers.append(
266
+ block(
267
+ self.inplanes,
268
+ planes,
269
+ dilation=dilation,
270
+ multi_grid=_grid(i, multi_grid),
271
+ )
272
+ )
273
+ return nn.Sequential(*layers)
274
+
275
+ def forward(self, x):
276
+ x = self.relu1(self.bn1(self.conv1(x)))
277
+ x = self.relu2(self.bn2(self.conv2(x)))
278
+ x = self.relu3(self.bn3(self.conv3(x)))
279
+ x = self.maxpool(x)
280
+ x2 = self.layer1(x)
281
+ x3 = self.layer2(x2)
282
+ x4 = self.layer3(x3)
283
+ x5 = self.layer4(x4)
284
+ context = self.context_encoding(x5)
285
+ parsing_result, parsing_fea = self.decoder(context, x2)
286
+ edge_result, edge_fea = self.edge(x2, x3, x4)
287
+ fusion_result = self.fushion(torch.cat([parsing_fea, edge_fea], dim=1))
288
+ # Return format mirrors the original: [[parsing, fusion], [edge]]
289
+ return [[parsing_result, fusion_result], [edge_result]]
290
+
291
+
292
+ # ── Transformers output dataclass ────────────────────────────────────────────
293
+ @dataclass
294
+ class SCHPSemanticSegmenterOutput(ModelOutput):
295
+ """
296
+ Output type for :class:`SCHPForSemanticSegmentation`.
297
+
298
+ Args:
299
+ loss: Cross-entropy loss (only when ``labels`` is provided).
300
+ logits: Final fusion logits, shape ``(batch, num_labels, H, W)``,
301
+ upsampled to the input image resolution.
302
+ parsing_logits: Decoder-branch logits before fusion,
303
+ shape ``(batch, num_labels, H, W)``.
304
+ edge_logits: Edge-branch logits, shape ``(batch, 2, H, W)``.
305
+ """
306
+
307
+ loss: Optional[torch.Tensor] = None
308
+ logits: Optional[torch.Tensor] = None
309
+ parsing_logits: Optional[torch.Tensor] = None
310
+ edge_logits: Optional[torch.Tensor] = None
311
+
312
+
313
+ # ── PreTrainedModel wrapper ───────────────────────────────────────────────────
314
+ class SCHPForSemanticSegmentation(PreTrainedModel):
315
+ """
316
+ SCHP ResNet-101 for human parsing / semantic segmentation.
317
+
318
+ Usage — loading from an original SCHP ``.pth`` checkpoint::
319
+
320
+ model = SCHPForSemanticSegmentation.from_schp_checkpoint(
321
+ "checkpoints/schp/exp-schp-201908301523-atr.pth"
322
+ )
323
+
324
+ Usage — loading after :meth:`save_pretrained`::
325
+
326
+ model = SCHPForSemanticSegmentation.from_pretrained(
327
+ "./my-schp-model", trust_remote_code=True
328
+ )
329
+ """
330
+
331
+ config_class = SCHPConfig
332
+ # num_batches_tracked is not stored in the original SCHP checkpoints
333
+ _keys_to_ignore_on_load_missing = [r"\.num_batches_tracked$"]
334
+
335
+ def __init__(self, config: SCHPConfig):
336
+ super().__init__(config)
337
+ self.model = _SCHPResNet(num_classes=config.num_labels)
338
+ self.post_init()
339
+
340
+ def forward(
341
+ self,
342
+ pixel_values: torch.Tensor,
343
+ labels: Optional[torch.LongTensor] = None,
344
+ return_dict: Optional[bool] = None,
345
+ ) -> Union[SCHPSemanticSegmenterOutput, Tuple]:
346
+ """
347
+ Args:
348
+ pixel_values: ``(batch, 3, H, W)`` — normalised with SCHP BGR-indexed means.
349
+ labels: ``(batch, H, W)`` integer class map for computing CE loss.
350
+ return_dict: Override ``config.use_return_dict``.
351
+ """
352
+ return_dict = return_dict if return_dict is not None else True
353
+
354
+ h, w = pixel_values.shape[-2:]
355
+ raw = self.model(pixel_values)
356
+ # raw = [[parsing_result, fusion_result], [edge_result]]
357
+
358
+ logits = F.interpolate(
359
+ raw[0][1], size=(h, w), mode="bilinear", align_corners=True
360
+ )
361
+ parsing_logits = F.interpolate(
362
+ raw[0][0], size=(h, w), mode="bilinear", align_corners=True
363
+ )
364
+ edge_logits = F.interpolate(
365
+ raw[1][0], size=(h, w), mode="bilinear", align_corners=True
366
+ )
367
+
368
+ loss = None
369
+ if labels is not None:
370
+ loss = F.cross_entropy(logits, labels.long())
371
+
372
+ if not return_dict:
373
+ return (loss, logits) if loss is not None else (logits,)
374
+
375
+ return SCHPSemanticSegmenterOutput(
376
+ loss=loss,
377
+ logits=logits,
378
+ parsing_logits=parsing_logits,
379
+ edge_logits=edge_logits,
380
+ )
381
+
382
+ @classmethod
383
+ def from_schp_checkpoint(
384
+ cls,
385
+ checkpoint_path: str,
386
+ config: Optional[SCHPConfig] = None,
387
+ map_location: str = "cpu",
388
+ ) -> "SCHPForSemanticSegmentation":
389
+ """
390
+ Load from an original SCHP ``.pth`` checkpoint.
391
+
392
+ Handles the ``module.`` prefix added by ``DataParallel`` training and
393
+ remaps keys to the ``model.*`` namespace used by this wrapper.
394
+
395
+ Args:
396
+ checkpoint_path: Path to the ``.pth`` file.
397
+ config: :class:`SCHPConfig` instance. Defaults to ATR-18 config.
398
+ map_location: PyTorch device string (``"cpu"`` or ``"cuda"``).
399
+ """
400
+ if config is None:
401
+ config = SCHPConfig()
402
+
403
+ model = cls(config)
404
+
405
+ raw = torch.load(checkpoint_path, map_location=map_location)
406
+ state_dict = raw.get("state_dict", raw)
407
+
408
+ # Strip DataParallel module. prefix if present
409
+ if all(k.startswith("module.") for k in state_dict):
410
+ state_dict = {k[len("module.") :]: v for k, v in state_dict.items()}
411
+
412
+ # Remap to model.* namespace (self.model = _SCHPResNet)
413
+ state_dict = {"model." + k: v for k, v in state_dict.items()}
414
+
415
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
416
+ real_missing = [k for k in missing if "num_batches_tracked" not in k]
417
+ if real_missing:
418
+ raise RuntimeError(
419
+ f"Missing keys when loading SCHP checkpoint ({len(real_missing)} total): "
420
+ f"{real_missing[:5]}"
421
+ )
422
+ if unexpected:
423
+ raise RuntimeError(
424
+ f"Unexpected keys when loading SCHP checkpoint ({len(unexpected)} total): "
425
+ f"{unexpected[:5]}"
426
+ )
427
+
428
+ return model
onnx/schp-pascal-7-int8-static.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:66b12766d7f1ddbc3de972e67e8626be727507e7feeeca34e1b23b6f45e756d2
3
+ size 69148800
onnx/schp-pascal-7.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f8ce1f038ed6cb429f0a4a2f146064afd6398520226569988275e13e2847fd0
3
+ size 1489921
onnx/schp-pascal-7.onnx.data ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67cfa8f2399a68d7e0e955fb70a0ff57ddf79063cd1d7d5130a3859601f8ef04
3
+ size 266665984
preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_schp.SCHPImageProcessor"
4
+ },
5
+ "image_mean": [
6
+ 0.406,
7
+ 0.456,
8
+ 0.485
9
+ ],
10
+ "image_processor_type": "SCHPImageProcessor",
11
+ "image_std": [
12
+ 0.225,
13
+ 0.224,
14
+ 0.229
15
+ ],
16
+ "size": {
17
+ "height": 512,
18
+ "width": 512
19
+ }
20
+ }