Spaces:
Running on Zero
Running on Zero
File size: 7,416 Bytes
03022ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import json
import torch
import logging
import numpy as np
class FunCineForgeDS(torch.utils.data.Dataset):
def __init__(self, data_jsonl: str, **kwargs):
super().__init__()
self.max_source_length = kwargs.get("max_source_length", None)
self.max_text_length = kwargs.get("max_text_length", None)
self.max_token_length = kwargs.get("max_token_length", None)
self.ignore_id = kwargs.get("ignore_id", -100)
self.frame_shift = kwargs.get("frame_shift", 25)
self.timebook_size = kwargs.get("timebook_size", 1500)
self.type_map = {"旁白": kwargs.get("pangbai", self.timebook_size),
"独白": kwargs.get("dubai", self.timebook_size + 1),
"对话": kwargs.get("duihua", self.timebook_size + 2),
"多人": kwargs.get("duoren", self.timebook_size + 3),}
self.gender_map = {"男": kwargs.get("male", self.timebook_size + 4),
"male": kwargs.get("male", self.timebook_size + 4),
"女": kwargs.get("female", self.timebook_size + 5),
"female": kwargs.get("female", self.timebook_size + 5),}
self.age_map = {"儿童": kwargs.get("child", self.timebook_size + 6),
"child": kwargs.get("child", self.timebook_size + 6),
"青年": kwargs.get("youth", self.timebook_size + 7),
"teenager": kwargs.get("youth", self.timebook_size + 7),
"中年": kwargs.get("adult", self.timebook_size + 8),
"adult": kwargs.get("adult", self.timebook_size + 8),
"中老年": kwargs.get("middle", self.timebook_size + 9),
"middle-aged": kwargs.get("middle", self.timebook_size + 9),
"老年": kwargs.get("elderly", self.timebook_size + 10),
"elderly": kwargs.get("elderly", self.timebook_size + 10)}
self.speaker_id_start = kwargs.get("speaker_id_start", self.timebook_size + 11)
load_meta_data_key = kwargs.get("load_meta_data_key").split(",")
if not (data_jsonl.endswith(".jsonl") or data_jsonl.endswith(".json")):
# jsonl list file
with open(data_jsonl, encoding="utf-8") as fin:
file_list = fin.readlines()
logging.info(f"file_list: {file_list}")
else:
file_list = [data_jsonl]
contents = []
for file_json in file_list:
with open(file_json.strip(), encoding="utf-8") as fin:
for line in fin:
data_dict = json.loads(line.strip())
utt = data_dict["utt"]
data_type = data_dict.get("type")
type_id = self.type_map[data_type] if data_type in self.type_map else 1500
data = data_dict["messages"]
speech_length = data_dict.get("speech_length", -1)
# 2 for startofclue, endofclue
text_length = data_dict.get("text_length", -1) + data_dict.get("clue_length", -1) + 2
if self.max_token_length is not None and (speech_length > self.max_token_length or speech_length <= 0):
logging.info(
f"speech_length: {speech_length} > {self.max_token_length}, drop it: {data_dict}"
)
continue
if self.max_text_length is not None and (text_length > self.max_text_length or text_length <= 0):
logging.info(
f"text_length: {text_length} > {self.max_text_length}, drop it: {data_dict}"
)
continue
skip_flag = None
roles = {item.get("role") for item in data}
for key in load_meta_data_key:
if key not in roles:
skip_flag = key
break
if skip_flag is not None:
logging.info(
f"doesn't have {skip_flag}, drop it: {data_dict}")
continue
contents_i = {}
timespk_ids_len = 0
for i, item in enumerate(data):
role = item["role"]
content = item["content"]
for key in load_meta_data_key:
if role == key:
if key == "dialogue":
timespk_ids = self.timespk_to_codec(content)
timespk_ids_len = len(timespk_ids)
if timespk_ids_len == 0:
logging.info(f"[WARNING] len of timespk_ids is 0: {data_dict}")
contents_i["timespk_ids"] = timespk_ids
else:
contents_i[role] = content
contents_i["utt"] = utt
contents_i["type_id"] = type_id
# face embs len = speech tokens len, so need * 2;
# 4: sos, tos, eos; type_id
contents_i["source_len"] = speech_length * 2 + text_length + timespk_ids_len + 4
contents_i["speech_len"] = speech_length
contents_i["text_len"] = text_length # include clue_length
contents.append(contents_i)
self.contents = contents
logging.info("total_num of samplers: {}, {}".format(len(self.contents), data_jsonl))
def timespk_to_codec(self, dialogue):
# tuple tokens (start, spk, gender, age, end) * n_parts
n_parts = len(dialogue)
if n_parts == 0:
return np.array([], dtype=np.int64)
starts = np.array([part["start"] for part in dialogue])
durations = np.array([part["duration"] for part in dialogue])
speakers = np.array([int(part["spk"]) for part in dialogue])
genders = [part["gender"] for part in dialogue]
ages = [part["age"] for part in dialogue]
start_idxs = (starts * self.frame_shift + 1).astype(np.int64)
end_idxs = ((starts + durations) * self.frame_shift + 1).astype(np.int64)
spk_ids = (self.speaker_id_start + speakers - 1).astype(np.int64)
gender_ids = [self.gender_map.get(g, self.ignore_id) for g in genders]
age_ids = [self.age_map.get(a, self.ignore_id) for a in ages]
sequence = np.full(n_parts * 5, self.ignore_id, dtype=np.int64)
sequence[0::5] = start_idxs
sequence[1::5] = spk_ids
sequence[2::5] = gender_ids
sequence[3::5] = age_ids
sequence[4::5] = end_idxs
return sequence
def __len__(self):
return len(self.contents)
def __getitem__(self, index):
data = self.contents[index]
return data
def get_source_len(self, data_dict):
source_len = data_dict.get("source_len", 0)
return source_len
def get_target_len(self, data_dict):
target_len = data_dict.get("speech_len", 0)
return target_len |