|
|
""" |
|
|
NullAI Fine-tuning API |
|
|
|
|
|
API endpoints for managing apprentice model fine-tuning using master outputs. |
|
|
""" |
|
|
from fastapi import APIRouter, HTTPException, BackgroundTasks |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Dict, Any, Optional, List |
|
|
import logging |
|
|
from datetime import datetime |
|
|
|
|
|
from null_ai.fine_tuning import FineTuningManager |
|
|
from null_ai.auto_training import AutoTrainingManager |
|
|
from backend.app.config import settings |
|
|
import asyncio |
|
|
|
|
|
router = APIRouter() |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
fine_tuning_manager = FineTuningManager() |
|
|
|
|
|
|
|
|
|
|
|
settings_dict = settings.model_dump() if hasattr(settings, "model_dump") else settings.dict() |
|
|
auto_training_config = settings_dict.get("auto_training", {}) |
|
|
auto_training_manager = AutoTrainingManager(auto_training_config, fine_tuning_manager) |
|
|
|
|
|
|
|
|
_monitoring_task: Optional[asyncio.Task] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StartTrainingRequest(BaseModel): |
|
|
"""Request to start fine-tuning an apprentice model.""" |
|
|
apprentice_model_name: str = Field(..., description="HuggingFace model name or path to fine-tune") |
|
|
domain_id: Optional[str] = Field(None, description="Domain to train on (None = all domains)") |
|
|
method: str = Field("peft", description="Training method: 'peft', 'unsloth', or 'mlx'") |
|
|
epochs: int = Field(3, ge=1, le=100, description="Number of training epochs") |
|
|
learning_rate: float = Field(2e-4, gt=0, description="Learning rate") |
|
|
batch_size: int = Field(4, ge=1, le=32, description="Batch size per device") |
|
|
lora_r: int = Field(8, ge=4, le=64, description="LoRA rank") |
|
|
lora_alpha: int = Field(16, ge=8, le=128, description="LoRA alpha") |
|
|
output_name: Optional[str] = Field(None, description="Custom name for output checkpoint") |
|
|
|
|
|
|
|
|
class TrainingStatusResponse(BaseModel): |
|
|
"""Current training status.""" |
|
|
is_training: bool |
|
|
progress: float = Field(..., ge=0, le=100, description="Training progress percentage") |
|
|
current_epoch: int |
|
|
total_epochs: int |
|
|
loss: float |
|
|
model_id: Optional[str] |
|
|
start_time: Optional[str] |
|
|
estimated_time_remaining: Optional[str] = None |
|
|
|
|
|
|
|
|
class TrainingResultResponse(BaseModel): |
|
|
"""Training completion result.""" |
|
|
success: bool |
|
|
output_dir: Optional[str] = None |
|
|
model_name: Optional[str] = None |
|
|
train_loss: Optional[float] = None |
|
|
method: Optional[str] = None |
|
|
error: Optional[str] = None |
|
|
metrics: Optional[Dict[str, Any]] = None |
|
|
|
|
|
|
|
|
class TrainingDataStatsResponse(BaseModel): |
|
|
"""Statistics about available training data.""" |
|
|
total_examples: int |
|
|
examples_by_domain: Dict[str, int] |
|
|
domains: List[str] |
|
|
file_paths: List[str] |
|
|
|
|
|
|
|
|
class TrainingMetricsResponse(BaseModel): |
|
|
"""Training metrics from a checkpoint.""" |
|
|
log_history: List[Dict[str, Any]] = [] |
|
|
best_metric: Optional[float] = None |
|
|
best_model_checkpoint: Optional[str] = None |
|
|
error: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/start", response_model=TrainingResultResponse) |
|
|
async def start_training(request: StartTrainingRequest, background_tasks: BackgroundTasks): |
|
|
""" |
|
|
Start fine-tuning an apprentice model. |
|
|
|
|
|
This endpoint initiates the fine-tuning process in the background. |
|
|
Use the /status endpoint to monitor progress. |
|
|
|
|
|
**Supported Methods:** |
|
|
- `peft`: HuggingFace PEFT with QLoRA (recommended, most compatible) |
|
|
- `unsloth`: Unsloth fast training (2x faster, Llama/Mistral/Qwen models) |
|
|
- `mlx`: MLX training (Apple Silicon only, experimental) |
|
|
|
|
|
**Training Data:** |
|
|
- Automatically loads master outputs from `training_data/master_outputs/` |
|
|
- Format: Alpaca-style JSONL (instruction-input-output) |
|
|
- High-quality outputs only (confidence >= 0.8) |
|
|
|
|
|
**Example:** |
|
|
```json |
|
|
{ |
|
|
"apprentice_model_name": "microsoft/phi-2", |
|
|
"domain_id": "medical", |
|
|
"method": "peft", |
|
|
"epochs": 3, |
|
|
"learning_rate": 2e-4, |
|
|
"batch_size": 4 |
|
|
} |
|
|
``` |
|
|
""" |
|
|
logger.info(f"Received training request: {request.dict()}") |
|
|
|
|
|
|
|
|
if fine_tuning_manager.current_training_state["is_training"]: |
|
|
raise HTTPException( |
|
|
status_code=409, |
|
|
detail="Training is already in progress. Please wait or stop the current training." |
|
|
) |
|
|
|
|
|
|
|
|
if request.method not in ["peft", "unsloth", "mlx"]: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail=f"Invalid training method: {request.method}. Must be 'peft', 'unsloth', or 'mlx'" |
|
|
) |
|
|
|
|
|
|
|
|
training_examples = fine_tuning_manager.load_training_data(request.domain_id) |
|
|
if not training_examples: |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"No training data found for domain: {request.domain_id or 'all'}" |
|
|
) |
|
|
|
|
|
logger.info(f"Found {len(training_examples)} training examples") |
|
|
|
|
|
|
|
|
async def run_training(): |
|
|
try: |
|
|
result = await fine_tuning_manager.start_training( |
|
|
apprentice_model_name=request.apprentice_model_name, |
|
|
domain_id=request.domain_id, |
|
|
method=request.method, |
|
|
epochs=request.epochs, |
|
|
learning_rate=request.learning_rate, |
|
|
batch_size=request.batch_size, |
|
|
output_name=request.output_name, |
|
|
progress_callback=None |
|
|
) |
|
|
logger.info(f"Training completed: {result}") |
|
|
except Exception as e: |
|
|
logger.error(f"Training failed: {e}", exc_info=True) |
|
|
fine_tuning_manager.current_training_state.update({ |
|
|
"is_training": False, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
background_tasks.add_task(run_training) |
|
|
|
|
|
return TrainingResultResponse( |
|
|
success=True, |
|
|
model_name=request.apprentice_model_name, |
|
|
method=request.method, |
|
|
output_dir=None |
|
|
) |
|
|
|
|
|
|
|
|
@router.get("/status", response_model=TrainingStatusResponse) |
|
|
async def get_training_status(): |
|
|
""" |
|
|
Get current training status and progress. |
|
|
|
|
|
Returns real-time information about ongoing training: |
|
|
- Progress percentage (0-100) |
|
|
- Current epoch and total epochs |
|
|
- Current loss value |
|
|
- Model being trained |
|
|
- Start time |
|
|
|
|
|
**Example Response:** |
|
|
```json |
|
|
{ |
|
|
"is_training": true, |
|
|
"progress": 45.5, |
|
|
"current_epoch": 1, |
|
|
"total_epochs": 3, |
|
|
"loss": 0.234, |
|
|
"model_id": "microsoft/phi-2", |
|
|
"start_time": "2025-12-02T10:30:00" |
|
|
} |
|
|
``` |
|
|
""" |
|
|
state = fine_tuning_manager.get_training_status() |
|
|
return TrainingStatusResponse(**state) |
|
|
|
|
|
|
|
|
@router.post("/stop") |
|
|
async def stop_training(): |
|
|
""" |
|
|
Stop the current training process. |
|
|
|
|
|
Attempts to gracefully stop the training and save the current checkpoint. |
|
|
Note: This may not work immediately depending on the training backend. |
|
|
""" |
|
|
if not fine_tuning_manager.current_training_state["is_training"]: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail="No training is currently in progress" |
|
|
) |
|
|
|
|
|
fine_tuning_manager.stop_training() |
|
|
logger.info("Training stop requested") |
|
|
|
|
|
return { |
|
|
"success": True, |
|
|
"message": "Training stop requested. This may take a moment..." |
|
|
} |
|
|
|
|
|
|
|
|
@router.get("/data/stats", response_model=TrainingDataStatsResponse) |
|
|
async def get_training_data_stats(): |
|
|
""" |
|
|
Get statistics about available training data. |
|
|
|
|
|
Returns: |
|
|
- Total number of training examples |
|
|
- Number of examples per domain |
|
|
- List of available domains |
|
|
- File paths to training data |
|
|
|
|
|
**Example Response:** |
|
|
```json |
|
|
{ |
|
|
"total_examples": 150, |
|
|
"examples_by_domain": { |
|
|
"medical": 50, |
|
|
"general": 100 |
|
|
}, |
|
|
"domains": ["medical", "general"], |
|
|
"file_paths": [ |
|
|
"training_data/master_outputs/master_outputs_medical.jsonl", |
|
|
"training_data/master_outputs/master_outputs_general.jsonl" |
|
|
] |
|
|
} |
|
|
``` |
|
|
""" |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
training_data_dir = Path("training_data/master_outputs") |
|
|
|
|
|
if not training_data_dir.exists(): |
|
|
return TrainingDataStatsResponse( |
|
|
total_examples=0, |
|
|
examples_by_domain={}, |
|
|
domains=[], |
|
|
file_paths=[] |
|
|
) |
|
|
|
|
|
examples_by_domain = {} |
|
|
file_paths = [] |
|
|
|
|
|
for jsonl_file in training_data_dir.glob("master_outputs_*.jsonl"): |
|
|
file_paths.append(str(jsonl_file)) |
|
|
|
|
|
|
|
|
domain = jsonl_file.stem.replace("master_outputs_", "") |
|
|
|
|
|
|
|
|
count = 0 |
|
|
with open(jsonl_file, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
try: |
|
|
json.loads(line.strip()) |
|
|
count += 1 |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
examples_by_domain[domain] = count |
|
|
|
|
|
total_examples = sum(examples_by_domain.values()) |
|
|
domains = list(examples_by_domain.keys()) |
|
|
|
|
|
return TrainingDataStatsResponse( |
|
|
total_examples=total_examples, |
|
|
examples_by_domain=examples_by_domain, |
|
|
domains=domains, |
|
|
file_paths=file_paths |
|
|
) |
|
|
|
|
|
|
|
|
@router.get("/metrics/{checkpoint_name}", response_model=TrainingMetricsResponse) |
|
|
async def get_training_metrics(checkpoint_name: str): |
|
|
""" |
|
|
Get training metrics from a specific checkpoint. |
|
|
|
|
|
Loads and returns the training history, including: |
|
|
- Loss values over time |
|
|
- Learning rate schedule |
|
|
- Best metric achieved |
|
|
- Best model checkpoint path |
|
|
|
|
|
**Example:** |
|
|
``` |
|
|
GET /api/training/metrics/apprentice_medical_20251202_103000 |
|
|
``` |
|
|
""" |
|
|
from pathlib import Path |
|
|
|
|
|
checkpoint_dir = Path("training_data/checkpoints") / checkpoint_name |
|
|
|
|
|
if not checkpoint_dir.exists(): |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Checkpoint not found: {checkpoint_name}" |
|
|
) |
|
|
|
|
|
metrics = fine_tuning_manager.get_training_metrics(str(checkpoint_dir)) |
|
|
|
|
|
if "error" in metrics: |
|
|
return TrainingMetricsResponse(error=metrics["error"]) |
|
|
|
|
|
return TrainingMetricsResponse(**metrics) |
|
|
|
|
|
|
|
|
@router.get("/checkpoints", response_model=List[Dict[str, Any]]) |
|
|
async def list_checkpoints(): |
|
|
""" |
|
|
List all available training checkpoints. |
|
|
|
|
|
Returns a list of checkpoint directories with metadata: |
|
|
- Checkpoint name |
|
|
- Creation date |
|
|
- Model name (if available) |
|
|
- Size on disk |
|
|
|
|
|
**Example Response:** |
|
|
```json |
|
|
[ |
|
|
{ |
|
|
"name": "apprentice_medical_20251202_103000", |
|
|
"created_at": "2025-12-02T10:30:00", |
|
|
"size_mb": 256.5, |
|
|
"model_name": "microsoft/phi-2" |
|
|
} |
|
|
] |
|
|
``` |
|
|
""" |
|
|
from pathlib import Path |
|
|
import os |
|
|
|
|
|
checkpoints_dir = Path("training_data/checkpoints") |
|
|
|
|
|
if not checkpoints_dir.exists(): |
|
|
return [] |
|
|
|
|
|
checkpoints = [] |
|
|
|
|
|
for checkpoint_dir in checkpoints_dir.iterdir(): |
|
|
if not checkpoint_dir.is_dir(): |
|
|
continue |
|
|
|
|
|
|
|
|
total_size = sum( |
|
|
f.stat().st_size for f in checkpoint_dir.rglob('*') if f.is_file() |
|
|
) |
|
|
size_mb = total_size / (1024 * 1024) |
|
|
|
|
|
|
|
|
created_at = datetime.fromtimestamp(checkpoint_dir.stat().st_ctime).isoformat() |
|
|
|
|
|
|
|
|
model_name = None |
|
|
config_file = checkpoint_dir / "config.json" |
|
|
if config_file.exists(): |
|
|
import json |
|
|
try: |
|
|
with open(config_file, 'r') as f: |
|
|
config = json.load(f) |
|
|
model_name = config.get("_name_or_path") |
|
|
except: |
|
|
pass |
|
|
|
|
|
checkpoints.append({ |
|
|
"name": checkpoint_dir.name, |
|
|
"created_at": created_at, |
|
|
"size_mb": round(size_mb, 2), |
|
|
"model_name": model_name |
|
|
}) |
|
|
|
|
|
|
|
|
checkpoints.sort(key=lambda x: x["created_at"], reverse=True) |
|
|
|
|
|
return checkpoints |
|
|
|
|
|
|
|
|
@router.delete("/checkpoints/{checkpoint_name}") |
|
|
async def delete_checkpoint(checkpoint_name: str): |
|
|
""" |
|
|
Delete a training checkpoint. |
|
|
|
|
|
**Warning:** This operation is irreversible. |
|
|
|
|
|
**Example:** |
|
|
``` |
|
|
DELETE /api/training/checkpoints/apprentice_medical_20251202_103000 |
|
|
``` |
|
|
""" |
|
|
from pathlib import Path |
|
|
import shutil |
|
|
|
|
|
checkpoint_dir = Path("training_data/checkpoints") / checkpoint_name |
|
|
|
|
|
if not checkpoint_dir.exists(): |
|
|
raise HTTPException( |
|
|
status_code=404, |
|
|
detail=f"Checkpoint not found: {checkpoint_name}" |
|
|
) |
|
|
|
|
|
|
|
|
if fine_tuning_manager.current_training_state["is_training"]: |
|
|
current_model = fine_tuning_manager.current_training_state.get("model_id", "") |
|
|
if checkpoint_name in current_model: |
|
|
raise HTTPException( |
|
|
status_code=409, |
|
|
detail="Cannot delete checkpoint while it's being used for training" |
|
|
) |
|
|
|
|
|
try: |
|
|
shutil.rmtree(checkpoint_dir) |
|
|
logger.info(f"Deleted checkpoint: {checkpoint_name}") |
|
|
return { |
|
|
"success": True, |
|
|
"message": f"Checkpoint '{checkpoint_name}' deleted successfully" |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to delete checkpoint: {e}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Failed to delete checkpoint: {str(e)}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/auto/status") |
|
|
async def get_auto_training_status(): |
|
|
""" |
|
|
自動学習システムの状態を取得 |
|
|
|
|
|
Returns: |
|
|
enabled, is_training, trigger conditions, data stats など |
|
|
""" |
|
|
try: |
|
|
status = auto_training_manager.get_status() |
|
|
return status |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get auto-training status: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@router.post("/auto/enable") |
|
|
async def enable_auto_training(): |
|
|
"""自動学習を有効化""" |
|
|
try: |
|
|
auto_training_manager.enable() |
|
|
return {"success": True, "message": "Auto-training enabled"} |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to enable auto-training: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@router.post("/auto/disable") |
|
|
async def disable_auto_training(): |
|
|
"""自動学習を無効化""" |
|
|
try: |
|
|
auto_training_manager.disable() |
|
|
return {"success": True, "message": "Auto-training disabled"} |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to disable auto-training: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@router.post("/auto/trigger") |
|
|
async def trigger_auto_training_manually(domain_id: Optional[str] = None, background_tasks: BackgroundTasks = None): |
|
|
""" |
|
|
手動で自動学習をトリガー |
|
|
|
|
|
Args: |
|
|
domain_id: 特定ドメインのみ学習する場合は指定 |
|
|
""" |
|
|
try: |
|
|
|
|
|
should_trigger, reason = auto_training_manager.check_training_trigger(domain_id) |
|
|
|
|
|
if not should_trigger: |
|
|
return { |
|
|
"success": False, |
|
|
"message": "Training conditions not met", |
|
|
"reason": reason |
|
|
} |
|
|
|
|
|
|
|
|
if background_tasks: |
|
|
background_tasks.add_task(auto_training_manager.trigger_auto_training, domain_id) |
|
|
return { |
|
|
"success": True, |
|
|
"message": "Auto-training triggered", |
|
|
"reason": reason |
|
|
} |
|
|
else: |
|
|
result = await auto_training_manager.trigger_auto_training(domain_id) |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to trigger auto-training: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@router.put("/auto/config") |
|
|
async def update_auto_training_config(config: Dict[str, Any]): |
|
|
""" |
|
|
自動学習の設定を更新 |
|
|
|
|
|
Args: |
|
|
config: 新しい設定 (min_examples, min_days, trigger_mode など) |
|
|
""" |
|
|
try: |
|
|
auto_training_manager.update_config(config) |
|
|
return { |
|
|
"success": True, |
|
|
"message": "Auto-training config updated", |
|
|
"new_config": auto_training_manager.config |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to update auto-training config: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def auto_training_monitor_loop(): |
|
|
""" |
|
|
バックグラウンドで定期的に自動学習条件をチェックするループ |
|
|
""" |
|
|
logger.info("Auto-training monitor started") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
if not auto_training_manager.enabled: |
|
|
await asyncio.sleep(60) |
|
|
continue |
|
|
|
|
|
|
|
|
check_interval = auto_training_manager.check_interval_minutes |
|
|
|
|
|
|
|
|
should_trigger, reason = auto_training_manager.check_training_trigger() |
|
|
|
|
|
if should_trigger: |
|
|
logger.info(f"Auto-training trigger conditions met: {reason}") |
|
|
|
|
|
|
|
|
if auto_training_manager.should_train_now(): |
|
|
logger.info("Starting auto-training (preferred time window)") |
|
|
await auto_training_manager.trigger_auto_training() |
|
|
else: |
|
|
logger.info(f"Trigger conditions met but not in preferred time window (hour {auto_training_manager.preferred_hour})") |
|
|
|
|
|
|
|
|
await asyncio.sleep(check_interval * 60) |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
logger.info("Auto-training monitor cancelled") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Error in auto-training monitor: {e}", exc_info=True) |
|
|
await asyncio.sleep(60) |
|
|
|
|
|
|
|
|
def start_auto_training_monitor(): |
|
|
"""バックグラウンド監視タスクを開始""" |
|
|
global _monitoring_task |
|
|
|
|
|
if _monitoring_task is None or _monitoring_task.done(): |
|
|
_monitoring_task = asyncio.create_task(auto_training_monitor_loop()) |
|
|
logger.info("Auto-training background monitor started") |
|
|
else: |
|
|
logger.warning("Auto-training monitor is already running") |
|
|
|
|
|
|
|
|
def stop_auto_training_monitor(): |
|
|
"""バックグラウンド監視タスクを停止""" |
|
|
global _monitoring_task |
|
|
|
|
|
if _monitoring_task and not _monitoring_task.done(): |
|
|
_monitoring_task.cancel() |
|
|
logger.info("Auto-training background monitor stopped") |
|
|
|
|
|
_monitoring_task = None |
|
|
|