Spaces:
Running
Running
Zhen Ye
commited on
Commit
·
fe2ace8
1
Parent(s):
2c4431d
Fix GroundingDino loading: manually materialize meta tensors to support tied weights
Browse files
models/detectors/grounding_dino.py
CHANGED
|
@@ -23,10 +23,28 @@ class GroundingDinoDetector(ObjectDetector):
|
|
| 23 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 25 |
self.processor = GroundingDinoProcessor.from_pretrained(self.MODEL_NAME)
|
| 26 |
-
self.model = GroundingDinoForObjectDetection.from_pretrained(
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
self.model.to(self.device)
|
| 31 |
self.model.eval()
|
| 32 |
|
|
|
|
| 23 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
|
| 25 |
self.processor = GroundingDinoProcessor.from_pretrained(self.MODEL_NAME)
|
| 26 |
+
self.model = GroundingDinoForObjectDetection.from_pretrained(self.MODEL_NAME)
|
| 27 |
+
# Materialize ALL meta-device tensors before .to().
|
| 28 |
+
# Grounding DINO uses regex-based _tied_weights_keys to share bbox_embed
|
| 29 |
+
# layers, but the tying can leave meta tensors unresolved.
|
| 30 |
+
# named_parameters() deduplicates tied weights, so we must walk each
|
| 31 |
+
# module's _parameters/_buffers dicts directly to catch every reference.
|
| 32 |
+
_meta_count = 0
|
| 33 |
+
for module in self.model.modules():
|
| 34 |
+
for key, val in list(module._parameters.items()):
|
| 35 |
+
if val is not None and val.device.type == "meta":
|
| 36 |
+
module._parameters[key] = torch.nn.Parameter(
|
| 37 |
+
torch.empty(val.shape, dtype=val.dtype, device="cpu"),
|
| 38 |
+
requires_grad=val.requires_grad,
|
| 39 |
+
)
|
| 40 |
+
_meta_count += 1
|
| 41 |
+
for key, val in list(module._buffers.items()):
|
| 42 |
+
if val is not None and val.device.type == "meta":
|
| 43 |
+
module._buffers[key] = torch.empty(val.shape, dtype=val.dtype, device="cpu")
|
| 44 |
+
_meta_count += 1
|
| 45 |
+
if _meta_count:
|
| 46 |
+
logging.warning("Materialized %d meta tensors in %s", _meta_count, self.MODEL_NAME)
|
| 47 |
+
self.model.tie_weights()
|
| 48 |
self.model.to(self.device)
|
| 49 |
self.model.eval()
|
| 50 |
|