DSTK / semantic_tokenizer /f40ms /simple_tokenizer_infer.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Daxin Tan,
# Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import librosa
import soundfile as sf
import torch
import torch.nn.functional as F
import numpy as np
import argparse
from tqdm import tqdm
from pathlib import Path
from omegaconf import OmegaConf
from fairseq import checkpoint_utils, utils
from fairseq.data.audio.audio_utils import get_features_or_waveform
from fairseq.models import import_models
import sys
current_root = Path(__file__).absolute().parent
sys.path.append(str(current_root))
relative_path = Path(current_root.parent.name) / current_root.name
namespace = str(relative_path / "models").replace("/" , ".")
import_models(str(current_root / "models"), namespace)
console_format = logging.Formatter(
"[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s"
)
console_handler = logging.StreamHandler()
console_handler.setFormatter(console_format)
console_handler.setLevel(logging.INFO)
defalut_handler = logging.root.handlers[0]
logging.root.removeHandler(defalut_handler)
logging.root.addHandler(console_handler)
logging.root.setLevel(logging.INFO)
TOKENIZE_ON_NPU = os.environ.get("TOKENIZE_ON_NPU")
if TOKENIZE_ON_NPU is not None and TOKENIZE_ON_NPU == "1":
import fairseq_npu_patch
import torch_npu
from torch_npu.contrib import transfer_to_npu
logging.info("Applying Patches for NPU!!!")
fairseq_npu_patch.patch_for_npu()
TOKENIZER_CFG_NAME = "hubert_config"
def get_unit_sequence(batch_quantized_ids, batch_quantized_ids_length):
unit_sequence_list, reduced_unit_sequence_list = [], []
for k, feat_len in enumerate(batch_quantized_ids_length):
feat = batch_quantized_ids[k][:feat_len]
unit_list = feat.cpu().numpy().tolist()
reduced_unit_list = []
prev_unit = None
for unit in unit_list:
if unit != prev_unit:
reduced_unit_list.append(unit)
prev_unit = unit
unit_sequence = " ".join([str(x) for x in unit_list])
reduced_unit_sequence = " ".join([str(x) for x in reduced_unit_list])
unit_sequence_list.append(unit_sequence)
reduced_unit_sequence_list.append(reduced_unit_sequence)
return unit_sequence_list, reduced_unit_sequence_list
def collater_audio(audios, audio_size, pad_audio=True):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = torch.BoolTensor(collated_audios.shape).fill_(False)
audio_starts = [0 for _ in audios]
for i, audio in enumerate(audios):
diff = len(audio) - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
return collated_audios, padding_mask, audio_starts
class SpeechTokenizer(object):
def __init__(
self,
ckpt_path: str = str(
(Path(__file__).parent / "ckpt/model.pt").absolute()
),
cfg_path: str = str((Path(__file__).parent / "config").absolute()),
cfg_name: str = TOKENIZER_CFG_NAME,
):
w2v_args = OmegaConf.load(f"{cfg_path}/{cfg_name}.yaml")
OmegaConf.update(w2v_args, "task.label_dir", cfg_path)
overrides = {
"task": {"pad_audio": True, "random_crop": False},
"common": {"seed": 1234},
}
## 以下是 hubert 提取的修复
overrides.update({"model": {"w2v_args": w2v_args}})
(
model,
cfg,
task,
) = checkpoint_utils.load_model_ensemble_and_task(
[ckpt_path], arg_overrides=overrides
)
self.model = model[0].eval().cuda()
self.task = task
self.use_cuda = True
self.use_fp16 = False
logging.info(f"TASK CONFIG:\n{self.task.cfg}")
def extract_single_segment(self, raw_wav):
wav = torch.from_numpy(raw_wav).float()
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
samples = [{"id": 0, "source": wav, "label_list": "None"}]
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
audio_size = max(audio_sizes)
collated_audios, padding_mask, audio_starts = collater_audio(audios, audio_size)
net_input = {"source": collated_audios, "padding_mask": padding_mask}
sample = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
}
sample = utils.move_to_cuda(sample) if self.use_cuda else sample
def apply_half(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.half)
return t
if self.use_fp16:
sample = utils.apply_to_sample(apply_half, sample)
self.model.set_num_updates(0)
with torch.no_grad():
net_output = self.model(**sample["net_input"])
batch_quantized_ids = net_output["quantized_ids"] # of shape (B, T)
batch_quantized_ids_length = net_output["quantized_id_lengths"] # of shape
unit_sequence_list, reduced_unit_sequence_list = get_unit_sequence(
batch_quantized_ids, batch_quantized_ids_length
)
numbers = reduced_unit_sequence_list[0].split()
audio_token = "".join([f"<|speech_{number}|>" for number in numbers])
unit_sequence = unit_sequence_list[0]
reduced_unit_sequence = reduced_unit_sequence_list[0]
return audio_token, unit_sequence, reduced_unit_sequence
def extract(self, raw_wavs_list, speech_tokenizer_segment_len=0):
"""
提取逻辑。
speech_tokenizer_segment_len (int, optional): 用于音频切割的长度。默认为0 表示不进行切割, 如果大于0, 则对每个音频文件进行切割。
"""
info_list = []
audio_token_list = []
for raw_wav in tqdm(raw_wavs_list):
wav_len = raw_wav.shape[0]
if (
speech_tokenizer_segment_len > 0
and wav_len > speech_tokenizer_segment_len
):
audio_token = ""
unit_sequence_list = []
reduced_unit_sequence_list = []
num_segments = int(np.ceil(wav_len / speech_tokenizer_segment_len))
# 拆分音频
for i in range(num_segments):
start_sample = int(i * speech_tokenizer_segment_len)
end_sample = int((i + 1) * speech_tokenizer_segment_len)
segment_wav = raw_wav[start_sample:end_sample]
(
segment_audio_token,
segment_unit_sequence,
segment_reduced_unit_sequence,
) = self.extract_single_segment(segment_wav)
audio_token += segment_audio_token
unit_sequence_list.extend(segment_unit_sequence.split(" "))
reduced_unit_sequence_list.extend(
segment_reduced_unit_sequence.split(" ")
)
unit_sequence = " ".join(unit_sequence_list)
reduced_unit_sequence = " ".join(reduced_unit_sequence_list)
else:
audio_token, unit_sequence, reduced_unit_sequence = (
self.extract_single_segment(raw_wav)
)
audio_token_list.append(audio_token)
info_list.append(
{
"unit_sequence": unit_sequence,
"reduced_unit_sequence": reduced_unit_sequence,
}
)
return audio_token_list, info_list
def main(args):
if args.ckpt is not None:
tokenizer = SpeechTokenizer(
ckpt_path=args.ckpt, cfg_path=args.cfg_path, cfg_name=args.cfg_name
)
else:
tokenizer = SpeechTokenizer()
wav_file_list = []
with open(args.input_list, "r") as input_file:
for line in input_file:
wav_file_list.append(line.strip().split("|")[0])
raw_wavs_list = [] # 用librosa 加载后的raw wave 波形数据
for file_path in wav_file_list:
# 加载波形数据
raw_wav, sr = librosa.load(file_path, sr=16000) # sr=None 保留原始采样率
raw_wavs_list.append(raw_wav)
token_list, token_info_list = tokenizer.extract(raw_wavs_list) # 传入波形数据
with open(args.output_file, "w") as output_file:
for token_info in token_info_list:
logging.info(token_info["unit_sequence"])
output_file.write(token_info["unit_sequence"] + "\n")
output_file.close()
logging.info("Finished")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
dest="ckpt",
required=False,
help="path to ckpt",
)
parser.add_argument(
"--cfg-path",
dest="cfg_path",
required=False,
default="config",
help="path to config",
)
parser.add_argument(
"--cfg-name",
dest="cfg_name",
required=False,
default="hubert_config",
help="name of config",
)
parser.add_argument(
"--input-list",
dest="input_list",
required=True,
help="list of input wavform",
)
parser.add_argument(
"--output-file",
dest="output_file",
required=True,
help="file to output speech tokens",
)
args = parser.parse_args()
main(args)