| import re |
| import os |
| from datasets import load_dataset, Dataset |
| from tqdm.auto import tqdm |
|
|
| |
| |
| SCORE_RATIO_THRESHOLD = 2.0 |
| MIN_CHOSEN_SCORE = 3 |
| MIN_HISTORY_WORDS = 10 |
| MAX_HISTORY_WORDS = 150 |
| MAX_URLS = 0 |
| MAX_NEWLINES = 5 |
| FORBIDDEN_PATTERNS = [ |
| r"```.*```", |
| r"\|.*\|.*\|", |
| ] |
| MIN_RESPONSE_WORDS = 10 |
|
|
| |
|
|
| def is_tts_friendly(text): |
| """检查文本是否大致适合 TTS""" |
| |
| word_count = len(text.split()) |
| if not (MIN_HISTORY_WORDS <= word_count <= MAX_HISTORY_WORDS): |
| return False |
| if text.count('http') > MAX_URLS: |
| return False |
| if text.count('\n') > MAX_NEWLINES: |
| return False |
| for pattern in FORBIDDEN_PATTERNS: |
| if re.search(pattern, text, re.DOTALL): |
| return False |
| return True |
|
|
| def filter_shp2_train_dataset(dataset_name="stanfordnlp/shp-2"): |
| """ |
| 加载并过滤 SHP-2 数据集的 'train' split, |
| 返回高质量、适合 TTS 的偏好对。 |
| """ |
| split_to_process = 'train' |
| print(f"加载数据集: {dataset_name}, split: {split_to_process}...") |
|
|
| try: |
| |
| train_dataset = load_dataset(dataset_name, split=split_to_process) |
| print(f"'{split_to_process}' split 加载完成。") |
| except Exception as e: |
| print(f"错误:无法加载数据集 {dataset_name} 的 '{split_to_process}' split。") |
| print(f"错误详情: {e}") |
| return [] |
|
|
| filtered_data = [] |
| seen_histories = set() |
|
|
| print(f"\n开始处理 '{split_to_process}' split...") |
| |
| for example in tqdm(train_dataset, desc=f"过滤 {split_to_process} split"): |
| history = example.get("history") |
| human_ref_A = example.get("human_ref_A") |
| human_ref_B = example.get("human_ref_B") |
| labels = example.get("labels") |
| score_A = example.get("score_A") |
| score_B = example.get("score_B") |
| score_ratio = example.get("score_ratio") |
| domain = example.get("domain") |
|
|
| |
| if not all([history, human_ref_A, human_ref_B, labels is not None, |
| score_A is not None, score_B is not None, score_ratio is not None, domain]): |
| continue |
|
|
| |
| if history in seen_histories: |
| continue |
|
|
| |
| try: |
| label_int = int(labels) |
| if label_int == 1: |
| chosen = human_ref_A |
| reject = human_ref_B |
| chosen_score = score_A |
| elif label_int == 0: |
| chosen = human_ref_B |
| reject = human_ref_A |
| chosen_score = score_B |
| else: |
| continue |
| except (ValueError, TypeError): |
| continue |
|
|
| |
| if score_ratio is None or not isinstance(score_ratio, (int, float)) or score_ratio < SCORE_RATIO_THRESHOLD: |
| continue |
| if chosen_score is None or not isinstance(chosen_score, (int, float)) or chosen_score < MIN_CHOSEN_SCORE: |
| continue |
| if not is_tts_friendly(history): |
| continue |
| if len(chosen.split()) < MIN_RESPONSE_WORDS or len(reject.split()) < MIN_RESPONSE_WORDS: |
| continue |
|
|
| |
| filtered_data.append({ |
| "query": history, |
| "chosen": chosen, |
| "reject": reject, |
| "domain": domain, |
| }) |
| seen_histories.add(history) |
|
|
| print(f"\n过滤完成。从 '{split_to_process}' split 中总共筛选出 {len(filtered_data)} 条高质量样本。") |
| return filtered_data |
|
|
| |
| if __name__ == "__main__": |
| |
| filtered_examples = filter_shp2_train_dataset() |
|
|
| if filtered_examples: |
| |
| filtered_dataset = Dataset.from_list(filtered_examples) |
|
|
| |
| output_path = "./shp2_filtered_tts_high_quality_train_only" |
| print(f"正在保存过滤后的训练集数据到: {output_path}") |
| |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| filtered_dataset.save_to_disk(output_path) |
| print("数据集保存完成。") |
|
|
| |
| print("\n部分样本预览:") |
| |
| try: |
| loaded_dataset = Dataset.load_from_disk(output_path) |
| for i in range(min(5, len(loaded_dataset))): |
| sample = loaded_dataset[i] |
| print(f"--- 样本 {i+1} ---") |
| print(f"Domain: {sample['domain']}") |
| print(f"Query: {sample['query'][:200]}...") |
| print(f"Chosen: {sample['chosen'][:200]}...") |
| except Exception as e: |
| print(f"加载预览样本时出错: {e}") |
|
|
| else: |
| print("没有找到符合条件的样本,请检查过滤参数设置或确认 'train' split 是否存在且包含数据。") |