schp-pascal-7 / modeling_schp.py
pirocheto's picture
style: reorder import
50b9f67
"""
SCHP (Self-Correction Human Parsing) β€” Transformers-compatible implementation.
Architecture inlined from https://github.com/GoGoDuck912/Self-Correction-Human-Parsing
(networks/AugmentCE2P.py) with the CUDA-only InPlaceABNSync replaced by a pure-PyTorch
drop-in, making the model fully runnable on CPU.
"""
import functools
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from schp.configuration_schp import SCHPConfig
from transformers import PreTrainedModel
from transformers.utils import ModelOutput
# ── Pure-PyTorch InPlaceABNSync shim ──────────────────────────────────────────
class InPlaceABNSync(nn.BatchNorm2d):
"""CPU-compatible drop-in for InPlaceABNSync.
Subclasses ``nn.BatchNorm2d`` directly so that state-dict keys
(weight, bias, running_mean, running_var) match the original SCHP
checkpoints without any nesting.
"""
def __init__(self, num_features, activation="leaky_relu", slope=0.01, **kwargs):
bn_kwargs = {
k: v
for k, v in kwargs.items()
if k in ("eps", "momentum", "affine", "track_running_stats")
}
super().__init__(num_features, **bn_kwargs)
self.activation = activation
self.slope = slope
def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore[override]
input = super().forward(input)
if self.activation == "leaky_relu":
return F.leaky_relu(input, negative_slope=self.slope, inplace=True)
elif self.activation == "elu":
return F.elu(input, inplace=True)
return input
# BatchNorm2d with no activation (activation="none")
BatchNorm2d = functools.partial(InPlaceABNSync, activation="none")
affine_par = True
# ── Model architecture (inlined from AugmentCE2P.py) ─────────────────────────
def _conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
)
class _Bottleneck(nn.Module):
expansion = 4
def __init__(
self, inplanes, planes, stride=1, dilation=1, downsample=None, multi_grid=1
):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=dilation * multi_grid,
dilation=dilation * multi_grid,
bias=False,
)
self.bn2 = BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=False)
self.relu_inplace = nn.ReLU(inplace=True)
self.downsample = downsample
self.dilation = dilation
self.stride = stride
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
if self.downsample is not None:
residual = self.downsample(x)
return self.relu_inplace(out + residual)
class _PSPModule(nn.Module):
def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6)):
super().__init__()
self.stages = nn.ModuleList(
[
nn.Sequential(
nn.AdaptiveAvgPool2d(size),
nn.Conv2d(features, out_features, kernel_size=1, bias=False),
InPlaceABNSync(out_features),
)
for size in sizes
]
)
self.bottleneck = nn.Sequential(
nn.Conv2d(
features + len(sizes) * out_features,
out_features,
kernel_size=3,
padding=1,
dilation=1,
bias=False,
),
InPlaceABNSync(out_features),
)
def forward(self, feats):
h, w = feats.size(2), feats.size(3)
priors = [
F.interpolate(
stage(feats), size=(h, w), mode="bilinear", align_corners=True
)
for stage in self.stages
] + [feats]
return self.bottleneck(torch.cat(priors, dim=1))
class _Edge_Module(nn.Module):
def __init__(self, in_fea=(256, 512, 1024), mid_fea=256, out_fea=2):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, bias=False),
InPlaceABNSync(mid_fea),
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, bias=False),
InPlaceABNSync(mid_fea),
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, bias=False),
InPlaceABNSync(mid_fea),
)
self.conv4 = nn.Conv2d(mid_fea, out_fea, kernel_size=3, padding=1, bias=True)
self.conv5 = nn.Conv2d(out_fea * 3, out_fea, kernel_size=1, bias=True)
def forward(self, x1, x2, x3):
_, _, h, w = x1.size()
ef1 = self.conv1(x1)
ef2 = self.conv2(x2)
ef3 = self.conv3(x3)
e1 = self.conv4(ef1)
e2 = F.interpolate(
self.conv4(ef2), size=(h, w), mode="bilinear", align_corners=True
)
e3 = F.interpolate(
self.conv4(ef3), size=(h, w), mode="bilinear", align_corners=True
)
ef2 = F.interpolate(ef2, size=(h, w), mode="bilinear", align_corners=True)
ef3 = F.interpolate(ef3, size=(h, w), mode="bilinear", align_corners=True)
edge = self.conv5(torch.cat([e1, e2, e3], dim=1))
edge_fea = torch.cat([ef1, ef2, ef3], dim=1)
return edge, edge_fea
class _Decoder_Module(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=1, bias=False),
InPlaceABNSync(256),
)
self.conv2 = nn.Sequential(
nn.Conv2d(256, 48, kernel_size=1, bias=False),
InPlaceABNSync(48),
)
self.conv3 = nn.Sequential(
nn.Conv2d(304, 256, kernel_size=1, bias=False),
InPlaceABNSync(256),
nn.Conv2d(256, 256, kernel_size=1, bias=False),
InPlaceABNSync(256),
)
self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, bias=True)
def forward(self, xt, xl):
_, _, h, w = xl.size()
xt = F.interpolate(
self.conv1(xt), size=(h, w), mode="bilinear", align_corners=True
)
xl = self.conv2(xl)
x = self.conv3(torch.cat([xt, xl], dim=1))
return self.conv4(x), x
class _SCHPResNet(nn.Module):
"""SCHP ResNet-101 backbone + decoder (reproduced from AugmentCE2P.py)."""
def __init__(self, num_classes: int):
self.inplanes = 128
super().__init__()
# Three-layer stem
self.conv1 = _conv3x3(3, 64, stride=2)
self.bn1 = BatchNorm2d(64)
self.relu1 = nn.ReLU(inplace=False)
self.conv2 = _conv3x3(64, 64)
self.bn2 = BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=False)
self.conv3 = _conv3x3(64, 128)
self.bn3 = BatchNorm2d(128)
self.relu3 = nn.ReLU(inplace=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# ResNet stages
self.layer1 = self._make_layer(_Bottleneck, 64, 3)
self.layer2 = self._make_layer(_Bottleneck, 128, 4, stride=2)
self.layer3 = self._make_layer(_Bottleneck, 256, 23, stride=2)
self.layer4 = self._make_layer(
_Bottleneck, 512, 3, stride=1, dilation=2, multi_grid=(1, 1, 1)
)
# Head modules
self.context_encoding = _PSPModule(2048, 512)
self.edge = _Edge_Module()
self.decoder = _Decoder_Module(num_classes)
self.fushion = nn.Sequential(
nn.Conv2d(1024, 256, kernel_size=1, bias=False),
InPlaceABNSync(256),
nn.Dropout2d(0.1),
nn.Conv2d(256, num_classes, kernel_size=1, bias=True),
)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
BatchNorm2d(planes * block.expansion, affine=affine_par),
)
def _grid(i, g):
return g[i % len(g)] if isinstance(g, tuple) else 1
layers = [
block(
self.inplanes,
planes,
stride,
dilation=dilation,
downsample=downsample,
multi_grid=_grid(0, multi_grid),
)
]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
dilation=dilation,
multi_grid=_grid(i, multi_grid),
)
)
return nn.Sequential(*layers)
def forward(self, x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.maxpool(x)
x2 = self.layer1(x)
x3 = self.layer2(x2)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
context = self.context_encoding(x5)
parsing_result, parsing_fea = self.decoder(context, x2)
edge_result, edge_fea = self.edge(x2, x3, x4)
fusion_result = self.fushion(torch.cat([parsing_fea, edge_fea], dim=1))
# Return format mirrors the original: [[parsing, fusion], [edge]]
return [[parsing_result, fusion_result], [edge_result]]
# ── Transformers output dataclass ────────────────────────────────────────────
@dataclass
class SCHPSemanticSegmenterOutput(ModelOutput):
"""
Output type for :class:`SCHPForSemanticSegmentation`.
Args:
loss: Cross-entropy loss (only when ``labels`` is provided).
logits: Final fusion logits, shape ``(batch, num_labels, H, W)``,
upsampled to the input image resolution.
parsing_logits: Decoder-branch logits before fusion,
shape ``(batch, num_labels, H, W)``.
edge_logits: Edge-branch logits, shape ``(batch, 2, H, W)``.
"""
loss: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
parsing_logits: Optional[torch.Tensor] = None
edge_logits: Optional[torch.Tensor] = None
# ── PreTrainedModel wrapper ───────────────────────────────────────────────────
class SCHPForSemanticSegmentation(PreTrainedModel):
"""
SCHP ResNet-101 for human parsing / semantic segmentation.
Usage β€” loading from an original SCHP ``.pth`` checkpoint::
model = SCHPForSemanticSegmentation.from_schp_checkpoint(
"checkpoints/schp/exp-schp-201908301523-atr.pth"
)
Usage β€” loading after :meth:`save_pretrained`::
model = SCHPForSemanticSegmentation.from_pretrained(
"./my-schp-model", trust_remote_code=True
)
"""
config_class = SCHPConfig
# num_batches_tracked is not stored in the original SCHP checkpoints
_keys_to_ignore_on_load_missing = [r"\.num_batches_tracked$"]
def __init__(self, config: SCHPConfig):
super().__init__(config)
self.model = _SCHPResNet(num_classes=config.num_labels)
self.post_init()
def forward(
self,
pixel_values: torch.Tensor,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[SCHPSemanticSegmenterOutput, Tuple]:
"""
Args:
pixel_values: ``(batch, 3, H, W)`` β€” normalised with SCHP BGR-indexed means.
labels: ``(batch, H, W)`` integer class map for computing CE loss.
return_dict: Override ``config.use_return_dict``.
"""
return_dict = return_dict if return_dict is not None else True
h, w = pixel_values.shape[-2:]
raw = self.model(pixel_values)
# raw = [[parsing_result, fusion_result], [edge_result]]
logits = F.interpolate(
raw[0][1], size=(h, w), mode="bilinear", align_corners=True
)
parsing_logits = F.interpolate(
raw[0][0], size=(h, w), mode="bilinear", align_corners=True
)
edge_logits = F.interpolate(
raw[1][0], size=(h, w), mode="bilinear", align_corners=True
)
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels.long())
if not return_dict:
return (loss, logits) if loss is not None else (logits,)
return SCHPSemanticSegmenterOutput(
loss=loss,
logits=logits,
parsing_logits=parsing_logits,
edge_logits=edge_logits,
)
@classmethod
def from_schp_checkpoint(
cls,
checkpoint_path: str,
config: Optional[SCHPConfig] = None,
map_location: str = "cpu",
) -> "SCHPForSemanticSegmentation":
"""
Load from an original SCHP ``.pth`` checkpoint.
Handles the ``module.`` prefix added by ``DataParallel`` training and
remaps keys to the ``model.*`` namespace used by this wrapper.
Args:
checkpoint_path: Path to the ``.pth`` file.
config: :class:`SCHPConfig` instance. Defaults to ATR-18 config.
map_location: PyTorch device string (``"cpu"`` or ``"cuda"``).
"""
if config is None:
config = SCHPConfig()
model = cls(config)
raw = torch.load(checkpoint_path, map_location=map_location)
state_dict = raw.get("state_dict", raw)
# Strip DataParallel module. prefix if present
if all(k.startswith("module.") for k in state_dict):
state_dict = {k[len("module.") :]: v for k, v in state_dict.items()}
# Remap to model.* namespace (self.model = _SCHPResNet)
state_dict = {"model." + k: v for k, v in state_dict.items()}
missing, unexpected = model.load_state_dict(state_dict, strict=False)
real_missing = [k for k in missing if "num_batches_tracked" not in k]
if real_missing:
raise RuntimeError(
f"Missing keys when loading SCHP checkpoint ({len(real_missing)} total): "
f"{real_missing[:5]}"
)
if unexpected:
raise RuntimeError(
f"Unexpected keys when loading SCHP checkpoint ({len(unexpected)} total): "
f"{unexpected[:5]}"
)
return model