File size: 4,146 Bytes
09c17cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cc4cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09c17cd
 
 
 
 
 
 
 
 
 
 
 
52c76d4
09c17cd
5cc4cfa
09c17cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import Any, Dict, List, Optional
from threading import Lock

from word_processor import WordDeterminer, WordPiece


class RustAdapter:
    """
    Rust から呼び出すためのアダプタ。
    - 初期化コストの高いコンポーネント(WordDeterminer, AIモデル)を1回だけ生成して保持
    - メソッドでビルド処理を提供
    - 返却はシリアライズしやすい dict/list 形式
    """

    _instance: Optional["RustAdapter"] = None
    _lock: Lock = Lock()

    def __init__(self, model_path: Optional[str] = None):
        # WordDeterminer(内部で Sudachi C モードの WordCounter を使用)
        self.determiner = WordDeterminer()

        # AIモデルは共有キャッシュ取得
        # model_path が None の場合はデフォルトモデル
        from ai import AI

        self.model = AI.get_model(model_path)

    @classmethod
    def get_instance(cls, model_path: Optional[str] = None) -> "RustAdapter":
        """シングルトン取得。model_path 指定時は初回のみ反映。"""
        if cls._instance is not None:
            return cls._instance
        with cls._lock:
            if cls._instance is None:
                cls._instance = RustAdapter(model_path)
        return cls._instance

    # ===== 公開API =====
    def _clean_text(self, text: str) -> str:
        """制御文字・不可視文字・置換文字を厳密に取り除く(最終出力用)"""
        if not text:
            return ""
        
        # 制御文字(0x00-0x1F、0x7F-0x9F)を除去
        # ただし、改行・タブ・復帰は許可
        cleaned = []
        for ch in text:
            code = ord(ch)
            # 許可する制御文字: 改行(0x0A), タブ(0x09), 復帰(0x0D)
            if code in [0x09, 0x0A, 0x0D]:
                cleaned.append(ch)
            # 通常の印刷可能文字(0x20-0x7E、およびその他のUnicode印刷可能文字)
            elif ch.isprintable():
                # 置換文字(U+FFFD)を除去
                if ch != "\uFFFD":
                    cleaned.append(ch)
            # その他の制御文字や不可視文字は除去
        
        result = "".join(cleaned)
        # ゼロ幅文字を除去
        result = result.replace("\u200B", "")  # Zero-width space
        result = result.replace("\u200C", "")  # Zero-width non-joiner
        result = result.replace("\u200D", "")  # Zero-width joiner
        result = result.replace("\uFEFF", "")  # Zero-width no-break space
        return result

    def build_word_tree(self, prompt_text: str, root_text: str = "", top_k: int = 5, max_depth: int = 10) -> List[Dict[str, Any]]:
        """
        単語ツリーを構築して、完成ピースを dict の配列で返す。
        各要素: { text: str, probability: float }
        """
        pieces: List[WordPiece] = self.determiner.build_word_tree(
            prompt_text=prompt_text,
            root_text=root_text,
            model=self.model,
            top_k=top_k,
            max_depth=max_depth,
        )
        # print(f"[RustAdapter] build_word_tree: {pieces}")
        return [
            {"text": self._clean_text(p.get_full_word()), "probability": float(p.probability)}
            for p in pieces
        ]

    def build_chat_prompt(self, user_content: str, system_content: str = "あなたは親切で役に立つAIアシスタントです。") -> str:
        """チャットプロンプト文字列を返す。"""
        return self.determiner.build_chat_prompt(user_content, system_content)

    def count_words(self, text: str) -> int:
        """Sudachi(C) ベースでの語数カウント。"""
        return self.determiner._count_words(text)


# 簡易動作テスト
if __name__ == "__main__":
    adapter = RustAdapter.get_instance()
    prompt = "電球を作ったのは誰?"
    results = adapter.build_word_tree(prompt_text=prompt, root_text="", top_k=5, max_depth=5)
    print("=== RustAdapter 確認 ===")
    for i, r in enumerate(results, 1):
        print(f"{i}. {r['text']} ({r['probability']:.4f})")