File size: 415 Bytes
9b02c0a 5b65d96 9b02c0a 5b65d96 9b02c0a 5b65d96 9b02c0a |
1 2 3 4 5 6 7 8 9 10 11 12 |
import torch
# from safetensors.torch import load_file
from PIGuard.modeling_piguard import PIGuard, PIGuardConfig
config = PIGuardConfig.from_pretrained("microsoft/deberta-v3-base")
config.num_labels = 2
model = PIGuard(config)
state_dict = torch.load("/home/hao/epoch_1_600_model.pth")
# state_dict = load_file("model.safetensors")
#model.load_state_dict(state_dict, strict=False)
model.save_pretrained("saves") |