Spaces:
Sleeping
Sleeping
File size: 5,873 Bytes
099b013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import numpy as np
from datasets import load_dataset
import os
import config
from itertools import combinations
import random
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "true"
dataset = load_dataset(config.PROCESS_REPO_ID, split="train")
def process_audio(audio_obj):
"""处理音频对象,返回音频数据和采样率"""
try:
if hasattr(audio_obj, 'get_all_samples'):
samples = audio_obj.get_all_samples()
audio_data = samples.data
if not isinstance(audio_data, np.ndarray):
audio_data = np.array(audio_data, dtype=np.float32)
sample_rate = samples.sample_rate
if not isinstance(sample_rate, int):
sample_rate = int(sample_rate)
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=0)
return audio_data, sample_rate
else:
print("音频对象缺少 get_all_samples 方法")
return None, None
except Exception as e:
print(f"处理音频失败: {e}")
return None, None
def generate_random_pairs(audio_fields, include_reverse=True, shuffle_order=True):
"""
生成随机的音频对比较对
Args:
audio_fields: 音频字段列表
include_reverse: 是否包含反向比较(A vs B 和 B vs A)
shuffle_order: 是否随机打乱比较对的顺序
Returns:
比较对的列表
"""
basic_combinations = list(combinations(audio_fields, 2))
if include_reverse:
pairs = []
for combo in basic_combinations:
if random.choice([True, False]):
pairs.append((combo[1], combo[0]))
else:
pairs.append(combo)
else:
pairs = basic_combinations
if shuffle_order:
random.shuffle(pairs)
return pairs
def generate_all_permutations(audio_fields, shuffle_order=True):
"""
生成所有可能的有序对(包括正向和反向)
Args:
audio_fields: 音频字段列表
shuffle_order: 是否随机打乱顺序
Returns:
所有有序对的列表
"""
pairs = []
for i, field_a in enumerate(audio_fields):
for j, field_b in enumerate(audio_fields):
if i != j: # 不与自己比较
pairs.append((field_a, field_b))
if shuffle_order:
random.shuffle(pairs)
return pairs
def load_tasks(comparison_mode="random_reverse", seed=None):
"""
使用config配置的音频字段进行两两比较
Args:
comparison_mode: 比较模式
- "fixed": 固定顺序的组合(原始模式)
- "random_reverse": 随机决定是否反转每个组合的顺序
- "all_permutations": 生成所有可能的有序对
seed: 随机种子,仅在需要复现结果时使用
"""
if seed is not None:
random.seed(seed)
print(f"使用随机种子: {seed}")
else:
print("使用真随机模式")
print("处理数据集...")
audio_fields = config.AUDIO_FIELDS
text_field = config.FIELD_TEXT
instruction_field = config.FIELD_INSTRUCTION
print(f"使用音频字段: {audio_fields}")
print(f"文本字段: {text_field}")
print(f"指令字段: {instruction_field}")
print(f"比较模式: {comparison_mode}")
tasks = []
for i, row in enumerate(dataset):
processed_audios = {}
for field in audio_fields:
if field not in row or row[field] is None:
print(f"任务 {i} 缺少音频字段: {field}")
continue
audio_data, audio_rate = process_audio(row[field])
if (audio_data is not None and audio_rate is not None and
isinstance(audio_data, np.ndarray) and isinstance(audio_rate, int)):
processed_audios[field] = (audio_data, audio_rate)
else:
print(f"任务 {i} 的音频字段 {field} 处理失败")
if len(processed_audios) < 2:
print(f"跳过任务 {i}:有效音频数量不足")
continue
text = row.get(text_field, '')
instruction = row.get(instruction_field, '请比较这两个音频的质量')
available_fields = list(processed_audios.keys())
if comparison_mode == "fixed":
pairs = list(combinations(available_fields, 2))
elif comparison_mode == "random_reverse":
pairs = generate_random_pairs(available_fields, include_reverse=True, shuffle_order=True)
elif comparison_mode == "all_permutations":
pairs = generate_all_permutations(available_fields, shuffle_order=True)
else:
raise ValueError(f"未知的比较模式: {comparison_mode}")
for field_a, field_b in pairs:
tasks.append({
"instruction": instruction,
"text": text,
"audioA": processed_audios[field_a],
"audioB": processed_audios[field_b],
"audioA_source": field_a,
"audioB_source": field_b,
"comparison": f"{field_a} vs {field_b}",
"original_index": i
})
print(f"成功生成 {len(tasks)} 个比较任务")
if len(tasks) == 0:
print("没有可用任务!")
exit()
comparison_counts = {}
for task in tasks:
comp = task["comparison"]
comparison_counts[comp] = comparison_counts.get(comp, 0) + 1
print("比较任务统计:")
for comp, count in sorted(comparison_counts.items()):
print(f" {comp}: {count} 个任务")
return tasks
|