LLMView / package /word_processor.py
WatNeru's picture
Add FastAPI word tree server
09c17cd
from typing import List, Tuple, Any, Optional
from dataclasses import dataclass
from enum import Enum
import os
import math
from word_counter import WordCounter
from config import Config
class WordState(Enum):
"""単語の状態"""
INCOMPLETE = "incomplete" # 未完成
COMPLETE = "complete" # 完成
TRIGGER = "trigger" # トリガー(次語の開始)
class KList:
def __init__(self, num: int):
self.num = num
self.list: List[Any] = []
def check_k(self) -> None:
if len(self.list) >= self.num:
self.list.sort(key=lambda x: x.probability, reverse=True)
self.list = self.list[:self.num]
else:
self.list.sort(key=lambda x: x.probability, reverse=True)
def add(self, piece_word: Any) -> None:
# 重複チェック: 同じテキストのピースが既に存在するか確認
new_text = piece_word.get_full_text()
for existing_piece in self.list:
if existing_piece.get_full_text() == new_text:
# 既存のピースに確率を足す
existing_piece.probability += piece_word.probability
# 確率を更新したので、ソートし直す
self.check_k()
return
# 重複がない場合は追加
self.list.append(piece_word)
self.check_k()
def pop(self) -> Any:
if self.list:
return self.list.pop(0)
raise IndexError("List is empty")
def empty(self) -> bool:
return len(self.list) == 0
@dataclass
class WordPiece:
"""単語のピース(部分)"""
text: str # ピースのテキスト
probability: float # 確率
next_tokens: Optional[List[Tuple[str, float]]] = None # 次のトークン候補
parent: Optional['WordPiece'] = None # 親ピース
children: List['WordPiece'] = None # 子ピース
def __post_init__(self):
if self.children is None:
self.children = []
def get_full_text(self) -> str:
"""ルートからこのピースまでの完全なテキストを取得"""
pieces = []
current = self
while current is not None:
if current.text:
pieces.append(current.text)
current = current.parent
return "".join(reversed(pieces))
def get_full_word(self) -> str:
"""ルートの次語からこのピースまでの完全な単語を取得"""
pieces = []
current = self
while current is not None:
if current.text:
pieces.append(current.text)
current = current.parent
reversed_pieces = reversed(pieces[:-1])
return "".join(reversed_pieces)
def add_child(self, text: str, probability: float, next_tokens: Optional[List[Tuple[str, float]]] = None) -> 'WordPiece':
"""子ピースを追加"""
child = WordPiece(
text=text,
probability=probability,
next_tokens=next_tokens,
parent=self
)
self.children.append(child)
return child
def is_leaf(self) -> bool:
"""葉ノードかどうか"""
return len(self.children) == 0
def get_depth(self) -> int:
"""ルートからの深さを取得"""
depth = 0
current = self.parent
while current is not None:
depth += 1
current = current.parent
return depth
class WordDeterminer:
"""単語確定システム(ストリーミング向けリアルタイムアルゴリズム)"""
def __init__(self, word_counter: WordCounter = None):
"""
初期化
Args:
word_counter: WordCounterインスタンス(Noneの場合はデフォルトを使用)
"""
self.word_counter = word_counter or WordCounter()
def is_boundary_char(self, char: str) -> bool:
"""境界文字かどうかを判定(fugashi使用)"""
if not char:
return False
# 空白文字
if char.isspace():
return True
# 句読点
punctuation = ",,..。!?!?:;;、\n\t"
return char in punctuation
def is_word_boundary(self, text: str, position: int) -> bool:
"""
WordCounterを使用して単語境界を判定
Args:
text: テキスト
position: 位置(負の値で末尾から指定可能)
Returns:
bool: 単語境界かどうか
"""
return self.word_counter.is_word_boundary(text, position)
def check_word_completion(self, piece: WordPiece, root_count: int, model: Any = None) -> Tuple[WordState, Optional[Any]]:
"""
ストリーミング向けリアルタイム単語決定アルゴリズム
戦略:
1. 確率エントロピー: 次のトークンの不確実性を測定
2. 確率重み付き境界検出: 高確率トークンの挙動を重視
3. 信頼度ベース判定: 高確率トークンが明確に境界を示す場合のみ確定
アルゴリズム:
- エントロピーが低い(確率が集中)→ 単語継続の可能性が高い
- エントロピーが高い(確率が分散)→ 単語境界の可能性
- 高確率トークンが境界を示す → 確定
- 低確率トークンだけが境界を示す → 無視
Args:
piece: チェックするピース
root_count: ルートテキストの単語数
model: LLMモデル(必要に応じて)
Returns:
Tuple[WordState, Optional[Any]]: (状態, ペイロード)
"""
full_text = piece.get_full_text()
# next_tokensを取得
if not piece.next_tokens:
if model:
piece.next_tokens = self._get_next_tokens_from_model(model, full_text)
else:
return (WordState.COMPLETE, None)
if not piece.next_tokens:
return (WordState.COMPLETE, None)
# 確率順にソート(念のため)
sorted_tokens = sorted(piece.next_tokens, key=lambda x: x[1], reverse=True)
# sorted_tokens = piece.next_tokens
if piece.get_full_word()[-1] in ["(","「","(","【","〈","《","[","{","⦅","《","[","{","⦅","《","[","{","⦅","《","[","{","⦅","《","[","{","⦅"]:
return (WordState.INCOMPLETE, None)
if piece.get_full_word()[-1] in [")","]","}","》","〉","》","]","}","⦆","》","]","}","⦆","》","]","}","⦆","》","]","}","⦆","》","]","}","⦆"]:
return (WordState.COMPLETE, None)
# 2.全トークンの挙動を確認
count = max(1, len(sorted_tokens) )
tokens = sorted_tokens[:count]
boundary_prob = 0.0 # 境界を示すトークンの確率合計
continuation_prob = 0.0 # 継続を示すトークンの確率合計
total = sum(prob for _, prob in tokens)
for token, prob in tokens:
test_text = full_text + token
test_word_count = self._count_words(test_text)
# 単語数がより多く増えた場合のみ境界と判定(まとまりを上げる)
if test_word_count > root_count + 1:
boundary_prob += prob
else:
continuation_prob += prob
# 3. 判定ロジック
if total > 0:
boundary_ratio = boundary_prob / total
# トークンの多くが境界を示す場合 → 確定(閾値を上げてまとまりを上げる)
if boundary_ratio > 0.85:
return (WordState.COMPLETE, None)
# トークンの多くが継続を示す場合 → 継続(閾値を下げて継続しやすく)
if boundary_ratio < 0.2:
return (WordState.INCOMPLETE, None)
# 1. 確率エントロピーを計算
probs = [prob for _, prob in sorted_tokens]
entropy = -sum(p * math.log(p + 1e-10) for p in probs if p > 0)
max_entropy = math.log(len(sorted_tokens)) if len(sorted_tokens) > 1 else 1.0
normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0
# 4. エントロピーベース判定
# エントロピーが低い(確率が集中)→ 単語継続の可能性
# エントロピーが高い(確率が分散)→ 単語境界の可能性
return (WordState.INCOMPLETE, None)
def _count_words(self, text: str) -> int:
"""
WordCounterを使用してテキストの単語数をカウント
Args:
text: カウントするテキスト
Returns:
int: 単語数
"""
return self.word_counter.count_words(text)
def _get_next_tokens_from_model(self, model: Any, text: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""
モデルから次のトークン候補を取得(常駐AIモデルを使用)
Args:
model: LLMモデル(パス文字列またはモデルオブジェクト)
text: 入力テキスト
top_k: 取得する候補数
Returns:
List[Tuple[str, float]]: (トークン, 確率)のリスト
"""
try:
# AIクラスをインポート
from ai import AI
# モデルパスを取得
if isinstance(model, str):
model_path = model
elif hasattr(model, 'model_path'):
model_path = model.model_path
else:
# デフォルトモデルを使用
model_path = None
# 常駐AIモデルを使用
ai_model = AI.get_model(model_path)
return ai_model.get_token_probabilities(text, top_k)
except Exception as e:
print(f"モデルからのトークン取得に失敗: {e}")
return []
def expand_piece(self, piece: WordPiece, model: Any = None) -> List[WordPiece]:
"""
ピースを展開して子ピースを生成
Args:
piece: 展開するピース
model: LLMモデル
Returns:
List[WordPiece]: 生成された子ピースのリスト
"""
children = []
full_text = piece.get_full_text()
#1#print(f"[WORD_PROCESSOR_STREAMING] expand_piece: '{full_text}'")
if piece.next_tokens:
# 既存のnext_tokensを使用
#1#print(f"[WORD_PROCESSOR_STREAMING] Using existing next_tokens: {len(piece.next_tokens)}")
for token, prob in piece.next_tokens:
# 空文字列トークンを無視
if not token:
continue
child_prob = piece.probability * prob
child = piece.add_child(token, child_prob)
children.append(child)
elif model:
# モデルから次のトークンを取得
#1#print(f"[WORD_PROCESSOR_STREAMING] Getting tokens from model for: '{full_text}'")
next_tokens = self._get_next_tokens_from_model(model, full_text)
#1#print(f"[WORD_PROCESSOR_STREAMING] Got {len(next_tokens)} tokens from model")
if next_tokens:
piece.next_tokens = next_tokens
for token, prob in next_tokens:
# 空文字列トークンを無視
if not token:
continue
child_prob = piece.probability * prob
child = piece.add_child(token, child_prob)
children.append(child)
else:
print(f"[WORD_PROCESSOR_STREAMING] No model provided for expansion")
#1#print(f"[WORD_PROCESSOR_STREAMING] Generated {len(children)} children")
return children
def build_word_tree(self, prompt_text: str, root_text: str, model: Any, top_k: int = 5, max_depth: int = 10) -> List[WordPiece]:
"""
単語ツリーを構築
Args:
root_text: ルートテキスト
model: LLMモデル
top_k: 取得する候補数
max_depth: 最大深さ
Returns:
List[WordPiece]: 完成した単語ピースのリスト
"""
#1#print(f"[WORD_PROCESSOR_STREAMING] build_word_tree called: prompt='{prompt_text}', root='{root_text}', top_k={top_k}")
# ルートピースを作成
root = WordPiece(text=self.build_chat_prompt(prompt_text, )+root_text, probability=1.0)
#1#print(f"[WORD_PROCESSOR_STREAMING] Root piece created: '{root.get_full_text()}'")
# 優先度付きキュー(確率順)
candidates = KList(2*top_k)
completed = []
iteration = 0
max_iterations = 1000
children = self.expand_piece(root, model)
#1#print(f"[WORD_PROCESSOR_STREAMING] Initial children: {len(children)}")
for child in children:
candidates.add(child)
while not candidates.empty() and iteration < max_iterations and len(completed) < top_k:
iteration += 1
# 最も確率の高い候補を取得
current = candidates.pop()
# # 深さ制限チェック
# if current.get_depth() >= max_depth:
# completed.append(current)
# continue
# 単語完成状態をチェック
root_count = self._count_words(root.get_full_text())
state, payload = self.check_word_completion(current, root_count, model)
if state == WordState.COMPLETE:
completed.append(current)
# print(f"☆☆☆☆☆complete: {current.get_full_text()}")
elif state == WordState.INCOMPLETE:
# ピースを展開
children = self.expand_piece(current, model)
if len(children) == 0:
# 子が生成できない場合、ピースを完成として扱う(無限ループ防止)
print(f"[WORD_PROCESSOR_STREAMING] No children generated for '{current.get_full_text()}', marking as COMPLETE")
completed.append(current)
else:
for child in children:
candidates.add(child)
# print(f"☆☆☆☆☆while end{len(completed),candidates.empty(),iteration}")
# 確率で正規化
total_prob = sum(p.probability for p in completed)
if total_prob > 0:
for piece in completed:
piece.probability = piece.probability / total_prob
return completed[:top_k]
def build_chat_prompt(self, user_content: str,
system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください") -> str:
"""
チャットプロンプトを構築
注意: Rust側で既に整形済みのプロンプトが渡される場合は、そのまま返す
後方互換性のため、単一のuser_contentが渡された場合は従来の形式で整形
"""
# Rust側で既に整形済みのプロンプトが渡されている場合(複数行、ヘッダーを含む)
# そのまま返す
if "<|start_header_id|>" in user_content or "<|eot_id|>" in user_content:
return user_content
# 後方互換性: 単一のuser_contentが渡された場合の従来の形式
prompt_text = (
f"<|begin_of_text|>"
f"<|start_header_id|>system<|end_header_id|>\n"
f"{system_content}\n<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n"
f"{user_content}\n<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n"
)
# BOS(<|begin_of_text|>) の重複を抑止: 先頭のBOSを全て除去
# llama-cpp 側でBOSが自動付与されるため、ここでは付与しない
BOS = "<|begin_of_text|>"
s = prompt_text.lstrip()
while s.startswith(BOS):
s = s[len(BOS):]
prompt_text = s
return prompt_text
if __name__ == "__main__":
"""WordDeterminerのテスト(ストリーミング版)"""
print("=== WordDeterminerテスト(ストリーミング版) ===")
try:
# WordDeterminerを初期化
determiner = WordDeterminer()
# プロンプト設定
prompt_text = "電球を作ったのは誰?"
root_text = ""
print(f"プロンプト: '{prompt_text}'")
print(f"ルートテキスト: '{root_text}'")
print("\nAIモデルテスト:")
prompt_text = "電球を作ったのは誰?"
root_text = "電球を作ったのは候補1:トマス"
try:
from ai import AI
# モデルを取得
model = AI.get_model()
print(f"モデル取得成功: {type(model)}")
# トークン確率取得テスト
test_text = prompt_text
tokens = model.get_token_probabilities(test_text, k=5)
print(f"トークン確率 ({test_text}): {tokens}")
# 単語ツリー構築テスト
print("\n単語ツリー構築テスト:")
completed_pieces = determiner.build_word_tree(
prompt_text=prompt_text,
root_text=root_text,
model=model,
top_k=3,
max_depth=5
)
print(f"完成したピース数: {len(completed_pieces)}")
for i, piece in enumerate(completed_pieces):
full_text = piece.get_full_text()
print(f" ピース{i+1}: '{full_text}' (確率: {piece.probability:.4f})")
except Exception as e:
print(f"AIモデルテスト失敗: {e}")
# 単語数カウントテスト
print("\n単語数カウントテスト:")
test_texts = [
"電球",
"電球を作った",
"電球を作ったのは",
"電球を作ったのは誰",
"電球を作ったのは誰?"
]
for text in test_texts:
word_count = determiner._count_words(text)
tokens = determiner._get_next_tokens_from_model(model, text)
print(f" '{text}' → {word_count}語: {tokens}")
# 単語確定テスト
print("\n単語確定テスト:")
test_sequence = ["電球", "電球を", "電球を作", "電球を作った", "電球を作ったの", "電球を作ったのは"]
prev_count = 0
for text in test_sequence:
current_count = determiner._count_words(text)
if current_count > prev_count:
print(f" '{text}' → {current_count}語 (確定!)")
prev_count = current_count
else:
print(f" '{text}' → {current_count}語 (継続)")
# 境界文字テスト
print("\n境界文字テスト:")
test_chars = [" ", "?", "、", "。", "a", "1"]
for char in test_chars:
is_boundary = determiner.is_boundary_char(char)
print(f" '{char}': {is_boundary}")
# ピース作成テスト
print("\nピース作成テスト:")
root = WordPiece(text="電球", probability=1.0)
child1 = root.add_child("を", 0.6)
child2 = root.add_child("の", 0.3)
print(f"ルートテキスト: {root.get_full_text()}")
print(f"子1テキスト: {child1.get_full_text()}")
print(f"子2テキスト: {child2.get_full_text()}")
print("\nテスト完了")
except ImportError as e:
print(f"必要なライブラリがインストールされていません: {e}")
except Exception as e:
print(f"テストエラー: {e}")