Spaces:
Sleeping
Sleeping
| from typing import List, Tuple, Any, Optional | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| import os | |
| import math | |
| from .word_counter import WordCounter | |
| from .config import Config | |
| from .ai.base import BaseAI | |
| class WordState(Enum): | |
| """単語の状態""" | |
| INCOMPLETE = "incomplete" # 未完成 | |
| COMPLETE = "complete" # 完成 | |
| TRIGGER = "trigger" # トリガー(次語の開始) | |
| class KList: | |
| def __init__(self, num: int): | |
| self.num = num | |
| self.list: List[Any] = [] | |
| def check_k(self) -> None: | |
| if len(self.list) >= self.num: | |
| self.list.sort(key=lambda x: x.probability, reverse=True) | |
| self.list = self.list[:self.num] | |
| else: | |
| self.list.sort(key=lambda x: x.probability, reverse=True) | |
| def add(self, piece_word: Any) -> None: | |
| # 重複チェック: 同じテキストのピースが既に存在するか確認 | |
| new_text = piece_word.get_full_text() | |
| for existing_piece in self.list: | |
| if existing_piece.get_full_text() == new_text: | |
| # 既存のピースに確率を足す | |
| existing_piece.probability += piece_word.probability | |
| # 確率を更新したので、ソートし直す | |
| self.check_k() | |
| return | |
| # 重複がない場合は追加 | |
| self.list.append(piece_word) | |
| self.check_k() | |
| def pop(self) -> Any: | |
| if self.list: | |
| return self.list.pop(0) | |
| raise IndexError("List is empty") | |
| def empty(self) -> bool: | |
| return len(self.list) == 0 | |
| class WordPiece: | |
| """単語のピース(部分)""" | |
| text: str # ピースのテキスト | |
| probability: float # 確率 | |
| next_tokens: Optional[List[Tuple[str, float]]] = None # 次のトークン候補 | |
| parent: Optional['WordPiece'] = None # 親ピース | |
| children: List['WordPiece'] = None # 子ピース | |
| def __post_init__(self): | |
| if self.children is None: | |
| self.children = [] | |
| def get_full_text(self) -> str: | |
| """ルートからこのピースまでの完全なテキストを取得""" | |
| pieces = [] | |
| current = self | |
| while current is not None: | |
| if current.text: | |
| pieces.append(current.text) | |
| current = current.parent | |
| return "".join(reversed(pieces)) | |
| def get_full_word(self) -> str: | |
| """ルートの次語からこのピースまでの完全な単語を取得""" | |
| pieces = [] | |
| current = self | |
| while current is not None: | |
| if current.text: | |
| pieces.append(current.text) | |
| current = current.parent | |
| reversed_pieces = reversed(pieces[:-1]) | |
| return "".join(reversed_pieces) | |
| def add_child(self, text: str, probability: float, next_tokens: Optional[List[Tuple[str, float]]] = None) -> 'WordPiece': | |
| """子ピースを追加""" | |
| child = WordPiece( | |
| text=text, | |
| probability=probability, | |
| next_tokens=next_tokens, | |
| parent=self | |
| ) | |
| self.children.append(child) | |
| return child | |
| def is_leaf(self) -> bool: | |
| """葉ノードかどうか""" | |
| return len(self.children) == 0 | |
| def get_depth(self) -> int: | |
| """ルートからの深さを取得""" | |
| depth = 0 | |
| current = self.parent | |
| while current is not None: | |
| depth += 1 | |
| current = current.parent | |
| return depth | |
| class WordDeterminer: | |
| """単語確定システム(ストリーミング向けリアルタイムアルゴリズム)""" | |
| def __init__(self, word_counter: WordCounter = None): | |
| """ | |
| 初期化 | |
| Args: | |
| word_counter: WordCounterインスタンス(Noneの場合はデフォルトを使用) | |
| """ | |
| self.word_counter = word_counter or WordCounter() | |
| def is_boundary_char(self, char: str) -> bool: | |
| """境界文字かどうかを判定(fugashi使用)""" | |
| if not char: | |
| return False | |
| # 空白文字 | |
| if char.isspace(): | |
| return True | |
| # 句読点 | |
| punctuation = ",,..。!?!?:;;、\n\t" | |
| return char in punctuation | |
| def is_word_boundary(self, text: str, position: int) -> bool: | |
| """ | |
| WordCounterを使用して単語境界を判定 | |
| Args: | |
| text: テキスト | |
| position: 位置(負の値で末尾から指定可能) | |
| Returns: | |
| bool: 単語境界かどうか | |
| """ | |
| return self.word_counter.is_word_boundary(text, position) | |
| def check_word_completion(self, piece: WordPiece, root_count: int, model: Any = None) -> Tuple[WordState, Optional[Any]]: | |
| """ | |
| ストリーミング向けリアルタイム単語決定アルゴリズム | |
| Args: | |
| piece: チェックするピース | |
| root_count: ルートテキストの単語数 | |
| model: LLMモデル(BaseAIを実装したオブジェクト) | |
| Returns: | |
| Tuple[WordState, Optional[Any]]: (状態, ペイロード) | |
| """ | |
| full_text = piece.get_full_text() | |
| # next_tokensを取得 | |
| if not piece.next_tokens: | |
| if model: | |
| piece.next_tokens = self._get_next_tokens_from_model(model, full_text) | |
| else: | |
| return (WordState.COMPLETE, None) | |
| if not piece.next_tokens: | |
| return (WordState.COMPLETE, None) | |
| # 確率順にソート(念のため) | |
| sorted_tokens = sorted(piece.next_tokens, key=lambda x: x[1], reverse=True) | |
| # 括弧の処理 | |
| if piece.get_full_word() and piece.get_full_word()[-1] in ["(","「","(","【","〈","《","[","{","⦅"]: | |
| return (WordState.INCOMPLETE, None) | |
| if piece.get_full_word() and piece.get_full_word()[-1] in [")","]","}","》","〉","》","]","}","⦆"]: | |
| return (WordState.COMPLETE, None) | |
| # 全トークンの挙動を確認 | |
| count = max(1, len(sorted_tokens)) | |
| tokens = sorted_tokens[:count] | |
| boundary_prob = 0.0 # 境界を示すトークンの確率合計 | |
| continuation_prob = 0.0 # 継続を示すトークンの確率合計 | |
| total = sum(prob for _, prob in tokens) | |
| for token, prob in tokens: | |
| test_text = full_text + token | |
| test_word_count = self._count_words(test_text) | |
| # 単語数がより多く増えた場合のみ境界と判定(まとまりを上げる) | |
| if test_word_count > root_count + 1: | |
| boundary_prob += prob | |
| else: | |
| continuation_prob += prob | |
| # 判定ロジック | |
| if total > 0: | |
| boundary_ratio = boundary_prob / total | |
| # トークンの多くが境界を示す場合 → 確定(閾値を上げてまとまりを上げる) | |
| if boundary_ratio > 0.85: | |
| return (WordState.COMPLETE, None) | |
| # トークンの多くが継続を示す場合 → 継続(閾値を下げて継続しやすく) | |
| if boundary_ratio < 0.2: | |
| return (WordState.INCOMPLETE, None) | |
| # エントロピーベース判定 | |
| probs = [prob for _, prob in sorted_tokens] | |
| entropy = -sum(p * math.log(p + 1e-10) for p in probs if p > 0) | |
| max_entropy = math.log(len(sorted_tokens)) if len(sorted_tokens) > 1 else 1.0 | |
| normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0 | |
| return (WordState.INCOMPLETE, None) | |
| def _count_words(self, text: str) -> int: | |
| """ | |
| WordCounterを使用してテキストの単語数をカウント | |
| Args: | |
| text: カウントするテキスト | |
| Returns: | |
| int: 単語数 | |
| """ | |
| return self.word_counter.count_words(text) | |
| def _get_next_tokens_from_model(self, model: Any, text: str, top_k: int = 5) -> List[Tuple[str, float]]: | |
| """ | |
| モデルから次のトークン候補を取得(新しいBaseAIインターフェースを使用) | |
| Args: | |
| model: BaseAIを実装したモデルオブジェクト | |
| text: 入力テキスト | |
| top_k: 取得する候補数 | |
| Returns: | |
| List[Tuple[str, float]]: (トークン, 確率)のリスト | |
| """ | |
| try: | |
| # BaseAIインターフェースを実装したモデルを使用 | |
| if isinstance(model, BaseAI): | |
| return model.get_token_probabilities(text, top_k) | |
| else: | |
| print(f"[WORD_PROCESSOR] モデルがBaseAIインターフェースを実装していません: {type(model)}") | |
| return [] | |
| except Exception as e: | |
| print(f"[WORD_PROCESSOR] モデルからのトークン取得に失敗: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return [] | |
| def expand_piece(self, piece: WordPiece, model: Any = None) -> List[WordPiece]: | |
| """ | |
| ピースを展開して子ピースを生成 | |
| Args: | |
| piece: 展開するピース | |
| model: LLMモデル(BaseAIを実装したオブジェクト) | |
| Returns: | |
| List[WordPiece]: 生成された子ピースのリスト | |
| """ | |
| children = [] | |
| full_text = piece.get_full_text() | |
| if piece.next_tokens: | |
| # 既存のnext_tokensを使用 | |
| for token, prob in piece.next_tokens: | |
| # 空文字列トークンを無視 | |
| if not token: | |
| continue | |
| child_prob = piece.probability * prob | |
| child = piece.add_child(token, child_prob) | |
| children.append(child) | |
| elif model: | |
| # モデルから次のトークンを取得 | |
| next_tokens = self._get_next_tokens_from_model(model, full_text) | |
| if next_tokens: | |
| piece.next_tokens = next_tokens | |
| for token, prob in next_tokens: | |
| # 空文字列トークンを無視 | |
| if not token: | |
| continue | |
| child_prob = piece.probability * prob | |
| child = piece.add_child(token, child_prob) | |
| children.append(child) | |
| else: | |
| print(f"[WORD_PROCESSOR] No model provided for expansion") | |
| return children | |
| def build_word_tree(self, prompt_text: str, root_text: str, model: Any, top_k: int = 5, max_depth: int = 10) -> List[WordPiece]: | |
| """ | |
| 単語ツリーを構築 | |
| Args: | |
| prompt_text: プロンプトテキスト | |
| root_text: ルートテキスト | |
| model: LLMモデル(BaseAIを実装したオブジェクト) | |
| top_k: 取得する候補数 | |
| max_depth: 最大深さ | |
| Returns: | |
| List[WordPiece]: 完成した単語ピースのリスト | |
| """ | |
| # モデルのbuild_chat_promptメソッドを使用 | |
| if isinstance(model, BaseAI): | |
| prompt = model.build_chat_prompt(prompt_text) | |
| else: | |
| # フォールバック: 従来の形式 | |
| prompt = self.build_chat_prompt(prompt_text) | |
| # ルートピースを作成 | |
| root = WordPiece(text=prompt + root_text, probability=1.0) | |
| # 優先度付きキュー(確率順) | |
| candidates = KList(2 * top_k) | |
| completed = [] | |
| iteration = 0 | |
| max_iterations = 1000 | |
| children = self.expand_piece(root, model) | |
| for child in children: | |
| candidates.add(child) | |
| while not candidates.empty() and iteration < max_iterations and len(completed) < top_k: | |
| iteration += 1 | |
| # 最も確率の高い候補を取得 | |
| current = candidates.pop() | |
| # 単語完成状態をチェック | |
| root_count = self._count_words(root.get_full_text()) | |
| state, payload = self.check_word_completion(current, root_count, model) | |
| if state == WordState.COMPLETE: | |
| completed.append(current) | |
| elif state == WordState.INCOMPLETE: | |
| # ピースを展開 | |
| children = self.expand_piece(current, model) | |
| if len(children) == 0: | |
| # 子が生成できない場合、ピースを完成として扱う(無限ループ防止) | |
| print(f"[WORD_PROCESSOR] No children generated for '{current.get_full_text()}', marking as COMPLETE") | |
| completed.append(current) | |
| else: | |
| for child in children: | |
| candidates.add(child) | |
| # 確率で正規化 | |
| total_prob = sum(p.probability for p in completed) | |
| if total_prob > 0: | |
| for piece in completed: | |
| piece.probability = piece.probability / total_prob | |
| return completed[:top_k] | |
| def build_chat_prompt(self, user_content: str, | |
| system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください") -> str: | |
| """ | |
| チャットプロンプトを構築(後方互換性のため) | |
| 注意: 新しいBaseAIインターフェースを使用する場合は、model.build_chat_prompt()を使用してください | |
| """ | |
| # 既に整形済みのプロンプトが渡されている場合(複数行、ヘッダーを含む) | |
| # そのまま返す | |
| if "<|start_header_id|>" in user_content or "<|eot_id|>" in user_content: | |
| return user_content | |
| # 後方互換性: 単一のuser_contentが渡された場合の従来の形式 | |
| prompt_text = ( | |
| f"<|begin_of_text|>" | |
| f"<|start_header_id|>system<|end_header_id|>\n" | |
| f"{system_content}\n<|eot_id|>" | |
| f"<|start_header_id|>user<|end_header_id|>\n" | |
| f"{user_content}\n<|eot_id|>" | |
| f"<|start_header_id|>assistant<|end_header_id|>\n" | |
| ) | |
| # BOS(<|begin_of_text|>) の重複を抑止: 先頭のBOSを全て除去 | |
| BOS = "<|begin_of_text|>" | |
| s = prompt_text.lstrip() | |
| while s.startswith(BOS): | |
| s = s[len(BOS):] | |
| prompt_text = s | |
| return prompt_text | |