| import itertools
|
| import os
|
| import re
|
| from collections import defaultdict
|
| from functools import partial
|
| from multiprocessing import Pool
|
| from pathlib import Path
|
|
|
| import click
|
| import numpy as np
|
| from loguru import logger
|
| from tqdm import tqdm
|
|
|
| from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
| from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
| from tools.file import load_filelist
|
|
|
|
|
| os.environ["MKL_NUM_THREADS"] = "1"
|
| os.environ["OMP_NUM_THREADS"] = "1"
|
|
|
|
|
| def task_generator_folder(root: Path, text_extension: str):
|
| files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
|
| files = sorted(files)
|
|
|
| grouped_files = defaultdict(list)
|
| for file in tqdm(files, desc=f"Grouping {root}"):
|
| p = str(file.parent)
|
| speaker = file.parent.name
|
|
|
| try:
|
| if isinstance(text_extension, str):
|
| texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
|
| else:
|
| texts = [
|
| file.with_suffix(ext).read_text(encoding="utf-8")
|
| for ext in text_extension
|
| ]
|
| except Exception as e:
|
| logger.error(f"Failed to read text {file}: {e}")
|
| continue
|
|
|
| grouped_files[p].append((speaker, file, texts))
|
|
|
| logger.info(
|
| f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
|
| )
|
|
|
| for i in grouped_files.values():
|
| subset = [(f, t) for _, f, t in i]
|
| yield i[0][0], subset, "folder"
|
|
|
|
|
| def task_generator_filelist(filelist):
|
| grouped_files = defaultdict(list)
|
| for filename, speaker, _, text in load_filelist(filelist):
|
| grouped_files[speaker].append((Path(filename), [text]))
|
|
|
| logger.info(f"Found {len(grouped_files)} groups in {filelist}")
|
| for speaker, values in grouped_files.items():
|
| yield speaker, values, "filelist"
|
|
|
|
|
| def run_task(task):
|
| name, subset, source = task
|
|
|
|
|
| sentences = []
|
| for file, texts in subset:
|
| np_file = file.with_suffix(".npy")
|
| if np_file.exists() is False:
|
| logger.warning(f"Can't find {np_file}")
|
| continue
|
|
|
| new_texts = []
|
|
|
| for text in texts:
|
|
|
| text = re.sub(r"\{.*?\}", " ", text)
|
| text = re.sub(r"<.*?>", " ", text)
|
| text = re.sub(r"\s+", " ", text)
|
| new_texts.append(text)
|
|
|
| try:
|
| semantics = np.load(np_file)
|
| except Exception as e:
|
| logger.error(f"Failed to parse {file}: {e}")
|
| continue
|
|
|
| if isinstance(semantics, np.ndarray):
|
| semantics = semantics.tolist()
|
|
|
| sentences.append(
|
| Sentence(
|
| texts=new_texts,
|
| semantics=[Semantics(values=s) for s in semantics],
|
| )
|
| )
|
|
|
|
|
| return pack_pb_stream(
|
| TextData(
|
| source=source,
|
| name=name,
|
| sentences=sentences,
|
| )
|
| )
|
|
|
|
|
| @click.command()
|
| @click.option(
|
| "--input",
|
| type=click.Path(path_type=Path),
|
| required=True,
|
| help="A folder containing the dataset or a filelist",
|
| multiple=True,
|
| )
|
| @click.option(
|
| "--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
|
| )
|
| @click.option("--num-workers", type=int, default=16)
|
| @click.option("--text-extension", type=str, default=[".txt"], multiple=True)
|
| @click.option(
|
| "--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
|
| )
|
| def main(input, output, num_workers, text_extension, shard_size):
|
| generator_fns = []
|
|
|
| for f in input:
|
| assert f.exists(), f"{f} not found"
|
|
|
| if f.is_dir():
|
| generator_fn = task_generator_folder(f, text_extension)
|
| else:
|
| generator_fn = task_generator_filelist(f)
|
|
|
| generator_fns.append(generator_fn)
|
|
|
| generator_fn = itertools.chain(*generator_fns)
|
| output.mkdir(parents=True, exist_ok=True)
|
|
|
| dataset_fp = None
|
| tar_idx = 0
|
| written_size = 0
|
|
|
| with Pool(num_workers) as p:
|
| for result in tqdm(p.imap_unordered(run_task, generator_fn)):
|
| if dataset_fp is None:
|
| dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
|
|
|
| dataset_fp.write(result)
|
| written_size += len(result)
|
|
|
| if written_size > shard_size * 1024 * 1024:
|
| logger.info(f"Finished writing {tar_idx} shards to {output}")
|
| dataset_fp.close()
|
| dataset_fp = None
|
| written_size = 0
|
| tar_idx += 1
|
|
|
| if dataset_fp is not None:
|
| dataset_fp.close()
|
|
|
| logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|