import torch from transformers import BertModel class BERT_CRF_NER(torch.nn.Module): def __init__(self, bert_model, start_label_id, stop_label_id, num_label_id, num_labels, max_seq_length, batch_size, device): super(BERT_CRF_NER, self).__init__() # モデルを定義またはロード済みの場合 # 例: model = ... (トレーニング後のモデル) bert_model = BertModel.from_pretrained("bert-base-cased") model = BERT_CRF_NER( bert_model=bert_model, start_label_id=0, stop_label_id=1, num_label_id=30, num_labels=30, max_seq_length=256, batch_size=16, device="cpu" ) checkpoint_path = "ner_bert_crf_checkpoint.pt" checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) model.load_state_dict(checkpoint["model_state"]) # モデルを保存 torch.save(model.state_dict(), "ner_bert_crf_model.pt") print("Model saved as ner_bert_crf_model.pt")