milkzheng's picture
Upload folder using huggingface_hub
e72f0d1 verified
Raw
History Blame Contribute Delete
3.87 kB
"""HuggingFace `PreTrainedModel` + `PretrainedConfig` wrapper around `DPT`.
Lets consumers do `AutoModel.from_pretrained(repo_id, trust_remote_code=True)`
without importing the local `DPT` class. The `auto_map` field on the config
tells HF to bundle `hf_model.py` + `dpt.py` with the uploaded weights so the
classes are reconstructable in a clean env.
"""
from __future__ import annotations
from typing import Any, Literal, Optional, cast
import torch
import torch.nn.functional as F
from transformers import (
AutoConfig,
AutoModel,
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_outputs import SemanticSegmenterOutput
from .dpt import DPT
class DPTConfig(PretrainedConfig):
model_type = "metpredict_dpt"
def __init__(
self,
n_classes: int = 4,
class_names: Optional[list[str]] = None,
backbone: str = "hf-hub:bioptimus/H-optimus-0",
encoder_depth: int = 4,
decoder_intermediate_channels: tuple[int, ...] = (224, 448, 896, 896),
decoder_fusion_channels: int = 224,
decoder_readout: str = "cat",
activation: Optional[str] = None,
in_channels: int = 3,
**kwargs,
):
super().__init__(**kwargs)
self.n_classes = n_classes
self.class_names = list(class_names) if class_names else []
self.backbone = backbone
self.encoder_depth = encoder_depth
self.decoder_intermediate_channels = list(decoder_intermediate_channels)
self.decoder_fusion_channels = decoder_fusion_channels
self.decoder_readout = decoder_readout
self.activation = activation
self.in_channels = in_channels
# `auto_map` makes the repo loadable as AutoModel without local imports.
self.auto_map = {
"AutoConfig": "hf_model.DPTConfig",
"AutoModel": "hf_model.DPTForSegmentation",
}
class DPTForSegmentation(PreTrainedModel):
config_class = DPTConfig
base_model_prefix = "dpt"
main_input_name = "pixel_values"
all_tied_weights_keys: dict = {} # To be compatible with transformers 4.x and 5.x
def __init__(self, config: DPTConfig):
super().__init__(config)
# `decoder_readout` is Literal-typed in DPT — cast since pydantic-loaded
# value is a plain str.
readout = cast(Literal["ignore", "add", "cat"], config.decoder_readout)
dpt_kwargs: dict[str, Any] = dict(
encoder_name=config.backbone,
encoder_depth=config.encoder_depth,
decoder_readout=readout,
decoder_intermediate_channels=tuple(config.decoder_intermediate_channels),
decoder_fusion_channels=config.decoder_fusion_channels,
in_channels=config.in_channels,
classes=config.n_classes,
activation=config.activation,
)
self.dpt = DPT(**dpt_kwargs)
# Skip post_init weight init — DPT.initialize() already ran inside DPT.__init__.
def forward(
self,
pixel_values: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
logits = self.dpt(pixel_values)
loss: Optional[torch.Tensor] = None
if labels is not None:
loss = F.cross_entropy(logits, labels.long())
if not return_dict:
return (loss, logits) if loss is not None else (logits,)
# SemanticSegmenterOutput expects FloatTensor — cast suppresses Pylance.
return SemanticSegmenterOutput(
loss=cast(Any, loss),
logits=cast(Any, logits),
)
def _register() -> None:
try:
AutoConfig.register("metpredict_dpt", DPTConfig)
AutoModel.register(DPTConfig, DPTForSegmentation)
except ValueError:
# Already registered (re-import).
pass
_register()