Image-GS / utils /quantization_utils.py
Julien Blanchon
Deploy optimized Image-GS with dynamic dependencies
d62394f
import torch
def ste_quantize(x: torch.Tensor, num_bits: int = 16) -> torch.Tensor:
"""
Bit precision control of Gaussian parameters using a straight-through estimator.
Reference: https://arxiv.org/abs/1308.3432
"""
qmin, qmax = 0, 2**num_bits - 1
min_val, max_val = x.min().item(), x.max().item()
scale = max((max_val - min_val) / (qmax - qmin), 1e-8)
# Quantize in forward pass (non-differentiable)
q_x = torch.round((x - min_val) / scale).clamp(qmin, qmax)
dq_x = q_x * scale + min_val
# Restore gradients in backward pass
dq_x = x + (dq_x - x).detach()
return dq_x