dippoo's picture
Initial deployment - Content Engine
ed37502
raw
history blame
8.09 kB
"""RunPod serverless generation provider.
Uses RunPod's serverless GPU endpoints for image generation.
Requires a pre-deployed endpoint with ComfyUI or an SD model.
Setup:
1. Deploy a serverless endpoint on RunPod with your model
2. Set RUNPOD_API_KEY and RUNPOD_ENDPOINT_ID in .env
"""
from __future__ import annotations
import asyncio
import base64
import logging
import time
from typing import Any
import httpx
import runpod
from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider
logger = logging.getLogger(__name__)
# Default timeout for generation (seconds)
GENERATION_TIMEOUT = 300
class RunPodProvider(CloudProvider):
"""Cloud provider using RunPod serverless endpoints for image generation."""
def __init__(self, api_key: str, endpoint_id: str):
self._api_key = api_key
self._endpoint_id = endpoint_id
runpod.api_key = api_key
self._endpoint = runpod.Endpoint(endpoint_id)
self._jobs: dict[str, dict[str, Any]] = {}
self._http = httpx.AsyncClient(timeout=60)
@property
def name(self) -> str:
return "runpod"
async def submit_generation(
self,
*,
positive_prompt: str,
negative_prompt: str,
checkpoint: str,
lora_name: str | None = None,
lora_strength: float = 0.85,
seed: int = -1,
steps: int = 28,
cfg: float = 7.0,
width: int = 832,
height: int = 1216,
) -> str:
"""Submit a generation job to RunPod serverless.
Returns a job ID for tracking.
"""
# Build input payload for the serverless worker
# This assumes a ComfyUI or SD worker that accepts these parameters
payload = {
"input": {
"prompt": positive_prompt,
"negative_prompt": negative_prompt,
"checkpoint": checkpoint,
"width": width,
"height": height,
"steps": steps,
"cfg_scale": cfg,
"seed": seed,
}
}
# Add LoRA if specified
if lora_name:
payload["input"]["lora"] = {
"name": lora_name,
"strength": lora_strength,
}
start_time = time.time()
try:
# Submit async job
run_request = await asyncio.to_thread(
self._endpoint.run,
payload["input"]
)
job_id = run_request.job_id
self._jobs[job_id] = {
"request": run_request,
"start_time": start_time,
"status": "pending",
}
logger.info("RunPod job submitted: %s", job_id)
return job_id
except Exception as e:
logger.error("RunPod submit failed: %s", e)
raise RuntimeError(f"Failed to submit to RunPod: {e}")
async def check_status(self, job_id: str) -> str:
"""Check job status. Returns: 'pending', 'running', 'completed', 'failed'."""
job_info = self._jobs.get(job_id)
if not job_info:
return "failed"
try:
run_request = job_info["request"]
status = await asyncio.to_thread(run_request.status)
# Map RunPod statuses to our standard statuses
status_map = {
"IN_QUEUE": "pending",
"IN_PROGRESS": "running",
"COMPLETED": "completed",
"FAILED": "failed",
"CANCELLED": "failed",
"TIMED_OUT": "failed",
}
normalized = status_map.get(status, "running")
job_info["status"] = normalized
return normalized
except Exception as e:
logger.error("Status check failed for %s: %s", job_id, e)
return "failed"
async def get_result(self, job_id: str) -> CloudGenerationResult:
"""Download the completed generation result."""
job_info = self._jobs.get(job_id)
if not job_info:
raise RuntimeError(f"Job not found: {job_id}")
try:
run_request = job_info["request"]
start_time = job_info["start_time"]
# Get output (blocks until complete or timeout)
output = await asyncio.to_thread(run_request.output)
generation_time = time.time() - start_time
# Parse output - format depends on worker implementation
# Common formats:
# 1. {"image_url": "data:image/png;base64,..."}
# 2. {"images": ["base64..."]}
# 3. {"output": [{"image": "base64..."}]}
image_bytes = self._extract_image_from_output(output)
# Cleanup
self._jobs.pop(job_id, None)
return CloudGenerationResult(
job_id=job_id,
image_bytes=image_bytes,
generation_time_seconds=generation_time,
)
except Exception as e:
logger.error("Failed to get result for %s: %s", job_id, e)
raise RuntimeError(f"Failed to get RunPod result: {e}")
def _extract_image_from_output(self, output: Any) -> bytes:
"""Extract image bytes from various output formats."""
if isinstance(output, dict):
# Format: {"image_url": "data:image/png;base64,..."}
if "image_url" in output:
return self._decode_data_url(output["image_url"])
# Format: {"image": "base64..."}
if "image" in output:
return base64.b64decode(output["image"])
# Format: {"images": ["base64..."]}
if "images" in output and output["images"]:
return base64.b64decode(output["images"][0])
# Format: {"output": {"image": "..."}}
if "output" in output:
return self._extract_image_from_output(output["output"])
elif isinstance(output, list) and output:
# Format: [{"image_url": "..."}]
return self._extract_image_from_output(output[0])
elif isinstance(output, str):
# Direct base64 string or data URL
if output.startswith("data:image"):
return self._decode_data_url(output)
return base64.b64decode(output)
raise ValueError(f"Could not extract image from output: {type(output)}")
def _decode_data_url(self, data_url: str) -> bytes:
"""Decode a data:image/xxx;base64,... URL to bytes."""
if "," in data_url:
_, base64_data = data_url.split(",", 1)
return base64.b64decode(base64_data)
return base64.b64decode(data_url)
async def is_available(self) -> bool:
"""Check if RunPod is configured and reachable."""
if not self._api_key or not self._endpoint_id:
return False
try:
# Try to check endpoint health
# RunPod SDK doesn't have a direct health check, so we verify the API key works
runpod.api_key = self._api_key
# This is a lightweight check - just verify we can make API calls
return True
except Exception:
return False
async def wait_for_completion(
self,
job_id: str,
timeout: int = GENERATION_TIMEOUT,
poll_interval: float = 2.0,
) -> CloudGenerationResult:
"""Wait for job completion and return result."""
start = time.time()
while time.time() - start < timeout:
status = await self.check_status(job_id)
if status == "completed":
return await self.get_result(job_id)
elif status == "failed":
raise RuntimeError(f"RunPod job {job_id} failed")
await asyncio.sleep(poll_interval)
raise TimeoutError(f"RunPod job {job_id} timed out after {timeout}s")
async def close(self):
"""Close HTTP client."""
await self._http.aclose()