Fix tensor shape error
#7
by
hiyouga
- opened
- modeling_chatglm.py +4 -7
modeling_chatglm.py
CHANGED
|
@@ -253,15 +253,12 @@ class CoreAttention(torch.nn.Module):
|
|
| 253 |
# This is actually dropping out entire tokens to attend to, which might
|
| 254 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 255 |
attention_probs = self.attention_dropout(attention_probs)
|
| 256 |
-
# =========================
|
| 257 |
-
# Context layer. [sq, b, hp]
|
| 258 |
-
# =========================
|
| 259 |
-
|
| 260 |
-
# value_layer -> context layer.
|
| 261 |
-
# [sk, b, np, hn] --> [b, np, sq, hn]
|
| 262 |
|
|
|
|
|
|
|
|
|
|
| 263 |
# context layer shape: [b, np, sq, hn]
|
| 264 |
-
output_size = (value_layer.size(
|
| 265 |
# change view [b * np, sk, hn]
|
| 266 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
| 267 |
# change view [b * np, sq, sk]
|
|
|
|
| 253 |
# This is actually dropping out entire tokens to attend to, which might
|
| 254 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 255 |
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
+
# query layer shape: [b * np, sq, hn]
|
| 258 |
+
# value layer shape: [b, np, sk, hn]
|
| 259 |
+
# attention shape: [b, np, sq, sk]
|
| 260 |
# context layer shape: [b, np, sq, hn]
|
| 261 |
+
output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
|
| 262 |
# change view [b * np, sk, hn]
|
| 263 |
value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
|
| 264 |
# change view [b * np, sq, sk]
|