"""RunPod cloud LoRA training — offload training to RunPod GPU pods. Creates a temporary GPU pod, uploads training images, runs Kohya sd-scripts, downloads the finished LoRA, then terminates the pod. No local GPU needed. Supports multiple base models (FLUX, SD 1.5, SDXL) via model registry. Usage: Set RUNPOD_API_KEY in .env Select "Cloud (RunPod)" in the training UI """ from __future__ import annotations import asyncio import logging import time import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Any import runpod import yaml logger = logging.getLogger(__name__) import os from content_engine.config import settings, IS_HF_SPACES from content_engine.models.database import catalog_session_factory, TrainingJob as TrainingJobDB LORA_OUTPUT_DIR = settings.paths.lora_dir if IS_HF_SPACES: CONFIG_DIR = Path("/app/config") else: CONFIG_DIR = Path("D:/AI automation/content_engine/config") # RunPod GPU options (id -> display name, approx cost/hr) GPU_OPTIONS = { # 24GB - SD 1.5, SDXL, FLUX.1 only (NOT enough for FLUX.2) "NVIDIA GeForce RTX 3090": "RTX 3090 24GB (~$0.22/hr)", "NVIDIA GeForce RTX 4090": "RTX 4090 24GB (~$0.44/hr)", "NVIDIA GeForce RTX 5090": "RTX 5090 32GB (~$0.69/hr)", "NVIDIA RTX A4000": "RTX A4000 16GB (~$0.20/hr)", "NVIDIA RTX A5000": "RTX A5000 24GB (~$0.28/hr)", # 48GB+ - Required for FLUX.2 Dev (Mistral text encoder needs ~48GB) "NVIDIA RTX A6000": "RTX A6000 48GB (~$0.76/hr)", "NVIDIA A40": "A40 48GB (~$0.64/hr)", "NVIDIA L40": "L40 48GB (~$0.89/hr)", "NVIDIA L40S": "L40S 48GB (~$1.09/hr)", "NVIDIA A100 80GB PCIe": "A100 80GB (~$1.89/hr)", "NVIDIA A100-SXM4-80GB": "A100 SXM 80GB (~$1.64/hr)", "NVIDIA H100 80GB HBM3": "H100 80GB (~$3.89/hr)", } DEFAULT_GPU = "NVIDIA GeForce RTX 4090" # Network volume for persistent model storage (avoids re-downloading models each run) # Set RUNPOD_VOLUME_ID in .env to use a persistent volume # Set RUNPOD_VOLUME_DC to the datacenter ID where the volume lives (e.g. "EU-RO-1") NETWORK_VOLUME_ID = os.environ.get("RUNPOD_VOLUME_ID", "") NETWORK_VOLUME_DC = os.environ.get("RUNPOD_VOLUME_DC", "") # Docker image with PyTorch + CUDA pre-installed DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04" def load_model_registry() -> dict[str, dict]: """Load training model configurations from config/models.yaml.""" models_file = CONFIG_DIR / "models.yaml" if not models_file.exists(): logger.warning("Model registry not found: %s", models_file) return {} with open(models_file) as f: config = yaml.safe_load(f) return config.get("training_models", {}) @dataclass class CloudTrainingJob: """Tracks state of a RunPod cloud training job.""" id: str name: str status: str = "pending" # pending, creating_pod, uploading, installing, training, downloading, completed, failed progress: float = 0.0 current_epoch: int = 0 total_epochs: int = 0 current_step: int = 0 total_steps: int = 0 loss: float | None = None started_at: float | None = None completed_at: float | None = None output_path: str | None = None error: str | None = None log_lines: list[str] = field(default_factory=list) pod_id: str | None = None gpu_type: str = DEFAULT_GPU cost_estimate: str | None = None base_model: str = "sd15_realistic" model_type: str = "sd15" _db_callback: Any = None # called on state changes to persist to DB def _log(self, msg: str): self.log_lines.append(msg) if len(self.log_lines) > 200: self.log_lines = self.log_lines[-200:] logger.info("[%s] %s", self.id, msg) if self._db_callback: self._db_callback(self) class RunPodTrainer: """Manages LoRA training on RunPod cloud GPUs.""" def __init__(self, api_key: str): self._api_key = api_key runpod.api_key = api_key self._jobs: dict[str, CloudTrainingJob] = {} self._model_registry = load_model_registry() self._loaded_from_db = False @property def available(self) -> bool: """Check if RunPod is configured.""" # Re-set module-level key in case uvicorn reload cleared it if self._api_key: runpod.api_key = self._api_key return bool(self._api_key) def list_gpu_options(self) -> dict[str, str]: return GPU_OPTIONS def list_training_models(self) -> dict[str, dict]: """List available base models for training with their parameters.""" return { key: { "name": cfg.get("name", key), "description": cfg.get("description", ""), "model_type": cfg.get("model_type", "sd15"), "resolution": cfg.get("resolution", 512), "learning_rate": cfg.get("learning_rate", 1e-4), "network_rank": cfg.get("network_rank", 32), "network_alpha": cfg.get("network_alpha", 16), "optimizer": cfg.get("optimizer", "AdamW8bit"), "lr_scheduler": cfg.get("lr_scheduler", "cosine"), "vram_required_gb": cfg.get("vram_required_gb", 8), "recommended_images": cfg.get("recommended_images", "15-30 photos"), } for key, cfg in self._model_registry.items() } def get_model_config(self, model_key: str) -> dict | None: """Get configuration for a specific training model.""" return self._model_registry.get(model_key) async def start_training( self, *, name: str, image_paths: list[str], trigger_word: str = "", base_model: str = "sd15_realistic", resolution: int | None = None, num_epochs: int = 10, max_train_steps: int | None = None, learning_rate: float | None = None, network_rank: int | None = None, network_alpha: int | None = None, optimizer: str | None = None, save_every_n_epochs: int = 2, save_every_n_steps: int = 500, gpu_type: str = DEFAULT_GPU, ) -> str: """Start a cloud training job. Returns job ID. Parameters use model registry defaults if not specified. """ job_id = str(uuid.uuid4())[:8] # Get model config (fall back to sd15_realistic if not found) model_cfg = self._model_registry.get(base_model, self._model_registry.get("sd15_realistic", {})) model_type = model_cfg.get("model_type", "sd15") # Use provided values or model defaults final_resolution = resolution or model_cfg.get("resolution", 512) final_lr = learning_rate or model_cfg.get("learning_rate", 1e-4) final_rank = network_rank or model_cfg.get("network_rank", 32) final_alpha = network_alpha or model_cfg.get("network_alpha", 16) final_optimizer = optimizer or model_cfg.get("optimizer", "AdamW8bit") final_steps = max_train_steps or model_cfg.get("max_train_steps") job = CloudTrainingJob( id=job_id, name=name, status="pending", total_epochs=num_epochs, total_steps=final_steps, gpu_type=gpu_type, started_at=time.time(), base_model=base_model, model_type=model_type, ) self._jobs[job_id] = job job._db_callback = self._schedule_db_save asyncio.ensure_future(self._save_to_db(job)) # Launch the full pipeline as a background task asyncio.create_task(self._run_cloud_training( job=job, image_paths=image_paths, trigger_word=trigger_word, model_cfg=model_cfg, resolution=final_resolution, num_epochs=num_epochs, max_train_steps=final_steps, learning_rate=final_lr, network_rank=final_rank, network_alpha=final_alpha, optimizer=final_optimizer, save_every_n_epochs=save_every_n_epochs, save_every_n_steps=save_every_n_steps, )) return job_id async def _run_cloud_training( self, job: CloudTrainingJob, image_paths: list[str], trigger_word: str, model_cfg: dict, resolution: int, num_epochs: int, max_train_steps: int | None, learning_rate: float, network_rank: int, network_alpha: int, optimizer: str, save_every_n_epochs: int, save_every_n_steps: int = 500, ): """Full cloud training pipeline: create pod -> upload -> train -> download -> cleanup.""" ssh = None sftp = None model_type = model_cfg.get("model_type", "sd15") name = job.name try: # Step 1: Create pod job.status = "creating_pod" job._log(f"Creating RunPod with {job.gpu_type}...") # Use network volume if configured (persists models across runs) pod_kwargs = { "container_disk_in_gb": 30, "ports": "22/tcp", "docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'", } if NETWORK_VOLUME_ID: pod_kwargs["network_volume_id"] = NETWORK_VOLUME_ID if NETWORK_VOLUME_DC: pod_kwargs["data_center_id"] = NETWORK_VOLUME_DC job._log(f"Using persistent network volume: {NETWORK_VOLUME_ID} (DC: {NETWORK_VOLUME_DC or 'auto'})") else: pod_kwargs["volume_in_gb"] = 75 pod = await asyncio.to_thread( runpod.create_pod, f"lora-train-{job.id}", DOCKER_IMAGE, job.gpu_type, **pod_kwargs, ) job.pod_id = pod["id"] job._log(f"Pod created: {job.pod_id}") # Wait for pod to be ready and get SSH info job._log("Waiting for pod to start...") ssh_host, ssh_port = await self._wait_for_pod_ready(job) job._log(f"Pod ready at {ssh_host}:{ssh_port}") # Step 2: Connect via SSH import paramiko ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) for attempt in range(30): try: await asyncio.to_thread( ssh.connect, ssh_host, port=ssh_port, username="root", password="runpod", timeout=10, ) break except Exception: if attempt == 29: raise RuntimeError("Could not SSH into pod after 30 attempts") await asyncio.sleep(5) job._log("SSH connected") # If using network volume, symlink to /workspace so all paths work if NETWORK_VOLUME_ID: await self._ssh_exec(ssh, "mkdir -p /runpod-volume/models && rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models") job._log("Network volume symlinked to /workspace") # Enable keepalive to prevent SSH timeout during uploads transport = ssh.get_transport() transport.set_keepalive(30) sftp = ssh.open_sftp() sftp.get_channel().settimeout(300) # 5 min timeout per file # Step 3: Upload training images (compress first to speed up transfer) job.status = "uploading" resolution = model_cfg.get("resolution", 1024) job._log(f"Compressing and uploading {len(image_paths)} training images...") import tempfile from PIL import Image tmp_dir = Path(tempfile.mkdtemp(prefix="lora_upload_")) folder_name = f"10_{trigger_word or 'character'}" await self._ssh_exec(ssh, f"mkdir -p /workspace/dataset/{folder_name}") for i, img_path in enumerate(image_paths): p = Path(img_path) if p.exists(): # Resize and convert to JPEG to reduce upload size try: img = Image.open(p) img.thumbnail((resolution * 2, resolution * 2), Image.LANCZOS) compressed = tmp_dir / f"{p.stem}.jpg" img.save(compressed, "JPEG", quality=95) upload_path = compressed except Exception: upload_path = p # fallback to original remote_name = f"{p.stem}.jpg" if upload_path.suffix == ".jpg" else p.name remote_path = f"/workspace/dataset/{folder_name}/{remote_name}" for attempt in range(3): try: await asyncio.to_thread(sftp.put, str(upload_path), remote_path) break except (EOFError, OSError): if attempt == 2: raise job._log(f"Upload retry {attempt+1} for {p.name}") sftp.close() sftp = ssh.open_sftp() sftp.get_channel().settimeout(300) # Upload matching caption .txt file if it exists locally local_caption = p.with_suffix(".txt") if local_caption.exists(): remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt" await asyncio.to_thread(sftp.put, str(local_caption), remote_caption) else: # Fallback: create caption from trigger word remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt" def _write_caption(): with sftp.open(remote_caption, "w") as f: f.write(trigger_word or "") await asyncio.to_thread(_write_caption) job._log(f"Uploaded {i+1}/{len(image_paths)}: {p.name}") # Cleanup temp compressed images import shutil shutil.rmtree(tmp_dir, ignore_errors=True) job._log("Images uploaded") # Step 4: Install training framework on the pod (skip if cached on volume) job.status = "installing" job.progress = 0.05 training_framework = model_cfg.get("training_framework", "sd-scripts") if training_framework == "musubi-tuner": # FLUX.2 uses musubi-tuner (Kohya's newer framework) tuner_dir = "/workspace/musubi-tuner" install_cmds = [] # Check if already present in workspace tuner_exist = (await self._ssh_exec(ssh, f"test -f {tuner_dir}/pyproject.toml && echo EXISTS || echo MISSING")).strip() if tuner_exist == "EXISTS": job._log("musubi-tuner found in workspace") else: # Check volume cache vol_exist = (await self._ssh_exec(ssh, "test -f /runpod-volume/musubi-tuner/pyproject.toml && echo EXISTS || echo MISSING")).strip() if vol_exist == "EXISTS": job._log("Restoring musubi-tuner from volume cache...") await self._ssh_exec(ssh, f"rm -rf {tuner_dir} 2>/dev/null; cp -r /runpod-volume/musubi-tuner {tuner_dir}") else: job._log("Cloning musubi-tuner from GitHub...") await self._ssh_exec(ssh, f"rm -rf {tuner_dir} /runpod-volume/musubi-tuner 2>/dev/null; true") install_cmds.append(f"cd /workspace && git clone --depth 1 https://github.com/kohya-ss/musubi-tuner.git") # Save to volume for future pods if NETWORK_VOLUME_ID: install_cmds.append(f"cp -r {tuner_dir} /runpod-volume/musubi-tuner") # Always install pip deps (they are pod-local, lost on every new pod) job._log("Installing pip dependencies (accelerate, torch, etc.)...") install_cmds.extend([ f"cd {tuner_dir} && pip install -e . 2>&1 | tail -5", "pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes 2>&1 | tail -5", ]) else: # SD 1.5 / SDXL / FLUX.1 use sd-scripts scripts_exist = (await self._ssh_exec(ssh, "test -f /workspace/sd-scripts/setup.py && echo EXISTS || echo MISSING")).strip() if scripts_exist == "EXISTS": job._log("Kohya sd-scripts already cached on volume, updating...") install_cmds = [ "cd /workspace/sd-scripts && git pull 2>&1 | tail -1", ] else: job._log("Installing Kohya sd-scripts (this takes a few minutes)...") install_cmds = [ "cd /workspace && git clone --depth 1 https://github.com/kohya-ss/sd-scripts.git", ] # Always install pip deps (pod-local, lost on new pods) install_cmds.extend([ "cd /workspace/sd-scripts && pip install -r requirements.txt 2>&1 | tail -1", "pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes xformers 2>&1 | tail -1", ]) for cmd in install_cmds: out = await self._ssh_exec(ssh, cmd, timeout=600) job._log(out[:200] if out else "done") # Download base model from HuggingFace (skip if already on network volume) hf_repo = model_cfg.get("hf_repo", "SG161222/Realistic_Vision_V5.1_noVAE") hf_filename = model_cfg.get("hf_filename", "Realistic_Vision_V5.1_fp16-no-ema.safetensors") model_name = model_cfg.get("name", job.base_model) job.progress = 0.1 await self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120) if model_type == "flux2": # FLUX.2 models are stored in a directory structure on the volume flux2_dir = "/workspace/models/FLUX.2-dev" dit_path = f"{flux2_dir}/flux2-dev.safetensors" vae_path = f"{flux2_dir}/ae.safetensors" # Original BFL format (not diffusers) te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors" dit_exists = (await self._ssh_exec(ssh, f"test -f {dit_path} && echo EXISTS || echo MISSING")).strip() vae_exists = (await self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING")).strip() te_exists = (await self._ssh_exec(ssh, f"test -f {te_path} && echo EXISTS || echo MISSING")).strip() if dit_exists != "EXISTS" or te_exists != "EXISTS": missing = [] if dit_exists != "EXISTS": missing.append("DiT") if te_exists != "EXISTS": missing.append("text encoder") raise RuntimeError(f"FLUX.2 Dev missing on volume: {', '.join(missing)}. Please download models to the network volume first.") # Download ae.safetensors (original format VAE) if not present if vae_exists != "EXISTS": job._log("Downloading FLUX.2 VAE (ae.safetensors, 336MB)...") await self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120) await self._ssh_exec(ssh, f"""python -c " from huggingface_hub import hf_hub_download hf_hub_download('black-forest-labs/FLUX.2-dev', 'ae.safetensors', local_dir='{flux2_dir}') print('Downloaded ae.safetensors') " 2>&1 | tail -5""", timeout=600) # Verify download vae_check = (await self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING")).strip() if vae_check != "EXISTS": raise RuntimeError("Failed to download ae.safetensors") job._log("VAE downloaded") job._log("FLUX.2 Dev models ready") elif model_type == "wan22": # WAN 2.2 T2V — 4 model files stored in /workspace/models/WAN2.2/ wan_dir = "/workspace/models/WAN2.2" await self._ssh_exec(ssh, f"mkdir -p {wan_dir}") wan_files = { "DiT low-noise": { "path": f"{wan_dir}/wan2.2_t2v_low_noise_14B_fp16.safetensors", "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "filename": "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp16.safetensors", }, "DiT high-noise": { "path": f"{wan_dir}/wan2.2_t2v_high_noise_14B_fp16.safetensors", "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "filename": "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp16.safetensors", }, "VAE": { "path": f"{wan_dir}/Wan2.1_VAE.pth", "repo": "Wan-AI/Wan2.1-I2V-14B-720P", "filename": "Wan2.1_VAE.pth", }, "T5 text encoder": { "path": f"{wan_dir}/models_t5_umt5-xxl-enc-bf16.pth", "repo": "Wan-AI/Wan2.1-I2V-14B-720P", "filename": "models_t5_umt5-xxl-enc-bf16.pth", }, } for label, info in wan_files.items(): exists = (await self._ssh_exec(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip() if exists == "EXISTS": job._log(f"WAN 2.2 {label} already cached") else: job._log(f"Downloading WAN 2.2 {label}...") await self._ssh_exec(ssh, f"""python -c " from huggingface_hub import hf_hub_download hf_hub_download('{info['repo']}', '{info['filename']}', local_dir='{wan_dir}') # hf_hub_download puts files in subdirs matching the filename path — move to root import os, shutil downloaded = os.path.join('{wan_dir}', '{info['filename']}') target = '{info['path']}' if os.path.exists(downloaded) and downloaded != target: shutil.move(downloaded, target) print('Downloaded {label}') " 2>&1 | tail -5""", timeout=1800) # Verify check = (await self._ssh_exec(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip() if check != "EXISTS": raise RuntimeError(f"Failed to download WAN 2.2 {label}") job._log("WAN 2.2 models ready") else: # SD 1.5 / SDXL / FLUX.1 — download single model file model_exists = (await self._ssh_exec(ssh, f"test -f /workspace/models/{hf_filename} && echo EXISTS || echo MISSING")).strip() if model_exists == "EXISTS": job._log(f"Base model already cached on volume: {model_name}") else: job._log(f"Downloading base model: {model_name}...") await self._ssh_exec(ssh, f""" python -c " from huggingface_hub import hf_hub_download hf_hub_download('{hf_repo}', '{hf_filename}', local_dir='/workspace/models') " 2>&1 | tail -5 """, timeout=1200) # For FLUX.1, download additional required models (CLIP, T5, VAE) if model_type == "flux": flux_files_check = (await self._ssh_exec(ssh, "test -f /workspace/models/clip_l.safetensors && test -f /workspace/models/t5xxl_fp16.safetensors && test -f /workspace/models/ae.safetensors && echo EXISTS || echo MISSING")).strip() if flux_files_check == "EXISTS": job._log("FLUX.1 auxiliary models already cached on volume") else: job._log("Downloading FLUX.1 auxiliary models (CLIP, T5, VAE)...") job.progress = 0.12 await self._ssh_exec(ssh, """ python -c " from huggingface_hub import hf_hub_download hf_hub_download('comfyanonymous/flux_text_encoders', 'clip_l.safetensors', local_dir='/workspace/models') hf_hub_download('comfyanonymous/flux_text_encoders', 't5xxl_fp16.safetensors', local_dir='/workspace/models') hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/workspace/models') " 2>&1 | tail -5 """, timeout=1200) job._log("Base model ready") job.progress = 0.15 # Step 5: Run training job.status = "training" job._log(f"Starting {model_type.upper()} LoRA training...") if model_type == "flux2": model_path = f"/workspace/models/FLUX.2-dev/flux2-dev.safetensors" elif model_type == "wan22": model_path = "/workspace/models/WAN2.2/wan2.2_t2v_low_noise_14B_fp16.safetensors" else: model_path = f"/workspace/models/{hf_filename}" # For musubi-tuner, create TOML dataset config if training_framework == "musubi-tuner": folder_name = f"10_{trigger_word or 'character'}" toml_content = f"""[[datasets]] image_directory = "/workspace/dataset/{folder_name}" caption_extension = ".txt" batch_size = 1 num_repeats = 10 resolution = [{resolution}, {resolution}] """ await self._ssh_exec(ssh, f"cat > /workspace/dataset.toml << 'TOMLEOF'\n{toml_content}TOMLEOF") job._log("Created dataset.toml config") # musubi-tuner requires pre-caching latents and text encoder outputs if model_type == "wan22": wan_dir = "/workspace/models/WAN2.2" vae_path = f"{wan_dir}/Wan2.1_VAE.pth" te_path = f"{wan_dir}/models_t5_umt5-xxl-enc-bf16.pth" job._log("Caching WAN 2.2 latents (VAE encoding)...") job.progress = 0.15 self._schedule_db_save(job) cache_latents_cmd = ( f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" f" python src/musubi_tuner/wan_cache_latents.py" f" --dataset_config /workspace/dataset.toml" f" --vae {vae_path}" f" --vae_dtype bfloat16" f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}" ) out = await self._ssh_exec(ssh, cache_latents_cmd, timeout=600) last_lines = out.split('\n')[-30:] job._log('\n'.join(last_lines)) if "EXIT_CODE=0" not in out: err_log = await self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10") job._log(f"Cache error details: {err_log}") raise RuntimeError(f"WAN latent caching failed") job._log("Caching WAN 2.2 text encoder outputs (T5)...") job.progress = 0.25 self._schedule_db_save(job) cache_te_cmd = ( f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" f" python src/musubi_tuner/wan_cache_text_encoder_outputs.py" f" --dataset_config /workspace/dataset.toml" f" --t5 {te_path}" f" --batch_size 16" f" 2>&1; echo EXIT_CODE=$?" ) out = await self._ssh_exec(ssh, cache_te_cmd, timeout=600) job._log(out[-500:] if out else "done") if "EXIT_CODE=0" not in out: raise RuntimeError(f"WAN text encoder caching failed: {out[-200:]}") else: # FLUX.2 caching flux2_dir = "/workspace/models/FLUX.2-dev" vae_path = f"{flux2_dir}/ae.safetensors" te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors" job._log("Caching latents (VAE encoding)...") job.progress = 0.15 self._schedule_db_save(job) cache_latents_cmd = ( f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python src/musubi_tuner/flux_2_cache_latents.py" f" --dataset_config /workspace/dataset.toml" f" --vae {vae_path}" f" --model_version dev" f" --vae_dtype bfloat16" f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}" ) out = await self._ssh_exec(ssh, cache_latents_cmd, timeout=600) last_lines = out.split('\n')[-30:] job._log('\n'.join(last_lines)) if "EXIT_CODE=0" not in out: err_log = await self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10") job._log(f"Cache error details: {err_log}") raise RuntimeError(f"Latent caching failed") job._log("Caching text encoder outputs (bf16)...") job.progress = 0.25 self._schedule_db_save(job) cache_te_cmd = ( f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" f" python src/musubi_tuner/flux_2_cache_text_encoder_outputs.py" f" --dataset_config /workspace/dataset.toml" f" --text_encoder {te_path}" f" --model_version dev" f" --batch_size 1" f" 2>&1; echo EXIT_CODE=$?" ) out = await self._ssh_exec(ssh, cache_te_cmd, timeout=600) job._log(out[-500:] if out else "done") if "EXIT_CODE=0" not in out: raise RuntimeError(f"Text encoder caching failed: {out[-200:]}") # Build training command based on model type train_cmd = self._build_training_command( model_type=model_type, model_path=model_path, name=name, resolution=resolution, num_epochs=num_epochs, max_train_steps=max_train_steps, learning_rate=learning_rate, network_rank=network_rank, network_alpha=network_alpha, optimizer=optimizer, save_every_n_epochs=save_every_n_epochs, save_every_n_steps=save_every_n_steps, model_cfg=model_cfg, gpu_type=job.gpu_type, ) # Execute training in a detached process (survives SSH disconnect) job._log("Starting training (detached — survives disconnects)...") log_file = "/tmp/training.log" pid_file = "/tmp/training.pid" exit_file = "/tmp/training.exit" await self._ssh_exec(ssh, f"rm -f {log_file} {exit_file} {pid_file}") # Write training command to a script file (avoids quoting issues with nohup) script_file = "/tmp/train.sh" await self._ssh_exec(ssh, f"cat > {script_file} << 'TRAINEOF'\n#!/bin/bash\n{train_cmd} > {log_file} 2>&1\necho $? > {exit_file}\nTRAINEOF") await self._ssh_exec(ssh, f"chmod +x {script_file}") # Verify script was written script_check = (await self._ssh_exec(ssh, f"wc -l < {script_file}")).strip() job._log(f"Training script written ({script_check} lines)") # Launch fully detached: close all FDs so SSH channel doesn't hang await self._ssh_exec( ssh, f"setsid {script_file} /dev/null 2>&1 &\necho $! > {pid_file}", timeout=15, ) await asyncio.sleep(3) pid = (await self._ssh_exec(ssh, f"cat {pid_file} 2>/dev/null")).strip() if not pid: # Fallback: find the process by script name pid = (await self._ssh_exec(ssh, "pgrep -f train.sh 2>/dev/null | head -1")).strip() job._log(f"Training PID: {pid}") # Verify process is actually running if pid: running = (await self._ssh_exec(ssh, f"kill -0 {pid} 2>&1 && echo RUNNING || echo DEAD")).strip() job._log(f"Process status: {running}") if "DEAD" in running: # Check if it already wrote an exit code (fast failure) early_exit = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip() early_log = (await self._ssh_exec(ssh, f"cat {log_file} 2>/dev/null | tail -20")).strip() raise RuntimeError(f"Training process died immediately. Exit: {early_exit}\nLog: {early_log}") else: early_log = (await self._ssh_exec(ssh, f"cat {log_file} 2>/dev/null | tail -20")).strip() raise RuntimeError(f"Failed to start training process.\nLog: {early_log}") # Monitor the log file (reconnect-safe) last_offset = 0 while True: # Check if training finished exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip() if exit_check: exit_code = int(exit_check) # Read remaining log remaining = (await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)) if remaining: for line in remaining.split("\n"): line = line.strip() if line: job._log(line) self._parse_progress(job, line) if exit_code != 0: raise RuntimeError(f"Training failed with exit code {exit_code}") break # Read new log output try: new_output = (await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)) if new_output: last_offset += len(new_output.encode("utf-8")) for line in new_output.replace("\r", "\n").split("\n"): line = line.strip() if not line: continue job._log(line) self._parse_progress(job, line) self._schedule_db_save(job) except Exception: job._log("Log read failed, retrying...") await asyncio.sleep(5) job._log("Training completed on RunPod!") job.progress = 0.9 # Step 6: Save LoRA to network volume and download locally job.status = "downloading" # First, copy to network volume for persistence job._log("Saving LoRA to network volume...") await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras") remote_output = f"/workspace/output/{name}.safetensors" # Find the output file check = (await self._ssh_exec(ssh, f"test -f {remote_output} && echo EXISTS || echo MISSING")).strip() if check == "MISSING": remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip() if remote_files: remote_output = remote_files.split("\n")[-1].strip() else: raise RuntimeError("No .safetensors output found") await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors") job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors") # Also save intermediate checkpoints (step 500, 1000, 1500, etc.) checkpoint_files = (await self._ssh_exec(ssh, f"ls /workspace/output/{name}-step*.safetensors 2>/dev/null")).strip() if checkpoint_files: for ckpt in checkpoint_files.split("\n"): ckpt = ckpt.strip() if ckpt: ckpt_name = ckpt.split("/")[-1] await self._ssh_exec(ssh, f"cp {ckpt} /runpod-volume/loras/{ckpt_name}") job._log(f"Checkpoint saved: /runpod-volume/loras/{ckpt_name}") # Download locally (skip on HF Spaces — limited storage) if IS_HF_SPACES: job.output_path = f"/runpod-volume/loras/{name}.safetensors" job._log("LoRA saved on RunPod volume (ready for generation)") else: job._log("Downloading LoRA to local machine...") LORA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) local_path = LORA_OUTPUT_DIR / f"{name}.safetensors" await asyncio.to_thread(sftp.get, remote_output, str(local_path)) job.output_path = str(local_path) job._log(f"LoRA saved locally to {local_path}") # Done! job.status = "completed" job.progress = 1.0 job.completed_at = time.time() elapsed = (job.completed_at - job.started_at) / 60 job._log(f"Cloud training complete in {elapsed:.1f} minutes") except Exception as e: job.status = "failed" job.error = str(e) job._log(f"ERROR: {e}") logger.error("Cloud training failed: %s", e, exc_info=True) finally: # Cleanup: close SSH and terminate pod if sftp: try: sftp.close() except Exception: pass if ssh: try: ssh.close() except Exception: pass # Clean up local training images (saves HF Spaces storage) if image_paths: import shutil first_image_dir = Path(image_paths[0]).parent if first_image_dir.exists() and "training_uploads" in str(first_image_dir): shutil.rmtree(first_image_dir, ignore_errors=True) if job.pod_id: try: job._log("Terminating RunPod...") await asyncio.to_thread(runpod.terminate_pod, job.pod_id) job._log("Pod terminated") except Exception as e: job._log(f"Warning: Failed to terminate pod {job.pod_id}: {e}") def _schedule_db_save(self, job: CloudTrainingJob): """Schedule a DB save (non-blocking).""" try: asyncio.get_event_loop().create_task(self._save_to_db(job)) except RuntimeError: pass # no event loop async def _save_to_db(self, job: CloudTrainingJob): """Persist job state to database.""" try: from sqlalchemy import text async with catalog_session_factory() as session: # Use raw INSERT OR REPLACE for SQLite upsert await session.execute( text("""INSERT OR REPLACE INTO training_jobs (id, name, status, progress, current_epoch, total_epochs, current_step, total_steps, loss, started_at, completed_at, output_path, error, log_text, pod_id, gpu_type, backend, base_model, model_type, created_at) VALUES (:id, :name, :status, :progress, :current_epoch, :total_epochs, :current_step, :total_steps, :loss, :started_at, :completed_at, :output_path, :error, :log_text, :pod_id, :gpu_type, :backend, :base_model, :model_type, COALESCE((SELECT created_at FROM training_jobs WHERE id = :id), CURRENT_TIMESTAMP)) """), { "id": job.id, "name": job.name, "status": job.status, "progress": job.progress, "current_epoch": job.current_epoch, "total_epochs": job.total_epochs, "current_step": job.current_step, "total_steps": job.total_steps, "loss": job.loss, "started_at": job.started_at, "completed_at": job.completed_at, "output_path": job.output_path, "error": job.error, "log_text": "\n".join(job.log_lines[-200:]), "pod_id": job.pod_id, "gpu_type": job.gpu_type, "backend": "runpod", "base_model": job.base_model, "model_type": job.model_type, } ) await session.commit() except Exception as e: logger.warning("Failed to save training job to DB: %s", e) async def _load_jobs_from_db(self): """Load previously saved jobs from database on startup.""" try: from sqlalchemy import select async with catalog_session_factory() as session: result = await session.execute( select(TrainingJobDB).order_by(TrainingJobDB.created_at.desc()).limit(20) ) db_jobs = result.scalars().all() for db_job in db_jobs: if db_job.id not in self._jobs: job = CloudTrainingJob( id=db_job.id, name=db_job.name, status=db_job.status, progress=db_job.progress or 0.0, current_epoch=db_job.current_epoch or 0, total_epochs=db_job.total_epochs or 0, current_step=db_job.current_step or 0, total_steps=db_job.total_steps or 0, loss=db_job.loss, started_at=db_job.started_at, completed_at=db_job.completed_at, output_path=db_job.output_path, error=db_job.error, log_lines=(db_job.log_text or "").split("\n") if db_job.log_text else [], pod_id=db_job.pod_id, gpu_type=db_job.gpu_type or DEFAULT_GPU, base_model=db_job.base_model or "sd15_realistic", model_type=db_job.model_type or "sd15", ) self._jobs[db_job.id] = job # Try to reconnect to running training pods if job.status not in ("completed", "failed") and job.pod_id: try: pod = await asyncio.to_thread(runpod.get_pod, job.pod_id) if pod and pod.get("desiredStatus") == "RUNNING": job.status = "training" job.error = None job._log("Reconnecting to running training pod after restart...") asyncio.create_task(self._reconnect_training(job)) logger.info("Reconnecting to training pod %s for job %s", job.pod_id, job.id) else: job.status = "failed" job.error = "Pod terminated during server restart" except Exception as e: logger.warning("Could not check pod %s: %s", job.pod_id, e) job.status = "failed" job.error = "Interrupted by server restart" elif job.status not in ("completed", "failed"): job.status = "failed" job.error = "Interrupted by server restart" except Exception as e: logger.warning("Failed to load training jobs from DB: %s", e) async def ensure_loaded(self): """Load jobs from DB on first access.""" if not self._loaded_from_db: self._loaded_from_db = True await self._load_jobs_from_db() async def _reconnect_training(self, job: CloudTrainingJob): """Reconnect to a training pod after server restart and resume log monitoring.""" import paramiko ssh = None try: # Get SSH info from RunPod pod = await asyncio.to_thread(runpod.get_pod, job.pod_id) if not pod: raise RuntimeError("Pod not found") runtime = pod.get("runtime") or {} ports = runtime.get("ports") or [] ssh_host = ssh_port = None for p in ports: if p.get("privatePort") == 22: ssh_host = p.get("ip") ssh_port = p.get("publicPort") if not ssh_host or not ssh_port: raise RuntimeError("SSH port not available") # Connect SSH ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) await asyncio.to_thread( ssh.connect, ssh_host, port=int(ssh_port), username="root", password="runpod", timeout=10, ) transport = ssh.get_transport() transport.set_keepalive(30) job._log(f"Reconnected to pod {job.pod_id}") # Check if training is still running log_file = "/tmp/training.log" exit_file = "/tmp/training.exit" pid_file = "/tmp/training.pid" exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip() if exit_check: # Training already finished while we were disconnected exit_code = int(exit_check) log_tail = await self._ssh_exec(ssh, f"tail -50 {log_file} 2>/dev/null") for line in log_tail.split("\n"): line = line.strip() if line: job._log(line) self._parse_progress(job, line) if exit_code == 0: job._log("Training completed while disconnected!") # Copy LoRA to volume name = job.name await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras") remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip() if remote_files: remote_output = remote_files.split("\n")[-1].strip() await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors") job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors") job.output_path = f"/runpod-volume/loras/{name}.safetensors" job.status = "completed" job.progress = 1.0 job.completed_at = time.time() else: raise RuntimeError(f"Training failed with exit code {exit_code}") else: # Training still running — resume log monitoring pid = (await self._ssh_exec(ssh, f"cat {pid_file} 2>/dev/null")).strip() job._log(f"Training still running (PID: {pid}), resuming monitoring...") last_offset = 0 while True: exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip() if exit_check: exit_code = int(exit_check) remaining = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30) if remaining: for line in remaining.split("\n"): line = line.strip() if line: job._log(line) self._parse_progress(job, line) if exit_code == 0: # Copy LoRA to volume name = job.name await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras") remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip() if remote_files: remote_output = remote_files.split("\n")[-1].strip() await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors") job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors") job.output_path = f"/runpod-volume/loras/{name}.safetensors" job.status = "completed" job.progress = 1.0 job.completed_at = time.time() break else: raise RuntimeError(f"Training failed with exit code {exit_code}") try: new_output = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30) if new_output: last_offset += len(new_output.encode("utf-8")) for line in new_output.replace("\r", "\n").split("\n"): line = line.strip() if line: job._log(line) self._parse_progress(job, line) self._schedule_db_save(job) except Exception: pass await asyncio.sleep(5) job._log("Training complete!") except Exception as e: job.status = "failed" job.error = str(e) job._log(f"Reconnect failed: {e}") logger.error("Training reconnect failed for %s: %s", job.id, e) finally: if ssh: try: ssh.close() except Exception: pass # Terminate pod if job.pod_id: try: await asyncio.to_thread(runpod.terminate_pod, job.pod_id) job._log("Pod terminated") except Exception: pass self._schedule_db_save(job) def _build_training_command( self, *, model_type: str, model_path: str, name: str, resolution: int, num_epochs: int, max_train_steps: int | None, learning_rate: float, network_rank: int, network_alpha: int, optimizer: str, save_every_n_epochs: int, save_every_n_steps: int = 500, model_cfg: dict, gpu_type: str = "", ) -> str: """Build the training command based on model type.""" # Common parameters base_args = f""" --train_data_dir="/workspace/dataset" \ --output_dir="/workspace/output" \ --output_name="{name}" \ --resolution={resolution} \ --train_batch_size=1 \ --learning_rate={learning_rate} \ --network_module=networks.lora \ --network_dim={network_rank} \ --network_alpha={network_alpha} \ --optimizer_type={optimizer} \ --save_every_n_epochs={save_every_n_epochs} \ --mixed_precision=fp16 \ --seed=42 \ --caption_extension=.txt \ --cache_latents \ --enable_bucket \ --save_model_as=safetensors""" # Steps vs epochs if max_train_steps: base_args += f" \\\n --max_train_steps={max_train_steps}" else: base_args += f" \\\n --max_train_epochs={num_epochs}" # LR scheduler lr_scheduler = model_cfg.get("lr_scheduler", "cosine_with_restarts") base_args += f" \\\n --lr_scheduler={lr_scheduler}" if model_type == "flux2": # FLUX.2 training via musubi-tuner flux2_dir = "/workspace/models/FLUX.2-dev" dit_path = f"{flux2_dir}/flux2-dev.safetensors" vae_path = f"{flux2_dir}/ae.safetensors" te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors" network_mod = model_cfg.get("network_module", "networks.lora_flux_2") ts_sampling = model_cfg.get("timestep_sampling", "flux2_shift") lr_scheduler = model_cfg.get("lr_scheduler", "cosine") # Build as list of args to avoid shell escaping issues args = [ "cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True", "accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16", "src/musubi_tuner/flux_2_train_network.py", "--model_version dev", f"--dit {dit_path}", f"--vae {vae_path}", f"--text_encoder {te_path}", "--dataset_config /workspace/dataset.toml", "--sdpa --mixed_precision bf16", f"--timestep_sampling {ts_sampling} --weighting_scheme none", f"--network_module {network_mod}", f"--network_dim={network_rank}", f"--network_alpha={network_alpha}", "--gradient_checkpointing", ] # Only use fp8_base on GPUs with native fp8 support (RTX 4090, H100) # A100 and A6000 don't support fp8 tensor ops, and have enough VRAM without it if gpu_type and ("4090" in gpu_type or "5090" in gpu_type or "L40S" in gpu_type or "H100" in gpu_type): args.append("--fp8_base") # Handle Prodigy optimizer (needs special class path and args) if optimizer.lower() == "prodigy": args.extend([ "--optimizer_type=prodigyopt.Prodigy", f"--learning_rate={learning_rate}", '--optimizer_args "weight_decay=0.01" "decouple=True" "use_bias_correction=True" "safeguard_warmup=True" "d_coef=2"', ]) else: args.extend([ f"--optimizer_type={optimizer}", f"--learning_rate={learning_rate}", ]) args.extend([ "--seed=42", '--output_dir=/workspace/output', f'--output_name={name}', f"--lr_scheduler={lr_scheduler}", ]) if max_train_steps: args.append(f"--max_train_steps={max_train_steps}") if save_every_n_steps: args.append(f"--save_every_n_steps={save_every_n_steps}") else: args.append(f"--save_every_n_epochs={save_every_n_epochs}") else: args.append(f"--max_train_epochs={num_epochs}") args.append(f"--save_every_n_epochs={save_every_n_epochs}") return " ".join(args) + " 2>&1" elif model_type == "wan22": # WAN 2.2 T2V LoRA training via musubi-tuner wan_dir = "/workspace/models/WAN2.2" dit_low = f"{wan_dir}/wan2.2_t2v_low_noise_14B_fp16.safetensors" dit_high = f"{wan_dir}/wan2.2_t2v_high_noise_14B_fp16.safetensors" network_mod = model_cfg.get("network_module", "networks.lora_wan") ts_sampling = model_cfg.get("timestep_sampling", "shift") discrete_shift = model_cfg.get("discrete_flow_shift", 5.0) args = [ "cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True", "accelerate launch --num_cpu_threads_per_process 1 --mixed_precision fp16", "src/musubi_tuner/wan_train_network.py", "--task t2v-A14B", f"--dit {dit_low}", f"--dit_high_noise {dit_high}", "--dataset_config /workspace/dataset.toml", "--sdpa --mixed_precision fp16", "--gradient_checkpointing", f"--timestep_sampling {ts_sampling}", f"--discrete_flow_shift {discrete_shift}", f"--network_module {network_mod}", f"--network_dim={network_rank}", f"--network_alpha={network_alpha}", f"--optimizer_type={optimizer}", f"--learning_rate={learning_rate}", "--seed=42", "--output_dir=/workspace/output", f"--output_name={name}", ] if max_train_steps: args.append(f"--max_train_steps={max_train_steps}") if save_every_n_steps: args.append(f"--save_every_n_steps={save_every_n_steps}") else: args.append(f"--max_train_epochs={num_epochs}") args.append(f"--save_every_n_epochs={save_every_n_epochs}") return " ".join(args) + " 2>&1" elif model_type == "flux": # FLUX.1 training via sd-scripts script = "flux_train_network.py" flux_args = f""" --pretrained_model_name_or_path="{model_path}" \ --clip_l="/workspace/models/clip_l.safetensors" \ --t5xxl="/workspace/models/t5xxl_fp16.safetensors" \ --ae="/workspace/models/ae.safetensors" \ --cache_text_encoder_outputs \ --cache_text_encoder_outputs_to_disk \ --fp8_base \ --split_mode""" # Text encoder learning rate text_encoder_lr = model_cfg.get("text_encoder_lr", 4e-5) flux_args += f" \\\n --text_encoder_lr={text_encoder_lr}" # Min SNR gamma if specified min_snr = model_cfg.get("min_snr_gamma") if min_snr: flux_args += f" \\\n --min_snr_gamma={min_snr}" return f"cd /workspace/sd-scripts && accelerate launch --num_cpu_threads_per_process 1 {script} {flux_args} {base_args} 2>&1" elif model_type == "sdxl": # SDXL-specific training script = "sdxl_train_network.py" clip_skip = model_cfg.get("clip_skip", 2) return f"""cd /workspace/sd-scripts && accelerate launch --num_cpu_threads_per_process 1 {script} \ --pretrained_model_name_or_path="{model_path}" \ --clip_skip={clip_skip} \ --xformers {base_args} 2>&1""" else: # SD 1.5 / default training script = "train_network.py" clip_skip = model_cfg.get("clip_skip", 1) return f"""cd /workspace/sd-scripts && accelerate launch --num_cpu_threads_per_process 1 {script} \ --pretrained_model_name_or_path="{model_path}" \ --clip_skip={clip_skip} \ --xformers {base_args} 2>&1""" async def _wait_for_pod_ready(self, job: CloudTrainingJob, timeout: int = 600) -> tuple[str, int]: """Wait for pod to be running and return (ssh_host, ssh_port).""" start = time.time() while time.time() - start < timeout: try: pod = await asyncio.to_thread(runpod.get_pod, job.pod_id) except Exception as e: job._log(f" API error: {e}") await asyncio.sleep(10) continue status = pod.get("desiredStatus", "") runtime = pod.get("runtime") if status == "RUNNING" and runtime: ports = runtime.get("ports") or [] for port_info in (ports or []): if port_info.get("privatePort") == 22: ip = port_info.get("ip") public_port = port_info.get("publicPort") if ip and public_port: return ip, int(public_port) elapsed = int(time.time() - start) if elapsed % 30 < 6: job._log(f" Status: {status} | runtime: {'ports pending' if runtime else 'not ready yet'} ({elapsed}s)") await asyncio.sleep(5) raise RuntimeError(f"Pod did not become ready within {timeout}s") def _ssh_exec_sync(self, ssh, cmd: str, timeout: int = 120) -> str: """Execute a command over SSH and return stdout (blocking).""" _, stdout, stderr = ssh.exec_command(cmd, timeout=timeout) out = stdout.read().decode("utf-8", errors="replace") err = stderr.read().decode("utf-8", errors="replace") exit_code = stdout.channel.recv_exit_status() if exit_code != 0 and "warning" not in err.lower(): logger.warning("SSH cmd failed (code %d): %s\nstderr: %s", exit_code, cmd[:100], err[:500]) return out.strip() async def _ssh_exec(self, ssh, cmd: str, timeout: int = 120) -> str: """Execute a command over SSH without blocking the event loop.""" return await asyncio.to_thread(self._ssh_exec_sync, ssh, cmd, timeout) def _parse_progress(self, job: CloudTrainingJob, line: str): """Parse Kohya training output for progress info.""" lower = line.lower() if "epoch" in lower and "/" in line: try: parts = lower.split("epoch") if len(parts) > 1: ep_part = parts[1].strip().split()[0] if "/" in ep_part: current, total = ep_part.split("/") job.current_epoch = int(current) job.total_epochs = int(total) job.progress = 0.15 + 0.75 * (job.current_epoch / max(job.total_epochs, 1)) except (ValueError, IndexError): pass if "loss=" in line or "loss:" in line: try: loss_str = line.split("loss")[1].strip("=: ").split()[0].strip(",") job.loss = float(loss_str) except (ValueError, IndexError): pass if "steps:" in lower or "step " in lower: try: import re step_match = re.search(r"(\d+)/(\d+)", line) if step_match: job.current_step = int(step_match.group(1)) job.total_steps = int(step_match.group(2)) except (ValueError, IndexError): pass def get_job(self, job_id: str) -> CloudTrainingJob | None: return self._jobs.get(job_id) def list_jobs(self) -> list[CloudTrainingJob]: return list(self._jobs.values()) async def cancel_job(self, job_id: str) -> bool: """Cancel a cloud training job and terminate its pod.""" job = self._jobs.get(job_id) if not job: return False if job.pod_id: try: await asyncio.to_thread(runpod.terminate_pod, job.pod_id) except Exception: pass job.status = "failed" job.error = "Cancelled by user" return True async def delete_job(self, job_id: str) -> bool: """Delete a training job from memory and database.""" if job_id not in self._jobs: return False del self._jobs[job_id] try: async with catalog_session_factory() as session: result = await session.execute( __import__('sqlalchemy').select(TrainingJobDB).where(TrainingJobDB.id == job_id) ) db_job = result.scalar_one_or_none() if db_job: await session.delete(db_job) await session.commit() except Exception as e: logger.warning("Failed to delete job from DB: %s", e) return True async def delete_failed_jobs(self) -> int: """Delete all failed/error training jobs.""" failed_ids = [jid for jid, j in self._jobs.items() if j.status in ("failed", "error")] for jid in failed_ids: await self.delete_job(jid) return len(failed_ids)