Spaces:
Running
Running
| """RunPod cloud LoRA training — offload training to RunPod GPU pods. | |
| Creates a temporary GPU pod, uploads training images, runs Kohya sd-scripts, | |
| downloads the finished LoRA, then terminates the pod. No local GPU needed. | |
| Supports multiple base models (FLUX, SD 1.5, SDXL) via model registry. | |
| Usage: | |
| Set RUNPOD_API_KEY in .env | |
| Select "Cloud (RunPod)" in the training UI | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import time | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| import runpod | |
| import yaml | |
| logger = logging.getLogger(__name__) | |
| import os | |
| from content_engine.config import settings, IS_HF_SPACES | |
| from content_engine.models.database import catalog_session_factory, TrainingJob as TrainingJobDB | |
| LORA_OUTPUT_DIR = settings.paths.lora_dir | |
| if IS_HF_SPACES: | |
| CONFIG_DIR = Path("/app/config") | |
| else: | |
| CONFIG_DIR = Path("D:/AI automation/content_engine/config") | |
| # RunPod GPU options (id -> display name, approx cost/hr) | |
| GPU_OPTIONS = { | |
| # 24GB - SD 1.5, SDXL, FLUX.1 only (NOT enough for FLUX.2) | |
| "NVIDIA GeForce RTX 3090": "RTX 3090 24GB (~$0.22/hr)", | |
| "NVIDIA GeForce RTX 4090": "RTX 4090 24GB (~$0.44/hr)", | |
| "NVIDIA GeForce RTX 5090": "RTX 5090 32GB (~$0.69/hr)", | |
| "NVIDIA RTX A4000": "RTX A4000 16GB (~$0.20/hr)", | |
| "NVIDIA RTX A5000": "RTX A5000 24GB (~$0.28/hr)", | |
| # 48GB+ - Required for FLUX.2 Dev (Mistral text encoder needs ~48GB) | |
| "NVIDIA RTX A6000": "RTX A6000 48GB (~$0.76/hr)", | |
| "NVIDIA A40": "A40 48GB (~$0.64/hr)", | |
| "NVIDIA L40": "L40 48GB (~$0.89/hr)", | |
| "NVIDIA L40S": "L40S 48GB (~$1.09/hr)", | |
| "NVIDIA A100 80GB PCIe": "A100 80GB (~$1.89/hr)", | |
| "NVIDIA A100-SXM4-80GB": "A100 SXM 80GB (~$1.64/hr)", | |
| "NVIDIA H100 80GB HBM3": "H100 80GB (~$3.89/hr)", | |
| } | |
| DEFAULT_GPU = "NVIDIA GeForce RTX 4090" | |
| # Network volume for persistent model storage (avoids re-downloading models each run) | |
| # Set RUNPOD_VOLUME_ID in .env to use a persistent volume | |
| # Set RUNPOD_VOLUME_DC to the datacenter ID where the volume lives (e.g. "EU-RO-1") | |
| NETWORK_VOLUME_ID = os.environ.get("RUNPOD_VOLUME_ID", "") | |
| NETWORK_VOLUME_DC = os.environ.get("RUNPOD_VOLUME_DC", "") | |
| # Docker image with PyTorch + CUDA pre-installed | |
| DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04" | |
| def load_model_registry() -> dict[str, dict]: | |
| """Load training model configurations from config/models.yaml.""" | |
| models_file = CONFIG_DIR / "models.yaml" | |
| if not models_file.exists(): | |
| logger.warning("Model registry not found: %s", models_file) | |
| return {} | |
| with open(models_file) as f: | |
| config = yaml.safe_load(f) | |
| return config.get("training_models", {}) | |
| class CloudTrainingJob: | |
| """Tracks state of a RunPod cloud training job.""" | |
| id: str | |
| name: str | |
| status: str = "pending" # pending, creating_pod, uploading, installing, training, downloading, completed, failed | |
| progress: float = 0.0 | |
| current_epoch: int = 0 | |
| total_epochs: int = 0 | |
| current_step: int = 0 | |
| total_steps: int = 0 | |
| loss: float | None = None | |
| started_at: float | None = None | |
| completed_at: float | None = None | |
| output_path: str | None = None | |
| error: str | None = None | |
| log_lines: list[str] = field(default_factory=list) | |
| pod_id: str | None = None | |
| gpu_type: str = DEFAULT_GPU | |
| cost_estimate: str | None = None | |
| base_model: str = "sd15_realistic" | |
| model_type: str = "sd15" | |
| _db_callback: Any = None # called on state changes to persist to DB | |
| def _log(self, msg: str): | |
| self.log_lines.append(msg) | |
| if len(self.log_lines) > 200: | |
| self.log_lines = self.log_lines[-200:] | |
| logger.info("[%s] %s", self.id, msg) | |
| if self._db_callback: | |
| self._db_callback(self) | |
| class RunPodTrainer: | |
| """Manages LoRA training on RunPod cloud GPUs.""" | |
| def __init__(self, api_key: str): | |
| self._api_key = api_key | |
| runpod.api_key = api_key | |
| self._jobs: dict[str, CloudTrainingJob] = {} | |
| self._model_registry = load_model_registry() | |
| self._loaded_from_db = False | |
| def available(self) -> bool: | |
| """Check if RunPod is configured.""" | |
| # Re-set module-level key in case uvicorn reload cleared it | |
| if self._api_key: | |
| runpod.api_key = self._api_key | |
| return bool(self._api_key) | |
| def list_gpu_options(self) -> dict[str, str]: | |
| return GPU_OPTIONS | |
| def list_training_models(self) -> dict[str, dict]: | |
| """List available base models for training with their parameters.""" | |
| return { | |
| key: { | |
| "name": cfg.get("name", key), | |
| "description": cfg.get("description", ""), | |
| "model_type": cfg.get("model_type", "sd15"), | |
| "resolution": cfg.get("resolution", 512), | |
| "learning_rate": cfg.get("learning_rate", 1e-4), | |
| "network_rank": cfg.get("network_rank", 32), | |
| "network_alpha": cfg.get("network_alpha", 16), | |
| "optimizer": cfg.get("optimizer", "AdamW8bit"), | |
| "lr_scheduler": cfg.get("lr_scheduler", "cosine"), | |
| "vram_required_gb": cfg.get("vram_required_gb", 8), | |
| "recommended_images": cfg.get("recommended_images", "15-30 photos"), | |
| } | |
| for key, cfg in self._model_registry.items() | |
| } | |
| def get_model_config(self, model_key: str) -> dict | None: | |
| """Get configuration for a specific training model.""" | |
| return self._model_registry.get(model_key) | |
| async def start_training( | |
| self, | |
| *, | |
| name: str, | |
| image_paths: list[str], | |
| trigger_word: str = "", | |
| base_model: str = "sd15_realistic", | |
| resolution: int | None = None, | |
| num_epochs: int = 10, | |
| max_train_steps: int | None = None, | |
| learning_rate: float | None = None, | |
| network_rank: int | None = None, | |
| network_alpha: int | None = None, | |
| optimizer: str | None = None, | |
| save_every_n_epochs: int = 2, | |
| save_every_n_steps: int = 500, | |
| gpu_type: str = DEFAULT_GPU, | |
| ) -> str: | |
| """Start a cloud training job. Returns job ID. | |
| Parameters use model registry defaults if not specified. | |
| """ | |
| job_id = str(uuid.uuid4())[:8] | |
| # Get model config (fall back to sd15_realistic if not found) | |
| model_cfg = self._model_registry.get(base_model, self._model_registry.get("sd15_realistic", {})) | |
| model_type = model_cfg.get("model_type", "sd15") | |
| # Use provided values or model defaults | |
| final_resolution = resolution or model_cfg.get("resolution", 512) | |
| final_lr = learning_rate or model_cfg.get("learning_rate", 1e-4) | |
| final_rank = network_rank or model_cfg.get("network_rank", 32) | |
| final_alpha = network_alpha or model_cfg.get("network_alpha", 16) | |
| final_optimizer = optimizer or model_cfg.get("optimizer", "AdamW8bit") | |
| final_steps = max_train_steps or model_cfg.get("max_train_steps") | |
| job = CloudTrainingJob( | |
| id=job_id, | |
| name=name, | |
| status="pending", | |
| total_epochs=num_epochs, | |
| total_steps=final_steps, | |
| gpu_type=gpu_type, | |
| started_at=time.time(), | |
| base_model=base_model, | |
| model_type=model_type, | |
| ) | |
| self._jobs[job_id] = job | |
| job._db_callback = self._schedule_db_save | |
| asyncio.ensure_future(self._save_to_db(job)) | |
| # Launch the full pipeline as a background task | |
| asyncio.create_task(self._run_cloud_training( | |
| job=job, | |
| image_paths=image_paths, | |
| trigger_word=trigger_word, | |
| model_cfg=model_cfg, | |
| resolution=final_resolution, | |
| num_epochs=num_epochs, | |
| max_train_steps=final_steps, | |
| learning_rate=final_lr, | |
| network_rank=final_rank, | |
| network_alpha=final_alpha, | |
| optimizer=final_optimizer, | |
| save_every_n_epochs=save_every_n_epochs, | |
| save_every_n_steps=save_every_n_steps, | |
| )) | |
| return job_id | |
| async def _run_cloud_training( | |
| self, | |
| job: CloudTrainingJob, | |
| image_paths: list[str], | |
| trigger_word: str, | |
| model_cfg: dict, | |
| resolution: int, | |
| num_epochs: int, | |
| max_train_steps: int | None, | |
| learning_rate: float, | |
| network_rank: int, | |
| network_alpha: int, | |
| optimizer: str, | |
| save_every_n_epochs: int, | |
| save_every_n_steps: int = 500, | |
| ): | |
| """Full cloud training pipeline: create pod -> upload -> train -> download -> cleanup.""" | |
| ssh = None | |
| sftp = None | |
| model_type = model_cfg.get("model_type", "sd15") | |
| name = job.name | |
| try: | |
| # Step 1: Create pod | |
| job.status = "creating_pod" | |
| job._log(f"Creating RunPod with {job.gpu_type}...") | |
| # Use network volume if configured (persists models across runs) | |
| pod_kwargs = { | |
| "container_disk_in_gb": 30, | |
| "ports": "22/tcp", | |
| "docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'", | |
| } | |
| if NETWORK_VOLUME_ID: | |
| pod_kwargs["network_volume_id"] = NETWORK_VOLUME_ID | |
| if NETWORK_VOLUME_DC: | |
| pod_kwargs["data_center_id"] = NETWORK_VOLUME_DC | |
| job._log(f"Using persistent network volume: {NETWORK_VOLUME_ID} (DC: {NETWORK_VOLUME_DC or 'auto'})") | |
| else: | |
| pod_kwargs["volume_in_gb"] = 75 | |
| pod = await asyncio.to_thread( | |
| runpod.create_pod, | |
| f"lora-train-{job.id}", | |
| DOCKER_IMAGE, | |
| job.gpu_type, | |
| **pod_kwargs, | |
| ) | |
| job.pod_id = pod["id"] | |
| job._log(f"Pod created: {job.pod_id}") | |
| # Wait for pod to be ready and get SSH info | |
| job._log("Waiting for pod to start...") | |
| ssh_host, ssh_port = await self._wait_for_pod_ready(job) | |
| job._log(f"Pod ready at {ssh_host}:{ssh_port}") | |
| # Step 2: Connect via SSH | |
| import paramiko | |
| ssh = paramiko.SSHClient() | |
| ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
| for attempt in range(30): | |
| try: | |
| await asyncio.to_thread( | |
| ssh.connect, | |
| ssh_host, port=ssh_port, | |
| username="root", password="runpod", | |
| timeout=10, | |
| ) | |
| break | |
| except Exception: | |
| if attempt == 29: | |
| raise RuntimeError("Could not SSH into pod after 30 attempts") | |
| await asyncio.sleep(5) | |
| job._log("SSH connected") | |
| # If using network volume, symlink to /workspace so all paths work | |
| if NETWORK_VOLUME_ID: | |
| await self._ssh_exec(ssh, "mkdir -p /runpod-volume/models && rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models") | |
| job._log("Network volume symlinked to /workspace") | |
| # Enable keepalive to prevent SSH timeout during uploads | |
| transport = ssh.get_transport() | |
| transport.set_keepalive(30) | |
| sftp = ssh.open_sftp() | |
| sftp.get_channel().settimeout(300) # 5 min timeout per file | |
| # Step 3: Upload training images (compress first to speed up transfer) | |
| job.status = "uploading" | |
| resolution = model_cfg.get("resolution", 1024) | |
| job._log(f"Compressing and uploading {len(image_paths)} training images...") | |
| import tempfile | |
| from PIL import Image | |
| tmp_dir = Path(tempfile.mkdtemp(prefix="lora_upload_")) | |
| folder_name = f"10_{trigger_word or 'character'}" | |
| await self._ssh_exec(ssh, f"mkdir -p /workspace/dataset/{folder_name}") | |
| for i, img_path in enumerate(image_paths): | |
| p = Path(img_path) | |
| if p.exists(): | |
| # Resize and convert to JPEG to reduce upload size | |
| try: | |
| img = Image.open(p) | |
| img.thumbnail((resolution * 2, resolution * 2), Image.LANCZOS) | |
| compressed = tmp_dir / f"{p.stem}.jpg" | |
| img.save(compressed, "JPEG", quality=95) | |
| upload_path = compressed | |
| except Exception: | |
| upload_path = p # fallback to original | |
| remote_name = f"{p.stem}.jpg" if upload_path.suffix == ".jpg" else p.name | |
| remote_path = f"/workspace/dataset/{folder_name}/{remote_name}" | |
| for attempt in range(3): | |
| try: | |
| await asyncio.to_thread(sftp.put, str(upload_path), remote_path) | |
| break | |
| except (EOFError, OSError): | |
| if attempt == 2: | |
| raise | |
| job._log(f"Upload retry {attempt+1} for {p.name}") | |
| sftp.close() | |
| sftp = ssh.open_sftp() | |
| sftp.get_channel().settimeout(300) | |
| # Upload matching caption .txt file if it exists locally | |
| local_caption = p.with_suffix(".txt") | |
| if local_caption.exists(): | |
| remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt" | |
| await asyncio.to_thread(sftp.put, str(local_caption), remote_caption) | |
| else: | |
| # Fallback: create caption from trigger word | |
| remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt" | |
| def _write_caption(): | |
| with sftp.open(remote_caption, "w") as f: | |
| f.write(trigger_word or "") | |
| await asyncio.to_thread(_write_caption) | |
| job._log(f"Uploaded {i+1}/{len(image_paths)}: {p.name}") | |
| # Cleanup temp compressed images | |
| import shutil | |
| shutil.rmtree(tmp_dir, ignore_errors=True) | |
| job._log("Images uploaded") | |
| # Step 4: Install training framework on the pod (skip if cached on volume) | |
| job.status = "installing" | |
| job.progress = 0.05 | |
| training_framework = model_cfg.get("training_framework", "sd-scripts") | |
| if training_framework == "musubi-tuner": | |
| # FLUX.2 uses musubi-tuner (Kohya's newer framework) | |
| tuner_dir = "/workspace/musubi-tuner" | |
| install_cmds = [] | |
| # Check if already present in workspace | |
| tuner_exist = (await self._ssh_exec(ssh, f"test -f {tuner_dir}/pyproject.toml && echo EXISTS || echo MISSING")).strip() | |
| if tuner_exist == "EXISTS": | |
| job._log("musubi-tuner found in workspace") | |
| else: | |
| # Check volume cache | |
| vol_exist = (await self._ssh_exec(ssh, "test -f /runpod-volume/musubi-tuner/pyproject.toml && echo EXISTS || echo MISSING")).strip() | |
| if vol_exist == "EXISTS": | |
| job._log("Restoring musubi-tuner from volume cache...") | |
| await self._ssh_exec(ssh, f"rm -rf {tuner_dir} 2>/dev/null; cp -r /runpod-volume/musubi-tuner {tuner_dir}") | |
| else: | |
| job._log("Cloning musubi-tuner from GitHub...") | |
| await self._ssh_exec(ssh, f"rm -rf {tuner_dir} /runpod-volume/musubi-tuner 2>/dev/null; true") | |
| install_cmds.append(f"cd /workspace && git clone --depth 1 https://github.com/kohya-ss/musubi-tuner.git") | |
| # Save to volume for future pods | |
| if NETWORK_VOLUME_ID: | |
| install_cmds.append(f"cp -r {tuner_dir} /runpod-volume/musubi-tuner") | |
| # Always install pip deps (they are pod-local, lost on every new pod) | |
| job._log("Installing pip dependencies (accelerate, torch, etc.)...") | |
| install_cmds.extend([ | |
| f"cd {tuner_dir} && pip install -e . 2>&1 | tail -5", | |
| "pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes 2>&1 | tail -5", | |
| ]) | |
| else: | |
| # SD 1.5 / SDXL / FLUX.1 use sd-scripts | |
| scripts_exist = (await self._ssh_exec(ssh, "test -f /workspace/sd-scripts/setup.py && echo EXISTS || echo MISSING")).strip() | |
| if scripts_exist == "EXISTS": | |
| job._log("Kohya sd-scripts already cached on volume, updating...") | |
| install_cmds = [ | |
| "cd /workspace/sd-scripts && git pull 2>&1 | tail -1", | |
| ] | |
| else: | |
| job._log("Installing Kohya sd-scripts (this takes a few minutes)...") | |
| install_cmds = [ | |
| "cd /workspace && git clone --depth 1 https://github.com/kohya-ss/sd-scripts.git", | |
| ] | |
| # Always install pip deps (pod-local, lost on new pods) | |
| install_cmds.extend([ | |
| "cd /workspace/sd-scripts && pip install -r requirements.txt 2>&1 | tail -1", | |
| "pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes xformers 2>&1 | tail -1", | |
| ]) | |
| for cmd in install_cmds: | |
| out = await self._ssh_exec(ssh, cmd, timeout=600) | |
| job._log(out[:200] if out else "done") | |
| # Download base model from HuggingFace (skip if already on network volume) | |
| hf_repo = model_cfg.get("hf_repo", "SG161222/Realistic_Vision_V5.1_noVAE") | |
| hf_filename = model_cfg.get("hf_filename", "Realistic_Vision_V5.1_fp16-no-ema.safetensors") | |
| model_name = model_cfg.get("name", job.base_model) | |
| job.progress = 0.1 | |
| await self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120) | |
| if model_type == "flux2": | |
| # FLUX.2 models are stored in a directory structure on the volume | |
| flux2_dir = "/workspace/models/FLUX.2-dev" | |
| dit_path = f"{flux2_dir}/flux2-dev.safetensors" | |
| vae_path = f"{flux2_dir}/ae.safetensors" # Original BFL format (not diffusers) | |
| te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors" | |
| dit_exists = (await self._ssh_exec(ssh, f"test -f {dit_path} && echo EXISTS || echo MISSING")).strip() | |
| vae_exists = (await self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING")).strip() | |
| te_exists = (await self._ssh_exec(ssh, f"test -f {te_path} && echo EXISTS || echo MISSING")).strip() | |
| if dit_exists != "EXISTS" or te_exists != "EXISTS": | |
| missing = [] | |
| if dit_exists != "EXISTS": | |
| missing.append("DiT") | |
| if te_exists != "EXISTS": | |
| missing.append("text encoder") | |
| raise RuntimeError(f"FLUX.2 Dev missing on volume: {', '.join(missing)}. Please download models to the network volume first.") | |
| # Download ae.safetensors (original format VAE) if not present | |
| if vae_exists != "EXISTS": | |
| job._log("Downloading FLUX.2 VAE (ae.safetensors, 336MB)...") | |
| await self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120) | |
| await self._ssh_exec(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download('black-forest-labs/FLUX.2-dev', 'ae.safetensors', local_dir='{flux2_dir}') | |
| print('Downloaded ae.safetensors') | |
| " 2>&1 | tail -5""", timeout=600) | |
| # Verify download | |
| vae_check = (await self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING")).strip() | |
| if vae_check != "EXISTS": | |
| raise RuntimeError("Failed to download ae.safetensors") | |
| job._log("VAE downloaded") | |
| job._log("FLUX.2 Dev models ready") | |
| elif model_type == "wan22": | |
| # WAN 2.2 T2V — 4 model files stored in /workspace/models/WAN2.2/ | |
| wan_dir = "/workspace/models/WAN2.2" | |
| await self._ssh_exec(ssh, f"mkdir -p {wan_dir}") | |
| wan_files = { | |
| "DiT low-noise": { | |
| "path": f"{wan_dir}/wan2.2_t2v_low_noise_14B_fp16.safetensors", | |
| "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", | |
| "filename": "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp16.safetensors", | |
| }, | |
| "DiT high-noise": { | |
| "path": f"{wan_dir}/wan2.2_t2v_high_noise_14B_fp16.safetensors", | |
| "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", | |
| "filename": "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp16.safetensors", | |
| }, | |
| "VAE": { | |
| "path": f"{wan_dir}/Wan2.1_VAE.pth", | |
| "repo": "Wan-AI/Wan2.1-I2V-14B-720P", | |
| "filename": "Wan2.1_VAE.pth", | |
| }, | |
| "T5 text encoder": { | |
| "path": f"{wan_dir}/models_t5_umt5-xxl-enc-bf16.pth", | |
| "repo": "Wan-AI/Wan2.1-I2V-14B-720P", | |
| "filename": "models_t5_umt5-xxl-enc-bf16.pth", | |
| }, | |
| } | |
| for label, info in wan_files.items(): | |
| exists = (await self._ssh_exec(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip() | |
| if exists == "EXISTS": | |
| job._log(f"WAN 2.2 {label} already cached") | |
| else: | |
| job._log(f"Downloading WAN 2.2 {label}...") | |
| await self._ssh_exec(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download('{info['repo']}', '{info['filename']}', local_dir='{wan_dir}') | |
| # hf_hub_download puts files in subdirs matching the filename path — move to root | |
| import os, shutil | |
| downloaded = os.path.join('{wan_dir}', '{info['filename']}') | |
| target = '{info['path']}' | |
| if os.path.exists(downloaded) and downloaded != target: | |
| shutil.move(downloaded, target) | |
| print('Downloaded {label}') | |
| " 2>&1 | tail -5""", timeout=1800) | |
| # Verify | |
| check = (await self._ssh_exec(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip() | |
| if check != "EXISTS": | |
| raise RuntimeError(f"Failed to download WAN 2.2 {label}") | |
| job._log("WAN 2.2 models ready") | |
| else: | |
| # SD 1.5 / SDXL / FLUX.1 — download single model file | |
| model_exists = (await self._ssh_exec(ssh, f"test -f /workspace/models/{hf_filename} && echo EXISTS || echo MISSING")).strip() | |
| if model_exists == "EXISTS": | |
| job._log(f"Base model already cached on volume: {model_name}") | |
| else: | |
| job._log(f"Downloading base model: {model_name}...") | |
| await self._ssh_exec(ssh, f""" | |
| python -c " | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download('{hf_repo}', '{hf_filename}', local_dir='/workspace/models') | |
| " 2>&1 | tail -5 | |
| """, timeout=1200) | |
| # For FLUX.1, download additional required models (CLIP, T5, VAE) | |
| if model_type == "flux": | |
| flux_files_check = (await self._ssh_exec(ssh, "test -f /workspace/models/clip_l.safetensors && test -f /workspace/models/t5xxl_fp16.safetensors && test -f /workspace/models/ae.safetensors && echo EXISTS || echo MISSING")).strip() | |
| if flux_files_check == "EXISTS": | |
| job._log("FLUX.1 auxiliary models already cached on volume") | |
| else: | |
| job._log("Downloading FLUX.1 auxiliary models (CLIP, T5, VAE)...") | |
| job.progress = 0.12 | |
| await self._ssh_exec(ssh, """ | |
| python -c " | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download('comfyanonymous/flux_text_encoders', 'clip_l.safetensors', local_dir='/workspace/models') | |
| hf_hub_download('comfyanonymous/flux_text_encoders', 't5xxl_fp16.safetensors', local_dir='/workspace/models') | |
| hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/workspace/models') | |
| " 2>&1 | tail -5 | |
| """, timeout=1200) | |
| job._log("Base model ready") | |
| job.progress = 0.15 | |
| # Step 5: Run training | |
| job.status = "training" | |
| job._log(f"Starting {model_type.upper()} LoRA training...") | |
| if model_type == "flux2": | |
| model_path = f"/workspace/models/FLUX.2-dev/flux2-dev.safetensors" | |
| elif model_type == "wan22": | |
| model_path = "/workspace/models/WAN2.2/wan2.2_t2v_low_noise_14B_fp16.safetensors" | |
| else: | |
| model_path = f"/workspace/models/{hf_filename}" | |
| # For musubi-tuner, create TOML dataset config | |
| if training_framework == "musubi-tuner": | |
| folder_name = f"10_{trigger_word or 'character'}" | |
| toml_content = f"""[[datasets]] | |
| image_directory = "/workspace/dataset/{folder_name}" | |
| caption_extension = ".txt" | |
| batch_size = 1 | |
| num_repeats = 10 | |
| resolution = [{resolution}, {resolution}] | |
| """ | |
| await self._ssh_exec(ssh, f"cat > /workspace/dataset.toml << 'TOMLEOF'\n{toml_content}TOMLEOF") | |
| job._log("Created dataset.toml config") | |
| # musubi-tuner requires pre-caching latents and text encoder outputs | |
| if model_type == "wan22": | |
| wan_dir = "/workspace/models/WAN2.2" | |
| vae_path = f"{wan_dir}/Wan2.1_VAE.pth" | |
| te_path = f"{wan_dir}/models_t5_umt5-xxl-enc-bf16.pth" | |
| job._log("Caching WAN 2.2 latents (VAE encoding)...") | |
| job.progress = 0.15 | |
| self._schedule_db_save(job) | |
| cache_latents_cmd = ( | |
| f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" | |
| f" python src/musubi_tuner/wan_cache_latents.py" | |
| f" --dataset_config /workspace/dataset.toml" | |
| f" --vae {vae_path}" | |
| f" --vae_dtype bfloat16" | |
| f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}" | |
| ) | |
| out = await self._ssh_exec(ssh, cache_latents_cmd, timeout=600) | |
| last_lines = out.split('\n')[-30:] | |
| job._log('\n'.join(last_lines)) | |
| if "EXIT_CODE=0" not in out: | |
| err_log = await self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10") | |
| job._log(f"Cache error details: {err_log}") | |
| raise RuntimeError(f"WAN latent caching failed") | |
| job._log("Caching WAN 2.2 text encoder outputs (T5)...") | |
| job.progress = 0.25 | |
| self._schedule_db_save(job) | |
| cache_te_cmd = ( | |
| f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" | |
| f" python src/musubi_tuner/wan_cache_text_encoder_outputs.py" | |
| f" --dataset_config /workspace/dataset.toml" | |
| f" --t5 {te_path}" | |
| f" --batch_size 16" | |
| f" 2>&1; echo EXIT_CODE=$?" | |
| ) | |
| out = await self._ssh_exec(ssh, cache_te_cmd, timeout=600) | |
| job._log(out[-500:] if out else "done") | |
| if "EXIT_CODE=0" not in out: | |
| raise RuntimeError(f"WAN text encoder caching failed: {out[-200:]}") | |
| else: | |
| # FLUX.2 caching | |
| flux2_dir = "/workspace/models/FLUX.2-dev" | |
| vae_path = f"{flux2_dir}/ae.safetensors" | |
| te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors" | |
| job._log("Caching latents (VAE encoding)...") | |
| job.progress = 0.15 | |
| self._schedule_db_save(job) | |
| cache_latents_cmd = ( | |
| f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python src/musubi_tuner/flux_2_cache_latents.py" | |
| f" --dataset_config /workspace/dataset.toml" | |
| f" --vae {vae_path}" | |
| f" --model_version dev" | |
| f" --vae_dtype bfloat16" | |
| f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}" | |
| ) | |
| out = await self._ssh_exec(ssh, cache_latents_cmd, timeout=600) | |
| last_lines = out.split('\n')[-30:] | |
| job._log('\n'.join(last_lines)) | |
| if "EXIT_CODE=0" not in out: | |
| err_log = await self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10") | |
| job._log(f"Cache error details: {err_log}") | |
| raise RuntimeError(f"Latent caching failed") | |
| job._log("Caching text encoder outputs (bf16)...") | |
| job.progress = 0.25 | |
| self._schedule_db_save(job) | |
| cache_te_cmd = ( | |
| f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" | |
| f" python src/musubi_tuner/flux_2_cache_text_encoder_outputs.py" | |
| f" --dataset_config /workspace/dataset.toml" | |
| f" --text_encoder {te_path}" | |
| f" --model_version dev" | |
| f" --batch_size 1" | |
| f" 2>&1; echo EXIT_CODE=$?" | |
| ) | |
| out = await self._ssh_exec(ssh, cache_te_cmd, timeout=600) | |
| job._log(out[-500:] if out else "done") | |
| if "EXIT_CODE=0" not in out: | |
| raise RuntimeError(f"Text encoder caching failed: {out[-200:]}") | |
| # Build training command based on model type | |
| train_cmd = self._build_training_command( | |
| model_type=model_type, | |
| model_path=model_path, | |
| name=name, | |
| resolution=resolution, | |
| num_epochs=num_epochs, | |
| max_train_steps=max_train_steps, | |
| learning_rate=learning_rate, | |
| network_rank=network_rank, | |
| network_alpha=network_alpha, | |
| optimizer=optimizer, | |
| save_every_n_epochs=save_every_n_epochs, | |
| save_every_n_steps=save_every_n_steps, | |
| model_cfg=model_cfg, | |
| gpu_type=job.gpu_type, | |
| ) | |
| # Execute training in a detached process (survives SSH disconnect) | |
| job._log("Starting training (detached — survives disconnects)...") | |
| log_file = "/tmp/training.log" | |
| pid_file = "/tmp/training.pid" | |
| exit_file = "/tmp/training.exit" | |
| await self._ssh_exec(ssh, f"rm -f {log_file} {exit_file} {pid_file}") | |
| # Write training command to a script file (avoids quoting issues with nohup) | |
| script_file = "/tmp/train.sh" | |
| await self._ssh_exec(ssh, f"cat > {script_file} << 'TRAINEOF'\n#!/bin/bash\n{train_cmd} > {log_file} 2>&1\necho $? > {exit_file}\nTRAINEOF") | |
| await self._ssh_exec(ssh, f"chmod +x {script_file}") | |
| # Verify script was written | |
| script_check = (await self._ssh_exec(ssh, f"wc -l < {script_file}")).strip() | |
| job._log(f"Training script written ({script_check} lines)") | |
| # Launch fully detached: close all FDs so SSH channel doesn't hang | |
| await self._ssh_exec( | |
| ssh, | |
| f"setsid {script_file} </dev/null >/dev/null 2>&1 &\necho $! > {pid_file}", | |
| timeout=15, | |
| ) | |
| await asyncio.sleep(3) | |
| pid = (await self._ssh_exec(ssh, f"cat {pid_file} 2>/dev/null")).strip() | |
| if not pid: | |
| # Fallback: find the process by script name | |
| pid = (await self._ssh_exec(ssh, "pgrep -f train.sh 2>/dev/null | head -1")).strip() | |
| job._log(f"Training PID: {pid}") | |
| # Verify process is actually running | |
| if pid: | |
| running = (await self._ssh_exec(ssh, f"kill -0 {pid} 2>&1 && echo RUNNING || echo DEAD")).strip() | |
| job._log(f"Process status: {running}") | |
| if "DEAD" in running: | |
| # Check if it already wrote an exit code (fast failure) | |
| early_exit = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip() | |
| early_log = (await self._ssh_exec(ssh, f"cat {log_file} 2>/dev/null | tail -20")).strip() | |
| raise RuntimeError(f"Training process died immediately. Exit: {early_exit}\nLog: {early_log}") | |
| else: | |
| early_log = (await self._ssh_exec(ssh, f"cat {log_file} 2>/dev/null | tail -20")).strip() | |
| raise RuntimeError(f"Failed to start training process.\nLog: {early_log}") | |
| # Monitor the log file (reconnect-safe) | |
| last_offset = 0 | |
| while True: | |
| # Check if training finished | |
| exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip() | |
| if exit_check: | |
| exit_code = int(exit_check) | |
| # Read remaining log | |
| remaining = (await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)) | |
| if remaining: | |
| for line in remaining.split("\n"): | |
| line = line.strip() | |
| if line: | |
| job._log(line) | |
| self._parse_progress(job, line) | |
| if exit_code != 0: | |
| raise RuntimeError(f"Training failed with exit code {exit_code}") | |
| break | |
| # Read new log output | |
| try: | |
| new_output = (await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)) | |
| if new_output: | |
| last_offset += len(new_output.encode("utf-8")) | |
| for line in new_output.replace("\r", "\n").split("\n"): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| job._log(line) | |
| self._parse_progress(job, line) | |
| self._schedule_db_save(job) | |
| except Exception: | |
| job._log("Log read failed, retrying...") | |
| await asyncio.sleep(5) | |
| job._log("Training completed on RunPod!") | |
| job.progress = 0.9 | |
| # Step 6: Save LoRA to network volume and download locally | |
| job.status = "downloading" | |
| # First, copy to network volume for persistence | |
| job._log("Saving LoRA to network volume...") | |
| await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras") | |
| remote_output = f"/workspace/output/{name}.safetensors" | |
| # Find the output file | |
| check = (await self._ssh_exec(ssh, f"test -f {remote_output} && echo EXISTS || echo MISSING")).strip() | |
| if check == "MISSING": | |
| remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip() | |
| if remote_files: | |
| remote_output = remote_files.split("\n")[-1].strip() | |
| else: | |
| raise RuntimeError("No .safetensors output found") | |
| await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors") | |
| job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors") | |
| # Also save intermediate checkpoints (step 500, 1000, 1500, etc.) | |
| checkpoint_files = (await self._ssh_exec(ssh, f"ls /workspace/output/{name}-step*.safetensors 2>/dev/null")).strip() | |
| if checkpoint_files: | |
| for ckpt in checkpoint_files.split("\n"): | |
| ckpt = ckpt.strip() | |
| if ckpt: | |
| ckpt_name = ckpt.split("/")[-1] | |
| await self._ssh_exec(ssh, f"cp {ckpt} /runpod-volume/loras/{ckpt_name}") | |
| job._log(f"Checkpoint saved: /runpod-volume/loras/{ckpt_name}") | |
| # Download locally (skip on HF Spaces — limited storage) | |
| if IS_HF_SPACES: | |
| job.output_path = f"/runpod-volume/loras/{name}.safetensors" | |
| job._log("LoRA saved on RunPod volume (ready for generation)") | |
| else: | |
| job._log("Downloading LoRA to local machine...") | |
| LORA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| local_path = LORA_OUTPUT_DIR / f"{name}.safetensors" | |
| await asyncio.to_thread(sftp.get, remote_output, str(local_path)) | |
| job.output_path = str(local_path) | |
| job._log(f"LoRA saved locally to {local_path}") | |
| # Done! | |
| job.status = "completed" | |
| job.progress = 1.0 | |
| job.completed_at = time.time() | |
| elapsed = (job.completed_at - job.started_at) / 60 | |
| job._log(f"Cloud training complete in {elapsed:.1f} minutes") | |
| except Exception as e: | |
| job.status = "failed" | |
| job.error = str(e) | |
| job._log(f"ERROR: {e}") | |
| logger.error("Cloud training failed: %s", e, exc_info=True) | |
| finally: | |
| # Cleanup: close SSH and terminate pod | |
| if sftp: | |
| try: | |
| sftp.close() | |
| except Exception: | |
| pass | |
| if ssh: | |
| try: | |
| ssh.close() | |
| except Exception: | |
| pass | |
| # Clean up local training images (saves HF Spaces storage) | |
| if image_paths: | |
| import shutil | |
| first_image_dir = Path(image_paths[0]).parent | |
| if first_image_dir.exists() and "training_uploads" in str(first_image_dir): | |
| shutil.rmtree(first_image_dir, ignore_errors=True) | |
| if job.pod_id: | |
| try: | |
| job._log("Terminating RunPod...") | |
| await asyncio.to_thread(runpod.terminate_pod, job.pod_id) | |
| job._log("Pod terminated") | |
| except Exception as e: | |
| job._log(f"Warning: Failed to terminate pod {job.pod_id}: {e}") | |
| def _schedule_db_save(self, job: CloudTrainingJob): | |
| """Schedule a DB save (non-blocking).""" | |
| try: | |
| asyncio.get_event_loop().create_task(self._save_to_db(job)) | |
| except RuntimeError: | |
| pass # no event loop | |
| async def _save_to_db(self, job: CloudTrainingJob): | |
| """Persist job state to database.""" | |
| try: | |
| from sqlalchemy import text | |
| async with catalog_session_factory() as session: | |
| # Use raw INSERT OR REPLACE for SQLite upsert | |
| await session.execute( | |
| text("""INSERT OR REPLACE INTO training_jobs | |
| (id, name, status, progress, current_epoch, total_epochs, | |
| current_step, total_steps, loss, started_at, completed_at, | |
| output_path, error, log_text, pod_id, gpu_type, backend, | |
| base_model, model_type, created_at) | |
| VALUES (:id, :name, :status, :progress, :current_epoch, :total_epochs, | |
| :current_step, :total_steps, :loss, :started_at, :completed_at, | |
| :output_path, :error, :log_text, :pod_id, :gpu_type, :backend, | |
| :base_model, :model_type, COALESCE((SELECT created_at FROM training_jobs WHERE id = :id), CURRENT_TIMESTAMP)) | |
| """), | |
| { | |
| "id": job.id, "name": job.name, "status": job.status, | |
| "progress": job.progress, "current_epoch": job.current_epoch, | |
| "total_epochs": job.total_epochs, "current_step": job.current_step, | |
| "total_steps": job.total_steps, "loss": job.loss, | |
| "started_at": job.started_at, "completed_at": job.completed_at, | |
| "output_path": job.output_path, "error": job.error, | |
| "log_text": "\n".join(job.log_lines[-200:]), | |
| "pod_id": job.pod_id, "gpu_type": job.gpu_type, | |
| "backend": "runpod", "base_model": job.base_model, | |
| "model_type": job.model_type, | |
| } | |
| ) | |
| await session.commit() | |
| except Exception as e: | |
| logger.warning("Failed to save training job to DB: %s", e) | |
| async def _load_jobs_from_db(self): | |
| """Load previously saved jobs from database on startup.""" | |
| try: | |
| from sqlalchemy import select | |
| async with catalog_session_factory() as session: | |
| result = await session.execute( | |
| select(TrainingJobDB).order_by(TrainingJobDB.created_at.desc()).limit(20) | |
| ) | |
| db_jobs = result.scalars().all() | |
| for db_job in db_jobs: | |
| if db_job.id not in self._jobs: | |
| job = CloudTrainingJob( | |
| id=db_job.id, | |
| name=db_job.name, | |
| status=db_job.status, | |
| progress=db_job.progress or 0.0, | |
| current_epoch=db_job.current_epoch or 0, | |
| total_epochs=db_job.total_epochs or 0, | |
| current_step=db_job.current_step or 0, | |
| total_steps=db_job.total_steps or 0, | |
| loss=db_job.loss, | |
| started_at=db_job.started_at, | |
| completed_at=db_job.completed_at, | |
| output_path=db_job.output_path, | |
| error=db_job.error, | |
| log_lines=(db_job.log_text or "").split("\n") if db_job.log_text else [], | |
| pod_id=db_job.pod_id, | |
| gpu_type=db_job.gpu_type or DEFAULT_GPU, | |
| base_model=db_job.base_model or "sd15_realistic", | |
| model_type=db_job.model_type or "sd15", | |
| ) | |
| self._jobs[db_job.id] = job | |
| # Try to reconnect to running training pods | |
| if job.status not in ("completed", "failed") and job.pod_id: | |
| try: | |
| pod = await asyncio.to_thread(runpod.get_pod, job.pod_id) | |
| if pod and pod.get("desiredStatus") == "RUNNING": | |
| job.status = "training" | |
| job.error = None | |
| job._log("Reconnecting to running training pod after restart...") | |
| asyncio.create_task(self._reconnect_training(job)) | |
| logger.info("Reconnecting to training pod %s for job %s", job.pod_id, job.id) | |
| else: | |
| job.status = "failed" | |
| job.error = "Pod terminated during server restart" | |
| except Exception as e: | |
| logger.warning("Could not check pod %s: %s", job.pod_id, e) | |
| job.status = "failed" | |
| job.error = "Interrupted by server restart" | |
| elif job.status not in ("completed", "failed"): | |
| job.status = "failed" | |
| job.error = "Interrupted by server restart" | |
| except Exception as e: | |
| logger.warning("Failed to load training jobs from DB: %s", e) | |
| async def ensure_loaded(self): | |
| """Load jobs from DB on first access.""" | |
| if not self._loaded_from_db: | |
| self._loaded_from_db = True | |
| await self._load_jobs_from_db() | |
| async def _reconnect_training(self, job: CloudTrainingJob): | |
| """Reconnect to a training pod after server restart and resume log monitoring.""" | |
| import paramiko | |
| ssh = None | |
| try: | |
| # Get SSH info from RunPod | |
| pod = await asyncio.to_thread(runpod.get_pod, job.pod_id) | |
| if not pod: | |
| raise RuntimeError("Pod not found") | |
| runtime = pod.get("runtime") or {} | |
| ports = runtime.get("ports") or [] | |
| ssh_host = ssh_port = None | |
| for p in ports: | |
| if p.get("privatePort") == 22: | |
| ssh_host = p.get("ip") | |
| ssh_port = p.get("publicPort") | |
| if not ssh_host or not ssh_port: | |
| raise RuntimeError("SSH port not available") | |
| # Connect SSH | |
| ssh = paramiko.SSHClient() | |
| ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
| await asyncio.to_thread( | |
| ssh.connect, ssh_host, port=int(ssh_port), | |
| username="root", password="runpod", timeout=10, | |
| ) | |
| transport = ssh.get_transport() | |
| transport.set_keepalive(30) | |
| job._log(f"Reconnected to pod {job.pod_id}") | |
| # Check if training is still running | |
| log_file = "/tmp/training.log" | |
| exit_file = "/tmp/training.exit" | |
| pid_file = "/tmp/training.pid" | |
| exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip() | |
| if exit_check: | |
| # Training already finished while we were disconnected | |
| exit_code = int(exit_check) | |
| log_tail = await self._ssh_exec(ssh, f"tail -50 {log_file} 2>/dev/null") | |
| for line in log_tail.split("\n"): | |
| line = line.strip() | |
| if line: | |
| job._log(line) | |
| self._parse_progress(job, line) | |
| if exit_code == 0: | |
| job._log("Training completed while disconnected!") | |
| # Copy LoRA to volume | |
| name = job.name | |
| await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras") | |
| remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip() | |
| if remote_files: | |
| remote_output = remote_files.split("\n")[-1].strip() | |
| await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors") | |
| job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors") | |
| job.output_path = f"/runpod-volume/loras/{name}.safetensors" | |
| job.status = "completed" | |
| job.progress = 1.0 | |
| job.completed_at = time.time() | |
| else: | |
| raise RuntimeError(f"Training failed with exit code {exit_code}") | |
| else: | |
| # Training still running — resume log monitoring | |
| pid = (await self._ssh_exec(ssh, f"cat {pid_file} 2>/dev/null")).strip() | |
| job._log(f"Training still running (PID: {pid}), resuming monitoring...") | |
| last_offset = 0 | |
| while True: | |
| exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip() | |
| if exit_check: | |
| exit_code = int(exit_check) | |
| remaining = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30) | |
| if remaining: | |
| for line in remaining.split("\n"): | |
| line = line.strip() | |
| if line: | |
| job._log(line) | |
| self._parse_progress(job, line) | |
| if exit_code == 0: | |
| # Copy LoRA to volume | |
| name = job.name | |
| await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras") | |
| remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip() | |
| if remote_files: | |
| remote_output = remote_files.split("\n")[-1].strip() | |
| await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors") | |
| job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors") | |
| job.output_path = f"/runpod-volume/loras/{name}.safetensors" | |
| job.status = "completed" | |
| job.progress = 1.0 | |
| job.completed_at = time.time() | |
| break | |
| else: | |
| raise RuntimeError(f"Training failed with exit code {exit_code}") | |
| try: | |
| new_output = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30) | |
| if new_output: | |
| last_offset += len(new_output.encode("utf-8")) | |
| for line in new_output.replace("\r", "\n").split("\n"): | |
| line = line.strip() | |
| if line: | |
| job._log(line) | |
| self._parse_progress(job, line) | |
| self._schedule_db_save(job) | |
| except Exception: | |
| pass | |
| await asyncio.sleep(5) | |
| job._log("Training complete!") | |
| except Exception as e: | |
| job.status = "failed" | |
| job.error = str(e) | |
| job._log(f"Reconnect failed: {e}") | |
| logger.error("Training reconnect failed for %s: %s", job.id, e) | |
| finally: | |
| if ssh: | |
| try: | |
| ssh.close() | |
| except Exception: | |
| pass | |
| # Terminate pod | |
| if job.pod_id: | |
| try: | |
| await asyncio.to_thread(runpod.terminate_pod, job.pod_id) | |
| job._log("Pod terminated") | |
| except Exception: | |
| pass | |
| self._schedule_db_save(job) | |
| def _build_training_command( | |
| self, | |
| *, | |
| model_type: str, | |
| model_path: str, | |
| name: str, | |
| resolution: int, | |
| num_epochs: int, | |
| max_train_steps: int | None, | |
| learning_rate: float, | |
| network_rank: int, | |
| network_alpha: int, | |
| optimizer: str, | |
| save_every_n_epochs: int, | |
| save_every_n_steps: int = 500, | |
| model_cfg: dict, | |
| gpu_type: str = "", | |
| ) -> str: | |
| """Build the training command based on model type.""" | |
| # Common parameters | |
| base_args = f""" | |
| --train_data_dir="/workspace/dataset" \ | |
| --output_dir="/workspace/output" \ | |
| --output_name="{name}" \ | |
| --resolution={resolution} \ | |
| --train_batch_size=1 \ | |
| --learning_rate={learning_rate} \ | |
| --network_module=networks.lora \ | |
| --network_dim={network_rank} \ | |
| --network_alpha={network_alpha} \ | |
| --optimizer_type={optimizer} \ | |
| --save_every_n_epochs={save_every_n_epochs} \ | |
| --mixed_precision=fp16 \ | |
| --seed=42 \ | |
| --caption_extension=.txt \ | |
| --cache_latents \ | |
| --enable_bucket \ | |
| --save_model_as=safetensors""" | |
| # Steps vs epochs | |
| if max_train_steps: | |
| base_args += f" \\\n --max_train_steps={max_train_steps}" | |
| else: | |
| base_args += f" \\\n --max_train_epochs={num_epochs}" | |
| # LR scheduler | |
| lr_scheduler = model_cfg.get("lr_scheduler", "cosine_with_restarts") | |
| base_args += f" \\\n --lr_scheduler={lr_scheduler}" | |
| if model_type == "flux2": | |
| # FLUX.2 training via musubi-tuner | |
| flux2_dir = "/workspace/models/FLUX.2-dev" | |
| dit_path = f"{flux2_dir}/flux2-dev.safetensors" | |
| vae_path = f"{flux2_dir}/ae.safetensors" | |
| te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors" | |
| network_mod = model_cfg.get("network_module", "networks.lora_flux_2") | |
| ts_sampling = model_cfg.get("timestep_sampling", "flux2_shift") | |
| lr_scheduler = model_cfg.get("lr_scheduler", "cosine") | |
| # Build as list of args to avoid shell escaping issues | |
| args = [ | |
| "cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True", | |
| "accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16", | |
| "src/musubi_tuner/flux_2_train_network.py", | |
| "--model_version dev", | |
| f"--dit {dit_path}", | |
| f"--vae {vae_path}", | |
| f"--text_encoder {te_path}", | |
| "--dataset_config /workspace/dataset.toml", | |
| "--sdpa --mixed_precision bf16", | |
| f"--timestep_sampling {ts_sampling} --weighting_scheme none", | |
| f"--network_module {network_mod}", | |
| f"--network_dim={network_rank}", | |
| f"--network_alpha={network_alpha}", | |
| "--gradient_checkpointing", | |
| ] | |
| # Only use fp8_base on GPUs with native fp8 support (RTX 4090, H100) | |
| # A100 and A6000 don't support fp8 tensor ops, and have enough VRAM without it | |
| if gpu_type and ("4090" in gpu_type or "5090" in gpu_type or "L40S" in gpu_type or "H100" in gpu_type): | |
| args.append("--fp8_base") | |
| # Handle Prodigy optimizer (needs special class path and args) | |
| if optimizer.lower() == "prodigy": | |
| args.extend([ | |
| "--optimizer_type=prodigyopt.Prodigy", | |
| f"--learning_rate={learning_rate}", | |
| '--optimizer_args "weight_decay=0.01" "decouple=True" "use_bias_correction=True" "safeguard_warmup=True" "d_coef=2"', | |
| ]) | |
| else: | |
| args.extend([ | |
| f"--optimizer_type={optimizer}", | |
| f"--learning_rate={learning_rate}", | |
| ]) | |
| args.extend([ | |
| "--seed=42", | |
| '--output_dir=/workspace/output', | |
| f'--output_name={name}', | |
| f"--lr_scheduler={lr_scheduler}", | |
| ]) | |
| if max_train_steps: | |
| args.append(f"--max_train_steps={max_train_steps}") | |
| if save_every_n_steps: | |
| args.append(f"--save_every_n_steps={save_every_n_steps}") | |
| else: | |
| args.append(f"--save_every_n_epochs={save_every_n_epochs}") | |
| else: | |
| args.append(f"--max_train_epochs={num_epochs}") | |
| args.append(f"--save_every_n_epochs={save_every_n_epochs}") | |
| return " ".join(args) + " 2>&1" | |
| elif model_type == "wan22": | |
| # WAN 2.2 T2V LoRA training via musubi-tuner | |
| wan_dir = "/workspace/models/WAN2.2" | |
| dit_low = f"{wan_dir}/wan2.2_t2v_low_noise_14B_fp16.safetensors" | |
| dit_high = f"{wan_dir}/wan2.2_t2v_high_noise_14B_fp16.safetensors" | |
| network_mod = model_cfg.get("network_module", "networks.lora_wan") | |
| ts_sampling = model_cfg.get("timestep_sampling", "shift") | |
| discrete_shift = model_cfg.get("discrete_flow_shift", 5.0) | |
| args = [ | |
| "cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True", | |
| "accelerate launch --num_cpu_threads_per_process 1 --mixed_precision fp16", | |
| "src/musubi_tuner/wan_train_network.py", | |
| "--task t2v-A14B", | |
| f"--dit {dit_low}", | |
| f"--dit_high_noise {dit_high}", | |
| "--dataset_config /workspace/dataset.toml", | |
| "--sdpa --mixed_precision fp16", | |
| "--gradient_checkpointing", | |
| f"--timestep_sampling {ts_sampling}", | |
| f"--discrete_flow_shift {discrete_shift}", | |
| f"--network_module {network_mod}", | |
| f"--network_dim={network_rank}", | |
| f"--network_alpha={network_alpha}", | |
| f"--optimizer_type={optimizer}", | |
| f"--learning_rate={learning_rate}", | |
| "--seed=42", | |
| "--output_dir=/workspace/output", | |
| f"--output_name={name}", | |
| ] | |
| if max_train_steps: | |
| args.append(f"--max_train_steps={max_train_steps}") | |
| if save_every_n_steps: | |
| args.append(f"--save_every_n_steps={save_every_n_steps}") | |
| else: | |
| args.append(f"--max_train_epochs={num_epochs}") | |
| args.append(f"--save_every_n_epochs={save_every_n_epochs}") | |
| return " ".join(args) + " 2>&1" | |
| elif model_type == "flux": | |
| # FLUX.1 training via sd-scripts | |
| script = "flux_train_network.py" | |
| flux_args = f""" | |
| --pretrained_model_name_or_path="{model_path}" \ | |
| --clip_l="/workspace/models/clip_l.safetensors" \ | |
| --t5xxl="/workspace/models/t5xxl_fp16.safetensors" \ | |
| --ae="/workspace/models/ae.safetensors" \ | |
| --cache_text_encoder_outputs \ | |
| --cache_text_encoder_outputs_to_disk \ | |
| --fp8_base \ | |
| --split_mode""" | |
| # Text encoder learning rate | |
| text_encoder_lr = model_cfg.get("text_encoder_lr", 4e-5) | |
| flux_args += f" \\\n --text_encoder_lr={text_encoder_lr}" | |
| # Min SNR gamma if specified | |
| min_snr = model_cfg.get("min_snr_gamma") | |
| if min_snr: | |
| flux_args += f" \\\n --min_snr_gamma={min_snr}" | |
| return f"cd /workspace/sd-scripts && accelerate launch --num_cpu_threads_per_process 1 {script} {flux_args} {base_args} 2>&1" | |
| elif model_type == "sdxl": | |
| # SDXL-specific training | |
| script = "sdxl_train_network.py" | |
| clip_skip = model_cfg.get("clip_skip", 2) | |
| return f"""cd /workspace/sd-scripts && accelerate launch --num_cpu_threads_per_process 1 {script} \ | |
| --pretrained_model_name_or_path="{model_path}" \ | |
| --clip_skip={clip_skip} \ | |
| --xformers {base_args} 2>&1""" | |
| else: | |
| # SD 1.5 / default training | |
| script = "train_network.py" | |
| clip_skip = model_cfg.get("clip_skip", 1) | |
| return f"""cd /workspace/sd-scripts && accelerate launch --num_cpu_threads_per_process 1 {script} \ | |
| --pretrained_model_name_or_path="{model_path}" \ | |
| --clip_skip={clip_skip} \ | |
| --xformers {base_args} 2>&1""" | |
| async def _wait_for_pod_ready(self, job: CloudTrainingJob, timeout: int = 600) -> tuple[str, int]: | |
| """Wait for pod to be running and return (ssh_host, ssh_port).""" | |
| start = time.time() | |
| while time.time() - start < timeout: | |
| try: | |
| pod = await asyncio.to_thread(runpod.get_pod, job.pod_id) | |
| except Exception as e: | |
| job._log(f" API error: {e}") | |
| await asyncio.sleep(10) | |
| continue | |
| status = pod.get("desiredStatus", "") | |
| runtime = pod.get("runtime") | |
| if status == "RUNNING" and runtime: | |
| ports = runtime.get("ports") or [] | |
| for port_info in (ports or []): | |
| if port_info.get("privatePort") == 22: | |
| ip = port_info.get("ip") | |
| public_port = port_info.get("publicPort") | |
| if ip and public_port: | |
| return ip, int(public_port) | |
| elapsed = int(time.time() - start) | |
| if elapsed % 30 < 6: | |
| job._log(f" Status: {status} | runtime: {'ports pending' if runtime else 'not ready yet'} ({elapsed}s)") | |
| await asyncio.sleep(5) | |
| raise RuntimeError(f"Pod did not become ready within {timeout}s") | |
| def _ssh_exec_sync(self, ssh, cmd: str, timeout: int = 120) -> str: | |
| """Execute a command over SSH and return stdout (blocking).""" | |
| _, stdout, stderr = ssh.exec_command(cmd, timeout=timeout) | |
| out = stdout.read().decode("utf-8", errors="replace") | |
| err = stderr.read().decode("utf-8", errors="replace") | |
| exit_code = stdout.channel.recv_exit_status() | |
| if exit_code != 0 and "warning" not in err.lower(): | |
| logger.warning("SSH cmd failed (code %d): %s\nstderr: %s", exit_code, cmd[:100], err[:500]) | |
| return out.strip() | |
| async def _ssh_exec(self, ssh, cmd: str, timeout: int = 120) -> str: | |
| """Execute a command over SSH without blocking the event loop.""" | |
| return await asyncio.to_thread(self._ssh_exec_sync, ssh, cmd, timeout) | |
| def _parse_progress(self, job: CloudTrainingJob, line: str): | |
| """Parse Kohya training output for progress info.""" | |
| lower = line.lower() | |
| if "epoch" in lower and "/" in line: | |
| try: | |
| parts = lower.split("epoch") | |
| if len(parts) > 1: | |
| ep_part = parts[1].strip().split()[0] | |
| if "/" in ep_part: | |
| current, total = ep_part.split("/") | |
| job.current_epoch = int(current) | |
| job.total_epochs = int(total) | |
| job.progress = 0.15 + 0.75 * (job.current_epoch / max(job.total_epochs, 1)) | |
| except (ValueError, IndexError): | |
| pass | |
| if "loss=" in line or "loss:" in line: | |
| try: | |
| loss_str = line.split("loss")[1].strip("=: ").split()[0].strip(",") | |
| job.loss = float(loss_str) | |
| except (ValueError, IndexError): | |
| pass | |
| if "steps:" in lower or "step " in lower: | |
| try: | |
| import re | |
| step_match = re.search(r"(\d+)/(\d+)", line) | |
| if step_match: | |
| job.current_step = int(step_match.group(1)) | |
| job.total_steps = int(step_match.group(2)) | |
| except (ValueError, IndexError): | |
| pass | |
| def get_job(self, job_id: str) -> CloudTrainingJob | None: | |
| return self._jobs.get(job_id) | |
| def list_jobs(self) -> list[CloudTrainingJob]: | |
| return list(self._jobs.values()) | |
| async def cancel_job(self, job_id: str) -> bool: | |
| """Cancel a cloud training job and terminate its pod.""" | |
| job = self._jobs.get(job_id) | |
| if not job: | |
| return False | |
| if job.pod_id: | |
| try: | |
| await asyncio.to_thread(runpod.terminate_pod, job.pod_id) | |
| except Exception: | |
| pass | |
| job.status = "failed" | |
| job.error = "Cancelled by user" | |
| return True | |
| async def delete_job(self, job_id: str) -> bool: | |
| """Delete a training job from memory and database.""" | |
| if job_id not in self._jobs: | |
| return False | |
| del self._jobs[job_id] | |
| try: | |
| async with catalog_session_factory() as session: | |
| result = await session.execute( | |
| __import__('sqlalchemy').select(TrainingJobDB).where(TrainingJobDB.id == job_id) | |
| ) | |
| db_job = result.scalar_one_or_none() | |
| if db_job: | |
| await session.delete(db_job) | |
| await session.commit() | |
| except Exception as e: | |
| logger.warning("Failed to delete job from DB: %s", e) | |
| return True | |
| async def delete_failed_jobs(self) -> int: | |
| """Delete all failed/error training jobs.""" | |
| failed_ids = [jid for jid, j in self._jobs.items() if j.status in ("failed", "error")] | |
| for jid in failed_ids: | |
| await self.delete_job(jid) | |
| return len(failed_ids) | |