Spaces:
Running
Running
| """RunPod Pod management routes β start/stop GPU pods for generation. | |
| Starts a persistent ComfyUI pod with network volume access. | |
| Models and LoRAs are loaded from the shared network volume. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any | |
| import runpod | |
| from fastapi import APIRouter, File, HTTPException, UploadFile | |
| from pydantic import BaseModel | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/api/pod", tags=["pod"]) | |
| # Persist pod state to disk so it survives server restarts | |
| _POD_STATE_FILE = Path(__file__).parent.parent.parent.parent / "pod_state.json" | |
| def _save_pod_state(): | |
| """Save pod state to disk.""" | |
| try: | |
| data = {k: v for k, v in _pod_state.items() if k != "setup_status"} | |
| _POD_STATE_FILE.write_text(json.dumps(data)) | |
| except Exception as e: | |
| logger.warning("Failed to save pod state: %s", e) | |
| def _load_pod_state(): | |
| """Load pod state from disk on startup.""" | |
| try: | |
| if _POD_STATE_FILE.exists(): | |
| data = json.loads(_POD_STATE_FILE.read_text()) | |
| for k, v in data.items(): | |
| if k in _pod_state: | |
| _pod_state[k] = v | |
| logger.info("Restored pod state: pod_id=%s status=%s", _pod_state.get("pod_id"), _pod_state.get("status")) | |
| except Exception as e: | |
| logger.warning("Failed to load pod state: %s", e) | |
| def _get_volume_config() -> tuple[str, str]: | |
| """Get network volume config at runtime (after dotenv loads).""" | |
| return ( | |
| os.environ.get("RUNPOD_VOLUME_ID", ""), | |
| os.environ.get("RUNPOD_VOLUME_DC", ""), | |
| ) | |
| # Docker image β PyTorch base with CUDA, we install ComfyUI ourselves | |
| DOCKER_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04" | |
| # Pod state | |
| _pod_state = { | |
| "pod_id": None, | |
| "status": "stopped", # stopped, starting, setting_up, running, stopping | |
| "ip": None, | |
| "ssh_port": None, | |
| "comfyui_port": None, | |
| "gpu_type": "NVIDIA RTX A6000", | |
| "model_type": "flux2", | |
| "started_at": None, | |
| "cost_per_hour": 0.76, | |
| "setup_status": None, | |
| } | |
| _load_pod_state() | |
| # GPU options (same as training) | |
| GPU_OPTIONS = { | |
| "NVIDIA A40": {"name": "A40 48GB", "vram": 48, "cost": 0.64}, | |
| "NVIDIA RTX A6000": {"name": "RTX A6000 48GB", "vram": 48, "cost": 0.76}, | |
| "NVIDIA L40": {"name": "L40 48GB", "vram": 48, "cost": 0.89}, | |
| "NVIDIA L40S": {"name": "L40S 48GB", "vram": 48, "cost": 1.09}, | |
| "NVIDIA A100-SXM4-80GB": {"name": "A100 SXM 80GB", "vram": 80, "cost": 1.64}, | |
| "NVIDIA A100 80GB PCIe": {"name": "A100 PCIe 80GB", "vram": 80, "cost": 1.89}, | |
| "NVIDIA H100 80GB HBM3": {"name": "H100 80GB", "vram": 80, "cost": 3.89}, | |
| "NVIDIA GeForce RTX 5090": {"name": "RTX 5090 32GB", "vram": 32, "cost": 0.69}, | |
| "NVIDIA GeForce RTX 4090": {"name": "RTX 4090 24GB", "vram": 24, "cost": 0.44}, | |
| "NVIDIA GeForce RTX 3090": {"name": "RTX 3090 24GB", "vram": 24, "cost": 0.22}, | |
| } | |
| def _get_comfyui_url() -> str | None: | |
| """Get the ComfyUI URL via RunPod's HTTPS proxy. | |
| RunPod HTTP ports are only accessible through their proxy at | |
| https://{pod_id}-{private_port}.proxy.runpod.net | |
| The raw IP:port from the API is an internal address, not publicly routable. | |
| """ | |
| pod_id = _pod_state.get("pod_id") | |
| if pod_id: | |
| return f"https://{pod_id}-8188.proxy.runpod.net" | |
| return None | |
| def _get_api_key() -> str: | |
| key = os.environ.get("RUNPOD_API_KEY") | |
| if not key: | |
| raise HTTPException(503, "RUNPOD_API_KEY not configured") | |
| runpod.api_key = key | |
| return key | |
| class StartPodRequest(BaseModel): | |
| gpu_type: str = "NVIDIA RTX A6000" | |
| model_type: str = "flux2" | |
| class PodStatus(BaseModel): | |
| status: str | |
| pod_id: str | None = None | |
| ip: str | None = None | |
| port: int | None = None | |
| gpu_type: str | None = None | |
| model_type: str | None = None | |
| cost_per_hour: float | None = None | |
| setup_status: str | None = None | |
| uptime_minutes: float | None = None | |
| comfyui_url: str | None = None | |
| async def get_pod_status(): | |
| """Get current pod status.""" | |
| _get_api_key() | |
| if _pod_state["pod_id"]: | |
| try: | |
| pod = await asyncio.wait_for( | |
| asyncio.to_thread(runpod.get_pod, _pod_state["pod_id"]), | |
| timeout=10, | |
| ) | |
| if pod: | |
| desired = pod.get("desiredStatus", "") | |
| if desired == "RUNNING": | |
| runtime = pod.get("runtime") or {} | |
| ports = runtime.get("ports") or [] | |
| for p in ports: | |
| if p.get("privatePort") == 22: | |
| _pod_state["ssh_ip"] = p.get("ip") | |
| _pod_state["ssh_port"] = p.get("publicPort") | |
| if p.get("privatePort") == 8188: | |
| _pod_state["comfyui_ip"] = p.get("ip") | |
| _pod_state["comfyui_port"] = p.get("publicPort") | |
| # Use SSH IP as the main IP for display | |
| _pod_state["ip"] = _pod_state.get("ssh_ip") or _pod_state.get("comfyui_ip") | |
| elif desired == "EXITED": | |
| _pod_state["status"] = "stopped" | |
| _pod_state["pod_id"] = None | |
| else: | |
| _pod_state["status"] = "stopped" | |
| _pod_state["pod_id"] = None | |
| except asyncio.TimeoutError: | |
| logger.warning("RunPod API timeout checking pod status") | |
| except Exception as e: | |
| logger.warning("Failed to check pod: %s", e) | |
| uptime = None | |
| if _pod_state["started_at"] and _pod_state["status"] in ("running", "setting_up"): | |
| uptime = (time.time() - _pod_state["started_at"]) / 60 | |
| comfyui_url = _get_comfyui_url() | |
| return PodStatus( | |
| status=_pod_state["status"], | |
| pod_id=_pod_state["pod_id"], | |
| ip=_pod_state["ip"], | |
| port=_pod_state.get("comfyui_port"), | |
| gpu_type=_pod_state["gpu_type"], | |
| model_type=_pod_state.get("model_type", "flux2"), | |
| cost_per_hour=_pod_state["cost_per_hour"], | |
| setup_status=_pod_state.get("setup_status"), | |
| uptime_minutes=uptime, | |
| comfyui_url=comfyui_url, | |
| ) | |
| async def list_gpu_options(): | |
| """List available GPU types.""" | |
| return {"gpus": GPU_OPTIONS} | |
| async def list_model_options(): | |
| """List available model types for the pod.""" | |
| return { | |
| "models": { | |
| "flux2": {"name": "FLUX.2 Dev", "description": "Best for realistic txt2img (requires 48GB+ VRAM)", "use_case": "txt2img"}, | |
| "flux1": {"name": "FLUX.1 Dev", "description": "Previous gen FLUX txt2img", "use_case": "txt2img"}, | |
| "wan22": {"name": "WAN 2.2 Remix", "description": "Realistic generation β dual-DiT MoE split-step (NSFW OK)", "use_case": "txt2img"}, | |
| "wan22_i2v": {"name": "WAN 2.2 I2V", "description": "Image-to-video generation", "use_case": "img2video"}, | |
| "wan22_animate": {"name": "WAN 2.2 Animate", "description": "Dance/motion transfer β animate a character from a driving video", "use_case": "animate"}, | |
| } | |
| } | |
| async def start_pod(request: StartPodRequest): | |
| """Start a GPU pod with ComfyUI for generation.""" | |
| _get_api_key() | |
| if _pod_state["status"] in ("running", "setting_up"): | |
| return {"status": "already_running", "pod_id": _pod_state["pod_id"]} | |
| if _pod_state["status"] == "starting": | |
| return {"status": "starting", "message": "Pod is already starting"} | |
| gpu_info = GPU_OPTIONS.get(request.gpu_type) | |
| if not gpu_info: | |
| raise HTTPException(400, f"Unknown GPU type: {request.gpu_type}") | |
| _pod_state["status"] = "starting" | |
| _pod_state["gpu_type"] = request.gpu_type | |
| _pod_state["cost_per_hour"] = gpu_info["cost"] | |
| _pod_state["model_type"] = request.model_type | |
| _pod_state["setup_status"] = "Creating pod..." | |
| try: | |
| logger.info("Starting RunPod with %s for %s...", request.gpu_type, request.model_type) | |
| pod_kwargs = { | |
| "container_disk_in_gb": 30, | |
| "ports": "22/tcp,8188/http", | |
| "docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'", | |
| } | |
| volume_id, volume_dc = _get_volume_config() | |
| if volume_id: | |
| pod_kwargs["network_volume_id"] = volume_id | |
| if volume_dc: | |
| pod_kwargs["data_center_id"] = volume_dc | |
| logger.info("Using network volume: %s (DC: %s)", volume_id, volume_dc) | |
| else: | |
| pod_kwargs["volume_in_gb"] = 75 | |
| logger.warning("No network volume configured β using temporary volume") | |
| pod = await asyncio.to_thread( | |
| runpod.create_pod, | |
| f"comfyui-gen-{request.model_type}", | |
| DOCKER_IMAGE, | |
| request.gpu_type, | |
| **pod_kwargs, | |
| ) | |
| _pod_state["pod_id"] = pod["id"] | |
| _pod_state["started_at"] = time.time() | |
| _save_pod_state() | |
| logger.info("Pod created: %s", pod["id"]) | |
| asyncio.create_task(_wait_and_setup_pod(pod["id"], request.model_type)) | |
| return { | |
| "status": "starting", | |
| "pod_id": pod["id"], | |
| "message": f"Starting {gpu_info['name']} pod (~5-8 min for setup)", | |
| } | |
| except Exception as e: | |
| _pod_state["status"] = "stopped" | |
| _pod_state["setup_status"] = None | |
| logger.error("Failed to start pod: %s", e) | |
| raise HTTPException(500, f"Failed to start pod: {e}") | |
| async def _wait_and_setup_pod(pod_id: str, model_type: str, timeout: int = 600): | |
| """Wait for pod to be ready, then install ComfyUI and link models via SSH.""" | |
| start = time.time() | |
| ssh_host = None | |
| ssh_port = None | |
| # Phase 1: Wait for SSH to be available | |
| _pod_state["setup_status"] = "Waiting for pod to start..." | |
| while time.time() - start < timeout: | |
| try: | |
| pod = await asyncio.to_thread(runpod.get_pod, pod_id) | |
| if pod and pod.get("desiredStatus") == "RUNNING": | |
| runtime = pod.get("runtime") or {} | |
| ports = runtime.get("ports") or [] | |
| for p in ports: | |
| if p.get("privatePort") == 22: | |
| ssh_host = p.get("ip") | |
| ssh_port = p.get("publicPort") | |
| _pod_state["ssh_ip"] = ssh_host | |
| _pod_state["ssh_port"] = ssh_port | |
| _pod_state["ip"] = ssh_host | |
| if p.get("privatePort") == 8188: | |
| _pod_state["comfyui_ip"] = p.get("ip") | |
| _pod_state["comfyui_port"] = p.get("publicPort") | |
| if ssh_host and ssh_port: | |
| break | |
| except Exception as e: | |
| logger.debug("Waiting for pod: %s", e) | |
| await asyncio.sleep(5) | |
| if not ssh_host or not ssh_port: | |
| logger.error("Pod did not become ready within %ds", timeout) | |
| _pod_state["status"] = "stopped" | |
| _pod_state["setup_status"] = "Failed: pod did not start" | |
| return | |
| # Phase 2: SSH in and set up ComfyUI | |
| _pod_state["status"] = "setting_up" | |
| _pod_state["setup_status"] = "Connecting via SSH..." | |
| import paramiko | |
| async def _ssh_connect_new() -> "paramiko.SSHClient": | |
| """Create a fresh SSH connection to the pod.""" | |
| client = paramiko.SSHClient() | |
| client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
| for attempt in range(10): | |
| try: | |
| await asyncio.to_thread( | |
| client.connect, ssh_host, port=int(ssh_port), | |
| username="root", password="runpod", timeout=15, | |
| banner_timeout=30, | |
| ) | |
| client.get_transport().set_keepalive(30) | |
| return client | |
| except Exception: | |
| if attempt == 9: | |
| raise | |
| await asyncio.sleep(5) | |
| raise RuntimeError("SSH connection failed after retries") | |
| async def _ssh_exec_r(cmd: str, timeout: int = 120) -> str: | |
| """Execute SSH command, reconnecting once if the session dropped.""" | |
| nonlocal ssh | |
| try: | |
| t = ssh.get_transport() | |
| if t is None or not t.is_active(): | |
| logger.info("SSH session dropped, reconnecting...") | |
| ssh = await _ssh_connect_new() | |
| return await _ssh_exec_async(ssh, cmd, timeout) | |
| except Exception as e: | |
| if "not active" in str(e).lower() or "session" in str(e).lower(): | |
| logger.info("SSH error '%s', reconnecting and retrying...", e) | |
| ssh = await _ssh_connect_new() | |
| return await _ssh_exec_async(ssh, cmd, timeout) | |
| raise | |
| for attempt in range(30): | |
| try: | |
| ssh = await _ssh_connect_new() | |
| break | |
| except Exception: | |
| if attempt == 29: | |
| _pod_state["setup_status"] = "Failed: SSH connection error" | |
| _pod_state["status"] = "stopped" | |
| return | |
| await asyncio.sleep(5) | |
| try: | |
| # Symlink network volume | |
| volume_id, _ = _get_volume_config() | |
| if volume_id: | |
| await _ssh_exec_async(ssh, "mkdir -p /runpod-volume/models /runpod-volume/loras") | |
| await _ssh_exec_async(ssh, "rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models") | |
| # Install ComfyUI (cache on volume for reuse) | |
| comfy_dir = "/workspace/ComfyUI" | |
| _pod_state["setup_status"] = "Installing ComfyUI..." | |
| comfy_exists = (await _ssh_exec_async(ssh, f"test -f {comfy_dir}/main.py && echo EXISTS || echo MISSING")).strip() | |
| if comfy_exists == "EXISTS": | |
| logger.info("ComfyUI already installed") | |
| _pod_state["setup_status"] = "ComfyUI found, updating..." | |
| await _ssh_exec_async(ssh, f"cd {comfy_dir} && git pull 2>&1 | tail -3", timeout=120) | |
| else: | |
| # Check volume cache | |
| vol_comfy = (await _ssh_exec_async(ssh, "test -f /runpod-volume/ComfyUI/main.py && echo EXISTS || echo MISSING")).strip() | |
| if vol_comfy == "EXISTS": | |
| _pod_state["setup_status"] = "Restoring ComfyUI from volume..." | |
| await _ssh_exec_async(ssh, f"cp -r /runpod-volume/ComfyUI {comfy_dir}", timeout=300) | |
| else: | |
| _pod_state["setup_status"] = "Cloning ComfyUI (first time, ~2 min)..." | |
| await _ssh_exec_async(ssh, f"cd /workspace && git clone --depth 1 https://github.com/comfyanonymous/ComfyUI.git", timeout=300) | |
| await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600) | |
| # Cache to volume | |
| volume_id, _ = _get_volume_config() | |
| if volume_id: | |
| await _ssh_exec_async(ssh, f"cp -r {comfy_dir} /runpod-volume/ComfyUI", timeout=300) | |
| # Install pip deps that aren't in ComfyUI requirements | |
| _pod_state["setup_status"] = "Installing dependencies..." | |
| await _ssh_exec_async(ssh, f"cd {comfy_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=600) | |
| await _ssh_exec_async(ssh, "pip install aiohttp einops sqlalchemy 2>&1 | tail -3", timeout=120) | |
| # Symlink models into ComfyUI directories | |
| _pod_state["setup_status"] = "Linking models..." | |
| await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/checkpoints {comfy_dir}/models/vae {comfy_dir}/models/loras {comfy_dir}/models/text_encoders") | |
| if model_type == "flux2": | |
| # FLUX.2 Dev β separate UNet, text encoder, and VAE | |
| await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models") | |
| await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/flux2-dev.safetensors {comfy_dir}/models/diffusion_models/flux2-dev.safetensors") | |
| await _ssh_exec_async(ssh, f"ln -sf /workspace/models/FLUX.2-dev/ae.safetensors {comfy_dir}/models/vae/ae.safetensors") | |
| # Text encoder β use Comfy-Org's pre-converted single-file version | |
| # (HF sharded format is incompatible with ComfyUI's CLIPLoader) | |
| te_file = "/runpod-volume/models/mistral_3_small_flux2_fp8.safetensors" | |
| te_exists = (await _ssh_exec_async(ssh, f"test -f {te_file} && echo EXISTS || echo MISSING")).strip() | |
| if te_exists != "EXISTS": | |
| _pod_state["setup_status"] = "Downloading FLUX.2 text encoder (~12GB, first time only)..." | |
| await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60) | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download( | |
| repo_id='Comfy-Org/flux2-dev', | |
| filename='split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors', | |
| local_dir='/tmp/flux2_te', | |
| ) | |
| import shutil | |
| shutil.move('/tmp/flux2_te/split_files/text_encoders/mistral_3_small_flux2_fp8.safetensors', '{te_file}') | |
| print('Text encoder downloaded') | |
| " 2>&1 | tail -5""", timeout=1800) | |
| await _ssh_exec_async(ssh, f"ln -sf {te_file} {comfy_dir}/models/text_encoders/mistral_3_small_flux2_fp8.safetensors") | |
| # Remove old sharded loader patch if present | |
| await _ssh_exec_async(ssh, f"rm -f {comfy_dir}/custom_nodes/sharded_loader.py") | |
| elif model_type == "flux1": | |
| await _ssh_exec_async(ssh, f"ln -sf /workspace/models/flux1-dev.safetensors {comfy_dir}/models/checkpoints/flux1-dev.safetensors") | |
| await _ssh_exec_async(ssh, f"ln -sf /workspace/models/ae.safetensors {comfy_dir}/models/vae/ae.safetensors") | |
| await _ssh_exec_async(ssh, f"ln -sf /workspace/models/clip_l.safetensors {comfy_dir}/models/text_encoders/clip_l.safetensors") | |
| await _ssh_exec_async(ssh, f"ln -sf /workspace/models/t5xxl_fp16.safetensors {comfy_dir}/models/text_encoders/t5xxl_fp16.safetensors") | |
| elif model_type == "z_image": | |
| # Z-Image Turbo β 6B param model by Tongyi-MAI, runs in 16GB VRAM | |
| z_dir = "/runpod-volume/models/z_image" | |
| await _ssh_exec_async(ssh, f"mkdir -p {z_dir}") | |
| await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60) | |
| # Delete FLUX.2 from volume to free space | |
| _pod_state["setup_status"] = "Cleaning up FLUX.2 from volume..." | |
| await _ssh_exec_async(ssh, "rm -rf /runpod-volume/models/FLUX.2-dev /runpod-volume/models/mistral_3_small_flux2_fp8.safetensors 2>/dev/null; echo done") | |
| # Download diffusion model (~12GB) | |
| diff_model = f"{z_dir}/z_image_turbo_bf16.safetensors" | |
| exists = (await _ssh_exec_async(ssh, f"test -f {diff_model} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| _pod_state["setup_status"] = "Downloading Z-Image Turbo diffusion model (~12GB)..." | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| import shutil, os | |
| p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/diffusion_models/z_image_turbo_bf16.safetensors', local_dir='/tmp/z_image') | |
| shutil.move(p, '{diff_model}') | |
| print('Diffusion model downloaded') | |
| " 2>&1 | tail -5""", timeout=3600) | |
| # Download text encoder (~8GB Qwen 3 4B) | |
| te_model = f"{z_dir}/qwen_3_4b.safetensors" | |
| exists = (await _ssh_exec_async(ssh, f"test -f {te_model} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| _pod_state["setup_status"] = "Downloading Z-Image text encoder (~8GB)..." | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| import shutil | |
| p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/text_encoders/qwen_3_4b.safetensors', local_dir='/tmp/z_image') | |
| shutil.move(p, '{te_model}') | |
| print('Text encoder downloaded') | |
| " 2>&1 | tail -5""", timeout=3600) | |
| # Download VAE (~335MB) | |
| vae_model = f"{z_dir}/ae.safetensors" | |
| exists = (await _ssh_exec_async(ssh, f"test -f {vae_model} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| _pod_state["setup_status"] = "Downloading Z-Image VAE..." | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| import shutil | |
| p = hf_hub_download('Comfy-Org/z_image_turbo', 'split_files/vae/ae.safetensors', local_dir='/tmp/z_image') | |
| shutil.move(p, '{vae_model}') | |
| print('VAE downloaded') | |
| " 2>&1 | tail -5""", timeout=600) | |
| # Symlink into ComfyUI directories | |
| await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/text_encoders {comfy_dir}/models/vae") | |
| await _ssh_exec_async(ssh, f"ln -sf {diff_model} {comfy_dir}/models/diffusion_models/z_image_turbo_bf16.safetensors") | |
| await _ssh_exec_async(ssh, f"ln -sf {te_model} {comfy_dir}/models/text_encoders/qwen_3_4b.safetensors") | |
| await _ssh_exec_async(ssh, f"ln -sf {vae_model} {comfy_dir}/models/vae/ae_z_image.safetensors") | |
| elif model_type == "wan22": | |
| # WAN 2.2 Remix NSFW β dual-DiT MoE split-step for realistic generation | |
| wan_dir = "/workspace/models/WAN2.2" | |
| await _ssh_exec_async(ssh, f"mkdir -p {wan_dir}") | |
| civitai_token = os.environ.get("CIVITAI_API_TOKEN", "") | |
| token_param = f"&token={civitai_token}" if civitai_token else "" | |
| # CivitAI Remix models (fp8 ~14GB each) | |
| civitai_models = { | |
| "Remix T2V High-noise": { | |
| "path": f"{wan_dir}/wan22_remix_t2v_high_fp8.safetensors", | |
| "url": f"https://civitai.com/api/download/models/2424167?type=Model&format=SafeTensor&size=pruned{token_param}", | |
| }, | |
| "Remix T2V Low-noise": { | |
| "path": f"{wan_dir}/wan22_remix_t2v_low_fp8.safetensors", | |
| "url": f"https://civitai.com/api/download/models/2424912?type=Model&format=SafeTensor&size=pruned{token_param}", | |
| }, | |
| } | |
| # HuggingFace models (T5 fp8 ~7GB, VAE ~1GB) | |
| hf_models = { | |
| "T5 text encoder (fp8)": { | |
| "path": f"{wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors", | |
| "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", | |
| "filename": "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", | |
| }, | |
| "VAE": { | |
| "path": f"{wan_dir}/wan_2.1_vae.safetensors", | |
| "repo": "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", | |
| "filename": "split_files/vae/wan_2.1_vae.safetensors", | |
| }, | |
| } | |
| # Download CivitAI Remix models | |
| for label, info in civitai_models.items(): | |
| exists = (await _ssh_exec_async(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip() | |
| if exists == "EXISTS": | |
| logger.info("WAN 2.2 %s already cached", label) | |
| else: | |
| _pod_state["setup_status"] = f"Downloading {label} (~14GB)..." | |
| await _ssh_exec_async(ssh, f"wget -q -O '{info['path']}' '{info['url']}'", timeout=1800) | |
| # Verify download | |
| check = (await _ssh_exec_async(ssh, f"test -f {info['path']} && stat -c%s {info['path']} || echo 0")).strip() | |
| if check == "0" or int(check) < 1000000: | |
| logger.error("Failed to download %s (size: %s). CivitAI API token may be required.", label, check) | |
| _pod_state["setup_status"] = f"Failed: {label} download failed. Set CIVITAI_API_TOKEN env var for NSFW models." | |
| return | |
| # Download HuggingFace models | |
| await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60) | |
| for label, info in hf_models.items(): | |
| exists = (await _ssh_exec_async(ssh, f"test -f {info['path']} && echo EXISTS || echo MISSING")).strip() | |
| if exists == "EXISTS": | |
| logger.info("WAN 2.2 %s already cached", label) | |
| else: | |
| _pod_state["setup_status"] = f"Downloading {label}..." | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| import os, shutil | |
| hf_hub_download('{info['repo']}', '{info['filename']}', local_dir='{wan_dir}') | |
| downloaded = os.path.join('{wan_dir}', '{info['filename']}') | |
| target = '{info['path']}' | |
| if os.path.exists(downloaded) and downloaded != target: | |
| os.makedirs(os.path.dirname(target), exist_ok=True) | |
| shutil.move(downloaded, target) | |
| print('Downloaded {label}') | |
| " 2>&1 | tail -5""", timeout=1800) | |
| # Symlink models into ComfyUI | |
| await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/text_encoders") | |
| await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan22_remix_t2v_high_fp8.safetensors {comfy_dir}/models/diffusion_models/") | |
| await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan22_remix_t2v_low_fp8.safetensors {comfy_dir}/models/diffusion_models/") | |
| await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/wan_2.1_vae.safetensors {comfy_dir}/models/vae/") | |
| await _ssh_exec_async(ssh, f"ln -sf {wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors {comfy_dir}/models/text_encoders/") | |
| # Install wanBlockSwap custom node (VRAM optimization for dual-DiT on 24GB GPUs) | |
| _pod_state["setup_status"] = "Installing WAN 2.2 custom nodes..." | |
| blockswap_dir = f"{comfy_dir}/custom_nodes/ComfyUI-wanBlockswap" | |
| blockswap_exists = (await _ssh_exec_async(ssh, f"test -d {blockswap_dir} && echo EXISTS || echo MISSING")).strip() | |
| if blockswap_exists != "EXISTS": | |
| await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/orssorbit/ComfyUI-wanBlockswap.git", timeout=120) | |
| elif model_type == "wan22_i2v": | |
| # WAN 2.2 Image-to-Video (14B params) β full model snapshot | |
| wan_dir = "/workspace/models/Wan2.2-I2V-A14B" | |
| wan_exists = (await _ssh_exec_async(ssh, f"test -d {wan_dir} && echo EXISTS || echo MISSING")).strip() | |
| if wan_exists != "EXISTS": | |
| _pod_state["setup_status"] = "Downloading WAN 2.2 I2V model (~28GB, first time only)..." | |
| await _ssh_exec_async(ssh, f"pip install huggingface_hub 2>&1 | tail -1", timeout=60) | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import snapshot_download | |
| snapshot_download('Wan-AI/Wan2.2-I2V-A14B', local_dir='{wan_dir}', ignore_patterns=['*.md', '*.txt']) | |
| print('WAN 2.2 I2V downloaded') | |
| " 2>&1 | tail -10""", timeout=3600) | |
| await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models") | |
| await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/diffusion_models/Wan2.2-I2V-A14B") | |
| await _ssh_exec_async(ssh, f"ln -sf {wan_dir} {comfy_dir}/models/checkpoints/Wan2.2-I2V-A14B") | |
| # Install ComfyUI-WanVideoWrapper custom nodes | |
| _pod_state["setup_status"] = "Installing WAN 2.2 ComfyUI nodes..." | |
| wan_nodes_dir = f"{comfy_dir}/custom_nodes/ComfyUI-WanVideoWrapper" | |
| wan_nodes_exist = (await _ssh_exec_async(ssh, f"test -d {wan_nodes_dir} && echo EXISTS || echo MISSING")).strip() | |
| if wan_nodes_exist != "EXISTS": | |
| await _ssh_exec_async(ssh, f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120) | |
| await _ssh_exec_async(ssh, f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) | |
| elif model_type == "wan22_animate": | |
| # WAN 2.2 Animate (14B fp8) β dance/motion transfer via pose skeleton | |
| animate_dir = "/workspace/models/WAN2.2-Animate" | |
| wan22_dir = "/workspace/models/WAN2.2" | |
| await _ssh_exec_async(ssh, f"mkdir -p {animate_dir}") | |
| await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=60) | |
| # Download main Animate model (~28GB bf16 β only version available) | |
| animate_model = f"{animate_dir}/wan2.2_animate_14B_bf16.safetensors" | |
| exists = (await _ssh_exec_async(ssh, f"test -f {animate_model} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| _pod_state["setup_status"] = "Downloading WAN 2.2 Animate model (~28GB, first time only)..." | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| import os, shutil | |
| hf_hub_download('Comfy-Org/Wan_2.2_ComfyUI_Repackaged', 'split_files/diffusion_models/wan2.2_animate_14B_bf16.safetensors', local_dir='{animate_dir}') | |
| src = os.path.join('{animate_dir}', 'split_files', 'diffusion_models', 'wan2.2_animate_14B_bf16.safetensors') | |
| if os.path.exists(src): | |
| shutil.move(src, '{animate_model}') | |
| print('Animate model downloaded') | |
| " 2>&1 | tail -5""", timeout=7200) | |
| # CLIP Vision H (~2.5GB) β ViT-H vision encoder | |
| clip_vision_target = f"{animate_dir}/clip_vision_h.safetensors" | |
| exists = (await _ssh_exec_async(ssh, f"test -f {clip_vision_target} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| _pod_state["setup_status"] = "Downloading CLIP Vision H (~2.5GB)..." | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| import os, shutil | |
| result = hf_hub_download('h94/IP-Adapter', 'models/image_encoder/model.safetensors', local_dir='{animate_dir}/tmp_clip') | |
| shutil.move(result, '{clip_vision_target}') | |
| shutil.rmtree('{animate_dir}/tmp_clip', ignore_errors=True) | |
| print('CLIP Vision H downloaded') | |
| " 2>&1 | tail -5""", timeout=1800) | |
| # VAE β reuse from WAN2.2 dir if available, else download (~1GB) | |
| vae_target = f"{animate_dir}/wan_2.1_vae.safetensors" | |
| exists = (await _ssh_exec_async(ssh, f"test -f {vae_target} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| vae_from_wan22 = (await _ssh_exec_async(ssh, f"test -f {wan22_dir}/wan_2.1_vae.safetensors && echo EXISTS || echo MISSING")).strip() | |
| if vae_from_wan22 == "EXISTS": | |
| await _ssh_exec_async(ssh, f"ln -sf {wan22_dir}/wan_2.1_vae.safetensors {vae_target}") | |
| else: | |
| _pod_state["setup_status"] = "Downloading VAE (~1GB)..." | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| import os, shutil | |
| hf_hub_download('Comfy-Org/Wan_2.2_ComfyUI_Repackaged', 'split_files/vae/wan_2.1_vae.safetensors', local_dir='{animate_dir}') | |
| src = os.path.join('{animate_dir}', 'split_files', 'vae', 'wan_2.1_vae.safetensors') | |
| if os.path.exists(src): | |
| shutil.move(src, '{vae_target}') | |
| print('VAE downloaded') | |
| " 2>&1 | tail -5""", timeout=600) | |
| # UMT5 T5 encoder fp8 (non-scaled) β use Kijai/WanVideo_comfy version | |
| # which is compatible with LoadWanVideoT5TextEncoder (scaled_fp8 is not supported) | |
| t5_filename = "umt5-xxl-enc-fp8_e4m3fn.safetensors" | |
| t5_target = f"{animate_dir}/{t5_filename}" | |
| t5_comfy_path = f"{comfy_dir}/models/text_encoders/{t5_filename}" | |
| t5_in_comfy = (await _ssh_exec_async(ssh, f"test -f {t5_comfy_path} && echo EXISTS || echo MISSING")).strip() | |
| t5_in_vol = (await _ssh_exec_async(ssh, f"test -f {t5_target} && echo EXISTS || echo MISSING")).strip() | |
| if t5_in_comfy != "EXISTS" and t5_in_vol != "EXISTS": | |
| _pod_state["setup_status"] = "Downloading UMT5 text encoder (~6.3GB, first time only)..." | |
| await _ssh_exec_async(ssh, f"""python -c " | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download('Kijai/WanVideo_comfy', '{t5_filename}', local_dir='{animate_dir}') | |
| print('UMT5 text encoder downloaded') | |
| " 2>&1 | tail -5""", timeout=1800) | |
| t5_in_vol = "EXISTS" | |
| # Symlink models into ComfyUI directories | |
| await _ssh_exec_async(ssh, f"mkdir -p {comfy_dir}/models/diffusion_models {comfy_dir}/models/vae {comfy_dir}/models/clip_vision {comfy_dir}/models/text_encoders") | |
| await _ssh_exec_async(ssh, f"ln -sf {animate_model} {comfy_dir}/models/diffusion_models/") | |
| await _ssh_exec_async(ssh, f"ln -sf {vae_target} {comfy_dir}/models/vae/") | |
| await _ssh_exec_async(ssh, f"ln -sf {clip_vision_target} {comfy_dir}/models/clip_vision/") | |
| if t5_in_vol == "EXISTS" and t5_in_comfy != "EXISTS": | |
| await _ssh_exec_async(ssh, f"ln -sf {t5_target} {t5_comfy_path}") | |
| # Reconnect SSH before custom node setup β connection may have dropped during long downloads | |
| ssh = await _ssh_connect_new() | |
| # Install required custom nodes | |
| _pod_state["setup_status"] = "Installing WAN Animate custom nodes..." | |
| # ComfyUI-WanVideoWrapper (WanVideoAnimateEmbeds, WanVideoSampler, etc.) | |
| wan_nodes_dir = f"{comfy_dir}/custom_nodes/ComfyUI-WanVideoWrapper" | |
| exists = (await _ssh_exec_r(f"test -d {wan_nodes_dir} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-WanVideoWrapper.git", timeout=120) | |
| await _ssh_exec_r(f"cd {wan_nodes_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) | |
| # ComfyUI-VideoHelperSuite (VHS_LoadVideo, VHS_VideoCombine) | |
| vhs_dir = f"{comfy_dir}/custom_nodes/ComfyUI-VideoHelperSuite" | |
| exists = (await _ssh_exec_r(f"test -d {vhs_dir} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git", timeout=120) | |
| await _ssh_exec_r(f"cd {vhs_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) | |
| # comfyui_controlnet_aux (DWPreprocessor for pose extraction) | |
| aux_dir = f"{comfy_dir}/custom_nodes/comfyui_controlnet_aux" | |
| exists = (await _ssh_exec_r(f"test -d {aux_dir} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/Fannovel16/comfyui_controlnet_aux.git", timeout=120) | |
| await _ssh_exec_r(f"cd {aux_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) | |
| # ComfyUI-KJNodes (ImageResizeKJv2 used in animate workflow) | |
| kj_dir = f"{comfy_dir}/custom_nodes/ComfyUI-KJNodes" | |
| exists = (await _ssh_exec_r(f"test -d {kj_dir} && echo EXISTS || echo MISSING")).strip() | |
| if exists != "EXISTS": | |
| await _ssh_exec_r(f"cd {comfy_dir}/custom_nodes && git clone --depth 1 https://github.com/kijai/ComfyUI-KJNodes.git", timeout=120) | |
| await _ssh_exec_r(f"cd {kj_dir} && pip install -r requirements.txt 2>&1 | tail -5", timeout=300) | |
| # Symlink all LoRAs from volume | |
| await _ssh_exec_r(f"ls /runpod-volume/loras/*.safetensors 2>/dev/null | while read f; do ln -sf \"$f\" {comfy_dir}/models/loras/; done") | |
| # Start ComfyUI in background (fire-and-forget β don't wait for output) | |
| _pod_state["setup_status"] = "Starting ComfyUI..." | |
| await asyncio.to_thread( | |
| _ssh_exec_fire_and_forget, | |
| ssh, | |
| f"cd {comfy_dir} && python main.py --listen 0.0.0.0 --port 8188 --fp8_e4m3fn-unet > /tmp/comfyui.log 2>&1", | |
| ) | |
| await asyncio.sleep(2) # Give it a moment to start | |
| # Wait for ComfyUI HTTP to respond | |
| _pod_state["setup_status"] = "Waiting for ComfyUI to load model..." | |
| import httpx | |
| comfyui_url = _get_comfyui_url() | |
| for attempt in range(120): # Up to 10 minutes | |
| try: | |
| async with httpx.AsyncClient(timeout=5) as client: | |
| resp = await client.get(f"{comfyui_url}/system_stats") | |
| if resp.status_code == 200: | |
| _pod_state["status"] = "running" | |
| _pod_state["setup_status"] = "Ready" | |
| _save_pod_state() | |
| logger.info("ComfyUI ready at %s", comfyui_url) | |
| return | |
| except Exception: | |
| pass | |
| await asyncio.sleep(5) | |
| # If we get here, ComfyUI didn't start | |
| # Check the log for errors | |
| log_tail = await _ssh_exec_async(ssh, "tail -20 /tmp/comfyui.log") | |
| logger.error("ComfyUI didn't start. Log: %s", log_tail) | |
| _pod_state["setup_status"] = f"ComfyUI failed to start. Check logs." | |
| _pod_state["status"] = "setting_up" # Keep pod running so user can debug | |
| except Exception as e: | |
| import traceback | |
| err_msg = f"{type(e).__name__}: {e}" | |
| logger.error("Pod setup failed: %s\n%s", err_msg, traceback.format_exc()) | |
| _pod_state["setup_status"] = f"Setup failed: {err_msg}" | |
| _pod_state["status"] = "setting_up" # Keep pod running so user can debug | |
| finally: | |
| try: | |
| ssh.close() | |
| except Exception: | |
| pass | |
| def _ssh_exec(ssh, cmd: str, timeout: int = 120) -> str: | |
| """Execute a command over SSH and return stdout (blocking β call from async via to_thread or background task).""" | |
| _, stdout, stderr = ssh.exec_command(cmd, timeout=timeout) | |
| out = stdout.read().decode("utf-8", errors="replace") | |
| return out.strip() | |
| async def _ssh_exec_async(ssh, cmd: str, timeout: int = 120) -> str: | |
| """Async wrapper for SSH exec that doesn't block the event loop.""" | |
| return await asyncio.to_thread(_ssh_exec, ssh, cmd, timeout) | |
| def _ssh_exec_fire_and_forget(ssh, cmd: str): | |
| """Start a command over SSH without waiting for output (for background processes).""" | |
| transport = ssh.get_transport() | |
| channel = transport.open_session() | |
| channel.exec_command(cmd) | |
| # Don't read stdout/stderr β just let it run | |
| # --- Pre-download models to network volume (saves money during training) --- | |
| _download_state = { | |
| "status": "idle", # idle, downloading, completed, failed | |
| "pod_id": None, | |
| "progress": "", | |
| "error": None, | |
| } | |
| class DownloadModelsRequest(BaseModel): | |
| model_type: str = "wan22" | |
| gpu_type: str = "NVIDIA GeForce RTX 3090" # Cheapest GPU, just for downloading | |
| async def download_models_to_volume(request: DownloadModelsRequest): | |
| """Pre-download model files to network volume using a cheap pod. | |
| This saves expensive GPU time during training β models are cached on the | |
| shared volume and reused across all future training/generation pods. | |
| """ | |
| _get_api_key() | |
| volume_id, volume_dc = _get_volume_config() | |
| if not volume_id: | |
| raise HTTPException(400, "No network volume configured (set RUNPOD_VOLUME_ID)") | |
| if _download_state["status"] == "downloading": | |
| return {"status": "already_downloading", "progress": _download_state["progress"]} | |
| _download_state["status"] = "downloading" | |
| _download_state["progress"] = "Creating cheap download pod..." | |
| _download_state["error"] = None | |
| asyncio.create_task(_download_models_task(request.model_type, request.gpu_type, volume_id, volume_dc)) | |
| return {"status": "started", "message": f"Downloading {request.model_type} models to volume (using {request.gpu_type})"} | |
| async def download_models_status(): | |
| """Check model download progress.""" | |
| return _download_state | |
| async def _download_models_task(model_type: str, gpu_type: str, volume_id: str, volume_dc: str): | |
| """Background task: spin up cheap pod, download models, terminate.""" | |
| import paramiko | |
| ssh = None | |
| pod_id = None | |
| try: | |
| # Create cheap pod with network volume β try multiple GPU types if first unavailable | |
| pod_kwargs = { | |
| "container_disk_in_gb": 10, | |
| "ports": "22/tcp", | |
| "network_volume_id": volume_id, | |
| "docker_args": "bash -c 'apt-get update && apt-get install -y openssh-server && mkdir -p /run/sshd && echo root:runpod | chpasswd && /usr/sbin/sshd -o PermitRootLogin=yes && sleep infinity'", | |
| } | |
| if volume_dc: | |
| pod_kwargs["data_center_id"] = volume_dc | |
| gpu_fallbacks = [ | |
| gpu_type, | |
| "NVIDIA RTX A4000", | |
| "NVIDIA RTX A5000", | |
| "NVIDIA GeForce RTX 4090", | |
| "NVIDIA GeForce RTX 4080", | |
| "NVIDIA A100-SXM4-80GB", | |
| ] | |
| pod = None | |
| used_gpu = gpu_type | |
| for try_gpu in gpu_fallbacks: | |
| try: | |
| pod = await asyncio.to_thread( | |
| runpod.create_pod, | |
| f"model-download-{model_type}", | |
| DOCKER_IMAGE, | |
| try_gpu, | |
| **pod_kwargs, | |
| ) | |
| used_gpu = try_gpu | |
| logger.info("Download pod created with %s", try_gpu) | |
| break | |
| except Exception as e: | |
| if "SUPPLY_CONSTRAINT" in str(e) or "no longer any instances" in str(e).lower(): | |
| logger.info("GPU %s unavailable, trying next...", try_gpu) | |
| continue | |
| raise | |
| if pod is None: | |
| raise RuntimeError("No GPU available for download pod. Try again later.") | |
| pod_id = pod["id"] | |
| _download_state["pod_id"] = pod_id | |
| _download_state["progress"] = f"Pod created with {used_gpu} ({pod_id}), waiting for SSH..." | |
| # Wait for SSH | |
| ssh_host = ssh_port = None | |
| start = time.time() | |
| while time.time() - start < 300: | |
| try: | |
| p = await asyncio.to_thread(runpod.get_pod, pod_id) | |
| if p and p.get("desiredStatus") == "RUNNING": | |
| for port in (p.get("runtime") or {}).get("ports") or []: | |
| if port.get("privatePort") == 22: | |
| ssh_host = port.get("ip") | |
| ssh_port = port.get("publicPort") | |
| if ssh_host and ssh_port: | |
| break | |
| except Exception: | |
| pass | |
| await asyncio.sleep(5) | |
| if not ssh_host: | |
| raise RuntimeError("Pod SSH not available after 5 min") | |
| # Connect SSH | |
| ssh = paramiko.SSHClient() | |
| ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
| for attempt in range(20): | |
| try: | |
| await asyncio.to_thread(ssh.connect, ssh_host, port=int(ssh_port), username="root", password="runpod", timeout=10) | |
| break | |
| except Exception: | |
| if attempt == 19: | |
| raise RuntimeError("SSH connection failed after 20 attempts") | |
| await asyncio.sleep(5) | |
| ssh.get_transport().set_keepalive(30) | |
| _download_state["progress"] = "SSH connected, setting up tools..." | |
| # Symlink volume | |
| await _ssh_exec_async(ssh, "mkdir -p /runpod-volume/models && rm -rf /workspace/models 2>/dev/null; ln -sf /runpod-volume/models /workspace/models") | |
| await _ssh_exec_async(ssh, "pip install huggingface_hub 2>&1 | tail -1", timeout=120) | |
| await _ssh_exec_async(ssh, "which aria2c || apt-get install -y aria2 2>&1 | tail -1", timeout=120) | |
| if model_type == "wan22": | |
| wan_dir = "/workspace/models/WAN2.2" | |
| await _ssh_exec_async(ssh, f"mkdir -p {wan_dir}") | |
| civitai_token = os.environ.get("CIVITAI_API_TOKEN", "") | |
| token_param = f"&token={civitai_token}" if civitai_token else "" | |
| # CivitAI Remix models (fp8) | |
| civitai_files = [ | |
| ("Remix T2V High-noise", f"https://civitai.com/api/download/models/2424167?type=Model&format=SafeTensor&size=pruned{token_param}", f"{wan_dir}/wan22_remix_t2v_high_fp8.safetensors"), | |
| ("Remix T2V Low-noise", f"https://civitai.com/api/download/models/2424912?type=Model&format=SafeTensor&size=pruned{token_param}", f"{wan_dir}/wan22_remix_t2v_low_fp8.safetensors"), | |
| ] | |
| # HuggingFace models | |
| hf_files = [ | |
| ("T5 text encoder (fp8)", "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", f"{wan_dir}/umt5_xxl_fp8_e4m3fn_scaled.safetensors"), | |
| ("VAE", "Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/vae/wan_2.1_vae.safetensors", f"{wan_dir}/wan_2.1_vae.safetensors"), | |
| ] | |
| total = len(civitai_files) + len(hf_files) | |
| idx = 0 | |
| for label, url, target in civitai_files: | |
| idx += 1 | |
| exists = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip() | |
| if exists == "EXISTS": | |
| _download_state["progress"] = f"[{idx}/{total}] {label} already cached" | |
| logger.info("WAN 2.2 %s already on volume", label) | |
| else: | |
| _download_state["progress"] = f"[{idx}/{total}] Downloading {label} (~14GB)..." | |
| await _ssh_exec_async(ssh, f"wget -q -O '{target}' '{url}'", timeout=1800) | |
| check = (await _ssh_exec_async(ssh, f"test -f {target} && stat -c%s {target} || echo 0")).strip() | |
| if check == "0" or int(check) < 1000000: | |
| raise RuntimeError(f"Failed to download {label}. Set CIVITAI_API_TOKEN for NSFW models.") | |
| _download_state["progress"] = f"[{idx}/{total}] {label} downloaded" | |
| for label, repo, filename, target in hf_files: | |
| idx += 1 | |
| exists = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip() | |
| if exists == "EXISTS": | |
| _download_state["progress"] = f"[{idx}/{total}] {label} already cached" | |
| logger.info("WAN 2.2 %s already on volume", label) | |
| else: | |
| _download_state["progress"] = f"[{idx}/{total}] Downloading {label}..." | |
| hf_url = f"https://huggingface.co/{repo}/resolve/main/{filename}" | |
| fname = target.split("/")[-1] | |
| tdir = "/".join(target.split("/")[:-1]) | |
| await _ssh_exec_async(ssh, f"aria2c -x 16 -s 16 -c -o '{fname}' --dir='{tdir}' '{hf_url}' 2>&1 | tail -3", timeout=1800) | |
| check = (await _ssh_exec_async(ssh, f"test -f {target} && echo EXISTS || echo MISSING")).strip() | |
| if check != "EXISTS": | |
| raise RuntimeError(f"Failed to download {label}") | |
| _download_state["progress"] = f"[{idx}/{total}] {label} downloaded" | |
| # Also pre-clone musubi-tuner to volume (for training) | |
| _download_state["progress"] = "Caching musubi-tuner to volume..." | |
| tuner_exists = (await _ssh_exec_async(ssh, "test -f /runpod-volume/musubi-tuner/pyproject.toml && echo EXISTS || echo MISSING")).strip() | |
| if tuner_exists != "EXISTS": | |
| await _ssh_exec_async(ssh, "cd /workspace && git clone --depth 1 https://github.com/kohya-ss/musubi-tuner.git && cp -r /workspace/musubi-tuner /runpod-volume/musubi-tuner", timeout=300) | |
| _download_state["progress"] = "musubi-tuner cached" | |
| else: | |
| _download_state["progress"] = "musubi-tuner already cached" | |
| elif model_type == "wan22_animate": | |
| animate_dir = "/workspace/models/WAN2.2-Animate" | |
| wan22_dir = "/workspace/models/WAN2.2" | |
| hf_base = "https://huggingface.co" | |
| await _ssh_exec_async(ssh, f"mkdir -p {animate_dir}") | |
| # Files to download: (label, url, target, timeout_s, min_bytes) | |
| wget_files = [ | |
| ( | |
| "WAN 2.2 Animate model (~32GB)", | |
| f"{hf_base}/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_animate_14B_bf16.safetensors", | |
| f"{animate_dir}/wan2.2_animate_14B_bf16.safetensors", | |
| 7200, | |
| 30_000_000_000, # 30GB min β partial downloads get resumed | |
| ), | |
| ( | |
| "UMT5 text encoder fp8 (~6.3GB)", | |
| f"{hf_base}/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors", | |
| f"{animate_dir}/umt5-xxl-enc-fp8_e4m3fn.safetensors", | |
| 1800, | |
| 6_000_000_000, | |
| ), | |
| ( | |
| "VAE (~242MB)", | |
| f"{hf_base}/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors", | |
| f"{animate_dir}/wan_2.1_vae.safetensors", | |
| 300, | |
| 200_000_000, | |
| ), | |
| ( | |
| "CLIP Vision H (~2.4GB)", | |
| f"{hf_base}/h94/IP-Adapter/resolve/main/models/image_encoder/model.safetensors", | |
| f"{animate_dir}/clip_vision_h.safetensors", | |
| 900, | |
| 2_000_000_000, | |
| ), | |
| ] | |
| total = len(wget_files) | |
| for idx, (label, url, target, dl_timeout, min_bytes) in enumerate(wget_files, 1): | |
| # For T5 and VAE, reuse from wan22 dir if already present (and complete) | |
| wan22_candidate = f"{wan22_dir}/{target.split('/')[-1]}" | |
| reused = False | |
| if label in ("UMT5 text encoder fp8 (~6.3GB)", "VAE (~1GB)"): | |
| wan22_size = (await _ssh_exec_async(ssh, f"stat -c%s {wan22_candidate} 2>/dev/null || echo 0")).strip() | |
| if int(wan22_size) >= min_bytes: | |
| _download_state["progress"] = f"[{idx}/{total}] {label} β reusing from WAN2.2 dir" | |
| await _ssh_exec_async(ssh, f"ln -sf {wan22_candidate} {target} 2>/dev/null || cp {wan22_candidate} {target}") | |
| reused = True | |
| if not reused: | |
| size_str = (await _ssh_exec_async(ssh, f"stat -c%s {target} 2>/dev/null || echo 0")).strip() | |
| if int(size_str) >= min_bytes: | |
| _download_state["progress"] = f"[{idx}/{total}] {label} already cached" | |
| else: | |
| _download_state["progress"] = f"[{idx}/{total}] Downloading {label}..." | |
| filename = target.split("/")[-1] | |
| target_dir = "/".join(target.split("/")[:-1]) | |
| # Remove stale symlinks before downloading (can't resume through a symlink) | |
| await _ssh_exec_async(ssh, f"test -L '{target}' && rm -f '{target}'; true") | |
| await _ssh_exec_async( | |
| ssh, | |
| f"aria2c -x 16 -s 16 -c -o '{filename}' --dir='{target_dir}' '{url}' 2>&1 | tail -3", | |
| timeout=dl_timeout, | |
| ) | |
| size_str = (await _ssh_exec_async(ssh, f"stat -c%s {target} 2>/dev/null || echo 0")).strip() | |
| if int(size_str) < min_bytes: | |
| raise RuntimeError(f"Failed to download {label} (size {size_str} < {min_bytes})") | |
| _download_state["progress"] = f"[{idx}/{total}] {label} downloaded" | |
| _download_state["status"] = "completed" | |
| _download_state["progress"] = "All models downloaded to volume! Ready for training." | |
| logger.info("Model pre-download complete for %s", model_type) | |
| except Exception as e: | |
| _download_state["status"] = "failed" | |
| _download_state["error"] = str(e) | |
| _download_state["progress"] = f"Failed: {e}" | |
| logger.error("Model download failed: %s", e) | |
| finally: | |
| if ssh: | |
| try: | |
| ssh.close() | |
| except Exception: | |
| pass | |
| if pod_id: | |
| try: | |
| await asyncio.to_thread(runpod.terminate_pod, pod_id) | |
| logger.info("Download pod terminated: %s", pod_id) | |
| except Exception as e: | |
| logger.warning("Failed to terminate download pod: %s", e) | |
| _download_state["pod_id"] = None | |
| async def stop_pod(): | |
| """Stop the GPU pod.""" | |
| _get_api_key() | |
| if not _pod_state["pod_id"]: | |
| return {"status": "already_stopped"} | |
| if _pod_state["status"] == "stopping": | |
| return {"status": "stopping", "message": "Pod is already stopping"} | |
| _pod_state["status"] = "stopping" | |
| try: | |
| pod_id = _pod_state["pod_id"] | |
| logger.info("Stopping pod: %s", pod_id) | |
| await asyncio.to_thread(runpod.terminate_pod, pod_id) | |
| _pod_state["pod_id"] = None | |
| _pod_state["ip"] = None | |
| _pod_state["ssh_port"] = None | |
| _pod_state["comfyui_port"] = None | |
| _pod_state["status"] = "stopped" | |
| _pod_state["started_at"] = None | |
| _pod_state["setup_status"] = None | |
| _save_pod_state() | |
| logger.info("Pod stopped") | |
| return {"status": "stopped", "message": "Pod terminated"} | |
| except Exception as e: | |
| logger.error("Failed to stop pod: %s", e) | |
| _pod_state["status"] = "running" | |
| raise HTTPException(500, f"Failed to stop pod: {e}") | |
| async def list_pod_loras(): | |
| """List LoRAs available on the pod.""" | |
| if _pod_state["status"] != "running" or not _pod_state["ip"]: | |
| return {"loras": [], "message": "Pod not running"} | |
| comfyui_url = _get_comfyui_url() | |
| try: | |
| import httpx | |
| async with httpx.AsyncClient(timeout=30) as client: | |
| url = f"{comfyui_url}/object_info/LoraLoader" | |
| resp = await client.get(url) | |
| if resp.status_code == 200: | |
| data = resp.json() | |
| loras = data.get("LoraLoader", {}).get("input", {}).get("required", {}).get("lora_name", [[]])[0] | |
| return {"loras": loras if isinstance(loras, list) else []} | |
| except Exception as e: | |
| logger.warning("Failed to list pod LoRAs: %s", e) | |
| return {"loras": [], "comfyui_url": comfyui_url} | |
| async def upload_lora_to_pod(file: UploadFile = File(...)): | |
| """Upload a LoRA file directly to /runpod-volume/loras/ via SFTP so it persists.""" | |
| import paramiko, io | |
| if _pod_state["status"] != "running": | |
| raise HTTPException(400, "Pod not running - start it first") | |
| if not file.filename.endswith(".safetensors"): | |
| raise HTTPException(400, "Only .safetensors files supported") | |
| ip = _pod_state.get("ip") | |
| port = _pod_state.get("ssh_port") or 22 | |
| if not ip: | |
| raise HTTPException(500, "No SSH IP available") | |
| content = await file.read() | |
| dest_path = f"/runpod-volume/loras/{file.filename}" | |
| comfy_link = f"/workspace/ComfyUI/models/loras/{file.filename}" | |
| def _sftp_upload(): | |
| client = paramiko.SSHClient() | |
| client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
| client.connect(ip, port=port, username="root", timeout=30) | |
| # Ensure dir exists | |
| client.exec_command("mkdir -p /runpod-volume/loras")[1].read() | |
| sftp = client.open_sftp() | |
| sftp.putfo(io.BytesIO(content), dest_path) | |
| sftp.close() | |
| # Symlink into ComfyUI | |
| client.exec_command(f"ln -sf {dest_path} {comfy_link}")[1].read() | |
| client.close() | |
| try: | |
| await asyncio.to_thread(_sftp_upload) | |
| logger.info("LoRA uploaded to volume: %s (%d bytes)", file.filename, len(content)) | |
| return {"status": "uploaded", "filename": file.filename, "path": dest_path} | |
| except Exception as e: | |
| logger.error("LoRA upload failed: %s", e) | |
| raise HTTPException(500, f"Upload failed: {e}") | |
| async def upload_lora_from_local(local_path: str, filename: str | None = None): | |
| """Upload a LoRA from a local server path directly to the volume via SFTP.""" | |
| import paramiko, io | |
| from pathlib import Path | |
| if _pod_state["status"] != "running": | |
| raise HTTPException(400, "Pod not running - start it first") | |
| src = Path(local_path) | |
| if not src.exists(): | |
| raise HTTPException(404, f"Local file not found: {local_path}") | |
| dest_name = filename or src.name | |
| if not dest_name.endswith(".safetensors"): | |
| raise HTTPException(400, "Only .safetensors files supported") | |
| ip = _pod_state.get("ip") | |
| port = _pod_state.get("ssh_port") or 22 | |
| dest_path = f"/runpod-volume/loras/{dest_name}" | |
| comfy_link = f"/workspace/ComfyUI/models/loras/{dest_name}" | |
| def _sftp_upload(): | |
| client = paramiko.SSHClient() | |
| client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
| client.connect(ip, port=port, username="root", timeout=30) | |
| client.exec_command("mkdir -p /runpod-volume/loras")[1].read() | |
| sftp = client.open_sftp() | |
| sftp.put(str(src), dest_path) | |
| sftp.close() | |
| client.exec_command(f"ln -sf {dest_path} {comfy_link}")[1].read() | |
| client.close() | |
| try: | |
| await asyncio.to_thread(_sftp_upload) | |
| size_mb = src.stat().st_size / 1024 / 1024 | |
| logger.info("LoRA uploaded from local: %s (%.1f MB)", dest_name, size_mb) | |
| return {"status": "uploaded", "filename": dest_name, "path": dest_path, "size_mb": round(size_mb, 1)} | |
| except Exception as e: | |
| logger.error("Local LoRA upload failed: %s", e) | |
| raise HTTPException(500, f"Upload failed: {e}") | |
| class PodGenerateRequest(BaseModel): | |
| prompt: str | |
| negative_prompt: str = "" | |
| width: int = 1024 | |
| height: int = 1024 | |
| steps: int = 28 | |
| cfg: float = 3.5 | |
| seed: int = -1 | |
| lora_name: str | None = None | |
| lora_strength: float = 0.85 | |
| lora_name_2: str | None = None | |
| lora_strength_2: float = 0.85 | |
| character_id: str | None = None | |
| template_id: str | None = None | |
| content_rating: str = "sfw" | |
| # In-memory job tracking for pod generation | |
| _pod_jobs: dict[str, dict] = {} | |
| async def generate_on_pod(request: PodGenerateRequest): | |
| """Generate an image using the running pod's ComfyUI.""" | |
| import httpx | |
| import random | |
| if _pod_state["status"] != "running": | |
| raise HTTPException(400, "Pod not running - start it first") | |
| job_id = str(uuid.uuid4())[:8] | |
| seed = request.seed if request.seed >= 0 else random.randint(0, 2**32 - 1) | |
| model_type = _pod_state.get("model_type", "flux2") | |
| if model_type == "wan22": | |
| workflow = _build_wan_t2i_workflow( | |
| prompt=request.prompt, | |
| negative_prompt=request.negative_prompt, | |
| width=request.width, | |
| height=request.height, | |
| steps=request.steps, | |
| cfg=request.cfg, | |
| seed=seed, | |
| lora_name=request.lora_name, | |
| lora_strength=request.lora_strength, | |
| lora_name_2=request.lora_name_2, | |
| lora_strength_2=request.lora_strength_2, | |
| ) | |
| else: | |
| workflow = _build_flux_workflow( | |
| prompt=request.prompt, | |
| negative_prompt=request.negative_prompt, | |
| width=request.width, | |
| height=request.height, | |
| steps=request.steps, | |
| cfg=request.cfg, | |
| seed=seed, | |
| lora_name=request.lora_name, | |
| lora_strength=request.lora_strength, | |
| model_type=model_type, | |
| ) | |
| comfyui_url = _get_comfyui_url() | |
| try: | |
| async with httpx.AsyncClient(timeout=30) as client: | |
| resp = await client.post(f"{comfyui_url}/prompt", json={"prompt": workflow}) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| prompt_id = data["prompt_id"] | |
| _pod_jobs[job_id] = { | |
| "prompt_id": prompt_id, | |
| "status": "running", | |
| "seed": seed, | |
| "created_at": time.time(), | |
| "started_at": time.time(), | |
| "positive_prompt": request.prompt, | |
| "negative_prompt": request.negative_prompt, | |
| "steps": request.steps, | |
| "cfg": request.cfg, | |
| "width": request.width, | |
| "height": request.height, | |
| } | |
| logger.info("Pod generation started: %s -> %s", job_id, prompt_id) | |
| asyncio.create_task(_poll_pod_job(job_id, prompt_id, request.content_rating)) | |
| return {"job_id": job_id, "status": "running", "seed": seed} | |
| except Exception as e: | |
| logger.error("Pod generation failed: %s", e) | |
| raise HTTPException(500, f"Generation failed: {e}") | |
| async def _poll_pod_job(job_id: str, prompt_id: str, content_rating: str): | |
| """Poll ComfyUI for job completion and save the result.""" | |
| import httpx | |
| start = time.time() | |
| timeout = 900 # 15 min β first gen loads model (~12GB) + samples | |
| comfyui_url = _get_comfyui_url() | |
| last_log_time = 0 | |
| async with httpx.AsyncClient(timeout=60) as client: | |
| while time.time() - start < timeout: | |
| try: | |
| # Log queue progress every 15 seconds and store in job | |
| elapsed = time.time() - start | |
| if elapsed - last_log_time >= 15: | |
| last_log_time = elapsed | |
| try: | |
| q_resp = await client.get(f"{comfyui_url}/queue") | |
| if q_resp.status_code == 200: | |
| q_data = q_resp.json() | |
| running = q_data.get("queue_running", []) | |
| pending = len(q_data.get("queue_pending", [])) | |
| status_msg = f"{int(elapsed)}s elapsed" | |
| if running: | |
| # Try to get node execution progress | |
| try: | |
| p_resp = await client.get(f"{comfyui_url}/prompt") | |
| if p_resp.status_code == 200: | |
| p_data = p_resp.json() | |
| exec_info = p_data.get("exec_info", {}) | |
| if exec_info: | |
| status_msg += f" | nodes: {exec_info}" | |
| except Exception: | |
| pass | |
| status_msg += " | generating..." | |
| elif pending: | |
| status_msg += " | loading models..." | |
| else: | |
| status_msg += " | waiting..." | |
| _pod_jobs[job_id]["progress_msg"] = status_msg | |
| logger.info("Pod gen %s: %s", job_id, status_msg) | |
| except Exception: | |
| pass | |
| resp = await client.get(f"{comfyui_url}/history/{prompt_id}") | |
| if resp.status_code == 200: | |
| data = resp.json() | |
| if prompt_id in data: | |
| outputs = data[prompt_id].get("outputs", {}) | |
| for node_id, node_output in outputs.items(): | |
| if "images" in node_output: | |
| image_info = node_output["images"][0] | |
| filename = image_info["filename"] | |
| subfolder = image_info.get("subfolder", "") | |
| params = {"filename": filename} | |
| if subfolder: | |
| params["subfolder"] = subfolder | |
| img_resp = await client.get(f"{comfyui_url}/view", params=params) | |
| if img_resp.status_code == 200: | |
| from content_engine.config import settings | |
| output_dir = settings.paths.output_dir / "pod" / content_rating / "raw" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| local_path = output_dir / f"pod_{job_id}.png" | |
| local_path.write_bytes(img_resp.content) | |
| _pod_jobs[job_id]["status"] = "completed" | |
| _pod_jobs[job_id]["output_path"] = str(local_path) | |
| _pod_jobs[job_id]["completed_at"] = time.time() | |
| logger.info("Pod generation completed: %s -> %s", job_id, local_path) | |
| try: | |
| from content_engine.services.catalog import CatalogService | |
| catalog = CatalogService() | |
| job_info = _pod_jobs[job_id] | |
| await catalog.insert_image( | |
| file_path=str(local_path), | |
| image_bytes=img_resp.content, | |
| content_rating=content_rating, | |
| positive_prompt=job_info.get("positive_prompt"), | |
| negative_prompt=job_info.get("negative_prompt"), | |
| seed=job_info.get("seed"), | |
| steps=job_info.get("steps"), | |
| cfg=job_info.get("cfg"), | |
| width=job_info.get("width"), | |
| height=job_info.get("height"), | |
| generation_backend="runpod-pod", | |
| generation_time_seconds=time.time() - job_info.get("created_at", time.time()), | |
| ) | |
| logger.info("Pod image cataloged: %s", job_id) | |
| except Exception as e: | |
| logger.warning("Failed to catalog pod image: %s", e) | |
| return | |
| except Exception as e: | |
| logger.debug("Polling pod job: %s", e) | |
| await asyncio.sleep(2) | |
| _pod_jobs[job_id]["status"] = "failed" | |
| _pod_jobs[job_id]["error"] = "Timeout waiting for generation" | |
| logger.error("Pod generation timed out: %s", job_id) | |
| async def get_pod_job(job_id: str): | |
| """Get status of a pod generation job.""" | |
| job = _pod_jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(404, "Job not found") | |
| return job | |
| async def get_pod_job_image(job_id: str): | |
| """Serve the generated image for a completed pod job.""" | |
| from fastapi.responses import FileResponse | |
| job = _pod_jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(404, "Job not found") | |
| output_path = job.get("output_path") | |
| if not output_path: | |
| raise HTTPException(404, "No image yet") | |
| from pathlib import Path | |
| p = Path(output_path) | |
| if not p.exists(): | |
| raise HTTPException(404, "Image file not found") | |
| return FileResponse(p, media_type="image/png") | |
| def _build_flux_workflow( | |
| prompt: str, | |
| negative_prompt: str, | |
| width: int, | |
| height: int, | |
| steps: int, | |
| cfg: float, | |
| seed: int, | |
| lora_name: str | None, | |
| lora_strength: float, | |
| model_type: str = "flux2", | |
| ) -> dict: | |
| """Build a ComfyUI workflow for FLUX generation. | |
| FLUX.2 Dev uses separate model components (not a single checkpoint): | |
| - UNETLoader for the diffusion model | |
| - CLIPLoader (type=flux2) for the Mistral text encoder | |
| - VAELoader for the autoencoder | |
| """ | |
| if model_type == "flux2": | |
| unet_name = "flux2-dev.safetensors" | |
| clip_type = "flux2" | |
| clip_name = "mistral_3_small_flux2_fp8.safetensors" | |
| else: | |
| unet_name = "flux1-dev.safetensors" | |
| clip_type = "flux" | |
| clip_name = "t5xxl_fp16.safetensors" | |
| # Model node ID references | |
| model_out = ["1", 0] # UNETLoader -> MODEL | |
| clip_out = ["2", 0] # CLIPLoader -> CLIP | |
| vae_out = ["3", 0] # VAELoader -> VAE | |
| workflow = { | |
| # Load diffusion model (UNet) | |
| "1": { | |
| "class_type": "UNETLoader", | |
| "inputs": { | |
| "unet_name": unet_name, | |
| "weight_dtype": "fp8_e4m3fn", | |
| }, | |
| }, | |
| # Load text encoder | |
| "2": { | |
| "class_type": "CLIPLoader", | |
| "inputs": { | |
| "clip_name": clip_name, | |
| "type": clip_type, | |
| }, | |
| }, | |
| # Load VAE | |
| "3": { | |
| "class_type": "VAELoader", | |
| "inputs": {"vae_name": "ae.safetensors"}, | |
| }, | |
| # Positive prompt | |
| "6": { | |
| "class_type": "CLIPTextEncode", | |
| "inputs": { | |
| "text": prompt, | |
| "clip": clip_out, | |
| }, | |
| }, | |
| # Negative prompt | |
| "7": { | |
| "class_type": "CLIPTextEncode", | |
| "inputs": { | |
| "text": negative_prompt or "", | |
| "clip": clip_out, | |
| }, | |
| }, | |
| # Empty latent | |
| "5": { | |
| "class_type": "EmptyLatentImage", | |
| "inputs": { | |
| "width": width, | |
| "height": height, | |
| "batch_size": 1, | |
| }, | |
| }, | |
| # Sampler | |
| "10": { | |
| "class_type": "KSampler", | |
| "inputs": { | |
| "seed": seed, | |
| "steps": steps, | |
| "cfg": cfg, | |
| "sampler_name": "euler", | |
| "scheduler": "simple", | |
| "denoise": 1.0, | |
| "model": model_out, | |
| "positive": ["6", 0], | |
| "negative": ["7", 0], | |
| "latent_image": ["5", 0], | |
| }, | |
| }, | |
| # Decode | |
| "8": { | |
| "class_type": "VAEDecode", | |
| "inputs": { | |
| "samples": ["10", 0], | |
| "vae": vae_out, | |
| }, | |
| }, | |
| # Save | |
| "9": { | |
| "class_type": "SaveImage", | |
| "inputs": { | |
| "filename_prefix": "flux_pod", | |
| "images": ["8", 0], | |
| }, | |
| }, | |
| } | |
| # Add LoRA if specified | |
| if lora_name: | |
| workflow["20"] = { | |
| "class_type": "LoraLoader", | |
| "inputs": { | |
| "lora_name": lora_name, | |
| "strength_model": lora_strength, | |
| "strength_clip": lora_strength, | |
| "model": model_out, | |
| "clip": clip_out, | |
| }, | |
| } | |
| # Rewire sampler and text encoders to use LoRA output | |
| workflow["10"]["inputs"]["model"] = ["20", 0] | |
| workflow["6"]["inputs"]["clip"] = ["20", 1] | |
| workflow["7"]["inputs"]["clip"] = ["20", 1] | |
| return workflow | |
| def _build_wan_t2i_workflow( | |
| prompt: str, | |
| negative_prompt: str, | |
| width: int, | |
| height: int, | |
| steps: int, | |
| cfg: float, | |
| seed: int, | |
| lora_name: str | None, | |
| lora_strength: float, | |
| lora_name_2: str | None = None, | |
| lora_strength_2: float = 0.85, | |
| ) -> dict: | |
| """Build a ComfyUI workflow for WAN 2.2 Remix β dual-DiT MoE split-step. | |
| Based on the WAN 2.2 Remix workflow from CivitAI: | |
| - Two UNETLoaders: high-noise + low-noise Remix models (fp8) | |
| - wanBlockSwap on both (offloads blocks to CPU for 24GB GPUs) | |
| - ModelSamplingSD3 with shift=5 on both | |
| - Dual KSamplerAdvanced: high-noise runs first half, low-noise finishes | |
| - CLIPLoader (type=wan) + CLIPTextEncode for T5 text encoding | |
| - Standard VAELoader + VAEDecode | |
| - EmptyHunyuanLatentVideo for latent (1 frame = image, 81+ = video) | |
| """ | |
| high_dit = "wan22_remix_t2v_high_fp8.safetensors" | |
| low_dit = "wan22_remix_t2v_low_fp8.safetensors" | |
| t5_name = "umt5_xxl_fp8_e4m3fn_scaled.safetensors" | |
| vae_name = "wan_2.1_vae.safetensors" | |
| total_steps = steps # default 8 | |
| split_step = total_steps // 2 # high-noise does first half, low-noise does rest | |
| shift = 5.0 | |
| block_swap = 20 # blocks offloaded to CPU (0-40, higher = less VRAM) | |
| workflow = { | |
| # ββ Load high-noise DiT ββ | |
| "1": { | |
| "class_type": "UNETLoader", | |
| "inputs": { | |
| "unet_name": high_dit, | |
| "weight_dtype": "fp8_e4m3fn", | |
| }, | |
| }, | |
| # ββ Load low-noise DiT ββ | |
| "2": { | |
| "class_type": "UNETLoader", | |
| "inputs": { | |
| "unet_name": low_dit, | |
| "weight_dtype": "fp8_e4m3fn", | |
| }, | |
| }, | |
| # ββ wanBlockSwap on high-noise (VRAM optimization) ββ | |
| "11": { | |
| "class_type": "wanBlockSwap", | |
| "inputs": { | |
| "model": ["1", 0], | |
| "blocks_to_swap": block_swap, | |
| "offload_img_emb": False, | |
| "offload_txt_emb": False, | |
| }, | |
| }, | |
| # ββ wanBlockSwap on low-noise ββ | |
| "12": { | |
| "class_type": "wanBlockSwap", | |
| "inputs": { | |
| "model": ["2", 0], | |
| "blocks_to_swap": block_swap, | |
| "offload_img_emb": False, | |
| "offload_txt_emb": False, | |
| }, | |
| }, | |
| # ββ ModelSamplingSD3 shift on high-noise ββ | |
| "13": { | |
| "class_type": "ModelSamplingSD3", | |
| "inputs": { | |
| "model": ["11", 0], | |
| "shift": shift, | |
| }, | |
| }, | |
| # ββ ModelSamplingSD3 shift on low-noise ββ | |
| "14": { | |
| "class_type": "ModelSamplingSD3", | |
| "inputs": { | |
| "model": ["12", 0], | |
| "shift": shift, | |
| }, | |
| }, | |
| # ββ Load T5 text encoder ββ | |
| "3": { | |
| "class_type": "CLIPLoader", | |
| "inputs": { | |
| "clip_name": t5_name, | |
| "type": "wan", | |
| }, | |
| }, | |
| # ββ Positive prompt ββ | |
| "6": { | |
| "class_type": "CLIPTextEncode", | |
| "inputs": { | |
| "text": prompt, | |
| "clip": ["3", 0], | |
| }, | |
| }, | |
| # ββ Negative prompt ββ | |
| "7": { | |
| "class_type": "CLIPTextEncode", | |
| "inputs": { | |
| "text": negative_prompt or "", | |
| "clip": ["3", 0], | |
| }, | |
| }, | |
| # ββ VAE ββ | |
| "4": { | |
| "class_type": "VAELoader", | |
| "inputs": {"vae_name": vae_name}, | |
| }, | |
| # ββ Empty latent (1 frame = single image) ββ | |
| "5": { | |
| "class_type": "EmptyHunyuanLatentVideo", | |
| "inputs": { | |
| "width": width, | |
| "height": height, | |
| "length": 1, | |
| "batch_size": 1, | |
| }, | |
| }, | |
| # ββ KSamplerAdvanced #1: High-noise model (first half of steps) ββ | |
| "15": { | |
| "class_type": "KSamplerAdvanced", | |
| "inputs": { | |
| "model": ["13", 0], | |
| "positive": ["6", 0], | |
| "negative": ["7", 0], | |
| "latent_image": ["5", 0], | |
| "add_noise": "enable", | |
| "noise_seed": seed, | |
| "steps": total_steps, | |
| "cfg": cfg, | |
| "sampler_name": "euler", | |
| "scheduler": "simple", | |
| "start_at_step": 0, | |
| "end_at_step": split_step, | |
| "return_with_leftover_noise": "enable", | |
| }, | |
| }, | |
| # ββ KSamplerAdvanced #2: Low-noise model (second half of steps) ββ | |
| "16": { | |
| "class_type": "KSamplerAdvanced", | |
| "inputs": { | |
| "model": ["14", 0], | |
| "positive": ["6", 0], | |
| "negative": ["7", 0], | |
| "latent_image": ["15", 0], | |
| "add_noise": "disable", | |
| "noise_seed": seed, | |
| "steps": total_steps, | |
| "cfg": cfg, | |
| "sampler_name": "euler", | |
| "scheduler": "simple", | |
| "start_at_step": split_step, | |
| "end_at_step": 10000, | |
| "return_with_leftover_noise": "disable", | |
| }, | |
| }, | |
| # ββ VAE Decode ββ | |
| "8": { | |
| "class_type": "VAEDecode", | |
| "inputs": { | |
| "samples": ["16", 0], | |
| "vae": ["4", 0], | |
| }, | |
| }, | |
| # ββ Save Image ββ | |
| "9": { | |
| "class_type": "SaveImage", | |
| "inputs": { | |
| "filename_prefix": "wan_remix_pod", | |
| "images": ["8", 0], | |
| }, | |
| }, | |
| } | |
| # Add LoRA(s) to both models if specified β chained: DiT β LoRA1 β LoRA2 β Sampler | |
| if lora_name: | |
| # LoRA 1 (body) on high-noise and low-noise DiT | |
| workflow["20"] = { | |
| "class_type": "LoraLoader", | |
| "inputs": { | |
| "lora_name": lora_name, | |
| "strength_model": lora_strength, | |
| "strength_clip": 1.0, | |
| "model": ["13", 0], | |
| "clip": ["3", 0], | |
| }, | |
| } | |
| workflow["21"] = { | |
| "class_type": "LoraLoader", | |
| "inputs": { | |
| "lora_name": lora_name, | |
| "strength_model": lora_strength, | |
| "strength_clip": 1.0, | |
| "model": ["14", 0], | |
| "clip": ["3", 0], | |
| }, | |
| } | |
| # Determine what the samplers and CLIP read from (LoRA2 if present, else LoRA1) | |
| high_model_out = ["20", 0] | |
| low_model_out = ["21", 0] | |
| clip_out = ["20", 1] | |
| if lora_name_2: | |
| # LoRA 2 (face) chained after LoRA 1 on both models | |
| workflow["22"] = { | |
| "class_type": "LoraLoader", | |
| "inputs": { | |
| "lora_name": lora_name_2, | |
| "strength_model": lora_strength_2, | |
| "strength_clip": 1.0, | |
| "model": ["20", 0], | |
| "clip": ["20", 1], | |
| }, | |
| } | |
| workflow["23"] = { | |
| "class_type": "LoraLoader", | |
| "inputs": { | |
| "lora_name": lora_name_2, | |
| "strength_model": lora_strength_2, | |
| "strength_clip": 1.0, | |
| "model": ["21", 0], | |
| "clip": ["21", 1], | |
| }, | |
| } | |
| high_model_out = ["22", 0] | |
| low_model_out = ["23", 0] | |
| clip_out = ["22", 1] | |
| # Rewire samplers and CLIP encoding | |
| workflow["15"]["inputs"]["model"] = high_model_out | |
| workflow["16"]["inputs"]["model"] = low_model_out | |
| workflow["6"]["inputs"]["clip"] = clip_out | |
| workflow["7"]["inputs"]["clip"] = clip_out | |
| return workflow | |