| from typing import List, Tuple, Any, Optional |
| from dataclasses import dataclass |
| from enum import Enum |
| import os |
| import math |
| from word_counter import WordCounter |
| from config import Config |
|
|
|
|
| class WordState(Enum): |
| """単語の状態""" |
| INCOMPLETE = "incomplete" |
| COMPLETE = "complete" |
| TRIGGER = "trigger" |
|
|
| class KList: |
| def __init__(self, num: int): |
| self.num = num |
| self.list: List[Any] = [] |
|
|
| def check_k(self) -> None: |
| if len(self.list) >= self.num: |
| self.list.sort(key=lambda x: x.probability, reverse=True) |
| self.list = self.list[:self.num] |
| else: |
| self.list.sort(key=lambda x: x.probability, reverse=True) |
|
|
| def add(self, piece_word: Any) -> None: |
| |
| new_text = piece_word.get_full_text() |
| for existing_piece in self.list: |
| if existing_piece.get_full_text() == new_text: |
| |
| existing_piece.probability += piece_word.probability |
| |
| self.check_k() |
| return |
| |
| |
| self.list.append(piece_word) |
| self.check_k() |
|
|
| def pop(self) -> Any: |
| if self.list: |
| return self.list.pop(0) |
| raise IndexError("List is empty") |
|
|
| def empty(self) -> bool: |
| return len(self.list) == 0 |
| |
| @dataclass |
| class WordPiece: |
| """単語のピース(部分)""" |
| text: str |
| probability: float |
| next_tokens: Optional[List[Tuple[str, float]]] = None |
| parent: Optional['WordPiece'] = None |
| children: List['WordPiece'] = None |
| |
| def __post_init__(self): |
| if self.children is None: |
| self.children = [] |
| |
| def get_full_text(self) -> str: |
| """ルートからこのピースまでの完全なテキストを取得""" |
| pieces = [] |
| current = self |
| while current is not None: |
| if current.text: |
| pieces.append(current.text) |
| current = current.parent |
| return "".join(reversed(pieces)) |
| |
| def get_full_word(self) -> str: |
| """ルートの次語からこのピースまでの完全な単語を取得""" |
| pieces = [] |
| current = self |
| while current is not None: |
| if current.text: |
| pieces.append(current.text) |
| current = current.parent |
| reversed_pieces = reversed(pieces[:-1]) |
| return "".join(reversed_pieces) |
|
|
| def add_child(self, text: str, probability: float, next_tokens: Optional[List[Tuple[str, float]]] = None) -> 'WordPiece': |
| """子ピースを追加""" |
| child = WordPiece( |
| text=text, |
| probability=probability, |
| next_tokens=next_tokens, |
| parent=self |
| ) |
| self.children.append(child) |
| return child |
| |
| def is_leaf(self) -> bool: |
| """葉ノードかどうか""" |
| return len(self.children) == 0 |
| |
| def get_depth(self) -> int: |
| """ルートからの深さを取得""" |
| depth = 0 |
| current = self.parent |
| while current is not None: |
| depth += 1 |
| current = current.parent |
| return depth |
|
|
|
|
| class WordDeterminer: |
| """単語確定システム(ストリーミング向けリアルタイムアルゴリズム)""" |
| |
| def __init__(self, word_counter: WordCounter = None): |
| """ |
| 初期化 |
| |
| Args: |
| word_counter: WordCounterインスタンス(Noneの場合はデフォルトを使用) |
| """ |
| self.word_counter = word_counter or WordCounter() |
| |
| def is_boundary_char(self, char: str) -> bool: |
| """境界文字かどうかを判定(fugashi使用)""" |
| if not char: |
| return False |
| |
| |
| if char.isspace(): |
| return True |
| |
| |
| punctuation = ",,..。!?!?:;;、\n\t" |
| return char in punctuation |
| |
| def is_word_boundary(self, text: str, position: int) -> bool: |
| """ |
| WordCounterを使用して単語境界を判定 |
| |
| Args: |
| text: テキスト |
| position: 位置(負の値で末尾から指定可能) |
| |
| Returns: |
| bool: 単語境界かどうか |
| """ |
| return self.word_counter.is_word_boundary(text, position) |
| |
| def check_word_completion(self, piece: WordPiece, root_count: int, model: Any = None) -> Tuple[WordState, Optional[Any]]: |
| """ |
| ストリーミング向けリアルタイム単語決定アルゴリズム |
| |
| 戦略: |
| 1. 確率エントロピー: 次のトークンの不確実性を測定 |
| 2. 確率重み付き境界検出: 高確率トークンの挙動を重視 |
| 3. 信頼度ベース判定: 高確率トークンが明確に境界を示す場合のみ確定 |
| |
| アルゴリズム: |
| - エントロピーが低い(確率が集中)→ 単語継続の可能性が高い |
| - エントロピーが高い(確率が分散)→ 単語境界の可能性 |
| - 高確率トークンが境界を示す → 確定 |
| - 低確率トークンだけが境界を示す → 無視 |
| |
| Args: |
| piece: チェックするピース |
| root_count: ルートテキストの単語数 |
| model: LLMモデル(必要に応じて) |
| |
| Returns: |
| Tuple[WordState, Optional[Any]]: (状態, ペイロード) |
| """ |
| full_text = piece.get_full_text() |
| |
| |
| if not piece.next_tokens: |
| if model: |
| piece.next_tokens = self._get_next_tokens_from_model(model, full_text) |
| else: |
| return (WordState.COMPLETE, None) |
| |
| if not piece.next_tokens: |
| return (WordState.COMPLETE, None) |
| |
| |
| sorted_tokens = sorted(piece.next_tokens, key=lambda x: x[1], reverse=True) |
| |
| if piece.get_full_word()[-1] in ["(","「","(","【","〈","《","[","{","⦅","《","[","{","⦅","《","[","{","⦅","《","[","{","⦅","《","[","{","⦅"]: |
| return (WordState.INCOMPLETE, None) |
| if piece.get_full_word()[-1] in [")","]","}","》","〉","》","]","}","⦆","》","]","}","⦆","》","]","}","⦆","》","]","}","⦆","》","]","}","⦆"]: |
| return (WordState.COMPLETE, None) |
| |
| |
| count = max(1, len(sorted_tokens) ) |
| tokens = sorted_tokens[:count] |
| |
| boundary_prob = 0.0 |
| continuation_prob = 0.0 |
| total = sum(prob for _, prob in tokens) |
| |
| for token, prob in tokens: |
| test_text = full_text + token |
| test_word_count = self._count_words(test_text) |
| |
| |
| if test_word_count > root_count + 1: |
| boundary_prob += prob |
| else: |
| continuation_prob += prob |
| |
| |
| if total > 0: |
| boundary_ratio = boundary_prob / total |
| |
| |
| if boundary_ratio > 0.85: |
| return (WordState.COMPLETE, None) |
| |
| |
| if boundary_ratio < 0.2: |
| return (WordState.INCOMPLETE, None) |
| |
| probs = [prob for _, prob in sorted_tokens] |
| entropy = -sum(p * math.log(p + 1e-10) for p in probs if p > 0) |
| max_entropy = math.log(len(sorted_tokens)) if len(sorted_tokens) > 1 else 1.0 |
| normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0 |
| |
| |
| |
|
|
| return (WordState.INCOMPLETE, None) |
| |
| def _count_words(self, text: str) -> int: |
| """ |
| WordCounterを使用してテキストの単語数をカウント |
| |
| Args: |
| text: カウントするテキスト |
| |
| Returns: |
| int: 単語数 |
| """ |
| return self.word_counter.count_words(text) |
| |
| def _get_next_tokens_from_model(self, model: Any, text: str, top_k: int = 5) -> List[Tuple[str, float]]: |
| """ |
| モデルから次のトークン候補を取得(常駐AIモデルを使用) |
| |
| Args: |
| model: LLMモデル(パス文字列またはモデルオブジェクト) |
| text: 入力テキスト |
| top_k: 取得する候補数 |
| |
| Returns: |
| List[Tuple[str, float]]: (トークン, 確率)のリスト |
| """ |
| try: |
| |
| from ai import AI |
| |
| |
| if isinstance(model, str): |
| model_path = model |
| elif hasattr(model, 'model_path'): |
| model_path = model.model_path |
| else: |
| |
| model_path = None |
| |
| |
| ai_model = AI.get_model(model_path) |
| return ai_model.get_token_probabilities(text, top_k) |
| |
| except Exception as e: |
| print(f"モデルからのトークン取得に失敗: {e}") |
| |
| return [] |
| |
| def expand_piece(self, piece: WordPiece, model: Any = None) -> List[WordPiece]: |
| """ |
| ピースを展開して子ピースを生成 |
| |
| Args: |
| piece: 展開するピース |
| model: LLMモデル |
| |
| Returns: |
| List[WordPiece]: 生成された子ピースのリスト |
| """ |
| children = [] |
| full_text = piece.get_full_text() |
| |
| |
| if piece.next_tokens: |
| |
| |
| for token, prob in piece.next_tokens: |
| |
| if not token: |
| continue |
| child_prob = piece.probability * prob |
| child = piece.add_child(token, child_prob) |
| children.append(child) |
| elif model: |
| |
| |
| next_tokens = self._get_next_tokens_from_model(model, full_text) |
| |
| |
| if next_tokens: |
| piece.next_tokens = next_tokens |
| for token, prob in next_tokens: |
| |
| if not token: |
| continue |
| child_prob = piece.probability * prob |
| child = piece.add_child(token, child_prob) |
| children.append(child) |
| else: |
| print(f"[WORD_PROCESSOR_STREAMING] No model provided for expansion") |
| |
| |
| return children |
| |
| def build_word_tree(self, prompt_text: str, root_text: str, model: Any, top_k: int = 5, max_depth: int = 10) -> List[WordPiece]: |
| """ |
| 単語ツリーを構築 |
| |
| Args: |
| root_text: ルートテキスト |
| model: LLMモデル |
| top_k: 取得する候補数 |
| max_depth: 最大深さ |
| |
| Returns: |
| List[WordPiece]: 完成した単語ピースのリスト |
| """ |
| |
| |
| |
| root = WordPiece(text=self.build_chat_prompt(prompt_text, )+root_text, probability=1.0) |
| |
| |
| |
| candidates = KList(2*top_k) |
| completed = [] |
| iteration = 0 |
| max_iterations = 1000 |
| children = self.expand_piece(root, model) |
| |
| for child in children: |
| candidates.add(child) |
| while not candidates.empty() and iteration < max_iterations and len(completed) < top_k: |
| iteration += 1 |
| |
| |
| current = candidates.pop() |
| |
| |
| |
| |
| |
| |
| |
| root_count = self._count_words(root.get_full_text()) |
| state, payload = self.check_word_completion(current, root_count, model) |
| |
| if state == WordState.COMPLETE: |
| completed.append(current) |
| |
| elif state == WordState.INCOMPLETE: |
| |
| children = self.expand_piece(current, model) |
| if len(children) == 0: |
| |
| print(f"[WORD_PROCESSOR_STREAMING] No children generated for '{current.get_full_text()}', marking as COMPLETE") |
| completed.append(current) |
| else: |
| for child in children: |
| candidates.add(child) |
| |
| |
| |
| total_prob = sum(p.probability for p in completed) |
| if total_prob > 0: |
| for piece in completed: |
| piece.probability = piece.probability / total_prob |
| |
| return completed[:top_k] |
| |
| def build_chat_prompt(self, user_content: str, |
| system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください") -> str: |
| """ |
| チャットプロンプトを構築 |
| |
| 注意: Rust側で既に整形済みのプロンプトが渡される場合は、そのまま返す |
| 後方互換性のため、単一のuser_contentが渡された場合は従来の形式で整形 |
| """ |
| |
| |
| if "<|start_header_id|>" in user_content or "<|eot_id|>" in user_content: |
| return user_content |
| |
| |
| prompt_text = ( |
| f"<|begin_of_text|>" |
| f"<|start_header_id|>system<|end_header_id|>\n" |
| f"{system_content}\n<|eot_id|>" |
| f"<|start_header_id|>user<|end_header_id|>\n" |
| f"{user_content}\n<|eot_id|>" |
| f"<|start_header_id|>assistant<|end_header_id|>\n" |
| ) |
| |
| |
| |
| BOS = "<|begin_of_text|>" |
| s = prompt_text.lstrip() |
| while s.startswith(BOS): |
| s = s[len(BOS):] |
| prompt_text = s |
| return prompt_text |
| |
| if __name__ == "__main__": |
| """WordDeterminerのテスト(ストリーミング版)""" |
| print("=== WordDeterminerテスト(ストリーミング版) ===") |
| |
| try: |
| |
| determiner = WordDeterminer() |
| |
| |
| prompt_text = "電球を作ったのは誰?" |
| root_text = "" |
| |
| print(f"プロンプト: '{prompt_text}'") |
| print(f"ルートテキスト: '{root_text}'") |
|
|
| print("\nAIモデルテスト:") |
| prompt_text = "電球を作ったのは誰?" |
| root_text = "電球を作ったのは候補1:トマス" |
| try: |
| from ai import AI |
| |
| |
| model = AI.get_model() |
| print(f"モデル取得成功: {type(model)}") |
| |
| |
| test_text = prompt_text |
| tokens = model.get_token_probabilities(test_text, k=5) |
| print(f"トークン確率 ({test_text}): {tokens}") |
| |
| |
| print("\n単語ツリー構築テスト:") |
| completed_pieces = determiner.build_word_tree( |
| prompt_text=prompt_text, |
| root_text=root_text, |
| model=model, |
| top_k=3, |
| max_depth=5 |
| ) |
| |
| print(f"完成したピース数: {len(completed_pieces)}") |
| for i, piece in enumerate(completed_pieces): |
| full_text = piece.get_full_text() |
| print(f" ピース{i+1}: '{full_text}' (確率: {piece.probability:.4f})") |
| |
| except Exception as e: |
| print(f"AIモデルテスト失敗: {e}") |
|
|
| |
| print("\n単語数カウントテスト:") |
| test_texts = [ |
| "電球", |
| "電球を作った", |
| "電球を作ったのは", |
| "電球を作ったのは誰", |
| "電球を作ったのは誰?" |
| ] |
| |
| for text in test_texts: |
| word_count = determiner._count_words(text) |
| tokens = determiner._get_next_tokens_from_model(model, text) |
| print(f" '{text}' → {word_count}語: {tokens}") |
| |
| |
| print("\n単語確定テスト:") |
| test_sequence = ["電球", "電球を", "電球を作", "電球を作った", "電球を作ったの", "電球を作ったのは"] |
| prev_count = 0 |
| |
| for text in test_sequence: |
| current_count = determiner._count_words(text) |
| if current_count > prev_count: |
| print(f" '{text}' → {current_count}語 (確定!)") |
| prev_count = current_count |
| else: |
| print(f" '{text}' → {current_count}語 (継続)") |
| |
| |
| print("\n境界文字テスト:") |
| test_chars = [" ", "?", "、", "。", "a", "1"] |
| for char in test_chars: |
| is_boundary = determiner.is_boundary_char(char) |
| print(f" '{char}': {is_boundary}") |
| |
| |
| print("\nピース作成テスト:") |
| root = WordPiece(text="電球", probability=1.0) |
| child1 = root.add_child("を", 0.6) |
| child2 = root.add_child("の", 0.3) |
| |
| print(f"ルートテキスト: {root.get_full_text()}") |
| print(f"子1テキスト: {child1.get_full_text()}") |
| print(f"子2テキスト: {child2.get_full_text()}") |
| |
| print("\nテスト完了") |
| |
| except ImportError as e: |
| print(f"必要なライブラリがインストールされていません: {e}") |
| except Exception as e: |
| print(f"テストエラー: {e}") |
|
|