"""RunPod Pod management routes — start/stop GPU pods for generation. Starts a persistent ComfyUI pod with network volume access. Models and LoRAs are loaded from the shared network volume. """ from __future__ import annotations import asyncio import json import logging import os import time import uuid from pathlib import Path from typing import Any import runpod from fastapi import APIRouter, File, HTTPException, UploadFile from pydantic import BaseModel logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/pod", tags=["pod"]) # Persist pod state to disk so it survives server restarts _POD_STATE_FILE = Path(__file__).parent.parent.parent.parent / "pod_state.json" def _save_pod_state(): """Save pod state to disk.""" try: data = {k: v for k, v in _pod_state.items() if k != "setup_status"} _POD_STATE_FILE.write_text(json.dumps(data)) except Exception as e: logger.warning("Failed to save pod state: %s", e) def _load_pod_state(): """Load pod state from disk on startup.""" try: if _POD_STATE_FILE.exists(): data = json.loads(_POD_STATE_FILE.read_text()) for k, v in data.items(): if k in _pod_state: _pod_state[k] = v logger.info("Restored pod state: pod_id=%s status=%s", _pod_state.get("pod_id"), _pod_state.get("status")) except Exception as e: logger.warning("Failed to load pod state: %s", e) def _get_volume_config() -> tuple[str, str]: """Get network volume config at runtime (after dotenv loads).""" return ( os.environ.get("RUNPOD_VOLUME_ID", ""), os.environ.get("RUNPOD_VOLUME_DC", ""), ) # Docker image — PyTorch base with CUDA, we install ComfyUI ourselves DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04" # Pod state _pod_state = { "pod_id": None, "status": "stopped", # stopped, starting, setting_up, running, stopping "ip": None, "ssh_port": None, "comfyui_port": None, "gpu_type": "NVIDIA RTX A6000", "model_type": "flux2", "started_at": None, "cost_per_hour": 0.76, "setup_status": None, } _load_pod_state() # GPU options (same as training) GPU_OPTIONS = { "NVIDIA A40": {"name": "A40 48GB", "vram": 48, "cost": 0.64}, "NVIDIA RTX A6000": {"name": "RTX A6000 48GB", "vram": 48, "cost": 0.76}, "NVIDIA L40": {"name": "L40 48GB", "vram": 48, "cost": 0.89}, "NVIDIA L40S": {"name": "L40S 48GB", "vram": 48, "cost": 1.09}, "NVIDIA A100-SXM4-80GB": {"name": "A100 SXM 80GB", "vram": 80, "cost": 1.64}, "NVIDIA A100 80GB PCIe": {"name": "A100 PCIe 80GB", "vram": 80, "cost": 1.89}, "NVIDIA H100 80GB HBM3": {"name": "H100 80GB", "vram": 80, "cost": 3.89}, "NVIDIA GeForce RTX 5090": {"name": "RTX 5090 32GB", "vram": 32, "cost": 0.69}, "NVIDIA GeForce RTX 4090": {"name": "RTX 4090 24GB", "vram": 24, "cost": 0.44}, "NVIDIA GeForce RTX 3090": {"name": "RTX 3090 24GB", "vram": 24, "cost": 0.22}, } def _get_comfyui_url() -> str | None: """Get the ComfyUI URL via RunPod's HTTPS proxy. RunPod HTTP ports are only accessible through their proxy at https://{pod_id}-{private_port}.proxy.runpod.net The raw IP:port from the API is an internal address, not publicly routable. """ pod_id = _pod_state.get("pod_id") if pod_id: return f"https://{pod_id}-8188.proxy.runpod.net" return None def _get_api_key() -> str: key = os.environ.get("RUNPOD_API_KEY") if not key: raise HTTPException(503, "RUNPOD_API_KEY not configured") runpod.api_key = key return key class StartPodRequest(BaseModel): gpu_type: str = "NVIDIA RTX A6000" model_type: str = "flux2" class PodStatus(BaseModel): status: str pod_id: str | None = None ip: str | None = None port: int | None = None gpu_type: str | None = None model_type: str | None = None cost_per_hour: float | None = None setup_status: str | None = None uptime_minutes: float | None = None comfyui_url: str | None = None @router.get("/status", response_model=PodStatus) async def get_pod_status(): """Get current pod status.""" _get_api_key() if _pod_state["pod_id"]: try: pod = await asyncio.wait_for( asyncio.to_thread(runpod.get_pod, _pod_state["pod_id"]), timeout=10, ) if pod: desired = pod.get("desiredStatus", "") if desired == "RUNNING": runtime = pod.get("runtime") or {} ports = runtime.get("ports") or [] for p in ports: if p.get("privatePort") == 22: _pod_state["ssh_ip"] = p.get("ip") _pod_state["ssh_port"] = p.get("publicPort") if p.get("privatePort") == 8188: _pod_state["comfyui_ip"] = p.get("ip") _pod_state["comfyui_port"] = p.get("publicPort") # Use SSH IP as the main IP for display _pod_state["ip"] = _pod_state.get("ssh_ip") or _pod_state.get("comfyui_ip") elif desired == "EXITED": _pod_state["status"] = "stopped" _pod_state["pod_id"] = None else: _pod_state["status"] = "stopped" _pod_state["pod_id"] = None except asyncio.TimeoutError: logger.warning("RunPod API timeout checking pod status") except Exception as e: logger.warning("Failed to check pod: %s", e) uptime = None if _pod_state["started_at"] and _pod_state["status"] in ("running", "setting_up"): uptime = (time.time() - _pod_state["started_at"]) / 60 comfyui_url = _get_comfyui_url() return PodStatus( status=_pod_state["status"], pod_id=_pod_state["pod_id"], ip=_pod_state["ip"], port=_pod_state.get("comfyui_port"), gpu_type=_pod_state["gpu_type"], model_type=_pod_state.get("model_type", "flux2"), cost_per_hour=_pod_state["cost_per_hour"], setup_status=_pod_state.get("setup_status"), uptime_minutes=uptime, comfyui_url=comfyui_url, ) @router.get("/gpu-options") async def list_gpu_options(): """List available GPU types.""" return {"gpus": GPU_OPTIONS} @router.get("/model-options") async def list_model_options(): """List available model types for the pod.""" return { "models": { "flux2": {"name": "FLUX.2 Dev", "description": "Best for realistic txt2img (requires 48GB+ VRAM)", "use_case": "txt2img"}, "flux1": {"name": "FLUX.1 Dev", "description": "Previous gen FLUX txt2img", "use_case": "txt2img"}, "wan22": {"name": "WAN 2.2 Remix", "description": "Realistic generation — dual-DiT MoE split-step (NSFW OK)", "use_case": "txt2img"}, "wan22_i2v": {"name": "WAN 2.2 I2V", "description": "Image-to-video generation", "use_case": "img2video"}, "wan22_animate": {"name": "WAN 2.2 Animate", "description": "Dance/motion transfer — animate a character from a driving video", "use_case": "animate"}, } } @router.post("/start") async def start_pod(request: StartPodRequest): """Start a GPU pod with ComfyUI for generation.""" _get_api_key() if _pod_state["status"] in ("running", "setting_up"): return {"status": "already_running", "pod_id": _pod_state["pod_id"]} if _pod_state["status"] == "starting": return {"status": "starting", "message": "Pod is already starting"} gpu_info = GPU_OPTIONS.get(request.gpu_type) if not gpu_info: raise HTTPException(400, f"Unknown GPU type: {request.gpu_type}") _pod_state["status"] = "starting" _pod_state["gpu_type"] = request.gpu_type _pod_state["cost_per_hour"] = gpu_info["cost"] _pod_state["model_type"] = request.model_type _pod_state["setup_status"] = "Creating pod..." try: logger.info("Starting RunPod with %s for %s...", request.gpu_type, request.model_type) pod_kwargs = { "container_disk_in_gb": 30, "ports": "22/tcp,8188/http", "docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'", } volume_id, volume_dc = _get_volume_config() if volume_id: pod_kwargs["network_volume_id"] = volume_id if volume_dc: pod_kwargs["data_center_id"] = volume_dc logger.info("Using network volume: %s (DC: %s)", volume_id, volume_dc) else: pod_kwargs["volume_in_gb"] = 75 logger.warning("No network volume configured — using temporary volume") pod = await asyncio.to_thread( runpod.create_pod, f"comfyui-gen-{request.model_type}", DOCKER_IMAGE, request.gpu_type, **pod_kwargs, ) _pod_state["pod_id"] = pod["id"] _pod_state["started_at"] = time.time() _save_pod_state() logger.info("Pod created: %s", pod["id"]) asyncio.create_task(_wait_and_setup_pod(pod["id"], request.model_type)) return { "status": "starting", "pod_id": pod["id"], "message": f"Starting {gpu_info['name']} pod (~5-8 min for setup)", } except Exception as e: _pod_state["status"] = "stopped" _pod_state["setup_status"] = None logger.error("Failed to start pod: %s", e) raise HTTPException(500, f"Failed to start pod: {e}") async def _wait_and_setup_pod(pod_id: str, model_type: str, timeout: int = 600): """Wait for pod to be ready, then install ComfyUI and link models via SSH.""" start = time.time() ssh_host = None ssh_port = None # Phase 1: Wait for SSH to be available _pod_state["setup_status"] = "Waiting for pod to start..." while time.time() - start < timeout: try: pod = await asyncio.to_thread(runpod.get_pod, pod_id) if pod and pod.get("desiredStatus") == "RUNNING": runtime = pod.get("runtime") or {} ports = runtime.get("ports") or [] for p in ports: if p.get("privatePort") == 22: ssh_host = p.get("ip") ssh_port = p.get("publicPort") _pod_state["ssh_ip"] = ssh_host _pod_state["ssh_port"] = ssh_port _pod_state["ip"] = ssh_host if p.get("privatePort") == 8188: _pod_state["comfyui_ip"] = p.get("ip") _pod_state["comfyui_port"] = p.get("publicPort") if ssh_host and ssh_port: break except Exception as e: logger.debug("Waiting for pod: %s", e) await asyncio.sleep(5) if not ssh_host or not ssh_port: logger.error("Pod did not become ready within %ds", timeout) _pod_state["status"] = "stopped" _pod_state["setup_status"] = "Failed: pod did not start" return # Phase 2: SSH in and set up ComfyUI _pod_state["status"] = "setting_up" _pod_state["setup_status"] = "Connecting via SSH..." import paramiko async def _ssh_connect_new() -> "paramiko.SSHClient": """Create a fresh SSH connection to the pod.""" client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) for attempt in range(10): try: await asyncio.to_thread( client.connect, ssh_host, port=int(ssh_port), username="root", password="runpod", timeout=15, banner_timeout=30, ) client.get_transport().set_keepalive(30) return client except Exception: if attempt == 9: raise await asyncio.sleep(5) raise RuntimeError("SSH connection failed after retries") async def _ssh_exec_r(cmd: str, timeout: int = 120) -> str: """Execute SSH command, reconnecting once if the session dropped.""" nonlocal ssh try: t = ssh.get_transport() if t is None or not t.is_active(): logger.info("SSH session dropped, reconnecting...") ssh = await _ssh_connect_new() return await _ssh_exec_async(ssh, cmd, timeout) except Exception as e: if "not active" in str(e).lower() or "session" in str(e).lower(): logger.info("SSH error '%s', reconnecting and retrying...", e) ssh = await _ssh_connect_new() return await _ssh_exec_async(ssh, cmd, timeout) raise for attempt in range(30): try: ssh = await _ssh_connect_new() break except Exception: if attempt == 29: _pod_state["setup_status"] = "Failed: SSH connection error" _pod_state["status"] = "stopped" return await asyncio.sleep(5) try: # Symlink network volume volume_id, _ = _get_volume_config() if volume_id: await _ssh_exec_async(ssh, "mkdir -p /runpod-volume/models /runpod-volume/loras") await _ssh_exec_async(ssh, "rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models") # Install ComfyUI (cache on volume for reuse) comfy_dir = "/workspace/ComfyUI" _pod_state["setup_status"] = "Installing ComfyUI..." comfy_exists = (await _ssh_exec_async(ssh, f"test -f {comfy_dir}/main.py && echo EXISTS || echo MISSING")).strip() if comfy_exists == "EXISTS": logger.info("ComfyUI already installed") _pod_state["setup_status"] = "ComfyUI found, updating..." await _ssh_exec_async(ssh, f"cd {comfy_dir} && git pull 2>&1 | tail -3", timeout=120) else: # Check volume cache vol_comfy = (await _ssh_exec_async(ssh, "test -f /runpod-volume/ComfyUI/main.py && echo EXISTS || echo MISSING")).strip() if vol_comfy == "EXISTS": _pod_state["setup_status"] = "Restoring ComfyUI from volume..." await _ssh_exec_async(ssh, f"cp -r /runpod-volume/ComfyUI {comfy_dir}", timeout=300) else: _pod_state["setup_status"] = "Cloning ComfyUI (first time, ~2 min)..." await _ssh_exec_async(ssh, f"cd /workspace && git clone --depth 1 https://github.com/comfyanonymous/ComfyUI.git", timeout=300) await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600) # Cache to volume volume_id, _ = _get_volume_config() if volume_id: await _ssh_exec_async(ssh, f"cp -r {comfy_dir} /runpod-volume/ComfyUI", timeout=300) # Install pip deps that aren't in ComfyUI requirements _pod_state["setup_status"] = "Installing dependencies..." await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600) await _ssh_exec_async(ssh, "pip install aiohttp einops sqlalchemy 2>&1 | tail -3", timeout=120) # Symlink models into ComfyUI directories _pod_state["setup_status"] = "Linking models..." await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/checkpoints {comfy_dir}/models/vae {comfy_dir}/models/loras {comfy_dir}/models/text_encoders") if model_type == "flux2": # FLUX.2 Dev — separate UNet, text encoder, and VAE await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models") await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/flux2-dev.safetensors {comfy_dir}/models/diffusion_models/flux2-dev.safetensors") await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/ae.safetensors {comfy_dir}/models/vae/ae.safetensors") # Text encoder — use Comfy-Org's pre-converted single-file version # (HF sharded format is incompatible with ComfyUI's CLIPLoader) te_file = "/runpod-volume/models/mistral_3_small_flux2_fp8.safetensors" te_exists = (await _ssh_exec_async(ssh, f"test -f {te_file} && echo EXISTS || echo MISSING")).strip() if te_exists != "EXISTS": _pod_state["setup_status"] = "Downloading FLUX.2 text encoder (~12GB, first time only)..." await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60) await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import hf_hub_download hf_hub_download( repo_id='Comfy-Org/flux2-dev', filename='split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors', local_dir='/tmp/flux2_te', ) import shutil shutil.move('/tmp/flux2_te/split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors', '{te_file}') print('Text encoder downloaded') " 2>&1 | tail -5""", timeout=1800) await _ssh_exec_async(ssh, f"ln -sf {te_file} {comfy_dir}/models/text_encoders/mistral_3_small_flux2_fp8.safetensors") # Remove old sharded loader patch if present await _ssh_exec_async(ssh, f"rm -f {comfy_dir}/custom_nodes/sharded_loader.py") elif model_type == "flux1": await _ssh_exec_async(ssh, f"ln -sf /workspace/models/flux1-dev.safetensors {comfy_dir}/models/checkpoints/flux1-dev.safetensors") await _ssh_exec_async(ssh, f"ln -sf /workspace/models/ae.safetensors {comfy_dir}/models/vae/ae.safetensors") await _ssh_exec_async(ssh, f"ln -sf /workspace/models/clip_l.safetensors {comfy_dir}/models/text_encoders/clip_l.safetensors") await _ssh_exec_async(ssh, f"ln -sf /workspace/models/t5xxl_fp16.safetensors {comfy_dir}/models/text_encoders/t5xxl_fp16.safetensors") elif model_type == "z_image": # Z-Image Turbo — 6B param model by Tongyi-MAI, runs in 16GB VRAM z_dir = "/runpod-volume/models/z_image" await _ssh_exec_async(ssh, f"mkdir -p {z_dir}") await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60) # Delete FLUX.2 from volume to free space _pod_state["setup_status"] = "Cleaning up FLUX.2 from volume..." await _ssh_exec_async(ssh, "rm -rf /runpod-volume/models/FLUX.2-dev /runpod-volume/models/mistral_3_small_flux2_fp8.safetensors 2>/dev/null; echo done") # Download diffusion model (~12GB) diff_model = f"{z_dir}/z_image_turbo_bf16.safetensors" exists = (await _ssh_exec_async(ssh, f"test -f {diff_model} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": _pod_state["setup_status"] = "Downloading Z-Image Turbo diffusion model (~12GB)..." await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import hf_hub_download import shutil, os p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/diffusion_models/z_image_turbo_bf16.safetensors', local_dir='/tmp/z_image') shutil.move(p, '{diff_model}') print('Diffusion model downloaded') " 2>&1 | tail -5""", timeout=3600) # Download text encoder (~8GB Qwen 3 4B) te_model = f"{z_dir}/qwen_3_4b.safetensors" exists = (await _ssh_exec_async(ssh, f"test -f {te_model} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": _pod_state["setup_status"] = "Downloading Z-Image text encoder (~8GB)..." await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import hf_hub_download import shutil p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/text_encoders/qwen_3_4b.safetensors', local_dir='/tmp/z_image') shutil.move(p, '{te_model}') print('Text encoder downloaded') " 2>&1 | tail -5""", timeout=3600) # Download VAE (~335MB) vae_model = f"{z_dir}/ae.safetensors" exists = (await _ssh_exec_async(ssh, f"test -f {vae_model} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": _pod_state["setup_status"] = "Downloading Z-Image VAE..." await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import hf_hub_download import shutil p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/vae/ae.safetensors', local_dir='/tmp/z_image') shutil.move(p, '{vae_model}') print('VAE downloaded') " 2>&1 | tail -5""", timeout=600) # Symlink into ComfyUI directories await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/text_encoders {comfy_dir}/models/vae") await _ssh_exec_async(ssh, f"ln -sf {diff_model} {comfy_dir}/models/diffusion_models/z_image_turbo_bf16.safetensors") await _ssh_exec_async(ssh, f"ln -sf {te_model} {comfy_dir}/models/text_encoders/qwen_3_4b.safetensors") await _ssh_exec_async(ssh, f"ln -sf {vae_model} {comfy_dir}/models/vae/ae_z_image.safetensors") elif model_type == "wan22": # WAN 2.2 Remix NSFW — dual-DiT MoE split-step for realistic generation wan_dir = "/workspace/models/WAN2.2" await _ssh_exec_async(ssh, f"mkdir -p {wan_dir}") civitai_token = os.environ.get("CIVITAI_API_TOKEN", "") token_param = f"&token={civitai_token}" if civitai_token else "" # CivitAI Remix models (fp8 ~14GB each) civitai_models = { "Remix T2V High-noise": { "path": f"{wan_dir}/wan22_remix_t2v_high_fp8.safetensors", "url": f"https://civitai.com/api/download/models/2424167?type=Model&format=SafeTensor&size=pruned{token_param}", }, "Remix T2V Low-noise": { "path": f"{wan_dir}/wan22_remix_t2v_low_fp8.safetensors", "url": f"https://civitai.com/api/download/models/2424912?type=Model&format=SafeTensor&size=pruned{token_param}", }, } # HuggingFace models (T5 fp8 ~7GB, VAE ~1GB) hf_models = { "T5 text encoder (fp8)": { "path": f"{wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors", "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "filename": "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", }, "VAE": { "path": f"{wan_dir}/wan_2.1_vae.safetensors", "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "filename": "split_files/vae/wan_2.1_vae.safetensors", }, } # Download CivitAI Remix models for label, info in civitai_models.items(): exists = (await _ssh_exec_async(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip() if exists == "EXISTS": logger.info("WAN 2.2 %s already cached", label) else: _pod_state["setup_status"] = f"Downloading {label} (~14GB)..." await _ssh_exec_async(ssh, f"wget -q -O '{info['path']}' '{info['url']}'", timeout=1800) # Verify download check = (await _ssh_exec_async(ssh, f"test -f {info['path']} && stat -c%s {info['path']} || echo 0")).strip() if check == "0" or int(check) < 1000000: logger.error("Failed to download %s (size: %s). CivitAI API token may be required.", label, check) _pod_state["setup_status"] = f"Failed: {label} download failed. Set CIVITAI_API_TOKEN env var for NSFW models." return # Download HuggingFace models await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60) for label, info in hf_models.items(): exists = (await _ssh_exec_async(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip() if exists == "EXISTS": logger.info("WAN 2.2 %s already cached", label) else: _pod_state["setup_status"] = f"Downloading {label}..." await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import hf_hub_download import os, shutil hf_hub_download('{info['repo']}', '{info['filename']}', local_dir='{wan_dir}') downloaded = os.path.join('{wan_dir}', '{info['filename']}') target = '{info['path']}' if os.path.exists(downloaded) and downloaded != target: os.makedirs(os.path.dirname(target), exist_ok=True) shutil.move(downloaded, target) print('Downloaded {label}') " 2>&1 | tail -5""", timeout=1800) # Symlink models into ComfyUI await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/text_encoders") await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan22_remix_t2v_high_fp8.safetensors {comfy_dir}/models/diffusion_models/") await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan22_remix_t2v_low_fp8.safetensors {comfy_dir}/models/diffusion_models/") await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan_2.1_vae.safetensors {comfy_dir}/models/vae/") await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors {comfy_dir}/models/text_encoders/") # Install wanBlockSwap custom node (VRAM optimization for dual-DiT on 24GB GPUs) _pod_state["setup_status"] = "Installing WAN 2.2 custom nodes..." blockswap_dir = f"{comfy_dir}/custom_nodes/ComfyUI-wanBlockswap" blockswap_exists = (await _ssh_exec_async(ssh, f"test -d {blockswap_dir} && echo EXISTS || echo MISSING")).strip() if blockswap_exists != "EXISTS": await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/orssorbit/ComfyUI-wanBlockswap.git", timeout=120) elif model_type == "wan22_i2v": # WAN 2.2 Image-to-Video (14B params) — full model snapshot wan_dir = "/workspace/models/Wan2.2-I2V-A14B" wan_exists = (await _ssh_exec_async(ssh, f"test -d {wan_dir} && echo EXISTS || echo MISSING")).strip() if wan_exists != "EXISTS": _pod_state["setup_status"] = "Downloading WAN 2.2 I2V model (~28GB, first time only)..." await _ssh_exec_async(ssh, f"pip install huggingface_hub 2>&1 | tail -1", timeout=60) await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import snapshot_download snapshot_download('Wan-AI/Wan2.2-I2V-A14B', local_dir='{wan_dir}', ignore_patterns=['*.md', '*.txt']) print('WAN 2.2 I2V downloaded') " 2>&1 | tail -10""", timeout=3600) await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models") await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/diffusion_models/Wan2.2-I2V-A14B") await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/checkpoints/Wan2.2-I2V-A14B") # Install ComfyUI-WanVideoWrapper custom nodes _pod_state["setup_status"] = "Installing WAN 2.2 ComfyUI nodes..." wan_nodes_dir = f"{comfy_dir}/custom_nodes/ComfyUI-WanVideoWrapper" wan_nodes_exist = (await _ssh_exec_async(ssh, f"test -d {wan_nodes_dir} && echo EXISTS || echo MISSING")).strip() if wan_nodes_exist != "EXISTS": await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120) await _ssh_exec_async(ssh, f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) elif model_type == "wan22_animate": # WAN 2.2 Animate (14B fp8) — dance/motion transfer via pose skeleton animate_dir = "/workspace/models/WAN2.2-Animate" wan22_dir = "/workspace/models/WAN2.2" await _ssh_exec_async(ssh, f"mkdir -p {animate_dir}") await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60) # Download main Animate model (~28GB bf16 — only version available) animate_model = f"{animate_dir}/wan2.2_animate_14B_bf16.safetensors" exists = (await _ssh_exec_async(ssh, f"test -f {animate_model} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": _pod_state["setup_status"] = "Downloading WAN 2.2 Animate model (~28GB, first time only)..." await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import hf_hub_download import os, shutil hf_hub_download('Comfy-Org/Wan_2.2_ComfyUI_Repackaged', 'split_files/diffusion_models/wan2.2_animate_14B_bf16.safetensors', local_dir='{animate_dir}') src = os.path.join('{animate_dir}', 'split_files', 'diffusion_models', 'wan2.2_animate_14B_bf16.safetensors') if os.path.exists(src): shutil.move(src, '{animate_model}') print('Animate model downloaded') " 2>&1 | tail -5""", timeout=7200) # CLIP Vision H (~2.5GB) — ViT-H vision encoder clip_vision_target = f"{animate_dir}/clip_vision_h.safetensors" exists = (await _ssh_exec_async(ssh, f"test -f {clip_vision_target} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": _pod_state["setup_status"] = "Downloading CLIP Vision H (~2.5GB)..." await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import hf_hub_download import os, shutil result = hf_hub_download('h94/IP-Adapter', 'models/image_encoder/model.safetensors', local_dir='{animate_dir}/tmp_clip') shutil.move(result, '{clip_vision_target}') shutil.rmtree('{animate_dir}/tmp_clip', ignore_errors=True) print('CLIP Vision H downloaded') " 2>&1 | tail -5""", timeout=1800) # VAE — reuse from WAN2.2 dir if available, else download (~1GB) vae_target = f"{animate_dir}/wan_2.1_vae.safetensors" exists = (await _ssh_exec_async(ssh, f"test -f {vae_target} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": vae_from_wan22 = (await _ssh_exec_async(ssh, f"test -f {wan22_dir}/wan_2.1_vae.safetensors && echo EXISTS || echo MISSING")).strip() if vae_from_wan22 == "EXISTS": await _ssh_exec_async(ssh, f"ln -sf {wan22_dir}/wan_2.1_vae.safetensors {vae_target}") else: _pod_state["setup_status"] = "Downloading VAE (~1GB)..." await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import hf_hub_download import os, shutil hf_hub_download('Comfy-Org/Wan_2.2_ComfyUI_Repackaged', 'split_files/vae/wan_2.1_vae.safetensors', local_dir='{animate_dir}') src = os.path.join('{animate_dir}', 'split_files', 'vae', 'wan_2.1_vae.safetensors') if os.path.exists(src): shutil.move(src, '{vae_target}') print('VAE downloaded') " 2>&1 | tail -5""", timeout=600) # UMT5 T5 encoder fp8 (non-scaled) — use Kijai/WanVideo_comfy version # which is compatible with LoadWanVideoT5TextEncoder (scaled_fp8 is not supported) t5_filename = "umt5-xxl-enc-fp8_e4m3fn.safetensors" t5_target = f"{animate_dir}/{t5_filename}" t5_comfy_path = f"{comfy_dir}/models/text_encoders/{t5_filename}" t5_in_comfy = (await _ssh_exec_async(ssh, f"test -f {t5_comfy_path} && echo EXISTS || echo MISSING")).strip() t5_in_vol = (await _ssh_exec_async(ssh, f"test -f {t5_target} && echo EXISTS || echo MISSING")).strip() if t5_in_comfy != "EXISTS" and t5_in_vol != "EXISTS": _pod_state["setup_status"] = "Downloading UMT5 text encoder (~6.3GB, first time only)..." await _ssh_exec_async(ssh, f"""python -c " from huggingface_hub import hf_hub_download hf_hub_download('Kijai/WanVideo_comfy', '{t5_filename}', local_dir='{animate_dir}') print('UMT5 text encoder downloaded') " 2>&1 | tail -5""", timeout=1800) t5_in_vol = "EXISTS" # Symlink models into ComfyUI directories await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/vae {comfy_dir}/models/clip_vision {comfy_dir}/models/text_encoders") await _ssh_exec_async(ssh, f"ln -sf {animate_model} {comfy_dir}/models/diffusion_models/") await _ssh_exec_async(ssh, f"ln -sf {vae_target} {comfy_dir}/models/vae/") await _ssh_exec_async(ssh, f"ln -sf {clip_vision_target} {comfy_dir}/models/clip_vision/") if t5_in_vol == "EXISTS" and t5_in_comfy != "EXISTS": await _ssh_exec_async(ssh, f"ln -sf {t5_target} {t5_comfy_path}") # Reconnect SSH before custom node setup — connection may have dropped during long downloads ssh = await _ssh_connect_new() # Install required custom nodes _pod_state["setup_status"] = "Installing WAN Animate custom nodes..." # ComfyUI-WanVideoWrapper (WanVideoAnimateEmbeds, WanVideoSampler, etc.) wan_nodes_dir = f"{comfy_dir}/custom_nodes/ComfyUI-WanVideoWrapper" exists = (await _ssh_exec_r(f"test -d {wan_nodes_dir} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120) await _ssh_exec_r(f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) # ComfyUI-VideoHelperSuite (VHS_LoadVideo, VHS_VideoCombine) vhs_dir = f"{comfy_dir}/custom_nodes/ComfyUI-VideoHelperSuite" exists = (await _ssh_exec_r(f"test -d {vhs_dir} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git", timeout=120) await _ssh_exec_r(f"cd {vhs_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) # comfyui_controlnet_aux (DWPreprocessor for pose extraction) aux_dir = f"{comfy_dir}/custom_nodes/comfyui_controlnet_aux" exists = (await _ssh_exec_r(f"test -d {aux_dir} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/Fannovel16/comfyui_controlnet_aux.git", timeout=120) await _ssh_exec_r(f"cd {aux_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) # ComfyUI-KJNodes (ImageResizeKJv2 used in animate workflow) kj_dir = f"{comfy_dir}/custom_nodes/ComfyUI-KJNodes" exists = (await _ssh_exec_r(f"test -d {kj_dir} && echo EXISTS || echo MISSING")).strip() if exists != "EXISTS": await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-KJNodes.git", timeout=120) await _ssh_exec_r(f"cd {kj_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) # Symlink all LoRAs from volume await _ssh_exec_r(f"ls /runpod-volume/loras/*.safetensors 2>/dev/null | while read f; do ln -sf \"$f\" {comfy_dir}/models/loras/; done") # Start ComfyUI in background (fire-and-forget — don't wait for output) _pod_state["setup_status"] = "Starting ComfyUI..." await asyncio.to_thread( _ssh_exec_fire_and_forget, ssh, f"cd {comfy_dir} && python main.py --listen 0.0.0.0 --port 8188 --fp8_e4m3fn-unet > /tmp/comfyui.log 2>&1", ) await asyncio.sleep(2) # Give it a moment to start # Wait for ComfyUI HTTP to respond _pod_state["setup_status"] = "Waiting for ComfyUI to load model..." import httpx comfyui_url = _get_comfyui_url() for attempt in range(120): # Up to 10 minutes try: async with httpx.AsyncClient(timeout=5) as client: resp = await client.get(f"{comfyui_url}/system_stats") if resp.status_code == 200: _pod_state["status"] = "running" _pod_state["setup_status"] = "Ready" _save_pod_state() logger.info("ComfyUI ready at %s", comfyui_url) return except Exception: pass await asyncio.sleep(5) # If we get here, ComfyUI didn't start # Check the log for errors log_tail = await _ssh_exec_async(ssh, "tail -20 /tmp/comfyui.log") logger.error("ComfyUI didn't start. Log: %s", log_tail) _pod_state["setup_status"] = f"ComfyUI failed to start. Check logs." _pod_state["status"] = "setting_up" # Keep pod running so user can debug except Exception as e: import traceback err_msg = f"{type(e).__name__}: {e}" logger.error("Pod setup failed: %s\n%s", err_msg, traceback.format_exc()) _pod_state["setup_status"] = f"Setup failed: {err_msg}" _pod_state["status"] = "setting_up" # Keep pod running so user can debug finally: try: ssh.close() except Exception: pass def _ssh_exec(ssh, cmd: str, timeout: int = 120) -> str: """Execute a command over SSH and return stdout (blocking — call from async via to_thread or background task).""" _, stdout, stderr = ssh.exec_command(cmd, timeout=timeout) out = stdout.read().decode("utf-8", errors="replace") return out.strip() async def _ssh_exec_async(ssh, cmd: str, timeout: int = 120) -> str: """Async wrapper for SSH exec that doesn't block the event loop.""" return await asyncio.to_thread(_ssh_exec, ssh, cmd, timeout) def _ssh_exec_fire_and_forget(ssh, cmd: str): """Start a command over SSH without waiting for output (for background processes).""" transport = ssh.get_transport() channel = transport.open_session() channel.exec_command(cmd) # Don't read stdout/stderr — just let it run # --- Pre-download models to network volume (saves money during training) --- _download_state = { "status": "idle", # idle, downloading, completed, failed "pod_id": None, "progress": "", "error": None, } class DownloadModelsRequest(BaseModel): model_type: str = "wan22" gpu_type: str = "NVIDIA GeForce RTX 3090" # Cheapest GPU, just for downloading @router.post("/download-models") async def download_models_to_volume(request: DownloadModelsRequest): """Pre-download model files to network volume using a cheap pod. This saves expensive GPU time during training — models are cached on the shared volume and reused across all future training/generation pods. """ _get_api_key() volume_id, volume_dc = _get_volume_config() if not volume_id: raise HTTPException(400, "No network volume configured (set RUNPOD_VOLUME_ID)") if _download_state["status"] == "downloading": return {"status": "already_downloading", "progress": _download_state["progress"]} _download_state["status"] = "downloading" _download_state["progress"] = "Creating cheap download pod..." _download_state["error"] = None asyncio.create_task(_download_models_task(request.model_type, request.gpu_type, volume_id, volume_dc)) return {"status": "started", "message": f"Downloading {request.model_type} models to volume (using {request.gpu_type})"} @router.get("/download-models/status") async def download_models_status(): """Check model download progress.""" return _download_state async def _download_models_task(model_type: str, gpu_type: str, volume_id: str, volume_dc: str): """Background task: spin up cheap pod, download models, terminate.""" import paramiko ssh = None pod_id = None try: # Create cheap pod with network volume — try multiple GPU types if first unavailable pod_kwargs = { "container_disk_in_gb": 10, "ports": "22/tcp", "network_volume_id": volume_id, "docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'", } if volume_dc: pod_kwargs["data_center_id"] = volume_dc gpu_fallbacks = [ gpu_type, "NVIDIA RTX A4000", "NVIDIA RTX A5000", "NVIDIA GeForce RTX 4090", "NVIDIA GeForce RTX 4080", "NVIDIA A100-SXM4-80GB", ] pod = None used_gpu = gpu_type for try_gpu in gpu_fallbacks: try: pod = await asyncio.to_thread( runpod.create_pod, f"model-download-{model_type}", DOCKER_IMAGE, try_gpu, **pod_kwargs, ) used_gpu = try_gpu logger.info("Download pod created with %s", try_gpu) break except Exception as e: if "SUPPLY_CONSTRAINT" in str(e) or "no longer any instances" in str(e).lower(): logger.info("GPU %s unavailable, trying next...", try_gpu) continue raise if pod is None: raise RuntimeError("No GPU available for download pod. Try again later.") pod_id = pod["id"] _download_state["pod_id"] = pod_id _download_state["progress"] = f"Pod created with {used_gpu} ({pod_id}), waiting for SSH..." # Wait for SSH ssh_host = ssh_port = None start = time.time() while time.time() - start < 300: try: p = await asyncio.to_thread(runpod.get_pod, pod_id) if p and p.get("desiredStatus") == "RUNNING": for port in (p.get("runtime") or {}).get("ports") or []: if port.get("privatePort") == 22: ssh_host = port.get("ip") ssh_port = port.get("publicPort") if ssh_host and ssh_port: break except Exception: pass await asyncio.sleep(5) if not ssh_host: raise RuntimeError("Pod SSH not available after 5 min") # Connect SSH ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) for attempt in range(20): try: await asyncio.to_thread(ssh.connect, ssh_host, port=int(ssh_port), username="root", password="runpod", timeout=10) break except Exception: if attempt == 19: raise RuntimeError("SSH connection failed after 20 attempts") await asyncio.sleep(5) ssh.get_transport().set_keepalive(30) _download_state["progress"] = "SSH connected, setting up tools..." # Symlink volume await _ssh_exec_async(ssh, "mkdir -p /runpod-volume/models && rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models") await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=120) await _ssh_exec_async(ssh, "which aria2c || apt-get install -y aria2 2>&1 | tail -1", timeout=120) if model_type == "wan22": wan_dir = "/workspace/models/WAN2.2" await _ssh_exec_async(ssh, f"mkdir -p {wan_dir}") civitai_token = os.environ.get("CIVITAI_API_TOKEN", "") token_param = f"&token={civitai_token}" if civitai_token else "" # CivitAI Remix models (fp8) civitai_files = [ ("Remix T2V High-noise", f"https://civitai.com/api/download/models/2424167?type=Model&format=SafeTensor&size=pruned{token_param}", f"{wan_dir}/wan22_remix_t2v_high_fp8.safetensors"), ("Remix T2V Low-noise", f"https://civitai.com/api/download/models/2424912?type=Model&format=SafeTensor&size=pruned{token_param}", f"{wan_dir}/wan22_remix_t2v_low_fp8.safetensors"), ] # HuggingFace models hf_files = [ ("T5 text encoder (fp8)", "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", f"{wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors"), ("VAE", "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/vae/wan_2.1_vae.safetensors", f"{wan_dir}/wan_2.1_vae.safetensors"), ] total = len(civitai_files) + len(hf_files) idx = 0 for label, url, target in civitai_files: idx += 1 exists = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip() if exists == "EXISTS": _download_state["progress"] = f"[{idx}/{total}] {label} already cached" logger.info("WAN 2.2 %s already on volume", label) else: _download_state["progress"] = f"[{idx}/{total}] Downloading {label} (~14GB)..." await _ssh_exec_async(ssh, f"wget -q -O '{target}' '{url}'", timeout=1800) check = (await _ssh_exec_async(ssh, f"test -f {target} && stat -c%s {target} || echo 0")).strip() if check == "0" or int(check) < 1000000: raise RuntimeError(f"Failed to download {label}. Set CIVITAI_API_TOKEN for NSFW models.") _download_state["progress"] = f"[{idx}/{total}] {label} downloaded" for label, repo, filename, target in hf_files: idx += 1 exists = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip() if exists == "EXISTS": _download_state["progress"] = f"[{idx}/{total}] {label} already cached" logger.info("WAN 2.2 %s already on volume", label) else: _download_state["progress"] = f"[{idx}/{total}] Downloading {label}..." hf_url = f"https://huggingface.co/{repo}/resolve/main/{filename}" fname = target.split("/")[-1] tdir = "/".join(target.split("/")[:-1]) await _ssh_exec_async(ssh, f"aria2c -x 16 -s 16 -c -o '{fname}' --dir='{tdir}' '{hf_url}' 2>&1 | tail -3", timeout=1800) check = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip() if check != "EXISTS": raise RuntimeError(f"Failed to download {label}") _download_state["progress"] = f"[{idx}/{total}] {label} downloaded" # Also pre-clone musubi-tuner to volume (for training) _download_state["progress"] = "Caching musubi-tuner to volume..." tuner_exists = (await _ssh_exec_async(ssh, "test -f /runpod-volume/musubi-tuner/pyproject.toml && echo EXISTS || echo MISSING")).strip() if tuner_exists != "EXISTS": await _ssh_exec_async(ssh, "cd /workspace && git clone --depth 1 https://github.com/kohya-ss/musubi-tuner.git && cp -r /workspace/musubi-tuner /runpod-volume/musubi-tuner", timeout=300) _download_state["progress"] = "musubi-tuner cached" else: _download_state["progress"] = "musubi-tuner already cached" elif model_type == "wan22_animate": animate_dir = "/workspace/models/WAN2.2-Animate" wan22_dir = "/workspace/models/WAN2.2" hf_base = "https://huggingface.co" await _ssh_exec_async(ssh, f"mkdir -p {animate_dir}") # Files to download: (label, url, target, timeout_s, min_bytes) wget_files = [ ( "WAN 2.2 Animate model (~32GB)", f"{hf_base}/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_animate_14B_bf16.safetensors", f"{animate_dir}/wan2.2_animate_14B_bf16.safetensors", 7200, 30_000_000_000, # 30GB min — partial downloads get resumed ), ( "UMT5 text encoder fp8 (~6.3GB)", f"{hf_base}/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors", f"{animate_dir}/umt5-xxl-enc-fp8_e4m3fn.safetensors", 1800, 6_000_000_000, ), ( "VAE (~242MB)", f"{hf_base}/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors", f"{animate_dir}/wan_2.1_vae.safetensors", 300, 200_000_000, ), ( "CLIP Vision H (~2.4GB)", f"{hf_base}/h94/IP-Adapter/resolve/main/models/image_encoder/model.safetensors", f"{animate_dir}/clip_vision_h.safetensors", 900, 2_000_000_000, ), ] total = len(wget_files) for idx, (label, url, target, dl_timeout, min_bytes) in enumerate(wget_files, 1): # For T5 and VAE, reuse from wan22 dir if already present (and complete) wan22_candidate = f"{wan22_dir}/{target.split('/')[-1]}" reused = False if label in ("UMT5 text encoder fp8 (~6.3GB)", "VAE (~1GB)"): wan22_size = (await _ssh_exec_async(ssh, f"stat -c%s {wan22_candidate} 2>/dev/null || echo 0")).strip() if int(wan22_size) >= min_bytes: _download_state["progress"] = f"[{idx}/{total}] {label} — reusing from WAN2.2 dir" await _ssh_exec_async(ssh, f"ln -sf {wan22_candidate} {target} 2>/dev/null || cp {wan22_candidate} {target}") reused = True if not reused: size_str = (await _ssh_exec_async(ssh, f"stat -c%s {target} 2>/dev/null || echo 0")).strip() if int(size_str) >= min_bytes: _download_state["progress"] = f"[{idx}/{total}] {label} already cached" else: _download_state["progress"] = f"[{idx}/{total}] Downloading {label}..." filename = target.split("/")[-1] target_dir = "/".join(target.split("/")[:-1]) # Remove stale symlinks before downloading (can't resume through a symlink) await _ssh_exec_async(ssh, f"test -L '{target}' && rm -f '{target}'; true") await _ssh_exec_async( ssh, f"aria2c -x 16 -s 16 -c -o '{filename}' --dir='{target_dir}' '{url}' 2>&1 | tail -3", timeout=dl_timeout, ) size_str = (await _ssh_exec_async(ssh, f"stat -c%s {target} 2>/dev/null || echo 0")).strip() if int(size_str) < min_bytes: raise RuntimeError(f"Failed to download {label} (size {size_str} < {min_bytes})") _download_state["progress"] = f"[{idx}/{total}] {label} downloaded" _download_state["status"] = "completed" _download_state["progress"] = "All models downloaded to volume! Ready for training." logger.info("Model pre-download complete for %s", model_type) except Exception as e: _download_state["status"] = "failed" _download_state["error"] = str(e) _download_state["progress"] = f"Failed: {e}" logger.error("Model download failed: %s", e) finally: if ssh: try: ssh.close() except Exception: pass if pod_id: try: await asyncio.to_thread(runpod.terminate_pod, pod_id) logger.info("Download pod terminated: %s", pod_id) except Exception as e: logger.warning("Failed to terminate download pod: %s", e) _download_state["pod_id"] = None @router.post("/stop") async def stop_pod(): """Stop the GPU pod.""" _get_api_key() if not _pod_state["pod_id"]: return {"status": "already_stopped"} if _pod_state["status"] == "stopping": return {"status": "stopping", "message": "Pod is already stopping"} _pod_state["status"] = "stopping" try: pod_id = _pod_state["pod_id"] logger.info("Stopping pod: %s", pod_id) await asyncio.to_thread(runpod.terminate_pod, pod_id) _pod_state["pod_id"] = None _pod_state["ip"] = None _pod_state["ssh_port"] = None _pod_state["comfyui_port"] = None _pod_state["status"] = "stopped" _pod_state["started_at"] = None _pod_state["setup_status"] = None _save_pod_state() logger.info("Pod stopped") return {"status": "stopped", "message": "Pod terminated"} except Exception as e: logger.error("Failed to stop pod: %s", e) _pod_state["status"] = "running" raise HTTPException(500, f"Failed to stop pod: {e}") @router.get("/loras") async def list_pod_loras(): """List LoRAs available on the pod.""" if _pod_state["status"] != "running" or not _pod_state["ip"]: return {"loras": [], "message": "Pod not running"} comfyui_url = _get_comfyui_url() try: import httpx async with httpx.AsyncClient(timeout=30) as client: url = f"{comfyui_url}/object_info/LoraLoader" resp = await client.get(url) if resp.status_code == 200: data = resp.json() loras = data.get("LoraLoader", {}).get("input", {}).get("required", {}).get("lora_name", [[]])[0] return {"loras": loras if isinstance(loras, list) else []} except Exception as e: logger.warning("Failed to list pod LoRAs: %s", e) return {"loras": [], "comfyui_url": comfyui_url} @router.post("/upload-lora") async def upload_lora_to_pod(file: UploadFile = File(...)): """Upload a LoRA file directly to /runpod-volume/loras/ via SFTP so it persists.""" import paramiko, io if _pod_state["status"] != "running": raise HTTPException(400, "Pod not running - start it first") if not file.filename.endswith(".safetensors"): raise HTTPException(400, "Only .safetensors files supported") ip = _pod_state.get("ip") port = _pod_state.get("ssh_port") or 22 if not ip: raise HTTPException(500, "No SSH IP available") content = await file.read() dest_path = f"/runpod-volume/loras/{file.filename}" comfy_link = f"/workspace/ComfyUI/models/loras/{file.filename}" def _sftp_upload(): client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.connect(ip, port=port, username="root", timeout=30) # Ensure dir exists client.exec_command("mkdir -p /runpod-volume/loras")[1].read() sftp = client.open_sftp() sftp.putfo(io.BytesIO(content), dest_path) sftp.close() # Symlink into ComfyUI client.exec_command(f"ln -sf {dest_path} {comfy_link}")[1].read() client.close() try: await asyncio.to_thread(_sftp_upload) logger.info("LoRA uploaded to volume: %s (%d bytes)", file.filename, len(content)) return {"status": "uploaded", "filename": file.filename, "path": dest_path} except Exception as e: logger.error("LoRA upload failed: %s", e) raise HTTPException(500, f"Upload failed: {e}") @router.post("/upload-lora-local") async def upload_lora_from_local(local_path: str, filename: str | None = None): """Upload a LoRA from a local server path directly to the volume via SFTP.""" import paramiko, io from pathlib import Path if _pod_state["status"] != "running": raise HTTPException(400, "Pod not running - start it first") src = Path(local_path) if not src.exists(): raise HTTPException(404, f"Local file not found: {local_path}") dest_name = filename or src.name if not dest_name.endswith(".safetensors"): raise HTTPException(400, "Only .safetensors files supported") ip = _pod_state.get("ip") port = _pod_state.get("ssh_port") or 22 dest_path = f"/runpod-volume/loras/{dest_name}" comfy_link = f"/workspace/ComfyUI/models/loras/{dest_name}" def _sftp_upload(): client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.connect(ip, port=port, username="root", timeout=30) client.exec_command("mkdir -p /runpod-volume/loras")[1].read() sftp = client.open_sftp() sftp.put(str(src), dest_path) sftp.close() client.exec_command(f"ln -sf {dest_path} {comfy_link}")[1].read() client.close() try: await asyncio.to_thread(_sftp_upload) size_mb = src.stat().st_size / 1024 / 1024 logger.info("LoRA uploaded from local: %s (%.1f MB)", dest_name, size_mb) return {"status": "uploaded", "filename": dest_name, "path": dest_path, "size_mb": round(size_mb, 1)} except Exception as e: logger.error("Local LoRA upload failed: %s", e) raise HTTPException(500, f"Upload failed: {e}") class PodGenerateRequest(BaseModel): prompt: str negative_prompt: str = "" width: int = 1024 height: int = 1024 steps: int = 28 cfg: float = 3.5 seed: int = -1 lora_name: str | None = None lora_strength: float = 0.85 lora_name_2: str | None = None lora_strength_2: float = 0.85 character_id: str | None = None template_id: str | None = None content_rating: str = "sfw" # In-memory job tracking for pod generation _pod_jobs: dict[str, dict] = {} @router.post("/generate") async def generate_on_pod(request: PodGenerateRequest): """Generate an image using the running pod's ComfyUI.""" import httpx import random if _pod_state["status"] != "running": raise HTTPException(400, "Pod not running - start it first") job_id = str(uuid.uuid4())[:8] seed = request.seed if request.seed >= 0 else random.randint(0, 2**32 - 1) model_type = _pod_state.get("model_type", "flux2") if model_type == "wan22": workflow = _build_wan_t2i_workflow( prompt=request.prompt, negative_prompt=request.negative_prompt, width=request.width, height=request.height, steps=request.steps, cfg=request.cfg, seed=seed, lora_name=request.lora_name, lora_strength=request.lora_strength, lora_name_2=request.lora_name_2, lora_strength_2=request.lora_strength_2, ) else: workflow = _build_flux_workflow( prompt=request.prompt, negative_prompt=request.negative_prompt, width=request.width, height=request.height, steps=request.steps, cfg=request.cfg, seed=seed, lora_name=request.lora_name, lora_strength=request.lora_strength, model_type=model_type, ) comfyui_url = _get_comfyui_url() try: async with httpx.AsyncClient(timeout=30) as client: resp = await client.post(f"{comfyui_url}/prompt", json={"prompt": workflow}) resp.raise_for_status() data = resp.json() prompt_id = data["prompt_id"] _pod_jobs[job_id] = { "prompt_id": prompt_id, "status": "running", "seed": seed, "created_at": time.time(), "started_at": time.time(), "positive_prompt": request.prompt, "negative_prompt": request.negative_prompt, "steps": request.steps, "cfg": request.cfg, "width": request.width, "height": request.height, } logger.info("Pod generation started: %s -> %s", job_id, prompt_id) asyncio.create_task(_poll_pod_job(job_id, prompt_id, request.content_rating)) return {"job_id": job_id, "status": "running", "seed": seed} except Exception as e: logger.error("Pod generation failed: %s", e) raise HTTPException(500, f"Generation failed: {e}") async def _poll_pod_job(job_id: str, prompt_id: str, content_rating: str): """Poll ComfyUI for job completion and save the result.""" import httpx start = time.time() timeout = 900 # 15 min — first gen loads model (~12GB) + samples comfyui_url = _get_comfyui_url() last_log_time = 0 async with httpx.AsyncClient(timeout=60) as client: while time.time() - start < timeout: try: # Log queue progress every 15 seconds and store in job elapsed = time.time() - start if elapsed - last_log_time >= 15: last_log_time = elapsed try: q_resp = await client.get(f"{comfyui_url}/queue") if q_resp.status_code == 200: q_data = q_resp.json() running = q_data.get("queue_running", []) pending = len(q_data.get("queue_pending", [])) status_msg = f"{int(elapsed)}s elapsed" if running: # Try to get node execution progress try: p_resp = await client.get(f"{comfyui_url}/prompt") if p_resp.status_code == 200: p_data = p_resp.json() exec_info = p_data.get("exec_info", {}) if exec_info: status_msg += f" | nodes: {exec_info}" except Exception: pass status_msg += " | generating..." elif pending: status_msg += " | loading models..." else: status_msg += " | waiting..." _pod_jobs[job_id]["progress_msg"] = status_msg logger.info("Pod gen %s: %s", job_id, status_msg) except Exception: pass resp = await client.get(f"{comfyui_url}/history/{prompt_id}") if resp.status_code == 200: data = resp.json() if prompt_id in data: outputs = data[prompt_id].get("outputs", {}) for node_id, node_output in outputs.items(): if "images" in node_output: image_info = node_output["images"][0] filename = image_info["filename"] subfolder = image_info.get("subfolder", "") params = {"filename": filename} if subfolder: params["subfolder"] = subfolder img_resp = await client.get(f"{comfyui_url}/view", params=params) if img_resp.status_code == 200: from content_engine.config import settings output_dir = settings.paths.output_dir / "pod" / content_rating / "raw" output_dir.mkdir(parents=True, exist_ok=True) local_path = output_dir / f"pod_{job_id}.png" local_path.write_bytes(img_resp.content) _pod_jobs[job_id]["status"] = "completed" _pod_jobs[job_id]["output_path"] = str(local_path) _pod_jobs[job_id]["completed_at"] = time.time() logger.info("Pod generation completed: %s -> %s", job_id, local_path) try: from content_engine.services.catalog import CatalogService catalog = CatalogService() job_info = _pod_jobs[job_id] await catalog.insert_image( file_path=str(local_path), image_bytes=img_resp.content, content_rating=content_rating, positive_prompt=job_info.get("positive_prompt"), negative_prompt=job_info.get("negative_prompt"), seed=job_info.get("seed"), steps=job_info.get("steps"), cfg=job_info.get("cfg"), width=job_info.get("width"), height=job_info.get("height"), generation_backend="runpod-pod", generation_time_seconds=time.time() - job_info.get("created_at", time.time()), ) logger.info("Pod image cataloged: %s", job_id) except Exception as e: logger.warning("Failed to catalog pod image: %s", e) return except Exception as e: logger.debug("Polling pod job: %s", e) await asyncio.sleep(2) _pod_jobs[job_id]["status"] = "failed" _pod_jobs[job_id]["error"] = "Timeout waiting for generation" logger.error("Pod generation timed out: %s", job_id) @router.get("/jobs/{job_id}") async def get_pod_job(job_id: str): """Get status of a pod generation job.""" job = _pod_jobs.get(job_id) if not job: raise HTTPException(404, "Job not found") return job @router.get("/jobs/{job_id}/image") async def get_pod_job_image(job_id: str): """Serve the generated image for a completed pod job.""" from fastapi.responses import FileResponse job = _pod_jobs.get(job_id) if not job: raise HTTPException(404, "Job not found") output_path = job.get("output_path") if not output_path: raise HTTPException(404, "No image yet") from pathlib import Path p = Path(output_path) if not p.exists(): raise HTTPException(404, "Image file not found") return FileResponse(p, media_type="image/png") def _build_flux_workflow( prompt: str, negative_prompt: str, width: int, height: int, steps: int, cfg: float, seed: int, lora_name: str | None, lora_strength: float, model_type: str = "flux2", ) -> dict: """Build a ComfyUI workflow for FLUX generation. FLUX.2 Dev uses separate model components (not a single checkpoint): - UNETLoader for the diffusion model - CLIPLoader (type=flux2) for the Mistral text encoder - VAELoader for the autoencoder """ if model_type == "flux2": unet_name = "flux2-dev.safetensors" clip_type = "flux2" clip_name = "mistral_3_small_flux2_fp8.safetensors" else: unet_name = "flux1-dev.safetensors" clip_type = "flux" clip_name = "t5xxl_fp16.safetensors" # Model node ID references model_out = ["1", 0] # UNETLoader -> MODEL clip_out = ["2", 0] # CLIPLoader -> CLIP vae_out = ["3", 0] # VAELoader -> VAE workflow = { # Load diffusion model (UNet) "1": { "class_type": "UNETLoader", "inputs": { "unet_name": unet_name, "weight_dtype": "fp8_e4m3fn", }, }, # Load text encoder "2": { "class_type": "CLIPLoader", "inputs": { "clip_name": clip_name, "type": clip_type, }, }, # Load VAE "3": { "class_type": "VAELoader", "inputs": {"vae_name": "ae.safetensors"}, }, # Positive prompt "6": { "class_type": "CLIPTextEncode", "inputs": { "text": prompt, "clip": clip_out, }, }, # Negative prompt "7": { "class_type": "CLIPTextEncode", "inputs": { "text": negative_prompt or "", "clip": clip_out, }, }, # Empty latent "5": { "class_type": "EmptyLatentImage", "inputs": { "width": width, "height": height, "batch_size": 1, }, }, # Sampler "10": { "class_type": "KSampler", "inputs": { "seed": seed, "steps": steps, "cfg": cfg, "sampler_name": "euler", "scheduler": "simple", "denoise": 1.0, "model": model_out, "positive": ["6", 0], "negative": ["7", 0], "latent_image": ["5", 0], }, }, # Decode "8": { "class_type": "VAEDecode", "inputs": { "samples": ["10", 0], "vae": vae_out, }, }, # Save "9": { "class_type": "SaveImage", "inputs": { "filename_prefix": "flux_pod", "images": ["8", 0], }, }, } # Add LoRA if specified if lora_name: workflow["20"] = { "class_type": "LoraLoader", "inputs": { "lora_name": lora_name, "strength_model": lora_strength, "strength_clip": lora_strength, "model": model_out, "clip": clip_out, }, } # Rewire sampler and text encoders to use LoRA output workflow["10"]["inputs"]["model"] = ["20", 0] workflow["6"]["inputs"]["clip"] = ["20", 1] workflow["7"]["inputs"]["clip"] = ["20", 1] return workflow def _build_wan_t2i_workflow( prompt: str, negative_prompt: str, width: int, height: int, steps: int, cfg: float, seed: int, lora_name: str | None, lora_strength: float, lora_name_2: str | None = None, lora_strength_2: float = 0.85, ) -> dict: """Build a ComfyUI workflow for WAN 2.2 Remix — dual-DiT MoE split-step. Based on the WAN 2.2 Remix workflow from CivitAI: - Two UNETLoaders: high-noise + low-noise Remix models (fp8) - wanBlockSwap on both (offloads blocks to CPU for 24GB GPUs) - ModelSamplingSD3 with shift=5 on both - Dual KSamplerAdvanced: high-noise runs first half, low-noise finishes - CLIPLoader (type=wan) + CLIPTextEncode for T5 text encoding - Standard VAELoader + VAEDecode - EmptyHunyuanLatentVideo for latent (1 frame = image, 81+ = video) """ high_dit = "wan22_remix_t2v_high_fp8.safetensors" low_dit = "wan22_remix_t2v_low_fp8.safetensors" t5_name = "umt5_xxl_fp8_e4m3fn_scaled.safetensors" vae_name = "wan_2.1_vae.safetensors" total_steps = steps # default 8 split_step = total_steps // 2 # high-noise does first half, low-noise does rest shift = 5.0 block_swap = 20 # blocks offloaded to CPU (0-40, higher = less VRAM) workflow = { # ── Load high-noise DiT ── "1": { "class_type": "UNETLoader", "inputs": { "unet_name": high_dit, "weight_dtype": "fp8_e4m3fn", }, }, # ── Load low-noise DiT ── "2": { "class_type": "UNETLoader", "inputs": { "unet_name": low_dit, "weight_dtype": "fp8_e4m3fn", }, }, # ── wanBlockSwap on high-noise (VRAM optimization) ── "11": { "class_type": "wanBlockSwap", "inputs": { "model": ["1", 0], "blocks_to_swap": block_swap, "offload_img_emb": False, "offload_txt_emb": False, }, }, # ── wanBlockSwap on low-noise ── "12": { "class_type": "wanBlockSwap", "inputs": { "model": ["2", 0], "blocks_to_swap": block_swap, "offload_img_emb": False, "offload_txt_emb": False, }, }, # ── ModelSamplingSD3 shift on high-noise ── "13": { "class_type": "ModelSamplingSD3", "inputs": { "model": ["11", 0], "shift": shift, }, }, # ── ModelSamplingSD3 shift on low-noise ── "14": { "class_type": "ModelSamplingSD3", "inputs": { "model": ["12", 0], "shift": shift, }, }, # ── Load T5 text encoder ── "3": { "class_type": "CLIPLoader", "inputs": { "clip_name": t5_name, "type": "wan", }, }, # ── Positive prompt ── "6": { "class_type": "CLIPTextEncode", "inputs": { "text": prompt, "clip": ["3", 0], }, }, # ── Negative prompt ── "7": { "class_type": "CLIPTextEncode", "inputs": { "text": negative_prompt or "", "clip": ["3", 0], }, }, # ── VAE ── "4": { "class_type": "VAELoader", "inputs": {"vae_name": vae_name}, }, # ── Empty latent (1 frame = single image) ── "5": { "class_type": "EmptyHunyuanLatentVideo", "inputs": { "width": width, "height": height, "length": 1, "batch_size": 1, }, }, # ── KSamplerAdvanced #1: High-noise model (first half of steps) ── "15": { "class_type": "KSamplerAdvanced", "inputs": { "model": ["13", 0], "positive": ["6", 0], "negative": ["7", 0], "latent_image": ["5", 0], "add_noise": "enable", "noise_seed": seed, "steps": total_steps, "cfg": cfg, "sampler_name": "euler", "scheduler": "simple", "start_at_step": 0, "end_at_step": split_step, "return_with_leftover_noise": "enable", }, }, # ── KSamplerAdvanced #2: Low-noise model (second half of steps) ── "16": { "class_type": "KSamplerAdvanced", "inputs": { "model": ["14", 0], "positive": ["6", 0], "negative": ["7", 0], "latent_image": ["15", 0], "add_noise": "disable", "noise_seed": seed, "steps": total_steps, "cfg": cfg, "sampler_name": "euler", "scheduler": "simple", "start_at_step": split_step, "end_at_step": 10000, "return_with_leftover_noise": "disable", }, }, # ── VAE Decode ── "8": { "class_type": "VAEDecode", "inputs": { "samples": ["16", 0], "vae": ["4", 0], }, }, # ── Save Image ── "9": { "class_type": "SaveImage", "inputs": { "filename_prefix": "wan_remix_pod", "images": ["8", 0], }, }, } # Add LoRA(s) to both models if specified — chained: DiT → LoRA1 → LoRA2 → Sampler if lora_name: # LoRA 1 (body) on high-noise and low-noise DiT workflow["20"] = { "class_type": "LoraLoader", "inputs": { "lora_name": lora_name, "strength_model": lora_strength, "strength_clip": 1.0, "model": ["13", 0], "clip": ["3", 0], }, } workflow["21"] = { "class_type": "LoraLoader", "inputs": { "lora_name": lora_name, "strength_model": lora_strength, "strength_clip": 1.0, "model": ["14", 0], "clip": ["3", 0], }, } # Determine what the samplers and CLIP read from (LoRA2 if present, else LoRA1) high_model_out = ["20", 0] low_model_out = ["21", 0] clip_out = ["20", 1] if lora_name_2: # LoRA 2 (face) chained after LoRA 1 on both models workflow["22"] = { "class_type": "LoraLoader", "inputs": { "lora_name": lora_name_2, "strength_model": lora_strength_2, "strength_clip": 1.0, "model": ["20", 0], "clip": ["20", 1], }, } workflow["23"] = { "class_type": "LoraLoader", "inputs": { "lora_name": lora_name_2, "strength_model": lora_strength_2, "strength_clip": 1.0, "model": ["21", 0], "clip": ["21", 1], }, } high_model_out = ["22", 0] low_model_out = ["23", 0] clip_out = ["22", 1] # Rewire samplers and CLIP encoding workflow["15"]["inputs"]["model"] = high_model_out workflow["16"]["inputs"]["model"] = low_model_out workflow["6"]["inputs"]["clip"] = clip_out workflow["7"]["inputs"]["clip"] = clip_out return workflow