hr-eval-api-v2 / scripts /prepare_sbert_data.py
KarenYYH
Initial commit - HR Evaluation API v2
c8b1f17
"""
Sentence-BERT训练数据准备脚本
从QA数据集构建语义相似度训练数据
"""
import json
import random
from pathlib import Path
from typing import List, Dict, Tuple
from datasets import load_from_disk
import numpy as np
class SBERTDataPreparator:
"""SBERT训练数据准备器"""
def __init__(self, qa_dataset_path: str, output_dir: str):
"""
Args:
qa_dataset_path: QA数据集路径
output_dir: 输出目录
"""
self.qa_dataset_path = qa_dataset_path
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# 加载QA数据集
print(f"加载QA数据集: {qa_dataset_path}")
self.qa_dataset = load_from_disk(qa_dataset_path)
print(f"数据集大小: {len(self.qa_dataset['train'])}")
def prepare_training_data(
self,
num_negatives: int = 5,
hard_negative_ratio: float = 0.3,
train_ratio: float = 0.7,
val_ratio: float = 0.15
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""
准备训练数据
Args:
num_negatives: 每个正样本的负样本数量
hard_negative_ratio: 困难负样本的比例
train_ratio: 训练集比例
val_ratio: 验证集比例
Returns:
(train_data, val_data, test_data)
"""
print("\n准备训练数据...")
# 转换为列表便于处理
qa_list = list(self.qa_dataset['train'])
# 构建正样本对
positive_pairs = self._create_positive_pairs(qa_list)
# 构建负样本
all_pairs = self._add_negatives(
positive_pairs,
qa_list,
num_negatives=num_negatives,
hard_negative_ratio=hard_negative_ratio
)
# 打乱数据
random.shuffle(all_pairs)
# 划分数据集
total = len(all_pairs)
train_end = int(total * train_ratio)
val_end = int(total * (train_ratio + val_ratio))
train_data = all_pairs[:train_end]
val_data = all_pairs[train_end:val_end]
test_data = all_pairs[val_end:]
print(f"\n数据集划分:")
print(f" 训练集: {len(train_data)} 样本")
print(f" 验证集: {len(val_data)} 样本")
print(f" 测试集: {len(test_data)} 样本")
return train_data, val_data, test_data
def _create_positive_pairs(self, qa_list: List[Dict]) -> List[Dict]:
"""创建正样本对 (question, answer_context)"""
positive_pairs = []
for qa in qa_list:
question = qa.get('question', '').strip()
answer_context = qa.get('answer_context', '').strip()
if question and answer_context:
positive_pairs.append({
'anchor': question,
'positive': answer_context,
'label': 1 # 相似
})
print(f"创建正样本对: {len(positive_pairs)}")
return positive_pairs
def _add_negatives(
self,
positive_pairs: List[Dict],
qa_list: List[Dict],
num_negatives: int = 5,
hard_negative_ratio: float = 0.3
) -> List[Dict]:
"""
添加负样本
Args:
positive_pairs: 正样本对
qa_list: 所有QA数据
num_negatives: 负样本数量
hard_negative_ratio: 困难负样本比例
"""
print(f"\n添加负样本 (每样本 {num_negatives} 个负样本)...")
all_answers = [qa.get('answer_context', '').strip() for qa in qa_list]
all_answers = [a for a in all_answers if a]
extended_pairs = []
for pair in positive_pairs:
anchor = pair['anchor']
positive = pair['positive']
# 添加原始正样本
extended_pairs.append(pair)
# 生成负样本
hard_neg_count = int(num_negatives * hard_negative_ratio)
random_neg_count = num_negatives - hard_neg_count
# 困难负样本: 同领域但不同的答案
hard_negatives = self._sample_hard_negatives(
anchor,
all_answers,
n=hard_neg_count,
exclude=positive
)
# 随机负样本
random_negatives = self._sample_random_negatives(
all_answers,
n=random_neg_count,
exclude=positive
)
# 添加负样本对
for neg in hard_negatives + random_negatives:
extended_pairs.append({
'anchor': anchor,
'positive': neg, # 在SBERT训练中作为负样本
'label': 0 # 不相似
})
print(f"总样本数: {len(extended_pairs)}")
return extended_pairs
def _sample_hard_negatives(
self,
anchor: str,
all_answers: List[str],
n: int,
exclude: str
) -> List[str]:
"""采样困难负样本(简单实现:随机采样后可改进)"""
candidates = [a for a in all_answers if a != exclude]
if len(candidates) <= n:
return candidates
return random.sample(candidates, n)
def _sample_random_negatives(
self,
all_answers: List[str],
n: int,
exclude: str
) -> List[str]:
"""采样随机负样本"""
candidates = [a for a in all_answers if a != exclude]
if len(candidates) <= n:
return candidates
return random.sample(candidates, n)
def save_data(
self,
train_data: List[Dict],
val_data: List[Dict],
test_data: List[Dict],
format: str = 'jsonl'
):
"""保存数据到文件"""
print(f"\n保存数据到 {self.output_dir}...")
if format == 'jsonl':
# JSONL格式 (适合sentence-transformers)
self._save_jsonl(train_data, 'train.jsonl')
self._save_jsonl(val_data, 'val.jsonl')
self._save_jsonl(test_data, 'test.jsonl')
elif format == 'csv':
# CSV格式
import pandas as pd
pd.DataFrame(train_data).to_csv(
self.output_dir / 'train.csv', index=False
)
pd.DataFrame(val_data).to_csv(
self.output_dir / 'val.csv', index=False
)
pd.DataFrame(test_data).to_csv(
self.output_dir / 'test.csv', index=False
)
print("✓ 数据保存完成")
def _save_jsonl(self, data: List[Dict], filename: str):
"""保存为JSONL格式"""
filepath = self.output_dir / filename
with open(filepath, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
print(f" 保存: {filepath} ({len(data)} 样本)")
def print_statistics(self, train_data: List[Dict], val_data: List[Dict], test_data: List[Dict]):
"""打印数据统计信息"""
print("\n=== 数据统计 ===")
# 正负样本比例
for name, data in [('训练集', train_data), ('验证集', val_data), ('测试集', test_data)]:
pos_count = sum(1 for item in data if item.get('label') == 1)
neg_count = len(data) - pos_count
print(f"\n{name}:")
print(f" 总样本: {len(data)}")
print(f" 正样本: {pos_count} ({pos_count/len(data)*100:.1f}%)")
print(f" 负样本: {neg_count} ({neg_count/len(data)*100:.1f}%)")
# 文本长度统计
all_anchors = [item['anchor'] for item in train_data]
all_positives = [item['positive'] for item in train_data]
anchor_lengths = [len(a.split()) for a in all_anchors]
positive_lengths = [len(p.split()) for p in all_positives]
print(f"\n文本长度统计 (训练集):")
print(f" Anchor: 平均 {np.mean(anchor_lengths):.1f} 词, "
f"最大 {max(anchor_lengths)}, 最小 {min(anchor_lengths)}")
print(f" Positive: 平均 {np.mean(positive_lengths):.1f} 词, "
f"最大 {max(positive_lengths)}, 最小 {min(positive_lengths)}")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='准备SBERT训练数据')
parser.add_argument(
'--qa_dataset',
type=str,
default='hr-multiwoz-dataset/qa_dataset',
help='QA数据集路径'
)
parser.add_argument(
'--output_dir',
type=str,
default='data/processed/sbert',
help='输出目录'
)
parser.add_argument(
'--num_negatives',
type=int,
default=5,
help='每个正样本的负样本数量'
)
parser.add_argument(
'--hard_negative_ratio',
type=float,
default=0.3,
help='困难负样本比例'
)
parser.add_argument(
'--format',
type=str,
default='jsonl',
choices=['jsonl', 'csv'],
help='输出格式'
)
args = parser.parse_args()
# 设置随机种子
random.seed(42)
np.random.seed(42)
# 创建数据准备器
preparator = SBERTDataPreparator(
qa_dataset_path=args.qa_dataset,
output_dir=args.output_dir
)
# 准备数据
train_data, val_data, test_data = preparator.prepare_training_data(
num_negatives=args.num_negatives,
hard_negative_ratio=args.hard_negative_ratio
)
# 保存数据
preparator.save_data(train_data, val_data, test_data, format=args.format)
# 打印统计
preparator.print_statistics(train_data, val_data, test_data)
print("\n✓ 数据准备完成!")
print(f"\n输出目录: {args.output_dir}")
print(f"下一步: 运行训练脚本")
print(f" python scripts/train_sbert.py --train_data {args.output_dir}/train.jsonl")
if __name__ == '__main__':
main()