""" Sentence-BERT训练数据准备脚本 从QA数据集构建语义相似度训练数据 """ import json import random from pathlib import Path from typing import List, Dict, Tuple from datasets import load_from_disk import numpy as np class SBERTDataPreparator: """SBERT训练数据准备器""" def __init__(self, qa_dataset_path: str, output_dir: str): """ Args: qa_dataset_path: QA数据集路径 output_dir: 输出目录 """ self.qa_dataset_path = qa_dataset_path self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) # 加载QA数据集 print(f"加载QA数据集: {qa_dataset_path}") self.qa_dataset = load_from_disk(qa_dataset_path) print(f"数据集大小: {len(self.qa_dataset['train'])}") def prepare_training_data( self, num_negatives: int = 5, hard_negative_ratio: float = 0.3, train_ratio: float = 0.7, val_ratio: float = 0.15 ) -> Tuple[List[Dict], List[Dict], List[Dict]]: """ 准备训练数据 Args: num_negatives: 每个正样本的负样本数量 hard_negative_ratio: 困难负样本的比例 train_ratio: 训练集比例 val_ratio: 验证集比例 Returns: (train_data, val_data, test_data) """ print("\n准备训练数据...") # 转换为列表便于处理 qa_list = list(self.qa_dataset['train']) # 构建正样本对 positive_pairs = self._create_positive_pairs(qa_list) # 构建负样本 all_pairs = self._add_negatives( positive_pairs, qa_list, num_negatives=num_negatives, hard_negative_ratio=hard_negative_ratio ) # 打乱数据 random.shuffle(all_pairs) # 划分数据集 total = len(all_pairs) train_end = int(total * train_ratio) val_end = int(total * (train_ratio + val_ratio)) train_data = all_pairs[:train_end] val_data = all_pairs[train_end:val_end] test_data = all_pairs[val_end:] print(f"\n数据集划分:") print(f" 训练集: {len(train_data)} 样本") print(f" 验证集: {len(val_data)} 样本") print(f" 测试集: {len(test_data)} 样本") return train_data, val_data, test_data def _create_positive_pairs(self, qa_list: List[Dict]) -> List[Dict]: """创建正样本对 (question, answer_context)""" positive_pairs = [] for qa in qa_list: question = qa.get('question', '').strip() answer_context = qa.get('answer_context', '').strip() if question and answer_context: positive_pairs.append({ 'anchor': question, 'positive': answer_context, 'label': 1 # 相似 }) print(f"创建正样本对: {len(positive_pairs)}") return positive_pairs def _add_negatives( self, positive_pairs: List[Dict], qa_list: List[Dict], num_negatives: int = 5, hard_negative_ratio: float = 0.3 ) -> List[Dict]: """ 添加负样本 Args: positive_pairs: 正样本对 qa_list: 所有QA数据 num_negatives: 负样本数量 hard_negative_ratio: 困难负样本比例 """ print(f"\n添加负样本 (每样本 {num_negatives} 个负样本)...") all_answers = [qa.get('answer_context', '').strip() for qa in qa_list] all_answers = [a for a in all_answers if a] extended_pairs = [] for pair in positive_pairs: anchor = pair['anchor'] positive = pair['positive'] # 添加原始正样本 extended_pairs.append(pair) # 生成负样本 hard_neg_count = int(num_negatives * hard_negative_ratio) random_neg_count = num_negatives - hard_neg_count # 困难负样本: 同领域但不同的答案 hard_negatives = self._sample_hard_negatives( anchor, all_answers, n=hard_neg_count, exclude=positive ) # 随机负样本 random_negatives = self._sample_random_negatives( all_answers, n=random_neg_count, exclude=positive ) # 添加负样本对 for neg in hard_negatives + random_negatives: extended_pairs.append({ 'anchor': anchor, 'positive': neg, # 在SBERT训练中作为负样本 'label': 0 # 不相似 }) print(f"总样本数: {len(extended_pairs)}") return extended_pairs def _sample_hard_negatives( self, anchor: str, all_answers: List[str], n: int, exclude: str ) -> List[str]: """采样困难负样本(简单实现:随机采样后可改进)""" candidates = [a for a in all_answers if a != exclude] if len(candidates) <= n: return candidates return random.sample(candidates, n) def _sample_random_negatives( self, all_answers: List[str], n: int, exclude: str ) -> List[str]: """采样随机负样本""" candidates = [a for a in all_answers if a != exclude] if len(candidates) <= n: return candidates return random.sample(candidates, n) def save_data( self, train_data: List[Dict], val_data: List[Dict], test_data: List[Dict], format: str = 'jsonl' ): """保存数据到文件""" print(f"\n保存数据到 {self.output_dir}...") if format == 'jsonl': # JSONL格式 (适合sentence-transformers) self._save_jsonl(train_data, 'train.jsonl') self._save_jsonl(val_data, 'val.jsonl') self._save_jsonl(test_data, 'test.jsonl') elif format == 'csv': # CSV格式 import pandas as pd pd.DataFrame(train_data).to_csv( self.output_dir / 'train.csv', index=False ) pd.DataFrame(val_data).to_csv( self.output_dir / 'val.csv', index=False ) pd.DataFrame(test_data).to_csv( self.output_dir / 'test.csv', index=False ) print("✓ 数据保存完成") def _save_jsonl(self, data: List[Dict], filename: str): """保存为JSONL格式""" filepath = self.output_dir / filename with open(filepath, 'w', encoding='utf-8') as f: for item in data: f.write(json.dumps(item, ensure_ascii=False) + '\n') print(f" 保存: {filepath} ({len(data)} 样本)") def print_statistics(self, train_data: List[Dict], val_data: List[Dict], test_data: List[Dict]): """打印数据统计信息""" print("\n=== 数据统计 ===") # 正负样本比例 for name, data in [('训练集', train_data), ('验证集', val_data), ('测试集', test_data)]: pos_count = sum(1 for item in data if item.get('label') == 1) neg_count = len(data) - pos_count print(f"\n{name}:") print(f" 总样本: {len(data)}") print(f" 正样本: {pos_count} ({pos_count/len(data)*100:.1f}%)") print(f" 负样本: {neg_count} ({neg_count/len(data)*100:.1f}%)") # 文本长度统计 all_anchors = [item['anchor'] for item in train_data] all_positives = [item['positive'] for item in train_data] anchor_lengths = [len(a.split()) for a in all_anchors] positive_lengths = [len(p.split()) for p in all_positives] print(f"\n文本长度统计 (训练集):") print(f" Anchor: 平均 {np.mean(anchor_lengths):.1f} 词, " f"最大 {max(anchor_lengths)}, 最小 {min(anchor_lengths)}") print(f" Positive: 平均 {np.mean(positive_lengths):.1f} 词, " f"最大 {max(positive_lengths)}, 最小 {min(positive_lengths)}") def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description='准备SBERT训练数据') parser.add_argument( '--qa_dataset', type=str, default='hr-multiwoz-dataset/qa_dataset', help='QA数据集路径' ) parser.add_argument( '--output_dir', type=str, default='data/processed/sbert', help='输出目录' ) parser.add_argument( '--num_negatives', type=int, default=5, help='每个正样本的负样本数量' ) parser.add_argument( '--hard_negative_ratio', type=float, default=0.3, help='困难负样本比例' ) parser.add_argument( '--format', type=str, default='jsonl', choices=['jsonl', 'csv'], help='输出格式' ) args = parser.parse_args() # 设置随机种子 random.seed(42) np.random.seed(42) # 创建数据准备器 preparator = SBERTDataPreparator( qa_dataset_path=args.qa_dataset, output_dir=args.output_dir ) # 准备数据 train_data, val_data, test_data = preparator.prepare_training_data( num_negatives=args.num_negatives, hard_negative_ratio=args.hard_negative_ratio ) # 保存数据 preparator.save_data(train_data, val_data, test_data, format=args.format) # 打印统计 preparator.print_statistics(train_data, val_data, test_data) print("\n✓ 数据准备完成!") print(f"\n输出目录: {args.output_dir}") print(f"下一步: 运行训练脚本") print(f" python scripts/train_sbert.py --train_data {args.output_dir}/train.jsonl") if __name__ == '__main__': main()