Update modeling_chatglm.py
Browse files- modeling_chatglm.py +10 -6
modeling_chatglm.py
CHANGED
|
@@ -516,6 +516,7 @@ class GLMBlock(torch.nn.Module):
|
|
| 516 |
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
| 517 |
super(GLMBlock, self).__init__()
|
| 518 |
self.layer_number = layer_number
|
|
|
|
| 519 |
|
| 520 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
| 521 |
|
|
@@ -524,7 +525,7 @@ class GLMBlock(torch.nn.Module):
|
|
| 524 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 525 |
# Layernorm on the input data.
|
| 526 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 527 |
-
dtype=
|
| 528 |
|
| 529 |
# Self attention.
|
| 530 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
@@ -532,7 +533,7 @@ class GLMBlock(torch.nn.Module):
|
|
| 532 |
|
| 533 |
# Layernorm on the attention output
|
| 534 |
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 535 |
-
dtype=
|
| 536 |
|
| 537 |
# MLP
|
| 538 |
self.mlp = MLP(config, device=device)
|
|
@@ -600,9 +601,10 @@ class GLMTransformer(torch.nn.Module):
|
|
| 600 |
|
| 601 |
if self.post_layer_norm:
|
| 602 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
|
| 603 |
# Final layer norm before output.
|
| 604 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 605 |
-
dtype=
|
| 606 |
|
| 607 |
self.gradient_checkpointing = False
|
| 608 |
|
|
@@ -711,13 +713,14 @@ class Embedding(torch.nn.Module):
|
|
| 711 |
|
| 712 |
def __init__(self, config: ChatGLMConfig, device=None):
|
| 713 |
super(Embedding, self).__init__()
|
|
|
|
| 714 |
|
| 715 |
self.hidden_size = config.hidden_size
|
| 716 |
# Word embeddings (parallel).
|
| 717 |
self.word_embeddings = nn.Embedding(
|
| 718 |
config.padded_vocab_size,
|
| 719 |
self.hidden_size,
|
| 720 |
-
dtype=
|
| 721 |
device=device
|
| 722 |
)
|
| 723 |
self.fp32_residual_connection = config.fp32_residual_connection
|
|
@@ -748,6 +751,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 748 |
self.num_layers = config.num_layers
|
| 749 |
self.multi_query_group_num = config.multi_query_group_num
|
| 750 |
self.kv_channels = config.kv_channels
|
|
|
|
| 751 |
|
| 752 |
# Rotary positional embeddings
|
| 753 |
self.seq_length = config.seq_length
|
|
@@ -756,10 +760,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 756 |
)
|
| 757 |
|
| 758 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
| 759 |
-
dtype=
|
| 760 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
| 761 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
| 762 |
-
dtype=
|
| 763 |
self.pre_seq_len = config.pre_seq_len
|
| 764 |
self.prefix_projection = config.prefix_projection
|
| 765 |
if self.pre_seq_len is not None:
|
|
|
|
| 516 |
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
|
| 517 |
super(GLMBlock, self).__init__()
|
| 518 |
self.layer_number = layer_number
|
| 519 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
| 520 |
|
| 521 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
| 522 |
|
|
|
|
| 525 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 526 |
# Layernorm on the input data.
|
| 527 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 528 |
+
dtype=dtype)
|
| 529 |
|
| 530 |
# Self attention.
|
| 531 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
|
|
| 533 |
|
| 534 |
# Layernorm on the attention output
|
| 535 |
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 536 |
+
dtype=dtype)
|
| 537 |
|
| 538 |
# MLP
|
| 539 |
self.mlp = MLP(config, device=device)
|
|
|
|
| 601 |
|
| 602 |
if self.post_layer_norm:
|
| 603 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
| 604 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
| 605 |
# Final layer norm before output.
|
| 606 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
| 607 |
+
dtype=dtype)
|
| 608 |
|
| 609 |
self.gradient_checkpointing = False
|
| 610 |
|
|
|
|
| 713 |
|
| 714 |
def __init__(self, config: ChatGLMConfig, device=None):
|
| 715 |
super(Embedding, self).__init__()
|
| 716 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
| 717 |
|
| 718 |
self.hidden_size = config.hidden_size
|
| 719 |
# Word embeddings (parallel).
|
| 720 |
self.word_embeddings = nn.Embedding(
|
| 721 |
config.padded_vocab_size,
|
| 722 |
self.hidden_size,
|
| 723 |
+
dtype=dtype,
|
| 724 |
device=device
|
| 725 |
)
|
| 726 |
self.fp32_residual_connection = config.fp32_residual_connection
|
|
|
|
| 751 |
self.num_layers = config.num_layers
|
| 752 |
self.multi_query_group_num = config.multi_query_group_num
|
| 753 |
self.kv_channels = config.kv_channels
|
| 754 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
| 755 |
|
| 756 |
# Rotary positional embeddings
|
| 757 |
self.seq_length = config.seq_length
|
|
|
|
| 760 |
)
|
| 761 |
|
| 762 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
| 763 |
+
dtype=dtype)
|
| 764 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
| 765 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
| 766 |
+
dtype=dtype, **init_kwargs)
|
| 767 |
self.pre_seq_len = config.pre_seq_len
|
| 768 |
self.prefix_projection = config.prefix_projection
|
| 769 |
if self.pre_seq_len is not None:
|