|
|
import torch |
|
|
from transformers import BertModel |
|
|
|
|
|
class BERT_CRF_NER(torch.nn.Module): |
|
|
def __init__(self, bert_model, start_label_id, stop_label_id, num_label_id, num_labels, max_seq_length, batch_size, device): |
|
|
super(BERT_CRF_NER, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bert_model = BertModel.from_pretrained("bert-base-cased") |
|
|
model = BERT_CRF_NER( |
|
|
bert_model=bert_model, |
|
|
start_label_id=0, |
|
|
stop_label_id=1, |
|
|
num_label_id=30, |
|
|
num_labels=30, |
|
|
max_seq_length=256, |
|
|
batch_size=16, |
|
|
device="cpu" |
|
|
) |
|
|
|
|
|
checkpoint_path = "ner_bert_crf_checkpoint.pt" |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
|
|
model.load_state_dict(checkpoint["model_state"]) |
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), "ner_bert_crf_model.pt") |
|
|
print("Model saved as ner_bert_crf_model.pt") |
|
|
|