LEE181204 commited on
Commit
b20b202
·
verified ·
1 Parent(s): 0d2f318

Upload configuration_spatialvla.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_spatialvla.py +119 -0
configuration_spatialvla.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import warnings
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+ from transformers import CONFIG_MAPPING, AutoConfig
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ class SpatialVLAConfig(PretrainedConfig):
23
+ model_type = "spatialvla"
24
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "vision_zoe_config": AutoConfig}
25
+
26
+ def __init__(
27
+ self,
28
+ vision_config=None,
29
+ text_config=None,
30
+ ignore_index=-100,
31
+ image_token_index=256000,
32
+ vocab_size=257152,
33
+ projection_dim=2048,
34
+ hidden_size=2048,
35
+ vision_zoe_config=None,
36
+ action_token_begin_idx=None,
37
+ spatial_token_num=259,
38
+ use_spatial_token=False,
39
+ ego3d_patch_reso=4,
40
+ n_freqs=8,
41
+ use_vision_zoe=True,
42
+ **kwargs,
43
+ ):
44
+ self._ignore_index = ignore_index
45
+ self.image_token_index = image_token_index
46
+ self._vocab_size = vocab_size
47
+ self.projection_dim = projection_dim
48
+ self.hidden_size = hidden_size
49
+ self.vision_config = vision_config
50
+ self.is_encoder_decoder = False
51
+
52
+ if isinstance(self.vision_config, dict):
53
+ vision_config["model_type"] = (
54
+ vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
55
+ )
56
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
57
+ elif vision_config is None:
58
+ self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
59
+ intermediate_size=4096,
60
+ hidden_size=1152,
61
+ patch_size=14,
62
+ image_size=224,
63
+ num_hidden_layers=27,
64
+ num_attention_heads=16,
65
+ vocab_size=257152,
66
+ vision_use_head=False,
67
+ )
68
+
69
+ self.text_config = text_config
70
+ if isinstance(self.text_config, dict):
71
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma2"
72
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
73
+ elif text_config is None:
74
+ self.text_config = CONFIG_MAPPING["gemma2"](
75
+ hidden_size=2048,
76
+ num_hidden_layers=18,
77
+ intermediate_size=16384,
78
+ num_attention_heads=8,
79
+ num_key_value_heads=1,
80
+ is_encoder_decoder=False,
81
+ vocab_size=vocab_size,
82
+ )
83
+ self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
84
+ self.vision_config.projection_dim = projection_dim
85
+
86
+ # vision zoe config
87
+ self.vision_zoe_config = vision_zoe_config
88
+ if isinstance(self.vision_zoe_config, dict):
89
+ vision_zoe_config["model_type"] = vision_zoe_config["model_type"] if "model_type" in vision_zoe_config else "zoedepth"
90
+ self.vision_zoe_config = CONFIG_MAPPING[vision_zoe_config["model_type"]](**vision_zoe_config)
91
+ else:
92
+ pass
93
+
94
+ # additional attributes
95
+ self.action_token_begin_idx = action_token_begin_idx
96
+ self.spatial_token_num = spatial_token_num
97
+ self.use_spatial_token = use_spatial_token
98
+ self.ego3d_patch_reso = ego3d_patch_reso
99
+ self.n_freqs = n_freqs
100
+ self.use_vision_zoe = use_vision_zoe
101
+
102
+ super().__init__(**kwargs)
103
+
104
+ @property
105
+ def ignore_index(self):
106
+ warnings.warn(
107
+ "The `ignore_index` attribute is deprecated and will be removed in v4.47.",
108
+ FutureWarning,
109
+ )
110
+ return self._ignore_index
111
+
112
+ @ignore_index.setter
113
+ def ignore_index(self, value):
114
+ self._ignore_index = value
115
+
116
+ def to_dict(self):
117
+ output = super().to_dict()
118
+ output.pop("_ignore_index", None)
119
+ return output