interactSpeech / .ipynb_checkpoints /filter_duration-checkpoint.py
Student0809's picture
Add files using upload-large-folder tool
e791fa3 verified
import json
import torchaudio
from tqdm import tqdm
import os
import sys
from collections import defaultdict
def filter_audio_duration(jsonl_path, max_duration=90.0):
"""筛选JSONL文件中时长在指定秒数以下的音频文件"""
stats = defaultdict(int)
filtered_data = []
error_log = []
# 第一次遍历:统计总行数(用于进度条)
with open(jsonl_path, 'r') as f:
total_lines = sum(1 for _ in f)
# 第二次遍历:实际筛选
with open(jsonl_path, 'r') as f:
for line_num, line in enumerate(tqdm(f, total=total_lines, desc="筛选进度", unit="line")):
try:
data = json.loads(line.strip())
if 'audios' not in data or not data['audios']:
stats['no_audio_field'] += 1
continue
valid_audios = []
all_audios_valid = True
for audio_path in data['audios']:
# 检查文件是否存在
if not os.path.exists(audio_path):
stats['missing'] += 1
error_log.append(f"[行{line_num+1}] 缺失文件: {audio_path}")
all_audios_valid = False
continue
# 检查文件大小
if os.path.getsize(audio_path) == 0:
stats['zero_size'] += 1
error_log.append(f"[行{line_num+1}] 空文件: {audio_path}")
all_audios_valid = False
continue
# 验证音频内容和时长
try:
waveform, sr = torchaudio.load(audio_path)
if waveform.numel() == 0:
stats['empty_audio'] += 1
error_log.append(f"[行{line_num+1}] 空音频: {audio_path}")
all_audios_valid = False
continue
# 计算音频时长(秒)
duration = waveform.shape[1] / sr
if duration > max_duration:
stats['too_long'] += 1
error_log.append(f"[行{line_num+1}] 时长过长({duration:.2f}s): {audio_path}")
all_audios_valid = False
continue
else:
stats['valid'] += 1
valid_audios.append(audio_path)
except Exception as e:
stats['corrupted'] += 1
error_type = str(e).split('(')[0]
error_log.append(f"[行{line_num+1}] 损坏文件({error_type}): {audio_path}")
all_audios_valid = False
continue
# 如果所有音频都有效且时长符合要求,保留这个样本
if all_audios_valid and valid_audios:
# 更新audios字段为筛选后的音频列表
data['audios'] = valid_audios
filtered_data.append(data)
stats['kept'] += 1
else:
stats['filtered_out'] += 1
except json.JSONDecodeError:
stats['invalid_json'] += 1
error_log.append(f"[行{line_num+1}] 无效JSON格式")
# 保存筛选后的数据
output_path = f"{os.path.splitext(jsonl_path)[0]}_filtered_{max_duration}s.jsonl"
with open(output_path, 'w', encoding='utf-8') as f:
for data in filtered_data:
f.write(json.dumps(data, ensure_ascii=False) + '\n')
# 打印统计报告
print("\n===== 筛选报告 =====")
print(f"最大时长限制: {max_duration}秒")
print(f"总行数: {total_lines}")
print(f"保留样本: {stats['kept']}")
print(f"过滤样本: {stats['filtered_out']}")
print("--- 详细统计 ---")
for k, v in sorted(stats.items()):
print(f"{k}: {v}")
# 保存错误日志
if error_log:
log_file = f"{os.path.splitext(jsonl_path)[0]}_duration_filter_errors.log"
with open(log_file, 'w', encoding='utf-8') as f:
f.write("\n".join(error_log))
print(f"\n发现 {len(error_log)} 个问题,已保存到 {log_file}")
print(f"\n筛选后的数据已保存到: {output_path}")
if __name__ == "__main__":
if len(sys.argv) < 2:
print("使用方法: python filter_duration.py <input.jsonl> [max_duration]")
print("默认最大时长: 100秒")
sys.exit(1)
if not os.path.exists(sys.argv[1]):
print(f"错误: 文件 {sys.argv[1]} 不存在")
sys.exit(1)
max_duration = 90.0
if len(sys.argv) >= 3:
try:
max_duration = float(sys.argv[2])
except ValueError:
print(f"错误: 无效的时长参数 {sys.argv[2]}")
sys.exit(1)
filter_audio_duration(sys.argv[1], max_duration)