Update SSE_quantize.py
Browse files- SSE_quantize.py +6 -3
SSE_quantize.py
CHANGED
|
@@ -187,16 +187,19 @@ class SSEQ(InputModule):
|
|
| 187 |
with open(bin_path, "rb") as f:
|
| 188 |
raw = f.read()
|
| 189 |
|
| 190 |
-
vocab = state["dyt.alpha"].shape[0] # hidden dim
|
| 191 |
hidden = state["dyt.alpha"].shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
packed_size = vocab * hidden // 2
|
| 194 |
|
| 195 |
packed = np.frombuffer(raw[:packed_size], dtype=np.uint8)
|
| 196 |
scales = np.frombuffer(raw[packed_size:], dtype=np.float32)
|
| 197 |
|
| 198 |
-
packed = packed.reshape(vocab, hidden//2)
|
| 199 |
-
scales = scales.reshape(
|
| 200 |
|
| 201 |
emb = dequantize_q4_k_m(packed, scales)
|
| 202 |
|
|
|
|
| 187 |
with open(bin_path, "rb") as f:
|
| 188 |
raw = f.read()
|
| 189 |
|
|
|
|
| 190 |
hidden = state["dyt.alpha"].shape[0]
|
| 191 |
+
total_uint8 = len(raw)
|
| 192 |
+
|
| 193 |
+
bytes_per_row = hidden // 2 + 4
|
| 194 |
+
vocab = total_uint8 // bytes_per_row
|
| 195 |
|
| 196 |
packed_size = vocab * hidden // 2
|
| 197 |
|
| 198 |
packed = np.frombuffer(raw[:packed_size], dtype=np.uint8)
|
| 199 |
scales = np.frombuffer(raw[packed_size:], dtype=np.float32)
|
| 200 |
|
| 201 |
+
packed = packed.reshape(vocab, hidden // 2)
|
| 202 |
+
scales = scales.reshape(vocab, 1)
|
| 203 |
|
| 204 |
emb = dequantize_q4_k_m(packed, scales)
|
| 205 |
|