ITFormer / utils /accelerate_compat.py
a12354's picture
Add files using upload-large-folder tool
f48983a verified
Raw
History Blame Contribute Delete
978 Bytes
import inspect
def patch_accelerate_unwrap_model():
"""Allow newer Transformers Trainer to run with older Accelerate.
Some Transformers versions call Accelerator.unwrap_model with the
keep_torch_compile keyword. Older Accelerate releases do not accept that
keyword, which raises a TypeError before training starts.
"""
from accelerate import Accelerator
signature = inspect.signature(Accelerator.unwrap_model)
if "keep_torch_compile" in signature.parameters:
return False
original_unwrap_model = Accelerator.unwrap_model
def unwrap_model_compat(self, model, keep_fp32_wrapper=True, keep_torch_compile=False):
try:
return original_unwrap_model(
self,
model,
keep_fp32_wrapper=keep_fp32_wrapper,
)
except TypeError:
return original_unwrap_model(self, model)
Accelerator.unwrap_model = unwrap_model_compat
return True