Update modeling_chatglm.py
Browse files- modeling_chatglm.py +12 -8
modeling_chatglm.py
CHANGED
|
@@ -489,9 +489,10 @@ class GLMBlock(torch.nn.Module):
|
|
| 489 |
self.fp32_residual_connection = config.fp32_residual_connection
|
| 490 |
|
| 491 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
|
| 492 |
# Layernorm on the input data.
|
| 493 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 494 |
-
dtype=
|
| 495 |
|
| 496 |
# Self attention.
|
| 497 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
@@ -499,7 +500,7 @@ class GLMBlock(torch.nn.Module):
|
|
| 499 |
|
| 500 |
# Layernorm on the attention output
|
| 501 |
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 502 |
-
dtype=
|
| 503 |
|
| 504 |
# MLP
|
| 505 |
self.mlp = MLP(config, device=device)
|
|
@@ -567,9 +568,10 @@ class GLMTransformer(torch.nn.Module):
|
|
| 567 |
|
| 568 |
if self.post_layer_norm:
|
| 569 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
|
| 570 |
# Final layer norm before output.
|
| 571 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 572 |
-
dtype=
|
| 573 |
|
| 574 |
self.gradient_checkpointing = False
|
| 575 |
|
|
@@ -690,10 +692,11 @@ class Embedding(torch.nn.Module):
|
|
| 690 |
|
| 691 |
self.hidden_size = config.hidden_size
|
| 692 |
# Word embeddings (parallel).
|
|
|
|
| 693 |
self.word_embeddings = nn.Embedding(
|
| 694 |
config.padded_vocab_size,
|
| 695 |
self.hidden_size,
|
| 696 |
-
dtype=
|
| 697 |
device=device
|
| 698 |
)
|
| 699 |
self.fp32_residual_connection = config.fp32_residual_connection
|
|
@@ -728,12 +731,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 728 |
rotary_dim = (
|
| 729 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
| 730 |
)
|
| 731 |
-
|
| 732 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
|
| 733 |
-
device=device, dtype=
|
| 734 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
| 735 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
| 736 |
-
dtype=
|
| 737 |
|
| 738 |
def get_input_embeddings(self):
|
| 739 |
return self.embedding.word_embeddings
|
|
@@ -1153,8 +1156,9 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
| 1153 |
|
| 1154 |
self.num_labels = config.num_labels
|
| 1155 |
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
|
|
|
| 1156 |
|
| 1157 |
-
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=
|
| 1158 |
if config.classifier_dropout is not None:
|
| 1159 |
self.dropout = nn.Dropout(config.classifier_dropout)
|
| 1160 |
else:
|
|
|
|
| 489 |
self.fp32_residual_connection = config.fp32_residual_connection
|
| 490 |
|
| 491 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 492 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
| 493 |
# Layernorm on the input data.
|
| 494 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 495 |
+
dtype=dtype)
|
| 496 |
|
| 497 |
# Self attention.
|
| 498 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
|
|
| 500 |
|
| 501 |
# Layernorm on the attention output
|
| 502 |
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 503 |
+
dtype=dtype)
|
| 504 |
|
| 505 |
# MLP
|
| 506 |
self.mlp = MLP(config, device=device)
|
|
|
|
| 568 |
|
| 569 |
if self.post_layer_norm:
|
| 570 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 571 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
| 572 |
# Final layer norm before output.
|
| 573 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 574 |
+
dtype=dtype)
|
| 575 |
|
| 576 |
self.gradient_checkpointing = False
|
| 577 |
|
|
|
|
| 692 |
|
| 693 |
self.hidden_size = config.hidden_size
|
| 694 |
# Word embeddings (parallel).
|
| 695 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
| 696 |
self.word_embeddings = nn.Embedding(
|
| 697 |
config.padded_vocab_size,
|
| 698 |
self.hidden_size,
|
| 699 |
+
dtype=dtype,
|
| 700 |
device=device
|
| 701 |
)
|
| 702 |
self.fp32_residual_connection = config.fp32_residual_connection
|
|
|
|
| 731 |
rotary_dim = (
|
| 732 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
| 733 |
)
|
| 734 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
| 735 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
|
| 736 |
+
device=device, dtype=dtype)
|
| 737 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
| 738 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
| 739 |
+
dtype=dtype, **init_kwargs)
|
| 740 |
|
| 741 |
def get_input_embeddings(self):
|
| 742 |
return self.embedding.word_embeddings
|
|
|
|
| 1156 |
|
| 1157 |
self.num_labels = config.num_labels
|
| 1158 |
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
| 1159 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
| 1160 |
|
| 1161 |
+
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=dtype)
|
| 1162 |
if config.classifier_dropout is not None:
|
| 1163 |
self.dropout = nn.Dropout(config.classifier_dropout)
|
| 1164 |
else:
|