Update modeling_internlm.py
Browse files- modeling_internlm.py +4 -16
modeling_internlm.py
CHANGED
|
@@ -243,22 +243,10 @@ def rotate_half(x):
|
|
| 243 |
|
| 244 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
| 245 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
position_ids = position_ids.flatten() + 1
|
| 252 |
-
max_length = max(position_ids)
|
| 253 |
-
position_ids = torch.stack([torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids])
|
| 254 |
-
k_cos = cos[position_ids].unsqueeze(1).expand(k.shape)
|
| 255 |
-
k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
|
| 256 |
-
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
|
| 257 |
-
else:
|
| 258 |
-
cos = cos[position_ids].unsqueeze(1)
|
| 259 |
-
sin = sin[position_ids].unsqueeze(1)
|
| 260 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 261 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 262 |
return q_embed, k_embed
|
| 263 |
|
| 264 |
|
|
|
|
| 243 |
|
| 244 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
| 245 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
| 246 |
+
cos = cos[position_ids].unsqueeze(1)
|
| 247 |
+
sin = sin[position_ids].unsqueeze(1)
|
| 248 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 249 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
return q_embed, k_embed
|
| 251 |
|
| 252 |
|