Spaces:
Sleeping
Sleeping
| """ | |
| NullAI Fine-tuning API | |
| API endpoints for managing apprentice model fine-tuning using master outputs. | |
| """ | |
| from fastapi import APIRouter, HTTPException, BackgroundTasks | |
| from pydantic import BaseModel, Field | |
| from typing import Dict, Any, Optional, List | |
| import logging | |
| from datetime import datetime | |
| from null_ai.fine_tuning import FineTuningManager | |
| from null_ai.auto_training import AutoTrainingManager | |
| from backend.app.config import settings | |
| import asyncio | |
| router = APIRouter() | |
| logger = logging.getLogger(__name__) | |
| # Global fine-tuning manager instance | |
| fine_tuning_manager = FineTuningManager() | |
| # Global auto-training manager instance | |
| # settingsをdict化してauto_training設定を取得 | |
| settings_dict = settings.model_dump() if hasattr(settings, "model_dump") else settings.dict() | |
| auto_training_config = settings_dict.get("auto_training", {}) | |
| auto_training_manager = AutoTrainingManager(auto_training_config, fine_tuning_manager) | |
| # Background task for auto-training monitoring | |
| _monitoring_task: Optional[asyncio.Task] = None | |
| # ===== Pydantic Models ===== | |
| class StartTrainingRequest(BaseModel): | |
| """Request to start fine-tuning an apprentice model.""" | |
| apprentice_model_name: str = Field(..., description="HuggingFace model name or path to fine-tune") | |
| domain_id: Optional[str] = Field(None, description="Domain to train on (None = all domains)") | |
| method: str = Field("peft", description="Training method: 'peft', 'unsloth', or 'mlx'") | |
| epochs: int = Field(3, ge=1, le=100, description="Number of training epochs") | |
| learning_rate: float = Field(2e-4, gt=0, description="Learning rate") | |
| batch_size: int = Field(4, ge=1, le=32, description="Batch size per device") | |
| lora_r: int = Field(8, ge=4, le=64, description="LoRA rank") | |
| lora_alpha: int = Field(16, ge=8, le=128, description="LoRA alpha") | |
| output_name: Optional[str] = Field(None, description="Custom name for output checkpoint") | |
| class TrainingStatusResponse(BaseModel): | |
| """Current training status.""" | |
| is_training: bool | |
| progress: float = Field(..., ge=0, le=100, description="Training progress percentage") | |
| current_epoch: int | |
| total_epochs: int | |
| loss: float | |
| model_id: Optional[str] | |
| start_time: Optional[str] | |
| estimated_time_remaining: Optional[str] = None | |
| class TrainingResultResponse(BaseModel): | |
| """Training completion result.""" | |
| success: bool | |
| output_dir: Optional[str] = None | |
| model_name: Optional[str] = None | |
| train_loss: Optional[float] = None | |
| method: Optional[str] = None | |
| error: Optional[str] = None | |
| metrics: Optional[Dict[str, Any]] = None | |
| class TrainingDataStatsResponse(BaseModel): | |
| """Statistics about available training data.""" | |
| total_examples: int | |
| examples_by_domain: Dict[str, int] | |
| domains: List[str] | |
| file_paths: List[str] | |
| class TrainingMetricsResponse(BaseModel): | |
| """Training metrics from a checkpoint.""" | |
| log_history: List[Dict[str, Any]] = [] | |
| best_metric: Optional[float] = None | |
| best_model_checkpoint: Optional[str] = None | |
| error: Optional[str] = None | |
| # ===== API Endpoints ===== | |
| async def start_training(request: StartTrainingRequest, background_tasks: BackgroundTasks): | |
| """ | |
| Start fine-tuning an apprentice model. | |
| This endpoint initiates the fine-tuning process in the background. | |
| Use the /status endpoint to monitor progress. | |
| **Supported Methods:** | |
| - `peft`: HuggingFace PEFT with QLoRA (recommended, most compatible) | |
| - `unsloth`: Unsloth fast training (2x faster, Llama/Mistral/Qwen models) | |
| - `mlx`: MLX training (Apple Silicon only, experimental) | |
| **Training Data:** | |
| - Automatically loads master outputs from `training_data/master_outputs/` | |
| - Format: Alpaca-style JSONL (instruction-input-output) | |
| - High-quality outputs only (confidence >= 0.8) | |
| **Example:** | |
| ```json | |
| { | |
| "apprentice_model_name": "microsoft/phi-2", | |
| "domain_id": "medical", | |
| "method": "peft", | |
| "epochs": 3, | |
| "learning_rate": 2e-4, | |
| "batch_size": 4 | |
| } | |
| ``` | |
| """ | |
| logger.info(f"Received training request: {request.dict()}") | |
| # Check if already training | |
| if fine_tuning_manager.current_training_state["is_training"]: | |
| raise HTTPException( | |
| status_code=409, | |
| detail="Training is already in progress. Please wait or stop the current training." | |
| ) | |
| # Validate method | |
| if request.method not in ["peft", "unsloth", "mlx"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid training method: {request.method}. Must be 'peft', 'unsloth', or 'mlx'" | |
| ) | |
| # Check if training data exists | |
| training_examples = fine_tuning_manager.load_training_data(request.domain_id) | |
| if not training_examples: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"No training data found for domain: {request.domain_id or 'all'}" | |
| ) | |
| logger.info(f"Found {len(training_examples)} training examples") | |
| # Start training in background | |
| async def run_training(): | |
| try: | |
| result = await fine_tuning_manager.start_training( | |
| apprentice_model_name=request.apprentice_model_name, | |
| domain_id=request.domain_id, | |
| method=request.method, | |
| epochs=request.epochs, | |
| learning_rate=request.learning_rate, | |
| batch_size=request.batch_size, | |
| output_name=request.output_name, | |
| progress_callback=None # TODO: Implement WebSocket for real-time updates | |
| ) | |
| logger.info(f"Training completed: {result}") | |
| except Exception as e: | |
| logger.error(f"Training failed: {e}", exc_info=True) | |
| fine_tuning_manager.current_training_state.update({ | |
| "is_training": False, | |
| "error": str(e) | |
| }) | |
| background_tasks.add_task(run_training) | |
| return TrainingResultResponse( | |
| success=True, | |
| model_name=request.apprentice_model_name, | |
| method=request.method, | |
| output_dir=None # Will be determined during training | |
| ) | |
| async def get_training_status(): | |
| """ | |
| Get current training status and progress. | |
| Returns real-time information about ongoing training: | |
| - Progress percentage (0-100) | |
| - Current epoch and total epochs | |
| - Current loss value | |
| - Model being trained | |
| - Start time | |
| **Example Response:** | |
| ```json | |
| { | |
| "is_training": true, | |
| "progress": 45.5, | |
| "current_epoch": 1, | |
| "total_epochs": 3, | |
| "loss": 0.234, | |
| "model_id": "microsoft/phi-2", | |
| "start_time": "2025-12-02T10:30:00" | |
| } | |
| ``` | |
| """ | |
| state = fine_tuning_manager.get_training_status() | |
| return TrainingStatusResponse(**state) | |
| async def stop_training(): | |
| """ | |
| Stop the current training process. | |
| Attempts to gracefully stop the training and save the current checkpoint. | |
| Note: This may not work immediately depending on the training backend. | |
| """ | |
| if not fine_tuning_manager.current_training_state["is_training"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No training is currently in progress" | |
| ) | |
| fine_tuning_manager.stop_training() | |
| logger.info("Training stop requested") | |
| return { | |
| "success": True, | |
| "message": "Training stop requested. This may take a moment..." | |
| } | |
| async def get_training_data_stats(): | |
| """ | |
| Get statistics about available training data. | |
| Returns: | |
| - Total number of training examples | |
| - Number of examples per domain | |
| - List of available domains | |
| - File paths to training data | |
| **Example Response:** | |
| ```json | |
| { | |
| "total_examples": 150, | |
| "examples_by_domain": { | |
| "medical": 50, | |
| "general": 100 | |
| }, | |
| "domains": ["medical", "general"], | |
| "file_paths": [ | |
| "training_data/master_outputs/master_outputs_medical.jsonl", | |
| "training_data/master_outputs/master_outputs_general.jsonl" | |
| ] | |
| } | |
| ``` | |
| """ | |
| from pathlib import Path | |
| import json | |
| training_data_dir = Path("training_data/master_outputs") | |
| if not training_data_dir.exists(): | |
| return TrainingDataStatsResponse( | |
| total_examples=0, | |
| examples_by_domain={}, | |
| domains=[], | |
| file_paths=[] | |
| ) | |
| examples_by_domain = {} | |
| file_paths = [] | |
| for jsonl_file in training_data_dir.glob("master_outputs_*.jsonl"): | |
| file_paths.append(str(jsonl_file)) | |
| # Extract domain from filename: master_outputs_{domain}.jsonl | |
| domain = jsonl_file.stem.replace("master_outputs_", "") | |
| # Count examples in file | |
| count = 0 | |
| with open(jsonl_file, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| try: | |
| json.loads(line.strip()) | |
| count += 1 | |
| except json.JSONDecodeError: | |
| continue | |
| examples_by_domain[domain] = count | |
| total_examples = sum(examples_by_domain.values()) | |
| domains = list(examples_by_domain.keys()) | |
| return TrainingDataStatsResponse( | |
| total_examples=total_examples, | |
| examples_by_domain=examples_by_domain, | |
| domains=domains, | |
| file_paths=file_paths | |
| ) | |
| async def get_training_metrics(checkpoint_name: str): | |
| """ | |
| Get training metrics from a specific checkpoint. | |
| Loads and returns the training history, including: | |
| - Loss values over time | |
| - Learning rate schedule | |
| - Best metric achieved | |
| - Best model checkpoint path | |
| **Example:** | |
| ``` | |
| GET /api/training/metrics/apprentice_medical_20251202_103000 | |
| ``` | |
| """ | |
| from pathlib import Path | |
| checkpoint_dir = Path("training_data/checkpoints") / checkpoint_name | |
| if not checkpoint_dir.exists(): | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Checkpoint not found: {checkpoint_name}" | |
| ) | |
| metrics = fine_tuning_manager.get_training_metrics(str(checkpoint_dir)) | |
| if "error" in metrics: | |
| return TrainingMetricsResponse(error=metrics["error"]) | |
| return TrainingMetricsResponse(**metrics) | |
| async def list_checkpoints(): | |
| """ | |
| List all available training checkpoints. | |
| Returns a list of checkpoint directories with metadata: | |
| - Checkpoint name | |
| - Creation date | |
| - Model name (if available) | |
| - Size on disk | |
| **Example Response:** | |
| ```json | |
| [ | |
| { | |
| "name": "apprentice_medical_20251202_103000", | |
| "created_at": "2025-12-02T10:30:00", | |
| "size_mb": 256.5, | |
| "model_name": "microsoft/phi-2" | |
| } | |
| ] | |
| ``` | |
| """ | |
| from pathlib import Path | |
| import os | |
| checkpoints_dir = Path("training_data/checkpoints") | |
| if not checkpoints_dir.exists(): | |
| return [] | |
| checkpoints = [] | |
| for checkpoint_dir in checkpoints_dir.iterdir(): | |
| if not checkpoint_dir.is_dir(): | |
| continue | |
| # Get directory size | |
| total_size = sum( | |
| f.stat().st_size for f in checkpoint_dir.rglob('*') if f.is_file() | |
| ) | |
| size_mb = total_size / (1024 * 1024) | |
| # Get creation time | |
| created_at = datetime.fromtimestamp(checkpoint_dir.stat().st_ctime).isoformat() | |
| # Try to read model name from config.json | |
| model_name = None | |
| config_file = checkpoint_dir / "config.json" | |
| if config_file.exists(): | |
| import json | |
| try: | |
| with open(config_file, 'r') as f: | |
| config = json.load(f) | |
| model_name = config.get("_name_or_path") | |
| except: | |
| pass | |
| checkpoints.append({ | |
| "name": checkpoint_dir.name, | |
| "created_at": created_at, | |
| "size_mb": round(size_mb, 2), | |
| "model_name": model_name | |
| }) | |
| # Sort by creation time (newest first) | |
| checkpoints.sort(key=lambda x: x["created_at"], reverse=True) | |
| return checkpoints | |
| async def delete_checkpoint(checkpoint_name: str): | |
| """ | |
| Delete a training checkpoint. | |
| **Warning:** This operation is irreversible. | |
| **Example:** | |
| ``` | |
| DELETE /api/training/checkpoints/apprentice_medical_20251202_103000 | |
| ``` | |
| """ | |
| from pathlib import Path | |
| import shutil | |
| checkpoint_dir = Path("training_data/checkpoints") / checkpoint_name | |
| if not checkpoint_dir.exists(): | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Checkpoint not found: {checkpoint_name}" | |
| ) | |
| # Safety check: don't delete if training is using this checkpoint | |
| if fine_tuning_manager.current_training_state["is_training"]: | |
| current_model = fine_tuning_manager.current_training_state.get("model_id", "") | |
| if checkpoint_name in current_model: | |
| raise HTTPException( | |
| status_code=409, | |
| detail="Cannot delete checkpoint while it's being used for training" | |
| ) | |
| try: | |
| shutil.rmtree(checkpoint_dir) | |
| logger.info(f"Deleted checkpoint: {checkpoint_name}") | |
| return { | |
| "success": True, | |
| "message": f"Checkpoint '{checkpoint_name}' deleted successfully" | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to delete checkpoint: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to delete checkpoint: {str(e)}" | |
| ) | |
| # ===== Auto-Training Endpoints ===== | |
| async def get_auto_training_status(): | |
| """ | |
| 自動学習システムの状態を取得 | |
| Returns: | |
| enabled, is_training, trigger conditions, data stats など | |
| """ | |
| try: | |
| status = auto_training_manager.get_status() | |
| return status | |
| except Exception as e: | |
| logger.error(f"Failed to get auto-training status: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def enable_auto_training(): | |
| """自動学習を有効化""" | |
| try: | |
| auto_training_manager.enable() | |
| return {"success": True, "message": "Auto-training enabled"} | |
| except Exception as e: | |
| logger.error(f"Failed to enable auto-training: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def disable_auto_training(): | |
| """自動学習を無効化""" | |
| try: | |
| auto_training_manager.disable() | |
| return {"success": True, "message": "Auto-training disabled"} | |
| except Exception as e: | |
| logger.error(f"Failed to disable auto-training: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def trigger_auto_training_manually(domain_id: Optional[str] = None, background_tasks: BackgroundTasks = None): | |
| """ | |
| 手動で自動学習をトリガー | |
| Args: | |
| domain_id: 特定ドメインのみ学習する場合は指定 | |
| """ | |
| try: | |
| # Check if training should be triggered | |
| should_trigger, reason = auto_training_manager.check_training_trigger(domain_id) | |
| if not should_trigger: | |
| return { | |
| "success": False, | |
| "message": "Training conditions not met", | |
| "reason": reason | |
| } | |
| # Trigger training in background | |
| if background_tasks: | |
| background_tasks.add_task(auto_training_manager.trigger_auto_training, domain_id) | |
| return { | |
| "success": True, | |
| "message": "Auto-training triggered", | |
| "reason": reason | |
| } | |
| else: | |
| result = await auto_training_manager.trigger_auto_training(domain_id) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Failed to trigger auto-training: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def update_auto_training_config(config: Dict[str, Any]): | |
| """ | |
| 自動学習の設定を更新 | |
| Args: | |
| config: 新しい設定 (min_examples, min_days, trigger_mode など) | |
| """ | |
| try: | |
| auto_training_manager.update_config(config) | |
| return { | |
| "success": True, | |
| "message": "Auto-training config updated", | |
| "new_config": auto_training_manager.config | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to update auto-training config: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ===== Background Monitoring Task ===== | |
| async def auto_training_monitor_loop(): | |
| """ | |
| バックグラウンドで定期的に自動学習条件をチェックするループ | |
| """ | |
| logger.info("Auto-training monitor started") | |
| while True: | |
| try: | |
| if not auto_training_manager.enabled: | |
| await asyncio.sleep(60) # Disabled時は1分ごとにチェック | |
| continue | |
| # チェック間隔(分) | |
| check_interval = auto_training_manager.check_interval_minutes | |
| # トリガー条件をチェック | |
| should_trigger, reason = auto_training_manager.check_training_trigger() | |
| if should_trigger: | |
| logger.info(f"Auto-training trigger conditions met: {reason}") | |
| # 推奨時間帯かチェック | |
| if auto_training_manager.should_train_now(): | |
| logger.info("Starting auto-training (preferred time window)") | |
| await auto_training_manager.trigger_auto_training() | |
| else: | |
| logger.info(f"Trigger conditions met but not in preferred time window (hour {auto_training_manager.preferred_hour})") | |
| # 次回チェックまで待機 | |
| await asyncio.sleep(check_interval * 60) | |
| except asyncio.CancelledError: | |
| logger.info("Auto-training monitor cancelled") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error in auto-training monitor: {e}", exc_info=True) | |
| await asyncio.sleep(60) # エラー時は1分後に再試行 | |
| def start_auto_training_monitor(): | |
| """バックグラウンド監視タスクを開始""" | |
| global _monitoring_task | |
| if _monitoring_task is None or _monitoring_task.done(): | |
| _monitoring_task = asyncio.create_task(auto_training_monitor_loop()) | |
| logger.info("Auto-training background monitor started") | |
| else: | |
| logger.warning("Auto-training monitor is already running") | |
| def stop_auto_training_monitor(): | |
| """バックグラウンド監視タスクを停止""" | |
| global _monitoring_task | |
| if _monitoring_task and not _monitoring_task.done(): | |
| _monitoring_task.cancel() | |
| logger.info("Auto-training background monitor stopped") | |
| _monitoring_task = None | |