ner-bert / outputs /save_model.py
aaya868868
Add initial model files
364d422
raw
history blame contribute delete
929 Bytes
import torch
from transformers import BertModel
class BERT_CRF_NER(torch.nn.Module):
def __init__(self, bert_model, start_label_id, stop_label_id, num_label_id, num_labels, max_seq_length, batch_size, device):
super(BERT_CRF_NER, self).__init__()
# モデルを定義またはロード済みの場合
# 例: model = ... (トレーニング後のモデル)
bert_model = BertModel.from_pretrained("bert-base-cased")
model = BERT_CRF_NER(
bert_model=bert_model,
start_label_id=0,
stop_label_id=1,
num_label_id=30,
num_labels=30,
max_seq_length=256,
batch_size=16,
device="cpu"
)
checkpoint_path = "ner_bert_crf_checkpoint.pt"
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model_state"])
# モデルを保存
torch.save(model.state_dict(), "ner_bert_crf_model.pt")
print("Model saved as ner_bert_crf_model.pt")