Spaces:
Build error
Build error
| import json | |
| from copy import deepcopy | |
| from pathlib import Path | |
| from tempfile import TemporaryDirectory | |
| from typing import Dict | |
| from unittest import TestCase | |
| from fastapi import HTTPException | |
| from pyopenjtalk import g2p, unset_user_dict | |
| from voicevox_engine.model import UserDictWord, WordTypes | |
| from voicevox_engine.part_of_speech_data import MAX_PRIORITY, part_of_speech_data | |
| from voicevox_engine.user_dict import ( | |
| apply_word, | |
| create_word, | |
| delete_word, | |
| import_user_dict, | |
| read_dict, | |
| rewrite_word, | |
| update_dict, | |
| ) | |
| # jsonとして保存される正しい形式の辞書データ | |
| valid_dict_dict_json = { | |
| "aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": { | |
| "surface": "test", | |
| "cost": part_of_speech_data[WordTypes.PROPER_NOUN].cost_candidates[5], | |
| "part_of_speech": "名詞", | |
| "part_of_speech_detail_1": "固有名詞", | |
| "part_of_speech_detail_2": "一般", | |
| "part_of_speech_detail_3": "*", | |
| "inflectional_type": "*", | |
| "inflectional_form": "*", | |
| "stem": "*", | |
| "yomi": "テスト", | |
| "pronunciation": "テスト", | |
| "accent_type": 1, | |
| "accent_associative_rule": "*", | |
| }, | |
| } | |
| # APIでやり取りされる正しい形式の辞書データ | |
| valid_dict_dict_api = deepcopy(valid_dict_dict_json) | |
| del valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]["cost"] | |
| valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]["priority"] = 5 | |
| import_word = UserDictWord( | |
| surface="test2", | |
| priority=5, | |
| part_of_speech="名詞", | |
| part_of_speech_detail_1="固有名詞", | |
| part_of_speech_detail_2="一般", | |
| part_of_speech_detail_3="*", | |
| inflectional_type="*", | |
| inflectional_form="*", | |
| stem="*", | |
| yomi="テストツー", | |
| pronunciation="テストツー", | |
| accent_type=1, | |
| accent_associative_rule="*", | |
| ) | |
| def get_new_word(user_dict: Dict[str, UserDictWord]): | |
| assert len(user_dict) == 2 or ( | |
| len(user_dict) == 1 and "aab7dda2-0d97-43c8-8cb7-3f440dab9b4e" not in user_dict | |
| ) | |
| for word_uuid in user_dict.keys(): | |
| if word_uuid == "aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": | |
| continue | |
| return user_dict[word_uuid] | |
| raise AssertionError | |
| class TestUserDict(TestCase): | |
| def setUp(self): | |
| self.tmp_dir = TemporaryDirectory() | |
| self.tmp_dir_path = Path(self.tmp_dir.name) | |
| def tearDown(self): | |
| unset_user_dict() | |
| self.tmp_dir.cleanup() | |
| def test_read_not_exist_json(self): | |
| self.assertEqual( | |
| read_dict(user_dict_path=(self.tmp_dir_path / "not_exist.json")), | |
| {}, | |
| ) | |
| def test_create_word(self): | |
| # 将来的に品詞などが追加された時にテストを増やす | |
| self.assertEqual( | |
| create_word(surface="test", pronunciation="テスト", accent_type=1), | |
| UserDictWord( | |
| surface="test", | |
| priority=5, | |
| part_of_speech="名詞", | |
| part_of_speech_detail_1="固有名詞", | |
| part_of_speech_detail_2="一般", | |
| part_of_speech_detail_3="*", | |
| inflectional_type="*", | |
| inflectional_form="*", | |
| stem="*", | |
| yomi="テスト", | |
| pronunciation="テスト", | |
| accent_type=1, | |
| accent_associative_rule="*", | |
| ), | |
| ) | |
| def test_apply_word_without_json(self): | |
| user_dict_path = self.tmp_dir_path / "test_apply_word_without_json.json" | |
| apply_word( | |
| surface="test", | |
| pronunciation="テスト", | |
| accent_type=1, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=(self.tmp_dir_path / "test_apply_word_without_json.dic"), | |
| ) | |
| res = read_dict(user_dict_path=user_dict_path) | |
| self.assertEqual(len(res), 1) | |
| new_word = get_new_word(res) | |
| self.assertEqual( | |
| ( | |
| new_word.surface, | |
| new_word.pronunciation, | |
| new_word.accent_type, | |
| ), | |
| ("test", "テスト", 1), | |
| ) | |
| def test_apply_word_with_json(self): | |
| user_dict_path = self.tmp_dir_path / "test_apply_word_with_json.json" | |
| user_dict_path.write_text( | |
| json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| apply_word( | |
| surface="test2", | |
| pronunciation="テストツー", | |
| accent_type=3, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=(self.tmp_dir_path / "test_apply_word_with_json.dic"), | |
| ) | |
| res = read_dict(user_dict_path=user_dict_path) | |
| self.assertEqual(len(res), 2) | |
| new_word = get_new_word(res) | |
| self.assertEqual( | |
| ( | |
| new_word.surface, | |
| new_word.pronunciation, | |
| new_word.accent_type, | |
| ), | |
| ("test2", "テストツー", 3), | |
| ) | |
| def test_rewrite_word_invalid_id(self): | |
| user_dict_path = self.tmp_dir_path / "test_rewrite_word_invalid_id.json" | |
| user_dict_path.write_text( | |
| json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| self.assertRaises( | |
| HTTPException, | |
| rewrite_word, | |
| word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9", | |
| surface="test2", | |
| pronunciation="テストツー", | |
| accent_type=2, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_invalid_id.dic"), | |
| ) | |
| def test_rewrite_word_valid_id(self): | |
| user_dict_path = self.tmp_dir_path / "test_rewrite_word_valid_id.json" | |
| user_dict_path.write_text( | |
| json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| rewrite_word( | |
| word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e", | |
| surface="test2", | |
| pronunciation="テストツー", | |
| accent_type=2, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=(self.tmp_dir_path / "test_rewrite_word_valid_id.dic"), | |
| ) | |
| new_word = read_dict(user_dict_path=user_dict_path)[ | |
| "aab7dda2-0d97-43c8-8cb7-3f440dab9b4e" | |
| ] | |
| self.assertEqual( | |
| (new_word.surface, new_word.pronunciation, new_word.accent_type), | |
| ("test2", "テストツー", 2), | |
| ) | |
| def test_delete_word_invalid_id(self): | |
| user_dict_path = self.tmp_dir_path / "test_delete_word_invalid_id.json" | |
| user_dict_path.write_text( | |
| json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| self.assertRaises( | |
| HTTPException, | |
| delete_word, | |
| word_uuid="c2be4dc5-d07d-4767-8be1-04a1bb3f05a9", | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=(self.tmp_dir_path / "test_delete_word_invalid_id.dic"), | |
| ) | |
| def test_delete_word_valid_id(self): | |
| user_dict_path = self.tmp_dir_path / "test_delete_word_valid_id.json" | |
| user_dict_path.write_text( | |
| json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| delete_word( | |
| word_uuid="aab7dda2-0d97-43c8-8cb7-3f440dab9b4e", | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=(self.tmp_dir_path / "test_delete_word_valid_id.dic"), | |
| ) | |
| self.assertEqual(len(read_dict(user_dict_path=user_dict_path)), 0) | |
| def test_priority(self): | |
| for pos in part_of_speech_data: | |
| for i in range(MAX_PRIORITY + 1): | |
| self.assertEqual( | |
| create_word( | |
| surface="test", | |
| pronunciation="テスト", | |
| accent_type=1, | |
| word_type=pos, | |
| priority=i, | |
| ).priority, | |
| i, | |
| ) | |
| def test_import_dict(self): | |
| user_dict_path = self.tmp_dir_path / "test_import_dict.json" | |
| compiled_dict_path = self.tmp_dir_path / "test_import_dict.dic" | |
| user_dict_path.write_text( | |
| json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| import_user_dict( | |
| {"b1affe2a-d5f0-4050-926c-f28e0c1d9a98": import_word}, | |
| override=False, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=compiled_dict_path, | |
| ) | |
| self.assertEqual( | |
| read_dict(user_dict_path)["b1affe2a-d5f0-4050-926c-f28e0c1d9a98"], | |
| import_word, | |
| ) | |
| self.assertEqual( | |
| read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"], | |
| UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]), | |
| ) | |
| def test_import_dict_no_override(self): | |
| user_dict_path = self.tmp_dir_path / "test_import_dict_no_override.json" | |
| compiled_dict_path = self.tmp_dir_path / "test_import_dict_no_override.dic" | |
| user_dict_path.write_text( | |
| json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| import_user_dict( | |
| {"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word}, | |
| override=False, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=compiled_dict_path, | |
| ) | |
| self.assertEqual( | |
| read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"], | |
| UserDictWord(**valid_dict_dict_api["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"]), | |
| ) | |
| def test_import_dict_override(self): | |
| user_dict_path = self.tmp_dir_path / "test_import_dict_override.json" | |
| compiled_dict_path = self.tmp_dir_path / "test_import_dict_override.dic" | |
| user_dict_path.write_text( | |
| json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| import_user_dict( | |
| {"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": import_word}, | |
| override=True, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=compiled_dict_path, | |
| ) | |
| self.assertEqual( | |
| read_dict(user_dict_path)["aab7dda2-0d97-43c8-8cb7-3f440dab9b4e"], | |
| import_word, | |
| ) | |
| def test_import_invalid_word(self): | |
| user_dict_path = self.tmp_dir_path / "test_import_invalid_dict.json" | |
| compiled_dict_path = self.tmp_dir_path / "test_import_invalid_dict.dic" | |
| invalid_accent_associative_rule_word = deepcopy(import_word) | |
| invalid_accent_associative_rule_word.accent_associative_rule = "invalid" | |
| user_dict_path.write_text( | |
| json.dumps(valid_dict_dict_json, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| self.assertRaises( | |
| AssertionError, | |
| import_user_dict, | |
| { | |
| "aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_accent_associative_rule_word | |
| }, | |
| override=True, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=compiled_dict_path, | |
| ) | |
| invalid_pos_word = deepcopy(import_word) | |
| invalid_pos_word.context_id = 2 | |
| invalid_pos_word.part_of_speech = "フィラー" | |
| invalid_pos_word.part_of_speech_detail_1 = "*" | |
| invalid_pos_word.part_of_speech_detail_2 = "*" | |
| invalid_pos_word.part_of_speech_detail_3 = "*" | |
| self.assertRaises( | |
| ValueError, | |
| import_user_dict, | |
| {"aab7dda2-0d97-43c8-8cb7-3f440dab9b4e": invalid_pos_word}, | |
| override=True, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=compiled_dict_path, | |
| ) | |
| def test_update_dict(self): | |
| user_dict_path = self.tmp_dir_path / "test_update_dict.json" | |
| compiled_dict_path = self.tmp_dir_path / "test_update_dict.dic" | |
| update_dict( | |
| user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path | |
| ) | |
| test_text = "テスト用の文字列" | |
| success_pronunciation = "デフォルトノジショデハゼッタイニセイセイサレナイヨミ" | |
| # 既に辞書に登録されていないか確認する | |
| self.assertNotEqual(g2p(text=test_text, kana=True), success_pronunciation) | |
| apply_word( | |
| surface=test_text, | |
| pronunciation=success_pronunciation, | |
| accent_type=1, | |
| priority=10, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=compiled_dict_path, | |
| ) | |
| self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation) | |
| # 疑似的にエンジンを再起動する | |
| unset_user_dict() | |
| update_dict( | |
| user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path | |
| ) | |
| self.assertEqual(g2p(text=test_text, kana=True), success_pronunciation) | |