Spaces:
Running
on
Zero
Running
on
Zero
Julian Bilcke
commited on
Commit
·
9dd297c
1
Parent(s):
3abc88b
attempting a fix
Browse files
quant.py
CHANGED
|
@@ -92,18 +92,27 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype, native_fp8_support=False):
|
|
| 92 |
else:
|
| 93 |
batch_size = None
|
| 94 |
A_input = A
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
output = torch._scaled_mm(
|
| 96 |
A_input,
|
| 97 |
B.t(),
|
| 98 |
out_dtype=out_dtype,
|
| 99 |
scale_a=A_scale,
|
| 100 |
scale_b=B_scale,
|
| 101 |
-
bias=bias,
|
| 102 |
)
|
| 103 |
if need_reshape:
|
| 104 |
output = output.reshape(
|
| 105 |
batch_size, output.shape[0] // batch_size, output.shape[1]
|
| 106 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
else:
|
| 108 |
output = torch.nn.functional.linear(
|
| 109 |
A.to(out_dtype) * A_scale,
|
|
|
|
| 92 |
else:
|
| 93 |
batch_size = None
|
| 94 |
A_input = A
|
| 95 |
+
|
| 96 |
+
# torch._scaled_mm doesn't support bias when out_dtype is Float32
|
| 97 |
+
# Apply bias separately in this case
|
| 98 |
+
use_bias_in_mm = bias is not None and out_dtype != torch.float32
|
| 99 |
+
|
| 100 |
output = torch._scaled_mm(
|
| 101 |
A_input,
|
| 102 |
B.t(),
|
| 103 |
out_dtype=out_dtype,
|
| 104 |
scale_a=A_scale,
|
| 105 |
scale_b=B_scale,
|
| 106 |
+
bias=bias if use_bias_in_mm else None,
|
| 107 |
)
|
| 108 |
if need_reshape:
|
| 109 |
output = output.reshape(
|
| 110 |
batch_size, output.shape[0] // batch_size, output.shape[1]
|
| 111 |
)
|
| 112 |
+
|
| 113 |
+
# Apply bias separately if out_dtype is Float32
|
| 114 |
+
if bias is not None and not use_bias_in_mm:
|
| 115 |
+
output = output + bias
|
| 116 |
else:
|
| 117 |
output = torch.nn.functional.linear(
|
| 118 |
A.to(out_dtype) * A_scale,
|