Fix double conditioning (#6)
Browse files- Fix double conditioning (19f901633dea829072dc888c2286a14f44f5f4e4)
- config.json +1 -0
- hf_model.py +4 -1
- radio_model.py +8 -1
config.json
CHANGED
|
@@ -347,6 +347,7 @@
|
|
| 347 |
"AutoConfig": "hf_model.RADIOConfig",
|
| 348 |
"AutoModel": "hf_model.RADIOModel"
|
| 349 |
},
|
|
|
|
| 350 |
"max_resolution": 2048,
|
| 351 |
"patch_size": 16,
|
| 352 |
"preferred_resolution": [
|
|
|
|
| 347 |
"AutoConfig": "hf_model.RADIOConfig",
|
| 348 |
"AutoModel": "hf_model.RADIOModel"
|
| 349 |
},
|
| 350 |
+
"external_conditioner": false,
|
| 351 |
"max_resolution": 2048,
|
| 352 |
"patch_size": 16,
|
| 353 |
"preferred_resolution": [
|
hf_model.py
CHANGED
|
@@ -45,6 +45,7 @@ class RADIOConfig(PretrainedConfig):
|
|
| 45 |
preferred_resolution: Optional[Resolution] = None,
|
| 46 |
adaptor_names: Union[str, List[str]] = None,
|
| 47 |
vitdet_window_size: Optional[int] = None,
|
|
|
|
| 48 |
**kwargs,
|
| 49 |
):
|
| 50 |
self.args = args
|
|
@@ -63,6 +64,7 @@ class RADIOConfig(PretrainedConfig):
|
|
| 63 |
)
|
| 64 |
self.adaptor_names = adaptor_names
|
| 65 |
self.vitdet_window_size = vitdet_window_size
|
|
|
|
| 66 |
super().__init__(**kwargs)
|
| 67 |
|
| 68 |
|
|
@@ -75,7 +77,7 @@ class RADIOModel(PreTrainedModel):
|
|
| 75 |
|
| 76 |
config_class = RADIOConfig
|
| 77 |
|
| 78 |
-
def __init__(self, config):
|
| 79 |
super().__init__(config)
|
| 80 |
|
| 81 |
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
|
|
@@ -115,6 +117,7 @@ class RADIOModel(PreTrainedModel):
|
|
| 115 |
preferred_resolution=config.preferred_resolution,
|
| 116 |
adaptors=adaptors,
|
| 117 |
)
|
|
|
|
| 118 |
|
| 119 |
@property
|
| 120 |
def adaptors(self) -> nn.ModuleDict:
|
|
|
|
| 45 |
preferred_resolution: Optional[Resolution] = None,
|
| 46 |
adaptor_names: Union[str, List[str]] = None,
|
| 47 |
vitdet_window_size: Optional[int] = None,
|
| 48 |
+
external_conditioner: Optional[bool] = False,
|
| 49 |
**kwargs,
|
| 50 |
):
|
| 51 |
self.args = args
|
|
|
|
| 64 |
)
|
| 65 |
self.adaptor_names = adaptor_names
|
| 66 |
self.vitdet_window_size = vitdet_window_size
|
| 67 |
+
self.external_conditioner = external_conditioner
|
| 68 |
super().__init__(**kwargs)
|
| 69 |
|
| 70 |
|
|
|
|
| 77 |
|
| 78 |
config_class = RADIOConfig
|
| 79 |
|
| 80 |
+
def __init__(self, config: RADIOConfig):
|
| 81 |
super().__init__(config)
|
| 82 |
|
| 83 |
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
|
|
|
|
| 117 |
preferred_resolution=config.preferred_resolution,
|
| 118 |
adaptors=adaptors,
|
| 119 |
)
|
| 120 |
+
self.radio_model._external_conditioner = config.external_conditioner
|
| 121 |
|
| 122 |
@property
|
| 123 |
def adaptors(self) -> nn.ModuleDict:
|
radio_model.py
CHANGED
|
@@ -51,6 +51,12 @@ class RADIOModel(nn.Module):
|
|
| 51 |
self._patch_size = patch_size
|
| 52 |
self._max_resolution = max_resolution
|
| 53 |
self._window_size = window_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
adaptors = adaptors or dict()
|
| 56 |
self.adaptors = nn.ModuleDict(adaptors)
|
|
@@ -113,7 +119,8 @@ class RADIOModel(nn.Module):
|
|
| 113 |
'`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
|
| 114 |
f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
|
| 115 |
|
| 116 |
-
|
|
|
|
| 117 |
y = self.model.forward_features(x)
|
| 118 |
|
| 119 |
if isinstance(self.model, VisionTransformer):
|
|
|
|
| 51 |
self._patch_size = patch_size
|
| 52 |
self._max_resolution = max_resolution
|
| 53 |
self._window_size = window_size
|
| 54 |
+
# This is a hack workaround for huggingface, since their
|
| 55 |
+
# data prep is annoying and complicated. If set to true,
|
| 56 |
+
# then will not call `self.input_conditioner` on the
|
| 57 |
+
# input tensor. This will be set in `hf_model.RADIOModel`
|
| 58 |
+
# where appropriate.
|
| 59 |
+
self._external_conditioner = False
|
| 60 |
|
| 61 |
adaptors = adaptors or dict()
|
| 62 |
self.adaptors = nn.ModuleDict(adaptors)
|
|
|
|
| 119 |
'`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
|
| 120 |
f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
|
| 121 |
|
| 122 |
+
if not self._external_conditioner:
|
| 123 |
+
x = self.input_conditioner(x)
|
| 124 |
y = self.model.forward_features(x)
|
| 125 |
|
| 126 |
if isinstance(self.model, VisionTransformer):
|