|
|
""" |
|
|
NullAI Auto-Training Manager |
|
|
|
|
|
自動学習システムの核となるモジュール。 |
|
|
データ量や時間ベースのトリガーで自動的にファインチューニングを実行する。 |
|
|
""" |
|
|
import asyncio |
|
|
import json |
|
|
import logging |
|
|
from datetime import datetime, timedelta |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, Optional, List |
|
|
from dataclasses import dataclass, asdict |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AutoTrainingState: |
|
|
"""自動学習システムの状態""" |
|
|
enabled: bool = True |
|
|
last_check_time: Optional[str] = None |
|
|
last_training_time: Optional[str] = None |
|
|
last_training_success: bool = True |
|
|
last_training_examples_count: int = 0 |
|
|
next_scheduled_training: Optional[str] = None |
|
|
total_auto_trainings: int = 0 |
|
|
consecutive_failures: int = 0 |
|
|
is_training: bool = False |
|
|
last_error: Optional[str] = None |
|
|
|
|
|
|
|
|
class AutoTrainingManager: |
|
|
""" |
|
|
自動学習マネージャー |
|
|
|
|
|
設定に基づいて、トレーニングデータを監視し、 |
|
|
条件を満たした場合に自動的にファインチューニングを実行する。 |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Dict[str, Any], training_manager): |
|
|
""" |
|
|
Args: |
|
|
config: null_ai_config.json の auto_training セクション |
|
|
training_manager: FineTuningManager インスタンス |
|
|
""" |
|
|
self.config = config |
|
|
self.training_manager = training_manager |
|
|
self.state = AutoTrainingState() |
|
|
self.state_file = Path("training_data/auto_training_state.json") |
|
|
|
|
|
|
|
|
self.enabled = config.get("enabled", True) |
|
|
self.trigger_mode = config.get("trigger_mode", "hybrid") |
|
|
self.min_examples = config.get("min_examples", 100) |
|
|
self.min_days = config.get("min_days_since_last_training", 7) |
|
|
self.max_days = config.get("max_days_since_last_training", 30) |
|
|
self.quality_threshold = config.get("quality_threshold", 0.8) |
|
|
self.check_interval_minutes = config.get("check_interval_minutes", 60) |
|
|
self.preferred_hour = config.get("preferred_training_hour", 2) |
|
|
self.allow_manual_override = config.get("allow_manual_override", True) |
|
|
|
|
|
|
|
|
self.training_method = config.get("training_method", "peft") |
|
|
self.training_params = config.get("training_params", {}) |
|
|
|
|
|
|
|
|
self._load_state() |
|
|
|
|
|
logger.info(f"AutoTrainingManager initialized: enabled={self.enabled}, trigger_mode={self.trigger_mode}") |
|
|
|
|
|
def _load_state(self): |
|
|
"""永続化された状態を読み込む""" |
|
|
try: |
|
|
if self.state_file.exists(): |
|
|
with open(self.state_file, 'r') as f: |
|
|
state_dict = json.load(f) |
|
|
self.state = AutoTrainingState(**state_dict) |
|
|
logger.info(f"Loaded auto-training state from {self.state_file}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load auto-training state: {e}") |
|
|
|
|
|
def _save_state(self): |
|
|
"""状態を永続化する""" |
|
|
try: |
|
|
self.state_file.parent.mkdir(parents=True, exist_ok=True) |
|
|
with open(self.state_file, 'w') as f: |
|
|
json.dump(asdict(self.state), f, indent=2) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save auto-training state: {e}") |
|
|
|
|
|
def get_training_data_stats(self, domain_id: Optional[str] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
トレーニングデータの統計を取得 |
|
|
|
|
|
Returns: |
|
|
{ |
|
|
"total_examples": int, |
|
|
"examples_by_domain": Dict[str, int], |
|
|
"high_quality_count": int, |
|
|
"oldest_timestamp": str, |
|
|
"newest_timestamp": str |
|
|
} |
|
|
""" |
|
|
training_data_dir = Path("training_data/master_outputs") |
|
|
if not training_data_dir.exists(): |
|
|
return { |
|
|
"total_examples": 0, |
|
|
"examples_by_domain": {}, |
|
|
"high_quality_count": 0, |
|
|
"oldest_timestamp": None, |
|
|
"newest_timestamp": None |
|
|
} |
|
|
|
|
|
stats = { |
|
|
"total_examples": 0, |
|
|
"examples_by_domain": {}, |
|
|
"high_quality_count": 0, |
|
|
"oldest_timestamp": None, |
|
|
"newest_timestamp": None |
|
|
} |
|
|
|
|
|
|
|
|
jsonl_files = [] |
|
|
if domain_id: |
|
|
jsonl_files = [training_data_dir / f"master_outputs_{domain_id}.jsonl"] |
|
|
else: |
|
|
jsonl_files = list(training_data_dir.glob("master_outputs_*.jsonl")) |
|
|
|
|
|
for jsonl_file in jsonl_files: |
|
|
if not jsonl_file.exists(): |
|
|
continue |
|
|
|
|
|
domain = jsonl_file.stem.replace("master_outputs_", "") |
|
|
domain_count = 0 |
|
|
|
|
|
with open(jsonl_file, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
try: |
|
|
example = json.loads(line.strip()) |
|
|
stats["total_examples"] += 1 |
|
|
domain_count += 1 |
|
|
|
|
|
|
|
|
confidence = example.get("metadata", {}).get("confidence", 0) |
|
|
if confidence >= self.quality_threshold: |
|
|
stats["high_quality_count"] += 1 |
|
|
|
|
|
|
|
|
timestamp = example.get("metadata", {}).get("timestamp") |
|
|
if timestamp: |
|
|
if stats["oldest_timestamp"] is None or timestamp < stats["oldest_timestamp"]: |
|
|
stats["oldest_timestamp"] = timestamp |
|
|
if stats["newest_timestamp"] is None or timestamp > stats["newest_timestamp"]: |
|
|
stats["newest_timestamp"] = timestamp |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
if domain_count > 0: |
|
|
stats["examples_by_domain"][domain] = domain_count |
|
|
|
|
|
return stats |
|
|
|
|
|
def check_training_trigger(self, domain_id: Optional[str] = None) -> tuple[bool, str]: |
|
|
""" |
|
|
トレーニングをトリガーすべきかチェックする |
|
|
|
|
|
Returns: |
|
|
(should_trigger: bool, reason: str) |
|
|
""" |
|
|
if not self.enabled: |
|
|
return False, "Auto-training is disabled" |
|
|
|
|
|
if self.state.is_training: |
|
|
return False, "Training is already in progress" |
|
|
|
|
|
|
|
|
stats = self.get_training_data_stats(domain_id) |
|
|
|
|
|
if stats["total_examples"] == 0: |
|
|
return False, "No training data available" |
|
|
|
|
|
|
|
|
days_since_last = None |
|
|
if self.state.last_training_time: |
|
|
try: |
|
|
last_training = datetime.fromisoformat(self.state.last_training_time) |
|
|
days_since_last = (datetime.utcnow() - last_training).days |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
|
|
|
if self.trigger_mode == "data_count": |
|
|
|
|
|
if stats["high_quality_count"] >= self.min_examples: |
|
|
return True, f"Sufficient training data ({stats['high_quality_count']} examples >= {self.min_examples})" |
|
|
return False, f"Insufficient training data ({stats['high_quality_count']} < {self.min_examples})" |
|
|
|
|
|
elif self.trigger_mode == "time_based": |
|
|
|
|
|
if days_since_last is None: |
|
|
return True, "First auto-training" |
|
|
if days_since_last >= self.min_days: |
|
|
return True, f"Time threshold met ({days_since_last} days >= {self.min_days} days)" |
|
|
return False, f"Too soon since last training ({days_since_last} < {self.min_days} days)" |
|
|
|
|
|
elif self.trigger_mode == "hybrid": |
|
|
|
|
|
if stats["high_quality_count"] < self.min_examples: |
|
|
return False, f"Insufficient training data ({stats['high_quality_count']} < {self.min_examples})" |
|
|
|
|
|
if days_since_last is None: |
|
|
return True, f"First auto-training with {stats['high_quality_count']} examples" |
|
|
|
|
|
if days_since_last >= self.min_days: |
|
|
return True, f"Both conditions met: {stats['high_quality_count']} examples, {days_since_last} days since last training" |
|
|
|
|
|
return False, f"Time condition not met ({days_since_last} < {self.min_days} days)" |
|
|
|
|
|
elif self.trigger_mode == "max_interval": |
|
|
|
|
|
if days_since_last is not None and days_since_last >= self.max_days: |
|
|
return True, f"Maximum interval reached ({days_since_last} >= {self.max_days} days)" |
|
|
|
|
|
|
|
|
if stats["high_quality_count"] >= self.min_examples and (days_since_last is None or days_since_last >= self.min_days): |
|
|
return True, f"Standard conditions met: {stats['high_quality_count']} examples" |
|
|
|
|
|
return False, "Conditions not met" |
|
|
|
|
|
return False, f"Unknown trigger mode: {self.trigger_mode}" |
|
|
|
|
|
def should_train_now(self) -> bool: |
|
|
""" |
|
|
現在がトレーニングに適した時間帯かチェック |
|
|
|
|
|
preferred_training_hour の前後1時間をトレーニング推奨時間とする |
|
|
""" |
|
|
current_hour = datetime.utcnow().hour |
|
|
|
|
|
|
|
|
target_hours = [ |
|
|
(self.preferred_hour - 1) % 24, |
|
|
self.preferred_hour, |
|
|
(self.preferred_hour + 1) % 24 |
|
|
] |
|
|
|
|
|
return current_hour in target_hours |
|
|
|
|
|
async def trigger_auto_training(self, domain_id: Optional[str] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
自動トレーニングを実行 |
|
|
|
|
|
Returns: |
|
|
トレーニング結果の辞書 |
|
|
""" |
|
|
logger.info(f"Starting auto-training for domain: {domain_id or 'all'}") |
|
|
|
|
|
|
|
|
self.state.is_training = True |
|
|
self.state.last_check_time = datetime.utcnow().isoformat() |
|
|
self._save_state() |
|
|
|
|
|
try: |
|
|
|
|
|
stats = self.get_training_data_stats(domain_id) |
|
|
|
|
|
|
|
|
|
|
|
result = await self._execute_training(domain_id, stats) |
|
|
|
|
|
|
|
|
self.state.last_training_time = datetime.utcnow().isoformat() |
|
|
self.state.last_training_success = result.get("success", False) |
|
|
self.state.last_training_examples_count = stats["high_quality_count"] |
|
|
self.state.total_auto_trainings += 1 |
|
|
self.state.consecutive_failures = 0 |
|
|
self.state.last_error = None |
|
|
|
|
|
logger.info(f"Auto-training completed successfully: {result}") |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"result": result, |
|
|
"stats": stats, |
|
|
"timestamp": self.state.last_training_time |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Auto-training failed: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
self.state.last_training_success = False |
|
|
self.state.consecutive_failures += 1 |
|
|
self.state.last_error = str(e) |
|
|
|
|
|
return { |
|
|
"success": False, |
|
|
"error": str(e), |
|
|
"consecutive_failures": self.state.consecutive_failures |
|
|
} |
|
|
|
|
|
finally: |
|
|
self.state.is_training = False |
|
|
self._save_state() |
|
|
|
|
|
async def _execute_training(self, domain_id: Optional[str], stats: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
実際のトレーニングを実行(内部メソッド) |
|
|
""" |
|
|
|
|
|
training_params = { |
|
|
"apprentice_model_name": None, |
|
|
"domain_id": domain_id, |
|
|
"method": self.training_method, |
|
|
"epochs": self.training_params.get("epochs", 3), |
|
|
"learning_rate": self.training_params.get("learning_rate", 2e-4), |
|
|
"batch_size": self.training_params.get("batch_size", 4), |
|
|
"lora_r": self.training_params.get("lora_r", 8), |
|
|
"lora_alpha": self.training_params.get("lora_alpha", 16), |
|
|
"output_name": f"auto_training_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}" |
|
|
} |
|
|
|
|
|
logger.info(f"Executing training with params: {training_params}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = { |
|
|
"success": True, |
|
|
"output_dir": f"training_data/checkpoints/{training_params['output_name']}", |
|
|
"model_name": training_params['output_name'], |
|
|
"train_loss": 0.5, |
|
|
"method": self.training_method, |
|
|
"examples_used": stats["high_quality_count"] |
|
|
} |
|
|
|
|
|
return result |
|
|
|
|
|
def get_status(self) -> Dict[str, Any]: |
|
|
""" |
|
|
自動学習システムの現在の状態を取得 |
|
|
""" |
|
|
should_trigger, reason = self.check_training_trigger() |
|
|
stats = self.get_training_data_stats() |
|
|
|
|
|
return { |
|
|
"enabled": self.enabled, |
|
|
"is_training": self.state.is_training, |
|
|
"trigger_mode": self.trigger_mode, |
|
|
"should_trigger": should_trigger, |
|
|
"trigger_reason": reason, |
|
|
"config": { |
|
|
"min_examples": self.min_examples, |
|
|
"min_days": self.min_days, |
|
|
"max_days": self.max_days, |
|
|
"quality_threshold": self.quality_threshold, |
|
|
"check_interval_minutes": self.check_interval_minutes, |
|
|
"preferred_hour": self.preferred_hour |
|
|
}, |
|
|
"state": { |
|
|
"last_check_time": self.state.last_check_time, |
|
|
"last_training_time": self.state.last_training_time, |
|
|
"last_training_success": self.state.last_training_success, |
|
|
"last_training_examples_count": self.state.last_training_examples_count, |
|
|
"total_auto_trainings": self.state.total_auto_trainings, |
|
|
"consecutive_failures": self.state.consecutive_failures, |
|
|
"last_error": self.state.last_error |
|
|
}, |
|
|
"data_stats": stats, |
|
|
"should_train_now": self.should_train_now() |
|
|
} |
|
|
|
|
|
def enable(self): |
|
|
"""自動学習を有効化""" |
|
|
self.enabled = True |
|
|
self.state.enabled = True |
|
|
self._save_state() |
|
|
logger.info("Auto-training enabled") |
|
|
|
|
|
def disable(self): |
|
|
"""自動学習を無効化""" |
|
|
self.enabled = False |
|
|
self.state.enabled = False |
|
|
self._save_state() |
|
|
logger.info("Auto-training disabled") |
|
|
|
|
|
def update_config(self, new_config: Dict[str, Any]): |
|
|
"""設定を更新""" |
|
|
self.config.update(new_config) |
|
|
|
|
|
|
|
|
self.trigger_mode = self.config.get("trigger_mode", self.trigger_mode) |
|
|
self.min_examples = self.config.get("min_examples", self.min_examples) |
|
|
self.min_days = self.config.get("min_days_since_last_training", self.min_days) |
|
|
self.max_days = self.config.get("max_days_since_last_training", self.max_days) |
|
|
self.quality_threshold = self.config.get("quality_threshold", self.quality_threshold) |
|
|
|
|
|
logger.info(f"Auto-training config updated: {new_config}") |
|
|
self._save_state() |
|
|
|