Spaces:
Running
Running
| """RunPod Pod-based generation provider. | |
| Spins up a GPU pod with ComfyUI + FLUX.2 on demand, generates images, | |
| then optionally shuts down. Simpler than serverless (no custom Docker needed). | |
| The pod uses a pre-built ComfyUI image with FLUX.2 support. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import time | |
| from typing import Any | |
| import httpx | |
| import runpod | |
| from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider | |
| logger = logging.getLogger(__name__) | |
| # Pre-built ComfyUI template with FLUX support | |
| COMFYUI_TEMPLATE = "runpod/comfyui:flux" # RunPod's official ComfyUI + FLUX image | |
| DOCKER_IMAGE = "ghcr.io/ai-dock/comfyui:v2-cuda-12.1.1-base" | |
| # Default GPU for FLUX.2 (needs 24GB VRAM) | |
| DEFAULT_GPU = "NVIDIA GeForce RTX 4090" | |
| # ComfyUI API port | |
| COMFYUI_PORT = 8188 | |
| class RunPodPodProvider(CloudProvider): | |
| """Generate images using an on-demand RunPod pod with ComfyUI.""" | |
| def __init__(self, api_key: str, auto_shutdown_minutes: int = 10): | |
| self._api_key = api_key | |
| runpod.api_key = api_key | |
| self._auto_shutdown_minutes = auto_shutdown_minutes | |
| self._pod_id: str | None = None | |
| self._pod_ip: str | None = None | |
| self._pod_port: int | None = None | |
| self._last_activity: float = 0 | |
| self._http = httpx.AsyncClient(timeout=120) | |
| self._shutdown_task: asyncio.Task | None = None | |
| def name(self) -> str: | |
| return "runpod-pod" | |
| async def _ensure_pod_running(self) -> tuple[str, int]: | |
| """Ensure a ComfyUI pod is running. Returns (ip, port).""" | |
| self._last_activity = time.time() | |
| # Check if existing pod is still running | |
| if self._pod_id: | |
| try: | |
| pod = await asyncio.to_thread(runpod.get_pod, self._pod_id) | |
| if pod and pod.get("desiredStatus") == "RUNNING": | |
| runtime = pod.get("runtime", {}) | |
| ports = runtime.get("ports", []) | |
| for p in ports: | |
| if p.get("privatePort") == COMFYUI_PORT: | |
| self._pod_ip = p.get("ip") | |
| self._pod_port = p.get("publicPort") | |
| if self._pod_ip and self._pod_port: | |
| return self._pod_ip, self._pod_port | |
| except Exception as e: | |
| logger.warning("Failed to check pod status: %s", e) | |
| self._pod_id = None | |
| # Create new pod | |
| logger.info("Starting ComfyUI pod with FLUX.2...") | |
| pod = await asyncio.to_thread( | |
| runpod.create_pod, | |
| name="content-engine-comfyui", | |
| image_name=DOCKER_IMAGE, | |
| gpu_type_id=DEFAULT_GPU, | |
| volume_in_gb=50, | |
| container_disk_in_gb=20, | |
| ports=f"{COMFYUI_PORT}/http", | |
| env={ | |
| "PROVISIONING_SCRIPT": "https://raw.githubusercontent.com/ai-dock/comfyui/main/config/provisioning/flux.sh", | |
| }, | |
| ) | |
| self._pod_id = pod["id"] | |
| logger.info("Pod created: %s", self._pod_id) | |
| # Wait for pod to be ready | |
| ip, port = await self._wait_for_pod_ready() | |
| self._pod_ip = ip | |
| self._pod_port = port | |
| # Wait for ComfyUI to be responsive | |
| await self._wait_for_comfyui(ip, port) | |
| # Schedule auto-shutdown | |
| self._schedule_shutdown() | |
| return ip, port | |
| async def _wait_for_pod_ready(self, timeout: int = 300) -> tuple[str, int]: | |
| """Wait for pod to be running and return ComfyUI endpoint.""" | |
| start = time.time() | |
| while time.time() - start < timeout: | |
| try: | |
| pod = await asyncio.to_thread(runpod.get_pod, self._pod_id) | |
| if pod.get("desiredStatus") == "RUNNING": | |
| runtime = pod.get("runtime", {}) | |
| ports = runtime.get("ports", []) | |
| for p in ports: | |
| if p.get("privatePort") == COMFYUI_PORT: | |
| ip = p.get("ip") | |
| port = p.get("publicPort") | |
| if ip and port: | |
| logger.info("Pod ready at %s:%s", ip, port) | |
| return ip, int(port) | |
| except Exception as e: | |
| logger.debug("Waiting for pod: %s", e) | |
| await asyncio.sleep(5) | |
| raise TimeoutError(f"Pod did not become ready within {timeout}s") | |
| async def _wait_for_comfyui(self, ip: str, port: int, timeout: int = 300): | |
| """Wait for ComfyUI API to be responsive.""" | |
| start = time.time() | |
| url = f"http://{ip}:{port}/system_stats" | |
| while time.time() - start < timeout: | |
| try: | |
| resp = await self._http.get(url) | |
| if resp.status_code == 200: | |
| logger.info("ComfyUI is ready!") | |
| return | |
| except Exception: | |
| pass | |
| await asyncio.sleep(5) | |
| logger.info("Waiting for ComfyUI to start...") | |
| raise TimeoutError("ComfyUI did not become ready") | |
| def _schedule_shutdown(self): | |
| """Schedule auto-shutdown after idle period.""" | |
| if self._shutdown_task: | |
| self._shutdown_task.cancel() | |
| async def shutdown_if_idle(): | |
| while True: | |
| await asyncio.sleep(60) # Check every minute | |
| idle_time = time.time() - self._last_activity | |
| if idle_time > self._auto_shutdown_minutes * 60: | |
| logger.info("Auto-shutting down idle pod...") | |
| await self.shutdown_pod() | |
| break | |
| self._shutdown_task = asyncio.create_task(shutdown_if_idle()) | |
| async def shutdown_pod(self): | |
| """Manually shut down the pod.""" | |
| if self._pod_id: | |
| try: | |
| await asyncio.to_thread(runpod.stop_pod, self._pod_id) | |
| logger.info("Pod stopped: %s", self._pod_id) | |
| except Exception as e: | |
| logger.warning("Failed to stop pod: %s", e) | |
| self._pod_id = None | |
| self._pod_ip = None | |
| self._pod_port = None | |
| async def submit_generation( | |
| self, | |
| *, | |
| positive_prompt: str, | |
| negative_prompt: str = "", | |
| checkpoint: str = "flux1-dev.safetensors", | |
| lora_name: str | None = None, | |
| lora_strength: float = 0.85, | |
| seed: int = -1, | |
| steps: int = 28, | |
| cfg: float = 3.5, | |
| width: int = 1024, | |
| height: int = 1024, | |
| ) -> str: | |
| """Submit generation to ComfyUI on the pod.""" | |
| ip, port = await self._ensure_pod_running() | |
| self._last_activity = time.time() | |
| # Build ComfyUI workflow for FLUX | |
| workflow = self._build_flux_workflow( | |
| prompt=positive_prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| steps=steps, | |
| cfg=cfg, | |
| seed=seed, | |
| lora_name=lora_name, | |
| lora_strength=lora_strength, | |
| ) | |
| # Submit to ComfyUI | |
| url = f"http://{ip}:{port}/prompt" | |
| resp = await self._http.post(url, json={"prompt": workflow}) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| prompt_id = data["prompt_id"] | |
| logger.info("ComfyUI job submitted: %s", prompt_id) | |
| return prompt_id | |
| def _build_flux_workflow( | |
| self, | |
| prompt: str, | |
| negative_prompt: str, | |
| width: int, | |
| height: int, | |
| steps: int, | |
| cfg: float, | |
| seed: int, | |
| lora_name: str | None, | |
| lora_strength: float, | |
| ) -> dict: | |
| """Build a ComfyUI workflow for FLUX generation.""" | |
| import random | |
| if seed < 0: | |
| seed = random.randint(0, 2**32 - 1) | |
| # Basic FLUX workflow | |
| workflow = { | |
| "3": { | |
| "class_type": "CheckpointLoaderSimple", | |
| "inputs": {"ckpt_name": "flux1-dev.safetensors"}, | |
| }, | |
| "6": { | |
| "class_type": "CLIPTextEncode", | |
| "inputs": { | |
| "text": prompt, | |
| "clip": ["3", 1], | |
| }, | |
| }, | |
| "7": { | |
| "class_type": "CLIPTextEncode", | |
| "inputs": { | |
| "text": negative_prompt or "", | |
| "clip": ["3", 1], | |
| }, | |
| }, | |
| "5": { | |
| "class_type": "EmptyLatentImage", | |
| "inputs": { | |
| "width": width, | |
| "height": height, | |
| "batch_size": 1, | |
| }, | |
| }, | |
| "10": { | |
| "class_type": "KSampler", | |
| "inputs": { | |
| "seed": seed, | |
| "steps": steps, | |
| "cfg": cfg, | |
| "sampler_name": "euler", | |
| "scheduler": "simple", | |
| "denoise": 1.0, | |
| "model": ["3", 0], | |
| "positive": ["6", 0], | |
| "negative": ["7", 0], | |
| "latent_image": ["5", 0], | |
| }, | |
| }, | |
| "8": { | |
| "class_type": "VAEDecode", | |
| "inputs": { | |
| "samples": ["10", 0], | |
| "vae": ["3", 2], | |
| }, | |
| }, | |
| "9": { | |
| "class_type": "SaveImage", | |
| "inputs": { | |
| "filename_prefix": "flux_gen", | |
| "images": ["8", 0], | |
| }, | |
| }, | |
| } | |
| # Add LoRA if specified | |
| if lora_name: | |
| workflow["4"] = { | |
| "class_type": "LoraLoader", | |
| "inputs": { | |
| "lora_name": lora_name, | |
| "strength_model": lora_strength, | |
| "strength_clip": lora_strength, | |
| "model": ["3", 0], | |
| "clip": ["3", 1], | |
| }, | |
| } | |
| # Rewire sampler to use LoRA output | |
| workflow["10"]["inputs"]["model"] = ["4", 0] | |
| workflow["6"]["inputs"]["clip"] = ["4", 1] | |
| workflow["7"]["inputs"]["clip"] = ["4", 1] | |
| return workflow | |
| async def check_status(self, job_id: str) -> str: | |
| """Check ComfyUI job status.""" | |
| if not self._pod_ip or not self._pod_port: | |
| return "failed" | |
| try: | |
| url = f"http://{self._pod_ip}:{self._pod_port}/history/{job_id}" | |
| resp = await self._http.get(url) | |
| if resp.status_code == 200: | |
| data = resp.json() | |
| if job_id in data: | |
| outputs = data[job_id].get("outputs", {}) | |
| if outputs: | |
| return "completed" | |
| status = data[job_id].get("status", {}) | |
| if status.get("completed"): | |
| return "completed" | |
| if status.get("status_str") == "error": | |
| return "failed" | |
| return "running" | |
| return "pending" | |
| except Exception as e: | |
| logger.error("Status check failed: %s", e) | |
| return "running" | |
| async def get_result(self, job_id: str) -> CloudGenerationResult: | |
| """Get the generated image from ComfyUI.""" | |
| if not self._pod_ip or not self._pod_port: | |
| raise RuntimeError("Pod not running") | |
| # Get history to find output filename | |
| url = f"http://{self._pod_ip}:{self._pod_port}/history/{job_id}" | |
| resp = await self._http.get(url) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| job_data = data.get(job_id, {}) | |
| outputs = job_data.get("outputs", {}) | |
| # Find the SaveImage output | |
| for node_id, node_output in outputs.items(): | |
| if "images" in node_output: | |
| image_info = node_output["images"][0] | |
| filename = image_info["filename"] | |
| subfolder = image_info.get("subfolder", "") | |
| # Download the image | |
| img_url = f"http://{self._pod_ip}:{self._pod_port}/view" | |
| params = {"filename": filename} | |
| if subfolder: | |
| params["subfolder"] = subfolder | |
| img_resp = await self._http.get(img_url, params=params) | |
| img_resp.raise_for_status() | |
| return CloudGenerationResult( | |
| job_id=job_id, | |
| image_bytes=img_resp.content, | |
| generation_time_seconds=0, # TODO: track actual time | |
| ) | |
| raise RuntimeError(f"No image output found for job {job_id}") | |
| async def wait_for_completion( | |
| self, | |
| job_id: str, | |
| timeout: int = 300, | |
| poll_interval: float = 2.0, | |
| ) -> CloudGenerationResult: | |
| """Wait for job completion.""" | |
| start = time.time() | |
| while time.time() - start < timeout: | |
| status = await self.check_status(job_id) | |
| if status == "completed": | |
| return await self.get_result(job_id) | |
| elif status == "failed": | |
| raise RuntimeError(f"ComfyUI job {job_id} failed") | |
| await asyncio.sleep(poll_interval) | |
| raise TimeoutError(f"Job {job_id} timed out after {timeout}s") | |
| async def is_available(self) -> bool: | |
| """Check if RunPod API is accessible.""" | |
| return bool(self._api_key) | |
| async def close(self): | |
| """Cleanup.""" | |
| if self._shutdown_task: | |
| self._shutdown_task.cancel() | |
| await self._http.aclose() | |