Spaces:
Running
Running
| """LoRA training service — train custom LoRA models from reference images. | |
| Wraps Kohya's sd-scripts for LoRA training with sensible defaults for | |
| character LoRAs on SD 1.5 / RealisticVision. Manages the full pipeline: | |
| dataset preparation, config generation, training launch, and output handling. | |
| Requirements (installed automatically on first use): | |
| - kohya sd-scripts (cloned from GitHub) | |
| - accelerate, lion-pytorch, prodigy-optimizer | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import shutil | |
| import subprocess | |
| import sys | |
| import time | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| logger = logging.getLogger(__name__) | |
| IS_HF_SPACES = os.environ.get("HF_SPACES") == "1" or os.environ.get("SPACE_ID") is not None | |
| if IS_HF_SPACES: | |
| TRAINING_BASE_DIR = Path("/app/data/training") | |
| LORA_OUTPUT_DIR = Path("/app/data/loras") | |
| else: | |
| TRAINING_BASE_DIR = Path("D:/AI automation/content_engine/training") | |
| LORA_OUTPUT_DIR = Path("D:/ComfyUI/Models/Lora") | |
| SD_SCRIPTS_DIR = TRAINING_BASE_DIR / "sd-scripts" | |
| def _default_base_model() -> str: | |
| """Get default base model path based on environment.""" | |
| if IS_HF_SPACES: | |
| return "/app/models/realisticVisionV51_v51VAE.safetensors" | |
| return "D:/ComfyUI/Models/StableDiffusion/realisticVisionV51_v51VAE.safetensors" | |
| class TrainingConfig: | |
| """Configuration for a LoRA training job.""" | |
| name: str | |
| base_model: str = "" # Set in __post_init__ | |
| resolution: int = 512 | |
| train_batch_size: int = 1 | |
| num_epochs: int = 10 | |
| learning_rate: float = 1e-4 | |
| network_rank: int = 32 # LoRA rank (dim) | |
| network_alpha: int = 16 | |
| optimizer: str = "AdamW8bit" # AdamW8bit, Lion, Prodigy | |
| lr_scheduler: str = "cosine_with_restarts" | |
| max_train_steps: int | None = None # If set, overrides epochs | |
| save_every_n_epochs: int = 2 | |
| clip_skip: int = 1 | |
| mixed_precision: str = "fp16" | |
| seed: int = 42 | |
| caption_extension: str = ".txt" | |
| trigger_word: str = "" | |
| extra_args: dict[str, Any] = field(default_factory=dict) | |
| def __post_init__(self): | |
| if not self.base_model: | |
| self.base_model = _default_base_model() | |
| class TrainingJob: | |
| """Tracks state of a running or completed training job.""" | |
| id: str | |
| name: str | |
| config: TrainingConfig | |
| status: str = "pending" # pending, preparing, training, completed, failed | |
| progress: float = 0.0 | |
| current_epoch: int = 0 | |
| total_epochs: int = 0 | |
| current_step: int = 0 | |
| total_steps: int = 0 | |
| loss: float | None = None | |
| started_at: float | None = None | |
| completed_at: float | None = None | |
| output_path: str | None = None | |
| error: str | None = None | |
| log_lines: list[str] = field(default_factory=list) | |
| class LoRATrainer: | |
| """Manages LoRA training jobs using Kohya sd-scripts.""" | |
| def __init__(self): | |
| self._jobs: dict[str, TrainingJob] = {} | |
| self._processes: dict[str, asyncio.subprocess.Process] = {} | |
| TRAINING_BASE_DIR.mkdir(parents=True, exist_ok=True) | |
| def sd_scripts_installed(self) -> bool: | |
| return (SD_SCRIPTS_DIR / "train_network.py").exists() | |
| async def install_sd_scripts(self) -> str: | |
| """Clone and set up Kohya sd-scripts. Returns status message.""" | |
| if self.sd_scripts_installed: | |
| return "sd-scripts already installed" | |
| SD_SCRIPTS_DIR.parent.mkdir(parents=True, exist_ok=True) | |
| logger.info("Cloning kohya sd-scripts...") | |
| proc = await asyncio.create_subprocess_exec( | |
| "git", "clone", "--depth", "1", | |
| "https://github.com/kohya-ss/sd-scripts.git", | |
| str(SD_SCRIPTS_DIR), | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE, | |
| ) | |
| stdout, stderr = await proc.communicate() | |
| if proc.returncode != 0: | |
| raise RuntimeError(f"Failed to clone sd-scripts: {stderr.decode()}") | |
| # Install requirements | |
| logger.info("Installing sd-scripts requirements...") | |
| proc = await asyncio.create_subprocess_exec( | |
| sys.executable, "-m", "pip", "install", | |
| "accelerate", "lion-pytorch", "prodigy-optimizer", | |
| "safetensors", "diffusers", "transformers", | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.PIPE, | |
| ) | |
| await proc.communicate() | |
| logger.info("sd-scripts installation complete") | |
| return "sd-scripts installed successfully" | |
| def prepare_dataset(self, job_id: str, image_paths: list[str], trigger_word: str = "") -> Path: | |
| """Prepare a training dataset directory with proper structure. | |
| Creates: training/{job_id}/dataset/{num_repeats}_{trigger_word}/ | |
| Each image gets a caption file with the trigger word. | |
| """ | |
| dataset_dir = TRAINING_BASE_DIR / job_id / "dataset" | |
| # Convention: {repeats}_{concept_name} | |
| repeats = 10 | |
| concept_dir = dataset_dir / f"{repeats}_{trigger_word or 'character'}" | |
| concept_dir.mkdir(parents=True, exist_ok=True) | |
| for img_path in image_paths: | |
| src = Path(img_path) | |
| if not src.exists(): | |
| logger.warning("Image not found: %s", img_path) | |
| continue | |
| dst = concept_dir / src.name | |
| shutil.copy2(src, dst) | |
| # Create caption file | |
| caption_file = dst.with_suffix(".txt") | |
| caption_file.write_text(trigger_word or "") | |
| return dataset_dir | |
| async def start_training(self, config: TrainingConfig, image_paths: list[str]) -> str: | |
| """Start a LoRA training job. Returns the job ID.""" | |
| job_id = str(uuid.uuid4())[:8] | |
| if not self.sd_scripts_installed: | |
| await self.install_sd_scripts() | |
| job = TrainingJob( | |
| id=job_id, | |
| name=config.name, | |
| config=config, | |
| status="preparing", | |
| total_epochs=config.num_epochs, | |
| ) | |
| self._jobs[job_id] = job | |
| # Prepare dataset | |
| try: | |
| dataset_dir = self.prepare_dataset(job_id, image_paths, config.trigger_word) | |
| except Exception as e: | |
| job.status = "failed" | |
| job.error = f"Dataset preparation failed: {e}" | |
| return job_id | |
| # Create output directory | |
| output_dir = TRAINING_BASE_DIR / job_id / "output" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Build training command | |
| cmd = self._build_training_command(config, dataset_dir, output_dir) | |
| job.log_lines.append(f"Command: {' '.join(cmd)}") | |
| # Launch training process | |
| job.status = "training" | |
| job.started_at = time.time() | |
| asyncio.create_task(self._run_training(job_id, cmd, output_dir, config)) | |
| return job_id | |
| def _build_training_command( | |
| self, config: TrainingConfig, dataset_dir: Path, output_dir: Path | |
| ) -> list[str]: | |
| """Build the training command for Kohya sd-scripts.""" | |
| cmd = [ | |
| sys.executable, | |
| str(SD_SCRIPTS_DIR / "train_network.py"), | |
| f"--pretrained_model_name_or_path={config.base_model}", | |
| f"--train_data_dir={dataset_dir}", | |
| f"--output_dir={output_dir}", | |
| f"--output_name={config.name}", | |
| f"--resolution={config.resolution}", | |
| f"--train_batch_size={config.train_batch_size}", | |
| f"--max_train_epochs={config.num_epochs}", | |
| f"--learning_rate={config.learning_rate}", | |
| f"--network_module=networks.lora", | |
| f"--network_dim={config.network_rank}", | |
| f"--network_alpha={config.network_alpha}", | |
| f"--optimizer_type={config.optimizer}", | |
| f"--lr_scheduler={config.lr_scheduler}", | |
| f"--save_every_n_epochs={config.save_every_n_epochs}", | |
| f"--clip_skip={config.clip_skip}", | |
| f"--mixed_precision={config.mixed_precision}", | |
| f"--seed={config.seed}", | |
| f"--caption_extension={config.caption_extension}", | |
| "--cache_latents", | |
| "--enable_bucket", | |
| "--xformers", | |
| "--save_model_as=safetensors", | |
| ] | |
| if config.max_train_steps: | |
| cmd.append(f"--max_train_steps={config.max_train_steps}") | |
| return cmd | |
| async def _run_training( | |
| self, job_id: str, cmd: list[str], output_dir: Path, config: TrainingConfig | |
| ): | |
| """Run the training process and monitor progress.""" | |
| job = self._jobs[job_id] | |
| try: | |
| proc = await asyncio.create_subprocess_exec( | |
| *cmd, | |
| stdout=asyncio.subprocess.PIPE, | |
| stderr=asyncio.subprocess.STDOUT, | |
| cwd=str(SD_SCRIPTS_DIR), | |
| ) | |
| self._processes[job_id] = proc | |
| # Read output lines and parse progress | |
| async for line_bytes in proc.stdout: | |
| line = line_bytes.decode("utf-8", errors="replace").strip() | |
| if not line: | |
| continue | |
| job.log_lines.append(line) | |
| # Keep last 200 lines | |
| if len(job.log_lines) > 200: | |
| job.log_lines = job.log_lines[-200:] | |
| # Parse progress from Kohya output | |
| if "epoch" in line.lower() and "/" in line: | |
| try: | |
| # Look for patterns like "epoch 3/10" | |
| parts = line.lower().split("epoch") | |
| if len(parts) > 1: | |
| ep_part = parts[1].strip().split()[0] | |
| if "/" in ep_part: | |
| current, total = ep_part.split("/") | |
| job.current_epoch = int(current) | |
| job.total_epochs = int(total) | |
| job.progress = job.current_epoch / max(job.total_epochs, 1) | |
| except (ValueError, IndexError): | |
| pass | |
| if "loss=" in line or "loss:" in line: | |
| try: | |
| loss_str = line.split("loss")[1].strip("=: ").split()[0].strip(",") | |
| job.loss = float(loss_str) | |
| except (ValueError, IndexError): | |
| pass | |
| if "steps:" in line.lower() or "step " in line.lower(): | |
| try: | |
| import re | |
| step_match = re.search(r"(\d+)/(\d+)", line) | |
| if step_match: | |
| job.current_step = int(step_match.group(1)) | |
| job.total_steps = int(step_match.group(2)) | |
| if job.total_steps > 0: | |
| job.progress = job.current_step / job.total_steps | |
| except (ValueError, IndexError): | |
| pass | |
| await proc.wait() | |
| if proc.returncode == 0: | |
| job.status = "completed" | |
| job.progress = 1.0 | |
| job.completed_at = time.time() | |
| # Find the output LoRA file and copy to ComfyUI | |
| lora_file = output_dir / f"{config.name}.safetensors" | |
| if lora_file.exists(): | |
| dest = LORA_OUTPUT_DIR / lora_file.name | |
| LORA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| shutil.copy2(lora_file, dest) | |
| job.output_path = str(dest) | |
| logger.info("Training complete! LoRA saved to %s", dest) | |
| else: | |
| # Check for epoch-saved versions | |
| for f in sorted(output_dir.glob("*.safetensors")): | |
| dest = LORA_OUTPUT_DIR / f.name | |
| shutil.copy2(f, dest) | |
| job.output_path = str(dest) | |
| logger.info("Training complete! Output in %s", output_dir) | |
| else: | |
| job.status = "failed" | |
| job.error = f"Training process exited with code {proc.returncode}" | |
| logger.error("Training failed: %s", job.error) | |
| except Exception as e: | |
| job.status = "failed" | |
| job.error = str(e) | |
| logger.error("Training error: %s", e, exc_info=True) | |
| finally: | |
| self._processes.pop(job_id, None) | |
| def get_job(self, job_id: str) -> TrainingJob | None: | |
| return self._jobs.get(job_id) | |
| def list_jobs(self) -> list[TrainingJob]: | |
| return list(self._jobs.values()) | |
| async def cancel_job(self, job_id: str) -> bool: | |
| """Cancel a running training job.""" | |
| proc = self._processes.get(job_id) | |
| if proc: | |
| proc.terminate() | |
| self._processes.pop(job_id, None) | |
| job = self._jobs.get(job_id) | |
| if job: | |
| job.status = "failed" | |
| job.error = "Cancelled by user" | |
| return True | |
| return False | |