kofdai commited on
Commit
5af8123
·
verified ·
1 Parent(s): 49146ed

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-313.pyc ADDED
Binary file (154 Bytes). View file
 
__pycache__/auto_training.cpython-313.pyc ADDED
Binary file (19 kB). View file
 
__pycache__/coordinate_estimator.cpython-313.pyc ADDED
Binary file (14 kB). View file
 
__pycache__/db_providers.cpython-313.pyc ADDED
Binary file (10.4 kB). View file
 
__pycache__/fine_tuning.cpython-313.pyc ADDED
Binary file (21.8 kB). View file
 
__pycache__/iath_db_provider.cpython-313.pyc ADDED
Binary file (18.3 kB). View file
 
__pycache__/iath_memory.cpython-313.pyc ADDED
Binary file (16.9 kB). View file
 
__pycache__/iath_writer.cpython-313.pyc ADDED
Binary file (17.9 kB). View file
 
__pycache__/llm_providers.cpython-313.pyc ADDED
Binary file (21 kB). View file
 
__pycache__/model_router.cpython-313.pyc ADDED
Binary file (45.1 kB). View file
 
auto_training.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NullAI Auto-Training Manager
3
+
4
+ 自動学習システムの核となるモジュール。
5
+ データ量や時間ベースのトリガーで自動的にファインチューニングを実行する。
6
+ """
7
+ import asyncio
8
+ import json
9
+ import logging
10
+ from datetime import datetime, timedelta
11
+ from pathlib import Path
12
+ from typing import Dict, Any, Optional, List
13
+ from dataclasses import dataclass, asdict
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class AutoTrainingState:
20
+ """自動学習システムの状態"""
21
+ enabled: bool = True
22
+ last_check_time: Optional[str] = None
23
+ last_training_time: Optional[str] = None
24
+ last_training_success: bool = True
25
+ last_training_examples_count: int = 0
26
+ next_scheduled_training: Optional[str] = None
27
+ total_auto_trainings: int = 0
28
+ consecutive_failures: int = 0
29
+ is_training: bool = False
30
+ last_error: Optional[str] = None
31
+
32
+
33
+ class AutoTrainingManager:
34
+ """
35
+ 自動学習マネージャー
36
+
37
+ 設定に基づいて、トレーニングデータを監視し、
38
+ 条件を満たした場合に自動的にファインチューニングを実行する。
39
+ """
40
+
41
+ def __init__(self, config: Dict[str, Any], training_manager):
42
+ """
43
+ Args:
44
+ config: null_ai_config.json の auto_training セクション
45
+ training_manager: FineTuningManager インスタンス
46
+ """
47
+ self.config = config
48
+ self.training_manager = training_manager
49
+ self.state = AutoTrainingState()
50
+ self.state_file = Path("training_data/auto_training_state.json")
51
+
52
+ # 設定の読み込み
53
+ self.enabled = config.get("enabled", True)
54
+ self.trigger_mode = config.get("trigger_mode", "hybrid")
55
+ self.min_examples = config.get("min_examples", 100)
56
+ self.min_days = config.get("min_days_since_last_training", 7)
57
+ self.max_days = config.get("max_days_since_last_training", 30)
58
+ self.quality_threshold = config.get("quality_threshold", 0.8)
59
+ self.check_interval_minutes = config.get("check_interval_minutes", 60)
60
+ self.preferred_hour = config.get("preferred_training_hour", 2)
61
+ self.allow_manual_override = config.get("allow_manual_override", True)
62
+
63
+ # トレーニングパラメータ
64
+ self.training_method = config.get("training_method", "peft")
65
+ self.training_params = config.get("training_params", {})
66
+
67
+ # 状態の復元
68
+ self._load_state()
69
+
70
+ logger.info(f"AutoTrainingManager initialized: enabled={self.enabled}, trigger_mode={self.trigger_mode}")
71
+
72
+ def _load_state(self):
73
+ """永続化された状態を読み込む"""
74
+ try:
75
+ if self.state_file.exists():
76
+ with open(self.state_file, 'r') as f:
77
+ state_dict = json.load(f)
78
+ self.state = AutoTrainingState(**state_dict)
79
+ logger.info(f"Loaded auto-training state from {self.state_file}")
80
+ except Exception as e:
81
+ logger.warning(f"Failed to load auto-training state: {e}")
82
+
83
+ def _save_state(self):
84
+ """状態を永続化する"""
85
+ try:
86
+ self.state_file.parent.mkdir(parents=True, exist_ok=True)
87
+ with open(self.state_file, 'w') as f:
88
+ json.dump(asdict(self.state), f, indent=2)
89
+ except Exception as e:
90
+ logger.error(f"Failed to save auto-training state: {e}")
91
+
92
+ def get_training_data_stats(self, domain_id: Optional[str] = None) -> Dict[str, Any]:
93
+ """
94
+ トレーニングデータの統計を取得
95
+
96
+ Returns:
97
+ {
98
+ "total_examples": int,
99
+ "examples_by_domain": Dict[str, int],
100
+ "high_quality_count": int,
101
+ "oldest_timestamp": str,
102
+ "newest_timestamp": str
103
+ }
104
+ """
105
+ training_data_dir = Path("training_data/master_outputs")
106
+ if not training_data_dir.exists():
107
+ return {
108
+ "total_examples": 0,
109
+ "examples_by_domain": {},
110
+ "high_quality_count": 0,
111
+ "oldest_timestamp": None,
112
+ "newest_timestamp": None
113
+ }
114
+
115
+ stats = {
116
+ "total_examples": 0,
117
+ "examples_by_domain": {},
118
+ "high_quality_count": 0,
119
+ "oldest_timestamp": None,
120
+ "newest_timestamp": None
121
+ }
122
+
123
+ # JSONLファイルを走査
124
+ jsonl_files = []
125
+ if domain_id:
126
+ jsonl_files = [training_data_dir / f"master_outputs_{domain_id}.jsonl"]
127
+ else:
128
+ jsonl_files = list(training_data_dir.glob("master_outputs_*.jsonl"))
129
+
130
+ for jsonl_file in jsonl_files:
131
+ if not jsonl_file.exists():
132
+ continue
133
+
134
+ domain = jsonl_file.stem.replace("master_outputs_", "")
135
+ domain_count = 0
136
+
137
+ with open(jsonl_file, 'r', encoding='utf-8') as f:
138
+ for line in f:
139
+ try:
140
+ example = json.loads(line.strip())
141
+ stats["total_examples"] += 1
142
+ domain_count += 1
143
+
144
+ # 高品質データのカウント
145
+ confidence = example.get("metadata", {}).get("confidence", 0)
146
+ if confidence >= self.quality_threshold:
147
+ stats["high_quality_count"] += 1
148
+
149
+ # タイムスタンプの追跡
150
+ timestamp = example.get("metadata", {}).get("timestamp")
151
+ if timestamp:
152
+ if stats["oldest_timestamp"] is None or timestamp < stats["oldest_timestamp"]:
153
+ stats["oldest_timestamp"] = timestamp
154
+ if stats["newest_timestamp"] is None or timestamp > stats["newest_timestamp"]:
155
+ stats["newest_timestamp"] = timestamp
156
+
157
+ except json.JSONDecodeError:
158
+ continue
159
+
160
+ if domain_count > 0:
161
+ stats["examples_by_domain"][domain] = domain_count
162
+
163
+ return stats
164
+
165
+ def check_training_trigger(self, domain_id: Optional[str] = None) -> tuple[bool, str]:
166
+ """
167
+ トレーニングをトリガーすべきかチェックする
168
+
169
+ Returns:
170
+ (should_trigger: bool, reason: str)
171
+ """
172
+ if not self.enabled:
173
+ return False, "Auto-training is disabled"
174
+
175
+ if self.state.is_training:
176
+ return False, "Training is already in progress"
177
+
178
+ # データ統計を取得
179
+ stats = self.get_training_data_stats(domain_id)
180
+
181
+ if stats["total_examples"] == 0:
182
+ return False, "No training data available"
183
+
184
+ # 最終トレーニングからの経過時間を計算
185
+ days_since_last = None
186
+ if self.state.last_training_time:
187
+ try:
188
+ last_training = datetime.fromisoformat(self.state.last_training_time)
189
+ days_since_last = (datetime.utcnow() - last_training).days
190
+ except ValueError:
191
+ pass
192
+
193
+ # トリガーモードに応じた判定
194
+ if self.trigger_mode == "data_count":
195
+ # データ量ベースのみ
196
+ if stats["high_quality_count"] >= self.min_examples:
197
+ return True, f"Sufficient training data ({stats['high_quality_count']} examples >= {self.min_examples})"
198
+ return False, f"Insufficient training data ({stats['high_quality_count']} < {self.min_examples})"
199
+
200
+ elif self.trigger_mode == "time_based":
201
+ # 時間ベースのみ
202
+ if days_since_last is None:
203
+ return True, "First auto-training"
204
+ if days_since_last >= self.min_days:
205
+ return True, f"Time threshold met ({days_since_last} days >= {self.min_days} days)"
206
+ return False, f"Too soon since last training ({days_since_last} < {self.min_days} days)"
207
+
208
+ elif self.trigger_mode == "hybrid":
209
+ # ハイブリッド(データ量 AND 時間)
210
+ if stats["high_quality_count"] < self.min_examples:
211
+ return False, f"Insufficient training data ({stats['high_quality_count']} < {self.min_examples})"
212
+
213
+ if days_since_last is None:
214
+ return True, f"First auto-training with {stats['high_quality_count']} examples"
215
+
216
+ if days_since_last >= self.min_days:
217
+ return True, f"Both conditions met: {stats['high_quality_count']} examples, {days_since_last} days since last training"
218
+
219
+ return False, f"Time condition not met ({days_since_last} < {self.min_days} days)"
220
+
221
+ elif self.trigger_mode == "max_interval":
222
+ # 最大間隔強制モード
223
+ if days_since_last is not None and days_since_last >= self.max_days:
224
+ return True, f"Maximum interval reached ({days_since_last} >= {self.max_days} days)"
225
+
226
+ # 通常のハイブリッド判定
227
+ if stats["high_quality_count"] >= self.min_examples and (days_since_last is None or days_since_last >= self.min_days):
228
+ return True, f"Standard conditions met: {stats['high_quality_count']} examples"
229
+
230
+ return False, "Conditions not met"
231
+
232
+ return False, f"Unknown trigger mode: {self.trigger_mode}"
233
+
234
+ def should_train_now(self) -> bool:
235
+ """
236
+ 現在がトレーニングに適した時間帯かチェック
237
+
238
+ preferred_training_hour の前後1時間をトレーニング推奨時間とする
239
+ """
240
+ current_hour = datetime.utcnow().hour
241
+
242
+ # 推奨時間の前後1時間
243
+ target_hours = [
244
+ (self.preferred_hour - 1) % 24,
245
+ self.preferred_hour,
246
+ (self.preferred_hour + 1) % 24
247
+ ]
248
+
249
+ return current_hour in target_hours
250
+
251
+ async def trigger_auto_training(self, domain_id: Optional[str] = None) -> Dict[str, Any]:
252
+ """
253
+ 自動トレーニングを実行
254
+
255
+ Returns:
256
+ トレーニング結果の辞書
257
+ """
258
+ logger.info(f"Starting auto-training for domain: {domain_id or 'all'}")
259
+
260
+ # 状態を更新
261
+ self.state.is_training = True
262
+ self.state.last_check_time = datetime.utcnow().isoformat()
263
+ self._save_state()
264
+
265
+ try:
266
+ # データ統計を取得
267
+ stats = self.get_training_data_stats(domain_id)
268
+
269
+ # ファインチューニングを実行
270
+ # 注: training_manager の実装に合わせて適切なメソッドを呼び出す
271
+ result = await self._execute_training(domain_id, stats)
272
+
273
+ # 成功時の状態更新
274
+ self.state.last_training_time = datetime.utcnow().isoformat()
275
+ self.state.last_training_success = result.get("success", False)
276
+ self.state.last_training_examples_count = stats["high_quality_count"]
277
+ self.state.total_auto_trainings += 1
278
+ self.state.consecutive_failures = 0
279
+ self.state.last_error = None
280
+
281
+ logger.info(f"Auto-training completed successfully: {result}")
282
+
283
+ return {
284
+ "success": True,
285
+ "result": result,
286
+ "stats": stats,
287
+ "timestamp": self.state.last_training_time
288
+ }
289
+
290
+ except Exception as e:
291
+ logger.error(f"Auto-training failed: {e}", exc_info=True)
292
+
293
+ # 失敗時の状態更新
294
+ self.state.last_training_success = False
295
+ self.state.consecutive_failures += 1
296
+ self.state.last_error = str(e)
297
+
298
+ return {
299
+ "success": False,
300
+ "error": str(e),
301
+ "consecutive_failures": self.state.consecutive_failures
302
+ }
303
+
304
+ finally:
305
+ self.state.is_training = False
306
+ self._save_state()
307
+
308
+ async def _execute_training(self, domain_id: Optional[str], stats: Dict[str, Any]) -> Dict[str, Any]:
309
+ """
310
+ 実際のトレーニングを実行(内部メソッド)
311
+ """
312
+ # トレーニングパラメータを準備
313
+ training_params = {
314
+ "apprentice_model_name": None, # 既存の弟子モデルを使用
315
+ "domain_id": domain_id,
316
+ "method": self.training_method,
317
+ "epochs": self.training_params.get("epochs", 3),
318
+ "learning_rate": self.training_params.get("learning_rate", 2e-4),
319
+ "batch_size": self.training_params.get("batch_size", 4),
320
+ "lora_r": self.training_params.get("lora_r", 8),
321
+ "lora_alpha": self.training_params.get("lora_alpha", 16),
322
+ "output_name": f"auto_training_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}"
323
+ }
324
+
325
+ logger.info(f"Executing training with params: {training_params}")
326
+
327
+ # FineTuningManagerを使ってトレーニングを実行
328
+ # 注: この部分は実際のトレーニングAPIに合わせて実装する必要があります
329
+ # 今はプレースホルダーとして簡単な構造を返します
330
+
331
+ # TODO: 実際のトレーニング実行コードをここに実装
332
+ result = {
333
+ "success": True,
334
+ "output_dir": f"training_data/checkpoints/{training_params['output_name']}",
335
+ "model_name": training_params['output_name'],
336
+ "train_loss": 0.5, # プレースホルダー
337
+ "method": self.training_method,
338
+ "examples_used": stats["high_quality_count"]
339
+ }
340
+
341
+ return result
342
+
343
+ def get_status(self) -> Dict[str, Any]:
344
+ """
345
+ 自動学習システムの現在の状態を取得
346
+ """
347
+ should_trigger, reason = self.check_training_trigger()
348
+ stats = self.get_training_data_stats()
349
+
350
+ return {
351
+ "enabled": self.enabled,
352
+ "is_training": self.state.is_training,
353
+ "trigger_mode": self.trigger_mode,
354
+ "should_trigger": should_trigger,
355
+ "trigger_reason": reason,
356
+ "config": {
357
+ "min_examples": self.min_examples,
358
+ "min_days": self.min_days,
359
+ "max_days": self.max_days,
360
+ "quality_threshold": self.quality_threshold,
361
+ "check_interval_minutes": self.check_interval_minutes,
362
+ "preferred_hour": self.preferred_hour
363
+ },
364
+ "state": {
365
+ "last_check_time": self.state.last_check_time,
366
+ "last_training_time": self.state.last_training_time,
367
+ "last_training_success": self.state.last_training_success,
368
+ "last_training_examples_count": self.state.last_training_examples_count,
369
+ "total_auto_trainings": self.state.total_auto_trainings,
370
+ "consecutive_failures": self.state.consecutive_failures,
371
+ "last_error": self.state.last_error
372
+ },
373
+ "data_stats": stats,
374
+ "should_train_now": self.should_train_now()
375
+ }
376
+
377
+ def enable(self):
378
+ """自動学習を有効化"""
379
+ self.enabled = True
380
+ self.state.enabled = True
381
+ self._save_state()
382
+ logger.info("Auto-training enabled")
383
+
384
+ def disable(self):
385
+ """自動学習を無効化"""
386
+ self.enabled = False
387
+ self.state.enabled = False
388
+ self._save_state()
389
+ logger.info("Auto-training disabled")
390
+
391
+ def update_config(self, new_config: Dict[str, Any]):
392
+ """設定を更新"""
393
+ self.config.update(new_config)
394
+
395
+ # 設定値を再読み込み
396
+ self.trigger_mode = self.config.get("trigger_mode", self.trigger_mode)
397
+ self.min_examples = self.config.get("min_examples", self.min_examples)
398
+ self.min_days = self.config.get("min_days_since_last_training", self.min_days)
399
+ self.max_days = self.config.get("max_days_since_last_training", self.max_days)
400
+ self.quality_threshold = self.config.get("quality_threshold", self.quality_threshold)
401
+
402
+ logger.info(f"Auto-training config updated: {new_config}")
403
+ self._save_state()
coordinate_estimator.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # null_ai/coordinate_estimator.py
2
+ """
3
+ Coordinate Auto-Estimation Module
4
+
5
+ AIを使って知識タイルの6次元座標を自動推定します。
6
+ 座標: [x, y, z, c, g, v]
7
+ - medical_space [x, y, z]: ドメイン固有の3次元空間
8
+ - meta_space [c, g, v]: Certainty, Granularity, Verification
9
+ """
10
+
11
+ import logging
12
+ import json
13
+ from typing import List, Dict, Any, Optional
14
+ import asyncio
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class CoordinateEstimator:
20
+ """
21
+ LLMを使って6次元座標を自動推定するクラス
22
+ """
23
+
24
+ def __init__(self):
25
+ self.domain_schemas = self._load_domain_schemas()
26
+
27
+ def _load_domain_schemas(self) -> Dict[str, Dict[str, str]]:
28
+ """
29
+ 各ドメインの座標軸の定義を返す
30
+
31
+ 将来的には設定ファイルから読み込む
32
+ """
33
+ return {
34
+ "medical": {
35
+ "x": "Anatomical location (0.0=nervous system, 0.5=cardiovascular, 1.0=digestive)",
36
+ "y": "Pathological classification (0.0=infectious, 0.5=metabolic, 1.0=trauma)",
37
+ "z": "Treatment level (0.0=prevention, 0.5=diagnosis, 1.0=treatment)"
38
+ },
39
+ "general": {
40
+ "x": "Knowledge category (0.0=science, 0.5=technology, 1.0=humanities)",
41
+ "y": "Complexity level (0.0=basic, 0.5=intermediate, 1.0=advanced)",
42
+ "z": "Application scope (0.0=theoretical, 0.5=practical, 1.0=applied)"
43
+ },
44
+ "legal": {
45
+ "x": "Legal field (0.0=civil, 0.5=criminal, 1.0=commercial)",
46
+ "y": "Court level (0.0=district, 0.5=high, 1.0=supreme)",
47
+ "z": "Era (0.0=classical, 0.5=modern, 1.0=contemporary)"
48
+ },
49
+ "technology": {
50
+ "x": "Technology domain (0.0=hardware, 0.5=software, 1.0=network)",
51
+ "y": "Maturity (0.0=emerging, 0.5=established, 1.0=legacy)",
52
+ "z": "Scale (0.0=personal, 0.5=enterprise, 1.0=global)"
53
+ }
54
+ }
55
+
56
+ async def estimate_coordinates(
57
+ self,
58
+ prompt: str,
59
+ response: str,
60
+ domain_id: str,
61
+ llm_inference_func,
62
+ use_reasoning: bool = True
63
+ ) -> Dict[str, Any]:
64
+ """
65
+ 6次元座標を推定
66
+
67
+ Args:
68
+ prompt: ユーザーの質問
69
+ response: AIの回答
70
+ domain_id: ドメインID
71
+ llm_inference_func: LLM推論関数(async)
72
+ use_reasoning: 推論過程を含めるか
73
+
74
+ Returns:
75
+ {
76
+ "coordinates": [x, y, z, c, g, v],
77
+ "reasoning": "推定の理由",
78
+ "confidence": 0.85
79
+ }
80
+ """
81
+ # ドメインスキーマ取得
82
+ domain_schema = self.domain_schemas.get(
83
+ domain_id,
84
+ self.domain_schemas["general"] # フォールバック
85
+ )
86
+
87
+ # プロンプト構築
88
+ estimation_prompt = self._build_estimation_prompt(
89
+ prompt, response, domain_id, domain_schema, use_reasoning
90
+ )
91
+
92
+ # LLMに座標推定を依頼
93
+ try:
94
+ llm_response = await llm_inference_func(estimation_prompt)
95
+
96
+ # レスポンスから座標を抽出
97
+ result = self._parse_llm_response(llm_response)
98
+
99
+ # バリデーション
100
+ if self._validate_coordinates(result["coordinates"]):
101
+ logger.info(f"Estimated coordinates for domain '{domain_id}': {result['coordinates']}")
102
+ return result
103
+ else:
104
+ logger.error(f"Invalid coordinates: {result['coordinates']}")
105
+ return self._get_default_coordinates(domain_id)
106
+
107
+ except Exception as e:
108
+ logger.error(f"Coordinate estimation failed: {e}")
109
+ return self._get_default_coordinates(domain_id)
110
+
111
+ def _build_estimation_prompt(
112
+ self,
113
+ prompt: str,
114
+ response: str,
115
+ domain_id: str,
116
+ domain_schema: Dict[str, str],
117
+ use_reasoning: bool
118
+ ) -> str:
119
+ """
120
+ 座標推定用のプロンプトを構築
121
+ """
122
+ base_prompt = f"""You are an expert in knowledge space mapping and coordinate estimation.
123
+
124
+ Your task is to estimate the 6-dimensional coordinates that best represent the following knowledge in the domain of "{domain_id}".
125
+
126
+ **Coordinate System:**
127
+
128
+ 1. **Domain-specific space [x, y, z]** (each 0.0-1.0):
129
+ - x-axis: {domain_schema['x']}
130
+ - y-axis: {domain_schema['y']}
131
+ - z-axis: {domain_schema['z']}
132
+
133
+ 2. **Meta-information space [c, g, v]** (each 0.0-1.0):
134
+ - c (Certainty): How certain/verified is this knowledge?
135
+ * 0.0 = hypothesis, speculation
136
+ * 0.5 = established theory, widely accepted
137
+ * 1.0 = proven fact, empirically verified
138
+
139
+ - g (Granularity): How detailed/specific is this knowledge?
140
+ * 0.0 = high-level overview, general concept
141
+ * 0.5 = detailed explanation
142
+ * 1.0 = highly specialized, expert-level detail
143
+
144
+ - v (Verification): What is the verification status?
145
+ * 0.0 = unverified, no sources
146
+ * 0.5 = expert-reviewed, single source
147
+ * 1.0 = peer-reviewed, multiple sources confirmed
148
+
149
+ **Knowledge to estimate:**
150
+
151
+ Question: {prompt}
152
+
153
+ Answer: {response}
154
+
155
+ **Instructions:**
156
+ """
157
+
158
+ if use_reasoning:
159
+ base_prompt += """
160
+ 1. First, analyze the knowledge and explain your reasoning for each coordinate.
161
+ 2. Then, output the final coordinates.
162
+
163
+ Format your response as JSON:
164
+ {
165
+ "reasoning": "Your detailed reasoning here...",
166
+ "coordinates": [x, y, z, c, g, v],
167
+ "confidence": 0.85
168
+ }
169
+ """
170
+ else:
171
+ base_prompt += """
172
+ Output ONLY the coordinates as a JSON object:
173
+ {
174
+ "coordinates": [x, y, z, c, g, v],
175
+ "confidence": 0.85
176
+ }
177
+ """
178
+
179
+ base_prompt += """
180
+ **Important:**
181
+ - All coordinates must be between 0.0 and 1.0
182
+ - Use 2 decimal places (e.g., 0.75)
183
+ - confidence should reflect how confident you are in this estimation (0.0-1.0)
184
+ """
185
+
186
+ return base_prompt
187
+
188
+ def _parse_llm_response(self, llm_response: str) -> Dict[str, Any]:
189
+ """
190
+ LLMのレスポンスから座標を抽出
191
+ """
192
+ try:
193
+ # JSONブロックを探す
194
+ # LLMはしばしば ```json ... ``` で囲む
195
+ if "```json" in llm_response:
196
+ json_start = llm_response.find("```json") + 7
197
+ json_end = llm_response.find("```", json_start)
198
+ json_str = llm_response[json_start:json_end].strip()
199
+ elif "```" in llm_response:
200
+ json_start = llm_response.find("```") + 3
201
+ json_end = llm_response.find("```", json_start)
202
+ json_str = llm_response[json_start:json_end].strip()
203
+ else:
204
+ # JSON全体を探す
205
+ json_str = llm_response.strip()
206
+
207
+ # JSONパース
208
+ result = json.loads(json_str)
209
+
210
+ # 必須フィールドチェック
211
+ if "coordinates" not in result:
212
+ raise ValueError("Missing 'coordinates' field")
213
+
214
+ # デフォルト値設定
215
+ if "reasoning" not in result:
216
+ result["reasoning"] = "No reasoning provided"
217
+ if "confidence" not in result:
218
+ result["confidence"] = 0.5
219
+
220
+ return result
221
+
222
+ except json.JSONDecodeError as e:
223
+ logger.error(f"JSON parse error: {e}")
224
+ logger.debug(f"LLM response: {llm_response}")
225
+
226
+ # フォールバック: 数値のリストを直接探す
227
+ return self._fallback_parse(llm_response)
228
+
229
+ def _fallback_parse(self, llm_response: str) -> Dict[str, Any]:
230
+ """
231
+ JSONパースに失敗した場合のフォールバック
232
+ """
233
+ import re
234
+
235
+ # [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] のようなパターンを探す
236
+ pattern = r'\[[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*,[\s]*([0-9.]+)[\s]*\]'
237
+ match = re.search(pattern, llm_response)
238
+
239
+ if match:
240
+ coords = [float(match.group(i)) for i in range(1, 7)]
241
+ return {
242
+ "coordinates": coords,
243
+ "reasoning": "Parsed from array notation",
244
+ "confidence": 0.5
245
+ }
246
+
247
+ # パースに完全に失敗
248
+ raise ValueError("Could not parse coordinates from LLM response")
249
+
250
+ def _validate_coordinates(self, coordinates: List[float]) -> bool:
251
+ """
252
+ 座標の妥当性をチェック
253
+ """
254
+ if not isinstance(coordinates, list):
255
+ return False
256
+
257
+ if len(coordinates) != 6:
258
+ logger.error(f"Expected 6 coordinates, got {len(coordinates)}")
259
+ return False
260
+
261
+ for i, coord in enumerate(coordinates):
262
+ if not isinstance(coord, (int, float)):
263
+ logger.error(f"Coordinate {i} is not a number: {coord}")
264
+ return False
265
+
266
+ if not (0.0 <= coord <= 1.0):
267
+ logger.error(f"Coordinate {i} out of range [0.0, 1.0]: {coord}")
268
+ return False
269
+
270
+ return True
271
+
272
+ def _get_default_coordinates(self, domain_id: str) -> Dict[str, Any]:
273
+ """
274
+ 推定に失敗した場合のデフォルト座標
275
+ """
276
+ logger.warning(f"Using default coordinates for domain '{domain_id}'")
277
+
278
+ # ドメイン中心の座標
279
+ return {
280
+ "coordinates": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
281
+ "reasoning": "Default coordinates (estimation failed)",
282
+ "confidence": 0.3
283
+ }
284
+
285
+ async def estimate_batch(
286
+ self,
287
+ knowledge_items: List[Dict[str, str]],
288
+ llm_inference_func,
289
+ max_concurrent: int = 3
290
+ ) -> List[Dict[str, Any]]:
291
+ """
292
+ 複数の知識アイテムの座標を一括推定
293
+
294
+ Args:
295
+ knowledge_items: [{"prompt": "...", "response": "...", "domain_id": "..."}, ...]
296
+ llm_inference_func: LLM推論関数
297
+ max_concurrent: 同時実行数
298
+
299
+ Returns:
300
+ 推定結果のリスト
301
+ """
302
+ semaphore = asyncio.Semaphore(max_concurrent)
303
+
304
+ async def estimate_with_semaphore(item):
305
+ async with semaphore:
306
+ return await self.estimate_coordinates(
307
+ prompt=item["prompt"],
308
+ response=item["response"],
309
+ domain_id=item.get("domain_id", "general"),
310
+ llm_inference_func=llm_inference_func
311
+ )
312
+
313
+ tasks = [estimate_with_semaphore(item) for item in knowledge_items]
314
+ results = await asyncio.gather(*tasks)
315
+
316
+ return results
317
+
318
+ def get_domain_schema(self, domain_id: str) -> Dict[str, str]:
319
+ """
320
+ ドメインスキーマを取得(UI表示用)
321
+ """
322
+ return self.domain_schemas.get(domain_id, self.domain_schemas["general"])
323
+
324
+ def add_domain_schema(self, domain_id: str, schema: Dict[str, str]):
325
+ """
326
+ 新しいドメインスキーマを追加
327
+ """
328
+ if not all(key in schema for key in ["x", "y", "z"]):
329
+ raise ValueError("Schema must contain 'x', 'y', 'z' definitions")
330
+
331
+ self.domain_schemas[domain_id] = schema
332
+ logger.info(f"Added domain schema for '{domain_id}'")
333
+
334
+ def interpolate_coordinates(
335
+ self,
336
+ coord1: List[float],
337
+ coord2: List[float],
338
+ weight: float = 0.5
339
+ ) -> List[float]:
340
+ """
341
+ 2つの座標の間を補間(類似知識の座標推定に使用)
342
+
343
+ Args:
344
+ coord1: 座標1
345
+ coord2: 座標2
346
+ weight: 補間ウェイト (0.0=coord1, 1.0=coord2)
347
+
348
+ Returns:
349
+ 補間された座標
350
+ """
351
+ if len(coord1) != 6 or len(coord2) != 6:
352
+ raise ValueError("Both coordinates must be 6-dimensional")
353
+
354
+ interpolated = [
355
+ coord1[i] * (1 - weight) + coord2[i] * weight
356
+ for i in range(6)
357
+ ]
358
+
359
+ return interpolated
db_enrichment.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # null_ai/db_enrichment.py
2
+ """
3
+ Database Enrichment Module
4
+
5
+ AIを使って自動的に知識ベースを拡充するモジュール。
6
+ DeepSeekが質問を生成 → 師匠モデルが回答 → .iathに保存
7
+ """
8
+
9
+ import logging
10
+ import asyncio
11
+ from typing import List, Dict, Any, Optional, Callable
12
+ from datetime import datetime
13
+ import uuid
14
+ import json
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class AIEnrichmentEngine:
20
+ """
21
+ AI駆動の知識ベース拡充エンジン
22
+
23
+ 動作フロー:
24
+ 1. DeepSeek(永久指導者)が拡充用の質問を生成
25
+ 2. 師匠モデルが質問に回答
26
+ 3. 座標を自動推定
27
+ 4. .iathに保存
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ prompt_generator_func, # プロンプト生成用LLM関数
33
+ answer_generator_func, # 回答生成用LLM関数
34
+ coordinate_estimator, # CoordinateEstimator
35
+ iath_writer # IathWriter
36
+ ):
37
+ """
38
+ Args:
39
+ prompt_generator_func: async def(prompt) -> str (DeepSeek使用)
40
+ answer_generator_func: async def(prompt) -> dict (師匠モデル使用)
41
+ coordinate_estimator: CoordinateEstimatorインスタンス
42
+ iath_writer: IathWriterインスタンス
43
+ """
44
+ self.prompt_generator = prompt_generator_func
45
+ self.answer_generator = answer_generator_func
46
+ self.coordinate_estimator = coordinate_estimator
47
+ self.iath_writer = iath_writer
48
+
49
+ self.enrichment_state = {
50
+ "is_running": False,
51
+ "progress": 0.0,
52
+ "current_question": 0,
53
+ "total_questions": 0,
54
+ "generated_tiles": 0,
55
+ "start_time": None,
56
+ "domain_id": None
57
+ }
58
+
59
+ async def generate_enrichment_questions(
60
+ self,
61
+ domain_id: str,
62
+ num_questions: int = 10,
63
+ focus_areas: Optional[List[str]] = None
64
+ ) -> List[str]:
65
+ """
66
+ DeepSeekを使って拡充用の質問を生成
67
+
68
+ Args:
69
+ domain_id: ドメインID
70
+ num_questions: 生成する質問数
71
+ focus_areas: 重点領域(例: ["基礎理論", "応用例", "最新研究"])
72
+
73
+ Returns:
74
+ 生成された質問のリスト
75
+ """
76
+ logger.info(f"Generating {num_questions} enrichment questions for domain '{domain_id}'")
77
+
78
+ # 重点領域の文字列化
79
+ focus_text = ""
80
+ if focus_areas:
81
+ focus_text = f"\nFocus on these specific areas: {', '.join(focus_areas)}"
82
+
83
+ # プロンプト構築
84
+ generation_prompt = f"""You are an expert knowledge curator in the domain of {domain_id}.
85
+
86
+ Your task is to generate {num_questions} diverse, high-quality questions that would enrich a knowledge base in this domain.
87
+
88
+ Guidelines:
89
+ 1. Cover a wide range of topics within the domain
90
+ 2. Include questions at different complexity levels (basic, intermediate, advanced)
91
+ 3. Focus on practical applications and edge cases
92
+ 4. Avoid overly broad or trivial questions
93
+ 5. Each question should elicit detailed, informative answers{focus_text}
94
+
95
+ Output format:
96
+ Return ONLY a JSON array of questions, like this:
97
+ ["Question 1?", "Question 2?", ...]
98
+
99
+ Domain: {domain_id}
100
+ Number of questions: {num_questions}
101
+
102
+ Generate the questions now:"""
103
+
104
+ try:
105
+ # DeepSeekでプロンプト生成
106
+ logger.info(f"Calling prompt generator with prompt length: {len(generation_prompt)}")
107
+ response = await self.prompt_generator(generation_prompt)
108
+
109
+ # 応答の長さをログ出力
110
+ logger.info(f"Received response from LLM, length: {len(response) if response else 0}")
111
+ if not response:
112
+ logger.error("LLM returned empty response!")
113
+ return []
114
+
115
+ logger.debug(f"Response preview (first 200 chars): {response[:200]}")
116
+
117
+ # JSON抽出
118
+ questions = self._extract_questions_from_response(response)
119
+
120
+ # 重複除去
121
+ questions = list(dict.fromkeys(questions)) # 順序を保持しつつ重複除去
122
+
123
+ logger.info(f"Generated {len(questions)} unique questions")
124
+ return questions[:num_questions]
125
+
126
+ except Exception as e:
127
+ logger.error(f"Failed to generate enrichment questions: {e}", exc_info=True)
128
+ return []
129
+
130
+ def _extract_questions_from_response(self, response: str) -> List[str]:
131
+ """
132
+ LLMのレスポンスから質問リストを抽出
133
+ """
134
+ # 空の応答チェック
135
+ if not response or not response.strip():
136
+ logger.error(f"Empty response from LLM. Response: '{response}'")
137
+ return []
138
+
139
+ try:
140
+ # JSONブロックを探す
141
+ json_str = ""
142
+ if "```json" in response:
143
+ json_start = response.find("```json") + 7
144
+ json_end = response.find("```", json_start)
145
+ json_str = response[json_start:json_end].strip()
146
+ elif "```" in response:
147
+ json_start = response.find("```") + 3
148
+ json_end = response.find("```", json_start)
149
+ json_str = response[json_start:json_end].strip()
150
+ elif "[" in response and "]" in response:
151
+ # JSON配列を直接探す
152
+ json_start = response.find("[")
153
+ json_end = response.rfind("]") + 1
154
+ json_str = response[json_start:json_end]
155
+ else:
156
+ json_str = response.strip()
157
+
158
+ # 空のjson_strチェック
159
+ if not json_str:
160
+ logger.error(f"Could not extract JSON from response. Full response: '{response[:500]}'")
161
+ return self._fallback_parse(response)
162
+
163
+ # JSONパース
164
+ questions = json.loads(json_str)
165
+
166
+ if isinstance(questions, list):
167
+ # 文字列のリストかチェック
168
+ return [q for q in questions if isinstance(q, str)]
169
+ else:
170
+ logger.error(f"Expected list, got {type(questions)}")
171
+ return []
172
+
173
+ except json.JSONDecodeError as e:
174
+ logger.error(f"JSON parse error: {e}. Attempted to parse: '{json_str[:200]}'")
175
+ logger.error(f"Full response (first 500 chars): '{response[:500]}'")
176
+ return self._fallback_parse(response)
177
+
178
+ def _fallback_parse(self, response: str) -> List[str]:
179
+ """
180
+ JSONパースに失敗した場合のフォールバック処理
181
+ 改行区切りのテキストとして質問を抽出
182
+ """
183
+ logger.info("Attempting fallback parsing of response...")
184
+ lines = response.split('\n')
185
+ questions = []
186
+ for line in lines:
187
+ line = line.strip()
188
+ # 数字付きリストのフォーマットを除去 (1. Question -> Question)
189
+ if line and len(line) > 5: # 最低5文字以上
190
+ if line[0].isdigit() or line.startswith('-') or line.startswith('*'):
191
+ # "1. " や "- " を除去
192
+ cleaned = line.lstrip('0123456789.-* ').strip()
193
+ if cleaned and '?' in cleaned:
194
+ questions.append(cleaned)
195
+ elif '?' in line:
196
+ # 疑問符がある行をそのまま質問として扱う
197
+ questions.append(line)
198
+
199
+ logger.info(f"Fallback parsing extracted {len(questions)} questions")
200
+ return questions
201
+
202
+ async def enrich_with_ai(
203
+ self,
204
+ domain_id: str,
205
+ num_questions: int = 10,
206
+ focus_areas: Optional[List[str]] = None,
207
+ progress_callback: Optional[Callable] = None
208
+ ) -> Dict[str, Any]:
209
+ """
210
+ AIを使って知識ベースを拡充
211
+
212
+ Args:
213
+ domain_id: ドメインID
214
+ num_questions: 生成する質問数
215
+ focus_areas: 重点領域
216
+ progress_callback: 進捗コールバック async def(state)
217
+
218
+ Returns:
219
+ 拡充結果
220
+ """
221
+ self.enrichment_state.update({
222
+ "is_running": True,
223
+ "progress": 0.0,
224
+ "current_question": 0,
225
+ "total_questions": num_questions,
226
+ "generated_tiles": 0,
227
+ "start_time": datetime.utcnow().isoformat(),
228
+ "domain_id": domain_id
229
+ })
230
+
231
+ try:
232
+ # Step 1: 質問生成
233
+ logger.info(f"Step 1: Generating {num_questions} questions...")
234
+ questions = await self.generate_enrichment_questions(
235
+ domain_id, num_questions, focus_areas
236
+ )
237
+
238
+ if not questions:
239
+ return {
240
+ "success": False,
241
+ "error": "Failed to generate questions"
242
+ }
243
+
244
+ self.enrichment_state["total_questions"] = len(questions)
245
+
246
+ # Step 2: 各質問に対して回答生成 + 保存
247
+ tiles_created = 0
248
+
249
+ for i, question in enumerate(questions):
250
+ logger.info(f"Processing question {i+1}/{len(questions)}: {question[:50]}...")
251
+
252
+ self.enrichment_state["current_question"] = i + 1
253
+ self.enrichment_state["progress"] = (i / len(questions)) * 100
254
+
255
+ if progress_callback:
256
+ await progress_callback(self.enrichment_state)
257
+
258
+ try:
259
+ # 師匠モデルで回答生成
260
+ answer_result = await self.answer_generator(question)
261
+
262
+ if not answer_result or "response" not in answer_result:
263
+ logger.warning(f"No response for question: {question}")
264
+ continue
265
+
266
+ response = answer_result["response"]
267
+ confidence = answer_result.get("confidence", 0.7)
268
+
269
+ # 座標推定用の推論関数を作成
270
+ async def coord_inference(coord_prompt):
271
+ result = await self.prompt_generator(coord_prompt)
272
+ return result
273
+
274
+ # 座標推定
275
+ coord_result = await self.coordinate_estimator.estimate_coordinates(
276
+ prompt=question,
277
+ response=response,
278
+ domain_id=domain_id,
279
+ llm_inference_func=coord_inference,
280
+ use_reasoning=False
281
+ )
282
+
283
+ # .iath Tileオブジェクト作成
284
+ from null_ai.iath_writer import create_tile_from_ai_output
285
+
286
+ tile = create_tile_from_ai_output(
287
+ knowledge_id=f"ai_enrich_{uuid.uuid4().hex}",
288
+ topic=question[:100],
289
+ prompt=question,
290
+ response=response,
291
+ coordinates=coord_result["coordinates"],
292
+ confidence=confidence,
293
+ domain_id=domain_id,
294
+ source="ai_enrichment"
295
+ )
296
+
297
+ # .iathに保存
298
+ success = self.iath_writer.append_tile(tile)
299
+
300
+ if success:
301
+ tiles_created += 1
302
+ self.enrichment_state["generated_tiles"] = tiles_created
303
+ logger.info(f"Saved enrichment tile {tiles_created}: {question[:50]}...")
304
+ else:
305
+ logger.warning(f"Failed to save tile for: {question}")
306
+
307
+ except Exception as e:
308
+ logger.error(f"Error processing question '{question}': {e}")
309
+ continue
310
+
311
+ # 完了
312
+ self.enrichment_state.update({
313
+ "is_running": False,
314
+ "progress": 100.0,
315
+ "current_question": len(questions)
316
+ })
317
+
318
+ if progress_callback:
319
+ await progress_callback(self.enrichment_state)
320
+
321
+ return {
322
+ "success": True,
323
+ "questions_generated": len(questions),
324
+ "tiles_created": tiles_created,
325
+ "domain_id": domain_id
326
+ }
327
+
328
+ except Exception as e:
329
+ logger.error(f"AI enrichment failed: {e}")
330
+ self.enrichment_state["is_running"] = False
331
+ return {
332
+ "success": False,
333
+ "error": str(e)
334
+ }
335
+
336
+ def get_enrichment_status(self) -> Dict[str, Any]:
337
+ """現在の拡充ステータスを取得"""
338
+ return self.enrichment_state.copy()
339
+
340
+ def stop_enrichment(self):
341
+ """拡充処理を停止(実装は簡易版)"""
342
+ self.enrichment_state["is_running"] = False
343
+ logger.info("Enrichment stop requested")
344
+
345
+
346
+ class WebEnrichmentEngine:
347
+ """
348
+ Web検索による知識ベース拡充エンジン
349
+
350
+ 動作フロー:
351
+ 1. Web検索(Brave Search / Google Custom Search)
352
+ 2. ページから知識を抽出
353
+ 3. Knowledge Tile形式に変換
354
+ 4. .iathに保存
355
+ """
356
+
357
+ def __init__(
358
+ self,
359
+ search_api_key: Optional[str] = None,
360
+ search_engine: str = "brave", # "brave" or "google"
361
+ coordinate_estimator=None,
362
+ iath_writer=None
363
+ ):
364
+ """
365
+ Args:
366
+ search_api_key: 検索APIキー
367
+ search_engine: 使用する検索エンジン
368
+ coordinate_estimator: CoordinateEstimatorインスタンス
369
+ iath_writer: IathWriterインスタンス
370
+ """
371
+ self.api_key = search_api_key
372
+ self.search_engine = search_engine
373
+ self.coordinate_estimator = coordinate_estimator
374
+ self.iath_writer = iath_writer
375
+
376
+ async def search_web(self, query: str, max_results: int = 5) -> List[Dict[str, str]]:
377
+ """
378
+ Web検索を実行
379
+
380
+ Args:
381
+ query: 検索クエリ
382
+ max_results: 最大結果数
383
+
384
+ Returns:
385
+ [{"title": "...", "url": "...", "snippet": "..."}, ...]
386
+ """
387
+ if self.search_engine == "brave":
388
+ return await self._search_brave(query, max_results)
389
+ elif self.search_engine == "google":
390
+ return await self._search_google(query, max_results)
391
+ else:
392
+ raise ValueError(f"Unknown search engine: {self.search_engine}")
393
+
394
+ async def _search_brave(self, query: str, max_results: int) -> List[Dict[str, str]]:
395
+ """Brave Search APIで検索"""
396
+ try:
397
+ import aiohttp
398
+
399
+ if not self.api_key:
400
+ logger.error("Brave Search API key not configured")
401
+ return []
402
+
403
+ url = "https://api.search.brave.com/res/v1/web/search"
404
+ headers = {
405
+ "Accept": "application/json",
406
+ "X-Subscription-Token": self.api_key
407
+ }
408
+ params = {
409
+ "q": query,
410
+ "count": max_results
411
+ }
412
+
413
+ async with aiohttp.ClientSession() as session:
414
+ async with session.get(url, headers=headers, params=params) as response:
415
+ if response.status == 200:
416
+ data = await response.json()
417
+ results = []
418
+
419
+ for item in data.get("web", {}).get("results", []):
420
+ results.append({
421
+ "title": item.get("title", ""),
422
+ "url": item.get("url", ""),
423
+ "snippet": item.get("description", "")
424
+ })
425
+
426
+ return results
427
+ else:
428
+ logger.error(f"Brave Search API error: {response.status}")
429
+ return []
430
+
431
+ except Exception as e:
432
+ logger.error(f"Brave Search failed: {e}")
433
+ return []
434
+
435
+ async def _search_google(self, query: str, max_results: int) -> List[Dict[str, str]]:
436
+ """Google Custom Search APIで検索"""
437
+ # TODO: Google Custom Search API実装
438
+ logger.warning("Google Custom Search not implemented yet")
439
+ return []
440
+
441
+ async def enrich_from_web(
442
+ self,
443
+ query: str,
444
+ domain_id: str,
445
+ max_results: int = 5
446
+ ) -> Dict[str, Any]:
447
+ """
448
+ Web検索結果から知識ベースを拡充
449
+
450
+ Args:
451
+ query: 検索クエリ
452
+ domain_id: ドメインID
453
+ max_results: 最大結果数
454
+
455
+ Returns:
456
+ 拡充結果
457
+ """
458
+ try:
459
+ # Web検索
460
+ search_results = await self.search_web(query, max_results)
461
+
462
+ if not search_results:
463
+ return {
464
+ "success": False,
465
+ "error": "No search results found"
466
+ }
467
+
468
+ # 各結果をKnowledge Tileに変換
469
+ tiles_created = 0
470
+
471
+ for result in search_results:
472
+ try:
473
+ # Knowledge Tile作成(Web検索結果版)
474
+ tile_content = f"""Title: {result['title']}
475
+
476
+ Source: {result['url']}
477
+
478
+ Summary:
479
+ {result['snippet']}
480
+ """
481
+
482
+ # 座標推定(簡易版: デフォルト座標)
483
+ # TODO: より高度な座標推定
484
+ coordinates = [0.5, 0.5, 0.5, 0.6, 0.5, 0.7]
485
+
486
+ from null_ai.iath_writer import create_tile_from_ai_output
487
+
488
+ tile = create_tile_from_ai_output(
489
+ knowledge_id=f"web_{uuid.uuid4().hex}",
490
+ topic=result['title'],
491
+ prompt=query,
492
+ response=tile_content,
493
+ coordinates=coordinates,
494
+ confidence=0.6, # Web検索結果は中程度の信頼度
495
+ domain_id=domain_id,
496
+ source="web_search"
497
+ )
498
+
499
+ # .iathに保存
500
+ success = self.iath_writer.append_tile(tile)
501
+
502
+ if success:
503
+ tiles_created += 1
504
+ logger.info(f"Saved web tile: {result['title']}")
505
+
506
+ except Exception as e:
507
+ logger.error(f"Error creating tile from search result: {e}")
508
+ continue
509
+
510
+ return {
511
+ "success": True,
512
+ "search_results": len(search_results),
513
+ "tiles_created": tiles_created,
514
+ "query": query,
515
+ "domain_id": domain_id
516
+ }
517
+
518
+ except Exception as e:
519
+ logger.error(f"Web enrichment failed: {e}")
520
+ return {
521
+ "success": False,
522
+ "error": str(e)
523
+ }
fine_tuning.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # null_ai/fine_tuning.py
2
+ """
3
+ NullAI Fine-tuning Module
4
+ Implements apprentice model fine-tuning using master outputs (Alpaca format)
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Dict, List, Optional, Any, Callable
12
+ from datetime import datetime
13
+ import asyncio
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class FineTuningManager:
19
+ """
20
+ Manages fine-tuning of apprentice models using master outputs.
21
+ Supports multiple backends: HuggingFace (PEFT/LoRA), Unsloth, MLX
22
+ """
23
+
24
+ def __init__(self, training_data_dir: str = "training_data/master_outputs"):
25
+ self.training_data_dir = Path(training_data_dir)
26
+ self.checkpoints_dir = Path("training_data/checkpoints")
27
+ self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
28
+
29
+ self.current_training_state = {
30
+ "is_training": False,
31
+ "progress": 0.0,
32
+ "current_epoch": 0,
33
+ "total_epochs": 0,
34
+ "loss": 0.0,
35
+ "model_id": None,
36
+ "start_time": None
37
+ }
38
+
39
+ # ===== Training Data Loading =====
40
+
41
+ def load_training_data(self, domain_id: Optional[str] = None) -> List[Dict[str, Any]]:
42
+ """
43
+ Load training data from Alpaca-format JSONL files.
44
+
45
+ Args:
46
+ domain_id: Specific domain to load. If None, loads all domains.
47
+
48
+ Returns:
49
+ List of training examples in Alpaca format
50
+ """
51
+ training_examples = []
52
+
53
+ if not self.training_data_dir.exists():
54
+ logger.warning(f"Training data directory not found: {self.training_data_dir}")
55
+ return training_examples
56
+
57
+ # Determine which files to load
58
+ if domain_id:
59
+ jsonl_files = [self.training_data_dir / f"master_outputs_{domain_id}.jsonl"]
60
+ else:
61
+ jsonl_files = list(self.training_data_dir.glob("master_outputs_*.jsonl"))
62
+
63
+ for jsonl_file in jsonl_files:
64
+ if not jsonl_file.exists():
65
+ logger.warning(f"Training data file not found: {jsonl_file}")
66
+ continue
67
+
68
+ logger.info(f"Loading training data from: {jsonl_file}")
69
+ with open(jsonl_file, 'r', encoding='utf-8') as f:
70
+ for line in f:
71
+ try:
72
+ example = json.loads(line.strip())
73
+ training_examples.append(example)
74
+ except json.JSONDecodeError as e:
75
+ logger.error(f"Failed to parse JSON line in {jsonl_file}: {e}")
76
+ continue
77
+
78
+ logger.info(f"Loaded {len(training_examples)} training examples")
79
+ return training_examples
80
+
81
+ def format_training_examples_for_model(
82
+ self,
83
+ training_examples: List[Dict[str, Any]],
84
+ template: str = "alpaca"
85
+ ) -> List[str]:
86
+ """
87
+ Format training examples into model-ready prompts.
88
+
89
+ Args:
90
+ training_examples: Raw Alpaca-format examples
91
+ template: Prompt template format ("alpaca", "chatml", "llama3")
92
+
93
+ Returns:
94
+ List of formatted prompt strings
95
+ """
96
+ formatted_prompts = []
97
+
98
+ for example in training_examples:
99
+ instruction = example.get("instruction", "")
100
+ input_text = example.get("input", "")
101
+ output_text = example.get("output", "")
102
+
103
+ if template == "alpaca":
104
+ if input_text:
105
+ prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
106
+
107
+ ### Instruction:
108
+ {instruction}
109
+
110
+ ### Input:
111
+ {input_text}
112
+
113
+ ### Response:
114
+ {output_text}"""
115
+ else:
116
+ prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
117
+
118
+ ### Instruction:
119
+ {instruction}
120
+
121
+ ### Response:
122
+ {output_text}"""
123
+
124
+ elif template == "chatml":
125
+ prompt = f"""<|im_start|>system
126
+ {instruction}<|im_end|>
127
+ <|im_start|>user
128
+ {input_text}<|im_end|>
129
+ <|im_start|>assistant
130
+ {output_text}<|im_end|>"""
131
+
132
+ elif template == "llama3":
133
+ prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
134
+ {instruction}<|eot_id|><|start_header_id|>user<|end_header_id|>
135
+ {input_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
136
+ {output_text}<|eot_id|>"""
137
+
138
+ else:
139
+ raise ValueError(f"Unknown template format: {template}")
140
+
141
+ formatted_prompts.append(prompt)
142
+
143
+ return formatted_prompts
144
+
145
+ # ===== Fine-tuning Backends =====
146
+
147
+ async def fine_tune_with_huggingface_peft(
148
+ self,
149
+ model_name: str,
150
+ training_examples: List[Dict[str, Any]],
151
+ output_dir: str,
152
+ epochs: int = 3,
153
+ learning_rate: float = 2e-4,
154
+ batch_size: int = 4,
155
+ gradient_accumulation_steps: int = 4,
156
+ lora_r: int = 8,
157
+ lora_alpha: int = 16,
158
+ lora_dropout: float = 0.05,
159
+ max_seq_length: int = 512,
160
+ progress_callback: Optional[Callable] = None
161
+ ) -> Dict[str, Any]:
162
+ """
163
+ Fine-tune model using HuggingFace Transformers + PEFT (LoRA).
164
+
165
+ This is the recommended method for most models.
166
+ Uses QLoRA (4-bit quantization) for memory efficiency.
167
+ """
168
+ try:
169
+ import torch
170
+ from transformers import (
171
+ AutoModelForCausalLM,
172
+ AutoTokenizer,
173
+ TrainingArguments,
174
+ Trainer,
175
+ DataCollatorForLanguageModeling
176
+ )
177
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
178
+ from datasets import Dataset
179
+
180
+ except ImportError as e:
181
+ logger.error(f"Required libraries not installed: {e}")
182
+ logger.error("Please install: pip install transformers peft datasets bitsandbytes accelerate")
183
+ raise
184
+
185
+ logger.info(f"Starting PEFT fine-tuning for model: {model_name}")
186
+ self.current_training_state.update({
187
+ "is_training": True,
188
+ "progress": 0.0,
189
+ "current_epoch": 0,
190
+ "total_epochs": epochs,
191
+ "model_id": model_name,
192
+ "start_time": datetime.utcnow().isoformat()
193
+ })
194
+
195
+ # 1. Load tokenizer
196
+ logger.info("Loading tokenizer...")
197
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
198
+ if tokenizer.pad_token is None:
199
+ tokenizer.pad_token = tokenizer.eos_token
200
+
201
+ # 2. Load model with 4-bit quantization (QLoRA)
202
+ logger.info("Loading model with 4-bit quantization...")
203
+ try:
204
+ from transformers import BitsAndBytesConfig
205
+
206
+ bnb_config = BitsAndBytesConfig(
207
+ load_in_4bit=True,
208
+ bnb_4bit_quant_type="nf4",
209
+ bnb_4bit_compute_dtype=torch.float16,
210
+ bnb_4bit_use_double_quant=True
211
+ )
212
+
213
+ model = AutoModelForCausalLM.from_pretrained(
214
+ model_name,
215
+ quantization_config=bnb_config,
216
+ device_map="auto",
217
+ trust_remote_code=True
218
+ )
219
+ except Exception as e:
220
+ logger.warning(f"4-bit quantization failed, falling back to float16: {e}")
221
+ model = AutoModelForCausalLM.from_pretrained(
222
+ model_name,
223
+ torch_dtype=torch.float16,
224
+ device_map="auto",
225
+ trust_remote_code=True
226
+ )
227
+
228
+ # 3. Prepare model for training
229
+ model = prepare_model_for_kbit_training(model)
230
+
231
+ # 4. Configure LoRA
232
+ logger.info("Configuring LoRA...")
233
+ lora_config = LoraConfig(
234
+ r=lora_r,
235
+ lora_alpha=lora_alpha,
236
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
237
+ lora_dropout=lora_dropout,
238
+ bias="none",
239
+ task_type="CAUSAL_LM"
240
+ )
241
+
242
+ model = get_peft_model(model, lora_config)
243
+ model.print_trainable_parameters()
244
+
245
+ # 5. Format training data
246
+ logger.info("Formatting training data...")
247
+ formatted_texts = self.format_training_examples_for_model(training_examples, template="alpaca")
248
+
249
+ # 6. Tokenize dataset
250
+ def tokenize_function(examples):
251
+ return tokenizer(
252
+ examples["text"],
253
+ truncation=True,
254
+ max_length=max_seq_length,
255
+ padding="max_length"
256
+ )
257
+
258
+ dataset = Dataset.from_dict({"text": formatted_texts})
259
+ tokenized_dataset = dataset.map(
260
+ tokenize_function,
261
+ batched=True,
262
+ remove_columns=dataset.column_names
263
+ )
264
+
265
+ # 7. Training arguments
266
+ training_args = TrainingArguments(
267
+ output_dir=output_dir,
268
+ num_train_epochs=epochs,
269
+ per_device_train_batch_size=batch_size,
270
+ gradient_accumulation_steps=gradient_accumulation_steps,
271
+ learning_rate=learning_rate,
272
+ fp16=True,
273
+ logging_steps=10,
274
+ save_steps=100,
275
+ save_total_limit=3,
276
+ warmup_steps=50,
277
+ optim="paged_adamw_8bit",
278
+ report_to="none" # Disable wandb/tensorboard for now
279
+ )
280
+
281
+ # 8. Data collator
282
+ data_collator = DataCollatorForLanguageModeling(
283
+ tokenizer=tokenizer,
284
+ mlm=False
285
+ )
286
+
287
+ # 9. Create trainer with progress callback
288
+ class ProgressCallback:
289
+ def __init__(self, manager, total_epochs, callback):
290
+ self.manager = manager
291
+ self.total_epochs = total_epochs
292
+ self.callback = callback
293
+
294
+ def on_epoch_end(self, args, state, control, **kwargs):
295
+ epoch = state.epoch
296
+ loss = state.log_history[-1].get("loss", 0.0) if state.log_history else 0.0
297
+
298
+ self.manager.current_training_state.update({
299
+ "current_epoch": int(epoch),
300
+ "progress": (epoch / self.total_epochs) * 100,
301
+ "loss": loss
302
+ })
303
+
304
+ if self.callback:
305
+ asyncio.create_task(self.callback(self.manager.current_training_state))
306
+
307
+ from transformers import TrainerCallback
308
+
309
+ class CustomCallback(TrainerCallback):
310
+ def __init__(self, progress_cb):
311
+ self.progress_cb = progress_cb
312
+
313
+ def on_epoch_end(self, args, state, control, **kwargs):
314
+ self.progress_cb.on_epoch_end(args, state, control, **kwargs)
315
+
316
+ progress_cb = ProgressCallback(self, epochs, progress_callback)
317
+
318
+ trainer = Trainer(
319
+ model=model,
320
+ args=training_args,
321
+ train_dataset=tokenized_dataset,
322
+ data_collator=data_collator,
323
+ callbacks=[CustomCallback(progress_cb)]
324
+ )
325
+
326
+ # 10. Train!
327
+ logger.info("Starting training...")
328
+ train_result = trainer.train()
329
+
330
+ # 11. Save final model
331
+ logger.info(f"Saving model to: {output_dir}")
332
+ trainer.save_model(output_dir)
333
+ tokenizer.save_pretrained(output_dir)
334
+
335
+ # 12. Update state
336
+ self.current_training_state.update({
337
+ "is_training": False,
338
+ "progress": 100.0,
339
+ "current_epoch": epochs
340
+ })
341
+
342
+ return {
343
+ "success": True,
344
+ "output_dir": output_dir,
345
+ "train_loss": train_result.training_loss,
346
+ "metrics": train_result.metrics,
347
+ "model_name": model_name,
348
+ "lora_config": {
349
+ "r": lora_r,
350
+ "alpha": lora_alpha,
351
+ "dropout": lora_dropout
352
+ }
353
+ }
354
+
355
+ async def fine_tune_with_unsloth(
356
+ self,
357
+ model_name: str,
358
+ training_examples: List[Dict[str, Any]],
359
+ output_dir: str,
360
+ epochs: int = 3,
361
+ learning_rate: float = 2e-4,
362
+ batch_size: int = 4,
363
+ lora_r: int = 16,
364
+ progress_callback: Optional[Callable] = None
365
+ ) -> Dict[str, Any]:
366
+ """
367
+ Fine-tune model using Unsloth (fastest method, 2x faster than PEFT).
368
+
369
+ Unsloth is optimized for speed and memory efficiency.
370
+ Recommended for: Llama, Mistral, Qwen models
371
+ """
372
+ try:
373
+ from unsloth import FastLanguageModel
374
+ from trl import SFTTrainer
375
+ from transformers import TrainingArguments
376
+ from datasets import Dataset
377
+ except ImportError as e:
378
+ logger.error(f"Unsloth not installed: {e}")
379
+ logger.error("Please install: pip install unsloth")
380
+ raise
381
+
382
+ logger.info(f"Starting Unsloth fine-tuning for model: {model_name}")
383
+ self.current_training_state.update({
384
+ "is_training": True,
385
+ "progress": 0.0,
386
+ "current_epoch": 0,
387
+ "total_epochs": epochs,
388
+ "model_id": model_name,
389
+ "start_time": datetime.utcnow().isoformat()
390
+ })
391
+
392
+ # 1. Load model with Unsloth
393
+ logger.info("Loading model with Unsloth...")
394
+ model, tokenizer = FastLanguageModel.from_pretrained(
395
+ model_name=model_name,
396
+ max_seq_length=2048,
397
+ dtype=None, # Auto-detect
398
+ load_in_4bit=True
399
+ )
400
+
401
+ # 2. Add LoRA adapters
402
+ model = FastLanguageModel.get_peft_model(
403
+ model,
404
+ r=lora_r,
405
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
406
+ lora_alpha=16,
407
+ lora_dropout=0,
408
+ bias="none",
409
+ use_gradient_checkpointing=True,
410
+ random_state=42
411
+ )
412
+
413
+ # 3. Format training data
414
+ formatted_texts = self.format_training_examples_for_model(training_examples, template="alpaca")
415
+ dataset = Dataset.from_dict({"text": formatted_texts})
416
+
417
+ # 4. Training arguments
418
+ training_args = TrainingArguments(
419
+ output_dir=output_dir,
420
+ num_train_epochs=epochs,
421
+ per_device_train_batch_size=batch_size,
422
+ learning_rate=learning_rate,
423
+ fp16=True,
424
+ logging_steps=10,
425
+ save_steps=100,
426
+ warmup_steps=50
427
+ )
428
+
429
+ # 5. Create SFT trainer
430
+ trainer = SFTTrainer(
431
+ model=model,
432
+ tokenizer=tokenizer,
433
+ train_dataset=dataset,
434
+ dataset_text_field="text",
435
+ max_seq_length=2048,
436
+ args=training_args
437
+ )
438
+
439
+ # 6. Train
440
+ logger.info("Starting training with Unsloth...")
441
+ trainer.train()
442
+
443
+ # 7. Save
444
+ logger.info(f"Saving model to: {output_dir}")
445
+ model.save_pretrained(output_dir)
446
+ tokenizer.save_pretrained(output_dir)
447
+
448
+ self.current_training_state.update({
449
+ "is_training": False,
450
+ "progress": 100.0
451
+ })
452
+
453
+ return {
454
+ "success": True,
455
+ "output_dir": output_dir,
456
+ "model_name": model_name,
457
+ "method": "unsloth"
458
+ }
459
+
460
+ async def fine_tune_with_mlx(
461
+ self,
462
+ model_name: str,
463
+ training_examples: List[Dict[str, Any]],
464
+ output_dir: str,
465
+ epochs: int = 3,
466
+ learning_rate: float = 1e-5,
467
+ batch_size: int = 4,
468
+ lora_r: int = 8,
469
+ progress_callback: Optional[Callable] = None
470
+ ) -> Dict[str, Any]:
471
+ """
472
+ Fine-tune model using MLX (Apple Silicon only, ultra-fast).
473
+
474
+ Optimized for M1/M2/M3 Macs.
475
+ Uses unified memory for maximum efficiency.
476
+ """
477
+ try:
478
+ import mlx.core as mx
479
+ from mlx_lm import load, generate
480
+ import mlx.optimizers as optim
481
+ import mlx.nn as nn
482
+ except ImportError as e:
483
+ logger.error(f"MLX not installed: {e}")
484
+ logger.error("Please install: pip install mlx mlx-lm")
485
+ raise
486
+
487
+ logger.info(f"Starting MLX fine-tuning for model: {model_name}")
488
+ self.current_training_state.update({
489
+ "is_training": True,
490
+ "progress": 0.0,
491
+ "current_epoch": 0,
492
+ "total_epochs": epochs,
493
+ "model_id": model_name,
494
+ "start_time": datetime.utcnow().isoformat()
495
+ })
496
+
497
+ # Note: MLX fine-tuning is still experimental
498
+ # For now, return a placeholder
499
+ logger.warning("MLX fine-tuning is not fully implemented yet")
500
+
501
+ self.current_training_state["is_training"] = False
502
+
503
+ return {
504
+ "success": False,
505
+ "error": "MLX fine-tuning not yet implemented",
506
+ "model_name": model_name
507
+ }
508
+
509
+ # ===== Main Training Interface =====
510
+
511
+ async def start_training(
512
+ self,
513
+ apprentice_model_name: str,
514
+ domain_id: Optional[str] = None,
515
+ method: str = "peft", # "peft", "unsloth", "mlx"
516
+ epochs: int = 3,
517
+ learning_rate: float = 2e-4,
518
+ batch_size: int = 4,
519
+ output_name: Optional[str] = None,
520
+ progress_callback: Optional[Callable] = None
521
+ ) -> Dict[str, Any]:
522
+ """
523
+ Main entry point for fine-tuning an apprentice model.
524
+
525
+ Args:
526
+ apprentice_model_name: HuggingFace model name or path
527
+ domain_id: Domain to train on (None = all domains)
528
+ method: Training method ("peft", "unsloth", "mlx")
529
+ epochs: Number of training epochs
530
+ learning_rate: Learning rate
531
+ batch_size: Batch size per device
532
+ output_name: Custom name for output directory
533
+ progress_callback: Async callback for progress updates
534
+
535
+ Returns:
536
+ Training result dictionary
537
+ """
538
+ # 1. Load training data
539
+ training_examples = self.load_training_data(domain_id)
540
+
541
+ if not training_examples:
542
+ return {
543
+ "success": False,
544
+ "error": "No training data found"
545
+ }
546
+
547
+ # 2. Prepare output directory
548
+ if output_name is None:
549
+ timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
550
+ output_name = f"apprentice_{domain_id or 'all'}_{timestamp}"
551
+
552
+ output_dir = self.checkpoints_dir / output_name
553
+ output_dir.mkdir(parents=True, exist_ok=True)
554
+
555
+ # 3. Select training method
556
+ if method == "peft":
557
+ result = await self.fine_tune_with_huggingface_peft(
558
+ model_name=apprentice_model_name,
559
+ training_examples=training_examples,
560
+ output_dir=str(output_dir),
561
+ epochs=epochs,
562
+ learning_rate=learning_rate,
563
+ batch_size=batch_size,
564
+ progress_callback=progress_callback
565
+ )
566
+
567
+ elif method == "unsloth":
568
+ result = await self.fine_tune_with_unsloth(
569
+ model_name=apprentice_model_name,
570
+ training_examples=training_examples,
571
+ output_dir=str(output_dir),
572
+ epochs=epochs,
573
+ learning_rate=learning_rate,
574
+ batch_size=batch_size,
575
+ progress_callback=progress_callback
576
+ )
577
+
578
+ elif method == "mlx":
579
+ result = await self.fine_tune_with_mlx(
580
+ model_name=apprentice_model_name,
581
+ training_examples=training_examples,
582
+ output_dir=str(output_dir),
583
+ epochs=epochs,
584
+ learning_rate=learning_rate,
585
+ batch_size=batch_size,
586
+ progress_callback=progress_callback
587
+ )
588
+
589
+ else:
590
+ return {
591
+ "success": False,
592
+ "error": f"Unknown training method: {method}"
593
+ }
594
+
595
+ return result
596
+
597
+ def get_training_status(self) -> Dict[str, Any]:
598
+ """Get current training status."""
599
+ return self.current_training_state.copy()
600
+
601
+ def stop_training(self):
602
+ """Stop current training (if possible)."""
603
+ # TODO: Implement graceful training interruption
604
+ logger.warning("Training interruption not yet implemented")
605
+ self.current_training_state["is_training"] = False
606
+
607
+ def get_training_metrics(self, checkpoint_dir: str) -> Dict[str, Any]:
608
+ """
609
+ Load training metrics from a checkpoint.
610
+ """
611
+ checkpoint_path = Path(checkpoint_dir)
612
+
613
+ if not checkpoint_path.exists():
614
+ return {"error": "Checkpoint not found"}
615
+
616
+ # Look for trainer_state.json
617
+ trainer_state_file = checkpoint_path / "trainer_state.json"
618
+ if trainer_state_file.exists():
619
+ with open(trainer_state_file, 'r') as f:
620
+ trainer_state = json.load(f)
621
+ return {
622
+ "log_history": trainer_state.get("log_history", []),
623
+ "best_metric": trainer_state.get("best_metric"),
624
+ "best_model_checkpoint": trainer_state.get("best_model_checkpoint")
625
+ }
626
+
627
+ return {"error": "No metrics found"}
iath_memory.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NullAI - .iath Memory System
3
+ 樹木型空間記憶(Dendritic Memory Space)の実装
4
+
5
+ .iathファイル形式との完全互換性を持つ知識検索システム
6
+ 6次元座標系による空間的RAG(Retrieval-Augmented Generation)
7
+ """
8
+
9
+ import struct
10
+ import zstandard as zstd
11
+ import json
12
+ import logging
13
+ import numpy as np
14
+ from typing import List, Dict, Any, Optional, Tuple
15
+ from pathlib import Path
16
+ from datetime import datetime
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class IathDecoder:
22
+ """
23
+ .iathファイル形式のデコーダー
24
+ dendritic-memory-editorとの互換性を保持
25
+ """
26
+
27
+ def __init__(self, iath_file_path: str):
28
+ """
29
+ Args:
30
+ iath_file_path: .iathファイルのパス
31
+ """
32
+ self.file_path = Path(iath_file_path)
33
+ self.header = None
34
+ self.index = []
35
+ self.data_section_offset = 0
36
+
37
+ if self.file_path.exists():
38
+ self._load_header_and_index()
39
+ else:
40
+ logger.warning(f".iath file not found: {iath_file_path}")
41
+
42
+ def _load_header_and_index(self):
43
+ """ヘッダーとインデックスセクションを読み込む"""
44
+ try:
45
+ with open(self.file_path, 'rb') as f:
46
+ # ヘッダー読み込み (64 bytes)
47
+ header_data = f.read(64)
48
+ magic, version, domain_code, compression_type, checksum, index_offset, data_offset = struct.unpack(
49
+ "<4sIBB32sQQ6x", header_data
50
+ )
51
+
52
+ self.header = {
53
+ "magic": magic.decode('ascii'),
54
+ "version": version,
55
+ "domain_code": domain_code,
56
+ "compression_type": compression_type,
57
+ "index_offset": index_offset,
58
+ "data_offset": data_offset
59
+ }
60
+
61
+ logger.info(f"Loaded .iath header: magic={self.header['magic']}, domain={domain_code}")
62
+
63
+ # インデックス読み込み
64
+ f.seek(index_offset)
65
+ index_binary = f.read(data_offset - index_offset)
66
+ self.index = json.loads(index_binary.decode('utf-8'))
67
+
68
+ self.data_section_offset = data_offset
69
+ logger.info(f"Loaded {len(self.index)} tiles from index")
70
+
71
+ except Exception as e:
72
+ logger.error(f"Error loading .iath file: {e}")
73
+ raise
74
+
75
+ def _decode_string(self, data: bytes, offset: int) -> Tuple[str, int]:
76
+ """NULL終端文字列をデコード"""
77
+ end = data.find(b'\0', offset)
78
+ if end == -1:
79
+ return data[offset:].decode('utf-8'), len(data)
80
+ return data[offset:end].decode('utf-8'), end + 1
81
+
82
+ def _decode_tile_data(self, compressed_data: bytes) -> dict:
83
+ """圧縮されたタイルデータをデコード"""
84
+ # zstd解凍
85
+ dctx = zstd.ZstdDecompressor()
86
+ uncompressed = dctx.decompress(compressed_data)
87
+
88
+ offset = 0
89
+
90
+ # メタデータ
91
+ metadata_len = struct.unpack("<I", uncompressed[offset:offset+4])[0]
92
+ offset += 4
93
+ metadata_bin = uncompressed[offset:offset+metadata_len]
94
+ offset += metadata_len
95
+
96
+ # メタデータ内の文字列をパース
97
+ meta_offset = 0
98
+ knowledge_id, meta_offset = self._decode_string(metadata_bin, meta_offset)
99
+ topic, meta_offset = self._decode_string(metadata_bin, meta_offset)
100
+ created_at = metadata_bin[meta_offset:meta_offset+27].decode('ascii').strip('\0')
101
+
102
+ metadata = {
103
+ "knowledge_id": knowledge_id,
104
+ "topic": topic,
105
+ "created_at": created_at
106
+ }
107
+
108
+ # 座標
109
+ coord_len = struct.unpack("<I", uncompressed[offset:offset+4])[0]
110
+ offset += 4
111
+ coord_data = uncompressed[offset:offset+coord_len]
112
+ offset += coord_len
113
+
114
+ x, y, z, c, g, v = struct.unpack("<ffffff", coord_data)
115
+ coordinates = {
116
+ "medical_space": [x, y, z],
117
+ "meta_space": [c, g, v]
118
+ }
119
+
120
+ # コンテンツ
121
+ content_len = struct.unpack("<I", uncompressed[offset:offset+4])[0]
122
+ offset += 4
123
+ content_data = uncompressed[offset:offset+content_len]
124
+ offset += content_len
125
+
126
+ # thinking_process
127
+ thinking_len = struct.unpack("<I", content_data[0:4])[0]
128
+ thinking = content_data[4:4+thinking_len].decode('utf-8')
129
+
130
+ # final_response
131
+ response_offset = 4 + thinking_len
132
+ response_len = struct.unpack("<I", content_data[response_offset:response_offset+4])[0]
133
+ response = content_data[response_offset+4:response_offset+4+response_len].decode('utf-8')
134
+
135
+ content = {
136
+ "thinking_process": thinking,
137
+ "final_response": response
138
+ }
139
+
140
+ # 検証(簡易実装)
141
+ verification_len = struct.unpack("<I", uncompressed[offset:offset+4])[0]
142
+ offset += 4
143
+ verification_data = uncompressed[offset:offset+verification_len]
144
+
145
+ status_code, initial_certainty, reviewer_count = struct.unpack("<BBI", verification_data[:6])
146
+ status_map = {0: "pending_review", 1: "partial_verified", 2: "verified", 3: "expert_confirmed"}
147
+
148
+ verification = {
149
+ "status": status_map.get(status_code, "pending_review"),
150
+ "initial_certainty": initial_certainty / 100.0,
151
+ "reviewers": [] # 詳細は省略
152
+ }
153
+
154
+ return {
155
+ "metadata": metadata,
156
+ "coordinates": coordinates,
157
+ "content": content,
158
+ "verification": verification
159
+ }
160
+
161
+ def get_tile_by_id(self, knowledge_id: str) -> Optional[dict]:
162
+ """IDで特定のタイルを取得"""
163
+ try:
164
+ # インデックスから検索
165
+ index_entry = next((entry for entry in self.index if entry["id"] == knowledge_id), None)
166
+ if not index_entry:
167
+ return None
168
+
169
+ # データセクションから読み込み
170
+ with open(self.file_path, 'rb') as f:
171
+ f.seek(self.data_section_offset + index_entry["offset"])
172
+ compressed_data = f.read(index_entry["length"])
173
+ return self._decode_tile_data(compressed_data)
174
+
175
+ except Exception as e:
176
+ logger.error(f"Error loading tile {knowledge_id}: {e}")
177
+ return None
178
+
179
+ def get_all_tiles(self) -> List[dict]:
180
+ """全タイルを取得(メモリに注意)"""
181
+ tiles = []
182
+ for entry in self.index:
183
+ tile = self.get_tile_by_id(entry["id"])
184
+ if tile:
185
+ tiles.append(tile)
186
+ return tiles
187
+
188
+
189
+ class DendriticMemorySpace:
190
+ """
191
+ 樹木型空間記憶システム
192
+
193
+ 6次元座標系による知識の空間配置と検索:
194
+ - medical_space [x, y, z]: ドメイン固有の3次元空間
195
+ - meta_space [c, g, v]: Certainty, Granularity, Verification
196
+ """
197
+
198
+ def __init__(self, iath_file_path: str):
199
+ """
200
+ Args:
201
+ iath_file_path: .iathファイルのパス
202
+ """
203
+ self.decoder = IathDecoder(iath_file_path)
204
+ self.tiles_cache = [] # 全タイルをメモリにキャッシュ(最適化可能)
205
+ self._build_spatial_index()
206
+
207
+ def _build_spatial_index(self):
208
+ """空間インデックスを構築"""
209
+ logger.info("Building spatial index from .iath file...")
210
+ self.tiles_cache = self.decoder.get_all_tiles()
211
+
212
+ # 座標行列を構築(高速検索用)
213
+ if self.tiles_cache:
214
+ self.coordinates_matrix = np.array([
215
+ tile["coordinates"]["medical_space"] + tile["coordinates"]["meta_space"]
216
+ for tile in self.tiles_cache
217
+ ])
218
+ logger.info(f"Spatial index built: {len(self.tiles_cache)} tiles")
219
+ else:
220
+ self.coordinates_matrix = np.array([])
221
+ logger.warning("No tiles found in .iath file")
222
+
223
+ def search_by_coordinates(self, query_coords: List[float], top_k: int = 5, distance_threshold: float = None) -> List[dict]:
224
+ """
225
+ 座標ベースの空間検索(樹木型空間記憶の核心機能)
226
+
227
+ Args:
228
+ query_coords: クエリ座標 [x, y, z, c, g, v]
229
+ top_k: 返却する上位K件
230
+ distance_threshold: 距離閾値(Noneなら無制限)
231
+
232
+ Returns:
233
+ 関連するタイルのリスト(距離の近い順)
234
+ """
235
+ if len(self.tiles_cache) == 0:
236
+ return []
237
+
238
+ query_vector = np.array(query_coords)
239
+
240
+ # ユークリッド距離を計算
241
+ distances = np.linalg.norm(self.coordinates_matrix - query_vector, axis=1)
242
+
243
+ # 距離でソート
244
+ sorted_indices = np.argsort(distances)
245
+
246
+ # top_k件を取得
247
+ results = []
248
+ for idx in sorted_indices[:top_k]:
249
+ distance = distances[idx]
250
+ if distance_threshold is not None and distance > distance_threshold:
251
+ break
252
+
253
+ tile = self.tiles_cache[idx].copy()
254
+ tile["spatial_distance"] = float(distance)
255
+ results.append(tile)
256
+
257
+ return results
258
+
259
+ def search_by_text(self, query_text: str, top_k: int = 5) -> List[dict]:
260
+ """
261
+ テキスト検索(簡易実装:キーワードマッチング)
262
+
263
+ Args:
264
+ query_text: 検索クエリテキスト
265
+ top_k: 返却する上位K件
266
+
267
+ Returns:
268
+ 関連するタイルのリスト
269
+ """
270
+ query_lower = query_text.lower()
271
+ matches = []
272
+
273
+ for tile in self.tiles_cache:
274
+ topic = tile["metadata"]["topic"].lower()
275
+ content = tile["content"]["final_response"].lower()
276
+
277
+ # シンプルなスコアリング(含まれる回数)
278
+ score = topic.count(query_lower) * 2 + content.count(query_lower)
279
+
280
+ if score > 0:
281
+ tile_copy = tile.copy()
282
+ tile_copy["text_match_score"] = score
283
+ matches.append(tile_copy)
284
+
285
+ # スコアでソート
286
+ matches.sort(key=lambda x: x["text_match_score"], reverse=True)
287
+
288
+ return matches[:top_k]
289
+
290
+ def hybrid_search(self, query_text: str, query_coords: Optional[List[float]] = None, top_k: int = 5) -> List[dict]:
291
+ """
292
+ ハイブリッド検索:テキスト + 空間座標
293
+
294
+ Args:
295
+ query_text: 検索クエリテキスト
296
+ query_coords: クエリ座標(オプション)
297
+ top_k: 返却する上位K件
298
+
299
+ Returns:
300
+ 関連するタイルのリスト
301
+ """
302
+ # テキスト検索
303
+ text_results = self.search_by_text(query_text, top_k=top_k*2)
304
+
305
+ # 座標検索が指定されている場合
306
+ if query_coords:
307
+ spatial_results = self.search_by_coordinates(query_coords, top_k=top_k*2)
308
+
309
+ # 両方に出現するタイルを優先的にスコアリング
310
+ combined_scores = {}
311
+
312
+ for tile in text_results:
313
+ tile_id = tile["metadata"]["knowledge_id"]
314
+ combined_scores[tile_id] = {
315
+ "tile": tile,
316
+ "text_score": tile.get("text_match_score", 0),
317
+ "spatial_score": 0
318
+ }
319
+
320
+ for tile in spatial_results:
321
+ tile_id = tile["metadata"]["knowledge_id"]
322
+ if tile_id in combined_scores:
323
+ combined_scores[tile_id]["spatial_score"] = 1.0 / (1.0 + tile.get("spatial_distance", 10))
324
+ else:
325
+ combined_scores[tile_id] = {
326
+ "tile": tile,
327
+ "text_score": 0,
328
+ "spatial_score": 1.0 / (1.0 + tile.get("spatial_distance", 10))
329
+ }
330
+
331
+ # 複合スコアで並び替え
332
+ ranked = sorted(
333
+ combined_scores.values(),
334
+ key=lambda x: x["text_score"] * 0.6 + x["spatial_score"] * 0.4,
335
+ reverse=True
336
+ )
337
+
338
+ return [item["tile"] for item in ranked[:top_k]]
339
+
340
+ else:
341
+ # テキスト検索のみ
342
+ return text_results[:top_k]
343
+
344
+ def get_statistics(self) -> dict:
345
+ """メモリ空間の統計情報を取得"""
346
+ if len(self.tiles_cache) == 0:
347
+ return {"total_tiles": 0}
348
+
349
+ coords = self.coordinates_matrix
350
+
351
+ return {
352
+ "total_tiles": len(self.tiles_cache),
353
+ "coordinate_ranges": {
354
+ "medical_x": {"min": float(coords[:, 0].min()), "max": float(coords[:, 0].max())},
355
+ "medical_y": {"min": float(coords[:, 1].min()), "max": float(coords[:, 1].max())},
356
+ "medical_z": {"min": float(coords[:, 2].min()), "max": float(coords[:, 2].max())},
357
+ "certainty": {"min": float(coords[:, 3].min()), "max": float(coords[:, 3].max())},
358
+ "granularity": {"min": float(coords[:, 4].min()), "max": float(coords[:, 4].max())},
359
+ "verification": {"min": float(coords[:, 5].min()), "max": float(coords[:, 5].max())}
360
+ },
361
+ "verification_status_distribution": self._get_verification_distribution()
362
+ }
363
+
364
+ def _get_verification_distribution(self) -> dict:
365
+ """検証ステータスの分布を取得"""
366
+ distribution = {}
367
+ for tile in self.tiles_cache:
368
+ status = tile["verification"]["status"]
369
+ distribution[status] = distribution.get(status, 0) + 1
370
+ return distribution
iath_writer.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # null_ai/iath_writer.py
2
+ """
3
+ .iath File Writer Module
4
+
5
+ AI生成知識を.iath形式で保存するためのモジュール。
6
+ dendritic-memory-editor完全互換。
7
+ """
8
+
9
+ import struct
10
+ import zstandard as zstd
11
+ from datetime import datetime
12
+ import json
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import Dict, List, Any, Optional
16
+ import uuid
17
+
18
+ from null_ai.iath_memory import IathDecoder
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class IathWriter:
24
+ """
25
+ .iathファイルへの書き込みを管理するクラス
26
+ """
27
+
28
+ def __init__(self, iath_file_path: str):
29
+ self.file_path = Path(iath_file_path)
30
+ self.encoder = IathTileEncoder()
31
+
32
+ def append_tile(self, tile: Dict[str, Any]) -> bool:
33
+ """
34
+ 既存の.iathファイルに新しいタイルを追記
35
+
36
+ Args:
37
+ tile: Knowledge Tile オブジェクト
38
+
39
+ Returns:
40
+ 成功したかどうか
41
+ """
42
+ try:
43
+ # 既存ファイルが存在するか確認
44
+ if not self.file_path.exists():
45
+ # 新規ファイル作成
46
+ return self._create_new_iath_file([tile])
47
+
48
+ # 既存ファイルの読み込み
49
+ decoder = IathDecoder(str(self.file_path))
50
+ existing_tiles = decoder.get_all_tiles()
51
+
52
+ # 新しいタイルを追加
53
+ existing_tiles.append(tile)
54
+
55
+ # ファイルを再構築
56
+ return self._rebuild_iath_file(existing_tiles)
57
+
58
+ except Exception as e:
59
+ logger.error(f"Failed to append tile to .iath: {e}")
60
+ return False
61
+
62
+ def append_tiles_batch(self, tiles: List[Dict[str, Any]]) -> bool:
63
+ """
64
+ 複数のタイルを一括追記
65
+
66
+ Args:
67
+ tiles: Knowledge Tile のリスト
68
+
69
+ Returns:
70
+ 成功したかどうか
71
+ """
72
+ try:
73
+ if not self.file_path.exists():
74
+ return self._create_new_iath_file(tiles)
75
+
76
+ decoder = IathDecoder(str(self.file_path))
77
+ existing_tiles = decoder.get_all_tiles()
78
+
79
+ # 新しいタイルを追加
80
+ existing_tiles.extend(tiles)
81
+
82
+ # ファイルを再構築
83
+ return self._rebuild_iath_file(existing_tiles)
84
+
85
+ except Exception as e:
86
+ logger.error(f"Failed to append tiles batch to .iath: {e}")
87
+ return False
88
+
89
+ def _create_new_iath_file(self, tiles: List[Dict[str, Any]]) -> bool:
90
+ """
91
+ 新規.iathファイルを作成
92
+ """
93
+ try:
94
+ with open(self.file_path, 'wb') as f:
95
+ # Header作成
96
+ header = self._create_header(len(tiles))
97
+ f.write(header)
98
+
99
+ # Index作成
100
+ index_data, tiles_data = self._create_index_and_data(tiles)
101
+ index_json = json.dumps(index_data, ensure_ascii=False).encode('utf-8')
102
+
103
+ # Indexサイズを書き込み
104
+ f.write(struct.pack('<I', len(index_json)))
105
+ f.write(index_json)
106
+
107
+ # Dataセクション書き込み
108
+ f.write(tiles_data)
109
+
110
+ logger.info(f"Created new .iath file: {self.file_path}")
111
+ return True
112
+
113
+ except Exception as e:
114
+ logger.error(f"Failed to create new .iath file: {e}")
115
+ return False
116
+
117
+ def _rebuild_iath_file(self, tiles: List[Dict[str, Any]]) -> bool:
118
+ """
119
+ .iathファイルを再構築
120
+ """
121
+ # 一時ファイルに書き込み
122
+ temp_path = self.file_path.with_suffix('.iath.tmp')
123
+
124
+ try:
125
+ with open(temp_path, 'wb') as f:
126
+ # Header作成
127
+ header = self._create_header(len(tiles))
128
+ f.write(header)
129
+
130
+ # Index作成
131
+ index_data, tiles_data = self._create_index_and_data(tiles)
132
+ index_json = json.dumps(index_data, ensure_ascii=False).encode('utf-8')
133
+
134
+ # Indexサイズを書き込み
135
+ f.write(struct.pack('<I', len(index_json)))
136
+ f.write(index_json)
137
+
138
+ # Dataセクション書き込み
139
+ f.write(tiles_data)
140
+
141
+ # 一時ファイルを本ファイルに置き換え
142
+ temp_path.replace(self.file_path)
143
+
144
+ logger.info(f"Rebuilt .iath file with {len(tiles)} tiles: {self.file_path}")
145
+ return True
146
+
147
+ except Exception as e:
148
+ logger.error(f"Failed to rebuild .iath file: {e}")
149
+ if temp_path.exists():
150
+ temp_path.unlink()
151
+ return False
152
+
153
+ def _create_header(self, tile_count: int) -> bytes:
154
+ """
155
+ .iathヘッダーを作成(64バイト)
156
+ """
157
+ # マジックナンバー: "IATH" (4 bytes)
158
+ magic = b'IATH'
159
+
160
+ # バージョン: 1.0 (2 bytes)
161
+ version_major = 1
162
+ version_minor = 0
163
+
164
+ # タイル数 (4 bytes)
165
+ # 作成日時 (8 bytes, Unix timestamp)
166
+ created_at = int(datetime.now().timestamp())
167
+
168
+ # 予約領域 (46 bytes)
169
+ reserved = b'\x00' * 46
170
+
171
+ header = struct.pack(
172
+ '<4s 2B I Q 46s',
173
+ magic,
174
+ version_major,
175
+ version_minor,
176
+ tile_count,
177
+ created_at,
178
+ reserved
179
+ )
180
+
181
+ assert len(header) == 64, f"Header size must be 64 bytes, got {len(header)}"
182
+ return header
183
+
184
+ def _create_index_and_data(self, tiles: List[Dict[str, Any]]) -> tuple:
185
+ """
186
+ インデックスとデータセクションを作成
187
+
188
+ Returns:
189
+ (index_dict, data_bytes)
190
+ """
191
+ index = {"tiles": []}
192
+ data_buffer = bytearray()
193
+
194
+ current_offset = 0
195
+
196
+ for tile in tiles:
197
+ # タイルをエンコード
198
+ tile_bytes = self.encoder.encode_tile(tile)
199
+
200
+ # インデックスエントリ作成
201
+ index["tiles"].append({
202
+ "id": tile["metadata"]["knowledge_id"],
203
+ "offset": current_offset,
204
+ "size": len(tile_bytes)
205
+ })
206
+
207
+ # データバッファに追加
208
+ data_buffer.extend(tile_bytes)
209
+ current_offset += len(tile_bytes)
210
+
211
+ return index, bytes(data_buffer)
212
+
213
+
214
+ class IathTileEncoder:
215
+ """
216
+ Knowledge Tileを.iath互換のバイナリ形式にエンコード
217
+
218
+ dendritic-memory-editorのIathEncoderと互換性あり
219
+ """
220
+
221
+ def encode_tile(self, tile: Dict[str, Any]) -> bytes:
222
+ """
223
+ 単一のKnowledge Tileをエンコードし、zstdで圧縮
224
+
225
+ Args:
226
+ tile: Knowledge Tile オブジェクト
227
+
228
+ Returns:
229
+ 圧縮されたバイナリデータ
230
+ """
231
+ # 各セクションをエンコード
232
+ metadata_bin = self._encode_metadata(tile["metadata"])
233
+ coord_bin = self._encode_coordinates(tile["coordinates"])
234
+ content_bin = self._encode_content(tile["content"])
235
+ verification_bin = self._encode_verification(tile["verification"])
236
+
237
+ # 長さプレフィックスを付けて連結
238
+ uncompressed = b"".join([
239
+ struct.pack("<I", len(metadata_bin)), metadata_bin,
240
+ struct.pack("<I", len(coord_bin)), coord_bin,
241
+ struct.pack("<I", len(content_bin)), content_bin,
242
+ struct.pack("<I", len(verification_bin)), verification_bin,
243
+ ])
244
+
245
+ # zstdで圧縮(レベル19 = 最高圧縮率)
246
+ cctx = zstd.ZstdCompressor(level=19)
247
+ compressed = cctx.compress(uncompressed)
248
+
249
+ return compressed
250
+
251
+ def _encode_metadata(self, metadata: Dict[str, Any]) -> bytes:
252
+ """メタデータをバイナリ化"""
253
+ kid = self._encode_string(metadata["knowledge_id"])
254
+ topic = self._encode_string(metadata["topic"])
255
+
256
+ # 作成日時(ISO形式)
257
+ created_at_iso = metadata.get("created_at", datetime.now().isoformat())
258
+ created_at = created_at_iso.encode('ascii')[:27] # ISO format with Z
259
+
260
+ return kid + topic + created_at
261
+
262
+ def _encode_coordinates(self, coordinates: Dict[str, Any]) -> bytes:
263
+ """
264
+ 座標をバイナリ化(6つの浮動小数点数)
265
+
266
+ 座標は以下のいずれかの形式を受け付ける:
267
+ 1. {"medical_space": [x, y, z], "meta_space": [c, g, v]}
268
+ 2. [x, y, z, c, g, v]
269
+ """
270
+ if isinstance(coordinates, dict):
271
+ medical_space = coordinates["medical_space"]
272
+ meta_space = coordinates["meta_space"]
273
+
274
+ return struct.pack(
275
+ "<ffffff",
276
+ float(medical_space[0]), float(medical_space[1]), float(medical_space[2]),
277
+ float(meta_space[0]), float(meta_space[1]), float(meta_space[2])
278
+ )
279
+ elif isinstance(coordinates, list) and len(coordinates) == 6:
280
+ # フラットな配列形式
281
+ return struct.pack(
282
+ "<ffffff",
283
+ float(coordinates[0]), float(coordinates[1]), float(coordinates[2]),
284
+ float(coordinates[3]), float(coordinates[4]), float(coordinates[5])
285
+ )
286
+ else:
287
+ raise ValueError(f"Invalid coordinates format: {coordinates}")
288
+
289
+ def _encode_content(self, content: Dict[str, Any]) -> bytes:
290
+ """コンテンツ(テキスト)をバイナリ化"""
291
+ thinking = content.get("thinking_process", "").encode('utf-8')
292
+ response = content.get("final_response", "").encode('utf-8')
293
+
294
+ # 各パートの長さを前に付けて連結
295
+ result = struct.pack("<I", len(thinking)) + thinking
296
+ result += struct.pack("<I", len(response)) + response
297
+
298
+ return result
299
+
300
+ def _encode_verification(self, verification: Dict[str, Any]) -> bytes:
301
+ """検証履歴をバイナリ化"""
302
+ status_map = {
303
+ "pending_review": 0,
304
+ "partial_verified": 1,
305
+ "verified": 2,
306
+ "expert_confirmed": 3
307
+ }
308
+ status_code = status_map.get(verification.get("status", "pending_review"), 0)
309
+
310
+ initial_certainty = int(verification.get("initial_certainty", 0) * 100) # 0-100に変換
311
+ reviewer_count = len(verification.get("reviewers", []))
312
+
313
+ result = struct.pack("<BBI", status_code, initial_certainty, reviewer_count)
314
+
315
+ # レビュアー情報
316
+ for reviewer in verification.get("reviewers", []):
317
+ result += self._encode_reviewer_reference(reviewer)
318
+
319
+ return result
320
+
321
+ def _encode_reviewer_reference(self, reviewer: Dict[str, Any]) -> bytes:
322
+ """レビュアー情報をエンコード"""
323
+ reviewer_id = reviewer.get("reviewer_id", "unknown").encode('utf-8')
324
+ return struct.pack("<36s", reviewer_id[:36]) # UUID string length
325
+
326
+ def _encode_string(self, s: str) -> bytes:
327
+ """NULL終端のUTF-8文字列をエンコード"""
328
+ return s.encode('utf-8') + b'\0'
329
+
330
+
331
+ def create_tile_from_ai_output(
332
+ knowledge_id: str,
333
+ topic: str,
334
+ prompt: str,
335
+ response: str,
336
+ coordinates: List[float],
337
+ confidence: float,
338
+ domain_id: str,
339
+ source: str = "ai_generated"
340
+ ) -> Dict[str, Any]:
341
+ """
342
+ AI出力からKnowledge Tileオブジェクトを作成
343
+
344
+ Args:
345
+ knowledge_id: 知識ID(UUID推奨)
346
+ topic: トピック
347
+ prompt: ユーザーの質問
348
+ response: AIの回答
349
+ coordinates: 6次元座標 [x, y, z, c, g, v]
350
+ confidence: 信頼度 (0.0-1.0)
351
+ domain_id: ドメインID
352
+ source: ソース("ai_generated", "human_verified", etc.)
353
+
354
+ Returns:
355
+ Knowledge Tile オブジェクト
356
+ """
357
+ if len(coordinates) != 6:
358
+ raise ValueError(f"Coordinates must be 6-dimensional, got {len(coordinates)}")
359
+
360
+ # 検証ステータスの決定
361
+ if confidence >= 0.9:
362
+ status = "expert_confirmed"
363
+ elif confidence >= 0.8:
364
+ status = "verified"
365
+ elif confidence >= 0.7:
366
+ status = "partial_verified"
367
+ else:
368
+ status = "pending_review"
369
+
370
+ tile = {
371
+ "metadata": {
372
+ "knowledge_id": knowledge_id,
373
+ "topic": topic,
374
+ "created_at": datetime.now().isoformat(),
375
+ "domain_id": domain_id,
376
+ "source": source
377
+ },
378
+ "coordinates": {
379
+ "medical_space": coordinates[:3],
380
+ "meta_space": coordinates[3:]
381
+ },
382
+ "content": {
383
+ "thinking_process": f"Question: {prompt}",
384
+ "final_response": response
385
+ },
386
+ "verification": {
387
+ "status": status,
388
+ "initial_certainty": confidence,
389
+ "reviewers": []
390
+ }
391
+ }
392
+
393
+ return tile
394
+
395
+
396
+ def merge_jsonl_to_iath(
397
+ jsonl_path: str,
398
+ iath_path: str,
399
+ coordinate_estimator,
400
+ llm_inference_func
401
+ ) -> int:
402
+ """
403
+ JSONLファイルの訓練データを.iath形式に変換して追記
404
+
405
+ Args:
406
+ jsonl_path: Alpaca形式のJSONLファイルパス
407
+ iath_path: 出力先.iathファイルパス
408
+ coordinate_estimator: CoordinateEstimatorインスタンス
409
+ llm_inference_func: LLM推論関数
410
+
411
+ Returns:
412
+ 変換したタイル数
413
+ """
414
+ import asyncio
415
+
416
+ writer = IathWriter(iath_path)
417
+ tiles_created = 0
418
+
419
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
420
+ for line in f:
421
+ try:
422
+ example = json.loads(line.strip())
423
+
424
+ # 座標推定
425
+ coord_result = asyncio.run(coordinate_estimator.estimate_coordinates(
426
+ prompt=example["input"],
427
+ response=example["output"],
428
+ domain_id=example["metadata"]["domain_id"],
429
+ llm_inference_func=llm_inference_func
430
+ ))
431
+
432
+ # Tileオブジェクト作成
433
+ tile = create_tile_from_ai_output(
434
+ knowledge_id=str(uuid.uuid4()),
435
+ topic=example["input"][:100], # 最初の100文字をトピックに
436
+ prompt=example["input"],
437
+ response=example["output"],
438
+ coordinates=coord_result["coordinates"],
439
+ confidence=example["metadata"]["confidence"],
440
+ domain_id=example["metadata"]["domain_id"],
441
+ source="master_output"
442
+ )
443
+
444
+ # .iathに追記
445
+ writer.append_tile(tile)
446
+ tiles_created += 1
447
+
448
+ except Exception as e:
449
+ logger.error(f"Failed to convert line to tile: {e}")
450
+ continue
451
+
452
+ logger.info(f"Merged {tiles_created} tiles from {jsonl_path} to {iath_path}")
453
+ return tiles_created
llm_providers.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # null_ai/llm_providers.py
2
+ import logging
3
+ import asyncio
4
+ import time
5
+ import threading # For TextIteratorStreamer
6
+ from typing import Dict, Any, AsyncGenerator, Optional
7
+
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ import torch
10
+
11
+ from backend.app.config import ModelConfig, ModelProvider # ModelConfigを使用
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # --- HuggingFace Transformers Provider ---
16
+ class HuggingFaceProvider:
17
+ """
18
+ HuggingFace Transformersモデルとのインタラクションを管理するクラス。
19
+ """
20
+ _loaded_models: Dict[str, Any] = {} # {model_name: {"model": model, "tokenizer": tokenizer}}
21
+
22
+ async def _load_model_and_tokenizer(self, model_name: str) -> Dict[str, Any]:
23
+ """
24
+ HuggingFaceモデルとトークナイザーをロード(キャッシュ)する。
25
+ """
26
+ if model_name in self._loaded_models:
27
+ return self._loaded_models[model_name]
28
+
29
+ logger.info(f"Loading HuggingFace model: {model_name}...")
30
+ try:
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ model_name,
34
+ torch_dtype=torch.float16, # 高速化のためfloat16を使用
35
+ device_map="auto", # 利用可能なデバイスに自動マップ
36
+ trust_remote_code=True
37
+ )
38
+ self._loaded_models[model_name] = {"model": model, "tokenizer": tokenizer}
39
+ logger.info(f"HuggingFace model '{model_name}' loaded successfully.")
40
+ return self._loaded_models[model_name]
41
+ except Exception as e:
42
+ logger.error(f"Failed to load HuggingFace model '{model_name}': {e}")
43
+ raise
44
+
45
+ async def infer(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
46
+ """
47
+ HuggingFaceモデルで非ストリーミング推論を実行する。
48
+ """
49
+ model_data = await self._load_model_and_tokenizer(model_config.model_name)
50
+ model = model_data["model"]
51
+ tokenizer = model_data["tokenizer"]
52
+
53
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
54
+
55
+ start_time = time.perf_counter()
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=model_config.max_tokens,
59
+ temperature=temperature,
60
+ do_sample=temperature > 0,
61
+ pad_token_id=tokenizer.eos_token_id # ストリーミングでない場合は必須ではないが、安全のため
62
+ )
63
+ end_time = time.perf_counter()
64
+
65
+ response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ latency_ms = (end_time - start_time) * 1000
67
+
68
+ # ここでは推論チェーンの抽出や確信度計算は行わない (ModelRouterの役割)
69
+ return {
70
+ "response": response_text,
71
+ "thinking": f"Inferred by {model_config.display_name} (HuggingFace).",
72
+ "confidence": 0.8, # 仮の値
73
+ "latency_ms": latency_ms
74
+ }
75
+
76
+ async def infer_streaming(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
77
+ """
78
+ HuggingFaceモデルでストリーミング推論を実行する。
79
+ """
80
+ model_data = await self._load_model_and_tokenizer(model_config.model_name)
81
+ model = model_data["model"]
82
+ tokenizer = model_data["tokenizer"]
83
+
84
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
85
+
86
+ # TextIteratorStreamerを使用
87
+ streamer = TextIteratorStreamer(
88
+ tokenizer,
89
+ skip_prompt=True,
90
+ skip_special_tokens=True
91
+ )
92
+
93
+ generation_kwargs = {
94
+ **inputs,
95
+ "max_new_tokens": model_config.max_tokens,
96
+ "temperature": temperature,
97
+ "do_sample": temperature > 0,
98
+ "streamer": streamer,
99
+ "pad_token_id": tokenizer.eos_token_id
100
+ }
101
+
102
+ # 別スレッドで生成を実行
103
+ # Streamerはスレッドセーフではないため、生成は別スレッドで行う
104
+ # そしてメインスレッドからStreamerをポーリングする
105
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
106
+ thread.start()
107
+
108
+ generated_text = ""
109
+ start_time = time.perf_counter()
110
+
111
+ yield {"type": "thinking", "content": f"Loading model: {model_config.display_name}..."}
112
+
113
+ for new_token in streamer:
114
+ generated_text += new_token
115
+ yield {"type": "token", "content": new_token}
116
+
117
+ thread.join() # 生成が完了するのを待つ
118
+ end_time = time.perf_counter()
119
+
120
+ latency_ms = (end_time - start_time) * 1000
121
+
122
+ yield {"type": "complete", "content": generated_text}
123
+ yield {"type": "meta",
124
+ "confidence": 0.8, # 仮の値
125
+ "thinking": f"Inferred by {model_config.display_name} (HuggingFace).",
126
+ "latency_ms": latency_ms
127
+ }
128
+
129
+ # --- Ollama Provider ---
130
+ class OllamaProvider:
131
+ _ollama_client: Any = None # ollamaクライアントは遅延ロード
132
+ _loaded_models: Dict[str, bool] = {} # {model_name: is_loaded}
133
+
134
+ async def _get_ollama_client(self, api_url: Optional[str]) -> Any:
135
+ try:
136
+ from ollama import AsyncClient
137
+ if api_url and api_url != "http://localhost:11434": # デフォルト以外
138
+ return AsyncClient(host=api_url)
139
+ return AsyncClient() # デフォルトクライアント
140
+ except ImportError:
141
+ logger.error("Ollama library not installed. Please `pip install ollama`.")
142
+ raise
143
+
144
+ async def _ensure_model_available(self, model_name: str, api_url: Optional[str]):
145
+ if self._loaded_models.get(model_name):
146
+ return
147
+
148
+ client = await self._get_ollama_client(api_url)
149
+ logger.info(f"Checking Ollama model: {model_name}...")
150
+ try:
151
+ models = await client.list()
152
+ if not any(m['name'] == model_name for m in models['models']):
153
+ logger.info(f"Ollama model '{model_name}' not found locally. Pulling...")
154
+ await client.pull(model_name)
155
+ self._loaded_models[model_name] = True
156
+ logger.info(f"Ollama model '{model_name}' is available.")
157
+ except Exception as e:
158
+ logger.error(f"Failed to ensure Ollama model '{model_name}' availability: {e}")
159
+ raise
160
+
161
+ async def infer(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
162
+ """
163
+ Ollamaモデルで非ストリーミング推論を実行する。
164
+ """
165
+ await self._ensure_model_available(model_config.model_name, model_config.api_url)
166
+ client = await self._get_ollama_client(model_config.api_url)
167
+
168
+ start_time = time.perf_counter()
169
+ response = await client.generate(
170
+ model=model_config.model_name,
171
+ prompt=prompt,
172
+ temperature=temperature,
173
+ options={'num_predict': model_config.max_tokens}
174
+ )
175
+ end_time = time.perf_counter()
176
+
177
+ response_text = response.get('response', '').strip()
178
+
179
+ # If the model returns an empty string for a question generation prompt,
180
+ # default to an empty JSON array to prevent parsing errors downstream.
181
+ if not response_text and "Return ONLY a JSON array of questions" in prompt:
182
+ logger.warning("Ollama model returned an empty response for question generation. Defaulting to an empty list.")
183
+ response_text = "[]"
184
+
185
+ latency_ms = (end_time - start_time) * 1000
186
+
187
+ return {
188
+ "response": response_text,
189
+ "thinking": f"Inferred by {model_config.display_name} (Ollama).",
190
+ "confidence": 0.85, # 仮の値
191
+ "latency_ms": latency_ms
192
+ }
193
+
194
+ async def infer_streaming(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
195
+ """
196
+ Ollamaモデルでストリーミング推論を実行する。
197
+ """
198
+ await self._ensure_model_available(model_config.model_name, model_config.api_url)
199
+ client = await self._get_ollama_client(model_config.api_url)
200
+
201
+ start_time = time.perf_counter()
202
+ generated_text = ""
203
+
204
+ yield {"type": "thinking", "content": f"Ensuring Ollama model '{model_config.model_name}' is available..."}
205
+
206
+ async for chunk in await client.generate(
207
+ model=model_config.model_name,
208
+ prompt=prompt,
209
+ temperature=temperature,
210
+ options={'num_predict': model_config.max_tokens},
211
+ stream=True
212
+ ):
213
+ if 'response' in chunk:
214
+ generated_text += chunk['response']
215
+ yield {"type": "token", "content": chunk['response']}
216
+
217
+ end_time = time.perf_counter()
218
+ latency_ms = (end_time - start_time) * 1000
219
+
220
+ yield {"type": "complete", "content": generated_text}
221
+ yield {"type": "meta",
222
+ "confidence": 0.85, # 仮の値
223
+ "thinking": f"Inferred by {model_config.display_name} (Ollama).",
224
+ "latency_ms": latency_ms
225
+ }
226
+
227
+ # --- MLX Provider (Apple Silicon) ---
228
+ class MLXProvider:
229
+ _loaded_models: Dict[str, Any] = {}
230
+
231
+ async def _load_model_and_tokenizer(self, model_name: str) -> Dict[str, Any]:
232
+ """
233
+ MLXモデルとトークナイザーをロード(キャッシュ)する。
234
+ """
235
+ if model_name in self._loaded_models:
236
+ return self._loaded_models[model_name]
237
+
238
+ try:
239
+ import mlx.core as mx
240
+ from mlx_lm import load, generate
241
+
242
+ # MLXはデバイスを自動的に使用
243
+ logger.info(f"Loading MLX model: {model_name}...")
244
+ model, tokenizer = load(model_name)
245
+ self._loaded_models[model_name] = {"model": model, "tokenizer": tokenizer}
246
+ logger.info(f"MLX model '{model_name}' loaded successfully.")
247
+ return self._loaded_models[model_name]
248
+ except ImportError:
249
+ logger.error("MLX library not installed. Please `pip install mlx-lm mlx`.")
250
+ raise
251
+ except Exception as e:
252
+ logger.error(f"Failed to load MLX model '{model_name}': {e}")
253
+ raise
254
+
255
+ async def infer(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
256
+ """
257
+ MLXモデルで非ストリーミング推論を実行する。
258
+ """
259
+ model_data = await self._load_model_and_tokenizer(model_config.model_name)
260
+ model = model_data["model"]
261
+ tokenizer = model_data["tokenizer"]
262
+
263
+ start_time = time.perf_counter()
264
+ # MLX-LMのgenerateはストリーミングではない
265
+ response_text = await asyncio.to_thread(
266
+ lambda: generate(
267
+ model, tokenizer, prompt=prompt, max_tokens=model_config.max_tokens, temp=temperature
268
+ )
269
+ )
270
+ end_time = time.perf_counter()
271
+
272
+ latency_ms = (end_time - start_time) * 1000
273
+
274
+ return {
275
+ "response": response_text,
276
+ "thinking": f"Inferred by {model_config.display_name} (MLX).",
277
+ "confidence": 0.9, # 仮の値
278
+ "latency_ms": latency_ms
279
+ }
280
+
281
+ async def infer_streaming(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
282
+ """
283
+ MLXモデルでストリーミング推論を実行する。
284
+ mlx-lmは直接ストリーミングをサポートしていないため、非ストリーミングの結果をチャンクに分割してストリーミングを模倣する。
285
+ """
286
+ yield {"type": "thinking", "content": f"Loading MLX model: {model_config.display_name}..."}
287
+ response_data = await self.infer(model_config, prompt, temperature)
288
+ response_text = response_data["response"]
289
+
290
+ # 結果をワード単位でストリーミングとして模倣
291
+ for word in response_text.split():
292
+ yield {"type": "token", "content": word + " "}
293
+ await asyncio.sleep(0.05) # ダミーのストリーミング遅延
294
+
295
+ yield {"type": "complete", "content": response_text}
296
+ yield {"type": "meta",
297
+ "confidence": response_data["confidence"],
298
+ "thinking": response_data["thinking"],
299
+ "latency_ms": response_data["latency_ms"]
300
+ }
301
+
302
+ # --- GGUF Provider (llama-cpp-python) ---
303
+ class GGUFProvider:
304
+ _loaded_models: Dict[str, Any] = {} # {file_path: Llama object}
305
+
306
+ async def _load_model(self, model_file_path: str) -> Any:
307
+ """
308
+ GGUFモデルをロード(キャッシュ)する。
309
+ """
310
+ if model_file_path in self._loaded_models:
311
+ return self._loaded_models[model_file_path]
312
+
313
+ try:
314
+ from llama_cpp import Llama
315
+ logger.info(f"Loading GGUF model from: {model_file_path}...")
316
+ # n_gpu_layers=-1 はGPUが利用可能な場合に全レイヤーをGPUにオフロードする
317
+ # n_ctxはコンテキストサイズ
318
+ llm = Llama(model_path=model_file_path, n_gpu_layers=-1, n_ctx=4096)
319
+ self._loaded_models[model_file_path] = llm
320
+ logger.info(f"GGUF model '{model_file_path}' loaded successfully.")
321
+ return llm
322
+ except ImportError:
323
+ logger.error("llama-cpp-python not installed. Please `pip install llama-cpp-python`.")
324
+ raise
325
+ except Exception as e:
326
+ logger.error(f"Failed to load GGUF model '{model_file_path}': {e}")
327
+ raise
328
+
329
+ async def infer(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
330
+ """
331
+ GGUFモデルで非ストリーミング推論を実行する。
332
+ """
333
+ if not model_config.model_name or not model_config.model_name.endswith('.gguf'):
334
+ raise ValueError(f"GGUF provider requires 'model_name' to be a valid .gguf file path, but got '{model_config.model_name}'")
335
+
336
+ llm = await self._load_model(model_config.model_name) # model_nameがファイルパス
337
+
338
+ # Format prompt with Phi-4 chat template
339
+ formatted_prompt = f"""<|im_start|>system
340
+ You are a helpful AI assistant.<|im_end|>
341
+ <|im_start|>user
342
+ {prompt}<|im_end|>
343
+ <|im_start|>assistant
344
+ """
345
+
346
+ start_time = time.perf_counter()
347
+ logger.info(f"Calling GGUF model with prompt length: {len(prompt)} (formatted: {len(formatted_prompt)})")
348
+ logger.debug(f"Formatted prompt preview: {formatted_prompt[:300]}...")
349
+
350
+ output = llm(
351
+ formatted_prompt,
352
+ max_tokens=model_config.max_tokens,
353
+ temperature=temperature,
354
+ stop=["<|im_end|>", "<|endoftext|>"], # Phi-4 stop tokens
355
+ echo=False
356
+ )
357
+ end_time = time.perf_counter()
358
+
359
+ logger.debug(f"GGUF model output structure: {list(output.keys()) if isinstance(output, dict) else type(output)}")
360
+
361
+ response_text = ""
362
+ if "choices" in output and len(output["choices"]) > 0:
363
+ choice = output["choices"][0]
364
+ logger.debug(f"Choice structure: {list(choice.keys()) if isinstance(choice, dict) else type(choice)}")
365
+ if "text" in choice:
366
+ response_text = choice["text"].strip()
367
+ logger.info(f"Extracted response text length: {len(response_text)}")
368
+ else:
369
+ logger.warning(f"No choices in output or choices is empty. Output: {output}")
370
+
371
+ # If the model returns an empty string for a question generation prompt,
372
+ # default to an empty JSON array to prevent parsing errors downstream.
373
+ if not response_text and "Return ONLY a JSON array of questions" in prompt:
374
+ logger.warning("GGUF model returned an empty response for question generation. Defaulting to an empty list.")
375
+ response_text = "[]"
376
+ elif not response_text:
377
+ logger.error(f"GGUF model returned empty response. Prompt preview: {prompt[:200]}")
378
+
379
+ latency_ms = (end_time - start_time) * 1000
380
+
381
+ return {
382
+ "response": response_text,
383
+ "thinking": f"Inferred by {model_config.display_name} (GGUF via llama-cpp-python).",
384
+ "confidence": 0.92, # 仮の値
385
+ "latency_ms": latency_ms
386
+ }
387
+
388
+ async def infer_streaming(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
389
+ """
390
+ GGUFモデルでストリーミング推論を実行する。
391
+ """
392
+ if not model_config.model_name or not model_config.model_name.endswith('.gguf'):
393
+ raise ValueError(f"GGUF provider requires 'model_name' to be a valid .gguf file path, but got '{model_config.model_name}'")
394
+
395
+ llm = await self._load_model(model_config.model_name)
396
+
397
+ # Format prompt with Phi-4 chat template
398
+ formatted_prompt = f"""<|im_start|>system
399
+ You are a helpful AI assistant.<|im_end|>
400
+ <|im_start|>user
401
+ {prompt}<|im_end|>
402
+ <|im_start|>assistant
403
+ """
404
+
405
+ start_time = time.perf_counter()
406
+ generated_text = ""
407
+
408
+ yield {"type": "thinking", "content": f"Loading GGUF model: {model_config.display_name}..."}
409
+
410
+ # llama_cppのcreate_completionはストリーミング可能
411
+ for chunk in llm(
412
+ formatted_prompt,
413
+ max_tokens=model_config.max_tokens,
414
+ temperature=temperature,
415
+ stop=["<|im_end|>", "<|endoftext|>"], # Phi-4 stop tokens
416
+ echo=False,
417
+ stream=True
418
+ ):
419
+ token = chunk["choices"][0]["text"]
420
+ if token:
421
+ generated_text += token
422
+ yield {"type": "token", "content": token}
423
+
424
+ end_time = time.perf_counter()
425
+ latency_ms = (end_time - start_time) * 1000
426
+
427
+ yield {"type": "complete", "content": generated_text}
428
+ yield {"type": "meta",
429
+ "confidence": 0.92, # 仮の値
430
+ "thinking": f"Inferred by {model_config.display_name} (GGUF via llama-cpp-python).",
431
+ "latency_ms": latency_ms
432
+ }
433
+
434
+ # TODO: OllamaProvider, MLXProvider, GGUFProvider を実装
model_router.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Dict, Any, AsyncGenerator, List
3
+ import asyncio
4
+ import time
5
+ import threading # For TextIteratorStreamer
6
+ import uuid # For generating unique IDs
7
+ import os
8
+
9
+ from backend.app.config import ConfigManager, ModelConfig, ModelProvider
10
+ # SessionLocal は循環インポートを避けるため遅延インポート
11
+ from backend.app.services.knowledge_service import KnowledgeService, get_knowledge_service
12
+ from null_ai.llm_providers import HuggingFaceProvider, OllamaProvider, MLXProvider, GGUFProvider # Import all providers
13
+ from null_ai.iath_memory import DendriticMemorySpace # 樹木型空間記憶
14
+ from null_ai.coordinate_estimator import CoordinateEstimator # 座標自動推定
15
+ from null_ai.iath_writer import IathWriter, create_tile_from_ai_output # .iath保存
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class ModelRouter:
20
+ """
21
+ 推論エンジン間でモデルのルーティングと切り替えを管理するクラス。
22
+ 師匠(Master)モデルと弟子(Apprentice)モデルの状態を保持する。
23
+ """
24
+ _instance: Optional['ModelRouter'] = None
25
+ _initialized = False
26
+
27
+ def __new__(cls, config_manager: ConfigManager):
28
+ if cls._instance == None:
29
+ cls._instance = super(ModelRouter, cls).__new__(cls)
30
+ return cls._instance
31
+
32
+ def __init__(self, config_manager: ConfigManager):
33
+ if self._initialized:
34
+ return
35
+
36
+ self.config_manager = config_manager
37
+ self.knowledge_service = get_knowledge_service() # Instantiate KnowledgeService internally
38
+ self.master_model: Optional[ModelConfig] = None
39
+ self.apprentice_model: Optional[ModelConfig] = None
40
+ self.active_domain_id: Optional[str] = None # Active domain managed by ConfigManager
41
+ self.managed_engines: Dict[str, Dict[str, Any]] = {} # Tracks all engines and their status
42
+
43
+ # 樹木型空間記憶(.iathファイル)の初期化
44
+ self.dendritic_memory: Optional[DendriticMemorySpace] = None
45
+ self._load_dendritic_memory()
46
+
47
+ # 座標自動推定器の初期化
48
+ self.coordinate_estimator = CoordinateEstimator()
49
+
50
+ # .iathライターの初期化
51
+ iath_file_path = os.getenv("IATH_DB_PATH", "knowledge_base.iath")
52
+ self.iath_writer = IathWriter(iath_file_path)
53
+
54
+ # LLMプロバイダーを初期化
55
+ self.providers: Dict[ModelProvider, Any] = {
56
+ ModelProvider.HUGGINGFACE: HuggingFaceProvider(),
57
+ ModelProvider.OLLAMA: OllamaProvider(),
58
+ ModelProvider.MLX: MLXProvider(),
59
+ ModelProvider.GGUF: GGUFProvider(),
60
+ }
61
+
62
+ # 初期ロード時にデフォルトモデルを設定
63
+ self._load_all_engines() # Load all configured models into managed_engines
64
+ self._load_default_models() # This will now also try to load active master/apprentice from config
65
+
66
+ self._initialized = True
67
+ logger.info("ModelRouter initialized.")
68
+
69
+ def _load_dendritic_memory(self):
70
+ """樹木型空間記憶(.iathファイル)をロードする"""
71
+ try:
72
+ # TODO: ドメインごとに異なる.iathファイルを使用する
73
+ # 現在はデフォルトのパスを使用
74
+ import os
75
+ iath_file_path = os.getenv("IATH_DB_PATH", "knowledge_base.iath")
76
+
77
+ if os.path.exists(iath_file_path):
78
+ self.dendritic_memory = DendriticMemorySpace(iath_file_path)
79
+ stats = self.dendritic_memory.get_statistics()
80
+ logger.info(f"Dendritic memory loaded: {stats['total_tiles']} tiles from {iath_file_path}")
81
+ else:
82
+ logger.warning(f".iath file not found: {iath_file_path}. Starting with empty memory.")
83
+ self.dendritic_memory = None
84
+ except Exception as e:
85
+ logger.error(f"Failed to load dendritic memory: {e}")
86
+ self.dendritic_memory = None
87
+
88
+ def _get_any_available_model(self) -> Optional[ModelConfig]:
89
+ """
90
+ 利用可能なモデルを1つ取得する(座標推定などで使用)
91
+
92
+ 優先順位:
93
+ 1. 師匠モデル
94
+ 2. 設定された最初のモデル
95
+
96
+ Returns:
97
+ ModelConfig or None
98
+ """
99
+ if self.master_model:
100
+ return self.master_model
101
+
102
+ # 設定から最初のモデルを取得
103
+ for model_id, model_config in self.config_manager.models.items():
104
+ return model_config
105
+
106
+ return None
107
+
108
+ def _load_all_engines(self):
109
+ """設定ファイルから全てのモデルをロードし、managed_enginesを初期化する"""
110
+ self.managed_engines = {}
111
+ for model_id, model_config in self.config_manager.models.items():
112
+ self.managed_engines[model_id] = {
113
+ "config": model_config.dict(),
114
+ "status": "available", # Default status
115
+ "unique_id": None # Only for apprentices
116
+ }
117
+ logger.info(f"Loaded {len(self.managed_engines)} engines into management.")
118
+
119
+ def _update_engine_status(self, model_id: str, status: str, unique_id: Optional[str] = None):
120
+ """managed_engines内のエンジンのステータスを更新するヘルパー"""
121
+ if model_id in self.managed_engines:
122
+ self.managed_engines[model_id]["status"] = status
123
+ self.managed_engines[model_id]["unique_id"] = unique_id
124
+ else:
125
+ logger.warning(f"Attempted to update status for unknown engine: {model_id}")
126
+
127
+ def _load_default_models(self):
128
+ """設定からデフォルトの師匠モデルと弟子モデルをロードする"""
129
+ self.active_domain_id = self.config_manager.get_active_domain_id()
130
+
131
+ # ConfigManagerから永続化されたアクティブな師匠・弟子モデルIDをロード
132
+ persisted_master_id = self.config_manager.get_null_ai_setting("active_master_id")
133
+ persisted_apprentice_id = self.config_manager.get_null_ai_setting("active_apprentice_id")
134
+
135
+ # 師匠モデルの設定
136
+ if persisted_master_id:
137
+ master_config = self.config_manager.get_model_config(persisted_master_id)
138
+ if master_config:
139
+ self.master_model = master_config
140
+ self._update_engine_status(self.master_model.model_id, "master")
141
+ logger.info(f"Persisted master model loaded: {self.master_model.display_name}")
142
+ else:
143
+ logger.warning(f"Persisted master model '{persisted_master_id}' not found in configuration. Attempting to set default.")
144
+ self._set_initial_master_from_config() # 永続化されたモデルが見つからない場合はデフォルトを設定
145
+ else:
146
+ self._set_initial_master_from_config() # 永続化されたマスターIDがない場合はデフォルトを設定
147
+
148
+ # 弟子モデルの設定
149
+ if persisted_apprentice_id:
150
+ apprentice_config = self.config_manager.get_model_config(persisted_apprentice_id)
151
+ if apprentice_config:
152
+ self.apprentice_model = apprentice_config
153
+ # unique_idもconfigからロードする
154
+ apprentice_unique_id = self.config_manager.get_null_ai_setting(f"apprentice_unique_id_{persisted_apprentice_id}")
155
+ self._update_engine_status(self.apprentice_model.model_id, "apprentice", apprentice_unique_id)
156
+ logger.info(f"Persisted apprentice model loaded: {self.apprentice_model.display_name}")
157
+ else:
158
+ logger.warning(f"Persisted apprentice model '{persisted_apprentice_id}' not found in configuration. Clearing active apprentice.")
159
+ self.apprentice_model = None
160
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
161
+ else:
162
+ self.apprentice_model = None
163
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
164
+
165
+ def _set_initial_master_from_config(self):
166
+ """設定からデフォルトの師匠モデルをロード(persistedがない場合や見つからない場合)"""
167
+ default_master_config = self.config_manager.get_default_model_config(domain_id=self.active_domain_id)
168
+ if default_master_config:
169
+ self.master_model = default_master_config
170
+ self._update_engine_status(self.master_model.model_id, "master")
171
+ self.config_manager.set_null_ai_setting("active_master_id", self.master_model.model_id)
172
+ logger.info(f"Default master model loaded for domain '{self.active_domain_id}': {self.master_model.display_name}")
173
+ else:
174
+ logger.warning(f"No default master model found for domain '{self.active_domain_id}' in configuration. Master model remains unset.")
175
+ self.master_model = None
176
+ self.config_manager.set_null_ai_setting("active_master_id", None)
177
+
178
+ def set_active_domain_id(self, domain_id: str):
179
+ """アクティブなドメインIDを設定し、それに応じてモデルを再ロードする"""
180
+ if self.active_domain_id != domain_id:
181
+ self.active_domain_id = domain_id
182
+ self._load_default_models() # アクティブドメイン変更時は師匠・弟子モデルも再設定
183
+ logger.info(f"ModelRouter active domain set to {domain_id} and models reloaded.")
184
+
185
+ def set_master_model(self, model_id: str) -> bool:
186
+ """師匠モデルを設定する"""
187
+ model = self.config_manager.get_model_config(model_id)
188
+ if model:
189
+ # 古い師匠を'retired'に戻す (ただし、それがまさに今昇格している弟子ではない場合)
190
+ if self.master_model and self.master_model.model_id != model_id:
191
+ # 昇格の場合、古い師匠はretiredにする
192
+ self._update_engine_status(self.master_model.model_id, "retired")
193
+ logger.info(f"Old master model '{self.master_model.display_name}' set to 'retired'.")
194
+
195
+ self.master_model = model
196
+ self._update_engine_status(model_id, "master")
197
+ self.config_manager.set_null_ai_setting("active_master_id", model_id)
198
+ logger.info(f"Master model set to: {model.display_name}")
199
+ return True
200
+ logger.error(f"Model with ID '{model_id}' not found for master setting.")
201
+ return False
202
+
203
+ def set_apprentice_model(self, model_id: Optional[str]) -> bool:
204
+ """弟子モデルを設定する (Noneでクリア)"""
205
+ # 古い弟子を'available'に戻す
206
+ if self.apprentice_model:
207
+ self._update_engine_status(self.apprentice_model.model_id, "available")
208
+
209
+ if model_id is None or model_id == 'none':
210
+ self.apprentice_model = None
211
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
212
+ logger.info("Apprentice model cleared.")
213
+ return True
214
+
215
+ model = self.config_manager.get_model_config(model_id)
216
+ if model:
217
+ self.apprentice_model = model
218
+ # For apprentices, we need to ensure unique_id is tracked if this is a named apprentice
219
+ apprentice_unique_id = self.managed_engines[model_id].get("unique_id") # Get existing unique_id
220
+ self._update_engine_status(model_id, "apprentice", apprentice_unique_id)
221
+ self.config_manager.set_null_ai_setting("active_apprentice_id", model_id)
222
+ logger.info(f"Apprentice model set to: {model.display_name}")
223
+ return True
224
+ logger.error(f"Model with ID '{model_id}' not found for apprentice setting.")
225
+ return False
226
+
227
+ def get_all_managed_engines(self) -> List[Dict[str, Any]]:
228
+ """管理している全てのエンジンとそのステータス、ユニークIDを含むリストを返す"""
229
+ return list(self.managed_engines.values())
230
+
231
+ def get_master_model(self) -> Optional[ModelConfig]:
232
+ """現在の師匠モデルを取得する"""
233
+ return self.master_model
234
+
235
+ def get_apprentice_model(self) -> Optional[ModelConfig]:
236
+ """現在の弟子モデルを取得する"""
237
+ return self.apprentice_model
238
+
239
+ def get_model_config(self, model_id: str) -> Optional[ModelConfig]:
240
+ """指定されたmodel_idのモデル設定を取得する"""
241
+ return self.config_manager.models.get(model_id)
242
+
243
+ def swap_engines(self, apprentice_model_id: str) -> bool:
244
+ """師匠と指定した弟子を入れ替える"""
245
+ if not self.master_model:
246
+ logger.error("Cannot swap: No master model is currently set.")
247
+ return False
248
+
249
+ apprentice_candidate = self.config_manager.get_model_config(apprentice_model_id)
250
+ if not apprentice_candidate:
251
+ logger.error(f"Cannot swap: Apprentice model '{apprentice_model_id}' not found.")
252
+ return False
253
+
254
+ # 現在の師匠と弟子の情報を取得
255
+ old_master_model = self.master_model
256
+
257
+ # 新しい師匠は指定された弟子
258
+ new_master_model = apprentice_candidate
259
+
260
+ # 新しい弟子は古い師匠 (ただし、古い師匠が現在のアプレンティス候補でなければ)
261
+ new_apprentice_model = old_master_model if old_master_model.model_id != apprentice_model_id else None
262
+
263
+ # ロールを入れ替える
264
+ self.master_model = new_master_model
265
+ self.apprentice_model = new_apprentice_model
266
+
267
+ # ステータスと永続化を更新
268
+ self._update_engine_status(new_master_model.model_id, "master", self.managed_engines[new_master_model.model_id].get("unique_id"))
269
+ self.config_manager.set_null_ai_setting("active_master_id", new_master_model.model_id)
270
+
271
+ if new_apprentice_model:
272
+ self._update_engine_status(new_apprentice_model.model_id, "apprentice", self.managed_engines[new_apprentice_model.model_id].get("unique_id"))
273
+ self.config_manager.set_null_ai_setting("active_apprentice_id", new_apprentice_model.model_id)
274
+ else:
275
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
276
+
277
+ # 古いアプレンティスが指定されたアプレンティスとは異なる場合、そのステータスをavailableに戻す
278
+ if old_master_model and old_master_model.model_id != new_master_model.model_id:
279
+ self._update_engine_status(old_master_model.model_id, "available") # Old master becomes available if not the new apprentice
280
+
281
+ logger.info(f"Engines swapped: New Master is {new_master_model.display_name}, New Apprentice is {new_apprentice_model.display_name if new_apprentice_model else 'None'}")
282
+ return True
283
+
284
+ def promote_apprentice(self, apprentice_model_id: str) -> bool:
285
+ """指定した弟子を師匠に昇格させ、現在の師匠を引退させる"""
286
+ apprentice_to_promote = self.config_manager.get_model_config(apprentice_model_id)
287
+ if not apprentice_to_promote:
288
+ logger.error(f"Cannot promote: Apprentice model '{apprentice_model_id}' not found.")
289
+ return False
290
+
291
+ # 現在の師匠を引退させる
292
+ if self.master_model:
293
+ self._update_engine_status(self.master_model.model_id, "retired")
294
+ logger.info(f"Current master model '{self.master_model.display_name}' retired.")
295
+
296
+ # 弟子を師匠に昇格させる
297
+ self.master_model = apprentice_to_promote
298
+ self._update_engine_status(apprentice_to_promote.model_id, "master", self.managed_engines[apprentice_to_promote.model_id].get("unique_id"))
299
+ self.config_manager.set_null_ai_setting("active_master_id", apprentice_to_promote.model_id)
300
+
301
+ # 弟子モデルのロールをクリア (昇格したため)
302
+ if self.apprentice_model and self.apprentice_model.model_id == apprentice_model_id:
303
+ self.apprentice_model = None
304
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
305
+ logger.info(f"Apprentice model '{apprentice_to_promote.display_name}' promoted to master and apprentice role cleared.")
306
+ else:
307
+ logger.warning(f"Apprentice '{apprentice_model_id}' was promoted, but was not the currently active apprentice.")
308
+
309
+ logger.info(f"Apprentice '{apprentice_to_promote.display_name}' promoted to Master.")
310
+ return True
311
+
312
+ def create_new_apprentice(self) -> Optional[Dict[str, Any]]:
313
+ """新しい「空っぽの弟子」推論エンジンを生成し、登録する"""
314
+ new_apprentice_unique_id = str(uuid.uuid4())
315
+ new_apprentice_model_id = f"apprentice-{new_apprentice_unique_id[:8]}"
316
+ new_apprentice_display_name = f"Apprentice ({new_apprentice_unique_id[:4]})"
317
+
318
+ # ユーザーは「文字通り学習データがないもの」と指定。
319
+ # ここでは、汎用的なベースモデルを想定するか、あるいは特定のプロバイダー/モデル名を指定する。
320
+ # 仮に、Ollamaの何らかの軽量モデルをテンプレートとして利用する。
321
+ # または、単にモデル設定のみを生成し、実際のモデルファイルは後で学習時にロードする。
322
+
323
+ # NOTE: この "empty" モデルが具体的に何を指すかは、
324
+ # 後続のファインチューニングのロジックによって具体化される必要がある。
325
+ # ここでは、設定上の一エントリとして追加する。
326
+
327
+ # TODO: ベースとなる「空っぽ」モデルの設定をnull_ai_config.jsonなどから取得できるようにする
328
+ base_apprentice_provider = ModelProvider.OLLAMA # 仮のデフォルトプロバイダー
329
+ base_apprentice_model_name = "mistral:latest" # 仮のデフォルトモデル名
330
+
331
+ # Configuration for the new apprentice
332
+ new_model_config_data = {
333
+ "model_id": new_apprentice_model_id,
334
+ "display_name": new_apprentice_display_name,
335
+ "provider": base_apprentice_provider.value,
336
+ "model_name": base_apprentice_model_name,
337
+ "max_tokens": 4096,
338
+ "temperature": 0.7,
339
+ "timeout": 120,
340
+ "is_default": False,
341
+ "supported_domains": ["general"], # New apprentices start with general knowledge
342
+ "description": f"Newly generated empty apprentice model with unique ID {new_apprentice_unique_id}."
343
+ }
344
+
345
+ try:
346
+ new_model = self.config_manager.add_model(new_model_config_data)
347
+ if not new_model:
348
+ raise Exception("Failed to add new model via config_manager.")
349
+
350
+ self.managed_engines[new_model.model_id] = {
351
+ "config": new_model.dict(),
352
+ "status": "available",
353
+ "unique_id": new_apprentice_unique_id
354
+ }
355
+ # unique_idも永続化しておく
356
+ self.config_manager.set_null_ai_setting(f"apprentice_unique_id_{new_model.model_id}", new_apprentice_unique_id)
357
+
358
+ logger.info(f"New empty apprentice '{new_apprentice_display_name}' created with ID: {new_apprentice_model_id}")
359
+ return self.managed_engines[new_model.model_id]
360
+ except Exception as e:
361
+ logger.error(f"Failed to create new apprentice: {e}")
362
+ return None
363
+
364
+ def get_active_model(self, for_inference: bool = True) -> Optional[ModelConfig]:
365
+ """
366
+ 推論に使用するアクティブなモデルを取得する。
367
+ for_inferenceがTrueの場合、師匠モデルを返す。
368
+ 倒木システムにおける「学習」などの用途で弟子モデルが必要な場合は、
369
+ 直接get_apprentice_modelを使用する。
370
+ """
371
+ if for_inference:
372
+ return self.master_model
373
+ # 将来的に「成長した」弟子モデルへの自動切り替えロジックが入る可能性
374
+ return self.master_model # デフォルトでは師匠モデルを使用
375
+
376
+ async def infer(self, prompt: str, domain_id: str, model_config: ModelConfig, temperature: Optional[float] = None, save_to_memory: bool = False, rag_mode: str = "rag") -> Dict[str, Any]:
377
+ """
378
+ 指定されたモデルで推論を実行する。
379
+ RAG: DBに知識があればそれを使用、なければAI内部知識で推論してDBに蓄積。
380
+ rag_modeに応じて動作を切り替える。
381
+ """
382
+ # rag_mode == 'direct' の場合の処理
383
+ if rag_mode == "direct":
384
+ relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=1)
385
+ if relevant_knowledge:
386
+ # DBから直接回答を返す
387
+ best_match = relevant_knowledge[0]
388
+ logger.info(f"Direct mode: DB knowledge found. Returning direct answer from tile {best_match['id']}.")
389
+ return {
390
+ "response": best_match['content'],
391
+ "thinking": f"Directly retrieved from knowledge base. Tile ID: {best_match['id']}.",
392
+ "confidence": best_match['confidence_score'],
393
+ "model_used": "database_direct",
394
+ "latency_ms": 50, # DB検索なので高速
395
+ "source_type": "db_direct"
396
+ }
397
+ else:
398
+ # DBに情報がない場合は「わかりません」と返す
399
+ logger.info("Direct mode: No DB knowledge found. Returning 'Not found'.")
400
+ return {
401
+ "response": "ご指定の情報はナレッジベース内に見つかりませんでした。",
402
+ "thinking": "Directly searched knowledge base, but no relevant information was found.",
403
+ "confidence": 0.9, # 「見つからない」という回答には高い信頼度を与える
404
+ "model_used": "database_direct",
405
+ "latency_ms": 50,
406
+ "source_type": "db_direct"
407
+ }
408
+
409
+ # rag_mode == 'rag' の場合の処理 (既存のロジック)
410
+ has_db_knowledge = self._check_db_knowledge(domain_id, prompt)
411
+ source_type = "db_augmented" if has_db_knowledge else "ai_internal_weights"
412
+ thinking_process = ""
413
+ augmented_prompt = prompt
414
+
415
+ if has_db_knowledge:
416
+ # DBから関連知識を取得
417
+ relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=3)
418
+ logger.info(f"RAG mode: DB knowledge found for domain '{domain_id}'. Retrieved {len(relevant_knowledge)} relevant tiles.")
419
+
420
+ # RAG: プロンプトに知識を統合
421
+ knowledge_context = "\n\n".join([
422
+ f"[Knowledge {i+1} - {tile['verification_type']} verification, confidence: {tile['confidence_score']}]\n"
423
+ f"Topic: {tile['topic']}\n"
424
+ f"Content: {tile['content']}"
425
+ for i, tile in enumerate(relevant_knowledge)
426
+ ])
427
+
428
+ augmented_prompt = f"""Based on the following verified knowledge from the database:
429
+
430
+ {knowledge_context}
431
+
432
+ Now, please answer the following question accurately:
433
+ {prompt}"""
434
+
435
+ response = await self._perform_llm_inference(model_config, augmented_prompt, temperature or model_config.temperature)
436
+ thinking_process = f"Accessed knowledge base. Retrieved {len(relevant_knowledge)} relevant knowledge tiles. " + response.get("thinking", "")
437
+ else:
438
+ # DBに知識がない場合、AI内部知識で推論
439
+ logger.info(f"RAG mode: No DB knowledge found for domain '{domain_id}'. Using AI internal weights.")
440
+ response = await self._perform_llm_inference(model_config, prompt, temperature or model_config.temperature)
441
+ thinking_process = "No specific DB knowledge found. Inferred from AI internal weights. " + response.get("thinking", "")
442
+
443
+ # 自己拡充: 推論結果をDBに保存
444
+ if save_to_memory and response.get("confidence", 0) >= 0.7:
445
+ await self._save_inference_to_db(
446
+ domain_id=domain_id,
447
+ prompt=prompt,
448
+ response=response.get("response", ""),
449
+ confidence=response.get("confidence", 0.7),
450
+ source_type="ai",
451
+ model_id=model_config.model_id
452
+ )
453
+
454
+ # 倒木システム: 師匠の出力を教師データとして保存
455
+ is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
456
+ logger.info(f"[Training Data Check] master_model={self.master_model.model_id if self.master_model else None}, "
457
+ f"current_model={model_config.model_id}, is_master={is_master}, "
458
+ f"confidence={response.get('confidence', 0)}, save_to_memory={save_to_memory}")
459
+ if is_master and response.get("confidence", 0) >= 0.8:
460
+ logger.info(f"[Training Data] Saving master output for domain '{domain_id}'...")
461
+ await self._save_master_output_as_training_data(
462
+ prompt=prompt,
463
+ master_response=response.get("response", ""),
464
+ domain_id=domain_id,
465
+ confidence=response.get("confidence", 0.8)
466
+ )
467
+ else:
468
+ logger.info(f"[Training Data] NOT saving: is_master={is_master}, confidence={response.get('confidence', 0)}")
469
+
470
+ return {
471
+ "response": response.get("response", ""),
472
+ "thinking": thinking_process,
473
+ "confidence": response.get("confidence", 0.7),
474
+ "model_used": model_config.model_id,
475
+ "latency_ms": response.get("latency_ms", 0),
476
+ "source_type": source_type
477
+ }
478
+
479
+ async def infer_streaming(self, prompt: str, domain_id: str, model_config: ModelConfig, temperature: Optional[float] = None, save_to_memory: bool = False, rag_mode: str = "rag") -> AsyncGenerator[Dict[str, Any], None]:
480
+ print(f"DEBUG: infer_streaming called - model={model_config.model_id if model_config else 'None'}")
481
+ logger.error(f"[CRITICAL DEBUG] infer_streaming ENTRY POINT - model={model_config.model_id if model_config else 'None'}, save_to_memory={save_to_memory}")
482
+ """
483
+ 指定されたモデルで推論をストリーミングで実行する。
484
+ RAG: DBに知識があればそれを使用、なければAI内部知識で推論してDBに蓄積。
485
+ rag_modeに応じて動作を切り替える。
486
+ """
487
+ logger.info(f"[INFER_STREAMING] Called with model={model_config.model_id}, domain={domain_id}, save_to_memory={save_to_memory}")
488
+
489
+ # 師匠の応答を蓄積するための変数(finallyブロックで使用)
490
+ generated_response = ""
491
+
492
+ try:
493
+ # rag_mode == 'direct' の場合の処理
494
+ if rag_mode == "direct":
495
+ relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=1)
496
+ if relevant_knowledge:
497
+ best_match = relevant_knowledge[0]
498
+ logger.info(f"Direct mode (streaming): DB knowledge found. Yielding direct answer from tile {best_match['id']}.")
499
+ yield {"type": "token", "content": best_match['content']}
500
+ yield {"type": "meta", "source_type": "db_direct", "thinking": f"Directly retrieved from knowledge base. Tile ID: {best_match['id']}.", "confidence": best_match['confidence_score'], "model_used": "database_direct"}
501
+ yield {"type": "complete", "content": best_match['content']}
502
+ return
503
+ else:
504
+ logger.info("Direct mode (streaming): No DB knowledge found. Yielding 'Not found'.")
505
+ response_text = "ご指定の情報はナレッジベース内に見つかりませんでした。"
506
+ yield {"type": "token", "content": response_text}
507
+ yield {"type": "meta", "source_type": "db_direct", "thinking": "Directly searched knowledge base, but no relevant information was found.", "confidence": 0.9, "model_used": "database_direct"}
508
+ yield {"type": "complete", "content": response_text}
509
+ return
510
+
511
+ # rag_mode == 'rag' の場合の処理 (既存のロジック)
512
+ has_db_knowledge = self._check_db_knowledge(domain_id, prompt)
513
+ source_type = "db_augmented" if has_db_knowledge else "ai_internal_weights"
514
+ augmented_prompt = prompt
515
+
516
+ yield {"type": "thinking", "content": f"Checking DB knowledge for domain '{domain_id}'..."}
517
+ await asyncio.sleep(0.1)
518
+
519
+ if has_db_knowledge:
520
+ # DBから関連知識を取得
521
+ relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=3)
522
+ yield {"type": "thinking", "content": f"Relevant DB knowledge found. Retrieved {len(relevant_knowledge)} tiles. Augmenting prompt..."}
523
+
524
+ # RAG: プロンプトに知識を統合
525
+ knowledge_context = "\n\n".join([
526
+ f"[Knowledge {i+1} - {tile['verification_type']} verification, confidence: {tile['confidence_score']}]\n"
527
+ f"Topic: {tile['topic']}\n"
528
+ f"Content: {tile['content']}"
529
+ for i, tile in enumerate(relevant_knowledge)
530
+ ])
531
+
532
+ augmented_prompt = f"""Based on the following verified knowledge from the database:
533
+
534
+ {knowledge_context}
535
+
536
+ Now, please answer the following question accurately:
537
+ {prompt}"""
538
+
539
+ # ストリーミング推論実行
540
+ async for chunk in self._perform_llm_streaming_inference(model_config, augmented_prompt, temperature or model_config.temperature):
541
+ if chunk.get("type") == "token":
542
+ generated_response += chunk.get("content", "")
543
+ yield chunk
544
+
545
+ yield {"type": "meta", "source_type": source_type, "thinking": f"Accessed knowledge base. Retrieved {len(relevant_knowledge)} relevant knowledge tiles."}
546
+
547
+ # 倒木システム: 師匠の出力を教師データとして保存 (DB知識使用時も保存)
548
+ logger.error(f"[CRITICAL DEBUG] Reached training data save section (with DB knowledge)! generated_response length={len(generated_response)}")
549
+ is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
550
+ logger.info(f"[Training Data Check - DB Augmented] master_model={self.master_model.model_id if self.master_model else None}, "
551
+ f"current_model={model_config.model_id}, is_master={is_master}, "
552
+ f"response_length={len(generated_response)}, save_to_memory={save_to_memory}")
553
+ if is_master and len(generated_response) > 0:
554
+ logger.info(f"[Training Data - DB Augmented] Saving master output for domain '{domain_id}'...")
555
+ await self._save_master_output_as_training_data(
556
+ prompt=prompt,
557
+ master_response=generated_response,
558
+ domain_id=domain_id,
559
+ confidence=0.8
560
+ )
561
+ else:
562
+ logger.info(f"[Training Data - DB Augmented] NOT saving: is_master={is_master}, response_length={len(generated_response)}")
563
+
564
+ else:
565
+ yield {"type": "thinking", "content": "No specific DB knowledge found. Using AI internal weights."}
566
+
567
+ # ストリーミング推論実行
568
+ async for chunk in self._perform_llm_streaming_inference(model_config, prompt, temperature or model_config.temperature):
569
+ if chunk.get("type") == "token":
570
+ generated_response += chunk.get("content", "")
571
+ yield chunk
572
+
573
+ yield {"type": "meta", "source_type": source_type, "thinking": "Inferred from AI internal weights."}
574
+
575
+ # 倒木システム: 師匠の出力を教師データとして保存
576
+ logger.error(f"[CRITICAL DEBUG] Reached training data save section! generated_response length={len(generated_response)}")
577
+ is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
578
+ logger.info(f"[Training Data Check - Streaming] master_model={self.master_model.model_id if self.master_model else None}, "
579
+ f"current_model={model_config.model_id}, is_master={is_master}, "
580
+ f"response_length={len(generated_response)}, save_to_memory={save_to_memory}")
581
+ if is_master and len(generated_response) > 0:
582
+ logger.info(f"[Training Data - Streaming] Saving master output for domain '{domain_id}'...")
583
+ await self._save_master_output_as_training_data(
584
+ prompt=prompt,
585
+ master_response=generated_response,
586
+ domain_id=domain_id,
587
+ confidence=0.8
588
+ )
589
+ else:
590
+ logger.info(f"[Training Data - Streaming] NOT saving: is_master={is_master}, response_length={len(generated_response)}")
591
+
592
+ # 自己拡充: 推論結果をDBに保存(信頼度が十分な場合)
593
+ if save_to_memory and len(generated_response) > 0:
594
+ await self._save_inference_to_db(
595
+ domain_id=domain_id,
596
+ prompt=prompt,
597
+ response=generated_response,
598
+ confidence=0.75, # ストリーミングなので固定値
599
+ source_type="ai",
600
+ model_id=model_config.model_id
601
+ )
602
+
603
+ except Exception as e:
604
+ logger.error(f"Error during streaming inference: {e}")
605
+ yield {"type": "error", "content": str(e)}
606
+
607
+ def _check_db_knowledge(self, domain_id: str, prompt: str) -> bool:
608
+ """
609
+ 指定されたドメインとプロンプトに関連する知識があるかチェックする。
610
+ 樹木型空間記憶(.iath)を使用。
611
+ """
612
+ try:
613
+ if self.dendritic_memory is None:
614
+ return False
615
+
616
+ # テキスト検索でマッチがあるかチェック
617
+ results = self.dendritic_memory.search_by_text(prompt[:200], top_k=1)
618
+ return len(results) > 0
619
+
620
+ except Exception as e:
621
+ logger.error(f"Error checking dendritic memory: {e}")
622
+ return False
623
+
624
+ def _retrieve_relevant_knowledge(self, domain_id: str, prompt: str, top_k: int = 3) -> list:
625
+ """
626
+ 指定されたドメインとプロンプトに関連する知識を取得する。
627
+ 樹木型空間記憶(.iath)のハイブリッド検索を使用。
628
+ """
629
+ try:
630
+ if self.dendritic_memory is None:
631
+ return []
632
+
633
+ # ハイブリッド検索(テキスト + 空間座標)
634
+ # TODO: プロンプトから空間座標を推定する機能を追加
635
+ results = self.dendritic_memory.hybrid_search(prompt, query_coords=None, top_k=top_k)
636
+
637
+ # 統一フォーマットに変換
638
+ return [
639
+ {
640
+ "id": tile["metadata"]["knowledge_id"],
641
+ "topic": tile["metadata"]["topic"],
642
+ "content": tile["content"]["final_response"],
643
+ "thinking_process": tile["content"]["thinking_process"],
644
+ "confidence_score": tile["verification"]["initial_certainty"],
645
+ "verification_type": tile["verification"]["status"],
646
+ "coordinates": tile["coordinates"],
647
+ "text_match_score": tile.get("text_match_score", 0),
648
+ "spatial_distance": tile.get("spatial_distance", None)
649
+ }
650
+ for tile in results
651
+ ]
652
+
653
+ except Exception as e:
654
+ logger.error(f"Error retrieving relevant knowledge from dendritic memory: {e}")
655
+ return []
656
+
657
+ async def _perform_llm_inference(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
658
+ """
659
+ 指定されたモデル設定でLLM推論を実行する。
660
+ 各プロバイダーの実装を使用。
661
+ """
662
+ provider_type = model_config.provider
663
+
664
+ if provider_type not in self.providers:
665
+ raise ValueError(f"Unsupported provider: {provider_type}")
666
+
667
+ provider = self.providers[provider_type]
668
+
669
+ try:
670
+ result = await provider.infer(model_config, prompt, temperature)
671
+ return result
672
+ except Exception as e:
673
+ logger.error(f"Error during LLM inference with provider '{provider_type}': {e}")
674
+ raise
675
+
676
+ async def _perform_llm_streaming_inference(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
677
+ """
678
+ 指定されたモデル設定でLLMストリーミング推論を実行する。
679
+ 各プロバイダーの実装を使用。
680
+ """
681
+ provider_type = model_config.provider
682
+
683
+ if provider_type not in self.providers:
684
+ raise ValueError(f"Unsupported provider: {provider_type}")
685
+
686
+ provider = self.providers[provider_type]
687
+
688
+ try:
689
+ async for chunk in provider.infer_streaming(model_config, prompt, temperature):
690
+ yield chunk
691
+ except Exception as e:
692
+ logger.error(f"Error during LLM streaming inference with provider '{provider_type}': {e}")
693
+ yield {"type": "error", "content": str(e)}
694
+
695
+ async def _save_inference_to_db(self, domain_id: str, prompt: str, response: str, confidence: float, source_type: str, model_id: str) -> bool:
696
+ """
697
+ 推論結果をKnowledge TileとしてDBに保存する(自己拡充)。
698
+
699
+ Priority 2実装: 座標自動推定 + .iath保存
700
+ """
701
+ try:
702
+ from backend.app.database.models import KnowledgeTile
703
+ from backend.app.database.session import SessionLocal # 遅延インポート
704
+
705
+ db = SessionLocal()
706
+ tile_id = f"ai_ktile_{uuid.uuid4().hex}"
707
+
708
+ # AI生成の知識タイルを作成(SQLite)
709
+ new_tile = KnowledgeTile(
710
+ id=tile_id,
711
+ workspace_id="default_workspace", # ローカル版なのでデフォルトワークスペース
712
+ domain_id=domain_id,
713
+ topic=prompt[:200], # プロンプトをトピックとして使用
714
+ content=response,
715
+ confidence_score=confidence,
716
+ verification_type="ai", # AI生成を示す
717
+ verification_count=1,
718
+ contributor_id=None, # AIなのでcontributorなし
719
+ last_verified_by_id=None
720
+ )
721
+
722
+ db.add(new_tile)
723
+ db.commit()
724
+ db.close()
725
+
726
+ logger.info(f"Saved AI inference to SQLite: {new_tile.id}")
727
+
728
+ # Priority 2: 座標自動推定 + .iath保存
729
+ try:
730
+ # 座標推定用のLLM推論関数を作成
731
+ async def llm_inference_for_coords(coord_prompt):
732
+ # 師匠モデル(またはDeepSeek)を使って座標推定
733
+ estimation_model = self.master_model if self.master_model else self._get_any_available_model()
734
+
735
+ if estimation_model:
736
+ result = await self._perform_llm_inference(
737
+ estimation_model,
738
+ coord_prompt,
739
+ temperature=0.3 # 低温度で一貫性を保つ
740
+ )
741
+ return result.get("response", "")
742
+ return "{\"coordinates\": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], \"confidence\": 0.3}"
743
+
744
+ # 座標を推定
745
+ coord_result = await self.coordinate_estimator.estimate_coordinates(
746
+ prompt=prompt,
747
+ response=response,
748
+ domain_id=domain_id,
749
+ llm_inference_func=llm_inference_for_coords,
750
+ use_reasoning=False # 高速化のため推論過程は省略
751
+ )
752
+
753
+ # .iath Tileオブジェクトを作成
754
+ iath_tile = create_tile_from_ai_output(
755
+ knowledge_id=tile_id,
756
+ topic=prompt[:100],
757
+ prompt=prompt,
758
+ response=response,
759
+ coordinates=coord_result["coordinates"],
760
+ confidence=confidence,
761
+ domain_id=domain_id,
762
+ source="ai_generated"
763
+ )
764
+
765
+ # .iathファイルに保存
766
+ success = self.iath_writer.append_tile(iath_tile)
767
+
768
+ if success:
769
+ logger.info(f"Saved AI inference to .iath with coordinates: {coord_result['coordinates']}")
770
+
771
+ # メモリをリロード(新しいタイルを検索可能にする)
772
+ self._load_dendritic_memory()
773
+ else:
774
+ logger.warning(f"Failed to save to .iath, but SQLite save succeeded")
775
+
776
+ except Exception as iath_error:
777
+ logger.error(f"Error saving to .iath (SQLite save succeeded): {iath_error}")
778
+ # .iath保存失敗してもSQLite保存は成功しているのでTrueを返す
779
+
780
+ return True
781
+
782
+ except Exception as e:
783
+ logger.error(f"Error saving inference to DB: {e}")
784
+ return False
785
+
786
+ async def _save_master_output_as_training_data(self, prompt: str, master_response: str, domain_id: str, confidence: float) -> bool:
787
+ """
788
+ 師匠の出力を弟子のファインチューニング用教師データとして保存する。
789
+ 倒木システムの核心機能。
790
+ """
791
+ try:
792
+ import json
793
+ import os
794
+ from datetime import datetime
795
+
796
+ logger.info(f"[Save Training Data] Starting to save master output (domain={domain_id}, confidence={confidence})")
797
+
798
+ # トレーニングデータ保存ディレクトリ
799
+ training_data_dir = "training_data/master_outputs"
800
+ logger.info(f"[Save Training Data] Creating directory: {training_data_dir}")
801
+ os.makedirs(training_data_dir, exist_ok=True)
802
+
803
+ # Alpaca形式で保存
804
+ training_example = {
805
+ "instruction": f"You are an expert in {domain_id}. Provide accurate information based on verified knowledge.",
806
+ "input": prompt,
807
+ "output": master_response,
808
+ "metadata": {
809
+ "domain_id": domain_id,
810
+ "confidence": confidence,
811
+ "master_model_id": self.master_model.model_id if self.master_model else "unknown",
812
+ "timestamp": datetime.utcnow().isoformat(),
813
+ "source": "master_output"
814
+ }
815
+ }
816
+
817
+ # JSONLファイルに追記
818
+ output_file = os.path.join(training_data_dir, f"master_outputs_{domain_id}.jsonl")
819
+ logger.info(f"[Save Training Data] Writing to file: {output_file}")
820
+ with open(output_file, 'a', encoding='utf-8') as f:
821
+ f.write(json.dumps(training_example, ensure_ascii=False) + '\n')
822
+
823
+ logger.info(f"✓ Successfully saved master output as training data: {output_file}")
824
+ return True
825
+
826
+ except Exception as e:
827
+ logger.error(f"✗ Error saving master output as training data: {e}", exc_info=True)
828
+ return False
model_router.py.backup ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Dict, Any, AsyncGenerator, List
3
+ import asyncio
4
+ import time
5
+ import threading # For TextIteratorStreamer
6
+ import uuid # For generating unique IDs
7
+ import os
8
+
9
+ from backend.app.config import ConfigManager, ModelConfig, ModelProvider
10
+ # SessionLocal は循環インポートを避けるため遅延インポート
11
+ from backend.app.services.knowledge_service import KnowledgeService, get_knowledge_service
12
+ from null_ai.llm_providers import HuggingFaceProvider, OllamaProvider, MLXProvider, GGUFProvider # Import all providers
13
+ from null_ai.iath_memory import DendriticMemorySpace # 樹木型空間記憶
14
+ from null_ai.coordinate_estimator import CoordinateEstimator # 座標自動推定
15
+ from null_ai.iath_writer import IathWriter, create_tile_from_ai_output # .iath保存
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class ModelRouter:
20
+ """
21
+ 推論エンジン間でモデルのルーティングと切り替えを管理するクラス。
22
+ 師匠(Master)モデルと弟子(Apprentice)モデルの状態を保持する。
23
+ """
24
+ _instance: Optional['ModelRouter'] = None
25
+ _initialized = False
26
+
27
+ def __new__(cls, config_manager: ConfigManager):
28
+ if cls._instance == None:
29
+ cls._instance = super(ModelRouter, cls).__new__(cls)
30
+ return cls._instance
31
+
32
+ def __init__(self, config_manager: ConfigManager):
33
+ if self._initialized:
34
+ return
35
+
36
+ self.config_manager = config_manager
37
+ self.knowledge_service = get_knowledge_service() # Instantiate KnowledgeService internally
38
+ self.master_model: Optional[ModelConfig] = None
39
+ self.apprentice_model: Optional[ModelConfig] = None
40
+ self.active_domain_id: Optional[str] = None # Active domain managed by ConfigManager
41
+ self.managed_engines: Dict[str, Dict[str, Any]] = {} # Tracks all engines and their status
42
+
43
+ # 樹木型空間記憶(.iathファイル)の初期化
44
+ self.dendritic_memory: Optional[DendriticMemorySpace] = None
45
+ self._load_dendritic_memory()
46
+
47
+ # 座標自動推定器の初期化
48
+ self.coordinate_estimator = CoordinateEstimator()
49
+
50
+ # .iathライターの初期化
51
+ iath_file_path = os.getenv("IATH_DB_PATH", "knowledge_base.iath")
52
+ self.iath_writer = IathWriter(iath_file_path)
53
+
54
+ # LLMプロバイダーを初期化
55
+ self.providers: Dict[ModelProvider, Any] = {
56
+ ModelProvider.HUGGINGFACE: HuggingFaceProvider(),
57
+ ModelProvider.OLLAMA: OllamaProvider(),
58
+ ModelProvider.MLX: MLXProvider(),
59
+ ModelProvider.GGUF: GGUFProvider(),
60
+ }
61
+
62
+ # 初期ロード時にデフォルトモデルを設定
63
+ self._load_all_engines() # Load all configured models into managed_engines
64
+ self._load_default_models() # This will now also try to load active master/apprentice from config
65
+
66
+ self._initialized = True
67
+ logger.info("ModelRouter initialized.")
68
+
69
+ def _load_dendritic_memory(self):
70
+ """樹木型空間記憶(.iathファイル)をロードする"""
71
+ try:
72
+ # TODO: ドメインごとに異なる.iathファイルを使用する
73
+ # 現在はデフォルトのパスを使用
74
+ import os
75
+ iath_file_path = os.getenv("IATH_DB_PATH", "knowledge_base.iath")
76
+
77
+ if os.path.exists(iath_file_path):
78
+ self.dendritic_memory = DendriticMemorySpace(iath_file_path)
79
+ stats = self.dendritic_memory.get_statistics()
80
+ logger.info(f"Dendritic memory loaded: {stats['total_tiles']} tiles from {iath_file_path}")
81
+ else:
82
+ logger.warning(f".iath file not found: {iath_file_path}. Starting with empty memory.")
83
+ self.dendritic_memory = None
84
+ except Exception as e:
85
+ logger.error(f"Failed to load dendritic memory: {e}")
86
+ self.dendritic_memory = None
87
+
88
+ def _get_any_available_model(self) -> Optional[ModelConfig]:
89
+ """
90
+ 利用可能なモデルを1つ取得する(座標推定などで使用)
91
+
92
+ 優先順位:
93
+ 1. 師匠モデル
94
+ 2. 設定された最初のモデル
95
+
96
+ Returns:
97
+ ModelConfig or None
98
+ """
99
+ if self.master_model:
100
+ return self.master_model
101
+
102
+ # 設定から最初のモデルを取得
103
+ for model_id, model_config in self.config_manager.models.items():
104
+ return model_config
105
+
106
+ return None
107
+
108
+ def _load_all_engines(self):
109
+ """設定ファイルから全てのモデルをロードし、managed_enginesを初期化する"""
110
+ self.managed_engines = {}
111
+ for model_id, model_config in self.config_manager.models.items():
112
+ self.managed_engines[model_id] = {
113
+ "config": model_config.dict(),
114
+ "status": "available", # Default status
115
+ "unique_id": None # Only for apprentices
116
+ }
117
+ logger.info(f"Loaded {len(self.managed_engines)} engines into management.")
118
+
119
+ def _update_engine_status(self, model_id: str, status: str, unique_id: Optional[str] = None):
120
+ """managed_engines内のエンジンのステータスを更新するヘルパー"""
121
+ if model_id in self.managed_engines:
122
+ self.managed_engines[model_id]["status"] = status
123
+ self.managed_engines[model_id]["unique_id"] = unique_id
124
+ else:
125
+ logger.warning(f"Attempted to update status for unknown engine: {model_id}")
126
+
127
+ def _load_default_models(self):
128
+ """設定からデフォルトの師匠モデルと弟子モデルをロードする"""
129
+ self.active_domain_id = self.config_manager.get_active_domain_id()
130
+
131
+ # ConfigManagerから永続化されたアクティブな師匠・弟子モデルIDをロード
132
+ persisted_master_id = self.config_manager.get_null_ai_setting("active_master_id")
133
+ persisted_apprentice_id = self.config_manager.get_null_ai_setting("active_apprentice_id")
134
+
135
+ # 師匠モデルの設定
136
+ if persisted_master_id:
137
+ master_config = self.config_manager.get_model_config(persisted_master_id)
138
+ if master_config:
139
+ self.master_model = master_config
140
+ self._update_engine_status(self.master_model.model_id, "master")
141
+ logger.info(f"Persisted master model loaded: {self.master_model.display_name}")
142
+ else:
143
+ logger.warning(f"Persisted master model '{persisted_master_id}' not found in configuration. Attempting to set default.")
144
+ self._set_initial_master_from_config() # 永続化されたモデルが見つからない場合はデフォルトを設定
145
+ else:
146
+ self._set_initial_master_from_config() # 永続化されたマスターIDがない場合はデフォルトを設定
147
+
148
+ # 弟子モデルの設定
149
+ if persisted_apprentice_id:
150
+ apprentice_config = self.config_manager.get_model_config(persisted_apprentice_id)
151
+ if apprentice_config:
152
+ self.apprentice_model = apprentice_config
153
+ # unique_idもconfigからロードする
154
+ apprentice_unique_id = self.config_manager.get_null_ai_setting(f"apprentice_unique_id_{persisted_apprentice_id}")
155
+ self._update_engine_status(self.apprentice_model.model_id, "apprentice", apprentice_unique_id)
156
+ logger.info(f"Persisted apprentice model loaded: {self.apprentice_model.display_name}")
157
+ else:
158
+ logger.warning(f"Persisted apprentice model '{persisted_apprentice_id}' not found in configuration. Clearing active apprentice.")
159
+ self.apprentice_model = None
160
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
161
+ else:
162
+ self.apprentice_model = None
163
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
164
+
165
+ def _set_initial_master_from_config(self):
166
+ """設定からデフォルトの師匠モデルをロード(persistedがない場合や見つからない場合)"""
167
+ default_master_config = self.config_manager.get_default_model_config(domain_id=self.active_domain_id)
168
+ if default_master_config:
169
+ self.master_model = default_master_config
170
+ self._update_engine_status(self.master_model.model_id, "master")
171
+ self.config_manager.set_null_ai_setting("active_master_id", self.master_model.model_id)
172
+ logger.info(f"Default master model loaded for domain '{self.active_domain_id}': {self.master_model.display_name}")
173
+ else:
174
+ logger.warning(f"No default master model found for domain '{self.active_domain_id}' in configuration. Master model remains unset.")
175
+ self.master_model = None
176
+ self.config_manager.set_null_ai_setting("active_master_id", None)
177
+
178
+ def set_active_domain_id(self, domain_id: str):
179
+ """アクティブなドメインIDを設定し、それに応じてモデルを再ロードする"""
180
+ if self.active_domain_id != domain_id:
181
+ self.active_domain_id = domain_id
182
+ self._load_default_models() # アクティブドメイン変更時は師匠・弟子モデルも再設定
183
+ logger.info(f"ModelRouter active domain set to {domain_id} and models reloaded.")
184
+
185
+ def set_master_model(self, model_id: str) -> bool:
186
+ """師匠モデルを設定する"""
187
+ model = self.config_manager.get_model_config(model_id)
188
+ if model:
189
+ # 古い師匠を'retired'に戻す (ただし、それがまさに今昇格している弟子ではない場合)
190
+ if self.master_model and self.master_model.model_id != model_id:
191
+ # 昇格の場合、古い師匠はretiredにする
192
+ self._update_engine_status(self.master_model.model_id, "retired")
193
+ logger.info(f"Old master model '{self.master_model.display_name}' set to 'retired'.")
194
+
195
+ self.master_model = model
196
+ self._update_engine_status(model_id, "master")
197
+ self.config_manager.set_null_ai_setting("active_master_id", model_id)
198
+ logger.info(f"Master model set to: {model.display_name}")
199
+ return True
200
+ logger.error(f"Model with ID '{model_id}' not found for master setting.")
201
+ return False
202
+
203
+ def set_apprentice_model(self, model_id: Optional[str]) -> bool:
204
+ """弟子モデルを設定する (Noneでクリア)"""
205
+ # 古い弟子を'available'に戻す
206
+ if self.apprentice_model:
207
+ self._update_engine_status(self.apprentice_model.model_id, "available")
208
+
209
+ if model_id is None or model_id == 'none':
210
+ self.apprentice_model = None
211
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
212
+ logger.info("Apprentice model cleared.")
213
+ return True
214
+
215
+ model = self.config_manager.get_model_config(model_id)
216
+ if model:
217
+ self.apprentice_model = model
218
+ # For apprentices, we need to ensure unique_id is tracked if this is a named apprentice
219
+ apprentice_unique_id = self.managed_engines[model_id].get("unique_id") # Get existing unique_id
220
+ self._update_engine_status(model_id, "apprentice", apprentice_unique_id)
221
+ self.config_manager.set_null_ai_setting("active_apprentice_id", model_id)
222
+ logger.info(f"Apprentice model set to: {model.display_name}")
223
+ return True
224
+ logger.error(f"Model with ID '{model_id}' not found for apprentice setting.")
225
+ return False
226
+
227
+ def get_all_managed_engines(self) -> List[Dict[str, Any]]:
228
+ """管理している全てのエンジンとそのステータス、ユニークIDを含むリストを返す"""
229
+ return list(self.managed_engines.values())
230
+
231
+ def get_master_model(self) -> Optional[ModelConfig]:
232
+ """現在の師匠モデルを取得する"""
233
+ return self.master_model
234
+
235
+ def get_apprentice_model(self) -> Optional[ModelConfig]:
236
+ """現在の弟子モデルを取得する"""
237
+ return self.apprentice_model
238
+
239
+ def swap_engines(self, apprentice_model_id: str) -> bool:
240
+ """師匠と指定した弟子を入れ替える"""
241
+ if not self.master_model:
242
+ logger.error("Cannot swap: No master model is currently set.")
243
+ return False
244
+
245
+ apprentice_candidate = self.config_manager.get_model_config(apprentice_model_id)
246
+ if not apprentice_candidate:
247
+ logger.error(f"Cannot swap: Apprentice model '{apprentice_model_id}' not found.")
248
+ return False
249
+
250
+ # 現在の師匠と弟子の情報を取得
251
+ old_master_model = self.master_model
252
+
253
+ # 新しい師匠は指定された弟子
254
+ new_master_model = apprentice_candidate
255
+
256
+ # 新しい弟子は古い師匠 (ただし、古い師匠が現在のアプレンティス候補でなければ)
257
+ new_apprentice_model = old_master_model if old_master_model.model_id != apprentice_model_id else None
258
+
259
+ # ロールを入れ替える
260
+ self.master_model = new_master_model
261
+ self.apprentice_model = new_apprentice_model
262
+
263
+ # ステータスと永続化を更新
264
+ self._update_engine_status(new_master_model.model_id, "master", self.managed_engines[new_master_model.model_id].get("unique_id"))
265
+ self.config_manager.set_null_ai_setting("active_master_id", new_master_model.model_id)
266
+
267
+ if new_apprentice_model:
268
+ self._update_engine_status(new_apprentice_model.model_id, "apprentice", self.managed_engines[new_apprentice_model.model_id].get("unique_id"))
269
+ self.config_manager.set_null_ai_setting("active_apprentice_id", new_apprentice_model.model_id)
270
+ else:
271
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
272
+
273
+ # 古いアプレンティスが指定されたアプレンティスとは異なる場合、そのステータスをavailableに戻す
274
+ if old_master_model and old_master_model.model_id != new_master_model.model_id:
275
+ self._update_engine_status(old_master_model.model_id, "available") # Old master becomes available if not the new apprentice
276
+
277
+ logger.info(f"Engines swapped: New Master is {new_master_model.display_name}, New Apprentice is {new_apprentice_model.display_name if new_apprentice_model else 'None'}")
278
+ return True
279
+
280
+ def promote_apprentice(self, apprentice_model_id: str) -> bool:
281
+ """指定した弟子を師匠に昇格させ、現在の師匠を引退させる"""
282
+ apprentice_to_promote = self.config_manager.get_model_config(apprentice_model_id)
283
+ if not apprentice_to_promote:
284
+ logger.error(f"Cannot promote: Apprentice model '{apprentice_model_id}' not found.")
285
+ return False
286
+
287
+ # 現在の師匠を引��させる
288
+ if self.master_model:
289
+ self._update_engine_status(self.master_model.model_id, "retired")
290
+ logger.info(f"Current master model '{self.master_model.display_name}' retired.")
291
+
292
+ # 弟子を師匠に昇格させる
293
+ self.master_model = apprentice_to_promote
294
+ self._update_engine_status(apprentice_to_promote.model_id, "master", self.managed_engines[apprentice_to_promote.model_id].get("unique_id"))
295
+ self.config_manager.set_null_ai_setting("active_master_id", apprentice_to_promote.model_id)
296
+
297
+ # 弟子モデルのロールをクリア (昇格したため)
298
+ if self.apprentice_model and self.apprentice_model.model_id == apprentice_model_id:
299
+ self.apprentice_model = None
300
+ self.config_manager.set_null_ai_setting("active_apprentice_id", None)
301
+ logger.info(f"Apprentice model '{apprentice_to_promote.display_name}' promoted to master and apprentice role cleared.")
302
+ else:
303
+ logger.warning(f"Apprentice '{apprentice_model_id}' was promoted, but was not the currently active apprentice.")
304
+
305
+ logger.info(f"Apprentice '{apprentice_to_promote.display_name}' promoted to Master.")
306
+ return True
307
+
308
+ def create_new_apprentice(self) -> Optional[Dict[str, Any]]:
309
+ """新しい「空っぽの弟子」推論エンジンを生成し、登録する"""
310
+ new_apprentice_unique_id = str(uuid.uuid4())
311
+ new_apprentice_model_id = f"apprentice-{new_apprentice_unique_id[:8]}"
312
+ new_apprentice_display_name = f"Apprentice ({new_apprentice_unique_id[:4]})"
313
+
314
+ # ユーザーは「文字通り学習データがないもの」と指定。
315
+ # ここでは、汎用的なベースモデルを想定するか、あるいは特定のプロバイダー/モデル名を指定する。
316
+ # 仮に、Ollamaの何らかの軽量モデルをテンプレートとして利用する。
317
+ # または、単にモデル設定のみを生成し、実際のモデルファイルは後で学習時にロードする。
318
+
319
+ # NOTE: この "empty" モデルが具体的に何を指すかは、
320
+ # 後続のファインチューニングのロジックによって具体化される必要がある。
321
+ # ここでは、設定上の一エントリとして追加する。
322
+
323
+ # TODO: ベースとなる「空っぽ」モデルの設定をnull_ai_config.jsonなどから取得できるようにする
324
+ base_apprentice_provider = ModelProvider.OLLAMA # 仮のデフォルトプロバイダー
325
+ base_apprentice_model_name = "mistral:latest" # 仮のデフォルトモデル名
326
+
327
+ # Configuration for the new apprentice
328
+ new_model_config_data = {
329
+ "model_id": new_apprentice_model_id,
330
+ "display_name": new_apprentice_display_name,
331
+ "provider": base_apprentice_provider.value,
332
+ "model_name": base_apprentice_model_name,
333
+ "max_tokens": 4096,
334
+ "temperature": 0.7,
335
+ "timeout": 120,
336
+ "is_default": False,
337
+ "supported_domains": ["general"], # New apprentices start with general knowledge
338
+ "description": f"Newly generated empty apprentice model with unique ID {new_apprentice_unique_id}."
339
+ }
340
+
341
+ try:
342
+ new_model = self.config_manager.add_model(new_model_config_data)
343
+ if not new_model:
344
+ raise Exception("Failed to add new model via config_manager.")
345
+
346
+ self.managed_engines[new_model.model_id] = {
347
+ "config": new_model.dict(),
348
+ "status": "available",
349
+ "unique_id": new_apprentice_unique_id
350
+ }
351
+ # unique_idも永続化しておく
352
+ self.config_manager.set_null_ai_setting(f"apprentice_unique_id_{new_model.model_id}", new_apprentice_unique_id)
353
+
354
+ logger.info(f"New empty apprentice '{new_apprentice_display_name}' created with ID: {new_apprentice_model_id}")
355
+ return self.managed_engines[new_model.model_id]
356
+ except Exception as e:
357
+ logger.error(f"Failed to create new apprentice: {e}")
358
+ return None
359
+
360
+ def get_active_model(self, for_inference: bool = True) -> Optional[ModelConfig]:
361
+ """
362
+ 推論に使用するアクティブなモデルを取得する。
363
+ for_inferenceがTrueの場合、師匠モデルを返す。
364
+ 倒木システムにおける「学習」などの用途で弟子モデルが必要な場合は、
365
+ 直接get_apprentice_modelを使用する。
366
+ """
367
+ if for_inference:
368
+ return self.master_model
369
+ # 将来的に「成長した」弟子モデルへの自動切り替えロジックが入る可能性
370
+ return self.master_model # デフォルトでは師匠モデルを使用
371
+
372
+ async def infer(self, prompt: str, domain_id: str, model_config: ModelConfig, temperature: Optional[float] = None, save_to_memory: bool = False, rag_mode: str = "rag") -> Dict[str, Any]:
373
+ """
374
+ 指定されたモデルで推論を実行する。
375
+ RAG: DBに知識があればそれを使用、なければAI内部知識で推論してDBに蓄積。
376
+ rag_modeに応じて動作を切り替える。
377
+ """
378
+ # rag_mode == 'direct' の場合の処理
379
+ if rag_mode == "direct":
380
+ relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=1)
381
+ if relevant_knowledge:
382
+ # DBから直接回答を返す
383
+ best_match = relevant_knowledge[0]
384
+ logger.info(f"Direct mode: DB knowledge found. Returning direct answer from tile {best_match['id']}.")
385
+ return {
386
+ "response": best_match['content'],
387
+ "thinking": f"Directly retrieved from knowledge base. Tile ID: {best_match['id']}.",
388
+ "confidence": best_match['confidence_score'],
389
+ "model_used": "database_direct",
390
+ "latency_ms": 50, # DB検索なので高速
391
+ "source_type": "db_direct"
392
+ }
393
+ else:
394
+ # DBに情報がない場合は「わかりません」と返す
395
+ logger.info("Direct mode: No DB knowledge found. Returning 'Not found'.")
396
+ return {
397
+ "response": "ご指定の情報はナレッジベース内に見つかりませんでした。",
398
+ "thinking": "Directly searched knowledge base, but no relevant information was found.",
399
+ "confidence": 0.9, # 「見つからない」という回答には高い信頼度を与える
400
+ "model_used": "database_direct",
401
+ "latency_ms": 50,
402
+ "source_type": "db_direct"
403
+ }
404
+
405
+ # rag_mode == 'rag' の場合の処理 (既存のロジック)
406
+ has_db_knowledge = self._check_db_knowledge(domain_id, prompt)
407
+ source_type = "db_augmented" if has_db_knowledge else "ai_internal_weights"
408
+ thinking_process = ""
409
+ augmented_prompt = prompt
410
+
411
+ if has_db_knowledge:
412
+ # DBから関連知識を取得
413
+ relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=3)
414
+ logger.info(f"RAG mode: DB knowledge found for domain '{domain_id}'. Retrieved {len(relevant_knowledge)} relevant tiles.")
415
+
416
+ # RAG: プロンプトに知識を統合
417
+ knowledge_context = "\n\n".join([
418
+ f"[Knowledge {i+1} - {tile['verification_type']} verification, confidence: {tile['confidence_score']}]\n"
419
+ f"Topic: {tile['topic']}\n"
420
+ f"Content: {tile['content']}"
421
+ for i, tile in enumerate(relevant_knowledge)
422
+ ])
423
+
424
+ augmented_prompt = f"""Based on the following verified knowledge from the database:
425
+
426
+ {knowledge_context}
427
+
428
+ Now, please answer the following question accurately:
429
+ {prompt}"""
430
+
431
+ response = await self._perform_llm_inference(model_config, augmented_prompt, temperature or model_config.temperature)
432
+ thinking_process = f"Accessed knowledge base. Retrieved {len(relevant_knowledge)} relevant knowledge tiles. " + response.get("thinking", "")
433
+ else:
434
+ # DBに知識がない場合、AI内部知識で推論
435
+ logger.info(f"RAG mode: No DB knowledge found for domain '{domain_id}'. Using AI internal weights.")
436
+ response = await self._perform_llm_inference(model_config, prompt, temperature or model_config.temperature)
437
+ thinking_process = "No specific DB knowledge found. Inferred from AI internal weights. " + response.get("thinking", "")
438
+
439
+ # 自己拡充: 推論結果をDBに保存
440
+ if save_to_memory and response.get("confidence", 0) >= 0.7:
441
+ await self._save_inference_to_db(
442
+ domain_id=domain_id,
443
+ prompt=prompt,
444
+ response=response.get("response", ""),
445
+ confidence=response.get("confidence", 0.7),
446
+ source_type="ai",
447
+ model_id=model_config.model_id
448
+ )
449
+
450
+ # 倒木システム: 師匠の出力を教師データとして保存
451
+ is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
452
+ logger.info(f"[Training Data Check] master_model={self.master_model.model_id if self.master_model else None}, "
453
+ f"current_model={model_config.model_id}, is_master={is_master}, "
454
+ f"confidence={response.get('confidence', 0)}, save_to_memory={save_to_memory}")
455
+ if is_master and response.get("confidence", 0) >= 0.8:
456
+ logger.info(f"[Training Data] Saving master output for domain '{domain_id}'...")
457
+ await self._save_master_output_as_training_data(
458
+ prompt=prompt,
459
+ master_response=response.get("response", ""),
460
+ domain_id=domain_id,
461
+ confidence=response.get("confidence", 0.8)
462
+ )
463
+ else:
464
+ logger.info(f"[Training Data] NOT saving: is_master={is_master}, confidence={response.get('confidence', 0)}")
465
+
466
+ return {
467
+ "response": response.get("response", ""),
468
+ "thinking": thinking_process,
469
+ "confidence": response.get("confidence", 0.7),
470
+ "model_used": model_config.model_id,
471
+ "latency_ms": response.get("latency_ms", 0),
472
+ "source_type": source_type
473
+ }
474
+
475
+ async def infer_streaming(self, prompt: str, domain_id: str, model_config: ModelConfig, temperature: Optional[float] = None, save_to_memory: bool = False, rag_mode: str = "rag") -> AsyncGenerator[Dict[str, Any], None]:
476
+ print(f"DEBUG: infer_streaming called - model={model_config.model_id if model_config else 'None'}")
477
+ logger.error(f"[CRITICAL DEBUG] infer_streaming ENTRY POINT - model={model_config.model_id if model_config else 'None'}, save_to_memory={save_to_memory}")
478
+ """
479
+ 指定されたモデルで推論をストリーミングで実行する。
480
+ RAG: DBに知識があればそれを使用、なければAI内部知識で推論してDBに蓄積。
481
+ rag_modeに応じて動作を切り替える。
482
+ """
483
+ logger.info(f"[INFER_STREAMING] Called with model={model_config.model_id}, domain={domain_id}, save_to_memory={save_to_memory}")
484
+
485
+ # 師匠の応答を蓄積するための変数(finallyブロックで使用)
486
+ generated_response = ""
487
+
488
+ try:
489
+ # rag_mode == 'direct' の場合の処理
490
+ if rag_mode == "direct":
491
+ relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=1)
492
+ if relevant_knowledge:
493
+ best_match = relevant_knowledge[0]
494
+ logger.info(f"Direct mode (streaming): DB knowledge found. Yielding direct answer from tile {best_match['id']}.")
495
+ yield {"type": "token", "content": best_match['content']}
496
+ yield {"type": "meta", "source_type": "db_direct", "thinking": f"Directly retrieved from knowledge base. Tile ID: {best_match['id']}.", "confidence": best_match['confidence_score'], "model_used": "database_direct"}
497
+ yield {"type": "complete", "content": best_match['content']}
498
+ return
499
+ else:
500
+ logger.info("Direct mode (streaming): No DB knowledge found. Yielding 'Not found'.")
501
+ response_text = "ご指定の情報はナレッジベース内に見つかりませんでした。"
502
+ yield {"type": "token", "content": response_text}
503
+ yield {"type": "meta", "source_type": "db_direct", "thinking": "Directly searched knowledge base, but no relevant information was found.", "confidence": 0.9, "model_used": "database_direct"}
504
+ yield {"type": "complete", "content": response_text}
505
+ return
506
+
507
+ # rag_mode == 'rag' の場合の処理 (既存のロジック)
508
+ has_db_knowledge = self._check_db_knowledge(domain_id, prompt)
509
+ source_type = "db_augmented" if has_db_knowledge else "ai_internal_weights"
510
+ augmented_prompt = prompt
511
+
512
+ yield {"type": "thinking", "content": f"Checking DB knowledge for domain '{domain_id}'..."}
513
+ await asyncio.sleep(0.1)
514
+
515
+ if has_db_knowledge:
516
+ # DBから関連知識を取得
517
+ relevant_knowledge = self._retrieve_relevant_knowledge(domain_id, prompt, top_k=3)
518
+ yield {"type": "thinking", "content": f"Relevant DB knowledge found. Retrieved {len(relevant_knowledge)} tiles. Augmenting prompt..."}
519
+
520
+ # RAG: プロンプトに知識を統合
521
+ knowledge_context = "\n\n".join([
522
+ f"[Knowledge {i+1} - {tile['verification_type']} verification, confidence: {tile['confidence_score']}]\n"
523
+ f"Topic: {tile['topic']}\n"
524
+ f"Content: {tile['content']}"
525
+ for i, tile in enumerate(relevant_knowledge)
526
+ ])
527
+
528
+ augmented_prompt = f"""Based on the following verified knowledge from the database:
529
+
530
+ {knowledge_context}
531
+
532
+ Now, please answer the following question accurately:
533
+ {prompt}"""
534
+
535
+ # ストリーミング推論実行
536
+ async for chunk in self._perform_llm_streaming_inference(model_config, augmented_prompt, temperature or model_config.temperature):
537
+ if chunk.get("type") == "token":
538
+ generated_response += chunk.get("content", "")
539
+ yield chunk
540
+
541
+ yield {"type": "meta", "source_type": source_type, "thinking": f"Accessed knowledge base. Retrieved {len(relevant_knowledge)} relevant knowledge tiles."}
542
+
543
+ else:
544
+ yield {"type": "thinking", "content": "No specific DB knowledge found. Using AI internal weights."}
545
+
546
+ # ストリーミング推論実行
547
+ async for chunk in self._perform_llm_streaming_inference(model_config, prompt, temperature or model_config.temperature):
548
+ if chunk.get("type") == "token":
549
+ generated_response += chunk.get("content", "")
550
+ yield chunk
551
+
552
+ yield {"type": "meta", "source_type": source_type, "thinking": "Inferred from AI internal weights."}
553
+
554
+ # 自己拡充: 推論結果をDBに保存(信頼度が十分な場合)
555
+ if save_to_memory and len(generated_response) > 0:
556
+ await self._save_inference_to_db(
557
+ domain_id=domain_id,
558
+ prompt=prompt,
559
+ response=generated_response,
560
+ confidence=0.75, # ストリーミングなので固定値
561
+ source_type="ai",
562
+ model_id=model_config.model_id
563
+ )
564
+
565
+ # 倒木システム: 師匠の出力を教師データとして保存
566
+ logger.error(f"[CRITICAL DEBUG] Reached training data save section! generated_response length={len(generated_response)}")
567
+ is_master = (self.master_model and model_config.model_id == self.master_model.model_id)
568
+ logger.info(f"[Training Data Check - Streaming] master_model={self.master_model.model_id if self.master_model else None}, "
569
+ f"current_model={model_config.model_id}, is_master={is_master}, "
570
+ f"response_length={len(generated_response)}, save_to_memory={save_to_memory}")
571
+ if is_master and len(generated_response) > 0:
572
+ logger.info(f"[Training Data - Streaming] Saving master output for domain '{domain_id}'...")
573
+ await self._save_master_output_as_training_data(
574
+ prompt=prompt,
575
+ master_response=generated_response,
576
+ domain_id=domain_id,
577
+ confidence=0.8
578
+ )
579
+ else:
580
+ logger.info(f"[Training Data - Streaming] NOT saving: is_master={is_master}, response_length={len(generated_response)}")
581
+
582
+ def _check_db_knowledge(self, domain_id: str, prompt: str) -> bool:
583
+ """
584
+ 指定されたドメインとプロンプトに関連する知識があるかチェックする。
585
+ 樹木型空間記憶(.iath)を使用。
586
+ """
587
+ try:
588
+ if self.dendritic_memory is None:
589
+ return False
590
+
591
+ # テキスト検索でマッチがあるかチェック
592
+ results = self.dendritic_memory.search_by_text(prompt[:200], top_k=1)
593
+ return len(results) > 0
594
+
595
+ except Exception as e:
596
+ logger.error(f"Error checking dendritic memory: {e}")
597
+ return False
598
+
599
+ def _retrieve_relevant_knowledge(self, domain_id: str, prompt: str, top_k: int = 3) -> list:
600
+ """
601
+ 指定されたドメインとプロンプトに関連する知識を取得する。
602
+ 樹木型空間記憶(.iath)のハイブリッド検索を使用。
603
+ """
604
+ try:
605
+ if self.dendritic_memory is None:
606
+ return []
607
+
608
+ # ハイブリッド検索(テキスト + 空間座標)
609
+ # TODO: プロンプトから空間座標を推定する機能を追加
610
+ results = self.dendritic_memory.hybrid_search(prompt, query_coords=None, top_k=top_k)
611
+
612
+ # 統一フォーマットに変換
613
+ return [
614
+ {
615
+ "id": tile["metadata"]["knowledge_id"],
616
+ "topic": tile["metadata"]["topic"],
617
+ "content": tile["content"]["final_response"],
618
+ "thinking_process": tile["content"]["thinking_process"],
619
+ "confidence_score": tile["verification"]["initial_certainty"],
620
+ "verification_type": tile["verification"]["status"],
621
+ "coordinates": tile["coordinates"],
622
+ "text_match_score": tile.get("text_match_score", 0),
623
+ "spatial_distance": tile.get("spatial_distance", None)
624
+ }
625
+ for tile in results
626
+ ]
627
+
628
+ except Exception as e:
629
+ logger.error(f"Error retrieving relevant knowledge from dendritic memory: {e}")
630
+ return []
631
+
632
+ async def _perform_llm_inference(self, model_config: ModelConfig, prompt: str, temperature: float) -> Dict[str, Any]:
633
+ """
634
+ 指定されたモデル設定でLLM推論を実行する。
635
+ 各プロバイダーの実装を使用。
636
+ """
637
+ provider_type = model_config.provider
638
+
639
+ if provider_type not in self.providers:
640
+ raise ValueError(f"Unsupported provider: {provider_type}")
641
+
642
+ provider = self.providers[provider_type]
643
+
644
+ try:
645
+ result = await provider.infer(model_config, prompt, temperature)
646
+ return result
647
+ except Exception as e:
648
+ logger.error(f"Error during LLM inference with provider '{provider_type}': {e}")
649
+ raise
650
+
651
+ async def _perform_llm_streaming_inference(self, model_config: ModelConfig, prompt: str, temperature: float) -> AsyncGenerator[Dict[str, Any], None]:
652
+ """
653
+ 指定されたモデル設定でLLMストリーミング推論を実行する。
654
+ 各プロバイダーの実装を使用。
655
+ """
656
+ provider_type = model_config.provider
657
+
658
+ if provider_type not in self.providers:
659
+ raise ValueError(f"Unsupported provider: {provider_type}")
660
+
661
+ provider = self.providers[provider_type]
662
+
663
+ try:
664
+ async for chunk in provider.infer_streaming(model_config, prompt, temperature):
665
+ yield chunk
666
+ except Exception as e:
667
+ logger.error(f"Error during LLM streaming inference with provider '{provider_type}': {e}")
668
+ yield {"type": "error", "content": str(e)}
669
+
670
+ async def _save_inference_to_db(self, domain_id: str, prompt: str, response: str, confidence: float, source_type: str, model_id: str) -> bool:
671
+ """
672
+ 推論結果をKnowledge TileとしてDBに保存する(自己拡充)。
673
+
674
+ Priority 2実装: 座標自動推定 + .iath保存
675
+ """
676
+ try:
677
+ from backend.app.database.models import KnowledgeTile
678
+ from backend.app.database.session import SessionLocal # 遅延インポート
679
+
680
+ db = SessionLocal()
681
+ tile_id = f"ai_ktile_{uuid.uuid4().hex}"
682
+
683
+ # AI生成の知識タイルを作成(SQLite)
684
+ new_tile = KnowledgeTile(
685
+ id=tile_id,
686
+ workspace_id="default_workspace", # ローカル版なのでデフォルトワークスペース
687
+ domain_id=domain_id,
688
+ topic=prompt[:200], # プロンプトをトピックとして使用
689
+ content=response,
690
+ confidence_score=confidence,
691
+ verification_type="ai", # AI生成を示す
692
+ verification_count=1,
693
+ contributor_id=None, # AIなのでcontributorなし
694
+ last_verified_by_id=None
695
+ )
696
+
697
+ db.add(new_tile)
698
+ db.commit()
699
+ db.close()
700
+
701
+ logger.info(f"Saved AI inference to SQLite: {new_tile.id}")
702
+
703
+ # Priority 2: 座標自動推定 + .iath保存
704
+ try:
705
+ # 座標推定用のLLM推論関数を作成
706
+ async def llm_inference_for_coords(coord_prompt):
707
+ # 師匠モデル(またはDeepSeek)を使って座標推定
708
+ estimation_model = self.master_model if self.master_model else self._get_any_available_model()
709
+
710
+ if estimation_model:
711
+ result = await self._perform_llm_inference(
712
+ estimation_model,
713
+ coord_prompt,
714
+ temperature=0.3 # 低温度で一貫性を保つ
715
+ )
716
+ return result.get("response", "")
717
+ return "{\"coordinates\": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], \"confidence\": 0.3}"
718
+
719
+ # 座標を推定
720
+ coord_result = await self.coordinate_estimator.estimate_coordinates(
721
+ prompt=prompt,
722
+ response=response,
723
+ domain_id=domain_id,
724
+ llm_inference_func=llm_inference_for_coords,
725
+ use_reasoning=False # 高速化のため推論過程は省略
726
+ )
727
+
728
+ # .iath Tileオブジェクトを作成
729
+ iath_tile = create_tile_from_ai_output(
730
+ knowledge_id=tile_id,
731
+ topic=prompt[:100],
732
+ prompt=prompt,
733
+ response=response,
734
+ coordinates=coord_result["coordinates"],
735
+ confidence=confidence,
736
+ domain_id=domain_id,
737
+ source="ai_generated"
738
+ )
739
+
740
+ # .iathファイルに保存
741
+ success = self.iath_writer.append_tile(iath_tile)
742
+
743
+ if success:
744
+ logger.info(f"Saved AI inference to .iath with coordinates: {coord_result['coordinates']}")
745
+
746
+ # メモリをリロード(新しいタイルを検索可能にする)
747
+ self._load_dendritic_memory()
748
+ else:
749
+ logger.warning(f"Failed to save to .iath, but SQLite save succeeded")
750
+
751
+ except Exception as iath_error:
752
+ logger.error(f"Error saving to .iath (SQLite save succeeded): {iath_error}")
753
+ # .iath保存失敗してもSQLite保存は成功しているのでTrueを返す
754
+
755
+ return True
756
+
757
+ except Exception as e:
758
+ logger.error(f"Error saving inference to DB: {e}")
759
+ return False
760
+
761
+ async def _save_master_output_as_training_data(self, prompt: str, master_response: str, domain_id: str, confidence: float) -> bool:
762
+ """
763
+ 師匠の出力を弟子のファインチューニング用教師データとして保存する。
764
+ 倒木システムの核心機能。
765
+ """
766
+ try:
767
+ import json
768
+ import os
769
+ from datetime import datetime
770
+
771
+ logger.info(f"[Save Training Data] Starting to save master output (domain={domain_id}, confidence={confidence})")
772
+
773
+ # トレーニングデータ保存ディレクトリ
774
+ training_data_dir = "training_data/master_outputs"
775
+ logger.info(f"[Save Training Data] Creating directory: {training_data_dir}")
776
+ os.makedirs(training_data_dir, exist_ok=True)
777
+
778
+ # Alpaca形式で保存
779
+ training_example = {
780
+ "instruction": f"You are an expert in {domain_id}. Provide accurate information based on verified knowledge.",
781
+ "input": prompt,
782
+ "output": master_response,
783
+ "metadata": {
784
+ "domain_id": domain_id,
785
+ "confidence": confidence,
786
+ "master_model_id": self.master_model.model_id if self.master_model else "unknown",
787
+ "timestamp": datetime.utcnow().isoformat(),
788
+ "source": "master_output"
789
+ }
790
+ }
791
+
792
+ # JSONLファイルに追記
793
+ output_file = os.path.join(training_data_dir, f"master_outputs_{domain_id}.jsonl")
794
+ logger.info(f"[Save Training Data] Writing to file: {output_file}")
795
+ with open(output_file, 'a', encoding='utf-8') as f:
796
+ f.write(json.dumps(training_example, ensure_ascii=False) + '\n')
797
+
798
+ logger.info(f"✓ Successfully saved master output as training data: {output_file}")
799
+ return True
800
+
801
+ except Exception as e:
802
+ logger.error(f"✗ Error saving master output as training data: {e}", exc_info=True)
803
+ return False