File size: 623 Bytes
d62394f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch


def ste_quantize(x: torch.Tensor, num_bits: int = 16) -> torch.Tensor:
    """
    Bit precision control of Gaussian parameters using a straight-through estimator.
    Reference: https://arxiv.org/abs/1308.3432
    """
    qmin, qmax = 0, 2**num_bits - 1
    min_val, max_val = x.min().item(), x.max().item()
    scale = max((max_val - min_val) / (qmax - qmin), 1e-8)
    # Quantize in forward pass (non-differentiable)
    q_x = torch.round((x - min_val) / scale).clamp(qmin, qmax)
    dq_x = q_x * scale + min_val
    # Restore gradients in backward pass
    dq_x = x + (dq_x - x).detach()
    return dq_x