ApDepth / app.py
developy's picture
Update app.py
a7638ff verified
raw
history blame
812 Bytes
import gradio as gr
from diffusers import MarigoldDepthPipeline, DDIMScheduler
import torch
from PIL import Image
CHECKPOINT = "developy/ApDepth"
device = "cpu"
dtype = torch.float32
pipe = MarigoldDepthPipeline.from_pretrained(CHECKPOINT)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe = pipe.to(device=device, dtype=dtype)
def predict(image: Image.Image):
out = pipe(image)
depth_vis = pipe.image_processor.visualize_depth(out.prediction)[0]
return depth_vis
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="pil", label="Depth Map"),
title="ApDepth Demo",
description="Monocular Depth Estimation based on Marigold"
)
if __name__ == "__main__":
demo.launch()