schp-pascal-7 / configuration_schp.py
pirocheto's picture
feat: initial release — Pascal Person Part 7-class SCHP model
e97480b
from transformers import PretrainedConfig
_PASCAL_LABELS = [
"Background",
"Head",
"Torso",
"Upper Arms",
"Lower Arms",
"Upper Legs",
"Lower Legs",
]
class SCHPConfig(PretrainedConfig):
r"""
Configuration for **Self-Correction-Human-Parsing (SCHP)**.
Args:
num_labels (`int`, *optional*, defaults to 7):
Number of segmentation classes (7 for Pascal Person Part dataset).
input_size (`int`, *optional*, defaults to 512):
Spatial resolution the model expects (height = width).
backbone (`str`, *optional*, defaults to `"resnet101"`):
Backbone architecture name. Only `"resnet101"` is supported.
"""
model_type = "schp"
def __init__(
self,
num_labels: int = 7,
input_size: int = 512,
backbone: str = "resnet101",
**kwargs,
):
super().__init__(**kwargs)
self.num_labels = num_labels
self.input_size = input_size
self.backbone = backbone
if "id2label" not in kwargs:
self.id2label = {
str(i): lbl for i, lbl in enumerate(_PASCAL_LABELS[:num_labels])
}
if "label2id" not in kwargs:
self.label2id = {
lbl: str(i) for i, lbl in enumerate(_PASCAL_LABELS[:num_labels])
}