| import os |
| import re |
| from datasets import load_from_disk, Dataset |
|
|
| |
| INPUT_BASE = '/root/autodl-tmp/audio-r1/r1-a/dataset/gsm8k_with_audio' |
| OUTPUT_BASE = './gsm8k_final_filtered' |
|
|
| os.makedirs(OUTPUT_BASE, exist_ok=True) |
|
|
| |
| def is_suitable_for_tts_question(q: str) -> bool: |
| words = q.split() |
| if len(words) < 5 or len(words) > 100: |
| return False |
| if re.search(r'[\(\)\[\]/\^<>]', q): |
| return False |
| if q.count(',') > 2: |
| return False |
| return True |
|
|
| |
| all_samples = [] |
| for split_name in os.listdir(INPUT_BASE): |
| split_dir = os.path.join(INPUT_BASE, split_name, 'final_dataset') |
| if not os.path.isdir(split_dir): |
| continue |
| print(f"→ Loading split '{split_name}'") |
| ds = load_from_disk(split_dir) |
|
|
| filtered = [] |
| for ex in ds: |
| q = ex.get('question_text', '') |
| wav = ex.get('audio_filepath', '') |
| |
| if not wav or not os.path.exists(wav): |
| continue |
| |
| if not is_suitable_for_tts_question(q): |
| continue |
| rec = { |
| 'query': q, |
| 'answer': ex.get('answer', ''), |
| 'source_dataset': "gsm8k", |
| 'audio': wav, |
| 'question_type': 'Math', |
| 'difficulty': '' |
| } |
| filtered.append(rec) |
| all_samples.append(rec) |
|
|
| print(f" Kept {len(filtered)}/{len(ds)} examples in '{split_name}'") |
| |
| out_dir = os.path.join(OUTPUT_BASE, split_name) |
| os.makedirs(out_dir, exist_ok=True) |
| Dataset.from_list(filtered).save_to_disk(out_dir) |
|
|
| |
| print("→ Saving combined dataset") |
| combined_dir = os.path.join(OUTPUT_BASE, 'combined') |
| os.makedirs(combined_dir, exist_ok=True) |
| Dataset.from_list(all_samples).save_to_disk(combined_dir) |
| print(f"Total kept examples: {len(all_samples)}") |
|
|