Spaces:
Running
Running
| """Video generation routes — WAN 2.2 img2video on RunPod pod or WaveSpeed cloud.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import base64 | |
| import logging | |
| import os | |
| import time | |
| import uuid | |
| from pathlib import Path | |
| import runpod | |
| from fastapi import APIRouter, File, Form, HTTPException, UploadFile | |
| from pydantic import BaseModel | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(prefix="/api/video", tags=["video"]) | |
| # Video jobs tracking | |
| _video_jobs: dict[str, dict] = {} | |
| # Cloud providers (initialized from main.py) | |
| _wavespeed_provider = None | |
| _higgsfield_provider = None | |
| def init_wavespeed(provider): | |
| """Initialize WaveSpeed provider for cloud video generation.""" | |
| global _wavespeed_provider | |
| _wavespeed_provider = provider | |
| def init_higgsfield(provider): | |
| """Initialize Higgsfield provider for Kling 3.0, Sora 2, etc.""" | |
| global _higgsfield_provider | |
| _higgsfield_provider = provider | |
| # Pod state is shared from routes_pod | |
| def _get_pod_state(): | |
| from content_engine.api.routes_pod import _pod_state | |
| return _pod_state | |
| def _get_comfyui_url(): | |
| from content_engine.api.routes_pod import _get_comfyui_url as _gcurl | |
| return _gcurl() | |
| class VideoGenerateRequest(BaseModel): | |
| prompt: str | |
| negative_prompt: str = "" | |
| num_frames: int = 81 # ~3 seconds at 24fps | |
| fps: int = 24 | |
| seed: int = -1 | |
| async def generate_video( | |
| image: UploadFile = File(...), | |
| prompt: str = Form(...), | |
| negative_prompt: str = Form(""), | |
| num_frames: int = Form(81), | |
| fps: int = Form(24), | |
| seed: int = Form(-1), | |
| ): | |
| """Generate a video from an image using WAN 2.2 I2V on the RunPod pod.""" | |
| import httpx | |
| import random | |
| import base64 | |
| pod_state = _get_pod_state() | |
| if pod_state["status"] != "running": | |
| raise HTTPException(400, "Pod not running - start it first in Status page") | |
| job_id = str(uuid.uuid4())[:8] | |
| seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) | |
| # Read the image | |
| image_bytes = await image.read() | |
| image_b64 = base64.b64encode(image_bytes).decode("utf-8") | |
| # Build ComfyUI workflow for WAN 2.2 I2V | |
| workflow = _build_wan_i2v_workflow( | |
| image_b64=image_b64, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_frames=num_frames, | |
| fps=fps, | |
| seed=seed, | |
| ) | |
| try: | |
| comfyui_url = _get_comfyui_url() | |
| async with httpx.AsyncClient(timeout=30) as client: | |
| # First upload the image to ComfyUI | |
| upload_url = f"{comfyui_url}/upload/image" | |
| files = {"image": (f"input_{job_id}.png", image_bytes, "image/png")} | |
| upload_resp = await client.post(upload_url, files=files) | |
| if upload_resp.status_code != 200: | |
| raise HTTPException(500, "Failed to upload image to pod") | |
| upload_data = upload_resp.json() | |
| uploaded_filename = upload_data.get("name", f"input_{job_id}.png") | |
| # Update workflow with uploaded filename | |
| workflow = _build_wan_i2v_workflow( | |
| uploaded_filename=uploaded_filename, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_frames=num_frames, | |
| fps=fps, | |
| seed=seed, | |
| ) | |
| # Submit workflow | |
| url = f"{comfyui_url}/prompt" | |
| resp = await client.post(url, json={"prompt": workflow}) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| prompt_id = data["prompt_id"] | |
| _video_jobs[job_id] = { | |
| "prompt_id": prompt_id, | |
| "status": "running", | |
| "seed": seed, | |
| "started_at": time.time(), | |
| "num_frames": num_frames, | |
| "fps": fps, | |
| } | |
| logger.info("Video generation started: %s -> %s", job_id, prompt_id) | |
| # Start background task to poll for completion | |
| asyncio.create_task(_poll_video_job(job_id, prompt_id)) | |
| return { | |
| "job_id": job_id, | |
| "status": "running", | |
| "seed": seed, | |
| "estimated_time": f"~{num_frames * 2} seconds", | |
| } | |
| except httpx.HTTPError as e: | |
| logger.error("Video generation failed: %s", e) | |
| raise HTTPException(500, f"Generation failed: {e}") | |
| async def generate_video_cloud( | |
| image: UploadFile = File(...), | |
| prompt: str = Form("smooth motion, high quality video"), | |
| negative_prompt: str = Form(""), | |
| model: str = Form("wan-2.6-i2v"), | |
| num_frames: int = Form(81), | |
| fps: int = Form(24), | |
| seed: int = Form(-1), | |
| backend: str = Form("wavespeed"), # wavespeed or higgsfield | |
| ): | |
| """Generate a video using cloud API (WaveSpeed or Higgsfield).""" | |
| import random | |
| import httpx | |
| logger.info("Video cloud generation request: model=%s, backend=%s, frames=%d", model, backend, num_frames) | |
| # Route to Higgsfield for Kling 3.0 models | |
| if backend == "higgsfield" or model.startswith("kling-3"): | |
| logger.info("Routing to Higgsfield for model: %s", model) | |
| return await generate_video_higgsfield( | |
| image=image, | |
| prompt=prompt, | |
| model=model, | |
| duration=max(3, num_frames // 24), # Convert frames to seconds | |
| seed=seed, | |
| ) | |
| if not _wavespeed_provider: | |
| logger.error("WaveSpeed provider not configured!") | |
| raise HTTPException(500, "WaveSpeed API not configured") | |
| job_id = str(uuid.uuid4())[:8] | |
| seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) | |
| # Read the image | |
| image_bytes = await image.read() | |
| image_b64 = base64.b64encode(image_bytes).decode("utf-8") | |
| # Create job entry | |
| _video_jobs[job_id] = { | |
| "status": "running", | |
| "seed": seed, | |
| "started_at": time.time(), | |
| "num_frames": num_frames, | |
| "fps": fps, | |
| "model": model, | |
| "backend": "cloud", | |
| } | |
| logger.info("Cloud video generation started: %s (model=%s)", job_id, model) | |
| # Start background task for cloud video generation | |
| asyncio.create_task(_generate_cloud_video(job_id, image_bytes, prompt, negative_prompt, model, seed)) | |
| return { | |
| "job_id": job_id, | |
| "status": "running", | |
| "seed": seed, | |
| "model": model, | |
| "estimated_time": "~30-120 seconds", | |
| } | |
| async def generate_video_higgsfield( | |
| image: UploadFile = File(...), | |
| prompt: str = Form("smooth cinematic motion"), | |
| model: str = Form("kling-3.0"), | |
| duration: int = Form(5), | |
| resolution: str = Form("720p"), | |
| enable_audio: bool = Form(False), | |
| seed: int = Form(-1), | |
| ): | |
| """Generate a video using Higgsfield (Kling 3.0, Sora 2, Veo 3.1).""" | |
| import random | |
| if not _higgsfield_provider: | |
| raise HTTPException(500, "Higgsfield API not configured - set HIGGSFIELD_API_KEY") | |
| job_id = str(uuid.uuid4())[:8] | |
| seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) | |
| # Read the image | |
| image_bytes = await image.read() | |
| # Create job entry | |
| _video_jobs[job_id] = { | |
| "status": "running", | |
| "seed": seed, | |
| "started_at": time.time(), | |
| "duration": duration, | |
| "model": model, | |
| "backend": "higgsfield", | |
| "message": "Starting Higgsfield video generation...", | |
| } | |
| logger.info("Higgsfield video generation started: %s (model=%s)", job_id, model) | |
| # Start background task | |
| asyncio.create_task(_generate_higgsfield_video( | |
| job_id, image_bytes, prompt, model, duration, resolution, enable_audio, seed | |
| )) | |
| return { | |
| "job_id": job_id, | |
| "status": "running", | |
| "seed": seed, | |
| "model": model, | |
| "backend": "higgsfield", | |
| "estimated_time": f"~{duration * 10}-{duration * 20} seconds", | |
| } | |
| async def generate_kling_motion( | |
| image: UploadFile = File(...), | |
| driving_video: UploadFile = File(...), | |
| prompt: str = Form("smooth motion, high quality video"), | |
| duration: int = Form(5), | |
| seed: int = Form(-1), | |
| character_orientation: str = Form("image"), | |
| ): | |
| """Generate video using Kling Motion Control (character image + driving video).""" | |
| import random | |
| if not _wavespeed_provider: | |
| raise HTTPException(500, "WaveSpeed API not configured") | |
| job_id = str(uuid.uuid4())[:8] | |
| seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) | |
| image_bytes = await image.read() | |
| video_bytes = await driving_video.read() | |
| _video_jobs[job_id] = { | |
| "status": "running", | |
| "seed": seed, | |
| "started_at": time.time(), | |
| "model": "kling-motion", | |
| "backend": "cloud", | |
| "message": "Uploading files...", | |
| } | |
| asyncio.create_task(_generate_kling_motion_video(job_id, image_bytes, video_bytes, prompt, duration, character_orientation)) | |
| return {"job_id": job_id, "status": "running", "estimated_time": "~60-120 seconds"} | |
| async def _generate_kling_motion_video( | |
| job_id: str, | |
| image_bytes: bytes, | |
| video_bytes: bytes, | |
| prompt: str, | |
| duration: int, | |
| character_orientation: str = "image", | |
| ): | |
| """Background task: upload image + driving video, call Kling Motion Control API.""" | |
| import httpx | |
| import aiohttp | |
| try: | |
| # Resize image if too large (Kling Motion limit) | |
| from PIL import Image | |
| import io | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| max_size = 1280 | |
| if img.width > max_size or img.height > max_size: | |
| img.thumbnail((max_size, max_size), Image.LANCZOS) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=90) | |
| image_bytes = buf.getvalue() | |
| logger.info("Resized character image to %dx%d", img.width, img.height) | |
| _video_jobs[job_id]["message"] = "Uploading character image..." | |
| image_url = await _wavespeed_provider._upload_temp_image(image_bytes) | |
| logger.info("Kling motion: character image uploaded: %s", image_url[:80]) | |
| _video_jobs[job_id]["message"] = "Uploading driving video..." | |
| video_url = await _upload_temp_video(video_bytes) | |
| logger.info("Kling motion: driving video uploaded: %s", video_url[:80]) | |
| api_key = _wavespeed_provider._api_key | |
| endpoint = "https://api.wavespeed.ai/api/v3/kwaivgi/kling-v2.6-pro/motion-control" | |
| payload = { | |
| "image": image_url, | |
| "video": video_url, | |
| "prompt": prompt, | |
| "duration": duration, | |
| "character_orientation": character_orientation, | |
| "enable_sync_mode": False, | |
| } | |
| _video_jobs[job_id]["message"] = "Calling Kling Motion Control API..." | |
| logger.info("Calling Kling Motion Control: %s", endpoint) | |
| # Reuse the provider's existing httpx client (avoids SSL reconnect issues) | |
| http = _wavespeed_provider._http_client | |
| resp = await http.post( | |
| endpoint, | |
| json=payload, | |
| headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, | |
| ) | |
| if resp.status_code != 200: | |
| error_text = resp.text[:500] | |
| logger.error("Kling Motion API error: %s", error_text) | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = f"API error: {error_text[:200]}" | |
| return | |
| result = resp.json() | |
| data = result.get("data", result) | |
| logger.info("Kling Motion API response: %s", str(result)[:300]) | |
| # Poll if async | |
| outputs = data.get("outputs", []) | |
| urls_data = data.get("urls", {}) | |
| if not outputs and urls_data.get("get"): | |
| _video_jobs[job_id]["message"] = "Waiting for Kling Motion to complete..." | |
| video_url_out = await _poll_wavespeed_video(urls_data["get"], api_key, job_id, max_attempts=300, interval=5.0) | |
| elif outputs: | |
| video_url_out = outputs[0] if isinstance(outputs[0], str) else outputs[0].get("url") | |
| else: | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = "No output URL in response" | |
| return | |
| if not video_url_out: | |
| return # poll already set status | |
| # Download and save | |
| _video_jobs[job_id]["message"] = "Downloading video..." | |
| from content_engine.config import settings | |
| output_dir = settings.paths.output_dir / "videos" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| filename = f"kling_motion_{job_id}.mp4" | |
| output_path = output_dir / filename | |
| r = await http.get(video_url_out) | |
| output_path.write_bytes(r.content) | |
| _video_jobs[job_id]["status"] = "completed" | |
| _video_jobs[job_id]["filename"] = filename | |
| _video_jobs[job_id]["message"] = "Done" | |
| logger.info("Kling Motion video saved: %s", filename) | |
| except Exception as e: | |
| logger.exception("Kling Motion generation failed: %s", e) | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = str(e) | |
| async def _upload_temp_video(video_bytes: bytes) -> str: | |
| """Upload a video file to litterbox.catbox.moe and return the URL.""" | |
| import aiohttp | |
| async with aiohttp.ClientSession() as session: | |
| data = aiohttp.FormData() | |
| data.add_field("reqtype", "fileupload") | |
| data.add_field("time", "1h") | |
| data.add_field("fileToUpload", video_bytes, filename="driving.mp4", content_type="video/mp4") | |
| async with session.post("https://litterbox.catbox.moe/resources/internals/api.php", data=data) as resp: | |
| if resp.status == 200: | |
| url = (await resp.text()).strip() | |
| if url.startswith("http"): | |
| return url | |
| raise RuntimeError("Failed to upload driving video to litterbox.catbox.moe") | |
| async def _generate_higgsfield_video( | |
| job_id: str, | |
| image_bytes: bytes, | |
| prompt: str, | |
| model: str, | |
| duration: int, | |
| resolution: str, | |
| enable_audio: bool, | |
| seed: int, | |
| ): | |
| """Background task to generate video via Higgsfield API.""" | |
| try: | |
| _video_jobs[job_id]["message"] = "Uploading image to Higgsfield..." | |
| # Upload image to temp URL | |
| image_url = await _higgsfield_provider._upload_temp_image(image_bytes) if hasattr(_higgsfield_provider, '_upload_temp_image') else None | |
| if not image_url: | |
| # Fall back to base64 data URL | |
| import base64 | |
| image_b64 = base64.b64encode(image_bytes).decode("utf-8") | |
| image_url = f"data:image/png;base64,{image_b64}" | |
| _video_jobs[job_id]["message"] = f"Generating video with {model}..." | |
| # Generate video | |
| result = await _higgsfield_provider.generate_video( | |
| prompt=prompt, | |
| model=model, | |
| duration=duration, | |
| resolution=resolution, | |
| enable_audio=enable_audio, | |
| image_url=image_url, | |
| ) | |
| video_url = result.get("video_url") | |
| if not video_url: | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = "No video URL in response" | |
| return | |
| # Download the video | |
| _video_jobs[job_id]["message"] = "Downloading generated video..." | |
| import httpx | |
| async with httpx.AsyncClient(timeout=120) as client: | |
| video_resp = await client.get(video_url) | |
| if video_resp.status_code != 200: | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = "Failed to download video" | |
| return | |
| # Save to local output directory | |
| from content_engine.config import settings | |
| output_dir = settings.paths.output_dir / "videos" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| ext = ".mp4" | |
| if video_url.endswith(".webm"): | |
| ext = ".webm" | |
| local_path = output_dir / f"video_{job_id}{ext}" | |
| local_path.write_bytes(video_resp.content) | |
| _video_jobs[job_id]["status"] = "completed" | |
| _video_jobs[job_id]["output_path"] = str(local_path) | |
| _video_jobs[job_id]["completed_at"] = time.time() | |
| _video_jobs[job_id]["filename"] = local_path.name | |
| _video_jobs[job_id]["message"] = "Video generated successfully!" | |
| logger.info("Higgsfield video saved: %s", local_path) | |
| except Exception as e: | |
| logger.error("Higgsfield video generation failed: %s", e) | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = str(e) | |
| async def _poll_wavespeed_video(poll_url: str, api_key: str, job_id: str, max_attempts: int = 120, interval: float = 3.0) -> str | None: | |
| """Poll the WaveSpeed async video job URL until outputs are ready. | |
| Returns the first output URL when available, or None on failure. | |
| """ | |
| import httpx | |
| async with httpx.AsyncClient(timeout=60) as client: | |
| for attempt in range(max_attempts): | |
| try: | |
| resp = await client.get( | |
| poll_url, | |
| headers={"Authorization": f"Bearer {api_key}"}, | |
| ) | |
| resp.raise_for_status() | |
| result = resp.json() | |
| data = result.get("data", result) | |
| status = data.get("status", "") | |
| if status == "failed": | |
| error_msg = data.get("error", "Unknown error") | |
| logger.error("WaveSpeed video job failed: %s", error_msg) | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = error_msg | |
| return None | |
| outputs = data.get("outputs", []) | |
| if outputs: | |
| logger.info("WaveSpeed video job completed after %d polls", attempt + 1) | |
| return outputs[0] | |
| # Also check for 'output' field | |
| if "output" in data: | |
| out = data["output"] | |
| if isinstance(out, list) and out: | |
| return out[0] | |
| elif isinstance(out, str): | |
| return out | |
| # Update job status with progress | |
| _video_jobs[job_id]["message"] = f"Generating video... (poll {attempt + 1}/{max_attempts})" | |
| logger.debug("WaveSpeed video job pending (attempt %d/%d)", attempt + 1, max_attempts) | |
| await asyncio.sleep(interval) | |
| except Exception as e: | |
| logger.warning("Video poll request failed: %s", e) | |
| await asyncio.sleep(interval) | |
| logger.error("WaveSpeed video job timed out after %d attempts", max_attempts) | |
| return None | |
| async def _generate_cloud_video( | |
| job_id: str, | |
| image_bytes: bytes, | |
| prompt: str, | |
| negative_prompt: str, | |
| model: str, | |
| seed: int, | |
| ): | |
| """Background task to generate video via WaveSpeed cloud API.""" | |
| import httpx | |
| import aiohttp | |
| logger.info("Starting cloud video generation: job=%s, model=%s, image_size=%d bytes", job_id, model, len(image_bytes)) | |
| _video_jobs[job_id]["message"] = "Uploading image..." | |
| try: | |
| # Upload image to temporary hosting (WaveSpeed needs URL) | |
| logger.info("Uploading image to temp host...") | |
| image_url = await _wavespeed_provider._upload_temp_image(image_bytes) | |
| logger.info("Image uploaded: %s", image_url[:80] if image_url else "FAILED") | |
| # Resolve model to WaveSpeed model ID | |
| from content_engine.services.cloud_providers.wavespeed_provider import VIDEO_MODEL_MAP | |
| wavespeed_model = VIDEO_MODEL_MAP.get(model, VIDEO_MODEL_MAP.get("default", "alibaba/wan-2.6-i2v-720p")) | |
| # Call WaveSpeed video API | |
| api_key = _wavespeed_provider._api_key | |
| endpoint = f"https://api.wavespeed.ai/api/v3/{wavespeed_model}" | |
| payload = { | |
| "image": image_url, | |
| "prompt": prompt, | |
| "enable_sync_mode": True, | |
| } | |
| if negative_prompt: | |
| payload["negative_prompt"] = negative_prompt | |
| # Grok Imagine Video uses duration (6 or 10s) instead of frame counts | |
| if model == "grok-imagine-i2v": | |
| num_frames = _video_jobs[job_id].get("num_frames", 81) | |
| payload["duration"] = 10 if num_frames > 150 else 6 | |
| _video_jobs[job_id]["message"] = f"Calling WaveSpeed API ({wavespeed_model})..." | |
| logger.info("Calling WaveSpeed video API: %s", endpoint) | |
| async with httpx.AsyncClient(timeout=300) as client: | |
| resp = await client.post( | |
| endpoint, | |
| json=payload, | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| ) | |
| if resp.status_code != 200: | |
| error_text = resp.text[:500] | |
| logger.error("WaveSpeed video API error: %s", error_text) | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = f"API error: {error_text[:200]}" | |
| return | |
| result = resp.json() | |
| logger.info("WaveSpeed video API response: %s", str(result)[:500]) | |
| data = result.get("data", result) | |
| # Check for failed status | |
| if data.get("status") == "failed": | |
| error_msg = data.get("error", "Unknown error") | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = error_msg | |
| return | |
| # Extract video URL - handle async response (outputs empty, urls.get present) | |
| video_url = None | |
| outputs = data.get("outputs", []) | |
| urls_data = data.get("urls", {}) | |
| # Check for async response first | |
| if not outputs and urls_data and urls_data.get("get"): | |
| poll_url = urls_data["get"] | |
| logger.info("WaveSpeed video returned async job, polling: %s", poll_url[:80]) | |
| _video_jobs[job_id]["message"] = "Polling for video result..." | |
| video_url = await _poll_wavespeed_video(poll_url, api_key, job_id) | |
| elif outputs: | |
| video_url = outputs[0] | |
| elif "output" in data: | |
| out = data["output"] | |
| if isinstance(out, list) and out: | |
| video_url = out[0] | |
| elif isinstance(out, str): | |
| video_url = out | |
| if not video_url: | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = f"No video URL in response: {data}" | |
| return | |
| # Download the video | |
| logger.info("Downloading cloud video: %s", video_url[:80]) | |
| video_resp = await client.get(video_url) | |
| if video_resp.status_code != 200: | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = "Failed to download video" | |
| return | |
| # Save to local output directory | |
| from content_engine.config import settings | |
| output_dir = settings.paths.output_dir / "videos" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Determine extension from URL or default to mp4 | |
| ext = ".mp4" | |
| if video_url.endswith(".webm"): | |
| ext = ".webm" | |
| elif video_url.endswith(".webp"): | |
| ext = ".webp" | |
| local_path = output_dir / f"video_{job_id}{ext}" | |
| local_path.write_bytes(video_resp.content) | |
| _video_jobs[job_id]["status"] = "completed" | |
| _video_jobs[job_id]["output_path"] = str(local_path) | |
| _video_jobs[job_id]["completed_at"] = time.time() | |
| _video_jobs[job_id]["filename"] = local_path.name | |
| logger.info("Cloud video saved: %s", local_path) | |
| except Exception as e: | |
| logger.error("Cloud video generation failed: %s", e) | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = str(e) | |
| async def _poll_video_job(job_id: str, prompt_id: str): | |
| """Poll ComfyUI for video job completion.""" | |
| import httpx | |
| start = time.time() | |
| timeout = 1800 # 30 minutes for video (WAN 2.2 needs time to load 14B model first run) | |
| comfyui_url = _get_comfyui_url() | |
| async with httpx.AsyncClient(timeout=60) as client: | |
| while time.time() - start < timeout: | |
| try: | |
| url = f"{comfyui_url}/history/{prompt_id}" | |
| resp = await client.get(url) | |
| if resp.status_code == 200: | |
| data = resp.json() | |
| if prompt_id in data: | |
| outputs = data[prompt_id].get("outputs", {}) | |
| # Find video output (SaveAnimatedWEBP or VHS_VideoCombine) | |
| for node_id, node_output in outputs.items(): | |
| # Check for gifs/videos | |
| if "gifs" in node_output: | |
| video_info = node_output["gifs"][0] | |
| await _download_video(client, job_id, video_info, pod_state) | |
| return | |
| # Check for images (animated) | |
| if "images" in node_output: | |
| img_info = node_output["images"][0] | |
| if img_info.get("type") == "output": | |
| await _download_video(client, job_id, img_info, pod_state) | |
| return | |
| except Exception as e: | |
| logger.debug("Polling video job: %s", e) | |
| await asyncio.sleep(3) | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = "Timeout waiting for video generation" | |
| logger.error("Video generation timed out: %s", job_id) | |
| async def _download_video(client, job_id: str, video_info: dict, pod_state: dict): | |
| """Download the generated video from ComfyUI.""" | |
| filename = video_info.get("filename") | |
| subfolder = video_info.get("subfolder", "") | |
| file_type = video_info.get("type", "output") | |
| # Download video | |
| view_url = f"{_get_comfyui_url()}/view" | |
| params = {"filename": filename, "type": file_type} | |
| if subfolder: | |
| params["subfolder"] = subfolder | |
| video_resp = await client.get(view_url, params=params) | |
| if video_resp.status_code == 200: | |
| # Save to local output directory | |
| from content_engine.config import settings | |
| output_dir = settings.paths.output_dir / "videos" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Determine extension | |
| ext = Path(filename).suffix or ".webp" | |
| local_path = output_dir / f"video_{job_id}{ext}" | |
| local_path.write_bytes(video_resp.content) | |
| _video_jobs[job_id]["status"] = "completed" | |
| _video_jobs[job_id]["output_path"] = str(local_path) | |
| _video_jobs[job_id]["completed_at"] = time.time() | |
| _video_jobs[job_id]["filename"] = local_path.name | |
| logger.info("Video saved: %s", local_path) | |
| else: | |
| _video_jobs[job_id]["status"] = "failed" | |
| _video_jobs[job_id]["error"] = "Failed to download video" | |
| async def list_video_jobs(): | |
| """List all video generation jobs.""" | |
| return list(_video_jobs.values()) | |
| async def get_video_job(job_id: str): | |
| """Get status of a video generation job.""" | |
| job = _video_jobs.get(job_id) | |
| if not job: | |
| raise HTTPException(404, "Job not found") | |
| return job | |
| async def get_video_file(filename: str): | |
| """Serve a generated video file.""" | |
| from fastapi.responses import FileResponse | |
| from content_engine.config import settings | |
| video_path = settings.paths.output_dir / "videos" / filename | |
| if not video_path.exists(): | |
| raise HTTPException(404, "Video not found") | |
| if filename.endswith(".webm"): | |
| media_type = "video/webm" | |
| elif filename.endswith(".mp4"): | |
| media_type = "video/mp4" | |
| else: | |
| media_type = "image/webp" | |
| return FileResponse(video_path, media_type=media_type) | |
| async def generate_video_animate( | |
| image: UploadFile = File(...), | |
| driving_video: UploadFile = File(...), | |
| prompt: str = Form("a person dancing, smooth motion, high quality"), | |
| negative_prompt: str = Form(""), | |
| width: int = Form(832), | |
| height: int = Form(480), | |
| num_frames: int = Form(81), | |
| fps: int = Form(16), | |
| seed: int = Form(-1), | |
| steps: int = Form(20), | |
| cfg: float = Form(6.0), | |
| bg_mode: str = Form("keep"), # keep | driving_video | auto | |
| ): | |
| """Generate a dance animation via WAN 2.2 Animate on RunPod ComfyUI pod. | |
| Requires on the pod: | |
| - models/diffusion_models/Wan2_2-Animate-14B_fp8_e4m3fn_scaled_KJ.safetensors | |
| - models/vae/wan_2.1_vae.safetensors | |
| - models/clip_vision/clip_vision_h.safetensors | |
| - models/text_encoders/umt5-xxl-enc-bf16.safetensors | |
| - Custom nodes: ComfyUI-WanVideoWrapper, ComfyUI-VideoHelperSuite, comfyui_controlnet_aux | |
| """ | |
| import httpx | |
| import random | |
| pod_state = _get_pod_state() | |
| if pod_state["status"] != "running": | |
| raise HTTPException(400, "Pod not running — start it first in Status page") | |
| job_id = str(uuid.uuid4())[:8] | |
| seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) | |
| image_bytes = await image.read() | |
| video_bytes = await driving_video.read() | |
| try: | |
| base_url = _get_comfyui_url() | |
| async with httpx.AsyncClient(timeout=60) as client: | |
| # Upload character reference image | |
| img_resp = await client.post( | |
| f"{base_url}/upload/image", | |
| files={"image": (f"ref_{job_id}.png", image_bytes, "image/png")}, | |
| ) | |
| if img_resp.status_code != 200: | |
| raise HTTPException(500, f"Failed to upload character image: {img_resp.text[:200]}") | |
| img_filename = img_resp.json().get("name", f"ref_{job_id}.png") | |
| logger.info("Uploaded character image: %s", img_filename) | |
| # Upload driving video | |
| vid_ext = "mp4" | |
| if driving_video.filename and "." in driving_video.filename: | |
| vid_ext = driving_video.filename.rsplit(".", 1)[-1].lower() | |
| vid_resp = await client.post( | |
| f"{base_url}/upload/image", | |
| files={"image": (f"drive_{job_id}.{vid_ext}", video_bytes, "video/mp4")}, | |
| ) | |
| if vid_resp.status_code != 200: | |
| raise HTTPException(500, f"Failed to upload driving video: {vid_resp.text[:200]}") | |
| vid_filename = vid_resp.json().get("name", f"drive_{job_id}.{vid_ext}") | |
| logger.info("Uploaded driving video: %s", vid_filename) | |
| workflow = _build_wan_animate_workflow( | |
| ref_image_filename=img_filename, | |
| driving_video_filename=vid_filename, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_frames=num_frames, | |
| fps=fps, | |
| seed=seed, | |
| steps=steps, | |
| cfg=cfg, | |
| bg_mode=bg_mode, | |
| ) | |
| resp = await client.post(f"{base_url}/prompt", json={"prompt": workflow}) | |
| if resp.status_code != 200: | |
| logger.error("ComfyUI /prompt rejected workflow: %s", resp.text[:2000]) | |
| resp.raise_for_status() | |
| prompt_id = resp.json()["prompt_id"] | |
| _video_jobs[job_id] = { | |
| "prompt_id": prompt_id, | |
| "status": "running", | |
| "seed": seed, | |
| "started_at": time.time(), | |
| "num_frames": num_frames, | |
| "fps": fps, | |
| "mode": "animate", | |
| "message": "WAN 2.2 Animate submitted...", | |
| } | |
| logger.info("WAN Animate job started: %s -> %s", job_id, prompt_id) | |
| asyncio.create_task(_poll_video_job(job_id, prompt_id)) | |
| return { | |
| "job_id": job_id, | |
| "status": "running", | |
| "seed": seed, | |
| "estimated_time": f"~{num_frames * 3} seconds", | |
| } | |
| except httpx.HTTPError as e: | |
| logger.error("WAN Animate generation failed: %s", e) | |
| raise HTTPException(500, f"Generation failed: {e}") | |
| def _build_wan_animate_workflow( | |
| ref_image_filename: str, | |
| driving_video_filename: str, | |
| prompt: str = "a person dancing, smooth motion", | |
| negative_prompt: str = "", | |
| width: int = 832, | |
| height: int = 480, | |
| num_frames: int = 81, | |
| fps: int = 16, | |
| seed: int = 42, | |
| steps: int = 20, | |
| cfg: float = 6.0, | |
| bg_mode: str = "auto", | |
| ) -> dict: | |
| """Build ComfyUI API workflow for WAN 2.2 Animate (motion transfer from driving video). | |
| Pipeline: | |
| reference image -> CLIP encode + resize | |
| driving video -> DWPreprocessor (pose skeleton) | |
| both -> WanVideoAnimateEmbeds -> WanVideoSampler -> decode -> MP4 | |
| bg_mode options: | |
| "keep" - use reference image as background (character's original background) | |
| "driving_video" - use driving video frames as background | |
| "auto" - no bg hint, model generates its own background | |
| """ | |
| neg = negative_prompt or "blurry, static, low quality, watermark, text" | |
| workflow = { | |
| # VAE | |
| "1": { | |
| "class_type": "WanVideoVAELoader", | |
| "inputs": { | |
| "model_name": "wan_2.1_vae.safetensors", | |
| "precision": "bf16", | |
| }, | |
| }, | |
| # CLIP Vision | |
| "2": { | |
| "class_type": "CLIPVisionLoader", | |
| "inputs": {"clip_name": "clip_vision_h.safetensors"}, | |
| }, | |
| # Diffusion model | |
| "3": { | |
| "class_type": "WanVideoModelLoader", | |
| "inputs": { | |
| "model": "wan2.2_animate_14B_bf16.safetensors", | |
| "base_precision": "bf16", | |
| "quantization": "fp8_e4m3fn", | |
| "load_device": "offload_device", | |
| "attention_mode": "sdpa", | |
| }, | |
| }, | |
| # Load T5 text encoder | |
| "4": { | |
| "class_type": "LoadWanVideoT5TextEncoder", | |
| "inputs": { | |
| "model_name": "umt5-xxl-enc-fp8_e4m3fn.safetensors", | |
| "precision": "bf16", | |
| }, | |
| }, | |
| # Encode text prompts | |
| "16": { | |
| "class_type": "WanVideoTextEncode", | |
| "inputs": { | |
| "positive_prompt": prompt, | |
| "negative_prompt": neg, | |
| "t5": ["4", 0], | |
| "force_offload": True, | |
| }, | |
| }, | |
| # Load reference character image | |
| "5": { | |
| "class_type": "LoadImage", | |
| "inputs": {"image": ref_image_filename}, | |
| }, | |
| # Resize to target resolution | |
| "6": { | |
| "class_type": "ImageResizeKJv2", | |
| "inputs": { | |
| "image": ["5", 0], | |
| "width": width, | |
| "height": height, | |
| "upscale_method": "lanczos", | |
| "keep_proportion": "pad_edge_pixel", | |
| "pad_color": "0, 0, 0", | |
| "crop_position": "top", | |
| "divisible_by": 16, | |
| }, | |
| }, | |
| # CLIP Vision encode reference | |
| "7": { | |
| "class_type": "WanVideoClipVisionEncode", | |
| "inputs": { | |
| "clip_vision": ["2", 0], | |
| "image_1": ["6", 0], | |
| "strength_1": 1.0, | |
| "strength_2": 1.0, | |
| "crop": "center", | |
| "combine_embeds": "average", | |
| "force_offload": True, | |
| }, | |
| }, | |
| # Load driving video (dance moves) | |
| "8": { | |
| "class_type": "VHS_LoadVideo", | |
| "inputs": { | |
| "video": driving_video_filename, | |
| "force_rate": fps, | |
| "custom_width": 0, | |
| "custom_height": 0, | |
| "frame_load_cap": num_frames if num_frames > 0 else 0, | |
| "skip_first_frames": 0, | |
| "select_every_nth": 1, | |
| }, | |
| }, | |
| # Extract pose skeleton from driving video | |
| "9": { | |
| "class_type": "DWPreprocessor", | |
| "inputs": { | |
| "image": ["8", 0], | |
| "detect_hand": "disable", | |
| "detect_body": "enable", | |
| "detect_face": "disable", | |
| "resolution": max(width, height), | |
| "bbox_detector": "yolox_l.torchscript.pt", | |
| "pose_estimator": "dw-ll_ucoco_384_bs5.torchscript.pt", | |
| "scale_stick_for_xinsr_cn": "disable", | |
| }, | |
| }, | |
| # Animate embeddings: combine ref image + pose + optional background | |
| "10": { | |
| "class_type": "WanVideoAnimateEmbeds", | |
| "inputs": { | |
| "vae": ["1", 0], | |
| "clip_embeds": ["7", 0], | |
| "ref_images": ["6", 0], | |
| "pose_images": ["9", 0], | |
| # bg_mode: "keep" = ref image bg, "driving_video" = video frames bg, "auto" = model decides | |
| **({} if bg_mode == "auto" else { | |
| "bg_images": ["6", 0] if bg_mode == "keep" else ["8", 0], | |
| }), | |
| "width": width, | |
| "height": height, | |
| # When num_frames==0 ("Match video"), link to GetImageSizeAndCount output slot 3 | |
| "num_frames": ["15", 3] if num_frames == 0 else num_frames, | |
| "force_offload": True, | |
| "frame_window_size": 77, | |
| "colormatch": "disabled", | |
| "pose_strength": 1.0, | |
| "face_strength": 1.0, | |
| }, | |
| }, | |
| # Diffusion sampler (no context_options — WanAnim handles looping internally) | |
| "12": { | |
| "class_type": "WanVideoSampler", | |
| "inputs": { | |
| "model": ["3", 0], | |
| "image_embeds": ["10", 0], | |
| "text_embeds": ["16", 0], | |
| "steps": steps, | |
| "cfg": cfg, | |
| "shift": 5.0, | |
| "seed": seed, | |
| "force_offload": True, | |
| "scheduler": "dpm++_sde", | |
| "riflex_freq_index": 0, | |
| "denoise_strength": 1.0, | |
| }, | |
| }, | |
| # Decode latents to frames | |
| "13": { | |
| "class_type": "WanVideoDecode", | |
| "inputs": { | |
| "vae": ["1", 0], | |
| "samples": ["12", 0], | |
| "enable_vae_tiling": True, | |
| "tile_x": 272, | |
| "tile_y": 272, | |
| "tile_stride_x": 144, | |
| "tile_stride_y": 128, | |
| }, | |
| }, | |
| # Combine frames into MP4 | |
| "14": { | |
| "class_type": "VHS_VideoCombine", | |
| "inputs": { | |
| "images": ["13", 0], | |
| "frame_rate": fps, | |
| "loop_count": 0, | |
| "filename_prefix": "WanAnimate", | |
| "format": "video/h264-mp4", | |
| "pix_fmt": "yuv420p", | |
| "crf": 19, | |
| "save_metadata": True, | |
| "trim_to_audio": False, | |
| "pingpong": False, | |
| "save_output": True, | |
| }, | |
| }, | |
| } | |
| # "Match video" mode (num_frames=0): detect actual frame count from posed video | |
| # GetImageSizeAndCount outputs: (IMAGE, width, height, count) — slot 3 = frame count | |
| if num_frames == 0: | |
| workflow["15"] = { | |
| "class_type": "GetImageSizeAndCount", | |
| "inputs": {"image": ["9", 0]}, | |
| } | |
| return workflow | |
| def _build_wan_i2v_workflow( | |
| uploaded_filename: str = None, | |
| image_b64: str = None, | |
| prompt: str = "", | |
| negative_prompt: str = "", | |
| num_frames: int = 81, | |
| fps: int = 24, | |
| seed: int = -1, | |
| ) -> dict: | |
| """Build ComfyUI workflow for WAN 2.2 Image-to-Video.""" | |
| # WAN 2.2 I2V workflow | |
| # This assumes the WAN 2.2 nodes are installed on the pod | |
| workflow = { | |
| # Load the input image | |
| "1": { | |
| "class_type": "LoadImage", | |
| "inputs": { | |
| "image": uploaded_filename or "input.png", | |
| }, | |
| }, | |
| # WAN 2.2 model loader | |
| "2": { | |
| "class_type": "DownloadAndLoadWanModel", | |
| "inputs": { | |
| "model": "Wan2.2-I2V-14B-480P", | |
| }, | |
| }, | |
| # Text encoder | |
| "3": { | |
| "class_type": "WanTextEncode", | |
| "inputs": { | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| "wan_model": ["2", 0], | |
| }, | |
| }, | |
| # Image-to-Video generation | |
| "4": { | |
| "class_type": "WanImageToVideo", | |
| "inputs": { | |
| "image": ["1", 0], | |
| "wan_model": ["2", 0], | |
| "conditioning": ["3", 0], | |
| "num_frames": num_frames, | |
| "seed": seed, | |
| "steps": 30, | |
| "cfg": 5.0, | |
| }, | |
| }, | |
| # Decode to frames | |
| "5": { | |
| "class_type": "WanDecode", | |
| "inputs": { | |
| "samples": ["4", 0], | |
| "wan_model": ["2", 0], | |
| }, | |
| }, | |
| # Save as animated WEBP | |
| "6": { | |
| "class_type": "SaveAnimatedWEBP", | |
| "inputs": { | |
| "images": ["5", 0], | |
| "filename_prefix": "wan_video", | |
| "fps": fps, | |
| "lossless": False, | |
| "quality": 85, | |
| }, | |
| }, | |
| } | |
| return workflow | |