|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import os |
|
|
import librosa |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
from pathlib import Path |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
from fairseq import checkpoint_utils, utils |
|
|
from fairseq.data.audio.audio_utils import get_features_or_waveform |
|
|
from fairseq.models import import_models |
|
|
|
|
|
import sys |
|
|
current_root = Path(__file__).absolute().parent |
|
|
sys.path.append(str(current_root)) |
|
|
|
|
|
relative_path = Path(current_root.parent.name) / current_root.name |
|
|
namespace = str(relative_path / "models").replace("/" , ".") |
|
|
import_models(str(current_root / "models"), namespace) |
|
|
|
|
|
|
|
|
console_format = logging.Formatter( |
|
|
"[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s" |
|
|
) |
|
|
console_handler = logging.StreamHandler() |
|
|
console_handler.setFormatter(console_format) |
|
|
console_handler.setLevel(logging.INFO) |
|
|
defalut_handler = logging.root.handlers[0] |
|
|
logging.root.removeHandler(defalut_handler) |
|
|
logging.root.addHandler(console_handler) |
|
|
logging.root.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
TOKENIZE_ON_NPU = os.environ.get("TOKENIZE_ON_NPU") |
|
|
if TOKENIZE_ON_NPU is not None and TOKENIZE_ON_NPU == "1": |
|
|
import fairseq_npu_patch |
|
|
import torch_npu |
|
|
from torch_npu.contrib import transfer_to_npu |
|
|
|
|
|
logging.info("Applying Patches for NPU!!!") |
|
|
fairseq_npu_patch.patch_for_npu() |
|
|
|
|
|
TOKENIZER_CFG_NAME = "hubert_config" |
|
|
|
|
|
def get_unit_sequence(batch_quantized_ids, batch_quantized_ids_length): |
|
|
unit_sequence_list, reduced_unit_sequence_list = [], [] |
|
|
for k, feat_len in enumerate(batch_quantized_ids_length): |
|
|
feat = batch_quantized_ids[k][:feat_len] |
|
|
unit_list = feat.cpu().numpy().tolist() |
|
|
reduced_unit_list = [] |
|
|
prev_unit = None |
|
|
for unit in unit_list: |
|
|
if unit != prev_unit: |
|
|
reduced_unit_list.append(unit) |
|
|
prev_unit = unit |
|
|
unit_sequence = " ".join([str(x) for x in unit_list]) |
|
|
reduced_unit_sequence = " ".join([str(x) for x in reduced_unit_list]) |
|
|
unit_sequence_list.append(unit_sequence) |
|
|
reduced_unit_sequence_list.append(reduced_unit_sequence) |
|
|
return unit_sequence_list, reduced_unit_sequence_list |
|
|
|
|
|
|
|
|
def collater_audio(audios, audio_size, pad_audio=True): |
|
|
collated_audios = audios[0].new_zeros(len(audios), audio_size) |
|
|
padding_mask = torch.BoolTensor(collated_audios.shape).fill_(False) |
|
|
audio_starts = [0 for _ in audios] |
|
|
for i, audio in enumerate(audios): |
|
|
diff = len(audio) - audio_size |
|
|
if diff == 0: |
|
|
collated_audios[i] = audio |
|
|
elif diff < 0: |
|
|
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) |
|
|
padding_mask[i, diff:] = True |
|
|
return collated_audios, padding_mask, audio_starts |
|
|
|
|
|
|
|
|
class SpeechTokenizer(object): |
|
|
def __init__( |
|
|
self, |
|
|
ckpt_path: str = str( |
|
|
(Path(__file__).parent / "ckpt/model.pt").absolute() |
|
|
), |
|
|
cfg_path: str = str((Path(__file__).parent / "config").absolute()), |
|
|
cfg_name: str = TOKENIZER_CFG_NAME, |
|
|
): |
|
|
w2v_args = OmegaConf.load(f"{cfg_path}/{cfg_name}.yaml") |
|
|
OmegaConf.update(w2v_args, "task.label_dir", cfg_path) |
|
|
overrides = { |
|
|
"task": {"pad_audio": True, "random_crop": False}, |
|
|
"common": {"seed": 1234}, |
|
|
} |
|
|
|
|
|
overrides.update({"model": {"w2v_args": w2v_args}}) |
|
|
|
|
|
( |
|
|
model, |
|
|
cfg, |
|
|
task, |
|
|
) = checkpoint_utils.load_model_ensemble_and_task( |
|
|
[ckpt_path], arg_overrides=overrides |
|
|
) |
|
|
self.model = model[0].eval().cuda() |
|
|
self.task = task |
|
|
|
|
|
self.use_cuda = True |
|
|
self.use_fp16 = False |
|
|
logging.info(f"TASK CONFIG:\n{self.task.cfg}") |
|
|
|
|
|
def extract_single_segment(self, raw_wav): |
|
|
wav = torch.from_numpy(raw_wav).float() |
|
|
with torch.no_grad(): |
|
|
wav = F.layer_norm(wav, wav.shape) |
|
|
samples = [{"id": 0, "source": wav, "label_list": "None"}] |
|
|
audios = [s["source"] for s in samples] |
|
|
audio_sizes = [len(s) for s in audios] |
|
|
audio_size = max(audio_sizes) |
|
|
collated_audios, padding_mask, audio_starts = collater_audio(audios, audio_size) |
|
|
|
|
|
net_input = {"source": collated_audios, "padding_mask": padding_mask} |
|
|
sample = { |
|
|
"id": torch.LongTensor([s["id"] for s in samples]), |
|
|
"net_input": net_input, |
|
|
} |
|
|
|
|
|
sample = utils.move_to_cuda(sample) if self.use_cuda else sample |
|
|
|
|
|
def apply_half(t): |
|
|
if t.dtype is torch.float32: |
|
|
return t.to(dtype=torch.half) |
|
|
return t |
|
|
|
|
|
if self.use_fp16: |
|
|
sample = utils.apply_to_sample(apply_half, sample) |
|
|
self.model.set_num_updates(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
net_output = self.model(**sample["net_input"]) |
|
|
|
|
|
batch_quantized_ids = net_output["quantized_ids"] |
|
|
batch_quantized_ids_length = net_output["quantized_id_lengths"] |
|
|
unit_sequence_list, reduced_unit_sequence_list = get_unit_sequence( |
|
|
batch_quantized_ids, batch_quantized_ids_length |
|
|
) |
|
|
|
|
|
numbers = reduced_unit_sequence_list[0].split() |
|
|
audio_token = "".join([f"<|speech_{number}|>" for number in numbers]) |
|
|
|
|
|
unit_sequence = unit_sequence_list[0] |
|
|
reduced_unit_sequence = reduced_unit_sequence_list[0] |
|
|
return audio_token, unit_sequence, reduced_unit_sequence |
|
|
|
|
|
def extract(self, raw_wavs_list, speech_tokenizer_segment_len=0): |
|
|
""" |
|
|
提取逻辑。 |
|
|
speech_tokenizer_segment_len (int, optional): 用于音频切割的长度。默认为0 表示不进行切割, 如果大于0, 则对每个音频文件进行切割。 |
|
|
""" |
|
|
info_list = [] |
|
|
audio_token_list = [] |
|
|
for raw_wav in tqdm(raw_wavs_list): |
|
|
wav_len = raw_wav.shape[0] |
|
|
if ( |
|
|
speech_tokenizer_segment_len > 0 |
|
|
and wav_len > speech_tokenizer_segment_len |
|
|
): |
|
|
audio_token = "" |
|
|
unit_sequence_list = [] |
|
|
reduced_unit_sequence_list = [] |
|
|
|
|
|
num_segments = int(np.ceil(wav_len / speech_tokenizer_segment_len)) |
|
|
|
|
|
for i in range(num_segments): |
|
|
start_sample = int(i * speech_tokenizer_segment_len) |
|
|
end_sample = int((i + 1) * speech_tokenizer_segment_len) |
|
|
segment_wav = raw_wav[start_sample:end_sample] |
|
|
( |
|
|
segment_audio_token, |
|
|
segment_unit_sequence, |
|
|
segment_reduced_unit_sequence, |
|
|
) = self.extract_single_segment(segment_wav) |
|
|
|
|
|
audio_token += segment_audio_token |
|
|
unit_sequence_list.extend(segment_unit_sequence.split(" ")) |
|
|
reduced_unit_sequence_list.extend( |
|
|
segment_reduced_unit_sequence.split(" ") |
|
|
) |
|
|
unit_sequence = " ".join(unit_sequence_list) |
|
|
reduced_unit_sequence = " ".join(reduced_unit_sequence_list) |
|
|
else: |
|
|
audio_token, unit_sequence, reduced_unit_sequence = ( |
|
|
self.extract_single_segment(raw_wav) |
|
|
) |
|
|
|
|
|
audio_token_list.append(audio_token) |
|
|
|
|
|
info_list.append( |
|
|
{ |
|
|
"unit_sequence": unit_sequence, |
|
|
"reduced_unit_sequence": reduced_unit_sequence, |
|
|
} |
|
|
) |
|
|
|
|
|
return audio_token_list, info_list |
|
|
|
|
|
|
|
|
def main(args): |
|
|
if args.ckpt is not None: |
|
|
tokenizer = SpeechTokenizer( |
|
|
ckpt_path=args.ckpt, cfg_path=args.cfg_path, cfg_name=args.cfg_name |
|
|
) |
|
|
else: |
|
|
tokenizer = SpeechTokenizer() |
|
|
|
|
|
wav_file_list = [] |
|
|
with open(args.input_list, "r") as input_file: |
|
|
for line in input_file: |
|
|
wav_file_list.append(line.strip().split("|")[0]) |
|
|
|
|
|
raw_wavs_list = [] |
|
|
for file_path in wav_file_list: |
|
|
|
|
|
raw_wav, sr = librosa.load(file_path, sr=16000) |
|
|
raw_wavs_list.append(raw_wav) |
|
|
|
|
|
token_list, token_info_list = tokenizer.extract(raw_wavs_list) |
|
|
with open(args.output_file, "w") as output_file: |
|
|
for token_info in token_info_list: |
|
|
logging.info(token_info["unit_sequence"]) |
|
|
output_file.write(token_info["unit_sequence"] + "\n") |
|
|
output_file.close() |
|
|
logging.info("Finished") |
|
|
return |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--ckpt", |
|
|
dest="ckpt", |
|
|
required=False, |
|
|
help="path to ckpt", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--cfg-path", |
|
|
dest="cfg_path", |
|
|
required=False, |
|
|
default="config", |
|
|
help="path to config", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--cfg-name", |
|
|
dest="cfg_name", |
|
|
required=False, |
|
|
default="hubert_config", |
|
|
help="name of config", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--input-list", |
|
|
dest="input_list", |
|
|
required=True, |
|
|
help="list of input wavform", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-file", |
|
|
dest="output_file", |
|
|
required=True, |
|
|
help="file to output speech tokens", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args) |
|
|
|