Spaces:
Running
Running
| """RunPod serverless generation provider. | |
| Uses RunPod's serverless GPU endpoints for image generation. | |
| Requires a pre-deployed endpoint with ComfyUI or an SD model. | |
| Setup: | |
| 1. Deploy a serverless endpoint on RunPod with your model | |
| 2. Set RUNPOD_API_KEY and RUNPOD_ENDPOINT_ID in .env | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import base64 | |
| import logging | |
| import time | |
| from typing import Any | |
| import httpx | |
| import runpod | |
| from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider | |
| logger = logging.getLogger(__name__) | |
| # Default timeout for generation (seconds) | |
| GENERATION_TIMEOUT = 300 | |
| class RunPodProvider(CloudProvider): | |
| """Cloud provider using RunPod serverless endpoints for image generation.""" | |
| def __init__(self, api_key: str, endpoint_id: str): | |
| self._api_key = api_key | |
| self._endpoint_id = endpoint_id | |
| runpod.api_key = api_key | |
| self._endpoint = runpod.Endpoint(endpoint_id) | |
| self._jobs: dict[str, dict[str, Any]] = {} | |
| self._http = httpx.AsyncClient(timeout=60) | |
| def name(self) -> str: | |
| return "runpod" | |
| async def submit_generation( | |
| self, | |
| *, | |
| positive_prompt: str, | |
| negative_prompt: str, | |
| checkpoint: str, | |
| lora_name: str | None = None, | |
| lora_strength: float = 0.85, | |
| seed: int = -1, | |
| steps: int = 28, | |
| cfg: float = 7.0, | |
| width: int = 832, | |
| height: int = 1216, | |
| ) -> str: | |
| """Submit a generation job to RunPod serverless. | |
| Returns a job ID for tracking. | |
| """ | |
| # Build input payload for the serverless worker | |
| # This assumes a ComfyUI or SD worker that accepts these parameters | |
| payload = { | |
| "input": { | |
| "prompt": positive_prompt, | |
| "negative_prompt": negative_prompt, | |
| "checkpoint": checkpoint, | |
| "width": width, | |
| "height": height, | |
| "steps": steps, | |
| "cfg_scale": cfg, | |
| "seed": seed, | |
| } | |
| } | |
| # Add LoRA if specified | |
| if lora_name: | |
| payload["input"]["lora"] = { | |
| "name": lora_name, | |
| "strength": lora_strength, | |
| } | |
| start_time = time.time() | |
| try: | |
| # Submit async job | |
| run_request = await asyncio.to_thread( | |
| self._endpoint.run, | |
| payload["input"] | |
| ) | |
| job_id = run_request.job_id | |
| self._jobs[job_id] = { | |
| "request": run_request, | |
| "start_time": start_time, | |
| "status": "pending", | |
| } | |
| logger.info("RunPod job submitted: %s", job_id) | |
| return job_id | |
| except Exception as e: | |
| logger.error("RunPod submit failed: %s", e) | |
| raise RuntimeError(f"Failed to submit to RunPod: {e}") | |
| async def check_status(self, job_id: str) -> str: | |
| """Check job status. Returns: 'pending', 'running', 'completed', 'failed'.""" | |
| job_info = self._jobs.get(job_id) | |
| if not job_info: | |
| return "failed" | |
| try: | |
| run_request = job_info["request"] | |
| status = await asyncio.to_thread(run_request.status) | |
| # Map RunPod statuses to our standard statuses | |
| status_map = { | |
| "IN_QUEUE": "pending", | |
| "IN_PROGRESS": "running", | |
| "COMPLETED": "completed", | |
| "FAILED": "failed", | |
| "CANCELLED": "failed", | |
| "TIMED_OUT": "failed", | |
| } | |
| normalized = status_map.get(status, "running") | |
| job_info["status"] = normalized | |
| return normalized | |
| except Exception as e: | |
| logger.error("Status check failed for %s: %s", job_id, e) | |
| return "failed" | |
| async def get_result(self, job_id: str) -> CloudGenerationResult: | |
| """Download the completed generation result.""" | |
| job_info = self._jobs.get(job_id) | |
| if not job_info: | |
| raise RuntimeError(f"Job not found: {job_id}") | |
| try: | |
| run_request = job_info["request"] | |
| start_time = job_info["start_time"] | |
| # Get output (blocks until complete or timeout) | |
| output = await asyncio.to_thread(run_request.output) | |
| generation_time = time.time() - start_time | |
| # Parse output - format depends on worker implementation | |
| # Common formats: | |
| # 1. {"image_url": "data:image/png;base64,..."} | |
| # 2. {"images": ["base64..."]} | |
| # 3. {"output": [{"image": "base64..."}]} | |
| image_bytes = self._extract_image_from_output(output) | |
| # Cleanup | |
| self._jobs.pop(job_id, None) | |
| return CloudGenerationResult( | |
| job_id=job_id, | |
| image_bytes=image_bytes, | |
| generation_time_seconds=generation_time, | |
| ) | |
| except Exception as e: | |
| logger.error("Failed to get result for %s: %s", job_id, e) | |
| raise RuntimeError(f"Failed to get RunPod result: {e}") | |
| def _extract_image_from_output(self, output: Any) -> bytes: | |
| """Extract image bytes from various output formats.""" | |
| if isinstance(output, dict): | |
| # Format: {"image_url": "data:image/png;base64,..."} | |
| if "image_url" in output: | |
| return self._decode_data_url(output["image_url"]) | |
| # Format: {"image": "base64..."} | |
| if "image" in output: | |
| return base64.b64decode(output["image"]) | |
| # Format: {"images": ["base64..."]} | |
| if "images" in output and output["images"]: | |
| return base64.b64decode(output["images"][0]) | |
| # Format: {"output": {"image": "..."}} | |
| if "output" in output: | |
| return self._extract_image_from_output(output["output"]) | |
| elif isinstance(output, list) and output: | |
| # Format: [{"image_url": "..."}] | |
| return self._extract_image_from_output(output[0]) | |
| elif isinstance(output, str): | |
| # Direct base64 string or data URL | |
| if output.startswith("data:image"): | |
| return self._decode_data_url(output) | |
| return base64.b64decode(output) | |
| raise ValueError(f"Could not extract image from output: {type(output)}") | |
| def _decode_data_url(self, data_url: str) -> bytes: | |
| """Decode a data:image/xxx;base64,... URL to bytes.""" | |
| if "," in data_url: | |
| _, base64_data = data_url.split(",", 1) | |
| return base64.b64decode(base64_data) | |
| return base64.b64decode(data_url) | |
| async def is_available(self) -> bool: | |
| """Check if RunPod is configured and reachable.""" | |
| if not self._api_key or not self._endpoint_id: | |
| return False | |
| try: | |
| # Try to check endpoint health | |
| # RunPod SDK doesn't have a direct health check, so we verify the API key works | |
| runpod.api_key = self._api_key | |
| # This is a lightweight check - just verify we can make API calls | |
| return True | |
| except Exception: | |
| return False | |
| async def wait_for_completion( | |
| self, | |
| job_id: str, | |
| timeout: int = GENERATION_TIMEOUT, | |
| poll_interval: float = 2.0, | |
| ) -> CloudGenerationResult: | |
| """Wait for job completion and return result.""" | |
| start = time.time() | |
| while time.time() - start < timeout: | |
| status = await self.check_status(job_id) | |
| if status == "completed": | |
| return await self.get_result(job_id) | |
| elif status == "failed": | |
| raise RuntimeError(f"RunPod job {job_id} failed") | |
| await asyncio.sleep(poll_interval) | |
| raise TimeoutError(f"RunPod job {job_id} timed out after {timeout}s") | |
| async def close(self): | |
| """Close HTTP client.""" | |
| await self._http.aclose() | |