"""HuggingFace `PreTrainedModel` + `PretrainedConfig` wrapper around `DPT`. Lets consumers do `AutoModel.from_pretrained(repo_id, trust_remote_code=True)` without importing the local `DPT` class. The `auto_map` field on the config tells HF to bundle `hf_model.py` + `dpt.py` with the uploaded weights so the classes are reconstructable in a clean env. """ from __future__ import annotations from typing import Any, Literal, Optional, cast import torch import torch.nn.functional as F from transformers import ( AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel, ) from transformers.modeling_outputs import SemanticSegmenterOutput from .dpt import DPT class DPTConfig(PretrainedConfig): model_type = "metpredict_dpt" def __init__( self, n_classes: int = 4, class_names: Optional[list[str]] = None, backbone: str = "hf-hub:bioptimus/H-optimus-0", encoder_depth: int = 4, decoder_intermediate_channels: tuple[int, ...] = (224, 448, 896, 896), decoder_fusion_channels: int = 224, decoder_readout: str = "cat", activation: Optional[str] = None, in_channels: int = 3, **kwargs, ): super().__init__(**kwargs) self.n_classes = n_classes self.class_names = list(class_names) if class_names else [] self.backbone = backbone self.encoder_depth = encoder_depth self.decoder_intermediate_channels = list(decoder_intermediate_channels) self.decoder_fusion_channels = decoder_fusion_channels self.decoder_readout = decoder_readout self.activation = activation self.in_channels = in_channels # `auto_map` makes the repo loadable as AutoModel without local imports. self.auto_map = { "AutoConfig": "hf_model.DPTConfig", "AutoModel": "hf_model.DPTForSegmentation", } class DPTForSegmentation(PreTrainedModel): config_class = DPTConfig base_model_prefix = "dpt" main_input_name = "pixel_values" all_tied_weights_keys: dict = {} # To be compatible with transformers 4.x and 5.x def __init__(self, config: DPTConfig): super().__init__(config) # `decoder_readout` is Literal-typed in DPT — cast since pydantic-loaded # value is a plain str. readout = cast(Literal["ignore", "add", "cat"], config.decoder_readout) dpt_kwargs: dict[str, Any] = dict( encoder_name=config.backbone, encoder_depth=config.encoder_depth, decoder_readout=readout, decoder_intermediate_channels=tuple(config.decoder_intermediate_channels), decoder_fusion_channels=config.decoder_fusion_channels, in_channels=config.in_channels, classes=config.n_classes, activation=config.activation, ) self.dpt = DPT(**dpt_kwargs) # Skip post_init weight init — DPT.initialize() already ran inside DPT.__init__. def forward( self, pixel_values: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: bool = True, ): logits = self.dpt(pixel_values) loss: Optional[torch.Tensor] = None if labels is not None: loss = F.cross_entropy(logits, labels.long()) if not return_dict: return (loss, logits) if loss is not None else (logits,) # SemanticSegmenterOutput expects FloatTensor — cast suppresses Pylance. return SemanticSegmenterOutput( loss=cast(Any, loss), logits=cast(Any, logits), ) def _register() -> None: try: AutoConfig.register("metpredict_dpt", DPTConfig) AutoModel.register(DPTConfig, DPTForSegmentation) except ValueError: # Already registered (re-import). pass _register()