hr-eval-api-v2 / scripts /train_sbert.py
KarenYYH
Initial commit - HR Evaluation API v2
c8b1f17
"""
Sentence-BERT微调训练脚本
针对HR对话质量评估优化语义相似度模型
"""
import os
import json
from pathlib import Path
from typing import List, Dict
import math
import numpy as np
from sentence_transformers import (
SentenceTransformer,
InputExample,
losses,
models,
datasets
)
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import torch
from tqdm import tqdm
class SBERTTrainer:
"""SBERT训练器"""
def __init__(
self,
model_name: str,
train_data_path: str,
output_dir: str,
val_data_path: str = None
):
"""
Args:
model_name: 预训练模型名称
train_data_path: 训练数据路径
output_dir: 输出目录
val_data_path: 验证数据路径(可选)
"""
self.model_name = model_name
self.train_data_path = train_data_path
self.val_data_path = val_data_path
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# 加载预训练模型
print(f"加载预训练模型: {model_name}")
self.model = SentenceTransformer(model_name)
# 检测GPU
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {self.device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
def load_data(self, data_path: str) -> List[InputExample]:
"""加载训练数据"""
print(f"\n加载数据: {data_path}")
examples = []
with open(data_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(tqdm(f, desc='读取数据'), 1):
try:
data = json.loads(line.strip())
# 创建InputExample
# label=1 表示相似对, label=0 表示不相似对
examples.append(InputExample(
texts=[data['anchor'], data['positive']],
label=float(data.get('label', 1))
))
except Exception as e:
print(f"警告: 第{line_num}行解析失败: {e}")
print(f"加载样本数: {len(examples)}")
return examples
def create_dataloader(
self,
examples: List[InputExample],
batch_size: int
) -> DataLoader:
"""创建DataLoader"""
return DataLoader(
examples,
shuffle=True,
batch_size=batch_size
)
def train(
self,
num_epochs: int = 5,
batch_size: int = 16,
warmup_steps: int = 100,
learning_rate: float = 2e-5,
evaluation_steps: int = 100,
save_best_model: bool = True
):
"""
训练模型
Args:
num_epochs: 训练轮数
batch_size: 批次大小
warmup_steps: 预热步数
learning_rate: 学习率
evaluation_steps: 评估频率
save_best_model: 是否保存最佳模型
"""
print("\n" + "="*50)
print("开始训练")
print("="*50)
# 加载训练数据
train_examples = self.load_data(self.train_data_path)
train_dataloader = self.create_dataloader(train_examples, batch_size)
# 计算总训练步数
num_train_steps = math.ceil(len(train_dataloader) * num_epochs)
print(f"\n训练配置:")
print(f" 训练样本: {len(train_examples)}")
print(f" 批次大小: {batch_size}")
print(f" 训练轮数: {num_epochs}")
print(f" 总步数: {num_train_steps}")
print(f" 预热步数: {warmup_steps}")
print(f" 学习率: {learning_rate}")
# 定义损失函数
# 使用CosineSimilarityLoss,支持正负样本对训练
print("使用CosineSimilarityLoss")
train_loss = losses.CosineSimilarityLoss(self.model)
# 创建验证器(如果有验证数据)
evaluator = None
if self.val_data_path and os.path.exists(self.val_data_path):
print(f"\n创建验证器: {self.val_data_path}")
evaluator = self.create_evaluator()
else:
print("\n未提供验证数据,跳过验证")
# 训练配置
training_args = {
'epochs': num_epochs,
'warmup_steps': warmup_steps,
'optimizer_params': {'lr': learning_rate},
'evaluation_steps': evaluation_steps if evaluator else None,
'evaluator': evaluator,
'output_path': str(self.output_dir),
}
# 开始训练
print("\n开始训练...")
self.model.fit(
train_objectives=[(train_dataloader, train_loss)],
**training_args
)
# 保存最终模型
final_model_path = self.output_dir / 'final_model'
self.model.save(str(final_model_path))
print(f"\n✓ 最终模型已保存: {final_model_path}")
def create_evaluator(self) -> EmbeddingSimilarityEvaluator:
"""创建验证器"""
# 加载验证数据
val_examples = self.load_data(self.val_data_path)
# 分离正负样本
sentences1 = []
sentences2 = []
scores = []
for example in val_examples:
sentences1.append(example.texts[0])
sentences2.append(example.texts[1])
scores.append(example.label)
# 创建验证器
return EmbeddingSimilarityEvaluator(
sentences1=sentences1,
sentences2=sentences2,
scores=scores,
batch_size=16,
name='hr_eval'
)
def evaluate(self, test_data_path: str = None):
"""评估模型"""
if test_data_path is None:
test_data_path = self.val_data_path
if test_data_path is None or not os.path.exists(test_data_path):
print("\n未提供测试数据,跳过评估")
return
print(f"\n评估模型: {test_data_path}")
# 加载测试数据
test_examples = self.load_data(test_data_path)
# 计算相似度
sentences1 = [ex.texts[0] for ex in test_examples]
sentences2 = [ex.texts[1] for ex in test_examples]
labels = [ex.label for ex in test_examples]
# 编码
print("计算embeddings...")
embeddings1 = self.model.encode(sentences1, convert_to_numpy=True)
embeddings2 = self.model.encode(sentences2, convert_to_numpy=True)
# 计算余弦相似度
from sklearn.metrics.pairwise import cosine_similarity
similarities = cosine_similarity(embeddings1, embeddings2)
predicted_scores = np.diag(similarities)
# 计算评估指标
# 将相似度转换为预测标签(阈值0.5)
predicted_labels = (predicted_scores >= 0.5).astype(int)
# 准确率
accuracy = np.mean(predicted_labels == labels)
# Spearman相关系数
from scipy.stats import spearmanr
correlation, _ = spearmanr(predicted_scores, labels)
print(f"\n评估结果:")
print(f" 准确率: {accuracy:.4f}")
print(f" Spearman相关: {correlation:.4f}")
print(f" 平均相似度: {np.mean(predicted_scores):.4f}")
return {
'accuracy': accuracy,
'spearman_correlation': correlation
}
def main():
"""主函数"""
import argparse
import numpy as np
parser = argparse.ArgumentParser(description='训练Sentence-BERT模型')
parser.add_argument(
'--model_name',
type=str,
default='distiluse-base-multilingual-cased-v1',
help='预训练模型名称'
)
parser.add_argument(
'--train_data',
type=str,
required=True,
help='训练数据路径 (JSONL格式)'
)
parser.add_argument(
'--val_data',
type=str,
default=None,
help='验证数据路径 (可选)'
)
parser.add_argument(
'--output_dir',
type=str,
default='models/sbert-hr',
help='输出目录'
)
parser.add_argument(
'--num_epochs',
type=int,
default=5,
help='训练轮数'
)
parser.add_argument(
'--batch_size',
type=int,
default=16,
help='批次大小'
)
parser.add_argument(
'--warmup_steps',
type=int,
default=100,
help='预热步数'
)
parser.add_argument(
'--learning_rate',
type=float,
default=2e-5,
help='学习率'
)
parser.add_argument(
'--evaluation_steps',
type=int,
default=100,
help='评估频率'
)
args = parser.parse_args()
# 设置随机种子
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
# 创建训练器
trainer = SBERTTrainer(
model_name=args.model_name,
train_data_path=args.train_data,
output_dir=args.output_dir,
val_data_path=args.val_data
)
# 开始训练
trainer.train(
num_epochs=args.num_epochs,
batch_size=args.batch_size,
warmup_steps=args.warmup_steps,
learning_rate=args.learning_rate,
evaluation_steps=args.evaluation_steps
)
# 评估
if args.val_data:
trainer.evaluate()
print("\n✓ 训练完成!")
print(f"\n模型已保存到: {args.output_dir}")
print("\n使用方法:")
print(f" from sentence_transformers import SentenceTransformer")
print(f" model = SentenceTransformer('{args.output_dir}/final_model')")
if __name__ == '__main__':
main()