| from typing import Any, Dict, List, Optional |
| from threading import Lock |
|
|
| from word_processor import WordDeterminer, WordPiece |
|
|
|
|
| class RustAdapter: |
| """ |
| Rust から呼び出すためのアダプタ。 |
| - 初期化コストの高いコンポーネント(WordDeterminer, AIモデル)を1回だけ生成して保持 |
| - メソッドでビルド処理を提供 |
| - 返却はシリアライズしやすい dict/list 形式 |
| """ |
|
|
| _instance: Optional["RustAdapter"] = None |
| _lock: Lock = Lock() |
|
|
| def __init__(self, model_path: Optional[str] = None): |
| |
| self.determiner = WordDeterminer() |
|
|
| |
| |
| from ai import AI |
|
|
| self.model = AI.get_model(model_path) |
|
|
| @classmethod |
| def get_instance(cls, model_path: Optional[str] = None) -> "RustAdapter": |
| """シングルトン取得。model_path 指定時は初回のみ反映。""" |
| if cls._instance is not None: |
| return cls._instance |
| with cls._lock: |
| if cls._instance is None: |
| cls._instance = RustAdapter(model_path) |
| return cls._instance |
|
|
| |
| def _clean_text(self, text: str) -> str: |
| """制御文字・不可視文字・置換文字を厳密に取り除く(最終出力用)""" |
| if not text: |
| return "" |
| |
| |
| |
| cleaned = [] |
| for ch in text: |
| code = ord(ch) |
| |
| if code in [0x09, 0x0A, 0x0D]: |
| cleaned.append(ch) |
| |
| elif ch.isprintable(): |
| |
| if ch != "\uFFFD": |
| cleaned.append(ch) |
| |
| |
| result = "".join(cleaned) |
| |
| result = result.replace("\u200B", "") |
| result = result.replace("\u200C", "") |
| result = result.replace("\u200D", "") |
| result = result.replace("\uFEFF", "") |
| return result |
|
|
| def build_word_tree(self, prompt_text: str, root_text: str = "", top_k: int = 5, max_depth: int = 10) -> List[Dict[str, Any]]: |
| """ |
| 単語ツリーを構築して、完成ピースを dict の配列で返す。 |
| 各要素: { text: str, probability: float } |
| """ |
| pieces: List[WordPiece] = self.determiner.build_word_tree( |
| prompt_text=prompt_text, |
| root_text=root_text, |
| model=self.model, |
| top_k=top_k, |
| max_depth=max_depth, |
| ) |
| |
| return [ |
| {"text": self._clean_text(p.get_full_word()), "probability": float(p.probability)} |
| for p in pieces |
| ] |
|
|
| def build_chat_prompt(self, user_content: str, system_content: str = "あなたは親切で役に立つAIアシスタントです。") -> str: |
| """チャットプロンプト文字列を返す。""" |
| return self.determiner.build_chat_prompt(user_content, system_content) |
|
|
| def count_words(self, text: str) -> int: |
| """Sudachi(C) ベースでの語数カウント。""" |
| return self.determiner._count_words(text) |
|
|
|
|
| |
| if __name__ == "__main__": |
| adapter = RustAdapter.get_instance() |
| prompt = "電球を作ったのは誰?" |
| results = adapter.build_word_tree(prompt_text=prompt, root_text="", top_k=5, max_depth=5) |
| print("=== RustAdapter 確認 ===") |
| for i, r in enumerate(results, 1): |
| print(f"{i}. {r['text']} ({r['probability']:.4f})") |
|
|