Upload folder using huggingface_hub
Browse files- modeling.py +8 -9
- psi.py +6 -6
modeling.py
CHANGED
|
@@ -5,12 +5,12 @@ import torch.nn.functional as F
|
|
| 5 |
import math
|
| 6 |
import importlib
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
xm = None
|
| 13 |
-
xs = None
|
| 14 |
|
| 15 |
|
| 16 |
class Rotary3D(nn.Module):
|
|
@@ -102,8 +102,7 @@ class PSIAttentionLayer(nn.Module):
|
|
| 102 |
# check if we are running on TPU
|
| 103 |
try:
|
| 104 |
# Use local import to avoid conflict if global xm is None and to check TPU specifically for this flag
|
| 105 |
-
|
| 106 |
-
raise ImportError
|
| 107 |
self.tpu = True
|
| 108 |
except ImportError:
|
| 109 |
self.tpu = False
|
|
@@ -153,7 +152,7 @@ class PSIAttentionLayer(nn.Module):
|
|
| 153 |
# Apply attention
|
| 154 |
if self.tpu:
|
| 155 |
# (1)
|
| 156 |
-
|
| 157 |
q_norm = q / math.sqrt(k.size(-1))
|
| 158 |
y = flash_attention(
|
| 159 |
q_norm, k, v,
|
|
|
|
| 5 |
import math
|
| 6 |
import importlib
|
| 7 |
|
| 8 |
+
try:
|
| 9 |
+
xm = importlib.import_module('torch_xla.core.xla_model')
|
| 10 |
+
xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
|
| 11 |
+
except ImportError:
|
| 12 |
+
xm = None
|
| 13 |
+
xs = None
|
| 14 |
|
| 15 |
|
| 16 |
class Rotary3D(nn.Module):
|
|
|
|
| 102 |
# check if we are running on TPU
|
| 103 |
try:
|
| 104 |
# Use local import to avoid conflict if global xm is None and to check TPU specifically for this flag
|
| 105 |
+
xm_local = importlib.import_module('torch_xla.core.xla_model')
|
|
|
|
| 106 |
self.tpu = True
|
| 107 |
except ImportError:
|
| 108 |
self.tpu = False
|
|
|
|
| 152 |
# Apply attention
|
| 153 |
if self.tpu:
|
| 154 |
# (1)
|
| 155 |
+
flash_attention = importlib.import_module('torch_xla.experimental.custom_kernel.flash_attention')
|
| 156 |
q_norm = q / math.sqrt(k.size(-1))
|
| 157 |
y = flash_attention(
|
| 158 |
q_norm, k, v,
|
psi.py
CHANGED
|
@@ -18,12 +18,12 @@ from .modeling import (
|
|
| 18 |
RMSNorm, PSIBlock, PartitionedEmbedding, PartitionedLinear
|
| 19 |
)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
xm = None
|
| 26 |
-
xs = None
|
| 27 |
|
| 28 |
|
| 29 |
|
|
|
|
| 18 |
RMSNorm, PSIBlock, PartitionedEmbedding, PartitionedLinear
|
| 19 |
)
|
| 20 |
|
| 21 |
+
try:
|
| 22 |
+
xm = importlib.import_module('torch_xla.core.xla_model')
|
| 23 |
+
xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
|
| 24 |
+
except ImportError:
|
| 25 |
+
xm = None
|
| 26 |
+
xs = None
|
| 27 |
|
| 28 |
|
| 29 |
|