File size: 2,740 Bytes
8613355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import json
import logging
import numpy as np

import torchaudio
from torch.utils.data import Dataset


def _handle_wav(wav_path, target_rate=16000):
    """

    处理单个音频文件

    返回:

        waveform: numpy数组(一维)

    """
    waveform, sample_rate = torchaudio.load(wav_path)
    if sample_rate != target_rate:
        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
    audio = waveform[0]
    return audio


def _handle_dialogue_evaluation(obj, sample_rate=16000):
    """

    处理对话评价任务,将用户和机器人音频拼接

    """
    # 加载用户和机器人音频
    user_audio = _handle_wav(obj["user_wav"], sample_rate)
    robot_audio = _handle_wav(obj["robot_wav"], sample_rate)
    
    # 在两段音频之间添加0.5秒静音
    silence = np.zeros(int(sample_rate * 0.5))
    
    # 拼接音频
    combined_audio = np.concatenate([user_audio.numpy(), silence, robot_audio.numpy()])
    
    # 构建评价提示
    prompt_template = (
        "上面有一段对话,用户先说话,中间隔0.5s,机器人回答。请评价上述对话中机器人回答的合理性。"
        "考虑机器人回答的情感、语气、内容等方面。"
        "首先在<think></think>标签中详细分析,然后在<score></score>标签中给出1-10的评分。"
    )
    
    # 构建处理后的对象
    processed_obj = {
        "id": obj["id"],
        "prompt": [{"role": "user", "content": [
            {"type": "audio", "audio_url": "combined_audio"},  # 占位符
            {"type": "text", "text": prompt_template}
        ]}],
        "solution": f"<think>分析机器人回答的合理性</think><score>{obj['gt_score']}</score>",
        "audio": combined_audio,
        "gt_score": obj["gt_score"]
    }
    
    return processed_obj


class AudioDataset(Dataset):
    def __init__(self, data_file, sample_rate=16000, is_perturb=False):
        super().__init__()
        self.data = []
        
        # 加载JSON格式的对话评价数据
        with open(data_file, 'r', encoding='utf8') as f:
            data_list = json.load(f)
            
        # 处理每个样本
        for item in data_list:
            processed_item = _handle_dialogue_evaluation(item, sample_rate)
            self.data.append(processed_item)

        self.sample_rate = sample_rate
        self.is_perturb = is_perturb
        logging.info(f"加载数据集: {data_file}, 样本数: {len(self.data)}, 采样率: {sample_rate}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]