import torch import gradio as gr import numpy as np from PIL import Image import torchvision.transforms as T from model import DepthSTAR import matplotlib.cm as cm # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DepthSTAR().to(device) model.load_state_dict(torch.load("depth_model_all.pth", map_location=device)) model.eval() # Preprocessing transform = T.Compose([ T.Resize((32, 32)), T.ToTensor(), ]) # Larger output display image_output_size = 256 def predict_depth(image): img = transform(image).unsqueeze(0).to(device) with torch.no_grad(): pred = model(img)[0, 0].cpu().numpy() # Normalize prediction pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) # Apply inferno colormap cmap = cm.get_cmap('inferno') pred_colored = cmap(pred_normalized)[:, :, :3] # Drop alpha channel pred_colored = (pred_colored * 255).astype(np.uint8) pred_pil = Image.fromarray(pred_colored) # Upscale for display upscale_size = (256, 256) image_resized = image.resize(upscale_size, resample=Image.NEAREST) pred_resized = pred_pil.resize(upscale_size, resample=Image.NEAREST) return [image_resized, pred_resized] # Gradio UI examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],] # examples = [["example.png"]] demo = gr.Interface( fn=predict_depth, inputs=gr.Image(type="pil", label="Input RGB Image", height=image_output_size), outputs=[ gr.Image(type="pil", label="Original Image", height=image_output_size), gr.Image(type="pil", label="Predicted Depth Map", height=image_output_size), ], title="🔭 DepthStar: Light-weight Depth Estimation", description="Upload an RGB image and get the depth map predicted by our tiny DepthStar model.", examples=examples, theme="darkdefault", ) if __name__ == "__main__": demo.launch()