Spaces:
Running
on
Zero
Running
on
Zero
v1
Browse files
meteor/arch/modeling_internlm2.py
CHANGED
|
@@ -277,8 +277,8 @@ def rotate_half(x):
|
|
| 277 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
| 278 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 279 |
"""Applies Rotary Position Embedding to the query and key tensors."""
|
| 280 |
-
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 281 |
-
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 282 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 283 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 284 |
return q_embed, k_embed
|
|
|
|
| 277 |
# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
|
| 278 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 279 |
"""Applies Rotary Position Embedding to the query and key tensors."""
|
| 280 |
+
cos = cos.to(position_ids.device)[position_ids].unsqueeze(unsqueeze_dim)
|
| 281 |
+
sin = sin.to(position_ids.device)[position_ids].unsqueeze(unsqueeze_dim)
|
| 282 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 283 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 284 |
return q_embed, k_embed
|