Spaces:
Running
Running
batch model inference
Browse files
app.py
CHANGED
|
@@ -43,24 +43,15 @@ def resize_image(image_buffer, max_size=256):
|
|
| 43 |
|
| 44 |
@spaces.GPU(duration=20)
|
| 45 |
def predict_depth(input_images):
|
| 46 |
-
# Preprocess the image
|
| 47 |
results = [depth_pro.load_rgb(image) for image in input_images]
|
| 48 |
-
|
| 49 |
-
# assume load_rgb returns a tuple of (image, f_px)
|
| 50 |
-
# stack the images and f_px values into tensors
|
| 51 |
-
images, f_px = zip(*results)
|
| 52 |
-
images = torch.stack(images)
|
| 53 |
-
f_px = torch.tensor(f_px)
|
| 54 |
-
|
| 55 |
-
images = transform(images)
|
| 56 |
-
|
| 57 |
images = images.to(device)
|
| 58 |
-
f_px = f_px.to(device)
|
| 59 |
|
| 60 |
# Run inference
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
# Convert depth to numpy array if it's a torch tensor
|
| 66 |
if isinstance(depth, torch.Tensor):
|
|
@@ -68,9 +59,9 @@ def predict_depth(input_images):
|
|
| 68 |
|
| 69 |
# Convert focal length to a float if it's a torch tensor
|
| 70 |
if isinstance(focallength_px, torch.Tensor):
|
| 71 |
-
focallength_px =
|
| 72 |
|
| 73 |
-
# Ensure depth is a
|
| 74 |
if depth.ndim != 2:
|
| 75 |
depth = depth.squeeze()
|
| 76 |
|
|
@@ -114,7 +105,13 @@ def run_rerun(path_to_video):
|
|
| 114 |
|
| 115 |
# limit the number of frames to 10 seconds of video
|
| 116 |
max_frames = min(10 * fps_video, num_frames)
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# go through all the frames in the video, using the batch size
|
| 120 |
for i in range(0, int(max_frames), batch_size):
|
|
@@ -122,6 +119,7 @@ def run_rerun(path_to_video):
|
|
| 122 |
raise gr.Error("Reached the maximum number of frames to process")
|
| 123 |
|
| 124 |
frames = []
|
|
|
|
| 125 |
for _ in range(batch_size):
|
| 126 |
ret, frame = cap.read()
|
| 127 |
if not ret:
|
|
@@ -135,11 +133,14 @@ def run_rerun(path_to_video):
|
|
| 135 |
|
| 136 |
depths, focal_lengths = predict_depth(temp_files)
|
| 137 |
|
| 138 |
-
for depth, focal_length in zip(
|
|
|
|
|
|
|
| 139 |
# find x and y scale factors, which can be applied to image
|
| 140 |
x_scale = depth.shape[1] / frames[0].shape[1]
|
| 141 |
y_scale = depth.shape[0] / frames[0].shape[0]
|
| 142 |
|
|
|
|
| 143 |
rr.log(
|
| 144 |
"world/camera/depth",
|
| 145 |
rr.DepthImage(depth, meter=1),
|
|
@@ -149,7 +150,7 @@ def run_rerun(path_to_video):
|
|
| 149 |
"world/camera/frame",
|
| 150 |
rr.VideoFrameReference(
|
| 151 |
timestamp=rr.components.VideoTimestamp(
|
| 152 |
-
nanoseconds=frame_timestamps_ns[
|
| 153 |
),
|
| 154 |
video_reference="world/video",
|
| 155 |
),
|
|
|
|
| 43 |
|
| 44 |
@spaces.GPU(duration=20)
|
| 45 |
def predict_depth(input_images):
|
|
|
|
| 46 |
results = [depth_pro.load_rgb(image) for image in input_images]
|
| 47 |
+
images = torch.stack([transform(result[0]) for result in results])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
images = images.to(device)
|
|
|
|
| 49 |
|
| 50 |
# Run inference
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
prediction = model.infer(images)
|
| 53 |
+
depth = prediction["depth"] # Depth in [m]
|
| 54 |
+
focallength_px = prediction["focallength_px"] # Focal length in pixels
|
| 55 |
|
| 56 |
# Convert depth to numpy array if it's a torch tensor
|
| 57 |
if isinstance(depth, torch.Tensor):
|
|
|
|
| 59 |
|
| 60 |
# Convert focal length to a float if it's a torch tensor
|
| 61 |
if isinstance(focallength_px, torch.Tensor):
|
| 62 |
+
focallength_px = [focal_length.item() for focal_length in focallength_px]
|
| 63 |
|
| 64 |
+
# Ensure depth is a BxHxW tensor
|
| 65 |
if depth.ndim != 2:
|
| 66 |
depth = depth.squeeze()
|
| 67 |
|
|
|
|
| 105 |
|
| 106 |
# limit the number of frames to 10 seconds of video
|
| 107 |
max_frames = min(10 * fps_video, num_frames)
|
| 108 |
+
|
| 109 |
+
torch.cuda.empty_cache()
|
| 110 |
+
free_vram, _ = torch.cuda.mem_get_info(device)
|
| 111 |
+
free_vram = free_vram / 1024 / 1024 / 1024
|
| 112 |
+
|
| 113 |
+
# batch size is determined by the amount of free vram
|
| 114 |
+
batch_size = int(min(free_vram // 4, max_frames))
|
| 115 |
|
| 116 |
# go through all the frames in the video, using the batch size
|
| 117 |
for i in range(0, int(max_frames), batch_size):
|
|
|
|
| 119 |
raise gr.Error("Reached the maximum number of frames to process")
|
| 120 |
|
| 121 |
frames = []
|
| 122 |
+
frame_indices = list(range(i, min(i + batch_size, int(max_frames))))
|
| 123 |
for _ in range(batch_size):
|
| 124 |
ret, frame = cap.read()
|
| 125 |
if not ret:
|
|
|
|
| 133 |
|
| 134 |
depths, focal_lengths = predict_depth(temp_files)
|
| 135 |
|
| 136 |
+
for depth, focal_length, frame_idx in zip(
|
| 137 |
+
depths, focal_lengths, frame_indices
|
| 138 |
+
):
|
| 139 |
# find x and y scale factors, which can be applied to image
|
| 140 |
x_scale = depth.shape[1] / frames[0].shape[1]
|
| 141 |
y_scale = depth.shape[0] / frames[0].shape[0]
|
| 142 |
|
| 143 |
+
rr.set_time_nanos("video_time", frame_timestamps_ns[frame_idx])
|
| 144 |
rr.log(
|
| 145 |
"world/camera/depth",
|
| 146 |
rr.DepthImage(depth, meter=1),
|
|
|
|
| 150 |
"world/camera/frame",
|
| 151 |
rr.VideoFrameReference(
|
| 152 |
timestamp=rr.components.VideoTimestamp(
|
| 153 |
+
nanoseconds=frame_timestamps_ns[frame_idx]
|
| 154 |
),
|
| 155 |
video_reference="world/video",
|
| 156 |
),
|