buchi-stdesign's picture
Update src/sbv2/utils.py
a629f10 verified
raw
history blame contribute delete
254 Bytes
from safetensors.torch import load_file
def load_checkpoint(filepath, model, strict=True):
print(f"Loading checkpoint: {filepath}")
checkpoint_dict = load_file(filepath)
model.load_state_dict(checkpoint_dict, strict=strict)
return model