Spaces:
Running
Running
| """Training API routes — LoRA model training management.""" | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from fastapi import APIRouter, File, Form, HTTPException, UploadFile | |
| from content_engine.services.lora_trainer import LoRATrainer, TrainingConfig | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/api/training", tags=["training"]) | |
| _trainer: LoRATrainer | None = None | |
| _runpod_trainer = None # RunPodTrainer | None | |
| def init_routes(trainer: LoRATrainer, runpod_trainer=None): | |
| global _trainer, _runpod_trainer | |
| _trainer = trainer | |
| _runpod_trainer = runpod_trainer | |
| async def training_status(): | |
| """Check if training infrastructure is ready.""" | |
| if _trainer is None: | |
| return {"ready": False, "sd_scripts_installed": False, "runpod_available": False} | |
| return { | |
| "ready": True, | |
| "sd_scripts_installed": _trainer.sd_scripts_installed, | |
| "runpod_available": _runpod_trainer is not None and _runpod_trainer.available, | |
| } | |
| async def list_training_models(): | |
| """List available base models for LoRA training with their recommended parameters.""" | |
| if _runpod_trainer is None: | |
| return {"models": {}, "default": "flux2_dev"} | |
| models = _runpod_trainer.list_training_models() | |
| return { | |
| "models": models, | |
| "default": "flux2_dev", # FLUX 2 recommended for realistic person | |
| } | |
| async def list_gpu_options(): | |
| """List available RunPod GPU types.""" | |
| if _runpod_trainer is None: | |
| return {"gpus": {}} | |
| return {"gpus": _runpod_trainer.list_gpu_options()} | |
| async def install_sd_scripts(): | |
| """Install Kohya sd-scripts for LoRA training.""" | |
| if _trainer is None: | |
| raise HTTPException(503, "Trainer not initialized") | |
| try: | |
| msg = await _trainer.install_sd_scripts() | |
| return {"status": "ok", "message": msg} | |
| except Exception as e: | |
| raise HTTPException(500, f"Installation failed: {e}") | |
| async def start_training( | |
| images: list[UploadFile] = File(...), | |
| name: str = Form(...), | |
| trigger_word: str = Form(""), | |
| captions_json: str = Form("{}"), | |
| base_model: str = Form("flux2_dev"), # Model registry key (flux2_dev, sd15_realistic, sdxl_base) | |
| resolution: int | None = Form(None), # None = use model default | |
| num_epochs: int = Form(100), # High default — max_steps controls actual limit | |
| max_steps: int = Form(1500), # Primary training length control | |
| learning_rate: float | None = Form(None), # None = use model default | |
| network_rank: int | None = Form(None), # None = use model default | |
| network_alpha: int | None = Form(None), # None = use model default | |
| optimizer: str | None = Form(None), # None = use model default | |
| train_batch_size: int = Form(1), | |
| save_every_n_steps: int = Form(500), | |
| backend: str = Form("runpod"), # Default to runpod for cloud training | |
| gpu_type: str = Form("NVIDIA GeForce RTX 4090"), | |
| ): | |
| """Start a LoRA training job (local or RunPod cloud). | |
| Parameters like resolution, learning_rate, network_rank will use model | |
| registry defaults if not specified. Use base_model to select the model type. | |
| """ | |
| import json | |
| if len(images) < 5: | |
| raise HTTPException(400, "Need at least 5 training images for reasonable results") | |
| # Parse captions | |
| try: | |
| captions = json.loads(captions_json) if captions_json else {} | |
| except json.JSONDecodeError: | |
| captions = {} | |
| # Save uploaded images to temp directory | |
| import uuid | |
| from content_engine.config import settings | |
| upload_dir = settings.paths.data_dir / "training_uploads" / str(uuid.uuid4())[:8] | |
| upload_dir.mkdir(parents=True, exist_ok=True) | |
| image_paths = [] | |
| for img in images: | |
| file_path = upload_dir / img.filename | |
| content = await img.read() | |
| file_path.write_bytes(content) | |
| image_paths.append(str(file_path)) | |
| # Write caption .txt file alongside the image | |
| caption_text = captions.get(img.filename, trigger_word or "") | |
| caption_path = file_path.with_suffix(".txt") | |
| caption_path.write_text(caption_text, encoding="utf-8") | |
| logger.info("Saved caption for %s: %s", img.filename, caption_text[:80]) | |
| # Route to RunPod cloud trainer | |
| if backend == "runpod": | |
| if _runpod_trainer is None: | |
| raise HTTPException(503, "RunPod not configured — set RUNPOD_API_KEY in .env") | |
| # Validate model exists | |
| model_cfg = _runpod_trainer.get_model_config(base_model) | |
| if not model_cfg: | |
| available = list(_runpod_trainer.list_training_models().keys()) | |
| raise HTTPException(400, f"Unknown base model: {base_model}. Available: {available}") | |
| job_id = await _runpod_trainer.start_training( | |
| name=name, | |
| image_paths=image_paths, | |
| trigger_word=trigger_word, | |
| base_model=base_model, | |
| resolution=resolution, | |
| num_epochs=num_epochs, | |
| max_train_steps=max_steps, | |
| learning_rate=learning_rate, | |
| network_rank=network_rank, | |
| network_alpha=network_alpha, | |
| optimizer=optimizer, | |
| save_every_n_steps=save_every_n_steps, | |
| gpu_type=gpu_type, | |
| ) | |
| job = _runpod_trainer.get_job(job_id) | |
| return { | |
| "job_id": job_id, | |
| "status": job.status if job else "unknown", | |
| "name": name, | |
| "backend": "runpod", | |
| "base_model": base_model, | |
| "model_type": model_cfg.get("model_type", "unknown"), | |
| } | |
| # Local training (uses local GPU with Kohya sd-scripts) | |
| if _trainer is None: | |
| raise HTTPException(503, "Trainer not initialized") | |
| # For local training, use model registry defaults if available | |
| model_cfg = {} | |
| if _runpod_trainer: | |
| model_cfg = _runpod_trainer.get_model_config(base_model) or {} | |
| # Resolve local model path | |
| local_model_path = model_cfg.get("local_path") if model_cfg else None | |
| if not local_model_path: | |
| # Fall back to default local path | |
| local_model_path = str(settings.paths.checkpoint_dir / "realisticVisionV51_v51VAE.safetensors") | |
| config = TrainingConfig( | |
| name=name, | |
| trigger_word=trigger_word, | |
| base_model=local_model_path, | |
| resolution=resolution or model_cfg.get("resolution", 512), | |
| num_epochs=num_epochs, | |
| learning_rate=learning_rate or model_cfg.get("learning_rate", 1e-4), | |
| network_rank=network_rank or model_cfg.get("network_rank", 32), | |
| network_alpha=network_alpha or model_cfg.get("network_alpha", 16), | |
| optimizer=optimizer or model_cfg.get("optimizer", "AdamW8bit"), | |
| train_batch_size=train_batch_size, | |
| save_every_n_epochs=save_every_n_steps, # Local trainer uses epoch-based saving | |
| ) | |
| job_id = await _trainer.start_training(config, image_paths) | |
| job = _trainer.get_job(job_id) | |
| return { | |
| "job_id": job_id, | |
| "status": job.status if job else "unknown", | |
| "name": name, | |
| "backend": "local", | |
| "base_model": base_model, | |
| } | |
| async def list_training_jobs(): | |
| """List all training jobs (local + cloud).""" | |
| jobs = [] | |
| if _trainer: | |
| for j in _trainer.list_jobs(): | |
| jobs.append({ | |
| "id": j.id, "name": j.name, "status": j.status, | |
| "progress": round(j.progress, 3), | |
| "current_epoch": j.current_epoch, "total_epochs": j.total_epochs, | |
| "current_step": j.current_step, "total_steps": j.total_steps, | |
| "loss": j.loss, "started_at": j.started_at, | |
| "completed_at": j.completed_at, "output_path": j.output_path, | |
| "error": j.error, "backend": "local", | |
| "log_lines": j.log_lines[-50:] if hasattr(j, 'log_lines') else [], | |
| }) | |
| if _runpod_trainer: | |
| await _runpod_trainer.ensure_loaded() | |
| for j in _runpod_trainer.list_jobs(): | |
| jobs.append({ | |
| "id": j.id, "name": j.name, "status": j.status, | |
| "progress": round(j.progress, 3), | |
| "current_epoch": j.current_epoch, "total_epochs": j.total_epochs, | |
| "current_step": j.current_step, "total_steps": j.total_steps, | |
| "loss": j.loss, "started_at": j.started_at, | |
| "completed_at": j.completed_at, "output_path": j.output_path, | |
| "error": j.error, "backend": "runpod", | |
| "base_model": j.base_model, "model_type": j.model_type, | |
| "log_lines": j.log_lines[-50:], | |
| }) | |
| return jobs | |
| async def get_training_job(job_id: str): | |
| """Get details of a specific training job including logs.""" | |
| # Check RunPod jobs first | |
| if _runpod_trainer: | |
| await _runpod_trainer.ensure_loaded() | |
| job = _runpod_trainer.get_job(job_id) | |
| if job: | |
| return { | |
| "id": job.id, "name": job.name, "status": job.status, | |
| "progress": round(job.progress, 3), | |
| "current_epoch": job.current_epoch, "total_epochs": job.total_epochs, | |
| "current_step": job.current_step, "total_steps": job.total_steps, | |
| "loss": job.loss, "started_at": job.started_at, | |
| "completed_at": job.completed_at, "output_path": job.output_path, | |
| "error": job.error, "log_lines": job.log_lines[-50:], | |
| "backend": "runpod", "base_model": job.base_model, | |
| } | |
| # Then check local trainer | |
| if _trainer: | |
| job = _trainer.get_job(job_id) | |
| if job: | |
| return { | |
| "id": job.id, "name": job.name, "status": job.status, | |
| "progress": round(job.progress, 3), | |
| "current_epoch": job.current_epoch, "total_epochs": job.total_epochs, | |
| "current_step": job.current_step, "total_steps": job.total_steps, | |
| "loss": job.loss, "started_at": job.started_at, | |
| "completed_at": job.completed_at, "output_path": job.output_path, | |
| "error": job.error, "log_lines": job.log_lines[-50:], | |
| } | |
| raise HTTPException(404, f"Training job not found: {job_id}") | |
| async def cancel_training_job(job_id: str): | |
| """Cancel a running training job (local or cloud).""" | |
| if _runpod_trainer and _runpod_trainer.get_job(job_id): | |
| cancelled = await _runpod_trainer.cancel_job(job_id) | |
| if cancelled: | |
| return {"status": "cancelled", "job_id": job_id} | |
| if _trainer: | |
| cancelled = await _trainer.cancel_job(job_id) | |
| if cancelled: | |
| return {"status": "cancelled", "job_id": job_id} | |
| raise HTTPException(404, "Job not found or not running") | |
| async def delete_training_job(job_id: str): | |
| """Delete a training job from history.""" | |
| if _runpod_trainer: | |
| deleted = await _runpod_trainer.delete_job(job_id) | |
| if deleted: | |
| return {"status": "deleted", "job_id": job_id} | |
| raise HTTPException(404, f"Training job not found: {job_id}") | |
| async def delete_failed_jobs(): | |
| """Delete all failed training jobs.""" | |
| if _runpod_trainer: | |
| count = await _runpod_trainer.delete_failed_jobs() | |
| return {"status": "ok", "deleted": count} | |
| return {"status": "ok", "deleted": 0} | |