import json import torchaudio from tqdm import tqdm import os import sys from collections import defaultdict def filter_audio_duration(jsonl_path, max_duration=90.0): """筛选JSONL文件中时长在指定秒数以下的音频文件""" stats = defaultdict(int) filtered_data = [] error_log = [] # 第一次遍历:统计总行数(用于进度条) with open(jsonl_path, 'r') as f: total_lines = sum(1 for _ in f) # 第二次遍历:实际筛选 with open(jsonl_path, 'r') as f: for line_num, line in enumerate(tqdm(f, total=total_lines, desc="筛选进度", unit="line")): try: data = json.loads(line.strip()) if 'audios' not in data or not data['audios']: stats['no_audio_field'] += 1 continue valid_audios = [] all_audios_valid = True for audio_path in data['audios']: # 检查文件是否存在 if not os.path.exists(audio_path): stats['missing'] += 1 error_log.append(f"[行{line_num+1}] 缺失文件: {audio_path}") all_audios_valid = False continue # 检查文件大小 if os.path.getsize(audio_path) == 0: stats['zero_size'] += 1 error_log.append(f"[行{line_num+1}] 空文件: {audio_path}") all_audios_valid = False continue # 验证音频内容和时长 try: waveform, sr = torchaudio.load(audio_path) if waveform.numel() == 0: stats['empty_audio'] += 1 error_log.append(f"[行{line_num+1}] 空音频: {audio_path}") all_audios_valid = False continue # 计算音频时长(秒) duration = waveform.shape[1] / sr if duration > max_duration: stats['too_long'] += 1 error_log.append(f"[行{line_num+1}] 时长过长({duration:.2f}s): {audio_path}") all_audios_valid = False continue else: stats['valid'] += 1 valid_audios.append(audio_path) except Exception as e: stats['corrupted'] += 1 error_type = str(e).split('(')[0] error_log.append(f"[行{line_num+1}] 损坏文件({error_type}): {audio_path}") all_audios_valid = False continue # 如果所有音频都有效且时长符合要求,保留这个样本 if all_audios_valid and valid_audios: # 更新audios字段为筛选后的音频列表 data['audios'] = valid_audios filtered_data.append(data) stats['kept'] += 1 else: stats['filtered_out'] += 1 except json.JSONDecodeError: stats['invalid_json'] += 1 error_log.append(f"[行{line_num+1}] 无效JSON格式") # 保存筛选后的数据 output_path = f"{os.path.splitext(jsonl_path)[0]}_filtered_{max_duration}s.jsonl" with open(output_path, 'w', encoding='utf-8') as f: for data in filtered_data: f.write(json.dumps(data, ensure_ascii=False) + '\n') # 打印统计报告 print("\n===== 筛选报告 =====") print(f"最大时长限制: {max_duration}秒") print(f"总行数: {total_lines}") print(f"保留样本: {stats['kept']}") print(f"过滤样本: {stats['filtered_out']}") print("--- 详细统计 ---") for k, v in sorted(stats.items()): print(f"{k}: {v}") # 保存错误日志 if error_log: log_file = f"{os.path.splitext(jsonl_path)[0]}_duration_filter_errors.log" with open(log_file, 'w', encoding='utf-8') as f: f.write("\n".join(error_log)) print(f"\n发现 {len(error_log)} 个问题,已保存到 {log_file}") print(f"\n筛选后的数据已保存到: {output_path}") if __name__ == "__main__": if len(sys.argv) < 2: print("使用方法: python filter_duration.py [max_duration]") print("默认最大时长: 100秒") sys.exit(1) if not os.path.exists(sys.argv[1]): print(f"错误: 文件 {sys.argv[1]} 不存在") sys.exit(1) max_duration = 90.0 if len(sys.argv) >= 3: try: max_duration = float(sys.argv[2]) except ValueError: print(f"错误: 无效的时长参数 {sys.argv[2]}") sys.exit(1) filter_audio_duration(sys.argv[1], max_duration)