Julien Blanchon commited on
Commit
de1ddad
·
1 Parent(s): 53f5edd

Fix CUDA tensor type error: explicit int casting for all gsplat parameters

Browse files

- Cast block_h, block_w to int() in rasterize_gaussians_sum calls
- Use direct function calls instead of tuple unpacking for better type control
- This should fix the 'expected scalar type Int but found Float' error in CUDA backend
- Based on traceback showing error in _C.nd_rasterize_forward_topk_norm

Files changed (1) hide show
  1. gradio_models.py +6 -8
gradio_models.py CHANGED
@@ -395,7 +395,7 @@ class GradioGaussianSplatting2D(nn.Module):
395
 
396
  if not self.disable_tiles:
397
  enable_topk_norm = not self.disable_topk_norm
398
- tmp = (
399
  xy,
400
  radii,
401
  conics,
@@ -403,11 +403,10 @@ class GradioGaussianSplatting2D(nn.Module):
403
  feat,
404
  int(img_h),
405
  int(img_w),
406
- self.block_h,
407
- self.block_w,
408
  enable_topk_norm,
409
  )
410
- out_image = rasterize_gaussians_sum(*tmp)
411
  else:
412
  tmp = xy, conics, feat, int(img_h), int(img_w)
413
  out_image = rasterize_gaussians_no_tiles(*tmp)
@@ -477,7 +476,7 @@ class GradioGaussianSplatting2D(nn.Module):
477
 
478
  if not self.disable_tiles:
479
  enable_topk_norm = not self.disable_topk_norm
480
- tmp = (
481
  xy,
482
  radii,
483
  conics,
@@ -485,11 +484,10 @@ class GradioGaussianSplatting2D(nn.Module):
485
  feat,
486
  int(self.img_h),
487
  int(self.img_w),
488
- self.block_h,
489
- self.block_w,
490
  enable_topk_norm,
491
  )
492
- out_image = rasterize_gaussians_sum(*tmp)
493
  else:
494
  tmp = xy, conics, feat, int(self.img_h), int(self.img_w)
495
  out_image = rasterize_gaussians_no_tiles(*tmp)
 
395
 
396
  if not self.disable_tiles:
397
  enable_topk_norm = not self.disable_topk_norm
398
+ out_image = rasterize_gaussians_sum(
399
  xy,
400
  radii,
401
  conics,
 
403
  feat,
404
  int(img_h),
405
  int(img_w),
406
+ int(self.block_h),
407
+ int(self.block_w),
408
  enable_topk_norm,
409
  )
 
410
  else:
411
  tmp = xy, conics, feat, int(img_h), int(img_w)
412
  out_image = rasterize_gaussians_no_tiles(*tmp)
 
476
 
477
  if not self.disable_tiles:
478
  enable_topk_norm = not self.disable_topk_norm
479
+ out_image = rasterize_gaussians_sum(
480
  xy,
481
  radii,
482
  conics,
 
484
  feat,
485
  int(self.img_h),
486
  int(self.img_w),
487
+ int(self.block_h),
488
+ int(self.block_w),
489
  enable_topk_norm,
490
  )
 
491
  else:
492
  tmp = xy, conics, feat, int(self.img_h), int(self.img_w)
493
  out_image = rasterize_gaussians_no_tiles(*tmp)