Update apdepth/marigold_pipeline.py
Browse files
apdepth/marigold_pipeline.py
CHANGED
|
@@ -394,7 +394,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
|
| 394 |
"""
|
| 395 |
device = self.device
|
| 396 |
rgb_in = rgb_in.to(device)
|
| 397 |
-
|
| 398 |
|
| 399 |
with torch.no_grad():
|
| 400 |
# Encode image
|
|
|
|
| 394 |
"""
|
| 395 |
device = self.device
|
| 396 |
rgb_in = rgb_in.to(device)
|
| 397 |
+
depth_da2 = self.da2.infer_batch(rgb_in).to(device)
|
| 398 |
|
| 399 |
with torch.no_grad():
|
| 400 |
# Encode image
|