"""LoRA training service — train custom LoRA models from reference images. Wraps Kohya's sd-scripts for LoRA training with sensible defaults for character LoRAs on SD 1.5 / RealisticVision. Manages the full pipeline: dataset preparation, config generation, training launch, and output handling. Requirements (installed automatically on first use): - kohya sd-scripts (cloned from GitHub) - accelerate, lion-pytorch, prodigy-optimizer """ from __future__ import annotations import asyncio import json import logging import os import shutil import subprocess import sys import time import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Any logger = logging.getLogger(__name__) IS_HF_SPACES = os.environ.get("HF_SPACES") == "1" or os.environ.get("SPACE_ID") is not None if IS_HF_SPACES: TRAINING_BASE_DIR = Path("/app/data/training") LORA_OUTPUT_DIR = Path("/app/data/loras") else: TRAINING_BASE_DIR = Path("D:/AI automation/content_engine/training") LORA_OUTPUT_DIR = Path("D:/ComfyUI/Models/Lora") SD_SCRIPTS_DIR = TRAINING_BASE_DIR / "sd-scripts" def _default_base_model() -> str: """Get default base model path based on environment.""" if IS_HF_SPACES: return "/app/models/realisticVisionV51_v51VAE.safetensors" return "D:/ComfyUI/Models/StableDiffusion/realisticVisionV51_v51VAE.safetensors" @dataclass class TrainingConfig: """Configuration for a LoRA training job.""" name: str base_model: str = "" # Set in __post_init__ resolution: int = 512 train_batch_size: int = 1 num_epochs: int = 10 learning_rate: float = 1e-4 network_rank: int = 32 # LoRA rank (dim) network_alpha: int = 16 optimizer: str = "AdamW8bit" # AdamW8bit, Lion, Prodigy lr_scheduler: str = "cosine_with_restarts" max_train_steps: int | None = None # If set, overrides epochs save_every_n_epochs: int = 2 clip_skip: int = 1 mixed_precision: str = "fp16" seed: int = 42 caption_extension: str = ".txt" trigger_word: str = "" extra_args: dict[str, Any] = field(default_factory=dict) def __post_init__(self): if not self.base_model: self.base_model = _default_base_model() @dataclass class TrainingJob: """Tracks state of a running or completed training job.""" id: str name: str config: TrainingConfig status: str = "pending" # pending, preparing, training, completed, failed progress: float = 0.0 current_epoch: int = 0 total_epochs: int = 0 current_step: int = 0 total_steps: int = 0 loss: float | None = None started_at: float | None = None completed_at: float | None = None output_path: str | None = None error: str | None = None log_lines: list[str] = field(default_factory=list) class LoRATrainer: """Manages LoRA training jobs using Kohya sd-scripts.""" def __init__(self): self._jobs: dict[str, TrainingJob] = {} self._processes: dict[str, asyncio.subprocess.Process] = {} TRAINING_BASE_DIR.mkdir(parents=True, exist_ok=True) @property def sd_scripts_installed(self) -> bool: return (SD_SCRIPTS_DIR / "train_network.py").exists() async def install_sd_scripts(self) -> str: """Clone and set up Kohya sd-scripts. Returns status message.""" if self.sd_scripts_installed: return "sd-scripts already installed" SD_SCRIPTS_DIR.parent.mkdir(parents=True, exist_ok=True) logger.info("Cloning kohya sd-scripts...") proc = await asyncio.create_subprocess_exec( "git", "clone", "--depth", "1", "https://github.com/kohya-ss/sd-scripts.git", str(SD_SCRIPTS_DIR), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() if proc.returncode != 0: raise RuntimeError(f"Failed to clone sd-scripts: {stderr.decode()}") # Install requirements logger.info("Installing sd-scripts requirements...") proc = await asyncio.create_subprocess_exec( sys.executable, "-m", "pip", "install", "accelerate", "lion-pytorch", "prodigy-optimizer", "safetensors", "diffusers", "transformers", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) await proc.communicate() logger.info("sd-scripts installation complete") return "sd-scripts installed successfully" def prepare_dataset(self, job_id: str, image_paths: list[str], trigger_word: str = "") -> Path: """Prepare a training dataset directory with proper structure. Creates: training/{job_id}/dataset/{num_repeats}_{trigger_word}/ Each image gets a caption file with the trigger word. """ dataset_dir = TRAINING_BASE_DIR / job_id / "dataset" # Convention: {repeats}_{concept_name} repeats = 10 concept_dir = dataset_dir / f"{repeats}_{trigger_word or 'character'}" concept_dir.mkdir(parents=True, exist_ok=True) for img_path in image_paths: src = Path(img_path) if not src.exists(): logger.warning("Image not found: %s", img_path) continue dst = concept_dir / src.name shutil.copy2(src, dst) # Create caption file caption_file = dst.with_suffix(".txt") caption_file.write_text(trigger_word or "") return dataset_dir async def start_training(self, config: TrainingConfig, image_paths: list[str]) -> str: """Start a LoRA training job. Returns the job ID.""" job_id = str(uuid.uuid4())[:8] if not self.sd_scripts_installed: await self.install_sd_scripts() job = TrainingJob( id=job_id, name=config.name, config=config, status="preparing", total_epochs=config.num_epochs, ) self._jobs[job_id] = job # Prepare dataset try: dataset_dir = self.prepare_dataset(job_id, image_paths, config.trigger_word) except Exception as e: job.status = "failed" job.error = f"Dataset preparation failed: {e}" return job_id # Create output directory output_dir = TRAINING_BASE_DIR / job_id / "output" output_dir.mkdir(parents=True, exist_ok=True) # Build training command cmd = self._build_training_command(config, dataset_dir, output_dir) job.log_lines.append(f"Command: {' '.join(cmd)}") # Launch training process job.status = "training" job.started_at = time.time() asyncio.create_task(self._run_training(job_id, cmd, output_dir, config)) return job_id def _build_training_command( self, config: TrainingConfig, dataset_dir: Path, output_dir: Path ) -> list[str]: """Build the training command for Kohya sd-scripts.""" cmd = [ sys.executable, str(SD_SCRIPTS_DIR / "train_network.py"), f"--pretrained_model_name_or_path={config.base_model}", f"--train_data_dir={dataset_dir}", f"--output_dir={output_dir}", f"--output_name={config.name}", f"--resolution={config.resolution}", f"--train_batch_size={config.train_batch_size}", f"--max_train_epochs={config.num_epochs}", f"--learning_rate={config.learning_rate}", f"--network_module=networks.lora", f"--network_dim={config.network_rank}", f"--network_alpha={config.network_alpha}", f"--optimizer_type={config.optimizer}", f"--lr_scheduler={config.lr_scheduler}", f"--save_every_n_epochs={config.save_every_n_epochs}", f"--clip_skip={config.clip_skip}", f"--mixed_precision={config.mixed_precision}", f"--seed={config.seed}", f"--caption_extension={config.caption_extension}", "--cache_latents", "--enable_bucket", "--xformers", "--save_model_as=safetensors", ] if config.max_train_steps: cmd.append(f"--max_train_steps={config.max_train_steps}") return cmd async def _run_training( self, job_id: str, cmd: list[str], output_dir: Path, config: TrainingConfig ): """Run the training process and monitor progress.""" job = self._jobs[job_id] try: proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, cwd=str(SD_SCRIPTS_DIR), ) self._processes[job_id] = proc # Read output lines and parse progress async for line_bytes in proc.stdout: line = line_bytes.decode("utf-8", errors="replace").strip() if not line: continue job.log_lines.append(line) # Keep last 200 lines if len(job.log_lines) > 200: job.log_lines = job.log_lines[-200:] # Parse progress from Kohya output if "epoch" in line.lower() and "/" in line: try: # Look for patterns like "epoch 3/10" parts = line.lower().split("epoch") if len(parts) > 1: ep_part = parts[1].strip().split()[0] if "/" in ep_part: current, total = ep_part.split("/") job.current_epoch = int(current) job.total_epochs = int(total) job.progress = job.current_epoch / max(job.total_epochs, 1) except (ValueError, IndexError): pass if "loss=" in line or "loss:" in line: try: loss_str = line.split("loss")[1].strip("=: ").split()[0].strip(",") job.loss = float(loss_str) except (ValueError, IndexError): pass if "steps:" in line.lower() or "step " in line.lower(): try: import re step_match = re.search(r"(\d+)/(\d+)", line) if step_match: job.current_step = int(step_match.group(1)) job.total_steps = int(step_match.group(2)) if job.total_steps > 0: job.progress = job.current_step / job.total_steps except (ValueError, IndexError): pass await proc.wait() if proc.returncode == 0: job.status = "completed" job.progress = 1.0 job.completed_at = time.time() # Find the output LoRA file and copy to ComfyUI lora_file = output_dir / f"{config.name}.safetensors" if lora_file.exists(): dest = LORA_OUTPUT_DIR / lora_file.name LORA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) shutil.copy2(lora_file, dest) job.output_path = str(dest) logger.info("Training complete! LoRA saved to %s", dest) else: # Check for epoch-saved versions for f in sorted(output_dir.glob("*.safetensors")): dest = LORA_OUTPUT_DIR / f.name shutil.copy2(f, dest) job.output_path = str(dest) logger.info("Training complete! Output in %s", output_dir) else: job.status = "failed" job.error = f"Training process exited with code {proc.returncode}" logger.error("Training failed: %s", job.error) except Exception as e: job.status = "failed" job.error = str(e) logger.error("Training error: %s", e, exc_info=True) finally: self._processes.pop(job_id, None) def get_job(self, job_id: str) -> TrainingJob | None: return self._jobs.get(job_id) def list_jobs(self) -> list[TrainingJob]: return list(self._jobs.values()) async def cancel_job(self, job_id: str) -> bool: """Cancel a running training job.""" proc = self._processes.get(job_id) if proc: proc.terminate() self._processes.pop(job_id, None) job = self._jobs.get(job_id) if job: job.status = "failed" job.error = "Cancelled by user" return True return False