File size: 894 Bytes
8dd9a5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from typing import Optional
from transformers import PretrainedConfig
class DiskConfig(PretrainedConfig):
model_type = "disk"
def __init__(
self,
weights: str = "depth",
max_num_keypoints: Optional[int] = None,
descriptor_decoder_dim: int = 128,
nms_window_size: int = 5,
detection_threshold: float = 0.0,
pad_if_not_divisible: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.weights = weights
self.max_num_keypoints = max_num_keypoints
self.descriptor_decoder_dim = descriptor_decoder_dim
self.nms_window_size = nms_window_size
self.detection_threshold = detection_threshold
self.pad_if_not_divisible = pad_if_not_divisible
if __name__ == "__main__":
config = DiskConfig()
config.save_pretrained("stevenbucaille/disk", push_to_hub=True)
|