| import inspect | |
| def patch_accelerate_unwrap_model(): | |
| """Allow newer Transformers Trainer to run with older Accelerate. | |
| Some Transformers versions call Accelerator.unwrap_model with the | |
| keep_torch_compile keyword. Older Accelerate releases do not accept that | |
| keyword, which raises a TypeError before training starts. | |
| """ | |
| from accelerate import Accelerator | |
| signature = inspect.signature(Accelerator.unwrap_model) | |
| if "keep_torch_compile" in signature.parameters: | |
| return False | |
| original_unwrap_model = Accelerator.unwrap_model | |
| def unwrap_model_compat(self, model, keep_fp32_wrapper=True, keep_torch_compile=False): | |
| try: | |
| return original_unwrap_model( | |
| self, | |
| model, | |
| keep_fp32_wrapper=keep_fp32_wrapper, | |
| ) | |
| except TypeError: | |
| return original_unwrap_model(self, model) | |
| Accelerator.unwrap_model = unwrap_model_compat | |
| return True | |