Zhen Ye commited on
Commit
fe2ace8
·
1 Parent(s): 2c4431d

Fix GroundingDino loading: manually materialize meta tensors to support tied weights

Browse files
Files changed (1) hide show
  1. models/detectors/grounding_dino.py +22 -4
models/detectors/grounding_dino.py CHANGED
@@ -23,10 +23,28 @@ class GroundingDinoDetector(ObjectDetector):
23
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
25
  self.processor = GroundingDinoProcessor.from_pretrained(self.MODEL_NAME)
26
- self.model = GroundingDinoForObjectDetection.from_pretrained(
27
- self.MODEL_NAME,
28
- low_cpu_mem_usage=False,
29
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  self.model.to(self.device)
31
  self.model.eval()
32
 
 
23
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
25
  self.processor = GroundingDinoProcessor.from_pretrained(self.MODEL_NAME)
26
+ self.model = GroundingDinoForObjectDetection.from_pretrained(self.MODEL_NAME)
27
+ # Materialize ALL meta-device tensors before .to().
28
+ # Grounding DINO uses regex-based _tied_weights_keys to share bbox_embed
29
+ # layers, but the tying can leave meta tensors unresolved.
30
+ # named_parameters() deduplicates tied weights, so we must walk each
31
+ # module's _parameters/_buffers dicts directly to catch every reference.
32
+ _meta_count = 0
33
+ for module in self.model.modules():
34
+ for key, val in list(module._parameters.items()):
35
+ if val is not None and val.device.type == "meta":
36
+ module._parameters[key] = torch.nn.Parameter(
37
+ torch.empty(val.shape, dtype=val.dtype, device="cpu"),
38
+ requires_grad=val.requires_grad,
39
+ )
40
+ _meta_count += 1
41
+ for key, val in list(module._buffers.items()):
42
+ if val is not None and val.device.type == "meta":
43
+ module._buffers[key] = torch.empty(val.shape, dtype=val.dtype, device="cpu")
44
+ _meta_count += 1
45
+ if _meta_count:
46
+ logging.warning("Materialized %d meta tensors in %s", _meta_count, self.MODEL_NAME)
47
+ self.model.tie_weights()
48
  self.model.to(self.device)
49
  self.model.eval()
50