grsdfdf / r1-a /dataset /shp.py
1f's picture
Add files using upload-large-folder tool
19891ba verified
import re
import os # 确保导入 os 用于保存
from datasets import load_dataset, Dataset
from tqdm.auto import tqdm # 用于显示进度条
# --- 可调整的过滤参数 ---
# (保持不变)
SCORE_RATIO_THRESHOLD = 2.0
MIN_CHOSEN_SCORE = 3
MIN_HISTORY_WORDS = 10
MAX_HISTORY_WORDS = 150 # 调整为 150
MAX_URLS = 0 # 调整为 0
MAX_NEWLINES = 5
FORBIDDEN_PATTERNS = [
r"```.*```",
r"\|.*\|.*\|",
]
MIN_RESPONSE_WORDS = 10
# --- 脚本主逻辑 ---
def is_tts_friendly(text):
"""检查文本是否大致适合 TTS"""
# (保持不变)
word_count = len(text.split())
if not (MIN_HISTORY_WORDS <= word_count <= MAX_HISTORY_WORDS):
return False
if text.count('http') > MAX_URLS: # 使用调整后的 MAX_URLS
return False
if text.count('\n') > MAX_NEWLINES:
return False
for pattern in FORBIDDEN_PATTERNS:
if re.search(pattern, text, re.DOTALL):
return False
return True
def filter_shp2_train_dataset(dataset_name="stanfordnlp/shp-2"): # 函数名稍作修改以反映其目的
"""
加载并过滤 SHP-2 数据集的 'train' split,
返回高质量、适合 TTS 的偏好对。
"""
split_to_process = 'train' # <--- 指定只处理 'train' split
print(f"加载数据集: {dataset_name}, split: {split_to_process}...")
try:
# --- 修改点 1: 直接加载指定的 split ---
train_dataset = load_dataset(dataset_name, split=split_to_process)
print(f"'{split_to_process}' split 加载完成。")
except Exception as e:
print(f"错误:无法加载数据集 {dataset_name} 的 '{split_to_process}' split。")
print(f"错误详情: {e}")
return [] # 返回空列表表示失败
filtered_data = []
seen_histories = set() # 用于跟踪已经添加的 history,确保唯一性
print(f"\n开始处理 '{split_to_process}' split...")
# --- 修改点 2: 直接迭代加载的 train_dataset ---
for example in tqdm(train_dataset, desc=f"过滤 {split_to_process} split"):
history = example.get("history")
human_ref_A = example.get("human_ref_A")
human_ref_B = example.get("human_ref_B")
labels = example.get("labels")
score_A = example.get("score_A")
score_B = example.get("score_B")
score_ratio = example.get("score_ratio")
domain = example.get("domain")
# 基本检查 (保持不变)
if not all([history, human_ref_A, human_ref_B, labels is not None,
score_A is not None, score_B is not None, score_ratio is not None, domain]):
continue
# 确保 history 未被处理过 (保持不变)
if history in seen_histories:
continue
# 确定 chosen 和 reject (保持不变)
try:
label_int = int(labels)
if label_int == 1:
chosen = human_ref_A
reject = human_ref_B
chosen_score = score_A
elif label_int == 0:
chosen = human_ref_B
reject = human_ref_A
chosen_score = score_B
else:
continue
except (ValueError, TypeError):
continue
# --- 应用过滤条件 (保持不变) ---
if score_ratio is None or not isinstance(score_ratio, (int, float)) or score_ratio < SCORE_RATIO_THRESHOLD:
continue
if chosen_score is None or not isinstance(chosen_score, (int, float)) or chosen_score < MIN_CHOSEN_SCORE:
continue
if not is_tts_friendly(history):
continue
if len(chosen.split()) < MIN_RESPONSE_WORDS or len(reject.split()) < MIN_RESPONSE_WORDS:
continue
# --- 如果所有过滤条件都通过 (保持不变) ---
filtered_data.append({
"query": history,
"chosen": chosen,
"reject": reject,
"domain": domain,
})
seen_histories.add(history)
print(f"\n过滤完成。从 '{split_to_process}' split 中总共筛选出 {len(filtered_data)} 条高质量样本。")
return filtered_data
# --- 主程序 ---
if __name__ == "__main__":
# 执行过滤 (调用修改后的函数)
filtered_examples = filter_shp2_train_dataset()
if filtered_examples:
# 将过滤后的数据转换为 Hugging Face Dataset 对象 (保持不变)
filtered_dataset = Dataset.from_list(filtered_examples)
# 保存过滤后的数据集 (保持不变)
output_path = "./shp2_filtered_tts_high_quality_train_only" # 修改输出路径以反映内容
print(f"正在保存过滤后的训练集数据到: {output_path}")
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True) # 如果 output_path 是目录,这行不需要
filtered_dataset.save_to_disk(output_path)
print("数据集保存完成。")
# 打印一些样本看看 (保持不变)
print("\n部分样本预览:")
# 从保存的 Dataset 加载并预览,确保保存成功
try:
loaded_dataset = Dataset.load_from_disk(output_path)
for i in range(min(5, len(loaded_dataset))):
sample = loaded_dataset[i]
print(f"--- 样本 {i+1} ---")
print(f"Domain: {sample['domain']}")
print(f"Query: {sample['query'][:200]}...")
print(f"Chosen: {sample['chosen'][:200]}...")
except Exception as e:
print(f"加载预览样本时出错: {e}") # 增加错误处理
else:
print("没有找到符合条件的样本,请检查过滤参数设置或确认 'train' split 是否存在且包含数据。")