Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from datasets import load_dataset | |
| import os | |
| import config | |
| from itertools import combinations | |
| import random | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "true" | |
| dataset = load_dataset(config.PROCESS_REPO_ID, split="train") | |
| def process_audio(audio_obj): | |
| """处理音频对象,返回音频数据和采样率""" | |
| try: | |
| if hasattr(audio_obj, 'get_all_samples'): | |
| samples = audio_obj.get_all_samples() | |
| audio_data = samples.data | |
| if not isinstance(audio_data, np.ndarray): | |
| audio_data = np.array(audio_data, dtype=np.float32) | |
| sample_rate = samples.sample_rate | |
| if not isinstance(sample_rate, int): | |
| sample_rate = int(sample_rate) | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.mean(axis=0) | |
| return audio_data, sample_rate | |
| else: | |
| print("音频对象缺少 get_all_samples 方法") | |
| return None, None | |
| except Exception as e: | |
| print(f"处理音频失败: {e}") | |
| return None, None | |
| def generate_random_pairs(audio_fields, include_reverse=True, shuffle_order=True): | |
| """ | |
| 生成随机的音频对比较对 | |
| Args: | |
| audio_fields: 音频字段列表 | |
| include_reverse: 是否包含反向比较(A vs B 和 B vs A) | |
| shuffle_order: 是否随机打乱比较对的顺序 | |
| Returns: | |
| 比较对的列表 | |
| """ | |
| basic_combinations = list(combinations(audio_fields, 2)) | |
| if include_reverse: | |
| pairs = [] | |
| for combo in basic_combinations: | |
| if random.choice([True, False]): | |
| pairs.append((combo[1], combo[0])) | |
| else: | |
| pairs.append(combo) | |
| else: | |
| pairs = basic_combinations | |
| if shuffle_order: | |
| random.shuffle(pairs) | |
| return pairs | |
| def generate_all_permutations(audio_fields, shuffle_order=True): | |
| """ | |
| 生成所有可能的有序对(包括正向和反向) | |
| Args: | |
| audio_fields: 音频字段列表 | |
| shuffle_order: 是否随机打乱顺序 | |
| Returns: | |
| 所有有序对的列表 | |
| """ | |
| pairs = [] | |
| for i, field_a in enumerate(audio_fields): | |
| for j, field_b in enumerate(audio_fields): | |
| if i != j: # 不与自己比较 | |
| pairs.append((field_a, field_b)) | |
| if shuffle_order: | |
| random.shuffle(pairs) | |
| return pairs | |
| def load_tasks(comparison_mode="random_reverse", seed=None): | |
| """ | |
| 使用config配置的音频字段进行两两比较 | |
| Args: | |
| comparison_mode: 比较模式 | |
| - "fixed": 固定顺序的组合(原始模式) | |
| - "random_reverse": 随机决定是否反转每个组合的顺序 | |
| - "all_permutations": 生成所有可能的有序对 | |
| seed: 随机种子,仅在需要复现结果时使用 | |
| """ | |
| if seed is not None: | |
| random.seed(seed) | |
| print(f"使用随机种子: {seed}") | |
| else: | |
| print("使用真随机模式") | |
| print("处理数据集...") | |
| audio_fields = config.AUDIO_FIELDS | |
| text_field = config.FIELD_TEXT | |
| instruction_field = config.FIELD_INSTRUCTION | |
| print(f"使用音频字段: {audio_fields}") | |
| print(f"文本字段: {text_field}") | |
| print(f"指令字段: {instruction_field}") | |
| print(f"比较模式: {comparison_mode}") | |
| tasks = [] | |
| for i, row in enumerate(dataset): | |
| processed_audios = {} | |
| for field in audio_fields: | |
| if field not in row or row[field] is None: | |
| print(f"任务 {i} 缺少音频字段: {field}") | |
| continue | |
| audio_data, audio_rate = process_audio(row[field]) | |
| if (audio_data is not None and audio_rate is not None and | |
| isinstance(audio_data, np.ndarray) and isinstance(audio_rate, int)): | |
| processed_audios[field] = (audio_data, audio_rate) | |
| else: | |
| print(f"任务 {i} 的音频字段 {field} 处理失败") | |
| if len(processed_audios) < 2: | |
| print(f"跳过任务 {i}:有效音频数量不足") | |
| continue | |
| text = row.get(text_field, '') | |
| instruction = row.get(instruction_field, '请比较这两个音频的质量') | |
| available_fields = list(processed_audios.keys()) | |
| if comparison_mode == "fixed": | |
| pairs = list(combinations(available_fields, 2)) | |
| elif comparison_mode == "random_reverse": | |
| pairs = generate_random_pairs(available_fields, include_reverse=True, shuffle_order=True) | |
| elif comparison_mode == "all_permutations": | |
| pairs = generate_all_permutations(available_fields, shuffle_order=True) | |
| else: | |
| raise ValueError(f"未知的比较模式: {comparison_mode}") | |
| for field_a, field_b in pairs: | |
| tasks.append({ | |
| "instruction": instruction, | |
| "text": text, | |
| "audioA": processed_audios[field_a], | |
| "audioB": processed_audios[field_b], | |
| "audioA_source": field_a, | |
| "audioB_source": field_b, | |
| "comparison": f"{field_a} vs {field_b}", | |
| "original_index": i | |
| }) | |
| print(f"成功生成 {len(tasks)} 个比较任务") | |
| if len(tasks) == 0: | |
| print("没有可用任务!") | |
| exit() | |
| comparison_counts = {} | |
| for task in tasks: | |
| comp = task["comparison"] | |
| comparison_counts[comp] = comparison_counts.get(comp, 0) + 1 | |
| print("比较任务统计:") | |
| for comp, count in sorted(comparison_counts.items()): | |
| print(f" {comp}: {count} 个任务") | |
| return tasks | |