File size: 2,740 Bytes
8613355 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import json
import logging
import numpy as np
import torchaudio
from torch.utils.data import Dataset
def _handle_wav(wav_path, target_rate=16000):
"""
处理单个音频文件
返回:
waveform: numpy数组(一维)
"""
waveform, sample_rate = torchaudio.load(wav_path)
if sample_rate != target_rate:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
audio = waveform[0]
return audio
def _handle_dialogue_evaluation(obj, sample_rate=16000):
"""
处理对话评价任务,将用户和机器人音频拼接
"""
# 加载用户和机器人音频
user_audio = _handle_wav(obj["user_wav"], sample_rate)
robot_audio = _handle_wav(obj["robot_wav"], sample_rate)
# 在两段音频之间添加0.5秒静音
silence = np.zeros(int(sample_rate * 0.5))
# 拼接音频
combined_audio = np.concatenate([user_audio.numpy(), silence, robot_audio.numpy()])
# 构建评价提示
prompt_template = (
"上面有一段对话,用户先说话,中间隔0.5s,机器人回答。请评价上述对话中机器人回答的合理性。"
"考虑机器人回答的情感、语气、内容等方面。"
"首先在<think></think>标签中详细分析,然后在<score></score>标签中给出1-10的评分。"
)
# 构建处理后的对象
processed_obj = {
"id": obj["id"],
"prompt": [{"role": "user", "content": [
{"type": "audio", "audio_url": "combined_audio"}, # 占位符
{"type": "text", "text": prompt_template}
]}],
"solution": f"<think>分析机器人回答的合理性</think><score>{obj['gt_score']}</score>",
"audio": combined_audio,
"gt_score": obj["gt_score"]
}
return processed_obj
class AudioDataset(Dataset):
def __init__(self, data_file, sample_rate=16000, is_perturb=False):
super().__init__()
self.data = []
# 加载JSON格式的对话评价数据
with open(data_file, 'r', encoding='utf8') as f:
data_list = json.load(f)
# 处理每个样本
for item in data_list:
processed_item = _handle_dialogue_evaluation(item, sample_rate)
self.data.append(processed_item)
self.sample_rate = sample_rate
self.is_perturb = is_perturb
logging.info(f"加载数据集: {data_file}, 样本数: {len(self.data)}, 采样率: {sample_rate}")
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index] |