Ensure tests also work in test shells
Browse files- tests/test_rotary.py +6 -3
tests/test_rotary.py
CHANGED
|
@@ -2,7 +2,6 @@ import pytest
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
from tests.utils import infer_device, supports_bfloat16
|
| 5 |
-
from kernels import get_local_kernel
|
| 6 |
from pathlib import Path
|
| 7 |
|
| 8 |
# import rotary
|
|
@@ -10,8 +9,12 @@ from pathlib import Path
|
|
| 10 |
# set_seed(42)
|
| 11 |
|
| 12 |
# Set the local repo path, relative path
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
|
| 17 |
assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
|
|
|
|
| 2 |
import torch
|
| 3 |
|
| 4 |
from tests.utils import infer_device, supports_bfloat16
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
# import rotary
|
|
|
|
| 9 |
# set_seed(42)
|
| 10 |
|
| 11 |
# Set the local repo path, relative path
|
| 12 |
+
try:
|
| 13 |
+
import rotary
|
| 14 |
+
except ImportError:
|
| 15 |
+
from kernels import get_local_kernel
|
| 16 |
+
repo_path = Path(__file__).parent.parent
|
| 17 |
+
rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
|
| 18 |
|
| 19 |
def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
|
| 20 |
assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
|