duzx16
commited on
Commit
·
5fe53eb
1
Parent(s):
74d61a6
Fix checkpointing
Browse files- modeling_chatglm.py +10 -6
modeling_chatglm.py
CHANGED
|
@@ -63,7 +63,7 @@ class PrefixEncoder(torch.nn.Module):
|
|
| 63 |
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
| 64 |
"""
|
| 65 |
|
| 66 |
-
def __init__(self, config):
|
| 67 |
super().__init__()
|
| 68 |
self.prefix_projection = config.prefix_projection
|
| 69 |
if self.prefix_projection:
|
|
@@ -75,7 +75,8 @@ class PrefixEncoder(torch.nn.Module):
|
|
| 75 |
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
|
| 76 |
)
|
| 77 |
else:
|
| 78 |
-
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
|
|
|
| 79 |
|
| 80 |
def forward(self, prefix: torch.Tensor):
|
| 81 |
if self.prefix_projection:
|
|
@@ -629,8 +630,8 @@ class GLMTransformer(torch.nn.Module):
|
|
| 629 |
hidden_states,
|
| 630 |
attention_mask,
|
| 631 |
rotary_pos_emb,
|
| 632 |
-
|
| 633 |
-
use_cache
|
| 634 |
)
|
| 635 |
else:
|
| 636 |
layer_ret = layer(
|
|
@@ -737,6 +738,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 737 |
if device is not None:
|
| 738 |
init_kwargs["device"] = device
|
| 739 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
# Rotary positional embeddings
|
| 742 |
self.seq_length = config.seq_length
|
|
@@ -768,8 +772,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 768 |
batch_size,
|
| 769 |
self.pre_seq_len,
|
| 770 |
self.num_layers * 2,
|
| 771 |
-
self.
|
| 772 |
-
self.
|
| 773 |
)
|
| 774 |
# seq_len, b, nh, hidden_size
|
| 775 |
past_key_values = self.dropout(past_key_values)
|
|
|
|
| 63 |
Output shape: (batch-size, prefix-length, 2*layers*hidden)
|
| 64 |
"""
|
| 65 |
|
| 66 |
+
def __init__(self, config: ChatGLMConfig):
|
| 67 |
super().__init__()
|
| 68 |
self.prefix_projection = config.prefix_projection
|
| 69 |
if self.prefix_projection:
|
|
|
|
| 75 |
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
|
| 76 |
)
|
| 77 |
else:
|
| 78 |
+
self.embedding = torch.nn.Embedding(config.pre_seq_len,
|
| 79 |
+
config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
|
| 80 |
|
| 81 |
def forward(self, prefix: torch.Tensor):
|
| 82 |
if self.prefix_projection:
|
|
|
|
| 630 |
hidden_states,
|
| 631 |
attention_mask,
|
| 632 |
rotary_pos_emb,
|
| 633 |
+
kv_caches[index],
|
| 634 |
+
use_cache
|
| 635 |
)
|
| 636 |
else:
|
| 637 |
layer_ret = layer(
|
|
|
|
| 738 |
if device is not None:
|
| 739 |
init_kwargs["device"] = device
|
| 740 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
| 741 |
+
self.num_layers = config.num_layers
|
| 742 |
+
self.multi_query_group_num = config.multi_query_group_num
|
| 743 |
+
self.kv_channels = config.kv_channels
|
| 744 |
|
| 745 |
# Rotary positional embeddings
|
| 746 |
self.seq_length = config.seq_length
|
|
|
|
| 772 |
batch_size,
|
| 773 |
self.pre_seq_len,
|
| 774 |
self.num_layers * 2,
|
| 775 |
+
self.multi_query_group_num,
|
| 776 |
+
self.kv_channels
|
| 777 |
)
|
| 778 |
# seq_len, b, nh, hidden_size
|
| 779 |
past_key_values = self.dropout(past_key_values)
|