"""Training API routes — LoRA model training management.""" from __future__ import annotations import logging from pathlib import Path from fastapi import APIRouter, File, Form, HTTPException, UploadFile from content_engine.services.lora_trainer import LoRATrainer, TrainingConfig logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/training", tags=["training"]) _trainer: LoRATrainer | None = None _runpod_trainer = None # RunPodTrainer | None def init_routes(trainer: LoRATrainer, runpod_trainer=None): global _trainer, _runpod_trainer _trainer = trainer _runpod_trainer = runpod_trainer @router.get("/status") async def training_status(): """Check if training infrastructure is ready.""" if _trainer is None: return {"ready": False, "sd_scripts_installed": False, "runpod_available": False} return { "ready": True, "sd_scripts_installed": _trainer.sd_scripts_installed, "runpod_available": _runpod_trainer is not None and _runpod_trainer.available, } @router.get("/models") async def list_training_models(): """List available base models for LoRA training with their recommended parameters.""" if _runpod_trainer is None: return {"models": {}, "default": "flux2_dev"} models = _runpod_trainer.list_training_models() return { "models": models, "default": "flux2_dev", # FLUX 2 recommended for realistic person } @router.get("/gpu-options") async def list_gpu_options(): """List available RunPod GPU types.""" if _runpod_trainer is None: return {"gpus": {}} return {"gpus": _runpod_trainer.list_gpu_options()} @router.post("/install") async def install_sd_scripts(): """Install Kohya sd-scripts for LoRA training.""" if _trainer is None: raise HTTPException(503, "Trainer not initialized") try: msg = await _trainer.install_sd_scripts() return {"status": "ok", "message": msg} except Exception as e: raise HTTPException(500, f"Installation failed: {e}") @router.post("/start") async def start_training( images: list[UploadFile] = File(...), name: str = Form(...), trigger_word: str = Form(""), captions_json: str = Form("{}"), base_model: str = Form("flux2_dev"), # Model registry key (flux2_dev, sd15_realistic, sdxl_base) resolution: int | None = Form(None), # None = use model default num_epochs: int = Form(100), # High default — max_steps controls actual limit max_steps: int = Form(1500), # Primary training length control learning_rate: float | None = Form(None), # None = use model default network_rank: int | None = Form(None), # None = use model default network_alpha: int | None = Form(None), # None = use model default optimizer: str | None = Form(None), # None = use model default train_batch_size: int = Form(1), save_every_n_steps: int = Form(500), backend: str = Form("runpod"), # Default to runpod for cloud training gpu_type: str = Form("NVIDIA GeForce RTX 4090"), ): """Start a LoRA training job (local or RunPod cloud). Parameters like resolution, learning_rate, network_rank will use model registry defaults if not specified. Use base_model to select the model type. """ import json if len(images) < 5: raise HTTPException(400, "Need at least 5 training images for reasonable results") # Parse captions try: captions = json.loads(captions_json) if captions_json else {} except json.JSONDecodeError: captions = {} # Save uploaded images to temp directory import uuid from content_engine.config import settings upload_dir = settings.paths.data_dir / "training_uploads" / str(uuid.uuid4())[:8] upload_dir.mkdir(parents=True, exist_ok=True) image_paths = [] for img in images: file_path = upload_dir / img.filename content = await img.read() file_path.write_bytes(content) image_paths.append(str(file_path)) # Write caption .txt file alongside the image caption_text = captions.get(img.filename, trigger_word or "") caption_path = file_path.with_suffix(".txt") caption_path.write_text(caption_text, encoding="utf-8") logger.info("Saved caption for %s: %s", img.filename, caption_text[:80]) # Route to RunPod cloud trainer if backend == "runpod": if _runpod_trainer is None: raise HTTPException(503, "RunPod not configured — set RUNPOD_API_KEY in .env") # Validate model exists model_cfg = _runpod_trainer.get_model_config(base_model) if not model_cfg: available = list(_runpod_trainer.list_training_models().keys()) raise HTTPException(400, f"Unknown base model: {base_model}. Available: {available}") job_id = await _runpod_trainer.start_training( name=name, image_paths=image_paths, trigger_word=trigger_word, base_model=base_model, resolution=resolution, num_epochs=num_epochs, max_train_steps=max_steps, learning_rate=learning_rate, network_rank=network_rank, network_alpha=network_alpha, optimizer=optimizer, save_every_n_steps=save_every_n_steps, gpu_type=gpu_type, ) job = _runpod_trainer.get_job(job_id) return { "job_id": job_id, "status": job.status if job else "unknown", "name": name, "backend": "runpod", "base_model": base_model, "model_type": model_cfg.get("model_type", "unknown"), } # Local training (uses local GPU with Kohya sd-scripts) if _trainer is None: raise HTTPException(503, "Trainer not initialized") # For local training, use model registry defaults if available model_cfg = {} if _runpod_trainer: model_cfg = _runpod_trainer.get_model_config(base_model) or {} # Resolve local model path local_model_path = model_cfg.get("local_path") if model_cfg else None if not local_model_path: # Fall back to default local path local_model_path = str(settings.paths.checkpoint_dir / "realisticVisionV51_v51VAE.safetensors") config = TrainingConfig( name=name, trigger_word=trigger_word, base_model=local_model_path, resolution=resolution or model_cfg.get("resolution", 512), num_epochs=num_epochs, learning_rate=learning_rate or model_cfg.get("learning_rate", 1e-4), network_rank=network_rank or model_cfg.get("network_rank", 32), network_alpha=network_alpha or model_cfg.get("network_alpha", 16), optimizer=optimizer or model_cfg.get("optimizer", "AdamW8bit"), train_batch_size=train_batch_size, save_every_n_epochs=save_every_n_steps, # Local trainer uses epoch-based saving ) job_id = await _trainer.start_training(config, image_paths) job = _trainer.get_job(job_id) return { "job_id": job_id, "status": job.status if job else "unknown", "name": name, "backend": "local", "base_model": base_model, } @router.get("/jobs") async def list_training_jobs(): """List all training jobs (local + cloud).""" jobs = [] if _trainer: for j in _trainer.list_jobs(): jobs.append({ "id": j.id, "name": j.name, "status": j.status, "progress": round(j.progress, 3), "current_epoch": j.current_epoch, "total_epochs": j.total_epochs, "current_step": j.current_step, "total_steps": j.total_steps, "loss": j.loss, "started_at": j.started_at, "completed_at": j.completed_at, "output_path": j.output_path, "error": j.error, "backend": "local", "log_lines": j.log_lines[-50:] if hasattr(j, 'log_lines') else [], }) if _runpod_trainer: await _runpod_trainer.ensure_loaded() for j in _runpod_trainer.list_jobs(): jobs.append({ "id": j.id, "name": j.name, "status": j.status, "progress": round(j.progress, 3), "current_epoch": j.current_epoch, "total_epochs": j.total_epochs, "current_step": j.current_step, "total_steps": j.total_steps, "loss": j.loss, "started_at": j.started_at, "completed_at": j.completed_at, "output_path": j.output_path, "error": j.error, "backend": "runpod", "base_model": j.base_model, "model_type": j.model_type, "log_lines": j.log_lines[-50:], }) return jobs @router.get("/jobs/{job_id}") async def get_training_job(job_id: str): """Get details of a specific training job including logs.""" # Check RunPod jobs first if _runpod_trainer: await _runpod_trainer.ensure_loaded() job = _runpod_trainer.get_job(job_id) if job: return { "id": job.id, "name": job.name, "status": job.status, "progress": round(job.progress, 3), "current_epoch": job.current_epoch, "total_epochs": job.total_epochs, "current_step": job.current_step, "total_steps": job.total_steps, "loss": job.loss, "started_at": job.started_at, "completed_at": job.completed_at, "output_path": job.output_path, "error": job.error, "log_lines": job.log_lines[-50:], "backend": "runpod", "base_model": job.base_model, } # Then check local trainer if _trainer: job = _trainer.get_job(job_id) if job: return { "id": job.id, "name": job.name, "status": job.status, "progress": round(job.progress, 3), "current_epoch": job.current_epoch, "total_epochs": job.total_epochs, "current_step": job.current_step, "total_steps": job.total_steps, "loss": job.loss, "started_at": job.started_at, "completed_at": job.completed_at, "output_path": job.output_path, "error": job.error, "log_lines": job.log_lines[-50:], } raise HTTPException(404, f"Training job not found: {job_id}") @router.post("/jobs/{job_id}/cancel") async def cancel_training_job(job_id: str): """Cancel a running training job (local or cloud).""" if _runpod_trainer and _runpod_trainer.get_job(job_id): cancelled = await _runpod_trainer.cancel_job(job_id) if cancelled: return {"status": "cancelled", "job_id": job_id} if _trainer: cancelled = await _trainer.cancel_job(job_id) if cancelled: return {"status": "cancelled", "job_id": job_id} raise HTTPException(404, "Job not found or not running") @router.delete("/jobs/{job_id}") async def delete_training_job(job_id: str): """Delete a training job from history.""" if _runpod_trainer: deleted = await _runpod_trainer.delete_job(job_id) if deleted: return {"status": "deleted", "job_id": job_id} raise HTTPException(404, f"Training job not found: {job_id}") @router.delete("/jobs") async def delete_failed_jobs(): """Delete all failed training jobs.""" if _runpod_trainer: count = await _runpod_trainer.delete_failed_jobs() return {"status": "ok", "deleted": count} return {"status": "ok", "deleted": 0}