Update app.py
Browse files
app.py
CHANGED
|
@@ -24,9 +24,16 @@ def predict_depth(image):
|
|
| 24 |
img = transform(image).unsqueeze(0).to(device)
|
| 25 |
with torch.no_grad():
|
| 26 |
pred = model(img)[0, 0].cpu().numpy()
|
|
|
|
| 27 |
pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
|
| 28 |
-
pred_image = Image.fromarray((pred_normalized * 255).astype(np.uint8))
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# Gradio UI
|
| 32 |
examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],]
|
|
@@ -42,7 +49,7 @@ demo = gr.Interface(
|
|
| 42 |
description="Upload an RGB image and get the depth map predicted by our tiny DepthStar model.",
|
| 43 |
examples=examples,
|
| 44 |
allow_flagging="never",
|
| 45 |
-
theme="
|
| 46 |
)
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|
|
|
|
| 24 |
img = transform(image).unsqueeze(0).to(device)
|
| 25 |
with torch.no_grad():
|
| 26 |
pred = model(img)[0, 0].cpu().numpy()
|
| 27 |
+
|
| 28 |
pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
|
| 29 |
+
pred_image = Image.fromarray((pred_normalized * 255).astype(np.uint8)).convert("L")
|
| 30 |
+
|
| 31 |
+
# Resize both for better display (e.g., 512x512)
|
| 32 |
+
upscale_size = (512, 512)
|
| 33 |
+
image_resized = image.resize(upscale_size, resample=Image.NEAREST)
|
| 34 |
+
pred_resized = pred_image.resize(upscale_size, resample=Image.NEAREST)
|
| 35 |
+
|
| 36 |
+
return [image_resized, pred_resized]
|
| 37 |
|
| 38 |
# Gradio UI
|
| 39 |
examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],]
|
|
|
|
| 49 |
description="Upload an RGB image and get the depth map predicted by our tiny DepthStar model.",
|
| 50 |
examples=examples,
|
| 51 |
allow_flagging="never",
|
| 52 |
+
theme="darkdefault",
|
| 53 |
)
|
| 54 |
|
| 55 |
if __name__ == "__main__":
|