File size: 929 Bytes
364d422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from transformers import BertModel

class BERT_CRF_NER(torch.nn.Module):
    def __init__(self, bert_model, start_label_id, stop_label_id, num_label_id, num_labels, max_seq_length, batch_size, device):
        super(BERT_CRF_NER, self).__init__()

# モデルを定義またはロード済みの場合
# 例: model = ... (トレーニング後のモデル)

bert_model = BertModel.from_pretrained("bert-base-cased")
model = BERT_CRF_NER(
    bert_model=bert_model,
    start_label_id=0,
    stop_label_id=1,
    num_label_id=30,
    num_labels=30,
    max_seq_length=256,
    batch_size=16,
    device="cpu"
)
    
checkpoint_path = "ner_bert_crf_checkpoint.pt"
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model_state"])

# モデルを保存
torch.save(model.state_dict(), "ner_bert_crf_model.pt")
print("Model saved as ner_bert_crf_model.pt")