Update resampler.py
Browse files- resampler.py +9 -0
resampler.py
CHANGED
|
@@ -117,6 +117,15 @@ class Resampler(nn.Module):
|
|
| 117 |
self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
|
| 118 |
self._set_2d_pos_cache(self.max_size, device)
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
def _init_weights(self, m):
|
| 121 |
if isinstance(m, nn.Linear):
|
| 122 |
trunc_normal_(m.weight, std=0.02)
|
|
|
|
| 117 |
self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
|
| 118 |
self._set_2d_pos_cache(self.max_size, device)
|
| 119 |
|
| 120 |
+
def _initialize_weights(self, module):
|
| 121 |
+
"""
|
| 122 |
+
Initialize the weights if they are not already initialized.
|
| 123 |
+
"""
|
| 124 |
+
if getattr(module, "_is_hf_initialized", False):
|
| 125 |
+
return
|
| 126 |
+
self._init_weights(module)
|
| 127 |
+
module._is_hf_initialized = True
|
| 128 |
+
|
| 129 |
def _init_weights(self, m):
|
| 130 |
if isinstance(m, nn.Linear):
|
| 131 |
trunc_normal_(m.weight, std=0.02)
|