|
|
import torch |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torchvision.transforms as T |
|
|
from model import DepthSTAR |
|
|
import matplotlib.cm as cm |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = DepthSTAR().to(device) |
|
|
model.load_state_dict(torch.load("depth_model_all.pth", map_location=device)) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = T.Compose([ |
|
|
T.Resize((32, 32)), |
|
|
T.ToTensor(), |
|
|
]) |
|
|
|
|
|
|
|
|
image_output_size = 256 |
|
|
|
|
|
def predict_depth(image): |
|
|
img = transform(image).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
pred = model(img)[0, 0].cpu().numpy() |
|
|
|
|
|
|
|
|
pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) |
|
|
|
|
|
|
|
|
cmap = cm.get_cmap('inferno') |
|
|
pred_colored = cmap(pred_normalized)[:, :, :3] |
|
|
pred_colored = (pred_colored * 255).astype(np.uint8) |
|
|
pred_pil = Image.fromarray(pred_colored) |
|
|
|
|
|
|
|
|
upscale_size = (256, 256) |
|
|
image_resized = image.resize(upscale_size, resample=Image.NEAREST) |
|
|
pred_resized = pred_pil.resize(upscale_size, resample=Image.NEAREST) |
|
|
|
|
|
return [image_resized, pred_resized] |
|
|
|
|
|
|
|
|
|
|
|
examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],] |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_depth, |
|
|
inputs=gr.Image(type="pil", label="Input RGB Image", height=image_output_size), |
|
|
outputs=[ |
|
|
gr.Image(type="pil", label="Original Image", height=image_output_size), |
|
|
gr.Image(type="pil", label="Predicted Depth Map", height=image_output_size), |
|
|
], |
|
|
title="🔭 DepthStar: Light-weight Depth Estimation", |
|
|
description="Upload an RGB image and get the depth map predicted by our tiny DepthStar model.", |
|
|
examples=examples, |
|
|
theme="darkdefault", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |