|
|
import json
|
|
|
import logging
|
|
|
import numpy as np
|
|
|
|
|
|
import torchaudio
|
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
|
|
|
def _handle_wav(wav_path, target_rate=16000):
|
|
|
"""
|
|
|
处理单个音频文件
|
|
|
返回:
|
|
|
waveform: numpy数组(一维)
|
|
|
"""
|
|
|
waveform, sample_rate = torchaudio.load(wav_path)
|
|
|
if sample_rate != target_rate:
|
|
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
|
|
|
audio = waveform[0]
|
|
|
return audio
|
|
|
|
|
|
|
|
|
def _handle_dialogue_evaluation(obj, sample_rate=16000):
|
|
|
"""
|
|
|
处理对话评价任务,将用户和机器人音频拼接
|
|
|
"""
|
|
|
|
|
|
user_audio = _handle_wav(obj["user_wav"], sample_rate)
|
|
|
robot_audio = _handle_wav(obj["robot_wav"], sample_rate)
|
|
|
|
|
|
|
|
|
silence = np.zeros(int(sample_rate * 0.5))
|
|
|
|
|
|
|
|
|
combined_audio = np.concatenate([user_audio.numpy(), silence, robot_audio.numpy()])
|
|
|
|
|
|
|
|
|
prompt_template = (
|
|
|
"上面有一段对话,用户先说话,中间隔0.5s,机器人回答。请评价上述对话中机器人回答的合理性。"
|
|
|
"考虑机器人回答的情感、语气、内容等方面。"
|
|
|
"首先在<think></think>标签中详细分析,然后在<score></score>标签中给出1-10的评分。"
|
|
|
)
|
|
|
|
|
|
|
|
|
processed_obj = {
|
|
|
"id": obj["id"],
|
|
|
"prompt": [{"role": "user", "content": [
|
|
|
{"type": "audio", "audio_url": "combined_audio"},
|
|
|
{"type": "text", "text": prompt_template}
|
|
|
]}],
|
|
|
"solution": f"<think>分析机器人回答的合理性</think><score>{obj['gt_score']}</score>",
|
|
|
"audio": combined_audio,
|
|
|
"gt_score": obj["gt_score"]
|
|
|
}
|
|
|
|
|
|
return processed_obj
|
|
|
|
|
|
|
|
|
class AudioDataset(Dataset):
|
|
|
def __init__(self, data_file, sample_rate=16000, is_perturb=False):
|
|
|
super().__init__()
|
|
|
self.data = []
|
|
|
|
|
|
|
|
|
with open(data_file, 'r', encoding='utf8') as f:
|
|
|
data_list = json.load(f)
|
|
|
|
|
|
|
|
|
for item in data_list:
|
|
|
processed_item = _handle_dialogue_evaluation(item, sample_rate)
|
|
|
self.data.append(processed_item)
|
|
|
|
|
|
self.sample_rate = sample_rate
|
|
|
self.is_perturb = is_perturb
|
|
|
logging.info(f"加载数据集: {data_file}, 样本数: {len(self.data)}, 采样率: {sample_rate}")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data)
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
return self.data[index] |