Spaces:
Paused
Paused
| import torch | |
| def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): | |
| mantissa_scaled = torch.where( | |
| normal_mask, | |
| (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), | |
| (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) | |
| ) | |
| mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator) | |
| return mantissa_scaled.floor() / (2**MANTISSA_BITS) | |
| #Not 100% sure about this | |
| def manual_stochastic_round_to_float8(x, dtype, generator=None): | |
| if dtype == torch.float8_e4m3fn: | |
| EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7 | |
| elif dtype == torch.float8_e5m2: | |
| EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15 | |
| else: | |
| raise ValueError("Unsupported dtype") | |
| x = x.half() | |
| sign = torch.sign(x) | |
| abs_x = x.abs() | |
| sign = torch.where(abs_x == 0, 0, sign) | |
| # Combine exponent calculation and clamping | |
| exponent = torch.clamp( | |
| torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS, | |
| 0, 2**EXPONENT_BITS - 1 | |
| ) | |
| # Combine mantissa calculation and rounding | |
| normal_mask = ~(exponent == 0) | |
| abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator) | |
| sign *= torch.where( | |
| normal_mask, | |
| (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x), | |
| (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x | |
| ) | |
| inf = torch.finfo(dtype) | |
| torch.clamp(sign, min=inf.min, max=inf.max, out=sign) | |
| return sign | |
| def stochastic_rounding(value, dtype, seed=0): | |
| if dtype == torch.float32: | |
| return value.to(dtype=torch.float32) | |
| if dtype == torch.float16: | |
| return value.to(dtype=torch.float16) | |
| if dtype == torch.bfloat16: | |
| return value.to(dtype=torch.bfloat16) | |
| if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: | |
| generator = torch.Generator(device=value.device) | |
| generator.manual_seed(seed) | |
| output = torch.empty_like(value, dtype=dtype) | |
| num_slices = max(1, (value.numel() / (4096 * 4096))) | |
| slice_size = max(1, round(value.shape[0] / num_slices)) | |
| for i in range(0, value.shape[0], slice_size): | |
| output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)) | |
| return output | |
| return value.to(dtype=dtype) | |
| # TODO: improve this? | |
| def stochastic_float_to_fp4_e2m1(x, generator): | |
| orig_shape = x.shape | |
| sign = torch.signbit(x).to(torch.uint8) | |
| exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3) | |
| x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 | |
| x = x.abs() | |
| exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3) | |
| mantissa = torch.where( | |
| exp > 0, | |
| (x / (2.0 ** (exp - 1)) - 1.0) * 2.0, | |
| (x * 2.0), | |
| out=x | |
| ).round().to(torch.uint8) | |
| del x | |
| exp = exp.to(torch.uint8) | |
| fp4 = (sign << 3) | (exp << 1) | mantissa | |
| del sign, exp, mantissa | |
| fp4_flat = fp4.view(-1) | |
| packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] | |
| return packed.reshape(list(orig_shape)[:-1] + [-1]) | |
| def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: | |
| """ | |
| Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. | |
| See: | |
| https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout | |
| Args: | |
| input_matrix: Input tensor of shape (H, W) | |
| Returns: | |
| Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) | |
| """ | |
| def ceil_div(a, b): | |
| return (a + b - 1) // b | |
| rows, cols = input_matrix.shape | |
| n_row_blocks = ceil_div(rows, 128) | |
| n_col_blocks = ceil_div(cols, 4) | |
| # Calculate the padded shape | |
| padded_rows = n_row_blocks * 128 | |
| padded_cols = n_col_blocks * 4 | |
| padded = input_matrix | |
| if (rows, cols) != (padded_rows, padded_cols): | |
| padded = torch.zeros( | |
| (padded_rows, padded_cols), | |
| device=input_matrix.device, | |
| dtype=input_matrix.dtype, | |
| ) | |
| padded[:rows, :cols] = input_matrix | |
| # Rearrange the blocks | |
| blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) | |
| rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) | |
| if flatten: | |
| return rearranged.flatten() | |
| return rearranged.reshape(padded_rows, padded_cols) | |
| def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator): | |
| F4_E2M1_MAX = 6.0 | |
| F8_E4M3_MAX = 448.0 | |
| orig_shape = x.shape | |
| block_size = 16 | |
| x = x.reshape(orig_shape[0], -1, block_size) | |
| scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) | |
| x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) | |
| x = x.view(orig_shape).nan_to_num() | |
| data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) | |
| return data_lp, scaled_block_scales_fp8 | |
| def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): | |
| def roundup(x: int, multiple: int) -> int: | |
| """Round up x to the nearest multiple.""" | |
| return ((x + multiple - 1) // multiple) * multiple | |
| generator = torch.Generator(device=x.device) | |
| generator.manual_seed(seed) | |
| # Handle padding | |
| if pad_16x: | |
| rows, cols = x.shape | |
| padded_rows = roundup(rows, 16) | |
| padded_cols = roundup(cols, 16) | |
| if padded_rows != rows or padded_cols != cols: | |
| x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) | |
| x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator) | |
| return x, to_blocked(blocked_scaled, flatten=False) | |
| def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096): | |
| def roundup(x: int, multiple: int) -> int: | |
| """Round up x to the nearest multiple.""" | |
| return ((x + multiple - 1) // multiple) * multiple | |
| orig_shape = x.shape | |
| # Handle padding | |
| if pad_16x: | |
| rows, cols = x.shape | |
| padded_rows = roundup(rows, 16) | |
| padded_cols = roundup(cols, 16) | |
| if padded_rows != rows or padded_cols != cols: | |
| x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) | |
| # Note: We update orig_shape because the output tensor logic below assumes x.shape matches | |
| # what we want to produce. If we pad here, we want the padded output. | |
| orig_shape = x.shape | |
| orig_shape = list(orig_shape) | |
| output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device) | |
| output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device) | |
| generator = torch.Generator(device=x.device) | |
| generator.manual_seed(seed) | |
| num_slices = max(1, (x.numel() / block_size)) | |
| slice_size = max(1, (round(x.shape[0] / num_slices))) | |
| for i in range(0, x.shape[0], slice_size): | |
| fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator) | |
| output_fp4[i:i + slice_size].copy_(fp4) | |
| output_block[i:i + slice_size].copy_(block) | |
| return output_fp4, to_blocked(output_block, flatten=False) | |