Julien Blanchon
commited on
Commit
·
e4c4055
1
Parent(s):
6dddb70
Fix tensor type mismatch: ensure int types for CUDA operations
Browse files- Cast tile_bounds components to int() to avoid Float/Int type mismatch
- Cast img_h, img_w to int() in all gsplat function calls
- Cast dimensions in tensor.view() operations to int()
- This fixes 'expected scalar type Int but found Float' error in Docker environment
- gradio_models.py +13 -13
gradio_models.py
CHANGED
|
@@ -154,9 +154,9 @@ class GradioGaussianSplatting2D(nn.Module):
|
|
| 154 |
)
|
| 155 |
self.img_h, self.img_w = self.gt_images.shape[1:]
|
| 156 |
self.tile_bounds = (
|
| 157 |
-
(self.img_w + self.block_w - 1) // self.block_w,
|
| 158 |
-
(self.img_h + self.block_h - 1) // self.block_h,
|
| 159 |
-
1,
|
| 160 |
)
|
| 161 |
|
| 162 |
def _init_bit_precision(self, args):
|
|
@@ -388,7 +388,7 @@ class GradioGaussianSplatting2D(nn.Module):
|
|
| 388 |
)
|
| 389 |
|
| 390 |
begin = perf_counter()
|
| 391 |
-
tmp = project_gaussians_2d_scale_rot(xy, scale, rot, img_h, img_w, tile_bounds)
|
| 392 |
xy, radii, conics, num_tiles_hit = tmp
|
| 393 |
|
| 394 |
if not self.disable_tiles:
|
|
@@ -399,15 +399,15 @@ class GradioGaussianSplatting2D(nn.Module):
|
|
| 399 |
conics,
|
| 400 |
num_tiles_hit,
|
| 401 |
feat,
|
| 402 |
-
img_h,
|
| 403 |
-
img_w,
|
| 404 |
self.block_h,
|
| 405 |
self.block_w,
|
| 406 |
enable_topk_norm,
|
| 407 |
)
|
| 408 |
out_image = rasterize_gaussians_sum(*tmp)
|
| 409 |
else:
|
| 410 |
-
tmp = xy, conics, feat, img_h, img_w
|
| 411 |
out_image = rasterize_gaussians_no_tiles(*tmp)
|
| 412 |
|
| 413 |
render_time = perf_counter() - begin
|
|
@@ -416,7 +416,7 @@ class GradioGaussianSplatting2D(nn.Module):
|
|
| 416 |
return render_time
|
| 417 |
|
| 418 |
out_image = (
|
| 419 |
-
out_image.view(-1, img_h, img_w, self.feat_dim)
|
| 420 |
.permute(0, 3, 1, 2)
|
| 421 |
.contiguous()
|
| 422 |
)
|
|
@@ -469,7 +469,7 @@ class GradioGaussianSplatting2D(nn.Module):
|
|
| 469 |
)
|
| 470 |
|
| 471 |
tmp = project_gaussians_2d_scale_rot(
|
| 472 |
-
xy, scale, rot, self.img_h, self.img_w, self.tile_bounds
|
| 473 |
)
|
| 474 |
xy, radii, conics, num_tiles_hit = tmp
|
| 475 |
|
|
@@ -481,19 +481,19 @@ class GradioGaussianSplatting2D(nn.Module):
|
|
| 481 |
conics,
|
| 482 |
num_tiles_hit,
|
| 483 |
feat,
|
| 484 |
-
self.img_h,
|
| 485 |
-
self.img_w,
|
| 486 |
self.block_h,
|
| 487 |
self.block_w,
|
| 488 |
enable_topk_norm,
|
| 489 |
)
|
| 490 |
out_image = rasterize_gaussians_sum(*tmp)
|
| 491 |
else:
|
| 492 |
-
tmp = xy, conics, feat, self.img_h, self.img_w
|
| 493 |
out_image = rasterize_gaussians_no_tiles(*tmp)
|
| 494 |
|
| 495 |
out_image = (
|
| 496 |
-
out_image.view(-1, self.img_h, self.img_w, self.feat_dim)
|
| 497 |
.permute(0, 3, 1, 2)
|
| 498 |
.contiguous()
|
| 499 |
).squeeze(dim=0)
|
|
|
|
| 154 |
)
|
| 155 |
self.img_h, self.img_w = self.gt_images.shape[1:]
|
| 156 |
self.tile_bounds = (
|
| 157 |
+
int((self.img_w + self.block_w - 1) // self.block_w),
|
| 158 |
+
int((self.img_h + self.block_h - 1) // self.block_h),
|
| 159 |
+
int(1),
|
| 160 |
)
|
| 161 |
|
| 162 |
def _init_bit_precision(self, args):
|
|
|
|
| 388 |
)
|
| 389 |
|
| 390 |
begin = perf_counter()
|
| 391 |
+
tmp = project_gaussians_2d_scale_rot(xy, scale, rot, int(img_h), int(img_w), tile_bounds)
|
| 392 |
xy, radii, conics, num_tiles_hit = tmp
|
| 393 |
|
| 394 |
if not self.disable_tiles:
|
|
|
|
| 399 |
conics,
|
| 400 |
num_tiles_hit,
|
| 401 |
feat,
|
| 402 |
+
int(img_h),
|
| 403 |
+
int(img_w),
|
| 404 |
self.block_h,
|
| 405 |
self.block_w,
|
| 406 |
enable_topk_norm,
|
| 407 |
)
|
| 408 |
out_image = rasterize_gaussians_sum(*tmp)
|
| 409 |
else:
|
| 410 |
+
tmp = xy, conics, feat, int(img_h), int(img_w)
|
| 411 |
out_image = rasterize_gaussians_no_tiles(*tmp)
|
| 412 |
|
| 413 |
render_time = perf_counter() - begin
|
|
|
|
| 416 |
return render_time
|
| 417 |
|
| 418 |
out_image = (
|
| 419 |
+
out_image.view(-1, int(img_h), int(img_w), self.feat_dim)
|
| 420 |
.permute(0, 3, 1, 2)
|
| 421 |
.contiguous()
|
| 422 |
)
|
|
|
|
| 469 |
)
|
| 470 |
|
| 471 |
tmp = project_gaussians_2d_scale_rot(
|
| 472 |
+
xy, scale, rot, int(self.img_h), int(self.img_w), self.tile_bounds
|
| 473 |
)
|
| 474 |
xy, radii, conics, num_tiles_hit = tmp
|
| 475 |
|
|
|
|
| 481 |
conics,
|
| 482 |
num_tiles_hit,
|
| 483 |
feat,
|
| 484 |
+
int(self.img_h),
|
| 485 |
+
int(self.img_w),
|
| 486 |
self.block_h,
|
| 487 |
self.block_w,
|
| 488 |
enable_topk_norm,
|
| 489 |
)
|
| 490 |
out_image = rasterize_gaussians_sum(*tmp)
|
| 491 |
else:
|
| 492 |
+
tmp = xy, conics, feat, int(self.img_h), int(self.img_w)
|
| 493 |
out_image = rasterize_gaussians_no_tiles(*tmp)
|
| 494 |
|
| 495 |
out_image = (
|
| 496 |
+
out_image.view(-1, int(self.img_h), int(self.img_w), self.feat_dim)
|
| 497 |
.permute(0, 3, 1, 2)
|
| 498 |
.contiguous()
|
| 499 |
).squeeze(dim=0)
|