gangweix commited on
Commit
128e81a
·
verified ·
1 Parent(s): b14aa39

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -112
app.py DELETED
@@ -1,112 +0,0 @@
1
- import gradio as gr
2
- import cv2
3
- import matplotlib
4
- import numpy as np
5
- import os
6
- from PIL import Image
7
- import spaces
8
- import torch
9
- import tempfile
10
- from gradio_imageslider import ImageSlider
11
- from huggingface_hub import hf_hub_download
12
-
13
- from ppd.utils.set_seed import set_seed
14
- from ppd.models.ppd import PixelPerfectDepth
15
-
16
- css = """
17
- #img-display-container {
18
- max-height: 100vh;
19
- }
20
- #img-display-input {
21
- max-height: 100vh;
22
- }
23
- #img-display-output {
24
- max-height: 100vh;
25
- }
26
- #download {
27
- height: 62px;
28
- }
29
-
30
- #img-display-output .image-slider-image {
31
- object-fit: contain !important;
32
- width: 100% !important;
33
- height: 100% !important;
34
- }
35
- """
36
-
37
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
38
- model = PixelPerfectDepth(sampling_steps=10)
39
- ckpt_path = hf_hub_download(
40
- repo_id="gangweix/Pixel-Perfect-Depth",
41
- filename="ppd.pth",
42
- repo_type="model"
43
- )
44
- state_dict = torch.load(ckpt_path, map_location="cpu")
45
- model.load_state_dict(state_dict, strict=False)
46
- model = model.to(DEVICE).eval()
47
-
48
-
49
-
50
- title = "# Pixel-Perfect Depth"
51
- description = """Official demo for **Pixel-Perfect Depth**.
52
- Please refer to our [paper](), [project page](https://pixel-perfect-depth.github.io), and [github](https://github.com/gangweix/pixel-perfect-depth) for more details."""
53
-
54
- @spaces.GPU
55
- def predict_depth(image):
56
- return model.infer_image(image)
57
-
58
- with gr.Blocks(css=css) as demo:
59
- gr.Markdown(title)
60
- gr.Markdown(description)
61
- gr.Markdown("### Depth Prediction demo")
62
-
63
- with gr.Row():
64
- input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
65
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
66
- submit = gr.Button(value="Predict Depth")
67
-
68
- concat_file = gr.File(label="Concatenated visualization (image+depth)", elem_id="image-depth-download")
69
- raw_file = gr.File(label="Raw depth output (saved as .npy)", elem_id="download",)
70
-
71
- cmap = matplotlib.colormaps.get_cmap('Spectral')
72
-
73
- def on_submit(image):
74
- original_image = image.copy()
75
-
76
- depth = predict_depth(image[:, :, ::-1])
77
-
78
- # save raw depth (npy)
79
- tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
80
- np.save(tmp_raw_depth.name, depth)
81
-
82
-
83
- depth_vis = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
84
- depth_vis = depth_vis.astype(np.uint8)
85
- colored_depth = (cmap(depth_vis)[:, :, :3] * 255).astype(np.uint8)
86
-
87
- split_region = np.ones((image.shape[0], 50, 3), dtype=np.uint8) * 255
88
- combined_result = cv2.hconcat([image[:, :, ::-1], split_region, colored_depth[:, :, ::-1]])
89
- tmp_concat = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
90
- cv2.imwrite(tmp_concat.name, combined_result)
91
-
92
- return [(original_image, colored_depth), tmp_concat.name, tmp_raw_depth.name]
93
-
94
- submit.click(
95
- on_submit,
96
- inputs=[input_image],
97
- outputs=[depth_image_slider, concat_file, raw_file]
98
- )
99
-
100
- example_files = os.listdir('assets/examples')
101
- example_files.sort()
102
- example_files = [os.path.join('assets/examples', filename) for filename in example_files]
103
- examples = gr.Examples(
104
- examples=example_files,
105
- inputs=[input_image],
106
- outputs=[depth_image_slider, concat_file, raw_file],
107
- fn=on_submit
108
- )
109
-
110
-
111
- if __name__ == '__main__':
112
- demo.queue().launch(share=True)