mansaripo commited on
Commit
d0aacb0
·
verified ·
1 Parent(s): 6f7dfbe

Update vllm_plugin/quartet2_quant.py

Browse files
Files changed (1) hide show
  1. vllm_plugin/quartet2_quant.py +10 -1
vllm_plugin/quartet2_quant.py CHANGED
@@ -93,9 +93,9 @@ class QuartetIILinearMethod(LinearMethodBase):
93
 
94
  weight = layer.weight
95
  orig_shape = x.shape
 
96
  flat_x = x.reshape(-1, x.shape[-1])
97
 
98
- # Quartet II requires rows to be multiples of 128; pad if needed.
99
  num_rows = flat_x.shape[0]
100
  remainder = num_rows % 128
101
  if remainder != 0:
@@ -104,6 +104,13 @@ class QuartetIILinearMethod(LinearMethodBase):
104
  else:
105
  pad_rows = 0
106
 
 
 
 
 
 
 
 
107
  input_amax = abs_max(flat_x)
108
  weight_amax = abs_max(weight)
109
 
@@ -128,6 +135,8 @@ class QuartetIILinearMethod(LinearMethodBase):
128
 
129
  if pad_rows > 0:
130
  output = output[:num_rows]
 
 
131
 
132
  output = output.reshape(*orig_shape[:-1], output.shape[-1])
133
  if bias is not None:
 
93
 
94
  weight = layer.weight
95
  orig_shape = x.shape
96
+ out_features = weight.shape[0]
97
  flat_x = x.reshape(-1, x.shape[-1])
98
 
 
99
  num_rows = flat_x.shape[0]
100
  remainder = num_rows % 128
101
  if remainder != 0:
 
104
  else:
105
  pad_rows = 0
106
 
107
+ w_remainder = out_features % 128
108
+ if w_remainder != 0:
109
+ w_pad = 128 - w_remainder
110
+ weight = F.pad(weight, (0, 0, 0, w_pad))
111
+ else:
112
+ w_pad = 0
113
+
114
  input_amax = abs_max(flat_x)
115
  weight_amax = abs_max(weight)
116
 
 
135
 
136
  if pad_rows > 0:
137
  output = output[:num_rows]
138
+ if w_pad > 0:
139
+ output = output[:, :out_features]
140
 
141
  output = output.reshape(*orig_shape[:-1], output.shape[-1])
142
  if bias is not None: