RaynWu2002 commited on
Commit
3dcfad8
Β·
verified Β·
1 Parent(s): 9778377

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -32
app.py CHANGED
@@ -10,6 +10,7 @@ from net.dornet_ddp import Net_ddp
10
 
11
  # init
12
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
13
  net = Net(tiny_model=False).to(device)
14
  model_ckpt_map = {
15
  "RGB-D-D": "./checkpoints/RGBDD.pth",
@@ -17,6 +18,7 @@ model_ckpt_map = {
17
  }
18
 
19
  # load model
 
20
  def load_model(model_type: str):
21
  global net
22
  ckpt_path = model_ckpt_map[model_type]
@@ -35,6 +37,7 @@ load_model("RGB-D-D")
35
 
36
 
37
  # data process
 
38
  def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
39
  image = np.array(rgb_image.convert("RGB")).astype(np.float32)
40
  h, w, _ = image.shape
@@ -56,6 +59,7 @@ def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
56
 
57
 
58
  # model inference
 
59
  @torch.no_grad()
60
  def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str):
61
  load_model(model_type) # reset weight
@@ -80,38 +84,6 @@ def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str):
80
  return pred_gray, Image.fromarray(pred_heat)
81
 
82
 
83
- # Gradio
84
- # demo = gr.Interface(
85
- # fn=infer,
86
- # inputs=[
87
- # gr.Image(label="RGB Image", type="pil"),
88
- # gr.Image(label="Low-res Depth", type="pil", image_mode="I"),
89
- # gr.Dropdown(choices=["RGB-D-D", "TOFDSR"], label="Model Type", value="RGB-D-D")
90
- # ],
91
- # outputs=[
92
- # gr.Image(label="DORNet Output", type="pil", elem_classes=["output-image"]),
93
- # gr.Image(label="Normalized Output", type="pil", elem_classes=["output-image"])
94
- # ],
95
- # examples=[
96
- # ["examples/RGB-D-D/20200518160957_RGB.jpg", "examples/RGB-D-D/20200518160957_LR_fill_depth.png", "RGB-D-D"],
97
- # ["examples/TOFDSR/2020_09_08_13_59_59_435_rgb_rgb_crop.png", "examples/TOFDSR/2020_09_08_13_59_59_435_rgb_depth_crop_fill.png", "TOFDSR"],
98
- # ],
99
- # allow_flagging="never",
100
- # title="DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution \n CVPR 2025 (Oral Presentation)",
101
- # css="""
102
- # .output-image {
103
- # display: flex;
104
- # justify-content: center;
105
- # align-items: center;
106
- # }
107
- # .output-image img {
108
- # margin: auto;
109
- # display: block;
110
- # }
111
- # """
112
- # )
113
- #
114
- # demo.launch(share=True)
115
  Intro = """
116
  ## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution
117
  [πŸ“„ Paper](https://arxiv.org/pdf/2410.11666) β€’ [πŸ’» Code](https://github.com/yanzq95/DORNet) β€’ [πŸ“¦ Model](https://huggingface.co/wzxwyx/DORNet/tree/main)
 
10
 
11
  # init
12
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
+ print(device)
14
  net = Net(tiny_model=False).to(device)
15
  model_ckpt_map = {
16
  "RGB-D-D": "./checkpoints/RGBDD.pth",
 
18
  }
19
 
20
  # load model
21
+ @spaces.GPU
22
  def load_model(model_type: str):
23
  global net
24
  ckpt_path = model_ckpt_map[model_type]
 
37
 
38
 
39
  # data process
40
+ @spaces.GPU
41
  def preprocess_inputs(rgb_image: Image.Image, lr_depth: Image.Image):
42
  image = np.array(rgb_image.convert("RGB")).astype(np.float32)
43
  h, w, _ = image.shape
 
59
 
60
 
61
  # model inference
62
+ @spaces.GPU
63
  @torch.no_grad()
64
  def infer(rgb_image: Image.Image, lr_depth: Image.Image, model_type: str):
65
  load_model(model_type) # reset weight
 
84
  return pred_gray, Image.fromarray(pred_heat)
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  Intro = """
88
  ## DORNet: A Degradation Oriented and Regularized Network for Blind Depth Super-Resolution
89
  [πŸ“„ Paper](https://arxiv.org/pdf/2410.11666) β€’ [πŸ’» Code](https://github.com/yanzq95/DORNet) β€’ [πŸ“¦ Model](https://huggingface.co/wzxwyx/DORNet/tree/main)