"""WaveSpeed.ai cloud provider — integrates NanoBanana, SeeDream and other models. WaveSpeed provides fast cloud inference for text-to-image and image editing models including Google NanoBanana and ByteDance SeeDream series. Text-to-image models: - google-nano-banana-text-to-image - google-nano-banana-pro-text-to-image - bytedance-seedream-v3 / v3.1 / v4 / v4.5 Image editing models (accept reference images): - bytedance-seedream-v4.5-edit - bytedance-seedream-v4-edit - google-nano-banana-edit - google-nano-banana-pro-edit SDK: pip install wavespeed Docs: https://wavespeed.ai/docs """ from __future__ import annotations import base64 import logging import time import uuid from typing import Any import httpx try: from wavespeed import Client as WaveSpeedClient _SDK_AVAILABLE = True except ImportError: WaveSpeedClient = None _SDK_AVAILABLE = False from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider logger = logging.getLogger(__name__) # Map friendly names to WaveSpeed model IDs (text-to-image) # Based on https://wavespeed.ai/models MODEL_MAP = { # SeeDream (ByteDance) - NSFW OK "seedream-4.5": "bytedance/seedream-v4.5", "seedream-4": "bytedance/seedream-v4", "seedream-3.1": "bytedance/seedream-v3.1", # NanoBanana (Google) "nano-banana-pro": "google/nano-banana-pro", "nano-banana": "google/nano-banana", # WAN (Alibaba) "wan-2.6": "alibaba/wan-2.6/text-to-image", "wan-2.5": "alibaba/wan-2.5/text-to-image", # Z-Image (WaveSpeed) — supports LoRA, ultra fast "z-image-turbo": "wavespeed-ai/z-image/turbo", "z-image-turbo-lora": "wavespeed-ai/z-image/turbo-lora", "z-image-base-lora": "wavespeed-ai/z-image/base-lora", # Qwen (WaveSpeed) "qwen-image": "wavespeed-ai/qwen-image/text-to-image", # GPT Image (OpenAI) "gpt-image-1.5": "openai/gpt-image-1.5/text-to-image", "gpt-image-1": "openai/gpt-image-1/text-to-image", "gpt-image-1-mini": "openai/gpt-image-1-mini/text-to-image", # Dreamina (ByteDance) "dreamina-3.1": "bytedance/dreamina-v3.1/text-to-image", "dreamina-3": "bytedance/dreamina-v3.0/text-to-image", # Kling (Kuaishou) "kling-image-o3": "kwaivgi/kling-image-o3/text-to-image", # Default "default": "bytedance/seedream-v4.5", } # Image-to-Video models # Based on https://wavespeed.ai/models VIDEO_MODEL_MAP = { # Higgsfield DoP (Cinematic Motion) "higgsfield-dop": "higgsfield/dop/image-to-video", "higgsfield-dop-lite": "higgsfield/dop/image-to-video", # Use options param "higgsfield-dop-turbo": "higgsfield/dop/image-to-video", # Use options param # WAN 2.6 I2V (Alibaba) "wan-2.6-i2v-pro": "alibaba/wan-2.6/image-to-video-pro", "wan-2.6-i2v": "alibaba/wan-2.6/image-to-video", "wan-2.6-i2v-flash": "alibaba/wan-2.6/image-to-video-flash", # WAN 2.5 I2V (Alibaba) "wan-2.5-i2v": "alibaba/wan-2.5/image-to-video", # WAN 2.2 I2V "wan-2.2-i2v-1080p": "alibaba/wan-2.2/i2v-plus-1080p", "wan-2.2-i2v-720p": "wavespeed-ai/wan-2.2/i2v-720p", # Kling (Kuaishou) "kling-o3-pro": "kwaivgi/kling-video-o3-pro/image-to-video", "kling-o3": "kwaivgi/kling-video-o3-std/image-to-video", "kling-motion": "kwaivgi/kling-v2.6-pro/motion-control", # Veo (Google) "veo-3.1": "google/veo-3.1", # Seedance (ByteDance) "seedance-1.5-pro": "bytedance/seedance-v1.5-pro/image-to-video", # Dreamina I2V (ByteDance) "dreamina-i2v-1080p": "bytedance/dreamina-v3.0/image-to-video-1080p", "dreamina-i2v-720p": "bytedance/dreamina-v3.0/image-to-video-720p", # Sora (OpenAI) "sora-2": "openai/sora-2/image-to-video", # Grok (xAI) "grok-imagine-i2v": "x-ai/grok-imagine-video/image-to-video", # Vidu "vidu-q3": "vidu/q3-turbo/image-to-video", # Default "default": "alibaba/wan-2.6/image-to-video", } # Map friendly names to WaveSpeed edit model API paths # Based on https://wavespeed.ai/models EDIT_MODEL_MAP = { # Higgsfield Soul (Character Consistency) "higgsfield-soul": "higgsfield/soul/image-to-image", # SeeDream Edit (ByteDance) - NSFW OK "seedream-4.5-edit": "bytedance/seedream-v4.5/edit", "seedream-4-edit": "bytedance/seedream-v4/edit", # SeeDream Multi-Image (Character Consistency across images) "seedream-4.5-multi": "bytedance/seedream-v4.5/edit-sequential", "seedream-4-multi": "bytedance/seedream-v4/edit-sequential", # WAN Edit (Alibaba) "wan-2.6-edit": "alibaba/wan-2.6/image-edit", "wan-2.5-edit": "alibaba/wan-2.5/image-edit", "wan-2.2-edit": "wavespeed-ai/wan-2.2/image-to-image", # Qwen Edit (WaveSpeed) "qwen-edit-lora": "wavespeed-ai/qwen-image/edit-plus-lora", "qwen-edit-angles": "wavespeed-ai/qwen-image/edit-multiple-angles", "qwen-layered": "wavespeed-ai/qwen-image/layered", # GPT Image Edit (OpenAI) "gpt-image-1.5-edit": "openai/gpt-image-1.5/edit", "gpt-image-1-edit": "openai/gpt-image-1/edit", "gpt-image-1-mini-edit": "openai/gpt-image-1-mini/edit", # NanoBanana Edit (Google) "nano-banana-pro-edit": "google/nano-banana-pro/edit", "nano-banana-edit": "google/nano-banana/edit", # Dreamina Edit (ByteDance) "dreamina-3-edit": "bytedance/dreamina-v3.0/edit", # Kling Edit (Kuaishou) "kling-o3-edit": "kwaivgi/kling-image-o3/edit", # Default edit model "default": "bytedance/seedream-v4.5/edit", } # Models that support multiple reference images MULTI_REF_MODELS = { # SeeDream Sequential (up to 3 images for character consistency) "seedream-4.5-multi": "bytedance/seedream-v4.5/edit-sequential", "seedream-4-multi": "bytedance/seedream-v4/edit-sequential", # NanoBanana Pro (Google) - multi-reference edit "nano-banana-pro-multi": "google/nano-banana-pro/edit", # Kling O1 (up to 10 reference images) "kling-o1-multi": "kwaivgi/kling-o1/image-to-image", # Qwen Multi-Angle (multiple angles of same subject) "qwen-multi-angle": "wavespeed-ai/qwen-image/edit-multiple-angles", } # Reference-to-Video models (character + pose reference) REF_TO_VIDEO_MAP = { # WAN 2.6 Reference-to-Video (multi-view identity consistency) "wan-2.6-ref": "alibaba/wan-2.6/reference-to-video", "wan-2.6-ref-flash": "alibaba/wan-2.6/reference-to-video-flash", # Kling O3 Reference-to-Video "kling-o3-ref": "kwaivgi/kling-video-o3-pro/reference-to-video", "kling-o3-std-ref": "kwaivgi/kling-video-o3-std/reference-to-video", } WAVESPEED_API_BASE = "https://api.wavespeed.ai/api/v3" class WaveSpeedProvider(CloudProvider): """Cloud provider using WaveSpeed.ai for NanoBanana and SeeDream models.""" def __init__(self, api_key: str): self._api_key = api_key self._client = WaveSpeedClient(api_key=api_key) if _SDK_AVAILABLE else None self._http_client = httpx.AsyncClient(timeout=300) @property def name(self) -> str: return "wavespeed" def _resolve_model(self, model_name: str | None) -> str: """Resolve a friendly model name to a WaveSpeed model ID.""" if model_name and model_name in MODEL_MAP: return MODEL_MAP[model_name] if model_name: return model_name return MODEL_MAP["default"] def _resolve_edit_model(self, model_name: str | None) -> str: """Resolve a friendly name to a WaveSpeed edit model API path.""" if model_name and model_name in EDIT_MODEL_MAP: return EDIT_MODEL_MAP[model_name] # Check multi-reference models if model_name and model_name in MULTI_REF_MODELS: return MULTI_REF_MODELS[model_name] if model_name: return model_name return EDIT_MODEL_MAP["default"] def _resolve_video_model(self, model_name: str | None) -> str: """Resolve a friendly name to a WaveSpeed video model API path.""" if model_name and model_name in VIDEO_MODEL_MAP: return VIDEO_MODEL_MAP[model_name] if model_name: return model_name return VIDEO_MODEL_MAP["default"] async def _poll_for_result(self, poll_url: str, max_attempts: int = 60, interval: float = 2.0) -> str: """Poll the WaveSpeed async job URL until outputs are ready. Returns the first output URL when available. """ import asyncio for attempt in range(max_attempts): try: resp = await self._http_client.get( poll_url, headers={"Authorization": f"Bearer {self._api_key}"}, ) resp.raise_for_status() result = resp.json() data = result.get("data", result) status = data.get("status", "") if status == "failed": error_msg = data.get("error", "Unknown error") raise RuntimeError(f"WaveSpeed job failed: {error_msg}") outputs = data.get("outputs", []) if outputs: logger.info("WaveSpeed job completed after %d polls", attempt + 1) return outputs[0] # Also check for 'output' field if "output" in data: out = data["output"] if isinstance(out, list) and out: return out[0] elif isinstance(out, str): return out if status == "completed" and not outputs: raise RuntimeError(f"WaveSpeed job completed but no outputs: {data}") logger.debug("WaveSpeed job pending (attempt %d/%d)", attempt + 1, max_attempts) await asyncio.sleep(interval) except httpx.HTTPStatusError as e: logger.warning("Poll request failed: %s", e) await asyncio.sleep(interval) raise RuntimeError(f"WaveSpeed job timed out after {max_attempts * interval}s") @staticmethod def _ensure_min_image_size(image_bytes: bytes, min_pixels: int = 3686400) -> bytes: """Upscale image if total pixel count is below the minimum required by the API. WaveSpeed edit APIs require images to be at least 3686400 pixels (~1920x1920). Uses Lanczos resampling for quality. """ import io from PIL import Image img = Image.open(io.BytesIO(image_bytes)) w, h = img.size current_pixels = w * h if current_pixels >= min_pixels: return image_bytes # Scale up proportionally to meet minimum scale = (min_pixels / current_pixels) ** 0.5 new_w = int(w * scale) + 1 # +1 to ensure we exceed minimum new_h = int(h * scale) + 1 logger.info("Upscaling image from %dx%d (%d px) to %dx%d (%d px) for API minimum", w, h, current_pixels, new_w, new_h, new_w * new_h) img = img.resize((new_w, new_h), Image.LANCZOS) buf = io.BytesIO() img.save(buf, format="PNG") return buf.getvalue() async def _upload_temp_image(self, image_bytes: bytes) -> str: """Upload image to a temporary public host and return the URL. Uses catbox.moe (anonymous, no account needed, 1hr expiry for temp). Falls back to base64 data URI if upload fails. """ try: # Try catbox.moe litterbox (temporary file hosting, 1h expiry) import aiohttp async with aiohttp.ClientSession() as session: data = aiohttp.FormData() data.add_field("reqtype", "fileupload") data.add_field("time", "1h") data.add_field( "fileToUpload", image_bytes, filename="ref_image.png", content_type="image/png", ) async with session.post( "https://litterbox.catbox.moe/resources/internals/api.php", data=data, ) as resp: if resp.status == 200: url = (await resp.text()).strip() if url.startswith("http"): logger.info("Uploaded temp image: %s", url) return url except Exception as e: logger.warning("Catbox upload failed: %s", e) # Fallback: try imgbb (free, no key needed for anonymous uploads) try: b64 = base64.b64encode(image_bytes).decode() resp = await self._http_client.post( "https://api.imgbb.com/1/upload", data={"image": b64, "expiration": 3600}, params={"key": ""}, # Anonymous upload ) if resp.status_code == 200: url = resp.json()["data"]["url"] logger.info("Uploaded temp image to imgbb: %s", url) return url except Exception as e: logger.warning("imgbb upload failed: %s", e) # Last resort: use 0x0.st try: import aiohttp async with aiohttp.ClientSession() as session: data = aiohttp.FormData() data.add_field( "file", image_bytes, filename="ref_image.png", content_type="image/png", ) async with session.post("https://0x0.st", data=data) as resp: if resp.status == 200: url = (await resp.text()).strip() if url.startswith("http"): logger.info("Uploaded temp image to 0x0.st: %s", url) return url except Exception as e: logger.warning("0x0.st upload failed: %s", e) raise RuntimeError( "Failed to upload reference image to a public host. " "WaveSpeed edit APIs require publicly accessible image URLs." ) async def submit_generation( self, *, positive_prompt: str, negative_prompt: str = "", checkpoint: str = "", lora_name: str | None = None, lora_strength: float = 0.85, seed: int = -1, steps: int = 28, cfg: float = 7.0, width: int = 832, height: int = 1216, model: str | None = None, ) -> str: """Submit a generation job to WaveSpeed. Returns a job ID.""" wavespeed_model = self._resolve_model(model) payload: dict[str, Any] = { "prompt": positive_prompt, "output_format": "png", } if negative_prompt: payload["negative_prompt"] = negative_prompt payload["width"] = width payload["height"] = height if seed >= 0: payload["seed"] = seed if lora_name: payload["loras"] = [{"path": lora_name, "scale": lora_strength}] logger.info("Submitting to WaveSpeed model=%s", wavespeed_model) try: output = self._client.run( wavespeed_model, payload, timeout=300.0, poll_interval=2.0, ) job_id = str(uuid.uuid4()) self._last_result = { "job_id": job_id, "output": output, "timestamp": time.time(), } return job_id except Exception as e: logger.error("WaveSpeed generation failed: %s", e) raise async def submit_edit( self, *, prompt: str, image_urls: list[str], model: str | None = None, size: str | None = None, ) -> str: """Submit an image editing job to WaveSpeed. Returns a job ID. Uses the SeeDream Edit or NanoBanana Edit APIs which accept reference images and apply prompt-guided transformations while preserving identity. """ edit_model_path = self._resolve_edit_model(model) endpoint = f"{WAVESPEED_API_BASE}/{edit_model_path}" payload: dict[str, Any] = { "prompt": prompt, "images": image_urls, "enable_sync_mode": True, "output_format": "png", } if size: payload["size"] = size logger.info("Submitting edit to WaveSpeed model=%s images=%d", edit_model_path, len(image_urls)) try: resp = await self._http_client.post( endpoint, json=payload, headers={ "Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json", }, ) resp.raise_for_status() result_data = resp.json() job_id = str(uuid.uuid4()) self._last_result = { "job_id": job_id, "output": result_data, "timestamp": time.time(), } return job_id except httpx.HTTPStatusError as e: body = e.response.text logger.error("WaveSpeed edit failed (HTTP %d): %s", e.response.status_code, body[:500]) raise RuntimeError(f"WaveSpeed edit API error: {body[:200]}") from e except Exception as e: logger.error("WaveSpeed edit failed: %s", e) raise async def edit_image( self, *, prompt: str, image_bytes: bytes, image_bytes_2: bytes | None = None, model: str | None = None, size: str | None = None, ) -> CloudGenerationResult: """Full edit flow: upload image(s) to temp host, call edit API, download result. Args: prompt: The edit prompt image_bytes: Primary reference image (character/subject) image_bytes_2: Optional second reference image (pose/style reference) model: Model name (some models support multiple references) size: Output size (widthxheight) """ start = time.time() # WaveSpeed edit APIs require minimum image size (3686400 pixels = ~1920x1920) # Auto-upscale small images to meet the requirement image_bytes = self._ensure_min_image_size(image_bytes, min_pixels=3686400) # Upload reference image(s) to public URLs image_urls = [await self._upload_temp_image(image_bytes)] # Upload second reference if provided (for multi-ref models) if image_bytes_2: image_bytes_2 = self._ensure_min_image_size(image_bytes_2, min_pixels=3686400) image_urls.append(await self._upload_temp_image(image_bytes_2)) logger.info("Multi-reference edit: uploading 2 images for model=%s", model) # Submit edit job job_id = await self.submit_edit( prompt=prompt, image_urls=image_urls, model=model, size=size, ) # Get result (already cached by submit_edit with sync mode) return await self.get_result(job_id) async def check_status(self, job_id: str) -> str: """Check job status. WaveSpeed SDK polls internally, so completed jobs are immediate.""" if hasattr(self, '_last_result') and self._last_result.get("job_id") == job_id: return "completed" return "unknown" async def get_result(self, job_id: str) -> CloudGenerationResult: """Get the generation result including image bytes.""" if not hasattr(self, '_last_result') or self._last_result.get("job_id") != job_id: raise RuntimeError(f"No cached result for job {job_id}") output = self._last_result["output"] elapsed = time.time() - self._last_result["timestamp"] # Extract image URL from output — handle various response shapes image_url = None if isinstance(output, dict): # Check for failed status (API may return 200 with status:failed inside) data = output.get("data", output) logger.info("WaveSpeed response data keys: %s", list(data.keys()) if isinstance(data, dict) else type(data)) if data.get("status") == "failed": error_msg = data.get("error", "Unknown error") raise RuntimeError(f"WaveSpeed generation failed: {error_msg}") # Direct API response: {"data": {"outputs": [url, ...]}} outputs = data.get("outputs", []) # Check for async response first (outputs empty but urls.get exists) urls_data = data.get("urls", {}) if not outputs and urls_data and urls_data.get("get"): poll_url = urls_data["get"] logger.info("WaveSpeed returned async job, polling: %s", poll_url[:80]) image_url = await self._poll_for_result(poll_url) elif outputs: image_url = outputs[0] elif "output" in data: out = data["output"] if isinstance(out, list) and out: image_url = out[0] elif isinstance(out, str): image_url = out elif isinstance(output, list) and output: image_url = output[0] elif isinstance(output, str): image_url = output if not image_url: raise RuntimeError(f"No image URL in WaveSpeed output: {output}") # Download the image logger.info("Downloading from WaveSpeed: %s", image_url[:80]) response = await self._http_client.get(image_url) response.raise_for_status() return CloudGenerationResult( job_id=job_id, image_bytes=response.content, generation_time_seconds=elapsed, ) async def generate( self, *, positive_prompt: str, negative_prompt: str = "", model: str | None = None, width: int = 1024, height: int = 1024, seed: int = -1, lora_name: str | None = None, lora_strength: float = 0.85, ) -> CloudGenerationResult: """Convenience method: submit + get result in one call.""" job_id = await self.submit_generation( positive_prompt=positive_prompt, negative_prompt=negative_prompt, model=model, width=width, height=height, seed=seed, lora_name=lora_name, lora_strength=lora_strength, ) return await self.get_result(job_id) async def is_available(self) -> bool: """Check if WaveSpeed API is reachable with valid credentials.""" try: test = self._client.run( "wavespeed-ai/z-image/turbo", {"prompt": "test"}, enable_sync_mode=True, timeout=10.0, ) return True except Exception: try: resp = await self._http_client.get( "https://api.wavespeed.ai/api/v3/health", headers={"Authorization": f"Bearer {self._api_key}"}, ) return resp.status_code < 500 except Exception: return False