|
|
from zoedepth.utils.misc import colorize, save_raw_16bit |
|
|
import torch |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from functools import partial |
|
|
|
|
|
def save_raw_16bit(depth, fpath="raw.png"): |
|
|
if isinstance(depth, torch.Tensor): |
|
|
depth = depth.squeeze().cpu().numpy() |
|
|
|
|
|
assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array" |
|
|
assert depth.ndim == 2, "Depth must be 2D" |
|
|
depth = depth * 256 |
|
|
depth = depth.astype(np.uint16) |
|
|
return depth |
|
|
|
|
|
|
|
|
@spaces.GPU(enable_queue=True) |
|
|
def process_image(model, image: Image.Image): |
|
|
image = image.convert("RGB") |
|
|
|
|
|
out = model.infer_pil(image) |
|
|
|
|
|
processed_array = save_raw_16bit(colorize(out)[:, :, 0]) |
|
|
return Image.fromarray(processed_array) |
|
|
|
|
|
def depth_interface(model, device): |
|
|
with gr.Row(): |
|
|
inputs=gr.Image(label="Input Image", type='pil') |
|
|
outputs=gr.Image(label="Depth Map", type='pil') |
|
|
generate_btn = gr.Button(value="Generate") |
|
|
generate_btn.click(partial(process_image, model.to(device)), inputs=inputs, outputs=outputs, api_name="generate_depth") |
|
|
|