File size: 5,114 Bytes
e791fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json
import torchaudio
from tqdm import tqdm
import os
import sys
from collections import defaultdict

def filter_audio_duration(jsonl_path, max_duration=90.0):
    """筛选JSONL文件中时长在指定秒数以下的音频文件"""
    stats = defaultdict(int)
    filtered_data = []
    error_log = []

    # 第一次遍历:统计总行数(用于进度条)
    with open(jsonl_path, 'r') as f:
        total_lines = sum(1 for _ in f)

    # 第二次遍历:实际筛选
    with open(jsonl_path, 'r') as f:
        for line_num, line in enumerate(tqdm(f, total=total_lines, desc="筛选进度", unit="line")):
            try:
                data = json.loads(line.strip())
                if 'audios' not in data or not data['audios']:
                    stats['no_audio_field'] += 1
                    continue

                valid_audios = []
                all_audios_valid = True

                for audio_path in data['audios']:
                    # 检查文件是否存在
                    if not os.path.exists(audio_path):
                        stats['missing'] += 1
                        error_log.append(f"[行{line_num+1}] 缺失文件: {audio_path}")
                        all_audios_valid = False
                        continue

                    # 检查文件大小
                    if os.path.getsize(audio_path) == 0:
                        stats['zero_size'] += 1
                        error_log.append(f"[行{line_num+1}] 空文件: {audio_path}")
                        all_audios_valid = False
                        continue

                    # 验证音频内容和时长
                    try:
                        waveform, sr = torchaudio.load(audio_path)
                        if waveform.numel() == 0:
                            stats['empty_audio'] += 1
                            error_log.append(f"[行{line_num+1}] 空音频: {audio_path}")
                            all_audios_valid = False
                            continue

                        # 计算音频时长(秒)
                        duration = waveform.shape[1] / sr
                        
                        if duration > max_duration:
                            stats['too_long'] += 1
                            error_log.append(f"[行{line_num+1}] 时长过长({duration:.2f}s): {audio_path}")
                            all_audios_valid = False
                            continue
                        else:
                            stats['valid'] += 1
                            valid_audios.append(audio_path)

                    except Exception as e:
                        stats['corrupted'] += 1
                        error_type = str(e).split('(')[0]
                        error_log.append(f"[行{line_num+1}] 损坏文件({error_type}): {audio_path}")
                        all_audios_valid = False
                        continue

                # 如果所有音频都有效且时长符合要求,保留这个样本
                if all_audios_valid and valid_audios:
                    # 更新audios字段为筛选后的音频列表
                    data['audios'] = valid_audios
                    filtered_data.append(data)
                    stats['kept'] += 1
                else:
                    stats['filtered_out'] += 1

            except json.JSONDecodeError:
                stats['invalid_json'] += 1
                error_log.append(f"[行{line_num+1}] 无效JSON格式")

    # 保存筛选后的数据
    output_path = f"{os.path.splitext(jsonl_path)[0]}_filtered_{max_duration}s.jsonl"
    with open(output_path, 'w', encoding='utf-8') as f:
        for data in filtered_data:
            f.write(json.dumps(data, ensure_ascii=False) + '\n')

    # 打印统计报告
    print("\n===== 筛选报告 =====")
    print(f"最大时长限制: {max_duration}秒")
    print(f"总行数: {total_lines}")
    print(f"保留样本: {stats['kept']}")
    print(f"过滤样本: {stats['filtered_out']}")
    print("--- 详细统计 ---")
    for k, v in sorted(stats.items()):
        print(f"{k}: {v}")

    # 保存错误日志
    if error_log:
        log_file = f"{os.path.splitext(jsonl_path)[0]}_duration_filter_errors.log"
        with open(log_file, 'w', encoding='utf-8') as f:
            f.write("\n".join(error_log))
        print(f"\n发现 {len(error_log)} 个问题,已保存到 {log_file}")

    print(f"\n筛选后的数据已保存到: {output_path}")

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("使用方法: python filter_duration.py <input.jsonl> [max_duration]")
        print("默认最大时长: 100秒")
        sys.exit(1)

    if not os.path.exists(sys.argv[1]):
        print(f"错误: 文件 {sys.argv[1]} 不存在")
        sys.exit(1)

    max_duration = 90.0
    if len(sys.argv) >= 3:
        try:
            max_duration = float(sys.argv[2])
        except ValueError:
            print(f"错误: 无效的时长参数 {sys.argv[2]}")
            sys.exit(1)

    filter_audio_duration(sys.argv[1], max_duration)