Spaces:
Sleeping
Sleeping
first commit
Browse files- MULTI_MODEL_DESIGN.md +378 -0
- README.md +122 -13
- app.py +239 -0
- package/__init__.py +5 -0
- package/adapter.py +93 -0
- package/ai/__init__.py +80 -0
- package/ai/anthropic_ai.py +180 -0
- package/ai/base.py +99 -0
- package/ai/google_ai.py +152 -0
- package/ai/openai_ai.py +181 -0
- package/ai/transformers_ai.py +278 -0
- package/config.py +36 -0
- package/word_counter.py +115 -0
- package/word_processor.py +392 -0
- requirements.txt +37 -0
MULTI_MODEL_DESIGN.md
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# マルチモデル対応設計案
|
| 2 |
+
|
| 3 |
+
## 概要
|
| 4 |
+
現在のLLMViewはLlama 3.2 3B(transformers)のみ対応していますが、他のTransformersモデル(Qwen、Mistral、Gemma等)にも対応できるように拡張します。
|
| 5 |
+
|
| 6 |
+
**重要**: Hugging Face Spacesでの使用を前提とする場合、**APIを使う必要はありません**。
|
| 7 |
+
Transformersライブラリでローカルにモデルをロードする方法(TransformersAI)を使用してください。
|
| 8 |
+
これにより、完全なトークン確率情報が取得でき、コストもかかりません。
|
| 9 |
+
|
| 10 |
+
外部API(OpenAI、Anthropic、Google)のサポートは、ローカル環境やテスト目的でのみ使用することを想定しています。
|
| 11 |
+
|
| 12 |
+
## アーキテクチャ設計
|
| 13 |
+
|
| 14 |
+
### 1. アダプターパターンの導入
|
| 15 |
+
|
| 16 |
+
```
|
| 17 |
+
AI (基底クラス/インターフェース)
|
| 18 |
+
├── TransformersAI (現在の実装 - Llama等のローカルモデル)
|
| 19 |
+
├── OpenAIAI (ChatGPT API)
|
| 20 |
+
├── AnthropicAI (Claude API)
|
| 21 |
+
├── GoogleAI (Gemini API)
|
| 22 |
+
└── HuggingFaceInferenceAI (Hugging Face Inference API)
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### 2. 統一インターフェース
|
| 26 |
+
|
| 27 |
+
すべてのモデルアダプターが実装すべきメソッド:
|
| 28 |
+
|
| 29 |
+
```python
|
| 30 |
+
class BaseAI:
|
| 31 |
+
def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
|
| 32 |
+
"""
|
| 33 |
+
テキストから次のトークン候補と確率を取得
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
List[Tuple[str, float]]: (トークン, 確率)のリスト
|
| 37 |
+
"""
|
| 38 |
+
raise NotImplementedError
|
| 39 |
+
|
| 40 |
+
def build_chat_prompt(self, user_content: str, system_content: str = "") -> str:
|
| 41 |
+
"""
|
| 42 |
+
モデル固有のチャットプロンプト形式に変換
|
| 43 |
+
"""
|
| 44 |
+
raise NotImplementedError
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### 3. モデルタイプの識別
|
| 48 |
+
|
| 49 |
+
環境変数または設定ファイルでモデルタイプを指定:
|
| 50 |
+
|
| 51 |
+
```python
|
| 52 |
+
MODEL_TYPE = os.getenv("MODEL_TYPE", "transformers") # transformers, openai, anthropic, google, hf_inference
|
| 53 |
+
MODEL_PATH = os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct") # モデル識別子(Hugging FaceリポジトリID)
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
**Hugging Face Spacesでの推奨設定**:
|
| 57 |
+
```python
|
| 58 |
+
# 環境変数例(Hugging Face Spaces用)
|
| 59 |
+
MODEL_TYPE=transformers
|
| 60 |
+
HF_MODEL_REPO=Qwen/Qwen2.5-3B-Instruct # または他のモデル
|
| 61 |
+
HF_TOKEN=your_hf_token # プライベートモデルの場合
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### 4. 各モデルの特徴と実装方針
|
| 65 |
+
|
| 66 |
+
#### 4.1 TransformersAI (現在の実装)
|
| 67 |
+
- **特徴**: ローカルでモデルをロード、logitsから直接確率を取得可能
|
| 68 |
+
- **利点**: 完全なトークン確率情報が利用可能
|
| 69 |
+
- **実装**: 現在の`AI`クラスを`TransformersAI`にリネーム
|
| 70 |
+
|
| 71 |
+
#### 4.2 OpenAIAI (ChatGPT)
|
| 72 |
+
- **特徴**: API経由、`logprobs`パラメータでトークン確率を取得可能
|
| 73 |
+
- **API**: `openai.ChatCompletion.create()` の `logprobs=True`
|
| 74 |
+
- **制約**:
|
| 75 |
+
- トークン確率は`logprobs`で取得可能(GPT-4以降)
|
| 76 |
+
- リクエストごとにAPIコールが必要
|
| 77 |
+
- レート制限とコストが発生
|
| 78 |
+
- **実装方針**:
|
| 79 |
+
```python
|
| 80 |
+
response = openai.ChatCompletion.create(
|
| 81 |
+
model="gpt-4",
|
| 82 |
+
messages=[...],
|
| 83 |
+
logprobs=True,
|
| 84 |
+
top_logprobs=5
|
| 85 |
+
)
|
| 86 |
+
# logprobsから確率を計算
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
#### 4.3 AnthropicAI (Claude)
|
| 90 |
+
- **特徴**: API経由、`logprobs`パラメータでトークン確率を取得可能(Claude 3.5以降)
|
| 91 |
+
- **API**: `anthropic.Anthropic().messages.create()` の `logprobs=True`
|
| 92 |
+
- **制約**:
|
| 93 |
+
- トークン確率は`logprobs`で取得可能
|
| 94 |
+
- リクエストごとにAPIコールが必要
|
| 95 |
+
- **実装方針**:
|
| 96 |
+
```python
|
| 97 |
+
response = client.messages.create(
|
| 98 |
+
model="claude-3-5-sonnet-20241022",
|
| 99 |
+
messages=[...],
|
| 100 |
+
logprobs=True,
|
| 101 |
+
top_logprobs=5
|
| 102 |
+
)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
#### 4.4 GoogleAI (Gemini)
|
| 106 |
+
- **特徴**: API経由、`logprobs`パラメータでトークン確率を取得可能
|
| 107 |
+
- **API**: `google.generativeai.GenerativeModel.generate_content()` の `logprobs=True`
|
| 108 |
+
- **制約**:
|
| 109 |
+
- トークン確率は`logprobs`で取得可能
|
| 110 |
+
- リクエストごとにAPIコールが必要
|
| 111 |
+
- **実装方針**:
|
| 112 |
+
```python
|
| 113 |
+
response = model.generate_content(
|
| 114 |
+
prompt,
|
| 115 |
+
generation_config={"logprobs": True, "top_k": 5}
|
| 116 |
+
)
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
#### 4.5 HuggingFaceInferenceAI
|
| 120 |
+
- **特徴**: Hugging Face Inference API経由、一部のモデルでlogits取得可能
|
| 121 |
+
- **API**: `huggingface_hub.InferenceClient.text_generation()` の `details=True`
|
| 122 |
+
- **制約**:
|
| 123 |
+
- すべてのモデルでlogitsが利用可能とは限らない
|
| 124 |
+
- リクエストごとにAPIコールが必要
|
| 125 |
+
- 無料プランにはレート制限あり
|
| 126 |
+
|
| 127 |
+
#### 4.6 Hugging Face Spacesで利用可能なモデル(Transformers経由)
|
| 128 |
+
|
| 129 |
+
Hugging Face Spacesでは、以下のモデルをTransformersライブラリで直接利用可能:
|
| 130 |
+
|
| 131 |
+
##### 4.6.1 日本語対応モデル(推奨)
|
| 132 |
+
|
| 133 |
+
| モデル | リポジトリID | サイズ | 特徴 | トークン確率取得 |
|
| 134 |
+
|--------|------------|--------|------|----------------|
|
| 135 |
+
| **Llama 3.2 3B Instruct** | `meta-llama/Llama-3.2-3B-Instruct` | 3B | 多言語対応、現在使用中 | ✅ 完全対応 |
|
| 136 |
+
| **Qwen 2.5** | `Qwen/Qwen2.5-3B-Instruct` | 3B | 日本語に強い、高性能 | ✅ 完全対応 |
|
| 137 |
+
| **Mistral 7B Instruct** | `mistralai/Mistral-7B-Instruct-v0.2` | 7B | 高性能、多言語対応 | ✅ 完全対応 |
|
| 138 |
+
| **Gemma 2B/7B** | `google/gemma-2b-it`, `google/gemma-7b-it` | 2B/7B | Google製、軽量 | ✅ 完全対応 |
|
| 139 |
+
| **Phi-3** | `microsoft/Phi-3-mini-4k-instruct` | 3.8B | 軽量、高性能 | ✅ 完全対応 |
|
| 140 |
+
| **TinyLlama** | `TinyLlama/TinyLlama-1.1B-Chat-v1.0` | 1.1B | 超軽量 | ✅ 完全対応 |
|
| 141 |
+
|
| 142 |
+
##### 4.6.2 日本語特化モデル
|
| 143 |
+
|
| 144 |
+
| モデル | リポジトリID | サイズ | 特徴 | トークン確率取得 |
|
| 145 |
+
|--------|------------|--------|------|----------------|
|
| 146 |
+
| **ELYZA-japanese-Llama-2** | `elyza/ELYZA-japanese-Llama-2-7b-instruct` | 7B | 日本語特化 | ✅ 完全対応 |
|
| 147 |
+
| **japanese-stablelm** | `stabilityai/japanese-stablelm-base-gamma-7b` | 7B | 日本語特化 | ✅ 完全対応 |
|
| 148 |
+
| **weblab-10b** | `rinna/weblab-10b-instruction-sft` | 10B | 日本語特化、大規模 | ✅ 完全対応 |
|
| 149 |
+
|
| 150 |
+
##### 4.6.3 その他の主要モデル
|
| 151 |
+
|
| 152 |
+
| モデル | リポジトリID | サイズ | 特徴 | トークン確率取得 |
|
| 153 |
+
|--------|------------|--------|------|----------------|
|
| 154 |
+
| **Falcon** | `tiiuae/falcon-7b-instruct` | 7B | オープンソース | ✅ 完全対応 |
|
| 155 |
+
| **MPT** | `mosaicml/mpt-7b-instruct` | 7B | 商用利用可能 | ✅ 完全対応 |
|
| 156 |
+
| **StarCoder** | `bigcode/starcoder2-7b` | 7B | コード生成特化 | ✅ 完全対応 |
|
| 157 |
+
|
| 158 |
+
##### 4.6.4 Hugging Face Inference APIで利用可能なモデル
|
| 159 |
+
|
| 160 |
+
**注意**: Inference APIでは、すべてのモデルでトークン確率(logits)が取得できるわけではありません。
|
| 161 |
+
以下のモデルはInference API経由でも利用可能ですが、トークン確率の取得はモデルによって異なります:
|
| 162 |
+
|
| 163 |
+
- **無料プラン**: 制限あり、一部モデルのみ
|
| 164 |
+
- **有料プラン**: より多くのモデルにアクセス可能
|
| 165 |
+
|
| 166 |
+
**推奨アプローチ**:
|
| 167 |
+
- Hugging Face Spacesでは、**Transformersライブラリで直接モデルをロード**する方法を推奨
|
| 168 |
+
- これにより、完全なトークン確率情報が取得可能
|
| 169 |
+
- Inference APIは、モデルをローカルにロードできない場合の代替手段
|
| 170 |
+
|
| 171 |
+
### 5. プロンプトフォーマットの統一
|
| 172 |
+
|
| 173 |
+
各モデルに適したプロンプト形式に変換する`build_chat_prompt`メソッドを実装:
|
| 174 |
+
|
| 175 |
+
```python
|
| 176 |
+
# Llama 3.2形式
|
| 177 |
+
"<|start_header_id|>system<|end_header_id|>\n{system}\n<|eot_id|>..."
|
| 178 |
+
|
| 179 |
+
# OpenAI形式
|
| 180 |
+
[
|
| 181 |
+
{"role": "system", "content": system},
|
| 182 |
+
{"role": "user", "content": user}
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
# Claude形式
|
| 186 |
+
[
|
| 187 |
+
{"role": "user", "content": f"{system}\n\n{user}"}
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
# Gemini形式
|
| 191 |
+
f"{system}\n\n{user}"
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### 6. 設定管理の拡張
|
| 195 |
+
|
| 196 |
+
`config.py`または環境変数で管理:
|
| 197 |
+
|
| 198 |
+
```python
|
| 199 |
+
# Hugging Face Spaces用(推奨)
|
| 200 |
+
MODEL_TYPE=transformers
|
| 201 |
+
HF_MODEL_REPO=Qwen/Qwen2.5-3B-Instruct # または meta-llama/Llama-3.2-3B-Instruct
|
| 202 |
+
HF_TOKEN=your_hf_token # プライベートモデルの場合のみ
|
| 203 |
+
|
| 204 |
+
# OpenAI API用
|
| 205 |
+
MODEL_TYPE=openai
|
| 206 |
+
OPENAI_API_KEY=sk-...
|
| 207 |
+
OPENAI_MODEL=gpt-4
|
| 208 |
+
|
| 209 |
+
# Anthropic API用
|
| 210 |
+
MODEL_TYPE=anthropic
|
| 211 |
+
ANTHROPIC_API_KEY=sk-ant-...
|
| 212 |
+
ANTHROPIC_MODEL=claude-3-5-sonnet-20241022
|
| 213 |
+
|
| 214 |
+
# Google API用
|
| 215 |
+
MODEL_TYPE=google
|
| 216 |
+
GOOGLE_API_KEY=...
|
| 217 |
+
GOOGLE_MODEL=gemini-pro
|
| 218 |
+
|
| 219 |
+
# Hugging Face Inference API用(オプション)
|
| 220 |
+
MODEL_TYPE=hf_inference
|
| 221 |
+
HF_INFERENCE_API_KEY=hf_...
|
| 222 |
+
HF_INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
**Hugging Face Spacesでの推奨設定**:
|
| 226 |
+
- `MODEL_TYPE=transformers`を使用(ローカルでモデルをロード)
|
| 227 |
+
- `HF_MODEL_REPO`でモデルを指定(デフォルト: `meta-llama/Llama-3.2-3B-Instruct`)
|
| 228 |
+
- 他のモデルに切り替える場合は、`HF_MODEL_REPO`を変更するだけ
|
| 229 |
+
|
| 230 |
+
### 7. 実装の優先順位
|
| 231 |
+
|
| 232 |
+
#### Phase 1: 基盤整備(最優先)
|
| 233 |
+
- `BaseAI`インターフェースの定義
|
| 234 |
+
- 現在の`AI`クラスを`TransformersAI`にリファクタリング
|
| 235 |
+
- モデルタイプの識別とファクトリーパターンの実装
|
| 236 |
+
- **Hugging Face Spacesでの複数モデル対応**(Llama 3.2以外のモデル選択可能に)
|
| 237 |
+
|
| 238 |
+
#### Phase 2: Hugging Faceモデルの拡張対応
|
| 239 |
+
- **Qwen 2.5対応**: 日本語に強い、高性能
|
| 240 |
+
- **Mistral 7B対応**: 多言語対応、高性能
|
| 241 |
+
- **Gemma対応**: Google製、軽量
|
| 242 |
+
- 各モデルのプロンプトフォーマット対応
|
| 243 |
+
|
| 244 |
+
#### Phase 3: OpenAI対応
|
| 245 |
+
- `OpenAIAI`クラスの実装
|
| 246 |
+
- `logprobs`パラメータの活用
|
| 247 |
+
- プロンプトフォーマット変換
|
| 248 |
+
|
| 249 |
+
#### Phase 4: Anthropic対応
|
| 250 |
+
- `AnthropicAI`クラスの実装
|
| 251 |
+
- Claude固有のプロンプト形式対応
|
| 252 |
+
|
| 253 |
+
#### Phase 5: Google/Gemini対応
|
| 254 |
+
- `GoogleAI`クラスの実装
|
| 255 |
+
|
| 256 |
+
#### Phase 6: その他
|
| 257 |
+
- Hugging Face Inference API対応(オプション)
|
| 258 |
+
- カスタムモデルエンドポイント対応
|
| 259 |
+
|
| 260 |
+
### 8. 課題と解決策
|
| 261 |
+
|
| 262 |
+
#### 課題1: APIコストとレート制限
|
| 263 |
+
- **解決策**:
|
| 264 |
+
- キャッシュ機能の実装
|
| 265 |
+
- リクエスト間隔の制御
|
| 266 |
+
- ローカルモデルとの併用推奨
|
| 267 |
+
|
| 268 |
+
#### 課題2: トークン確率の取得方法の違い
|
| 269 |
+
- **解決策**:
|
| 270 |
+
- 各APIの`logprobs`パラメータを活用
|
| 271 |
+
- 確率の正規化処理を統一
|
| 272 |
+
|
| 273 |
+
#### 課題3: プロンプト形式の違い
|
| 274 |
+
- **解決策**:
|
| 275 |
+
- 各モデル用の`build_chat_prompt`メソッドを実装
|
| 276 |
+
- 統一された入力インターフェースを提供
|
| 277 |
+
|
| 278 |
+
#### 課題4: エラーハンドリング
|
| 279 |
+
- **解決策**:
|
| 280 |
+
- 各APIのエラーレスポンスを統一形式で処理
|
| 281 |
+
- フォールバック機能の実装
|
| 282 |
+
|
| 283 |
+
### 9. ファイル構成
|
| 284 |
+
|
| 285 |
+
```
|
| 286 |
+
package/
|
| 287 |
+
├── ai/
|
| 288 |
+
│ ├── __init__.py # ファクトリー関数
|
| 289 |
+
│ ├── base.py # BaseAIインターフェース
|
| 290 |
+
│ ├── transformers_ai.py # TransformersAI (現在のAIクラス)
|
| 291 |
+
│ ├── openai_ai.py # OpenAIAI
|
| 292 |
+
│ ├── anthropic_ai.py # AnthropicAI
|
| 293 |
+
│ ├── google_ai.py # GoogleAI
|
| 294 |
+
│ └── hf_inference_ai.py # HuggingFaceInferenceAI
|
| 295 |
+
├── config.py # 設定管理(拡張)
|
| 296 |
+
└── ...
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
### 10. 使用例
|
| 300 |
+
|
| 301 |
+
#### 10.1 Hugging Face Spacesでの使用(推奨)
|
| 302 |
+
|
| 303 |
+
```python
|
| 304 |
+
# 環境変数でモデルを指定
|
| 305 |
+
# HF_MODEL_REPO=Qwen/Qwen2.5-3B-Instruct python app.py
|
| 306 |
+
# HF_MODEL_REPO=mistralai/Mistral-7B-Instruct-v0.2 python app.py
|
| 307 |
+
|
| 308 |
+
from package.ai import get_ai_model
|
| 309 |
+
|
| 310 |
+
# ファクトリー関数で適切なモデルを取得
|
| 311 |
+
ai_model = get_ai_model() # MODEL_TYPE=transformers(デフォルト)
|
| 312 |
+
|
| 313 |
+
# 統一されたインターフェースで使用
|
| 314 |
+
tokens = ai_model.get_token_probabilities("こんにちは", k=5)
|
| 315 |
+
```
|
| 316 |
+
|
| 317 |
+
#### 10.2 OpenAI APIでの使用
|
| 318 |
+
|
| 319 |
+
```python
|
| 320 |
+
# MODEL_TYPE=openai OPENAI_API_KEY=sk-... python app.py
|
| 321 |
+
|
| 322 |
+
from package.ai import get_ai_model
|
| 323 |
+
|
| 324 |
+
ai_model = get_ai_model() # MODEL_TYPE=openai
|
| 325 |
+
tokens = ai_model.get_token_probabilities("こんにちは", k=5)
|
| 326 |
+
```
|
| 327 |
+
|
| 328 |
+
#### 10.3 モデルの動的切り替え
|
| 329 |
+
|
| 330 |
+
```python
|
| 331 |
+
# アプリ起動時に環境変数で指定
|
| 332 |
+
# または、設定ファイルで管理
|
| 333 |
+
|
| 334 |
+
import os
|
| 335 |
+
os.environ["HF_MODEL_REPO"] = "Qwen/Qwen2.5-3B-Instruct"
|
| 336 |
+
|
| 337 |
+
from package.ai import get_ai_model
|
| 338 |
+
ai_model = get_ai_model()
|
| 339 |
+
```
|
| 340 |
+
|
| 341 |
+
### 11. Hugging Face Spacesでのモデル選択ガイド
|
| 342 |
+
|
| 343 |
+
#### 11.1 モデル選択の基準
|
| 344 |
+
|
| 345 |
+
1. **日本語対応**: 日本語処理が必要な場合
|
| 346 |
+
- 推奨: `Qwen/Qwen2.5-3B-Instruct`, `meta-llama/Llama-3.2-3B-Instruct`
|
| 347 |
+
|
| 348 |
+
2. **軽量性**: リソース制約がある場合
|
| 349 |
+
- 推奨: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`, `google/gemma-2b-it`
|
| 350 |
+
|
| 351 |
+
3. **高性能**: 品質を重視する場合
|
| 352 |
+
- 推奨: `mistralai/Mistral-7B-Instruct-v0.2`, `Qwen/Qwen2.5-3B-Instruct`
|
| 353 |
+
|
| 354 |
+
4. **日本語特化**: 日本語タスクに特化
|
| 355 |
+
- 推奨: `elyza/ELYZA-japanese-Llama-2-7b-instruct`, `rinna/weblab-10b-instruction-sft`
|
| 356 |
+
|
| 357 |
+
#### 11.2 モデル切り替え手順
|
| 358 |
+
|
| 359 |
+
1. Hugging Face Spacesの環境変数で`HF_MODEL_REPO`を設定
|
| 360 |
+
2. アプリを再起動
|
| 361 |
+
3. モデルが自動的にロードされる(初回はダウンロード時間がかかる場合あり)
|
| 362 |
+
|
| 363 |
+
#### 11.3 注意事項
|
| 364 |
+
|
| 365 |
+
- **ストレージ制約**: Hugging Face Spacesのストレージ制限に注意
|
| 366 |
+
- **モデルサイズ**: 大きなモデル(7B以上)はメモリとロード時間がかかる
|
| 367 |
+
- **トークン確率**: すべてのTransformersモデルで完全なトークン確率が取得可能
|
| 368 |
+
- **APIコスト**: Transformersモデルは無料(ローカルロード)、APIモデルは有料
|
| 369 |
+
|
| 370 |
+
## まとめ
|
| 371 |
+
|
| 372 |
+
この設計により、以下のメリットが得られます:
|
| 373 |
+
|
| 374 |
+
1. **拡張性**: 新しいモデルを簡単に追加可能
|
| 375 |
+
2. **互換性**: 既存のコードを最小限の変更で維持
|
| 376 |
+
3. **柔軟性**: ユーザーが好みのモデルを選択可能
|
| 377 |
+
4. **統一性**: すべてのモデルが同じインターフェースを使用
|
| 378 |
+
|
README.md
CHANGED
|
@@ -1,13 +1,122 @@
|
|
| 1 |
-
--
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLMView Multi-Model - Hugging Face Spaces版
|
| 2 |
+
|
| 3 |
+
複数のAIモデルに対応した単語ツリー構築ツール(Gradio版)
|
| 4 |
+
|
| 5 |
+
## 特徴
|
| 6 |
+
|
| 7 |
+
- ✅ **マルチモデル対応**: Transformersモデル(Llama、Qwen、Mistral、Gemma等)に対応
|
| 8 |
+
- ✅ **Hugging Face Spaces対応**: GradioでHFSにデプロイ可能
|
| 9 |
+
- ✅ **GPU対応**: ZeroGPUを使用してGPUリソースを要求
|
| 10 |
+
- ✅ **完全なトークン確率**: ローカルモデルで完全なトークン確率情報を取得
|
| 11 |
+
|
| 12 |
+
## Hugging Face Spacesでのデプロイ
|
| 13 |
+
|
| 14 |
+
### 1. リポジトリの作成
|
| 15 |
+
|
| 16 |
+
1. Hugging Face Spacesで新しいSpaceを作成
|
| 17 |
+
2. SDK: **Gradio** を選択
|
| 18 |
+
3. Hardware: **GPU** を選択(推奨)
|
| 19 |
+
|
| 20 |
+
### 2. 環境変数の設定
|
| 21 |
+
|
| 22 |
+
Spaceの設定で以下の環境変数を設定:
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
# モデルタイプ(transformers推奨)
|
| 26 |
+
MODEL_TYPE=transformers
|
| 27 |
+
|
| 28 |
+
# 使用するモデル(Hugging FaceリポジトリID)
|
| 29 |
+
HF_MODEL_REPO=meta-llama/Llama-3.2-3B-Instruct
|
| 30 |
+
|
| 31 |
+
# プライベートモデルの場合のみ
|
| 32 |
+
HF_TOKEN=your_hf_token
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
### 3. モデルの切り替え
|
| 36 |
+
|
| 37 |
+
環境変数`HF_MODEL_REPO`を変更するだけで、他のモデルに切り替え可能:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
# Qwen 2.5
|
| 41 |
+
HF_MODEL_REPO=Qwen/Qwen2.5-3B-Instruct
|
| 42 |
+
|
| 43 |
+
# Mistral 7B
|
| 44 |
+
HF_MODEL_REPO=mistralai/Mistral-7B-Instruct-v0.2
|
| 45 |
+
|
| 46 |
+
# Gemma 2B
|
| 47 |
+
HF_MODEL_REPO=google/gemma-2b-it
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## 使用方法
|
| 51 |
+
|
| 52 |
+
1. **プロンプトを入力**: 質問や指示を入力
|
| 53 |
+
2. **ルートテキスト(オプション)**: 既存のテキストの続きを生成する場合に指定
|
| 54 |
+
3. **パラメータ調整**:
|
| 55 |
+
- `top_k`: 取得する候補数(1-20)
|
| 56 |
+
- `max_depth`: 最大探索深さ(1-50)
|
| 57 |
+
4. **「単語ツリーを構築」ボタンをクリック**
|
| 58 |
+
|
| 59 |
+
## ファイル構成
|
| 60 |
+
|
| 61 |
+
```
|
| 62 |
+
LLMView_multi_model/
|
| 63 |
+
├── app.py # Gradioアプリ(メイン)
|
| 64 |
+
├── requirements.txt # 依存パッケージ
|
| 65 |
+
├── README.md # このファイル
|
| 66 |
+
├── MULTI_MODEL_DESIGN.md # 設計ドキュメント
|
| 67 |
+
└── package/
|
| 68 |
+
├── __init__.py
|
| 69 |
+
├── adapter.py # ModelAdapter(マルチモデル対応)
|
| 70 |
+
├── config.py # 設定管理
|
| 71 |
+
├── word_processor.py # 単語処理ロジック
|
| 72 |
+
├── word_counter.py # 単語数カウント
|
| 73 |
+
└── ai/
|
| 74 |
+
├── __init__.py # ファクトリー関数
|
| 75 |
+
├── base.py # BaseAIインターフェース
|
| 76 |
+
├── transformers_ai.py # TransformersAI(ローカルモデル)
|
| 77 |
+
├── openai_ai.py # OpenAIAI(オプション)
|
| 78 |
+
├── anthropic_ai.py # AnthropicAI(オプション)
|
| 79 |
+
└── google_ai.py # GoogleAI(オプション)
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## 依存パッケージ
|
| 83 |
+
|
| 84 |
+
主要な依存パッケージ:
|
| 85 |
+
|
| 86 |
+
- `gradio>=4.0.0`: Web UI
|
| 87 |
+
- `spaces`: Hugging Face Spaces用
|
| 88 |
+
- `transformers>=4.30.0`: Transformersモデル
|
| 89 |
+
- `torch>=2.0.0`: PyTorch
|
| 90 |
+
- `fugashi>=1.3.0`: 形態素解析(日本語)
|
| 91 |
+
- `sudachipy>=0.6.7`: Sudachi形態素解析(オプション)
|
| 92 |
+
|
| 93 |
+
詳細は`requirements.txt`を参照してください。
|
| 94 |
+
|
| 95 |
+
## 注意事項
|
| 96 |
+
|
| 97 |
+
1. **GPU推奨**: モデルのロードと推論にはGPUが推奨されます
|
| 98 |
+
2. **モデルサイズ**: 大きなモデル(7B以上)はメモリとロード時間がかかります
|
| 99 |
+
3. **初回起動**: モデルのダウンロードに時間がかかる場合があります
|
| 100 |
+
4. **API非推奨**: Hugging Face Spacesでは、Transformersモデル(ローカルロード)を使用してください
|
| 101 |
+
|
| 102 |
+
## トラブルシューティング
|
| 103 |
+
|
| 104 |
+
### モデルがロードされない
|
| 105 |
+
|
| 106 |
+
- `HF_TOKEN`が正しく設定されているか確認(プライベートモデルの場合)
|
| 107 |
+
- モデルリポジトリIDが正しいか確認
|
| 108 |
+
- Spaceのログを確認
|
| 109 |
+
|
| 110 |
+
### GPUが利用できない
|
| 111 |
+
|
| 112 |
+
- SpaceのHardware設定でGPUが有効になっているか確認
|
| 113 |
+
- `spaces`パッケージがインストールされているか確認
|
| 114 |
+
|
| 115 |
+
### 形態素解析エラー
|
| 116 |
+
|
| 117 |
+
- `fugashi`がインストールされているか確認
|
| 118 |
+
- HFSでは通常、デフォルト設定で動作します
|
| 119 |
+
|
| 120 |
+
## ライセンス
|
| 121 |
+
|
| 122 |
+
MIT License
|
app.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
LLMView Multi-Model - Gradioアプリ
|
| 4 |
+
Hugging Face Spaces用
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import threading
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Dict, Any, Optional
|
| 11 |
+
|
| 12 |
+
import gradio as gr
|
| 13 |
+
|
| 14 |
+
# ZeroGPU対応: spacesパッケージをインポート(デコレータ用)
|
| 15 |
+
try:
|
| 16 |
+
import spaces
|
| 17 |
+
SPACES_AVAILABLE = True
|
| 18 |
+
print("[SPACE] spacesパッケージをインポートしました")
|
| 19 |
+
except ImportError:
|
| 20 |
+
SPACES_AVAILABLE = False
|
| 21 |
+
print("[SPACE] spacesパッケージが見つかりません(ローカル環境の可能性)")
|
| 22 |
+
# ダミーデコレータを定義
|
| 23 |
+
class DummyGPU:
|
| 24 |
+
def __call__(self, func):
|
| 25 |
+
return func
|
| 26 |
+
spaces = type('spaces', (), {'GPU': DummyGPU()})()
|
| 27 |
+
|
| 28 |
+
# パッケージパスを追加
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 30 |
+
|
| 31 |
+
from package.ai import get_ai_model
|
| 32 |
+
from package.word_processor import WordDeterminer, WordPiece
|
| 33 |
+
from package.adapter import ModelAdapter
|
| 34 |
+
|
| 35 |
+
# グローバル変数
|
| 36 |
+
adapter: Optional[ModelAdapter] = None
|
| 37 |
+
status_message = "モデル初期化中..."
|
| 38 |
+
status_lock = threading.Lock()
|
| 39 |
+
|
| 40 |
+
# 環境変数から設定を取得
|
| 41 |
+
MODEL_TYPE = os.getenv("MODEL_TYPE", "transformers")
|
| 42 |
+
HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _set_status(message: str) -> None:
|
| 46 |
+
"""ステータスメッセージを更新"""
|
| 47 |
+
global status_message
|
| 48 |
+
with status_lock:
|
| 49 |
+
status_message = message
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def initialize_model() -> None:
|
| 53 |
+
"""モデルを初期化"""
|
| 54 |
+
global adapter
|
| 55 |
+
try:
|
| 56 |
+
print("[INIT] モデル初期化開始")
|
| 57 |
+
_set_status("モデルを読み込み中です...")
|
| 58 |
+
|
| 59 |
+
# AIモデルを取得
|
| 60 |
+
ai_model = get_ai_model()
|
| 61 |
+
print(f"[INIT] AIモデル取得成功: {type(ai_model)}")
|
| 62 |
+
|
| 63 |
+
# ModelAdapterを初期化
|
| 64 |
+
adapter = ModelAdapter(ai_model)
|
| 65 |
+
print("[INIT] ModelAdapter初期化完了")
|
| 66 |
+
|
| 67 |
+
_set_status("モデル準備完了")
|
| 68 |
+
print("[INIT] モデル初期化完了")
|
| 69 |
+
except Exception as exc:
|
| 70 |
+
error_msg = f"モデル初期化に失敗しました: {exc}"
|
| 71 |
+
print(f"[INIT] エラー: {error_msg}")
|
| 72 |
+
_set_status(error_msg)
|
| 73 |
+
import traceback
|
| 74 |
+
traceback.print_exc()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# バックグラウンドでモデルを初期化
|
| 78 |
+
threading.Thread(target=initialize_model, daemon=True).start()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def build_word_tree(
|
| 82 |
+
prompt_text: str,
|
| 83 |
+
root_text: str = "",
|
| 84 |
+
top_k: int = 5,
|
| 85 |
+
max_depth: int = 10
|
| 86 |
+
) -> List[Dict[str, Any]]:
|
| 87 |
+
"""
|
| 88 |
+
単語ツリーを構築
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
prompt_text: プロンプトテキスト
|
| 92 |
+
root_text: ルートテキスト(オプション)
|
| 93 |
+
top_k: 取得する候補数
|
| 94 |
+
max_depth: 最大深さ
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
List[Dict[str, Any]]: 候補リスト
|
| 98 |
+
"""
|
| 99 |
+
if not prompt_text.strip():
|
| 100 |
+
return [{"text": "プロンプトを入力してください", "probability": 0.0}]
|
| 101 |
+
|
| 102 |
+
if adapter is None:
|
| 103 |
+
with status_lock:
|
| 104 |
+
current_status = status_message
|
| 105 |
+
return [{"text": f"モデル準備中: {current_status}", "probability": 0.0}]
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
results = adapter.build_word_tree(
|
| 109 |
+
prompt_text=prompt_text,
|
| 110 |
+
root_text=root_text,
|
| 111 |
+
top_k=top_k,
|
| 112 |
+
max_depth=max_depth,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if not results:
|
| 116 |
+
return [{"text": "候補が生成されませんでした", "probability": 0.0}]
|
| 117 |
+
|
| 118 |
+
return results
|
| 119 |
+
except Exception as exc:
|
| 120 |
+
import traceback
|
| 121 |
+
traceback.print_exc()
|
| 122 |
+
return [{"text": f"エラー: {exc}", "probability": 0.0}]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_status() -> str:
|
| 126 |
+
"""ステータスを取得"""
|
| 127 |
+
with status_lock:
|
| 128 |
+
current_status = status_message
|
| 129 |
+
|
| 130 |
+
model_info = f"モデルタイプ: {MODEL_TYPE}\n"
|
| 131 |
+
if MODEL_TYPE == "transformers":
|
| 132 |
+
model_info += f"モデル: {HF_MODEL_REPO}\n"
|
| 133 |
+
|
| 134 |
+
return f"{model_info}ステータス: {current_status}"
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# Gradioインターフェース
|
| 138 |
+
with gr.Blocks(title="LLMView Multi-Model", theme=gr.themes.Soft()) as demo:
|
| 139 |
+
gr.Markdown("""
|
| 140 |
+
# LLMView Multi-Model
|
| 141 |
+
|
| 142 |
+
複数のAIモデルに対応した単語ツリー構築ツール
|
| 143 |
+
|
| 144 |
+
## 使い方
|
| 145 |
+
1. プロンプトを入力
|
| 146 |
+
2. オプションでルートテキストを指定(既存のテキストの続きを生成する場合)
|
| 147 |
+
3. パラメータを調整(top_k: 候補数、max_depth: 最大深さ)
|
| 148 |
+
4. 「単語ツリーを構築」ボタンをクリック
|
| 149 |
+
""")
|
| 150 |
+
|
| 151 |
+
with gr.Row():
|
| 152 |
+
with gr.Column(scale=2):
|
| 153 |
+
prompt_input = gr.Textbox(
|
| 154 |
+
label="プロンプト",
|
| 155 |
+
placeholder="例: 電球を作ったのは誰?",
|
| 156 |
+
lines=3
|
| 157 |
+
)
|
| 158 |
+
root_input = gr.Textbox(
|
| 159 |
+
label="ルートテキスト(オプション)",
|
| 160 |
+
placeholder="例: 電球を作ったの���",
|
| 161 |
+
lines=2
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
with gr.Row():
|
| 165 |
+
top_k_slider = gr.Slider(
|
| 166 |
+
minimum=1,
|
| 167 |
+
maximum=20,
|
| 168 |
+
value=5,
|
| 169 |
+
step=1,
|
| 170 |
+
label="候補数 (top_k)"
|
| 171 |
+
)
|
| 172 |
+
max_depth_slider = gr.Slider(
|
| 173 |
+
minimum=1,
|
| 174 |
+
maximum=50,
|
| 175 |
+
value=10,
|
| 176 |
+
step=1,
|
| 177 |
+
label="最大深さ (max_depth)"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
build_btn = gr.Button("単語ツリーを構築", variant="primary")
|
| 181 |
+
|
| 182 |
+
with gr.Column(scale=1):
|
| 183 |
+
status_output = gr.Textbox(
|
| 184 |
+
label="ステータス",
|
| 185 |
+
value=get_status(),
|
| 186 |
+
lines=5,
|
| 187 |
+
interactive=False
|
| 188 |
+
)
|
| 189 |
+
refresh_status_btn = gr.Button("ステータス更新")
|
| 190 |
+
|
| 191 |
+
results_output = gr.Dataframe(
|
| 192 |
+
label="結果",
|
| 193 |
+
headers=["テキスト", "確率"],
|
| 194 |
+
datatype=["str", "number"],
|
| 195 |
+
interactive=False
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# イベントハンドラ
|
| 199 |
+
def build_and_display(prompt, root, top_k, max_depth):
|
| 200 |
+
results = build_word_tree(prompt, root, int(top_k), int(max_depth))
|
| 201 |
+
# DataFrame用に変換
|
| 202 |
+
df_data = [[r["text"], f"{r['probability']:.4f}"] for r in results]
|
| 203 |
+
return df_data, get_status()
|
| 204 |
+
|
| 205 |
+
build_btn.click(
|
| 206 |
+
fn=build_and_display,
|
| 207 |
+
inputs=[prompt_input, root_input, top_k_slider, max_depth_slider],
|
| 208 |
+
outputs=[results_output, status_output]
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
refresh_status_btn.click(
|
| 212 |
+
fn=lambda: get_status(),
|
| 213 |
+
outputs=status_output
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ZeroGPU対応: 起動時に検出されるように、デコレータ付き関数を定義
|
| 218 |
+
@spaces.GPU
|
| 219 |
+
def _gpu_init_function():
|
| 220 |
+
"""GPU初期化用のダミー関数(Space起動時に検出される)"""
|
| 221 |
+
pass
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
# Hugging Face Spaces用の設定
|
| 226 |
+
# GPU要求を確実に検出させる
|
| 227 |
+
if SPACES_AVAILABLE:
|
| 228 |
+
try:
|
| 229 |
+
_gpu_init_function()
|
| 230 |
+
print("[SPACE] GPU要求を送信しました")
|
| 231 |
+
except Exception as e:
|
| 232 |
+
print(f"[SPACE] GPU要求エラー: {e}")
|
| 233 |
+
|
| 234 |
+
demo.launch(
|
| 235 |
+
server_name="0.0.0.0",
|
| 236 |
+
server_port=7860,
|
| 237 |
+
share=False
|
| 238 |
+
)
|
| 239 |
+
|
package/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Other models パッケージ
|
| 3 |
+
マルチモデル対応のAIアダプター
|
| 4 |
+
"""
|
| 5 |
+
|
package/adapter.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ModelAdapter - マルチモデル対応アダプター
|
| 3 |
+
新しいAIインターフェース(BaseAI)に対応
|
| 4 |
+
"""
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
from .word_processor import WordDeterminer, WordPiece
|
| 7 |
+
from .ai.base import BaseAI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ModelAdapter:
|
| 11 |
+
"""
|
| 12 |
+
マルチモデル対応アダプター
|
| 13 |
+
- 初期化コストの高いコンポーネント(WordDeterminer, AIモデル)を1回だけ生成して保持
|
| 14 |
+
- メソッドでビルド処理を提供
|
| 15 |
+
- 返却はシリアライズしやすい dict/list 形式
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, ai_model: BaseAI):
|
| 19 |
+
"""
|
| 20 |
+
初期化
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
ai_model: BaseAIを実装したモデルインスタンス
|
| 24 |
+
"""
|
| 25 |
+
# WordDeterminer(内部で Sudachi C モードの WordCounter を使用)
|
| 26 |
+
self.determiner = WordDeterminer()
|
| 27 |
+
|
| 28 |
+
# AIモデルを保持
|
| 29 |
+
self.model = ai_model
|
| 30 |
+
|
| 31 |
+
def _clean_text(self, text: str) -> str:
|
| 32 |
+
"""制御文字・不可視文字・置換文字を厳密に取り除く(最終出力用)"""
|
| 33 |
+
if not text:
|
| 34 |
+
return ""
|
| 35 |
+
|
| 36 |
+
# 制御文字(0x00-0x1F、0x7F-0x9F)を除去
|
| 37 |
+
# ただし、改行・タブ・復帰は許可
|
| 38 |
+
cleaned = []
|
| 39 |
+
for ch in text:
|
| 40 |
+
code = ord(ch)
|
| 41 |
+
# 許可する制御文字: 改行(0x0A), タブ(0x09), 復帰(0x0D)
|
| 42 |
+
if code in [0x09, 0x0A, 0x0D]:
|
| 43 |
+
cleaned.append(ch)
|
| 44 |
+
# 通常の印刷可能文字
|
| 45 |
+
elif ch.isprintable():
|
| 46 |
+
# 置換文字(U+FFFD)を除去
|
| 47 |
+
if ch != "\uFFFD":
|
| 48 |
+
cleaned.append(ch)
|
| 49 |
+
# その他の制御文字や不可視文字は除去
|
| 50 |
+
|
| 51 |
+
result = "".join(cleaned)
|
| 52 |
+
# ゼロ幅文字を除去
|
| 53 |
+
result = result.replace("\u200B", "") # Zero-width space
|
| 54 |
+
result = result.replace("\u200C", "") # Zero-width non-joiner
|
| 55 |
+
result = result.replace("\u200D", "") # Zero-width joiner
|
| 56 |
+
result = result.replace("\uFEFF", "") # Zero-width no-break space
|
| 57 |
+
return result.strip()
|
| 58 |
+
|
| 59 |
+
def build_word_tree(
|
| 60 |
+
self,
|
| 61 |
+
prompt_text: str,
|
| 62 |
+
root_text: str = "",
|
| 63 |
+
top_k: int = 5,
|
| 64 |
+
max_depth: int = 10
|
| 65 |
+
) -> List[Dict[str, Any]]:
|
| 66 |
+
"""
|
| 67 |
+
単語ツリーを構築して、完成ピースを dict の配列で返す。
|
| 68 |
+
各要素: { text: str, probability: float }
|
| 69 |
+
"""
|
| 70 |
+
pieces: List[WordPiece] = self.determiner.build_word_tree(
|
| 71 |
+
prompt_text=prompt_text,
|
| 72 |
+
root_text=root_text,
|
| 73 |
+
model=self.model,
|
| 74 |
+
top_k=top_k,
|
| 75 |
+
max_depth=max_depth,
|
| 76 |
+
)
|
| 77 |
+
return [
|
| 78 |
+
{"text": self._clean_text(p.get_full_word()), "probability": float(p.probability)}
|
| 79 |
+
for p in pieces
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
def build_chat_prompt(
|
| 83 |
+
self,
|
| 84 |
+
user_content: str,
|
| 85 |
+
system_content: str = "あなたは親切で役に立つAIアシスタントです。"
|
| 86 |
+
) -> str:
|
| 87 |
+
"""チャットプロンプト文字列を返す。"""
|
| 88 |
+
return self.model.build_chat_prompt(user_content, system_content)
|
| 89 |
+
|
| 90 |
+
def count_words(self, text: str) -> int:
|
| 91 |
+
"""Sudachi(C) ベースでの語数カウント。"""
|
| 92 |
+
return self.determiner._count_words(text)
|
| 93 |
+
|
package/ai/__init__.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AIモデルアダプターのファクトリー関数
|
| 3 |
+
環境変数に基づいて適切なモデルを自動選択
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from .base import BaseAI
|
| 8 |
+
from .transformers_ai import TransformersAI
|
| 9 |
+
from .openai_ai import OpenAIAI
|
| 10 |
+
from .anthropic_ai import AnthropicAI
|
| 11 |
+
from .google_ai import GoogleAI
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_ai_model(model_type: Optional[str] = None, **kwargs) -> BaseAI:
|
| 15 |
+
"""
|
| 16 |
+
環境変数または引数に基づいて適切なAIモデルを取得
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
model_type: モデルタイプ("transformers", "openai", "anthropic", "google")
|
| 20 |
+
Noneの場合は環境変数MODEL_TYPEから取得
|
| 21 |
+
**kwargs: 各モデル固有の引数
|
| 22 |
+
- transformers: model_path
|
| 23 |
+
- openai: model_name, api_key
|
| 24 |
+
- anthropic: model_name, api_key
|
| 25 |
+
- google: model_name, api_key
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
BaseAI: 選択されたモデルのインスタンス
|
| 29 |
+
|
| 30 |
+
Examples:
|
| 31 |
+
# 環境変数から自動選択
|
| 32 |
+
ai = get_ai_model()
|
| 33 |
+
|
| 34 |
+
# 明示的に指定
|
| 35 |
+
ai = get_ai_model("transformers", model_path="Qwen/Qwen2.5-3B-Instruct")
|
| 36 |
+
ai = get_ai_model("openai", model_name="gpt-4", api_key="sk-...")
|
| 37 |
+
"""
|
| 38 |
+
# モデルタイプを決定
|
| 39 |
+
if model_type is None:
|
| 40 |
+
model_type = os.getenv("MODEL_TYPE", "transformers")
|
| 41 |
+
|
| 42 |
+
model_type = model_type.lower()
|
| 43 |
+
|
| 44 |
+
# モデルタイプに応じて適切なクラスを返す
|
| 45 |
+
if model_type == "transformers":
|
| 46 |
+
model_path = kwargs.get("model_path") or os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")
|
| 47 |
+
return TransformersAI.get_model(model_path=model_path)
|
| 48 |
+
|
| 49 |
+
elif model_type == "openai":
|
| 50 |
+
model_name = kwargs.get("model_name") or os.getenv("OPENAI_MODEL", "gpt-4")
|
| 51 |
+
api_key = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
|
| 52 |
+
return OpenAIAI.get_model(model_name=model_name, api_key=api_key)
|
| 53 |
+
|
| 54 |
+
elif model_type == "anthropic":
|
| 55 |
+
model_name = kwargs.get("model_name") or os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022")
|
| 56 |
+
api_key = kwargs.get("api_key") or os.getenv("ANTHROPIC_API_KEY")
|
| 57 |
+
return AnthropicAI.get_model(model_name=model_name, api_key=api_key)
|
| 58 |
+
|
| 59 |
+
elif model_type == "google":
|
| 60 |
+
model_name = kwargs.get("model_name") or os.getenv("GOOGLE_MODEL", "gemini-pro")
|
| 61 |
+
api_key = kwargs.get("api_key") or os.getenv("GOOGLE_API_KEY")
|
| 62 |
+
return GoogleAI.get_model(model_name=model_name, api_key=api_key)
|
| 63 |
+
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(
|
| 66 |
+
f"不明なモデルタイプ: {model_type}. "
|
| 67 |
+
f"サポートされているタイプ: transformers, openai, anthropic, google"
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# 後方互換性のため、BaseAIもエクスポート
|
| 72 |
+
__all__ = [
|
| 73 |
+
"BaseAI",
|
| 74 |
+
"TransformersAI",
|
| 75 |
+
"OpenAIAI",
|
| 76 |
+
"AnthropicAI",
|
| 77 |
+
"GoogleAI",
|
| 78 |
+
"get_ai_model",
|
| 79 |
+
]
|
| 80 |
+
|
package/ai/anthropic_ai.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AnthropicAI - Anthropic API(Claude)用アダプター
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Tuple, Optional, Dict, Any
|
| 5 |
+
import os
|
| 6 |
+
import math
|
| 7 |
+
from .base import BaseAI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AnthropicAI(BaseAI):
|
| 11 |
+
"""
|
| 12 |
+
Anthropic API(Claude)用アダプター
|
| 13 |
+
|
| 14 |
+
特徴:
|
| 15 |
+
- API経由でモデルにアクセス
|
| 16 |
+
- logprobsパラメータでトークン確率を取得可能(Claude 3.5以降)
|
| 17 |
+
- user/assistantを明確に分離する形式を推奨(messages配列形式)
|
| 18 |
+
- systemは別パラメータとして扱う(messagesとは別)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
_instances = {} # モデルごとのインスタンスをキャッシュ
|
| 22 |
+
|
| 23 |
+
def __new__(cls, model_name: str = None, api_key: str = None):
|
| 24 |
+
"""シングルトンパターンでクライアントを常駐"""
|
| 25 |
+
model = model_name or os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022")
|
| 26 |
+
key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
| 27 |
+
|
| 28 |
+
cache_key = f"{model}:{key}"
|
| 29 |
+
if cache_key not in cls._instances:
|
| 30 |
+
cls._instances[cache_key] = super().__new__(cls)
|
| 31 |
+
cls._instances[cache_key]._initialized = False
|
| 32 |
+
|
| 33 |
+
return cls._instances[cache_key]
|
| 34 |
+
|
| 35 |
+
def __init__(self, model_name: str = None, api_key: str = None):
|
| 36 |
+
"""
|
| 37 |
+
初期化
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model_name: モデル名(例: "claude-3-5-sonnet-20241022")
|
| 41 |
+
api_key: Anthropic APIキー
|
| 42 |
+
"""
|
| 43 |
+
if hasattr(self, '_initialized') and self._initialized:
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
self.model_name = model_name or os.getenv("ANTHROPIC_MODEL", "claude-3-5-sonnet-20241022")
|
| 47 |
+
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
| 48 |
+
self._initialized = True
|
| 49 |
+
|
| 50 |
+
if not self.api_key:
|
| 51 |
+
raise ValueError("ANTHROPIC_API_KEYが設定されていません")
|
| 52 |
+
|
| 53 |
+
# Anthropicクライアントを初期化
|
| 54 |
+
try:
|
| 55 |
+
from anthropic import Anthropic
|
| 56 |
+
self.client = Anthropic(api_key=self.api_key)
|
| 57 |
+
print(f"[AnthropicAI] 初期化完了: モデル={self.model_name}")
|
| 58 |
+
except ImportError:
|
| 59 |
+
raise ImportError("anthropicパッケージがインストールされていません。pip install anthropic を実行してください")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
raise ValueError(f"Anthropicクライアントの初期化に失敗しました: {e}")
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def get_model(cls, model_name: str = None, api_key: str = None) -> 'AnthropicAI':
|
| 65 |
+
"""モデルインスタンスを取得(常駐キャッシュから)"""
|
| 66 |
+
return cls(model_name, api_key)
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def clear_cache(cls):
|
| 70 |
+
"""キャッシュをクリア(開発・テスト用)"""
|
| 71 |
+
cls._instances.clear()
|
| 72 |
+
|
| 73 |
+
def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
|
| 74 |
+
"""
|
| 75 |
+
文章とkを引数に、{token, 確率}のリストを返す
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
text: 入力文章(messages配列または文字列)
|
| 79 |
+
k: 取得するトークン数
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
List[Tuple[str, float]]: (トークン, 確率)のリスト
|
| 83 |
+
"""
|
| 84 |
+
try:
|
| 85 |
+
# textがmessages形式かどうかを判定
|
| 86 |
+
if isinstance(text, str):
|
| 87 |
+
# 文字列の場合は、userメッセージとして扱う
|
| 88 |
+
messages = [{"role": "user", "content": text}]
|
| 89 |
+
system = None
|
| 90 |
+
elif isinstance(text, dict):
|
| 91 |
+
# dictの場合は、messagesとsystemを分離
|
| 92 |
+
messages = text.get("messages", [])
|
| 93 |
+
system = text.get("system")
|
| 94 |
+
else:
|
| 95 |
+
messages = text
|
| 96 |
+
system = None
|
| 97 |
+
|
| 98 |
+
# API呼び出し(logprobs=Trueでトークン確率を取得)
|
| 99 |
+
response = self.client.messages.create(
|
| 100 |
+
model=self.model_name,
|
| 101 |
+
messages=messages,
|
| 102 |
+
system=system,
|
| 103 |
+
logprobs=True,
|
| 104 |
+
top_logprobs=k,
|
| 105 |
+
max_tokens=1, # 次のトークン1つだけを取得
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# logprobsから確率を計算
|
| 109 |
+
items: List[Tuple[str, float]] = []
|
| 110 |
+
|
| 111 |
+
if response.content and len(response.content) > 0:
|
| 112 |
+
content_block = response.content[0]
|
| 113 |
+
if hasattr(content_block, 'logprobs') and content_block.logprobs:
|
| 114 |
+
# top_logprobsから確率を取得
|
| 115 |
+
for token_info in content_block.logprobs.top_logprobs:
|
| 116 |
+
token = self._clean_text(token_info.token)
|
| 117 |
+
if not token:
|
| 118 |
+
continue
|
| 119 |
+
# logprobを確率に変換
|
| 120 |
+
prob = math.exp(token_info.logprob)
|
| 121 |
+
items.append((token, float(prob)))
|
| 122 |
+
|
| 123 |
+
# 確��を正規化
|
| 124 |
+
if items:
|
| 125 |
+
total_prob = sum(prob for _, prob in items)
|
| 126 |
+
if total_prob > 0:
|
| 127 |
+
normalized_items: List[Tuple[str, float]] = []
|
| 128 |
+
for token, prob in items:
|
| 129 |
+
normalized_prob = prob / total_prob
|
| 130 |
+
normalized_items.append((token, normalized_prob))
|
| 131 |
+
return normalized_items
|
| 132 |
+
|
| 133 |
+
return items
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
print(f"[AnthropicAI] トークン確率取得エラー: {e}")
|
| 137 |
+
import traceback
|
| 138 |
+
traceback.print_exc()
|
| 139 |
+
return []
|
| 140 |
+
|
| 141 |
+
def build_chat_prompt(
|
| 142 |
+
self,
|
| 143 |
+
user_content: str,
|
| 144 |
+
system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください",
|
| 145 |
+
assistant_content: Optional[str] = None
|
| 146 |
+
) -> Dict[str, Any]:
|
| 147 |
+
"""
|
| 148 |
+
チャットプロンプトを構築(Anthropic messages形式)
|
| 149 |
+
|
| 150 |
+
注意: Anthropicでは、user/assistantを明確に分離するmessages配列形式を推奨します。
|
| 151 |
+
また、systemはmessagesとは別のパラメータとして扱います。
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
user_content: ユーザーのメッセージ
|
| 155 |
+
system_content: システムプロンプト(messagesとは別)
|
| 156 |
+
assistant_content: アシスタントの既存応答(会話履歴用、オプション)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Dict[str, Any]: {"messages": [...], "system": "..."} 形式
|
| 160 |
+
"""
|
| 161 |
+
messages = []
|
| 162 |
+
|
| 163 |
+
# 会話履歴がある場合(assistant_contentが指定されている場合)
|
| 164 |
+
if assistant_content:
|
| 165 |
+
messages.append({
|
| 166 |
+
"role": "assistant",
|
| 167 |
+
"content": assistant_content
|
| 168 |
+
})
|
| 169 |
+
|
| 170 |
+
# 現在のUserメッセージ
|
| 171 |
+
messages.append({
|
| 172 |
+
"role": "user",
|
| 173 |
+
"content": user_content
|
| 174 |
+
})
|
| 175 |
+
|
| 176 |
+
return {
|
| 177 |
+
"messages": messages,
|
| 178 |
+
"system": system_content if system_content else None
|
| 179 |
+
}
|
| 180 |
+
|
package/ai/base.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BaseAI - すべてのAIモデルアダプターの基底クラス
|
| 3 |
+
"""
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import List, Tuple, Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseAI(ABC):
|
| 9 |
+
"""
|
| 10 |
+
すべてのAIモデルアダプターが実装すべき基底クラス
|
| 11 |
+
|
| 12 |
+
各モデルは以下のメソッドを実装する必要があります:
|
| 13 |
+
- get_token_probabilities: トークン確率の取得
|
| 14 |
+
- build_chat_prompt: モデル固有のプロンプト形式への変換
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
|
| 19 |
+
"""
|
| 20 |
+
テキストから次のトークン候補と確率を取得
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
text: 入力テキスト(プロンプト)
|
| 24 |
+
k: 取得するトークン候補数
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
List[Tuple[str, float]]: (トークン, 確率)のリスト(確率順)
|
| 28 |
+
"""
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def build_chat_prompt(
|
| 33 |
+
self,
|
| 34 |
+
user_content: str,
|
| 35 |
+
system_content: str = "",
|
| 36 |
+
assistant_content: Optional[str] = None
|
| 37 |
+
) -> str:
|
| 38 |
+
"""
|
| 39 |
+
モデル固有のチャットプロンプト形式に変換
|
| 40 |
+
|
| 41 |
+
注意: モデルによってuser/assistantの分離方法が異なります
|
| 42 |
+
- OpenAI, Claude: user/assistantを明確に分離することを推奨
|
| 43 |
+
- Gemini: user/assistantを分離しない方が良い場合もある
|
| 44 |
+
- Transformers: モデルによって異なる(Llamaは分離推奨)
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
user_content: ユーザーのメッセージ
|
| 48 |
+
system_content: システムプロンプト(オプション)
|
| 49 |
+
assistant_content: アシスタントの既存応答(会話履歴用、オプション)
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
str: モデル固有のプロンプト形式
|
| 53 |
+
"""
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
|
| 56 |
+
def _clean_text(self, text: str) -> str:
|
| 57 |
+
"""
|
| 58 |
+
制御文字・不可視文字・置換文字を厳密に取り除く(共通処理)
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
text: クリーンアップするテキスト
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
str: クリーンアップされたテキスト
|
| 65 |
+
"""
|
| 66 |
+
if not text:
|
| 67 |
+
return ""
|
| 68 |
+
|
| 69 |
+
# 制御文字(0x00-0x1F、0x7F-0x9F)を除去
|
| 70 |
+
# ただし、改行・タブ・復帰は許可
|
| 71 |
+
cleaned = []
|
| 72 |
+
for ch in text:
|
| 73 |
+
code = ord(ch)
|
| 74 |
+
# 許可する制御文字: 改行(0x0A), タブ(0x09), 復帰(0x0D)
|
| 75 |
+
if code in [0x09, 0x0A, 0x0D]:
|
| 76 |
+
cleaned.append(ch)
|
| 77 |
+
# 通常の印刷可能文字
|
| 78 |
+
elif ch.isprintable():
|
| 79 |
+
# 置換文字(U+FFFD)を除去
|
| 80 |
+
if ch != "\uFFFD":
|
| 81 |
+
cleaned.append(ch)
|
| 82 |
+
# その他の制御文字や不可視文字は除去
|
| 83 |
+
|
| 84 |
+
result = "".join(cleaned)
|
| 85 |
+
# ゼロ幅文字を除去
|
| 86 |
+
result = result.replace("\u200B", "") # Zero-width space
|
| 87 |
+
result = result.replace("\u200C", "") # Zero-width non-joiner
|
| 88 |
+
result = result.replace("\u200D", "") # Zero-width joiner
|
| 89 |
+
result = result.replace("\uFEFF", "") # Zero-width no-break space
|
| 90 |
+
# その他の不可視文字(結合文字など)を除去
|
| 91 |
+
result = result.replace("\u200E", "") # Left-to-right mark
|
| 92 |
+
result = result.replace("\u200F", "") # Right-to-left mark
|
| 93 |
+
result = result.replace("\u202A", "") # Left-to-right embedding
|
| 94 |
+
result = result.replace("\u202B", "") # Right-to-left embedding
|
| 95 |
+
result = result.replace("\u202C", "") # Pop directional formatting
|
| 96 |
+
result = result.replace("\u202D", "") # Left-to-right override
|
| 97 |
+
result = result.replace("\u202E", "") # Right-to-left override
|
| 98 |
+
return result.strip()
|
| 99 |
+
|
package/ai/google_ai.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GoogleAI - Google API(Gemini)用アダプター
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Tuple, Optional
|
| 5 |
+
import os
|
| 6 |
+
import math
|
| 7 |
+
from .base import BaseAI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GoogleAI(BaseAI):
|
| 11 |
+
"""
|
| 12 |
+
Google API(Gemini)用アダプター
|
| 13 |
+
|
| 14 |
+
特徴:
|
| 15 |
+
- API経由でモデルにアクセス
|
| 16 |
+
- logprobsパラメータでトークン確率を取得可能
|
| 17 |
+
- user/assistantを分離しない方が良い場合もある(テキスト形式)
|
| 18 |
+
- systemとuserを結合したテキスト形式を推奨
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
_instances = {} # モデルごとのインスタンスをキャッシュ
|
| 22 |
+
|
| 23 |
+
def __new__(cls, model_name: str = None, api_key: str = None):
|
| 24 |
+
"""シングルトンパターンでクライアントを常駐"""
|
| 25 |
+
model = model_name or os.getenv("GOOGLE_MODEL", "gemini-pro")
|
| 26 |
+
key = api_key or os.getenv("GOOGLE_API_KEY")
|
| 27 |
+
|
| 28 |
+
cache_key = f"{model}:{key}"
|
| 29 |
+
if cache_key not in cls._instances:
|
| 30 |
+
cls._instances[cache_key] = super().__new__(cls)
|
| 31 |
+
cls._instances[cache_key]._initialized = False
|
| 32 |
+
|
| 33 |
+
return cls._instances[cache_key]
|
| 34 |
+
|
| 35 |
+
def __init__(self, model_name: str = None, api_key: str = None):
|
| 36 |
+
"""
|
| 37 |
+
初期化
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model_name: モデル名(例: "gemini-pro")
|
| 41 |
+
api_key: Google APIキー
|
| 42 |
+
"""
|
| 43 |
+
if hasattr(self, '_initialized') and self._initialized:
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
self.model_name = model_name or os.getenv("GOOGLE_MODEL", "gemini-pro")
|
| 47 |
+
self.api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
| 48 |
+
self._initialized = True
|
| 49 |
+
|
| 50 |
+
if not self.api_key:
|
| 51 |
+
raise ValueError("GOOGLE_API_KEYが設定されていません")
|
| 52 |
+
|
| 53 |
+
# Google Generative AIクライアントを初期化
|
| 54 |
+
try:
|
| 55 |
+
import google.generativeai as genai
|
| 56 |
+
genai.configure(api_key=self.api_key)
|
| 57 |
+
self.model = genai.GenerativeModel(self.model_name)
|
| 58 |
+
print(f"[GoogleAI] 初期化完了: モデル={self.model_name}")
|
| 59 |
+
except ImportError:
|
| 60 |
+
raise ImportError("google-generativeaiパッケージがインストールされていません。pip install google-generativeai を実行してください")
|
| 61 |
+
except Exception as e:
|
| 62 |
+
raise ValueError(f"Google Generative AIクライアントの初期化に失敗しました: {e}")
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def get_model(cls, model_name: str = None, api_key: str = None) -> 'GoogleAI':
|
| 66 |
+
"""モデルインスタンスを取得(常駐キャッシュから)"""
|
| 67 |
+
return cls(model_name, api_key)
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def clear_cache(cls):
|
| 71 |
+
"""キャッシュをクリア(開発・テスト用)"""
|
| 72 |
+
cls._instances.clear()
|
| 73 |
+
|
| 74 |
+
def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
|
| 75 |
+
"""
|
| 76 |
+
文章とkを引数に、{token, 確率}のリストを返す
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
text: 入力文章(プロンプト)
|
| 80 |
+
k: 取得するトークン数
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
List[Tuple[str, float]]: (トークン, 確率)のリスト
|
| 84 |
+
"""
|
| 85 |
+
try:
|
| 86 |
+
# Gemini APIでトークン確率を取得
|
| 87 |
+
# 注意: Gemini APIのlogprobs取得方法は他のAPIと異なる可能性があります
|
| 88 |
+
response = self.model.generate_content(
|
| 89 |
+
text,
|
| 90 |
+
generation_config={
|
| 91 |
+
"max_output_tokens": 1, # 次のトークン1つだけを取得
|
| 92 |
+
"temperature": 0.0, # 確定的な結果を得るため
|
| 93 |
+
}
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# 注意: Gemini APIのlogprobs取得方法は公式ドキュメントを確認してください
|
| 97 |
+
# ここでは仮の実装です
|
| 98 |
+
items: List[Tuple[str, float]] = []
|
| 99 |
+
|
| 100 |
+
# 実際の実装では、responseからlogprobsを取得する必要があります
|
| 101 |
+
# 現在のGemini APIでは、logprobsの直接取得が難しい可能性があります
|
| 102 |
+
# 代替案: 複数回のサンプリングで確率を推定
|
| 103 |
+
|
| 104 |
+
print("[GoogleAI] 警告: Gemini APIのlogprobs取得は実装が不完全です")
|
| 105 |
+
return items
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"[GoogleAI] トークン確率取得エラー: {e}")
|
| 109 |
+
import traceback
|
| 110 |
+
traceback.print_exc()
|
| 111 |
+
return []
|
| 112 |
+
|
| 113 |
+
def build_chat_prompt(
|
| 114 |
+
self,
|
| 115 |
+
user_content: str,
|
| 116 |
+
system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください",
|
| 117 |
+
assistant_content: Optional[str] = None
|
| 118 |
+
) -> str:
|
| 119 |
+
"""
|
| 120 |
+
チャットプロンプトを構築(Gemini形式)
|
| 121 |
+
|
| 122 |
+
注意: Geminiでは、user/assistantを分離しない方が良い場合もあります。
|
| 123 |
+
systemとuserを結合したテキスト形式を推奨します。
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
user_content: ユーザーのメッセージ
|
| 127 |
+
system_content: システムプロンプト
|
| 128 |
+
assistant_content: アシスタントの既存応答(会話履歴用、オプション)
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
str: Gemini形式のプロンプト(テキスト)
|
| 132 |
+
"""
|
| 133 |
+
prompt_parts = []
|
| 134 |
+
|
| 135 |
+
# Systemメッセージ(最初に1回だけ)
|
| 136 |
+
if system_content:
|
| 137 |
+
prompt_parts.append(f"システム: {system_content}")
|
| 138 |
+
prompt_parts.append("")
|
| 139 |
+
|
| 140 |
+
# 会話履歴がある場合(assistant_contentが指定されている場合)
|
| 141 |
+
if assistant_content:
|
| 142 |
+
prompt_parts.append(f"ユーザー: {user_content}")
|
| 143 |
+
prompt_parts.append(f"アシスタント: {assistant_content}")
|
| 144 |
+
prompt_parts.append("")
|
| 145 |
+
|
| 146 |
+
# 現在のUserメッセージ
|
| 147 |
+
prompt_parts.append(f"ユーザー: {user_content}")
|
| 148 |
+
prompt_parts.append("アシスタント:")
|
| 149 |
+
|
| 150 |
+
prompt_text = "\n".join(prompt_parts)
|
| 151 |
+
return prompt_text
|
| 152 |
+
|
package/ai/openai_ai.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAIAI - OpenAI API(ChatGPT)用アダプター
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Tuple, Optional, Dict, Any
|
| 5 |
+
import os
|
| 6 |
+
import math
|
| 7 |
+
from .base import BaseAI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class OpenAIAI(BaseAI):
|
| 11 |
+
"""
|
| 12 |
+
OpenAI API(ChatGPT)用アダプター
|
| 13 |
+
|
| 14 |
+
特徴:
|
| 15 |
+
- API経由でモデルにアクセス
|
| 16 |
+
- logprobsパラメータでトークン確率を取得可能(GPT-4以降)
|
| 17 |
+
- user/assistantを明確に分離する形式を推奨(messages配列形式)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
_instances = {} # モデルごとのインスタンスをキャッシュ
|
| 21 |
+
|
| 22 |
+
def __new__(cls, model_name: str = None, api_key: str = None):
|
| 23 |
+
"""シングルトンパターンでクライアントを常駐"""
|
| 24 |
+
model = model_name or os.getenv("OPENAI_MODEL", "gpt-4")
|
| 25 |
+
key = api_key or os.getenv("OPENAI_API_KEY")
|
| 26 |
+
|
| 27 |
+
cache_key = f"{model}:{key}"
|
| 28 |
+
if cache_key not in cls._instances:
|
| 29 |
+
cls._instances[cache_key] = super().__new__(cls)
|
| 30 |
+
cls._instances[cache_key]._initialized = False
|
| 31 |
+
|
| 32 |
+
return cls._instances[cache_key]
|
| 33 |
+
|
| 34 |
+
def __init__(self, model_name: str = None, api_key: str = None):
|
| 35 |
+
"""
|
| 36 |
+
初期化
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
model_name: モデル名(例: "gpt-4", "gpt-3.5-turbo")
|
| 40 |
+
api_key: OpenAI APIキー
|
| 41 |
+
"""
|
| 42 |
+
if hasattr(self, '_initialized') and self._initialized:
|
| 43 |
+
return
|
| 44 |
+
|
| 45 |
+
self.model_name = model_name or os.getenv("OPENAI_MODEL", "gpt-4")
|
| 46 |
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
| 47 |
+
self._initialized = True
|
| 48 |
+
|
| 49 |
+
if not self.api_key:
|
| 50 |
+
raise ValueError("OPENAI_API_KEYが設定されていません")
|
| 51 |
+
|
| 52 |
+
# OpenAIクライアントを初期化
|
| 53 |
+
try:
|
| 54 |
+
import openai
|
| 55 |
+
self.client = openai.OpenAI(api_key=self.api_key)
|
| 56 |
+
print(f"[OpenAIAI] 初期化完了: モデル={self.model_name}")
|
| 57 |
+
except ImportError:
|
| 58 |
+
raise ImportError("openaiパッケージがインストールされていません。pip install openai を実行してください")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
raise ValueError(f"OpenAIクライアントの初期化に失敗しました: {e}")
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def get_model(cls, model_name: str = None, api_key: str = None) -> 'OpenAIAI':
|
| 64 |
+
"""モデルインスタンスを取得(常駐キャッシュから)"""
|
| 65 |
+
return cls(model_name, api_key)
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def clear_cache(cls):
|
| 69 |
+
"""キャッシュをクリア(開発・テスト用)"""
|
| 70 |
+
cls._instances.clear()
|
| 71 |
+
|
| 72 |
+
def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
|
| 73 |
+
"""
|
| 74 |
+
文章とkを引数に、{token, 確率}のリストを返す
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
text: 入力文章(プロンプト)
|
| 78 |
+
k: 取得するトークン数
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
List[Tuple[str, float]]: (トークン, 確率)のリスト
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
# OpenAI APIでは、messages形式でリクエストする必要がある
|
| 85 |
+
# textが既にmessages形式かどうかを判定
|
| 86 |
+
if isinstance(text, str):
|
| 87 |
+
# 文字列の場合は、userメッセージとして扱う
|
| 88 |
+
messages = [{"role": "user", "content": text}]
|
| 89 |
+
else:
|
| 90 |
+
messages = text
|
| 91 |
+
|
| 92 |
+
# API呼び出し(logprobs=Trueでトークン確率を取得)
|
| 93 |
+
response = self.client.chat.completions.create(
|
| 94 |
+
model=self.model_name,
|
| 95 |
+
messages=messages,
|
| 96 |
+
logprobs=True,
|
| 97 |
+
top_logprobs=k,
|
| 98 |
+
max_tokens=1, # 次のトークン1つだけを取得
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# logprobsから確率を計算
|
| 102 |
+
items: List[Tuple[str, float]] = []
|
| 103 |
+
|
| 104 |
+
if response.choices and response.choices[0].logprobs:
|
| 105 |
+
logprobs = response.choices[0].logprobs.content[0] if response.choices[0].logprobs.content else None
|
| 106 |
+
|
| 107 |
+
if logprobs:
|
| 108 |
+
# top_logprobsから確率を取得
|
| 109 |
+
for token_info in logprobs.top_logprobs:
|
| 110 |
+
token = self._clean_text(token_info.token)
|
| 111 |
+
if not token:
|
| 112 |
+
continue
|
| 113 |
+
# logprobを確率に変換
|
| 114 |
+
prob = math.exp(token_info.logprob)
|
| 115 |
+
items.append((token, float(prob)))
|
| 116 |
+
|
| 117 |
+
# 確率を正規化
|
| 118 |
+
if items:
|
| 119 |
+
total_prob = sum(prob for _, prob in items)
|
| 120 |
+
if total_prob > 0:
|
| 121 |
+
normalized_items: List[Tuple[str, float]] = []
|
| 122 |
+
for token, prob in items:
|
| 123 |
+
normalized_prob = prob / total_prob
|
| 124 |
+
normalized_items.append((token, normalized_prob))
|
| 125 |
+
return normalized_items
|
| 126 |
+
|
| 127 |
+
return items
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"[OpenAIAI] トークン確率取得エラー: {e}")
|
| 131 |
+
import traceback
|
| 132 |
+
traceback.print_exc()
|
| 133 |
+
return []
|
| 134 |
+
|
| 135 |
+
def build_chat_prompt(
|
| 136 |
+
self,
|
| 137 |
+
user_content: str,
|
| 138 |
+
system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください",
|
| 139 |
+
assistant_content: Optional[str] = None
|
| 140 |
+
) -> List[Dict[str, str]]:
|
| 141 |
+
"""
|
| 142 |
+
チャットプロンプトを構築(OpenAI messages形式)
|
| 143 |
+
|
| 144 |
+
注意: OpenAIでは、user/assistantを明確に分離するmessages配列形式を推奨します。
|
| 145 |
+
このメソッドは文字列ではなく、messages配列を返します。
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
user_content: ユーザーのメッセージ
|
| 149 |
+
system_content: システムプロンプト
|
| 150 |
+
assistant_content: アシスタントの既存応答(会話履歴用、オプション)
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
List[Dict[str, str]]: OpenAI messages形式の配列
|
| 154 |
+
"""
|
| 155 |
+
messages = []
|
| 156 |
+
|
| 157 |
+
# Systemメッセージ(最初に1回だけ)
|
| 158 |
+
if system_content:
|
| 159 |
+
messages.append({
|
| 160 |
+
"role": "system",
|
| 161 |
+
"content": system_content
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
# 会話履歴がある場合(assistant_contentが指定されている場合)
|
| 165 |
+
if assistant_content:
|
| 166 |
+
# 前回のuserメッセージとassistant応答を追加
|
| 167 |
+
# 注意: この実装では、assistant_contentのみを追加
|
| 168 |
+
# 実際の会話履歴管理は呼び出し側で行う必要があります
|
| 169 |
+
messages.append({
|
| 170 |
+
"role": "assistant",
|
| 171 |
+
"content": assistant_content
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
# 現在のUserメッセージ
|
| 175 |
+
messages.append({
|
| 176 |
+
"role": "user",
|
| 177 |
+
"content": user_content
|
| 178 |
+
})
|
| 179 |
+
|
| 180 |
+
return messages
|
| 181 |
+
|
package/ai/transformers_ai.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TransformersAI - Hugging Face Transformersモデル用アダプター
|
| 3 |
+
Llama 3.2、Qwen、Mistral、Gemma等のローカルモデルに対応
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Tuple, Any, Optional
|
| 6 |
+
import os
|
| 7 |
+
from .base import BaseAI
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TransformersAI(BaseAI):
|
| 11 |
+
"""
|
| 12 |
+
Hugging Face Transformersモデル用アダプター
|
| 13 |
+
|
| 14 |
+
特徴:
|
| 15 |
+
- ローカルでモデルをロード
|
| 16 |
+
- logitsから直接確率を取得可能
|
| 17 |
+
- user/assistantを明確に分離する形式を推奨(Llama 3.2形式)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
_instances = {} # モデルパスごとのインスタンスをキャッシュ(常駐)
|
| 21 |
+
|
| 22 |
+
def __new__(cls, model_path: str = None):
|
| 23 |
+
"""シングルトンパターンでモデルを常駐"""
|
| 24 |
+
path = model_path or os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")
|
| 25 |
+
|
| 26 |
+
if path not in cls._instances:
|
| 27 |
+
cls._instances[path] = super().__new__(cls)
|
| 28 |
+
cls._instances[path]._initialized = False
|
| 29 |
+
|
| 30 |
+
return cls._instances[path]
|
| 31 |
+
|
| 32 |
+
def __init__(self, model_path: str = None):
|
| 33 |
+
"""
|
| 34 |
+
モデルをロードして初期化(一度だけ実行、常駐)
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
model_path: モデルリポジトリIDまたはローカルパス
|
| 38 |
+
"""
|
| 39 |
+
if hasattr(self, '_initialized') and self._initialized:
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
self.model_path = model_path or os.getenv("HF_MODEL_REPO", "meta-llama/Llama-3.2-3B-Instruct")
|
| 43 |
+
self.model = self._load_model(self.model_path)
|
| 44 |
+
self._initialized = True
|
| 45 |
+
|
| 46 |
+
if self.model is None:
|
| 47 |
+
raise ValueError(f"モデルのロードに失敗しました: {self.model_path}")
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def get_model(cls, model_path: str = None) -> 'TransformersAI':
|
| 51 |
+
"""モデルインスタンスを取得(常駐キャッシュから)"""
|
| 52 |
+
return cls(model_path)
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def clear_cache(cls):
|
| 56 |
+
"""キャッシュをクリア(開発・テスト用)"""
|
| 57 |
+
cls._instances.clear()
|
| 58 |
+
|
| 59 |
+
def _load_model(self, model_path: str) -> Optional[Any]:
|
| 60 |
+
"""モデルをロード(Transformers使用、Hubから直接読み込み)"""
|
| 61 |
+
try:
|
| 62 |
+
if not model_path:
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
# モデルパスがリポジトリID("user/repo"形式)か、ローカルパスかを判定
|
| 66 |
+
is_repo_id = "/" in model_path and not os.path.exists(model_path)
|
| 67 |
+
|
| 68 |
+
# リポジトリIDの場合は os.path.exists() チェックをスキップ
|
| 69 |
+
if not is_repo_id and not os.path.exists(model_path):
|
| 70 |
+
print(f"[TransformersAI] モデルパスが存在しません: {model_path}")
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
# transformersを使用してモデルをロード
|
| 74 |
+
try:
|
| 75 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 76 |
+
import torch
|
| 77 |
+
|
| 78 |
+
# GPUが利用可能かチェック
|
| 79 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 80 |
+
if device == "cuda":
|
| 81 |
+
print("[TransformersAI] GPU検出: CUDAを使用します")
|
| 82 |
+
else:
|
| 83 |
+
print("[TransformersAI] GPU未検出: CPUモードで実行します")
|
| 84 |
+
|
| 85 |
+
print(f"[TransformersAI] モデルをロード中: {model_path}")
|
| 86 |
+
print(f"[TransformersAI] デバイス: {device}")
|
| 87 |
+
|
| 88 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 89 |
+
|
| 90 |
+
if is_repo_id:
|
| 91 |
+
print(f"[TransformersAI] Hugging Face Hub から直接読み込み: {model_path}")
|
| 92 |
+
else:
|
| 93 |
+
print(f"[TransformersAI] ローカルパスから読み込み: {model_path}")
|
| 94 |
+
|
| 95 |
+
# トークナイザーとモデルをロード(Hubから直接読み込む)
|
| 96 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 97 |
+
model_path,
|
| 98 |
+
token=hf_token,
|
| 99 |
+
)
|
| 100 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 101 |
+
model_path,
|
| 102 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 103 |
+
device_map="auto" if device == "cuda" else None,
|
| 104 |
+
token=hf_token,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if device == "cpu":
|
| 108 |
+
model = model.to(device)
|
| 109 |
+
|
| 110 |
+
# モデルとトークナイザーをタプルで返す
|
| 111 |
+
print(f"[TransformersAI] モデルロード成功 ({device}モード)")
|
| 112 |
+
return (model, tokenizer)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
import traceback
|
| 115 |
+
print(f"[TransformersAI] transformersでのロードに失敗: {e}")
|
| 116 |
+
traceback.print_exc()
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
import traceback
|
| 121 |
+
print(f"[TransformersAI] モデルロードエラー: {e}")
|
| 122 |
+
traceback.print_exc()
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
def get_token_probabilities(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
|
| 126 |
+
"""
|
| 127 |
+
文章とkを引数に、{token, 確率}のリストを返す
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
text: 入力文章
|
| 131 |
+
k: 取得するトークン数
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
List[Tuple[str, float]]: (トークン, 確率)のリスト
|
| 135 |
+
"""
|
| 136 |
+
if self.model is None:
|
| 137 |
+
return []
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
# transformers モデルの場合
|
| 141 |
+
if isinstance(self.model, tuple) and len(self.model) == 2:
|
| 142 |
+
model, tokenizer = self.model
|
| 143 |
+
import torch
|
| 144 |
+
|
| 145 |
+
# テキストをトークン化
|
| 146 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 147 |
+
device = next(model.parameters()).device
|
| 148 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 149 |
+
|
| 150 |
+
# モデルで推論(勾配計算なし)
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
outputs = model(**inputs)
|
| 153 |
+
logits = outputs.logits[0, -1, :] # 最後のトークンのlogits
|
| 154 |
+
|
| 155 |
+
# logitsを確率に変換(softmax)
|
| 156 |
+
probs = torch.softmax(logits, dim=-1)
|
| 157 |
+
|
| 158 |
+
# 上位k個のトークンを取得
|
| 159 |
+
top_probs, top_indices = torch.topk(probs, k)
|
| 160 |
+
|
| 161 |
+
# トークンIDを文字列に変換
|
| 162 |
+
items: List[Tuple[str, float]] = []
|
| 163 |
+
|
| 164 |
+
# 特殊トークンを定義(Llama 3.2、Qwen、Mistral等で使用)
|
| 165 |
+
SPECIAL_TOKENS = [
|
| 166 |
+
"<|begin_of_text|>",
|
| 167 |
+
"<|end_of_text|>",
|
| 168 |
+
"<|eot_id|>",
|
| 169 |
+
"<|start_header_id|>",
|
| 170 |
+
"<|end_header_id|>",
|
| 171 |
+
"<|im_start|>",
|
| 172 |
+
"<|im_end|>",
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
def _clean_text_local(text: str) -> str:
|
| 176 |
+
"""制御文字・不可視文字・置換文字・特殊トークンを厳密に取り除く"""
|
| 177 |
+
if not text:
|
| 178 |
+
return ""
|
| 179 |
+
|
| 180 |
+
# 特殊トークンを除去
|
| 181 |
+
for special_token in SPECIAL_TOKENS:
|
| 182 |
+
text = text.replace(special_token, "")
|
| 183 |
+
|
| 184 |
+
# 基底クラスの_clean_textを使用
|
| 185 |
+
return self._clean_text(text)
|
| 186 |
+
|
| 187 |
+
for idx, prob in zip(top_indices, top_probs):
|
| 188 |
+
token_id = idx.item()
|
| 189 |
+
# skip_special_tokens=Trueで特殊トークンを除外
|
| 190 |
+
token = tokenizer.decode([token_id], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
| 191 |
+
token = _clean_text_local(token)
|
| 192 |
+
# 空文字列のトークンは除外
|
| 193 |
+
if not token:
|
| 194 |
+
continue
|
| 195 |
+
prob_value = prob.item()
|
| 196 |
+
items.append((token, float(prob_value)))
|
| 197 |
+
|
| 198 |
+
# 確率を正規化
|
| 199 |
+
if items:
|
| 200 |
+
total_prob = sum(prob for _, prob in items)
|
| 201 |
+
if total_prob > 0:
|
| 202 |
+
normalized_items: List[Tuple[str, float]] = []
|
| 203 |
+
for token, prob in items:
|
| 204 |
+
normalized_prob = prob / total_prob
|
| 205 |
+
normalized_items.append((token, normalized_prob))
|
| 206 |
+
return normalized_items
|
| 207 |
+
|
| 208 |
+
return items
|
| 209 |
+
else:
|
| 210 |
+
print("[TransformersAI] モデルがサポートされていません")
|
| 211 |
+
return []
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"[TransformersAI] トークン確率取得エラー: {e}")
|
| 215 |
+
import traceback
|
| 216 |
+
traceback.print_exc()
|
| 217 |
+
return []
|
| 218 |
+
|
| 219 |
+
def build_chat_prompt(
|
| 220 |
+
self,
|
| 221 |
+
user_content: str,
|
| 222 |
+
system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください",
|
| 223 |
+
assistant_content: Optional[str] = None
|
| 224 |
+
) -> str:
|
| 225 |
+
"""
|
| 226 |
+
チャットプロンプトを構築(Llama 3.2形式)
|
| 227 |
+
|
| 228 |
+
注意: Transformersモデル(特にLlama 3.2、Qwen等)では、
|
| 229 |
+
user/assistantを明確に分離する形式を推奨します。
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
user_content: ユーザーのメッセージ
|
| 233 |
+
system_content: システムプロンプト
|
| 234 |
+
assistant_content: アシスタントの既存応答(会話履歴用、オプション)
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
str: Llama 3.2形式のプロンプト
|
| 238 |
+
"""
|
| 239 |
+
# 既に整形済みのプロンプトが渡されている場合(複数行、ヘッダーを含む)
|
| 240 |
+
# そのまま返す
|
| 241 |
+
if "<|start_header_id|>" in user_content or "<|eot_id|>" in user_content:
|
| 242 |
+
return user_content
|
| 243 |
+
|
| 244 |
+
# Llama 3.2形式でプロンプトを構築
|
| 245 |
+
prompt_parts = []
|
| 246 |
+
|
| 247 |
+
# Systemメッセージ
|
| 248 |
+
if system_content:
|
| 249 |
+
prompt_parts.append("<|start_header_id|>system<|end_header_id|>")
|
| 250 |
+
prompt_parts.append(system_content)
|
| 251 |
+
prompt_parts.append("<|eot_id|>")
|
| 252 |
+
|
| 253 |
+
# Userメッセージ
|
| 254 |
+
prompt_parts.append("<|start_header_id|>user<|end_header_id|>")
|
| 255 |
+
prompt_parts.append(user_content)
|
| 256 |
+
prompt_parts.append("<|eot_id|>")
|
| 257 |
+
|
| 258 |
+
# Assistantメッセージ(会話履歴がある場合)
|
| 259 |
+
if assistant_content:
|
| 260 |
+
prompt_parts.append("<|start_header_id|>assistant<|end_header_id|>")
|
| 261 |
+
prompt_parts.append(assistant_content)
|
| 262 |
+
prompt_parts.append("<|eot_id|>")
|
| 263 |
+
|
| 264 |
+
# 新しい応答を生成する場合は、assistantヘッダーだけを追加
|
| 265 |
+
prompt_parts.append("<|start_header_id|>assistant<|end_header_id|>")
|
| 266 |
+
|
| 267 |
+
prompt_text = "\n".join(prompt_parts)
|
| 268 |
+
|
| 269 |
+
# BOS(<|begin_of_text|>) の重複を抑止: 先頭のBOSを全て除去
|
| 270 |
+
# transformers側でBOSが自動付与される場合があるため
|
| 271 |
+
BOS = "<|begin_of_text|>"
|
| 272 |
+
s = prompt_text.lstrip()
|
| 273 |
+
while s.startswith(BOS):
|
| 274 |
+
s = s[len(BOS):]
|
| 275 |
+
prompt_text = s
|
| 276 |
+
|
| 277 |
+
return prompt_text
|
| 278 |
+
|
package/config.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
設定ファイル - Hugging Face Spaces用(簡易版)
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Config:
|
| 8 |
+
"""設定管理クラス(Hugging Face Spaces用)"""
|
| 9 |
+
|
| 10 |
+
# MeCab設定(fugashi用、HFSでは通常不要だが互換性のため)
|
| 11 |
+
MECAB_CONFIG_PATH = os.getenv("MECAB_CONFIG_PATH", "/opt/homebrew/etc/mecabrc")
|
| 12 |
+
MECAB_DICT_PATH = os.getenv("MECAB_DICT_PATH", "/opt/homebrew/lib/mecab/dic/ipadic")
|
| 13 |
+
|
| 14 |
+
# fugashi設定(MeCab/IPA用)
|
| 15 |
+
FUGASHI_ARGS = f"-r {MECAB_CONFIG_PATH}"
|
| 16 |
+
|
| 17 |
+
@classmethod
|
| 18 |
+
def get_fugashi_args(cls) -> str:
|
| 19 |
+
"""fugashi用の引数を取得"""
|
| 20 |
+
# HFSでは通常fugashiはデフォルト設定で動作
|
| 21 |
+
# 引数なしでデフォルト設定を使用できる場合は空文字列を返す
|
| 22 |
+
# そうでない場合は設定ファイルパスを返す
|
| 23 |
+
try:
|
| 24 |
+
import fugashi
|
| 25 |
+
# デフォルトのGenericTaggerを使用(引数なしで動作する場合)
|
| 26 |
+
# ただし、設定ファイルが必要な場合はパスを返す
|
| 27 |
+
if os.path.exists(cls.MECAB_CONFIG_PATH):
|
| 28 |
+
return cls.FUGASHI_ARGS
|
| 29 |
+
else:
|
| 30 |
+
# 設定ファイルが存在しない場合は空文字列(デフォルト設定を使用)
|
| 31 |
+
return ""
|
| 32 |
+
except ImportError:
|
| 33 |
+
# fugashiがインストールされていない場合は空文字列を返す
|
| 34 |
+
# (WordCounterでエラーハンドリングされる)
|
| 35 |
+
return ""
|
| 36 |
+
|
package/word_counter.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
import fugashi
|
| 3 |
+
from .config import Config
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
# SudachiPy があれば直接利用してモードCを使用
|
| 7 |
+
from sudachipy import dictionary as sudachi_dictionary
|
| 8 |
+
from sudachipy import tokenizer as sudachi_tokenizer
|
| 9 |
+
_SUDACHI_AVAILABLE = True
|
| 10 |
+
except Exception:
|
| 11 |
+
_SUDACHI_AVAILABLE = False
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class WordCounter:
|
| 15 |
+
"""単語数を数えるクラス(SudachiPyがあれば mode=C、なければfugashi)"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, tokenizer: Any = None):
|
| 18 |
+
"""
|
| 19 |
+
初期化
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
tokenizer: fugashiトークナイザー(Noneの場合はデフォルトを使用)
|
| 23 |
+
"""
|
| 24 |
+
# 優先順位: 引数tokenizer > SudachiPy > fugashi(GenericTagger)
|
| 25 |
+
self._use_sudachi = False
|
| 26 |
+
self._sudachi_mode = None
|
| 27 |
+
if tokenizer is not None:
|
| 28 |
+
self.tokenizer = tokenizer
|
| 29 |
+
elif _SUDACHI_AVAILABLE:
|
| 30 |
+
# SudachiPyの辞書は自動で同梱辞書を参照(sudachidict_core)
|
| 31 |
+
# 外部設定不要。SplitMode.C を使用
|
| 32 |
+
self._use_sudachi = True
|
| 33 |
+
self.tokenizer = sudachi_dictionary.Dictionary().create()
|
| 34 |
+
self._sudachi_mode = sudachi_tokenizer.Tokenizer.SplitMode.C
|
| 35 |
+
else:
|
| 36 |
+
# fugashi (MeCab) フォールバック
|
| 37 |
+
fugashi_args = Config.get_fugashi_args()
|
| 38 |
+
if fugashi_args:
|
| 39 |
+
self.tokenizer = fugashi.GenericTagger(fugashi_args)
|
| 40 |
+
else:
|
| 41 |
+
# 引数なしでデフォルト設定を使用
|
| 42 |
+
self.tokenizer = fugashi.GenericTagger()
|
| 43 |
+
|
| 44 |
+
def count_words(self, text: str) -> int:
|
| 45 |
+
"""
|
| 46 |
+
テキストの単語数をカウント
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
text: カウントするテキスト
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
int: 単語数
|
| 53 |
+
"""
|
| 54 |
+
if not text:
|
| 55 |
+
return 0
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# fugashiで形態素解析して単語数をカウント
|
| 59 |
+
if self._use_sudachi:
|
| 60 |
+
tokens = self.tokenizer.tokenize(text, self._sudachi_mode)
|
| 61 |
+
return len(tokens)
|
| 62 |
+
else:
|
| 63 |
+
tokens = self.tokenizer(text)
|
| 64 |
+
return len(tokens)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"単語数カウントエラー: {e}")
|
| 67 |
+
# フォールバック: 空白で分割
|
| 68 |
+
return len(text.split())
|
| 69 |
+
|
| 70 |
+
def is_word_boundary(self, text: str, position: int) -> bool:
|
| 71 |
+
"""
|
| 72 |
+
指定位置が単語境界かどうかを判定
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
text: テキスト
|
| 76 |
+
position: 位置(負の値で末尾から指定可能、-1は末尾)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
bool: 単語境界かどうか
|
| 80 |
+
"""
|
| 81 |
+
if not text:
|
| 82 |
+
return True
|
| 83 |
+
|
| 84 |
+
# 負のインデックスを正のインデックスに変換
|
| 85 |
+
if position < 0:
|
| 86 |
+
position = len(text) + position
|
| 87 |
+
|
| 88 |
+
if position >= len(text):
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# fugashiで形態素解析
|
| 93 |
+
if self._use_sudachi:
|
| 94 |
+
tokens = self.tokenizer.tokenize(text, self._sudachi_mode)
|
| 95 |
+
surfaces = [m.surface() for m in tokens]
|
| 96 |
+
else:
|
| 97 |
+
tokens = self.tokenizer(text)
|
| 98 |
+
surfaces = [m.surface for m in tokens]
|
| 99 |
+
|
| 100 |
+
current_pos = 0
|
| 101 |
+
for surface in surfaces:
|
| 102 |
+
token_length = len(surface)
|
| 103 |
+
if current_pos <= position < current_pos + token_length:
|
| 104 |
+
return False
|
| 105 |
+
if position == current_pos + token_length:
|
| 106 |
+
return True
|
| 107 |
+
current_pos += token_length
|
| 108 |
+
|
| 109 |
+
return True
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"境界判定エラー: {e}")
|
| 113 |
+
# フォールバック: 空白文字で判定
|
| 114 |
+
return position < len(text) and text[position].isspace()
|
| 115 |
+
|
package/word_processor.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Any, Optional
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from enum import Enum
|
| 4 |
+
import os
|
| 5 |
+
import math
|
| 6 |
+
from .word_counter import WordCounter
|
| 7 |
+
from .config import Config
|
| 8 |
+
from .ai.base import BaseAI
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class WordState(Enum):
|
| 12 |
+
"""単語の状態"""
|
| 13 |
+
INCOMPLETE = "incomplete" # 未完成
|
| 14 |
+
COMPLETE = "complete" # 完成
|
| 15 |
+
TRIGGER = "trigger" # トリガー(次語の開始)
|
| 16 |
+
|
| 17 |
+
class KList:
|
| 18 |
+
def __init__(self, num: int):
|
| 19 |
+
self.num = num
|
| 20 |
+
self.list: List[Any] = []
|
| 21 |
+
|
| 22 |
+
def check_k(self) -> None:
|
| 23 |
+
if len(self.list) >= self.num:
|
| 24 |
+
self.list.sort(key=lambda x: x.probability, reverse=True)
|
| 25 |
+
self.list = self.list[:self.num]
|
| 26 |
+
else:
|
| 27 |
+
self.list.sort(key=lambda x: x.probability, reverse=True)
|
| 28 |
+
|
| 29 |
+
def add(self, piece_word: Any) -> None:
|
| 30 |
+
# 重複チェック: 同じテキストのピースが既に存在するか確認
|
| 31 |
+
new_text = piece_word.get_full_text()
|
| 32 |
+
for existing_piece in self.list:
|
| 33 |
+
if existing_piece.get_full_text() == new_text:
|
| 34 |
+
# 既存のピースに確率を足す
|
| 35 |
+
existing_piece.probability += piece_word.probability
|
| 36 |
+
# 確率を更新したので、ソートし直す
|
| 37 |
+
self.check_k()
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
# 重複がない場合は追加
|
| 41 |
+
self.list.append(piece_word)
|
| 42 |
+
self.check_k()
|
| 43 |
+
|
| 44 |
+
def pop(self) -> Any:
|
| 45 |
+
if self.list:
|
| 46 |
+
return self.list.pop(0)
|
| 47 |
+
raise IndexError("List is empty")
|
| 48 |
+
|
| 49 |
+
def empty(self) -> bool:
|
| 50 |
+
return len(self.list) == 0
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class WordPiece:
|
| 54 |
+
"""単語のピース(部分)"""
|
| 55 |
+
text: str # ピースのテキスト
|
| 56 |
+
probability: float # 確率
|
| 57 |
+
next_tokens: Optional[List[Tuple[str, float]]] = None # 次のトークン候補
|
| 58 |
+
parent: Optional['WordPiece'] = None # 親ピース
|
| 59 |
+
children: List['WordPiece'] = None # 子ピース
|
| 60 |
+
|
| 61 |
+
def __post_init__(self):
|
| 62 |
+
if self.children is None:
|
| 63 |
+
self.children = []
|
| 64 |
+
|
| 65 |
+
def get_full_text(self) -> str:
|
| 66 |
+
"""ルートからこのピースまでの完全なテキストを取得"""
|
| 67 |
+
pieces = []
|
| 68 |
+
current = self
|
| 69 |
+
while current is not None:
|
| 70 |
+
if current.text:
|
| 71 |
+
pieces.append(current.text)
|
| 72 |
+
current = current.parent
|
| 73 |
+
return "".join(reversed(pieces))
|
| 74 |
+
|
| 75 |
+
def get_full_word(self) -> str:
|
| 76 |
+
"""ルートの次語からこのピースまでの完全な単語を取得"""
|
| 77 |
+
pieces = []
|
| 78 |
+
current = self
|
| 79 |
+
while current is not None:
|
| 80 |
+
if current.text:
|
| 81 |
+
pieces.append(current.text)
|
| 82 |
+
current = current.parent
|
| 83 |
+
reversed_pieces = reversed(pieces[:-1])
|
| 84 |
+
return "".join(reversed_pieces)
|
| 85 |
+
|
| 86 |
+
def add_child(self, text: str, probability: float, next_tokens: Optional[List[Tuple[str, float]]] = None) -> 'WordPiece':
|
| 87 |
+
"""子ピースを追加"""
|
| 88 |
+
child = WordPiece(
|
| 89 |
+
text=text,
|
| 90 |
+
probability=probability,
|
| 91 |
+
next_tokens=next_tokens,
|
| 92 |
+
parent=self
|
| 93 |
+
)
|
| 94 |
+
self.children.append(child)
|
| 95 |
+
return child
|
| 96 |
+
|
| 97 |
+
def is_leaf(self) -> bool:
|
| 98 |
+
"""葉ノードかどうか"""
|
| 99 |
+
return len(self.children) == 0
|
| 100 |
+
|
| 101 |
+
def get_depth(self) -> int:
|
| 102 |
+
"""ルートからの深さを取得"""
|
| 103 |
+
depth = 0
|
| 104 |
+
current = self.parent
|
| 105 |
+
while current is not None:
|
| 106 |
+
depth += 1
|
| 107 |
+
current = current.parent
|
| 108 |
+
return depth
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class WordDeterminer:
|
| 112 |
+
"""単語確定システム(ストリーミング向けリアルタイムアルゴリズム)"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, word_counter: WordCounter = None):
|
| 115 |
+
"""
|
| 116 |
+
初期化
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
word_counter: WordCounterインスタンス(Noneの場合はデフォルトを使用)
|
| 120 |
+
"""
|
| 121 |
+
self.word_counter = word_counter or WordCounter()
|
| 122 |
+
|
| 123 |
+
def is_boundary_char(self, char: str) -> bool:
|
| 124 |
+
"""境界文字かどうかを判定(fugashi使用)"""
|
| 125 |
+
if not char:
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
# 空白文字
|
| 129 |
+
if char.isspace():
|
| 130 |
+
return True
|
| 131 |
+
|
| 132 |
+
# 句読点
|
| 133 |
+
punctuation = ",,..。!?!?:;;、\n\t"
|
| 134 |
+
return char in punctuation
|
| 135 |
+
|
| 136 |
+
def is_word_boundary(self, text: str, position: int) -> bool:
|
| 137 |
+
"""
|
| 138 |
+
WordCounterを使用して単語境界を判定
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
text: テキスト
|
| 142 |
+
position: 位置(負の値で末尾から指定可能)
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
bool: 単語境界かどうか
|
| 146 |
+
"""
|
| 147 |
+
return self.word_counter.is_word_boundary(text, position)
|
| 148 |
+
|
| 149 |
+
def check_word_completion(self, piece: WordPiece, root_count: int, model: Any = None) -> Tuple[WordState, Optional[Any]]:
|
| 150 |
+
"""
|
| 151 |
+
ストリ���ミング向けリアルタイム単語決定アルゴリズム
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
piece: チェックするピース
|
| 155 |
+
root_count: ルートテキストの単語数
|
| 156 |
+
model: LLMモデル(BaseAIを実装したオブジェクト)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Tuple[WordState, Optional[Any]]: (状態, ペイロード)
|
| 160 |
+
"""
|
| 161 |
+
full_text = piece.get_full_text()
|
| 162 |
+
|
| 163 |
+
# next_tokensを取得
|
| 164 |
+
if not piece.next_tokens:
|
| 165 |
+
if model:
|
| 166 |
+
piece.next_tokens = self._get_next_tokens_from_model(model, full_text)
|
| 167 |
+
else:
|
| 168 |
+
return (WordState.COMPLETE, None)
|
| 169 |
+
|
| 170 |
+
if not piece.next_tokens:
|
| 171 |
+
return (WordState.COMPLETE, None)
|
| 172 |
+
|
| 173 |
+
# 確率順にソート(念のため)
|
| 174 |
+
sorted_tokens = sorted(piece.next_tokens, key=lambda x: x[1], reverse=True)
|
| 175 |
+
|
| 176 |
+
# 括弧の処理
|
| 177 |
+
if piece.get_full_word() and piece.get_full_word()[-1] in ["(","「","(","【","〈","《","[","{","⦅"]:
|
| 178 |
+
return (WordState.INCOMPLETE, None)
|
| 179 |
+
if piece.get_full_word() and piece.get_full_word()[-1] in [")","]","}","》","〉","》","]","}","⦆"]:
|
| 180 |
+
return (WordState.COMPLETE, None)
|
| 181 |
+
|
| 182 |
+
# 全トークンの挙動を確認
|
| 183 |
+
count = max(1, len(sorted_tokens))
|
| 184 |
+
tokens = sorted_tokens[:count]
|
| 185 |
+
|
| 186 |
+
boundary_prob = 0.0 # 境界を示すトークンの確率合計
|
| 187 |
+
continuation_prob = 0.0 # 継続を示すトークンの確率合計
|
| 188 |
+
total = sum(prob for _, prob in tokens)
|
| 189 |
+
|
| 190 |
+
for token, prob in tokens:
|
| 191 |
+
test_text = full_text + token
|
| 192 |
+
test_word_count = self._count_words(test_text)
|
| 193 |
+
|
| 194 |
+
# 単語数がより多く増えた場合のみ境界と判定(まとまりを上げる)
|
| 195 |
+
if test_word_count > root_count + 1:
|
| 196 |
+
boundary_prob += prob
|
| 197 |
+
else:
|
| 198 |
+
continuation_prob += prob
|
| 199 |
+
|
| 200 |
+
# 判定ロジック
|
| 201 |
+
if total > 0:
|
| 202 |
+
boundary_ratio = boundary_prob / total
|
| 203 |
+
|
| 204 |
+
# トークンの多くが境界を示す場合 → 確定(閾値を上げてまとまりを上げる)
|
| 205 |
+
if boundary_ratio > 0.85:
|
| 206 |
+
return (WordState.COMPLETE, None)
|
| 207 |
+
|
| 208 |
+
# トークンの多くが継続を示す場合 → 継続(閾値を下げて継続しやすく)
|
| 209 |
+
if boundary_ratio < 0.2:
|
| 210 |
+
return (WordState.INCOMPLETE, None)
|
| 211 |
+
|
| 212 |
+
# エントロピーベース判定
|
| 213 |
+
probs = [prob for _, prob in sorted_tokens]
|
| 214 |
+
entropy = -sum(p * math.log(p + 1e-10) for p in probs if p > 0)
|
| 215 |
+
max_entropy = math.log(len(sorted_tokens)) if len(sorted_tokens) > 1 else 1.0
|
| 216 |
+
normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0
|
| 217 |
+
|
| 218 |
+
return (WordState.INCOMPLETE, None)
|
| 219 |
+
|
| 220 |
+
def _count_words(self, text: str) -> int:
|
| 221 |
+
"""
|
| 222 |
+
WordCounterを使用してテキストの単語数をカウント
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
text: カウントするテキスト
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
int: 単語数
|
| 229 |
+
"""
|
| 230 |
+
return self.word_counter.count_words(text)
|
| 231 |
+
|
| 232 |
+
def _get_next_tokens_from_model(self, model: Any, text: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
| 233 |
+
"""
|
| 234 |
+
モデルから次のトークン候補を取得(新しいBaseAIインターフェースを使用)
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
model: BaseAIを実装したモデルオブジェクト
|
| 238 |
+
text: 入力テキスト
|
| 239 |
+
top_k: 取得する候補数
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
List[Tuple[str, float]]: (トークン, 確率)のリスト
|
| 243 |
+
"""
|
| 244 |
+
try:
|
| 245 |
+
# BaseAIインターフェースを実装したモデルを使用
|
| 246 |
+
if isinstance(model, BaseAI):
|
| 247 |
+
return model.get_token_probabilities(text, top_k)
|
| 248 |
+
else:
|
| 249 |
+
print(f"[WORD_PROCESSOR] モデルがBaseAIインターフェースを実装していません: {type(model)}")
|
| 250 |
+
return []
|
| 251 |
+
except Exception as e:
|
| 252 |
+
print(f"[WORD_PROCESSOR] モデルからのトークン取得に失敗: {e}")
|
| 253 |
+
import traceback
|
| 254 |
+
traceback.print_exc()
|
| 255 |
+
|
| 256 |
+
return []
|
| 257 |
+
|
| 258 |
+
def expand_piece(self, piece: WordPiece, model: Any = None) -> List[WordPiece]:
|
| 259 |
+
"""
|
| 260 |
+
ピースを展開して子ピースを生成
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
piece: 展開するピース
|
| 264 |
+
model: LLMモデル(BaseAIを実装したオブジェクト)
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
List[WordPiece]: 生成された子ピースのリスト
|
| 268 |
+
"""
|
| 269 |
+
children = []
|
| 270 |
+
full_text = piece.get_full_text()
|
| 271 |
+
|
| 272 |
+
if piece.next_tokens:
|
| 273 |
+
# 既存のnext_tokensを使用
|
| 274 |
+
for token, prob in piece.next_tokens:
|
| 275 |
+
# 空文字列トークンを無視
|
| 276 |
+
if not token:
|
| 277 |
+
continue
|
| 278 |
+
child_prob = piece.probability * prob
|
| 279 |
+
child = piece.add_child(token, child_prob)
|
| 280 |
+
children.append(child)
|
| 281 |
+
elif model:
|
| 282 |
+
# モデルから次のトークンを取得
|
| 283 |
+
next_tokens = self._get_next_tokens_from_model(model, full_text)
|
| 284 |
+
|
| 285 |
+
if next_tokens:
|
| 286 |
+
piece.next_tokens = next_tokens
|
| 287 |
+
for token, prob in next_tokens:
|
| 288 |
+
# 空文字列トークンを無視
|
| 289 |
+
if not token:
|
| 290 |
+
continue
|
| 291 |
+
child_prob = piece.probability * prob
|
| 292 |
+
child = piece.add_child(token, child_prob)
|
| 293 |
+
children.append(child)
|
| 294 |
+
else:
|
| 295 |
+
print(f"[WORD_PROCESSOR] No model provided for expansion")
|
| 296 |
+
|
| 297 |
+
return children
|
| 298 |
+
|
| 299 |
+
def build_word_tree(self, prompt_text: str, root_text: str, model: Any, top_k: int = 5, max_depth: int = 10) -> List[WordPiece]:
|
| 300 |
+
"""
|
| 301 |
+
単語ツリーを構築
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
prompt_text: プロンプトテキスト
|
| 305 |
+
root_text: ルートテキスト
|
| 306 |
+
model: LLMモデル(BaseAIを実装したオブジェクト)
|
| 307 |
+
top_k: 取得する候補数
|
| 308 |
+
max_depth: 最大深さ
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
List[WordPiece]: 完成した単語ピースのリスト
|
| 312 |
+
"""
|
| 313 |
+
# モデルのbuild_chat_promptメソッドを使用
|
| 314 |
+
if isinstance(model, BaseAI):
|
| 315 |
+
prompt = model.build_chat_prompt(prompt_text)
|
| 316 |
+
else:
|
| 317 |
+
# フォールバック: 従来の形式
|
| 318 |
+
prompt = self.build_chat_prompt(prompt_text)
|
| 319 |
+
|
| 320 |
+
# ルートピースを作成
|
| 321 |
+
root = WordPiece(text=prompt + root_text, probability=1.0)
|
| 322 |
+
|
| 323 |
+
# 優先度付きキュー(確率順)
|
| 324 |
+
candidates = KList(2 * top_k)
|
| 325 |
+
completed = []
|
| 326 |
+
iteration = 0
|
| 327 |
+
max_iterations = 1000
|
| 328 |
+
children = self.expand_piece(root, model)
|
| 329 |
+
for child in children:
|
| 330 |
+
candidates.add(child)
|
| 331 |
+
|
| 332 |
+
while not candidates.empty() and iteration < max_iterations and len(completed) < top_k:
|
| 333 |
+
iteration += 1
|
| 334 |
+
|
| 335 |
+
# 最も確率の高い候補を取得
|
| 336 |
+
current = candidates.pop()
|
| 337 |
+
|
| 338 |
+
# 単語完成状態をチェック
|
| 339 |
+
root_count = self._count_words(root.get_full_text())
|
| 340 |
+
state, payload = self.check_word_completion(current, root_count, model)
|
| 341 |
+
|
| 342 |
+
if state == WordState.COMPLETE:
|
| 343 |
+
completed.append(current)
|
| 344 |
+
elif state == WordState.INCOMPLETE:
|
| 345 |
+
# ピースを展開
|
| 346 |
+
children = self.expand_piece(current, model)
|
| 347 |
+
if len(children) == 0:
|
| 348 |
+
# 子が生成できない場合、ピースを完成として扱う(無限ループ防止)
|
| 349 |
+
print(f"[WORD_PROCESSOR] No children generated for '{current.get_full_text()}', marking as COMPLETE")
|
| 350 |
+
completed.append(current)
|
| 351 |
+
else:
|
| 352 |
+
for child in children:
|
| 353 |
+
candidates.add(child)
|
| 354 |
+
|
| 355 |
+
# 確率で正規化
|
| 356 |
+
total_prob = sum(p.probability for p in completed)
|
| 357 |
+
if total_prob > 0:
|
| 358 |
+
for piece in completed:
|
| 359 |
+
piece.probability = piece.probability / total_prob
|
| 360 |
+
|
| 361 |
+
return completed[:top_k]
|
| 362 |
+
|
| 363 |
+
def build_chat_prompt(self, user_content: str,
|
| 364 |
+
system_content: str = "あなたは親切で役に立つAIアシスタントです。簡潔な回答をしてください") -> str:
|
| 365 |
+
"""
|
| 366 |
+
チャットプロンプトを構築(後方互換性のため)
|
| 367 |
+
|
| 368 |
+
注意: 新しいBaseAIインターフェースを使用する場合は、model.build_chat_prompt()を使用してください
|
| 369 |
+
"""
|
| 370 |
+
# 既に整形済みのプロンプトが渡されている場合(複数行、ヘッダーを含む)
|
| 371 |
+
# そのまま返す
|
| 372 |
+
if "<|start_header_id|>" in user_content or "<|eot_id|>" in user_content:
|
| 373 |
+
return user_content
|
| 374 |
+
|
| 375 |
+
# 後方互換性: 単一のuser_contentが渡された場合の従来の形式
|
| 376 |
+
prompt_text = (
|
| 377 |
+
f"<|begin_of_text|>"
|
| 378 |
+
f"<|start_header_id|>system<|end_header_id|>\n"
|
| 379 |
+
f"{system_content}\n<|eot_id|>"
|
| 380 |
+
f"<|start_header_id|>user<|end_header_id|>\n"
|
| 381 |
+
f"{user_content}\n<|eot_id|>"
|
| 382 |
+
f"<|start_header_id|>assistant<|end_header_id|>\n"
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# BOS(<|begin_of_text|>) の重複を抑止: 先頭のBOSを全て除去
|
| 386 |
+
BOS = "<|begin_of_text|>"
|
| 387 |
+
s = prompt_text.lstrip()
|
| 388 |
+
while s.startswith(BOS):
|
| 389 |
+
s = s[len(BOS):]
|
| 390 |
+
prompt_text = s
|
| 391 |
+
return prompt_text
|
| 392 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# マルチモデル対応に必要なパッケージ
|
| 2 |
+
|
| 3 |
+
# ============================================
|
| 4 |
+
# Hugging Face Spaces用(推奨、必須)
|
| 5 |
+
# ============================================
|
| 6 |
+
# Gradio - Web UI
|
| 7 |
+
gradio>=4.0.0
|
| 8 |
+
|
| 9 |
+
# Hugging Face Spaces
|
| 10 |
+
spaces
|
| 11 |
+
|
| 12 |
+
# Transformers - ローカルでモデルをロード(無料、完全なトークン確率取得可能)
|
| 13 |
+
transformers>=4.30.0
|
| 14 |
+
torch>=2.0.0
|
| 15 |
+
huggingface-hub>=0.16.0
|
| 16 |
+
|
| 17 |
+
# 形態素解析
|
| 18 |
+
fugashi>=1.3.0
|
| 19 |
+
sudachipy>=0.6.7
|
| 20 |
+
sudachidict-core>=20240125
|
| 21 |
+
|
| 22 |
+
# その他
|
| 23 |
+
numpy>=1.24.0
|
| 24 |
+
|
| 25 |
+
# ============================================
|
| 26 |
+
# 外部API用(オプション、非推奨)
|
| 27 |
+
# Hugging Face Spacesでは不要です
|
| 28 |
+
# ============================================
|
| 29 |
+
# OpenAI API(有料、レート制限あり)
|
| 30 |
+
# openai>=1.0.0
|
| 31 |
+
|
| 32 |
+
# Anthropic API(有料、レート制限あり)
|
| 33 |
+
# anthropic>=0.18.0
|
| 34 |
+
|
| 35 |
+
# Google API(有料、レート制限あり)
|
| 36 |
+
# google-generativeai>=0.3.0
|
| 37 |
+
|