gangweix commited on
Commit
0f4cff1
·
verified ·
1 Parent(s): 130c454

Update ppd/models/ppd.py

Browse files
Files changed (1) hide show
  1. ppd/models/ppd.py +6 -1
ppd/models/ppd.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  import cv2
8
  import random
 
9
  from ppd.utils.timesteps import Timesteps
10
  from ppd.utils.schedule import LinearSchedule
11
  from ppd.utils.sampler import EulerSampler
@@ -23,7 +24,7 @@ class PixelPerfectDepth(nn.Module):
23
  ):
24
  super(PixelPerfectDepth, self).__init__()
25
 
26
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
27
  self.device = DEVICE
28
 
29
  self.semantics_encoder = DepthAnythingV2(
@@ -31,6 +32,10 @@ class PixelPerfectDepth(nn.Module):
31
  features=256,
32
  out_channels=[256, 512, 1024, 1024]
33
  )
 
 
 
 
34
  self.semantics_encoder.load_state_dict(torch.load(semantics_pth, map_location='cpu'), strict=False)
35
  self.semantics_encoder = self.semantics_encoder.to(self.device).eval()
36
  self.dit = DiT()
 
6
  import torch.nn.functional as F
7
  import cv2
8
  import random
9
+ from huggingface_hub import hf_hub_download
10
  from ppd.utils.timesteps import Timesteps
11
  from ppd.utils.schedule import LinearSchedule
12
  from ppd.utils.sampler import EulerSampler
 
24
  ):
25
  super(PixelPerfectDepth, self).__init__()
26
 
27
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
  self.device = DEVICE
29
 
30
  self.semantics_encoder = DepthAnythingV2(
 
32
  features=256,
33
  out_channels=[256, 512, 1024, 1024]
34
  )
35
+ semantics_pth = hf_hub_download(
36
+ repo_id="depth-anything/Depth-Anything-V2-Large",
37
+ filename="depth_anything_v2_vitl.pth",
38
+ repo_type="model")
39
  self.semantics_encoder.load_state_dict(torch.load(semantics_pth, map_location='cpu'), strict=False)
40
  self.semantics_encoder = self.semantics_encoder.to(self.device).eval()
41
  self.dit = DiT()