Spaces:
Build error
Build error
| import json | |
| import sys | |
| import threading | |
| import traceback | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| from uuid import UUID, uuid4 | |
| import numpy as np | |
| import pyopenjtalk | |
| from fastapi import HTTPException | |
| from pydantic import conint | |
| from .model import UserDictWord, WordTypes | |
| from .part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data | |
| from .utility import engine_root, get_save_dir, mutex_wrapper | |
| root_dir = engine_root() | |
| save_dir = get_save_dir() | |
| if not save_dir.is_dir(): | |
| save_dir.mkdir(parents=True) | |
| default_dict_path = root_dir / "default.csv" | |
| user_dict_path = save_dir / "user_dict.json" | |
| compiled_dict_path = save_dir / "user.dic" | |
| mutex_user_dict = threading.Lock() | |
| mutex_openjtalk_dict = threading.Lock() | |
| def write_to_json(user_dict: Dict[str, UserDictWord], user_dict_path: Path): | |
| converted_user_dict = {} | |
| for word_uuid, word in user_dict.items(): | |
| word_dict = word.dict() | |
| word_dict["cost"] = priority2cost( | |
| word_dict["context_id"], word_dict["priority"] | |
| ) | |
| del word_dict["priority"] | |
| converted_user_dict[word_uuid] = word_dict | |
| # 予めjsonに変換できることを確かめる | |
| user_dict_json = json.dumps(converted_user_dict, ensure_ascii=False) | |
| user_dict_path.write_text(user_dict_json, encoding="utf-8") | |
| def update_dict( | |
| default_dict_path: Path = default_dict_path, | |
| user_dict_path: Path = user_dict_path, | |
| compiled_dict_path: Path = compiled_dict_path, | |
| ): | |
| random_string = uuid4() | |
| tmp_csv_path = save_dir / f".tmp.dict_csv-{random_string}" | |
| tmp_compiled_path = save_dir / f".tmp.dict_compiled-{random_string}" | |
| try: | |
| # 辞書.csvを作成 | |
| csv_text = "" | |
| if not default_dict_path.is_file(): | |
| print("Warning: Cannot find default dictionary.", file=sys.stderr) | |
| return | |
| default_dict = default_dict_path.read_text(encoding="utf-8") | |
| if default_dict == default_dict.rstrip(): | |
| default_dict += "\n" | |
| csv_text += default_dict | |
| user_dict = read_dict(user_dict_path=user_dict_path) | |
| for word_uuid in user_dict: | |
| word = user_dict[word_uuid] | |
| csv_text += ( | |
| "{surface},{context_id},{context_id},{cost},{part_of_speech}," | |
| + "{part_of_speech_detail_1},{part_of_speech_detail_2}," | |
| + "{part_of_speech_detail_3},{inflectional_type}," | |
| + "{inflectional_form},{stem},{yomi},{pronunciation}," | |
| + "{accent_type}/{mora_count},{accent_associative_rule}\n" | |
| ).format( | |
| surface=word.surface, | |
| context_id=word.context_id, | |
| cost=priority2cost(word.context_id, word.priority), | |
| part_of_speech=word.part_of_speech, | |
| part_of_speech_detail_1=word.part_of_speech_detail_1, | |
| part_of_speech_detail_2=word.part_of_speech_detail_2, | |
| part_of_speech_detail_3=word.part_of_speech_detail_3, | |
| inflectional_type=word.inflectional_type, | |
| inflectional_form=word.inflectional_form, | |
| stem=word.stem, | |
| yomi=word.yomi, | |
| pronunciation=word.pronunciation, | |
| accent_type=word.accent_type, | |
| mora_count=word.mora_count, | |
| accent_associative_rule=word.accent_associative_rule, | |
| ) | |
| tmp_csv_path.write_text(csv_text, encoding="utf-8") | |
| # 辞書.csvをOpenJTalk用にコンパイル | |
| pyopenjtalk.create_user_dict(str(tmp_csv_path), str(tmp_compiled_path)) | |
| if not tmp_compiled_path.is_file(): | |
| raise RuntimeError("辞書のコンパイル時にエラーが発生しました。") | |
| # コンパイル済み辞書の置き換え・読み込み | |
| pyopenjtalk.unset_user_dict() | |
| tmp_compiled_path.replace(compiled_dict_path) | |
| if compiled_dict_path.is_file(): | |
| pyopenjtalk.set_user_dict(str(compiled_dict_path.resolve(strict=True))) | |
| except Exception as e: | |
| print("Error: Failed to update dictionary.", file=sys.stderr) | |
| traceback.print_exc(file=sys.stderr) | |
| raise e | |
| finally: | |
| # 後処理 | |
| if tmp_csv_path.exists(): | |
| tmp_csv_path.unlink() | |
| if tmp_compiled_path.exists(): | |
| tmp_compiled_path.unlink() | |
| def read_dict(user_dict_path: Path = user_dict_path) -> Dict[str, UserDictWord]: | |
| if not user_dict_path.is_file(): | |
| return {} | |
| with user_dict_path.open(encoding="utf-8") as f: | |
| result = {} | |
| for word_uuid, word in json.load(f).items(): | |
| # cost2priorityで変換を行う際にcontext_idが必要となるが、 | |
| # 0.12以前の辞書は、context_idがハードコーディングされていたためにユーザー辞書内に保管されていない | |
| # ハードコーディングされていたcontext_idは固有名詞を意味するものなので、固有名詞のcontext_idを補完する | |
| if word.get("context_id") is None: | |
| word["context_id"] = part_of_speech_data[ | |
| WordTypes.PROPER_NOUN | |
| ].context_id | |
| word["priority"] = cost2priority(word["context_id"], word["cost"]) | |
| del word["cost"] | |
| result[str(UUID(word_uuid))] = UserDictWord(**word) | |
| return result | |
| def create_word( | |
| surface: str, | |
| pronunciation: str, | |
| accent_type: int, | |
| word_type: Optional[WordTypes] = None, | |
| priority: Optional[int] = None, | |
| ) -> UserDictWord: | |
| if word_type is None: | |
| word_type = WordTypes.PROPER_NOUN | |
| if word_type not in part_of_speech_data.keys(): | |
| raise HTTPException(status_code=422, detail="不明な品詞です") | |
| if priority is None: | |
| priority = 5 | |
| if not MIN_PRIORITY <= priority <= MAX_PRIORITY: | |
| raise HTTPException(status_code=422, detail="優先度の値が無効です") | |
| pos_detail = part_of_speech_data[word_type] | |
| return UserDictWord( | |
| surface=surface, | |
| context_id=pos_detail.context_id, | |
| priority=priority, | |
| part_of_speech=pos_detail.part_of_speech, | |
| part_of_speech_detail_1=pos_detail.part_of_speech_detail_1, | |
| part_of_speech_detail_2=pos_detail.part_of_speech_detail_2, | |
| part_of_speech_detail_3=pos_detail.part_of_speech_detail_3, | |
| inflectional_type="*", | |
| inflectional_form="*", | |
| stem="*", | |
| yomi=pronunciation, | |
| pronunciation=pronunciation, | |
| accent_type=accent_type, | |
| accent_associative_rule="*", | |
| ) | |
| def apply_word( | |
| surface: str, | |
| pronunciation: str, | |
| accent_type: int, | |
| word_type: Optional[WordTypes] = None, | |
| priority: Optional[int] = None, | |
| user_dict_path: Path = user_dict_path, | |
| compiled_dict_path: Path = compiled_dict_path, | |
| ) -> str: | |
| word = create_word( | |
| surface=surface, | |
| pronunciation=pronunciation, | |
| accent_type=accent_type, | |
| word_type=word_type, | |
| priority=priority, | |
| ) | |
| user_dict = read_dict(user_dict_path=user_dict_path) | |
| word_uuid = str(uuid4()) | |
| user_dict[word_uuid] = word | |
| write_to_json(user_dict, user_dict_path) | |
| update_dict(user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path) | |
| return word_uuid | |
| def rewrite_word( | |
| word_uuid: str, | |
| surface: str, | |
| pronunciation: str, | |
| accent_type: int, | |
| word_type: Optional[WordTypes] = None, | |
| priority: Optional[int] = None, | |
| user_dict_path: Path = user_dict_path, | |
| compiled_dict_path: Path = compiled_dict_path, | |
| ): | |
| word = create_word( | |
| surface=surface, | |
| pronunciation=pronunciation, | |
| accent_type=accent_type, | |
| word_type=word_type, | |
| priority=priority, | |
| ) | |
| user_dict = read_dict(user_dict_path=user_dict_path) | |
| if word_uuid not in user_dict: | |
| raise HTTPException(status_code=422, detail="UUIDに該当するワードが見つかりませんでした") | |
| user_dict[word_uuid] = word | |
| write_to_json(user_dict, user_dict_path) | |
| update_dict(user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path) | |
| def delete_word( | |
| word_uuid: str, | |
| user_dict_path: Path = user_dict_path, | |
| compiled_dict_path: Path = compiled_dict_path, | |
| ): | |
| user_dict = read_dict(user_dict_path=user_dict_path) | |
| if word_uuid not in user_dict: | |
| raise HTTPException(status_code=422, detail="IDに該当するワードが見つかりませんでした") | |
| del user_dict[word_uuid] | |
| write_to_json(user_dict, user_dict_path) | |
| update_dict(user_dict_path=user_dict_path, compiled_dict_path=compiled_dict_path) | |
| def import_user_dict( | |
| dict_data: Dict[str, UserDictWord], | |
| override: bool = False, | |
| user_dict_path: Path = user_dict_path, | |
| default_dict_path: Path = default_dict_path, | |
| compiled_dict_path: Path = compiled_dict_path, | |
| ): | |
| # 念のため型チェックを行う | |
| for word_uuid, word in dict_data.items(): | |
| UUID(word_uuid) | |
| assert type(word) == UserDictWord | |
| for pos_detail in part_of_speech_data.values(): | |
| if word.context_id == pos_detail.context_id: | |
| assert word.part_of_speech == pos_detail.part_of_speech | |
| assert ( | |
| word.part_of_speech_detail_1 == pos_detail.part_of_speech_detail_1 | |
| ) | |
| assert ( | |
| word.part_of_speech_detail_2 == pos_detail.part_of_speech_detail_2 | |
| ) | |
| assert ( | |
| word.part_of_speech_detail_3 == pos_detail.part_of_speech_detail_3 | |
| ) | |
| assert ( | |
| word.accent_associative_rule in pos_detail.accent_associative_rules | |
| ) | |
| break | |
| else: | |
| raise ValueError("対応していない品詞です") | |
| old_dict = read_dict(user_dict_path=user_dict_path) | |
| if override: | |
| new_dict = {**old_dict, **dict_data} | |
| else: | |
| new_dict = {**dict_data, **old_dict} | |
| write_to_json(user_dict=new_dict, user_dict_path=user_dict_path) | |
| update_dict( | |
| default_dict_path=default_dict_path, | |
| user_dict_path=user_dict_path, | |
| compiled_dict_path=compiled_dict_path, | |
| ) | |
| def search_cost_candidates(context_id: int) -> List[int]: | |
| for value in part_of_speech_data.values(): | |
| if value.context_id == context_id: | |
| return value.cost_candidates | |
| raise HTTPException(status_code=422, detail="品詞IDが不正です") | |
| def cost2priority(context_id: int, cost: conint(ge=-32768, le=32767)) -> int: | |
| cost_candidates = search_cost_candidates(context_id) | |
| # cost_candidatesの中にある値で最も近い値を元にpriorityを返す | |
| # 参考: https://qiita.com/Krypf/items/2eada91c37161d17621d | |
| # この関数とpriority2cost関数によって、辞書ファイルのcostを操作しても最も近いpriorityのcostに上書きされる | |
| return MAX_PRIORITY - np.argmin(np.abs(np.array(cost_candidates) - cost)) | |
| def priority2cost( | |
| context_id: int, priority: conint(ge=MIN_PRIORITY, le=MAX_PRIORITY) | |
| ) -> int: | |
| cost_candidates = search_cost_candidates(context_id) | |
| return cost_candidates[MAX_PRIORITY - priority] | |