"""RunPod Pod-based generation provider. Spins up a GPU pod with ComfyUI + FLUX.2 on demand, generates images, then optionally shuts down. Simpler than serverless (no custom Docker needed). The pod uses a pre-built ComfyUI image with FLUX.2 support. """ from __future__ import annotations import asyncio import logging import time from typing import Any import httpx import runpod from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider logger = logging.getLogger(__name__) # Pre-built ComfyUI template with FLUX support COMFYUI_TEMPLATE = "runpod/comfyui:flux" # RunPod's official ComfyUI + FLUX image DOCKER_IMAGE = "ghcr.io/ai-dock/comfyui:v2-cuda-12.1.1-base" # Default GPU for FLUX.2 (needs 24GB VRAM) DEFAULT_GPU = "NVIDIA GeForce RTX 4090" # ComfyUI API port COMFYUI_PORT = 8188 class RunPodPodProvider(CloudProvider): """Generate images using an on-demand RunPod pod with ComfyUI.""" def __init__(self, api_key: str, auto_shutdown_minutes: int = 10): self._api_key = api_key runpod.api_key = api_key self._auto_shutdown_minutes = auto_shutdown_minutes self._pod_id: str | None = None self._pod_ip: str | None = None self._pod_port: int | None = None self._last_activity: float = 0 self._http = httpx.AsyncClient(timeout=120) self._shutdown_task: asyncio.Task | None = None @property def name(self) -> str: return "runpod-pod" async def _ensure_pod_running(self) -> tuple[str, int]: """Ensure a ComfyUI pod is running. Returns (ip, port).""" self._last_activity = time.time() # Check if existing pod is still running if self._pod_id: try: pod = await asyncio.to_thread(runpod.get_pod, self._pod_id) if pod and pod.get("desiredStatus") == "RUNNING": runtime = pod.get("runtime", {}) ports = runtime.get("ports", []) for p in ports: if p.get("privatePort") == COMFYUI_PORT: self._pod_ip = p.get("ip") self._pod_port = p.get("publicPort") if self._pod_ip and self._pod_port: return self._pod_ip, self._pod_port except Exception as e: logger.warning("Failed to check pod status: %s", e) self._pod_id = None # Create new pod logger.info("Starting ComfyUI pod with FLUX.2...") pod = await asyncio.to_thread( runpod.create_pod, name="content-engine-comfyui", image_name=DOCKER_IMAGE, gpu_type_id=DEFAULT_GPU, volume_in_gb=50, container_disk_in_gb=20, ports=f"{COMFYUI_PORT}/http", env={ "PROVISIONING_SCRIPT": "https://raw.githubusercontent.com/ai-dock/comfyui/main/config/provisioning/flux.sh", }, ) self._pod_id = pod["id"] logger.info("Pod created: %s", self._pod_id) # Wait for pod to be ready ip, port = await self._wait_for_pod_ready() self._pod_ip = ip self._pod_port = port # Wait for ComfyUI to be responsive await self._wait_for_comfyui(ip, port) # Schedule auto-shutdown self._schedule_shutdown() return ip, port async def _wait_for_pod_ready(self, timeout: int = 300) -> tuple[str, int]: """Wait for pod to be running and return ComfyUI endpoint.""" start = time.time() while time.time() - start < timeout: try: pod = await asyncio.to_thread(runpod.get_pod, self._pod_id) if pod.get("desiredStatus") == "RUNNING": runtime = pod.get("runtime", {}) ports = runtime.get("ports", []) for p in ports: if p.get("privatePort") == COMFYUI_PORT: ip = p.get("ip") port = p.get("publicPort") if ip and port: logger.info("Pod ready at %s:%s", ip, port) return ip, int(port) except Exception as e: logger.debug("Waiting for pod: %s", e) await asyncio.sleep(5) raise TimeoutError(f"Pod did not become ready within {timeout}s") async def _wait_for_comfyui(self, ip: str, port: int, timeout: int = 300): """Wait for ComfyUI API to be responsive.""" start = time.time() url = f"http://{ip}:{port}/system_stats" while time.time() - start < timeout: try: resp = await self._http.get(url) if resp.status_code == 200: logger.info("ComfyUI is ready!") return except Exception: pass await asyncio.sleep(5) logger.info("Waiting for ComfyUI to start...") raise TimeoutError("ComfyUI did not become ready") def _schedule_shutdown(self): """Schedule auto-shutdown after idle period.""" if self._shutdown_task: self._shutdown_task.cancel() async def shutdown_if_idle(): while True: await asyncio.sleep(60) # Check every minute idle_time = time.time() - self._last_activity if idle_time > self._auto_shutdown_minutes * 60: logger.info("Auto-shutting down idle pod...") await self.shutdown_pod() break self._shutdown_task = asyncio.create_task(shutdown_if_idle()) async def shutdown_pod(self): """Manually shut down the pod.""" if self._pod_id: try: await asyncio.to_thread(runpod.stop_pod, self._pod_id) logger.info("Pod stopped: %s", self._pod_id) except Exception as e: logger.warning("Failed to stop pod: %s", e) self._pod_id = None self._pod_ip = None self._pod_port = None async def submit_generation( self, *, positive_prompt: str, negative_prompt: str = "", checkpoint: str = "flux1-dev.safetensors", lora_name: str | None = None, lora_strength: float = 0.85, seed: int = -1, steps: int = 28, cfg: float = 3.5, width: int = 1024, height: int = 1024, ) -> str: """Submit generation to ComfyUI on the pod.""" ip, port = await self._ensure_pod_running() self._last_activity = time.time() # Build ComfyUI workflow for FLUX workflow = self._build_flux_workflow( prompt=positive_prompt, negative_prompt=negative_prompt, width=width, height=height, steps=steps, cfg=cfg, seed=seed, lora_name=lora_name, lora_strength=lora_strength, ) # Submit to ComfyUI url = f"http://{ip}:{port}/prompt" resp = await self._http.post(url, json={"prompt": workflow}) resp.raise_for_status() data = resp.json() prompt_id = data["prompt_id"] logger.info("ComfyUI job submitted: %s", prompt_id) return prompt_id def _build_flux_workflow( self, prompt: str, negative_prompt: str, width: int, height: int, steps: int, cfg: float, seed: int, lora_name: str | None, lora_strength: float, ) -> dict: """Build a ComfyUI workflow for FLUX generation.""" import random if seed < 0: seed = random.randint(0, 2**32 - 1) # Basic FLUX workflow workflow = { "3": { "class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "flux1-dev.safetensors"}, }, "6": { "class_type": "CLIPTextEncode", "inputs": { "text": prompt, "clip": ["3", 1], }, }, "7": { "class_type": "CLIPTextEncode", "inputs": { "text": negative_prompt or "", "clip": ["3", 1], }, }, "5": { "class_type": "EmptyLatentImage", "inputs": { "width": width, "height": height, "batch_size": 1, }, }, "10": { "class_type": "KSampler", "inputs": { "seed": seed, "steps": steps, "cfg": cfg, "sampler_name": "euler", "scheduler": "simple", "denoise": 1.0, "model": ["3", 0], "positive": ["6", 0], "negative": ["7", 0], "latent_image": ["5", 0], }, }, "8": { "class_type": "VAEDecode", "inputs": { "samples": ["10", 0], "vae": ["3", 2], }, }, "9": { "class_type": "SaveImage", "inputs": { "filename_prefix": "flux_gen", "images": ["8", 0], }, }, } # Add LoRA if specified if lora_name: workflow["4"] = { "class_type": "LoraLoader", "inputs": { "lora_name": lora_name, "strength_model": lora_strength, "strength_clip": lora_strength, "model": ["3", 0], "clip": ["3", 1], }, } # Rewire sampler to use LoRA output workflow["10"]["inputs"]["model"] = ["4", 0] workflow["6"]["inputs"]["clip"] = ["4", 1] workflow["7"]["inputs"]["clip"] = ["4", 1] return workflow async def check_status(self, job_id: str) -> str: """Check ComfyUI job status.""" if not self._pod_ip or not self._pod_port: return "failed" try: url = f"http://{self._pod_ip}:{self._pod_port}/history/{job_id}" resp = await self._http.get(url) if resp.status_code == 200: data = resp.json() if job_id in data: outputs = data[job_id].get("outputs", {}) if outputs: return "completed" status = data[job_id].get("status", {}) if status.get("completed"): return "completed" if status.get("status_str") == "error": return "failed" return "running" return "pending" except Exception as e: logger.error("Status check failed: %s", e) return "running" async def get_result(self, job_id: str) -> CloudGenerationResult: """Get the generated image from ComfyUI.""" if not self._pod_ip or not self._pod_port: raise RuntimeError("Pod not running") # Get history to find output filename url = f"http://{self._pod_ip}:{self._pod_port}/history/{job_id}" resp = await self._http.get(url) resp.raise_for_status() data = resp.json() job_data = data.get(job_id, {}) outputs = job_data.get("outputs", {}) # Find the SaveImage output for node_id, node_output in outputs.items(): if "images" in node_output: image_info = node_output["images"][0] filename = image_info["filename"] subfolder = image_info.get("subfolder", "") # Download the image img_url = f"http://{self._pod_ip}:{self._pod_port}/view" params = {"filename": filename} if subfolder: params["subfolder"] = subfolder img_resp = await self._http.get(img_url, params=params) img_resp.raise_for_status() return CloudGenerationResult( job_id=job_id, image_bytes=img_resp.content, generation_time_seconds=0, # TODO: track actual time ) raise RuntimeError(f"No image output found for job {job_id}") async def wait_for_completion( self, job_id: str, timeout: int = 300, poll_interval: float = 2.0, ) -> CloudGenerationResult: """Wait for job completion.""" start = time.time() while time.time() - start < timeout: status = await self.check_status(job_id) if status == "completed": return await self.get_result(job_id) elif status == "failed": raise RuntimeError(f"ComfyUI job {job_id} failed") await asyncio.sleep(poll_interval) raise TimeoutError(f"Job {job_id} timed out after {timeout}s") async def is_available(self) -> bool: """Check if RunPod API is accessible.""" return bool(self._api_key) async def close(self): """Cleanup.""" if self._shutdown_task: self._shutdown_task.cancel() await self._http.aclose()