hyeongjun User commited on
Commit
4d0f14f
·
1 Parent(s): 82c37dd

Fix runtime: load local DINOv3 backbones without torch.hub hubconf

Browse files
Files changed (1) hide show
  1. ManipDet/models/dinov3_loader.py +7 -1
ManipDet/models/dinov3_loader.py CHANGED
@@ -4,6 +4,7 @@ Loads from official facebookresearch/dinov3 repository.
4
  Supports offline loading via source='local' + backbone_weights_path.
5
  """
6
  import os
 
7
  import torch
8
  import torch.nn as nn
9
  from typing import Optional
@@ -77,7 +78,12 @@ def load_dinov3_backbone(
77
  raise ValueError("weights must be provided when source='local'")
78
  if not os.path.isdir(weights):
79
  raise ValueError(f"Local source requires a directory path, got: {weights}")
80
- model = torch.hub.load(weights, model_name, source='local', pretrained=False)
 
 
 
 
 
81
  if backbone_weights_path is not None and os.path.isfile(backbone_weights_path):
82
  checkpoint = torch.load(backbone_weights_path, map_location='cpu')
83
  state_dict = checkpoint.get('state_dict') or checkpoint.get('model') or checkpoint.get('model_state_dict') or checkpoint
 
4
  Supports offline loading via source='local' + backbone_weights_path.
5
  """
6
  import os
7
+ import sys
8
  import torch
9
  import torch.nn as nn
10
  from typing import Optional
 
78
  raise ValueError("weights must be provided when source='local'")
79
  if not os.path.isdir(weights):
80
  raise ValueError(f"Local source requires a directory path, got: {weights}")
81
+ if weights not in sys.path:
82
+ sys.path.insert(0, weights)
83
+ from dinov3.hub import backbones as dinov3_backbones
84
+ if not hasattr(dinov3_backbones, model_name):
85
+ raise ValueError(f"Unknown DINOv3 backbone model: {model_name}")
86
+ model = getattr(dinov3_backbones, model_name)(pretrained=False)
87
  if backbone_weights_path is not None and os.path.isfile(backbone_weights_path):
88
  checkpoint = torch.load(backbone_weights_path, map_location='cpu')
89
  state_dict = checkpoint.get('state_dict') or checkpoint.get('model') or checkpoint.get('model_state_dict') or checkpoint