Spaces:
Sleeping
Sleeping
| """ | |
| Sentence-BERT训练数据准备脚本 | |
| 从QA数据集构建语义相似度训练数据 | |
| """ | |
| import json | |
| import random | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple | |
| from datasets import load_from_disk | |
| import numpy as np | |
| class SBERTDataPreparator: | |
| """SBERT训练数据准备器""" | |
| def __init__(self, qa_dataset_path: str, output_dir: str): | |
| """ | |
| Args: | |
| qa_dataset_path: QA数据集路径 | |
| output_dir: 输出目录 | |
| """ | |
| self.qa_dataset_path = qa_dataset_path | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| # 加载QA数据集 | |
| print(f"加载QA数据集: {qa_dataset_path}") | |
| self.qa_dataset = load_from_disk(qa_dataset_path) | |
| print(f"数据集大小: {len(self.qa_dataset['train'])}") | |
| def prepare_training_data( | |
| self, | |
| num_negatives: int = 5, | |
| hard_negative_ratio: float = 0.3, | |
| train_ratio: float = 0.7, | |
| val_ratio: float = 0.15 | |
| ) -> Tuple[List[Dict], List[Dict], List[Dict]]: | |
| """ | |
| 准备训练数据 | |
| Args: | |
| num_negatives: 每个正样本的负样本数量 | |
| hard_negative_ratio: 困难负样本的比例 | |
| train_ratio: 训练集比例 | |
| val_ratio: 验证集比例 | |
| Returns: | |
| (train_data, val_data, test_data) | |
| """ | |
| print("\n准备训练数据...") | |
| # 转换为列表便于处理 | |
| qa_list = list(self.qa_dataset['train']) | |
| # 构建正样本对 | |
| positive_pairs = self._create_positive_pairs(qa_list) | |
| # 构建负样本 | |
| all_pairs = self._add_negatives( | |
| positive_pairs, | |
| qa_list, | |
| num_negatives=num_negatives, | |
| hard_negative_ratio=hard_negative_ratio | |
| ) | |
| # 打乱数据 | |
| random.shuffle(all_pairs) | |
| # 划分数据集 | |
| total = len(all_pairs) | |
| train_end = int(total * train_ratio) | |
| val_end = int(total * (train_ratio + val_ratio)) | |
| train_data = all_pairs[:train_end] | |
| val_data = all_pairs[train_end:val_end] | |
| test_data = all_pairs[val_end:] | |
| print(f"\n数据集划分:") | |
| print(f" 训练集: {len(train_data)} 样本") | |
| print(f" 验证集: {len(val_data)} 样本") | |
| print(f" 测试集: {len(test_data)} 样本") | |
| return train_data, val_data, test_data | |
| def _create_positive_pairs(self, qa_list: List[Dict]) -> List[Dict]: | |
| """创建正样本对 (question, answer_context)""" | |
| positive_pairs = [] | |
| for qa in qa_list: | |
| question = qa.get('question', '').strip() | |
| answer_context = qa.get('answer_context', '').strip() | |
| if question and answer_context: | |
| positive_pairs.append({ | |
| 'anchor': question, | |
| 'positive': answer_context, | |
| 'label': 1 # 相似 | |
| }) | |
| print(f"创建正样本对: {len(positive_pairs)}") | |
| return positive_pairs | |
| def _add_negatives( | |
| self, | |
| positive_pairs: List[Dict], | |
| qa_list: List[Dict], | |
| num_negatives: int = 5, | |
| hard_negative_ratio: float = 0.3 | |
| ) -> List[Dict]: | |
| """ | |
| 添加负样本 | |
| Args: | |
| positive_pairs: 正样本对 | |
| qa_list: 所有QA数据 | |
| num_negatives: 负样本数量 | |
| hard_negative_ratio: 困难负样本比例 | |
| """ | |
| print(f"\n添加负样本 (每样本 {num_negatives} 个负样本)...") | |
| all_answers = [qa.get('answer_context', '').strip() for qa in qa_list] | |
| all_answers = [a for a in all_answers if a] | |
| extended_pairs = [] | |
| for pair in positive_pairs: | |
| anchor = pair['anchor'] | |
| positive = pair['positive'] | |
| # 添加原始正样本 | |
| extended_pairs.append(pair) | |
| # 生成负样本 | |
| hard_neg_count = int(num_negatives * hard_negative_ratio) | |
| random_neg_count = num_negatives - hard_neg_count | |
| # 困难负样本: 同领域但不同的答案 | |
| hard_negatives = self._sample_hard_negatives( | |
| anchor, | |
| all_answers, | |
| n=hard_neg_count, | |
| exclude=positive | |
| ) | |
| # 随机负样本 | |
| random_negatives = self._sample_random_negatives( | |
| all_answers, | |
| n=random_neg_count, | |
| exclude=positive | |
| ) | |
| # 添加负样本对 | |
| for neg in hard_negatives + random_negatives: | |
| extended_pairs.append({ | |
| 'anchor': anchor, | |
| 'positive': neg, # 在SBERT训练中作为负样本 | |
| 'label': 0 # 不相似 | |
| }) | |
| print(f"总样本数: {len(extended_pairs)}") | |
| return extended_pairs | |
| def _sample_hard_negatives( | |
| self, | |
| anchor: str, | |
| all_answers: List[str], | |
| n: int, | |
| exclude: str | |
| ) -> List[str]: | |
| """采样困难负样本(简单实现:随机采样后可改进)""" | |
| candidates = [a for a in all_answers if a != exclude] | |
| if len(candidates) <= n: | |
| return candidates | |
| return random.sample(candidates, n) | |
| def _sample_random_negatives( | |
| self, | |
| all_answers: List[str], | |
| n: int, | |
| exclude: str | |
| ) -> List[str]: | |
| """采样随机负样本""" | |
| candidates = [a for a in all_answers if a != exclude] | |
| if len(candidates) <= n: | |
| return candidates | |
| return random.sample(candidates, n) | |
| def save_data( | |
| self, | |
| train_data: List[Dict], | |
| val_data: List[Dict], | |
| test_data: List[Dict], | |
| format: str = 'jsonl' | |
| ): | |
| """保存数据到文件""" | |
| print(f"\n保存数据到 {self.output_dir}...") | |
| if format == 'jsonl': | |
| # JSONL格式 (适合sentence-transformers) | |
| self._save_jsonl(train_data, 'train.jsonl') | |
| self._save_jsonl(val_data, 'val.jsonl') | |
| self._save_jsonl(test_data, 'test.jsonl') | |
| elif format == 'csv': | |
| # CSV格式 | |
| import pandas as pd | |
| pd.DataFrame(train_data).to_csv( | |
| self.output_dir / 'train.csv', index=False | |
| ) | |
| pd.DataFrame(val_data).to_csv( | |
| self.output_dir / 'val.csv', index=False | |
| ) | |
| pd.DataFrame(test_data).to_csv( | |
| self.output_dir / 'test.csv', index=False | |
| ) | |
| print("✓ 数据保存完成") | |
| def _save_jsonl(self, data: List[Dict], filename: str): | |
| """保存为JSONL格式""" | |
| filepath = self.output_dir / filename | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| for item in data: | |
| f.write(json.dumps(item, ensure_ascii=False) + '\n') | |
| print(f" 保存: {filepath} ({len(data)} 样本)") | |
| def print_statistics(self, train_data: List[Dict], val_data: List[Dict], test_data: List[Dict]): | |
| """打印数据统计信息""" | |
| print("\n=== 数据统计 ===") | |
| # 正负样本比例 | |
| for name, data in [('训练集', train_data), ('验证集', val_data), ('测试集', test_data)]: | |
| pos_count = sum(1 for item in data if item.get('label') == 1) | |
| neg_count = len(data) - pos_count | |
| print(f"\n{name}:") | |
| print(f" 总样本: {len(data)}") | |
| print(f" 正样本: {pos_count} ({pos_count/len(data)*100:.1f}%)") | |
| print(f" 负样本: {neg_count} ({neg_count/len(data)*100:.1f}%)") | |
| # 文本长度统计 | |
| all_anchors = [item['anchor'] for item in train_data] | |
| all_positives = [item['positive'] for item in train_data] | |
| anchor_lengths = [len(a.split()) for a in all_anchors] | |
| positive_lengths = [len(p.split()) for p in all_positives] | |
| print(f"\n文本长度统计 (训练集):") | |
| print(f" Anchor: 平均 {np.mean(anchor_lengths):.1f} 词, " | |
| f"最大 {max(anchor_lengths)}, 最小 {min(anchor_lengths)}") | |
| print(f" Positive: 平均 {np.mean(positive_lengths):.1f} 词, " | |
| f"最大 {max(positive_lengths)}, 最小 {min(positive_lengths)}") | |
| def main(): | |
| """主函数""" | |
| import argparse | |
| parser = argparse.ArgumentParser(description='准备SBERT训练数据') | |
| parser.add_argument( | |
| '--qa_dataset', | |
| type=str, | |
| default='hr-multiwoz-dataset/qa_dataset', | |
| help='QA数据集路径' | |
| ) | |
| parser.add_argument( | |
| '--output_dir', | |
| type=str, | |
| default='data/processed/sbert', | |
| help='输出目录' | |
| ) | |
| parser.add_argument( | |
| '--num_negatives', | |
| type=int, | |
| default=5, | |
| help='每个正样本的负样本数量' | |
| ) | |
| parser.add_argument( | |
| '--hard_negative_ratio', | |
| type=float, | |
| default=0.3, | |
| help='困难负样本比例' | |
| ) | |
| parser.add_argument( | |
| '--format', | |
| type=str, | |
| default='jsonl', | |
| choices=['jsonl', 'csv'], | |
| help='输出格式' | |
| ) | |
| args = parser.parse_args() | |
| # 设置随机种子 | |
| random.seed(42) | |
| np.random.seed(42) | |
| # 创建数据准备器 | |
| preparator = SBERTDataPreparator( | |
| qa_dataset_path=args.qa_dataset, | |
| output_dir=args.output_dir | |
| ) | |
| # 准备数据 | |
| train_data, val_data, test_data = preparator.prepare_training_data( | |
| num_negatives=args.num_negatives, | |
| hard_negative_ratio=args.hard_negative_ratio | |
| ) | |
| # 保存数据 | |
| preparator.save_data(train_data, val_data, test_data, format=args.format) | |
| # 打印统计 | |
| preparator.print_statistics(train_data, val_data, test_data) | |
| print("\n✓ 数据准备完成!") | |
| print(f"\n输出目录: {args.output_dir}") | |
| print(f"下一步: 运行训练脚本") | |
| print(f" python scripts/train_sbert.py --train_data {args.output_dir}/train.jsonl") | |
| if __name__ == '__main__': | |
| main() | |