Spaces:
Running
on
Zero
Running
on
Zero
Update ppd/models/ppd.py
Browse files- ppd/models/ppd.py +6 -1
ppd/models/ppd.py
CHANGED
|
@@ -6,6 +6,7 @@ import torch.nn as nn
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import cv2
|
| 8 |
import random
|
|
|
|
| 9 |
from ppd.utils.timesteps import Timesteps
|
| 10 |
from ppd.utils.schedule import LinearSchedule
|
| 11 |
from ppd.utils.sampler import EulerSampler
|
|
@@ -23,7 +24,7 @@ class PixelPerfectDepth(nn.Module):
|
|
| 23 |
):
|
| 24 |
super(PixelPerfectDepth, self).__init__()
|
| 25 |
|
| 26 |
-
DEVICE = torch.device('cuda' if torch.cuda.is_available() else '
|
| 27 |
self.device = DEVICE
|
| 28 |
|
| 29 |
self.semantics_encoder = DepthAnythingV2(
|
|
@@ -31,6 +32,10 @@ class PixelPerfectDepth(nn.Module):
|
|
| 31 |
features=256,
|
| 32 |
out_channels=[256, 512, 1024, 1024]
|
| 33 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
self.semantics_encoder.load_state_dict(torch.load(semantics_pth, map_location='cpu'), strict=False)
|
| 35 |
self.semantics_encoder = self.semantics_encoder.to(self.device).eval()
|
| 36 |
self.dit = DiT()
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import cv2
|
| 8 |
import random
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
from ppd.utils.timesteps import Timesteps
|
| 11 |
from ppd.utils.schedule import LinearSchedule
|
| 12 |
from ppd.utils.sampler import EulerSampler
|
|
|
|
| 24 |
):
|
| 25 |
super(PixelPerfectDepth, self).__init__()
|
| 26 |
|
| 27 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 28 |
self.device = DEVICE
|
| 29 |
|
| 30 |
self.semantics_encoder = DepthAnythingV2(
|
|
|
|
| 32 |
features=256,
|
| 33 |
out_channels=[256, 512, 1024, 1024]
|
| 34 |
)
|
| 35 |
+
semantics_pth = hf_hub_download(
|
| 36 |
+
repo_id="depth-anything/Depth-Anything-V2-Large",
|
| 37 |
+
filename="depth_anything_v2_vitl.pth",
|
| 38 |
+
repo_type="model")
|
| 39 |
self.semantics_encoder.load_state_dict(torch.load(semantics_pth, map_location='cpu'), strict=False)
|
| 40 |
self.semantics_encoder = self.semantics_encoder.to(self.device).eval()
|
| 41 |
self.dit = DiT()
|