TheTrueJard commited on
Commit
f70ae43
·
verified ·
1 Parent(s): 8bc54c9

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. modeling.py +3 -13
  2. psi.py +3 -2
modeling.py CHANGED
@@ -3,21 +3,11 @@ import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  import math
6
-
7
- import sys
8
- if "torch_xla" not in sys.modules:
9
- sys.modules["torch_xla"] = type(sys)("torch_xla")
10
- # Recreate the submodule structure
11
- sys.modules["torch_xla"].core = type(sys)("core")
12
- sys.modules["torch_xla"].distributed = type(sys)("distributed")
13
- sys.modules["torch_xla"].distributed.spmd = type(sys)("spmd")
14
- # Set required attributes to None (matching your try-except)
15
- sys.modules["torch_xla"].core.xla_model = None
16
- sys.modules["torch_xla"].distributed.spmd.xla_sharding = None
17
 
18
  try:
19
- import torch_xla.core.xla_model as xm
20
- import torch_xla.distributed.spmd.xla_sharding as xs
21
  except ImportError:
22
  xm = None
23
  xs = None
 
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
  import math
6
+ import importlib
 
 
 
 
 
 
 
 
 
 
7
 
8
  try:
9
+ xm = importlib.import_module('torch_xla.core.xla_model')
10
+ xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
11
  except ImportError:
12
  xm = None
13
  xs = None
psi.py CHANGED
@@ -4,6 +4,7 @@ PSI Model Definition
4
 
5
 
6
  import math
 
7
  from typing import Tuple, Union, List, Optional, Callable, Dict
8
  from transformers import PreTrainedModel
9
  import torch
@@ -18,8 +19,8 @@ from .modeling import (
18
  )
19
 
20
  try:
21
- import torch_xla.core.xla_model as xm
22
- import torch_xla.distributed.spmd.xla_sharding as xs
23
  except ImportError:
24
  xm = None
25
  xs = None
 
4
 
5
 
6
  import math
7
+ import importlib
8
  from typing import Tuple, Union, List, Optional, Callable, Dict
9
  from transformers import PreTrainedModel
10
  import torch
 
19
  )
20
 
21
  try:
22
+ xm = importlib.import_module('torch_xla.core.xla_model')
23
+ xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
24
  except ImportError:
25
  xm = None
26
  xs = None