File size: 1,965 Bytes
98141d1
 
 
 
 
f54cbf4
4c8e425
98141d1
 
 
f54cbf4
98141d1
 
 
 
 
 
 
 
 
f7fa060
71e0311
f7fa060
98141d1
 
 
 
4c8e425
 
98141d1
4817508
4c8e425
 
 
 
 
 
 
71e0311
4817508
4c8e425
4817508
 
98141d1
4c8e425
98141d1
7364c33
 
98141d1
 
 
50e087e
98141d1
f7fa060
 
98141d1
 
 
 
4817508
98141d1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import gradio as gr
import numpy as np
from PIL import Image
import torchvision.transforms as T
from model import DepthSTAR
import matplotlib.cm as cm

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DepthSTAR().to(device)
model.load_state_dict(torch.load("depth_model_all.pth", map_location=device))
model.eval()

# Preprocessing
transform = T.Compose([
    T.Resize((32, 32)),
    T.ToTensor(),
])

# Larger output display
image_output_size = 256

def predict_depth(image):
    img = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(img)[0, 0].cpu().numpy()

    # Normalize prediction
    pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)

    # Apply inferno colormap
    cmap = cm.get_cmap('inferno')
    pred_colored = cmap(pred_normalized)[:, :, :3]  # Drop alpha channel
    pred_colored = (pred_colored * 255).astype(np.uint8)
    pred_pil = Image.fromarray(pred_colored)

    # Upscale for display
    upscale_size = (256, 256)
    image_resized = image.resize(upscale_size, resample=Image.NEAREST)
    pred_resized = pred_pil.resize(upscale_size, resample=Image.NEAREST)

    return [image_resized, pred_resized]


# Gradio UI
examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],]
# examples = [["example.png"]]

demo = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="pil", label="Input RGB Image", height=image_output_size),
    outputs=[
        gr.Image(type="pil", label="Original Image", height=image_output_size),
        gr.Image(type="pil", label="Predicted Depth Map", height=image_output_size),
    ],
    title="🔭 DepthStar: Light-weight Depth Estimation",
    description="Upload an RGB image and get the depth map predicted by our tiny DepthStar model.",
    examples=examples,
    theme="darkdefault",
)

if __name__ == "__main__":
    demo.launch()