import numpy as np from datasets import load_dataset import os import config from itertools import combinations import random os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "true" dataset = load_dataset(config.PROCESS_REPO_ID, split="train") def process_audio(audio_obj): """处理音频对象,返回音频数据和采样率""" try: if hasattr(audio_obj, 'get_all_samples'): samples = audio_obj.get_all_samples() audio_data = samples.data if not isinstance(audio_data, np.ndarray): audio_data = np.array(audio_data, dtype=np.float32) sample_rate = samples.sample_rate if not isinstance(sample_rate, int): sample_rate = int(sample_rate) if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=0) return audio_data, sample_rate else: print("音频对象缺少 get_all_samples 方法") return None, None except Exception as e: print(f"处理音频失败: {e}") return None, None def generate_random_pairs(audio_fields, include_reverse=True, shuffle_order=True): """ 生成随机的音频对比较对 Args: audio_fields: 音频字段列表 include_reverse: 是否包含反向比较(A vs B 和 B vs A) shuffle_order: 是否随机打乱比较对的顺序 Returns: 比较对的列表 """ basic_combinations = list(combinations(audio_fields, 2)) if include_reverse: pairs = [] for combo in basic_combinations: if random.choice([True, False]): pairs.append((combo[1], combo[0])) else: pairs.append(combo) else: pairs = basic_combinations if shuffle_order: random.shuffle(pairs) return pairs def generate_all_permutations(audio_fields, shuffle_order=True): """ 生成所有可能的有序对(包括正向和反向) Args: audio_fields: 音频字段列表 shuffle_order: 是否随机打乱顺序 Returns: 所有有序对的列表 """ pairs = [] for i, field_a in enumerate(audio_fields): for j, field_b in enumerate(audio_fields): if i != j: # 不与自己比较 pairs.append((field_a, field_b)) if shuffle_order: random.shuffle(pairs) return pairs def load_tasks(comparison_mode="random_reverse", seed=None): """ 使用config配置的音频字段进行两两比较 Args: comparison_mode: 比较模式 - "fixed": 固定顺序的组合(原始模式) - "random_reverse": 随机决定是否反转每个组合的顺序 - "all_permutations": 生成所有可能的有序对 seed: 随机种子,仅在需要复现结果时使用 """ if seed is not None: random.seed(seed) print(f"使用随机种子: {seed}") else: print("使用真随机模式") print("处理数据集...") audio_fields = config.AUDIO_FIELDS text_field = config.FIELD_TEXT instruction_field = config.FIELD_INSTRUCTION print(f"使用音频字段: {audio_fields}") print(f"文本字段: {text_field}") print(f"指令字段: {instruction_field}") print(f"比较模式: {comparison_mode}") tasks = [] for i, row in enumerate(dataset): processed_audios = {} for field in audio_fields: if field not in row or row[field] is None: print(f"任务 {i} 缺少音频字段: {field}") continue audio_data, audio_rate = process_audio(row[field]) if (audio_data is not None and audio_rate is not None and isinstance(audio_data, np.ndarray) and isinstance(audio_rate, int)): processed_audios[field] = (audio_data, audio_rate) else: print(f"任务 {i} 的音频字段 {field} 处理失败") if len(processed_audios) < 2: print(f"跳过任务 {i}:有效音频数量不足") continue text = row.get(text_field, '') instruction = row.get(instruction_field, '请比较这两个音频的质量') available_fields = list(processed_audios.keys()) if comparison_mode == "fixed": pairs = list(combinations(available_fields, 2)) elif comparison_mode == "random_reverse": pairs = generate_random_pairs(available_fields, include_reverse=True, shuffle_order=True) elif comparison_mode == "all_permutations": pairs = generate_all_permutations(available_fields, shuffle_order=True) else: raise ValueError(f"未知的比较模式: {comparison_mode}") for field_a, field_b in pairs: tasks.append({ "instruction": instruction, "text": text, "audioA": processed_audios[field_a], "audioB": processed_audios[field_b], "audioA_source": field_a, "audioB_source": field_b, "comparison": f"{field_a} vs {field_b}", "original_index": i }) print(f"成功生成 {len(tasks)} 个比较任务") if len(tasks) == 0: print("没有可用任务!") exit() comparison_counts = {} for task in tasks: comp = task["comparison"] comparison_counts[comp] = comparison_counts.get(comp, 0) + 1 print("比较任务统计:") for comp, count in sorted(comparison_counts.items()): print(f" {comp}: {count} 个任务") return tasks