File size: 7,416 Bytes
03022ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import json
import torch
import logging
import numpy as np


class FunCineForgeDS(torch.utils.data.Dataset):

    def __init__(self, data_jsonl: str, **kwargs):
        super().__init__()

        self.max_source_length = kwargs.get("max_source_length", None)
        self.max_text_length = kwargs.get("max_text_length", None)
        self.max_token_length = kwargs.get("max_token_length", None)
        self.ignore_id = kwargs.get("ignore_id", -100)
        self.frame_shift = kwargs.get("frame_shift", 25)
        self.timebook_size = kwargs.get("timebook_size", 1500)
        self.type_map = {"旁白": kwargs.get("pangbai", self.timebook_size),
                         "独白": kwargs.get("dubai", self.timebook_size + 1),
                         "对话": kwargs.get("duihua", self.timebook_size + 2),
                         "多人": kwargs.get("duoren", self.timebook_size + 3),}
        self.gender_map = {"男": kwargs.get("male", self.timebook_size + 4), 
                           "male": kwargs.get("male", self.timebook_size + 4), 
                           "女": kwargs.get("female", self.timebook_size + 5),
                           "female": kwargs.get("female", self.timebook_size + 5),}
        self.age_map = {"儿童": kwargs.get("child", self.timebook_size + 6), 
                        "child": kwargs.get("child", self.timebook_size + 6), 
                        "青年": kwargs.get("youth", self.timebook_size + 7), 
                        "teenager": kwargs.get("youth", self.timebook_size + 7), 
                        "中年": kwargs.get("adult", self.timebook_size + 8), 
                        "adult": kwargs.get("adult", self.timebook_size + 8), 
                        "中老年": kwargs.get("middle", self.timebook_size + 9), 
                        "middle-aged": kwargs.get("middle", self.timebook_size + 9), 
                        "老年": kwargs.get("elderly", self.timebook_size + 10),
                        "elderly": kwargs.get("elderly", self.timebook_size + 10)}
        self.speaker_id_start = kwargs.get("speaker_id_start", self.timebook_size + 11)
        
        load_meta_data_key = kwargs.get("load_meta_data_key").split(",")

        if not (data_jsonl.endswith(".jsonl") or data_jsonl.endswith(".json")):
            # jsonl list file
            with open(data_jsonl, encoding="utf-8") as fin:
                file_list = fin.readlines()
                logging.info(f"file_list: {file_list}")
        else:
            file_list = [data_jsonl]
            
        contents = []
        for file_json in file_list:
            with open(file_json.strip(), encoding="utf-8") as fin:
                for line in fin:
                    data_dict = json.loads(line.strip())
                    utt = data_dict["utt"]
                    data_type = data_dict.get("type")
                    type_id = self.type_map[data_type] if data_type in self.type_map else 1500
                    data = data_dict["messages"]
                    speech_length = data_dict.get("speech_length", -1)
                    # 2 for startofclue, endofclue
                    text_length = data_dict.get("text_length", -1) + data_dict.get("clue_length", -1) + 2
                    if self.max_token_length is not None and (speech_length > self.max_token_length or speech_length <= 0):
                        logging.info(
                            f"speech_length: {speech_length} > {self.max_token_length}, drop it: {data_dict}"
                        )
                        continue
                    if self.max_text_length is not None and (text_length > self.max_text_length or text_length <= 0):
                        logging.info(
                            f"text_length: {text_length} > {self.max_text_length}, drop it: {data_dict}"
                        )
                        continue

                    skip_flag = None
                    roles = {item.get("role") for item in data}
                    for key in load_meta_data_key:
                        if key not in roles:
                            skip_flag = key
                            break
                    if skip_flag is not None:
                        logging.info(
                            f"doesn't have {skip_flag}, drop it: {data_dict}")
                        continue

                    contents_i = {}
                    timespk_ids_len = 0
                    for i, item in enumerate(data):
                        role = item["role"]
                        content = item["content"]
                        for key in load_meta_data_key:
                            if role == key:
                                if key == "dialogue":
                                    timespk_ids = self.timespk_to_codec(content)
                                    timespk_ids_len = len(timespk_ids)
                                    if timespk_ids_len == 0:
                                        logging.info(f"[WARNING] len of timespk_ids is 0: {data_dict}")
                                    contents_i["timespk_ids"] = timespk_ids
                                else:
                                    contents_i[role] = content
                    contents_i["utt"] = utt
                    contents_i["type_id"] = type_id
                    # face embs len = speech tokens len, so need * 2;
                    # 4: sos, tos, eos; type_id
                    contents_i["source_len"] = speech_length * 2 + text_length + timespk_ids_len + 4
                    contents_i["speech_len"] = speech_length
                    contents_i["text_len"] = text_length # include clue_length
                    contents.append(contents_i)

        self.contents = contents

        logging.info("total_num of samplers: {}, {}".format(len(self.contents), data_jsonl))


    def timespk_to_codec(self, dialogue):
        # tuple tokens (start, spk, gender, age, end) * n_parts
        n_parts = len(dialogue)
        if n_parts == 0:
            return np.array([], dtype=np.int64)
        starts = np.array([part["start"] for part in dialogue])
        durations = np.array([part["duration"] for part in dialogue])
        speakers = np.array([int(part["spk"]) for part in dialogue])
        genders = [part["gender"] for part in dialogue]
        ages = [part["age"] for part in dialogue]
        
        start_idxs = (starts * self.frame_shift + 1).astype(np.int64)
        end_idxs = ((starts + durations) * self.frame_shift + 1).astype(np.int64)
        spk_ids = (self.speaker_id_start + speakers - 1).astype(np.int64)
        gender_ids = [self.gender_map.get(g, self.ignore_id) for g in genders]
        age_ids = [self.age_map.get(a, self.ignore_id) for a in ages]
        
        sequence = np.full(n_parts * 5, self.ignore_id, dtype=np.int64)
        sequence[0::5] = start_idxs
        sequence[1::5] = spk_ids
        sequence[2::5] = gender_ids
        sequence[3::5] = age_ids
        sequence[4::5] = end_idxs
        return sequence

    def __len__(self):
        return len(self.contents)

    def __getitem__(self, index):

        data = self.contents[index]

        return data

    def get_source_len(self, data_dict):
        source_len = data_dict.get("source_len", 0)
        return source_len

    def get_target_len(self, data_dict):
        target_len = data_dict.get("speech_len", 0)
        return target_len