Julien Blanchon commited on
Commit
e4c4055
·
1 Parent(s): 6dddb70

Fix tensor type mismatch: ensure int types for CUDA operations

Browse files

- Cast tile_bounds components to int() to avoid Float/Int type mismatch
- Cast img_h, img_w to int() in all gsplat function calls
- Cast dimensions in tensor.view() operations to int()
- This fixes 'expected scalar type Int but found Float' error in Docker environment

Files changed (1) hide show
  1. gradio_models.py +13 -13
gradio_models.py CHANGED
@@ -154,9 +154,9 @@ class GradioGaussianSplatting2D(nn.Module):
154
  )
155
  self.img_h, self.img_w = self.gt_images.shape[1:]
156
  self.tile_bounds = (
157
- (self.img_w + self.block_w - 1) // self.block_w,
158
- (self.img_h + self.block_h - 1) // self.block_h,
159
- 1,
160
  )
161
 
162
  def _init_bit_precision(self, args):
@@ -388,7 +388,7 @@ class GradioGaussianSplatting2D(nn.Module):
388
  )
389
 
390
  begin = perf_counter()
391
- tmp = project_gaussians_2d_scale_rot(xy, scale, rot, img_h, img_w, tile_bounds)
392
  xy, radii, conics, num_tiles_hit = tmp
393
 
394
  if not self.disable_tiles:
@@ -399,15 +399,15 @@ class GradioGaussianSplatting2D(nn.Module):
399
  conics,
400
  num_tiles_hit,
401
  feat,
402
- img_h,
403
- img_w,
404
  self.block_h,
405
  self.block_w,
406
  enable_topk_norm,
407
  )
408
  out_image = rasterize_gaussians_sum(*tmp)
409
  else:
410
- tmp = xy, conics, feat, img_h, img_w
411
  out_image = rasterize_gaussians_no_tiles(*tmp)
412
 
413
  render_time = perf_counter() - begin
@@ -416,7 +416,7 @@ class GradioGaussianSplatting2D(nn.Module):
416
  return render_time
417
 
418
  out_image = (
419
- out_image.view(-1, img_h, img_w, self.feat_dim)
420
  .permute(0, 3, 1, 2)
421
  .contiguous()
422
  )
@@ -469,7 +469,7 @@ class GradioGaussianSplatting2D(nn.Module):
469
  )
470
 
471
  tmp = project_gaussians_2d_scale_rot(
472
- xy, scale, rot, self.img_h, self.img_w, self.tile_bounds
473
  )
474
  xy, radii, conics, num_tiles_hit = tmp
475
 
@@ -481,19 +481,19 @@ class GradioGaussianSplatting2D(nn.Module):
481
  conics,
482
  num_tiles_hit,
483
  feat,
484
- self.img_h,
485
- self.img_w,
486
  self.block_h,
487
  self.block_w,
488
  enable_topk_norm,
489
  )
490
  out_image = rasterize_gaussians_sum(*tmp)
491
  else:
492
- tmp = xy, conics, feat, self.img_h, self.img_w
493
  out_image = rasterize_gaussians_no_tiles(*tmp)
494
 
495
  out_image = (
496
- out_image.view(-1, self.img_h, self.img_w, self.feat_dim)
497
  .permute(0, 3, 1, 2)
498
  .contiguous()
499
  ).squeeze(dim=0)
 
154
  )
155
  self.img_h, self.img_w = self.gt_images.shape[1:]
156
  self.tile_bounds = (
157
+ int((self.img_w + self.block_w - 1) // self.block_w),
158
+ int((self.img_h + self.block_h - 1) // self.block_h),
159
+ int(1),
160
  )
161
 
162
  def _init_bit_precision(self, args):
 
388
  )
389
 
390
  begin = perf_counter()
391
+ tmp = project_gaussians_2d_scale_rot(xy, scale, rot, int(img_h), int(img_w), tile_bounds)
392
  xy, radii, conics, num_tiles_hit = tmp
393
 
394
  if not self.disable_tiles:
 
399
  conics,
400
  num_tiles_hit,
401
  feat,
402
+ int(img_h),
403
+ int(img_w),
404
  self.block_h,
405
  self.block_w,
406
  enable_topk_norm,
407
  )
408
  out_image = rasterize_gaussians_sum(*tmp)
409
  else:
410
+ tmp = xy, conics, feat, int(img_h), int(img_w)
411
  out_image = rasterize_gaussians_no_tiles(*tmp)
412
 
413
  render_time = perf_counter() - begin
 
416
  return render_time
417
 
418
  out_image = (
419
+ out_image.view(-1, int(img_h), int(img_w), self.feat_dim)
420
  .permute(0, 3, 1, 2)
421
  .contiguous()
422
  )
 
469
  )
470
 
471
  tmp = project_gaussians_2d_scale_rot(
472
+ xy, scale, rot, int(self.img_h), int(self.img_w), self.tile_bounds
473
  )
474
  xy, radii, conics, num_tiles_hit = tmp
475
 
 
481
  conics,
482
  num_tiles_hit,
483
  feat,
484
+ int(self.img_h),
485
+ int(self.img_w),
486
  self.block_h,
487
  self.block_w,
488
  enable_topk_norm,
489
  )
490
  out_image = rasterize_gaussians_sum(*tmp)
491
  else:
492
+ tmp = xy, conics, feat, int(self.img_h), int(self.img_w)
493
  out_image = rasterize_gaussians_no_tiles(*tmp)
494
 
495
  out_image = (
496
+ out_image.view(-1, int(self.img_h), int(self.img_w), self.feat_dim)
497
  .permute(0, 3, 1, 2)
498
  .contiguous()
499
  ).squeeze(dim=0)