Spaces:
Running on Zero
Running on Zero
bnb: math.prod(shape) so tuple works
Browse files
diffusers_src/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
CHANGED
|
@@ -206,10 +206,12 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
|
| 206 |
module._parameters[tensor_name] = new_value
|
| 207 |
|
| 208 |
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
|
|
|
|
|
|
|
| 209 |
current_param_shape = current_param.shape
|
| 210 |
loaded_param_shape = loaded_param.shape
|
| 211 |
|
| 212 |
-
n =
|
| 213 |
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
|
| 214 |
if loaded_param_shape != inferred_shape:
|
| 215 |
raise ValueError(
|
|
|
|
| 206 |
module._parameters[tensor_name] = new_value
|
| 207 |
|
| 208 |
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
|
| 209 |
+
import math
|
| 210 |
+
|
| 211 |
current_param_shape = current_param.shape
|
| 212 |
loaded_param_shape = loaded_param.shape
|
| 213 |
|
| 214 |
+
n = math.prod(current_param_shape)
|
| 215 |
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
|
| 216 |
if loaded_param_shape != inferred_shape:
|
| 217 |
raise ValueError(
|