TheTrueJard commited on
Commit
e3d287d
·
verified ·
1 Parent(s): ed193e4

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. modeling.py +8 -9
  2. psi.py +6 -6
modeling.py CHANGED
@@ -5,12 +5,12 @@ import torch.nn.functional as F
5
  import math
6
  import importlib
7
 
8
- #try:
9
- # xm = importlib.import_module('torch_xla.core.xla_model')
10
- # xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
11
- #except ImportError:
12
- xm = None
13
- xs = None
14
 
15
 
16
  class Rotary3D(nn.Module):
@@ -102,8 +102,7 @@ class PSIAttentionLayer(nn.Module):
102
  # check if we are running on TPU
103
  try:
104
  # Use local import to avoid conflict if global xm is None and to check TPU specifically for this flag
105
- #xm_local = importlib.import_module('torch_xla.core.xla_model')
106
- raise ImportError
107
  self.tpu = True
108
  except ImportError:
109
  self.tpu = False
@@ -153,7 +152,7 @@ class PSIAttentionLayer(nn.Module):
153
  # Apply attention
154
  if self.tpu:
155
  # (1)
156
- #from torch_xla.experimental.custom_kernel import flash_attention
157
  q_norm = q / math.sqrt(k.size(-1))
158
  y = flash_attention(
159
  q_norm, k, v,
 
5
  import math
6
  import importlib
7
 
8
+ try:
9
+ xm = importlib.import_module('torch_xla.core.xla_model')
10
+ xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
11
+ except ImportError:
12
+ xm = None
13
+ xs = None
14
 
15
 
16
  class Rotary3D(nn.Module):
 
102
  # check if we are running on TPU
103
  try:
104
  # Use local import to avoid conflict if global xm is None and to check TPU specifically for this flag
105
+ xm_local = importlib.import_module('torch_xla.core.xla_model')
 
106
  self.tpu = True
107
  except ImportError:
108
  self.tpu = False
 
152
  # Apply attention
153
  if self.tpu:
154
  # (1)
155
+ flash_attention = importlib.import_module('torch_xla.experimental.custom_kernel.flash_attention')
156
  q_norm = q / math.sqrt(k.size(-1))
157
  y = flash_attention(
158
  q_norm, k, v,
psi.py CHANGED
@@ -18,12 +18,12 @@ from .modeling import (
18
  RMSNorm, PSIBlock, PartitionedEmbedding, PartitionedLinear
19
  )
20
 
21
- #try:
22
- # xm = importlib.import_module('torch_xla.core.xla_model')
23
- # xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
24
- #except ImportError:
25
- xm = None
26
- xs = None
27
 
28
 
29
 
 
18
  RMSNorm, PSIBlock, PartitionedEmbedding, PartitionedLinear
19
  )
20
 
21
+ try:
22
+ xm = importlib.import_module('torch_xla.core.xla_model')
23
+ xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
24
+ except ImportError:
25
+ xm = None
26
+ xs = None
27
 
28
 
29