Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from torchvision.transforms import Compose | |
| import cv2 | |
| from dpt.models import DPTDepthModel, DPTSegmentationModel | |
| from dpt.transforms import Resize, NormalizeImage, PrepareForNet | |
| import os | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("device: %s" % device) | |
| default_models = { | |
| "dpt_hybrid": "weights/dpt_hybrid-midas-501f0c75.pt", | |
| "segment_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt" | |
| } | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = True | |
| depth_model = DPTDepthModel( | |
| path=default_models["dpt_hybrid"], | |
| backbone="vitb_rn50_384", | |
| non_negative=True, | |
| enable_attention_hooks=False, | |
| ) | |
| depth_model.eval() | |
| depth_model.to(device) | |
| seg_model = DPTSegmentationModel( | |
| 150, | |
| path=default_models["segment_hybrid"], | |
| backbone="vitb_rn50_384", | |
| ) | |
| seg_model.eval() | |
| seg_model.to(device) | |
| # Transform | |
| net_w = net_h = 384 | |
| normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| transform = Compose( | |
| [ | |
| Resize( | |
| net_w, | |
| net_h, | |
| resize_target=None, | |
| keep_aspect_ratio=True, | |
| ensure_multiple_of=32, | |
| resize_method="minimal", | |
| image_interpolation_method=cv2.INTER_CUBIC, | |
| ), | |
| normalization, | |
| PrepareForNet(), | |
| ] | |
| ) | |
| def write_depth(depth, bits=1, absolute_depth=False): | |
| """Write depth map to pfm and png file. | |
| Args: | |
| path (str): filepath without extension | |
| depth (array): depth | |
| """ | |
| # write_pfm(path + ".pfm", depth.astype(np.float32)) | |
| if absolute_depth: | |
| out = depth | |
| else: | |
| depth_min = depth.min() | |
| depth_max = depth.max() | |
| max_val = (2 ** (8 * bits)) - 1 | |
| if depth_max - depth_min > np.finfo("float").eps: | |
| out = max_val * (depth - depth_min) / (depth_max - depth_min) | |
| else: | |
| out = np.zeros(depth.shape, dtype=depth.dtype) | |
| if bits == 1: | |
| return out.astype("uint8") | |
| elif bits == 2: | |
| return out.astype("uint16") | |
| def DPT(image): | |
| img_input = transform({"image": image})["image"] | |
| # compute | |
| with torch.no_grad(): | |
| sample = torch.from_numpy(img_input).to(device).unsqueeze(0) | |
| prediction = depth_model.forward(sample) | |
| prediction = ( | |
| torch.nn.functional.interpolate( | |
| prediction.unsqueeze(1), | |
| size=image.shape[:2], | |
| mode="bicubic", | |
| align_corners=False, | |
| ) | |
| .squeeze() | |
| .cpu() | |
| .numpy() | |
| ) | |
| depth_img = write_depth(prediction, bits=2) | |
| return depth_img | |
| def Segment(image): | |
| img_input = transform({"image": image})["image"] | |
| # compute | |
| with torch.no_grad(): | |
| sample = torch.from_numpy(img_input).to(device).unsqueeze(0) | |
| # if optimize == True and device == torch.device("cuda"): | |
| # sample = sample.to(memory_format=torch.channels_last) | |
| # sample = sample.half() | |
| out = seg_model.forward(sample) | |
| prediction = torch.nn.functional.interpolate( | |
| out, size=image.shape[:2], mode="bicubic", align_corners=False | |
| ) | |
| prediction = torch.argmax(prediction, dim=1) + 1 | |
| prediction = prediction.squeeze().cpu().numpy() | |
| return prediction | |
| title = " AISeed AI Application Demo " | |
| description = "# A Demo of Deep Learning for Depth Estimation" | |
| example_list = [["examples/" + example] for example in os.listdir("examples")] | |
| with gr.Blocks() as demo: | |
| demo.title = title | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| im_2 = gr.Image(label="Depth Image") | |
| im_3 = gr.Image(label="Segment Image") | |
| with gr.Column(): | |
| im = gr.Image(label="Input Image") | |
| btn1 = gr.Button(value="Depth Estimator") | |
| btn1.click(DPT, inputs=[im], outputs=[im_2]) | |
| btn2 = gr.Button(value="Segment") | |
| btn2.click(Segment, inputs=[im], outputs=[im_3]) | |
| gr.Examples(examples=example_list, | |
| inputs=[im], | |
| outputs=[im_2]) | |
| if __name__ == "__main__": | |
| demo.launch() |