Spaces:
Sleeping
Sleeping
hyeongjun User commited on
Commit ·
4d0f14f
1
Parent(s): 82c37dd
Fix runtime: load local DINOv3 backbones without torch.hub hubconf
Browse files
ManipDet/models/dinov3_loader.py
CHANGED
|
@@ -4,6 +4,7 @@ Loads from official facebookresearch/dinov3 repository.
|
|
| 4 |
Supports offline loading via source='local' + backbone_weights_path.
|
| 5 |
"""
|
| 6 |
import os
|
|
|
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
from typing import Optional
|
|
@@ -77,7 +78,12 @@ def load_dinov3_backbone(
|
|
| 77 |
raise ValueError("weights must be provided when source='local'")
|
| 78 |
if not os.path.isdir(weights):
|
| 79 |
raise ValueError(f"Local source requires a directory path, got: {weights}")
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
if backbone_weights_path is not None and os.path.isfile(backbone_weights_path):
|
| 82 |
checkpoint = torch.load(backbone_weights_path, map_location='cpu')
|
| 83 |
state_dict = checkpoint.get('state_dict') or checkpoint.get('model') or checkpoint.get('model_state_dict') or checkpoint
|
|
|
|
| 4 |
Supports offline loading via source='local' + backbone_weights_path.
|
| 5 |
"""
|
| 6 |
import os
|
| 7 |
+
import sys
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
from typing import Optional
|
|
|
|
| 78 |
raise ValueError("weights must be provided when source='local'")
|
| 79 |
if not os.path.isdir(weights):
|
| 80 |
raise ValueError(f"Local source requires a directory path, got: {weights}")
|
| 81 |
+
if weights not in sys.path:
|
| 82 |
+
sys.path.insert(0, weights)
|
| 83 |
+
from dinov3.hub import backbones as dinov3_backbones
|
| 84 |
+
if not hasattr(dinov3_backbones, model_name):
|
| 85 |
+
raise ValueError(f"Unknown DINOv3 backbone model: {model_name}")
|
| 86 |
+
model = getattr(dinov3_backbones, model_name)(pretrained=False)
|
| 87 |
if backbone_weights_path is not None and os.path.isfile(backbone_weights_path):
|
| 88 |
checkpoint = torch.load(backbone_weights_path, map_location='cpu')
|
| 89 |
state_dict = checkpoint.get('state_dict') or checkpoint.get('model') or checkpoint.get('model_state_dict') or checkpoint
|