Update apdepth/marigold_pipeline.py
Browse files
apdepth/marigold_pipeline.py
CHANGED
|
@@ -154,7 +154,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 154 |
# 初始化 DA2 模型
|
| 155 |
if da2_config is not None:
|
| 156 |
self.da2 = DepthAnythingV2(**da2_config)
|
| 157 |
-
self.da2.load_state_dict(torch.load(f'
|
| 158 |
self.da2.to(device="cuda").eval()
|
| 159 |
else:
|
| 160 |
self.da2 = None
|
|
|
|
| 154 |
# 初始化 DA2 模型
|
| 155 |
if da2_config is not None:
|
| 156 |
self.da2 = DepthAnythingV2(**da2_config)
|
| 157 |
+
self.da2.load_state_dict(torch.load(f'./DA2/checkpoints/depth_anything_v2_{da2_config["encoder"]}.pth', map_location='cpu'))
|
| 158 |
self.da2.to(device="cuda").eval()
|
| 159 |
else:
|
| 160 |
self.da2 = None
|