Upload folder using huggingface_hub
Browse files- __init__.py +0 -0
- __pycache__/__init__.cpython-313.pyc +0 -0
- __pycache__/auto_training.cpython-313.pyc +0 -0
- __pycache__/coordinate_estimator.cpython-313.pyc +0 -0
- __pycache__/db_providers.cpython-313.pyc +0 -0
- __pycache__/fine_tuning.cpython-313.pyc +0 -0
- __pycache__/iath_db_provider.cpython-313.pyc +0 -0
- __pycache__/iath_memory.cpython-313.pyc +0 -0
- __pycache__/iath_writer.cpython-313.pyc +0 -0
- __pycache__/llm_providers.cpython-313.pyc +0 -0
- __pycache__/model_router.cpython-313.pyc +0 -0
- auto_training.py +403 -0
- coordinate_estimator.py +359 -0
- db_enrichment.py +523 -0
- fine_tuning.py +627 -0
- iath_memory.py +370 -0
- iath_writer.py +453 -0
- llm_providers.py +434 -0
- model_router.py +828 -0
- model_router.py.backup +803 -0
__init__.py
ADDED
|
File without changes
|
__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
__pycache__/auto_training.cpython-313.pyc
ADDED
|
Binary file (19 kB). View file
|
|
|
__pycache__/coordinate_estimator.cpython-313.pyc
ADDED
|
Binary file (14 kB). View file
|
|
|
__pycache__/db_providers.cpython-313.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
__pycache__/fine_tuning.cpython-313.pyc
ADDED
|
Binary file (21.8 kB). View file
|
|
|
__pycache__/iath_db_provider.cpython-313.pyc
ADDED
|
Binary file (18.3 kB). View file
|
|
|
__pycache__/iath_memory.cpython-313.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
__pycache__/iath_writer.cpython-313.pyc
ADDED
|
Binary file (17.9 kB). View file
|
|
|
__pycache__/llm_providers.cpython-313.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
__pycache__/model_router.cpython-313.pyc
ADDED
|
Binary file (45.1 kB). View file
|
|
|
auto_training.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NullAI Auto-Training Manager
|
| 3 |
+
|
| 4 |
+
自動学習システムの核となるモジュール。
|
| 5 |
+
データ量や時間ベースのトリガーで自動的にファインチューニングを実行する。
|
| 6 |
+
"""
|
| 7 |
+
import asyncio
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Dict, Any, Optional, List
|
| 13 |
+
from dataclasses import dataclass, asdict
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class AutoTrainingState:
|
| 20 |
+
"""自動学習システムの状態"""
|
| 21 |
+
enabled: bool = True
|
| 22 |
+
last_check_time: Optional[str] = None
|
| 23 |
+
last_training_time: Optional[str] = None
|
| 24 |
+
last_training_success: bool = True
|
| 25 |
+
last_training_examples_count: int = 0
|
| 26 |
+
next_scheduled_training: Optional[str] = None
|
| 27 |
+
total_auto_trainings: int = 0
|
| 28 |
+
consecutive_failures: int = 0
|
| 29 |
+
is_training: bool = False
|
| 30 |
+
last_error: Optional[str] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class AutoTrainingManager:
|
| 34 |
+
"""
|
| 35 |
+
自動学習マネージャー
|
| 36 |
+
|
| 37 |
+
設定に基づいて、トレーニングデータを監視し、
|
| 38 |
+
条件を満たした場合に自動的にファインチューニングを実行する。
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, config: Dict[str, Any], training_manager):
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
config: null_ai_config.json の auto_training セクション
|
| 45 |
+
training_manager: FineTuningManager インスタンス
|
| 46 |
+
"""
|
| 47 |
+
self.config = config
|
| 48 |
+
self.training_manager = training_manager
|
| 49 |
+
self.state = AutoTrainingState()
|
| 50 |
+
self.state_file = Path("training_data/auto_training_state.json")
|
| 51 |
+
|
| 52 |
+
# 設定の読み込み
|
| 53 |
+
self.enabled = config.get("enabled", True)
|
| 54 |
+
self.trigger_mode = config.get("trigger_mode", "hybrid")
|
| 55 |
+
self.min_examples = config.get("min_examples", 100)
|
| 56 |
+
self.min_days = config.get("min_days_since_last_training", 7)
|
| 57 |
+
self.max_days = config.get("max_days_since_last_training", 30)
|
| 58 |
+
self.quality_threshold = config.get("quality_threshold", 0.8)
|
| 59 |
+
self.check_interval_minutes = config.get("check_interval_minutes", 60)
|
| 60 |
+
self.preferred_hour = config.get("preferred_training_hour", 2)
|
| 61 |
+
self.allow_manual_override = config.get("allow_manual_override", True)
|
| 62 |
+
|
| 63 |
+
# トレーニングパラメータ
|
| 64 |
+
self.training_method = config.get("training_method", "peft")
|
| 65 |
+
self.training_params = config.get("training_params", {})
|
| 66 |
+
|
| 67 |
+
# 状態の復元
|
| 68 |
+
self._load_state()
|
| 69 |
+
|
| 70 |
+
logger.info(f"AutoTrainingManager initialized: enabled={self.enabled}, trigger_mode={self.trigger_mode}")
|
| 71 |
+
|
| 72 |
+
def _load_state(self):
|
| 73 |
+
"""永続化された状態を読み込む"""
|
| 74 |
+
try:
|
| 75 |
+
if self.state_file.exists():
|
| 76 |
+
with open(self.state_file, 'r') as f:
|
| 77 |
+
state_dict = json.load(f)
|
| 78 |
+
self.state = AutoTrainingState(**state_dict)
|
| 79 |
+
logger.info(f"Loaded auto-training state from {self.state_file}")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.warning(f"Failed to load auto-training state: {e}")
|
| 82 |
+
|
| 83 |
+
def _save_state(self):
|
| 84 |
+
"""状態を永続化する"""
|
| 85 |
+
try:
|
| 86 |
+
self.state_file.parent.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
with open(self.state_file, 'w') as f:
|
| 88 |
+
json.dump(asdict(self.state), f, indent=2)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"Failed to save auto-training state: {e}")
|
| 91 |
+
|
| 92 |
+
def get_training_data_stats(self, domain_id: Optional[str] = None) -> Dict[str, Any]:
|
| 93 |
+
"""
|
| 94 |
+
トレーニングデータの統計を取得
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
{
|
| 98 |
+
"total_examples": int,
|
| 99 |
+
"examples_by_domain": Dict[str, int],
|
| 100 |
+
"high_quality_count": int,
|
| 101 |
+
"oldest_timestamp": str,
|
| 102 |
+
"newest_timestamp": str
|
| 103 |
+
}
|
| 104 |
+
"""
|
| 105 |
+
training_data_dir = Path("training_data/master_outputs")
|
| 106 |
+
if not training_data_dir.exists():
|
| 107 |
+
return {
|
| 108 |
+
"total_examples": 0,
|
| 109 |
+
"examples_by_domain": {},
|
| 110 |
+
"high_quality_count": 0,
|
| 111 |
+
"oldest_timestamp": None,
|
| 112 |
+
"newest_timestamp": None
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
stats = {
|
| 116 |
+
"total_examples": 0,
|
| 117 |
+
"examples_by_domain": {},
|
| 118 |
+
"high_quality_count": 0,
|
| 119 |
+
"oldest_timestamp": None,
|
| 120 |
+
"newest_timestamp": None
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# JSONLファイルを走査
|
| 124 |
+
jsonl_files = []
|
| 125 |
+
if domain_id:
|
| 126 |
+
jsonl_files = [training_data_dir / f"master_outputs_{domain_id}.jsonl"]
|
| 127 |
+
else:
|
| 128 |
+
jsonl_files = list(training_data_dir.glob("master_outputs_*.jsonl"))
|
| 129 |
+
|
| 130 |
+
for jsonl_file in jsonl_files:
|
| 131 |
+
if not jsonl_file.exists():
|
| 132 |
+
continue
|
| 133 |
+
|
| 134 |
+
domain = jsonl_file.stem.replace("master_outputs_", "")
|
| 135 |
+
domain_count = 0
|
| 136 |
+
|
| 137 |
+
with open(jsonl_file, 'r', encoding='utf-8') as f:
|
| 138 |
+
for line in f:
|
| 139 |
+
try:
|
| 140 |
+
example = json.loads(line.strip())
|
| 141 |
+
stats["total_examples"] += 1
|
| 142 |
+
domain_count += 1
|
| 143 |
+
|
| 144 |
+
# 高品質データのカウント
|
| 145 |
+
confidence = example.get("metadata", {}).get("confidence", 0)
|
| 146 |
+
if confidence >= self.quality_threshold:
|
| 147 |
+
stats["high_quality_count"] += 1
|
| 148 |
+
|
| 149 |
+
# タイムスタンプの追跡
|
| 150 |
+
timestamp = example.get("metadata", {}).get("timestamp")
|
| 151 |
+
if timestamp:
|
| 152 |
+
if stats["oldest_timestamp"] is None or timestamp < stats["oldest_timestamp"]:
|
| 153 |
+
stats["oldest_timestamp"] = timestamp
|
| 154 |
+
if stats["newest_timestamp"] is None or timestamp > stats["newest_timestamp"]:
|
| 155 |
+
stats["newest_timestamp"] = timestamp
|
| 156 |
+
|
| 157 |
+
except json.JSONDecodeError:
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
if domain_count > 0:
|
| 161 |
+
stats["examples_by_domain"][domain] = domain_count
|
| 162 |
+
|
| 163 |
+
return stats
|
| 164 |
+
|
| 165 |
+
def check_training_trigger(self, domain_id: Optional[str] = None) -> tuple[bool, str]:
|
| 166 |
+
"""
|
| 167 |
+
トレーニングをトリガーすべきかチェックする
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
(should_trigger: bool, reason: str)
|
| 171 |
+
"""
|
| 172 |
+
if not self.enabled:
|
| 173 |
+
return False, "Auto-training is disabled"
|
| 174 |
+
|
| 175 |
+
if self.state.is_training:
|
| 176 |
+
return False, "Training is already in progress"
|
| 177 |
+
|
| 178 |
+
# データ統計を取得
|
| 179 |
+
stats = self.get_training_data_stats(domain_id)
|
| 180 |
+
|
| 181 |
+
if stats["total_examples"] == 0:
|
| 182 |
+
return False, "No training data available"
|
| 183 |
+
|
| 184 |
+
# 最終トレーニングからの経過時間を計算
|
| 185 |
+
days_since_last = None
|
| 186 |
+
if self.state.last_training_time:
|
| 187 |
+
try:
|
| 188 |
+
last_training = datetime.fromisoformat(self.state.last_training_time)
|
| 189 |
+
days_since_last = (datetime.utcnow() - last_training).days
|
| 190 |
+
except ValueError:
|
| 191 |
+
pass
|
| 192 |
+
|
| 193 |
+
# トリガーモードに応じた判定
|
| 194 |
+
if self.trigger_mode == "data_count":
|
| 195 |
+
# データ量ベースのみ
|
| 196 |
+
if stats["high_quality_count"] >= self.min_examples:
|
| 197 |
+
return True, f"Sufficient training data ({stats['high_quality_count']} examples >= {self.min_examples})"
|
| 198 |
+
return False, f"Insufficient training data ({stats['high_quality_count']} < {self.min_examples})"
|
| 199 |
+
|
| 200 |
+
elif self.trigger_mode == "time_based":
|
| 201 |
+
# 時間ベースのみ
|
| 202 |
+
if days_since_last is None:
|
| 203 |
+
return True, "First auto-training"
|
| 204 |
+
if days_since_last >= self.min_days:
|
| 205 |
+
return True, f"Time threshold met ({days_since_last} days >= {self.min_days} days)"
|
| 206 |
+
return False, f"Too soon since last training ({days_since_last} < {self.min_days} days)"
|
| 207 |
+
|
| 208 |
+
elif self.trigger_mode == "hybrid":
|
| 209 |
+
# ハイブリッド(データ量 AND 時間)
|
| 210 |
+
if stats["high_quality_count"] < self.min_examples:
|
| 211 |
+
return False, f"Insufficient training data ({stats['high_quality_count']} < {self.min_examples})"
|
| 212 |
+
|
| 213 |
+
if days_since_last is None:
|
| 214 |
+
return True, f"First auto-training with {stats['high_quality_count']} examples"
|
| 215 |
+
|
| 216 |
+
if days_since_last >= self.min_days:
|
| 217 |
+
return True, f"Both conditions met: {stats['high_quality_count']} examples, {days_since_last} days since last training"
|
| 218 |
+
|
| 219 |
+
return False, f"Time condition not met ({days_since_last} < {self.min_days} days)"
|
| 220 |
+
|
| 221 |
+
elif self.trigger_mode == "max_interval":
|
| 222 |
+
# 最大間隔強制モード
|
| 223 |
+
if days_since_last is not None and days_since_last >= self.max_days:
|
| 224 |
+
return True, f"Maximum interval reached ({days_since_last} >= {self.max_days} days)"
|
| 225 |
+
|
| 226 |
+
# 通常のハイブリッド判定
|
| 227 |
+
if stats["high_quality_count"] >= self.min_examples and (days_since_last is None or days_since_last >= self.min_days):
|
| 228 |
+
return True, f"Standard conditions met: {stats['high_quality_count']} examples"
|
| 229 |
+
|
| 230 |
+
return False, "Conditions not met"
|
| 231 |
+
|
| 232 |
+
return False, f"Unknown trigger mode: {self.trigger_mode}"
|
| 233 |
+
|
| 234 |
+
def should_train_now(self) -> bool:
|
| 235 |
+
"""
|
| 236 |
+
現在がトレーニングに適した時間帯かチェック
|
| 237 |
+
|
| 238 |
+
preferred_training_hour の前後1時間をトレーニング推奨時間とする
|
| 239 |
+
"""
|
| 240 |
+
current_hour = datetime.utcnow().hour
|
| 241 |
+
|
| 242 |
+
# 推奨時間の前後1時間
|
| 243 |
+
target_hours = [
|
| 244 |
+
(self.preferred_hour - 1) % 24,
|
| 245 |
+
self.preferred_hour,
|
| 246 |
+
(self.preferred_hour + 1) % 24
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
return current_hour in target_hours
|
| 250 |
+
|
| 251 |
+
async def trigger_auto_training(self, domain_id: Optional[str] = None) -> Dict[str, Any]:
|
| 252 |
+
"""
|
| 253 |
+
自動トレーニングを実行
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
トレーニング結果の辞書
|
| 257 |
+
"""
|
| 258 |
+
logger.info(f"Starting auto-training for domain: {domain_id or 'all'}")
|
| 259 |
+
|
| 260 |
+
# 状態を更新
|
| 261 |
+
self.state.is_training = True
|
| 262 |
+
self.state.last_check_time = datetime.utcnow().isoformat()
|
| 263 |
+
self._save_state()
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
# データ統計を取得
|
| 267 |
+
stats = self.get_training_data_stats(domain_id)
|
| 268 |
+
|
| 269 |
+
# ファインチューニングを実行
|
| 270 |
+
# 注: training_manager の実装に合わせて適切なメソッドを呼び出す
|
| 271 |
+
result = await self._execute_training(domain_id, stats)
|
| 272 |
+
|
| 273 |
+
# 成功時の状態更新
|
| 274 |
+
self.state.last_training_time = datetime.utcnow().isoformat()
|
| 275 |
+
self.state.last_training_success = result.get("success", False)
|
| 276 |
+
self.state.last_training_examples_count = stats["high_quality_count"]
|
| 277 |
+
self.state.total_auto_trainings += 1
|
| 278 |
+
self.state.consecutive_failures = 0
|
| 279 |
+
self.state.last_error = None
|
| 280 |
+
|
| 281 |
+
logger.info(f"Auto-training completed successfully: {result}")
|
| 282 |
+
|
| 283 |
+
return {
|
| 284 |
+
"success": True,
|
| 285 |
+
"result": result,
|
| 286 |
+
"stats": stats,
|
| 287 |
+
"timestamp": self.state.last_training_time
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
logger.error(f"Auto-training failed: {e}", exc_info=True)
|
| 292 |
+
|
| 293 |
+
# 失敗時の状態更新
|
| 294 |
+
self.state.last_training_success = False
|
| 295 |
+
self.state.consecutive_failures += 1
|
| 296 |
+
self.state.last_error = str(e)
|
| 297 |
+
|
| 298 |
+
return {
|
| 299 |
+
"success": False,
|
| 300 |
+
"error": str(e),
|
| 301 |
+
"consecutive_failures": self.state.consecutive_failures
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
finally:
|
| 305 |
+
self.state.is_training = False
|
| 306 |
+
self._save_state()
|
| 307 |
+
|
| 308 |
+
async def _execute_training(self, domain_id: Optional[str], stats: Dict[str, Any]) -> Dict[str, Any]:
|
| 309 |
+
"""
|
| 310 |
+
実際のトレーニングを実行(内部メソッド)
|
| 311 |
+
"""
|
| 312 |
+
# トレーニングパラメータを準備
|
| 313 |
+
training_params = {
|
| 314 |
+
"apprentice_model_name": None, # 既存の弟子モデルを使用
|
| 315 |
+
"domain_id": domain_id,
|
| 316 |
+
"method": self.training_method,
|
| 317 |
+
"epochs": self.training_params.get("epochs", 3),
|
| 318 |
+
"learning_rate": self.training_params.get("learning_rate", 2e-4),
|
| 319 |
+
"batch_size": self.training_params.get("batch_size", 4),
|
| 320 |
+
"lora_r": self.training_params.get("lora_r", 8),
|
| 321 |
+
"lora_alpha": self.training_params.get("lora_alpha", 16),
|
| 322 |
+
"output_name": f"auto_training_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
logger.info(f"Executing training with params: {training_params}")
|
| 326 |
+
|
| 327 |
+
# FineTuningManagerを使ってトレーニングを実行
|
| 328 |
+
# 注: この部分は実際のトレーニングAPIに合わせて実装する必要があります
|
| 329 |
+
# 今はプレースホルダーとして簡単な構造を返します
|
| 330 |
+
|
| 331 |
+
# TODO: 実際のトレーニング実行コードをここに実装
|
| 332 |
+
result = {
|
| 333 |
+
"success": True,
|
| 334 |
+
"output_dir": f"training_data/checkpoints/{training_params['output_name']}",
|
| 335 |
+
"model_name": training_params['output_name'],
|
| 336 |
+
"train_loss": 0.5, # プレースホルダー
|
| 337 |
+
"method": self.training_method,
|
| 338 |
+
"examples_used": stats["high_quality_count"]
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
return result
|
| 342 |
+
|
| 343 |
+
def get_status(self) -> Dict[str, Any]:
|
| 344 |
+
"""
|
| 345 |
+
自動学習システムの現在の状態を取得
|
| 346 |
+
"""
|
| 347 |
+
should_trigger, reason = self.check_training_trigger()
|
| 348 |
+
stats = self.get_training_data_stats()
|
| 349 |
+
|
| 350 |
+
return {
|
| 351 |
+
"enabled": self.enabled,
|
| 352 |
+
"is_training": self.state.is_training,
|
| 353 |
+
"trigger_mode": self.trigger_mode,
|
| 354 |
+
"should_trigger": should_trigger,
|
| 355 |
+
"trigger_reason": reason,
|
| 356 |
+
"config": {
|
| 357 |
+
"min_examples": self.min_examples,
|
| 358 |
+
"min_days": self.min_days,
|
| 359 |
+
"max_days": self.max_days,
|
| 360 |
+
"quality_threshold": self.quality_threshold,
|
| 361 |
+
"check_interval_minutes": self.check_interval_minutes,
|
| 362 |
+
"preferred_hour": self.preferred_hour
|
| 363 |
+
},
|
| 364 |
+
"state": {
|
| 365 |
+
"last_check_time": self.state.last_check_time,
|
| 366 |
+
"last_training_time": self.state.last_training_time,
|
| 367 |
+
"last_training_success": self.state.last_training_success,
|
| 368 |
+
"last_training_examples_count": self.state.last_training_examples_count,
|
| 369 |
+
"total_auto_trainings": self.state.total_auto_trainings,
|
| 370 |
+
"consecutive_failures": self.state.consecutive_failures,
|
| 371 |
+
"last_error": self.state.last_error
|
| 372 |
+
},
|
| 373 |
+
"data_stats": stats,
|
| 374 |
+
"should_train_now": self.should_train_now()
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
def enable(self):
|
| 378 |
+
"""自動学習を有効化"""
|
| 379 |
+
self.enabled = True
|
| 380 |
+
self.state.enabled = True
|
| 381 |
+
self._save_state()
|
| 382 |
+
logger.info("Auto-training enabled")
|
| 383 |
+
|
| 384 |
+
def disable(self):
|
| 385 |
+
"""自動学習を無効化"""
|
| 386 |
+
self.enabled = False
|
| 387 |
+
self.state.enabled = False
|
| 388 |
+
self._save_state()
|
| 389 |
+
logger.info("Auto-training disabled")
|
| 390 |
+
|
| 391 |
+
def update_config(self, new_config: Dict[str, Any]):
|
| 392 |
+
"""設定を更新"""
|
| 393 |
+
self.config.update(new_config)
|
| 394 |
+
|
| 395 |
+
# 設定値を再読み込み
|
| 396 |
+
self.trigger_mode = self.config.get("trigger_mode", self.trigger_mode)
|
| 397 |
+
self.min_examples = self.config.get("min_examples", self.min_examples)
|
| 398 |
+
self.min_days = self.config.get("min_days_since_last_training", self.min_days)
|
| 399 |
+
self.max_days = self.config.get("max_days_since_last_training", self.max_days)
|
| 400 |
+
self.quality_threshold = self.config.get("quality_threshold", self.quality_threshold)
|
| 401 |
+
|
| 402 |
+
logger.info(f"Auto-training config updated: {new_config}")
|
| 403 |
+
self._save_state()
|
coordinate_estimator.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# null_ai/coordinate_estimator.py
|
| 2 |
+
"""
|
| 3 |
+
Coordinate Auto-Estimation Module
|
| 4 |
+
|
| 5 |
+
AIを使って知識タイルの6次元座標を自動推定します。
|
| 6 |
+
座標: [x, y, z, c, g, v]
|
| 7 |
+
- medical_space [x, y, z]: ドメイン固有の3次元空間
|
| 8 |
+
- meta_space [c, g, v]: Certainty, Granularity, Verification
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import json
|
| 13 |
+
from typing import List, Dict, Any, Optional
|
| 14 |
+
import asyncio
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CoordinateEstimator:
|
| 20 |
+
"""
|
| 21 |
+
LLMを使って6次元座標を自動推定するクラス
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.domain_schemas = self._load_domain_schemas()
|
| 26 |
+
|
| 27 |
+
def _load_domain_schemas(self) -> Dict[str, Dict[str, str]]:
|
| 28 |
+
"""
|
| 29 |
+
各ドメインの座標軸の定義を返す
|
| 30 |
+
|
| 31 |
+
将来的には設定ファイルから読み込む
|
| 32 |
+
"""
|
| 33 |
+
return {
|
| 34 |
+
"medical": {
|
| 35 |
+
"x": "Anatomical location (0.0=nervous system, 0.5=cardiovascular, 1.0=digestive)",
|
| 36 |
+
"y": "Pathological classification (0.0=infectious, 0.5=metabolic, 1.0=trauma)",
|
| 37 |
+
"z": "Treatment level (0.0=prevention, 0.5=diagnosis, 1.0=treatment)"
|
| 38 |
+
},
|
| 39 |
+
"general": {
|
| 40 |
+
"x": "Knowledge category (0.0=science, 0.5=technology, 1.0=humanities)",
|
| 41 |
+
"y": "Complexity level (0.0=basic, 0.5=intermediate, 1.0=advanced)",
|
| 42 |
+
"z": "Application scope (0.0=theoretical, 0.5=practical, 1.0=applied)"
|
| 43 |
+
},
|
| 44 |
+
"legal": {
|
| 45 |
+
"x": "Legal field (0.0=civil, 0.5=criminal, 1.0=commercial)",
|
| 46 |
+
"y": "Court level (0.0=district, 0.5=high, 1.0=supreme)",
|
| 47 |
+
"z": "Era (0.0=classical, 0.5=modern, 1.0=contemporary)"
|
| 48 |
+
},
|
| 49 |
+
"technology": {
|
| 50 |
+
"x": "Technology domain (0.0=hardware, 0.5=software, 1.0=network)",
|
| 51 |
+
"y": "Maturity (0.0=emerging, 0.5=established, 1.0=legacy)",
|
| 52 |
+
"z": "Scale (0.0=personal, 0.5=enterprise, 1.0=global)"
|
| 53 |
+
}
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
async def estimate_coordinates(
|
| 57 |
+
self,
|
| 58 |
+
prompt: str,
|
| 59 |
+
response: str,
|
| 60 |
+
domain_id: str,
|
| 61 |
+
llm_inference_func,
|
| 62 |
+
use_reasoning: bool = True
|
| 63 |
+
) -> Dict[str, Any]:
|
| 64 |
+
"""
|
| 65 |
+
6次元座標を推定
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
prompt: ユーザーの質問
|
| 69 |
+
response: AIの回答
|
| 70 |
+
domain_id: ドメインID
|
| 71 |
+
llm_inference_func: LLM推論関数(async)
|
| 72 |
+
use_reasoning: 推論過程を含めるか
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
{
|
| 76 |
+
"coordinates": [x, y, z, c, g, v],
|
| 77 |
+
"reasoning": "推定の理由",
|
| 78 |
+
"confidence": 0.85
|
| 79 |
+
}
|
| 80 |
+
"""
|
| 81 |
+
# ドメインスキーマ取得
|
| 82 |
+
domain_schema = self.domain_schemas.get(
|
| 83 |
+
domain_id,
|
| 84 |
+
self.domain_schemas["general"] # フォールバック
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# プロンプト構築
|
| 88 |
+
estimation_prompt = self._build_estimation_prompt(
|
| 89 |
+
prompt, response, domain_id, domain_schema, use_reasoning
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# LLMに座標推定を依頼
|
| 93 |
+
try:
|
| 94 |
+
llm_response = await llm_inference_func(estimation_prompt)
|
| 95 |
+
|
| 96 |
+
# レスポンスから座標を抽出
|
| 97 |
+
result = self._parse_llm_response(llm_response)
|
| 98 |
+
|
| 99 |
+
# バリデーション
|
| 100 |
+
if self._validate_coordinates(result["coordinates"]):
|
| 101 |
+
logger.info(f"Estimated coordinates for domain '{domain_id}': {result['coordinates']}")
|
| 102 |
+
return result
|
| 103 |
+
else:
|
| 104 |
+
logger.error(f"Invalid coordinates: {result['coordinates']}")
|
| 105 |
+
return self._get_default_coordinates(domain_id)
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Coordinate estimation failed: {e}")
|
| 109 |
+
return self._get_default_coordinates(domain_id)
|
| 110 |
+
|
| 111 |
+
def _build_estimation_prompt(
|
| 112 |
+
self,
|
| 113 |
+
prompt: str,
|
| 114 |
+
response: str,
|
| 115 |
+
domain_id: str,
|
| 116 |
+
domain_schema: Dict[str, str],
|
| 117 |
+
use_reasoning: bool
|
| 118 |
+
) -> str:
|
| 119 |
+
"""
|
| 120 |
+
座標推定用のプロンプトを構築
|
| 121 |
+
"""
|
| 122 |
+
base_prompt = f"""You are an expert in knowledge space mapping and coordinate estimation.
|
| 123 |
+
|
| 124 |
+
Your task is to estimate the 6-dimensional coordinates that best represent the following knowledge in the domain of "{domain_id}".
|
| 125 |
+
|
| 126 |
+
**Coordinate System:**
|
| 127 |
+
|
| 128 |
+
1. **Domain-specific space [x, y, z]** (each 0.0-1.0):
|
| 129 |
+
- x-axis: {domain_schema['x']}
|
| 130 |
+
- y-axis: {domain_schema['y']}
|
| 131 |
+
- z-axis: {domain_schema['z']}
|
| 132 |
+
|
| 133 |
+
2. **Meta-information space [c, g, v]** (each 0.0-1.0):
|
| 134 |
+
- c (Certainty): How certain/verified is this knowledge?
|
| 135 |
+
* 0.0 = hypothesis, speculation
|
| 136 |
+
* 0.5 = established theory, widely accepted
|
| 137 |
+
* 1.0 = proven fact, empirically verified
|
| 138 |
+
|
| 139 |
+
- g (Granularity): How detailed/specific is this knowledge?
|
| 140 |
+
* 0.0 = high-level overview, general concept
|
| 141 |
+
* 0.5 = detailed explanation
|
| 142 |
+
* 1.0 = highly specialized, expert-level detail
|
| 143 |
+
|
| 144 |
+
- v (Verification): What is the verification status?
|
| 145 |
+
* 0.0 = unverified, no sources
|
| 146 |
+
* 0.5 = expert-reviewed, single source
|
| 147 |
+
* 1.0 = peer-reviewed, multiple sources confirmed
|
| 148 |
+
|
| 149 |
+
**Knowledge to estimate:**
|
| 150 |
+
|
| 151 |
+
Question: {prompt}
|
| 152 |
+
|
| 153 |
+
Answer: {response}
|
| 154 |
+
|
| 155 |
+
**Instructions:**
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
if use_reasoning:
|
| 159 |
+
base_prompt += """
|
| 160 |
+
1. First, analyze the knowledge and explain your reasoning for each coordinate.
|
| 161 |
+
2. Then, output the final coordinates.
|
| 162 |
+
|
| 163 |
+
Format your response as JSON:
|
| 164 |
+
{
|
| 165 |
+
"reasoning": "Your detailed reasoning here...",
|
| 166 |
+
"coordinates": [x, y, z, c, g, v],
|
| 167 |
+
"confidence": 0.85
|
| 168 |
+
}
|
| 169 |
+
"""
|
| 170 |
+
else:
|
| 171 |
+
base_prompt += """
|
| 172 |
+
Output ONLY the coordinates as a JSON object:
|
| 173 |
+
{
|
| 174 |
+
"coordinates": [x, y, z, c, g, v],
|
| 175 |
+
"confidence": 0.85
|
| 176 |
+
}
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
base_prompt += """
|
| 180 |
+
**Important:**
|
| 181 |
+
- All coordinates must be between 0.0 and 1.0
|
| 182 |
+
- Use 2 decimal places (e.g., 0.75)
|
| 183 |
+
- confidence should reflect how confident you are in this estimation (0.0-1.0)
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
return base_prompt
|
| 187 |
+
|
| 188 |
+
def _parse_llm_response(self, llm_response: str) -> Dict[str, Any]:
|
| 189 |
+
"""
|
| 190 |
+
LLMのレスポンスから座標を抽出
|
| 191 |
+
"""
|
| 192 |
+
try:
|
| 193 |
+
# JSONブロックを探す
|
| 194 |
+
# LLMはしばしば ```json ... ``` で囲む
|
| 195 |
+
if "```json" in llm_response:
|
| 196 |
+
json_start = llm_response.find("```json") + 7
|
| 197 |
+
json_end = llm_response.find("```", json_start)
|
| 198 |
+
json_str = llm_response[json_start:json_end].strip()
|
| 199 |
+
elif "```" in llm_response:
|
| 200 |
+
json_start = llm_response.find("```") + 3
|
| 201 |
+
json_end = llm_response.find("```", json_start)
|
| 202 |
+
json_str = llm_response[json_start:json_end].strip()
|
| 203 |
+
else:
|
| 204 |
+
# JSON全体を探す
|
| 205 |
+
json_str = llm_response.strip()
|
| 206 |
+
|
| 207 |
+
# JSONパース
|
| 208 |
+
result = json.loads(json_str)
|
| 209 |
+
|
| 210 |
+
# 必須フィールドチェック
|
| 211 |
+
if "coordinates" not in result:
|
| 212 |
+
raise ValueError("Missing 'coordinates' field")
|
| 213 |
+
|
| 214 |
+
# デフォルト値設定
|
| 215 |
+
if "reasoning" not in result:
|
| 216 |
+
result["reasoning"] = "No reasoning provided"
|
| 217 |
+
if "confidence" not in result:
|
| 218 |
+
result["confidence"] = 0.5
|
| 219 |
+
|
| 220 |
+
return result
|
| 221 |
+
|
| 222 |
+
except json.JSONDecodeError as e:
|
| 223 |
+
logger.error(f"JSON parse error: {e}")
|
| 224 |
+
logger.debug(f"LLM response: {llm_response}")
|
| 225 |
+
|
| 226 |
+
# フォールバック: 数値のリストを直接探す
|
| 227 |
+
return self._fallback_parse(llm_response)
|
| 228 |
+
|
| 229 |
+
def _fallback_parse(self, llm_response: str) -> Dict[str, Any]:
|
| 230 |
+
"""
|
| 231 |
+
JSONパースに失敗した場合のフォールバック
|
| 232 |
+
"""
|
| 233 |
+
import re
|
| 234 |
+
|
| 235 |
+
# [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] のようなパターンを探す
|
| 236 |
+
pattern = r'\[[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*\]'
|
| 237 |
+
match = re.search(pattern, llm_response)
|
| 238 |
+
|
| 239 |
+
if match:
|
| 240 |
+
coords = [float(match.group(i)) for i in range(1, 7)]
|
| 241 |
+
return {
|
| 242 |
+
"coordinates": coords,
|
| 243 |
+
"reasoning": "Parsed from array notation",
|
| 244 |
+
"confidence": 0.5
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
# パースに完全に失敗
|
| 248 |
+
raise ValueError("Could not parse coordinates from LLM response")
|
| 249 |
+
|
| 250 |
+
def _validate_coordinates(self, coordinates: List[float]) -> bool:
|
| 251 |
+
"""
|
| 252 |
+
座標の妥当性をチェック
|
| 253 |
+
"""
|
| 254 |
+
if not isinstance(coordinates, list):
|
| 255 |
+
return False
|
| 256 |
+
|
| 257 |
+
if len(coordinates) != 6:
|
| 258 |
+
logger.error(f"Expected 6 coordinates, got {len(coordinates)}")
|
| 259 |
+
return False
|
| 260 |
+
|
| 261 |
+
for i, coord in enumerate(coordinates):
|
| 262 |
+
if not isinstance(coord, (int, float)):
|
| 263 |
+
logger.error(f"Coordinate {i} is not a number: {coord}")
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
if not (0.0 <= coord <= 1.0):
|
| 267 |
+
logger.error(f"Coordinate {i} out of range [0.0, 1.0]: {coord}")
|
| 268 |
+
return False
|
| 269 |
+
|
| 270 |
+
return True
|
| 271 |
+
|
| 272 |
+
def _get_default_coordinates(self, domain_id: str) -> Dict[str, Any]:
|
| 273 |
+
"""
|
| 274 |
+
推定に失敗した場合のデフォルト座標
|
| 275 |
+
"""
|
| 276 |
+
logger.warning(f"Using default coordinates for domain '{domain_id}'")
|
| 277 |
+
|
| 278 |
+
# ドメイン中心の座標
|
| 279 |
+
return {
|
| 280 |
+
"coordinates": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
|
| 281 |
+
"reasoning": "Default coordinates (estimation failed)",
|
| 282 |
+
"confidence": 0.3
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
async def estimate_batch(
|
| 286 |
+
self,
|
| 287 |
+
knowledge_items: List[Dict[str, str]],
|
| 288 |
+
llm_inference_func,
|
| 289 |
+
max_concurrent: int = 3
|
| 290 |
+
) -> List[Dict[str, Any]]:
|
| 291 |
+
"""
|
| 292 |
+
複数の知識アイテムの座標を一括推定
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
knowledge_items: [{"prompt": "...", "response": "...", "domain_id": "..."}, ...]
|
| 296 |
+
llm_inference_func: LLM推論関数
|
| 297 |
+
max_concurrent: 同時実行数
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
推定結果のリスト
|
| 301 |
+
"""
|
| 302 |
+
semaphore = asyncio.Semaphore(max_concurrent)
|
| 303 |
+
|
| 304 |
+
async def estimate_with_semaphore(item):
|
| 305 |
+
async with semaphore:
|
| 306 |
+
return await self.estimate_coordinates(
|
| 307 |
+
prompt=item["prompt"],
|
| 308 |
+
response=item["response"],
|
| 309 |
+
domain_id=item.get("domain_id", "general"),
|
| 310 |
+
llm_inference_func=llm_inference_func
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
tasks = [estimate_with_semaphore(item) for item in knowledge_items]
|
| 314 |
+
results = await asyncio.gather(*tasks)
|
| 315 |
+
|
| 316 |
+
return results
|
| 317 |
+
|
| 318 |
+
def get_domain_schema(self, domain_id: str) -> Dict[str, str]:
|
| 319 |
+
"""
|
| 320 |
+
ドメインスキーマを取得(UI表示用)
|
| 321 |
+
"""
|
| 322 |
+
return self.domain_schemas.get(domain_id, self.domain_schemas["general"])
|
| 323 |
+
|
| 324 |
+
def add_domain_schema(self, domain_id: str, schema: Dict[str, str]):
|
| 325 |
+
"""
|
| 326 |
+
新しいドメインスキーマを追加
|
| 327 |
+
"""
|
| 328 |
+
if not all(key in schema for key in ["x", "y", "z"]):
|
| 329 |
+
raise ValueError("Schema must contain 'x', 'y', 'z' definitions")
|
| 330 |
+
|
| 331 |
+
self.domain_schemas[domain_id] = schema
|
| 332 |
+
logger.info(f"Added domain schema for '{domain_id}'")
|
| 333 |
+
|
| 334 |
+
def interpolate_coordinates(
|
| 335 |
+
self,
|
| 336 |
+
coord1: List[float],
|
| 337 |
+
coord2: List[float],
|
| 338 |
+
weight: float = 0.5
|
| 339 |
+
) -> List[float]:
|
| 340 |
+
"""
|
| 341 |
+
2つの座標の間を補間(類似知識の座標推定に使用)
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
coord1: 座標1
|
| 345 |
+
coord2: 座標2
|
| 346 |
+
weight: 補間ウェイト (0.0=coord1, 1.0=coord2)
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
補間された座標
|
| 350 |
+
"""
|
| 351 |
+
if len(coord1) != 6 or len(coord2) != 6:
|
| 352 |
+
raise ValueError("Both coordinates must be 6-dimensional")
|
| 353 |
+
|
| 354 |
+
interpolated = [
|
| 355 |
+
coord1[i] * (1 - weight) + coord2[i] * weight
|
| 356 |
+
for i in range(6)
|
| 357 |
+
]
|
| 358 |
+
|
| 359 |
+
return interpolated
|
db_enrichment.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# null_ai/db_enrichment.py
|
| 2 |
+
"""
|
| 3 |
+
Database Enrichment Module
|
| 4 |
+
|
| 5 |
+
AIを使って自動的に知識ベースを拡充するモジュール。
|
| 6 |
+
DeepSeekが質問を生成 → 師匠モデルが回答 → .iathに保存
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import asyncio
|
| 11 |
+
from typing import List, Dict, Any, Optional, Callable
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import uuid
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AIEnrichmentEngine:
|
| 20 |
+
"""
|
| 21 |
+
AI駆動の知識ベース拡充エンジン
|
| 22 |
+
|
| 23 |
+
動作フロー:
|
| 24 |
+
1. DeepSeek(永久指導者)が拡充用の質問を生成
|
| 25 |
+
2. 師匠モデルが質問に回答
|
| 26 |
+
3. 座標を自動推定
|
| 27 |
+
4. .iathに保存
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
prompt_generator_func, # プロンプト生成用LLM関数
|
| 33 |
+
answer_generator_func, # 回答生成用LLM関数
|
| 34 |
+
coordinate_estimator, # CoordinateEstimator
|
| 35 |
+
iath_writer # IathWriter
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
prompt_generator_func: async def(prompt) -> str (DeepSeek使用)
|
| 40 |
+
answer_generator_func: async def(prompt) -> dict (師匠モデル使用)
|
| 41 |
+
coordinate_estimator: CoordinateEstimatorインスタンス
|
| 42 |
+
iath_writer: IathWriterインスタンス
|
| 43 |
+
"""
|
| 44 |
+
self.prompt_generator = prompt_generator_func
|
| 45 |
+
self.answer_generator = answer_generator_func
|
| 46 |
+
self.coordinate_estimator = coordinate_estimator
|
| 47 |
+
self.iath_writer = iath_writer
|
| 48 |
+
|
| 49 |
+
self.enrichment_state = {
|
| 50 |
+
"is_running": False,
|
| 51 |
+
"progress": 0.0,
|
| 52 |
+
"current_question": 0,
|
| 53 |
+
"total_questions": 0,
|
| 54 |
+
"generated_tiles": 0,
|
| 55 |
+
"start_time": None,
|
| 56 |
+
"domain_id": None
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
async def generate_enrichment_questions(
|
| 60 |
+
self,
|
| 61 |
+
domain_id: str,
|
| 62 |
+
num_questions: int = 10,
|
| 63 |
+
focus_areas: Optional[List[str]] = None
|
| 64 |
+
) -> List[str]:
|
| 65 |
+
"""
|
| 66 |
+
DeepSeekを使って拡充用の質問を生成
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
domain_id: ドメインID
|
| 70 |
+
num_questions: 生成する質問数
|
| 71 |
+
focus_areas: 重点領域(例: ["基礎理論", "応用例", "最新研究"])
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
生成された質問のリスト
|
| 75 |
+
"""
|
| 76 |
+
logger.info(f"Generating {num_questions} enrichment questions for domain '{domain_id}'")
|
| 77 |
+
|
| 78 |
+
# 重点領域の文字列化
|
| 79 |
+
focus_text = ""
|
| 80 |
+
if focus_areas:
|
| 81 |
+
focus_text = f"\nFocus on these specific areas: {', '.join(focus_areas)}"
|
| 82 |
+
|
| 83 |
+
# プロンプト構築
|
| 84 |
+
generation_prompt = f"""You are an expert knowledge curator in the domain of {domain_id}.
|
| 85 |
+
|
| 86 |
+
Your task is to generate {num_questions} diverse, high-quality questions that would enrich a knowledge base in this domain.
|
| 87 |
+
|
| 88 |
+
Guidelines:
|
| 89 |
+
1. Cover a wide range of topics within the domain
|
| 90 |
+
2. Include questions at different complexity levels (basic, intermediate, advanced)
|
| 91 |
+
3. Focus on practical applications and edge cases
|
| 92 |
+
4. Avoid overly broad or trivial questions
|
| 93 |
+
5. Each question should elicit detailed, informative answers{focus_text}
|
| 94 |
+
|
| 95 |
+
Output format:
|
| 96 |
+
Return ONLY a JSON array of questions, like this:
|
| 97 |
+
["Question 1?", "Question 2?", ...]
|
| 98 |
+
|
| 99 |
+
Domain: {domain_id}
|
| 100 |
+
Number of questions: {num_questions}
|
| 101 |
+
|
| 102 |
+
Generate the questions now:"""
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
# DeepSeekでプロンプト生成
|
| 106 |
+
logger.info(f"Calling prompt generator with prompt length: {len(generation_prompt)}")
|
| 107 |
+
response = await self.prompt_generator(generation_prompt)
|
| 108 |
+
|
| 109 |
+
# 応答の長さをログ出力
|
| 110 |
+
logger.info(f"Received response from LLM, length: {len(response) if response else 0}")
|
| 111 |
+
if not response:
|
| 112 |
+
logger.error("LLM returned empty response!")
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
logger.debug(f"Response preview (first 200 chars): {response[:200]}")
|
| 116 |
+
|
| 117 |
+
# JSON抽出
|
| 118 |
+
questions = self._extract_questions_from_response(response)
|
| 119 |
+
|
| 120 |
+
# 重複除去
|
| 121 |
+
questions = list(dict.fromkeys(questions)) # 順序を保持しつつ重複除去
|
| 122 |
+
|
| 123 |
+
logger.info(f"Generated {len(questions)} unique questions")
|
| 124 |
+
return questions[:num_questions]
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.error(f"Failed to generate enrichment questions: {e}", exc_info=True)
|
| 128 |
+
return []
|
| 129 |
+
|
| 130 |
+
def _extract_questions_from_response(self, response: str) -> List[str]:
|
| 131 |
+
"""
|
| 132 |
+
LLMのレスポンスから質問リストを抽出
|
| 133 |
+
"""
|
| 134 |
+
# 空の応答チェック
|
| 135 |
+
if not response or not response.strip():
|
| 136 |
+
logger.error(f"Empty response from LLM. Response: '{response}'")
|
| 137 |
+
return []
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
# JSONブロックを探す
|
| 141 |
+
json_str = ""
|
| 142 |
+
if "```json" in response:
|
| 143 |
+
json_start = response.find("```json") + 7
|
| 144 |
+
json_end = response.find("```", json_start)
|
| 145 |
+
json_str = response[json_start:json_end].strip()
|
| 146 |
+
elif "```" in response:
|
| 147 |
+
json_start = response.find("```") + 3
|
| 148 |
+
json_end = response.find("```", json_start)
|
| 149 |
+
json_str = response[json_start:json_end].strip()
|
| 150 |
+
elif "[" in response and "]" in response:
|
| 151 |
+
# JSON配列を直接探す
|
| 152 |
+
json_start = response.find("[")
|
| 153 |
+
json_end = response.rfind("]") + 1
|
| 154 |
+
json_str = response[json_start:json_end]
|
| 155 |
+
else:
|
| 156 |
+
json_str = response.strip()
|
| 157 |
+
|
| 158 |
+
# 空のjson_strチェック
|
| 159 |
+
if not json_str:
|
| 160 |
+
logger.error(f"Could not extract JSON from response. Full response: '{response[:500]}'")
|
| 161 |
+
return self._fallback_parse(response)
|
| 162 |
+
|
| 163 |
+
# JSONパース
|
| 164 |
+
questions = json.loads(json_str)
|
| 165 |
+
|
| 166 |
+
if isinstance(questions, list):
|
| 167 |
+
# 文字列のリストかチェック
|
| 168 |
+
return [q for q in questions if isinstance(q, str)]
|
| 169 |
+
else:
|
| 170 |
+
logger.error(f"Expected list, got {type(questions)}")
|
| 171 |
+
return []
|
| 172 |
+
|
| 173 |
+
except json.JSONDecodeError as e:
|
| 174 |
+
logger.error(f"JSON parse error: {e}. Attempted to parse: '{json_str[:200]}'")
|
| 175 |
+
logger.error(f"Full response (first 500 chars): '{response[:500]}'")
|
| 176 |
+
return self._fallback_parse(response)
|
| 177 |
+
|
| 178 |
+
def _fallback_parse(self, response: str) -> List[str]:
|
| 179 |
+
"""
|
| 180 |
+
JSONパースに失敗した場合のフォールバック処理
|
| 181 |
+
改行区切りのテキストとして質問を抽出
|
| 182 |
+
"""
|
| 183 |
+
logger.info("Attempting fallback parsing of response...")
|
| 184 |
+
lines = response.split('\n')
|
| 185 |
+
questions = []
|
| 186 |
+
for line in lines:
|
| 187 |
+
line = line.strip()
|
| 188 |
+
# 数字付きリストのフォーマットを除去 (1. Question -> Question)
|
| 189 |
+
if line and len(line) > 5: # 最低5文字以上
|
| 190 |
+
if line[0].isdigit() or line.startswith('-') or line.startswith('*'):
|
| 191 |
+
# "1. " や "- " を除去
|
| 192 |
+
cleaned = line.lstrip('0123456789.-* ').strip()
|
| 193 |
+
if cleaned and '?' in cleaned:
|
| 194 |
+
questions.append(cleaned)
|
| 195 |
+
elif '?' in line:
|
| 196 |
+
# 疑問符がある行をそのまま質問として扱う
|
| 197 |
+
questions.append(line)
|
| 198 |
+
|
| 199 |
+
logger.info(f"Fallback parsing extracted {len(questions)} questions")
|
| 200 |
+
return questions
|
| 201 |
+
|
| 202 |
+
async def enrich_with_ai(
|
| 203 |
+
self,
|
| 204 |
+
domain_id: str,
|
| 205 |
+
num_questions: int = 10,
|
| 206 |
+
focus_areas: Optional[List[str]] = None,
|
| 207 |
+
progress_callback: Optional[Callable] = None
|
| 208 |
+
) -> Dict[str, Any]:
|
| 209 |
+
"""
|
| 210 |
+
AIを使って知識ベースを拡充
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
domain_id: ドメインID
|
| 214 |
+
num_questions: 生成する質問数
|
| 215 |
+
focus_areas: 重点領域
|
| 216 |
+
progress_callback: 進捗コールバック async def(state)
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
拡充結果
|
| 220 |
+
"""
|
| 221 |
+
self.enrichment_state.update({
|
| 222 |
+
"is_running": True,
|
| 223 |
+
"progress": 0.0,
|
| 224 |
+
"current_question": 0,
|
| 225 |
+
"total_questions": num_questions,
|
| 226 |
+
"generated_tiles": 0,
|
| 227 |
+
"start_time": datetime.utcnow().isoformat(),
|
| 228 |
+
"domain_id": domain_id
|
| 229 |
+
})
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
# Step 1: 質問生成
|
| 233 |
+
logger.info(f"Step 1: Generating {num_questions} questions...")
|
| 234 |
+
questions = await self.generate_enrichment_questions(
|
| 235 |
+
domain_id, num_questions, focus_areas
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if not questions:
|
| 239 |
+
return {
|
| 240 |
+
"success": False,
|
| 241 |
+
"error": "Failed to generate questions"
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
self.enrichment_state["total_questions"] = len(questions)
|
| 245 |
+
|
| 246 |
+
# Step 2: 各質問に対して回答生成 + 保存
|
| 247 |
+
tiles_created = 0
|
| 248 |
+
|
| 249 |
+
for i, question in enumerate(questions):
|
| 250 |
+
logger.info(f"Processing question {i+1}/{len(questions)}: {question[:50]}...")
|
| 251 |
+
|
| 252 |
+
self.enrichment_state["current_question"] = i + 1
|
| 253 |
+
self.enrichment_state["progress"] = (i / len(questions)) * 100
|
| 254 |
+
|
| 255 |
+
if progress_callback:
|
| 256 |
+
await progress_callback(self.enrichment_state)
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
# 師匠モデルで回答生成
|
| 260 |
+
answer_result = await self.answer_generator(question)
|
| 261 |
+
|
| 262 |
+
if not answer_result or "response" not in answer_result:
|
| 263 |
+
logger.warning(f"No response for question: {question}")
|
| 264 |
+
continue
|
| 265 |
+
|
| 266 |
+
response = answer_result["response"]
|
| 267 |
+
confidence = answer_result.get("confidence", 0.7)
|
| 268 |
+
|
| 269 |
+
# 座標推定用の推論関数を作成
|
| 270 |
+
async def coord_inference(coord_prompt):
|
| 271 |
+
result = await self.prompt_generator(coord_prompt)
|
| 272 |
+
return result
|
| 273 |
+
|
| 274 |
+
# 座標推定
|
| 275 |
+
coord_result = await self.coordinate_estimator.estimate_coordinates(
|
| 276 |
+
prompt=question,
|
| 277 |
+
response=response,
|
| 278 |
+
domain_id=domain_id,
|
| 279 |
+
llm_inference_func=coord_inference,
|
| 280 |
+
use_reasoning=False
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# .iath Tileオブジェクト作成
|
| 284 |
+
from null_ai.iath_writer import create_tile_from_ai_output
|
| 285 |
+
|
| 286 |
+
tile = create_tile_from_ai_output(
|
| 287 |
+
knowledge_id=f"ai_enrich_{uuid.uuid4().hex}",
|
| 288 |
+
topic=question[:100],
|
| 289 |
+
prompt=question,
|
| 290 |
+
response=response,
|
| 291 |
+
coordinates=coord_result["coordinates"],
|
| 292 |
+
confidence=confidence,
|
| 293 |
+
domain_id=domain_id,
|
| 294 |
+
source="ai_enrichment"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# .iathに保存
|
| 298 |
+
success = self.iath_writer.append_tile(tile)
|
| 299 |
+
|
| 300 |
+
if success:
|
| 301 |
+
tiles_created += 1
|
| 302 |
+
self.enrichment_state["generated_tiles"] = tiles_created
|
| 303 |
+
logger.info(f"Saved enrichment tile {tiles_created}: {question[:50]}...")
|
| 304 |
+
else:
|
| 305 |
+
logger.warning(f"Failed to save tile for: {question}")
|
| 306 |
+
|
| 307 |
+
except Exception as e:
|
| 308 |
+
logger.error(f"Error processing question '{question}': {e}")
|
| 309 |
+
continue
|
| 310 |
+
|
| 311 |
+
# 完了
|
| 312 |
+
self.enrichment_state.update({
|
| 313 |
+
"is_running": False,
|
| 314 |
+
"progress": 100.0,
|
| 315 |
+
"current_question": len(questions)
|
| 316 |
+
})
|
| 317 |
+
|
| 318 |
+
if progress_callback:
|
| 319 |
+
await progress_callback(self.enrichment_state)
|
| 320 |
+
|
| 321 |
+
return {
|
| 322 |
+
"success": True,
|
| 323 |
+
"questions_generated": len(questions),
|
| 324 |
+
"tiles_created": tiles_created,
|
| 325 |
+
"domain_id": domain_id
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
except Exception as e:
|
| 329 |
+
logger.error(f"AI enrichment failed: {e}")
|
| 330 |
+
self.enrichment_state["is_running"] = False
|
| 331 |
+
return {
|
| 332 |
+
"success": False,
|
| 333 |
+
"error": str(e)
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
def get_enrichment_status(self) -> Dict[str, Any]:
|
| 337 |
+
"""現在の拡充ステータスを取得"""
|
| 338 |
+
return self.enrichment_state.copy()
|
| 339 |
+
|
| 340 |
+
def stop_enrichment(self):
|
| 341 |
+
"""拡充処理を停止(実装は簡易版)"""
|
| 342 |
+
self.enrichment_state["is_running"] = False
|
| 343 |
+
logger.info("Enrichment stop requested")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class WebEnrichmentEngine:
|
| 347 |
+
"""
|
| 348 |
+
Web検索による知識ベース拡充エンジン
|
| 349 |
+
|
| 350 |
+
動作フロー:
|
| 351 |
+
1. Web検索(Brave Search / Google Custom Search)
|
| 352 |
+
2. ページから知識を抽出
|
| 353 |
+
3. Knowledge Tile形式に変換
|
| 354 |
+
4. .iathに保存
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
def __init__(
|
| 358 |
+
self,
|
| 359 |
+
search_api_key: Optional[str] = None,
|
| 360 |
+
search_engine: str = "brave", # "brave" or "google"
|
| 361 |
+
coordinate_estimator=None,
|
| 362 |
+
iath_writer=None
|
| 363 |
+
):
|
| 364 |
+
"""
|
| 365 |
+
Args:
|
| 366 |
+
search_api_key: 検索APIキー
|
| 367 |
+
search_engine: 使用する検索エンジン
|
| 368 |
+
coordinate_estimator: CoordinateEstimatorインスタンス
|
| 369 |
+
iath_writer: IathWriterインスタンス
|
| 370 |
+
"""
|
| 371 |
+
self.api_key = search_api_key
|
| 372 |
+
self.search_engine = search_engine
|
| 373 |
+
self.coordinate_estimator = coordinate_estimator
|
| 374 |
+
self.iath_writer = iath_writer
|
| 375 |
+
|
| 376 |
+
async def search_web(self, query: str, max_results: int = 5) -> List[Dict[str, str]]:
|
| 377 |
+
"""
|
| 378 |
+
Web検索を実行
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
query: 検索クエリ
|
| 382 |
+
max_results: 最大結果数
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
[{"title": "...", "url": "...", "snippet": "..."}, ...]
|
| 386 |
+
"""
|
| 387 |
+
if self.search_engine == "brave":
|
| 388 |
+
return await self._search_brave(query, max_results)
|
| 389 |
+
elif self.search_engine == "google":
|
| 390 |
+
return await self._search_google(query, max_results)
|
| 391 |
+
else:
|
| 392 |
+
raise ValueError(f"Unknown search engine: {self.search_engine}")
|
| 393 |
+
|
| 394 |
+
async def _search_brave(self, query: str, max_results: int) -> List[Dict[str, str]]:
|
| 395 |
+
"""Brave Search APIで検索"""
|
| 396 |
+
try:
|
| 397 |
+
import aiohttp
|
| 398 |
+
|
| 399 |
+
if not self.api_key:
|
| 400 |
+
logger.error("Brave Search API key not configured")
|
| 401 |
+
return []
|
| 402 |
+
|
| 403 |
+
url = "https://api.search.brave.com/res/v1/web/search"
|
| 404 |
+
headers = {
|
| 405 |
+
"Accept": "application/json",
|
| 406 |
+
"X-Subscription-Token": self.api_key
|
| 407 |
+
}
|
| 408 |
+
params = {
|
| 409 |
+
"q": query,
|
| 410 |
+
"count": max_results
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
async with aiohttp.ClientSession() as session:
|
| 414 |
+
async with session.get(url, headers=headers, params=params) as response:
|
| 415 |
+
if response.status == 200:
|
| 416 |
+
data = await response.json()
|
| 417 |
+
results = []
|
| 418 |
+
|
| 419 |
+
for item in data.get("web", {}).get("results", []):
|
| 420 |
+
results.append({
|
| 421 |
+
"title": item.get("title", ""),
|
| 422 |
+
"url": item.get("url", ""),
|
| 423 |
+
"snippet": item.get("description", "")
|
| 424 |
+
})
|
| 425 |
+
|
| 426 |
+
return results
|
| 427 |
+
else:
|
| 428 |
+
logger.error(f"Brave Search API error: {response.status}")
|
| 429 |
+
return []
|
| 430 |
+
|
| 431 |
+
except Exception as e:
|
| 432 |
+
logger.error(f"Brave Search failed: {e}")
|
| 433 |
+
return []
|
| 434 |
+
|
| 435 |
+
async def _search_google(self, query: str, max_results: int) -> List[Dict[str, str]]:
|
| 436 |
+
"""Google Custom Search APIで検索"""
|
| 437 |
+
# TODO: Google Custom Search API実装
|
| 438 |
+
logger.warning("Google Custom Search not implemented yet")
|
| 439 |
+
return []
|
| 440 |
+
|
| 441 |
+
async def enrich_from_web(
|
| 442 |
+
self,
|
| 443 |
+
query: str,
|
| 444 |
+
domain_id: str,
|
| 445 |
+
max_results: int = 5
|
| 446 |
+
) -> Dict[str, Any]:
|
| 447 |
+
"""
|
| 448 |
+
Web検索結果から知識ベースを拡充
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
query: 検索クエリ
|
| 452 |
+
domain_id: ドメインID
|
| 453 |
+
max_results: 最大結果数
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
拡充結果
|
| 457 |
+
"""
|
| 458 |
+
try:
|
| 459 |
+
# Web検索
|
| 460 |
+
search_results = await self.search_web(query, max_results)
|
| 461 |
+
|
| 462 |
+
if not search_results:
|
| 463 |
+
return {
|
| 464 |
+
"success": False,
|
| 465 |
+
"error": "No search results found"
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
# 各結果をKnowledge Tileに変換
|
| 469 |
+
tiles_created = 0
|
| 470 |
+
|
| 471 |
+
for result in search_results:
|
| 472 |
+
try:
|
| 473 |
+
# Knowledge Tile作成(Web検索結果版)
|
| 474 |
+
tile_content = f"""Title: {result['title']}
|
| 475 |
+
|
| 476 |
+
Source: {result['url']}
|
| 477 |
+
|
| 478 |
+
Summary:
|
| 479 |
+
{result['snippet']}
|
| 480 |
+
"""
|
| 481 |
+
|
| 482 |
+
# 座標推定(簡易版: デフォルト座標)
|
| 483 |
+
# TODO: より高度な座標推定
|
| 484 |
+
coordinates = [0.5, 0.5, 0.5, 0.6, 0.5, 0.7]
|
| 485 |
+
|
| 486 |
+
from null_ai.iath_writer import create_tile_from_ai_output
|
| 487 |
+
|
| 488 |
+
tile = create_tile_from_ai_output(
|
| 489 |
+
knowledge_id=f"web_{uuid.uuid4().hex}",
|
| 490 |
+
topic=result['title'],
|
| 491 |
+
prompt=query,
|
| 492 |
+
response=tile_content,
|
| 493 |
+
coordinates=coordinates,
|
| 494 |
+
confidence=0.6, # Web検索結果は中程度の信頼度
|
| 495 |
+
domain_id=domain_id,
|
| 496 |
+
source="web_search"
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
# .iathに保存
|
| 500 |
+
success = self.iath_writer.append_tile(tile)
|
| 501 |
+
|
| 502 |
+
if success:
|
| 503 |
+
tiles_created += 1
|
| 504 |
+
logger.info(f"Saved web tile: {result['title']}")
|
| 505 |
+
|
| 506 |
+
except Exception as e:
|
| 507 |
+
logger.error(f"Error creating tile from search result: {e}")
|
| 508 |
+
continue
|
| 509 |
+
|
| 510 |
+
return {
|
| 511 |
+
"success": True,
|
| 512 |
+
"search_results": len(search_results),
|
| 513 |
+
"tiles_created": tiles_created,
|
| 514 |
+
"query": query,
|
| 515 |
+
"domain_id": domain_id
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
except Exception as e:
|
| 519 |
+
logger.error(f"Web enrichment failed: {e}")
|
| 520 |
+
return {
|
| 521 |
+
"success": False,
|
| 522 |
+
"error": str(e)
|
| 523 |
+
}
|
fine_tuning.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# null_ai/fine_tuning.py
|
| 2 |
+
"""
|
| 3 |
+
NullAI Fine-tuning Module
|
| 4 |
+
Implements apprentice model fine-tuning using master outputs (Alpaca format)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Optional, Any, Callable
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import asyncio
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FineTuningManager:
|
| 19 |
+
"""
|
| 20 |
+
Manages fine-tuning of apprentice models using master outputs.
|
| 21 |
+
Supports multiple backends: HuggingFace (PEFT/LoRA), Unsloth, MLX
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, training_data_dir: str = "training_data/master_outputs"):
|
| 25 |
+
self.training_data_dir = Path(training_data_dir)
|
| 26 |
+
self.checkpoints_dir = Path("training_data/checkpoints")
|
| 27 |
+
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
self.current_training_state = {
|
| 30 |
+
"is_training": False,
|
| 31 |
+
"progress": 0.0,
|
| 32 |
+
"current_epoch": 0,
|
| 33 |
+
"total_epochs": 0,
|
| 34 |
+
"loss": 0.0,
|
| 35 |
+
"model_id": None,
|
| 36 |
+
"start_time": None
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
# ===== Training Data Loading =====
|
| 40 |
+
|
| 41 |
+
def load_training_data(self, domain_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
| 42 |
+
"""
|
| 43 |
+
Load training data from Alpaca-format JSONL files.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
domain_id: Specific domain to load. If None, loads all domains.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
List of training examples in Alpaca format
|
| 50 |
+
"""
|
| 51 |
+
training_examples = []
|
| 52 |
+
|
| 53 |
+
if not self.training_data_dir.exists():
|
| 54 |
+
logger.warning(f"Training data directory not found: {self.training_data_dir}")
|
| 55 |
+
return training_examples
|
| 56 |
+
|
| 57 |
+
# Determine which files to load
|
| 58 |
+
if domain_id:
|
| 59 |
+
jsonl_files = [self.training_data_dir / f"master_outputs_{domain_id}.jsonl"]
|
| 60 |
+
else:
|
| 61 |
+
jsonl_files = list(self.training_data_dir.glob("master_outputs_*.jsonl"))
|
| 62 |
+
|
| 63 |
+
for jsonl_file in jsonl_files:
|
| 64 |
+
if not jsonl_file.exists():
|
| 65 |
+
logger.warning(f"Training data file not found: {jsonl_file}")
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
logger.info(f"Loading training data from: {jsonl_file}")
|
| 69 |
+
with open(jsonl_file, 'r', encoding='utf-8') as f:
|
| 70 |
+
for line in f:
|
| 71 |
+
try:
|
| 72 |
+
example = json.loads(line.strip())
|
| 73 |
+
training_examples.append(example)
|
| 74 |
+
except json.JSONDecodeError as e:
|
| 75 |
+
logger.error(f"Failed to parse JSON line in {jsonl_file}: {e}")
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
logger.info(f"Loaded {len(training_examples)} training examples")
|
| 79 |
+
return training_examples
|
| 80 |
+
|
| 81 |
+
def format_training_examples_for_model(
|
| 82 |
+
self,
|
| 83 |
+
training_examples: List[Dict[str, Any]],
|
| 84 |
+
template: str = "alpaca"
|
| 85 |
+
) -> List[str]:
|
| 86 |
+
"""
|
| 87 |
+
Format training examples into model-ready prompts.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
training_examples: Raw Alpaca-format examples
|
| 91 |
+
template: Prompt template format ("alpaca", "chatml", "llama3")
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
List of formatted prompt strings
|
| 95 |
+
"""
|
| 96 |
+
formatted_prompts = []
|
| 97 |
+
|
| 98 |
+
for example in training_examples:
|
| 99 |
+
instruction = example.get("instruction", "")
|
| 100 |
+
input_text = example.get("input", "")
|
| 101 |
+
output_text = example.get("output", "")
|
| 102 |
+
|
| 103 |
+
if template == "alpaca":
|
| 104 |
+
if input_text:
|
| 105 |
+
prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
| 106 |
+
|
| 107 |
+
### Instruction:
|
| 108 |
+
{instruction}
|
| 109 |
+
|
| 110 |
+
### Input:
|
| 111 |
+
{input_text}
|
| 112 |
+
|
| 113 |
+
### Response:
|
| 114 |
+
{output_text}"""
|
| 115 |
+
else:
|
| 116 |
+
prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
| 117 |
+
|
| 118 |
+
### Instruction:
|
| 119 |
+
{instruction}
|
| 120 |
+
|
| 121 |
+
### Response:
|
| 122 |
+
{output_text}"""
|
| 123 |
+
|
| 124 |
+
elif template == "chatml":
|
| 125 |
+
prompt = f"""<|im_start|>system
|
| 126 |
+
{instruction}<|im_end|>
|
| 127 |
+
<|im_start|>user
|
| 128 |
+
{input_text}<|im_end|>
|
| 129 |
+
<|im_start|>assistant
|
| 130 |
+
{output_text}<|im_end|>"""
|
| 131 |
+
|
| 132 |
+
elif template == "llama3":
|
| 133 |
+
prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
| 134 |
+
{instruction}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
| 135 |
+
{input_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
| 136 |
+
{output_text}<|eot_id|>"""
|
| 137 |
+
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError(f"Unknown template format: {template}")
|
| 140 |
+
|
| 141 |
+
formatted_prompts.append(prompt)
|
| 142 |
+
|
| 143 |
+
return formatted_prompts
|
| 144 |
+
|
| 145 |
+
# ===== Fine-tuning Backends =====
|
| 146 |
+
|
| 147 |
+
async def fine_tune_with_huggingface_peft(
|
| 148 |
+
self,
|
| 149 |
+
model_name: str,
|
| 150 |
+
training_examples: List[Dict[str, Any]],
|
| 151 |
+
output_dir: str,
|
| 152 |
+
epochs: int = 3,
|
| 153 |
+
learning_rate: float = 2e-4,
|
| 154 |
+
batch_size: int = 4,
|
| 155 |
+
gradient_accumulation_steps: int = 4,
|
| 156 |
+
lora_r: int = 8,
|
| 157 |
+
lora_alpha: int = 16,
|
| 158 |
+
lora_dropout: float = 0.05,
|
| 159 |
+
max_seq_length: int = 512,
|
| 160 |
+
progress_callback: Optional[Callable] = None
|
| 161 |
+
) -> Dict[str, Any]:
|
| 162 |
+
"""
|
| 163 |
+
Fine-tune model using HuggingFace Transformers + PEFT (LoRA).
|
| 164 |
+
|
| 165 |
+
This is the recommended method for most models.
|
| 166 |
+
Uses QLoRA (4-bit quantization) for memory efficiency.
|
| 167 |
+
"""
|
| 168 |
+
try:
|
| 169 |
+
import torch
|
| 170 |
+
from transformers import (
|
| 171 |
+
AutoModelForCausalLM,
|
| 172 |
+
AutoTokenizer,
|
| 173 |
+
TrainingArguments,
|
| 174 |
+
Trainer,
|
| 175 |
+
DataCollatorForLanguageModeling
|
| 176 |
+
)
|
| 177 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 178 |
+
from datasets import Dataset
|
| 179 |
+
|
| 180 |
+
except ImportError as e:
|
| 181 |
+
logger.error(f"Required libraries not installed: {e}")
|
| 182 |
+
logger.error("Please install: pip install transformers peft datasets bitsandbytes accelerate")
|
| 183 |
+
raise
|
| 184 |
+
|
| 185 |
+
logger.info(f"Starting PEFT fine-tuning for model: {model_name}")
|
| 186 |
+
self.current_training_state.update({
|
| 187 |
+
"is_training": True,
|
| 188 |
+
"progress": 0.0,
|
| 189 |
+
"current_epoch": 0,
|
| 190 |
+
"total_epochs": epochs,
|
| 191 |
+
"model_id": model_name,
|
| 192 |
+
"start_time": datetime.utcnow().isoformat()
|
| 193 |
+
})
|
| 194 |
+
|
| 195 |
+
# 1. Load tokenizer
|
| 196 |
+
logger.info("Loading tokenizer...")
|
| 197 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 198 |
+
if tokenizer.pad_token is None:
|
| 199 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 200 |
+
|
| 201 |
+
# 2. Load model with 4-bit quantization (QLoRA)
|
| 202 |
+
logger.info("Loading model with 4-bit quantization...")
|
| 203 |
+
try:
|
| 204 |
+
from transformers import BitsAndBytesConfig
|
| 205 |
+
|
| 206 |
+
bnb_config = BitsAndBytesConfig(
|
| 207 |
+
load_in_4bit=True,
|
| 208 |
+
bnb_4bit_quant_type="nf4",
|
| 209 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 210 |
+
bnb_4bit_use_double_quant=True
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 214 |
+
model_name,
|
| 215 |
+
quantization_config=bnb_config,
|
| 216 |
+
device_map="auto",
|
| 217 |
+
trust_remote_code=True
|
| 218 |
+
)
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.warning(f"4-bit quantization failed, falling back to float16: {e}")
|
| 221 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 222 |
+
model_name,
|
| 223 |
+
torch_dtype=torch.float16,
|
| 224 |
+
device_map="auto",
|
| 225 |
+
trust_remote_code=True
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# 3. Prepare model for training
|
| 229 |
+
model = prepare_model_for_kbit_training(model)
|
| 230 |
+
|
| 231 |
+
# 4. Configure LoRA
|
| 232 |
+
logger.info("Configuring LoRA...")
|
| 233 |
+
lora_config = LoraConfig(
|
| 234 |
+
r=lora_r,
|
| 235 |
+
lora_alpha=lora_alpha,
|
| 236 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 237 |
+
lora_dropout=lora_dropout,
|
| 238 |
+
bias="none",
|
| 239 |
+
task_type="CAUSAL_LM"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
model = get_peft_model(model, lora_config)
|
| 243 |
+
model.print_trainable_parameters()
|
| 244 |
+
|
| 245 |
+
# 5. Format training data
|
| 246 |
+
logger.info("Formatting training data...")
|
| 247 |
+
formatted_texts = self.format_training_examples_for_model(training_examples, template="alpaca")
|
| 248 |
+
|
| 249 |
+
# 6. Tokenize dataset
|
| 250 |
+
def tokenize_function(examples):
|
| 251 |
+
return tokenizer(
|
| 252 |
+
examples["text"],
|
| 253 |
+
truncation=True,
|
| 254 |
+
max_length=max_seq_length,
|
| 255 |
+
padding="max_length"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
dataset = Dataset.from_dict({"text": formatted_texts})
|
| 259 |
+
tokenized_dataset = dataset.map(
|
| 260 |
+
tokenize_function,
|
| 261 |
+
batched=True,
|
| 262 |
+
remove_columns=dataset.column_names
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# 7. Training arguments
|
| 266 |
+
training_args = TrainingArguments(
|
| 267 |
+
output_dir=output_dir,
|
| 268 |
+
num_train_epochs=epochs,
|
| 269 |
+
per_device_train_batch_size=batch_size,
|
| 270 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 271 |
+
learning_rate=learning_rate,
|
| 272 |
+
fp16=True,
|
| 273 |
+
logging_steps=10,
|
| 274 |
+
save_steps=100,
|
| 275 |
+
save_total_limit=3,
|
| 276 |
+
warmup_steps=50,
|
| 277 |
+
optim="paged_adamw_8bit",
|
| 278 |
+
report_to="none" # Disable wandb/tensorboard for now
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# 8. Data collator
|
| 282 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 283 |
+
tokenizer=tokenizer,
|
| 284 |
+
mlm=False
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# 9. Create trainer with progress callback
|
| 288 |
+
class ProgressCallback:
|
| 289 |
+
def __init__(self, manager, total_epochs, callback):
|
| 290 |
+
self.manager = manager
|
| 291 |
+
self.total_epochs = total_epochs
|
| 292 |
+
self.callback = callback
|
| 293 |
+
|
| 294 |
+
def on_epoch_end(self, args, state, control, **kwargs):
|
| 295 |
+
epoch = state.epoch
|
| 296 |
+
loss = state.log_history[-1].get("loss", 0.0) if state.log_history else 0.0
|
| 297 |
+
|
| 298 |
+
self.manager.current_training_state.update({
|
| 299 |
+
"current_epoch": int(epoch),
|
| 300 |
+
"progress": (epoch / self.total_epochs) * 100,
|
| 301 |
+
"loss": loss
|
| 302 |
+
})
|
| 303 |
+
|
| 304 |
+
if self.callback:
|
| 305 |
+
asyncio.create_task(self.callback(self.manager.current_training_state))
|
| 306 |
+
|
| 307 |
+
from transformers import TrainerCallback
|
| 308 |
+
|
| 309 |
+
class CustomCallback(TrainerCallback):
|
| 310 |
+
def __init__(self, progress_cb):
|
| 311 |
+
self.progress_cb = progress_cb
|
| 312 |
+
|
| 313 |
+
def on_epoch_end(self, args, state, control, **kwargs):
|
| 314 |
+
self.progress_cb.on_epoch_end(args, state, control, **kwargs)
|
| 315 |
+
|
| 316 |
+
progress_cb = ProgressCallback(self, epochs, progress_callback)
|
| 317 |
+
|
| 318 |
+
trainer = Trainer(
|
| 319 |
+
model=model,
|
| 320 |
+
args=training_args,
|
| 321 |
+
train_dataset=tokenized_dataset,
|
| 322 |
+
data_collator=data_collator,
|
| 323 |
+
callbacks=[CustomCallback(progress_cb)]
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# 10. Train!
|
| 327 |
+
logger.info("Starting training...")
|
| 328 |
+
train_result = trainer.train()
|
| 329 |
+
|
| 330 |
+
# 11. Save final model
|
| 331 |
+
logger.info(f"Saving model to: {output_dir}")
|
| 332 |
+
trainer.save_model(output_dir)
|
| 333 |
+
tokenizer.save_pretrained(output_dir)
|
| 334 |
+
|
| 335 |
+
# 12. Update state
|
| 336 |
+
self.current_training_state.update({
|
| 337 |
+
"is_training": False,
|
| 338 |
+
"progress": 100.0,
|
| 339 |
+
"current_epoch": epochs
|
| 340 |
+
})
|
| 341 |
+
|
| 342 |
+
return {
|
| 343 |
+
"success": True,
|
| 344 |
+
"output_dir": output_dir,
|
| 345 |
+
"train_loss": train_result.training_loss,
|
| 346 |
+
"metrics": train_result.metrics,
|
| 347 |
+
"model_name": model_name,
|
| 348 |
+
"lora_config": {
|
| 349 |
+
"r": lora_r,
|
| 350 |
+
"alpha": lora_alpha,
|
| 351 |
+
"dropout": lora_dropout
|
| 352 |
+
}
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
async def fine_tune_with_unsloth(
|
| 356 |
+
self,
|
| 357 |
+
model_name: str,
|
| 358 |
+
training_examples: List[Dict[str, Any]],
|
| 359 |
+
output_dir: str,
|
| 360 |
+
epochs: int = 3,
|
| 361 |
+
learning_rate: float = 2e-4,
|
| 362 |
+
batch_size: int = 4,
|
| 363 |
+
lora_r: int = 16,
|
| 364 |
+
progress_callback: Optional[Callable] = None
|
| 365 |
+
) -> Dict[str, Any]:
|
| 366 |
+
"""
|
| 367 |
+
Fine-tune model using Unsloth (fastest method, 2x faster than PEFT).
|
| 368 |
+
|
| 369 |
+
Unsloth is optimized for speed and memory efficiency.
|
| 370 |
+
Recommended for: Llama, Mistral, Qwen models
|
| 371 |
+
"""
|
| 372 |
+
try:
|
| 373 |
+
from unsloth import FastLanguageModel
|
| 374 |
+
from trl import SFTTrainer
|
| 375 |
+
from transformers import TrainingArguments
|
| 376 |
+
from datasets import Dataset
|
| 377 |
+
except ImportError as e:
|
| 378 |
+
logger.error(f"Unsloth not installed: {e}")
|
| 379 |
+
logger.error("Please install: pip install unsloth")
|
| 380 |
+
raise
|
| 381 |
+
|
| 382 |
+
logger.info(f"Starting Unsloth fine-tuning for model: {model_name}")
|
| 383 |
+
self.current_training_state.update({
|
| 384 |
+
"is_training": True,
|
| 385 |
+
"progress": 0.0,
|
| 386 |
+
"current_epoch": 0,
|
| 387 |
+
"total_epochs": epochs,
|
| 388 |
+
"model_id": model_name,
|
| 389 |
+
"start_time": datetime.utcnow().isoformat()
|
| 390 |
+
})
|
| 391 |
+
|
| 392 |
+
# 1. Load model with Unsloth
|
| 393 |
+
logger.info("Loading model with Unsloth...")
|
| 394 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 395 |
+
model_name=model_name,
|
| 396 |
+
max_seq_length=2048,
|
| 397 |
+
dtype=None, # Auto-detect
|
| 398 |
+
load_in_4bit=True
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# 2. Add LoRA adapters
|
| 402 |
+
model = FastLanguageModel.get_peft_model(
|
| 403 |
+
model,
|
| 404 |
+
r=lora_r,
|
| 405 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 406 |
+
lora_alpha=16,
|
| 407 |
+
lora_dropout=0,
|
| 408 |
+
bias="none",
|
| 409 |
+
use_gradient_checkpointing=True,
|
| 410 |
+
random_state=42
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# 3. Format training data
|
| 414 |
+
formatted_texts = self.format_training_examples_for_model(training_examples, template="alpaca")
|
| 415 |
+
dataset = Dataset.from_dict({"text": formatted_texts})
|
| 416 |
+
|
| 417 |
+
# 4. Training arguments
|
| 418 |
+
training_args = TrainingArguments(
|
| 419 |
+
output_dir=output_dir,
|
| 420 |
+
num_train_epochs=epochs,
|
| 421 |
+
per_device_train_batch_size=batch_size,
|
| 422 |
+
learning_rate=learning_rate,
|
| 423 |
+
fp16=True,
|
| 424 |
+
logging_steps=10,
|
| 425 |
+
save_steps=100,
|
| 426 |
+
warmup_steps=50
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# 5. Create SFT trainer
|
| 430 |
+
trainer = SFTTrainer(
|
| 431 |
+
model=model,
|
| 432 |
+
tokenizer=tokenizer,
|
| 433 |
+
train_dataset=dataset,
|
| 434 |
+
dataset_text_field="text",
|
| 435 |
+
max_seq_length=2048,
|
| 436 |
+
args=training_args
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# 6. Train
|
| 440 |
+
logger.info("Starting training with Unsloth...")
|
| 441 |
+
trainer.train()
|
| 442 |
+
|
| 443 |
+
# 7. Save
|
| 444 |
+
logger.info(f"Saving model to: {output_dir}")
|
| 445 |
+
model.save_pretrained(output_dir)
|
| 446 |
+
tokenizer.save_pretrained(output_dir)
|
| 447 |
+
|
| 448 |
+
self.current_training_state.update({
|
| 449 |
+
"is_training": False,
|
| 450 |
+
"progress": 100.0
|
| 451 |
+
})
|
| 452 |
+
|
| 453 |
+
return {
|
| 454 |
+
"success": True,
|
| 455 |
+
"output_dir": output_dir,
|
| 456 |
+
"model_name": model_name,
|
| 457 |
+
"method": "unsloth"
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
async def fine_tune_with_mlx(
|
| 461 |
+
self,
|
| 462 |
+
model_name: str,
|
| 463 |
+
training_examples: List[Dict[str, Any]],
|
| 464 |
+
output_dir: str,
|
| 465 |
+
epochs: int = 3,
|
| 466 |
+
learning_rate: float = 1e-5,
|
| 467 |
+
batch_size: int = 4,
|
| 468 |
+
lora_r: int = 8,
|
| 469 |
+
progress_callback: Optional[Callable] = None
|
| 470 |
+
) -> Dict[str, Any]:
|
| 471 |
+
"""
|
| 472 |
+
Fine-tune model using MLX (Apple Silicon only, ultra-fast).
|
| 473 |
+
|
| 474 |
+
Optimized for M1/M2/M3 Macs.
|
| 475 |
+
Uses unified memory for maximum efficiency.
|
| 476 |
+
"""
|
| 477 |
+
try:
|
| 478 |
+
import mlx.core as mx
|
| 479 |
+
from mlx_lm import load, generate
|
| 480 |
+
import mlx.optimizers as optim
|
| 481 |
+
import mlx.nn as nn
|
| 482 |
+
except ImportError as e:
|
| 483 |
+
logger.error(f"MLX not installed: {e}")
|
| 484 |
+
logger.error("Please install: pip install mlx mlx-lm")
|
| 485 |
+
raise
|
| 486 |
+
|
| 487 |
+
logger.info(f"Starting MLX fine-tuning for model: {model_name}")
|
| 488 |
+
self.current_training_state.update({
|
| 489 |
+
"is_training": True,
|
| 490 |
+
"progress": 0.0,
|
| 491 |
+
"current_epoch": 0,
|
| 492 |
+
"total_epochs": epochs,
|
| 493 |
+
"model_id": model_name,
|
| 494 |
+
"start_time": datetime.utcnow().isoformat()
|
| 495 |
+
})
|
| 496 |
+
|
| 497 |
+
# Note: MLX fine-tuning is still experimental
|
| 498 |
+
# For now, return a placeholder
|
| 499 |
+
logger.warning("MLX fine-tuning is not fully implemented yet")
|
| 500 |
+
|
| 501 |
+
self.current_training_state["is_training"] = False
|
| 502 |
+
|
| 503 |
+
return {
|
| 504 |
+
"success": False,
|
| 505 |
+
"error": "MLX fine-tuning not yet implemented",
|
| 506 |
+
"model_name": model_name
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
# ===== Main Training Interface =====
|
| 510 |
+
|
| 511 |
+
async def start_training(
|
| 512 |
+
self,
|
| 513 |
+
apprentice_model_name: str,
|
| 514 |
+
domain_id: Optional[str] = None,
|
| 515 |
+
method: str = "peft", # "peft", "unsloth", "mlx"
|
| 516 |
+
epochs: int = 3,
|
| 517 |
+
learning_rate: float = 2e-4,
|
| 518 |
+
batch_size: int = 4,
|
| 519 |
+
output_name: Optional[str] = None,
|
| 520 |
+
progress_callback: Optional[Callable] = None
|
| 521 |
+
) -> Dict[str, Any]:
|
| 522 |
+
"""
|
| 523 |
+
Main entry point for fine-tuning an apprentice model.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
apprentice_model_name: HuggingFace model name or path
|
| 527 |
+
domain_id: Domain to train on (None = all domains)
|
| 528 |
+
method: Training method ("peft", "unsloth", "mlx")
|
| 529 |
+
epochs: Number of training epochs
|
| 530 |
+
learning_rate: Learning rate
|
| 531 |
+
batch_size: Batch size per device
|
| 532 |
+
output_name: Custom name for output directory
|
| 533 |
+
progress_callback: Async callback for progress updates
|
| 534 |
+
|
| 535 |
+
Returns:
|
| 536 |
+
Training result dictionary
|
| 537 |
+
"""
|
| 538 |
+
# 1. Load training data
|
| 539 |
+
training_examples = self.load_training_data(domain_id)
|
| 540 |
+
|
| 541 |
+
if not training_examples:
|
| 542 |
+
return {
|
| 543 |
+
"success": False,
|
| 544 |
+
"error": "No training data found"
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
# 2. Prepare output directory
|
| 548 |
+
if output_name is None:
|
| 549 |
+
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
| 550 |
+
output_name = f"apprentice_{domain_id or 'all'}_{timestamp}"
|
| 551 |
+
|
| 552 |
+
output_dir = self.checkpoints_dir / output_name
|
| 553 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 554 |
+
|
| 555 |
+
# 3. Select training method
|
| 556 |
+
if method == "peft":
|
| 557 |
+
result = await self.fine_tune_with_huggingface_peft(
|
| 558 |
+
model_name=apprentice_model_name,
|
| 559 |
+
training_examples=training_examples,
|
| 560 |
+
output_dir=str(output_dir),
|
| 561 |
+
epochs=epochs,
|
| 562 |
+
learning_rate=learning_rate,
|
| 563 |
+
batch_size=batch_size,
|
| 564 |
+
progress_callback=progress_callback
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
elif method == "unsloth":
|
| 568 |
+
result = await self.fine_tune_with_unsloth(
|
| 569 |
+
model_name=apprentice_model_name,
|
| 570 |
+
training_examples=training_examples,
|
| 571 |
+
output_dir=str(output_dir),
|
| 572 |
+
epochs=epochs,
|
| 573 |
+
learning_rate=learning_rate,
|
| 574 |
+
batch_size=batch_size,
|
| 575 |
+
progress_callback=progress_callback
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
elif method == "mlx":
|
| 579 |
+
result = await self.fine_tune_with_mlx(
|
| 580 |
+
model_name=apprentice_model_name,
|
| 581 |
+
training_examples=training_examples,
|
| 582 |
+
output_dir=str(output_dir),
|
| 583 |
+
epochs=epochs,
|
| 584 |
+
learning_rate=learning_rate,
|
| 585 |
+
batch_size=batch_size,
|
| 586 |
+
progress_callback=progress_callback
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
else:
|
| 590 |
+
return {
|
| 591 |
+
"success": False,
|
| 592 |
+
"error": f"Unknown training method: {method}"
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
return result
|
| 596 |
+
|
| 597 |
+
def get_training_status(self) -> Dict[str, Any]:
|
| 598 |
+
"""Get current training status."""
|
| 599 |
+
return self.current_training_state.copy()
|
| 600 |
+
|
| 601 |
+
def stop_training(self):
|
| 602 |
+
"""Stop current training (if possible)."""
|
| 603 |
+
# TODO: Implement graceful training interruption
|
| 604 |
+
logger.warning("Training interruption not yet implemented")
|
| 605 |
+
self.current_training_state["is_training"] = False
|
| 606 |
+
|
| 607 |
+
def get_training_metrics(self, checkpoint_dir: str) -> Dict[str, Any]:
|
| 608 |
+
"""
|
| 609 |
+
Load training metrics from a checkpoint.
|
| 610 |
+
"""
|
| 611 |
+
checkpoint_path = Path(checkpoint_dir)
|
| 612 |
+
|
| 613 |
+
if not checkpoint_path.exists():
|
| 614 |
+
return {"error": "Checkpoint not found"}
|
| 615 |
+
|
| 616 |
+
# Look for trainer_state.json
|
| 617 |
+
trainer_state_file = checkpoint_path / "trainer_state.json"
|
| 618 |
+
if trainer_state_file.exists():
|
| 619 |
+
with open(trainer_state_file, 'r') as f:
|
| 620 |
+
trainer_state = json.load(f)
|
| 621 |
+
return {
|
| 622 |
+
"log_history": trainer_state.get("log_history", []),
|
| 623 |
+
"best_metric": trainer_state.get("best_metric"),
|
| 624 |
+
"best_model_checkpoint": trainer_state.get("best_model_checkpoint")
|
| 625 |
+
}
|
| 626 |
+
|
| 627 |
+
return {"error": "No metrics found"}
|
iath_memory.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NullAI - .iath Memory System
|
| 3 |
+
樹木型空間記憶(Dendritic Memory Space)の実装
|
| 4 |
+
|
| 5 |
+
.iathファイル形式との完全互換性を持つ知識検索システム
|
| 6 |
+
6次元座標系による空間的RAG(Retrieval-Augmented Generation)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import struct
|
| 10 |
+
import zstandard as zstd
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
import numpy as np
|
| 14 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class IathDecoder:
|
| 22 |
+
"""
|
| 23 |
+
.iathファイル形式のデコーダー
|
| 24 |
+
dendritic-memory-editorとの互換性を保持
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, iath_file_path: str):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
iath_file_path: .iathファイルのパス
|
| 31 |
+
"""
|
| 32 |
+
self.file_path = Path(iath_file_path)
|
| 33 |
+
self.header = None
|
| 34 |
+
self.index = []
|
| 35 |
+
self.data_section_offset = 0
|
| 36 |
+
|
| 37 |
+
if self.file_path.exists():
|
| 38 |
+
self._load_header_and_index()
|
| 39 |
+
else:
|
| 40 |
+
logger.warning(f".iath file not found: {iath_file_path}")
|
| 41 |
+
|
| 42 |
+
def _load_header_and_index(self):
|
| 43 |
+
"""ヘッダーとインデックスセクションを読み込む"""
|
| 44 |
+
try:
|
| 45 |
+
with open(self.file_path, 'rb') as f:
|
| 46 |
+
# ヘッダー読み込み (64 bytes)
|
| 47 |
+
header_data = f.read(64)
|
| 48 |
+
magic, version, domain_code, compression_type, checksum, index_offset, data_offset = struct.unpack(
|
| 49 |
+
"<4sIBB32sQQ6x", header_data
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.header = {
|
| 53 |
+
"magic": magic.decode('ascii'),
|
| 54 |
+
"version": version,
|
| 55 |
+
"domain_code": domain_code,
|
| 56 |
+
"compression_type": compression_type,
|
| 57 |
+
"index_offset": index_offset,
|
| 58 |
+
"data_offset": data_offset
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
logger.info(f"Loaded .iath header: magic={self.header['magic']}, domain={domain_code}")
|
| 62 |
+
|
| 63 |
+
# インデックス読み込み
|
| 64 |
+
f.seek(index_offset)
|
| 65 |
+
index_binary = f.read(data_offset - index_offset)
|
| 66 |
+
self.index = json.loads(index_binary.decode('utf-8'))
|
| 67 |
+
|
| 68 |
+
self.data_section_offset = data_offset
|
| 69 |
+
logger.info(f"Loaded {len(self.index)} tiles from index")
|
| 70 |
+
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.error(f"Error loading .iath file: {e}")
|
| 73 |
+
raise
|
| 74 |
+
|
| 75 |
+
def _decode_string(self, data: bytes, offset: int) -> Tuple[str, int]:
|
| 76 |
+
"""NULL終端文字列をデコード"""
|
| 77 |
+
end = data.find(b'\0', offset)
|
| 78 |
+
if end == -1:
|
| 79 |
+
return data[offset:].decode('utf-8'), len(data)
|
| 80 |
+
return data[offset:end].decode('utf-8'), end + 1
|
| 81 |
+
|
| 82 |
+
def _decode_tile_data(self, compressed_data: bytes) -> dict:
|
| 83 |
+
"""圧縮されたタイルデータをデコード"""
|
| 84 |
+
# zstd解凍
|
| 85 |
+
dctx = zstd.ZstdDecompressor()
|
| 86 |
+
uncompressed = dctx.decompress(compressed_data)
|
| 87 |
+
|
| 88 |
+
offset = 0
|
| 89 |
+
|
| 90 |
+
# メタデータ
|
| 91 |
+
metadata_len = struct.unpack("<I", uncompressed[offset:offset+4])[0]
|
| 92 |
+
offset += 4
|
| 93 |
+
metadata_bin = uncompressed[offset:offset+metadata_len]
|
| 94 |
+
offset += metadata_len
|
| 95 |
+
|
| 96 |
+
# メタデータ内の文字列をパース
|
| 97 |
+
meta_offset = 0
|
| 98 |
+
knowledge_id, meta_offset = self._decode_string(metadata_bin, meta_offset)
|
| 99 |
+
topic, meta_offset = self._decode_string(metadata_bin, meta_offset)
|
| 100 |
+
created_at = metadata_bin[meta_offset:meta_offset+27].decode('ascii').strip('\0')
|
| 101 |
+
|
| 102 |
+
metadata = {
|
| 103 |
+
"knowledge_id": knowledge_id,
|
| 104 |
+
"topic": topic,
|
| 105 |
+
"created_at": created_at
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# 座標
|
| 109 |
+
coord_len = struct.unpack("<I", uncompressed[offset:offset+4])[0]
|
| 110 |
+
offset += 4
|
| 111 |
+
coord_data = uncompressed[offset:offset+coord_len]
|
| 112 |
+
offset += coord_len
|
| 113 |
+
|
| 114 |
+
x, y, z, c, g, v = struct.unpack("<ffffff", coord_data)
|
| 115 |
+
coordinates = {
|
| 116 |
+
"medical_space": [x, y, z],
|
| 117 |
+
"meta_space": [c, g, v]
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
# コンテンツ
|
| 121 |
+
content_len = struct.unpack("<I", uncompressed[offset:offset+4])[0]
|
| 122 |
+
offset += 4
|
| 123 |
+
content_data = uncompressed[offset:offset+content_len]
|
| 124 |
+
offset += content_len
|
| 125 |
+
|
| 126 |
+
# thinking_process
|
| 127 |
+
thinking_len = struct.unpack("<I", content_data[0:4])[0]
|
| 128 |
+
thinking = content_data[4:4+thinking_len].decode('utf-8')
|
| 129 |
+
|
| 130 |
+
# final_response
|
| 131 |
+
response_offset = 4 + thinking_len
|
| 132 |
+
response_len = struct.unpack("<I", content_data[response_offset:response_offset+4])[0]
|
| 133 |
+
response = content_data[response_offset+4:response_offset+4+response_len].decode('utf-8')
|
| 134 |
+
|
| 135 |
+
content = {
|
| 136 |
+
"thinking_process": thinking,
|
| 137 |
+
"final_response": response
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
# 検証(簡易実装)
|
| 141 |
+
verification_len = struct.unpack("<I", uncompressed[offset:offset+4])[0]
|
| 142 |
+
offset += 4
|
| 143 |
+
verification_data = uncompressed[offset:offset+verification_len]
|
| 144 |
+
|
| 145 |
+
status_code, initial_certainty, reviewer_count = struct.unpack("<BBI", verification_data[:6])
|
| 146 |
+
status_map = {0: "pending_review", 1: "partial_verified", 2: "verified", 3: "expert_confirmed"}
|
| 147 |
+
|
| 148 |
+
verification = {
|
| 149 |
+
"status": status_map.get(status_code, "pending_review"),
|
| 150 |
+
"initial_certainty": initial_certainty / 100.0,
|
| 151 |
+
"reviewers": [] # 詳細は省略
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
return {
|
| 155 |
+
"metadata": metadata,
|
| 156 |
+
"coordinates": coordinates,
|
| 157 |
+
"content": content,
|
| 158 |
+
"verification": verification
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
def get_tile_by_id(self, knowledge_id: str) -> Optional[dict]:
|
| 162 |
+
"""IDで特定のタイルを取得"""
|
| 163 |
+
try:
|
| 164 |
+
# インデックスから検索
|
| 165 |
+
index_entry = next((entry for entry in self.index if entry["id"] == knowledge_id), None)
|
| 166 |
+
if not index_entry:
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
# データセクションから読み込み
|
| 170 |
+
with open(self.file_path, 'rb') as f:
|
| 171 |
+
f.seek(self.data_section_offset + index_entry["offset"])
|
| 172 |
+
compressed_data = f.read(index_entry["length"])
|
| 173 |
+
return self._decode_tile_data(compressed_data)
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
logger.error(f"Error loading tile {knowledge_id}: {e}")
|
| 177 |
+
return None
|
| 178 |
+
|
| 179 |
+
def get_all_tiles(self) -> List[dict]:
|
| 180 |
+
"""全タイルを取得(メモリに注意)"""
|
| 181 |
+
tiles = []
|
| 182 |
+
for entry in self.index:
|
| 183 |
+
tile = self.get_tile_by_id(entry["id"])
|
| 184 |
+
if tile:
|
| 185 |
+
tiles.append(tile)
|
| 186 |
+
return tiles
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class DendriticMemorySpace:
|
| 190 |
+
"""
|
| 191 |
+
樹木型空間記憶システム
|
| 192 |
+
|
| 193 |
+
6次元座標系による知識の空間配置と検索:
|
| 194 |
+
- medical_space [x, y, z]: ドメイン固有の3次元空間
|
| 195 |
+
- meta_space [c, g, v]: Certainty, Granularity, Verification
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self, iath_file_path: str):
|
| 199 |
+
"""
|
| 200 |
+
Args:
|
| 201 |
+
iath_file_path: .iathファイルのパス
|
| 202 |
+
"""
|
| 203 |
+
self.decoder = IathDecoder(iath_file_path)
|
| 204 |
+
self.tiles_cache = [] # 全タイルをメモリにキャッシュ(最適化可能)
|
| 205 |
+
self._build_spatial_index()
|
| 206 |
+
|
| 207 |
+
def _build_spatial_index(self):
|
| 208 |
+
"""空間インデックスを構築"""
|
| 209 |
+
logger.info("Building spatial index from .iath file...")
|
| 210 |
+
self.tiles_cache = self.decoder.get_all_tiles()
|
| 211 |
+
|
| 212 |
+
# 座標行列を構築(高速検索用)
|
| 213 |
+
if self.tiles_cache:
|
| 214 |
+
self.coordinates_matrix = np.array([
|
| 215 |
+
tile["coordinates"]["medical_space"] + tile["coordinates"]["meta_space"]
|
| 216 |
+
for tile in self.tiles_cache
|
| 217 |
+
])
|
| 218 |
+
logger.info(f"Spatial index built: {len(self.tiles_cache)} tiles")
|
| 219 |
+
else:
|
| 220 |
+
self.coordinates_matrix = np.array([])
|
| 221 |
+
logger.warning("No tiles found in .iath file")
|
| 222 |
+
|
| 223 |
+
def search_by_coordinates(self, query_coords: List[float], top_k: int = 5, distance_threshold: float = None) -> List[dict]:
|
| 224 |
+
"""
|
| 225 |
+
座標ベースの空間検索(樹木型空間記憶の核心機能)
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
query_coords: クエリ座標 [x, y, z, c, g, v]
|
| 229 |
+
top_k: 返却する上位K件
|
| 230 |
+
distance_threshold: 距離閾値(Noneなら無制限)
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
関連するタイルのリスト(距離の近い順)
|
| 234 |
+
"""
|
| 235 |
+
if len(self.tiles_cache) == 0:
|
| 236 |
+
return []
|
| 237 |
+
|
| 238 |
+
query_vector = np.array(query_coords)
|
| 239 |
+
|
| 240 |
+
# ユークリッド距離を計算
|
| 241 |
+
distances = np.linalg.norm(self.coordinates_matrix - query_vector, axis=1)
|
| 242 |
+
|
| 243 |
+
# 距離でソート
|
| 244 |
+
sorted_indices = np.argsort(distances)
|
| 245 |
+
|
| 246 |
+
# top_k件を取得
|
| 247 |
+
results = []
|
| 248 |
+
for idx in sorted_indices[:top_k]:
|
| 249 |
+
distance = distances[idx]
|
| 250 |
+
if distance_threshold is not None and distance > distance_threshold:
|
| 251 |
+
break
|
| 252 |
+
|
| 253 |
+
tile = self.tiles_cache[idx].copy()
|
| 254 |
+
tile["spatial_distance"] = float(distance)
|
| 255 |
+
results.append(tile)
|
| 256 |
+
|
| 257 |
+
return results
|
| 258 |
+
|
| 259 |
+
def search_by_text(self, query_text: str, top_k: int = 5) -> List[dict]:
|
| 260 |
+
"""
|
| 261 |
+
テキスト検索(簡易実装:キーワードマッチング)
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
query_text: 検索クエリテキスト
|
| 265 |
+
top_k: 返却する上位K件
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
関連するタイルのリスト
|
| 269 |
+
"""
|
| 270 |
+
query_lower = query_text.lower()
|
| 271 |
+
matches = []
|
| 272 |
+
|
| 273 |
+
for tile in self.tiles_cache:
|
| 274 |
+
topic = tile["metadata"]["topic"].lower()
|
| 275 |
+
content = tile["content"]["final_response"].lower()
|
| 276 |
+
|
| 277 |
+
# シンプルなスコアリング(含まれる回数)
|
| 278 |
+
score = topic.count(query_lower) * 2 + content.count(query_lower)
|
| 279 |
+
|
| 280 |
+
if score > 0:
|
| 281 |
+
tile_copy = tile.copy()
|
| 282 |
+
tile_copy["text_match_score"] = score
|
| 283 |
+
matches.append(tile_copy)
|
| 284 |
+
|
| 285 |
+
# スコアでソート
|
| 286 |
+
matches.sort(key=lambda x: x["text_match_score"], reverse=True)
|
| 287 |
+
|
| 288 |
+
return matches[:top_k]
|
| 289 |
+
|
| 290 |
+
def hybrid_search(self, query_text: str, query_coords: Optional[List[float]] = None, top_k: int = 5) -> List[dict]:
|
| 291 |
+
"""
|
| 292 |
+
ハイブリッド検索:テキスト + 空間座標
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
query_text: 検索クエリテキスト
|
| 296 |
+
query_coords: クエリ座標(オプション)
|
| 297 |
+
top_k: 返却する上位K件
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
関連するタイルのリスト
|
| 301 |
+
"""
|
| 302 |
+
# テキスト検索
|
| 303 |
+
text_results = self.search_by_text(query_text, top_k=top_k*2)
|
| 304 |
+
|
| 305 |
+
# 座標検索が指定されている場合
|
| 306 |
+
if query_coords:
|
| 307 |
+
spatial_results = self.search_by_coordinates(query_coords, top_k=top_k*2)
|
| 308 |
+
|
| 309 |
+
# 両方に出現するタイルを優先的にスコアリング
|
| 310 |
+
combined_scores = {}
|
| 311 |
+
|
| 312 |
+
for tile in text_results:
|
| 313 |
+
tile_id = tile["metadata"]["knowledge_id"]
|
| 314 |
+
combined_scores[tile_id] = {
|
| 315 |
+
"tile": tile,
|
| 316 |
+
"text_score": tile.get("text_match_score", 0),
|
| 317 |
+
"spatial_score": 0
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
for tile in spatial_results:
|
| 321 |
+
tile_id = tile["metadata"]["knowledge_id"]
|
| 322 |
+
if tile_id in combined_scores:
|
| 323 |
+
combined_scores[tile_id]["spatial_score"] = 1.0 / (1.0 + tile.get("spatial_distance", 10))
|
| 324 |
+
else:
|
| 325 |
+
combined_scores[tile_id] = {
|
| 326 |
+
"tile": tile,
|
| 327 |
+
"text_score": 0,
|
| 328 |
+
"spatial_score": 1.0 / (1.0 + tile.get("spatial_distance", 10))
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
# 複合スコアで並び替え
|
| 332 |
+
ranked = sorted(
|
| 333 |
+
combined_scores.values(),
|
| 334 |
+
key=lambda x: x["text_score"] * 0.6 + x["spatial_score"] * 0.4,
|
| 335 |
+
reverse=True
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return [item["tile"] for item in ranked[:top_k]]
|
| 339 |
+
|
| 340 |
+
else:
|
| 341 |
+
# テキスト検索のみ
|
| 342 |
+
return text_results[:top_k]
|
| 343 |
+
|
| 344 |
+
def get_statistics(self) -> dict:
|
| 345 |
+
"""メモリ空間の統計情報を取得"""
|
| 346 |
+
if len(self.tiles_cache) == 0:
|
| 347 |
+
return {"total_tiles": 0}
|
| 348 |
+
|
| 349 |
+
coords = self.coordinates_matrix
|
| 350 |
+
|
| 351 |
+
return {
|
| 352 |
+
"total_tiles": len(self.tiles_cache),
|
| 353 |
+
"coordinate_ranges": {
|
| 354 |
+
"medical_x": {"min": float(coords[:, 0].min()), "max": float(coords[:, 0].max())},
|
| 355 |
+
"medical_y": {"min": float(coords[:, 1].min()), "max": float(coords[:, 1].max())},
|
| 356 |
+
"medical_z": {"min": float(coords[:, 2].min()), "max": float(coords[:, 2].max())},
|
| 357 |
+
"certainty": {"min": float(coords[:, 3].min()), "max": float(coords[:, 3].max())},
|
| 358 |
+
"granularity": {"min": float(coords[:, 4].min()), "max": float(coords[:, 4].max())},
|
| 359 |
+
"verification": {"min": float(coords[:, 5].min()), "max": float(coords[:, 5].max())}
|
| 360 |
+
},
|
| 361 |
+
"verification_status_distribution": self._get_verification_distribution()
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
def _get_verification_distribution(self) -> dict:
|
| 365 |
+
"""検証ステータスの分布を取得"""
|
| 366 |
+
distribution = {}
|
| 367 |
+
for tile in self.tiles_cache:
|
| 368 |
+
status = tile["verification"]["status"]
|
| 369 |
+
distribution[status] = distribution.get(status, 0) + 1
|
| 370 |
+
return distribution
|
iath_writer.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# null_ai/iath_writer.py
|
| 2 |
+
"""
|
| 3 |
+
.iath File Writer Module
|
| 4 |
+
|
| 5 |
+
AI生成知識を.iath形式で保存するためのモジュール。
|
| 6 |
+
dendritic-memory-editor完全互換。
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import struct
|
| 10 |
+
import zstandard as zstd
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import json
|
| 13 |
+
import logging
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from typing import Dict, List, Any, Optional
|
| 16 |
+
import uuid
|
| 17 |
+
|
| 18 |
+
from null_ai.iath_memory import IathDecoder
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class IathWriter:
|
| 24 |
+
"""
|
| 25 |
+
.iathファイルへの書き込みを管理するクラス
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, iath_file_path: str):
|
| 29 |
+
self.file_path = Path(iath_file_path)
|
| 30 |
+
self.encoder = IathTileEncoder()
|
| 31 |
+
|
| 32 |
+
def append_tile(self, tile: Dict[str, Any]) -> bool:
|
| 33 |
+
"""
|
| 34 |
+
既存の.iathファイルに新しいタイルを追記
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
tile: Knowledge Tile オブジェクト
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
成功したかどうか
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
# 既存ファイルが存在するか確認
|
| 44 |
+
if not self.file_path.exists():
|
| 45 |
+
# 新規ファイル作成
|
| 46 |
+
return self._create_new_iath_file([tile])
|
| 47 |
+
|
| 48 |
+
# 既存ファイルの読み込み
|
| 49 |
+
decoder = IathDecoder(str(self.file_path))
|
| 50 |
+
existing_tiles = decoder.get_all_tiles()
|
| 51 |
+
|
| 52 |
+
# 新しいタイルを追加
|
| 53 |
+
existing_tiles.append(tile)
|
| 54 |
+
|
| 55 |
+
# ファイルを再構築
|
| 56 |
+
return self._rebuild_iath_file(existing_tiles)
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"Failed to append tile to .iath: {e}")
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
def append_tiles_batch(self, tiles: List[Dict[str, Any]]) -> bool:
|
| 63 |
+
"""
|
| 64 |
+
複数のタイルを一括追記
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
tiles: Knowledge Tile のリスト
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
成功したかどうか
|
| 71 |
+
"""
|
| 72 |
+
try:
|
| 73 |
+
if not self.file_path.exists():
|
| 74 |
+
return self._create_new_iath_file(tiles)
|
| 75 |
+
|
| 76 |
+
decoder = IathDecoder(str(self.file_path))
|
| 77 |
+
existing_tiles = decoder.get_all_tiles()
|
| 78 |
+
|
| 79 |
+
# 新しいタイルを追加
|
| 80 |
+
existing_tiles.extend(tiles)
|
| 81 |
+
|
| 82 |
+
# ファイルを再構築
|
| 83 |
+
return self._rebuild_iath_file(existing_tiles)
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Failed to append tiles batch to .iath: {e}")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
def _create_new_iath_file(self, tiles: List[Dict[str, Any]]) -> bool:
|
| 90 |
+
"""
|
| 91 |
+
新規.iathファイルを作成
|
| 92 |
+
"""
|
| 93 |
+
try:
|
| 94 |
+
with open(self.file_path, 'wb') as f:
|
| 95 |
+
# Header作成
|
| 96 |
+
header = self._create_header(len(tiles))
|
| 97 |
+
f.write(header)
|
| 98 |
+
|
| 99 |
+
# Index作成
|
| 100 |
+
index_data, tiles_data = self._create_index_and_data(tiles)
|
| 101 |
+
index_json = json.dumps(index_data, ensure_ascii=False).encode('utf-8')
|
| 102 |
+
|
| 103 |
+
# Indexサイズを書き込み
|
| 104 |
+
f.write(struct.pack('<I', len(index_json)))
|
| 105 |
+
f.write(index_json)
|
| 106 |
+
|
| 107 |
+
# Dataセクション書き込み
|
| 108 |
+
f.write(tiles_data)
|
| 109 |
+
|
| 110 |
+
logger.info(f"Created new .iath file: {self.file_path}")
|
| 111 |
+
return True
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Failed to create new .iath file: {e}")
|
| 115 |
+
return False
|
| 116 |
+
|
| 117 |
+
def _rebuild_iath_file(self, tiles: List[Dict[str, Any]]) -> bool:
|
| 118 |
+
"""
|
| 119 |
+
.iathファイルを再構築
|
| 120 |
+
"""
|
| 121 |
+
# 一時ファイルに書き込み
|
| 122 |
+
temp_path = self.file_path.with_suffix('.iath.tmp')
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
with open(temp_path, 'wb') as f:
|
| 126 |
+
# Header作成
|
| 127 |
+
header = self._create_header(len(tiles))
|
| 128 |
+
f.write(header)
|
| 129 |
+
|
| 130 |
+
# Index作成
|
| 131 |
+
index_data, tiles_data = self._create_index_and_data(tiles)
|
| 132 |
+
index_json = json.dumps(index_data, ensure_ascii=False).encode('utf-8')
|
| 133 |
+
|
| 134 |
+
# Indexサイズを書き込み
|
| 135 |
+
f.write(struct.pack('<I', len(index_json)))
|
| 136 |
+
f.write(index_json)
|
| 137 |
+
|
| 138 |
+
# Dataセクション書き込み
|
| 139 |
+
f.write(tiles_data)
|
| 140 |
+
|
| 141 |
+
# 一時ファイルを本ファイルに置き換え
|
| 142 |
+
temp_path.replace(self.file_path)
|
| 143 |
+
|
| 144 |
+
logger.info(f"Rebuilt .iath file with {len(tiles)} tiles: {self.file_path}")
|
| 145 |
+
return True
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
logger.error(f"Failed to rebuild .iath file: {e}")
|
| 149 |
+
if temp_path.exists():
|
| 150 |
+
temp_path.unlink()
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
def _create_header(self, tile_count: int) -> bytes:
|
| 154 |
+
"""
|
| 155 |
+
.iathヘッダーを作成(64バイト)
|
| 156 |
+
"""
|
| 157 |
+
# マジックナンバー: "IATH" (4 bytes)
|
| 158 |
+
magic = b'IATH'
|
| 159 |
+
|
| 160 |
+
# バージョン: 1.0 (2 bytes)
|
| 161 |
+
version_major = 1
|
| 162 |
+
version_minor = 0
|
| 163 |
+
|
| 164 |
+
# タイル数 (4 bytes)
|
| 165 |
+
# 作成日時 (8 bytes, Unix timestamp)
|
| 166 |
+
created_at = int(datetime.now().timestamp())
|
| 167 |
+
|
| 168 |
+
# 予約領域 (46 bytes)
|
| 169 |
+
reserved = b'\x00' * 46
|
| 170 |
+
|
| 171 |
+
header = struct.pack(
|
| 172 |
+
'<4s 2B I Q 46s',
|
| 173 |
+
magic,
|
| 174 |
+
version_major,
|
| 175 |
+
version_minor,
|
| 176 |
+
tile_count,
|
| 177 |
+
created_at,
|
| 178 |
+
reserved
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
assert len(header) == 64, f"Header size must be 64 bytes, got {len(header)}"
|
| 182 |
+
return header
|
| 183 |
+
|
| 184 |
+
def _create_index_and_data(self, tiles: List[Dict[str, Any]]) -> tuple:
|
| 185 |
+
"""
|
| 186 |
+
インデックスとデータセクションを作成
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
(index_dict, data_bytes)
|
| 190 |
+
"""
|
| 191 |
+
index = {"tiles": []}
|
| 192 |
+
data_buffer = bytearray()
|
| 193 |
+
|
| 194 |
+
current_offset = 0
|
| 195 |
+
|
| 196 |
+
for tile in tiles:
|
| 197 |
+
# タイルをエンコード
|
| 198 |
+
tile_bytes = self.encoder.encode_tile(tile)
|
| 199 |
+
|
| 200 |
+
# インデックスエントリ作成
|
| 201 |
+
index["tiles"].append({
|
| 202 |
+
"id": tile["metadata"]["knowledge_id"],
|
| 203 |
+
"offset": current_offset,
|
| 204 |
+
"size": len(tile_bytes)
|
| 205 |
+
})
|
| 206 |
+
|
| 207 |
+
# データバッファに追加
|
| 208 |
+
data_buffer.extend(tile_bytes)
|
| 209 |
+
current_offset += len(tile_bytes)
|
| 210 |
+
|
| 211 |
+
return index, bytes(data_buffer)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class IathTileEncoder:
|
| 215 |
+
"""
|
| 216 |
+
Knowledge Tileを.iath互換のバイナリ形式にエンコード
|
| 217 |
+
|
| 218 |
+
dendritic-memory-editorのIathEncoderと互換性あり
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
def encode_tile(self, tile: Dict[str, Any]) -> bytes:
|
| 222 |
+
"""
|
| 223 |
+
単一のKnowledge Tileをエンコードし、zstdで圧縮
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
tile: Knowledge Tile オブジェクト
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
圧縮されたバイナリデータ
|
| 230 |
+
"""
|
| 231 |
+
# 各セクションをエンコード
|
| 232 |
+
metadata_bin = self._encode_metadata(tile["metadata"])
|
| 233 |
+
coord_bin = self._encode_coordinates(tile["coordinates"])
|
| 234 |
+
content_bin = self._encode_content(tile["content"])
|
| 235 |
+
verification_bin = self._encode_verification(tile["verification"])
|
| 236 |
+
|
| 237 |
+
# 長さプレフィックスを付けて連結
|
| 238 |
+
uncompressed = b"".join([
|
| 239 |
+
struct.pack("<I", len(metadata_bin)), metadata_bin,
|
| 240 |
+
struct.pack("<I", len(coord_bin)), coord_bin,
|
| 241 |
+
struct.pack("<I", len(content_bin)), content_bin,
|
| 242 |
+
struct.pack("<I", len(verification_bin)), verification_bin,
|
| 243 |
+
])
|
| 244 |
+
|
| 245 |
+
# zstdで圧縮(レベル19 = 最高圧縮率)
|
| 246 |
+
cctx = zstd.ZstdCompressor(level=19)
|
| 247 |
+
compressed = cctx.compress(uncompressed)
|
| 248 |
+
|
| 249 |
+
return compressed
|
| 250 |
+
|
| 251 |
+
def _encode_metadata(self, metadata: Dict[str, Any]) -> bytes:
|
| 252 |
+
"""メタデータをバイナリ化"""
|
| 253 |
+
kid = self._encode_string(metadata["knowledge_id"])
|
| 254 |
+
topic = self._encode_string(metadata["topic"])
|
| 255 |
+
|
| 256 |
+
# 作成日時(ISO形式)
|
| 257 |
+
created_at_iso = metadata.get("created_at", datetime.now().isoformat())
|
| 258 |
+
created_at = created_at_iso.encode('ascii')[:27] # ISO format with Z
|
| 259 |
+
|
| 260 |
+
return kid + topic + created_at
|
| 261 |
+
|
| 262 |
+
def _encode_coordinates(self, coordinates: Dict[str, Any]) -> bytes:
|
| 263 |
+
"""
|
| 264 |
+
座標をバイナリ化(6つの浮動小数点数)
|
| 265 |
+
|
| 266 |
+
座標は以下のいずれかの形式を受け付ける:
|
| 267 |
+
1. {"medical_space": [x, y, z], "meta_space": [c, g, v]}
|
| 268 |
+
2. [x, y, z, c, g, v]
|
| 269 |
+
"""
|
| 270 |
+
if isinstance(coordinates, dict):
|
| 271 |
+
medical_space = coordinates["medical_space"]
|
| 272 |
+
meta_space = coordinates["meta_space"]
|
| 273 |
+
|
| 274 |
+
return struct.pack(
|
| 275 |
+
"<ffffff",
|
| 276 |
+
float(medical_space[0]), float(medical_space[1]), float(medical_space[2]),
|
| 277 |
+
float(meta_space[0]), float(meta_space[1]), float(meta_space[2])
|
| 278 |
+
)
|
| 279 |
+
elif isinstance(coordinates, list) and len(coordinates) == 6:
|
| 280 |
+
# フラットな配列形式
|
| 281 |
+
return struct.pack(
|
| 282 |
+
"<ffffff",
|
| 283 |
+
float(coordinates[0]), float(coordinates[1]), float(coordinates[2]),
|
| 284 |
+
float(coordinates[3]), float(coordinates[4]), float(coordinates[5])
|
| 285 |
+
)
|
| 286 |
+
else:
|
| 287 |
+
raise ValueError(f"Invalid coordinates format: {coordinates}")
|
| 288 |
+
|
| 289 |
+
def _encode_content(self, content: Dict[str, Any]) -> bytes:
|
| 290 |
+
"""コンテンツ(テキスト)をバイナリ化"""
|
| 291 |
+
thinking = content.get("thinking_process", "").encode('utf-8')
|
| 292 |
+
response = content.get("final_response", "").encode('utf-8')
|
| 293 |
+
|
| 294 |
+
# 各パートの長さを前に付けて連結
|
| 295 |
+
result = struct.pack("<I", len(thinking)) + thinking
|
| 296 |
+
result += struct.pack("<I", len(response)) + response
|
| 297 |
+
|
| 298 |
+
return result
|
| 299 |
+
|
| 300 |
+
def _encode_verification(self, verification: Dict[str, Any]) -> bytes:
|
| 301 |
+
"""検証履歴をバイナリ化"""
|
| 302 |
+
status_map = {
|
| 303 |
+
"pending_review": 0,
|
| 304 |
+
"partial_verified": 1,
|
| 305 |
+
"verified": 2,
|
| 306 |
+
"expert_confirmed": 3
|
| 307 |
+
}
|
| 308 |
+
status_code = status_map.get(verification.get("status", "pending_review"), 0)
|
| 309 |
+
|
| 310 |
+
initial_certainty = int(verification.get("initial_certainty", 0) * 100) # 0-100に変換
|
| 311 |
+
reviewer_count = len(verification.get("reviewers", []))
|
| 312 |
+
|
| 313 |
+
result = struct.pack("<BBI", status_code, initial_certainty, reviewer_count)
|
| 314 |
+
|
| 315 |
+
# レビュアー情報
|
| 316 |
+
for reviewer in verification.get("reviewers", []):
|
| 317 |
+
result += self._encode_reviewer_reference(reviewer)
|
| 318 |
+
|
| 319 |
+
return result
|
| 320 |
+
|
| 321 |
+
def _encode_reviewer_reference(self, reviewer: Dict[str, Any]) -> bytes:
|
| 322 |
+
"""レビュアー情報をエンコード"""
|
| 323 |
+
reviewer_id = reviewer.get("reviewer_id", "unknown").encode('utf-8')
|
| 324 |
+
return struct.pack("<36s", reviewer_id[:36]) # UUID string length
|
| 325 |
+
|
| 326 |
+
def _encode_string(self, s: str) -> bytes:
|
| 327 |
+
"""NULL終端のUTF-8文字列をエンコード"""
|
| 328 |
+
return s.encode('utf-8') + b'\0'
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def create_tile_from_ai_output(
|
| 332 |
+
knowledge_id: str,
|
| 333 |
+
topic: str,
|
| 334 |
+
prompt: str,
|
| 335 |
+
response: str,
|
| 336 |
+
coordinates: List[float],
|
| 337 |
+
confidence: float,
|
| 338 |
+
domain_id: str,
|
| 339 |
+
source: str = "ai_generated"
|
| 340 |
+
) -> Dict[str, Any]:
|
| 341 |
+
"""
|
| 342 |
+
AI出力からKnowledge Tileオブジェクトを作成
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
knowledge_id: 知識ID(UUID推奨)
|
| 346 |
+
topic: トピック
|
| 347 |
+
prompt: ユーザーの質問
|
| 348 |
+
response: AIの回答
|
| 349 |
+
coordinates: 6次元座標 [x, y, z, c, g, v]
|
| 350 |
+
confidence: 信頼度 (0.0-1.0)
|
| 351 |
+
domain_id: ドメインID
|
| 352 |
+
source: ソース("ai_generated", "human_verified", etc.)
|
| 353 |
+
|
| 354 |
+
Returns:
|
| 355 |
+
Knowledge Tile オブジェクト
|
| 356 |
+
"""
|
| 357 |
+
if len(coordinates) != 6:
|
| 358 |
+
raise ValueError(f"Coordinates must be 6-dimensional, got {len(coordinates)}")
|
| 359 |
+
|
| 360 |
+
# 検証ステータスの決定
|
| 361 |
+
if confidence >= 0.9:
|
| 362 |
+
status = "expert_confirmed"
|
| 363 |
+
elif confidence >= 0.8:
|
| 364 |
+
status = "verified"
|
| 365 |
+
elif confidence >= 0.7:
|
| 366 |
+
status = "partial_verified"
|
| 367 |
+
else:
|
| 368 |
+
status = "pending_review"
|
| 369 |
+
|
| 370 |
+
tile = {
|
| 371 |
+
"metadata": {
|
| 372 |
+
"knowledge_id": knowledge_id,
|
| 373 |
+
"topic": topic,
|
| 374 |
+
"created_at": datetime.now().isoformat(),
|
| 375 |
+
"domain_id": domain_id,
|
| 376 |
+
"source": source
|
| 377 |
+
},
|
| 378 |
+
"coordinates": {
|
| 379 |
+
"medical_space": coordinates[:3],
|
| 380 |
+
"meta_space": coordinates[3:]
|
| 381 |
+
},
|
| 382 |
+
"content": {
|
| 383 |
+
"thinking_process": f"Question: {prompt}",
|
| 384 |
+
"final_response": response
|
| 385 |
+
},
|
| 386 |
+
"verification": {
|
| 387 |
+
"status": status,
|
| 388 |
+
"initial_certainty": confidence,
|
| 389 |
+
"reviewers": []
|
| 390 |
+
}
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
return tile
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def merge_jsonl_to_iath(
|
| 397 |
+
jsonl_path: str,
|
| 398 |
+
iath_path: str,
|
| 399 |
+
coordinate_estimator,
|
| 400 |
+
llm_inference_func
|
| 401 |
+
) -> int:
|
| 402 |
+
"""
|
| 403 |
+
JSONLファイルの訓練データを.iath形式に変換して追記
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
jsonl_path: Alpaca形式のJSONLファイルパス
|
| 407 |
+
iath_path: 出力先.iathファイルパス
|
| 408 |
+
coordinate_estimator: CoordinateEstimatorインスタンス
|
| 409 |
+
llm_inference_func: LLM推論関数
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
変換したタイル数
|
| 413 |
+
"""
|
| 414 |
+
import asyncio
|
| 415 |
+
|
| 416 |
+
writer = IathWriter(iath_path)
|
| 417 |
+
tiles_created = 0
|
| 418 |
+
|
| 419 |
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
| 420 |
+
for line in f:
|
| 421 |
+
try:
|
| 422 |
+
example = json.loads(line.strip())
|
| 423 |
+
|
| 424 |
+
# 座標推定
|
| 425 |
+
coord_result = asyncio.run(coordinate_estimator.estimate_coordinates(
|
| 426 |
+
prompt=example["input"],
|
| 427 |
+
response=example["output"],
|
| 428 |
+
domain_id=example["metadata"]["domain_id"],
|
| 429 |
+
llm_inference_func=llm_inference_func
|
| 430 |
+
))
|
| 431 |
+
|
| 432 |
+
# Tileオブジェクト作成
|
| 433 |
+
tile = create_tile_from_ai_output(
|
| 434 |
+
knowledge_id=str(uuid.uuid4()),
|
| 435 |
+
topic=example["input"][:100], # 最初の100文字をトピックに
|
| 436 |
+
prompt=example["input"],
|
| 437 |
+
response=example["output"],
|
| 438 |
+
coordinates=coord_result["coordinates"],
|
| 439 |
+
confidence=example["metadata"]["confidence"],
|
| 440 |
+
domain_id=example["metadata"]["domain_id"],
|
| 441 |
+
source="master_output"
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# .iathに追記
|
| 445 |
+
writer.append_tile(tile)
|
| 446 |
+
tiles_created += 1
|
| 447 |
+
|
| 448 |
+
except Exception as e:
|
| 449 |
+
logger.error(f"Failed to convert line to tile: {e}")
|
| 450 |
+
continue
|
| 451 |
+
|
| 452 |
+
logger.info(f"Merged {tiles_created} tiles from {jsonl_path} to {iath_path}")
|
| 453 |
+
return tiles_created
|
llm_providers.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# null_ai/llm_providers.py
|
| 2 |
+
import logging
|
| 3 |
+
import asyncio
|
| 4 |
+
import time
|
| 5 |
+
import threading # For TextIteratorStreamer
|
| 6 |
+
from typing import Dict, Any, AsyncGenerator, Optional
|
| 7 |
+
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from backend.app.config import ModelConfig, ModelProvider # ModelConfigを使用
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# --- HuggingFace Transformers Provider ---
|
| 16 |
+
class HuggingFaceProvider:
|
| 17 |
+
"""
|
| 18 |
+
HuggingFace Transformersモデルとのインタラクションを管理するクラス。
|
| 19 |
+
"""
|
| 20 |
+
_loaded_models: Dict[str, Any] = {} # {model_name: {"model": model, "tokenizer": tokenizer}}
|
| 21 |
+
|
| 22 |
+
async def _load_model_and_tokenizer(self, model_name: str) -> Dict[str, Any]:
|
| 23 |
+
"""
|
| 24 |
+
HuggingFaceモデルとトークナイザーをロード(キャッシュ)する。
|
| 25 |
+
"""
|
| 26 |
+
if model_name in self._loaded_models:
|
| 27 |
+
return self._loaded_models[model_name]
|
| 28 |
+
|
| 29 |
+
logger.info(f"Loading HuggingFace model: {model_name}...")
|
| 30 |
+
try:
|
| 31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 32 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 33 |
+
model_name,
|
| 34 |
+
torch_dtype=torch.float16, # 高速化のためfloat16を使用
|
| 35 |
+
device_map="auto", # 利用可能なデバイスに自動マップ
|
| 36 |
+
trust_remote_code=True
|
| 37 |
+
)
|
| 38 |
+
self._loaded_models[model_name] = {"model": model, "tokenizer": tokenizer}
|
| 39 |
+
logger.info(f"HuggingFace model '{model_name}' loaded successfully.")
|
| 40 |
+
return self._loaded_models[model_name]
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.error(f"Failed to load HuggingFace model '{model_name}': {e}")
|
| 43 |
+
raise
|
| 44 |
+
|
| 45 |
+
async def infer(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
|
| 46 |
+
"""
|
| 47 |
+
HuggingFaceモデルで非ストリーミング推論を実行する。
|
| 48 |
+
"""
|
| 49 |
+
model_data = await self._load_model_and_tokenizer(model_config.model_name)
|
| 50 |
+
model = model_data["model"]
|
| 51 |
+
tokenizer = model_data["tokenizer"]
|
| 52 |
+
|
| 53 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 54 |
+
|
| 55 |
+
start_time = time.perf_counter()
|
| 56 |
+
outputs = model.generate(
|
| 57 |
+
**inputs,
|
| 58 |
+
max_new_tokens=model_config.max_tokens,
|
| 59 |
+
temperature=temperature,
|
| 60 |
+
do_sample=temperature > 0,
|
| 61 |
+
pad_token_id=tokenizer.eos_token_id # ストリーミングでない場合は必須ではないが、安全のため
|
| 62 |
+
)
|
| 63 |
+
end_time = time.perf_counter()
|
| 64 |
+
|
| 65 |
+
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 66 |
+
latency_ms = (end_time - start_time) * 1000
|
| 67 |
+
|
| 68 |
+
# ここでは推論チェーンの抽出や確信度計算は行わない (ModelRouterの役割)
|
| 69 |
+
return {
|
| 70 |
+
"response": response_text,
|
| 71 |
+
"thinking": f"Inferred by {model_config.display_name} (HuggingFace).",
|
| 72 |
+
"confidence": 0.8, # 仮の値
|
| 73 |
+
"latency_ms": latency_ms
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
async def infer_streaming(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
|
| 77 |
+
"""
|
| 78 |
+
HuggingFaceモデルでストリーミング推論を実行する。
|
| 79 |
+
"""
|
| 80 |
+
model_data = await self._load_model_and_tokenizer(model_config.model_name)
|
| 81 |
+
model = model_data["model"]
|
| 82 |
+
tokenizer = model_data["tokenizer"]
|
| 83 |
+
|
| 84 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 85 |
+
|
| 86 |
+
# TextIteratorStreamerを使用
|
| 87 |
+
streamer = TextIteratorStreamer(
|
| 88 |
+
tokenizer,
|
| 89 |
+
skip_prompt=True,
|
| 90 |
+
skip_special_tokens=True
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
generation_kwargs = {
|
| 94 |
+
**inputs,
|
| 95 |
+
"max_new_tokens": model_config.max_tokens,
|
| 96 |
+
"temperature": temperature,
|
| 97 |
+
"do_sample": temperature > 0,
|
| 98 |
+
"streamer": streamer,
|
| 99 |
+
"pad_token_id": tokenizer.eos_token_id
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# 別スレッドで生成を実行
|
| 103 |
+
# Streamerはスレッドセーフではないため、生成は別スレッドで行う
|
| 104 |
+
# そしてメインスレッドからStreamerをポーリングする
|
| 105 |
+
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
| 106 |
+
thread.start()
|
| 107 |
+
|
| 108 |
+
generated_text = ""
|
| 109 |
+
start_time = time.perf_counter()
|
| 110 |
+
|
| 111 |
+
yield {"type": "thinking", "content": f"Loading model: {model_config.display_name}..."}
|
| 112 |
+
|
| 113 |
+
for new_token in streamer:
|
| 114 |
+
generated_text += new_token
|
| 115 |
+
yield {"type": "token", "content": new_token}
|
| 116 |
+
|
| 117 |
+
thread.join() # 生成が完了するのを待つ
|
| 118 |
+
end_time = time.perf_counter()
|
| 119 |
+
|
| 120 |
+
latency_ms = (end_time - start_time) * 1000
|
| 121 |
+
|
| 122 |
+
yield {"type": "complete", "content": generated_text}
|
| 123 |
+
yield {"type": "meta",
|
| 124 |
+
"confidence": 0.8, # 仮の値
|
| 125 |
+
"thinking": f"Inferred by {model_config.display_name} (HuggingFace).",
|
| 126 |
+
"latency_ms": latency_ms
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
# --- Ollama Provider ---
|
| 130 |
+
class OllamaProvider:
|
| 131 |
+
_ollama_client: Any = None # ollamaクライアントは遅延ロード
|
| 132 |
+
_loaded_models: Dict[str, bool] = {} # {model_name: is_loaded}
|
| 133 |
+
|
| 134 |
+
async def _get_ollama_client(self, api_url: Optional[str]) -> Any:
|
| 135 |
+
try:
|
| 136 |
+
from ollama import AsyncClient
|
| 137 |
+
if api_url and api_url != "http://localhost:11434": # デフォルト以外
|
| 138 |
+
return AsyncClient(host=api_url)
|
| 139 |
+
return AsyncClient() # デフォルトクライアント
|
| 140 |
+
except ImportError:
|
| 141 |
+
logger.error("Ollama library not installed. Please `pip install ollama`.")
|
| 142 |
+
raise
|
| 143 |
+
|
| 144 |
+
async def _ensure_model_available(self, model_name: str, api_url: Optional[str]):
|
| 145 |
+
if self._loaded_models.get(model_name):
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
client = await self._get_ollama_client(api_url)
|
| 149 |
+
logger.info(f"Checking Ollama model: {model_name}...")
|
| 150 |
+
try:
|
| 151 |
+
models = await client.list()
|
| 152 |
+
if not any(m['name'] == model_name for m in models['models']):
|
| 153 |
+
logger.info(f"Ollama model '{model_name}' not found locally. Pulling...")
|
| 154 |
+
await client.pull(model_name)
|
| 155 |
+
self._loaded_models[model_name] = True
|
| 156 |
+
logger.info(f"Ollama model '{model_name}' is available.")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.error(f"Failed to ensure Ollama model '{model_name}' availability: {e}")
|
| 159 |
+
raise
|
| 160 |
+
|
| 161 |
+
async def infer(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
|
| 162 |
+
"""
|
| 163 |
+
Ollamaモデルで非ストリーミング推論を実行する。
|
| 164 |
+
"""
|
| 165 |
+
await self._ensure_model_available(model_config.model_name, model_config.api_url)
|
| 166 |
+
client = await self._get_ollama_client(model_config.api_url)
|
| 167 |
+
|
| 168 |
+
start_time = time.perf_counter()
|
| 169 |
+
response = await client.generate(
|
| 170 |
+
model=model_config.model_name,
|
| 171 |
+
prompt=prompt,
|
| 172 |
+
temperature=temperature,
|
| 173 |
+
options={'num_predict': model_config.max_tokens}
|
| 174 |
+
)
|
| 175 |
+
end_time = time.perf_counter()
|
| 176 |
+
|
| 177 |
+
response_text = response.get('response', '').strip()
|
| 178 |
+
|
| 179 |
+
# If the model returns an empty string for a question generation prompt,
|
| 180 |
+
# default to an empty JSON array to prevent parsing errors downstream.
|
| 181 |
+
if not response_text and "Return ONLY a JSON array of questions" in prompt:
|
| 182 |
+
logger.warning("Ollama model returned an empty response for question generation. Defaulting to an empty list.")
|
| 183 |
+
response_text = "[]"
|
| 184 |
+
|
| 185 |
+
latency_ms = (end_time - start_time) * 1000
|
| 186 |
+
|
| 187 |
+
return {
|
| 188 |
+
"response": response_text,
|
| 189 |
+
"thinking": f"Inferred by {model_config.display_name} (Ollama).",
|
| 190 |
+
"confidence": 0.85, # 仮の値
|
| 191 |
+
"latency_ms": latency_ms
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
async def infer_streaming(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
|
| 195 |
+
"""
|
| 196 |
+
Ollamaモデルでストリーミング推論を実行する。
|
| 197 |
+
"""
|
| 198 |
+
await self._ensure_model_available(model_config.model_name, model_config.api_url)
|
| 199 |
+
client = await self._get_ollama_client(model_config.api_url)
|
| 200 |
+
|
| 201 |
+
start_time = time.perf_counter()
|
| 202 |
+
generated_text = ""
|
| 203 |
+
|
| 204 |
+
yield {"type": "thinking", "content": f"Ensuring Ollama model '{model_config.model_name}' is available..."}
|
| 205 |
+
|
| 206 |
+
async for chunk in await client.generate(
|
| 207 |
+
model=model_config.model_name,
|
| 208 |
+
prompt=prompt,
|
| 209 |
+
temperature=temperature,
|
| 210 |
+
options={'num_predict': model_config.max_tokens},
|
| 211 |
+
stream=True
|
| 212 |
+
):
|
| 213 |
+
if 'response' in chunk:
|
| 214 |
+
generated_text += chunk['response']
|
| 215 |
+
yield {"type": "token", "content": chunk['response']}
|
| 216 |
+
|
| 217 |
+
end_time = time.perf_counter()
|
| 218 |
+
latency_ms = (end_time - start_time) * 1000
|
| 219 |
+
|
| 220 |
+
yield {"type": "complete", "content": generated_text}
|
| 221 |
+
yield {"type": "meta",
|
| 222 |
+
"confidence": 0.85, # 仮の値
|
| 223 |
+
"thinking": f"Inferred by {model_config.display_name} (Ollama).",
|
| 224 |
+
"latency_ms": latency_ms
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
# --- MLX Provider (Apple Silicon) ---
|
| 228 |
+
class MLXProvider:
|
| 229 |
+
_loaded_models: Dict[str, Any] = {}
|
| 230 |
+
|
| 231 |
+
async def _load_model_and_tokenizer(self, model_name: str) -> Dict[str, Any]:
|
| 232 |
+
"""
|
| 233 |
+
MLXモデルとトークナイザーをロード(キャッシュ)する。
|
| 234 |
+
"""
|
| 235 |
+
if model_name in self._loaded_models:
|
| 236 |
+
return self._loaded_models[model_name]
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
import mlx.core as mx
|
| 240 |
+
from mlx_lm import load, generate
|
| 241 |
+
|
| 242 |
+
# MLXはデバイスを自動的に使用
|
| 243 |
+
logger.info(f"Loading MLX model: {model_name}...")
|
| 244 |
+
model, tokenizer = load(model_name)
|
| 245 |
+
self._loaded_models[model_name] = {"model": model, "tokenizer": tokenizer}
|
| 246 |
+
logger.info(f"MLX model '{model_name}' loaded successfully.")
|
| 247 |
+
return self._loaded_models[model_name]
|
| 248 |
+
except ImportError:
|
| 249 |
+
logger.error("MLX library not installed. Please `pip install mlx-lm mlx`.")
|
| 250 |
+
raise
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.error(f"Failed to load MLX model '{model_name}': {e}")
|
| 253 |
+
raise
|
| 254 |
+
|
| 255 |
+
async def infer(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
|
| 256 |
+
"""
|
| 257 |
+
MLXモデルで非ストリーミング推論を実行する。
|
| 258 |
+
"""
|
| 259 |
+
model_data = await self._load_model_and_tokenizer(model_config.model_name)
|
| 260 |
+
model = model_data["model"]
|
| 261 |
+
tokenizer = model_data["tokenizer"]
|
| 262 |
+
|
| 263 |
+
start_time = time.perf_counter()
|
| 264 |
+
# MLX-LMのgenerateはストリーミングではない
|
| 265 |
+
response_text = await asyncio.to_thread(
|
| 266 |
+
lambda: generate(
|
| 267 |
+
model, tokenizer, prompt=prompt, max_tokens=model_config.max_tokens, temp=temperature
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
end_time = time.perf_counter()
|
| 271 |
+
|
| 272 |
+
latency_ms = (end_time - start_time) * 1000
|
| 273 |
+
|
| 274 |
+
return {
|
| 275 |
+
"response": response_text,
|
| 276 |
+
"thinking": f"Inferred by {model_config.display_name} (MLX).",
|
| 277 |
+
"confidence": 0.9, # 仮の値
|
| 278 |
+
"latency_ms": latency_ms
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
async def infer_streaming(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
|
| 282 |
+
"""
|
| 283 |
+
MLXモデルでストリーミング推論を実行する。
|
| 284 |
+
mlx-lmは直接ストリーミングをサポートしていないため、非ストリーミングの結果をチャンクに分割してストリーミングを模倣する。
|
| 285 |
+
"""
|
| 286 |
+
yield {"type": "thinking", "content": f"Loading MLX model: {model_config.display_name}..."}
|
| 287 |
+
response_data = await self.infer(model_config, prompt, temperature)
|
| 288 |
+
response_text = response_data["response"]
|
| 289 |
+
|
| 290 |
+
# 結果をワード単位でストリーミングとして模倣
|
| 291 |
+
for word in response_text.split():
|
| 292 |
+
yield {"type": "token", "content": word + " "}
|
| 293 |
+
await asyncio.sleep(0.05) # ダミーのストリーミング遅延
|
| 294 |
+
|
| 295 |
+
yield {"type": "complete", "content": response_text}
|
| 296 |
+
yield {"type": "meta",
|
| 297 |
+
"confidence": response_data["confidence"],
|
| 298 |
+
"thinking": response_data["thinking"],
|
| 299 |
+
"latency_ms": response_data["latency_ms"]
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
# --- GGUF Provider (llama-cpp-python) ---
|
| 303 |
+
class GGUFProvider:
|
| 304 |
+
_loaded_models: Dict[str, Any] = {} # {file_path: Llama object}
|
| 305 |
+
|
| 306 |
+
async def _load_model(self, model_file_path: str) -> Any:
|
| 307 |
+
"""
|
| 308 |
+
GGUFモデルをロード(キャッシュ)する。
|
| 309 |
+
"""
|
| 310 |
+
if model_file_path in self._loaded_models:
|
| 311 |
+
return self._loaded_models[model_file_path]
|
| 312 |
+
|
| 313 |
+
try:
|
| 314 |
+
from llama_cpp import Llama
|
| 315 |
+
logger.info(f"Loading GGUF model from: {model_file_path}...")
|
| 316 |
+
# n_gpu_layers=-1 はGPUが利用可能な場合に全レイヤーをGPUにオフロードする
|
| 317 |
+
# n_ctxはコンテキストサイズ
|
| 318 |
+
llm = Llama(model_path=model_file_path, n_gpu_layers=-1, n_ctx=4096)
|
| 319 |
+
self._loaded_models[model_file_path] = llm
|
| 320 |
+
logger.info(f"GGUF model '{model_file_path}' loaded successfully.")
|
| 321 |
+
return llm
|
| 322 |
+
except ImportError:
|
| 323 |
+
logger.error("llama-cpp-python not installed. Please `pip install llama-cpp-python`.")
|
| 324 |
+
raise
|
| 325 |
+
except Exception as e:
|
| 326 |
+
logger.error(f"Failed to load GGUF model '{model_file_path}': {e}")
|
| 327 |
+
raise
|
| 328 |
+
|
| 329 |
+
async def infer(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
|
| 330 |
+
"""
|
| 331 |
+
GGUFモデルで非ストリーミング推論を実行する。
|
| 332 |
+
"""
|
| 333 |
+
if not model_config.model_name or not model_config.model_name.endswith('.gguf'):
|
| 334 |
+
raise ValueError(f"GGUF provider requires 'model_name' to be a valid .gguf file path, but got '{model_config.model_name}'")
|
| 335 |
+
|
| 336 |
+
llm = await self._load_model(model_config.model_name) # model_nameがファイルパス
|
| 337 |
+
|
| 338 |
+
# Format prompt with Phi-4 chat template
|
| 339 |
+
formatted_prompt = f"""<|im_start|>system
|
| 340 |
+
You are a helpful AI assistant.<|im_end|>
|
| 341 |
+
<|im_start|>user
|
| 342 |
+
{prompt}<|im_end|>
|
| 343 |
+
<|im_start|>assistant
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
start_time = time.perf_counter()
|
| 347 |
+
logger.info(f"Calling GGUF model with prompt length: {len(prompt)} (formatted: {len(formatted_prompt)})")
|
| 348 |
+
logger.debug(f"Formatted prompt preview: {formatted_prompt[:300]}...")
|
| 349 |
+
|
| 350 |
+
output = llm(
|
| 351 |
+
formatted_prompt,
|
| 352 |
+
max_tokens=model_config.max_tokens,
|
| 353 |
+
temperature=temperature,
|
| 354 |
+
stop=["<|im_end|>", "<|endoftext|>"], # Phi-4 stop tokens
|
| 355 |
+
echo=False
|
| 356 |
+
)
|
| 357 |
+
end_time = time.perf_counter()
|
| 358 |
+
|
| 359 |
+
logger.debug(f"GGUF model output structure: {list(output.keys()) if isinstance(output, dict) else type(output)}")
|
| 360 |
+
|
| 361 |
+
response_text = ""
|
| 362 |
+
if "choices" in output and len(output["choices"]) > 0:
|
| 363 |
+
choice = output["choices"][0]
|
| 364 |
+
logger.debug(f"Choice structure: {list(choice.keys()) if isinstance(choice, dict) else type(choice)}")
|
| 365 |
+
if "text" in choice:
|
| 366 |
+
response_text = choice["text"].strip()
|
| 367 |
+
logger.info(f"Extracted response text length: {len(response_text)}")
|
| 368 |
+
else:
|
| 369 |
+
logger.warning(f"No choices in output or choices is empty. Output: {output}")
|
| 370 |
+
|
| 371 |
+
# If the model returns an empty string for a question generation prompt,
|
| 372 |
+
# default to an empty JSON array to prevent parsing errors downstream.
|
| 373 |
+
if not response_text and "Return ONLY a JSON array of questions" in prompt:
|
| 374 |
+
logger.warning("GGUF model returned an empty response for question generation. Defaulting to an empty list.")
|
| 375 |
+
response_text = "[]"
|
| 376 |
+
elif not response_text:
|
| 377 |
+
logger.error(f"GGUF model returned empty response. Prompt preview: {prompt[:200]}")
|
| 378 |
+
|
| 379 |
+
latency_ms = (end_time - start_time) * 1000
|
| 380 |
+
|
| 381 |
+
return {
|
| 382 |
+
"response": response_text,
|
| 383 |
+
"thinking": f"Inferred by {model_config.display_name} (GGUF via llama-cpp-python).",
|
| 384 |
+
"confidence": 0.92, # 仮の値
|
| 385 |
+
"latency_ms": latency_ms
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
async def infer_streaming(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
|
| 389 |
+
"""
|
| 390 |
+
GGUFモデルでストリーミング推論を実行する。
|
| 391 |
+
"""
|
| 392 |
+
if not model_config.model_name or not model_config.model_name.endswith('.gguf'):
|
| 393 |
+
raise ValueError(f"GGUF provider requires 'model_name' to be a valid .gguf file path, but got '{model_config.model_name}'")
|
| 394 |
+
|
| 395 |
+
llm = await self._load_model(model_config.model_name)
|
| 396 |
+
|
| 397 |
+
# Format prompt with Phi-4 chat template
|
| 398 |
+
formatted_prompt = f"""<|im_start|>system
|
| 399 |
+
You are a helpful AI assistant.<|im_end|>
|
| 400 |
+
<|im_start|>user
|
| 401 |
+
{prompt}<|im_end|>
|
| 402 |
+
<|im_start|>assistant
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
start_time = time.perf_counter()
|
| 406 |
+
generated_text = ""
|
| 407 |
+
|
| 408 |
+
yield {"type": "thinking", "content": f"Loading GGUF model: {model_config.display_name}..."}
|
| 409 |
+
|
| 410 |
+
# llama_cppのcreate_completionはストリーミング可能
|
| 411 |
+
for chunk in llm(
|
| 412 |
+
formatted_prompt,
|
| 413 |
+
max_tokens=model_config.max_tokens,
|
| 414 |
+
temperature=temperature,
|
| 415 |
+
stop=["<|im_end|>", "<|endoftext|>"], # Phi-4 stop tokens
|
| 416 |
+
echo=False,
|
| 417 |
+
stream=True
|
| 418 |
+
):
|
| 419 |
+
token = chunk["choices"][0]["text"]
|
| 420 |
+
if token:
|
| 421 |
+
generated_text += token
|
| 422 |
+
yield {"type": "token", "content": token}
|
| 423 |
+
|
| 424 |
+
end_time = time.perf_counter()
|
| 425 |
+
latency_ms = (end_time - start_time) * 1000
|
| 426 |
+
|
| 427 |
+
yield {"type": "complete", "content": generated_text}
|
| 428 |
+
yield {"type": "meta",
|
| 429 |
+
"confidence": 0.92, # 仮の値
|
| 430 |
+
"thinking": f"Inferred by {model_config.display_name} (GGUF via llama-cpp-python).",
|
| 431 |
+
"latency_ms": latency_ms
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
# TODO: OllamaProvider, MLXProvider, GGUFProvider を実装
|
model_router.py
ADDED
|
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional, Dict, Any, AsyncGenerator, List
|
| 3 |
+
import asyncio
|
| 4 |
+
import time
|
| 5 |
+
import threading # For TextIteratorStreamer
|
| 6 |
+
import uuid # For generating unique IDs
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from backend.app.config import ConfigManager, ModelConfig, ModelProvider
|
| 10 |
+
# SessionLocal は循環インポートを避けるため遅延インポート
|
| 11 |
+
from backend.app.services.knowledge_service import KnowledgeService, get_knowledge_service
|
| 12 |
+
from null_ai.llm_providers import HuggingFaceProvider, OllamaProvider, MLXProvider, GGUFProvider # Import all providers
|
| 13 |
+
from null_ai.iath_memory import DendriticMemorySpace # 樹木型空間記憶
|
| 14 |
+
from null_ai.coordinate_estimator import CoordinateEstimator # 座標自動推定
|
| 15 |
+
from null_ai.iath_writer import IathWriter, create_tile_from_ai_output # .iath保存
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class ModelRouter:
|
| 20 |
+
"""
|
| 21 |
+
推論エンジン間でモデルのルーティングと切り替えを管理するクラス。
|
| 22 |
+
師匠(Master)モデルと弟子(Apprentice)モデルの状態を保持する。
|
| 23 |
+
"""
|
| 24 |
+
_instance: Optional['ModelRouter'] = None
|
| 25 |
+
_initialized = False
|
| 26 |
+
|
| 27 |
+
def __new__(cls, config_manager: ConfigManager):
|
| 28 |
+
if cls._instance == None:
|
| 29 |
+
cls._instance = super(ModelRouter, cls).__new__(cls)
|
| 30 |
+
return cls._instance
|
| 31 |
+
|
| 32 |
+
def __init__(self, config_manager: ConfigManager):
|
| 33 |
+
if self._initialized:
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
self.config_manager = config_manager
|
| 37 |
+
self.knowledge_service = get_knowledge_service() # Instantiate KnowledgeService internally
|
| 38 |
+
self.master_model: Optional[ModelConfig] = None
|
| 39 |
+
self.apprentice_model: Optional[ModelConfig] = None
|
| 40 |
+
self.active_domain_id: Optional[str] = None # Active domain managed by ConfigManager
|
| 41 |
+
self.managed_engines: Dict[str, Dict[str, Any]] = {} # Tracks all engines and their status
|
| 42 |
+
|
| 43 |
+
# 樹木型空間記憶(.iathファイル)の初期化
|
| 44 |
+
self.dendritic_memory: Optional[DendriticMemorySpace] = None
|
| 45 |
+
self._load_dendritic_memory()
|
| 46 |
+
|
| 47 |
+
# 座標自動推定器の初期化
|
| 48 |
+
self.coordinate_estimator = CoordinateEstimator()
|
| 49 |
+
|
| 50 |
+
# .iathライターの初期化
|
| 51 |
+
iath_file_path = os.getenv("IATH_DB_PATH", "knowledge_base.iath")
|
| 52 |
+
self.iath_writer = IathWriter(iath_file_path)
|
| 53 |
+
|
| 54 |
+
# LLMプロバイダーを初期化
|
| 55 |
+
self.providers: Dict[ModelProvider, Any] = {
|
| 56 |
+
ModelProvider.HUGGINGFACE: HuggingFaceProvider(),
|
| 57 |
+
ModelProvider.OLLAMA: OllamaProvider(),
|
| 58 |
+
ModelProvider.MLX: MLXProvider(),
|
| 59 |
+
ModelProvider.GGUF: GGUFProvider(),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# 初期ロード時にデフォルトモデルを設定
|
| 63 |
+
self._load_all_engines() # Load all configured models into managed_engines
|
| 64 |
+
self._load_default_models() # This will now also try to load active master/apprentice from config
|
| 65 |
+
|
| 66 |
+
self._initialized = True
|
| 67 |
+
logger.info("ModelRouter initialized.")
|
| 68 |
+
|
| 69 |
+
def _load_dendritic_memory(self):
|
| 70 |
+
"""樹木型空間記憶(.iathファイル)をロードする"""
|
| 71 |
+
try:
|
| 72 |
+
# TODO: ドメインごとに異なる.iathファイルを使用する
|
| 73 |
+
# 現在はデフォルトのパスを使用
|
| 74 |
+
import os
|
| 75 |
+
iath_file_path = os.getenv("IATH_DB_PATH", "knowledge_base.iath")
|
| 76 |
+
|
| 77 |
+
if os.path.exists(iath_file_path):
|
| 78 |
+
self.dendritic_memory = DendriticMemorySpace(iath_file_path)
|
| 79 |
+
stats = self.dendritic_memory.get_statistics()
|
| 80 |
+
logger.info(f"Dendritic memory loaded: {stats['total_tiles']} tiles from {iath_file_path}")
|
| 81 |
+
else:
|
| 82 |
+
logger.warning(f".iath file not found: {iath_file_path}. Starting with empty memory.")
|
| 83 |
+
self.dendritic_memory = None
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.error(f"Failed to load dendritic memory: {e}")
|
| 86 |
+
self.dendritic_memory = None
|
| 87 |
+
|
| 88 |
+
def _get_any_available_model(self) -> Optional[ModelConfig]:
|
| 89 |
+
"""
|
| 90 |
+
利用可能なモデルを1つ取得する(座標推定などで使用)
|
| 91 |
+
|
| 92 |
+
優先順位:
|
| 93 |
+
1. 師匠モデル
|
| 94 |
+
2. 設定された最初のモデル
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
ModelConfig or None
|
| 98 |
+
"""
|
| 99 |
+
if self.master_model:
|
| 100 |
+
return self.master_model
|
| 101 |
+
|
| 102 |
+
# 設定から最初のモデルを取得
|
| 103 |
+
for model_id, model_config in self.config_manager.models.items():
|
| 104 |
+
return model_config
|
| 105 |
+
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def _load_all_engines(self):
|
| 109 |
+
"""設定ファイルから全てのモデルをロードし、managed_enginesを初期化する"""
|
| 110 |
+
self.managed_engines = {}
|
| 111 |
+
for model_id, model_config in self.config_manager.models.items():
|
| 112 |
+
self.managed_engines[model_id] = {
|
| 113 |
+
"config": model_config.dict(),
|
| 114 |
+
"status": "available", # Default status
|
| 115 |
+
"unique_id": None # Only for apprentices
|
| 116 |
+
}
|
| 117 |
+
logger.info(f"Loaded {len(self.managed_engines)} engines into management.")
|
| 118 |
+
|
| 119 |
+
def _update_engine_status(self, model_id: str, status: str, unique_id: Optional[str] = None):
|
| 120 |
+
"""managed_engines内のエンジンのステータスを更新するヘルパー"""
|
| 121 |
+
if model_id in self.managed_engines:
|
| 122 |
+
self.managed_engines[model_id]["status"] = status
|
| 123 |
+
self.managed_engines[model_id]["unique_id"] = unique_id
|
| 124 |
+
else:
|
| 125 |
+
logger.warning(f"Attempted to update status for unknown engine: {model_id}")
|
| 126 |
+
|
| 127 |
+
def _load_default_models(self):
|
| 128 |
+
"""設定からデフォルトの師匠モデルと弟子モデルをロードする"""
|
| 129 |
+
self.active_domain_id = self.config_manager.get_active_domain_id()
|
| 130 |
+
|
| 131 |
+
# ConfigManagerから永続化されたアクティブな師匠・弟子モデルIDをロード
|
| 132 |
+
persisted_master_id = self.config_manager.get_null_ai_setting("active_master_id")
|
| 133 |
+
persisted_apprentice_id = self.config_manager.get_null_ai_setting("active_apprentice_id")
|
| 134 |
+
|
| 135 |
+
# 師匠モデルの設定
|
| 136 |
+
if persisted_master_id:
|
| 137 |
+
master_config = self.config_manager.get_model_config(persisted_master_id)
|
| 138 |
+
if master_config:
|
| 139 |
+
self.master_model = master_config
|
| 140 |
+
self._update_engine_status(self.master_model.model_id, "master")
|
| 141 |
+
logger.info(f"Persisted master model loaded: {self.master_model.display_name}")
|
| 142 |
+
else:
|
| 143 |
+
logger.warning(f"Persisted master model '{persisted_master_id}' not found in configuration. Attempting to set default.")
|
| 144 |
+
self._set_initial_master_from_config() # 永続化されたモデルが見つからない場合はデフォルトを設定
|
| 145 |
+
else:
|
| 146 |
+
self._set_initial_master_from_config() # 永続化されたマスターIDがない場合はデフォルトを設定
|
| 147 |
+
|
| 148 |
+
# 弟子モデルの設定
|
| 149 |
+
if persisted_apprentice_id:
|
| 150 |
+
apprentice_config = self.config_manager.get_model_config(persisted_apprentice_id)
|
| 151 |
+
if apprentice_config:
|
| 152 |
+
self.apprentice_model = apprentice_config
|
| 153 |
+
# unique_idもconfigからロードする
|
| 154 |
+
apprentice_unique_id = self.config_manager.get_null_ai_setting(f"apprentice_unique_id_{persisted_apprentice_id}")
|
| 155 |
+
self._update_engine_status(self.apprentice_model.model_id, "apprentice", apprentice_unique_id)
|
| 156 |
+
logger.info(f"Persisted apprentice model loaded: {self.apprentice_model.display_name}")
|
| 157 |
+
else:
|
| 158 |
+
logger.warning(f"Persisted apprentice model '{persisted_apprentice_id}' not found in configuration. Clearing active apprentice.")
|
| 159 |
+
self.apprentice_model = None
|
| 160 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 161 |
+
else:
|
| 162 |
+
self.apprentice_model = None
|
| 163 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 164 |
+
|
| 165 |
+
def _set_initial_master_from_config(self):
|
| 166 |
+
"""設定からデフォルトの師匠モデルをロード(persistedがない場合や見つからない場合)"""
|
| 167 |
+
default_master_config = self.config_manager.get_default_model_config(domain_id=self.active_domain_id)
|
| 168 |
+
if default_master_config:
|
| 169 |
+
self.master_model = default_master_config
|
| 170 |
+
self._update_engine_status(self.master_model.model_id, "master")
|
| 171 |
+
self.config_manager.set_null_ai_setting("active_master_id", self.master_model.model_id)
|
| 172 |
+
logger.info(f"Default master model loaded for domain '{self.active_domain_id}': {self.master_model.display_name}")
|
| 173 |
+
else:
|
| 174 |
+
logger.warning(f"No default master model found for domain '{self.active_domain_id}' in configuration. Master model remains unset.")
|
| 175 |
+
self.master_model = None
|
| 176 |
+
self.config_manager.set_null_ai_setting("active_master_id", None)
|
| 177 |
+
|
| 178 |
+
def set_active_domain_id(self, domain_id: str):
|
| 179 |
+
"""アクティブなドメインIDを設定し、それに応じてモデルを再ロードする"""
|
| 180 |
+
if self.active_domain_id != domain_id:
|
| 181 |
+
self.active_domain_id = domain_id
|
| 182 |
+
self._load_default_models() # アクティブドメイン変更時は師匠・弟子モデルも再設定
|
| 183 |
+
logger.info(f"ModelRouter active domain set to {domain_id} and models reloaded.")
|
| 184 |
+
|
| 185 |
+
def set_master_model(self, model_id: str) -> bool:
|
| 186 |
+
"""師匠モデルを設定する"""
|
| 187 |
+
model = self.config_manager.get_model_config(model_id)
|
| 188 |
+
if model:
|
| 189 |
+
# 古い師匠を'retired'に戻す (ただし、それがまさに今昇格している弟子ではない場合)
|
| 190 |
+
if self.master_model and self.master_model.model_id != model_id:
|
| 191 |
+
# 昇格の場合、古い師匠はretiredにする
|
| 192 |
+
self._update_engine_status(self.master_model.model_id, "retired")
|
| 193 |
+
logger.info(f"Old master model '{self.master_model.display_name}' set to 'retired'.")
|
| 194 |
+
|
| 195 |
+
self.master_model = model
|
| 196 |
+
self._update_engine_status(model_id, "master")
|
| 197 |
+
self.config_manager.set_null_ai_setting("active_master_id", model_id)
|
| 198 |
+
logger.info(f"Master model set to: {model.display_name}")
|
| 199 |
+
return True
|
| 200 |
+
logger.error(f"Model with ID '{model_id}' not found for master setting.")
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
def set_apprentice_model(self, model_id: Optional[str]) -> bool:
|
| 204 |
+
"""弟子モデルを設定する (Noneでクリア)"""
|
| 205 |
+
# 古い弟子を'available'に戻す
|
| 206 |
+
if self.apprentice_model:
|
| 207 |
+
self._update_engine_status(self.apprentice_model.model_id, "available")
|
| 208 |
+
|
| 209 |
+
if model_id is None or model_id == 'none':
|
| 210 |
+
self.apprentice_model = None
|
| 211 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 212 |
+
logger.info("Apprentice model cleared.")
|
| 213 |
+
return True
|
| 214 |
+
|
| 215 |
+
model = self.config_manager.get_model_config(model_id)
|
| 216 |
+
if model:
|
| 217 |
+
self.apprentice_model = model
|
| 218 |
+
# For apprentices, we need to ensure unique_id is tracked if this is a named apprentice
|
| 219 |
+
apprentice_unique_id = self.managed_engines[model_id].get("unique_id") # Get existing unique_id
|
| 220 |
+
self._update_engine_status(model_id, "apprentice", apprentice_unique_id)
|
| 221 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", model_id)
|
| 222 |
+
logger.info(f"Apprentice model set to: {model.display_name}")
|
| 223 |
+
return True
|
| 224 |
+
logger.error(f"Model with ID '{model_id}' not found for apprentice setting.")
|
| 225 |
+
return False
|
| 226 |
+
|
| 227 |
+
def get_all_managed_engines(self) -> List[Dict[str, Any]]:
|
| 228 |
+
"""管理している全てのエンジンとそのステータス、ユニークIDを含むリストを返す"""
|
| 229 |
+
return list(self.managed_engines.values())
|
| 230 |
+
|
| 231 |
+
def get_master_model(self) -> Optional[ModelConfig]:
|
| 232 |
+
"""現在の師匠モデルを取得する"""
|
| 233 |
+
return self.master_model
|
| 234 |
+
|
| 235 |
+
def get_apprentice_model(self) -> Optional[ModelConfig]:
|
| 236 |
+
"""現在の弟子モデルを取得する"""
|
| 237 |
+
return self.apprentice_model
|
| 238 |
+
|
| 239 |
+
def get_model_config(self, model_id: str) -> Optional[ModelConfig]:
|
| 240 |
+
"""指定されたmodel_idのモデル設定を取得する"""
|
| 241 |
+
return self.config_manager.models.get(model_id)
|
| 242 |
+
|
| 243 |
+
def swap_engines(self, apprentice_model_id: str) -> bool:
|
| 244 |
+
"""師匠と指定した弟子を入れ替える"""
|
| 245 |
+
if not self.master_model:
|
| 246 |
+
logger.error("Cannot swap: No master model is currently set.")
|
| 247 |
+
return False
|
| 248 |
+
|
| 249 |
+
apprentice_candidate = self.config_manager.get_model_config(apprentice_model_id)
|
| 250 |
+
if not apprentice_candidate:
|
| 251 |
+
logger.error(f"Cannot swap: Apprentice model '{apprentice_model_id}' not found.")
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
# 現在の師匠と弟子の情報を取得
|
| 255 |
+
old_master_model = self.master_model
|
| 256 |
+
|
| 257 |
+
# 新しい師匠は指定された弟子
|
| 258 |
+
new_master_model = apprentice_candidate
|
| 259 |
+
|
| 260 |
+
# 新しい弟子は古い師匠 (ただし、古い師匠が現在のアプレンティス候補でなければ)
|
| 261 |
+
new_apprentice_model = old_master_model if old_master_model.model_id != apprentice_model_id else None
|
| 262 |
+
|
| 263 |
+
# ロールを入れ替える
|
| 264 |
+
self.master_model = new_master_model
|
| 265 |
+
self.apprentice_model = new_apprentice_model
|
| 266 |
+
|
| 267 |
+
# ステータスと永続化を更新
|
| 268 |
+
self._update_engine_status(new_master_model.model_id, "master", self.managed_engines[new_master_model.model_id].get("unique_id"))
|
| 269 |
+
self.config_manager.set_null_ai_setting("active_master_id", new_master_model.model_id)
|
| 270 |
+
|
| 271 |
+
if new_apprentice_model:
|
| 272 |
+
self._update_engine_status(new_apprentice_model.model_id, "apprentice", self.managed_engines[new_apprentice_model.model_id].get("unique_id"))
|
| 273 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", new_apprentice_model.model_id)
|
| 274 |
+
else:
|
| 275 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 276 |
+
|
| 277 |
+
# 古いアプレンティスが指定されたアプレンティスとは異なる場合、そのステータスをavailableに戻す
|
| 278 |
+
if old_master_model and old_master_model.model_id != new_master_model.model_id:
|
| 279 |
+
self._update_engine_status(old_master_model.model_id, "available") # Old master becomes available if not the new apprentice
|
| 280 |
+
|
| 281 |
+
logger.info(f"Engines swapped: New Master is {new_master_model.display_name}, New Apprentice is {new_apprentice_model.display_name if new_apprentice_model else 'None'}")
|
| 282 |
+
return True
|
| 283 |
+
|
| 284 |
+
def promote_apprentice(self, apprentice_model_id: str) -> bool:
|
| 285 |
+
"""指定した弟子を師匠に昇格させ、現在の師匠を引退させる"""
|
| 286 |
+
apprentice_to_promote = self.config_manager.get_model_config(apprentice_model_id)
|
| 287 |
+
if not apprentice_to_promote:
|
| 288 |
+
logger.error(f"Cannot promote: Apprentice model '{apprentice_model_id}' not found.")
|
| 289 |
+
return False
|
| 290 |
+
|
| 291 |
+
# 現在の師匠を引退させる
|
| 292 |
+
if self.master_model:
|
| 293 |
+
self._update_engine_status(self.master_model.model_id, "retired")
|
| 294 |
+
logger.info(f"Current master model '{self.master_model.display_name}' retired.")
|
| 295 |
+
|
| 296 |
+
# 弟子を師匠に昇格させる
|
| 297 |
+
self.master_model = apprentice_to_promote
|
| 298 |
+
self._update_engine_status(apprentice_to_promote.model_id, "master", self.managed_engines[apprentice_to_promote.model_id].get("unique_id"))
|
| 299 |
+
self.config_manager.set_null_ai_setting("active_master_id", apprentice_to_promote.model_id)
|
| 300 |
+
|
| 301 |
+
# 弟子モデルのロールをクリア (昇格したため)
|
| 302 |
+
if self.apprentice_model and self.apprentice_model.model_id == apprentice_model_id:
|
| 303 |
+
self.apprentice_model = None
|
| 304 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 305 |
+
logger.info(f"Apprentice model '{apprentice_to_promote.display_name}' promoted to master and apprentice role cleared.")
|
| 306 |
+
else:
|
| 307 |
+
logger.warning(f"Apprentice '{apprentice_model_id}' was promoted, but was not the currently active apprentice.")
|
| 308 |
+
|
| 309 |
+
logger.info(f"Apprentice '{apprentice_to_promote.display_name}' promoted to Master.")
|
| 310 |
+
return True
|
| 311 |
+
|
| 312 |
+
def create_new_apprentice(self) -> Optional[Dict[str, Any]]:
|
| 313 |
+
"""新しい「空っぽの弟子」推論エンジンを生成し、登録する"""
|
| 314 |
+
new_apprentice_unique_id = str(uuid.uuid4())
|
| 315 |
+
new_apprentice_model_id = f"apprentice-{new_apprentice_unique_id[:8]}"
|
| 316 |
+
new_apprentice_display_name = f"Apprentice ({new_apprentice_unique_id[:4]})"
|
| 317 |
+
|
| 318 |
+
# ユーザーは「文字通り学習データがないもの」と指定。
|
| 319 |
+
# ここでは、汎用的なベースモデルを想定するか、あるいは特定のプロバイダー/モデル名を指定する。
|
| 320 |
+
# 仮に、Ollamaの何らかの軽量モデルをテンプレートとして利用する。
|
| 321 |
+
# または、単にモデル設定のみを生成し、実際のモデルファイルは後で学習時にロードする。
|
| 322 |
+
|
| 323 |
+
# NOTE: この "empty" モデルが具体的に何を指すかは、
|
| 324 |
+
# 後続のファインチューニングのロジックによって具体化される必要がある。
|
| 325 |
+
# ここでは、設定上の一エントリとして追加する。
|
| 326 |
+
|
| 327 |
+
# TODO: ベースとなる「空っぽ」モデルの設定をnull_ai_config.jsonなどから取得できるようにする
|
| 328 |
+
base_apprentice_provider = ModelProvider.OLLAMA # 仮のデフォルトプロバイダー
|
| 329 |
+
base_apprentice_model_name = "mistral:latest" # 仮のデフォルトモデル名
|
| 330 |
+
|
| 331 |
+
# Configuration for the new apprentice
|
| 332 |
+
new_model_config_data = {
|
| 333 |
+
"model_id": new_apprentice_model_id,
|
| 334 |
+
"display_name": new_apprentice_display_name,
|
| 335 |
+
"provider": base_apprentice_provider.value,
|
| 336 |
+
"model_name": base_apprentice_model_name,
|
| 337 |
+
"max_tokens": 4096,
|
| 338 |
+
"temperature": 0.7,
|
| 339 |
+
"timeout": 120,
|
| 340 |
+
"is_default": False,
|
| 341 |
+
"supported_domains": ["general"], # New apprentices start with general knowledge
|
| 342 |
+
"description": f"Newly generated empty apprentice model with unique ID {new_apprentice_unique_id}."
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
try:
|
| 346 |
+
new_model = self.config_manager.add_model(new_model_config_data)
|
| 347 |
+
if not new_model:
|
| 348 |
+
raise Exception("Failed to add new model via config_manager.")
|
| 349 |
+
|
| 350 |
+
self.managed_engines[new_model.model_id] = {
|
| 351 |
+
"config": new_model.dict(),
|
| 352 |
+
"status": "available",
|
| 353 |
+
"unique_id": new_apprentice_unique_id
|
| 354 |
+
}
|
| 355 |
+
# unique_idも永続化しておく
|
| 356 |
+
self.config_manager.set_null_ai_setting(f"apprentice_unique_id_{new_model.model_id}", new_apprentice_unique_id)
|
| 357 |
+
|
| 358 |
+
logger.info(f"New empty apprentice '{new_apprentice_display_name}' created with ID: {new_apprentice_model_id}")
|
| 359 |
+
return self.managed_engines[new_model.model_id]
|
| 360 |
+
except Exception as e:
|
| 361 |
+
logger.error(f"Failed to create new apprentice: {e}")
|
| 362 |
+
return None
|
| 363 |
+
|
| 364 |
+
def get_active_model(self, for_inference: bool = True) -> Optional[ModelConfig]:
|
| 365 |
+
"""
|
| 366 |
+
推論に使用するアクティブなモデルを取得する。
|
| 367 |
+
for_inferenceがTrueの場合、師匠モデルを返す。
|
| 368 |
+
倒木システムにおける「学習」などの用途で弟子モデルが必要な場合は、
|
| 369 |
+
直接get_apprentice_modelを使用する。
|
| 370 |
+
"""
|
| 371 |
+
if for_inference:
|
| 372 |
+
return self.master_model
|
| 373 |
+
# 将来的に「成長した」弟子モデルへの自動切り替えロジックが入る可能性
|
| 374 |
+
return self.master_model # デフォルトでは師匠モデルを使用
|
| 375 |
+
|
| 376 |
+
async def infer(self, prompt: str, domain_id: str, model_config: ModelConfig, temperature: Optional[float] = None, save_to_memory: bool = False, rag_mode: str = "rag") -> Dict[str, Any]:
|
| 377 |
+
"""
|
| 378 |
+
指定されたモデルで推論を実行する。
|
| 379 |
+
RAG: DBに知識があればそれを使用、なければAI内部知識で推論してDBに蓄積。
|
| 380 |
+
rag_modeに応じて動作を切り替える。
|
| 381 |
+
"""
|
| 382 |
+
# rag_mode == 'direct' の場合の処理
|
| 383 |
+
if rag_mode == "direct":
|
| 384 |
+
relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=1)
|
| 385 |
+
if relevant_knowledge:
|
| 386 |
+
# DBから直接回答を返す
|
| 387 |
+
best_match = relevant_knowledge[0]
|
| 388 |
+
logger.info(f"Direct mode: DB knowledge found. Returning direct answer from tile {best_match['id']}.")
|
| 389 |
+
return {
|
| 390 |
+
"response": best_match['content'],
|
| 391 |
+
"thinking": f"Directly retrieved from knowledge base. Tile ID: {best_match['id']}.",
|
| 392 |
+
"confidence": best_match['confidence_score'],
|
| 393 |
+
"model_used": "database_direct",
|
| 394 |
+
"latency_ms": 50, # DB検索なので高速
|
| 395 |
+
"source_type": "db_direct"
|
| 396 |
+
}
|
| 397 |
+
else:
|
| 398 |
+
# DBに情報がない場合は「わかりません」と返す
|
| 399 |
+
logger.info("Direct mode: No DB knowledge found. Returning 'Not found'.")
|
| 400 |
+
return {
|
| 401 |
+
"response": "ご指定の情報はナレッジベース内に見つかりませんでした。",
|
| 402 |
+
"thinking": "Directly searched knowledge base, but no relevant information was found.",
|
| 403 |
+
"confidence": 0.9, # 「見つからない」という回答には高い信頼度を与える
|
| 404 |
+
"model_used": "database_direct",
|
| 405 |
+
"latency_ms": 50,
|
| 406 |
+
"source_type": "db_direct"
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
# rag_mode == 'rag' の場合の処理 (既存のロジック)
|
| 410 |
+
has_db_knowledge = self._check_db_knowledge(domain_id, prompt)
|
| 411 |
+
source_type = "db_augmented" if has_db_knowledge else "ai_internal_weights"
|
| 412 |
+
thinking_process = ""
|
| 413 |
+
augmented_prompt = prompt
|
| 414 |
+
|
| 415 |
+
if has_db_knowledge:
|
| 416 |
+
# DBから関連知識を取得
|
| 417 |
+
relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=3)
|
| 418 |
+
logger.info(f"RAG mode: DB knowledge found for domain '{domain_id}'. Retrieved {len(relevant_knowledge)} relevant tiles.")
|
| 419 |
+
|
| 420 |
+
# RAG: プロンプトに知識を統合
|
| 421 |
+
knowledge_context = "\n\n".join([
|
| 422 |
+
f"[Knowledge {i+1} - {tile['verification_type']} verification, confidence: {tile['confidence_score']}]\n"
|
| 423 |
+
f"Topic: {tile['topic']}\n"
|
| 424 |
+
f"Content: {tile['content']}"
|
| 425 |
+
for i, tile in enumerate(relevant_knowledge)
|
| 426 |
+
])
|
| 427 |
+
|
| 428 |
+
augmented_prompt = f"""Based on the following verified knowledge from the database:
|
| 429 |
+
|
| 430 |
+
{knowledge_context}
|
| 431 |
+
|
| 432 |
+
Now, please answer the following question accurately:
|
| 433 |
+
{prompt}"""
|
| 434 |
+
|
| 435 |
+
response = await self._perform_llm_inference(model_config, augmented_prompt, temperature or model_config.temperature)
|
| 436 |
+
thinking_process = f"Accessed knowledge base. Retrieved {len(relevant_knowledge)} relevant knowledge tiles. " + response.get("thinking", "")
|
| 437 |
+
else:
|
| 438 |
+
# DBに知識がない場合、AI内部知識で推論
|
| 439 |
+
logger.info(f"RAG mode: No DB knowledge found for domain '{domain_id}'. Using AI internal weights.")
|
| 440 |
+
response = await self._perform_llm_inference(model_config, prompt, temperature or model_config.temperature)
|
| 441 |
+
thinking_process = "No specific DB knowledge found. Inferred from AI internal weights. " + response.get("thinking", "")
|
| 442 |
+
|
| 443 |
+
# 自己拡充: 推論結果をDBに保存
|
| 444 |
+
if save_to_memory and response.get("confidence", 0) >= 0.7:
|
| 445 |
+
await self._save_inference_to_db(
|
| 446 |
+
domain_id=domain_id,
|
| 447 |
+
prompt=prompt,
|
| 448 |
+
response=response.get("response", ""),
|
| 449 |
+
confidence=response.get("confidence", 0.7),
|
| 450 |
+
source_type="ai",
|
| 451 |
+
model_id=model_config.model_id
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
# 倒木システム: 師匠の出力を教師データとして保存
|
| 455 |
+
is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
|
| 456 |
+
logger.info(f"[Training Data Check] master_model={self.master_model.model_id if self.master_model else None}, "
|
| 457 |
+
f"current_model={model_config.model_id}, is_master={is_master}, "
|
| 458 |
+
f"confidence={response.get('confidence', 0)}, save_to_memory={save_to_memory}")
|
| 459 |
+
if is_master and response.get("confidence", 0) >= 0.8:
|
| 460 |
+
logger.info(f"[Training Data] Saving master output for domain '{domain_id}'...")
|
| 461 |
+
await self._save_master_output_as_training_data(
|
| 462 |
+
prompt=prompt,
|
| 463 |
+
master_response=response.get("response", ""),
|
| 464 |
+
domain_id=domain_id,
|
| 465 |
+
confidence=response.get("confidence", 0.8)
|
| 466 |
+
)
|
| 467 |
+
else:
|
| 468 |
+
logger.info(f"[Training Data] NOT saving: is_master={is_master}, confidence={response.get('confidence', 0)}")
|
| 469 |
+
|
| 470 |
+
return {
|
| 471 |
+
"response": response.get("response", ""),
|
| 472 |
+
"thinking": thinking_process,
|
| 473 |
+
"confidence": response.get("confidence", 0.7),
|
| 474 |
+
"model_used": model_config.model_id,
|
| 475 |
+
"latency_ms": response.get("latency_ms", 0),
|
| 476 |
+
"source_type": source_type
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
async def infer_streaming(self, prompt: str, domain_id: str, model_config: ModelConfig, temperature: Optional[float] = None, save_to_memory: bool = False, rag_mode: str = "rag") -> AsyncGenerator[Dict[str, Any], None]:
|
| 480 |
+
print(f"DEBUG: infer_streaming called - model={model_config.model_id if model_config else 'None'}")
|
| 481 |
+
logger.error(f"[CRITICAL DEBUG] infer_streaming ENTRY POINT - model={model_config.model_id if model_config else 'None'}, save_to_memory={save_to_memory}")
|
| 482 |
+
"""
|
| 483 |
+
指定されたモデルで推論をストリーミングで実行する。
|
| 484 |
+
RAG: DBに知識があればそれを使用、なければAI内部知識で推論してDBに蓄積。
|
| 485 |
+
rag_modeに応じて動作を切り替える。
|
| 486 |
+
"""
|
| 487 |
+
logger.info(f"[INFER_STREAMING] Called with model={model_config.model_id}, domain={domain_id}, save_to_memory={save_to_memory}")
|
| 488 |
+
|
| 489 |
+
# 師匠の応答を蓄積するための変数(finallyブロックで使用)
|
| 490 |
+
generated_response = ""
|
| 491 |
+
|
| 492 |
+
try:
|
| 493 |
+
# rag_mode == 'direct' の場合の処理
|
| 494 |
+
if rag_mode == "direct":
|
| 495 |
+
relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=1)
|
| 496 |
+
if relevant_knowledge:
|
| 497 |
+
best_match = relevant_knowledge[0]
|
| 498 |
+
logger.info(f"Direct mode (streaming): DB knowledge found. Yielding direct answer from tile {best_match['id']}.")
|
| 499 |
+
yield {"type": "token", "content": best_match['content']}
|
| 500 |
+
yield {"type": "meta", "source_type": "db_direct", "thinking": f"Directly retrieved from knowledge base. Tile ID: {best_match['id']}.", "confidence": best_match['confidence_score'], "model_used": "database_direct"}
|
| 501 |
+
yield {"type": "complete", "content": best_match['content']}
|
| 502 |
+
return
|
| 503 |
+
else:
|
| 504 |
+
logger.info("Direct mode (streaming): No DB knowledge found. Yielding 'Not found'.")
|
| 505 |
+
response_text = "ご指定の情報はナレッジベース内に見つかりませんでした。"
|
| 506 |
+
yield {"type": "token", "content": response_text}
|
| 507 |
+
yield {"type": "meta", "source_type": "db_direct", "thinking": "Directly searched knowledge base, but no relevant information was found.", "confidence": 0.9, "model_used": "database_direct"}
|
| 508 |
+
yield {"type": "complete", "content": response_text}
|
| 509 |
+
return
|
| 510 |
+
|
| 511 |
+
# rag_mode == 'rag' の場合の処理 (既存のロジック)
|
| 512 |
+
has_db_knowledge = self._check_db_knowledge(domain_id, prompt)
|
| 513 |
+
source_type = "db_augmented" if has_db_knowledge else "ai_internal_weights"
|
| 514 |
+
augmented_prompt = prompt
|
| 515 |
+
|
| 516 |
+
yield {"type": "thinking", "content": f"Checking DB knowledge for domain '{domain_id}'..."}
|
| 517 |
+
await asyncio.sleep(0.1)
|
| 518 |
+
|
| 519 |
+
if has_db_knowledge:
|
| 520 |
+
# DBから関連知識を取得
|
| 521 |
+
relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=3)
|
| 522 |
+
yield {"type": "thinking", "content": f"Relevant DB knowledge found. Retrieved {len(relevant_knowledge)} tiles. Augmenting prompt..."}
|
| 523 |
+
|
| 524 |
+
# RAG: プロンプトに知識を統合
|
| 525 |
+
knowledge_context = "\n\n".join([
|
| 526 |
+
f"[Knowledge {i+1} - {tile['verification_type']} verification, confidence: {tile['confidence_score']}]\n"
|
| 527 |
+
f"Topic: {tile['topic']}\n"
|
| 528 |
+
f"Content: {tile['content']}"
|
| 529 |
+
for i, tile in enumerate(relevant_knowledge)
|
| 530 |
+
])
|
| 531 |
+
|
| 532 |
+
augmented_prompt = f"""Based on the following verified knowledge from the database:
|
| 533 |
+
|
| 534 |
+
{knowledge_context}
|
| 535 |
+
|
| 536 |
+
Now, please answer the following question accurately:
|
| 537 |
+
{prompt}"""
|
| 538 |
+
|
| 539 |
+
# ストリーミング推論実行
|
| 540 |
+
async for chunk in self._perform_llm_streaming_inference(model_config, augmented_prompt, temperature or model_config.temperature):
|
| 541 |
+
if chunk.get("type") == "token":
|
| 542 |
+
generated_response += chunk.get("content", "")
|
| 543 |
+
yield chunk
|
| 544 |
+
|
| 545 |
+
yield {"type": "meta", "source_type": source_type, "thinking": f"Accessed knowledge base. Retrieved {len(relevant_knowledge)} relevant knowledge tiles."}
|
| 546 |
+
|
| 547 |
+
# 倒木システム: 師匠の出力を教師データとして保存 (DB知識使用時も保存)
|
| 548 |
+
logger.error(f"[CRITICAL DEBUG] Reached training data save section (with DB knowledge)! generated_response length={len(generated_response)}")
|
| 549 |
+
is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
|
| 550 |
+
logger.info(f"[Training Data Check - DB Augmented] master_model={self.master_model.model_id if self.master_model else None}, "
|
| 551 |
+
f"current_model={model_config.model_id}, is_master={is_master}, "
|
| 552 |
+
f"response_length={len(generated_response)}, save_to_memory={save_to_memory}")
|
| 553 |
+
if is_master and len(generated_response) > 0:
|
| 554 |
+
logger.info(f"[Training Data - DB Augmented] Saving master output for domain '{domain_id}'...")
|
| 555 |
+
await self._save_master_output_as_training_data(
|
| 556 |
+
prompt=prompt,
|
| 557 |
+
master_response=generated_response,
|
| 558 |
+
domain_id=domain_id,
|
| 559 |
+
confidence=0.8
|
| 560 |
+
)
|
| 561 |
+
else:
|
| 562 |
+
logger.info(f"[Training Data - DB Augmented] NOT saving: is_master={is_master}, response_length={len(generated_response)}")
|
| 563 |
+
|
| 564 |
+
else:
|
| 565 |
+
yield {"type": "thinking", "content": "No specific DB knowledge found. Using AI internal weights."}
|
| 566 |
+
|
| 567 |
+
# ストリーミング推論実行
|
| 568 |
+
async for chunk in self._perform_llm_streaming_inference(model_config, prompt, temperature or model_config.temperature):
|
| 569 |
+
if chunk.get("type") == "token":
|
| 570 |
+
generated_response += chunk.get("content", "")
|
| 571 |
+
yield chunk
|
| 572 |
+
|
| 573 |
+
yield {"type": "meta", "source_type": source_type, "thinking": "Inferred from AI internal weights."}
|
| 574 |
+
|
| 575 |
+
# 倒木システム: 師匠の出力を教師データとして保存
|
| 576 |
+
logger.error(f"[CRITICAL DEBUG] Reached training data save section! generated_response length={len(generated_response)}")
|
| 577 |
+
is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
|
| 578 |
+
logger.info(f"[Training Data Check - Streaming] master_model={self.master_model.model_id if self.master_model else None}, "
|
| 579 |
+
f"current_model={model_config.model_id}, is_master={is_master}, "
|
| 580 |
+
f"response_length={len(generated_response)}, save_to_memory={save_to_memory}")
|
| 581 |
+
if is_master and len(generated_response) > 0:
|
| 582 |
+
logger.info(f"[Training Data - Streaming] Saving master output for domain '{domain_id}'...")
|
| 583 |
+
await self._save_master_output_as_training_data(
|
| 584 |
+
prompt=prompt,
|
| 585 |
+
master_response=generated_response,
|
| 586 |
+
domain_id=domain_id,
|
| 587 |
+
confidence=0.8
|
| 588 |
+
)
|
| 589 |
+
else:
|
| 590 |
+
logger.info(f"[Training Data - Streaming] NOT saving: is_master={is_master}, response_length={len(generated_response)}")
|
| 591 |
+
|
| 592 |
+
# 自己拡充: 推論結果をDBに保存(信頼度が十分な場合)
|
| 593 |
+
if save_to_memory and len(generated_response) > 0:
|
| 594 |
+
await self._save_inference_to_db(
|
| 595 |
+
domain_id=domain_id,
|
| 596 |
+
prompt=prompt,
|
| 597 |
+
response=generated_response,
|
| 598 |
+
confidence=0.75, # ストリーミングなので固定値
|
| 599 |
+
source_type="ai",
|
| 600 |
+
model_id=model_config.model_id
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
except Exception as e:
|
| 604 |
+
logger.error(f"Error during streaming inference: {e}")
|
| 605 |
+
yield {"type": "error", "content": str(e)}
|
| 606 |
+
|
| 607 |
+
def _check_db_knowledge(self, domain_id: str, prompt: str) -> bool:
|
| 608 |
+
"""
|
| 609 |
+
指定されたドメインとプロンプトに関連する知識があるかチェックする。
|
| 610 |
+
樹木型空間記憶(.iath)を使用。
|
| 611 |
+
"""
|
| 612 |
+
try:
|
| 613 |
+
if self.dendritic_memory is None:
|
| 614 |
+
return False
|
| 615 |
+
|
| 616 |
+
# テキスト検索でマッチがあるかチェック
|
| 617 |
+
results = self.dendritic_memory.search_by_text(prompt[:200], top_k=1)
|
| 618 |
+
return len(results) > 0
|
| 619 |
+
|
| 620 |
+
except Exception as e:
|
| 621 |
+
logger.error(f"Error checking dendritic memory: {e}")
|
| 622 |
+
return False
|
| 623 |
+
|
| 624 |
+
def _retrieve_relevant_knowledge(self, domain_id: str, prompt: str, top_k: int = 3) -> list:
|
| 625 |
+
"""
|
| 626 |
+
指定されたドメインとプロンプトに関連する知識を取得する。
|
| 627 |
+
樹木型空間記憶(.iath)のハイブリッド検索を使用。
|
| 628 |
+
"""
|
| 629 |
+
try:
|
| 630 |
+
if self.dendritic_memory is None:
|
| 631 |
+
return []
|
| 632 |
+
|
| 633 |
+
# ハイブリッド検索(テキスト + 空間座標)
|
| 634 |
+
# TODO: プロンプトから空間座標を推定する機能を追加
|
| 635 |
+
results = self.dendritic_memory.hybrid_search(prompt, query_coords=None, top_k=top_k)
|
| 636 |
+
|
| 637 |
+
# 統一フォーマットに変換
|
| 638 |
+
return [
|
| 639 |
+
{
|
| 640 |
+
"id": tile["metadata"]["knowledge_id"],
|
| 641 |
+
"topic": tile["metadata"]["topic"],
|
| 642 |
+
"content": tile["content"]["final_response"],
|
| 643 |
+
"thinking_process": tile["content"]["thinking_process"],
|
| 644 |
+
"confidence_score": tile["verification"]["initial_certainty"],
|
| 645 |
+
"verification_type": tile["verification"]["status"],
|
| 646 |
+
"coordinates": tile["coordinates"],
|
| 647 |
+
"text_match_score": tile.get("text_match_score", 0),
|
| 648 |
+
"spatial_distance": tile.get("spatial_distance", None)
|
| 649 |
+
}
|
| 650 |
+
for tile in results
|
| 651 |
+
]
|
| 652 |
+
|
| 653 |
+
except Exception as e:
|
| 654 |
+
logger.error(f"Error retrieving relevant knowledge from dendritic memory: {e}")
|
| 655 |
+
return []
|
| 656 |
+
|
| 657 |
+
async def _perform_llm_inference(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
|
| 658 |
+
"""
|
| 659 |
+
指定されたモデル設定でLLM推論を実行する。
|
| 660 |
+
各プロバイダーの実装を使用。
|
| 661 |
+
"""
|
| 662 |
+
provider_type = model_config.provider
|
| 663 |
+
|
| 664 |
+
if provider_type not in self.providers:
|
| 665 |
+
raise ValueError(f"Unsupported provider: {provider_type}")
|
| 666 |
+
|
| 667 |
+
provider = self.providers[provider_type]
|
| 668 |
+
|
| 669 |
+
try:
|
| 670 |
+
result = await provider.infer(model_config, prompt, temperature)
|
| 671 |
+
return result
|
| 672 |
+
except Exception as e:
|
| 673 |
+
logger.error(f"Error during LLM inference with provider '{provider_type}': {e}")
|
| 674 |
+
raise
|
| 675 |
+
|
| 676 |
+
async def _perform_llm_streaming_inference(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
|
| 677 |
+
"""
|
| 678 |
+
指定されたモデル設定でLLMストリーミング推論を実行する。
|
| 679 |
+
各プロバイダーの実装を使用。
|
| 680 |
+
"""
|
| 681 |
+
provider_type = model_config.provider
|
| 682 |
+
|
| 683 |
+
if provider_type not in self.providers:
|
| 684 |
+
raise ValueError(f"Unsupported provider: {provider_type}")
|
| 685 |
+
|
| 686 |
+
provider = self.providers[provider_type]
|
| 687 |
+
|
| 688 |
+
try:
|
| 689 |
+
async for chunk in provider.infer_streaming(model_config, prompt, temperature):
|
| 690 |
+
yield chunk
|
| 691 |
+
except Exception as e:
|
| 692 |
+
logger.error(f"Error during LLM streaming inference with provider '{provider_type}': {e}")
|
| 693 |
+
yield {"type": "error", "content": str(e)}
|
| 694 |
+
|
| 695 |
+
async def _save_inference_to_db(self, domain_id: str, prompt: str, response: str, confidence: float, source_type: str, model_id: str) -> bool:
|
| 696 |
+
"""
|
| 697 |
+
推論結果をKnowledge TileとしてDBに保存する(自己拡充)。
|
| 698 |
+
|
| 699 |
+
Priority 2実装: 座標自動推定 + .iath保存
|
| 700 |
+
"""
|
| 701 |
+
try:
|
| 702 |
+
from backend.app.database.models import KnowledgeTile
|
| 703 |
+
from backend.app.database.session import SessionLocal # 遅延インポート
|
| 704 |
+
|
| 705 |
+
db = SessionLocal()
|
| 706 |
+
tile_id = f"ai_ktile_{uuid.uuid4().hex}"
|
| 707 |
+
|
| 708 |
+
# AI生成の知識タイルを作成(SQLite)
|
| 709 |
+
new_tile = KnowledgeTile(
|
| 710 |
+
id=tile_id,
|
| 711 |
+
workspace_id="default_workspace", # ローカル版なのでデフォルトワークスペース
|
| 712 |
+
domain_id=domain_id,
|
| 713 |
+
topic=prompt[:200], # プロンプトをトピックとして使用
|
| 714 |
+
content=response,
|
| 715 |
+
confidence_score=confidence,
|
| 716 |
+
verification_type="ai", # AI生成を示す
|
| 717 |
+
verification_count=1,
|
| 718 |
+
contributor_id=None, # AIなのでcontributorなし
|
| 719 |
+
last_verified_by_id=None
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
db.add(new_tile)
|
| 723 |
+
db.commit()
|
| 724 |
+
db.close()
|
| 725 |
+
|
| 726 |
+
logger.info(f"Saved AI inference to SQLite: {new_tile.id}")
|
| 727 |
+
|
| 728 |
+
# Priority 2: 座標自動推定 + .iath保存
|
| 729 |
+
try:
|
| 730 |
+
# 座標推定用のLLM推論関数を作成
|
| 731 |
+
async def llm_inference_for_coords(coord_prompt):
|
| 732 |
+
# 師匠モデル(またはDeepSeek)を使って座標推定
|
| 733 |
+
estimation_model = self.master_model if self.master_model else self._get_any_available_model()
|
| 734 |
+
|
| 735 |
+
if estimation_model:
|
| 736 |
+
result = await self._perform_llm_inference(
|
| 737 |
+
estimation_model,
|
| 738 |
+
coord_prompt,
|
| 739 |
+
temperature=0.3 # 低温度で一貫性を保つ
|
| 740 |
+
)
|
| 741 |
+
return result.get("response", "")
|
| 742 |
+
return "{\"coordinates\": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], \"confidence\": 0.3}"
|
| 743 |
+
|
| 744 |
+
# 座標を推定
|
| 745 |
+
coord_result = await self.coordinate_estimator.estimate_coordinates(
|
| 746 |
+
prompt=prompt,
|
| 747 |
+
response=response,
|
| 748 |
+
domain_id=domain_id,
|
| 749 |
+
llm_inference_func=llm_inference_for_coords,
|
| 750 |
+
use_reasoning=False # 高速化のため推論過程は省略
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# .iath Tileオブジェクトを作成
|
| 754 |
+
iath_tile = create_tile_from_ai_output(
|
| 755 |
+
knowledge_id=tile_id,
|
| 756 |
+
topic=prompt[:100],
|
| 757 |
+
prompt=prompt,
|
| 758 |
+
response=response,
|
| 759 |
+
coordinates=coord_result["coordinates"],
|
| 760 |
+
confidence=confidence,
|
| 761 |
+
domain_id=domain_id,
|
| 762 |
+
source="ai_generated"
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
# .iathファイルに保存
|
| 766 |
+
success = self.iath_writer.append_tile(iath_tile)
|
| 767 |
+
|
| 768 |
+
if success:
|
| 769 |
+
logger.info(f"Saved AI inference to .iath with coordinates: {coord_result['coordinates']}")
|
| 770 |
+
|
| 771 |
+
# メモリをリロード(新しいタイルを検索可能にする)
|
| 772 |
+
self._load_dendritic_memory()
|
| 773 |
+
else:
|
| 774 |
+
logger.warning(f"Failed to save to .iath, but SQLite save succeeded")
|
| 775 |
+
|
| 776 |
+
except Exception as iath_error:
|
| 777 |
+
logger.error(f"Error saving to .iath (SQLite save succeeded): {iath_error}")
|
| 778 |
+
# .iath保存失敗してもSQLite保存は成功しているのでTrueを返す
|
| 779 |
+
|
| 780 |
+
return True
|
| 781 |
+
|
| 782 |
+
except Exception as e:
|
| 783 |
+
logger.error(f"Error saving inference to DB: {e}")
|
| 784 |
+
return False
|
| 785 |
+
|
| 786 |
+
async def _save_master_output_as_training_data(self, prompt: str, master_response: str, domain_id: str, confidence: float) -> bool:
|
| 787 |
+
"""
|
| 788 |
+
師匠の出力を弟子のファインチューニング用教師データとして保存する。
|
| 789 |
+
倒木システムの核心機能。
|
| 790 |
+
"""
|
| 791 |
+
try:
|
| 792 |
+
import json
|
| 793 |
+
import os
|
| 794 |
+
from datetime import datetime
|
| 795 |
+
|
| 796 |
+
logger.info(f"[Save Training Data] Starting to save master output (domain={domain_id}, confidence={confidence})")
|
| 797 |
+
|
| 798 |
+
# トレーニングデータ保存ディレクトリ
|
| 799 |
+
training_data_dir = "training_data/master_outputs"
|
| 800 |
+
logger.info(f"[Save Training Data] Creating directory: {training_data_dir}")
|
| 801 |
+
os.makedirs(training_data_dir, exist_ok=True)
|
| 802 |
+
|
| 803 |
+
# Alpaca形式で保存
|
| 804 |
+
training_example = {
|
| 805 |
+
"instruction": f"You are an expert in {domain_id}. Provide accurate information based on verified knowledge.",
|
| 806 |
+
"input": prompt,
|
| 807 |
+
"output": master_response,
|
| 808 |
+
"metadata": {
|
| 809 |
+
"domain_id": domain_id,
|
| 810 |
+
"confidence": confidence,
|
| 811 |
+
"master_model_id": self.master_model.model_id if self.master_model else "unknown",
|
| 812 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 813 |
+
"source": "master_output"
|
| 814 |
+
}
|
| 815 |
+
}
|
| 816 |
+
|
| 817 |
+
# JSONLファイルに追記
|
| 818 |
+
output_file = os.path.join(training_data_dir, f"master_outputs_{domain_id}.jsonl")
|
| 819 |
+
logger.info(f"[Save Training Data] Writing to file: {output_file}")
|
| 820 |
+
with open(output_file, 'a', encoding='utf-8') as f:
|
| 821 |
+
f.write(json.dumps(training_example, ensure_ascii=False) + '\n')
|
| 822 |
+
|
| 823 |
+
logger.info(f"✓ Successfully saved master output as training data: {output_file}")
|
| 824 |
+
return True
|
| 825 |
+
|
| 826 |
+
except Exception as e:
|
| 827 |
+
logger.error(f"✗ Error saving master output as training data: {e}", exc_info=True)
|
| 828 |
+
return False
|
model_router.py.backup
ADDED
|
@@ -0,0 +1,803 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Optional, Dict, Any, AsyncGenerator, List
|
| 3 |
+
import asyncio
|
| 4 |
+
import time
|
| 5 |
+
import threading # For TextIteratorStreamer
|
| 6 |
+
import uuid # For generating unique IDs
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from backend.app.config import ConfigManager, ModelConfig, ModelProvider
|
| 10 |
+
# SessionLocal は循環インポートを避けるため遅延インポート
|
| 11 |
+
from backend.app.services.knowledge_service import KnowledgeService, get_knowledge_service
|
| 12 |
+
from null_ai.llm_providers import HuggingFaceProvider, OllamaProvider, MLXProvider, GGUFProvider # Import all providers
|
| 13 |
+
from null_ai.iath_memory import DendriticMemorySpace # 樹木型空間記憶
|
| 14 |
+
from null_ai.coordinate_estimator import CoordinateEstimator # 座標自動推定
|
| 15 |
+
from null_ai.iath_writer import IathWriter, create_tile_from_ai_output # .iath保存
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class ModelRouter:
|
| 20 |
+
"""
|
| 21 |
+
推論エンジン間でモデルのルーティングと切り替えを管理するクラス。
|
| 22 |
+
師匠(Master)モデルと弟子(Apprentice)モデルの状態を保持する。
|
| 23 |
+
"""
|
| 24 |
+
_instance: Optional['ModelRouter'] = None
|
| 25 |
+
_initialized = False
|
| 26 |
+
|
| 27 |
+
def __new__(cls, config_manager: ConfigManager):
|
| 28 |
+
if cls._instance == None:
|
| 29 |
+
cls._instance = super(ModelRouter, cls).__new__(cls)
|
| 30 |
+
return cls._instance
|
| 31 |
+
|
| 32 |
+
def __init__(self, config_manager: ConfigManager):
|
| 33 |
+
if self._initialized:
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
self.config_manager = config_manager
|
| 37 |
+
self.knowledge_service = get_knowledge_service() # Instantiate KnowledgeService internally
|
| 38 |
+
self.master_model: Optional[ModelConfig] = None
|
| 39 |
+
self.apprentice_model: Optional[ModelConfig] = None
|
| 40 |
+
self.active_domain_id: Optional[str] = None # Active domain managed by ConfigManager
|
| 41 |
+
self.managed_engines: Dict[str, Dict[str, Any]] = {} # Tracks all engines and their status
|
| 42 |
+
|
| 43 |
+
# 樹木型空間記憶(.iathファイル)の初期化
|
| 44 |
+
self.dendritic_memory: Optional[DendriticMemorySpace] = None
|
| 45 |
+
self._load_dendritic_memory()
|
| 46 |
+
|
| 47 |
+
# 座標自動推定器の初期化
|
| 48 |
+
self.coordinate_estimator = CoordinateEstimator()
|
| 49 |
+
|
| 50 |
+
# .iathライターの初期化
|
| 51 |
+
iath_file_path = os.getenv("IATH_DB_PATH", "knowledge_base.iath")
|
| 52 |
+
self.iath_writer = IathWriter(iath_file_path)
|
| 53 |
+
|
| 54 |
+
# LLMプロバイダーを初期化
|
| 55 |
+
self.providers: Dict[ModelProvider, Any] = {
|
| 56 |
+
ModelProvider.HUGGINGFACE: HuggingFaceProvider(),
|
| 57 |
+
ModelProvider.OLLAMA: OllamaProvider(),
|
| 58 |
+
ModelProvider.MLX: MLXProvider(),
|
| 59 |
+
ModelProvider.GGUF: GGUFProvider(),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# 初期ロード時にデフォルトモデルを設定
|
| 63 |
+
self._load_all_engines() # Load all configured models into managed_engines
|
| 64 |
+
self._load_default_models() # This will now also try to load active master/apprentice from config
|
| 65 |
+
|
| 66 |
+
self._initialized = True
|
| 67 |
+
logger.info("ModelRouter initialized.")
|
| 68 |
+
|
| 69 |
+
def _load_dendritic_memory(self):
|
| 70 |
+
"""樹木型空間記憶(.iathファイル)をロードする"""
|
| 71 |
+
try:
|
| 72 |
+
# TODO: ドメインごとに異なる.iathファイルを使用する
|
| 73 |
+
# 現在はデフォルトのパスを使用
|
| 74 |
+
import os
|
| 75 |
+
iath_file_path = os.getenv("IATH_DB_PATH", "knowledge_base.iath")
|
| 76 |
+
|
| 77 |
+
if os.path.exists(iath_file_path):
|
| 78 |
+
self.dendritic_memory = DendriticMemorySpace(iath_file_path)
|
| 79 |
+
stats = self.dendritic_memory.get_statistics()
|
| 80 |
+
logger.info(f"Dendritic memory loaded: {stats['total_tiles']} tiles from {iath_file_path}")
|
| 81 |
+
else:
|
| 82 |
+
logger.warning(f".iath file not found: {iath_file_path}. Starting with empty memory.")
|
| 83 |
+
self.dendritic_memory = None
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.error(f"Failed to load dendritic memory: {e}")
|
| 86 |
+
self.dendritic_memory = None
|
| 87 |
+
|
| 88 |
+
def _get_any_available_model(self) -> Optional[ModelConfig]:
|
| 89 |
+
"""
|
| 90 |
+
利用可能なモデルを1つ取得する(座標推定などで使用)
|
| 91 |
+
|
| 92 |
+
優先順位:
|
| 93 |
+
1. 師匠モデル
|
| 94 |
+
2. 設定された最初のモデル
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
ModelConfig or None
|
| 98 |
+
"""
|
| 99 |
+
if self.master_model:
|
| 100 |
+
return self.master_model
|
| 101 |
+
|
| 102 |
+
# 設定から最初のモデルを取得
|
| 103 |
+
for model_id, model_config in self.config_manager.models.items():
|
| 104 |
+
return model_config
|
| 105 |
+
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
def _load_all_engines(self):
|
| 109 |
+
"""設定ファイルから全てのモデルをロードし、managed_enginesを初期化する"""
|
| 110 |
+
self.managed_engines = {}
|
| 111 |
+
for model_id, model_config in self.config_manager.models.items():
|
| 112 |
+
self.managed_engines[model_id] = {
|
| 113 |
+
"config": model_config.dict(),
|
| 114 |
+
"status": "available", # Default status
|
| 115 |
+
"unique_id": None # Only for apprentices
|
| 116 |
+
}
|
| 117 |
+
logger.info(f"Loaded {len(self.managed_engines)} engines into management.")
|
| 118 |
+
|
| 119 |
+
def _update_engine_status(self, model_id: str, status: str, unique_id: Optional[str] = None):
|
| 120 |
+
"""managed_engines内のエンジンのステータスを更新するヘルパー"""
|
| 121 |
+
if model_id in self.managed_engines:
|
| 122 |
+
self.managed_engines[model_id]["status"] = status
|
| 123 |
+
self.managed_engines[model_id]["unique_id"] = unique_id
|
| 124 |
+
else:
|
| 125 |
+
logger.warning(f"Attempted to update status for unknown engine: {model_id}")
|
| 126 |
+
|
| 127 |
+
def _load_default_models(self):
|
| 128 |
+
"""設定からデフォルトの師匠モデルと弟子モデルをロードする"""
|
| 129 |
+
self.active_domain_id = self.config_manager.get_active_domain_id()
|
| 130 |
+
|
| 131 |
+
# ConfigManagerから永続化されたアクティブな師匠・弟子モデルIDをロード
|
| 132 |
+
persisted_master_id = self.config_manager.get_null_ai_setting("active_master_id")
|
| 133 |
+
persisted_apprentice_id = self.config_manager.get_null_ai_setting("active_apprentice_id")
|
| 134 |
+
|
| 135 |
+
# 師匠モデルの設定
|
| 136 |
+
if persisted_master_id:
|
| 137 |
+
master_config = self.config_manager.get_model_config(persisted_master_id)
|
| 138 |
+
if master_config:
|
| 139 |
+
self.master_model = master_config
|
| 140 |
+
self._update_engine_status(self.master_model.model_id, "master")
|
| 141 |
+
logger.info(f"Persisted master model loaded: {self.master_model.display_name}")
|
| 142 |
+
else:
|
| 143 |
+
logger.warning(f"Persisted master model '{persisted_master_id}' not found in configuration. Attempting to set default.")
|
| 144 |
+
self._set_initial_master_from_config() # 永続化されたモデルが見つからない場合はデフォルトを設定
|
| 145 |
+
else:
|
| 146 |
+
self._set_initial_master_from_config() # 永続化されたマスターIDがない場合はデフォルトを設定
|
| 147 |
+
|
| 148 |
+
# 弟子モデルの設定
|
| 149 |
+
if persisted_apprentice_id:
|
| 150 |
+
apprentice_config = self.config_manager.get_model_config(persisted_apprentice_id)
|
| 151 |
+
if apprentice_config:
|
| 152 |
+
self.apprentice_model = apprentice_config
|
| 153 |
+
# unique_idもconfigからロードする
|
| 154 |
+
apprentice_unique_id = self.config_manager.get_null_ai_setting(f"apprentice_unique_id_{persisted_apprentice_id}")
|
| 155 |
+
self._update_engine_status(self.apprentice_model.model_id, "apprentice", apprentice_unique_id)
|
| 156 |
+
logger.info(f"Persisted apprentice model loaded: {self.apprentice_model.display_name}")
|
| 157 |
+
else:
|
| 158 |
+
logger.warning(f"Persisted apprentice model '{persisted_apprentice_id}' not found in configuration. Clearing active apprentice.")
|
| 159 |
+
self.apprentice_model = None
|
| 160 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 161 |
+
else:
|
| 162 |
+
self.apprentice_model = None
|
| 163 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 164 |
+
|
| 165 |
+
def _set_initial_master_from_config(self):
|
| 166 |
+
"""設定からデフォルトの師匠モデルをロード(persistedがない場合や見つからない場合)"""
|
| 167 |
+
default_master_config = self.config_manager.get_default_model_config(domain_id=self.active_domain_id)
|
| 168 |
+
if default_master_config:
|
| 169 |
+
self.master_model = default_master_config
|
| 170 |
+
self._update_engine_status(self.master_model.model_id, "master")
|
| 171 |
+
self.config_manager.set_null_ai_setting("active_master_id", self.master_model.model_id)
|
| 172 |
+
logger.info(f"Default master model loaded for domain '{self.active_domain_id}': {self.master_model.display_name}")
|
| 173 |
+
else:
|
| 174 |
+
logger.warning(f"No default master model found for domain '{self.active_domain_id}' in configuration. Master model remains unset.")
|
| 175 |
+
self.master_model = None
|
| 176 |
+
self.config_manager.set_null_ai_setting("active_master_id", None)
|
| 177 |
+
|
| 178 |
+
def set_active_domain_id(self, domain_id: str):
|
| 179 |
+
"""アクティブなドメインIDを設定し、それに応じてモデルを再ロードする"""
|
| 180 |
+
if self.active_domain_id != domain_id:
|
| 181 |
+
self.active_domain_id = domain_id
|
| 182 |
+
self._load_default_models() # アクティブドメイン変更時は師匠・弟子モデルも再設定
|
| 183 |
+
logger.info(f"ModelRouter active domain set to {domain_id} and models reloaded.")
|
| 184 |
+
|
| 185 |
+
def set_master_model(self, model_id: str) -> bool:
|
| 186 |
+
"""師匠モデルを設定する"""
|
| 187 |
+
model = self.config_manager.get_model_config(model_id)
|
| 188 |
+
if model:
|
| 189 |
+
# 古い師匠を'retired'に戻す (ただし、それがまさに今昇格している弟子ではない場合)
|
| 190 |
+
if self.master_model and self.master_model.model_id != model_id:
|
| 191 |
+
# 昇格の場合、古い師匠はretiredにする
|
| 192 |
+
self._update_engine_status(self.master_model.model_id, "retired")
|
| 193 |
+
logger.info(f"Old master model '{self.master_model.display_name}' set to 'retired'.")
|
| 194 |
+
|
| 195 |
+
self.master_model = model
|
| 196 |
+
self._update_engine_status(model_id, "master")
|
| 197 |
+
self.config_manager.set_null_ai_setting("active_master_id", model_id)
|
| 198 |
+
logger.info(f"Master model set to: {model.display_name}")
|
| 199 |
+
return True
|
| 200 |
+
logger.error(f"Model with ID '{model_id}' not found for master setting.")
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
def set_apprentice_model(self, model_id: Optional[str]) -> bool:
|
| 204 |
+
"""弟子モデルを設定する (Noneでクリア)"""
|
| 205 |
+
# 古い弟子を'available'に戻す
|
| 206 |
+
if self.apprentice_model:
|
| 207 |
+
self._update_engine_status(self.apprentice_model.model_id, "available")
|
| 208 |
+
|
| 209 |
+
if model_id is None or model_id == 'none':
|
| 210 |
+
self.apprentice_model = None
|
| 211 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 212 |
+
logger.info("Apprentice model cleared.")
|
| 213 |
+
return True
|
| 214 |
+
|
| 215 |
+
model = self.config_manager.get_model_config(model_id)
|
| 216 |
+
if model:
|
| 217 |
+
self.apprentice_model = model
|
| 218 |
+
# For apprentices, we need to ensure unique_id is tracked if this is a named apprentice
|
| 219 |
+
apprentice_unique_id = self.managed_engines[model_id].get("unique_id") # Get existing unique_id
|
| 220 |
+
self._update_engine_status(model_id, "apprentice", apprentice_unique_id)
|
| 221 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", model_id)
|
| 222 |
+
logger.info(f"Apprentice model set to: {model.display_name}")
|
| 223 |
+
return True
|
| 224 |
+
logger.error(f"Model with ID '{model_id}' not found for apprentice setting.")
|
| 225 |
+
return False
|
| 226 |
+
|
| 227 |
+
def get_all_managed_engines(self) -> List[Dict[str, Any]]:
|
| 228 |
+
"""管理している全てのエンジンとそのステータス、ユニークIDを含むリストを返す"""
|
| 229 |
+
return list(self.managed_engines.values())
|
| 230 |
+
|
| 231 |
+
def get_master_model(self) -> Optional[ModelConfig]:
|
| 232 |
+
"""現在の師匠モデルを取得する"""
|
| 233 |
+
return self.master_model
|
| 234 |
+
|
| 235 |
+
def get_apprentice_model(self) -> Optional[ModelConfig]:
|
| 236 |
+
"""現在の弟子モデルを取得する"""
|
| 237 |
+
return self.apprentice_model
|
| 238 |
+
|
| 239 |
+
def swap_engines(self, apprentice_model_id: str) -> bool:
|
| 240 |
+
"""師匠と指定した弟子を入れ替える"""
|
| 241 |
+
if not self.master_model:
|
| 242 |
+
logger.error("Cannot swap: No master model is currently set.")
|
| 243 |
+
return False
|
| 244 |
+
|
| 245 |
+
apprentice_candidate = self.config_manager.get_model_config(apprentice_model_id)
|
| 246 |
+
if not apprentice_candidate:
|
| 247 |
+
logger.error(f"Cannot swap: Apprentice model '{apprentice_model_id}' not found.")
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
# 現在の師匠と弟子の情報を取得
|
| 251 |
+
old_master_model = self.master_model
|
| 252 |
+
|
| 253 |
+
# 新しい師匠は指定された弟子
|
| 254 |
+
new_master_model = apprentice_candidate
|
| 255 |
+
|
| 256 |
+
# 新しい弟子は古い師匠 (ただし、古い師匠が現在のアプレンティス候補でなければ)
|
| 257 |
+
new_apprentice_model = old_master_model if old_master_model.model_id != apprentice_model_id else None
|
| 258 |
+
|
| 259 |
+
# ロールを入れ替える
|
| 260 |
+
self.master_model = new_master_model
|
| 261 |
+
self.apprentice_model = new_apprentice_model
|
| 262 |
+
|
| 263 |
+
# ステータスと永続化を更新
|
| 264 |
+
self._update_engine_status(new_master_model.model_id, "master", self.managed_engines[new_master_model.model_id].get("unique_id"))
|
| 265 |
+
self.config_manager.set_null_ai_setting("active_master_id", new_master_model.model_id)
|
| 266 |
+
|
| 267 |
+
if new_apprentice_model:
|
| 268 |
+
self._update_engine_status(new_apprentice_model.model_id, "apprentice", self.managed_engines[new_apprentice_model.model_id].get("unique_id"))
|
| 269 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", new_apprentice_model.model_id)
|
| 270 |
+
else:
|
| 271 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 272 |
+
|
| 273 |
+
# 古いアプレンティスが指定されたアプレンティスとは異なる場合、そのステータスをavailableに戻す
|
| 274 |
+
if old_master_model and old_master_model.model_id != new_master_model.model_id:
|
| 275 |
+
self._update_engine_status(old_master_model.model_id, "available") # Old master becomes available if not the new apprentice
|
| 276 |
+
|
| 277 |
+
logger.info(f"Engines swapped: New Master is {new_master_model.display_name}, New Apprentice is {new_apprentice_model.display_name if new_apprentice_model else 'None'}")
|
| 278 |
+
return True
|
| 279 |
+
|
| 280 |
+
def promote_apprentice(self, apprentice_model_id: str) -> bool:
|
| 281 |
+
"""指定した弟子を師匠に昇格させ、現在の師匠を引退させる"""
|
| 282 |
+
apprentice_to_promote = self.config_manager.get_model_config(apprentice_model_id)
|
| 283 |
+
if not apprentice_to_promote:
|
| 284 |
+
logger.error(f"Cannot promote: Apprentice model '{apprentice_model_id}' not found.")
|
| 285 |
+
return False
|
| 286 |
+
|
| 287 |
+
# 現在の師匠を引��させる
|
| 288 |
+
if self.master_model:
|
| 289 |
+
self._update_engine_status(self.master_model.model_id, "retired")
|
| 290 |
+
logger.info(f"Current master model '{self.master_model.display_name}' retired.")
|
| 291 |
+
|
| 292 |
+
# 弟子を師匠に昇格させる
|
| 293 |
+
self.master_model = apprentice_to_promote
|
| 294 |
+
self._update_engine_status(apprentice_to_promote.model_id, "master", self.managed_engines[apprentice_to_promote.model_id].get("unique_id"))
|
| 295 |
+
self.config_manager.set_null_ai_setting("active_master_id", apprentice_to_promote.model_id)
|
| 296 |
+
|
| 297 |
+
# 弟子モデルのロールをクリア (昇格したため)
|
| 298 |
+
if self.apprentice_model and self.apprentice_model.model_id == apprentice_model_id:
|
| 299 |
+
self.apprentice_model = None
|
| 300 |
+
self.config_manager.set_null_ai_setting("active_apprentice_id", None)
|
| 301 |
+
logger.info(f"Apprentice model '{apprentice_to_promote.display_name}' promoted to master and apprentice role cleared.")
|
| 302 |
+
else:
|
| 303 |
+
logger.warning(f"Apprentice '{apprentice_model_id}' was promoted, but was not the currently active apprentice.")
|
| 304 |
+
|
| 305 |
+
logger.info(f"Apprentice '{apprentice_to_promote.display_name}' promoted to Master.")
|
| 306 |
+
return True
|
| 307 |
+
|
| 308 |
+
def create_new_apprentice(self) -> Optional[Dict[str, Any]]:
|
| 309 |
+
"""新しい「空っぽの弟子」推論エンジンを生成し、登録する"""
|
| 310 |
+
new_apprentice_unique_id = str(uuid.uuid4())
|
| 311 |
+
new_apprentice_model_id = f"apprentice-{new_apprentice_unique_id[:8]}"
|
| 312 |
+
new_apprentice_display_name = f"Apprentice ({new_apprentice_unique_id[:4]})"
|
| 313 |
+
|
| 314 |
+
# ユーザーは「文字通り学習データがないもの」と指定。
|
| 315 |
+
# ここでは、汎用的なベースモデルを想定するか、あるいは特定のプロバイダー/モデル名を指定する。
|
| 316 |
+
# 仮に、Ollamaの何らかの軽量モデルをテンプレートとして利用する。
|
| 317 |
+
# または、単にモデル設定のみを生成し、実際のモデルファイルは後で学習時にロードする。
|
| 318 |
+
|
| 319 |
+
# NOTE: この "empty" モデルが具体的に何を指すかは、
|
| 320 |
+
# 後続のファインチューニングのロジックによって具体化される必要がある。
|
| 321 |
+
# ここでは、設定上の一エントリとして追加する。
|
| 322 |
+
|
| 323 |
+
# TODO: ベースとなる「空っぽ」モデルの設定をnull_ai_config.jsonなどから取得できるようにする
|
| 324 |
+
base_apprentice_provider = ModelProvider.OLLAMA # 仮のデフォルトプロバイダー
|
| 325 |
+
base_apprentice_model_name = "mistral:latest" # 仮のデフォルトモデル名
|
| 326 |
+
|
| 327 |
+
# Configuration for the new apprentice
|
| 328 |
+
new_model_config_data = {
|
| 329 |
+
"model_id": new_apprentice_model_id,
|
| 330 |
+
"display_name": new_apprentice_display_name,
|
| 331 |
+
"provider": base_apprentice_provider.value,
|
| 332 |
+
"model_name": base_apprentice_model_name,
|
| 333 |
+
"max_tokens": 4096,
|
| 334 |
+
"temperature": 0.7,
|
| 335 |
+
"timeout": 120,
|
| 336 |
+
"is_default": False,
|
| 337 |
+
"supported_domains": ["general"], # New apprentices start with general knowledge
|
| 338 |
+
"description": f"Newly generated empty apprentice model with unique ID {new_apprentice_unique_id}."
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
try:
|
| 342 |
+
new_model = self.config_manager.add_model(new_model_config_data)
|
| 343 |
+
if not new_model:
|
| 344 |
+
raise Exception("Failed to add new model via config_manager.")
|
| 345 |
+
|
| 346 |
+
self.managed_engines[new_model.model_id] = {
|
| 347 |
+
"config": new_model.dict(),
|
| 348 |
+
"status": "available",
|
| 349 |
+
"unique_id": new_apprentice_unique_id
|
| 350 |
+
}
|
| 351 |
+
# unique_idも永続化しておく
|
| 352 |
+
self.config_manager.set_null_ai_setting(f"apprentice_unique_id_{new_model.model_id}", new_apprentice_unique_id)
|
| 353 |
+
|
| 354 |
+
logger.info(f"New empty apprentice '{new_apprentice_display_name}' created with ID: {new_apprentice_model_id}")
|
| 355 |
+
return self.managed_engines[new_model.model_id]
|
| 356 |
+
except Exception as e:
|
| 357 |
+
logger.error(f"Failed to create new apprentice: {e}")
|
| 358 |
+
return None
|
| 359 |
+
|
| 360 |
+
def get_active_model(self, for_inference: bool = True) -> Optional[ModelConfig]:
|
| 361 |
+
"""
|
| 362 |
+
推論に使用するアクティブなモデルを取得する。
|
| 363 |
+
for_inferenceがTrueの場合、師匠モデルを返す。
|
| 364 |
+
倒木システムにおける「学習」などの用途で弟子モデルが必要な場合は、
|
| 365 |
+
直接get_apprentice_modelを使用する。
|
| 366 |
+
"""
|
| 367 |
+
if for_inference:
|
| 368 |
+
return self.master_model
|
| 369 |
+
# 将来的に「成長した」弟子モデルへの自動切り替えロジックが入る可能性
|
| 370 |
+
return self.master_model # デフォルトでは師匠モデルを使用
|
| 371 |
+
|
| 372 |
+
async def infer(self, prompt: str, domain_id: str, model_config: ModelConfig, temperature: Optional[float] = None, save_to_memory: bool = False, rag_mode: str = "rag") -> Dict[str, Any]:
|
| 373 |
+
"""
|
| 374 |
+
指定されたモデルで推論を実行する。
|
| 375 |
+
RAG: DBに知識があればそれを使用、なければAI内部知識で推論してDBに蓄積。
|
| 376 |
+
rag_modeに応じて動作を切り替える。
|
| 377 |
+
"""
|
| 378 |
+
# rag_mode == 'direct' の場合の処理
|
| 379 |
+
if rag_mode == "direct":
|
| 380 |
+
relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=1)
|
| 381 |
+
if relevant_knowledge:
|
| 382 |
+
# DBから直接回答を返す
|
| 383 |
+
best_match = relevant_knowledge[0]
|
| 384 |
+
logger.info(f"Direct mode: DB knowledge found. Returning direct answer from tile {best_match['id']}.")
|
| 385 |
+
return {
|
| 386 |
+
"response": best_match['content'],
|
| 387 |
+
"thinking": f"Directly retrieved from knowledge base. Tile ID: {best_match['id']}.",
|
| 388 |
+
"confidence": best_match['confidence_score'],
|
| 389 |
+
"model_used": "database_direct",
|
| 390 |
+
"latency_ms": 50, # DB検索なので高速
|
| 391 |
+
"source_type": "db_direct"
|
| 392 |
+
}
|
| 393 |
+
else:
|
| 394 |
+
# DBに情報がない場合は「わかりません」と返す
|
| 395 |
+
logger.info("Direct mode: No DB knowledge found. Returning 'Not found'.")
|
| 396 |
+
return {
|
| 397 |
+
"response": "ご指定の情報はナレッジベース内に見つかりませんでした。",
|
| 398 |
+
"thinking": "Directly searched knowledge base, but no relevant information was found.",
|
| 399 |
+
"confidence": 0.9, # 「見つからない」という回答には高い信頼度を与える
|
| 400 |
+
"model_used": "database_direct",
|
| 401 |
+
"latency_ms": 50,
|
| 402 |
+
"source_type": "db_direct"
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
# rag_mode == 'rag' の場合の処理 (既存のロジック)
|
| 406 |
+
has_db_knowledge = self._check_db_knowledge(domain_id, prompt)
|
| 407 |
+
source_type = "db_augmented" if has_db_knowledge else "ai_internal_weights"
|
| 408 |
+
thinking_process = ""
|
| 409 |
+
augmented_prompt = prompt
|
| 410 |
+
|
| 411 |
+
if has_db_knowledge:
|
| 412 |
+
# DBから関連知識を取得
|
| 413 |
+
relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=3)
|
| 414 |
+
logger.info(f"RAG mode: DB knowledge found for domain '{domain_id}'. Retrieved {len(relevant_knowledge)} relevant tiles.")
|
| 415 |
+
|
| 416 |
+
# RAG: プロンプトに知識を統合
|
| 417 |
+
knowledge_context = "\n\n".join([
|
| 418 |
+
f"[Knowledge {i+1} - {tile['verification_type']} verification, confidence: {tile['confidence_score']}]\n"
|
| 419 |
+
f"Topic: {tile['topic']}\n"
|
| 420 |
+
f"Content: {tile['content']}"
|
| 421 |
+
for i, tile in enumerate(relevant_knowledge)
|
| 422 |
+
])
|
| 423 |
+
|
| 424 |
+
augmented_prompt = f"""Based on the following verified knowledge from the database:
|
| 425 |
+
|
| 426 |
+
{knowledge_context}
|
| 427 |
+
|
| 428 |
+
Now, please answer the following question accurately:
|
| 429 |
+
{prompt}"""
|
| 430 |
+
|
| 431 |
+
response = await self._perform_llm_inference(model_config, augmented_prompt, temperature or model_config.temperature)
|
| 432 |
+
thinking_process = f"Accessed knowledge base. Retrieved {len(relevant_knowledge)} relevant knowledge tiles. " + response.get("thinking", "")
|
| 433 |
+
else:
|
| 434 |
+
# DBに知識がない場合、AI内部知識で推論
|
| 435 |
+
logger.info(f"RAG mode: No DB knowledge found for domain '{domain_id}'. Using AI internal weights.")
|
| 436 |
+
response = await self._perform_llm_inference(model_config, prompt, temperature or model_config.temperature)
|
| 437 |
+
thinking_process = "No specific DB knowledge found. Inferred from AI internal weights. " + response.get("thinking", "")
|
| 438 |
+
|
| 439 |
+
# 自己拡充: 推論結果をDBに保存
|
| 440 |
+
if save_to_memory and response.get("confidence", 0) >= 0.7:
|
| 441 |
+
await self._save_inference_to_db(
|
| 442 |
+
domain_id=domain_id,
|
| 443 |
+
prompt=prompt,
|
| 444 |
+
response=response.get("response", ""),
|
| 445 |
+
confidence=response.get("confidence", 0.7),
|
| 446 |
+
source_type="ai",
|
| 447 |
+
model_id=model_config.model_id
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# 倒木システム: 師匠の出力を教師データとして保存
|
| 451 |
+
is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
|
| 452 |
+
logger.info(f"[Training Data Check] master_model={self.master_model.model_id if self.master_model else None}, "
|
| 453 |
+
f"current_model={model_config.model_id}, is_master={is_master}, "
|
| 454 |
+
f"confidence={response.get('confidence', 0)}, save_to_memory={save_to_memory}")
|
| 455 |
+
if is_master and response.get("confidence", 0) >= 0.8:
|
| 456 |
+
logger.info(f"[Training Data] Saving master output for domain '{domain_id}'...")
|
| 457 |
+
await self._save_master_output_as_training_data(
|
| 458 |
+
prompt=prompt,
|
| 459 |
+
master_response=response.get("response", ""),
|
| 460 |
+
domain_id=domain_id,
|
| 461 |
+
confidence=response.get("confidence", 0.8)
|
| 462 |
+
)
|
| 463 |
+
else:
|
| 464 |
+
logger.info(f"[Training Data] NOT saving: is_master={is_master}, confidence={response.get('confidence', 0)}")
|
| 465 |
+
|
| 466 |
+
return {
|
| 467 |
+
"response": response.get("response", ""),
|
| 468 |
+
"thinking": thinking_process,
|
| 469 |
+
"confidence": response.get("confidence", 0.7),
|
| 470 |
+
"model_used": model_config.model_id,
|
| 471 |
+
"latency_ms": response.get("latency_ms", 0),
|
| 472 |
+
"source_type": source_type
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
async def infer_streaming(self, prompt: str, domain_id: str, model_config: ModelConfig, temperature: Optional[float] = None, save_to_memory: bool = False, rag_mode: str = "rag") -> AsyncGenerator[Dict[str, Any], None]:
|
| 476 |
+
print(f"DEBUG: infer_streaming called - model={model_config.model_id if model_config else 'None'}")
|
| 477 |
+
logger.error(f"[CRITICAL DEBUG] infer_streaming ENTRY POINT - model={model_config.model_id if model_config else 'None'}, save_to_memory={save_to_memory}")
|
| 478 |
+
"""
|
| 479 |
+
指定されたモデルで推論をストリーミングで実行する。
|
| 480 |
+
RAG: DBに知識があればそれを使用、なければAI内部知識で推論してDBに蓄積。
|
| 481 |
+
rag_modeに応じて動作を切り替える。
|
| 482 |
+
"""
|
| 483 |
+
logger.info(f"[INFER_STREAMING] Called with model={model_config.model_id}, domain={domain_id}, save_to_memory={save_to_memory}")
|
| 484 |
+
|
| 485 |
+
# 師匠の応答を蓄積するための変数(finallyブロックで使用)
|
| 486 |
+
generated_response = ""
|
| 487 |
+
|
| 488 |
+
try:
|
| 489 |
+
# rag_mode == 'direct' の場合の処理
|
| 490 |
+
if rag_mode == "direct":
|
| 491 |
+
relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=1)
|
| 492 |
+
if relevant_knowledge:
|
| 493 |
+
best_match = relevant_knowledge[0]
|
| 494 |
+
logger.info(f"Direct mode (streaming): DB knowledge found. Yielding direct answer from tile {best_match['id']}.")
|
| 495 |
+
yield {"type": "token", "content": best_match['content']}
|
| 496 |
+
yield {"type": "meta", "source_type": "db_direct", "thinking": f"Directly retrieved from knowledge base. Tile ID: {best_match['id']}.", "confidence": best_match['confidence_score'], "model_used": "database_direct"}
|
| 497 |
+
yield {"type": "complete", "content": best_match['content']}
|
| 498 |
+
return
|
| 499 |
+
else:
|
| 500 |
+
logger.info("Direct mode (streaming): No DB knowledge found. Yielding 'Not found'.")
|
| 501 |
+
response_text = "ご指定の情報はナレッジベース内に見つかりませんでした。"
|
| 502 |
+
yield {"type": "token", "content": response_text}
|
| 503 |
+
yield {"type": "meta", "source_type": "db_direct", "thinking": "Directly searched knowledge base, but no relevant information was found.", "confidence": 0.9, "model_used": "database_direct"}
|
| 504 |
+
yield {"type": "complete", "content": response_text}
|
| 505 |
+
return
|
| 506 |
+
|
| 507 |
+
# rag_mode == 'rag' の場合の処理 (既存のロジック)
|
| 508 |
+
has_db_knowledge = self._check_db_knowledge(domain_id, prompt)
|
| 509 |
+
source_type = "db_augmented" if has_db_knowledge else "ai_internal_weights"
|
| 510 |
+
augmented_prompt = prompt
|
| 511 |
+
|
| 512 |
+
yield {"type": "thinking", "content": f"Checking DB knowledge for domain '{domain_id}'..."}
|
| 513 |
+
await asyncio.sleep(0.1)
|
| 514 |
+
|
| 515 |
+
if has_db_knowledge:
|
| 516 |
+
# DBから関連知識を取得
|
| 517 |
+
relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=3)
|
| 518 |
+
yield {"type": "thinking", "content": f"Relevant DB knowledge found. Retrieved {len(relevant_knowledge)} tiles. Augmenting prompt..."}
|
| 519 |
+
|
| 520 |
+
# RAG: プロンプトに知識を統合
|
| 521 |
+
knowledge_context = "\n\n".join([
|
| 522 |
+
f"[Knowledge {i+1} - {tile['verification_type']} verification, confidence: {tile['confidence_score']}]\n"
|
| 523 |
+
f"Topic: {tile['topic']}\n"
|
| 524 |
+
f"Content: {tile['content']}"
|
| 525 |
+
for i, tile in enumerate(relevant_knowledge)
|
| 526 |
+
])
|
| 527 |
+
|
| 528 |
+
augmented_prompt = f"""Based on the following verified knowledge from the database:
|
| 529 |
+
|
| 530 |
+
{knowledge_context}
|
| 531 |
+
|
| 532 |
+
Now, please answer the following question accurately:
|
| 533 |
+
{prompt}"""
|
| 534 |
+
|
| 535 |
+
# ストリーミング推論実行
|
| 536 |
+
async for chunk in self._perform_llm_streaming_inference(model_config, augmented_prompt, temperature or model_config.temperature):
|
| 537 |
+
if chunk.get("type") == "token":
|
| 538 |
+
generated_response += chunk.get("content", "")
|
| 539 |
+
yield chunk
|
| 540 |
+
|
| 541 |
+
yield {"type": "meta", "source_type": source_type, "thinking": f"Accessed knowledge base. Retrieved {len(relevant_knowledge)} relevant knowledge tiles."}
|
| 542 |
+
|
| 543 |
+
else:
|
| 544 |
+
yield {"type": "thinking", "content": "No specific DB knowledge found. Using AI internal weights."}
|
| 545 |
+
|
| 546 |
+
# ストリーミング推論実行
|
| 547 |
+
async for chunk in self._perform_llm_streaming_inference(model_config, prompt, temperature or model_config.temperature):
|
| 548 |
+
if chunk.get("type") == "token":
|
| 549 |
+
generated_response += chunk.get("content", "")
|
| 550 |
+
yield chunk
|
| 551 |
+
|
| 552 |
+
yield {"type": "meta", "source_type": source_type, "thinking": "Inferred from AI internal weights."}
|
| 553 |
+
|
| 554 |
+
# 自己拡充: 推論結果をDBに保存(信頼度が十分な場合)
|
| 555 |
+
if save_to_memory and len(generated_response) > 0:
|
| 556 |
+
await self._save_inference_to_db(
|
| 557 |
+
domain_id=domain_id,
|
| 558 |
+
prompt=prompt,
|
| 559 |
+
response=generated_response,
|
| 560 |
+
confidence=0.75, # ストリーミングなので固定値
|
| 561 |
+
source_type="ai",
|
| 562 |
+
model_id=model_config.model_id
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# 倒木システム: 師匠の出力を教師データとして保存
|
| 566 |
+
logger.error(f"[CRITICAL DEBUG] Reached training data save section! generated_response length={len(generated_response)}")
|
| 567 |
+
is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
|
| 568 |
+
logger.info(f"[Training Data Check - Streaming] master_model={self.master_model.model_id if self.master_model else None}, "
|
| 569 |
+
f"current_model={model_config.model_id}, is_master={is_master}, "
|
| 570 |
+
f"response_length={len(generated_response)}, save_to_memory={save_to_memory}")
|
| 571 |
+
if is_master and len(generated_response) > 0:
|
| 572 |
+
logger.info(f"[Training Data - Streaming] Saving master output for domain '{domain_id}'...")
|
| 573 |
+
await self._save_master_output_as_training_data(
|
| 574 |
+
prompt=prompt,
|
| 575 |
+
master_response=generated_response,
|
| 576 |
+
domain_id=domain_id,
|
| 577 |
+
confidence=0.8
|
| 578 |
+
)
|
| 579 |
+
else:
|
| 580 |
+
logger.info(f"[Training Data - Streaming] NOT saving: is_master={is_master}, response_length={len(generated_response)}")
|
| 581 |
+
|
| 582 |
+
def _check_db_knowledge(self, domain_id: str, prompt: str) -> bool:
|
| 583 |
+
"""
|
| 584 |
+
指定されたドメインとプロンプトに関連する知識があるかチェックする。
|
| 585 |
+
樹木型空間記憶(.iath)を使用。
|
| 586 |
+
"""
|
| 587 |
+
try:
|
| 588 |
+
if self.dendritic_memory is None:
|
| 589 |
+
return False
|
| 590 |
+
|
| 591 |
+
# テキスト検索でマッチがあるかチェック
|
| 592 |
+
results = self.dendritic_memory.search_by_text(prompt[:200], top_k=1)
|
| 593 |
+
return len(results) > 0
|
| 594 |
+
|
| 595 |
+
except Exception as e:
|
| 596 |
+
logger.error(f"Error checking dendritic memory: {e}")
|
| 597 |
+
return False
|
| 598 |
+
|
| 599 |
+
def _retrieve_relevant_knowledge(self, domain_id: str, prompt: str, top_k: int = 3) -> list:
|
| 600 |
+
"""
|
| 601 |
+
指定されたドメインとプロンプトに関連する知識を取得する。
|
| 602 |
+
樹木型空間記憶(.iath)のハイブリッド検索を使用。
|
| 603 |
+
"""
|
| 604 |
+
try:
|
| 605 |
+
if self.dendritic_memory is None:
|
| 606 |
+
return []
|
| 607 |
+
|
| 608 |
+
# ハイブリッド検索(テキスト + 空間座標)
|
| 609 |
+
# TODO: プロンプトから空間座標を推定する機能を追加
|
| 610 |
+
results = self.dendritic_memory.hybrid_search(prompt, query_coords=None, top_k=top_k)
|
| 611 |
+
|
| 612 |
+
# 統一フォーマットに変換
|
| 613 |
+
return [
|
| 614 |
+
{
|
| 615 |
+
"id": tile["metadata"]["knowledge_id"],
|
| 616 |
+
"topic": tile["metadata"]["topic"],
|
| 617 |
+
"content": tile["content"]["final_response"],
|
| 618 |
+
"thinking_process": tile["content"]["thinking_process"],
|
| 619 |
+
"confidence_score": tile["verification"]["initial_certainty"],
|
| 620 |
+
"verification_type": tile["verification"]["status"],
|
| 621 |
+
"coordinates": tile["coordinates"],
|
| 622 |
+
"text_match_score": tile.get("text_match_score", 0),
|
| 623 |
+
"spatial_distance": tile.get("spatial_distance", None)
|
| 624 |
+
}
|
| 625 |
+
for tile in results
|
| 626 |
+
]
|
| 627 |
+
|
| 628 |
+
except Exception as e:
|
| 629 |
+
logger.error(f"Error retrieving relevant knowledge from dendritic memory: {e}")
|
| 630 |
+
return []
|
| 631 |
+
|
| 632 |
+
async def _perform_llm_inference(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
|
| 633 |
+
"""
|
| 634 |
+
指定されたモデル設定でLLM推論を実行する。
|
| 635 |
+
各プロバイダーの実装を使用。
|
| 636 |
+
"""
|
| 637 |
+
provider_type = model_config.provider
|
| 638 |
+
|
| 639 |
+
if provider_type not in self.providers:
|
| 640 |
+
raise ValueError(f"Unsupported provider: {provider_type}")
|
| 641 |
+
|
| 642 |
+
provider = self.providers[provider_type]
|
| 643 |
+
|
| 644 |
+
try:
|
| 645 |
+
result = await provider.infer(model_config, prompt, temperature)
|
| 646 |
+
return result
|
| 647 |
+
except Exception as e:
|
| 648 |
+
logger.error(f"Error during LLM inference with provider '{provider_type}': {e}")
|
| 649 |
+
raise
|
| 650 |
+
|
| 651 |
+
async def _perform_llm_streaming_inference(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
|
| 652 |
+
"""
|
| 653 |
+
指定されたモデル設定でLLMストリーミング推論を実行する。
|
| 654 |
+
各プロバイダーの実装を使用。
|
| 655 |
+
"""
|
| 656 |
+
provider_type = model_config.provider
|
| 657 |
+
|
| 658 |
+
if provider_type not in self.providers:
|
| 659 |
+
raise ValueError(f"Unsupported provider: {provider_type}")
|
| 660 |
+
|
| 661 |
+
provider = self.providers[provider_type]
|
| 662 |
+
|
| 663 |
+
try:
|
| 664 |
+
async for chunk in provider.infer_streaming(model_config, prompt, temperature):
|
| 665 |
+
yield chunk
|
| 666 |
+
except Exception as e:
|
| 667 |
+
logger.error(f"Error during LLM streaming inference with provider '{provider_type}': {e}")
|
| 668 |
+
yield {"type": "error", "content": str(e)}
|
| 669 |
+
|
| 670 |
+
async def _save_inference_to_db(self, domain_id: str, prompt: str, response: str, confidence: float, source_type: str, model_id: str) -> bool:
|
| 671 |
+
"""
|
| 672 |
+
推論結果をKnowledge TileとしてDBに保存する(自己拡充)。
|
| 673 |
+
|
| 674 |
+
Priority 2実装: 座標自動推定 + .iath保存
|
| 675 |
+
"""
|
| 676 |
+
try:
|
| 677 |
+
from backend.app.database.models import KnowledgeTile
|
| 678 |
+
from backend.app.database.session import SessionLocal # 遅延インポート
|
| 679 |
+
|
| 680 |
+
db = SessionLocal()
|
| 681 |
+
tile_id = f"ai_ktile_{uuid.uuid4().hex}"
|
| 682 |
+
|
| 683 |
+
# AI生成の知識タイルを作成(SQLite)
|
| 684 |
+
new_tile = KnowledgeTile(
|
| 685 |
+
id=tile_id,
|
| 686 |
+
workspace_id="default_workspace", # ローカル版なのでデフォルトワークスペース
|
| 687 |
+
domain_id=domain_id,
|
| 688 |
+
topic=prompt[:200], # プロンプトをトピックとして使用
|
| 689 |
+
content=response,
|
| 690 |
+
confidence_score=confidence,
|
| 691 |
+
verification_type="ai", # AI生成を示す
|
| 692 |
+
verification_count=1,
|
| 693 |
+
contributor_id=None, # AIなのでcontributorなし
|
| 694 |
+
last_verified_by_id=None
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
db.add(new_tile)
|
| 698 |
+
db.commit()
|
| 699 |
+
db.close()
|
| 700 |
+
|
| 701 |
+
logger.info(f"Saved AI inference to SQLite: {new_tile.id}")
|
| 702 |
+
|
| 703 |
+
# Priority 2: 座標自動推定 + .iath保存
|
| 704 |
+
try:
|
| 705 |
+
# 座標推定用のLLM推論関数を作成
|
| 706 |
+
async def llm_inference_for_coords(coord_prompt):
|
| 707 |
+
# 師匠モデル(またはDeepSeek)を使って座標推定
|
| 708 |
+
estimation_model = self.master_model if self.master_model else self._get_any_available_model()
|
| 709 |
+
|
| 710 |
+
if estimation_model:
|
| 711 |
+
result = await self._perform_llm_inference(
|
| 712 |
+
estimation_model,
|
| 713 |
+
coord_prompt,
|
| 714 |
+
temperature=0.3 # 低温度で一貫性を保つ
|
| 715 |
+
)
|
| 716 |
+
return result.get("response", "")
|
| 717 |
+
return "{\"coordinates\": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], \"confidence\": 0.3}"
|
| 718 |
+
|
| 719 |
+
# 座標を推定
|
| 720 |
+
coord_result = await self.coordinate_estimator.estimate_coordinates(
|
| 721 |
+
prompt=prompt,
|
| 722 |
+
response=response,
|
| 723 |
+
domain_id=domain_id,
|
| 724 |
+
llm_inference_func=llm_inference_for_coords,
|
| 725 |
+
use_reasoning=False # 高速化のため推論過程は省略
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
# .iath Tileオブジェクトを作成
|
| 729 |
+
iath_tile = create_tile_from_ai_output(
|
| 730 |
+
knowledge_id=tile_id,
|
| 731 |
+
topic=prompt[:100],
|
| 732 |
+
prompt=prompt,
|
| 733 |
+
response=response,
|
| 734 |
+
coordinates=coord_result["coordinates"],
|
| 735 |
+
confidence=confidence,
|
| 736 |
+
domain_id=domain_id,
|
| 737 |
+
source="ai_generated"
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# .iathファイルに保存
|
| 741 |
+
success = self.iath_writer.append_tile(iath_tile)
|
| 742 |
+
|
| 743 |
+
if success:
|
| 744 |
+
logger.info(f"Saved AI inference to .iath with coordinates: {coord_result['coordinates']}")
|
| 745 |
+
|
| 746 |
+
# メモリをリロード(新しいタイルを検索可能にする)
|
| 747 |
+
self._load_dendritic_memory()
|
| 748 |
+
else:
|
| 749 |
+
logger.warning(f"Failed to save to .iath, but SQLite save succeeded")
|
| 750 |
+
|
| 751 |
+
except Exception as iath_error:
|
| 752 |
+
logger.error(f"Error saving to .iath (SQLite save succeeded): {iath_error}")
|
| 753 |
+
# .iath保存失敗してもSQLite保存は成功しているのでTrueを返す
|
| 754 |
+
|
| 755 |
+
return True
|
| 756 |
+
|
| 757 |
+
except Exception as e:
|
| 758 |
+
logger.error(f"Error saving inference to DB: {e}")
|
| 759 |
+
return False
|
| 760 |
+
|
| 761 |
+
async def _save_master_output_as_training_data(self, prompt: str, master_response: str, domain_id: str, confidence: float) -> bool:
|
| 762 |
+
"""
|
| 763 |
+
師匠の出力を弟子のファインチューニング用教師データとして保存する。
|
| 764 |
+
倒木システムの核心機能。
|
| 765 |
+
"""
|
| 766 |
+
try:
|
| 767 |
+
import json
|
| 768 |
+
import os
|
| 769 |
+
from datetime import datetime
|
| 770 |
+
|
| 771 |
+
logger.info(f"[Save Training Data] Starting to save master output (domain={domain_id}, confidence={confidence})")
|
| 772 |
+
|
| 773 |
+
# トレーニングデータ保存ディレクトリ
|
| 774 |
+
training_data_dir = "training_data/master_outputs"
|
| 775 |
+
logger.info(f"[Save Training Data] Creating directory: {training_data_dir}")
|
| 776 |
+
os.makedirs(training_data_dir, exist_ok=True)
|
| 777 |
+
|
| 778 |
+
# Alpaca形式で保存
|
| 779 |
+
training_example = {
|
| 780 |
+
"instruction": f"You are an expert in {domain_id}. Provide accurate information based on verified knowledge.",
|
| 781 |
+
"input": prompt,
|
| 782 |
+
"output": master_response,
|
| 783 |
+
"metadata": {
|
| 784 |
+
"domain_id": domain_id,
|
| 785 |
+
"confidence": confidence,
|
| 786 |
+
"master_model_id": self.master_model.model_id if self.master_model else "unknown",
|
| 787 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 788 |
+
"source": "master_output"
|
| 789 |
+
}
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
# JSONLファイルに追記
|
| 793 |
+
output_file = os.path.join(training_data_dir, f"master_outputs_{domain_id}.jsonl")
|
| 794 |
+
logger.info(f"[Save Training Data] Writing to file: {output_file}")
|
| 795 |
+
with open(output_file, 'a', encoding='utf-8') as f:
|
| 796 |
+
f.write(json.dumps(training_example, ensure_ascii=False) + '\n')
|
| 797 |
+
|
| 798 |
+
logger.info(f"✓ Successfully saved master output as training data: {output_file}")
|
| 799 |
+
return True
|
| 800 |
+
|
| 801 |
+
except Exception as e:
|
| 802 |
+
logger.error(f"✗ Error saving master output as training data: {e}", exc_info=True)
|
| 803 |
+
return False
|