crackrammer's picture
Upload folder using huggingface_hub
db60e24 verified
"""
模型定义:基于 BertForSequenceClassification 的二分类模型
使用 HuggingFace 原生的 save_pretrained / from_pretrained 实现可靠保存/加载
"""
from transformers import BertForSequenceClassification
def create_model(model_name: str = "bert-base-chinese", num_labels: int = 2):
"""创建分类模型"""
model = BertForSequenceClassification.from_pretrained(
model_name,
num_labels=num_labels,
)
return model
def load_model(model_path: str):
"""加载已训练的模型"""
model = BertForSequenceClassification.from_pretrained(model_path)
return model