import json import torch import logging import numpy as np class FunCineForgeDS(torch.utils.data.Dataset): def __init__(self, data_jsonl: str, **kwargs): super().__init__() self.max_source_length = kwargs.get("max_source_length", None) self.max_text_length = kwargs.get("max_text_length", None) self.max_token_length = kwargs.get("max_token_length", None) self.ignore_id = kwargs.get("ignore_id", -100) self.frame_shift = kwargs.get("frame_shift", 25) self.timebook_size = kwargs.get("timebook_size", 1500) self.type_map = {"旁白": kwargs.get("pangbai", self.timebook_size), "独白": kwargs.get("dubai", self.timebook_size + 1), "对话": kwargs.get("duihua", self.timebook_size + 2), "多人": kwargs.get("duoren", self.timebook_size + 3),} self.gender_map = {"男": kwargs.get("male", self.timebook_size + 4), "male": kwargs.get("male", self.timebook_size + 4), "女": kwargs.get("female", self.timebook_size + 5), "female": kwargs.get("female", self.timebook_size + 5),} self.age_map = {"儿童": kwargs.get("child", self.timebook_size + 6), "child": kwargs.get("child", self.timebook_size + 6), "青年": kwargs.get("youth", self.timebook_size + 7), "teenager": kwargs.get("youth", self.timebook_size + 7), "中年": kwargs.get("adult", self.timebook_size + 8), "adult": kwargs.get("adult", self.timebook_size + 8), "中老年": kwargs.get("middle", self.timebook_size + 9), "middle-aged": kwargs.get("middle", self.timebook_size + 9), "老年": kwargs.get("elderly", self.timebook_size + 10), "elderly": kwargs.get("elderly", self.timebook_size + 10)} self.speaker_id_start = kwargs.get("speaker_id_start", self.timebook_size + 11) load_meta_data_key = kwargs.get("load_meta_data_key").split(",") if not (data_jsonl.endswith(".jsonl") or data_jsonl.endswith(".json")): # jsonl list file with open(data_jsonl, encoding="utf-8") as fin: file_list = fin.readlines() logging.info(f"file_list: {file_list}") else: file_list = [data_jsonl] contents = [] for file_json in file_list: with open(file_json.strip(), encoding="utf-8") as fin: for line in fin: data_dict = json.loads(line.strip()) utt = data_dict["utt"] data_type = data_dict.get("type") type_id = self.type_map[data_type] if data_type in self.type_map else 1500 data = data_dict["messages"] speech_length = data_dict.get("speech_length", -1) # 2 for startofclue, endofclue text_length = data_dict.get("text_length", -1) + data_dict.get("clue_length", -1) + 2 if self.max_token_length is not None and (speech_length > self.max_token_length or speech_length <= 0): logging.info( f"speech_length: {speech_length} > {self.max_token_length}, drop it: {data_dict}" ) continue if self.max_text_length is not None and (text_length > self.max_text_length or text_length <= 0): logging.info( f"text_length: {text_length} > {self.max_text_length}, drop it: {data_dict}" ) continue skip_flag = None roles = {item.get("role") for item in data} for key in load_meta_data_key: if key not in roles: skip_flag = key break if skip_flag is not None: logging.info( f"doesn't have {skip_flag}, drop it: {data_dict}") continue contents_i = {} timespk_ids_len = 0 for i, item in enumerate(data): role = item["role"] content = item["content"] for key in load_meta_data_key: if role == key: if key == "dialogue": timespk_ids = self.timespk_to_codec(content) timespk_ids_len = len(timespk_ids) if timespk_ids_len == 0: logging.info(f"[WARNING] len of timespk_ids is 0: {data_dict}") contents_i["timespk_ids"] = timespk_ids else: contents_i[role] = content contents_i["utt"] = utt contents_i["type_id"] = type_id # face embs len = speech tokens len, so need * 2; # 4: sos, tos, eos; type_id contents_i["source_len"] = speech_length * 2 + text_length + timespk_ids_len + 4 contents_i["speech_len"] = speech_length contents_i["text_len"] = text_length # include clue_length contents.append(contents_i) self.contents = contents logging.info("total_num of samplers: {}, {}".format(len(self.contents), data_jsonl)) def timespk_to_codec(self, dialogue): # tuple tokens (start, spk, gender, age, end) * n_parts n_parts = len(dialogue) if n_parts == 0: return np.array([], dtype=np.int64) starts = np.array([part["start"] for part in dialogue]) durations = np.array([part["duration"] for part in dialogue]) speakers = np.array([int(part["spk"]) for part in dialogue]) genders = [part["gender"] for part in dialogue] ages = [part["age"] for part in dialogue] start_idxs = (starts * self.frame_shift + 1).astype(np.int64) end_idxs = ((starts + durations) * self.frame_shift + 1).astype(np.int64) spk_ids = (self.speaker_id_start + speakers - 1).astype(np.int64) gender_ids = [self.gender_map.get(g, self.ignore_id) for g in genders] age_ids = [self.age_map.get(a, self.ignore_id) for a in ages] sequence = np.full(n_parts * 5, self.ignore_id, dtype=np.int64) sequence[0::5] = start_idxs sequence[1::5] = spk_ids sequence[2::5] = gender_ids sequence[3::5] = age_ids sequence[4::5] = end_idxs return sequence def __len__(self): return len(self.contents) def __getitem__(self, index): data = self.contents[index] return data def get_source_len(self, data_dict): source_len = data_dict.get("source_len", 0) return source_len def get_target_len(self, data_dict): target_len = data_dict.get("speech_len", 0) return target_len