dippoo's picture
Add Kling Motion Control: character image + driving video via WaveSpeed
b02f80a
raw
history blame
23.4 kB
"""WaveSpeed.ai cloud provider — integrates NanoBanana, SeeDream and other models.
WaveSpeed provides fast cloud inference for text-to-image and image editing models
including Google NanoBanana and ByteDance SeeDream series.
Text-to-image models:
- google-nano-banana-text-to-image
- google-nano-banana-pro-text-to-image
- bytedance-seedream-v3 / v3.1 / v4 / v4.5
Image editing models (accept reference images):
- bytedance-seedream-v4.5-edit
- bytedance-seedream-v4-edit
- google-nano-banana-edit
- google-nano-banana-pro-edit
SDK: pip install wavespeed
Docs: https://wavespeed.ai/docs
"""
from __future__ import annotations
import base64
import logging
import time
import uuid
from typing import Any
import httpx
try:
from wavespeed import Client as WaveSpeedClient
_SDK_AVAILABLE = True
except ImportError:
WaveSpeedClient = None
_SDK_AVAILABLE = False
from content_engine.services.cloud_providers.base import CloudGenerationResult, CloudProvider
logger = logging.getLogger(__name__)
# Map friendly names to WaveSpeed model IDs (text-to-image)
# Based on https://wavespeed.ai/models
MODEL_MAP = {
# SeeDream (ByteDance) - NSFW OK
"seedream-4.5": "bytedance/seedream-v4.5",
"seedream-4": "bytedance/seedream-v4",
"seedream-3.1": "bytedance/seedream-v3.1",
# NanoBanana (Google)
"nano-banana-pro": "google/nano-banana-pro",
"nano-banana": "google/nano-banana",
# WAN (Alibaba)
"wan-2.6": "alibaba/wan-2.6/text-to-image",
"wan-2.5": "alibaba/wan-2.5/text-to-image",
# Z-Image (WaveSpeed) — supports LoRA, ultra fast
"z-image-turbo": "wavespeed-ai/z-image/turbo",
"z-image-turbo-lora": "wavespeed-ai/z-image/turbo-lora",
"z-image-base-lora": "wavespeed-ai/z-image/base-lora",
# Qwen (WaveSpeed)
"qwen-image": "wavespeed-ai/qwen-image/text-to-image",
# GPT Image (OpenAI)
"gpt-image-1.5": "openai/gpt-image-1.5/text-to-image",
"gpt-image-1": "openai/gpt-image-1/text-to-image",
"gpt-image-1-mini": "openai/gpt-image-1-mini/text-to-image",
# Dreamina (ByteDance)
"dreamina-3.1": "bytedance/dreamina-v3.1/text-to-image",
"dreamina-3": "bytedance/dreamina-v3.0/text-to-image",
# Kling (Kuaishou)
"kling-image-o3": "kwaivgi/kling-image-o3/text-to-image",
# Default
"default": "bytedance/seedream-v4.5",
}
# Image-to-Video models
# Based on https://wavespeed.ai/models
VIDEO_MODEL_MAP = {
# Higgsfield DoP (Cinematic Motion)
"higgsfield-dop": "higgsfield/dop/image-to-video",
"higgsfield-dop-lite": "higgsfield/dop/image-to-video", # Use options param
"higgsfield-dop-turbo": "higgsfield/dop/image-to-video", # Use options param
# WAN 2.6 I2V (Alibaba)
"wan-2.6-i2v-pro": "alibaba/wan-2.6/image-to-video-pro",
"wan-2.6-i2v": "alibaba/wan-2.6/image-to-video",
"wan-2.6-i2v-flash": "alibaba/wan-2.6/image-to-video-flash",
# WAN 2.5 I2V (Alibaba)
"wan-2.5-i2v": "alibaba/wan-2.5/image-to-video",
# WAN 2.2 I2V
"wan-2.2-i2v-1080p": "alibaba/wan-2.2/i2v-plus-1080p",
"wan-2.2-i2v-720p": "wavespeed-ai/wan-2.2/i2v-720p",
# Kling (Kuaishou)
"kling-o3-pro": "kwaivgi/kling-video-o3-pro/image-to-video",
"kling-o3": "kwaivgi/kling-video-o3-std/image-to-video",
"kling-motion": "kwaivgi/kling-v2.6-pro/motion-control",
# Veo (Google)
"veo-3.1": "google/veo-3.1",
# Seedance (ByteDance)
"seedance-1.5-pro": "bytedance/seedance-v1.5-pro/image-to-video",
# Dreamina I2V (ByteDance)
"dreamina-i2v-1080p": "bytedance/dreamina-v3.0/image-to-video-1080p",
"dreamina-i2v-720p": "bytedance/dreamina-v3.0/image-to-video-720p",
# Sora (OpenAI)
"sora-2": "openai/sora-2/image-to-video",
# Grok (xAI)
"grok-imagine-i2v": "x-ai/grok-imagine-video/image-to-video",
# Vidu
"vidu-q3": "vidu/q3-turbo/image-to-video",
# Default
"default": "alibaba/wan-2.6/image-to-video",
}
# Map friendly names to WaveSpeed edit model API paths
# Based on https://wavespeed.ai/models
EDIT_MODEL_MAP = {
# Higgsfield Soul (Character Consistency)
"higgsfield-soul": "higgsfield/soul/image-to-image",
# SeeDream Edit (ByteDance) - NSFW OK
"seedream-4.5-edit": "bytedance/seedream-v4.5/edit",
"seedream-4-edit": "bytedance/seedream-v4/edit",
# SeeDream Multi-Image (Character Consistency across images)
"seedream-4.5-multi": "bytedance/seedream-v4.5/edit-sequential",
"seedream-4-multi": "bytedance/seedream-v4/edit-sequential",
# WAN Edit (Alibaba)
"wan-2.6-edit": "alibaba/wan-2.6/image-edit",
"wan-2.5-edit": "alibaba/wan-2.5/image-edit",
"wan-2.2-edit": "wavespeed-ai/wan-2.2/image-to-image",
# Qwen Edit (WaveSpeed)
"qwen-edit-lora": "wavespeed-ai/qwen-image/edit-plus-lora",
"qwen-edit-angles": "wavespeed-ai/qwen-image/edit-multiple-angles",
"qwen-layered": "wavespeed-ai/qwen-image/layered",
# GPT Image Edit (OpenAI)
"gpt-image-1.5-edit": "openai/gpt-image-1.5/edit",
"gpt-image-1-edit": "openai/gpt-image-1/edit",
"gpt-image-1-mini-edit": "openai/gpt-image-1-mini/edit",
# NanoBanana Edit (Google)
"nano-banana-pro-edit": "google/nano-banana-pro/edit",
"nano-banana-edit": "google/nano-banana/edit",
# Dreamina Edit (ByteDance)
"dreamina-3-edit": "bytedance/dreamina-v3.0/edit",
# Kling Edit (Kuaishou)
"kling-o3-edit": "kwaivgi/kling-image-o3/edit",
# Default edit model
"default": "bytedance/seedream-v4.5/edit",
}
# Models that support multiple reference images
MULTI_REF_MODELS = {
# SeeDream Sequential (up to 3 images for character consistency)
"seedream-4.5-multi": "bytedance/seedream-v4.5/edit-sequential",
"seedream-4-multi": "bytedance/seedream-v4/edit-sequential",
# NanoBanana Pro (Google) - multi-reference edit
"nano-banana-pro-multi": "google/nano-banana-pro/edit",
# Kling O1 (up to 10 reference images)
"kling-o1-multi": "kwaivgi/kling-o1/image-to-image",
# Qwen Multi-Angle (multiple angles of same subject)
"qwen-multi-angle": "wavespeed-ai/qwen-image/edit-multiple-angles",
}
# Reference-to-Video models (character + pose reference)
REF_TO_VIDEO_MAP = {
# WAN 2.6 Reference-to-Video (multi-view identity consistency)
"wan-2.6-ref": "alibaba/wan-2.6/reference-to-video",
"wan-2.6-ref-flash": "alibaba/wan-2.6/reference-to-video-flash",
# Kling O3 Reference-to-Video
"kling-o3-ref": "kwaivgi/kling-video-o3-pro/reference-to-video",
"kling-o3-std-ref": "kwaivgi/kling-video-o3-std/reference-to-video",
}
WAVESPEED_API_BASE = "https://api.wavespeed.ai/api/v3"
class WaveSpeedProvider(CloudProvider):
"""Cloud provider using WaveSpeed.ai for NanoBanana and SeeDream models."""
def __init__(self, api_key: str):
self._api_key = api_key
self._client = WaveSpeedClient(api_key=api_key) if _SDK_AVAILABLE else None
self._http_client = httpx.AsyncClient(timeout=300)
@property
def name(self) -> str:
return "wavespeed"
def _resolve_model(self, model_name: str | None) -> str:
"""Resolve a friendly model name to a WaveSpeed model ID."""
if model_name and model_name in MODEL_MAP:
return MODEL_MAP[model_name]
if model_name:
return model_name
return MODEL_MAP["default"]
def _resolve_edit_model(self, model_name: str | None) -> str:
"""Resolve a friendly name to a WaveSpeed edit model API path."""
if model_name and model_name in EDIT_MODEL_MAP:
return EDIT_MODEL_MAP[model_name]
# Check multi-reference models
if model_name and model_name in MULTI_REF_MODELS:
return MULTI_REF_MODELS[model_name]
if model_name:
return model_name
return EDIT_MODEL_MAP["default"]
def _resolve_video_model(self, model_name: str | None) -> str:
"""Resolve a friendly name to a WaveSpeed video model API path."""
if model_name and model_name in VIDEO_MODEL_MAP:
return VIDEO_MODEL_MAP[model_name]
if model_name:
return model_name
return VIDEO_MODEL_MAP["default"]
async def _poll_for_result(self, poll_url: str, max_attempts: int = 60, interval: float = 2.0) -> str:
"""Poll the WaveSpeed async job URL until outputs are ready.
Returns the first output URL when available.
"""
import asyncio
for attempt in range(max_attempts):
try:
resp = await self._http_client.get(
poll_url,
headers={"Authorization": f"Bearer {self._api_key}"},
)
resp.raise_for_status()
result = resp.json()
data = result.get("data", result)
status = data.get("status", "")
if status == "failed":
error_msg = data.get("error", "Unknown error")
raise RuntimeError(f"WaveSpeed job failed: {error_msg}")
outputs = data.get("outputs", [])
if outputs:
logger.info("WaveSpeed job completed after %d polls", attempt + 1)
return outputs[0]
# Also check for 'output' field
if "output" in data:
out = data["output"]
if isinstance(out, list) and out:
return out[0]
elif isinstance(out, str):
return out
if status == "completed" and not outputs:
raise RuntimeError(f"WaveSpeed job completed but no outputs: {data}")
logger.debug("WaveSpeed job pending (attempt %d/%d)", attempt + 1, max_attempts)
await asyncio.sleep(interval)
except httpx.HTTPStatusError as e:
logger.warning("Poll request failed: %s", e)
await asyncio.sleep(interval)
raise RuntimeError(f"WaveSpeed job timed out after {max_attempts * interval}s")
@staticmethod
def _ensure_min_image_size(image_bytes: bytes, min_pixels: int = 3686400) -> bytes:
"""Upscale image if total pixel count is below the minimum required by the API.
WaveSpeed edit APIs require images to be at least 3686400 pixels (~1920x1920).
Uses Lanczos resampling for quality.
"""
import io
from PIL import Image
img = Image.open(io.BytesIO(image_bytes))
w, h = img.size
current_pixels = w * h
if current_pixels >= min_pixels:
return image_bytes
# Scale up proportionally to meet minimum
scale = (min_pixels / current_pixels) ** 0.5
new_w = int(w * scale) + 1 # +1 to ensure we exceed minimum
new_h = int(h * scale) + 1
logger.info("Upscaling image from %dx%d (%d px) to %dx%d (%d px) for API minimum",
w, h, current_pixels, new_w, new_h, new_w * new_h)
img = img.resize((new_w, new_h), Image.LANCZOS)
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
async def _upload_temp_image(self, image_bytes: bytes) -> str:
"""Upload image to a temporary public host and return the URL.
Uses catbox.moe (anonymous, no account needed, 1hr expiry for temp).
Falls back to base64 data URI if upload fails.
"""
try:
# Try catbox.moe litterbox (temporary file hosting, 1h expiry)
import aiohttp
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field("reqtype", "fileupload")
data.add_field("time", "1h")
data.add_field(
"fileToUpload",
image_bytes,
filename="ref_image.png",
content_type="image/png",
)
async with session.post(
"https://litterbox.catbox.moe/resources/internals/api.php",
data=data,
) as resp:
if resp.status == 200:
url = (await resp.text()).strip()
if url.startswith("http"):
logger.info("Uploaded temp image: %s", url)
return url
except Exception as e:
logger.warning("Catbox upload failed: %s", e)
# Fallback: try imgbb (free, no key needed for anonymous uploads)
try:
b64 = base64.b64encode(image_bytes).decode()
resp = await self._http_client.post(
"https://api.imgbb.com/1/upload",
data={"image": b64, "expiration": 3600},
params={"key": ""}, # Anonymous upload
)
if resp.status_code == 200:
url = resp.json()["data"]["url"]
logger.info("Uploaded temp image to imgbb: %s", url)
return url
except Exception as e:
logger.warning("imgbb upload failed: %s", e)
# Last resort: use 0x0.st
try:
import aiohttp
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field(
"file",
image_bytes,
filename="ref_image.png",
content_type="image/png",
)
async with session.post("https://0x0.st", data=data) as resp:
if resp.status == 200:
url = (await resp.text()).strip()
if url.startswith("http"):
logger.info("Uploaded temp image to 0x0.st: %s", url)
return url
except Exception as e:
logger.warning("0x0.st upload failed: %s", e)
raise RuntimeError(
"Failed to upload reference image to a public host. "
"WaveSpeed edit APIs require publicly accessible image URLs."
)
async def submit_generation(
self,
*,
positive_prompt: str,
negative_prompt: str = "",
checkpoint: str = "",
lora_name: str | None = None,
lora_strength: float = 0.85,
seed: int = -1,
steps: int = 28,
cfg: float = 7.0,
width: int = 832,
height: int = 1216,
model: str | None = None,
) -> str:
"""Submit a generation job to WaveSpeed. Returns a job ID."""
wavespeed_model = self._resolve_model(model)
payload: dict[str, Any] = {
"prompt": positive_prompt,
"output_format": "png",
}
if negative_prompt:
payload["negative_prompt"] = negative_prompt
payload["width"] = width
payload["height"] = height
if seed >= 0:
payload["seed"] = seed
if lora_name:
payload["loras"] = [{"path": lora_name, "scale": lora_strength}]
logger.info("Submitting to WaveSpeed model=%s", wavespeed_model)
try:
output = self._client.run(
wavespeed_model,
payload,
timeout=300.0,
poll_interval=2.0,
)
job_id = str(uuid.uuid4())
self._last_result = {
"job_id": job_id,
"output": output,
"timestamp": time.time(),
}
return job_id
except Exception as e:
logger.error("WaveSpeed generation failed: %s", e)
raise
async def submit_edit(
self,
*,
prompt: str,
image_urls: list[str],
model: str | None = None,
size: str | None = None,
) -> str:
"""Submit an image editing job to WaveSpeed. Returns a job ID.
Uses the SeeDream Edit or NanoBanana Edit APIs which accept reference
images and apply prompt-guided transformations while preserving identity.
"""
edit_model_path = self._resolve_edit_model(model)
endpoint = f"{WAVESPEED_API_BASE}/{edit_model_path}"
payload: dict[str, Any] = {
"prompt": prompt,
"images": image_urls,
"enable_sync_mode": True,
"output_format": "png",
}
if size:
payload["size"] = size
logger.info("Submitting edit to WaveSpeed model=%s images=%d", edit_model_path, len(image_urls))
try:
resp = await self._http_client.post(
endpoint,
json=payload,
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
)
resp.raise_for_status()
result_data = resp.json()
job_id = str(uuid.uuid4())
self._last_result = {
"job_id": job_id,
"output": result_data,
"timestamp": time.time(),
}
return job_id
except httpx.HTTPStatusError as e:
body = e.response.text
logger.error("WaveSpeed edit failed (HTTP %d): %s", e.response.status_code, body[:500])
raise RuntimeError(f"WaveSpeed edit API error: {body[:200]}") from e
except Exception as e:
logger.error("WaveSpeed edit failed: %s", e)
raise
async def edit_image(
self,
*,
prompt: str,
image_bytes: bytes,
image_bytes_2: bytes | None = None,
model: str | None = None,
size: str | None = None,
) -> CloudGenerationResult:
"""Full edit flow: upload image(s) to temp host, call edit API, download result.
Args:
prompt: The edit prompt
image_bytes: Primary reference image (character/subject)
image_bytes_2: Optional second reference image (pose/style reference)
model: Model name (some models support multiple references)
size: Output size (widthxheight)
"""
start = time.time()
# WaveSpeed edit APIs require minimum image size (3686400 pixels = ~1920x1920)
# Auto-upscale small images to meet the requirement
image_bytes = self._ensure_min_image_size(image_bytes, min_pixels=3686400)
# Upload reference image(s) to public URLs
image_urls = [await self._upload_temp_image(image_bytes)]
# Upload second reference if provided (for multi-ref models)
if image_bytes_2:
image_bytes_2 = self._ensure_min_image_size(image_bytes_2, min_pixels=3686400)
image_urls.append(await self._upload_temp_image(image_bytes_2))
logger.info("Multi-reference edit: uploading 2 images for model=%s", model)
# Submit edit job
job_id = await self.submit_edit(
prompt=prompt,
image_urls=image_urls,
model=model,
size=size,
)
# Get result (already cached by submit_edit with sync mode)
return await self.get_result(job_id)
async def check_status(self, job_id: str) -> str:
"""Check job status. WaveSpeed SDK polls internally, so completed jobs are immediate."""
if hasattr(self, '_last_result') and self._last_result.get("job_id") == job_id:
return "completed"
return "unknown"
async def get_result(self, job_id: str) -> CloudGenerationResult:
"""Get the generation result including image bytes."""
if not hasattr(self, '_last_result') or self._last_result.get("job_id") != job_id:
raise RuntimeError(f"No cached result for job {job_id}")
output = self._last_result["output"]
elapsed = time.time() - self._last_result["timestamp"]
# Extract image URL from output — handle various response shapes
image_url = None
if isinstance(output, dict):
# Check for failed status (API may return 200 with status:failed inside)
data = output.get("data", output)
logger.info("WaveSpeed response data keys: %s", list(data.keys()) if isinstance(data, dict) else type(data))
if data.get("status") == "failed":
error_msg = data.get("error", "Unknown error")
raise RuntimeError(f"WaveSpeed generation failed: {error_msg}")
# Direct API response: {"data": {"outputs": [url, ...]}}
outputs = data.get("outputs", [])
# Check for async response first (outputs empty but urls.get exists)
urls_data = data.get("urls", {})
if not outputs and urls_data and urls_data.get("get"):
poll_url = urls_data["get"]
logger.info("WaveSpeed returned async job, polling: %s", poll_url[:80])
image_url = await self._poll_for_result(poll_url)
elif outputs:
image_url = outputs[0]
elif "output" in data:
out = data["output"]
if isinstance(out, list) and out:
image_url = out[0]
elif isinstance(out, str):
image_url = out
elif isinstance(output, list) and output:
image_url = output[0]
elif isinstance(output, str):
image_url = output
if not image_url:
raise RuntimeError(f"No image URL in WaveSpeed output: {output}")
# Download the image
logger.info("Downloading from WaveSpeed: %s", image_url[:80])
response = await self._http_client.get(image_url)
response.raise_for_status()
return CloudGenerationResult(
job_id=job_id,
image_bytes=response.content,
generation_time_seconds=elapsed,
)
async def generate(
self,
*,
positive_prompt: str,
negative_prompt: str = "",
model: str | None = None,
width: int = 1024,
height: int = 1024,
seed: int = -1,
lora_name: str | None = None,
lora_strength: float = 0.85,
) -> CloudGenerationResult:
"""Convenience method: submit + get result in one call."""
job_id = await self.submit_generation(
positive_prompt=positive_prompt,
negative_prompt=negative_prompt,
model=model,
width=width,
height=height,
seed=seed,
lora_name=lora_name,
lora_strength=lora_strength,
)
return await self.get_result(job_id)
async def is_available(self) -> bool:
"""Check if WaveSpeed API is reachable with valid credentials."""
try:
test = self._client.run(
"wavespeed-ai/z-image/turbo",
{"prompt": "test"},
enable_sync_mode=True,
timeout=10.0,
)
return True
except Exception:
try:
resp = await self._http_client.get(
"https://api.wavespeed.ai/api/v3/health",
headers={"Authorization": f"Bearer {self._api_key}"},
)
return resp.status_code < 500
except Exception:
return False