Update apdepth/marigold_pipeline.py
Browse files
apdepth/marigold_pipeline.py
CHANGED
|
@@ -155,7 +155,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 155 |
if da2_config is not None:
|
| 156 |
self.da2 = DepthAnythingV2(**da2_config)
|
| 157 |
self.da2.load_state_dict(torch.load(f'./DA2/checkpoints/depth_anything_v2_{da2_config["encoder"]}.pth', map_location='cpu'))
|
| 158 |
-
self.da2.to(device="
|
| 159 |
else:
|
| 160 |
self.da2 = None
|
| 161 |
|
|
|
|
| 155 |
if da2_config is not None:
|
| 156 |
self.da2 = DepthAnythingV2(**da2_config)
|
| 157 |
self.da2.load_state_dict(torch.load(f'./DA2/checkpoints/depth_anything_v2_{da2_config["encoder"]}.pth', map_location='cpu'))
|
| 158 |
+
self.da2.to(device="cpu").eval() #hf只有CPU :(
|
| 159 |
else:
|
| 160 |
self.da2 = None
|
| 161 |
|