from safetensors.torch import load_file def load_checkpoint(filepath, model, strict=True): print(f"Loading checkpoint: {filepath}") checkpoint_dict = load_file(filepath) model.load_state_dict(checkpoint_dict, strict=strict) return model