Spaces:
Running
Running
add compile call
Browse files- gradio_app.py +53 -4
gradio_app.py
CHANGED
|
@@ -6,6 +6,7 @@ import torch.nn.functional as F
|
|
| 6 |
import os
|
| 7 |
import time
|
| 8 |
import spaces
|
|
|
|
| 9 |
|
| 10 |
from dataloader.stereo import transforms
|
| 11 |
from utils.utils import InputPadder, calc_noc_mask
|
|
@@ -37,11 +38,18 @@ class MatchStereoDemo:
|
|
| 37 |
|
| 38 |
def load_model(self, mode, variant, precision, mat_impl):
|
| 39 |
"""load model, skip if the model has been loaded"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
if (self.model is not None and
|
| 41 |
self.current_variant == variant and
|
| 42 |
self.current_mode == mode and
|
| 43 |
self.current_precision == precision and
|
| 44 |
-
self.current_mat_impl == mat_impl
|
|
|
|
| 45 |
return "Model already loaded"
|
| 46 |
|
| 47 |
# fixed checkpoint path
|
|
@@ -65,6 +73,7 @@ class MatchStereoDemo:
|
|
| 65 |
|
| 66 |
if not self.has_cuda:
|
| 67 |
precision = "fp32"
|
|
|
|
| 68 |
dtypes = {'fp32': torch.float32, 'fp16': torch.float16}
|
| 69 |
self.dtype = dtypes[precision]
|
| 70 |
|
|
@@ -138,6 +147,12 @@ class MatchStereoDemo:
|
|
| 138 |
def process_images(self, left_image, right_image, mode, variant,
|
| 139 |
low_res_init=False, inference_size_name="Original",
|
| 140 |
precision="fp32", mat_impl="pytorch"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
if not self.has_cuda:
|
| 142 |
precision = "fp32"
|
| 143 |
mat_impl = "pytorch"
|
|
@@ -247,6 +262,37 @@ class MatchStereoDemo:
|
|
| 247 |
|
| 248 |
demo_model = MatchStereoDemo()
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
# example images
|
| 251 |
examples = [
|
| 252 |
["examples/booster_bathroom_left.png", "examples/booster_bathroom_right.png", "stereo", "tiny"],
|
|
@@ -280,8 +326,11 @@ with gr.Blocks(title="MatchStereo/MatchFlow Demo") as demo:
|
|
| 280 |
gr.Markdown("# MatchStereo/MatchFlow Demo")
|
| 281 |
gr.Markdown("Upload stereo images for disparity estimation or consecutive frames for optical flow estimation.")
|
| 282 |
|
| 283 |
-
|
|
|
|
| 284 |
gr.Markdown("> Note: Running on CPU. Some options (fp16, cuda) are disabled.")
|
|
|
|
|
|
|
| 285 |
|
| 286 |
with gr.Row():
|
| 287 |
with gr.Column():
|
|
@@ -321,14 +370,14 @@ with gr.Blocks(title="MatchStereo/MatchFlow Demo") as demo:
|
|
| 321 |
label="Precision",
|
| 322 |
value="fp32",
|
| 323 |
info="Model precision",
|
| 324 |
-
interactive=
|
| 325 |
)
|
| 326 |
mat_impl = gr.Radio(
|
| 327 |
choices=["cuda", "pytorch"],
|
| 328 |
label="MatchAttention Implementation",
|
| 329 |
value="cuda",
|
| 330 |
info="MatchAttention implementations",
|
| 331 |
-
interactive=
|
| 332 |
)
|
| 333 |
|
| 334 |
run_btn = gr.Button("Run Inference", variant="primary")
|
|
|
|
| 6 |
import os
|
| 7 |
import time
|
| 8 |
import spaces
|
| 9 |
+
import subprocess
|
| 10 |
|
| 11 |
from dataloader.stereo import transforms
|
| 12 |
from utils.utils import InputPadder, calc_noc_mask
|
|
|
|
| 38 |
|
| 39 |
def load_model(self, mode, variant, precision, mat_impl):
|
| 40 |
"""load model, skip if the model has been loaded"""
|
| 41 |
+
current_has_cuda = torch.cuda.is_available()
|
| 42 |
+
if current_has_cuda != self.has_cuda:
|
| 43 |
+
print(f"CUDA status changed: {self.has_cuda} -> {current_has_cuda}")
|
| 44 |
+
self.has_cuda = current_has_cuda
|
| 45 |
+
self.device = "cuda" if self.has_cuda else 'cpu'
|
| 46 |
+
|
| 47 |
if (self.model is not None and
|
| 48 |
self.current_variant == variant and
|
| 49 |
self.current_mode == mode and
|
| 50 |
self.current_precision == precision and
|
| 51 |
+
self.current_mat_impl == mat_impl and
|
| 52 |
+
self.has_cuda == current_has_cuda):
|
| 53 |
return "Model already loaded"
|
| 54 |
|
| 55 |
# fixed checkpoint path
|
|
|
|
| 73 |
|
| 74 |
if not self.has_cuda:
|
| 75 |
precision = "fp32"
|
| 76 |
+
mat_impl = "pytorch"
|
| 77 |
dtypes = {'fp32': torch.float32, 'fp16': torch.float16}
|
| 78 |
self.dtype = dtypes[precision]
|
| 79 |
|
|
|
|
| 147 |
def process_images(self, left_image, right_image, mode, variant,
|
| 148 |
low_res_init=False, inference_size_name="Original",
|
| 149 |
precision="fp32", mat_impl="pytorch"):
|
| 150 |
+
current_has_cuda = torch.cuda.is_available()
|
| 151 |
+
if current_has_cuda != self.has_cuda:
|
| 152 |
+
print(f"CUDA status changed before processing: {self.has_cuda} -> {current_has_cuda}")
|
| 153 |
+
self.has_cuda = current_has_cuda
|
| 154 |
+
self.device = "cuda" if self.has_cuda else 'cpu'
|
| 155 |
+
|
| 156 |
if not self.has_cuda:
|
| 157 |
precision = "fp32"
|
| 158 |
mat_impl = "pytorch"
|
|
|
|
| 262 |
|
| 263 |
demo_model = MatchStereoDemo()
|
| 264 |
|
| 265 |
+
def compile_cuda_extensions():
|
| 266 |
+
try:
|
| 267 |
+
print("Start compiling CUDA extension...")
|
| 268 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 269 |
+
models_dir = os.path.join(current_dir, "models")
|
| 270 |
+
compile_script = os.path.join(models_dir, "compile.sh")
|
| 271 |
+
|
| 272 |
+
if os.path.exists(compile_script):
|
| 273 |
+
original_cwd = os.getcwd()
|
| 274 |
+
os.chdir(models_dir)
|
| 275 |
+
|
| 276 |
+
result = subprocess.run(["bash", "compile.sh"],
|
| 277 |
+
capture_output=True, text=True)
|
| 278 |
+
|
| 279 |
+
os.chdir(original_cwd)
|
| 280 |
+
|
| 281 |
+
if result.returncode == 0:
|
| 282 |
+
print("CUDA extension compile succeed!")
|
| 283 |
+
print("output:", result.stdout)
|
| 284 |
+
else:
|
| 285 |
+
print("CUDA extension compile failed!")
|
| 286 |
+
print(result.stderr)
|
| 287 |
+
print(result.stdout)
|
| 288 |
+
else:
|
| 289 |
+
print(f"no compile scripts found: {compile_script}")
|
| 290 |
+
|
| 291 |
+
except Exception as e:
|
| 292 |
+
print(f"Error during compile: {e}")
|
| 293 |
+
|
| 294 |
+
compile_cuda_extensions()
|
| 295 |
+
|
| 296 |
# example images
|
| 297 |
examples = [
|
| 298 |
["examples/booster_bathroom_left.png", "examples/booster_bathroom_right.png", "stereo", "tiny"],
|
|
|
|
| 326 |
gr.Markdown("# MatchStereo/MatchFlow Demo")
|
| 327 |
gr.Markdown("Upload stereo images for disparity estimation or consecutive frames for optical flow estimation.")
|
| 328 |
|
| 329 |
+
current_has_cuda = torch.cuda.is_available()
|
| 330 |
+
if not current_has_cuda:
|
| 331 |
gr.Markdown("> Note: Running on CPU. Some options (fp16, cuda) are disabled.")
|
| 332 |
+
else:
|
| 333 |
+
gr.Markdown(f"> Note: Running on GPU ({torch.cuda.get_device_name(0)}).")
|
| 334 |
|
| 335 |
with gr.Row():
|
| 336 |
with gr.Column():
|
|
|
|
| 370 |
label="Precision",
|
| 371 |
value="fp32",
|
| 372 |
info="Model precision",
|
| 373 |
+
interactive=current_has_cuda
|
| 374 |
)
|
| 375 |
mat_impl = gr.Radio(
|
| 376 |
choices=["cuda", "pytorch"],
|
| 377 |
label="MatchAttention Implementation",
|
| 378 |
value="cuda",
|
| 379 |
info="MatchAttention implementations",
|
| 380 |
+
interactive=current_has_cuda
|
| 381 |
)
|
| 382 |
|
| 383 |
run_btn = gr.Button("Run Inference", variant="primary")
|