dippoo's picture
Initial deployment - Content Engine
ed37502
raw
history blame
13.8 kB
"""RunPod Pod-based generation provider.
Spins up a GPU pod with ComfyUI + FLUX.2 on demand, generates images,
then optionally shuts down. Simpler than serverless (no custom Docker needed).
The pod uses a pre-built ComfyUI image with FLUX.2 support.
"""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Any
import httpx
import runpod
from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider
logger = logging.getLogger(__name__)
# Pre-built ComfyUI template with FLUX support
COMFYUI_TEMPLATE = "runpod/comfyui:flux" # RunPod's official ComfyUI + FLUX image
DOCKER_IMAGE = "ghcr.io/ai-dock/comfyui:v2-cuda-12.1.1-base"
# Default GPU for FLUX.2 (needs 24GB VRAM)
DEFAULT_GPU = "NVIDIA GeForce RTX 4090"
# ComfyUI API port
COMFYUI_PORT = 8188
class RunPodPodProvider(CloudProvider):
"""Generate images using an on-demand RunPod pod with ComfyUI."""
def __init__(self, api_key: str, auto_shutdown_minutes: int = 10):
self._api_key = api_key
runpod.api_key = api_key
self._auto_shutdown_minutes = auto_shutdown_minutes
self._pod_id: str | None = None
self._pod_ip: str | None = None
self._pod_port: int | None = None
self._last_activity: float = 0
self._http = httpx.AsyncClient(timeout=120)
self._shutdown_task: asyncio.Task | None = None
@property
def name(self) -> str:
return "runpod-pod"
async def _ensure_pod_running(self) -> tuple[str, int]:
"""Ensure a ComfyUI pod is running. Returns (ip, port)."""
self._last_activity = time.time()
# Check if existing pod is still running
if self._pod_id:
try:
pod = await asyncio.to_thread(runpod.get_pod, self._pod_id)
if pod and pod.get("desiredStatus") == "RUNNING":
runtime = pod.get("runtime", {})
ports = runtime.get("ports", [])
for p in ports:
if p.get("privatePort") == COMFYUI_PORT:
self._pod_ip = p.get("ip")
self._pod_port = p.get("publicPort")
if self._pod_ip and self._pod_port:
return self._pod_ip, self._pod_port
except Exception as e:
logger.warning("Failed to check pod status: %s", e)
self._pod_id = None
# Create new pod
logger.info("Starting ComfyUI pod with FLUX.2...")
pod = await asyncio.to_thread(
runpod.create_pod,
name="content-engine-comfyui",
image_name=DOCKER_IMAGE,
gpu_type_id=DEFAULT_GPU,
volume_in_gb=50,
container_disk_in_gb=20,
ports=f"{COMFYUI_PORT}/http",
env={
"PROVISIONING_SCRIPT": "https://raw.githubusercontent.com/ai-dock/comfyui/main/config/provisioning/flux.sh",
},
)
self._pod_id = pod["id"]
logger.info("Pod created: %s", self._pod_id)
# Wait for pod to be ready
ip, port = await self._wait_for_pod_ready()
self._pod_ip = ip
self._pod_port = port
# Wait for ComfyUI to be responsive
await self._wait_for_comfyui(ip, port)
# Schedule auto-shutdown
self._schedule_shutdown()
return ip, port
async def _wait_for_pod_ready(self, timeout: int = 300) -> tuple[str, int]:
"""Wait for pod to be running and return ComfyUI endpoint."""
start = time.time()
while time.time() - start < timeout:
try:
pod = await asyncio.to_thread(runpod.get_pod, self._pod_id)
if pod.get("desiredStatus") == "RUNNING":
runtime = pod.get("runtime", {})
ports = runtime.get("ports", [])
for p in ports:
if p.get("privatePort") == COMFYUI_PORT:
ip = p.get("ip")
port = p.get("publicPort")
if ip and port:
logger.info("Pod ready at %s:%s", ip, port)
return ip, int(port)
except Exception as e:
logger.debug("Waiting for pod: %s", e)
await asyncio.sleep(5)
raise TimeoutError(f"Pod did not become ready within {timeout}s")
async def _wait_for_comfyui(self, ip: str, port: int, timeout: int = 300):
"""Wait for ComfyUI API to be responsive."""
start = time.time()
url = f"http://{ip}:{port}/system_stats"
while time.time() - start < timeout:
try:
resp = await self._http.get(url)
if resp.status_code == 200:
logger.info("ComfyUI is ready!")
return
except Exception:
pass
await asyncio.sleep(5)
logger.info("Waiting for ComfyUI to start...")
raise TimeoutError("ComfyUI did not become ready")
def _schedule_shutdown(self):
"""Schedule auto-shutdown after idle period."""
if self._shutdown_task:
self._shutdown_task.cancel()
async def shutdown_if_idle():
while True:
await asyncio.sleep(60) # Check every minute
idle_time = time.time() - self._last_activity
if idle_time > self._auto_shutdown_minutes * 60:
logger.info("Auto-shutting down idle pod...")
await self.shutdown_pod()
break
self._shutdown_task = asyncio.create_task(shutdown_if_idle())
async def shutdown_pod(self):
"""Manually shut down the pod."""
if self._pod_id:
try:
await asyncio.to_thread(runpod.stop_pod, self._pod_id)
logger.info("Pod stopped: %s", self._pod_id)
except Exception as e:
logger.warning("Failed to stop pod: %s", e)
self._pod_id = None
self._pod_ip = None
self._pod_port = None
async def submit_generation(
self,
*,
positive_prompt: str,
negative_prompt: str = "",
checkpoint: str = "flux1-dev.safetensors",
lora_name: str | None = None,
lora_strength: float = 0.85,
seed: int = -1,
steps: int = 28,
cfg: float = 3.5,
width: int = 1024,
height: int = 1024,
) -> str:
"""Submit generation to ComfyUI on the pod."""
ip, port = await self._ensure_pod_running()
self._last_activity = time.time()
# Build ComfyUI workflow for FLUX
workflow = self._build_flux_workflow(
prompt=positive_prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
cfg=cfg,
seed=seed,
lora_name=lora_name,
lora_strength=lora_strength,
)
# Submit to ComfyUI
url = f"http://{ip}:{port}/prompt"
resp = await self._http.post(url, json={"prompt": workflow})
resp.raise_for_status()
data = resp.json()
prompt_id = data["prompt_id"]
logger.info("ComfyUI job submitted: %s", prompt_id)
return prompt_id
def _build_flux_workflow(
self,
prompt: str,
negative_prompt: str,
width: int,
height: int,
steps: int,
cfg: float,
seed: int,
lora_name: str | None,
lora_strength: float,
) -> dict:
"""Build a ComfyUI workflow for FLUX generation."""
import random
if seed < 0:
seed = random.randint(0, 2**32 - 1)
# Basic FLUX workflow
workflow = {
"3": {
"class_type": "CheckpointLoaderSimple",
"inputs": {"ckpt_name": "flux1-dev.safetensors"},
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": prompt,
"clip": ["3", 1],
},
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": negative_prompt or "",
"clip": ["3", 1],
},
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"width": width,
"height": height,
"batch_size": 1,
},
},
"10": {
"class_type": "KSampler",
"inputs": {
"seed": seed,
"steps": steps,
"cfg": cfg,
"sampler_name": "euler",
"scheduler": "simple",
"denoise": 1.0,
"model": ["3", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0],
},
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["10", 0],
"vae": ["3", 2],
},
},
"9": {
"class_type": "SaveImage",
"inputs": {
"filename_prefix": "flux_gen",
"images": ["8", 0],
},
},
}
# Add LoRA if specified
if lora_name:
workflow["4"] = {
"class_type": "LoraLoader",
"inputs": {
"lora_name": lora_name,
"strength_model": lora_strength,
"strength_clip": lora_strength,
"model": ["3", 0],
"clip": ["3", 1],
},
}
# Rewire sampler to use LoRA output
workflow["10"]["inputs"]["model"] = ["4", 0]
workflow["6"]["inputs"]["clip"] = ["4", 1]
workflow["7"]["inputs"]["clip"] = ["4", 1]
return workflow
async def check_status(self, job_id: str) -> str:
"""Check ComfyUI job status."""
if not self._pod_ip or not self._pod_port:
return "failed"
try:
url = f"http://{self._pod_ip}:{self._pod_port}/history/{job_id}"
resp = await self._http.get(url)
if resp.status_code == 200:
data = resp.json()
if job_id in data:
outputs = data[job_id].get("outputs", {})
if outputs:
return "completed"
status = data[job_id].get("status", {})
if status.get("completed"):
return "completed"
if status.get("status_str") == "error":
return "failed"
return "running"
return "pending"
except Exception as e:
logger.error("Status check failed: %s", e)
return "running"
async def get_result(self, job_id: str) -> CloudGenerationResult:
"""Get the generated image from ComfyUI."""
if not self._pod_ip or not self._pod_port:
raise RuntimeError("Pod not running")
# Get history to find output filename
url = f"http://{self._pod_ip}:{self._pod_port}/history/{job_id}"
resp = await self._http.get(url)
resp.raise_for_status()
data = resp.json()
job_data = data.get(job_id, {})
outputs = job_data.get("outputs", {})
# Find the SaveImage output
for node_id, node_output in outputs.items():
if "images" in node_output:
image_info = node_output["images"][0]
filename = image_info["filename"]
subfolder = image_info.get("subfolder", "")
# Download the image
img_url = f"http://{self._pod_ip}:{self._pod_port}/view"
params = {"filename": filename}
if subfolder:
params["subfolder"] = subfolder
img_resp = await self._http.get(img_url, params=params)
img_resp.raise_for_status()
return CloudGenerationResult(
job_id=job_id,
image_bytes=img_resp.content,
generation_time_seconds=0, # TODO: track actual time
)
raise RuntimeError(f"No image output found for job {job_id}")
async def wait_for_completion(
self,
job_id: str,
timeout: int = 300,
poll_interval: float = 2.0,
) -> CloudGenerationResult:
"""Wait for job completion."""
start = time.time()
while time.time() - start < timeout:
status = await self.check_status(job_id)
if status == "completed":
return await self.get_result(job_id)
elif status == "failed":
raise RuntimeError(f"ComfyUI job {job_id} failed")
await asyncio.sleep(poll_interval)
raise TimeoutError(f"Job {job_id} timed out after {timeout}s")
async def is_available(self) -> bool:
"""Check if RunPod API is accessible."""
return bool(self._api_key)
async def close(self):
"""Cleanup."""
if self._shutdown_task:
self._shutdown_task.cancel()
await self._http.aclose()