grsdfdf / r1-a /dataset /filter /gsm8k.py
1f's picture
Add files using upload-large-folder tool
2a79d4e verified
import os
import re
from datasets import load_from_disk, Dataset
# --- 配置参数 ---
INPUT_BASE = '/root/autodl-tmp/audio-r1/r1-a/dataset/gsm8k_with_audio'
OUTPUT_BASE = './gsm8k_final_filtered'
os.makedirs(OUTPUT_BASE, exist_ok=True)
# --- 过滤函数(同之前) ---
def is_suitable_for_tts_question(q: str) -> bool:
words = q.split()
if len(words) < 5 or len(words) > 100:
return False
if re.search(r'[\(\)\[\]/\^<>]', q):
return False
if q.count(',') > 2:
return False
return True
# --- 处理每个 split ---
all_samples = []
for split_name in os.listdir(INPUT_BASE):
split_dir = os.path.join(INPUT_BASE, split_name, 'final_dataset')
if not os.path.isdir(split_dir):
continue
print(f"→ Loading split '{split_name}'")
ds = load_from_disk(split_dir)
filtered = []
for ex in ds:
q = ex.get('question_text', '')
wav = ex.get('audio_filepath', '')
# 跳过无音频或文件缺失
if not wav or not os.path.exists(wav):
continue
# 过滤不合适的问句
if not is_suitable_for_tts_question(q):
continue
rec = {
'query': q,
'answer': ex.get('answer', ''),
'source_dataset': "gsm8k",
'audio': wav,
'question_type': 'Math',
'difficulty': ''
}
filtered.append(rec)
all_samples.append(rec)
print(f" Kept {len(filtered)}/{len(ds)} examples in '{split_name}'")
# 保存该 split
out_dir = os.path.join(OUTPUT_BASE, split_name)
os.makedirs(out_dir, exist_ok=True)
Dataset.from_list(filtered).save_to_disk(out_dir)
# --- 可选:合并所有 split ---
print("→ Saving combined dataset")
combined_dir = os.path.join(OUTPUT_BASE, 'combined')
os.makedirs(combined_dir, exist_ok=True)
Dataset.from_list(all_samples).save_to_disk(combined_dir)
print(f"Total kept examples: {len(all_samples)}")