File size: 1,965 Bytes
98141d1 f54cbf4 4c8e425 98141d1 f54cbf4 98141d1 f7fa060 71e0311 f7fa060 98141d1 4c8e425 98141d1 4817508 4c8e425 71e0311 4817508 4c8e425 4817508 98141d1 4c8e425 98141d1 7364c33 98141d1 50e087e 98141d1 f7fa060 98141d1 4817508 98141d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import torch
import gradio as gr
import numpy as np
from PIL import Image
import torchvision.transforms as T
from model import DepthSTAR
import matplotlib.cm as cm
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DepthSTAR().to(device)
model.load_state_dict(torch.load("depth_model_all.pth", map_location=device))
model.eval()
# Preprocessing
transform = T.Compose([
T.Resize((32, 32)),
T.ToTensor(),
])
# Larger output display
image_output_size = 256
def predict_depth(image):
img = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
pred = model(img)[0, 0].cpu().numpy()
# Normalize prediction
pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
# Apply inferno colormap
cmap = cm.get_cmap('inferno')
pred_colored = cmap(pred_normalized)[:, :, :3] # Drop alpha channel
pred_colored = (pred_colored * 255).astype(np.uint8)
pred_pil = Image.fromarray(pred_colored)
# Upscale for display
upscale_size = (256, 256)
image_resized = image.resize(upscale_size, resample=Image.NEAREST)
pred_resized = pred_pil.resize(upscale_size, resample=Image.NEAREST)
return [image_resized, pred_resized]
# Gradio UI
examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],]
# examples = [["example.png"]]
demo = gr.Interface(
fn=predict_depth,
inputs=gr.Image(type="pil", label="Input RGB Image", height=image_output_size),
outputs=[
gr.Image(type="pil", label="Original Image", height=image_output_size),
gr.Image(type="pil", label="Predicted Depth Map", height=image_output_size),
],
title="🔭 DepthStar: Light-weight Depth Estimation",
description="Upload an RGB image and get the depth map predicted by our tiny DepthStar model.",
examples=examples,
theme="darkdefault",
)
if __name__ == "__main__":
demo.launch() |