File size: 5,873 Bytes
099b013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import numpy as np
from datasets import load_dataset
import os
import config
from itertools import combinations
import random

os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "true"

dataset = load_dataset(config.PROCESS_REPO_ID, split="train")


def process_audio(audio_obj):
    """处理音频对象,返回音频数据和采样率"""
    try:
        if hasattr(audio_obj, 'get_all_samples'):
            samples = audio_obj.get_all_samples()
            audio_data = samples.data
            if not isinstance(audio_data, np.ndarray):
                audio_data = np.array(audio_data, dtype=np.float32)
            sample_rate = samples.sample_rate
            if not isinstance(sample_rate, int):
                sample_rate = int(sample_rate)
            if len(audio_data.shape) > 1:
                audio_data = audio_data.mean(axis=0)
            return audio_data, sample_rate
        else:
            print("音频对象缺少 get_all_samples 方法")
            return None, None
    except Exception as e:
        print(f"处理音频失败: {e}")
        return None, None


def generate_random_pairs(audio_fields, include_reverse=True, shuffle_order=True):
    """

    生成随机的音频对比较对



    Args:

        audio_fields: 音频字段列表

        include_reverse: 是否包含反向比较(A vs B 和 B vs A)

        shuffle_order: 是否随机打乱比较对的顺序



    Returns:

        比较对的列表

    """
    basic_combinations = list(combinations(audio_fields, 2))

    if include_reverse:
        pairs = []
        for combo in basic_combinations:
            if random.choice([True, False]):
                pairs.append((combo[1], combo[0]))
            else:
                pairs.append(combo)
    else:
        pairs = basic_combinations

    if shuffle_order:
        random.shuffle(pairs)

    return pairs


def generate_all_permutations(audio_fields, shuffle_order=True):
    """

    生成所有可能的有序对(包括正向和反向)



    Args:

        audio_fields: 音频字段列表

        shuffle_order: 是否随机打乱顺序



    Returns:

        所有有序对的列表

    """
    pairs = []
    for i, field_a in enumerate(audio_fields):
        for j, field_b in enumerate(audio_fields):
            if i != j:  # 不与自己比较
                pairs.append((field_a, field_b))

    if shuffle_order:
        random.shuffle(pairs)

    return pairs


def load_tasks(comparison_mode="random_reverse", seed=None):
    """

    使用config配置的音频字段进行两两比较



    Args:

        comparison_mode: 比较模式

            - "fixed": 固定顺序的组合(原始模式)

            - "random_reverse": 随机决定是否反转每个组合的顺序

            - "all_permutations": 生成所有可能的有序对

        seed: 随机种子,仅在需要复现结果时使用

    """
    if seed is not None:
        random.seed(seed)
        print(f"使用随机种子: {seed}")
    else:
        print("使用真随机模式")

    print("处理数据集...")

    audio_fields = config.AUDIO_FIELDS
    text_field = config.FIELD_TEXT
    instruction_field = config.FIELD_INSTRUCTION

    print(f"使用音频字段: {audio_fields}")
    print(f"文本字段: {text_field}")
    print(f"指令字段: {instruction_field}")
    print(f"比较模式: {comparison_mode}")

    tasks = []

    for i, row in enumerate(dataset):
        processed_audios = {}
        for field in audio_fields:
            if field not in row or row[field] is None:
                print(f"任务 {i} 缺少音频字段: {field}")
                continue

            audio_data, audio_rate = process_audio(row[field])
            if (audio_data is not None and audio_rate is not None and
                    isinstance(audio_data, np.ndarray) and isinstance(audio_rate, int)):
                processed_audios[field] = (audio_data, audio_rate)
            else:
                print(f"任务 {i} 的音频字段 {field} 处理失败")

        if len(processed_audios) < 2:
            print(f"跳过任务 {i}:有效音频数量不足")
            continue

        text = row.get(text_field, '')
        instruction = row.get(instruction_field, '请比较这两个音频的质量')

        available_fields = list(processed_audios.keys())

        if comparison_mode == "fixed":
            pairs = list(combinations(available_fields, 2))
        elif comparison_mode == "random_reverse":
            pairs = generate_random_pairs(available_fields, include_reverse=True, shuffle_order=True)
        elif comparison_mode == "all_permutations":
            pairs = generate_all_permutations(available_fields, shuffle_order=True)
        else:
            raise ValueError(f"未知的比较模式: {comparison_mode}")

        for field_a, field_b in pairs:
            tasks.append({
                "instruction": instruction,
                "text": text,
                "audioA": processed_audios[field_a],
                "audioB": processed_audios[field_b],
                "audioA_source": field_a,
                "audioB_source": field_b,
                "comparison": f"{field_a} vs {field_b}",
                "original_index": i
            })

    print(f"成功生成 {len(tasks)} 个比较任务")
    if len(tasks) == 0:
        print("没有可用任务!")
        exit()

    comparison_counts = {}
    for task in tasks:
        comp = task["comparison"]
        comparison_counts[comp] = comparison_counts.get(comp, 0) + 1

    print("比较任务统计:")
    for comp, count in sorted(comparison_counts.items()):
        print(f"  {comp}: {count} 个任务")

    return tasks