Spaces:
Running
Running
File size: 11,598 Bytes
ed37502 01a9c08 ed37502 01a9c08 ed37502 01a9c08 ed37502 01a9c08 ed37502 01a9c08 ed37502 27fea48 ed37502 27fea48 ed37502 27fea48 ed37502 27fea48 ed37502 27fea48 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 | """Training API routes — LoRA model training management."""
from __future__ import annotations
import logging
from pathlib import Path
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from content_engine.services.lora_trainer import LoRATrainer, TrainingConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/training", tags=["training"])
_trainer: LoRATrainer | None = None
_runpod_trainer = None # RunPodTrainer | None
def init_routes(trainer: LoRATrainer, runpod_trainer=None):
global _trainer, _runpod_trainer
_trainer = trainer
_runpod_trainer = runpod_trainer
@router.get("/status")
async def training_status():
"""Check if training infrastructure is ready."""
if _trainer is None:
return {"ready": False, "sd_scripts_installed": False, "runpod_available": False}
return {
"ready": True,
"sd_scripts_installed": _trainer.sd_scripts_installed,
"runpod_available": _runpod_trainer is not None and _runpod_trainer.available,
}
@router.get("/models")
async def list_training_models():
"""List available base models for LoRA training with their recommended parameters."""
if _runpod_trainer is None:
return {"models": {}, "default": "flux2_dev"}
models = _runpod_trainer.list_training_models()
return {
"models": models,
"default": "flux2_dev", # FLUX 2 recommended for realistic person
}
@router.get("/gpu-options")
async def list_gpu_options():
"""List available RunPod GPU types."""
if _runpod_trainer is None:
return {"gpus": {}}
return {"gpus": _runpod_trainer.list_gpu_options()}
@router.post("/install")
async def install_sd_scripts():
"""Install Kohya sd-scripts for LoRA training."""
if _trainer is None:
raise HTTPException(503, "Trainer not initialized")
try:
msg = await _trainer.install_sd_scripts()
return {"status": "ok", "message": msg}
except Exception as e:
raise HTTPException(500, f"Installation failed: {e}")
@router.post("/start")
async def start_training(
images: list[UploadFile] = File(...),
name: str = Form(...),
trigger_word: str = Form(""),
captions_json: str = Form("{}"),
base_model: str = Form("flux2_dev"), # Model registry key (flux2_dev, sd15_realistic, sdxl_base)
resolution: int | None = Form(None), # None = use model default
num_epochs: int = Form(100), # High default — max_steps controls actual limit
max_steps: int = Form(1500), # Primary training length control
learning_rate: float | None = Form(None), # None = use model default
network_rank: int | None = Form(None), # None = use model default
network_alpha: int | None = Form(None), # None = use model default
optimizer: str | None = Form(None), # None = use model default
train_batch_size: int = Form(1),
save_every_n_steps: int = Form(500),
backend: str = Form("runpod"), # Default to runpod for cloud training
gpu_type: str = Form("NVIDIA GeForce RTX 4090"),
):
"""Start a LoRA training job (local or RunPod cloud).
Parameters like resolution, learning_rate, network_rank will use model
registry defaults if not specified. Use base_model to select the model type.
"""
import json
if len(images) < 5:
raise HTTPException(400, "Need at least 5 training images for reasonable results")
# Parse captions
try:
captions = json.loads(captions_json) if captions_json else {}
except json.JSONDecodeError:
captions = {}
# Save uploaded images to temp directory
import uuid
from content_engine.config import settings
upload_dir = settings.paths.data_dir / "training_uploads" / str(uuid.uuid4())[:8]
upload_dir.mkdir(parents=True, exist_ok=True)
image_paths = []
for img in images:
file_path = upload_dir / img.filename
content = await img.read()
file_path.write_bytes(content)
image_paths.append(str(file_path))
# Write caption .txt file alongside the image
caption_text = captions.get(img.filename, trigger_word or "")
caption_path = file_path.with_suffix(".txt")
caption_path.write_text(caption_text, encoding="utf-8")
logger.info("Saved caption for %s: %s", img.filename, caption_text[:80])
# Route to RunPod cloud trainer
if backend == "runpod":
if _runpod_trainer is None:
raise HTTPException(503, "RunPod not configured — set RUNPOD_API_KEY in .env")
# Validate model exists
model_cfg = _runpod_trainer.get_model_config(base_model)
if not model_cfg:
available = list(_runpod_trainer.list_training_models().keys())
raise HTTPException(400, f"Unknown base model: {base_model}. Available: {available}")
job_id = await _runpod_trainer.start_training(
name=name,
image_paths=image_paths,
trigger_word=trigger_word,
base_model=base_model,
resolution=resolution,
num_epochs=num_epochs,
max_train_steps=max_steps,
learning_rate=learning_rate,
network_rank=network_rank,
network_alpha=network_alpha,
optimizer=optimizer,
save_every_n_steps=save_every_n_steps,
gpu_type=gpu_type,
)
job = _runpod_trainer.get_job(job_id)
return {
"job_id": job_id,
"status": job.status if job else "unknown",
"name": name,
"backend": "runpod",
"base_model": base_model,
"model_type": model_cfg.get("model_type", "unknown"),
}
# Local training (uses local GPU with Kohya sd-scripts)
if _trainer is None:
raise HTTPException(503, "Trainer not initialized")
# For local training, use model registry defaults if available
model_cfg = {}
if _runpod_trainer:
model_cfg = _runpod_trainer.get_model_config(base_model) or {}
# Resolve local model path
local_model_path = model_cfg.get("local_path") if model_cfg else None
if not local_model_path:
# Fall back to default local path
local_model_path = str(settings.paths.checkpoint_dir / "realisticVisionV51_v51VAE.safetensors")
config = TrainingConfig(
name=name,
trigger_word=trigger_word,
base_model=local_model_path,
resolution=resolution or model_cfg.get("resolution", 512),
num_epochs=num_epochs,
learning_rate=learning_rate or model_cfg.get("learning_rate", 1e-4),
network_rank=network_rank or model_cfg.get("network_rank", 32),
network_alpha=network_alpha or model_cfg.get("network_alpha", 16),
optimizer=optimizer or model_cfg.get("optimizer", "AdamW8bit"),
train_batch_size=train_batch_size,
save_every_n_epochs=save_every_n_steps, # Local trainer uses epoch-based saving
)
job_id = await _trainer.start_training(config, image_paths)
job = _trainer.get_job(job_id)
return {
"job_id": job_id,
"status": job.status if job else "unknown",
"name": name,
"backend": "local",
"base_model": base_model,
}
@router.get("/jobs")
async def list_training_jobs():
"""List all training jobs (local + cloud)."""
jobs = []
if _trainer:
for j in _trainer.list_jobs():
jobs.append({
"id": j.id, "name": j.name, "status": j.status,
"progress": round(j.progress, 3),
"current_epoch": j.current_epoch, "total_epochs": j.total_epochs,
"current_step": j.current_step, "total_steps": j.total_steps,
"loss": j.loss, "started_at": j.started_at,
"completed_at": j.completed_at, "output_path": j.output_path,
"error": j.error, "backend": "local",
"log_lines": j.log_lines[-50:] if hasattr(j, 'log_lines') else [],
})
if _runpod_trainer:
await _runpod_trainer.ensure_loaded()
for j in _runpod_trainer.list_jobs():
jobs.append({
"id": j.id, "name": j.name, "status": j.status,
"progress": round(j.progress, 3),
"current_epoch": j.current_epoch, "total_epochs": j.total_epochs,
"current_step": j.current_step, "total_steps": j.total_steps,
"loss": j.loss, "started_at": j.started_at,
"completed_at": j.completed_at, "output_path": j.output_path,
"error": j.error, "backend": "runpod",
"base_model": j.base_model, "model_type": j.model_type,
"log_lines": j.log_lines[-50:],
})
return jobs
@router.get("/jobs/{job_id}")
async def get_training_job(job_id: str):
"""Get details of a specific training job including logs."""
# Check RunPod jobs first
if _runpod_trainer:
await _runpod_trainer.ensure_loaded()
job = _runpod_trainer.get_job(job_id)
if job:
return {
"id": job.id, "name": job.name, "status": job.status,
"progress": round(job.progress, 3),
"current_epoch": job.current_epoch, "total_epochs": job.total_epochs,
"current_step": job.current_step, "total_steps": job.total_steps,
"loss": job.loss, "started_at": job.started_at,
"completed_at": job.completed_at, "output_path": job.output_path,
"error": job.error, "log_lines": job.log_lines[-50:],
"backend": "runpod", "base_model": job.base_model,
}
# Then check local trainer
if _trainer:
job = _trainer.get_job(job_id)
if job:
return {
"id": job.id, "name": job.name, "status": job.status,
"progress": round(job.progress, 3),
"current_epoch": job.current_epoch, "total_epochs": job.total_epochs,
"current_step": job.current_step, "total_steps": job.total_steps,
"loss": job.loss, "started_at": job.started_at,
"completed_at": job.completed_at, "output_path": job.output_path,
"error": job.error, "log_lines": job.log_lines[-50:],
}
raise HTTPException(404, f"Training job not found: {job_id}")
@router.post("/jobs/{job_id}/cancel")
async def cancel_training_job(job_id: str):
"""Cancel a running training job (local or cloud)."""
if _runpod_trainer and _runpod_trainer.get_job(job_id):
cancelled = await _runpod_trainer.cancel_job(job_id)
if cancelled:
return {"status": "cancelled", "job_id": job_id}
if _trainer:
cancelled = await _trainer.cancel_job(job_id)
if cancelled:
return {"status": "cancelled", "job_id": job_id}
raise HTTPException(404, "Job not found or not running")
@router.delete("/jobs/{job_id}")
async def delete_training_job(job_id: str):
"""Delete a training job from history."""
if _runpod_trainer:
deleted = await _runpod_trainer.delete_job(job_id)
if deleted:
return {"status": "deleted", "job_id": job_id}
raise HTTPException(404, f"Training job not found: {job_id}")
@router.delete("/jobs")
async def delete_failed_jobs():
"""Delete all failed training jobs."""
if _runpod_trainer:
count = await _runpod_trainer.delete_failed_jobs()
return {"status": "ok", "deleted": count}
return {"status": "ok", "deleted": 0}
|