File size: 1,028 Bytes
957e2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import random
import torch

from src.simulation.effect import Effect

################################################################################
# Simulate simple quantization distortions
################################################################################


class Quantize(Effect):

    def __init__(self, bits: any = 8):
        super().__init__(compute_grad=False)

        self.min_bits, self.max_bits = self.parse_range(
            bits,
            int,
            f'Invalid bit depth {bits}'
        )
        self.bits = None
        self.sample_params()

    def forward(self, x: torch.Tensor):

        # rescale full range to -2^(bits - 1), 2^(bits - 1)
        scale_factor = 2 ** (self.bits - 1)
        x_scaled = x * scale_factor / self.scale
        x_quant = torch.round(x_scaled)
        return x_quant * self.scale / scale_factor

    def sample_params(self):
        self.bits = round(
            random.uniform(self.min_bits, self.max_bits)
        )