Spaces:
Paused
Paused
| import os | |
| def prelude(): | |
| os.environ["PYTORCH_JIT"] = "0v" | |
| # patch for jit script | |
| # if we find `def expand_2d_or_3d_tensor(x,` in /usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py | |
| # patch it with `def expand_2d_or_3d_tensor(x: Tensor,` | |
| FAIRSEQ_CODE = ( | |
| "/usr/local/lib/python3.10/site-packages/fairseq/models/model_utils.py" | |
| ) | |
| if os.path.exists(FAIRSEQ_CODE): | |
| with open(FAIRSEQ_CODE, "r") as f: | |
| lines = f.readlines() | |
| with open(FAIRSEQ_CODE, "w") as f: | |
| for line in lines: | |
| if ( | |
| "def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int):" | |
| in line | |
| ): | |
| f.write( | |
| "def expand_2d_or_3d_tensor(x: Tensor, trg_dim: int, padding_idx: int) -> Tensor:\n" | |
| ) | |
| else: | |
| f.write(line) | |