content-engine / src /content_engine /api /routes_training.py
dippoo's picture
Switch training UI from epochs to max steps (default 1500)
01a9c08
raw
history blame
11.6 kB
"""Training API routes — LoRA model training management."""
from __future__ import annotations
import logging
from pathlib import Path
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from content_engine.services.lora_trainer import LoRATrainer, TrainingConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/training", tags=["training"])
_trainer: LoRATrainer | None = None
_runpod_trainer = None # RunPodTrainer | None
def init_routes(trainer: LoRATrainer, runpod_trainer=None):
global _trainer, _runpod_trainer
_trainer = trainer
_runpod_trainer = runpod_trainer
@router.get("/status")
async def training_status():
"""Check if training infrastructure is ready."""
if _trainer is None:
return {"ready": False, "sd_scripts_installed": False, "runpod_available": False}
return {
"ready": True,
"sd_scripts_installed": _trainer.sd_scripts_installed,
"runpod_available": _runpod_trainer is not None and _runpod_trainer.available,
}
@router.get("/models")
async def list_training_models():
"""List available base models for LoRA training with their recommended parameters."""
if _runpod_trainer is None:
return {"models": {}, "default": "flux2_dev"}
models = _runpod_trainer.list_training_models()
return {
"models": models,
"default": "flux2_dev", # FLUX 2 recommended for realistic person
}
@router.get("/gpu-options")
async def list_gpu_options():
"""List available RunPod GPU types."""
if _runpod_trainer is None:
return {"gpus": {}}
return {"gpus": _runpod_trainer.list_gpu_options()}
@router.post("/install")
async def install_sd_scripts():
"""Install Kohya sd-scripts for LoRA training."""
if _trainer is None:
raise HTTPException(503, "Trainer not initialized")
try:
msg = await _trainer.install_sd_scripts()
return {"status": "ok", "message": msg}
except Exception as e:
raise HTTPException(500, f"Installation failed: {e}")
@router.post("/start")
async def start_training(
images: list[UploadFile] = File(...),
name: str = Form(...),
trigger_word: str = Form(""),
captions_json: str = Form("{}"),
base_model: str = Form("flux2_dev"), # Model registry key (flux2_dev, sd15_realistic, sdxl_base)
resolution: int | None = Form(None), # None = use model default
num_epochs: int = Form(100), # High default — max_steps controls actual limit
max_steps: int = Form(1500), # Primary training length control
learning_rate: float | None = Form(None), # None = use model default
network_rank: int | None = Form(None), # None = use model default
network_alpha: int | None = Form(None), # None = use model default
optimizer: str | None = Form(None), # None = use model default
train_batch_size: int = Form(1),
save_every_n_steps: int = Form(500),
backend: str = Form("runpod"), # Default to runpod for cloud training
gpu_type: str = Form("NVIDIA GeForce RTX 4090"),
):
"""Start a LoRA training job (local or RunPod cloud).
Parameters like resolution, learning_rate, network_rank will use model
registry defaults if not specified. Use base_model to select the model type.
"""
import json
if len(images) < 5:
raise HTTPException(400, "Need at least 5 training images for reasonable results")
# Parse captions
try:
captions = json.loads(captions_json) if captions_json else {}
except json.JSONDecodeError:
captions = {}
# Save uploaded images to temp directory
import uuid
from content_engine.config import settings
upload_dir = settings.paths.data_dir / "training_uploads" / str(uuid.uuid4())[:8]
upload_dir.mkdir(parents=True, exist_ok=True)
image_paths = []
for img in images:
file_path = upload_dir / img.filename
content = await img.read()
file_path.write_bytes(content)
image_paths.append(str(file_path))
# Write caption .txt file alongside the image
caption_text = captions.get(img.filename, trigger_word or "")
caption_path = file_path.with_suffix(".txt")
caption_path.write_text(caption_text, encoding="utf-8")
logger.info("Saved caption for %s: %s", img.filename, caption_text[:80])
# Route to RunPod cloud trainer
if backend == "runpod":
if _runpod_trainer is None:
raise HTTPException(503, "RunPod not configured — set RUNPOD_API_KEY in .env")
# Validate model exists
model_cfg = _runpod_trainer.get_model_config(base_model)
if not model_cfg:
available = list(_runpod_trainer.list_training_models().keys())
raise HTTPException(400, f"Unknown base model: {base_model}. Available: {available}")
job_id = await _runpod_trainer.start_training(
name=name,
image_paths=image_paths,
trigger_word=trigger_word,
base_model=base_model,
resolution=resolution,
num_epochs=num_epochs,
max_train_steps=max_steps,
learning_rate=learning_rate,
network_rank=network_rank,
network_alpha=network_alpha,
optimizer=optimizer,
save_every_n_steps=save_every_n_steps,
gpu_type=gpu_type,
)
job = _runpod_trainer.get_job(job_id)
return {
"job_id": job_id,
"status": job.status if job else "unknown",
"name": name,
"backend": "runpod",
"base_model": base_model,
"model_type": model_cfg.get("model_type", "unknown"),
}
# Local training (uses local GPU with Kohya sd-scripts)
if _trainer is None:
raise HTTPException(503, "Trainer not initialized")
# For local training, use model registry defaults if available
model_cfg = {}
if _runpod_trainer:
model_cfg = _runpod_trainer.get_model_config(base_model) or {}
# Resolve local model path
local_model_path = model_cfg.get("local_path") if model_cfg else None
if not local_model_path:
# Fall back to default local path
local_model_path = str(settings.paths.checkpoint_dir / "realisticVisionV51_v51VAE.safetensors")
config = TrainingConfig(
name=name,
trigger_word=trigger_word,
base_model=local_model_path,
resolution=resolution or model_cfg.get("resolution", 512),
num_epochs=num_epochs,
learning_rate=learning_rate or model_cfg.get("learning_rate", 1e-4),
network_rank=network_rank or model_cfg.get("network_rank", 32),
network_alpha=network_alpha or model_cfg.get("network_alpha", 16),
optimizer=optimizer or model_cfg.get("optimizer", "AdamW8bit"),
train_batch_size=train_batch_size,
save_every_n_epochs=save_every_n_steps, # Local trainer uses epoch-based saving
)
job_id = await _trainer.start_training(config, image_paths)
job = _trainer.get_job(job_id)
return {
"job_id": job_id,
"status": job.status if job else "unknown",
"name": name,
"backend": "local",
"base_model": base_model,
}
@router.get("/jobs")
async def list_training_jobs():
"""List all training jobs (local + cloud)."""
jobs = []
if _trainer:
for j in _trainer.list_jobs():
jobs.append({
"id": j.id, "name": j.name, "status": j.status,
"progress": round(j.progress, 3),
"current_epoch": j.current_epoch, "total_epochs": j.total_epochs,
"current_step": j.current_step, "total_steps": j.total_steps,
"loss": j.loss, "started_at": j.started_at,
"completed_at": j.completed_at, "output_path": j.output_path,
"error": j.error, "backend": "local",
"log_lines": j.log_lines[-50:] if hasattr(j, 'log_lines') else [],
})
if _runpod_trainer:
await _runpod_trainer.ensure_loaded()
for j in _runpod_trainer.list_jobs():
jobs.append({
"id": j.id, "name": j.name, "status": j.status,
"progress": round(j.progress, 3),
"current_epoch": j.current_epoch, "total_epochs": j.total_epochs,
"current_step": j.current_step, "total_steps": j.total_steps,
"loss": j.loss, "started_at": j.started_at,
"completed_at": j.completed_at, "output_path": j.output_path,
"error": j.error, "backend": "runpod",
"base_model": j.base_model, "model_type": j.model_type,
"log_lines": j.log_lines[-50:],
})
return jobs
@router.get("/jobs/{job_id}")
async def get_training_job(job_id: str):
"""Get details of a specific training job including logs."""
# Check RunPod jobs first
if _runpod_trainer:
await _runpod_trainer.ensure_loaded()
job = _runpod_trainer.get_job(job_id)
if job:
return {
"id": job.id, "name": job.name, "status": job.status,
"progress": round(job.progress, 3),
"current_epoch": job.current_epoch, "total_epochs": job.total_epochs,
"current_step": job.current_step, "total_steps": job.total_steps,
"loss": job.loss, "started_at": job.started_at,
"completed_at": job.completed_at, "output_path": job.output_path,
"error": job.error, "log_lines": job.log_lines[-50:],
"backend": "runpod", "base_model": job.base_model,
}
# Then check local trainer
if _trainer:
job = _trainer.get_job(job_id)
if job:
return {
"id": job.id, "name": job.name, "status": job.status,
"progress": round(job.progress, 3),
"current_epoch": job.current_epoch, "total_epochs": job.total_epochs,
"current_step": job.current_step, "total_steps": job.total_steps,
"loss": job.loss, "started_at": job.started_at,
"completed_at": job.completed_at, "output_path": job.output_path,
"error": job.error, "log_lines": job.log_lines[-50:],
}
raise HTTPException(404, f"Training job not found: {job_id}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(job_id: str):
"""Cancel a running training job (local or cloud)."""
if _runpod_trainer and _runpod_trainer.get_job(job_id):
cancelled = await _runpod_trainer.cancel_job(job_id)
if cancelled:
return {"status": "cancelled", "job_id": job_id}
if _trainer:
cancelled = await _trainer.cancel_job(job_id)
if cancelled:
return {"status": "cancelled", "job_id": job_id}
raise HTTPException(404, "Job not found or not running")
@router.delete("/jobs/{job_id}")
async def delete_training_job(job_id: str):
"""Delete a training job from history."""
if _runpod_trainer:
deleted = await _runpod_trainer.delete_job(job_id)
if deleted:
return {"status": "deleted", "job_id": job_id}
raise HTTPException(404, f"Training job not found: {job_id}")
@router.delete("/jobs")
async def delete_failed_jobs():
"""Delete all failed training jobs."""
if _runpod_trainer:
count = await _runpod_trainer.delete_failed_jobs()
return {"status": "ok", "deleted": count}
return {"status": "ok", "deleted": 0}