|
|
import json |
|
|
import torchaudio |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import sys |
|
|
from collections import defaultdict |
|
|
|
|
|
def filter_audio_duration(jsonl_path, max_duration=90.0): |
|
|
"""筛选JSONL文件中时长在指定秒数以下的音频文件""" |
|
|
stats = defaultdict(int) |
|
|
filtered_data = [] |
|
|
error_log = [] |
|
|
|
|
|
|
|
|
with open(jsonl_path, 'r') as f: |
|
|
total_lines = sum(1 for _ in f) |
|
|
|
|
|
|
|
|
with open(jsonl_path, 'r') as f: |
|
|
for line_num, line in enumerate(tqdm(f, total=total_lines, desc="筛选进度", unit="line")): |
|
|
try: |
|
|
data = json.loads(line.strip()) |
|
|
if 'audios' not in data or not data['audios']: |
|
|
stats['no_audio_field'] += 1 |
|
|
continue |
|
|
|
|
|
valid_audios = [] |
|
|
all_audios_valid = True |
|
|
|
|
|
for audio_path in data['audios']: |
|
|
|
|
|
if not os.path.exists(audio_path): |
|
|
stats['missing'] += 1 |
|
|
error_log.append(f"[行{line_num+1}] 缺失文件: {audio_path}") |
|
|
all_audios_valid = False |
|
|
continue |
|
|
|
|
|
|
|
|
if os.path.getsize(audio_path) == 0: |
|
|
stats['zero_size'] += 1 |
|
|
error_log.append(f"[行{line_num+1}] 空文件: {audio_path}") |
|
|
all_audios_valid = False |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
waveform, sr = torchaudio.load(audio_path) |
|
|
if waveform.numel() == 0: |
|
|
stats['empty_audio'] += 1 |
|
|
error_log.append(f"[行{line_num+1}] 空音频: {audio_path}") |
|
|
all_audios_valid = False |
|
|
continue |
|
|
|
|
|
|
|
|
duration = waveform.shape[1] / sr |
|
|
|
|
|
if duration > max_duration: |
|
|
stats['too_long'] += 1 |
|
|
error_log.append(f"[行{line_num+1}] 时长过长({duration:.2f}s): {audio_path}") |
|
|
all_audios_valid = False |
|
|
continue |
|
|
else: |
|
|
stats['valid'] += 1 |
|
|
valid_audios.append(audio_path) |
|
|
|
|
|
except Exception as e: |
|
|
stats['corrupted'] += 1 |
|
|
error_type = str(e).split('(')[0] |
|
|
error_log.append(f"[行{line_num+1}] 损坏文件({error_type}): {audio_path}") |
|
|
all_audios_valid = False |
|
|
continue |
|
|
|
|
|
|
|
|
if all_audios_valid and valid_audios: |
|
|
|
|
|
data['audios'] = valid_audios |
|
|
filtered_data.append(data) |
|
|
stats['kept'] += 1 |
|
|
else: |
|
|
stats['filtered_out'] += 1 |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
stats['invalid_json'] += 1 |
|
|
error_log.append(f"[行{line_num+1}] 无效JSON格式") |
|
|
|
|
|
|
|
|
output_path = f"{os.path.splitext(jsonl_path)[0]}_filtered_{max_duration}s.jsonl" |
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
for data in filtered_data: |
|
|
f.write(json.dumps(data, ensure_ascii=False) + '\n') |
|
|
|
|
|
|
|
|
print("\n===== 筛选报告 =====") |
|
|
print(f"最大时长限制: {max_duration}秒") |
|
|
print(f"总行数: {total_lines}") |
|
|
print(f"保留样本: {stats['kept']}") |
|
|
print(f"过滤样本: {stats['filtered_out']}") |
|
|
print("--- 详细统计 ---") |
|
|
for k, v in sorted(stats.items()): |
|
|
print(f"{k}: {v}") |
|
|
|
|
|
|
|
|
if error_log: |
|
|
log_file = f"{os.path.splitext(jsonl_path)[0]}_duration_filter_errors.log" |
|
|
with open(log_file, 'w', encoding='utf-8') as f: |
|
|
f.write("\n".join(error_log)) |
|
|
print(f"\n发现 {len(error_log)} 个问题,已保存到 {log_file}") |
|
|
|
|
|
print(f"\n筛选后的数据已保存到: {output_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
if len(sys.argv) < 2: |
|
|
print("使用方法: python filter_duration.py <input.jsonl> [max_duration]") |
|
|
print("默认最大时长: 100秒") |
|
|
sys.exit(1) |
|
|
|
|
|
if not os.path.exists(sys.argv[1]): |
|
|
print(f"错误: 文件 {sys.argv[1]} 不存在") |
|
|
sys.exit(1) |
|
|
|
|
|
max_duration = 90.0 |
|
|
if len(sys.argv) >= 3: |
|
|
try: |
|
|
max_duration = float(sys.argv[2]) |
|
|
except ValueError: |
|
|
print(f"错误: 无效的时长参数 {sys.argv[2]}") |
|
|
sys.exit(1) |
|
|
|
|
|
filter_audio_duration(sys.argv[1], max_duration) |