depthstar / app.py
keivalya's picture
Update app.py
7364c33 verified
import torch
import gradio as gr
import numpy as np
from PIL import Image
import torchvision.transforms as T
from model import DepthSTAR
import matplotlib.cm as cm
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DepthSTAR().to(device)
model.load_state_dict(torch.load("depth_model_all.pth", map_location=device))
model.eval()
# Preprocessing
transform = T.Compose([
T.Resize((32, 32)),
T.ToTensor(),
])
# Larger output display
image_output_size = 256
def predict_depth(image):
img = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
pred = model(img)[0, 0].cpu().numpy()
# Normalize prediction
pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
# Apply inferno colormap
cmap = cm.get_cmap('inferno')
pred_colored = cmap(pred_normalized)[:, :, :3] # Drop alpha channel
pred_colored = (pred_colored * 255).astype(np.uint8)
pred_pil = Image.fromarray(pred_colored)
# Upscale for display
upscale_size = (256, 256)
image_resized = image.resize(upscale_size, resample=Image.NEAREST)
pred_resized = pred_pil.resize(upscale_size, resample=Image.NEAREST)
return [image_resized, pred_resized]
# Gradio UI
examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],]
# examples = [["example.png"]]
demo = gr.Interface(
fn=predict_depth,
inputs=gr.Image(type="pil", label="Input RGB Image", height=image_output_size),
outputs=[
gr.Image(type="pil", label="Original Image", height=image_output_size),
gr.Image(type="pil", label="Predicted Depth Map", height=image_output_size),
],
title="🔭 DepthStar: Light-weight Depth Estimation",
description="Upload an RGB image and get the depth map predicted by our tiny DepthStar model.",
examples=examples,
theme="darkdefault",
)
if __name__ == "__main__":
demo.launch()