AudioLabelingApp / data_processing.py
sunnyzjx's picture
Upload 7 files
099b013 verified
import numpy as np
from datasets import load_dataset
import os
import config
from itertools import combinations
import random
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "true"
dataset = load_dataset(config.PROCESS_REPO_ID, split="train")
def process_audio(audio_obj):
"""处理音频对象,返回音频数据和采样率"""
try:
if hasattr(audio_obj, 'get_all_samples'):
samples = audio_obj.get_all_samples()
audio_data = samples.data
if not isinstance(audio_data, np.ndarray):
audio_data = np.array(audio_data, dtype=np.float32)
sample_rate = samples.sample_rate
if not isinstance(sample_rate, int):
sample_rate = int(sample_rate)
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=0)
return audio_data, sample_rate
else:
print("音频对象缺少 get_all_samples 方法")
return None, None
except Exception as e:
print(f"处理音频失败: {e}")
return None, None
def generate_random_pairs(audio_fields, include_reverse=True, shuffle_order=True):
"""
生成随机的音频对比较对
Args:
audio_fields: 音频字段列表
include_reverse: 是否包含反向比较(A vs B 和 B vs A)
shuffle_order: 是否随机打乱比较对的顺序
Returns:
比较对的列表
"""
basic_combinations = list(combinations(audio_fields, 2))
if include_reverse:
pairs = []
for combo in basic_combinations:
if random.choice([True, False]):
pairs.append((combo[1], combo[0]))
else:
pairs.append(combo)
else:
pairs = basic_combinations
if shuffle_order:
random.shuffle(pairs)
return pairs
def generate_all_permutations(audio_fields, shuffle_order=True):
"""
生成所有可能的有序对(包括正向和反向)
Args:
audio_fields: 音频字段列表
shuffle_order: 是否随机打乱顺序
Returns:
所有有序对的列表
"""
pairs = []
for i, field_a in enumerate(audio_fields):
for j, field_b in enumerate(audio_fields):
if i != j: # 不与自己比较
pairs.append((field_a, field_b))
if shuffle_order:
random.shuffle(pairs)
return pairs
def load_tasks(comparison_mode="random_reverse", seed=None):
"""
使用config配置的音频字段进行两两比较
Args:
comparison_mode: 比较模式
- "fixed": 固定顺序的组合(原始模式)
- "random_reverse": 随机决定是否反转每个组合的顺序
- "all_permutations": 生成所有可能的有序对
seed: 随机种子,仅在需要复现结果时使用
"""
if seed is not None:
random.seed(seed)
print(f"使用随机种子: {seed}")
else:
print("使用真随机模式")
print("处理数据集...")
audio_fields = config.AUDIO_FIELDS
text_field = config.FIELD_TEXT
instruction_field = config.FIELD_INSTRUCTION
print(f"使用音频字段: {audio_fields}")
print(f"文本字段: {text_field}")
print(f"指令字段: {instruction_field}")
print(f"比较模式: {comparison_mode}")
tasks = []
for i, row in enumerate(dataset):
processed_audios = {}
for field in audio_fields:
if field not in row or row[field] is None:
print(f"任务 {i} 缺少音频字段: {field}")
continue
audio_data, audio_rate = process_audio(row[field])
if (audio_data is not None and audio_rate is not None and
isinstance(audio_data, np.ndarray) and isinstance(audio_rate, int)):
processed_audios[field] = (audio_data, audio_rate)
else:
print(f"任务 {i} 的音频字段 {field} 处理失败")
if len(processed_audios) < 2:
print(f"跳过任务 {i}:有效音频数量不足")
continue
text = row.get(text_field, '')
instruction = row.get(instruction_field, '请比较这两个音频的质量')
available_fields = list(processed_audios.keys())
if comparison_mode == "fixed":
pairs = list(combinations(available_fields, 2))
elif comparison_mode == "random_reverse":
pairs = generate_random_pairs(available_fields, include_reverse=True, shuffle_order=True)
elif comparison_mode == "all_permutations":
pairs = generate_all_permutations(available_fields, shuffle_order=True)
else:
raise ValueError(f"未知的比较模式: {comparison_mode}")
for field_a, field_b in pairs:
tasks.append({
"instruction": instruction,
"text": text,
"audioA": processed_audios[field_a],
"audioB": processed_audios[field_b],
"audioA_source": field_a,
"audioB_source": field_b,
"comparison": f"{field_a} vs {field_b}",
"original_index": i
})
print(f"成功生成 {len(tasks)} 个比较任务")
if len(tasks) == 0:
print("没有可用任务!")
exit()
comparison_counts = {}
for task in tasks:
comp = task["comparison"]
comparison_counts[comp] = comparison_counts.get(comp, 0) + 1
print("比较任务统计:")
for comp, count in sorted(comparison_counts.items()):
print(f" {comp}: {count} 个任务")
return tasks