aradhye commited on
Commit
9785846
·
verified ·
1 Parent(s): 9740b98

Update gradio_depth_pred.py

Browse files
Files changed (1) hide show
  1. gradio_depth_pred.py +9 -65
gradio_depth_pred.py CHANGED
@@ -6,7 +6,7 @@ import torchvision.transforms as transforms
6
  import torch
7
  import spaces
8
 
9
- resolution = 2
10
  base_area = resolution * 480 * 640
11
  flip_test = True
12
 
@@ -27,72 +27,16 @@ def predict_depth(model, image):
27
  new_w, new_h = int(scale * frame_width), int(scale * frame_height)
28
  frame = cv2.resize(frame, (new_w, new_h))
29
 
30
- # ---------------------------------------
31
- # Test-Time Augmentations (TTA) + flip-test for each augmentation
32
- # ---------------------------------------
33
-
34
- frame_tensor = transforms.ToTensor()(frame).unsqueeze(0).to(DEVICE)
35
-
36
- # ---- Base augmentations (without flip) ----
37
- def apply_resize(x):
38
- b, c, h, w = x.shape
39
- scale = torch.empty(1).uniform_(0.9, 1.1).item()
40
- nh, nw = int(h * scale), int(w * scale)
41
- x2 = torch.nn.functional.interpolate(x, (nh, nw), mode="bilinear", align_corners=False)
42
- x2 = torch.nn.functional.interpolate(x2, (h, w), mode="bilinear", align_corners=False)
43
- return x2
44
-
45
- brightness_jitter = transforms.ColorJitter(brightness=0.15)
46
- def apply_brightness(x):
47
- img = transforms.ToPILImage()(x[0].cpu())
48
- img = brightness_jitter(img)
49
- return transforms.ToTensor()(img).unsqueeze(0).to(x.device)
50
-
51
- tta_base = [
52
- ("orig", lambda x: x),
53
- ("resize", apply_resize),
54
- ("bright", apply_brightness),
55
- ]
56
-
57
- # ---------------------------------------
58
- # Build augmented batch (A(x) and A(x_flipped))
59
- # ---------------------------------------
60
- augmented_frames = []
61
- reverse_fns = []
62
-
63
- for name, aug_fn in tta_base:
64
- # A(x)
65
- ax = aug_fn(frame_tensor)
66
- augmented_frames.append(ax)
67
- reverse_fns.append(lambda y: y) # no unflip needed
68
-
69
- if True: # flip_test
70
- # A(x_flip)
71
- axf = aug_fn(frame_tensor.flip(-1))
72
- augmented_frames.append(axf)
73
- reverse_fns.append(lambda y: y.flip(-1)) # unflip prediction
74
-
75
- batch = torch.cat(augmented_frames, dim=0) # [N_aug*2, 3, H, W]
76
-
77
- # ---------------------------------------
78
- # One forward pass
79
- # ---------------------------------------
80
  model.to(DEVICE)
81
-
82
  with torch.no_grad():
83
- pred_batch = model(batch) # [N_aug*2, 1, H, W]
84
-
85
- # ---------------------------------------
86
- # Reverse and average predictions
87
- # ---------------------------------------
88
- corrected = []
89
- for i, reverse in enumerate(reverse_fns):
90
- corrected.append(reverse(pred_batch[i:i+1]))
91
-
92
- depth = torch.stack(corrected).mean(dim=0)
93
- depth = depth[0, 0].cpu().numpy()
94
- return depth
95
-
96
 
97
  def create_demo(model, scene):
98
  gr.Markdown("### Depth Prediction demo")
 
6
  import torch
7
  import spaces
8
 
9
+ resolution = 4
10
  base_area = resolution * 480 * 640
11
  flip_test = True
12
 
 
27
  new_w, new_h = int(scale * frame_width), int(scale * frame_height)
28
  frame = cv2.resize(frame, (new_w, new_h))
29
 
30
+ frame = transforms.ToTensor()(frame).unsqueeze(0)
31
+ if flip_test:
32
+ frame = torch.cat([frame, frame.flip(-1)])
33
+ frame = frame.to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  model.to(DEVICE)
 
35
  with torch.no_grad():
36
+ depth = model(frame)
37
+ if flip_test:
38
+ depth = ((depth[0] + depth[1].flip(-1))/2).unsqueeze(0)
39
+ return depth.detach().cpu().numpy()[0, 0]
 
 
 
 
 
 
 
 
 
40
 
41
  def create_demo(model, scene):
42
  gr.Markdown("### Depth Prediction demo")