Hak5's picture
Add bundled AVoice runtime for HF-only inference
7496177 verified
#!/usr/bin/env python3
# Copyright 2026 Xiaomi Corp. (authors: Han Zhu)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text processing utilities for TTS inference.
Provides:
- ``chunk_text_punctuation()``: Splits long text into model-friendly chunks at
sentence boundaries, with abbreviation-aware punctuation splitting.
- ``add_punctuation()``: Appends missing end punctuation (Chinese or English).
"""
from typing import List, Optional
SPLIT_PUNCTUATION = set(".,;:!?。,;:!?")
CLOSING_MARKS = set("\"'""')]》》>」】")
END_PUNCTUATION = {
";",
":",
",",
".",
"!",
"?",
"…",
")",
"]",
"}",
'"',
"'",
""",
"'",
";",
":",
",",
"。",
"!",
"?",
"、",
"……",
")",
"】",
""",
"'",
}
ABBREVIATIONS = {
"Mr.",
"Mrs.",
"Ms.",
"Dr.",
"Prof.",
"Sr.",
"Jr.",
"Rev.",
"Fr.",
"Hon.",
"Pres.",
"Gov.",
"Capt.",
"Gen.",
"Sen.",
"Rep.",
"Col.",
"Maj.",
"Lt.",
"Cmdr.",
"Sgt.",
"Cpl.",
"Co.",
"Corp.",
"Inc.",
"Ltd.",
"Est.",
"Dept.",
"St.",
"Ave.",
"Blvd.",
"Rd.",
"Mt.",
"Ft.",
"No.",
"Jan.",
"Feb.",
"Mar.",
"Apr.",
"Aug.",
"Sep.",
"Sept.",
"Oct.",
"Nov.",
"Dec.",
"i.e.",
"e.g.",
"vs.",
"Vs.",
"Etc.",
"approx.",
"fig.",
"def.",
}
def chunk_text_punctuation(
text: str,
chunk_len: int,
min_chunk_len: Optional[int] = None,
) -> List[str]:
"""
Splits the input tokens list into chunks according to punctuations,
avoiding splits on common abbreviations (e.g., Mr., No.).
"""
# 1. Split the tokens according to punctuations.
sentences = []
current_sentence = []
tokens_list = list(text)
for token in tokens_list:
# If the first token of current sentence is punctuation,
# append it to the end of the previous sentence.
if (
len(current_sentence) == 0
and len(sentences) != 0
and (token in SPLIT_PUNCTUATION or token in CLOSING_MARKS)
):
sentences[-1].append(token)
# Otherwise, append the current token to the current sentence.
else:
current_sentence.append(token)
# Split the sentence in positions of punctuations.
if token in SPLIT_PUNCTUATION:
is_abbreviation = False
if token == ".":
temp_str = "".join(current_sentence).strip()
if temp_str:
last_word = temp_str.split()[-1]
if last_word in ABBREVIATIONS:
is_abbreviation = True
if not is_abbreviation:
sentences.append(current_sentence)
current_sentence = []
# Assume the last few tokens are also a sentence
if len(current_sentence) != 0:
sentences.append(current_sentence)
# 2. Merge short sentences.
merged_chunks = []
current_chunk = []
for sentence in sentences:
if len(current_chunk) + len(sentence) <= chunk_len:
current_chunk.extend(sentence)
else:
if len(current_chunk) > 0:
merged_chunks.append(current_chunk)
current_chunk = sentence
if len(current_chunk) > 0:
merged_chunks.append(current_chunk)
# 4. Post-process: Check for undersized chunks and merge them
# with the previous chunk or next chunk (if it's the first chunk).
if min_chunk_len is not None:
first_chunk_short_flag = (
len(merged_chunks) > 0 and len(merged_chunks[0]) < min_chunk_len
)
final_chunks = []
for i, chunk in enumerate(merged_chunks):
if i == 1 and first_chunk_short_flag:
final_chunks[-1].extend(chunk)
else:
if len(chunk) >= min_chunk_len:
final_chunks.append(chunk)
else:
if len(final_chunks) == 0:
final_chunks.append(chunk)
else:
final_chunks[-1].extend(chunk)
else:
final_chunks = merged_chunks
chunk_strings = [
"".join(chunk).strip() for chunk in final_chunks if "".join(chunk).strip()
]
return chunk_strings
def add_punctuation(text: str):
"""Add punctuation if there is not in the end of text"""
text = text.strip()
if not text:
return text
if text[-1] not in END_PUNCTUATION:
is_chinese = any("\u4e00" <= char <= "\u9fff" for char in text)
text += "。" if is_chinese else "."
return text