File size: 254 Bytes
a629f10
 
 
 
 
 
 
1
2
3
4
5
6
7
8
from safetensors.torch import load_file

def load_checkpoint(filepath, model, strict=True):
    print(f"Loading checkpoint: {filepath}")
    checkpoint_dict = load_file(filepath)
    model.load_state_dict(checkpoint_dict, strict=strict)
    return model