import torch def ste_quantize(x: torch.Tensor, num_bits: int = 16) -> torch.Tensor: """ Bit precision control of Gaussian parameters using a straight-through estimator. Reference: https://arxiv.org/abs/1308.3432 """ qmin, qmax = 0, 2**num_bits - 1 min_val, max_val = x.min().item(), x.max().item() scale = max((max_val - min_val) / (qmax - qmin), 1e-8) # Quantize in forward pass (non-differentiable) q_x = torch.round((x - min_val) / scale).clamp(qmin, qmax) dq_x = q_x * scale + min_val # Restore gradients in backward pass dq_x = x + (dq_x - x).detach() return dq_x