neuTTS-JP-150m / tokenization_llm_jp_tts.py
tsukemono's picture
Upload model_causal files and README
aeced03 verified
import re
import pyopenjtalk
from transformers import PreTrainedTokenizerFast
def katakana_to_hiragana(text: str) -> str:
"""
カタカナをひらがなに変換
Args:
text: カタカナテキスト
Returns:
ひらがなテキスト
"""
result = []
for char in text:
# カタカナ(ァ-ヶ)をひらがな(ぁ-ゖ)に変換
if "ァ" <= char <= "ヶ":
# カタカナの開始コードポイント: 0x30A1
# ひらがなの開始コードポイント: 0x3041
# 差分: 0x60
result.append(chr(ord(char) - 0x60))
else:
result.append(char)
return "".join(result)
def add_ruby_single(text: str, file_path: str = None) -> str:
"""
単一のテキストにルビを振る(pyopenjtalk-plus使用)
Args:
text: 入力テキスト
file_path: ファイルパス(エラー時のログ用)
Returns:
ルビが振られたテキスト(形式: 漢字[よみ])
Raises:
RuntimeError: テキストが長すぎる場合(512バイト超過)
"""
# 事前にバイト数をチェック(pyopenjtalkの制限: 512バイト)
text_bytes = len(text.encode("utf-8"))
if text_bytes > 512:
# エラー情報を出力
if file_path:
print("\nERROR: pyopenjtalk入力長制限エラー(事前チェック)")
print(f" ファイル: {file_path}")
print(f" テキスト長: {len(text)} 文字, {text_bytes} bytes (max: 512 bytes)")
print(f" テキスト内容: {text[:100]}...")
raise RuntimeError(f"Input too long: {text_bytes} bytes (max 512 bytes)")
try:
# pyopenjtalk-plusで形態素解析
features = pyopenjtalk.run_frontend(text)
# 空の結果が返された場合もエラー扱い(処理失敗の可能性)
if not features:
if text.strip(): # 空白だけのテキストは除く
print("\nWARNING: pyopenjtalkが空の結果を返しました")
if file_path:
print(f" ファイル: {file_path}")
print(f" テキスト: {text[:100]}...")
result = []
for feature in features:
surface = feature["string"] # 表層形(元のテキスト)
reading = feature["read"] # 読み(カタカナ)
# カタカナをひらがなに変換
reading_hira = katakana_to_hiragana(reading)
# 表層形と読みが異なる場合のみルビを振る
# ひらがな・カタカナはそのまま(ルビ不要)
if surface != reading_hira and not all(
c in "ぁ-ん" or c in "ァ-ヶー" for c in surface
):
result.append(f"{surface}[{reading_hira}]")
else:
result.append(surface)
return "".join(result)
except Exception as e:
error_msg = str(e)
# 既知のエラーメッセージをチェック
if "Input too long" in error_msg or "max 512" in error_msg:
if file_path:
print("\nERROR: pyopenjtalk入力長制限エラー(実行時)")
print(f" ファイル: {file_path}")
print(f" テキスト長: {len(text)} 文字, {text_bytes} bytes")
print(f" テキスト内容: {text[:100]}...")
raise RuntimeError(f"Input too long: {text_bytes} bytes") from e
else:
# その他のエラー
raise
def add_ruby(text: str, file_path: str = None) -> str:
"""
テキストにルビを振る(pyopenjtalk-plus使用、長文対応)
Args:
text: 入力テキスト
file_path: ファイルパス(エラー時のログ用)
Returns:
ルビが振られたテキスト(形式: 漢字[よみ])
"""
# 1. まず全体を処理してみる
try:
return add_ruby_single(text, file_path)
except (RuntimeError, Exception) as e:
# RuntimeErrorまたは"Input too long"を含むエラーの場合のみ分割処理
if (
not isinstance(e, RuntimeError)
and "Input too long" not in str(e)
and "max 512" not in str(e)
):
# 長さ以外のエラーは再スロー
raise
# 2. 「。」と「?」と「!」で分割して処理
sentences = re.split(r"(。|?|!)", text)
result_parts = []
for sentence in sentences:
if not sentence:
continue
try:
result_parts.append(add_ruby_single(sentence, file_path))
except (RuntimeError, Exception) as e:
if (
not isinstance(e, RuntimeError)
and "Input too long" not in str(e)
and "max 512" not in str(e)
):
raise
# 3. 「、」でさらに分割
sub_sentences = re.split(r"(、)", sentence)
for sub_sentence in sub_sentences:
if not sub_sentence:
continue
try:
result_parts.append(add_ruby_single(sub_sentence, file_path))
except (RuntimeError, Exception) as e:
if (
not isinstance(e, RuntimeError)
and "Input too long" not in str(e)
and "max 512" not in str(e)
):
raise
# 4. 空白でさらに分割
words = re.split(r"(\s+)", sub_sentence)
for word in words:
if not word:
continue
try:
result_parts.append(add_ruby_single(word, file_path))
except (RuntimeError, Exception) as e:
if (
not isinstance(e, RuntimeError)
and "Input too long" not in str(e)
and "max 512" not in str(e)
):
raise
# 5. 強制的に文字数で分割(バイト数ベース)
print("\nWARNING: 強制分割を実行します(句読点・空白なし)")
if file_path:
print(f" ファイル: {file_path}")
print(
f" テキスト長: {len(word)} 文字, {len(word.encode('utf-8'))} bytes"
)
# 最大バイト数(安全のため少し余裕を持たせる)
max_bytes = 400
current_chunk = ""
current_bytes = 0
for char in word:
char_bytes = len(char.encode("utf-8"))
# 次の文字を追加すると制限を超える場合
if current_bytes + char_bytes > max_bytes:
# 現在のチャンクを処理
if current_chunk:
try:
result_parts.append(
add_ruby_single(
current_chunk, file_path
)
)
except Exception as chunk_e:
# それでもエラーの場合は元のテキストをそのまま使用
print(
f" WARNING: チャンク処理も失敗、元のテキストを使用: {chunk_e}"
)
result_parts.append(current_chunk)
# 新しいチャンクを開始
current_chunk = char
current_bytes = char_bytes
else:
current_chunk += char
current_bytes += char_bytes
# 最後のチャンクを処理
if current_chunk:
try:
result_parts.append(
add_ruby_single(current_chunk, file_path)
)
except Exception as chunk_e:
print(
f" WARNING: 最後のチャンク処理も失敗、元のテキストを使用: {chunk_e}"
)
result_parts.append(current_chunk)
print(f" 強制分割完了: {len(word)} 文字を処理")
return "".join(result_parts)
class LlmJpTtsTokenizer(PreTrainedTokenizerFast):
def _apply_ruby_to_text(self, text, *, is_split_into_words: bool):
if text is None or is_split_into_words:
return text
if isinstance(text, str):
return add_ruby(text)
if isinstance(text, (list, tuple)):
if not text:
return text
if all(isinstance(item, str) for item in text):
return [add_ruby(item) for item in text]
if all(isinstance(item, (list, tuple)) and len(item) == 2 for item in text):
processed = []
for first, second in text:
first_text = add_ruby(first) if isinstance(first, str) else first
second_text = (
add_ruby(second) if isinstance(second, str) else second
)
processed.append((first_text, second_text))
return processed
return text
def __call__(
self,
text=None,
text_pair=None,
text_target=None,
text_pair_target=None,
add_special_tokens=True,
padding=False,
truncation=None,
max_length=None,
stride=0,
is_split_into_words=False,
pad_to_multiple_of=None,
padding_side=None,
return_tensors=None,
return_token_type_ids=None,
return_attention_mask=None,
return_overflowing_tokens=False,
return_special_tokens_mask=False,
return_offsets_mapping=False,
return_length=False,
verbose=True,
tokenizer_kwargs=None,
**kwargs,
):
text = self._apply_ruby_to_text(text, is_split_into_words=is_split_into_words)
text_pair = self._apply_ruby_to_text(
text_pair, is_split_into_words=is_split_into_words
)
text_target = self._apply_ruby_to_text(
text_target, is_split_into_words=is_split_into_words
)
text_pair_target = self._apply_ruby_to_text(
text_pair_target, is_split_into_words=is_split_into_words
)
return super().__call__(
text=text,
text_pair=text_pair,
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
tokenizer_kwargs=tokenizer_kwargs,
**kwargs,
)
def encode(
self,
text,
text_pair=None,
add_special_tokens=True,
padding=False,
truncation=None,
max_length=None,
stride=0,
padding_side=None,
return_tensors=None,
**kwargs,
):
is_split_into_words = bool(kwargs.get("is_split_into_words", False))
text = self._apply_ruby_to_text(text, is_split_into_words=is_split_into_words)
text_pair = self._apply_ruby_to_text(
text_pair, is_split_into_words=is_split_into_words
)
return super().encode(
text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
padding_side=padding_side,
return_tensors=return_tensors,
**kwargs,
)