Spaces:
Running
Running
batch model inference
Browse files
app.py
CHANGED
|
@@ -42,16 +42,23 @@ def resize_image(image_buffer, max_size=256):
|
|
| 42 |
|
| 43 |
|
| 44 |
@spaces.GPU(duration=20)
|
| 45 |
-
def predict_depth(
|
| 46 |
# Preprocess the image
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Run inference
|
| 54 |
-
prediction = model.infer(
|
| 55 |
depth = prediction["depth"] # Depth in [m]
|
| 56 |
focallength_px = prediction["focallength_px"] # Focal length in pixels
|
| 57 |
|
|
@@ -107,62 +114,68 @@ def run_rerun(path_to_video):
|
|
| 107 |
|
| 108 |
# limit the number of frames to 10 seconds of video
|
| 109 |
max_frames = min(10 * fps_video, num_frames)
|
|
|
|
| 110 |
|
| 111 |
-
|
|
|
|
| 112 |
if i >= max_frames:
|
| 113 |
raise gr.Error("Reached the maximum number of frames to process")
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
|
| 120 |
try:
|
| 121 |
-
# Resize the
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
| 141 |
),
|
| 142 |
-
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
),
|
| 157 |
-
)
|
| 158 |
|
| 159 |
-
|
| 160 |
except Exception as e:
|
| 161 |
raise gr.Error(f"An error has occurred: {e}")
|
| 162 |
finally:
|
| 163 |
-
# Clean up the temporary
|
| 164 |
-
|
| 165 |
-
os.
|
|
|
|
| 166 |
|
| 167 |
yield stream.read()
|
| 168 |
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
@spaces.GPU(duration=20)
|
| 45 |
+
def predict_depth(input_images):
|
| 46 |
# Preprocess the image
|
| 47 |
+
results = [depth_pro.load_rgb(image) for image in input_images]
|
| 48 |
+
|
| 49 |
+
# assume load_rgb returns a tuple of (image, f_px)
|
| 50 |
+
# stack the images and f_px values into tensors
|
| 51 |
+
images, f_px = zip(*results)
|
| 52 |
+
images = torch.stack(images)
|
| 53 |
+
f_px = torch.tensor(f_px)
|
| 54 |
+
|
| 55 |
+
images = transform(images)
|
| 56 |
+
|
| 57 |
+
images = images.to(device)
|
| 58 |
+
f_px = f_px.to(device)
|
| 59 |
|
| 60 |
# Run inference
|
| 61 |
+
prediction = model.infer(images, f_px=f_px)
|
| 62 |
depth = prediction["depth"] # Depth in [m]
|
| 63 |
focallength_px = prediction["focallength_px"] # Focal length in pixels
|
| 64 |
|
|
|
|
| 114 |
|
| 115 |
# limit the number of frames to 10 seconds of video
|
| 116 |
max_frames = min(10 * fps_video, num_frames)
|
| 117 |
+
batch_size = min(16, max_frames)
|
| 118 |
|
| 119 |
+
# go through all the frames in the video, using the batch size
|
| 120 |
+
for i in range(0, int(max_frames), batch_size):
|
| 121 |
if i >= max_frames:
|
| 122 |
raise gr.Error("Reached the maximum number of frames to process")
|
| 123 |
|
| 124 |
+
frames = []
|
| 125 |
+
for _ in range(batch_size):
|
| 126 |
+
ret, frame = cap.read()
|
| 127 |
+
if not ret:
|
| 128 |
+
break
|
| 129 |
+
frames.append(frame)
|
| 130 |
|
| 131 |
+
temp_files = []
|
| 132 |
try:
|
| 133 |
+
# Resize the images to make the inference faster
|
| 134 |
+
temp_files = [resize_image(frame, max_size=256) for frame in frames]
|
| 135 |
+
|
| 136 |
+
depths, focal_lengths = predict_depth(temp_files)
|
| 137 |
+
|
| 138 |
+
for depth, focal_length in zip(depths, focal_lengths):
|
| 139 |
+
# find x and y scale factors, which can be applied to image
|
| 140 |
+
x_scale = depth.shape[1] / frames[0].shape[1]
|
| 141 |
+
y_scale = depth.shape[0] / frames[0].shape[0]
|
| 142 |
+
|
| 143 |
+
rr.log(
|
| 144 |
+
"world/camera/depth",
|
| 145 |
+
rr.DepthImage(depth, meter=1),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
rr.log(
|
| 149 |
+
"world/camera/frame",
|
| 150 |
+
rr.VideoFrameReference(
|
| 151 |
+
timestamp=rr.components.VideoTimestamp(
|
| 152 |
+
nanoseconds=frame_timestamps_ns[i]
|
| 153 |
+
),
|
| 154 |
+
video_reference="world/video",
|
| 155 |
),
|
| 156 |
+
rr.Transform3D(scale=(x_scale, y_scale, 1)),
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
rr.log(
|
| 160 |
+
"world/camera",
|
| 161 |
+
rr.Pinhole(
|
| 162 |
+
focal_length=focal_length,
|
| 163 |
+
width=depth.shape[1],
|
| 164 |
+
height=depth.shape[0],
|
| 165 |
+
principal_point=(depth.shape[1] / 2, depth.shape[0] / 2),
|
| 166 |
+
camera_xyz=rr.ViewCoordinates.FLU,
|
| 167 |
+
image_plane_distance=depth.max(),
|
| 168 |
+
),
|
| 169 |
+
)
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
yield stream.read()
|
| 172 |
except Exception as e:
|
| 173 |
raise gr.Error(f"An error has occurred: {e}")
|
| 174 |
finally:
|
| 175 |
+
# Clean up the temporary files
|
| 176 |
+
for temp_file in temp_files:
|
| 177 |
+
if temp_file and os.path.exists(temp_file):
|
| 178 |
+
os.remove(temp_file)
|
| 179 |
|
| 180 |
yield stream.read()
|
| 181 |
|