duzx16 commited on
Commit ·
5fc46d2
1
Parent(s): bfb1a8f
Fix embedding quantization
Browse files- modeling_chatglm.py +10 -5
modeling_chatglm.py
CHANGED
|
@@ -1408,6 +1408,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1408 |
|
| 1409 |
self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
|
| 1410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1411 |
if quantize_embeddings:
|
| 1412 |
logger.info("Applying quantization to embeddings")
|
| 1413 |
self.transformer.word_embeddings = QuantizedEmbedding(
|
|
@@ -1415,11 +1420,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1415 |
weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
|
| 1416 |
num_embeddings=self.transformer.word_embeddings.num_embeddings,
|
| 1417 |
embedding_dim=self.transformer.word_embeddings.embedding_dim,
|
| 1418 |
-
dtype=
|
| 1419 |
-
empty_init=
|
| 1420 |
device=self.transformer.word_embeddings.weight.device,
|
| 1421 |
)
|
| 1422 |
-
self.lm_head =
|
| 1423 |
weight_bit_width=bits,
|
| 1424 |
weight_tensor=self.lm_head.weight.to(self.device),
|
| 1425 |
bias_tensor=None,
|
|
@@ -1428,8 +1433,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1428 |
bias=False,
|
| 1429 |
quantized_weight=self.transformer.word_embeddings.weight,
|
| 1430 |
quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
|
| 1431 |
-
dtype=
|
| 1432 |
-
empty_init=
|
| 1433 |
device=self.lm_head.weight.device,
|
| 1434 |
)
|
| 1435 |
|
|
|
|
| 1408 |
|
| 1409 |
self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs)
|
| 1410 |
|
| 1411 |
+
if self.device == torch.device("cpu"):
|
| 1412 |
+
dtype = torch.float32
|
| 1413 |
+
else:
|
| 1414 |
+
dtype = torch.half
|
| 1415 |
+
|
| 1416 |
if quantize_embeddings:
|
| 1417 |
logger.info("Applying quantization to embeddings")
|
| 1418 |
self.transformer.word_embeddings = QuantizedEmbedding(
|
|
|
|
| 1420 |
weight_tensor=self.transformer.word_embeddings.weight.to(self.device),
|
| 1421 |
num_embeddings=self.transformer.word_embeddings.num_embeddings,
|
| 1422 |
embedding_dim=self.transformer.word_embeddings.embedding_dim,
|
| 1423 |
+
dtype=dtype,
|
| 1424 |
+
empty_init=empty_init,
|
| 1425 |
device=self.transformer.word_embeddings.weight.device,
|
| 1426 |
)
|
| 1427 |
+
self.lm_head = QuantizedLinear(
|
| 1428 |
weight_bit_width=bits,
|
| 1429 |
weight_tensor=self.lm_head.weight.to(self.device),
|
| 1430 |
bias_tensor=None,
|
|
|
|
| 1433 |
bias=False,
|
| 1434 |
quantized_weight=self.transformer.word_embeddings.weight,
|
| 1435 |
quantized_weight_scale=self.transformer.word_embeddings.weight_scale,
|
| 1436 |
+
dtype=dtype,
|
| 1437 |
+
empty_init=empty_init,
|
| 1438 |
device=self.lm_head.weight.device,
|
| 1439 |
)
|
| 1440 |
|