dippoo's picture
Sync all local changes: video routes, pod management, wavespeed, UI updates
e808ae1
raw
history blame
80.2 kB
"""RunPod Pod management routes β€” start/stop GPU pods for generation.
Starts a persistent ComfyUI pod with network volume access.
Models and LoRAs are loaded from the shared network volume.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
import uuid
from pathlib import Path
from typing import Any
import runpod
from fastapi import APIRouter, File, HTTPException, UploadFile
from pydantic import BaseModel
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/pod", tags=["pod"])
# Persist pod state to disk so it survives server restarts
_POD_STATE_FILE = Path(__file__).parent.parent.parent.parent / "pod_state.json"
def _save_pod_state():
"""Save pod state to disk."""
try:
data = {k: v for k, v in _pod_state.items() if k != "setup_status"}
_POD_STATE_FILE.write_text(json.dumps(data))
except Exception as e:
logger.warning("Failed to save pod state: %s", e)
def _load_pod_state():
"""Load pod state from disk on startup."""
try:
if _POD_STATE_FILE.exists():
data = json.loads(_POD_STATE_FILE.read_text())
for k, v in data.items():
if k in _pod_state:
_pod_state[k] = v
logger.info("Restored pod state: pod_id=%s status=%s", _pod_state.get("pod_id"), _pod_state.get("status"))
except Exception as e:
logger.warning("Failed to load pod state: %s", e)
def _get_volume_config() -> tuple[str, str]:
"""Get network volume config at runtime (after dotenv loads)."""
return (
os.environ.get("RUNPOD_VOLUME_ID", ""),
os.environ.get("RUNPOD_VOLUME_DC", ""),
)
# Docker image β€” PyTorch base with CUDA, we install ComfyUI ourselves
DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04"
# Pod state
_pod_state = {
"pod_id": None,
"status": "stopped", # stopped, starting, setting_up, running, stopping
"ip": None,
"ssh_port": None,
"comfyui_port": None,
"gpu_type": "NVIDIA RTX A6000",
"model_type": "flux2",
"started_at": None,
"cost_per_hour": 0.76,
"setup_status": None,
}
_load_pod_state()
# GPU options (same as training)
GPU_OPTIONS = {
"NVIDIA A40": {"name": "A40 48GB", "vram": 48, "cost": 0.64},
"NVIDIA RTX A6000": {"name": "RTX A6000 48GB", "vram": 48, "cost": 0.76},
"NVIDIA L40": {"name": "L40 48GB", "vram": 48, "cost": 0.89},
"NVIDIA L40S": {"name": "L40S 48GB", "vram": 48, "cost": 1.09},
"NVIDIA A100-SXM4-80GB": {"name": "A100 SXM 80GB", "vram": 80, "cost": 1.64},
"NVIDIA A100 80GB PCIe": {"name": "A100 PCIe 80GB", "vram": 80, "cost": 1.89},
"NVIDIA H100 80GB HBM3": {"name": "H100 80GB", "vram": 80, "cost": 3.89},
"NVIDIA GeForce RTX 5090": {"name": "RTX 5090 32GB", "vram": 32, "cost": 0.69},
"NVIDIA GeForce RTX 4090": {"name": "RTX 4090 24GB", "vram": 24, "cost": 0.44},
"NVIDIA GeForce RTX 3090": {"name": "RTX 3090 24GB", "vram": 24, "cost": 0.22},
}
def _get_comfyui_url() -> str | None:
"""Get the ComfyUI URL via RunPod's HTTPS proxy.
RunPod HTTP ports are only accessible through their proxy at
https://{pod_id}-{private_port}.proxy.runpod.net
The raw IP:port from the API is an internal address, not publicly routable.
"""
pod_id = _pod_state.get("pod_id")
if pod_id:
return f"https://{pod_id}-8188.proxy.runpod.net"
return None
def _get_api_key() -> str:
key = os.environ.get("RUNPOD_API_KEY")
if not key:
raise HTTPException(503, "RUNPOD_API_KEY not configured")
runpod.api_key = key
return key
class StartPodRequest(BaseModel):
gpu_type: str = "NVIDIA RTX A6000"
model_type: str = "flux2"
class PodStatus(BaseModel):
status: str
pod_id: str | None = None
ip: str | None = None
port: int | None = None
gpu_type: str | None = None
model_type: str | None = None
cost_per_hour: float | None = None
setup_status: str | None = None
uptime_minutes: float | None = None
comfyui_url: str | None = None
@router.get("/status", response_model=PodStatus)
async def get_pod_status():
"""Get current pod status."""
_get_api_key()
if _pod_state["pod_id"]:
try:
pod = await asyncio.wait_for(
asyncio.to_thread(runpod.get_pod, _pod_state["pod_id"]),
timeout=10,
)
if pod:
desired = pod.get("desiredStatus", "")
if desired == "RUNNING":
runtime = pod.get("runtime") or {}
ports = runtime.get("ports") or []
for p in ports:
if p.get("privatePort") == 22:
_pod_state["ssh_ip"] = p.get("ip")
_pod_state["ssh_port"] = p.get("publicPort")
if p.get("privatePort") == 8188:
_pod_state["comfyui_ip"] = p.get("ip")
_pod_state["comfyui_port"] = p.get("publicPort")
# Use SSH IP as the main IP for display
_pod_state["ip"] = _pod_state.get("ssh_ip") or _pod_state.get("comfyui_ip")
elif desired == "EXITED":
_pod_state["status"] = "stopped"
_pod_state["pod_id"] = None
else:
_pod_state["status"] = "stopped"
_pod_state["pod_id"] = None
except asyncio.TimeoutError:
logger.warning("RunPod API timeout checking pod status")
except Exception as e:
logger.warning("Failed to check pod: %s", e)
uptime = None
if _pod_state["started_at"] and _pod_state["status"] in ("running", "setting_up"):
uptime = (time.time() - _pod_state["started_at"]) / 60
comfyui_url = _get_comfyui_url()
return PodStatus(
status=_pod_state["status"],
pod_id=_pod_state["pod_id"],
ip=_pod_state["ip"],
port=_pod_state.get("comfyui_port"),
gpu_type=_pod_state["gpu_type"],
model_type=_pod_state.get("model_type", "flux2"),
cost_per_hour=_pod_state["cost_per_hour"],
setup_status=_pod_state.get("setup_status"),
uptime_minutes=uptime,
comfyui_url=comfyui_url,
)
@router.get("/gpu-options")
async def list_gpu_options():
"""List available GPU types."""
return {"gpus": GPU_OPTIONS}
@router.get("/model-options")
async def list_model_options():
"""List available model types for the pod."""
return {
"models": {
"flux2": {"name": "FLUX.2 Dev", "description": "Best for realistic txt2img (requires 48GB+ VRAM)", "use_case": "txt2img"},
"flux1": {"name": "FLUX.1 Dev", "description": "Previous gen FLUX txt2img", "use_case": "txt2img"},
"wan22": {"name": "WAN 2.2 Remix", "description": "Realistic generation β€” dual-DiT MoE split-step (NSFW OK)", "use_case": "txt2img"},
"wan22_i2v": {"name": "WAN 2.2 I2V", "description": "Image-to-video generation", "use_case": "img2video"},
"wan22_animate": {"name": "WAN 2.2 Animate", "description": "Dance/motion transfer β€” animate a character from a driving video", "use_case": "animate"},
}
}
@router.post("/start")
async def start_pod(request: StartPodRequest):
"""Start a GPU pod with ComfyUI for generation."""
_get_api_key()
if _pod_state["status"] in ("running", "setting_up"):
return {"status": "already_running", "pod_id": _pod_state["pod_id"]}
if _pod_state["status"] == "starting":
return {"status": "starting", "message": "Pod is already starting"}
gpu_info = GPU_OPTIONS.get(request.gpu_type)
if not gpu_info:
raise HTTPException(400, f"Unknown GPU type: {request.gpu_type}")
_pod_state["status"] = "starting"
_pod_state["gpu_type"] = request.gpu_type
_pod_state["cost_per_hour"] = gpu_info["cost"]
_pod_state["model_type"] = request.model_type
_pod_state["setup_status"] = "Creating pod..."
try:
logger.info("Starting RunPod with %s for %s...", request.gpu_type, request.model_type)
pod_kwargs = {
"container_disk_in_gb": 30,
"ports": "22/tcp,8188/http",
"docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'",
}
volume_id, volume_dc = _get_volume_config()
if volume_id:
pod_kwargs["network_volume_id"] = volume_id
if volume_dc:
pod_kwargs["data_center_id"] = volume_dc
logger.info("Using network volume: %s (DC: %s)", volume_id, volume_dc)
else:
pod_kwargs["volume_in_gb"] = 75
logger.warning("No network volume configured β€” using temporary volume")
pod = await asyncio.to_thread(
runpod.create_pod,
f"comfyui-gen-{request.model_type}",
DOCKER_IMAGE,
request.gpu_type,
**pod_kwargs,
)
_pod_state["pod_id"] = pod["id"]
_pod_state["started_at"] = time.time()
_save_pod_state()
logger.info("Pod created: %s", pod["id"])
asyncio.create_task(_wait_and_setup_pod(pod["id"], request.model_type))
return {
"status": "starting",
"pod_id": pod["id"],
"message": f"Starting {gpu_info['name']} pod (~5-8 min for setup)",
}
except Exception as e:
_pod_state["status"] = "stopped"
_pod_state["setup_status"] = None
logger.error("Failed to start pod: %s", e)
raise HTTPException(500, f"Failed to start pod: {e}")
async def _wait_and_setup_pod(pod_id: str, model_type: str, timeout: int = 600):
"""Wait for pod to be ready, then install ComfyUI and link models via SSH."""
start = time.time()
ssh_host = None
ssh_port = None
# Phase 1: Wait for SSH to be available
_pod_state["setup_status"] = "Waiting for pod to start..."
while time.time() - start < timeout:
try:
pod = await asyncio.to_thread(runpod.get_pod, pod_id)
if pod and pod.get("desiredStatus") == "RUNNING":
runtime = pod.get("runtime") or {}
ports = runtime.get("ports") or []
for p in ports:
if p.get("privatePort") == 22:
ssh_host = p.get("ip")
ssh_port = p.get("publicPort")
_pod_state["ssh_ip"] = ssh_host
_pod_state["ssh_port"] = ssh_port
_pod_state["ip"] = ssh_host
if p.get("privatePort") == 8188:
_pod_state["comfyui_ip"] = p.get("ip")
_pod_state["comfyui_port"] = p.get("publicPort")
if ssh_host and ssh_port:
break
except Exception as e:
logger.debug("Waiting for pod: %s", e)
await asyncio.sleep(5)
if not ssh_host or not ssh_port:
logger.error("Pod did not become ready within %ds", timeout)
_pod_state["status"] = "stopped"
_pod_state["setup_status"] = "Failed: pod did not start"
return
# Phase 2: SSH in and set up ComfyUI
_pod_state["status"] = "setting_up"
_pod_state["setup_status"] = "Connecting via SSH..."
import paramiko
async def _ssh_connect_new() -> "paramiko.SSHClient":
"""Create a fresh SSH connection to the pod."""
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
for attempt in range(10):
try:
await asyncio.to_thread(
client.connect, ssh_host, port=int(ssh_port),
username="root", password="runpod", timeout=15,
banner_timeout=30,
)
client.get_transport().set_keepalive(30)
return client
except Exception:
if attempt == 9:
raise
await asyncio.sleep(5)
raise RuntimeError("SSH connection failed after retries")
async def _ssh_exec_r(cmd: str, timeout: int = 120) -> str:
"""Execute SSH command, reconnecting once if the session dropped."""
nonlocal ssh
try:
t = ssh.get_transport()
if t is None or not t.is_active():
logger.info("SSH session dropped, reconnecting...")
ssh = await _ssh_connect_new()
return await _ssh_exec_async(ssh, cmd, timeout)
except Exception as e:
if "not active" in str(e).lower() or "session" in str(e).lower():
logger.info("SSH error '%s', reconnecting and retrying...", e)
ssh = await _ssh_connect_new()
return await _ssh_exec_async(ssh, cmd, timeout)
raise
for attempt in range(30):
try:
ssh = await _ssh_connect_new()
break
except Exception:
if attempt == 29:
_pod_state["setup_status"] = "Failed: SSH connection error"
_pod_state["status"] = "stopped"
return
await asyncio.sleep(5)
try:
# Symlink network volume
volume_id, _ = _get_volume_config()
if volume_id:
await _ssh_exec_async(ssh, "mkdir -p /runpod-volume/models /runpod-volume/loras")
await _ssh_exec_async(ssh, "rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models")
# Install ComfyUI (cache on volume for reuse)
comfy_dir = "/workspace/ComfyUI"
_pod_state["setup_status"] = "Installing ComfyUI..."
comfy_exists = (await _ssh_exec_async(ssh, f"test -f {comfy_dir}/main.py && echo EXISTS || echo MISSING")).strip()
if comfy_exists == "EXISTS":
logger.info("ComfyUI already installed")
_pod_state["setup_status"] = "ComfyUI found, updating..."
await _ssh_exec_async(ssh, f"cd {comfy_dir} && git pull 2>&1 | tail -3", timeout=120)
else:
# Check volume cache
vol_comfy = (await _ssh_exec_async(ssh, "test -f /runpod-volume/ComfyUI/main.py && echo EXISTS || echo MISSING")).strip()
if vol_comfy == "EXISTS":
_pod_state["setup_status"] = "Restoring ComfyUI from volume..."
await _ssh_exec_async(ssh, f"cp -r /runpod-volume/ComfyUI {comfy_dir}", timeout=300)
else:
_pod_state["setup_status"] = "Cloning ComfyUI (first time, ~2 min)..."
await _ssh_exec_async(ssh, f"cd /workspace && git clone --depth 1 https://github.com/comfyanonymous/ComfyUI.git", timeout=300)
await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600)
# Cache to volume
volume_id, _ = _get_volume_config()
if volume_id:
await _ssh_exec_async(ssh, f"cp -r {comfy_dir} /runpod-volume/ComfyUI", timeout=300)
# Install pip deps that aren't in ComfyUI requirements
_pod_state["setup_status"] = "Installing dependencies..."
await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600)
await _ssh_exec_async(ssh, "pip install aiohttp einops sqlalchemy 2>&1 | tail -3", timeout=120)
# Symlink models into ComfyUI directories
_pod_state["setup_status"] = "Linking models..."
await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/checkpoints {comfy_dir}/models/vae {comfy_dir}/models/loras {comfy_dir}/models/text_encoders")
if model_type == "flux2":
# FLUX.2 Dev β€” separate UNet, text encoder, and VAE
await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models")
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/flux2-dev.safetensors {comfy_dir}/models/diffusion_models/flux2-dev.safetensors")
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/ae.safetensors {comfy_dir}/models/vae/ae.safetensors")
# Text encoder β€” use Comfy-Org's pre-converted single-file version
# (HF sharded format is incompatible with ComfyUI's CLIPLoader)
te_file = "/runpod-volume/models/mistral_3_small_flux2_fp8.safetensors"
te_exists = (await _ssh_exec_async(ssh, f"test -f {te_file} && echo EXISTS || echo MISSING")).strip()
if te_exists != "EXISTS":
_pod_state["setup_status"] = "Downloading FLUX.2 text encoder (~12GB, first time only)..."
await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60)
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
hf_hub_download(
repo_id='Comfy-Org/flux2-dev',
filename='split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors',
local_dir='/tmp/flux2_te',
)
import shutil
shutil.move('/tmp/flux2_te/split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors', '{te_file}')
print('Text encoder downloaded')
" 2>&1 | tail -5""", timeout=1800)
await _ssh_exec_async(ssh, f"ln -sf {te_file} {comfy_dir}/models/text_encoders/mistral_3_small_flux2_fp8.safetensors")
# Remove old sharded loader patch if present
await _ssh_exec_async(ssh, f"rm -f {comfy_dir}/custom_nodes/sharded_loader.py")
elif model_type == "flux1":
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/flux1-dev.safetensors {comfy_dir}/models/checkpoints/flux1-dev.safetensors")
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/ae.safetensors {comfy_dir}/models/vae/ae.safetensors")
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/clip_l.safetensors {comfy_dir}/models/text_encoders/clip_l.safetensors")
await _ssh_exec_async(ssh, f"ln -sf /workspace/models/t5xxl_fp16.safetensors {comfy_dir}/models/text_encoders/t5xxl_fp16.safetensors")
elif model_type == "z_image":
# Z-Image Turbo β€” 6B param model by Tongyi-MAI, runs in 16GB VRAM
z_dir = "/runpod-volume/models/z_image"
await _ssh_exec_async(ssh, f"mkdir -p {z_dir}")
await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60)
# Delete FLUX.2 from volume to free space
_pod_state["setup_status"] = "Cleaning up FLUX.2 from volume..."
await _ssh_exec_async(ssh, "rm -rf /runpod-volume/models/FLUX.2-dev /runpod-volume/models/mistral_3_small_flux2_fp8.safetensors 2>/dev/null; echo done")
# Download diffusion model (~12GB)
diff_model = f"{z_dir}/z_image_turbo_bf16.safetensors"
exists = (await _ssh_exec_async(ssh, f"test -f {diff_model} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
_pod_state["setup_status"] = "Downloading Z-Image Turbo diffusion model (~12GB)..."
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
import shutil, os
p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/diffusion_models/z_image_turbo_bf16.safetensors', local_dir='/tmp/z_image')
shutil.move(p, '{diff_model}')
print('Diffusion model downloaded')
" 2>&1 | tail -5""", timeout=3600)
# Download text encoder (~8GB Qwen 3 4B)
te_model = f"{z_dir}/qwen_3_4b.safetensors"
exists = (await _ssh_exec_async(ssh, f"test -f {te_model} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
_pod_state["setup_status"] = "Downloading Z-Image text encoder (~8GB)..."
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
import shutil
p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/text_encoders/qwen_3_4b.safetensors', local_dir='/tmp/z_image')
shutil.move(p, '{te_model}')
print('Text encoder downloaded')
" 2>&1 | tail -5""", timeout=3600)
# Download VAE (~335MB)
vae_model = f"{z_dir}/ae.safetensors"
exists = (await _ssh_exec_async(ssh, f"test -f {vae_model} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
_pod_state["setup_status"] = "Downloading Z-Image VAE..."
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
import shutil
p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/vae/ae.safetensors', local_dir='/tmp/z_image')
shutil.move(p, '{vae_model}')
print('VAE downloaded')
" 2>&1 | tail -5""", timeout=600)
# Symlink into ComfyUI directories
await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/text_encoders {comfy_dir}/models/vae")
await _ssh_exec_async(ssh, f"ln -sf {diff_model} {comfy_dir}/models/diffusion_models/z_image_turbo_bf16.safetensors")
await _ssh_exec_async(ssh, f"ln -sf {te_model} {comfy_dir}/models/text_encoders/qwen_3_4b.safetensors")
await _ssh_exec_async(ssh, f"ln -sf {vae_model} {comfy_dir}/models/vae/ae_z_image.safetensors")
elif model_type == "wan22":
# WAN 2.2 Remix NSFW β€” dual-DiT MoE split-step for realistic generation
wan_dir = "/workspace/models/WAN2.2"
await _ssh_exec_async(ssh, f"mkdir -p {wan_dir}")
civitai_token = os.environ.get("CIVITAI_API_TOKEN", "")
token_param = f"&token={civitai_token}" if civitai_token else ""
# CivitAI Remix models (fp8 ~14GB each)
civitai_models = {
"Remix T2V High-noise": {
"path": f"{wan_dir}/wan22_remix_t2v_high_fp8.safetensors",
"url": f"https://civitai.com/api/download/models/2424167?type=Model&format=SafeTensor&size=pruned{token_param}",
},
"Remix T2V Low-noise": {
"path": f"{wan_dir}/wan22_remix_t2v_low_fp8.safetensors",
"url": f"https://civitai.com/api/download/models/2424912?type=Model&format=SafeTensor&size=pruned{token_param}",
},
}
# HuggingFace models (T5 fp8 ~7GB, VAE ~1GB)
hf_models = {
"T5 text encoder (fp8)": {
"path": f"{wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
"filename": "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors",
},
"VAE": {
"path": f"{wan_dir}/wan_2.1_vae.safetensors",
"repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged",
"filename": "split_files/vae/wan_2.1_vae.safetensors",
},
}
# Download CivitAI Remix models
for label, info in civitai_models.items():
exists = (await _ssh_exec_async(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip()
if exists == "EXISTS":
logger.info("WAN 2.2 %s already cached", label)
else:
_pod_state["setup_status"] = f"Downloading {label} (~14GB)..."
await _ssh_exec_async(ssh, f"wget -q -O '{info['path']}' '{info['url']}'", timeout=1800)
# Verify download
check = (await _ssh_exec_async(ssh, f"test -f {info['path']} && stat -c%s {info['path']} || echo 0")).strip()
if check == "0" or int(check) < 1000000:
logger.error("Failed to download %s (size: %s). CivitAI API token may be required.", label, check)
_pod_state["setup_status"] = f"Failed: {label} download failed. Set CIVITAI_API_TOKEN env var for NSFW models."
return
# Download HuggingFace models
await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60)
for label, info in hf_models.items():
exists = (await _ssh_exec_async(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip()
if exists == "EXISTS":
logger.info("WAN 2.2 %s already cached", label)
else:
_pod_state["setup_status"] = f"Downloading {label}..."
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
import os, shutil
hf_hub_download('{info['repo']}', '{info['filename']}', local_dir='{wan_dir}')
downloaded = os.path.join('{wan_dir}', '{info['filename']}')
target = '{info['path']}'
if os.path.exists(downloaded) and downloaded != target:
os.makedirs(os.path.dirname(target), exist_ok=True)
shutil.move(downloaded, target)
print('Downloaded {label}')
" 2>&1 | tail -5""", timeout=1800)
# Symlink models into ComfyUI
await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/text_encoders")
await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan22_remix_t2v_high_fp8.safetensors {comfy_dir}/models/diffusion_models/")
await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan22_remix_t2v_low_fp8.safetensors {comfy_dir}/models/diffusion_models/")
await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan_2.1_vae.safetensors {comfy_dir}/models/vae/")
await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors {comfy_dir}/models/text_encoders/")
# Install wanBlockSwap custom node (VRAM optimization for dual-DiT on 24GB GPUs)
_pod_state["setup_status"] = "Installing WAN 2.2 custom nodes..."
blockswap_dir = f"{comfy_dir}/custom_nodes/ComfyUI-wanBlockswap"
blockswap_exists = (await _ssh_exec_async(ssh, f"test -d {blockswap_dir} && echo EXISTS || echo MISSING")).strip()
if blockswap_exists != "EXISTS":
await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/orssorbit/ComfyUI-wanBlockswap.git", timeout=120)
elif model_type == "wan22_i2v":
# WAN 2.2 Image-to-Video (14B params) β€” full model snapshot
wan_dir = "/workspace/models/Wan2.2-I2V-A14B"
wan_exists = (await _ssh_exec_async(ssh, f"test -d {wan_dir} && echo EXISTS || echo MISSING")).strip()
if wan_exists != "EXISTS":
_pod_state["setup_status"] = "Downloading WAN 2.2 I2V model (~28GB, first time only)..."
await _ssh_exec_async(ssh, f"pip install huggingface_hub 2>&1 | tail -1", timeout=60)
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import snapshot_download
snapshot_download('Wan-AI/Wan2.2-I2V-A14B', local_dir='{wan_dir}', ignore_patterns=['*.md', '*.txt'])
print('WAN 2.2 I2V downloaded')
" 2>&1 | tail -10""", timeout=3600)
await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models")
await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/diffusion_models/Wan2.2-I2V-A14B")
await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/checkpoints/Wan2.2-I2V-A14B")
# Install ComfyUI-WanVideoWrapper custom nodes
_pod_state["setup_status"] = "Installing WAN 2.2 ComfyUI nodes..."
wan_nodes_dir = f"{comfy_dir}/custom_nodes/ComfyUI-WanVideoWrapper"
wan_nodes_exist = (await _ssh_exec_async(ssh, f"test -d {wan_nodes_dir} && echo EXISTS || echo MISSING")).strip()
if wan_nodes_exist != "EXISTS":
await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120)
await _ssh_exec_async(ssh, f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
elif model_type == "wan22_animate":
# WAN 2.2 Animate (14B fp8) β€” dance/motion transfer via pose skeleton
animate_dir = "/workspace/models/WAN2.2-Animate"
wan22_dir = "/workspace/models/WAN2.2"
await _ssh_exec_async(ssh, f"mkdir -p {animate_dir}")
await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60)
# Download main Animate model (~28GB bf16 β€” only version available)
animate_model = f"{animate_dir}/wan2.2_animate_14B_bf16.safetensors"
exists = (await _ssh_exec_async(ssh, f"test -f {animate_model} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
_pod_state["setup_status"] = "Downloading WAN 2.2 Animate model (~28GB, first time only)..."
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
import os, shutil
hf_hub_download('Comfy-Org/Wan_2.2_ComfyUI_Repackaged', 'split_files/diffusion_models/wan2.2_animate_14B_bf16.safetensors', local_dir='{animate_dir}')
src = os.path.join('{animate_dir}', 'split_files', 'diffusion_models', 'wan2.2_animate_14B_bf16.safetensors')
if os.path.exists(src):
shutil.move(src, '{animate_model}')
print('Animate model downloaded')
" 2>&1 | tail -5""", timeout=7200)
# CLIP Vision H (~2.5GB) β€” ViT-H vision encoder
clip_vision_target = f"{animate_dir}/clip_vision_h.safetensors"
exists = (await _ssh_exec_async(ssh, f"test -f {clip_vision_target} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
_pod_state["setup_status"] = "Downloading CLIP Vision H (~2.5GB)..."
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
import os, shutil
result = hf_hub_download('h94/IP-Adapter', 'models/image_encoder/model.safetensors', local_dir='{animate_dir}/tmp_clip')
shutil.move(result, '{clip_vision_target}')
shutil.rmtree('{animate_dir}/tmp_clip', ignore_errors=True)
print('CLIP Vision H downloaded')
" 2>&1 | tail -5""", timeout=1800)
# VAE β€” reuse from WAN2.2 dir if available, else download (~1GB)
vae_target = f"{animate_dir}/wan_2.1_vae.safetensors"
exists = (await _ssh_exec_async(ssh, f"test -f {vae_target} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
vae_from_wan22 = (await _ssh_exec_async(ssh, f"test -f {wan22_dir}/wan_2.1_vae.safetensors && echo EXISTS || echo MISSING")).strip()
if vae_from_wan22 == "EXISTS":
await _ssh_exec_async(ssh, f"ln -sf {wan22_dir}/wan_2.1_vae.safetensors {vae_target}")
else:
_pod_state["setup_status"] = "Downloading VAE (~1GB)..."
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
import os, shutil
hf_hub_download('Comfy-Org/Wan_2.2_ComfyUI_Repackaged', 'split_files/vae/wan_2.1_vae.safetensors', local_dir='{animate_dir}')
src = os.path.join('{animate_dir}', 'split_files', 'vae', 'wan_2.1_vae.safetensors')
if os.path.exists(src):
shutil.move(src, '{vae_target}')
print('VAE downloaded')
" 2>&1 | tail -5""", timeout=600)
# UMT5 T5 encoder fp8 (non-scaled) β€” use Kijai/WanVideo_comfy version
# which is compatible with LoadWanVideoT5TextEncoder (scaled_fp8 is not supported)
t5_filename = "umt5-xxl-enc-fp8_e4m3fn.safetensors"
t5_target = f"{animate_dir}/{t5_filename}"
t5_comfy_path = f"{comfy_dir}/models/text_encoders/{t5_filename}"
t5_in_comfy = (await _ssh_exec_async(ssh, f"test -f {t5_comfy_path} && echo EXISTS || echo MISSING")).strip()
t5_in_vol = (await _ssh_exec_async(ssh, f"test -f {t5_target} && echo EXISTS || echo MISSING")).strip()
if t5_in_comfy != "EXISTS" and t5_in_vol != "EXISTS":
_pod_state["setup_status"] = "Downloading UMT5 text encoder (~6.3GB, first time only)..."
await _ssh_exec_async(ssh, f"""python -c "
from huggingface_hub import hf_hub_download
hf_hub_download('Kijai/WanVideo_comfy', '{t5_filename}', local_dir='{animate_dir}')
print('UMT5 text encoder downloaded')
" 2>&1 | tail -5""", timeout=1800)
t5_in_vol = "EXISTS"
# Symlink models into ComfyUI directories
await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/vae {comfy_dir}/models/clip_vision {comfy_dir}/models/text_encoders")
await _ssh_exec_async(ssh, f"ln -sf {animate_model} {comfy_dir}/models/diffusion_models/")
await _ssh_exec_async(ssh, f"ln -sf {vae_target} {comfy_dir}/models/vae/")
await _ssh_exec_async(ssh, f"ln -sf {clip_vision_target} {comfy_dir}/models/clip_vision/")
if t5_in_vol == "EXISTS" and t5_in_comfy != "EXISTS":
await _ssh_exec_async(ssh, f"ln -sf {t5_target} {t5_comfy_path}")
# Reconnect SSH before custom node setup β€” connection may have dropped during long downloads
ssh = await _ssh_connect_new()
# Install required custom nodes
_pod_state["setup_status"] = "Installing WAN Animate custom nodes..."
# ComfyUI-WanVideoWrapper (WanVideoAnimateEmbeds, WanVideoSampler, etc.)
wan_nodes_dir = f"{comfy_dir}/custom_nodes/ComfyUI-WanVideoWrapper"
exists = (await _ssh_exec_r(f"test -d {wan_nodes_dir} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120)
await _ssh_exec_r(f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
# ComfyUI-VideoHelperSuite (VHS_LoadVideo, VHS_VideoCombine)
vhs_dir = f"{comfy_dir}/custom_nodes/ComfyUI-VideoHelperSuite"
exists = (await _ssh_exec_r(f"test -d {vhs_dir} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git", timeout=120)
await _ssh_exec_r(f"cd {vhs_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
# comfyui_controlnet_aux (DWPreprocessor for pose extraction)
aux_dir = f"{comfy_dir}/custom_nodes/comfyui_controlnet_aux"
exists = (await _ssh_exec_r(f"test -d {aux_dir} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/Fannovel16/comfyui_controlnet_aux.git", timeout=120)
await _ssh_exec_r(f"cd {aux_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
# ComfyUI-KJNodes (ImageResizeKJv2 used in animate workflow)
kj_dir = f"{comfy_dir}/custom_nodes/ComfyUI-KJNodes"
exists = (await _ssh_exec_r(f"test -d {kj_dir} && echo EXISTS || echo MISSING")).strip()
if exists != "EXISTS":
await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-KJNodes.git", timeout=120)
await _ssh_exec_r(f"cd {kj_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300)
# Symlink all LoRAs from volume
await _ssh_exec_r(f"ls /runpod-volume/loras/*.safetensors 2>/dev/null | while read f; do ln -sf \"$f\" {comfy_dir}/models/loras/; done")
# Start ComfyUI in background (fire-and-forget β€” don't wait for output)
_pod_state["setup_status"] = "Starting ComfyUI..."
await asyncio.to_thread(
_ssh_exec_fire_and_forget,
ssh,
f"cd {comfy_dir} && python main.py --listen 0.0.0.0 --port 8188 --fp8_e4m3fn-unet > /tmp/comfyui.log 2>&1",
)
await asyncio.sleep(2) # Give it a moment to start
# Wait for ComfyUI HTTP to respond
_pod_state["setup_status"] = "Waiting for ComfyUI to load model..."
import httpx
comfyui_url = _get_comfyui_url()
for attempt in range(120): # Up to 10 minutes
try:
async with httpx.AsyncClient(timeout=5) as client:
resp = await client.get(f"{comfyui_url}/system_stats")
if resp.status_code == 200:
_pod_state["status"] = "running"
_pod_state["setup_status"] = "Ready"
_save_pod_state()
logger.info("ComfyUI ready at %s", comfyui_url)
return
except Exception:
pass
await asyncio.sleep(5)
# If we get here, ComfyUI didn't start
# Check the log for errors
log_tail = await _ssh_exec_async(ssh, "tail -20 /tmp/comfyui.log")
logger.error("ComfyUI didn't start. Log: %s", log_tail)
_pod_state["setup_status"] = f"ComfyUI failed to start. Check logs."
_pod_state["status"] = "setting_up" # Keep pod running so user can debug
except Exception as e:
import traceback
err_msg = f"{type(e).__name__}: {e}"
logger.error("Pod setup failed: %s\n%s", err_msg, traceback.format_exc())
_pod_state["setup_status"] = f"Setup failed: {err_msg}"
_pod_state["status"] = "setting_up" # Keep pod running so user can debug
finally:
try:
ssh.close()
except Exception:
pass
def _ssh_exec(ssh, cmd: str, timeout: int = 120) -> str:
"""Execute a command over SSH and return stdout (blocking β€” call from async via to_thread or background task)."""
_, stdout, stderr = ssh.exec_command(cmd, timeout=timeout)
out = stdout.read().decode("utf-8", errors="replace")
return out.strip()
async def _ssh_exec_async(ssh, cmd: str, timeout: int = 120) -> str:
"""Async wrapper for SSH exec that doesn't block the event loop."""
return await asyncio.to_thread(_ssh_exec, ssh, cmd, timeout)
def _ssh_exec_fire_and_forget(ssh, cmd: str):
"""Start a command over SSH without waiting for output (for background processes)."""
transport = ssh.get_transport()
channel = transport.open_session()
channel.exec_command(cmd)
# Don't read stdout/stderr β€” just let it run
# --- Pre-download models to network volume (saves money during training) ---
_download_state = {
"status": "idle", # idle, downloading, completed, failed
"pod_id": None,
"progress": "",
"error": None,
}
class DownloadModelsRequest(BaseModel):
model_type: str = "wan22"
gpu_type: str = "NVIDIA GeForce RTX 3090" # Cheapest GPU, just for downloading
@router.post("/download-models")
async def download_models_to_volume(request: DownloadModelsRequest):
"""Pre-download model files to network volume using a cheap pod.
This saves expensive GPU time during training β€” models are cached on the
shared volume and reused across all future training/generation pods.
"""
_get_api_key()
volume_id, volume_dc = _get_volume_config()
if not volume_id:
raise HTTPException(400, "No network volume configured (set RUNPOD_VOLUME_ID)")
if _download_state["status"] == "downloading":
return {"status": "already_downloading", "progress": _download_state["progress"]}
_download_state["status"] = "downloading"
_download_state["progress"] = "Creating cheap download pod..."
_download_state["error"] = None
asyncio.create_task(_download_models_task(request.model_type, request.gpu_type, volume_id, volume_dc))
return {"status": "started", "message": f"Downloading {request.model_type} models to volume (using {request.gpu_type})"}
@router.get("/download-models/status")
async def download_models_status():
"""Check model download progress."""
return _download_state
async def _download_models_task(model_type: str, gpu_type: str, volume_id: str, volume_dc: str):
"""Background task: spin up cheap pod, download models, terminate."""
import paramiko
ssh = None
pod_id = None
try:
# Create cheap pod with network volume β€” try multiple GPU types if first unavailable
pod_kwargs = {
"container_disk_in_gb": 10,
"ports": "22/tcp",
"network_volume_id": volume_id,
"docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'",
}
if volume_dc:
pod_kwargs["data_center_id"] = volume_dc
gpu_fallbacks = [
gpu_type,
"NVIDIA RTX A4000",
"NVIDIA RTX A5000",
"NVIDIA GeForce RTX 4090",
"NVIDIA GeForce RTX 4080",
"NVIDIA A100-SXM4-80GB",
]
pod = None
used_gpu = gpu_type
for try_gpu in gpu_fallbacks:
try:
pod = await asyncio.to_thread(
runpod.create_pod,
f"model-download-{model_type}",
DOCKER_IMAGE,
try_gpu,
**pod_kwargs,
)
used_gpu = try_gpu
logger.info("Download pod created with %s", try_gpu)
break
except Exception as e:
if "SUPPLY_CONSTRAINT" in str(e) or "no longer any instances" in str(e).lower():
logger.info("GPU %s unavailable, trying next...", try_gpu)
continue
raise
if pod is None:
raise RuntimeError("No GPU available for download pod. Try again later.")
pod_id = pod["id"]
_download_state["pod_id"] = pod_id
_download_state["progress"] = f"Pod created with {used_gpu} ({pod_id}), waiting for SSH..."
# Wait for SSH
ssh_host = ssh_port = None
start = time.time()
while time.time() - start < 300:
try:
p = await asyncio.to_thread(runpod.get_pod, pod_id)
if p and p.get("desiredStatus") == "RUNNING":
for port in (p.get("runtime") or {}).get("ports") or []:
if port.get("privatePort") == 22:
ssh_host = port.get("ip")
ssh_port = port.get("publicPort")
if ssh_host and ssh_port:
break
except Exception:
pass
await asyncio.sleep(5)
if not ssh_host:
raise RuntimeError("Pod SSH not available after 5 min")
# Connect SSH
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
for attempt in range(20):
try:
await asyncio.to_thread(ssh.connect, ssh_host, port=int(ssh_port), username="root", password="runpod", timeout=10)
break
except Exception:
if attempt == 19:
raise RuntimeError("SSH connection failed after 20 attempts")
await asyncio.sleep(5)
ssh.get_transport().set_keepalive(30)
_download_state["progress"] = "SSH connected, setting up tools..."
# Symlink volume
await _ssh_exec_async(ssh, "mkdir -p /runpod-volume/models && rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models")
await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=120)
await _ssh_exec_async(ssh, "which aria2c || apt-get install -y aria2 2>&1 | tail -1", timeout=120)
if model_type == "wan22":
wan_dir = "/workspace/models/WAN2.2"
await _ssh_exec_async(ssh, f"mkdir -p {wan_dir}")
civitai_token = os.environ.get("CIVITAI_API_TOKEN", "")
token_param = f"&token={civitai_token}" if civitai_token else ""
# CivitAI Remix models (fp8)
civitai_files = [
("Remix T2V High-noise", f"https://civitai.com/api/download/models/2424167?type=Model&format=SafeTensor&size=pruned{token_param}", f"{wan_dir}/wan22_remix_t2v_high_fp8.safetensors"),
("Remix T2V Low-noise", f"https://civitai.com/api/download/models/2424912?type=Model&format=SafeTensor&size=pruned{token_param}", f"{wan_dir}/wan22_remix_t2v_low_fp8.safetensors"),
]
# HuggingFace models
hf_files = [
("T5 text encoder (fp8)", "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", f"{wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors"),
("VAE", "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/vae/wan_2.1_vae.safetensors", f"{wan_dir}/wan_2.1_vae.safetensors"),
]
total = len(civitai_files) + len(hf_files)
idx = 0
for label, url, target in civitai_files:
idx += 1
exists = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip()
if exists == "EXISTS":
_download_state["progress"] = f"[{idx}/{total}] {label} already cached"
logger.info("WAN 2.2 %s already on volume", label)
else:
_download_state["progress"] = f"[{idx}/{total}] Downloading {label} (~14GB)..."
await _ssh_exec_async(ssh, f"wget -q -O '{target}' '{url}'", timeout=1800)
check = (await _ssh_exec_async(ssh, f"test -f {target} && stat -c%s {target} || echo 0")).strip()
if check == "0" or int(check) < 1000000:
raise RuntimeError(f"Failed to download {label}. Set CIVITAI_API_TOKEN for NSFW models.")
_download_state["progress"] = f"[{idx}/{total}] {label} downloaded"
for label, repo, filename, target in hf_files:
idx += 1
exists = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip()
if exists == "EXISTS":
_download_state["progress"] = f"[{idx}/{total}] {label} already cached"
logger.info("WAN 2.2 %s already on volume", label)
else:
_download_state["progress"] = f"[{idx}/{total}] Downloading {label}..."
hf_url = f"https://huggingface.co/{repo}/resolve/main/{filename}"
fname = target.split("/")[-1]
tdir = "/".join(target.split("/")[:-1])
await _ssh_exec_async(ssh, f"aria2c -x 16 -s 16 -c -o '{fname}' --dir='{tdir}' '{hf_url}' 2>&1 | tail -3", timeout=1800)
check = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip()
if check != "EXISTS":
raise RuntimeError(f"Failed to download {label}")
_download_state["progress"] = f"[{idx}/{total}] {label} downloaded"
# Also pre-clone musubi-tuner to volume (for training)
_download_state["progress"] = "Caching musubi-tuner to volume..."
tuner_exists = (await _ssh_exec_async(ssh, "test -f /runpod-volume/musubi-tuner/pyproject.toml && echo EXISTS || echo MISSING")).strip()
if tuner_exists != "EXISTS":
await _ssh_exec_async(ssh, "cd /workspace && git clone --depth 1 https://github.com/kohya-ss/musubi-tuner.git && cp -r /workspace/musubi-tuner /runpod-volume/musubi-tuner", timeout=300)
_download_state["progress"] = "musubi-tuner cached"
else:
_download_state["progress"] = "musubi-tuner already cached"
elif model_type == "wan22_animate":
animate_dir = "/workspace/models/WAN2.2-Animate"
wan22_dir = "/workspace/models/WAN2.2"
hf_base = "https://huggingface.co"
await _ssh_exec_async(ssh, f"mkdir -p {animate_dir}")
# Files to download: (label, url, target, timeout_s, min_bytes)
wget_files = [
(
"WAN 2.2 Animate model (~32GB)",
f"{hf_base}/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_animate_14B_bf16.safetensors",
f"{animate_dir}/wan2.2_animate_14B_bf16.safetensors",
7200,
30_000_000_000, # 30GB min β€” partial downloads get resumed
),
(
"UMT5 text encoder fp8 (~6.3GB)",
f"{hf_base}/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors",
f"{animate_dir}/umt5-xxl-enc-fp8_e4m3fn.safetensors",
1800,
6_000_000_000,
),
(
"VAE (~242MB)",
f"{hf_base}/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors",
f"{animate_dir}/wan_2.1_vae.safetensors",
300,
200_000_000,
),
(
"CLIP Vision H (~2.4GB)",
f"{hf_base}/h94/IP-Adapter/resolve/main/models/image_encoder/model.safetensors",
f"{animate_dir}/clip_vision_h.safetensors",
900,
2_000_000_000,
),
]
total = len(wget_files)
for idx, (label, url, target, dl_timeout, min_bytes) in enumerate(wget_files, 1):
# For T5 and VAE, reuse from wan22 dir if already present (and complete)
wan22_candidate = f"{wan22_dir}/{target.split('/')[-1]}"
reused = False
if label in ("UMT5 text encoder fp8 (~6.3GB)", "VAE (~1GB)"):
wan22_size = (await _ssh_exec_async(ssh, f"stat -c%s {wan22_candidate} 2>/dev/null || echo 0")).strip()
if int(wan22_size) >= min_bytes:
_download_state["progress"] = f"[{idx}/{total}] {label} β€” reusing from WAN2.2 dir"
await _ssh_exec_async(ssh, f"ln -sf {wan22_candidate} {target} 2>/dev/null || cp {wan22_candidate} {target}")
reused = True
if not reused:
size_str = (await _ssh_exec_async(ssh, f"stat -c%s {target} 2>/dev/null || echo 0")).strip()
if int(size_str) >= min_bytes:
_download_state["progress"] = f"[{idx}/{total}] {label} already cached"
else:
_download_state["progress"] = f"[{idx}/{total}] Downloading {label}..."
filename = target.split("/")[-1]
target_dir = "/".join(target.split("/")[:-1])
# Remove stale symlinks before downloading (can't resume through a symlink)
await _ssh_exec_async(ssh, f"test -L '{target}' && rm -f '{target}'; true")
await _ssh_exec_async(
ssh,
f"aria2c -x 16 -s 16 -c -o '{filename}' --dir='{target_dir}' '{url}' 2>&1 | tail -3",
timeout=dl_timeout,
)
size_str = (await _ssh_exec_async(ssh, f"stat -c%s {target} 2>/dev/null || echo 0")).strip()
if int(size_str) < min_bytes:
raise RuntimeError(f"Failed to download {label} (size {size_str} < {min_bytes})")
_download_state["progress"] = f"[{idx}/{total}] {label} downloaded"
_download_state["status"] = "completed"
_download_state["progress"] = "All models downloaded to volume! Ready for training."
logger.info("Model pre-download complete for %s", model_type)
except Exception as e:
_download_state["status"] = "failed"
_download_state["error"] = str(e)
_download_state["progress"] = f"Failed: {e}"
logger.error("Model download failed: %s", e)
finally:
if ssh:
try:
ssh.close()
except Exception:
pass
if pod_id:
try:
await asyncio.to_thread(runpod.terminate_pod, pod_id)
logger.info("Download pod terminated: %s", pod_id)
except Exception as e:
logger.warning("Failed to terminate download pod: %s", e)
_download_state["pod_id"] = None
@router.post("/stop")
async def stop_pod():
"""Stop the GPU pod."""
_get_api_key()
if not _pod_state["pod_id"]:
return {"status": "already_stopped"}
if _pod_state["status"] == "stopping":
return {"status": "stopping", "message": "Pod is already stopping"}
_pod_state["status"] = "stopping"
try:
pod_id = _pod_state["pod_id"]
logger.info("Stopping pod: %s", pod_id)
await asyncio.to_thread(runpod.terminate_pod, pod_id)
_pod_state["pod_id"] = None
_pod_state["ip"] = None
_pod_state["ssh_port"] = None
_pod_state["comfyui_port"] = None
_pod_state["status"] = "stopped"
_pod_state["started_at"] = None
_pod_state["setup_status"] = None
_save_pod_state()
logger.info("Pod stopped")
return {"status": "stopped", "message": "Pod terminated"}
except Exception as e:
logger.error("Failed to stop pod: %s", e)
_pod_state["status"] = "running"
raise HTTPException(500, f"Failed to stop pod: {e}")
@router.get("/loras")
async def list_pod_loras():
"""List LoRAs available on the pod."""
if _pod_state["status"] != "running" or not _pod_state["ip"]:
return {"loras": [], "message": "Pod not running"}
comfyui_url = _get_comfyui_url()
try:
import httpx
async with httpx.AsyncClient(timeout=30) as client:
url = f"{comfyui_url}/object_info/LoraLoader"
resp = await client.get(url)
if resp.status_code == 200:
data = resp.json()
loras = data.get("LoraLoader", {}).get("input", {}).get("required", {}).get("lora_name", [[]])[0]
return {"loras": loras if isinstance(loras, list) else []}
except Exception as e:
logger.warning("Failed to list pod LoRAs: %s", e)
return {"loras": [], "comfyui_url": comfyui_url}
@router.post("/upload-lora")
async def upload_lora_to_pod(file: UploadFile = File(...)):
"""Upload a LoRA file directly to /runpod-volume/loras/ via SFTP so it persists."""
import paramiko, io
if _pod_state["status"] != "running":
raise HTTPException(400, "Pod not running - start it first")
if not file.filename.endswith(".safetensors"):
raise HTTPException(400, "Only .safetensors files supported")
ip = _pod_state.get("ip")
port = _pod_state.get("ssh_port") or 22
if not ip:
raise HTTPException(500, "No SSH IP available")
content = await file.read()
dest_path = f"/runpod-volume/loras/{file.filename}"
comfy_link = f"/workspace/ComfyUI/models/loras/{file.filename}"
def _sftp_upload():
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(ip, port=port, username="root", timeout=30)
# Ensure dir exists
client.exec_command("mkdir -p /runpod-volume/loras")[1].read()
sftp = client.open_sftp()
sftp.putfo(io.BytesIO(content), dest_path)
sftp.close()
# Symlink into ComfyUI
client.exec_command(f"ln -sf {dest_path} {comfy_link}")[1].read()
client.close()
try:
await asyncio.to_thread(_sftp_upload)
logger.info("LoRA uploaded to volume: %s (%d bytes)", file.filename, len(content))
return {"status": "uploaded", "filename": file.filename, "path": dest_path}
except Exception as e:
logger.error("LoRA upload failed: %s", e)
raise HTTPException(500, f"Upload failed: {e}")
@router.post("/upload-lora-local")
async def upload_lora_from_local(local_path: str, filename: str | None = None):
"""Upload a LoRA from a local server path directly to the volume via SFTP."""
import paramiko, io
from pathlib import Path
if _pod_state["status"] != "running":
raise HTTPException(400, "Pod not running - start it first")
src = Path(local_path)
if not src.exists():
raise HTTPException(404, f"Local file not found: {local_path}")
dest_name = filename or src.name
if not dest_name.endswith(".safetensors"):
raise HTTPException(400, "Only .safetensors files supported")
ip = _pod_state.get("ip")
port = _pod_state.get("ssh_port") or 22
dest_path = f"/runpod-volume/loras/{dest_name}"
comfy_link = f"/workspace/ComfyUI/models/loras/{dest_name}"
def _sftp_upload():
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(ip, port=port, username="root", timeout=30)
client.exec_command("mkdir -p /runpod-volume/loras")[1].read()
sftp = client.open_sftp()
sftp.put(str(src), dest_path)
sftp.close()
client.exec_command(f"ln -sf {dest_path} {comfy_link}")[1].read()
client.close()
try:
await asyncio.to_thread(_sftp_upload)
size_mb = src.stat().st_size / 1024 / 1024
logger.info("LoRA uploaded from local: %s (%.1f MB)", dest_name, size_mb)
return {"status": "uploaded", "filename": dest_name, "path": dest_path, "size_mb": round(size_mb, 1)}
except Exception as e:
logger.error("Local LoRA upload failed: %s", e)
raise HTTPException(500, f"Upload failed: {e}")
class PodGenerateRequest(BaseModel):
prompt: str
negative_prompt: str = ""
width: int = 1024
height: int = 1024
steps: int = 28
cfg: float = 3.5
seed: int = -1
lora_name: str | None = None
lora_strength: float = 0.85
lora_name_2: str | None = None
lora_strength_2: float = 0.85
character_id: str | None = None
template_id: str | None = None
content_rating: str = "sfw"
# In-memory job tracking for pod generation
_pod_jobs: dict[str, dict] = {}
@router.post("/generate")
async def generate_on_pod(request: PodGenerateRequest):
"""Generate an image using the running pod's ComfyUI."""
import httpx
import random
if _pod_state["status"] != "running":
raise HTTPException(400, "Pod not running - start it first")
job_id = str(uuid.uuid4())[:8]
seed = request.seed if request.seed >= 0 else random.randint(0, 2**32 - 1)
model_type = _pod_state.get("model_type", "flux2")
if model_type == "wan22":
workflow = _build_wan_t2i_workflow(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
steps=request.steps,
cfg=request.cfg,
seed=seed,
lora_name=request.lora_name,
lora_strength=request.lora_strength,
lora_name_2=request.lora_name_2,
lora_strength_2=request.lora_strength_2,
)
else:
workflow = _build_flux_workflow(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
width=request.width,
height=request.height,
steps=request.steps,
cfg=request.cfg,
seed=seed,
lora_name=request.lora_name,
lora_strength=request.lora_strength,
model_type=model_type,
)
comfyui_url = _get_comfyui_url()
try:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(f"{comfyui_url}/prompt", json={"prompt": workflow})
resp.raise_for_status()
data = resp.json()
prompt_id = data["prompt_id"]
_pod_jobs[job_id] = {
"prompt_id": prompt_id,
"status": "running",
"seed": seed,
"created_at": time.time(),
"started_at": time.time(),
"positive_prompt": request.prompt,
"negative_prompt": request.negative_prompt,
"steps": request.steps,
"cfg": request.cfg,
"width": request.width,
"height": request.height,
}
logger.info("Pod generation started: %s -> %s", job_id, prompt_id)
asyncio.create_task(_poll_pod_job(job_id, prompt_id, request.content_rating))
return {"job_id": job_id, "status": "running", "seed": seed}
except Exception as e:
logger.error("Pod generation failed: %s", e)
raise HTTPException(500, f"Generation failed: {e}")
async def _poll_pod_job(job_id: str, prompt_id: str, content_rating: str):
"""Poll ComfyUI for job completion and save the result."""
import httpx
start = time.time()
timeout = 900 # 15 min β€” first gen loads model (~12GB) + samples
comfyui_url = _get_comfyui_url()
last_log_time = 0
async with httpx.AsyncClient(timeout=60) as client:
while time.time() - start < timeout:
try:
# Log queue progress every 15 seconds and store in job
elapsed = time.time() - start
if elapsed - last_log_time >= 15:
last_log_time = elapsed
try:
q_resp = await client.get(f"{comfyui_url}/queue")
if q_resp.status_code == 200:
q_data = q_resp.json()
running = q_data.get("queue_running", [])
pending = len(q_data.get("queue_pending", []))
status_msg = f"{int(elapsed)}s elapsed"
if running:
# Try to get node execution progress
try:
p_resp = await client.get(f"{comfyui_url}/prompt")
if p_resp.status_code == 200:
p_data = p_resp.json()
exec_info = p_data.get("exec_info", {})
if exec_info:
status_msg += f" | nodes: {exec_info}"
except Exception:
pass
status_msg += " | generating..."
elif pending:
status_msg += " | loading models..."
else:
status_msg += " | waiting..."
_pod_jobs[job_id]["progress_msg"] = status_msg
logger.info("Pod gen %s: %s", job_id, status_msg)
except Exception:
pass
resp = await client.get(f"{comfyui_url}/history/{prompt_id}")
if resp.status_code == 200:
data = resp.json()
if prompt_id in data:
outputs = data[prompt_id].get("outputs", {})
for node_id, node_output in outputs.items():
if "images" in node_output:
image_info = node_output["images"][0]
filename = image_info["filename"]
subfolder = image_info.get("subfolder", "")
params = {"filename": filename}
if subfolder:
params["subfolder"] = subfolder
img_resp = await client.get(f"{comfyui_url}/view", params=params)
if img_resp.status_code == 200:
from content_engine.config import settings
output_dir = settings.paths.output_dir / "pod" / content_rating / "raw"
output_dir.mkdir(parents=True, exist_ok=True)
local_path = output_dir / f"pod_{job_id}.png"
local_path.write_bytes(img_resp.content)
_pod_jobs[job_id]["status"] = "completed"
_pod_jobs[job_id]["output_path"] = str(local_path)
_pod_jobs[job_id]["completed_at"] = time.time()
logger.info("Pod generation completed: %s -> %s", job_id, local_path)
try:
from content_engine.services.catalog import CatalogService
catalog = CatalogService()
job_info = _pod_jobs[job_id]
await catalog.insert_image(
file_path=str(local_path),
image_bytes=img_resp.content,
content_rating=content_rating,
positive_prompt=job_info.get("positive_prompt"),
negative_prompt=job_info.get("negative_prompt"),
seed=job_info.get("seed"),
steps=job_info.get("steps"),
cfg=job_info.get("cfg"),
width=job_info.get("width"),
height=job_info.get("height"),
generation_backend="runpod-pod",
generation_time_seconds=time.time() - job_info.get("created_at", time.time()),
)
logger.info("Pod image cataloged: %s", job_id)
except Exception as e:
logger.warning("Failed to catalog pod image: %s", e)
return
except Exception as e:
logger.debug("Polling pod job: %s", e)
await asyncio.sleep(2)
_pod_jobs[job_id]["status"] = "failed"
_pod_jobs[job_id]["error"] = "Timeout waiting for generation"
logger.error("Pod generation timed out: %s", job_id)
@router.get("/jobs/{job_id}")
async def get_pod_job(job_id: str):
"""Get status of a pod generation job."""
job = _pod_jobs.get(job_id)
if not job:
raise HTTPException(404, "Job not found")
return job
@router.get("/jobs/{job_id}/image")
async def get_pod_job_image(job_id: str):
"""Serve the generated image for a completed pod job."""
from fastapi.responses import FileResponse
job = _pod_jobs.get(job_id)
if not job:
raise HTTPException(404, "Job not found")
output_path = job.get("output_path")
if not output_path:
raise HTTPException(404, "No image yet")
from pathlib import Path
p = Path(output_path)
if not p.exists():
raise HTTPException(404, "Image file not found")
return FileResponse(p, media_type="image/png")
def _build_flux_workflow(
prompt: str,
negative_prompt: str,
width: int,
height: int,
steps: int,
cfg: float,
seed: int,
lora_name: str | None,
lora_strength: float,
model_type: str = "flux2",
) -> dict:
"""Build a ComfyUI workflow for FLUX generation.
FLUX.2 Dev uses separate model components (not a single checkpoint):
- UNETLoader for the diffusion model
- CLIPLoader (type=flux2) for the Mistral text encoder
- VAELoader for the autoencoder
"""
if model_type == "flux2":
unet_name = "flux2-dev.safetensors"
clip_type = "flux2"
clip_name = "mistral_3_small_flux2_fp8.safetensors"
else:
unet_name = "flux1-dev.safetensors"
clip_type = "flux"
clip_name = "t5xxl_fp16.safetensors"
# Model node ID references
model_out = ["1", 0] # UNETLoader -> MODEL
clip_out = ["2", 0] # CLIPLoader -> CLIP
vae_out = ["3", 0] # VAELoader -> VAE
workflow = {
# Load diffusion model (UNet)
"1": {
"class_type": "UNETLoader",
"inputs": {
"unet_name": unet_name,
"weight_dtype": "fp8_e4m3fn",
},
},
# Load text encoder
"2": {
"class_type": "CLIPLoader",
"inputs": {
"clip_name": clip_name,
"type": clip_type,
},
},
# Load VAE
"3": {
"class_type": "VAELoader",
"inputs": {"vae_name": "ae.safetensors"},
},
# Positive prompt
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": prompt,
"clip": clip_out,
},
},
# Negative prompt
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": negative_prompt or "",
"clip": clip_out,
},
},
# Empty latent
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"width": width,
"height": height,
"batch_size": 1,
},
},
# Sampler
"10": {
"class_type": "KSampler",
"inputs": {
"seed": seed,
"steps": steps,
"cfg": cfg,
"sampler_name": "euler",
"scheduler": "simple",
"denoise": 1.0,
"model": model_out,
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0],
},
},
# Decode
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["10", 0],
"vae": vae_out,
},
},
# Save
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "flux_pod",
"images": ["8", 0],
},
},
}
# Add LoRA if specified
if lora_name:
workflow["20"] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": lora_name,
"strength_model": lora_strength,
"strength_clip": lora_strength,
"model": model_out,
"clip": clip_out,
},
}
# Rewire sampler and text encoders to use LoRA output
workflow["10"]["inputs"]["model"] = ["20", 0]
workflow["6"]["inputs"]["clip"] = ["20", 1]
workflow["7"]["inputs"]["clip"] = ["20", 1]
return workflow
def _build_wan_t2i_workflow(
prompt: str,
negative_prompt: str,
width: int,
height: int,
steps: int,
cfg: float,
seed: int,
lora_name: str | None,
lora_strength: float,
lora_name_2: str | None = None,
lora_strength_2: float = 0.85,
) -> dict:
"""Build a ComfyUI workflow for WAN 2.2 Remix β€” dual-DiT MoE split-step.
Based on the WAN 2.2 Remix workflow from CivitAI:
- Two UNETLoaders: high-noise + low-noise Remix models (fp8)
- wanBlockSwap on both (offloads blocks to CPU for 24GB GPUs)
- ModelSamplingSD3 with shift=5 on both
- Dual KSamplerAdvanced: high-noise runs first half, low-noise finishes
- CLIPLoader (type=wan) + CLIPTextEncode for T5 text encoding
- Standard VAELoader + VAEDecode
- EmptyHunyuanLatentVideo for latent (1 frame = image, 81+ = video)
"""
high_dit = "wan22_remix_t2v_high_fp8.safetensors"
low_dit = "wan22_remix_t2v_low_fp8.safetensors"
t5_name = "umt5_xxl_fp8_e4m3fn_scaled.safetensors"
vae_name = "wan_2.1_vae.safetensors"
total_steps = steps # default 8
split_step = total_steps // 2 # high-noise does first half, low-noise does rest
shift = 5.0
block_swap = 20 # blocks offloaded to CPU (0-40, higher = less VRAM)
workflow = {
# ── Load high-noise DiT ──
"1": {
"class_type": "UNETLoader",
"inputs": {
"unet_name": high_dit,
"weight_dtype": "fp8_e4m3fn",
},
},
# ── Load low-noise DiT ──
"2": {
"class_type": "UNETLoader",
"inputs": {
"unet_name": low_dit,
"weight_dtype": "fp8_e4m3fn",
},
},
# ── wanBlockSwap on high-noise (VRAM optimization) ──
"11": {
"class_type": "wanBlockSwap",
"inputs": {
"model": ["1", 0],
"blocks_to_swap": block_swap,
"offload_img_emb": False,
"offload_txt_emb": False,
},
},
# ── wanBlockSwap on low-noise ──
"12": {
"class_type": "wanBlockSwap",
"inputs": {
"model": ["2", 0],
"blocks_to_swap": block_swap,
"offload_img_emb": False,
"offload_txt_emb": False,
},
},
# ── ModelSamplingSD3 shift on high-noise ──
"13": {
"class_type": "ModelSamplingSD3",
"inputs": {
"model": ["11", 0],
"shift": shift,
},
},
# ── ModelSamplingSD3 shift on low-noise ──
"14": {
"class_type": "ModelSamplingSD3",
"inputs": {
"model": ["12", 0],
"shift": shift,
},
},
# ── Load T5 text encoder ──
"3": {
"class_type": "CLIPLoader",
"inputs": {
"clip_name": t5_name,
"type": "wan",
},
},
# ── Positive prompt ──
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": prompt,
"clip": ["3", 0],
},
},
# ── Negative prompt ──
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": negative_prompt or "",
"clip": ["3", 0],
},
},
# ── VAE ──
"4": {
"class_type": "VAELoader",
"inputs": {"vae_name": vae_name},
},
# ── Empty latent (1 frame = single image) ──
"5": {
"class_type": "EmptyHunyuanLatentVideo",
"inputs": {
"width": width,
"height": height,
"length": 1,
"batch_size": 1,
},
},
# ── KSamplerAdvanced #1: High-noise model (first half of steps) ──
"15": {
"class_type": "KSamplerAdvanced",
"inputs": {
"model": ["13", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0],
"add_noise": "enable",
"noise_seed": seed,
"steps": total_steps,
"cfg": cfg,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": split_step,
"return_with_leftover_noise": "enable",
},
},
# ── KSamplerAdvanced #2: Low-noise model (second half of steps) ──
"16": {
"class_type": "KSamplerAdvanced",
"inputs": {
"model": ["14", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["15", 0],
"add_noise": "disable",
"noise_seed": seed,
"steps": total_steps,
"cfg": cfg,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": split_step,
"end_at_step": 10000,
"return_with_leftover_noise": "disable",
},
},
# ── VAE Decode ──
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["16", 0],
"vae": ["4", 0],
},
},
# ── Save Image ──
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "wan_remix_pod",
"images": ["8", 0],
},
},
}
# Add LoRA(s) to both models if specified β€” chained: DiT β†’ LoRA1 β†’ LoRA2 β†’ Sampler
if lora_name:
# LoRA 1 (body) on high-noise and low-noise DiT
workflow["20"] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": lora_name,
"strength_model": lora_strength,
"strength_clip": 1.0,
"model": ["13", 0],
"clip": ["3", 0],
},
}
workflow["21"] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": lora_name,
"strength_model": lora_strength,
"strength_clip": 1.0,
"model": ["14", 0],
"clip": ["3", 0],
},
}
# Determine what the samplers and CLIP read from (LoRA2 if present, else LoRA1)
high_model_out = ["20", 0]
low_model_out = ["21", 0]
clip_out = ["20", 1]
if lora_name_2:
# LoRA 2 (face) chained after LoRA 1 on both models
workflow["22"] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": lora_name_2,
"strength_model": lora_strength_2,
"strength_clip": 1.0,
"model": ["20", 0],
"clip": ["20", 1],
},
}
workflow["23"] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": lora_name_2,
"strength_model": lora_strength_2,
"strength_clip": 1.0,
"model": ["21", 0],
"clip": ["21", 1],
},
}
high_model_out = ["22", 0]
low_model_out = ["23", 0]
clip_out = ["22", 1]
# Rewire samplers and CLIP encoding
workflow["15"]["inputs"]["model"] = high_model_out
workflow["16"]["inputs"]["model"] = low_model_out
workflow["6"]["inputs"]["clip"] = clip_out
workflow["7"]["inputs"]["clip"] = clip_out
return workflow