"""RunPod serverless generation provider. Uses RunPod's serverless GPU endpoints for image generation. Requires a pre-deployed endpoint with ComfyUI or an SD model. Setup: 1. Deploy a serverless endpoint on RunPod with your model 2. Set RUNPOD_API_KEY and RUNPOD_ENDPOINT_ID in .env """ from __future__ import annotations import asyncio import base64 import logging import time from typing import Any import httpx import runpod from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider logger = logging.getLogger(__name__) # Default timeout for generation (seconds) GENERATION_TIMEOUT = 300 class RunPodProvider(CloudProvider): """Cloud provider using RunPod serverless endpoints for image generation.""" def __init__(self, api_key: str, endpoint_id: str): self._api_key = api_key self._endpoint_id = endpoint_id runpod.api_key = api_key self._endpoint = runpod.Endpoint(endpoint_id) self._jobs: dict[str, dict[str, Any]] = {} self._http = httpx.AsyncClient(timeout=60) @property def name(self) -> str: return "runpod" async def submit_generation( self, *, positive_prompt: str, negative_prompt: str, checkpoint: str, lora_name: str | None = None, lora_strength: float = 0.85, seed: int = -1, steps: int = 28, cfg: float = 7.0, width: int = 832, height: int = 1216, ) -> str: """Submit a generation job to RunPod serverless. Returns a job ID for tracking. """ # Build input payload for the serverless worker # This assumes a ComfyUI or SD worker that accepts these parameters payload = { "input": { "prompt": positive_prompt, "negative_prompt": negative_prompt, "checkpoint": checkpoint, "width": width, "height": height, "steps": steps, "cfg_scale": cfg, "seed": seed, } } # Add LoRA if specified if lora_name: payload["input"]["lora"] = { "name": lora_name, "strength": lora_strength, } start_time = time.time() try: # Submit async job run_request = await asyncio.to_thread( self._endpoint.run, payload["input"] ) job_id = run_request.job_id self._jobs[job_id] = { "request": run_request, "start_time": start_time, "status": "pending", } logger.info("RunPod job submitted: %s", job_id) return job_id except Exception as e: logger.error("RunPod submit failed: %s", e) raise RuntimeError(f"Failed to submit to RunPod: {e}") async def check_status(self, job_id: str) -> str: """Check job status. Returns: 'pending', 'running', 'completed', 'failed'.""" job_info = self._jobs.get(job_id) if not job_info: return "failed" try: run_request = job_info["request"] status = await asyncio.to_thread(run_request.status) # Map RunPod statuses to our standard statuses status_map = { "IN_QUEUE": "pending", "IN_PROGRESS": "running", "COMPLETED": "completed", "FAILED": "failed", "CANCELLED": "failed", "TIMED_OUT": "failed", } normalized = status_map.get(status, "running") job_info["status"] = normalized return normalized except Exception as e: logger.error("Status check failed for %s: %s", job_id, e) return "failed" async def get_result(self, job_id: str) -> CloudGenerationResult: """Download the completed generation result.""" job_info = self._jobs.get(job_id) if not job_info: raise RuntimeError(f"Job not found: {job_id}") try: run_request = job_info["request"] start_time = job_info["start_time"] # Get output (blocks until complete or timeout) output = await asyncio.to_thread(run_request.output) generation_time = time.time() - start_time # Parse output - format depends on worker implementation # Common formats: # 1. {"image_url": "data:image/png;base64,..."} # 2. {"images": ["base64..."]} # 3. {"output": [{"image": "base64..."}]} image_bytes = self._extract_image_from_output(output) # Cleanup self._jobs.pop(job_id, None) return CloudGenerationResult( job_id=job_id, image_bytes=image_bytes, generation_time_seconds=generation_time, ) except Exception as e: logger.error("Failed to get result for %s: %s", job_id, e) raise RuntimeError(f"Failed to get RunPod result: {e}") def _extract_image_from_output(self, output: Any) -> bytes: """Extract image bytes from various output formats.""" if isinstance(output, dict): # Format: {"image_url": "data:image/png;base64,..."} if "image_url" in output: return self._decode_data_url(output["image_url"]) # Format: {"image": "base64..."} if "image" in output: return base64.b64decode(output["image"]) # Format: {"images": ["base64..."]} if "images" in output and output["images"]: return base64.b64decode(output["images"][0]) # Format: {"output": {"image": "..."}} if "output" in output: return self._extract_image_from_output(output["output"]) elif isinstance(output, list) and output: # Format: [{"image_url": "..."}] return self._extract_image_from_output(output[0]) elif isinstance(output, str): # Direct base64 string or data URL if output.startswith("data:image"): return self._decode_data_url(output) return base64.b64decode(output) raise ValueError(f"Could not extract image from output: {type(output)}") def _decode_data_url(self, data_url: str) -> bytes: """Decode a data:image/xxx;base64,... URL to bytes.""" if "," in data_url: _, base64_data = data_url.split(",", 1) return base64.b64decode(base64_data) return base64.b64decode(data_url) async def is_available(self) -> bool: """Check if RunPod is configured and reachable.""" if not self._api_key or not self._endpoint_id: return False try: # Try to check endpoint health # RunPod SDK doesn't have a direct health check, so we verify the API key works runpod.api_key = self._api_key # This is a lightweight check - just verify we can make API calls return True except Exception: return False async def wait_for_completion( self, job_id: str, timeout: int = GENERATION_TIMEOUT, poll_interval: float = 2.0, ) -> CloudGenerationResult: """Wait for job completion and return result.""" start = time.time() while time.time() - start < timeout: status = await self.check_status(job_id) if status == "completed": return await self.get_result(job_id) elif status == "failed": raise RuntimeError(f"RunPod job {job_id} failed") await asyncio.sleep(poll_interval) raise TimeoutError(f"RunPod job {job_id} timed out after {timeout}s") async def close(self): """Close HTTP client.""" await self._http.aclose()