multimodalart HF Staff commited on
Commit
290c73c
·
verified ·
1 Parent(s): 002edd6

bnb: math.prod(shape) so tuple works

Browse files
diffusers_src/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py CHANGED
@@ -206,10 +206,12 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
206
  module._parameters[tensor_name] = new_value
207
 
208
  def check_quantized_param_shape(self, param_name, current_param, loaded_param):
 
 
209
  current_param_shape = current_param.shape
210
  loaded_param_shape = loaded_param.shape
211
 
212
- n = current_param_shape.numel()
213
  inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
214
  if loaded_param_shape != inferred_shape:
215
  raise ValueError(
 
206
  module._parameters[tensor_name] = new_value
207
 
208
  def check_quantized_param_shape(self, param_name, current_param, loaded_param):
209
+ import math
210
+
211
  current_param_shape = current_param.shape
212
  loaded_param_shape = loaded_param.shape
213
 
214
+ n = math.prod(current_param_shape)
215
  inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
216
  if loaded_param_shape != inferred_shape:
217
  raise ValueError(