Julian Bilcke commited on
Commit
9dd297c
·
1 Parent(s): 3abc88b

attempting a fix

Browse files
Files changed (1) hide show
  1. quant.py +10 -1
quant.py CHANGED
@@ -92,18 +92,27 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype, native_fp8_support=False):
92
  else:
93
  batch_size = None
94
  A_input = A
 
 
 
 
 
95
  output = torch._scaled_mm(
96
  A_input,
97
  B.t(),
98
  out_dtype=out_dtype,
99
  scale_a=A_scale,
100
  scale_b=B_scale,
101
- bias=bias,
102
  )
103
  if need_reshape:
104
  output = output.reshape(
105
  batch_size, output.shape[0] // batch_size, output.shape[1]
106
  )
 
 
 
 
107
  else:
108
  output = torch.nn.functional.linear(
109
  A.to(out_dtype) * A_scale,
 
92
  else:
93
  batch_size = None
94
  A_input = A
95
+
96
+ # torch._scaled_mm doesn't support bias when out_dtype is Float32
97
+ # Apply bias separately in this case
98
+ use_bias_in_mm = bias is not None and out_dtype != torch.float32
99
+
100
  output = torch._scaled_mm(
101
  A_input,
102
  B.t(),
103
  out_dtype=out_dtype,
104
  scale_a=A_scale,
105
  scale_b=B_scale,
106
+ bias=bias if use_bias_in_mm else None,
107
  )
108
  if need_reshape:
109
  output = output.reshape(
110
  batch_size, output.shape[0] // batch_size, output.shape[1]
111
  )
112
+
113
+ # Apply bias separately if out_dtype is Float32
114
+ if bias is not None and not use_bias_in_mm:
115
+ output = output + bias
116
  else:
117
  output = torch.nn.functional.linear(
118
  A.to(out_dtype) * A_scale,