man2machine commited on
Commit
274f8d5
·
verified ·
1 Parent(s): 2fe50e8

is_torch_fx_available removed in newer versions of transformers package

Browse files
Files changed (1) hide show
  1. modeling_slimmoe.py +3 -5
modeling_slimmoe.py CHANGED
@@ -46,7 +46,6 @@ from transformers.utils import (
46
  logging,
47
  replace_return_docstrings,
48
  )
49
- from transformers.utils.import_utils import is_torch_fx_available
50
  from .configuration_slimmoe import PhiMoEConfig
51
 
52
  from einops import rearrange
@@ -61,11 +60,10 @@ if is_flash_attn_2_available():
61
 
62
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
63
  # It means that the function will not be traced through and simply appear as a node in the graph.
64
- if is_torch_fx_available():
65
- if not is_torch_greater_or_equal_than_1_13:
66
- import torch.fx
67
 
68
- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
69
 
70
 
71
  logger = logging.get_logger(__name__)
 
46
  logging,
47
  replace_return_docstrings,
48
  )
 
49
  from .configuration_slimmoe import PhiMoEConfig
50
 
51
  from einops import rearrange
 
60
 
61
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
62
  # It means that the function will not be traced through and simply appear as a node in the graph.
63
+ if not is_torch_greater_or_equal_than_1_13:
64
+ import torch.fx
 
65
 
66
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
67
 
68
 
69
  logger = logging.get_logger(__name__)