File size: 623 Bytes
db60e24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""
模型定义:基于 BertForSequenceClassification 的二分类模型
使用 HuggingFace 原生的 save_pretrained / from_pretrained 实现可靠保存/加载
"""

from transformers import BertForSequenceClassification


def create_model(model_name: str = "bert-base-chinese", num_labels: int = 2):
    """创建分类模型"""
    model = BertForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
    )
    return model


def load_model(model_path: str):
    """加载已训练的模型"""
    model = BertForSequenceClassification.from_pretrained(model_path)
    return model