tipsv2-l14-dpt / configuration_dpt.py
gberton's picture
Upload configuration_dpt.py with huggingface_hub
281491c verified
"""TIPSv2 DPT head configuration."""
from transformers import PretrainedConfig
class TIPSv2DPTConfig(PretrainedConfig):
"""Configuration for TIPSv2 DPT dense prediction heads."""
model_type = "tipsv2_dpt"
def __init__(
self,
backbone_repo="google/tipsv2-l14",
embed_dim=1024,
channels=256,
post_process_channels=(128, 256, 512, 1024),
block_indices=(5, 11, 17, 23),
readout_type="project",
num_depth_bins=256,
min_depth=1e-3,
max_depth=10.0,
num_seg_classes=150,
**kwargs,
):
super().__init__(**kwargs)
self.backbone_repo = backbone_repo
self.embed_dim = embed_dim
self.channels = channels
self.post_process_channels = list(post_process_channels)
self.block_indices = list(block_indices)
self.readout_type = readout_type
self.num_depth_bins = num_depth_bins
self.min_depth = min_depth
self.max_depth = max_depth
self.num_seg_classes = num_seg_classes