dippoo's picture
Sync all local changes: video routes, pod management, wavespeed, UI updates
e808ae1
raw
history blame
66.5 kB
"""RunPod cloud LoRA training — offload training to RunPod GPU pods.
Creates a temporary GPU pod, uploads training images, runs Kohya sd-scripts,
downloads the finished LoRA, then terminates the pod. No local GPU needed.
Supports multiple base models (FLUX, SD 1.5, SDXL) via model registry.
Usage:
Set RUNPOD_API_KEY in .env
Select "Cloud (RunPod)" in the training UI
"""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import runpod
import yaml
logger = logging.getLogger(__name__)
import os
from content_engine.config import settings, IS_HF_SPACES
from content_engine.models.database import catalog_session_factory, TrainingJob as TrainingJobDB
LORA_OUTPUT_DIR = settings.paths.lora_dir
if IS_HF_SPACES:
CONFIG_DIR = Path("/app/config")
else:
CONFIG_DIR = Path("D:/AI automation/content_engine/config")
# RunPod GPU options (id -> display name, approx cost/hr)
GPU_OPTIONS = {
# 24GB - SD 1.5, SDXL, FLUX.1 only (NOT enough for FLUX.2)
"NVIDIA GeForce RTX 3090": "RTX 3090 24GB (~$0.22/hr)",
"NVIDIA GeForce RTX 4090": "RTX 4090 24GB (~$0.44/hr)",
"NVIDIA GeForce RTX 5090": "RTX 5090 32GB (~$0.69/hr)",
"NVIDIA RTX A4000": "RTX A4000 16GB (~$0.20/hr)",
"NVIDIA RTX A5000": "RTX A5000 24GB (~$0.28/hr)",
# 48GB+ - Required for FLUX.2 Dev (Mistral text encoder needs ~48GB)
"NVIDIA RTX A6000": "RTX A6000 48GB (~$0.76/hr)",
"NVIDIA A40": "A40 48GB (~$0.64/hr)",
"NVIDIA L40": "L40 48GB (~$0.89/hr)",
"NVIDIA L40S": "L40S 48GB (~$1.09/hr)",
"NVIDIA A100 80GB PCIe": "A100 80GB (~$1.89/hr)",
"NVIDIA A100-SXM4-80GB": "A100 SXM 80GB (~$1.64/hr)",
"NVIDIA H100 80GB HBM3": "H100 80GB (~$3.89/hr)",
}
DEFAULT_GPU = "NVIDIA GeForce RTX 4090"
# Network volume for persistent model storage (avoids re-downloading models each run)
# Set RUNPOD_VOLUME_ID in .env to use a persistent volume
# Set RUNPOD_VOLUME_DC to the datacenter ID where the volume lives (e.g. "EU-RO-1")
NETWORK_VOLUME_ID = os.environ.get("RUNPOD_VOLUME_ID", "")
NETWORK_VOLUME_DC = os.environ.get("RUNPOD_VOLUME_DC", "")
# Docker image with PyTorch + CUDA pre-installed
DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
def load_model_registry() -> dict[str, dict]:
"""Load training model configurations from config/models.yaml."""
models_file = CONFIG_DIR / "models.yaml"
if not models_file.exists():
logger.warning("Model registry not found: %s", models_file)
return {}
with open(models_file) as f:
config = yaml.safe_load(f)
return config.get("training_models", {})
@dataclass
class CloudTrainingJob:
"""Tracks state of a RunPod cloud training job."""
id: str
name: str
status: str = "pending" # pending, creating_pod, uploading, installing, training, downloading, completed, failed
progress: float = 0.0
current_epoch: int = 0
total_epochs: int = 0
current_step: int = 0
total_steps: int = 0
loss: float | None = None
started_at: float | None = None
completed_at: float | None = None
output_path: str | None = None
error: str | None = None
log_lines: list[str] = field(default_factory=list)
pod_id: str | None = None
gpu_type: str = DEFAULT_GPU
cost_estimate: str | None = None
base_model: str = "sd15_realistic"
model_type: str = "sd15"
_db_callback: Any = None # called on state changes to persist to DB
def _log(self, msg: str):
self.log_lines.append(msg)
if len(self.log_lines) > 200:
self.log_lines = self.log_lines[-200:]
logger.info("[%s] %s", self.id, msg)
if self._db_callback:
self._db_callback(self)
class RunPodTrainer:
"""Manages LoRA training on RunPod cloud GPUs."""
def __init__(self, api_key: str):
self._api_key = api_key
runpod.api_key = api_key
self._jobs: dict[str, CloudTrainingJob] = {}
self._model_registry = load_model_registry()
self._loaded_from_db = False
@property
def available(self) -> bool:
"""Check if RunPod is configured."""
# Re-set module-level key in case uvicorn reload cleared it
if self._api_key:
runpod.api_key = self._api_key
return bool(self._api_key)
def list_gpu_options(self) -> dict[str, str]:
return GPU_OPTIONS
def list_training_models(self) -> dict[str, dict]:
"""List available base models for training with their parameters."""
return {
key: {
"name": cfg.get("name", key),
"description": cfg.get("description", ""),
"model_type": cfg.get("model_type", "sd15"),
"resolution": cfg.get("resolution", 512),
"learning_rate": cfg.get("learning_rate", 1e-4),
"network_rank": cfg.get("network_rank", 32),
"network_alpha": cfg.get("network_alpha", 16),
"optimizer": cfg.get("optimizer", "AdamW8bit"),
"lr_scheduler": cfg.get("lr_scheduler", "cosine"),
"vram_required_gb": cfg.get("vram_required_gb", 8),
"recommended_images": cfg.get("recommended_images", "15-30 photos"),
}
for key, cfg in self._model_registry.items()
}
def get_model_config(self, model_key: str) -> dict | None:
"""Get configuration for a specific training model."""
return self._model_registry.get(model_key)
async def start_training(
self,
*,
name: str,
image_paths: list[str],
trigger_word: str = "",
base_model: str = "sd15_realistic",
resolution: int | None = None,
num_epochs: int = 10,
max_train_steps: int | None = None,
learning_rate: float | None = None,
network_rank: int | None = None,
network_alpha: int | None = None,
optimizer: str | None = None,
save_every_n_epochs: int = 2,
save_every_n_steps: int = 500,
gpu_type: str = DEFAULT_GPU,
) -> str:
"""Start a cloud training job. Returns job ID.
Parameters use model registry defaults if not specified.
"""
job_id = str(uuid.uuid4())[:8]
# Get model config (fall back to sd15_realistic if not found)
model_cfg = self._model_registry.get(base_model, self._model_registry.get("sd15_realistic", {}))
model_type = model_cfg.get("model_type", "sd15")
# Use provided values or model defaults
final_resolution = resolution or model_cfg.get("resolution", 512)
final_lr = learning_rate or model_cfg.get("learning_rate", 1e-4)
final_rank = network_rank or model_cfg.get("network_rank", 32)
final_alpha = network_alpha or model_cfg.get("network_alpha", 16)
final_optimizer = optimizer or model_cfg.get("optimizer", "AdamW8bit")
final_steps = max_train_steps or model_cfg.get("max_train_steps")
job = CloudTrainingJob(
id=job_id,
name=name,
status="pending",
total_epochs=num_epochs,
total_steps=final_steps,
gpu_type=gpu_type,
started_at=time.time(),
base_model=base_model,
model_type=model_type,
)
self._jobs[job_id] = job
job._db_callback = self._schedule_db_save
asyncio.ensure_future(self._save_to_db(job))
# Launch the full pipeline as a background task
asyncio.create_task(self._run_cloud_training(
job=job,
image_paths=image_paths,
trigger_word=trigger_word,
model_cfg=model_cfg,
resolution=final_resolution,
num_epochs=num_epochs,
max_train_steps=final_steps,
learning_rate=final_lr,
network_rank=final_rank,
network_alpha=final_alpha,
optimizer=final_optimizer,
save_every_n_epochs=save_every_n_epochs,
save_every_n_steps=save_every_n_steps,
))
return job_id
async def _run_cloud_training(
self,
job: CloudTrainingJob,
image_paths: list[str],
trigger_word: str,
model_cfg: dict,
resolution: int,
num_epochs: int,
max_train_steps: int | None,
learning_rate: float,
network_rank: int,
network_alpha: int,
optimizer: str,
save_every_n_epochs: int,
save_every_n_steps: int = 500,
):
"""Full cloud training pipeline: create pod -> upload -> train -> download -> cleanup."""
ssh = None
sftp = None
model_type = model_cfg.get("model_type", "sd15")
name = job.name
try:
# Step 1: Create pod
job.status = "creating_pod"
job._log(f"Creating RunPod with {job.gpu_type}...")
# Use network volume if configured (persists models across runs)
pod_kwargs = {
"container_disk_in_gb": 30,
"ports": "22/tcp",
"docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'",
}
if NETWORK_VOLUME_ID:
pod_kwargs["network_volume_id"] = NETWORK_VOLUME_ID
if NETWORK_VOLUME_DC:
pod_kwargs["data_center_id"] = NETWORK_VOLUME_DC
job._log(f"Using persistent network volume: {NETWORK_VOLUME_ID} (DC: {NETWORK_VOLUME_DC or 'auto'})")
else:
pod_kwargs["volume_in_gb"] = 75
pod = await asyncio.to_thread(
runpod.create_pod,
f"lora-train-{job.id}",
DOCKER_IMAGE,
job.gpu_type,
**pod_kwargs,
)
job.pod_id = pod["id"]
job._log(f"Pod created: {job.pod_id}")
# Wait for pod to be ready and get SSH info
job._log("Waiting for pod to start...")
ssh_host, ssh_port = await self._wait_for_pod_ready(job)
job._log(f"Pod ready at {ssh_host}:{ssh_port}")
# Step 2: Connect via SSH
import paramiko
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
for attempt in range(30):
try:
await asyncio.to_thread(
ssh.connect,
ssh_host, port=ssh_port,
username="root", password="runpod",
timeout=10,
)
break
except Exception:
if attempt == 29:
raise RuntimeError("Could not SSH into pod after 30 attempts")
await asyncio.sleep(5)
job._log("SSH connected")
# If using network volume, symlink to /workspace so all paths work
if NETWORK_VOLUME_ID:
await self._ssh_exec(ssh, "mkdir -p /runpod-volume/models && rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models")
job._log("Network volume symlinked to /workspace")
# Enable keepalive to prevent SSH timeout during uploads
transport = ssh.get_transport()
transport.set_keepalive(30)
sftp = ssh.open_sftp()
sftp.get_channel().settimeout(300) # 5 min timeout per file
# Step 3: Upload training images (compress first to speed up transfer)
job.status = "uploading"
resolution = model_cfg.get("resolution", 1024)
job._log(f"Compressing and uploading {len(image_paths)} training images...")
import tempfile
from PIL import Image
tmp_dir = Path(tempfile.mkdtemp(prefix="lora_upload_"))
folder_name = f"10_{trigger_word or 'character'}"
await self._ssh_exec(ssh, f"mkdir -p /workspace/dataset/{folder_name}")
for i, img_path in enumerate(image_paths):
p = Path(img_path)
if p.exists():
# Resize and convert to JPEG to reduce upload size
try:
img = Image.open(p)
img.thumbnail((resolution * 2, resolution * 2), Image.LANCZOS)
compressed = tmp_dir / f"{p.stem}.jpg"
img.save(compressed, "JPEG", quality=95)
upload_path = compressed
except Exception:
upload_path = p # fallback to original
remote_name = f"{p.stem}.jpg" if upload_path.suffix == ".jpg" else p.name
remote_path = f"/workspace/dataset/{folder_name}/{remote_name}"
for attempt in range(3):
try:
await asyncio.to_thread(sftp.put, str(upload_path), remote_path)
break
except (EOFError, OSError):
if attempt == 2:
raise
job._log(f"Upload retry {attempt+1} for {p.name}")
sftp.close()
sftp = ssh.open_sftp()
sftp.get_channel().settimeout(300)
# Upload matching caption .txt file if it exists locally
local_caption = p.with_suffix(".txt")
if local_caption.exists():
remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt"
await asyncio.to_thread(sftp.put, str(local_caption), remote_caption)
else:
# Fallback: create caption from trigger word
remote_caption = f"/workspace/dataset/{folder_name}/{p.stem}.txt"
def _write_caption():
with sftp.open(remote_caption, "w") as f:
f.write(trigger_word or "")
await asyncio.to_thread(_write_caption)
job._log(f"Uploaded {i+1}/{len(image_paths)}: {p.name}")
# Cleanup temp compressed images
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)
job._log("Images uploaded")
# Step 4: Install training framework on the pod (skip if cached on volume)
job.status = "installing"
job.progress = 0.05
training_framework = model_cfg.get("training_framework", "sd-scripts")
if training_framework == "musubi-tuner":
# FLUX.2 uses musubi-tuner (Kohya's newer framework)
tuner_dir = "/workspace/musubi-tuner"
install_cmds = []
# Check if already present in workspace
tuner_exist = (await self._ssh_exec(ssh, f"test -f {tuner_dir}/pyproject.toml && echo EXISTS || echo MISSING")).strip()
if tuner_exist == "EXISTS":
job._log("musubi-tuner found in workspace")
else:
# Check volume cache
vol_exist = (await self._ssh_exec(ssh, "test -f /runpod-volume/musubi-tuner/pyproject.toml && echo EXISTS || echo MISSING")).strip()
if vol_exist == "EXISTS":
job._log("Restoring musubi-tuner from volume cache...")
await self._ssh_exec(ssh, f"rm -rf {tuner_dir} 2>/dev/null; cp -r /runpod-volume/musubi-tuner {tuner_dir}")
else:
job._log("Cloning musubi-tuner from GitHub...")
await self._ssh_exec(ssh, f"rm -rf {tuner_dir} /runpod-volume/musubi-tuner 2>/dev/null; true")
install_cmds.append(f"cd /workspace && git clone --depth 1 https://github.com/kohya-ss/musubi-tuner.git")
# Save to volume for future pods
if NETWORK_VOLUME_ID:
install_cmds.append(f"cp -r {tuner_dir} /runpod-volume/musubi-tuner")
# Always install pip deps (they are pod-local, lost on every new pod)
job._log("Installing pip dependencies (accelerate, torch, etc.)...")
install_cmds.extend([
f"cd {tuner_dir} && pip install -e . 2>&1 | tail -5",
"pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes 2>&1 | tail -5",
])
else:
# SD 1.5 / SDXL / FLUX.1 use sd-scripts
scripts_exist = (await self._ssh_exec(ssh, "test -f /workspace/sd-scripts/setup.py && echo EXISTS || echo MISSING")).strip()
if scripts_exist == "EXISTS":
job._log("Kohya sd-scripts already cached on volume, updating...")
install_cmds = [
"cd /workspace/sd-scripts && git pull 2>&1 | tail -1",
]
else:
job._log("Installing Kohya sd-scripts (this takes a few minutes)...")
install_cmds = [
"cd /workspace && git clone --depth 1 https://github.com/kohya-ss/sd-scripts.git",
]
# Always install pip deps (pod-local, lost on new pods)
install_cmds.extend([
"cd /workspace/sd-scripts && pip install -r requirements.txt 2>&1 | tail -1",
"pip install accelerate lion-pytorch prodigyopt safetensors bitsandbytes xformers 2>&1 | tail -1",
])
for cmd in install_cmds:
out = await self._ssh_exec(ssh, cmd, timeout=600)
job._log(out[:200] if out else "done")
# Download base model from HuggingFace (skip if already on network volume)
hf_repo = model_cfg.get("hf_repo", "SG161222/Realistic_Vision_V5.1_noVAE")
hf_filename = model_cfg.get("hf_filename", "Realistic_Vision_V5.1_fp16-no-ema.safetensors")
model_name = model_cfg.get("name", job.base_model)
job.progress = 0.1
await self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120)
if model_type == "flux2":
# FLUX.2 models are stored in a directory structure on the volume
flux2_dir = "/workspace/models/FLUX.2-dev"
dit_path = f"{flux2_dir}/flux2-dev.safetensors"
vae_path = f"{flux2_dir}/ae.safetensors" # Original BFL format (not diffusers)
te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
dit_exists = (await self._ssh_exec(ssh, f"test -f {dit_path} && echo EXISTS || echo MISSING")).strip()
vae_exists = (await self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING")).strip()
te_exists = (await self._ssh_exec(ssh, f"test -f {te_path} && echo EXISTS || echo MISSING")).strip()
if dit_exists != "EXISTS" or te_exists != "EXISTS":
missing = []
if dit_exists != "EXISTS":
missing.append("DiT")
if te_exists != "EXISTS":
missing.append("text encoder")
raise RuntimeError(f"FLUX.2 Dev missing on volume: {', '.join(missing)}. Please download models to the network volume first.")
# Download ae.safetensors (original format VAE) if not present
if vae_exists != "EXISTS":
job._log("Downloading FLUX.2 VAE (ae.safetensors, 336MB)...")
await self._ssh_exec(ssh, """pip install huggingface_hub 2>&1 | tail -1""", timeout=120)
await self._ssh_exec(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('black-forest-labs/FLUX.2-dev', 'ae.safetensors', local_dir='{flux2_dir}')
print('Downloaded ae.safetensors')
" 2>&1 | tail -5""", timeout=600)
# Verify download
vae_check = (await self._ssh_exec(ssh, f"test -f {vae_path} && echo EXISTS || echo MISSING")).strip()
if vae_check != "EXISTS":
raise RuntimeError("Failed to download ae.safetensors")
job._log("VAE downloaded")
job._log("FLUX.2 Dev models ready")
elif model_type == "wan22":
# WAN 2.2 T2V — 4 model files stored in /workspace/models/WAN2.2/
wan_dir = "/workspace/models/WAN2.2"
await self._ssh_exec(ssh, f"mkdir -p {wan_dir}")
wan_files = {
"DiT low-noise": {
"path": f"{wan_dir}/wan2.2_t2v_low_noise_14B_fp16.safetensors",
"repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
"filename": "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp16.safetensors",
},
"DiT high-noise": {
"path": f"{wan_dir}/wan2.2_t2v_high_noise_14B_fp16.safetensors",
"repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
"filename": "split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp16.safetensors",
},
"VAE": {
"path": f"{wan_dir}/Wan2.1_VAE.pth",
"repo": "Wan-AI/Wan2.1-I2V-14B-720P",
"filename": "Wan2.1_VAE.pth",
},
"T5 text encoder": {
"path": f"{wan_dir}/models_t5_umt5-xxl-enc-bf16.pth",
"repo": "Wan-AI/Wan2.1-I2V-14B-720P",
"filename": "models_t5_umt5-xxl-enc-bf16.pth",
},
}
for label, info in wan_files.items():
exists = (await self._ssh_exec(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip()
if exists == "EXISTS":
job._log(f"WAN 2.2 {label} already cached")
else:
job._log(f"Downloading WAN 2.2 {label}...")
await self._ssh_exec(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('{info['repo']}', '{info['filename']}', local_dir='{wan_dir}')
# hf_hub_download puts files in subdirs matching the filename path — move to root
import os, shutil
downloaded = os.path.join('{wan_dir}', '{info['filename']}')
target = '{info['path']}'
if os.path.exists(downloaded) and downloaded != target:
shutil.move(downloaded, target)
print('Downloaded {label}')
" 2>&1 | tail -5""", timeout=1800)
# Verify
check = (await self._ssh_exec(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip()
if check != "EXISTS":
raise RuntimeError(f"Failed to download WAN 2.2 {label}")
job._log("WAN 2.2 models ready")
else:
# SD 1.5 / SDXL / FLUX.1 — download single model file
model_exists = (await self._ssh_exec(ssh, f"test -f /workspace/models/{hf_filename} && echo EXISTS || echo MISSING")).strip()
if model_exists == "EXISTS":
job._log(f"Base model already cached on volume: {model_name}")
else:
job._log(f"Downloading base model: {model_name}...")
await self._ssh_exec(ssh, f"""
python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('{hf_repo}', '{hf_filename}', local_dir='/workspace/models')
" 2>&1 | tail -5
""", timeout=1200)
# For FLUX.1, download additional required models (CLIP, T5, VAE)
if model_type == "flux":
flux_files_check = (await self._ssh_exec(ssh, "test -f /workspace/models/clip_l.safetensors && test -f /workspace/models/t5xxl_fp16.safetensors && test -f /workspace/models/ae.safetensors && echo EXISTS || echo MISSING")).strip()
if flux_files_check == "EXISTS":
job._log("FLUX.1 auxiliary models already cached on volume")
else:
job._log("Downloading FLUX.1 auxiliary models (CLIP, T5, VAE)...")
job.progress = 0.12
await self._ssh_exec(ssh, """
python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('comfyanonymous/flux_text_encoders', 'clip_l.safetensors', local_dir='/workspace/models')
hf_hub_download('comfyanonymous/flux_text_encoders', 't5xxl_fp16.safetensors', local_dir='/workspace/models')
hf_hub_download('black-forest-labs/FLUX.1-dev', 'ae.safetensors', local_dir='/workspace/models')
" 2>&1 | tail -5
""", timeout=1200)
job._log("Base model ready")
job.progress = 0.15
# Step 5: Run training
job.status = "training"
job._log(f"Starting {model_type.upper()} LoRA training...")
if model_type == "flux2":
model_path = f"/workspace/models/FLUX.2-dev/flux2-dev.safetensors"
elif model_type == "wan22":
model_path = "/workspace/models/WAN2.2/wan2.2_t2v_low_noise_14B_fp16.safetensors"
else:
model_path = f"/workspace/models/{hf_filename}"
# For musubi-tuner, create TOML dataset config
if training_framework == "musubi-tuner":
folder_name = f"10_{trigger_word or 'character'}"
toml_content = f"""[[datasets]]
image_directory = "/workspace/dataset/{folder_name}"
caption_extension = ".txt"
batch_size = 1
num_repeats = 10
resolution = [{resolution}, {resolution}]
"""
await self._ssh_exec(ssh, f"cat > /workspace/dataset.toml << 'TOMLEOF'\n{toml_content}TOMLEOF")
job._log("Created dataset.toml config")
# musubi-tuner requires pre-caching latents and text encoder outputs
if model_type == "wan22":
wan_dir = "/workspace/models/WAN2.2"
vae_path = f"{wan_dir}/Wan2.1_VAE.pth"
te_path = f"{wan_dir}/models_t5_umt5-xxl-enc-bf16.pth"
job._log("Caching WAN 2.2 latents (VAE encoding)...")
job.progress = 0.15
self._schedule_db_save(job)
cache_latents_cmd = (
f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
f" python src/musubi_tuner/wan_cache_latents.py"
f" --dataset_config /workspace/dataset.toml"
f" --vae {vae_path}"
f" --vae_dtype bfloat16"
f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}"
)
out = await self._ssh_exec(ssh, cache_latents_cmd, timeout=600)
last_lines = out.split('\n')[-30:]
job._log('\n'.join(last_lines))
if "EXIT_CODE=0" not in out:
err_log = await self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10")
job._log(f"Cache error details: {err_log}")
raise RuntimeError(f"WAN latent caching failed")
job._log("Caching WAN 2.2 text encoder outputs (T5)...")
job.progress = 0.25
self._schedule_db_save(job)
cache_te_cmd = (
f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
f" python src/musubi_tuner/wan_cache_text_encoder_outputs.py"
f" --dataset_config /workspace/dataset.toml"
f" --t5 {te_path}"
f" --batch_size 16"
f" 2>&1; echo EXIT_CODE=$?"
)
out = await self._ssh_exec(ssh, cache_te_cmd, timeout=600)
job._log(out[-500:] if out else "done")
if "EXIT_CODE=0" not in out:
raise RuntimeError(f"WAN text encoder caching failed: {out[-200:]}")
else:
# FLUX.2 caching
flux2_dir = "/workspace/models/FLUX.2-dev"
vae_path = f"{flux2_dir}/ae.safetensors"
te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
job._log("Caching latents (VAE encoding)...")
job.progress = 0.15
self._schedule_db_save(job)
cache_latents_cmd = (
f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python src/musubi_tuner/flux_2_cache_latents.py"
f" --dataset_config /workspace/dataset.toml"
f" --vae {vae_path}"
f" --model_version dev"
f" --vae_dtype bfloat16"
f" 2>&1 | tee /tmp/cache_latents.log; echo EXIT_CODE=${{PIPESTATUS[0]}}"
)
out = await self._ssh_exec(ssh, cache_latents_cmd, timeout=600)
last_lines = out.split('\n')[-30:]
job._log('\n'.join(last_lines))
if "EXIT_CODE=0" not in out:
err_log = await self._ssh_exec(ssh, "grep -i 'error\\|exception\\|traceback\\|failed' /tmp/cache_latents.log | tail -10")
job._log(f"Cache error details: {err_log}")
raise RuntimeError(f"Latent caching failed")
job._log("Caching text encoder outputs (bf16)...")
job.progress = 0.25
self._schedule_db_save(job)
cache_te_cmd = (
f"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
f" python src/musubi_tuner/flux_2_cache_text_encoder_outputs.py"
f" --dataset_config /workspace/dataset.toml"
f" --text_encoder {te_path}"
f" --model_version dev"
f" --batch_size 1"
f" 2>&1; echo EXIT_CODE=$?"
)
out = await self._ssh_exec(ssh, cache_te_cmd, timeout=600)
job._log(out[-500:] if out else "done")
if "EXIT_CODE=0" not in out:
raise RuntimeError(f"Text encoder caching failed: {out[-200:]}")
# Build training command based on model type
train_cmd = self._build_training_command(
model_type=model_type,
model_path=model_path,
name=name,
resolution=resolution,
num_epochs=num_epochs,
max_train_steps=max_train_steps,
learning_rate=learning_rate,
network_rank=network_rank,
network_alpha=network_alpha,
optimizer=optimizer,
save_every_n_epochs=save_every_n_epochs,
save_every_n_steps=save_every_n_steps,
model_cfg=model_cfg,
gpu_type=job.gpu_type,
)
# Execute training in a detached process (survives SSH disconnect)
job._log("Starting training (detached — survives disconnects)...")
log_file = "/tmp/training.log"
pid_file = "/tmp/training.pid"
exit_file = "/tmp/training.exit"
await self._ssh_exec(ssh, f"rm -f {log_file} {exit_file} {pid_file}")
# Write training command to a script file (avoids quoting issues with nohup)
script_file = "/tmp/train.sh"
await self._ssh_exec(ssh, f"cat > {script_file} << 'TRAINEOF'\n#!/bin/bash\n{train_cmd} > {log_file} 2>&1\necho $? > {exit_file}\nTRAINEOF")
await self._ssh_exec(ssh, f"chmod +x {script_file}")
# Verify script was written
script_check = (await self._ssh_exec(ssh, f"wc -l < {script_file}")).strip()
job._log(f"Training script written ({script_check} lines)")
# Launch fully detached: close all FDs so SSH channel doesn't hang
await self._ssh_exec(
ssh,
f"setsid {script_file} </dev/null >/dev/null 2>&1 &\necho $! > {pid_file}",
timeout=15,
)
await asyncio.sleep(3)
pid = (await self._ssh_exec(ssh, f"cat {pid_file} 2>/dev/null")).strip()
if not pid:
# Fallback: find the process by script name
pid = (await self._ssh_exec(ssh, "pgrep -f train.sh 2>/dev/null | head -1")).strip()
job._log(f"Training PID: {pid}")
# Verify process is actually running
if pid:
running = (await self._ssh_exec(ssh, f"kill -0 {pid} 2>&1 && echo RUNNING || echo DEAD")).strip()
job._log(f"Process status: {running}")
if "DEAD" in running:
# Check if it already wrote an exit code (fast failure)
early_exit = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip()
early_log = (await self._ssh_exec(ssh, f"cat {log_file} 2>/dev/null | tail -20")).strip()
raise RuntimeError(f"Training process died immediately. Exit: {early_exit}\nLog: {early_log}")
else:
early_log = (await self._ssh_exec(ssh, f"cat {log_file} 2>/dev/null | tail -20")).strip()
raise RuntimeError(f"Failed to start training process.\nLog: {early_log}")
# Monitor the log file (reconnect-safe)
last_offset = 0
while True:
# Check if training finished
exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip()
if exit_check:
exit_code = int(exit_check)
# Read remaining log
remaining = (await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30))
if remaining:
for line in remaining.split("\n"):
line = line.strip()
if line:
job._log(line)
self._parse_progress(job, line)
if exit_code != 0:
raise RuntimeError(f"Training failed with exit code {exit_code}")
break
# Read new log output
try:
new_output = (await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30))
if new_output:
last_offset += len(new_output.encode("utf-8"))
for line in new_output.replace("\r", "\n").split("\n"):
line = line.strip()
if not line:
continue
job._log(line)
self._parse_progress(job, line)
self._schedule_db_save(job)
except Exception:
job._log("Log read failed, retrying...")
await asyncio.sleep(5)
job._log("Training completed on RunPod!")
job.progress = 0.9
# Step 6: Save LoRA to network volume and download locally
job.status = "downloading"
# First, copy to network volume for persistence
job._log("Saving LoRA to network volume...")
await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras")
remote_output = f"/workspace/output/{name}.safetensors"
# Find the output file
check = (await self._ssh_exec(ssh, f"test -f {remote_output} && echo EXISTS || echo MISSING")).strip()
if check == "MISSING":
remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip()
if remote_files:
remote_output = remote_files.split("\n")[-1].strip()
else:
raise RuntimeError("No .safetensors output found")
await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
# Also save intermediate checkpoints (step 500, 1000, 1500, etc.)
checkpoint_files = (await self._ssh_exec(ssh, f"ls /workspace/output/{name}-step*.safetensors 2>/dev/null")).strip()
if checkpoint_files:
for ckpt in checkpoint_files.split("\n"):
ckpt = ckpt.strip()
if ckpt:
ckpt_name = ckpt.split("/")[-1]
await self._ssh_exec(ssh, f"cp {ckpt} /runpod-volume/loras/{ckpt_name}")
job._log(f"Checkpoint saved: /runpod-volume/loras/{ckpt_name}")
# Download locally (skip on HF Spaces — limited storage)
if IS_HF_SPACES:
job.output_path = f"/runpod-volume/loras/{name}.safetensors"
job._log("LoRA saved on RunPod volume (ready for generation)")
else:
job._log("Downloading LoRA to local machine...")
LORA_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
local_path = LORA_OUTPUT_DIR / f"{name}.safetensors"
await asyncio.to_thread(sftp.get, remote_output, str(local_path))
job.output_path = str(local_path)
job._log(f"LoRA saved locally to {local_path}")
# Done!
job.status = "completed"
job.progress = 1.0
job.completed_at = time.time()
elapsed = (job.completed_at - job.started_at) / 60
job._log(f"Cloud training complete in {elapsed:.1f} minutes")
except Exception as e:
job.status = "failed"
job.error = str(e)
job._log(f"ERROR: {e}")
logger.error("Cloud training failed: %s", e, exc_info=True)
finally:
# Cleanup: close SSH and terminate pod
if sftp:
try:
sftp.close()
except Exception:
pass
if ssh:
try:
ssh.close()
except Exception:
pass
# Clean up local training images (saves HF Spaces storage)
if image_paths:
import shutil
first_image_dir = Path(image_paths[0]).parent
if first_image_dir.exists() and "training_uploads" in str(first_image_dir):
shutil.rmtree(first_image_dir, ignore_errors=True)
if job.pod_id:
try:
job._log("Terminating RunPod...")
await asyncio.to_thread(runpod.terminate_pod, job.pod_id)
job._log("Pod terminated")
except Exception as e:
job._log(f"Warning: Failed to terminate pod {job.pod_id}: {e}")
def _schedule_db_save(self, job: CloudTrainingJob):
"""Schedule a DB save (non-blocking)."""
try:
asyncio.get_event_loop().create_task(self._save_to_db(job))
except RuntimeError:
pass # no event loop
async def _save_to_db(self, job: CloudTrainingJob):
"""Persist job state to database."""
try:
from sqlalchemy import text
async with catalog_session_factory() as session:
# Use raw INSERT OR REPLACE for SQLite upsert
await session.execute(
text("""INSERT OR REPLACE INTO training_jobs
(id, name, status, progress, current_epoch, total_epochs,
current_step, total_steps, loss, started_at, completed_at,
output_path, error, log_text, pod_id, gpu_type, backend,
base_model, model_type, created_at)
VALUES (:id, :name, :status, :progress, :current_epoch, :total_epochs,
:current_step, :total_steps, :loss, :started_at, :completed_at,
:output_path, :error, :log_text, :pod_id, :gpu_type, :backend,
:base_model, :model_type, COALESCE((SELECT created_at FROM training_jobs WHERE id = :id), CURRENT_TIMESTAMP))
"""),
{
"id": job.id, "name": job.name, "status": job.status,
"progress": job.progress, "current_epoch": job.current_epoch,
"total_epochs": job.total_epochs, "current_step": job.current_step,
"total_steps": job.total_steps, "loss": job.loss,
"started_at": job.started_at, "completed_at": job.completed_at,
"output_path": job.output_path, "error": job.error,
"log_text": "\n".join(job.log_lines[-200:]),
"pod_id": job.pod_id, "gpu_type": job.gpu_type,
"backend": "runpod", "base_model": job.base_model,
"model_type": job.model_type,
}
)
await session.commit()
except Exception as e:
logger.warning("Failed to save training job to DB: %s", e)
async def _load_jobs_from_db(self):
"""Load previously saved jobs from database on startup."""
try:
from sqlalchemy import select
async with catalog_session_factory() as session:
result = await session.execute(
select(TrainingJobDB).order_by(TrainingJobDB.created_at.desc()).limit(20)
)
db_jobs = result.scalars().all()
for db_job in db_jobs:
if db_job.id not in self._jobs:
job = CloudTrainingJob(
id=db_job.id,
name=db_job.name,
status=db_job.status,
progress=db_job.progress or 0.0,
current_epoch=db_job.current_epoch or 0,
total_epochs=db_job.total_epochs or 0,
current_step=db_job.current_step or 0,
total_steps=db_job.total_steps or 0,
loss=db_job.loss,
started_at=db_job.started_at,
completed_at=db_job.completed_at,
output_path=db_job.output_path,
error=db_job.error,
log_lines=(db_job.log_text or "").split("\n") if db_job.log_text else [],
pod_id=db_job.pod_id,
gpu_type=db_job.gpu_type or DEFAULT_GPU,
base_model=db_job.base_model or "sd15_realistic",
model_type=db_job.model_type or "sd15",
)
self._jobs[db_job.id] = job
# Try to reconnect to running training pods
if job.status not in ("completed", "failed") and job.pod_id:
try:
pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
if pod and pod.get("desiredStatus") == "RUNNING":
job.status = "training"
job.error = None
job._log("Reconnecting to running training pod after restart...")
asyncio.create_task(self._reconnect_training(job))
logger.info("Reconnecting to training pod %s for job %s", job.pod_id, job.id)
else:
job.status = "failed"
job.error = "Pod terminated during server restart"
except Exception as e:
logger.warning("Could not check pod %s: %s", job.pod_id, e)
job.status = "failed"
job.error = "Interrupted by server restart"
elif job.status not in ("completed", "failed"):
job.status = "failed"
job.error = "Interrupted by server restart"
except Exception as e:
logger.warning("Failed to load training jobs from DB: %s", e)
async def ensure_loaded(self):
"""Load jobs from DB on first access."""
if not self._loaded_from_db:
self._loaded_from_db = True
await self._load_jobs_from_db()
async def _reconnect_training(self, job: CloudTrainingJob):
"""Reconnect to a training pod after server restart and resume log monitoring."""
import paramiko
ssh = None
try:
# Get SSH info from RunPod
pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
if not pod:
raise RuntimeError("Pod not found")
runtime = pod.get("runtime") or {}
ports = runtime.get("ports") or []
ssh_host = ssh_port = None
for p in ports:
if p.get("privatePort") == 22:
ssh_host = p.get("ip")
ssh_port = p.get("publicPort")
if not ssh_host or not ssh_port:
raise RuntimeError("SSH port not available")
# Connect SSH
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
await asyncio.to_thread(
ssh.connect, ssh_host, port=int(ssh_port),
username="root", password="runpod", timeout=10,
)
transport = ssh.get_transport()
transport.set_keepalive(30)
job._log(f"Reconnected to pod {job.pod_id}")
# Check if training is still running
log_file = "/tmp/training.log"
exit_file = "/tmp/training.exit"
pid_file = "/tmp/training.pid"
exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip()
if exit_check:
# Training already finished while we were disconnected
exit_code = int(exit_check)
log_tail = await self._ssh_exec(ssh, f"tail -50 {log_file} 2>/dev/null")
for line in log_tail.split("\n"):
line = line.strip()
if line:
job._log(line)
self._parse_progress(job, line)
if exit_code == 0:
job._log("Training completed while disconnected!")
# Copy LoRA to volume
name = job.name
await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras")
remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip()
if remote_files:
remote_output = remote_files.split("\n")[-1].strip()
await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
job.output_path = f"/runpod-volume/loras/{name}.safetensors"
job.status = "completed"
job.progress = 1.0
job.completed_at = time.time()
else:
raise RuntimeError(f"Training failed with exit code {exit_code}")
else:
# Training still running — resume log monitoring
pid = (await self._ssh_exec(ssh, f"cat {pid_file} 2>/dev/null")).strip()
job._log(f"Training still running (PID: {pid}), resuming monitoring...")
last_offset = 0
while True:
exit_check = (await self._ssh_exec(ssh, f"cat {exit_file} 2>/dev/null")).strip()
if exit_check:
exit_code = int(exit_check)
remaining = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)
if remaining:
for line in remaining.split("\n"):
line = line.strip()
if line:
job._log(line)
self._parse_progress(job, line)
if exit_code == 0:
# Copy LoRA to volume
name = job.name
await self._ssh_exec(ssh, "mkdir -p /runpod-volume/loras")
remote_files = (await self._ssh_exec(ssh, "ls /workspace/output/*.safetensors 2>/dev/null")).strip()
if remote_files:
remote_output = remote_files.split("\n")[-1].strip()
await self._ssh_exec(ssh, f"cp {remote_output} /runpod-volume/loras/{name}.safetensors")
job._log(f"LoRA saved to volume: /runpod-volume/loras/{name}.safetensors")
job.output_path = f"/runpod-volume/loras/{name}.safetensors"
job.status = "completed"
job.progress = 1.0
job.completed_at = time.time()
break
else:
raise RuntimeError(f"Training failed with exit code {exit_code}")
try:
new_output = await self._ssh_exec(ssh, f"tail -c +{last_offset + 1} {log_file} 2>/dev/null", timeout=30)
if new_output:
last_offset += len(new_output.encode("utf-8"))
for line in new_output.replace("\r", "\n").split("\n"):
line = line.strip()
if line:
job._log(line)
self._parse_progress(job, line)
self._schedule_db_save(job)
except Exception:
pass
await asyncio.sleep(5)
job._log("Training complete!")
except Exception as e:
job.status = "failed"
job.error = str(e)
job._log(f"Reconnect failed: {e}")
logger.error("Training reconnect failed for %s: %s", job.id, e)
finally:
if ssh:
try:
ssh.close()
except Exception:
pass
# Terminate pod
if job.pod_id:
try:
await asyncio.to_thread(runpod.terminate_pod, job.pod_id)
job._log("Pod terminated")
except Exception:
pass
self._schedule_db_save(job)
def _build_training_command(
self,
*,
model_type: str,
model_path: str,
name: str,
resolution: int,
num_epochs: int,
max_train_steps: int | None,
learning_rate: float,
network_rank: int,
network_alpha: int,
optimizer: str,
save_every_n_epochs: int,
save_every_n_steps: int = 500,
model_cfg: dict,
gpu_type: str = "",
) -> str:
"""Build the training command based on model type."""
# Common parameters
base_args = f"""
--train_data_dir="/workspace/dataset" \
--output_dir="/workspace/output" \
--output_name="{name}" \
--resolution={resolution} \
--train_batch_size=1 \
--learning_rate={learning_rate} \
--network_module=networks.lora \
--network_dim={network_rank} \
--network_alpha={network_alpha} \
--optimizer_type={optimizer} \
--save_every_n_epochs={save_every_n_epochs} \
--mixed_precision=fp16 \
--seed=42 \
--caption_extension=.txt \
--cache_latents \
--enable_bucket \
--save_model_as=safetensors"""
# Steps vs epochs
if max_train_steps:
base_args += f" \\\n --max_train_steps={max_train_steps}"
else:
base_args += f" \\\n --max_train_epochs={num_epochs}"
# LR scheduler
lr_scheduler = model_cfg.get("lr_scheduler", "cosine_with_restarts")
base_args += f" \\\n --lr_scheduler={lr_scheduler}"
if model_type == "flux2":
# FLUX.2 training via musubi-tuner
flux2_dir = "/workspace/models/FLUX.2-dev"
dit_path = f"{flux2_dir}/flux2-dev.safetensors"
vae_path = f"{flux2_dir}/ae.safetensors"
te_path = f"{flux2_dir}/text_encoder/model-00001-of-00010.safetensors"
network_mod = model_cfg.get("network_module", "networks.lora_flux_2")
ts_sampling = model_cfg.get("timestep_sampling", "flux2_shift")
lr_scheduler = model_cfg.get("lr_scheduler", "cosine")
# Build as list of args to avoid shell escaping issues
args = [
"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True",
"accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16",
"src/musubi_tuner/flux_2_train_network.py",
"--model_version dev",
f"--dit {dit_path}",
f"--vae {vae_path}",
f"--text_encoder {te_path}",
"--dataset_config /workspace/dataset.toml",
"--sdpa --mixed_precision bf16",
f"--timestep_sampling {ts_sampling} --weighting_scheme none",
f"--network_module {network_mod}",
f"--network_dim={network_rank}",
f"--network_alpha={network_alpha}",
"--gradient_checkpointing",
]
# Only use fp8_base on GPUs with native fp8 support (RTX 4090, H100)
# A100 and A6000 don't support fp8 tensor ops, and have enough VRAM without it
if gpu_type and ("4090" in gpu_type or "5090" in gpu_type or "L40S" in gpu_type or "H100" in gpu_type):
args.append("--fp8_base")
# Handle Prodigy optimizer (needs special class path and args)
if optimizer.lower() == "prodigy":
args.extend([
"--optimizer_type=prodigyopt.Prodigy",
f"--learning_rate={learning_rate}",
'--optimizer_args "weight_decay=0.01" "decouple=True" "use_bias_correction=True" "safeguard_warmup=True" "d_coef=2"',
])
else:
args.extend([
f"--optimizer_type={optimizer}",
f"--learning_rate={learning_rate}",
])
args.extend([
"--seed=42",
'--output_dir=/workspace/output',
f'--output_name={name}',
f"--lr_scheduler={lr_scheduler}",
])
if max_train_steps:
args.append(f"--max_train_steps={max_train_steps}")
if save_every_n_steps:
args.append(f"--save_every_n_steps={save_every_n_steps}")
else:
args.append(f"--save_every_n_epochs={save_every_n_epochs}")
else:
args.append(f"--max_train_epochs={num_epochs}")
args.append(f"--save_every_n_epochs={save_every_n_epochs}")
return " ".join(args) + " 2>&1"
elif model_type == "wan22":
# WAN 2.2 T2V LoRA training via musubi-tuner
wan_dir = "/workspace/models/WAN2.2"
dit_low = f"{wan_dir}/wan2.2_t2v_low_noise_14B_fp16.safetensors"
dit_high = f"{wan_dir}/wan2.2_t2v_high_noise_14B_fp16.safetensors"
network_mod = model_cfg.get("network_module", "networks.lora_wan")
ts_sampling = model_cfg.get("timestep_sampling", "shift")
discrete_shift = model_cfg.get("discrete_flow_shift", 5.0)
args = [
"cd /workspace/musubi-tuner && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True",
"accelerate launch --num_cpu_threads_per_process 1 --mixed_precision fp16",
"src/musubi_tuner/wan_train_network.py",
"--task t2v-A14B",
f"--dit {dit_low}",
f"--dit_high_noise {dit_high}",
"--dataset_config /workspace/dataset.toml",
"--sdpa --mixed_precision fp16",
"--gradient_checkpointing",
f"--timestep_sampling {ts_sampling}",
f"--discrete_flow_shift {discrete_shift}",
f"--network_module {network_mod}",
f"--network_dim={network_rank}",
f"--network_alpha={network_alpha}",
f"--optimizer_type={optimizer}",
f"--learning_rate={learning_rate}",
"--seed=42",
"--output_dir=/workspace/output",
f"--output_name={name}",
]
if max_train_steps:
args.append(f"--max_train_steps={max_train_steps}")
if save_every_n_steps:
args.append(f"--save_every_n_steps={save_every_n_steps}")
else:
args.append(f"--max_train_epochs={num_epochs}")
args.append(f"--save_every_n_epochs={save_every_n_epochs}")
return " ".join(args) + " 2>&1"
elif model_type == "flux":
# FLUX.1 training via sd-scripts
script = "flux_train_network.py"
flux_args = f"""
--pretrained_model_name_or_path="{model_path}" \
--clip_l="/workspace/models/clip_l.safetensors" \
--t5xxl="/workspace/models/t5xxl_fp16.safetensors" \
--ae="/workspace/models/ae.safetensors" \
--cache_text_encoder_outputs \
--cache_text_encoder_outputs_to_disk \
--fp8_base \
--split_mode"""
# Text encoder learning rate
text_encoder_lr = model_cfg.get("text_encoder_lr", 4e-5)
flux_args += f" \\\n --text_encoder_lr={text_encoder_lr}"
# Min SNR gamma if specified
min_snr = model_cfg.get("min_snr_gamma")
if min_snr:
flux_args += f" \\\n --min_snr_gamma={min_snr}"
return f"cd /workspace/sd-scripts && accelerate launch --num_cpu_threads_per_process 1 {script} {flux_args} {base_args} 2>&1"
elif model_type == "sdxl":
# SDXL-specific training
script = "sdxl_train_network.py"
clip_skip = model_cfg.get("clip_skip", 2)
return f"""cd /workspace/sd-scripts && accelerate launch --num_cpu_threads_per_process 1 {script} \
--pretrained_model_name_or_path="{model_path}" \
--clip_skip={clip_skip} \
--xformers {base_args} 2>&1"""
else:
# SD 1.5 / default training
script = "train_network.py"
clip_skip = model_cfg.get("clip_skip", 1)
return f"""cd /workspace/sd-scripts && accelerate launch --num_cpu_threads_per_process 1 {script} \
--pretrained_model_name_or_path="{model_path}" \
--clip_skip={clip_skip} \
--xformers {base_args} 2>&1"""
async def _wait_for_pod_ready(self, job: CloudTrainingJob, timeout: int = 600) -> tuple[str, int]:
"""Wait for pod to be running and return (ssh_host, ssh_port)."""
start = time.time()
while time.time() - start < timeout:
try:
pod = await asyncio.to_thread(runpod.get_pod, job.pod_id)
except Exception as e:
job._log(f" API error: {e}")
await asyncio.sleep(10)
continue
status = pod.get("desiredStatus", "")
runtime = pod.get("runtime")
if status == "RUNNING" and runtime:
ports = runtime.get("ports") or []
for port_info in (ports or []):
if port_info.get("privatePort") == 22:
ip = port_info.get("ip")
public_port = port_info.get("publicPort")
if ip and public_port:
return ip, int(public_port)
elapsed = int(time.time() - start)
if elapsed % 30 < 6:
job._log(f" Status: {status} | runtime: {'ports pending' if runtime else 'not ready yet'} ({elapsed}s)")
await asyncio.sleep(5)
raise RuntimeError(f"Pod did not become ready within {timeout}s")
def _ssh_exec_sync(self, ssh, cmd: str, timeout: int = 120) -> str:
"""Execute a command over SSH and return stdout (blocking)."""
_, stdout, stderr = ssh.exec_command(cmd, timeout=timeout)
out = stdout.read().decode("utf-8", errors="replace")
err = stderr.read().decode("utf-8", errors="replace")
exit_code = stdout.channel.recv_exit_status()
if exit_code != 0 and "warning" not in err.lower():
logger.warning("SSH cmd failed (code %d): %s\nstderr: %s", exit_code, cmd[:100], err[:500])
return out.strip()
async def _ssh_exec(self, ssh, cmd: str, timeout: int = 120) -> str:
"""Execute a command over SSH without blocking the event loop."""
return await asyncio.to_thread(self._ssh_exec_sync, ssh, cmd, timeout)
def _parse_progress(self, job: CloudTrainingJob, line: str):
"""Parse Kohya training output for progress info."""
lower = line.lower()
if "epoch" in lower and "/" in line:
try:
parts = lower.split("epoch")
if len(parts) > 1:
ep_part = parts[1].strip().split()[0]
if "/" in ep_part:
current, total = ep_part.split("/")
job.current_epoch = int(current)
job.total_epochs = int(total)
job.progress = 0.15 + 0.75 * (job.current_epoch / max(job.total_epochs, 1))
except (ValueError, IndexError):
pass
if "loss=" in line or "loss:" in line:
try:
loss_str = line.split("loss")[1].strip("=: ").split()[0].strip(",")
job.loss = float(loss_str)
except (ValueError, IndexError):
pass
if "steps:" in lower or "step " in lower:
try:
import re
step_match = re.search(r"(\d+)/(\d+)", line)
if step_match:
job.current_step = int(step_match.group(1))
job.total_steps = int(step_match.group(2))
except (ValueError, IndexError):
pass
def get_job(self, job_id: str) -> CloudTrainingJob | None:
return self._jobs.get(job_id)
def list_jobs(self) -> list[CloudTrainingJob]:
return list(self._jobs.values())
async def cancel_job(self, job_id: str) -> bool:
"""Cancel a cloud training job and terminate its pod."""
job = self._jobs.get(job_id)
if not job:
return False
if job.pod_id:
try:
await asyncio.to_thread(runpod.terminate_pod, job.pod_id)
except Exception:
pass
job.status = "failed"
job.error = "Cancelled by user"
return True
async def delete_job(self, job_id: str) -> bool:
"""Delete a training job from memory and database."""
if job_id not in self._jobs:
return False
del self._jobs[job_id]
try:
async with catalog_session_factory() as session:
result = await session.execute(
__import__('sqlalchemy').select(TrainingJobDB).where(TrainingJobDB.id == job_id)
)
db_job = result.scalar_one_or_none()
if db_job:
await session.delete(db_job)
await session.commit()
except Exception as e:
logger.warning("Failed to delete job from DB: %s", e)
return True
async def delete_failed_jobs(self) -> int:
"""Delete all failed/error training jobs."""
failed_ids = [jid for jid, j in self._jobs.items() if j.status in ("failed", "error")]
for jid in failed_ids:
await self.delete_job(jid)
return len(failed_ids)