# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Yusen Sun, # Xiao Chen) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ast import fileinput import logging import os import sys import time import argparse from collections import namedtuple from tqdm import tqdm from pathlib import Path import numpy as np import torch from fairseq import checkpoint_utils, options, tasks, utils from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output from fairseq.models import import_models PHONE_SPLITTER = {"[SIL]", "[CM]", "[PD]", "[QN]", "[EX]"} current_root = Path(__file__).absolute().parent sys.path.append(str(current_root)) sys.path.append(str(current_root.parent / "thirdparty/G2P")) from G2P_processors import MultilingualG2P relative_path = Path(current_root.name) namespace = str(relative_path / "models").replace("/" , ".") import_models(str(current_root / "models"), namespace) TOKENIZE_ON_NPU = os.environ.get("TOKENIZE_ON_NPU") if TOKENIZE_ON_NPU is not None and TOKENIZE_ON_NPU == "1": import torch_npu from torch_npu.contrib import transfer_to_npu logging.info("Applying Patches for NPU!!!") console_format = logging.Formatter( "[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s" ) console_handler = logging.StreamHandler() console_handler.setFormatter(console_format) console_handler.setLevel(logging.INFO) if len(logging.root.handlers) > 0: for handler in logging.root.handlers: logging.root.removeHandler(handler) logging.root.addHandler(console_handler) logging.root.setLevel(logging.INFO) Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") Translation = namedtuple("Translation", "src_str hypos pos_scores alignments") DEFAULT_T2U_ARGS = [ str(current_root) + "/data_bin", "--path", str(current_root) + "/ckpt/40ms.checkpoint15.pt", "--batch-size", "1", "--buffer-size", "2", "--beam", "5", "--max-len-b", "1024", # "--input", # "./sample.txt", "--source-lang", "ph", "--target-lang", "tgt.unit", ] def dummy_encode_fn(x): return x class Text2TokenGenerator: def __init__(self, args=None) -> None: self._initialize(args) def _initialize(self, args): t2u_args = DEFAULT_T2U_ARGS if args is not None and len(args) > 0: t2u_args = t2u_args + args parser = options.get_interactive_generation_parser() t2u_fairseq_args = options.parse_args_and_arch( parser=parser, input_args=t2u_args ) cfg: FairseqConfig = convert_namespace_to_omegaconf(t2u_fairseq_args) utils.import_user_module(cfg.common) if cfg.interactive.buffer_size < 1: cfg.interactive.buffer_size = 1 if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.batch_size = 1 assert ( not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( not cfg.dataset.batch_size or cfg.dataset.batch_size <= cfg.interactive.buffer_size ), "--batch-size cannot be larger than --buffer-size" self.cfg = cfg logging.info(self.cfg) # Fix seed for stochastic decoding if ( self.cfg.common.seed is not None and not self.cfg.generation.no_seed_provided ): np.random.seed(self.cfg.common.seed) utils.set_torch_seed(self.cfg.common.seed) self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu # Setup task, e.g., translation self.task = tasks.setup_task(self.cfg.task) # Load ensemble overrides = ast.literal_eval(self.cfg.common_eval.model_overrides) logging.info("loading model(s) from {}".format(self.cfg.common_eval.path)) self.models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(self.cfg.common_eval.path), arg_overrides=overrides, task=self.task, suffix=self.cfg.checkpoint.checkpoint_suffix, strict=(self.cfg.checkpoint.checkpoint_shard_count == 1), num_shards=self.cfg.checkpoint.checkpoint_shard_count, ) # Set dictionaries self.src_dict = self.task.source_dictionary self.tgt_dict = self.task.target_dictionary # Optimize ensemble for generation for model in self.models: if model is None: continue if self.cfg.common.fp16: model.half() if ( self.use_cuda and not self.cfg.distributed_training.pipeline_model_parallel ): model.cuda() model.prepare_for_inference_(cfg) # Initialize generator self.generator = self.task.build_generator(self.models, self.cfg.generation) # Handle tokenization and BPE self.tokenizer = self.task.build_tokenizer(cfg.tokenizer) self.bpe = self.task.build_bpe(cfg.bpe) self.align_dict = None self.max_positions = utils.resolve_max_positions( self.task.max_positions(), *[model.max_positions() for model in self.models] ) # init G2P self.language = "zh" # zh means the model treats all non-English as Chinese, en means the model treats all langauge as English. self.mG2P = MultilingualG2P( "wenet", remove_interjections=False, remove_erhua=False ) # 'baidu' or 'wenet' def text2phone(self, text): phones, norm_text = self.mG2P.text_normalization_and_g2p( text, self.language, with_lang_prefix=True, normalize_punct=True ) return " ".join(phones) def buffered_read(self, input, buffer_size): buffer = [] with fileinput.input( files=[input], openhook=fileinput.hook_encoded("utf-8") ) as h: for src_str in h: phones = self.text2phone(src_str.strip()) buffer.append(phones) if len(buffer) >= buffer_size: yield buffer buffer = [] if len(buffer) > 0: yield buffer def make_batches(self, lines, encode_fn): def encode_fn_target(x): return encode_fn(x) if self.cfg.generation.constraints: # Strip (tab-delimited) contraints, if present, from input lines, # store them in batch_constraints batch_constraints = [list() for _ in lines] for i, line in enumerate(lines): if "\t" in line: lines[i], *batch_constraints[i] = line.split("\t") # Convert each List[str] to List[Tensor] for i, constraint_list in enumerate(batch_constraints): batch_constraints[i] = [ self.task.target_dictionary.encode_line( encode_fn_target(constraint), append_eos=False, add_if_not_exist=False, ) for constraint in constraint_list ] if self.cfg.generation.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None tokens, lengths = self.task.get_interactive_tokens_and_lengths(lines, encode_fn) itr = self.task.get_batch_iterator( dataset=self.task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor ), max_tokens=self.cfg.dataset.max_tokens, max_sentences=self.cfg.dataset.batch_size, max_positions=self.max_positions, ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch["id"] src_tokens = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] constraints = batch.get("constraints", None) yield Batch( ids=ids, src_tokens=src_tokens, src_lengths=src_lengths, constraints=constraints, ) def generate_for_text_file_input(self, input): start_time = time.time() total_translate_time = 0 hypo_outputs = [] start_id = 0 for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size): results = [] for batch in self.make_batches(inputs, dummy_encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, }, } translate_start_time = time.time() translations = self.task.inference_step( self.generator, self.models, sample, constraints=constraints ) translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] if self.cfg.generation.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad()) constraints = list_constraints[i] results.append( ( start_id + id, src_tokens_i, hypos, { "constraints": constraints, "time": translate_time / len(translations), }, ) ) # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): output = {} output["src_tokens"] = [] # src_str = "" if self.src_dict is not None: src_str = self.src_dict.string( src_tokens, self.cfg.common_eval.post_process ) output["src_tokens"] = src_str.split() # Process top predictions output["hypotheses"] = [] for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]: hypo_str = self.tgt_dict.string( hypo["tokens"].int().cpu(), self.cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output( self.generator ), ) output["hypotheses"].append( { "hypo_tokens": hypo_str.split(), "alignment": hypo["alignment"], } ) hypo_outputs.append(output) # update running id_ counter start_id += len(inputs) logging.info( "Total time: {:.3f} seconds; translation time: {:.3f}".format( time.time() - start_time, total_translate_time ) ) return hypo_outputs def split_phone_segments(self, phones, max_segment_len=0): phone_segments = [] phone_splits = phones.split() seps = [] for idx in range(len(phone_splits)): ph = phone_splits[idx] if ph in PHONE_SPLITTER: seps.append(idx) if len(seps) <= 0: return [phones] if seps[-1] < len(phone_splits) - 1: seps.append(len(phone_splits) - 1) segment_start = 0 segment_end = 0 for idx in range(len(seps)): seglen = seps[idx] - segment_start + 1 if seglen >= max_segment_len or idx == len(seps) - 1: segment_end = segment_start + seglen phone_segments.append(" ".join(phone_splits[segment_start:segment_end])) segment_start = segment_end else: continue reproduce_phone = " ".join(phone_segments) if phones != reproduce_phone: logging.info(f"ERROR!!!!! segments shorter than phones") exit() return phone_segments def generate_for_long_input_text(self, input_phones, max_segment_len=0): total_translate_time = 0 input_segments = [] segment_lens = [] for input in input_phones: segments = self.split_phone_segments(input, max_segment_len) segment_lens.append(len(segments)) input_segments.extend(segments) logging.info( f"Spliting {len(input_phones)} inputs into {len(input_segments)} segments" ) results = [] start_id = 0 for batch in self.make_batches(input_segments, dummy_encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, }, } logging.info(f"processing batch: {bsz}") translate_start_time = time.time() translations = self.task.inference_step( self.generator, self.models, sample, constraints=constraints ) translate_time = time.time() - translate_start_time total_translate_time += translate_time for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): results.append((start_id + id, hypos)) segment_results = [] sorted_results = sorted(results, key=lambda x: x[0]) start_pos = 0 for sl in segment_lens: segment_results.append(sorted_results[start_pos : start_pos + sl]) start_pos += sl assert len(input_phones) == len(segment_results) hypo_tokens = [] for seg_res in segment_results: token_res = [] for id_, hypos in seg_res: # Process top predictions hypo = hypos[0] hypo_str = self.tgt_dict.string( hypo["tokens"].int().cpu(), self.cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output( self.generator ), ) token_res.extend(hypo_str.split()) hypo_tokens.append(token_res) return hypo_tokens, total_translate_time def generate_for_long_text_input_file(self, input, max_segment_len=0): start_time = time.time() total_translate_time = 0 hypo_outputs = [] for inputs in self.buffered_read(input, self.cfg.interactive.buffer_size): logging.info(f"processing inputs: {len(inputs)}") # for input_phones in tqdm(inputs): hypo_tokens, translate_time = self.generate_for_long_input_text( inputs, max_segment_len=max_segment_len ) total_translate_time += translate_time hypo_outputs.extend(hypo_tokens) logging.info( "Total time: {:.3f} seconds; translation time: {:.3f}".format( time.time() - start_time, total_translate_time ) ) return hypo_outputs def infer(unk_args, output_file, max_seg_len): output_fp = sys.stdout if output_file is not None: output_fp = open(output_file, "w") t2u = Text2TokenGenerator(unk_args) if max_seg_len <= 0: speech_tokens_info = t2u.generate_for_text_file_input(t2u.cfg.interactive.input) for infor in speech_tokens_info: output_fp.write(" ".join(infor["hypotheses"][0]["hypo_tokens"]) + "\n") else: speech_tokens_info = t2u.generate_for_long_text_input_file( t2u.cfg.interactive.input, max_segment_len=max_seg_len ) for infor in speech_tokens_info: output_fp.write(" ".join(infor) + "\n") output_fp.flush() output_fp.close() return if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--output", dest="output", required=False, default=None, help="output file", ) parser.add_argument( "--max-seg-len", dest="max_seg_len", required=False, default=0, type=int, help="max segment length", ) args, unknown_args = parser.parse_known_args() infer(unknown_args, args.output, args.max_seg_len)