Julien Blanchon commited on
Commit
6f11a19
Β·
1 Parent(s): 090544c

Fix CUDA tensor type error: radii should be float32, not int32

Browse files

- Debug output revealed radii=torch.int32 when it should be torch.float32
- Convert radii back to float32 after project_gaussians_2d_scale_rot
- Keep num_tiles_hit as int32 as required
- Apply fix to both main forward() and Gaussian ID visualization
- Based on debug: xy=float32βœ…, radii=int32❌, conics=float32βœ…, num_tiles_hit=int32βœ…, feat=float32βœ…

Files changed (1) hide show
  1. gradio_models.py +14 -8
gradio_models.py CHANGED
@@ -393,16 +393,21 @@ class GradioGaussianSplatting2D(nn.Module):
393
  )
394
  xy, radii, conics, num_tiles_hit = tmp
395
 
396
- # Ensure num_tiles_hit is int32 tensor (required by CUDA backend)
397
- num_tiles_hit = num_tiles_hit.to(dtype=torch.int32)
 
398
 
399
  if not self.disable_tiles:
400
  enable_topk_norm = not self.disable_topk_norm
401
-
402
  # Debug: Print tensor types to identify the Float tensor
403
- self.worklog.info(f"DEBUG - Tensor types: xy={xy.dtype}, radii={radii.dtype}, conics={conics.dtype}, num_tiles_hit={num_tiles_hit.dtype}, feat={feat.dtype}")
404
- self.worklog.info(f"DEBUG - Params: img_h={img_h} ({type(img_h)}), img_w={img_w} ({type(img_w)}), block_h={self.block_h} ({type(self.block_h)}), block_w={self.block_w} ({type(self.block_w)})")
405
-
 
 
 
 
406
  out_image = rasterize_gaussians_sum(
407
  xy,
408
  radii,
@@ -482,8 +487,9 @@ class GradioGaussianSplatting2D(nn.Module):
482
  )
483
  xy, radii, conics, num_tiles_hit = tmp
484
 
485
- # Ensure num_tiles_hit is int32 tensor (required by CUDA backend)
486
- num_tiles_hit = num_tiles_hit.to(dtype=torch.int32)
 
487
 
488
  if not self.disable_tiles:
489
  enable_topk_norm = not self.disable_topk_norm
 
393
  )
394
  xy, radii, conics, num_tiles_hit = tmp
395
 
396
+ # Ensure correct tensor types for CUDA backend
397
+ num_tiles_hit = num_tiles_hit.to(dtype=torch.int32) # Must be int32
398
+ radii = radii.to(dtype=torch.float32) # Must be float32
399
 
400
  if not self.disable_tiles:
401
  enable_topk_norm = not self.disable_topk_norm
402
+
403
  # Debug: Print tensor types to identify the Float tensor
404
+ self.worklog.info(
405
+ f"DEBUG - Tensor types: xy={xy.dtype}, radii={radii.dtype}, conics={conics.dtype}, num_tiles_hit={num_tiles_hit.dtype}, feat={feat.dtype}"
406
+ )
407
+ self.worklog.info(
408
+ f"DEBUG - Params: img_h={img_h} ({type(img_h)}), img_w={img_w} ({type(img_w)}), block_h={self.block_h} ({type(self.block_h)}), block_w={self.block_w} ({type(self.block_w)})"
409
+ )
410
+
411
  out_image = rasterize_gaussians_sum(
412
  xy,
413
  radii,
 
487
  )
488
  xy, radii, conics, num_tiles_hit = tmp
489
 
490
+ # Ensure correct tensor types for CUDA backend
491
+ num_tiles_hit = num_tiles_hit.to(dtype=torch.int32) # Must be int32
492
+ radii = radii.to(dtype=torch.float32) # Must be float32
493
 
494
  if not self.disable_tiles:
495
  enable_topk_norm = not self.disable_topk_norm