File size: 10,135 Bytes
c8b1f17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""
Sentence-BERT训练数据准备脚本
从QA数据集构建语义相似度训练数据
"""
import json
import random
from pathlib import Path
from typing import List, Dict, Tuple
from datasets import load_from_disk
import numpy as np


class SBERTDataPreparator:
    """SBERT训练数据准备器"""

    def __init__(self, qa_dataset_path: str, output_dir: str):
        """
        Args:
            qa_dataset_path: QA数据集路径
            output_dir: 输出目录
        """
        self.qa_dataset_path = qa_dataset_path
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # 加载QA数据集
        print(f"加载QA数据集: {qa_dataset_path}")
        self.qa_dataset = load_from_disk(qa_dataset_path)
        print(f"数据集大小: {len(self.qa_dataset['train'])}")

    def prepare_training_data(
        self,
        num_negatives: int = 5,
        hard_negative_ratio: float = 0.3,
        train_ratio: float = 0.7,
        val_ratio: float = 0.15
    ) -> Tuple[List[Dict], List[Dict], List[Dict]]:
        """
        准备训练数据

        Args:
            num_negatives: 每个正样本的负样本数量
            hard_negative_ratio: 困难负样本的比例
            train_ratio: 训练集比例
            val_ratio: 验证集比例

        Returns:
            (train_data, val_data, test_data)
        """
        print("\n准备训练数据...")

        # 转换为列表便于处理
        qa_list = list(self.qa_dataset['train'])

        # 构建正样本对
        positive_pairs = self._create_positive_pairs(qa_list)

        # 构建负样本
        all_pairs = self._add_negatives(
            positive_pairs,
            qa_list,
            num_negatives=num_negatives,
            hard_negative_ratio=hard_negative_ratio
        )

        # 打乱数据
        random.shuffle(all_pairs)

        # 划分数据集
        total = len(all_pairs)
        train_end = int(total * train_ratio)
        val_end = int(total * (train_ratio + val_ratio))

        train_data = all_pairs[:train_end]
        val_data = all_pairs[train_end:val_end]
        test_data = all_pairs[val_end:]

        print(f"\n数据集划分:")
        print(f"  训练集: {len(train_data)} 样本")
        print(f"  验证集: {len(val_data)} 样本")
        print(f"  测试集: {len(test_data)} 样本")

        return train_data, val_data, test_data

    def _create_positive_pairs(self, qa_list: List[Dict]) -> List[Dict]:
        """创建正样本对 (question, answer_context)"""
        positive_pairs = []

        for qa in qa_list:
            question = qa.get('question', '').strip()
            answer_context = qa.get('answer_context', '').strip()

            if question and answer_context:
                positive_pairs.append({
                    'anchor': question,
                    'positive': answer_context,
                    'label': 1  # 相似
                })

        print(f"创建正样本对: {len(positive_pairs)}")
        return positive_pairs

    def _add_negatives(
        self,
        positive_pairs: List[Dict],
        qa_list: List[Dict],
        num_negatives: int = 5,
        hard_negative_ratio: float = 0.3
    ) -> List[Dict]:
        """
        添加负样本

        Args:
            positive_pairs: 正样本对
            qa_list: 所有QA数据
            num_negatives: 负样本数量
            hard_negative_ratio: 困难负样本比例
        """
        print(f"\n添加负样本 (每样本 {num_negatives} 个负样本)...")

        all_answers = [qa.get('answer_context', '').strip() for qa in qa_list]
        all_answers = [a for a in all_answers if a]

        extended_pairs = []

        for pair in positive_pairs:
            anchor = pair['anchor']
            positive = pair['positive']

            # 添加原始正样本
            extended_pairs.append(pair)

            # 生成负样本
            hard_neg_count = int(num_negatives * hard_negative_ratio)
            random_neg_count = num_negatives - hard_neg_count

            # 困难负样本: 同领域但不同的答案
            hard_negatives = self._sample_hard_negatives(
                anchor,
                all_answers,
                n=hard_neg_count,
                exclude=positive
            )

            # 随机负样本
            random_negatives = self._sample_random_negatives(
                all_answers,
                n=random_neg_count,
                exclude=positive
            )

            # 添加负样本对
            for neg in hard_negatives + random_negatives:
                extended_pairs.append({
                    'anchor': anchor,
                    'positive': neg,  # 在SBERT训练中作为负样本
                    'label': 0  # 不相似
                })

        print(f"总样本数: {len(extended_pairs)}")
        return extended_pairs

    def _sample_hard_negatives(
        self,
        anchor: str,
        all_answers: List[str],
        n: int,
        exclude: str
    ) -> List[str]:
        """采样困难负样本(简单实现:随机采样后可改进)"""
        candidates = [a for a in all_answers if a != exclude]
        if len(candidates) <= n:
            return candidates
        return random.sample(candidates, n)

    def _sample_random_negatives(
        self,
        all_answers: List[str],
        n: int,
        exclude: str
    ) -> List[str]:
        """采样随机负样本"""
        candidates = [a for a in all_answers if a != exclude]
        if len(candidates) <= n:
            return candidates
        return random.sample(candidates, n)

    def save_data(
        self,
        train_data: List[Dict],
        val_data: List[Dict],
        test_data: List[Dict],
        format: str = 'jsonl'
    ):
        """保存数据到文件"""
        print(f"\n保存数据到 {self.output_dir}...")

        if format == 'jsonl':
            # JSONL格式 (适合sentence-transformers)
            self._save_jsonl(train_data, 'train.jsonl')
            self._save_jsonl(val_data, 'val.jsonl')
            self._save_jsonl(test_data, 'test.jsonl')

        elif format == 'csv':
            # CSV格式
            import pandas as pd
            pd.DataFrame(train_data).to_csv(
                self.output_dir / 'train.csv', index=False
            )
            pd.DataFrame(val_data).to_csv(
                self.output_dir / 'val.csv', index=False
            )
            pd.DataFrame(test_data).to_csv(
                self.output_dir / 'test.csv', index=False
            )

        print("✓ 数据保存完成")

    def _save_jsonl(self, data: List[Dict], filename: str):
        """保存为JSONL格式"""
        filepath = self.output_dir / filename
        with open(filepath, 'w', encoding='utf-8') as f:
            for item in data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
        print(f"  保存: {filepath} ({len(data)} 样本)")

    def print_statistics(self, train_data: List[Dict], val_data: List[Dict], test_data: List[Dict]):
        """打印数据统计信息"""
        print("\n=== 数据统计 ===")

        # 正负样本比例
        for name, data in [('训练集', train_data), ('验证集', val_data), ('测试集', test_data)]:
            pos_count = sum(1 for item in data if item.get('label') == 1)
            neg_count = len(data) - pos_count
            print(f"\n{name}:")
            print(f"  总样本: {len(data)}")
            print(f"  正样本: {pos_count} ({pos_count/len(data)*100:.1f}%)")
            print(f"  负样本: {neg_count} ({neg_count/len(data)*100:.1f}%)")

        # 文本长度统计
        all_anchors = [item['anchor'] for item in train_data]
        all_positives = [item['positive'] for item in train_data]

        anchor_lengths = [len(a.split()) for a in all_anchors]
        positive_lengths = [len(p.split()) for p in all_positives]

        print(f"\n文本长度统计 (训练集):")
        print(f"  Anchor: 平均 {np.mean(anchor_lengths):.1f} 词, "
              f"最大 {max(anchor_lengths)}, 最小 {min(anchor_lengths)}")
        print(f"  Positive: 平均 {np.mean(positive_lengths):.1f} 词, "
              f"最大 {max(positive_lengths)}, 最小 {min(positive_lengths)}")


def main():
    """主函数"""
    import argparse

    parser = argparse.ArgumentParser(description='准备SBERT训练数据')
    parser.add_argument(
        '--qa_dataset',
        type=str,
        default='hr-multiwoz-dataset/qa_dataset',
        help='QA数据集路径'
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        default='data/processed/sbert',
        help='输出目录'
    )
    parser.add_argument(
        '--num_negatives',
        type=int,
        default=5,
        help='每个正样本的负样本数量'
    )
    parser.add_argument(
        '--hard_negative_ratio',
        type=float,
        default=0.3,
        help='困难负样本比例'
    )
    parser.add_argument(
        '--format',
        type=str,
        default='jsonl',
        choices=['jsonl', 'csv'],
        help='输出格式'
    )

    args = parser.parse_args()

    # 设置随机种子
    random.seed(42)
    np.random.seed(42)

    # 创建数据准备器
    preparator = SBERTDataPreparator(
        qa_dataset_path=args.qa_dataset,
        output_dir=args.output_dir
    )

    # 准备数据
    train_data, val_data, test_data = preparator.prepare_training_data(
        num_negatives=args.num_negatives,
        hard_negative_ratio=args.hard_negative_ratio
    )

    # 保存数据
    preparator.save_data(train_data, val_data, test_data, format=args.format)

    # 打印统计
    preparator.print_statistics(train_data, val_data, test_data)

    print("\n✓ 数据准备完成!")
    print(f"\n输出目录: {args.output_dir}")
    print(f"下一步: 运行训练脚本")
    print(f"  python scripts/train_sbert.py --train_data {args.output_dir}/train.jsonl")


if __name__ == '__main__':
    main()