gangweix commited on
Commit
5ee6fca
·
verified ·
1 Parent(s): d28b1b0

Update ppd/models/ppd.py

Browse files
Files changed (1) hide show
  1. ppd/models/ppd.py +10 -5
ppd/models/ppd.py CHANGED
@@ -53,12 +53,20 @@ class PixelPerfectDepth(nn.Module):
53
  )
54
 
55
  @torch.no_grad()
56
- def forward_test(self, image):
57
  h, w = image.shape[:2]
58
  image = resize_keep_aspect(image)
59
  image = image2tensor(image)
60
  image = image.to(self.device)
61
 
 
 
 
 
 
 
 
 
62
  semantics = self.semantics_prompt(image)
63
  cond = image - 0.5
64
  latent = torch.randn(size=[cond.shape[0], 1, cond.shape[2], cond.shape[3]]).to(self.device)
@@ -68,10 +76,7 @@ class PixelPerfectDepth(nn.Module):
68
  pred = self.dit(x=input, semantics=semantics, timestep=timestep)
69
  latent = self.sampler.step(pred=pred, x_t=latent, t=timestep)
70
 
71
- depth = latent + 0.5
72
- depth = F.interpolate(depth, size=(h, w), mode='bilinear', align_corners=False)[0, 0]
73
-
74
- return depth.cpu().numpy()
75
 
76
 
77
  @torch.no_grad()
 
53
  )
54
 
55
  @torch.no_grad()
56
+ def infer_image(self, image):
57
  h, w = image.shape[:2]
58
  image = resize_keep_aspect(image)
59
  image = image2tensor(image)
60
  image = image.to(self.device)
61
 
62
+ depth = self.forward_test(image)
63
+ depth = F.interpolate(depth, size=(h, w), mode='bilinear', align_corners=False)[0, 0]
64
+
65
+ return depth.cpu().numpy()
66
+
67
+ @torch.no_grad()
68
+ def forward_test(self, image):
69
+
70
  semantics = self.semantics_prompt(image)
71
  cond = image - 0.5
72
  latent = torch.randn(size=[cond.shape[0], 1, cond.shape[2], cond.shape[3]]).to(self.device)
 
76
  pred = self.dit(x=input, semantics=semantics, timestep=timestep)
77
  latent = self.sampler.step(pred=pred, x_t=latent, t=timestep)
78
 
79
+ return latent + 0.5
 
 
 
80
 
81
 
82
  @torch.no_grad()