| | import torch |
| |
|
| | def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): |
| | mantissa_scaled = torch.where( |
| | normal_mask, |
| | (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), |
| | (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) |
| | ) |
| |
|
| | mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator) |
| | return mantissa_scaled.floor() / (2**MANTISSA_BITS) |
| |
|
| | |
| | def manual_stochastic_round_to_float8(x, dtype, generator=None): |
| | if dtype == torch.float8_e4m3fn: |
| | EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7 |
| | elif dtype == torch.float8_e5m2: |
| | EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15 |
| | else: |
| | raise ValueError("Unsupported dtype") |
| |
|
| | x = x.half() |
| | sign = torch.sign(x) |
| | abs_x = x.abs() |
| | sign = torch.where(abs_x == 0, 0, sign) |
| |
|
| | |
| | exponent = torch.clamp( |
| | torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS, |
| | 0, 2**EXPONENT_BITS - 1 |
| | ) |
| |
|
| | |
| | normal_mask = ~(exponent == 0) |
| |
|
| | abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator) |
| |
|
| | sign *= torch.where( |
| | normal_mask, |
| | (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x), |
| | (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x |
| | ) |
| |
|
| | inf = torch.finfo(dtype) |
| | torch.clamp(sign, min=inf.min, max=inf.max, out=sign) |
| | return sign |
| |
|
| |
|
| |
|
| | def stochastic_rounding(value, dtype, seed=0): |
| | if dtype == torch.float32: |
| | return value.to(dtype=torch.float32) |
| | if dtype == torch.float16: |
| | return value.to(dtype=torch.float16) |
| | if dtype == torch.bfloat16: |
| | return value.to(dtype=torch.bfloat16) |
| | if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: |
| | generator = torch.Generator(device=value.device) |
| | generator.manual_seed(seed) |
| | output = torch.empty_like(value, dtype=dtype) |
| | num_slices = max(1, (value.numel() / (4096 * 4096))) |
| | slice_size = max(1, round(value.shape[0] / num_slices)) |
| | for i in range(0, value.shape[0], slice_size): |
| | output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)) |
| | return output |
| |
|
| | return value.to(dtype=dtype) |
| |
|