Tingman commited on
Commit
d657660
·
1 Parent(s): 5febdd3

add compile call

Browse files
Files changed (1) hide show
  1. gradio_app.py +53 -4
gradio_app.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn.functional as F
6
  import os
7
  import time
8
  import spaces
 
9
 
10
  from dataloader.stereo import transforms
11
  from utils.utils import InputPadder, calc_noc_mask
@@ -37,11 +38,18 @@ class MatchStereoDemo:
37
 
38
  def load_model(self, mode, variant, precision, mat_impl):
39
  """load model, skip if the model has been loaded"""
 
 
 
 
 
 
40
  if (self.model is not None and
41
  self.current_variant == variant and
42
  self.current_mode == mode and
43
  self.current_precision == precision and
44
- self.current_mat_impl == mat_impl):
 
45
  return "Model already loaded"
46
 
47
  # fixed checkpoint path
@@ -65,6 +73,7 @@ class MatchStereoDemo:
65
 
66
  if not self.has_cuda:
67
  precision = "fp32"
 
68
  dtypes = {'fp32': torch.float32, 'fp16': torch.float16}
69
  self.dtype = dtypes[precision]
70
 
@@ -138,6 +147,12 @@ class MatchStereoDemo:
138
  def process_images(self, left_image, right_image, mode, variant,
139
  low_res_init=False, inference_size_name="Original",
140
  precision="fp32", mat_impl="pytorch"):
 
 
 
 
 
 
141
  if not self.has_cuda:
142
  precision = "fp32"
143
  mat_impl = "pytorch"
@@ -247,6 +262,37 @@ class MatchStereoDemo:
247
 
248
  demo_model = MatchStereoDemo()
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  # example images
251
  examples = [
252
  ["examples/booster_bathroom_left.png", "examples/booster_bathroom_right.png", "stereo", "tiny"],
@@ -280,8 +326,11 @@ with gr.Blocks(title="MatchStereo/MatchFlow Demo") as demo:
280
  gr.Markdown("# MatchStereo/MatchFlow Demo")
281
  gr.Markdown("Upload stereo images for disparity estimation or consecutive frames for optical flow estimation.")
282
 
283
- if not demo_model.has_cuda:
 
284
  gr.Markdown("> Note: Running on CPU. Some options (fp16, cuda) are disabled.")
 
 
285
 
286
  with gr.Row():
287
  with gr.Column():
@@ -321,14 +370,14 @@ with gr.Blocks(title="MatchStereo/MatchFlow Demo") as demo:
321
  label="Precision",
322
  value="fp32",
323
  info="Model precision",
324
- interactive=demo_model.has_cuda
325
  )
326
  mat_impl = gr.Radio(
327
  choices=["cuda", "pytorch"],
328
  label="MatchAttention Implementation",
329
  value="cuda",
330
  info="MatchAttention implementations",
331
- interactive=demo_model.has_cuda
332
  )
333
 
334
  run_btn = gr.Button("Run Inference", variant="primary")
 
6
  import os
7
  import time
8
  import spaces
9
+ import subprocess
10
 
11
  from dataloader.stereo import transforms
12
  from utils.utils import InputPadder, calc_noc_mask
 
38
 
39
  def load_model(self, mode, variant, precision, mat_impl):
40
  """load model, skip if the model has been loaded"""
41
+ current_has_cuda = torch.cuda.is_available()
42
+ if current_has_cuda != self.has_cuda:
43
+ print(f"CUDA status changed: {self.has_cuda} -> {current_has_cuda}")
44
+ self.has_cuda = current_has_cuda
45
+ self.device = "cuda" if self.has_cuda else 'cpu'
46
+
47
  if (self.model is not None and
48
  self.current_variant == variant and
49
  self.current_mode == mode and
50
  self.current_precision == precision and
51
+ self.current_mat_impl == mat_impl and
52
+ self.has_cuda == current_has_cuda):
53
  return "Model already loaded"
54
 
55
  # fixed checkpoint path
 
73
 
74
  if not self.has_cuda:
75
  precision = "fp32"
76
+ mat_impl = "pytorch"
77
  dtypes = {'fp32': torch.float32, 'fp16': torch.float16}
78
  self.dtype = dtypes[precision]
79
 
 
147
  def process_images(self, left_image, right_image, mode, variant,
148
  low_res_init=False, inference_size_name="Original",
149
  precision="fp32", mat_impl="pytorch"):
150
+ current_has_cuda = torch.cuda.is_available()
151
+ if current_has_cuda != self.has_cuda:
152
+ print(f"CUDA status changed before processing: {self.has_cuda} -> {current_has_cuda}")
153
+ self.has_cuda = current_has_cuda
154
+ self.device = "cuda" if self.has_cuda else 'cpu'
155
+
156
  if not self.has_cuda:
157
  precision = "fp32"
158
  mat_impl = "pytorch"
 
262
 
263
  demo_model = MatchStereoDemo()
264
 
265
+ def compile_cuda_extensions():
266
+ try:
267
+ print("Start compiling CUDA extension...")
268
+ current_dir = os.path.dirname(os.path.abspath(__file__))
269
+ models_dir = os.path.join(current_dir, "models")
270
+ compile_script = os.path.join(models_dir, "compile.sh")
271
+
272
+ if os.path.exists(compile_script):
273
+ original_cwd = os.getcwd()
274
+ os.chdir(models_dir)
275
+
276
+ result = subprocess.run(["bash", "compile.sh"],
277
+ capture_output=True, text=True)
278
+
279
+ os.chdir(original_cwd)
280
+
281
+ if result.returncode == 0:
282
+ print("CUDA extension compile succeed!")
283
+ print("output:", result.stdout)
284
+ else:
285
+ print("CUDA extension compile failed!")
286
+ print(result.stderr)
287
+ print(result.stdout)
288
+ else:
289
+ print(f"no compile scripts found: {compile_script}")
290
+
291
+ except Exception as e:
292
+ print(f"Error during compile: {e}")
293
+
294
+ compile_cuda_extensions()
295
+
296
  # example images
297
  examples = [
298
  ["examples/booster_bathroom_left.png", "examples/booster_bathroom_right.png", "stereo", "tiny"],
 
326
  gr.Markdown("# MatchStereo/MatchFlow Demo")
327
  gr.Markdown("Upload stereo images for disparity estimation or consecutive frames for optical flow estimation.")
328
 
329
+ current_has_cuda = torch.cuda.is_available()
330
+ if not current_has_cuda:
331
  gr.Markdown("> Note: Running on CPU. Some options (fp16, cuda) are disabled.")
332
+ else:
333
+ gr.Markdown(f"> Note: Running on GPU ({torch.cuda.get_device_name(0)}).")
334
 
335
  with gr.Row():
336
  with gr.Column():
 
370
  label="Precision",
371
  value="fp32",
372
  info="Model precision",
373
+ interactive=current_has_cuda
374
  )
375
  mat_impl = gr.Radio(
376
  choices=["cuda", "pytorch"],
377
  label="MatchAttention Implementation",
378
  value="cuda",
379
  info="MatchAttention implementations",
380
+ interactive=current_has_cuda
381
  )
382
 
383
  run_btn = gr.Button("Run Inference", variant="primary")